{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "W7rEsKyWcxmu" }, "source": [ "##### Copyright 2019 The TF-Agents Authors.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2021-02-12T22:19:28.086042Z", "iopub.status.busy": "2021-02-12T22:19:28.085333Z", "iopub.status.idle": "2021-02-12T22:19:28.087803Z", "shell.execute_reply": "2021-02-12T22:19:28.087280Z" }, "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "G6aOV15Wc4HP" }, "source": [ "### CheckpointerとPolicySaver\n", "\n", "\n", " \n", " \n", " \n", " \n", "
TensorFlow.orgで表示 Google Colabで実行GitHub でソースを表示{ノートブックをダウンロード/a0}
" ] }, { "cell_type": "markdown", "metadata": { "id": "M3HE5S3wsMEh" }, "source": [ "## はじめに\n", "\n", "`tf_agents.utils.common.Checkpointer`は、ローカルストレージとの間でトレーニングの状態、ポリシーの状態、およびreplay_bufferの状態を保存/読み込むユーティリティです。\n", "\n", "`tf_agents.policies.policy_saver.PolicySaver`は、ポリシーのみを保存/読み込むツールであり、`Checkpointer`よりも軽量です。`PolicySaver`を使用すると、ポリシーを作成したコードに関する知識がなくてもモデルをデプロイできます。\n", "\n", "このチュートリアルでは、DQNを使用してモデルをトレーニングし、次に`Checkpointer`と`PolicySaver`を使用して、状態とモデルをインタラクティブな方法で保存および読み込む方法を紹介します。`PolicySaver`では、TF2.0の新しいsaved_modelツールとフォーマットを使用することに注意してください。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vbTrDrX4dkP_" }, "source": [ "## セットアップ" ] }, { "cell_type": "markdown", "metadata": { "id": "Opk_cVDYdgct" }, "source": [ "以下の依存関係をインストールしていない場合は、実行します。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:19:28.100419Z", "iopub.status.busy": "2021-02-12T22:19:28.099728Z", "iopub.status.idle": "2021-02-12T22:19:43.083818Z", "shell.execute_reply": "2021-02-12T22:19:43.084343Z" }, "id": "Jv668dKvZmka" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 100%\r", "\r", "Reading package lists... Done\r", "\r\n", "\r", "Building dependency tree... 0%\r", "\r", "Building dependency tree... 0%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree... 50%\r", "\r", "Building dependency tree... 50%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree \r", "\r\n", "\r", "Reading state information... 0%\r", "\r", "Reading state information... 0%\r", "\r", "Reading state information... Done\r", "\r\n", "ffmpeg is already the newest version (7:3.4.8-0ubuntu0.2).\r\n", "xvfb is already the newest version (2:1.19.6-1ubuntu4.8).\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The following packages were automatically installed and are no longer required:\r\n", " adwaita-icon-theme ca-certificates-java dconf-gsettings-backend\r\n", " dconf-service default-jre default-jre-headless dkms fonts-dejavu-extra\r\n", " freeglut3 freeglut3-dev g++-6 glib-networking glib-networking-common\r\n", " glib-networking-services gsettings-desktop-schemas gtk-update-icon-cache\r\n", " hicolor-icon-theme humanity-icon-theme java-common libaccinj64-9.1\r\n", " libatk-bridge2.0-0 libatk-wrapper-java libatk-wrapper-java-jni libatk1.0-0\r\n", " libatk1.0-data libatspi2.0-0 libavahi-client3 libavahi-common-data\r\n", " libavahi-common3 libcairo-gobject2 libcolord2 libcudart9.1 libcufft9.1\r\n", " libcufftw9.1 libcups2 libcurand9.1 libcusolver9.1 libcusparse9.1 libdconf1\r\n", " libdrm-dev libegl-mesa0 libegl1 libegl1-mesa libepoxy0 libgbm1 libgif7\r\n", " libgl1-mesa-dev libgles1 libgles2 libglu1-mesa libglu1-mesa-dev\r\n", " libglvnd-core-dev libglvnd-dev libgtk-3-0 libgtk-3-common libgtk2.0-0\r\n", " libgtk2.0-common libice-dev libjansson4 libjson-glib-1.0-0\r\n", " libjson-glib-1.0-common liblcms2-2 libnppc9.1 libnppial9.1 libnppicc9.1\r\n", " libnppicom9.1 libnppidei9.1 libnppif9.1 libnppig9.1 libnppim9.1 libnppist9.1\r\n", " libnppisu9.1 libnppitc9.1 libnpps9.1 libnvrtc9.1 libnvtoolsext1 libnvvm3\r\n", " libopengl0 libpcsclite1 libproxy1v5 libpthread-stubs0-dev librest-0.7-0\r\n", " libsm-dev libsoup-gnome2.4-1 libsoup2.4-1 libstdc++-6-dev libthrust-dev\r\n", " libvdpau-dev libwayland-server0 libx11-dev libx11-xcb-dev libxau-dev\r\n", " libxcb-dri2-0-dev libxcb-dri3-dev libxcb-glx0-dev libxcb-present-dev\r\n", " libxcb-randr0 libxcb-randr0-dev libxcb-render0-dev libxcb-shape0-dev\r\n", " libxcb-sync-dev libxcb-xfixes0-dev libxcb1-dev libxcomposite1 libxdamage-dev\r\n", " libxdmcp-dev libxext-dev libxfixes-dev libxft2 libxi-dev libxmu-dev\r\n", " libxmu-headers libxnvctrl0 libxshmfence-dev libxt-dev libxtst6 libxxf86dga1\r\n", " libxxf86vm-dev linux-gcp-5.3-headers-5.3.0-1030 linux-gcp-headers-5.0.0-1026\r\n", " linux-headers-5.3.0-1030-gcp linux-image-5.3.0-1030-gcp\r\n", " linux-modules-5.3.0-1030-gcp linux-modules-extra-5.3.0-1030-gcp\r\n", " mesa-common-dev ocl-icd-libopencl1 ocl-icd-opencl-dev opencl-c-headers\r\n", " openjdk-11-jre openjdk-11-jre-headless openjdk-8-jre openjdk-8-jre-headless\r\n", " pkg-config policykit-1-gnome python3-xkit screen-resolution-extra\r\n", " ubuntu-mono x11-utils x11proto-core-dev x11proto-damage-dev x11proto-dev\r\n", " x11proto-fixes-dev x11proto-input-dev x11proto-xext-dev\r\n", " x11proto-xf86vidmode-dev xorg-sgml-doctools xserver-xorg-core-hwe-18.04\r\n", " xtrans-dev\r\n", "Use 'sudo apt autoremove' to remove them.\r\n", "0 upgraded, 0 newly installed, 0 to remove and 93 not upgraded.\r\n" ] } ], "source": [ "#@test {\"skip\": true}\n", "!sudo apt-get install -y xvfb ffmpeg\n", "!pip install -q 'gym==0.10.11'\n", "!pip install -q 'imageio==2.4.0'\n", "!pip install -q 'pyglet==1.3.2'\n", "!pip install -q 'xvfbwrapper==0.2.9'\n", "!pip install -q tf-agents" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:19:43.092225Z", "iopub.status.busy": "2021-02-12T22:19:43.091447Z", "iopub.status.idle": "2021-02-12T22:19:49.844194Z", "shell.execute_reply": "2021-02-12T22:19:49.844609Z" }, "id": "bQMULMo1dCEn" }, "outputs": [], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "import base64\n", "import imageio\n", "import io\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import os\n", "import shutil\n", "import tempfile\n", "import tensorflow as tf\n", "import zipfile\n", "import IPython\n", "\n", "try:\n", " from google.colab import files\n", "except ImportError:\n", " files = None\n", "from tf_agents.agents.dqn import dqn_agent\n", "from tf_agents.drivers import dynamic_step_driver\n", "from tf_agents.environments import suite_gym\n", "from tf_agents.environments import tf_py_environment\n", "from tf_agents.eval import metric_utils\n", "from tf_agents.metrics import tf_metrics\n", "from tf_agents.networks import q_network\n", "from tf_agents.policies import policy_saver\n", "from tf_agents.policies import py_tf_eager_policy\n", "from tf_agents.policies import random_tf_policy\n", "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n", "from tf_agents.trajectories import trajectory\n", "from tf_agents.utils import common\n", "\n", "tf.compat.v1.enable_v2_behavior()\n", "\n", "tempdir = os.getenv(\"TEST_TMPDIR\", tempfile.gettempdir())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:19:49.848823Z", "iopub.status.busy": "2021-02-12T22:19:49.848050Z", "iopub.status.idle": "2021-02-12T22:19:49.965490Z", "shell.execute_reply": "2021-02-12T22:19:49.964823Z" }, "id": "AwIqiLdDCX9Q" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "# Set up a virtual display for rendering OpenAI gym environments.\n", "import xvfbwrapper\n", "xvfbwrapper.Xvfb(1400, 900, 24).start()" ] }, { "cell_type": "markdown", "metadata": { "id": "AOv_kofIvWnW" }, "source": [ "## DQNエージェント\n", "\n", "前のColabと同じように、DQNエージェントを設定します。 このColabでは、詳細は主な部分ではないので、デフォルトでは非表示になっていますが、「コードを表示」をクリックすると詳細を表示できます。" ] }, { "cell_type": "markdown", "metadata": { "id": "cStmaxredFSW" }, "source": [ "### ハイパーパラメーター" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "cellView": "both", "execution": { "iopub.execute_input": "2021-02-12T22:19:49.971585Z", "iopub.status.busy": "2021-02-12T22:19:49.970717Z", "iopub.status.idle": "2021-02-12T22:19:49.973025Z", "shell.execute_reply": "2021-02-12T22:19:49.972505Z" }, "id": "yxFs6QU0dGI_" }, "outputs": [], "source": [ "env_name = \"CartPole-v1\"\n", "\n", "collect_steps_per_iteration = 100\n", "replay_buffer_capacity = 100000\n", "\n", "fc_layer_params = (100,)\n", "\n", "batch_size = 64\n", "learning_rate = 1e-3\n", "log_interval = 5\n", "\n", "num_eval_episodes = 10\n", "eval_interval = 1000" ] }, { "cell_type": "markdown", "metadata": { "id": "w4GR7RDndIOR" }, "source": [ "### 環境" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:19:49.977525Z", "iopub.status.busy": "2021-02-12T22:19:49.976883Z", "iopub.status.idle": "2021-02-12T22:19:49.994869Z", "shell.execute_reply": "2021-02-12T22:19:49.994376Z" }, "id": "fZwK4d-bdI7Z" }, "outputs": [], "source": [ "train_py_env = suite_gym.load(env_name)\n", "eval_py_env = suite_gym.load(env_name)\n", "\n", "train_env = tf_py_environment.TFPyEnvironment(train_py_env)\n", "eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "0AvYRwfkeMvo" }, "source": [ "### エージェント" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "cellView": "both", "execution": { "iopub.execute_input": "2021-02-12T22:19:50.000563Z", "iopub.status.busy": "2021-02-12T22:19:49.999910Z", "iopub.status.idle": "2021-02-12T22:19:52.150355Z", "shell.execute_reply": "2021-02-12T22:19:52.149759Z" }, "id": "cUrFl83ieOvV" }, "outputs": [], "source": [ "#@title\n", "q_net = q_network.QNetwork(\n", " train_env.observation_spec(),\n", " train_env.action_spec(),\n", " fc_layer_params=fc_layer_params)\n", "\n", "optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)\n", "\n", "global_step = tf.compat.v1.train.get_or_create_global_step()\n", "\n", "agent = dqn_agent.DqnAgent(\n", " train_env.time_step_spec(),\n", " train_env.action_spec(),\n", " q_network=q_net,\n", " optimizer=optimizer,\n", " td_errors_loss_fn=common.element_wise_squared_loss,\n", " train_step_counter=global_step)\n", "agent.initialize()" ] }, { "cell_type": "markdown", "metadata": { "id": "p8ganoJhdsbn" }, "source": [ "### データ収集" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "cellView": "both", "execution": { "iopub.execute_input": "2021-02-12T22:19:52.156254Z", "iopub.status.busy": "2021-02-12T22:19:52.155578Z", "iopub.status.idle": "2021-02-12T22:19:53.981330Z", "shell.execute_reply": "2021-02-12T22:19:53.980730Z" }, "id": "XiT1p78HdtSe" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tf_agents/drivers/dynamic_step_driver.py:203: calling while_loop_v2 (from tensorflow.python.ops.control_flow_ops) with back_prop=False is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "back_prop=False is deprecated. Consider using tf.stop_gradient instead.\n", "Instead of:\n", "results = tf.while_loop(c, b, vars, back_prop=False)\n", "Use:\n", "results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/autograph/operators/control_flow.py:1218: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use `as_dataset(..., single_deterministic_pass=False) instead.\n" ] } ], "source": [ "#@title\n", "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n", " data_spec=agent.collect_data_spec,\n", " batch_size=train_env.batch_size,\n", " max_length=replay_buffer_capacity)\n", "\n", "collect_driver = dynamic_step_driver.DynamicStepDriver(\n", " train_env,\n", " agent.collect_policy,\n", " observers=[replay_buffer.add_batch],\n", " num_steps=collect_steps_per_iteration)\n", "\n", "# Initial data collection\n", "collect_driver.run()\n", "\n", "# Dataset generates trajectories with shape [BxTx...] where\n", "# T = n_step_update + 1.\n", "dataset = replay_buffer.as_dataset(\n", " num_parallel_calls=3, sample_batch_size=batch_size,\n", " num_steps=2).prefetch(3)\n", "\n", "iterator = iter(dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "8V8bojrKdupW" }, "source": [ "### エージェントのトレーニング" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "cellView": "both", "execution": { "iopub.execute_input": "2021-02-12T22:19:53.987864Z", "iopub.status.busy": "2021-02-12T22:19:53.987061Z", "iopub.status.idle": "2021-02-12T22:19:53.989218Z", "shell.execute_reply": "2021-02-12T22:19:53.989595Z" }, "id": "-rDC3leXdvm_" }, "outputs": [], "source": [ "#@title\n", "# (Optional) Optimize by wrapping some of the code in a graph using TF function.\n", "agent.train = common.function(agent.train)\n", "\n", "def train_one_iteration():\n", "\n", " # Collect a few steps using collect_policy and save to the replay buffer.\n", " for _ in range(collect_steps_per_iteration):\n", " collect_driver.run()\n", "\n", " # Sample a batch of data from the buffer and update the agent's network.\n", " experience, unused_info = next(iterator)\n", " train_loss = agent.train(experience)\n", "\n", " iteration = agent.train_step_counter.numpy()\n", " print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))" ] }, { "cell_type": "markdown", "metadata": { "id": "vgqVaPnUeDAn" }, "source": [ "### ビデオ生成" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "cellView": "both", "execution": { "iopub.execute_input": "2021-02-12T22:19:53.996329Z", "iopub.status.busy": "2021-02-12T22:19:53.995617Z", "iopub.status.idle": "2021-02-12T22:19:53.997627Z", "shell.execute_reply": "2021-02-12T22:19:53.998026Z" }, "id": "ZY6w-fcieFDW" }, "outputs": [], "source": [ "#@title\n", "def embed_gif(gif_buffer):\n", " \"\"\"Embeds a gif file in the notebook.\"\"\"\n", " tag = ''.format(base64.b64encode(gif_buffer).decode())\n", " return IPython.display.HTML(tag)\n", "\n", "def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env):\n", " num_episodes = 3\n", " frames = []\n", " for _ in range(num_episodes):\n", " time_step = eval_tf_env.reset()\n", " frames.append(eval_py_env.render())\n", " while not time_step.is_last():\n", " action_step = policy.action(time_step)\n", " time_step = eval_tf_env.step(action_step.action)\n", " frames.append(eval_py_env.render())\n", " gif_file = io.BytesIO()\n", " imageio.mimsave(gif_file, frames, format='gif', fps=60)\n", " IPython.display.display(embed_gif(gif_file.getvalue()))" ] }, { "cell_type": "markdown", "metadata": { "id": "y-oA8VYJdFdj" }, "source": [ "### ビデオ生成\n", "\n", "ビデオを生成して、ポリシーのパフォーマンスを確認します。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:19:54.002279Z", "iopub.status.busy": "2021-02-12T22:19:54.001618Z", "iopub.status.idle": "2021-02-12T22:19:55.370231Z", "shell.execute_reply": "2021-02-12T22:19:55.370629Z" }, "id": "FpmPLXWbdG70" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "global_step:\n", "\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print ('global_step:')\n", "print (global_step)\n", "run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "7RPLExsxwnOm" }, "source": [ "## チェックポインタとPolicySaverのセットアップ\n", "\n", "CheckpointerとPolicySaverを使用する準備ができました。" ] }, { "cell_type": "markdown", "metadata": { "id": "g-iyQJacfQqO" }, "source": [ "### Checkpointer\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:19:55.376454Z", "iopub.status.busy": "2021-02-12T22:19:55.375831Z", "iopub.status.idle": "2021-02-12T22:19:55.379689Z", "shell.execute_reply": "2021-02-12T22:19:55.379220Z" }, "id": "2DzCJZ-6YYbX" }, "outputs": [], "source": [ "checkpoint_dir = os.path.join(tempdir, 'checkpoint')\n", "train_checkpointer = common.Checkpointer(\n", " ckpt_dir=checkpoint_dir,\n", " max_to_keep=1,\n", " agent=agent,\n", " policy=agent.policy,\n", " replay_buffer=replay_buffer,\n", " global_step=global_step\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "MKpWNZM4WE8d" }, "source": [ "### Policy Saver" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:19:55.384397Z", "iopub.status.busy": "2021-02-12T22:19:55.383790Z", "iopub.status.idle": "2021-02-12T22:19:55.459488Z", "shell.execute_reply": "2021-02-12T22:19:55.459914Z" }, "id": "8mDZ_YMUWEY9" }, "outputs": [], "source": [ "policy_dir = os.path.join(tempdir, 'policy')\n", "tf_policy_saver = policy_saver.PolicySaver(agent.policy)" ] }, { "cell_type": "markdown", "metadata": { "id": "1OnANb1Idx8-" }, "source": [ "### 1回のイテレーションのトレーニング" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:19:55.464252Z", "iopub.status.busy": "2021-02-12T22:19:55.463635Z", "iopub.status.idle": "2021-02-12T22:21:30.964340Z", "shell.execute_reply": "2021-02-12T22:21:30.963838Z" }, "id": "ql_D1iq8dl0X" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training one iteration....\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "back_prop=False is deprecated. Consider using tf.stop_gradient instead.\n", "Instead of:\n", "results = tf.foldr(fn, elems, back_prop=False)\n", "Use:\n", "results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "back_prop=False is deprecated. Consider using tf.stop_gradient instead.\n", "Instead of:\n", "results = tf.foldr(fn, elems, back_prop=False)\n", "Use:\n", "results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "iteration: 1 loss: 1.025463342666626\n" ] } ], "source": [ "#@test {\"skip\": true}\n", "print('Training one iteration....')\n", "train_one_iteration()" ] }, { "cell_type": "markdown", "metadata": { "id": "eSChNSQPlySb" }, "source": [ "### チェックポイントに保存" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:21:30.969468Z", "iopub.status.busy": "2021-02-12T22:21:30.968528Z", "iopub.status.idle": "2021-02-12T22:21:30.996920Z", "shell.execute_reply": "2021-02-12T22:21:30.996392Z" }, "id": "usDm_Wpsl0bu" }, "outputs": [], "source": [ "train_checkpointer.save(global_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "gTQUrKgihuic" }, "source": [ "### チェックポイントに復元\n", "\n", "チェックポイントに復元するためには、チェックポイントが作成されたときと同じ方法でオブジェクト全体を再作成する必要があります。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:21:31.001884Z", "iopub.status.busy": "2021-02-12T22:21:31.000911Z", "iopub.status.idle": "2021-02-12T22:21:31.003694Z", "shell.execute_reply": "2021-02-12T22:21:31.003200Z" }, "id": "l6l3EB-Yhwmz" }, "outputs": [], "source": [ "train_checkpointer.initialize_or_restore()\n", "global_step = tf.compat.v1.train.get_global_step()" ] }, { "cell_type": "markdown", "metadata": { "id": "Nb8_MSE2XjRp" }, "source": [ "また、ポリシーを保存して指定する場所にエクスポートします。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:21:31.008935Z", "iopub.status.busy": "2021-02-12T22:21:31.007953Z", "iopub.status.idle": "2021-02-12T22:21:31.333352Z", "shell.execute_reply": "2021-02-12T22:21:31.332856Z" }, "id": "3xHz09WCWjwA" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as QNetwork_layer_call_and_return_conditional_losses, QNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, dense_1_layer_call_and_return_conditional_losses while saving (showing 5 of 25). These functions will not be directly callable after loading.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as QNetwork_layer_call_and_return_conditional_losses, QNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, dense_1_layer_call_and_return_conditional_losses while saving (showing 5 of 25). These functions will not be directly callable after loading.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/policy/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/policy/assets\n" ] } ], "source": [ "tf_policy_saver.save(policy_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "Mz-xScbuh4Vo" }, "source": [ "ポリシーの作成に使用されたエージェントまたはネットワークについての知識がなくても、ポリシーを読み込めるので、ポリシーのデプロイが非常に簡単になります。\n", "\n", "保存されたポリシーを読み込み、それがどのように機能するかを確認します。" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:21:31.342356Z", "iopub.status.busy": "2021-02-12T22:21:31.336815Z", "iopub.status.idle": "2021-02-12T22:21:32.148504Z", "shell.execute_reply": "2021-02-12T22:21:32.149034Z" }, "id": "J6T5KLTMh9ZB" }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "saved_policy = tf.compat.v2.saved_model.load(policy_dir)\n", "run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "MpE0KKfqjc0c" }, "source": [ "## エクスポートとインポート\n", "\n", "以下は、Checkpointerとポリシーディレクトリのエクスポート/インポートに役立ち、後でトレーニングを継続して、再度トレーニングすることなくモデルをデプロイできます。\n", "\n", "「1回のイテレーションのトレーニング」に戻り、後で違いを理解できるように、さらに数回トレーニングします。 結果が少し改善し始めたら、以下に進みます。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "cellView": "both", "execution": { "iopub.execute_input": "2021-02-12T22:21:32.156236Z", "iopub.status.busy": "2021-02-12T22:21:32.155574Z", "iopub.status.idle": "2021-02-12T22:21:32.157660Z", "shell.execute_reply": "2021-02-12T22:21:32.158074Z" }, "id": "fd5Cj7DVjfH4" }, "outputs": [], "source": [ "#@title Create zip file and upload zip file (double-click to see the code)\n", "def create_zip_file(dirname, base_filename):\n", " return shutil.make_archive(base_filename, 'zip', dirname)\n", "\n", "def upload_and_unzip_file_to(dirname):\n", " if files is None:\n", " return\n", " uploaded = files.upload()\n", " for fn in uploaded.keys():\n", " print('User uploaded file \"{name}\" with length {length} bytes'.format(\n", " name=fn, length=len(uploaded[fn])))\n", " shutil.rmtree(dirname)\n", " zip_files = zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')\n", " zip_files.extractall(dirname)\n", " zip_files.close()" ] }, { "cell_type": "markdown", "metadata": { "id": "hgyy29doHCmL" }, "source": [ "チェックポイントディレクトリからzipファイルを作成します。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:21:32.162796Z", "iopub.status.busy": "2021-02-12T22:21:32.162148Z", "iopub.status.idle": "2021-02-12T22:21:32.245118Z", "shell.execute_reply": "2021-02-12T22:21:32.244542Z" }, "id": "nhR8NeWzF4fe" }, "outputs": [], "source": [ "train_checkpointer.save(global_step)\n", "checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))" ] }, { "cell_type": "markdown", "metadata": { "id": "VGEpntTocd2u" }, "source": [ "zipファイルをダウンロードします。" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:21:32.249847Z", "iopub.status.busy": "2021-02-12T22:21:32.249211Z", "iopub.status.idle": "2021-02-12T22:21:32.251272Z", "shell.execute_reply": "2021-02-12T22:21:32.250774Z" }, "id": "upFxb5k8b4MC" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "if files is not None:\n", " files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469" ] }, { "cell_type": "markdown", "metadata": { "id": "VRaZMrn5jLmE" }, "source": [ "10〜15回ほどトレーニングした後、チェックポイントのzipファイルをダウンロードし、[ランタイム]> [再起動してすべて実行]に移動してトレーニングをリセットし、このセルに戻ります。ダウンロードしたzipファイルをアップロードして、トレーニングを続けます。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:21:32.255568Z", "iopub.status.busy": "2021-02-12T22:21:32.254870Z", "iopub.status.idle": "2021-02-12T22:21:32.257254Z", "shell.execute_reply": "2021-02-12T22:21:32.256768Z" }, "id": "kg-bKgMsF-H_" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "upload_and_unzip_file_to(checkpoint_dir)\n", "train_checkpointer.initialize_or_restore()\n", "global_step = tf.compat.v1.train.get_global_step()" ] }, { "cell_type": "markdown", "metadata": { "id": "uXrNax5Zk3vF" }, "source": [ "チェックポイントディレクトリをアップロードしたら、「1回のイテレーションのトレーニング」に戻ってトレーニングを続けるか、「ビデオ生成」に戻って読み込まれたポリシーのパフォーマンスを確認します。" ] }, { "cell_type": "markdown", "metadata": { "id": "OAkvVZ-NeN2j" }, "source": [ "または、ポリシー(モデル)を保存して復元することもできます。Checkpointerとは異なり、トレーニングを続けることはできませんが、モデルをデプロイすることはできます。ダウンロードしたファイルはCheckpointerのファイルよりも大幅に小さいことに注意してください。" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:21:32.262093Z", "iopub.status.busy": "2021-02-12T22:21:32.261304Z", "iopub.status.idle": "2021-02-12T22:21:32.425227Z", "shell.execute_reply": "2021-02-12T22:21:32.424715Z" }, "id": "s7qMn6D8eiIA" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as QNetwork_layer_call_and_return_conditional_losses, QNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, dense_1_layer_call_and_return_conditional_losses while saving (showing 5 of 25). These functions will not be directly callable after loading.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as QNetwork_layer_call_and_return_conditional_losses, QNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, dense_1_layer_call_and_return_conditional_losses while saving (showing 5 of 25). These functions will not be directly callable after loading.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/policy/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmp/policy/assets\n" ] } ], "source": [ "tf_policy_saver.save(policy_dir)\n", "policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:21:32.429791Z", "iopub.status.busy": "2021-02-12T22:21:32.429079Z", "iopub.status.idle": "2021-02-12T22:21:32.431145Z", "shell.execute_reply": "2021-02-12T22:21:32.431603Z" }, "id": "rrGvCEXwerJj" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "if files is not None:\n", " files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469" ] }, { "cell_type": "markdown", "metadata": { "id": "DyC_O_gsgSi5" }, "source": [ "ダウンロードしたポリシーディレクトリ(exported_policy.zip)をアップロードし、保存したポリシーの動作を確認します。" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:21:32.436402Z", "iopub.status.busy": "2021-02-12T22:21:32.435752Z", "iopub.status.idle": "2021-02-12T22:21:33.172538Z", "shell.execute_reply": "2021-02-12T22:21:33.172979Z" }, "id": "bgWLimRlXy5z" }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#@test {\"skip\": true}\n", "upload_and_unzip_file_to(policy_dir)\n", "saved_policy = tf.compat.v2.saved_model.load(policy_dir)\n", "run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "HSehXThTm4af" }, "source": [ "## SavedModelPyTFEagerPolicy\n", "\n", "TFポリシーを使用しない場合は、`py_tf_eager_policy.SavedModelPyTFEagerPolicy`を使用して、Python envでsaved_modelを直接使用することもできます。\n", "\n", "これは、eagerモードが有効になっている場合にのみ機能することに注意してください。" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T22:21:33.177906Z", "iopub.status.busy": "2021-02-12T22:21:33.177234Z", "iopub.status.idle": "2021-02-12T22:21:33.912672Z", "shell.execute_reply": "2021-02-12T22:21:33.913189Z" }, "id": "iUC5XuLf1jF7" }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(\n", " policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())\n", "\n", "# Note that we're passing eval_py_env not eval_env.\n", "run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "10_checkpointer_policysaver_tutorial.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 0 }