{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "6bYaCABobL5q" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T22:28:43.744580Z", "iopub.status.busy": "2022-12-14T22:28:43.744096Z", "iopub.status.idle": "2022-12-14T22:28:43.747705Z", "shell.execute_reply": "2022-12-14T22:28:43.747156Z" }, "id": "FlUw7tSKbtg4" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "61dp4Hg5ksTC" }, "source": [ "# モデルチェックポイントの移行\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示\n", " Google Colab で実行\n", "GitHub でソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "avuMwzscPnHh" }, "source": [ "注意: `tf.compat.v1.Saver` で保存されたチェックポイントは、多くの場合、*TF1 または名前ベース*のチェックポイントと呼ばれます。`tf.train.Checkpoint` で保存されたチェックポイントは、*TF2 またはオブジェクトベース*のチェックポイントと呼ばれます。\n", "\n", "## 概要\n", "\n", "このガイドでは、[`tf.compat.v1.Saver`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Saver) を使用してチェックポイントを保存および読み込むモデルがあり、TF2 [`tf.train.Checkpoint`](https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint) API を使用してコードを移行するか、TF2 モデルで既存のチェックポイントを使用する方法を実演します。\n", "\n", "以下に、一般的なシナリオをいくつか示します。\n", "\n", "**シナリオ 1**\n", "\n", "以前に実行したトレーニングからの既存の TF1 チェックポイントを TF2 に読み込むまたは変換する必要があります。\n", "\n", "- TF2 に TF1 チェックポイントを読み込むには、スニペット [*TF2 に TF1 チェックポイントを読み込む*](#load-tf1-in-tf2)を参照してください。\n", "- チェックポイントを TF2 に変換するには、[*チェックポイントの変換*](#checkpoint-conversion)を参照してください。\n", "\n", "**シナリオ 2**\n", "\n", "モデルを調整する際に変数名とパスを変更するリスクがある場合(`get_variable` から明示的な `tf.Variable` の作成に段階的に移行する場合など)、途中で既存のチェックポイントの保存/読み込みを維持したいと考えています。\n", "\n", "[*モデルの移行中にチェックポイントの互換性を維持する方法*](#maintain-checkpoint-compat)のセクションを参照してください。\n", "\n", "**シナリオ 3**\n", "\n", "トレーニングコードとチェックポイントを TF2 に移行していますが、推論パイプラインには引き続き TF1 チェックポイントが必要です(本番環境の安定性のため)。\n", "\n", "*オプション 1*\n", "\n", "トレーニング時に TF1 と TF2 の両方のチェックポイントを保存します。\n", "\n", "- [*TF1 チェックポイントを TF2 に保存する*](#save-tf1-in-tf2)を参照してください。\n", "\n", "*オプション 2*\n", "\n", "TF2 チェックポイントを TF1 に変換します。\n", "\n", "- [*チェックポイント変換*](#checkpoint-conversion)を参照してください\n", "\n", "---\n", "\n", "以下の例は、モデルの移行方法を柔軟に決定できるように TF1/TF2 でのチェックポイントの保存と読み込みのすべての組み合わせを示しています。" ] }, { "cell_type": "markdown", "metadata": { "id": "TaYgaekzOAHf" }, "source": [ "## セットアップ" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:43.751046Z", "iopub.status.busy": "2022-12-14T22:28:43.750613Z", "iopub.status.idle": "2022-12-14T22:28:45.650070Z", "shell.execute_reply": "2022-12-14T22:28:45.649431Z" }, "id": "kcvTd5QhZ78L" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 22:28:44.682638: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:28:44.682724: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:28:44.682733: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import tensorflow as tf\n", "import tensorflow.compat.v1 as tf1\n", "\n", "def print_checkpoint(save_path):\n", " reader = tf.train.load_checkpoint(save_path)\n", " shapes = reader.get_variable_to_shape_map()\n", " dtypes = reader.get_variable_to_dtype_map()\n", " print(f\"Checkpoint at '{save_path}':\")\n", " for key in shapes:\n", " print(f\" (key='{key}', shape={shapes[key]}, dtype={dtypes[key].name}, \"\n", " f\"value={reader.get_tensor(key)})\")" ] }, { "cell_type": "markdown", "metadata": { "id": "gO8Q6QkulJlj" }, "source": [ "## TF1 から TF2 への変更\n", "\n", "このセクションは、TF1 と TF2 の間で何が変更されたか、および「名前ベース」(TF1)と「オブジェクトベース」(TF2)のチェックポイントの意味について説明します。\n", "\n", "2 種類のチェックポイントは、実際には同じ形式(基本的にはキーと値の表)で保存されます。違いは、キーの生成方法にあります。\n", "\n", "名前ベースのチェックポイントのキーは、**変数の名前**です。オブジェクトベースのチェックポイントのキーは、**ルートオブジェクトから変数へのパス**を参照します(以下の例は、これが何を意味するかをよりよく理解するのに役立ちます)。\n", "\n", "まず、いくつかのチェックポイントを保存します。\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:45.654723Z", "iopub.status.busy": "2022-12-14T22:28:45.653935Z", "iopub.status.idle": "2022-12-14T22:28:49.069188Z", "shell.execute_reply": "2022-12-14T22:28:49.068498Z" }, "id": "8YXzbXvOWvdF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint at 'tf1-ckpt':\n", " (key='scoped/c', shape=[], dtype=float32, value=3.0)\n", " (key='b', shape=[], dtype=float32, value=2.0)\n", " (key='a', shape=[], dtype=float32, value=1.0)\n" ] } ], "source": [ "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " with tf1.Session() as sess:\n", " saver = tf1.train.Saver()\n", " sess.run(a.assign(1))\n", " sess.run(b.assign(2))\n", " sess.run(c.assign(3))\n", " saver.save(sess, 'tf1-ckpt')\n", "\n", "print_checkpoint('tf1-ckpt')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.072793Z", "iopub.status.busy": "2022-12-14T22:28:49.072156Z", "iopub.status.idle": "2022-12-14T22:28:49.123848Z", "shell.execute_reply": "2022-12-14T22:28:49.123193Z" }, "id": "raOych1UaJzl" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint at 'tf2-ckpt-1':\n", " (key='variables/2/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=7.0)\n", " (key='variables/1/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=6.0)\n", " (key='variables/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=5.0)\n", " (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)\n", " (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b\"\\n%\\n\\r\\x08\\x01\\x12\\tvariables\\n\\x10\\x08\\x02\\x12\\x0csave_counter*\\x02\\x08\\x01\\n\\x19\\n\\x05\\x08\\x03\\x12\\x010\\n\\x05\\x08\\x04\\x12\\x011\\n\\x05\\x08\\x05\\x12\\x012*\\x02\\x08\\x01\\nM\\x12G\\n\\x0eVARIABLE_VALUE\\x12\\x0csave_counter\\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nA\\x12;\\n\\x0eVARIABLE_VALUE\\x12\\x01a\\x1a&variables/0/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nA\\x12;\\n\\x0eVARIABLE_VALUE\\x12\\x01b\\x1a&variables/1/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nH\\x12B\\n\\x0eVARIABLE_VALUE\\x12\\x08scoped/c\\x1a&variables/2/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\")\n" ] } ], "source": [ "a = tf.Variable(5.0, name='a')\n", "b = tf.Variable(6.0, name='b')\n", "with tf.name_scope('scoped'):\n", " c = tf.Variable(7.0, name='c')\n", "\n", "ckpt = tf.train.Checkpoint(variables=[a, b, c])\n", "save_path_v2 = ckpt.save('tf2-ckpt')\n", "print_checkpoint(save_path_v2)" ] }, { "cell_type": "markdown", "metadata": { "id": "UYyLhTYszcpl" }, "source": [ "`tf2-ckpt` のキーを見ると、それらはすべて各変数のオブジェクトパスを参照しています。たとえば、変数 `a` は `variables` リストの最初の要素であるため、そのキーは `variables/0/...` になります (.ATTRIBUTES/VARIABLE_VALUE 定数は無視できます)。\n", "\n", "以下では `Checkpoint` オブジェクトを詳しく見てみます。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.127634Z", "iopub.status.busy": "2022-12-14T22:28:49.127114Z", "iopub.status.idle": "2022-12-14T22:28:49.134716Z", "shell.execute_reply": "2022-12-14T22:28:49.134093Z" }, "id": "kLOxvoosg4Al" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root type = Checkpoint\n", "root.variables = ListWrapper([, , ])\n", "root.variables[0] = \n" ] } ], "source": [ "a = tf.Variable(0.)\n", "b = tf.Variable(0.)\n", "c = tf.Variable(0.)\n", "root = ckpt = tf.train.Checkpoint(variables=[a, b, c])\n", "print(\"root type =\", type(root).__name__)\n", "print(\"root.variables =\", root.variables)\n", "print(\"root.variables[0] =\", root.variables[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "1Qaed1yAm3Ar" }, "source": [ "以下のスニペットを試してみて、オブジェクト構造によってチェックポイントキーがどのように変化するかを確認してください。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.138101Z", "iopub.status.busy": "2022-12-14T22:28:49.137626Z", "iopub.status.idle": "2022-12-14T22:28:49.151315Z", "shell.execute_reply": "2022-12-14T22:28:49.150710Z" }, "id": "EdHJXlZOyDnn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint at 'root-tf2-ckpt-1':\n", " (key='v/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)\n", " (key='v/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)\n", " (key='module/d/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)\n", " (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)\n", " (key='c/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)\n", " (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b\"\\n0\\n\\x05\\x08\\x01\\x12\\x01c\\n\\n\\x08\\x02\\x12\\x06module\\n\\x05\\x08\\x03\\x12\\x01v\\n\\x10\\x08\\x04\\x12\\x0csave_counter*\\x02\\x08\\x01\\n>\\x128\\n\\x0eVARIABLE_VALUE\\x12\\x08Variable\\x1a\\x1cc/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n\\x0b\\n\\x05\\x08\\x05\\x12\\x01d*\\x02\\x08\\x01\\n\\x12\\n\\x05\\x08\\x06\\x12\\x01a\\n\\x05\\x08\\x07\\x12\\x01b*\\x02\\x08\\x01\\nM\\x12G\\n\\x0eVARIABLE_VALUE\\x12\\x0csave_counter\\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nE\\x12?\\n\\x0eVARIABLE_VALUE\\x12\\x08Variable\\x1a#module/d/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n@\\x12:\\n\\x0eVARIABLE_VALUE\\x12\\x08Variable\\x1a\\x1ev/a/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n@\\x12:\\n\\x0eVARIABLE_VALUE\\x12\\x08Variable\\x1a\\x1ev/b/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\")\n" ] } ], "source": [ "module = tf.Module()\n", "module.d = tf.Variable(0.)\n", "test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b}, \n", " c=c,\n", " module=module)\n", "test_ckpt_path = test_ckpt.save('root-tf2-ckpt')\n", "print_checkpoint(test_ckpt_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "8iWitEsayDWs" }, "source": [ "*なぜ TF2 はこのメカニズムを使用するのでしょうか。*\n", "\n", "TF2 にはグローバルグラフがないため、変数名は信頼できず、プログラム間で矛盾する可能性があります。TF2 は、変数がレイヤーによって所有され、レイヤーがモデルによって所有されるオブジェクト指向モデリングアプローチを推奨します。\n", "\n", "```\n", "variable = tf.Variable(...)\n", "layer.variable_name = variable\n", "model.layer_name = layer\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "9kv9SmyVjGLA" }, "source": [ "## モデルの移行中にチェックポイントの互換性を維持する方法\n", "\n", "\n", "\n", "移行プロセスの重要なステップの 1 つは、*すべての変数が正しい値に初期化されていることを確認する*ことです。これにより、演算や関数が正しい計算を行っていることを検証できます。そのためには、移行のさまざまな段階でモデル間の**チェックポイントの互換性**を考慮する必要があります。基本的に、このセクションでは、*モデルを変更しながら同じチェックポイントを使い続けるにはどうすればよいか*という質問に答えます。\n", "\n", "以下に、柔軟性を高めるために、チェックポイントの互換性を維持する 3 つの方法を示します。\n", "\n", "1. モデルには以前と**同じ変数名**があります。\n", "2. モデルにはさまざまな変数名があり、チェックポイント内の変数名を新しい名前にマッピングする**割り当てマップ**を維持します。\n", "3. モデルにはさまざまな変数名があり、すべての変数を格納する **TF2 チェックポイントオブジェクト**を維持しています。" ] }, { "cell_type": "markdown", "metadata": { "id": "L5JhCyPZDx43" }, "source": [ "### 変数名が一致する場合\n", "\n", "長いタイトル: 変数名が一致する場合にチェックポイントを再利用する方法。\n", "\n", "簡単な答え: `tf1.train.Saver` または `tf.train.Checkpoint` のいずれかを使用して、既存のチェックポイントを直接読み込むことができます。\n", "\n", "---\n", "\n", "`tf.compat.v1.keras.utils.track_tf1_style_variables` を使用するとモデル変数名が以前と同じであることを保証できます。また、変数名が一致することを手動で確認することもできます。\n", "\n", "移行されたモデルで変数名が一致する場合、`tf.train.Checkpoint` または `tf.compat.v1.train.Saver` のいずれかを直接使用してチェックポイントを読み込めます。どちらの API も Eager モードと Graph モードと互換性があるため、移行のどの段階でも使用できます。\n", "\n", "注意: `tf.train.Checkpoint` を使用して TF1 チェックポイントを読み込むことはできますが、`tf.compat.v1.Saver` を使用して TF2 チェックポイントを読み込むには複雑な名前の照合が必要です。\n", "\n", "以下は、異なるモデルで同じチェックポイントを使用する例です。 まず、TF1 チェックポイントを `tf1.train.Saver` で保存します。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.154956Z", "iopub.status.busy": "2022-12-14T22:28:49.154354Z", "iopub.status.idle": "2022-12-14T22:28:49.215328Z", "shell.execute_reply": "2022-12-14T22:28:49.214690Z" }, "id": "ijlHS96URsfR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint at 'tf1-ckpt':\n", " (key='scoped/c', shape=[], dtype=float32, value=3.0)\n", " (key='b', shape=[], dtype=float32, value=2.0)\n", " (key='a', shape=[], dtype=float32, value=1.0)\n" ] } ], "source": [ "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " with tf1.Session() as sess:\n", " saver = tf1.train.Saver()\n", " sess.run(a.assign(1))\n", " sess.run(b.assign(2))\n", " sess.run(c.assign(3))\n", " save_path = saver.save(sess, 'tf1-ckpt')\n", "print_checkpoint(save_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "zg7nWZphQD9u" }, "source": [ "以下の例では、`tf.compat.v1.Saver` を使用して、Eager モードでチェックポイントを読み込みます。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.218678Z", "iopub.status.busy": "2022-12-14T22:28:49.218078Z", "iopub.status.idle": "2022-12-14T22:28:49.235910Z", "shell.execute_reply": "2022-12-14T22:28:49.235223Z" }, "id": "Y4K16m0PPncQ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from tf1-ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loaded values of [a, b, c]: [1.0, 2.0, 3.0]\n", "Checkpoint at 'tf1-ckpt-saved-in-eager':\n", " (key='scoped/c', shape=[], dtype=float32, value=3.0)\n", " (key='b', shape=[], dtype=float32, value=2.0)\n", " (key='a', shape=[], dtype=float32, value=1.0)\n" ] } ], "source": [ "a = tf.Variable(0.0, name='a')\n", "b = tf.Variable(0.0, name='b')\n", "with tf.name_scope('scoped'):\n", " c = tf.Variable(0.0, name='c')\n", "\n", "# With the removal of collections in TF2, you must pass in the list of variables\n", "# to the Saver object:\n", "saver = tf1.train.Saver(var_list=[a, b, c])\n", "saver.restore(sess=None, save_path=save_path)\n", "print(f\"loaded values of [a, b, c]: [{a.numpy()}, {b.numpy()}, {c.numpy()}]\")\n", "\n", "# Saving also works in eager (sess must be None).\n", "path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')\n", "print_checkpoint(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "dWnq1f5yAPkq" }, "source": [ "次のスニペットは、TF2 API `tf.train.Checkpoint` を使用してチェックポイントを読み込みます。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.239308Z", "iopub.status.busy": "2022-12-14T22:28:49.238704Z", "iopub.status.idle": "2022-12-14T22:28:49.252798Z", "shell.execute_reply": "2022-12-14T22:28:49.252224Z" }, "id": "StyrzwGvW1YZ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable names: \n", " a.name = a:0\n", " b.name = b:0\n", " c.name = scoped/c:0\n", " c_2.name = scoped/c:0\n", "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/checkpoint/checkpoint.py:1473: NameBasedSaverStatus.__init__ (from tensorflow.python.checkpoint.checkpoint) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loaded values of [a, b, c, c_2]: [1.0, 2.0, 3.0, 3.0]\n" ] } ], "source": [ "a = tf.Variable(0.0, name='a')\n", "b = tf.Variable(0.0, name='b')\n", "with tf.name_scope('scoped'):\n", " c = tf.Variable(0.0, name='c')\n", "\n", "# Without the name_scope, name=\"scoped/c\" works too:\n", "c_2 = tf.Variable(0.0, name='scoped/c')\n", "\n", "print(\"Variable names: \")\n", "print(f\" a.name = {a.name}\")\n", "print(f\" b.name = {b.name}\")\n", "print(f\" c.name = {c.name}\")\n", "print(f\" c_2.name = {c_2.name}\")\n", "\n", "# Restore the values with tf.train.Checkpoint\n", "ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2])\n", "ckpt.restore(save_path)\n", "print(f\"loaded values of [a, b, c, c_2]: [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DYYgbj8F7Yb7" }, "source": [ "#### TF2 の変数名\n", "\n", "- 変数はすべて設定が可能な `name` 引数を持ちます。\n", "- また、Keras モデルは `name` 引数を取り、それらの変数のためのプレフィックスとして設定されます。\n", "- `v1.name_scope` 関数は、変数名のプレフィックスの設定に使用できます。これは `tf.variable_scope` とは大きく異なります。これは名前だけに影響するもので、変数と再利用の追跡はしません。\n", "\n", "`tf.compat.v1.keras.utils.track_tf1_style_variables` デコレータは、`tf.variable_scope` と `tf.compat.v1.get_variable` の命名と再利用のセマンティクスを変更せずに維持し、変数名と TF1 チェックポイントの互換性を維持するのに役立つ shim です。詳細については、[モデルマッピングガイド](./model_mapping.ipynb)を参照してください。\n", "\n", "**注意 1: shim を使用している場合は、TF2 API を使用してチェックポイントを読み込みます(事前トレーニング済みの TF1 チェックポイントを使用する場合でも)。**\n", "\n", "*Keras のチェックポイント*のセクションを参照してください。\n", "\n", "**注意 2: `get_variable` から `tf.Variable` に移行する場合:**\n", "\n", "shim でデコレートされたレイヤーまたはモジュールが、`tf.compat.v1.get_variable` の代わりに `tf.Variable` を使用するいくつかの変数(または Keras レイヤー/モデル)で構成されていて、プロパティとしてアタッチされる場合やオブジェクト指向の方法で追跡される場合、TF1.x グラフ/セッションと Eager execution 実行時では、変数の命名セマンティクスが異なる場合があります。\n", "\n", "つまり、TF2 で実行すると、*名前が期待どおりにならない可能性があります*。\n", "\n", "警告: 名前ベースのチェックポイント内の複数の変数を同じ名前にマップする必要がある場合、問題が発生する可能性があります。`tf.name_scope` とレイヤー コンストラクタまたは `tf.Variable` `name` 引数を使用して変数名を調整することで、レイヤーと変数の名前を明示的に調整し、重複がないことを確認できるかもしれません。" ] }, { "cell_type": "markdown", "metadata": { "id": "NkUQJUUyjOJz" }, "source": [ "### 割り当てマップの維持\n", "\n", "割り当てマップは、一般に TF1 モデル間で重みを転送するために使用され、モデルの移行中に変数名が変更された場合にも使用できます。\n", "\n", "これらのマップを使用すると [`tf.compat.v1.train.init_from_checkpoint`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/init_from_checkpoint)、[`tf.compat.v1.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Saver)、および [`tf.train.load_checkpoint`](https://www.tensorflow.org/api_docs/python/tf/train/load_checkpoint) を使用して、変数またはスコープ名が変更されている可能性があるモデルに重みを読み込めます。\n", "\n", "このセクションの例では、以前に保存したチェックポイントを使用します。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.256291Z", "iopub.status.busy": "2022-12-14T22:28:49.255695Z", "iopub.status.idle": "2022-12-14T22:28:49.259362Z", "shell.execute_reply": "2022-12-14T22:28:49.258765Z" }, "id": "PItyo7DdJ6Ek" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint at 'tf1-ckpt':\n", " (key='scoped/c', shape=[], dtype=float32, value=3.0)\n", " (key='b', shape=[], dtype=float32, value=2.0)\n", " (key='a', shape=[], dtype=float32, value=1.0)\n" ] } ], "source": [ "print_checkpoint('tf1-ckpt')" ] }, { "cell_type": "markdown", "metadata": { "id": "rPryV_WBJrI3" }, "source": [ "#### `init_from checkpoint` で読み込む\n", "\n", "[`tf1.train.init_from_checkpoint`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/init_from_checkpoint) は、割り当て演算を作成する代わりに変数イニシャライザに値を配置するため、グラフ/セッション内で呼び出す必要があります。\n", "\n", "`assignment_map` 引数を使用して、変数を読み込む方法を構成します。ドキュメントから以下を実行します。\n", "\n", "> 割り当てマップは、次の構文をサポートしています。\n", "\n", "- `'checkpoint_scope_name/': 'scope_name/'` - テンソル名が一致する `checkpoint_scope_name` から最新の `scope_name` 内のすべての変数を読み込みます。\n", "- `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - `checkpoint_scope_name/some_other_variable` から `scope_name/variable_name` 変数を初期化します。\n", "- `'scope_variable_name': variable` - 指定された `tf.Variable` オブジェクトをチェックポイントからのテンソル 'scope_variable_name' で初期化します。\n", "- `'scope_variable_name': list(variable)` - チェックポイントからテンソル 'scope_variable_name' を使用して、分割された変数のリストを初期化します。\n", "- `'/': 'scope_name/'` - 最新の `scope_name` 内のすべての変数をチェックポイントのルートから読み込みます(例: スコープなし)。\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.262408Z", "iopub.status.busy": "2022-12-14T22:28:49.262015Z", "iopub.status.idle": "2022-12-14T22:28:49.307035Z", "shell.execute_reply": "2022-12-14T22:28:49.306388Z" }, "id": "ZM_7OzRpdH0A" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Restored [a, b, c]: [1.0, 2.0, 3.0]\n" ] } ], "source": [ "# Restoring with tf1.train.init_from_checkpoint:\n", "\n", "# A new model with a different scope for the variables.\n", "with tf.Graph().as_default() as g:\n", " with tf1.variable_scope('new_scope'):\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " with tf1.Session() as sess:\n", " # The assignment map will remap all variables in the checkpoint to the\n", " # new scope:\n", " tf1.train.init_from_checkpoint(\n", " 'tf1-ckpt',\n", " assignment_map={'/': 'new_scope/'})\n", " # `init_from_checkpoint` adds the initializers to these variables.\n", " # Use `sess.run` to run these initializers.\n", " sess.run(tf1.global_variables_initializer())\n", "\n", " print(\"Restored [a, b, c]: \", sess.run([a, b, c]))" ] }, { "cell_type": "markdown", "metadata": { "id": "Za_8xhFWKVlH" }, "source": [ "#### `tf1.train.Saver` で読み込む\n", "\n", "`init_from_checkpoint` とは異なり、[`tf.compat.v1.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Saver) は Graph モードと Eager モードの両方で実行できます。`var_list` 引数はオプションでディクショナリを受け入れますが、変数名を `tf.Variable` オブジェクトにマップする必要があります。\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.310528Z", "iopub.status.busy": "2022-12-14T22:28:49.310037Z", "iopub.status.idle": "2022-12-14T22:28:49.322597Z", "shell.execute_reply": "2022-12-14T22:28:49.321950Z" }, "id": "IiKNmdGJgoX9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from tf1-ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored [a, b, c]: [1.0, 2.0, 3.0]\n" ] } ], "source": [ "# Restoring with tf1.train.Saver (works in both graph and eager):\n", "\n", "# A new model with a different scope for the variables.\n", "with tf1.variable_scope('new_scope'):\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", "# Initialize the saver with a dictionary with the original variable names:\n", "saver = tf1.train.Saver({'a': a, 'b': b, 'scoped/c': c})\n", "saver.restore(sess=None, save_path='tf1-ckpt')\n", "print(\"Restored [a, b, c]: \", [a.numpy(), b.numpy(), c.numpy()])" ] }, { "cell_type": "markdown", "metadata": { "id": "7JsgCXt3Ly-h" }, "source": [ "#### `tf.train.load_checkpoint` で読み込む\n", "\n", "このオプションは、変数値を正確に制御する必要がある場合に適しています。繰り返しますが、これは Graph モードと Eager モードの両方で機能します。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.325765Z", "iopub.status.busy": "2022-12-14T22:28:49.325279Z", "iopub.status.idle": "2022-12-14T22:28:49.372723Z", "shell.execute_reply": "2022-12-14T22:28:49.372070Z" }, "id": "Pc39Bh6JMso6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Restored [a, b, c]: [1.0, 2.0, 3.0]\n" ] } ], "source": [ "# Restoring with tf.train.load_checkpoint (works in both graph and eager):\n", "\n", "# A new model with a different scope for the variables.\n", "with tf.Graph().as_default() as g:\n", " with tf1.variable_scope('new_scope'):\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " with tf1.Session() as sess:\n", " # It may be easier writing a loop if your model has a lot of variables.\n", " reader = tf.train.load_checkpoint('tf1-ckpt')\n", " sess.run(a.assign(reader.get_tensor('a')))\n", " sess.run(b.assign(reader.get_tensor('b')))\n", " sess.run(c.assign(reader.get_tensor('scoped/c')))\n", " print(\"Restored [a, b, c]: \", sess.run([a, b, c]))" ] }, { "cell_type": "markdown", "metadata": { "id": "nBSTJVCNDKed" }, "source": [ "### TF2 チェックポイントオブジェクトの維持\n", "\n", "移行中に変数名とスコープ名が大幅に変更される可能性がある場合は、`tf.train.Checkpoint` と TF2 チェックポイントを使用してください。TF2 は、変数名の代わりに**オブジェクト構造**を使用します(詳細については、*TF1 から TF2 への変更*を参照してください)。\n", "\n", "つまり、チェックポイントを保存または復元する `tf.train.Checkpoint` を作成するときは、同じ**順序**(リストの場合)と**キー**を使用するようにしてください。(`Checkpoint` イニシャライザへのディクショナリとキーワード引数)。以下にチェックポイントの互換性の例を示します。\n", "\n", "```\n", "ckpt = tf.train.Checkpoint(foo=[var_a, var_b])\n", "\n", "# compatible with ckpt\n", "tf.train.Checkpoint(foo=[var_a, var_b])\n", "\n", "# not compatible with ckpt\n", "tf.train.Checkpoint(foo=[var_b, var_a])\n", "tf.train.Checkpoint(bar=[var_a, var_b])\n", "```\n", "\n", "以下のコードサンプルは、「同じ」`tf.train.Checkpoint` を使用して異なる名前の変数を読み込む方法を示しています。まず、TF2 チェックポイントを保存します。\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.376256Z", "iopub.status.busy": "2022-12-14T22:28:49.375640Z", "iopub.status.idle": "2022-12-14T22:28:49.444868Z", "shell.execute_reply": "2022-12-14T22:28:49.444234Z" }, "id": "tCSkz_-Tbct6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[a, b, c]: [1.0, 2.0, 3.0]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint at 'tf2-ckpt-1':\n", " (key='unscoped/1/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)\n", " (key='unscoped/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)\n", " (key='scoped/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)\n", " (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)\n", " (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b\"\\n0\\n\\n\\x08\\x01\\x12\\x06scoped\\n\\x0c\\x08\\x02\\x12\\x08unscoped\\n\\x10\\x08\\x03\\x12\\x0csave_counter*\\x02\\x08\\x01\\n\\x0b\\n\\x05\\x08\\x04\\x12\\x010*\\x02\\x08\\x01\\n\\x12\\n\\x05\\x08\\x05\\x12\\x010\\n\\x05\\x08\\x06\\x12\\x011*\\x02\\x08\\x01\\nM\\x12G\\n\\x0eVARIABLE_VALUE\\x12\\x0csave_counter\\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nE\\x12?\\n\\x0eVARIABLE_VALUE\\x12\\x08scoped/c\\x1a#scoped/0/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n@\\x12:\\n\\x0eVARIABLE_VALUE\\x12\\x01a\\x1a%unscoped/0/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n@\\x12:\\n\\x0eVARIABLE_VALUE\\x12\\x01b\\x1a%unscoped/1/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\")\n" ] } ], "source": [ "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(1))\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(2))\n", " with tf1.variable_scope('scoped'):\n", " c = tf1.get_variable('c', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(3))\n", " with tf1.Session() as sess:\n", " sess.run(tf1.global_variables_initializer())\n", " print(\"[a, b, c]: \", sess.run([a, b, c]))\n", "\n", " # Save a TF2 checkpoint\n", " ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])\n", " tf2_ckpt_path = ckpt.save('tf2-ckpt')\n", " print_checkpoint(tf2_ckpt_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "62MWdZMxezeP" }, "source": [ "変数やスコープ名が変更しても `tf.train.Checkpoint` を引き続き使用できます。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.448195Z", "iopub.status.busy": "2022-12-14T22:28:49.447743Z", "iopub.status.idle": "2022-12-14T22:28:49.503254Z", "shell.execute_reply": "2022-12-14T22:28:49.502673Z" }, "id": "Vh61SGeqb27b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initialized [a, b, c]: [0.0, 0.0, 0.0]\n", "Restored [a, b, c]: [1.0, 2.0, 3.0]\n" ] } ], "source": [ "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a_different_name', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " b = tf1.get_variable('b_different_name', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " with tf1.variable_scope('different_scope'):\n", " c = tf1.get_variable('c', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " with tf1.Session() as sess:\n", " sess.run(tf1.global_variables_initializer())\n", " print(\"Initialized [a, b, c]: \", sess.run([a, b, c]))\n", "\n", " ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])\n", " # `assert_consumed` validates that all checkpoint objects are restored from\n", " # the checkpoint. `run_restore_ops` is required when running in a TF1\n", " # session.\n", " ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()\n", "\n", " # Removing `assert_consumed` is fine if you want to skip the validation.\n", " # ckpt.restore(tf2_ckpt_path).run_restore_ops()\n", "\n", " print(\"Restored [a, b, c]: \", sess.run([a, b, c]))" ] }, { "cell_type": "markdown", "metadata": { "id": "unDPmL-kldr2" }, "source": [ "Eager モード:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.506594Z", "iopub.status.busy": "2022-12-14T22:28:49.506113Z", "iopub.status.idle": "2022-12-14T22:28:49.519055Z", "shell.execute_reply": "2022-12-14T22:28:49.518351Z" }, "id": "79S0zMAnfzx7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initialized [a, b, c]: [0.0, 0.0, 0.0]\n", "Restored [a, b, c]: [1.0, 2.0, 3.0]\n" ] } ], "source": [ "a = tf.Variable(0.)\n", "b = tf.Variable(0.)\n", "c = tf.Variable(0.)\n", "print(\"Initialized [a, b, c]: \", [a.numpy(), b.numpy(), c.numpy()])\n", "\n", "# The keys \"scoped\" and \"unscoped\" are no longer relevant, but are used to\n", "# maintain compatibility with the saved checkpoints.\n", "ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])\n", "\n", "ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()\n", "print(\"Restored [a, b, c]: \", [a.numpy(), b.numpy(), c.numpy()])" ] }, { "cell_type": "markdown", "metadata": { "id": "dKfNAr8l3aFg" }, "source": [ "## Estimator の TF2 チェックポイント\n", "\n", "上記のセクションでは、モデルの移行中にチェックポイントの互換性を維持する方法について説明しました。これらの概念は、Estimator モデルにも適用されますが、チェックポイントの保存/読み込み方法は少し異なります。Estimator モデルを移行して TF2 API を使用する場合、*モデルがまだ Estimator を使用している間に*、TF1 チェックポイントから TF2 チェックポイントに切り替えたい場合があります。このセクションでは、その方法を示します。\n", "\n", "[`tf.estimator.Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) と [`MonitoredSession`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/MonitoredSession) には、`scaffold` と呼ばれる保存メカニズムがあります。これは、[`tf.compat.v1.train.Scaffold`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Scaffold) オブジェクトです。`Scaffold` には、TF1 または TF2 スタイルのチェックポイントを保存するための `Estimator` と `MonitoredSession` が含まれていることがあります。\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.522469Z", "iopub.status.busy": "2022-12-14T22:28:49.521958Z", "iopub.status.idle": "2022-12-14T22:28:49.950514Z", "shell.execute_reply": "2022-12-14T22:28:49.949758Z" }, "id": "D8AT_oO-5TXU" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': 'est-tf1', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into est-tf1/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 1.0, step = 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 1 into est-tf1/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 1.0.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint at 'est-tf1/model.ckpt-1':\n", " (key='var', shape=[], dtype=float32, value=2.0)\n", " (key='global_step', shape=[], dtype=int64, value=1)\n" ] } ], "source": [ "# A model_fn that saves a TF1 checkpoint\n", "def model_fn_tf1_ckpt(features, labels, mode):\n", " # This model adds 2 to the variable `v` in every train step.\n", " train_step = tf1.train.get_or_create_global_step()\n", " v = tf1.get_variable('var', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " return tf.estimator.EstimatorSpec(\n", " mode,\n", " predictions=v,\n", " train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),\n", " loss=tf.constant(1.),\n", " scaffold=None\n", " )\n", "\n", "!rm -rf est-tf1\n", "est = tf.estimator.Estimator(model_fn_tf1_ckpt, 'est-tf1')\n", "\n", "def train_fn():\n", " return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))\n", "est.train(train_fn, steps=1)\n", "\n", "latest_checkpoint = tf.train.latest_checkpoint('est-tf1')\n", "print_checkpoint(latest_checkpoint) " ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:49.954279Z", "iopub.status.busy": "2022-12-14T22:28:49.953649Z", "iopub.status.idle": "2022-12-14T22:28:50.361729Z", "shell.execute_reply": "2022-12-14T22:28:50.360948Z" }, "id": "ttH6cUrl7jK2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': 'est-tf2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='est-tf1', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Warm-starting from: est-tf1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Warm-started 1 variables.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into est-tf2/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 1.0, step = 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 1 into est-tf2/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 1.0.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint at 'est-tf2/model.ckpt-1':\n", " (key='var_list/var/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=4.0)\n", " (key='step/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)\n", " (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b\"\\n\\x1c\\n\\x08\\x08\\x01\\x12\\x04step\\n\\x0c\\x08\\x02\\x12\\x08var_list*\\x02\\x08\\x01\\nD\\x12>\\n\\x0eVARIABLE_VALUE\\x12\\x0bglobal_step\\x1a\\x1fstep/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n\\r\\n\\x07\\x08\\x03\\x12\\x03var*\\x02\\x08\\x01\\nD\\x12>\\n\\x0eVARIABLE_VALUE\\x12\\x03var\\x1a'var_list/var/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\")\n" ] } ], "source": [ "# A model_fn that saves a TF2 checkpoint\n", "def model_fn_tf2_ckpt(features, labels, mode):\n", " # This model adds 2 to the variable `v` in every train step.\n", " train_step = tf1.train.get_or_create_global_step()\n", " v = tf1.get_variable('var', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " ckpt = tf.train.Checkpoint(var_list={'var': v}, step=train_step)\n", " return tf.estimator.EstimatorSpec(\n", " mode,\n", " predictions=v,\n", " train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),\n", " loss=tf.constant(1.),\n", " scaffold=tf1.train.Scaffold(saver=ckpt)\n", " )\n", "\n", "!rm -rf est-tf2\n", "est = tf.estimator.Estimator(model_fn_tf2_ckpt, 'est-tf2',\n", " warm_start_from='est-tf1')\n", "\n", "def train_fn():\n", " return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))\n", "est.train(train_fn, steps=1)\n", "\n", "latest_checkpoint = tf.train.latest_checkpoint('est-tf2')\n", "print_checkpoint(latest_checkpoint) \n", "\n", "assert est.get_variable_value('var_list/var/.ATTRIBUTES/VARIABLE_VALUE') == 4" ] }, { "cell_type": "markdown", "metadata": { "id": "hYVYgahE8daL" }, "source": [ "`v` の最終的な値は、`est-tf1` からウォームスタートし、さらに 5 ステップのトレーニングを行った後、`16` になるはずです。トレーニングステップの値は、`warm_start` チェックポイントから引き継がれません。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Pq8EjblQUIA2" }, "source": [ "## Keras のチェックポイントを設定する\n", "\n", "Keras で構築されたモデルは、引き続き `tf1.train.Saver` と `tf.train.Checkpoint` を使用して既存の重みを読み込みます。モデルの移行が完了したら、特にトレーニング時に `ModelCheckpoint` コールバックを使用している場合は、`model.save_weights` と `model.load_weights` を使用するように切り替えます。\n", "\n", "チェックポイントと Keras について知っておくべきこと:\n", "\n", "**初期化と構築**\n", "\n", "Keras のモデルとレイヤーは、作成を完了する前に **2 つのステップ**が必要があります。1 つ目は、Python オブジェクトの *初期化*: `layer = tf.keras.layers.Dense(x)` です。2 番目は *構築*ステップ `layer.build(input_shape)` で、ほとんどの重みが実際に作成されます。モデルを呼び出すか、単一の `train`、`eval`、または `predict` ステップを実行してモデルを構築することもできます(初回のみ)。\n", "\n", "`model.load_weights(path).assert_consumed()` でエラーが発生している場合は、モデル/レイヤーが構築されていない可能性があります。\n", "\n", "**Keras は TF2 チェックポイントを使用する**\n", "\n", "`tf.train.Checkpoint(model).write` は `model.save_weights` と同等です。また、`tf.train.Checkpoint(model).read` は`model.load_weights` と同等です。`Checkpoint(model) != Checkpoint(model=model)` であることに注意してください。\n", "\n", "**TF2 チェックポイントは Keras の `build()` ステップで機能する**\n", "\n", "`tf.train.Checkpoint.restore` には、*遅延復元*と呼ばれるメカニズムがあります。これにより、変数がまだ作成されていない場合、`tf.Module` と Keras オブジェクトが変数値を格納できるようになり、*初期化された*モデルが重みを読み込んでから*構築*できるようになります。\n", "\n", "```\n", "m = YourKerasModel()\n", "status = m.load_weights(path)\n", "\n", "# This call builds the model. The variables are created with the restored\n", "# values.\n", "m.predict(inputs)\n", "\n", "status.assert_consumed()\n", "```\n", "\n", "このメカニズムのため、Keras モデルで TF2 チェックポイント読み込み API を使用することを強くお勧めします(既存の TF1 チェックポイントを[モデルマッピング shim](./model_mapping.ipynb) に復元する場合でも)。詳しくは[チェックポイントガイド](https://www.tensorflow.org/guide/checkpoint#delayed_restorations)を参照してください。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "xO2NucRtqMm6" }, "source": [ "## コード スニペット\n", "\n", "以下のスニペットは、チェックポイント保存 API における TF1/TF2 バージョンの互換性を示しています。 " ] }, { "cell_type": "markdown", "metadata": { "id": "C3SSc74olkX3" }, "source": [ "### TF1 チェックポイントを TF2 に保存する\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:50.366077Z", "iopub.status.busy": "2022-12-14T22:28:50.365466Z", "iopub.status.idle": "2022-12-14T22:28:50.378320Z", "shell.execute_reply": "2022-12-14T22:28:50.377754Z" }, "id": "t2ZPk8BPloE1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint at 'tf1-ckpt-saved-in-eager':\n", " (key='scoped/c', shape=[], dtype=float32, value=3.0)\n", " (key='b', shape=[], dtype=float32, value=2.0)\n", " (key='a', shape=[], dtype=float32, value=1.0)\n" ] } ], "source": [ "a = tf.Variable(1.0, name='a')\n", "b = tf.Variable(2.0, name='b')\n", "with tf.name_scope('scoped'):\n", " c = tf.Variable(3.0, name='c')\n", "\n", "saver = tf1.train.Saver(var_list=[a, b, c])\n", "path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')\n", "print_checkpoint(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "BxyN5khVjhmA" }, "source": [ "### TF1 チェックポイントを TF2 に 読み込む\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:50.381881Z", "iopub.status.busy": "2022-12-14T22:28:50.381227Z", "iopub.status.idle": "2022-12-14T22:28:50.392889Z", "shell.execute_reply": "2022-12-14T22:28:50.392292Z" }, "id": "Z5kSXy3FmA79" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initialized [a, b, c]: [0.0, 0.0, 0.0]\n", "WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from tf1-ckpt-saved-in-eager\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored [a, b, c]: [1.0, 2.0, 3.0]\n" ] } ], "source": [ "a = tf.Variable(0., name='a')\n", "b = tf.Variable(0., name='b')\n", "with tf.name_scope('scoped'):\n", " c = tf.Variable(0., name='c')\n", "print(\"Initialized [a, b, c]: \", [a.numpy(), b.numpy(), c.numpy()])\n", "saver = tf1.train.Saver(var_list=[a, b, c])\n", "saver.restore(sess=None, save_path='tf1-ckpt-saved-in-eager')\n", "print(\"Restored [a, b, c]: \", [a.numpy(), b.numpy(), c.numpy()])" ] }, { "cell_type": "markdown", "metadata": { "id": "Ul3V4pEwloeN" }, "source": [ "### TF1 に TF2 チェックポイントを保存する" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:50.396335Z", "iopub.status.busy": "2022-12-14T22:28:50.395838Z", "iopub.status.idle": "2022-12-14T22:28:50.460742Z", "shell.execute_reply": "2022-12-14T22:28:50.460147Z" }, "id": "UhuP_2EIlRm4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint at 'tf2-ckpt-saved-in-session-1':\n", " (key='var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)\n", " (key='var_list/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)\n", " (key='var_list/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)\n", " (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)\n", " (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b\"\\n$\\n\\x0c\\x08\\x01\\x12\\x08var_list\\n\\x10\\x08\\x02\\x12\\x0csave_counter*\\x02\\x08\\x01\\n \\n\\x05\\x08\\x03\\x12\\x01a\\n\\x05\\x08\\x04\\x12\\x01b\\n\\x0c\\x08\\x05\\x12\\x08scoped/c*\\x02\\x08\\x01\\nM\\x12G\\n\\x0eVARIABLE_VALUE\\x12\\x0csave_counter\\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n@\\x12:\\n\\x0eVARIABLE_VALUE\\x12\\x01a\\x1a%var_list/a/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n@\\x12:\\n\\x0eVARIABLE_VALUE\\x12\\x01b\\x1a%var_list/b/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nO\\x12I\\n\\x0eVARIABLE_VALUE\\x12\\x08scoped/c\\x1a-var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\")\n" ] } ], "source": [ "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(1))\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(2))\n", " with tf1.variable_scope('scoped'):\n", " c = tf1.get_variable('c', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(3))\n", " with tf1.Session() as sess:\n", " sess.run(tf1.global_variables_initializer())\n", " ckpt = tf.train.Checkpoint(\n", " var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})\n", " tf2_in_tf1_path = ckpt.save('tf2-ckpt-saved-in-session')\n", " print_checkpoint(tf2_in_tf1_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "GiViCjCDgxhz" }, "source": [ "### TF1 に TF2 チェックポイントを読み込む\n" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:50.464232Z", "iopub.status.busy": "2022-12-14T22:28:50.463659Z", "iopub.status.idle": "2022-12-14T22:28:50.519514Z", "shell.execute_reply": "2022-12-14T22:28:50.518870Z" }, "id": "j-4hIPZvmXlb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initialized [a, b, c]: [0.0, 0.0, 0.0]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Restored [a, b, c]: [1.0, 2.0, 3.0]\n" ] } ], "source": [ "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " with tf1.variable_scope('scoped'):\n", " c = tf1.get_variable('c', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " with tf1.Session() as sess:\n", " sess.run(tf1.global_variables_initializer())\n", " print(\"Initialized [a, b, c]: \", sess.run([a, b, c]))\n", " ckpt = tf.train.Checkpoint(\n", " var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})\n", " ckpt.restore('tf2-ckpt-saved-in-session-1').run_restore_ops()\n", " print(\"Restored [a, b, c]: \", sess.run([a, b, c]))" ] }, { "cell_type": "markdown", "metadata": { "id": "oRrSE2X6sgAM" }, "source": [ "## チェックポイント変換\n", "\n", "\n", "\n", "チェックポイントを読み込んで再保存することにより、TF1 と TF2 の間でチェックポイントを変換できます。また、代替手段として、`tf.train.load_checkpoint` を使用できます。以下のコードに示します。" ] }, { "cell_type": "markdown", "metadata": { "id": "o9KByaLous4q" }, "source": [ "### TF1 チェックポイントを TF2 に変換する" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:50.523029Z", "iopub.status.busy": "2022-12-14T22:28:50.522519Z", "iopub.status.idle": "2022-12-14T22:28:50.527036Z", "shell.execute_reply": "2022-12-14T22:28:50.526456Z" }, "id": "NG8grCv2smAb" }, "outputs": [], "source": [ "def convert_tf1_to_tf2(checkpoint_path, output_prefix):\n", " \"\"\"Converts a TF1 checkpoint to TF2.\n", "\n", " To load the converted checkpoint, you must build a dictionary that maps\n", " variable names to variable objects.\n", " ```\n", " ckpt = tf.train.Checkpoint(vars={name: variable}) \n", " ckpt.restore(converted_ckpt_path)\n", " ```\n", "\n", " Args:\n", " checkpoint_path: Path to the TF1 checkpoint.\n", " output_prefix: Path prefix to the converted checkpoint.\n", "\n", " Returns:\n", " Path to the converted checkpoint.\n", " \"\"\"\n", " vars = {}\n", " reader = tf.train.load_checkpoint(checkpoint_path)\n", " dtypes = reader.get_variable_to_dtype_map()\n", " for key in dtypes.keys():\n", " vars[key] = tf.Variable(reader.get_tensor(key))\n", " return tf.train.Checkpoint(vars=vars).save(output_prefix)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "TyvqK6Sb3dad" }, "source": [ "スニペット `Save a TF1 checkpoint in TF2` に保存されているチェックポイントを変換します。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:50.530527Z", "iopub.status.busy": "2022-12-14T22:28:50.529905Z", "iopub.status.idle": "2022-12-14T22:28:50.552781Z", "shell.execute_reply": "2022-12-14T22:28:50.552160Z" }, "id": "gcHLN4lPvYvw" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint at 'tf1-ckpt-saved-in-eager':\n", " (key='scoped/c', shape=[], dtype=float32, value=3.0)\n", " (key='b', shape=[], dtype=float32, value=2.0)\n", " (key='a', shape=[], dtype=float32, value=1.0)\n", "\n", "[Converted]\n", "Checkpoint at 'converted-tf1-to-tf2-1':\n", " (key='vars/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)\n", " (key='vars/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)\n", " (key='vars/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)\n", " (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)\n", " (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b\"\\n \\n\\x08\\x08\\x01\\x12\\x04vars\\n\\x10\\x08\\x02\\x12\\x0csave_counter*\\x02\\x08\\x01\\n \\n\\x0c\\x08\\x03\\x12\\x08scoped/c\\n\\x05\\x08\\x04\\x12\\x01b\\n\\x05\\x08\\x05\\x12\\x01a*\\x02\\x08\\x01\\nM\\x12G\\n\\x0eVARIABLE_VALUE\\x12\\x0csave_counter\\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nK\\x12E\\n\\x0eVARIABLE_VALUE\\x12\\x08Variable\\x1a)vars/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nC\\x12=\\n\\x0eVARIABLE_VALUE\\x12\\x08Variable\\x1a!vars/b/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nC\\x12=\\n\\x0eVARIABLE_VALUE\\x12\\x08Variable\\x1a!vars/a/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\")\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Restored [a, b, c]: [1.0, 2.0, 3.0]\n" ] } ], "source": [ "# Make sure to run the snippet in `Save a TF1 checkpoint in TF2`.\n", "print_checkpoint('tf1-ckpt-saved-in-eager')\n", "converted_path = convert_tf1_to_tf2('tf1-ckpt-saved-in-eager', \n", " 'converted-tf1-to-tf2')\n", "print(\"\\n[Converted]\")\n", "print_checkpoint(converted_path)\n", "\n", "# Try loading the converted checkpoint.\n", "a = tf.Variable(0.)\n", "b = tf.Variable(0.)\n", "c = tf.Variable(0.)\n", "ckpt = tf.train.Checkpoint(vars={'a': a, 'b': b, 'scoped/c': c})\n", "ckpt.restore(converted_path).assert_consumed()\n", "print(\"\\nRestored [a, b, c]: \", [a.numpy(), b.numpy(), c.numpy()])" ] }, { "cell_type": "markdown", "metadata": { "id": "fokg6ybZvE20" }, "source": [ "### TF2 チェックポイントを TF1 に変換する" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:50.556174Z", "iopub.status.busy": "2022-12-14T22:28:50.555639Z", "iopub.status.idle": "2022-12-14T22:28:50.560409Z", "shell.execute_reply": "2022-12-14T22:28:50.559838Z" }, "id": "NPQsXQveuQiC" }, "outputs": [], "source": [ "def convert_tf2_to_tf1(checkpoint_path, output_prefix):\n", " \"\"\"Converts a TF2 checkpoint to TF1.\n", "\n", " The checkpoint must be saved using a \n", " `tf.train.Checkpoint(var_list={name: variable})`\n", "\n", " To load the converted checkpoint with `tf.compat.v1.Saver`:\n", " ```\n", " saver = tf.compat.v1.train.Saver(var_list={name: variable}) \n", "\n", " # An alternative, if the variable names match the keys:\n", " saver = tf.compat.v1.train.Saver(var_list=[variables]) \n", " saver.restore(sess, output_path)\n", " ```\n", " \"\"\"\n", " vars = {}\n", " reader = tf.train.load_checkpoint(checkpoint_path)\n", " dtypes = reader.get_variable_to_dtype_map()\n", " for key in dtypes.keys():\n", " # Get the \"name\" from the \n", " if key.startswith('var_list/'):\n", " var_name = key.split('/')[1]\n", " # TF2 checkpoint keys use '/', so if they appear in the user-defined name,\n", " # they are escaped to '.S'.\n", " var_name = var_name.replace('.S', '/')\n", " vars[var_name] = tf.Variable(reader.get_tensor(key))\n", " \n", " return tf1.train.Saver(var_list=vars).save(sess=None, save_path=output_prefix)" ] }, { "cell_type": "markdown", "metadata": { "id": "VjZD_OSf1mKX" }, "source": [ "スニペット `Save a TF2 checkpoint in TF1` に保存されているチェックポイントを変換します。" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:28:50.563760Z", "iopub.status.busy": "2022-12-14T22:28:50.563179Z", "iopub.status.idle": "2022-12-14T22:28:50.615690Z", "shell.execute_reply": "2022-12-14T22:28:50.615049Z" }, "id": "vc1MVeV6z2DB" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checkpoint at 'tf2-ckpt-saved-in-session-1':\n", " (key='var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)\n", " (key='var_list/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)\n", " (key='var_list/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)\n", " (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)\n", " (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b\"\\n$\\n\\x0c\\x08\\x01\\x12\\x08var_list\\n\\x10\\x08\\x02\\x12\\x0csave_counter*\\x02\\x08\\x01\\n \\n\\x05\\x08\\x03\\x12\\x01a\\n\\x05\\x08\\x04\\x12\\x01b\\n\\x0c\\x08\\x05\\x12\\x08scoped/c*\\x02\\x08\\x01\\nM\\x12G\\n\\x0eVARIABLE_VALUE\\x12\\x0csave_counter\\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n@\\x12:\\n\\x0eVARIABLE_VALUE\\x12\\x01a\\x1a%var_list/a/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\n@\\x12:\\n\\x0eVARIABLE_VALUE\\x12\\x01b\\x1a%var_list/b/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\\nO\\x12I\\n\\x0eVARIABLE_VALUE\\x12\\x08scoped/c\\x1a-var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE*\\x02\\x08\\x01\")\n", "WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "[Converted]\n", "Checkpoint at 'converted-tf2-to-tf1':\n", " (key='scoped/c', shape=[], dtype=float32, value=3.0)\n", " (key='b', shape=[], dtype=float32, value=2.0)\n", " (key='a', shape=[], dtype=float32, value=1.0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from converted-tf2-to-tf1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Restored [a, b, c]: [1.0, 2.0, 3.0]\n" ] } ], "source": [ "# Make sure to run the snippet in `Save a TF2 checkpoint in TF1`.\n", "print_checkpoint('tf2-ckpt-saved-in-session-1')\n", "converted_path = convert_tf2_to_tf1('tf2-ckpt-saved-in-session-1',\n", " 'converted-tf2-to-tf1')\n", "print(\"\\n[Converted]\")\n", "print_checkpoint(converted_path)\n", "\n", "# Try loading the converted checkpoint.\n", "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " with tf1.variable_scope('scoped'):\n", " c = tf1.get_variable('c', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " with tf1.Session() as sess:\n", " saver = tf1.train.Saver([a, b, c])\n", " saver.restore(sess, converted_path)\n", " print(\"\\nRestored [a, b, c]: \", sess.run([a, b, c]))" ] }, { "cell_type": "markdown", "metadata": { "id": "JBMfArLQ0jb-" }, "source": [ "## 関連ガイド\n", "\n", "- [数値の等価性と正確性の検証](./validate_correctness.ipynb)\n", "- [モデルマッピングガイド](./model_mapping.ipynb) と `tf.compat.v1.keras.utils.track_tf1_style_variables`\n", "- [TF2 チェックポイントガイド](https://www.tensorflow.org/guide/checkpoint)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "migrating_checkpoints.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }