##### Copyright 2019 The TensorFlow Authors.


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Keras で DTensors を使用する

<table class="tfo-notebook-buttons" align="left">
  <td>     <a target="_blank" href="https://www.tensorflow.org/tutorials/distribute/dtensor_keras_tutorial">     <img src="https://www.tensorflow.org/images/tf_logo_32px.png">     TensorFlow.org で表示</a> </td>
  <td>     <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/tutorials/distribute/dtensor_keras_tutorial.ipynb">     <img src="https://www.tensorflow.org/images/colab_logo_32px.png">     Google Colab で実行</a> </td>
  <td><a target="_blank" href="https://github.com/tensorflow/docs-l10n/blob/master/site/ja/tutorials/distribute/dtensor_keras_tutorial.ipynb">     <img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">     GitHubでソースを表示</a></td>
  <td><a href="https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/tutorials/distribute/dtensor_keras_tutorial.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png">ノートブックをダウンロード</a></td>
</table>

## 概要

このチュートリアルでは、Keras で DTensor を使用する方法について学習します。

DTensor を Keras と組み合わせることで、分散型機械学習モデルの構築とトレーニングに既存の Keras レイヤーとモデルを再利用することができます。

MNIST データを使用してマルチレイヤーの分類モデルをトレーニングします。サブクラス化モデル、Sequential モデル、Functional モデルのレイアウトの設定について説明します。

このチュートリアルでは、すでに「[DTensor プログラミングガイド](/guide/dtensor_overview)」を読んでいること、`Mesh` や `Layout` などの基本的な DTensor の概念に精通していることを前提としています。

このチュートリアルでは、https://www.tensorflow.org/datasets/keras_example を基盤に使用しています。

## MNIST モデルをビルドする

DTensor は、TensorFlow 2.9.0 リリースに含まれています。

In [2]:
!pip install --quiet --upgrade --pre tensorflow tensorflow-datasets

次に、`tensorflow` と `tensorflow.experimental.dtensor` をインポートし、8 個の仮想 CPU を使用するように TensorFlow を構成します。

この例では CPU を使用しますが、DTensor は CPU、GPU、または TPU デバイスで同じように動作します。

In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.experimental import dtensor

2024-01-11 18:16:50.357646: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 18:16:50.357693: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 18:16:50.359187: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [4]:
def configure_virtual_cpus(ncpu):
  phy_devices = tf.config.list_physical_devices('CPU')
  tf.config.set_logical_device_configuration(
        phy_devices[0], 
        [tf.config.LogicalDeviceConfiguration()] * ncpu)
  
configure_virtual_cpus(8)
tf.config.list_logical_devices('CPU')

devices = [f'CPU:{i}' for i in range(8)]

## 決定論的疑似乱数ジェネレータ

1 つ注意しておかなければならないのは、DTensor API では、実行中の各クライアントに同じランダムシードがある必要があることです。そうすることで、重みの初期化で決定論的動作が得られます。これは、Keras で `tf.keras.utils.set_random_seed()` を使ってグローバルシードを設定することで行えます。

In [5]:
tf.keras.backend.experimental.enable_tf_random_generator()
tf.keras.utils.set_random_seed(1337)

## データ並列メッシュを作成する

このチュートリアルでは、データ並列トレーニングを実演します。モデルの並列トレーニングと空間の並列トレーニングへの適応は、別のセットの `Layout` オブジェクトに切り替えるのと同じくらい単純です。データ並列を超える分散トレーニングについての詳細は、[DTensor の詳細な ML チュートリアル](https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial)をご覧ください。

データ並列トレーニングは、一般的に使用される並列トレーニングですが、`tf.distribute.MirroredStrategy` などによっても使用されます。

DTensor を使うと、データ並列トレーニングループは、単一の 'batch' 次元で構成される `Mesh` を使用します。各デバイスは、グローバルの batch からシャードを受け取るモデルのレプリカを実行します。


In [6]:
mesh = dtensor.create_mesh([("batch", 8)], devices=devices)

各デバイスがモデルの完全なレプリカを実行する過程で、モデルの変数がメッシュ（シャーディングなし）間で完全に複製されます。例として、この `Mesh` の階数 2 の重みに対して完全に複製されるレイアウトは、以下のようになります。

In [7]:
example_weight_layout = dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh)  # or
example_weight_layout = dtensor.Layout.replicated(mesh, rank=2)

階数 2 のデータテンソルのレイアウトは、最初の次元（`batch_sharded` としても知られます）に沿ってシャーディングされます。

In [8]:
example_data_layout = dtensor.Layout(['batch', dtensor.UNSHARDED], mesh)  # or
example_data_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)

## レイアウトで Keras レイヤーを作成する

データ並列スキームでは通常、モデルの各レプリカがシャーディングされた入力データを使って計算を行えるように、完全に複製されたレイアウトを使ってモデルの重みを作成します。

レイヤーの重みのレイアウト情報を構成するために、Keras では、ほとんどの組み込みレイヤーで使用できる追加のパラメータをレイヤーコンストラクタに公開しています。

以下は、完全に複製された重みレイアウトを使用する小さな画像分類モデルを構築する例です。レイアウト情報の `kernel` と `bias` は、`kernel_layout` と `bias_layout` の引数を介して `tf.keras.layers.Dense` に指定できます。組み込み Keras レイヤーのほとんどは、レイアウトの重みの `Layout` を明示的に指定できるようになっています。

In [9]:
unsharded_layout_2d = dtensor.Layout.replicated(mesh, 2)
unsharded_layout_1d = dtensor.Layout.replicated(mesh, 1)

In [10]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, 
                        activation='relu',
                        name='d1',
                        kernel_layout=unsharded_layout_2d, 
                        bias_layout=unsharded_layout_1d),
  tf.keras.layers.Dense(10,
                        name='d2',
                        kernel_layout=unsharded_layout_2d, 
                        bias_layout=unsharded_layout_1d)
])

レイアウト情報は、重みの `layout` プロパティを調べることで確認できます。

In [11]:
for weight in model.weights:
  print(f'Weight name: {weight.name} with layout: {weight.layout}')
  break

Weight name: d1/kernel:0 with layout: Layout.from_string(sharding_specs:unsharded,unsharded, mesh:|batch=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7)


## データセットを読み込んで入力パイプラインを構築する

MNIST データセットを読み込んで、それに使用する事前処理用の入力パイプラインを構成します。データセット自体は DTensor レイアウト情報に関連付けられていません。今後の TensorFlow リリースにおいて、`tf.data` との DTensor Keras 統合が改善される予定です。


In [12]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [13]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

In [14]:
batch_size = 128

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

In [15]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(batch_size)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

## モデルのトレーニングロジックを定義する

次に、モデルのトレーニングロジックと評価ロジックを定義します。

TensorFlow 2.9 の時点では、DTensor が有効化された Keras モデルにカスタムトレーニングループを書き込む必要があります。これは、入力データに、Keras の標準の `tf.keras.Model.fit()` または `tf.keras.Model.eval()` 関数には組み込まれていない適切なレイアウト情報を詰め込むために行います。さらに多くの `tf.data` サポートが、今後のリリースで追加される予定です。 

In [16]:
@tf.function
def train_step(model, x, y, optimizer, metrics):
  with tf.GradientTape() as tape:
    logits = model(x, training=True)
    # tf.reduce_sum sums the batch sharded per-example loss to a replicated
    # global loss (scalar).
    loss = tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(
        y, logits, from_logits=True))
    
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  for metric in metrics.values():
    metric.update_state(y_true=y, y_pred=logits)

  loss_per_sample = loss / len(x)
  results = {'loss': loss_per_sample}
  return results

In [17]:
@tf.function
def eval_step(model, x, y, metrics):
  logits = model(x, training=False)
  loss = tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(
        y, logits, from_logits=True))

  for metric in metrics.values():
    metric.update_state(y_true=y, y_pred=logits)

  loss_per_sample = loss / len(x)
  results = {'eval_loss': loss_per_sample}
  return results

In [18]:
def pack_dtensor_inputs(images, labels, image_layout, label_layout):
  num_local_devices = image_layout.mesh.num_local_devices()
  images = tf.split(images, num_local_devices)
  labels = tf.split(labels, num_local_devices)
  images = dtensor.pack(images, image_layout)
  labels = dtensor.pack(labels, label_layout)
  return  images, labels

## Metric と Optimizer

Keras `Metric` と `Optimizer` を使って DTensor API を使用する場合、追加のメッシュ情報を指定して、内部状態変数とテンソルがモデルの変数と連携できるようにする必要があります。

- オプティマイザの場合、DTensor は `keras.dtensor.experimental.optimizers` という新しい実験的な名前空間を使用します。多くの既存の Keras Optimizer は、追加の `mesh` 引数を受け取るように拡張されます。今後のリリースでは、Keras のコアオプティマイザにマージされる可能性があります。

- 指標の場合、DTensor 対応の `Metric` になるように、コンストラクタに直接引数として `mesh` を指定できます。

In [19]:
optimizer = tf.keras.dtensor.experimental.optimizers.Adam(0.01, mesh=mesh)
metrics = {'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(mesh=mesh)}
eval_metrics = {'eval_accuracy': tf.keras.metrics.SparseCategoricalAccuracy(mesh=mesh)}

## モデルのトレーニング

以下の例では、batch 次元で入力パイプラインのデータをシャード化し、完全に複製された重みをもつモデルを使ってトレーニングします。

モデルは 3 つのエポックで、約 97% の精度を達成します。

In [20]:
num_epochs = 3

image_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=4)
label_layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)

for epoch in range(num_epochs):
  print("============================") 
  print("Epoch: ", epoch)
  for metric in metrics.values():
    metric.reset_state()
  step = 0
  results = {}
  pbar = tf.keras.utils.Progbar(target=None, stateful_metrics=[])
  for input in ds_train:
    images, labels = input[0], input[1]
    images, labels = pack_dtensor_inputs(
        images, labels, image_layout, label_layout)

    results.update(train_step(model, images, labels, optimizer, metrics))
    for metric_name, metric in metrics.items():
      results[metric_name] = metric.result()

    pbar.update(step, values=results.items(), finalize=False)
    step += 1
  pbar.update(step, values=results.items(), finalize=True)

  for metric in eval_metrics.values():
    metric.reset_state()
  for input in ds_test:
    images, labels = input[0], input[1]
    images, labels = pack_dtensor_inputs(
        images, labels, image_layout, label_layout)
    results.update(eval_step(model, images, labels, eval_metrics))

  for metric_name, metric in eval_metrics.items():
    results[metric_name] = metric.result()
  
  for metric_name, metric in results.items():
    print(f"{metric_name}: {metric.numpy()}")


Epoch:  0


      0/Unknown - 3s 0s/step - loss: 2.4019 - accuracy: 0.0859

      4/Unknown - 3s 692ms/step - loss: 2.1233 - accuracy: 0.1611

     10/Unknown - 3s 282ms/step - loss: 1.7960 - accuracy: 0.2829

     16/Unknown - 3s 180ms/step - loss: 1.5757 - accuracy: 0.3551

     22/Unknown - 3s 133ms/step - loss: 1.3916 - accuracy: 0.4056

     28/Unknown - 3s 106ms/step - loss: 1.2480 - accuracy: 0.4456

     34/Unknown - 3s 89ms/step - loss: 1.1229 - accuracy: 0.4786 

     40/Unknown - 3s 77ms/step - loss: 1.0421 - accuracy: 0.5064

     46/Unknown - 3s 68ms/step - loss: 0.9624 - accuracy: 0.5302

     52/Unknown - 3s 61ms/step - loss: 0.8969 - accuracy: 0.5510

     58/Unknown - 3s 56ms/step - loss: 0.8472 - accuracy: 0.5694

     64/Unknown - 3s 51ms/step - loss: 0.8005 - accuracy: 0.5858

     70/Unknown - 3s 48ms/step - loss: 0.7649 - accuracy: 0.6006

     76/Unknown - 3s 45ms/step - loss: 0.7302 - accuracy: 0.6138

     82/Unknown - 3s 42ms/step - loss: 0.7035 - accuracy: 0.6257

     88/Unknown - 3s 40ms/step - loss: 0.6772 - accuracy: 0.6367

     94/Unknown - 4s 38ms/step - loss: 0.6539 - accuracy: 0.6467

    100/Unknown - 4s 36ms/step - loss: 0.6312 - accuracy: 0.6560

    106/Unknown - 4s 34ms/step - loss: 0.6123 - accuracy: 0.6646

    112/Unknown - 4s 33ms/step - loss: 0.5953 - accuracy: 0.6726

    118/Unknown - 4s 32ms/step - loss: 0.5797 - accuracy: 0.6800

    124/Unknown - 4s 31ms/step - loss: 0.5635 - accuracy: 0.6870

    130/Unknown - 4s 30ms/step - loss: 0.5484 - accuracy: 0.6935

    136/Unknown - 4s 29ms/step - loss: 0.5370 - accuracy: 0.6997

    142/Unknown - 4s 28ms/step - loss: 0.5241 - accuracy: 0.7055

    148/Unknown - 4s 27ms/step - loss: 0.5134 - accuracy: 0.7109

    154/Unknown - 4s 26ms/step - loss: 0.5033 - accuracy: 0.7161

    160/Unknown - 4s 26ms/step - loss: 0.4913 - accuracy: 0.7210

    166/Unknown - 4s 25ms/step - loss: 0.4830 - accuracy: 0.7257

    172/Unknown - 4s 24ms/step - loss: 0.4763 - accuracy: 0.7301

    178/Unknown - 4s 24ms/step - loss: 0.4685 - accuracy: 0.7343

    184/Unknown - 4s 23ms/step - loss: 0.4624 - accuracy: 0.7383

    190/Unknown - 4s 23ms/step - loss: 0.4552 - accuracy: 0.7421

    196/Unknown - 4s 23ms/step - loss: 0.4493 - accuracy: 0.7457

    202/Unknown - 4s 22ms/step - loss: 0.4415 - accuracy: 0.7492

    208/Unknown - 5s 22ms/step - loss: 0.4361 - accuracy: 0.7526

    214/Unknown - 5s 21ms/step - loss: 0.4301 - accuracy: 0.7558

    220/Unknown - 5s 21ms/step - loss: 0.4235 - accuracy: 0.7588

    226/Unknown - 5s 21ms/step - loss: 0.4180 - accuracy: 0.7618

    232/Unknown - 5s 20ms/step - loss: 0.4128 - accuracy: 0.7647

    238/Unknown - 5s 20ms/step - loss: 0.4076 - accuracy: 0.7674

    244/Unknown - 5s 20ms/step - loss: 0.4018 - accuracy: 0.7701

    250/Unknown - 5s 20ms/step - loss: 0.3975 - accuracy: 0.7727

    256/Unknown - 5s 19ms/step - loss: 0.3934 - accuracy: 0.7752

    262/Unknown - 5s 19ms/step - loss: 0.3883 - accuracy: 0.7776

    268/Unknown - 5s 19ms/step - loss: 0.3839 - accuracy: 0.7799

    274/Unknown - 5s 19ms/step - loss: 0.3798 - accuracy: 0.7821

    280/Unknown - 5s 18ms/step - loss: 0.3767 - accuracy: 0.7843

    286/Unknown - 5s 18ms/step - loss: 0.3729 - accuracy: 0.7864

    292/Unknown - 5s 18ms/step - loss: 0.3685 - accuracy: 0.7885

    298/Unknown - 5s 18ms/step - loss: 0.3649 - accuracy: 0.7905

    304/Unknown - 5s 18ms/step - loss: 0.3615 - accuracy: 0.7924

    310/Unknown - 5s 17ms/step - loss: 0.3579 - accuracy: 0.7943

    316/Unknown - 5s 17ms/step - loss: 0.3540 - accuracy: 0.7962

    322/Unknown - 5s 17ms/step - loss: 0.3510 - accuracy: 0.7980

    328/Unknown - 6s 17ms/step - loss: 0.3475 - accuracy: 0.7997

    334/Unknown - 6s 17ms/step - loss: 0.3434 - accuracy: 0.8014

    340/Unknown - 6s 17ms/step - loss: 0.3413 - accuracy: 0.8030

    346/Unknown - 6s 16ms/step - loss: 0.3380 - accuracy: 0.8047

    352/Unknown - 6s 16ms/step - loss: 0.3345 - accuracy: 0.8062

    358/Unknown - 6s 16ms/step - loss: 0.3314 - accuracy: 0.8078

    364/Unknown - 6s 16ms/step - loss: 0.3281 - accuracy: 0.8093

    370/Unknown - 6s 16ms/step - loss: 0.3259 - accuracy: 0.8107

    376/Unknown - 6s 16ms/step - loss: 0.3227 - accuracy: 0.8121

    382/Unknown - 6s 16ms/step - loss: 0.3194 - accuracy: 0.8135

    388/Unknown - 6s 16ms/step - loss: 0.3165 - accuracy: 0.8149

    394/Unknown - 6s 16ms/step - loss: 0.3149 - accuracy: 0.8163

    400/Unknown - 6s 15ms/step - loss: 0.3125 - accuracy: 0.8176

    406/Unknown - 6s 15ms/step - loss: 0.3103 - accuracy: 0.8188

    412/Unknown - 6s 15ms/step - loss: 0.3082 - accuracy: 0.8201

    418/Unknown - 6s 15ms/step - loss: 0.3063 - accuracy: 0.8213

    424/Unknown - 6s 15ms/step - loss: 0.3041 - accuracy: 0.8225

    430/Unknown - 6s 15ms/step - loss: 0.3020 - accuracy: 0.8237

    436/Unknown - 6s 15ms/step - loss: 0.2998 - accuracy: 0.8249

    442/Unknown - 7s 15ms/step - loss: 0.2978 - accuracy: 0.8260

    448/Unknown - 7s 15ms/step - loss: 0.2964 - accuracy: 0.8271

    454/Unknown - 7s 15ms/step - loss: 0.2948 - accuracy: 0.8282

    460/Unknown - 7s 15ms/step - loss: 0.2935 - accuracy: 0.8293

    466/Unknown - 7s 14ms/step - loss: 0.2917 - accuracy: 0.8303

    468/Unknown - 7s 15ms/step - loss: 0.2910 - accuracy: 0.8306

    469/Unknown - 7s 15ms/step - loss: 0.2907 - accuracy: 0.8308


loss: 0.12670570611953735
accuracy: 0.9109166860580444
eval_loss: 0.04795415699481964
eval_accuracy: 0.958899974822998
Epoch:  1
      0/Unknown - 0s 0s/step - loss: 0.1073 - accuracy: 0.9766

      6/Unknown - 0s 23ms/step - loss: 0.1262 - accuracy: 0.9683

     12/Unknown - 0s 16ms/step - loss: 0.1237 - accuracy: 0.9649

     18/Unknown - 0s 13ms/step - loss: 0.1337 - accuracy: 0.9641

     24/Unknown - 0s 12ms/step - loss: 0.1317 - accuracy: 0.9632

     30/Unknown - 0s 12ms/step - loss: 0.1264 - accuracy: 0.9629

     36/Unknown - 0s 11ms/step - loss: 0.1273 - accuracy: 0.9625

     42/Unknown - 0s 11ms/step - loss: 0.1245 - accuracy: 0.9622

     48/Unknown - 1s 10ms/step - loss: 0.1233 - accuracy: 0.9621

     54/Unknown - 1s 10ms/step - loss: 0.1263 - accuracy: 0.9620

     60/Unknown - 1s 10ms/step - loss: 0.1256 - accuracy: 0.9620

     66/Unknown - 1s 10ms/step - loss: 0.1249 - accuracy: 0.9619

     72/Unknown - 1s 10ms/step - loss: 0.1283 - accuracy: 0.9618

     78/Unknown - 1s 10ms/step - loss: 0.1276 - accuracy: 0.9617

     84/Unknown - 1s 10ms/step - loss: 0.1252 - accuracy: 0.9616

     90/Unknown - 1s 10ms/step - loss: 0.1265 - accuracy: 0.9616

     96/Unknown - 1s 10ms/step - loss: 0.1287 - accuracy: 0.9615

    102/Unknown - 1s 9ms/step - loss: 0.1305 - accuracy: 0.9613 

    108/Unknown - 1s 9ms/step - loss: 0.1322 - accuracy: 0.9612

    114/Unknown - 1s 9ms/step - loss: 0.1335 - accuracy: 0.9610

    120/Unknown - 1s 9ms/step - loss: 0.1356 - accuracy: 0.9608

    126/Unknown - 1s 9ms/step - loss: 0.1381 - accuracy: 0.9606

    132/Unknown - 1s 9ms/step - loss: 0.1372 - accuracy: 0.9604

    138/Unknown - 1s 9ms/step - loss: 0.1383 - accuracy: 0.9603

    144/Unknown - 1s 9ms/step - loss: 0.1371 - accuracy: 0.9601

    150/Unknown - 1s 9ms/step - loss: 0.1368 - accuracy: 0.9600

    156/Unknown - 1s 9ms/step - loss: 0.1359 - accuracy: 0.9599

    162/Unknown - 1s 9ms/step - loss: 0.1354 - accuracy: 0.9599

    168/Unknown - 2s 9ms/step - loss: 0.1351 - accuracy: 0.9598

    174/Unknown - 2s 9ms/step - loss: 0.1339 - accuracy: 0.9597

    180/Unknown - 2s 9ms/step - loss: 0.1341 - accuracy: 0.9597

    186/Unknown - 2s 9ms/step - loss: 0.1347 - accuracy: 0.9597

    192/Unknown - 2s 9ms/step - loss: 0.1349 - accuracy: 0.9596

    198/Unknown - 2s 9ms/step - loss: 0.1344 - accuracy: 0.9596

    204/Unknown - 2s 9ms/step - loss: 0.1349 - accuracy: 0.9595

    210/Unknown - 2s 9ms/step - loss: 0.1346 - accuracy: 0.9595

    216/Unknown - 2s 9ms/step - loss: 0.1349 - accuracy: 0.9595

    222/Unknown - 2s 9ms/step - loss: 0.1341 - accuracy: 0.9594

    228/Unknown - 2s 9ms/step - loss: 0.1345 - accuracy: 0.9594

    234/Unknown - 2s 9ms/step - loss: 0.1346 - accuracy: 0.9594

    240/Unknown - 2s 9ms/step - loss: 0.1355 - accuracy: 0.9593

    246/Unknown - 2s 9ms/step - loss: 0.1355 - accuracy: 0.9593

    252/Unknown - 2s 9ms/step - loss: 0.1358 - accuracy: 0.9593

    258/Unknown - 2s 9ms/step - loss: 0.1353 - accuracy: 0.9593

    264/Unknown - 2s 9ms/step - loss: 0.1350 - accuracy: 0.9593

    270/Unknown - 2s 9ms/step - loss: 0.1340 - accuracy: 0.9593

    276/Unknown - 2s 9ms/step - loss: 0.1344 - accuracy: 0.9592

    282/Unknown - 3s 9ms/step - loss: 0.1340 - accuracy: 0.9592

    288/Unknown - 3s 9ms/step - loss: 0.1338 - accuracy: 0.9592

    294/Unknown - 3s 9ms/step - loss: 0.1332 - accuracy: 0.9592

    300/Unknown - 3s 9ms/step - loss: 0.1330 - accuracy: 0.9592

    306/Unknown - 3s 9ms/step - loss: 0.1327 - accuracy: 0.9592

    312/Unknown - 3s 9ms/step - loss: 0.1325 - accuracy: 0.9592

    318/Unknown - 3s 9ms/step - loss: 0.1324 - accuracy: 0.9592

    324/Unknown - 3s 9ms/step - loss: 0.1321 - accuracy: 0.9592

    330/Unknown - 3s 9ms/step - loss: 0.1318 - accuracy: 0.9592

    336/Unknown - 3s 9ms/step - loss: 0.1314 - accuracy: 0.9592

    342/Unknown - 3s 9ms/step - loss: 0.1311 - accuracy: 0.9592

    348/Unknown - 3s 9ms/step - loss: 0.1308 - accuracy: 0.9592

    354/Unknown - 3s 9ms/step - loss: 0.1302 - accuracy: 0.9592

    360/Unknown - 3s 9ms/step - loss: 0.1296 - accuracy: 0.9592

    366/Unknown - 3s 9ms/step - loss: 0.1291 - accuracy: 0.9592

    372/Unknown - 3s 9ms/step - loss: 0.1290 - accuracy: 0.9592

    378/Unknown - 3s 9ms/step - loss: 0.1283 - accuracy: 0.9593

    384/Unknown - 3s 9ms/step - loss: 0.1277 - accuracy: 0.9593

    390/Unknown - 3s 9ms/step - loss: 0.1268 - accuracy: 0.9593

    396/Unknown - 4s 9ms/step - loss: 0.1266 - accuracy: 0.9593

    402/Unknown - 4s 9ms/step - loss: 0.1263 - accuracy: 0.9593

    408/Unknown - 4s 9ms/step - loss: 0.1264 - accuracy: 0.9593

    414/Unknown - 4s 9ms/step - loss: 0.1262 - accuracy: 0.9594

    420/Unknown - 4s 9ms/step - loss: 0.1261 - accuracy: 0.9594

    426/Unknown - 4s 9ms/step - loss: 0.1262 - accuracy: 0.9594

    432/Unknown - 4s 9ms/step - loss: 0.1262 - accuracy: 0.9594

    438/Unknown - 4s 9ms/step - loss: 0.1270 - accuracy: 0.9594

    444/Unknown - 4s 9ms/step - loss: 0.1274 - accuracy: 0.9594

    450/Unknown - 4s 9ms/step - loss: 0.1280 - accuracy: 0.9595

    456/Unknown - 4s 9ms/step - loss: 0.1278 - accuracy: 0.9595

    462/Unknown - 4s 9ms/step - loss: 0.1288 - accuracy: 0.9595

    468/Unknown - 4s 9ms/step - loss: 0.1286 - accuracy: 0.9595

    469/Unknown - 4s 9ms/step - loss: 0.1285 - accuracy: 0.9595


loss: 0.08036476373672485
accuracy: 0.9599999785423279
eval_loss: 0.008920179679989815
eval_accuracy: 0.9642000198364258
Epoch:  2
      0/Unknown - 0s 0s/step - loss: 0.1102 - accuracy: 0.9688

      6/Unknown - 0s 23ms/step - loss: 0.1375 - accuracy: 0.9626

     12/Unknown - 0s 16ms/step - loss: 0.1137 - accuracy: 0.9621

     18/Unknown - 0s 13ms/step - loss: 0.1053 - accuracy: 0.9633

     24/Unknown - 0s 12ms/step - loss: 0.1093 - accuracy: 0.9640

     30/Unknown - 0s 11ms/step - loss: 0.1113 - accuracy: 0.9643

     36/Unknown - 0s 11ms/step - loss: 0.1071 - accuracy: 0.9647

     42/Unknown - 0s 11ms/step - loss: 0.1050 - accuracy: 0.9651

     48/Unknown - 0s 10ms/step - loss: 0.1055 - accuracy: 0.9654

     54/Unknown - 1s 10ms/step - loss: 0.1075 - accuracy: 0.9655

     60/Unknown - 1s 10ms/step - loss: 0.1094 - accuracy: 0.9656

     66/Unknown - 1s 10ms/step - loss: 0.1070 - accuracy: 0.9658

     72/Unknown - 1s 10ms/step - loss: 0.1059 - accuracy: 0.9659

     78/Unknown - 1s 10ms/step - loss: 0.1054 - accuracy: 0.9661

     84/Unknown - 1s 10ms/step - loss: 0.1041 - accuracy: 0.9662

     90/Unknown - 1s 10ms/step - loss: 0.1039 - accuracy: 0.9663

     96/Unknown - 1s 9ms/step - loss: 0.1029 - accuracy: 0.9664 

    102/Unknown - 1s 9ms/step - loss: 0.1004 - accuracy: 0.9665

    108/Unknown - 1s 9ms/step - loss: 0.1003 - accuracy: 0.9667

    114/Unknown - 1s 9ms/step - loss: 0.1011 - accuracy: 0.9668

    120/Unknown - 1s 9ms/step - loss: 0.1007 - accuracy: 0.9669

    126/Unknown - 1s 9ms/step - loss: 0.0996 - accuracy: 0.9670

    132/Unknown - 1s 9ms/step - loss: 0.0982 - accuracy: 0.9671

    138/Unknown - 1s 9ms/step - loss: 0.0997 - accuracy: 0.9672

    145/Unknown - 1s 9ms/step - loss: 0.0996 - accuracy: 0.9672

    151/Unknown - 1s 9ms/step - loss: 0.0985 - accuracy: 0.9673

    157/Unknown - 1s 9ms/step - loss: 0.0987 - accuracy: 0.9674

    163/Unknown - 1s 9ms/step - loss: 0.0990 - accuracy: 0.9674

    169/Unknown - 2s 9ms/step - loss: 0.1003 - accuracy: 0.9675

    175/Unknown - 2s 9ms/step - loss: 0.0997 - accuracy: 0.9675

    181/Unknown - 2s 9ms/step - loss: 0.0989 - accuracy: 0.9675

    187/Unknown - 2s 9ms/step - loss: 0.0997 - accuracy: 0.9676

    193/Unknown - 2s 9ms/step - loss: 0.0995 - accuracy: 0.9676

    199/Unknown - 2s 9ms/step - loss: 0.0991 - accuracy: 0.9676

    205/Unknown - 2s 9ms/step - loss: 0.0986 - accuracy: 0.9677

    211/Unknown - 2s 9ms/step - loss: 0.0992 - accuracy: 0.9677

    217/Unknown - 2s 9ms/step - loss: 0.1001 - accuracy: 0.9677

    223/Unknown - 2s 9ms/step - loss: 0.0990 - accuracy: 0.9678

    229/Unknown - 2s 9ms/step - loss: 0.0990 - accuracy: 0.9678

    235/Unknown - 2s 9ms/step - loss: 0.0981 - accuracy: 0.9678

    241/Unknown - 2s 9ms/step - loss: 0.0988 - accuracy: 0.9679

    247/Unknown - 2s 9ms/step - loss: 0.0984 - accuracy: 0.9679

    253/Unknown - 2s 9ms/step - loss: 0.0990 - accuracy: 0.9680

    259/Unknown - 2s 9ms/step - loss: 0.0989 - accuracy: 0.9680

    265/Unknown - 2s 9ms/step - loss: 0.0994 - accuracy: 0.9680

    271/Unknown - 2s 9ms/step - loss: 0.0987 - accuracy: 0.9680

    277/Unknown - 2s 9ms/step - loss: 0.0985 - accuracy: 0.9681

    283/Unknown - 3s 9ms/step - loss: 0.0979 - accuracy: 0.9681

    289/Unknown - 3s 9ms/step - loss: 0.0984 - accuracy: 0.9681

    295/Unknown - 3s 9ms/step - loss: 0.0984 - accuracy: 0.9681

    301/Unknown - 3s 9ms/step - loss: 0.0979 - accuracy: 0.9681

    307/Unknown - 3s 9ms/step - loss: 0.0982 - accuracy: 0.9682

    313/Unknown - 3s 9ms/step - loss: 0.0979 - accuracy: 0.9682

    319/Unknown - 3s 9ms/step - loss: 0.0980 - accuracy: 0.9682

    325/Unknown - 3s 9ms/step - loss: 0.0984 - accuracy: 0.9682

    331/Unknown - 3s 9ms/step - loss: 0.0989 - accuracy: 0.9682

    337/Unknown - 3s 9ms/step - loss: 0.0990 - accuracy: 0.9682

    343/Unknown - 3s 9ms/step - loss: 0.0992 - accuracy: 0.9683

    349/Unknown - 3s 9ms/step - loss: 0.0994 - accuracy: 0.9683

    355/Unknown - 3s 9ms/step - loss: 0.0998 - accuracy: 0.9683

    361/Unknown - 3s 9ms/step - loss: 0.0998 - accuracy: 0.9683

    367/Unknown - 3s 9ms/step - loss: 0.0996 - accuracy: 0.9683

    373/Unknown - 3s 9ms/step - loss: 0.0996 - accuracy: 0.9683

    379/Unknown - 3s 9ms/step - loss: 0.1000 - accuracy: 0.9683

    385/Unknown - 3s 9ms/step - loss: 0.1007 - accuracy: 0.9683

    391/Unknown - 3s 9ms/step - loss: 0.1011 - accuracy: 0.9683

    397/Unknown - 3s 9ms/step - loss: 0.1006 - accuracy: 0.9683

    403/Unknown - 4s 9ms/step - loss: 0.1005 - accuracy: 0.9682

    409/Unknown - 4s 9ms/step - loss: 0.1011 - accuracy: 0.9682

    415/Unknown - 4s 9ms/step - loss: 0.1013 - accuracy: 0.9682

    421/Unknown - 4s 9ms/step - loss: 0.1016 - accuracy: 0.9682

    427/Unknown - 4s 9ms/step - loss: 0.1020 - accuracy: 0.9682

    433/Unknown - 4s 9ms/step - loss: 0.1021 - accuracy: 0.9682

    439/Unknown - 4s 9ms/step - loss: 0.1019 - accuracy: 0.9682

    445/Unknown - 4s 9ms/step - loss: 0.1017 - accuracy: 0.9682

    451/Unknown - 4s 9ms/step - loss: 0.1014 - accuracy: 0.9682

    457/Unknown - 4s 9ms/step - loss: 0.1016 - accuracy: 0.9682

    463/Unknown - 4s 9ms/step - loss: 0.1015 - accuracy: 0.9682

    469/Unknown - 4s 9ms/step - loss: 0.1010 - accuracy: 0.9682


loss: 0.044021397829055786
accuracy: 0.9682833552360535
eval_loss: 0.05413995310664177
eval_accuracy: 0.9656000137329102


## 既存のモデルコードのレイアウトを指定する

ほとんどの場合、モデルは特定のユースケースでうまく動作するようになっているため、モデル内の個別のレイヤーに `Layout` 情報を指定する作業は膨大であり、多数の編集作業が必要となります。

既存の Keras モデルを DTensor API で動作できるようにするための変換作業を行いやすくするために、グローバルな観点で `Layout` を指定できる新しい `dtensor.LayoutMap` API を使用できます。

まず、`LayoutMap` インスタンスを作成する必要があります。これは、モデルの重みに指定するすべての `Layout` を構成するディクショナリのようなオブジェクトです。

`LayoutMap` には、init 時に `Mesh` インスタンスが必要です。これは、Layout が構成されていない、任意の重みに対するデフォルトの複製済み `Layout` を指定するために使用できます。すべてのモデルの重みを完全に複製するだけの場合は、空の `LayoutMap` を指定すると、デフォルトのメッシュを使って複製された `Layout` が作成されます。

`LayoutMap` は、文字列をキーとして、`Layout` を値として使用します。通常の Python dict とこのクラスでは、動作が異なります。文字列キーは、値を取得する際の正規表現として処理されます。

### Subclassed モデル

Keras のサブクラス化モデル構文を使って定義された以下のモデルについて考察してみましょう。

In [21]:
class SubclassedModel(tf.keras.Model):

  def __init__(self, name=None):
    super().__init__(name=name)
    self.feature = tf.keras.layers.Dense(16)
    self.feature_2 = tf.keras.layers.Dense(24)
    self.dropout = tf.keras.layers.Dropout(0.1)

  def call(self, inputs, training=None):
    x = self.feature(inputs)
    x = self.dropout(x, training=training)
    return self.feature_2(x)

このモデルには、2 つの `Dense` レイヤーに対し `kernel` と `bias` レイヤーという 4 つの重みがあります。それぞれは、オブジェクトパスに基づいてマッピングされています。

- `model.feature.kernel`
- `model.feature.bias`
- `model.feature_2.kernel`
- `model.feature_2.bias`

注意: Subclassed モデルでは、マッピングから Layout を取得する際に、レイヤーの `.name` 属性ではなく、属性名がキーとして使用されます。これは、`tf.Module` のチェックポイント設定が使う規則と同じです。多数のレイヤーを持つ複雑なモデルでは、[チェックポイントを手動で検査する](https://www.tensorflow.org/guide/checkpoint#manually_inspecting_checkpoints)ことで、属性のマッピングを確認できます。

では、以下の `LayoutMap` を定義して、モデルを適用しましょう。

In [22]:
layout_map = tf.keras.dtensor.experimental.LayoutMap(mesh=mesh)

layout_map['feature.*kernel'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)
layout_map['feature.*bias'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)

with layout_map.scope():
  subclassed_model = SubclassedModel()

モデルの重みは最初のセルに作成されているため、DTensor 入力でモデルを呼び出し、重みに期待されるレイアウトがあることを確認します。

In [23]:
dtensor_input = dtensor.copy_to_mesh(tf.zeros((16, 16)), layout=unsharded_layout_2d)
# Trigger the weights creation for subclass model
subclassed_model(dtensor_input)

print(subclassed_model.feature.kernel.layout)

Layout.from_string(sharding_specs:batch,unsharded, mesh:|batch=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7)


これにより、既存のコードを更新することなく、`Layout` をすばやくモデルにマッピングすることができます。 

### Sequential モデルと Functional モデル

Keras Functional モデルと Sequential モデルの場合も、`LayoutMap` を使用できます。

注意: Functional モデルと Sequential モデルでは、マッピングにわずかな違いがあります。モデルのレイヤーには、モデルに接続された公開属性がありません（ただし、`model.layers` を介してリストとしてアクセス可能です）。この場合、文字列名をキーとして使用します。文字列名は、モデル内で必ず一意の値です。

In [24]:
layout_map = tf.keras.dtensor.experimental.LayoutMap(mesh=mesh)

layout_map['feature.*kernel'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=2)
layout_map['feature.*bias'] = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)

In [25]:
with layout_map.scope():
  inputs = tf.keras.Input((16,), batch_size=16)
  x = tf.keras.layers.Dense(16, name='feature')(inputs)
  x = tf.keras.layers.Dropout(0.1)(x)
  output = tf.keras.layers.Dense(32, name='feature_2')(x)
  model = tf.keras.Model(inputs, output)

print(model.layers[1].kernel.layout)

Layout.from_string(sharding_specs:batch,unsharded, mesh:|batch=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7)


In [26]:
with layout_map.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(16, name='feature', input_shape=(16,)),
      tf.keras.layers.Dropout(0.1),
      tf.keras.layers.Dense(32, name='feature_2')
  ])

print(model.layers[2].kernel.layout)

Layout.from_string(sharding_specs:batch,unsharded, mesh:|batch=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7)
