{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Pnn4rDWGqDZL" }, "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2021-02-12T21:36:16.959920Z", "iopub.status.busy": "2021-02-12T21:36:16.959335Z", "iopub.status.idle": "2021-02-12T21:36:16.961087Z", "shell.execute_reply": "2021-02-12T21:36:16.961462Z" }, "id": "l534d35Gp68G" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "3TI3Q3XBesaS" }, "source": [ "# トレーニングのチェックポイント" ] }, { "cell_type": "markdown", "metadata": { "id": "yw_a0iGucY8z" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示Google Colab で実行GitHub でソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "LeDp7dovcbus" }, "source": [ "「TensorFlow のモデルを保存する」という言いまわしは通常、次の 2 つのいずれかを意味します。\n", "\n", "1. チェックポイント、または\n", "2. 保存されたモデル(SavedModel)\n", "\n", "チェックポイントは、モデルで使用されるすべてのパラメータ(`tf.Variable`オブジェクト)の正確な値をキャプチャします。チェックポイントにはモデルで定義された計算のいかなる記述も含まれていないため、通常は、保存されたパラメータ値を使用するソースコードが利用可能な場合に限り有用です。\n", "\n", "一方、SavedModel 形式には、パラメータ値(チェックポイント)に加え、モデルで定義された計算のシリアライズされた記述が含まれています。この形式のモデルは、モデルを作成したソースコードから独立しています。したがって、TensorFlow Serving、TensorFlow Lite、TensorFlow.js、または他のプログラミング言語のプログラム(C、C++、Java、Go、Rust、C# などの TensorFlow API)を介したデプロイに適しています。\n", "\n", "このガイドでは、チェックポイントの書き込みと読み取りを行う API について説明します。" ] }, { "cell_type": "markdown", "metadata": { "id": "U0nm8k-6xfh2" }, "source": [ "## セットアップ" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:16.969294Z", "iopub.status.busy": "2021-02-12T21:36:16.968704Z", "iopub.status.idle": "2021-02-12T21:36:22.854596Z", "shell.execute_reply": "2021-02-12T21:36:22.853992Z" }, "id": "VEvpMYAKsC4z" }, "outputs": [], "source": [ "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:22.860193Z", "iopub.status.busy": "2021-02-12T21:36:22.859599Z", "iopub.status.idle": "2021-02-12T21:36:22.861245Z", "shell.execute_reply": "2021-02-12T21:36:22.861591Z" }, "id": "OEQCseyeC4Ev" }, "outputs": [], "source": [ "class Net(tf.keras.Model):\n", " \"\"\"A simple linear model.\"\"\"\n", "\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.l1 = tf.keras.layers.Dense(5)\n", "\n", " def call(self, x):\n", " return self.l1(x)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:22.865355Z", "iopub.status.busy": "2021-02-12T21:36:22.864749Z", "iopub.status.idle": "2021-02-12T21:36:24.472099Z", "shell.execute_reply": "2021-02-12T21:36:24.472513Z" }, "id": "utqeoDADC5ZR" }, "outputs": [], "source": [ "net = Net()" ] }, { "cell_type": "markdown", "metadata": { "id": "5vsq3-pffo1I" }, "source": [ "## `tf.keras`トレーニング API から保存する\n", "\n", "[`tf.keras`の保存と復元に関するガイド](./keras/overview.ipynb#save_and_restore)をご覧ください。\n", "\n", "`tf.keras.Model.save_weights`で TensorFlow チェックポイントを保存します。 " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:24.477581Z", "iopub.status.busy": "2021-02-12T21:36:24.476977Z", "iopub.status.idle": "2021-02-12T21:36:24.487879Z", "shell.execute_reply": "2021-02-12T21:36:24.487438Z" }, "id": "SuhmrYPEl4D_" }, "outputs": [], "source": [ "net.save_weights('easy_checkpoint')" ] }, { "cell_type": "markdown", "metadata": { "id": "XseWX5jDg4lQ" }, "source": [ "## チェックポイントを記述する\n" ] }, { "cell_type": "markdown", "metadata": { "id": "1jpZPz76ZP3K" }, "source": [ "TensorFlow モデルの永続的な状態は、`tf.Variable`オブジェクトに格納されます。これらは直接作成できますが、多くの場合は`tf.keras.layers`や`tf.keras.Model`などの高レベル API を介して作成されます。\n", "\n", "変数を管理する最も簡単な方法は、変数を Python オブジェクトにアタッチし、それらのオブジェクトを参照することです。\n", "\n", "`tf.train.Checkpoint`、`tf.keras.layers.Layer`および`tf.keras.Model`のサブクラスは、属性に割り当てられた変数を自動的に追跡します。以下の例では、単純な線形モデルを作成し、モデルのすべての変数の値を含むチェックポイントを記述します。" ] }, { "cell_type": "markdown", "metadata": { "id": "x0vFBr_Im73_" }, "source": [ "`Model.save_weights`で、モデルチェックポイントを簡単に保存できます。" ] }, { "cell_type": "markdown", "metadata": { "id": "FHTJ1JzxCi8a" }, "source": [ "### 手動チェックポイント" ] }, { "cell_type": "markdown", "metadata": { "id": "6cF9fqYOCrEO" }, "source": [ "#### セットアップ" ] }, { "cell_type": "markdown", "metadata": { "id": "fNjf9KaLdIRP" }, "source": [ "`tf.train.Checkpoint`のすべての機能を実演するために、トイデータセットと最適化ステップを次のように定義します。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:24.493219Z", "iopub.status.busy": "2021-02-12T21:36:24.492598Z", "iopub.status.idle": "2021-02-12T21:36:24.494365Z", "shell.execute_reply": "2021-02-12T21:36:24.494719Z" }, "id": "tSNyP4IJ9nkU" }, "outputs": [], "source": [ "def toy_dataset():\n", " inputs = tf.range(10.)[:, None]\n", " labels = inputs * 5. + tf.range(5.)[None, :]\n", " return tf.data.Dataset.from_tensor_slices(\n", " dict(x=inputs, y=labels)).repeat().batch(2)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:24.499339Z", "iopub.status.busy": "2021-02-12T21:36:24.498745Z", "iopub.status.idle": "2021-02-12T21:36:24.500809Z", "shell.execute_reply": "2021-02-12T21:36:24.500388Z" }, "id": "ICm1cufh_JH8" }, "outputs": [], "source": [ "def train_step(net, example, optimizer):\n", " \"\"\"Trains `net` on `example` using `optimizer`.\"\"\"\n", " with tf.GradientTape() as tape:\n", " output = net(example['x'])\n", " loss = tf.reduce_mean(tf.abs(output - example['y']))\n", " variables = net.trainable_variables\n", " gradients = tape.gradient(loss, variables)\n", " optimizer.apply_gradients(zip(gradients, variables))\n", " return loss" ] }, { "cell_type": "markdown", "metadata": { "id": "vxzGpHRbOVO6" }, "source": [ "#### チェックポイントオブジェクトを作成する\n", "\n", "チェックポイントを手動で作成するには、`tf.train.Checkpoint`オブジェクトが必要です。チェックポイントするオブジェクトの場所は、オブジェクトの属性として設定します。\n", "\n", "`tf.train.CheckpointManager`は、複数のチェックポイントの管理にも役立ちます。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:24.505435Z", "iopub.status.busy": "2021-02-12T21:36:24.504782Z", "iopub.status.idle": "2021-02-12T21:36:24.515774Z", "shell.execute_reply": "2021-02-12T21:36:24.515316Z" }, "id": "ou5qarOQOWYl" }, "outputs": [], "source": [ "opt = tf.keras.optimizers.Adam(0.1)\n", "dataset = toy_dataset()\n", "iterator = iter(dataset)\n", "ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)\n", "manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "8ZbYSD4uCy96" }, "source": [ "#### モデルをトレーニングおよびチェックポイントする" ] }, { "cell_type": "markdown", "metadata": { "id": "NP9IySmCeCkn" }, "source": [ "次のトレーニングループは、モデルとオプティマイザのインスタンスを作成し、それらを`tf.train.Checkpoint`オブジェクトに集めます。それはデータの各バッチのループ内でトレーニングステップを呼び出し、定期的にチェックポイントをディスクに書き込みます。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:24.521365Z", "iopub.status.busy": "2021-02-12T21:36:24.520788Z", "iopub.status.idle": "2021-02-12T21:36:24.522295Z", "shell.execute_reply": "2021-02-12T21:36:24.522718Z" }, "id": "BbCS5A6K1VSH" }, "outputs": [], "source": [ "def train_and_checkpoint(net, manager):\n", " ckpt.restore(manager.latest_checkpoint)\n", " if manager.latest_checkpoint:\n", " print(\"Restored from {}\".format(manager.latest_checkpoint))\n", " else:\n", " print(\"Initializing from scratch.\")\n", "\n", " for _ in range(50):\n", " example = next(iterator)\n", " loss = train_step(net, example, opt)\n", " ckpt.step.assign_add(1)\n", " if int(ckpt.step) % 10 == 0:\n", " save_path = manager.save()\n", " print(\"Saved checkpoint for step {}: {}\".format(int(ckpt.step), save_path))\n", " print(\"loss {:1.2f}\".format(loss.numpy()))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:24.526240Z", "iopub.status.busy": "2021-02-12T21:36:24.525659Z", "iopub.status.idle": "2021-02-12T21:36:25.119783Z", "shell.execute_reply": "2021-02-12T21:36:25.119279Z" }, "id": "Ik3IBMTdPW41" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initializing from scratch.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saved checkpoint for step 10: ./tf_ckpts/ckpt-1\n", "loss 26.91\n", "Saved checkpoint for step 20: ./tf_ckpts/ckpt-2\n", "loss 20.32\n", "Saved checkpoint for step 30: ./tf_ckpts/ckpt-3\n", "loss 13.76\n", "Saved checkpoint for step 40: ./tf_ckpts/ckpt-4\n", "loss 7.35\n", "Saved checkpoint for step 50: ./tf_ckpts/ckpt-5\n", "loss 2.48\n" ] } ], "source": [ "train_and_checkpoint(net, manager)" ] }, { "cell_type": "markdown", "metadata": { "id": "2wzcc1xYN-sH" }, "source": [ "#### 復元してトレーニングを続ける" ] }, { "cell_type": "markdown", "metadata": { "id": "lw1QeyRBgsLE" }, "source": [ "最初の実行後、新しいモデルとマネジャーを渡すことができますが、トレーニングをやめた所からトレーニングを再開します。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:25.126751Z", "iopub.status.busy": "2021-02-12T21:36:25.126167Z", "iopub.status.idle": "2021-02-12T21:36:25.373699Z", "shell.execute_reply": "2021-02-12T21:36:25.373183Z" }, "id": "UjilkTOV2PBK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Restored from ./tf_ckpts/ckpt-5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saved checkpoint for step 60: ./tf_ckpts/ckpt-6\n", "loss 1.38\n", "Saved checkpoint for step 70: ./tf_ckpts/ckpt-7\n", "loss 0.95\n", "Saved checkpoint for step 80: ./tf_ckpts/ckpt-8\n", "loss 0.44\n", "Saved checkpoint for step 90: ./tf_ckpts/ckpt-9\n", "loss 0.35\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saved checkpoint for step 100: ./tf_ckpts/ckpt-10\n", "loss 0.30\n" ] } ], "source": [ "opt = tf.keras.optimizers.Adam(0.1)\n", "net = Net()\n", "dataset = toy_dataset()\n", "iterator = iter(dataset)\n", "ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)\n", "manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)\n", "\n", "train_and_checkpoint(net, manager)" ] }, { "cell_type": "markdown", "metadata": { "id": "dxJT9vV-2PnZ" }, "source": [ "`tf.train.CheckpointManager`オブジェクトは古いチェックポイントを削除します。上記では、最新の 3 つのチェックポイントのみを保持するように構成されています。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:25.378499Z", "iopub.status.busy": "2021-02-12T21:36:25.377723Z", "iopub.status.idle": "2021-02-12T21:36:25.380407Z", "shell.execute_reply": "2021-02-12T21:36:25.380789Z" }, "id": "3zmM0a-F5XqC" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']\n" ] } ], "source": [ "print(manager.checkpoints) # List the three remaining checkpoints" ] }, { "cell_type": "markdown", "metadata": { "id": "qwlYDyjemY4P" }, "source": [ "これらのパス、例えば`'./tf_ckpts/ckpt-10'`などは、ディスク上のファイルではなく、`index`ファイルのプレフィックスで、変数値を含む 1 つまたはそれ以上のデータファイルです。これらのプレフィックスは、まとめて単一の`checkpoint`ファイル(`'./tf_ckpts/checkpoint'`)にグループ化され、`CheckpointManager`がその状態を保存します。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:25.401936Z", "iopub.status.busy": "2021-02-12T21:36:25.384856Z", "iopub.status.idle": "2021-02-12T21:36:25.527770Z", "shell.execute_reply": "2021-02-12T21:36:25.527241Z" }, "id": "t1feej9JntV_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "checkpoint\t\t ckpt-8.data-00000-of-00001 ckpt-9.index\r\n", "ckpt-10.data-00000-of-00001 ckpt-8.index\r\n", "ckpt-10.index\t\t ckpt-9.data-00000-of-00001\r\n" ] } ], "source": [ "!ls ./tf_ckpts" ] }, { "cell_type": "markdown", "metadata": { "id": "DR2wQc9x6b3X" }, "source": [ "\n", "\n", "## 読み込みの仕組み\n", "\n", "TensorFlowは、読み込まれたオブジェクトから始めて、名前付きエッジを持つ有向グラフを走査することにより、変数をチェックポイントされた値に合わせます。エッジ名は通常、オブジェクトの属性名に由来しており、`self.l1 = tf.keras.layers.Dense(5)`の`\"l1\"`などがその例です。`tf.train.Checkpoint`は、`tf.train.Checkpoint(step=...)`の`\"step\"`のように、キーワード引数名を使用します。\n", "\n", "上記の例の依存関係グラフは次のようになります。\n", "\n", "![Visualization of the dependency graph for the example training loop](https://tensorflow.org/images/guide/whole_checkpoint.svg)\n", "\n", "オプティマイザは赤色、通常変数は青色、オプティマイザスロット変数はオレンジ色です。他のノード、例えば`tf.train.Checkpoint`を表すものは黒色です。\n", "\n", "スロット変数はオプティマイザの状態の一部ですが、特定の変数のために作成されます。例えば、上記の`'m'`エッジはモメンタムに対応し、Adam オプティマイザが各変数のために追跡します。スロット変数は変数とオプティマイザの両方が保存される場合に限りチェックポイントに保存されるので、破線のエッジです。" ] }, { "cell_type": "markdown", "metadata": { "id": "VpY5IuanUEQ0" }, "source": [ "`tf.train.Checkpoint`オブジェクト上での`restore()`呼び出しは、要求された復元をキューに入れ、`Checkpoint`オブジェクトから一致するパスがあるとすぐに変数値を復元します。例えば、ネットワークとレイヤーを介してそれへのパスを 1 つ再構築することにより、上記で定義したモデルからバイアスのみを読み込むことができます。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:25.538751Z", "iopub.status.busy": "2021-02-12T21:36:25.537773Z", "iopub.status.idle": "2021-02-12T21:36:25.547320Z", "shell.execute_reply": "2021-02-12T21:36:25.547697Z" }, "id": "wmX2AuyH7TVt" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0. 0. 0. 0. 0.]\n", "[1.8858831 1.9214293 2.8519926 2.9979987 5.1035223]\n" ] } ], "source": [ "to_restore = tf.Variable(tf.zeros([5]))\n", "print(to_restore.numpy()) # All zeros\n", "fake_layer = tf.train.Checkpoint(bias=to_restore)\n", "fake_net = tf.train.Checkpoint(l1=fake_layer)\n", "new_root = tf.train.Checkpoint(net=fake_net)\n", "status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))\n", "print(to_restore.numpy()) # We get the restored value now" ] }, { "cell_type": "markdown", "metadata": { "id": "GqEW-_pJDAnE" }, "source": [ "これらの新しいオブジェクトの依存関係グラフは、上で書いたより大きなチェックポイントの遥かに小さなサブグラフです。 これには、バイアスと`tf.train.Checkpoint`がチェックポイントに番号付けするために使用した保存カウンタのみを含みます。\n", "\n", "![Visualization of a subgraph for the bias variable](https://tensorflow.org/images/guide/partial_checkpoint.svg)\n", "\n", "`restore()`は、オプションのアサーションを持つ状態オブジェクトを返します。新しい`Checkpoint`で作成したすべてのオブジェクトが復元され、status.assert_existing_objects_matched()を渡します。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:25.557012Z", "iopub.status.busy": "2021-02-12T21:36:25.556283Z", "iopub.status.idle": "2021-02-12T21:36:25.559854Z", "shell.execute_reply": "2021-02-12T21:36:25.560296Z" }, "id": "P9TQXl81Dq5r" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "status.assert_existing_objects_matched()" ] }, { "cell_type": "markdown", "metadata": { "id": "GoMwf8CFDu9r" }, "source": [ "チェックポイントには、層のカーネルやオプティマイザの変数など、一致しない多くのオブジェクトがあります。`status.assert_consumed()`は、チェックポイントとプログラムが正確に一致する場合に限り渡すため、ここでは例外をスローします。" ] }, { "cell_type": "markdown", "metadata": { "id": "KCcmJ-2j9RUP" }, "source": [ "### 復元遅延(Delayed restoration)\n", "\n", "TensorFlow の`Layer`オブジェクトは、入力形状が利用可能な場合、最初の呼び出しまで変数の作成を遅らせる可能性があります。例えば、`Dense`レイヤーのカーネルの形状はレイヤーの入力形状と出力形状の両方に依存するため、コンストラクタ引数として必要な出力形状は、単独で変数を作成するために充分な情報ではありません。`Layer`の呼び出しは変数の値も読み取るため、復元は変数の作成とその最初の使用の間で発生する必要があります。\n", "\n", "このイディオムをサポートするために、`tf.train.Checkpoint`は一致する変数をまだ持たない復元をキューに入れます。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:25.565549Z", "iopub.status.busy": "2021-02-12T21:36:25.564761Z", "iopub.status.idle": "2021-02-12T21:36:25.569642Z", "shell.execute_reply": "2021-02-12T21:36:25.570006Z" }, "id": "TXYUCO3v-I72" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 0. 0. 0. 0.]]\n", "[[4.645149 4.866288 4.8621855 5.02036 4.9210296]]\n" ] } ], "source": [ "delayed_restore = tf.Variable(tf.zeros([1, 5]))\n", "print(delayed_restore.numpy()) # Not restored; still zeros\n", "fake_layer.kernel = delayed_restore\n", "print(delayed_restore.numpy()) # Restored" ] }, { "cell_type": "markdown", "metadata": { "id": "-DWhJ3glyobN" }, "source": [ "### チェックポイントを手動で検査する\n", "\n", "`tf.train.list_variables`は、チェックポイントキーとチェックポイント内の変数の形状をリスト表示します。チェックポイントキーは上で示したグラフのパスです。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:25.575441Z", "iopub.status.busy": "2021-02-12T21:36:25.573477Z", "iopub.status.idle": "2021-02-12T21:36:25.578566Z", "shell.execute_reply": "2021-02-12T21:36:25.578035Z" }, "id": "RlRsADTezoBD" }, "outputs": [ { "data": { "text/plain": [ "[('_CHECKPOINTABLE_OBJECT_GRAPH', []),\n", " ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),\n", " ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),\n", " ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),\n", " ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),\n", " ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),\n", " ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',\n", " [1, 5]),\n", " ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',\n", " [1, 5]),\n", " ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),\n", " ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),\n", " ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),\n", " ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),\n", " ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),\n", " ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),\n", " ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))" ] }, { "cell_type": "markdown", "metadata": { "id": "5fxk_BnZ4W1b" }, "source": [ "### リストとディクショナリを追跡する\n", "\n", "`self.l1 = tf.keras.layers.Dense(5)`のような直接の属性割り当てと同様に、リストとディクショナリを属性に割り当てると、それらの内容を追跡します。" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:25.584393Z", "iopub.status.busy": "2021-02-12T21:36:25.583812Z", "iopub.status.idle": "2021-02-12T21:36:25.603285Z", "shell.execute_reply": "2021-02-12T21:36:25.602706Z" }, "id": "rfaIbDtDHAr_" }, "outputs": [], "source": [ "save = tf.train.Checkpoint()\n", "save.listed = [tf.Variable(1.)]\n", "save.listed.append(tf.Variable(2.))\n", "save.mapped = {'one': save.listed[0]}\n", "save.mapped['two'] = save.listed[1]\n", "save_path = save.save('./tf_list_example')\n", "\n", "restore = tf.train.Checkpoint()\n", "v2 = tf.Variable(0.)\n", "assert 0. == v2.numpy() # Not restored yet\n", "restore.mapped = {'two': v2}\n", "restore.restore(save_path)\n", "assert 2. == v2.numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "UTKvbxHcI3T2" }, "source": [ "リストとディクショナリのラッパーオブジェクトにお気づきでしょうか。これらのラッパーは基礎的なデータ構造のチェックポイント可能なバージョンです。属性に基づく読み込みと同様に、これらのラッパーは変数の値がコンテナに追加されるとすぐにそれを復元します。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:25.608488Z", "iopub.status.busy": "2021-02-12T21:36:25.607915Z", "iopub.status.idle": "2021-02-12T21:36:25.611096Z", "shell.execute_reply": "2021-02-12T21:36:25.611473Z" }, "id": "s0Uq1Hv5JCmm" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ListWrapper([])\n" ] } ], "source": [ "restore.listed = []\n", "print(restore.listed) # ListWrapper([])\n", "v1 = tf.Variable(0.)\n", "restore.listed.append(v1) # Restores v1, from restore() in the previous cell\n", "assert 1. == v1.numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "OxCIf2J6JyQ8" }, "source": [ "同じ追跡が`tf.keras.Model`のサブクラスに自動的に適用され、例えばレイヤーのリストの追跡にも使用される可能性があります。" ] }, { "cell_type": "markdown", "metadata": { "id": "zGG1tOM0L6iM" }, "source": [ "## Estimator でオブジェクトベースのチェックポイントを保存する\n", "\n", "[Estimator のガイド](https://www.tensorflow.org/guide/estimator)をご覧ください。\n", "\n", "Estimator はデフォルトで、前のセクションで説明したオブジェクトグラフではなく、変数名でチェックポイントを保存します。`tf.train.Checkpoint`は名前ベースのチェックポイントを受け取りますが、モデルの一部を Estimator の`model_fn`の外側に移動すると変数名が変わることがあります。 オブジェクトベースのチェックポイントを保存すると、Estimator の内側でモデルをトレーニングし、外側でそれを使用することが容易になります。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:25.615495Z", "iopub.status.busy": "2021-02-12T21:36:25.614882Z", "iopub.status.idle": "2021-02-12T21:36:25.655926Z", "shell.execute_reply": "2021-02-12T21:36:25.655284Z" }, "id": "-8AMJeueNyoM" }, "outputs": [], "source": [ "import tensorflow.compat.v1 as tf_compat" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:25.663692Z", "iopub.status.busy": "2021-02-12T21:36:25.662995Z", "iopub.status.idle": "2021-02-12T21:36:26.172433Z", "shell.execute_reply": "2021-02-12T21:36:26.172792Z" }, "id": "T6fQsBzJQN2y" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 4.4524446, step = 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 36.07061.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def model_fn(features, labels, mode):\n", " net = Net()\n", " opt = tf.keras.optimizers.Adam(0.1)\n", " ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),\n", " optimizer=opt, net=net)\n", " with tf.GradientTape() as tape:\n", " output = net(features['x'])\n", " loss = tf.reduce_mean(tf.abs(output - features['y']))\n", " variables = net.trainable_variables\n", " gradients = tape.gradient(loss, variables)\n", " return tf.estimator.EstimatorSpec(\n", " mode,\n", " loss=loss,\n", " train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),\n", " ckpt.step.assign_add(1)),\n", " # Tell the Estimator to save \"ckpt\" in an object-based format.\n", " scaffold=tf_compat.train.Scaffold(saver=ckpt))\n", "\n", "tf.keras.backend.clear_session()\n", "est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')\n", "est.train(toy_dataset, steps=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "tObYHnrrb_mL" }, "source": [ "その後、`tf.train.Checkpoint`は Estimator のチェックポイントをその`model_dir`から読み込むことができます。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2021-02-12T21:36:26.181593Z", "iopub.status.busy": "2021-02-12T21:36:26.180927Z", "iopub.status.idle": "2021-02-12T21:36:26.189230Z", "shell.execute_reply": "2021-02-12T21:36:26.189636Z" }, "id": "Q6IP3Y_wb-fs" }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "opt = tf.keras.optimizers.Adam(0.1)\n", "net = Net()\n", "ckpt = tf.train.Checkpoint(\n", " step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)\n", "ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))\n", "ckpt.step.numpy() # From est.train(..., steps=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "knyUFMrJg8y4" }, "source": [ "## まとめ\n", "\n", "TensorFlow オブジェクトは、それらが使用する変数の値を保存および復元するための容易で自動的な仕組みを提供します。\n" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "checkpoint.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 0 }