{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "d6p8EySq1zXZ" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-11-07T19:01:30.580391Z", "iopub.status.busy": "2023-11-07T19:01:30.579923Z", "iopub.status.idle": "2023-11-07T19:01:30.583514Z", "shell.execute_reply": "2023-11-07T19:01:30.582961Z" }, "id": "KsOkK8O69PyT" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "F1xIRPtY0E1w" }, "source": [ "# 通过 Keras 模型创建 Estimator" ] }, { "cell_type": "markdown", "metadata": { "id": "r61fkA2i9Y3_" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看在 Google Colab 中运行在 GitHub 上查看源代码下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "Dhcq8Ds4mCtm" }, "source": [ "> 警告:不建议将 Estimator 用于新代码。Estimator 运行 `v1.Session` 风格的代码,此类代码更加难以正确编写,并且可能会出现意外行为,尤其是与 TF 2 代码结合使用时。Estimator 确实在我们的[兼容性保证](https://tensorflow.org/guide/versions)范围内,但除了安全漏洞之外不会得到任何修复。请参阅[迁移指南](https://tensorflow.org/guide/migrate)以了解详情。" ] }, { "cell_type": "markdown", "metadata": { "id": "ZaGcclVLwqDS" }, "source": [ "## 概述\n", "\n", "TensorFlow 支持 TensorFlow Estimator,可以从新的和现有的 `tf.keras` 模型创建 Estimator。本教程包含了该过程完整且最为简短的示例。\n", "\n", "注:如果您有 Keras 模型,可以直接将其与 [`tf.distribute` 策略](https://tensorflow.org/guide/migrate/guide/distributed_training)一起使用,而无需将其转换为 Estimator。因此,不再推荐使用 `model_to_estimator`。" ] }, { "cell_type": "markdown", "metadata": { "id": "epgfaZmO2vF0" }, "source": [ "## 设置" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:01:30.587509Z", "iopub.status.busy": "2023-11-07T19:01:30.587003Z", "iopub.status.idle": "2023-11-07T19:01:33.237937Z", "shell.execute_reply": "2023-11-07T19:01:33.237233Z" }, "id": "Qmq4FzaztASN" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 19:01:31.022571: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-11-07 19:01:31.022614: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-11-07 19:01:31.024113: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "import numpy as np\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": { "id": "9ZUATGJGtQIU" }, "source": [ "### 创建一个简单的 Keras 模型。" ] }, { "cell_type": "markdown", "metadata": { "id": "rR-zPidHyzcb" }, "source": [ "在 Keras 中,需要通过组装*层*来构建*模型*。模型(通常)是由层构成的计算图。最常见的模型类型是一种叠加层:`tf.keras.Sequential` 模型。\n", "\n", "构建一个简单的全连接网络(即多层感知器):" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:01:33.242660Z", "iopub.status.busy": "2023-11-07T19:01:33.242269Z", "iopub.status.idle": "2023-11-07T19:01:35.536082Z", "shell.execute_reply": "2023-11-07T19:01:35.535338Z" }, "id": "p5NSx38itD1a" }, "outputs": [], "source": [ "model = tf.keras.models.Sequential([\n", " tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),\n", " tf.keras.layers.Dropout(0.2),\n", " tf.keras.layers.Dense(3)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "ABgo9-8BtYNs" }, "source": [ "编译模型并获取摘要。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:01:35.540284Z", "iopub.status.busy": "2023-11-07T19:01:35.540013Z", "iopub.status.idle": "2023-11-07T19:01:35.562332Z", "shell.execute_reply": "2023-11-07T19:01:35.561744Z" }, "id": "nViACuBDtVEC" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense (Dense) (None, 16) 80 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dropout (Dropout) (None, 16) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_1 (Dense) (None, 3) 51 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 131 (524.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 131 (524.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " optimizer='adam')\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "pM3Cx5Fm_sHI" }, "source": [ "### 创建输入函数\n", "\n", "使用 [Datasets API](../../guide/data.md) 可以扩展到大型数据集或多设备训练。\n", "\n", "Estimator 需要控制构建输入流水线的时间和方式。为此,它们需要一个“输入函数”或 `input_fn`。`Estimator` 将不使用任何参数调用此函数。`input_fn` 必须返回 `tf.data.Dataset`。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:01:35.568775Z", "iopub.status.busy": "2023-11-07T19:01:35.568301Z", "iopub.status.idle": "2023-11-07T19:01:35.572432Z", "shell.execute_reply": "2023-11-07T19:01:35.571792Z" }, "id": "H0DpLEop_x0o" }, "outputs": [], "source": [ "def input_fn():\n", " split = tfds.Split.TRAIN\n", " dataset = tfds.load('iris', split=split, as_supervised=True)\n", " dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))\n", " dataset = dataset.batch(32).repeat()\n", " return dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "UR1vRw1bBFjo" }, "source": [ "测试您的 `input_fn`" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:01:35.575733Z", "iopub.status.busy": "2023-11-07T19:01:35.575201Z", "iopub.status.idle": "2023-11-07T19:01:36.824765Z", "shell.execute_reply": "2023-11-07T19:01:36.823921Z" }, "id": "WO94bGYKBKRv" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'dense_input': }\n", "tf.Tensor([0 2 1 2 0 1 1 1 0 2 1 0 2 0 0 0 0 0 2 2 2 2 2 0 2 0 2 1 1 1 1 1], shape=(32,), dtype=int64)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 19:01:36.811575: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" ] } ], "source": [ "for features_batch, labels_batch in input_fn().take(1):\n", " print(features_batch)\n", " print(labels_batch)" ] }, { "cell_type": "markdown", "metadata": { "id": "svdhkQ4Otcv0" }, "source": [ "### 通过 tf.keras 模型创建 Estimator。\n", "\n", "可以使用 `tf.estimator` API 来训练 `tf.keras.Model`,方法是使用 `tf.keras.estimator.model_to_estimator` 将模型转换为 `tf.estimator.Estimator` 对象。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:01:36.828806Z", "iopub.status.busy": "2023-11-07T19:01:36.828124Z", "iopub.status.idle": "2023-11-07T19:01:37.240806Z", "shell.execute_reply": "2023-11-07T19:01:37.240091Z" }, "id": "roChngg8t7il" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using the Keras model provided.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Using the Keras model provided.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:You are using `tf.keras.optimizers.experimental.Optimizer` in TF estimator, which only supports `tf.keras.optimizers.legacy.Optimizer`. Automatically converting your optimizer to `tf.keras.optimizers.legacy.Optimizer`.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/backend.py:452: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/keras_lib.py:740: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 19:01:37.125728: W tensorflow/c/c_api.cc:305] Operation '{name:'training/Adam/dense_1/bias/m/Assign' id:190 op device:{requested: '', assigned: ''} def:{{{node training/Adam/dense_1/bias/m/Assign}} = AssignVariableOp[_has_manual_control_dependencies=true, dtype=DT_FLOAT, validate_shape=false](training/Adam/dense_1/bias/m, training/Adam/dense_1/bias/m/Initializer/zeros)}}' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.\n", "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/keras_lib.py:740: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpqgqyc7xz', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpqgqyc7xz', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:2404: WarmStartSettings.__new__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:2404: WarmStartSettings.__new__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] } ], "source": [ "import tempfile\n", "model_dir = tempfile.mkdtemp()\n", "keras_estimator = tf.keras.estimator.model_to_estimator(\n", " keras_model=model, model_dir=model_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "U-8ekW5It_2w" }, "source": [ "训练和评估 Estimator。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T19:01:37.245161Z", "iopub.status.busy": "2023-11-07T19:01:37.244534Z", "iopub.status.idle": "2023-11-07T19:01:41.854222Z", "shell.execute_reply": "2023-11-07T19:01:41.853534Z" }, "id": "ouIkVtp9uAg5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/keras_lib.py:400: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/keras_lib.py:400: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/keras_lib.py:419: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/keras_lib.py:419: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmpqgqyc7xz/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmpqgqyc7xz/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmpqgqyc7xz/keras/keras_model.ckpt\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmpqgqyc7xz/keras/keras_model.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Warm-started 4 variables.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Warm-started 4 variables.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpqgqyc7xz/model.ckpt.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpqgqyc7xz/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 1.6025522, step = 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 1.6025522, step = 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 531.011\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 531.011\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.8349081, step = 100 (0.190 sec)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.8349081, step = 100 (0.190 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 695.053\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 695.053\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.63622224, step = 200 (0.144 sec)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.63622224, step = 200 (0.144 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 682.461\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 682.461\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.5632105, step = 300 (0.147 sec)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.5632105, step = 300 (0.147 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 694.078\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 694.078\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.46418035, step = 400 (0.144 sec)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.46418035, step = 400 (0.144 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 500 into /tmpfs/tmp/tmpqgqyc7xz/model.ckpt.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 500 into /tmpfs/tmp/tmpqgqyc7xz/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 0.4637785.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 0.4637785.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/training_v1.py:2335: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.\n", " updates = self.state_updates\n", "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Starting evaluation at 2023-11-07T19:01:41\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Starting evaluation at 2023-11-07T19:01:41\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpqgqyc7xz/model.ckpt-500\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpqgqyc7xz/model.ckpt-500\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [1/10]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [1/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [2/10]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [2/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [3/10]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [3/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [4/10]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [4/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [5/10]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [5/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [6/10]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [6/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [7/10]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [7/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [8/10]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [8/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [9/10]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [9/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [10/10]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [10/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Inference Time : 0.78849s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Inference Time : 0.78849s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Finished evaluation at 2023-11-07-19:01:41\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Finished evaluation at 2023-11-07-19:01:41\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving dict for global step 500: global_step = 500, loss = 0.37220484\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Saving dict for global step 500: global_step = 500, loss = 0.37220484\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmpfs/tmp/tmpqgqyc7xz/model.ckpt-500\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmpfs/tmp/tmpqgqyc7xz/model.ckpt-500\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Eval result: {'loss': 0.37220484, 'global_step': 500}\n" ] } ], "source": [ "keras_estimator.train(input_fn=input_fn, steps=500)\n", "eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)\n", "print('Eval result: {}'.format(eval_result))" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "keras_model_to_estimator.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }