{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "MhoQ0WE77laV" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-01-11T18:11:24.874443Z", "iopub.status.busy": "2024-01-11T18:11:24.874166Z", "iopub.status.idle": "2024-01-11T18:11:24.878249Z", "shell.execute_reply": "2024-01-11T18:11:24.877656Z" }, "id": "_ckMIh7O7s6D" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "jYysdyb-CaWM" }, "source": [ "# tf.distribute.Strategy を使用したカスタムトレーニング" ] }, { "cell_type": "markdown", "metadata": { "id": "S5Uhzt6vVIB2" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
reduce
で、集約された値を取得することができます。また、tf.distribute.Strategy.experimental_local_results
を実行して、ローカルレプリカごとに 1 つ、結果に含まれる値のリストを取得することもできます。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-q5qp31IQD8t"
},
"source": [
"## 最新のチェックポイントを復元してテストする"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WNW2P00bkMGJ"
},
"source": [
"`tf.distribute.Strategy`でチェックポイントされたモデルは、ストラテジーの有無に関わらず復元することができます。"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-11T18:12:20.890492Z",
"iopub.status.busy": "2024-01-11T18:12:20.889866Z",
"iopub.status.idle": "2024-01-11T18:12:20.955159Z",
"shell.execute_reply": "2024-01-11T18:12:20.954559Z"
},
"id": "pg3B-Cw_cn3a"
},
"outputs": [],
"source": [
"eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
" name='eval_accuracy')\n",
"\n",
"new_model = create_model()\n",
"new_optimizer = tf.keras.optimizers.Adam()\n",
"\n",
"test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-11T18:12:20.958616Z",
"iopub.status.busy": "2024-01-11T18:12:20.958143Z",
"iopub.status.idle": "2024-01-11T18:12:20.961672Z",
"shell.execute_reply": "2024-01-11T18:12:20.961113Z"
},
"id": "7qYii7KUYiSM"
},
"outputs": [],
"source": [
"@tf.function\n",
"def eval_step(images, labels):\n",
" predictions = new_model(images, training=False)\n",
" eval_accuracy(labels, predictions)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-11T18:12:20.964537Z",
"iopub.status.busy": "2024-01-11T18:12:20.964012Z",
"iopub.status.idle": "2024-01-11T18:12:21.593962Z",
"shell.execute_reply": "2024-01-11T18:12:21.593146Z"
},
"id": "LeZ6eeWRoUNq"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy after restoring the saved model without strategy: 89.8800048828125\n"
]
}
],
"source": [
"checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)\n",
"checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))\n",
"\n",
"for images, labels in test_dataset:\n",
" eval_step(images, labels)\n",
"\n",
"print('Accuracy after restoring the saved model without strategy: {}'.format(\n",
" eval_accuracy.result() * 100))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EbcI87EEzhzg"
},
"source": [
"## データセットのイテレーションの代替方法\n",
"\n",
"### イテレータを使用する\n",
"\n",
"データセット全体ではなく、任意のステップ数のイテレーションを行う場合は、`iter` 呼び出しを使用してイテレータを作成し、そのイテレータ上で `next` を明示的に呼び出すことができます。`tf.function` の内側と外側の両方でデータセットのイテレーションを選択することができます。ここでは、イテレータを使用し `tf.function` の外側のデータセットのイテレーションを実行する小さなスニペットを示します。\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-11T18:12:21.597478Z",
"iopub.status.busy": "2024-01-11T18:12:21.596924Z",
"iopub.status.idle": "2024-01-11T18:12:27.963415Z",
"shell.execute_reply": "2024-01-11T18:12:27.962710Z"
},
"id": "7c73wGC00CzN"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10, Loss: 0.22530433535575867, Accuracy: 92.0703125\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10, Loss: 0.21155035495758057, Accuracy: 92.421875\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10, Loss: 0.23270802199840546, Accuracy: 91.09375\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10, Loss: 0.2111983597278595, Accuracy: 92.421875\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10, Loss: 0.2315395325422287, Accuracy: 91.7578125\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10, Loss: 0.22891399264335632, Accuracy: 91.640625\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10, Loss: 0.23187729716300964, Accuracy: 91.4453125\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10, Loss: 0.23954670131206512, Accuracy: 91.4453125\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10, Loss: 0.21727390587329865, Accuracy: 92.5390625\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10, Loss: 0.2208312749862671, Accuracy: 92.3046875\n"
]
}
],
"source": [
"for _ in range(EPOCHS):\n",
" total_loss = 0.0\n",
" num_batches = 0\n",
" train_iter = iter(train_dist_dataset)\n",
"\n",
" for _ in range(10):\n",
" total_loss += distributed_train_step(next(train_iter))\n",
" num_batches += 1\n",
" average_train_loss = total_loss / num_batches\n",
"\n",
" template = (\"Epoch {}, Loss: {}, Accuracy: {}\")\n",
" print(template.format(epoch + 1, average_train_loss, train_accuracy.result() * 100))\n",
" train_accuracy.reset_states()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GxVp48Oy0m6y"
},
"source": [
"### tf.function 内でイテレーションする\n",
"\n",
"`for x in ...` コンストラクトを使用して、または上記で行ったようにイテレータを作成して、`tf.function` 内で `train_dist_dataset` の入力全体をイテレートすることもできます。以下の例では、1 エポックのトレーニングを `@tf.function` デコレータでラップし、関数内で `train_dist_dataset` をイテレーションする方法を示します。"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-11T18:12:27.966722Z",
"iopub.status.busy": "2024-01-11T18:12:27.966466Z",
"iopub.status.idle": "2024-01-11T18:12:41.531191Z",
"shell.execute_reply": "2024-01-11T18:12:41.530407Z"
},
"id": "-REzmcXv00qm"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:462: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.\n",
" warnings.warn(\"To make it possible to preserve tf.data options across \"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, Loss: 0.22321248054504395, Accuracy: 92.12667083740234\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2, Loss: 0.21352693438529968, Accuracy: 92.40499877929688\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3, Loss: 0.2031208723783493, Accuracy: 92.74666595458984\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4, Loss: 0.19914129376411438, Accuracy: 92.96666717529297\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5, Loss: 0.18742477893829346, Accuracy: 93.40333557128906\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6, Loss: 0.18182916939258575, Accuracy: 93.59166717529297\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 7, Loss: 0.17676156759262085, Accuracy: 93.77666473388672\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 8, Loss: 0.16894836723804474, Accuracy: 94.06333923339844\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 9, Loss: 0.1639356017112732, Accuracy: 94.2683334350586\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10, Loss: 0.1561119258403778, Accuracy: 94.56999969482422\n"
]
}
],
"source": [
"@tf.function\n",
"def distributed_train_epoch(dataset):\n",
" total_loss = 0.0\n",
" num_batches = 0\n",
" for x in dataset:\n",
" per_replica_losses = strategy.run(train_step, args=(x,))\n",
" total_loss += strategy.reduce(\n",
" tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)\n",
" num_batches += 1\n",
" return total_loss / tf.cast(num_batches, dtype=tf.float32)\n",
"\n",
"for epoch in range(EPOCHS):\n",
" train_loss = distributed_train_epoch(train_dist_dataset)\n",
"\n",
" template = (\"Epoch {}, Loss: {}, Accuracy: {}\")\n",
" print(template.format(epoch + 1, train_loss, train_accuracy.result() * 100))\n",
"\n",
" train_accuracy.reset_states()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MuZGXiyC7ABR"
},
"source": [
"### レプリカ間でトレーニング損失を追跡する\n",
"\n",
"注意: 一般的なルールとして、サンプルごとの値の追跡には`tf.keras.Metrics`を使用し、レプリカ内で集約された値を避ける必要があります。\n",
"\n",
"損失スケーリングの計算が実行されるため、レプリカ間でトレーニング損失を追跡するために `tf.keras.metrics.Mean` を使用することは推奨されません。\n",
"\n",
"例えば、次のような特徴を持つトレーニングジョブを実行するとします。\n",
"\n",
"- レプリカ 2 つ\n",
"- 各レプリカで 2 つのサンプルを処理\n",
"- 結果の損失値 : 各レプリカで [2, 3] および [4, 5]\n",
"- グローバルバッチサイズ = 4\n",
"\n",
"損失スケーリングで損失値を加算して各レプリカのサンプルごとの損失の値を計算し、さらにグローバルバッチサイズで除算します。この場合は、`(2 + 3) / 4 = 1.25`および`(4 + 5) / 4 = 2.25`となります。\n",
"\n",
"`tf.keras.metrics.Mean` を使用して 2 つのレプリカ間の損失を追跡すると、異なる結果が得られます。この例では、`total` は 3.50、`count` は 2 となるため、メトリックで `result()` が呼び出されると、`total`/`count` = 1.75 となります。`tf.keras.Metrics` で計算された損失は、同期するレプリカの数に等しい追加の係数によってスケーリングされます。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xisYJaV9KZTN"
},
"source": [
"### ガイドと例\n",
"\n",
"カスタムトレーニングループを用いた分散ストラテジーの使用例をここに幾つか示します。\n",
"\n",
"1. 分散型トレーニングガイド\n",
"2. `MirroredStrategy`を使用した [DenseNet](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/densenet/distributed_train.py) の例。\n",
"3. MirroredStrategy
と`TPUStrategy`を使用してトレーニングされた BERT の例。この例は、分散トレーニングなどの間にチェックポイントから読み込む方法と、定期的にチェックポイントを生成する方法を理解するのに特に有用です。\n",
"4. `MirroredStrategy` を使用してトレーニングされ、`keras_use_ctl` フラグを使用した有効化が可能な、[NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py) の例。\n",
"5. `MirroredStrategy`を使用してトレーニングされた、[NMT](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/nmt_with_attention/distributed_train.py) の例。\n",
"\n",
"その他の例は、[分散型ストラテジーガイド](../../guide/distributed_training.ipynb)の「*例とチュートリアル*」に記載されています。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6hEJNsokjOKs"
},
"source": [
"## 次のステップ\n",
"\n",
"- 新しい`tf.distribute.Strategy` API を独自のモデルで試してみましょう。\n",
"- TensorFlow モデルのパフォーマンスを最適化する方法についてのその他の詳細は、[`tf.function` によるパフォーマンスの改善](../../guide/function.ipynb)と [TensorFlow Profiler](../../guide/profiler.md) をご覧ください。\n",
"- [TensorFlow での分散型トレーニング](../../guide/distributed_training.ipynb)ガイドでは、利用可能な分散ストラテジーの概要が説明されています。"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "custom_training.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 0
}