{ "cells": [ { "cell_type": "markdown", "id": "g_nWetWWd_ns", "metadata": { "id": "g_nWetWWd_ns" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "id": "2pHVBk_seED1", "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-15T01:06:26.552036Z", "iopub.status.busy": "2022-12-15T01:06:26.551825Z", "iopub.status.idle": "2022-12-15T01:06:26.555940Z", "shell.execute_reply": "2022-12-15T01:06:26.555432Z" }, "id": "2pHVBk_seED1" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "id": "M7vSdG6sAIQn", "metadata": { "id": "M7vSdG6sAIQn" }, "source": [ "# TensorFlow Lite의 서명" ] }, { "cell_type": "markdown", "id": "fwc5GKHBASdc", "metadata": { "id": "fwc5GKHBASdc" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org에서 보기Google Colab에서 실행GitHub에서 소스 보기노트북 다운로드
" ] }, { "cell_type": "markdown", "id": "9ee074e4", "metadata": { "id": "9ee074e4" }, "source": [ "TensorFlow Lite는 TensorFlow 모델의 입력/출력 사양을 TensorFlow Lite 모델로 변환하는 것을 지원합니다. 입/출력 사양을 \"서명\"이라고 합니다. SavedModel을 구축하거나 구체적인 기능을 생성할 때 서명을 지정할 수 있습니다.\n", "\n", "TensorFlow Lite의 서명은 다음 기능을 제공합니다.\n", "\n", "- TensorFlow 모델의 서명을 적용하여 변환된 TensorFlow Lite 모델의 입력 및 출력을 지정합니다.\n", "- 단일 TensorFlow Lite 모델이 여러 진입점을 지원할 수 있습니다.\n", "\n", "서명은 세 부분으로 구성됩니다.\n", "\n", "- 입력: 서명의 입력 이름에서 입력 텐서로의 입력에 대한 매핑입니다.\n", "- 출력: 서명의 출력 이름에서 출력 텐서로의 출력 매핑을 위한 맵입니다.\n", "- 서명 키: 그래프의 진입점을 식별하는 이름입니다.\n" ] }, { "cell_type": "markdown", "id": "UaWdLA3fQDK2", "metadata": { "id": "UaWdLA3fQDK2" }, "source": [ "## 설정" ] }, { "cell_type": "code", "execution_count": 2, "id": "9j4MGqyKQEo4", "metadata": { "execution": { "iopub.execute_input": "2022-12-15T01:06:26.559565Z", "iopub.status.busy": "2022-12-15T01:06:26.559166Z", "iopub.status.idle": "2022-12-15T01:06:28.592819Z", "shell.execute_reply": "2022-12-15T01:06:28.592091Z" }, "id": "9j4MGqyKQEo4" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-15 01:06:27.563712: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-15 01:06:27.563811: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-15 01:06:27.563821: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import tensorflow as tf" ] }, { "cell_type": "markdown", "id": "FN2N6hPEP-Ay", "metadata": { "id": "FN2N6hPEP-Ay" }, "source": [ "## 예제 모델\n", "\n", "TensorFlow 모델로 인코딩 및 디코딩과 같은 두 가지 작업이 있다고 가정해 보겠습니다." ] }, { "cell_type": "code", "execution_count": 3, "id": "d8577c80", "metadata": { "execution": { "iopub.execute_input": "2022-12-15T01:06:28.597095Z", "iopub.status.busy": "2022-12-15T01:06:28.596652Z", "iopub.status.idle": "2022-12-15T01:06:28.603152Z", "shell.execute_reply": "2022-12-15T01:06:28.602541Z" }, "id": "d8577c80" }, "outputs": [], "source": [ "class Model(tf.Module):\n", "\n", " @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])\n", " def encode(self, x):\n", " result = tf.strings.as_string(x)\n", " return {\n", " \"encoded_result\": result\n", " }\n", "\n", " @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])\n", " def decode(self, x):\n", " result = tf.strings.to_number(x)\n", " return {\n", " \"decoded_result\": result\n", " }" ] }, { "cell_type": "markdown", "id": "9c814c6e", "metadata": { "id": "9c814c6e" }, "source": [ "서명 측면에서 위의 TensorFlow 모델은 다음과 같이 요약될 수 있습니다.\n", "\n", "- 서명\n", "\n", " - 키: 인코딩\n", " - 입력: {\"x\"}\n", " - 출력: {\"encoded_result\"}\n", "\n", "- 서명\n", "\n", " - 키: 디코딩\n", " - 입력: {\"x\"}\n", " - 출력: {\"decoded_result\"}" ] }, { "cell_type": "markdown", "id": "c4099f20", "metadata": { "id": "c4099f20" }, "source": [ "## 서명이 있는 모델 변환\n", "\n", "TensorFlow Lite 변환기 API는 위의 서명 정보를 변환된 TensorFlow Lite 모델로 가져옵니다.\n", "\n", "이 변환 기능은 TensorFlow 버전 2.7.0부터 모든 변환기 API에서 사용할 수 있습니다. 사용 예를 참조하세요.\n" ] }, { "cell_type": "markdown", "id": "Qv0WwFQkQgnO", "metadata": { "id": "Qv0WwFQkQgnO" }, "source": [ "### SavedModel에서" ] }, { "cell_type": "code", "execution_count": 4, "id": "96c8fc79", "metadata": { "execution": { "iopub.execute_input": "2022-12-15T01:06:28.607236Z", "iopub.status.busy": "2022-12-15T01:06:28.606546Z", "iopub.status.idle": "2022-12-15T01:06:32.429562Z", "shell.execute_reply": "2022-12-15T01:06:32.428802Z" }, "id": "96c8fc79" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: content/saved_models/coding/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'decode': {'inputs': ['x'], 'outputs': ['decoded_result']}, 'encode': {'inputs': ['x'], 'outputs': ['encoded_result']}}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-12-15 01:06:32.364466: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.\n", "2022-12-15 01:06:32.364500: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.\n", "2022-12-15 01:06:32.409841: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2046] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s):\n", "Flex ops: FlexAsString, FlexStringToNumber\n", "Details:\n", "\ttf.AsString(tensor) -> (tensor) : {device = \"\", fill = \"\", precision = -1 : i64, scientific = false, shortest = false, width = -1 : i64}\n", "\ttf.StringToNumber(tensor) -> (tensor) : {device = \"\", out_type = f32}\n", "See instructions: https://www.tensorflow.org/lite/guide/ops_select\n", "INFO: Created TensorFlow Lite delegate for select TF ops.\n", "INFO: TfLiteFlexDelegate delegate: 1 nodes delegated out of 1 nodes with 1 partitions.\n", "\n" ] } ], "source": [ "model = Model()\n", "\n", "# Save the model\n", "SAVED_MODEL_PATH = 'content/saved_models/coding'\n", "\n", "tf.saved_model.save(\n", " model, SAVED_MODEL_PATH,\n", " signatures={\n", " 'encode': model.encode.get_concrete_function(),\n", " 'decode': model.decode.get_concrete_function()\n", " })\n", "\n", "# Convert the saved model using TFLiteConverter\n", "converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)\n", "converter.target_spec.supported_ops = [\n", " tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.\n", " tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.\n", "]\n", "tflite_model = converter.convert()\n", "\n", "# Print the signatures from the converted model\n", "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n", "signatures = interpreter.get_signature_list()\n", "print(signatures)" ] }, { "cell_type": "markdown", "id": "5baa9f17", "metadata": { "id": "5baa9f17" }, "source": [ "### Keras 모델에서" ] }, { "cell_type": "code", "execution_count": 5, "id": "71f29229", "metadata": { "execution": { "iopub.execute_input": "2022-12-15T01:06:32.434036Z", "iopub.status.busy": "2022-12-15T01:06:32.433493Z", "iopub.status.idle": "2022-12-15T01:06:33.206056Z", "shell.execute_reply": "2022-12-15T01:06:33.205256Z" }, "id": "71f29229" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp24we75_2/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'serving_default': {'inputs': ['x_input'], 'outputs': ['output']}}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-12-15 01:06:33.126947: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.\n", "2022-12-15 01:06:33.126990: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.\n" ] } ], "source": [ "# Generate a Keras model.\n", "keras_model = tf.keras.Sequential(\n", " [\n", " tf.keras.layers.Dense(2, input_dim=4, activation='relu', name='x'),\n", " tf.keras.layers.Dense(1, activation='relu', name='output'),\n", " ]\n", ")\n", "\n", "# Convert the keras model using TFLiteConverter.\n", "# Keras model converter API uses the default signature automatically.\n", "converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)\n", "tflite_model = converter.convert()\n", "\n", "# Print the signatures from the converted model\n", "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n", "\n", "signatures = interpreter.get_signature_list()\n", "print(signatures)" ] }, { "cell_type": "markdown", "id": "e4d30f85", "metadata": { "id": "e4d30f85" }, "source": [ "### 구체적인 기능에서" ] }, { "cell_type": "code", "execution_count": 6, "id": "c9e8a742", "metadata": { "execution": { "iopub.execute_input": "2022-12-15T01:06:33.210177Z", "iopub.status.busy": "2022-12-15T01:06:33.209593Z", "iopub.status.idle": "2022-12-15T01:06:33.380518Z", "shell.execute_reply": "2022-12-15T01:06:33.379792Z" }, "id": "c9e8a742" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpl2ma5ilo/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpl2ma5ilo/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'decode': {'inputs': ['x'], 'outputs': ['decoded_result']}, 'encode': {'inputs': ['x'], 'outputs': ['encoded_result']}}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-12-15 01:06:33.323787: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.\n", "2022-12-15 01:06:33.323827: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.\n", "2022-12-15 01:06:33.361670: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2046] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s):\n", "Flex ops: FlexAsString, FlexStringToNumber\n", "Details:\n", "\ttf.AsString(tensor) -> (tensor) : {device = \"\", fill = \"\", precision = -1 : i64, scientific = false, shortest = false, width = -1 : i64}\n", "\ttf.StringToNumber(tensor) -> (tensor) : {device = \"\", out_type = f32}\n", "See instructions: https://www.tensorflow.org/lite/guide/ops_select\n" ] } ], "source": [ "model = Model()\n", "\n", "# Convert the concrete functions using TFLiteConverter\n", "converter = tf.lite.TFLiteConverter.from_concrete_functions(\n", " [model.encode.get_concrete_function(),\n", " model.decode.get_concrete_function()], model)\n", "converter.target_spec.supported_ops = [\n", " tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.\n", " tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.\n", "]\n", "tflite_model = converter.convert()\n", "\n", "# Print the signatures from the converted model\n", "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n", "signatures = interpreter.get_signature_list()\n", "print(signatures)" ] }, { "cell_type": "markdown", "id": "b5e85934", "metadata": { "id": "b5e85934" }, "source": [ "## 서명 실행\n", "\n", "TensorFlow 추론 API는 서명 기반 실행을 지원합니다.\n", "\n", "- 서명으로 지정된 입력 및 출력 이름을 통해 입력/출력 텐서에 액세스합니다.\n", "- 서명 키로 식별되는 그래프의 각 진입점을 별도로 실행합니다.\n", "- SavedModel의 초기화 절차를 지원합니다.\n", "\n", "Java, C++ 및 Python 언어 바인딩을 현재 사용할 수 있습니다. 아래 섹션의 예를 참조하세요.\n" ] }, { "cell_type": "markdown", "id": "ZRBMFciMQmiB", "metadata": { "id": "ZRBMFciMQmiB" }, "source": [ "### Java" ] }, { "cell_type": "markdown", "id": "04c5a4fc", "metadata": { "id": "04c5a4fc" }, "source": [ "```\n", "try (Interpreter interpreter = new Interpreter(file_of_tensorflowlite_model)) {\n", " // Run encoding signature.\n", " Map<String, Object> inputs = new HashMap<>();\n", " inputs.put(\"x\", input);\n", " Map<String, Object> outputs = new HashMap<>();\n", " outputs.put(\"encoded_result\", encoded_result);\n", " interpreter.runSignature(inputs, outputs, \"encode\");\n", "\n", " // Run decoding signature.\n", " Map<String, Object> inputs = new HashMap<>();\n", " inputs.put(\"x\", encoded_result);\n", " Map<String, Object> outputs = new HashMap<>();\n", " outputs.put(\"decoded_result\", decoded_result);\n", " interpreter.runSignature(inputs, outputs, \"decode\");\n", "}\n", "```" ] }, { "cell_type": "markdown", "id": "5ba86c64", "metadata": { "id": "5ba86c64" }, "source": [ "### C++" ] }, { "cell_type": "markdown", "id": "397ad6fd", "metadata": { "id": "397ad6fd" }, "source": [ "```\n", "SignatureRunner* encode_runner =\n", " interpreter->GetSignatureRunner(\"encode\");\n", "encode_runner->ResizeInputTensor(\"x\", {100});\n", "encode_runner->AllocateTensors();\n", "\n", "TfLiteTensor* input_tensor = encode_runner->input_tensor(\"x\");\n", "float* input = input_tensor->data.f;\n", "// Fill `input`.\n", "\n", "encode_runner->Invoke();\n", "\n", "const TfLiteTensor* output_tensor = encode_runner->output_tensor(\n", " \"encoded_result\");\n", "float* output = output_tensor->data.f;\n", "// Access `output`.\n", "```" ] }, { "cell_type": "markdown", "id": "0f4c6ad4", "metadata": { "id": "0f4c6ad4" }, "source": [ "### Python" ] }, { "cell_type": "code", "execution_count": 7, "id": "ab7b1963", "metadata": { "execution": { "iopub.execute_input": "2022-12-15T01:06:33.385543Z", "iopub.status.busy": "2022-12-15T01:06:33.384838Z", "iopub.status.idle": "2022-12-15T01:06:33.402794Z", "shell.execute_reply": "2022-12-15T01:06:33.401937Z" }, "id": "ab7b1963" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Signature: {'decode': {'inputs': ['x'], 'outputs': ['decoded_result']}, 'encode': {'inputs': ['x'], 'outputs': ['encoded_result']}}\n", "Input: tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)\n", "Encoded result: {'encoded_result': array([b'1.000000', b'2.000000', b'3.000000'], dtype=object)}\n", "Decoded result: {'decoded_result': array([1., 2., 3.], dtype=float32)}\n" ] } ], "source": [ "# Load the TFLite model in TFLite Interpreter\n", "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n", "\n", "# Print the signatures from the converted model\n", "signatures = interpreter.get_signature_list()\n", "print('Signature:', signatures)\n", "\n", "# encode and decode are callable with input as arguments.\n", "encode = interpreter.get_signature_runner('encode')\n", "decode = interpreter.get_signature_runner('decode')\n", "\n", "# 'encoded' and 'decoded' are dictionaries with all outputs from the inference.\n", "input = tf.constant([1, 2, 3], dtype=tf.float32)\n", "print('Input:', input)\n", "encoded = encode(x=input)\n", "print('Encoded result:', encoded)\n", "decoded = decode(x=encoded['encoded_result'])\n", "print('Decoded result:', decoded)" ] }, { "cell_type": "markdown", "id": "81b42e5b", "metadata": { "id": "81b42e5b" }, "source": [ "## 알려진 제한 사항\n", "\n", "- TFLite 인터프리터는 스레드 안전을 보장하지 않으므로 동일한 인터프리터의 서명 실행자는 동시에 실행되지 않습니다.\n", "- C/iOS/Swift에 대한 지원은 아직 제공되지 않습니다.\n" ] }, { "cell_type": "markdown", "id": "3032Iof6QqmJ", "metadata": { "id": "3032Iof6QqmJ" }, "source": [ "## 업데이트\n", "\n", "- 버전 2.7\n", " - 다중 서명 기능이 구현됩니다.\n", " - 버전 2의 모든 변환기 API는 서명이 지원되는 TensorFlow Lite 모델을 생성합니다.\n", "- 버전 2.5\n", " - 서명 기능은 `from_saved_model` 변환기 API를 통해 사용할 수 있습니다." ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "signatures.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }