{ "cells": [ { "cell_type": "markdown", "id": "g_nWetWWd_ns", "metadata": { "id": "g_nWetWWd_ns" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "id": "2pHVBk_seED1", "metadata": { "cellView": "form", "id": "2pHVBk_seED1" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "id": "M7vSdG6sAIQn", "metadata": { "id": "M7vSdG6sAIQn" }, "source": [ "# TensorFlow Lite のシグネチャ" ] }, { "cell_type": "markdown", "id": "fwc5GKHBASdc", "metadata": { "id": "fwc5GKHBASdc" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示 Google Colab で実行 GitHubでソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "id": "9ee074e4", "metadata": { "id": "9ee074e4" }, "source": [ "TensorFlow Lite では、TensorFlow モデルの入出力仕様を TensorFlow Lite モデルに変換できます。入出力仕様は「シグネチャ」と呼ばれます。署名は、SavedModel の構築時または具体的な関数の作成時に指定できます。\n", "\n", "TensorFlow Lite のシグネチャには次の機能があります。\n", "\n", "- TensorFlow モデルのシグネチャを守ることで、変換された TensorFlow Lite モデルの入出力を指定する。\n", "- 1 つの TensorFlow Lite モデルで複数の入力点をサポートできる。\n", "\n", "シグネチャには次の 3 つの要素があります。\n", "\n", "- 入力: シグネチャの入力名から入力テンソルへの入力のマッピング。\n", "- 出力: シグネチャの出力名から出力テンソルへの出力のマッピング。\n", "- シグネチャキー: グラフの入力点を識別する名前。\n" ] }, { "cell_type": "markdown", "id": "UaWdLA3fQDK2", "metadata": { "id": "UaWdLA3fQDK2" }, "source": [ "## MNIST モデルをビルドする" ] }, { "cell_type": "code", "execution_count": null, "id": "9j4MGqyKQEo4", "metadata": { "id": "9j4MGqyKQEo4" }, "outputs": [], "source": [ "import tensorflow as tf" ] }, { "cell_type": "markdown", "id": "FN2N6hPEP-Ay", "metadata": { "id": "FN2N6hPEP-Ay" }, "source": [ "## サンプル モデル\n", "\n", "エンコードとデコードといった 2 つのタスクが TensorFlow モデルとして存在するとします。" ] }, { "cell_type": "code", "execution_count": null, "id": "d8577c80", "metadata": { "id": "d8577c80" }, "outputs": [], "source": [ "class Model(tf.Module):\n", "\n", " @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])\n", " def encode(self, x):\n", " result = tf.strings.as_string(x)\n", " return {\n", " \"encoded_result\": result\n", " }\n", "\n", " @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])\n", " def decode(self, x):\n", " result = tf.strings.to_number(x)\n", " return {\n", " \"decoded_result\": result\n", " }" ] }, { "cell_type": "markdown", "id": "9c814c6e", "metadata": { "id": "9c814c6e" }, "source": [ "シグネチャという点では、上記の TensorFlow モデルは次のように要約することができます。\n", "\n", "- シグネチャ\n", "\n", " - キー: encode\n", " - 入力: {\"x\"}\n", " - 出力: {\"encoded_result\"}\n", "\n", "- シグネチャ\n", "\n", " - キー: decode\n", " - 入力: {\"x\"}\n", " - 出力: {\"decoded_result\"}" ] }, { "cell_type": "markdown", "id": "c4099f20", "metadata": { "id": "c4099f20" }, "source": [ "## シグネチャを使用したモデルの変換\n", "\n", "TensorFlow Lite コンバータ API は、上記のシグネチャ情報を変換された TensorFlow Lite モデルに渡します。\n", "\n", "この変換機能は、TensorFlow バージョン 2.7.0 以降のすべてのコンバータ API で提供されています。使用例を参照してください。\n" ] }, { "cell_type": "markdown", "id": "Qv0WwFQkQgnO", "metadata": { "id": "Qv0WwFQkQgnO" }, "source": [ "### 保存されたモデルから変換" ] }, { "cell_type": "code", "execution_count": null, "id": "96c8fc79", "metadata": { "id": "96c8fc79" }, "outputs": [], "source": [ "model = Model()\n", "\n", "# Save the model\n", "SAVED_MODEL_PATH = 'content/saved_models/coding'\n", "\n", "tf.saved_model.save(\n", " model, SAVED_MODEL_PATH,\n", " signatures={\n", " 'encode': model.encode.get_concrete_function(),\n", " 'decode': model.decode.get_concrete_function()\n", " })\n", "\n", "# Convert the saved model using TFLiteConverter\n", "converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)\n", "converter.target_spec.supported_ops = [\n", " tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.\n", " tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.\n", "]\n", "tflite_model = converter.convert()\n", "\n", "# Print the signatures from the converted model\n", "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n", "signatures = interpreter.get_signature_list()\n", "print(signatures)" ] }, { "cell_type": "markdown", "id": "5baa9f17", "metadata": { "id": "5baa9f17" }, "source": [ "### Keras モデルから変換" ] }, { "cell_type": "code", "execution_count": null, "id": "71f29229", "metadata": { "id": "71f29229" }, "outputs": [], "source": [ "# Generate a Keras model.\n", "keras_model = tf.keras.Sequential(\n", " [\n", " tf.keras.layers.Dense(2, input_dim=4, activation='relu', name='x'),\n", " tf.keras.layers.Dense(1, activation='relu', name='output'),\n", " ]\n", ")\n", "\n", "# Convert the keras model using TFLiteConverter.\n", "# Keras model converter API uses the default signature automatically.\n", "converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)\n", "tflite_model = converter.convert()\n", "\n", "# Print the signatures from the converted model\n", "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n", "\n", "signatures = interpreter.get_signature_list()\n", "print(signatures)" ] }, { "cell_type": "markdown", "id": "e4d30f85", "metadata": { "id": "e4d30f85" }, "source": [ "### Concrete 関数から変換" ] }, { "cell_type": "code", "execution_count": null, "id": "c9e8a742", "metadata": { "id": "c9e8a742" }, "outputs": [], "source": [ "model = Model()\n", "\n", "# Convert the concrete functions using TFLiteConverter\n", "converter = tf.lite.TFLiteConverter.from_concrete_functions(\n", " [model.encode.get_concrete_function(),\n", " model.decode.get_concrete_function()], model)\n", "converter.target_spec.supported_ops = [\n", " tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.\n", " tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.\n", "]\n", "tflite_model = converter.convert()\n", "\n", "# Print the signatures from the converted model\n", "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n", "signatures = interpreter.get_signature_list()\n", "print(signatures)" ] }, { "cell_type": "markdown", "id": "b5e85934", "metadata": { "id": "b5e85934" }, "source": [ "## シグネチャの実行\n", "\n", "TensorFlow の推論 API は、シグネチャに基づく実行をサポートします。\n", "\n", "- シグネチャで指定された入出力の名前を使用して、入出力テンソルにアクセスします。\n", "- シグネチャキーで指定されたグラフの各入力点を個別に実行します。\n", "- SavedModel の初期化手順をサポートします。\n", "\n", "Java、C++、Python 言語バインディングは現在使用できます。次のセクションの例を参照してください。\n" ] }, { "cell_type": "markdown", "id": "ZRBMFciMQmiB", "metadata": { "id": "ZRBMFciMQmiB" }, "source": [ "### Java" ] }, { "cell_type": "markdown", "id": "04c5a4fc", "metadata": { "id": "04c5a4fc" }, "source": [ "```\n", "try (Interpreter interpreter = new Interpreter(file_of_tensorflowlite_model)) {\n", " // Run encoding signature.\n", " Map<String, Object> inputs = new HashMap<>();\n", " inputs.put(\"x\", input);\n", " Map<String, Object> outputs = new HashMap<>();\n", " outputs.put(\"encoded_result\", encoded_result);\n", " interpreter.runSignature(inputs, outputs, \"encode\");\n", "\n", " // Run decoding signature.\n", " Map<String, Object> inputs = new HashMap<>();\n", " inputs.put(\"x\", encoded_result);\n", " Map<String, Object> outputs = new HashMap<>();\n", " outputs.put(\"decoded_result\", decoded_result);\n", " interpreter.runSignature(inputs, outputs, \"decode\");\n", "}\n", "```" ] }, { "cell_type": "markdown", "id": "5ba86c64", "metadata": { "id": "5ba86c64" }, "source": [ "### C++" ] }, { "cell_type": "markdown", "id": "397ad6fd", "metadata": { "id": "397ad6fd" }, "source": [ "```\n", "SignatureRunner* encode_runner =\n", " interpreter->GetSignatureRunner(\"encode\");\n", "encode_runner->ResizeInputTensor(\"x\", {100});\n", "encode_runner->AllocateTensors();\n", "\n", "TfLiteTensor* input_tensor = encode_runner->input_tensor(\"x\");\n", "float* input = GetTensorData(input_tensor);\n", "// Fill `input`.\n", "\n", "encode_runner->Invoke();\n", "\n", "const TfLiteTensor* output_tensor = encode_runner->output_tensor(\n", " \"encoded_result\");\n", "float* output = GetTensorData(output_tensor);\n", "// Access `output`.\n", "```" ] }, { "cell_type": "markdown", "id": "0f4c6ad4", "metadata": { "id": "0f4c6ad4" }, "source": [ "### Python" ] }, { "cell_type": "code", "execution_count": null, "id": "ab7b1963", "metadata": { "id": "ab7b1963" }, "outputs": [], "source": [ "# Load the TFLite model in TFLite Interpreter\n", "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n", "\n", "# Print the signatures from the converted model\n", "signatures = interpreter.get_signature_list()\n", "print('Signature:', signatures)\n", "\n", "# encode and decode are callable with input as arguments.\n", "encode = interpreter.get_signature_runner('encode')\n", "decode = interpreter.get_signature_runner('decode')\n", "\n", "# 'encoded' and 'decoded' are dictionaries with all outputs from the inference.\n", "input = tf.constant([1, 2, 3], dtype=tf.float32)\n", "print('Input:', input)\n", "encoded = encode(x=input)\n", "print('Encoded result:', encoded)\n", "decoded = decode(x=encoded['encoded_result'])\n", "print('Decoded result:', decoded)" ] }, { "cell_type": "markdown", "id": "81b42e5b", "metadata": { "id": "81b42e5b" }, "source": [ "## 既知の制限\n", "\n", "- TFLite インタープリタはスレッドの安全を保証しないため、同じインタープリタからのシグネチャランナーは同時に実行されません。\n", "- C/iOS/Swift のサポートはまだ提供されていません。\n" ] }, { "cell_type": "markdown", "id": "3032Iof6QqmJ", "metadata": { "id": "3032Iof6QqmJ" }, "source": [ "## 更新\n", "\n", "- バージョン 2.7\n", " - 複数のシグネチャ機能が実装されました。\n", " - バージョン 2 以降のすべてのコンバータ API は、シグネチャ対応 TensorFlow Lite モデルを生成します。\n", "- バージョン 2.5\n", " - シグネチャ機能は、`from_saved_model` コンバータ API から利用できます。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "signatures.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }