{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "g_nWetWWd_ns" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-01-11T21:23:23.962586Z", "iopub.status.busy": "2024-01-11T21:23:23.962353Z", "iopub.status.idle": "2024-01-11T21:23:23.965965Z", "shell.execute_reply": "2024-01-11T21:23:23.965399Z" }, "id": "2pHVBk_seED1" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-01-11T21:23:23.969104Z", "iopub.status.busy": "2024-01-11T21:23:23.968528Z", "iopub.status.idle": "2024-01-11T21:23:23.971884Z", "shell.execute_reply": "2024-01-11T21:23:23.971262Z" }, "id": "N_fMsQ-N8I7j" }, "outputs": [], "source": [ "#@title MIT License\n", "#\n", "# Copyright (c) 2017 François Chollet\n", "#\n", "# Permission is hereby granted, free of charge, to any person obtaining a\n", "# copy of this software and associated documentation files (the \"Software\"),\n", "# to deal in the Software without restriction, including without limitation\n", "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", "# and/or sell copies of the Software, and to permit persons to whom the\n", "# Software is furnished to do so, subject to the following conditions:\n", "#\n", "# The above copyright notice and this permission notice shall be included in\n", "# all copies or substantial portions of the Software.\n", "#\n", "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", "# DEALINGS IN THE SOFTWARE." ] }, { "cell_type": "markdown", "metadata": { "id": "pZJ3uY9O17VN" }, "source": [ "# モデルの保存と復元" ] }, { "cell_type": "markdown", "metadata": { "id": "M4Ata7_wMul1" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "mBdde4YJeJKF" }, "source": [ "モデルの進行状況は、トレーニング中およびトレーニング後に保存できます。モデルが中断したところから再開できるので、長いトレーニング時間を回避できます。また、保存することによりモデルを共有したり、他の人による作業の再現が可能になります。研究モデルや手法を公開する場合、ほとんどの機械学習の実践者は次を共有します。\n", "\n", "- モデルを構築するプログラム\n", "- モデルのトレーニング済みモデルの重みやパラメータ\n", "\n", "このデータを共有することで、他の人がモデルがどの様に動作するかを理解したり、新しいデータに試してみたりすることが容易になります。\n", "\n", "注意: TensorFlow モデルはコードであり、信頼できないコードに注意する必要があります。詳細については、[TensorFlow を安全に使用する](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md)をご覧ください。\n", "\n", "### オプション\n", "\n", "TensorFlow モデルを保存するには、使用している API に応じて様々な方法があります。このガイドでは、TensorFlow でモデルのビルドとトレーニングを行う [tf.keras](https://www.tensorflow.org/guide/keras) という高レベル API を使用しています。このチュートリアルで使用されている新しい高レベル `.keras` 形式は、堅牢で効率的な名前ベースの保存方法を提供しており、通常、低レベルやレガシー形式よりも簡単にデバッグできるため、Keras オブジェクトの保存に推奨されています。より高度な保存またはシリアル化ワークフロー、特にカスタムオブジェクトが関わるワークフローについては、[「Keras モデルを保存して読み込む」ガイド](https://www.tensorflow.org/guide/keras/save_and_serialize)をご覧ください。他のアプローチについては、[「SavedModel 形式の使用」ガイド](../../guide/saved_model.ipynb)をご覧ください。" ] }, { "cell_type": "markdown", "metadata": { "id": "xCUREq7WXgvg" }, "source": [ "## セットアップ\n", "\n", "### インストールとインポート" ] }, { "cell_type": "markdown", "metadata": { "id": "7l0MiTOrXtNv" }, "source": [ "TensorFlow をインストールし、依存関係インポートします。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:23.975010Z", "iopub.status.busy": "2024-01-11T21:23:23.974782Z", "iopub.status.idle": "2024-01-11T21:23:25.794620Z", "shell.execute_reply": "2024-01-11T21:23:25.793756Z" }, "id": "RzIOVSdnMYyO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyyaml in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (6.0.1)\r\n", "Requirement already satisfied: h5py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (3.10.0)\r\n", "Requirement already satisfied: numpy>=1.17.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from h5py) (1.26.3)\r\n" ] } ], "source": [ "!pip install pyyaml h5py # Required to save models in HDF5 format" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:25.798968Z", "iopub.status.busy": "2024-01-11T21:23:25.798713Z", "iopub.status.idle": "2024-01-11T21:23:28.151197Z", "shell.execute_reply": "2024-01-11T21:23:28.150498Z" }, "id": "7Nm7Tyb-gRt-" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-01-11 21:23:26.230036: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-01-11 21:23:26.230083: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-01-11 21:23:26.231707: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2.15.0\n" ] } ], "source": [ "import os\n", "\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "\n", "print(tf.version.VERSION)" ] }, { "cell_type": "markdown", "metadata": { "id": "SbGsznErXWt6" }, "source": [ "### サンプルデータセットの取得\n", "\n", "ここでは、重みの保存と読み込みをデモするために、[MNIST データセット](http://yann.lecun.com/exdb/mnist/)を使います。デモの実行を速くするため、最初の 1,000 件のサンプルだけを使います。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:28.154819Z", "iopub.status.busy": "2024-01-11T21:23:28.154448Z", "iopub.status.idle": "2024-01-11T21:23:28.443899Z", "shell.execute_reply": "2024-01-11T21:23:28.443021Z" }, "id": "9rGfFwE9XVwz" }, "outputs": [], "source": [ "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n", "\n", "train_labels = train_labels[:1000]\n", "test_labels = test_labels[:1000]\n", "\n", "train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0\n", "test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0" ] }, { "cell_type": "markdown", "metadata": { "id": "anG3iVoXyZGI" }, "source": [ "### モデルの定義" ] }, { "cell_type": "markdown", "metadata": { "id": "wynsOBfby0Pa" }, "source": [ "簡単なシーケンシャルモデルを構築することから始めます。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:28.448317Z", "iopub.status.busy": "2024-01-11T21:23:28.448045Z", "iopub.status.idle": "2024-01-11T21:23:30.795083Z", "shell.execute_reply": "2024-01-11T21:23:30.794452Z" }, "id": "0HZbJIjxyX1S" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense (Dense) (None, 512) 401920 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dropout (Dropout) (None, 512) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_1 (Dense) (None, 10) 5130 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "# Define a simple sequential model\n", "def create_model():\n", " model = tf.keras.Sequential([\n", " keras.layers.Dense(512, activation='relu', input_shape=(784,)),\n", " keras.layers.Dropout(0.2),\n", " keras.layers.Dense(10)\n", " ])\n", "\n", " model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n", "\n", " return model\n", "\n", "# Create a basic model instance\n", "model = create_model()\n", "\n", "# Display the model's architecture\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "soDE0W_KH8rG" }, "source": [ "## トレーニング中にチェックポイントを保存する" ] }, { "cell_type": "markdown", "metadata": { "id": "mRyd5qQQIXZm" }, "source": [ "再トレーニングせずにトレーニング済みモデルを使用したり、トレーニングプロセスを中断したところから再開することもできます。`tf.keras.callbacks.ModelCheckpoint` コールバックを使用すると、*トレーニング中*でも*トレーニングの終了時*でもモデルを継続的に保存できます。\n", "\n", "### チェックポイントコールバックの使い方\n", "\n", "トレーニング中にのみ重みを保存する `tf.keras.callbacks.ModelCheckpoint` コールバックを作成します。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:30.802873Z", "iopub.status.busy": "2024-01-11T21:23:30.802128Z", "iopub.status.idle": "2024-01-11T21:23:34.905806Z", "shell.execute_reply": "2024-01-11T21:23:34.904977Z" }, "id": "IFPuhwntH8VH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1705008212.295762 622332 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 54s - loss: 2.3475 - sparse_categorical_accuracy: 0.0938" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 1.3640 - sparse_categorical_accuracy: 0.6065 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 1: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 2s 13ms/step - loss: 1.1532 - sparse_categorical_accuracy: 0.6700 - val_loss: 0.7411 - val_sparse_categorical_accuracy: 0.7790\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.8073 - sparse_categorical_accuracy: 0.7500" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.4534 - sparse_categorical_accuracy: 0.8551" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 2: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 7ms/step - loss: 0.4309 - sparse_categorical_accuracy: 0.8700 - val_loss: 0.5356 - val_sparse_categorical_accuracy: 0.8340\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.5226 - sparse_categorical_accuracy: 0.8438" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.2729 - sparse_categorical_accuracy: 0.9307" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 3: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 6ms/step - loss: 0.2916 - sparse_categorical_accuracy: 0.9260 - val_loss: 0.5225 - val_sparse_categorical_accuracy: 0.8390\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.2464 - sparse_categorical_accuracy: 0.9062" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.2042 - sparse_categorical_accuracy: 0.9460" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 4: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 7ms/step - loss: 0.2157 - sparse_categorical_accuracy: 0.9470 - val_loss: 0.4547 - val_sparse_categorical_accuracy: 0.8530\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.2549 - sparse_categorical_accuracy: 0.9375" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.1520 - sparse_categorical_accuracy: 0.9728" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 5: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 6ms/step - loss: 0.1518 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.4463 - val_sparse_categorical_accuracy: 0.8570\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.0625 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "24/32 [=====================>........] - ETA: 0s - loss: 0.1357 - sparse_categorical_accuracy: 0.9714" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 6: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 6ms/step - loss: 0.1311 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.4503 - val_sparse_categorical_accuracy: 0.8550\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.0902 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.0959 - sparse_categorical_accuracy: 0.9837" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 7: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 6ms/step - loss: 0.0916 - sparse_categorical_accuracy: 0.9840 - val_loss: 0.4363 - val_sparse_categorical_accuracy: 0.8600\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.0531 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.0650 - sparse_categorical_accuracy: 0.9905" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 8: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 6ms/step - loss: 0.0649 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.4235 - val_sparse_categorical_accuracy: 0.8680\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.0327 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.0536 - sparse_categorical_accuracy: 0.9973" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 9: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 6ms/step - loss: 0.0523 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.4353 - val_sparse_categorical_accuracy: 0.8580\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.0545 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "24/32 [=====================>........] - ETA: 0s - loss: 0.0402 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 10: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 6ms/step - loss: 0.0393 - sparse_categorical_accuracy: 0.9990 - val_loss: 0.4228 - val_sparse_categorical_accuracy: 0.8650\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "checkpoint_path = \"training_1/cp.ckpt\"\n", "checkpoint_dir = os.path.dirname(checkpoint_path)\n", "\n", "# Create a callback that saves the model's weights\n", "cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\n", " save_weights_only=True,\n", " verbose=1)\n", "\n", "# Train the model with the new callback\n", "model.fit(train_images, \n", " train_labels, \n", " epochs=10,\n", " validation_data=(test_images, test_labels),\n", " callbacks=[cp_callback]) # Pass callback to training\n", "\n", "# This may generate warnings related to saving the state of the optimizer.\n", "# These warnings (and similar warnings throughout this notebook)\n", "# are in place to discourage outdated usage, and can be ignored." ] }, { "cell_type": "markdown", "metadata": { "id": "rlM-sgyJO084" }, "source": [ "この結果、エポックごとに更新される一連のTensorFlowチェックポイントファイルが作成されます。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:34.909439Z", "iopub.status.busy": "2024-01-11T21:23:34.908744Z", "iopub.status.idle": "2024-01-11T21:23:34.913756Z", "shell.execute_reply": "2024-01-11T21:23:34.913018Z" }, "id": "gXG5FVKFOVQ3" }, "outputs": [ { "data": { "text/plain": [ "['cp.ckpt.data-00000-of-00001', 'checkpoint', 'cp.ckpt.index']" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "os.listdir(checkpoint_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "wlRN_f56Pqa9" }, "source": [ "2 つのモデルが同じアーキテクチャを共有している限り、それらの間で重みを共有できます。したがって、重みのみからモデルを復元する場合は、元のモデルと同じアーキテクチャでモデルを作成してから、その重みを設定します。\n", "\n", "次に、トレーニングされていない新しいモデルを再構築し、テストセットで評価します。トレーニングされていないモデルは、偶然誤差(10% 以下の正解率)で実行されます。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:34.916954Z", "iopub.status.busy": "2024-01-11T21:23:34.916466Z", "iopub.status.idle": "2024-01-11T21:23:35.176543Z", "shell.execute_reply": "2024-01-11T21:23:35.175872Z" }, "id": "Fp5gbuiaPqCT" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 2.3377 - sparse_categorical_accuracy: 0.1260 - 179ms/epoch - 6ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Untrained model, accuracy: 12.60%\n" ] } ], "source": [ "# Create a basic model instance\n", "model = create_model()\n", "\n", "# Evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Untrained model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "1DTKpZssRSo3" }, "source": [ "次に、チェックポイントから重みをロードし、再び評価します。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:35.180189Z", "iopub.status.busy": "2024-01-11T21:23:35.179600Z", "iopub.status.idle": "2024-01-11T21:23:35.327907Z", "shell.execute_reply": "2024-01-11T21:23:35.327188Z" }, "id": "2IZxbwiRRSD2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 0.4228 - sparse_categorical_accuracy: 0.8650 - 90ms/epoch - 3ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored model, accuracy: 86.50%\n" ] } ], "source": [ "# Loads the weights\n", "model.load_weights(checkpoint_path)\n", "\n", "# Re-evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "bpAbKkAyVPV8" }, "source": [ "### チェックポイントコールバックのオプション\n", "\n", "このコールバックには、チェックポイントに一意な名前をつけたり、チェックポイントの頻度を調整するためのオプションがあります。\n", "\n", "新しいモデルをトレーニングし、5 エポックごとに一意な名前のチェックポイントを保存します。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:35.331782Z", "iopub.status.busy": "2024-01-11T21:23:35.331138Z", "iopub.status.idle": "2024-01-11T21:23:44.517571Z", "shell.execute_reply": "2024-01-11T21:23:44.516775Z" }, "id": "mQF_dlgIVOvq" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 5: saving model to training_2/cp-0005.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 10: saving model to training_2/cp-0010.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 15: saving model to training_2/cp-0015.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 20: saving model to training_2/cp-0020.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 25: saving model to training_2/cp-0025.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 30: saving model to training_2/cp-0030.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 35: saving model to training_2/cp-0035.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 40: saving model to training_2/cp-0040.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 45: saving model to training_2/cp-0045.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 50: saving model to training_2/cp-0050.ckpt\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Include the epoch in the file name (uses `str.format`)\n", "checkpoint_path = \"training_2/cp-{epoch:04d}.ckpt\"\n", "checkpoint_dir = os.path.dirname(checkpoint_path)\n", "\n", "batch_size = 32\n", "\n", "# Calculate the number of batches per epoch\n", "import math\n", "n_batches = len(train_images) / batch_size\n", "n_batches = math.ceil(n_batches) # round up the number of batches to the nearest whole integer\n", "\n", "# Create a callback that saves the model's weights every 5 epochs\n", "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n", " filepath=checkpoint_path, \n", " verbose=1, \n", " save_weights_only=True,\n", " save_freq=5*n_batches)\n", "\n", "# Create a new model instance\n", "model = create_model()\n", "\n", "# Save the weights using the `checkpoint_path` format\n", "model.save_weights(checkpoint_path.format(epoch=0))\n", "\n", "# Train the model with the new callback\n", "model.fit(train_images, \n", " train_labels,\n", " epochs=50, \n", " batch_size=batch_size, \n", " callbacks=[cp_callback],\n", " validation_data=(test_images, test_labels),\n", " verbose=0)" ] }, { "cell_type": "markdown", "metadata": { "id": "1zFrKTjjavWI" }, "source": [ "次に、できあがったチェックポイントをレビューし、最新のものを選択します。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:44.521490Z", "iopub.status.busy": "2024-01-11T21:23:44.520786Z", "iopub.status.idle": "2024-01-11T21:23:44.525920Z", "shell.execute_reply": "2024-01-11T21:23:44.525229Z" }, "id": "p64q3-V4sXt0" }, "outputs": [ { "data": { "text/plain": [ "['cp-0040.ckpt.data-00000-of-00001',\n", " 'cp-0035.ckpt.index',\n", " 'cp-0010.ckpt.data-00000-of-00001',\n", " 'cp-0015.ckpt.index',\n", " 'cp-0025.ckpt.index',\n", " 'cp-0040.ckpt.index',\n", " 'cp-0005.ckpt.data-00000-of-00001',\n", " 'cp-0005.ckpt.index',\n", " 'cp-0000.ckpt.index',\n", " 'cp-0030.ckpt.index',\n", " 'cp-0050.ckpt.index',\n", " 'cp-0020.ckpt.data-00000-of-00001',\n", " 'cp-0025.ckpt.data-00000-of-00001',\n", " 'cp-0010.ckpt.index',\n", " 'cp-0045.ckpt.data-00000-of-00001',\n", " 'cp-0045.ckpt.index',\n", " 'cp-0050.ckpt.data-00000-of-00001',\n", " 'checkpoint',\n", " 'cp-0000.ckpt.data-00000-of-00001',\n", " 'cp-0015.ckpt.data-00000-of-00001',\n", " 'cp-0020.ckpt.index',\n", " 'cp-0030.ckpt.data-00000-of-00001',\n", " 'cp-0035.ckpt.data-00000-of-00001']" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "os.listdir(checkpoint_dir)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:44.529217Z", "iopub.status.busy": "2024-01-11T21:23:44.528657Z", "iopub.status.idle": "2024-01-11T21:23:44.533740Z", "shell.execute_reply": "2024-01-11T21:23:44.532881Z" }, "id": "1AN_fnuyR41H" }, "outputs": [ { "data": { "text/plain": [ "'training_2/cp-0050.ckpt'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "latest = tf.train.latest_checkpoint(checkpoint_dir)\n", "latest" ] }, { "cell_type": "markdown", "metadata": { "id": "Zk2ciGbKg561" }, "source": [ "注意: デフォルトの TensorFlow 形式では、最新の 5 つのチェックポイントのみが保存されます。\n", "\n", "テストのため、モデルをリセットし最新のチェックポイントを読み込みます。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:44.536664Z", "iopub.status.busy": "2024-01-11T21:23:44.536414Z", "iopub.status.idle": "2024-01-11T21:23:44.804863Z", "shell.execute_reply": "2024-01-11T21:23:44.804218Z" }, "id": "3M04jyK-H3QK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 0.4859 - sparse_categorical_accuracy: 0.8750 - 179ms/epoch - 6ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored model, accuracy: 87.50%\n" ] } ], "source": [ "# Create a new model instance\n", "model = create_model()\n", "\n", "# Load the previously saved weights\n", "model.load_weights(latest)\n", "\n", "# Re-evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "c2OxsJOTHxia" }, "source": [ "## これらのファイルは何?" ] }, { "cell_type": "markdown", "metadata": { "id": "JtdYhvWnH2ib" }, "source": [ "上記のコードは、バイナリ形式でトレーニングされた重みのみを含む[ checkpoint ](../../guide/checkpoint.ipynb)形式のファイルのコレクションに重みを格納します。チェックポイントには、次のものが含まれます。\n", "\n", "- 1 つ以上のモデルの重みのシャード。\n", "- どの重みがどのシャードに格納されているかを示すインデックスファイル。\n", "\n", "一台のマシンでモデルをトレーニングしている場合は、接尾辞が `.data-00000-of-00001` のシャードが 1 つあります。" ] }, { "cell_type": "markdown", "metadata": { "id": "S_FA-ZvxuXQV" }, "source": [ "## 手動で重みを保存する\n", "\n", "`tf.keras.Model.save_weights` を使用して、手動で重みを保存します。デフォルトでは、`tf.keras`、特に `Model.save_weights` メソッドは、`.ckpt` 拡張子を持つ TensorFlow [Checkpoint](../../guide/checkpoint.ipynb) 形式を使用します。`.h5` 拡張して HDF5 形式として保存するには、[モデルを保存して読み込む](https://www.tensorflow.org/guide/keras/save_and_serialize)ガイドをご覧ください。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:44.808342Z", "iopub.status.busy": "2024-01-11T21:23:44.808110Z", "iopub.status.idle": "2024-01-11T21:23:45.093928Z", "shell.execute_reply": "2024-01-11T21:23:45.093138Z" }, "id": "R7W5plyZ-u9X" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 0.4859 - sparse_categorical_accuracy: 0.8750 - 181ms/epoch - 6ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored model, accuracy: 87.50%\n" ] } ], "source": [ "# Save the weights\n", "model.save_weights('./checkpoints/my_checkpoint')\n", "\n", "# Create a new model instance\n", "model = create_model()\n", "\n", "# Restore the weights\n", "model.load_weights('./checkpoints/my_checkpoint')\n", "\n", "# Evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "kOGlxPRBEvV1" }, "source": [ "## モデル全体を保存する\n", "\n", "`tf.keras.Model.save` を呼び出して、単一の `model.keras` zip アーカイブに、モデルのアーキテクチャ、重み、およびトレーニング構成を保存します。\n", "\n", "モデル全体の保存は、3 つの異なる形式(新しい `.keras` 形式と 2 つのレガシー形式: `SavedModel` と `HDF5`)で行えます。`path/to/model.keras` として保存すると、自動的に最新の形式で保存されます。\n", "\n", "**注意:** Keras オブジェクトについては、新しい高レベルの `.keras` 形式を使用することが推奨されています。よりリッチで名前ベースの保存と再読み込みを行えるため、デバッグしやすいのが特徴です。既存のコードについては、低レベルの SavedModel 形式とレガシーの H5 形式が引き続きサポートされています。\n", "\n", "次の方法で、SavedModel 形式に切り替えることができます。\n", "\n", "- `save_format='tf'` を `save()` に渡す\n", "- 拡張子なしでファイル名を渡す\n", "\n", "次の方法で H5 形式に切り替えることができます。\n", "\n", "- `save_format='h5'` を `save()` に渡す\n", "- `.h5` で終わるファイル名を渡す\n", "\n", "Saving a fully-functional model is very useful—you can load them in TensorFlow.js ([Saved Model](https://www.tensorflow.org/js/tutorials/conversion/import_saved_model), [HDF5](https://www.tensorflow.org/js/tutorials/conversion/import_keras)) and then train and run them in web browsers, or convert them to run on mobile devices using TensorFlow Lite ([Saved Model](https://www.tensorflow.org/lite/models/convert/#convert_a_savedmodel_recommended_), [HDF5](https://www.tensorflow.org/lite/models/convert/#convert_a_keras_model_))\n", "\n", "*Custom objects (for example, subclassed models or layers) require special attention when saving and loading. Refer to the **Saving custom objects** section below." ] }, { "cell_type": "markdown", "metadata": { "id": "0fRGnlHMrkI7" }, "source": [ "### 新しい高レベルの `.keras` 形式" ] }, { "cell_type": "markdown", "metadata": { "id": "eqO8jj7GsCDn" }, "source": [ "新しい Keras v3 保存形式は `.keras` 拡張を使用し、名前ベースの保存を実装するよりシンプルで効率的な形式であるため、Python の観点から、読み込んだものが実際に保存したものであることが保証されます。これにより、デバッグをはるかに容易に行えるため、Keras に推奨される形式となっています。\n", "\n", "以下のセクションは、`.keras` 形式でモデルを保存し、復元する方法を説明しています。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:45.097773Z", "iopub.status.busy": "2024-01-11T21:23:45.097498Z", "iopub.status.idle": "2024-01-11T21:23:46.318749Z", "shell.execute_reply": "2024-01-11T21:23:46.318066Z" }, "id": "3f55mAXwukUX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 20s - loss: 2.2506 - sparse_categorical_accuracy: 0.1562" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/32 [==================>...........] - ETA: 0s - loss: 1.3519 - sparse_categorical_accuracy: 0.6205 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 1s 3ms/step - loss: 1.1083 - sparse_categorical_accuracy: 0.6940\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.6709 - sparse_categorical_accuracy: 0.7812" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.4405 - sparse_categorical_accuracy: 0.8641" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.4247 - sparse_categorical_accuracy: 0.8680\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.2538 - sparse_categorical_accuracy: 0.8750" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.2660 - sparse_categorical_accuracy: 0.9212" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2815 - sparse_categorical_accuracy: 0.9220\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.2601 - sparse_categorical_accuracy: 0.9688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.1887 - sparse_categorical_accuracy: 0.9633" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2004 - sparse_categorical_accuracy: 0.9550\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.0934 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "24/32 [=====================>........] - ETA: 0s - loss: 0.1468 - sparse_categorical_accuracy: 0.9740" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1497 - sparse_categorical_accuracy: 0.9710\n" ] } ], "source": [ "# Create and train a new model instance.\n", "model = create_model()\n", "model.fit(train_images, train_labels, epochs=5)\n", "\n", "# Save the entire model as a `.keras` zip archive.\n", "model.save('my_model.keras')" ] }, { "cell_type": "markdown", "metadata": { "id": "iHqwaun5g8lD" }, "source": [ "`.keras` zip アーカイブからフレッシュな Keras モデルを再読み込みします。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:46.322671Z", "iopub.status.busy": "2024-01-11T21:23:46.322112Z", "iopub.status.idle": "2024-01-11T21:23:46.471610Z", "shell.execute_reply": "2024-01-11T21:23:46.470715Z" }, "id": "HyfUMOZwux_-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_5\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_10 (Dense) (None, 512) 401920 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dropout_5 (Dropout) (None, 512) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_11 (Dense) (None, 10) 5130 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "new_model = tf.keras.models.load_model('my_model.keras')\n", "\n", "# Show the model architecture\n", "new_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "9Cn3pSBqvJ5f" }, "source": [ "読み込まれたモデルで評価と予測を実行してみましょう。" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:46.479110Z", "iopub.status.busy": "2024-01-11T21:23:46.478433Z", "iopub.status.idle": "2024-01-11T21:23:46.920587Z", "shell.execute_reply": "2024-01-11T21:23:46.919896Z" }, "id": "8BT4mHNIvMdW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 0.4726 - sparse_categorical_accuracy: 0.8500 - 180ms/epoch - 6ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored model, accuracy: 85.00%\n", "\r", " 1/32 [..............................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 1ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(1000, 10)\n" ] } ], "source": [ "# Evaluate the restored model\n", "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n", "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))\n", "\n", "print(new_model.predict(test_images).shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "kPyhgcoVzqUB" }, "source": [ "### SavedModel 形式" ] }, { "cell_type": "markdown", "metadata": { "id": "LtcN4VIb7JkK" }, "source": [ "SavedModel 形式は、モデルをシリアル化するもう 1 つの方法です。この形式で保存されたモデルは、`tf.keras.models.load_model` を使用して復元でき、TensorFlow Serving と互換性があります。SavedModel をサービングおよび検査する方法についての詳細は、[SavedModel ガイド](../../guide/saved_model.ipynb)を参照してください。以下のセクションでは、モデルを保存および復元する手順を示します。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:46.924535Z", "iopub.status.busy": "2024-01-11T21:23:46.923898Z", "iopub.status.idle": "2024-01-11T21:23:48.799809Z", "shell.execute_reply": "2024-01-11T21:23:48.799126Z" }, "id": "sI1YvCDFzpl3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 20s - loss: 2.3876 - sparse_categorical_accuracy: 0.1875" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/32 [==================>...........] - ETA: 0s - loss: 1.4307 - sparse_categorical_accuracy: 0.5833 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 1s 3ms/step - loss: 1.1850 - sparse_categorical_accuracy: 0.6680\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.5525 - sparse_categorical_accuracy: 0.8438" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.4671 - sparse_categorical_accuracy: 0.8608" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.4369 - sparse_categorical_accuracy: 0.8670\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.2078 - sparse_categorical_accuracy: 0.9688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.2872 - sparse_categorical_accuracy: 0.9321" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2892 - sparse_categorical_accuracy: 0.9270\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1791 - sparse_categorical_accuracy: 0.9688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "24/32 [=====================>........] - ETA: 0s - loss: 0.2137 - sparse_categorical_accuracy: 0.9401" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2175 - sparse_categorical_accuracy: 0.9430\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1610 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "24/32 [=====================>........] - ETA: 0s - loss: 0.1492 - sparse_categorical_accuracy: 0.9714" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1508 - sparse_categorical_accuracy: 0.9690\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: saved_model/my_model/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: saved_model/my_model/assets\n" ] } ], "source": [ "# Create and train a new model instance.\n", "model = create_model()\n", "model.fit(train_images, train_labels, epochs=5)\n", "\n", "# Save the entire model as a SavedModel.\n", "!mkdir -p saved_model\n", "model.save('saved_model/my_model') " ] }, { "cell_type": "markdown", "metadata": { "id": "iUvT_3qE8hV5" }, "source": [ "SavedModel 形式は、protobuf バイナリと TensorFlow チェックポイントを含むディレクトリです。保存されたモデルディレクトリを調べます。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:48.803995Z", "iopub.status.busy": "2024-01-11T21:23:48.803324Z", "iopub.status.idle": "2024-01-11T21:23:49.111515Z", "shell.execute_reply": "2024-01-11T21:23:49.110353Z" }, "id": "sq8fPglI1RWA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "my_model\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "assets\tfingerprint.pb\tkeras_metadata.pb saved_model.pb variables\r\n" ] } ], "source": [ "# my_model directory\n", "!ls saved_model\n", "\n", "# Contains an assets folder, saved_model.pb, and variables folder.\n", "!ls saved_model/my_model" ] }, { "cell_type": "markdown", "metadata": { "id": "B7qfpvpY9HCe" }, "source": [ "保存したモデルから新しい Keras モデルを再度読み込みます。" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:49.116533Z", "iopub.status.busy": "2024-01-11T21:23:49.115923Z", "iopub.status.idle": "2024-01-11T21:23:49.547010Z", "shell.execute_reply": "2024-01-11T21:23:49.546364Z" }, "id": "0YofwHdN0pxa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_6\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_12 (Dense) (None, 512) 401920 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dropout_6 (Dropout) (None, 512) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_13 (Dense) (None, 10) 5130 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "new_model = tf.keras.models.load_model('saved_model/my_model')\n", "\n", "# Check its architecture\n", "new_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "uWwgNaz19TH2" }, "source": [ "復元されたモデルは、元のモデルと同じ引数でコンパイルされます。読み込まれたモデルで評価と予測を実行してみてください。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:49.554056Z", "iopub.status.busy": "2024-01-11T21:23:49.553458Z", "iopub.status.idle": "2024-01-11T21:23:49.970182Z", "shell.execute_reply": "2024-01-11T21:23:49.969401Z" }, "id": "Yh5Mu0yOgE5J" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 0.4212 - sparse_categorical_accuracy: 0.8590 - 189ms/epoch - 6ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored model, accuracy: 85.90%\n", "\r", " 1/32 [..............................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 1ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(1000, 10)\n" ] } ], "source": [ "# Evaluate the restored model\n", "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n", "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))\n", "\n", "print(new_model.predict(test_images).shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "SkGwf-50zLNn" }, "source": [ "### HDF5 形式\n", "\n", "Keras には、[HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) 標準を使用した基本的なレガシーの高レベル保存形式が備わっています。 " ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:49.974140Z", "iopub.status.busy": "2024-01-11T21:23:49.973515Z", "iopub.status.idle": "2024-01-11T21:23:51.156167Z", "shell.execute_reply": "2024-01-11T21:23:51.155372Z" }, "id": "m2dkmJVCGUia" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 20s - loss: 2.3408 - sparse_categorical_accuracy: 0.0938" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 1.3457 - sparse_categorical_accuracy: 0.6223 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 1s 2ms/step - loss: 1.1673 - sparse_categorical_accuracy: 0.6750\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.6018 - sparse_categorical_accuracy: 0.8438" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "24/32 [=====================>........] - ETA: 0s - loss: 0.4276 - sparse_categorical_accuracy: 0.8789" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.4260 - sparse_categorical_accuracy: 0.8790\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.3602 - sparse_categorical_accuracy: 0.8750" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "24/32 [=====================>........] - ETA: 0s - loss: 0.2691 - sparse_categorical_accuracy: 0.9271" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2811 - sparse_categorical_accuracy: 0.9260\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.3469 - sparse_categorical_accuracy: 0.9375" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.2042 - sparse_categorical_accuracy: 0.9552" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2110 - sparse_categorical_accuracy: 0.9510\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1238 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "24/32 [=====================>........] - ETA: 0s - loss: 0.1614 - sparse_categorical_accuracy: 0.9661" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1555 - sparse_categorical_accuracy: 0.9670\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/training.py:3103: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`.\n", " saving_api.save_model(\n" ] } ], "source": [ "# Create and train a new model instance.\n", "model = create_model()\n", "model.fit(train_images, train_labels, epochs=5)\n", "\n", "# Save the entire model to a HDF5 file.\n", "# The '.h5' extension indicates that the model should be saved to HDF5.\n", "model.save('my_model.h5') " ] }, { "cell_type": "markdown", "metadata": { "id": "GWmttMOqS68S" }, "source": [ "保存したファイルを使ってモデルを再作成します。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:51.160065Z", "iopub.status.busy": "2024-01-11T21:23:51.159457Z", "iopub.status.idle": "2024-01-11T21:23:51.232607Z", "shell.execute_reply": "2024-01-11T21:23:51.231932Z" }, "id": "5NDMO_7kS6Do" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_7\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_14 (Dense) (None, 512) 401920 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dropout_7 (Dropout) (None, 512) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_15 (Dense) (None, 10) 5130 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "# Recreate the exact same model, including its weights and the optimizer\n", "new_model = tf.keras.models.load_model('my_model.h5')\n", "\n", "# Show the model architecture\n", "new_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "JXQpbTicTBwt" }, "source": [ "正解率を検査します。" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:23:51.239766Z", "iopub.status.busy": "2024-01-11T21:23:51.239078Z", "iopub.status.idle": "2024-01-11T21:23:51.475400Z", "shell.execute_reply": "2024-01-11T21:23:51.474730Z" }, "id": "jwEaj9DnTCVA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 0.4251 - sparse_categorical_accuracy: 0.8540 - 186ms/epoch - 6ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored model, accuracy: 85.40%\n" ] } ], "source": [ "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n", "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "dGXqd4wWJl8O" }, "source": [ "Keras は、アーキテクチャを検査することでモデルを保存します。この手法ではすべてが保存されます。\n", "\n", "- 重みの値\n", "- モデルのアーキテクチャ\n", "- モデルのトレーニング構成(`.compile()` メソッドに渡すもの)\n", "- あれば、オプティマイザとその状態(中断した所からトレーニングを再開するため)\n", "\n", "Keras は `v1.x` (`tf.compat.v1.train` にあります) のオプティマイザを保存できません。これらはチェックポイントと互換性がないためです。v1.x のオプティマイザでは、オプティマイザの状態を読み込ませてモデルを再度コンパイルする必要があります。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "kAUKJQyGqTNH" }, "source": [ "### カスタムオブジェクトの保存\n", "\n", "SavedModel 形式を使用している場合は、このセクションを省略できます。高レベルの `.keras`/HDF5 形式と低レベルの SavedModel 形式の違いは、`.keras`/HDF5 形式はオブジェクト構成を使用してモデルアーキテクチャを保存するのに対し、SavedModel は実行グラフを保存するという点です。したがって、SavedModel は、元のコードがなくても、サブクラス化されたモデルやカスタムレイヤーなどのカスタムオブジェクトを保存することができます。ただしこれにより、低レベルの SavedModels のデバッグはより困難であるため、名前ベースで、Keras ネイティブであるという特性を備えた高レベルの `.keras` 形式を代わりに使用することをお勧めします。\n", "\n", "カスタムオブジェクトを `.keras` と HDF5 に保存するには、以下を実行します。\n", "\n", "1. オブジェクトで `get_config` メソッドを定義し、オプションで `from_config` クラスメソッドを定義します。\n", " - `get_config(self)` は、オブジェクトの再作成に必要なパラメータの JSON シリアル化可能なディクショナリを返します。\n", " - `from_config(cls, config){/code0 }は、get_config` から返された構成を使用して新しいオブジェクトを作成します。デフォルトでは、この関数は構成を初期化 kwargs (`return cls(**config)`) として使用します。\n", "2. 以下のいずれかの方法で、カスタムオブジェクトをモデルに渡します。\n", " - `@tf.keras.utils.register_keras_serializable` デコレータを使ってカスタムオブジェクトを登録します。**(推奨)**\n", " - モデルを読み込むときに、オブジェクトを直接 `custom_objects` 引数に渡します。引数は、文字列クラス名を Python クラスにマッピングするディクショナリである必要があります。(例: `tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})`)\n", " - `tf.keras.utils.custom_object_scope` を使用します。`custom_objects` ディクショナリ引数にオブジェクトを含め、範囲内に `tf.keras.models.load_model(path)` 呼び出しを配置します。\n", "\n", "カスタムオブジェクトと `get_config` の例については、[レイヤーとモデルを最初から作成する](https://www.tensorflow.org/guide/keras/custom_layers_and_models)チュートリアルをご覧ください。\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "save_and_load.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }