{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "1Pi_B2cvdBiW" }, "source": [ "##### Copyright 2023 The TF-Agents Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-03-09T12:12:53.583961Z", "iopub.status.busy": "2024-03-09T12:12:53.583498Z", "iopub.status.idle": "2024-03-09T12:12:53.587205Z", "shell.execute_reply": "2024-03-09T12:12:53.586639Z" }, "id": "nQnmcm0oI1Q-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "3NFuTvWVZG_B" }, "source": [ "# Policies\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " View on TensorFlow.org\n", " \n", " \n", " \n", " Run in Google Colab\n", " \n", " \n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "31uij8nIo5bG" }, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "metadata": { "id": "PqFn7q5bs3BF" }, "source": [ "In Reinforcement Learning terminology, policies map an observation from the environment to an action or a distribution over actions. In TF-Agents, observations from the environment are contained in a named tuple `TimeStep('step_type', 'discount', 'reward', 'observation')`, and policies map timesteps to actions or distributions over actions. Most policies use `timestep.observation`, some policies use `timestep.step_type` (e.g. to reset the state at the beginning of an episode in stateful policies), but `timestep.discount` and `timestep.reward` are usually ignored.\n", "\n", "Policies are related to other components in TF-Agents in the following way. Most policies have a neural network to compute actions and/or distributions over actions from TimeSteps. Agents can contain one or more policies for different purposes, e.g. a main policy that is being trained for deployment, and a noisy policy for data collection. Policies can be saved/restored, and can be used indepedently of the agent for data collection, evaluation etc.\n", "\n", "Some policies are easier to write in Tensorflow (e.g. those with a neural network), whereas others are easier to write in Python (e.g. following a script of actions). So in TF agents, we allow both Python and Tensorflow policies. Morever, policies written in TensorFlow might have to be used in a Python environment, or vice versa, e.g. a TensorFlow policy is used for training but later deployed in a production Python environment. To make this easier, we provide wrappers for converting between Python and TensorFlow policies.\n", "\n", "Another interesting class of policies are policy wrappers, which modify a given policy in a certain way, e.g. add a particular type of noise, make a greedy or epsilon-greedy version of a stochastic policy, randomly mix multiple policies etc. " ] }, { "cell_type": "markdown", "metadata": { "id": "HdnG_TT_amWH" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "9Meq2nT_aquh" }, "source": [ "If you haven't installed tf-agents yet, run:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:12:53.590836Z", "iopub.status.busy": "2024-03-09T12:12:53.590591Z", "iopub.status.idle": "2024-03-09T12:13:03.218346Z", "shell.execute_reply": "2024-03-09T12:13:03.217500Z" }, "id": "xsLTHlVdiZP3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting tf-agents\r\n", " Using cached tf_agents-0.19.0-py3-none-any.whl.metadata (12 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: absl-py>=0.6.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.4.0)\r\n", "Collecting cloudpickle>=1.3 (from tf-agents)\r\n", " Using cached cloudpickle-3.0.0-py3-none-any.whl.metadata (7.0 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gin-config>=0.4.0 (from tf-agents)\r\n", " Using cached gin_config-0.5.0-py3-none-any.whl.metadata (2.9 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gym<=0.23.0,>=0.17.0 (from tf-agents)\r\n", " Using cached gym-0.23.0-py3-none-any.whl\r\n", "Requirement already satisfied: numpy>=1.19.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.26.4)\r\n", "Requirement already satisfied: pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (10.2.0)\r\n", "Requirement already satisfied: six>=1.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.16.0)\r\n", "Requirement already satisfied: protobuf>=3.11.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (3.20.3)\r\n", "Requirement already satisfied: wrapt>=1.11.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.16.0)\r\n", "Collecting typing-extensions==4.5.0 (from tf-agents)\r\n", " Using cached typing_extensions-4.5.0-py3-none-any.whl.metadata (8.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting pygame==2.1.3 (from tf-agents)\r\n", " Using cached pygame-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.3 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow-probability~=0.23.0 (from tf-agents)\r\n", " Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl.metadata (13 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gym-notices>=0.0.4 (from gym<=0.23.0,>=0.17.0->tf-agents)\r\n", " Using cached gym_notices-0.0.8-py3-none-any.whl.metadata (1.0 kB)\r\n", "Requirement already satisfied: importlib-metadata>=4.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from gym<=0.23.0,>=0.17.0->tf-agents) (7.0.2)\r\n", "Requirement already satisfied: decorator in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (5.1.1)\r\n", "Requirement already satisfied: gast>=0.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (0.5.4)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (0.1.8)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.10.0->gym<=0.23.0,>=0.17.0->tf-agents) (3.17.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached tf_agents-0.19.0-py3-none-any.whl (1.4 MB)\r\n", "Using cached pygame-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.7 MB)\r\n", "Using cached typing_extensions-4.5.0-py3-none-any.whl (27 kB)\r\n", "Using cached cloudpickle-3.0.0-py3-none-any.whl (20 kB)\r\n", "Using cached gin_config-0.5.0-py3-none-any.whl (61 kB)\r\n", "Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl (6.9 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: gym-notices, gin-config, typing-extensions, pygame, cloudpickle, tensorflow-probability, gym, tf-agents\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: typing-extensions\r\n", " Found existing installation: typing_extensions 4.10.0\r\n", " Uninstalling typing_extensions-4.10.0:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled typing_extensions-4.10.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed cloudpickle-3.0.0 gin-config-0.5.0 gym-0.23.0 gym-notices-0.0.8 pygame-2.1.3 tensorflow-probability-0.23.0 tf-agents-0.19.0 typing-extensions-4.5.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tf-keras in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.16.0)\r\n", "Requirement already satisfied: tensorflow<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-keras) (2.16.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.4.0)\r\n", "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (24.3.7)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.5.4)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.2.0)\r\n", "Requirement already satisfied: h5py>=3.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.10.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (16.0.6)\r\n", "Requirement already satisfied: ml-dtypes~=0.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.3.2)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (23.2)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.20.3)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.31.0)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (69.1.1)\r\n", "Requirement already satisfied: six>=1.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.16.0)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.4.0)\r\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (4.5.0)\r\n", "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.16.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.62.1)\r\n", "Requirement already satisfied: tensorboard<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.16.2)\r\n", "Requirement already satisfied: keras>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.0.5)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.36.0)\r\n", "Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.26.4)\r\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow<2.17,>=2.16->tf-keras) (0.41.2)\r\n", "Requirement already satisfied: rich in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (13.7.1)\r\n", "Requirement already satisfied: namex in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.0.7)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.1.8)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (3.6)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (2.2.1)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (2024.2.2)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.5.2)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.0.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (7.0.2)\r\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (2.1.5)\r\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (3.0.0)\r\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (2.17.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.17.0)\r\n", "Requirement already satisfied: mdurl~=0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.1.2)\r\n" ] } ], "source": [ "!pip install tf-agents\n", "!pip install tf-keras" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:13:03.222956Z", "iopub.status.busy": "2024-03-09T12:13:03.222277Z", "iopub.status.idle": "2024-03-09T12:13:03.225977Z", "shell.execute_reply": "2024-03-09T12:13:03.225397Z" }, "id": "h3dkwi09ZQeJ" }, "outputs": [], "source": [ "import os\n", "# Keep using keras-2 (tf-keras) rather than keras-3 (keras).\n", "os.environ['TF_USE_LEGACY_KERAS'] = '1'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:13:03.228877Z", "iopub.status.busy": "2024-03-09T12:13:03.228387Z", "iopub.status.idle": "2024-03-09T12:13:06.049678Z", "shell.execute_reply": "2024-03-09T12:13:06.048839Z" }, "id": "sdvop99JlYSM" }, "outputs": [], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "import abc\n", "import tensorflow as tf\n", "import tensorflow_probability as tfp\n", "import numpy as np\n", "\n", "from tf_agents.specs import array_spec\n", "from tf_agents.specs import tensor_spec\n", "from tf_agents.networks import network\n", "\n", "from tf_agents.policies import py_policy\n", "from tf_agents.policies import random_py_policy\n", "from tf_agents.policies import scripted_py_policy\n", "\n", "from tf_agents.policies import tf_policy\n", "from tf_agents.policies import random_tf_policy\n", "from tf_agents.policies import actor_policy\n", "from tf_agents.policies import q_policy\n", "from tf_agents.policies import greedy_policy\n", "\n", "from tf_agents.trajectories import time_step as ts" ] }, { "cell_type": "markdown", "metadata": { "id": "NyXO5-Aalb-6" }, "source": [ "## Python Policies" ] }, { "cell_type": "markdown", "metadata": { "id": "DOtUZ1hs02bu" }, "source": [ "The interface for Python policies is defined in `policies/py_policy.PyPolicy`. The main methods are:\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:13:06.054486Z", "iopub.status.busy": "2024-03-09T12:13:06.053654Z", "iopub.status.idle": "2024-03-09T12:13:06.059473Z", "shell.execute_reply": "2024-03-09T12:13:06.058887Z" }, "id": "4PqNEVls1uqc" }, "outputs": [], "source": [ "class Base(object):\n", "\n", " @abc.abstractmethod\n", " def __init__(self, time_step_spec, action_spec, policy_state_spec=()):\n", " self._time_step_spec = time_step_spec\n", " self._action_spec = action_spec\n", " self._policy_state_spec = policy_state_spec\n", "\n", " @abc.abstractmethod\n", " def reset(self, policy_state=()):\n", " # return initial_policy_state.\n", " pass\n", "\n", " @abc.abstractmethod\n", " def action(self, time_step, policy_state=()):\n", " # return a PolicyStep(action, state, info) named tuple.\n", " pass\n", "\n", " @abc.abstractmethod\n", " def distribution(self, time_step, policy_state=()):\n", " # Not implemented in python, only for TF policies.\n", " pass\n", "\n", " @abc.abstractmethod\n", " def update(self, policy):\n", " # update self to be similar to the input `policy`.\n", " pass\n", "\n", " @property\n", " def time_step_spec(self):\n", " return self._time_step_spec\n", "\n", " @property\n", " def action_spec(self):\n", " return self._action_spec\n", "\n", " @property\n", " def policy_state_spec(self):\n", " return self._policy_state_spec" ] }, { "cell_type": "markdown", "metadata": { "id": "16kyDKk65bka" }, "source": [ "The most important method is `action(time_step)` which maps a `time_step` containing an observation from the environment to a PolicyStep named tuple containing the following attributes:\n", "\n", "* `action`: The action to be applied to the environment.\n", "* `state`: The state of the policy (e.g. RNN state) to be fed into the next call to action.\n", "* `info`: Optional side information such as action log probabilities.\n", "\n", "The `time_step_spec` and `action_spec` are specifications for the input time step and the output action. Policies also have a `reset` function which is typically used for resetting the state in stateful policies. The `update(new_policy)` function updates `self` towards `new_policy`.\n", "\n", "Now, let us look at a couple of examples of Python policies.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "YCH1Hs_WlmDT" }, "source": [ "### Example 1: Random Python Policy" ] }, { "cell_type": "markdown", "metadata": { "id": "lbnQ0BQ3_0N2" }, "source": [ "A simple example of a `PyPolicy` is the `RandomPyPolicy` which generates random actions for the discrete/continuous given action_spec. The input `time_step` is ignored." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:13:06.063249Z", "iopub.status.busy": "2024-03-09T12:13:06.062777Z", "iopub.status.idle": "2024-03-09T12:13:06.070057Z", "shell.execute_reply": "2024-03-09T12:13:06.069462Z" }, "id": "QX8M4Nl-_0uu" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PolicyStep(action=array([5, 3], dtype=int32), state=(), info=())\n", "PolicyStep(action=array([-4, 3], dtype=int32), state=(), info=())\n" ] } ], "source": [ "action_spec = array_spec.BoundedArraySpec((2,), np.int32, -10, 10)\n", "my_random_py_policy = random_py_policy.RandomPyPolicy(time_step_spec=None,\n", " action_spec=action_spec)\n", "time_step = None\n", "action_step = my_random_py_policy.action(time_step)\n", "print(action_step)\n", "action_step = my_random_py_policy.action(time_step)\n", "print(action_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "B8WrFOR1lz31" }, "source": [ "### Example 2: Scripted Python Policy" ] }, { "cell_type": "markdown", "metadata": { "id": "AJ0Br1lGBnTT" }, "source": [ "A scripted policy plays back a script of actions represented as a list of `(num_repeats, action)` tuples. Every time the `action` function is called, it returns the next action from the list until the specified number of repeats is done, and then moves on to the next action in the list. The `reset` method can be called to start executing from the beginning of the list." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:13:06.073608Z", "iopub.status.busy": "2024-03-09T12:13:06.073175Z", "iopub.status.idle": "2024-03-09T12:13:06.081230Z", "shell.execute_reply": "2024-03-09T12:13:06.080537Z" }, "id": "_mZ244m4BUYv" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Executing scripted policy...\n", "PolicyStep(action=array([5, 2], dtype=int32), state=[0, 1], info=())\n", "PolicyStep(action=array([1, 2], dtype=int32), state=[2, 1], info=())\n", "PolicyStep(action=array([1, 2], dtype=int32), state=[2, 2], info=())\n", "Resetting my_scripted_py_policy...\n", "PolicyStep(action=array([5, 2], dtype=int32), state=[0, 1], info=())\n" ] } ], "source": [ "action_spec = array_spec.BoundedArraySpec((2,), np.int32, -10, 10)\n", "action_script = [(1, np.array([5, 2], dtype=np.int32)),\n", " (0, np.array([0, 0], dtype=np.int32)), # Setting `num_repeats` to 0 will skip this action.\n", " (2, np.array([1, 2], dtype=np.int32)),\n", " (1, np.array([3, 4], dtype=np.int32))]\n", "\n", "my_scripted_py_policy = scripted_py_policy.ScriptedPyPolicy(\n", " time_step_spec=None, action_spec=action_spec, action_script=action_script)\n", "\n", "policy_state = my_scripted_py_policy.get_initial_state()\n", "time_step = None\n", "print('Executing scripted policy...')\n", "action_step = my_scripted_py_policy.action(time_step, policy_state)\n", "print(action_step)\n", "action_step= my_scripted_py_policy.action(time_step, action_step.state)\n", "print(action_step)\n", "action_step = my_scripted_py_policy.action(time_step, action_step.state)\n", "print(action_step)\n", "\n", "print('Resetting my_scripted_py_policy...')\n", "policy_state = my_scripted_py_policy.get_initial_state()\n", "action_step = my_scripted_py_policy.action(time_step, policy_state)\n", "print(action_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "3Dz7HSTZl6aU" }, "source": [ "## TensorFlow Policies" ] }, { "cell_type": "markdown", "metadata": { "id": "LwcoBXqKl8Yb" }, "source": [ "TensorFlow policies follow the same interface as Python policies. Let us look at a few examples." ] }, { "cell_type": "markdown", "metadata": { "id": "3x8pDWEFrQ5C" }, "source": [ "### Example 1: Random TF Policy\n", "\n", "A RandomTFPolicy can be used to generate random actions according to a given discrete/continuous `action_spec`. The input `time_step` is ignored.\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:13:06.084808Z", "iopub.status.busy": "2024-03-09T12:13:06.084328Z", "iopub.status.idle": "2024-03-09T12:13:09.578982Z", "shell.execute_reply": "2024-03-09T12:13:09.578249Z" }, "id": "nZ3pe5G4rjrW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Action:\n", "tf.Tensor([9.8276138e-04 2.8761353e+00], shape=(2,), dtype=float32)\n" ] } ], "source": [ "action_spec = tensor_spec.BoundedTensorSpec(\n", " (2,), tf.float32, minimum=-1, maximum=3)\n", "input_tensor_spec = tensor_spec.TensorSpec((2,), tf.float32)\n", "time_step_spec = ts.time_step_spec(input_tensor_spec)\n", "\n", "my_random_tf_policy = random_tf_policy.RandomTFPolicy(\n", " action_spec=action_spec, time_step_spec=time_step_spec)\n", "observation = tf.ones(time_step_spec.observation.shape)\n", "time_step = ts.restart(observation)\n", "action_step = my_random_tf_policy.action(time_step)\n", "\n", "print('Action:')\n", "print(action_step.action)" ] }, { "cell_type": "markdown", "metadata": { "id": "GOBoWETprWCB" }, "source": [ "### Example 2: Actor Policy\n", "\n", "An actor policy can be created using either a network that maps `time_steps` to actions or a network that maps `time_steps` to distributions over actions.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2S94E5zQgge_" }, "source": [ "#### Using an action network" ] }, { "cell_type": "markdown", "metadata": { "id": "X2LM5STNgv1u" }, "source": [ "Let us define a network as follows:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:13:09.583240Z", "iopub.status.busy": "2024-03-09T12:13:09.582951Z", "iopub.status.idle": "2024-03-09T12:13:09.588550Z", "shell.execute_reply": "2024-03-09T12:13:09.587951Z" }, "id": "S2wFgzJFteQX" }, "outputs": [], "source": [ "class ActionNet(network.Network):\n", "\n", " def __init__(self, input_tensor_spec, output_tensor_spec):\n", " super(ActionNet, self).__init__(\n", " input_tensor_spec=input_tensor_spec,\n", " state_spec=(),\n", " name='ActionNet')\n", " self._output_tensor_spec = output_tensor_spec\n", " self._sub_layers = [\n", " tf.keras.layers.Dense(\n", " action_spec.shape.num_elements(), activation=tf.nn.tanh),\n", " ]\n", "\n", " def call(self, observations, step_type, network_state):\n", " del step_type\n", "\n", " output = tf.cast(observations, dtype=tf.float32)\n", " for layer in self._sub_layers:\n", " output = layer(output)\n", " actions = tf.reshape(output, [-1] + self._output_tensor_spec.shape.as_list())\n", "\n", " # Scale and shift actions to the correct range if necessary.\n", " return actions, network_state" ] }, { "cell_type": "markdown", "metadata": { "id": "k7fIn-ybVdC6" }, "source": [ "In TensorFlow most network layers are designed for batch operations, so we expect the input time_steps to be batched, and the output of the network will be batched as well. Also the network is responsible for producing actions in the correct range of the given action_spec. This is conventionally done using e.g. a tanh activation for the final layer to produce actions in [-1, 1] and then scaling and shifting this to the correct range as the input action_spec (e.g. see `tf_agents/agents/ddpg/networks.actor_network()`).\n", "\n", "Now, we can create an actor policy using the above network." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:13:09.592137Z", "iopub.status.busy": "2024-03-09T12:13:09.591476Z", "iopub.status.idle": "2024-03-09T12:13:09.884991Z", "shell.execute_reply": "2024-03-09T12:13:09.884259Z" }, "id": "0UGmFTe7a5VQ" }, "outputs": [], "source": [ "input_tensor_spec = tensor_spec.TensorSpec((4,), tf.float32)\n", "time_step_spec = ts.time_step_spec(input_tensor_spec)\n", "action_spec = tensor_spec.BoundedTensorSpec((3,),\n", " tf.float32,\n", " minimum=-1,\n", " maximum=1)\n", "\n", "action_net = ActionNet(input_tensor_spec, action_spec)\n", "\n", "my_actor_policy = actor_policy.ActorPolicy(\n", " time_step_spec=time_step_spec,\n", " action_spec=action_spec,\n", " actor_network=action_net)" ] }, { "cell_type": "markdown", "metadata": { "id": "xlmGPTAmfPK3" }, "source": [ "We can apply it to any batch of time_steps that follow time_step_spec:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:13:09.888973Z", "iopub.status.busy": "2024-03-09T12:13:09.888395Z", "iopub.status.idle": "2024-03-09T12:13:09.937540Z", "shell.execute_reply": "2024-03-09T12:13:09.936836Z" }, "id": "fvsIsR0VfOA4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Action:\n", "tf.Tensor(\n", "[[ 0.85880756 -0.74206954 -0.7772715 ]\n", " [ 0.85880756 -0.74206954 -0.7772715 ]], shape=(2, 3), dtype=float32)\n", "Action distribution:\n", "tfp.distributions.Deterministic(\"Deterministic\", batch_shape=[2, 3], event_shape=[], dtype=float32)\n" ] } ], "source": [ "batch_size = 2\n", "observations = tf.ones([2] + time_step_spec.observation.shape.as_list())\n", "\n", "time_step = ts.restart(observations, batch_size)\n", "\n", "action_step = my_actor_policy.action(time_step)\n", "print('Action:')\n", "print(action_step.action)\n", "\n", "distribution_step = my_actor_policy.distribution(time_step)\n", "print('Action distribution:')\n", "print(distribution_step.action)" ] }, { "cell_type": "markdown", "metadata": { "id": "lumtyhejZOXR" }, "source": [ "In the above example, we created the policy using an action network that produces an action tensor. In this case, `policy.distribution(time_step)` is a deterministic (delta) distribution around the output of `policy.action(time_step)`. One way to produce a stochastic policy is to wrap the actor policy in a policy wrapper that adds noise to the actions. Another way is to create the actor policy using an action distribution network instead of an action network as shown below." ] }, { "cell_type": "markdown", "metadata": { "id": "_eNrJ5gKgl3W" }, "source": [ "#### Using an action distribution network" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:13:09.941191Z", "iopub.status.busy": "2024-03-09T12:13:09.940513Z", "iopub.status.idle": "2024-03-09T12:13:10.072096Z", "shell.execute_reply": "2024-03-09T12:13:10.071340Z" }, "id": "sSYzC9LobVsK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Action:\n", "tf.Tensor(\n", "[[-1. 1. 1. ]\n", " [-0.7074561 0.1602813 0.34091526]], shape=(2, 3), dtype=float32)\n", "Action distribution:\n", "tfp.distributions.MultivariateNormalDiag(\"MultivariateNormalDiag\", batch_shape=[2], event_shape=[3], dtype=float32)\n" ] } ], "source": [ "class ActionDistributionNet(ActionNet):\n", "\n", " def call(self, observations, step_type, network_state):\n", " action_means, network_state = super(ActionDistributionNet, self).call(\n", " observations, step_type, network_state)\n", "\n", " action_std = tf.ones_like(action_means)\n", " return tfp.distributions.MultivariateNormalDiag(action_means, action_std), network_state\n", "\n", "\n", "action_distribution_net = ActionDistributionNet(input_tensor_spec, action_spec)\n", "\n", "my_actor_policy = actor_policy.ActorPolicy(\n", " time_step_spec=time_step_spec,\n", " action_spec=action_spec,\n", " actor_network=action_distribution_net)\n", "\n", "action_step = my_actor_policy.action(time_step)\n", "print('Action:')\n", "print(action_step.action)\n", "distribution_step = my_actor_policy.distribution(time_step)\n", "print('Action distribution:')\n", "print(distribution_step.action)" ] }, { "cell_type": "markdown", "metadata": { "id": "BzoNGJnlibtz" }, "source": [ "Note that in the above, actions are clipped to the range of the given action spec [-1, 1]. This is because a constructor argument of ActorPolicy clip=True by default. Setting this to false will return unclipped actions produced by the network." ] }, { "cell_type": "markdown", "metadata": { "id": "PLj6A-5domNG" }, "source": [ "Stochastic policies can be converted to deterministic policies using, for example, a GreedyPolicy wrapper which chooses `stochastic_policy.distribution().mode()` as its action, and a deterministic/delta distribution around this greedy action as its `distribution()`." ] }, { "cell_type": "markdown", "metadata": { "id": "4Xxzo2a7rZ7v" }, "source": [ "### Example 3: Q Policy" ] }, { "cell_type": "markdown", "metadata": { "id": "79eGLqpOhQVp" }, "source": [ "A Q policy is used in agents like DQN and is based on a Q network that predicts a Q value for each discrete action. For a given time step, the action distribution in the Q Policy is a categorical distribution created using the q values as logits.\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:13:10.076135Z", "iopub.status.busy": "2024-03-09T12:13:10.075710Z", "iopub.status.idle": "2024-03-09T12:13:10.105420Z", "shell.execute_reply": "2024-03-09T12:13:10.104771Z" }, "id": "Haakr2VvjqKC" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Action:\n", "tf.Tensor([2 0], shape=(2,), dtype=int32)\n", "Action distribution:\n", "tfp.distributions.Categorical(\"Categorical\", batch_shape=[2], event_shape=[], dtype=int32)\n" ] } ], "source": [ "input_tensor_spec = tensor_spec.TensorSpec((4,), tf.float32)\n", "time_step_spec = ts.time_step_spec(input_tensor_spec)\n", "action_spec = tensor_spec.BoundedTensorSpec((),\n", " tf.int32,\n", " minimum=0,\n", " maximum=2)\n", "num_actions = action_spec.maximum - action_spec.minimum + 1\n", "\n", "\n", "class QNetwork(network.Network):\n", "\n", " def __init__(self, input_tensor_spec, action_spec, num_actions=num_actions, name=None):\n", " super(QNetwork, self).__init__(\n", " input_tensor_spec=input_tensor_spec,\n", " state_spec=(),\n", " name=name)\n", " self._sub_layers = [\n", " tf.keras.layers.Dense(num_actions),\n", " ]\n", "\n", " def call(self, inputs, step_type=None, network_state=()):\n", " del step_type\n", " inputs = tf.cast(inputs, tf.float32)\n", " for layer in self._sub_layers:\n", " inputs = layer(inputs)\n", " return inputs, network_state\n", "\n", "\n", "batch_size = 2\n", "observation = tf.ones([batch_size] + time_step_spec.observation.shape.as_list())\n", "time_steps = ts.restart(observation, batch_size=batch_size)\n", "\n", "my_q_network = QNetwork(\n", " input_tensor_spec=input_tensor_spec,\n", " action_spec=action_spec)\n", "my_q_policy = q_policy.QPolicy(\n", " time_step_spec, action_spec, q_network=my_q_network)\n", "action_step = my_q_policy.action(time_steps)\n", "distribution_step = my_q_policy.distribution(time_steps)\n", "\n", "print('Action:')\n", "print(action_step.action)\n", "\n", "print('Action distribution:')\n", "print(distribution_step.action)" ] }, { "cell_type": "markdown", "metadata": { "id": "Xpu9m6mvqJY-" }, "source": [ "## Policy Wrappers" ] }, { "cell_type": "markdown", "metadata": { "id": "OfaUrqRAoigk" }, "source": [ "A policy wrapper can be used to wrap and modify a given policy, e.g. add noise. Policy wrappers are a subclass of Policy (Python/TensorFlow) and can therefore be used just like any other policy." ] }, { "cell_type": "markdown", "metadata": { "id": "-JJVVAALqVNQ" }, "source": [ "### Example: Greedy Policy\n", "\n", "\n", "A greedy wrapper can be used to wrap any TensorFlow policy that implements `distribution()`. `GreedyPolicy.action()` will return `wrapped_policy.distribution().mode()` and `GreedyPolicy.distribution()` is a deterministic/delta distribution around `GreedyPolicy.action()`:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T12:13:10.109103Z", "iopub.status.busy": "2024-03-09T12:13:10.108514Z", "iopub.status.idle": "2024-03-09T12:13:10.133725Z", "shell.execute_reply": "2024-03-09T12:13:10.133099Z" }, "id": "xsRPBeLZtXvu" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Action:\n", "tf.Tensor([0 0], shape=(2,), dtype=int32)\n", "Action distribution:\n", "tfp.distributions.Deterministic(\"Deterministic\", batch_shape=[2], event_shape=[], dtype=int32)\n" ] } ], "source": [ "my_greedy_policy = greedy_policy.GreedyPolicy(my_q_policy)\n", "\n", "action_step = my_greedy_policy.action(time_steps)\n", "print('Action:')\n", "print(action_step.action)\n", "\n", "distribution_step = my_greedy_policy.distribution(time_steps)\n", "print('Action distribution:')\n", "print(distribution_step.action)" ] } ], "metadata": { "colab": { "name": "TF-Agents Policies Tutorial.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }