{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2021-02-13T03:02:53.247295Z", "iopub.status.busy": "2021-02-13T03:02:53.246506Z", "iopub.status.idle": "2021-02-13T03:02:53.249004Z", "shell.execute_reply": "2021-02-13T03:02:53.248485Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "qFdPvlXBOdUN" }, "source": [ "# Keras の例による量子化認識トレーニング" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.orgで表示 Google Colab で実行GitHub でソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "Bjmi3qZeu_xk" }, "source": [ "## 概要\n", "\n", "*量子化認識トレーニング*のエンドツーエンドの例へようこそ。\n", "\n", "### その他のページ\n", "\n", "量子化認識トレーニングの紹介、および認識トレーニングを使用すべきかどうかの判定(サポート情報も含む)については、[概要ページ](https://www.tensorflow.org/model_optimization/guide/quantization/training.md)をご覧ください。\n", "\n", "ユースケースに合った API を素早く特定するには(8 ビットのモデルの完全量子化を超えるユースケース)、[総合ガイド](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md)をご覧ください。\n", "\n", "### 要約\n", "\n", "このチュートリアルでは、次について説明しています。\n", "\n", "1. MNIST の `tf.keras` モデルを最初からトレーニングする。\n", "2. 量子化認識トレーニング API を適用してモデルをファインチューニングし、精度を確認して量子化認識モデルをエクスポートする。\n", "3. このモデルを使用して、TFLite バックエンドのために実際に量子化されたモデルを作成する。\n", "4. TFLite および 1/4 のモデルの精度の永続性を確認する。モバイルでのレイテンシーのメリットを確認するには、[TFLite アプリリポジトリ内の](https://www.tensorflow.org/lite/models) TFLite の例を試してみてください。" ] }, { "cell_type": "markdown", "metadata": { "id": "yEAZYXvZU_XG" }, "source": [ "## セットアップ" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2021-02-13T03:02:53.260043Z", "iopub.status.busy": "2021-02-13T03:02:53.259328Z", "iopub.status.idle": "2021-02-13T03:03:19.366045Z", "shell.execute_reply": "2021-02-13T03:03:19.365425Z" }, "id": "zN4yVFK5-0Bf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found existing installation: tensorflow 2.4.1\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Uninstalling tensorflow-2.4.1:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled tensorflow-2.4.1\r\n" ] } ], "source": [ "! pip uninstall -y tensorflow\n", "! pip install -q tf-nightly\n", "! pip install -q tensorflow-model-optimization\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2021-02-13T03:03:19.371559Z", "iopub.status.busy": "2021-02-13T03:03:19.370847Z", "iopub.status.idle": "2021-02-13T03:03:25.956453Z", "shell.execute_reply": "2021-02-13T03:03:25.956935Z" }, "id": "yJwIonXEVJo6" }, "outputs": [], "source": [ "import tempfile\n", "import os\n", "\n", "import tensorflow as tf\n", "\n", "from tensorflow import keras" ] }, { "cell_type": "markdown", "metadata": { "id": "psViY5PRDurp" }, "source": [ "## 量子化認識トレーニングを使用せずに、MNIST のモデルをトレーニングする" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2021-02-13T03:03:25.964690Z", "iopub.status.busy": "2021-02-13T03:03:25.963989Z", "iopub.status.idle": "2021-02-13T03:03:34.754140Z", "shell.execute_reply": "2021-02-13T03:03:34.754563Z" }, "id": "pbY-KGMPvbW9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/1688 [..............................] - ETA: 57:15 - loss: 2.3546 - accuracy: 0.0000e+00" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 23/1688 [..............................] - ETA: 3s - loss: 2.1957 - accuracy: 0.2001 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 46/1688 [..............................] - ETA: 3s - loss: 2.0404 - accuracy: 0.3254" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 71/1688 [>.............................] - ETA: 3s - loss: 1.8700 - accuracy: 0.4154" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 96/1688 [>.............................] - ETA: 3s - loss: 1.7193 - accuracy: 0.4782" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 121/1688 [=>............................] - ETA: 3s - loss: 1.5970 - accuracy: 0.5232" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 145/1688 [=>............................] - ETA: 3s - loss: 1.5011 - accuracy: 0.5566" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 170/1688 [==>...........................] - ETA: 3s - loss: 1.4166 - accuracy: 0.5850" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 195/1688 [==>...........................] - ETA: 3s - loss: 1.3448 - accuracy: 0.6084" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 220/1688 [==>...........................] - ETA: 3s - loss: 1.2831 - accuracy: 0.6279" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 245/1688 [===>..........................] - ETA: 3s - loss: 1.2294 - accuracy: 0.6446" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 270/1688 [===>..........................] - ETA: 2s - loss: 1.1820 - accuracy: 0.6592" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 295/1688 [====>.........................] - ETA: 2s - loss: 1.1399 - accuracy: 0.6720" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 320/1688 [====>.........................] - ETA: 2s - loss: 1.1022 - accuracy: 0.6834" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 345/1688 [=====>........................] - ETA: 2s - loss: 1.0682 - accuracy: 0.6936" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 369/1688 [=====>........................] - ETA: 2s - loss: 1.0384 - accuracy: 0.7026" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 393/1688 [=====>........................] - ETA: 2s - loss: 1.0111 - accuracy: 0.7107" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 416/1688 [======>.......................] - ETA: 2s - loss: 0.9868 - accuracy: 0.7179" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 441/1688 [======>.......................] - ETA: 2s - loss: 0.9625 - accuracy: 0.7251" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 466/1688 [=======>......................] - ETA: 2s - loss: 0.9400 - accuracy: 0.7318" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 491/1688 [=======>......................] - ETA: 2s - loss: 0.9192 - accuracy: 0.7379" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 515/1688 [========>.....................] - ETA: 2s - loss: 0.9006 - accuracy: 0.7433" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 539/1688 [========>.....................] - ETA: 2s - loss: 0.8832 - accuracy: 0.7484" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 565/1688 [=========>....................] - ETA: 2s - loss: 0.8654 - accuracy: 0.7535" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 590/1688 [=========>....................] - ETA: 2s - loss: 0.8494 - accuracy: 0.7582" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 615/1688 [=========>....................] - ETA: 2s - loss: 0.8343 - accuracy: 0.7626" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 640/1688 [==========>...................] - ETA: 2s - loss: 0.8201 - accuracy: 0.7667" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 665/1688 [==========>...................] - ETA: 2s - loss: 0.8065 - accuracy: 0.7706" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 690/1688 [===========>..................] - ETA: 2s - loss: 0.7937 - accuracy: 0.7743" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 715/1688 [===========>..................] - ETA: 2s - loss: 0.7815 - accuracy: 0.7778" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 740/1688 [============>.................] - ETA: 1s - loss: 0.7698 - accuracy: 0.7812" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 765/1688 [============>.................] - ETA: 1s - loss: 0.7587 - accuracy: 0.7844" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 790/1688 [=============>................] - ETA: 1s - loss: 0.7481 - accuracy: 0.7875" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 815/1688 [=============>................] - ETA: 1s - loss: 0.7379 - accuracy: 0.7904" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 839/1688 [=============>................] - ETA: 1s - loss: 0.7285 - accuracy: 0.7931" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 863/1688 [==============>...............] - ETA: 1s - loss: 0.7195 - accuracy: 0.7957" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 888/1688 [==============>...............] - ETA: 1s - loss: 0.7104 - accuracy: 0.7983" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 913/1688 [===============>..............] - ETA: 1s - loss: 0.7017 - accuracy: 0.8008" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 938/1688 [===============>..............] - ETA: 1s - loss: 0.6933 - accuracy: 0.8033" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 963/1688 [================>.............] - ETA: 1s - loss: 0.6851 - accuracy: 0.8056" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 986/1688 [================>.............] - ETA: 1s - loss: 0.6779 - accuracy: 0.8077" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1011/1688 [================>.............] - ETA: 1s - loss: 0.6703 - accuracy: 0.8098" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1036/1688 [=================>............] - ETA: 1s - loss: 0.6630 - accuracy: 0.8119" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1061/1688 [=================>............] - ETA: 1s - loss: 0.6560 - accuracy: 0.8139" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1085/1688 [==================>...........] - ETA: 1s - loss: 0.6494 - accuracy: 0.8158" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1110/1688 [==================>...........] - ETA: 1s - loss: 0.6428 - accuracy: 0.8177" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1135/1688 [===================>..........] - ETA: 1s - loss: 0.6364 - accuracy: 0.8196" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1160/1688 [===================>..........] - ETA: 1s - loss: 0.6301 - accuracy: 0.8213" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1185/1688 [====================>.........] - ETA: 1s - loss: 0.6241 - accuracy: 0.8231" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1210/1688 [====================>.........] - ETA: 0s - loss: 0.6182 - accuracy: 0.8247" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1235/1688 [====================>.........] - ETA: 0s - loss: 0.6125 - accuracy: 0.8264" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1260/1688 [=====================>........] - ETA: 0s - loss: 0.6069 - accuracy: 0.8280" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1285/1688 [=====================>........] - ETA: 0s - loss: 0.6015 - accuracy: 0.8295" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1310/1688 [======================>.......] - ETA: 0s - loss: 0.5962 - accuracy: 0.8310" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1335/1688 [======================>.......] - ETA: 0s - loss: 0.5911 - accuracy: 0.8325" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1359/1688 [=======================>......] - ETA: 0s - loss: 0.5863 - accuracy: 0.8339" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1383/1688 [=======================>......] - ETA: 0s - loss: 0.5816 - accuracy: 0.8352" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1407/1688 [========================>.....] - ETA: 0s - loss: 0.5770 - accuracy: 0.8365" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1432/1688 [========================>.....] - ETA: 0s - loss: 0.5724 - accuracy: 0.8379" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1457/1688 [========================>.....] - ETA: 0s - loss: 0.5678 - accuracy: 0.8391" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1482/1688 [=========================>....] - ETA: 0s - loss: 0.5634 - accuracy: 0.8404" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1506/1688 [=========================>....] - ETA: 0s - loss: 0.5592 - accuracy: 0.8416" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1531/1688 [==========================>...] - ETA: 0s - loss: 0.5550 - accuracy: 0.8428" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1556/1688 [==========================>...] - ETA: 0s - loss: 0.5509 - accuracy: 0.8440" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1581/1688 [===========================>..] - ETA: 0s - loss: 0.5468 - accuracy: 0.8451" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1606/1688 [===========================>..] - ETA: 0s - loss: 0.5429 - accuracy: 0.8463" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1631/1688 [===========================>..] - ETA: 0s - loss: 0.5390 - accuracy: 0.8474" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1656/1688 [============================>.] - ETA: 0s - loss: 0.5352 - accuracy: 0.8484" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1681/1688 [============================>.] - ETA: 0s - loss: 0.5315 - accuracy: 0.8495" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1688/1688 [==============================] - 6s 2ms/step - loss: 0.5304 - accuracy: 0.8498 - val_loss: 0.1238 - val_accuracy: 0.9680\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load MNIST dataset\n", "mnist = keras.datasets.mnist\n", "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n", "\n", "# Normalize the input image so that each pixel value is between 0 to 1.\n", "train_images = train_images / 255.0\n", "test_images = test_images / 255.0\n", "\n", "# Define the model architecture.\n", "model = keras.Sequential([\n", " keras.layers.InputLayer(input_shape=(28, 28)),\n", " keras.layers.Reshape(target_shape=(28, 28, 1)),\n", " keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),\n", " keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", " keras.layers.Flatten(),\n", " keras.layers.Dense(10)\n", "])\n", "\n", "# Train the digit classification model\n", "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])\n", "\n", "model.fit(\n", " train_images,\n", " train_labels,\n", " epochs=1,\n", " validation_split=0.1,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "K8747K9OE72P" }, "source": [ "## 量子化認識トレーニングを使用して、事前トレーニング済みモデルをクローンおよびファインチューニングする\n" ] }, { "cell_type": "markdown", "metadata": { "id": "F19k7ExXF_h2" }, "source": [ "### モデルを定義する" ] }, { "cell_type": "markdown", "metadata": { "id": "JsZROpNYMWQ0" }, "source": [ "量子化認識トレーニングをモデル全体に適用し、これをモデルの要約で確認します。すべてのレイヤーにプレフィックス \"quant\" が付いているはずです。\n", "\n", "結果のモデルは量子化認識モデルですが、量子化はされていないので注意してください(例えば、重みは int8 ではなく float32 です)。次のセクションでは、量子化認識モデルから量子化モデルを作成する方法を示します。\n", "\n", "[総合ガイド](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md)では、モデルの精度を改善するために、一部のレイヤーを量子化する方法をご覧いただけます。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2021-02-13T03:03:34.760223Z", "iopub.status.busy": "2021-02-13T03:03:34.759531Z", "iopub.status.idle": "2021-02-13T03:03:36.232377Z", "shell.execute_reply": "2021-02-13T03:03:36.231883Z" }, "id": "oq6blGjgFDCW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "quantize_layer (QuantizeLaye (None, 28, 28) 3 \n", "_________________________________________________________________\n", "quant_reshape (QuantizeWrapp (None, 28, 28, 1) 1 \n", "_________________________________________________________________\n", "quant_conv2d (QuantizeWrappe (None, 26, 26, 12) 147 \n", "_________________________________________________________________\n", "quant_max_pooling2d (Quantiz (None, 13, 13, 12) 1 \n", "_________________________________________________________________\n", "quant_flatten (QuantizeWrapp (None, 2028) 1 \n", "_________________________________________________________________\n", "quant_dense (QuantizeWrapper (None, 10) 20295 \n", "=================================================================\n", "Total params: 20,448\n", "Trainable params: 20,410\n", "Non-trainable params: 38\n", "_________________________________________________________________\n" ] } ], "source": [ "import tensorflow_model_optimization as tfmot\n", "\n", "quantize_model = tfmot.quantization.keras.quantize_model\n", "\n", "# q_aware stands for for quantization aware.\n", "q_aware_model = quantize_model(model)\n", "\n", "# `quantize_model` requires a recompile.\n", "q_aware_model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])\n", "\n", "q_aware_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "uDr2ijwpGCI-" }, "source": [ "### モデルをベースラインに対してトレーニングおよび評価する" ] }, { "cell_type": "markdown", "metadata": { "id": "XUBEn94hXYB1" }, "source": [ "モデルを 1 エポックだけトレーニングした後のファインチューニングを実証するために、トレーニングデータのサブセットに対して、量子化認識トレーニングを使用してファインチューニングを行います。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2021-02-13T03:03:36.237791Z", "iopub.status.busy": "2021-02-13T03:03:36.237115Z", "iopub.status.idle": "2021-02-13T03:03:37.030206Z", "shell.execute_reply": "2021-02-13T03:03:37.029654Z" }, "id": "_PHDGJryE31X" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/2 [==============>...............] - ETA: 0s - loss: 0.1671 - accuracy: 0.9520" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/2 [==============================] - 1s 210ms/step - loss: 0.1565 - accuracy: 0.9544 - val_loss: 0.1604 - val_accuracy: 0.9700\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_images_subset = train_images[0:1000] # out of 60000\n", "train_labels_subset = train_labels[0:1000]\n", "\n", "q_aware_model.fit(train_images_subset, train_labels_subset,\n", " batch_size=500, epochs=1, validation_split=0.1)" ] }, { "cell_type": "markdown", "metadata": { "id": "-byC2lYlMkfN" }, "source": [ "この例では、ベースラインと比較し、量子化認識トレーニング後のテスト精度の損失は、最小限あるいはゼロです。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2021-02-13T03:03:37.035495Z", "iopub.status.busy": "2021-02-13T03:03:37.034806Z", "iopub.status.idle": "2021-02-13T03:03:38.107098Z", "shell.execute_reply": "2021-02-13T03:03:38.106528Z" }, "id": "6bMFTKSSHyyZ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Baseline test accuracy: 0.9642000198364258\n", "Quant test accuracy: 0.9656000137329102\n" ] } ], "source": [ "_, baseline_model_accuracy = model.evaluate(\n", " test_images, test_labels, verbose=0)\n", "\n", "_, q_aware_model_accuracy = q_aware_model.evaluate(\n", " test_images, test_labels, verbose=0)\n", "\n", "print('Baseline test accuracy:', baseline_model_accuracy)\n", "print('Quant test accuracy:', q_aware_model_accuracy)" ] }, { "cell_type": "markdown", "metadata": { "id": "2IepmUPSITn6" }, "source": [ "## TFLite バックエンドの量子化モデルを作成する" ] }, { "cell_type": "markdown", "metadata": { "id": "1FgNP4rbOLH8" }, "source": [ "この後に、重み int8 と活性化関数 uint8 を持つ、実際に量子化されたモデルが出来上がります。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2021-02-13T03:03:38.114343Z", "iopub.status.busy": "2021-02-13T03:03:38.111524Z", "iopub.status.idle": "2021-02-13T03:03:39.719196Z", "shell.execute_reply": "2021-02-13T03:03:39.718509Z" }, "id": "w7fztWsAOHTz" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as reshape_layer_call_fn, reshape_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, max_pooling2d_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/tmpupf9tsyx/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/tmpupf9tsyx/assets\n" ] } ], "source": [ "converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)\n", "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", "\n", "quantized_tflite_model = converter.convert()" ] }, { "cell_type": "markdown", "metadata": { "id": "BEYsyYVqNgeY" }, "source": [ "## TF から TFLite への精度の永続性を確認する" ] }, { "cell_type": "markdown", "metadata": { "id": "saadXD4JQsBK" }, "source": [ "テストデータセットで TFLite モデルを評価するヘルパー関数を定義します。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2021-02-13T03:03:39.727587Z", "iopub.status.busy": "2021-02-13T03:03:39.726869Z", "iopub.status.idle": "2021-02-13T03:03:39.728677Z", "shell.execute_reply": "2021-02-13T03:03:39.729093Z" }, "id": "b8yBouuGNqls" }, "outputs": [], "source": [ "import numpy as np\n", "\n", "def evaluate_model(interpreter):\n", " input_index = interpreter.get_input_details()[0][\"index\"]\n", " output_index = interpreter.get_output_details()[0][\"index\"]\n", "\n", " # Run predictions on every image in the \"test\" dataset.\n", " prediction_digits = []\n", " for i, test_image in enumerate(test_images):\n", " if i % 1000 == 0:\n", " print('Evaluated on {n} results so far.'.format(n=i))\n", " # Pre-processing: add batch dimension and convert to float32 to match with\n", " # the model's input data format.\n", " test_image = np.expand_dims(test_image, axis=0).astype(np.float32)\n", " interpreter.set_tensor(input_index, test_image)\n", "\n", " # Run inference.\n", " interpreter.invoke()\n", "\n", " # Post-processing: remove batch dimension and find the digit with highest\n", " # probability.\n", " output = interpreter.tensor(output_index)\n", " digit = np.argmax(output()[0])\n", " prediction_digits.append(digit)\n", "\n", " print('\\n')\n", " # Compare prediction results with ground truth labels to calculate accuracy.\n", " prediction_digits = np.array(prediction_digits)\n", " accuracy = (prediction_digits == test_labels).mean()\n", " return accuracy" ] }, { "cell_type": "markdown", "metadata": { "id": "TuEFS4CIQvUw" }, "source": [ "量子化されたモデルを評価し、TensorFlow の精度が TFLite バックエンドに持続されていることを確認します。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2021-02-13T03:03:39.734260Z", "iopub.status.busy": "2021-02-13T03:03:39.733611Z", "iopub.status.idle": "2021-02-13T03:03:43.556998Z", "shell.execute_reply": "2021-02-13T03:03:43.556484Z" }, "id": "VqQTyqz4NsWd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 0 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 1000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 2000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 3000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 4000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 5000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 6000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 7000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 8000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 9000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Quant TFLite test_accuracy: 0.9656\n", "Quant TF test accuracy: 0.9656000137329102\n" ] } ], "source": [ "interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)\n", "interpreter.allocate_tensors()\n", "\n", "test_accuracy = evaluate_model(interpreter)\n", "\n", "print('Quant TFLite test_accuracy:', test_accuracy)\n", "print('Quant TF test accuracy:', q_aware_model_accuracy)" ] }, { "cell_type": "markdown", "metadata": { "id": "z8D7WnFF5DZR" }, "source": [ "## 量子化でモデルが 1/4 になることを確認する" ] }, { "cell_type": "markdown", "metadata": { "id": "I1c2IecBRCdQ" }, "source": [ "浮動小数点数の TFLite モデルを作成して、量子化された TFLite モデルが 1/4 になっていることを確認します。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2021-02-13T03:03:43.564325Z", "iopub.status.busy": "2021-02-13T03:03:43.563679Z", "iopub.status.idle": "2021-02-13T03:03:44.138100Z", "shell.execute_reply": "2021-02-13T03:03:44.137590Z" }, "id": "jy_Lgfh8VkyX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/tmpx96l01vr/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/tmpx96l01vr/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Float model in Mb: 0.08058547973632812\n", "Quantized model in Mb: 0.0234527587890625\n" ] } ], "source": [ "# Create float TFLite model.\n", "float_converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", "float_tflite_model = float_converter.convert()\n", "\n", "# Measure sizes of models.\n", "_, float_file = tempfile.mkstemp('.tflite')\n", "_, quant_file = tempfile.mkstemp('.tflite')\n", "\n", "with open(quant_file, 'wb') as f:\n", " f.write(quantized_tflite_model)\n", "\n", "with open(float_file, 'wb') as f:\n", " f.write(float_tflite_model)\n", "\n", "print(\"Float model in Mb:\", os.path.getsize(float_file) / float(2**20))\n", "print(\"Quantized model in Mb:\", os.path.getsize(quant_file) / float(2**20))" ] }, { "cell_type": "markdown", "metadata": { "id": "0O5xuci-SonI" }, "source": [ "## 結論" ] }, { "cell_type": "markdown", "metadata": { "id": "O2I7xmyMW5QY" }, "source": [ "このチュートリアルでは、TensorFlow Model Optimization Toolkit API を使用して量子化認識モデルを作成し、TFLite バックエンドの量子化モデルを作成する方法を紹介しました。\n", "\n", "MNIST のモデルでは、精度の違いを最小限に抑えながらモデルサイズを 1/4 に圧縮できることを示しましたた。モバイルでのレイテンシーのメリットを確認するには、[TFLite アプリリポジトリ内の](https://www.tensorflow.org/lite/models) TFLite の例を試してみてください。\n", "\n", "この新しい機能をぜひお試しください。リソースが制限される環境でのデプロイにおいて、特に重要となります。\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "training_example.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 0 }