{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2019 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-11-07T23:15:16.495560Z", "iopub.status.busy": "2023-11-07T23:15:16.495018Z", "iopub.status.idle": "2023-11-07T23:15:16.499343Z", "shell.execute_reply": "2023-11-07T23:15:16.498703Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "# 使用 Keras 和 MultiWorkerMirroredStrategy 的自定义训练循环\n", "\n", "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
tf.distribute.Strategy
API 的自定义训练循环执行多工作进程分布式训练。训练循环通过 `tf.distribute.MultiWorkerMirroredStrategy` 进行分布。这样,设计为在[单个工作进程上运行的 `tf.keras` 模型](custom_training.ipynb)即可通过最少的代码更改无缝地在多个工作进程上运行。自定义训练循环提供了灵活性和更好的训练控制,同时也使模型的调试更加容易。请详细了解有关[编写基本训练循环](../../guide/basic_training_loops.ipynb)、 [从头开始编写训练循环](https://tensorflow.google.cn/guide/keras/writing_a_training_loop_from_scratch)和[自定义训练](../customization/custom_training_walkthrough.ipynb)的信息。\n",
"\n",
"如果您正在寻找如何将 `MultiWorkerMirroredStrategy` 与 `tf.keras.Model.fit` 一起使用,请参阅此[教程](multi_worker_with_keras.ipynb)。\n",
"\n",
"[TensorFlow 中的分布式训练](../../guide/distributed_training.ipynb)指南概述了 TensorFlow 支持的分布式策略,并适用于想要更深入了解 `tf.distribute.Strategy` API 的人。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MUXex9ctTuDB"
},
"source": [
"## 安装\n",
"\n",
"首先,进行一些必要的导入。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:16.503236Z",
"iopub.status.busy": "2023-11-07T23:15:16.502604Z",
"iopub.status.idle": "2023-11-07T23:15:16.509179Z",
"shell.execute_reply": "2023-11-07T23:15:16.508511Z"
},
"id": "bnYxvfLD-LW-"
},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"import sys"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zz0EY91y3mxy"
},
"source": [
"在导入 TensorFlow 之前,需要对环境进行一些变更:\n",
"\n",
"- 停用所有 GPU。这可以防止所有工作进程都尝试使用同一个 GPU 而导致的错误。对于真实应用,每个工作进程都将在不同的计算机上运行。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:16.512523Z",
"iopub.status.busy": "2023-11-07T23:15:16.511976Z",
"iopub.status.idle": "2023-11-07T23:15:16.515646Z",
"shell.execute_reply": "2023-11-07T23:15:16.515038Z"
},
"id": "685pbYEY3jGC"
},
"outputs": [],
"source": [
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7X1MS6385BWi"
},
"source": [
"- 重置 `'TF_CONFIG'` 环境变量(稍后您将看到更多相关信息)。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:16.519247Z",
"iopub.status.busy": "2023-11-07T23:15:16.518695Z",
"iopub.status.idle": "2023-11-07T23:15:16.521957Z",
"shell.execute_reply": "2023-11-07T23:15:16.521355Z"
},
"id": "WEJLYa2_7OZF"
},
"outputs": [],
"source": [
"os.environ.pop('TF_CONFIG', None)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rd4L9Ii77SS8"
},
"source": [
"- 确保当前目录位于 Python 的路径上。这样,笔记本可以导入稍后由 `%%writefile` 写入的文件。\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:16.525511Z",
"iopub.status.busy": "2023-11-07T23:15:16.524881Z",
"iopub.status.idle": "2023-11-07T23:15:16.528298Z",
"shell.execute_reply": "2023-11-07T23:15:16.527662Z"
},
"id": "hPBuZUNSZmrQ"
},
"outputs": [],
"source": [
"if '.' not in sys.path:\n",
" sys.path.insert(0, '.')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pDhHuMjb7bfU"
},
"source": [
"现在导入 TensorFlow。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:16.531882Z",
"iopub.status.busy": "2023-11-07T23:15:16.531274Z",
"iopub.status.idle": "2023-11-07T23:15:19.109502Z",
"shell.execute_reply": "2023-11-07T23:15:19.108553Z"
},
"id": "vHNvttzV43sA"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-11-07 23:15:16.986830: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2023-11-07 23:15:16.986881: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2023-11-07 23:15:16.988638: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
}
],
"source": [
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0S2jpf6Sx50i"
},
"source": [
"### 数据集和模型定义"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fLW6D2TzvC-4"
},
"source": [
"接下来,使用简单的模型和数据集设置创建 `mnist.py` 文件。本教程中的工作进程将使用此 Python 文件:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:19.114488Z",
"iopub.status.busy": "2023-11-07T23:15:19.113982Z",
"iopub.status.idle": "2023-11-07T23:15:19.120140Z",
"shell.execute_reply": "2023-11-07T23:15:19.119342Z"
},
"id": "dma_wUAxZqo2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Writing mnist.py\n"
]
}
],
"source": [
"%%writefile mnist.py\n",
"\n",
"import os\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"\n",
"def mnist_dataset(batch_size):\n",
" (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()\n",
" # The `x` arrays are in uint8 and have values in the range [0, 255].\n",
" # You need to convert them to float32 with values in the range [0, 1]\n",
" x_train = x_train / np.float32(255)\n",
" y_train = y_train.astype(np.int64)\n",
" train_dataset = tf.data.Dataset.from_tensor_slices(\n",
" (x_train, y_train)).shuffle(60000)\n",
" return train_dataset\n",
"\n",
"def dataset_fn(global_batch_size, input_context):\n",
" batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n",
" dataset = mnist_dataset(batch_size)\n",
" dataset = dataset.shard(input_context.num_input_pipelines,\n",
" input_context.input_pipeline_id)\n",
" dataset = dataset.batch(batch_size)\n",
" return dataset\n",
"\n",
"def build_cnn_model():\n",
" return tf.keras.Sequential([\n",
" tf.keras.Input(shape=(28, 28)),\n",
" tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
" tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(128, activation='relu'),\n",
" tf.keras.layers.Dense(10)\n",
" ])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JmgZwwymxqt5"
},
"source": [
"## 多工作进程配置\n",
"\n",
"接下来,我们进入多工作进程训练的世界。在 TensorFlow 中,在多台计算机上进行训练需要 `'TF_CONFIG'` 环境变量。每台计算机可能有不同的角色。下面使用的 `'TF_CONFIG'` 变量是一个 JSON 字符串,它指定集群中每个工作进程的集群配置。这是使用 `cluster_resolver.TFConfigClusterResolver` 指定集群的默认方法,但在 `distribute.cluster_resolver` 模块中还有其他可用选项。请在[分布式训练指南](../../guide/distributed_training.ipynb)中了解有关设置 `'TF_CONFIG'` 变量的更多信息。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SS8WhvRhe_Ya"
},
"source": [
"### 描述您的集群\n",
"\n",
"下面是一个示例配置:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:19.124437Z",
"iopub.status.busy": "2023-11-07T23:15:19.123838Z",
"iopub.status.idle": "2023-11-07T23:15:19.127351Z",
"shell.execute_reply": "2023-11-07T23:15:19.126741Z"
},
"id": "XK1eTYvSZiX7"
},
"outputs": [],
"source": [
"tf_config = {\n",
" 'cluster': {\n",
" 'worker': ['localhost:12345', 'localhost:23456']\n",
" },\n",
" 'task': {'type': 'worker', 'index': 0}\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JjgwJbPKZkJL"
},
"source": [
"请注意,`tf_config` 只是 Python 中的局部变量。要将其用于训练配置,请将其序列化为 JSON 并将其放在 `'TF_CONFIG'` 环境变量中。这是序列化为 JSON 字符串的相同 `'TF_CONFIG'`:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:19.130890Z",
"iopub.status.busy": "2023-11-07T23:15:19.130639Z",
"iopub.status.idle": "2023-11-07T23:15:19.137618Z",
"shell.execute_reply": "2023-11-07T23:15:19.136927Z"
},
"id": "yY-T0YDQZjbu"
},
"outputs": [
{
"data": {
"text/plain": [
"'{\"cluster\": {\"worker\": [\"localhost:12345\", \"localhost:23456\"]}, \"task\": {\"type\": \"worker\", \"index\": 0}}'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"json.dumps(tf_config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AUBmYRZqxthH"
},
"source": [
"`'TF_CONFIG'` 有两个组件:`'cluster'` 和 `'task'`。\n",
"\n",
"- `'cluster'` 对所有工作进程都相同,并提供有关训练集群的信息,这是一个由不同类型的作业组成的字典,例如 `'worker'` 。在使用 `MultiWorkerMirroredStrategy` 进行的多工作进程训练中,除了普通的 `'worker'` 之外,通常还有一个 `'worker'` 承担更多的责任,例如保存检查点和为 TensorBoard 编写摘要文件。这样的工作进程被称为 `'chief'` 工作进程,习惯上将 `'index'` 为 0 的 `'worker'` 指定为首席 `worker`。\n",
"\n",
"- `'task'` 提供当前任务的信息,并且在每个工作进程上都不相同。它指定该工作进程的 `'type'` 和 `'index'`。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8YFpxrcsZ2xG"
},
"source": [
"在本例中,您会将任务 `'type'` 设置为 `'worker'`,将任务 `'index'` 设置为 `0`。这台计算机是首个工作进程,将被指定为首席工作进程,并需要比其他工作进程承担更多的工作。请注意,其他计算机也需要设置 `'TF_CONFIG'` 环境变量,且应该具有相同的 `'cluster'` 字典,但要根据这些计算机的具体角色来设置不同的任务 `'type'` 或任务 `'index'`。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aogb74kHxynz"
},
"source": [
"出于演示的目的,本教程将展示如何在 `'localhost'` 上设置具有两个工作进程的 `'TF_CONFIG'`。在实践中,用户会在外部 IP 地址/端口上创建多个工作进程,并为每个工作进程正确设置 `'TF_CONFIG'`。\n",
"\n",
"本示例使用两个工作进程,第一个工作进程的 `'TF_CONFIG'` 如上所示。对于第二个工作进程,设置 `tf_config['task']['index']=1`。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cIlkfWmjz1PG"
},
"source": [
"### 笔记本中的环境变量和子进程"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FcjAbuGY1ACJ"
},
"source": [
"子进程会从其父进程继承环境变量。因此,如果您在此 Jupyter Notebook 进程中设置环境变量:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:19.141656Z",
"iopub.status.busy": "2023-11-07T23:15:19.141385Z",
"iopub.status.idle": "2023-11-07T23:15:19.144932Z",
"shell.execute_reply": "2023-11-07T23:15:19.144267Z"
},
"id": "PH2gHn2_0_U8"
},
"outputs": [],
"source": [
"os.environ['GREETINGS'] = 'Hello TensorFlow!'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gQkIX-cg18md"
},
"source": [
"然后,您可以从子进程访问环境变量:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:19.148549Z",
"iopub.status.busy": "2023-11-07T23:15:19.148286Z",
"iopub.status.idle": "2023-11-07T23:15:19.194941Z",
"shell.execute_reply": "2023-11-07T23:15:19.194072Z"
},
"id": "pquKO6IA18G5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello TensorFlow!\n"
]
}
],
"source": [
"%%bash\n",
"echo ${GREETINGS}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "af6BCA-Y2fpz"
},
"source": [
"在下一部分中,您将使用它来将 `'TF_CONFIG'` 传递给工作进程子进程。实际上,您永远不会以这种方式启动您的作业,但这完全可以满足此教程的演示目的:呈现最简单的多工作进程示例。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UhNtHfuxCGVy"
},
"source": [
"## MultiWorkerMirroredStrategy\n",
"\n",
"在训练模型之前,首先创建一个 `tf.distribute.MultiWorkerMirroredStrategy` 的实例:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:19.199657Z",
"iopub.status.busy": "2023-11-07T23:15:19.198947Z",
"iopub.status.idle": "2023-11-07T23:15:19.331499Z",
"shell.execute_reply": "2023-11-07T23:15:19.330736Z"
},
"id": "1uFSHCJXMrQ-"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using MirroredStrategy with devices ('/device:CPU:0',)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CommunicationImplementation.AUTO\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-11-07 23:15:19.298847: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n"
]
}
],
"source": [
"strategy = tf.distribute.MultiWorkerMirroredStrategy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N0iv7SyyAohc"
},
"source": [
"注:在您调用 `tf.distribute.MultiWorkerMirroredStrategy` 时,会解析 `'TF_CONFIG'` 并启动 TensorFlow 的 GRPC 服务器。因此,您必须在实例化 `tf.distribute.Strategy` 之前设置 `'TF_CONFIG'` 环境变量。为了在这个说明性示例中节省时间,本教程中没有对此进行演示,因此不需要启动服务器。您可以在本教程的最后一个部分中找到完整的示例。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TS4S-faBHHam"
},
"source": [
"使用 `tf.distribute.Strategy.scope` 指定构建模型时应使用的策略。这使得该策略可以控制变量放置之类的事情,它将在所有工作进程的每个设备上,在模型的层中创建所有变量的副本。"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:19.335070Z",
"iopub.status.busy": "2023-11-07T23:15:19.334808Z",
"iopub.status.idle": "2023-11-07T23:15:19.441785Z",
"shell.execute_reply": "2023-11-07T23:15:19.441061Z"
},
"id": "nXV49tG1_opc"
},
"outputs": [],
"source": [
"import mnist\n",
"with strategy.scope():\n",
" # Model building needs to be within `strategy.scope()`.\n",
" multi_worker_model = mnist.build_cnn_model()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DSYkM-on6r3Y"
},
"source": [
"## 在工作进程之间对数据进行自动分片\n",
"\n",
"在多工作进程训练中,需要通过*数据集分片*来确保收敛性和可重复性。分片意味着将整个数据集的一个子集交给每个工作进程,这有助于创造类似于对单个工作进程进行训练的体验。在下面的示例中,您依赖于 `tf.distribute` 的默认自动分片策略。您还可以通过设置 `tf.data.experimental.DistributeOptions` 的 `tf.data.experimental.AutoShardPolicy` 来对其进行自定义。要了解更多信息,请参阅[分布式输入教程](input.ipynb)的*分片*部分。"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:19.446767Z",
"iopub.status.busy": "2023-11-07T23:15:19.446050Z",
"iopub.status.idle": "2023-11-07T23:15:20.082698Z",
"shell.execute_reply": "2023-11-07T23:15:20.081648Z"
},
"id": "65-p36pt6rUF"
},
"outputs": [],
"source": [
"per_worker_batch_size = 64\n",
"num_workers = len(tf_config['cluster']['worker'])\n",
"global_batch_size = per_worker_batch_size * num_workers\n",
"\n",
"with strategy.scope():\n",
" multi_worker_dataset = strategy.distribute_datasets_from_function(\n",
" lambda input_context: mnist.dataset_fn(global_batch_size, input_context))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rkNzSR3g60iP"
},
"source": [
"## 定义自定义训练循环并训练模型\n",
"\n",
"指定优化器:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:20.087270Z",
"iopub.status.busy": "2023-11-07T23:15:20.086938Z",
"iopub.status.idle": "2023-11-07T23:15:20.100395Z",
"shell.execute_reply": "2023-11-07T23:15:20.099700Z"
},
"id": "NoMr4_zTeKSn"
},
"outputs": [],
"source": [
"with strategy.scope():\n",
" # The creation of optimizer and train_accuracy needs to be in\n",
" # `strategy.scope()` as well, since they create variables.\n",
" optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)\n",
" train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
" name='train_accuracy')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RmrDcAii4B5O"
},
"source": [
"使用 `tf.function` 定义训练步骤:\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:20.104187Z",
"iopub.status.busy": "2023-11-07T23:15:20.103914Z",
"iopub.status.idle": "2023-11-07T23:15:20.110790Z",
"shell.execute_reply": "2023-11-07T23:15:20.110133Z"
},
"id": "znXWN5S3eUDB"
},
"outputs": [],
"source": [
"@tf.function\n",
"def train_step(iterator):\n",
" \"\"\"Training step function.\"\"\"\n",
"\n",
" def step_fn(inputs):\n",
" \"\"\"Per-Replica step function.\"\"\"\n",
" x, y = inputs\n",
" with tf.GradientTape() as tape:\n",
" predictions = multi_worker_model(x, training=True)\n",
" per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(\n",
" from_logits=True,\n",
" reduction=tf.keras.losses.Reduction.NONE)(y, predictions)\n",
" loss = tf.nn.compute_average_loss(\n",
" per_batch_loss, global_batch_size=global_batch_size)\n",
"\n",
" grads = tape.gradient(loss, multi_worker_model.trainable_variables)\n",
" optimizer.apply_gradients(\n",
" zip(grads, multi_worker_model.trainable_variables))\n",
" train_accuracy.update_state(y, predictions)\n",
" return loss\n",
"\n",
" per_replica_losses = strategy.run(step_fn, args=(next(iterator),))\n",
" return strategy.reduce(\n",
" tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eFXHsUVBy0Rx"
},
"source": [
"### 检查点保存和恢复\n",
"\n",
"在编写自定义训练循环时,您需要手动处理[检查点保存](../../guide/checkpoint.ipynb),而不是依赖 Keras 回调。请注意,对于 `MultiWorkerMirroredStrategy`,保存检查点或完整模型需要所有工作进程的参与,因为尝试仅在首席工作进程上进行保存可能会导致死锁。工作进程还需要写入不同的路径以避免相互重写。以下是如何配置目录的示例:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:20.114623Z",
"iopub.status.busy": "2023-11-07T23:15:20.113953Z",
"iopub.status.idle": "2023-11-07T23:15:20.120496Z",
"shell.execute_reply": "2023-11-07T23:15:20.119689Z"
},
"id": "LcFO6x1KyjhI"
},
"outputs": [],
"source": [
"from multiprocessing import util\n",
"checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')\n",
"\n",
"def _is_chief(task_type, task_id, cluster_spec):\n",
" return (task_type is None\n",
" or task_type == 'chief'\n",
" or (task_type == 'worker'\n",
" and task_id == 0\n",
" and \"chief\" not in cluster_spec.as_dict()))\n",
"\n",
"def _get_temp_dir(dirpath, task_id):\n",
" base_dirpath = 'workertemp_' + str(task_id)\n",
" temp_dir = os.path.join(dirpath, base_dirpath)\n",
" tf.io.gfile.makedirs(temp_dir)\n",
" return temp_dir\n",
"\n",
"def write_filepath(filepath, task_type, task_id, cluster_spec):\n",
" dirpath = os.path.dirname(filepath)\n",
" base = os.path.basename(filepath)\n",
" if not _is_chief(task_type, task_id, cluster_spec):\n",
" dirpath = _get_temp_dir(dirpath, task_id)\n",
" return os.path.join(dirpath, base)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nrcdPHtG4ObO"
},
"source": [
"创建一个跟踪模型的 `tf.train.Checkpoint`,由 `tf.train.CheckpointManager` 管理,以便仅保留最新的检查点:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:20.124367Z",
"iopub.status.busy": "2023-11-07T23:15:20.123804Z",
"iopub.status.idle": "2023-11-07T23:15:20.133425Z",
"shell.execute_reply": "2023-11-07T23:15:20.132707Z"
},
"id": "4rURT2pI4aqV"
},
"outputs": [],
"source": [
"epoch = tf.Variable(\n",
" initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')\n",
"step_in_epoch = tf.Variable(\n",
" initial_value=tf.constant(0, dtype=tf.dtypes.int64),\n",
" name='step_in_epoch')\n",
"task_type, task_id = (strategy.cluster_resolver.task_type,\n",
" strategy.cluster_resolver.task_id)\n",
"# Normally, you don't need to manually instantiate a `ClusterSpec`, but in this \n",
"# illustrative example you did not set `'TF_CONFIG'` before initializing the\n",
"# strategy. Check out the next section for \"real-world\" usage.\n",
"cluster_spec = tf.train.ClusterSpec(tf_config['cluster'])\n",
"\n",
"checkpoint = tf.train.Checkpoint(\n",
" model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)\n",
"\n",
"write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id,\n",
" cluster_spec)\n",
"checkpoint_manager = tf.train.CheckpointManager(\n",
" checkpoint, directory=write_checkpoint_dir, max_to_keep=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RO7cbN40XD5v"
},
"source": [
"现在,当需要恢复检查点时,您可以方便地使用 `tf.train.latest_checkpoint` 函数(或通过调用 `tf.train.CheckpointManager.restore_or_initialize` )找到最新的已保存检查点。"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:20.137411Z",
"iopub.status.busy": "2023-11-07T23:15:20.136897Z",
"iopub.status.idle": "2023-11-07T23:15:20.140692Z",
"shell.execute_reply": "2023-11-07T23:15:20.140009Z"
},
"id": "gniynaQj6HMV"
},
"outputs": [],
"source": [
"latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)\n",
"if latest_checkpoint:\n",
" checkpoint.restore(latest_checkpoint)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1j9JuI-h6ObW"
},
"source": [
"恢复检查点后,您可以继续训练自定义训练循环。"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:20.143940Z",
"iopub.status.busy": "2023-11-07T23:15:20.143513Z",
"iopub.status.idle": "2023-11-07T23:15:24.222775Z",
"shell.execute_reply": "2023-11-07T23:15:24.221984Z"
},
"id": "kZzXZCh45FY6"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-11-07 23:15:20.366756: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0, accuracy: 0.807366, train_loss: 0.621664.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 1, accuracy: 0.926786, train_loss: 0.255375.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 2, accuracy: 0.947656, train_loss: 0.172921.\n"
]
}
],
"source": [
"num_epochs = 3\n",
"num_steps_per_epoch = 70\n",
"\n",
"while epoch.numpy() < num_epochs:\n",
" iterator = iter(multi_worker_dataset)\n",
" total_loss = 0.0\n",
" num_batches = 0\n",
"\n",
" while step_in_epoch.numpy() < num_steps_per_epoch:\n",
" total_loss += train_step(iterator)\n",
" num_batches += 1\n",
" step_in_epoch.assign_add(1)\n",
"\n",
" train_loss = total_loss / num_batches\n",
" print('Epoch: %d, accuracy: %f, train_loss: %f.'\n",
" %(epoch.numpy(), train_accuracy.result(), train_loss))\n",
"\n",
" train_accuracy.reset_states()\n",
"\n",
" # Once the `CheckpointManager` is set up, you're now ready to save, and remove\n",
" # the checkpoints non-chief workers saved.\n",
" checkpoint_manager.save()\n",
" if not _is_chief(task_type, task_id, cluster_spec):\n",
" tf.io.gfile.rmtree(write_checkpoint_dir)\n",
"\n",
" epoch.assign_add(1)\n",
" step_in_epoch.assign(0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0W1Osks466DE"
},
"source": [
"## 完整代码一览"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jfYpmIxO6Jck"
},
"source": [
"总结一下到目前为止讨论的所有程序:\n",
"\n",
"1. 创建工作进程。\n",
"2. 将 `'TF_CONFIG'` 传递给工作进程。\n",
"3. 让每个工作进程运行下面包含训练代码的脚本。"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:24.227041Z",
"iopub.status.busy": "2023-11-07T23:15:24.226440Z",
"iopub.status.idle": "2023-11-07T23:15:24.233423Z",
"shell.execute_reply": "2023-11-07T23:15:24.232771Z"
},
"id": "MIDCESkVzN6M"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Writing main.py\n"
]
}
],
"source": [
"%%writefile main.py\n",
"#@title File: `main.py`\n",
"import os\n",
"import json\n",
"import tensorflow as tf\n",
"import mnist\n",
"from multiprocessing import util\n",
"\n",
"per_worker_batch_size = 64\n",
"tf_config = json.loads(os.environ['TF_CONFIG'])\n",
"num_workers = len(tf_config['cluster']['worker'])\n",
"global_batch_size = per_worker_batch_size * num_workers\n",
"\n",
"num_epochs = 3\n",
"num_steps_per_epoch=70\n",
"\n",
"# Checkpoint saving and restoring\n",
"def _is_chief(task_type, task_id, cluster_spec):\n",
" return (task_type is None\n",
" or task_type == 'chief'\n",
" or (task_type == 'worker'\n",
" and task_id == 0\n",
" and 'chief' not in cluster_spec.as_dict()))\n",
" \n",
"def _get_temp_dir(dirpath, task_id):\n",
" base_dirpath = 'workertemp_' + str(task_id)\n",
" temp_dir = os.path.join(dirpath, base_dirpath)\n",
" tf.io.gfile.makedirs(temp_dir)\n",
" return temp_dir\n",
"\n",
"def write_filepath(filepath, task_type, task_id, cluster_spec):\n",
" dirpath = os.path.dirname(filepath)\n",
" base = os.path.basename(filepath)\n",
" if not _is_chief(task_type, task_id, cluster_spec):\n",
" dirpath = _get_temp_dir(dirpath, task_id)\n",
" return os.path.join(dirpath, base)\n",
"\n",
"checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')\n",
"\n",
"# Define Strategy\n",
"strategy = tf.distribute.MultiWorkerMirroredStrategy()\n",
"\n",
"with strategy.scope():\n",
" # Model building/compiling need to be within `tf.distribute.Strategy.scope`.\n",
" multi_worker_model = mnist.build_cnn_model()\n",
"\n",
" multi_worker_dataset = strategy.distribute_datasets_from_function(\n",
" lambda input_context: mnist.dataset_fn(global_batch_size, input_context)) \n",
" optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)\n",
" train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
" name='train_accuracy')\n",
"\n",
"@tf.function\n",
"def train_step(iterator):\n",
" \"\"\"Training step function.\"\"\"\n",
"\n",
" def step_fn(inputs):\n",
" \"\"\"Per-Replica step function.\"\"\"\n",
" x, y = inputs\n",
" with tf.GradientTape() as tape:\n",
" predictions = multi_worker_model(x, training=True)\n",
" per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(\n",
" from_logits=True,\n",
" reduction=tf.keras.losses.Reduction.NONE)(y, predictions)\n",
" loss = tf.nn.compute_average_loss(\n",
" per_batch_loss, global_batch_size=global_batch_size)\n",
"\n",
" grads = tape.gradient(loss, multi_worker_model.trainable_variables)\n",
" optimizer.apply_gradients(\n",
" zip(grads, multi_worker_model.trainable_variables))\n",
" train_accuracy.update_state(y, predictions)\n",
"\n",
" return loss\n",
"\n",
" per_replica_losses = strategy.run(step_fn, args=(next(iterator),))\n",
" return strategy.reduce(\n",
" tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)\n",
"\n",
"epoch = tf.Variable(\n",
" initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')\n",
"step_in_epoch = tf.Variable(\n",
" initial_value=tf.constant(0, dtype=tf.dtypes.int64),\n",
" name='step_in_epoch')\n",
"\n",
"task_type, task_id, cluster_spec = (strategy.cluster_resolver.task_type,\n",
" strategy.cluster_resolver.task_id,\n",
" strategy.cluster_resolver.cluster_spec())\n",
"\n",
"checkpoint = tf.train.Checkpoint(\n",
" model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)\n",
"\n",
"write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id,\n",
" cluster_spec)\n",
"checkpoint_manager = tf.train.CheckpointManager(\n",
" checkpoint, directory=write_checkpoint_dir, max_to_keep=1)\n",
"\n",
"# Restoring the checkpoint\n",
"latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)\n",
"if latest_checkpoint:\n",
" checkpoint.restore(latest_checkpoint)\n",
"\n",
"# Resume our CTL training\n",
"while epoch.numpy() < num_epochs:\n",
" iterator = iter(multi_worker_dataset)\n",
" total_loss = 0.0\n",
" num_batches = 0\n",
"\n",
" while step_in_epoch.numpy() < num_steps_per_epoch:\n",
" total_loss += train_step(iterator)\n",
" num_batches += 1\n",
" step_in_epoch.assign_add(1)\n",
"\n",
" train_loss = total_loss / num_batches\n",
" print('Epoch: %d, accuracy: %f, train_loss: %f.'\n",
" %(epoch.numpy(), train_accuracy.result(), train_loss))\n",
" \n",
" train_accuracy.reset_states()\n",
"\n",
" checkpoint_manager.save()\n",
" if not _is_chief(task_type, task_id, cluster_spec):\n",
" tf.io.gfile.rmtree(write_checkpoint_dir)\n",
"\n",
" epoch.assign_add(1)\n",
" step_in_epoch.assign(0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ItVOvPN1qnZ6"
},
"source": [
"当前目录现包含两个 Python 文件:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:24.237059Z",
"iopub.status.busy": "2023-11-07T23:15:24.236529Z",
"iopub.status.idle": "2023-11-07T23:15:24.300080Z",
"shell.execute_reply": "2023-11-07T23:15:24.299121Z"
},
"id": "bi6x05Sr60O9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"main.py\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mnist.py\n"
]
}
],
"source": [
"%%bash\n",
"ls *.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qmEEStPS6vR_"
},
"source": [
"因此,对 `'TF_CONFIG'` 执行 JSON 序列化,然后将其添加到环境变量:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:24.304381Z",
"iopub.status.busy": "2023-11-07T23:15:24.303636Z",
"iopub.status.idle": "2023-11-07T23:15:24.307915Z",
"shell.execute_reply": "2023-11-07T23:15:24.307209Z"
},
"id": "9uu3g7vV7Bbt"
},
"outputs": [],
"source": [
"os.environ['TF_CONFIG'] = json.dumps(tf_config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MsY3dQLK7jdf"
},
"source": [
"现在,您可以启动一个将运行 `main.py` 并使用 `'TF_CONFIG'` 的工作进程:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:24.311524Z",
"iopub.status.busy": "2023-11-07T23:15:24.311017Z",
"iopub.status.idle": "2023-11-07T23:15:24.315129Z",
"shell.execute_reply": "2023-11-07T23:15:24.314498Z"
},
"id": "txMXaq8d8N_S"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All background processes were killed.\n"
]
}
],
"source": [
"# first kill any previous runs\n",
"%killbgscripts"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:24.318373Z",
"iopub.status.busy": "2023-11-07T23:15:24.318092Z",
"iopub.status.idle": "2023-11-07T23:15:24.375371Z",
"shell.execute_reply": "2023-11-07T23:15:24.374149Z"
},
"id": "qnSma_Ck7r-r"
},
"outputs": [],
"source": [
"%%bash --bg\n",
"python main.py &> job_0.log"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZChyazqS7v0P"
},
"source": [
"以上命令有几点需要注意:\n",
"\n",
"1. 它使用 `%%bash`,这是一项用于运行一些 bash 命令的[笔记本“魔术命令”](https://ipython.readthedocs.io/en/stable/interactive/magics.html)。\n",
"2. 它使用 `--bg` 标志在后台运行 `bash` 进程,因为此工作进程不会终止。它在开始之前会等待所有工作进程。\n",
"\n",
"后台工作进程不会将输出打印到此笔记本。`&>` 会将其输出重定向到一个文件,以便您可以查看所发生的情况。\n",
"\n",
"等待几秒钟以启动该进程:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:24.379842Z",
"iopub.status.busy": "2023-11-07T23:15:24.379527Z",
"iopub.status.idle": "2023-11-07T23:15:44.404224Z",
"shell.execute_reply": "2023-11-07T23:15:44.403201Z"
},
"id": "Hm2yrULE9281"
},
"outputs": [],
"source": [
"import time\n",
"time.sleep(20)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZFPoNxg_9_Mx"
},
"source": [
"接下来,检查一下目前为止输出到工作进程日志文件的内容:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:44.409197Z",
"iopub.status.busy": "2023-11-07T23:15:44.408493Z",
"iopub.status.idle": "2023-11-07T23:15:44.474674Z",
"shell.execute_reply": "2023-11-07T23:15:44.473733Z"
},
"id": "vZEOuVgQ9-hn"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-11-07 23:15:24.897952: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-11-07 23:15:24.898016: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-11-07 23:15:24.899709: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-11-07 23:15:27.043487: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n"
]
}
],
"source": [
"%%bash\n",
"cat job_0.log"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RqZhVF7L_KOy"
},
"source": [
"日志文件的最后一行内容应为:`Started server with target: grpc://localhost:12345`。第一个工作进程现已准备就绪,正在等待所有其他工作进程准备就绪以继续。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pi8vPNNA_l4a"
},
"source": [
"更新 `tf_config` 以供第二个工作进程取用:"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:44.478977Z",
"iopub.status.busy": "2023-11-07T23:15:44.478379Z",
"iopub.status.idle": "2023-11-07T23:15:44.483079Z",
"shell.execute_reply": "2023-11-07T23:15:44.482298Z"
},
"id": "lAiYkkPu_Jqd"
},
"outputs": [],
"source": [
"tf_config['task']['index'] = 1\n",
"os.environ['TF_CONFIG'] = json.dumps(tf_config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0AshGVO0_x0w"
},
"source": [
"现在,启动第二个工作进程。这将开始训练,因为所有工作进程都已处于活动状态(因此无需在后台执行此进程):"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:44.486857Z",
"iopub.status.busy": "2023-11-07T23:15:44.486284Z",
"iopub.status.idle": "2023-11-07T23:15:59.434365Z",
"shell.execute_reply": "2023-11-07T23:15:59.433017Z"
},
"id": "_ESVtyQ9_xjx"
},
"outputs": [],
"source": [
"%%bash\n",
"python main.py > /dev/null 2>&1"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hX4FA2O2AuAn"
},
"source": [
"如果您重新检查第一个工作进程编写的日志,您会看到它参与了该模型的训练:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:59.441138Z",
"iopub.status.busy": "2023-11-07T23:15:59.440438Z",
"iopub.status.idle": "2023-11-07T23:15:59.507207Z",
"shell.execute_reply": "2023-11-07T23:15:59.506105Z"
},
"id": "rc6hw3yTBKXX"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-11-07 23:15:24.897952: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-11-07 23:15:24.898016: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-11-07 23:15:24.899709: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-11-07 23:15:27.043487: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-11-07 23:15:48.287770: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0, accuracy: 0.804129, train_loss: 0.624825.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 1, accuracy: 0.920201, train_loss: 0.276320.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 2, accuracy: 0.946429, train_loss: 0.194815.\n"
]
}
],
"source": [
"%%bash\n",
"cat job_0.log"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:15:59.511709Z",
"iopub.status.busy": "2023-11-07T23:15:59.511401Z",
"iopub.status.idle": "2023-11-07T23:15:59.516717Z",
"shell.execute_reply": "2023-11-07T23:15:59.516003Z"
},
"id": "sG5_1UgrgniF"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All background processes were killed.\n"
]
}
],
"source": [
"# Delete the `'TF_CONFIG'`, and kill any background tasks so they don't affect the next section.\n",
"os.environ.pop('TF_CONFIG', None)\n",
"%killbgscripts"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bhxMXa0AaZkK"
},
"source": [
"## 深入了解多工作进程训练\n",
"\n",
"本教程演示了多工作进程设置的自定义训练循环工作流程。有关其他主题的详细描述可在适用于自定义训练循环的[使用 Keras 进行多工作进程训练 (`tf.keras.Model.fit`)](multi_worker_with_keras.ipynb) 教程中找到。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ega2hdOQEmy_"
},
"source": [
"## 了解更多\n",
"\n",
"1. [TensorFlow 中的分布式训练](../../guide/distributed_training.ipynb)指南概述了可用的分布式策略。\n",
"2. [官方模型](https://github.com/tensorflow/models/tree/master/official),其中许多模型可以配置为运行多个分布式策略。\n",
"3. `tf.function` 指南中的[“性能”部分](../../guide/function.ipynb)提供了有关其他策略和[工具](../../guide/profiler.md)的信息,您可以使用它们来优化 TensorFlow 模型的性能。\n"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "multi_worker_with_ctl.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 0
}