{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "1Z6Wtb_jisbA" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-11-07T19:05:32.448616Z", "iopub.status.busy": "2023-11-07T19:05:32.448399Z", "iopub.status.idle": "2023-11-07T19:05:32.452280Z", "shell.execute_reply": "2023-11-07T19:05:32.451614Z" }, "id": "QUyRGn9riopB" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "H1yCdGFW4j_F" }, "source": [ "# 预创建的 Estimators" ] }, { "cell_type": "markdown", "metadata": { "id": "PS6_yKSoyLAl" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 tensorFlow.google.cn 上查看在 Google Colab 中运行在 GitHub 上查看源代码 下载笔记本\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "stQiPWL6ni6_" }, "source": [ "> 警告:不建议将 Estimator 用于新代码。Estimator 运行 `v1.Session` 风格的代码,此类代码更加难以正确编写,并且可能会出现意外行为,尤其是与 TF 2 代码结合使用时。Estimator 确实在我们的[兼容性保证范围](https://tensorflow.org/guide/versions)内,但除了安全漏洞之外不会得到任何修复。请参阅[迁移指南](https://tensorflow.org/guide/migrate)以了解详情。" ] }, { "cell_type": "markdown", "metadata": { "id": "R4YZ_ievcY7p" }, "source": [ "本教程向您展示了如何使用 Estimator 在 TensorFlow 中解决鸢尾花分类问题。Estimator 是完整模型在旧版 TensorFlow 中的高级表示。有关更多详细信息,请参阅 [Estimator](https://tensorflow.google.cn/guide/estimator)。\n", "\n", "注:在 TensorFlow 2.0 中,[Keras API](https://tensorflow.google.cn/guide/keras) 可以完成这些相同的任务,并且被认为是一个更容易学习的 API。如果您刚入门,建议您从 Keras 开始。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "8IFct0yedsTy" }, "source": [ "## 首先要做的事\n", "\n", "为了开始,您将首先导入 Tensorflow 和一系列您需要的库。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:32.455857Z", "iopub.status.busy": "2023-11-07T19:05:32.455638Z", "iopub.status.idle": "2023-11-07T19:05:34.784119Z", "shell.execute_reply": "2023-11-07T19:05:34.783350Z" }, "id": "jPo5bQwndr9P" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 19:05:32.884593: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-11-07 19:05:32.884639: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-11-07 19:05:32.886222: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": { "id": "c5w4m5gncnGh" }, "source": [ "## 数据集\n", "\n", "本文档中的示例程序构建并测试了一个模型,该模型根据[花萼](https://en.wikipedia.org/wiki/Sepal)和[花瓣](https://en.wikipedia.org/wiki/Petal)的大小将鸢尾花分成三种物种。\n", "\n", "您将使用鸢尾花数据集训练模型。该数据集包括四个特征和一个[标签](https://developers.google.com/machine-learning/glossary/#label)。这四个特征确定了单个鸢尾花的以下植物学特征:\n", "\n", "- 花萼长度\n", "- 花萼宽度\n", "- 花瓣长度\n", "- 花瓣宽度\n", "\n", "根据这些信息,您可以定义一些有用的常量来解析数据:\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:34.788213Z", "iopub.status.busy": "2023-11-07T19:05:34.787822Z", "iopub.status.idle": "2023-11-07T19:05:34.791492Z", "shell.execute_reply": "2023-11-07T19:05:34.790846Z" }, "id": "lSyrXp_He_UE" }, "outputs": [], "source": [ "CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']\n", "SPECIES = ['Setosa', 'Versicolor', 'Virginica']" ] }, { "cell_type": "markdown", "metadata": { "id": "j6mTfIQzfC9w" }, "source": [ "接下来,使用 Keras 与 Pandas 下载并解析鸢尾花数据集。注意为训练和测试保留不同的数据集。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:34.794457Z", "iopub.status.busy": "2023-11-07T19:05:34.794194Z", "iopub.status.idle": "2023-11-07T19:05:34.963721Z", "shell.execute_reply": "2023-11-07T19:05:34.963102Z" }, "id": "PumyCN8VdGGc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "2194/2194 [==============================] - 0s 0us/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "573/573 [==============================] - 0s 0us/step\n" ] } ], "source": [ "train_path = tf.keras.utils.get_file(\n", " \"iris_training.csv\", \"https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv\")\n", "test_path = tf.keras.utils.get_file(\n", " \"iris_test.csv\", \"https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv\")\n", "\n", "train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)\n", "test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)" ] }, { "cell_type": "markdown", "metadata": { "id": "wHFxNLszhQjz" }, "source": [ "通过检查数据您可以发现有四列浮点型特征和一列 int32 型标签。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:34.966976Z", "iopub.status.busy": "2023-11-07T19:05:34.966760Z", "iopub.status.idle": "2023-11-07T19:05:34.979314Z", "shell.execute_reply": "2023-11-07T19:05:34.978759Z" }, "id": "WOJt-ML4hAwI" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SepalLengthSepalWidthPetalLengthPetalWidthSpecies
06.42.85.62.22
15.02.33.31.01
24.92.54.51.72
34.93.11.50.10
45.73.81.70.30
\n", "
" ], "text/plain": [ " SepalLength SepalWidth PetalLength PetalWidth Species\n", "0 6.4 2.8 5.6 2.2 2\n", "1 5.0 2.3 3.3 1.0 1\n", "2 4.9 2.5 4.5 1.7 2\n", "3 4.9 3.1 1.5 0.1 0\n", "4 5.7 3.8 1.7 0.3 0" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "jQJEYfVvfznP" }, "source": [ "对于每个数据集都分割出标签,模型将被训练来预测这些标签。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:34.982696Z", "iopub.status.busy": "2023-11-07T19:05:34.982106Z", "iopub.status.idle": "2023-11-07T19:05:34.990658Z", "shell.execute_reply": "2023-11-07T19:05:34.990112Z" }, "id": "zM0wz2TueuA6" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SepalLengthSepalWidthPetalLengthPetalWidth
06.42.85.62.2
15.02.33.31.0
24.92.54.51.7
34.93.11.50.1
45.73.81.70.3
\n", "
" ], "text/plain": [ " SepalLength SepalWidth PetalLength PetalWidth\n", "0 6.4 2.8 5.6 2.2\n", "1 5.0 2.3 3.3 1.0\n", "2 4.9 2.5 4.5 1.7\n", "3 4.9 3.1 1.5 0.1\n", "4 5.7 3.8 1.7 0.3" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_y = train.pop('Species')\n", "test_y = test.pop('Species')\n", "\n", "# The label column has now been removed from the features.\n", "train.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "jZx1L_1Vcmxv" }, "source": [ "## Estimator 编程概述\n", "\n", "现在您已经设置了数据,可以使用 TensorFlow Estimator 定义模型。 Estimator 是从 `tf.estimator.Estimator` 派生的任何类。TensorFlow 提供了一组 `tf.estimator`(例如 `LinearRegressor`)来实现常见的 ML 算法。除此之外,您可以编写自己的[自定义 Estimator](https://tensorflow.google.cn/guide/estimator#custom_estimators)。建议在刚开始时使用预制的 Estimator。\n", "\n", "为了编写基于预创建的 Estimator 的 Tensorflow 项目,您必须完成以下工作:\n", "\n", "- 创建一个或多个输入函数\n", "- 定义模型的特征列\n", "- 实例化一个 Estimator,指定特征列和各种超参数。\n", "- 在 Estimator 对象上调用一个或多个方法,传递合适的输入函数以作为数据源。\n", "\n", "我们来看看这些任务是如何在鸢尾花分类中实现的。" ] }, { "cell_type": "markdown", "metadata": { "id": "2OcguDfBcmmg" }, "source": [ "## 创建输入函数\n", "\n", "您必须创建输入函数来提供用于训练、评估和预测的数据。\n", "\n", "**输入函数**是一个返回 `tf.data.Dataset` 对象的函数,此对象会输出下列含两个元素的元组:\n", "\n", "- [`features`](https://developers.google.com/machine-learning/glossary/#feature)——Python字典,其中:\n", " - 每个键都是特征名称\n", " - 每个值都是包含此特征所有值的数组\n", "- `label` 包含每个样本的[标签](https://developers.google.com/machine-learning/glossary/#label)的值的数组。\n", "\n", "为了向您展示输入函数的格式,请查看下面这个简单的实现:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:34.993885Z", "iopub.status.busy": "2023-11-07T19:05:34.993668Z", "iopub.status.idle": "2023-11-07T19:05:34.997832Z", "shell.execute_reply": "2023-11-07T19:05:34.997166Z" }, "id": "nzr5vRr5caGF" }, "outputs": [], "source": [ "def input_evaluation_set():\n", " features = {'SepalLength': np.array([6.4, 5.0]),\n", " 'SepalWidth': np.array([2.8, 2.3]),\n", " 'PetalLength': np.array([5.6, 3.3]),\n", " 'PetalWidth': np.array([2.2, 1.0])}\n", " labels = np.array([2, 1])\n", " return features, labels" ] }, { "cell_type": "markdown", "metadata": { "id": "NpXvGjfnjHgY" }, "source": [ "您的输入函数可以用您喜欢的任何方式生成 `features`字典和`label` 列表。但是,推荐使用 TensorFlow 的 [Dataset API](https://tensorflow.google.cn/guide/datasets),它可以解析各种数据。\n", "\n", "Dataset API 可以为您处理很多常见情况。例如,使用 Dataset API,您可以轻松地从大量文件中并行读取记录,并将它们合并为单个数据流。\n", "\n", "为了简化此示例,我们将使用 [pandas](https://pandas.pydata.org/) 加载数据,并利用此内存数据构建输入管道。\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:35.001032Z", "iopub.status.busy": "2023-11-07T19:05:35.000451Z", "iopub.status.idle": "2023-11-07T19:05:35.004564Z", "shell.execute_reply": "2023-11-07T19:05:35.003946Z" }, "id": "T20u1anCi8NP" }, "outputs": [], "source": [ "def input_fn(features, labels, training=True, batch_size=256):\n", " \"\"\"An input function for training or evaluating\"\"\"\n", " # Convert the inputs to a Dataset.\n", " dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))\n", "\n", " # Shuffle and repeat if you are in training mode.\n", " if training:\n", " dataset = dataset.shuffle(1000).repeat()\n", " \n", " return dataset.batch(batch_size)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "xIwcFT4MlZEi" }, "source": [ "## 定义特征列(feature columns)\n", "\n", "[**特征列(feature columns)**](https://developers.google.com/machine-learning/glossary/#feature_columns)是一个对象,用于描述模型应该如何使用特征字典中的原始输入数据。当您构建一个 Estimator 模型的时候,您会向其传递一个特征列的列表,其中包含您希望模型使用的每个特征。`tf.feature_column` 模块提供了许多为模型表示数据的选项。\n", "\n", "对于鸢尾花,4 个原始特征是数值,因此您将构建一个特征列列表来告诉 Estimator 模型将四个特征中的每一个表示为 32 位浮点值。因此,创建特征列的代码为:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:35.008601Z", "iopub.status.busy": "2023-11-07T19:05:35.008372Z", "iopub.status.idle": "2023-11-07T19:05:35.012560Z", "shell.execute_reply": "2023-11-07T19:05:35.011780Z" }, "id": "ZTTriO8FlSML" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_420919/1593920324.py:4: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n" ] } ], "source": [ "# Feature columns describe how to use the input.\n", "my_feature_columns = []\n", "for key in train.keys():\n", " my_feature_columns.append(tf.feature_column.numeric_column(key=key))" ] }, { "cell_type": "markdown", "metadata": { "id": "jpKkhMoZljco" }, "source": [ "特征列可能比这里显示的要复杂得多。您可以在[此指南](https://tensorflow.google.cn/guide/feature_columns)中阅读有关特征列的更多信息。\n", "\n", "我们已经介绍了如何使模型表示原始特征,现在您可以构建 Estimator 了。" ] }, { "cell_type": "markdown", "metadata": { "id": "kuE59XHEl22K" }, "source": [ "## 实例化 Estimator\n", "\n", "鸢尾花为题是一个经典的分类问题。幸运的是,Tensorflow 提供了几个预创建的 Estimator 分类器,其中包括:\n", "\n", "- `tf.estimator.DNNClassifier` 用于多类别分类的深度模型\n", "- `tf.estimator.DNNLinearCombinedClassifier` 用于广度与深度模型\n", "- `tf.estimator.LinearClassifier` 用于基于线性模型的分类器\n", "\n", "对于鸢尾花问题,`tf.estimator.DNNClassifier` 似乎是最好的选择。您可以这样实例化该 Estimator:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:35.015611Z", "iopub.status.busy": "2023-11-07T19:05:35.015393Z", "iopub.status.idle": "2023-11-07T19:05:37.321097Z", "shell.execute_reply": "2023-11-07T19:05:37.320409Z" }, "id": "qnf4o2V5lcPn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_420919/2221267581.py:2: DNNClassifierV2.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/head_utils.py:59: MultiClassHead.__init__ (from tensorflow_estimator.python.estimator.head.multi_class_head) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:759: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpoulu7cx6\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpoulu7cx6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" ] } ], "source": [ "# Build a DNN with 2 hidden layers with 30 and 10 hidden nodes each.\n", "classifier = tf.estimator.DNNClassifier(\n", " feature_columns=my_feature_columns,\n", " # Two hidden layers of 30 and 10 nodes respectively.\n", " hidden_units=[30, 10],\n", " # The model must choose between 3 classes.\n", " n_classes=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "tzzt5nUpmEe3" }, "source": [ "## 训练、评估和预测\n", "\n", "我们已经有一个 Estimator 对象,现在可以调用方法来执行下列操作:\n", "\n", "- 训练模型。\n", "- 评估经过训练的模型。\n", "- 使用经过训练的模型进行预测。" ] }, { "cell_type": "markdown", "metadata": { "id": "rnihuLdWmE75" }, "source": [ "### 训练模型\n", "\n", "通过调用 Estimator 的 `Train` 方法来训练模型,如下所示:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:37.324549Z", "iopub.status.busy": "2023-11-07T19:05:37.324303Z", "iopub.status.idle": "2023-11-07T19:05:46.899488Z", "shell.execute_reply": "2023-11-07T19:05:46.898815Z" }, "id": "4jW08YtPl1iS" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/legacy/adagrad.py:93: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Call initializer instance with the dtype argument instead of passing it to the constructor\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 19:05:37.897085: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:\n", "type_id: TFT_OPTIONAL\n", "args {\n", " type_id: TFT_PRODUCT\n", " args {\n", " type_id: TFT_TENSOR\n", " args {\n", " type_id: TFT_INT64\n", " }\n", " }\n", "}\n", " is neither a subtype nor a supertype of the combined inputs preceding it:\n", "type_id: TFT_OPTIONAL\n", "args {\n", " type_id: TFT_PRODUCT\n", " args {\n", " type_id: TFT_TENSOR\n", " args {\n", " type_id: TFT_INT32\n", " }\n", " }\n", "}\n", "\n", "\tfor Tuple type infernce function 0\n", "\twhile inferring type of node 'dnn/zero_fraction/cond/output/_18'\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpoulu7cx6/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 1.5073682, step = 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 439.465\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 1.0172048, step = 100 (0.229 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 606.207\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.88746536, step = 200 (0.165 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 616.132\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.8561147, step = 300 (0.162 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 609.728\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.8380361, step = 400 (0.164 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 612.992\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.83253044, step = 500 (0.163 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 613.225\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.7997909, step = 600 (0.163 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 620.106\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.8083163, step = 700 (0.161 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 600.24\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.7587497, step = 800 (0.166 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 622.992\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.76957697, step = 900 (0.160 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 608.274\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.7207211, step = 1000 (0.164 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 613.67\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.7105023, step = 1100 (0.163 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 623.223\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.7275356, step = 1200 (0.161 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 584.544\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.7416762, step = 1300 (0.171 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 595.923\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.716245, step = 1400 (0.168 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 594.104\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.69990337, step = 1500 (0.168 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 596.8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.69416165, step = 1600 (0.168 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 597.293\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.67331016, step = 1700 (0.167 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 601.12\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.6699522, step = 1800 (0.166 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 589.327\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.66161495, step = 1900 (0.170 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 586.803\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.6554887, step = 2000 (0.170 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 587.731\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.6613943, step = 2100 (0.170 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 601.215\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.6285989, step = 2200 (0.166 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 603.741\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.64100504, step = 2300 (0.166 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 604.138\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.62196255, step = 2400 (0.165 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 599.965\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.59547615, step = 2500 (0.167 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 593.705\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.5903188, step = 2600 (0.168 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 604.55\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.616672, step = 2700 (0.165 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 627.945\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.60870504, step = 2800 (0.159 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 623.905\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.59756136, step = 2900 (0.160 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 613.108\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.5934744, step = 3000 (0.163 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 616.35\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.59139955, step = 3100 (0.162 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 599.788\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.584731, step = 3200 (0.167 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 594.577\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.5786096, step = 3300 (0.168 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 605.614\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.58198833, step = 3400 (0.165 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 594.772\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.57257384, step = 3500 (0.168 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 585.277\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.5604176, step = 3600 (0.171 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 604.04\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.550858, step = 3700 (0.166 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 585.579\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.57899547, step = 3800 (0.171 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 593.479\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.54325897, step = 3900 (0.168 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 605.619\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.57464546, step = 4000 (0.165 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 615.744\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.54382163, step = 4100 (0.162 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 612.668\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.5404015, step = 4200 (0.163 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 589.679\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.5463786, step = 4300 (0.169 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 587.872\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.54900044, step = 4400 (0.170 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 602.697\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.52490914, step = 4500 (0.166 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 601.161\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.51717925, step = 4600 (0.166 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 615.13\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.51362664, step = 4700 (0.163 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 620.598\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.5242693, step = 4800 (0.161 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 602.749\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.5284901, step = 4900 (0.166 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 5000 into /tmpfs/tmp/tmpoulu7cx6/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 0.512492.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Train the Model.\n", "classifier.train(\n", " input_fn=lambda: input_fn(train, train_y, training=True),\n", " steps=5000)" ] }, { "cell_type": "markdown", "metadata": { "id": "ybiTFDmlmes8" }, "source": [ "注意将 ` input_fn` 调用封装在 [`lambda`](https://docs.python.org/3/tutorial/controlflow.html) 中以获取参数,同时提供不带参数的输入函数,如 Estimator 所预期的那样。`step` 参数告知该方法在训练多少步后停止训练。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "HNvJLH8hmsdf" }, "source": [ "### 评估经过训练的模型\n", "\n", "现在模型已经经过训练,您可以获取一些关于模型性能的统计信息。代码块将在测试数据上对经过训练的模型的准确率(accuracy)进行评估:\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:46.903262Z", "iopub.status.busy": "2023-11-07T19:05:46.902625Z", "iopub.status.idle": "2023-11-07T19:05:47.980307Z", "shell.execute_reply": "2023-11-07T19:05:47.979497Z" }, "id": "A169XuO4mKxF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Starting evaluation at 2023-11-07T19:05:47\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpoulu7cx6/model.ckpt-5000\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Inference Time : 0.70216s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Finished evaluation at 2023-11-07-19:05:47\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.53333336, average_loss = 0.6654332, global_step = 5000, loss = 0.6654332\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmpfs/tmp/tmpoulu7cx6/model.ckpt-5000\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Test set accuracy: 0.533\n", "\n" ] } ], "source": [ "eval_result = classifier.evaluate(\n", " input_fn=lambda: input_fn(test, test_y, training=False))\n", "\n", "print('\\nTest set accuracy: {accuracy:0.3f}\\n'.format(**eval_result))" ] }, { "cell_type": "markdown", "metadata": { "id": "VnPMP5EHph17" }, "source": [ "与对 `train` 方法的调用不同,我们没有传递 `steps` 参数来进行评估。用于评估的 `input_fn` 只生成一个 [epoch](https://developers.google.com/machine-learning/glossary/#epoch) 的数据。\n", "\n", "`eval_result` 字典亦包含 `average_loss`(每个样本的平均误差),`loss`(每个 mini-batch 的平均误差)与 Estimator 的 `global_step`(经历的训练迭代次数)值。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ur624ibpp52X" }, "source": [ "### 利用经过训练的模型进行预测(推理)\n", "\n", "我们已经有一个经过训练的模型,可以生成准确的评估结果。我们现在可以使用经过训练的模型,根据一些无标签测量结果预测鸢尾花的品种。与训练和评估一样,我们使用单个函数调用进行预测:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:47.984574Z", "iopub.status.busy": "2023-11-07T19:05:47.983853Z", "iopub.status.idle": "2023-11-07T19:05:47.989003Z", "shell.execute_reply": "2023-11-07T19:05:47.988265Z" }, "id": "wltc0jpgng38" }, "outputs": [], "source": [ "# Generate predictions from the model\n", "expected = ['Setosa', 'Versicolor', 'Virginica']\n", "predict_x = {\n", " 'SepalLength': [5.1, 5.9, 6.9],\n", " 'SepalWidth': [3.3, 3.0, 3.1],\n", " 'PetalLength': [1.7, 4.2, 5.4],\n", " 'PetalWidth': [0.5, 1.5, 2.1],\n", "}\n", "\n", "def input_fn(features, batch_size=256):\n", " \"\"\"An input function for prediction.\"\"\"\n", " # Convert the inputs to a Dataset without labels.\n", " return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)\n", "\n", "predictions = classifier.predict(\n", " input_fn=lambda: input_fn(predict_x))" ] }, { "cell_type": "markdown", "metadata": { "id": "JsETKQo0rHvi" }, "source": [ "`predict` 方法返回一个 Python 可迭代对象,为每个样本生成一个预测结果字典。以下代码输出了一些预测及其概率:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:05:47.992580Z", "iopub.status.busy": "2023-11-07T19:05:47.992001Z", "iopub.status.idle": "2023-11-07T19:05:48.455160Z", "shell.execute_reply": "2023-11-07T19:05:48.454480Z" }, "id": "Efm4mLzkrCxp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/base_head.py:786: ClassificationOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/multi_class_head.py:455: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpoulu7cx6/model.ckpt-5000\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Prediction is \"Setosa\" (78.5%), expected \"Setosa\"\n", "Prediction is \"Virginica\" (40.7%), expected \"Versicolor\"\n", "Prediction is \"Virginica\" (75.4%), expected \"Virginica\"\n" ] } ], "source": [ "for pred_dict, expec in zip(predictions, expected):\n", " class_id = pred_dict['class_ids'][0]\n", " probability = pred_dict['probabilities'][class_id]\n", "\n", " print('Prediction is \"{}\" ({:.1f}%), expected \"{}\"'.format(\n", " SPECIES[class_id], 100 * probability, expec))" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "premade.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }