{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2019 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T20:00:03.143830Z", "iopub.status.busy": "2022-12-14T20:00:03.143405Z", "iopub.status.idle": "2022-12-14T20:00:03.147088Z", "shell.execute_reply": "2022-12-14T20:00:03.146590Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "# Keras を使ったマルチワーカートレーニング\n", "\n", "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
Model.fit
API によるマルチワーカー分散型トレーニングを実演します。このストラテジーにより、単一のワーカーで実行するように設計された Keras モデルは、最小限のコード変更で複数のワーカーでシームレスに機能することができます。\n",
"\n",
"To learn how to use the `MultiWorkerMirroredStrategy` with Keras and a custom training loop, refer to [Custom training loop with Keras and MultiWorkerMirroredStrategy](multi_worker_with_ctl.ipynb).\n",
"\n",
"このチュートリアルには、デモ用に 2 つのワーカーを含む最小限のマルチワーカーの例が含まれています。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JUdRerXg6yz3"
},
"source": [
"### 適切なストラテジーを選択する"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YAiCV_oL63GM"
},
"source": [
"始める前に、アクセラレータとトレーニングに `tf.distribute.MultiWorkerMirroredStrategy` が適切な選択であることを確認してください。これらは、データ並列処理を使用してトレーニングを分散する 2 つの一般的な方法です。\n",
"\n",
"- *同期トレーニング*。`tf.distribute.MirroredStrategy`、`tf.distribute.TPUStrategy` および `tf.distribute.MultiWorkerMirroredStrategy` などのトレーニングステップがワーカーとレプリカ間で同期されます。すべてのワーカーは、入力データの異なるスライスを同期してトレーニングし、各ステップで勾配を集約します。\n",
"- *非同期トレーニング*。`tf.distribute.experimental.ParameterServerStrategy` など、トレーニングステップが厳密に同期されていません。すべてのワーカーは、入力データを個別にトレーニングし、変数を非同期的に更新します。\n",
"\n",
"TPU を使用しないマルチワーカーの同期トレーニングには、`tf.distribute.experimental.MultiWorkerMirroredStrategy` を使用します。これは、すべてのワーカーの各デバイスにあるモデルのレイヤーにすべての変数のコピーを作成します。集合通信に使用する TensorFlow 演算子 `CollectiveOps` を使用して勾配を集め、変数の同期を維持します。集合実装オプションについては、`tf.distribute.experimental.CommunicationOptions` パラメータを確認してください。\n",
"\n",
"`tf.distribute.Strategy` API の概要については、[TensorFlow での分散トレーニング](../../guide/distributed_training.ipynb)を参照してください。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MUXex9ctTuDB"
},
"source": [
"## セットアップ\n",
"\n",
"まず、必要なものをインポートします。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T20:00:03.150605Z",
"iopub.status.busy": "2022-12-14T20:00:03.150155Z",
"iopub.status.idle": "2022-12-14T20:00:03.156960Z",
"shell.execute_reply": "2022-12-14T20:00:03.156339Z"
},
"id": "bnYxvfLD-LW-"
},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"import sys"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zz0EY91y3mxy"
},
"source": [
"TensorFlow をインポートする前に、環境にいくつかの変更を加えます。\n",
"\n",
"- 実際のアプリケーションでは、各ワーカーは異なるマシン上にあります。このチュートリアルでは、すべてのワーカーが**この**マシンで実行されます。そのため、すべての GPU を無効にして、すべてのワーカーが同じ GPU を使用しようとすることによって発生するエラーを防ぎます。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T20:00:03.160306Z",
"iopub.status.busy": "2022-12-14T20:00:03.159739Z",
"iopub.status.idle": "2022-12-14T20:00:03.162864Z",
"shell.execute_reply": "2022-12-14T20:00:03.162219Z"
},
"id": "rpEIVI5upIzM"
},
"outputs": [],
"source": [
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7X1MS6385BWi"
},
"source": [
"- `TF_CONFIG` 環境変数をリセットします(これについては後で詳しく説明します)。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T20:00:03.166028Z",
"iopub.status.busy": "2022-12-14T20:00:03.165585Z",
"iopub.status.idle": "2022-12-14T20:00:03.168435Z",
"shell.execute_reply": "2022-12-14T20:00:03.167905Z"
},
"id": "WEJLYa2_7OZF"
},
"outputs": [],
"source": [
"os.environ.pop('TF_CONFIG', None)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rd4L9Ii77SS8"
},
"source": [
"- 現在のディレクトリが Python のパス上にあることを確認してください。これにより、ノートブックは `%%writefile` で書き込まれたファイルを後でインポートできるようになります。\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T20:00:03.171657Z",
"iopub.status.busy": "2022-12-14T20:00:03.171137Z",
"iopub.status.idle": "2022-12-14T20:00:03.174209Z",
"shell.execute_reply": "2022-12-14T20:00:03.173672Z"
},
"id": "hPBuZUNSZmrQ"
},
"outputs": [],
"source": [
"if '.' not in sys.path:\n",
" sys.path.insert(0, '.')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9hLpDZhAz2q-"
},
"source": [
"`tf-nightly` をインストールします。TensorFlow 2.10 から `tf.keras.callbacks.BackupAndRestore` の `save_freq` 引数を使用した特定のステップでのチェックポイント保存頻度が導入されます。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T20:00:03.177440Z",
"iopub.status.busy": "2022-12-14T20:00:03.177054Z",
"iopub.status.idle": "2022-12-14T20:00:31.425335Z",
"shell.execute_reply": "2022-12-14T20:00:31.424469Z"
},
"id": "-XqozLfzz30N"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting tf-nightly\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Downloading tf_nightly-2.12.0.dev20221214-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (556.4 MB)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting jax>=0.3.15\r\n",
" Downloading jax-0.4.1.tar.gz (1.2 MB)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Preparing metadata (setup.py) ... \u001b[?25l-"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b \bdone\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[?25hRequirement already satisfied: numpy>=1.20 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.24.0rc2)\r\n",
"Requirement already satisfied: flatbuffers>=2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (22.12.6)\r\n",
"Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.3.0)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting tb-nightly~=2.12.0.a\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Downloading tb_nightly-2.12.0a20221214-py3-none-any.whl (5.7 MB)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.14.1)\r\n",
"Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (4.4.0)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting tf-estimator-nightly~=2.12.0.dev\r\n",
" Downloading tf_estimator_nightly-2.12.0.dev2022121409-py2.py3-none-any.whl (439 kB)\r\n",
"Requirement already satisfied: six>=1.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.16.0)\r\n",
"Requirement already satisfied: protobuf<3.20,>=3.9.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (3.19.6)\r\n",
"Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (2.1.1)\r\n",
"Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (0.2.0)\r\n",
"Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (65.6.3)\r\n",
"Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (0.4.0)\r\n",
"Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (22.0)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.51.1)\r\n",
"Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (0.28.0)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting keras-nightly~=2.12.0.dev\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Downloading keras_nightly-2.12.0.dev2022121408-py2.py3-none-any.whl (1.7 MB)\r\n",
"Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (14.0.6)\r\n",
"Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (3.7.0)\r\n",
"Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.6.3)\r\n",
"Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (3.3.0)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tf-nightly) (0.37.1)\r\n",
"Requirement already satisfied: scipy>=1.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax>=0.3.15->tf-nightly) (1.9.3)\r\n",
"Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.12.0.a->tf-nightly) (2.28.1)\r\n",
"Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.12.0.a->tf-nightly) (1.8.1)\r\n",
"Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.12.0.a->tf-nightly) (2.15.0)\r\n",
"Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.12.0.a->tf-nightly) (0.4.6)\r\n",
"Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.12.0.a->tf-nightly) (2.2.2)\r\n",
"Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.12.0.a->tf-nightly) (3.4.1)\r\n",
"Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.12.0.a->tf-nightly) (0.6.1)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tb-nightly~=2.12.0.a->tf-nightly) (5.2.0)\r\n",
"Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tb-nightly~=2.12.0.a->tf-nightly) (4.9)\r\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tb-nightly~=2.12.0.a->tf-nightly) (0.3.0rc1)\r\n",
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly~=2.12.0.a->tf-nightly) (1.3.1)\r\n",
"Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tb-nightly~=2.12.0.a->tf-nightly) (5.1.0)\r\n",
"Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tb-nightly~=2.12.0.a->tf-nightly) (2022.12.7)\r\n",
"Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tb-nightly~=2.12.0.a->tf-nightly) (3.4)\r\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tb-nightly~=2.12.0.a->tf-nightly) (1.26.13)\r\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tb-nightly~=2.12.0.a->tf-nightly) (2.1.1)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tb-nightly~=2.12.0.a->tf-nightly) (2.1.1)\r\n",
"Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tb-nightly~=2.12.0.a->tf-nightly) (3.11.0)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tb-nightly~=2.12.0.a->tf-nightly) (0.5.0rc2)\r\n",
"Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly~=2.12.0.a->tf-nightly) (3.2.2)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Building wheels for collected packages: jax\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Building wheel for jax (setup.py) ... \u001b[?25l-"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b \b\\"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b \b|"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b \b/"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b \bdone\r\n",
"\u001b[?25h Created wheel for jax: filename=jax-0.4.1-py3-none-any.whl size=1332462 sha256=e4b7a0b05e48ea35ddbda56a810cf2446f98784785f2bf0acfab5db413b6e4b1\r\n",
" Stored in directory: /home/kbuilder/.cache/pip/wheels/50/a9/f3/86082312fd44e12e52b1b7744c37ed1d02e64deefdc735c77b\r\n",
"Successfully built jax\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Installing collected packages: tf-estimator-nightly, keras-nightly, jax, tb-nightly, tf-nightly\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Successfully installed jax-0.4.1 keras-nightly-2.12.0.dev2022121408 tb-nightly-2.12.0a20221214 tf-estimator-nightly-2.12.0.dev2022121409 tf-nightly-2.12.0.dev20221214\r\n"
]
}
],
"source": [
"!pip install tf-nightly"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "524e38dab658"
},
"source": [
"最後に、TensorFlow をインポートします。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T20:00:31.429648Z",
"iopub.status.busy": "2022-12-14T20:00:31.429341Z",
"iopub.status.idle": "2022-12-14T20:00:33.890439Z",
"shell.execute_reply": "2022-12-14T20:00:33.889770Z"
},
"id": "vHNvttzV43sA"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-12-14 20:00:31.684068: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay\n"
]
}
],
"source": [
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0S2jpf6Sx50i"
},
"source": [
"### データセットとモデルの定義"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fLW6D2TzvC-4"
},
"source": [
"次に、単純なモデルとデータセットの設定を使用して `mnist.py` ファイルを作成します。この Python ファイルは、このチュートリアルのワーカープロセスによって使用されます。"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T20:00:33.894960Z",
"iopub.status.busy": "2022-12-14T20:00:33.894589Z",
"iopub.status.idle": "2022-12-14T20:00:33.899960Z",
"shell.execute_reply": "2022-12-14T20:00:33.899337Z"
},
"id": "dma_wUAxZqo2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Writing mnist_setup.py\n"
]
}
],
"source": [
"%%writefile mnist_setup.py\n",
"\n",
"import os\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"\n",
"def mnist_dataset(batch_size):\n",
" (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()\n",
" # The `x` arrays are in uint8 and have values in the [0, 255] range.\n",
" # You need to convert them to float32 with values in the [0, 1] range.\n",
" x_train = x_train / np.float32(255)\n",
" y_train = y_train.astype(np.int64)\n",
" train_dataset = tf.data.Dataset.from_tensor_slices(\n",
" (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)\n",
" return train_dataset\n",
"\n",
"def build_and_compile_cnn_model():\n",
" model = tf.keras.Sequential([\n",
" tf.keras.layers.InputLayer(input_shape=(28, 28)),\n",
" tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
" tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(128, activation='relu'),\n",
" tf.keras.layers.Dense(10)\n",
" ])\n",
" model.compile(\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=['accuracy'])\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2UL3kisMO90X"
},
"source": [
"### シングルワーカーでのモデルのトレーニング\n",
"\n",
"まず、少数のエポックでモデルをトレーニングし、シングルワーカーで結果を観察して、すべてが正しく機能していることを確認します。エポックが進むにつれ、損失が下降し、精度が 1.0 に近づくはずです。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T20:00:33.903561Z",
"iopub.status.busy": "2022-12-14T20:00:33.903029Z",
"iopub.status.idle": "2022-12-14T20:00:36.897242Z",
"shell.execute_reply": "2022-12-14T20:00:36.896543Z"
},
"id": "6Qe6iAf5O8iJ"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 8192/11490434 [..............................] - ETA: 0s"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 2932736/11490434 [======>.......................] - ETA: 0s"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 7176192/11490434 [=================>............] - ETA: 0s"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"11490434/11490434 [==============================] - 0s 0us/step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-12-14 20:00:34.473671: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/70 [..............................] - ETA: 33s - loss: 2.2978 - accuracy: 0.1562"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 8/70 [==>...........................] - ETA: 0s - loss: 2.3086 - accuracy: 0.1152 "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"17/70 [======>.......................] - ETA: 0s - loss: 2.3051 - accuracy: 0.1296"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"26/70 [==========>...................] - ETA: 0s - loss: 2.3001 - accuracy: 0.1364"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"34/70 [=============>................] - ETA: 0s - loss: 2.2935 - accuracy: 0.1535"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"43/70 [=================>............] - ETA: 0s - loss: 2.2877 - accuracy: 0.1657"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"51/70 [====================>.........] - ETA: 0s - loss: 2.2831 - accuracy: 0.1798"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"59/70 [========================>.....] - ETA: 0s - loss: 2.2782 - accuracy: 0.1965"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"67/70 [===========================>..] - ETA: 0s - loss: 2.2737 - accuracy: 0.2113"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"70/70 [==============================] - 1s 7ms/step - loss: 2.2720 - accuracy: 0.2181\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/3\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/70 [..............................] - ETA: 0s - loss: 2.2347 - accuracy: 0.4062"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"10/70 [===>..........................] - ETA: 0s - loss: 2.2253 - accuracy: 0.3922"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"19/70 [=======>......................] - ETA: 0s - loss: 2.2219 - accuracy: 0.3964"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"28/70 [===========>..................] - ETA: 0s - loss: 2.2163 - accuracy: 0.4096"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"37/70 [==============>...............] - ETA: 0s - loss: 2.2139 - accuracy: 0.4067"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"46/70 [==================>...........] - ETA: 0s - loss: 2.2095 - accuracy: 0.4202"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"55/70 [======================>.......] - ETA: 0s - loss: 2.2046 - accuracy: 0.4330"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"63/70 [==========================>...] - ETA: 0s - loss: 2.2007 - accuracy: 0.4422"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"70/70 [==============================] - 0s 6ms/step - loss: 2.1978 - accuracy: 0.4471\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/3\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/70 [..............................] - ETA: 0s - loss: 2.1798 - accuracy: 0.4531"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 9/70 [==>...........................] - ETA: 0s - loss: 2.1475 - accuracy: 0.5382"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"16/70 [=====>........................] - ETA: 0s - loss: 2.1393 - accuracy: 0.5625"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"25/70 [=========>....................] - ETA: 0s - loss: 2.1369 - accuracy: 0.5587"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"33/70 [=============>................] - ETA: 0s - loss: 2.1329 - accuracy: 0.5691"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"42/70 [=================>............] - ETA: 0s - loss: 2.1279 - accuracy: 0.5722"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"51/70 [====================>.........] - ETA: 0s - loss: 2.1223 - accuracy: 0.5748"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"59/70 [========================>.....] - ETA: 0s - loss: 2.1160 - accuracy: 0.5792"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"68/70 [============================>.] - ETA: 0s - loss: 2.1079 - accuracy: 0.5885"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"70/70 [==============================] - 0s 6ms/step - loss: 2.1061 - accuracy: 0.5902\n"
]
},
{
"data": {
"text/plain": [
"