##### Copyright 2019 The TensorFlow Authors.


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 将 DTensor 与 Keras 一起使用

<table class="tfo-notebook-buttons" align="left">
  <td>     <a target="_blank" href="https://tensorflow.google.cn/tutorials/distribute/dtensor_keras_tutorial"><img src="https://tensorflow.google.cn/images/tf_logo_32px.png">在 TensorFlow.org 上查看</a> </td>
  <td>     <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/distribute/dtensor_keras_tutorial.ipynb"><img src="https://tensorflow.google.cn/images/colab_logo_32px.png">在 Google Colab 中运行</a> </td>
  <td>     <a target="_blank" href="https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/distribute/dtensor_keras_tutorial.ipynb"><img src="https://tensorflow.google.cn/images/GitHub-Mark-32px.png">在 GitHub 上查看源代码</a> </td>
  <td>     <a href="https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/tutorials/distribute/dtensor_keras_tutorial.ipynb"><img src="https://tensorflow.google.cn/images/download_logo_32px.png">下载笔记本</a> </td>
</table>

## 概述

在本教程中，您将学习如何将 DTensor 与 Keras 一起使用。

通过将 DTensor 与 Keras 集成，您可以重用现有的 Keras 层和模型来构建和训练分布式机器学习模型。

您将使用 MNIST 数据训练多层分类模型。本文将演示如何设置子类化模型、序贯模型和函数式模型的布局。

本教程假设您已经阅读了 [DTensor 编程指南](/guide/dtensor_overview)，并且熟悉基本的 DTensor 概念，例如 `Mesh` 和 `Layout`。

本教程基于 https://tensorflow.google.cn/datasets/keras_example。

## 安装

DTensor 是 TensorFlow 2.9.0 版本的一部分。

In [2]:
!pip install --quiet --upgrade --pre tensorflow tensorflow-datasets

接下来，导入 `tensorflow` 和 `tensorflow.experimental.dtensor`，并将 TensorFlow 配置为使用 8 个虚拟 CPU。

尽管本示例使用了 CPU，但 DTensor 在 CPU、GPU 或 TPU 设备上的工作方式相同。

In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.experimental import dtensor

2023-11-07 23:29:54.584316: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 23:29:54.584359: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 23:29:54.585900: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [4]:
def configure_virtual_cpus(ncpu):
  phy_devices = tf.config.list_physical_devices('CPU')
  tf.config.set_logical_device_configuration(
        phy_devices[0], 
        [tf.config.LogicalDeviceConfiguration()] * ncpu)
  
configure_virtual_cpus(8)
tf.config.list_logical_devices('CPU')

devices = [f'CPU:{i}' for i in range(8)]

## 确定性伪随机数生成器

您应当注意的一件事是 DTensor API 要求每个正在运行的客户端具有相同的随机种子，以便它可以具有用于初始化权重的确定性行为。可以通过 `tf.keras.utils.set_random_seed()` 在 Keras 中设置全局种子来实现此目的。

In [5]:
tf.keras.backend.experimental.enable_tf_random_generator()
tf.keras.utils.set_random_seed(1337)

## 创建数据并行网格

本教程演示数据并行训练。适应模型并行训练和空间并行训练可以像切换到一组不同的 `Layout` 对象一样简单。有关数据并行之外的分布式训练的更多信息，请参阅 [DTensor 深入机器学习教程](https://tensorflow.google.cn/tutorials/distribute/dtensor_ml_tutorial)。

数据并行训练是一种常用的并行训练方案，也被诸如 `tf.distribute.MirroredStrategy` 等使用。

使用 DTensor，数据并行训练循环使用由单个“批次”维度组成的 `Mesh`，其中每个设备都会运行模型的副本，从全局批次接收分片。


In [6]:
mesh = dtensor.create_mesh([("batch", 8)], devices=devices)

由于每个设备都运行模型的完整副本，模型变量应在网格中完全复制（不分片）。例如，此 `Mesh` 上 2 秩权重的完全复制布局如下：

In [7]:
example_weight_layout = dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh)  # or
example_weight_layout = dtensor.Layout.replicated(mesh, rank=2)

此 `Mesh` 上 2 秩数据张量的布局将沿第一个维度进行分片（有时称为 `batch_sharded`），

In [8]:
example_data_layout = dtensor.Layout(['batch', dtensor.UNSHARDED], mesh)  # or
example_data_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)

## 使用布局创建 Keras 层

在数据并行方案中，您通常使用完全复制的布局创建模型权重，以便模型的每个副本都可以使用分片输入数据进行计算。

为了为您的层权重配置布局信息，Keras 在层构造函数中为大多数内置层公开了一个额外的参数。

以下示例使用完全复制的权重布局构建了一个小型图像分类模型。您可以通过参数 `kernel_layout` 和 `bias_layout` 在 `tf.keras.layers.Dense` 中指定布局信息 `kernel` 和 `bias`。大多数内置 Keras 层都可以显式地指定层权重的 `Layout`。

In [9]:
unsharded_layout_2d = dtensor.Layout.replicated(mesh, 2)
unsharded_layout_1d = dtensor.Layout.replicated(mesh, 1)

In [10]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, 
                        activation='relu',
                        name='d1',
                        kernel_layout=unsharded_layout_2d, 
                        bias_layout=unsharded_layout_1d),
  tf.keras.layers.Dense(10,
                        name='d2',
                        kernel_layout=unsharded_layout_2d, 
                        bias_layout=unsharded_layout_1d)
])

您可以通过检查权重的 `layout` 属性来查看布局信息。

In [11]:
for weight in model.weights:
  print(f'Weight name: {weight.name} with layout: {weight.layout}')
  break

Weight name: d1/kernel:0 with layout: Layout.from_string(sharding_specs:unsharded,unsharded, mesh:|batch=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7)


## 加载数据集并构建输入流水线

加载一个 MNIST 数据集并为其配置一些预处理输入流水线。数据集本身与任何 DTensor 布局信息不关联。我们计划在未来的 TensorFlow 版本中改进 DTensor Keras 与 `tf.data` 的集成。


In [12]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [13]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

In [14]:
batch_size = 128

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

In [15]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(batch_size)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

## 定义模型的训练逻辑

接下来，定义模型的训练和评估逻辑。

从 TensorFlow 2.9 开始，您必须为启用 DTensor 的 Keras 模型编写自定义训练循环。这是为了用适当的布局信息打包输入数据，这些信息未与 Keras 中的标准 `tf.keras.Model.fit()` 或 `tf.keras.Model.eval()` 函数集成。您将在即将发布的版本中获得更多 `tf.data` 支持。 

In [16]:
@tf.function
def train_step(model, x, y, optimizer, metrics):
  with tf.GradientTape() as tape:
    logits = model(x, training=True)
    # tf.reduce_sum sums the batch sharded per-example loss to a replicated
    # global loss (scalar).
    loss = tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(
        y, logits, from_logits=True))
    
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  for metric in metrics.values():
    metric.update_state(y_true=y, y_pred=logits)

  loss_per_sample = loss / len(x)
  results = {'loss': loss_per_sample}
  return results

In [17]:
@tf.function
def eval_step(model, x, y, metrics):
  logits = model(x, training=False)
  loss = tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(
        y, logits, from_logits=True))

  for metric in metrics.values():
    metric.update_state(y_true=y, y_pred=logits)

  loss_per_sample = loss / len(x)
  results = {'eval_loss': loss_per_sample}
  return results

In [18]:
def pack_dtensor_inputs(images, labels, image_layout, label_layout):
  num_local_devices = image_layout.mesh.num_local_devices()
  images = tf.split(images, num_local_devices)
  labels = tf.split(labels, num_local_devices)
  images = dtensor.pack(images, image_layout)
  labels = dtensor.pack(labels, label_layout)
  return  images, labels

## 指标和优化器

将 DTensor API 与 Keras `Metric` 和 `Optimizer` 一起使用时，您需要提供额外的网格信息，以便任何内部状态变量和张量都可以使用模型中的变量。

- 对于优化器，DTensor 引入了一个新的实验性命名空间 `keras.dtensor.experimental.optimizers`，其中扩展了许多现有的 Keras 优化器以接收额外的 `mesh` 参数。在未来的版本中，它可能会与 Keras 核心优化器合并。

- 对于指标，可以直接将 `mesh` 作为参数指定给构造函数，使其成为兼容 DTensor 的 `Metric`。

In [19]:
optimizer = tf.keras.dtensor.experimental.optimizers.Adam(0.01, mesh=mesh)
metrics = {'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(mesh=mesh)}
eval_metrics = {'eval_accuracy': tf.keras.metrics.SparseCategoricalAccuracy(mesh=mesh)}

## 训练模型

以下示例在批次维度上对来自输入流水线的数据进行分片，并使用具有完全复制权重的模型进行训练。

经过 3 个周期后，模型应当达到大约 97% 的准确率。

In [20]:
num_epochs = 3

image_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=4)
label_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)

for epoch in range(num_epochs):
  print("============================") 
  print("Epoch: ", epoch)
  for metric in metrics.values():
    metric.reset_state()
  step = 0
  results = {}
  pbar = tf.keras.utils.Progbar(target=None, stateful_metrics=[])
  for input in ds_train:
    images, labels = input[0], input[1]
    images, labels = pack_dtensor_inputs(
        images, labels, image_layout, label_layout)

    results.update(train_step(model, images, labels, optimizer, metrics))
    for metric_name, metric in metrics.items():
      results[metric_name] = metric.result()

    pbar.update(step, values=results.items(), finalize=False)
    step += 1
  pbar.update(step, values=results.items(), finalize=True)

  for metric in eval_metrics.values():
    metric.reset_state()
  for input in ds_test:
    images, labels = input[0], input[1]
    images, labels = pack_dtensor_inputs(
        images, labels, image_layout, label_layout)
    results.update(eval_step(model, images, labels, eval_metrics))

  for metric_name, metric in eval_metrics.items():
    results[metric_name] = metric.result()
  
  for metric_name, metric in results.items():
    print(f"{metric_name}: {metric.numpy()}")


Epoch:  0


      0/Unknown - 3s 0s/step - loss: 2.4019 - accuracy: 0.0859

      4/Unknown - 3s 738ms/step - loss: 2.1233 - accuracy: 0.1611

     10/Unknown - 3s 301ms/step - loss: 1.7960 - accuracy: 0.2829

     16/Unknown - 3s 192ms/step - loss: 1.5757 - accuracy: 0.3551

     22/Unknown - 3s 142ms/step - loss: 1.3916 - accuracy: 0.4056

     28/Unknown - 3s 113ms/step - loss: 1.2480 - accuracy: 0.4456

     34/Unknown - 3s 95ms/step - loss: 1.1229 - accuracy: 0.4786 

     40/Unknown - 3s 82ms/step - loss: 1.0421 - accuracy: 0.5064

     46/Unknown - 3s 73ms/step - loss: 0.9624 - accuracy: 0.5302

     52/Unknown - 3s 65ms/step - loss: 0.8969 - accuracy: 0.5510

     58/Unknown - 3s 59ms/step - loss: 0.8472 - accuracy: 0.5694

     64/Unknown - 3s 55ms/step - loss: 0.8005 - accuracy: 0.5858

     70/Unknown - 4s 51ms/step - loss: 0.7649 - accuracy: 0.6006

     76/Unknown - 4s 47ms/step - loss: 0.7302 - accuracy: 0.6138

     82/Unknown - 4s 45ms/step - loss: 0.7035 - accuracy: 0.6257

     88/Unknown - 4s 42ms/step - loss: 0.6772 - accuracy: 0.6367

     94/Unknown - 4s 40ms/step - loss: 0.6539 - accuracy: 0.6467

    100/Unknown - 4s 38ms/step - loss: 0.6312 - accuracy: 0.6560

    106/Unknown - 4s 37ms/step - loss: 0.6123 - accuracy: 0.6646

    112/Unknown - 4s 35ms/step - loss: 0.5953 - accuracy: 0.6726

    118/Unknown - 4s 34ms/step - loss: 0.5797 - accuracy: 0.6800

    124/Unknown - 4s 33ms/step - loss: 0.5635 - accuracy: 0.6870

    130/Unknown - 4s 31ms/step - loss: 0.5484 - accuracy: 0.6935

    136/Unknown - 4s 30ms/step - loss: 0.5370 - accuracy: 0.6997

    142/Unknown - 4s 30ms/step - loss: 0.5241 - accuracy: 0.7055

    148/Unknown - 4s 29ms/step - loss: 0.5134 - accuracy: 0.7109

    154/Unknown - 4s 28ms/step - loss: 0.5033 - accuracy: 0.7161

    160/Unknown - 4s 27ms/step - loss: 0.4913 - accuracy: 0.7210

    166/Unknown - 4s 27ms/step - loss: 0.4830 - accuracy: 0.7257

    172/Unknown - 4s 26ms/step - loss: 0.4763 - accuracy: 0.7301

    178/Unknown - 5s 25ms/step - loss: 0.4685 - accuracy: 0.7343

    184/Unknown - 5s 25ms/step - loss: 0.4624 - accuracy: 0.7383

    190/Unknown - 5s 24ms/step - loss: 0.4552 - accuracy: 0.7421

    196/Unknown - 5s 24ms/step - loss: 0.4493 - accuracy: 0.7457

    202/Unknown - 5s 23ms/step - loss: 0.4415 - accuracy: 0.7492

    208/Unknown - 5s 23ms/step - loss: 0.4361 - accuracy: 0.7526

    214/Unknown - 5s 23ms/step - loss: 0.4301 - accuracy: 0.7558

    220/Unknown - 5s 22ms/step - loss: 0.4235 - accuracy: 0.7588

    226/Unknown - 5s 22ms/step - loss: 0.4180 - accuracy: 0.7618

    232/Unknown - 5s 22ms/step - loss: 0.4128 - accuracy: 0.7647

    238/Unknown - 5s 21ms/step - loss: 0.4076 - accuracy: 0.7674

    244/Unknown - 5s 21ms/step - loss: 0.4018 - accuracy: 0.7701

    250/Unknown - 5s 21ms/step - loss: 0.3975 - accuracy: 0.7727

    256/Unknown - 5s 20ms/step - loss: 0.3934 - accuracy: 0.7752

    262/Unknown - 5s 20ms/step - loss: 0.3883 - accuracy: 0.7776

    268/Unknown - 5s 20ms/step - loss: 0.3839 - accuracy: 0.7799

    274/Unknown - 5s 20ms/step - loss: 0.3798 - accuracy: 0.7821

    280/Unknown - 5s 19ms/step - loss: 0.3767 - accuracy: 0.7843

    286/Unknown - 5s 19ms/step - loss: 0.3729 - accuracy: 0.7864

    292/Unknown - 6s 19ms/step - loss: 0.3685 - accuracy: 0.7885

    298/Unknown - 6s 19ms/step - loss: 0.3649 - accuracy: 0.7905

    304/Unknown - 6s 19ms/step - loss: 0.3615 - accuracy: 0.7924

    310/Unknown - 6s 18ms/step - loss: 0.3579 - accuracy: 0.7943

    316/Unknown - 6s 18ms/step - loss: 0.3540 - accuracy: 0.7962

    322/Unknown - 6s 18ms/step - loss: 0.3510 - accuracy: 0.7980

    328/Unknown - 6s 18ms/step - loss: 0.3475 - accuracy: 0.7997

    334/Unknown - 6s 18ms/step - loss: 0.3434 - accuracy: 0.8014

    340/Unknown - 6s 18ms/step - loss: 0.3413 - accuracy: 0.8030

    346/Unknown - 6s 17ms/step - loss: 0.3380 - accuracy: 0.8047

    352/Unknown - 6s 17ms/step - loss: 0.3345 - accuracy: 0.8062

    358/Unknown - 6s 17ms/step - loss: 0.3314 - accuracy: 0.8078

    364/Unknown - 6s 17ms/step - loss: 0.3281 - accuracy: 0.8093

    370/Unknown - 6s 17ms/step - loss: 0.3259 - accuracy: 0.8107

    376/Unknown - 6s 17ms/step - loss: 0.3227 - accuracy: 0.8121

    382/Unknown - 6s 17ms/step - loss: 0.3194 - accuracy: 0.8135

    388/Unknown - 6s 16ms/step - loss: 0.3165 - accuracy: 0.8149

    394/Unknown - 6s 16ms/step - loss: 0.3149 - accuracy: 0.8163

    400/Unknown - 6s 16ms/step - loss: 0.3125 - accuracy: 0.8176

    406/Unknown - 7s 16ms/step - loss: 0.3103 - accuracy: 0.8188

    412/Unknown - 7s 16ms/step - loss: 0.3082 - accuracy: 0.8201

    418/Unknown - 7s 16ms/step - loss: 0.3063 - accuracy: 0.8213

    424/Unknown - 7s 16ms/step - loss: 0.3041 - accuracy: 0.8225

    430/Unknown - 7s 16ms/step - loss: 0.3020 - accuracy: 0.8237

    436/Unknown - 7s 16ms/step - loss: 0.2998 - accuracy: 0.8249

    442/Unknown - 7s 16ms/step - loss: 0.2978 - accuracy: 0.8260

    448/Unknown - 7s 15ms/step - loss: 0.2964 - accuracy: 0.8271

    454/Unknown - 7s 15ms/step - loss: 0.2948 - accuracy: 0.8282

    460/Unknown - 7s 15ms/step - loss: 0.2935 - accuracy: 0.8293

    466/Unknown - 7s 15ms/step - loss: 0.2917 - accuracy: 0.8303

    468/Unknown - 7s 16ms/step - loss: 0.2910 - accuracy: 0.8306

    469/Unknown - 7s 16ms/step - loss: 0.2907 - accuracy: 0.8308


loss: 0.12670570611953735
accuracy: 0.9109166860580444
eval_loss: 0.04795415699481964
eval_accuracy: 0.958899974822998
Epoch:  1
      0/Unknown - 0s 0s/step - loss: 0.1073 - accuracy: 0.9766

      6/Unknown - 0s 23ms/step - loss: 0.1262 - accuracy: 0.9683

     12/Unknown - 0s 16ms/step - loss: 0.1237 - accuracy: 0.9649

     18/Unknown - 0s 14ms/step - loss: 0.1337 - accuracy: 0.9641

     24/Unknown - 0s 13ms/step - loss: 0.1317 - accuracy: 0.9632

     30/Unknown - 0s 12ms/step - loss: 0.1264 - accuracy: 0.9629

     36/Unknown - 0s 11ms/step - loss: 0.1273 - accuracy: 0.9625

     42/Unknown - 0s 11ms/step - loss: 0.1245 - accuracy: 0.9622

     48/Unknown - 1s 11ms/step - loss: 0.1233 - accuracy: 0.9621

     54/Unknown - 1s 11ms/step - loss: 0.1263 - accuracy: 0.9620

     60/Unknown - 1s 11ms/step - loss: 0.1256 - accuracy: 0.9620

     66/Unknown - 1s 10ms/step - loss: 0.1249 - accuracy: 0.9619

     72/Unknown - 1s 10ms/step - loss: 0.1283 - accuracy: 0.9618

     78/Unknown - 1s 10ms/step - loss: 0.1276 - accuracy: 0.9617

     84/Unknown - 1s 10ms/step - loss: 0.1252 - accuracy: 0.9616

     90/Unknown - 1s 10ms/step - loss: 0.1265 - accuracy: 0.9616

     96/Unknown - 1s 10ms/step - loss: 0.1287 - accuracy: 0.9615

    102/Unknown - 1s 10ms/step - loss: 0.1305 - accuracy: 0.9613

    108/Unknown - 1s 10ms/step - loss: 0.1322 - accuracy: 0.9612

    114/Unknown - 1s 10ms/step - loss: 0.1335 - accuracy: 0.9610

    120/Unknown - 1s 10ms/step - loss: 0.1356 - accuracy: 0.9608

    126/Unknown - 1s 10ms/step - loss: 0.1381 - accuracy: 0.9606

    132/Unknown - 1s 10ms/step - loss: 0.1372 - accuracy: 0.9604

    138/Unknown - 1s 10ms/step - loss: 0.1383 - accuracy: 0.9603

    144/Unknown - 1s 10ms/step - loss: 0.1371 - accuracy: 0.9601

    150/Unknown - 1s 10ms/step - loss: 0.1368 - accuracy: 0.9600

    156/Unknown - 1s 10ms/step - loss: 0.1359 - accuracy: 0.9599

    162/Unknown - 2s 10ms/step - loss: 0.1354 - accuracy: 0.9599

    168/Unknown - 2s 10ms/step - loss: 0.1351 - accuracy: 0.9598

    174/Unknown - 2s 10ms/step - loss: 0.1339 - accuracy: 0.9597

    180/Unknown - 2s 10ms/step - loss: 0.1341 - accuracy: 0.9597

    186/Unknown - 2s 10ms/step - loss: 0.1347 - accuracy: 0.9597

    192/Unknown - 2s 9ms/step - loss: 0.1349 - accuracy: 0.9596 

    198/Unknown - 2s 9ms/step - loss: 0.1344 - accuracy: 0.9596

    204/Unknown - 2s 9ms/step - loss: 0.1349 - accuracy: 0.9595

    210/Unknown - 2s 9ms/step - loss: 0.1346 - accuracy: 0.9595

    216/Unknown - 2s 9ms/step - loss: 0.1349 - accuracy: 0.9595

    222/Unknown - 2s 9ms/step - loss: 0.1341 - accuracy: 0.9594

    228/Unknown - 2s 9ms/step - loss: 0.1345 - accuracy: 0.9594

    234/Unknown - 2s 9ms/step - loss: 0.1346 - accuracy: 0.9594

    240/Unknown - 2s 9ms/step - loss: 0.1355 - accuracy: 0.9593

    246/Unknown - 2s 9ms/step - loss: 0.1355 - accuracy: 0.9593

    252/Unknown - 2s 9ms/step - loss: 0.1358 - accuracy: 0.9593

    258/Unknown - 2s 9ms/step - loss: 0.1353 - accuracy: 0.9593

    264/Unknown - 2s 9ms/step - loss: 0.1350 - accuracy: 0.9593

    270/Unknown - 3s 9ms/step - loss: 0.1340 - accuracy: 0.9593

    276/Unknown - 3s 9ms/step - loss: 0.1344 - accuracy: 0.9592

    282/Unknown - 3s 9ms/step - loss: 0.1340 - accuracy: 0.9592

    288/Unknown - 3s 9ms/step - loss: 0.1338 - accuracy: 0.9592

    294/Unknown - 3s 9ms/step - loss: 0.1332 - accuracy: 0.9592

    300/Unknown - 3s 9ms/step - loss: 0.1330 - accuracy: 0.9592

    306/Unknown - 3s 9ms/step - loss: 0.1327 - accuracy: 0.9592

    312/Unknown - 3s 9ms/step - loss: 0.1325 - accuracy: 0.9592

    318/Unknown - 3s 9ms/step - loss: 0.1324 - accuracy: 0.9592

    324/Unknown - 3s 9ms/step - loss: 0.1321 - accuracy: 0.9592

    330/Unknown - 3s 9ms/step - loss: 0.1318 - accuracy: 0.9592

    336/Unknown - 3s 9ms/step - loss: 0.1314 - accuracy: 0.9592

    342/Unknown - 3s 9ms/step - loss: 0.1311 - accuracy: 0.9592

    348/Unknown - 3s 9ms/step - loss: 0.1308 - accuracy: 0.9592

    354/Unknown - 3s 9ms/step - loss: 0.1302 - accuracy: 0.9592

    360/Unknown - 3s 9ms/step - loss: 0.1296 - accuracy: 0.9592

    366/Unknown - 3s 9ms/step - loss: 0.1291 - accuracy: 0.9592

    372/Unknown - 3s 9ms/step - loss: 0.1290 - accuracy: 0.9592

    378/Unknown - 4s 9ms/step - loss: 0.1283 - accuracy: 0.9593

    384/Unknown - 4s 9ms/step - loss: 0.1277 - accuracy: 0.9593

    390/Unknown - 4s 9ms/step - loss: 0.1268 - accuracy: 0.9593

    396/Unknown - 4s 9ms/step - loss: 0.1266 - accuracy: 0.9593

    402/Unknown - 4s 9ms/step - loss: 0.1263 - accuracy: 0.9593

    408/Unknown - 4s 9ms/step - loss: 0.1264 - accuracy: 0.9593

    414/Unknown - 4s 9ms/step - loss: 0.1262 - accuracy: 0.9594

    420/Unknown - 4s 9ms/step - loss: 0.1261 - accuracy: 0.9594

    426/Unknown - 4s 9ms/step - loss: 0.1262 - accuracy: 0.9594

    432/Unknown - 4s 9ms/step - loss: 0.1262 - accuracy: 0.9594

    438/Unknown - 4s 9ms/step - loss: 0.1270 - accuracy: 0.9594

    444/Unknown - 4s 9ms/step - loss: 0.1274 - accuracy: 0.9594

    450/Unknown - 4s 9ms/step - loss: 0.1280 - accuracy: 0.9595

    456/Unknown - 4s 9ms/step - loss: 0.1278 - accuracy: 0.9595

    462/Unknown - 4s 9ms/step - loss: 0.1288 - accuracy: 0.9595

    468/Unknown - 4s 9ms/step - loss: 0.1286 - accuracy: 0.9595

    469/Unknown - 4s 9ms/step - loss: 0.1285 - accuracy: 0.9595


loss: 0.08036476373672485
accuracy: 0.9599999785423279
eval_loss: 0.008920179679989815
eval_accuracy: 0.9642000198364258
Epoch:  2
      0/Unknown - 0s 0s/step - loss: 0.1102 - accuracy: 0.9688

      6/Unknown - 0s 23ms/step - loss: 0.1375 - accuracy: 0.9626

     12/Unknown - 0s 16ms/step - loss: 0.1137 - accuracy: 0.9621

     18/Unknown - 0s 14ms/step - loss: 0.1053 - accuracy: 0.9633

     24/Unknown - 0s 13ms/step - loss: 0.1093 - accuracy: 0.9640

     30/Unknown - 0s 12ms/step - loss: 0.1113 - accuracy: 0.9643

     36/Unknown - 0s 11ms/step - loss: 0.1071 - accuracy: 0.9647

     42/Unknown - 0s 11ms/step - loss: 0.1050 - accuracy: 0.9651

     48/Unknown - 1s 11ms/step - loss: 0.1055 - accuracy: 0.9654

     54/Unknown - 1s 11ms/step - loss: 0.1075 - accuracy: 0.9655

     60/Unknown - 1s 10ms/step - loss: 0.1094 - accuracy: 0.9656

     66/Unknown - 1s 10ms/step - loss: 0.1070 - accuracy: 0.9658

     72/Unknown - 1s 10ms/step - loss: 0.1059 - accuracy: 0.9659

     78/Unknown - 1s 10ms/step - loss: 0.1054 - accuracy: 0.9661

     84/Unknown - 1s 10ms/step - loss: 0.1041 - accuracy: 0.9662

     90/Unknown - 1s 10ms/step - loss: 0.1039 - accuracy: 0.9663

     96/Unknown - 1s 10ms/step - loss: 0.1029 - accuracy: 0.9664

    102/Unknown - 1s 10ms/step - loss: 0.1004 - accuracy: 0.9665

    108/Unknown - 1s 10ms/step - loss: 0.1003 - accuracy: 0.9667

    114/Unknown - 1s 10ms/step - loss: 0.1011 - accuracy: 0.9668

    120/Unknown - 1s 10ms/step - loss: 0.1007 - accuracy: 0.9669

    126/Unknown - 1s 10ms/step - loss: 0.0996 - accuracy: 0.9670

    132/Unknown - 1s 10ms/step - loss: 0.0982 - accuracy: 0.9671

    138/Unknown - 1s 10ms/step - loss: 0.0997 - accuracy: 0.9672

    144/Unknown - 1s 10ms/step - loss: 0.0997 - accuracy: 0.9672

    150/Unknown - 1s 10ms/step - loss: 0.0987 - accuracy: 0.9673

    156/Unknown - 1s 9ms/step - loss: 0.0986 - accuracy: 0.9674 

    162/Unknown - 2s 9ms/step - loss: 0.0994 - accuracy: 0.9674

    168/Unknown - 2s 9ms/step - loss: 0.1002 - accuracy: 0.9675

    174/Unknown - 2s 9ms/step - loss: 0.0998 - accuracy: 0.9675

    180/Unknown - 2s 9ms/step - loss: 0.0989 - accuracy: 0.9675

    186/Unknown - 2s 9ms/step - loss: 0.0998 - accuracy: 0.9676

    192/Unknown - 2s 9ms/step - loss: 0.0993 - accuracy: 0.9676

    198/Unknown - 2s 9ms/step - loss: 0.0994 - accuracy: 0.9676

    204/Unknown - 2s 9ms/step - loss: 0.0986 - accuracy: 0.9677

    210/Unknown - 2s 9ms/step - loss: 0.0990 - accuracy: 0.9677

    216/Unknown - 2s 9ms/step - loss: 0.1002 - accuracy: 0.9677

    222/Unknown - 2s 9ms/step - loss: 0.0992 - accuracy: 0.9678

    228/Unknown - 2s 9ms/step - loss: 0.0991 - accuracy: 0.9678

    234/Unknown - 2s 9ms/step - loss: 0.0982 - accuracy: 0.9678

    240/Unknown - 2s 9ms/step - loss: 0.0989 - accuracy: 0.9679

    246/Unknown - 2s 9ms/step - loss: 0.0982 - accuracy: 0.9679

    252/Unknown - 2s 9ms/step - loss: 0.0990 - accuracy: 0.9679

    258/Unknown - 2s 9ms/step - loss: 0.0992 - accuracy: 0.9680

    264/Unknown - 2s 9ms/step - loss: 0.0994 - accuracy: 0.9680

    270/Unknown - 3s 9ms/step - loss: 0.0989 - accuracy: 0.9680

    276/Unknown - 3s 9ms/step - loss: 0.0984 - accuracy: 0.9681

    282/Unknown - 3s 9ms/step - loss: 0.0980 - accuracy: 0.9681

    288/Unknown - 3s 9ms/step - loss: 0.0984 - accuracy: 0.9681

    294/Unknown - 3s 9ms/step - loss: 0.0983 - accuracy: 0.9681

    300/Unknown - 3s 9ms/step - loss: 0.0978 - accuracy: 0.9681

    306/Unknown - 3s 9ms/step - loss: 0.0983 - accuracy: 0.9682

    312/Unknown - 3s 9ms/step - loss: 0.0979 - accuracy: 0.9682

    318/Unknown - 3s 9ms/step - loss: 0.0981 - accuracy: 0.9682

    324/Unknown - 3s 9ms/step - loss: 0.0985 - accuracy: 0.9682

    330/Unknown - 3s 9ms/step - loss: 0.0987 - accuracy: 0.9682

    336/Unknown - 3s 9ms/step - loss: 0.0990 - accuracy: 0.9682

    342/Unknown - 3s 9ms/step - loss: 0.0992 - accuracy: 0.9683

    348/Unknown - 3s 9ms/step - loss: 0.0994 - accuracy: 0.9683

    354/Unknown - 3s 9ms/step - loss: 0.0998 - accuracy: 0.9683

    360/Unknown - 3s 9ms/step - loss: 0.0999 - accuracy: 0.9683

    366/Unknown - 3s 9ms/step - loss: 0.0997 - accuracy: 0.9683

    372/Unknown - 3s 9ms/step - loss: 0.0997 - accuracy: 0.9683

    378/Unknown - 3s 9ms/step - loss: 0.0999 - accuracy: 0.9683

    384/Unknown - 4s 9ms/step - loss: 0.1004 - accuracy: 0.9683

    390/Unknown - 4s 9ms/step - loss: 0.1012 - accuracy: 0.9683

    396/Unknown - 4s 9ms/step - loss: 0.1008 - accuracy: 0.9683

    402/Unknown - 4s 9ms/step - loss: 0.1006 - accuracy: 0.9682

    408/Unknown - 4s 9ms/step - loss: 0.1010 - accuracy: 0.9682

    414/Unknown - 4s 9ms/step - loss: 0.1014 - accuracy: 0.9682

    420/Unknown - 4s 9ms/step - loss: 0.1017 - accuracy: 0.9682

    426/Unknown - 4s 9ms/step - loss: 0.1021 - accuracy: 0.9682

    432/Unknown - 4s 9ms/step - loss: 0.1022 - accuracy: 0.9682

    438/Unknown - 4s 9ms/step - loss: 0.1020 - accuracy: 0.9682

    444/Unknown - 4s 9ms/step - loss: 0.1016 - accuracy: 0.9682

    450/Unknown - 4s 9ms/step - loss: 0.1014 - accuracy: 0.9682

    456/Unknown - 4s 9ms/step - loss: 0.1018 - accuracy: 0.9682

    462/Unknown - 4s 9ms/step - loss: 0.1016 - accuracy: 0.9682

    468/Unknown - 4s 9ms/step - loss: 0.1012 - accuracy: 0.9682

    469/Unknown - 4s 9ms/step - loss: 0.1010 - accuracy: 0.9682


loss: 0.044021397829055786
accuracy: 0.9682833552360535
eval_loss: 0.05413995310664177
eval_accuracy: 0.9656000137329102


## 为现有模型代码指定布局

通常，您的模型非常适合您的用例。为模型中的每个单独层指定 `Layout` 信息将是一项需要大量编辑的工作。

为了帮助您轻松地将现有 Keras 模型转换为使用 DTensor API，可以使用新的 `dtensor.LayoutMap` API，它允许您从全局角度指定 `Layout`。

首先，您需要创建一个 `LayoutMap` 实例，它是一个类似字典的对象，其中包含您要为模型权重指定的所有 `Layout`。

`LayoutMap` 在初始化时需要一个 `Mesh` 实例，该实例可用于为任何未配置布局的权重提供默认的复制 `Layout`。如果您希望完全复制所有模型权重，则可以提供空的 `LayoutMap`，默认网格将用于创建复制的 `Layout`。

`LayoutMap` 使用字符串作为键，使用 `Layout` 作为值。普通的 Python 字典与此类之间存在行为差异。检索值时，字符串键将被视为正则表达式

### 子类化模型

考虑使用 Keras 子类化模型语法定义的以下模型。

In [21]:
class SubclassedModel(tf.keras.Model):

  def __init__(self, name=None):
    super().__init__(name=name)
    self.feature = tf.keras.layers.Dense(16)
    self.feature_2 = tf.keras.layers.Dense(24)
    self.dropout = tf.keras.layers.Dropout(0.1)

  def call(self, inputs, training=None):
    x = self.feature(inputs)
    x = self.dropout(x, training=training)
    return self.feature_2(x)

此模型中有 4 个权重，分别是两个 `Dense` 层的 `kernel` 和 `bias`。它们中的每一个都基于对象路径进行映射：

- `model.feature.kernel`
- `model.feature.bias`
- `model.feature_2.kernel`
- `model.feature_2.bias`

注：对于子类化模型，特性名称而不是层的 `.name` 特性用作从映射中检索布局的键。这与 `tf.Module` 检查点遵循的约定一致。对于具有多个层的复杂模型，您可以[手动检查检查点](https://tensorflow.google.cn/guide/checkpoint#manually_inspecting_checkpoints)来查看特性映射。

现在，定义以下 `LayoutMap` 并将其应用于模型。

In [22]:
layout_map = tf.keras.dtensor.experimental.LayoutMap(mesh=mesh)

layout_map['feature.*kernel'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)
layout_map['feature.*bias'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)

with layout_map.scope():
  subclassed_model = SubclassedModel()

模型权重是在第一次调用时创建的，因此使用 DTensor 输入调用模型并确认权重具有预期的布局。

In [23]:
dtensor_input = dtensor.copy_to_mesh(tf.zeros((16, 16)), layout=unsharded_layout_2d)
# Trigger the weights creation for subclass model
subclassed_model(dtensor_input)

print(subclassed_model.feature.kernel.layout)

Layout.from_string(sharding_specs:batch,unsharded, mesh:|batch=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7)


这样一来，您就可以快速将 `Layout` 映射到您的模型，而无需更新任何现有代码。 

### 序贯模型和函数式模型

对于 Keras 序贯和函数式模型，您也可以使用 `LayoutMap`。

注：对于序贯模型和函数式模型，映射略有不同。模型中的层没有附加到模型的公共特性（尽管可以通过 `model.layers` 作为列表访问它们）。在这种情况下，使用字符串名称作为键。字符串名称保证在模型中是唯一的。

In [24]:
layout_map = tf.keras.dtensor.experimental.LayoutMap(mesh=mesh)

layout_map['feature.*kernel'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)
layout_map['feature.*bias'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)

In [25]:
with layout_map.scope():
  inputs = tf.keras.Input((16,), batch_size=16)
  x = tf.keras.layers.Dense(16, name='feature')(inputs)
  x = tf.keras.layers.Dropout(0.1)(x)
  output = tf.keras.layers.Dense(32, name='feature_2')(x)
  model = tf.keras.Model(inputs, output)

print(model.layers[1].kernel.layout)

Layout.from_string(sharding_specs:batch,unsharded, mesh:|batch=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7)


In [26]:
with layout_map.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(16, name='feature', input_shape=(16,)),
      tf.keras.layers.Dropout(0.1),
      tf.keras.layers.Dense(32, name='feature_2')
  ])

print(model.layers[2].kernel.layout)

Layout.from_string(sharding_specs:batch,unsharded, mesh:|batch=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7)
