{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T20:37:02.191456Z", "iopub.status.busy": "2022-12-14T20:37:02.191030Z", "iopub.status.idle": "2022-12-14T20:37:02.194921Z", "shell.execute_reply": "2022-12-14T20:37:02.194270Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "qFdPvlXBOdUN" }, "source": [ "# 乱数の生成" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "BlGY1iiph_C2" }, "source": [ "TensorFlow では、`tf.random` モジュールに疑似乱数ジェネレータ(RNG)を提供しています。このドキュメントは、乱数ジェネレータをどのように制御し、これらのジェネレータがほかの TensorFlow サブシステムとどのように対話するのかを説明します。\n", "\n", "注意: TensorFlow バージョン間での乱数の一貫性は保証されていません。「[バージョンの互換性](https://www.tensorflow.org/guide/versions#what_is_not_covered)」をご覧ください。\n", "\n", "TensorFlow provides two approaches for controlling the random number generation process:\n", "\n", "1. `tf.random.Generator` オブジェクトを明示的に使用する。各オブジェクトは、数字が生成された後に変更される状態を維持します(`tf.Variable`)。\n", "\n", "2. `tf.random.stateless_uniform` などの純粋関数型のステートレスランダム関数を使用する。同一の引数を使って(シードを含む)同じデバイスでこれらの関数を呼び出すと、同じ結果が必ず生成されます。\n", "\n", "警告: `tf.random.uniform` や `tf.random.normal` といった TF 1.x の古い RNG は、まだ使用廃止になっていませんが、使用しないことを強くお勧めします。" ] }, { "cell_type": "markdown", "metadata": { "id": "zIGh9faCOp6x" }, "source": [ "## セットアップ" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:02.198447Z", "iopub.status.busy": "2022-12-14T20:37:02.197991Z", "iopub.status.idle": "2022-12-14T20:37:04.230875Z", "shell.execute_reply": "2022-12-14T20:37:04.230089Z" }, "id": "ECDrttf0s8Nu" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 20:37:03.134096: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 20:37:03.134199: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 20:37:03.134210: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "# Creates some virtual devices (cpu:0, cpu:1, etc.) for using distribution strategy\n", "physical_devices = tf.config.list_physical_devices(\"CPU\")\n", "tf.config.experimental.set_virtual_device_configuration(\n", " physical_devices[0], [\n", " tf.config.experimental.VirtualDeviceConfiguration(),\n", " tf.config.experimental.VirtualDeviceConfiguration(),\n", " tf.config.experimental.VirtualDeviceConfiguration()\n", " ])" ] }, { "cell_type": "markdown", "metadata": { "id": "eqMlrUsVu2Ai" }, "source": [ "## `tf.random.Generator` クラス\n", "\n", "`tf.random.Generator` クラスは、それぞれの RNG 呼び出しで異なる結果を得たい場合に使用します。また、乱数が生成されるたびに更新される内部状態(`tf.Variable` オブジェクトが管理)を維持します。状態は `tf.Variable` によって管理されるため、チェックポイントの簡単な設定、自動制御依存関係、およびスレッドセーフなど、`tf.Variable` が提供する便利な機能を利用できます。\n", "\n", "クラスのオブジェクトを手動で作成して `tf.random.Generator` を得るか、`tf.random.get_global_generator()` を呼び出して、デフォルトグローバルジェネレータを得ることができます。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:04.235225Z", "iopub.status.busy": "2022-12-14T20:37:04.234780Z", "iopub.status.idle": "2022-12-14T20:37:07.528893Z", "shell.execute_reply": "2022-12-14T20:37:07.528221Z" }, "id": "7yU1E3JvxOQD" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 0.43842277 -0.53439844 -0.07710262]\n", " [ 1.5658046 -0.1012345 -0.2744976 ]], shape=(2, 3), dtype=float32)\n", "tf.Tensor(\n", "[[-0.50628775 -0.1510015 -0.06051261]\n", " [ 0.8713965 -1.641594 0.34522513]], shape=(2, 3), dtype=float32)\n" ] } ], "source": [ "g1 = tf.random.Generator.from_seed(1)\n", "print(g1.normal(shape=[2, 3]))\n", "g2 = tf.random.get_global_generator()\n", "print(g2.normal(shape=[2, 3]))" ] }, { "cell_type": "markdown", "metadata": { "id": "QmRCeAvTxulW" }, "source": [ "ジェネレータオブジェクトには、さまざまな作成方法があります。最も簡単なのは、上記に示した `Generator.from_seed` で、シードからジェネレータを作成します。シードは、負でない整数値です。`from_seed` にはオプションの引数 `alg` があり、このジェネレータが使用する RNG アルゴリズムを指定します。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:07.532302Z", "iopub.status.busy": "2022-12-14T20:37:07.532020Z", "iopub.status.idle": "2022-12-14T20:37:07.540160Z", "shell.execute_reply": "2022-12-14T20:37:07.539562Z" }, "id": "kISbOE4Xfjhv" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 0.43842277 -0.53439844 -0.07710262]\n", " [ 1.5658046 -0.1012345 -0.2744976 ]], shape=(2, 3), dtype=float32)\n" ] } ], "source": [ "g1 = tf.random.Generator.from_seed(1, alg='philox')\n", "print(g1.normal(shape=[2, 3]))" ] }, { "cell_type": "markdown", "metadata": { "id": "_mCRaN7dfd8j" }, "source": [ "この詳細については、以下の「*アルゴリズム*」のセクションをご覧ください。\n", "\n", "ジェネレータを作成するもう 1 つの方法に、`Generator.from_non_deterministic_state` を使用する方法があります。この方法で作成されたジェネレータは、時刻や OS に基づいて、非確定的状態から開始します。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:07.543404Z", "iopub.status.busy": "2022-12-14T20:37:07.543170Z", "iopub.status.idle": "2022-12-14T20:37:07.550306Z", "shell.execute_reply": "2022-12-14T20:37:07.549731Z" }, "id": "gxPLCLsz00qY" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 0.28561354 0.44855464 2.2449033 ]\n", " [-0.7224477 1.1459732 0.65525055]], shape=(2, 3), dtype=float32)\n" ] } ], "source": [ "g = tf.random.Generator.from_non_deterministic_state()\n", "print(g.normal(shape=[2, 3]))" ] }, { "cell_type": "markdown", "metadata": { "id": "zSAp2BMj1JZ6" }, "source": [ "ジェネレータの作成方法はほかにもありますが、このガイドでは説明されていません。\n", "\n", "`tf.random.get_global_generator` を使用してグローバルジェネレータを得る場合、デバイスの配置に注意する必要があります。グローバルジェネレータは、`tf.random.get_global_generator` が呼び出されたときに初めて作成され(非確定的状態から)、その呼び出し時のデフォルトのデバイスに配置されます。そのため、たとえば `tf.random.get_global_generator` を呼び出す最初の場所が `tf.device(\"gpu\")` スコープ内である場合、グローバルジェネレータは GPU に配置され、後で CPU からそのグローバルジェネレータを使用すると、GPU から CPU へのコピーを招くことになります。\n", "\n", "また、グローバルジェネレータを別のジェネレータオブジェクトに置き換える `tf.random.set_global_generator` 関数もあります。ただし、古いグローバルジェネレータが `tf.function` によってキャプチャされている可能性があり(弱参照として)、それを置き換えるとガベージコレクタで解放され、`tf.function` が機能しなくなるため、この関数の使用には注意が必要です。グローバルジェネレータをリセットする方法としては、 `Generator.reset_from_seed` などの「リセット」関数を使用する方が有効です。この関数は新しいジェネレータオブジェクトを作成しません。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:07.553319Z", "iopub.status.busy": "2022-12-14T20:37:07.553075Z", "iopub.status.idle": "2022-12-14T20:37:07.571537Z", "shell.execute_reply": "2022-12-14T20:37:07.570931Z" }, "id": "324S5bpd9HRg" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(0.43842277, shape=(), dtype=float32)\n", "tf.Tensor(1.6272374, shape=(), dtype=float32)\n", "tf.Tensor(0.43842277, shape=(), dtype=float32)\n" ] } ], "source": [ "g = tf.random.Generator.from_seed(1)\n", "print(g.normal([]))\n", "print(g.normal([]))\n", "g.reset_from_seed(1)\n", "print(g.normal([]))" ] }, { "cell_type": "markdown", "metadata": { "id": "z9H0wuvp9VwH" }, "source": [ "### 独立乱数ストリームを作成する\n", "\n", "多くのアプリケーションでは、複数の独立乱数ストリームが必要です。ここでいう独立とは、これらのストリームがオーバーラップすることがないため、統計的に検出可能な相関を持つことがないということです。これは、`Generator.split` を使用して、互いに独立することが保証された複数のジェネレータを作成する(独立ストリームを生成する)ことで実現できます。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:07.575000Z", "iopub.status.busy": "2022-12-14T20:37:07.574386Z", "iopub.status.idle": "2022-12-14T20:37:07.598921Z", "shell.execute_reply": "2022-12-14T20:37:07.598230Z" }, "id": "Vg5_KN18OZjo" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(0.43842277, shape=(), dtype=float32)\n", "tf.Tensor(2.536413, shape=(), dtype=float32)\n", "tf.Tensor(0.33186463, shape=(), dtype=float32)\n", "tf.Tensor(-0.07144657, shape=(), dtype=float32)\n", "tf.Tensor(-0.79253083, shape=(), dtype=float32)\n" ] } ], "source": [ "g = tf.random.Generator.from_seed(1)\n", "print(g.normal([]))\n", "new_gs = g.split(3)\n", "for new_g in new_gs:\n", " print(new_g.normal([]))\n", "print(g.normal([]))" ] }, { "cell_type": "markdown", "metadata": { "id": "dqOaGVzKOsRJ" }, "source": [ "`split` は、`normal` などの RNG メソッドと同様に、それを呼び出したジェネレータ(上記の例の `g`)の状態を変更します。互いに独立しているほかに、新しいジェネレータ(`new_gs`)は、古いジェネレータ(`g`)からの独立も保証されます。\n", "\n", "新しいジェネレータをスポーンする方法は、デバイス間コピーのオーバーヘッドを回避するために、使用するジェネレータがほかの計算と同じデバイス上にあることを確実にする上でも役立ちます。次に例を示します。 " ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:07.602062Z", "iopub.status.busy": "2022-12-14T20:37:07.601794Z", "iopub.status.idle": "2022-12-14T20:37:07.623203Z", "shell.execute_reply": "2022-12-14T20:37:07.622604Z" }, "id": "5jSnJBlUQzF3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(-0.55285686, shape=(), dtype=float32)\n" ] } ], "source": [ "with tf.device(\"cpu\"): # change \"cpu\" to the device you want\n", " g = tf.random.get_global_generator().split(1)[0] \n", " print(g.normal([])) # use of g won't cause cross-device copy, unlike the global generator" ] }, { "cell_type": "markdown", "metadata": { "id": "sCxbccYMRdd4" }, "source": [ "注意: 理論的には、`split` の代わりに `from_seed` などのコンストラクタを使用して新しいジェネレータを取得することはできますが、そうした場合、新しいジェネレータがグローバルジェネレータから独立する保証がなくなります。また、偶発的に、同じシードまたは乱数ストリームをオーバーラップさせるシードを伴うジェネレータを 2 つ作成してしまうリスクもあります。\n", "\n", "split ジェネレータに `split` を呼び出して分割を再帰的に実行することができます。再帰の深度には制限(整数オーバーフローの禁止)はありません。" ] }, { "cell_type": "markdown", "metadata": { "id": "8JUgnQM_O0lg" }, "source": [ "### `tf.function` とのインタラクション\n", "\n", "`tf.random.Generator` は、`tf.function` と使用した場合の `tf.Variable` と同じルールに従います。これには、3 つの側面が含まれます。" ] }, { "cell_type": "markdown", "metadata": { "id": "jnSjhY6WM-J8" }, "source": [ "#### `tf.function` の外部でジェネレータを作成する\n", "\n", "`tf.function` は、その外部で作成されたジェネレータを使用できます。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:07.626339Z", "iopub.status.busy": "2022-12-14T20:37:07.626105Z", "iopub.status.idle": "2022-12-14T20:37:07.681051Z", "shell.execute_reply": "2022-12-14T20:37:07.680478Z" }, "id": "a5EEy0E2UHMw" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(0.43842277, shape=(), dtype=float32)\n" ] } ], "source": [ "g = tf.random.Generator.from_seed(1)\n", "@tf.function\n", "def foo():\n", " return g.normal([])\n", "print(foo())" ] }, { "cell_type": "markdown", "metadata": { "id": "L_8kC7kbO5uu" }, "source": [ "ユーザーは、関数が呼び出されたときにジェネレータオブジェクトがアライブである(ガベージコレクトされていない)ことを確認する必要があります。" ] }, { "cell_type": "markdown", "metadata": { "id": "PwIrBv_zUYwI" }, "source": [ "#### `tf.function` 内部でジェネレータを作成する\n", "\n", "`tf.function` 内部でのジェネレータの作成は、関数を初めて実行したときにのみ発生します。 " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:07.684426Z", "iopub.status.busy": "2022-12-14T20:37:07.684130Z", "iopub.status.idle": "2022-12-14T20:37:07.796740Z", "shell.execute_reply": "2022-12-14T20:37:07.796156Z" }, "id": "3JzpUvqJU4MW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(0.43842277, shape=(), dtype=float32)\n", "tf.Tensor(1.6272374, shape=(), dtype=float32)\n" ] } ], "source": [ "g = None\n", "@tf.function\n", "def foo():\n", " global g\n", " if g is None:\n", " g = tf.random.Generator.from_seed(1)\n", " return g.normal([])\n", "print(foo())\n", "print(foo())" ] }, { "cell_type": "markdown", "metadata": { "id": "UaTVnOhHVM9a" }, "source": [ "#### ジェネレータを引数として `tf.function` に渡す\n", "\n", "`tf.function` の引数として使用される場合、別のジェネレータオブジェクトによって `tf.function` の再トレースが始まります。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:07.799834Z", "iopub.status.busy": "2022-12-14T20:37:07.799592Z", "iopub.status.idle": "2022-12-14T20:37:07.877115Z", "shell.execute_reply": "2022-12-14T20:37:07.876529Z" }, "id": "DeR9kvt0V-ad" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2\n" ] } ], "source": [ "num_traces = 0\n", "@tf.function\n", "def foo(g):\n", " global num_traces\n", " num_traces += 1\n", " return g.normal([])\n", "foo(tf.random.Generator.from_seed(1))\n", "foo(tf.random.Generator.from_seed(2))\n", "print(num_traces)" ] }, { "cell_type": "markdown", "metadata": { "id": "E0RxllJzkGfo" }, "source": [ "この再トレースの動作は `tf.Variable` と同じです。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:07.880315Z", "iopub.status.busy": "2022-12-14T20:37:07.880080Z", "iopub.status.idle": "2022-12-14T20:37:07.916855Z", "shell.execute_reply": "2022-12-14T20:37:07.916274Z" }, "id": "oWD2f_qxkSe7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n" ] } ], "source": [ "num_traces = 0\n", "@tf.function\n", "def foo(v):\n", " global num_traces\n", " num_traces += 1\n", " return v.read_value()\n", "foo(tf.Variable(1))\n", "foo(tf.Variable(2))\n", "print(num_traces)" ] }, { "cell_type": "markdown", "metadata": { "id": "fxcS6IY8WZuh" }, "source": [ "### 分散ストラテジーとのインタラクション\n", "\n", "`Generator` が分散ストラテジーとインタラクションする方法には 2 つあります。" ] }, { "cell_type": "markdown", "metadata": { "id": "GyZv9QJkZfkQ" }, "source": [ "#### 分散ストラテジーの外部でジェネレータを作成する\n", "\n", "ストラテジーのスコープ外でジェネレータを作成する場合、すべてのレプリカによるジェネレータへのアクセスはシリアライズされるため、レプリカは異なる乱数を取得します。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:07.919950Z", "iopub.status.busy": "2022-12-14T20:37:07.919704Z", "iopub.status.idle": "2022-12-14T20:37:07.949233Z", "shell.execute_reply": "2022-12-14T20:37:07.948560Z" }, "id": "HX_beT9SZWMp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(0.43842274, shape=(), dtype=float32)\n", "tf.Tensor(1.6272374, shape=(), dtype=float32)\n" ] } ], "source": [ "g = tf.random.Generator.from_seed(1)\n", "strat = tf.distribute.MirroredStrategy(devices=[\"cpu:0\", \"cpu:1\"])\n", "with strat.scope():\n", " def f():\n", " print(g.normal([]))\n", " results = strat.run(f)" ] }, { "cell_type": "markdown", "metadata": { "id": "ydYQbUqLPAgH" }, "source": [ "この使用方法では、ジェネレータのデバイスがレプリカとは異なるため、パフォーマンスの問題が発生する可能性があります。" ] }, { "cell_type": "markdown", "metadata": { "id": "Yal4LbBKbAeN" }, "source": [ "#### 分散ストラテジーの内部でジェネレータを作成する\n", "\n", "ジェネレータがストラテジーの範囲内で作成される場合、レプリカごとに異なる独立した乱数のストリームが取得されます。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:07.952497Z", "iopub.status.busy": "2022-12-14T20:37:07.952256Z", "iopub.status.idle": "2022-12-14T20:37:07.982297Z", "shell.execute_reply": "2022-12-14T20:37:07.981548Z" }, "id": "5SeUu7IFmTyQ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor(-0.87930447, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.020661574, shape=(), dtype=float32)\n", "}\n", "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor(-1.5822568, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.77539235, shape=(), dtype=float32)\n", "}\n" ] } ], "source": [ "strat = tf.distribute.MirroredStrategy(devices=[\"cpu:0\", \"cpu:1\"])\n", "with strat.scope():\n", " g = tf.random.Generator.from_seed(1)\n", " print(strat.run(lambda: g.normal([])))\n", " print(strat.run(lambda: g.normal([])))" ] }, { "cell_type": "markdown", "metadata": { "id": "PFBlrOudfu9u" }, "source": [ "注意: 現在、`tf.random.Generator` には異なるレプリカに異なるストリームの代わりに同一のストリームを取得させるオプションはありません(厳密には困難なことではありませんが)。この特徴量のユースケースがある場合は、TensorFlow 開発者に知らせてください。\n", "\n", "ジェネレータがシードされる場合(`Generator.from_seed` で作成されるなど)、レプリカごとに異なる無関係な数字が取得されるにも関わらず、乱数はシードによって決定されます。レプリカで生成される乱数をレプリカ ID のハッシュやすべてのレプリカに共通するランダムな「素」数として考えることができます。そのため、システム全体は依然として確定的なままです。\n", "\n", "`tf.random.Generator` は、`Strategy.run` 内にも作成できます。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:07.985617Z", "iopub.status.busy": "2022-12-14T20:37:07.985356Z", "iopub.status.idle": "2022-12-14T20:37:08.026073Z", "shell.execute_reply": "2022-12-14T20:37:08.025507Z" }, "id": "nlQXi5Msb1Wu" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor([-0.87930447 -1.5822568 ], shape=(2,), dtype=float32),\n", " 1: tf.Tensor([0.02066157 0.77539235], shape=(2,), dtype=float32)\n", "}\n", "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor([-0.87930447 -1.5822568 ], shape=(2,), dtype=float32),\n", " 1: tf.Tensor([0.02066157 0.77539235], shape=(2,), dtype=float32)\n", "}\n" ] } ], "source": [ "strat = tf.distribute.MirroredStrategy(devices=[\"cpu:0\", \"cpu:1\"])\n", "with strat.scope():\n", " def f():\n", " g = tf.random.Generator.from_seed(1)\n", " a = g.normal([])\n", " b = g.normal([])\n", " return tf.stack([a, b])\n", " print(strat.run(f))\n", " print(strat.run(f))" ] }, { "cell_type": "markdown", "metadata": { "id": "4Sv-aiaOmrOr" }, "source": [ "`tf.random.Generator` を `Strategy.run` の引数として渡すことは推奨されなくなりました。`Strategy.run` が一般的にジェネレータではなくテンソルの引数を期待するためです。" ] }, { "cell_type": "markdown", "metadata": { "id": "8RbM4vabtiWM" }, "source": [ "### ジェネレータを保存する\n", "\n", "一般的に保存またはシリアル化については、`tf.random.Generator` を `tf.Variable` や `tf.Module`(またはそのサブクラス)と同じように扱うことができます。TF には、[チェックポイント](https://www.tensorflow.org/guide/checkpoint)と [SavedModel](https://www.tensorflow.org/guide/saved_model) という 2 つのシリアル化の仕組みがあります。" ] }, { "cell_type": "markdown", "metadata": { "id": "PDtySQDotWQc" }, "source": [ "#### チェックポイント\n", "\n", "ジェネレータは `tf.train.Checkpoint` を使って保存と復元を自在に行うことができます。復元ポイントの乱数ストリームは保存ポイントの乱数ストリームと同じになります。 " ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.029362Z", "iopub.status.busy": "2022-12-14T20:37:08.029128Z", "iopub.status.idle": "2022-12-14T20:37:08.037449Z", "shell.execute_reply": "2022-12-14T20:37:08.036803Z" }, "id": "uB_bDSbzpbne" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(0.43842277, shape=(), dtype=float32)\n" ] } ], "source": [ "filename = \"./checkpoint\"\n", "g = tf.random.Generator.from_seed(1)\n", "cp = tf.train.Checkpoint(generator=g)\n", "print(g.normal([]))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.040594Z", "iopub.status.busy": "2022-12-14T20:37:08.040054Z", "iopub.status.idle": "2022-12-14T20:37:08.061218Z", "shell.execute_reply": "2022-12-14T20:37:08.060607Z" }, "id": "bKKtRWeIkIjX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RNG stream from saving point:\n", "tf.Tensor(1.6272374, shape=(), dtype=float32)\n", "tf.Tensor(1.6307176, shape=(), dtype=float32)\n" ] } ], "source": [ "cp.write(filename)\n", "print(\"RNG stream from saving point:\")\n", "print(g.normal([]))\n", "print(g.normal([]))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.064004Z", "iopub.status.busy": "2022-12-14T20:37:08.063773Z", "iopub.status.idle": "2022-12-14T20:37:08.076102Z", "shell.execute_reply": "2022-12-14T20:37:08.075529Z" }, "id": "-cIHcHwRkQp3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RNG stream from restoring point:\n", "tf.Tensor(1.6272374, shape=(), dtype=float32)\n", "tf.Tensor(1.6307176, shape=(), dtype=float32)\n" ] } ], "source": [ "cp.restore(filename)\n", "print(\"RNG stream from restoring point:\")\n", "print(g.normal([]))\n", "print(g.normal([]))" ] }, { "cell_type": "markdown", "metadata": { "id": "A-OeUUQEJ37X" }, "source": [ "また、分散ストラテジー内でも保存と復元を実行することができます。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.078954Z", "iopub.status.busy": "2022-12-14T20:37:08.078724Z", "iopub.status.idle": "2022-12-14T20:37:08.092898Z", "shell.execute_reply": "2022-12-14T20:37:08.092352Z" }, "id": "3aI6TQ2lq28w" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor(-0.87930447, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.020661574, shape=(), dtype=float32)\n", "}\n" ] } ], "source": [ "filename = \"./checkpoint\"\n", "strat = tf.distribute.MirroredStrategy(devices=[\"cpu:0\", \"cpu:1\"])\n", "with strat.scope():\n", " g = tf.random.Generator.from_seed(1)\n", " cp = tf.train.Checkpoint(my_generator=g)\n", " print(strat.run(lambda: g.normal([])))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.095721Z", "iopub.status.busy": "2022-12-14T20:37:08.095494Z", "iopub.status.idle": "2022-12-14T20:37:08.112706Z", "shell.execute_reply": "2022-12-14T20:37:08.112129Z" }, "id": "kTZcdaMwkvJI" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RNG stream from saving point:\n", "PerReplica:{\n", " 0: tf.Tensor(-1.5822568, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.77539235, shape=(), dtype=float32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor(-0.5039703, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.1251838, shape=(), dtype=float32)\n", "}\n" ] } ], "source": [ "with strat.scope():\n", " cp.write(filename)\n", " print(\"RNG stream from saving point:\")\n", " print(strat.run(lambda: g.normal([])))\n", " print(strat.run(lambda: g.normal([])))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.115534Z", "iopub.status.busy": "2022-12-14T20:37:08.115264Z", "iopub.status.idle": "2022-12-14T20:37:08.135548Z", "shell.execute_reply": "2022-12-14T20:37:08.134910Z" }, "id": "nizFA5IrkzN1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RNG stream from restoring point:\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor(-1.5822568, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.77539235, shape=(), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor(-0.5039703, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.1251838, shape=(), dtype=float32)\n", "}\n" ] } ], "source": [ "with strat.scope():\n", " cp.restore(filename)\n", " print(\"RNG stream from restoring point:\")\n", " print(strat.run(lambda: g.normal([])))\n", " print(strat.run(lambda: g.normal([])))" ] }, { "cell_type": "markdown", "metadata": { "id": "Z2rsPfp9J6JA" }, "source": [ "保存する前に、レプリカが RNG 呼び出し履歴で分岐しないことを確認する必要があります(1 つのレプリカが 1 つの RNG 呼び出しを行い、もう 1 つが 2 つの RNG 呼び出しを行うなど)。そうでない場合、RNG の内部状態が分岐してしまい、最初のレプリカの状態のみを保存する `tf.train.Checkpoint` が、すべてのレプリカを適切に復元できなくなります。\n", "\n", "また、保存したチェックポイントをレプリカ数の異なる別の分散ストラテジーに復元することもできます。ストラテジー内で作成された `tf.random.Generator` オブジェクトは、同じストラテジーでのみ使用できるため、異なるストラテジーを復元するには以下の例に示すとおり、目的のストラテジーで新しい `tf.random.Generator` を作成し、それに使用する新しい `tf.train.Checkpoint` を作成する必要があります。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.138913Z", "iopub.status.busy": "2022-12-14T20:37:08.138655Z", "iopub.status.idle": "2022-12-14T20:37:08.152834Z", "shell.execute_reply": "2022-12-14T20:37:08.152266Z" }, "id": "zgoFRf59-IvW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor(-0.87930447, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.020661574, shape=(), dtype=float32)\n", "}\n" ] } ], "source": [ "filename = \"./checkpoint\"\n", "strat1 = tf.distribute.MirroredStrategy(devices=[\"cpu:0\", \"cpu:1\"])\n", "with strat1.scope():\n", " g1 = tf.random.Generator.from_seed(1)\n", " cp1 = tf.train.Checkpoint(my_generator=g1)\n", " print(strat1.run(lambda: g1.normal([])))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.155709Z", "iopub.status.busy": "2022-12-14T20:37:08.155487Z", "iopub.status.idle": "2022-12-14T20:37:08.172216Z", "shell.execute_reply": "2022-12-14T20:37:08.171654Z" }, "id": "Lu79ETxMlDpO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RNG stream from saving point:\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor(-1.5822568, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.77539235, shape=(), dtype=float32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor(-0.5039703, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.1251838, shape=(), dtype=float32)\n", "}\n" ] } ], "source": [ "with strat1.scope():\n", " cp1.write(filename)\n", " print(\"RNG stream from saving point:\")\n", " print(strat1.run(lambda: g1.normal([])))\n", " print(strat1.run(lambda: g1.normal([])))" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.175093Z", "iopub.status.busy": "2022-12-14T20:37:08.174838Z", "iopub.status.idle": "2022-12-14T20:37:08.223308Z", "shell.execute_reply": "2022-12-14T20:37:08.222718Z" }, "id": "VYoRFUjklKOk" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1', '/job:localhost/replica:0/task:0/device:CPU:2')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "RNG stream from restoring point:\n", "PerReplica:{\n", " 0: tf.Tensor(-1.5822568, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.77539235, shape=(), dtype=float32),\n", " 2: tf.Tensor(0.6851049, shape=(), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor(-0.5039703, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.1251838, shape=(), dtype=float32),\n", " 2: tf.Tensor(-0.58519536, shape=(), dtype=float32)\n", "}\n" ] } ], "source": [ "strat2 = tf.distribute.MirroredStrategy(devices=[\"cpu:0\", \"cpu:1\", \"cpu:2\"])\n", "with strat2.scope():\n", " g2 = tf.random.Generator.from_seed(1)\n", " cp2 = tf.train.Checkpoint(my_generator=g2)\n", " cp2.restore(filename)\n", " print(\"RNG stream from restoring point:\")\n", " print(strat2.run(lambda: g2.normal([])))\n", " print(strat2.run(lambda: g2.normal([])))" ] }, { "cell_type": "markdown", "metadata": { "id": "kMltUKbANqgl" }, "source": [ "`g1` と `cp1` は `g2` と `cp2` からの別々のオブジェクトですが、共通の `filename` というチェックポイントファイルと `my_generator` というオブジェクト名を介してリンクされています。ストラテジー間で重複するレプリカ(上の `cpu:0` To `cpu:1` など)は、前の例と同様に、RNG ストリームが適切に復元されます。この動作は、ジェネレータが 1 つのストラテジーの範囲内で保存されており、ストラテジーの範囲外で復元される場合やその逆の場合には保証されません。ストラテジー外のデバイスは、ストラテジー内のレプリカと異なって処理されるためです。" ] }, { "cell_type": "markdown", "metadata": { "id": "w9dqrp1LnTaJ" }, "source": [ "#### SavedModel\n", "\n", "`tf.random.Generator` は、SavedModel に保存可能です。ジェネレータはストラテジーの範囲内で作成できます。保存もストラテジーの範囲内で行われます。 " ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.227136Z", "iopub.status.busy": "2022-12-14T20:37:08.226483Z", "iopub.status.idle": "2022-12-14T20:37:08.345914Z", "shell.execute_reply": "2022-12-14T20:37:08.345281Z" }, "id": "0AKO5SnUtyqx" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor(-1.4154755, shape=(), dtype=float32),\n", " 1: tf.Tensor(-0.11388441, shape=(), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "state: tf.Tensor([256 0 0], shape=(3,), dtype=int64)\n" ] } ], "source": [ "filename = \"./saved_model\"\n", "\n", "class MyModule(tf.Module):\n", "\n", " def __init__(self):\n", " super(MyModule, self).__init__()\n", " self.g = tf.random.Generator.from_seed(0)\n", "\n", " @tf.function\n", " def __call__(self):\n", " return self.g.normal([])\n", "\n", " @tf.function\n", " def state(self):\n", " return self.g.state\n", "\n", "strat = tf.distribute.MirroredStrategy(devices=[\"cpu:0\", \"cpu:1\"])\n", "with strat.scope():\n", " m = MyModule()\n", " print(strat.run(m))\n", " print(\"state:\", m.state())" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.348938Z", "iopub.status.busy": "2022-12-14T20:37:08.348692Z", "iopub.status.idle": "2022-12-14T20:37:08.424069Z", "shell.execute_reply": "2022-12-14T20:37:08.423431Z" }, "id": "jg2148hulfLB" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: ./saved_model/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "RNG stream from saving point:\n", "PerReplica:{\n", " 0: tf.Tensor(-0.68758255, shape=(), dtype=float32),\n", " 1: tf.Tensor(0.8084062, shape=(), dtype=float32)\n", "}\n", "state: tf.Tensor([512 0 0], shape=(3,), dtype=int64)\n", "PerReplica:{\n", " 0: tf.Tensor(-0.27342677, shape=(), dtype=float32),\n", " 1: tf.Tensor(-0.53093255, shape=(), dtype=float32)\n", "}\n", "state: tf.Tensor([768 0 0], shape=(3,), dtype=int64)\n" ] } ], "source": [ "with strat.scope():\n", " tf.saved_model.save(m, filename)\n", " print(\"RNG stream from saving point:\")\n", " print(strat.run(m))\n", " print(\"state:\", m.state())\n", " print(strat.run(m))\n", " print(\"state:\", m.state())" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.427213Z", "iopub.status.busy": "2022-12-14T20:37:08.426964Z", "iopub.status.idle": "2022-12-14T20:37:08.481859Z", "shell.execute_reply": "2022-12-14T20:37:08.481084Z" }, "id": "93AgVyzOllG7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RNG stream from loading point:\n", "state: tf.Tensor([256 0 0], shape=(3,), dtype=int64)\n", "tf.Tensor(-1.0359411, shape=(), dtype=float32)\n", "state: tf.Tensor([512 0 0], shape=(3,), dtype=int64)\n", "tf.Tensor(-0.06425078, shape=(), dtype=float32)\n", "state: tf.Tensor([768 0 0], shape=(3,), dtype=int64)\n" ] } ], "source": [ "imported = tf.saved_model.load(filename)\n", "print(\"RNG stream from loading point:\")\n", "print(\"state:\", imported.state())\n", "print(imported())\n", "print(\"state:\", imported.state())\n", "print(imported())\n", "print(\"state:\", imported.state())" ] }, { "cell_type": "markdown", "metadata": { "id": "sbb23j3pZNNq" }, "source": [ "`tf.random.Generator` を含む SavedModel を分散ストラテジーに読み込むのは、レプリカがすべて同じ乱数ストリームを生成するため、推奨されません。これは、レプリカの ID が SavedModel のグラフで凍結しているために起こります。\n", "\n", "上記の例のように、分散された `tf.random.Generator`(分散ストラテジーで作成されたジェネレータ)を非ストラテジー環境に読み込む場合にも注意が必要です。RNG の状態は適切に復元されますが、生成される乱数は元のストラテジーのジェネレータとは異なります(やはり、ストラテジー外のデバイスは、ストラテジー内のrプリカとは異なって扱われるためです)。" ] }, { "cell_type": "markdown", "metadata": { "id": "73an1POpsi6V" }, "source": [ "## ステートレス RNG\n", "\n", "ステートレス RNG の使用法はシンプルです。純粋関数であるため、ステートや副次的効果はありません。" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:37:08.485212Z", "iopub.status.busy": "2022-12-14T20:37:08.484953Z", "iopub.status.idle": "2022-12-14T20:37:08.492581Z", "shell.execute_reply": "2022-12-14T20:37:08.492010Z" }, "id": "0-aOOA3gasn_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 0.5441101 0.20738031 0.07356433]\n", " [ 0.04643455 -1.30159 -0.95385665]], shape=(2, 3), dtype=float32)\n", "tf.Tensor(\n", "[[ 0.5441101 0.20738031 0.07356433]\n", " [ 0.04643455 -1.30159 -0.95385665]], shape=(2, 3), dtype=float32)\n" ] } ], "source": [ "print(tf.random.stateless_normal(shape=[2, 3], seed=[1, 2]))\n", "print(tf.random.stateless_normal(shape=[2, 3], seed=[1, 2]))" ] }, { "cell_type": "markdown", "metadata": { "id": "2O_D-RAFNH2Q" }, "source": [ "各ステートレス RNG には、`seed` 引数が必要です。これは、形状 `[2]` の整数テンソルである必要があります。この演算の結果はこのシードによって完全に決定されます。\n", "\n", "ステートレス RNG で使用される RNG アルゴリズムはデバイス依存型であるため、異なるデバイスで実行している同一の演算の出力は異なることがあります。" ] }, { "cell_type": "markdown", "metadata": { "id": "4BvGkPnaOUPF" }, "source": [ "## アルゴリズム" ] }, { "cell_type": "markdown", "metadata": { "id": "58-8kvR4pRwO" }, "source": [ "### 全般\n", "\n", "`tf.random.Generator` クラスと `stateless` 関数は、すべてのデバイスで Philox アルゴリズム(`\"philox\"` または `tf.random.Algorithm.PHILOX` として記述されている)をサポートします。\n", "\n", "異なるデバイスで同じアルゴリズムを使用し、同じ状態から開始した場合、同じ整数が生成されます。また、デバイスが浮動小数計算の実行の仕方がデバイスによって異なるため(還元順など)わずかな数値の違いが出るかもしれませんが、「ほぼ同じ」浮動小数点数も生成します。" ] }, { "cell_type": "markdown", "metadata": { "id": "WETA04F1OYPL" }, "source": [ "### XLA デバイス\n", "\n", "XLA 駆動のデバイス(TPU など、および XLA が有効で和える場合の CPU/GPU)では、ThreeFry アルゴリズム(`\"threefry\"` また `tf.random.Algorithm.THREEFRY`)もサポートされています。このアルゴリズムは TPU では高速ですが、Philox と比べると、CPU/GPU では遅くなります。 " ] }, { "cell_type": "markdown", "metadata": { "id": "c04JkebCPTPu" }, "source": [ "これらのアルゴリズムに関する詳細は、「[Parallel Random Numbers: As Easy as 1, 2, 3](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)」をご覧ください。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "random_numbers.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }