{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "beObUOFyuRjT" }, "source": [ "##### Copyright 2023 The TF-Agents Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-03-09T12:16:56.658634Z", "iopub.status.busy": "2024-03-09T12:16:56.658414Z", "iopub.status.idle": "2024-03-09T12:16:56.662333Z", "shell.execute_reply": "2024-03-09T12:16:56.661748Z" }, "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "s6D70EeAZe-Q" }, "source": [ "# Drivers\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " View on TensorFlow.org\n", " \n", " \n", " \n", " Run in Google Colab\n", " \n", " \n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "8aPHF9kXFggA" }, "source": [ "## Introduction\n", "\n", "A common pattern in reinforcement learning is to execute a policy in an environment for a specified number of steps or episodes. This happens, for example, during data collection, evaluation and generating a video of the agent.\n", "\n", "While this is relatively straightforward to write in python, it is much more complex to write and debug in TensorFlow because it involves `tf.while` loops, `tf.cond` and `tf.control_dependencies`. Therefore we abstract this notion of a run loop into a class called `driver`, and provide well tested implementations both in Python and TensorFlow.\n", "\n", "Additionally, the data encountered by the driver at each step is saved in a named tuple called Trajectory and broadcast to a set of observers such as replay buffers and metrics. This data includes the observation from the environment, the action recommended by the policy, the reward obtained, the type of the current and the next step, etc." ] }, { "cell_type": "markdown", "metadata": { "id": "t7PM1QfMZqkS" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "0w-Ykwl1bn4v" }, "source": [ "If you haven't installed tf-agents or gym yet, run:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:16:56.666083Z", "iopub.status.busy": "2024-03-09T12:16:56.665499Z", "iopub.status.idle": "2024-03-09T12:17:06.265574Z", "shell.execute_reply": "2024-03-09T12:17:06.264652Z" }, "id": "TnE2CgilrngG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting tf-agents\r\n", " Using cached tf_agents-0.19.0-py3-none-any.whl.metadata (12 kB)\r\n", "Requirement already satisfied: absl-py>=0.6.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.4.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting cloudpickle>=1.3 (from tf-agents)\r\n", " Using cached cloudpickle-3.0.0-py3-none-any.whl.metadata (7.0 kB)\r\n", "Collecting gin-config>=0.4.0 (from tf-agents)\r\n", " Using cached gin_config-0.5.0-py3-none-any.whl.metadata (2.9 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gym<=0.23.0,>=0.17.0 (from tf-agents)\r\n", " Using cached gym-0.23.0-py3-none-any.whl\r\n", "Requirement already satisfied: numpy>=1.19.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.26.4)\r\n", "Requirement already satisfied: pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (10.2.0)\r\n", "Requirement already satisfied: six>=1.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.16.0)\r\n", "Requirement already satisfied: protobuf>=3.11.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (3.20.3)\r\n", "Requirement already satisfied: wrapt>=1.11.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.16.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting typing-extensions==4.5.0 (from tf-agents)\r\n", " Using cached typing_extensions-4.5.0-py3-none-any.whl.metadata (8.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting pygame==2.1.3 (from tf-agents)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Using cached pygame-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.3 kB)\r\n", "Collecting tensorflow-probability~=0.23.0 (from tf-agents)\r\n", " Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl.metadata (13 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gym-notices>=0.0.4 (from gym<=0.23.0,>=0.17.0->tf-agents)\r\n", " Using cached gym_notices-0.0.8-py3-none-any.whl.metadata (1.0 kB)\r\n", "Requirement already satisfied: importlib-metadata>=4.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from gym<=0.23.0,>=0.17.0->tf-agents) (7.0.2)\r\n", "Requirement already satisfied: decorator in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (5.1.1)\r\n", "Requirement already satisfied: gast>=0.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (0.5.4)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (0.1.8)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.10.0->gym<=0.23.0,>=0.17.0->tf-agents) (3.17.0)\r\n", "Using cached tf_agents-0.19.0-py3-none-any.whl (1.4 MB)\r\n", "Using cached pygame-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.7 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached typing_extensions-4.5.0-py3-none-any.whl (27 kB)\r\n", "Using cached cloudpickle-3.0.0-py3-none-any.whl (20 kB)\r\n", "Using cached gin_config-0.5.0-py3-none-any.whl (61 kB)\r\n", "Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl (6.9 MB)\r\n", "Using cached gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: gym-notices, gin-config, typing-extensions, pygame, cloudpickle, tensorflow-probability, gym, tf-agents\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: typing-extensions\r\n", " Found existing installation: typing_extensions 4.10.0\r\n", " Uninstalling typing_extensions-4.10.0:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled typing_extensions-4.10.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed cloudpickle-3.0.0 gin-config-0.5.0 gym-0.23.0 gym-notices-0.0.8 pygame-2.1.3 tensorflow-probability-0.23.0 tf-agents-0.19.0 typing-extensions-4.5.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tf-keras in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.16.0)\r\n", "Requirement already satisfied: tensorflow<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-keras) (2.16.1)\r\n", "Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.4.0)\r\n", "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (24.3.7)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.5.4)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.2.0)\r\n", "Requirement already satisfied: h5py>=3.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.10.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (16.0.6)\r\n", "Requirement already satisfied: ml-dtypes~=0.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.3.2)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (23.2)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.20.3)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.31.0)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (69.1.1)\r\n", "Requirement already satisfied: six>=1.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.16.0)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.4.0)\r\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (4.5.0)\r\n", "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.16.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.62.1)\r\n", "Requirement already satisfied: tensorboard<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.16.2)\r\n", "Requirement already satisfied: keras>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.0.5)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.36.0)\r\n", "Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.26.4)\r\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow<2.17,>=2.16->tf-keras) (0.41.2)\r\n", "Requirement already satisfied: rich in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (13.7.1)\r\n", "Requirement already satisfied: namex in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.0.7)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.1.8)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (3.6)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (2.2.1)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (2024.2.2)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.5.2)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.0.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (7.0.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (2.1.5)\r\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (3.0.0)\r\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (2.17.2)\r\n", "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.17.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: mdurl~=0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.1.2)\r\n" ] } ], "source": [ "!pip install tf-agents\n", "!pip install tf-keras" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:17:06.269802Z", "iopub.status.busy": "2024-03-09T12:17:06.269529Z", "iopub.status.idle": "2024-03-09T12:17:06.273225Z", "shell.execute_reply": "2024-03-09T12:17:06.272607Z" }, "id": "WPuD0bMEY9Iz" }, "outputs": [], "source": [ "import os\n", "# Keep using keras-2 (tf-keras) rather than keras-3 (keras).\n", "os.environ['TF_USE_LEGACY_KERAS'] = '1'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:17:06.276528Z", "iopub.status.busy": "2024-03-09T12:17:06.275987Z", "iopub.status.idle": "2024-03-09T12:17:09.138313Z", "shell.execute_reply": "2024-03-09T12:17:09.137186Z" }, "id": "whYNP894FSkA" }, "outputs": [], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "import tensorflow as tf\n", "\n", "\n", "from tf_agents.environments import suite_gym\n", "from tf_agents.environments import tf_py_environment\n", "from tf_agents.policies import random_py_policy\n", "from tf_agents.policies import random_tf_policy\n", "from tf_agents.metrics import py_metrics\n", "from tf_agents.metrics import tf_metrics\n", "from tf_agents.drivers import py_driver\n", "from tf_agents.drivers import dynamic_episode_driver" ] }, { "cell_type": "markdown", "metadata": { "id": "9V7DEcB8IeiQ" }, "source": [ "## Python Drivers\n", "\n", "The `PyDriver` class takes a python environment, a python policy and a list of observers to update at each step. The main method is `run()`, which steps the environment using actions from the policy until at least one of the following termination criteria is met: The number of steps reaches `max_steps` or the number of episodes reaches `max_episodes`.\n", "\n", "The implementation is roughly as follows:\n", "\n", "\n", "```python\n", "class PyDriver(object):\n", "\n", " def __init__(self, env, policy, observers, max_steps=1, max_episodes=1):\n", " self._env = env\n", " self._policy = policy\n", " self._observers = observers or []\n", " self._max_steps = max_steps or np.inf\n", " self._max_episodes = max_episodes or np.inf\n", "\n", " def run(self, time_step, policy_state=()):\n", " num_steps = 0\n", " num_episodes = 0\n", " while num_steps < self._max_steps and num_episodes < self._max_episodes:\n", "\n", " # Compute an action using the policy for the given time_step\n", " action_step = self._policy.action(time_step, policy_state)\n", "\n", " # Apply the action to the environment and get the next step\n", " next_time_step = self._env.step(action_step.action)\n", "\n", " # Package information into a trajectory\n", " traj = trajectory.Trajectory(\n", " time_step.step_type,\n", " time_step.observation,\n", " action_step.action,\n", " action_step.info,\n", " next_time_step.step_type,\n", " next_time_step.reward,\n", " next_time_step.discount)\n", "\n", " for observer in self._observers:\n", " observer(traj)\n", "\n", " # Update statistics to check termination\n", " num_episodes += np.sum(traj.is_last())\n", " num_steps += np.sum(~traj.is_boundary())\n", "\n", " time_step = next_time_step\n", " policy_state = action_step.state\n", "\n", " return time_step, policy_state\n", "\n", "```\n", "\n", "Now, let us run through the example of running a random policy on the CartPole environment, saving the results to a replay buffer and computing some metrics." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:17:09.142835Z", "iopub.status.busy": "2024-03-09T12:17:09.142341Z", "iopub.status.idle": "2024-03-09T12:17:09.383343Z", "shell.execute_reply": "2024-03-09T12:17:09.382457Z" }, "id": "Dj4_-77_5ExP" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Replay Buffer:\n", "Trajectory(\n", "{'step_type': array(0, dtype=int32),\n", " 'observation': array([ 0.00374074, -0.02818722, -0.02798625, -0.0196638 ], dtype=float32),\n", " 'action': array(1),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([ 0.00317699, 0.16732468, -0.02837953, -0.3210437 ], dtype=float32),\n", " 'action': array(0),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([ 0.00652349, -0.02738187, -0.0348004 , -0.03744393], dtype=float32),\n", " 'action': array(0),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([ 0.00597585, -0.22198795, -0.03554928, 0.24405919], dtype=float32),\n", " 'action': array(0),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([ 0.00153609, -0.41658458, -0.0306681 , 0.5253204 ], dtype=float32),\n", " 'action': array(0),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.0067956 , -0.61126184, -0.02016169, 0.80818397], dtype=float32),\n", " 'action': array(0),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.01902084, -0.8061018 , -0.00399801, 1.0944574 ], dtype=float32),\n", " 'action': array(1),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.03514287, -0.6109274 , 0.01789114, 0.8005227 ], dtype=float32),\n", " 'action': array(0),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.04736142, -0.8062901 , 0.03390159, 1.0987796 ], dtype=float32),\n", " 'action': array(1),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.06348722, -0.61163044, 0.05587719, 0.816923 ], dtype=float32),\n", " 'action': array(1),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.07571983, -0.41731614, 0.07221565, 0.54232585], dtype=float32),\n", " 'action': array(0),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.08406615, -0.61337477, 0.08306216, 0.8568603 ], dtype=float32),\n", " 'action': array(0),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.09633365, -0.8095243 , 0.10019937, 1.1744623 ], dtype=float32),\n", " 'action': array(1),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.11252414, -0.6158369 , 0.12368862, 0.91479784], dtype=float32),\n", " 'action': array(0),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.12484087, -0.8123951 , 0.14198457, 1.2436544 ], dtype=float32),\n", " 'action': array(1),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.14108877, -0.61935145, 0.16685766, 0.9986062 ], dtype=float32),\n", " 'action': array(1),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.1534758 , -0.42680538, 0.18682979, 0.7626272 ], dtype=float32),\n", " 'action': array(1),\n", " 'policy_info': (),\n", " 'next_step_type': array(1, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(1, dtype=int32),\n", " 'observation': array([-0.1620119 , -0.23468053, 0.20208232, 0.5340639 ], dtype=float32),\n", " 'action': array(0),\n", " 'policy_info': (),\n", " 'next_step_type': array(2, dtype=int32),\n", " 'reward': array(1., dtype=float32),\n", " 'discount': array(0., dtype=float32)})\n", "Trajectory(\n", "{'step_type': array(2, dtype=int32),\n", " 'observation': array([-0.16670552, -0.43198496, 0.21276361, 0.8830067 ], dtype=float32),\n", " 'action': array(1),\n", " 'policy_info': (),\n", " 'next_step_type': array(0, dtype=int32),\n", " 'reward': array(0., dtype=float32),\n", " 'discount': array(1., dtype=float32)})\n", "Average Return: 18.0\n" ] } ], "source": [ "env = suite_gym.load('CartPole-v0')\n", "policy = random_py_policy.RandomPyPolicy(time_step_spec=env.time_step_spec(), \n", " action_spec=env.action_spec())\n", "replay_buffer = []\n", "metric = py_metrics.AverageReturnMetric()\n", "observers = [replay_buffer.append, metric]\n", "driver = py_driver.PyDriver(\n", " env, policy, observers, max_steps=20, max_episodes=1)\n", "\n", "initial_time_step = env.reset()\n", "final_time_step, _ = driver.run(initial_time_step)\n", "\n", "print('Replay Buffer:')\n", "for traj in replay_buffer:\n", " print(traj)\n", "\n", "print('Average Return: ', metric.result())" ] }, { "cell_type": "markdown", "metadata": { "id": "X3Yrxg36Ik1x" }, "source": [ "## TensorFlow Drivers\n", "\n", "We also have drivers in TensorFlow which are functionally similar to Python drivers, but use TF environments, TF policies, TF observers etc. We currently have 2 TensorFlow drivers: `DynamicStepDriver`, which terminates after a given number of (valid) environment steps and `DynamicEpisodeDriver`, which terminates after a given number of episodes. Let us look at an example of the DynamicEpisode in action.\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:17:09.386873Z", "iopub.status.busy": "2024-03-09T12:17:09.386599Z", "iopub.status.idle": "2024-03-09T12:17:13.134232Z", "shell.execute_reply": "2024-03-09T12:17:13.133386Z" }, "id": "WC4ba3ObSceA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "final_time_step TimeStep(\n", "{'step_type': ,\n", " 'reward': ,\n", " 'discount': ,\n", " 'observation': })\n", "Number of Steps: 34\n", "Number of Episodes: 2\n" ] } ], "source": [ "env = suite_gym.load('CartPole-v0')\n", "tf_env = tf_py_environment.TFPyEnvironment(env)\n", "\n", "tf_policy = random_tf_policy.RandomTFPolicy(action_spec=tf_env.action_spec(),\n", " time_step_spec=tf_env.time_step_spec())\n", "\n", "\n", "num_episodes = tf_metrics.NumberOfEpisodes()\n", "env_steps = tf_metrics.EnvironmentSteps()\n", "observers = [num_episodes, env_steps]\n", "driver = dynamic_episode_driver.DynamicEpisodeDriver(\n", " tf_env, tf_policy, observers, num_episodes=2)\n", "\n", "# Initial driver.run will reset the environment and initialize the policy.\n", "final_time_step, policy_state = driver.run()\n", "\n", "print('final_time_step', final_time_step)\n", "print('Number of Steps: ', env_steps.result().numpy())\n", "print('Number of Episodes: ', num_episodes.result().numpy())" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:17:13.137797Z", "iopub.status.busy": "2024-03-09T12:17:13.137522Z", "iopub.status.idle": "2024-03-09T12:17:13.366195Z", "shell.execute_reply": "2024-03-09T12:17:13.365483Z" }, "id": "Sz5jhHnU0fX1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "final_time_step TimeStep(\n", "{'step_type': ,\n", " 'reward': ,\n", " 'discount': ,\n", " 'observation': })\n", "Number of Steps: 63\n", "Number of Episodes: 4\n" ] } ], "source": [ "# Continue running from previous state\n", "final_time_step, _ = driver.run(final_time_step, policy_state)\n", "\n", "print('final_time_step', final_time_step)\n", "print('Number of Steps: ', env_steps.result().numpy())\n", "print('Number of Episodes: ', num_episodes.result().numpy())" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "TF-Agents Drivers Tutorial.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }