{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2019 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-11-07T23:13:05.973477Z", "iopub.status.busy": "2023-11-07T23:13:05.973211Z", "iopub.status.idle": "2023-11-07T23:13:05.977493Z", "shell.execute_reply": "2023-11-07T23:13:05.976873Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "# 利用 Estimator 进行多工作进程训练\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看\n", " 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "-_ZO8y69hs-N" }, "source": [ "> 警告:不建议将 Estimator 用于新代码。Estimator 运行 `v1.Session` 风格的代码,此类代码更加难以正确编写,并且可能会出现意外行为,尤其是与 TF 2 代码结合使用时。Estimator 确实在我们的[兼容性保证范围](https://tensorflow.org/guide/versions)内,但除了安全漏洞之外不会得到任何修复。请参阅[迁移指南](https://tensorflow.org/guide/migrate)以了解详情。" ] }, { "cell_type": "markdown", "metadata": { "id": "xHxb-dlhMIzW" }, "source": [ "## 概述\n", "\n", "注:虽然您可以将 Estimator 与 `tf.distribute` API 结合使用,但建议将 Keras 和 `tf.distribute` 结合使用,请参阅[利用 Keras 进行多工作进程训练](multi_worker_with_keras.ipynb)。使用 `tf.distribute.Strategy` 的 Estimator 训练支持有限。\n", "\n", "本教程展示了如何通过 `tf.estimator` 将 `tf.distribute.Strategy` 用于分布式多工作进程训练。如果您使用 `tf.estimator` 编写代码,并且希望以高性能扩展到更多机器,那么本教程很适合您。\n", "\n", "在开始之前,请先阅读[分布式策略](../../guide/distributed_training.ipynb)指南。同样相关的还有[多 GPU 训练教程](./keras.ipynb),因为本教程使用的是相同的模型。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MUXex9ctTuDB" }, "source": [ "## 创建\n", "\n", "首先,设置好 TensorFlow 以及将会用到的输入模块。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:13:05.981427Z", "iopub.status.busy": "2023-11-07T23:13:05.981174Z", "iopub.status.idle": "2023-11-07T23:13:08.942120Z", "shell.execute_reply": "2023-11-07T23:13:08.941316Z" }, "id": "bnYxvfLD-LW-" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 23:13:06.920204: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-11-07 23:13:06.920259: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-11-07 23:13:06.921969: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tensorflow_datasets as tfds\n", "import tensorflow as tf\n", "\n", "import os, json" ] }, { "cell_type": "markdown", "metadata": { "id": "-xicK9byC7hi" }, "source": [ "注:从 TF2.4 开始,如果在启用了 Eager(默认)的情况下运行,多工作进程镜像策略会在使用 Estimator 时失败。TF2.4 中的错误是 `TypeError: cannot pickle '_thread.lock' object`。请参阅[议题 #46556](https://github.com/tensorflow/tensorflow/issues/46556) 了解详细信息。解决办法是禁用 Eager Execution。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:13:08.946800Z", "iopub.status.busy": "2023-11-07T23:13:08.946302Z", "iopub.status.idle": "2023-11-07T23:13:08.950360Z", "shell.execute_reply": "2023-11-07T23:13:08.949645Z" }, "id": "5dJ6UYrGDsVs" }, "outputs": [], "source": [ "tf.compat.v1.disable_eager_execution()" ] }, { "cell_type": "markdown", "metadata": { "id": "hPBuZUNSZmrQ" }, "source": [ "## 输入函数\n", "\n", "本教程使用的是 [TensorFlow 数据集](https://tensorflow.google.cn/datasets)中的 MNIST 数据集。本教程中的代码与[多 GPU 训练教程](./keras.ipynb)类似,但有一个主要区别:当使用 Estimator 进行多工作进程训练时,需要根据工作进程的数量对数据集进行拆分,以确保模型收敛。输入数据会根据工作进程索引来拆分,因此每个工作进程负责处理数据集的 `1/num_workers` 个不同部分。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:13:08.953661Z", "iopub.status.busy": "2023-11-07T23:13:08.953425Z", "iopub.status.idle": "2023-11-07T23:13:08.958798Z", "shell.execute_reply": "2023-11-07T23:13:08.958098Z" }, "id": "dma_wUAxZqo2" }, "outputs": [], "source": [ "BUFFER_SIZE = 10000\n", "BATCH_SIZE = 64\n", "\n", "def input_fn(mode, input_context=None):\n", " datasets, info = tfds.load(name='mnist',\n", " with_info=True,\n", " as_supervised=True)\n", " mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else\n", " datasets['test'])\n", "\n", " def scale(image, label):\n", " image = tf.cast(image, tf.float32)\n", " image /= 255\n", " return image, label\n", "\n", " if input_context:\n", " mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,\n", " input_context.input_pipeline_id)\n", " return mnist_dataset.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" ] }, { "cell_type": "markdown", "metadata": { "id": "4BlcVXMhB59T" }, "source": [ "使模型收敛的另一种合理方式是在每个工作器上设置不同的随机种子,然后对数据集进行随机重排。" ] }, { "cell_type": "markdown", "metadata": { "id": "8YFpxrcsZ2xG" }, "source": [ "## 多工作器配置\n", "\n", "本教程主要的不同(区别于[使用多 GPU 训练教程](./keras.ipynb))在于多工作器的创建。明确集群中每个工作器的配置的标准方式是设置环境变量 `TF_CONFIG` 。\n", "\n", "`TF_CONFIG` 里包括了两个部分:`cluster` 和 `task`。`cluster` 提供了关于整个集群的信息,也就是集群中的工作器和参数服务器(parameter server)。`task` 提供了关于当前任务的信息。在本例中,任务的类型(type)是 worker 且该任务的索引(index)是 0。\n", "\n", "出于演示的目的,本教程展示了怎么将 `TF_CONFIG` 设置成两个本地的工作器。在实践中,你可以在外部的IP地址和端口上创建多个工作器,并为每个工作器正确地配置好 `TF_CONFIG` 变量,也就是更改任务的索引。\n", "\n", "警告:*请勿在 Colab 中执行以下代码*。TensorFlow 的运行时将尝试在指定的 IP 地址和端口上创建 gRPC 服务器,而这可能会失败。请参阅本教程的 [Keras 版本](multi_worker_with_keras.ipynb),查看说明如何在单台计算机上测试运行多个工作进程的示例。\n", "\n", "```\n", "os.environ['TF_CONFIG'] = json.dumps({\n", " 'cluster': {\n", " 'worker': [\"localhost:12345\", \"localhost:23456\"]\n", " },\n", " 'task': {'type': 'worker', 'index': 0}\n", "})\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "id": "qDreJzTffAP5" }, "source": [ "## 定义模型\n", "\n", "定义训练中用到的层,优化器和损失函数。本教程使用 Keras layers 定义模型,同[使用多 GPU 训练教程](./keras.ipynb)类似。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:13:08.962288Z", "iopub.status.busy": "2023-11-07T23:13:08.962025Z", "iopub.status.idle": "2023-11-07T23:13:08.968958Z", "shell.execute_reply": "2023-11-07T23:13:08.968260Z" }, "id": "WNvOn_OeiUYC" }, "outputs": [], "source": [ "LEARNING_RATE = 1e-4\n", "def model_fn(features, labels, mode):\n", " model = tf.keras.Sequential([\n", " tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),\n", " tf.keras.layers.MaxPooling2D(),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(64, activation='relu'),\n", " tf.keras.layers.Dense(10)\n", " ])\n", " logits = model(features, training=False)\n", "\n", " if mode == tf.estimator.ModeKeys.PREDICT:\n", " predictions = {'logits': logits}\n", " return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)\n", "\n", " optimizer = tf.compat.v1.train.GradientDescentOptimizer(\n", " learning_rate=LEARNING_RATE)\n", " loss = tf.keras.losses.SparseCategoricalCrossentropy(\n", " from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits)\n", " loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)\n", " if mode == tf.estimator.ModeKeys.EVAL:\n", " return tf.estimator.EstimatorSpec(mode, loss=loss)\n", "\n", " return tf.estimator.EstimatorSpec(\n", " mode=mode,\n", " loss=loss,\n", " train_op=optimizer.minimize(\n", " loss, tf.compat.v1.train.get_or_create_global_step()))" ] }, { "cell_type": "markdown", "metadata": { "id": "P94PrIW_kSCE" }, "source": [ "注:尽管在本例中学习率是固定的,但通常情况下可能有必要根据全局批次大小对学习率进行调整。" ] }, { "cell_type": "markdown", "metadata": { "id": "UhNtHfuxCGVy" }, "source": [ "## MultiWorkerMirroredStrategy\n", "\n", "要训练模型,请使用 `tf.distribute.experimental.MultiWorkerMirroredStrategy` 的实例。`MultiWorkerMirroredStrategy`会在所有工作进程的每个设备上的模型的层中创建所有变量的副本。它使用 `CollectiveOps`(一种用于集合通信的 TensorFlow 运算)来聚合梯度并确保变量同步。[`tf.distribute.Strategy` 指南](../../guide/distribute_strategy.ipynb)中有关于此策略的更多详细信息。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:13:08.972719Z", "iopub.status.busy": "2023-11-07T23:13:08.972093Z", "iopub.status.idle": "2023-11-07T23:13:10.943223Z", "shell.execute_reply": "2023-11-07T23:13:10.942532Z" }, "id": "1uFSHCJXMrQ-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_515360/349189047.py:1: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "use distribute.MultiWorkerMirroredStrategy instead\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO\n" ] } ], "source": [ "strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()" ] }, { "cell_type": "markdown", "metadata": { "id": "H47DDcOgfzm7" }, "source": [ "## 训练和评估模型\n", "\n", "接下来,在 `RunConfig` 中为 Estimator 指定分布式策略,并通过调用 `tf.estimator.train_and_evaluate` 进行训练和评估。本教程通过 `train_distribute` 指定策略来分布训练。也可以通过 `eval_distribute` 来分布评估。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T23:13:10.947191Z", "iopub.status.busy": "2023-11-07T23:13:10.946476Z", "iopub.status.idle": "2023-11-07T23:13:23.855817Z", "shell.execute_reply": "2023-11-07T23:13:23.854970Z" }, "id": "BcsuBYrpgnlS" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_515360/2557501124.py:1: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Initializing RunConfig with distribution strategies.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Not using Distribute Coordinator.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_515360/2557501124.py:3: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': '/tmp/multiworker', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': , '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_515360/2557501124.py:7: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_515360/2557501124.py:8: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_515360/2557501124.py:5: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Not using Distribute Coordinator.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running training and evaluation locally (non-distributed).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1246: StrategyBase.configure (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "use `update_config_proto` instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:The `input_fn` accepts an `input_context` which will be given by DistributionStrategy\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:462: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.\n", " warnings.warn(\"To make it possible to preserve tf.data options across \"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:459: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:459: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.v1.input_lib) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use the iterator's `initializer` property instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.v1.input_lib) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use the iterator's `initializer` property instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 23:13:17.215155: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node {{node MultiDeviceIteratorFromStringHandle}}\n", "\t. Registered: device='CPU'\n", "\n", "2023-11-07 23:13:17.216474: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node {{node MultiDeviceIteratorGetNextFromShard}}\n", "\t. Registered: device='CPU'\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 23:13:17.225235: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node {{node MultiDeviceIteratorFromStringHandle}}\n", "\t. Registered: device='CPU'\n", "\n", "2023-11-07 23:13:17.226238: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node {{node MultiDeviceIteratorGetNextFromShard}}\n", "\t. Registered: device='CPU'\n", "\n", "2023-11-07 23:13:17.249919: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node {{node MultiDeviceIteratorFromStringHandle}}\n", "\t. Registered: device='CPU'\n", "\n", "2023-11-07 23:13:17.250960: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node {{node MultiDeviceIteratorGetNextFromShard}}\n", "\t. Registered: device='CPU'\n", "\n", "2023-11-07 23:13:17.258385: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node {{node MultiDeviceIteratorFromStringHandle}}\n", "\t. Registered: device='CPU'\n", "\n", "2023-11-07 23:13:17.259003: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node {{node MultiDeviceIteratorGetNextFromShard}}\n", "\t. Registered: device='CPU'\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 2.308822, step = 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 2.308822, step = 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 153.52\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 153.52\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 2.297161, step = 100 (0.654 sec)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 2.297161, step = 100 (0.654 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 211.129\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:global_step/sec: 211.129\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 2.2924602, step = 200 (0.473 sec)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 2.2924602, step = 200 (0.473 sec)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 234...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 234...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 234 into /tmp/multiworker/model.ckpt.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 234 into /tmp/multiworker/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 234...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 234...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Starting evaluation at 2023-11-07T23:13:22\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Starting evaluation at 2023-11-07T23:13:22\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-234\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-234\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [10/100]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [10/100]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [20/100]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [20/100]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [30/100]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [30/100]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [40/100]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [40/100]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [50/100]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [50/100]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [60/100]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [60/100]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [70/100]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [70/100]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [80/100]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [80/100]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [90/100]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [90/100]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [100/100]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [100/100]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Inference Time : 1.34731s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Inference Time : 1.34731s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Finished evaluation at 2023-11-07-23:13:23\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Finished evaluation at 2023-11-07-23:13:23\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving dict for global step 234: global_step = 234, loss = 2.2986863\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Saving dict for global step 234: global_step = 234, loss = 2.2986863\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 234: /tmp/multiworker/model.ckpt-234\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 234: /tmp/multiworker/model.ckpt-234\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 2.2939153.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 2.2939153.\n" ] }, { "data": { "text/plain": [ "({'loss': 2.2986863, 'global_step': 234}, [])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = tf.estimator.RunConfig(train_distribute=strategy)\n", "\n", "classifier = tf.estimator.Estimator(\n", " model_fn=model_fn, model_dir='/tmp/multiworker', config=config)\n", "tf.estimator.train_and_evaluate(\n", " classifier,\n", " train_spec=tf.estimator.TrainSpec(input_fn=input_fn),\n", " eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "XVk4ftYx6JAO" }, "source": [ "## 优化训练性能\n", "\n", "您现在有了一个模型和一个由 `tf.distribute.Strategy` 驱动的支持多工作进程的 Estimator。您可以尝试使用以下技巧来优化多工作进程训练的性能:\n", "\n", "- *增加批次大小*:这里指定的批次大小是按 GPU 计算的。通常,建议使用适合 GPU 内存的最大批次大小。\n", "\n", "- *强制转换变量*:尽可能将变量强制转换为 `tf.float`。官方 ResNet 模型包括一个说明如何实现的[示例](https://github.com/tensorflow/models/blob/8367cf6dabe11adf7628541706b660821f397dce/official/resnet/resnet_model.py#L466)。\n", "\n", "- *使用集合通信*: `MultiWorkerMirroredStrategy` 提供了多个[集合通信实现](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/distribute/cross_device_ops.py)。\n", "\n", " - `RING` 使用 gRPC 作为跨主机通信层实现了基于环的集合。\n", " - `NCCL` 使用 [NVIDIA 的 NCCL](https://developer.nvidia.com/nccl) 来实现集合。\n", " - `AUTO` 将选择推迟到运行时。\n", "\n", " 集合实现的最佳选择取决于 GPU 的数量和种类,以及集群中的网络互连。要重写自动选择,请为 `MultiWorkerMirroredStrategy` 的构造函数的 `communication` 参数指定有效值,例如 `communication=tf.distribute.experimental.CollectiveCommunication.NCCL`。\n", "\n", "访问指南中的[性能部分](../../guide/function.ipynb),了解有关其他策略和[工具](../../guide/profiler.md)的更多信息,您可以使用它们来优化 TensorFlow 模型的性能。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "AW0Hb2xM6EGX" }, "source": [ "## 更多的代码示例\n", "\n", "1. tensorflow/ecosystem 中的[端到端示例](https://github.com/tensorflow/ecosystem/tree/master/distribution_strategy),使用 Kubernetes 模板进行多工作进程训练。该示例以 Keras 模型开始,并使用 `tf.keras.estimator.model_to_estimator` API 将其转换为 Estimator。\n", "2. [官方模型](https://github.com/tensorflow/models/blob/master/official/resnet/imagenet_main.py),其中许多模型可以配置为运行多个分布式策略。\n" ] } ], "metadata": { "colab": { "collapsed_sections": [ "Tce3stUlHN0L" ], "name": "multi_worker_with_estimator.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }