##### Copyright 2019 The TensorFlow Authors.


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 分散ストラテジーを使ってモデルを保存して読み込む

<table class="tfo-notebook-buttons" align="left">
  <td><a target="_blank" href="https://www.tensorflow.org/tutorials/distribute/save_and_load"><img src="https://www.tensorflow.org/images/tf_logo_32px.png">TensorFlow.org で表示</a></td>
  <td><a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/tutorials/distribute/save_and_load.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png">Google Colab で実行</a></td>
  <td><a target="_blank" href="https://github.com/tensorflow/docs-l10n/blob/master/site/ja/tutorials/distribute/save_and_load.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">GitHub でソースを表示</a></td>
  <td><a href="https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/tutorials/distribute/save_and_load.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png"> ノートブックをダウンロード</a></td>
</table>

## 概要

このチュートリアルでは、トレーニング中またはトレーニング後に `tf.distribute.Strategy` を使用して SavedModel 形式でモデルを保存して読み込む方法を説明します。Keras モデルの保存と読み込みには、高レベル（`tf.keras.Model.save` と `tf.keras.models.load_model`）と低レベル（`tf.saved_model.save` と `tf.saved_model.load`）の 2 種類の API があります。

SavedModel とシリアル化の全般的な内容については、[SavedModel ガイド](../../guide/saved_model.ipynb)と [Keras モデルのシリアル化ガイド](https://www.tensorflow.org/guide/keras/save_and_serialize)をお読みください。では、単純な例から始めましょう。

注意: TensorFlow モデルはコードであるため、信頼できないコードには注意する必要があります。詳細は、[TensorFlow を安全に使用する](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md)をご覧ください。


依存関係をインポートします。

In [2]:
import tensorflow_datasets as tfds

import tensorflow as tf


2024-01-11 18:14:24.741176: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 18:14:24.741223: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 18:14:24.742716: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


TensorFlow Datasets と `tf.data` でデータを読み込んで準備し、`tf.distribute.MirroredStrategy` を使ってモデルを作成します。

In [3]:
mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets = tfds.load(name='mnist', as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


`tf.keras.Model.fit` を使用してモデルをトレーニングします。 

In [4]:
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)

2024-01-11 18:14:31.467356: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.


Epoch 1/2


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


I0000 00:00:1704996877.849655   50970 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


  1/235 [..............................] - ETA: 25:12 - loss: 2.3161 - sparse_categorical_accuracy: 0.1016

  7/235 [..............................] - ETA: 2s - loss: 1.8962 - sparse_categorical_accuracy: 0.5631   

 14/235 [>.............................] - ETA: 1s - loss: 1.4808 - sparse_categorical_accuracy: 0.6643

 21/235 [=>............................] - ETA: 1s - loss: 1.2110 - sparse_categorical_accuracy: 0.7145

 28/235 [==>...........................] - ETA: 1s - loss: 1.0177 - sparse_categorical_accuracy: 0.7549

 35/235 [===>..........................] - ETA: 1s - loss: 0.9060 - sparse_categorical_accuracy: 0.7775

 42/235 [====>.........................] - ETA: 1s - loss: 0.8139 - sparse_categorical_accuracy: 0.7971

 49/235 [=====>........................] - ETA: 1s - loss: 0.7475 - sparse_categorical_accuracy: 0.8109





















































INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).




Epoch 2/2


  1/235 [..............................] - ETA: 8s - loss: 0.1046 - sparse_categorical_accuracy: 0.9766

  9/235 [>.............................] - ETA: 1s - loss: 0.1336 - sparse_categorical_accuracy: 0.9614

 17/235 [=>............................] - ETA: 1s - loss: 0.1402 - sparse_categorical_accuracy: 0.9584

 25/235 [==>...........................] - ETA: 1s - loss: 0.1255 - sparse_categorical_accuracy: 0.9634

 33/235 [===>..........................] - ETA: 1s - loss: 0.1260 - sparse_categorical_accuracy: 0.9640

 41/235 [====>.........................] - ETA: 1s - loss: 0.1268 - sparse_categorical_accuracy: 0.9636

 49/235 [=====>........................] - ETA: 1s - loss: 0.1253 - sparse_categorical_accuracy: 0.9640

















































<keras.src.callbacks.History at 0x7f7e10386cd0>

## モデルを保存して読み込む

作業に使用する単純なモデルを準備できたので、保存と読み込みに使用する API を見てみましょう。使用できる API には、以下の 2 種類があります。

- 高レベル（Keras）: `Model.save` および `tf.keras.models.load_model`（`.keras` zip アーカイブ形式）
- 低レベル: `tf.saved_model.save` および `tf.saved_model.load`（TF SavedModel 形式）


### Keras API

Keras API を使用したモデルの保存と読み込みの例を以下に示します。

In [5]:
keras_model_path = '/tmp/keras_save.keras'
model.save(keras_model_path)

`tf.distribute.Strategy` を使用せずにモデルを復元します。

In [6]:
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)

Epoch 1/2


  1/235 [..............................] - ETA: 3:36 - loss: 0.0558 - sparse_categorical_accuracy: 0.9922

 13/235 [>.............................] - ETA: 1s - loss: 0.0834 - sparse_categorical_accuracy: 0.9784  

 25/235 [==>...........................] - ETA: 0s - loss: 0.0850 - sparse_categorical_accuracy: 0.9775

 37/235 [===>..........................] - ETA: 0s - loss: 0.0844 - sparse_categorical_accuracy: 0.9773

 49/235 [=====>........................] - ETA: 0s - loss: 0.0814 - sparse_categorical_accuracy: 0.9781































Epoch 2/2


  1/235 [..............................] - ETA: 7s - loss: 0.0566 - sparse_categorical_accuracy: 0.9961

 14/235 [>.............................] - ETA: 0s - loss: 0.0643 - sparse_categorical_accuracy: 0.9821

 28/235 [==>...........................] - ETA: 0s - loss: 0.0636 - sparse_categorical_accuracy: 0.9824

 42/235 [====>.........................] - ETA: 0s - loss: 0.0608 - sparse_categorical_accuracy: 0.9837































<keras.src.callbacks.History at 0x7f7f5048b670>

モデルを復元したら、`Model.compile` をもう一度呼び出さずにそのままトレーニングを続行できます。これは、保存前にすでにコンパイル済みであるためです。このモデルは、Keras zip アーカイブ形式で保存されており、`.keras` 拡張子で識別できます。詳細については、[Keras の保存に関するガイド](https://www.tensorflow.org/guide/keras/save_and_serialize)をご覧ください。

次に、`tf.distribute.Strategy` を使用してモデルを復元し、トレーニングします。

In [7]:
another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)

Epoch 1/2


2024-01-11 18:14:45.633878: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2024-01-11 18:14:45.694588: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


  1/235 [..............................] - ETA: 1:49 - loss: 0.0754 - sparse_categorical_accuracy: 0.9805

  5/235 [..............................] - ETA: 3s - loss: 0.0787 - sparse_categorical_accuracy: 0.9766  

 10/235 [>.............................] - ETA: 2s - loss: 0.0798 - sparse_categorical_accuracy: 0.9770

 15/235 [>.............................] - ETA: 2s - loss: 0.0826 - sparse_categorical_accuracy: 0.9776

2024-01-11 18:14:46.266915: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.
2024-01-11 18:14:46.267137: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.
2024-01-11 18:14:46.301591: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.


 20/235 [=>............................] - ETA: 2s - loss: 0.0776 - sparse_categorical_accuracy: 0.9789

 25/235 [==>...........................] - ETA: 2s - loss: 0.0782 - sparse_categorical_accuracy: 0.9781

 30/235 [==>...........................] - ETA: 2s - loss: 0.0782 - sparse_categorical_accuracy: 0.9779

 35/235 [===>..........................] - ETA: 2s - loss: 0.0765 - sparse_categorical_accuracy: 0.9785

 40/235 [====>.........................] - ETA: 2s - loss: 0.0758 - sparse_categorical_accuracy: 0.9792

 45/235 [====>.........................] - ETA: 2s - loss: 0.0767 - sparse_categorical_accuracy: 0.9791

 50/235 [=====>........................] - ETA: 2s - loss: 0.0758 - sparse_categorical_accuracy: 0.9793

 54/235 [=====>........................] - ETA: 2s - loss: 0.0753 - sparse_categorical_accuracy: 0.9792











































































Epoch 2/2


  1/235 [..............................] - ETA: 9s - loss: 0.1249 - sparse_categorical_accuracy: 0.9609

  5/235 [..............................] - ETA: 2s - loss: 0.0685 - sparse_categorical_accuracy: 0.9773

  9/235 [>.............................] - ETA: 2s - loss: 0.0683 - sparse_categorical_accuracy: 0.9787

 14/235 [>.............................] - ETA: 2s - loss: 0.0670 - sparse_categorical_accuracy: 0.9794

 19/235 [=>............................] - ETA: 2s - loss: 0.0632 - sparse_categorical_accuracy: 0.9817

 24/235 [==>...........................] - ETA: 2s - loss: 0.0606 - sparse_categorical_accuracy: 0.9821

 29/235 [==>...........................] - ETA: 2s - loss: 0.0617 - sparse_categorical_accuracy: 0.9824

 33/235 [===>..........................] - ETA: 2s - loss: 0.0638 - sparse_categorical_accuracy: 0.9820

 38/235 [===>..........................] - ETA: 2s - loss: 0.0617 - sparse_categorical_accuracy: 0.9823

 43/235 [====>.........................] - ETA: 2s - loss: 0.0595 - sparse_categorical_accuracy: 0.9831

 48/235 [=====>........................] - ETA: 2s - loss: 0.0578 - sparse_categorical_accuracy: 0.9835

 53/235 [=====>........................] - ETA: 2s - loss: 0.0570 - sparse_categorical_accuracy: 0.9835













































































`Model.fit` 出力からわかるように、`tf.distribute.Strategy` を使って期待どおり読み込まれました。ここで使用されるストラテジーは、保存前と同じストラテジーである必要はありません。 

### `tf.saved_model` API

より低レベルの API を使用したモデルの保存方法は、Keras API を使う方法に似ています。

In [8]:
model = get_model()  # get a fresh model
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)

INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


読み込みは、`tf.saved_model.load` を使用して行えますが、これは低レベル API（したがって、より幅広いユースケースのある API）であるため、Keras モデルを返しません。代わりに、推論を行うために使用できる関数を含むオブジェクトを返します。以下に例を示します。

In [9]:
DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

読み込まれたオブジェクトには、それぞれにキーが関連付けられた複数の関数が含まれている可能性があります。`"serving_default"` キーは、保存された Keras モデルを使用した推論関数のデフォルトのキーです。この関数で推論するには、以下のようにします。 

In [10]:
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))

{'dense_3': <tf.Tensor: shape=(256, 10), dtype=float32, numpy=
array([[ 0.04390968,  0.30341768,  0.05374109, ..., -0.35343656,
         0.03065785, -0.00975093],
       [-0.04910231,  0.16482985,  0.06436244, ..., -0.27770516,
         0.02216907,  0.13293922],
       [-0.05661844,  0.2683993 , -0.06041192, ..., -0.26340052,
         0.02152548,  0.10264045],
       ...,
       [-0.12805948,  0.11079367, -0.10359426, ..., -0.26105058,
         0.0311166 ,  0.02954188],
       [-0.11231118,  0.22162321,  0.04027553, ..., -0.34616578,
         0.02095792,  0.01622906],
       [-0.07966347,  0.08217648, -0.14690818, ..., -0.21150741,
         0.03090278, -0.12792973]], dtype=float32)>}


2024-01-11 18:14:52.675006: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


また、分散方法で読み込んで推論を実行することもできます。

In [11]:
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    result = another_strategy.run(inference_func, args=(batch,))
    print(result)
    break

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


2024-01-11 18:14:52.889461: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.






{'dense_3': PerReplica:{
  0: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 0.04390973,  0.30341768,  0.05374111, -0.02709487, -0.28792804,
        -0.19333729,  0.325674  , -0.3534366 ,  0.03065785, -0.00975094],
       [-0.04910226,  0.1648299 ,  0.06436247,  0.0941631 , -0.36874533,
        -0.20513675,  0.09101719, -0.27770516,  0.02216907,  0.1329391 ],
       [-0.05661836,  0.26839924, -0.0604119 ,  0.04293879, -0.37735796,
        -0.01866844,  0.23681116, -0.26340055,  0.02152544,  0.10264044],
       [ 0.10640895,  0.1561212 ,  0.06909597, -0.06987031, -0.1984469 ,
        -0.04289627,  0.18389529, -0.18640813,  0.06818488, -0.11756891],
       [ 0.01074794,  0.21058783, -0.13376951, -0.07198893, -0.34633294,
        -0.0951823 ,  0.09859037, -0.17102982,  0.00822654,  0.02078733],
       [ 0.02912218,  0.19898024, -0.2194208 , -0.09297523, -0.22816458,
        -0.14823863,  0.1251952 , -0.22406554,  0.07672149, -0.06627873],
       [-0.07373619,  0.1642075 ,  0.0

2024-01-11 18:14:53.554141: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


復元された関数の呼び出しは、保存されたモデル（`tf.keras.Model.predict`）に対するフォワードパスです。読み込まれた関数をトレーニングし続ける場合はどうでしょうか。または読み込まれた関数をより大きなモデルに埋め込むには？一般的には、この読み込まれたオブジェクトを Keras レイヤーにラップして行うことができます。幸いにも、[TF Hub](https://www.tensorflow.org/hub) には、以下に示すとおり、この目的に使用できる [`hub.KerasLayer`](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/keras_layer.py) が用意されています。

In [12]:
import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


2024-01-11 18:14:54.395157: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.


Epoch 1/2


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


  1/235 [..............................] - ETA: 11:14 - loss: 2.3140 - sparse_categorical_accuracy: 0.0586

  8/235 [>.............................] - ETA: 1s - loss: 1.8867 - sparse_categorical_accuracy: 0.4834   

 16/235 [=>............................] - ETA: 1s - loss: 1.4627 - sparse_categorical_accuracy: 0.6167

 24/235 [==>...........................] - ETA: 1s - loss: 1.1927 - sparse_categorical_accuracy: 0.6810

 32/235 [===>..........................] - ETA: 1s - loss: 1.0109 - sparse_categorical_accuracy: 0.7264

 40/235 [====>.........................] - ETA: 1s - loss: 0.8893 - sparse_categorical_accuracy: 0.7588

 48/235 [=====>........................] - ETA: 1s - loss: 0.8022 - sparse_categorical_accuracy: 0.7811

















































Epoch 2/2


  1/235 [..............................] - ETA: 8s - loss: 0.1879 - sparse_categorical_accuracy: 0.9531

  9/235 [>.............................] - ETA: 1s - loss: 0.1717 - sparse_categorical_accuracy: 0.9549

 17/235 [=>............................] - ETA: 1s - loss: 0.1583 - sparse_categorical_accuracy: 0.9561

 25/235 [==>...........................] - ETA: 1s - loss: 0.1561 - sparse_categorical_accuracy: 0.9564

 33/235 [===>..........................] - ETA: 1s - loss: 0.1521 - sparse_categorical_accuracy: 0.9576

 41/235 [====>.........................] - ETA: 1s - loss: 0.1447 - sparse_categorical_accuracy: 0.9591

 49/235 [=====>........................] - ETA: 1s - loss: 0.1386 - sparse_categorical_accuracy: 0.9609

















































上記の例では、`hub.KerasLayer` は `tf.saved_model.load()` から読み込まれた結果を、別のモデルの構築に使用できる Keras レイヤーにラップしています。転移学習を行う際に非常に便利な手法です。 

### どの API を使用すべきですか？

保存においては、Keras モデルを使用している場合は、低レベル API が実現できる追加の制御が必要でない限り、Keras の `Model.save` API を使用します。保存しているものが Keras モデルでない場合は、低レベル API の `tf.saved_model.save` しか使用できません。

読み込みにおいては、使用する API はモデルの読み込みから得ようとしているものによって異なります。Keras モデルを使用できない場合（または使用したくない場合）は、`tf.saved_model.load` を使用し、使用できる場合は `tf.keras.models.load_model` を使用します。Keras モデルを保存した場合にのみ、Keras モデルを読み込めることに注意してください。

API を混在させることも可能です。`model.save` で Keras モデルを保存し、低レベルの  `tf.saved_model.load` API を使用して、非 Keras モデルを読み込むことができます。 

In [13]:
model = get_model()

# Saving the model using Keras `Model.save`
model.save(saved_model_path)

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using the lower-level API
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)

INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


### ローカルデバイスからの読み込みまたは保存

ローカル I/O デバイスから読み込みと保存を行い、リモートデバイスでトレーニングする場合（Cloud TPU を使用する場合など）、`tf.saved_model.SaveOptions` と `tf.saved_model.LoadOptions` に `experimental_io_device` を使用して、I/O デバイスを `localhost` に設定する必要があります。以下に例を示します。

In [14]:
model = get_model()

# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)

INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


### 警告

Keras モデルを特定の方法で作成してから、トレーニングする前に保存するという、以下のような特別なケースがあります。

In [15]:
class SubclassedModel(tf.keras.Model):
  """Example model defined by subclassing `tf.keras.Model`."""

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
try:
  my_model.save(saved_model_path)
except ValueError as e:
  print(f'{type(e).__name__}: ', *e.args)





ValueError:  Model <__main__.SubclassedModel object at 0x7f7f09dc29a0> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.


SavedModel は `tf.function` をトレースする際に生成される `tf.types.experimental.ConcreteFunction` オブジェクトを保存します（詳細は、[グラフと tf.function の基本](../../guide/intro_to_graphs.ipynb)ガイドの*関数はいつトレースしますか？* をご覧ください）。このような `ValueError` が発生した場合、`Model.save` がトレースされた `ConcreteFunction` を見つけられなかったか作成できなかったことが原因です。

**注意:** 少なくとも 1 つの `ConcreteFunction` がない場合にモデルを保存しないことをお勧めします。そうでない場合、低レベル API は、`ConcreteFunction` シグネチャのない状態で SavedModel を生成してしまうためです（SavedModel 形式については、[こちら](../../guide/saved_model.ipynb)をご覧ください）。以下に例を示します。

In [16]:
tf.saved_model.save(my_model, saved_model_path)
x = tf.saved_model.load(saved_model_path)
x.signatures









INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


_SignatureMap({})

通常、モデルのフォワードパス（`call` メソッド）は、モデルが Keras の `Model.fit` メソッドを通じて初めて呼び出されたときに、自動的にトレースされます。また、最初のレイヤーを `tf.keras.layers.InputLayer` などにして、`input_shape` キーワード引数に渡すことで入力形状を設定している場合、Keras の [Sequential](https://www.tensorflow.org/guide/keras/sequential_model) API と [Functional](https://www.tensorflow.org/guide/keras/functional) API によって `ConcreteFunction` が生成されることもあります。

モデルにトレース済みの `ConcreteFunction` が存在するかを確認するには、`Model.save_spec` が `None` になっていることを確認します。

In [17]:
print(my_model.save_spec() is None)

True


`tf.keras.Model.fit` を使ってモデルをトレーニングし、`save_spec` が定義され、モデルの保存が機能するかを確認しましょう。

In [18]:
BATCH_SIZE_PER_REPLICA = 4
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

dataset_size = 100
dataset = tf.data.Dataset.from_tensors(
    (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))
    ).repeat(dataset_size).batch(BATCH_SIZE)

my_model.compile(optimizer='adam', loss='mean_squared_error')
my_model.fit(dataset, epochs=2)

print(my_model.save_spec() is None)
my_model.save(saved_model_path)

Epoch 1/2


1/7 [===>..........................] - ETA: 5s - loss: 4.8873



Epoch 2/2


1/7 [===>..........................] - ETA: 0s - loss: 4.5952



False
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09cb3220>, 140183601399632), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09cb3220>, 140183601399632), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09a6c790>, 140183601400208), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09a6c790>, 140183601400208), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09cb3220>, 140183601399632), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09cb3220>, 140183601399632), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09a6c790>, 140183601400208), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09a6c790>, 140183601400208), {}).


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets
