{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "6bYaCABobL5q" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "FlUw7tSKbtg4" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "61dp4Hg5ksTC" }, "source": [ "# Migrating model checkpoints\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "avuMwzscPnHh" }, "source": [ "Note: Checkpoints saved with `tf.compat.v1.Saver` are often referred as *TF1 or name-based* checkpoints. Checkpoints saved with `tf.train.Checkpoint` are referred as *TF2 or object-based* checkpoints.\n", "\n", "\n", "## Overview \n", "This guide assumes that you have a model that saves and loads checkpoints with [`tf.compat.v1.Saver`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Saver), and want to migrate the code use the TF2 [`tf.train.Checkpoint`](https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint) API, or use pre-existing checkpoints in your TF2 model.\n", "\n", "Below are some common scenarios that you may encounter:\n", "\n", "**Scenario 1**\n", "\n", "There are existing TF1 checkpoints from previous training runs that need to be loaded or converted to TF2.\n", "\n", "* To load the TF1 checkpoint in TF2, see the snippet [*Load a TF1 checkpoint in TF2*](#load-tf1-in-tf2).\n", "* To convert the checkpoint to TF2, see [*Checkpoint conversion*](#checkpoint-conversion).\n", " \n", "**Scenario 2**\n", "\n", "You are adjusting your model in a way that risks changing variable names and paths (such as when incrementally migrating away from `get_variable` to explicit `tf.Variable` creation), and would like to maintain saving/loading of existing checkpoints along the way.\n", "\n", "See the section on [*How to maintain checkpoint compatibility during model migration*](#maintain-checkpoint-compat)\n", "\n", "**Scenario 3**\n", "\n", "You are migrating your training code and checkpoints to TF2, but your inference pipeline continues to require TF1 checkpoints for now (for production stability).\n", "\n", "*Option 1*\n", "\n", "Save both TF1 and TF2 checkpoints when training. \n", "\n", "* see [*Save a TF1 checkpoint in TF2*](#save-tf1-in-tf2)\n", "\n", "*Option 2*\n", "\n", "Convert the TF2 checkpoint to TF1.\n", "\n", "* see [*Checkpoint conversion*](#checkpoint-conversion)\n", "\n", "\n", "\n", "---\n", "\n", "\n", "The examples below show all the combinations of saving and loading checkpoints in TF1/TF2, so you have some flexibility in determining how to migrate your model." ] }, { "cell_type": "markdown", "metadata": { "id": "TaYgaekzOAHf" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kcvTd5QhZ78L" }, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow.compat.v1 as tf1\n", "\n", "def print_checkpoint(save_path):\n", " reader = tf.train.load_checkpoint(save_path)\n", " shapes = reader.get_variable_to_shape_map()\n", " dtypes = reader.get_variable_to_dtype_map()\n", " print(f\"Checkpoint at '{save_path}':\")\n", " for key in shapes:\n", " print(f\" (key='{key}', shape={shapes[key]}, dtype={dtypes[key].name}, \"\n", " f\"value={reader.get_tensor(key)})\")" ] }, { "cell_type": "markdown", "metadata": { "id": "gO8Q6QkulJlj" }, "source": [ "## Changes from TF1 to TF2\n", "\n", "This section is included if you are curious about what has changed between TF1 and TF2, and what we mean by \"name-based\" (TF1) vs \"object-based\" (TF2) checkpoints. \n", "\n", "The two types of checkpoints are actually saved in the same format, which is essentially a key-value table. The difference lies in how the keys are generated.\n", "\n", " The keys in named-based checkpoints are the **names of the variables**. The keys in object-based checkpoints refer to the **path from the root object to the variable** (the examples below will help to get a better sense of what this means).\n", "\n", " First, save some checkpoints:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8YXzbXvOWvdF" }, "outputs": [], "source": [ "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " with tf1.Session() as sess:\n", " saver = tf1.train.Saver()\n", " sess.run(a.assign(1))\n", " sess.run(b.assign(2))\n", " sess.run(c.assign(3))\n", " saver.save(sess, 'tf1-ckpt')\n", "\n", "print_checkpoint('tf1-ckpt')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "raOych1UaJzl" }, "outputs": [], "source": [ "a = tf.Variable(5.0, name='a')\n", "b = tf.Variable(6.0, name='b')\n", "with tf.name_scope('scoped'):\n", " c = tf.Variable(7.0, name='c')\n", "\n", "ckpt = tf.train.Checkpoint(variables=[a, b, c])\n", "save_path_v2 = ckpt.save('tf2-ckpt')\n", "print_checkpoint(save_path_v2)" ] }, { "cell_type": "markdown", "metadata": { "id": "UYyLhTYszcpl" }, "source": [ "If you look at the keys in `tf2-ckpt`, they all refer to the object paths of each variable. For example, variable `a` is the first element in the `variables` list, so its key becomes `variables/0/...` (feel free to ignore the .ATTRIBUTES/VARIABLE_VALUE constant).\n", "\n", "A closer inspection of the `Checkpoint` object below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kLOxvoosg4Al" }, "outputs": [], "source": [ "a = tf.Variable(0.)\n", "b = tf.Variable(0.)\n", "c = tf.Variable(0.)\n", "root = ckpt = tf.train.Checkpoint(variables=[a, b, c])\n", "print(\"root type =\", type(root).__name__)\n", "print(\"root.variables =\", root.variables)\n", "print(\"root.variables[0] =\", root.variables[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "1Qaed1yAm3Ar" }, "source": [ "Try experimenting with the below snippet and see how the checkpoint keys change with the object structure:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EdHJXlZOyDnn" }, "outputs": [], "source": [ "module = tf.Module()\n", "module.d = tf.Variable(0.)\n", "test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b}, \n", " c=c,\n", " module=module)\n", "test_ckpt_path = test_ckpt.save('root-tf2-ckpt')\n", "print_checkpoint(test_ckpt_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "8iWitEsayDWs" }, "source": [ "*Why does TF2 use this mechanism?* \n", "\n", "Because there is no more global graph in TF2, variable names are unreliable and can be inconsistent between programs. TF2 encourages the object-oriented modelling approach where variables are owned by layers, and layers are owned by a model:\n", "\n", "```\n", "variable = tf.Variable(...)\n", "layer.variable_name = variable\n", "model.layer_name = layer\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "9kv9SmyVjGLA" }, "source": [ "## How to maintain checkpoint compatibility during model migration\n", "\n", "\n", "\n", "One important step in the migration process is *ensuring that all variables are initialized to the correct values*, which in turn allows you to validate that the ops/functions are doing the correct computations. To accomplish this, you must consider the **checkpoint compatibility** between models in the various stages of migration. Essentially, this section answers the question, *how do I keep using the same checkpoint while changing the model*.\n", "\n", "Below are three ways of maintaining checkpoint compatibility, in order of increasing flexibility:\n", "\n", "1. The model has the **same variable names** as before.\n", "2. The model has different variable names, and maintains a **assignment map** that maps variable names in the checkpoint to the new names.\n", "3. The model has different variable names, and maintains a **TF2 Checkpoint object** that stores all of the variables." ] }, { "cell_type": "markdown", "metadata": { "id": "L5JhCyPZDx43" }, "source": [ "### When the variable names match\n", "Long title: How to re-use checkpoints when the variable names match.\n", "\n", "Short answer: You can directly load the pre-existing checkpoint with either `tf1.train.Saver` or `tf.train.Checkpoint`.\n", "\n", "---\n", "\n", "If you are using `tf.compat.v1.keras.utils.track_tf1_style_variables`, then it will ensure that your model variable names are the same as before. You can also manually ensure that variable names match.\n", "\n", "When the variable names match in the migrated models, you may directly use either `tf.train.Checkpoint` or `tf.compat.v1.train.Saver` to load the checkpoint. Both APIs are compatible with eager and graph mode, so you can use them at any stage of the migration.\n", "\n", "Note: You can use `tf.train.Checkpoint` to load TF1 checkpoints, but you cannot use `tf.compat.v1.Saver` to load TF2 checkpoints without complicated name matching.\n", "\n", "Below are examples of using the same checkpoint with different models. First, save a TF1 checkpoint with `tf1.train.Saver`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ijlHS96URsfR" }, "outputs": [], "source": [ "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " with tf1.Session() as sess:\n", " saver = tf1.train.Saver()\n", " sess.run(a.assign(1))\n", " sess.run(b.assign(2))\n", " sess.run(c.assign(3))\n", " save_path = saver.save(sess, 'tf1-ckpt')\n", "print_checkpoint(save_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "zg7nWZphQD9u" }, "source": [ "The example below uses `tf.compat.v1.Saver` to load the checkpoint while in eager mode:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y4K16m0PPncQ" }, "outputs": [], "source": [ "a = tf.Variable(0.0, name='a')\n", "b = tf.Variable(0.0, name='b')\n", "with tf.name_scope('scoped'):\n", " c = tf.Variable(0.0, name='c')\n", "\n", "# With the removal of collections in TF2, you must pass in the list of variables\n", "# to the Saver object:\n", "saver = tf1.train.Saver(var_list=[a, b, c])\n", "saver.restore(sess=None, save_path=save_path)\n", "print(f\"loaded values of [a, b, c]: [{a.numpy()}, {b.numpy()}, {c.numpy()}]\")\n", "\n", "# Saving also works in eager (sess must be None).\n", "path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')\n", "print_checkpoint(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "dWnq1f5yAPkq" }, "source": [ "The next snippet loads the checkpoint using the TF2 API `tf.train.Checkpoint`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "StyrzwGvW1YZ" }, "outputs": [], "source": [ "a = tf.Variable(0.0, name='a')\n", "b = tf.Variable(0.0, name='b')\n", "with tf.name_scope('scoped'):\n", " c = tf.Variable(0.0, name='c')\n", "\n", "# Without the name_scope, name=\"scoped/c\" works too:\n", "c_2 = tf.Variable(0.0, name='scoped/c')\n", "\n", "print(\"Variable names: \")\n", "print(f\" a.name = {a.name}\")\n", "print(f\" b.name = {b.name}\")\n", "print(f\" c.name = {c.name}\")\n", "print(f\" c_2.name = {c_2.name}\")\n", "\n", "# Restore the values with tf.train.Checkpoint\n", "ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2])\n", "ckpt.restore(save_path)\n", "print(f\"loaded values of [a, b, c, c_2]: [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DYYgbj8F7Yb7" }, "source": [ "#### Variable names in TF2\n", "\n", "\n", "- Variables still all have a `name` argument you can set.\n", "- Keras models also take a `name` argument as which they set as the prefix for their variables.\n", "- The `v1.name_scope` function can be used to set variable name prefixes. This is very different from `tf.variable_scope`. It only affects names, and doesn't track variables and reuse.\n", "\n", "\n", "The `tf.compat.v1.keras.utils.track_tf1_style_variables` decorator is a shim that helps you maintain variable names and TF1 checkpoint compatibility, by keeping the naming and reuse semantics of `tf.variable_scope` and `tf.compat.v1.get_variable` unchanged. See the [Model mapping guide](./model_mapping.ipynb) for more info. \n", "\n", "**Note 1: If you are using the shim, use TF2 APIs to load your checkpoints (even when using pre-trained TF1 checkpoints).**\n", "\n", "See the section *Checkpoint Keras*.\n", "\n", "**Note 2: When migrating to `tf.Variable` from `get_variable`:**\n", "\n", "If your shim-decorated layer or module consists of some variables (or Keras layers/models) that use `tf.Variable` instead of `tf.compat.v1.get_variable` and get attached as properties/tracked in an object oriented way, they may have different variable naming semantics in TF1.x graphs/sessions versus during eager execution.\n", "\n", "In short, *the names may not be what you expect them to be* when running in TF2.\n", "\n", "Warning: Variables may have duplicate names in eager execution, which may cause problems if multiple variables in the name-based checkpoint need to be mapped to the same name. You may be able to explicitly adjust the layer and variable names using `tf.name_scope` and layer constructor or `tf.Variable` `name` arguments to adjust variable names and ensure there are no duplicates." ] }, { "cell_type": "markdown", "metadata": { "id": "NkUQJUUyjOJz" }, "source": [ "### Maintaining assignment maps\n", "\n", "Assignment maps are commonly used to transfer weights between TF1 models, and can also be used during your model migration if the variable names change. \n", "\n", "You can use these maps with [`tf.compat.v1.train.init_from_checkpoint`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/init_from_checkpoint), [`tf.compat.v1.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Saver), and [`tf.train.load_checkpoint`](https://www.tensorflow.org/api_docs/python/tf/train/load_checkpoint) to load weights into models in which the variable or scope names may have changed.\n", "\n", "The examples in this section will use a previously saved checkpoint:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PItyo7DdJ6Ek" }, "outputs": [], "source": [ "print_checkpoint('tf1-ckpt')" ] }, { "cell_type": "markdown", "metadata": { "id": "rPryV_WBJrI3" }, "source": [ "#### Loading with `init_from_checkpoint`\n", "\n", "[`tf1.train.init_from_checkpoint`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/init_from_checkpoint) must be called while in a Graph/Session, because it places the values in the variable initializers instead of creating an assign op. \n", "\n", "You can use the `assignment_map` argument to configure how the variables are loaded. From the documentation:\n", "> Assignment map supports following syntax:\n", " * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in\n", " current `scope_name` from `checkpoint_scope_name` with matching tensor\n", " names.\n", " * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -\n", " will initialize `scope_name/variable_name` variable\n", " from `checkpoint_scope_name/some_other_variable`.\n", " * `'scope_variable_name': variable` - will initialize given `tf.Variable`\n", " object with tensor 'scope_variable_name' from the checkpoint.\n", " * `'scope_variable_name': list(variable)` - will initialize list of\n", " partitioned variables with tensor 'scope_variable_name' from the checkpoint.\n", " * `'/': 'scope_name/'` - will load all variables in current `scope_name` from\n", " checkpoint's root (e.g. no scope).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZM_7OzRpdH0A" }, "outputs": [], "source": [ "# Restoring with tf1.train.init_from_checkpoint:\n", "\n", "# A new model with a different scope for the variables.\n", "with tf.Graph().as_default() as g:\n", " with tf1.variable_scope('new_scope'):\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " with tf1.Session() as sess:\n", " # The assignment map will remap all variables in the checkpoint to the\n", " # new scope:\n", " tf1.train.init_from_checkpoint(\n", " 'tf1-ckpt',\n", " assignment_map={'/': 'new_scope/'})\n", " # `init_from_checkpoint` adds the initializers to these variables.\n", " # Use `sess.run` to run these initializers.\n", " sess.run(tf1.global_variables_initializer())\n", "\n", " print(\"Restored [a, b, c]: \", sess.run([a, b, c]))" ] }, { "cell_type": "markdown", "metadata": { "id": "Za_8xhFWKVlH" }, "source": [ "#### Loading with `tf1.train.Saver`\n", "\n", "Unlike `init_from_checkpoint`, [`tf.compat.v1.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Saver) runs in both graph and eager mode. The `var_list` argument optionally accepts a dictionary, except it must map variable names to the `tf.Variable` object.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IiKNmdGJgoX9" }, "outputs": [], "source": [ "# Restoring with tf1.train.Saver (works in both graph and eager):\n", "\n", "# A new model with a different scope for the variables.\n", "with tf1.variable_scope('new_scope'):\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", "# Initialize the saver with a dictionary with the original variable names:\n", "saver = tf1.train.Saver({'a': a, 'b': b, 'scoped/c': c})\n", "saver.restore(sess=None, save_path='tf1-ckpt')\n", "print(\"Restored [a, b, c]: \", [a.numpy(), b.numpy(), c.numpy()])" ] }, { "cell_type": "markdown", "metadata": { "id": "7JsgCXt3Ly-h" }, "source": [ "#### Loading with `tf.train.load_checkpoint`\n", "\n", "This option is for you if you need precise control over the variable values. Again, this works in both graph and eager modes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pc39Bh6JMso6" }, "outputs": [], "source": [ "# Restoring with tf.train.load_checkpoint (works in both graph and eager):\n", "\n", "# A new model with a different scope for the variables.\n", "with tf.Graph().as_default() as g:\n", " with tf1.variable_scope('new_scope'):\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " with tf1.Session() as sess:\n", " # It may be easier writing a loop if your model has a lot of variables.\n", " reader = tf.train.load_checkpoint('tf1-ckpt')\n", " sess.run(a.assign(reader.get_tensor('a')))\n", " sess.run(b.assign(reader.get_tensor('b')))\n", " sess.run(c.assign(reader.get_tensor('scoped/c')))\n", " print(\"Restored [a, b, c]: \", sess.run([a, b, c]))" ] }, { "cell_type": "markdown", "metadata": { "id": "nBSTJVCNDKed" }, "source": [ "### Maintaining a TF2 Checkpoint object\n", "If the variable and scope names may change a lot during the migration, then use `tf.train.Checkpoint` and TF2 checkpoints. TF2 uses the **object structure** instead of variable names (more details in *Changes from TF1 to TF2*).\n", "\n", "In short, when creating a `tf.train.Checkpoint` to save or restore checkpoints, make sure it uses the same **ordering** (for lists) and **keys** (for dictionaries and keyword arguments to the `Checkpoint` initializer). Some examples of checkpoint compatibility:\n", "\n", "```\n", "ckpt = tf.train.Checkpoint(foo=[var_a, var_b])\n", "\n", "# compatible with ckpt\n", "tf.train.Checkpoint(foo=[var_a, var_b])\n", "\n", "# not compatible with ckpt\n", "tf.train.Checkpoint(foo=[var_b, var_a])\n", "tf.train.Checkpoint(bar=[var_a, var_b])\n", "```\n", "\n", "The code samples below show how to use the \"same\" `tf.train.Checkpoint` to load variables with different names. First, save a TF2 checkpoint:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tCSkz_-Tbct6" }, "outputs": [], "source": [ "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(1))\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(2))\n", " with tf1.variable_scope('scoped'):\n", " c = tf1.get_variable('c', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(3))\n", " with tf1.Session() as sess:\n", " sess.run(tf1.global_variables_initializer())\n", " print(\"[a, b, c]: \", sess.run([a, b, c]))\n", "\n", " # Save a TF2 checkpoint\n", " ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])\n", " tf2_ckpt_path = ckpt.save('tf2-ckpt')\n", " print_checkpoint(tf2_ckpt_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "62MWdZMxezeP" }, "source": [ "You can keep using `tf.train.Checkpoint` even if the variable/scope names change:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vh61SGeqb27b" }, "outputs": [], "source": [ "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a_different_name', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " b = tf1.get_variable('b_different_name', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " with tf1.variable_scope('different_scope'):\n", " c = tf1.get_variable('c', shape=[], dtype=tf.float32, \n", " initializer=tf1.zeros_initializer())\n", " with tf1.Session() as sess:\n", " sess.run(tf1.global_variables_initializer())\n", " print(\"Initialized [a, b, c]: \", sess.run([a, b, c]))\n", "\n", " ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])\n", " # `assert_consumed` validates that all checkpoint objects are restored from\n", " # the checkpoint. `run_restore_ops` is required when running in a TF1\n", " # session.\n", " ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()\n", "\n", " # Removing `assert_consumed` is fine if you want to skip the validation.\n", " # ckpt.restore(tf2_ckpt_path).run_restore_ops()\n", "\n", " print(\"Restored [a, b, c]: \", sess.run([a, b, c]))" ] }, { "cell_type": "markdown", "metadata": { "id": "unDPmL-kldr2" }, "source": [ "And in eager mode:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "79S0zMAnfzx7" }, "outputs": [], "source": [ "a = tf.Variable(0.)\n", "b = tf.Variable(0.)\n", "c = tf.Variable(0.)\n", "print(\"Initialized [a, b, c]: \", [a.numpy(), b.numpy(), c.numpy()])\n", "\n", "# The keys \"scoped\" and \"unscoped\" are no longer relevant, but are used to\n", "# maintain compatibility with the saved checkpoints.\n", "ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])\n", "\n", "ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()\n", "print(\"Restored [a, b, c]: \", [a.numpy(), b.numpy(), c.numpy()])" ] }, { "cell_type": "markdown", "metadata": { "id": "dKfNAr8l3aFg" }, "source": [ "## TF2 checkpoints in Estimator\n", "\n", "The sections above describe how to maintain checkpoint compatiblity while migrating your model. These concepts also apply for Estimator models, although the way the checkpoint is saved/loaded is slightly different. As you migrate your Estimator model to use TF2 APIs, you may want to switch from TF1 to TF2 checkpoints *while the model is still using the estimator*. This sections shows how to do so.\n", "\n", "[`tf.estimator.Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) and [`MonitoredSession`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/MonitoredSession) have a saving mechanism called the `scaffold`, a [`tf.compat.v1.train.Scaffold`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Scaffold) object. The `Scaffold` can contain a `tf1.train.Saver` or `tf.train.Checkpoint`, which enables `Estimator` and `MonitoredSession` to save TF1- or TF2-style checkpoints.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "D8AT_oO-5TXU" }, "outputs": [], "source": [ "# A model_fn that saves a TF1 checkpoint\n", "def model_fn_tf1_ckpt(features, labels, mode):\n", " # This model adds 2 to the variable `v` in every train step.\n", " train_step = tf1.train.get_or_create_global_step()\n", " v = tf1.get_variable('var', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " return tf.estimator.EstimatorSpec(\n", " mode,\n", " predictions=v,\n", " train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),\n", " loss=tf.constant(1.),\n", " scaffold=None\n", " )\n", "\n", "!rm -rf est-tf1\n", "est = tf.estimator.Estimator(model_fn_tf1_ckpt, 'est-tf1')\n", "\n", "def train_fn():\n", " return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))\n", "est.train(train_fn, steps=1)\n", "\n", "latest_checkpoint = tf.train.latest_checkpoint('est-tf1')\n", "print_checkpoint(latest_checkpoint) " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ttH6cUrl7jK2" }, "outputs": [], "source": [ "# A model_fn that saves a TF2 checkpoint\n", "def model_fn_tf2_ckpt(features, labels, mode):\n", " # This model adds 2 to the variable `v` in every train step.\n", " train_step = tf1.train.get_or_create_global_step()\n", " v = tf1.get_variable('var', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " ckpt = tf.train.Checkpoint(var_list={'var': v}, step=train_step)\n", " return tf.estimator.EstimatorSpec(\n", " mode,\n", " predictions=v,\n", " train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),\n", " loss=tf.constant(1.),\n", " scaffold=tf1.train.Scaffold(saver=ckpt)\n", " )\n", "\n", "!rm -rf est-tf2\n", "est = tf.estimator.Estimator(model_fn_tf2_ckpt, 'est-tf2',\n", " warm_start_from='est-tf1')\n", "\n", "def train_fn():\n", " return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))\n", "est.train(train_fn, steps=1)\n", "\n", "latest_checkpoint = tf.train.latest_checkpoint('est-tf2')\n", "print_checkpoint(latest_checkpoint) \n", "\n", "assert est.get_variable_value('var_list/var/.ATTRIBUTES/VARIABLE_VALUE') == 4" ] }, { "cell_type": "markdown", "metadata": { "id": "hYVYgahE8daL" }, "source": [ "The final value of `v` should be `16`, after being warm-started from `est-tf1`, then trained for an additional 5 steps. The train step value doesn't carry over from the `warm_start` checkpoint.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Pq8EjblQUIA2" }, "source": [ "## Checkpointing Keras\n", "\n", "Models built with Keras still use `tf1.train.Saver` and `tf.train.Checkpoint` to load pre-existing weights. When your model is fully migrated, switch to using `model.save_weights` and `model.load_weights`, especially if you are using the `ModelCheckpoint` callback when training.\n", "\n", "Some things you should know about checkpoints and Keras:\n", "\n", "**Initialization vs Building**\n", "\n", "Keras models and layers must go through **two steps** before being fully created. First is the *initialization* of the Python object: `layer = tf.keras.layers.Dense(x)`. Second is the *build* step, in which most of the weights are actually created: `layer.build(input_shape)`. You can also build a model by calling it or running a single `train`, `eval`, or `predict` step (the first time only).\n", "\n", "If you find that `model.load_weights(path).assert_consumed()` is raising an error, then it is likely that the model/layers have not been built. \n", "\n", "**Keras uses TF2 checkpoints**\n", "\n", "`tf.train.Checkpoint(model).write` is equivalent to `model.save_weights`. Same with `tf.train.Checkpoint(model).read` and `model.load_weights`. Note that `Checkpoint(model) != Checkpoint(model=model)`.\n", "\n", "**TF2 checkpoints work with Keras's `build()` step**\n", "\n", "`tf.train.Checkpoint.restore` has a mechanism called *deferred restoration* which\n", "allows `tf.Module` and Keras objects to store variable values if the variable has not yet been created. This allows *initialized* models to load weights and *build* after.\n", "\n", "```\n", "m = YourKerasModel()\n", "status = m.load_weights(path)\n", "\n", "# This call builds the model. The variables are created with the restored\n", "# values.\n", "m.predict(inputs)\n", "\n", "status.assert_consumed()\n", "```\n", "\n", "Because of this mechanism, we highly recommend that you use TF2 checkpoint loading APIs with Keras models (even when restoring pre-existing TF1 checkpoints into the [model mapping shims](./model_mapping.ipynb)). See more in the [checkpoint guide](https://www.tensorflow.org/guide/checkpoint#delayed_restorations).\n" ] }, { "cell_type": "markdown", "metadata": { "id": "xO2NucRtqMm6" }, "source": [ "## Code Snippets\n", "\n", "The snippets below show the TF1/TF2 version compatibility in the checkpoint saving APIs. " ] }, { "cell_type": "markdown", "metadata": { "id": "C3SSc74olkX3" }, "source": [ "### Save a TF1 checkpoint in TF2\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t2ZPk8BPloE1" }, "outputs": [], "source": [ "a = tf.Variable(1.0, name='a')\n", "b = tf.Variable(2.0, name='b')\n", "with tf.name_scope('scoped'):\n", " c = tf.Variable(3.0, name='c')\n", "\n", "saver = tf1.train.Saver(var_list=[a, b, c])\n", "path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')\n", "print_checkpoint(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "BxyN5khVjhmA" }, "source": [ "### Load a TF1 checkpoint in TF2\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z5kSXy3FmA79" }, "outputs": [], "source": [ "a = tf.Variable(0., name='a')\n", "b = tf.Variable(0., name='b')\n", "with tf.name_scope('scoped'):\n", " c = tf.Variable(0., name='c')\n", "print(\"Initialized [a, b, c]: \", [a.numpy(), b.numpy(), c.numpy()])\n", "saver = tf1.train.Saver(var_list=[a, b, c])\n", "saver.restore(sess=None, save_path='tf1-ckpt-saved-in-eager')\n", "print(\"Restored [a, b, c]: \", [a.numpy(), b.numpy(), c.numpy()])" ] }, { "cell_type": "markdown", "metadata": { "id": "Ul3V4pEwloeN" }, "source": [ "### Save a TF2 checkpoint in TF1" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UhuP_2EIlRm4" }, "outputs": [], "source": [ "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(1))\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(2))\n", " with tf1.variable_scope('scoped'):\n", " c = tf1.get_variable('c', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(3))\n", " with tf1.Session() as sess:\n", " sess.run(tf1.global_variables_initializer())\n", " ckpt = tf.train.Checkpoint(\n", " var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})\n", " tf2_in_tf1_path = ckpt.save('tf2-ckpt-saved-in-session')\n", " print_checkpoint(tf2_in_tf1_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "GiViCjCDgxhz" }, "source": [ "### Load a TF2 checkpoint in TF1\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "j-4hIPZvmXlb" }, "outputs": [], "source": [ "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " with tf1.variable_scope('scoped'):\n", " c = tf1.get_variable('c', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " with tf1.Session() as sess:\n", " sess.run(tf1.global_variables_initializer())\n", " print(\"Initialized [a, b, c]: \", sess.run([a, b, c]))\n", " ckpt = tf.train.Checkpoint(\n", " var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})\n", " ckpt.restore('tf2-ckpt-saved-in-session-1').run_restore_ops()\n", " print(\"Restored [a, b, c]: \", sess.run([a, b, c]))" ] }, { "cell_type": "markdown", "metadata": { "id": "oRrSE2X6sgAM" }, "source": [ "## Checkpoint conversion\n", "\n", "\n", "\n", "You can convert checkpoints between TF1 and TF2 by loading and re-saving the checkpoints. An alternative is `tf.train.load_checkpoint`, shown in the code below." ] }, { "cell_type": "markdown", "metadata": { "id": "o9KByaLous4q" }, "source": [ "### Convert TF1 checkpoint to TF2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NG8grCv2smAb" }, "outputs": [], "source": [ "def convert_tf1_to_tf2(checkpoint_path, output_prefix):\n", " \"\"\"Converts a TF1 checkpoint to TF2.\n", "\n", " To load the converted checkpoint, you must build a dictionary that maps\n", " variable names to variable objects.\n", " ```\n", " ckpt = tf.train.Checkpoint(vars={name: variable}) \n", " ckpt.restore(converted_ckpt_path)\n", " ```\n", "\n", " Args:\n", " checkpoint_path: Path to the TF1 checkpoint.\n", " output_prefix: Path prefix to the converted checkpoint.\n", "\n", " Returns:\n", " Path to the converted checkpoint.\n", " \"\"\"\n", " vars = {}\n", " reader = tf.train.load_checkpoint(checkpoint_path)\n", " dtypes = reader.get_variable_to_dtype_map()\n", " for key in dtypes.keys():\n", " vars[key] = tf.Variable(reader.get_tensor(key))\n", " return tf.train.Checkpoint(vars=vars).save(output_prefix)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "TyvqK6Sb3dad" }, "source": [ "Convert the checkpoint saved in the snippet `Save a TF1 checkpoint in TF2`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gcHLN4lPvYvw" }, "outputs": [], "source": [ "# Make sure to run the snippet in `Save a TF1 checkpoint in TF2`.\n", "print_checkpoint('tf1-ckpt-saved-in-eager')\n", "converted_path = convert_tf1_to_tf2('tf1-ckpt-saved-in-eager', \n", " 'converted-tf1-to-tf2')\n", "print(\"\\n[Converted]\")\n", "print_checkpoint(converted_path)\n", "\n", "# Try loading the converted checkpoint.\n", "a = tf.Variable(0.)\n", "b = tf.Variable(0.)\n", "c = tf.Variable(0.)\n", "ckpt = tf.train.Checkpoint(vars={'a': a, 'b': b, 'scoped/c': c})\n", "ckpt.restore(converted_path).assert_consumed()\n", "print(\"\\nRestored [a, b, c]: \", [a.numpy(), b.numpy(), c.numpy()])" ] }, { "cell_type": "markdown", "metadata": { "id": "fokg6ybZvE20" }, "source": [ "### Convert TF2 checkpoint to TF1" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NPQsXQveuQiC" }, "outputs": [], "source": [ "def convert_tf2_to_tf1(checkpoint_path, output_prefix):\n", " \"\"\"Converts a TF2 checkpoint to TF1.\n", "\n", " The checkpoint must be saved using a \n", " `tf.train.Checkpoint(var_list={name: variable})`\n", "\n", " To load the converted checkpoint with `tf.compat.v1.Saver`:\n", " ```\n", " saver = tf.compat.v1.train.Saver(var_list={name: variable}) \n", "\n", " # An alternative, if the variable names match the keys:\n", " saver = tf.compat.v1.train.Saver(var_list=[variables]) \n", " saver.restore(sess, output_path)\n", " ```\n", " \"\"\"\n", " vars = {}\n", " reader = tf.train.load_checkpoint(checkpoint_path)\n", " dtypes = reader.get_variable_to_dtype_map()\n", " for key in dtypes.keys():\n", " # Get the \"name\" from the \n", " if key.startswith('var_list/'):\n", " var_name = key.split('/')[1]\n", " # TF2 checkpoint keys use '/', so if they appear in the user-defined name,\n", " # they are escaped to '.S'.\n", " var_name = var_name.replace('.S', '/')\n", " vars[var_name] = tf.Variable(reader.get_tensor(key))\n", " \n", " return tf1.train.Saver(var_list=vars).save(sess=None, save_path=output_prefix)" ] }, { "cell_type": "markdown", "metadata": { "id": "VjZD_OSf1mKX" }, "source": [ "Convert the checkpoint saved in the snippet `Save a TF2 checkpoint in TF1`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vc1MVeV6z2DB" }, "outputs": [], "source": [ "# Make sure to run the snippet in `Save a TF2 checkpoint in TF1`.\n", "print_checkpoint('tf2-ckpt-saved-in-session-1')\n", "converted_path = convert_tf2_to_tf1('tf2-ckpt-saved-in-session-1',\n", " 'converted-tf2-to-tf1')\n", "print(\"\\n[Converted]\")\n", "print_checkpoint(converted_path)\n", "\n", "# Try loading the converted checkpoint.\n", "with tf.Graph().as_default() as g:\n", " a = tf1.get_variable('a', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " b = tf1.get_variable('b', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " with tf1.variable_scope('scoped'):\n", " c = tf1.get_variable('c', shape=[], dtype=tf.float32, \n", " initializer=tf1.constant_initializer(0))\n", " with tf1.Session() as sess:\n", " saver = tf1.train.Saver([a, b, c])\n", " saver.restore(sess, converted_path)\n", " print(\"\\nRestored [a, b, c]: \", sess.run([a, b, c]))" ] }, { "cell_type": "markdown", "metadata": { "id": "JBMfArLQ0jb-" }, "source": [ "## Related Guides\n", "\n", "* [Validating numerical equivalence and correctness](./validate_correctness.ipynb)\n", "* [Model mapping guide](./model_mapping.ipynb) and `tf.compat.v1.keras.utils.track_tf1_style_variables`\n", "* [TF2 Checkpoint guide](https://www.tensorflow.org/guide/checkpoint)." ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "migrating_checkpoints.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }