{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "b518b04cbfe0" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T22:54:50.484142Z", "iopub.status.busy": "2022-12-14T22:54:50.483754Z", "iopub.status.idle": "2022-12-14T22:54:50.487583Z", "shell.execute_reply": "2022-12-14T22:54:50.487041Z" }, "id": "906e07f6e562" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "e2d97c7e31aa" }, "source": [ "# 하위 클래스화를 통한 새로운 레이어 및 모델 만들기" ] }, { "cell_type": "markdown", "metadata": { "id": "4e352274064f" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org에서 보기Google Colab에서 실행하기GitHub에서소스 보기노트북 다운로드하기
" ] }, { "cell_type": "markdown", "metadata": { "id": "8d4ac441b1fc" }, "source": [ "## !pip install -U tf-hub-nightly
import tensorflow_hub as hub

from tensorflow.keras import layers" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:50.491091Z", "iopub.status.busy": "2022-12-14T22:54:50.490663Z", "iopub.status.idle": "2022-12-14T22:54:52.413742Z", "shell.execute_reply": "2022-12-14T22:54:52.413062Z" }, "id": "4e7dce39dd1d" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 22:54:51.448414: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:54:51.448525: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:54:51.448535: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import tensorflow as tf\n", "from tensorflow import keras" ] }, { "cell_type": "markdown", "metadata": { "id": "7b363673d96c" }, "source": [ "## `Layer` 클래스: 상태(가중치)와 일부 계산의 조합\n", "\n", "Keras의 주요 추상화 중 하나는 `Layer` 클래스입니다. 레이어는 상태(레이어의 \"가중치\")와 입력에서 출력으로의 변환(\"호출, 레이어의 정방향 패스\")을 모두 캡슐화합니다.\n", "\n", "다음은 밀집 레이어입니다. 상태는 변수 `w` 및 `b`입니다." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:52.417865Z", "iopub.status.busy": "2022-12-14T22:54:52.417467Z", "iopub.status.idle": "2022-12-14T22:54:52.423104Z", "shell.execute_reply": "2022-12-14T22:54:52.422299Z" }, "id": "59b8317dbd3c" }, "outputs": [], "source": [ "class Linear(keras.layers.Layer):\n", " def __init__(self, units=32, input_dim=32):\n", " super(Linear, self).__init__()\n", " w_init = tf.random_normal_initializer()\n", " self.w = tf.Variable(\n", " initial_value=w_init(shape=(input_dim, units), dtype=\"float32\"),\n", " trainable=True,\n", " )\n", " b_init = tf.zeros_initializer()\n", " self.b = tf.Variable(\n", " initial_value=b_init(shape=(units,), dtype=\"float32\"), trainable=True\n", " )\n", "\n", " def call(self, inputs):\n", " return tf.matmul(inputs, self.w) + self.b\n" ] }, { "cell_type": "markdown", "metadata": { "id": "dac8fb03a642" }, "source": [ "파이썬 함수와 매우 유사한 일부 텐서 입력에서 레이어를 호출하여 레이어를 사용합니다." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:52.426353Z", "iopub.status.busy": "2022-12-14T22:54:52.425773Z", "iopub.status.idle": "2022-12-14T22:54:56.102784Z", "shell.execute_reply": "2022-12-14T22:54:56.101872Z" }, "id": "cdcd15d5e68a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[-0.07012919 -0.02425323 0.06503569 0.01101883]\n", " [-0.07012919 -0.02425323 0.06503569 0.01101883]], shape=(2, 4), dtype=float32)\n" ] } ], "source": [ "x = tf.ones((2, 2))\n", "linear_layer = Linear(4, 2)\n", "y = linear_layer(x)\n", "print(y)" ] }, { "cell_type": "markdown", "metadata": { "id": "382960020a56" }, "source": [ "가중치 `w`와 `b`는 레이어 속성으로 설정될 때 레이어에 의해 자동으로 추적됩니다." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.106632Z", "iopub.status.busy": "2022-12-14T22:54:56.105970Z", "iopub.status.idle": "2022-12-14T22:54:56.109858Z", "shell.execute_reply": "2022-12-14T22:54:56.109124Z" }, "id": "d3d875af9465" }, "outputs": [], "source": [ "assert linear_layer.weights == [linear_layer.w, linear_layer.b]" ] }, { "cell_type": "markdown", "metadata": { "id": "ec9d72aa7538" }, "source": [ "레이어에 가중치를 추가하는 더 빠른 바로 가기에 액세스할 수도 있습니다. `add_weight()` 메서드:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.113539Z", "iopub.status.busy": "2022-12-14T22:54:56.112964Z", "iopub.status.idle": "2022-12-14T22:54:56.128216Z", "shell.execute_reply": "2022-12-14T22:54:56.127514Z" }, "id": "168548eba841" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 0.01591256 -0.08179533 -0.01888242 0.05522368]\n", " [ 0.01591256 -0.08179533 -0.01888242 0.05522368]], shape=(2, 4), dtype=float32)\n" ] } ], "source": [ "class Linear(keras.layers.Layer):\n", " def __init__(self, units=32, input_dim=32):\n", " super(Linear, self).__init__()\n", " self.w = self.add_weight(\n", " shape=(input_dim, units), initializer=\"random_normal\", trainable=True\n", " )\n", " self.b = self.add_weight(shape=(units,), initializer=\"zeros\", trainable=True)\n", "\n", " def call(self, inputs):\n", " return tf.matmul(inputs, self.w) + self.b\n", "\n", "\n", "x = tf.ones((2, 2))\n", "linear_layer = Linear(4, 2)\n", "y = linear_layer(x)\n", "print(y)" ] }, { "cell_type": "markdown", "metadata": { "id": "070ea9b4db6c" }, "source": [ "## 레이어는 훈련 불가능한 가중치를 가질 수 있습니다\n", "\n", "훈련 가능한 가중치 외에도 훈련 불가능한 가중치를 레이어에 추가할 수 있습니다. 이러한 가중치는 레이어를 훈련할 때 역전파 동안 고려되지 않아야 합니다.\n", "\n", "훈련 불가능한 가중치를 추가 및 사용하는 방법은 다음과 같습니다." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.131723Z", "iopub.status.busy": "2022-12-14T22:54:56.131127Z", "iopub.status.idle": "2022-12-14T22:54:56.144690Z", "shell.execute_reply": "2022-12-14T22:54:56.143992Z" }, "id": "7c4cb404145f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2. 2.]\n", "[4. 4.]\n" ] } ], "source": [ "class ComputeSum(keras.layers.Layer):\n", " def __init__(self, input_dim):\n", " super(ComputeSum, self).__init__()\n", " self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)\n", "\n", " def call(self, inputs):\n", " self.total.assign_add(tf.reduce_sum(inputs, axis=0))\n", " return self.total\n", "\n", "\n", "x = tf.ones((2, 2))\n", "my_sum = ComputeSum(2)\n", "y = my_sum(x)\n", "print(y.numpy())\n", "y = my_sum(x)\n", "print(y.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "40f5b74d3d87" }, "source": [ "`layer.weights`의 일부이지만, 훈련 불가능한 가중치로 분류됩니다." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.147613Z", "iopub.status.busy": "2022-12-14T22:54:56.147394Z", "iopub.status.idle": "2022-12-14T22:54:56.151610Z", "shell.execute_reply": "2022-12-14T22:54:56.150750Z" }, "id": "3d4db4ef4fa4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "weights: 1\n", "non-trainable weights: 1\n", "trainable_weights: []\n" ] } ], "source": [ "print(\"weights:\", len(my_sum.weights))\n", "print(\"non-trainable weights:\", len(my_sum.non_trainable_weights))\n", "\n", "# It's not included in the trainable weights:\n", "print(\"trainable_weights:\", my_sum.trainable_weights)" ] }, { "cell_type": "markdown", "metadata": { "id": "fe6942aff7c6" }, "source": [ "## 모범 사례: 입력 형상이 알려질 때까지 가중치 생성 지연하기\n", "\n", "위의 `Linear` 레이어는 `__init__()`에서 가중치 `w` 및 `b`의 형상을 계산하는 데 사용되는 `input_dim` 인수를 사용했습니다." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.154978Z", "iopub.status.busy": "2022-12-14T22:54:56.154400Z", "iopub.status.idle": "2022-12-14T22:54:56.158750Z", "shell.execute_reply": "2022-12-14T22:54:56.158005Z" }, "id": "275b68d5ea9f" }, "outputs": [], "source": [ "class Linear(keras.layers.Layer):\n", " def __init__(self, units=32, input_dim=32):\n", " super(Linear, self).__init__()\n", " self.w = self.add_weight(\n", " shape=(input_dim, units), initializer=\"random_normal\", trainable=True\n", " )\n", " self.b = self.add_weight(shape=(units,), initializer=\"zeros\", trainable=True)\n", "\n", " def call(self, inputs):\n", " return tf.matmul(inputs, self.w) + self.b\n" ] }, { "cell_type": "markdown", "metadata": { "id": "5ebcacebb348" }, "source": [ "대부분의 경우, 입력의 크기를 미리 알지 못할 수 있으며, 레이어를 인스턴스화한 후 얼마 지나지 않아 해당 값을 알게 되면 가중치를 지연 생성하고자 합니다.\n", "\n", "Keras API에서는 레이어의 `build(self, inputs_shape)` 메서드에서 레이어 가중치를 만드는 것이 좋습니다. 다음과 같습니다." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.162362Z", "iopub.status.busy": "2022-12-14T22:54:56.161725Z", "iopub.status.idle": "2022-12-14T22:54:56.166640Z", "shell.execute_reply": "2022-12-14T22:54:56.165866Z" }, "id": "118c899f427e" }, "outputs": [], "source": [ "class Linear(keras.layers.Layer):\n", " def __init__(self, units=32):\n", " super(Linear, self).__init__()\n", " self.units = units\n", "\n", " def build(self, input_shape):\n", " self.w = self.add_weight(\n", " shape=(input_shape[-1], self.units),\n", " initializer=\"random_normal\",\n", " trainable=True,\n", " )\n", " self.b = self.add_weight(\n", " shape=(self.units,), initializer=\"random_normal\", trainable=True\n", " )\n", "\n", " def call(self, inputs):\n", " return tf.matmul(inputs, self.w) + self.b\n" ] }, { "cell_type": "markdown", "metadata": { "id": "78061e0583c6" }, "source": [ "레이어의 `__call__()` 메서드는 처음 호출될 때 자동으로 빌드를 실행합니다. 지연되어 사용하기 쉬운 레이어입니다." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.169710Z", "iopub.status.busy": "2022-12-14T22:54:56.169235Z", "iopub.status.idle": "2022-12-14T22:54:56.181608Z", "shell.execute_reply": "2022-12-14T22:54:56.181038Z" }, "id": "0697afb97bc1" }, "outputs": [], "source": [ "# At instantiation, we don't know on what inputs this is going to get called\n", "linear_layer = Linear(32)\n", "\n", "# The layer's weights are created dynamically the first time the layer is called\n", "y = linear_layer(x)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "51b81f42b466" }, "source": [ "위에 표시된 대로 `build()`를 별도로 구현하면 가중치를 한 번만 생성하는 것과 모든 호출에서 가중치를 사용하는 것을 잘 구분할 수 있습니다. 그러나 일부 고급 사용자 지정 레이어의 경우 상태 생성과 계산을 분리하는 것이 비실용적일 수 있습니다. 레이어 구현자는 가중치 생성을 첫 번째 `__call__()`로 연기할 수 있지만 이후 호출에서 동일한 가중치를 사용하도록 주의해야 합니다. 또한 `__call__()`은 `tf.function` 내부에서 처음으로 실행될 가능성이 높기 때문에 `__call__()`에서 발생하는 모든 변수 생성은 `tf.init_scope`로 래핑되어야 합니다." ] }, { "cell_type": "markdown", "metadata": { "id": "0b7a45f57610" }, "source": [ "## 재귀적으로 구성 가능한 레이어\n", "\n", "또 다른 인스턴스의 속성으로 Layer 인스턴스를 할당하면 외부 레이어가 내부 레이어로 생성한 가중치를 추적하기 시작합니다.\n", "\n", "`__init__()` 메서드에서 이러한 서브 레이어를 만들고 가중치를 빌드하도록 트리거할 수 있게 첫 번째 `__call__()`에 그대로 두는 것이 좋습니다." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.185139Z", "iopub.status.busy": "2022-12-14T22:54:56.184551Z", "iopub.status.idle": "2022-12-14T22:54:56.214947Z", "shell.execute_reply": "2022-12-14T22:54:56.214203Z" }, "id": "1aaaf82ab8ce" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "weights: 6\n", "trainable weights: 6\n" ] } ], "source": [ "class MLPBlock(keras.layers.Layer):\n", " def __init__(self):\n", " super(MLPBlock, self).__init__()\n", " self.linear_1 = Linear(32)\n", " self.linear_2 = Linear(32)\n", " self.linear_3 = Linear(1)\n", "\n", " def call(self, inputs):\n", " x = self.linear_1(inputs)\n", " x = tf.nn.relu(x)\n", " x = self.linear_2(x)\n", " x = tf.nn.relu(x)\n", " return self.linear_3(x)\n", "\n", "\n", "mlp = MLPBlock()\n", "y = mlp(tf.ones(shape=(3, 64))) # The first call to the `mlp` will create the weights\n", "print(\"weights:\", len(mlp.weights))\n", "print(\"trainable weights:\", len(mlp.trainable_weights))" ] }, { "cell_type": "markdown", "metadata": { "id": "2bf11b296bd2" }, "source": [ "## `add_loss()` 메서드\n", "\n", "레이어의 `call()` 메서드를 작성할 때는 훈련 루프를 작성할 때 나중에 사용하려는 손실 텐서를 만들 수 있습니다. `self.add_loss(value)`를 호출하면 됩니다." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.218322Z", "iopub.status.busy": "2022-12-14T22:54:56.217743Z", "iopub.status.idle": "2022-12-14T22:54:56.222011Z", "shell.execute_reply": "2022-12-14T22:54:56.221249Z" }, "id": "ba2782dc0879" }, "outputs": [], "source": [ "# A layer that creates an activity regularization loss\n", "class ActivityRegularizationLayer(keras.layers.Layer):\n", " def __init__(self, rate=1e-2):\n", " super(ActivityRegularizationLayer, self).__init__()\n", " self.rate = rate\n", "\n", " def call(self, inputs):\n", " self.add_loss(self.rate * tf.reduce_sum(inputs))\n", " return inputs\n" ] }, { "cell_type": "markdown", "metadata": { "id": "a883b230a9e9" }, "source": [ "이러한 손실(내부 레이어에서 생성된 손실 포함)은 `layer.losses`를 통해 검색할 수 있습니다. 이 속성은 모든 `__call__()`이 시작될 때 최상위 레이어로 재설정되므로 `layer.losses`에는 항상 마지막 정방향 패스에서 생성된 손실값이 포함됩니다." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.225372Z", "iopub.status.busy": "2022-12-14T22:54:56.224784Z", "iopub.status.idle": "2022-12-14T22:54:56.234226Z", "shell.execute_reply": "2022-12-14T22:54:56.233656Z" }, "id": "b56d223a30cd" }, "outputs": [], "source": [ "class OuterLayer(keras.layers.Layer):\n", " def __init__(self):\n", " super(OuterLayer, self).__init__()\n", " self.activity_reg = ActivityRegularizationLayer(1e-2)\n", "\n", " def call(self, inputs):\n", " return self.activity_reg(inputs)\n", "\n", "\n", "layer = OuterLayer()\n", "assert len(layer.losses) == 0 # No losses yet since the layer has never been called\n", "\n", "_ = layer(tf.zeros(1, 1))\n", "assert len(layer.losses) == 1 # We created one loss value\n", "\n", "# `layer.losses` gets reset at the start of each __call__\n", "_ = layer(tf.zeros(1, 1))\n", "assert len(layer.losses) == 1 # This is the loss created during the call above" ] }, { "cell_type": "markdown", "metadata": { "id": "0809dec680ff" }, "source": [ "또한, `loss` 속성에는 내부 레이어의 가중치에 대해 생성된 정규화 손실도 포함됩니다." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.237473Z", "iopub.status.busy": "2022-12-14T22:54:56.236915Z", "iopub.status.idle": "2022-12-14T22:54:56.256562Z", "shell.execute_reply": "2022-12-14T22:54:56.255874Z" }, "id": "41016153e983" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[]\n" ] } ], "source": [ "class OuterLayerWithKernelRegularizer(keras.layers.Layer):\n", " def __init__(self):\n", " super(OuterLayerWithKernelRegularizer, self).__init__()\n", " self.dense = keras.layers.Dense(\n", " 32, kernel_regularizer=tf.keras.regularizers.l2(1e-3)\n", " )\n", "\n", " def call(self, inputs):\n", " return self.dense(inputs)\n", "\n", "\n", "layer = OuterLayerWithKernelRegularizer()\n", "_ = layer(tf.zeros((1, 1)))\n", "\n", "# This is `1e-3 * sum(layer.dense.kernel ** 2)`,\n", "# created by the `kernel_regularizer` above.\n", "print(layer.losses)" ] }, { "cell_type": "markdown", "metadata": { "id": "589465e06e4f" }, "source": [ "이러한 손실은 다음과 같이 훈련 루프를 작성할 때 고려됩니다.\n", "\n", "```python\n", "# Instantiate an optimizer.\n", "optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)\n", "loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", "\n", "# Iterate over the batches of a dataset.\n", "for x_batch_train, y_batch_train in train_dataset:\n", " with tf.GradientTape() as tape:\n", " logits = layer(x_batch_train) # Logits for this minibatch\n", " # Loss value for this minibatch\n", " loss_value = loss_fn(y_batch_train, logits)\n", " # Add extra losses created during this forward pass:\n", " loss_value += sum(model.losses)\n", "\n", " grads = tape.gradient(loss_value, model.trainable_weights)\n", " optimizer.apply_gradients(zip(grads, model.trainable_weights))\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "7fb41ca8c3b0" }, "source": [ "훈련 루프 작성에 대한 자세한 가이드는 [처음부터 훈련 루프 작성하기 가이드](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch/)를 참조하세요.\n", "\n", "이러한 손실은 `fit()`에서도 완벽하게 작동합니다(손실이 있는 경우, 자동으로 합산되어 주 손실에 추가됨)." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.260114Z", "iopub.status.busy": "2022-12-14T22:54:56.259483Z", "iopub.status.idle": "2022-12-14T22:54:56.658971Z", "shell.execute_reply": "2022-12-14T22:54:56.658230Z" }, "id": "769bc6612ebf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 0.1996" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 179ms/step - loss: 0.1996\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 0.0425" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 34ms/step - loss: 0.0425\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "\n", "inputs = keras.Input(shape=(3,))\n", "outputs = ActivityRegularizationLayer()(inputs)\n", "model = keras.Model(inputs, outputs)\n", "\n", "# If there is a loss passed in `compile`, the regularization\n", "# losses get added to it\n", "model.compile(optimizer=\"adam\", loss=\"mse\")\n", "model.fit(np.random.random((2, 3)), np.random.random((2, 3)))\n", "\n", "# It's also possible not to pass any loss in `compile`,\n", "# since the model already has a loss to minimize, via the `add_loss`\n", "# call during the forward pass!\n", "model.compile(optimizer=\"adam\")\n", "model.fit(np.random.random((2, 3)), np.random.random((2, 3)))" ] }, { "cell_type": "markdown", "metadata": { "id": "149c71e442bb" }, "source": [ "## `add_metric()` 메서드\n", "\n", "`add_loss()`와 마찬가지로, 레이어에는 훈련 중 수량의 이동 평균을 추적하기 위한 `add_metric()` 메서드도 있습니다.\n", "\n", "다음 \"로지스틱 엔드포인트\" 레이어를 고려합니다. 입력 예측 및 목표치로 사용하여 `add_loss()`를 통해 추적하는 손실을 계산하고 `add_metric()`을 통해 추적하는 정확도 스칼라를 계산합니다." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.662323Z", "iopub.status.busy": "2022-12-14T22:54:56.662021Z", "iopub.status.idle": "2022-12-14T22:54:56.667565Z", "shell.execute_reply": "2022-12-14T22:54:56.666879Z" }, "id": "bfb2df515096" }, "outputs": [], "source": [ "class LogisticEndpoint(keras.layers.Layer):\n", " def __init__(self, name=None):\n", " super(LogisticEndpoint, self).__init__(name=name)\n", " self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)\n", " self.accuracy_fn = keras.metrics.BinaryAccuracy()\n", "\n", " def call(self, targets, logits, sample_weights=None):\n", " # Compute the training-time loss value and add it\n", " # to the layer using `self.add_loss()`.\n", " loss = self.loss_fn(targets, logits, sample_weights)\n", " self.add_loss(loss)\n", "\n", " # Log accuracy as a metric and add it\n", " # to the layer using `self.add_metric()`.\n", " acc = self.accuracy_fn(targets, logits, sample_weights)\n", " self.add_metric(acc, name=\"accuracy\")\n", "\n", " # Return the inference-time prediction tensor (for `.predict()`).\n", " return tf.nn.softmax(logits)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "e68f88373800" }, "source": [ "이러한 방식으로 추적되는 메트릭은 `layer.metrics`를 통해 액세스할 수 있습니다." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.671040Z", "iopub.status.busy": "2022-12-14T22:54:56.670365Z", "iopub.status.idle": "2022-12-14T22:54:56.700373Z", "shell.execute_reply": "2022-12-14T22:54:56.699661Z" }, "id": "1834d74450b6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "layer.metrics: []\n", "current accuracy value: 1.0\n" ] } ], "source": [ "layer = LogisticEndpoint()\n", "\n", "targets = tf.ones((2, 2))\n", "logits = tf.ones((2, 2))\n", "y = layer(targets, logits)\n", "\n", "print(\"layer.metrics:\", layer.metrics)\n", "print(\"current accuracy value:\", float(layer.metrics[0].result()))" ] }, { "cell_type": "markdown", "metadata": { "id": "546cfbd4ea05" }, "source": [ "`add_loss()`와 마찬가지로, 이러한 메트릭은 `fit()`의해 추적됩니다." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:56.703608Z", "iopub.status.busy": "2022-12-14T22:54:56.703112Z", "iopub.status.idle": "2022-12-14T22:54:57.441661Z", "shell.execute_reply": "2022-12-14T22:54:57.440892Z" }, "id": "f5e74cb4da34" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 0.8806 - binary_accuracy: 0.0000e+00" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 1s 588ms/step - loss: 0.8806 - binary_accuracy: 0.0000e+00\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs = keras.Input(shape=(3,), name=\"inputs\")\n", "targets = keras.Input(shape=(10,), name=\"targets\")\n", "logits = keras.layers.Dense(10)(inputs)\n", "predictions = LogisticEndpoint(name=\"predictions\")(logits, targets)\n", "\n", "model = keras.Model(inputs=[inputs, targets], outputs=predictions)\n", "model.compile(optimizer=\"adam\")\n", "\n", "data = {\n", " \"inputs\": np.random.random((3, 3)),\n", " \"targets\": np.random.random((3, 10)),\n", "}\n", "model.fit(data)" ] }, { "cell_type": "markdown", "metadata": { "id": "4012fa8683e5" }, "source": [ "## 레이어에서 선택적으로 직렬화를 활성화할 수 있습니다\n", "\n", "[함수 모델](https://www.tensorflow.org/guide/keras/functional/)의 일부로 사용자 정의 레이어를 직렬화해야 하는 경우, 선택적으로 `get_config()` 메서드를 구현할 수 있습니다." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:57.445496Z", "iopub.status.busy": "2022-12-14T22:54:57.444807Z", "iopub.status.idle": "2022-12-14T22:54:57.452763Z", "shell.execute_reply": "2022-12-14T22:54:57.451990Z" }, "id": "0a720cbd5f54" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'units': 64}\n" ] } ], "source": [ "class Linear(keras.layers.Layer):\n", " def __init__(self, units=32):\n", " super(Linear, self).__init__()\n", " self.units = units\n", "\n", " def build(self, input_shape):\n", " self.w = self.add_weight(\n", " shape=(input_shape[-1], self.units),\n", " initializer=\"random_normal\",\n", " trainable=True,\n", " )\n", " self.b = self.add_weight(\n", " shape=(self.units,), initializer=\"random_normal\", trainable=True\n", " )\n", "\n", " def call(self, inputs):\n", " return tf.matmul(inputs, self.w) + self.b\n", "\n", " def get_config(self):\n", " return {\"units\": self.units}\n", "\n", "\n", "# Now you can recreate the layer from its config:\n", "layer = Linear(64)\n", "config = layer.get_config()\n", "print(config)\n", "new_layer = Linear.from_config(config)" ] }, { "cell_type": "markdown", "metadata": { "id": "1b43aad6c145" }, "source": [ "기본 `Layer` 클래스의 `__init__()` 메서드는 일부 키워드 인수, 특히 `name` 및 `dtype`를 사용합니다. 이러한 인수를 `__init__()`의 부모 클래스에 전달하고 레이어 구성에 포함하는 것이 좋습니다." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:57.456108Z", "iopub.status.busy": "2022-12-14T22:54:57.455637Z", "iopub.status.idle": "2022-12-14T22:54:57.462563Z", "shell.execute_reply": "2022-12-14T22:54:57.461867Z" }, "id": "0cbad8a6e6cd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'name': 'linear_8', 'trainable': True, 'dtype': 'float32', 'units': 64}\n" ] } ], "source": [ "class Linear(keras.layers.Layer):\n", " def __init__(self, units=32, **kwargs):\n", " super(Linear, self).__init__(**kwargs)\n", " self.units = units\n", "\n", " def build(self, input_shape):\n", " self.w = self.add_weight(\n", " shape=(input_shape[-1], self.units),\n", " initializer=\"random_normal\",\n", " trainable=True,\n", " )\n", " self.b = self.add_weight(\n", " shape=(self.units,), initializer=\"random_normal\", trainable=True\n", " )\n", "\n", " def call(self, inputs):\n", " return tf.matmul(inputs, self.w) + self.b\n", "\n", " def get_config(self):\n", " config = super(Linear, self).get_config()\n", " config.update({\"units\": self.units})\n", " return config\n", "\n", "\n", "layer = Linear(64)\n", "config = layer.get_config()\n", "print(config)\n", "new_layer = Linear.from_config(config)" ] }, { "cell_type": "markdown", "metadata": { "id": "2421f80b5b86" }, "source": [ "구성에서 레이어를 역직렬화할 때 유연성이 더 필요한 경우, `from_config()` 클래스 메서드를 재정의할 수도 있습니다. 다음은 `from_config()`의 기본 구현입니다.\n", "\n", "```python\n", "def from_config(cls, config):\n", " return cls(**config)\n", "```\n", "\n", "직렬화 및 저장에 대한 자세한 내용은 [모델 저장 및 직렬화 가이드](https://www.tensorflow.org/guide/keras/save_and_serialize/)를 참조하세요." ] }, { "cell_type": "markdown", "metadata": { "id": "3d7e2304a047" }, "source": [ "## `call()` 메서드의 권한 있는 `training` 인수\n", "\n", "일부 레이어, 특히 `BatchNormalization` 레이어와 `Dropout` 레이어는 훈련 및 추론 중에 서로 다른 동작을 갖습니다. 이러한 레이어의 경우, `call()` 메서드에서 `training`(boolean) 인수를 노출하는 것이 표준 관행입니다.\n", "\n", "이 인수를 `call()`에서 노출하면 내장 훈련 및 평가 루프(예: `fit()`)를 사용하여 훈련 및 추론에서 레이어를 올바르게 사용할 수 있습니다." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:57.466251Z", "iopub.status.busy": "2022-12-14T22:54:57.465771Z", "iopub.status.idle": "2022-12-14T22:54:57.470084Z", "shell.execute_reply": "2022-12-14T22:54:57.469326Z" }, "id": "a169812c2c00" }, "outputs": [], "source": [ "class CustomDropout(keras.layers.Layer):\n", " def __init__(self, rate, **kwargs):\n", " super(CustomDropout, self).__init__(**kwargs)\n", " self.rate = rate\n", "\n", " def call(self, inputs, training=None):\n", " if training:\n", " return tf.nn.dropout(inputs, rate=self.rate)\n", " return inputs\n" ] }, { "cell_type": "markdown", "metadata": { "id": "9e1482c9f010" }, "source": [ "## `call()` 메서드의 권한 있는 `mask` 인수\n", "\n", "`call()`에서 지원되는 다른 권한 있는 인수는 `mask` 인수입니다.\n", "\n", "이 인수는 모든 Keras RNN 레이어에서 볼 수 있습니다. 마스크는 시계열 데이터를 처리할 때 특정 입력 타임스텝을 건너뛰는 데 사용되는 부울 텐서(입력의 타임스텝당 하나의 부울 값)입니다.\n", "\n", "Keras는 이전 레이어에서 마스크가 생성될 때 이를 지원하는 레이어에 대해 올바른 `mask` 인수를 `__call__()`에 자동으로 전달합니다. 마스크 생성 레이어는 `mask_zero=True` 레이어와 `Masking` 레이어로 구성된 `Embedding`입니다.\n", "\n", "마스킹 및 마스킹 지원 레이어를 작성하는 방법에 대한 자세한 내용은 [\"패딩 및 마스킹 이해하기\"](https://www.tensorflow.org/guide/keras/masking_and_padding/) 가이드를 확인하세요." ] }, { "cell_type": "markdown", "metadata": { "id": "344110f9e134" }, "source": [ "## `Model` 클래스\n", "\n", "일반적으로, `Layer` 클래스를 사용하여 내부 계산 블록을 정의하고 `Model` 클래스를 사용하여 훈련할 객체인 외부 모델을 정의합니다.\n", "\n", "예를 들어, ResNet50 모델에는 `Layer`를 하위 클래스화하는 여러 ResNet 블록과 전체 ResNet50 네트워크를 포괄하는 단일 `Model`이 있습니다.\n", "\n", "`Model` 클래스는 `Layer`와 같은 API를 가지며, 다음과 같은 차이점이 있습니다.\n", "\n", "- 내장 훈련, 평가 및 예측 루프( `model.fit()` , `model.evaluate()`, `model.predict()`)를 제공합니다.\n", "- `model.layers` 속성을 통해 내부 레이어의 목록을 노출합니다.\n", "- 저장 및 직렬화 API(`save()`, `save_weights()`...)를 노출합니다.\n", "\n", "효과적으로, `Layer` 클래스는 문서에서 일컫는 \"레이어\"(\"컨볼루션 레이어\" 또는 \"되풀이 레이어\"에서와 같이) 또는 \"블록\"(\"ResNet 블록\" 또는 \"Inception 블록\"에서와 같이)에 해당합니다.\n", "\n", "한편, `Model` 클래스는 문서에서 \"모델\"(\"딥 러닝 모델\"에서) 또는 \"네트워크\"( \"딥 신경망\"에서)로 지칭되는 것에 해당합다.\n", "\n", "\"`Layer` 클래스를 사용해야 할까요? 아니면 `Model` 클래스를 사용해야 할까요?\"라는 질문이 있다면 자문해 보세요. `fit()`을 호출해야 할까? `save()`를 호출해야 할까? 만약 그렇다면 `Model`를 사용하세요. 그렇지 않다면(클래스가 더 큰 시스템의 블록이거나 직접 훈련을 작성하고 코드를 저장하기 때문에) `Layer`를 사용하세요.\n", "\n", "예를 들어, 위의 mini-resnet 예제를 사용하여 `fit()`으로 훈련하고 `save_weights()`로 저장할 수 있는 `Model`을 빌드할 수 있습니다." ] }, { "cell_type": "markdown", "metadata": { "id": "09caa642b72e" }, "source": [ "```python\n", "class ResNet(tf.keras.Model):\n", "\n", " def __init__(self, num_classes=1000):\n", " super(ResNet, self).__init__()\n", " self.block_1 = ResNetBlock()\n", " self.block_2 = ResNetBlock()\n", " self.global_pool = layers.GlobalAveragePooling2D()\n", " self.classifier = Dense(num_classes)\n", "\n", " def call(self, inputs):\n", " x = self.block_1(inputs)\n", " x = self.block_2(x)\n", " x = self.global_pool(x)\n", " return self.classifier(x)\n", "\n", "\n", "resnet = ResNet()\n", "dataset = ...\n", "resnet.fit(dataset, epochs=10)\n", "resnet.save(filepath)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "a2e32d225a1b" }, "source": [ "## 종합: 엔드 투 엔드 예제\n", "\n", "지금까지 배운 내용은 다음과 같습니다.\n", "\n", "- `Layer`는 상태(`__init__()` 또는 `build()`) 및 일부 계산(`call()`에서 정의)을 캡슐화합니다.\n", "- 레이어를 재귀적으로 중첩하여 새롭고 더 큰 계산 블록을 만들 수 있습니다.\n", "- 레이어는 `add_loss()` 및 `add_metric()`을 통해 메트릭뿐만 아니라 손실(일반적으로, 정규화 손실)을 생성 및 추적할 수 있습니다.\n", "- 훈련하려는 외부 컨테이너는 `Model`입니다. `Model`은 `Layer`와 비슷하지만, 훈련 및 직렬화 유틸리티가 추가되었습니다.\n", "\n", "이 모든 것을 엔드 투 엔드 예제에 넣어봅시다. VAE(Variational AutoEncoder)를 구현할 것이며, MNIST 숫자로 훈련할 것입니다.\n", "\n", "VAE는 `Model`의 서브 클래스가 될 것이며 `Layer`를 하위 클래스화하는 중첩된 레이어 구성으로 빌드됩니다. 정규화 손실(KL 확산)을 제공합니다." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:57.473725Z", "iopub.status.busy": "2022-12-14T22:54:57.473085Z", "iopub.status.idle": "2022-12-14T22:54:57.483727Z", "shell.execute_reply": "2022-12-14T22:54:57.483056Z" }, "id": "56aaae7af872" }, "outputs": [], "source": [ "from tensorflow.keras import layers\n", "\n", "\n", "class Sampling(layers.Layer):\n", " \"\"\"Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n", "\n", " def call(self, inputs):\n", " z_mean, z_log_var = inputs\n", " batch = tf.shape(z_mean)[0]\n", " dim = tf.shape(z_mean)[1]\n", " epsilon = tf.keras.backend.random_normal(shape=(batch, dim))\n", " return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n", "\n", "\n", "class Encoder(layers.Layer):\n", " \"\"\"Maps MNIST digits to a triplet (z_mean, z_log_var, z).\"\"\"\n", "\n", " def __init__(self, latent_dim=32, intermediate_dim=64, name=\"encoder\", **kwargs):\n", " super(Encoder, self).__init__(name=name, **kwargs)\n", " self.dense_proj = layers.Dense(intermediate_dim, activation=\"relu\")\n", " self.dense_mean = layers.Dense(latent_dim)\n", " self.dense_log_var = layers.Dense(latent_dim)\n", " self.sampling = Sampling()\n", "\n", " def call(self, inputs):\n", " x = self.dense_proj(inputs)\n", " z_mean = self.dense_mean(x)\n", " z_log_var = self.dense_log_var(x)\n", " z = self.sampling((z_mean, z_log_var))\n", " return z_mean, z_log_var, z\n", "\n", "\n", "class Decoder(layers.Layer):\n", " \"\"\"Converts z, the encoded digit vector, back into a readable digit.\"\"\"\n", "\n", " def __init__(self, original_dim, intermediate_dim=64, name=\"decoder\", **kwargs):\n", " super(Decoder, self).__init__(name=name, **kwargs)\n", " self.dense_proj = layers.Dense(intermediate_dim, activation=\"relu\")\n", " self.dense_output = layers.Dense(original_dim, activation=\"sigmoid\")\n", "\n", " def call(self, inputs):\n", " x = self.dense_proj(inputs)\n", " return self.dense_output(x)\n", "\n", "\n", "class VariationalAutoEncoder(keras.Model):\n", " \"\"\"Combines the encoder and decoder into an end-to-end model for training.\"\"\"\n", "\n", " def __init__(\n", " self,\n", " original_dim,\n", " intermediate_dim=64,\n", " latent_dim=32,\n", " name=\"autoencoder\",\n", " **kwargs\n", " ):\n", " super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)\n", " self.original_dim = original_dim\n", " self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)\n", " self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)\n", "\n", " def call(self, inputs):\n", " z_mean, z_log_var, z = self.encoder(inputs)\n", " reconstructed = self.decoder(z)\n", " # Add KL divergence regularization loss.\n", " kl_loss = -0.5 * tf.reduce_mean(\n", " z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1\n", " )\n", " self.add_loss(kl_loss)\n", " return reconstructed\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2f8ae035a7c9" }, "source": [ "MNIST에 간단한 훈련 루프를 작성해 봅시다." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:57.487002Z", "iopub.status.busy": "2022-12-14T22:54:57.486431Z", "iopub.status.idle": "2022-12-14T22:55:41.697397Z", "shell.execute_reply": "2022-12-14T22:55:41.696691Z" }, "id": "40f11d1ef3bc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Start of epoch 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:5 out of the last 5 calls to triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:6 out of the last 6 calls to triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 0: mean loss = 0.3124\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 100: mean loss = 0.1235\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 200: mean loss = 0.0981\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 300: mean loss = 0.0885\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 400: mean loss = 0.0837\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 500: mean loss = 0.0805\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 600: mean loss = 0.0784\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 700: mean loss = 0.0768\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 800: mean loss = 0.0757\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 900: mean loss = 0.0747\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Start of epoch 1\n", "step 0: mean loss = 0.0744\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 100: mean loss = 0.0738\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 200: mean loss = 0.0733\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 300: mean loss = 0.0729\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 400: mean loss = 0.0725\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 500: mean loss = 0.0722\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 600: mean loss = 0.0719\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 700: mean loss = 0.0716\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 800: mean loss = 0.0714\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "step 900: mean loss = 0.0711\n" ] } ], "source": [ "original_dim = 784\n", "vae = VariationalAutoEncoder(original_dim, 64, 32)\n", "\n", "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n", "mse_loss_fn = tf.keras.losses.MeanSquaredError()\n", "\n", "loss_metric = tf.keras.metrics.Mean()\n", "\n", "(x_train, _), _ = tf.keras.datasets.mnist.load_data()\n", "x_train = x_train.reshape(60000, 784).astype(\"float32\") / 255\n", "\n", "train_dataset = tf.data.Dataset.from_tensor_slices(x_train)\n", "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n", "\n", "epochs = 2\n", "\n", "# Iterate over epochs.\n", "for epoch in range(epochs):\n", " print(\"Start of epoch %d\" % (epoch,))\n", "\n", " # Iterate over the batches of the dataset.\n", " for step, x_batch_train in enumerate(train_dataset):\n", " with tf.GradientTape() as tape:\n", " reconstructed = vae(x_batch_train)\n", " # Compute reconstruction loss\n", " loss = mse_loss_fn(x_batch_train, reconstructed)\n", " loss += sum(vae.losses) # Add KLD regularization loss\n", "\n", " grads = tape.gradient(loss, vae.trainable_weights)\n", " optimizer.apply_gradients(zip(grads, vae.trainable_weights))\n", "\n", " loss_metric(loss)\n", "\n", " if step % 100 == 0:\n", " print(\"step %d: mean loss = %.4f\" % (step, loss_metric.result()))" ] }, { "cell_type": "markdown", "metadata": { "id": "f0d65fae5d3d" }, "source": [ "VAE는 `Model`을 하위 클래스화하기 때문에 내장된 훈련 루프를 제공합니다. 따라서 다음과 같이 훈련할 수도 있습니다." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:55:41.701889Z", "iopub.status.busy": "2022-12-14T22:55:41.701254Z", "iopub.status.idle": "2022-12-14T22:55:49.277320Z", "shell.execute_reply": "2022-12-14T22:55:49.276673Z" }, "id": "5af13f70d528" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/938 [..............................] - ETA: 25:26 - loss: 0.3454" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 18/938 [..............................] - ETA: 2s - loss: 0.2274 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 36/938 [>.............................] - ETA: 2s - loss: 0.1893" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 55/938 [>.............................] - ETA: 2s - loss: 0.1625" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 74/938 [=>............................] - ETA: 2s - loss: 0.1439" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 93/938 [=>............................] - ETA: 2s - loss: 0.1308" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "112/938 [==>...........................] - ETA: 2s - loss: 0.1215" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "131/938 [===>..........................] - ETA: 2s - loss: 0.1145" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "150/938 [===>..........................] - ETA: 2s - loss: 0.1091" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "169/938 [====>.........................] - ETA: 2s - loss: 0.1048" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "188/938 [=====>........................] - ETA: 2s - loss: 0.1014" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "207/938 [=====>........................] - ETA: 2s - loss: 0.0985" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "226/938 [======>.......................] - ETA: 1s - loss: 0.0960" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "244/938 [======>.......................] - ETA: 1s - loss: 0.0941" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "263/938 [=======>......................] - ETA: 1s - loss: 0.0923" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "282/938 [========>.....................] - ETA: 1s - loss: 0.0907" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "301/938 [========>.....................] - ETA: 1s - loss: 0.0893" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "320/938 [=========>....................] - ETA: 1s - loss: 0.0881" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "339/938 [=========>....................] - ETA: 1s - loss: 0.0870" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "357/938 [==========>...................] - ETA: 1s - loss: 0.0861" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "376/938 [===========>..................] - ETA: 1s - loss: 0.0852" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "395/938 [===========>..................] - ETA: 1s - loss: 0.0844" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "414/938 [============>.................] - ETA: 1s - loss: 0.0837" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "433/938 [============>.................] - ETA: 1s - loss: 0.0830" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "452/938 [=============>................] - ETA: 1s - loss: 0.0823" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "471/938 [==============>...............] - ETA: 1s - loss: 0.0818" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "489/938 [==============>...............] - ETA: 1s - loss: 0.0813" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "508/938 [===============>..............] - ETA: 1s - loss: 0.0807" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "527/938 [===============>..............] - ETA: 1s - loss: 0.0803" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "546/938 [================>.............] - ETA: 1s - loss: 0.0798" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "565/938 [=================>............] - ETA: 1s - loss: 0.0794" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "584/938 [=================>............] - ETA: 0s - loss: 0.0790" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "603/938 [==================>...........] - ETA: 0s - loss: 0.0787" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "622/938 [==================>...........] - ETA: 0s - loss: 0.0784" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "641/938 [===================>..........] - ETA: 0s - loss: 0.0781" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "660/938 [====================>.........] - ETA: 0s - loss: 0.0778" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "679/938 [====================>.........] - ETA: 0s - loss: 0.0775" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "698/938 [=====================>........] - ETA: 0s - loss: 0.0772" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "717/938 [=====================>........] - ETA: 0s - loss: 0.0770" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "736/938 [======================>.......] - ETA: 0s - loss: 0.0768" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "755/938 [=======================>......] - ETA: 0s - loss: 0.0765" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "774/938 [=======================>......] - ETA: 0s - loss: 0.0763" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "793/938 [========================>.....] - ETA: 0s - loss: 0.0761" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "812/938 [========================>.....] - ETA: 0s - loss: 0.0759" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "830/938 [=========================>....] - ETA: 0s - loss: 0.0757" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "849/938 [==========================>...] - ETA: 0s - loss: 0.0756" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "868/938 [==========================>...] - ETA: 0s - loss: 0.0754" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "887/938 [===========================>..] - ETA: 0s - loss: 0.0753" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "906/938 [===========================>..] - ETA: 0s - loss: 0.0751" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "924/938 [============================>.] - ETA: 0s - loss: 0.0750" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "938/938 [==============================] - 4s 3ms/step - loss: 0.0749\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/938 [..............................] - ETA: 3s - loss: 0.0700" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 20/938 [..............................] - ETA: 2s - loss: 0.0674" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 39/938 [>.............................] - ETA: 2s - loss: 0.0674" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 58/938 [>.............................] - ETA: 2s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 77/938 [=>............................] - ETA: 2s - loss: 0.0675" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 96/938 [==>...........................] - ETA: 2s - loss: 0.0678" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "114/938 [==>...........................] - ETA: 2s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "132/938 [===>..........................] - ETA: 2s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "150/938 [===>..........................] - ETA: 2s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "168/938 [====>.........................] - ETA: 2s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "187/938 [====>.........................] - ETA: 2s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "206/938 [=====>........................] - ETA: 2s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "225/938 [======>.......................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "244/938 [======>.......................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "263/938 [=======>......................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "282/938 [========>.....................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "301/938 [========>.....................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "320/938 [=========>....................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "339/938 [=========>....................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "357/938 [==========>...................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "375/938 [==========>...................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "393/938 [===========>..................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "411/938 [============>.................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "430/938 [============>.................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "449/938 [=============>................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "468/938 [=============>................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "487/938 [==============>...............] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "505/938 [===============>..............] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "524/938 [===============>..............] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "543/938 [================>.............] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "562/938 [================>.............] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "581/938 [=================>............] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "600/938 [==================>...........] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "619/938 [==================>...........] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "638/938 [===================>..........] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "657/938 [====================>.........] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "676/938 [====================>.........] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "695/938 [=====================>........] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "714/938 [=====================>........] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "733/938 [======================>.......] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "752/938 [=======================>......] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "771/938 [=======================>......] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "790/938 [========================>.....] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "808/938 [========================>.....] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "826/938 [=========================>....] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "844/938 [=========================>....] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "862/938 [==========================>...] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "881/938 [===========================>..] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "900/938 [===========================>..] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "919/938 [============================>.] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "938/938 [==============================] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "938/938 [==============================] - 3s 3ms/step - loss: 0.0677\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vae = VariationalAutoEncoder(784, 64, 32)\n", "\n", "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n", "\n", "vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())\n", "vae.fit(x_train, x_train, epochs=2, batch_size=64)" ] }, { "cell_type": "markdown", "metadata": { "id": "d34b7ba21662" }, "source": [ "## 객체 지향 개발을 넘어: 함수형 API\n", "\n", "이 예제가 너무 지나친 객체 지향 개발입니까? [함수형 API(Functional API)](https://www.tensorflow.org/guide/keras/functional/)를 사용하여 모델을 빌드할 수도 있습니다. 중요한 것은 하나의 스타일을 선택한다고 해서 다른 스타일로 작성된 구성 요소를 활용하지 못하는 것은 아닙니다. 항상 목적에 따라 다르게 선택할 수 있습니다.\n", "\n", "예를 들어, 아래의 함수형 API 예제는 위 예제에서 정의한 것과 같은 `Sampling` 레이어를 재사용합니다." ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:55:49.281214Z", "iopub.status.busy": "2022-12-14T22:55:49.280490Z", "iopub.status.idle": "2022-12-14T22:55:59.575666Z", "shell.execute_reply": "2022-12-14T22:55:59.574949Z" }, "id": "be77fc8f9b26" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/938 [..............................] - ETA: 23:29 - loss: 0.3185" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 19/938 [..............................] - ETA: 2s - loss: 0.2189 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 38/938 [>.............................] - ETA: 2s - loss: 0.1830" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 56/938 [>.............................] - ETA: 2s - loss: 0.1594" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 74/938 [=>............................] - ETA: 2s - loss: 0.1427" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 92/938 [=>............................] - ETA: 2s - loss: 0.1307" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "111/938 [==>...........................] - ETA: 2s - loss: 0.1215" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "129/938 [===>..........................] - ETA: 2s - loss: 0.1147" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "147/938 [===>..........................] - ETA: 2s - loss: 0.1096" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "165/938 [====>.........................] - ETA: 2s - loss: 0.1055" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "183/938 [====>.........................] - ETA: 2s - loss: 0.1020" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "201/938 [=====>........................] - ETA: 2s - loss: 0.0992" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "219/938 [======>.......................] - ETA: 2s - loss: 0.0968" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "237/938 [======>.......................] - ETA: 1s - loss: 0.0948" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "255/938 [=======>......................] - ETA: 1s - loss: 0.0929" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "274/938 [=======>......................] - ETA: 1s - loss: 0.0912" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "293/938 [========>.....................] - ETA: 1s - loss: 0.0898" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "311/938 [========>.....................] - ETA: 1s - loss: 0.0886" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "329/938 [=========>....................] - ETA: 1s - loss: 0.0874" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "348/938 [==========>...................] - ETA: 1s - loss: 0.0865" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "366/938 [==========>...................] - ETA: 1s - loss: 0.0856" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "384/938 [===========>..................] - ETA: 1s - loss: 0.0848" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "402/938 [===========>..................] - ETA: 1s - loss: 0.0841" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "420/938 [============>.................] - ETA: 1s - loss: 0.0833" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "439/938 [=============>................] - ETA: 1s - loss: 0.0827" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "457/938 [=============>................] - ETA: 1s - loss: 0.0821" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "475/938 [==============>...............] - ETA: 1s - loss: 0.0816" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "493/938 [==============>...............] - ETA: 1s - loss: 0.0811" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "511/938 [===============>..............] - ETA: 1s - loss: 0.0806" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "529/938 [===============>..............] - ETA: 1s - loss: 0.0802" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "547/938 [================>.............] - ETA: 1s - loss: 0.0798" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "565/938 [=================>............] - ETA: 1s - loss: 0.0794" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "583/938 [=================>............] - ETA: 1s - loss: 0.0790" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "601/938 [==================>...........] - ETA: 0s - loss: 0.0787" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "619/938 [==================>...........] - ETA: 0s - loss: 0.0784" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "637/938 [===================>..........] - ETA: 0s - loss: 0.0781" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "655/938 [===================>..........] - ETA: 0s - loss: 0.0778" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "673/938 [====================>.........] - ETA: 0s - loss: 0.0775" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "692/938 [=====================>........] - ETA: 0s - loss: 0.0773" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "710/938 [=====================>........] - ETA: 0s - loss: 0.0770" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "728/938 [======================>.......] - ETA: 0s - loss: 0.0768" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "746/938 [======================>.......] - ETA: 0s - loss: 0.0766" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "764/938 [=======================>......] - ETA: 0s - loss: 0.0764" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "783/938 [========================>.....] - ETA: 0s - loss: 0.0762" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "802/938 [========================>.....] - ETA: 0s - loss: 0.0760" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "820/938 [=========================>....] - ETA: 0s - loss: 0.0758" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "838/938 [=========================>....] - ETA: 0s - loss: 0.0756" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "857/938 [==========================>...] - ETA: 0s - loss: 0.0755" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "875/938 [==========================>...] - ETA: 0s - loss: 0.0753" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "894/938 [===========================>..] - ETA: 0s - loss: 0.0751" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "913/938 [============================>.] - ETA: 0s - loss: 0.0750" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "931/938 [============================>.] - ETA: 0s - loss: 0.0749" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "938/938 [==============================] - 4s 3ms/step - loss: 0.0748\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/938 [..............................] - ETA: 3s - loss: 0.0670" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 19/938 [..............................] - ETA: 2s - loss: 0.0675" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 37/938 [>.............................] - ETA: 2s - loss: 0.0675" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 56/938 [>.............................] - ETA: 2s - loss: 0.0678" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 74/938 [=>............................] - ETA: 2s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 92/938 [=>............................] - ETA: 2s - loss: 0.0678" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "110/938 [==>...........................] - ETA: 2s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "128/938 [===>..........................] - ETA: 2s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "146/938 [===>..........................] - ETA: 2s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "164/938 [====>.........................] - ETA: 2s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "182/938 [====>.........................] - ETA: 2s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "200/938 [=====>........................] - ETA: 2s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "219/938 [======>.......................] - ETA: 2s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "237/938 [======>.......................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "256/938 [=======>......................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "274/938 [=======>......................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "292/938 [========>.....................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "311/938 [========>.....................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "329/938 [=========>....................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "348/938 [==========>...................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "366/938 [==========>...................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "384/938 [===========>..................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "403/938 [===========>..................] - ETA: 1s - loss: 0.0678" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "422/938 [============>.................] - ETA: 1s - loss: 0.0678" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "441/938 [=============>................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "459/938 [=============>................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "478/938 [==============>...............] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "496/938 [==============>...............] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "514/938 [===============>..............] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "533/938 [================>.............] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "551/938 [================>.............] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "570/938 [=================>............] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "588/938 [=================>............] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "606/938 [==================>...........] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "624/938 [==================>...........] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "643/938 [===================>..........] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "661/938 [====================>.........] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "680/938 [====================>.........] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "698/938 [=====================>........] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "716/938 [=====================>........] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "734/938 [======================>.......] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "752/938 [=======================>......] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "770/938 [=======================>......] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "788/938 [========================>.....] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "806/938 [========================>.....] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "824/938 [=========================>....] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "842/938 [=========================>....] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "860/938 [==========================>...] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "878/938 [===========================>..] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "896/938 [===========================>..] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "914/938 [============================>.] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "933/938 [============================>.] - ETA: 0s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "938/938 [==============================] - 3s 3ms/step - loss: 0.0676\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/938 [..............................] - ETA: 3s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 20/938 [..............................] - ETA: 2s - loss: 0.0683" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 38/938 [>.............................] - ETA: 2s - loss: 0.0681" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 56/938 [>.............................] - ETA: 2s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 74/938 [=>............................] - ETA: 2s - loss: 0.0679" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 92/938 [=>............................] - ETA: 2s - loss: 0.0679" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "110/938 [==>...........................] - ETA: 2s - loss: 0.0678" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "128/938 [===>..........................] - ETA: 2s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "146/938 [===>..........................] - ETA: 2s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "164/938 [====>.........................] - ETA: 2s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "183/938 [====>.........................] - ETA: 2s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "202/938 [=====>........................] - ETA: 2s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "221/938 [======>.......................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "240/938 [======>.......................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "258/938 [=======>......................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "275/938 [=======>......................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "293/938 [========>.....................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "311/938 [========>.....................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "329/938 [=========>....................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "347/938 [==========>...................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "366/938 [==========>...................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "385/938 [===========>..................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "403/938 [===========>..................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "421/938 [============>.................] - ETA: 1s - loss: 0.0677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "439/938 [=============>................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "457/938 [=============>................] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "475/938 [==============>...............] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "494/938 [==============>...............] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "513/938 [===============>..............] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "532/938 [================>.............] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "551/938 [================>.............] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "570/938 [=================>............] - ETA: 1s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "589/938 [=================>............] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "607/938 [==================>...........] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "625/938 [==================>...........] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "643/938 [===================>..........] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "661/938 [====================>.........] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "679/938 [====================>.........] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "697/938 [=====================>........] - ETA: 0s - loss: 0.0675" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "715/938 [=====================>........] - ETA: 0s - loss: 0.0675" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "733/938 [======================>.......] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "752/938 [=======================>......] - ETA: 0s - loss: 0.0675" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "770/938 [=======================>......] - ETA: 0s - loss: 0.0675" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "789/938 [========================>.....] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "808/938 [========================>.....] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "827/938 [=========================>....] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "845/938 [==========================>...] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "863/938 [==========================>...] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "881/938 [===========================>..] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "900/938 [===========================>..] - ETA: 0s - loss: 0.0675" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "919/938 [============================>.] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "938/938 [==============================] - ETA: 0s - loss: 0.0676" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "938/938 [==============================] - 3s 3ms/step - loss: 0.0676\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "original_dim = 784\n", "intermediate_dim = 64\n", "latent_dim = 32\n", "\n", "# Define encoder model.\n", "original_inputs = tf.keras.Input(shape=(original_dim,), name=\"encoder_input\")\n", "x = layers.Dense(intermediate_dim, activation=\"relu\")(original_inputs)\n", "z_mean = layers.Dense(latent_dim, name=\"z_mean\")(x)\n", "z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n", "z = Sampling()((z_mean, z_log_var))\n", "encoder = tf.keras.Model(inputs=original_inputs, outputs=z, name=\"encoder\")\n", "\n", "# Define decoder model.\n", "latent_inputs = tf.keras.Input(shape=(latent_dim,), name=\"z_sampling\")\n", "x = layers.Dense(intermediate_dim, activation=\"relu\")(latent_inputs)\n", "outputs = layers.Dense(original_dim, activation=\"sigmoid\")(x)\n", "decoder = tf.keras.Model(inputs=latent_inputs, outputs=outputs, name=\"decoder\")\n", "\n", "# Define VAE model.\n", "outputs = decoder(z)\n", "vae = tf.keras.Model(inputs=original_inputs, outputs=outputs, name=\"vae\")\n", "\n", "# Add KL divergence regularization loss.\n", "kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)\n", "vae.add_loss(kl_loss)\n", "\n", "# Train.\n", "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n", "vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())\n", "vae.fit(x_train, x_train, epochs=3, batch_size=64)" ] }, { "cell_type": "markdown", "metadata": { "id": "e2f135ea7cf5" }, "source": [ "자세한 정보는 [함수형 API 가이드](https://www.tensorflow.org/guide/keras/functional/)를 참고하세요." ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "custom_layers_and_models.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }