{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "b518b04cbfe0" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T22:00:38.849154Z", "iopub.status.busy": "2022-12-14T22:00:38.848714Z", "iopub.status.idle": "2022-12-14T22:00:38.852260Z", "shell.execute_reply": "2022-12-14T22:00:38.851721Z" }, "id": "906e07f6e562" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "394e705afdd5" }, "source": [ "# 保存和加载 Keras 模型" ] }, { "cell_type": "markdown", "metadata": { "id": "60de82f6bcea" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
model.add_loss()
和 `model.add_metric()` 添加的外部损失和指标不会被保存(这与 SavedModel 不同)。如果您的模型有此类损失和指标且您想要恢复训练,则您需要在加载模型后自行重新添加这些损失。请注意,这不适用于通过 self.add_loss()
和 `self.add_metric()` 在层内创建的损失/指标。只要该层被加载,这些损失和指标就会被保留,因为它们是该层 `call` 方法的一部分。\n",
"- 已保存的文件中不包含**自定义对象(如自定义层)的计算图**。在加载时,Keras 需要访问这些对象的 Python 类/函数以重建模型。请参阅[自定义对象](#custom-objects)。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bf78706009bf"
},
"source": [
"## 保存架构\n",
"\n",
"模型的配置(或架构)指定模型包含的层,以及这些层的连接方式*。如果您有模型的配置,则可以使用权重的新初始化状态创建模型,而无需编译信息。\n",
"\n",
"*请注意,这仅适用于使用函数式或序列式 API 定义的模型,不适用于子类化模型。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "58a708dbb5da"
},
"source": [
"### 序贯模型或函数式 API 模型的配置\n",
"\n",
"这些类型的模型是显式的层计算图:它们的配置始终以结构化形式提供。\n",
"\n",
"#### API\n",
"\n",
"- `get_config()` 和 `from_config()`\n",
"- `tf.keras.models.model_to_json()` 和 `tf.keras.models.model_from_json()`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3d8b20812b50"
},
"source": [
"#### `get_config()` 和 `from_config()`\n",
"\n",
"调用 `config = model.get_config()` 将返回一个包含模型配置的 Python 字典。然后可以通过 `Sequential.from_config(config)`(针对 `Sequential` 模型)或 `Model.from_config(config)`(针对函数式 API 模型)重建同一模型。\n",
"\n",
"相同的工作流也适用于任何可序列化的层。\n",
"\n",
"**层示例:**"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:48.015953Z",
"iopub.status.busy": "2022-12-14T22:00:48.015191Z",
"iopub.status.idle": "2022-12-14T22:00:48.020819Z",
"shell.execute_reply": "2022-12-14T22:00:48.020173Z"
},
"id": "4f26b94e879a"
},
"outputs": [],
"source": [
"layer = keras.layers.Dense(3, activation=\"relu\")\n",
"layer_config = layer.get_config()\n",
"new_layer = keras.layers.Dense.from_config(layer_config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a7e5dd2a439c"
},
"source": [
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:48.024342Z",
"iopub.status.busy": "2022-12-14T22:00:48.023801Z",
"iopub.status.idle": "2022-12-14T22:00:48.049651Z",
"shell.execute_reply": "2022-12-14T22:00:48.049065Z"
},
"id": "ae0842be8a2a"
},
"outputs": [],
"source": [
"model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])\n",
"config = model.get_config()\n",
"new_model = keras.Sequential.from_config(config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1e97ca5f73d7"
},
"source": [
"**函数式模型示例:**"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:48.052949Z",
"iopub.status.busy": "2022-12-14T22:00:48.052489Z",
"iopub.status.idle": "2022-12-14T22:00:48.076770Z",
"shell.execute_reply": "2022-12-14T22:00:48.076191Z"
},
"id": "da001f34e412"
},
"outputs": [],
"source": [
"inputs = keras.Input((32,))\n",
"outputs = keras.layers.Dense(1)(inputs)\n",
"model = keras.Model(inputs, outputs)\n",
"config = model.get_config()\n",
"new_model = keras.Model.from_config(config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d7c08fae3eef"
},
"source": [
"#### `to_json()` 和 `tf.keras.models.model_from_json()`\n",
"\n",
"这与 `get_config` / `from_config` 类似,不同之处在于它会将模型转换成 JSON 字符串,之后该字符串可以在没有原始模型类的情况下进行加载。它还特定于模型,不适用于层。\n",
"\n",
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:48.080194Z",
"iopub.status.busy": "2022-12-14T22:00:48.079615Z",
"iopub.status.idle": "2022-12-14T22:00:48.103213Z",
"shell.execute_reply": "2022-12-14T22:00:48.102646Z"
},
"id": "12885447bd35"
},
"outputs": [],
"source": [
"model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])\n",
"json_config = model.to_json()\n",
"new_model = keras.models.model_from_json(json_config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "edcae4bf461c"
},
"source": [
"### 仅加载 TensorFlow 计算图\n",
"\n",
"**模型和层**\n",
"\n",
"子类化模型和层的架构在 `__init__` 和 `call` 方法中进行定义。它们被视为 Python 字节码,无法将其序列化为与 JSON 兼容的配置。您可以尝试将字节码序列化(例如通过 `pickle`),但这样做极不安全,因为模型将无法在其他系统上进行加载。\n",
"\n",
"为了保存/加载带有自定义层的模型或子类化模型,您应该重写 `get_config` 和 `from_config`(可选)方法。此外,您还应该注册自定义对象,以便 Keras 能够感知它。\n",
"\n",
"**自定义函数**\n",
"\n",
"自定义函数(如激活损失或初始化)不需要 `get_config` 方法。只需将函数名称注册为自定义对象,就足以进行加载。\n",
"\n",
"**仅加载 TensorFlow 计算图**\n",
"\n",
"您可以加载由 Keras 生成的 TensorFlow 计算图。要进行此类加载,您无需提供任何 `custom_objects`。您可以执行以下代码进行加载:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:48.106663Z",
"iopub.status.busy": "2022-12-14T22:00:48.106205Z",
"iopub.status.idle": "2022-12-14T22:00:48.385595Z",
"shell.execute_reply": "2022-12-14T22:00:48.384924Z"
},
"id": "1651c6825106"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_model/assets\n"
]
}
],
"source": [
"model.save(\"my_model\")\n",
"tensorflow_graph = tf.saved_model.load(\"my_model\")\n",
"x = np.random.uniform(size=(4, 32)).astype(np.float32)\n",
"predicted = tensorflow_graph(x).numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b15faa16734b"
},
"source": [
"请注意,此方式有几个缺点:\n",
"\n",
"- tf.saved_model.load
返回的对象不是 Keras 模型,因此不太容易使用。例如,您将无法访问 .predict()
或 .fit()
。\n",
"- `tf.saved_model.load` 返回的对象不是 Keras 模型,因此不太容易使用。例如,您将无法访问 `.predict()` 或 `.fit()`。\n",
"\n",
"虽然不鼓励使用此方式,但当您遇到棘手问题(例如,您丢失了自定义对象的代码,或在使用 `tf.keras.models.load_model()` 加载模型时遇到问题)时,它还是能够提供帮助。\n",
"\n",
"有关详细信息,请参阅 [`tf.saved_model.load` 相关页面](https://tensorflow.google.cn/api_docs/python/tf/saved_model/load)。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d308bc27a04d"
},
"source": [
"#### 定义配置方法\n",
"\n",
"规范:\n",
"\n",
"- `get_config` 应该返回一个 JSON 可序列化字典,以便兼容 Keras 节省架构和模型的 API。\n",
"- `from_config(config)` (`classmethod`) 应返回从配置创建的新层或模型对象。默认实现返回 `cls(**config)`。\n",
"\n",
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:48.389287Z",
"iopub.status.busy": "2022-12-14T22:00:48.389062Z",
"iopub.status.idle": "2022-12-14T22:00:48.399784Z",
"shell.execute_reply": "2022-12-14T22:00:48.399193Z"
},
"id": "e18c4668dadc"
},
"outputs": [],
"source": [
"class CustomLayer(keras.layers.Layer):\n",
" def __init__(self, a):\n",
" self.var = tf.Variable(a, name=\"var_a\")\n",
"\n",
" def call(self, inputs, training=False):\n",
" if training:\n",
" return inputs * self.var\n",
" else:\n",
" return inputs\n",
"\n",
" def get_config(self):\n",
" return {\"a\": self.var.numpy()}\n",
"\n",
" # There's actually no need to define `from_config` here, since returning\n",
" # `cls(**config)` is the default behavior.\n",
" @classmethod\n",
" def from_config(cls, config):\n",
" return cls(**config)\n",
"\n",
"\n",
"layer = CustomLayer(5)\n",
"layer.var.assign(2)\n",
"\n",
"serialized_layer = keras.layers.serialize(layer)\n",
"new_layer = keras.layers.deserialize(\n",
" serialized_layer, custom_objects={\"CustomLayer\": CustomLayer}\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "425a9baa574e"
},
"source": [
"#### 注册自定义对象\n",
"\n",
"Keras 会记录哪个类生成了配置。在上面的示例中,`tf.keras.layers.serialize` 会生成自定义层的序列化形式:\n",
"\n",
"```\n",
"{'class_name': 'CustomLayer', 'config': {'a': 2}}\n",
"```\n",
"\n",
"Keras 会维护一份所有内置层、模型、优化器和指标类的主列表,用于查找正确的类以调用 `from_config`。如果找不到该类,则会引发错误 (`Value Error: Unknown layer`)。可以通过几种方式将自定义类注册到此列表中:\n",
"\n",
"1. 在加载函数中设置 `custom_objects` 参数。(请参阅上文”定义配置方法“部分中的示例)\n",
"2. `tf.keras.utils.custom_object_scope` 或者 `tf.keras.utils.CustomObjectScope`\n",
"3. `tf.keras.utils.register_keras_serializable`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a047be0ba572"
},
"source": [
"#### 示例:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:48.403341Z",
"iopub.status.busy": "2022-12-14T22:00:48.402884Z",
"iopub.status.idle": "2022-12-14T22:00:48.460812Z",
"shell.execute_reply": "2022-12-14T22:00:48.460227Z"
},
"id": "04a82ec30b5c"
},
"outputs": [],
"source": [
"class CustomLayer(keras.layers.Layer):\n",
" def __init__(self, units=32, **kwargs):\n",
" super(CustomLayer, self).__init__(**kwargs)\n",
" self.units = units\n",
"\n",
" def build(self, input_shape):\n",
" self.w = self.add_weight(\n",
" shape=(input_shape[-1], self.units),\n",
" initializer=\"random_normal\",\n",
" trainable=True,\n",
" )\n",
" self.b = self.add_weight(\n",
" shape=(self.units,), initializer=\"random_normal\", trainable=True\n",
" )\n",
"\n",
" def call(self, inputs):\n",
" return tf.matmul(inputs, self.w) + self.b\n",
"\n",
" def get_config(self):\n",
" config = super(CustomLayer, self).get_config()\n",
" config.update({\"units\": self.units})\n",
" return config\n",
"\n",
"\n",
"def custom_activation(x):\n",
" return tf.nn.tanh(x) ** 2\n",
"\n",
"\n",
"# Make a model with the CustomLayer and custom_activation\n",
"inputs = keras.Input((32,))\n",
"x = CustomLayer(32)(inputs)\n",
"outputs = keras.layers.Activation(custom_activation)(x)\n",
"model = keras.Model(inputs, outputs)\n",
"\n",
"# Retrieve the config\n",
"config = model.get_config()\n",
"\n",
"# At loading time, register the custom objects with a `custom_object_scope`:\n",
"custom_objects = {\"CustomLayer\": CustomLayer, \"custom_activation\": custom_activation}\n",
"with keras.utils.custom_object_scope(custom_objects):\n",
" new_model = keras.Model.from_config(config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "13c7f2a1be03"
},
"source": [
"### 内存中模型克隆\n",
"\n",
"您还可以通过 `tf.keras.models.clone_model()` 在内存中克隆模型。这相当于获取模型的配置,然后通过配置重建模型(因此它不会保留编译信息或层的权重值)。\n",
"\n",
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:48.464278Z",
"iopub.status.busy": "2022-12-14T22:00:48.463783Z",
"iopub.status.idle": "2022-12-14T22:00:48.480353Z",
"shell.execute_reply": "2022-12-14T22:00:48.479771Z"
},
"id": "93056ffe6eb4"
},
"outputs": [],
"source": [
"with keras.utils.custom_object_scope(custom_objects):\n",
" new_model = keras.models.clone_model(model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "05c91a5a23e3"
},
"source": [
"## 您只需使用模型进行推断:在这种情况下,您无需重新开始训练,因此不需要编译信息或优化器状态。\n",
"\n",
"在内存中将权重从一层转移到另一层\n",
"\n",
"- 您只需使用模型进行推断:在这种情况下,您无需重新开始训练,因此不需要编译信息或优化器状态。\n",
"- 您正在进行迁移学习:在这种情况下,您需要重用先验模型的状态来训练新模型,因此不需要先验模型的编译信息。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c5229f4014f2"
},
"source": [
"### 用于内存中权重迁移的 API\n",
"\n",
"您可以使用 `get_weights` 和 `set_weights` 在不同对象之间复制权重:\n",
"\n",
"- `tf.keras.layers.Layer.get_weights()`:返回 Numpy 数组列表。\n",
"- `tf.keras.layers.Layer.set_weights()`:将模型权重设置为 `weights` 参数中的值。\n",
"\n",
"示例如下。\n",
"\n",
"***通常建议使用相同的 API 来构建模型。如果您在序贯模型和函数式模型之间,或在函数式模型和子类化模型等之间进行切换,请始终重新构建预训练模型并将预训练权重加载到该模型。***"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:48.483722Z",
"iopub.status.busy": "2022-12-14T22:00:48.483248Z",
"iopub.status.idle": "2022-12-14T22:00:48.501678Z",
"shell.execute_reply": "2022-12-14T22:00:48.501137Z"
},
"id": "c9124df19cb2"
},
"outputs": [],
"source": [
"def create_layer():\n",
" layer = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")\n",
" layer.build((None, 784))\n",
" return layer\n",
"\n",
"\n",
"layer_1 = create_layer()\n",
"layer_2 = create_layer()\n",
"\n",
"# Copy weights from layer 1 to layer 2\n",
"layer_2.set_weights(layer_1.get_weights())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ff7945516c7d"
},
"source": [
"***在内存中将权重从一个模型转移到另一个具有兼容架构的模型***"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:48.505115Z",
"iopub.status.busy": "2022-12-14T22:00:48.504600Z",
"iopub.status.idle": "2022-12-14T22:00:48.569253Z",
"shell.execute_reply": "2022-12-14T22:00:48.568686Z"
},
"id": "11005d4023d4"
},
"outputs": [],
"source": [
"# Create a simple functional model\n",
"inputs = keras.Input(shape=(784,), name=\"digits\")\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
"outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n",
"functional_model = keras.Model(inputs=inputs, outputs=outputs, name=\"3_layer_mlp\")\n",
"\n",
"# Define a subclassed model with the same architecture\n",
"class SubclassedModel(keras.Model):\n",
" def __init__(self, output_dim, name=None):\n",
" super(SubclassedModel, self).__init__(name=name)\n",
" self.output_dim = output_dim\n",
" self.dense_1 = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")\n",
" self.dense_2 = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")\n",
" self.dense_3 = keras.layers.Dense(output_dim, name=\"predictions\")\n",
"\n",
" def call(self, inputs):\n",
" x = self.dense_1(inputs)\n",
" x = self.dense_2(x)\n",
" x = self.dense_3(x)\n",
" return x\n",
"\n",
" def get_config(self):\n",
" return {\"output_dim\": self.output_dim, \"name\": self.name}\n",
"\n",
"\n",
"subclassed_model = SubclassedModel(10)\n",
"# Call the subclassed model once to create the weights.\n",
"subclassed_model(tf.ones((1, 784)))\n",
"\n",
"# Copy weights from functional_model to subclassed_model.\n",
"subclassed_model.set_weights(functional_model.get_weights())\n",
"\n",
"assert len(functional_model.weights) == len(subclassed_model.weights)\n",
"for a, b in zip(functional_model.weights, subclassed_model.weights):\n",
" np.testing.assert_allclose(a.numpy(), b.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bd4d08bff725"
},
"source": [
"***无状态层的情况***\n",
"\n",
"无状态层不会改变权重的顺序或数量,因此即便存在额外的/缺失的无状态层,模型也可以具有兼容架构。"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:48.572654Z",
"iopub.status.busy": "2022-12-14T22:00:48.572183Z",
"iopub.status.idle": "2022-12-14T22:00:48.632237Z",
"shell.execute_reply": "2022-12-14T22:00:48.631669Z"
},
"id": "927dc7934d44"
},
"outputs": [],
"source": [
"inputs = keras.Input(shape=(784,), name=\"digits\")\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
"outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n",
"functional_model = keras.Model(inputs=inputs, outputs=outputs, name=\"3_layer_mlp\")\n",
"\n",
"inputs = keras.Input(shape=(784,), name=\"digits\")\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
"x = keras.layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
"\n",
"# Add a dropout layer, which does not contain any weights.\n",
"x = keras.layers.Dropout(0.5)(x)\n",
"outputs = keras.layers.Dense(10, name=\"predictions\")(x)\n",
"functional_model_with_dropout = keras.Model(\n",
" inputs=inputs, outputs=outputs, name=\"3_layer_mlp\"\n",
")\n",
"\n",
"functional_model_with_dropout.set_weights(functional_model.get_weights())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "199e984872d3"
},
"source": [
"### 用于将权重保存到磁盘并将其加载回来的 API\n",
"\n",
"可以用以下格式调用 `model.save_weights`,将权重保存到磁盘:\n",
"\n",
"- TensorFlow 检查点\n",
"- HDF5\n",
"\n",
"`model.save_weights` 的默认格式是 TensorFlow 检查点。可以通过以下两种方式指定保存格式:\n",
"\n",
"1. `save_format` 参数:将值设置为 `save_format=\"tf\"` 或 `save_format=\"h5\"`。\n",
"2. `path` 参数:如果路径以 `.h5` 或 `.hdf5` 结束,则使用 HDF5 格式。除非设置了 `save_format`,否则对于其他后缀,将使用 TensorFlow 检查点格式。\n",
"\n",
"您还可以选择将权重作为内存中 Numpy 数组取回。每个 API 都有自己的优缺点,详情如下。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3505dc65d6c1"
},
"source": [
"### TF 检查点格式\n",
"\n",
"**示例:**"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:00:48.635568Z",
"iopub.status.busy": "2022-12-14T22:00:48.635322Z",
"iopub.status.idle": "2022-12-14T22:00:48.687247Z",
"shell.execute_reply": "2022-12-14T22:00:48.686592Z"
},
"id": "f92053377391"
},
"outputs": [
{
"data": {
"text/plain": [
"