{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2021-01-15T02:15:00.635498Z", "iopub.status.busy": "2021-01-15T02:15:00.634286Z", "iopub.status.idle": "2021-01-15T02:15:00.637098Z", "shell.execute_reply": "2021-01-15T02:15:00.636582Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "qFdPvlXBOdUN" }, "source": [ "# Keras 예제의 양자화 인식 훈련" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org에서 보기Run in Google ColabGitHub에서 소스 보기노트북 다운로드하기
" ] }, { "cell_type": "markdown", "metadata": { "id": "Bjmi3qZeu_xk" }, "source": [ "## 개요\n", "\n", "*양자화 인식 훈련*의 엔드 투 엔드 예제를 시작합니다.\n", "\n", "### 기타 페이지\n", "\n", "양자화 인식 훈련이 무엇인지에 대한 소개와 이를 사용해야 하는지에 대한 결정(지원되는 내용 포함)은 [개요 페이지](https://www.tensorflow.org/model_optimization/guide/quantization/training.md)를 참조하세요.\n", "\n", "사용 사례에 필요한 API를 빠르게 찾으려면(8bit로 모델을 완전히 양자화하는 것 이상), [종합 가이드](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md)를 참조하세요.\n", "\n", "### 요약\n", "\n", "이 튜토리얼에서는 다음을 수행합니다.\n", "\n", "1. MNIST용 `tf.keras` 모델을 처음부터 훈련합니다.\n", "2. 양자화 인식 학습 API를 적용하여 모델을 미세 조정하고, 정확성을 확인하고, 양자화 인식 모델을 내보냅니다.\n", "3. 모델을 사용하여 TFLite 백엔드에 대해 실제로 양자화된 모델을 만듭니다.\n", "4. TFLite와 4배 더 작아진 모델에서 정확성의 지속성을 확인합니다. 모바일에서의 지연 시간 이점을 확인하려면, [TFLite 앱 리포지토리](https://www.tensorflow.org/lite/models)에서 TFLite 예제를 사용해 보세요." ] }, { "cell_type": "markdown", "metadata": { "id": "yEAZYXvZU_XG" }, "source": [ "## 설정" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2021-01-15T02:15:00.648490Z", "iopub.status.busy": "2021-01-15T02:15:00.647523Z", "iopub.status.idle": "2021-01-15T02:15:26.184945Z", "shell.execute_reply": "2021-01-15T02:15:26.185428Z" }, "id": "zN4yVFK5-0Bf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found existing installation: tensorflow 2.4.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Uninstalling tensorflow-2.4.0:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled tensorflow-2.4.0\r\n" ] } ], "source": [ "! pip uninstall -y tensorflow\n", "! pip install -q tf-nightly\n", "! pip install -q tensorflow-model-optimization\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2021-01-15T02:15:26.190720Z", "iopub.status.busy": "2021-01-15T02:15:26.189980Z", "iopub.status.idle": "2021-01-15T02:15:32.769432Z", "shell.execute_reply": "2021-01-15T02:15:32.768714Z" }, "id": "yJwIonXEVJo6" }, "outputs": [], "source": [ "import tempfile\n", "import os\n", "\n", "import tensorflow as tf\n", "\n", "from tensorflow import keras" ] }, { "cell_type": "markdown", "metadata": { "id": "psViY5PRDurp" }, "source": [ "## 양자화 인식 훈련 없이 MNIST 모델 훈련하기" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2021-01-15T02:15:32.779347Z", "iopub.status.busy": "2021-01-15T02:15:32.778633Z", "iopub.status.idle": "2021-01-15T02:15:41.981279Z", "shell.execute_reply": "2021-01-15T02:15:41.981763Z" }, "id": "pbY-KGMPvbW9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/1688 [..............................] - ETA: 1:01:25 - loss: 2.2936 - accuracy: 0.0625" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 22/1688 [..............................] - ETA: 4s - loss: 2.1873 - accuracy: 0.2267 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 45/1688 [..............................] - ETA: 3s - loss: 2.0397 - accuracy: 0.3321" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 68/1688 [>.............................] - ETA: 3s - loss: 1.8877 - accuracy: 0.4111" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 91/1688 [>.............................] - ETA: 3s - loss: 1.7481 - accuracy: 0.4695" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 114/1688 [=>............................] - ETA: 3s - loss: 1.6303 - accuracy: 0.5135" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 138/1688 [=>............................] - ETA: 3s - loss: 1.5270 - accuracy: 0.5492" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 162/1688 [=>............................] - ETA: 3s - loss: 1.4395 - accuracy: 0.5782" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 185/1688 [==>...........................] - ETA: 3s - loss: 1.3690 - accuracy: 0.6011" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 208/1688 [==>...........................] - ETA: 3s - loss: 1.3080 - accuracy: 0.6204" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 231/1688 [===>..........................] - ETA: 3s - loss: 1.2543 - accuracy: 0.6371" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 255/1688 [===>..........................] - ETA: 3s - loss: 1.2049 - accuracy: 0.6523" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 279/1688 [===>..........................] - ETA: 3s - loss: 1.1607 - accuracy: 0.6657" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 302/1688 [====>.........................] - ETA: 3s - loss: 1.1229 - accuracy: 0.6771" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 326/1688 [====>.........................] - ETA: 2s - loss: 1.0874 - accuracy: 0.6876" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 349/1688 [=====>........................] - ETA: 2s - loss: 1.0565 - accuracy: 0.6968" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 372/1688 [=====>........................] - ETA: 2s - loss: 1.0280 - accuracy: 0.7053" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 394/1688 [======>.......................] - ETA: 2s - loss: 1.0030 - accuracy: 0.7127" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 418/1688 [======>.......................] - ETA: 2s - loss: 0.9779 - accuracy: 0.7201" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 441/1688 [======>.......................] - ETA: 2s - loss: 0.9557 - accuracy: 0.7268" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 465/1688 [=======>......................] - ETA: 2s - loss: 0.9341 - accuracy: 0.7331" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 489/1688 [=======>......................] - ETA: 2s - loss: 0.9140 - accuracy: 0.7391" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 513/1688 [========>.....................] - ETA: 2s - loss: 0.8951 - accuracy: 0.7446" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 536/1688 [========>.....................] - ETA: 2s - loss: 0.8782 - accuracy: 0.7496" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 559/1688 [========>.....................] - ETA: 2s - loss: 0.8622 - accuracy: 0.7543" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 583/1688 [=========>....................] - ETA: 2s - loss: 0.8465 - accuracy: 0.7589" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 607/1688 [=========>....................] - ETA: 2s - loss: 0.8316 - accuracy: 0.7632" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 630/1688 [==========>...................] - ETA: 2s - loss: 0.8182 - accuracy: 0.7671" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 654/1688 [==========>...................] - ETA: 2s - loss: 0.8049 - accuracy: 0.7709" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 677/1688 [===========>..................] - ETA: 2s - loss: 0.7927 - accuracy: 0.7744" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 700/1688 [===========>..................] - ETA: 2s - loss: 0.7812 - accuracy: 0.7778" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 723/1688 [===========>..................] - ETA: 2s - loss: 0.7701 - accuracy: 0.7809" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 746/1688 [============>.................] - ETA: 2s - loss: 0.7595 - accuracy: 0.7840" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 769/1688 [============>.................] - ETA: 2s - loss: 0.7494 - accuracy: 0.7869" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 792/1688 [=============>................] - ETA: 1s - loss: 0.7398 - accuracy: 0.7896" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 815/1688 [=============>................] - ETA: 1s - loss: 0.7304 - accuracy: 0.7923" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 839/1688 [=============>................] - ETA: 1s - loss: 0.7211 - accuracy: 0.7950" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 862/1688 [==============>...............] - ETA: 1s - loss: 0.7125 - accuracy: 0.7974" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 885/1688 [==============>...............] - ETA: 1s - loss: 0.7042 - accuracy: 0.7998" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 908/1688 [===============>..............] - ETA: 1s - loss: 0.6962 - accuracy: 0.8021" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 931/1688 [===============>..............] - ETA: 1s - loss: 0.6886 - accuracy: 0.8043" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 954/1688 [===============>..............] - ETA: 1s - loss: 0.6811 - accuracy: 0.8064" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 978/1688 [================>.............] - ETA: 1s - loss: 0.6736 - accuracy: 0.8085" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1001/1688 [================>.............] - ETA: 1s - loss: 0.6667 - accuracy: 0.8105" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1025/1688 [=================>............] - ETA: 1s - loss: 0.6597 - accuracy: 0.8125" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1048/1688 [=================>............] - ETA: 1s - loss: 0.6532 - accuracy: 0.8143" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1072/1688 [==================>...........] - ETA: 1s - loss: 0.6467 - accuracy: 0.8162" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1097/1688 [==================>...........] - ETA: 1s - loss: 0.6401 - accuracy: 0.8181" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1121/1688 [==================>...........] - ETA: 1s - loss: 0.6339 - accuracy: 0.8198" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1143/1688 [===================>..........] - ETA: 1s - loss: 0.6284 - accuracy: 0.8214" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1166/1688 [===================>..........] - ETA: 1s - loss: 0.6229 - accuracy: 0.8230" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1189/1688 [====================>.........] - ETA: 1s - loss: 0.6175 - accuracy: 0.8245" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1212/1688 [====================>.........] - ETA: 1s - loss: 0.6122 - accuracy: 0.8260" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1235/1688 [====================>.........] - ETA: 0s - loss: 0.6070 - accuracy: 0.8275" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1259/1688 [=====================>........] - ETA: 0s - loss: 0.6018 - accuracy: 0.8289" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1282/1688 [=====================>........] - ETA: 0s - loss: 0.5970 - accuracy: 0.8303" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1306/1688 [======================>.......] - ETA: 0s - loss: 0.5920 - accuracy: 0.8317" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1330/1688 [======================>.......] - ETA: 0s - loss: 0.5872 - accuracy: 0.8331" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1353/1688 [=======================>......] - ETA: 0s - loss: 0.5827 - accuracy: 0.8344" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1376/1688 [=======================>......] - ETA: 0s - loss: 0.5783 - accuracy: 0.8356" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1399/1688 [=======================>......] - ETA: 0s - loss: 0.5740 - accuracy: 0.8368" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1423/1688 [========================>.....] - ETA: 0s - loss: 0.5697 - accuracy: 0.8381" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1447/1688 [========================>.....] - ETA: 0s - loss: 0.5654 - accuracy: 0.8393" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1471/1688 [=========================>....] - ETA: 0s - loss: 0.5612 - accuracy: 0.8405" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1496/1688 [=========================>....] - ETA: 0s - loss: 0.5569 - accuracy: 0.8417" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1521/1688 [==========================>...] - ETA: 0s - loss: 0.5527 - accuracy: 0.8429" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1543/1688 [==========================>...] - ETA: 0s - loss: 0.5491 - accuracy: 0.8439" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1567/1688 [==========================>...] - ETA: 0s - loss: 0.5453 - accuracy: 0.8450" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1590/1688 [===========================>..] - ETA: 0s - loss: 0.5417 - accuracy: 0.8460" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1614/1688 [===========================>..] - ETA: 0s - loss: 0.5380 - accuracy: 0.8470" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1637/1688 [============================>.] - ETA: 0s - loss: 0.5346 - accuracy: 0.8480" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1661/1688 [============================>.] - ETA: 0s - loss: 0.5310 - accuracy: 0.8490" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1684/1688 [============================>.] - ETA: 0s - loss: 0.5277 - accuracy: 0.8499" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1688/1688 [==============================] - 7s 3ms/step - loss: 0.5270 - accuracy: 0.8501 - val_loss: 0.1138 - val_accuracy: 0.9703\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load MNIST dataset\n", "mnist = keras.datasets.mnist\n", "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n", "\n", "# Normalize the input image so that each pixel value is between 0 to 1.\n", "train_images = train_images / 255.0\n", "test_images = test_images / 255.0\n", "\n", "# Define the model architecture.\n", "model = keras.Sequential([\n", " keras.layers.InputLayer(input_shape=(28, 28)),\n", " keras.layers.Reshape(target_shape=(28, 28, 1)),\n", " keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),\n", " keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", " keras.layers.Flatten(),\n", " keras.layers.Dense(10)\n", "])\n", "\n", "# Train the digit classification model\n", "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])\n", "\n", "model.fit(\n", " train_images,\n", " train_labels,\n", " epochs=1,\n", " validation_split=0.1,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "K8747K9OE72P" }, "source": [ "## 양자화 인식 훈련으로 사전 훈련된 모델 복제 및 미세 조정\n" ] }, { "cell_type": "markdown", "metadata": { "id": "F19k7ExXF_h2" }, "source": [ "### 모델 정의하기" ] }, { "cell_type": "markdown", "metadata": { "id": "JsZROpNYMWQ0" }, "source": [ "전체 모델에 양자화 인식 훈련을 적용하고 모델 요약에서 이를 확인합니다. 이제 모든 레이어 앞에 \"quant\"가 붙습니다.\n", "\n", "결과 모델은 양자화를 인식하지만, 양자화되지는 않습니다(예: 가중치가 int8 대신 float32임). 다음 섹션에서는 양자화 인식 모델에서 양자화된 모델을 만드는 방법을 보여줍니다.\n", "\n", "[종합 가이드](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md)에서 모델 정확성의 향상을 위해 일부 레이어를 양자화하는 방법을 볼 수 있습니다." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2021-01-15T02:15:41.988132Z", "iopub.status.busy": "2021-01-15T02:15:41.987377Z", "iopub.status.idle": "2021-01-15T02:15:43.564967Z", "shell.execute_reply": "2021-01-15T02:15:43.564441Z" }, "id": "oq6blGjgFDCW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "quantize_layer (QuantizeLaye (None, 28, 28) 3 \n", "_________________________________________________________________\n", "quant_reshape (QuantizeWrapp (None, 28, 28, 1) 1 \n", "_________________________________________________________________\n", "quant_conv2d (QuantizeWrappe (None, 26, 26, 12) 147 \n", "_________________________________________________________________\n", "quant_max_pooling2d (Quantiz (None, 13, 13, 12) 1 \n", "_________________________________________________________________\n", "quant_flatten (QuantizeWrapp (None, 2028) 1 \n", "_________________________________________________________________\n", "quant_dense (QuantizeWrapper (None, 10) 20295 \n", "=================================================================\n", "Total params: 20,448\n", "Trainable params: 20,410\n", "Non-trainable params: 38\n", "_________________________________________________________________\n" ] } ], "source": [ "import tensorflow_model_optimization as tfmot\n", "\n", "quantize_model = tfmot.quantization.keras.quantize_model\n", "\n", "# q_aware stands for for quantization aware.\n", "q_aware_model = quantize_model(model)\n", "\n", "# `quantize_model` requires a recompile.\n", "q_aware_model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])\n", "\n", "q_aware_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "uDr2ijwpGCI-" }, "source": [ "### 기준선과 비교하여 모델 훈련 및 평가하기" ] }, { "cell_type": "markdown", "metadata": { "id": "XUBEn94hXYB1" }, "source": [ "하나의 epoch 동안 모델을 훈련한 후 미세 조정을 시연하려면 훈련 데이터의 하위 집합에 대한 양자화 인식 훈련으로 미세 조정합니다." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2021-01-15T02:15:43.570312Z", "iopub.status.busy": "2021-01-15T02:15:43.569591Z", "iopub.status.idle": "2021-01-15T02:15:44.437339Z", "shell.execute_reply": "2021-01-15T02:15:44.436654Z" }, "id": "_PHDGJryE31X" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/2 [==============>...............] - ETA: 0s - loss: 0.1233 - accuracy: 0.9660" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/2 [==============================] - 1s 248ms/step - loss: 0.1391 - accuracy: 0.9613 - val_loss: 0.1508 - val_accuracy: 0.9700\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_images_subset = train_images[0:1000] # out of 60000\n", "train_labels_subset = train_labels[0:1000]\n", "\n", "q_aware_model.fit(train_images_subset, train_labels_subset,\n", " batch_size=500, epochs=1, validation_split=0.1)" ] }, { "cell_type": "markdown", "metadata": { "id": "-byC2lYlMkfN" }, "source": [ "이 예제의 경우, 기준선과 비교하여 양자화 인식 훈련 후 테스트 정확성의 손실이 거의 없습니다." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2021-01-15T02:15:44.443178Z", "iopub.status.busy": "2021-01-15T02:15:44.442456Z", "iopub.status.idle": "2021-01-15T02:15:45.641554Z", "shell.execute_reply": "2021-01-15T02:15:45.640906Z" }, "id": "6bMFTKSSHyyZ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Baseline test accuracy: 0.9616000056266785\n", "Quant test accuracy: 0.9625999927520752\n" ] } ], "source": [ "_, baseline_model_accuracy = model.evaluate(\n", " test_images, test_labels, verbose=0)\n", "\n", "_, q_aware_model_accuracy = q_aware_model.evaluate(\n", " test_images, test_labels, verbose=0)\n", "\n", "print('Baseline test accuracy:', baseline_model_accuracy)\n", "print('Quant test accuracy:', q_aware_model_accuracy)" ] }, { "cell_type": "markdown", "metadata": { "id": "2IepmUPSITn6" }, "source": [ "## TFLite 백엔드를 위한 양자화 모델 생성하기" ] }, { "cell_type": "markdown", "metadata": { "id": "1FgNP4rbOLH8" }, "source": [ "다음을 통해 int8 가중치 및 uint8 활성화를 사용하여 실제로 양자화된 모델을 얻게 됩니다." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2021-01-15T02:15:45.649613Z", "iopub.status.busy": "2021-01-15T02:15:45.648663Z", "iopub.status.idle": "2021-01-15T02:15:47.307689Z", "shell.execute_reply": "2021-01-15T02:15:47.306967Z" }, "id": "w7fztWsAOHTz" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, max_pooling2d_layer_call_and_return_conditional_losses while saving (showing 5 of 25). These functions will not be directly callable after loading.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/tmpmfpzobkx/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/tmpmfpzobkx/assets\n" ] } ], "source": [ "converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)\n", "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", "\n", "quantized_tflite_model = converter.convert()" ] }, { "cell_type": "markdown", "metadata": { "id": "BEYsyYVqNgeY" }, "source": [ "## TF에서 TFLite까지 정확성의 지속성 확인하기" ] }, { "cell_type": "markdown", "metadata": { "id": "saadXD4JQsBK" }, "source": [ "테스트 데이터세트에 대해 TF Lite 모델을 평가하는 도우미 함수를 정의합니다." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2021-01-15T02:15:47.316184Z", "iopub.status.busy": "2021-01-15T02:15:47.315338Z", "iopub.status.idle": "2021-01-15T02:15:47.317871Z", "shell.execute_reply": "2021-01-15T02:15:47.317310Z" }, "id": "b8yBouuGNqls" }, "outputs": [], "source": [ "import numpy as np\n", "\n", "def evaluate_model(interpreter):\n", " input_index = interpreter.get_input_details()[0][\"index\"]\n", " output_index = interpreter.get_output_details()[0][\"index\"]\n", "\n", " # Run predictions on every image in the \"test\" dataset.\n", " prediction_digits = []\n", " for i, test_image in enumerate(test_images):\n", " if i % 1000 == 0:\n", " print('Evaluated on {n} results so far.'.format(n=i))\n", " # Pre-processing: add batch dimension and convert to float32 to match with\n", " # the model's input data format.\n", " test_image = np.expand_dims(test_image, axis=0).astype(np.float32)\n", " interpreter.set_tensor(input_index, test_image)\n", "\n", " # Run inference.\n", " interpreter.invoke()\n", "\n", " # Post-processing: remove batch dimension and find the digit with highest\n", " # probability.\n", " output = interpreter.tensor(output_index)\n", " digit = np.argmax(output()[0])\n", " prediction_digits.append(digit)\n", "\n", " print('\\n')\n", " # Compare prediction results with ground truth labels to calculate accuracy.\n", " prediction_digits = np.array(prediction_digits)\n", " accuracy = (prediction_digits == test_labels).mean()\n", " return accuracy" ] }, { "cell_type": "markdown", "metadata": { "id": "TuEFS4CIQvUw" }, "source": [ "양자화 모델을 평가하고 TensorFlow의 정확성이 TFLite 백엔드까지 유지되는지 확인합니다." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2021-01-15T02:15:47.323236Z", "iopub.status.busy": "2021-01-15T02:15:47.322563Z", "iopub.status.idle": "2021-01-15T02:15:51.375154Z", "shell.execute_reply": "2021-01-15T02:15:51.374545Z" }, "id": "VqQTyqz4NsWd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 0 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 1000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 2000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 3000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 4000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 5000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 6000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 7000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 8000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluated on 9000 results so far.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Quant TFLite test_accuracy: 0.9626\n", "Quant TF test accuracy: 0.9625999927520752\n" ] } ], "source": [ "interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)\n", "interpreter.allocate_tensors()\n", "\n", "test_accuracy = evaluate_model(interpreter)\n", "\n", "print('Quant TFLite test_accuracy:', test_accuracy)\n", "print('Quant TF test accuracy:', q_aware_model_accuracy)" ] }, { "cell_type": "markdown", "metadata": { "id": "z8D7WnFF5DZR" }, "source": [ "## 양자화로 4배 더 작아진 모델 확인하기" ] }, { "cell_type": "markdown", "metadata": { "id": "I1c2IecBRCdQ" }, "source": [ "float TFLite 모델을 생성한 다음 TFLite 양자화 모델이 4배 더 작아진 것을 확인합니다." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2021-01-15T02:15:51.382896Z", "iopub.status.busy": "2021-01-15T02:15:51.382151Z", "iopub.status.idle": "2021-01-15T02:15:51.978985Z", "shell.execute_reply": "2021-01-15T02:15:51.978352Z" }, "id": "jy_Lgfh8VkyX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/tmpsmkiqoq3/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/tmpsmkiqoq3/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Float model in Mb: 0.08058547973632812\n", "Quantized model in Mb: 0.0234527587890625\n" ] } ], "source": [ "# Create float TFLite model.\n", "float_converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", "float_tflite_model = float_converter.convert()\n", "\n", "# Measure sizes of models.\n", "_, float_file = tempfile.mkstemp('.tflite')\n", "_, quant_file = tempfile.mkstemp('.tflite')\n", "\n", "with open(quant_file, 'wb') as f:\n", " f.write(quantized_tflite_model)\n", "\n", "with open(float_file, 'wb') as f:\n", " f.write(float_tflite_model)\n", "\n", "print(\"Float model in Mb:\", os.path.getsize(float_file) / float(2**20))\n", "print(\"Quantized model in Mb:\", os.path.getsize(quant_file) / float(2**20))" ] }, { "cell_type": "markdown", "metadata": { "id": "0O5xuci-SonI" }, "source": [ "## 결론" ] }, { "cell_type": "markdown", "metadata": { "id": "O2I7xmyMW5QY" }, "source": [ "이 튜토리얼에서는 TensorFlow Model Optimization Toolkit API를 사용하여 양자화 인식 모델을 만든 다음 TFLite 백엔드용 양자화 모델을 만드는 방법을 살펴보았습니다.\n", "\n", "정확성 차이를 최소화하면서 MNIST 모델의 크기를 4배 압축하는 이점을 확인했습니다. 모바일에서의 지연 시간 이점을 확인하려면, [TFLite 앱 리포지토리](https://www.tensorflow.org/lite/models)에서 TFLite 예제를 사용해 보세요.\n", "\n", "이 새로운 기능은 리소스가 제한된 환경에서 배포할 때 특히 중요하므로 사용해 볼 것을 권장합니다.\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "training_example.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 0 }