{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "b518b04cbfe0" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T21:22:25.845810Z", "iopub.status.busy": "2022-12-14T21:22:25.845258Z", "iopub.status.idle": "2022-12-14T21:22:25.848896Z", "shell.execute_reply": "2022-12-14T21:22:25.848358Z" }, "id": "906e07f6e562" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "da6087fbd570" }, "source": [ "# 函数式 API" ] }, { "cell_type": "markdown", "metadata": { "id": "d169f4a559d5" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "8d4ac441b1fc" }, "source": [ "## 设置" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:25.852555Z", "iopub.status.busy": "2022-12-14T21:22:25.852015Z", "iopub.status.idle": "2022-12-14T21:22:27.793159Z", "shell.execute_reply": "2022-12-14T21:22:27.792480Z" }, "id": "ec52be14e686" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 21:22:26.826756: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 21:22:26.826867: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 21:22:26.826878: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras import layers" ] }, { "cell_type": "markdown", "metadata": { "id": "871fbb54ea07" }, "source": [ "## 简介\n", "\n", "Keras *函数式 API* 是一种比 `tf.keras.Sequential` API 更加灵活的模型创建方式。函数式 API 可以处理具有非线性拓扑的模型、具有共享层的模型,以及具有多个输入或输出的模型。\n", "\n", "深度学习模型通常是层的有向无环图 (DAG)。因此,函数式 API 是构建*层计算图*的一种方式。\n", "\n", "请考虑以下模型:\n", "\n", "```\n", "(input: 784-dimensional vectors) ↧ [Dense (64 units, relu activation)] ↧ [Dense (64 units, relu activation)] ↧ [Dense (10 units, softmax activation)] ↧ (output: logits of a probability distribution over 10 classes)\n", "```\n", "\n", "这是一个具有三层的基本计算图。要使用函数式 API 构建此模型,请先创建一个输入节点:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:27.798180Z", "iopub.status.busy": "2022-12-14T21:22:27.797454Z", "iopub.status.idle": "2022-12-14T21:22:27.805574Z", "shell.execute_reply": "2022-12-14T21:22:27.805022Z" }, "id": "8d477c91955a" }, "outputs": [], "source": [ "inputs = keras.Input(shape=(784,))" ] }, { "cell_type": "markdown", "metadata": { "id": "13c14d993620" }, "source": [ "数据的形状设置为 784 维向量。由于仅指定了每个样本的形状,因此始终忽略批次大小。\n", "\n", "例如,如果您有一个形状为 `(32, 32, 3)` 的图像输入,则可以使用:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:27.809003Z", "iopub.status.busy": "2022-12-14T21:22:27.808768Z", "iopub.status.idle": "2022-12-14T21:22:27.813011Z", "shell.execute_reply": "2022-12-14T21:22:27.812452Z" }, "id": "e4732e8e279b" }, "outputs": [], "source": [ "# Just for demonstration purposes.\n", "img_inputs = keras.Input(shape=(32, 32, 3))" ] }, { "cell_type": "markdown", "metadata": { "id": "971bf8b5588f" }, "source": [ "返回的 `inputs` 包含馈送给模型的输入数据的形状和 `dtype`。形状如下:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:27.816291Z", "iopub.status.busy": "2022-12-14T21:22:27.815926Z", "iopub.status.idle": "2022-12-14T21:22:27.821946Z", "shell.execute_reply": "2022-12-14T21:22:27.821402Z" }, "id": "ee96c179846a" }, "outputs": [ { "data": { "text/plain": [ "TensorShape([None, 784])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "866eee86d63e" }, "source": [ "dtype 如下:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:27.825709Z", "iopub.status.busy": "2022-12-14T21:22:27.825229Z", "iopub.status.idle": "2022-12-14T21:22:27.829172Z", "shell.execute_reply": "2022-12-14T21:22:27.828663Z" }, "id": "480be92067f3" }, "outputs": [ { "data": { "text/plain": [ "tf.float32" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs.dtype" ] }, { "cell_type": "markdown", "metadata": { "id": "6c93172cdfba" }, "source": [ "可以通过在此 `inputs` 对象上调用层,在层计算图中创建新的节点:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:27.832400Z", "iopub.status.busy": "2022-12-14T21:22:27.832052Z", "iopub.status.idle": "2022-12-14T21:22:31.311325Z", "shell.execute_reply": "2022-12-14T21:22:31.310539Z" }, "id": "b50da8b1c28d" }, "outputs": [], "source": [ "dense = layers.Dense(64, activation=\"relu\")\n", "x = dense(inputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "0f36afe42ff3" }, "source": [ "“层调用”操作就像从“输入”向您创建的该层绘制一个箭头。您将输入“传递”到 `dense` 层,然后得到 `x`。\n", "\n", "让我们为层计算图多添加几个层:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:31.315381Z", "iopub.status.busy": "2022-12-14T21:22:31.315126Z", "iopub.status.idle": "2022-12-14T21:22:31.338333Z", "shell.execute_reply": "2022-12-14T21:22:31.337760Z" }, "id": "463d5cd0c484" }, "outputs": [], "source": [ "x = layers.Dense(64, activation=\"relu\")(x)\n", "outputs = layers.Dense(10)(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "e379f089b044" }, "source": [ "model = keras.Model(inputs=inputs, outputs=outputs, name=\"mnist_model\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:31.341925Z", "iopub.status.busy": "2022-12-14T21:22:31.341339Z", "iopub.status.idle": "2022-12-14T21:22:31.349152Z", "shell.execute_reply": "2022-12-14T21:22:31.348616Z" }, "id": "7820cc2209a6" }, "outputs": [], "source": [ "model = keras.Model(inputs=inputs, outputs=outputs, name=\"mnist_model\")" ] }, { "cell_type": "markdown", "metadata": { "id": "c9aa111852d3" }, "source": [ "让我们看看模型摘要是什么样子:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:31.352537Z", "iopub.status.busy": "2022-12-14T21:22:31.352026Z", "iopub.status.idle": "2022-12-14T21:22:31.363236Z", "shell.execute_reply": "2022-12-14T21:22:31.362703Z" }, "id": "4949ab8242e8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"mnist_model\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " input_1 (InputLayer) [(None, 784)] 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense (Dense) (None, 64) 50240 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_1 (Dense) (None, 64) 4160 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_2 (Dense) (None, 10) 650 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 55,050\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 55,050\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "99ab8535d6c3" }, "source": [ "您还可以将模型绘制为计算图:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:31.370085Z", "iopub.status.busy": "2022-12-14T21:22:31.369553Z", "iopub.status.idle": "2022-12-14T21:22:31.502884Z", "shell.execute_reply": "2022-12-14T21:22:31.502115Z" }, "id": "6872f1b1b8b8" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAAFgCAYAAABANm70AAAABmJLR0QA/wD/AP+gvaeTAAAgAElEQVR4nO3de0xUZ/4G8GeAYbjPgIto8YKXWtsupZG6VStFpaKNl1EKIt5Yra7RtnY1VuvaGuKa3Zq2tt1U19a2a5u4ETSR1eq60uiaKGNqEXC1xSrGrIpQxMIycnFgvr8/GubX0wEVHGbeGZ5PchJ5zzvnfOc953HmnJk5RyciAiJSzR4/T1dARO1jOIkUxXASKYrhJFJUwC8bLBYLtmzZ4olaiHqsPXv2OLU5vXJevXoVe/fudUtB9P/27t2La9eueboMcrNr1651mDenV8427SWZuo9Op8PKlSsxa9YsT5dCbpSXl4fMzMx25/GYk0hRDCeRohhOIkUxnESKYjiJFMVwEimK4SRSFMNJpCiGk0hRDCeRohhOIkUxnESKYjiJFMVwPoBDhw5h2LBhCAjo8Mc93SYsLAw6nU4zvfPOO26vwxV86bm4kkvCabVa8fDDD2Pq1KmuWJzyysvLMX36dKxbtw5VVVUeqcFqtaK4uBgAYDabISJYvXq1R2p5UL70XFzJJeEUEdjtdtjtdlcsrluFhYVh7NixD7SMN998E2PGjEFRURHCw8NdVJlvc8W49zQueT8WHh6O8vJyVyzKK3z66acIDg72dBnk43jM2QUMJrnDA4czPz9fcyDf1NTUbvuVK1eQmZkJk8mEXr16YerUqZpX23feecfRt1+/fjh9+jRSUlIQHh6OkJAQjB8/HidPnnT037Rpk6P/z98uHT582NH+q1/9ymn5t2/fxsmTJx19PHEypzv1hHFvaWlBbm4uJk6ciD59+iA4OBjx8fH44IMPHIdWtbW1TieZNm3a5Hj8z9vT09Mdy66ursaKFSsQFxeHwMBAREdHIy0tDSUlJR2O8YULFzBr1iz06tXL0Xbz5s0uPz8H+YXc3Fxpp/mezGazAJDGxsZ2281msxQWForVapWCggIJDg6WkSNHOi0nISFBQkNDZfTo0Y7+p0+flieeeEICAwPl3//+t6Z/aGioPPPMM07LSUxMlF69ejm1d9S/q2JjY8Xf3/+BlwNAcnNzO/WY4uJix9j+kreN+92eyy8dOHBAAMif/vQnuXXrllRXV8tf/vIX8fPzk9WrV2v6Tpo0Sfz8/OTSpUtOyxk9erTs2rXL8XdFRYUMHDhQYmJi5ODBg1JfXy/nzp2T5ORkCQoKksLCQs3j28Y4OTlZjh07Jrdv35ZTp06Jv7+/VFdX3/N5iNw1b3luC+eBAwc07enp6QLA6UkkJCQIACkuLta0nz17VgBIQkKCpp3hvHs4vWXcOxvOcePGObXPmzdP9Hq91NXVOdr+9a9/CQBZvny5pu+JEyckNjZW7ty542jLzs4WAJrAiojcuHFDDAaDJCYmatrbxvjQoUP3rLkjdwun2445R44cqfm7f//+AICKigqnvqGhoXjyySc1bfHx8XjooYdQWlqKGzdudF+hPsYXx33q1Kk4duyYU3tCQgJsNhvOnz/vaEtNTUV8fDx27tyJmpoaR/vbb7+NV155BXq93tGWn58PPz8/p48E+/Tpg8cffxxFRUXtXr70N7/5jSuelhO3hdNoNGr+DgwMBIB2P34xmUztLqN3794AgB9++MHF1fkuXxz3uro6bNiwAfHx8YiMjHQc57322msAgIaGBk3/3//+92hoaMC2bdsAAN9//z2OHj2K3/3ud44+zc3NqKurg91uh9FodDpePXPmDADg4sWLTvWEhoZ2y/NU8mxtTU0NpJ07E7btHG07CwD4+fnhzp07Tn1ra2vbXbZOp3NRlb7HW8Z92rRp+OMf/4glS5bg+++/h91uh4jgvffeAwCn5zB37lzExMTgww8/RHNzM959911kZ2cjMjLS0cdgMMBkMiEgIAA2mw0i0u40fvx4lz2Pe1EynE1NTTh9+rSm7T//+Q8qKiqQkJCAvn37Otr79u2L69eva/pWVlbiv//9b7vLDgkJ0exUjzzyCD7++GMXVu+9VB/3gIAAnD9/HidPnkSfPn2wYsUKREdHO4Lf2NjY7uMMBgOWL1+OH374Ae+++y527dqFV1991alfWloaWlpaNGen22zevBkDBgxAS0tLp2p+EEqG02g04g9/+AMsFgtu376Nb775BvPmzUNgYCA++OADTd/U1FRUVFTgww8/hNVqRXl5OV599VXN//I/N2LECHz//fe4evUqLBYLLl++jKSkJHc8LeV5w7j7+/tj3LhxqKysxNtvv42bN2+isbERx44dw/bt2zt83PLlyxEcHIw33ngDzz33HIYOHerU589//jOGDBmCRYsW4Z///Cfq6upw69YtfPTRR9i4cSPeeecd93701omzR+3at2+fANBMc+fOFYvF4tS+fv16kZ/ec2imKVOmOJaXkJAgsbGx8u2338qkSZMkPDxcgoODJTk5WU6cOOG0/traWlm8eLH07dtXgoODZezYsXL69GlJTEx0LH/t2rWO/mVlZZKUlCShoaHSv39/2bp1630/1zZtp/Lbm3bs2NHp5bWNSWfO1oaGhjqt++233/bKcW/vuXQ0fffdd1JdXS1Lly6V/v37i16vl5iYGPntb38rr7/+uqPfL8+siogsWbJEAMjx48c7HNeamhpZtWqVDB48WPR6vURHR0tqaqoUFBQ4+rQ3xp3JzM+55aMUV2nbSXqazobT1XrCuH/22WfthtaTlPgohcjTtm/fjlWrVnm6jPvGcJLP+uSTTzBz5kxYrVZs374dP/74o1fdxU2ZcLZ9B7O0tBTXr1+HTqfDG2+84bb1//JzrfamnJwct9XjLp4e9+6Wn5+PyMhI/PWvf8Xu3bu96rvUOhHth0Jt9wuUdj7vou6j0+mQm5vrVf+z04O7S972KPPKSURaDCeRohhOIkUxnESKYjiJFMVwEimK4SRSFMNJpCiGk0hRDCeRohhOIkUxnESKYjiJFNXh72cyMjLcWQcBeO+997Bnzx5Pl0Fu1N51cNs4/WTMYrFgy5Yt3V4UuVZ1dTW+++47PPvss54uhbqgnf+U9ziFk7wTf4frc/h7TiJVMZxEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFMZxEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFMZxEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFMZxEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFBXi6AOq8a9euITs7G62trY62mzdvIiAgAOPGjdP0feSRR/DRRx+5uUJyBYbTC/Xr1w9XrlzB5cuXneYdP35c83dSUpK7yiIX49taL7VgwQLo9fp79ps9e7YbqqHuwHB6qblz58Jms921z2OPPYbHH3/cTRWRqzGcXmro0KF44oknoNPp2p2v1+uRnZ3t5qrIlRhOL7ZgwQL4+/u3O6+lpQWzZs1yc0XkSgynF8vKyoLdbndq1+l0ePrppxEXF+f+oshlGE4v9tBDD2HMmDHw89NuRn9/fyxYsMBDVZGrMJxebv78+U5tIoIXXnjBA9WQKzGcXi4jI0Pzyunv74/nnnsOvXv39mBV5AoMp5eLjIxEamqq48SQiGDevHkeropcgeH0AfPmzXOcGAoICMD06dM9XBG5AsPpA6ZPnw6DweD4d0REhIcrIldQ5ru1165dQ2FhoafL8FojRoxAYWEhBg0ahLy8PE+X47VU+mxYJyLi6SIAIC8vD5mZmZ4ug3o4ReIAAHuUeeVso9DgeI2MjAzY7XYMHToUmzdv9nQ5XknFFwcec/oIPz8/5OTkeLoMciGG04cEBwd7ugRyIYaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRflcOHfv3g2dTgedToegoCBPl6OssLAwxzi1TX5+foiMjERCQgKWL1+OoqIiT5fZo/lcOGfPng0RQUpKiqdLUZrVakVxcTEAwGw2Q0Rgs9lQVlaGjRs3oqysDE899RQWLlyIhoYGD1fbM/lcOKnr/P39ERMTA7PZjKNHj2LNmjXYuXMnsrKy+DtbD2A4qUNvvfUWnn76aezfvx+7d+/2dDk9DsNJHdLpdHj55ZcBANu2bfNwNT2P14ezrKwMM2bMgNFoRGhoKJKSknDixIkO+1dXV2PFihWIi4tDYGAgoqOjkZaWhpKSEkef/Px8zYmSK1euIDMzEyaTCb169cLUqVNRXl6uWW5zczM2bNiA4cOHIyQkBFFRUZg2bRr279+vuQP1/dagirFjxwIATp06pbnlIMfRDUQRubm50tlyLl68KCaTSWJjY+XIkSNSX18vZ8+eldTUVImLixODwaDpX1FRIQMHDpSYmBg5ePCg1NfXy7lz5yQ5OVmCgoKksLBQ099sNgsAMZvNUlhYKFarVQoKCiQ4OFhGjhyp6bt48WIxGo1y5MgRaWhokMrKSlm9erUAkGPHjnW5hvuRnp4u6enpnX5ccXGx4/l1pLGxUQAIAKmoqOjSc/CGcezK/tfN8pSppiuDk5GRIQBk7969mvbr16+LwWBwCmd2drYAkF27dmnab9y4IQaDQRITEzXtbTvVgQMHNO3p6ekCQKqrqx1tgwYNkjFjxjjVOGzYMM1O1dka7kd3hrOhocEpnL44jgznXXRlcMLDwwWA1NfXO82Lj493CqfRaBQ/Pz+pq6tz6j9ixAgBIFevXnW0te1UlZWVmr4rV64UAFJaWupoW7ZsmQCQJUuWiMVikZaWlnZr7mwN96M7w1leXi4ARK/Xy507d0TEN8dRxXB67TFnc3Mz6uvrERQUhLCwMKf5v7yRT3NzM+rq6mC322E0Gp0+gD9z5gwA4OLFi07LMhqNmr8DAwMBQHNvzK1bt+KLL77A5cuXkZKSgoiICEyePBn79u1zSQ2e0nb8Pnr0aOj1eo6jG3ltOA0GA8LDw9HU1ASr1eo0/9atW079TSYTAgICYLPZICLtTuPHj+9SPTqdDvPnz8dXX32F2tpa5OfnQ0SQlpaGLVu2uKUGV7Pb7di6dSsA4KWXXgLAcXQnrw0nADz//PMAgMOHD2vab968iQsXLjj1T0tLQ0tLC06ePOk0b/PmzRgwYABaWlq6VIvJZEJZWRkAQK/XY+LEiY6zlQcPHnRLDa62bt06fP3115g5cyYyMjIc7RxHN3HXG+h76cp7/kuXLklUVJTmbO358+dl0qRJ0rt3b6djzqqqKhkyZIgMHjxYDh06JLW1tVJTUyPbt2+XkJAQyc3N1fRvO1ZqbGzUtK9du1YASHFxsaPNaDRKcnKylJaWSlNTk1RVVUlOTo4AkE2bNnW5hvvhqmPO1tZWqaqqkvz8fJkwYYIAkEWLFklDQ4Pmcb44jioecypTTVcH58KFCzJjxgyJiIhwnJr/8ssvJSUlxXGW8cUXX3T0r6mpkVWrVsngwYNFr9dLdHS0pKamSkFBgaOPxWJxPLZtWr9+vYiIU/uUKVNERKSkpESWLl0qjz76qISEhEhUVJSMGjVKduzYIXa7XVPz/dTQGV0JZ2hoqNNz0el0YjQaJT4+XpYtWyZFRUUdPt7XxlHFcCp3IyNFyvEqbW859+zZ4+FKvJeC+98erz7mJPJlDCeRohhOIkUxnESKYjiJFMVwEimK4SRSFMNJpCiGk0hRDCeRohhOIkUxnESKYjiJFMVwEimK4SRSFMNJpCiGk0hRAZ4u4Jfy8vI8XYLXuXbtGgCO3YOwWCyeLsGJcuHMzMz0dAlei2PnW5S5hhA9GAWvgUMPhtcQIlIVw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFKUcredp3urrq7Gvn37NG3ffPMNAODjjz/WtIeFhWHOnDluq41ch7ed90LNzc2Ijo7G7du34e/vDwAQEYgI/Pz+/82QzWbDggUL8Pnnn3uqVOo63nbeGxkMBmRkZCAgIAA2mw02mw0tLS1obW11/G2z2QCAr5pejOH0UnPmzMGdO3fu2sdkMiElJcVNFZGrMZxeavz48YiOju5wvl6vx7x58xAQwNMK3orh9FJ+fn6YM2cOAgMD251vs9mQlZXl5qrIlRhOL5aVldXhW9u+ffti9OjRbq6IXInh9GJPP/00Bg4c6NSu1+uRnZ0NnU7ngarIVRhOLzd//nzo9XpNG9/S+gaG08vNnTvX8bFJm6FDh+KJJ57wUEXkKgynlxs+fDgee+wxx1tYvV6PhQsXergqcgWG0wcsWLDA8U0hm82GWbNmebgicgWG0wfMnj0bra2tAIDExEQMHTrUwxWRKzCcPmDgwIEYOXIkgJ9eRck3dPsX3/Py8pCZmdmdqyByOzf8XmSP277blZub665V9Uj/+9//sG3bNrz++usd9nnvvfcAACtXrnRXWT7HYrHg/fffd8u63BZOnqTofsnJyXj44Yc7nL9nzx4A3BYPyl3h5DGnD7lbMMn7MJxEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFMZxEimI4iRTFcBIpiuEkUpTXhHP37t3Q6XTQ6XQICgrydDludejQIQwbNswjV28PCwtzjHvb5Ofnh8jISCQkJGD58uUoKipye109gdeEc/bs2RCRHnXvj/LyckyfPh3r1q1DVVWVR2qwWq0oLi4GAJjNZogIbDYbysrKsHHjRpSVleGpp57CwoUL0dDQ4JEafZXXhLMnevPNNzFmzBgUFRUhPDzc0+U4+Pv7IyYmBmazGUePHsWaNWuwc+dOZGVlueMKAT0G73KjsE8//RTBwcGeLuOe3nrrLRw/fhz79+/H7t27eUFrF+Erp8K8IZgAoNPp8PLLLwMAtm3b5uFqfIey4SwrK8OMGTNgNBoRGhqKpKQknDhxosP+1dXVWLFiBeLi4hAYGIjo6GikpaWhpKTE0Sc/P19zYuPKlSvIzMyEyWRCr169MHXqVJSXl2uW29zcjA0bNmD48OEICQlBVFQUpk2bhv379zsuR9mZGnzV2LFjAQCnTp3SXIGe2+UBSDfLzc2Vzq7m4sWLYjKZJDY2Vo4cOSL19fVy9uxZSU1Nlbi4ODEYDJr+FRUVMnDgQImJiZGDBw9KfX29nDt3TpKTkyUoKEgKCws1/c1mswAQs9kshYWFYrVapaCgQIKDg2XkyJGavosXLxaj0ShHjhyRhoYGqayslNWrVwsAOXbsWJdr6KzY2Fjx9/d/oGWkp6dLenp6px9XXFzsGK+ONDY2CgABIBUVFSLim9ulK/tzF+UpGc6MjAwBIHv37tW0X79+XQwGg1M4s7OzBYDs2rVL037jxg0xGAySmJioaW/bCQ4cOKBpT09PFwBSXV3taBs0aJCMGTPGqcZhw4ZpdoLO1tBZqoezoaHBKZy+uF16fDjDw8MFgNTX1zvNi4+Pdwqn0WgUPz8/qaurc+o/YsQIASBXr151tLXtBJWVlZq+K1euFABSWlrqaFu2bJkAkCVLlojFYpGWlpZ2a+5sDZ2lejjLy8sFgOj1erlz546I+OZ2cWc4lTvmbG5uRn19PYKCghAWFuY0v3fv3k796+rqYLfbYTQanT4wP3PmDADg4sWLTssyGo2av9vuEm232x1tW7duxRdffIHLly8jJSUFERERmDx5Mvbt2+eSGnxF2/mA0aNHQ6/Xc7u4gHLhNBgMCA8PR1NTE6xWq9P8W7duOfU3mUwICAiAzWaDiLQ7jR8/vkv16HQ6zJ8/H1999RVqa2uRn58PEUFaWhq2bNnilhpUZ7fbsXXrVgDASy+9BIDbxRWUCycAPP/88wCAw4cPa9pv3ryJCxcuOPVPS0tDS0sLTp486TRv8+bNGDBgAFpaWrpUi8lkQllZGYCfbq83ceJEx9nFgwcPuqUG1a1btw5ff/01Zs6ciYyMDEc7t8sD6u43zl15j37p0iWJiorSnK09f/68TJo0SXr37u10zFlVVSVDhgyRwYMHy6FDh6S2tlZqampk+/btEhISIrm5uZr+bcc2jY2Nmva1a9cKACkuLna0GY1GSU5OltLSUmlqapKqqirJyckRALJp06Yu19BZKh1ztra2SlVVleTn58uECRMEgCxatEgaGho0j/PF7dLjTwiJiFy4cEFmzJghERERjlPpX375paSkpDjOCr744ouO/jU1NbJq1SoZPHiw6PV6iY6OltTUVCkoKHD0sVgsjse2TevXrxcRcWqfMmWKiIiUlJTI0qVL5dFHH5WQkBCJioqSUaNGyY4dO8Rut2tqvp8aOuPAgQNOdbVNO3bs6PTyuhLO0NBQp3XrdDoxGo0SHx8vy5Ytk6Kiog4f72vbxZ3hdNtdxrp5NXQf2t5ytt0zhTrPjfvzHiWPOYlI0RNCRMRwut0vP2trb8rJyfF0maQA/mTMzXjsTfeLr5xEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFMZxEimI4iRTFcBIpiuEkUhTDSaQot/0qRafTuWtVdA/cFt6h28M5ZswY5ObmdvdqejyLxYL333+fY+1Duv0aQuQevFaTz+E1hIhUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpy223nyXVsNhusVqum7fbt2wCAH3/8UdOu0+lgMpncVhu5DsPphWpqatCvXz+0trY6zYuKitL8PW7cOBw7dsxdpZEL8W2tF+rTpw+effZZ+PndffPpdDpkZWW5qSpyNYbTS82fPx86ne6uffz8/PDCCy+4qSJyNYbTS73wwgvw9/fvcL6/vz8mT56MXr16ubEqciWG00tFRERg8uTJCAho/7SBiGDevHluropcieH0YvPmzWv3pBAABAYGYurUqW6uiFyJ4fRi06ZNQ0hIiFN7QEAAZs6cibCwMA9URa7CcHqxoKAgpKWlQa/Xa9pbWlowd+5cD1VFrsJwerk5c+bAZrNp2iIiIjBx4kQPVUSuwnB6ueeee07zxQO9Xo/Zs2cjMDDQg1WRKzCcXi4gIACzZ892vLW12WyYM2eOh6siV2A4fUBWVpbjrW1MTAySkpI8XBG5AsPpA5555hk89NBDAH765tC9vtZH3qHbv/husViwZcuW7l5NjxceHg4AKC4uRkZGhoer8X179uzp9nV0+3+xV69exd69e7t7NT3egAEDEB4ejsjIyA77nDp1CqdOnXJjVb7n2rVrbtuf3faTMXf8T9PT5eXlYdasWR3Ob3tF5bboury8PGRmZrplXTw48SF3CyZ5H4aTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIrymnDu3r0bOp0OOp0OQUFBni6n2/3444/Yvn07JkyYgKioKAQHB+Phhx/G3LlzUVpa6rY6wsLCHOPeNvn5+SEyMhIJCQlYvnw5ioqK3FZPT+I14Zw9ezZEBCkpKZ4uxS1ee+01vPLKKzCbzfj2229RU1ODzz77DCUlJUhMTER+fr5b6rBarSguLgYAmM1miAhsNhvKysqwceNGlJWV4amnnsLChQvR0NDglpp6Cq8JZ0+0aNEivPrqq+jTpw9CQkKQlJSEv//972htbcWaNWs8Vpe/vz9iYmJgNptx9OhRrFmzBjt37kRWVhZExGN1+Rren1NRn3zySbvtCQkJCA4ORnl5OUTknncac4e33noLx48fx/79+7F7927edtBF+MrpZW7fvo3Gxkb8+te/ViKYwE/3AX355ZcBANu2bfNwNb5D2XCWlZVhxowZMBqNCA0NRVJSEk6cONFh/+rqaqxYsQJxcXEIDAxEdHQ00tLSUFJS4uiTn5+vObFx5coVZGZmwmQyoVevXpg6dSrKy8s1y21ubsaGDRswfPhwhISEICoqCtOmTcP+/fudbiJ0PzU8qLZLjKxfv95ly3SFsWPHAvjpOkU/vwJ9T9ku3UK6WW5urnR2NRcvXhSTySSxsbFy5MgRqa+vl7Nnz0pqaqrExcWJwWDQ9K+oqJCBAwdKTEyMHDx4UOrr6+XcuXOSnJwsQUFBUlhYqOlvNpsFgJjNZiksLBSr1SoFBQUSHBwsI0eO1PRdvHixGI1GOXLkiDQ0NEhlZaWsXr1aAMixY8e6XENXVFZWSkxMjCxevLhLj09PT5f09PROP664uNgxXh1pbGwUAAJAKioqRMQ3t0tX9ucuylMynBkZGQJA9u7dq2m/fv26GAwGp3BmZ2cLANm1a5em/caNG2IwGCQxMVHT3rYTHDhwQNOenp4uAKS6utrRNmjQIBkzZoxTjcOGDdPsBJ2tobNu3rwpTz75pGRmZkpLS0uXltGd4WxoaHAKpy9ulx4fzvDwcAEg9fX1TvPi4+Odwmk0GsXPz0/q6uqc+o8YMUIAyNWrVx1tbTtBZWWlpu/KlSsFgJSWljrali1bJgBkyZIlYrFYOgxGZ2voDKvVKomJiTJnzpwuB1Oke8NZXl4uAESv18udO3dExDe3izvDqdwxZ3NzM+rr6xEUFNTu/SV79+7t1L+urg52ux1Go9HpA/MzZ84AAC5evOi0LKPRqPm77eY/drvd0bZ161Z88cUXuHz5MlJSUhx3lN63b59LariXlpYWZGRkIDY2Fp9//vldbzXvSW3nA0aPHg29Xu/z28UdlAunwWBAeHg4mpqaYLVanebfunXLqb/JZEJAQABsNhtEpN1p/PjxXapHp9Nh/vz5+Oqrr1BbW4v8/HyICNLS0hxXsu/OGpYuXYrm5mbk5eVpbjE/dOcBt98AAAJ/SURBVOhQZS4QbbfbsXXrVgDASy+9BMD3t4s7KBdOAHj++ecBAIcPH9a037x5ExcuXHDqn5aWhpaWFpw8edJp3ubNmzFgwAC0tLR0qRaTyYSysjIAP91eb+LEiY6ziwcPHuzWGnJycnD+/Hn84x//gMFg6FL97rBu3Tp8/fXXmDlzpuZWEL66Xdymu984d+U9+qVLlyQqKkpztvb8+fMyadIk6d27t9MxZ1VVlQwZMkQGDx4shw4dktraWqmpqZHt27dLSEiI5Obmavq3Hds0NjZq2teuXSsApLi42NFmNBolOTlZSktLpampSaqqqiQnJ0cAyKZNm7pcw7387W9/c5xg6WiyWCydWqarjjlbW1ulqqpK8vPzZcKECQJAFi1aJA0NDZrH+eJ26fEnhERELly4IDNmzJCIiAjHqfQvv/xSUlJSHDvniy++6OhfU1Mjq1atksGDB4ter5fo6GhJTU2VgoICRx+LxeK0g69fv15ExKl9ypQpIiJSUlIiS5culUcffVRCQkIkKipKRo0aJTt27BC73a6p+X5quF9TpkxRIpyhoaFO69XpdGI0GiU+Pl6WLVsmRUVFHT7e17aLO8OpE+neL0O23Vuim1dD94H3Snlwbtyf9yh5zElEip4QIiKG0+1++Vlbe1NOTo6nyyQF8CdjbsZjb7pffOUkUhTDSaQohpNIUQwnkaIYTiJFMZxEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFue1XKT+/8BN5RtvV+rgtuu7atWtuW1e3h7N///5IT0/v7tXQfRg1apSnS/B6/fr1c9v+3O3XECKiLuE1hIhUxXASKYrhJFIUw0mkqP8DKFnjfb+C/sgAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "keras.utils.plot_model(model, \"my_first_model.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "6d9880136879" }, "source": [ "并且,您还可以选择在绘制的计算图中显示每层的输入和输出形状:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:31.507062Z", "iopub.status.busy": "2022-12-14T21:22:31.506380Z", "iopub.status.idle": "2022-12-14T21:22:31.636218Z", "shell.execute_reply": "2022-12-14T21:22:31.635422Z" }, "id": "aa14046d3388" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "keras.utils.plot_model(model, \"my_first_model_with_shape_info.png\", show_shapes=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "71969f9c91bb" }, "source": [ "此图和代码几乎完全相同。在代码版本中,连接箭头由调用操作代替。\n", "\n", "“层计算图”是深度学习模型的直观心理图像,而函数式 API 是创建密切反映此图像的模型的方法。" ] }, { "cell_type": "markdown", "metadata": { "id": "775b997c8c28" }, "source": [ "## 训练,评估和推断\n", "\n", "对于使用函数式 API 构建的模型来说,其训练、评估和推断的工作方式与 `Sequential` 模型完全相同。\n", "\n", "`Model` 类提供了一个内置训练循环(`fit()` 方法)和一个内置评估循环(`evaluate()` 方法)。请注意,您可以轻松地[自定义这些循环](https://tensorflow.google.cn/guide/keras/customizing_what_happens_in_fit/),以实现监督学习之外的训练例程(例如 [GAN](/examples/generative/dcgan_overriding_train_step/))。\n", "\n", "如下所示,加载 MNIST 图像数据,将其改造为向量,将模型与数据拟合(同时监视验证拆分的性能),然后在测试数据上评估模型:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:31.640315Z", "iopub.status.busy": "2022-12-14T21:22:31.639723Z", "iopub.status.idle": "2022-12-14T21:22:38.163183Z", "shell.execute_reply": "2022-12-14T21:22:38.162511Z" }, "id": "e61366d54487" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/750 [..............................] - ETA: 15:26 - loss: 2.3278 - accuracy: 0.0469" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 25/750 [>.............................] - ETA: 1s - loss: 1.5759 - accuracy: 0.5906 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 50/750 [=>............................] - ETA: 1s - loss: 1.1976 - accuracy: 0.6925" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 75/750 [==>...........................] - ETA: 1s - loss: 0.9897 - accuracy: 0.7421" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 99/750 [==>...........................] - ETA: 1s - loss: 0.8638 - accuracy: 0.7724" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "124/750 [===>..........................] - ETA: 1s - loss: 0.7746 - accuracy: 0.7931" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "149/750 [====>.........................] - ETA: 1s - loss: 0.7098 - accuracy: 0.8081" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "173/750 [=====>........................] - ETA: 1s - loss: 0.6592 - accuracy: 0.8218" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "198/750 [======>.......................] - ETA: 1s - loss: 0.6250 - accuracy: 0.8311" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "223/750 [=======>......................] - ETA: 1s - loss: 0.5914 - accuracy: 0.8380" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "248/750 [========>.....................] - ETA: 1s - loss: 0.5657 - accuracy: 0.8443" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "272/750 [=========>....................] - ETA: 0s - loss: 0.5412 - accuracy: 0.8506" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "296/750 [==========>...................] - ETA: 0s - loss: 0.5214 - accuracy: 0.8563" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "320/750 [===========>..................] - ETA: 0s - loss: 0.5022 - accuracy: 0.8612" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "345/750 [============>.................] - ETA: 0s - loss: 0.4875 - accuracy: 0.8652" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "370/750 [=============>................] - ETA: 0s - loss: 0.4724 - accuracy: 0.8685" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "395/750 [==============>...............] - ETA: 0s - loss: 0.4587 - accuracy: 0.8721" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "420/750 [===============>..............] - ETA: 0s - loss: 0.4461 - accuracy: 0.8751" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "444/750 [================>.............] - ETA: 0s - loss: 0.4360 - accuracy: 0.8780" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "469/750 [=================>............] - ETA: 0s - loss: 0.4280 - accuracy: 0.8803" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "494/750 [==================>...........] - ETA: 0s - loss: 0.4189 - accuracy: 0.8827" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "518/750 [===================>..........] - ETA: 0s - loss: 0.4108 - accuracy: 0.8845" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "542/750 [====================>.........] - ETA: 0s - loss: 0.4028 - accuracy: 0.8868" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "565/750 [=====================>........] - ETA: 0s - loss: 0.3950 - accuracy: 0.8888" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "590/750 [======================>.......] - ETA: 0s - loss: 0.3886 - accuracy: 0.8907" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "615/750 [=======================>......] - ETA: 0s - loss: 0.3823 - accuracy: 0.8926" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "640/750 [========================>.....] - ETA: 0s - loss: 0.3765 - accuracy: 0.8942" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "665/750 [=========================>....] - ETA: 0s - loss: 0.3701 - accuracy: 0.8960" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "690/750 [==========================>...] - ETA: 0s - loss: 0.3654 - accuracy: 0.8973" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "715/750 [===========================>..] - ETA: 0s - loss: 0.3598 - accuracy: 0.8986" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "740/750 [============================>.] - ETA: 0s - loss: 0.3542 - accuracy: 0.9001" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "750/750 [==============================] - 3s 3ms/step - loss: 0.3525 - accuracy: 0.9006 - val_loss: 0.1963 - val_accuracy: 0.9407\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/750 [..............................] - ETA: 2s - loss: 0.1896 - accuracy: 0.9219" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 25/750 [>.............................] - ETA: 1s - loss: 0.1843 - accuracy: 0.9469" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 50/750 [=>............................] - ETA: 1s - loss: 0.1916 - accuracy: 0.9472" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 75/750 [==>...........................] - ETA: 1s - loss: 0.1875 - accuracy: 0.9442" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "101/750 [===>..........................] - ETA: 1s - loss: 0.1912 - accuracy: 0.9440" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "126/750 [====>.........................] - ETA: 1s - loss: 0.1906 - accuracy: 0.9443" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "152/750 [=====>........................] - ETA: 1s - loss: 0.1842 - accuracy: 0.9469" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "177/750 [======>.......................] - ETA: 1s - loss: 0.1825 - accuracy: 0.9472" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "201/750 [=======>......................] - ETA: 1s - loss: 0.1792 - accuracy: 0.9479" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "225/750 [========>.....................] - ETA: 1s - loss: 0.1795 - accuracy: 0.9483" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "251/750 [=========>....................] - ETA: 1s - loss: 0.1777 - accuracy: 0.9489" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "277/750 [==========>...................] - ETA: 0s - loss: 0.1748 - accuracy: 0.9492" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "303/750 [===========>..................] - ETA: 0s - loss: 0.1729 - accuracy: 0.9497" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "329/750 [============>.................] - ETA: 0s - loss: 0.1728 - accuracy: 0.9498" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "356/750 [=============>................] - ETA: 0s - loss: 0.1723 - accuracy: 0.9495" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "383/750 [==============>...............] - ETA: 0s - loss: 0.1702 - accuracy: 0.9500" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "409/750 [===============>..............] - ETA: 0s - loss: 0.1700 - accuracy: 0.9499" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "436/750 [================>.............] - ETA: 0s - loss: 0.1676 - accuracy: 0.9504" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "462/750 [=================>............] - ETA: 0s - loss: 0.1673 - accuracy: 0.9506" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "488/750 [==================>...........] - ETA: 0s - loss: 0.1677 - accuracy: 0.9508" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "514/750 [===================>..........] - ETA: 0s - loss: 0.1664 - accuracy: 0.9508" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "541/750 [====================>.........] - ETA: 0s - loss: 0.1664 - accuracy: 0.9505" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "567/750 [=====================>........] - ETA: 0s - loss: 0.1662 - accuracy: 0.9506" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "594/750 [======================>.......] - ETA: 0s - loss: 0.1652 - accuracy: 0.9511" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "620/750 [=======================>......] - ETA: 0s - loss: 0.1647 - accuracy: 0.9512" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "646/750 [========================>.....] - ETA: 0s - loss: 0.1629 - accuracy: 0.9515" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "672/750 [=========================>....] - ETA: 0s - loss: 0.1633 - accuracy: 0.9514" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "699/750 [==========================>...] - ETA: 0s - loss: 0.1631 - accuracy: 0.9514" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "725/750 [============================>.] - ETA: 0s - loss: 0.1623 - accuracy: 0.9515" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "750/750 [==============================] - 2s 2ms/step - loss: 0.1613 - accuracy: 0.9516 - val_loss: 0.1461 - val_accuracy: 0.9572\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "313/313 - 0s - loss: 0.1391 - accuracy: 0.9577 - 472ms/epoch - 2ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Test loss: 0.13910873234272003\n", "Test accuracy: 0.9577000141143799\n" ] } ], "source": [ "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", "\n", "x_train = x_train.reshape(60000, 784).astype(\"float32\") / 255\n", "x_test = x_test.reshape(10000, 784).astype(\"float32\") / 255\n", "\n", "model.compile(\n", " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " optimizer=keras.optimizers.RMSprop(),\n", " metrics=[\"accuracy\"],\n", ")\n", "\n", "history = model.fit(x_train, y_train, batch_size=64, epochs=2, validation_split=0.2)\n", "\n", "test_scores = model.evaluate(x_test, y_test, verbose=2)\n", "print(\"Test loss:\", test_scores[0])\n", "print(\"Test accuracy:\", test_scores[1])" ] }, { "cell_type": "markdown", "metadata": { "id": "2e13d7168c86" }, "source": [ "有关更多信息,请参阅[训练和评估](https://tensorflow.google.cn/guide/keras/train_and_evaluate/)指南。" ] }, { "cell_type": "markdown", "metadata": { "id": "26991ef4dbbb" }, "source": [ "## 保存和序列化\n", "\n", "对于使用函数式 API 构建的模型,其保存模型和序列化的工作方式与 `Sequential` 模型相同。保存函数式模型的标准方式是调用 `model.save()` 将整个模型保存为单个文件。您可以稍后从该文件重新创建相同的模型,即使构建该模型的代码已不再可用。\n", "\n", "保存的文件包括:\n", "\n", "- 模型架构\n", "- 模型权重值(在训练过程中得知)\n", "- 模型训练配置(如果有的话,如传递给 `compile`)\n", "- 优化器及其状态(如果有的话,用来从上次中断的地方重新开始训练)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:38.167056Z", "iopub.status.busy": "2022-12-14T21:22:38.166390Z", "iopub.status.idle": "2022-12-14T21:22:38.895399Z", "shell.execute_reply": "2022-12-14T21:22:38.894359Z" }, "id": "7e5e48669225" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: path_to_my_model/assets\n" ] } ], "source": [ "model.save(\"path_to_my_model\")\n", "del model\n", "# Recreate the exact same model purely from the file:\n", "model = keras.models.load_model(\"path_to_my_model\")" ] }, { "cell_type": "markdown", "metadata": { "id": "cfe2a761139b" }, "source": [ "有关详细信息,请阅读模型[序列化和保存](https://tensorflow.google.cn/guide/keras/save_and_serialize/)指南。" ] }, { "cell_type": "markdown", "metadata": { "id": "b747517364a9" }, "source": [ "## 所有模型均可像层一样调用\n", "\n", "在函数式 API 中,模型是通过在层计算图中指定其输入和输出来创建的。这意味着可以使用单个层计算图来生成多个模型。\n", "\n", "在下面的示例中,您将使用相同的层堆栈来实例化两个模型:能够将图像输入转换为 16 维向量的 `encoder` 模型,以及用于训练的端到端 `autoencoder` 模型。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:38.899534Z", "iopub.status.busy": "2022-12-14T21:22:38.899271Z", "iopub.status.idle": "2022-12-14T21:22:39.055648Z", "shell.execute_reply": "2022-12-14T21:22:39.055014Z" }, "id": "f9924d8c9ed3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"encoder\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " img (InputLayer) [(None, 28, 28, 1)] 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d (Conv2D) (None, 26, 26, 16) 160 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_1 (Conv2D) (None, 24, 24, 32) 4640 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " max_pooling2d (MaxPooling2D (None, 8, 8, 32) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_2 (Conv2D) (None, 6, 6, 32) 9248 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_3 (Conv2D) (None, 4, 4, 16) 4624 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " global_max_pooling2d (Globa (None, 16) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " lMaxPooling2D) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 18,672\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 18,672\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"autoencoder\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " img (InputLayer) [(None, 28, 28, 1)] 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d (Conv2D) (None, 26, 26, 16) 160 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_1 (Conv2D) (None, 24, 24, 32) 4640 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " max_pooling2d (MaxPooling2D (None, 8, 8, 32) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_2 (Conv2D) (None, 6, 6, 32) 9248 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_3 (Conv2D) (None, 4, 4, 16) 4624 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " global_max_pooling2d (Globa (None, 16) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " lMaxPooling2D) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " reshape (Reshape) (None, 4, 4, 1) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_transpose (Conv2DTra (None, 6, 6, 16) 160 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " nspose) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_transpose_1 (Conv2DT (None, 8, 8, 32) 4640 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ranspose) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " up_sampling2d (UpSampling2D (None, 24, 24, 32) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_transpose_2 (Conv2DT (None, 26, 26, 16) 4624 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ranspose) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_transpose_3 (Conv2DT (None, 28, 28, 1) 145 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ranspose) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 28,241\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 28,241\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "encoder_input = keras.Input(shape=(28, 28, 1), name=\"img\")\n", "x = layers.Conv2D(16, 3, activation=\"relu\")(encoder_input)\n", "x = layers.Conv2D(32, 3, activation=\"relu\")(x)\n", "x = layers.MaxPooling2D(3)(x)\n", "x = layers.Conv2D(32, 3, activation=\"relu\")(x)\n", "x = layers.Conv2D(16, 3, activation=\"relu\")(x)\n", "encoder_output = layers.GlobalMaxPooling2D()(x)\n", "\n", "encoder = keras.Model(encoder_input, encoder_output, name=\"encoder\")\n", "encoder.summary()\n", "\n", "x = layers.Reshape((4, 4, 1))(encoder_output)\n", "x = layers.Conv2DTranspose(16, 3, activation=\"relu\")(x)\n", "x = layers.Conv2DTranspose(32, 3, activation=\"relu\")(x)\n", "x = layers.UpSampling2D(3)(x)\n", "x = layers.Conv2DTranspose(16, 3, activation=\"relu\")(x)\n", "decoder_output = layers.Conv2DTranspose(1, 3, activation=\"relu\")(x)\n", "\n", "autoencoder = keras.Model(encoder_input, decoder_output, name=\"autoencoder\")\n", "autoencoder.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "e87d4185b652" }, "source": [ "在上例中,解码架构与编码架构严格对称,因此输出形状与输入形状 `(28, 28, 1)` 相同。\n", "\n", "`Conv2D` 层的反面是 `Conv2DTranspose` 层,`MaxPooling2D` 层的反面是 `UpSampling2D` 层。" ] }, { "cell_type": "markdown", "metadata": { "id": "9c746c1a0b79" }, "source": [ "## 所有模型均可像层一样调用\n", "\n", "您可以通过在 `Input` 上或在另一个层的输出上调用任何模型来将其当作层来处理。通过调用模型,您不仅可以重用模型的架构,还可以重用它的权重。\n", "\n", "为了查看实际运行情况,下面是对自动编码器示例的另一种处理方式,该示例创建了一个编码器模型、一个解码器模型,并在两个调用中将它们链接,以获得自动编码器模型:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:39.070291Z", "iopub.status.busy": "2022-12-14T21:22:39.069673Z", "iopub.status.idle": "2022-12-14T21:22:39.281437Z", "shell.execute_reply": "2022-12-14T21:22:39.280735Z" }, "id": "862ac58e928b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"encoder\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " original_img (InputLayer) [(None, 28, 28, 1)] 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_4 (Conv2D) (None, 26, 26, 16) 160 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_5 (Conv2D) (None, 24, 24, 32) 4640 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " max_pooling2d_1 (MaxPooling (None, 8, 8, 32) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 2D) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_6 (Conv2D) (None, 6, 6, 32) 9248 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_7 (Conv2D) (None, 4, 4, 16) 4624 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " global_max_pooling2d_1 (Glo (None, 16) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " balMaxPooling2D) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 18,672\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 18,672\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"decoder\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " encoded_img (InputLayer) [(None, 16)] 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " reshape_1 (Reshape) (None, 4, 4, 1) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_transpose_4 (Conv2DT (None, 6, 6, 16) 160 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ranspose) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_transpose_5 (Conv2DT (None, 8, 8, 32) 4640 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ranspose) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " up_sampling2d_1 (UpSampling (None, 24, 24, 32) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 2D) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_transpose_6 (Conv2DT (None, 26, 26, 16) 4624 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ranspose) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_transpose_7 (Conv2DT (None, 28, 28, 1) 145 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ranspose) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 9,569\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 9,569\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"autoencoder\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " img (InputLayer) [(None, 28, 28, 1)] 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " encoder (Functional) (None, 16) 18672 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " decoder (Functional) (None, 28, 28, 1) 9569 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 28,241\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 28,241\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "encoder_input = keras.Input(shape=(28, 28, 1), name=\"original_img\")\n", "x = layers.Conv2D(16, 3, activation=\"relu\")(encoder_input)\n", "x = layers.Conv2D(32, 3, activation=\"relu\")(x)\n", "x = layers.MaxPooling2D(3)(x)\n", "x = layers.Conv2D(32, 3, activation=\"relu\")(x)\n", "x = layers.Conv2D(16, 3, activation=\"relu\")(x)\n", "encoder_output = layers.GlobalMaxPooling2D()(x)\n", "\n", "encoder = keras.Model(encoder_input, encoder_output, name=\"encoder\")\n", "encoder.summary()\n", "\n", "decoder_input = keras.Input(shape=(16,), name=\"encoded_img\")\n", "x = layers.Reshape((4, 4, 1))(decoder_input)\n", "x = layers.Conv2DTranspose(16, 3, activation=\"relu\")(x)\n", "x = layers.Conv2DTranspose(32, 3, activation=\"relu\")(x)\n", "x = layers.UpSampling2D(3)(x)\n", "x = layers.Conv2DTranspose(16, 3, activation=\"relu\")(x)\n", "decoder_output = layers.Conv2DTranspose(1, 3, activation=\"relu\")(x)\n", "\n", "decoder = keras.Model(decoder_input, decoder_output, name=\"decoder\")\n", "decoder.summary()\n", "\n", "autoencoder_input = keras.Input(shape=(28, 28, 1), name=\"img\")\n", "encoded_img = encoder(autoencoder_input)\n", "decoded_img = decoder(encoded_img)\n", "autoencoder = keras.Model(autoencoder_input, decoded_img, name=\"autoencoder\")\n", "autoencoder.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "0f77623d9cd5" }, "source": [ "如您所见,模型可以嵌套:模型可以包含子模型(因为模型就像层一样)。模型嵌套的一个常见用例是*装配*。例如,以下展示了如何将一组模型装配成一个平均其预测的模型:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:39.287205Z", "iopub.status.busy": "2022-12-14T21:22:39.286703Z", "iopub.status.idle": "2022-12-14T21:22:39.341737Z", "shell.execute_reply": "2022-12-14T21:22:39.341075Z" }, "id": "3bb36b630e5d" }, "outputs": [], "source": [ "def get_model():\n", " inputs = keras.Input(shape=(128,))\n", " outputs = layers.Dense(1)(inputs)\n", " return keras.Model(inputs, outputs)\n", "\n", "\n", "model1 = get_model()\n", "model2 = get_model()\n", "model3 = get_model()\n", "\n", "inputs = keras.Input(shape=(128,))\n", "y1 = model1(inputs)\n", "y2 = model2(inputs)\n", "y3 = model3(inputs)\n", "outputs = layers.average([y1, y2, y3])\n", "ensemble_model = keras.Model(inputs=inputs, outputs=outputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "447a319b73a6" }, "source": [ "## 处理复杂的计算图拓扑\n", "\n", "### 具有多个输入和输出的模型\n", "\n", "函数式 API 使处理多个输入和输出变得容易。而这无法使用 `Sequential` API 处理。\n", "\n", "例如,如果您要构建一个系统,该系统按照优先级对自定义问题工单进行排序,然后将工单传送到正确的部门,则此模型将具有三个输入:\n", "\n", "- 工单标题(文本输入),\n", "- 工单的文本正文(文本输入),以及\n", "- 用户添加的任何标签(分类输入)\n", "\n", "此模型将具有两个输出:\n", "\n", "- 介于 0 和 1 之间的优先级分数(标量 Sigmoid 输出),以及\n", "- 应该处理工单的部门(部门范围内的 Softmax 输出)。\n", "\n", "您可以使用函数式 API 通过几行代码构建此模型:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:39.345342Z", "iopub.status.busy": "2022-12-14T21:22:39.344761Z", "iopub.status.idle": "2022-12-14T21:22:39.813426Z", "shell.execute_reply": "2022-12-14T21:22:39.812752Z" }, "id": "49009e53da63" }, "outputs": [], "source": [ "num_tags = 12 # Number of unique issue tags\n", "num_words = 10000 # Size of vocabulary obtained when preprocessing text data\n", "num_departments = 4 # Number of departments for predictions\n", "\n", "title_input = keras.Input(\n", " shape=(None,), name=\"title\"\n", ") # Variable-length sequence of ints\n", "body_input = keras.Input(shape=(None,), name=\"body\") # Variable-length sequence of ints\n", "tags_input = keras.Input(\n", " shape=(num_tags,), name=\"tags\"\n", ") # Binary vectors of size `num_tags`\n", "\n", "# Embed each word in the title into a 64-dimensional vector\n", "title_features = layers.Embedding(num_words, 64)(title_input)\n", "# Embed each word in the text into a 64-dimensional vector\n", "body_features = layers.Embedding(num_words, 64)(body_input)\n", "\n", "# Reduce sequence of embedded words in the title into a single 128-dimensional vector\n", "title_features = layers.LSTM(128)(title_features)\n", "# Reduce sequence of embedded words in the body into a single 32-dimensional vector\n", "body_features = layers.LSTM(32)(body_features)\n", "\n", "# Merge all available features into a single large vector via concatenation\n", "x = layers.concatenate([title_features, body_features, tags_input])\n", "\n", "# Stick a logistic regression for priority prediction on top of the features\n", "priority_pred = layers.Dense(1, name=\"priority\")(x)\n", "# Stick a department classifier on top of the features\n", "department_pred = layers.Dense(num_departments, name=\"department\")(x)\n", "\n", "# Instantiate an end-to-end model predicting both priority and department\n", "model = keras.Model(\n", " inputs=[title_input, body_input, tags_input],\n", " outputs=[priority_pred, department_pred],\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "ee2735b3eff1" }, "source": [ "现在绘制模型:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:39.817475Z", "iopub.status.busy": "2022-12-14T21:22:39.816965Z", "iopub.status.idle": "2022-12-14T21:22:40.011587Z", "shell.execute_reply": "2022-12-14T21:22:40.010543Z" }, "id": "52c4dc6fd93e" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "keras.utils.plot_model(model, \"multi_input_and_output_model.png\", show_shapes=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "907c119d04a4" }, "source": [ "编译此模型时,可以为每个输出分配不同的损失。甚至可以为每个损失分配不同的权重,以调整其对总训练损失的贡献。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:40.016242Z", "iopub.status.busy": "2022-12-14T21:22:40.015556Z", "iopub.status.idle": "2022-12-14T21:22:40.030302Z", "shell.execute_reply": "2022-12-14T21:22:40.029717Z" }, "id": "3e1acef07668" }, "outputs": [], "source": [ "model.compile(\n", " optimizer=keras.optimizers.RMSprop(1e-3),\n", " loss=[\n", " keras.losses.BinaryCrossentropy(from_logits=True),\n", " keras.losses.CategoricalCrossentropy(from_logits=True),\n", " ],\n", " loss_weights=[1.0, 0.2],\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "c4bd84048d41" }, "source": [ "由于输出层具有不同的名称,您还可以使用对应的层名称指定损失和损失权重:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:40.033591Z", "iopub.status.busy": "2022-12-14T21:22:40.033369Z", "iopub.status.idle": "2022-12-14T21:22:40.042246Z", "shell.execute_reply": "2022-12-14T21:22:40.041628Z" }, "id": "37a6af4b30c8" }, "outputs": [], "source": [ "model.compile(\n", " optimizer=keras.optimizers.RMSprop(1e-3),\n", " loss={\n", " \"priority\": keras.losses.BinaryCrossentropy(from_logits=True),\n", " \"department\": keras.losses.CategoricalCrossentropy(from_logits=True),\n", " },\n", " loss_weights={\"priority\": 1.0, \"department\": 0.2},\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "845b20ca3c9d" }, "source": [ "通过传递输入和目标的 NumPy 数组列表来训练模型:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:40.045573Z", "iopub.status.busy": "2022-12-14T21:22:40.045086Z", "iopub.status.idle": "2022-12-14T21:22:50.291198Z", "shell.execute_reply": "2022-12-14T21:22:50.290530Z" }, "id": "ae5ff9364b19" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/40 [..............................] - ETA: 2:34 - loss: 1.2824 - priority_loss: 0.6962 - department_loss: 2.9312" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 2/40 [>.............................] - ETA: 4s - loss: 1.2990 - priority_loss: 0.6979 - department_loss: 3.0056 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 3/40 [=>............................] - ETA: 4s - loss: 1.2897 - priority_loss: 0.7019 - department_loss: 2.9389" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 4/40 [==>...........................] - ETA: 4s - loss: 1.2748 - priority_loss: 0.6993 - department_loss: 2.8777" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5/40 [==>...........................] - ETA: 4s - loss: 1.2535 - priority_loss: 0.6988 - department_loss: 2.7734" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 6/40 [===>..........................] - ETA: 4s - loss: 1.2486 - priority_loss: 0.6968 - department_loss: 2.7587" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 7/40 [====>.........................] - ETA: 4s - loss: 1.2364 - priority_loss: 0.6997 - department_loss: 2.6835" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 8/40 [=====>........................] - ETA: 3s - loss: 1.2473 - priority_loss: 0.6999 - department_loss: 2.7369" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9/40 [=====>........................] - ETA: 3s - loss: 1.2619 - priority_loss: 0.7001 - department_loss: 2.8092" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/40 [======>.......................] - ETA: 3s - loss: 1.2614 - priority_loss: 0.7003 - department_loss: 2.8054" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "11/40 [=======>......................] - ETA: 3s - loss: 1.2699 - priority_loss: 0.6986 - department_loss: 2.8562" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "12/40 [========>.....................] - ETA: 3s - loss: 1.2676 - priority_loss: 0.6985 - department_loss: 2.8454" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "13/40 [========>.....................] - ETA: 3s - loss: 1.2731 - priority_loss: 0.6981 - department_loss: 2.8749" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "14/40 [=========>....................] - ETA: 3s - loss: 1.2646 - priority_loss: 0.6981 - department_loss: 2.8325" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "16/40 [===========>..................] - ETA: 2s - loss: 1.2600 - priority_loss: 0.6987 - department_loss: 2.8065" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "17/40 [===========>..................] - ETA: 2s - loss: 1.2659 - priority_loss: 0.7000 - department_loss: 2.8294" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "18/40 [============>.................] - ETA: 2s - loss: 1.2648 - priority_loss: 0.6996 - department_loss: 2.8263" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "19/40 [=============>................] - ETA: 2s - loss: 1.2709 - priority_loss: 0.7000 - department_loss: 2.8546" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/40 [==============>...............] - ETA: 2s - loss: 1.2741 - priority_loss: 0.6995 - department_loss: 2.8727" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/40 [==============>...............] - ETA: 2s - loss: 1.2722 - priority_loss: 0.6994 - department_loss: 2.8638" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "24/40 [=================>............] - ETA: 1s - loss: 1.2778 - priority_loss: 0.6991 - department_loss: 2.8937" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "25/40 [=================>............] - ETA: 1s - loss: 1.2786 - priority_loss: 0.6988 - department_loss: 2.8992" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "26/40 [==================>...........] - ETA: 1s - loss: 1.2825 - priority_loss: 0.7000 - department_loss: 2.9125" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "28/40 [====================>.........] - ETA: 1s - loss: 1.2812 - priority_loss: 0.7006 - department_loss: 2.9030" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "29/40 [====================>.........] - ETA: 1s - loss: 1.2815 - priority_loss: 0.7002 - department_loss: 2.9066" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "30/40 [=====================>........] - ETA: 1s - loss: 1.2795 - priority_loss: 0.6997 - department_loss: 2.8989" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "31/40 [======================>.......] - ETA: 0s - loss: 1.2815 - priority_loss: 0.6999 - department_loss: 2.9081" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "34/40 [========================>.....] - ETA: 0s - loss: 1.2814 - priority_loss: 0.7002 - department_loss: 2.9057" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "36/40 [==========================>...] - ETA: 0s - loss: 1.2823 - priority_loss: 0.7001 - department_loss: 2.9110" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "40/40 [==============================] - 8s 91ms/step - loss: 1.2805 - priority_loss: 0.6993 - department_loss: 2.9057\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/40 [..............................] - ETA: 4s - loss: 1.3596 - priority_loss: 0.7339 - department_loss: 3.1284" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 3/40 [=>............................] - ETA: 2s - loss: 1.3326 - priority_loss: 0.7094 - department_loss: 3.1159" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 4/40 [==>...........................] - ETA: 3s - loss: 1.3325 - priority_loss: 0.7027 - department_loss: 3.1488" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5/40 [==>...........................] - ETA: 3s - loss: 1.3318 - priority_loss: 0.6993 - department_loss: 3.1621" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 6/40 [===>..........................] - ETA: 3s - loss: 1.3281 - priority_loss: 0.6992 - department_loss: 3.1445" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 7/40 [====>.........................] - ETA: 3s - loss: 1.3304 - priority_loss: 0.7042 - department_loss: 3.1309" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 8/40 [=====>........................] - ETA: 3s - loss: 1.3224 - priority_loss: 0.7003 - department_loss: 3.1105" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9/40 [=====>........................] - ETA: 3s - loss: 1.3077 - priority_loss: 0.6969 - department_loss: 3.0541" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/40 [======>.......................] - ETA: 3s - loss: 1.3075 - priority_loss: 0.6976 - department_loss: 3.0497" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "11/40 [=======>......................] - ETA: 3s - loss: 1.3109 - priority_loss: 0.6991 - department_loss: 3.0589" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "12/40 [========>.....................] - ETA: 3s - loss: 1.3107 - priority_loss: 0.6994 - department_loss: 3.0567" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "15/40 [==========>...................] - ETA: 2s - loss: 1.2931 - priority_loss: 0.6985 - department_loss: 2.9729" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "16/40 [===========>..................] - ETA: 2s - loss: 1.2967 - priority_loss: 0.7005 - department_loss: 2.9812" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/40 [==============>...............] - ETA: 1s - loss: 1.2922 - priority_loss: 0.7019 - department_loss: 2.9514" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/40 [==============>...............] - ETA: 1s - loss: 1.2916 - priority_loss: 0.7019 - department_loss: 2.9482" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/40 [===============>..............] - ETA: 1s - loss: 1.2927 - priority_loss: 0.7024 - department_loss: 2.9514" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "27/40 [===================>..........] - ETA: 1s - loss: 1.2854 - priority_loss: 0.7014 - department_loss: 2.9200" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "34/40 [========================>.....] - ETA: 0s - loss: 1.2836 - priority_loss: 0.7015 - department_loss: 2.9104" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "35/40 [=========================>....] - ETA: 0s - loss: 1.2834 - priority_loss: 0.7017 - department_loss: 2.9084" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "37/40 [==========================>...] - ETA: 0s - loss: 1.2857 - priority_loss: 0.7012 - department_loss: 2.9228" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "40/40 [==============================] - 3s 64ms/step - loss: 1.2832 - priority_loss: 0.7004 - department_loss: 2.9143\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Dummy input data\n", "title_data = np.random.randint(num_words, size=(1280, 10))\n", "body_data = np.random.randint(num_words, size=(1280, 100))\n", "tags_data = np.random.randint(2, size=(1280, num_tags)).astype(\"float32\")\n", "\n", "# Dummy target data\n", "priority_targets = np.random.random(size=(1280, 1))\n", "dept_targets = np.random.randint(2, size=(1280, num_departments))\n", "\n", "model.fit(\n", " {\"title\": title_data, \"body\": body_data, \"tags\": tags_data},\n", " {\"priority\": priority_targets, \"department\": dept_targets},\n", " epochs=2,\n", " batch_size=32,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "3c87f1fbe7aa" }, "source": [ "当使用 `Dataset` 对象调用拟合时,它应该会生成一个列表元组(如 `([title_data, body_data, tags_data], [priority_targets, dept_targets])` 或一个字典元组(如 `({'title': title_data, 'body': body_data, 'tags': tags_data}, {'priority': priority_targets, 'department': dept_targets})`)。\n", "\n", "有关详细说明,请参阅[训练和评估](https://tensorflow.google.cn/guide/keras/train_and_evaluate/)指南。" ] }, { "cell_type": "markdown", "metadata": { "id": "64ada3f80484" }, "source": [ "### 小 ResNet 模型\n", "\n", "除了具有多个输入和输出的模型外,函数式 API 还使处理非线性连接拓扑(这些模型的层没有按顺序连接)变得容易。这是 `Sequential` API 无法处理的。\n", "\n", "关于这一点的一个常见用例是残差连接。让我们来为 CIFAR10 构建一个小 ResNet 模型以进行演示:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:50.294844Z", "iopub.status.busy": "2022-12-14T21:22:50.294336Z", "iopub.status.idle": "2022-12-14T21:22:50.435735Z", "shell.execute_reply": "2022-12-14T21:22:50.434897Z" }, "id": "bfa8b7503813" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"toy_resnet\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "__________________________________________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # Connected to \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "==================================================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " img (InputLayer) [(None, 32, 32, 3)] 0 [] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_8 (Conv2D) (None, 30, 30, 32) 896 ['img[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_9 (Conv2D) (None, 28, 28, 64) 18496 ['conv2d_8[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " max_pooling2d_2 (MaxPooling2D) (None, 9, 9, 64) 0 ['conv2d_9[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_10 (Conv2D) (None, 9, 9, 64) 36928 ['max_pooling2d_2[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_11 (Conv2D) (None, 9, 9, 64) 36928 ['conv2d_10[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " add (Add) (None, 9, 9, 64) 0 ['conv2d_11[0][0]', \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 'max_pooling2d_2[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_12 (Conv2D) (None, 9, 9, 64) 36928 ['add[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_13 (Conv2D) (None, 9, 9, 64) 36928 ['conv2d_12[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " add_1 (Add) (None, 9, 9, 64) 0 ['conv2d_13[0][0]', \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 'add[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " conv2d_14 (Conv2D) (None, 7, 7, 64) 36928 ['add_1[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " global_average_pooling2d (Glob (None, 64) 0 ['conv2d_14[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " alAveragePooling2D) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_6 (Dense) (None, 256) 16640 ['global_average_pooling2d[0][0]'\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dropout (Dropout) (None, 256) 0 ['dense_6[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_7 (Dense) (None, 10) 2570 ['dropout[0][0]'] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "==================================================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 223,242\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 223,242\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "__________________________________________________________________________________________________\n" ] } ], "source": [ "inputs = keras.Input(shape=(32, 32, 3), name=\"img\")\n", "x = layers.Conv2D(32, 3, activation=\"relu\")(inputs)\n", "x = layers.Conv2D(64, 3, activation=\"relu\")(x)\n", "block_1_output = layers.MaxPooling2D(3)(x)\n", "\n", "x = layers.Conv2D(64, 3, activation=\"relu\", padding=\"same\")(block_1_output)\n", "x = layers.Conv2D(64, 3, activation=\"relu\", padding=\"same\")(x)\n", "block_2_output = layers.add([x, block_1_output])\n", "\n", "x = layers.Conv2D(64, 3, activation=\"relu\", padding=\"same\")(block_2_output)\n", "x = layers.Conv2D(64, 3, activation=\"relu\", padding=\"same\")(x)\n", "block_3_output = layers.add([x, block_2_output])\n", "\n", "x = layers.Conv2D(64, 3, activation=\"relu\")(block_3_output)\n", "x = layers.GlobalAveragePooling2D()(x)\n", "x = layers.Dense(256, activation=\"relu\")(x)\n", "x = layers.Dropout(0.5)(x)\n", "outputs = layers.Dense(10)(x)\n", "\n", "model = keras.Model(inputs, outputs, name=\"toy_resnet\")\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "05aefc66c54f" }, "source": [ "绘制模型:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:50.439164Z", "iopub.status.busy": "2022-12-14T21:22:50.438467Z", "iopub.status.idle": "2022-12-14T21:22:50.683526Z", "shell.execute_reply": "2022-12-14T21:22:50.682426Z" }, "id": "ef7ac19c83be" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "keras.utils.plot_model(model, \"mini_resnet.png\", show_shapes=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "4f0883eae520" }, "source": [ "现在训练模型:" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:22:50.688539Z", "iopub.status.busy": "2022-12-14T21:22:50.687895Z", "iopub.status.idle": "2022-12-14T21:23:00.813859Z", "shell.execute_reply": "2022-12-14T21:23:00.813125Z" }, "id": "4e1c7b530071" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 8192/170498071 [..............................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 204800/170498071 [..............................] - ETA: 56s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 1064960/170498071 [..............................] - ETA: 18s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 4464640/170498071 [..............................] - ETA: 6s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9093120/170498071 [>.............................] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 12345344/170498071 [=>............................] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 17170432/170498071 [==>...........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 19980288/170498071 [==>...........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 25001984/170498071 [===>..........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 27574272/170498071 [===>..........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 31883264/170498071 [====>.........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 34922496/170498071 [=====>........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 39436288/170498071 [=====>........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 43843584/170498071 [======>.......................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 48111616/170498071 [=======>......................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 52445184/170498071 [========>.....................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 56418304/170498071 [========>.....................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 61153280/170498071 [=========>....................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 64839680/170498071 [==========>...................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 70189056/170498071 [===========>..................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 73383936/170498071 [===========>..................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 78856192/170498071 [============>.................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 82214912/170498071 [=============>................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 87375872/170498071 [==============>...............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 91176960/170498071 [===============>..............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 95707136/170498071 [===============>..............] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 99819520/170498071 [================>.............] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "104103936/170498071 [=================>............] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "108503040/170498071 [==================>...........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "112558080/170498071 [==================>...........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "117284864/170498071 [===================>..........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "120954880/170498071 [====================>.........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "126205952/170498071 [=====================>........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "129556480/170498071 [=====================>........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "135045120/170498071 [======================>.......] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "138330112/170498071 [=======================>......] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "143548416/170498071 [========================>.....] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "147316736/170498071 [========================>.....] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "151945216/170498071 [=========================>....] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "156033024/170498071 [==========================>...] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "160276480/170498071 [===========================>..] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "164569088/170498071 [===========================>..] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "168796160/170498071 [============================>.] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "170498071/170498071 [==============================] - 2s 0us/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/13 [=>............................] - ETA: 35s - loss: 2.3094 - acc: 0.1094" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9/13 [===================>..........] - ETA: 0s - loss: 2.3044 - acc: 0.1042 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "13/13 [==============================] - ETA: 0s - loss: 2.2958 - acc: 0.1013" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "13/13 [==============================] - 3s 30ms/step - loss: 2.2958 - acc: 0.1013 - val_loss: 2.2806 - val_acc: 0.1900\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n", "\n", "x_train = x_train.astype(\"float32\") / 255.0\n", "x_test = x_test.astype(\"float32\") / 255.0\n", "y_train = keras.utils.to_categorical(y_train, 10)\n", "y_test = keras.utils.to_categorical(y_test, 10)\n", "\n", "model.compile(\n", " optimizer=keras.optimizers.RMSprop(1e-3),\n", " loss=keras.losses.CategoricalCrossentropy(from_logits=True),\n", " metrics=[\"acc\"],\n", ")\n", "# We restrict the data to the first 1000 samples so as to limit execution time\n", "# on Colab. Try to train on the entire dataset until convergence!\n", "model.fit(x_train[:1000], y_train[:1000], batch_size=64, epochs=1, validation_split=0.2)" ] }, { "cell_type": "markdown", "metadata": { "id": "e7f35a9a1061" }, "source": [ "## 共享层\n", "\n", "函数式 API 的另一个很好的用途是使用*shared layers*的模型。共享层是在同一个模型中多次重用的层实例,它们会学习与层计算图中的多个路径相对应的特征。\n", "\n", "共享层通常用于对来自相似空间(例如,两个具有相似词汇的不同文本)的输入进行编码。它们可以实现在这些不同的输入之间共享信息,以及在更少的数据上训练这种模型。如果在其中的一个输入中看到了一个给定单词,那么将有利于处理通过共享层的所有输入。\n", "\n", "要在函数式 API 中共享层,请多次调用同一个层实例。例如,下面是一个在两个不同文本输入之间共享的 `Embedding` 层:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:23:00.817513Z", "iopub.status.busy": "2022-12-14T21:23:00.817024Z", "iopub.status.idle": "2022-12-14T21:23:00.834636Z", "shell.execute_reply": "2022-12-14T21:23:00.834027Z" }, "id": "4b8e6a4f3e88" }, "outputs": [], "source": [ "# Embedding for 1000 unique words mapped to 128-dimensional vectors\n", "shared_embedding = layers.Embedding(1000, 128)\n", "\n", "# Variable-length sequence of integers\n", "text_input_a = keras.Input(shape=(None,), dtype=\"int32\")\n", "\n", "# Variable-length sequence of integers\n", "text_input_b = keras.Input(shape=(None,), dtype=\"int32\")\n", "\n", "# Reuse the same layer to encode both inputs\n", "encoded_input_a = shared_embedding(text_input_a)\n", "encoded_input_b = shared_embedding(text_input_b)" ] }, { "cell_type": "markdown", "metadata": { "id": "b4f193a74581" }, "source": [ "## 提取和重用层计算图中的节点\n", "\n", "由于要处理的层计算图是静态数据结构,可以对其进行访问和检查。而这就是将函数式模型绘制为图像的方式。\n", "\n", "这也意味着您可以访问中间层的激活函数(计算图中的“节点”)并在其他地方重用它们,这对于特征提取之类的操作十分有用。\n", "\n", "让我们来看一个例子。下面是一个 VGG19 模型,其权重已在 ImageNet 上进行了预训练:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:23:00.838606Z", "iopub.status.busy": "2022-12-14T21:23:00.837870Z", "iopub.status.idle": "2022-12-14T21:23:05.625518Z", "shell.execute_reply": "2022-12-14T21:23:05.624773Z" }, "id": "8bdaa209ccbe" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels.h5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 8192/574710816 [..............................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5799936/574710816 [..............................] - ETA: 4s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 16785408/574710816 [..............................] - ETA: 4s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 28114944/574710816 [>.............................] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 29999104/574710816 [>.............................] - ETA: 4s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 32063488/574710816 [>.............................] - ETA: 5s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 35987456/574710816 [>.............................] - ETA: 5s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 45228032/574710816 [=>............................] - ETA: 4s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 61489152/574710816 [==>...........................] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 65921024/574710816 [==>...........................] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 69402624/574710816 [==>...........................] - ETA: 4s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 71163904/574710816 [==>...........................] - ETA: 4s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 82771968/574710816 [===>..........................] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 99508224/574710816 [====>.........................] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "108756992/574710816 [====>.........................] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "112377856/574710816 [====>.........................] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "124493824/574710816 [=====>........................] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "135159808/574710816 [======>.......................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "147300352/574710816 [======>.......................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "156606464/574710816 [=======>......................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "168132608/574710816 [=======>......................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "183451648/574710816 [========>.....................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "199639040/574710816 [=========>....................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "215408640/574710816 [==========>...................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "220151808/574710816 [==========>...................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "222806016/574710816 [==========>...................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "236429312/574710816 [===========>..................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "248578048/574710816 [===========>..................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "263823360/574710816 [============>.................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "276832256/574710816 [=============>................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "291446784/574710816 [==============>...............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "304168960/574710816 [==============>...............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "318865408/574710816 [===============>..............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "333905920/574710816 [================>.............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "342966272/574710816 [================>.............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "356024320/574710816 [=================>............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "371826688/574710816 [==================>...........] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "388898816/574710816 [===================>..........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "405487616/574710816 [====================>.........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "421986304/574710816 [=====================>........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "438411264/574710816 [=====================>........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "454410240/574710816 [======================>.......] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "470245376/574710816 [=======================>......] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "480518144/574710816 [========================>.....] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "494444544/574710816 [========================>.....] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "509452288/574710816 [=========================>....] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "525434880/574710816 [==========================>...] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "539877376/574710816 [===========================>..] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "556146688/574710816 [============================>.] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "573169664/574710816 [============================>.] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "574710816/574710816 [==============================] - 3s 0us/step\n" ] } ], "source": [ "vgg19 = tf.keras.applications.VGG19()" ] }, { "cell_type": "markdown", "metadata": { "id": "874ef4b4de49" }, "source": [ "下面是通过查询计算图数据结构获得的模型的中间激活:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:23:05.629775Z", "iopub.status.busy": "2022-12-14T21:23:05.629272Z", "iopub.status.idle": "2022-12-14T21:23:05.632765Z", "shell.execute_reply": "2022-12-14T21:23:05.632228Z" }, "id": "391817839937" }, "outputs": [], "source": [ "features_list = [layer.output for layer in vgg19.layers]" ] }, { "cell_type": "markdown", "metadata": { "id": "e91a9dc2f5b0" }, "source": [ "使用以下特征来创建新的特征提取模型,该模型会返回中间层激活的值:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:23:05.636172Z", "iopub.status.busy": "2022-12-14T21:23:05.635707Z", "iopub.status.idle": "2022-12-14T21:23:05.836518Z", "shell.execute_reply": "2022-12-14T21:23:05.835882Z" }, "id": "36a450517b63" }, "outputs": [], "source": [ "feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)\n", "\n", "img = np.random.random((1, 224, 224, 3)).astype(\"float32\")\n", "extracted_features = feat_extraction_model(img)" ] }, { "cell_type": "markdown", "metadata": { "id": "f2ac248fe202" }, "source": [ "这尤其适用于诸如[神经样式转换](https://tensorflow.google.cn/tutorials/generative/style_transfer)之类的任务。" ] }, { "cell_type": "markdown", "metadata": { "id": "c894ba891064" }, "source": [ "## 使用自定义层扩展 API\n", "\n", "`tf.keras` 包含了各种内置层,例如:\n", "\n", "- 卷积层:`Conv1D`、`Conv2D`、`Conv3D`、`Conv2DTranspose`\n", "- 池化层:`MaxPooling1D`、`MaxPooling2D`、`MaxPooling3D`、`AveragePooling1D`\n", "- RNN 层:`GRU`、`LSTM`、`ConvLSTM2D`\n", "- `BatchNormalization`、`Dropout`、`Embedding` 等\n", "\n", "但是,如果找不到所需内容,可以通过创建您自己的层来方便地扩展 API。所有层都会子类化 `Layer` 类并实现下列方法:\n", "\n", "- `call` 方法,用于指定由层完成的计算。\n", "- `build` 方法,用于创建层的权重(这只是一种样式约定,因为您也可以在 `__init__` 中创建权重)。\n", "\n", "要详细了解从头开始创建层的详细信息,请阅读[自定义层和模型](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)指南。\n", "\n", "以下是 `tf.keras.layers.Dense` 的基本实现:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:23:05.840498Z", "iopub.status.busy": "2022-12-14T21:23:05.839958Z", "iopub.status.idle": "2022-12-14T21:23:05.885319Z", "shell.execute_reply": "2022-12-14T21:23:05.884710Z" }, "id": "1d9faf1f622a" }, "outputs": [], "source": [ "class CustomDense(layers.Layer):\n", " def __init__(self, units=32):\n", " super(CustomDense, self).__init__()\n", " self.units = units\n", "\n", " def build(self, input_shape):\n", " self.w = self.add_weight(\n", " shape=(input_shape[-1], self.units),\n", " initializer=\"random_normal\",\n", " trainable=True,\n", " )\n", " self.b = self.add_weight(\n", " shape=(self.units,), initializer=\"random_normal\", trainable=True\n", " )\n", "\n", " def call(self, inputs):\n", " return tf.matmul(inputs, self.w) + self.b\n", "\n", "\n", "inputs = keras.Input((4,))\n", "outputs = CustomDense(10)(inputs)\n", "\n", "model = keras.Model(inputs, outputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "b8933568358c" }, "source": [ "为了在您的自定义层中支持序列化,请定义一个`get_config`方法,该方法返回该层实例的构造函数参数:" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:23:05.888754Z", "iopub.status.busy": "2022-12-14T21:23:05.888186Z", "iopub.status.idle": "2022-12-14T21:23:05.915702Z", "shell.execute_reply": "2022-12-14T21:23:05.915099Z" }, "id": "b22a134918a2" }, "outputs": [], "source": [ "class CustomDense(layers.Layer):\n", " def __init__(self, units=32):\n", " super(CustomDense, self).__init__()\n", " self.units = units\n", "\n", " def build(self, input_shape):\n", " self.w = self.add_weight(\n", " shape=(input_shape[-1], self.units),\n", " initializer=\"random_normal\",\n", " trainable=True,\n", " )\n", " self.b = self.add_weight(\n", " shape=(self.units,), initializer=\"random_normal\", trainable=True\n", " )\n", "\n", " def call(self, inputs):\n", " return tf.matmul(inputs, self.w) + self.b\n", "\n", " def get_config(self):\n", " return {\"units\": self.units}\n", "\n", "\n", "inputs = keras.Input((4,))\n", "outputs = CustomDense(10)(inputs)\n", "\n", "model = keras.Model(inputs, outputs)\n", "config = model.get_config()\n", "\n", "new_model = keras.Model.from_config(config, custom_objects={\"CustomDense\": CustomDense})" ] }, { "cell_type": "markdown", "metadata": { "id": "015abf7d0508" }, "source": [ "您也可以选择实现 `from_config(cls, config)` 类方法,该方法用于在给定其配置字典的情况下重新创建层实例。`from_config` 的默认实现如下:\n", "\n", "```python\n", "def from_config(cls, config): return cls(**config)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "b4ead34e01dd" }, "source": [ "## 何时使用函数式 API\n", "\n", "什么时候应该使用 Keras 函数式 API 来创建新的模型,或者什么时候应该直接对 `Model` 类进行子类化呢?通常来说,函数式 API 更高级、更易用且更安全,并且具有许多子类化模型所不支持的功能。\n", "\n", "但是,当构建不容易表示为有向无环的层计算图的模型时,模型子类化会提供更大的灵活性。例如,您无法使用函数式 API 来实现 Tree-RNN,而必须直接子类化 `Model` 类。\n", "\n", "要深入了解函数式 API 和模型子类化之间的区别,请阅读 [TensorFlow 2.0 符号式 API 和命令式 API 介绍](https://blog.tensorflow.org/2019/01/what-are-symbolic-and-imperative-apis.html)。\n", "\n", "### 函数式 API 的优势:\n", "\n", "下列属性对于序贯模型(也是数据结构)同样适用,但对于子类化模型(是 Python 字节码而非数据结构)则不适用。\n", "\n", "#### 更加简洁\n", "\n", "没有 `super(MyClass, self).__init__(...)`,没有 `def call(self, ...):` 等内容。\n", "\n", "对比:\n", "\n", "```python\n", "inputs = keras.Input(shape=(32,)) x = layers.Dense(64, activation='relu')(inputs) outputs = layers.Dense(10)(x) mlp = keras.Model(inputs, outputs)\n", "```\n", "\n", "下面是子类化版本:\n", "\n", "```python\n", "class MLP(keras.Model): def __init__(self, **kwargs): super(MLP, self).__init__(**kwargs) self.dense_1 = layers.Dense(64, activation='relu') self.dense_2 = layers.Dense(10) def call(self, inputs): x = self.dense_1(inputs) return self.dense_2(x) # Instantiate the model. mlp = MLP() # Necessary to create the model's state. # The model doesn't have a state until it's called at least once. _ = mlp(tf.zeros((1, 32)))\n", "```\n", "\n", "#### 定义连接计算图时进行模型验证\n", "\n", "在函数式 API 中,输入规范(形状和 dtype)是预先创建的(使用 `Input`)。每次调用层时,该层都会检查传递给它的规范是否符合其假设,如不符合,它将引发有用的错误消息。\n", "\n", "这样可以保证能够使用函数式 API 构建的任何模型都可以运行。所有调试(除与收敛有关的调试外)均在模型构造的过程中静态发生,而不是在执行时发生。这类似于编译器中的类型检查。\n", "\n", "#### 函数式模型可绘制且可检查\n", "\n", "您可以将模型绘制为计算图,并且可以轻松访问该计算图中的中间节点。例如,要提取和重用中间层的激活(如前面的示例所示),请运行以下代码:\n", "\n", "```python\n", "features_list = [layer.output for layer in vgg19.layers] feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)\n", "```\n", "\n", "#### 函数式模型可以序列化或克隆\n", "\n", "因为函数式模型是数据结构而非一段代码,所以它可以安全地序列化,并且可以保存为单个文件,从而使您可以重新创建完全相同的模型,而无需访问任何原始代码。请参阅[序列化和保存](https://tensorflow.google.cn/guide/keras/save_and_serialize/)指南。\n", "\n", "要序列化子类化模型,实现器必须在模型级别指定 `get_config()` 和 `from_config()` 方法。\n", "\n", "### 函数式 API 的劣势:\n", "\n", "#### 不支持动态架构\n", "\n", "函数式 API 将模型视为层的 DAG。对于大多数深度学习架构来说确实如此,但并非所有(例如,递归网络或 Tree RNN 就不遵循此假设,无法在函数式 API 中实现)。" ] }, { "cell_type": "markdown", "metadata": { "id": "72992d4ed462" }, "source": [ "## 混搭 API 样式\n", "\n", "在函数式 API 或模型子类化之间进行选择并非是让您作出二选一的决定而将您限制在某一类模型中。`tf.keras` API 中的所有模型都可以彼此交互,无论它们是 `Sequential` 模型、函数式模型,还是从头开始编写的子类化模型。\n", "\n", "您始终可以将函数式模型或 `Sequential` 模型用作子类化模型或层的一部分:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:23:05.919621Z", "iopub.status.busy": "2022-12-14T21:23:05.918997Z", "iopub.status.idle": "2022-12-14T21:23:05.976125Z", "shell.execute_reply": "2022-12-14T21:23:05.975569Z" }, "id": "3c6221508766" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 10, 32)\n" ] } ], "source": [ "units = 32\n", "timesteps = 10\n", "input_dim = 5\n", "\n", "# Define a Functional model\n", "inputs = keras.Input((None, units))\n", "x = layers.GlobalAveragePooling1D()(inputs)\n", "outputs = layers.Dense(1)(x)\n", "model = keras.Model(inputs, outputs)\n", "\n", "\n", "class CustomRNN(layers.Layer):\n", " def __init__(self):\n", " super(CustomRNN, self).__init__()\n", " self.units = units\n", " self.projection_1 = layers.Dense(units=units, activation=\"tanh\")\n", " self.projection_2 = layers.Dense(units=units, activation=\"tanh\")\n", " # Our previously-defined Functional model\n", " self.classifier = model\n", "\n", " def call(self, inputs):\n", " outputs = []\n", " state = tf.zeros(shape=(inputs.shape[0], self.units))\n", " for t in range(inputs.shape[1]):\n", " x = inputs[:, t, :]\n", " h = self.projection_1(x)\n", " y = h + self.projection_2(state)\n", " state = y\n", " outputs.append(y)\n", " features = tf.stack(outputs, axis=1)\n", " print(features.shape)\n", " return self.classifier(features)\n", "\n", "\n", "rnn_model = CustomRNN()\n", "_ = rnn_model(tf.zeros((1, timesteps, input_dim)))" ] }, { "cell_type": "markdown", "metadata": { "id": "41f42eb2a9c0" }, "source": [ "您可以在函数式 API 中使用任何子类化层或模型,前提是它实现了遵循以下模式之一的 `call` 方法:\n", "\n", "- `call(self, inputs, **kwargs)` - 其中 `inputs` 是张量或张量的嵌套结构(例如张量列表),`**kwargs` 是非张量参数(非输入)。\n", "- `call(self, inputs, training=None, **kwargs)` - 其中 `training` 是指示该层是否应在训练模式和推断模式下运行的布尔值。\n", "- `call(self, inputs, mask=None, **kwargs)` - 其中 `mask` 是一个布尔掩码张量(对 RNN 等十分有用)。\n", "- `call(self, inputs, training=None, mask=None, **kwargs)` - 当然,您可以同时具有掩码和训练特有的行为。\n", "\n", "此外,如果您在自定义层或模型上实现了 `get_config` 方法,则您创建的函数式模型将仍可序列化和克隆。\n", "\n", "下面是一个从头开始编写、用于函数式模型的自定义 RNN 的简单示例:" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:23:05.979347Z", "iopub.status.busy": "2022-12-14T21:23:05.978887Z", "iopub.status.idle": "2022-12-14T21:23:06.161023Z", "shell.execute_reply": "2022-12-14T21:23:06.160378Z" }, "id": "3deb90222d05" }, "outputs": [], "source": [ "units = 32\n", "timesteps = 10\n", "input_dim = 5\n", "batch_size = 16\n", "\n", "\n", "class CustomRNN(layers.Layer):\n", " def __init__(self):\n", " super(CustomRNN, self).__init__()\n", " self.units = units\n", " self.projection_1 = layers.Dense(units=units, activation=\"tanh\")\n", " self.projection_2 = layers.Dense(units=units, activation=\"tanh\")\n", " self.classifier = layers.Dense(1)\n", "\n", " def call(self, inputs):\n", " outputs = []\n", " state = tf.zeros(shape=(inputs.shape[0], self.units))\n", " for t in range(inputs.shape[1]):\n", " x = inputs[:, t, :]\n", " h = self.projection_1(x)\n", " y = h + self.projection_2(state)\n", " state = y\n", " outputs.append(y)\n", " features = tf.stack(outputs, axis=1)\n", " return self.classifier(features)\n", "\n", "\n", "# Note that you specify a static batch size for the inputs with the `batch_shape`\n", "# arg, because the inner computation of `CustomRNN` requires a static batch size\n", "# (when you create the `state` zeros tensor).\n", "inputs = keras.Input(batch_shape=(batch_size, timesteps, input_dim))\n", "x = layers.Conv1D(32, 3)(inputs)\n", "outputs = CustomRNN()(x)\n", "\n", "model = keras.Model(inputs, outputs)\n", "\n", "rnn_model = CustomRNN()\n", "_ = rnn_model(tf.zeros((1, 10, 5)))" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "functional.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }