{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "klGNgWREsvQv" }, "source": [ "##### Copyright 2023 The TF-Agents Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-12-22T13:54:57.504900Z", "iopub.status.busy": "2023-12-22T13:54:57.504513Z", "iopub.status.idle": "2023-12-22T13:54:57.508233Z", "shell.execute_reply": "2023-12-22T13:54:57.507658Z" }, "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "pmDI-h7cI0tI" }, "source": [ "# Train a Deep Q Network with TF-Agents\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " View on TensorFlow.org\n", " \n", " \n", " \n", " Run in Google Colab\n", " \n", " \n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "lsaQlK8fFQqH" }, "source": [ "## Introduction\n" ] }, { "cell_type": "markdown", "metadata": { "id": "cKOCZlhUgXVK" }, "source": [ "This example shows how to train a [DQN (Deep Q Networks)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) agent on the Cartpole environment using the TF-Agents library.\n", "\n", "![Cartpole environment](https://raw.githubusercontent.com/tensorflow/agents/master/docs/tutorials/images/cartpole.png)\n", "\n", "It will walk you through all the components in a Reinforcement Learning (RL) pipeline for training, evaluation and data collection.\n", "\n", "\n", "To run this code live, click the 'Run in Google Colab' link above.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "1u9QVVsShC9X" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "kNrNXKI7bINP" }, "source": [ "If you haven't installed the following dependencies, run:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:54:57.511794Z", "iopub.status.busy": "2023-12-22T13:54:57.511269Z", "iopub.status.idle": "2023-12-22T13:55:17.436360Z", "shell.execute_reply": "2023-12-22T13:55:17.435427Z" }, "id": "KEHR2Ui-lo8O" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]\r", " \r", "Hit:1 http://us-central1.gce.archive.ubuntu.com/ubuntu focal InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (91.189.91.83)] [Connecting to apt.llvm.o\r", " \r", "Hit:2 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates InRelease\r\n", "\r", " \r", "Hit:3 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-backports InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (91.189.91.83)] [Waiting for headers] [Co\r", "0% [Connecting to security.ubuntu.com (91.189.91.83)] [Connecting to developer." ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Get:5 https://nvidia.github.io/libnvidia-container/stable/ubuntu18.04/amd64 InRelease [1484 B]\r\n", "\r", "0% [Connected to apt.llvm.org (199.232.198.49)] [Waiting for headers] [Waiting \r", "0% [Connected to apt.llvm.org (199.232.198.49)] [Waiting for headers] [Waiting \r", " \r", "Hit:6 https://download.docker.com/linux/ubuntu focal InRelease\r\n", "\r", "0% [Connected to apt.llvm.org (199.232.198.49)] [Waiting for headers] [Waiting \r", " \r", "Hit:7 https://nvidia.github.io/nvidia-container-runtime/stable/ubuntu18.04/amd64 InRelease\r\n", "\r", "0% [Connected to apt.llvm.org (199.232.198.49)] [Waiting for headers] [Waiting \r", " \r", "Hit:8 https://nvidia.github.io/nvidia-docker/ubuntu18.04/amd64 InRelease\r\n", "\r", " \r", "0% [Waiting for headers] [Waiting for headers] [Waiting for headers]\r", " \r", "Hit:9 http://security.ubuntu.com/ubuntu focal-security InRelease\r\n", "\r", " \r", "0% [Waiting for headers] [Waiting for headers]\r", " \r", "Hit:4 https://apt.llvm.org/focal llvm-toolchain-focal-17 InRelease\r\n", "\r", " \r", "0% [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:10 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64 InRelease\r\n", "\r", " \r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Connecting to ppa.launchpad.net (185.125.190.80)]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "0% [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:11 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease\r\n", "\r", " \r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:12 http://ppa.launchpad.net/longsleep/golang-backports/ubuntu focal InRelease\r\n", "\r", " \r", "0% [Working]\r", "0% [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:13 http://ppa.launchpad.net/openjdk-r/ppa/ubuntu focal InRelease\r\n", "\r", " \r", "0% [Working]\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "100% [Working]\r", " \r", "Fetched 1484 B in 1s (1072 B/s)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 0%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 2%\r", "\r", "Reading package lists... 2%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 28%\r", "\r", "Reading package lists... 28%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 49%\r", "\r", "Reading package lists... 49%\r", "\r", "Reading package lists... 55%\r", "\r", "Reading package lists... 55%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 58%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 62%\r", "\r", "Reading package lists... 62%\r", "\r", "Reading package lists... 65%\r", "\r", "Reading package lists... 65%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 68%\r", "\r", "Reading package lists... 68%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 77%\r", "\r", "Reading package lists... 77%\r", "\r", "Reading package lists... 82%\r", "\r", "Reading package lists... 82%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 89%\r", "\r", "Reading package lists... 89%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 92%\r", "\r", "Reading package lists... 92%\r", "\r", "Reading package lists... 94%\r", "\r", "Reading package lists... 94%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... Done\r", "\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 100%\r", "\r", "Reading package lists... Done\r", "\r\n", "\r", "Building dependency tree... 0%\r", "\r", "Building dependency tree... 0%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree... 50%\r", "\r", "Building dependency tree... 50%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree \r", "\r\n", "\r", "Reading state information... 0%\r", "\r", "Reading state information... 0%\r", "\r", "Reading state information... Done\r", "\r\n", "freeglut3-dev is already the newest version (2.8.1-3).\r\n", "ffmpeg is already the newest version (7:4.2.7-0ubuntu0.1).\r\n", "xvfb is already the newest version (2:1.20.13-1ubuntu1~20.04.12).\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The following packages were automatically installed and are no longer required:\r\n", " libatasmart4 libblockdev-fs2 libblockdev-loop2 libblockdev-part-err2\r\n", " libblockdev-part2 libblockdev-swap2 libblockdev-utils2 libblockdev2\r\n", " libparted-fs-resize0 libxmlb2\r\n", "Use 'sudo apt autoremove' to remove them.\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0 upgraded, 0 newly installed, 0 to remove and 115 not upgraded.\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting imageio==2.4.0\r\n", " Using cached imageio-2.4.0-py3-none-any.whl\r\n", "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from imageio==2.4.0) (1.26.2)\r\n", "Requirement already satisfied: pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from imageio==2.4.0) (10.1.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: imageio\r\n", " Attempting uninstall: imageio\r\n", " Found existing installation: imageio 2.33.1\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Uninstalling imageio-2.33.1:\r\n", " Successfully uninstalled imageio-2.33.1\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n", "scikit-image 0.22.0 requires imageio>=2.27, but you have imageio 2.4.0 which is incompatible.\u001b[0m\u001b[31m\r\n", "\u001b[0mSuccessfully installed imageio-2.4.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting pyvirtualdisplay\r\n", " Using cached PyVirtualDisplay-3.0-py3-none-any.whl (15 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: pyvirtualdisplay\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed pyvirtualdisplay-3.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tf-agents[reverb]\r\n", " Using cached tf_agents-0.19.0-py3-none-any.whl.metadata (12 kB)\r\n", "Requirement already satisfied: absl-py>=0.6.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (1.4.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting cloudpickle>=1.3 (from tf-agents[reverb])\r\n", " Using cached cloudpickle-3.0.0-py3-none-any.whl.metadata (7.0 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gin-config>=0.4.0 (from tf-agents[reverb])\r\n", " Using cached gin_config-0.5.0-py3-none-any.whl (61 kB)\r\n", "Collecting gym<=0.23.0,>=0.17.0 (from tf-agents[reverb])\r\n", " Using cached gym-0.23.0-py3-none-any.whl\r\n", "Requirement already satisfied: numpy>=1.19.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (1.26.2)\r\n", "Requirement already satisfied: pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (10.1.0)\r\n", "Requirement already satisfied: six>=1.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (1.16.0)\r\n", "Requirement already satisfied: protobuf>=3.11.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (3.20.3)\r\n", "Requirement already satisfied: wrapt>=1.11.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (1.14.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting typing-extensions==4.5.0 (from tf-agents[reverb])\r\n", " Using cached typing_extensions-4.5.0-py3-none-any.whl (27 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting pygame==2.1.3 (from tf-agents[reverb])\r\n", " Using cached pygame-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.7 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow-probability~=0.23.0 (from tf-agents[reverb])\r\n", " Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl.metadata (13 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting rlds (from tf-agents[reverb])\r\n", " Using cached rlds-0.1.8-py3-none-manylinux2010_x86_64.whl (48 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting dm-reverb~=0.14.0 (from tf-agents[reverb])\r\n", " Using cached dm_reverb-0.14.0-cp39-cp39-manylinux2014_x86_64.whl.metadata (17 kB)\r\n", "Requirement already satisfied: tensorflow~=2.15.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (2.15.0.post1)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from dm-reverb~=0.14.0->tf-agents[reverb]) (0.1.8)\r\n", "Collecting portpicker (from dm-reverb~=0.14.0->tf-agents[reverb])\r\n", " Using cached portpicker-1.6.0-py3-none-any.whl.metadata (1.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gym-notices>=0.0.4 (from gym<=0.23.0,>=0.17.0->tf-agents[reverb])\r\n", " Using cached gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n", "Requirement already satisfied: importlib-metadata>=4.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from gym<=0.23.0,>=0.17.0->tf-agents[reverb]) (7.0.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (23.5.26)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (0.5.4)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (0.2.0)\r\n", "Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (3.10.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (16.0.6)\r\n", "Requirement already satisfied: ml-dtypes~=0.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (0.2.0)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (23.2)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (69.0.2)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (2.4.0)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (0.35.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (1.60.0)\r\n", "Requirement already satisfied: tensorboard<2.16,>=2.15 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (2.15.1)\r\n", "Requirement already satisfied: tensorflow-estimator<2.16,>=2.15.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (2.15.0)\r\n", "Requirement already satisfied: keras<2.16,>=2.15.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (2.15.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: decorator in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents[reverb]) (5.1.1)\r\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow~=2.15.0->tf-agents[reverb]) (0.41.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.10.0->gym<=0.23.0,>=0.17.0->tf-agents[reverb]) (3.17.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2.25.2)\r\n", "Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (1.2.0)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.5.1)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2.31.0)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.0.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: psutil in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from portpicker->dm-reverb~=0.14.0->tf-agents[reverb]) (5.9.7)\r\n", "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (5.3.2)\r\n", "Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (0.3.0)\r\n", "Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (4.9)\r\n", "Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (1.3.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.6)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2.1.0)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2023.11.17)\r\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2.1.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (0.5.1)\r\n", "Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.2.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached cloudpickle-3.0.0-py3-none-any.whl (20 kB)\r\n", "Using cached dm_reverb-0.14.0-cp39-cp39-manylinux2014_x86_64.whl (6.4 MB)\r\n", "Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl (6.9 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached tf_agents-0.19.0-py3-none-any.whl (1.4 MB)\r\n", "Using cached portpicker-1.6.0-py3-none-any.whl (16 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: gym-notices, gin-config, typing-extensions, rlds, pygame, portpicker, cloudpickle, tensorflow-probability, gym, dm-reverb, tf-agents\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: typing-extensions\r\n", " Found existing installation: typing_extensions 4.9.0\r\n", " Uninstalling typing_extensions-4.9.0:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled typing_extensions-4.9.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed cloudpickle-3.0.0 dm-reverb-0.14.0 gin-config-0.5.0 gym-0.23.0 gym-notices-0.0.8 portpicker-1.6.0 pygame-2.1.3 rlds-0.1.8 tensorflow-probability-0.23.0 tf-agents-0.19.0 typing-extensions-4.5.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyglet in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.0.10)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tf-keras\r\n", " Using cached tf_keras-2.15.0-py3-none-any.whl.metadata (1.6 kB)\r\n", "Using cached tf_keras-2.15.0-py3-none-any.whl (1.7 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: tf-keras\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed tf-keras-2.15.0\r\n" ] } ], "source": [ "!sudo apt-get update\n", "!sudo apt-get install -y xvfb ffmpeg freeglut3-dev\n", "!pip install 'imageio==2.4.0'\n", "!pip install pyvirtualdisplay\n", "!pip install tf-agents[reverb]\n", "!pip install pyglet\n", "!pip install tf-keras" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:17.440509Z", "iopub.status.busy": "2023-12-22T13:55:17.440250Z", "iopub.status.idle": "2023-12-22T13:55:17.444088Z", "shell.execute_reply": "2023-12-22T13:55:17.443523Z" }, "id": "UX0aSKBCYmj2" }, "outputs": [], "source": [ "import os\n", "# Keep using keras-2 (tf-keras) rather than keras-3 (keras).\n", "os.environ['TF_USE_LEGACY_KERAS'] = '1'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:17.447507Z", "iopub.status.busy": "2023-12-22T13:55:17.447082Z", "iopub.status.idle": "2023-12-22T13:55:20.751418Z", "shell.execute_reply": "2023-12-22T13:55:20.750703Z" }, "id": "sMitx5qSgJk1" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-12-22 13:55:18.305379: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-12-22 13:55:18.305427: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-12-22 13:55:18.307063: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "from __future__ import absolute_import, division, print_function\n", "\n", "import base64\n", "import imageio\n", "import IPython\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import PIL.Image\n", "import pyvirtualdisplay\n", "import reverb\n", "\n", "import tensorflow as tf\n", "\n", "from tf_agents.agents.dqn import dqn_agent\n", "from tf_agents.drivers import py_driver\n", "from tf_agents.environments import suite_gym\n", "from tf_agents.environments import tf_py_environment\n", "from tf_agents.eval import metric_utils\n", "from tf_agents.metrics import tf_metrics\n", "from tf_agents.networks import sequential\n", "from tf_agents.policies import py_tf_eager_policy\n", "from tf_agents.policies import random_tf_policy\n", "from tf_agents.replay_buffers import reverb_replay_buffer\n", "from tf_agents.replay_buffers import reverb_utils\n", "from tf_agents.trajectories import trajectory\n", "from tf_agents.specs import tensor_spec\n", "from tf_agents.utils import common" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:20.756250Z", "iopub.status.busy": "2023-12-22T13:55:20.755326Z", "iopub.status.idle": "2023-12-22T13:55:20.843808Z", "shell.execute_reply": "2023-12-22T13:55:20.842659Z" }, "id": "J6HsdS5GbSjd" }, "outputs": [], "source": [ "# Set up a virtual display for rendering OpenAI gym environments.\n", "display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:20.848485Z", "iopub.status.busy": "2023-12-22T13:55:20.847776Z", "iopub.status.idle": "2023-12-22T13:55:20.855875Z", "shell.execute_reply": "2023-12-22T13:55:20.855228Z" }, "id": "NspmzG4nP3b9" }, "outputs": [ { "data": { "text/plain": [ "'2.15.0'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.version.VERSION" ] }, { "cell_type": "markdown", "metadata": { "id": "LmC0NDhdLIKY" }, "source": [ "## Hyperparameters" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:20.859412Z", "iopub.status.busy": "2023-12-22T13:55:20.858836Z", "iopub.status.idle": "2023-12-22T13:55:20.862532Z", "shell.execute_reply": "2023-12-22T13:55:20.861983Z" }, "id": "HC1kNrOsLSIZ" }, "outputs": [], "source": [ "num_iterations = 20000 # @param {type:\"integer\"}\n", "\n", "initial_collect_steps = 100 # @param {type:\"integer\"}\n", "collect_steps_per_iteration = 1# @param {type:\"integer\"}\n", "replay_buffer_max_length = 100000 # @param {type:\"integer\"}\n", "\n", "batch_size = 64 # @param {type:\"integer\"}\n", "learning_rate = 1e-3 # @param {type:\"number\"}\n", "log_interval = 200 # @param {type:\"integer\"}\n", "\n", "num_eval_episodes = 10 # @param {type:\"integer\"}\n", "eval_interval = 1000 # @param {type:\"integer\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "VMsJC3DEgI0x" }, "source": [ "## Environment\n", "\n", "In Reinforcement Learning (RL), an environment represents the task or problem to be solved. Standard environments can be created in TF-Agents using `tf_agents.environments` suites. TF-Agents has suites for loading environments from sources such as the OpenAI Gym, Atari, and DM Control.\n", "\n", "Load the CartPole environment from the OpenAI Gym suite." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:20.865902Z", "iopub.status.busy": "2023-12-22T13:55:20.865429Z", "iopub.status.idle": "2023-12-22T13:55:20.896073Z", "shell.execute_reply": "2023-12-22T13:55:20.895421Z" }, "id": "pYEz-S9gEv2-" }, "outputs": [], "source": [ "env_name = 'CartPole-v0'\n", "env = suite_gym.load(env_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "IIHYVBkuvPNw" }, "source": [ "You can render this environment to see how it looks. A free-swinging pole is attached to a cart. The goal is to move the cart right or left in order to keep the pole pointing up." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:20.899550Z", "iopub.status.busy": "2023-12-22T13:55:20.899088Z", "iopub.status.idle": "2023-12-22T13:55:21.042230Z", "shell.execute_reply": "2023-12-22T13:55:21.041616Z" }, "id": "RlO7WIQHu_7D" }, "outputs": [ { "data": { "image/jpeg": "", "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAIAAAD9V4nPAAAVSElEQVR4Ae3dzW4kZxUG4Iw9joefJWKTIPYIiQ1JWLBHbBB/CySE2AUkrohwCyABG26ABSgsuQEUwgKxhGQ8Hts4cZLxzNjtLnd9p0/V+0RIeLqr6jvnOV/0pnps14OLi4tX/EOAAAECBFIFDlIb1zcBAgQIEPhIQBDaBwQIECAQLSAIo8eveQIECBAQhPYAAQIECEQLCMLo8WueAAECBAShPUCAAAEC0QKCMHr8midAgAABQWgPECBAgEC0gCCMHr/mCRAgQEAQ2gMECBAgEC0gCKPHr3kCBAgQEIT2AAECBAhECwjC6PFrngABAgQEoT1AgAABAtECgjB6/JonQIAAAUFoDxAgQIBAtIAgjB6/5gkQIEBAENoDBAgQIBAtIAijx695AgQIEBCE9gABAgQIRAsIwujxa54AAQIEBKE9QIAAAQLRAoIwevyaJ0CAAAFBaA8QIECAQLSAIIwev+YJECBAQBDaAwQIECAQLSAIo8eveQIECBAQhPYAAQIECEQLCMLo8WueAAECBAShPUCAAAEC0QKCMHr8midAgAABQWgPECBAgEC0gCCMHr/mCRAgQEAQ2gMECBAgEC0gCKPHr3kCBAgQEIT2AAECBAhECwjC6PFrngABAgQEoT1AgAABAtECgjB6/JonQIAAAUFoDxAgQIBAtIAgjB6/5gkQIEBAENoDBAgQIBAtIAijx695AgQIEBCE9gABAgQIRAsIwujxa54AAQIEBKE9QIAAAQLRAoIwevyaJ0CAAAFBaA8QIECAQLSAIIwev+YJECBAQBDaAwQIECAQLSAIo8eveQIECBAQhPYAAQIECEQLCMLo8WueAAECBAShPUCAAAEC0QKCMHr8midAgAABQWgPECBAgEC0gCCMHr/mCRAgQEAQ2gMECBAgEC0gCKPHr3kCBAgQEIT2AAECBAhECwjC6PFrngABAgQEoT1AgAABAtECgjB6/JonQIAAAUFoDxAgQIBAtIAgjB6/5gkQIEBAENoDBAgQIBAtIAijx695AgQIEBCE9gABAgQIRAsIwujxa54AAQIEBKE9QIAAAQLRAoIwevyaJ0CAAAFBaA8QIECAQLSAIIwev+YJECBAQBDaAwQIECAQLSAIo8eveQIECBAQhPYAAQIECEQLCMLo8WueAAECBAShPUCAAAEC0QKCMHr8midAgAABQWgPECBAgEC0gCCMHr/mCRAgQEAQ2gMECBAgEC0gCKPHr3kCBAgQEIT2AAECBAhECwjC6PFrngABAgQEoT1AgAABAtECgjB6/JonQIAAAUFoDxAgQIBAtIAgjB6/5gkQIEBAENoDBAgQIBAtIAijx695AgQIEBCE9gABAgQIRAsIwujxa54AAQIEBKE9QIAAAQLRAoIwevyaJ0CAAAFBaA8QIECAQLSAIIwev+YJECBAQBDaAwQIECAQLSAIo8eveQIECBAQhPYAAQIECEQLCMLo8WueAAECBAShPUCAAAEC0QKCMHr8midAgAABQWgPECBAgEC0gCCMHr/mCRAgQEAQ2gMECBAgEC0gCKPHr3kCBAgQEIT2AAECBAhECwjC6PFrngABAgQEoT1AgAABAtECgjB6/JonQIAAAUFoDxAgQIBAtIAgjB6/5gkQIEBAENoDBAgQIBAtIAijx695AgQIEBCE9gABAgQIRAsIwujxa54AAQIEBKE9QIAAAQLRAoIwevyaJ0CAAAFBaA8QIECAQLSAIIwev+YJECBAQBDaAwQIECAQLSAIo8eveQIECBAQhPYAAQIECEQLCMLo8WueAAECBAShPUCAAAEC0QKCMHr8midAgAABQWgPECBAgEC0gCCMHr/mCRAgQEAQ2gMECBAgEC0gCKPHr3kCBAgQEIT2AAECBAhECwjC6PFrngABAgQEoT1AgAABAtECgjB6/JonQIAAAUFoDxAgQIBAtIAgjB6/5gkQIEBAENoDBAgQIBAtIAijx695AgQIEBCE9gABAgQIRAsIwujxa54AAQIEBKE9QIAAAQLRAoIwevyaJ0CAAAFBaA8QIECAQLSAIIwev+YJECBAQBDaAwQIECAQLSAIo8eveQIECBAQhPYAAQIECEQLCMLo8WueAAECBAShPUCAAAEC0QKCMHr8midAgAABQWgPECBAgEC0gCCMHr/mCRAgQEAQ2gMECBAgEC0gCKPHr3kCBAgQEIT2AAECBAhECwjC6PFrngABAgQEoT1AgAABAtECgjB6/JonQIAAAUFoDxAgQIBAtIAgjB6/5gkQIEBAENoDBAgQIBAtIAijx695AgQIEBCE9gABAgQIRAsIwujxa54AAQIEBKE9QIAAAQLRAoIwevyaJ0CAAAFBaA8QIECAQLSAIIwev+YJECBAQBDaAwQIECAQLSAIo8eveQIECBB4iIAAgXEC7/767c0Xf+MX72w+wLsECIwWcEc4Wtj1CWwSODs92fS29wgQGC8gCMcbW4HA7QLngvB2HO8QqBEQhDXOViFws8DZ6eOb3/AqAQJVAoKwSto6BG4ScEd4k4rXCJQKCMJSbosReEHAHeELIP5IoF5AENabW5HAMwHfLPPMwlcE9iQgCPcEb1kCHwuc+ztCO4HAvgUE4b4nYP1sAXeE2fPXfQsBQdhiDIqIFTh/6ucIY4ev8S4CgrDLJNSRKXD2xI9PZE5e140EBGGjYShlfQKvv/XDzU29/+7vNx/gXQIERgsIwtHCrh8tcPDwOLp/zRNYgoAgXMKU1LhYgcNXHy22doUTSBEQhCmT1udeBA6P3BHuBd6iBCYICMIJWA4lMFXAR6NTxRxPoF5AENabWzFI4ODIR6NB49bqQgUE4UIHp+xlCPhodBlzUmW2gCDMnr/uBwscuiMcLOzyBHYXEIS7G7oCgVsFDnyzzK023iDQRUAQdpmEOlYp4I5wlWPV1MoEBOHKBqqdXgIHD4/uLOji4uLOYxxAgMA4AUE4ztaVCWwl4CH1WzE5iMAwAUE4jNaFCWwn4CH12zk5isAoAUE4StZ1CWwp4I5wSyiHERgkIAgHwbosgW0F3BFuK+U4AmMEBOEYV1clsLWAZ/NuTeVAAkMEBOEQVhclsL2AZ/Nub+VIAiMEBOEIVdckMEHg7PRkwtEOJUBgbgFBOLeo6xGYKOCj0YlgDicws4AgnBnU5QhMFfBdo1PFHE9gXgFBOK+nqxGYLOC7RieTOYHArAKCcFZOFyPwksDrb/3opdeee+H9d//w3J/9gQCBWgFBWOtttTwBjyTMm7mOFyYgCBc2MOUuTsBD6hc3MgWnCQjCtInrt1rAHWG1uPUITBQQhBPBHE5gooA7wolgDidQLSAIq8WtlybgjjBt4vpdnIAgXNzIFLwwAQ+pX9jAlJsnIAjzZq7jWoGDo+PaBa1GgMA0AUE4zcvRBKYK+Gh0qpjjCRQLCMJicMvFCfhmmbiRa3hpAoJwaRNT79IEDg4f3lnyxfn5ncc4gACBQQKCcBCsyxKYIODXjU7AciiBuQUE4dyirkdguoAnMU03cwaB2QQE4WyULkTg3gIeUn9vOicS2F1AEO5u6AoEdhVwR7iroPMJ7CAgCHfAcyqBmQTOTk9mupLLECAwWUAQTiZzAoHZBc5PH89+TRckQGBLAUG4JZTDCAwUcEc4ENelCdwlIAjvEvI+gfEC5z4aHY9sBQK3CQjC22S8TqBOwB1hnbWVCLwkIAhfIvECgXIBP1BfTm5BAs8EBOEzC18RGCTwlW/9ePOV//W3P24+wLsECIwTEITjbF2ZwCcCfu+2rUCgs4Ag7Dwdta1EwJOYVjJIbaxUQBCudLDa6iTgjrDTNNRC4EUBQfiiiD8TmF3AHeHspC5IYEYBQTgjpksRuFng4Oj45je8SoBAAwFB2GAISli7wOHRo7W3qD8CCxYQhAsentKXIuCOcCmTUmemgCDMnLuuSwX8HWEpt8UITBQQhBPBHE5guoCPRqebOYNAnYAgrLO2UqzAg4PDO3s/P3t65zEOIEBghIAgHKHqmgQmC3hI/WQyJxCYSUAQzgTpMgR2E/AAit38nE3g/gKC8P52ziQwo4CH1M+I6VIEJgkIwklcDiYwSsAd4ShZ1yVwl4AgvEvI+wRKBDykvoTZIgRuEBCEN6B4iUC9gGfz1ptbkcCVgCC0Ewi0EHBH2GIMiogUEISRY9d0PwF/R9hvJipKERCEKZPWZ3MB3zXafEDKW7GAIFzxcLXWSOC1b35vczXv/eW3mw/wLgECgwQE4SBYlyXwnICH1D/H4Q8EOgkIwk7TUMt6BTyAYr2z1dniBQTh4keogUUIuCNcxJgUmSkgCDPnrutqAXeE1eLWI7C1gCDcmsqBBHYQ8JD6HfCcSmCsgCAc6+vqBK4E3BHaCQTaCgjCtqNR2KoE/B3hqsapmXUJCMJ1zVM3XQXcEXadjLoIvCIIbQICFQIHD48rlrEGAQLTBQThdDNnEJgucPjqo+knOYMAgQoBQVihbA0CDx7c/e/a+dkpKAIE6gXu/pezviYrEsgU8CSmzLnreu8CgnDvI1AAgU8EPInJViCwFwFBuBd2ixK4QcBD6m9A8RKB8QKCcLyxFQhsJ+Cj0e2cHEVgZgFBODOoyxG4t4CPRu9N50QCuwgIwl30nEtgTgEPqZ9T07UIbC0gCLemciCBwQLuCAcDuzyBmwUE4c0uXiVQL+COsN7cigQuBQShbUCgi4A7wi6TUEeYgCAMG7h29yfw2hvf37z4P//6u80HeJcAgRECgnCEqmsSuEHAAyhuQPESgQYCgrDBEJSQIeAh9Rlz1uXyBATh8mam4oUKHB55AMVCR6fslQsIwpUPWHt9BNwR9pmFSghcFxCE1zV8TWCggL8jHIjr0gR2EBCEO+A5lcAUAQ+pn6LlWAJ1AoKwztpK4QIeUh++AbTfVkAQth2NwtYm4I5wbRPVz1oEBOFaJqmP9gK+a7T9iBQYKiAIQwev7XoBH43Wm1uRwDYCgnAbJccQKBI4f/qkaCXLECDwqYAg/FTC/xNoIOD3bjcYghLiBARh3Mg13Fng/PSkc3lqI7BKAUG4yrFqaqkCZ6ePl1q6ugksVkAQLnZ0Cl+jgI9G1zhVPXUXEITdJ6S+KAEPqY8at2abCAjCJoNQBoGPBNwR2gcE6gUEYb25FRcv8OC+/9zZ+c9++pP7Xvuj8+68vgMIEHhZQBC+bOIVAnsT+Pzx0d7WtjCBVAFBmDp5fbcU+Nzxw5Z1KYrAmgUeXFxcrLk/vREYIHDvDyF//p1v/OoHb15V9Pf/fvvfT756cv6F44P/ffnVf3z9i3++ev3NX/7m3iX71/nedE5MFvCfn8nT13u1wIcnT6+W/NN/3v5s7cssfO/x1y7/990vvfPZi74gQKBMwEejZdQWIvDKByenlwrXU/A6ym2vXz/G1wQIzC4gCGcndUECtwp88Ph0c9ptfvfW63qDAIEdBAThDnhOJTBR4OqOcOJJDidAYKyAIBzr6+oErgt8+PFHo9df8TUBAnsXEIR7H4ECggTcEQYNW6vLERCEy5mVSpcvcPl3hMtvQgcE1iYgCNc2Uf10Frj88YnNPyOx+d3OramNwHIFBOFyZ6fy5QlcfTR6W9rd9vry+lQxgUUJ+IH6RY1LsQsXePzkkx+ov8y8236zzMJbVD6B5Qn4FWvLm5mK9y5w71+xNrpyv2JttLDrEyBAgAABAgTWJuCOcG0T1U+BgDvCAmRLECgT8M0yZdQWIkCAAIGOAoKw41TURIAAAQJlAoKwjNpCBAgQINBRQBB2nIqaCBAgQKBMQBCWUVuIAAECBDoKCMKOU1ETAQIECJQJCMIyagsRIECAQEcBQdhxKmoiQIAAgTIBQVhGbSECBAgQ6CggCDtORU0ECBAgUCYgCMuoLUSAAAECHQUEYcepqIkAAQIEygQEYRm1hQgQIECgo4Ag7DgVNREgQIBAmYDHMJVRW4gAAQIEOgq4I+w4FTURIECAQJmAICyjthABAgQIdBQQhB2noiYCBAgQKBMQhGXUFiJAgACBjgKCsONU1ESAAAECZQKCsIzaQgQIECDQUUAQdpyKmggQIECgTEAQllFbiAABAgQ6CgjCjlNREwECBAiUCQjCMmoLESBAgEBHAUHYcSpqIkCAAIEyAUFYRm0hAgQIEOgoIAg7TkVNBAgQIFAmIAjLqC1EgAABAh0FBGHHqaiJAAECBMoEBGEZtYUIECBAoKOAIOw4FTURIECAQJmAICyjthABAgQIdBQQhB2noiYCBAgQKBMQhGXUFiJAgACBjgKCsONU1ESAAAECZQKCsIzaQgQIECDQUUAQdpyKmggQIECgTEAQllFbiAABAgQ6CgjCjlNREwECBAiUCQjCMmoLESBAgEBHAUHYcSpqIkCAAIEyAUFYRm0hAgQIEOgoIAg7TkVNBAgQIFAmIAjLqC1EgAABAh0FBGHHqaiJAAECBMoEBGEZtYUIECBAoKOAIOw4FTURIECAQJmAICyjthABAgQIdBQQhB2noiYCBAgQKBMQhGXUFiJAgACBjgKCsONU1ESAAAECZQKCsIzaQgQIECDQUUAQdpyKmggQIECgTEAQllFbiAABAgQ6CgjCjlNREwECBAiUCQjCMmoLESBAgEBHAUHYcSpqIkCAAIEyAUFYRm0hAgQIEOgoIAg7TkVNBAgQIFAmIAjLqC1EgAABAh0FBGHHqaiJAAECBMoEBGEZtYUIECBAoKOAIOw4FTURIECAQJmAICyjthABAgQIdBQQhB2noiYCBAgQKBMQhGXUFiJAgACBjgKCsONU1ESAAAECZQKCsIzaQgQIECDQUUAQdpyKmggQIECgTEAQllFbiAABAgQ6CgjCjlNREwECBAiUCQjCMmoLESBAgEBHAUHYcSpqIkCAAIEyAUFYRm0hAgQIEOgoIAg7TkVNBAgQIFAmIAjLqC1EgAABAh0FBGHHqaiJAAECBMoEBGEZtYUIECBAoKOAIOw4FTURIECAQJmAICyjthABAgQIdBQQhB2noiYCBAgQKBMQhGXUFiJAgACBjgKCsONU1ESAAAECZQKCsIzaQgQIECDQUUAQdpyKmggQIECgTEAQllFbiAABAgQ6CgjCjlNREwECBAiUCQjCMmoLESBAgEBHAUHYcSpqIkCAAIEyAUFYRm0hAgQIEOgoIAg7TkVNBAgQIFAmIAjLqC1EgAABAh0FBGHHqaiJAAECBMoEBGEZtYUIECBAoKOAIOw4FTURIECAQJmAICyjthABAgQIdBQQhB2noiYCBAgQKBMQhGXUFiJAgACBjgL/B9eOJ5dEX8BPAAAAAElFTkSuQmCC", "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#@test {\"skip\": true}\n", "env.reset()\n", "PIL.Image.fromarray(env.render())" ] }, { "cell_type": "markdown", "metadata": { "id": "B9_lskPOey18" }, "source": [ "The `environment.step` method takes an `action` in the environment and returns a `TimeStep` tuple containing the next observation of the environment and the reward for the action.\n", "\n", "The `time_step_spec()` method returns the specification for the `TimeStep` tuple. Its `observation` attribute shows the shape of observations, the data types, and the ranges of allowed values. The `reward` attribute shows the same details for the reward.\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:21.045996Z", "iopub.status.busy": "2023-12-22T13:55:21.045342Z", "iopub.status.idle": "2023-12-22T13:55:21.052853Z", "shell.execute_reply": "2023-12-22T13:55:21.052224Z" }, "id": "exDv57iHfwQV" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Observation Spec:\n", "BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])\n" ] } ], "source": [ "print('Observation Spec:')\n", "print(env.time_step_spec().observation)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:21.056164Z", "iopub.status.busy": "2023-12-22T13:55:21.055699Z", "iopub.status.idle": "2023-12-22T13:55:21.059703Z", "shell.execute_reply": "2023-12-22T13:55:21.059058Z" }, "id": "UxiSyCbBUQPi" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Reward Spec:\n", "ArraySpec(shape=(), dtype=dtype('float32'), name='reward')\n" ] } ], "source": [ "print('Reward Spec:')\n", "print(env.time_step_spec().reward)" ] }, { "cell_type": "markdown", "metadata": { "id": "b_lHcIcqUaqB" }, "source": [ "The `action_spec()` method returns the shape, data types, and allowed values of valid actions." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:21.062912Z", "iopub.status.busy": "2023-12-22T13:55:21.062530Z", "iopub.status.idle": "2023-12-22T13:55:21.066031Z", "shell.execute_reply": "2023-12-22T13:55:21.065367Z" }, "id": "bttJ4uxZUQBr" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Action Spec:\n", "BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)\n" ] } ], "source": [ "print('Action Spec:')\n", "print(env.action_spec())" ] }, { "cell_type": "markdown", "metadata": { "id": "eJCgJnx3g0yY" }, "source": [ "In the Cartpole environment:\n", "\n", "- `observation` is an array of 4 floats:\n", " - the position and velocity of the cart\n", " - the angular position and velocity of the pole\n", "- `reward` is a scalar float value\n", "- `action` is a scalar integer with only two possible values:\n", " - `0` — \"move left\"\n", " - `1` — \"move right\"\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:21.069391Z", "iopub.status.busy": "2023-12-22T13:55:21.068881Z", "iopub.status.idle": "2023-12-22T13:55:21.075210Z", "shell.execute_reply": "2023-12-22T13:55:21.074587Z" }, "id": "V2UGR5t_iZX-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Time step:\n", "TimeStep(\n", "{'step_type': array(0, dtype=int32),\n", " 'reward': array(0., dtype=float32),\n", " 'discount': array(1., dtype=float32),\n", " 'observation': array([ 0.0365577 , -0.00826731, -0.02852953, -0.02371309], dtype=float32)})\n", "Next time step:\n", "TimeStep(\n", "{'step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32),\n", " 'observation': array([ 0.03639235, 0.18725191, -0.02900379, -0.32525912], dtype=float32)})\n" ] } ], "source": [ "time_step = env.reset()\n", "print('Time step:')\n", "print(time_step)\n", "\n", "action = np.array(1, dtype=np.int32)\n", "\n", "next_time_step = env.step(action)\n", "print('Next time step:')\n", "print(next_time_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "4JSc9GviWUBK" }, "source": [ "Usually two environments are instantiated: one for training and one for evaluation." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:21.078434Z", "iopub.status.busy": "2023-12-22T13:55:21.078196Z", "iopub.status.idle": "2023-12-22T13:55:21.083380Z", "shell.execute_reply": "2023-12-22T13:55:21.082812Z" }, "id": "N7brXNIGWXjC" }, "outputs": [], "source": [ "train_py_env = suite_gym.load(env_name)\n", "eval_py_env = suite_gym.load(env_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "zuUqXAVmecTU" }, "source": [ "The Cartpole environment, like most environments, is written in pure Python. This is converted to TensorFlow using the `TFPyEnvironment` wrapper.\n", "\n", "The original environment's API uses Numpy arrays. The `TFPyEnvironment` converts these to `Tensors` to make it compatible with Tensorflow agents and policies.\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:21.086745Z", "iopub.status.busy": "2023-12-22T13:55:21.086262Z", "iopub.status.idle": "2023-12-22T13:55:21.095262Z", "shell.execute_reply": "2023-12-22T13:55:21.094705Z" }, "id": "Xp-Y4mD6eDhF" }, "outputs": [], "source": [ "train_env = tf_py_environment.TFPyEnvironment(train_py_env)\n", "eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "E9lW_OZYFR8A" }, "source": [ "## Agent\n", "\n", "The algorithm used to solve an RL problem is represented by an `Agent`. TF-Agents provides standard implementations of a variety of `Agents`, including:\n", "\n", "- [DQN](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) (used in this tutorial)\n", "- [REINFORCE](https://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)\n", "- [DDPG](https://arxiv.org/pdf/1509.02971.pdf)\n", "- [TD3](https://arxiv.org/pdf/1802.09477.pdf)\n", "- [PPO](https://arxiv.org/abs/1707.06347)\n", "- [SAC](https://arxiv.org/abs/1801.01290)\n", "\n", "The DQN agent can be used in any environment which has a discrete action space.\n", "\n", "At the heart of a DQN Agent is a `QNetwork`, a neural network model that can learn to predict `QValues` (expected returns) for all actions, given an observation from the environment.\n", "\n", "We will use `tf_agents.networks.` to create a `QNetwork`. The network will consist of a sequence of `tf.keras.layers.Dense` layers, where the final layer will have 1 output for each possible action." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:21.098814Z", "iopub.status.busy": "2023-12-22T13:55:21.098224Z", "iopub.status.idle": "2023-12-22T13:55:21.120409Z", "shell.execute_reply": "2023-12-22T13:55:21.119824Z" }, "id": "TgkdEPg_muzV" }, "outputs": [], "source": [ "fc_layer_params = (100, 50)\n", "action_tensor_spec = tensor_spec.from_spec(env.action_spec())\n", "num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1\n", "\n", "# Define a helper function to create Dense layers configured with the right\n", "# activation and kernel initializer.\n", "def dense_layer(num_units):\n", " return tf.keras.layers.Dense(\n", " num_units,\n", " activation=tf.keras.activations.relu,\n", " kernel_initializer=tf.keras.initializers.VarianceScaling(\n", " scale=2.0, mode='fan_in', distribution='truncated_normal'))\n", "\n", "# QNetwork consists of a sequence of Dense layers followed by a dense layer\n", "# with `num_actions` units to generate one q_value per available action as\n", "# its output.\n", "dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]\n", "q_values_layer = tf.keras.layers.Dense(\n", " num_actions,\n", " activation=None,\n", " kernel_initializer=tf.keras.initializers.RandomUniform(\n", " minval=-0.03, maxval=0.03),\n", " bias_initializer=tf.keras.initializers.Constant(-0.2))\n", "q_net = sequential.Sequential(dense_layers + [q_values_layer])" ] }, { "cell_type": "markdown", "metadata": { "id": "z62u55hSmviJ" }, "source": [ "Now use `tf_agents.agents.dqn.dqn_agent` to instantiate a `DqnAgent`. In addition to the `time_step_spec`, `action_spec` and the QNetwork, the agent constructor also requires an optimizer (in this case, `AdamOptimizer`), a loss function, and an integer step counter." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:21.123474Z", "iopub.status.busy": "2023-12-22T13:55:21.123220Z", "iopub.status.idle": "2023-12-22T13:55:23.609895Z", "shell.execute_reply": "2023-12-22T13:55:23.609126Z" }, "id": "jbY4yrjTEyc9" }, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n", "\n", "train_step_counter = tf.Variable(0)\n", "\n", "agent = dqn_agent.DqnAgent(\n", " train_env.time_step_spec(),\n", " train_env.action_spec(),\n", " q_network=q_net,\n", " optimizer=optimizer,\n", " td_errors_loss_fn=common.element_wise_squared_loss,\n", " train_step_counter=train_step_counter)\n", "\n", "agent.initialize()" ] }, { "cell_type": "markdown", "metadata": { "id": "I0KLrEPwkn5x" }, "source": [ "## Policies\n", "\n", "A policy defines the way an agent acts in an environment. Typically, the goal of reinforcement learning is to train the underlying model until the policy produces the desired outcome.\n", "\n", "In this tutorial:\n", "\n", "- The desired outcome is keeping the pole balanced upright over the cart.\n", "- The policy returns an action (left or right) for each `time_step` observation.\n", "\n", "Agents contain two policies:\n", "\n", "- `agent.policy` — The main policy that is used for evaluation and deployment.\n", "- `agent.collect_policy` — A second policy that is used for data collection.\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:23.614035Z", "iopub.status.busy": "2023-12-22T13:55:23.613778Z", "iopub.status.idle": "2023-12-22T13:55:23.617153Z", "shell.execute_reply": "2023-12-22T13:55:23.616470Z" }, "id": "BwY7StuMkuV4" }, "outputs": [], "source": [ "eval_policy = agent.policy\n", "collect_policy = agent.collect_policy" ] }, { "cell_type": "markdown", "metadata": { "id": "2Qs1Fl3dV0ae" }, "source": [ "Policies can be created independently of agents. For example, use `tf_agents.policies.random_tf_policy` to create a policy which will randomly select an action for each `time_step`." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:23.620353Z", "iopub.status.busy": "2023-12-22T13:55:23.620132Z", "iopub.status.idle": "2023-12-22T13:55:23.623537Z", "shell.execute_reply": "2023-12-22T13:55:23.622958Z" }, "id": "HE37-UCIrE69" }, "outputs": [], "source": [ "random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),\n", " train_env.action_spec())" ] }, { "cell_type": "markdown", "metadata": { "id": "dOlnlRRsUbxP" }, "source": [ "To get an action from a policy, call the `policy.action(time_step)` method. The `time_step` contains the observation from the environment. This method returns a `PolicyStep`, which is a named tuple with three components:\n", "\n", "- `action` — the action to be taken (in this case, `0` or `1`)\n", "- `state` — used for stateful (that is, RNN-based) policies\n", "- `info` — auxiliary data, such as log probabilities of actions" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:23.626567Z", "iopub.status.busy": "2023-12-22T13:55:23.626324Z", "iopub.status.idle": "2023-12-22T13:55:23.632855Z", "shell.execute_reply": "2023-12-22T13:55:23.632276Z" }, "id": "5gCcpXswVAxk" }, "outputs": [], "source": [ "example_environment = tf_py_environment.TFPyEnvironment(\n", " suite_gym.load('CartPole-v0'))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:23.636090Z", "iopub.status.busy": "2023-12-22T13:55:23.635847Z", "iopub.status.idle": "2023-12-22T13:55:23.641638Z", "shell.execute_reply": "2023-12-22T13:55:23.641020Z" }, "id": "D4DHZtq3Ndis" }, "outputs": [], "source": [ "time_step = example_environment.reset()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:23.644790Z", "iopub.status.busy": "2023-12-22T13:55:23.644526Z", "iopub.status.idle": "2023-12-22T13:55:24.302461Z", "shell.execute_reply": "2023-12-22T13:55:24.301812Z" }, "id": "PRFqAUzpNaAW" }, "outputs": [ { "data": { "text/plain": [ "PolicyStep(action=, state=(), info=())" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_policy.action(time_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "94rCXQtbUbXv" }, "source": [ "## Metrics and Evaluation\n", "\n", "The most common metric used to evaluate a policy is the average return. The return is the sum of rewards obtained while running a policy in an environment for an episode. Several episodes are run, creating an average return.\n", "\n", "The following function computes the average return of a policy, given the policy, environment, and a number of episodes.\n" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:24.306055Z", "iopub.status.busy": "2023-12-22T13:55:24.305783Z", "iopub.status.idle": "2023-12-22T13:55:24.310552Z", "shell.execute_reply": "2023-12-22T13:55:24.309978Z" }, "id": "bitzHo5_UbXy" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "def compute_avg_return(environment, policy, num_episodes=10):\n", "\n", " total_return = 0.0\n", " for _ in range(num_episodes):\n", "\n", " time_step = environment.reset()\n", " episode_return = 0.0\n", "\n", " while not time_step.is_last():\n", " action_step = policy.action(time_step)\n", " time_step = environment.step(action_step.action)\n", " episode_return += time_step.reward\n", " total_return += episode_return\n", "\n", " avg_return = total_return / num_episodes\n", " return avg_return.numpy()[0]\n", "\n", "\n", "# See also the metrics module for standard implementations of different metrics.\n", "# https://github.com/tensorflow/agents/tree/master/tf_agents/metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "_snCVvq5Z8lJ" }, "source": [ "Running this computation on the `random_policy` shows a baseline performance in the environment." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:24.313642Z", "iopub.status.busy": "2023-12-22T13:55:24.313416Z", "iopub.status.idle": "2023-12-22T13:55:25.361689Z", "shell.execute_reply": "2023-12-22T13:55:25.361087Z" }, "id": "9bgU6Q6BZ8Bp" }, "outputs": [ { "data": { "text/plain": [ "23.5" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "compute_avg_return(eval_env, random_policy, num_eval_episodes)" ] }, { "cell_type": "markdown", "metadata": { "id": "NLva6g2jdWgr" }, "source": [ "## Replay Buffer\n", "\n", "In order to keep track of the data collected from the environment, we will use [Reverb](https://deepmind.com/research/open-source/Reverb), an efficient, extensible, and easy-to-use replay system by Deepmind. It stores experience data when we collect trajectories and is consumed during training.\n", "\n", "This replay buffer is constructed using specs describing the tensors that are to be stored, which can be obtained from the agent using agent.collect_data_spec.\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:25.365091Z", "iopub.status.busy": "2023-12-22T13:55:25.364852Z", "iopub.status.idle": "2023-12-22T13:55:25.381868Z", "shell.execute_reply": "2023-12-22T13:55:25.381207Z" }, "id": "vX2zGUWJGWAl" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[reverb/cc/platform/tfrecord_checkpointer.cc:162] Initializing TFRecordCheckpointer in /tmpfs/tmp/tmpcvnrrkpg.\n", "[reverb/cc/platform/tfrecord_checkpointer.cc:565] Loading latest checkpoint from /tmpfs/tmp/tmpcvnrrkpg\n", "[reverb/cc/platform/default/server.cc:71] Started replay server on port 46351\n" ] } ], "source": [ "table_name = 'uniform_table'\n", "replay_buffer_signature = tensor_spec.from_spec(\n", " agent.collect_data_spec)\n", "replay_buffer_signature = tensor_spec.add_outer_dim(\n", " replay_buffer_signature)\n", "\n", "table = reverb.Table(\n", " table_name,\n", " max_size=replay_buffer_max_length,\n", " sampler=reverb.selectors.Uniform(),\n", " remover=reverb.selectors.Fifo(),\n", " rate_limiter=reverb.rate_limiters.MinSize(1),\n", " signature=replay_buffer_signature)\n", "\n", "reverb_server = reverb.Server([table])\n", "\n", "replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(\n", " agent.collect_data_spec,\n", " table_name=table_name,\n", " sequence_length=2,\n", " local_server=reverb_server)\n", "\n", "rb_observer = reverb_utils.ReverbAddTrajectoryObserver(\n", " replay_buffer.py_client,\n", " table_name,\n", " sequence_length=2)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZGNTDJpZs4NN" }, "source": [ "For most agents, `collect_data_spec` is a named tuple called `Trajectory`, containing the specs for observations, actions, rewards, and other items." ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:25.385484Z", "iopub.status.busy": "2023-12-22T13:55:25.384903Z", "iopub.status.idle": "2023-12-22T13:55:25.390729Z", "shell.execute_reply": "2023-12-22T13:55:25.390168Z" }, "id": "_IZ-3HcqgE1z" }, "outputs": [ { "data": { "text/plain": [ "Trajectory(\n", "{'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),\n", " 'observation': BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],\n", " dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],\n", " dtype=float32)),\n", " 'action': BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0), maximum=array(1)),\n", " 'policy_info': (),\n", " 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),\n", " 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),\n", " 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32))})" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "agent.collect_data_spec" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:25.394024Z", "iopub.status.busy": "2023-12-22T13:55:25.393464Z", "iopub.status.idle": "2023-12-22T13:55:25.397672Z", "shell.execute_reply": "2023-12-22T13:55:25.397125Z" }, "id": "sy6g1tGcfRlw" }, "outputs": [ { "data": { "text/plain": [ "('step_type',\n", " 'observation',\n", " 'action',\n", " 'policy_info',\n", " 'next_step_type',\n", " 'reward',\n", " 'discount')" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "agent.collect_data_spec._fields" ] }, { "cell_type": "markdown", "metadata": { "id": "rVD5nQ9ZGo8_" }, "source": [ "## Data Collection\n", "\n", "Now execute the random policy in the environment for a few steps, recording the data in the replay buffer.\n", "\n", "Here we are using 'PyDriver' to run the experience collecting loop. You can learn more about TF Agents driver in our [drivers tutorial](https://www.tensorflow.org/agents/tutorials/4_drivers_tutorial)." ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:25.401107Z", "iopub.status.busy": "2023-12-22T13:55:25.400539Z", "iopub.status.idle": "2023-12-22T13:55:25.704624Z", "shell.execute_reply": "2023-12-22T13:55:25.703936Z" }, "id": "wr1KSAEGG4h9" }, "outputs": [ { "data": { "text/plain": [ "(TimeStep(\n", " {'step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32),\n", " 'observation': array([-0.03368392, 0.18694404, -0.00172193, -0.24534112], dtype=float32)}),\n", " ())" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#@test {\"skip\": true}\n", "py_driver.PyDriver(\n", " env,\n", " py_tf_eager_policy.PyTFEagerPolicy(\n", " random_policy, use_tf_function=True),\n", " [rb_observer],\n", " max_steps=initial_collect_steps).run(train_py_env.reset())" ] }, { "cell_type": "markdown", "metadata": { "id": "84z5pQJdoKxo" }, "source": [ "The replay buffer is now a collection of Trajectories." ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:25.708600Z", "iopub.status.busy": "2023-12-22T13:55:25.707980Z", "iopub.status.idle": "2023-12-22T13:55:25.710942Z", "shell.execute_reply": "2023-12-22T13:55:25.710384Z" }, "id": "4wZnLu2ViO4E" }, "outputs": [], "source": [ "# For the curious:\n", "# Uncomment to peel one of these off and inspect it.\n", "# iter(replay_buffer.as_dataset()).next()" ] }, { "cell_type": "markdown", "metadata": { "id": "TujU-PMUsKjS" }, "source": [ "The agent needs access to the replay buffer. This is provided by creating an iterable `tf.data.Dataset` pipeline which will feed data to the agent.\n", "\n", "Each row of the replay buffer only stores a single observation step. But since the DQN Agent needs both the current and next observation to compute the loss, the dataset pipeline will sample two adjacent rows for each item in the batch (`num_steps=2`).\n", "\n", "This dataset is also optimized by running parallel calls and prefetching data." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:25.714431Z", "iopub.status.busy": "2023-12-22T13:55:25.713882Z", "iopub.status.idle": "2023-12-22T13:55:26.034185Z", "shell.execute_reply": "2023-12-22T13:55:26.033540Z" }, "id": "ba7bilizt_qW" }, "outputs": [ { "data": { "text/plain": [ "<_PrefetchDataset element_spec=(Trajectory(\n", "{'step_type': TensorSpec(shape=(64, 2), dtype=tf.int32, name=None),\n", " 'observation': TensorSpec(shape=(64, 2, 4), dtype=tf.float32, name=None),\n", " 'action': TensorSpec(shape=(64, 2), dtype=tf.int64, name=None),\n", " 'policy_info': (),\n", " 'next_step_type': TensorSpec(shape=(64, 2), dtype=tf.int32, name=None),\n", " 'reward': TensorSpec(shape=(64, 2), dtype=tf.float32, name=None),\n", " 'discount': TensorSpec(shape=(64, 2), dtype=tf.float32, name=None)}), SampleInfo(key=TensorSpec(shape=(64, 2), dtype=tf.uint64, name=None), probability=TensorSpec(shape=(64, 2), dtype=tf.float64, name=None), table_size=TensorSpec(shape=(64, 2), dtype=tf.int64, name=None), priority=TensorSpec(shape=(64, 2), dtype=tf.float64, name=None), times_sampled=TensorSpec(shape=(64, 2), dtype=tf.int32, name=None)))>" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Dataset generates trajectories with shape [Bx2x...]\n", "dataset = replay_buffer.as_dataset(\n", " num_parallel_calls=3,\n", " sample_batch_size=batch_size,\n", " num_steps=2).prefetch(3)\n", "\n", "dataset" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:26.037724Z", "iopub.status.busy": "2023-12-22T13:55:26.037127Z", "iopub.status.idle": "2023-12-22T13:55:26.122254Z", "shell.execute_reply": "2023-12-22T13:55:26.121618Z" }, "id": "K13AST-2ppOq" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "iterator = iter(dataset)\n", "print(iterator)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:26.125434Z", "iopub.status.busy": "2023-12-22T13:55:26.125207Z", "iopub.status.idle": "2023-12-22T13:55:26.128430Z", "shell.execute_reply": "2023-12-22T13:55:26.127753Z" }, "id": "Th5w5Sff0b16" }, "outputs": [], "source": [ "# For the curious:\n", "# Uncomment to see what the dataset iterator is feeding to the agent.\n", "# Compare this representation of replay data\n", "# to the collection of individual trajectories shown earlier.\n", "\n", "# iterator.next()" ] }, { "cell_type": "markdown", "metadata": { "id": "hBc9lj9VWWtZ" }, "source": [ "## Training the agent\n", "\n", "Two things must happen during the training loop:\n", "\n", "- collect data from the environment\n", "- use that data to train the agent's neural network(s)\n", "\n", "This example also periodicially evaluates the policy and prints the current score.\n", "\n", "The following will take ~5 minutes to run." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T13:55:26.131645Z", "iopub.status.busy": "2023-12-22T13:55:26.131384Z", "iopub.status.idle": "2023-12-22T14:02:44.359198Z", "shell.execute_reply": "2023-12-22T14:02:44.358222Z" }, "id": "0pTbJ3PeyF-u" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:1260: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "back_prop=False is deprecated. Consider using tf.stop_gradient instead.\n", "Instead of:\n", "results = tf.foldr(fn, elems, back_prop=False)\n", "Use:\n", "results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (42980) so Table uniform_table is accessed directly without gRPC.\n", "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (42980) so Table uniform_table is accessed directly without gRPC.\n", "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (42980) so Table uniform_table is accessed directly without gRPC.\n", "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (42980) so Table uniform_table is accessed directly without gRPC.\n", "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (42980) so Table uniform_table is accessed directly without gRPC.\n", "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (42980) so Table uniform_table is accessed directly without gRPC.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1703253329.256450 44311 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 200: loss = 168.9337615966797\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 400: loss = 2.769679069519043\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 600: loss = 20.378292083740234\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 800: loss = 2.9951205253601074\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 1000: loss = 3.985201358795166\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 1000: Average Return = 41.5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 1200: loss = 27.128450393676758\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 1400: loss = 5.9545087814331055\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 1600: loss = 30.321374893188477\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 1800: loss = 4.8639116287231445\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 2000: loss = 77.69764709472656\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 2000: Average Return = 189.3000030517578\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 2200: loss = 38.41033935546875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 2400: loss = 73.83688354492188\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 2600: loss = 89.96795654296875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 2800: loss = 318.172119140625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 3000: loss = 119.87837219238281\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 3000: Average Return = 183.1999969482422\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 3200: loss = 348.0591125488281\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 3400: loss = 306.32928466796875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 3600: loss = 2720.41943359375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 3800: loss = 1241.906982421875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 4000: loss = 259.3073425292969\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 4000: Average Return = 177.60000610351562\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 4200: loss = 411.57086181640625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 4400: loss = 96.17520141601562\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 4600: loss = 293.4364318847656\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 4800: loss = 115.97804260253906\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 5000: loss = 135.9969482421875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 5000: Average Return = 184.10000610351562\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 5200: loss = 108.25897216796875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 5400: loss = 117.57241821289062\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 5600: loss = 203.2187957763672\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 5800: loss = 107.27171325683594\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 6000: loss = 89.8726806640625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 6000: Average Return = 196.5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 6200: loss = 719.5379638671875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 6400: loss = 671.7078247070312\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 6600: loss = 605.4098510742188\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 6800: loss = 118.79557800292969\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 7000: loss = 1082.111572265625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 7000: Average Return = 200.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 7200: loss = 377.11651611328125\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 7400: loss = 135.56011962890625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 7600: loss = 155.7529296875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 7800: loss = 162.6855926513672\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 8000: loss = 160.82798767089844\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 8000: Average Return = 200.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 8200: loss = 162.89614868164062\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 8400: loss = 167.7406005859375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 8600: loss = 108.040771484375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 8800: loss = 545.4006958007812\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 9000: loss = 176.59364318847656\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 9000: Average Return = 200.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 9200: loss = 808.9935913085938\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 9400: loss = 179.5496063232422\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 9600: loss = 115.72040557861328\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 9800: loss = 110.83393096923828\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 10000: loss = 1168.90380859375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 10000: Average Return = 200.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 10200: loss = 387.125244140625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 10400: loss = 3282.5703125\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 10600: loss = 4486.83642578125\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 10800: loss = 5873.224609375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 11000: loss = 4588.74462890625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 11000: Average Return = 200.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 11200: loss = 233958.21875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 11400: loss = 3961.323486328125\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 11600: loss = 9469.7607421875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 11800: loss = 79834.6953125\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 12000: loss = 6522.5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 12000: Average Return = 200.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 12200: loss = 4317.1884765625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 12400: loss = 187011.5625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 12600: loss = 2300.244873046875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 12800: loss = 2199.23193359375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 13000: loss = 4176.35888671875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 13000: Average Return = 154.10000610351562\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 13200: loss = 3100.556640625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 13400: loss = 114706.8125\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 13600: loss = 1447.1259765625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 13800: loss = 11129.3818359375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 14000: loss = 1454.640380859375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 14000: Average Return = 200.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 14200: loss = 1165.739990234375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 14400: loss = 1011.5919189453125\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 14600: loss = 1090.4755859375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 14800: loss = 1562.9677734375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 15000: loss = 1205.5361328125\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 15000: Average Return = 200.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 15200: loss = 913.7637939453125\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 15400: loss = 8834.7216796875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 15600: loss = 318027.15625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 15800: loss = 5136.9150390625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 16000: loss = 374743.65625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 16000: Average Return = 200.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 16200: loss = 4737.19287109375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 16400: loss = 5279.40478515625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 16600: loss = 4674.5009765625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 16800: loss = 3743.15087890625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 17000: loss = 15105.62109375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 17000: Average Return = 200.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 17200: loss = 938550.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 17400: loss = 9318.6015625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 17600: loss = 10585.978515625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 17800: loss = 8195.138671875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 18000: loss = 288772.40625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 18000: Average Return = 200.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 18200: loss = 6771.6826171875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 18400: loss = 3363.34326171875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 18600: loss = 611807.75\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 18800: loss = 6124.15966796875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 19000: loss = 1373558.5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 19000: Average Return = 200.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 19200: loss = 764662.625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 19400: loss = 342950.84375\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 19600: loss = 10324.072265625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 19800: loss = 13140.9892578125\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 20000: loss = 55873.1328125\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 20000: Average Return = 200.0\n" ] } ], "source": [ "#@test {\"skip\": true}\n", "try:\n", " %%time\n", "except:\n", " pass\n", "\n", "# (Optional) Optimize by wrapping some of the code in a graph using TF function.\n", "agent.train = common.function(agent.train)\n", "\n", "# Reset the train step.\n", "agent.train_step_counter.assign(0)\n", "\n", "# Evaluate the agent's policy once before training.\n", "avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)\n", "returns = [avg_return]\n", "\n", "# Reset the environment.\n", "time_step = train_py_env.reset()\n", "\n", "# Create a driver to collect experience.\n", "collect_driver = py_driver.PyDriver(\n", " env,\n", " py_tf_eager_policy.PyTFEagerPolicy(\n", " agent.collect_policy, use_tf_function=True),\n", " [rb_observer],\n", " max_steps=collect_steps_per_iteration)\n", "\n", "for _ in range(num_iterations):\n", "\n", " # Collect a few steps and save to the replay buffer.\n", " time_step, _ = collect_driver.run(time_step)\n", "\n", " # Sample a batch of data from the buffer and update the agent's network.\n", " experience, unused_info = next(iterator)\n", " train_loss = agent.train(experience).loss\n", "\n", " step = agent.train_step_counter.numpy()\n", "\n", " if step % log_interval == 0:\n", " print('step = {0}: loss = {1}'.format(step, train_loss))\n", "\n", " if step % eval_interval == 0:\n", " avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)\n", " print('step = {0}: Average Return = {1}'.format(step, avg_return))\n", " returns.append(avg_return)" ] }, { "cell_type": "markdown", "metadata": { "id": "68jNcA_TiJDq" }, "source": [ "## Visualization\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aO-LWCdbbOIC" }, "source": [ "### Plots\n", "\n", "Use `matplotlib.pyplot` to chart how the policy improved during training.\n", "\n", "One iteration of `Cartpole-v0` consists of 200 time steps. The environment gives a reward of `+1` for each step the pole stays up, so the maximum return for one episode is 200. The charts shows the return increasing towards that maximum each time it is evaluated during training. (It may be a little unstable and not increase monotonically each time.)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:02:44.363116Z", "iopub.status.busy": "2023-12-22T14:02:44.362813Z", "iopub.status.idle": "2023-12-22T14:02:44.580901Z", "shell.execute_reply": "2023-12-22T14:02:44.580147Z" }, "id": "NxtL1mbOYCVO" }, "outputs": [ { "data": { "text/plain": [ "(0.08000040054321289, 250.0)" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#@test {\"skip\": true}\n", "\n", "iterations = range(0, num_iterations + 1, eval_interval)\n", "plt.plot(iterations, returns)\n", "plt.ylabel('Average Return')\n", "plt.xlabel('Iterations')\n", "plt.ylim(top=250)" ] }, { "cell_type": "markdown", "metadata": { "id": "M7-XpPP99Cy7" }, "source": [ "### Videos" ] }, { "cell_type": "markdown", "metadata": { "id": "9pGfGxSH32gn" }, "source": [ "Charts are nice. But more exciting is seeing an agent actually performing a task in an environment.\n", "\n", "First, create a function to embed videos in the notebook." ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:02:44.584768Z", "iopub.status.busy": "2023-12-22T14:02:44.584312Z", "iopub.status.idle": "2023-12-22T14:02:44.588924Z", "shell.execute_reply": "2023-12-22T14:02:44.588085Z" }, "id": "ULaGr8pvOKbl" }, "outputs": [], "source": [ "def embed_mp4(filename):\n", " \"\"\"Embeds an mp4 file in the notebook.\"\"\"\n", " video = open(filename,'rb').read()\n", " b64 = base64.b64encode(video)\n", " tag = '''\n", " '''.format(b64.decode())\n", "\n", " return IPython.display.HTML(tag)" ] }, { "cell_type": "markdown", "metadata": { "id": "9c_PH-pX4Pr5" }, "source": [ "Now iterate through a few episodes of the Cartpole game with the agent. The underlying Python environment (the one \"inside\" the TensorFlow environment wrapper) provides a `render()` method, which outputs an image of the environment state. These can be collected into a video." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:02:44.592390Z", "iopub.status.busy": "2023-12-22T14:02:44.591933Z", "iopub.status.idle": "2023-12-22T14:02:55.961972Z", "shell.execute_reply": "2023-12-22T14:02:55.961152Z" }, "id": "owOVWB158NlF" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[swscaler @ 0x555a5d3cf880] Warning: data is not aligned! This can lead to a speed loss\n" ] }, { "data": { "text/html": [ "\n", " " ], "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def create_policy_eval_video(policy, filename, num_episodes=5, fps=30):\n", " filename = filename + \".mp4\"\n", " with imageio.get_writer(filename, fps=fps) as video:\n", " for _ in range(num_episodes):\n", " time_step = eval_env.reset()\n", " video.append_data(eval_py_env.render())\n", " while not time_step.is_last():\n", " action_step = policy.action(time_step)\n", " time_step = eval_env.step(action_step.action)\n", " video.append_data(eval_py_env.render())\n", " return embed_mp4(filename)\n", "\n", "create_policy_eval_video(agent.policy, \"trained-agent\")" ] }, { "cell_type": "markdown", "metadata": { "id": "povaAOcZygLw" }, "source": [ "For fun, compare the trained agent (above) to an agent moving randomly. (It does not do as well.)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:02:55.966374Z", "iopub.status.busy": "2023-12-22T14:02:55.966121Z", "iopub.status.idle": "2023-12-22T14:02:57.039849Z", "shell.execute_reply": "2023-12-22T14:02:57.038874Z" }, "id": "pJZIdC37yNH4" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[swscaler @ 0x55f466934880] Warning: data is not aligned! This can lead to a speed loss\n" ] }, { "data": { "text/html": [ "\n", " " ], "text/plain": [ "" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "create_policy_eval_video(random_policy, \"random-agent\")" ] } ], "metadata": { "colab": { "name": "DQN Tutorial.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 0 }