{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "b518b04cbfe0" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T22:25:44.023452Z", "iopub.status.busy": "2022-12-14T22:25:44.022853Z", "iopub.status.idle": "2022-12-14T22:25:44.026779Z", "shell.execute_reply": "2022-12-14T22:25:44.026219Z" }, "id": "906e07f6e562" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "a5620ee4049e" }, "source": [ "# Model.fit의 동작 사용자 정의하기" ] }, { "cell_type": "markdown", "metadata": { "id": "0a56ffedf331" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org에서 보기Google Colab에서 실행 View source on GitHub\n", "노트북 다운로드
" ] }, { "cell_type": "markdown", "metadata": { "id": "7ebb4e65ef9b" }, "source": [ "## 시작하기\n", "\n", "감독 학습을 수행할 때 `fit()`를 사용할 수 있으며 모든 것이 원활하게 작동합니다.\n", "\n", "훈련 루프를 처음부터 작성해야 하는 경우, `GradientTape`를 사용하여 모든 세부 사항을 제어할 수 있습니다.\n", "\n", "그러나 사용자 정의 훈련 알고리즘이 필요하지만 콜백, 내장 배포 지원 또는 단계 융합과 같은 `fit()`의 편리한 특성을 계속 활용하려면 어떻게 해야 할까요?\n", "\n", "Keras의 핵심 원칙은 **복잡성의 점진적인 공개**입니다. 항상 점진적으로 저수준 워크플로부터 시작할 수 있어야 합니다. 높은 수준의 기능이 자신의 사용 사례와 정확하게 일치하지 않다고 해서 절망할 필요는 없습니다. 적절한 수준의 고수준 편의를 유지하면서 작은 세부 사항을 보다 효과적으로 제어할 수 있어야 합니다.\n", "\n", "`fit()`를 사용자 정의해야 하는 경우, **`Model` 클래스의 훈련 단계 함수를 재정의**해야 합니다. 이 함수는 모든 데이터 배치에 대해 `fit()`에 의해 호출되는 함수입니다. 그런 다음 평소와 같이 `fit()`을 호출 할 수 있으며 자체 학습 알고리즘을 실행합니다.\n", "\n", "이 패턴은 Functional API를 사용하여 모델을 빌드하는 데 방해가 되지 않습니다. `Sequential` 모델, Functional API 모델, 또는 하위 클래스화된 모델과 관계없이 수행할 수 있습니다.\n", "\n", "어떻게 동작하는지 살펴보겠습니다." ] }, { "cell_type": "markdown", "metadata": { "id": "2849e371b9b6" }, "source": [ "## 설정\n", "\n", "TensorFlow 2.2 이상이 필요합니다." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:25:44.030418Z", "iopub.status.busy": "2022-12-14T22:25:44.029950Z", "iopub.status.idle": "2022-12-14T22:25:45.973057Z", "shell.execute_reply": "2022-12-14T22:25:45.972303Z" }, "id": "4dadb6688663" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 22:25:44.982786: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:25:44.982892: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:25:44.982903: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import tensorflow as tf\n", "from tensorflow import keras" ] }, { "cell_type": "markdown", "metadata": { "id": "9022333acaa7" }, "source": [ "## 첫 번째 간단한 예제\n", "\n", "간단한 예제부터 시작하겠습니다.\n", "\n", "- `keras.Model`을 하위 클래스화하는 새 클래스를 만듭니다.\n", "- `train_step(self, data)` 메서드를 재정의합니다.\n", "- 손실을 포함하여 사전 매핑 메트릭 이름을 현재 값으로 반환합니다.\n", "\n", "입력 인수 `data`는 훈련 데이터에 맞게 전달됩니다.\n", "\n", "- `fit(x, y, ...)`를 호출하여 Numpy 배열을 전달하면 `data`는 튜플 `(x, y)`가 됩니다.\n", "- `tf.data.Dataset`를 전달하는 경우, `fit(dataset, ...)`를 호출하여 `data`가 각 배치에서 `dataset`에 의해 산출됩니다.\n", "\n", "`train_step` 메서드의 본문에서 이미 익숙한 것과 유사한 정기적인 훈련 업데이트를 구현합니다. 중요한 것은 **`self.compiled_loss`를 통해 손실을 계산하여** `compile()`로 전달된 손실 함수를 래핑합니다.\n", "\n", "마찬가지로, `self.compiled_metrics.update_state(y, y_pred)`를 호출하여 `compile()`에 전달된 메트릭의 상태를 업데이트하고, 마지막에 `self.metrics`의 결과를 쿼리하여 현재 값을 검색합니다." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:25:45.977551Z", "iopub.status.busy": "2022-12-14T22:25:45.976882Z", "iopub.status.idle": "2022-12-14T22:25:45.982131Z", "shell.execute_reply": "2022-12-14T22:25:45.981520Z" }, "id": "060c8bf4150d" }, "outputs": [], "source": [ "class CustomModel(keras.Model):\n", " def train_step(self, data):\n", " # Unpack the data. Its structure depends on your model and\n", " # on what you pass to `fit()`.\n", " x, y = data\n", "\n", " with tf.GradientTape() as tape:\n", " y_pred = self(x, training=True) # Forward pass\n", " # Compute the loss value\n", " # (the loss function is configured in `compile()`)\n", " loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)\n", "\n", " # Compute gradients\n", " trainable_vars = self.trainable_variables\n", " gradients = tape.gradient(loss, trainable_vars)\n", " # Update weights\n", " self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n", " # Update metrics (includes the metric that tracks the loss)\n", " self.compiled_metrics.update_state(y, y_pred)\n", " # Return a dict mapping metric names to current value\n", " return {m.name: m.result() for m in self.metrics}\n" ] }, { "cell_type": "markdown", "metadata": { "id": "c9d2cc7a7014" }, "source": [ "다음을 시도해봅시다." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:25:45.985772Z", "iopub.status.busy": "2022-12-14T22:25:45.985218Z", "iopub.status.idle": "2022-12-14T22:25:50.746952Z", "shell.execute_reply": "2022-12-14T22:25:50.746126Z" }, "id": "5e6bd7b554f6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 32s - loss: 1.1769 - mae: 0.9829" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "25/32 [======================>.......] - ETA: 0s - loss: 0.7192 - mae: 0.7248 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 1s 2ms/step - loss: 0.6662 - mae: 0.6898\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.3443 - mae: 0.4840" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "25/32 [======================>.......] - ETA: 0s - loss: 0.2968 - mae: 0.4367" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2882 - mae: 0.4303\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.2676 - mae: 0.4261" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "27/32 [========================>.....] - ETA: 0s - loss: 0.2299 - mae: 0.3877" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2313 - mae: 0.3891\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "\n", "# Construct and compile an instance of CustomModel\n", "inputs = keras.Input(shape=(32,))\n", "outputs = keras.layers.Dense(1)(inputs)\n", "model = CustomModel(inputs, outputs)\n", "model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n", "\n", "# Just use `fit` as usual\n", "x = np.random.random((1000, 32))\n", "y = np.random.random((1000, 1))\n", "model.fit(x, y, epochs=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "a882cb6467d6" }, "source": [ "## 더 낮은 수준으로 구성하기\n", "\n", "당연히 `compile()`에서 손실 함수의 전달을 건너뛰고, 대신 train_step에서 수동으로 모두 수행할 수 있습니다. 메트릭도 마찬가지입니다.\n", "\n", "다음은 옵티마이저를 구성하기 위해 `compile()`만 사용하는 하위 수준의 예입니다.\n", "\n", "- 먼저 손실과 MAE 점수를 추적하기 위해 `Metric` 인스턴스를 생성합니다.\n", "- (메트릭에 대한 `update_state()`를 호출하여) 메트릭의 상태를 업데이트하는 사용자 정의`train_step()`을 구현한 다음, 쿼리하여(`result()`를 통해) 현재 평균 값을 반환하여 진행률 표시줄에 표시되고 모든 콜백에 전달되도록 합니다.\n", "- 각 epoch 사이의 메트릭에 대해 `reset_states()`를 호출해야 합니다. 그렇지 않으면, `result()`를 호출하면 훈련 시작 이후부터 평균이 반환되지만, 일반적으로 epoch당 평균을 사용합니다. 다행히도 프레임워크에서는 다음과 같이 수행할 수 있습니다. 즉, 재설정하려는 매트릭을 모델의 `metrics` 속성에 나열하기만 하면 됩니다. 모델은 각 `fit()` epoch가 시작될 때 또는 `evaluate()` 호출이 시작될 때 여기에 나열된 모든 객체에 대해 `reset_states()`를 호출합니다." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:25:50.750459Z", "iopub.status.busy": "2022-12-14T22:25:50.749777Z", "iopub.status.idle": "2022-12-14T22:25:51.542923Z", "shell.execute_reply": "2022-12-14T22:25:51.542209Z" }, "id": "2308abf5fe7d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 11s - loss: 0.2490 - mae: 0.4097" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "27/32 [========================>.....] - ETA: 0s - loss: 0.2765 - mae: 0.4322 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2728 - mae: 0.4312\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1766 - mae: 0.3184" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "27/32 [========================>.....] - ETA: 0s - loss: 0.2215 - mae: 0.3808" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2168 - mae: 0.3769\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.2760 - mae: 0.4525" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "27/32 [========================>.....] - ETA: 0s - loss: 0.2141 - mae: 0.3754" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2082 - mae: 0.3701\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1593 - mae: 0.3233" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "28/32 [=========================>....] - ETA: 0s - loss: 0.2034 - mae: 0.3662" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.2004 - mae: 0.3634\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1927 - mae: 0.3662" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "28/32 [=========================>....] - ETA: 0s - loss: 0.1899 - mae: 0.3539" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1923 - mae: 0.3558\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss_tracker = keras.metrics.Mean(name=\"loss\")\n", "mae_metric = keras.metrics.MeanAbsoluteError(name=\"mae\")\n", "\n", "\n", "class CustomModel(keras.Model):\n", " def train_step(self, data):\n", " x, y = data\n", "\n", " with tf.GradientTape() as tape:\n", " y_pred = self(x, training=True) # Forward pass\n", " # Compute our own loss\n", " loss = keras.losses.mean_squared_error(y, y_pred)\n", "\n", " # Compute gradients\n", " trainable_vars = self.trainable_variables\n", " gradients = tape.gradient(loss, trainable_vars)\n", "\n", " # Update weights\n", " self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n", "\n", " # Compute our own metrics\n", " loss_tracker.update_state(loss)\n", " mae_metric.update_state(y, y_pred)\n", " return {\"loss\": loss_tracker.result(), \"mae\": mae_metric.result()}\n", "\n", " @property\n", " def metrics(self):\n", " # We list our `Metric` objects here so that `reset_states()` can be\n", " # called automatically at the start of each epoch\n", " # or at the start of `evaluate()`.\n", " # If you don't implement this property, you have to call\n", " # `reset_states()` yourself at the time of your choosing.\n", " return [loss_tracker, mae_metric]\n", "\n", "\n", "# Construct an instance of CustomModel\n", "inputs = keras.Input(shape=(32,))\n", "outputs = keras.layers.Dense(1)(inputs)\n", "model = CustomModel(inputs, outputs)\n", "\n", "# We don't passs a loss or metrics here.\n", "model.compile(optimizer=\"adam\")\n", "\n", "# Just use `fit` as usual -- you can use callbacks, etc.\n", "x = np.random.random((1000, 32))\n", "y = np.random.random((1000, 1))\n", "model.fit(x, y, epochs=5)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "f451e382c6a8" }, "source": [ "## `sample_weight` 및 `class_weight` 지원하기\n", "\n", "첫 번째 기본 예제에서는 샘플 가중치에 대해 언급하지 않았습니다. `fit()` 인수 `sample_weight` 및 `class_weight`를 지원하려면 다음을 수행하면 됩니다.\n", "\n", "- `data` 인수에서 `sample_weight` 패키지를 풉니다.\n", "- `compiled_loss` 및 `compiled_metrics`에 전달합니다(손실 및 메트릭을 위해 `compile()`에 의존하지 않는다면 수동으로 적용할 수도 있습니다).\n", "- 다음은 그 목록입니다." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:25:51.546641Z", "iopub.status.busy": "2022-12-14T22:25:51.546032Z", "iopub.status.idle": "2022-12-14T22:25:52.579166Z", "shell.execute_reply": "2022-12-14T22:25:52.578388Z" }, "id": "522d7281f948" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 13s - loss: 0.1627 - mae: 0.4548" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "25/32 [======================>.......] - ETA: 0s - loss: 0.1440 - mae: 0.4346 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 1s 2ms/step - loss: 0.1389 - mae: 0.4280\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1250 - mae: 0.4456" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "26/32 [=======================>......] - ETA: 0s - loss: 0.1246 - mae: 0.4110" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1296 - mae: 0.4126\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 0s - loss: 0.1353 - mae: 0.4371" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "27/32 [========================>.....] - ETA: 0s - loss: 0.1228 - mae: 0.4011" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 0.1229 - mae: 0.4022\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class CustomModel(keras.Model):\n", " def train_step(self, data):\n", " # Unpack the data. Its structure depends on your model and\n", " # on what you pass to `fit()`.\n", " if len(data) == 3:\n", " x, y, sample_weight = data\n", " else:\n", " sample_weight = None\n", " x, y = data\n", "\n", " with tf.GradientTape() as tape:\n", " y_pred = self(x, training=True) # Forward pass\n", " # Compute the loss value.\n", " # The loss function is configured in `compile()`.\n", " loss = self.compiled_loss(\n", " y,\n", " y_pred,\n", " sample_weight=sample_weight,\n", " regularization_losses=self.losses,\n", " )\n", "\n", " # Compute gradients\n", " trainable_vars = self.trainable_variables\n", " gradients = tape.gradient(loss, trainable_vars)\n", "\n", " # Update weights\n", " self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n", "\n", " # Update the metrics.\n", " # Metrics are configured in `compile()`.\n", " self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)\n", "\n", " # Return a dict mapping metric names to current value.\n", " # Note that it will include the loss (tracked in self.metrics).\n", " return {m.name: m.result() for m in self.metrics}\n", "\n", "\n", "# Construct and compile an instance of CustomModel\n", "inputs = keras.Input(shape=(32,))\n", "outputs = keras.layers.Dense(1)(inputs)\n", "model = CustomModel(inputs, outputs)\n", "model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n", "\n", "# You can now use sample_weight argument\n", "x = np.random.random((1000, 32))\n", "y = np.random.random((1000, 1))\n", "sw = np.random.random((1000, 1))\n", "model.fit(x, y, sample_weight=sw, epochs=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "03000c5590db" }, "source": [ "## 자신만의 평가 단계 제공하기\n", "\n", "`model.evaluate()` 호출에 대해 같은 작업을 수행하려면 어떻게 해야 할까요? 정확히 같은 방식으로 `test_step`을 재정의합니다. 다음과 같습니다." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:25:52.582614Z", "iopub.status.busy": "2022-12-14T22:25:52.581940Z", "iopub.status.idle": "2022-12-14T22:25:52.840030Z", "shell.execute_reply": "2022-12-14T22:25:52.839409Z" }, "id": "999edb22c50e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/32 [..............................] - ETA: 4s - loss: 1.1469 - mae: 0.9393" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/32 [==============================] - 0s 2ms/step - loss: 1.1476 - mae: 0.9479\n" ] }, { "data": { "text/plain": [ "[1.147562026977539, 0.9478825926780701]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class CustomModel(keras.Model):\n", " def test_step(self, data):\n", " # Unpack the data\n", " x, y = data\n", " # Compute predictions\n", " y_pred = self(x, training=False)\n", " # Updates the metrics tracking the loss\n", " self.compiled_loss(y, y_pred, regularization_losses=self.losses)\n", " # Update the metrics.\n", " self.compiled_metrics.update_state(y, y_pred)\n", " # Return a dict mapping metric names to current value.\n", " # Note that it will include the loss (tracked in self.metrics).\n", " return {m.name: m.result() for m in self.metrics}\n", "\n", "\n", "# Construct an instance of CustomModel\n", "inputs = keras.Input(shape=(32,))\n", "outputs = keras.layers.Dense(1)(inputs)\n", "model = CustomModel(inputs, outputs)\n", "model.compile(loss=\"mse\", metrics=[\"mae\"])\n", "\n", "# Evaluate with our custom test_step\n", "x = np.random.random((1000, 32))\n", "y = np.random.random((1000, 1))\n", "model.evaluate(x, y)" ] }, { "cell_type": "markdown", "metadata": { "id": "9e6a662e6588" }, "source": [ "## 마무리: 엔드-투-엔드 GAN 예제\n", "\n", "방금 배운 모든 내용을 활용하는 엔드 투 엔드 예제를 살펴보겠습니다.\n", "\n", "다음을 고려합니다.\n", "\n", "- 생성기 네트워크는 28x28x1 이미지를 생성합니다.\n", "- discriminator 네트워크는 28x28x1 이미지를 두 개의 클래스(\"false\" 및 \"real\")로 분류하기 위한 것입니다.\n", "- 각각 하나의 옵티마이저를 가집니다.\n", "- discriminator를 훈련하는 손실 함수입니다.\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:25:52.843381Z", "iopub.status.busy": "2022-12-14T22:25:52.842861Z", "iopub.status.idle": "2022-12-14T22:25:52.961435Z", "shell.execute_reply": "2022-12-14T22:25:52.960763Z" }, "id": "6748db01dc7c" }, "outputs": [], "source": [ "from tensorflow.keras import layers\n", "\n", "# Create the discriminator\n", "discriminator = keras.Sequential(\n", " [\n", " keras.Input(shape=(28, 28, 1)),\n", " layers.Conv2D(64, (3, 3), strides=(2, 2), padding=\"same\"),\n", " layers.LeakyReLU(alpha=0.2),\n", " layers.Conv2D(128, (3, 3), strides=(2, 2), padding=\"same\"),\n", " layers.LeakyReLU(alpha=0.2),\n", " layers.GlobalMaxPooling2D(),\n", " layers.Dense(1),\n", " ],\n", " name=\"discriminator\",\n", ")\n", "\n", "# Create the generator\n", "latent_dim = 128\n", "generator = keras.Sequential(\n", " [\n", " keras.Input(shape=(latent_dim,)),\n", " # We want to generate 128 coefficients to reshape into a 7x7x128 map\n", " layers.Dense(7 * 7 * 128),\n", " layers.LeakyReLU(alpha=0.2),\n", " layers.Reshape((7, 7, 128)),\n", " layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n", " layers.LeakyReLU(alpha=0.2),\n", " layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n", " layers.LeakyReLU(alpha=0.2),\n", " layers.Conv2D(1, (7, 7), padding=\"same\", activation=\"sigmoid\"),\n", " ],\n", " name=\"generator\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "801e8dd0c92a" }, "source": [ "다음은 자신만의 서명을 사용하기 위해 `compile()`을 재정의하고 `train_step` 17줄로 전체 GAN 알고리즘을 구현하는 특성 완료형 GAN 클래스입니다." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:25:52.965563Z", "iopub.status.busy": "2022-12-14T22:25:52.965303Z", "iopub.status.idle": "2022-12-14T22:25:52.973749Z", "shell.execute_reply": "2022-12-14T22:25:52.973041Z" }, "id": "bc3fb4111393" }, "outputs": [], "source": [ "class GAN(keras.Model):\n", " def __init__(self, discriminator, generator, latent_dim):\n", " super(GAN, self).__init__()\n", " self.discriminator = discriminator\n", " self.generator = generator\n", " self.latent_dim = latent_dim\n", "\n", " def compile(self, d_optimizer, g_optimizer, loss_fn):\n", " super(GAN, self).compile()\n", " self.d_optimizer = d_optimizer\n", " self.g_optimizer = g_optimizer\n", " self.loss_fn = loss_fn\n", "\n", " def train_step(self, real_images):\n", " if isinstance(real_images, tuple):\n", " real_images = real_images[0]\n", " # Sample random points in the latent space\n", " batch_size = tf.shape(real_images)[0]\n", " random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n", "\n", " # Decode them to fake images\n", " generated_images = self.generator(random_latent_vectors)\n", "\n", " # Combine them with real images\n", " combined_images = tf.concat([generated_images, real_images], axis=0)\n", "\n", " # Assemble labels discriminating real from fake images\n", " labels = tf.concat(\n", " [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0\n", " )\n", " # Add random noise to the labels - important trick!\n", " labels += 0.05 * tf.random.uniform(tf.shape(labels))\n", "\n", " # Train the discriminator\n", " with tf.GradientTape() as tape:\n", " predictions = self.discriminator(combined_images)\n", " d_loss = self.loss_fn(labels, predictions)\n", " grads = tape.gradient(d_loss, self.discriminator.trainable_weights)\n", " self.d_optimizer.apply_gradients(\n", " zip(grads, self.discriminator.trainable_weights)\n", " )\n", "\n", " # Sample random points in the latent space\n", " random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n", "\n", " # Assemble labels that say \"all real images\"\n", " misleading_labels = tf.zeros((batch_size, 1))\n", "\n", " # Train the generator (note that we should *not* update the weights\n", " # of the discriminator)!\n", " with tf.GradientTape() as tape:\n", " predictions = self.discriminator(self.generator(random_latent_vectors))\n", " g_loss = self.loss_fn(misleading_labels, predictions)\n", " grads = tape.gradient(g_loss, self.generator.trainable_weights)\n", " self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))\n", " return {\"d_loss\": d_loss, \"g_loss\": g_loss}\n" ] }, { "cell_type": "markdown", "metadata": { "id": "095c499a6149" }, "source": [ "테스트해 봅시다." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:25:52.976856Z", "iopub.status.busy": "2022-12-14T22:25:52.976371Z", "iopub.status.idle": "2022-12-14T22:25:59.298016Z", "shell.execute_reply": "2022-12-14T22:25:59.297207Z" }, "id": "46832f2077ac" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/100 [..............................] - ETA: 6:12 - d_loss: 0.6858 - g_loss: 0.6390" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5/100 [>.............................] - ETA: 1s - d_loss: 0.6609 - g_loss: 0.6753 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9/100 [=>............................] - ETA: 1s - d_loss: 0.6379 - g_loss: 0.7131" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 13/100 [==>...........................] - ETA: 1s - d_loss: 0.6173 - g_loss: 0.7479" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 17/100 [====>.........................] - ETA: 1s - d_loss: 0.5990 - g_loss: 0.7753" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 21/100 [=====>........................] - ETA: 1s - d_loss: 0.5823 - g_loss: 0.7955" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 25/100 [======>.......................] - ETA: 1s - d_loss: 0.5676 - g_loss: 0.8086" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 29/100 [=======>......................] - ETA: 1s - d_loss: 0.5565 - g_loss: 0.8112" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 33/100 [========>.....................] - ETA: 1s - d_loss: 0.5498 - g_loss: 0.8018" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 37/100 [==========>...................] - ETA: 0s - d_loss: 0.5449 - g_loss: 0.7885" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 41/100 [===========>..................] - ETA: 0s - d_loss: 0.5392 - g_loss: 0.7767" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 45/100 [============>.................] - ETA: 0s - d_loss: 0.5321 - g_loss: 0.7698" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 49/100 [=============>................] - ETA: 0s - d_loss: 0.5226 - g_loss: 0.7705" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 53/100 [==============>...............] - ETA: 0s - d_loss: 0.5127 - g_loss: 0.7744" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 57/100 [================>.............] - ETA: 0s - d_loss: 0.5031 - g_loss: 0.7792" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 61/100 [=================>............] - ETA: 0s - d_loss: 0.4952 - g_loss: 0.7813" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 65/100 [==================>...........] - ETA: 0s - d_loss: 0.4876 - g_loss: 0.7841" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 69/100 [===================>..........] - ETA: 0s - d_loss: 0.4801 - g_loss: 0.7873" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 73/100 [====================>.........] - ETA: 0s - d_loss: 0.4727 - g_loss: 0.7906" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 77/100 [======================>.......] - ETA: 0s - d_loss: 0.4655 - g_loss: 0.7947" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 81/100 [=======================>......] - ETA: 0s - d_loss: 0.4585 - g_loss: 0.7999" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 85/100 [========================>.....] - ETA: 0s - d_loss: 0.4516 - g_loss: 0.8060" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 89/100 [=========================>....] - ETA: 0s - d_loss: 0.4447 - g_loss: 0.8134" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 93/100 [==========================>...] - ETA: 0s - d_loss: 0.4379 - g_loss: 0.8219" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 97/100 [============================>.] - ETA: 0s - d_loss: 0.4311 - g_loss: 0.8316" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "100/100 [==============================] - 5s 15ms/step - d_loss: 0.4243 - g_loss: 0.8427\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Prepare the dataset. We use both the training & test MNIST digits.\n", "batch_size = 64\n", "(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n", "all_digits = np.concatenate([x_train, x_test])\n", "all_digits = all_digits.astype(\"float32\") / 255.0\n", "all_digits = np.reshape(all_digits, (-1, 28, 28, 1))\n", "dataset = tf.data.Dataset.from_tensor_slices(all_digits)\n", "dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)\n", "\n", "gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)\n", "gan.compile(\n", " d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n", " g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n", " loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),\n", ")\n", "\n", "# To limit the execution time, we only train on 100 batches. You can train on\n", "# the entire dataset. You will need about 20 epochs to get nice results.\n", "gan.fit(dataset.take(100), epochs=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "2ed211016c96" }, "source": [ "딥 러닝의 기본 개념은 간단합니다. 구현이 고통스러울 이유가 없습니다." ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "customizing_what_happens_in_fit.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }