##### Copyright 2019 The TensorFlow Authors.


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DTensor를 사용한 분산 훈련


<table class="tfo-notebook-buttons" align="left">
  <td><a target="_blank" href="https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial"><img src="https://www.tensorflow.org/images/tf_logo_32px.png">TensorFlow.org에서 보기</a></td>
  <td><a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ko/tutorials/distribute/dtensor_ml_tutorial.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png">Google Colab에서 실행하기</a></td>
  <td><a target="_blank" href="https://github.com/tensorflow/docs-l10n/blob/master/site/ko/tutorials/distribute/dtensor_ml_tutorial.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">GitHub에서 소스 보기</a></td>
  <td><a href="https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ko/tutorials/distribute/dtensor_ml_tutorial.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png">노트북 다운로드하기</a></td>
</table>

## 개요

DTensor는 여러 장치에 모델 훈련을 배포하여 효율성, 안정성 및 확장성을 개선하는 데 이용할 수 있습니다. DTensor 개념에 대한 자세한 내용은 [DTensor 프로그래밍 가이드](https://www.tensorflow.org/guide/dtensor_overview)를 참조하세요.

이 튜토리얼에서는 DTensor를 사용하여 감정 분석 모델을 훈련합니다. 이 예에서는 세 가지 분산 훈련 방식을 보여줍니다.

- 훈련 샘플이 장치로 샤딩(분할)되는 데이터 병렬 훈련
- 모델 변수가 장치로 샤딩되는 모델 병렬 훈련
- 입력 데이터의 기능이 장치에 샤딩되는 공간 병렬 훈련([공간 분할](https://cloud.google.com/blog/products/ai-machine-learning/train-ml-models-on-large-images-and-3d-volumes-with-spatial-partitioning-on-cloud-tpus)이라고도 함)

이 튜토리얼의 훈련 부분은 [감정 분석에 대한 Kaggle 가이드](https://www.kaggle.com/code/anasofiauzsoy/yelp-review-sentiment-analysis-tensorflow-tfds/notebook) 노트북에서 아이디어를 얻었습니다. 전체 훈련 및 평가 워크플로(DTensor 제외)에 대해 알아보려면 해당 노트북을 참조하세요.

이 튜토리얼에서는 다음 단계를 안내합니다.

- 먼저 토큰화된 문장과 그 극성의 `tf.data.Dataset`을 얻기 위해 일부 데이터 정리로 시작합니다.

- 다음으로, 사용자 정의 Dense 및 BatchNorm 레이어를 사용하여 MLP 모델을 빌드합니다. `tf.Module`을 사용하여 추론 변수를 추적합니다. 모델 생성자는 변수의 샤딩을 제어하기 위해 추가 `Layout` 인수를 사용합니다.

- 훈련을 위해 먼저 `tf.experimental.dtensor`의 체크포인트 기능과 함께 데이터 병렬 훈련을 사용합니다. 그런 다음, 모델 병렬 훈련과 공간 병렬 훈련을 계속합니다.

- 마지막 섹션에서는 TensorFlow 2.9에서 `tf.saved_model`과 `tf.experimental.dtensor` 간의 상호 작용을 간략하게 설명합니다.


## 설정

DTensor는 TensorFlow 2.9.0 릴리스의 일부입니다.

In [2]:
!pip install --quiet --upgrade --pre tensorflow tensorflow-datasets

다음으로 `tensorflow` 및 `tensorflow.experimental.dtensor`를 가져옵니다. 그런 다음 8개의 가상 CPU를 사용하도록 TensorFlow를 구성합니다.

이 예제에서는 CPU를 사용하지만 DTensor는 CPU, GPU 또는 TPU 장치에서 동일한 방식으로 작동합니다.

In [3]:
import tempfile
import numpy as np
import tensorflow_datasets as tfds

import tensorflow as tf

from tensorflow.experimental import dtensor
print('TensorFlow version:', tf.__version__)

2022-12-15 01:50:37.265914: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-15 01:50:37.266014: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


TensorFlow version: 2.11.0


In [4]:
def configure_virtual_cpus(ncpu):
  phy_devices = tf.config.list_physical_devices('CPU')
  tf.config.set_logical_device_configuration(phy_devices[0], [
        tf.config.LogicalDeviceConfiguration(),
    ] * ncpu)

configure_virtual_cpus(8)
DEVICES = [f'CPU:{i}' for i in range(8)]

tf.config.list_logical_devices('CPU')

[LogicalDevice(name='/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/device:CPU:1', device_type='CPU'),
 LogicalDevice(name='/device:CPU:2', device_type='CPU'),
 LogicalDevice(name='/device:CPU:3', device_type='CPU'),
 LogicalDevice(name='/device:CPU:4', device_type='CPU'),
 LogicalDevice(name='/device:CPU:5', device_type='CPU'),
 LogicalDevice(name='/device:CPU:6', device_type='CPU'),
 LogicalDevice(name='/device:CPU:7', device_type='CPU')]

## 데이터세트 다운로드하기

감정 분석 모델을 훈련하기 위한 IMDB 리뷰 데이터세트를 다운로드합니다.

In [5]:
train_data = tfds.load('imdb_reviews', split='train', shuffle_files=True, batch_size=64)
train_data

<PrefetchDataset element_spec={'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None), 'text': TensorSpec(shape=(None,), dtype=tf.string, name=None)}>

## 데이터 준비하기

먼저 텍스트를 토큰화합니다. 여기에서는 원-핫 인코딩의 확장인 `tf.keras.layers.TextVectorization`의 `'tf_idf'` 모드를 사용합니다.

- 속도를 위해 토큰 수를 1200개로 제한합니다.
- `tf.Module`을 단순하게 유지하기 위해 훈련 전의 전처리 단계로 `TextVectorization`을 실행합니다.

데이터 정리 섹션의 최종 결과는 토큰화된 텍스트가 `x`이고 레이블이 `y`인 `Dataset`입니다.

**참고**: `TextVectorization`을 전처리 단계로 실행하는 것은 **일반적이지 않고 권장되지도 않습니다**. 그렇게 하면 훈련 데이터가 클라이언트 메모리에 맞는다고 가정하는 것이지만 항상 그런 것은 아니기 때문입니다.


In [6]:
text_vectorization = tf.keras.layers.TextVectorization(output_mode='tf_idf', max_tokens=1200, output_sequence_length=None)
text_vectorization.adapt(data=train_data.map(lambda x: x['text']))

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [7]:
def vectorize(features):
  return text_vectorization(features['text']), features['label']

train_data_vec = train_data.map(vectorize)
train_data_vec

<MapDataset element_spec=(TensorSpec(shape=(None, 1200), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>

## DTensor로 신경망 구축하기

이제 `DTensor`를 사용하여 MLP(Multi-Layer Perceptron) 네트워크를 구축합니다. 이 네트워크는 완전히 연결된 Dense 및 BatchNorm 레이어를 사용합니다.

`DTensor`는 입력 `Tensor` 및 변수의 `dtensor.Layout` 속성에 따라 일반 TensorFlow Ops의 SPMD(단일 프로그램 다중 데이터) 확장을 통해 TensorFlow를 확장합니다.

`DTensor` 인식 레이어의 변수는 `dtensor.DVariable`이며 `DTensor` 인식 레이어 객체의 생성자는 일반적인 레이어 매개변수 외에 추가 `Layout` 입력을 받습니다.

참고: TensorFlow 2.9부터 tf.keras.layer.Dense 및 `tf.keras.layer.BatchNormalization`과 같은 `tf.keras.layer.Dense` 레이어는 `dtensor.Layout` 인수를 허용합니다. Keras를 DTensor와 함께 사용하기 위한 자세한 내용은 [DTensor Keras 통합 튜토리얼](/tutorials/distribute/dtensor_keras_tutorial)을 참조하세요.

### Dense 레이어

다음 사용자 정의 Dense 레이어는 2개의 레이어 변수를 정의합니다. $W_{ij}$는 가중치에 대한 변수이고 $b_i$는 바이어스에 대한 변수입니다.

$$ y_j = \sigma(\sum_i x_i W_{ij} + b_j) $$


### 레이아웃 추론

이 결과는 다음 관찰로부터 비롯됩니다.

- 행렬 내적 $t_j = \sum_i x_i W_{ij}$에 대한 피연산자의 기본 DTensor 샤딩은 $i$ 축을 따라 $\mathbf{W}$ 및 $\mathbf{x}$를 같은 방식으로 샤딩하는 것입니다.

- 행렬 합 $t_j + b_j$에 대한 피연산자에 대한 기본 DTensor 샤딩은 $j$ 축을 따라 동일한 방식으로 $\mathbf{t}$ 및 $\mathbf{b}$를 분할하는 것입니다.


In [8]:
class Dense(tf.Module):

  def __init__(self, input_size, output_size,
               init_seed, weight_layout, activation=None):
    super().__init__()

    random_normal_initializer = tf.function(tf.random.stateless_normal)

    self.weight = dtensor.DVariable(
        dtensor.call_with_layout(
            random_normal_initializer, weight_layout,
            shape=[input_size, output_size],
            seed=init_seed
            ))
    if activation is None:
      activation = lambda x:x
    self.activation = activation
    
    # bias is sharded the same way as the last axis of weight.
    bias_layout = weight_layout.delete([0])

    self.bias = dtensor.DVariable(
        dtensor.call_with_layout(tf.zeros, bias_layout, [output_size]))

  def __call__(self, x):
    y = tf.matmul(x, self.weight) + self.bias
    y = self.activation(y)

    return y

### BatchNorm

배치 정규화 레이어는 훈련 중 모드 축소를 방지하는 데 도움이 됩니다. 이 경우 배치 정규화 레이어를 추가하면 모델 훈련이 0만 생성하는 모델을 생성하는 것을 방지할 수 있습니다.

아래 사용자 정의 `BatchNorm` 레이어의 생성자는 `Layout` 인수를 사용하지 않습니다. `BatchNorm`에는 레이어 변수가 없기 때문입니다. 그래도 DTensor에서의 작동에는 문제가 없는데 레이어에 대한 유일한 입력인 'x'가 이미 전역 배치를 나타내는 DTensor이기 때문입니다.

참고: DTensor에서 입력 Tensor 'x'는 항상 전역 배치를 나타냅니다. 따라서 `tf.nn.batch_normalization`은 전역 배치에 적용됩니다. `tf.distribute.MirroredStrategy`를 사용한 훈련은 이와 다르며, 여기서는 Tensor 'x'가 배치(로컬 배치)의 복제본당 샤드만 나타냅니다.

In [9]:
class BatchNorm(tf.Module):

  def __init__(self):
    super().__init__()

  def __call__(self, x, training=True):
    if not training:
      # This branch is not used in the Tutorial.
      pass
    mean, variance = tf.nn.moments(x, axes=[0])
    return tf.nn.batch_normalization(x, mean, variance, 0.0, 1.0, 1e-5)

완전한 기능을 갖춘 배치 정규화 레이어(예: `tf.keras.layers.BatchNormalization`)에는 변수에 대한 레이아웃 인수가 필요합니다.

In [10]:
def make_keras_bn(bn_layout):
  return tf.keras.layers.BatchNormalization(gamma_layout=bn_layout,
                                            beta_layout=bn_layout,
                                            moving_mean_layout=bn_layout,
                                            moving_variance_layout=bn_layout,
                                            fused=False)

### 레이어 합치기

다음으로, 위의 구성 요소를 사용하여 MLP(Multi-layer perceptron) 네트워크를 구축합니다. 아래 다이어그램은 DTensor 샤딩 또는 복제가 적용되지 않은 두 `Dense` 레이어에 대한 입력 `x`와 가중치 행렬 간의 축 관계를 보여줍니다.

<img src="https://www.tensorflow.org/images/dtensor/no_dtensor.png" class="no-filter" alt="비분산 모델에 대한 입력 및 가중치 행렬.">


첫 번째 `Dense` 레이어의 출력은 `BatchNorm` 이후의 두 번째 `Dense` 레이어의 입력으로 전달됩니다. 따라서 첫 번째 `Dense` 레이어($\mathbf{W_1}$)의 출력과 두 번째 `Dense` 레이어($\mathbf{W_2}$)의 입력에 대한 기본 DTensor 샤딩은 $\hat{j}$ 공통 축을 따라 같은 방식으로 $\mathbf{W_1}$ 및 $\mathbf{W_2}$를 샤딩하는 것입니다.

$$ \mathsf{Layout}[{W_{1,ij}}; i, j] = \left[\hat{i}, \hat{j}\right] \ \mathsf{Layout}[{W_{2,jk}}; j, k] = \left[\hat{j}, \hat{k} \right] $$

레이아웃 추론이 2개의 레이아웃이 독립적이지 않음을 보여주더라도 모델 인터페이스를 단순화하기 위해 `MLP`는 Dense 레이어당 하나씩 2개의 `Layout` 인수를 사용합니다.

In [11]:
from typing import Tuple

class MLP(tf.Module):

  def __init__(self, dense_layouts: Tuple[dtensor.Layout, dtensor.Layout]):
    super().__init__()

    self.dense1 = Dense(
        1200, 48, (1, 2), dense_layouts[0], activation=tf.nn.relu)
    self.bn = BatchNorm()
    self.dense2 = Dense(48, 2, (3, 4), dense_layouts[1])

  def __call__(self, x):
    y = x
    y = self.dense1(y)
    y = self.bn(y)
    y = self.dense2(y)
    return y


레이아웃 추론 제약의 정확성과 API의 단순성 사이의 절충은 DTensor를 사용하는 API 설계 시 공통적인 주안점입니다. 다른 API를 사용하여 `Layout` 간의 종속성을 포착하는 것도 가능합니다. 예를 들어 `MLPStricter` 클래스는 생성자에서 `Layout` 객체를 만듭니다.

In [12]:
class MLPStricter(tf.Module):

  def __init__(self, mesh, input_mesh_dim, inner_mesh_dim1, output_mesh_dim):
    super().__init__()

    self.dense1 = Dense(
        1200, 48, (1, 2), dtensor.Layout([input_mesh_dim, inner_mesh_dim1], mesh),
        activation=tf.nn.relu)
    self.bn = BatchNorm()
    self.dense2 = Dense(48, 2, (3, 4), dtensor.Layout([inner_mesh_dim1, output_mesh_dim], mesh))


  def __call__(self, x):
    y = x
    y = self.dense1(y)
    y = self.bn(y)
    y = self.dense2(y)
    return y

모델이 실행되는지 확인하려면 완전히 복제된 레이아웃과 `'x'` 입력의 완전히 복제된 배치로 모델을 검사해야 합니다.

In [13]:
WORLD = dtensor.create_mesh([("world", 8)], devices=DEVICES)

model = MLP([dtensor.Layout.replicated(WORLD, rank=2),
             dtensor.Layout.replicated(WORLD, rank=2)])

sample_x, sample_y = train_data_vec.take(1).get_single_element()
sample_x = dtensor.copy_to_mesh(sample_x, dtensor.Layout.replicated(WORLD, rank=2))
print(model(sample_x))

tf.Tensor([[-5.61041546 5.04737568]
 [-7.14075 6.86515808]
 [-3.10483789 1.5816828]
 ...
 [6.87280321 -3.56776118]
 [8.27548695 -5.7091856]
 [-1.98807693 1.71495843]], layout="sharding_specs:unsharded,unsharded, mesh:|world=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7", shape=(64, 2), dtype=float32)


## 장치로 데이터 이동하기

일반적으로 `tf.data` 반복기(및 기타 데이터 가져오기 방법)는 로컬 호스트 장치 메모리가 지원하는 텐서 객체를 생성합니다. 이 데이터는 DTensor의 구성 요소 텐서를 지원하는 가속기 장치 메모리로 전송되어야 합니다.

`dtensor.copy_to_mesh`는 이 상황에 적합하지 않은데, DTensor의 전역 관점으로 인해 입력 텐서를 모든 장치에 복제하기 때문입니다. 따라서 이 튜토리얼에서는 원활한 데이터 전송을 위해 헬퍼 함수 `repack_local_tensor`를 사용합니다. 이 헬퍼 함수는 `dtensor.pack`을 사용하여 복제본을 지원하는 장치에 대한 복제본의 용도로서 전역 배치의 샤드를 보냅니다(보내기만 함).

이 단순화된 기능은 단일 클라이언트를 가정합니다. 로컬 텐서를 분할하는 올바른 방법과 분할 조각과 로컬 장치 간의 매핑을 결정하는 것은 다중 클라이언트 애플리케이션에서 힘들 수 있습니다.

`tf.data` 통합을 단순화하여 단일 클라이언트와 다중 클라이언트 애플리케이션을 모두 지원하기 위한 추가 DTensor API가 계획되어 있습니다. 계속 지켜봐 주세요.

In [14]:
def repack_local_tensor(x, layout):
  """Repacks a local Tensor-like to a DTensor with layout.

  This function assumes a single-client application.
  """
  x = tf.convert_to_tensor(x)
  sharded_dims = []

  # For every sharded dimension, use tf.split to split the along the dimension.
  # The result is a nested list of split-tensors in queue[0].
  queue = [x]
  for axis, dim in enumerate(layout.sharding_specs):
    if dim == dtensor.UNSHARDED:
      continue
    num_splits = layout.shape[axis]
    queue = tf.nest.map_structure(lambda x: tf.split(x, num_splits, axis=axis), queue)
    sharded_dims.append(dim)

  # Now we can build the list of component tensors by looking up the location in
  # the nested list of split-tensors created in queue[0].
  components = []
  for locations in layout.mesh.local_device_locations():
    t = queue[0]
    for dim in sharded_dims:
      split_index = locations[dim]  # Only valid on single-client mesh.
      t = t[split_index]
    components.append(t)

  return dtensor.pack(components, layout)

## 데이터 병렬 훈련

이 섹션에서는 데이터 병렬 훈련으로 MLP 모델을 훈련합니다. 다음 섹션에서는 모델 병렬 훈련과 공간 병렬 훈련을 보여줍니다.

데이터 병렬 훈련은 분산 머린 러닝에 일반적으로 사용되는 방식입니다.

- 모델 변수는 각각 N개의 장치에 복제됩니다.
- 전역 배치는 복제본당 N개의 배치로 분할됩니다.
- 각 복제본의 배치는 복제본 장치에서 훈련됩니다.
- 모든 복제본에서 데이터 가중치 부여가 일괄적으로 수행되기 전에 그래디언트가 감소합니다.

데이터 병렬 훈련은 장치 수와 관련하여 거의 선형적인 속도 향상을 제공합니다.

### 데이터 병렬 메쉬 생성하기

일반적인 데이터 병렬 처리 훈련 루프는 단일 `batch` 차원으로 구성된 DTensor `Mesh`를 사용하며, 여기서 각 장치는 전역 배치에서 샤드를 받는 복제본이 됩니다.

<img src="https://www.tensorflow.org/images/dtensor/dtensor_data_para.png" class="no-filter" alt="데이터 병렬 메쉬">

복제된 모델은 복제본에서 실행되므로 모델 변수는 완전히 복제됩니다(샤딩되지 않음).

In [15]:
mesh = dtensor.create_mesh([("batch", 8)], devices=DEVICES)

model = MLP([dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh),
             dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh),])


### DTensor에 훈련 데이터 패킹하기

훈련 데이터 배치는 `'batch'`(첫 번째) 축을 따라 샤딩된 DTensor로 패킹되어야 하며, 그래야 DTensor가 훈련 데이터를 `'batch'` 메쉬 차원에 고르게 분배합니다.

**참고**: DTensor에서 `batch size`는 항상 전역 배치 크기를 나타냅니다. 배치 크기는 `batch` 메쉬 차원의 크기로 균등하게 나눠질 수 있도록 선택해야 합니다.

In [16]:
def repack_batch(x, y, mesh):
  x = repack_local_tensor(x, layout=dtensor.Layout(['batch', dtensor.UNSHARDED], mesh))
  y = repack_local_tensor(y, layout=dtensor.Layout(['batch'], mesh))
  return x, y

sample_x, sample_y = train_data_vec.take(1).get_single_element()
sample_x, sample_y = repack_batch(sample_x, sample_y, mesh)

print('x', sample_x[:, 0])
print('y', sample_y)

x tf.Tensor({"CPU:0": [85.6979828 57.1319885 139.655975 ... 260.267944 438.011902 111.089973], "CPU:1": [117.437973 66.6539841 107.915977 ... 146.003967 260.267944 47.6099892], "CPU:2": [136.481964 215.831955 285.659943 ... 355.487915 206.309952 101.567978], "CPU:3": [107.915977 57.1319885 79.3499832 ... 63.4799881 203.135956 371.35791], "CPU:4": [206.309952 73.0019836 34.9139938 ... 82.5239792 44.4359894 69.8279877], "CPU:5": [95.2199783 219.005951 434.837891 ... 98.3939819 95.2199783 345.965912], "CPU:6": [174.569962 282.485931 38.0879898 ... 234.875946 79.3499832 79.3499832], "CPU:7": [215.831955 590.363892 107.915977 ... 238.049942 244.397949 82.5239792]}, layout="sharding_specs:batch, mesh:|batch=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:

### 훈련 스텝

이 예에서는 사용자 정의 훈련 루프(CTL)와 함께 확률적 경사 하강법 옵티마이저를 사용합니다. 이러한 주제에 대한 자세한 내용은 [사용자 정의 훈련 루프 가이드](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch) 및 [둘러보기](https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough)를 참조하세요.

`train_step`은 `tf.function`으로 캡슐화되어 이 본문이 TensorFlow 그래프로 추적될 것임을 나타냅니다. `train_step`의 본문은 순방향 추론 패스, 역방향 그래디언트 패스 및 변수 업데이트로 구성됩니다.

`train_step`의 본문에는 특별한 DTensor 주석이 포함되어 있지 않습니다. 대신 `train_step`에는 입력 배치 및 모델의 전역 보기에서 입력 `x` 및 `y`를 처리하는 상위 수준 TensorFlow 연산만 포함됩니다. 모든 DTensor 주석( `Mesh`, `Layout`)은 훈련 스텝에서 고려되지 않습니다.

In [17]:
# Refer to the CTL (custom training loop guide)
@tf.function
def train_step(model, x, y, learning_rate=tf.constant(1e-4)):
  with tf.GradientTape() as tape:
    logits = model(x)
    # tf.reduce_sum sums the batch sharded per-example loss to a replicated
    # global loss (scalar).
    loss = tf.reduce_sum(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=y))
  parameters = model.trainable_variables
  gradients = tape.gradient(loss, parameters)
  for parameter, parameter_gradient in zip(parameters, gradients):
    parameter.assign_sub(learning_rate * parameter_gradient)

  # Define some metrics
  accuracy = 1.0 - tf.reduce_sum(tf.cast(tf.argmax(logits, axis=-1, output_type=tf.int64) != y, tf.float32)) / x.shape[0]
  loss_per_sample = loss / len(x)
  return {'loss': loss_per_sample, 'accuracy': accuracy}

### 체크포인트

`dtensor.DTensorCheckpoint`를 사용하여 DTensor 모델을 체크포인트할 수 있습니다. DTensor 체크포인트의 형식은 표준 TensorFlow 체크포인트와 완전히 호환됩니다. `dtensor.DTensorCheckpoint`를 `tf.train.Checkpoint`에 통합하려는 작업이 진행 중입니다.

DTensor 체크포인트가 복원되면 변수의 `Layout`은 체크포인트를 저장할 때와 다를 수 있습니다. 이 튜토리얼에서는 이 기능을 사용하여 모델 병렬 훈련 및 공간 병렬 훈련 섹션에서 훈련을 계속합니다.


In [18]:
CHECKPOINT_DIR = tempfile.mkdtemp()

def start_checkpoint_manager(mesh, model):
  ckpt = dtensor.DTensorCheckpoint(mesh, root=model)
  manager = tf.train.CheckpointManager(ckpt, CHECKPOINT_DIR, max_to_keep=3)

  if manager.latest_checkpoint:
    print("Restoring a checkpoint")
    ckpt.restore(manager.latest_checkpoint).assert_consumed()
  else:
    print("new training")
  return manager


### 훈련 루프

데이터 병렬 훈련 방식의 경우, epoch에 대해 훈련하고 진행 상황을 보고합니다. 3개의 epoch로는 모델 훈련에 충분하지 않습니다. 50%의 정확도는 무작위로 추측하는 정도에 불과합니다.

나중에 훈련을 할 수 있도록 체크포인트를 활성화합니다. 다음 섹션에서는 체크포인트를 로드하고 다른 병렬 방식으로 훈련합니다.

In [19]:
num_epochs = 2
manager = start_checkpoint_manager(mesh, model)

for epoch in range(num_epochs):
  step = 0
  pbar = tf.keras.utils.Progbar(target=int(train_data_vec.cardinality()), stateful_metrics=[])
  metrics = {'epoch': epoch}
  for x,y in train_data_vec:

    x, y = repack_batch(x, y, mesh)

    metrics.update(train_step(model, x, y, 1e-2))

    pbar.update(step, values=metrics.items(), finalize=False)
    step += 1
  manager.save()
  pbar.update(step, values=metrics.items(), finalize=True)

Instructions for updating:
Please use tf.train.Checkpoint instead of DTensorCheckpoint. DTensor is integrated with tf.train.Checkpoint and it can be used out of the box to save and restore dtensors.


Instructions for updating:
Please use tf.train.Checkpoint instead of DTensorCheckpoint. DTensor is integrated with tf.train.Checkpoint and it can be used out of the box to save and restore dtensors.


new training


  0/391 [..............................] - ETA: 0s - epoch: 0.0000e+00 - loss: 6.3831 - accuracy: 0.3750

  1/391 [..............................] - ETA: 6:05 - epoch: 0.0000e+00 - loss: 4.8145 - accuracy: 0.4297

  2/391 [..............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 4.1734 - accuracy: 0.4479 

  3/391 [..............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 3.9043 - accuracy: 0.4141

  4/391 [..............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 3.7511 - accuracy: 0.4313

  5/391 [..............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 3.5512 - accuracy: 0.4427

  6/391 [..............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 3.3543 - accuracy: 0.4464

  7/391 [..............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 3.0973 - accuracy: 0.4551

  8/391 [..............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 2.9359 - accuracy: 0.4566

  9/391 [..............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.8797 - accuracy: 0.4625

 11/391 [..............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.7526 - accuracy: 0.4753

 12/391 [..............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.6465 - accuracy: 0.4736

 13/391 [..............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.5404 - accuracy: 0.4766

 14/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.5242 - accuracy: 0.4771

 15/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.5309 - accuracy: 0.4746

 16/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.5012 - accuracy: 0.4807

 17/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.4961 - accuracy: 0.4792

 18/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.5176 - accuracy: 0.4836

 19/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.4740 - accuracy: 0.4859

 20/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.4322 - accuracy: 0.4859

 21/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.4050 - accuracy: 0.4851

 22/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.3588 - accuracy: 0.4864

 23/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.3012 - accuracy: 0.4883

 24/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.2531 - accuracy: 0.4900

 25/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.2060 - accuracy: 0.4940

 26/391 [>.............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 2.1812 - accuracy: 0.4931

 27/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 2.1839 - accuracy: 0.4933

 28/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 2.1809 - accuracy: 0.4925

 29/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 2.1578 - accuracy: 0.4948

 30/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 2.1466 - accuracy: 0.4955

 31/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 2.1469 - accuracy: 0.5000

 32/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 2.1074 - accuracy: 0.5028

 33/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 2.0644 - accuracy: 0.5055

 34/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 2.0328 - accuracy: 0.5071

 35/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 2.0063 - accuracy: 0.5082

 36/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 2.0017 - accuracy: 0.5089

 37/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 1.9800 - accuracy: 0.5090

 38/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 1.9733 - accuracy: 0.5080

 39/391 [=>............................] - ETA: 19s - epoch: 0.0000e+00 - loss: 1.9468 - accuracy: 0.5070

 40/391 [==>...........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 1.9322 - accuracy: 0.5053

 42/391 [==>...........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 1.9282 - accuracy: 0.5051

 43/391 [==>...........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 1.9194 - accuracy: 0.5057

 44/391 [==>...........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 1.9201 - accuracy: 0.5069

 45/391 [==>...........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 1.9055 - accuracy: 0.5065

 46/391 [==>...........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 1.8895 - accuracy: 0.5093

 47/391 [==>...........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 1.8682 - accuracy: 0.5094

 49/391 [==>...........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 1.8244 - accuracy: 0.5122

 50/391 [==>...........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 1.8087 - accuracy: 0.5119

 51/391 [==>...........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 1.8030 - accuracy: 0.5153

 52/391 [==>...........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.7854 - accuracy: 0.5162

 53/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.7715 - accuracy: 0.5153

 54/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.7652 - accuracy: 0.5165

 55/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.7773 - accuracy: 0.5162

 56/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.8098 - accuracy: 0.5148

 57/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.8055 - accuracy: 0.5162

 58/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.8047 - accuracy: 0.5156

 59/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.8169 - accuracy: 0.5148

 60/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.8111 - accuracy: 0.5149

 61/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.8152 - accuracy: 0.5144

 62/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.8259 - accuracy: 0.5131

 63/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.8223 - accuracy: 0.5134

 64/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.8162 - accuracy: 0.5147

 65/391 [===>..........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.8102 - accuracy: 0.5152

 66/391 [====>.........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 1.8276 - accuracy: 0.5131

 68/391 [====>.........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.8367 - accuracy: 0.5129

 70/391 [====>.........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.8154 - accuracy: 0.5121

 71/391 [====>.........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.8167 - accuracy: 0.5119

 72/391 [====>.........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.8096 - accuracy: 0.5122

 73/391 [====>.........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.8053 - accuracy: 0.5125

 74/391 [====>.........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.7936 - accuracy: 0.5135

 75/391 [====>.........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.7870 - accuracy: 0.5138

 76/391 [====>.........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.8003 - accuracy: 0.5120

 77/391 [====>.........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.8016 - accuracy: 0.5132

 78/391 [====>.........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.7883 - accuracy: 0.5150

 79/391 [=====>........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.7757 - accuracy: 0.5162

 80/391 [=====>........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.7630 - accuracy: 0.5185

 81/391 [=====>........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.7507 - accuracy: 0.5192

 82/391 [=====>........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.7391 - accuracy: 0.5198

 83/391 [=====>........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 1.7282 - accuracy: 0.5212

 84/391 [=====>........................] - ETA: 15s - epoch: 0.0000e+00 - loss: 1.7218 - accuracy: 0.5206

 85/391 [=====>........................] - ETA: 15s - epoch: 0.0000e+00 - loss: 1.7271 - accuracy: 0.5205

 87/391 [=====>........................] - ETA: 15s - epoch: 0.0000e+00 - loss: 1.7148 - accuracy: 0.5206

 88/391 [=====>........................] - ETA: 15s - epoch: 0.0000e+00 - loss: 1.7135 - accuracy: 0.5207

 89/391 [=====>........................] - ETA: 15s - epoch: 0.0000e+00 - loss: 1.7026 - accuracy: 0.5208

 90/391 [=====>........................] - ETA: 15s - epoch: 0.0000e+00 - loss: 1.6971 - accuracy: 0.5206

 91/391 [=====>........................] - ETA: 15s - epoch: 0.0000e+00 - loss: 1.6962 - accuracy: 0.5200





















































































































































































































































































































































































































































  0/391 [..............................] - ETA: 0s - epoch: 1.0000 - loss: 0.8295 - accuracy: 0.5625

  1/391 [..............................] - ETA: 2:51 - epoch: 1.0000 - loss: 0.9183 - accuracy: 0.5312

  2/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 0.9802 - accuracy: 0.5208 

  3/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 1.1663 - accuracy: 0.5195

  4/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 1.3465 - accuracy: 0.5156

  5/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 1.4158 - accuracy: 0.5182

  6/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 1.3362 - accuracy: 0.5268

  7/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 1.2632 - accuracy: 0.5352

  8/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 1.2339 - accuracy: 0.5382

  9/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2263 - accuracy: 0.5375

 10/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2733 - accuracy: 0.5312

 11/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2434 - accuracy: 0.5404

 12/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2063 - accuracy: 0.5493

 13/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 1.1795 - accuracy: 0.5502

 14/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2024 - accuracy: 0.5417

 15/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2203 - accuracy: 0.5361

 16/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2314 - accuracy: 0.5395

 17/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2502 - accuracy: 0.5347

 18/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2873 - accuracy: 0.5354

 19/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2668 - accuracy: 0.5367

 20/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2569 - accuracy: 0.5379

 21/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2720 - accuracy: 0.5412

 22/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2526 - accuracy: 0.5408

 23/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2319 - accuracy: 0.5495

 24/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.2090 - accuracy: 0.5550

 25/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.1914 - accuracy: 0.5565

 26/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 1.1983 - accuracy: 0.5561

 27/391 [=>............................] - ETA: 20s - epoch: 1.0000 - loss: 1.1970 - accuracy: 0.5564

 28/391 [=>............................] - ETA: 20s - epoch: 1.0000 - loss: 1.1790 - accuracy: 0.5587

 29/391 [=>............................] - ETA: 20s - epoch: 1.0000 - loss: 1.1730 - accuracy: 0.5578

 30/391 [=>............................] - ETA: 20s - epoch: 1.0000 - loss: 1.1846 - accuracy: 0.5580

 31/391 [=>............................] - ETA: 20s - epoch: 1.0000 - loss: 1.1803 - accuracy: 0.5571

 32/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 1.1999 - accuracy: 0.5568

 33/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 1.1894 - accuracy: 0.5551

 34/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 1.1883 - accuracy: 0.5571

 35/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 1.1731 - accuracy: 0.5629

 36/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 1.1598 - accuracy: 0.5638

 37/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 1.1495 - accuracy: 0.5662

 38/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 1.1524 - accuracy: 0.5665

 39/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 1.1682 - accuracy: 0.5660

 40/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1802 - accuracy: 0.5655

 41/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1757 - accuracy: 0.5655

 42/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1696 - accuracy: 0.5658

 43/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1845 - accuracy: 0.5632

 44/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.2015 - accuracy: 0.5625

 45/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.2076 - accuracy: 0.5615

 46/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1973 - accuracy: 0.5622

 47/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.2013 - accuracy: 0.5618

 48/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1924 - accuracy: 0.5625

 49/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1811 - accuracy: 0.5650

 50/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1694 - accuracy: 0.5686

 51/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1579 - accuracy: 0.5721

 52/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 1.1484 - accuracy: 0.5743

 53/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1389 - accuracy: 0.5752

 54/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1328 - accuracy: 0.5764

 55/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1421 - accuracy: 0.5753

 56/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1627 - accuracy: 0.5735

 57/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1655 - accuracy: 0.5735

 58/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1643 - accuracy: 0.5728

 59/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1686 - accuracy: 0.5719

 60/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1663 - accuracy: 0.5720

 61/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1661 - accuracy: 0.5721

 62/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1723 - accuracy: 0.5709

 63/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1723 - accuracy: 0.5708

 64/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1662 - accuracy: 0.5712

 65/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1615 - accuracy: 0.5717

 66/391 [====>.........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1680 - accuracy: 0.5697

 67/391 [====>.........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1780 - accuracy: 0.5692

 68/391 [====>.........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1745 - accuracy: 0.5691

 69/391 [====>.........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1701 - accuracy: 0.5690

 70/391 [====>.........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1677 - accuracy: 0.5682

 71/391 [====>.........................] - ETA: 18s - epoch: 1.0000 - loss: 1.1665 - accuracy: 0.5686

 72/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1620 - accuracy: 0.5685

 73/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1599 - accuracy: 0.5686

 74/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1536 - accuracy: 0.5700

 75/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1499 - accuracy: 0.5705

 76/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1491 - accuracy: 0.5698

 77/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1523 - accuracy: 0.5707

 78/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1471 - accuracy: 0.5718

 79/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1422 - accuracy: 0.5725

 80/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1358 - accuracy: 0.5741

 81/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1323 - accuracy: 0.5749

 82/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1271 - accuracy: 0.5759

 83/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1235 - accuracy: 0.5757

 84/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1182 - accuracy: 0.5767

 85/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1139 - accuracy: 0.5774

 86/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1138 - accuracy: 0.5772

 87/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1169 - accuracy: 0.5779

 88/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1185 - accuracy: 0.5774

 89/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 1.1213 - accuracy: 0.5773

 90/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 1.1204 - accuracy: 0.5768

 91/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 1.1225 - accuracy: 0.5766

























































































































































































































































































































































































































































































































































































































## 모델 병렬 훈련

2차원 `Mesh`로 전환하고 두 번째 메쉬 차원을 따라 모델 변수를 샤딩하면 훈련이 모델 병렬이 됩니다.

모델 병렬 훈련에서 각 모델 복제본은 여러 장치(이 경우 2개)에 걸쳐 있습니다.

- 4개의 모델 복제본이 있으며 훈련 데이터 배치는 4개의 복제본에 배포됩니다.
- 단일 모델 복제본 내의 2개 장치는 복제된 훈련 데이터를 수신합니다.

<img src="https://www.tensorflow.org/images/dtensor/dtensor_model_para.png" class="no-filter" alt="모델 병렬 메쉬">


In [20]:
mesh = dtensor.create_mesh([("batch", 4), ("model", 2)], devices=DEVICES)
model = MLP([dtensor.Layout([dtensor.UNSHARDED, "model"], mesh), 
             dtensor.Layout(["model", dtensor.UNSHARDED], mesh)])









훈련 데이터는 여전히 배치 차원을 따라 샤딩되므로 데이터 병렬 훈련 사례와 동일한 `repack_batch` 함수를 재사용할 수 있습니다. DTensor는 `"model"` 메쉬 차원을 따라 복제본 내부의 모든 장치에 복제본당 배치를 자동으로 복제합니다.

In [21]:
def repack_batch(x, y, mesh):
  x = repack_local_tensor(x, layout=dtensor.Layout(['batch', dtensor.UNSHARDED], mesh))
  y = repack_local_tensor(y, layout=dtensor.Layout(['batch'], mesh))
  return x, y

다음으로, 훈련 루프를 실행합니다. 훈련 루프는 데이터 병렬 훈련 예제와 동일한 체크포인트 관리자를 재사용하며 코드는 동일해 보입니다.

모델 병렬 학습에서 데이터 병렬 학습 모델을 계속 훈련할 수 있습니다.

In [22]:
num_epochs = 2
manager = start_checkpoint_manager(mesh, model)

for epoch in range(num_epochs):
  step = 0
  pbar = tf.keras.utils.Progbar(target=int(train_data_vec.cardinality()))
  metrics = {'epoch': epoch}
  for x,y in train_data_vec:
    x, y = repack_batch(x, y, mesh)
    metrics.update(train_step(model, x, y, 1e-2))
    pbar.update(step, values=metrics.items(), finalize=False)
    step += 1
  manager.save()
  pbar.update(step, values=metrics.items(), finalize=True)

Restoring a checkpoint


  0/391 [..............................] - ETA: 0s - epoch: 0.0000e+00 - loss: 0.9760 - accuracy: 0.4844

  1/391 [..............................] - ETA: 4:32 - epoch: 0.0000e+00 - loss: 1.1283 - accuracy: 0.5078

  2/391 [..............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.9604 - accuracy: 0.5729 

  3/391 [..............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 1.0287 - accuracy: 0.5742

  4/391 [..............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 1.0962 - accuracy: 0.5656

  5/391 [..............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 1.1744 - accuracy: 0.5599

  6/391 [..............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 1.1026 - accuracy: 0.5692

  7/391 [..............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 1.0744 - accuracy: 0.5742

  8/391 [..............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 1.0553 - accuracy: 0.5816

  9/391 [..............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 1.0362 - accuracy: 0.5813

 10/391 [..............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 1.0463 - accuracy: 0.5739

 11/391 [..............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 1.0185 - accuracy: 0.5820

 12/391 [..............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9964 - accuracy: 0.5913

 13/391 [..............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9788 - accuracy: 0.5960

 14/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9498 - accuracy: 0.6062

 15/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9406 - accuracy: 0.6055

 16/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9377 - accuracy: 0.6085

 17/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9690 - accuracy: 0.6016

 18/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 1.0026 - accuracy: 0.6012

 19/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9880 - accuracy: 0.6023

 20/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.9740 - accuracy: 0.6049

 21/391 [>.............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9729 - accuracy: 0.6072

 22/391 [>.............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9613 - accuracy: 0.6094

 23/391 [>.............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9482 - accuracy: 0.6172

 24/391 [>.............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9372 - accuracy: 0.6212

 25/391 [>.............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9279 - accuracy: 0.6220

 26/391 [>.............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9336 - accuracy: 0.6209

 27/391 [=>............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9282 - accuracy: 0.6228

 28/391 [=>............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9163 - accuracy: 0.6261

 29/391 [=>............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.9079 - accuracy: 0.6281

 30/391 [=>............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.8992 - accuracy: 0.6305

 31/391 [=>............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.8956 - accuracy: 0.6304

 32/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9076 - accuracy: 0.6307

 33/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9062 - accuracy: 0.6282

 34/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9083 - accuracy: 0.6286

 35/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.8959 - accuracy: 0.6328

 36/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.8910 - accuracy: 0.6322

 37/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.8868 - accuracy: 0.6340

 38/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.8904 - accuracy: 0.6338

 39/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9041 - accuracy: 0.6328

 40/391 [==>...........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9170 - accuracy: 0.6315

 41/391 [==>...........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9152 - accuracy: 0.6321

 42/391 [==>...........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9146 - accuracy: 0.6315

 43/391 [==>...........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9256 - accuracy: 0.6293

 44/391 [==>...........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9406 - accuracy: 0.6285

 45/391 [==>...........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.9456 - accuracy: 0.6277

 46/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9415 - accuracy: 0.6280

 47/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9475 - accuracy: 0.6279

 48/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9398 - accuracy: 0.6285

 49/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9313 - accuracy: 0.6313

 50/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9232 - accuracy: 0.6330

 51/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9176 - accuracy: 0.6328

 52/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9177 - accuracy: 0.6338

 53/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9109 - accuracy: 0.6354

 54/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9060 - accuracy: 0.6355

 55/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9084 - accuracy: 0.6362

 56/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9237 - accuracy: 0.6338

 57/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9285 - accuracy: 0.6334

 58/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9294 - accuracy: 0.6337

 59/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9338 - accuracy: 0.6328

 60/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9314 - accuracy: 0.6347

 61/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.9315 - accuracy: 0.6343

 62/391 [===>..........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9310 - accuracy: 0.6332

 63/391 [===>..........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9337 - accuracy: 0.6323

 64/391 [===>..........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9291 - accuracy: 0.6344

 65/391 [===>..........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9241 - accuracy: 0.6352

 66/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9258 - accuracy: 0.6339

 67/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9336 - accuracy: 0.6328

 68/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9338 - accuracy: 0.6322

 69/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9327 - accuracy: 0.6326

 70/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9349 - accuracy: 0.6316

 71/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9354 - accuracy: 0.6319

 72/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9307 - accuracy: 0.6323

 73/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9280 - accuracy: 0.6330

 74/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9229 - accuracy: 0.6344

 75/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9215 - accuracy: 0.6343

 76/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.9180 - accuracy: 0.6347

 77/391 [====>.........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.9166 - accuracy: 0.6354

 78/391 [====>.........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.9140 - accuracy: 0.6355

 79/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.9108 - accuracy: 0.6363

 80/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.9074 - accuracy: 0.6375

 81/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.9058 - accuracy: 0.6381

 82/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.9019 - accuracy: 0.6386

 83/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.8998 - accuracy: 0.6391

 84/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.8952 - accuracy: 0.6404

 85/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.8912 - accuracy: 0.6417

 86/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.8873 - accuracy: 0.6431

 87/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.8863 - accuracy: 0.6435

 88/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.8899 - accuracy: 0.6420

 89/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.8939 - accuracy: 0.6413

 90/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.8936 - accuracy: 0.6411

 91/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.8957 - accuracy: 0.6410

























































































































































































































































































































































































































































































































































































































  0/391 [..............................] - ETA: 0s - epoch: 1.0000 - loss: 0.8002 - accuracy: 0.5625

  1/391 [..............................] - ETA: 3:03 - epoch: 1.0000 - loss: 0.8801 - accuracy: 0.5859

  2/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.7753 - accuracy: 0.6458 

  3/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.8636 - accuracy: 0.6406

  4/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.9272 - accuracy: 0.6281

  5/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.9971 - accuracy: 0.6146

  6/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.9293 - accuracy: 0.6295

  7/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.9288 - accuracy: 0.6270

  8/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.9206 - accuracy: 0.6337

  9/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.8984 - accuracy: 0.6438

 10/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.8901 - accuracy: 0.6449

 11/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.8628 - accuracy: 0.6536

 12/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.8747 - accuracy: 0.6514

 13/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.8870 - accuracy: 0.6507

 14/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8651 - accuracy: 0.6583

 15/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8584 - accuracy: 0.6562

 16/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8534 - accuracy: 0.6599

 17/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8740 - accuracy: 0.6562

 18/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.9013 - accuracy: 0.6538

 19/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8833 - accuracy: 0.6586

 20/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8668 - accuracy: 0.6637

 21/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8531 - accuracy: 0.6669

 22/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8388 - accuracy: 0.6719

 23/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8270 - accuracy: 0.6771

 24/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8181 - accuracy: 0.6787

 25/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8146 - accuracy: 0.6791

 26/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8180 - accuracy: 0.6811

 27/391 [=>............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8109 - accuracy: 0.6814

 28/391 [=>............................] - ETA: 20s - epoch: 1.0000 - loss: 0.8026 - accuracy: 0.6827

 29/391 [=>............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7975 - accuracy: 0.6839

 30/391 [=>............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7903 - accuracy: 0.6855

 31/391 [=>............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7872 - accuracy: 0.6860

 32/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.7933 - accuracy: 0.6856

 33/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.7934 - accuracy: 0.6838

 34/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.7950 - accuracy: 0.6839

 35/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.7836 - accuracy: 0.6888

 36/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.7797 - accuracy: 0.6900

 37/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.7776 - accuracy: 0.6912

 38/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.7805 - accuracy: 0.6911

 39/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.7907 - accuracy: 0.6891

 40/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.8033 - accuracy: 0.6860

 41/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.8024 - accuracy: 0.6860

 42/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.8022 - accuracy: 0.6871

 43/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.8077 - accuracy: 0.6847

 44/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.8175 - accuracy: 0.6844

 45/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.8215 - accuracy: 0.6834

 46/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.8171 - accuracy: 0.6845

 47/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.8216 - accuracy: 0.6839

 48/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.8159 - accuracy: 0.6846

 49/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.8078 - accuracy: 0.6866

 50/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.8017 - accuracy: 0.6875

 51/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7983 - accuracy: 0.6869

 52/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7985 - accuracy: 0.6869

 53/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7920 - accuracy: 0.6889

 54/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7887 - accuracy: 0.6892

 55/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7887 - accuracy: 0.6895

 56/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7912 - accuracy: 0.6878

 57/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7960 - accuracy: 0.6870

 58/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7975 - accuracy: 0.6864

 59/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7984 - accuracy: 0.6865

 60/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7955 - accuracy: 0.6883

 61/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7952 - accuracy: 0.6885

 62/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7925 - accuracy: 0.6885

 63/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7937 - accuracy: 0.6875

 64/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7897 - accuracy: 0.6889

 65/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.7849 - accuracy: 0.6896

 66/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7860 - accuracy: 0.6882

 67/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7933 - accuracy: 0.6861

 68/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7951 - accuracy: 0.6855

 69/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7956 - accuracy: 0.6857

 70/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7984 - accuracy: 0.6842

 71/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.8005 - accuracy: 0.6840

 72/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7966 - accuracy: 0.6854

 73/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7948 - accuracy: 0.6856

 74/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7913 - accuracy: 0.6867

 75/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7896 - accuracy: 0.6873

 76/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7871 - accuracy: 0.6883

 77/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7864 - accuracy: 0.6887

 78/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7842 - accuracy: 0.6891

 79/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7822 - accuracy: 0.6891

 80/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7788 - accuracy: 0.6906

 81/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7757 - accuracy: 0.6915

 82/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 0.7734 - accuracy: 0.6911

 83/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.7720 - accuracy: 0.6912

 84/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.7686 - accuracy: 0.6915

 85/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.7668 - accuracy: 0.6915

 86/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.7642 - accuracy: 0.6916

 87/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.7635 - accuracy: 0.6912

 88/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.7620 - accuracy: 0.6917

 89/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.7593 - accuracy: 0.6927

 90/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.7576 - accuracy: 0.6932

 91/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.7556 - accuracy: 0.6941

























































































































































































































































































































































































































































































































































































































## 공간 병렬 훈련

매우 높은 차원의 데이터(예: 매우 큰 이미지 또는 비디오)를 훈련할 때 특성 차원을 따라 샤딩하는 것이 바람직할 수 있습니다. 이를 [공간 분할(Spatial Partitioning)](https://cloud.google.com/blog/products/ai-machine-learning/train-ml-models-on-large-images-and-3d-volumes-with-spatial-partitioning-on-cloud-tpus)이라고 하며, 이는 큰 3차원 입력 샘플이 있는 모델을 훈련하기 위해 TensorFlow에 처음 도입되었습니다.

<img src="https://www.tensorflow.org/images/dtensor/dtensor_spatial_para.png" class="no-filter" alt="공간 병렬 메쉬">

DTensor도 이러한 사례를 지원합니다. 유일하게 변경해야 하는 부분은 `feature` 차원을 포함하는 메쉬를 생성하고 해당 `Layout`을 적용하는 것입니다.


In [23]:
mesh = dtensor.create_mesh([("batch", 2), ("feature", 2), ("model", 2)], devices=DEVICES)
model = MLP([dtensor.Layout(["feature", "model"], mesh), 
             dtensor.Layout(["model", dtensor.UNSHARDED], mesh)])


입력 텐서를 DTensor로 패킹할 때 `feature` 차원을 따라 입력 데이터를 샤딩합니다. 약간 다른 repack 함수인 `repack_batch_for_spt`를 사용하여 이 작업을 수행합니다. 여기서 `spt`는 Spatial Parallel Training을 나타냅니다.

In [24]:
def repack_batch_for_spt(x, y, mesh):
    # Shard data on feature dimension, too
    x = repack_local_tensor(x, layout=dtensor.Layout(["batch", 'feature'], mesh))
    y = repack_local_tensor(y, layout=dtensor.Layout(["batch"], mesh))
    return x, y

다른 병렬 훈련 방식으로 생성된 체크포인트에서 공간 병렬 훈련을 계속 이어갈 수도 있습니다.

In [25]:
num_epochs = 2

manager = start_checkpoint_manager(mesh, model)
for epoch in range(num_epochs):
  step = 0
  metrics = {'epoch': epoch}
  pbar = tf.keras.utils.Progbar(target=int(train_data_vec.cardinality()))

  for x, y in train_data_vec:
    x, y = repack_batch_for_spt(x, y, mesh)
    metrics.update(train_step(model, x, y, 1e-2))

    pbar.update(step, values=metrics.items(), finalize=False)
    step += 1
  manager.save()
  pbar.update(step, values=metrics.items(), finalize=True)

Restoring a checkpoint


  0/391 [..............................] - ETA: 0s - epoch: 0.0000e+00 - loss: 0.7536 - accuracy: 0.6250

  1/391 [..............................] - ETA: 4:13 - epoch: 0.0000e+00 - loss: 0.7093 - accuracy: 0.6719

2022-12-15 01:52:24.807448: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:24.813784: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:24.818575: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:24.825993: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:24.827772: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:24.827811: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:24.827939: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:24.828144: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions


  2/391 [..............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.6289 - accuracy: 0.7135 

  3/391 [..............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.7070 - accuracy: 0.6992

  4/391 [..............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.7735 - accuracy: 0.6875

  5/391 [..............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.8479 - accuracy: 0.6693

  6/391 [..............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.7911 - accuracy: 0.6830

  7/391 [..............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.8112 - accuracy: 0.6719

  8/391 [..............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.8101 - accuracy: 0.6753

  9/391 [..............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.7871 - accuracy: 0.6828

 10/391 [..............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.7771 - accuracy: 0.6832

 11/391 [..............................] - ETA: 24s - epoch: 0.0000e+00 - loss: 0.7612 - accuracy: 0.6875

 12/391 [..............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.7818 - accuracy: 0.6815

 13/391 [..............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.7907 - accuracy: 0.6797

 14/391 [>.............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.7660 - accuracy: 0.6875

 15/391 [>.............................] - ETA: 23s - epoch: 0.0000e+00 - loss: 0.7637 - accuracy: 0.6826

 16/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7607 - accuracy: 0.6847

 17/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7737 - accuracy: 0.6814

 18/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7856 - accuracy: 0.6793

 19/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7689 - accuracy: 0.6852

 20/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7550 - accuracy: 0.6912

 21/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7410 - accuracy: 0.6953

 22/391 [>.............................] - ETA: 22s - epoch: 0.0000e+00 - loss: 0.7303 - accuracy: 0.6997

 23/391 [>.............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.7224 - accuracy: 0.7031

 24/391 [>.............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.7137 - accuracy: 0.7056

 25/391 [>.............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.7114 - accuracy: 0.7061

 26/391 [>.............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.7098 - accuracy: 0.7089

 27/391 [=>............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.7045 - accuracy: 0.7093

 28/391 [=>............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6995 - accuracy: 0.7101

 29/391 [=>............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6961 - accuracy: 0.7099

 30/391 [=>............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6907 - accuracy: 0.7122

 31/391 [=>............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6876 - accuracy: 0.7134

 32/391 [=>............................] - ETA: 21s - epoch: 0.0000e+00 - loss: 0.6905 - accuracy: 0.7121

 33/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6922 - accuracy: 0.7100

 34/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6937 - accuracy: 0.7089

 35/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6834 - accuracy: 0.7140

 36/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6806 - accuracy: 0.7166

 37/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6793 - accuracy: 0.7171

 38/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6795 - accuracy: 0.7188

 39/391 [=>............................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6831 - accuracy: 0.7176

 40/391 [==>...........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6940 - accuracy: 0.7161

 41/391 [==>...........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6945 - accuracy: 0.7147

 42/391 [==>...........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6952 - accuracy: 0.7155

 43/391 [==>...........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.6998 - accuracy: 0.7124

 44/391 [==>...........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.7078 - accuracy: 0.7115

 45/391 [==>...........................] - ETA: 20s - epoch: 0.0000e+00 - loss: 0.7125 - accuracy: 0.7113

 46/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.7096 - accuracy: 0.7131

 47/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.7134 - accuracy: 0.7129

 48/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.7079 - accuracy: 0.7156

 49/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.7019 - accuracy: 0.7169

 50/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6975 - accuracy: 0.7175

 51/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6957 - accuracy: 0.7175

 52/391 [==>...........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6963 - accuracy: 0.7173

 53/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6908 - accuracy: 0.7196

 54/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6894 - accuracy: 0.7193

 55/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6914 - accuracy: 0.7193

 56/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.6973 - accuracy: 0.7168

 57/391 [===>..........................] - ETA: 19s - epoch: 0.0000e+00 - loss: 0.7034 - accuracy: 0.7161

 58/391 [===>..........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.7049 - accuracy: 0.7153

 59/391 [===>..........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.7051 - accuracy: 0.7167

 60/391 [===>..........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.7026 - accuracy: 0.7185

 61/391 [===>..........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.7015 - accuracy: 0.7190

 62/391 [===>..........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.6985 - accuracy: 0.7197

 63/391 [===>..........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.6986 - accuracy: 0.7200

 64/391 [===>..........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.6952 - accuracy: 0.7204

 65/391 [===>..........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.6911 - accuracy: 0.7216

 66/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.6931 - accuracy: 0.7206

 67/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.6993 - accuracy: 0.7190

 68/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.7007 - accuracy: 0.7181

 69/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.7025 - accuracy: 0.7179

 70/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.7044 - accuracy: 0.7170

 71/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.7062 - accuracy: 0.7164

 72/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.7034 - accuracy: 0.7173

 73/391 [====>.........................] - ETA: 18s - epoch: 0.0000e+00 - loss: 0.7017 - accuracy: 0.7173

 74/391 [====>.........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6990 - accuracy: 0.7177

 75/391 [====>.........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6983 - accuracy: 0.7185

 76/391 [====>.........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6968 - accuracy: 0.7190

 77/391 [====>.........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6962 - accuracy: 0.7194

 78/391 [====>.........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6945 - accuracy: 0.7203

 79/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6933 - accuracy: 0.7201

 80/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6907 - accuracy: 0.7209

 81/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6885 - accuracy: 0.7214

 82/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6880 - accuracy: 0.7206

 83/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6884 - accuracy: 0.7201

 84/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6862 - accuracy: 0.7210

 85/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6863 - accuracy: 0.7206

 86/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6841 - accuracy: 0.7209

 87/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6834 - accuracy: 0.7207

 88/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6822 - accuracy: 0.7205

 89/391 [=====>........................] - ETA: 17s - epoch: 0.0000e+00 - loss: 0.6799 - accuracy: 0.7215

 90/391 [=====>........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 0.6785 - accuracy: 0.7215

 91/391 [=====>........................] - ETA: 16s - epoch: 0.0000e+00 - loss: 0.6774 - accuracy: 0.7223























































































































































































































































































































































































































































































































































































































2022-12-15 01:52:46.658600: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:46.661797: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:46.663141: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:46.669512: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:46.670651: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:46.671392: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:46.671950: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions
2022-12-15 01:52:46.671989: E tensorflow/core/grappler/costs/op_level_cost_estimator.cc:1104] Incompatible Matrix dimensions




  0/391 [..............................] - ETA: 0s - epoch: 1.0000 - loss: 0.6446 - accuracy: 0.7344

  1/391 [..............................] - ETA: 2:32 - epoch: 1.0000 - loss: 0.6031 - accuracy: 0.7578

  2/391 [..............................] - ETA: 21s - epoch: 1.0000 - loss: 0.5305 - accuracy: 0.7708 

  3/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 0.6011 - accuracy: 0.7461

  4/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 0.6807 - accuracy: 0.7312

  5/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7588 - accuracy: 0.7057

  6/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7086 - accuracy: 0.7188

  7/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7376 - accuracy: 0.7070

  8/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7467 - accuracy: 0.7049

  9/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7279 - accuracy: 0.7109

 10/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7144 - accuracy: 0.7131

 11/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7015 - accuracy: 0.7174

 12/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7261 - accuracy: 0.7115

 13/391 [..............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7402 - accuracy: 0.7087

 14/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7202 - accuracy: 0.7125

 15/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7163 - accuracy: 0.7051

 16/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7135 - accuracy: 0.7068

 17/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7257 - accuracy: 0.7031

 18/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7379 - accuracy: 0.7007

 19/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7233 - accuracy: 0.7063

 20/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.7109 - accuracy: 0.7128

 21/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.6974 - accuracy: 0.7173

 22/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.6878 - accuracy: 0.7188

 23/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.6804 - accuracy: 0.7220

 24/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.6728 - accuracy: 0.7256

 25/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.6682 - accuracy: 0.7290

 26/391 [>.............................] - ETA: 20s - epoch: 1.0000 - loss: 0.6654 - accuracy: 0.7321

 27/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6602 - accuracy: 0.7344

 28/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6567 - accuracy: 0.7355

 29/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6518 - accuracy: 0.7359

 30/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6449 - accuracy: 0.7374

 31/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6392 - accuracy: 0.7412

 32/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6368 - accuracy: 0.7420

 33/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6373 - accuracy: 0.7413

 34/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6366 - accuracy: 0.7415

 35/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6270 - accuracy: 0.7457

 36/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6250 - accuracy: 0.7479

 37/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6235 - accuracy: 0.7484

 38/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6234 - accuracy: 0.7496

 39/391 [=>............................] - ETA: 19s - epoch: 1.0000 - loss: 0.6263 - accuracy: 0.7488

 40/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6325 - accuracy: 0.7466

 41/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6333 - accuracy: 0.7455

 42/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6338 - accuracy: 0.7456

 43/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6367 - accuracy: 0.7436

 44/391 [==>...........................] - ETA: 19s - epoch: 1.0000 - loss: 0.6408 - accuracy: 0.7424

 45/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6449 - accuracy: 0.7415

 46/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6424 - accuracy: 0.7430

 47/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6460 - accuracy: 0.7432

 48/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6414 - accuracy: 0.7455

 49/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6364 - accuracy: 0.7466

 50/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6327 - accuracy: 0.7463

 51/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6309 - accuracy: 0.7458

 52/391 [==>...........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6298 - accuracy: 0.7456

 53/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6248 - accuracy: 0.7477

 54/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6240 - accuracy: 0.7474

 55/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6264 - accuracy: 0.7472

 56/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6326 - accuracy: 0.7442

 57/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6376 - accuracy: 0.7435

 58/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6388 - accuracy: 0.7431

 59/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6382 - accuracy: 0.7440

 60/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6354 - accuracy: 0.7462

 61/391 [===>..........................] - ETA: 18s - epoch: 1.0000 - loss: 0.6340 - accuracy: 0.7465

 62/391 [===>..........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6313 - accuracy: 0.7478

 63/391 [===>..........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6315 - accuracy: 0.7471

 64/391 [===>..........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6288 - accuracy: 0.7478

 65/391 [===>..........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6256 - accuracy: 0.7491

 66/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6286 - accuracy: 0.7479

 67/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6333 - accuracy: 0.7468

 68/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6344 - accuracy: 0.7457

 69/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6369 - accuracy: 0.7453

 70/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6378 - accuracy: 0.7447

 71/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6392 - accuracy: 0.7441

 72/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6373 - accuracy: 0.7449

 73/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6367 - accuracy: 0.7447

 74/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6340 - accuracy: 0.7452

 75/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6338 - accuracy: 0.7463

 76/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6340 - accuracy: 0.7463

 77/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6341 - accuracy: 0.7464

 78/391 [====>.........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6331 - accuracy: 0.7472

 79/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6324 - accuracy: 0.7471

 80/391 [=====>........................] - ETA: 17s - epoch: 1.0000 - loss: 0.6300 - accuracy: 0.7471

 81/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.6285 - accuracy: 0.7475

 82/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.6309 - accuracy: 0.7459

 83/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.6335 - accuracy: 0.7448

 84/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.6313 - accuracy: 0.7454

 85/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.6298 - accuracy: 0.7462

 86/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.6273 - accuracy: 0.7471

 87/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.6264 - accuracy: 0.7468

 88/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.6259 - accuracy: 0.7461

 89/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.6242 - accuracy: 0.7467

 90/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.6228 - accuracy: 0.7469

 91/391 [=====>........................] - ETA: 16s - epoch: 1.0000 - loss: 0.6219 - accuracy: 0.7480

























































































































































































































































































































































































































































































































































































































## SavedModel 및 DTensor

DTensor와 SavedModel의 통합은 아직 개발 중입니다. 이 섹션에서는 TensorFlow 2.9.0의 현재 상태만 설명합니다.

TensorFlow 2.9.0부터 `tf.saved_model`은 완전히 복제된 변수가 있는 DTensor 모델만 허용합니다.

해결 방법으로 체크포인트를 다시 로드하여 DTensor 모델을 완전히 복제된 모델로 변환할 수 있습니다. 그러나 모델이 저장된 후에는 모든 DTensor 주석이 손실되고 저장된 서명은 DTensor가 아닌 일반 Tensor에서만 사용할 수 있습니다.

In [26]:
mesh = dtensor.create_mesh([("world", 1)], devices=DEVICES[:1])
mlp = MLP([dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh), 
           dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh)])

manager = start_checkpoint_manager(mesh, mlp)

model_for_saving = tf.keras.Sequential([
  text_vectorization,
  mlp
])

@tf.function(input_signature=[tf.TensorSpec([None], tf.string)])
def run(inputs):
  return {'result': model_for_saving(inputs)}

tf.saved_model.save(
    model_for_saving, "/tmp/saved_model",
    signatures=run)

Restoring a checkpoint


INFO:tensorflow:Assets written to: /tmp/saved_model/assets


INFO:tensorflow:Assets written to: /tmp/saved_model/assets


TensorFlow 2.9.0부터 일반 Tensor 또는 완전히 복제된 DTensor(일반 Tensor로 변환됨)로만 로드된 서명을 호출할 수 있습니다.

In [27]:
sample_batch = train_data.take(1).get_single_element()
sample_batch

{'label': <tf.Tensor: shape=(64,), dtype=int64, numpy=
 array([0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1,
        1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1])>,
 'text': <tf.Tensor: shape=(64,), dtype=string, numpy=
 array([b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Wa

In [28]:
loaded = tf.saved_model.load("/tmp/saved_model")

run_sig = loaded.signatures["serving_default"]
result = run_sig(sample_batch['text'])['result']

In [29]:
np.mean(tf.argmax(result, axis=-1) == sample_batch['label'])

0.75

## 다음 단계

이 튜토리얼에서는 DTensor를 사용하여 MLP 감정 분석 모델을 빌드하고 훈련하는 방법을 보여주었습니다.

`Mesh` 및 `Layout` 기본 형식을 통해 DTensor는 TensorFlow `tf.function`을 다양한 훈련 방식에 적합한 분산 프로그램으로 변환할 수 있습니다.

실제 머신 러닝 애플리케이션에서는 과적합된 모델을 생성하지 않도록 평가와 교차 검증을 적용해야 합니다. 이 튜토리얼에서 소개된 방법은 평가에 병렬 처리를 도입하는 데도 적용할 수 있습니다.

처음부터 `tf.Module`로 모델을 구성하려면 많은 작업이 필요하지만 레이어 및 헬퍼 함수와 같은 기존 빌딩 블록을 재사용하면 모델 개발 속도를 크게 높일 수 있습니다. TensorFlow 2.9부터 `tf.keras.layers` 아래의 모든 Keras 레이어는 DTensor 레이아웃을 인수로 허용하며 DTensor 모델을 빌드하는 데 사용할 수 있습니다. 모델 구현을 수정하지 않고 DTensor로 Keras 모델을 직접 재사용할 수도 있습니다. DTensor Keras 사용에 대한 정보는 [DTensor Keras 통합 튜토리얼](https://www.tensorflow.org/tutorials/distribute/dtensor_keras_tutorial)을 참조하세요. 