{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "_jQ1tEQCxwRx" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-01-11T20:24:41.571315Z", "iopub.status.busy": "2024-01-11T20:24:41.571058Z", "iopub.status.idle": "2024-01-11T20:24:41.574757Z", "shell.execute_reply": "2024-01-11T20:24:41.574202Z" }, "id": "V_sgB_5dx1f1" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "p62G8M_viUJp" }, "source": [ "# Actor-Critic 法による CartPole の実験\n" ] }, { "cell_type": "markdown", "metadata": { "id": "-mJ2i6jvZ3sK" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "kFgN7h_wiUJq" }, "source": [ "このチュートリアルでは、[(深層)強化学習](https://en.wikipedia.org/wiki/Deep_reinforcement_learning)の[ポリシー勾配メソッド](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)を理解していることを前提に、TensorFlow を使用して [Actor-Critic](https://papers.nips.cc/paper/1786-actor-critic-algorithms.pdf) 法を実装し、[Open AI Gym](https://www.gymlibrary.dev/) の [`CartPole-v0`](https://www.gymlibrary.dev/environments/classic_control/cart_pole/) 環境のエージェントをトレーニングする方法を示します。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "_kA10ZKRR0hi" }, "source": [ "**Actor-Critic 法**\n", "\n", "Actor-Critic 法は、価値関数から独立してポリシー関数を表す[TD(時間的差分)学習](https://en.wikipedia.org/wiki/Temporal_difference_learning)の手法です。\n", "\n", "ポリシー関数(またはポリシー)は、ある特定の状態に基づいてエージェントが実行できるアクションの確率分布を返します。価値関数は、特定の状態で開始し、その後永久に特定のポリシーに従って動作するエージェントの期待される戻り値を決定します。\n", "\n", "Actor-Critic 法では、ポリシーは状態に応じて一連の可能なアクションを提案する「*アクター*」と呼ばれます。推定される価値関数は「*クリティック*’と呼ばれ、特定のポリシーに基づいて「*アクター*」が実行するアクションを評価します。\n", "\n", "このチュートリアルでは、*アクター*と*クリティック*は、2 つの出力を持つ 1 つのニューラルネットワークを使って表現されます。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "rBfiafKSRs2k" }, "source": [ "**`CartPole-v0`**\n", "\n", "[`CartPole-v0` 環境](https://www.gymlibrary.dev/environments/classic_control/cart_pole/)では、ポールは摩擦のないレール上を移動するカートに取り付けられています。ポールは直立状態で始まり、エージェントの目標は、カートに `-1` または `+1` の力を加えてポールが倒れないようにすることです。ポールが直立状態を維持する時間ステップごとに `+1` の報酬が与えられます。エピソードは、(1)ポールが直立から 15 度以上に傾斜したとき、または(2)カートが中央から 2.4 ユニット以上移動したときに、終了します。\n", "\n", "
\n", "
<figure>\n",
    "    <image src=\"https://tensorflow.org/tutorials/reinforcement_learning/images/cartpole-v0.gif\">\n",
    "    <figcaption>\n",
    "      Trained actor-critic model in Cartpole-v0 environment\n",
    "    </figcaption>\n",
    "  </figure>
\n", "
\n" ] }, { "cell_type": "markdown", "metadata": { "id": "XSNVK0AeRoJd" }, "source": [ "この問題は、100 回の連続トライアルにおいて、エピソードの平均合計報酬が 195 に達すると「解決」とみなされます。" ] }, { "cell_type": "markdown", "metadata": { "id": "glLwIctHiUJq" }, "source": [ "## セットアップ\n", "\n", "必要なパッケージをインポートし、グローバル設定を構成します。\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:24:41.578963Z", "iopub.status.busy": "2024-01-11T20:24:41.578354Z", "iopub.status.idle": "2024-01-11T20:24:51.944115Z", "shell.execute_reply": "2024-01-11T20:24:51.943292Z" }, "id": "13l6BbxKhCKp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting gym[classic_control]\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading gym-0.26.2.tar.gz (721 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Installing build dependencies ... \u001b[?25l-" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \b\\" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \b|" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \bdone\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?25h Getting requirements to build wheel ... \u001b[?25l-\b \bdone\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25l-" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \bdone\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: gym 0.26.2 does not provide the extra 'classic-control'\u001b[0m\u001b[33m\r\n", "\u001b[0m\u001b[?25hRequirement already satisfied: numpy>=1.18.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from gym[classic_control]) (1.26.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting cloudpickle>=1.2.0 (from gym[classic_control])\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading cloudpickle-3.0.0-py3-none-any.whl.metadata (7.0 kB)\r\n", "Collecting gym-notices>=0.0.4 (from gym[classic_control])\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n", "Requirement already satisfied: importlib-metadata>=4.8.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from gym[classic_control]) (7.0.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting pygame==2.1.0 (from gym[classic_control])\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading pygame-2.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.8.0->gym[classic_control]) (3.17.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading cloudpickle-3.0.0-py3-none-any.whl (20 kB)\r\n", "Building wheels for collected packages: gym\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Building wheel for gym (pyproject.toml) ... \u001b[?25l-" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \b\\" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \bdone\r\n", "\u001b[?25h Created wheel for gym: filename=gym-0.26.2-py3-none-any.whl size=827620 sha256=8170d880af0e8c9336c9f2111586b5bd2f0af1f6dce1fe9c1bab4000085e5fef\r\n", " Stored in directory: /home/kbuilder/.cache/pip/wheels/af/2b/30/5e78b8b9599f2a2286a582b8da80594f654bf0e18d825a4405\r\n", "Successfully built gym\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: gym-notices, pygame, cloudpickle, gym\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed cloudpickle-3.0.0 gym-0.26.2 gym-notices-0.0.8 pygame-2.1.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting pyglet\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading pyglet-2.0.10-py3-none-any.whl.metadata (8.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading pyglet-2.0.10-py3-none-any.whl (858 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: pyglet\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed pyglet-2.0.10\r\n" ] } ], "source": [ "!pip install gym[classic_control]\n", "!pip install pyglet" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:24:51.947959Z", "iopub.status.busy": "2024-01-11T20:24:51.947699Z", "iopub.status.idle": "2024-01-11T20:25:02.908823Z", "shell.execute_reply": "2024-01-11T20:25:02.907893Z" }, "id": "WBeQhPi2S4m5" }, "outputs": [], "source": [ "%%bash\n", "# Install additional packages for visualization\n", "sudo apt-get install -y python-opengl > /dev/null 2>&1\n", "pip install git+https://github.com/tensorflow/docs > /dev/null 2>&1" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:25:02.913528Z", "iopub.status.busy": "2024-01-11T20:25:02.913251Z", "iopub.status.idle": "2024-01-11T20:25:05.572707Z", "shell.execute_reply": "2024-01-11T20:25:05.571905Z" }, "id": "tT4N3qYviUJr" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-01-11 20:25:03.412854: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-01-11 20:25:03.412898: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-01-11 20:25:03.414545: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import collections\n", "import gym\n", "import numpy as np\n", "import statistics\n", "import tensorflow as tf\n", "import tqdm\n", "\n", "from matplotlib import pyplot as plt\n", "from tensorflow.keras import layers\n", "from typing import Any, List, Sequence, Tuple\n", "\n", "\n", "# Create the environment\n", "env = gym.make(\"CartPole-v1\")\n", "\n", "# Set seed for experiment reproducibility\n", "seed = 42\n", "tf.random.set_seed(seed)\n", "np.random.seed(seed)\n", "\n", "# Small epsilon value for stabilizing division operations\n", "eps = np.finfo(np.float32).eps.item()" ] }, { "cell_type": "markdown", "metadata": { "id": "AOUCe2D0iUJu" }, "source": [ "## モデル\n", "\n", "*アクター*と*クリティック*は、アクションの確率とクリティック値をそれぞれ生成する 1 つのニューラルネットワークを使ってモデル化されます。このチュートリアルでは、モデルの定義にモデルのサブクラス化を使用します。\n", "\n", "フォワードパス中、モデルは、状態を入力として取り、アクション確率とクリティック値 $V$ の両方を出力し、これによって状態に依存する[価値関数](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html#value-functions)を形成します。期待される[戻り値](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html#reward-and-return)を最大化するポリシー $\\pi$ に基づいてアクションを選択するモデルをトレーニングするのが目標です。\n", "\n", "`CartPole-v0` では、状態は 4 つの値で表現されます。カートの位置、カートの速度、ポールの角度、およびポールの速度です。エージェントは 2 つのアクションを取って、カートを左(`0`)と右(`1`)に押します。\n", "\n", "詳細については、[Gym's Cart Pole ドキュメントページ](https://www.gymlibrary.dev/environments/classic_control/cart_pole/)と Barto, Sutton and Anderson (1983) の「[*Neuronlike adaptive elements that can solve difficult learning control problems*](http://www.derongliu.org/adp/adp-cdrom/Barto1983.pdf)」をご覧ください。\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:25:05.576923Z", "iopub.status.busy": "2024-01-11T20:25:05.576513Z", "iopub.status.idle": "2024-01-11T20:25:05.581905Z", "shell.execute_reply": "2024-01-11T20:25:05.581261Z" }, "id": "aXKbbMC-kmuv" }, "outputs": [], "source": [ "class ActorCritic(tf.keras.Model):\n", " \"\"\"Combined actor-critic network.\"\"\"\n", "\n", " def __init__(\n", " self, \n", " num_actions: int, \n", " num_hidden_units: int):\n", " \"\"\"Initialize.\"\"\"\n", " super().__init__()\n", "\n", " self.common = layers.Dense(num_hidden_units, activation=\"relu\")\n", " self.actor = layers.Dense(num_actions)\n", " self.critic = layers.Dense(1)\n", "\n", " def call(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:\n", " x = self.common(inputs)\n", " return self.actor(x), self.critic(x)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:25:05.585305Z", "iopub.status.busy": "2024-01-11T20:25:05.584856Z", "iopub.status.idle": "2024-01-11T20:25:07.827621Z", "shell.execute_reply": "2024-01-11T20:25:07.826900Z" }, "id": "nWyxJgjLn68c" }, "outputs": [], "source": [ "num_actions = env.action_space.n # 2\n", "num_hidden_units = 128\n", "\n", "model = ActorCritic(num_actions, num_hidden_units)" ] }, { "cell_type": "markdown", "metadata": { "id": "hk92njFziUJw" }, "source": [ "## エージェントをトレーニングする\n", "\n", "エージェントをトレーニングするには、次の手順を実行します。\n", "\n", "1. 環境でエージェントを実行し、エピソードごとのトレーニングデータを収集します。\n", "2. 時間ステップごとに期待される戻り値を計算します。\n", "3. Actor-Critic の混合モデルの損失を計算します。\n", "4. 勾配を計算し、ネットワークパラメーターを更新します。\n", "5. 成功基準または最大エピソード数に達するまで、1~4 の手順を繰り返します。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "R2nde2XDs8Gh" }, "source": [ "### 1. トレーニングデータを収集する\n", "\n", "アクタークリティックモデルのトレーニングでは、教師あり学習と同様にトレーニングデータが必要です。ただし、そのようなデータを収集するには、モデルを環境で「実行」する必要があります。\n", "\n", "トレーニングデータは、エピソードごとに収集されます。次に、モデルの重みによってパラメーター化された現在のポリシーに基づいてアクションの確率とクリティック値を生成するために、時間ステップごとにモデルのフォワードパスが環境の状態で実行されます。\n", "\n", "次のアクションはモデルが生成したアクション確率からサンプリングされます。これが環境に適用されると、次の状態と報酬が生成されます。\n", "\n", "このプロセスは、`run_episode` 関数に実装されます。後で TensorFlow グラフにコンパイルしてトレーニングを加速化できるように、TensorFlow 演算が使用されています。可変長配列でテンソルをイテレーションできるように、`tf.TensorArray` が使用されていることに注意してください。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:25:07.831888Z", "iopub.status.busy": "2024-01-11T20:25:07.831616Z", "iopub.status.idle": "2024-01-11T20:25:07.836911Z", "shell.execute_reply": "2024-01-11T20:25:07.836239Z" }, "id": "5URrbGlDSAGx" }, "outputs": [], "source": [ "# Wrap Gym's `env.step` call as an operation in a TensorFlow function.\n", "# This would allow it to be included in a callable TensorFlow graph.\n", "\n", "def env_step(action: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:\n", " \"\"\"Returns state, reward and done flag given an action.\"\"\"\n", "\n", " state, reward, done, truncated, info = env.step(action)\n", " return (state.astype(np.float32), \n", " np.array(reward, np.int32), \n", " np.array(done, np.int32))\n", "\n", "\n", "def tf_env_step(action: tf.Tensor) -> List[tf.Tensor]:\n", " return tf.numpy_function(env_step, [action], \n", " [tf.float32, tf.int32, tf.int32])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:25:07.840193Z", "iopub.status.busy": "2024-01-11T20:25:07.839735Z", "iopub.status.idle": "2024-01-11T20:25:07.847494Z", "shell.execute_reply": "2024-01-11T20:25:07.846926Z" }, "id": "a4qVRV063Cl9" }, "outputs": [], "source": [ "def run_episode(\n", " initial_state: tf.Tensor, \n", " model: tf.keras.Model, \n", " max_steps: int) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:\n", " \"\"\"Runs a single episode to collect training data.\"\"\"\n", "\n", " action_probs = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)\n", " values = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)\n", " rewards = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)\n", "\n", " initial_state_shape = initial_state.shape\n", " state = initial_state\n", "\n", " for t in tf.range(max_steps):\n", " # Convert state into a batched tensor (batch size = 1)\n", " state = tf.expand_dims(state, 0)\n", " \n", " # Run the model and to get action probabilities and critic value\n", " action_logits_t, value = model(state)\n", " \n", " # Sample next action from the action probability distribution\n", " action = tf.random.categorical(action_logits_t, 1)[0, 0]\n", " action_probs_t = tf.nn.softmax(action_logits_t)\n", "\n", " # Store critic values\n", " values = values.write(t, tf.squeeze(value))\n", "\n", " # Store log probability of the action chosen\n", " action_probs = action_probs.write(t, action_probs_t[0, action])\n", " \n", " # Apply action to the environment to get next state and reward\n", " state, reward, done = tf_env_step(action)\n", " state.set_shape(initial_state_shape)\n", " \n", " # Store reward\n", " rewards = rewards.write(t, reward)\n", "\n", " if tf.cast(done, tf.bool):\n", " break\n", "\n", " action_probs = action_probs.stack()\n", " values = values.stack()\n", " rewards = rewards.stack()\n", " \n", " return action_probs, values, rewards" ] }, { "cell_type": "markdown", "metadata": { "id": "lBnIHdz22dIx" }, "source": [ "### 2. 期待される戻り値を計算する\n", "\n", "1 つのエピソード中に収集される各時間ステップ $t$ の報酬のシーケンス ${r_{t}}^{T}{t=1}$ は、期待される戻り値のシーケンス ${G{t}}^{T}_{t=1}$ に変換されます。ここで、報酬の合計は現在の時間ステップ $t$ ~ $T$ から取得され、各報酬は指数関数で気に減衰するディスカウント要因 $\\gamma$ で乗算されます。\n", "\n", "$$G_{t} = \\sum^{T}*{t'=t} \\gamma^{t'-t}r*{t'}$$\n", "\n", "$\\gamma\\in(0,1)$ であるため、現在の時間ステップ以降の報酬には与えられる重みは、徐々に少なくなります。\n", "\n", "直感的に、期待される戻り値は単に、現在の報酬が後の報酬より良いことを示しています。数学的に見れば、報酬の合計が収束することが保証されています。\n", "\n", "トレーニングを安定化するには、生成される戻り値のシーケンスも標準化されます(つまり、平均と単位標準偏差がゼロ)。\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:25:07.850499Z", "iopub.status.busy": "2024-01-11T20:25:07.850281Z", "iopub.status.idle": "2024-01-11T20:25:07.855654Z", "shell.execute_reply": "2024-01-11T20:25:07.855084Z" }, "id": "jpEwFyl315dl" }, "outputs": [], "source": [ "def get_expected_return(\n", " rewards: tf.Tensor, \n", " gamma: float, \n", " standardize: bool = True) -> tf.Tensor:\n", " \"\"\"Compute expected returns per timestep.\"\"\"\n", "\n", " n = tf.shape(rewards)[0]\n", " returns = tf.TensorArray(dtype=tf.float32, size=n)\n", "\n", " # Start from the end of `rewards` and accumulate reward sums\n", " # into the `returns` array\n", " rewards = tf.cast(rewards[::-1], dtype=tf.float32)\n", " discounted_sum = tf.constant(0.0)\n", " discounted_sum_shape = discounted_sum.shape\n", " for i in tf.range(n):\n", " reward = rewards[i]\n", " discounted_sum = reward + gamma * discounted_sum\n", " discounted_sum.set_shape(discounted_sum_shape)\n", " returns = returns.write(i, discounted_sum)\n", " returns = returns.stack()[::-1]\n", "\n", " if standardize:\n", " returns = ((returns - tf.math.reduce_mean(returns)) / \n", " (tf.math.reduce_std(returns) + eps))\n", "\n", " return returns" ] }, { "cell_type": "markdown", "metadata": { "id": "qhr50_Czxazw" }, "source": [ "### 3. Actor-Critic 損失\n", "\n", "ハイブリッドの Actor-Critic モデルを使用しているため、選択される損失関数は、以下に示すように、トレーニング用のアクター損失とクリティック損失の合計です。\n", "\n", "$$L = L_{actor} + L_{critic}$$" ] }, { "cell_type": "markdown", "metadata": { "id": "nOQIJuG1xdTH" }, "source": [ "#### アクター損失\n", "\n", "アクター損失は、[クリティックを状態依存の基準としたポリシー勾配](https://www.youtube.com/watch?v=EKqxumCuAAY&t=62m23s)に基づき、単一サンプル(エピソード単位)の推定で計算されます。\n", "\n", "$$L_{actor} = -\\sum^{T}*{t=1} log\\pi*{\\theta}(a_{t} | s_{t})[G(s_{t}, a_{t}) - V^{\\pi}*{\\theta}(s*{t})]$$\n", "\n", "上記は以下を意味します。\n", "\n", "- $T$: エピソードごとの時間ステップの数。エピソードごとにことなります。\n", "- $s_{t}$: 時間ステップ $t$ における状態。\n", "- $a_{t}$: 状態 $s$ の場合の時間ステップ $t$ で選択されたアクション。\n", "- $\\pi_{\\theta}$: $\\theta$ でパラメータ化されたポリシー(アクター)。\n", "- $V^{\\pi}_{\\theta}$: $\\theta$ でパラメータ化された価値関数(クリティック)。\n", "- $G = G_{t}$: 時間ステップ $t$ において特定の状態とアクションに対して期待される戻り値。\n", "\n", "組み合わせの損失を最小限に抑えることでアクションがより高い報酬を生み出す確率を最大化しようとしているため、合計に負の項が追加されます。\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "Y304O4OAxiAv" }, "source": [ "##### アドバンテージ\n", "\n", "$L_{actor}$ 式の $G - V$ の項は[アドバンテージ](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html#advantage-functions)と呼ばれ、特定の状態において、あるアクションが、その状態のポリシー $\\pi$ に従って選択されたランダムアクションと比べてどれくらい優れているかを示します。\n", "\n", "ベースラインを除外することは可能ですが、除外した場合、トレーニング中のバリアンスが高くなってしまう可能性があります。また、ベースラインとしてクリティック $V$ を選択すると、できる限り $G$ に近くなるようにトレーニングされるため、バリアンスがより低くなります。\n", "\n", "さらに、クリティックがなければ、アルゴリズムは特定の状態で実行されるアクションの確率を期待される戻り値に応じて高めようとするため、アクション間の相対的な確率が同じままである場合、結果はあまり変わりません。\n", "\n", "たとえば、特定の状態における 2 つのアクションが同じ期待される戻り値を生成したとします。クリティックがなければ、アルゴリズムは客観的な $J$ に基づき、これらのアクションの確率をあげようとします。クリティックがあれば、アドバンテージがなく($G - V = 0$)、そのためアクションの確率を上げることにアドバンテージがなく、アルゴリズムは勾配をゼロに設定します。\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "1hrPLrgGxlvb" }, "source": [ "#### The Critic loss\n", "\n", "Training $V$ to be as close possible to $G$ can be set up as a regression problem with the following loss function:\n", "\n", "$$L_{critic} = L_{\\delta}(G, V^{\\pi}_{\\theta})$$\n", "\n", "上記の $L_{\\delta}$ は [Huber 損失](https://en.wikipedia.org/wiki/Huber_loss)でこれは、二乗誤差損失よりもデータの外れ値に対する感度が低くなります。\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:25:07.859000Z", "iopub.status.busy": "2024-01-11T20:25:07.858781Z", "iopub.status.idle": "2024-01-11T20:25:07.862981Z", "shell.execute_reply": "2024-01-11T20:25:07.862330Z" }, "id": "9EXwbEez6n9m" }, "outputs": [], "source": [ "huber_loss = tf.keras.losses.Huber(reduction=tf.keras.losses.Reduction.SUM)\n", "\n", "def compute_loss(\n", " action_probs: tf.Tensor, \n", " values: tf.Tensor, \n", " returns: tf.Tensor) -> tf.Tensor:\n", " \"\"\"Computes the combined Actor-Critic loss.\"\"\"\n", "\n", " advantage = returns - values\n", "\n", " action_log_probs = tf.math.log(action_probs)\n", " actor_loss = -tf.math.reduce_sum(action_log_probs * advantage)\n", "\n", " critic_loss = huber_loss(values, returns)\n", "\n", " return actor_loss + critic_loss" ] }, { "cell_type": "markdown", "metadata": { "id": "HSYkQOmRfV75" }, "source": [ "### 4. パラメータを更新するようにトレーニングステップを定義する\n", "\n", "上記のすべてのステップは、エピソードごとに実行されるトレーニングステップに結合されます。損失関数に導くすべてのステップは、自動微分を可能にする `tf.GradientTape` コンテキストで実行されます。\n", "\n", "このチュートリアルでは、Adam オプティマイザを使って勾配をモデルパラメーターに適用します。\n", "\n", "ディスカウントされていない報酬の合計 `episode_reward` も、このステップで計算されます。この値は、成功基準が満たされるかどうかを評価するために、後で使用されます。\n", "\n", "`tf.function` コンテキストは `train_step` 関数に適用することで、コーラブル TenSorFlow グラフにコンパイルできるようなります。そうすることで、トレーニングを 10 倍加速させることができます。\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:25:07.866155Z", "iopub.status.busy": "2024-01-11T20:25:07.865720Z", "iopub.status.idle": "2024-01-11T20:25:07.875457Z", "shell.execute_reply": "2024-01-11T20:25:07.874887Z" }, "id": "QoccrkF3IFCg" }, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)\n", "\n", "\n", "@tf.function\n", "def train_step(\n", " initial_state: tf.Tensor, \n", " model: tf.keras.Model, \n", " optimizer: tf.keras.optimizers.Optimizer, \n", " gamma: float, \n", " max_steps_per_episode: int) -> tf.Tensor:\n", " \"\"\"Runs a model training step.\"\"\"\n", "\n", " with tf.GradientTape() as tape:\n", "\n", " # Run the model for one episode to collect training data\n", " action_probs, values, rewards = run_episode(\n", " initial_state, model, max_steps_per_episode) \n", "\n", " # Calculate the expected returns\n", " returns = get_expected_return(rewards, gamma)\n", "\n", " # Convert training data to appropriate TF tensor shapes\n", " action_probs, values, returns = [\n", " tf.expand_dims(x, 1) for x in [action_probs, values, returns]] \n", "\n", " # Calculate the loss values to update our network\n", " loss = compute_loss(action_probs, values, returns)\n", "\n", " # Compute the gradients from the loss\n", " grads = tape.gradient(loss, model.trainable_variables)\n", "\n", " # Apply the gradients to the model's parameters\n", " optimizer.apply_gradients(zip(grads, model.trainable_variables))\n", "\n", " episode_reward = tf.math.reduce_sum(rewards)\n", "\n", " return episode_reward" ] }, { "cell_type": "markdown", "metadata": { "id": "HFvZiDoAflGK" }, "source": [ "### 5. トレーニングツールを実行する\n", "\n", "トレーニングは、成功基準またはエピソードの最大数に達するまでトレーニングステップを実行することで、実行されます。\n", "\n", "エピソード報酬の実行中の記録はキューに保持されます。100 トライアルに達したら、キューの左終端(テール)から最も古い報酬が削除され、キューの右(ヘッド)に最も新しい報酬が追加されます。実行中の報酬の合計も、計算の効率を得るために管理されます。\n", "\n", "ランタイムによっては、トレーニングを 1 分未満で完了することもできます。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:25:07.878550Z", "iopub.status.busy": "2024-01-11T20:25:07.878312Z", "iopub.status.idle": "2024-01-11T20:26:38.146763Z", "shell.execute_reply": "2024-01-11T20:26:38.146046Z" }, "id": "kbmBxnzLiUJx" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/10000 [00:00= 475 over 500 \n", "# consecutive trials\n", "reward_threshold = 475\n", "running_reward = 0\n", "\n", "# The discount factor for future rewards\n", "gamma = 0.99\n", "\n", "# Keep the last episodes reward\n", "episodes_reward: collections.deque = collections.deque(maxlen=min_episodes_criterion)\n", "\n", "t = tqdm.trange(max_episodes)\n", "for i in t:\n", " initial_state, info = env.reset()\n", " initial_state = tf.constant(initial_state, dtype=tf.float32)\n", " episode_reward = int(train_step(\n", " initial_state, model, optimizer, gamma, max_steps_per_episode))\n", " \n", " episodes_reward.append(episode_reward)\n", " running_reward = statistics.mean(episodes_reward)\n", " \n", "\n", " t.set_postfix(\n", " episode_reward=episode_reward, running_reward=running_reward)\n", " \n", " # Show the average episode reward every 10 episodes\n", " if i % 10 == 0:\n", " pass # print(f'Episode {i}: average reward: {avg_reward}')\n", " \n", " if running_reward > reward_threshold and i >= min_episodes_criterion: \n", " break\n", "\n", "print(f'\\nSolved at episode {i}: average reward: {running_reward:.2f}!')" ] }, { "cell_type": "markdown", "metadata": { "id": "ru8BEwS1EmAv" }, "source": [ "## 可視化\n", "\n", "トレーニングが終わったら、モデルが環境でどのように実行するかを可視化すると良いでしょう。以下のセルを実行すると、モデルの 1 エピソードの実行を視覚化する GIF アニメーションを生成することができます。Colab で環境の画像を正しくレンダリングするには、Gym の追加パッケージをインストールする必要があることに注意してください。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:26:38.150318Z", "iopub.status.busy": "2024-01-11T20:26:38.150066Z", "iopub.status.idle": "2024-01-11T20:26:40.833324Z", "shell.execute_reply": "2024-01-11T20:26:40.832562Z" }, "id": "qbIMMkfmRHyC" }, "outputs": [], "source": [ "# Render an episode and save as a GIF file\n", "\n", "from IPython import display as ipythondisplay\n", "from PIL import Image\n", "\n", "render_env = gym.make(\"CartPole-v1\", render_mode='rgb_array')\n", "\n", "def render_episode(env: gym.Env, model: tf.keras.Model, max_steps: int): \n", " state, info = env.reset()\n", " state = tf.constant(state, dtype=tf.float32)\n", " screen = env.render()\n", " images = [Image.fromarray(screen)]\n", " \n", " for i in range(1, max_steps + 1):\n", " state = tf.expand_dims(state, 0)\n", " action_probs, _ = model(state)\n", " action = np.argmax(np.squeeze(action_probs))\n", "\n", " state, reward, done, truncated, info = env.step(action)\n", " state = tf.constant(state, dtype=tf.float32)\n", "\n", " # Render screen every 10 steps\n", " if i % 10 == 0:\n", " screen = env.render()\n", " images.append(Image.fromarray(screen))\n", " \n", " if done:\n", " break\n", " \n", " return images\n", "\n", "\n", "# Save GIF image\n", "images = render_episode(render_env, model, max_steps_per_episode)\n", "image_file = 'cartpole-v1.gif'\n", "# loop=0: loop forever, duration=1: play each frame for 1ms\n", "images[0].save(\n", " image_file, save_all=True, append_images=images[1:], loop=0, duration=1)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T20:26:40.837323Z", "iopub.status.busy": "2024-01-11T20:26:40.837026Z", "iopub.status.idle": "2024-01-11T20:26:40.846810Z", "shell.execute_reply": "2024-01-11T20:26:40.846217Z" }, "id": "TLd720SejKmf" }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import tensorflow_docs.vis.embed as embed\n", "embed.embed_file(image_file)" ] }, { "cell_type": "markdown", "metadata": { "id": "lnq9Hzo1Po6X" }, "source": [ "## 次のステップ\n", "\n", "このチュートリアルでは、TensorFlow を使って Actor-Critic 法を実装する方法を説明しました。\n", "\n", "次のステップでは、Gym の別の環境でモデルをトレーニングしてみるとよいでしょう。\n", "\n", "Actor-Critic 法と Cartpole-v0 問題に関するその他の詳細については、以下のリソースをご覧ください。\n", "\n", "- [Actor-Critic 法](https://hal.inria.fr/hal-00840470/document)\n", "- [Actor-Critic に関する講義(CAL)](https://www.youtube.com/watch?v=EKqxumCuAAY&list=PLkFD6_40KJIwhWJpGazJ9VSj9CFMkb79A&index=7&t=0s)\n", "- [Cart Pole learning control problem [Barto, et al. 1983]](http://www.derongliu.org/adp/adp-cdrom/Barto1983.pdf)\n", "\n", "TenSorFlow における強化学習のその他の例については、次のリソースをご覧ください。\n", "\n", "- [強化学習のコードサンプル(keras.io)](https://keras.io/examples/rl/)\n", "- [TF-Agents の強化学習用ライブラリ](https://www.tensorflow.org/agents)\n" ] } ], "metadata": { "colab": { "collapsed_sections": [ "_jQ1tEQCxwRx" ], "name": "actor_critic.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }