{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "MhoQ0WE77laV" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-11-07T23:10:58.102099Z", "iopub.status.busy": "2023-11-07T23:10:58.101852Z", "iopub.status.idle": "2023-11-07T23:10:58.105808Z", "shell.execute_reply": "2023-11-07T23:10:58.105236Z" }, "id": "_ckMIh7O7s6D" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "jYysdyb-CaWM" }, "source": [ "# 分布式输入" ] }, { "cell_type": "markdown", "metadata": { "id": "S5Uhzt6vVIB2" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看在 Google Colab 中运行在 GitHub 上查看源代码下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "FbVhjPpzn6BM" }, "source": [ "[tf.distribute](https://tensorflow.google.cn/guide/distributed_training) API 为用户提供了一种简单的方法,可将训练范围从一台计算机扩展到多台计算机。扩展模型时,用户还必须将其输入分布到多个设备上。`tf.distribute` 提供了相应的 API,您可以利用这些 API 在设备之间自动分布输入。\n", "\n", "本指南将展示使用 `tf.distribute` API 创建分布式数据集和迭代器的不同方法。此外,还将涵盖以下主题:\n", "\n", "- 使用 `tf.distribute.Strategy.experimental_distribute_dataset` 和 `tf.distribute.Strategy.distribute_datasets_from_function` 时的用法、分片和批处理选项。\n", "- 遍历分布式数据集的不同方式。\n", "- Differences between `tf.distribute.Strategy.experimental_distribute_dataset`/`tf.distribute.Strategy.distribute_datasets_from_function` APIs and `tf.data` APIs as well any limitations that users may come across in their usage.\n", "\n", "本指南不介绍如何将分布式输入与 Keras API 一起使用。" ] }, { "cell_type": "markdown", "metadata": { "id": "MM6W__qraV55" }, "source": [ "## 分布式数据集" ] }, { "cell_type": "markdown", "metadata": { "id": "lNy9GxjSlMKQ" }, "source": [ "要使用 `tf.distribute` API 扩缩,请使用 `tf.data.Dataset` 表示其输入。`tf.distribute` 可以与 `tf.data.Dataset` 高效地协同工作(例如,通过自动预提取到每个加速器设备和定期性能更新)。如果您有使用除 `tf.data.Dataset` 以外的其他 API 的用例,请参阅本指南中的[张量输入](#tensorinputs)部分。在非分布式训练循环中,首先创建一个 `tf.data.Dataset` 实例,然后迭代各个元素。例如:\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:10:58.109756Z", "iopub.status.busy": "2023-11-07T23:10:58.109512Z", "iopub.status.idle": "2023-11-07T23:11:00.688043Z", "shell.execute_reply": "2023-11-07T23:11:00.687256Z" }, "id": "pCu2Jj-21AEf" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 23:10:58.568819: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-11-07 23:10:58.568882: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-11-07 23:10:58.570564: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2.15.0-rc1\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "# Helper libraries\n", "import numpy as np\n", "import os\n", "\n", "print(tf.__version__)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:00.691920Z", "iopub.status.busy": "2023-11-07T23:11:00.691466Z", "iopub.status.idle": "2023-11-07T23:11:01.251410Z", "shell.execute_reply": "2023-11-07T23:11:01.250604Z" }, "id": "6cnilUtmKwpa" }, "outputs": [], "source": [ "# Simulate multiple CPUs with virtual devices\n", "N_VIRTUAL_DEVICES = 2\n", "physical_devices = tf.config.list_physical_devices(\"CPU\")\n", "tf.config.set_logical_device_configuration(\n", " physical_devices[0], [tf.config.LogicalDeviceConfiguration() for _ in range(N_VIRTUAL_DEVICES)])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:01.255836Z", "iopub.status.busy": "2023-11-07T23:11:01.255518Z", "iopub.status.idle": "2023-11-07T23:11:02.616821Z", "shell.execute_reply": "2023-11-07T23:11:02.616047Z" }, "id": "zd4l1ySeLRk1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Available devices:\n", "0) LogicalDevice(name='/device:CPU:0', device_type='CPU')\n", "1) LogicalDevice(name='/device:CPU:1', device_type='CPU')\n", "2) LogicalDevice(name='/device:GPU:0', device_type='GPU')\n", "3) LogicalDevice(name='/device:GPU:1', device_type='GPU')\n", "4) LogicalDevice(name='/device:GPU:2', device_type='GPU')\n", "5) LogicalDevice(name='/device:GPU:3', device_type='GPU')\n" ] } ], "source": [ "print(\"Available devices:\")\n", "for i, device in enumerate(tf.config.list_logical_devices()):\n", " print(\"%d) %s\" % (i, device))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:02.620893Z", "iopub.status.busy": "2023-11-07T23:11:02.620319Z", "iopub.status.idle": "2023-11-07T23:11:03.081448Z", "shell.execute_reply": "2023-11-07T23:11:03.080671Z" }, "id": "dzLKpmZICaWN" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(16, 1), dtype=float32)\n", "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(16, 1), dtype=float32)\n", "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(16, 1), dtype=float32)\n", "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(16, 1), dtype=float32)\n", "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(16, 1), dtype=float32)\n", "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(16, 1), dtype=float32)\n", "tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n" ] } ], "source": [ "global_batch_size = 16\n", "# Create a tf.data.Dataset object.\n", "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n", "\n", "@tf.function\n", "def train_step(inputs):\n", " features, labels = inputs\n", " return labels - 0.3 * features\n", "\n", "# Iterate over the dataset using the for..in construct.\n", "for inputs in dataset:\n", " print(train_step(inputs))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ihrhYDYRrVLH" }, "source": [ "为了在尽可能不更改用户现有代码的情况下使用户能够使用 `tf.distribute` 策略,我们引入了两个 API,它们将分配 `tf.data.Dataset` 实例并返回一个分布式数据集对象。随后,用户可以遍历此分布式数据集实例并像以前一样训练自己的模型。现在让我们更详细地看一下这两个 API - `tf.distribute.Strategy.experimental_distribute_dataset` 和 `tf.distribute.Strategy.distribute_datasets_from_function`:" ] }, { "cell_type": "markdown", "metadata": { "id": "4AXoHhrsbdF3" }, "source": [ "### `tf.distribute.Strategy.experimental_distribute_dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "5mVuLZhbem8d" }, "source": [ "#### 用法\n", "\n", "此 API 将 `tf.data.Dataset` 实例作为输入,并返回 `tf.distribute.DistributedDataset` 实例。您应当使用等于全局批次大小的值对输入数据集进行批处理。此全局批次大小是您要在所有设备中一步处理的样本数。您可以用 Python 样式迭代此分布式数据集,或者使用 `iter` 创建一个迭代器。返回的对象不是 `tf.data.Dataset` 实例,并且不支持以任何方式转换或检查数据集的任何其他 API。如果您没有特定的方式将输入分片到不同副本中,则建议使用此 API。\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:03.085820Z", "iopub.status.busy": "2023-11-07T23:11:03.085130Z", "iopub.status.idle": "2023-11-07T23:11:04.446050Z", "shell.execute_reply": "2023-11-07T23:11:04.445250Z" }, "id": "F2VeZUWUj5S4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "}, PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] } ], "source": [ "global_batch_size = 16\n", "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "\n", "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n", "# Distribute input using the `experimental_distribute_dataset`.\n", "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n", "# 1 global batch of data fed to the model in 1 step.\n", "print(next(iter(dist_dataset)))" ] }, { "cell_type": "markdown", "metadata": { "id": "QPceDmRht54F" }, "source": [ "#### 属性" ] }, { "cell_type": "markdown", "metadata": { "id": "0Qb6nDgxiN_n" }, "source": [ "##### 批处理\n", "\n", "`tf.distribute` 使用新的批次大小(等于全局批次大小除以同步副本数)对输入 `tf.data.Dataset` 实例进行重新批处理。同步副本数等于训练期间参与梯度全归约的设备数。当用户在分布式迭代器上调用 `next` 时,将在每个副本上返回数据的每个副本批次大小。经过重新批处理的数据集基数将始终为副本数的倍数。下面是一些示例:\n", "\n", "- `tf.data.Dataset.range(6).batch(4, drop_remainder=False)`\n", "\n", " - 无分布:\n", "\n", " - 批次 1:[0, 1, 2, 3]\n", " - 批次 2:[4, 5]\n", "\n", " - 分布在 2 个副本上。最后一个批次 ([4, 5]) 被拆分到 2 个副本中。\n", "\n", " - 批次 1:\n", "\n", " - 副本 1:[0, 1]\n", " - 副本 2:[2, 3]\n", "\n", " - 批次 2:\n", "\n", " - 副本 1:[4]\n", " - 副本 2:[5]\n", "\n", "- `tf.data.Dataset.range(4).batch(4)`\n", "\n", " - 无分布:\n", " - 批次 1:[0, 1, 2, 3]\n", " - 分布在 5 个副本上:\n", " - 批次 1:\n", " - 副本 1:[0]\n", " - 副本 2:[1]\n", " - 副本 3:[2]\n", " - 副本 4:[3]\n", " - 副本 5:[]\n", "\n", "- `tf.data.Dataset.range(8).batch(4)`\n", "\n", " - 无分布:\n", " - 批次 1:[0, 1, 2, 3]\n", " - 批次 2:[4, 5, 6, 7]\n", " - 分布在 3 个副本上:\n", " - 批次 1:\n", " - 副本 1:[0, 1]\n", " - 副本 2:[2, 3]\n", " - 副本 3:[]\n", " - 批次 2:\n", " - 副本 1:[4, 5]\n", " - 副本 2:[6, 7]\n", " - 副本 3:[]\n", "\n", "无分布:\n", "\n", "对数据集进行重新批处理的空间复杂度随副本数量线性增加。对于多工作器训练用例,这意味着输入流水线可能会遇到 OOM 错误。 " ] }, { "cell_type": "markdown", "metadata": { "id": "IszBuubdtydp" }, "source": [ "##### 分片\n", "\n", "`tf.distribute` 还使用 `MultiWorkerMirroredStrategy` 和 `TPUStrategy` 在多工作进程训练中自动分片输入数据集。每个数据集都是在工作进程的 CPU 设备上创建的。在一组工作进程上自动分片数据集意味着每个工作进程都被分配了整个数据集的一个子集(如果设置了正确的 `tf.data.experimental.AutoShardPolicy`)。这是为了确保在每个步骤中,每个工作进程都将处理非重叠数据集元素的全局批次大小。自动分片有几个不同的选项,可以使用 `tf.data.experimental.DistributeOptions` 来指定。请注意,使用 `ParameterServerStrategy` 的多工作进程训练中没有自动分片,有关使用此策略创建数据集的更多信息,请参阅[参数服务器策略教程](parameter_server_training.ipynb)。 " ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:04.450481Z", "iopub.status.busy": "2023-11-07T23:11:04.449754Z", "iopub.status.idle": "2023-11-07T23:11:04.460497Z", "shell.execute_reply": "2023-11-07T23:11:04.459788Z" }, "id": "jwJtsCQhHK-E" }, "outputs": [], "source": [ "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(64).batch(16)\n", "options = tf.data.Options()\n", "options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA\n", "dataset = dataset.with_options(options)" ] }, { "cell_type": "markdown", "metadata": { "id": "J7fj3GskHC8g" }, "source": [ "您可以为 `tf.data.experimental.AutoShardPolicy` 设置三个不同的选项:\n", "\n", "- AUTO:这是默认选项,意味着将尝试按 FILE 分片。如果未检测到基于文件的数据集,则按 FILE 分片的尝试失败。随后,`tf.distribute` 将退回到按 DATA 分片。请注意,如果输入数据集基于文件,但文件数小于工作进程数,则会引发错误。\n", "\n", "- FILE:如果您想将输入文件分片到所有工作进程上,则可以使用此选项。如果输入文件的数量远大于工作进程的数量并且文件中的数据均匀分布,则应使用此选项。如果文件中的数据分布不均匀,则此选项的缺点是有空闲的工作进程。如果文件数量小于工作进程数量,则会引发 `InvalidArgumentError`。如果发生这种情况,请将策略显式设置为 `AutoShardPolicy.DATA`。例如,我们将 2 个文件分布在 2 个工作进程上,每个工作进程有 1 个副本。文件 1 包含 [0, 1, 2, 3, 4, 5],文件 2 包含 [6, 7, 8, 9, 10, 11]。假设同步的副本总数为 2,全局批次大小为 4。\n", "\n", " - 工作进程 0:\n", " - 批次 1 = 副本 1:[0, 1]\n", " - 批次 2 = 副本 1:[2, 3]\n", " - 批次 3 = 副本 1:[4]\n", " - 批次 4 = 副本 1:[5]\n", " - 工作进程 1:\n", " - 批次 1 = 副本 2:[6, 7]\n", " - 批次 2 = 副本 2:[8, 9]\n", " - 批次 3 = 副本 2:[10]\n", " - 批次 4 = 副本 2:[11]\n", "\n", "- DATA:这将在所有工作进程中对元素自动分片。每个工作进程都会读取整个数据集,并且仅处理分配给它的分片。所有其他分片将被丢弃。如果输入文件数小于工作进程数,并且您希望跨所有工作进程对数据更好地分片,通常使用此方法。这种方法的缺点是,将在每个工作进程上读取整个数据集。例如,假设我们将 1 个文件分布到 2 个工作进程中。文件 1 包含 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]。假设同步副本总数为 2。\n", "\n", " - 工作进程 0:\n", " - 批次 1 = 副本 1:[0, 1]\n", " - 批次 2 = 副本 1:[4, 5]\n", " - 批次 3 = 副本 1:[8, 9]\n", " - 工作进程 1:\n", " - 批次 1 = 副本 2:[2, 3]\n", " - 批次 2 = 副本 2:[6, 7]\n", " - 批次 3 = 副本 2:[10, 11]\n", "\n", "- OFF:如果关闭自动分片,则每个工作进程都将处理所有数据。例如,假设我们将 1 个文件分布到 2 个工作进程中。文件 1 包含 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]。假设同步副本总数为 2。那么每个工作器的分布如下:\n", "\n", " - 工作进程 0:\n", "\n", " - 批次 1 = 副本 1:[0, 1]\n", " - 批次 2 = 副本 1:[2, 3]\n", " - 批次 3 = 副本 1:[4, 5]\n", " - 批次 4 = 副本 1:[6, 7]\n", " - 批次 5 = 副本 1:[8, 9]\n", " - 批次 6 = 副本 1:[10, 11]\n", "\n", " - 工作进程 1:\n", "\n", " - 批次 1 = 副本 2:[0, 1]\n", " - 批次 2 = 副本 2:[2, 3]\n", " - 批次 3 = 副本 2:[4, 5]\n", " - 批次 4 = 副本 2:[6, 7]\n", " - 批次 5 = 副本 2:[8, 9]\n", " - 批次 6 = 副本 2:[10, 11] " ] }, { "cell_type": "markdown", "metadata": { "id": "OK46ZJGPH5H2" }, "source": [ "##### 预提取\n", "\n", "默认情况下,`tf.distribute` 会向用户提供的 `tf.data.Dataset` 实例末尾添加预提取转换。预提取转换的参数 `buffer_size` 等于同步副本数。" ] }, { "cell_type": "markdown", "metadata": { "id": "PjiGSY3gtr6_" }, "source": [ "### `tf.distribute.Strategy.distribute_datasets_from_function`" ] }, { "cell_type": "markdown", "metadata": { "id": "bAXAo_wWbWSb" }, "source": [ "#### 用法\n", "\n", "此 API 使用输入函数并返回 `tf.distribute.DistributedDataset` 实例。用户传入的输入函数具有 `tf.distribute.InputContext` 参数,并且应返回 `tf.data.Dataset` 实例。使用此 API,`tf.distribute` 不会对从输入函数返回的用户 `tf.data.Dataset` 实例进行任何进一步的更改。用户负责对数据集进行批处理和分片。`tf.distribute` 调用每个工作器的 CPU 设备上的输入函数。除了允许用户指定自己的批处理和分片逻辑外,当此 API 用于多工作器训练时,还表现出比 `tf.distribute.Strategy.experimental_distribute_dataset` 更出色的可扩展性和性能。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:04.463974Z", "iopub.status.busy": "2023-11-07T23:11:04.463676Z", "iopub.status.idle": "2023-11-07T23:11:04.479809Z", "shell.execute_reply": "2023-11-07T23:11:04.479092Z" }, "id": "9ODch-OFCaW4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] } ], "source": [ "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "\n", "def dataset_fn(input_context):\n", " batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n", " dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(64).batch(16)\n", " dataset = dataset.shard(\n", " input_context.num_input_pipelines, input_context.input_pipeline_id)\n", " dataset = dataset.batch(batch_size)\n", " dataset = dataset.prefetch(2) # This prefetches 2 batches per device.\n", " return dataset\n", "\n", "dist_dataset = mirrored_strategy.distribute_datasets_from_function(dataset_fn)" ] }, { "cell_type": "markdown", "metadata": { "id": "M1bpzPYzt_R7" }, "source": [ "#### 属性" ] }, { "cell_type": "markdown", "metadata": { "id": "7cgzhwiiuBvO" }, "source": [ "##### 批处理\n", "\n", "应当使用每个副本的批次大小对作为输入函数返回值的 `tf.data.Dataset` 实例进行批处理。每个副本的批次大小等于全局批次大小除以参与同步训练的副本数。这是因为 `tf.distribute` 会在每个工作进程的 CPU 设备上调用输入函数。在给定工作进程上创建的数据集应准备好供该工作进程上的所有副本使用。 " ] }, { "cell_type": "markdown", "metadata": { "id": "e-wlFFZbP33n" }, "source": [ "##### 分片\n", "\n", "`tf.distribute.InputContext` 对象由 `tf.distribute` 在后台创建,它作为参数隐式传递到用户的输入函数。它包含有关工作器数、当前工作器 ID 等方面的信息。此输入函数可以根据用户使用这些属性(属于 `tf.distribute.InputContext` 对象的一部分)设置的策略来处理分片。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "7TGwnDM-ICHf" }, "source": [ "##### 预提取\n", "\n", "`tf.distribute` 不会在用户提供的输入函数所返回的 `tf.data.Dataset` 的末尾添加预提取转换,因此您需要在上例中显式调用 `Dataset.prefetch`。" ] }, { "cell_type": "markdown", "metadata": { "id": "iOMsf8kyZZpv" }, "source": [ "注:`tf.distribute.Strategy.experimental_distribute_dataset` 和 `tf.distribute.Strategy.distribute_datasets_from_function` 都会返回不属于 `tf.data.Dataset` 类型的 **`tf.distribute.DistributedDataset` 实例。您可以对这些实例进行迭代(如分布式迭代器部分中所示)并使用 `element_spec` 属性。** " ] }, { "cell_type": "markdown", "metadata": { "id": "dL3XbI1gzEjO" }, "source": [ "## 分布式迭代器" ] }, { "cell_type": "markdown", "metadata": { "id": "w8y54-o9T2Ni" }, "source": [ "与非分布式 `tf.data.Dataset` 实例类似,您将需要在 `tf.distribute.DistributedDataset` 实例上创建一个迭代器以对其进行迭代,并访问 `tf.distribute.DistributedDataset` 中的元素。下面是创建 `tf.distribute.DistributedIterator` 并将其用于训练模型的方法:\n" ] }, { "cell_type": "markdown", "metadata": { "id": "FlKh8NV0uOtZ" }, "source": [ "### 用法" ] }, { "cell_type": "markdown", "metadata": { "id": "eSZz6EqOuSlB" }, "source": [ "#### 使用 Python 式 for 循环结构\n", "\n", "您可以使用用户友好的 Python 式循环对 `tf.distribute.DistributedDataset` 进行迭代。从 `tf.distribute.DistributedIterator` 返回的元素可以是单个 `tf.Tensor` 或包含每个副本的值的 `tf.distribute.DistributedValues`。将循环放置在 `tf.function` 内有助于提高性能。但是,目前不支持对放置在 `tf.function` 内的 `tf.distribute.DistributedDataset` 的循环使用 `break` 和 `return`。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:04.484073Z", "iopub.status.busy": "2023-11-07T23:11:04.483824Z", "iopub.status.idle": "2023-11-07T23:11:05.001871Z", "shell.execute_reply": "2023-11-07T23:11:05.000921Z" }, "id": "zt3AHb46Tr3w" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32),\n", " 1: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32),\n", " 2: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32),\n", " 3: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32)\n", "}\n" ] } ], "source": [ "global_batch_size = 16\n", "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "\n", "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n", "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n", "\n", "@tf.function\n", "def train_step(inputs):\n", " features, labels = inputs\n", " return labels - 0.3 * features\n", "\n", "for x in dist_dataset:\n", " # train_step trains the model using the dataset elements\n", " loss = mirrored_strategy.run(train_step, args=(x,))\n", " print(\"Loss is \", loss)" ] }, { "cell_type": "markdown", "metadata": { "id": "NchPwTEiuSqb" }, "source": [ "#### 使用 `iter` 创建显式迭代器\n", "\n", "要迭代 `tf.distribute.DistributedDataset` 实例中的元素,您可以在该实例上使用 `iter` API 创建一个 `tf.distribute.DistributedIterator`。使用显式迭代器,您可以迭代固定数量的步骤。为了从 `tf.distribute.DistributedIterator` 实例 `dist_iterator` 获取下一个元素,您可以调用 `next(dist_iterator)`、`dist_iterator.get_next()` 或 `dist_iterator.get_next_as_optional()`。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:05.006566Z", "iopub.status.busy": "2023-11-07T23:11:05.005902Z", "iopub.status.idle": "2023-11-07T23:11:08.370061Z", "shell.execute_reply": "2023-11-07T23:11:08.369141Z" }, "id": "OrMmakq5EqeQ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n", "Loss is PerReplica:{\n", " 0: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 1: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 2: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32),\n", " 3: tf.Tensor(\n", "[[0.7]\n", " [0.7]\n", " [0.7]\n", " [0.7]], shape=(4, 1), dtype=float32)\n", "}\n" ] } ], "source": [ "num_epochs = 10\n", "steps_per_epoch = 5\n", "for epoch in range(num_epochs):\n", " dist_iterator = iter(dist_dataset)\n", " for step in range(steps_per_epoch):\n", " # train_step trains the model using the dataset elements\n", " loss = mirrored_strategy.run(train_step, args=(next(dist_iterator),))\n", " # which is the same as\n", " # loss = mirrored_strategy.run(train_step, args=(dist_iterator.get_next(),))\n", " print(\"Loss is \", loss)" ] }, { "cell_type": "markdown", "metadata": { "id": "UpJXIlxjqPYg" }, "source": [ "使用 `next()` 或 `tf.distribute.DistributedIterator.get_next` 时,如果 `tf.distribute.DistributedIterator` 已到达末尾,将引发 OutOfRange 错误。客户端可以在 Python 端捕获该错误,并继续执行其他工作,例如设置检查点和评估。但是,如果您使用的是主机训练循环(即,每个 `tf.function` 运行多个步骤),这种方式将不会奏效,如下所示:\n", "\n", "```\n", "@tf.function\n", "def train_fn(iterator):\n", " for _ in tf.range(steps_per_loop):\n", " strategy.run(step_fn, args=(next(iterator),))\n", "```\n", "\n", "`train_fn` 通过将步骤主体封装在 `tf.range` 中来包含多个步骤。在这种情况下,循环中没有依赖项的不同迭代可以并行开始,因此会在先前迭代的计算完成之前在后续的迭代中触发 OutOfRange 错误。一旦抛出 OutOfRange 错误,函数中的所有运算都会立即终止。如果您想要避免这种情况,则不抛出 OutOfRange 错误的替代方案为 `tf.distribute.DistributedIterator.get_next_as_optional`。`get_next_as_optional` 返回 `tf.experimental.Optional`,其中包含下一个元素或者不包含任何值(如果 `tf.distribute.DistributedIterator` 已到达末尾)。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:08.374448Z", "iopub.status.busy": "2023-11-07T23:11:08.374157Z", "iopub.status.idle": "2023-11-07T23:11:09.060161Z", "shell.execute_reply": "2023-11-07T23:11:09.059317Z" }, "id": "Iyjao96Vqwyz" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "([0], [1], [2], [3])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "([4], [5], [6], [7])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "([8], [], [], [])\n" ] } ], "source": [ "# You can break the loop with `get_next_as_optional` by checking if the `Optional` contains a value\n", "global_batch_size = 4\n", "steps_per_loop = 5\n", "strategy = tf.distribute.MirroredStrategy()\n", "\n", "dataset = tf.data.Dataset.range(9).batch(global_batch_size)\n", "distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))\n", "\n", "@tf.function\n", "def train_fn(distributed_iterator):\n", " for _ in tf.range(steps_per_loop):\n", " optional_data = distributed_iterator.get_next_as_optional()\n", " if not optional_data.has_value():\n", " break\n", " per_replica_results = strategy.run(lambda x: x, args=(optional_data.get_value(),))\n", " tf.print(strategy.experimental_local_results(per_replica_results))\n", "train_fn(distributed_iterator)" ] }, { "cell_type": "markdown", "metadata": { "id": "LaclbKnqzLjf" }, "source": [ "## 使用 `element_spec` 属性" ] }, { "cell_type": "markdown", "metadata": { "id": "Z1YvXqOpwy08" }, "source": [ "如果将分布式数据集的元素传递给 `tf.function` 并且需要 `tf.TypeSpec` 保证,则可以指定 `tf.function` 的 `input_signature` 参数。分布式数据集的输出为 `tf.distribute.DistributedValues`,它可以表示单个设备或多个设备的输入。要获取与此分布式值相对应的 `tf.TypeSpec`,可以使用 `tf.distribute.DistributedDataset.element_spec` 或 `tf.distribute.DistributedIterator.element_spec`。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:09.064905Z", "iopub.status.busy": "2023-11-07T23:11:09.064122Z", "iopub.status.idle": "2023-11-07T23:11:11.131319Z", "shell.execute_reply": "2023-11-07T23:11:11.130485Z" }, "id": "pg3B-Cw_cn3a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "},\n", " PerReplica:{\n", " 0: ,\n", " 1: ,\n", " 2: ,\n", " 3: \n", "})\n" ] } ], "source": [ "global_batch_size = 16\n", "epochs = 5\n", "steps_per_epoch = 5\n", "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "\n", "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n", "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n", "\n", "@tf.function(input_signature=[dist_dataset.element_spec])\n", "def train_step(per_replica_inputs):\n", " def step_fn(inputs):\n", " return 2 * inputs\n", "\n", " return mirrored_strategy.run(step_fn, args=(per_replica_inputs,))\n", "\n", "for _ in range(epochs):\n", " iterator = iter(dist_dataset)\n", " for _ in range(steps_per_epoch):\n", " output = train_step(next(iterator))\n", " tf.print(output)" ] }, { "cell_type": "markdown", "metadata": { "id": "-OAa6svUzuWm" }, "source": [ "## 数据预处理" ] }, { "cell_type": "markdown", "metadata": { "id": "pSMrs3kJQexW" }, "source": [ "目前为止,您已经学习了如何分布 `tf.data.Dataset`。但在数据准备好用于模型之前,还需要对其进行预处理,例如对数据进行清理、转换和扩充。以下是两套方便的预处理工具:\n", "\n", "- [Keras 预处理层](https://tensorflow.google.cn/guide/keras/preprocessing_layers):一组可供开发者构建 Keras 原生输入处理流水线的 Keras 层。 一些 Keras 预处理层包含不可训练的状态,可以在初始化时设置或进行 `adapt`(请参阅 [Keras 预处理层指南](https://tensorflow.google.cn/guide/keras/preprocessing_layers)的 `adapt` 部分)。在分布有状态预处理层时,应将状态复制到所有工作进程。要使用这些层,您可以使其成为模型的一部分或将其应用于数据集。\n", "\n", "- [TensorFlow Transform (tf.Transform)](https://tensorflow.google.cn/tfx/transform/get_started):可供您通过数据预处理流水线定义实例级和全通数据转换的 TensorFlow 库。TensorFlow Transform 包含两个阶段。第一个阶段为分析阶段,该阶段会在全通进程中分析原始训练数据,以计算转换所需的统计数据,并会生成转换逻辑作为实例级运算。第二个阶段为转换阶段,该阶段会在实例级进程中转换原始训练数据。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Pd4aUCFdVlZ1" }, "source": [ "### Keras 预处理层与 TensorFlow Transform\n", "\n", "TensorFlow Transform 和 Keras 预处理层均支持在训练期间拆分预处理,并在推断期间将预处理与模型捆绑在一起,从而降低训练/应用偏差。\n", "\n", "TensorFlow Transform 已与 [TFX](https://tensorflow.google.cn/tfx) 深度集成,提供了一项可扩缩的映射-归约解决方案,可在与训练流水线分开的作业中分析和转换任何大小的数据集。如果您需要运行的数据集分析不适合在单台机器上进行,则 TensorFlow Transform 应是您的首选。\n", "\n", "Keras 预处理层则更适于首先从磁盘读取数据,然后在训练期间应用的预处理。它们能够无缝适配 Keras 库中的模型开发。它们支持通过 [`adapt`](https://tensorflow.google.cn/guide/keras/preprocessing_layers#the_adapt_method) 来分析较小的数据集,并支持诸如图像数据扩充等用例,在图像数据扩充中,每次传递输入数据集都会产生不同的训练样本。\n", "\n", "这两个库也可以混合使用,其中 TensorFlow Transform 用于输入数据分析和静态转换,Keras 预处理层用于训练时转换(例如,独热编码或数据扩充)。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MReKhhZpHUpj" }, "source": [ "### tf.distribute 最佳做法\n", "\n", "使用这两种工具都需要初始化应用于数据的转换逻辑,这可能会创建 TensorFlow 资源。这些资源或状态应复制到所有工作进程,以节省工作进程间或工作进程-协调器间的通信。为此,建议您在 `tf.distribute.Strategy.scope` 下创建 Keras 预处理层 `tft.TFTransformOutput.transform_features_layer` 或 `tft.TransformFeaturesLayer`,就像创建任何其他 Keras 层一样。\n", "\n", "以下示例分别演示了 `tf.distribute.Strategy` API 与高级 Keras `Model.fit` API 以及与自定义训练循环配合使用的用法。" ] }, { "cell_type": "markdown", "metadata": { "id": "rwEGMWuoX7kJ" }, "source": [ "#### 针对 Keras 预处理层用户的额外说明:\n", "\n", "**预处理层和大型词汇表**\n", "\n", "在多工作进程环境(例如,`tf.distribute.MultiWorkerMirroredStrategy`、`tf.distribute.experimental.ParameterServerStrategy`、`tf.distribute.TPUStrategy`)中处理大型词汇表(超过 1 GB)时,建议将词汇表保存至所有工作进程均可访问的静态文件中(例如,使用 Cloud Storage)。这将减少在训练期间向所有工作进程复制词汇表所花费的时间。\n", "\n", "**`tf.data` 流水线中的预处理与模型中的预处理**\n", "\n", "Keras 预处理层既可以作为模型的一部分应用,也可以直接应用于 `tf.data.Dataset`,但每种选项各具优势:\n", "\n", "- 在模型中应用预处理层可以使您的模型具备可移植性,并有助于减少训练/应用偏差。(有关详情,请参阅[使用预处理层](https://tensorflow.google.cn/guide/keras/preprocessing_layers#benefits_of_doing_preprocessing_inside_the_model_at_inference_time)指南中的*推断时在模型内部进行预处理的好处*部分)\n", "- 在 `tf.data` 流水线中应用预处理可以预提取或卸载至 CPU,这通常可以在使用加速器时提高性能。\n", "\n", "在一个或多个 TPU 上运行时,用户几乎应始终将 Keras 预处理层置于 `tf.data` 流水线内,因为并非所有层都支持 TPU,并且无法在 TPU 上执行字符串运算。(`tf.keras.layers.Normalization` 和 `tf.keras.layers.Rescaling` 是两个例外,它们在 TPU 上运行良好,并且常被用作图像模型中的第一层。)" ] }, { "cell_type": "markdown", "metadata": { "id": "hNCYZ9L-BD2R" }, "source": [ "### 使用 `Model.fit` 进行预处理" ] }, { "cell_type": "markdown", "metadata": { "id": "NhRB2Xe8B6bX" }, "source": [ "使用 Keras `Model.fit` 时,您不需要使用 `tf.distribute.Strategy.experimental_distribute_dataset` 或 `tf.distribute.Strategy.distribute_datasets_from_function` 自行分布数据。请参阅[使用预处理层](https://tensorflow.google.cn/guide/keras/preprocessing_layers)指南和[使用 Keras 进行分布式训练](https://tensorflow.google.cn/tutorials/distribute/keras)指南以了解详情。一个简短的示例如下所示:\n", "\n", "```\n", "strategy = tf.distribute.MirroredStrategy()\n", "with strategy.scope():\n", " # Create the layer(s) under scope.\n", " integer_preprocessing_layer = tf.keras.layers.IntegerLookup(vocabulary=FILE_PATH)\n", " model = ...\n", " model.compile(...)\n", "dataset = dataset.map(lambda x, y: (integer_preprocessing_layer(x), y))\n", "model.fit(dataset)\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "id": "3zL2vzJ-G0yg" }, "source": [ "使用 `tf.distribute.experimental.ParameterServerStrategy` 和 `Model.fit` API 的用户需要使用 `tf.keras.utils.experimental.DatasetCreator` 作为输入。(请参阅[参数服务器训练](https://tensorflow.google.cn/tutorials/distribute/parameter_server_training#parameter_server_training_with_modelfit_api)指南以了解详情。)\n", "\n", "```\n", "strategy = tf.distribute.experimental.ParameterServerStrategy(\n", " cluster_resolver,\n", " variable_partitioner=variable_partitioner)\n", "\n", "with strategy.scope():\n", " preprocessing_layer = tf.keras.layers.StringLookup(vocabulary=FILE_PATH)\n", " model = ...\n", " model.compile(...)\n", "\n", "def dataset_fn(input_context):\n", " ...\n", " dataset = dataset.map(preprocessing_layer)\n", " ...\n", " return dataset\n", "\n", "dataset_creator = tf.keras.utils.experimental.DatasetCreator(dataset_fn)\n", "model.fit(dataset_creator, epochs=5, steps_per_epoch=20, callbacks=callbacks)\n", "\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "imZLQUOYBJyW" }, "source": [ "### 使用自定义训练循环进行预处理" ] }, { "cell_type": "markdown", "metadata": { "id": "r2PX1QH_OwU3" }, "source": [ "编写[自定义训练循环](https://tensorflow.google.cn/tutorials/distribute/custom_training)时,您将使用 `tf.distribute.Strategy.experimental_distribute_dataset` API 或 `tf.distribute.Strategy.distribute_datasets_from_function` API 分布数据。如果您通过 `tf.distribute.Strategy.experimental_distribute_dataset` 分布数据集,则在数据流水线中应用这些预处理 API 将导致资源自动与数据流水线归于同一位置,以避免远程资源访问。因此,这里的示例都将使用 `tf.distribute.Strategy.distribute_datasets_from_function`,在这种情况下,必须在 `strategy.scope()` 下放置这些 API 的初始化以提高效率:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:11.136642Z", "iopub.status.busy": "2023-11-07T23:11:11.135877Z", "iopub.status.idle": "2023-11-07T23:11:11.682141Z", "shell.execute_reply": "2023-11-07T23:11:11.681340Z" }, "id": "wJS1UmcWQeab" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor([1], shape=(1,), dtype=int64),\n", " 1: tf.Tensor([3], shape=(1,), dtype=int64),\n", " 2: tf.Tensor([0], shape=(1,), dtype=int64),\n", " 3: tf.Tensor([1], shape=(1,), dtype=int64)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor([3], shape=(1,), dtype=int64),\n", " 1: tf.Tensor([0], shape=(1,), dtype=int64),\n", " 2: tf.Tensor([1], shape=(1,), dtype=int64),\n", " 3: tf.Tensor([3], shape=(1,), dtype=int64)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor([0], shape=(1,), dtype=int64),\n", " 1: tf.Tensor([1], shape=(1,), dtype=int64),\n", " 2: tf.Tensor([3], shape=(1,), dtype=int64),\n", " 3: tf.Tensor([0], shape=(1,), dtype=int64)\n", "}\n" ] } ], "source": [ "strategy = tf.distribute.MirroredStrategy()\n", "vocab = [\"a\", \"b\", \"c\", \"d\", \"f\"]\n", "\n", "with strategy.scope():\n", " # Create the layer(s) under scope.\n", " layer = tf.keras.layers.StringLookup(vocabulary=vocab)\n", "\n", "def dataset_fn(input_context):\n", " # a tf.data.Dataset\n", " dataset = tf.data.Dataset.from_tensor_slices([\"a\", \"c\", \"e\"]).repeat()\n", "\n", " # Custom your batching, sharding, prefetching, etc.\n", " global_batch_size = 4\n", " batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n", " dataset = dataset.batch(batch_size)\n", " dataset = dataset.shard(\n", " input_context.num_input_pipelines,\n", " input_context.input_pipeline_id)\n", "\n", " # Apply the preprocessing layer(s) to the tf.data.Dataset\n", " def preprocess_with_kpl(input):\n", " return layer(input)\n", "\n", " processed_ds = dataset.map(preprocess_with_kpl)\n", " return processed_ds\n", "\n", "distributed_dataset = strategy.distribute_datasets_from_function(dataset_fn)\n", "\n", "# Print out a few example batches.\n", "distributed_dataset_iterator = iter(distributed_dataset)\n", "for _ in range(3):\n", " print(next(distributed_dataset_iterator))" ] }, { "cell_type": "markdown", "metadata": { "id": "PVl1cblWQy8b" }, "source": [ "请注意,如果您使用 `tf.distribute.experimental.ParameterServerStrategy` 进行训练,那么您还将调用 `tf.distribute.experimental.coordinator.ClusterCoordinator.create_per_worker_dataset`\n", "\n", "```\n", "@tf.function\n", "def per_worker_dataset_fn():\n", " return strategy.distribute_datasets_from_function(dataset_fn)\n", "\n", "per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)\n", "per_worker_iterator = iter(per_worker_dataset)\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Ol7SmPID1dAt" }, "source": [ "对于 TensorFlow Transform,如上所述,分析阶段会与训练分开完成,因此在此省略。有关详细的操作方法,请参阅[教程](https://tensorflow.google.cn/tfx/tutorials/transform/census)。通常,此阶段包括创建 `tf.Transform` 预处理函数,以及使用此预处理函数转换 [Apache Beam](https://beam.apache.org/) 流水线中的数据。在分析阶段结束时,可以将输出导出为 TensorFlow 计算图,进而用于训练和应用。我们的示例仅涵盖了训练流水线部分:\n", "\n", "```\n", "with strategy.scope():\n", " # working_dir contains the tf.Transform output.\n", " tf_transform_output = tft.TFTransformOutput(working_dir)\n", " # Loading from working_dir to create a Keras layer for applying the tf.Transform output to data\n", " tft_layer = tf_transform_output.transform_features_layer()\n", " ...\n", "\n", "def dataset_fn(input_context):\n", " ...\n", " dataset.map(tft_layer, num_parallel_calls=tf.data.AUTOTUNE)\n", " ...\n", " return dataset\n", "\n", "distributed_dataset = strategy.distribute_datasets_from_function(dataset_fn)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "3_IQxRXxQWof" }, "source": [ "## 部分批次" ] }, { "cell_type": "markdown", "metadata": { "id": "hW2_gVkiztUG" }, "source": [ "当 1) 用户创建的 `tf.data.Dataset` 实例包含的批次大小不能被副本数整除,或者 2) 数据集实例的基数不能被批次大小整除时,将遇到部分批次。这意味着,当数据集分布在多个副本上时,某些迭代器上的 `next` 调用将导致 `tf.errors.OutOfRangeError`。要处理此用例,`tf.distribute` 会在没有更多数据要处理的副本上返回批次大小为 `0` 的虚拟批次。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "rqutdpqtPcCH" }, "source": [ "对于单工作进程情况,如果迭代器上的 `next` 调用未返回数据,则会创建批次大小为 0 的虚拟批次,并将其与数据集中的实际数据一起使用。在部分批次的情况下,数据的最后一个全局批次将包含实际数据以及虚拟数据批次。现在,用于处理数据的停止条件会检查是否有任何副本具有数据。如果任何副本上都没有数据,则会出现 `tf.errors.OutOfRangeError` 错误。\n", "\n", "对于多工作进程情况,使用跨副本通信聚合表示每个工作进程上数据存在的布尔值,该布尔值用于标识所有工作进程是否已完成对分布式数据集的处理。由于这涉及跨工作进程通信,因此会涉及一些性能损失。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vehLsgljz90Y" }, "source": [ "## 警告" ] }, { "cell_type": "markdown", "metadata": { "id": "Nx4jyN_Az-Dy" }, "source": [ "- 将 `tf.distribute.Strategy.experimental_distribute_dataset` API 与多工作进程环境结合使用时,您会传递从文件读取的 `tf.data.Dataset`。如果 `tf.data.experimental.AutoShardPolicy` 设置为 `AUTO` 或 `FILE`,则实际的每步批次大小可能会小于您为全局批次大小定义的值。当文件中的剩余元素小于全局批次大小时,可能会发生这种情况。您可以在不依赖于运行步数的情况下耗尽数据集,也可以通过将 `tf.data.experimental.AutoShardPolicy` 设置为 `DATA` 来解决。\n", "\n", "- `tf.distribute` 当前不支持有状态数据集转换,并且当前将忽略数据集可能具有的任何有状态运算。例如,如果您的数据集包含使用 `tf.random.uniform` 来旋转图像的 `map_fn`,您的数据集计算图将依赖于执行 Python 进程的本地机器上的状态(即,随机种子)。\n", "\n", "- 默认停用的实验性 `tf.data.experimental.OptimizationOptions` 在某些上下文中(例如与 `tf.distribute` 一起使用时)可能会导致性能下降。只有在分布设置中验证它们有利于您的工作负载性能后,才应将其启用。\n", "\n", "- 请参阅[这篇指南](https://tensorflow.google.cn/guide/data_performance),了解如何使用 `tf.data` 优化您的输入流水线。一些附加提示:\n", "\n", " - 如果您有多个工作进程并且正在使用 `tf.data.Dataset.list_files` 从匹配一个或多个 glob 模式的所有文件创建数据集,请记住设置 `seed` 参数或设置 `shuffle=False`,这样每个工作进程才能一致地分片文件。\n", "\n", "- 如果您的输入流水线包括在记录级别上打乱数据的顺序和解析数据,除非未解析的数据明显大于已解析的数据(通常不是这种情况),否则请先打乱数据,然后再解析,如下面的示例中所示。这样做对内存使用率和性能有利。\n", "\n", "```\n", "d = tf.data.Dataset.list_files(pattern, shuffle=False)\n", "d = d.shard(num_workers, worker_index)\n", "d = d.repeat(num_epochs)\n", "d = d.shuffle(shuffle_buffer_size)\n", "d = d.interleave(tf.data.TFRecordDataset,\n", " cycle_length=num_readers, block_length=1)\n", "d = d.map(parser_fn, num_parallel_calls=num_map_threads)\n", "```\n", "\n", "- `tf.data.Dataset.shuffle(buffer_size, seed=None, reshuffle_each_iteration=None)` 维持 `buffer_size` 元素的内部缓冲区,因此减小 `buffer_size` 可以缓解 OOM 问题。" ] }, { "cell_type": "markdown", "metadata": { "id": "dAC_vRmJyzrB" }, "source": [ "- 使用 `tf.distribute.experimental_distribute_dataset` 或 `tf.distribute.distribute_datasets_from_function` 时,工作进程处理数据的顺序无法得到保证。如果您使用 `tf.distribute` 来扩展预测,这通常是必需的。但是,您可以为批次中的每个元素插入索引并相应地对输出进行排序。以下代码段是如何对输出进行排序的示例。\n", "\n", "注:为方便起见,此处使用 `tf.distribute.MirroredStrategy`。仅当您使用多工作进程,但将 `tf.distribute.MirroredStrategy` 用于在单工作进程上分布训练时,才需要对输入重新排序。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:11.687311Z", "iopub.status.busy": "2023-11-07T23:11:11.686610Z", "iopub.status.idle": "2023-11-07T23:11:12.071571Z", "shell.execute_reply": "2023-11-07T23:11:12.070807Z" }, "id": "Zr2xAy-uZZaL" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{0: 0, 1: 2, 2: 4, 3: 6, 4: 8, 5: 10, 6: 12, 7: 14, 8: 16, 9: 18, 10: 20, 11: 22, 12: 24, 13: 26, 14: 28, 15: 30, 16: 32, 17: 34, 18: 36, 19: 38, 20: 40, 21: 42, 22: 44, 23: 46}\n" ] } ], "source": [ "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "dataset_size = 24\n", "batch_size = 6\n", "dataset = tf.data.Dataset.range(dataset_size).enumerate().batch(batch_size)\n", "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n", "\n", "def predict(index, inputs):\n", " outputs = 2 * inputs\n", " return index, outputs\n", "\n", "result = {}\n", "for index, inputs in dist_dataset:\n", " output_index, outputs = mirrored_strategy.run(predict, args=(index, inputs))\n", " indices = list(mirrored_strategy.experimental_local_results(output_index))\n", " rindices = []\n", " for a in indices:\n", " rindices.extend(a.numpy())\n", " outputs = list(mirrored_strategy.experimental_local_results(outputs))\n", " routputs = []\n", " for a in outputs:\n", " routputs.extend(a.numpy())\n", " for i, value in zip(rindices, routputs):\n", " result[i] = value\n", "\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": { "id": "nNbn7HXx0YqB" }, "source": [ " ## 张量输入而非 tf.data " ] }, { "cell_type": "markdown", "metadata": { "id": "dymZixqo0nKK" }, "source": [ "有时用户无法使用 `tf.data.Dataset` 表示其输入,随后也无法使用上述 API 将数据集分布到多个设备。在这种情况下,您可以使用原始张量或来自生成器的输入。\n", "\n", "### 将 experimental_distribute_values_from_function 用于任意张量输入\n", "\n", "`strategy.run` 接受 `tf.distribute.DistributedValues`,它是 `next(iterator)` 的输出。要传递张量值,请使用 `tf.distribute.Strategy.experimental_distribute_values_from_function` 从原始张量构造 `tf.distribute.DistributedValues`。用户必须使用此选项在输入函数中指定自己的批处理和分片逻辑,这可以使用 `tf.distribute.experimental.ValueContext` 输入对象来完成。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:12.075564Z", "iopub.status.busy": "2023-11-07T23:11:12.074987Z", "iopub.status.idle": "2023-11-07T23:11:12.096076Z", "shell.execute_reply": "2023-11-07T23:11:12.095349Z" }, "id": "ajZHNRQs0kqm" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor(0, shape=(), dtype=int32),\n", " 1: tf.Tensor(1, shape=(), dtype=int32),\n", " 2: tf.Tensor(2, shape=(), dtype=int32),\n", " 3: tf.Tensor(3, shape=(), dtype=int32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor(0, shape=(), dtype=int32),\n", " 1: tf.Tensor(1, shape=(), dtype=int32),\n", " 2: tf.Tensor(2, shape=(), dtype=int32),\n", " 3: tf.Tensor(3, shape=(), dtype=int32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor(0, shape=(), dtype=int32),\n", " 1: tf.Tensor(1, shape=(), dtype=int32),\n", " 2: tf.Tensor(2, shape=(), dtype=int32),\n", " 3: tf.Tensor(3, shape=(), dtype=int32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor(0, shape=(), dtype=int32),\n", " 1: tf.Tensor(1, shape=(), dtype=int32),\n", " 2: tf.Tensor(2, shape=(), dtype=int32),\n", " 3: tf.Tensor(3, shape=(), dtype=int32)\n", "}\n" ] } ], "source": [ "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "\n", "def value_fn(ctx):\n", " return tf.constant(ctx.replica_id_in_sync_group)\n", "\n", "distributed_values = mirrored_strategy.experimental_distribute_values_from_function(value_fn)\n", "for _ in range(4):\n", " result = mirrored_strategy.run(lambda x: x, args=(distributed_values,))\n", " print(result)" ] }, { "cell_type": "markdown", "metadata": { "id": "P98aFQGf0x_7" }, "source": [ "### 如果您的输入来自生成器,则使用 tf.data.Dataset.from_generator" ] }, { "cell_type": "markdown", "metadata": { "id": "emZCWQSi04qT" }, "source": [ "如果您具有要使用的生成器函数,则可以使用 `from_generator` API 创建一个 `tf.data.Dataset` 实例。\n", "\n", "注:`tf.distribute.TPUStrategy` 当前不支持此功能。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:11:12.099842Z", "iopub.status.busy": "2023-11-07T23:11:12.099535Z", "iopub.status.idle": "2023-11-07T23:11:12.504139Z", "shell.execute_reply": "2023-11-07T23:11:12.503218Z" }, "id": "jRhU0X230787" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PerReplica:{\n", " 0: tf.Tensor([0.26498944], shape=(1,), dtype=float32),\n", " 1: tf.Tensor([0.9832243], shape=(1,), dtype=float32),\n", " 2: tf.Tensor([0.7569181], shape=(1,), dtype=float32),\n", " 3: tf.Tensor([0.5905416], shape=(1,), dtype=float32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor([0.2529385], shape=(1,), dtype=float32),\n", " 1: tf.Tensor([0.75223196], shape=(1,), dtype=float32),\n", " 2: tf.Tensor([0.8507075], shape=(1,), dtype=float32),\n", " 3: tf.Tensor([0.35577485], shape=(1,), dtype=float32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor([0.47461054], shape=(1,), dtype=float32),\n", " 1: tf.Tensor([0.46633008], shape=(1,), dtype=float32),\n", " 2: tf.Tensor([0.2187182], shape=(1,), dtype=float32),\n", " 3: tf.Tensor([0.8489092], shape=(1,), dtype=float32)\n", "}\n", "PerReplica:{\n", " 0: tf.Tensor([0.27852485], shape=(1,), dtype=float32),\n", " 1: tf.Tensor([0.10208022], shape=(1,), dtype=float32),\n", " 2: tf.Tensor([0.5859448], shape=(1,), dtype=float32),\n", " 3: tf.Tensor([0.4391938], shape=(1,), dtype=float32)\n", "}\n" ] } ], "source": [ "mirrored_strategy = tf.distribute.MirroredStrategy()\n", "def input_gen():\n", " while True:\n", " yield np.random.rand(4)\n", "\n", "# use Dataset.from_generator\n", "dataset = tf.data.Dataset.from_generator(\n", " input_gen, output_types=(tf.float32), output_shapes=tf.TensorShape([4]))\n", "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n", "iterator = iter(dist_dataset)\n", "for _ in range(4):\n", " result = mirrored_strategy.run(lambda x: x, args=(next(iterator),))\n", " print(result)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "input.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }