{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "MhoQ0WE77laV" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-01-11T18:25:56.638149Z", "iopub.status.busy": "2024-01-11T18:25:56.637919Z", "iopub.status.idle": "2024-01-11T18:25:56.641768Z", "shell.execute_reply": "2024-01-11T18:25:56.641064Z" }, "id": "_ckMIh7O7s6D" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "jYysdyb-CaWM" }, "source": [ "# 分散入力" ] }, { "cell_type": "markdown", "metadata": { "id": "S5Uhzt6vVIB2" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示Google Colabで実行GitHubでソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "FbVhjPpzn6BM" }, "source": [ "[tf.distribute](https://www.tensorflow.org/guide/distributed_training) API は、1 台のコンピュータから複数のコンピュータにトレーニングを簡単にスケーリングする方法を提供します。モデルをスケーリングする際には、ユーザーは入力を複数のデバイスに分散する必要がありますが、`tf.distribute` は、入力を自動的にデバイスに分散できる API を提供します。\n", "\n", "このガイドは、`tf.distribute` API を使用して、分散データセットとイテレータを作成するためのさまざまな方法を見ていきます。さらに、次のトピックについても説明しています。\n", "\n", "- `tf.distribute.Strategy.experimental_distribute_dataset` と `tf.distribute.Strategy.distribute_datasets_from_function` の使用方法、およびこれらを使用したシャーディングとバッチオプション\n", "- 分散データセットのさまざまなイテレーション方法\n", "- `tf.distribute.Strategy.experimental_distribute_dataset`/`tf.distribute.Strategy.distribute_datasets_from_function` API と `tf.data` API の違い、および使用時の制限\n", "\n", "このガイドでは、Keras API を使用した分散入力の使用方法は説明されていません。" ] }, { "cell_type": "markdown", "metadata": { "id": "MM6W__qraV55" }, "source": [ "## 分散データセット" ] }, { "cell_type": "markdown", "metadata": { "id": "lNy9GxjSlMKQ" }, "source": [ "`tf.distribute` API を使用してスケーリングするには、`tf.data.Dataset` を使って入力を表現します。`tf.distribute` は、パフォーマンス最適化を定期的に実装に統合しながら、`tf.data.Dataset` と効率的に動作します (各アクセラレータデバイスへのデータの自動プリフェッチ機能、定期的なパフォーマンスの更新など)。`tf.data.Dataset` 以外を使用するユースケースがある場合は、このガイドの [Tensor 入力セクション](#tensorinputs)を参照してください。非分散型トレーニングループでは、`tf.data.Dataset` インスタンスを作成してから要素をイテレートします。次に例を示します。\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:25:56.645433Z", "iopub.status.busy": "2024-01-11T18:25:56.645187Z", "iopub.status.idle": "2024-01-11T18:25:58.999612Z", "shell.execute_reply": "2024-01-11T18:25:58.998919Z" }, "id": "pCu2Jj-21AEf" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-01-11 18:25:57.072759: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-01-11 18:25:57.072803: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-01-11 18:25:57.074336: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2.15.0\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "# Helper libraries\n", "import numpy as np\n", "import os\n", "\n", "print(tf.__version__)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:25:59.003650Z", "iopub.status.busy": "2024-01-11T18:25:59.003052Z", "iopub.status.idle": "2024-01-11T18:25:59.623568Z", "shell.execute_reply": "2024-01-11T18:25:59.622856Z" }, "id": "6cnilUtmKwpa" }, "outputs": [], "source": [ "# Simulate multiple CPUs with virtual devices\n", "N_VIRTUAL_DEVICES = 2\n", "physical_devices = tf.config.list_physical_devices(\"CPU\")\n", "tf.config.set_logical_device_configuration(\n", " physical_devices[0], [tf.config.LogicalDeviceConfiguration() for _ in range(N_VIRTUAL_DEVICES)])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:25:59.626868Z", "iopub.status.busy": "2024-01-11T18:25:59.626623Z", "iopub.status.idle": "2024-01-11T18:26:00.939104Z", "shell.execute_reply": "2024-01-11T18:26:00.938388Z" }, "id": "zd4l1ySeLRk1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Available devices:\n", "0) LogicalDevice(name='/device:CPU:0', device_type='CPU')\n", "1) LogicalDevice(name='/device:CPU:1', device_type='CPU')\n", "2) LogicalDevice(name='/device:GPU:0', device_type='GPU')\n", "3) LogicalDevice(name='/device:GPU:1', device_type='GPU')\n", "4) LogicalDevice(name='/device:GPU:2', device_type='GPU')\n", "5) LogicalDevice(name='/device:GPU:3', device_type='GPU')\n" ] } ], "source": [ "print(\"Available devices:\")\n", "for i, device in enumerate(tf.config.list_logical_devices()):\n", " print(\"%d) %s\" % (i, device))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:26:00.942771Z", "iopub.status.busy": "2024-01-11T18:26:00.942318Z", "iopub.status.idle": "2024-01-11T18:26:01.375527Z", "shell.execute_reply": "2024-01-11T18:26:01.374850Z" }, "id": "dzLKpmZICaWN" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(16, 1), dtype=float32)\n", "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(16, 1), dtype=float32)\n", "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(16, 1), dtype=float32)\n", "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(16, 1), dtype=float32)\n", "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(16, 1), dtype=float32)\n", "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(16, 1), dtype=float32)\n", "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n" ] } ], "source": [ "global_batch_size = 16\n", "# Create a tf.data.Dataset object.\n", "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n", "\n", "@tf.function\n", "def train_step(inputs):\n", " features, labels = inputs\n", " return labels - 0.3 * features\n", "\n", "# Iterate over the dataset using the for..in construct.\n", "for inputs in dataset:\n", " print(train_step(inputs))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ihrhYDYRrVLH" }, "source": [ "ユーザーの既存のコードへの変更を最小限に抑えて `tf.distribute` ストラテジーを使用できるように、`tf.data.Dataset` インスタンスを分散し、分散化されたデータセットインスタンスを返す、2 つの API が導入されています。その分散化されたデータセットインスタンスを以前と同様にイテレートして、モデルをトレーニングすることができます。では、これら 2 つの API を詳しく見てみましょう。`tf.distribute.Strategy.experimental_distribute_dataset` API と `tf.distribute.Strategy.distribute_datasets_from_function` API です。" ] }, { "cell_type": "markdown", "metadata": { "id": "4AXoHhrsbdF3" }, "source": [ "### `tf.distribute.Strategy.experimental_distribute_dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "5mVuLZhbem8d" }, "source": [ "#### 使い方\n", "\n", "この API は`tf.data.Dataset`インスタンスを入力として取り、`tf.distribute.DistributedDataset`インスタンスを返します。この入力データセットを、グローバルバッチサイズと同じ値でバッチ化します。このグローバルバッチサイズは、1 つのステップで処理する全デバイスのサンプル数です。この分散データセットのイテレーションを Python 式に行うか、`iter` を使用してイテレータを作成します。返されるオブジェクトは`tf.data.Dataset`インスタンスではなく、またデータセットを変換したり検査したりするほかの API をまったくサポートしていません。これは、入力をさまざまなレプリカにシャーディングするための特定の方法がない場合に推奨される API です。\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:26:01.379119Z", "iopub.status.busy": "2024-01-11T18:26:01.378881Z", "iopub.status.idle": "2024-01-11T18:26:02.660797Z", "shell.execute_reply": "2024-01-11T18:26:02.659984Z" }, "id": "F2VeZUWUj5S4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "}, PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] } ], "source": [ "global_batch_size = 16\n", "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "\n", "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n", "# Distribute input using the `experimental_distribute_dataset`.\n", "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n", "# 1 global batch of data fed to the model in 1 step.\n", "print(next(iter(dist_dataset)))" ] }, { "cell_type": "markdown", "metadata": { "id": "QPceDmRht54F" }, "source": [ "#### プロパティ" ] }, { "cell_type": "markdown", "metadata": { "id": "0Qb6nDgxiN_n" }, "source": [ "##### バッチ処理\n", "\n", "`tf.distribute` は、グローバルバッチサイズを同期中のレプリカの数で除算した値に等しい新しいバッチサイズで、入力 `tf.data.Dataset` インスタンスのバッチを再作成します。同期中のレプリカの数は、トレーニング中の勾配の allreduce に参加しているデバイスの数と同等です。ユーザーが分散イテレータで `next` を呼び出すと、レプリカ当たりのデータサイズが各レプリカに返されます。再バッチされたデータセットのカーディナリティは、必ずレプリカ数の倍数になります。次にいくつかの例を示します。\n", "\n", "- `tf.data.Dataset.range(6).batch(4, drop_remainder=False)`\n", "\n", " - 分散無し:\n", "\n", " - バッチ 1: [0, 1, 2, 3]\n", " - バッチ 2: [4, 5]\n", "\n", " - 2 つのレプリカで分散。最後のバッチ([4, 5])は、2つのレプリカ間で分割。\n", "\n", " - バッチ 1:\n", "\n", " - レプリカ 1: [0, 1]\n", " - レプリカ 2: [2, 3]\n", "\n", " - バッチ 2:\n", "\n", " - レプリカ 1: [4]\n", " - レプリカ 2: [5]\n", "\n", "- `tf.data.Dataset.range(4).batch(4)`\n", "\n", " - 分散無し:\n", " - バッチ 1: [0, 1, 2, 3]\n", " - 5 つのレプリカで分散:\n", " - バッチ 1:\n", " - レプリカ 1: [0]\n", " - レプリカ 2: [1]\n", " - レプリカ 3: [2]\n", " - レプリカ 4: [3]\n", " - レプリカ 5: []\n", "\n", "- `tf.data.Dataset.range(8).batch(4)`\n", "\n", " - 分散無し:\n", " - バッチ 1: [0, 1, 2, 3]\n", " - バッチ 2: [4, 5, 6, 7]\n", " - 3 つのレプリカで分散:\n", " - バッチ 1:\n", " - レプリカ 1: [0, 1]\n", " - レプリカ 2: [2, 3]\n", " - レプリカ 3: []\n", " - バッチ 2:\n", " - レプリカ 1: [4, 5]\n", " - レプリカ 2: [6, 7]\n", " - レプリカ 3: []\n", "\n", "注意: 上記の例は、異なるレプリカでグローバルバッチがどのように分割されるかのみを説明しています。実装によって実際の値が異なる可能性があるため、各レプリカで最終的に得られる可能性のある実際の値に依存することはお勧めできません。\n", "\n", "データセットのバッチの再作成には、レプリカの数とともに直線的に増加する空間的コストがあります。つまり、マルチワーカートレーニングのユースケースで言えば、入力パイプラインで OOM エラーが発生する可能性があります。 " ] }, { "cell_type": "markdown", "metadata": { "id": "IszBuubdtydp" }, "source": [ "##### シャーディング\n", "\n", "`tf.distribute` は、`MultiWorkerMirroredStrategy` と `TPUStrategy` のマルチワーカートレーニングで入力データセットの自動シャーディングも行います。各データセットはワーカーの CPU デバイス上に作成されます。データセットを一連のワーカーで自動シャーディングすると、各ワーカーにデータセット全体のサブセットが割り当てられることになります(適切な `tf.data.experimental.AutoShardPolicy` が設定されている場合)。これは、各ステップにおいて、オーバーラップしていないデータセット要素のグローバルバッチサイズが各ワーカーで処理されるようにするためです。自動シャーディングには、`tf.data.experimental.DistributeOptions` で指定できる 2 つのオプションがあります。`ParameterServerStrategy` のマルチワーカーでは自動シャーディングは行われません。このストラテジーでのデータセット作成の詳細については、[ParameterServerStrategy のチュートリアル](parameter_server_training.ipynb)をご覧ください。 " ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:26:02.664611Z", "iopub.status.busy": "2024-01-11T18:26:02.664343Z", "iopub.status.idle": "2024-01-11T18:26:02.674283Z", "shell.execute_reply": "2024-01-11T18:26:02.673614Z" }, "id": "jwJtsCQhHK-E" }, "outputs": [], "source": [ "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(64).batch(16)\n", "options = tf.data.Options()\n", "options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA\n", "dataset = dataset.with_options(options)" ] }, { "cell_type": "markdown", "metadata": { "id": "J7fj3GskHC8g" }, "source": [ "`tf.data.experimental.AutoShardPolicy` に設定できるオプションには 3 つあります。\n", "\n", "- AUTO: デフォルトのオプションです。ファイルごとにシャーディングしようとします。ファイルベースのデータセットが検出されない場合、ファイルごとのシャーディングは失敗し、`tf.distribute` はデータごとのシャーディングに切り替えます。入力データセットがファイルベースであっても、ファイル数がワーカー数より少ない場合は、`InvalidArgumentError` が発生します。これが発生した場合は、ポリシーを明示的に`AutoShardPolicy.DATA` に設定するか、入力ソースを小さなファイルに分割して、ファイル数がワーカー数よりも多くなるようにします。\n", "\n", "- FILE: すべてのワーカーで入力をシャーディングする場合のオプションです。入力ファイルの数がワーカー数を大きく上回り、ファイル内のデータが均等に分散されている場合は、このオプションを使用します。このオプションの欠点は、ファイル内のデータが均等に分散されていない場合にアイドル状態のワーカーが存在することにあります。ファイル数がワーカー数より少ない場合、`InvalidArgumentError` が発生します。 これが発生した場合は、ポリシーを明示的に `AutoShardPolicy.DATA` に設定してください。例として、2 つのファイルをそれぞれに 1 つのレプリカを持つ 2 つのワーカーに分散します。ファイル 1 には [0, 1, 2, 3, 4, 5]、ファイル 2 には [6, 7, 8, 9, 10, 11] が含まれます。同期中のレプリカの合計数を 2 、グローバルバッチサイズを 4 とします。\n", "\n", " - ワーカー 0:\n", " - バッチ 1 = レプリカ 1: [0, 1]\n", " - バッチ 2 = レプリカ 1: [2, 3]\n", " - バッチ 3 = レプリカ 1: [4]\n", " - バッチ 4 = レプリカ 1: [5]\n", " - ワーカー 1:\n", " - バッチ 1 = レプリカ 2: [6, 7]\n", " - バッチ 2 = レプリカ 2: [8, 9]\n", " - バッチ 3 = レプリカ 2: [10]\n", " - バッチ 4 = レプリカ 2: [11]\n", "\n", "- DATA: すべてのワーカーで要素を自動シャーディングします。各ワーカーはデータセット全体を読み取って、それに割り当てられたシャードのみを処理し、その他すべてのシャードは破棄されます。これは通常、入力ファイルの数がワーカー数より少なく、すべてのワーカー間でデータのシャーディングをより最適に行う場合に使用されます。欠点は、各ワーカーでデータセット全体が読み取られることです。例として、1 つのファイルを 2 つのワーカーで分散します。ファイル 1 には [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] が含まれます。同期中のレプリカの合計数を 2 とします。\n", "\n", " - ワーカー 0:\n", " - バッチ 1 = レプリカ 1: [0, 1]\n", " - バッチ 2 = レプリカ 1: [4, 5]\n", " - バッチ 3 = レプリカ 1: [8, 9]\n", " - ワーカー 1:\n", " - バッチ 1 = レプリカ 2: [2, 3]\n", " - バッチ 2 = レプリカ 2: [6, 7]\n", " - バッチ 3 = レプリカ 2: [10, 11]\n", "\n", "- OFF: 自動シャーディングをオフにすると、各ワーカーはすべてのデータを処理します。例として、1 つのファイルを 2 つのワーカーで分散します。ファイル 1 には、[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] が含まれます。同期中のレプリカの合計数を 2 とします。各ワーカーは、次のような分散になります。\n", "\n", " - ワーカー 0:\n", "\n", " - バッチ 1 = レプリカ 1: [0, 1]\n", " - バッチ 2 = レプリカ 1: [2, 3]\n", " - バッチ 3 = レプリカ 1: [4, 5]\n", " - バッチ 4 = レプリカ 1: [6, 7]\n", " - バッチ 5 = レプリカ 1: [8, 9]\n", " - バッチ 6 = レプリカ 1: [10, 11]\n", "\n", " - ワーカー 1:\n", "\n", " - バッチ 1 = レプリカ 2: [0, 1]\n", " - バッチ 2 = レプリカ 2: [2, 3]\n", " - バッチ 3 = レプリカ 2: [4, 5]\n", " - バッチ 4 = レプリカ 2: [6, 7]\n", " - バッチ 5 = レプリカ 2: [8, 9]\n", " - バッチ 6 = レプリカ 2: [10, 11] " ] }, { "cell_type": "markdown", "metadata": { "id": "OK46ZJGPH5H2" }, "source": [ "##### プリフェッチ\n", "\n", "デフォルトでは、`tf.distribute`はユーザーが提供する`tf.data.Dataset`インスタンスにプリフェッチ変換を追加します。プリフェッチ変換に対する引数`buffer_size`は同期中のレプリカの数と同等です。" ] }, { "cell_type": "markdown", "metadata": { "id": "PjiGSY3gtr6_" }, "source": [ "### `tf.distribute.Strategy.distribute_datasets_from_function`" ] }, { "cell_type": "markdown", "metadata": { "id": "bAXAo_wWbWSb" }, "source": [ "#### 使い方\n", "\n", "この API は、入力関数を取って `tf.distribute.DistributedDataset` インスタンスを返します。ユーザーが渡す入力関数には `tf.distribute.InputContext` 引数があり、`tf.data.Dataset` インスタンスを返します。この API を使用すると、`tf.distribute` は、入力関数から返されたユーザーの `tf.data.Dataset` インスタンスにそれ以降の変更を適用しません。そのため、ユーザーがデータセットをバッチ処理してシャーディングする必要があります。`tf.distribute` は各ワーカーの CPU デバイスで入力関数を呼び出します。ユーザーが独自のバッチングとシャーディングのロジックを指定できるほか、この API は、マルチワーカートレーニングに使用される場合に、`tf.distribute.Strategy.experimental_distribute_dataset` よりも優れたスケーラビリティとパフォーマンスを示します。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:26:02.677548Z", "iopub.status.busy": "2024-01-11T18:26:02.677288Z", "iopub.status.idle": "2024-01-11T18:26:02.692418Z", "shell.execute_reply": "2024-01-11T18:26:02.691717Z" }, "id": "9ODch-OFCaW4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] } ], "source": [ "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "\n", "def dataset_fn(input_context):\n", " batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n", " dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(64).batch(16)\n", " dataset = dataset.shard(\n", " input_context.num_input_pipelines, input_context.input_pipeline_id)\n", " dataset = dataset.batch(batch_size)\n", " dataset = dataset.prefetch(2) # This prefetches 2 batches per device.\n", " return dataset\n", "\n", "dist_dataset = mirrored_strategy.distribute_datasets_from_function(dataset_fn)" ] }, { "cell_type": "markdown", "metadata": { "id": "M1bpzPYzt_R7" }, "source": [ "#### プロパティ" ] }, { "cell_type": "markdown", "metadata": { "id": "7cgzhwiiuBvO" }, "source": [ "##### バッチ処理\n", "\n", "入力関数の戻り値である`tf.data.Dataset`インスタンスは、レプリカごとのバッチサイズを使用してバッチ処理する必要があります。レプリカごとのバッチサイズは、グローバルバッチサイズを同期型トレーニングに参加しているレプリカの数で除算した値です。これは、`tf.distribute`が各ワーカーの CPU デバイスで入力関数を呼び出すためです。あるワーカーで作成されるデータセットは、そのワーカーのすべてのレプリカで使用する準備を整えています。 " ] }, { "cell_type": "markdown", "metadata": { "id": "e-wlFFZbP33n" }, "source": [ "##### シャーディング\n", "\n", "ユーザーの入力関数への引数として暗黙的に渡される`tf.distribute.InputContext`オブジェクトは、内部的に`tf.distribute`よって作成されます。このオブジェクトには、ワーカー数、現在のワーカー ID などの情報が含まれます。この入力関数は、`tf.distribute.InputContext`オブジェクトの一部であるプロパティを使用し、ユーザーが設定したポリシーに従って、シャーディングを処理することができます。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "7TGwnDM-ICHf" }, "source": [ "##### プリフェッチ\n", "\n", "`tf.distribute` は、ユーザーが提供する入力関数によって返される `tf.data.Dataset` の最後に、プリフェッチ変換を追加しません。したがって、上記の例では明示的に `Dataset.prefetch` を呼び出します。" ] }, { "cell_type": "markdown", "metadata": { "id": "iOMsf8kyZZpv" }, "source": [ "注意:`tf.distribute.Strategy.experimental_distribute_dataset` と `tf.distribute.Strategy.distribute_datasets_from_function` は両方とも、**`tf.data.Dataset` 型ではない `tf.distribute.DistributedDataset` インスタンス**を返します。これらのインスタンスをイテレートし(「分散イテレータ」を参照)、`element_spec` プロパティを使用することができます。 " ] }, { "cell_type": "markdown", "metadata": { "id": "dL3XbI1gzEjO" }, "source": [ "## 分散イテレータ" ] }, { "cell_type": "markdown", "metadata": { "id": "w8y54-o9T2Ni" }, "source": [ "非分散型`tf.data.Dataset`インスタンスと同様に、`tf.distribute.DistributedDataset`インスタンスを作成してイテレートし、`tf.distribute.DistributedDataset`の要素にアクセスする必要があります。次に、`tf.distribute.DistributedIterator`を作成して、それをモデルのトレーニングに使用する方法を示します。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "FlKh8NV0uOtZ" }, "source": [ "### 使用方法" ] }, { "cell_type": "markdown", "metadata": { "id": "eSZz6EqOuSlB" }, "source": [ "#### Python 式の for ループコンストラクトを使用する\n", "\n", "ユーザーフレンドリーな Python 式のループを使用して、`tf.distribute.DistributedDataset` をイテレートすることができます。`tf.distribute.DistributedIterator` から返される要素は、単一の`tf.Tensor` か、レプリカあたりの値を含む `tf.distribute.DistributedValues` です。`tf.function` にループを配置すると、パフォーマンスは上昇しますが、`break` と`return` は、現在、`tf.function` 内に配置された `tf.distribute.DistributedDataset` のループではサポートされていません。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:26:02.696252Z", "iopub.status.busy": "2024-01-11T18:26:02.695868Z", "iopub.status.idle": "2024-01-11T18:26:03.173762Z", "shell.execute_reply": "2024-01-11T18:26:03.172944Z" }, "id": "zt3AHb46Tr3w" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32),\n", " 1: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32),\n", " 2: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32),\n", " 3: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32)\n", "}\n" ] } ], "source": [ "global_batch_size = 16\n", "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "\n", "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n", "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n", "\n", "@tf.function\n", "def train_step(inputs):\n", " features, labels = inputs\n", " return labels - 0.3 * features\n", "\n", "for x in dist_dataset:\n", " # train_step trains the model using the dataset elements\n", " loss = mirrored_strategy.run(train_step, args=(x,))\n", " print(\"Loss is \", loss)" ] }, { "cell_type": "markdown", "metadata": { "id": "NchPwTEiuSqb" }, "source": [ "#### `iter`を使用して明示的なイテレータを作成する\n", "\n", "`tf.distribute.DistributedDataset` インスタンスの要素をイテレートするには、`iter` API を使って`tf.distribute.DistributedIterator` を作成することができます。明示的なイテレータを使用すると、一定のステップ数、イテレートすることができます。`tf.distribute.DistributedIterator` インスタンスの `dist_iterator` から次の要素を取得するには、`next(dist_iterator)`、`dist_iterator.get_next()`、または `dist_iterator.get_next_as_optional()` を呼び出すことができます。最初の 2 つは基本的に同じです。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:26:03.177533Z", "iopub.status.busy": "2024-01-11T18:26:03.176838Z", "iopub.status.idle": "2024-01-11T18:26:06.349981Z", "shell.execute_reply": "2024-01-11T18:26:06.349236Z" }, "id": "OrMmakq5EqeQ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] } ], "source": [ "num_epochs = 10\n", "steps_per_epoch = 5\n", "for epoch in range(num_epochs):\n", " dist_iterator = iter(dist_dataset)\n", " for step in range(steps_per_epoch):\n", " # train_step trains the model using the dataset elements\n", " loss = mirrored_strategy.run(train_step, args=(next(dist_iterator),))\n", " # which is the same as\n", " # loss = mirrored_strategy.run(train_step, args=(dist_iterator.get_next(),))\n", " print(\"Loss is \", loss)" ] }, { "cell_type": "markdown", "metadata": { "id": "UpJXIlxjqPYg" }, "source": [ "`next` または `tf.distribute.DistributedIterator.get_next` を使用すると、`tf.distribute.DistributedIterator` が最後に到達した場合に、OutOfRange エラーが発生します。クライアントは Python 側でそのエラーをキャッチし、チェックポイント作成や評価といった他の作業を継続することができます。ただし、次に示すようなホストトレーニングループ (`tf.function` ごとに複数のステップを実行する) を使用している場合は機能しません。\n", "\n", "```\n", "@tf.function\n", "def train_fn(iterator):\n", " for _ in tf.range(steps_per_loop):\n", " strategy.run(step_fn, args=(next(iterator),))\n", "```\n", "\n", "`train_fn` には、`tf.range` 内にステップ本文をラッピングすることで、複数のステップが含まれています。この場合、ループでの依存関係のない別のイテレーションが並行して開始する可能性があるため、前のイテレーションの計算が終了する前の後の方のイテレーションで、OutOfRange エラーが発生することがあります。OutOfRange エラーが発生してしまえば、関数内のすべての演算は即座に終了されてしまいます。この状況を避ける場合は、OutOfRange エラーが発生しない別の方法として、`tf.distribute.DistributedIterator.get_next_as_optional` が挙げられます。`get_next_as_optional` は、次の要素を含む `tf.experimental.Optional` を返すか、`tf.distribute.DistributedIterator` が最後に達している場合は何の値も返しません。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:26:06.353931Z", "iopub.status.busy": "2024-01-11T18:26:06.353396Z", "iopub.status.idle": "2024-01-11T18:26:06.994729Z", "shell.execute_reply": "2024-01-11T18:26:06.993995Z" }, "id": "Iyjao96Vqwyz" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "([0], [1], [2], [3])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "([4], [5], [6], [7])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "([8], [], [], [])\n" ] } ], "source": [ "# You can break the loop with `get_next_as_optional` by checking if the `Optional` contains a value\n", "global_batch_size = 4\n", "steps_per_loop = 5\n", "strategy = tf.distribute.MirroredStrategy()\n", "\n", "dataset = tf.data.Dataset.range(9).batch(global_batch_size)\n", "distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))\n", "\n", "@tf.function\n", "def train_fn(distributed_iterator):\n", " for _ in tf.range(steps_per_loop):\n", " optional_data = distributed_iterator.get_next_as_optional()\n", " if not optional_data.has_value():\n", " break\n", " per_replica_results = strategy.run(lambda x: x, args=(optional_data.get_value(),))\n", " tf.print(strategy.experimental_local_results(per_replica_results))\n", "train_fn(distributed_iterator)" ] }, { "cell_type": "markdown", "metadata": { "id": "LaclbKnqzLjf" }, "source": [ "## `element_spec` プロパティを使用する" ] }, { "cell_type": "markdown", "metadata": { "id": "Z1YvXqOpwy08" }, "source": [ "分散データセットの要素を `tf.function` に渡し、`tf.TypeSpec` の保証を必要としている場合は、`tf.function` の `input_signature` 引数を指定することができます。分散データセットの出力は、単一のデバイスまたは複数のデバイスへの入力を表せる `tf.distribute.DistributedValues` です。この分散値に対応する `tf.TypeSpec` を取得するには、`tf.distribute.DistributedDataset.element_spec` または `tf.distribute.DistributedIterator.element_spec` を使用することができます。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:26:06.998394Z", "iopub.status.busy": "2024-01-11T18:26:06.997849Z", "iopub.status.idle": "2024-01-11T18:26:08.952868Z", "shell.execute_reply": "2024-01-11T18:26:08.952087Z" }, "id": "pg3B-Cw_cn3a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] } ], "source": [ "global_batch_size = 16\n", "epochs = 5\n", "steps_per_epoch = 5\n", "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "\n", "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n", "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n", "\n", "@tf.function(input_signature=[dist_dataset.element_spec])\n", "def train_step(per_replica_inputs):\n", " def step_fn(inputs):\n", " return 2 * inputs\n", "\n", " return mirrored_strategy.run(step_fn, args=(per_replica_inputs,))\n", "\n", "for _ in range(epochs):\n", " iterator = iter(dist_dataset)\n", " for _ in range(steps_per_epoch):\n", " output = train_step(next(iterator))\n", " tf.print(output)" ] }, { "cell_type": "markdown", "metadata": { "id": "-OAa6svUzuWm" }, "source": [ "## データの前処理をする" ] }, { "cell_type": "markdown", "metadata": { "id": "pSMrs3kJQexW" }, "source": [ "以上、`tf.data.Dataset` を分散する方法を見てきましたが、データをモデルに使用する前に、データのクレンジング、変換、拡張などの前処理を行う必要があります。 次の 2 つの便利なツールを利用できます。\n", "\n", "- [Keras 前処理レイヤー](https://www.tensorflow.org/guide/keras/preprocessing_layers): 開発者が Keras ネイティブの入力処理パイプラインを構築できるようにする一連の Keras レイヤーです。一部の Keras 前処理レイヤーには、初期化または `adapt` 時に設定できるトレーニング不可能な状態が含まれています ([Keras 前処理レイヤー ガイド](https://www.tensorflow.org/guide/keras/preprocessing_layers)の `adapt` セクションを参照してください)。 ステートフルな前処理レイヤーを分散する場合、状態をすべてのワーカーに複製する必要があります。これらのレイヤーを使用するには、それらをモデルの一部にするか、データセットに適用します。\n", "\n", "- [TensorFlow Transform (tf.Transform)](https://www.tensorflow.org/tfx/transform/get_started): データ前処理パイプラインを介してインスタンスレベルとフルパスの両方のデータ変換を定義するための TensorFlow のライブラリです。Tensorflow Transform には 2 つのフェーズがあります。 1 つ目は分析フェーズです。ここでは、生のトレーニングデータがフルパスプロセスで分析され、変換に必要な統計が計算され、変換ロジックがインスタンスレベルの演算として生成されます。2 つ目は変換フェーズで、生のトレーニングデータがインスタンスレベルのプロセスで変換されます。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Pd4aUCFdVlZ1" }, "source": [ "### Keras 前処理レイヤーと Tensorflow Transform の比較\n", "\n", "Tensorflow Transform と Keras の前処理レイヤーはどちらも、トレーニング時の前処理を分割し、推論中に前処理をモデルにバンドルして、トレーニング/サーブ スキューを減らす方法を提供します。\n", "\n", "[TFX](https://www.tensorflow.org/tfx) と密に統合された Tensorflow Transform は、トレーニングパイプラインとは別のジョブで、あらゆるサイズのデータセットを分析および変換するスケーラブルな map-reduce ソリューションを提供します。単一のマシンに収まらないデータセットで分析を実行する必要がある場合は、Tensorflow Transform が最初の選択肢になります。\n", "\n", "Keras 前処理レイヤーは、ディスクからデータを読み取った後、トレーニング時に前処理を適用する場合に最適です。これらは、Keras ライブラリのモデル開発にシームレスに適合し、[`adapt`](https://www.tensorflow.org/guide/keras/preprocessing_layers#the_adapt_method) による小規模なデータセットの分析をサポートします。Keras 前処理レイヤーは、画像データの拡張などのユースケースをサポートし、入力データセットを通過するたびに、トレーニング用のさまざまな例が生成されます。\n", "\n", "2 つのライブラリを混在させることもできます。この場合、Tensorflow Transform は入力データの分析と静的変換に、Keras 前処理レイヤーはトレーニング時の変換(One-Hot エンコーディングやデータ拡張など)に使用されます。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MReKhhZpHUpj" }, "source": [ "### tf.distribute のベスト プラクティス\n", "\n", "両方のツールを使用する場合、変換ロジックを初期化してデータに適用する必要があります。これにより Tensorflow リソースが作成される場合があります。これらのリソースまたは状態は、ワーカー間またはワーカーとコーディネーター間の通信を節約するために、すべてのワーカーに複製する必要があります。そのためには、他の Keras レイヤーと同じように Keras 前処理レイヤー、`tft.TFTransformOutput.transform_features_layer`、または `tft.TransformFeaturesLayer` を `tf.distribute.Strategy.scope` の下に作成することをお勧めします。\n", "\n", "次の例は、`tf.distribute.Strategy` API を高レベルの Keras `Model.fit` API およびカスタムトレーニングループと別に使用する方法を示しています。" ] }, { "cell_type": "markdown", "metadata": { "id": "rwEGMWuoX7kJ" }, "source": [ "#### Keras 前処理レイヤーのユーザー向けの注意事項:\n", "\n", "**前処理レイヤーと大きな語彙**\n", "\n", "マルチワーカー設定で大きな語彙 (1 ギガバイト以上) を扱う場合 (`tf.distribute.MultiWorkerMirroredStrategy`、`tf.distribute.experimental.ParameterServerStrategy`、`tf.distribute.TPUStrategy` など)、すべてのワーカーからアクセス可能な静的ファイル (Cloud Storage などを使用する) に語彙を保存することをお勧めします。これにより、トレーニング時に語彙をすべてのワーカーに複製するのにかかる時間を短縮できます。\n", "\n", "**`tf.data` パイプラインでの前処理とモデルでの前処理の比較**\n", "\n", "Keras の前処理レイヤーはモデルの一部として適用することも、`tf.data.Dataset` に直接適用することもできますが、それぞれ利点があります。\n", "\n", "- モデル内に前処理レイヤーを適用すると、モデルが移植可能になり、トレーニング/サービング スキューを低減するのに役立ちます。(詳細については、[前処理レイヤーの使用ガイド](https://www.tensorflow.org/guide/keras/preprocessing_layers#benefits_of_doing_preprocessing_inside_the_model_at_inference_time)の*推論時にモデル内で前処理を行う利点*セクションを参照してください)。\n", "- `tf.data` パイプライン内で適用すると、プリフェッチまたは CPU へのオフロードが可能になり、アクセラレータを使用する際のパフォーマンスが向上します。\n", "\n", "1 つ以上の TPU で実行する場合、ほとんどの場合、ユーザーは Keras 前処理レイヤーを `tf.data` パイプラインに配置する必要があります。すべてのレイヤーは TPU をサポートしていないので、文字列演算は TPU では実行されません。(2 つの例外は、`tf.keras.layers.Normalization` と `tf.keras.layers.Rescaling` です。これらは TPU で正常に動作し、一般的に画像モデルの最初のレイヤーとして使用されます。)" ] }, { "cell_type": "markdown", "metadata": { "id": "hNCYZ9L-BD2R" }, "source": [ "### `Model.fit` で前処理する" ] }, { "cell_type": "markdown", "metadata": { "id": "NhRB2Xe8B6bX" }, "source": [ "Keras `Model.fit` を使用する場合、`tf.distribute.Strategy.experimental_distribute_dataset` や `tf.distribute.Strategy.distribute_datasets_from_function` でデータを分散する必要はありません。詳細については、[前処理レイヤーの使用](https://www.tensorflow.org/guide/keras/preprocessing_layers)ガイドと [Keras を使用した分散トレーニング](https://www.tensorflow.org/tutorials/distribute/keras)ガイドを参照してください。次に簡単な例を示します。\n", "\n", "```\n", "strategy = tf.distribute.MirroredStrategy()\n", "with strategy.scope():\n", " # Create the layer(s) under scope.\n", " integer_preprocessing_layer = tf.keras.layers.IntegerLookup(vocabulary=FILE_PATH)\n", " model = ...\n", " model.compile(...)\n", "dataset = dataset.map(lambda x, y: (integer_preprocessing_layer(x), y))\n", "model.fit(dataset)\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "id": "3zL2vzJ-G0yg" }, "source": [ "`Model.fit` API を使用する `tf.distribute.experimental.ParameterServerStrategy` のユーザーは、`tf.keras.utils.experimental.DatasetCreator` を入力として使用する必要があります。(詳細については、[パラメータサーバートレーニング](https://www.tensorflow.org/tutorials/distribute/parameter_server_training#parameter_server_training_with_modelfit_api)ガイドを参照してください)\n", "\n", "```\n", "strategy = tf.distribute.experimental.ParameterServerStrategy(\n", " cluster_resolver,\n", " variable_partitioner=variable_partitioner)\n", "\n", "with strategy.scope():\n", " preprocessing_layer = tf.keras.layers.StringLookup(vocabulary=FILE_PATH)\n", " model = ...\n", " model.compile(...)\n", "\n", "def dataset_fn(input_context):\n", " ...\n", " dataset = dataset.map(preprocessing_layer)\n", " ...\n", " return dataset\n", "\n", "dataset_creator = tf.keras.utils.experimental.DatasetCreator(dataset_fn)\n", "model.fit(dataset_creator, epochs=5, steps_per_epoch=20, callbacks=callbacks)\n", "\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "imZLQUOYBJyW" }, "source": [ "### カスタムトレーニングループを使った前処理" ] }, { "cell_type": "markdown", "metadata": { "id": "r2PX1QH_OwU3" }, "source": [ "[カスタムトレーニングループ](https://www.tensorflow.org/tutorials/distribute/custom_training)を作成する場合、`tf.distribute.Strategy.experimental_distribute_dataset` API または `tf.distribute.Strategy.distribute_datasets_from_function` API のいずれかを使用してデータを分散します。`tf.distribute.Strategy.experimental_distribute_dataset` を介してデータセットを分散する場合、これらの前処理 API をデータパイプラインに適用すると、リソースが自動的にデータパイプラインと同じ場所に配置され、リモートリソースアクセスを回避できます。したがって、ここでの例はすべて `tf.distribute.Strategy.distribute_datasets_from_function` を使用します。この場合、これらの API の初期化を `strategy.scope()` の下に配置して効率化することが重要です。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:26:08.957398Z", "iopub.status.busy": "2024-01-11T18:26:08.956788Z", "iopub.status.idle": "2024-01-11T18:26:09.465469Z", "shell.execute_reply": "2024-01-11T18:26:09.464657Z" }, "id": "wJS1UmcWQeab" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor([1], shape=(1,), dtype=int64),\n", " 1: tf.Tensor([3], shape=(1,), dtype=int64),\n", " 2: tf.Tensor([0], shape=(1,), dtype=int64),\n", " 3: tf.Tensor([1], shape=(1,), dtype=int64)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor([3], shape=(1,), dtype=int64),\n", " 1: tf.Tensor([0], shape=(1,), dtype=int64),\n", " 2: tf.Tensor([1], shape=(1,), dtype=int64),\n", " 3: tf.Tensor([3], shape=(1,), dtype=int64)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor([0], shape=(1,), dtype=int64),\n", " 1: tf.Tensor([1], shape=(1,), dtype=int64),\n", " 2: tf.Tensor([3], shape=(1,), dtype=int64),\n", " 3: tf.Tensor([0], shape=(1,), dtype=int64)\n", "}\n" ] } ], "source": [ "strategy = tf.distribute.MirroredStrategy()\n", "vocab = [\"a\", \"b\", \"c\", \"d\", \"f\"]\n", "\n", "with strategy.scope():\n", " # Create the layer(s) under scope.\n", " layer = tf.keras.layers.StringLookup(vocabulary=vocab)\n", "\n", "def dataset_fn(input_context):\n", " # a tf.data.Dataset\n", " dataset = tf.data.Dataset.from_tensor_slices([\"a\", \"c\", \"e\"]).repeat()\n", "\n", " # Custom your batching, sharding, prefetching, etc.\n", " global_batch_size = 4\n", " batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n", " dataset = dataset.batch(batch_size)\n", " dataset = dataset.shard(\n", " input_context.num_input_pipelines,\n", " input_context.input_pipeline_id)\n", "\n", " # Apply the preprocessing layer(s) to the tf.data.Dataset\n", " def preprocess_with_kpl(input):\n", " return layer(input)\n", "\n", " processed_ds = dataset.map(preprocess_with_kpl)\n", " return processed_ds\n", "\n", "distributed_dataset = strategy.distribute_datasets_from_function(dataset_fn)\n", "\n", "# Print out a few example batches.\n", "distributed_dataset_iterator = iter(distributed_dataset)\n", "for _ in range(3):\n", " print(next(distributed_dataset_iterator))" ] }, { "cell_type": "markdown", "metadata": { "id": "PVl1cblWQy8b" }, "source": [ "`tf.distribute.experimental.ParameterServerStrategy` でトレーニングしている場合は、`tf.distribute.experimental.coordinator.ClusterCoordinator.create_per_worker_dataset` も呼び出すことに注意してください。\n", "\n", "```\n", "@tf.function\n", "def per_worker_dataset_fn():\n", " return strategy.distribute_datasets_from_function(dataset_fn)\n", "\n", "per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)\n", "per_worker_iterator = iter(per_worker_dataset)\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Ol7SmPID1dAt" }, "source": [ "Tensorflow Transform の場合、前述のように、分析段階はトレーニングとは別に行われるため、ここでは触れません。詳細については、[チュートリアル](https://www.tensorflow.org/tfx/tutorials/transform/census)を参照してください。通常、この段階では、`tf.Transform` 前処理関数を作成し、この前処理関数を使用して [Apache Beam](https://beam.apache.org/) パイプラインでデータを変換します。分析段階の最後に、トレーニングとサービングの両方に使用できる TensorFlow グラフとして出力をエクスポートできます。この例では、トレーニングパイプラインの部分のみを扱います。\n", "\n", "```\n", "with strategy.scope():\n", " # working_dir contains the tf.Transform output.\n", " tf_transform_output = tft.TFTransformOutput(working_dir)\n", " # Loading from working_dir to create a Keras layer for applying the tf.Transform output to data\n", " tft_layer = tf_transform_output.transform_features_layer()\n", " ...\n", "\n", "def dataset_fn(input_context):\n", " ...\n", " dataset.map(tft_layer, num_parallel_calls=tf.data.AUTOTUNE)\n", " ...\n", " return dataset\n", "\n", "distributed_dataset = strategy.distribute_datasets_from_function(dataset_fn)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "3_IQxRXxQWof" }, "source": [ "## 部分バッチ" ] }, { "cell_type": "markdown", "metadata": { "id": "hW2_gVkiztUG" }, "source": [ "1)ユーザーが作成する `tf.data.Dataset` インスタンスに、レプリカの数で均等に除算できないバッチサイズが含まれていない場合、または 2)データセットインスタンスのカーディナリティがバッチサイズで除算できない場合に、部分バッチが発生します。つまり、データセットが複数のレプリカに分散される場合、一部のイテレータでの `next` 呼び出しが、`tf.errors.OutOfRangeError` になります。このユースケースに対応するために、`tf.distribute` は、処理するデータが残っていないレプリカで、バッチサイズ `0` のダミーバッチを返します。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "rqutdpqtPcCH" }, "source": [ "単一のワーカーの場合、データがイテレータの `next` 呼び出しで返されない場合に、バッチサイズ 0 のダミーバッチが作成され、データセットの実際のデータとともに使用されます。部分バッチの場合、データの最後のグローバルバッチには、データのダミーバッチとともに実際のデータが含まれます。データ処理に使用する抑止条件では、レプリカにデータが存在するかどうかを確認するようになっています。データが存在しないレプリカが検出されると、`tf.errors.OutOfRangeError` エラーが発生します。\n", "\n", "一方、複数のワーカーの場合は、クロスレプリカ通信を使用して各ワーカーのデータの存在を表すブール値が集計されます。これは、すべてのワーカーが分散データセットの処理を終了したことを識別するために使用されます。これにはクロスワーカー通信が伴うため、パフォーマンスに何らかの悪影響が及びます。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vehLsgljz90Y" }, "source": [ "## 注意事項" ] }, { "cell_type": "markdown", "metadata": { "id": "Nx4jyN_Az-Dy" }, "source": [ "- マルチワーカーセットアップで `tf.distribute.Strategy.experimental_distribute_dataset` API を使用する場合、ユーザーはファイルから読み取る `tf.data.Dataset` を渡します。`tf.data.experimental.AutoShardPolicy` が `AUTO` または `FILE` に設定されている場合、ステップごとの実際のバッチサイズがユーザー定義のグローバルバッチサイズより小さくなる可能性があります。これは、ファイルの残りの要素がグローバルバッチサイズより少なくなる場合に発生することがあります。ユーザーは、実行するステップ数に依存せずにデータセットを使い果たすか、`tf.data.experimental.AutoShardPolicy` を `DATA` に設定してこれを回避することができます。\n", "\n", "- ステートフルデータセット変換は、現在 `tf.distribute` でサポートされていません。データセットにあるステートフル演算は、現在のところ無視されます。たとえば、データセットに `tf.random.uniform` を使って画像を回転させる `map_fn` がある場合、Python プロセスが実行されているローカルマシン上の状態 (ランダムシード) に依存するデータセットグラフがあります。\n", "\n", "- デフォルトで無効化されている実験的な `tf.data.experimental.OptimizationOptions` は、`tf.distribute` と併せて使用されている場合などでは、パフォーマンスの低下を生じる可能性があります。分散環境においてワークロードのパフォーマンスに有益であることが確認されてから、有効化するようにしてください。\n", "\n", "- 一般的に tf.data を使用して入力パイプラインを最適化する方法については、このガイドを参照してください。また、以下のヒントをご覧ください。\n", "\n", " - 複数のワーカーがあり、`tf.data.Dataset.list_files` を使用して、1 つ以上の glob パターンに一致するすべてのファイルからデータセットを作成している場合は、必ず `seed` 引数を設定するか、`shuffle=False` を設定して、各ワーカーが一貫してファイルを分割するようにします。\n", "\n", "- 入力パイプラインにレコードレベルでのデータのシャッフルとデータの解析の両方が含まれている場合、解析されていないデータが解析されたデータよりも大幅に大きくない限り (通常はそうではありません)、次の例に示すように、最初にシャッフルしてから解析します。これにより、メモリ使用量とパフォーマンスが向上することがあります。\n", "\n", "```\n", "d = tf.data.Dataset.list_files(pattern, shuffle=False)\n", "d = d.shard(num_workers, worker_index)\n", "d = d.repeat(num_epochs)\n", "d = d.shuffle(shuffle_buffer_size)\n", "d = d.interleave(tf.data.TFRecordDataset,\n", " cycle_length=num_readers, block_length=1)\n", "d = d.map(parser_fn, num_parallel_calls=num_map_threads)\n", "```\n", "\n", "- `tf.data.Dataset.shuffle(buffer_size, seed=None, reshuffle_each_iteration=None)` は、`buffer_size` 要素の内部バッファを維持し、`buffer_size` を削減します。これにより、OOM の問題が軽減される可能性があります。" ] }, { "cell_type": "markdown", "metadata": { "id": "dAC_vRmJyzrB" }, "source": [ "- `tf.distribute.experimental_distribute_dataset` または`tf.distribute.distribute_datasets_from_function` を使用している場合、ワーカーがデータを処理する順番は保証されていません。これは通常、`tf.distribute` を使用して予測をスケーリングする場合に必要です。ただし、バッチの各要素に対するインデックスを挿入し、それに従って出力を順序付けることができます。次のスニペットは、出力を順序付ける方法を示します。\n", "\n", "注意: ここでは便宜上、`tf.distribute.MirroredStrategy` が使用されていますが、複数のワーカーを使用しており、単一ワーカーへの分散に `tf.distribute.MirroredStrategy` が使用されている場合には、入力の順番のみを変更する必要があります。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:26:09.470013Z", "iopub.status.busy": "2024-01-11T18:26:09.469369Z", "iopub.status.idle": "2024-01-11T18:26:09.841403Z", "shell.execute_reply": "2024-01-11T18:26:09.840589Z" }, "id": "Zr2xAy-uZZaL" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{0: 0, 1: 2, 2: 4, 3: 6, 4: 8, 5: 10, 6: 12, 7: 14, 8: 16, 9: 18, 10: 20, 11: 22, 12: 24, 13: 26, 14: 28, 15: 30, 16: 32, 17: 34, 18: 36, 19: 38, 20: 40, 21: 42, 22: 44, 23: 46}\n" ] } ], "source": [ "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "dataset_size = 24\n", "batch_size = 6\n", "dataset = tf.data.Dataset.range(dataset_size).enumerate().batch(batch_size)\n", "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n", "\n", "def predict(index, inputs):\n", " outputs = 2 * inputs\n", " return index, outputs\n", "\n", "result = {}\n", "for index, inputs in dist_dataset:\n", " output_index, outputs = mirrored_strategy.run(predict, args=(index, inputs))\n", " indices = list(mirrored_strategy.experimental_local_results(output_index))\n", " rindices = []\n", " for a in indices:\n", " rindices.extend(a.numpy())\n", " outputs = list(mirrored_strategy.experimental_local_results(outputs))\n", " routputs = []\n", " for a in outputs:\n", " routputs.extend(a.numpy())\n", " for i, value in zip(rindices, routputs):\n", " result[i] = value\n", "\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": { "id": "nNbn7HXx0YqB" }, "source": [ " ## tf.data の代わりにテンソル入力を使用する" ] }, { "cell_type": "markdown", "metadata": { "id": "dymZixqo0nKK" }, "source": [ "入力を表現する`tf.data.Dataset`と、上記に示した、複数のデバイスにデータセットを分散する後続の API を使用できないことがあります。このような場合は、生のテンソルを使用するか、ジェネレータの入力を使用することができます。\n", "\n", "### 任意のテンソル入力に experimental_distribute_values_from_function を使用する\n", "\n", "`strategy.run` は、`next(iterator)` の出力である `tf.distribute.DistributedValues` を受け入れます。テンソル値を渡すには、`tf.distribute.Strategy.experimental_distribute_values_from_function` を使用して生のテンソルから `tf.distribute.DistributedValues` を構築します。ユーザーは、`tf.distribute.experimental.ValueContext` 入力オブジェクトを使用して、このオプションを使用して入力関数で独自のバッチ処理およびシャーディングロジックを指定する必要があります。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:26:09.845330Z", "iopub.status.busy": "2024-01-11T18:26:09.844660Z", "iopub.status.idle": "2024-01-11T18:26:09.864001Z", "shell.execute_reply": "2024-01-11T18:26:09.863387Z" }, "id": "ajZHNRQs0kqm" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor(0, shape=(), dtype=int32),\n", " 1: tf.Tensor(1, shape=(), dtype=int32),\n", " 2: tf.Tensor(2, shape=(), dtype=int32),\n", " 3: tf.Tensor(3, shape=(), dtype=int32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor(0, shape=(), dtype=int32),\n", " 1: tf.Tensor(1, shape=(), dtype=int32),\n", " 2: tf.Tensor(2, shape=(), dtype=int32),\n", " 3: tf.Tensor(3, shape=(), dtype=int32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor(0, shape=(), dtype=int32),\n", " 1: tf.Tensor(1, shape=(), dtype=int32),\n", " 2: tf.Tensor(2, shape=(), dtype=int32),\n", " 3: tf.Tensor(3, shape=(), dtype=int32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor(0, shape=(), dtype=int32),\n", " 1: tf.Tensor(1, shape=(), dtype=int32),\n", " 2: tf.Tensor(2, shape=(), dtype=int32),\n", " 3: tf.Tensor(3, shape=(), dtype=int32)\n", "}\n" ] } ], "source": [ "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "\n", "def value_fn(ctx):\n", " return tf.constant(ctx.replica_id_in_sync_group)\n", "\n", "distributed_values = mirrored_strategy.experimental_distribute_values_from_function(value_fn)\n", "for _ in range(4):\n", " result = mirrored_strategy.run(lambda x: x, args=(distributed_values,))\n", " print(result)" ] }, { "cell_type": "markdown", "metadata": { "id": "P98aFQGf0x_7" }, "source": [ "### ジェネレータからの入力である場合に tf.data.Dataset.from_generator を使用する" ] }, { "cell_type": "markdown", "metadata": { "id": "emZCWQSi04qT" }, "source": [ "使用を検討しているジェネレータ関数がある場合は、`from_generator` API を使用して`tf.data.Dataset` インスタンスを作成できます。\n", "\n", "注意: 現在のところ、`tf.distribute.TPUStrategy` ではサポートされていません。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T18:26:09.867432Z", "iopub.status.busy": "2024-01-11T18:26:09.867050Z", "iopub.status.idle": "2024-01-11T18:26:10.252708Z", "shell.execute_reply": "2024-01-11T18:26:10.251613Z" }, "id": "jRhU0X230787" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor([0.40325633], shape=(1,), dtype=float32),\n", " 1: tf.Tensor([0.28778756], shape=(1,), dtype=float32),\n", " 2: tf.Tensor([0.5146715], shape=(1,), dtype=float32),\n", " 3: tf.Tensor([0.3218396], shape=(1,), dtype=float32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor([0.79922175], shape=(1,), dtype=float32),\n", " 1: tf.Tensor([0.02518538], shape=(1,), dtype=float32),\n", " 2: tf.Tensor([0.27494904], shape=(1,), dtype=float32),\n", " 3: tf.Tensor([0.54404545], shape=(1,), dtype=float32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor([0.7162087], shape=(1,), dtype=float32),\n", " 1: tf.Tensor([0.76167136], shape=(1,), dtype=float32),\n", " 2: tf.Tensor([0.8244246], shape=(1,), dtype=float32),\n", " 3: tf.Tensor([0.37525535], shape=(1,), dtype=float32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor([0.8093572], shape=(1,), dtype=float32),\n", " 1: tf.Tensor([0.0389544], shape=(1,), dtype=float32),\n", " 2: tf.Tensor([0.5250396], shape=(1,), dtype=float32),\n", " 3: tf.Tensor([0.04613635], shape=(1,), dtype=float32)\n", "}\n" ] } ], "source": [ "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "def input_gen():\n", " while True:\n", " yield np.random.rand(4)\n", "\n", "# use Dataset.from_generator\n", "dataset = tf.data.Dataset.from_generator(\n", " input_gen, output_types=(tf.float32), output_shapes=tf.TensorShape([4]))\n", "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n", "iterator = iter(dist_dataset)\n", "for _ in range(4):\n", " result = mirrored_strategy.run(lambda x: x, args=(next(iterator),))\n", " print(result)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "input.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }