{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "beObUOFyuRjT" }, "source": [ "##### Copyright 2023 The TF-Agents Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-03-09T12:22:08.960696Z", "iopub.status.busy": "2024-03-09T12:22:08.960140Z", "iopub.status.idle": "2024-03-09T12:22:08.963690Z", "shell.execute_reply": "2024-03-09T12:22:08.963138Z" }, "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "eutDVTs9aJEL" }, "source": [ "# Replay Buffers\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " View on TensorFlow.org\n", " \n", " \n", " \n", " Run in Google Colab\n", " \n", " \n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "8aPHF9kXFggA" }, "source": [ "## Introduction\n", "\n", "Reinforcement learning algorithms use replay buffers to store trajectories of experience when executing a policy in an environment. During training, replay buffers are queried for a subset of the trajectories (either a sequential subset or a sample) to \"replay\" the agent's experience.\n", "\n", "In this colab, we explore two types of replay buffers: python-backed and tensorflow-backed, sharing a common API. In the following sections, we describe the API, each of the buffer implementations and how to use them during data collection training.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "1uSlqYgvaG9b" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "GztmUpWKZ7kq" }, "source": [ "Install tf-agents if you haven't already." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:22:08.967454Z", "iopub.status.busy": "2024-03-09T12:22:08.966927Z", "iopub.status.idle": "2024-03-09T12:22:18.588872Z", "shell.execute_reply": "2024-03-09T12:22:18.588019Z" }, "id": "TnE2CgilrngG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting tf-agents\r\n", " Using cached tf_agents-0.19.0-py3-none-any.whl.metadata (12 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: absl-py>=0.6.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.4.0)\r\n", "Collecting cloudpickle>=1.3 (from tf-agents)\r\n", " Using cached cloudpickle-3.0.0-py3-none-any.whl.metadata (7.0 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gin-config>=0.4.0 (from tf-agents)\r\n", " Using cached gin_config-0.5.0-py3-none-any.whl.metadata (2.9 kB)\r\n", "Collecting gym<=0.23.0,>=0.17.0 (from tf-agents)\r\n", " Using cached gym-0.23.0-py3-none-any.whl\r\n", "Requirement already satisfied: numpy>=1.19.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.26.4)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (10.2.0)\r\n", "Requirement already satisfied: six>=1.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.16.0)\r\n", "Requirement already satisfied: protobuf>=3.11.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (3.20.3)\r\n", "Requirement already satisfied: wrapt>=1.11.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.16.0)\r\n", "Collecting typing-extensions==4.5.0 (from tf-agents)\r\n", " Using cached typing_extensions-4.5.0-py3-none-any.whl.metadata (8.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting pygame==2.1.3 (from tf-agents)\r\n", " Using cached pygame-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.3 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow-probability~=0.23.0 (from tf-agents)\r\n", " Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl.metadata (13 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gym-notices>=0.0.4 (from gym<=0.23.0,>=0.17.0->tf-agents)\r\n", " Using cached gym_notices-0.0.8-py3-none-any.whl.metadata (1.0 kB)\r\n", "Requirement already satisfied: importlib-metadata>=4.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from gym<=0.23.0,>=0.17.0->tf-agents) (7.0.2)\r\n", "Requirement already satisfied: decorator in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (5.1.1)\r\n", "Requirement already satisfied: gast>=0.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (0.5.4)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (0.1.8)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.10.0->gym<=0.23.0,>=0.17.0->tf-agents) (3.17.0)\r\n", "Using cached tf_agents-0.19.0-py3-none-any.whl (1.4 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached pygame-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.7 MB)\r\n", "Using cached typing_extensions-4.5.0-py3-none-any.whl (27 kB)\r\n", "Using cached cloudpickle-3.0.0-py3-none-any.whl (20 kB)\r\n", "Using cached gin_config-0.5.0-py3-none-any.whl (61 kB)\r\n", "Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl (6.9 MB)\r\n", "Using cached gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: gym-notices, gin-config, typing-extensions, pygame, cloudpickle, tensorflow-probability, gym, tf-agents\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: typing-extensions\r\n", " Found existing installation: typing_extensions 4.10.0\r\n", " Uninstalling typing_extensions-4.10.0:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled typing_extensions-4.10.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed cloudpickle-3.0.0 gin-config-0.5.0 gym-0.23.0 gym-notices-0.0.8 pygame-2.1.3 tensorflow-probability-0.23.0 tf-agents-0.19.0 typing-extensions-4.5.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tf-keras in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.16.0)\r\n", "Requirement already satisfied: tensorflow<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-keras) (2.16.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.4.0)\r\n", "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (24.3.7)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.5.4)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.2.0)\r\n", "Requirement already satisfied: h5py>=3.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.10.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (16.0.6)\r\n", "Requirement already satisfied: ml-dtypes~=0.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.3.2)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (23.2)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.20.3)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.31.0)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (69.1.1)\r\n", "Requirement already satisfied: six>=1.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.16.0)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.4.0)\r\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (4.5.0)\r\n", "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.16.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.62.1)\r\n", "Requirement already satisfied: tensorboard<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.16.2)\r\n", "Requirement already satisfied: keras>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.0.5)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.36.0)\r\n", "Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.26.4)\r\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow<2.17,>=2.16->tf-keras) (0.41.2)\r\n", "Requirement already satisfied: rich in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (13.7.1)\r\n", "Requirement already satisfied: namex in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.0.7)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.1.8)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (3.6)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (2.2.1)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (2024.2.2)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.5.2)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.0.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (7.0.2)\r\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (2.1.5)\r\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (3.0.0)\r\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (2.17.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.17.0)\r\n", "Requirement already satisfied: mdurl~=0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.1.2)\r\n" ] } ], "source": [ "!pip install tf-agents\n", "!pip install tf-keras" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:22:18.593245Z", "iopub.status.busy": "2024-03-09T12:22:18.592981Z", "iopub.status.idle": "2024-03-09T12:22:18.596589Z", "shell.execute_reply": "2024-03-09T12:22:18.596042Z" }, "id": "WPuD0bMEY9Iz" }, "outputs": [], "source": [ "import os\n", "# Keep using keras-2 (tf-keras) rather than keras-3 (keras).\n", "os.environ['TF_USE_LEGACY_KERAS'] = '1'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:22:18.599875Z", "iopub.status.busy": "2024-03-09T12:22:18.599258Z", "iopub.status.idle": "2024-03-09T12:22:21.475333Z", "shell.execute_reply": "2024-03-09T12:22:21.474561Z" }, "id": "whYNP894FSkA" }, "outputs": [], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "import tensorflow as tf\n", "import numpy as np\n", "\n", "from tf_agents import specs\n", "from tf_agents.agents.dqn import dqn_agent\n", "from tf_agents.drivers import dynamic_step_driver\n", "from tf_agents.environments import suite_gym\n", "from tf_agents.environments import tf_py_environment\n", "from tf_agents.networks import q_network\n", "from tf_agents.replay_buffers import py_uniform_replay_buffer\n", "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n", "from tf_agents.specs import tensor_spec\n", "from tf_agents.trajectories import time_step" ] }, { "cell_type": "markdown", "metadata": { "id": "xcQWclL9FpZl" }, "source": [ "## Replay Buffer API\n", "\n", "The Replay Buffer class has the following definition and methods:\n", "\n", "```python\n", "class ReplayBuffer(tf.Module):\n", " \"\"\"Abstract base class for TF-Agents replay buffer.\"\"\"\n", "\n", " def __init__(self, data_spec, capacity):\n", " \"\"\"Initializes the replay buffer.\n", "\n", " Args:\n", " data_spec: A spec or a list/tuple/nest of specs describing\n", " a single item that can be stored in this buffer\n", " capacity: number of elements that the replay buffer can hold.\n", " \"\"\"\n", "\n", " @property\n", " def data_spec(self):\n", " \"\"\"Returns the spec for items in the replay buffer.\"\"\"\n", "\n", " @property\n", " def capacity(self):\n", " \"\"\"Returns the capacity of the replay buffer.\"\"\"\n", "\n", " def add_batch(self, items):\n", " \"\"\"Adds a batch of items to the replay buffer.\"\"\"\n", "\n", " def get_next(self,\n", " sample_batch_size=None,\n", " num_steps=None,\n", " time_stacked=True):\n", " \"\"\"Returns an item or batch of items from the buffer.\"\"\"\n", "\n", " def as_dataset(self,\n", " sample_batch_size=None,\n", " num_steps=None,\n", " num_parallel_calls=None):\n", " \"\"\"Creates and returns a dataset that returns entries from the buffer.\"\"\"\n", "\n", "\n", " def gather_all(self):\n", " \"\"\"Returns all the items in buffer.\"\"\"\n", " return self._gather_all()\n", "\n", " def clear(self):\n", " \"\"\"Resets the contents of replay buffer\"\"\"\n", "\n", "```\n", "\n", "Note that when the replay buffer object is initialized, it requires the `data_spec` of the elements that it will store. This spec corresponds to the `TensorSpec` of trajectory elements that will be added to the buffer. This spec is usually acquired by looking at an agent's `agent.collect_data_spec` which defines the shapes, types, and structures expected by the agent when training (more on that later)." ] }, { "cell_type": "markdown", "metadata": { "id": "X3Yrxg36Ik1x" }, "source": [ "## TFUniformReplayBuffer\n", "\n", "`TFUniformReplayBuffer` is the most commonly used replay buffer in TF-Agents, thus we will use it in our tutorial here. In `TFUniformReplayBuffer` the backing buffer storage is done by tensorflow variables and thus is part of the compute graph. \n", "\n", "The buffer stores batches of elements and has a maximum capacity `max_length` elements per batch segment. Thus, the total buffer capacity is `batch_size` x `max_length` elements. The elements stored in the buffer must all have a matching data spec. When the replay buffer is used for data collection, the spec is the agent's collect data spec.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "lYk-bn2taXlw" }, "source": [ "### Creating the buffer:\n", "To create a `TFUniformReplayBuffer` we pass in:\n", "1. the spec of the data elements that the buffer will store\n", "2. the `batch size` corresponding to the batch size of the buffer \n", "3. the `max_length` number of elements per batch segment\n", "\n", "Here is an example of creating a `TFUniformReplayBuffer` with sample data specs, `batch_size` 32 and `max_length` 1000." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:22:21.480205Z", "iopub.status.busy": "2024-03-09T12:22:21.479277Z", "iopub.status.idle": "2024-03-09T12:22:24.287402Z", "shell.execute_reply": "2024-03-09T12:22:24.286565Z" }, "id": "Dj4_-77_5ExP" }, "outputs": [], "source": [ "data_spec = (\n", " tf.TensorSpec([3], tf.float32, 'action'),\n", " (\n", " tf.TensorSpec([5], tf.float32, 'lidar'),\n", " tf.TensorSpec([3, 2], tf.float32, 'camera')\n", " )\n", ")\n", "\n", "batch_size = 32\n", "max_length = 1000\n", "\n", "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n", " data_spec,\n", " batch_size=batch_size,\n", " max_length=max_length)" ] }, { "cell_type": "markdown", "metadata": { "id": "XB8rOw5ATDD2" }, "source": [ "### Writing to the buffer:\n", "To add elements to the replay buffer, we use the `add_batch(items)` method where `items` is a list/tuple/nest of tensors representing the batch of items to be added to the buffer. Each element of `items` must have an outer dimension equal `batch_size` and the remaining dimensions must adhere to the data spec of the item (same as the data specs passed to the replay buffer constructor). \n", "\n", "Here's an example of adding a batch of items \n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:22:24.290892Z", "iopub.status.busy": "2024-03-09T12:22:24.290617Z", "iopub.status.idle": "2024-03-09T12:22:24.317386Z", "shell.execute_reply": "2024-03-09T12:22:24.316811Z" }, "id": "nOvkp4vJhBOT" }, "outputs": [], "source": [ "action = tf.constant(1 * np.ones(\n", " data_spec[0].shape.as_list(), dtype=np.float32))\n", "lidar = tf.constant(\n", " 2 * np.ones(data_spec[1][0].shape.as_list(), dtype=np.float32))\n", "camera = tf.constant(\n", " 3 * np.ones(data_spec[1][1].shape.as_list(), dtype=np.float32))\n", " \n", "values = (action, (lidar, camera))\n", "values_batched = tf.nest.map_structure(lambda t: tf.stack([t] * batch_size),\n", " values)\n", " \n", "replay_buffer.add_batch(values_batched)" ] }, { "cell_type": "markdown", "metadata": { "id": "smnVAxHghKly" }, "source": [ "### Reading from the buffer\n", "\n", "There are three ways to read data from the `TFUniformReplayBuffer`:\n", "\n", "1. `get_next()` - returns one sample from the buffer. The sample batch size and number of timesteps returned can be specified via arguments to this method.\n", "2. `as_dataset()` - returns the replay buffer as a `tf.data.Dataset`. One can then create a dataset iterator and iterate through the samples of the items in the buffer.\n", "3. `gather_all()` - returns all the items in the buffer as a Tensor with shape `[batch, time, data_spec]`\n", "\n", "Below are examples of how to read from the replay buffer using each of these methods:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:22:24.320530Z", "iopub.status.busy": "2024-03-09T12:22:24.320269Z", "iopub.status.idle": "2024-03-09T12:22:25.167346Z", "shell.execute_reply": "2024-03-09T12:22:25.166717Z" }, "id": "IlQ1eGhohM3M" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_27300/1348928897.py:7: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use `as_dataset(..., single_deterministic_pass=False) instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iterator trajectories:\n", "[(TensorShape([4, 2, 3]), (TensorShape([4, 2, 5]), TensorShape([4, 2, 3, 2]))), (TensorShape([4, 2, 3]), (TensorShape([4, 2, 5]), TensorShape([4, 2, 3, 2]))), (TensorShape([4, 2, 3]), (TensorShape([4, 2, 5]), TensorShape([4, 2, 3, 2])))]\n", "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_27300/1348928897.py:24: ReplayBuffer.gather_all (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use `as_dataset(..., single_deterministic_pass=True)` instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trajectories from gather all:\n", "(TensorShape([32, 6, 3]), (TensorShape([32, 6, 5]), TensorShape([32, 6, 3, 2])))\n" ] } ], "source": [ "# add more items to the buffer before reading\n", "for _ in range(5):\n", " replay_buffer.add_batch(values_batched)\n", "\n", "# Get one sample from the replay buffer with batch size 10 and 1 timestep:\n", "\n", "sample = replay_buffer.get_next(sample_batch_size=10, num_steps=1)\n", "\n", "# Convert the replay buffer to a tf.data.Dataset and iterate through it\n", "dataset = replay_buffer.as_dataset(\n", " sample_batch_size=4,\n", " num_steps=2)\n", "\n", "iterator = iter(dataset)\n", "print(\"Iterator trajectories:\")\n", "trajectories = []\n", "for _ in range(3):\n", " t, _ = next(iterator)\n", " trajectories.append(t)\n", " \n", "print(tf.nest.map_structure(lambda t: t.shape, trajectories))\n", "\n", "# Read all elements in the replay buffer:\n", "trajectories = replay_buffer.gather_all()\n", "\n", "print(\"Trajectories from gather all:\")\n", "print(tf.nest.map_structure(lambda t: t.shape, trajectories))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "BcS49HrNF34W" }, "source": [ "## PyUniformReplayBuffer\n", "`PyUniformReplayBuffer` has the same functionaly as the `TFUniformReplayBuffer` but instead of tf variables, its data is stored in numpy arrays. This buffer can be used for out-of-graph data collection. Having the backing storage in numpy may make it easier for some applications to do data manipulation (such as indexing for updating priorities) without using Tensorflow variables. However, this implementation won't have the benefit of graph optimizations with Tensorflow. \n", "\n", "Below is an example of instantiating a `PyUniformReplayBuffer` from the agent's policy trajectory specs:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:22:25.171070Z", "iopub.status.busy": "2024-03-09T12:22:25.170396Z", "iopub.status.idle": "2024-03-09T12:22:25.174996Z", "shell.execute_reply": "2024-03-09T12:22:25.174413Z" }, "id": "F4neLPpL25wI" }, "outputs": [], "source": [ "replay_buffer_capacity = 1000*32 # same capacity as the TFUniformReplayBuffer\n", "\n", "py_replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(\n", " capacity=replay_buffer_capacity,\n", " data_spec=tensor_spec.to_nest_array_spec(data_spec))" ] }, { "cell_type": "markdown", "metadata": { "id": "9V7DEcB8IeiQ" }, "source": [ "## Using replay buffers during training\n", "Now that we know how to create a replay buffer, write items to it and read from it, we can use it to store trajectories during training of our agents. \n", "\n", "### Data collection\n", "First, let's look at how to use the replay buffer during data collection.\n", "\n", "In TF-Agents we use a `Driver` (see the Driver tutorial for more details) to collect experience in an environment. To use a `Driver`, we specify an `Observer` that is a function for the `Driver` to execute when it receives a trajectory. \n", "\n", "Thus, to add trajectory elements to the replay buffer, we add an observer that calls `add_batch(items)` to add a batch of items on the replay buffer. \n", "\n", "Below is an example of this with `TFUniformReplayBuffer`. We first create an environment, a network and an agent. Then we create a `TFUniformReplayBuffer`. Note that the specs of the trajectory elements in the replay buffer are equal to the agent's collect data spec. We then set its `add_batch` method as the observer for the driver that will do the data collect during our training:\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:22:25.178210Z", "iopub.status.busy": "2024-03-09T12:22:25.177610Z", "iopub.status.idle": "2024-03-09T12:22:26.425216Z", "shell.execute_reply": "2024-03-09T12:22:26.424485Z" }, "id": "pCbTDO3Z5UCS" }, "outputs": [], "source": [ "env = suite_gym.load('CartPole-v0')\n", "tf_env = tf_py_environment.TFPyEnvironment(env)\n", "\n", "q_net = q_network.QNetwork(\n", " tf_env.time_step_spec().observation,\n", " tf_env.action_spec(),\n", " fc_layer_params=(100,))\n", "\n", "agent = dqn_agent.DqnAgent(\n", " tf_env.time_step_spec(),\n", " tf_env.action_spec(),\n", " q_network=q_net,\n", " optimizer=tf.compat.v1.train.AdamOptimizer(0.001))\n", "\n", "replay_buffer_capacity = 1000\n", "\n", "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n", " agent.collect_data_spec,\n", " batch_size=tf_env.batch_size,\n", " max_length=replay_buffer_capacity)\n", "\n", "# Add an observer that adds to the replay buffer:\n", "replay_observer = [replay_buffer.add_batch]\n", "\n", "collect_steps_per_iteration = 10\n", "collect_op = dynamic_step_driver.DynamicStepDriver(\n", " tf_env,\n", " agent.collect_policy,\n", " observers=replay_observer,\n", " num_steps=collect_steps_per_iteration).run()" ] }, { "cell_type": "markdown", "metadata": { "id": "huGCDbO4GAF1" }, "source": [ "### Reading data for a train step\n", "\n", "After adding trajectory elements to the replay buffer, we can read batches of trajectories from the replay buffer to use as input data for a train step.\n", "\n", "Here is an example of how to train on trajectories from the replay buffer in a training loop: " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:22:26.429582Z", "iopub.status.busy": "2024-03-09T12:22:26.428962Z", "iopub.status.idle": "2024-03-09T12:22:28.791955Z", "shell.execute_reply": "2024-03-09T12:22:28.791188Z" }, "id": "gg8SUyXXnSMr" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:1260: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "back_prop=False is deprecated. Consider using tf.stop_gradient instead.\n", "Instead of:\n", "results = tf.foldr(fn, elems, back_prop=False)\n", "Use:\n", "results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))\n" ] } ], "source": [ "# Read the replay buffer as a Dataset,\n", "# read batches of 4 elements, each with 2 timesteps:\n", "dataset = replay_buffer.as_dataset(\n", " sample_batch_size=4,\n", " num_steps=2)\n", "\n", "iterator = iter(dataset)\n", "\n", "num_train_steps = 10\n", "\n", "for _ in range(num_train_steps):\n", " trajectories, _ = next(iterator)\n", " loss = agent.train(experience=trajectories)\n" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "TF-Agents Replay Buffers Tutorial.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }