{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-08-15T02:19:00.117848Z", "iopub.status.busy": "2024-08-15T02:19:00.117627Z", "iopub.status.idle": "2024-08-15T02:19:00.121381Z", "shell.execute_reply": "2024-08-15T02:19:00.120848Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "D70XgUYdLwI6" }, "source": [ "# TensorFlow 1.x vs TensorFlow 2 - Behaviors and APIs" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "akxmN3SQsEcb" }, "source": [ "Under the hood, TensorFlow 2 follows a fundamentally different programming paradigm from TF1.x.\n", "\n", "This guide describes the fundamental differences between TF1.x and TF2 in terms of behaviors and the APIs, and how these all relate to your migration journey." ] }, { "cell_type": "markdown", "metadata": { "id": "Xzy2mT87mwth" }, "source": [ "## High-level summary of major changes\n", "\n", "Fundamentally, TF1.x and TF2 use a different set of runtime behaviors around execution (eager in TF2), variables, control flow, tensor shapes, and tensor equality comparisons. To be TF2 compatible, your code must be compatible with the full set of TF2 behaviors. During migration, you can enable or disable most of these behaviors individually via the `tf.compat.v1.enable_*` or `tf.compat.v1.disable_*` APIs. The one exception is the removal of collections, which is a side effect of enabling/disabling eager execution.\n", "\n", "At a high level, TensorFlow 2:\n", "\n", "* Removes [redundant APIs](https://github.com/tensorflow/community/blob/master/rfcs/20180827-api-names.md).\n", "* Makes APIs more consistent - for example, \n", "[Unified RNNs](https://github.com/tensorflow/community/blob/master/rfcs/20180920-unify-rnn-interface.md) and \n", "[Unified Optimizers](https://github.com/tensorflow/community/blob/master/rfcs/20181016-optimizer-unification.md).\n", "* Prefers [functions over sessions](https://github.com/tensorflow/community/blob/master/rfcs/20180918-functions-not-sessions-20.md) and integrates better with the Python runtime with\n", "[Eager execution](https://www.tensorflow.org/guide/eager) enabled by default along with `tf.function` that provides automatic control dependencies for graphs and compilation.\n", "* Deprecates global graph [collections](https://github.com/tensorflow/community/blob/master/rfcs/20180905-deprecate-collections.md).\n", "* Alters Variable concurrency semantics by using [`ResourceVariables` over `ReferenceVariables`](https://github.com/tensorflow/community/blob/master/rfcs/20180817-variables-20.md).\n", "* Supports [function-based](https://github.com/tensorflow/community/blob/master/rfcs/20180507-cond-v2.md) and differentiable [control flow](https://github.com/tensorflow/community/blob/master/rfcs/20180821-differentiable-functional-while.md) (Control Flow v2).\n", "* Simplifies the TensorShape API to hold `int`s instead of `tf.compat.v1.Dimension` objects.\n", "* Updates tensor equality mechanics. In TF1.x the `==` operator on tensors and variables checks for object reference equality. In TF2 it checks for value equality. Additionally, tensors/variables are no longer hashable, but you can get hashable object references to them via `var.ref()` if you need to use them in sets or as `dict` keys.\n", "\n", "The sections below provide some more context on the differences between TF1.x and TF2. To learn more about the design process behind TF2, read the\n", "[RFCs](https://github.com/tensorflow/community/pulls?utf8=%E2%9C%93&q=is%3Apr) and the [design docs](https://github.com/tensorflow/community/tree/master/rfcs)." ] }, { "cell_type": "markdown", "metadata": { "id": "dlCiIgEE2OhY" }, "source": [ "## API cleanup\n", "\n", "Many APIs are either [gone or moved](https://github.com/tensorflow/community/blob/master/rfcs/20180827-api-names.md) in TF2. Some of the major changes include removing `tf.app`, `tf.flags`, and `tf.logging` in favor of the now open-source [absl-py](https://github.com/abseil/abseil-py), rehoming projects that lived in `tf.contrib`, and cleaning up the main `tf.*` namespace by moving lesser used functions into subpackages like `tf.math`. Some APIs have been replaced with their TF2 equivalents - `tf.summary`, `tf.keras.metrics`, and\n", "`tf.keras.optimizers`.\n", "\n", "### `tf.compat.v1`: Legacy and Compatibility API Endpoints\n", "\n", "Symbols under the `tf.compat` and `tf.compat.v1` namespaces are not considered TF2 APIs. These namespaces expose a mix of compatibility symbols, as well as legacy API endpoints from TF 1.x. These are intended to aid migration from TF1.x to TF2. However, as none of these `compat.v1` APIs are idiomatic TF2 APIs, do not use them for writing brand-new TF2 code.\n", "\n", "Individual `tf.compat.v1` symbols may be TF2 compatible because they continue to work even with TF2 behaviors enabled (such as `tf.compat.v1.losses.mean_squared_error`), while others are incompatible with TF2 (such as `tf.compat.v1.metrics.accuracy`). Many `compat.v1` symbols (though not all) contain dedicated migration information in their documentation that explains their degree of compatibility with TF2 behaviors, as well as how to migrate them to TF2 APIs.\n", "\n", "The [TF2 upgrade script](https://www.tensorflow.org/guide/migrate/upgrade) can map many `compat.v1` API symbols to equivalent TF2 APIs in the case where they are aliases or have the same arguments but with a different ordering. You can also use the upgrade script to automatically rename TF1.x APIs.\n", "\n", "### False friend APIs\n", "\n", "There are a set of \"false-friend\" symbols found in the TF2 `tf` namespace (not under `compat.v1`) that actually ignore TF2 behaviors under-the-hood, and/or are not fully compatible with the full set of TF2 behaviors. As such, these APIs are likely to misbehave with TF2 code, potentially in silent ways.\n", "\n", "* `tf.estimator.*`: Estimators create and use graphs and sessions under the hood. As such, these should not be considered TF2-compatible. If your code is running estimators, it is not using TF2 behaviors.\n", "* `keras.Model.model_to_estimator(...)`: This creates an Estimator under the hood, which as mentioned above is not TF2-compatible.\n", "* `tf.Graph().as_default()`: This enters TF1.x graph behaviors and does not follow standard TF2-compatible `tf.function` behaviors. Code that enters graphs like this will generally run them via Sessions, and should not be considered TF2-compatible.\n", "* `tf.feature_column.*` The feature column APIs generally rely on TF1-style `tf.compat.v1.get_variable` variable creation and assume that the created variables will be accessed via global collections. As TF2 does not support collections, APIs may not work correctly when running them with TF2 behaviors enabled.\n", "\n", "### Other API changes\n", "\n", "* TF2 features significant improvements to the device placement algorithms which renders the usage of `tf.colocate_with` unnecessary. If removing it causes a performance degradation, [please file a bug](https://github.com/tensorflow/tensorflow/issues).\n", "\n", "* Replace all usage of `tf.v1.ConfigProto` with equivalent functions from `tf.config`." ] }, { "cell_type": "markdown", "metadata": { "id": "RxEU79Rd83Yz" }, "source": [ "## Eager execution\n", "\n", "TF1.x required you to manually stitch together an [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree) (the graph) by making `tf.*` API calls and then manually compile the abstract syntax tree by passing a set of output tensors and input tensors to a `session.run` call. TF2 executes eagerly (like Python normally does) and makes graphs and sessions feel like implementation details.\n", "\n", "One notable byproduct of eager execution is that `tf.control_dependencies` is no\n", "longer required, as all lines of code execute in order (within a `tf.function`,\n", "code with side effects executes in the order written)." ] }, { "cell_type": "markdown", "metadata": { "id": "LH3YizX-9S7g" }, "source": [ "## No more globals\n", "\n", "TF1.x relied heavily on implicit global namespaces and collections. When you call `tf.Variable`, it would be put into a collection in the default graph, and it would remain there, even if you lost track of the Python variable pointing to it. You could then recover that `tf.Variable`, but only if you knew the name that it had been created with. This was difficult to do if you were not in control of the variable's creation. As a result, all sorts of mechanisms proliferated to\n", "attempt to help you find your variables again, and for frameworks to find\n", "user-created variables. Some of these include: variable scopes, global collections, helper methods like `tf.get_global_step` and `tf.global_variables_initializer`, optimizers implicitly\n", "computing gradients over all trainable variables, and so on. TF2 eliminates all of these mechanisms ([Variables 2.0 RFC](https://github.com/tensorflow/community/pull/11)) in favor of the default mechanism - you keep track of your variables. If you lose track of a `tf.Variable`, it gets garbage collected.\n", "\n", "The requirement to track variables creates some extra work, but with tools like the [modeling shims](./model_mapping.ipynb) and behaviors like [implicit object-oriented variable collections in `tf.Module`s and `tf.keras.layers.Layer`s](https://www.tensorflow.org/guide/intro_to_modules), the burden is minimized." ] }, { "cell_type": "markdown", "metadata": { "id": "NXwBgAjJ98J2" }, "source": [ "## Functions, not sessions\n", "\n", "A `session.run` call is almost like a function call: you specify the inputs and\n", "the function to be called, and you get back a set of outputs. In TF2, you can decorate a Python function using `tf.function` to mark it for JIT compilation so that TensorFlow runs it as a single graph ([Functions 2.0 RFC](https://github.com/tensorflow/community/pull/20)). This mechanism allows TF2 to gain all of the benefits of graph mode:\n", "\n", "- Performance: The function can be optimized (node pruning, kernel fusion,\n", " etc.)\n", "- Portability: The function can be exported/reimported\n", " ([SavedModel 2.0 RFC](https://github.com/tensorflow/community/pull/34)),\n", " allowing you to reuse and share modular TensorFlow functions.\n", "\n", "```python\n", "# TF1.x\n", "outputs = session.run(f(placeholder), feed_dict={placeholder: input})\n", "# TF2\n", "outputs = f(input)\n", "```\n", "\n", "With the power to freely intersperse Python and TensorFlow code, you can take\n", "advantage of Python's expressiveness. However, portable TensorFlow executes in\n", "contexts without a Python interpreter, such as mobile, C++, and JavaScript. To\n", "help avoid rewriting your code when adding `tf.function`, use [AutoGraph](https://tensorflow.org/guide/function) to convert a subset of Python constructs\n", "into their TensorFlow equivalents:\n", "\n", "* `for`/`while` -> `tf.while_loop` (`break` and `continue` are supported)\n", "* `if` -> `tf.cond`\n", "* `for _ in dataset` -> `dataset.reduce`\n", "\n", "AutoGraph supports arbitrary nestings of control flow, which makes it possible\n", "to performantly and concisely implement many complex ML programs such as\n", "sequence models, reinforcement learning, custom training loops, and more." ] }, { "cell_type": "markdown", "metadata": { "id": "Mj3gaj4tpi7O" }, "source": [ "## Adapting to TF 2.x Behavior Changes\n", "\n", "Your migration to TF2 is only complete once you have migrated to the full set of TF2 behaviors. The full set of behaviors can be enabled or disabled via `tf.compat.v1.enable_v2_behaviors` and `tf.compat.v1.disable_v2_behaviors`. The sections below discuss each major behavior change in detail." ] }, { "cell_type": "markdown", "metadata": { "id": "_M0zEtR9p0XD" }, "source": [ "### Using `tf.function`s\n", "\n", "The largest changes to your programs during migration are likely to come from the fundamental programming model paradigm shift from graphs and sessions to eager execution and `tf.function`. Refer to the [TF2 migration guides](https://tensorflow.org/guide/migrate) to learn more about moving from APIs that are incompatible with eager execution and `tf.function` to APIs that are compatible with them.\n", "\n", "Note: During migration you may choose to directly enable and disable eager execution with `tf.compat.v1.enable_eager_execution` and `tf.compat.v1.disable_eager_execution`, but this may only be done once during the lifetime of your program.\n", "\n", "Below are some common program patterns not tied to any one API that may cause problems when switching from `tf.Graph`s and `tf.compat.v1.Session`s to eager execution with `tf.function`s." ] }, { "cell_type": "markdown", "metadata": { "id": "UgwEtwwN2PWy" }, "source": [ "#### Pattern 1: Python object manipulation and variable creation intended to be done only once get run multiple times\n", "\n", "\n", "In TF1.x programs that rely on graphs and sessions, the expectation is usually that all Python logic in your program will only run once. However, with eager execution and `tf.function` it is fair to expect that your Python logic will be run at least once, but possibly more times (either multiple times eagerly, or multiple times across different `tf.function` traces). Sometimes, `tf.function` will even trace twice on the same input, causing unexpected behaviors (see Example 1 and 2). Refer to the `tf.function` [guide](https://www.tensorflow.org/guide/function) for more details.\n", "\n", "Note: This pattern usually causes your code to silently misbehave when executing eagerly without `tf.function`s, but generally raises an `InaccessibleTensorError` or a `ValueError` when attempting to wrap the problematic code inside of a `tf.function`. To discover and debug this issue, it is recommended you wrap your code with `tf.function` early on, and use [pdb](https://docs.python.org/3/library/pdb.html) or interactive debugging to identify the source of the `InaccessibleTensorError`.\n", "\n", "**Example 1: Variable creation**\n", "\n", "Consider the example below, where the function creates a variable when called:\n", "\n", "```python\n", "def f():\n", " v = tf.Variable(1.0)\n", " return v\n", "\n", "with tf.Graph().as_default():\n", " with tf.compat.v1.Session() as sess:\n", " res = f()\n", " sess.run(tf.compat.v1.global_variables_initializer())\n", " sess.run(res)\n", "```\n", "\n", "However, naively wrapping the above function that contains variable creation with `tf.function` is not allowed. `tf.function` only supports [singleton variable creations on the first call](https://www.tensorflow.org/guide/function#creating_tfvariables). To enforce this, when tf.function detects variable creation in the first call, it will attempt to trace again and raise an error if there is variable creation in the second trace.\n", "\n", "```python\n", "@tf.function\n", "def f():\n", " print(\"trace\") # This will print twice because the python body is run twice\n", " v = tf.Variable(1.0)\n", " return v\n", "\n", "try:\n", " f()\n", "except ValueError as e:\n", " print(e)\n", "```\n", "\n", "A workaround is caching and reusing the variable after it is created in the first call.\n", "\n", "```python\n", "class Model(tf.Module):\n", " def __init__(self):\n", " self.v = None\n", "\n", " @tf.function\n", " def __call__(self):\n", " print(\"trace\") # This will print twice because the python body is run twice\n", " if self.v is None:\n", " self.v = tf.Variable(0)\n", " return self.v\n", "\n", "m = Model()\n", "m()\n", "```\n", "\n", "**Example 2: Out-of-scope Tensors due to `tf.function` retracing**\n", "\n", "As demonstrated in Example 1, `tf.function` will retrace when it detects Variable creation in the first call. This can cause extra confusion, because the two tracings will create two graphs. When the second graph from retracing attempts to access a Tensor from the graph generated during the first tracing, Tensorflow will raise an error complaining that the Tensor is out of scope. To demonstrate the scenario, the code below creates a dataset on the first `tf.function` call. This would run as expected.\n", "\n", "```python\n", "class Model(tf.Module):\n", " def __init__(self):\n", " self.dataset = None\n", "\n", " @tf.function\n", " def __call__(self):\n", " print(\"trace\") # This will print once: only traced once\n", " if self.dataset is None:\n", " self.dataset = tf.data.Dataset.from_tensors([1, 2, 3])\n", " it = iter(self.dataset)\n", " return next(it)\n", "\n", "m = Model()\n", "m()\n", "```\n", "\n", "However, if we also attempt to create a variable on the first `tf.function` call, the code will raise an error complaining that the dataset is out of scope. This is because the dataset is in the first graph, while the second graph is also attempting to access it.\n", "\n", "```python\n", "class Model(tf.Module):\n", " def __init__(self):\n", " self.v = None\n", " self.dataset = None\n", "\n", " @tf.function\n", " def __call__(self):\n", " print(\"trace\") # This will print twice because the python body is run twice\n", " if self.v is None:\n", " self.v = tf.Variable(0)\n", " if self.dataset is None:\n", " self.dataset = tf.data.Dataset.from_tensors([1, 2, 3])\n", " it = iter(self.dataset)\n", " return [self.v, next(it)]\n", "\n", "m = Model()\n", "try:\n", " m()\n", "except TypeError as e:\n", " print(e) # is out of scope and cannot be used here.\n", "```\n", "\n", "The most straightforward solution is ensuring that the variable creation and dataset creation are both outside of the `tf.function` call. For example:\n", "\n", "```python\n", "class Model(tf.Module):\n", " def __init__(self):\n", " self.v = None\n", " self.dataset = None\n", "\n", " def initialize(self):\n", " if self.dataset is None:\n", " self.dataset = tf.data.Dataset.from_tensors([1, 2, 3])\n", " if self.v is None:\n", " self.v = tf.Variable(0)\n", "\n", " @tf.function\n", " def __call__(self):\n", " it = iter(self.dataset)\n", " return [self.v, next(it)]\n", "\n", "m = Model()\n", "m.initialize()\n", "m()\n", "```\n", "\n", "However, sometimes it's not avoidable to create variables in `tf.function` (such as slot variables in some [TF keras optimizers](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer#slots)). Still, we can simply move the dataset creation outside of the `tf.function` call. The reason that we can rely on this is because `tf.function` will receive the dataset as an implicit input and both graphs can access it properly.\n", "\n", "```python\n", "class Model(tf.Module):\n", " def __init__(self):\n", " self.v = None\n", " self.dataset = None\n", "\n", " def initialize(self):\n", " if self.dataset is None:\n", " self.dataset = tf.data.Dataset.from_tensors([1, 2, 3])\n", "\n", " @tf.function\n", " def __call__(self):\n", " if self.v is None:\n", " self.v = tf.Variable(0)\n", " it = iter(self.dataset)\n", " return [self.v, next(it)]\n", "\n", "m = Model()\n", "m.initialize()\n", "m()\n", "```\n", "\n", "**Example 3: Unexpected Tensorflow object re-creations due to dict usage**\n", "\n", "`tf.function` has very poor support for python side effects such as appending to a list, or checking/adding to a dictionary. More details are in [\"Better performance with tf.function\"](https://www.tensorflow.org/guide/function#executing_python_side_effects). In the example below, the code uses dictionaries to cache datasets and iterators. For the same key, each call to the model will return the same iterator of the dataset.\n", "\n", "```python\n", "class Model(tf.Module):\n", " def __init__(self):\n", " self.datasets = {}\n", " self.iterators = {}\n", "\n", " def __call__(self, key):\n", " if key not in self.datasets:\n", " self.datasets[key] = tf.compat.v1.data.Dataset.from_tensor_slices([1, 2, 3])\n", " self.iterators[key] = self.datasets[key].make_initializable_iterator()\n", " return self.iterators[key]\n", "\n", "with tf.Graph().as_default():\n", " with tf.compat.v1.Session() as sess:\n", " m = Model()\n", " it = m('a')\n", " sess.run(it.initializer)\n", " for _ in range(3):\n", " print(sess.run(it.get_next())) # prints 1, 2, 3\n", "```\n", "\n", "However, the pattern above will not work as expected in `tf.function`. During tracing, `tf.function` will ignore the python side effect of addition to the dictionaries. Instead, it only remembers the creation of a new dataset and iterator. As a result, each call to the model will always return a new iterator. This issue is hard to notice unless the numerical results or performance are significant enough. Hence, we recommend users to think about the code carefully before wrapping `tf.function` naively onto the python code.\n", "\n", "```python\n", "class Model(tf.Module):\n", " def __init__(self):\n", " self.datasets = {}\n", " self.iterators = {}\n", "\n", " @tf.function\n", " def __call__(self, key):\n", " if key not in self.datasets:\n", " self.datasets[key] = tf.data.Dataset.from_tensor_slices([1, 2, 3])\n", " self.iterators[key] = iter(self.datasets[key])\n", " return self.iterators[key]\n", "\n", "m = Model()\n", "for _ in range(3):\n", " print(next(m('a'))) # prints 1, 1, 1\n", "```\n", "\n", "We can use [`tf.init_scope`](https://www.tensorflow.org/api_docs/python/tf/init_scope) to lift the dataset and iterator creation outside of the graph, to achieve the expected behavior:\n", "\n", "```python\n", "class Model(tf.Module):\n", " def __init__(self):\n", " self.datasets = {}\n", " self.iterators = {}\n", "\n", " @tf.function\n", " def __call__(self, key):\n", " if key not in self.datasets:\n", " # Lifts ops out of function-building graphs\n", " with tf.init_scope():\n", " self.datasets[key] = tf.data.Dataset.from_tensor_slices([1, 2, 3])\n", " self.iterators[key] = iter(self.datasets[key])\n", " return self.iterators[key]\n", "\n", "m = Model()\n", "for _ in range(3):\n", " print(next(m('a'))) # prints 1, 2, 3\n", "```\n", "\n", "The general rule of thumb is to avoid relying on Python side effects in your logic and only use them to debug your traces.\n", "\n", "**Example 4: Manipulating a global Python list**\n", "\n", "The following TF1.x code uses a global list of losses that it uses to only maintain the list of losses generated by the current training step. Note that the Python logic that appends losses to the list will only be called once regardless of how many training steps the session is run for.\n", "\n", "```python\n", "all_losses = []\n", "\n", "class Model():\n", " def __call__(...):\n", " ...\n", " all_losses.append(regularization_loss)\n", " all_losses.append(label_loss_a)\n", " all_losses.append(label_loss_b)\n", " ...\n", "\n", "g = tf.Graph()\n", "with g.as_default():\n", " ...\n", " # initialize all objects\n", " model = Model()\n", " optimizer = ...\n", " ...\n", " # train step\n", " model(...)\n", " total_loss = tf.reduce_sum(all_losses)\n", " optimizer.minimize(total_loss)\n", " ...\n", "...\n", "sess = tf.compat.v1.Session(graph=g)\n", "sess.run(...) \n", "```\n", "\n", "However, if this Python logic is naively mapped to TF2 with eager execution, the global list of losses will have new values appended to it in each training step. This means the training step code which previously expected the list to only contain losses from the current training step now actually sees the list of losses from all training steps run so far. This is an unintended behavior change, and the list will either need to be cleared at the start of each step or made local to the training step.\n", "\n", "```python\n", "all_losses = []\n", "\n", "class Model():\n", " def __call__(...):\n", " ...\n", " all_losses.append(regularization_loss)\n", " all_losses.append(label_loss_a)\n", " all_losses.append(label_loss_b)\n", " ...\n", "\n", "# initialize all objects\n", "model = Model()\n", "optimizer = ...\n", "\n", "def train_step(...)\n", " ...\n", " model(...)\n", " total_loss = tf.reduce_sum(all_losses) # global list is never cleared,\n", " # Accidentally accumulates sum loss across all training steps\n", " optimizer.minimize(total_loss)\n", " ...\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "qaYnjPo-tmTI" }, "source": [ "#### Pattern 2: A symbolic tensor meant to be recomputed every step in TF1.x is accidentally cached with the initial value when switching to eager.\n", "\n", "\n", "This pattern usually causes your code to silently misbehave when executing eagerly outside of tf.functions, but raises an `InaccessibleTensorError` if the initial value caching occurs inside of a `tf.function`. However, be aware that in order to avoid [Pattern 1](#pattern-1) above you will often inadvertently structure your code in such a way that this initial value caching will happen *outside* of any `tf.function` that would be able to raise an error. So, take extra care if you know your program may be susceptible to this pattern.\n", "\n", "The general solution to this pattern is to restructure the code or use Python callables if necessary to make sure the value is recomputed each time instead of being accidentally cached.\n", "\n", "**Example 1: Learning rate/hyperparameter/etc. schedules that depend on global step**\n", "\n", "In the following code snippet, the expectation is that every time the session is run the most recent `global_step` value will be read and a new learning rate will be computed.\n", "```python\n", "g = tf.Graph()\n", "with g.as_default():\n", " ...\n", " global_step = tf.Variable(0)\n", " learning_rate = 1.0 / global_step\n", " opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate)\n", " ...\n", " global_step.assign_add(1)\n", "...\n", "sess = tf.compat.v1.Session(graph=g)\n", "sess.run(...)\n", "```\n", "\n", "However, when trying to switch to eager, be wary of ending up with the learning rate only being computed once then reused, rather than following the intended schedule:\n", "```python\n", "global_step = tf.Variable(0)\n", "learning_rate = 1.0 / global_step # Wrong! Only computed once!\n", "opt = tf.keras.optimizers.SGD(learning_rate)\n", "\n", "def train_step(...):\n", " ...\n", " opt.apply_gradients(...)\n", " global_step.assign_add(1)\n", " ...\n", "```\n", "\n", "Because this specific example is a common pattern and optimizers should only be initialized once rather than at each training step, TF2 optimizers support `tf.keras.optimizers.schedules.LearningRateSchedule` schedules or Python callables as arguments for the learning rate and other hyperparameters.\n", "\n", "**Example 2: Symbolic random number initializations assigned as object attributes then reused via pointer are accidentally cached when switching to eager**\n", "\n", "Consider the following `NoiseAdder` module:\n", "\n", "```python\n", "class NoiseAdder(tf.Module):\n", " def __init__(shape, mean):\n", " self.noise_distribution = tf.random.normal(shape=shape, mean=mean)\n", " self.trainable_scale = tf.Variable(1.0, trainable=True)\n", " \n", " def add_noise(input):\n", " return (self.noise_distribution + input) * self.trainable_scale\n", "```\n", "\n", "Using it as follows in TF1.x will compute a new random noise tensor every time the session is run:\n", "```python\n", "g = tf.Graph()\n", "with g.as_default():\n", " ...\n", " # initialize all variable-containing objects\n", " noise_adder = NoiseAdder(shape, mean)\n", " ...\n", " # computation pass\n", " x_with_noise = noise_adder.add_noise(x)\n", " ...\n", "...\n", "sess = tf.compat.v1.Session(graph=g)\n", "sess.run(...)\n", "```\n", "\n", "However, in TF2 initializing the `noise_adder` at the beginning will cause the `noise_distribution` to be only computed once and get frozen for all training steps:\n", "```python\n", "...\n", "# initialize all variable-containing objects\n", "noise_adder = NoiseAdder(shape, mean) # Freezes `self.noise_distribution`!\n", "...\n", "# computation pass\n", "x_with_noise = noise_adder.add_noise(x)\n", "...\n", "```\n", "\n", "To fix this, refactor `NoiseAdder` to call `tf.random.normal` every time a new random tensor is needed, instead of referring to the same tensor object each time.\n", "\n", "```python\n", "class NoiseAdder(tf.Module):\n", " def __init__(shape, mean):\n", " self.noise_distribution = lambda: tf.random.normal(shape=shape, mean=mean)\n", " self.trainable_scale = tf.Variable(1.0, trainable=True)\n", " \n", " def add_noise(input):\n", " return (self.noise_distribution() + input) * self.trainable_scale\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "j2PXkSflCaCl" }, "source": [ "#### Pattern 3: TF1.x code directly relies on and looks up tensors by name\n", "\n", "\n", "It is common for TF1.x code tests to rely on checking what tensors or operations are present in a graph. In some rare cases, modeling code will also rely on these lookups by name.\n", "\n", "Tensor names are not generated when executing eagerly outside of `tf.function` at all, so all usages of `tf.Tensor.name` must happen inside of a `tf.function`. Keep in mind the actual generated names are very likely to differ between TF1.x and TF2 even within the same `tf.function`, and API guarantees do not ensure stability of the generated names across TF versions.\n", "\n", "Note: Variable names are still generated even outside of `tf.function`s, but their names also are not guaranteed to match between TF1.x and TF2 except when following the relevant sections of the [model mapping guide](./model_mapping.ipynb).\n" ] }, { "cell_type": "markdown", "metadata": { "id": "5NB3bycl5Lde" }, "source": [ "#### Pattern 4: TF1.x session selectively runs only part of the generated graph\n", "\n", "\n", "In TF1.x, you can construct a graph and then choose to only selectively run only a subset of it with a session by choosing a set of inputs and outputs that do not require running every op in the graph.\n", "\n", "For example, you may have both a generator and a discriminator inside of a single graph, and use separate `tf.compat.v1.Session.run` calls to alternate between only training the discriminator or only training the generator.\n", "\n", "In TF2, due to automatic control dependencies in `tf.function` and eager execution, there is no selective pruning of `tf.function` traces. A full graph containing all variable updates would get run even if, for example, only the output of the discriminator or the generator is output from the `tf.function`.\n", "\n", "So, you would need to either use multiple `tf.function`s containing different parts of the program, or a conditional argument to the `tf.function` that you branch on so as to execute only the things you actually want to have run." ] }, { "cell_type": "markdown", "metadata": { "id": "CnNaUmROp5fV" }, "source": [ "### Collections Removal\n", "\n", "When eager execution is enabled, graph collection-related `compat.v1` APIs (including those that read or write to collections under the hood such as `tf.compat.v1.trainable_variables`) are no longer available. Some may raise `ValueError`s, while others may silently return empty lists.\n", "\n", "The most standard usage of collections in TF1.x is to maintain initializers, the global step, weights, regularization losses, model output losses, and variable updates that need to be run such as from `BatchNormalization` layers.\n", "\n", "To handle each of these standard usages:\n", "1. Initializers - Ignore. Manual variable initialization is not required with eager execution enabled.\n", "2. Global step - See the documentation of `tf.compat.v1.train.get_or_create_global_step` for migration instructions.\n", "3. Weights - Map your models to `tf.Module`s/`tf.keras.layers.Layer`s/`tf.keras.Model`s by following the guidance in the [model mapping guide](./model_mapping.ipynb) and then use their respective weight-tracking mechanisms such as `tf.module.trainable_variables`.\n", "4. Regularization losses - Map your models to `tf.Module`s/`tf.keras.layers.Layer`s/`tf.keras.Model`s by following the guidance in the [model mapping guide](./model_mapping.ipynb) and then use `tf.keras.losses`. Alternatively, you can also manually track your regularization losses.\n", "5. Model output losses - Use `tf.keras.Model` loss management mechanisms or separately track your losses without using collections.\n", "6. Weight updates - Ignore this collection. Eager execution and `tf.function` (with autograph and auto-control-dependencies) means all variable updates will get run automatically. So, you will not have to explicitly run all weight updates at the end, but note that it means the weight updates may happen at a different time than they did in your TF1.x code, depending on how you were using control dependencies.\n", "7. Summaries - Refer to the [migrating summary API guide](https://www.tensorflow.org/tensorboard/migrate).\n", "\n", "More complex collections usage (such as using custom collections) may require you to refactor your code to either maintain your own global stores, or to make it not rely on global stores at all." ] }, { "cell_type": "markdown", "metadata": { "id": "8J_ckZstp8y1" }, "source": [ "### `ResourceVariables` instead of `ReferenceVariables`\n", "\n", "`ResourceVariables` have stronger read-write consistency guarantees than `ReferenceVariables`. This leads to more predictable, easier-to-reason semantics about whether or not you will observe the result of a previous write when using your variables. This change is extremely unlikely to cause existing code to raise errors or to break silently.\n", "\n", "However, it is ***possible though unlikely*** that these stronger consistency guarantees may increase the memory usage of your specific program. Please file an [issue](https://github.com/tensorflow/tensorflow/issues) if you find this to be the case. Additionally, if you have unit tests relying on exact string comparisons against the operator names in a graph corresponding to variable reads, be aware that enabling resource variables may slightly change the name of these operators.\n", "\n", "To isolate the impact of this behavior change on your code, if eager execution is disabled you can use `tf.compat.v1.disable_resource_variables()` and `tf.compat.v1.enable_resource_variables()` to globally disable or enable this behavior change. `ResourceVariables` will always be used if eager execution is enabled.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "FTU-4P1vux0e" }, "source": [ "### Control flow v2\n", "\n", "In TF1.x, control flow ops such as `tf.cond` and `tf.while_loop` inline low-level ops such as `Switch`, `Merge` etc. TF2 provides improved functional control flow ops that are implemented with separate `tf.function` traces for every branch and support higher-order differentiation.\n", "\n", "To isolate the impact of this behavior change on your code, if eager execution is disabled you can use `tf.compat.v1.disable_control_flow_v2()` and `tf.compat.v1.enable_control_flow_v2()` to globally disable or enable this behavior change. However, you can only disable control flow v2 if eager execution is also disabled. If it is enabled, control flow v2 will always be used.\n", "\n", "This behavior change can dramatically change the structure of generated TF programs that use control flow, as they will contain several nested function traces rather than one flat graph. So, any code that is highly dependent on the exact semantics of produced traces may require some modification. This includes:\n", "* Code relying on operator and tensor names\n", "* Code referring to tensors created within a TensorFlow control flow branch from outside of that branch. This is likely to produce an `InaccessibleTensorError`\n", "\n", "This behavior change is intended to be performance neutral to positive, but if you run into an issue where control flow v2 performs worse for you than TF1.x control flow then please file an [issue](https://github.com/tensorflow/tensorflow/issues) with reproduction steps. " ] }, { "cell_type": "markdown", "metadata": { "id": "W7VwgVCGqE9S" }, "source": [ "## TensorShape API behavior changes\n", "\n", "The `TensorShape` class was simplified to hold `int`s, instead of `tf.compat.v1.Dimension` objects. So there is no need to call `.value` to get an `int`.\n", "\n", "Individual `tf.compat.v1.Dimension` objects are still accessible from `tf.TensorShape.dims`.\n", "\n", "To isolate the impact of this behavior change on your code, you can use `tf.compat.v1.disable_v2_tensorshape()` and `tf.compat.v1.enable_v2_tensorshape()` to globally disable or enable this behavior change." ] }, { "cell_type": "markdown", "metadata": { "id": "x36cWcmM8Eu1" }, "source": [ "The following demonstrate the differences between TF1.x and TF2." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:00.126324Z", "iopub.status.busy": "2024-08-15T02:19:00.126074Z", "iopub.status.idle": "2024-08-15T02:19:02.486396Z", "shell.execute_reply": "2024-08-15T02:19:02.485665Z" }, "id": "QF4un9UpVTRA" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-08-15 02:19:00.377611: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-08-15 02:19:00.398782: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-08-15 02:19:00.405281: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:02.490680Z", "iopub.status.busy": "2024-08-15T02:19:02.490277Z", "iopub.status.idle": "2024-08-15T02:19:02.497456Z", "shell.execute_reply": "2024-08-15T02:19:02.496850Z" }, "id": "PbpD-kHOZR4A" }, "outputs": [ { "data": { "text/plain": [ "TensorShape([16, None, 256])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a shape and choose an index\n", "i = 0\n", "shape = tf.TensorShape([16, None, 256])\n", "shape" ] }, { "cell_type": "markdown", "metadata": { "id": "kDFck03neNy0" }, "source": [ "If you had this in TF1.x:\n", "\n", "```python\n", "value = shape[i].value\n", "```\n", "\n", "Then do this in TF2:\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:02.501299Z", "iopub.status.busy": "2024-08-15T02:19:02.500811Z", "iopub.status.idle": "2024-08-15T02:19:02.505043Z", "shell.execute_reply": "2024-08-15T02:19:02.504495Z" }, "id": "KuR73QGEeNdH" }, "outputs": [ { "data": { "text/plain": [ "16" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "value = shape[i]\n", "value" ] }, { "cell_type": "markdown", "metadata": { "id": "bPWPNKRiZmkd" }, "source": [ "If you had this in TF1.x:\n", "\n", "```python\n", "for dim in shape:\n", " value = dim.value\n", " print(value)\n", "```\n", "\n", "Then, do this in TF2:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:02.508039Z", "iopub.status.busy": "2024-08-15T02:19:02.507805Z", "iopub.status.idle": "2024-08-15T02:19:02.511202Z", "shell.execute_reply": "2024-08-15T02:19:02.510612Z" }, "id": "y6s0vuuprJfc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "16\n", "None\n", "256\n" ] } ], "source": [ "for value in shape:\n", " print(value)" ] }, { "cell_type": "markdown", "metadata": { "id": "YpRgngu3Zw-A" }, "source": [ "If you had this in TF1.x (or used any other dimension method):\n", "\n", "```python\n", "dim = shape[i]\n", "dim.assert_is_compatible_with(other_dim)\n", "```\n", "\n", "Then do this in TF2:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:02.514351Z", "iopub.status.busy": "2024-08-15T02:19:02.513892Z", "iopub.status.idle": "2024-08-15T02:19:02.519050Z", "shell.execute_reply": "2024-08-15T02:19:02.518373Z" }, "id": "LpViGEcUZDGX" }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "other_dim = 16\n", "Dimension = tf.compat.v1.Dimension\n", "\n", "if shape.rank is None:\n", " dim = Dimension(None)\n", "else:\n", " dim = shape.dims[i]\n", "dim.is_compatible_with(other_dim) # or any other dimension method" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:02.522103Z", "iopub.status.busy": "2024-08-15T02:19:02.521601Z", "iopub.status.idle": "2024-08-15T02:19:02.525003Z", "shell.execute_reply": "2024-08-15T02:19:02.524426Z" }, "id": "GaiGe36dOdZ_" }, "outputs": [], "source": [ "shape = tf.TensorShape(None)\n", "\n", "if shape:\n", " dim = shape.dims[i]\n", " dim.is_compatible_with(other_dim) # or any other dimension method" ] }, { "cell_type": "markdown", "metadata": { "id": "3kLLY0I3PI-l" }, "source": [ "The boolean value of a `tf.TensorShape` is `True` if the rank is known, `False` otherwise." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:02.528319Z", "iopub.status.busy": "2024-08-15T02:19:02.527769Z", "iopub.status.idle": "2024-08-15T02:19:02.532614Z", "shell.execute_reply": "2024-08-15T02:19:02.532021Z" }, "id": "-Ow1ndKpOnJd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n", "True\n", "True\n", "True\n", "True\n", "\n", "False\n" ] } ], "source": [ "print(bool(tf.TensorShape([]))) # Scalar\n", "print(bool(tf.TensorShape([0]))) # 0-length vector\n", "print(bool(tf.TensorShape([1]))) # 1-length vector\n", "print(bool(tf.TensorShape([None]))) # Unknown-length vector\n", "print(bool(tf.TensorShape([1, 10, 100]))) # 3D tensor\n", "print(bool(tf.TensorShape([None, None, None]))) # 3D tensor with no known dimensions\n", "print()\n", "print(bool(tf.TensorShape(None))) # A tensor with unknown rank." ] }, { "cell_type": "markdown", "metadata": { "id": "KvfEd-uSsWqN" }, "source": [ "### Potential errors due to TensorShape changes\n", "\n", "The TensorShape behavior changes are unlikely to silently break your code. However, you may see shape-related code begin to raise `AttributeError`s as `int`s and `None`s do not have the same attributes that `tf.compat.v1.Dimension`s do. Below are some examples of these `AttributeError`s:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:02.535468Z", "iopub.status.busy": "2024-08-15T02:19:02.535245Z", "iopub.status.idle": "2024-08-15T02:19:02.538884Z", "shell.execute_reply": "2024-08-15T02:19:02.538323Z" }, "id": "r18f8JAGsQi6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "'int' object has no attribute 'value'\n" ] } ], "source": [ "try:\n", " # Create a shape and choose an index\n", " shape = tf.TensorShape([16, None, 256])\n", " value = shape[0].value\n", "except AttributeError as e:\n", " # 'int' object has no attribute 'value'\n", " print(e)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:02.541752Z", "iopub.status.busy": "2024-08-15T02:19:02.541531Z", "iopub.status.idle": "2024-08-15T02:19:02.545535Z", "shell.execute_reply": "2024-08-15T02:19:02.544918Z" }, "id": "t9flHru1uIdT" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "'NoneType' object has no attribute 'assert_is_compatible_with'\n" ] } ], "source": [ "try:\n", " # Create a shape and choose an index\n", " shape = tf.TensorShape([16, None, 256])\n", " dim = shape[1]\n", " other_dim = shape[2]\n", " dim.assert_is_compatible_with(other_dim)\n", "except AttributeError as e:\n", " # 'NoneType' object has no attribute 'assert_is_compatible_with'\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": { "id": "Og7H_TwJqIOF" }, "source": [ "## Tensor Equality by Value\n", "\n", "The binary `==` and `!=` operators on variables and tensors were changed to compare by value in TF2 rather than comparing by object reference like in TF1.x. Additionally, tensors and variables are no longer directly hashable or usable in sets or dict keys, because it may not be possible to hash them by value. Instead, they expose a `.ref()` method that you can use to get a hashable reference to the tensor or variable.\n", "\n", "To isolate the impact of this behavior change, you can use `tf.compat.v1.disable_tensor_equality()` and `tf.compat.v1.enable_tensor_equality()` to globally disable or enable this behavior change." ] }, { "cell_type": "markdown", "metadata": { "id": "NGN4oL3lz0ki" }, "source": [ "For example, in TF1.x, two variables with the same value will return false when you use the `==` operator:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:02.548682Z", "iopub.status.busy": "2024-08-15T02:19:02.548251Z", "iopub.status.idle": "2024-08-15T02:19:04.741654Z", "shell.execute_reply": "2024-08-15T02:19:04.741001Z" }, "id": "dkGPGpEZ5DI-" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1723688343.035972 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688343.039323 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688343.042974 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688343.046726 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688343.058112 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688343.061129 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688343.064551 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688343.068092 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688343.071030 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688343.074018 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688343.077399 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688343.080793 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.308197 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.310340 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.312350 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.314440 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.316503 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.318477 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.320380 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.322366 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.324325 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.326301 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.328196 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.330191 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.369490 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.371572 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.373525 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.375567 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.378245 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.381003 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.382927 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.385462 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.387403 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.389956 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.392328 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723688344.394762 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" ] }, { "data": { "text/plain": [ "False" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.compat.v1.disable_tensor_equality()\n", "x = tf.Variable(0.0)\n", "y = tf.Variable(0.0)\n", "\n", "x == y" ] }, { "cell_type": "markdown", "metadata": { "id": "RqbewjIFz_oz" }, "source": [ "While in TF2 with tensor equality checks enabled, `x == y` will return `True`." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:04.745617Z", "iopub.status.busy": "2024-08-15T02:19:04.745343Z", "iopub.status.idle": "2024-08-15T02:19:04.753819Z", "shell.execute_reply": "2024-08-15T02:19:04.753151Z" }, "id": "V5P_Rwy-zxVE" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.compat.v1.enable_tensor_equality()\n", "x = tf.Variable(0.0)\n", "y = tf.Variable(0.0)\n", "\n", "x == y" ] }, { "cell_type": "markdown", "metadata": { "id": "BqdUPLhHypfs" }, "source": [ "So, in TF2, if you need to compare by object reference make sure to use `is` and `is not`" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:04.757280Z", "iopub.status.busy": "2024-08-15T02:19:04.756622Z", "iopub.status.idle": "2024-08-15T02:19:04.762296Z", "shell.execute_reply": "2024-08-15T02:19:04.761674Z" }, "id": "iEjXVxlu4uxo" }, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.compat.v1.enable_tensor_equality()\n", "x = tf.Variable(0.0)\n", "y = tf.Variable(0.0)\n", "\n", "x is y" ] }, { "cell_type": "markdown", "metadata": { "id": "r2ai1BGN01VI" }, "source": [ "### Hashing tensors and variables\n", "With TF1.x behaviors you used to be able to directly add variables and tensors to data structures that require hashing, such as `set` and `dict` keys.\n", "```python\n", "x = tf.Variable(0.0)\n", "set([x, tf.constant(2.0)])\n", "```\n", "\n", "However, in TF2 with tensor equality enabled, tensors and variables are made unhashable due to the `==` and `!=` operator semantics changing to value equality checks." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:04.765480Z", "iopub.status.busy": "2024-08-15T02:19:04.765114Z", "iopub.status.idle": "2024-08-15T02:19:04.771258Z", "shell.execute_reply": "2024-08-15T02:19:04.770641Z" }, "id": "-TR1KfJu462w" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable is unhashable. Instead, use variable.ref() as the key. (Variable: )\n" ] } ], "source": [ "tf.compat.v1.enable_tensor_equality()\n", "x = tf.Variable(0.0)\n", "\n", "try:\n", " set([x, tf.constant(2.0)])\n", "except TypeError as e:\n", " # TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": { "id": "CQY7NvNAa7be" }, "source": [ "So, in TF2 if you need to use tensor or variable objects as keys or `set` contents, you can use `tensor.ref()` to get a hashable reference that can be used as a key:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:04.774494Z", "iopub.status.busy": "2024-08-15T02:19:04.773857Z", "iopub.status.idle": "2024-08-15T02:19:04.779988Z", "shell.execute_reply": "2024-08-15T02:19:04.779420Z" }, "id": "p-1kVPs01ZuU" }, "outputs": [ { "data": { "text/plain": [ "{>,\n", " >}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.compat.v1.enable_tensor_equality()\n", "x = tf.Variable(0.0)\n", "\n", "tensor_set = set([x.ref(), tf.constant(2.0).ref()])\n", "assert x.ref() in tensor_set\n", "\n", "tensor_set" ] }, { "cell_type": "markdown", "metadata": { "id": "PqqRqfOYbaOX" }, "source": [ "If needed, you can also get the tensor or variable from the reference by using `reference.deref()`:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:19:04.782804Z", "iopub.status.busy": "2024-08-15T02:19:04.782496Z", "iopub.status.idle": "2024-08-15T02:19:04.787348Z", "shell.execute_reply": "2024-08-15T02:19:04.786760Z" }, "id": "DwRZMYV06M7q" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "referenced_var = x.ref().deref()\n", "assert referenced_var is x\n", "referenced_var" ] }, { "cell_type": "markdown", "metadata": { "id": "5XSFQbJaReVC" }, "source": [ "## Resources and further reading\n", "\n", "* Visit the [Migrate to TF2](https://tensorflow.org/guide/migrate) section to read more about migrating to TF2 from TF1.x.\n", "* Read the [model mapping guide](./model_mapping.ipynb) to learn more mapping your TF1.x models to work in TF2 directly. " ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "tf1_vs_tf2.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }