{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "klGNgWREsvQv" }, "source": [ "##### Copyright 2023 The TF-Agents Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-03-09T12:47:09.993607Z", "iopub.status.busy": "2024-03-09T12:47:09.993014Z", "iopub.status.idle": "2024-03-09T12:47:09.996725Z", "shell.execute_reply": "2024-03-09T12:47:09.996144Z" }, "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "oMaGpi7TciQs" }, "source": [ "# DQN C51/Rainbow\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " View on TensorFlow.org\n", " \n", " \n", " \n", " Run in Google Colab\n", " \n", " \n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "ZOUOQOrFs3zn" }, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "metadata": { "id": "cKOCZlhUgXVK" }, "source": [ "This example shows how to train a [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf) agent on the Cartpole environment using the TF-Agents library.\n", "\n", "![Cartpole environment](https://github.com/tensorflow/agents/blob/master/docs/tutorials/images/cartpole.png?raw=1)\n", "\n", "Make sure you take a look through the [DQN tutorial](https://github.com/tensorflow/agents/blob/master/docs/tutorials/1_dqn_tutorial.ipynb) as a prerequisite. This tutorial will assume familiarity with the DQN tutorial; it will mainly focus on the differences between DQN and C51.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "lsaQlK8fFQqH" }, "source": [ "## Setup\n" ] }, { "cell_type": "markdown", "metadata": { "id": "-NzBsZzPcyBm" }, "source": [ "If you haven't installed tf-agents yet, run:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:47:10.000701Z", "iopub.status.busy": "2024-03-09T12:47:10.000075Z", "iopub.status.idle": "2024-03-09T12:47:29.099292Z", "shell.execute_reply": "2024-03-09T12:47:29.098447Z" }, "id": "KEHR2Ui-lo8O" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]\r", " \r", "Hit:1 http://us-central1.gce.archive.ubuntu.com/ubuntu focal InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (91.189.91.81)] [Connecting to apt.llvm.o\r", " \r", "Hit:2 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates InRelease\r\n", "\r", " \r", "Hit:3 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-backports InRelease\r\n", "\r", "0% [Connecting to security.ubuntu.com (91.189.91.81)] [Connecting to apt.llvm.o" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Waiting for headers] [Waiting for headers] [Connecting to ppa.launchpad.net\r", " \r", "Hit:5 https://download.docker.com/linux/ubuntu focal InRelease\r\n", "\r", "0% [Waiting for headers] [Waiting for headers] [Connecting to ppa.launchpad.net\r", " \r", "Get:6 https://nvidia.github.io/libnvidia-container/stable/ubuntu18.04/amd64 InRelease [1484 B]\r\n", "\r", "0% [Waiting for headers] [Waiting for headers] [Connecting to ppa.launchpad.net\r", "0% [Waiting for headers] [Waiting for headers] [Connecting to ppa.launchpad.net\r", " \r", "Hit:7 http://security.ubuntu.com/ubuntu focal-security InRelease\r\n", "\r", "0% [Waiting for headers] [Connecting to ppa.launchpad.net (185.125.190.80)] [Wa\r", " \r", "Hit:8 https://nvidia.github.io/nvidia-container-runtime/stable/ubuntu18.04/amd64 InRelease\r\n", "\r", " \r", "0% [Waiting for headers] [Connecting to ppa.launchpad.net (185.125.190.80)]\r", " \r", "Hit:9 https://nvidia.github.io/nvidia-docker/ubuntu18.04/amd64 InRelease\r\n", "\r", "0% [Waiting for headers] [Connecting to ppa.launchpad.net (185.125.190.80)]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:10 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64 InRelease\r\n", "\r", " \r", "0% [Connecting to ppa.launchpad.net (185.125.190.80)]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "0% [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:11 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease\r\n", "\r", " \r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Connected to apt.llvm.org (199.232.198.49)] [Waiting for headers]\r", " \r", "Hit:4 https://apt.llvm.org/focal llvm-toolchain-focal-17 InRelease\r\n", "\r", " \r", "0% [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:12 http://ppa.launchpad.net/longsleep/golang-backports/ubuntu focal InRelease\r\n", "\r", " \r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Waiting for headers]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " \r", "Hit:13 http://ppa.launchpad.net/openjdk-r/ppa/ubuntu focal InRelease\r\n", "\r", " \r", "0% [Working]\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "0% [Working]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "100% [Working]\r", " \r", "Fetched 1484 B in 1s (1065 B/s)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 0%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 2%\r", "\r", "Reading package lists... 2%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r", "\r", "Reading package lists... 4%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 27%\r", "\r", "Reading package lists... 27%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 39%\r", "\r", "Reading package lists... 39%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 40%\r", "\r", "Reading package lists... 40%\r", "\r", "Reading package lists... 40%\r", "\r", "Reading package lists... 40%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 48%\r", "\r", "Reading package lists... 48%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 53%\r", "\r", "Reading package lists... 53%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 56%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 61%\r", "\r", "Reading package lists... 61%\r", "\r", "Reading package lists... 64%\r", "\r", "Reading package lists... 64%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 68%\r", "\r", "Reading package lists... 68%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r", "\r", "Reading package lists... 69%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 76%\r", "\r", "Reading package lists... 76%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 81%\r", "\r", "Reading package lists... 81%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 88%\r", "\r", "Reading package lists... 88%\r", "\r", "Reading package lists... 92%\r", "\r", "Reading package lists... 92%\r", "\r", "Reading package lists... 94%\r", "\r", "Reading package lists... 94%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r", "\r", "Reading package lists... 95%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... Done\r", "\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 100%\r", "\r", "Reading package lists... Done\r", "\r\n", "\r", "Building dependency tree... 0%\r", "\r", "Building dependency tree... 0%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree... 50%\r", "\r", "Building dependency tree... 50%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree \r", "\r\n", "\r", "Reading state information... 0%\r", "\r", "Reading state information... 0%\r", "\r", "Reading state information... Done\r", "\r\n", "freeglut3-dev is already the newest version (2.8.1-3).\r\n", "ffmpeg is already the newest version (7:4.2.7-0ubuntu0.1).\r\n", "xvfb is already the newest version (2:1.20.13-1ubuntu1~20.04.15).\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The following packages were automatically installed and are no longer required:\r\n", " libatasmart4 libblockdev-fs2 libblockdev-loop2 libblockdev-part-err2\r\n", " libblockdev-part2 libblockdev-swap2 libblockdev-utils2 libblockdev2\r\n", " libparted-fs-resize0 libxmlb2\r\n", "Use 'sudo apt autoremove' to remove them.\r\n", "0 upgraded, 0 newly installed, 0 to remove and 187 not upgraded.\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting imageio==2.4.0\r\n", " Using cached imageio-2.4.0-py3-none-any.whl\r\n", "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from imageio==2.4.0) (1.26.4)\r\n", "Requirement already satisfied: pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from imageio==2.4.0) (10.2.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: imageio\r\n", " Attempting uninstall: imageio\r\n", " Found existing installation: imageio 2.34.0\r\n", " Uninstalling imageio-2.34.0:\r\n", " Successfully uninstalled imageio-2.34.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n", "scikit-image 0.22.0 requires imageio>=2.27, but you have imageio 2.4.0 which is incompatible.\u001b[0m\u001b[31m\r\n", "\u001b[0mSuccessfully installed imageio-2.4.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting pyvirtualdisplay\r\n", " Using cached PyVirtualDisplay-3.0-py3-none-any.whl.metadata (943 bytes)\r\n", "Using cached PyVirtualDisplay-3.0-py3-none-any.whl (15 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: pyvirtualdisplay\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed pyvirtualdisplay-3.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tf-agents\r\n", " Using cached tf_agents-0.19.0-py3-none-any.whl.metadata (12 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: absl-py>=0.6.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.4.0)\r\n", "Collecting cloudpickle>=1.3 (from tf-agents)\r\n", " Using cached cloudpickle-3.0.0-py3-none-any.whl.metadata (7.0 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gin-config>=0.4.0 (from tf-agents)\r\n", " Using cached gin_config-0.5.0-py3-none-any.whl.metadata (2.9 kB)\r\n", "Collecting gym<=0.23.0,>=0.17.0 (from tf-agents)\r\n", " Using cached gym-0.23.0-py3-none-any.whl\r\n", "Requirement already satisfied: numpy>=1.19.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.26.4)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (10.2.0)\r\n", "Requirement already satisfied: six>=1.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.16.0)\r\n", "Requirement already satisfied: protobuf>=3.11.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (3.20.3)\r\n", "Requirement already satisfied: wrapt>=1.11.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.16.0)\r\n", "Collecting typing-extensions==4.5.0 (from tf-agents)\r\n", " Using cached typing_extensions-4.5.0-py3-none-any.whl.metadata (8.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting pygame==2.1.3 (from tf-agents)\r\n", " Using cached pygame-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.3 kB)\r\n", "Collecting tensorflow-probability~=0.23.0 (from tf-agents)\r\n", " Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl.metadata (13 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gym-notices>=0.0.4 (from gym<=0.23.0,>=0.17.0->tf-agents)\r\n", " Using cached gym_notices-0.0.8-py3-none-any.whl.metadata (1.0 kB)\r\n", "Requirement already satisfied: importlib-metadata>=4.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from gym<=0.23.0,>=0.17.0->tf-agents) (7.0.2)\r\n", "Requirement already satisfied: decorator in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (5.1.1)\r\n", "Requirement already satisfied: gast>=0.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (0.5.4)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (0.1.8)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.10.0->gym<=0.23.0,>=0.17.0->tf-agents) (3.17.0)\r\n", "Using cached tf_agents-0.19.0-py3-none-any.whl (1.4 MB)\r\n", "Using cached pygame-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.7 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached typing_extensions-4.5.0-py3-none-any.whl (27 kB)\r\n", "Using cached cloudpickle-3.0.0-py3-none-any.whl (20 kB)\r\n", "Using cached gin_config-0.5.0-py3-none-any.whl (61 kB)\r\n", "Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl (6.9 MB)\r\n", "Using cached gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: gym-notices, gin-config, typing-extensions, pygame, cloudpickle, tensorflow-probability, gym, tf-agents\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: typing-extensions\r\n", " Found existing installation: typing_extensions 4.10.0\r\n", " Uninstalling typing_extensions-4.10.0:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled typing_extensions-4.10.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed cloudpickle-3.0.0 gin-config-0.5.0 gym-0.23.0 gym-notices-0.0.8 pygame-2.1.3 tensorflow-probability-0.23.0 tf-agents-0.19.0 typing-extensions-4.5.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyglet in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.0.14)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tf-keras in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.16.0)\r\n", "Requirement already satisfied: tensorflow<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-keras) (2.16.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.4.0)\r\n", "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (24.3.7)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.5.4)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.2.0)\r\n", "Requirement already satisfied: h5py>=3.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.10.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (16.0.6)\r\n", "Requirement already satisfied: ml-dtypes~=0.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.3.2)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (23.2)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.20.3)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.31.0)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (69.1.1)\r\n", "Requirement already satisfied: six>=1.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.16.0)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.4.0)\r\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (4.5.0)\r\n", "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.16.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.62.1)\r\n", "Requirement already satisfied: tensorboard<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.16.2)\r\n", "Requirement already satisfied: keras>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.0.5)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.36.0)\r\n", "Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.26.4)\r\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow<2.17,>=2.16->tf-keras) (0.41.2)\r\n", "Requirement already satisfied: rich in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (13.7.1)\r\n", "Requirement already satisfied: namex in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.0.7)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.1.8)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (3.6)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (2.2.1)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (2024.2.2)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.5.2)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.0.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (7.0.2)\r\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (2.1.5)\r\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (3.0.0)\r\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (2.17.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.17.0)\r\n", "Requirement already satisfied: mdurl~=0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.1.2)\r\n" ] } ], "source": [ "!sudo apt-get update\n", "!sudo apt-get install -y xvfb ffmpeg freeglut3-dev\n", "!pip install 'imageio==2.4.0'\n", "!pip install pyvirtualdisplay\n", "!pip install tf-agents\n", "!pip install pyglet\n", "!pip install tf-keras" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:47:29.103366Z", "iopub.status.busy": "2024-03-09T12:47:29.103101Z", "iopub.status.idle": "2024-03-09T12:47:29.106787Z", "shell.execute_reply": "2024-03-09T12:47:29.106118Z" }, "id": "WPuD0bMEY9Iz" }, "outputs": [], "source": [ "import os\n", "# Keep using keras-2 (tf-keras) rather than keras-3 (keras).\n", "os.environ['TF_USE_LEGACY_KERAS'] = '1'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:47:29.109806Z", "iopub.status.busy": "2024-03-09T12:47:29.109441Z", "iopub.status.idle": "2024-03-09T12:47:32.547395Z", "shell.execute_reply": "2024-03-09T12:47:32.546038Z" }, "id": "sMitx5qSgJk1" }, "outputs": [], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "import base64\n", "import imageio\n", "import IPython\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import PIL.Image\n", "import pyvirtualdisplay\n", "\n", "import tensorflow as tf\n", "\n", "from tf_agents.agents.categorical_dqn import categorical_dqn_agent\n", "from tf_agents.drivers import dynamic_step_driver\n", "from tf_agents.environments import suite_gym\n", "from tf_agents.environments import tf_py_environment\n", "from tf_agents.eval import metric_utils\n", "from tf_agents.metrics import tf_metrics\n", "from tf_agents.networks import categorical_q_network\n", "from tf_agents.policies import random_tf_policy\n", "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n", "from tf_agents.trajectories import trajectory\n", "from tf_agents.utils import common\n", "\n", "# Set up a virtual display for rendering OpenAI gym environments.\n", "display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()" ] }, { "cell_type": "markdown", "metadata": { "id": "LmC0NDhdLIKY" }, "source": [ "## Hyperparameters" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:47:32.552900Z", "iopub.status.busy": "2024-03-09T12:47:32.551774Z", "iopub.status.idle": "2024-03-09T12:47:32.557770Z", "shell.execute_reply": "2024-03-09T12:47:32.557060Z" }, "id": "HC1kNrOsLSIZ" }, "outputs": [], "source": [ "env_name = \"CartPole-v1\" # @param {type:\"string\"}\n", "num_iterations = 15000 # @param {type:\"integer\"}\n", "\n", "initial_collect_steps = 1000 # @param {type:\"integer\"} \n", "collect_steps_per_iteration = 1 # @param {type:\"integer\"}\n", "replay_buffer_capacity = 100000 # @param {type:\"integer\"}\n", "\n", "fc_layer_params = (100,)\n", "\n", "batch_size = 64 # @param {type:\"integer\"}\n", "learning_rate = 1e-3 # @param {type:\"number\"}\n", "gamma = 0.99\n", "log_interval = 200 # @param {type:\"integer\"}\n", "\n", "num_atoms = 51 # @param {type:\"integer\"}\n", "min_q_value = -20 # @param {type:\"integer\"}\n", "max_q_value = 20 # @param {type:\"integer\"}\n", "n_step_update = 2 # @param {type:\"integer\"}\n", "\n", "num_eval_episodes = 10 # @param {type:\"integer\"}\n", "eval_interval = 1000 # @param {type:\"integer\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "VMsJC3DEgI0x" }, "source": [ "## Environment\n", "\n", "Load the environment as before, with one for training and one for evaluation. Here we use CartPole-v1 (vs. CartPole-v0 in the DQN tutorial), which has a larger max reward of 500 rather than 200." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:47:32.560829Z", "iopub.status.busy": "2024-03-09T12:47:32.560567Z", "iopub.status.idle": "2024-03-09T12:47:32.602216Z", "shell.execute_reply": "2024-03-09T12:47:32.601553Z" }, "id": "Xp-Y4mD6eDhF" }, "outputs": [], "source": [ "train_py_env = suite_gym.load(env_name)\n", "eval_py_env = suite_gym.load(env_name)\n", "\n", "train_env = tf_py_environment.TFPyEnvironment(train_py_env)\n", "eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)" ] }, { "cell_type": "markdown", "metadata": { "id": "E9lW_OZYFR8A" }, "source": [ "## Agent\n", "\n", "C51 is a Q-learning algorithm based on DQN. Like DQN, it can be used on any environment with a discrete action space.\n", "\n", "The main difference between C51 and DQN is that rather than simply predicting the Q-value for each state-action pair, C51 predicts a histogram model for the probability distribution of the Q-value:\n", "\n", "![Example C51 Distribution](images/c51_distribution.png)\n", "\n", "By learning the distribution rather than simply the expected value, the algorithm is able to stay more stable during training, leading to improved final performance. This is particularly true in situations with bimodal or even multimodal value distributions, where a single average does not provide an accurate picture.\n", "\n", "In order to train on probability distributions rather than on values, C51 must perform some complex distributional computations in order to calculate its loss function. But don't worry, all of this is taken care of for you in TF-Agents!\n", "\n", "To create a C51 Agent, we first need to create a `CategoricalQNetwork`. The API of the `CategoricalQNetwork` is the same as that of the `QNetwork`, except that there is an additional argument `num_atoms`. This represents the number of support points in our probability distribution estimates. (The above image includes 10 support points, each represented by a vertical blue bar.) As you can tell from the name, the default number of atoms is 51.\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:47:32.605688Z", "iopub.status.busy": "2024-03-09T12:47:32.605217Z", "iopub.status.idle": "2024-03-09T12:47:32.633217Z", "shell.execute_reply": "2024-03-09T12:47:32.632572Z" }, "id": "TgkdEPg_muzV" }, "outputs": [], "source": [ "categorical_q_net = categorical_q_network.CategoricalQNetwork(\n", " train_env.observation_spec(),\n", " train_env.action_spec(),\n", " num_atoms=num_atoms,\n", " fc_layer_params=fc_layer_params)" ] }, { "cell_type": "markdown", "metadata": { "id": "z62u55hSmviJ" }, "source": [ "We also need an `optimizer` to train the network we just created, and a `train_step_counter` variable to keep track of how many times the network was updated.\n", "\n", "Note that one other significant difference from vanilla `DqnAgent` is that we now need to specify `min_q_value` and `max_q_value` as arguments. These specify the most extreme values of the support (in other words, the most extreme of the 51 atoms on either side). Make sure to choose these appropriately for your particular environment. Here we use -20 and 20." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:47:32.637145Z", "iopub.status.busy": "2024-03-09T12:47:32.636382Z", "iopub.status.idle": "2024-03-09T12:47:36.252323Z", "shell.execute_reply": "2024-03-09T12:47:36.251611Z" }, "id": "jbY4yrjTEyc9" }, "outputs": [], "source": [ "optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)\n", "\n", "train_step_counter = tf.Variable(0)\n", "\n", "agent = categorical_dqn_agent.CategoricalDqnAgent(\n", " train_env.time_step_spec(),\n", " train_env.action_spec(),\n", " categorical_q_network=categorical_q_net,\n", " optimizer=optimizer,\n", " min_q_value=min_q_value,\n", " max_q_value=max_q_value,\n", " n_step_update=n_step_update,\n", " td_errors_loss_fn=common.element_wise_squared_loss,\n", " gamma=gamma,\n", " train_step_counter=train_step_counter)\n", "agent.initialize()" ] }, { "cell_type": "markdown", "metadata": { "id": "L7O7F_HqiQ1G" }, "source": [ "One last thing to note is that we also added an argument to use n-step updates with $n$ = 2. In single-step Q-learning ($n$ = 1), we only compute the error between the Q-values at the current time step and the next time step using the single-step return (based on the Bellman optimality equation). The single-step return is defined as:\n", "\n", "$G_t = R_{t + 1} + \\gamma V(s_{t + 1})$\n", "\n", "where we define $V(s) = \\max_a{Q(s, a)}$.\n", "\n", "N-step updates involve expanding the standard single-step return function $n$ times:\n", "\n", "$G_t^n = R_{t + 1} + \\gamma R_{t + 2} + \\gamma^2 R_{t + 3} + \\dots + \\gamma^n V(s_{t + n})$\n", "\n", "N-step updates enable the agent to bootstrap from further in the future, and with the right value of $n$, this often leads to faster learning.\n", "\n", "Although C51 and n-step updates are often combined with prioritized replay to form the core of the [Rainbow agent](https://arxiv.org/pdf/1710.02298.pdf), we saw no measurable improvement from implementing prioritized replay. Moreover, we find that when combining our C51 agent with n-step updates alone, our agent performs as well as other Rainbow agents on the sample of Atari environments we've tested." ] }, { "cell_type": "markdown", "metadata": { "id": "94rCXQtbUbXv" }, "source": [ "## Metrics and Evaluation\n", "\n", "The most common metric used to evaluate a policy is the average return. The return is the sum of rewards obtained while running a policy in an environment for an episode, and we usually average this over a few episodes. We can compute the average return metric as follows.\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:47:36.256573Z", "iopub.status.busy": "2024-03-09T12:47:36.256058Z", "iopub.status.idle": "2024-03-09T12:47:38.637806Z", "shell.execute_reply": "2024-03-09T12:47:38.637056Z" }, "id": "bitzHo5_UbXy" }, "outputs": [ { "data": { "text/plain": [ "37.7" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#@test {\"skip\": true}\n", "def compute_avg_return(environment, policy, num_episodes=10):\n", "\n", " total_return = 0.0\n", " for _ in range(num_episodes):\n", "\n", " time_step = environment.reset()\n", " episode_return = 0.0\n", "\n", " while not time_step.is_last():\n", " action_step = policy.action(time_step)\n", " time_step = environment.step(action_step.action)\n", " episode_return += time_step.reward\n", " total_return += episode_return\n", "\n", " avg_return = total_return / num_episodes\n", " return avg_return.numpy()[0]\n", "\n", "\n", "random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),\n", " train_env.action_spec())\n", "\n", "compute_avg_return(eval_env, random_policy, num_eval_episodes)\n", "\n", "# Please also see the metrics module for standard implementations of different\n", "# metrics." ] }, { "cell_type": "markdown", "metadata": { "id": "NLva6g2jdWgr" }, "source": [ "## Data Collection\n", "\n", "As in the DQN tutorial, set up the replay buffer and the initial data collection with the random policy." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:47:38.641633Z", "iopub.status.busy": "2024-03-09T12:47:38.640989Z", "iopub.status.idle": "2024-03-09T12:47:46.000946Z", "shell.execute_reply": "2024-03-09T12:47:46.000189Z" }, "id": "wr1KSAEGG4h9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:377: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use `as_dataset(..., single_deterministic_pass=False) instead.\n" ] } ], "source": [ "#@test {\"skip\": true}\n", "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n", " data_spec=agent.collect_data_spec,\n", " batch_size=train_env.batch_size,\n", " max_length=replay_buffer_capacity)\n", "\n", "def collect_step(environment, policy):\n", " time_step = environment.current_time_step()\n", " action_step = policy.action(time_step)\n", " next_time_step = environment.step(action_step.action)\n", " traj = trajectory.from_transition(time_step, action_step, next_time_step)\n", "\n", " # Add trajectory to the replay buffer\n", " replay_buffer.add_batch(traj)\n", "\n", "for _ in range(initial_collect_steps):\n", " collect_step(train_env, random_policy)\n", "\n", "# This loop is so common in RL, that we provide standard implementations of\n", "# these. For more details see the drivers module.\n", "\n", "# Dataset generates trajectories with shape [BxTx...] where\n", "# T = n_step_update + 1.\n", "dataset = replay_buffer.as_dataset(\n", " num_parallel_calls=3, sample_batch_size=batch_size,\n", " num_steps=n_step_update + 1).prefetch(3)\n", "\n", "iterator = iter(dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "hBc9lj9VWWtZ" }, "source": [ "## Training the agent\n", "\n", "The training loop involves both collecting data from the environment and optimizing the agent's networks. Along the way, we will occasionally evaluate the agent's policy to see how we are doing.\n", "\n", "The following will take ~7 minutes to run." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:47:46.004870Z", "iopub.status.busy": "2024-03-09T12:47:46.004589Z", "iopub.status.idle": "2024-03-09T12:59:36.892398Z", "shell.execute_reply": "2024-03-09T12:59:36.891625Z" }, "id": "0pTbJ3PeyF-u" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:1260: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "back_prop=False is deprecated. Consider using tf.stop_gradient instead.\n", "Instead of:\n", "results = tf.foldr(fn, elems, back_prop=False)\n", "Use:\n", "results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 200: loss = 3.2159409523010254\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 400: loss = 2.422974109649658\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 600: loss = 1.9803032875061035\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 800: loss = 1.733839750289917\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 1000: loss = 1.705157995223999\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 1000: Average Return = 88.60\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 1200: loss = 1.655350923538208\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 1400: loss = 1.419114351272583\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 1600: loss = 1.2578476667404175\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 1800: loss = 1.3189895153045654\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 2000: loss = 0.9676651954650879\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 2000: Average Return = 130.80\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 2200: loss = 0.7909003496170044\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 2400: loss = 0.9291537404060364\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 2600: loss = 0.8300429582595825\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 2800: loss = 0.9739845991134644\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 3000: loss = 0.5435967445373535\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 3000: Average Return = 261.40\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 3200: loss = 0.7065144777297974\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 3400: loss = 0.8492055535316467\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 3600: loss = 0.808651864528656\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 3800: loss = 0.48259130120277405\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 4000: loss = 0.9187874794006348\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 4000: Average Return = 280.90\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 4200: loss = 0.7415772676467896\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 4400: loss = 0.621947169303894\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 4600: loss = 0.5226543545722961\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 4800: loss = 0.7011302709579468\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 5000: loss = 0.7732619047164917\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 5000: Average Return = 271.70\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 5200: loss = 0.8493011593818665\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 5400: loss = 0.6786139011383057\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 5600: loss = 0.5639233589172363\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 5800: loss = 0.48468759655952454\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 6000: loss = 0.6366198062896729\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 6000: Average Return = 350.70\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 6200: loss = 0.4855012893676758\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 6400: loss = 0.4458327889442444\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 6600: loss = 0.6745614409446716\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 6800: loss = 0.5021890997886658\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 7000: loss = 0.4639193117618561\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 7000: Average Return = 343.00\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 7200: loss = 0.4711253345012665\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 7400: loss = 0.5891958475112915\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 7600: loss = 0.3957907557487488\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 7800: loss = 0.4868921637535095\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 8000: loss = 0.5140666365623474\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 8000: Average Return = 396.10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 8200: loss = 0.6051771640777588\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 8400: loss = 0.6179391741752625\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 8600: loss = 0.5253893733024597\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 8800: loss = 0.3697047531604767\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 9000: loss = 0.7271263599395752\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 9000: Average Return = 320.20\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 9200: loss = 0.5285177826881409\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 9400: loss = 0.4590812921524048\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 9600: loss = 0.4743385910987854\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 9800: loss = 0.47938746213912964\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 10000: loss = 0.5290409326553345\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 10000: Average Return = 433.00\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 10200: loss = 0.4573556184768677\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 10400: loss = 0.352144718170166\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 10600: loss = 0.39160820841789246\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 10800: loss = 0.3254846930503845\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 11000: loss = 0.37145161628723145\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 11000: Average Return = 414.60\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 11200: loss = 0.382583349943161\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 11400: loss = 0.44465434551239014\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 11600: loss = 0.4484185576438904\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 11800: loss = 0.248131662607193\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 12000: loss = 0.5516679883003235\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 12000: Average Return = 375.40\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 12200: loss = 0.3307253420352936\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 12400: loss = 0.19486135244369507\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 12600: loss = 0.31668007373809814\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 12800: loss = 0.4462052285671234\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 13000: loss = 0.241848886013031\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 13000: Average Return = 326.80\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 13200: loss = 0.20919030904769897\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 13400: loss = 0.2044396996498108\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 13600: loss = 0.428558886051178\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 13800: loss = 0.1880824714899063\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 14000: loss = 0.34256821870803833\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 14000: Average Return = 345.50\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 14200: loss = 0.22452744841575623\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 14400: loss = 0.29694461822509766\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 14600: loss = 0.4149337410926819\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 14800: loss = 0.41922691464424133\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 15000: loss = 0.4064670205116272\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step = 15000: Average Return = 242.10\n" ] } ], "source": [ "#@test {\"skip\": true}\n", "try:\n", " %%time\n", "except:\n", " pass\n", "\n", "# (Optional) Optimize by wrapping some of the code in a graph using TF function.\n", "agent.train = common.function(agent.train)\n", "\n", "# Reset the train step\n", "agent.train_step_counter.assign(0)\n", "\n", "# Evaluate the agent's policy once before training.\n", "avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)\n", "returns = [avg_return]\n", "\n", "for _ in range(num_iterations):\n", "\n", " # Collect a few steps using collect_policy and save to the replay buffer.\n", " for _ in range(collect_steps_per_iteration):\n", " collect_step(train_env, agent.collect_policy)\n", "\n", " # Sample a batch of data from the buffer and update the agent's network.\n", " experience, unused_info = next(iterator)\n", " train_loss = agent.train(experience)\n", "\n", " step = agent.train_step_counter.numpy()\n", "\n", " if step % log_interval == 0:\n", " print('step = {0}: loss = {1}'.format(step, train_loss.loss))\n", "\n", " if step % eval_interval == 0:\n", " avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)\n", " print('step = {0}: Average Return = {1:.2f}'.format(step, avg_return))\n", " returns.append(avg_return)" ] }, { "cell_type": "markdown", "metadata": { "id": "68jNcA_TiJDq" }, "source": [ "## Visualization\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aO-LWCdbbOIC" }, "source": [ "### Plots\n", "\n", "We can plot return vs global steps to see the performance of our agent. In `Cartpole-v1`, the environment gives a reward of +1 for every time step the pole stays up, and since the maximum number of steps is 500, the maximum possible return is also 500." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:59:36.896462Z", "iopub.status.busy": "2024-03-09T12:59:36.896172Z", "iopub.status.idle": "2024-03-09T12:59:37.121197Z", "shell.execute_reply": "2024-03-09T12:59:37.120473Z" }, "id": "NxtL1mbOYCVO" }, "outputs": [ { "data": { "text/plain": [ "(-11.255000400543214, 550.0)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#@test {\"skip\": true}\n", "\n", "steps = range(0, num_iterations + 1, eval_interval)\n", "plt.plot(steps, returns)\n", "plt.ylabel('Average Return')\n", "plt.xlabel('Step')\n", "plt.ylim(top=550)" ] }, { "cell_type": "markdown", "metadata": { "id": "M7-XpPP99Cy7" }, "source": [ "### Videos" ] }, { "cell_type": "markdown", "metadata": { "id": "9pGfGxSH32gn" }, "source": [ "It is helpful to visualize the performance of an agent by rendering the environment at each step. Before we do that, let us first create a function to embed videos in this colab." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:59:37.124663Z", "iopub.status.busy": "2024-03-09T12:59:37.124398Z", "iopub.status.idle": "2024-03-09T12:59:37.128794Z", "shell.execute_reply": "2024-03-09T12:59:37.128102Z" }, "id": "ULaGr8pvOKbl" }, "outputs": [], "source": [ "def embed_mp4(filename):\n", " \"\"\"Embeds an mp4 file in the notebook.\"\"\"\n", " video = open(filename,'rb').read()\n", " b64 = base64.b64encode(video)\n", " tag = '''\n", " '''.format(b64.decode())\n", "\n", " return IPython.display.HTML(tag)" ] }, { "cell_type": "markdown", "metadata": { "id": "9c_PH-pX4Pr5" }, "source": [ "The following code visualizes the agent's policy for a few episodes:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:59:37.132462Z", "iopub.status.busy": "2024-03-09T12:59:37.131935Z", "iopub.status.idle": "2024-03-09T12:59:44.221913Z", "shell.execute_reply": "2024-03-09T12:59:44.221017Z" }, "id": "owOVWB158NlF" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[swscaler @ 0x55f48c41d880] Warning: data is not aligned! This can lead to a speed loss\n" ] }, { "data": { "text/html": [ "\n", " " ], "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num_episodes = 3\n", "video_filename = 'imageio.mp4'\n", "with imageio.get_writer(video_filename, fps=60) as video:\n", " for _ in range(num_episodes):\n", " time_step = eval_env.reset()\n", " video.append_data(eval_py_env.render())\n", " while not time_step.is_last():\n", " action_step = agent.policy.action(time_step)\n", " time_step = eval_env.step(action_step.action)\n", " video.append_data(eval_py_env.render())\n", "\n", "embed_mp4(video_filename)" ] }, { "cell_type": "markdown", "metadata": { "id": "exziB27hY8ia" }, "source": [ "C51 tends to do slightly better than DQN on CartPole-v1, but the difference between the two agents becomes more and more significant in increasingly complex environments. For example, on the full Atari 2600 benchmark, C51 demonstrates a mean score improvement of 126% over DQN after normalizing with respect to a random agent. Additional improvements can be gained by including n-step updates.\n", "\n", "For a deeper dive into the C51 algorithm, see [A Distributional Perspective on Reinforcement Learning (2017)](https://arxiv.org/pdf/1707.06887.pdf)." ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "DQN C51/Rainbow Tutorial.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 0 }