{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "g_nWetWWd_ns" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-11-08T00:03:56.067378Z", "iopub.status.busy": "2023-11-08T00:03:56.067122Z", "iopub.status.idle": "2023-11-08T00:03:56.071576Z", "shell.execute_reply": "2023-11-08T00:03:56.070855Z" }, "id": "2pHVBk_seED1" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-11-08T00:03:56.074747Z", "iopub.status.busy": "2023-11-08T00:03:56.074510Z", "iopub.status.idle": "2023-11-08T00:03:56.078363Z", "shell.execute_reply": "2023-11-08T00:03:56.077726Z" }, "id": "N_fMsQ-N8I7j" }, "outputs": [], "source": [ "#@title MIT License\n", "#\n", "# Copyright (c) 2017 François Chollet\n", "#\n", "# Permission is hereby granted, free of charge, to any person obtaining a\n", "# copy of this software and associated documentation files (the \"Software\"),\n", "# to deal in the Software without restriction, including without limitation\n", "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", "# and/or sell copies of the Software, and to permit persons to whom the\n", "# Software is furnished to do so, subject to the following conditions:\n", "#\n", "# The above copyright notice and this permission notice shall be included in\n", "# all copies or substantial portions of the Software.\n", "#\n", "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", "# DEALINGS IN THE SOFTWARE." ] }, { "cell_type": "markdown", "metadata": { "id": "pZJ3uY9O17VN" }, "source": [ "# 保存和恢复模型" ] }, { "cell_type": "markdown", "metadata": { "id": "M4Ata7_wMul1" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "mBdde4YJeJKF" }, "source": [ "可以在训练期间和之后保存模型进度。这意味着模型可以从停止的地方恢复,避免长时间的训练。此外,保存还意味着您可以分享您的模型,其他人可以重现您的工作。在发布研究模型和技术时,大多数机器学习从业者会分享:\n", "\n", "- 用于创建模型的代码\n", "- 模型的训练权重或形参\n", "\n", "共享数据有助于其他人了解模型的工作原理,并使用新数据自行尝试。\n", "\n", "小心:TensorFlow 模型是代码,对于不受信任的代码,一定要小心。请参阅 [安全使用 TensorFlow](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) 以了解详情。\n", "\n", "### 选项\n", "\n", "根据您使用的 API,可以通过不同的方式保存 TensorFlow 模型。本指南使用 [tf.keras](https://tensorflow.google.cn/guide/keras) – 一种用于在 TensorFlow 中构建和训练模型的高级 API。建议使用本教程中使用的新的高级 `.keras` 格式来保存 Keras 对象,因为它提供了强大、高效的基于名称的保存,通常比低级或旧版格式更容易调试。如需更高级的保存或序列化工作流,尤其是那些涉及自定义对象的工作流,请参阅[保存和加载 Keras 模型指南](https://tensorflow.google.cn/guide/keras/save_and_serialize)。对于其他方式,请参阅[使用 SavedModel 格式指南](../../guide/saved_model.ipynb)。" ] }, { "cell_type": "markdown", "metadata": { "id": "xCUREq7WXgvg" }, "source": [ "## 配置\n", "\n", "### 安装并导入" ] }, { "cell_type": "markdown", "metadata": { "id": "7l0MiTOrXtNv" }, "source": [ "安装并导入Tensorflow和依赖项:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:03:56.082437Z", "iopub.status.busy": "2023-11-08T00:03:56.081875Z", "iopub.status.idle": "2023-11-08T00:03:58.152894Z", "shell.execute_reply": "2023-11-08T00:03:58.151889Z" }, "id": "RzIOVSdnMYyO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyyaml in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (6.0.1)\r\n", "Requirement already satisfied: h5py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (3.10.0)\r\n", "Requirement already satisfied: numpy>=1.17.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from h5py) (1.26.1)\r\n" ] } ], "source": [ "!pip install pyyaml h5py # Required to save models in HDF5 format" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:03:58.157428Z", "iopub.status.busy": "2023-11-08T00:03:58.157126Z", "iopub.status.idle": "2023-11-08T00:04:00.748993Z", "shell.execute_reply": "2023-11-08T00:04:00.748205Z" }, "id": "7Nm7Tyb-gRt-" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-11-08 00:03:58.631423: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-11-08 00:03:58.631475: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-11-08 00:03:58.633260: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2.15.0-rc1\n" ] } ], "source": [ "import os\n", "\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "\n", "print(tf.version.VERSION)" ] }, { "cell_type": "markdown", "metadata": { "id": "SbGsznErXWt6" }, "source": [ "### 获取示例数据集\n", "\n", "为了演示如何保存和加载权重,您将使用 [MNIST 数据集](http://yann.lecun.com/exdb/mnist/)。为了加快运行速度,请使用前 1000 个样本:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:00.753152Z", "iopub.status.busy": "2023-11-08T00:04:00.752536Z", "iopub.status.idle": "2023-11-08T00:04:01.073866Z", "shell.execute_reply": "2023-11-08T00:04:01.072867Z" }, "id": "9rGfFwE9XVwz" }, "outputs": [], "source": [ "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n", "\n", "train_labels = train_labels[:1000]\n", "test_labels = test_labels[:1000]\n", "\n", "train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0\n", "test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0" ] }, { "cell_type": "markdown", "metadata": { "id": "anG3iVoXyZGI" }, "source": [ "### 定义模型" ] }, { "cell_type": "markdown", "metadata": { "id": "wynsOBfby0Pa" }, "source": [ "首先构建一个简单的序列(sequential)模型:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:01.078897Z", "iopub.status.busy": "2023-11-08T00:04:01.078287Z", "iopub.status.idle": "2023-11-08T00:04:03.456375Z", "shell.execute_reply": "2023-11-08T00:04:03.455574Z" }, "id": "0HZbJIjxyX1S" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense (Dense) (None, 512) 401920 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dropout (Dropout) (None, 512) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_1 (Dense) (None, 10) 5130 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "# Define a simple sequential model\n", "def create_model():\n", " model = tf.keras.Sequential([\n", " keras.layers.Dense(512, activation='relu', input_shape=(784,)),\n", " keras.layers.Dropout(0.2),\n", " keras.layers.Dense(10)\n", " ])\n", "\n", " model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n", "\n", " return model\n", "\n", "# Create a basic model instance\n", "model = create_model()\n", "\n", "# Display the model's architecture\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "soDE0W_KH8rG" }, "source": [ "## 在训练期间保存模型(以 checkpoints 形式保存)" ] }, { "cell_type": "markdown", "metadata": { "id": "mRyd5qQQIXZm" }, "source": [ "您可以使用经过训练的模型而无需重新训练,或者在训练过程中断的情况下从离开处继续训练。`tf.keras.callbacks.ModelCheckpoint` 回调允许您在训练*期间*和*结束*时持续保存模型。\n", "\n", "### Checkpoint 回调用法\n", "\n", "创建一个只在训练期间保存权重的 `tf.keras.callbacks.ModelCheckpoint` 回调:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:03.464469Z", "iopub.status.busy": "2023-11-08T00:04:03.463687Z", "iopub.status.idle": "2023-11-08T00:04:07.839525Z", "shell.execute_reply": "2023-11-08T00:04:07.838585Z" }, "id": "IFPuhwntH8VH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1699401845.104377 594618 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 59s - loss: 2.4154 - sparse_categorical_accuracy: 0.0938" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/32 [==================>...........] - ETA: 0s - loss: 1.3630 - sparse_categorical_accuracy: 0.6250 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 1: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 2s 13ms/step - loss: 1.1396 - sparse_categorical_accuracy: 0.6780 - val_loss: 0.6977 - val_sparse_categorical_accuracy: 0.7850\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.5334 - sparse_categorical_accuracy: 0.8125" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/32 [==================>...........] - ETA: 0s - loss: 0.4101 - sparse_categorical_accuracy: 0.8869" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 2: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 7ms/step - loss: 0.4015 - sparse_categorical_accuracy: 0.8870 - val_loss: 0.5432 - val_sparse_categorical_accuracy: 0.8300\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.3153 - sparse_categorical_accuracy: 0.9375" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/32 [==================>...........] - ETA: 0s - loss: 0.3156 - sparse_categorical_accuracy: 0.9152" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 3: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 7ms/step - loss: 0.2896 - sparse_categorical_accuracy: 0.9240 - val_loss: 0.4702 - val_sparse_categorical_accuracy: 0.8460\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.2313 - sparse_categorical_accuracy: 0.9062" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.1843 - sparse_categorical_accuracy: 0.9531" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 4: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 6ms/step - loss: 0.1912 - sparse_categorical_accuracy: 0.9540 - val_loss: 0.4457 - val_sparse_categorical_accuracy: 0.8550\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.0893 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/32 [==================>...........] - ETA: 0s - loss: 0.1528 - sparse_categorical_accuracy: 0.9717" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 5: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 7ms/step - loss: 0.1541 - sparse_categorical_accuracy: 0.9670 - val_loss: 0.4262 - val_sparse_categorical_accuracy: 0.8560\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.0743 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/32 [==================>...........] - ETA: 0s - loss: 0.1017 - sparse_categorical_accuracy: 0.9881" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 6: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 7ms/step - loss: 0.1096 - sparse_categorical_accuracy: 0.9830 - val_loss: 0.4280 - val_sparse_categorical_accuracy: 0.8540\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1880 - sparse_categorical_accuracy: 0.9375" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/32 [==================>...........] - ETA: 0s - loss: 0.0838 - sparse_categorical_accuracy: 0.9866" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 7: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 7ms/step - loss: 0.0869 - sparse_categorical_accuracy: 0.9840 - val_loss: 0.4128 - val_sparse_categorical_accuracy: 0.8680\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.0568 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/32 [==================>...........] - ETA: 0s - loss: 0.0612 - sparse_categorical_accuracy: 0.9926" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 8: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 7ms/step - loss: 0.0647 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.4074 - val_sparse_categorical_accuracy: 0.8620\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.0419 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.0540 - sparse_categorical_accuracy: 0.9986" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 9: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 6ms/step - loss: 0.0511 - sparse_categorical_accuracy: 0.9980 - val_loss: 0.4087 - val_sparse_categorical_accuracy: 0.8700\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.0259 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.0372 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 10: saving model to training_1/cp.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 7ms/step - loss: 0.0401 - sparse_categorical_accuracy: 0.9990 - val_loss: 0.4369 - val_sparse_categorical_accuracy: 0.8560\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "checkpoint_path = \"training_1/cp.ckpt\"\n", "checkpoint_dir = os.path.dirname(checkpoint_path)\n", "\n", "# Create a callback that saves the model's weights\n", "cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\n", " save_weights_only=True,\n", " verbose=1)\n", "\n", "# Train the model with the new callback\n", "model.fit(train_images, \n", " train_labels, \n", " epochs=10,\n", " validation_data=(test_images, test_labels),\n", " callbacks=[cp_callback]) # Pass callback to training\n", "\n", "# This may generate warnings related to saving the state of the optimizer.\n", "# These warnings (and similar warnings throughout this notebook)\n", "# are in place to discourage outdated usage, and can be ignored." ] }, { "cell_type": "markdown", "metadata": { "id": "rlM-sgyJO084" }, "source": [ "这将创建一个 TensorFlow checkpoint 文件集合,这些文件在每个 epoch 结束时更新:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:07.843754Z", "iopub.status.busy": "2023-11-08T00:04:07.843045Z", "iopub.status.idle": "2023-11-08T00:04:07.848862Z", "shell.execute_reply": "2023-11-08T00:04:07.847946Z" }, "id": "gXG5FVKFOVQ3" }, "outputs": [ { "data": { "text/plain": [ "['cp.ckpt.index', 'cp.ckpt.data-00000-of-00001', 'checkpoint']" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "os.listdir(checkpoint_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "wlRN_f56Pqa9" }, "source": [ "只要两个模型共享相同的架构,您就可以在它们之间共享权重。因此,当从仅权重恢复模型时,创建一个与原始模型具有相同架构的模型,然后设置其权重。\n", "\n", "现在,重新构建一个未经训练的全新模型并基于测试集对其进行评估。未经训练的模型将以机会水平执行(约 10% 的准确率):" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:07.852691Z", "iopub.status.busy": "2023-11-08T00:04:07.852016Z", "iopub.status.idle": "2023-11-08T00:04:08.130001Z", "shell.execute_reply": "2023-11-08T00:04:08.129112Z" }, "id": "Fp5gbuiaPqCT" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 2.3531 - sparse_categorical_accuracy: 0.0720 - 188ms/epoch - 6ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Untrained model, accuracy: 7.20%\n" ] } ], "source": [ "# Create a basic model instance\n", "model = create_model()\n", "\n", "# Evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Untrained model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "1DTKpZssRSo3" }, "source": [ "然后从 checkpoint 加载权重并重新评估:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:08.134014Z", "iopub.status.busy": "2023-11-08T00:04:08.133346Z", "iopub.status.idle": "2023-11-08T00:04:08.291014Z", "shell.execute_reply": "2023-11-08T00:04:08.290129Z" }, "id": "2IZxbwiRRSD2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 0.4369 - sparse_categorical_accuracy: 0.8560 - 90ms/epoch - 3ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored model, accuracy: 85.60%\n" ] } ], "source": [ "# Loads the weights\n", "model.load_weights(checkpoint_path)\n", "\n", "# Re-evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "bpAbKkAyVPV8" }, "source": [ "### checkpoint 回调选项\n", "\n", "回调提供了几个选项,为 checkpoint 提供唯一名称并调整 checkpoint 频率。\n", "\n", "训练一个新模型,每五个 epochs 保存一次唯一命名的 checkpoint :" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:08.294950Z", "iopub.status.busy": "2023-11-08T00:04:08.294407Z", "iopub.status.idle": "2023-11-08T00:04:17.842195Z", "shell.execute_reply": "2023-11-08T00:04:17.841201Z" }, "id": "mQF_dlgIVOvq" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 5: saving model to training_2/cp-0005.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 10: saving model to training_2/cp-0010.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 15: saving model to training_2/cp-0015.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 20: saving model to training_2/cp-0020.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 25: saving model to training_2/cp-0025.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 30: saving model to training_2/cp-0030.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 35: saving model to training_2/cp-0035.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 40: saving model to training_2/cp-0040.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 45: saving model to training_2/cp-0045.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 50: saving model to training_2/cp-0050.ckpt\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Include the epoch in the file name (uses `str.format`)\n", "checkpoint_path = \"training_2/cp-{epoch:04d}.ckpt\"\n", "checkpoint_dir = os.path.dirname(checkpoint_path)\n", "\n", "batch_size = 32\n", "\n", "# Calculate the number of batches per epoch\n", "import math\n", "n_batches = len(train_images) / batch_size\n", "n_batches = math.ceil(n_batches) # round up the number of batches to the nearest whole integer\n", "\n", "# Create a callback that saves the model's weights every 5 epochs\n", "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n", " filepath=checkpoint_path, \n", " verbose=1, \n", " save_weights_only=True,\n", " save_freq=5*n_batches)\n", "\n", "# Create a new model instance\n", "model = create_model()\n", "\n", "# Save the weights using the `checkpoint_path` format\n", "model.save_weights(checkpoint_path.format(epoch=0))\n", "\n", "# Train the model with the new callback\n", "model.fit(train_images, \n", " train_labels,\n", " epochs=50, \n", " batch_size=batch_size, \n", " callbacks=[cp_callback],\n", " validation_data=(test_images, test_labels),\n", " verbose=0)" ] }, { "cell_type": "markdown", "metadata": { "id": "1zFrKTjjavWI" }, "source": [ "现在,检查生成的检查点并选择最新检查点:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:17.846493Z", "iopub.status.busy": "2023-11-08T00:04:17.845856Z", "iopub.status.idle": "2023-11-08T00:04:17.852002Z", "shell.execute_reply": "2023-11-08T00:04:17.851127Z" }, "id": "p64q3-V4sXt0" }, "outputs": [ { "data": { "text/plain": [ "['cp-0050.ckpt.index',\n", " 'cp-0045.ckpt.data-00000-of-00001',\n", " 'cp-0005.ckpt.data-00000-of-00001',\n", " 'cp-0000.ckpt.index',\n", " 'cp-0000.ckpt.data-00000-of-00001',\n", " 'cp-0045.ckpt.index',\n", " 'cp-0035.ckpt.index',\n", " 'cp-0015.ckpt.data-00000-of-00001',\n", " 'cp-0025.ckpt.index',\n", " 'cp-0040.ckpt.data-00000-of-00001',\n", " 'cp-0040.ckpt.index',\n", " 'cp-0005.ckpt.index',\n", " 'cp-0010.ckpt.index',\n", " 'cp-0030.ckpt.index',\n", " 'cp-0015.ckpt.index',\n", " 'cp-0035.ckpt.data-00000-of-00001',\n", " 'cp-0020.ckpt.index',\n", " 'cp-0030.ckpt.data-00000-of-00001',\n", " 'cp-0020.ckpt.data-00000-of-00001',\n", " 'cp-0010.ckpt.data-00000-of-00001',\n", " 'cp-0025.ckpt.data-00000-of-00001',\n", " 'checkpoint',\n", " 'cp-0050.ckpt.data-00000-of-00001']" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "os.listdir(checkpoint_dir)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:17.855542Z", "iopub.status.busy": "2023-11-08T00:04:17.854890Z", "iopub.status.idle": "2023-11-08T00:04:17.861011Z", "shell.execute_reply": "2023-11-08T00:04:17.860225Z" }, "id": "1AN_fnuyR41H" }, "outputs": [ { "data": { "text/plain": [ "'training_2/cp-0050.ckpt'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "latest = tf.train.latest_checkpoint(checkpoint_dir)\n", "latest" ] }, { "cell_type": "markdown", "metadata": { "id": "Zk2ciGbKg561" }, "source": [ "注:默认 TensorFlow 格式只保存最近的 5 个检查点。\n", "\n", "要进行测试,请重置模型并加载最新检查点:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:17.864891Z", "iopub.status.busy": "2023-11-08T00:04:17.864247Z", "iopub.status.idle": "2023-11-08T00:04:18.150300Z", "shell.execute_reply": "2023-11-08T00:04:18.149558Z" }, "id": "3M04jyK-H3QK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 0.4731 - sparse_categorical_accuracy: 0.8800 - 184ms/epoch - 6ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored model, accuracy: 88.00%\n" ] } ], "source": [ "# Create a new model instance\n", "model = create_model()\n", "\n", "# Load the previously saved weights\n", "model.load_weights(latest)\n", "\n", "# Re-evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "c2OxsJOTHxia" }, "source": [ "## 这些文件是什么?" ] }, { "cell_type": "markdown", "metadata": { "id": "JtdYhvWnH2ib" }, "source": [ "上述代码可将权重存储到[检查点](../../guide/checkpoint.ipynb)格式文件(仅包含二进制格式训练权重) 的合集中。检查点包含:\n", "\n", "- 一个或多个包含模型权重的分片。\n", "- 一个索引文件,指示哪些权重存储在哪个分片中。\n", "\n", "如果您在一台计算机上训练模型,您将获得一个具有如下后缀的分片:`.data-00000-of-00001`" ] }, { "cell_type": "markdown", "metadata": { "id": "S_FA-ZvxuXQV" }, "source": [ "## 手动保存权重\n", "\n", "要手动保存权重,请使用 `tf.keras.Model.save_weights`。默认情况下,`tf.keras`(尤其是 `Model.save_weights` 方法)使用扩展名为 `.ckpt` 的 TensorFlow [检查点](../../guide/checkpoint.ipynb)格式。要以扩展名为 `.h5` 的 HDF5 格式保存,请参阅[保存和加载模型](https://tensorflow.google.cn/guide/keras/save_and_serialize)指南。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:18.154276Z", "iopub.status.busy": "2023-11-08T00:04:18.153999Z", "iopub.status.idle": "2023-11-08T00:04:18.449398Z", "shell.execute_reply": "2023-11-08T00:04:18.448376Z" }, "id": "R7W5plyZ-u9X" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 0.4731 - sparse_categorical_accuracy: 0.8800 - 185ms/epoch - 6ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored model, accuracy: 88.00%\n" ] } ], "source": [ "# Save the weights\n", "model.save_weights('./checkpoints/my_checkpoint')\n", "\n", "# Create a new model instance\n", "model = create_model()\n", "\n", "# Restore the weights\n", "model.load_weights('./checkpoints/my_checkpoint')\n", "\n", "# Evaluate the model\n", "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n", "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "kOGlxPRBEvV1" }, "source": [ "## 保存整个模型\n", "\n", "调用 `tf.keras.Model.save`,将模型的架构、权重和训练配置保存在单个 `model.keras` zip 存档中。\n", "\n", "整个模型可以保存为三种不同的文件格式(新的 `.keras` 格式和两种旧格式:`SavedModel` 和 `HDF5`)。将模型保存为 `path/to/model.keras` 会自动以最新格式保存。\n", "\n", "**注意**:对于 Keras 对象,建议使用新的高级 `.keras` 格式进行更丰富的基于名称的保存和重新加载,这样更易于调试。现有代码继续支持低级 SavedModel 格式和旧版 H5 格式。\n", "\n", "您可以通过以下方式切换到 SavedModel 格式:\n", "\n", "- 将 `save_format='tf'` 传递到 `save()`\n", "- 传递不带扩展名的文件名\n", "\n", "您可以通过以下方式切换到 H5 格式:\n", "\n", "- 将 `save_format='h5'` 传递到 `save()`\n", "- 传递以 `.h5` 结尾的文件名\n", "\n", "Saving a fully-functional model is very useful—you can load them in TensorFlow.js ([Saved Model](https://tensorflow.google.cn/js/tutorials/conversion/import_saved_model), [HDF5](https://tensorflow.google.cn/js/tutorials/conversion/import_keras)) and then train and run them in web browsers, or convert them to run on mobile devices using TensorFlow Lite ([Saved Model](https://tensorflow.google.cn/lite/models/convert/#convert_a_savedmodel_recommended_), [HDF5](https://tensorflow.google.cn/lite/models/convert/#convert_a_keras_model_))\n", "\n", "*Custom objects (for example, subclassed models or layers) require special attention when saving and loading. Refer to the **Saving custom objects** section below." ] }, { "cell_type": "markdown", "metadata": { "id": "0fRGnlHMrkI7" }, "source": [ "### 新的高级 `.keras` 格式" ] }, { "cell_type": "markdown", "metadata": { "id": "eqO8jj7GsCDn" }, "source": [ "以 `.keras` 扩展名标记的新 Keras v3 保存格式是一种更简单、更高效的格式,它实现了基于名称的保存,从 Python 的角度确保您加载的内容与您保存的内容完全相同。这使得调试更容易,并且它是 Keras 的推荐格式。\n", "\n", "下面的部分说明了如何以 `.keras` 格式保存和恢复模型。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:18.453699Z", "iopub.status.busy": "2023-11-08T00:04:18.453111Z", "iopub.status.idle": "2023-11-08T00:04:19.747619Z", "shell.execute_reply": "2023-11-08T00:04:19.746803Z" }, "id": "3f55mAXwukUX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 21s - loss: 2.3923 - sparse_categorical_accuracy: 0.0938" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/32 [==================>...........] - ETA: 0s - loss: 1.4819 - sparse_categorical_accuracy: 0.5551 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 1s 3ms/step - loss: 1.2257 - sparse_categorical_accuracy: 0.6420\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.4663 - sparse_categorical_accuracy: 0.8750" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.4353 - sparse_categorical_accuracy: 0.8845" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.4333 - sparse_categorical_accuracy: 0.8820\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.3810 - sparse_categorical_accuracy: 0.8438" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/32 [==================>...........] - ETA: 0s - loss: 0.3120 - sparse_categorical_accuracy: 0.9286" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 3ms/step - loss: 0.3014 - sparse_categorical_accuracy: 0.9260\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1449 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.2208 - sparse_categorical_accuracy: 0.9524" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2073 - sparse_categorical_accuracy: 0.9580\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1947 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/32 [====================>.........] - ETA: 0s - loss: 0.1542 - sparse_categorical_accuracy: 0.9606" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1596 - sparse_categorical_accuracy: 0.9620\n" ] } ], "source": [ "# Create and train a new model instance.\n", "model = create_model()\n", "model.fit(train_images, train_labels, epochs=5)\n", "\n", "# Save the entire model as a `.keras` zip archive.\n", "model.save('my_model.keras')" ] }, { "cell_type": "markdown", "metadata": { "id": "iHqwaun5g8lD" }, "source": [ "从 `.keras` zip 归档重新加载新的 Keras 模型:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:19.752206Z", "iopub.status.busy": "2023-11-08T00:04:19.751546Z", "iopub.status.idle": "2023-11-08T00:04:19.912501Z", "shell.execute_reply": "2023-11-08T00:04:19.911729Z" }, "id": "HyfUMOZwux_-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_5\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_10 (Dense) (None, 512) 401920 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dropout_5 (Dropout) (None, 512) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_11 (Dense) (None, 10) 5130 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "new_model = tf.keras.models.load_model('my_model.keras')\n", "\n", "# Show the model architecture\n", "new_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "9Cn3pSBqvJ5f" }, "source": [ "尝试使用加载的模型运行评估和预测:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:19.920373Z", "iopub.status.busy": "2023-11-08T00:04:19.920106Z", "iopub.status.idle": "2023-11-08T00:04:20.372924Z", "shell.execute_reply": "2023-11-08T00:04:20.372095Z" }, "id": "8BT4mHNIvMdW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 0.4198 - sparse_categorical_accuracy: 0.8670 - 187ms/epoch - 6ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored model, accuracy: 86.70%\n", "\r", " 1/32 [..............................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 1ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(1000, 10)\n" ] } ], "source": [ "# Evaluate the restored model\n", "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n", "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))\n", "\n", "print(new_model.predict(test_images).shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "kPyhgcoVzqUB" }, "source": [ "### SavedModel 格式" ] }, { "cell_type": "markdown", "metadata": { "id": "LtcN4VIb7JkK" }, "source": [ "SavedModel 格式是另一种序列化模型的方式。以这种格式保存的模型可以使用 `tf.keras.models.load_model` 还原,并且与 TensorFlow Serving 兼容。[SavedModel 指南](../../guide/saved_model.ipynb)详细介绍了如何 `serve/inspect` SavedModel。以下部分说明了保存和恢复模型的步骤。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:20.377032Z", "iopub.status.busy": "2023-11-08T00:04:20.376433Z", "iopub.status.idle": "2023-11-08T00:04:22.350639Z", "shell.execute_reply": "2023-11-08T00:04:22.349908Z" }, "id": "sI1YvCDFzpl3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 21s - loss: 2.4572 - sparse_categorical_accuracy: 0.0000e+00" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/32 [=================>............] - ETA: 0s - loss: 1.4852 - sparse_categorical_accuracy: 0.5594 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 1s 3ms/step - loss: 1.1823 - sparse_categorical_accuracy: 0.6530\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.6145 - sparse_categorical_accuracy: 0.8438" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/32 [==================>...........] - ETA: 0s - loss: 0.4739 - sparse_categorical_accuracy: 0.8527" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 3ms/step - loss: 0.4464 - sparse_categorical_accuracy: 0.8660\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1838 - sparse_categorical_accuracy: 0.9688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.2747 - sparse_categorical_accuracy: 0.9276" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 3ms/step - loss: 0.2765 - sparse_categorical_accuracy: 0.9290\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1509 - sparse_categorical_accuracy: 0.9688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.2231 - sparse_categorical_accuracy: 0.9474" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 3ms/step - loss: 0.2195 - sparse_categorical_accuracy: 0.9490\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1657 - sparse_categorical_accuracy: 0.9688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.1510 - sparse_categorical_accuracy: 0.9716" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 3ms/step - loss: 0.1523 - sparse_categorical_accuracy: 0.9660\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: saved_model/my_model/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: saved_model/my_model/assets\n" ] } ], "source": [ "# Create and train a new model instance.\n", "model = create_model()\n", "model.fit(train_images, train_labels, epochs=5)\n", "\n", "# Save the entire model as a SavedModel.\n", "!mkdir -p saved_model\n", "model.save('saved_model/my_model') " ] }, { "cell_type": "markdown", "metadata": { "id": "iUvT_3qE8hV5" }, "source": [ "SavedModel 格式是一个包含 protobuf 二进制文件和 TensorFlow 检查点的目录。检查保存的模型目录:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:22.354744Z", "iopub.status.busy": "2023-11-08T00:04:22.354450Z", "iopub.status.idle": "2023-11-08T00:04:22.669329Z", "shell.execute_reply": "2023-11-08T00:04:22.668282Z" }, "id": "sq8fPglI1RWA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "my_model\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "assets\tfingerprint.pb\tkeras_metadata.pb saved_model.pb variables\r\n" ] } ], "source": [ "# my_model directory\n", "!ls saved_model\n", "\n", "# Contains an assets folder, saved_model.pb, and variables folder.\n", "!ls saved_model/my_model" ] }, { "cell_type": "markdown", "metadata": { "id": "B7qfpvpY9HCe" }, "source": [ "从保存的模型重新加载一个新的 Keras 模型:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:22.674215Z", "iopub.status.busy": "2023-11-08T00:04:22.673880Z", "iopub.status.idle": "2023-11-08T00:04:23.137929Z", "shell.execute_reply": "2023-11-08T00:04:23.137208Z" }, "id": "0YofwHdN0pxa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_6\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_12 (Dense) (None, 512) 401920 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dropout_6 (Dropout) (None, 512) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_13 (Dense) (None, 10) 5130 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "new_model = tf.keras.models.load_model('saved_model/my_model')\n", "\n", "# Check its architecture\n", "new_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "uWwgNaz19TH2" }, "source": [ "使用与原始模型相同的实参编译恢复的模型。尝试使用加载的模型运行评估和预测:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:23.145460Z", "iopub.status.busy": "2023-11-08T00:04:23.145199Z", "iopub.status.idle": "2023-11-08T00:04:23.569906Z", "shell.execute_reply": "2023-11-08T00:04:23.569154Z" }, "id": "Yh5Mu0yOgE5J" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 0.4436 - sparse_categorical_accuracy: 0.8560 - 189ms/epoch - 6ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored model, accuracy: 85.60%\n", "\r", " 1/32 [..............................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 1ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(1000, 10)\n" ] } ], "source": [ "# Evaluate the restored model\n", "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n", "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))\n", "\n", "print(new_model.predict(test_images).shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "SkGwf-50zLNn" }, "source": [ "### HDF5 格式\n", "\n", "Keras 使用 [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) 标准提供基本的旧版高级保存格式。 " ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:23.574213Z", "iopub.status.busy": "2023-11-08T00:04:23.573585Z", "iopub.status.idle": "2023-11-08T00:04:24.845421Z", "shell.execute_reply": "2023-11-08T00:04:24.844458Z" }, "id": "m2dkmJVCGUia" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 21s - loss: 2.3124 - sparse_categorical_accuracy: 0.0938" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/32 [=================>............] - ETA: 0s - loss: 1.3583 - sparse_categorical_accuracy: 0.6187 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 1s 3ms/step - loss: 1.1044 - sparse_categorical_accuracy: 0.6980\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.2788 - sparse_categorical_accuracy: 0.9688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.4147 - sparse_categorical_accuracy: 0.8821" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 3ms/step - loss: 0.3960 - sparse_categorical_accuracy: 0.8830\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.2157 - sparse_categorical_accuracy: 0.9688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.2576 - sparse_categorical_accuracy: 0.9276" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2616 - sparse_categorical_accuracy: 0.9280\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1809 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.1991 - sparse_categorical_accuracy: 0.9531" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 3ms/step - loss: 0.1953 - sparse_categorical_accuracy: 0.9520\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1050 - sparse_categorical_accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/32 [===================>..........] - ETA: 0s - loss: 0.1563 - sparse_categorical_accuracy: 0.9645" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 3ms/step - loss: 0.1439 - sparse_categorical_accuracy: 0.9670\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/training.py:3103: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`.\n", " saving_api.save_model(\n" ] } ], "source": [ "# Create and train a new model instance.\n", "model = create_model()\n", "model.fit(train_images, train_labels, epochs=5)\n", "\n", "# Save the entire model to a HDF5 file.\n", "# The '.h5' extension indicates that the model should be saved to HDF5.\n", "model.save('my_model.h5') " ] }, { "cell_type": "markdown", "metadata": { "id": "GWmttMOqS68S" }, "source": [ "现在,从该文件重新创建模型:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:24.849451Z", "iopub.status.busy": "2023-11-08T00:04:24.848669Z", "iopub.status.idle": "2023-11-08T00:04:24.927225Z", "shell.execute_reply": "2023-11-08T00:04:24.926564Z" }, "id": "5NDMO_7kS6Do" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_7\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_14 (Dense) (None, 512) 401920 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dropout_7 (Dropout) (None, 512) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_15 (Dense) (None, 10) 5130 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 407050 (1.55 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "# Recreate the exact same model, including its weights and the optimizer\n", "new_model = tf.keras.models.load_model('my_model.h5')\n", "\n", "# Show the model architecture\n", "new_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "JXQpbTicTBwt" }, "source": [ "检查其准确率(accuracy):" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2023-11-08T00:04:24.934763Z", "iopub.status.busy": "2023-11-08T00:04:24.934505Z", "iopub.status.idle": "2023-11-08T00:04:25.179642Z", "shell.execute_reply": "2023-11-08T00:04:25.178886Z" }, "id": "jwEaj9DnTCVA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32/32 - 0s - loss: 0.4303 - sparse_categorical_accuracy: 0.8610 - 191ms/epoch - 6ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored model, accuracy: 86.10%\n" ] } ], "source": [ "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n", "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "dGXqd4wWJl8O" }, "source": [ "Keras 通过检查模型的架构来保存这些模型。这种技术可以保存所有内容:\n", "\n", "- 权重值\n", "- 模型的架构\n", "- 模型的训练配置(您传递给 `.compile()` 方法的内容)\n", "- 优化器及其状态(如果有)(这样,您便可从中断的地方重新启动训练)\n", "\n", "Keras 无法保存 `v1.x` 优化器(来自 `tf.compat.v1.train`),因为它们与检查点不兼容。对于 v1.x 优化器,您需要在加载-失去优化器的状态后,重新编译模型。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "kAUKJQyGqTNH" }, "source": [ "### 保存自定义对象\n", "\n", "如果您使用的是 SavedModel 格式,则可以跳过此部分。高级 `.keras`/HDF5 格式与低级 SavedModel 格式之间的主要区别在于 `.keras`/HDF5 格式使用对象配置来保存模型架构,而 SavedModel 保存执行计算图。因此,SavedModels 能够保存自定义对象,例如子类化模型和自定义层,而无需原始代码。但是,因此调试低级 SavedModels 可能会更加困难,鉴于基于名称并且对于 Keras 是原生的特性,我们建议改用高级 `.keras` 格式。\n", "\n", "要将自定义对象保存到 `.keras` 和 HDF5,您必须执行以下操作:\n", "\n", "1. 在您的对象中定义一个 `get_config` 方法,并且可以选择定义一个 `from_config` 类方法。\n", " - `get_config(self)` 返回重新创建对象所需的形参的 JSON 可序列化字典。\n", " - `from_config(cls, config)` 使用从 `get_config` 返回的配置来创建一个新对象。默认情况下,此函数将使用配置作为初始化 kwarg (`return cls(**config)`)。\n", "2. 通过以下三种方式之一将自定义对象传递给模型:\n", " - 使用 `@tf.keras.utils.register_keras_serializable` 装饰器注册自定义对象。**(推荐)**\n", " - 加载模型时直接将对象传递给 `custom_objects` 实参。实参必须是将字符串类名映射到 Python 类的字典。例如 `tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})`\n", " - 将 `tf.keras.utils.custom_object_scope` 与 `custom_objects` 字典实参中包含的对象一起使用,并在作用域内放置一个 `tf.keras.models.load_model(path){ /code2} 调用。`\n", "\n", "有关自定义对象和 `get_config` 的示例,请参阅[从头开始编写层和模型](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)教程。\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "save_and_load.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }