{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "_DDaAex5Q7u-" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "W1dWWdNHQ9L0" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "6Y8E0lw5eYWm" }, "source": [ "# 训练后整数量化" ] }, { "cell_type": "markdown", "metadata": { "id": "CIGrZZPTZVeO" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 查看在 Google Colab 中运行 在 GitHub 上查看源代码下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "BTC1rDAuei_1" }, "source": [ "## 概述\n", "\n", "整数量化是一种优化策略,可将 32 位浮点数(如权重和激活输出)转换为 8 位定点数。这样可以缩减模型大小并加快推理速度,这对低功耗设备(如[微控制器](https://tensorflow.google.cn/lite/microcontrollers))很有价值。仅支持整数的加速器(如 [Edge TPU](https://coral.ai/))也需要使用此数据格式。\n", "\n", "在本教程中,您将从头开始训练一个 MNIST 模型、将其转换为 TensorFlow Lite 文件,并使用[训练后量化](https://tensorflow.google.cn/lite/performance/post_training_quantization)对其进行量化。最后,您将检查转换后模型的准确率并将其与原始浮点模型进行比较。\n", "\n", "实际上,对模型进行量化的程度有几种选项。在本教程中,您将执行“全整数量化”,它会将所有权重和激活输出转换为 8 位整数数据,而其他策略可能会将部分数据保留为浮点。\n", "\n", "要详细了解各种量化策略,请阅读 [TensorFlow Lite 模型优化](https://tensorflow.google.cn/lite/performance/model_optimization)。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "dDqqUIZjZjac" }, "source": [ "## 设置" ] }, { "cell_type": "markdown", "metadata": { "id": "I0nR5AMEWq0H" }, "source": [ "为了量化输入和输出张量,我们需要使用 TensorFlow 2.3 中新添加的 API:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WsN6s5L1ieNl" }, "outputs": [], "source": [ "import logging\n", "logging.getLogger(\"tensorflow\").setLevel(logging.DEBUG)\n", "\n", "import tensorflow as tf\n", "import numpy as np\n", "print(\"TensorFlow version: \", tf.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "2XsEP17Zelz9" }, "source": [ "## 生成 TensorFlow 模型" ] }, { "cell_type": "markdown", "metadata": { "id": "5NMaNZQCkW9X" }, "source": [ "我们将构建一个简单的模型来对 [MNIST 数据集](https://tensorflow.google.cn/datasets/catalog/mnist)中的数字进行分类。\n", "\n", "此训练不会花很长时间,因为只对模型进行 5 个周期的训练,训练到约 98% 的准确率。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eMsw_6HujaqM" }, "outputs": [], "source": [ "# Load MNIST dataset\n", "mnist = tf.keras.datasets.mnist\n", "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n", "\n", "# Normalize the input image so that each pixel value is between 0 to 1.\n", "train_images = train_images.astype(np.float32) / 255.0\n", "test_images = test_images.astype(np.float32) / 255.0\n", "\n", "# Define the model architecture\n", "model = tf.keras.Sequential([\n", " tf.keras.layers.InputLayer(input_shape=(28, 28)),\n", " tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n", " tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),\n", " tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(10)\n", "])\n", "\n", "# Train the digit classification model\n", "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(\n", " from_logits=True),\n", " metrics=['accuracy'])\n", "model.fit(\n", " train_images,\n", " train_labels,\n", " epochs=5,\n", " validation_data=(test_images, test_labels)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "KuTEoGFYd8aM" }, "source": [ "## 转换为 TensorFlow Lite 模型" ] }, { "cell_type": "markdown", "metadata": { "id": "xl8_fzVAZwOh" }, "source": [ "现在,您可以使用TensorFlow Lite [Converter](https://tensorflow.google.cn/lite/models/convert) 将训练后的模型转换为 TensorFlow Lite 格式,并应用不同程度的量化。\n", "\n", "请注意,某些版本的量化会将部分数据保留为浮点格式。因此,以下各个部分将以量化程度不断增加的顺序展示每个选项,直到获得完全由 int8 或 uint8 数据组成的模型。(请注意,我们在每个部分中重复了一些代码,使您能够看到每个选项的全部量化步骤。)\n", "\n", "首先,下面是一个没有量化的转换后模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_i8B2nDZmAgQ" }, "outputs": [], "source": [ "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", "\n", "tflite_model = converter.convert()" ] }, { "cell_type": "markdown", "metadata": { "id": "7BONhYtYocQY" }, "source": [ "它现在是一个 TensorFlow Lite 模型,但所有参数数据仍使用 32 位浮点值。" ] }, { "cell_type": "markdown", "metadata": { "id": "jPYZwgZTwJMT" }, "source": [ "### 使用动态范围量化进行转换\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Hjvq1vpJd4U_" }, "source": [ "现在,我们启用默认的 `optimizations` 标记来量化所有固定参数(例如权重):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HEZ6ET1AHAS3" }, "outputs": [], "source": [ "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", "\n", "tflite_model_quant = converter.convert()" ] }, { "cell_type": "markdown", "metadata": { "id": "o5wuE-RcdX_3" }, "source": [ "现在,进行了权重量化的模型要略小一些,但其他变量数据仍为浮点格式。" ] }, { "cell_type": "markdown", "metadata": { "id": "UgKDdnHQEhpb" }, "source": [ "### 使用浮点回退量化进行转换" ] }, { "cell_type": "markdown", "metadata": { "id": "rTe8avZJHMDO" }, "source": [ "要量化可变数据(例如模型输入/输出和层之间的中间体),您需要提供 [`RepresentativeDataset`](https://tensorflow.google.cn/api_docs/python/tf/lite/RepresentativeDataset)。这是一个生成器函数,它提供一组足够大的输入数据来代表典型值。转换器可以通过该函数估算所有可变数据的动态范围。(相比训练或评估数据集,此数据集不必唯一。)为了支持多个输入,每个代表性数据点都是一个列表,并且列表中的元素会根据其索引被馈送到模型。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FiwiWU3gHdkW" }, "outputs": [], "source": [ "def representative_data_gen():\n", " for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):\n", " # Model has only one input so each data point has one element.\n", " yield [input_value]\n", "\n", "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", "converter.representative_dataset = representative_data_gen\n", "\n", "tflite_model_quant = converter.convert()" ] }, { "cell_type": "markdown", "metadata": { "id": "_GC3HFlptf7x" }, "source": [ "现在,所有权重和可变数据都已量化,并且与原始 TensorFlow Lite 模型相比,该模型要小得多。\n", "\n", "但是,为了与传统上使用浮点模型输入和输出张量的应用保持兼容,TensorFlow Lite 转换器将模型的输入和输出张量保留为浮点:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "id1OEKFELQwp" }, "outputs": [], "source": [ "interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)\n", "input_type = interpreter.get_input_details()[0]['dtype']\n", "print('input: ', input_type)\n", "output_type = interpreter.get_output_details()[0]['dtype']\n", "print('output: ', output_type)" ] }, { "cell_type": "markdown", "metadata": { "id": "RACBJuj2XO8x" }, "source": [ "这通常对兼容性有利,但它无法兼容执行全整数运算的设备(如 Edge TPU)。\n", "\n", "此外,如果 TensorFlow Lite 不包括某个运算的量化实现,则上述过程可能会将该运算保留为浮点格式。您仍能通过此策略完成转换,并得到一个更小、更高效的模型,但它还是不兼容仅支持整数的硬件。(此 MNIST 模型中的所有算子都有量化的实现。)\n", "\n", "因此,为了确保端到端全整数模型,您还需要几个参数…" ] }, { "cell_type": "markdown", "metadata": { "id": "FQgTqbvPvxGJ" }, "source": [ "### 使用仅整数量化进行转换" ] }, { "cell_type": "markdown", "metadata": { "id": "mwR9keYAwArA" }, "source": [ "为了量化输入和输出张量,并让转换器在遇到无法量化的运算时引发错误,使用一些附加参数再次转换模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kzjEjcDs3BHa" }, "outputs": [], "source": [ "def representative_data_gen():\n", " for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):\n", " yield [input_value]\n", "\n", "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", "converter.representative_dataset = representative_data_gen\n", "# Ensure that if any ops can't be quantized, the converter throws an error\n", "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n", "# Set the input and output tensors to uint8 (APIs added in r2.3)\n", "converter.inference_input_type = tf.uint8\n", "converter.inference_output_type = tf.uint8\n", "\n", "tflite_model_quant = converter.convert()" ] }, { "cell_type": "markdown", "metadata": { "id": "wYd6NxD03yjB" }, "source": [ "内部量化与上文相同,但您可以看到输入和输出张量现在是整数格式:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PaNkOS-twz4k" }, "outputs": [], "source": [ "interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)\n", "input_type = interpreter.get_input_details()[0]['dtype']\n", "print('input: ', input_type)\n", "output_type = interpreter.get_output_details()[0]['dtype']\n", "print('output: ', output_type)" ] }, { "cell_type": "markdown", "metadata": { "id": "TO17AP84wzBb" }, "source": [ "现在,您有了一个整数量化模型,该模型使用整数数据作为模型的输入和输出张量,因此它兼容仅支持整数的硬件(如 [Edge TPU](https://coral.ai))。" ] }, { "cell_type": "markdown", "metadata": { "id": "sse224YJ4KMm" }, "source": [ "### 将模型另存为文件" ] }, { "cell_type": "markdown", "metadata": { "id": "4_9nZ4nv4b9P" }, "source": [ "您需要 `.tflite` 文件才能在其他设备上部署模型。因此,我们将转换的模型保存为文件,然后在下面运行推断时加载它们。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BEY59dC14uRv" }, "outputs": [], "source": [ "import pathlib\n", "\n", "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n", "tflite_models_dir.mkdir(exist_ok=True, parents=True)\n", "\n", "# Save the unquantized/float model:\n", "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n", "tflite_model_file.write_bytes(tflite_model)\n", "# Save the quantized model:\n", "tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n", "tflite_model_quant_file.write_bytes(tflite_model_quant)" ] }, { "cell_type": "markdown", "metadata": { "id": "9t9yaTeF9fyM" }, "source": [ "## 运行 TensorFlow Lite 模型" ] }, { "cell_type": "markdown", "metadata": { "id": "L8lQHMp_asCq" }, "source": [ "现在,我们使用 TensorFlow Lite [`Interpreter`](https://tensorflow.google.cn/api_docs/python/tf/lite/Interpreter) 运行推断来比较模型的准确率。\n", "\n", "首先,我们需要一个函数,该函数使用给定的模型和图像运行推断,然后返回预测值:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X092SbeWfd1A" }, "outputs": [], "source": [ "# Helper function to run inference on a TFLite model\n", "def run_tflite_model(tflite_file, test_image_indices):\n", " global test_images\n", "\n", " # Initialize the interpreter\n", " interpreter = tf.lite.Interpreter(model_path=str(tflite_file))\n", " interpreter.allocate_tensors()\n", "\n", " input_details = interpreter.get_input_details()[0]\n", " output_details = interpreter.get_output_details()[0]\n", "\n", " predictions = np.zeros((len(test_image_indices),), dtype=int)\n", " for i, test_image_index in enumerate(test_image_indices):\n", " test_image = test_images[test_image_index]\n", " test_label = test_labels[test_image_index]\n", "\n", " # Check if the input type is quantized, then rescale input data to uint8\n", " if input_details['dtype'] == np.uint8:\n", " input_scale, input_zero_point = input_details[\"quantization\"]\n", " test_image = test_image / input_scale + input_zero_point\n", "\n", " test_image = np.expand_dims(test_image, axis=0).astype(input_details[\"dtype\"])\n", " interpreter.set_tensor(input_details[\"index\"], test_image)\n", " interpreter.invoke()\n", " output = interpreter.get_tensor(output_details[\"index\"])[0]\n", "\n", " predictions[i] = output.argmax()\n", "\n", " return predictions\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2opUt_JTdyEu" }, "source": [ "### 在单个图像上测试模型\n" ] }, { "cell_type": "markdown", "metadata": { "id": "QpPpFPaz7eEM" }, "source": [ "现在,我们来比较一下浮点模型和量化模型的性能:\n", "\n", "- `tflite_model_file` 是使用浮点数据的原始 TensorFlow Lite 模型。\n", "- `tflite_model_quant_file` 是我们使用全整数量化转换的上一个模型(它使用 uint8 数据作为输入和输出)。\n", "\n", "我们来创建另一个函数打印预测值:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zR2cHRUcUZ6e" }, "outputs": [], "source": [ "import matplotlib.pylab as plt\n", "\n", "# Change this to test a different image\n", "test_image_index = 1\n", "\n", "## Helper function to test the models on one image\n", "def test_model(tflite_file, test_image_index, model_type):\n", " global test_labels\n", "\n", " predictions = run_tflite_model(tflite_file, [test_image_index])\n", "\n", " plt.imshow(test_images[test_image_index])\n", " template = model_type + \" Model \\n True:{true}, Predicted:{predict}\"\n", " _ = plt.title(template.format(true= str(test_labels[test_image_index]), predict=str(predictions[0])))\n", " plt.grid(False)" ] }, { "cell_type": "markdown", "metadata": { "id": "A5OTJ_6Vcslt" }, "source": [ "现在测试浮点模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iTK0x980coto" }, "outputs": [], "source": [ "test_model(tflite_model_file, test_image_index, model_type=\"Float\")" ] }, { "cell_type": "markdown", "metadata": { "id": "o3N6-UGl1dfE" }, "source": [ "然后测试量化模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rc1i9umMcp0t" }, "outputs": [], "source": [ "test_model(tflite_model_quant_file, test_image_index, model_type=\"Quantized\")" ] }, { "cell_type": "markdown", "metadata": { "id": "LwN7uIdCd8Gw" }, "source": [ "### 在所有图像上评估模型" ] }, { "cell_type": "markdown", "metadata": { "id": "RFKOD4DG8XmU" }, "source": [ "现在,我们使用在本教程开始时加载的所有测试图像来运行两个模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "05aeAuWjvjPx" }, "outputs": [], "source": [ "# Helper function to evaluate a TFLite model on all images\n", "def evaluate_model(tflite_file, model_type):\n", " global test_images\n", " global test_labels\n", "\n", " test_image_indices = range(test_images.shape[0])\n", " predictions = run_tflite_model(tflite_file, test_image_indices)\n", "\n", " accuracy = (np.sum(test_labels== predictions) * 100) / len(test_images)\n", "\n", " print('%s model accuracy is %.4f%% (Number of test samples=%d)' % (\n", " model_type, accuracy, len(test_images)))" ] }, { "cell_type": "markdown", "metadata": { "id": "xnFilQpBuMh5" }, "source": [ "评估浮点模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T5mWkSbMcU5z" }, "outputs": [], "source": [ "evaluate_model(tflite_model_file, model_type=\"Float\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Km3cY9ry8ZlG" }, "source": [ "评估量化模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-9cnwiPp6EGm" }, "outputs": [], "source": [ "evaluate_model(tflite_model_quant_file, model_type=\"Quantized\")" ] }, { "cell_type": "markdown", "metadata": { "id": "L7lfxkor8pgv" }, "source": [ "现在您有了一个整数量化模型,该模型的准确率与浮点模型相比几乎没有差别。\n", "\n", "要详细了解其他量化策略,请阅读 [TensorFlow Lite 模型优化](https://tensorflow.google.cn/lite/performance/model_optimization)。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "post_training_integer_quant.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }