{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "b518b04cbfe0" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T22:00:38.849154Z", "iopub.status.busy": "2022-12-14T22:00:38.848714Z", "iopub.status.idle": "2022-12-14T22:00:38.852260Z", "shell.execute_reply": "2022-12-14T22:00:38.851721Z" }, "id": "906e07f6e562" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "394e705afdd5" }, "source": [ "# 保存和加载 Keras 模型" ] }, { "cell_type": "markdown", "metadata": { "id": "60de82f6bcea" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 运行 在 GitHub 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "a6e70d8d6fa8" }, "source": [ "## 简介\n", "\n", "Keras 模型由多个组件组成:\n", "\n", "- 架构或配置,指定模型包含的层及其连接方式。\n", "- 优化器(通过编译模型来定义)。\n", "- 优化器(通过编译模型来定义)。\n", "- 一组损失和指标(通过编译模型或调用 `add_loss()` 或 `add_metric()` 定义)。\n", "\n", "您可以通过 Keras API 将这些片段一次性保存到磁盘,或仅选择性地保存其中一些片段:\n", "\n", "- 将所有内容以 TensorFlow SavedModel 格式(或较早的 Keras H5 格式)保存到单个存档。这是标准做法。\n", "- 仅保存架构/配置,通常保存为 JSON 文件。\n", "- 仅保存权重值。通常在训练模型时使用。\n", "\n", "我们来看看每个选项。什么时候使用哪个选项?它们是如何工作的?" ] }, { "cell_type": "markdown", "metadata": { "id": "ff15300e41fe" }, "source": [ "## 如何保存和加载模型\n", "\n", "如果您只有 10 秒钟来阅读本指南,则您只需了解以下内容。\n", "\n", "**保存 Keras 模型**\n", "\n", "```python\n", "model = ... # Get model (Sequential, Functional Model, or Model subclass)\n", "model.save('path/to/location')\n", "```\n", "\n", "**重新加载模型:**\n", "\n", "```python\n", "from tensorflow import keras\n", "model = keras.models.load_model('path/to/location')\n", "```\n", "\n", "现在,我们来查看详细信息。" ] }, { "cell_type": "markdown", "metadata": { "id": "41fbd6a3290a" }, "source": [ "## 安装" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:38.856221Z", "iopub.status.busy": "2022-12-14T22:00:38.855729Z", "iopub.status.idle": "2022-12-14T22:00:40.790253Z", "shell.execute_reply": "2022-12-14T22:00:40.789568Z" }, "id": "abff67cc7505" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 22:00:39.820500: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:00:39.820591: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:00:39.820601: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "from tensorflow import keras" ] }, { "cell_type": "markdown", "metadata": { "id": "438511b14a90" }, "source": [ "## 全模型保存和加载\n", "\n", "您可以将整个模型保存到单个工件中。它将包括:\n", "\n", "- 模型的架构/配置\n", "- 模型的权重值(在训练过程中学习)\n", "- 模型的编译信息(如果调用了 `compile()`)\n", "- 优化器及其状态(如果有,这使您可以从中断的地方重新启动训练)\n", "\n", "#### API\n", "\n", "- `model.save()` 或 `tf.keras.models.save_model()`\n", "- `tf.keras.models.load_model()`\n", "\n", "您可以使用两种格式将整个模型保存到磁盘:**TensorFlow SavedModel 格式**和**较早的 Keras H5 格式**。推荐使用 SavedModel 格式。它是使用 `model.save()` 时的默认格式。\n", "\n", "您可以通过以下方式切换到 H5 格式:\n", "\n", "- 将 `save_format='h5'` 传递给 `save()`。\n", "- 将以 `.h5` 或 `.keras` 结尾的文件名传递给 `save()`。" ] }, { "cell_type": "markdown", "metadata": { "id": "812f19d9dc7c" }, "source": [ "### SavedModel 格式\n", "\n", "SavedModel 是更全面的保存格式,它可以保存模型架构、权重和调用函数的跟踪 Tensorflow 子计算图。这使 Keras 能够恢复内置层和自定义对象。\n", "\n", "**示例:**" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:40.794632Z", "iopub.status.busy": "2022-12-14T22:00:40.794217Z", "iopub.status.idle": "2022-12-14T22:00:46.248571Z", "shell.execute_reply": "2022-12-14T22:00:46.247608Z" }, "id": "4d910eb33378" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/4 [======>.......................] - ETA: 2s - loss: 0.2647" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/4 [==============================] - 1s 4ms/step - loss: 0.3378\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: my_model/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/4 [======>.......................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/4 [==============================] - 0s 2ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/4 [======>.......................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/4 [==============================] - 0s 2ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/4 [======>.......................] - ETA: 1s - loss: 0.2572" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/4 [==============================] - 0s 3ms/step - loss: 0.2969\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def get_model():\n", " # Create a simple model.\n", " inputs = keras.Input(shape=(32,))\n", " outputs = keras.layers.Dense(1)(inputs)\n", " model = keras.Model(inputs, outputs)\n", " model.compile(optimizer=\"adam\", loss=\"mean_squared_error\")\n", " return model\n", "\n", "\n", "model = get_model()\n", "\n", "# Train the model.\n", "test_input = np.random.random((128, 32))\n", "test_target = np.random.random((128, 1))\n", "model.fit(test_input, test_target)\n", "\n", "# Calling `save('my_model')` creates a SavedModel folder `my_model`.\n", "model.save(\"my_model\")\n", "\n", "# It can be used to reconstruct the model identically.\n", "reconstructed_model = keras.models.load_model(\"my_model\")\n", "\n", "# Let's check:\n", "np.testing.assert_allclose(\n", " model.predict(test_input), reconstructed_model.predict(test_input)\n", ")\n", "\n", "# The reconstructed model is already compiled and has retained the optimizer\n", "# state, so training can resume:\n", "reconstructed_model.fit(test_input, test_target)" ] }, { "cell_type": "markdown", "metadata": { "id": "3f8e96c8a949" }, "source": [ "#### SavedModel 包含的内容\n", "\n", "调用 `model.save('my_model')` 会创建一个名为 `my_model` 的文件夹,其包含以下内容:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:46.252505Z", "iopub.status.busy": "2022-12-14T22:00:46.251935Z", "iopub.status.idle": "2022-12-14T22:00:46.426797Z", "shell.execute_reply": "2022-12-14T22:00:46.425996Z" }, "id": "47be41998ac5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "assets\tfingerprint.pb\tkeras_metadata.pb saved_model.pb variables\r\n" ] } ], "source": [ "!ls my_model" ] }, { "cell_type": "markdown", "metadata": { "id": "ead32a4345d1" }, "source": [ "以下示例演示了在没有重写配置方法的情况下,从 SavedModel 格式加载自定义层所发生的情况。\n", "\n", "有关 SavedModel 格式的详细信息,请参阅 [SavedModel 指南(*磁盘上的 SavedModel 格式*)](https://tensorflow.google.cn/guide/saved_model#the_savedmodel_format_on_disk)。\n", "\n", "#### SavedModel 处理自定义对象的方式\n", "\n", "保存模型和模型的层时,SavedModel 格式会存储类名称、**调用函数**、损失和权重(如果已实现,还包括配置)。调用函数会定义模型/层的计算图。\n", "\n", "在没有模型/层配置的情况下,调用函数用于创建一个与原始模型一样存在的模型,可以训练、评估该模型以及将其用于推断。\n", "\n", "不过,在编写自定义模型或层类时,定义 `get_config` 和 `from_config` 方法始终是一个好习惯。这使您可以在需要时轻松更新计算。如需了解详情,请参阅有关[自定义对象](#custom-objects)的部分。\n", "\n", "示例:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:46.430660Z", "iopub.status.busy": "2022-12-14T22:00:46.430397Z", "iopub.status.idle": "2022-12-14T22:00:47.023656Z", "shell.execute_reply": "2022-12-14T22:00:47.022779Z" }, "id": "28bbf9f611d6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: my_model/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Original model: <__main__.CustomModel object at 0x7f97b0aaddf0>\n", "Model Loaded with custom objects: <__main__.CustomModel object at 0x7f98920dad90>\n", "Model loaded without the custom object class: \n" ] } ], "source": [ "class CustomModel(keras.Model):\n", " def __init__(self, hidden_units):\n", " super(CustomModel, self).__init__()\n", " self.hidden_units = hidden_units\n", " self.dense_layers = [keras.layers.Dense(u) for u in hidden_units]\n", "\n", " def call(self, inputs):\n", " x = inputs\n", " for layer in self.dense_layers:\n", " x = layer(x)\n", " return x\n", "\n", " def get_config(self):\n", " return {\"hidden_units\": self.hidden_units}\n", "\n", " @classmethod\n", " def from_config(cls, config):\n", " return cls(**config)\n", "\n", "\n", "model = CustomModel([16, 16, 10])\n", "# Build the model by calling it\n", "input_arr = tf.random.uniform((1, 5))\n", "outputs = model(input_arr)\n", "model.save(\"my_model\")\n", "\n", "# Option 1: Load with the custom_object argument.\n", "loaded_1 = keras.models.load_model(\n", " \"my_model\", custom_objects={\"CustomModel\": CustomModel}\n", ")\n", "\n", "# Option 2: Load without the CustomModel class.\n", "\n", "# Delete the custom-defined model class to ensure that the loader does not have\n", "# access to it.\n", "del CustomModel\n", "\n", "loaded_2 = keras.models.load_model(\"my_model\")\n", "np.testing.assert_allclose(loaded_1(input_arr), outputs)\n", "np.testing.assert_allclose(loaded_2(input_arr), outputs)\n", "\n", "print(\"Original model:\", model)\n", "print(\"Model Loaded with custom objects:\", loaded_1)\n", "print(\"Model loaded without the custom object class:\", loaded_2)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ba7b964eb364" }, "source": [ "第一个加载的模型是使用配置和 `CustomModel` 类加载的。第二个模型是通过动态创建类似于原始模型的模型类来加载的。" ] }, { "cell_type": "markdown", "metadata": { "id": "516e30bfbe10" }, "source": [ "#### 序贯模型示例:\n", "\n", "*TensoFlow 2.4 中的新功能*参数 `save_traces` 已添加到 `model.save`,它允许您切换 SavedModel 函数跟踪。保存函数以允许 Keras 在没有原始类定义的情况下重新加载自定义对象,因此当 `save_traces=False` 时,所有自定义对象必须已定义 `get_config`/` from_config` 方法。加载时,必须将自定义对象传递给 `custom_objects` 参数。`save_traces=False` 会减少 SavedModel 使用的磁盘空间并节省时间。" ] }, { "cell_type": "markdown", "metadata": { "id": "71d9e6b3d6af" }, "source": [ "### 如上例所示,加载器动态地创建了一个与原始模型行为类似的新模型。\n", "\n", "Keras 还支持保存单个 HDF5 文件,其中包含模型的架构、权重值和 `compile()` 信息。它是 SavedModel 的轻量化替代选择。\n", "\n", "**示例:**" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:47.027984Z", "iopub.status.busy": "2022-12-14T22:00:47.027135Z", "iopub.status.idle": "2022-12-14T22:00:48.011937Z", "shell.execute_reply": "2022-12-14T22:00:48.011083Z" }, "id": "1ae0912f6f9b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/4 [======>.......................] - ETA: 1s - loss: 0.3572" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/4 [==============================] - 0s 3ms/step - loss: 0.3264\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/4 [======>.......................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/4 [==============================] - 0s 1ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/4 [======>.......................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/4 [==============================] - 0s 1ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/4 [======>.......................] - ETA: 0s - loss: 0.3183" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/4 [==============================] - 0s 3ms/step - loss: 0.2991\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = get_model()\n", "\n", "# Train the model.\n", "test_input = np.random.random((128, 32))\n", "test_target = np.random.random((128, 1))\n", "model.fit(test_input, test_target)\n", "\n", "# Calling `save('my_model.h5')` creates a h5 file `my_model.h5`.\n", "model.save(\"my_h5_model.h5\")\n", "\n", "# It can be used to reconstruct the model identically.\n", "reconstructed_model = keras.models.load_model(\"my_h5_model.h5\")\n", "\n", "# Let's check:\n", "np.testing.assert_allclose(\n", " model.predict(test_input), reconstructed_model.predict(test_input)\n", ")\n", "\n", "# The reconstructed model is already compiled and has retained the optimizer\n", "# state, so training can resume:\n", "reconstructed_model.fit(test_input, test_target)" ] }, { "cell_type": "markdown", "metadata": { "id": "116bd1b1b215" }, "source": [ "#### 函数式模型示例:\n", "\n", "与 SavedModel 格式相比,H5 文件不包括以下两方面内容:\n", "\n", "- 通过 model.add_loss() 和 `model.add_metric()` 添加的外部损失和指标不会被保存(这与 SavedModel 不同)。如果您的模型有此类损失和指标且您想要恢复训练,则您需要在加载模型后自行重新添加这些损失。请注意,这不适用于通过 self.add_loss() 和 `self.add_metric()` 在层创建的损失/指标。只要该层被加载,这些损失和指标就会被保留,因为它们是该层 `call` 方法的一部分。\n", "- 已保存的文件中不包含**自定义对象(如自定义层)的计算图**。在加载时,Keras 需要访问这些对象的 Python 类/函数以重建模型。请参阅[自定义对象](#custom-objects)。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "bf78706009bf" }, "source": [ "## 保存架构\n", "\n", "模型的配置(或架构)指定模型包含的层,以及这些层的连接方式*。如果您有模型的配置,则可以使用权重的新初始化状态创建模型,而无需编译信息。\n", "\n", "*请注意,这仅适用于使用函数式或序列式 API 定义的模型,不适用于子类化模型。" ] }, { "cell_type": "markdown", "metadata": { "id": "58a708dbb5da" }, "source": [ "### 序贯模型或函数式 API 模型的配置\n", "\n", "这些类型的模型是显式的层计算图:它们的配置始终以结构化形式提供。\n", "\n", "#### API\n", "\n", "- `get_config()` 和 `from_config()`\n", "- `tf.keras.models.model_to_json()` 和 `tf.keras.models.model_from_json()`" ] }, { "cell_type": "markdown", "metadata": { "id": "3d8b20812b50" }, "source": [ "#### `get_config()` 和 `from_config()`\n", "\n", "调用 `config = model.get_config()` 将返回一个包含模型配置的 Python 字典。然后可以通过 `Sequential.from_config(config)`(针对 `Sequential` 模型)或 `Model.from_config(config)`(针对函数式 API 模型)重建同一模型。\n", "\n", "相同的工作流也适用于任何可序列化的层。\n", "\n", "**层示例:**" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.015953Z", "iopub.status.busy": "2022-12-14T22:00:48.015191Z", "iopub.status.idle": "2022-12-14T22:00:48.020819Z", "shell.execute_reply": "2022-12-14T22:00:48.020173Z" }, "id": "4f26b94e879a" }, "outputs": [], "source": [ "layer = keras.layers.Dense(3, activation=\"relu\")\n", "layer_config = layer.get_config()\n", "new_layer = keras.layers.Dense.from_config(layer_config)" ] }, { "cell_type": "markdown", "metadata": { "id": "a7e5dd2a439c" }, "source": [ "**示例:**" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.024342Z", "iopub.status.busy": "2022-12-14T22:00:48.023801Z", "iopub.status.idle": "2022-12-14T22:00:48.049651Z", "shell.execute_reply": "2022-12-14T22:00:48.049065Z" }, "id": "ae0842be8a2a" }, "outputs": [], "source": [ "model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])\n", "config = model.get_config()\n", "new_model = keras.Sequential.from_config(config)" ] }, { "cell_type": "markdown", "metadata": { "id": "1e97ca5f73d7" }, "source": [ "**函数式模型示例:**" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.052949Z", "iopub.status.busy": "2022-12-14T22:00:48.052489Z", "iopub.status.idle": "2022-12-14T22:00:48.076770Z", "shell.execute_reply": "2022-12-14T22:00:48.076191Z" }, "id": "da001f34e412" }, "outputs": [], "source": [ "inputs = keras.Input((32,))\n", "outputs = keras.layers.Dense(1)(inputs)\n", "model = keras.Model(inputs, outputs)\n", "config = model.get_config()\n", "new_model = keras.Model.from_config(config)" ] }, { "cell_type": "markdown", "metadata": { "id": "d7c08fae3eef" }, "source": [ "#### `to_json()` 和 `tf.keras.models.model_from_json()`\n", "\n", "这与 `get_config` / `from_config` 类似,不同之处在于它会将模型转换成 JSON 字符串,之后该字符串可以在没有原始模型类的情况下进行加载。它还特定于模型,不适用于层。\n", "\n", "**示例:**" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.080194Z", "iopub.status.busy": "2022-12-14T22:00:48.079615Z", "iopub.status.idle": "2022-12-14T22:00:48.103213Z", "shell.execute_reply": "2022-12-14T22:00:48.102646Z" }, "id": "12885447bd35" }, "outputs": [], "source": [ "model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])\n", "json_config = model.to_json()\n", "new_model = keras.models.model_from_json(json_config)" ] }, { "cell_type": "markdown", "metadata": { "id": "edcae4bf461c" }, "source": [ "### 仅加载 TensorFlow 计算图\n", "\n", "**模型和层**\n", "\n", "子类化模型和层的架构在 `__init__` 和 `call` 方法中进行定义。它们被视为 Python 字节码,无法将其序列化为与 JSON 兼容的配置。您可以尝试将字节码序列化(例如通过 `pickle`),但这样做极不安全,因为模型将无法在其他系统上进行加载。\n", "\n", "为了保存/加载带有自定义层的模型或子类化模型,您应该重写 `get_config` 和 `from_config`(可选)方法。此外,您还应该注册自定义对象,以便 Keras 能够感知它。\n", "\n", "**自定义函数**\n", "\n", "自定义函数(如激活损失或初始化)不需要 `get_config` 方法。只需将函数名称注册为自定义对象,就足以进行加载。\n", "\n", "**仅加载 TensorFlow 计算图**\n", "\n", "您可以加载由 Keras 生成的 TensorFlow 计算图。要进行此类加载,您无需提供任何 `custom_objects`。您可以执行以下代码进行加载:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.106663Z", "iopub.status.busy": "2022-12-14T22:00:48.106205Z", "iopub.status.idle": "2022-12-14T22:00:48.385595Z", "shell.execute_reply": "2022-12-14T22:00:48.384924Z" }, "id": "1651c6825106" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: my_model/assets\n" ] } ], "source": [ "model.save(\"my_model\")\n", "tensorflow_graph = tf.saved_model.load(\"my_model\")\n", "x = np.random.uniform(size=(4, 32)).astype(np.float32)\n", "predicted = tensorflow_graph(x).numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "b15faa16734b" }, "source": [ "请注意,此方式有几个缺点:\n", "\n", "- tf.saved_model.load 返回的对象不是 Keras 模型,因此不太容易使用。例如,您将无法访问 .predict().fit()。\n", "- `tf.saved_model.load` 返回的对象不是 Keras 模型,因此不太容易使用。例如,您将无法访问 `.predict()` 或 `.fit()`。\n", "\n", "虽然不鼓励使用此方式,但当您遇到棘手问题(例如,您丢失了自定义对象的代码,或在使用 `tf.keras.models.load_model()` 加载模型时遇到问题)时,它还是能够提供帮助。\n", "\n", "有关详细信息,请参阅 [`tf.saved_model.load` 相关页面](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load)。" ] }, { "cell_type": "markdown", "metadata": { "id": "d308bc27a04d" }, "source": [ "#### 定义配置方法\n", "\n", "规范:\n", "\n", "- `get_config` 应该返回一个 JSON 可序列化字典,以便兼容 Keras 节省架构和模型的 API。\n", "- `from_config(config)` (`classmethod`) 应返回从配置创建的新层或模型对象。默认实现返回 `cls(**config)`。\n", "\n", "**示例:**" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.389287Z", "iopub.status.busy": "2022-12-14T22:00:48.389062Z", "iopub.status.idle": "2022-12-14T22:00:48.399784Z", "shell.execute_reply": "2022-12-14T22:00:48.399193Z" }, "id": "e18c4668dadc" }, "outputs": [], "source": [ "class CustomLayer(keras.layers.Layer):\n", " def __init__(self, a):\n", " self.var = tf.Variable(a, name=\"var_a\")\n", "\n", " def call(self, inputs, training=False):\n", " if training:\n", " return inputs * self.var\n", " else:\n", " return inputs\n", "\n", " def get_config(self):\n", " return {\"a\": self.var.numpy()}\n", "\n", " # There's actually no need to define `from_config` here, since returning\n", " # `cls(**config)` is the default behavior.\n", " @classmethod\n", " def from_config(cls, config):\n", " return cls(**config)\n", "\n", "\n", "layer = CustomLayer(5)\n", "layer.var.assign(2)\n", "\n", "serialized_layer = keras.layers.serialize(layer)\n", "new_layer = keras.layers.deserialize(\n", " serialized_layer, custom_objects={\"CustomLayer\": CustomLayer}\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "425a9baa574e" }, "source": [ "#### 注册自定义对象\n", "\n", "Keras 会记录哪个类生成了配置。在上面的示例中,`tf.keras.layers.serialize` 会生成自定义层的序列化形式:\n", "\n", "```\n", "{'class_name': 'CustomLayer', 'config': {'a': 2}}\n", "```\n", "\n", "Keras 会维护一份所有内置层、模型、优化器和指标类的主列表,用于查找正确的类以调用 `from_config`。如果找不到该类,则会引发错误 (`Value Error: Unknown layer`)。可以通过几种方式将自定义类注册到此列表中:\n", "\n", "1. 在加载函数中设置 `custom_objects` 参数。(请参阅上文”定义配置方法“部分中的示例)\n", "2. `tf.keras.utils.custom_object_scope` 或者 `tf.keras.utils.CustomObjectScope`\n", "3. `tf.keras.utils.register_keras_serializable`" ] }, { "cell_type": "markdown", "metadata": { "id": "a047be0ba572" }, "source": [ "#### 示例:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.403341Z", "iopub.status.busy": "2022-12-14T22:00:48.402884Z", "iopub.status.idle": "2022-12-14T22:00:48.460812Z", "shell.execute_reply": "2022-12-14T22:00:48.460227Z" }, "id": "04a82ec30b5c" }, "outputs": [], "source": [ "class CustomLayer(keras.layers.Layer):\n", " def __init__(self, units=32, **kwargs):\n", " super(CustomLayer, self).__init__(**kwargs)\n", " self.units = units\n", "\n", " def build(self, input_shape):\n", " self.w = self.add_weight(\n", " shape=(input_shape[-1], self.units),\n", " initializer=\"random_normal\",\n", " trainable=True,\n", " )\n", " self.b = self.add_weight(\n", " shape=(self.units,), initializer=\"random_normal\", trainable=True\n", " )\n", "\n", " def call(self, inputs):\n", " return tf.matmul(inputs, self.w) + self.b\n", "\n", " def get_config(self):\n", " config = super(CustomLayer, self).get_config()\n", " config.update({\"units\": self.units})\n", " return config\n", "\n", "\n", "def custom_activation(x):\n", " return tf.nn.tanh(x) ** 2\n", "\n", "\n", "# Make a model with the CustomLayer and custom_activation\n", "inputs = keras.Input((32,))\n", "x = CustomLayer(32)(inputs)\n", "outputs = keras.layers.Activation(custom_activation)(x)\n", "model = keras.Model(inputs, outputs)\n", "\n", "# Retrieve the config\n", "config = model.get_config()\n", "\n", "# At loading time, register the custom objects with a `custom_object_scope`:\n", "custom_objects = {\"CustomLayer\": CustomLayer, \"custom_activation\": custom_activation}\n", "with keras.utils.custom_object_scope(custom_objects):\n", " new_model = keras.Model.from_config(config)" ] }, { "cell_type": "markdown", "metadata": { "id": "13c7f2a1be03" }, "source": [ "### 内存中模型克隆\n", "\n", "您还可以通过 `tf.keras.models.clone_model()` 在内存中克隆模型。这相当于获取模型的配置,然后通过配置重建模型(因此它不会保留编译信息或层的权重值)。\n", "\n", "**示例:**" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.464278Z", "iopub.status.busy": "2022-12-14T22:00:48.463783Z", "iopub.status.idle": "2022-12-14T22:00:48.480353Z", "shell.execute_reply": "2022-12-14T22:00:48.479771Z" }, "id": "93056ffe6eb4" }, "outputs": [], "source": [ "with keras.utils.custom_object_scope(custom_objects):\n", " new_model = keras.models.clone_model(model)" ] }, { "cell_type": "markdown", "metadata": { "id": "05c91a5a23e3" }, "source": [ "## 您只需使用模型进行推断:在这种情况下,您无需重新开始训练,因此不需要编译信息或优化器状态。\n", "\n", "在内存中将权重从一层转移到另一层\n", "\n", "- 您只需使用模型进行推断:在这种情况下,您无需重新开始训练,因此不需要编译信息或优化器状态。\n", "- 您正在进行迁移学习:在这种情况下,您需要重用先验模型的状态来训练新模型,因此不需要先验模型的编译信息。" ] }, { "cell_type": "markdown", "metadata": { "id": "c5229f4014f2" }, "source": [ "### 用于内存中权重迁移的 API\n", "\n", "您可以使用 `get_weights` 和 `set_weights` 在不同对象之间复制权重:\n", "\n", "- `tf.keras.layers.Layer.get_weights()`:返回 Numpy 数组列表。\n", "- `tf.keras.layers.Layer.set_weights()`:将模型权重设置为 `weights` 参数中的值。\n", "\n", "示例如下。\n", "\n", "***通常建议使用相同的 API 来构建模型。如果您在序贯模型和函数式模型之间,或在函数式模型和子类化模型等之间进行切换,请始终重新构建预训练模型并将预训练权重加载到该模型。***" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.483722Z", "iopub.status.busy": "2022-12-14T22:00:48.483248Z", "iopub.status.idle": "2022-12-14T22:00:48.501678Z", "shell.execute_reply": "2022-12-14T22:00:48.501137Z" }, "id": "c9124df19cb2" }, "outputs": [], "source": [ "def create_layer():\n", " layer = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")\n", " layer.build((None, 784))\n", " return layer\n", "\n", "\n", "layer_1 = create_layer()\n", "layer_2 = create_layer()\n", "\n", "# Copy weights from layer 1 to layer 2\n", "layer_2.set_weights(layer_1.get_weights())" ] }, { "cell_type": "markdown", "metadata": { "id": "ff7945516c7d" }, "source": [ "***在内存中将权重从一个模型转移到另一个具有兼容架构的模型***" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.505115Z", "iopub.status.busy": "2022-12-14T22:00:48.504600Z", "iopub.status.idle": "2022-12-14T22:00:48.569253Z", "shell.execute_reply": "2022-12-14T22:00:48.568686Z" }, "id": "11005d4023d4" }, "outputs": [], "source": [ "# Create a simple functional model\n", "inputs = keras.Input(shape=(784,), name=\"digits\")\n", "x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n", "x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n", "outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n", "functional_model = keras.Model(inputs=inputs, outputs=outputs, name=\"3_layer_mlp\")\n", "\n", "# Define a subclassed model with the same architecture\n", "class SubclassedModel(keras.Model):\n", " def __init__(self, output_dim, name=None):\n", " super(SubclassedModel, self).__init__(name=name)\n", " self.output_dim = output_dim\n", " self.dense_1 = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")\n", " self.dense_2 = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")\n", " self.dense_3 = keras.layers.Dense(output_dim, name=\"predictions\")\n", "\n", " def call(self, inputs):\n", " x = self.dense_1(inputs)\n", " x = self.dense_2(x)\n", " x = self.dense_3(x)\n", " return x\n", "\n", " def get_config(self):\n", " return {\"output_dim\": self.output_dim, \"name\": self.name}\n", "\n", "\n", "subclassed_model = SubclassedModel(10)\n", "# Call the subclassed model once to create the weights.\n", "subclassed_model(tf.ones((1, 784)))\n", "\n", "# Copy weights from functional_model to subclassed_model.\n", "subclassed_model.set_weights(functional_model.get_weights())\n", "\n", "assert len(functional_model.weights) == len(subclassed_model.weights)\n", "for a, b in zip(functional_model.weights, subclassed_model.weights):\n", " np.testing.assert_allclose(a.numpy(), b.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "bd4d08bff725" }, "source": [ "***无状态层的情况***\n", "\n", "无状态层不会改变权重的顺序或数量,因此即便存在额外的/缺失的无状态层,模型也可以具有兼容架构。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.572654Z", "iopub.status.busy": "2022-12-14T22:00:48.572183Z", "iopub.status.idle": "2022-12-14T22:00:48.632237Z", "shell.execute_reply": "2022-12-14T22:00:48.631669Z" }, "id": "927dc7934d44" }, "outputs": [], "source": [ "inputs = keras.Input(shape=(784,), name=\"digits\")\n", "x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n", "x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n", "outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n", "functional_model = keras.Model(inputs=inputs, outputs=outputs, name=\"3_layer_mlp\")\n", "\n", "inputs = keras.Input(shape=(784,), name=\"digits\")\n", "x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n", "x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n", "\n", "# Add a dropout layer, which does not contain any weights.\n", "x = keras.layers.Dropout(0.5)(x)\n", "outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n", "functional_model_with_dropout = keras.Model(\n", " inputs=inputs, outputs=outputs, name=\"3_layer_mlp\"\n", ")\n", "\n", "functional_model_with_dropout.set_weights(functional_model.get_weights())" ] }, { "cell_type": "markdown", "metadata": { "id": "199e984872d3" }, "source": [ "### 用于将权重保存到磁盘并将其加载回来的 API\n", "\n", "可以用以下格式调用 `model.save_weights`,将权重保存到磁盘:\n", "\n", "- TensorFlow 检查点\n", "- HDF5\n", "\n", "`model.save_weights` 的默认格式是 TensorFlow 检查点。可以通过以下两种方式指定保存格式:\n", "\n", "1. `save_format` 参数:将值设置为 `save_format=\"tf\"` 或 `save_format=\"h5\"`。\n", "2. `path` 参数:如果路径以 `.h5` 或 `.hdf5` 结束,则使用 HDF5 格式。除非设置了 `save_format`,否则对于其他后缀,将使用 TensorFlow 检查点格式。\n", "\n", "您还可以选择将权重作为内存中 Numpy 数组取回。每个 API 都有自己的优缺点,详情如下。" ] }, { "cell_type": "markdown", "metadata": { "id": "3505dc65d6c1" }, "source": [ "### TF 检查点格式\n", "\n", "**示例:**" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.635568Z", "iopub.status.busy": "2022-12-14T22:00:48.635322Z", "iopub.status.idle": "2022-12-14T22:00:48.687247Z", "shell.execute_reply": "2022-12-14T22:00:48.686592Z" }, "id": "f92053377391" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Runnable example\n", "sequential_model = keras.Sequential(\n", " [\n", " keras.Input(shape=(784,), name=\"digits\"),\n", " keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\"),\n", " keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\"),\n", " keras.layers.Dense(10, name=\"predictions\"),\n", " ]\n", ")\n", "sequential_model.save_weights(\"ckpt\")\n", "load_status = sequential_model.load_weights(\"ckpt\")\n", "\n", "# `assert_consumed` can be used as validation that all variable values have been\n", "# restored from the checkpoint. See `tf.train.Checkpoint.restore` for other\n", "# methods in the Status object.\n", "load_status.assert_consumed()" ] }, { "cell_type": "markdown", "metadata": { "id": "87f1145ac846" }, "source": [ "#### 格式详细信息\n", "\n", "TensorFlow 检查点格式使用对象特性名称来保存和恢复权重。以 `tf.keras.layers.Dense` 层为例。该层包含两个权重:`dense.kernel` 和 `dense.bias`。将层保存为 `tf` 格式后,生成的检查点会包含 `\"kernel\"` 和 `\"bias\"` 键及其对应的权重值。有关详情,请参阅 [TF 检查点指南中的“加载机制”](https://tensorflow.google.cn/guide/checkpoint#loading_mechanics)。\n", "\n", "请注意,特性/计算图边缘根据**父对象中使用的名称而非变量的名称**进行命名。请考虑下面示例中的 `CustomLayer`。变量 `CustomLayer.var` 是将 `\"var\"` 而非 `\"var_a\"` 作为键的一部分来保存的。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.690199Z", "iopub.status.busy": "2022-12-14T22:00:48.689952Z", "iopub.status.idle": "2022-12-14T22:00:48.707997Z", "shell.execute_reply": "2022-12-14T22:00:48.707411Z" }, "id": "c919189b3697" }, "outputs": [ { "data": { "text/plain": [ "{'save_counter/.ATTRIBUTES/VARIABLE_VALUE': tf.int64,\n", " 'layer/var/.ATTRIBUTES/VARIABLE_VALUE': tf.int32,\n", " '_CHECKPOINTABLE_OBJECT_GRAPH': tf.string}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class CustomLayer(keras.layers.Layer):\n", " def __init__(self, a):\n", " self.var = tf.Variable(a, name=\"var_a\")\n", "\n", "\n", "layer = CustomLayer(5)\n", "layer_ckpt = tf.train.Checkpoint(layer=layer).save(\"custom_layer\")\n", "\n", "ckpt_reader = tf.train.load_checkpoint(layer_ckpt)\n", "\n", "ckpt_reader.get_variable_to_dtype_map()" ] }, { "cell_type": "markdown", "metadata": { "id": "c4e5a7162b13" }, "source": [ "#### 迁移学习示例\n", "\n", "本质上,只要两个模型具有相同的架构,它们就可以共享同一个检查点。\n", "\n", "**示例:**" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.711157Z", "iopub.status.busy": "2022-12-14T22:00:48.710701Z", "iopub.status.idle": "2022-12-14T22:00:48.857738Z", "shell.execute_reply": "2022-12-14T22:00:48.857193Z" }, "id": "78d08199d27f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"pretrained_model\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " digits (InputLayer) [(None, 784)] 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_1 (Dense) (None, 64) 50240 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_2 (Dense) (None, 64) 4160 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 54,400\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 54,400\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " --------------------------------------------------\n", "Model: \"new_model\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " digits (InputLayer) [(None, 784)] 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_1 (Dense) (None, 64) 50240 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_2 (Dense) (None, 64) 4160 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " predictions (Dense) (None, 5) 325 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 54,725\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 54,725\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_3\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " pretrained (Functional) (None, 64) 54400 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " predictions (Dense) (None, 5) 325 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 54,725\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 54,725\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs = keras.Input(shape=(784,), name=\"digits\")\n", "x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n", "x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n", "outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n", "functional_model = keras.Model(inputs=inputs, outputs=outputs, name=\"3_layer_mlp\")\n", "\n", "# Extract a portion of the functional model defined in the Setup section.\n", "# The following lines produce a new model that excludes the final output\n", "# layer of the functional model.\n", "pretrained = keras.Model(\n", " functional_model.inputs, functional_model.layers[-1].input, name=\"pretrained_model\"\n", ")\n", "# Randomly assign \"trained\" weights.\n", "for w in pretrained.weights:\n", " w.assign(tf.random.normal(w.shape))\n", "pretrained.save_weights(\"pretrained_ckpt\")\n", "pretrained.summary()\n", "\n", "# Assume this is a separate program where only 'pretrained_ckpt' exists.\n", "# Create a new functional model with a different output dimension.\n", "inputs = keras.Input(shape=(784,), name=\"digits\")\n", "x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n", "x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n", "outputs = keras.layers.Dense(5, name=\"predictions\")(x)\n", "model = keras.Model(inputs=inputs, outputs=outputs, name=\"new_model\")\n", "\n", "# Load the weights from pretrained_ckpt into model.\n", "model.load_weights(\"pretrained_ckpt\")\n", "\n", "# Check that all of the pretrained weights have been loaded.\n", "for a, b in zip(pretrained.weights, model.weights):\n", " np.testing.assert_allclose(a.numpy(), b.numpy())\n", "\n", "print(\"\\n\", \"-\" * 50)\n", "model.summary()\n", "\n", "# Example 2: Sequential model\n", "# Recreate the pretrained model, and load the saved weights.\n", "inputs = keras.Input(shape=(784,), name=\"digits\")\n", "x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n", "x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n", "pretrained_model = keras.Model(inputs=inputs, outputs=x, name=\"pretrained\")\n", "\n", "# Sequential example:\n", "model = keras.Sequential([pretrained_model, keras.layers.Dense(5, name=\"predictions\")])\n", "model.summary()\n", "\n", "pretrained_model.load_weights(\"pretrained_ckpt\")\n", "\n", "# Warning! Calling `model.load_weights('pretrained_ckpt')` won't throw an error,\n", "# but will *not* work as expected. If you inspect the weights, you'll see that\n", "# none of the weights will have loaded. `pretrained_model.load_weights()` is the\n", "# correct method to call." ] }, { "cell_type": "markdown", "metadata": { "id": "7b07ad5fe5b0" }, "source": [ "通常建议使用相同的 API 来构建模型。如果您在序贯模型和函数式模型之间切换,或在函数式模型和子类化模型等之间切换,请始终重新构建预训练模型并将预训练权重加载到该模型。" ] }, { "cell_type": "markdown", "metadata": { "id": "2ab83c542e2d" }, "source": [ "下一个问题是,如果模型架构截然不同,如何保存权重并将其加载到不同模型?解决方案是使用 `tf.train.Checkpoint` 来保存和恢复确切的层/变量。\n", "\n", "**示例:**" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.860842Z", "iopub.status.busy": "2022-12-14T22:00:48.860624Z", "iopub.status.idle": "2022-12-14T22:00:48.893681Z", "shell.execute_reply": "2022-12-14T22:00:48.893143Z" }, "id": "97037b9ea265" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_131352/1562824211.py:15: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.\n", " self.kernel = self.add_variable(\"kernel\", shape=(64, 10))\n", "/tmpfs/tmp/ipykernel_131352/1562824211.py:16: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.\n", " self.bias = self.add_variable(\"bias\", shape=(10,))\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a subclassed model that essentially uses functional_model's first\n", "# and last layers.\n", "# First, save the weights of functional_model's first and last dense layers.\n", "first_dense = functional_model.layers[1]\n", "last_dense = functional_model.layers[-1]\n", "ckpt_path = tf.train.Checkpoint(\n", " dense=first_dense, kernel=last_dense.kernel, bias=last_dense.bias\n", ").save(\"ckpt\")\n", "\n", "# Define the subclassed model.\n", "class ContrivedModel(keras.Model):\n", " def __init__(self):\n", " super(ContrivedModel, self).__init__()\n", " self.first_dense = keras.layers.Dense(64)\n", " self.kernel = self.add_variable(\"kernel\", shape=(64, 10))\n", " self.bias = self.add_variable(\"bias\", shape=(10,))\n", "\n", " def call(self, inputs):\n", " x = self.first_dense(inputs)\n", " return tf.matmul(x, self.kernel) + self.bias\n", "\n", "\n", "model = ContrivedModel()\n", "# Call model on inputs to create the variables of the dense layer.\n", "_ = model(tf.ones((1, 784)))\n", "\n", "# Create a Checkpoint with the same structure as before, and load the weights.\n", "tf.train.Checkpoint(\n", " dense=model.first_dense, kernel=model.kernel, bias=model.bias\n", ").restore(ckpt_path).assert_consumed()" ] }, { "cell_type": "markdown", "metadata": { "id": "18356461e7dd" }, "source": [ "### HDF5 格式\n", "\n", "HDF5 格式包含按层名称分组的权重。权重是通过将可训练权重列表与不可训练权重列表连接起来进行排序的列表(与 `layer.weights` 相同)。因此,如果模型的层和可训练状态与保存在检查点中的相同,则可以使用 HDF5 检查点。\n", "\n", "**示例:**" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.897362Z", "iopub.status.busy": "2022-12-14T22:00:48.897127Z", "iopub.status.idle": "2022-12-14T22:00:48.939263Z", "shell.execute_reply": "2022-12-14T22:00:48.938571Z" }, "id": "43aec1e07913" }, "outputs": [], "source": [ "# Runnable example\n", "sequential_model = keras.Sequential(\n", " [\n", " keras.Input(shape=(784,), name=\"digits\"),\n", " keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\"),\n", " keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\"),\n", " keras.layers.Dense(10, name=\"predictions\"),\n", " ]\n", ")\n", "sequential_model.save_weights(\"weights.h5\")\n", "sequential_model.load_weights(\"weights.h5\")" ] }, { "cell_type": "markdown", "metadata": { "id": "dc63aef6e0d3" }, "source": [ "请注意,当模型包含嵌套层时,更改 `layer.trainable` 可能导致 `layer.weights` 的顺序不同。" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.942672Z", "iopub.status.busy": "2022-12-14T22:00:48.942177Z", "iopub.status.idle": "2022-12-14T22:00:48.988311Z", "shell.execute_reply": "2022-12-14T22:00:48.987713Z" }, "id": "83b70826944a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "variables: ['nested/dense_1/kernel:0', 'nested/dense_1/bias:0', 'nested/dense_2/kernel:0', 'nested/dense_2/bias:0']\n", "\n", "Changing trainable status of one of the nested layers...\n", "\n", "variables: ['nested/dense_2/kernel:0', 'nested/dense_2/bias:0', 'nested/dense_1/kernel:0', 'nested/dense_1/bias:0']\n", "variable ordering changed: True\n" ] } ], "source": [ "class NestedDenseLayer(keras.layers.Layer):\n", " def __init__(self, units, name=None):\n", " super(NestedDenseLayer, self).__init__(name=name)\n", " self.dense_1 = keras.layers.Dense(units, name=\"dense_1\")\n", " self.dense_2 = keras.layers.Dense(units, name=\"dense_2\")\n", "\n", " def call(self, inputs):\n", " return self.dense_2(self.dense_1(inputs))\n", "\n", "\n", "nested_model = keras.Sequential([keras.Input((784,)), NestedDenseLayer(10, \"nested\")])\n", "variable_names = [v.name for v in nested_model.weights]\n", "print(\"variables: {}\".format(variable_names))\n", "\n", "print(\"\\nChanging trainable status of one of the nested layers...\")\n", "nested_model.get_layer(\"nested\").dense_1.trainable = False\n", "\n", "variable_names_2 = [v.name for v in nested_model.weights]\n", "print(\"\\nvariables: {}\".format(variable_names_2))\n", "print(\"variable ordering changed:\", variable_names != variable_names_2)" ] }, { "cell_type": "markdown", "metadata": { "id": "cc261c1a31ee" }, "source": [ "#### 迁移学习示例\n", "\n", "从 HDF5 加载预训练权重时,建议将权重加载到设置了检查点的原始模型中,然后将所需的权重/层提取到新模型中。\n", "\n", "**示例:**" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:00:48.991546Z", "iopub.status.busy": "2022-12-14T22:00:48.990972Z", "iopub.status.idle": "2022-12-14T22:00:49.082460Z", "shell.execute_reply": "2022-12-14T22:00:49.081885Z" }, "id": "06cabc31494a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_6\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_1 (Dense) (None, 64) 50240 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_2 (Dense) (None, 64) 4160 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_3 (Dense) (None, 5) 325 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 54,725\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 54,725\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "def create_functional_model():\n", " inputs = keras.Input(shape=(784,), name=\"digits\")\n", " x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n", " x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n", " outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n", " return keras.Model(inputs=inputs, outputs=outputs, name=\"3_layer_mlp\")\n", "\n", "\n", "functional_model = create_functional_model()\n", "functional_model.save_weights(\"pretrained_weights.h5\")\n", "\n", "# In a separate program:\n", "pretrained_model = create_functional_model()\n", "pretrained_model.load_weights(\"pretrained_weights.h5\")\n", "\n", "# Create a new model by extracting layers from the original model:\n", "extracted_layers = pretrained_model.layers[:-1]\n", "extracted_layers.append(keras.layers.Dense(5, name=\"dense_3\"))\n", "model = keras.Sequential(extracted_layers)\n", "model.summary()" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "save_and_serialize.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }