{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2019 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "# Multi-worker training with Keras\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "xHxb-dlhMIzW" }, "source": [ "## Overview\n", "\n", "This tutorial demonstrates how to perform multi-worker distributed training with a Keras model and the `Model.fit` API using the `tf.distribute.MultiWorkerMirroredStrategy` API. With the help of this strategy, a Keras model that was designed to run on a single-worker can seamlessly work on multiple workers with minimal code changes.\n", "\n", "To learn how to use the `MultiWorkerMirroredStrategy` with Keras and a custom training loop, refer to [Custom training loop with Keras and MultiWorkerMirroredStrategy](multi_worker_with_ctl.ipynb).\n", "\n", "This tutorial contains a minimal multi-worker example with two workers for demonstration purposes." ] }, { "cell_type": "markdown", "metadata": { "id": "JUdRerXg6yz3" }, "source": [ "### Choose the right strategy" ] }, { "cell_type": "markdown", "metadata": { "id": "YAiCV_oL63GM" }, "source": [ "Before you dive in, make sure that `tf.distribute.MultiWorkerMirroredStrategy` is the right choice for your accelerator(s) and training. These are two common ways of distributing training with data parallelism:\n", "\n", "* _Synchronous training_, where the steps of training are synced across the workers and replicas, such as `tf.distribute.MirroredStrategy`, `tf.distribute.TPUStrategy`, and `tf.distribute.MultiWorkerMirroredStrategy`. All workers train over different slices of input data in sync, and aggregating gradients at each step.\n", "* _Asynchronous training_, where the training steps are not strictly synced, such as `tf.distribute.experimental.ParameterServerStrategy`. All workers are independently training over the input data and updating variables asynchronously.\n", "\n", "If you are looking for multi-worker synchronous training without TPU, then `tf.distribute.MultiWorkerMirroredStrategy` is your choice. It creates copies of all variables in the model's layers on each device across all workers. It uses `CollectiveOps`, a TensorFlow op for collective communication, to aggregate gradients and keeps the variables in sync. For those interested, check out the `tf.distribute.experimental.CommunicationOptions` parameter for the collective implementation options.\n", "\n", "For an overview of `tf.distribute.Strategy` APIs, refer to [Distributed training in TensorFlow](../../guide/distributed_training.ipynb)." ] }, { "cell_type": "markdown", "metadata": { "id": "MUXex9ctTuDB" }, "source": [ "## Setup\n", "\n", "Start with some necessary imports:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bnYxvfLD-LW-" }, "outputs": [], "source": [ "import json\n", "import os\n", "import sys" ] }, { "cell_type": "markdown", "metadata": { "id": "Zz0EY91y3mxy" }, "source": [ "Before importing TensorFlow, make a few changes to the environment:\n", "\n", "* In a real-world application, each worker would be on a different machine. For the purposes of this tutorial, all the workers will run on the **this** machine. Therefore, disable all GPUs to prevent errors caused by all workers trying to use the same GPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rpEIVI5upIzM" }, "outputs": [], "source": [ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"" ] }, { "cell_type": "markdown", "metadata": { "id": "7X1MS6385BWi" }, "source": [ "* Reset the `TF_CONFIG` environment variable (you'll learn more about this later):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WEJLYa2_7OZF" }, "outputs": [], "source": [ "os.environ.pop('TF_CONFIG', None)" ] }, { "cell_type": "markdown", "metadata": { "id": "Rd4L9Ii77SS8" }, "source": [ "* Make sure that the current directory is on Python's path—this allows the notebook to import the files written by `%%writefile` later:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hPBuZUNSZmrQ" }, "outputs": [], "source": [ "if '.' not in sys.path:\n", " sys.path.insert(0, '.')" ] }, { "cell_type": "markdown", "metadata": { "id": "9hLpDZhAz2q-" }, "source": [ "Install `tf-nightly`, as the frequency of checkpoint saving at a particular step with the `save_freq` argument in `tf.keras.callbacks.BackupAndRestore` is introduced from TensorFlow 2.10:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-XqozLfzz30N" }, "outputs": [], "source": [ "!pip install tf-nightly" ] }, { "cell_type": "markdown", "metadata": { "id": "524e38dab658" }, "source": [ "Finally, import TensorFlow:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vHNvttzV43sA" }, "outputs": [], "source": [ "import tensorflow as tf" ] }, { "cell_type": "markdown", "metadata": { "id": "0S2jpf6Sx50i" }, "source": [ "### Dataset and model definition" ] }, { "cell_type": "markdown", "metadata": { "id": "fLW6D2TzvC-4" }, "source": [ "Next, create an `mnist_setup.py` file with a simple model and dataset setup. This Python file will be used by the worker processes in this tutorial:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dma_wUAxZqo2" }, "outputs": [], "source": [ "%%writefile mnist_setup.py\n", "\n", "import os\n", "import tensorflow as tf\n", "import numpy as np\n", "\n", "def mnist_dataset(batch_size):\n", " (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()\n", " # The `x` arrays are in uint8 and have values in the [0, 255] range.\n", " # You need to convert them to float32 with values in the [0, 1] range.\n", " x_train = x_train / np.float32(255)\n", " y_train = y_train.astype(np.int64)\n", " train_dataset = tf.data.Dataset.from_tensor_slices(\n", " (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)\n", " return train_dataset\n", "\n", "def build_and_compile_cnn_model():\n", " model = tf.keras.Sequential([\n", " tf.keras.layers.InputLayer(input_shape=(28, 28)),\n", " tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n", " tf.keras.layers.Conv2D(32, 3, activation='relu'),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(128, activation='relu'),\n", " tf.keras.layers.Dense(10)\n", " ])\n", " model.compile(\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n", " metrics=['accuracy'])\n", " return model" ] }, { "cell_type": "markdown", "metadata": { "id": "2UL3kisMO90X" }, "source": [ "### Model training on a single worker\n", "\n", "Try training the model for a small number of epochs and observe the results of _a single worker_ to make sure everything works correctly. As training progresses, the loss should drop and the accuracy should increase." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6Qe6iAf5O8iJ" }, "outputs": [], "source": [ "import mnist_setup\n", "\n", "batch_size = 64\n", "single_worker_dataset = mnist_setup.mnist_dataset(batch_size)\n", "single_worker_model = mnist_setup.build_and_compile_cnn_model()\n", "single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)" ] }, { "cell_type": "markdown", "metadata": { "id": "JmgZwwymxqt5" }, "source": [ "## Multi-worker configuration\n", "\n", "Now let's enter the world of multi-worker training.\n", "\n", "### A cluster with jobs and tasks\n", "\n", "In TensorFlow, distributed training involves a `'cluster'`\n", "with several jobs, and each of the jobs may have one or more `'task'`s.\n", "\n", "You will need the `TF_CONFIG` configuration environment variable for training on multiple machines, each of which possibly has a different role. `TF_CONFIG` is a JSON string used to specify the cluster configuration for each worker that is part of the cluster.\n", "\n", "There are two components of a `TF_CONFIG` variable: `'cluster'` and `'task'`.\n", "\n", "* A `'cluster'` is the same for all workers and provides information about the training cluster, which is a dict consisting of different types of jobs, such as `'worker'` or `'chief'`.\n", " - In multi-worker training with `tf.distribute.MultiWorkerMirroredStrategy`, there is usually one `'worker'` that takes on more responsibilities, such as saving a checkpoint and writing a summary file for TensorBoard, in addition to what a regular `'worker'` does. Such `'worker'` is referred to as the chief worker (with a job name `'chief'`).\n", " - It is customary for the worker with `'index'` `0` to be the `'chief'`.\n", "\n", "* A `'task'` provides information on the current task and is different for each worker. It specifies the `'type'` and `'index'` of that worker.\n", "\n", "Below is an example configuration:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XK1eTYvSZiX7" }, "outputs": [], "source": [ "tf_config = {\n", " 'cluster': {\n", " 'worker': ['localhost:12345', 'localhost:23456']\n", " },\n", " 'task': {'type': 'worker', 'index': 0}\n", "}" ] }, { "cell_type": "markdown", "metadata": { "id": "JjgwJbPKZkJL" }, "source": [ "Note that `tf_config` is just a local variable in Python. To use it for training configuration, serialize it as a JSON and place it in a `TF_CONFIG` environment variable." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yY-T0YDQZjbu" }, "outputs": [], "source": [ "json.dumps(tf_config)" ] }, { "cell_type": "markdown", "metadata": { "id": "8YFpxrcsZ2xG" }, "source": [ "In the example configuration above, you set the task `'type'` to `'worker'` and the task `'index'` to `0`. Therefore, this machine is the _first_ worker. It will be appointed as the `'chief'` worker.\n", "\n", "Note: Other machines will need to have the `TF_CONFIG` environment variable set as well, and it should have the same `'cluster'` dict, but different task `'type'`s or task `'index'`es, depending on the roles of those machines." ] }, { "cell_type": "markdown", "metadata": { "id": "aogb74kHxynz" }, "source": [ "In practice, you would create multiple workers on external IP addresses/ports and set a `TF_CONFIG` variable on each worker accordingly. For illustration purposes, this tutorial shows how you may set up a `TF_CONFIG` variable with two workers on a `localhost`:\n", "- The first (`'chief'`) worker's `TF_CONFIG` as shown above.\n", "- For the second worker, you will set `tf_config['task']['index']=1`" ] }, { "cell_type": "markdown", "metadata": { "id": "cIlkfWmjz1PG" }, "source": [ "### Environment variables and subprocesses in notebooks" ] }, { "cell_type": "markdown", "metadata": { "id": "FcjAbuGY1ACJ" }, "source": [ "Subprocesses inherit environment variables from their parent. So if you set an environment variable in this Jupyter Notebook process:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PH2gHn2_0_U8" }, "outputs": [], "source": [ "os.environ['GREETINGS'] = 'Hello TensorFlow!'" ] }, { "cell_type": "markdown", "metadata": { "id": "gQkIX-cg18md" }, "source": [ "... then you can access the environment variable from the subprocesses:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pquKO6IA18G5" }, "outputs": [], "source": [ "%%bash\n", "echo ${GREETINGS}" ] }, { "cell_type": "markdown", "metadata": { "id": "af6BCA-Y2fpz" }, "source": [ "In the next section, you'll use this method to pass the `TF_CONFIG` to the worker subprocesses. You would never really launch your jobs this way in a real-world scenario—this tutorial is just showing how to do it with a minimal multi-worker example." ] }, { "cell_type": "markdown", "metadata": { "id": "dnDJmaRA9qnf" }, "source": [ "## Train the model" ] }, { "cell_type": "markdown", "metadata": { "id": "UhNtHfuxCGVy" }, "source": [ "To train the model, firstly create an instance of the `tf.distribute.MultiWorkerMirroredStrategy`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1uFSHCJXMrQ-" }, "outputs": [], "source": [ "strategy = tf.distribute.MultiWorkerMirroredStrategy()" ] }, { "cell_type": "markdown", "metadata": { "id": "N0iv7SyyAohc" }, "source": [ "Note: `TF_CONFIG` is parsed and TensorFlow's GRPC servers are started at the time `MultiWorkerMirroredStrategy` is called, so the `TF_CONFIG` environment variable must be set before a `tf.distribute.Strategy` instance is created. Since `TF_CONFIG` is not set yet, the above strategy is effectively single-worker training." ] }, { "cell_type": "markdown", "metadata": { "id": "H47DDcOgfzm7" }, "source": [ "With the integration of `tf.distribute.Strategy` API into `tf.keras`, the only change you will make to distribute the training to multiple-workers is enclosing the model building and `model.compile()` call inside `strategy.scope()`. The distribution strategy's scope dictates how and where the variables are created, and in the case of `MultiWorkerMirroredStrategy`, the variables created are `MirroredVariable`s, and they are replicated on each of the workers.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wo6b9wX65glL" }, "outputs": [], "source": [ "with strategy.scope():\n", " # Model building/compiling need to be within `strategy.scope()`.\n", " multi_worker_model = mnist_setup.build_and_compile_cnn_model()" ] }, { "cell_type": "markdown", "metadata": { "id": "Mhq3fzyR5hTw" }, "source": [ "Note: Currently there is a limitation in `MultiWorkerMirroredStrategy` where TensorFlow ops need to be created after the instance of strategy is created. If you encounter `RuntimeError: Collective ops must be configured at program startup`, try creating the instance of `MultiWorkerMirroredStrategy` at the beginning of the program and put the code that may create ops after the strategy is instantiated." ] }, { "cell_type": "markdown", "metadata": { "id": "jfYpmIxO6Jck" }, "source": [ "To actually run with `MultiWorkerMirroredStrategy` you'll need to run worker processes and pass a `TF_CONFIG` to them.\n", "\n", "Like the `mnist_setup.py` file written earlier, here is the `main.py` that each of the workers will run:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BcsuBYrpgnlS" }, "outputs": [], "source": [ "%%writefile main.py\n", "\n", "import os\n", "import json\n", "\n", "import tensorflow as tf\n", "import mnist_setup\n", "\n", "per_worker_batch_size = 64\n", "tf_config = json.loads(os.environ['TF_CONFIG'])\n", "num_workers = len(tf_config['cluster']['worker'])\n", "\n", "strategy = tf.distribute.MultiWorkerMirroredStrategy()\n", "\n", "global_batch_size = per_worker_batch_size * num_workers\n", "multi_worker_dataset = mnist_setup.mnist_dataset(global_batch_size)\n", "\n", "with strategy.scope():\n", " # Model building/compiling need to be within `strategy.scope()`.\n", " multi_worker_model = mnist_setup.build_and_compile_cnn_model()\n", "\n", "\n", "multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)" ] }, { "cell_type": "markdown", "metadata": { "id": "Aom9xelvJQ_6" }, "source": [ "In the code snippet above note that the `global_batch_size`, which gets passed to `Dataset.batch`, is set to `per_worker_batch_size * num_workers`. This ensures that each worker processes batches of `per_worker_batch_size` examples regardless of the number of workers." ] }, { "cell_type": "markdown", "metadata": { "id": "lHLhOii67Saa" }, "source": [ "The current directory now contains both Python files:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bi6x05Sr60O9" }, "outputs": [], "source": [ "%%bash\n", "ls *.py" ] }, { "cell_type": "markdown", "metadata": { "id": "qmEEStPS6vR_" }, "source": [ "Serialize the `TF_CONFIG` to JSON and add it to the environment variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9uu3g7vV7Bbt" }, "outputs": [], "source": [ "os.environ['TF_CONFIG'] = json.dumps(tf_config)" ] }, { "cell_type": "markdown", "metadata": { "id": "MsY3dQLK7jdf" }, "source": [ "Now, you can launch a worker process that will run the `main.py` and use the `TF_CONFIG`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "txMXaq8d8N_S" }, "outputs": [], "source": [ "# first kill any previous runs\n", "%killbgscripts" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qnSma_Ck7r-r" }, "outputs": [], "source": [ "%%bash --bg\n", "python main.py &> job_0.log" ] }, { "cell_type": "markdown", "metadata": { "id": "ZChyazqS7v0P" }, "source": [ "There are a few things to note about the above command:\n", "\n", "1. It uses the `%%bash` which is a [notebook \"magic\"](https://ipython.readthedocs.io/en/stable/interactive/magics.html) to run some bash commands.\n", "2. It uses the `--bg` flag to run the `bash` process in the background, because this worker will not terminate. It waits for all the workers before it starts.\n", "\n", "The backgrounded worker process won't print output to this notebook, so the `&>` redirects its output to a file so that you can inspect what happened in a log file later.\n", "\n", "So, wait a few seconds for the process to start up:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Hm2yrULE9281" }, "outputs": [], "source": [ "import time\n", "time.sleep(10)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZFPoNxg_9_Mx" }, "source": [ "Now, inspect what's been output to the worker's log file so far:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vZEOuVgQ9-hn" }, "outputs": [], "source": [ "%%bash\n", "cat job_0.log" ] }, { "cell_type": "markdown", "metadata": { "id": "RqZhVF7L_KOy" }, "source": [ "The last line of the log file should say: `Started server with target: grpc://localhost:12345`. The first worker is now ready and is waiting for all the other worker(s) to be ready to proceed." ] }, { "cell_type": "markdown", "metadata": { "id": "Pi8vPNNA_l4a" }, "source": [ "So update the `tf_config` for the second worker's process to pick up:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lAiYkkPu_Jqd" }, "outputs": [], "source": [ "tf_config['task']['index'] = 1\n", "os.environ['TF_CONFIG'] = json.dumps(tf_config)" ] }, { "cell_type": "markdown", "metadata": { "id": "0AshGVO0_x0w" }, "source": [ "Launch the second worker. This will start the training since all the workers are active (so there's no need to background this process):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_ESVtyQ9_xjx" }, "outputs": [], "source": [ "%%bash\n", "python main.py" ] }, { "cell_type": "markdown", "metadata": { "id": "hX4FA2O2AuAn" }, "source": [ "If you recheck the logs written by the first worker, you'll learn that it participated in training that model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rc6hw3yTBKXX" }, "outputs": [], "source": [ "%%bash\n", "cat job_0.log" ] }, { "cell_type": "markdown", "metadata": { "id": "zL79ak5PMzEg" }, "source": [ "Note: This may run slower than the test run at the beginning of this tutorial because running multiple workers on a single machine only adds overhead. The goal here is not to improve the training time but to give an example of multi-worker training.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sG5_1UgrgniF" }, "outputs": [], "source": [ "# Delete the `TF_CONFIG`, and kill any background tasks so they don't affect the next section.\n", "os.environ.pop('TF_CONFIG', None)\n", "%killbgscripts" ] }, { "cell_type": "markdown", "metadata": { "id": "9j2FJVHoUIrE" }, "source": [ "## Multi-worker training in depth\n" ] }, { "cell_type": "markdown", "metadata": { "id": "C1hBks_dAZmT" }, "source": [ "So far, you have learned how to perform a basic multi-worker setup. The rest of the tutorial goes over other factors, which may be useful or important for real use cases, in detail." ] }, { "cell_type": "markdown", "metadata": { "id": "Rr14Vl9GR4zq" }, "source": [ "### Dataset sharding\n", "\n", "In multi-worker training, _dataset sharding_ is needed to ensure convergence and performance.\n", "\n", "The example in the previous section relies on the default autosharding provided by the `tf.distribute.Strategy` API. You can control the sharding by setting the `tf.data.experimental.AutoShardPolicy` of the `tf.data.experimental.DistributeOptions`.\n", "\n", "To learn more about _auto-sharding_, refer to the [Distributed input guide](https://www.tensorflow.org/tutorials/distribute/input#sharding).\n", "\n", "Here is a quick example of how to turn the auto sharding off, so that each replica processes every example (_not recommended_):\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JxEtdh1vH-TF" }, "outputs": [], "source": [ "options = tf.data.Options()\n", "options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF\n", "\n", "global_batch_size = 64\n", "multi_worker_dataset = mnist_setup.mnist_dataset(batch_size=64)\n", "dataset_no_auto_shard = multi_worker_dataset.with_options(options)" ] }, { "cell_type": "markdown", "metadata": { "id": "z85hElxsBQsT" }, "source": [ "### Evaluation" ] }, { "cell_type": "markdown", "metadata": { "id": "gmqvlh5LhAoU" }, "source": [ "If you pass the `validation_data` into `Model.fit` as well, it will alternate between training and evaluation for each epoch. The evaluation work is distributed across the same set of workers, and its results are aggregated and available to all workers.\n", "\n", "Similar to training, the validation dataset is automatically sharded at the file level. You need to set a global batch size in the validation dataset and set the `validation_steps`.\n", "\n", "A repeated dataset (by calling `tf.data.Dataset.repeat`) is recommended for evaluation.\n", "\n", "Alternatively, you can also create another task that periodically reads checkpoints and runs the evaluation. This is what an Estimator does. But this is not a recommended way to perform evaluation and thus its details are omitted." ] }, { "cell_type": "markdown", "metadata": { "id": "FNkoxUPJBNTb" }, "source": [ "### Performance" ] }, { "cell_type": "markdown", "metadata": { "id": "XVk4ftYx6JAO" }, "source": [ "To tweak the performance of multi-worker training, you can try the following:\n", "\n", "- `tf.distribute.MultiWorkerMirroredStrategy` provides multiple [collective communication implementations](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/CommunicationImplementation):\n", " - `RING` implements ring-based collectives using gRPC as the cross-host communication layer.\n", " - `NCCL` uses the [NVIDIA Collective Communication Library](https://developer.nvidia.com/nccl) to implement collectives.\n", " - `AUTO` defers the choice to the runtime.\n", " \n", " The best choice of collective implementation depends upon the number of GPUs, the type of GPUs, and the network interconnects in the cluster. To override the automatic choice, specify the `communication_options` parameter of `MultiWorkerMirroredStrategy`'s constructor. For example:\n", " \n", " ```python\n", " communication_options=tf.distribute.experimental.CommunicationOptions(implementation=tf.distribute.experimental.CommunicationImplementation.NCCL)\n", " ```\n", "\n", "- Cast the variables to `tf.float` if possible:\n", " - The official ResNet model includes [an example](https://github.com/tensorflow/models/blob/8367cf6dabe11adf7628541706b660821f397dce/official/resnet/resnet_model.py#L466) of how to do this." ] }, { "cell_type": "markdown", "metadata": { "id": "97WhAu8uKw3j" }, "source": [ "### Fault tolerance\n", "\n", "In synchronous training, the cluster would fail if one of the workers fails and no failure-recovery mechanism exists.\n", "\n", "Using Keras with `tf.distribute.Strategy` comes with the advantage of fault tolerance in cases where workers die or are otherwise unstable. You can do this by preserving the training state in the distributed file system of your choice, such that upon a restart of the instance that previously failed or preempted, the training state is recovered.\n", "\n", "When a worker becomes unavailable, other workers will fail (possibly after a timeout). In such cases, the unavailable worker needs to be restarted, as well as other workers that have failed.\n", "\n", "Note: Previously, the `ModelCheckpoint` callback provided a mechanism to restore the training state upon a restart from a job failure for multi-worker training. The TensorFlow team is introducing a new [`BackupAndRestore`](#scrollTo=kmH8uCUhfn4w) callback, which also adds the support to single-worker training for a consistent experience, and removed the fault tolerance functionality from existing `ModelCheckpoint` callback. From now on, applications that rely on this behavior should migrate to the new `BackupAndRestore` callback." ] }, { "cell_type": "markdown", "metadata": { "id": "KvHPjGlyyFt6" }, "source": [ "#### The `ModelCheckpoint` callback\n", "\n", "`ModelCheckpoint` callback no longer provides fault tolerance functionality, please use [`BackupAndRestore`](#scrollTo=kmH8uCUhfn4w) callback instead.\n", "\n", "The `ModelCheckpoint` callback can still be used to save checkpoints. But with this, if training was interrupted or successfully finished, in order to continue training from the checkpoint, the user is responsible to load the model manually.\n", "\n", "Optionally, users can choose to save and restore model/weights outside `ModelCheckpoint` callback." ] }, { "cell_type": "markdown", "metadata": { "id": "EUNV5Utc1d0s" }, "source": [ "### Model saving and loading\n", "\n", "To save your model using `model.save` or `tf.saved_model.save`, the saving destination needs to be different for each worker.\n", "\n", "- For non-chief workers, you will need to save the model to a temporary directory.\n", "- For the chief, you will need to save to the provided model directory.\n", "\n", "The temporary directories on the worker need to be unique to prevent errors resulting from multiple workers trying to write to the same location.\n", "\n", "The model saved in all the directories is identical, and typically only the model saved by the chief should be referenced for restoring or serving.\n", "\n", "You should have some cleanup logic that deletes the temporary directories created by the workers once your training has completed.\n", "\n", "The reason for saving on the chief and workers at the same time is because you might be aggregating variables during checkpointing, which requires both the chief and workers to participate in the allreduce communication protocol. On the other hand, letting chief and workers save to the same model directory will result in errors due to contention.\n", "\n", "Using the `MultiWorkerMirroredStrategy`, the program is run on every worker, and in order to know whether the current worker is the chief, it takes advantage of the cluster resolver object that has attributes `task_type` and `task_id`:\n", "- `task_type` tells you what the current job is (for example, `'worker'`).\n", "- `task_id` tells you the identifier of the worker.\n", "- The worker with `task_id == 0` is designated as the chief worker.\n", "\n", "In the code snippet below, the `write_filepath` function provides the file path to write, which depends on the worker's `task_id`:\n", "\n", "- For the chief worker (with `task_id == 0`), it writes to the original file path. \n", "- For other workers, it creates a temporary directory—`temp_dir`—with the `task_id` in the directory path to write in:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XQfGkmg-pfCY" }, "outputs": [], "source": [ "model_path = '/tmp/keras-model'\n", "\n", "def _is_chief(task_type, task_id):\n", " # Note: there are two possible `TF_CONFIG` configurations.\n", " # 1) In addition to `worker` tasks, a `chief` task type is use;\n", " # in this case, this function should be modified to\n", " # `return task_type == 'chief'`.\n", " # 2) Only `worker` task type is used; in this case, worker 0 is\n", " # regarded as the chief. The implementation demonstrated here\n", " # is for this case.\n", " # For the purpose of this Colab section, the `task_type` is `None` case\n", " # is added because it is effectively run with only a single worker.\n", " return (task_type == 'worker' and task_id == 0) or task_type is None\n", "\n", "def _get_temp_dir(dirpath, task_id):\n", " base_dirpath = 'workertemp_' + str(task_id)\n", " temp_dir = os.path.join(dirpath, base_dirpath)\n", " tf.io.gfile.makedirs(temp_dir)\n", " return temp_dir\n", "\n", "def write_filepath(filepath, task_type, task_id):\n", " dirpath = os.path.dirname(filepath)\n", " base = os.path.basename(filepath)\n", " if not _is_chief(task_type, task_id):\n", " dirpath = _get_temp_dir(dirpath, task_id)\n", " return os.path.join(dirpath, base)\n", "\n", "task_type, task_id = (strategy.cluster_resolver.task_type,\n", " strategy.cluster_resolver.task_id)\n", "write_model_path = write_filepath(model_path, task_type, task_id)" ] }, { "cell_type": "markdown", "metadata": { "id": "hs0_agYR_qKm" }, "source": [ "With that, you're now ready to save:" ] }, { "cell_type": "markdown", "metadata": { "id": "XnToxeIcg_6O" }, "source": [ "Deprecated: For Keras objects, it's recommended to use the new high-level `.keras` format and `tf.keras.Model.export`, as demonstrated in the guide [here](https://www.tensorflow.org/guide/keras/save_and_serialize). The low-level SavedModel format continues to be supported for existing code." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J-yA3BYG_vTs" }, "outputs": [], "source": [ "multi_worker_model.save(write_model_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "8LXUVVl9_v5x" }, "source": [ "As described above, later on the model should only be loaded from the file path the chief worker saved to. Therefore, remove the temporary ones the non-chief workers have saved:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aJTyu-97ABpY" }, "outputs": [], "source": [ "if not _is_chief(task_type, task_id):\n", " tf.io.gfile.rmtree(os.path.dirname(write_model_path))" ] }, { "cell_type": "markdown", "metadata": { "id": "Nr-2PKlHAPBT" }, "source": [ "Now, when it's time to load, use the convenient `tf.keras.models.load_model` API, and continue with further work.\n", "\n", "Here, assume only using single worker to load and continue training, in which case you do not call `tf.keras.models.load_model` within another `strategy.scope()` (note that `strategy = tf.distribute.MultiWorkerMirroredStrategy()`, as defined earlier):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iUZna-JKAOrX" }, "outputs": [], "source": [ "loaded_model = tf.keras.models.load_model(model_path)\n", "\n", "# Now that the model is restored, and can continue with the training.\n", "loaded_model.fit(single_worker_dataset, epochs=2, steps_per_epoch=20)" ] }, { "cell_type": "markdown", "metadata": { "id": "YJ1fmxmTpocS" }, "source": [ "### Checkpoint saving and restoring\n", "\n", "On the other hand, checkpointing allows you to save your model's weights and restore them without having to save the whole model.\n", "\n", "Here, you'll create one `tf.train.Checkpoint` that tracks the model, which is managed by the `tf.train.CheckpointManager`, so that only the latest checkpoint is preserved:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_1-RYaB5xnNH" }, "outputs": [], "source": [ "checkpoint_dir = '/tmp/ckpt'\n", "\n", "checkpoint = tf.train.Checkpoint(model=multi_worker_model)\n", "write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)\n", "checkpoint_manager = tf.train.CheckpointManager(\n", " checkpoint, directory=write_checkpoint_dir, max_to_keep=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "7oBpPCRsW1MF" }, "source": [ "Once the `CheckpointManager` is set up, you're now ready to save and remove the checkpoints the non-chief workers had saved:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l1ZXG_GbWzLp" }, "outputs": [], "source": [ "checkpoint_manager.save()\n", "if not _is_chief(task_type, task_id):\n", " tf.io.gfile.rmtree(write_checkpoint_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "RO7cbN40XD5v" }, "source": [ "Now, when you need to restore the model, you can find the latest checkpoint saved using the convenient `tf.train.latest_checkpoint` function. After restoring the checkpoint, you can continue with training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NJW7vtknXFEH" }, "outputs": [], "source": [ "latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)\n", "checkpoint.restore(latest_checkpoint)\n", "multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)" ] }, { "cell_type": "markdown", "metadata": { "id": "kmH8uCUhfn4w" }, "source": [ "#### The `BackupAndRestore` callback\n", "\n", "The `tf.keras.callbacks.BackupAndRestore` callback provides the fault tolerance functionality by backing up the model and current training state in a temporary checkpoint file under `backup_dir` argument to `BackupAndRestore`. \n", "\n", "Note: In Tensorflow 2.9, the current model and the training state is backed up at epoch boundaries. In the `tf-nightly` version and from TensorFlow 2.10, the `BackupAndRestore` callback can back up the model and the training state at epoch or step boundaries. `BackupAndRestore` accepts an optional `save_freq` argument. `save_freq` accepts either `'epoch'` or an `int` value. If `save_freq` is set to `'epoch'` the model is backed up after every epoch. If `save_freq` is set to an integer value greater than `0`, the model is backed up after every `save_freq` number of batches.\n", "\n", "Once the jobs get interrupted and restarted, the `BackupAndRestore` callback restores the last checkpoint, and you can continue training from the beginning of the epoch and step at which the training state was last saved.\n", "\n", "To use it, provide an instance of `tf.keras.callbacks.BackupAndRestore` at the `Model.fit` call.\n", "\n", "With `MultiWorkerMirroredStrategy`, if a worker gets interrupted, the whole cluster will pause until the interrupted worker is restarted. Other workers will also restart, and the interrupted worker will rejoin the cluster. Then, every worker will read the checkpoint file that was previously saved and pick up its former state, thereby allowing the cluster to get back in sync. Then, the training will continue. The distributed dataset iterator state will be re-initialized and not restored.\n", "\n", "The `BackupAndRestore` callback uses the `CheckpointManager` to save and restore the training state, which generates a file called checkpoint that tracks existing checkpoints together with the latest one. For this reason, `backup_dir` should not be re-used to store other checkpoints in order to avoid name collision.\n", "\n", "Currently, the `BackupAndRestore` callback supports single-worker training with no strategy—`MirroredStrategy`—and multi-worker training with `MultiWorkerMirroredStrategy`.\n", "\n", "Below are two examples for both multi-worker training and single-worker training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CYdzZi4Qs1jz" }, "outputs": [], "source": [ "# Multi-worker training with `MultiWorkerMirroredStrategy`\n", "# and the `BackupAndRestore` callback. The training state \n", "# is backed up at epoch boundaries by default.\n", "\n", "callbacks = [tf.keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup')]\n", "with strategy.scope():\n", " multi_worker_model = mnist_setup.build_and_compile_cnn_model()\n", "multi_worker_model.fit(multi_worker_dataset,\n", " epochs=3,\n", " steps_per_epoch=70,\n", " callbacks=callbacks)" ] }, { "cell_type": "markdown", "metadata": { "id": "f8e86TAp0Rsl" }, "source": [ "If the `save_freq` argument in the `BackupAndRestore` callback is set to `'epoch'`, the model is backed up after every epoch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rZjQGPsF0aEI" }, "outputs": [], "source": [ "# The training state is backed up at epoch boundaries because `save_freq` is\n", "# set to `epoch`.\n", "\n", "callbacks = [tf.keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup')]\n", "with strategy.scope():\n", " multi_worker_model = mnist_setup.build_and_compile_cnn_model()\n", "multi_worker_model.fit(multi_worker_dataset,\n", " epochs=3,\n", " steps_per_epoch=70,\n", " callbacks=callbacks)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "p-r44kCM0jc6" }, "source": [ "Note: The next code block uses features that are only available in `tf-nightly` until Tensorflow 2.10 is released.\n", "\n", "If the `save_freq` argument in the `BackupAndRestore` callback is set to an integer value greater than `0`, the model is backed up after every `save_freq` number of batches." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bSJUyLSF0moC" }, "outputs": [], "source": [ "# The training state is backed up at every 30 steps because `save_freq` is set\n", "# to an integer value of `30`.\n", "\n", "callbacks = [tf.keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup', save_freq=30)]\n", "with strategy.scope():\n", " multi_worker_model = mnist_setup.build_and_compile_cnn_model()\n", "multi_worker_model.fit(multi_worker_dataset,\n", " epochs=3,\n", " steps_per_epoch=70,\n", " callbacks=callbacks)" ] }, { "cell_type": "markdown", "metadata": { "id": "rIV5_3ebzXmB" }, "source": [ "If you inspect the directory of `backup_dir` you specified in `BackupAndRestore`, you may notice some temporarily generated checkpoint files. Those files are needed for recovering the previously lost instances, and they will be removed by the library at the end of `Model.fit` upon successful exiting of your training.\n", "\n", "Note: Currently the `BackupAndRestore` callback only supports eager mode. In graph mode, consider using `Model.save`/`tf.saved_model.save` and `tf.keras.models.load_model` for saving and restoring models, respectively, as described in the _Model saving and loading_ section above, and by providing `initial_epoch` in `Model.fit` during training." ] }, { "cell_type": "markdown", "metadata": { "id": "ega2hdOQEmy_" }, "source": [ "## Additional resources\n", "\n", "1. The [Distributed training in TensorFlow](../../guide/distributed_training.ipynb) guide provides an overview of the available distribution strategies.\n", "1. The [Custom training loop with Keras and MultiWorkerMirroredStrategy](multi_worker_with_ctl.ipynb) tutorial shows how to use the `MultiWorkerMirroredStrategy` with Keras and a custom training loop.\n", "1. Check out the [official models](https://github.com/tensorflow/models/tree/master/official), many of which can be configured to run multiple distribution strategies.\n", "1. The [Better performance with tf.function](../../guide/function.ipynb) guide provides information about other strategies and tools, such as the [TensorFlow Profiler](../../guide/profiler.md) you can use to optimize the performance of your TensorFlow models." ] } ], "metadata": { "colab": { "name": "multi_worker_with_keras.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }