{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "klGNgWREsvQv" }, "source": [ "##### Copyright 2023 The TF-Agents Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-12-22T14:04:42.484544Z", "iopub.status.busy": "2023-12-22T14:04:42.484329Z", "iopub.status.idle": "2023-12-22T14:04:42.487821Z", "shell.execute_reply": "2023-12-22T14:04:42.487255Z" }, "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "HNtBC6Bbb1YU" }, "source": [ "# REINFORCE agent\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " View on TensorFlow.org\n", " \n", " \n", " \n", " Run in Google Colab\n", " \n", " \n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "ZOUOQOrFs3zn" }, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "metadata": { "id": "cKOCZlhUgXVK" }, "source": [ "This example shows how to train a [REINFORCE](https://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) agent on the Cartpole environment using the TF-Agents library, similar to the [DQN tutorial](1_dqn_tutorial.ipynb).\n", "\n", "![Cartpole environment](images/cartpole.png)\n", "\n", "We will walk you through all the components in a Reinforcement Learning (RL) pipeline for training, evaluation and data collection.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "1u9QVVsShC9X" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "I5PNmEzIb9t4" }, "source": [ "If you haven't installed the following dependencies, run:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:04:42.491556Z", "iopub.status.busy": "2023-12-22T14:04:42.491001Z", "iopub.status.idle": "2023-12-22T14:05:02.486448Z", "shell.execute_reply": "2023-12-22T14:05:02.485617Z" }, "id": "KEHR2Ui-lo8O" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]\r", " \r", "Hit:1 http://us-central1.gce.archive.ubuntu.com/ubuntu focal InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Connecting to apt.llvm\r", " \r", "Hit:2 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates InRelease\r\n", "\r", " \r", "Hit:3 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-backports InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Waiting for headers] [\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Connected to developer" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Get:5 https://nvidia.github.io/libnvidia-container/stable/ubuntu18.04/amd64 InRelease [1484 B]\r\n", "\r", "0% [Connected to apt.llvm.org (199.232.198.49)] [Connecting to security.ubuntu.\r", "0% [Connected to apt.llvm.org (199.232.198.49)] [Connecting to security.ubuntu.\r", " \r", "Hit:6 https://download.docker.com/linux/ubuntu focal InRelease\r\n", "\r", "0% [Connected to apt.llvm.org (199.232.198.49)] [Connecting to security.ubuntu.\r", " \r", "Hit:7 https://nvidia.github.io/nvidia-container-runtime/stable/ubuntu18.04/amd64 InRelease\r\n", "\r", "0% [Connected to apt.llvm.org (199.232.198.49)] [Connecting to security.ubuntu.\r", " \r", "Hit:8 https://nvidia.github.io/nvidia-docker/ubuntu18.04/amd64 InRelease\r\n", "\r", "0% [Connected to apt.llvm.org (199.232.198.49)] [Connecting to security.ubuntu." ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:4 https://apt.llvm.org/focal llvm-toolchain-focal-17 InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Waiting for headers] [\r", " \r", "Hit:9 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64 InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (185.125.190.36)] [Connecting to ppa.laun" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "0% [Waiting for headers] [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:10 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease\r\n", "\r", " \r", "0% [Waiting for headers]\r", " \r", "Hit:11 http://security.ubuntu.com/ubuntu focal-security InRelease\r\n", "\r", " \r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:12 http://ppa.launchpad.net/longsleep/golang-backports/ubuntu focal InRelease\r\n", "\r", " \r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:13 http://ppa.launchpad.net/openjdk-r/ppa/ubuntu focal InRelease\r\n", "\r", " \r", "0% [Working]\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "100% [Working]\r", " \r", "Fetched 1484 B in 1s (1075 B/s)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 0%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 2%\r", "\r", "Reading package lists... 2%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 28%\r", "\r", "Reading package lists... 28%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r", "\r", "Reading package lists... 41%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 49%\r", "\r", "Reading package lists... 49%\r", "\r", "Reading package lists... 55%\r", "\r", "Reading package lists... 55%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 57%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 62%\r", "\r", "Reading package lists... 62%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 65%\r", "\r", "Reading package lists... 65%\r", "\r", "Reading package lists... 68%\r", "\r", "Reading package lists... 68%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 70%\r", "\r", "Reading package lists... 77%\r", "\r", "Reading package lists... 77%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 82%\r", "\r", "Reading package lists... 82%\r", "\r", "Reading package lists... 89%\r", "\r", "Reading package lists... 89%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 92%\r", "\r", "Reading package lists... 92%\r", "\r", "Reading package lists... 94%\r", "\r", "Reading package lists... 94%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... Done\r", "\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 100%\r", "\r", "Reading package lists... Done\r", "\r\n", "\r", "Building dependency tree... 0%\r", "\r", "Building dependency tree... 0%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree... 50%\r", "\r", "Building dependency tree... 50%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree \r", "\r\n", "\r", "Reading state information... 0%\r", "\r", "Reading state information... 0%\r", "\r", "Reading state information... Done\r", "\r\n", "freeglut3-dev is already the newest version (2.8.1-3).\r\n", "ffmpeg is already the newest version (7:4.2.7-0ubuntu0.1).\r\n", "xvfb is already the newest version (2:1.20.13-1ubuntu1~20.04.12).\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The following packages were automatically installed and are no longer required:\r\n", " libatasmart4 libblockdev-fs2 libblockdev-loop2 libblockdev-part-err2\r\n", " libblockdev-part2 libblockdev-swap2 libblockdev-utils2 libblockdev2\r\n", " libparted-fs-resize0 libxmlb2\r\n", "Use 'sudo apt autoremove' to remove them.\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0 upgraded, 0 newly installed, 0 to remove and 115 not upgraded.\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting imageio==2.4.0\r\n", " Using cached imageio-2.4.0-py3-none-any.whl\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from imageio==2.4.0) (1.26.2)\r\n", "Requirement already satisfied: pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from imageio==2.4.0) (10.1.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: imageio\r\n", " Attempting uninstall: imageio\r\n", " Found existing installation: imageio 2.33.1\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Uninstalling imageio-2.33.1:\r\n", " Successfully uninstalled imageio-2.33.1\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n", "scikit-image 0.22.0 requires imageio>=2.27, but you have imageio 2.4.0 which is incompatible.\u001b[0m\u001b[31m\r\n", "\u001b[0mSuccessfully installed imageio-2.4.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting pyvirtualdisplay\r\n", " Using cached PyVirtualDisplay-3.0-py3-none-any.whl (15 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: pyvirtualdisplay\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed pyvirtualdisplay-3.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tf-agents[reverb]\r\n", " Using cached tf_agents-0.19.0-py3-none-any.whl.metadata (12 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: absl-py>=0.6.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (1.4.0)\r\n", "Collecting cloudpickle>=1.3 (from tf-agents[reverb])\r\n", " Using cached cloudpickle-3.0.0-py3-none-any.whl.metadata (7.0 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gin-config>=0.4.0 (from tf-agents[reverb])\r\n", " Using cached gin_config-0.5.0-py3-none-any.whl (61 kB)\r\n", "Collecting gym<=0.23.0,>=0.17.0 (from tf-agents[reverb])\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Using cached gym-0.23.0-py3-none-any.whl\r\n", "Requirement already satisfied: numpy>=1.19.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (1.26.2)\r\n", "Requirement already satisfied: pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (10.1.0)\r\n", "Requirement already satisfied: six>=1.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (1.16.0)\r\n", "Requirement already satisfied: protobuf>=3.11.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (3.20.3)\r\n", "Requirement already satisfied: wrapt>=1.11.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (1.14.1)\r\n", "Collecting typing-extensions==4.5.0 (from tf-agents[reverb])\r\n", " Using cached typing_extensions-4.5.0-py3-none-any.whl (27 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting pygame==2.1.3 (from tf-agents[reverb])\r\n", " Using cached pygame-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.7 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow-probability~=0.23.0 (from tf-agents[reverb])\r\n", " Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl.metadata (13 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting rlds (from tf-agents[reverb])\r\n", " Using cached rlds-0.1.8-py3-none-manylinux2010_x86_64.whl (48 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting dm-reverb~=0.14.0 (from tf-agents[reverb])\r\n", " Using cached dm_reverb-0.14.0-cp39-cp39-manylinux2014_x86_64.whl.metadata (17 kB)\r\n", "Requirement already satisfied: tensorflow~=2.15.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents[reverb]) (2.15.0.post1)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from dm-reverb~=0.14.0->tf-agents[reverb]) (0.1.8)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting portpicker (from dm-reverb~=0.14.0->tf-agents[reverb])\r\n", " Using cached portpicker-1.6.0-py3-none-any.whl.metadata (1.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gym-notices>=0.0.4 (from gym<=0.23.0,>=0.17.0->tf-agents[reverb])\r\n", " Using cached gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n", "Requirement already satisfied: importlib-metadata>=4.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from gym<=0.23.0,>=0.17.0->tf-agents[reverb]) (7.0.0)\r\n", "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (23.5.26)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (0.5.4)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (0.2.0)\r\n", "Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (3.10.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (16.0.6)\r\n", "Requirement already satisfied: ml-dtypes~=0.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (0.2.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (23.2)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (69.0.2)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (2.4.0)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (0.35.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (1.60.0)\r\n", "Requirement already satisfied: tensorboard<2.16,>=2.15 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (2.15.1)\r\n", "Requirement already satisfied: tensorflow-estimator<2.16,>=2.15.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (2.15.0)\r\n", "Requirement already satisfied: keras<2.16,>=2.15.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.15.0->tf-agents[reverb]) (2.15.0)\r\n", "Requirement already satisfied: decorator in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents[reverb]) (5.1.1)\r\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow~=2.15.0->tf-agents[reverb]) (0.41.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.10.0->gym<=0.23.0,>=0.17.0->tf-agents[reverb]) (3.17.0)\r\n", "Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2.25.2)\r\n", "Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (1.2.0)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.5.1)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2.31.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.0.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: psutil in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from portpicker->dm-reverb~=0.14.0->tf-agents[reverb]) (5.9.7)\r\n", "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (5.3.2)\r\n", "Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (0.3.0)\r\n", "Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (4.9)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (1.3.1)\r\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.6)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2.1.0)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2023.11.17)\r\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (2.1.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (0.5.1)\r\n", "Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow~=2.15.0->tf-agents[reverb]) (3.2.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached cloudpickle-3.0.0-py3-none-any.whl (20 kB)\r\n", "Using cached dm_reverb-0.14.0-cp39-cp39-manylinux2014_x86_64.whl (6.4 MB)\r\n", "Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl (6.9 MB)\r\n", "Using cached tf_agents-0.19.0-py3-none-any.whl (1.4 MB)\r\n", "Using cached portpicker-1.6.0-py3-none-any.whl (16 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: gym-notices, gin-config, typing-extensions, rlds, pygame, portpicker, cloudpickle, tensorflow-probability, gym, dm-reverb, tf-agents\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: typing-extensions\r\n", " Found existing installation: typing_extensions 4.9.0\r\n", " Uninstalling typing_extensions-4.9.0:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled typing_extensions-4.9.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed cloudpickle-3.0.0 dm-reverb-0.14.0 gin-config-0.5.0 gym-0.23.0 gym-notices-0.0.8 portpicker-1.6.0 pygame-2.1.3 rlds-0.1.8 tensorflow-probability-0.23.0 tf-agents-0.19.0 typing-extensions-4.5.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyglet in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.0.10)\r\n", "Requirement already satisfied: xvfbwrapper in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (0.2.9)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tf-keras\r\n", " Using cached tf_keras-2.15.0-py3-none-any.whl.metadata (1.6 kB)\r\n", "Using cached tf_keras-2.15.0-py3-none-any.whl (1.7 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: tf-keras\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed tf-keras-2.15.0\r\n" ] } ], "source": [ "!sudo apt-get update\n", "!sudo apt-get install -y xvfb ffmpeg freeglut3-dev\n", "!pip install 'imageio==2.4.0'\n", "!pip install pyvirtualdisplay\n", "!pip install tf-agents[reverb]\n", "!pip install pyglet xvfbwrapper\n", "!pip install tf-keras" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:02.490785Z", "iopub.status.busy": "2023-12-22T14:05:02.490492Z", "iopub.status.idle": "2023-12-22T14:05:02.494190Z", "shell.execute_reply": "2023-12-22T14:05:02.493660Z" }, "id": "WPuD0bMEY9Iz" }, "outputs": [], "source": [ "import os\n", "# Keep using keras-2 (tf-keras) rather than keras-3 (keras).\n", "os.environ['TF_USE_LEGACY_KERAS'] = '1'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:02.497273Z", "iopub.status.busy": "2023-12-22T14:05:02.497032Z", "iopub.status.idle": "2023-12-22T14:05:05.889818Z", "shell.execute_reply": "2023-12-22T14:05:05.888543Z" }, "id": "sMitx5qSgJk1" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-12-22 14:05:03.363396: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-12-22 14:05:03.363443: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-12-22 14:05:03.365008: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "import base64\n", "import imageio\n", "import IPython\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import PIL.Image\n", "import pyvirtualdisplay\n", "import reverb\n", "\n", "import tensorflow as tf\n", "\n", "from tf_agents.agents.reinforce import reinforce_agent\n", "from tf_agents.drivers import py_driver\n", "from tf_agents.environments import suite_gym\n", "from tf_agents.environments import tf_py_environment\n", "from tf_agents.networks import actor_distribution_network\n", "from tf_agents.policies import py_tf_eager_policy\n", "from tf_agents.replay_buffers import reverb_replay_buffer\n", "from tf_agents.replay_buffers import reverb_utils\n", "from tf_agents.specs import tensor_spec\n", "from tf_agents.trajectories import trajectory\n", "from tf_agents.utils import common\n", "\n", "# Set up a virtual display for rendering OpenAI gym environments.\n", "display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()" ] }, { "cell_type": "markdown", "metadata": { "id": "LmC0NDhdLIKY" }, "source": [ "## Hyperparameters" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:05.895225Z", "iopub.status.busy": "2023-12-22T14:05:05.894401Z", "iopub.status.idle": "2023-12-22T14:05:05.899775Z", "shell.execute_reply": "2023-12-22T14:05:05.899179Z" }, "id": "HC1kNrOsLSIZ" }, "outputs": [], "source": [ "env_name = \"CartPole-v0\" # @param {type:\"string\"}\n", "num_iterations = 250 # @param {type:\"integer\"}\n", "collect_episodes_per_iteration = 2 # @param {type:\"integer\"}\n", "replay_buffer_capacity = 2000 # @param {type:\"integer\"}\n", "\n", "fc_layer_params = (100,)\n", "\n", "learning_rate = 1e-3 # @param {type:\"number\"}\n", "log_interval = 25 # @param {type:\"integer\"}\n", "num_eval_episodes = 10 # @param {type:\"integer\"}\n", "eval_interval = 50 # @param {type:\"integer\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "VMsJC3DEgI0x" }, "source": [ "## Environment\n", "\n", "Environments in RL represent the task or problem that we are trying to solve. Standard environments can be easily created in TF-Agents using `suites`. We have different `suites` for loading environments from sources such as the OpenAI Gym, Atari, DM Control, etc., given a string environment name.\n", "\n", "Now let us load the CartPole environment from the OpenAI Gym suite." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:05.902973Z", "iopub.status.busy": "2023-12-22T14:05:05.902734Z", "iopub.status.idle": "2023-12-22T14:05:05.933129Z", "shell.execute_reply": "2023-12-22T14:05:05.932580Z" }, "id": "pYEz-S9gEv2-" }, "outputs": [], "source": [ "env = suite_gym.load(env_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "IIHYVBkuvPNw" }, "source": [ "We can render this environment to see how it looks. A free-swinging pole is attached to a cart. The goal is to move the cart right or left in order to keep the pole pointing up." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:05.936382Z", "iopub.status.busy": "2023-12-22T14:05:05.936138Z", "iopub.status.idle": "2023-12-22T14:05:06.079783Z", "shell.execute_reply": "2023-12-22T14:05:06.079193Z" }, "id": "RlO7WIQHu_7D" }, "outputs": [ { "data": { "image/jpeg": "", "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAIAAAD9V4nPAAAUu0lEQVR4Ae3dvY5kVxUFYFfPCAIIEYkDckSITcILkCCewjwTfgpEwguQmZAXQAgCRAgSIHcVY3tm6HZX1dTPvfvuc9ZnITFTXXXP2d86raWqvjOzOxwOH/mPAAECBAikCjykDm5uAgQIECDwlYAidA4IECBAIFpAEUbHb3gCBAgQUITOAAECBAhECyjC6PgNT4AAAQKK0BkgQIAAgWgBRRgdv+EJECBAQBE6AwQIECAQLaAIo+M3PAECBAgoQmeAAAECBKIFFGF0/IYnQIAAAUXoDBAgQIBAtIAijI7f8AQIECCgCJ0BAgQIEIgWUITR8RueAAECBBShM0CAAAEC0QKKMDp+wxMgQICAInQGCBAgQCBaQBFGx294AgQIEFCEzgABAgQIRAsowuj4DU+AAAECitAZIECAAIFoAUUYHb/hCRAgQEAROgMECBAgEC2gCKPjNzwBAgQIKEJngAABAgSiBRRhdPyGJ0CAAAFF6AwQIECAQLSAIoyO3/AECBAgoAidAQIECBCIFlCE0fEbngABAgQUoTNAgAABAtECijA6fsMTIECAgCJ0BggQIEAgWkARRsdveAIECBBQhM4AAQIECEQLKMLo+A1PgAABAorQGSBAgACBaAFFGB2/4QkQIEBAEToDBAgQIBAtoAij4zc8AQIECChCZ4AAAQIEogUUYXT8hidAgAABRegMECBAgEC0gCKMjt/wBAgQIKAInQECBAgQiBZQhNHxG54AAQIEFKEzQIAAAQLRAoowOn7DEyBAgIAidAYIECBAIFpAEUbHb3gCBAgQUITOAAECBAhECyjC6PgNT4AAAQKK0BkgQIAAgWgBRRgdv+EJECBAQBE6AwQIECAQLaAIo+M3PAECBAgoQmeAAAECBKIFFGF0/IYnQIAAAUXoDBAgQIBAtIAijI7f8AQIECCgCJ0BAgQIEIgWUITR8RueAAECBBShM0CAAAEC0QKKMDp+wxMgQICAInQGCBAgQCBaQBFGx294AgQIEFCEzgABAgQIRAsowuj4DU+AAAECitAZIECAAIFoAUUYHb/hCRAgQEAROgMECBAgEC2gCKPjNzwBAgQIKEJngAABAgSiBRRhdPyGJ0CAAAFF6AwQIECAQLSAIoyO3/AECBAgoAidAQIECBCIFlCE0fEbngABAgQUoTNAgAABAtECijA6fsMTIECAgCJ0BggQIEAgWkARRsdveAIECBBQhM4AAQIECEQLKMLo+A1PgAABAorQGSBAgACBaAFFGB2/4QkQIEBAEToDBAgQIBAtoAij4zc8AQIECChCZ4AAAQIEogUUYXT8hidAgAABRegMECBAgEC0gCKMjt/wBAgQIKAInQECBAgQiBZQhNHxG54AAQIEFKEzQIAAAQLRAoowOn7DEyBAgIAidAYIECBAIFpAEUbHb3gCBAgQUITOAAECBAhECyjC6PgNT4AAAQKK0BkgQIAAgWgBRRgdv+EJECBAQBE6AwQIECAQLaAIo+M3PAECBAgoQmeAAAECBKIFFGF0/IYnQIAAAUXoDBAgQIBAtIAijI7f8AQIECCgCJ0BAgQIEIgWUITR8RueAAECBBShM0CAAAEC0QKKMDp+wxMgQICAInQGCBAgQCBaQBFGx294AgQIEFCEzgABAgQIRAsowuj4DU+AAAECitAZIECAAIFoAUUYHb/hCRAgQEAROgMECBAgEC2gCKPjNzwBAgQIKEJngAABAgSiBRRhdPyGJ0CAAAFF6AwQIECAQLSAIoyO3/AECBAgoAidAQIECBCIFlCE0fEbngABAgQUoTNAgAABAtECijA6fsMTIECAgCJ0BggQIEAgWkARRsdveAIECBBQhM4AAQIECEQLKMLo+A1PgAABAorQGSBAgACBaAFFGB2/4QkQIEBAEToDBAgQIBAtoAij4zc8AQIECChCZ4AAAQIEogUUYXT8hidAgAABRegMECBAgEC0gCKMjt/wBAgQIKAInQECBAgQiBZQhNHxG54AAQIEFKEzQIAAAQLRAoowOn7DEyBAgIAidAYIECBAIFpAEUbHb3gCBAgQUITOAAECBAhECyjC6PgNT4AAAQKK0BkgQIAAgWgBRRgdv+EJECBAQBE6AwQIECAQLaAIo+M3PAECBAgoQmeAAAECBKIFFGF0/IYnQIAAAUXoDBAgQIBAtIAijI7f8AQIECCgCJ0BAgQIEIgWUITR8RueAAECBBShM0CAAAEC0QKKMDp+wxMgQICAInQGCBAgQCBaQBFGx294AgQIEFCEzgABAgQIRAsowuj4DU+AAAECitAZIECAAIFoAUUYHb/hCRAgQEAROgMECBAgEC2gCKPjNzwBAgQIKEJngAABAgSiBRRhdPyGJ0CAAIHXCAgQqBH44jefnV/ok19/fv4JvkqAwBoC3hGuoeqaBG4ROOz3t7zMawgQuE9AEd7n59UElhM4HB6Xu5grESBwqYAivFTK8wisLeAd4drCrk/gqIAiPMriQQIbCBz23hFuwG5JAorQGSDQRuDgZ4RtsrCRJAFFmJS2WXsL+Gi0dz52N62AIpw2WoMNJ+BmmeEis+E5BBThHDmaYgYBPyOcIUUzDCigCAcMzZYnFfDR6KTBGqu7gCLsnpD95Qh4R5iTtUlbCSjCVnHYTLTAwV2j0fkbfjMBRbgZvYUJfFvAX7H2bRG/J1AhoAgrlK1B4BIBd41eouQ5BBYXUISLk7oggRsF3CxzI5yXEbhPQBHe5+fVBJYTcLPMcpauROAKAUV4BZanElhVwM0yq/K6OIFTAorwlIzHCVQL+Gi0Wtx6BL4WUIQOAoE2Av49wjZR2EiUgCKMituwrQX8jLB1PDY3r4AinDdbk40m4KPR0RKz30kEFOEkQRpjAgHvCCcI0QgjCijCEVOz5zkF3DU6Z66mai+gCNtHZIMxAj4ajYnaoL0EFGGvPOwmWcBHo8npm31DAUW4Ib6lCTwX8K9PPPfwOwI1AoqwxtkqBD4s4B3hh408g8AKAopwBVSXJHCTgJtlbmLzIgL3CijCewW9nsBSAm6WWUrSdQhcJaAIr+LyZAIrCvhodEVclyZwWkARnrbxFQK1Aoqw1ttqBN4KKEJHgUAXAT8j7JKEfYQJKMKwwI27ncDHn/zq/OJ/++Pvzj/BVwkQWENAEa6h6poEjgjsHny7HWHxEIHNBXxnbh6BDaQI7Ha+3VKyNudYAr4zx8rLbgcW2D28Gnj3tk5gXgFFOG+2Jusm4B1ht0Tsh8DXAorQQSBQJOAdYRG0ZQhcKaAIrwTzdAK3CrhZ5lY5ryOwroAiXNfX1Qm8F3CzzHsKvyDQSkARtorDZmYW8NHozOmabWQBRThyevY+lIAiHCoumw0SUIRBYRt1WwE/I9zW3+oETgkowlMyHiewtMDOnyNcmtT1CCwhoAiXUHQNAhcIeEd4AZKnENhAQBFugG7JTAF3jWbmbur+Aoqwf0Z2OImAm2UmCdIY0wkowukiNVBXAR+Ndk3GvtIFFGH6CTB/nYCbZeqsrUTgCgFFeAWWpxK4R8A7wnv0vJbAegKKcD1bVybwTMDNMs84/IZAGwFF2CYKG5ldwM0ysydsvlEFFOGoydn3cAKKcLjIbDhEQBGGBG3M7QX8jHD7DOyAwDEBRXhMxWMEVhDYuWt0BVWXJHC/gCK839AVCFwm8ODb7TIozyJQK+A7s9bbasEC7hoNDt/orQUUYet4bG4mATfLzJSmWWYSUIQzpWmW1gJulmkdj80FCyjC4PCNXivgZplab6sRuFRAEV4q5XkE7hTwjvBOQC8nsJKAIlwJ1mUJvBDY+XZ7YeIBAg0EfGc2CMEWMgQeXr3OGNSUBAYTUISDBWa7cwscDoe5BzQdgYYCirBhKLaUK3DYP+YOb3ICGwkowo3gLUvgqMBhf/RhDxIgsJ6AIlzP1pUJXC1w2CvCq9G8gMCdAorwTkAvJ7CkwOHgo9ElPV2LwCUCivASJc8hUCTgHWERtGUIPBFQhE8w/JLA1gJultk6AesnCijCxNTN3Fbg4GaZttnY2LwCinDebE02ooCbZUZMzZ4HF1CEgwdo+3MJ+Gh0rjxNM4aAIhwjJ7sMEXDXaEjQxmwloAhbxWEz6QLuGk0/AebfQkARbqFuTQInBHw0egLGwwRWFFCEK+K6NIFrBdw1eq2Y5xO4X0AR3m/oCgQWE/DR6GKULkTgYgFFeDGVJxIoEPBXrBUgW4LAcwFF+NzD7whsKuAd4ab8Fg8VUIShwRu7p4CbZXrmYldzCyjCufM13WACbpYZLDDbnUJAEU4RoyFmEfDR6CxJmmMkAUU4Ulr2Or2Aj0anj9iADQUUYcNQbClXQBHmZm/y7QQU4Xb2VibwUsA/w/TSxCMEVhZQhCsDuzyBawS8I7xGy3MJLCOgCJdxdBUCiwi4a3QRRhchcJWAIryKy5MJrCvgrtF1fV2dwDEBRXhMxWMENhLw0ehG8JaNFlCE0fEbvp2Am2XaRWJD8wsowvkzNuFAAt4RDhSWrU4joAinidIgMwi4WWaGFM0wmoAiHC0x+51awM0yU8druKYCirBpMLY1pcDHP/3l+bn++sVvzz/BVwkQWFxAES5O6oIETgs8vDr9NV8hQGAbAUW4jbtVMwV2O99xmcmburWAb8vW8djcZAI77wgnS9Q4UwgowiliNMQgArsH33GDRGWbSQK+LZPSNuvWAj4a3ToB6xM4IqAIj6B4iMBKAj4aXQnWZQncI6AI79HzWgJXCrhZ5kowTydQIKAIC5AtQeCtgHeEjgKBhgKKsGEotjStgJtlpo3WYCMLKMKR07P30QS8IxwtMfuNEFCEETEbsomAu0abBGEbBJ4KKMKnGn5NYF0B7wjX9XV1AjcJKMKb2LyIwE0CfkZ4E5sXEVhXQBGu6+vqBJ4J7Pyl2888/IZABwFF2CEFe0gR8I4wJWlzDiWgCIeKy2YHF3CzzOAB2v6cAopwzlxN1VPAzTI9c7GrcAFFGH4AjH+LwO7W/z799GcfXO/Wa3/1ug9e3BMIEHgpoAhfmniEwFoCX+73a13adQkQuFVAEd4q53UErhd43B+uf5FXECCwrsDrdS/v6gQIPBH48vH/7wj/9M+f//2/P/rP/nvfffjXD7/z5598/w9PnuiXBAjUCSjCOmsrEXh8V4S//8dn7zXedOFf/v3jN//7xQ8+f/+gXxAgUCbgo9EyagsR+Oibj0aftuBTlFOPP32OXxMgsLiAIlyc1AUJnBR43O/Pt935r568ri8QIHCHgCK8A89LCVwp8PjoZpkryTydwPoCinB9YysQeCfgj0+8k/D/BBoJKMJGYdjK9ALvb5aZflIDEhhIQBEOFJatDi/gzxEOH6EBZhRQhDOmaqauAm9uljn/ZyTOf7XrWPZFYGwBRTh2fnY/lsA3N8ucartTj481o90SGE7AH6gfLjIbHlhgf3h71+ibzvM3ywwcpK3PJbA7vPvOnGsu0xBYUaDtv/Pg23nF1F2aAAECBAgQIDClgHeEU8ZqqHUFvCNc19fVCdQKuFmm1ttqBAgQINBMQBE2C8R2CBAgQKBWQBHWeluNAAECBJoJKMJmgdgOAQIECNQKKMJab6sRIECAQDMBRdgsENshQIAAgVoBRVjrbTUCBAgQaCagCJsFYjsECBAgUCugCGu9rUaAAAECzQQUYbNAbIcAAQIEagUUYa231QgQIECgmYAibBaI7RAgQIBArYAirPW2GgECBAg0E1CEzQKxHQIECBCoFfDPMNV6W40AAQIEmgl4R9gsENshQIAAgVoBRVjrbTUCBAgQaCagCJsFYjsECBAgUCugCGu9rUaAAAECzQQUYbNAbIcAAQIEagUUYa231QgQIECgmYAibBaI7RAgQIBArYAirPW2GgECBAg0E1CEzQKxHQIECBCoFVCEtd5WI0CAAIFmAoqwWSC2Q4AAAQK1Aoqw1ttqBAgQINBMQBE2C8R2CBAgQKBWQBHWeluNAAECBJoJKMJmgdgOAQIECNQKKMJab6sRIECAQDMBRdgsENshQIAAgVoBRVjrbTUCBAgQaCagCJsFYjsECBAgUCugCGu9rUaAAAECzQQUYbNAbIcAAQIEagUUYa231QgQIECgmYAibBaI7RAgQIBArYAirPW2GgECBAg0E1CEzQKxHQIECBCoFVCEtd5WI0CAAIFmAoqwWSC2Q4AAAQK1Aoqw1ttqBAgQINBMQBE2C8R2CBAgQKBWQBHWeluNAAECBJoJKMJmgdgOAQIECNQKKMJab6sRIECAQDMBRdgsENshQIAAgVoBRVjrbTUCBAgQaCagCJsFYjsECBAgUCugCGu9rUaAAAECzQQUYbNAbIcAAQIEagUUYa231QgQIECgmYAibBaI7RAgQIBArYAirPW2GgECBAg0E1CEzQKxHQIECBCoFVCEtd5WI0CAAIFmAoqwWSC2Q4AAAQK1Aoqw1ttqBAgQINBMQBE2C8R2CBAgQKBWQBHWeluNAAECBJoJKMJmgdgOAQIECNQKKMJab6sRIECAQDMBRdgsENshQIAAgVoBRVjrbTUCBAgQaCagCJsFYjsECBAgUCugCGu9rUaAAAECzQQUYbNAbIcAAQIEagUUYa231QgQIECgmYAibBaI7RAgQIBArYAirPW2GgECBAg0E1CEzQKxHQIECBCoFVCEtd5WI0CAAIFmAoqwWSC2Q4AAAQK1Aoqw1ttqBAgQINBMQBE2C8R2CBAgQKBWQBHWeluNAAECBJoJKMJmgdgOAQIECNQKKMJab6sRIECAQDMBRdgsENshQIAAgVoBRVjrbTUCBAgQaCagCJsFYjsECBAgUCugCGu9rUaAAAECzQQUYbNAbIcAAQIEagUUYa231QgQIECgmYAibBaI7RAgQIBArYAirPW2GgECBAg0E1CEzQKxHQIECBCoFVCEtd5WI0CAAIFmAoqwWSC2Q4AAAQK1Aoqw1ttqBAgQINBMQBE2C8R2CBAgQKBWQBHWeluNAAECBJoJKMJmgdgOAQIECNQKKMJab6sRIECAQDMBRdgsENshQIAAgVoBRVjrbTUCBAgQaCagCJsFYjsECBAgUCugCGu9rUaAAAECzQT+BwOnHqGjaJ6pAAAAAElFTkSuQmCC", "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#@test {\"skip\": true}\n", "env.reset()\n", "PIL.Image.fromarray(env.render())" ] }, { "cell_type": "markdown", "metadata": { "id": "B9_lskPOey18" }, "source": [ "The `time_step = environment.step(action)` statement takes `action` in the environment. The `TimeStep` tuple returned contains the environment's next observation and reward for that action. The `time_step_spec()` and `action_spec()` methods in the environment return the specifications (types, shapes, bounds) of the `time_step` and `action` respectively." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:06.083142Z", "iopub.status.busy": "2023-12-22T14:05:06.082907Z", "iopub.status.idle": "2023-12-22T14:05:06.090439Z", "shell.execute_reply": "2023-12-22T14:05:06.089883Z" }, "id": "exDv57iHfwQV" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Observation Spec:\n", "BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])\n", "Action Spec:\n", "BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)\n" ] } ], "source": [ "print('Observation Spec:')\n", "print(env.time_step_spec().observation)\n", "print('Action Spec:')\n", "print(env.action_spec())" ] }, { "cell_type": "markdown", "metadata": { "id": "eJCgJnx3g0yY" }, "source": [ "So, we see that observation is an array of 4 floats: the position and velocity of the cart, and the angular position and velocity of the pole. Since only two actions are possible (move left or move right), the `action_spec` is a scalar where 0 means \"move left\" and 1 means \"move right.\"" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:06.093731Z", "iopub.status.busy": "2023-12-22T14:05:06.093281Z", "iopub.status.idle": "2023-12-22T14:05:06.099811Z", "shell.execute_reply": "2023-12-22T14:05:06.099243Z" }, "id": "V2UGR5t_iZX-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Time step:\n", "TimeStep(\n", "{'step_type': array(0, dtype=int32),\n", " 'reward': array(0., dtype=float32),\n", " 'discount': array(1., dtype=float32),\n", " 'observation': array([-0.00907558, 0.02627698, -0.01019297, 0.04808202], dtype=float32)})\n", "Next time step:\n", "TimeStep(\n", "{'step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32),\n", " 'observation': array([-0.00855004, 0.2215436 , -0.00923133, -0.24779937], dtype=float32)})\n" ] } ], "source": [ "time_step = env.reset()\n", "print('Time step:')\n", "print(time_step)\n", "\n", "action = np.array(1, dtype=np.int32)\n", "\n", "next_time_step = env.step(action)\n", "print('Next time step:')\n", "print(next_time_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "zuUqXAVmecTU" }, "source": [ "Usually we create two environments: one for training and one for evaluation. Most environments are written in pure python, but they can be easily converted to TensorFlow using the `TFPyEnvironment` wrapper. The original environment's API uses numpy arrays, the `TFPyEnvironment` converts these to/from `Tensors` for you to more easily interact with TensorFlow policies and agents.\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:06.103106Z", "iopub.status.busy": "2023-12-22T14:05:06.102593Z", "iopub.status.idle": "2023-12-22T14:05:06.113658Z", "shell.execute_reply": "2023-12-22T14:05:06.113116Z" }, "id": "Xp-Y4mD6eDhF" }, "outputs": [], "source": [ "train_py_env = suite_gym.load(env_name)\n", "eval_py_env = suite_gym.load(env_name)\n", "\n", "train_env = tf_py_environment.TFPyEnvironment(train_py_env)\n", "eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "E9lW_OZYFR8A" }, "source": [ "## Agent\n", "\n", "The algorithm that we use to solve an RL problem is represented as an `Agent`. In addition to the REINFORCE agent, TF-Agents provides standard implementations of a variety of `Agents` such as [DQN](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf), [DDPG](https://arxiv.org/pdf/1509.02971.pdf), [TD3](https://arxiv.org/pdf/1802.09477.pdf), [PPO](https://arxiv.org/abs/1707.06347) and [SAC](https://arxiv.org/abs/1801.01290).\n", "\n", "To create a REINFORCE Agent, we first need an `Actor Network` that can learn to predict the action given an observation from the environment.\n", "\n", "We can easily create an `Actor Network` using the specs of the observations and actions. We can specify the layers in the network which, in this example, is the `fc_layer_params` argument set to a tuple of `ints` representing the sizes of each hidden layer (see the Hyperparameters section above).\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:06.117206Z", "iopub.status.busy": "2023-12-22T14:05:06.116765Z", "iopub.status.idle": "2023-12-22T14:05:06.161414Z", "shell.execute_reply": "2023-12-22T14:05:06.160841Z" }, "id": "TgkdEPg_muzV" }, "outputs": [], "source": [ "actor_net = actor_distribution_network.ActorDistributionNetwork(\n", " train_env.observation_spec(),\n", " train_env.action_spec(),\n", " fc_layer_params=fc_layer_params)" ] }, { "cell_type": "markdown", "metadata": { "id": "z62u55hSmviJ" }, "source": [ "We also need an `optimizer` to train the network we just created, and a `train_step_counter` variable to keep track of how many times the network was updated.\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:06.164948Z", "iopub.status.busy": "2023-12-22T14:05:06.164368Z", "iopub.status.idle": "2023-12-22T14:05:08.626259Z", "shell.execute_reply": "2023-12-22T14:05:08.625470Z" }, "id": "jbY4yrjTEyc9" }, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n", "\n", "train_step_counter = tf.Variable(0)\n", "\n", "tf_agent = reinforce_agent.ReinforceAgent(\n", " train_env.time_step_spec(),\n", " train_env.action_spec(),\n", " actor_network=actor_net,\n", " optimizer=optimizer,\n", " normalize_returns=True,\n", " train_step_counter=train_step_counter)\n", "tf_agent.initialize()" ] }, { "cell_type": "markdown", "metadata": { "id": "I0KLrEPwkn5x" }, "source": [ "## Policies\n", "\n", "In TF-Agents, policies represent the standard notion of policies in RL: given a `time_step` produce an action or a distribution over actions. The main method is `policy_step = policy.action(time_step)` where `policy_step` is a named tuple `PolicyStep(action, state, info)`. The `policy_step.action` is the `action` to be applied to the environment, `state` represents the state for stateful (RNN) policies and `info` may contain auxiliary information such as log probabilities of the actions.\n", "\n", "Agents contain two policies: the main policy that is used for evaluation/deployment (agent.policy) and another policy that is used for data collection (agent.collect_policy)." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:08.630846Z", "iopub.status.busy": "2023-12-22T14:05:08.630328Z", "iopub.status.idle": "2023-12-22T14:05:08.633596Z", "shell.execute_reply": "2023-12-22T14:05:08.632978Z" }, "id": "BwY7StuMkuV4" }, "outputs": [], "source": [ "eval_policy = tf_agent.policy\n", "collect_policy = tf_agent.collect_policy" ] }, { "cell_type": "markdown", "metadata": { "id": "94rCXQtbUbXv" }, "source": [ "## Metrics and Evaluation\n", "\n", "The most common metric used to evaluate a policy is the average return. The return is the sum of rewards obtained while running a policy in an environment for an episode, and we usually average this over a few episodes. We can compute the average return metric as follows.\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:08.636739Z", "iopub.status.busy": "2023-12-22T14:05:08.636497Z", "iopub.status.idle": "2023-12-22T14:05:08.641058Z", "shell.execute_reply": "2023-12-22T14:05:08.640376Z" }, "id": "bitzHo5_UbXy" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "def compute_avg_return(environment, policy, num_episodes=10):\n", "\n", " total_return = 0.0\n", " for _ in range(num_episodes):\n", "\n", " time_step = environment.reset()\n", " episode_return = 0.0\n", "\n", " while not time_step.is_last():\n", " action_step = policy.action(time_step)\n", " time_step = environment.step(action_step.action)\n", " episode_return += time_step.reward\n", " total_return += episode_return\n", "\n", " avg_return = total_return / num_episodes\n", " return avg_return.numpy()[0]\n", "\n", "\n", "# Please also see the metrics module for standard implementations of different\n", "# metrics." ] }, { "cell_type": "markdown", "metadata": { "id": "NLva6g2jdWgr" }, "source": [ "## Replay Buffer\n", "\n", "In order to keep track of the data collected from the environment, we will use [Reverb](https://deepmind.com/research/open-source/Reverb), an efficient, extensible, and easy-to-use replay system by Deepmind. It stores experience data when we collect trajectories and is consumed during training.\n", "\n", "This replay buffer is constructed using specs describing the tensors that are to be stored, which can be obtained from the agent using `tf_agent.collect_data_spec`." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:08.644420Z", "iopub.status.busy": "2023-12-22T14:05:08.643928Z", "iopub.status.idle": "2023-12-22T14:05:08.660169Z", "shell.execute_reply": "2023-12-22T14:05:08.659543Z" }, "id": "vX2zGUWJGWAl" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[reverb/cc/platform/tfrecord_checkpointer.cc:162] Initializing TFRecordCheckpointer in /tmpfs/tmp/tmpkagdqs1n.\n", "[reverb/cc/platform/tfrecord_checkpointer.cc:565] Loading latest checkpoint from /tmpfs/tmp/tmpkagdqs1n\n", "[reverb/cc/platform/default/server.cc:71] Started replay server on port 41705\n" ] } ], "source": [ "table_name = 'uniform_table'\n", "replay_buffer_signature = tensor_spec.from_spec(\n", " tf_agent.collect_data_spec)\n", "replay_buffer_signature = tensor_spec.add_outer_dim(\n", " replay_buffer_signature)\n", "table = reverb.Table(\n", " table_name,\n", " max_size=replay_buffer_capacity,\n", " sampler=reverb.selectors.Uniform(),\n", " remover=reverb.selectors.Fifo(),\n", " rate_limiter=reverb.rate_limiters.MinSize(1),\n", " signature=replay_buffer_signature)\n", "\n", "reverb_server = reverb.Server([table])\n", "\n", "replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(\n", " tf_agent.collect_data_spec,\n", " table_name=table_name,\n", " sequence_length=None,\n", " local_server=reverb_server)\n", "\n", "rb_observer = reverb_utils.ReverbAddEpisodeObserver(\n", " replay_buffer.py_client,\n", " table_name,\n", " replay_buffer_capacity\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "ZGNTDJpZs4NN" }, "source": [ "For most agents, the `collect_data_spec` is a `Trajectory` named tuple containing the observation, action, reward etc." ] }, { "cell_type": "markdown", "metadata": { "id": "rVD5nQ9ZGo8_" }, "source": [ "## Data Collection\n", "\n", "As REINFORCE learns from whole episodes, we define a function to collect an episode using the given data collection policy and save the data (observations, actions, rewards etc.) as trajectories in the replay buffer. Here we are using 'PyDriver' to run the experience collecting loop. You can learn more about TF Agents driver in our [drivers tutorial](https://www.tensorflow.org/agents/tutorials/4_drivers_tutorial)." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:08.663547Z", "iopub.status.busy": "2023-12-22T14:05:08.662908Z", "iopub.status.idle": "2023-12-22T14:05:08.666866Z", "shell.execute_reply": "2023-12-22T14:05:08.666320Z" }, "id": "wr1KSAEGG4h9" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "\n", "def collect_episode(environment, policy, num_episodes):\n", "\n", " driver = py_driver.PyDriver(\n", " environment,\n", " py_tf_eager_policy.PyTFEagerPolicy(\n", " policy, use_tf_function=True),\n", " [rb_observer],\n", " max_episodes=num_episodes)\n", " initial_time_step = environment.reset()\n", " driver.run(initial_time_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "hBc9lj9VWWtZ" }, "source": [ "## Training the agent\n", "\n", "The training loop involves both collecting data from the environment and optimizing the agent's networks. Along the way, we will occasionally evaluate the agent's policy to see how we are doing.\n", "\n", "The following will take ~3 minutes to run." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:05:08.670316Z", "iopub.status.busy": "2023-12-22T14:05:08.669689Z", "iopub.status.idle": "2023-12-22T14:07:49.310371Z", "shell.execute_reply": "2023-12-22T14:07:49.309563Z" }, "id": "0pTbJ3PeyF-u" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1703253913.189247 48625 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 25: loss = 1.8318419456481934\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 50: loss = 0.0070743560791015625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 50: Average Return = 9.800000190734863\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 75: loss = 1.1006038188934326\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 100: loss = 0.5719594955444336\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 100: Average Return = 50.29999923706055\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 125: loss = -1.2458715438842773\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 150: loss = 1.9363441467285156\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 150: Average Return = 98.30000305175781\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 175: loss = 0.8784818649291992\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 200: loss = 1.9726766347885132\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 200: Average Return = 143.6999969482422\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 225: loss = 2.316105842590332\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 250: loss = 2.5175299644470215\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 250: Average Return = 191.5\n" ] } ], "source": [ "#@test {\"skip\": true}\n", "try:\n", " %%time\n", "except:\n", " pass\n", "\n", "# (Optional) Optimize by wrapping some of the code in a graph using TF function.\n", "tf_agent.train = common.function(tf_agent.train)\n", "\n", "# Reset the train step\n", "tf_agent.train_step_counter.assign(0)\n", "\n", "# Evaluate the agent's policy once before training.\n", "avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)\n", "returns = [avg_return]\n", "\n", "for _ in range(num_iterations):\n", "\n", " # Collect a few episodes using collect_policy and save to the replay buffer.\n", " collect_episode(\n", " train_py_env, tf_agent.collect_policy, collect_episodes_per_iteration)\n", "\n", " # Use data from the buffer and update the agent's network.\n", " iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))\n", " trajectories, _ = next(iterator)\n", " train_loss = tf_agent.train(experience=trajectories)\n", "\n", " replay_buffer.clear()\n", "\n", " step = tf_agent.train_step_counter.numpy()\n", "\n", " if step % log_interval == 0:\n", " print('step = {0}: loss = {1}'.format(step, train_loss.loss))\n", "\n", " if step % eval_interval == 0:\n", " avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)\n", " print('step = {0}: Average Return = {1}'.format(step, avg_return))\n", " returns.append(avg_return)" ] }, { "cell_type": "markdown", "metadata": { "id": "68jNcA_TiJDq" }, "source": [ "## Visualization\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aO-LWCdbbOIC" }, "source": [ "### Plots\n", "\n", "We can plot return vs global steps to see the performance of our agent. In `Cartpole-v0`, the environment gives a reward of +1 for every time step the pole stays up, and since the maximum number of steps is 200, the maximum possible return is also 200." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:07:49.314255Z", "iopub.status.busy": "2023-12-22T14:07:49.313751Z", "iopub.status.idle": "2023-12-22T14:07:49.521328Z", "shell.execute_reply": "2023-12-22T14:07:49.520744Z" }, "id": "NxtL1mbOYCVO" }, "outputs": [ { "data": { "text/plain": [ "(0.7150002002716054, 250.0)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#@test {\"skip\": true}\n", "\n", "steps = range(0, num_iterations + 1, eval_interval)\n", "plt.plot(steps, returns)\n", "plt.ylabel('Average Return')\n", "plt.xlabel('Step')\n", "plt.ylim(top=250)" ] }, { "cell_type": "markdown", "metadata": { "id": "M7-XpPP99Cy7" }, "source": [ "### Videos" ] }, { "cell_type": "markdown", "metadata": { "id": "9pGfGxSH32gn" }, "source": [ "It is helpful to visualize the performance of an agent by rendering the environment at each step. Before we do that, let us first create a function to embed videos in this colab." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:07:49.524818Z", "iopub.status.busy": "2023-12-22T14:07:49.524272Z", "iopub.status.idle": "2023-12-22T14:07:49.528329Z", "shell.execute_reply": "2023-12-22T14:07:49.527677Z" }, "id": "ULaGr8pvOKbl" }, "outputs": [], "source": [ "def embed_mp4(filename):\n", " \"\"\"Embeds an mp4 file in the notebook.\"\"\"\n", " video = open(filename,'rb').read()\n", " b64 = base64.b64encode(video)\n", " tag = '''\n", " '''.format(b64.decode())\n", "\n", " return IPython.display.HTML(tag)" ] }, { "cell_type": "markdown", "metadata": { "id": "9c_PH-pX4Pr5" }, "source": [ "The following code visualizes the agent's policy for a few episodes:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2023-12-22T14:07:49.531508Z", "iopub.status.busy": "2023-12-22T14:07:49.531193Z", "iopub.status.idle": "2023-12-22T14:07:56.632628Z", "shell.execute_reply": "2023-12-22T14:07:56.631727Z" }, "id": "owOVWB158NlF" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[swscaler @ 0x5563cf186880] Warning: data is not aligned! This can lead to a speed loss\n" ] }, { "data": { "text/html": [ "\n", " " ], "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num_episodes = 3\n", "video_filename = 'imageio.mp4'\n", "with imageio.get_writer(video_filename, fps=60) as video:\n", " for _ in range(num_episodes):\n", " time_step = eval_env.reset()\n", " video.append_data(eval_py_env.render())\n", " while not time_step.is_last():\n", " action_step = tf_agent.policy.action(time_step)\n", " time_step = eval_env.step(action_step.action)\n", " video.append_data(eval_py_env.render())\n", "\n", "embed_mp4(video_filename)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "6_reinforce_tutorial.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }