{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "b518b04cbfe0" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T21:05:03.531243Z", "iopub.status.busy": "2022-12-14T21:05:03.530736Z", "iopub.status.idle": "2022-12-14T21:05:03.534281Z", "shell.execute_reply": "2022-12-14T21:05:03.533777Z" }, "id": "906e07f6e562" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "a5620ee4049e" }, "source": [ "# Model.fit の処理をカスタマイズする" ] }, { "cell_type": "markdown", "metadata": { "id": "0a56ffedf331" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "7ebb4e65ef9b" }, "source": [ "## はじめに\n", "\n", "教師あり学習を実行するときに `fit()` を使用するとスムーズに学習を進めることができます。\n", "\n", "独自のトレーニングループを新規で書く必要がある場合には、`GradientTape` を使用すると、細かく制御することができます。\n", "\n", "しかし、カスタムトレーニングアルゴリズムが必要ながらも、コールバック、組み込みの分散サポート、ステップ結合など、`fit()` の便利な機能を利用したい場合には、どうすればよいのでしょうか?\n", "\n", "Keras の基本原則は、**複雑性のプログレッシブディスクロージャ―**です。常に段階的に低レベルのワークフローに入ることが可能で、高レベルの機能性がユースケースと完全に一致しない場合でも、急激に性能が落ちるようなことはありません。相応の高レベルの利便性を維持しながら細部をさらに制御することができます。\n", "\n", "`fit()` の動作をカスタマイズする必要がある場合は、**`Model` クラスのトレーニングステップ関数をオーバーライド**する必要があります。これはデータのバッチごとに `fit()` に呼び出される関数です。これによって、通常通り `fit()` を呼び出せるようになり、独自の学習アルゴリズムが実行されるようになります。\n", "\n", "このパターンは Functional API を使用したモデル構築を妨げるものではないことに注意してください。これは、`Sequential` モデル、Functional API モデル、サブクラス化されたモデルのいずれを構築する場合にも適用可能です。\n", "\n", "では、その仕組みを見ていきましょう。" ] }, { "cell_type": "markdown", "metadata": { "id": "2849e371b9b6" }, "source": [ "## セットアップ\n", "\n", "TensorFlow 2.2 以降が必要です。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:05:03.537745Z", "iopub.status.busy": "2022-12-14T21:05:03.537248Z", "iopub.status.idle": "2022-12-14T21:05:05.452797Z", "shell.execute_reply": "2022-12-14T21:05:05.452060Z" }, "id": "4dadb6688663" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 21:05:04.471118: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 21:05:04.471213: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 21:05:04.471222: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import tensorflow as tf\n", "from tensorflow import keras" ] }, { "cell_type": "markdown", "metadata": { "id": "9022333acaa7" }, "source": [ "## 最初の簡単な例\n", "\n", "簡単な例から始めてみましょう。\n", "\n", "- `keras.Model` をサブクラス化する新しいクラスを作成します。\n", "- `train_step(self, data)` メソッドだけをオーバーライドします。\n", "- メトリクス名(損失を含む)をマッピングするディクショナリを現在の値に返します。\n", "\n", "入力引数の `data` は、トレーニングデータとして適合するために渡される値です。\n", "\n", "- `fit(x, y, ...)` を呼び出して Numpy 配列を渡す場合は、`data` はタプル型 `(x, y)` になります。\n", "- `fit(dataset, ...)` を呼び出して `tf.data.Dataset` を渡す場合は、`data` は各バッチで `dataset` により生成される値になります。\n", "\n", "`train_step` メソッドの本体には、既に使い慣れているものと同様の定期的なトレーニングアップデートを実装しています。重要なのは、**損失の計算を `self.compiled_loss` を介して行っている**ことで、それによって `compile()` に渡された損失関数がラップされています。\n", "\n", "同様に、`self.compiled_metrics.update_state(y, y_pred)` を呼び出して `compile()` に渡されたメトリクスの状態を更新し、最後に `self.metrics` の結果をクエリして現在の値を取得しています。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:05:05.457520Z", "iopub.status.busy": "2022-12-14T21:05:05.456786Z", "iopub.status.idle": "2022-12-14T21:05:05.462359Z", "shell.execute_reply": "2022-12-14T21:05:05.461748Z" }, "id": "060c8bf4150d" }, "outputs": [], "source": [ "class CustomModel(keras.Model):\n", " def train_step(self, data):\n", " # Unpack the data. Its structure depends on your model and\n", " # on what you pass to `fit()`.\n", " x, y = data\n", "\n", " with tf.GradientTape() as tape:\n", " y_pred = self(x, training=True) # Forward pass\n", " # Compute the loss value\n", " # (the loss function is configured in `compile()`)\n", " loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)\n", "\n", " # Compute gradients\n", " trainable_vars = self.trainable_variables\n", " gradients = tape.gradient(loss, trainable_vars)\n", " # Update weights\n", " self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n", " # Update metrics (includes the metric that tracks the loss)\n", " self.compiled_metrics.update_state(y, y_pred)\n", " # Return a dict mapping metric names to current value\n", " return {m.name: m.result() for m in self.metrics}\n" ] }, { "cell_type": "markdown", "metadata": { "id": "c9d2cc7a7014" }, "source": [ "これを試してみましょう。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:05:05.465622Z", "iopub.status.busy": "2022-12-14T21:05:05.465039Z", "iopub.status.idle": "2022-12-14T21:05:10.145581Z", "shell.execute_reply": "2022-12-14T21:05:10.144932Z" }, "id": "5e6bd7b554f6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 31s - loss: 0.3652 - mae: 0.5208" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "27/32 [========================>.....] - ETA: 0s - loss: 0.2593 - mae: 0.4162 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 1s 2ms/step - loss: 0.2539 - mae: 0.4101\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1608 - mae: 0.3305" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "26/32 [=======================>......] - ETA: 0s - loss: 0.1893 - mae: 0.3509" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1882 - mae: 0.3503\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1891 - mae: 0.3509" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "26/32 [=======================>......] - ETA: 0s - loss: 0.1801 - mae: 0.3424" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1825 - mae: 0.3448\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "\n", "# Construct and compile an instance of CustomModel\n", "inputs = keras.Input(shape=(32,))\n", "outputs = keras.layers.Dense(1)(inputs)\n", "model = CustomModel(inputs, outputs)\n", "model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n", "\n", "# Just use `fit` as usual\n", "x = np.random.random((1000, 32))\n", "y = np.random.random((1000, 1))\n", "model.fit(x, y, epochs=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "a882cb6467d6" }, "source": [ "## 低レベルにする\n", "\n", "当然ながら、`compile()` に損失関数を渡すことを省略し、代わりに `train_step` ですべてを*手動で*実行することは可能です。これはメトリクスの場合でも同様です。\n", "\n", "オプティマイザの構成に `compile()` のみを使用した、低レベルの例を次に示します。\n", "\n", "- まず、損失と MAE スコアを追跡する `Metric` インスタンスを作成します。\n", "- これらのメトリクスの状態を更新するカスタム `train_step()` を実装し(メトリクスで `update_state()` を呼び出します)、現在の平均値を返して進捗バーで表示し、任意のコールバックに渡せるようにメトリクスをクエリします(result() を使用)。\n", "- エポックごとにメトリクスに `reset_states()` を呼び出す必要があるところに注意してください。呼び出さない場合、`result()` は通常処理しているエポックごとの平均ではなく、トレーニングを開始してからの平均を返してしまいます。幸いにも、これはフレームワークが行ってくれるため、モデルの `metrics` プロパティにリセットするメトリクスをリストするだけで実現できます。モデルは、そこにリストされているオブジェクトに対する `reset_states()` の呼び出しを各 `fit()` エポックの開始時または `evaluate()` への呼び出しの開始時に行うようになります。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:05:10.149179Z", "iopub.status.busy": "2022-12-14T21:05:10.148689Z", "iopub.status.idle": "2022-12-14T21:05:10.935539Z", "shell.execute_reply": "2022-12-14T21:05:10.934820Z" }, "id": "2308abf5fe7d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 11s - loss: 0.2348 - mae: 0.3991" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "26/32 [=======================>......] - ETA: 0s - loss: 0.1867 - mae: 0.3506 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1901 - mae: 0.3546\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1726 - mae: 0.3265" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "28/32 [=========================>....] - ETA: 0s - loss: 0.1784 - mae: 0.3454" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1799 - mae: 0.3451\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1752 - mae: 0.3578" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "28/32 [=========================>....] - ETA: 0s - loss: 0.1716 - mae: 0.3383" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1705 - mae: 0.3363\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1291 - mae: 0.3050" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "28/32 [=========================>....] - ETA: 0s - loss: 0.1618 - mae: 0.3282" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1620 - mae: 0.3284\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.2133 - mae: 0.3846" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "28/32 [=========================>....] - ETA: 0s - loss: 0.1557 - mae: 0.3215" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1545 - mae: 0.3212\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss_tracker = keras.metrics.Mean(name=\"loss\")\n", "mae_metric = keras.metrics.MeanAbsoluteError(name=\"mae\")\n", "\n", "\n", "class CustomModel(keras.Model):\n", " def train_step(self, data):\n", " x, y = data\n", "\n", " with tf.GradientTape() as tape:\n", " y_pred = self(x, training=True) # Forward pass\n", " # Compute our own loss\n", " loss = keras.losses.mean_squared_error(y, y_pred)\n", "\n", " # Compute gradients\n", " trainable_vars = self.trainable_variables\n", " gradients = tape.gradient(loss, trainable_vars)\n", "\n", " # Update weights\n", " self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n", "\n", " # Compute our own metrics\n", " loss_tracker.update_state(loss)\n", " mae_metric.update_state(y, y_pred)\n", " return {\"loss\": loss_tracker.result(), \"mae\": mae_metric.result()}\n", "\n", " @property\n", " def metrics(self):\n", " # We list our `Metric` objects here so that `reset_states()` can be\n", " # called automatically at the start of each epoch\n", " # or at the start of `evaluate()`.\n", " # If you don't implement this property, you have to call\n", " # `reset_states()` yourself at the time of your choosing.\n", " return [loss_tracker, mae_metric]\n", "\n", "\n", "# Construct an instance of CustomModel\n", "inputs = keras.Input(shape=(32,))\n", "outputs = keras.layers.Dense(1)(inputs)\n", "model = CustomModel(inputs, outputs)\n", "\n", "# We don't passs a loss or metrics here.\n", "model.compile(optimizer=\"adam\")\n", "\n", "# Just use `fit` as usual -- you can use callbacks, etc.\n", "x = np.random.random((1000, 32))\n", "y = np.random.random((1000, 1))\n", "model.fit(x, y, epochs=5)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "f451e382c6a8" }, "source": [ "## `sample_weight` と `class_weight` をサポートする\n", "\n", "最初の基本的な例では、サンプルの重み付けについては何も言及していないことに気付いているかもしれません。`fit()` の引数 `sample_weight` と `class_weight` をサポートする場合には、次のようにします。\n", "\n", "- `data` 引数から `sample_weight` をアンパックします。\n", "- それを `compiled_loss` と `compiled_metrics` に渡します(もちろん、 損失とメトリクスが `compile()` に依存しない場合は手動での適用が可能です)。\n", "- それがリストです。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:05:10.939394Z", "iopub.status.busy": "2022-12-14T21:05:10.938757Z", "iopub.status.idle": "2022-12-14T21:05:11.961093Z", "shell.execute_reply": "2022-12-14T21:05:11.960379Z" }, "id": "522d7281f948" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 13s - loss: 0.5207 - mae: 0.9753" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "24/32 [=====================>........] - ETA: 0s - loss: 0.4820 - mae: 0.8482 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 1s 2ms/step - loss: 0.4508 - mae: 0.8067\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.3352 - mae: 0.6110" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "25/32 [======================>.......] - ETA: 0s - loss: 0.1987 - mae: 0.4903" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1881 - mae: 0.4751\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1292 - mae: 0.3837" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "26/32 [=======================>......] - ETA: 0s - loss: 0.1213 - mae: 0.3811" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1230 - mae: 0.3819\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class CustomModel(keras.Model):\n", " def train_step(self, data):\n", " # Unpack the data. Its structure depends on your model and\n", " # on what you pass to `fit()`.\n", " if len(data) == 3:\n", " x, y, sample_weight = data\n", " else:\n", " sample_weight = None\n", " x, y = data\n", "\n", " with tf.GradientTape() as tape:\n", " y_pred = self(x, training=True) # Forward pass\n", " # Compute the loss value.\n", " # The loss function is configured in `compile()`.\n", " loss = self.compiled_loss(\n", " y,\n", " y_pred,\n", " sample_weight=sample_weight,\n", " regularization_losses=self.losses,\n", " )\n", "\n", " # Compute gradients\n", " trainable_vars = self.trainable_variables\n", " gradients = tape.gradient(loss, trainable_vars)\n", "\n", " # Update weights\n", " self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n", "\n", " # Update the metrics.\n", " # Metrics are configured in `compile()`.\n", " self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)\n", "\n", " # Return a dict mapping metric names to current value.\n", " # Note that it will include the loss (tracked in self.metrics).\n", " return {m.name: m.result() for m in self.metrics}\n", "\n", "\n", "# Construct and compile an instance of CustomModel\n", "inputs = keras.Input(shape=(32,))\n", "outputs = keras.layers.Dense(1)(inputs)\n", "model = CustomModel(inputs, outputs)\n", "model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n", "\n", "# You can now use sample_weight argument\n", "x = np.random.random((1000, 32))\n", "y = np.random.random((1000, 1))\n", "sw = np.random.random((1000, 1))\n", "model.fit(x, y, sample_weight=sw, epochs=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "03000c5590db" }, "source": [ "## 独自の評価ステップを提供する\n", "\n", "`model.evaluate()` への呼び出しに同じことをする場合はどうしたらよいでしょう?その場合は、まったく同じ方法で `test_step` をオーバーライドします。これは次のようになります。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:05:11.964734Z", "iopub.status.busy": "2022-12-14T21:05:11.964160Z", "iopub.status.idle": "2022-12-14T21:05:12.228229Z", "shell.execute_reply": "2022-12-14T21:05:12.227555Z" }, "id": "999edb22c50e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 4s - loss: 0.3011 - mae: 0.4150" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "31/32 [============================>.] - ETA: 0s - loss: 0.2713 - mae: 0.4184" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2722 - mae: 0.4190\n" ] }, { "data": { "text/plain": [ "[0.27221357822418213, 0.4190024435520172]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class CustomModel(keras.Model):\n", " def test_step(self, data):\n", " # Unpack the data\n", " x, y = data\n", " # Compute predictions\n", " y_pred = self(x, training=False)\n", " # Updates the metrics tracking the loss\n", " self.compiled_loss(y, y_pred, regularization_losses=self.losses)\n", " # Update the metrics.\n", " self.compiled_metrics.update_state(y, y_pred)\n", " # Return a dict mapping metric names to current value.\n", " # Note that it will include the loss (tracked in self.metrics).\n", " return {m.name: m.result() for m in self.metrics}\n", "\n", "\n", "# Construct an instance of CustomModel\n", "inputs = keras.Input(shape=(32,))\n", "outputs = keras.layers.Dense(1)(inputs)\n", "model = CustomModel(inputs, outputs)\n", "model.compile(loss=\"mse\", metrics=[\"mae\"])\n", "\n", "# Evaluate with our custom test_step\n", "x = np.random.random((1000, 32))\n", "y = np.random.random((1000, 1))\n", "model.evaluate(x, y)" ] }, { "cell_type": "markdown", "metadata": { "id": "9e6a662e6588" }, "source": [ "## まとめ: エンドツーエンド GAN の例\n", "\n", "ここで学んだことをすべて採り入れたエンドツーエンドの例を見てみましょう。\n", "\n", "以下を検討してみましょう。\n", "\n", "- 28x28x1 の画像を生成するジェネレーターネットワーク。\n", "- 28x28x1 の画像を 2 つのクラス(「偽物」と「本物」)に分類するディスクリミネーターネットワーク。\n", "- それぞれに 1 つのオプティマイザ。\n", "- ディスクリミネーターをトレーニングする損失関数。\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:05:12.232016Z", "iopub.status.busy": "2022-12-14T21:05:12.231345Z", "iopub.status.idle": "2022-12-14T21:05:12.349017Z", "shell.execute_reply": "2022-12-14T21:05:12.348367Z" }, "id": "6748db01dc7c" }, "outputs": [], "source": [ "from tensorflow.keras import layers\n", "\n", "# Create the discriminator\n", "discriminator = keras.Sequential(\n", " [\n", " keras.Input(shape=(28, 28, 1)),\n", " layers.Conv2D(64, (3, 3), strides=(2, 2), padding=\"same\"),\n", " layers.LeakyReLU(alpha=0.2),\n", " layers.Conv2D(128, (3, 3), strides=(2, 2), padding=\"same\"),\n", " layers.LeakyReLU(alpha=0.2),\n", " layers.GlobalMaxPooling2D(),\n", " layers.Dense(1),\n", " ],\n", " name=\"discriminator\",\n", ")\n", "\n", "# Create the generator\n", "latent_dim = 128\n", "generator = keras.Sequential(\n", " [\n", " keras.Input(shape=(latent_dim,)),\n", " # We want to generate 128 coefficients to reshape into a 7x7x128 map\n", " layers.Dense(7 * 7 * 128),\n", " layers.LeakyReLU(alpha=0.2),\n", " layers.Reshape((7, 7, 128)),\n", " layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n", " layers.LeakyReLU(alpha=0.2),\n", " layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n", " layers.LeakyReLU(alpha=0.2),\n", " layers.Conv2D(1, (7, 7), padding=\"same\", activation=\"sigmoid\"),\n", " ],\n", " name=\"generator\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "801e8dd0c92a" }, "source": [ "ここにフィーチャーコンプリートの GAN クラスがあります。`compile()`をオーバーライドして独自のシグネチャを使用することにより、GAN アルゴリズム全体を`train_step`の 17 行で実装しています。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:05:12.352806Z", "iopub.status.busy": "2022-12-14T21:05:12.352268Z", "iopub.status.idle": "2022-12-14T21:05:12.361186Z", "shell.execute_reply": "2022-12-14T21:05:12.360573Z" }, "id": "bc3fb4111393" }, "outputs": [], "source": [ "class GAN(keras.Model):\n", " def __init__(self, discriminator, generator, latent_dim):\n", " super(GAN, self).__init__()\n", " self.discriminator = discriminator\n", " self.generator = generator\n", " self.latent_dim = latent_dim\n", "\n", " def compile(self, d_optimizer, g_optimizer, loss_fn):\n", " super(GAN, self).compile()\n", " self.d_optimizer = d_optimizer\n", " self.g_optimizer = g_optimizer\n", " self.loss_fn = loss_fn\n", "\n", " def train_step(self, real_images):\n", " if isinstance(real_images, tuple):\n", " real_images = real_images[0]\n", " # Sample random points in the latent space\n", " batch_size = tf.shape(real_images)[0]\n", " random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n", "\n", " # Decode them to fake images\n", " generated_images = self.generator(random_latent_vectors)\n", "\n", " # Combine them with real images\n", " combined_images = tf.concat([generated_images, real_images], axis=0)\n", "\n", " # Assemble labels discriminating real from fake images\n", " labels = tf.concat(\n", " [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0\n", " )\n", " # Add random noise to the labels - important trick!\n", " labels += 0.05 * tf.random.uniform(tf.shape(labels))\n", "\n", " # Train the discriminator\n", " with tf.GradientTape() as tape:\n", " predictions = self.discriminator(combined_images)\n", " d_loss = self.loss_fn(labels, predictions)\n", " grads = tape.gradient(d_loss, self.discriminator.trainable_weights)\n", " self.d_optimizer.apply_gradients(\n", " zip(grads, self.discriminator.trainable_weights)\n", " )\n", "\n", " # Sample random points in the latent space\n", " random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n", "\n", " # Assemble labels that say \"all real images\"\n", " misleading_labels = tf.zeros((batch_size, 1))\n", "\n", " # Train the generator (note that we should *not* update the weights\n", " # of the discriminator)!\n", " with tf.GradientTape() as tape:\n", " predictions = self.discriminator(self.generator(random_latent_vectors))\n", " g_loss = self.loss_fn(misleading_labels, predictions)\n", " grads = tape.gradient(g_loss, self.generator.trainable_weights)\n", " self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))\n", " return {\"d_loss\": d_loss, \"g_loss\": g_loss}\n" ] }, { "cell_type": "markdown", "metadata": { "id": "095c499a6149" }, "source": [ "試運転してみましょう。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:05:12.364584Z", "iopub.status.busy": "2022-12-14T21:05:12.364043Z", "iopub.status.idle": "2022-12-14T21:05:19.399100Z", "shell.execute_reply": "2022-12-14T21:05:19.398040Z" }, "id": "46832f2077ac" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/100 [..............................] - ETA: 7:19 - d_loss: 0.6902 - g_loss: 0.6724" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5/100 [>.............................] - ETA: 1s - d_loss: 0.6676 - g_loss: 0.6966 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9/100 [=>............................] - ETA: 1s - d_loss: 0.6478 - g_loss: 0.7175" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 13/100 [==>...........................] - ETA: 1s - d_loss: 0.6306 - g_loss: 0.7405" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 17/100 [====>.........................] - ETA: 1s - d_loss: 0.6161 - g_loss: 0.7615" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 21/100 [=====>........................] - ETA: 1s - d_loss: 0.6044 - g_loss: 0.7758" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 25/100 [======>.......................] - ETA: 1s - d_loss: 0.5953 - g_loss: 0.7812" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 29/100 [=======>......................] - ETA: 1s - d_loss: 0.5891 - g_loss: 0.7750" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 33/100 [========>.....................] - ETA: 1s - d_loss: 0.5834 - g_loss: 0.7647" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 37/100 [==========>...................] - ETA: 0s - d_loss: 0.5740 - g_loss: 0.7595" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 41/100 [===========>..................] - ETA: 0s - d_loss: 0.5616 - g_loss: 0.7590" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 45/100 [============>.................] - ETA: 0s - d_loss: 0.5478 - g_loss: 0.7633" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 49/100 [=============>................] - ETA: 0s - d_loss: 0.5336 - g_loss: 0.7715" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 53/100 [==============>...............] - ETA: 0s - d_loss: 0.5197 - g_loss: 0.7812" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 57/100 [================>.............] - ETA: 0s - d_loss: 0.5096 - g_loss: 0.7827" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 61/100 [=================>............] - ETA: 0s - d_loss: 0.5009 - g_loss: 0.7833" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 65/100 [==================>...........] - ETA: 0s - d_loss: 0.4923 - g_loss: 0.7851" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 69/100 [===================>..........] - ETA: 0s - d_loss: 0.4839 - g_loss: 0.7882" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 73/100 [====================>.........] - ETA: 0s - d_loss: 0.4758 - g_loss: 0.7925" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 77/100 [======================>.......] - ETA: 0s - d_loss: 0.4678 - g_loss: 0.7979" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 81/100 [=======================>......] - ETA: 0s - d_loss: 0.4601 - g_loss: 0.8046" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 85/100 [========================>.....] - ETA: 0s - d_loss: 0.4523 - g_loss: 0.8127" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 89/100 [=========================>....] - ETA: 0s - d_loss: 0.4445 - g_loss: 0.8224" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 93/100 [==========================>...] - ETA: 0s - d_loss: 0.4368 - g_loss: 0.8337" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 97/100 [============================>.] - ETA: 0s - d_loss: 0.4290 - g_loss: 0.8469" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "100/100 [==============================] - 6s 15ms/step - d_loss: 0.4212 - g_loss: 0.8618\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Prepare the dataset. We use both the training & test MNIST digits.\n", "batch_size = 64\n", "(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n", "all_digits = np.concatenate([x_train, x_test])\n", "all_digits = all_digits.astype(\"float32\") / 255.0\n", "all_digits = np.reshape(all_digits, (-1, 28, 28, 1))\n", "dataset = tf.data.Dataset.from_tensor_slices(all_digits)\n", "dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)\n", "\n", "gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)\n", "gan.compile(\n", " d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n", " g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n", " loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),\n", ")\n", "\n", "# To limit the execution time, we only train on 100 batches. You can train on\n", "# the entire dataset. You will need about 20 epochs to get nice results.\n", "gan.fit(dataset.take(100), epochs=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "2ed211016c96" }, "source": [ "ディープラーニングの背後にある考え方は単純なわけですから、当然、実装も単純なのです。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "customizing_what_happens_in_fit.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }