{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2019 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "MT-LkFOl2axM" }, "source": [ "# Using DTensors with Keras" ] }, { "cell_type": "markdown", "metadata": { "id": "r6P32iYYV27b" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "vTe9dcbUAwqx" }, "source": [ "## Overview\n", "\n", "In this tutorial, you will learn how to use DTensors with Keras.\n", "\n", "Through DTensor integration with Keras, you can reuse your existing Keras layers and models to build and train distributed machine learning models.\n", "\n", "You will train a multi-layer classification model with the MNIST data. Setting the layout for subclassing model, Sequential model, and functional model will be demonstrated.\n", "\n", "This tutorial assumes that you have already read the [DTensor programing guide](/guide/dtensor_overview), and are familiar with basic DTensor concepts like `Mesh` and `Layout`.\n", "\n", "This tutorial is based on [Training a neural network on MNIST with Keras](https://www.tensorflow.org/datasets/keras_example)." ] }, { "cell_type": "markdown", "metadata": { "id": "keIyP3IoA1o4" }, "source": [ "## Setup\n", "\n", "DTensor (`tf.experimental.dtensor`) has been part of TensorFlow since the 2.9.0 release.\n", "\n", "First, install or upgrade TensorFlow Datasets:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4dHik7NYA5vm" }, "outputs": [], "source": [ "!pip install --quiet --upgrade tensorflow-datasets" ] }, { "cell_type": "markdown", "metadata": { "id": "VttBMZngDx8x" }, "source": [ "Next, import TensorFlow and `dtensor`, and configure TensorFlow to use 8 virtual CPUs.\n", "\n", "Even though this example uses virtual CPUs, DTensor works the same way on CPU, GPU or TPU devices." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CodX6idGBGSm" }, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "from tensorflow.experimental import dtensor" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aAtvrpasDpDD" }, "outputs": [], "source": [ "def configure_virtual_cpus(ncpu):\n", " phy_devices = tf.config.list_physical_devices('CPU')\n", " tf.config.set_logical_device_configuration(\n", " phy_devices[0], \n", " [tf.config.LogicalDeviceConfiguration()] * ncpu)\n", " \n", "configure_virtual_cpus(8)\n", "tf.config.list_logical_devices('CPU')\n", "\n", "devices = [f'CPU:{i}' for i in range(8)]" ] }, { "cell_type": "markdown", "metadata": { "id": "ogULE1OHtyd9" }, "source": [ "## Deterministic pseudo-random number generators\n", "One thing you should note is that DTensor API requires each of the running client to have the same random seeds, so that it could have deterministic behavior for initializing the weights. You can achieve this by setting the global seeds in keras via `tf.keras.utils.set_random_seed()`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9u85YypguL8N" }, "outputs": [], "source": [ "tf.keras.backend.experimental.enable_tf_random_generator()\n", "tf.keras.utils.set_random_seed(1337)" ] }, { "cell_type": "markdown", "metadata": { "id": "tO11XvPDAu3_" }, "source": [ "## Creating a Data Parallel Mesh\n", "\n", "This tutorial demonstrates Data Parallel training. Adapting to Model Parallel training and Spatial Parallel training can be as simple as switching to a different set of `Layout` objects. Refer to the [Distributed training with DTensors](dtensor_ml_tutorial.ipynb) tutorial for more information on distributed training beyond Data Parallel.\n", "\n", "Data Parallel training is a commonly used parallel training scheme, also used by, for example, `tf.distribute.MirroredStrategy`.\n", "\n", "With DTensor, a Data Parallel training loop uses a `Mesh` that consists of a single 'batch' dimension, where each device runs a replica of the model that receives a shard from the global batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6sT6s6z4j9H-" }, "outputs": [], "source": [ "mesh = dtensor.create_mesh([(\"batch\", 8)], devices=devices)" ] }, { "cell_type": "markdown", "metadata": { "id": "rouFcF6FE0aF" }, "source": [ "As each device runs a full replica of the model, the model variables shall be fully replicated across the mesh (unsharded). As an example, a fully replicated Layout for a rank-2 weight on this `Mesh` would be as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U8OxvkDKE1Nu" }, "outputs": [], "source": [ "example_weight_layout = dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh) # or\n", "example_weight_layout = dtensor.Layout.replicated(mesh, rank=2)" ] }, { "cell_type": "markdown", "metadata": { "id": "6Bnic98RE0xi" }, "source": [ "A layout for a rank-2 data tensor on this `Mesh` would be sharded along the first dimension (sometimes known as `batch_sharded`)," ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PhYp0EKBFfxt" }, "outputs": [], "source": [ "example_data_layout = dtensor.Layout(['batch', dtensor.UNSHARDED], mesh) # or\n", "example_data_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)" ] }, { "cell_type": "markdown", "metadata": { "id": "4U-6n0DericV" }, "source": [ "## Create Keras layers with layout\n", "\n", "In the data parallel scheme, you usually create your model weights with a fully replicated layout, so that each replica of the model can do calculations with the sharded input data. \n", "\n", "In order to configure the layout information for your layers' weights, Keras has exposed an extra parameter in the layer constructor for most of the built-in layers.\n", "\n", "The following example builds a small image classification model with fully replicated weight layout. You can specify layout information `kernel` and `bias` in `tf.keras.layers.Dense` via arguments `kernel_layout` and `bias_layout`. Most of the built-in keras layers are ready for explicitly specifying the `Layout` for the layer weights." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Koc5GlA1tFXY" }, "outputs": [], "source": [ "unsharded_layout_2d = dtensor.Layout.replicated(mesh, 2)\n", "unsharded_layout_1d = dtensor.Layout.replicated(mesh, 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GfOGTIxGs5Ql" }, "outputs": [], "source": [ "model = tf.keras.models.Sequential([\n", " tf.keras.layers.Flatten(input_shape=(28, 28)),\n", " tf.keras.layers.Dense(128, \n", " activation='relu',\n", " name='d1',\n", " kernel_layout=unsharded_layout_2d, \n", " bias_layout=unsharded_layout_1d),\n", " tf.keras.layers.Dense(10,\n", " name='d2',\n", " kernel_layout=unsharded_layout_2d, \n", " bias_layout=unsharded_layout_1d)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "0frf3jsVtx_n" }, "source": [ "You can check the layout information by examining the `layout` property on the weights." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z_nqv_VdwcXo" }, "outputs": [], "source": [ "for weight in model.weights:\n", " print(f'Weight name: {weight.name} with layout: {weight.layout}')\n", " break" ] }, { "cell_type": "markdown", "metadata": { "id": "6FMGB-QsxPtU" }, "source": [ "## Load a dataset and build input pipeline\n", "\n", "Load a MNIST dataset and configure some pre-processing input pipeline for it. The dataset itself is not associated with any DTensor layout information." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zGt4kwltxOt4" }, "outputs": [], "source": [ "(ds_train, ds_test), ds_info = tfds.load(\n", " 'mnist',\n", " split=['train', 'test'],\n", " shuffle_files=True,\n", " as_supervised=True,\n", " with_info=True,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HkUaOB_ryaLH" }, "outputs": [], "source": [ "def normalize_img(image, label):\n", " \"\"\"Normalizes images: `uint8` -> `float32`.\"\"\"\n", " return tf.cast(image, tf.float32) / 255., label" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Efm2H1iqydan" }, "outputs": [], "source": [ "batch_size = 128\n", "\n", "ds_train = ds_train.map(\n", " normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n", "ds_train = ds_train.cache()\n", "ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n", "ds_train = ds_train.batch(batch_size)\n", "ds_train = ds_train.prefetch(tf.data.AUTOTUNE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Lcrg6QAtyis4" }, "outputs": [], "source": [ "ds_test = ds_test.map(\n", " normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n", "ds_test = ds_test.batch(batch_size)\n", "ds_test = ds_test.cache()\n", "ds_test = ds_test.prefetch(tf.data.AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "fHEZwib7lhqn" }, "source": [ "## Define the training logic for the model\n", "\n", "Next, define the training and evaluation logic for the model. \n", "\n", "As of TensorFlow 2.9, you have to write a custom-training-loop for a DTensor-enabled Keras model. This is to pack the input data with proper layout information, which is not integrated with the standard `tf.keras.Model.fit()` or `tf.keras.Model.eval()` functions from Keras. you will get more `tf.data` support in the upcoming release. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CAx11gMjzzjs" }, "outputs": [], "source": [ "@tf.function\n", "def train_step(model, x, y, optimizer, metrics):\n", " with tf.GradientTape() as tape:\n", " logits = model(x, training=True)\n", " # tf.reduce_sum sums the batch sharded per-example loss to a replicated\n", " # global loss (scalar).\n", " loss = tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(\n", " y, logits, from_logits=True))\n", " \n", " gradients = tape.gradient(loss, model.trainable_variables)\n", " optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n", "\n", " for metric in metrics.values():\n", " metric.update_state(y_true=y, y_pred=logits)\n", "\n", " loss_per_sample = loss / len(x)\n", " results = {'loss': loss_per_sample}\n", " return results" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "maSTWeRemO0P" }, "outputs": [], "source": [ "@tf.function\n", "def eval_step(model, x, y, metrics):\n", " logits = model(x, training=False)\n", " loss = tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(\n", " y, logits, from_logits=True))\n", "\n", " for metric in metrics.values():\n", " metric.update_state(y_true=y, y_pred=logits)\n", "\n", " loss_per_sample = loss / len(x)\n", " results = {'eval_loss': loss_per_sample}\n", " return results" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dt00axcLmvLr" }, "outputs": [], "source": [ "def pack_dtensor_inputs(images, labels, image_layout, label_layout):\n", " num_local_devices = image_layout.mesh.num_local_devices()\n", " images = tf.split(images, num_local_devices)\n", " labels = tf.split(labels, num_local_devices)\n", " images = dtensor.pack(images, image_layout)\n", " labels = dtensor.pack(labels, label_layout)\n", " return images, labels" ] }, { "cell_type": "markdown", "metadata": { "id": "9Eb-qIJGrxB9" }, "source": [ "## Metrics and optimizers\n", "\n", "When using DTensor API with Keras `Metric` and `Optimizer`, you will need to provide the extra mesh information, so that any internal state variables and tensors can work with variables in the model.\n", "\n", "- For an optimizer, DTensor introduces a new experimental namespace `keras.dtensor.experimental.optimizers`, where many existing Keras Optimizers are extended to receive an additional `mesh` argument. In future releases, it may be merged with Keras core optimizers.\n", "\n", "- For metrics, you can directly specify the `mesh` to the constructor as an argument to make it a DTensor compatible `Metric`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1lu_0mz1sxrl" }, "outputs": [], "source": [ "optimizer = tf.keras.dtensor.experimental.optimizers.Adam(0.01, mesh=mesh)\n", "metrics = {'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(mesh=mesh)}\n", "eval_metrics = {'eval_accuracy': tf.keras.metrics.SparseCategoricalAccuracy(mesh=mesh)}" ] }, { "cell_type": "markdown", "metadata": { "id": "QzufrkistELx" }, "source": [ "## Train the model\n", "\n", "The following example demonstrates how to shard the data from input pipeline on the batch dimension, and train with the model, which has fully replicated weights. \n", "\n", "After 3 epochs, the model should achieve about 97% of accuracy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kZW568Dk0vvL" }, "outputs": [], "source": [ "num_epochs = 3\n", "\n", "image_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=4)\n", "label_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)\n", "\n", "for epoch in range(num_epochs):\n", " print(\"============================\") \n", " print(\"Epoch: \", epoch)\n", " for metric in metrics.values():\n", " metric.reset_state()\n", " step = 0\n", " results = {}\n", " pbar = tf.keras.utils.Progbar(target=None, stateful_metrics=[])\n", " for input in ds_train:\n", " images, labels = input[0], input[1]\n", " images, labels = pack_dtensor_inputs(\n", " images, labels, image_layout, label_layout)\n", "\n", " results.update(train_step(model, images, labels, optimizer, metrics))\n", " for metric_name, metric in metrics.items():\n", " results[metric_name] = metric.result()\n", "\n", " pbar.update(step, values=results.items(), finalize=False)\n", " step += 1\n", " pbar.update(step, values=results.items(), finalize=True)\n", "\n", " for metric in eval_metrics.values():\n", " metric.reset_state()\n", " for input in ds_test:\n", " images, labels = input[0], input[1]\n", " images, labels = pack_dtensor_inputs(\n", " images, labels, image_layout, label_layout)\n", " results.update(eval_step(model, images, labels, eval_metrics))\n", "\n", " for metric_name, metric in eval_metrics.items():\n", " results[metric_name] = metric.result()\n", " \n", " for metric_name, metric in results.items():\n", " print(f\"{metric_name}: {metric.numpy()}\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "HYEXF6qCuoSr" }, "source": [ "## Specify Layout for existing model code\n", "\n", "Often you have models that work well for your use case. Specifying `Layout` information to each individual layer within the model will be a large amount of work requiring a lot of edits.\n", "\n", "To help you easily convert your existing Keras model to work with DTensor API you can use the new `tf.keras.dtensor.experimental.LayoutMap` API that allow you to specify the `Layout` from a global point of view.\n", "\n", "First, you need to create a `LayoutMap` instance, which is a dictionary-like object that contains all the `Layout` you would like to specify for your model weights.\n", "\n", "`LayoutMap` needs a `Mesh` instance at init, which can be used to provide default replicated `Layout` for any weights that doesn't have Layout configured. In case you would like all your model weights to be just fully replicated, you can provide empty `LayoutMap`, and the default mesh will be used to create replicated `Layout`.\n", "\n", "`LayoutMap` uses a string as key and a `Layout` as value. There is a behavior difference between a normal Python dict and this class. The string key will be treated as a regex when retrieving the value." ] }, { "cell_type": "markdown", "metadata": { "id": "SCq5Nl-UP_dS" }, "source": [ "### Subclassed Model\n", "\n", "Consider the following model defined using the Keras subclassing Model syntax." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LZ0hRFs8unu0" }, "outputs": [], "source": [ "class SubclassedModel(tf.keras.Model):\n", "\n", " def __init__(self, name=None):\n", " super().__init__(name=name)\n", " self.feature = tf.keras.layers.Dense(16)\n", " self.feature_2 = tf.keras.layers.Dense(24)\n", " self.dropout = tf.keras.layers.Dropout(0.1)\n", "\n", " def call(self, inputs, training=None):\n", " x = self.feature(inputs)\n", " x = self.dropout(x, training=training)\n", " return self.feature_2(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "1njxqPB-yS97" }, "source": [ "There are 4 weights in this model, which are `kernel` and `bias` for two `Dense` layers. Each of them are mapped based on the object path:\n", "\n", "* `model.feature.kernel`\n", "* `model.feature.bias`\n", "* `model.feature_2.kernel`\n", "* `model.feature_2.bias`\n", "\n", "Note: For subclassed Models, the attribute name, rather than the `.name` attribute of the layer, is used as the key to retrieve the Layout from the mapping. This is consistent with the convention followed by `tf.Module` checkpointing. For complex models with more than a few layers, you can [manually inspect checkpoints](https://www.tensorflow.org/guide/checkpoint#manually_inspecting_checkpoints) to view the attribute mappings. \n", "\n", "Now define the following `LayoutMap` and apply it to the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "goVX6iIZw468" }, "outputs": [], "source": [ "layout_map = tf.keras.dtensor.experimental.LayoutMap(mesh=mesh)\n", "\n", "layout_map['feature.*kernel'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)\n", "layout_map['feature.*bias'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)\n", "\n", "with layout_map.scope():\n", " subclassed_model = SubclassedModel()" ] }, { "cell_type": "markdown", "metadata": { "id": "M32HcSp_PyWs" }, "source": [ "The model weights are created on the first call, so call the model with a DTensor input and confirm the weights have the expected layouts:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c3CbD9l7qUNq" }, "outputs": [], "source": [ "dtensor_input = dtensor.copy_to_mesh(tf.zeros((16, 16)), layout=unsharded_layout_2d)\n", "# Trigger the weights creation for subclass model\n", "subclassed_model(dtensor_input)\n", "\n", "print(subclassed_model.feature.kernel.layout)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZyCnfd-4Q2jk" }, "source": [ "With this, you can quickly map the `Layout` to your models without updating any of your existing code. " ] }, { "cell_type": "markdown", "metadata": { "id": "6GliUdWTQnKC" }, "source": [ "### Sequential and Functional Models" ] }, { "cell_type": "markdown", "metadata": { "id": "6zzvTqAR2Teu" }, "source": [ "For Keras Functional and Sequential models, you can use `tf.keras.dtensor.experimental.LayoutMap` as well.\n", "\n", "Note: For Functional and Sequential models, the mappings are slightly different. The layers in the model don't have a public attribute attached to the model (though you can access them via `Model.layers` as a list). Use the string name as the key in this case. The string name is guaranteed to be unique within a model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gXK2EquIRJCC" }, "outputs": [], "source": [ "layout_map = tf.keras.dtensor.experimental.LayoutMap(mesh=mesh)\n", "\n", "layout_map['feature.*kernel'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)\n", "layout_map['feature.*bias'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cBzwJqrg2TH3" }, "outputs": [], "source": [ "with layout_map.scope():\n", " inputs = tf.keras.Input((16,), batch_size=16)\n", " x = tf.keras.layers.Dense(16, name='feature')(inputs)\n", " x = tf.keras.layers.Dropout(0.1)(x)\n", " output = tf.keras.layers.Dense(32, name='feature_2')(x)\n", " model = tf.keras.Model(inputs, output)\n", "\n", "print(model.layers[1].kernel.layout)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pPuh1NlE3-wO" }, "outputs": [], "source": [ "with layout_map.scope():\n", " model = tf.keras.Sequential([\n", " tf.keras.layers.Dense(16, name='feature', input_shape=(16,)),\n", " tf.keras.layers.Dropout(0.1),\n", " tf.keras.layers.Dense(32, name='feature_2')\n", " ])\n", "\n", "print(model.layers[2].kernel.layout)" ] } ], "metadata": { "colab": { "name": "dtensor_keras_tutorial.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }