##### Copyright 2018 The TensorFlow Authors.


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Use TPUs

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/guide/tpu"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/tpu.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/guide/tpu.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/tpu.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

This guide demonstrates how to perform basic training on [Tensor Processing Units (TPUs)](https://cloud.google.com/tpu/) and TPU Pods, a collection of TPU devices connected by dedicated high-speed network interfaces, with `tf.keras` and custom training loops.

TPUs are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning workloads. They are available through [Google Colab](https://colab.research.google.com/), the [TPU Research Cloud](https://sites.research.google/trc/), and [Cloud TPU](https://cloud.google.com/tpu).

## Setup

Before you run this Colab notebook, make sure that your hardware accelerator is a TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU**.

Import some necessary libraries, including TensorFlow Datasets:

In [2]:
import tensorflow as tf

import os
import tensorflow_datasets as tfds



## TPU initialization

TPUs are typically [Cloud TPU](https://cloud.google.com/tpu/docs/) workers, which are different from the local process running the user's Python program. Thus, you need to do some initialization work to connect to the remote cluster and initialize the TPUs. Note that the `tpu` argument to `tf.distribute.cluster_resolver.TPUClusterResolver` is a special address just for Colab. If you are running your code on Google Compute Engine (GCE), you should instead pass in the name of your Cloud TPU.

Note: The TPU initialization code has to be at the beginning of your program.

In [3]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
# This is the TPU initialization code that has to be at the beginning.
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))

INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.


2023-06-09 12:13:34.011755: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.


INFO:tensorflow:Initializing the TPU system: grpc://10.25.167.66:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.25.167.66:8470


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]


## Manual device placement

After the TPU is initialized, you can use manual device placement to place the computation on a single TPU device:


In [4]:
a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

with tf.device('/TPU:0'):
  c = tf.matmul(a, b)

print("c device: ", c.device)
print(c)

c device:  /job:worker/replica:0/task:0/device:TPU:0


tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32)


## Distribution strategies

Usually, you run your model on multiple TPUs in a data-parallel way. To distribute your model on multiple TPUs (as well as multiple GPUs or multiple machines), TensorFlow offers the `tf.distribute.Strategy` API. You can replace your distribution strategy and the model will run on any given (TPU) device. Learn more in the [Distributed training with TensorFlow](./distributed_training.ipynb) guide.

Using the `tf.distribute.TPUStrategy` option implements synchronous distributed training. TPUs provide their own implementation of efficient all-reduce and other collective operations across multiple TPU cores, which are used in `TPUStrategy`.

To demonstrate this, create a `tf.distribute.TPUStrategy` object:

In [5]:
strategy = tf.distribute.TPUStrategy(resolver)

INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


To replicate a computation so it can run in all TPU cores, you can pass it into the `Strategy.run` API. Below is an example that shows all cores receiving the same inputs `(a, b)` and performing matrix multiplication on each core independently. The outputs will be the values from all the replicas.

In [6]:
@tf.function
def matmul_fn(x, y):
  z = tf.matmul(x, y)
  return z

z = strategy.run(matmul_fn, args=(a, b))
print(z)

PerReplica:{
  0: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  1: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  2: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  3: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  4: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  5: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  6: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  7: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32)
}


## Classification on TPUs

Having covered the basic concepts, consider a more concrete example. This section demonstrates how to use the distribution strategy—`tf.distribute.TPUStrategy`—to train a Keras model on a Cloud TPU.

### Define a Keras model

Start with a definition of a [`Sequential` Keras model](https://www.tensorflow.org/guide/keras/sequential_model) for image classification on the MNIST dataset. It's no different than what you would use if you were training on CPUs or GPUs. Note that Keras model creation needs to be inside the `Strategy.scope`, so the variables can be created on each TPU device. Other parts of the code are not necessary to be inside the `Strategy` scope.

In [7]:
def create_model():
  return tf.keras.Sequential(
      [tf.keras.layers.Conv2D(256, 3, activation='relu', input_shape=(28, 28, 1)),
       tf.keras.layers.Conv2D(256, 3, activation='relu'),
       tf.keras.layers.Flatten(),
       tf.keras.layers.Dense(256, activation='relu'),
       tf.keras.layers.Dense(128, activation='relu'),
       tf.keras.layers.Dense(10)])

### Load the dataset

Efficient use of the `tf.data.Dataset` API is critical when using a Cloud TPU. You can learn more about dataset performance in the [Input pipeline performance guide](./data_performance.ipynb).

If you are using [TPU Nodes](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm), you need to store all data files read by the TensorFlow `Dataset` in [Google Cloud Storage (GCS) buckets](https://cloud.google.com/tpu/docs/storage-buckets). If you are using [TPU VMs](https://cloud.google.com/tpu/docs/users-guide-tpu-vm), you can store data wherever you like. For more information on TPU Nodes and TPU VMs, refer to the [TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) documentation.

For most use cases, it is recommended to convert your data into the `TFRecord` format and use a `tf.data.TFRecordDataset` to read it. Check the [TFRecord and tf.Example tutorial](../tutorials/load_data/tfrecord.ipynb) for details on how to do this. It is not a hard requirement and you can use other dataset readers, such as `tf.data.FixedLengthRecordDataset` or `tf.data.TextLineDataset`.

You can load entire small datasets into memory using `tf.data.Dataset.cache`.

Regardless of the data format used, it is strongly recommended that you use large files on the order of 100MB. This is especially important in this networked setting, as the overhead of opening a file is significantly higher.

As shown in the code below, you should use the Tensorflow Datasets `tfds.load` module to get a copy of the MNIST training and test data. Note that `try_gcs` is specified to use a copy that is available in a public GCS bucket. If you don't specify this, the TPU will not be able to access the downloaded data.

In [8]:
def get_dataset(batch_size, is_training=True):
  split = 'train' if is_training else 'test'
  dataset, info = tfds.load(name='mnist', split=split, with_info=True,
                            as_supervised=True, try_gcs=True)

  # Normalize the input data.
  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.0
    return image, label

  dataset = dataset.map(scale)

  # Only shuffle and repeat the dataset in training. The advantage of having an
  # infinite dataset for training is to avoid the potential last partial batch
  # in each epoch, so that you don't need to think about scaling the gradients
  # based on the actual batch size.
  if is_training:
    dataset = dataset.shuffle(10000)
    dataset = dataset.repeat()

  dataset = dataset.batch(batch_size)

  return dataset

### Train the model using Keras high-level APIs

You can train your model with Keras `Model.fit` and `Model.compile` APIs. There is nothing TPU-specific in this step—you write the code as if you were using multiple GPUs and a `MirroredStrategy` instead of the `TPUStrategy`. You can learn more in the [Distributed training with Keras](../tutorials/distribute/keras.ipynb) tutorial.

In [9]:
with strategy.scope():
  model = create_model()
  model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['sparse_categorical_accuracy'])

batch_size = 200
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size

train_dataset = get_dataset(batch_size, is_training=True)
test_dataset = get_dataset(batch_size, is_training=False)

model.fit(train_dataset,
          epochs=5,
          steps_per_epoch=steps_per_epoch,
          validation_data=test_dataset,
          validation_steps=validation_steps)

Epoch 1/5


  1/300 [..............................] - ETA: 37:09 - loss: 2.3069 - sparse_categorical_accuracy: 0.0950

  4/300 [..............................] - ETA: 5s - loss: 1.7620 - sparse_categorical_accuracy: 0.3925   

  7/300 [..............................] - ETA: 5s - loss: 1.3732 - sparse_categorical_accuracy: 0.5443

 10/300 [>.............................] - ETA: 5s - loss: 1.1138 - sparse_categorical_accuracy: 0.6385

 13/300 [>.............................] - ETA: 5s - loss: 0.9397 - sparse_categorical_accuracy: 0.6985

 16/300 [>.............................] - ETA: 5s - loss: 0.8184 - sparse_categorical_accuracy: 0.7366

 19/300 [>.............................] - ETA: 5s - loss: 0.7352 - sparse_categorical_accuracy: 0.7650

 22/300 [=>............................] - ETA: 5s - loss: 0.6673 - sparse_categorical_accuracy: 0.7884

 25/300 [=>............................] - ETA: 5s - loss: 0.6103 - sparse_categorical_accuracy: 0.8064

 28/300 [=>............................] - ETA: 5s - loss: 0.5604 - sparse_categorical_accuracy: 0.8225

 31/300 [==>...........................] - ETA: 5s - loss: 0.5242 - sparse_categorical_accuracy: 0.8339

 34/300 [==>...........................] - ETA: 5s - loss: 0.4926 - sparse_categorical_accuracy: 0.8446

 37/300 [==>...........................] - ETA: 5s - loss: 0.4634 - sparse_categorical_accuracy: 0.8535

 40/300 [===>..........................] - ETA: 5s - loss: 0.4424 - sparse_categorical_accuracy: 0.8604

 43/300 [===>..........................] - ETA: 4s - loss: 0.4203 - sparse_categorical_accuracy: 0.8674

 46/300 [===>..........................] - ETA: 4s - loss: 0.3992 - sparse_categorical_accuracy: 0.8745

 49/300 [===>..........................] - ETA: 4s - loss: 0.3833 - sparse_categorical_accuracy: 0.8794

 52/300 [====>.........................] - ETA: 4s - loss: 0.3699 - sparse_categorical_accuracy: 0.8841

 55/300 [====>.........................] - ETA: 4s - loss: 0.3575 - sparse_categorical_accuracy: 0.8884

 58/300 [====>.........................] - ETA: 4s - loss: 0.3456 - sparse_categorical_accuracy: 0.8924

 61/300 [=====>........................] - ETA: 4s - loss: 0.3332 - sparse_categorical_accuracy: 0.8962

 64/300 [=====>........................] - ETA: 4s - loss: 0.3228 - sparse_categorical_accuracy: 0.8996

 67/300 [=====>........................] - ETA: 4s - loss: 0.3115 - sparse_categorical_accuracy: 0.9029































































































































































Epoch 2/5


  1/300 [..............................] - ETA: 11s - loss: 0.0212 - sparse_categorical_accuracy: 0.9950

  4/300 [..............................] - ETA: 5s - loss: 0.0458 - sparse_categorical_accuracy: 0.9875 

  7/300 [..............................] - ETA: 5s - loss: 0.0427 - sparse_categorical_accuracy: 0.9864

 10/300 [>.............................] - ETA: 5s - loss: 0.0419 - sparse_categorical_accuracy: 0.9870

 13/300 [>.............................] - ETA: 5s - loss: 0.0429 - sparse_categorical_accuracy: 0.9858

 16/300 [>.............................] - ETA: 5s - loss: 0.0410 - sparse_categorical_accuracy: 0.9859

 17/300 [>.............................] - ETA: 6s - loss: 0.0395 - sparse_categorical_accuracy: 0.9865

 20/300 [=>............................] - ETA: 5s - loss: 0.0414 - sparse_categorical_accuracy: 0.9858

 23/300 [=>............................] - ETA: 5s - loss: 0.0427 - sparse_categorical_accuracy: 0.9857

 26/300 [=>............................] - ETA: 5s - loss: 0.0405 - sparse_categorical_accuracy: 0.9860

 29/300 [=>............................] - ETA: 5s - loss: 0.0442 - sparse_categorical_accuracy: 0.9847

 32/300 [==>...........................] - ETA: 5s - loss: 0.0449 - sparse_categorical_accuracy: 0.9845

 35/300 [==>...........................] - ETA: 5s - loss: 0.0430 - sparse_categorical_accuracy: 0.9853

 38/300 [==>...........................] - ETA: 5s - loss: 0.0431 - sparse_categorical_accuracy: 0.9853

 41/300 [===>..........................] - ETA: 5s - loss: 0.0441 - sparse_categorical_accuracy: 0.9851

 44/300 [===>..........................] - ETA: 5s - loss: 0.0432 - sparse_categorical_accuracy: 0.9856

 47/300 [===>..........................] - ETA: 5s - loss: 0.0418 - sparse_categorical_accuracy: 0.9862

 50/300 [====>.........................] - ETA: 4s - loss: 0.0430 - sparse_categorical_accuracy: 0.9856

 53/300 [====>.........................] - ETA: 4s - loss: 0.0427 - sparse_categorical_accuracy: 0.9857

 56/300 [====>.........................] - ETA: 4s - loss: 0.0420 - sparse_categorical_accuracy: 0.9860

 59/300 [====>.........................] - ETA: 4s - loss: 0.0416 - sparse_categorical_accuracy: 0.9862

 62/300 [=====>........................] - ETA: 4s - loss: 0.0411 - sparse_categorical_accuracy: 0.9861

 65/300 [=====>........................] - ETA: 4s - loss: 0.0413 - sparse_categorical_accuracy: 0.9861

 68/300 [=====>........................] - ETA: 4s - loss: 0.0411 - sparse_categorical_accuracy: 0.9860





























































































































































Epoch 3/5


  1/300 [..............................] - ETA: 11s - loss: 0.0792 - sparse_categorical_accuracy: 0.9900

  4/300 [..............................] - ETA: 5s - loss: 0.0384 - sparse_categorical_accuracy: 0.9912 

  7/300 [..............................] - ETA: 5s - loss: 0.0356 - sparse_categorical_accuracy: 0.9900

 10/300 [>.............................] - ETA: 5s - loss: 0.0325 - sparse_categorical_accuracy: 0.9920

 13/300 [>.............................] - ETA: 5s - loss: 0.0310 - sparse_categorical_accuracy: 0.9912

 16/300 [>.............................] - ETA: 5s - loss: 0.0275 - sparse_categorical_accuracy: 0.9922

 19/300 [>.............................] - ETA: 5s - loss: 0.0255 - sparse_categorical_accuracy: 0.9926

 22/300 [=>............................] - ETA: 5s - loss: 0.0263 - sparse_categorical_accuracy: 0.9923

 25/300 [=>............................] - ETA: 5s - loss: 0.0253 - sparse_categorical_accuracy: 0.9922

 28/300 [=>............................] - ETA: 5s - loss: 0.0250 - sparse_categorical_accuracy: 0.9921

 31/300 [==>...........................] - ETA: 5s - loss: 0.0248 - sparse_categorical_accuracy: 0.9924

 34/300 [==>...........................] - ETA: 5s - loss: 0.0237 - sparse_categorical_accuracy: 0.9928

 37/300 [==>...........................] - ETA: 5s - loss: 0.0241 - sparse_categorical_accuracy: 0.9926

 40/300 [===>..........................] - ETA: 5s - loss: 0.0238 - sparse_categorical_accuracy: 0.9925

 43/300 [===>..........................] - ETA: 4s - loss: 0.0242 - sparse_categorical_accuracy: 0.9928

 46/300 [===>..........................] - ETA: 4s - loss: 0.0230 - sparse_categorical_accuracy: 0.9932

 49/300 [===>..........................] - ETA: 4s - loss: 0.0218 - sparse_categorical_accuracy: 0.9936

 52/300 [====>.........................] - ETA: 4s - loss: 0.0226 - sparse_categorical_accuracy: 0.9930

 55/300 [====>.........................] - ETA: 4s - loss: 0.0221 - sparse_categorical_accuracy: 0.9929

 58/300 [====>.........................] - ETA: 4s - loss: 0.0221 - sparse_categorical_accuracy: 0.9929

 61/300 [=====>........................] - ETA: 4s - loss: 0.0221 - sparse_categorical_accuracy: 0.9928

 64/300 [=====>........................] - ETA: 4s - loss: 0.0222 - sparse_categorical_accuracy: 0.9928

 67/300 [=====>........................] - ETA: 4s - loss: 0.0220 - sparse_categorical_accuracy: 0.9930





























































































































































Epoch 4/5


  1/300 [..............................] - ETA: 11s - loss: 0.0018 - sparse_categorical_accuracy: 1.0000

  4/300 [..............................] - ETA: 5s - loss: 0.0096 - sparse_categorical_accuracy: 0.9975 

  7/300 [..............................] - ETA: 5s - loss: 0.0128 - sparse_categorical_accuracy: 0.9957

 10/300 [>.............................] - ETA: 5s - loss: 0.0125 - sparse_categorical_accuracy: 0.9955

 13/300 [>.............................] - ETA: 5s - loss: 0.0112 - sparse_categorical_accuracy: 0.9962

 16/300 [>.............................] - ETA: 5s - loss: 0.0138 - sparse_categorical_accuracy: 0.9947

 19/300 [>.............................] - ETA: 5s - loss: 0.0151 - sparse_categorical_accuracy: 0.9942

 22/300 [=>............................] - ETA: 5s - loss: 0.0143 - sparse_categorical_accuracy: 0.9945

 25/300 [=>............................] - ETA: 5s - loss: 0.0149 - sparse_categorical_accuracy: 0.9942

 28/300 [=>............................] - ETA: 5s - loss: 0.0163 - sparse_categorical_accuracy: 0.9941

 31/300 [==>...........................] - ETA: 5s - loss: 0.0172 - sparse_categorical_accuracy: 0.9944

 34/300 [==>...........................] - ETA: 5s - loss: 0.0164 - sparse_categorical_accuracy: 0.9947

 37/300 [==>...........................] - ETA: 5s - loss: 0.0154 - sparse_categorical_accuracy: 0.9950

 40/300 [===>..........................] - ETA: 5s - loss: 0.0160 - sparse_categorical_accuracy: 0.9949

 43/300 [===>..........................] - ETA: 5s - loss: 0.0156 - sparse_categorical_accuracy: 0.9950

 46/300 [===>..........................] - ETA: 5s - loss: 0.0154 - sparse_categorical_accuracy: 0.9951

 49/300 [===>..........................] - ETA: 5s - loss: 0.0152 - sparse_categorical_accuracy: 0.9951

 52/300 [====>.........................] - ETA: 5s - loss: 0.0147 - sparse_categorical_accuracy: 0.9952

 55/300 [====>.........................] - ETA: 5s - loss: 0.0148 - sparse_categorical_accuracy: 0.9952

 58/300 [====>.........................] - ETA: 5s - loss: 0.0148 - sparse_categorical_accuracy: 0.9950

 61/300 [=====>........................] - ETA: 5s - loss: 0.0144 - sparse_categorical_accuracy: 0.9952

 64/300 [=====>........................] - ETA: 4s - loss: 0.0145 - sparse_categorical_accuracy: 0.9952

 67/300 [=====>........................] - ETA: 4s - loss: 0.0150 - sparse_categorical_accuracy: 0.9951































































































































































Epoch 5/5


  1/300 [..............................] - ETA: 13s - loss: 0.0277 - sparse_categorical_accuracy: 0.9950

  4/300 [..............................] - ETA: 5s - loss: 0.0265 - sparse_categorical_accuracy: 0.9912 

  7/300 [..............................] - ETA: 5s - loss: 0.0172 - sparse_categorical_accuracy: 0.9936

 10/300 [>.............................] - ETA: 5s - loss: 0.0202 - sparse_categorical_accuracy: 0.9930

 13/300 [>.............................] - ETA: 5s - loss: 0.0173 - sparse_categorical_accuracy: 0.9938

 16/300 [>.............................] - ETA: 5s - loss: 0.0164 - sparse_categorical_accuracy: 0.9941

 19/300 [>.............................] - ETA: 5s - loss: 0.0153 - sparse_categorical_accuracy: 0.9945

 22/300 [=>............................] - ETA: 5s - loss: 0.0155 - sparse_categorical_accuracy: 0.9943

 25/300 [=>............................] - ETA: 5s - loss: 0.0174 - sparse_categorical_accuracy: 0.9936

 28/300 [=>............................] - ETA: 5s - loss: 0.0167 - sparse_categorical_accuracy: 0.9937

 31/300 [==>...........................] - ETA: 5s - loss: 0.0154 - sparse_categorical_accuracy: 0.9944

 34/300 [==>...........................] - ETA: 5s - loss: 0.0178 - sparse_categorical_accuracy: 0.9937

 37/300 [==>...........................] - ETA: 5s - loss: 0.0181 - sparse_categorical_accuracy: 0.9936

 40/300 [===>..........................] - ETA: 5s - loss: 0.0173 - sparse_categorical_accuracy: 0.9941

 43/300 [===>..........................] - ETA: 4s - loss: 0.0171 - sparse_categorical_accuracy: 0.9941

 46/300 [===>..........................] - ETA: 4s - loss: 0.0169 - sparse_categorical_accuracy: 0.9940

 49/300 [===>..........................] - ETA: 4s - loss: 0.0163 - sparse_categorical_accuracy: 0.9942

 52/300 [====>.........................] - ETA: 4s - loss: 0.0158 - sparse_categorical_accuracy: 0.9943

 55/300 [====>.........................] - ETA: 4s - loss: 0.0155 - sparse_categorical_accuracy: 0.9944

 58/300 [====>.........................] - ETA: 4s - loss: 0.0148 - sparse_categorical_accuracy: 0.9947

 61/300 [=====>........................] - ETA: 4s - loss: 0.0148 - sparse_categorical_accuracy: 0.9947

 64/300 [=====>........................] - ETA: 4s - loss: 0.0148 - sparse_categorical_accuracy: 0.9946

 67/300 [=====>........................] - ETA: 4s - loss: 0.0143 - sparse_categorical_accuracy: 0.9948































































































































































<keras.callbacks.History at 0x7f79107c8d30>

To reduce Python overhead and maximize the performance of your TPU, pass in the `steps_per_execution` argument to Keras `Model.compile`. In this example, it increases throughput by about 50%:

In [10]:
with strategy.scope():
  model = create_model()
  model.compile(optimizer='adam',
                # Anything between 2 and `steps_per_epoch` could help here.
                steps_per_execution = 50,
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['sparse_categorical_accuracy'])

model.fit(train_dataset,
          epochs=5,
          steps_per_epoch=steps_per_epoch,
          validation_data=test_dataset,
          validation_steps=validation_steps)

Epoch 1/5


 50/300 [====>.........................] - ETA: 41s - loss: 0.4093 - sparse_categorical_accuracy: 0.8690













Epoch 2/5


 50/300 [====>.........................] - ETA: 1s - loss: 0.0452 - sparse_categorical_accuracy: 0.9874













Epoch 3/5


 50/300 [====>.........................] - ETA: 1s - loss: 0.0300 - sparse_categorical_accuracy: 0.9906













Epoch 4/5


 50/300 [====>.........................] - ETA: 1s - loss: 0.0145 - sparse_categorical_accuracy: 0.9948













Epoch 5/5


 50/300 [====>.........................] - ETA: 1s - loss: 0.0097 - sparse_categorical_accuracy: 0.9964













<keras.callbacks.History at 0x7f7898488e20>

### Train the model using a custom training loop

You can also create and train your model using `tf.function` and `tf.distribute` APIs directly. You can use the `Strategy.experimental_distribute_datasets_from_function` API to distribute the `tf.data.Dataset` given a dataset function. Note that in the example below the batch size passed into the `Dataset` is the per-replica batch size instead of the global batch size. To learn more, check out the [Custom training with `tf.distribute.Strategy`](../tutorials/distribute/custom_training.ipynb) tutorial.


First, create the model, datasets and `tf.function`s:

In [11]:
# Create the model, optimizer and metrics inside the `tf.distribute.Strategy`
# scope, so that the variables can be mirrored on each device.
with strategy.scope():
  model = create_model()
  optimizer = tf.keras.optimizers.Adam()
  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
  training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      'training_accuracy', dtype=tf.float32)

# Calculate per replica batch size, and distribute the `tf.data.Dataset`s
# on each TPU worker.
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync

train_dataset = strategy.experimental_distribute_datasets_from_function(
    lambda _: get_dataset(per_replica_batch_size, is_training=True))

@tf.function
def train_step(iterator):
  """The step function for one training step."""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits, from_logits=True)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  strategy.run(step_fn, args=(next(iterator),))

Instructions for updating:
rename to distribute_datasets_from_function


Instructions for updating:
rename to distribute_datasets_from_function


Then, run the training loop:

In [12]:
steps_per_eval = 10000 // batch_size

train_iterator = iter(train_dataset)
for epoch in range(5):
  print('Epoch: {}/5'.format(epoch))

  for step in range(steps_per_epoch):
    train_step(train_iterator)
  print('Current step: {}, training loss: {}, accuracy: {}%'.format(
      optimizer.iterations.numpy(),
      round(float(training_loss.result()), 4),
      round(float(training_accuracy.result()) * 100, 2)))
  training_loss.reset_states()
  training_accuracy.reset_states()

Epoch: 0/5


Current step: 300, training loss: 0.1465, accuracy: 95.4%
Epoch: 1/5


Current step: 600, training loss: 0.035, accuracy: 98.94%
Epoch: 2/5


Current step: 900, training loss: 0.0197, accuracy: 99.39%
Epoch: 3/5


Current step: 1200, training loss: 0.0126, accuracy: 99.59%
Epoch: 4/5


Current step: 1500, training loss: 0.0109, accuracy: 99.64%


### Improving performance with multiple steps inside `tf.function`

You can improve the performance by running multiple steps within a `tf.function`. This is achieved by wrapping the `Strategy.run` call with a `tf.range` inside `tf.function`, and AutoGraph will convert it to a `tf.while_loop` on the TPU worker. You can learn more about `tf.function`s in the [Better performance with `tf.function`](./function.ipynb) guide.

Despite the improved performance, there are tradeoffs with this method compared to running a single step inside a `tf.function`. Running multiple steps in a `tf.function` is less flexible—you cannot run things eagerly or arbitrary Python code within the steps.


In [13]:
@tf.function
def train_multiple_steps(iterator, steps):
  """The step function for one training step."""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits, from_logits=True)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  for _ in tf.range(steps):
    strategy.run(step_fn, args=(next(iterator),))

# Convert `steps_per_epoch` to `tf.Tensor` so the `tf.function` won't get
# retraced if the value changes.
train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))

print('Current step: {}, training loss: {}, accuracy: {}%'.format(
      optimizer.iterations.numpy(),
      round(float(training_loss.result()), 4),
      round(float(training_accuracy.result()) * 100, 2)))

Current step: 1800, training loss: 0.009, accuracy: 99.72%


## Next steps

To learn more about Cloud TPUs and how to use them:

- [Google Cloud TPU](https://cloud.google.com/tpu): The Google Cloud TPU homepage.
- [Google Cloud TPU documentation](https://cloud.google.com/tpu/docs/): Google Cloud TPU documentation, which includes:
  - [Introduction to Cloud TPU](https://cloud.google.com/tpu/docs/intro-to-tpu): An overview of working with Cloud TPUs.
  - [Cloud TPU quickstarts](https://cloud.google.com/tpu/docs/quick-starts): Quickstart introductions to working with Cloud TPU VMs using TensorFlow and other main machine learning frameworks.
- [Google Cloud TPU Colab notebooks](https://cloud.google.com/tpu/docs/colabs): End-to-end training examples.
- [Google Cloud TPU performance guide](https://cloud.google.com/tpu/docs/performance-guide): Enhance Cloud TPU performance further by adjusting Cloud TPU configuration parameters for your application
- [Distributed training with TensorFlow](./distributed_training.ipynb): How to use distribution strategies—including `tf.distribute.TPUStrategy`—with examples showing best practices.
- TPU embeddings: TensorFlow includes specialized support for training embeddings on TPUs via `tf.tpu.experimental.embedding`. In addition, [TensorFlow Recommenders](https://www.tensorflow.org/recommenders) has `tfrs.layers.embedding.TPUEmbedding`. Embeddings provide efficient and dense representations, capturing complex similarities and relationships between features. TensorFlow's TPU-specific embedding support allows you to train embeddings that are larger than the memory of a single TPU device, and to use sparse and ragged inputs on TPUs.
- [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/): TRC enables researchers to apply for access to a cluster of more than 1,000 Cloud TPU devices.
