{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "USSV_OlCFKOD" }, "source": [ "# Training a neural network on MNIST with Keras\n", "\n", "This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "J8y9ZkLXmAZc" }, "source": [ "Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0" ] }, { "cell_type": "markdown", "metadata": { "id": "OGw9EgE0tC0C" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TTBSvHcSLBzc" }, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": { "id": "VjI6VgOBf0v0" }, "source": [ "## 手順 1: 入力パイプラインを作成する\n", "\n", "まず、次のガイドを参照し、有効な入力パイプラインを構築します。\n", "\n", "- [TFDS パフォーマンスガイド](https://www.tensorflow.org/datasets/performances)\n", "- [tf.data パフォーマンスガイド](https://www.tensorflow.org/guide/data_performance#optimize_performance)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "c3aH3vP_XLI8" }, "source": [ "### データセットを読み込む\n", "\n", "次の引数を使って MNIST データセットを読み込みます。\n", "\n", "- `shuffle_files`: MNIST データは、単一のファイルにのみ保存されていますが、ディスク上の複数のファイルを伴うより大きなデータセットについては、トレーニングの際にシャッフルすることが良い実践です。\n", "- `as_supervised`: dict `{'image': img, 'label': label}` の代わりに tuple `(img, label)` を返します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZUMhCXhFXdHQ" }, "outputs": [], "source": [ "(ds_train, ds_test), ds_info = tfds.load(\n", " 'mnist',\n", " split=['train', 'test'],\n", " shuffle_files=True,\n", " as_supervised=True,\n", " with_info=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "rgwCFAcWXQTx" }, "source": [ "### トレーニングパイプラインを構築する\n", "\n", "次の変換を適用します。\n", "\n", "- `tf.data.Dataset.map`: TFDS は画像を `tf.uint8` として提供しますが、モデルは `tf.float32` を期待するため、画像を正規化します。\n", "- `tf.data.Dataset.cache`: データセットがメモリに収まる場合、シャッフル前にキャッシュすると、パフォーマンスを改善できます。
**注意:** ランダム変換は、キャッシュの後に適用してください。\n", "- `tf.data.Dataset.shuffle`: 真のランダム性を得るには、シャッフルバッファをデータセットの完全なサイズに設定してください。
**注意:** メモリに収まらない大きなデータセットについては、システムで可能な場合は `buffer_size=1000` にします。\n", "- `tf.data.Dataset.batch`: シャッフルの後にバッチ処理を行い、各エポックで一意のバッチを取得します。\n", "- `tf.data.Dataset.prefetch`: プリフェッチによってパイプラインを終了し、[パフォーマンスを向上](https://www.tensorflow.org/guide/data_performance#prefetching)させます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "haykx2K9XgiI" }, "outputs": [], "source": [ "def normalize_img(image, label):\n", " \"\"\"Normalizes images: `uint8` -> `float32`.\"\"\"\n", " return tf.cast(image, tf.float32) / 255., label\n", "\n", "ds_train = ds_train.map(\n", " normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n", "ds_train = ds_train.cache()\n", "ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n", "ds_train = ds_train.batch(128)\n", "ds_train = ds_train.prefetch(tf.data.AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "RbsMy4X1XVFv" }, "source": [ "### 評価パイプラインを構築する\n", "\n", "テストのパイプラインはトレーニングのパイプラインと似ていますが、次のようにわずかに異なります。\n", "\n", "- `tf.data.Dataset.shuffle` を呼び出す必要はありません。\n", "- エポック間のバッチが同一である可能性があるのでキャッシュはバッチ処理の後に行われます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A0KjuDf7XiqY" }, "outputs": [], "source": [ "ds_test = ds_test.map(\n", " normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n", "ds_test = ds_test.batch(128)\n", "ds_test = ds_test.cache()\n", "ds_test = ds_test.prefetch(tf.data.AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "nTFoji3INMEM" }, "source": [ "## 手順 2: モデルを作成してトレーニングする\n", "\n", "TFDS 入力パイプラインを簡単な Keras モデルにプラグインし、モデルをコンパイルしてトレーニングします。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XWqxdmS1NLKA" }, "outputs": [], "source": [ "model = tf.keras.models.Sequential([\n", " tf.keras.layers.Flatten(input_shape=(28, 28)),\n", " tf.keras.layers.Dense(128, activation='relu'),\n", " tf.keras.layers.Dense(10)\n", "])\n", "model.compile(\n", " optimizer=tf.keras.optimizers.Adam(0.001),\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n", ")\n", "\n", "model.fit(\n", " ds_train,\n", " epochs=6,\n", " validation_data=ds_test,\n", ")" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "keras_example.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }