{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "wJcYs_ERTnnI" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T22:35:04.804882Z", "iopub.status.busy": "2022-12-14T22:35:04.804433Z", "iopub.status.idle": "2022-12-14T22:35:04.808309Z", "shell.execute_reply": "2022-12-14T22:35:04.807704Z" }, "id": "HMUDt0CiUJk9" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "77z2OchJTk0l" }, "source": [ "# LoggingTensorHook と StopAtStepHook を Keras コールバックに移行する\n", "\n", "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示\n", " Google Colab で実行\n", " GitHub でソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "meUTrR4I6m1C" }, "source": [ "TensorFlow 1 では、`tf.estimator.LoggingTensorHook` を使用してテンソルを監視および記録しますが、`tf.estimator.StopAtStepHook` は `tf.estimator.Estimator` でトレーニングする場合に指定されたステップでトレーニングを停止するのに役立ちます。このノートブックは、`Model.fit` でカスタム Keras コールバック(`tf.keras.callbacks.Callback`)を使用して、これらの API から TensorFlow 2 の同等のものに移行する方法を示しています。\n", "\n", "Keras [コールバック](https://www.tensorflow.org/guide/keras/custom_callback)は、組み込みの Keras `Model.fit`/`Model.evaluate`/`Model.predict` API でのトレーニング/評価/予測中のさまざまな時点で呼び出されるオブジェクトです。`tf.keras.callbacks.Callback` API ドキュメントでコールバックの詳細を学ぶことができます。また、[独自のコールバックの作成](../..guide/keras/custom_callback.ipynb/)および[組み込みメソッドを使用したトレーニングと評価](https://www.tensorflow.org/guide/keras/train_and_evaluate)(*コールバックの使用*セクション)ガイドも参照してください。 TensorFlow 1 の `SessionRunHook` から TensorFlow 2 の Keras コールバックへの移行については、[支援ロジック付きトレーニングの移行](sessionrunhook_callback.ipynb)ガイドをご覧ください。" ] }, { "cell_type": "markdown", "metadata": { "id": "YdZSoIXEbhg-" }, "source": [ "## セットアップ\n", "\n", "インポートとデモンストレーション用の単純なデータセットから始めます。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:35:04.811924Z", "iopub.status.busy": "2022-12-14T22:35:04.811317Z", "iopub.status.idle": "2022-12-14T22:35:06.725787Z", "shell.execute_reply": "2022-12-14T22:35:06.725101Z" }, "id": "iE0vSfMXumKI" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 22:35:05.753254: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:35:05.753360: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:35:05.753369: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import tensorflow as tf\n", "import tensorflow.compat.v1 as tf1" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:35:06.729949Z", "iopub.status.busy": "2022-12-14T22:35:06.729552Z", "iopub.status.idle": "2022-12-14T22:35:06.733780Z", "shell.execute_reply": "2022-12-14T22:35:06.733235Z" }, "id": "m7rnGxsXtDkV" }, "outputs": [], "source": [ "features = [[1., 1.5], [2., 2.5], [3., 3.5]]\n", "labels = [[0.3], [0.5], [0.7]]\n", "\n", "# Define an input function.\n", "def _input_fn():\n", " return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)" ] }, { "cell_type": "markdown", "metadata": { "id": "4uXff1BEssdE" }, "source": [ "## TensorFlow 1: テンソルをログに記録し、tf.estimator API でトレーニングを停止する" ] }, { "cell_type": "markdown", "metadata": { "id": "zW-X5cmzmkuw" }, "source": [ "TensorFlow 1 では、さまざまなフックを定義してトレーニング動作を制御します。次に、これらのフックを `tf.estimator.EstimatorSpec` に渡します。\n", "\n", "以下に例を示します。\n", "\n", "- テンソル(モデルの重みや損失など)を監視/記録するには、`tf.estimator.LoggingTensorHook`( `tf.train.LoggingTensorHook` はそのエイリアスです)を使用します。\n", "- 特定のステップでトレーニングを停止するには、`tf.estimator.StopAtStepHook` を使用します(
`tf.train.StopAtStepHook` はそのエイリアスです)。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:35:06.737306Z", "iopub.status.busy": "2022-12-14T22:35:06.736804Z", "iopub.status.idle": "2022-12-14T22:35:10.908322Z", "shell.execute_reply": "2022-12-14T22:35:10.907690Z" }, "id": "lqe9obf7suIj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpnls8yefw\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpnls8yefw', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Call initializer instance with the dtype argument instead of passing it to the constructor\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpnls8yefw/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.15147454, step = 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Tensor(\"Identity:0\", shape=(2, 1), dtype=float32) = [[ 0.7482141 ]\n", " [-0.03934455]], Tensor(\"Identity_1:0\", shape=(1,), dtype=float32) = [0.]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss from LoggingTensorHook = 0.15147454\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Tensor(\"Identity:0\", shape=(2, 1), dtype=float32) = [[ 0.7018909 ]\n", " [-0.08760582]], Tensor(\"Identity_1:0\", shape=(1,), dtype=float32) = [-0.04632322] (0.028 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 2 into /tmpfs/tmp/tmpnls8yefw/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 0.40761083.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def _model_fn(features, labels, mode):\n", " dense = tf1.layers.Dense(1)\n", " logits = dense(features)\n", " loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)\n", " optimizer = tf1.train.AdagradOptimizer(0.05)\n", " train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n", "\n", " # Define the stop hook.\n", " stop_hook = tf1.train.StopAtStepHook(num_steps=2)\n", "\n", " # Access tensors to be logged by names.\n", " kernel_name = tf.identity(dense.weights[0])\n", " bias_name = tf.identity(dense.weights[1])\n", " logging_weight_hook = tf1.train.LoggingTensorHook(\n", " tensors=[kernel_name, bias_name],\n", " every_n_iter=1)\n", " # Log the training loss by the tensor object.\n", " logging_loss_hook = tf1.train.LoggingTensorHook(\n", " {'loss from LoggingTensorHook': loss},\n", " every_n_secs=3)\n", "\n", " # Pass all hooks to `EstimatorSpec`.\n", " return tf1.estimator.EstimatorSpec(mode,\n", " loss=loss,\n", " train_op=train_op,\n", " training_hooks=[stop_hook,\n", " logging_weight_hook,\n", " logging_loss_hook])\n", "\n", "estimator = tf1.estimator.Estimator(model_fn=_model_fn)\n", "\n", "# Begin training.\n", "# The training will stop after 2 steps, and the weights/loss will also be logged.\n", "estimator.train(_input_fn)" ] }, { "cell_type": "markdown", "metadata": { "id": "KEmzBjfnsxwT" }, "source": [ "## TensorFlow 2: カスタムコールバックと Model.fit を使用してテンソルをログに記録し、トレーニングを停止する" ] }, { "cell_type": "markdown", "metadata": { "id": "839R9i4xheI5" }, "source": [ "TensorFlow 2 では、組み込みの Keras `Model.fit`(または `Model.evaluate`)をトレーニング/評価に使用する場合、カスタム `tf.keras.callbacks.Callback` を定義することで、テンソルの監視とトレーニングの停止を構成できます。次に、それらを `Model.fit`(または`Model.evaluate`)の `callbacks` パラメータに渡します。 (詳細については、[独自のコールバックの作成](../..guide/keras/custom_callback.ipynb)ガイドを参照してください。)\n", "\n", "以下に例を示します。\n", "\n", "- `StopAtStepHook` の機能を再作成するには、`on_batch_end` メソッドをオーバーライドして特定のステップ数の後にトレーニングを停止するカスタムコールバック(以下では `StopAtStepCallback` という名前)を定義します。\n", "- `LoggingTensorHook` の動作を再作成するには、名前によるテンソルへのアクセスがサポートされていないため、ログに記録されたテンソルを手動で記録および出力するカスタムコールバック(`LoggingTensorCallback`)を定義します。カスタムコールバック内にログ記録の頻度を実装することもできます。以下の例では、2 ステップごとに重みを出力します。 N 秒ごとにログを記録するなどの他のストラテジーも可能です。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:35:10.912049Z", "iopub.status.busy": "2022-12-14T22:35:10.911500Z", "iopub.status.idle": "2022-12-14T22:35:10.917836Z", "shell.execute_reply": "2022-12-14T22:35:10.917231Z" }, "id": "atVciNgPs0fw" }, "outputs": [], "source": [ "class StopAtStepCallback(tf.keras.callbacks.Callback):\n", " def __init__(self, stop_step=None):\n", " super().__init__()\n", " self._stop_step = stop_step\n", "\n", " def on_batch_end(self, batch, logs=None):\n", " if self.model.optimizer.iterations >= self._stop_step:\n", " self.model.stop_training = True\n", " print('\\nstop training now')\n", "\n", "class LoggingTensorCallback(tf.keras.callbacks.Callback):\n", " def __init__(self, every_n_iter):\n", " super().__init__()\n", " self._every_n_iter = every_n_iter\n", " self._log_count = every_n_iter\n", "\n", " def on_batch_end(self, batch, logs=None):\n", " if self._log_count > 0:\n", " self._log_count -= 1\n", " print(\"Logging Tensor Callback: dense/kernel:\",\n", " model.layers[0].weights[0])\n", " print(\"Logging Tensor Callback: dense/bias:\",\n", " model.layers[0].weights[1])\n", " print(\"Logging Tensor Callback loss:\", logs[\"loss\"])\n", " else:\n", " self._log_count -= self._every_n_iter" ] }, { "cell_type": "markdown", "metadata": { "id": "30a8b71263e0" }, "source": [ "終了したら、新しいコールバック(`StopAtStepCallback` と `LoggingTensorCallback`)を `Model.fit` の `callbacks` パラメータに渡します。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:35:10.921286Z", "iopub.status.busy": "2022-12-14T22:35:10.920737Z", "iopub.status.idle": "2022-12-14T22:35:11.575247Z", "shell.execute_reply": "2022-12-14T22:35:11.574556Z" }, "id": "Kip65sYBlKiu" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Logging Tensor Callback: dense/kernel: \n", "Logging Tensor Callback: dense/bias: \n", "Logging Tensor Callback loss: 1.5501126050949097\n", "\r", "1/3 [=========>....................] - ETA: 0s - loss: 1.5501" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "stop training now\n", "Logging Tensor Callback: dense/kernel: \n", "Logging Tensor Callback: dense/bias: \n", "Logging Tensor Callback loss: 2.257307767868042\n", "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "3/3 [==============================] - 0s 4ms/step - loss: 2.2573\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)\n", "model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])\n", "optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)\n", "model.compile(optimizer, \"mse\")\n", "\n", "# Begin training.\n", "# The training will stop after 2 steps, and the weights/loss will also be logged.\n", "model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),\n", " LoggingTensorCallback(every_n_iter=2)])" ] }, { "cell_type": "markdown", "metadata": { "id": "19508f4720f5" }, "source": [ "## Next steps\n", "\n", "コールバックの詳細については、次を参照してください。\n", "\n", "- API ドキュメント: `tf.keras.callbacks.Callback`\n", "- ガイド: [独自のコールバックの作成](../..guide/keras/custom_callback.ipynb/)\n", "- ガイド: [組み込みメソッドを使用したトレーニングと評価](https://www.tensorflow.org/guide/keras/train_and_evaluate)(*コールバックの使用*セクション)\n", "\n", "次の移行関連のリソースも役立つ場合があります。\n", "\n", "- [早期停止移行ガイド](early_stopping.ipynb): `tf.keras.callbacks.EarlyStopping` は組み込みの早期停止コールバックです\n", "- [TensorBoard 移行ガイド](tensorboard.ipynb): TensorBoard により、指標の追跡と表示が可能になります\n", "- [支援ロジック付きトレーニングの移行ガイド](sessionrunhook_callback.ipynb): TensorFlow 1 の `SessionRunHook` から TensorFlow 2 の Keras コールバックへ" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "logging_stop_hook.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }