{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "rX8mhOLljYeM" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-01-11T21:56:02.947709Z", "iopub.status.busy": "2024-01-11T21:56:02.947251Z", "iopub.status.idle": "2024-01-11T21:56:02.950907Z", "shell.execute_reply": "2024-01-11T21:56:02.950265Z" }, "id": "BZSlp3DAjdYf" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "3wF5wszaj97Y" }, "source": [ "# エキスパートのための TensorFlow 2 クイックスタート" ] }, { "cell_type": "markdown", "metadata": { "id": "DUNzJc4jTj6G" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示\n", " Google Colabで実行\n", "GitHubでソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "hiH7AC-NTniF" }, "source": [ "これは [Google Colaboratory](https://colab.research.google.com/notebooks/welcome.ipynb) のノートブックファイルです。Python プログラムはブラウザ上で直接実行されるため、TensorFlow を学んで使用するのに最適です。このチュートリアルを進めるには、このページの上部にあるボタンをクリックして Google Colab でノートブックを実行してください。\n", "\n", "1. Colab で、Python ランタイムに接続します。メニューバーの右上にある *CONNECT* を選択してください。\n", "2. ノートブックのコードセルをすべて実行します。*Runtime* > *Run all* を選択してください。" ] }, { "cell_type": "markdown", "metadata": { "id": "eOsVdx6GGHmU" }, "source": [ "TensorFlow 2 をダウンロードしてインストールします。TensorFlow をプログラムにインポートします。\n", "\n", "注意: `pip` をアップグレードして TensorFlow 2 をインストールします。詳細は、[インストールガイド](https://www.tensorflow.org/install)をご覧ください。" ] }, { "cell_type": "markdown", "metadata": { "id": "QS7DDTiZGRTo" }, "source": [ "TensorFlow をプログラムにインポートします。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:56:02.954501Z", "iopub.status.busy": "2024-01-11T21:56:02.954065Z", "iopub.status.idle": "2024-01-11T21:56:05.338696Z", "shell.execute_reply": "2024-01-11T21:56:05.337982Z" }, "id": "0trJmd6DjqBZ" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-01-11 21:56:03.385012: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-01-11 21:56:03.385056: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-01-11 21:56:03.386602: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "TensorFlow version: 2.15.0\n" ] } ], "source": [ "import tensorflow as tf\n", "print(\"TensorFlow version:\", tf.__version__)\n", "\n", "from tensorflow.keras.layers import Dense, Flatten, Conv2D\n", "from tensorflow.keras import Model" ] }, { "cell_type": "markdown", "metadata": { "id": "7NAbSZiaoJ4z" }, "source": [ "[MNIST データセット](http://yann.lecun.com/exdb/mnist/)をロードして準備します。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:56:05.342442Z", "iopub.status.busy": "2024-01-11T21:56:05.342020Z", "iopub.status.idle": "2024-01-11T21:56:05.890651Z", "shell.execute_reply": "2024-01-11T21:56:05.889808Z" }, "id": "JqFRS6K07jJs" }, "outputs": [], "source": [ "mnist = tf.keras.datasets.mnist\n", "\n", "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", "x_train, x_test = x_train / 255.0, x_test / 255.0\n", "\n", "# Add a channels dimension\n", "x_train = x_train[..., tf.newaxis].astype(\"float32\")\n", "x_test = x_test[..., tf.newaxis].astype(\"float32\")" ] }, { "cell_type": "markdown", "metadata": { "id": "k1Evqx0S22r_" }, "source": [ "`tf.data` を使用して、データセットをバッチ化してシャッフルします。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:56:05.894698Z", "iopub.status.busy": "2024-01-11T21:56:05.894441Z", "iopub.status.idle": "2024-01-11T21:56:08.697197Z", "shell.execute_reply": "2024-01-11T21:56:08.696464Z" }, "id": "8Iu_quO024c2" }, "outputs": [], "source": [ "train_ds = tf.data.Dataset.from_tensor_slices(\n", " (x_train, y_train)).shuffle(10000).batch(32)\n", "\n", "test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)" ] }, { "cell_type": "markdown", "metadata": { "id": "BPZ68wASog_I" }, "source": [ "Keras の model subclassing API を使って tf.keras モデルを構築します。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:56:08.701528Z", "iopub.status.busy": "2024-01-11T21:56:08.700972Z", "iopub.status.idle": "2024-01-11T21:56:08.719846Z", "shell.execute_reply": "2024-01-11T21:56:08.719231Z" }, "id": "h3IKyzTCDNGo" }, "outputs": [], "source": [ "class MyModel(Model):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = Conv2D(32, 3, activation='relu')\n", " self.flatten = Flatten()\n", " self.d1 = Dense(128, activation='relu')\n", " self.d2 = Dense(10)\n", "\n", " def call(self, x):\n", " x = self.conv1(x)\n", " x = self.flatten(x)\n", " x = self.d1(x)\n", " return self.d2(x)\n", "\n", "# Create an instance of the model\n", "model = MyModel()" ] }, { "cell_type": "markdown", "metadata": { "id": "uGih-c2LgbJu" }, "source": [ "トレーニングを実施するために、オプティマイザと損失関数を選択します。 " ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:56:08.723244Z", "iopub.status.busy": "2024-01-11T21:56:08.722788Z", "iopub.status.idle": "2024-01-11T21:56:08.729508Z", "shell.execute_reply": "2024-01-11T21:56:08.728882Z" }, "id": "u48C9WQ774n4" }, "outputs": [], "source": [ "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", "\n", "optimizer = tf.keras.optimizers.Adam()" ] }, { "cell_type": "markdown", "metadata": { "id": "JB6A1vcigsIe" }, "source": [ "モデルの損失と精度を測定するためのメトリクスを選択します。これらのメトリクスはエポックの値を集計し、最終結果を出力します。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:56:08.732958Z", "iopub.status.busy": "2024-01-11T21:56:08.732474Z", "iopub.status.idle": "2024-01-11T21:56:08.751840Z", "shell.execute_reply": "2024-01-11T21:56:08.751181Z" }, "id": "N0MqHFb4F_qn" }, "outputs": [], "source": [ "train_loss = tf.keras.metrics.Mean(name='train_loss')\n", "train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')\n", "\n", "test_loss = tf.keras.metrics.Mean(name='test_loss')\n", "test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')" ] }, { "cell_type": "markdown", "metadata": { "id": "ix4mEL65on-w" }, "source": [ "`tf.GradientTape` を使ってモデルをトレーニングします。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:56:08.755122Z", "iopub.status.busy": "2024-01-11T21:56:08.754666Z", "iopub.status.idle": "2024-01-11T21:56:08.758930Z", "shell.execute_reply": "2024-01-11T21:56:08.758348Z" }, "id": "OZACiVqA8KQV" }, "outputs": [], "source": [ "@tf.function\n", "def train_step(images, labels):\n", " with tf.GradientTape() as tape:\n", " # training=True is only needed if there are layers with different\n", " # behavior during training versus inference (e.g. Dropout).\n", " predictions = model(images, training=True)\n", " loss = loss_object(labels, predictions)\n", " gradients = tape.gradient(loss, model.trainable_variables)\n", " optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n", "\n", " train_loss(loss)\n", " train_accuracy(labels, predictions)" ] }, { "cell_type": "markdown", "metadata": { "id": "Z8YT7UmFgpjV" }, "source": [ "モデルをテストします。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:56:08.762067Z", "iopub.status.busy": "2024-01-11T21:56:08.761623Z", "iopub.status.idle": "2024-01-11T21:56:08.765239Z", "shell.execute_reply": "2024-01-11T21:56:08.764668Z" }, "id": "xIKdEzHAJGt7" }, "outputs": [], "source": [ "@tf.function\n", "def test_step(images, labels):\n", " # training=False is only needed if there are layers with different\n", " # behavior during training versus inference (e.g. Dropout).\n", " predictions = model(images, training=False)\n", " t_loss = loss_object(labels, predictions)\n", "\n", " test_loss(t_loss)\n", " test_accuracy(labels, predictions)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:56:08.768178Z", "iopub.status.busy": "2024-01-11T21:56:08.767729Z", "iopub.status.idle": "2024-01-11T21:56:30.221999Z", "shell.execute_reply": "2024-01-11T21:56:30.221070Z" }, "id": "i-2pkctU_Ci7" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1705010170.188690 982259 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, Loss: 0.14312978088855743, Accuracy: 95.6883316040039, Test Loss: 0.06150883808732033, Test Accuracy: 98.06999969482422\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2, Loss: 0.04425348341464996, Accuracy: 98.6199951171875, Test Loss: 0.05480688810348511, Test Accuracy: 98.18999481201172\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3, Loss: 0.023569559678435326, Accuracy: 99.23333740234375, Test Loss: 0.05095665156841278, Test Accuracy: 98.3699951171875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4, Loss: 0.01400002371519804, Accuracy: 99.54833221435547, Test Loss: 0.05548116937279701, Test Accuracy: 98.43999481201172\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5, Loss: 0.010487733408808708, Accuracy: 99.62999725341797, Test Loss: 0.05794484168291092, Test Accuracy: 98.5\n" ] } ], "source": [ "EPOCHS = 5\n", "\n", "for epoch in range(EPOCHS):\n", " # Reset the metrics at the start of the next epoch\n", " train_loss.reset_states()\n", " train_accuracy.reset_states()\n", " test_loss.reset_states()\n", " test_accuracy.reset_states()\n", "\n", " for images, labels in train_ds:\n", " train_step(images, labels)\n", "\n", " for test_images, test_labels in test_ds:\n", " test_step(test_images, test_labels)\n", "\n", " print(\n", " f'Epoch {epoch + 1}, '\n", " f'Loss: {train_loss.result()}, '\n", " f'Accuracy: {train_accuracy.result() * 100}, '\n", " f'Test Loss: {test_loss.result()}, '\n", " f'Test Accuracy: {test_accuracy.result() * 100}'\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "T4JfEh7kvx6m" }, "source": [ "この画像分類器は、このデータセットで約 98% の精度にトレーニングされました。さらに学習するには、[TensorFlow のチュートリアル](https://www.tensorflow.org/tutorials/) をご覧ください。" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "advanced.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }