{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ISubpr_SSsiM" }, "source": [ "##### Copyright 2020 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T22:33:45.593987Z", "iopub.status.busy": "2022-12-14T22:33:45.593529Z", "iopub.status.idle": "2022-12-14T22:33:45.597263Z", "shell.execute_reply": "2022-12-14T22:33:45.596704Z" }, "id": "3jTMb1dySr3V" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "6DWfyNThSziV" }, "source": [ "# 使用 tf.function 提升性能\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看\n", "在 Google Colab 中运行 在 GitHub 上查看源代码\n", "下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "J122XQYG7W6w" }, "source": [ "在 TensorFlow 2 中,[Eager Execution](eager.ipynb) 默认处于启用状态。界面非常灵活直观(执行一次性运算要简单快速得多),不过,这可能对性能和可部署性造成一定影响。\n", "\n", "您可以使用 `tf.function` 将程序转换为计算图。这是一个转换工具,用于从 Python 代码创建独立于 Python 的数据流图。它可以帮助您创建高效且可移植的模型,并且如果要使用 `SavedModel`,则必须使用此工具。\n", "\n", "本指南介绍 `tf.function` 的底层工作原理,让您形成概念化理解,从而有效地加以利用。\n", "\n", "要点和建议包括:\n", "\n", "- 先在 Eager 模式下调试,然后使用 `@tf.function` 进行装饰。\n", "- 不依赖 Python 的副作用,如对象变异或列表追加。\n", "- `tf.function` 最适合处理 TensorFlow 运算;NumPy 和 Python 调用会转换为常量。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "SjvqpgepHJPd" }, "source": [ "## 设置" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:45.600636Z", "iopub.status.busy": "2022-12-14T22:33:45.600427Z", "iopub.status.idle": "2022-12-14T22:33:49.321477Z", "shell.execute_reply": "2022-12-14T22:33:49.320762Z" }, "id": "otIdN1TS8N7S" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 22:33:48.348405: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:33:48.348501: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:33:48.348510: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "# Update TensorFlow, as this notebook requires version 2.9 or later\n", "!pip install -q -U tensorflow>=2.9.0\n", "import tensorflow as tf" ] }, { "cell_type": "markdown", "metadata": { "id": "I0xDjO4SHLUD" }, "source": [ "定义一个辅助函数来演示可能遇到的错误类型:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:49.325894Z", "iopub.status.busy": "2022-12-14T22:33:49.325486Z", "iopub.status.idle": "2022-12-14T22:33:49.330651Z", "shell.execute_reply": "2022-12-14T22:33:49.329892Z" }, "id": "D25apou9IOXa" }, "outputs": [], "source": [ "import traceback\n", "import contextlib\n", "\n", "# Some helper code to demonstrate the kinds of errors you might encounter.\n", "@contextlib.contextmanager\n", "def assert_raises(error_class):\n", " try:\n", " yield\n", " except error_class as e:\n", " print('Caught expected exception \\n {}:'.format(error_class))\n", " traceback.print_exc(limit=2)\n", " except Exception as e:\n", " raise e\n", " else:\n", " raise Exception('Expected {} to be raised but no error was raised!'.format(\n", " error_class))" ] }, { "cell_type": "markdown", "metadata": { "id": "WPSfepzTHThq" }, "source": [ "## 基础知识" ] }, { "cell_type": "markdown", "metadata": { "id": "CNwYTIJ8r56W" }, "source": [ "### 用法\n", "\n", "您定义的 `Function`(例如,通过应用 `@tf.function` 装饰器)就像核心 TensorFlow 运算:您可以在 Eager 模式下执行它,可以计算梯度,等等。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:49.334339Z", "iopub.status.busy": "2022-12-14T22:33:49.333916Z", "iopub.status.idle": "2022-12-14T22:33:52.718791Z", "shell.execute_reply": "2022-12-14T22:33:52.717889Z" }, "id": "SbtT1-Wm70F2" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@tf.function # The decorator converts `add` into a `Function`.\n", "def add(a, b):\n", " return a + b\n", "\n", "add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:52.722240Z", "iopub.status.busy": "2022-12-14T22:33:52.722008Z", "iopub.status.idle": "2022-12-14T22:33:52.749166Z", "shell.execute_reply": "2022-12-14T22:33:52.748537Z" }, "id": "uP-zUelB8DbX" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v = tf.Variable(1.0)\n", "with tf.GradientTape() as tape:\n", " result = add(v, 1.0)\n", "tape.gradient(result, v)" ] }, { "cell_type": "markdown", "metadata": { "id": "ocWZvqrmHnmX" }, "source": [ "`Function` 中可以嵌套其他 `Function`。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:52.752399Z", "iopub.status.busy": "2022-12-14T22:33:52.751935Z", "iopub.status.idle": "2022-12-14T22:33:53.124633Z", "shell.execute_reply": "2022-12-14T22:33:53.123647Z" }, "id": "l5qRjdbBVdU6" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@tf.function\n", "def dense_layer(x, w, b):\n", " return add(tf.matmul(x, w), b)\n", "\n", "dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))" ] }, { "cell_type": "markdown", "metadata": { "id": "piBhz7gYsHqU" }, "source": [ "`Function` 的执行速度比 Eager 代码快,尤其是对于包含很多简单运算的计算图。但是,对于包含一些复杂运算(如卷积)的计算图,速度提升不会太明显。\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:53.128218Z", "iopub.status.busy": "2022-12-14T22:33:53.127964Z", "iopub.status.idle": "2022-12-14T22:33:54.098762Z", "shell.execute_reply": "2022-12-14T22:33:54.098011Z" }, "id": "zuXt4wRysI03" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Eager conv: 0.006609583000681596\n", "Function conv: 0.006563433998962864\n", "Note how there's not much difference in performance for convolutions\n" ] } ], "source": [ "import timeit\n", "conv_layer = tf.keras.layers.Conv2D(100, 3)\n", "\n", "@tf.function\n", "def conv_fn(image):\n", " return conv_layer(image)\n", "\n", "image = tf.zeros([1, 200, 200, 100])\n", "# Warm up\n", "conv_layer(image); conv_fn(image)\n", "print(\"Eager conv:\", timeit.timeit(lambda: conv_layer(image), number=10))\n", "print(\"Function conv:\", timeit.timeit(lambda: conv_fn(image), number=10))\n", "print(\"Note how there's not much difference in performance for convolutions\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "uZ4Do2AV80cO" }, "source": [ "### 跟踪\n", "\n", "本部分介绍了 `Function` 的幕后运作方式,包括*未来可能会发生变化*的实现细节。但是,当您了解跟踪的原因和时间后,就能够更轻松高效地使用 `tf.function`!" ] }, { "cell_type": "markdown", "metadata": { "id": "nhpUtRqsXoyM" }, "source": [ "#### 什么是“跟踪”?\n", "\n", "`Function` 在 [TensorFlow 计算图](https://tensorflow.google.cn/guide/intro_to_graphs#what_are_graphs)中运行您的程序。但是,`tf.Graph` 不能代表您在 Eager TensorFlow 程序中编写的全部内容。例如,Python 支持多态,但是 `tf.Graph` 要求其输入具有指定的数据类型和维度。或者,您可能执行辅助任务,例如读取命令行参数、引发错误或使用更复杂的 Python 对象。这些内容均不能在 `tf.Graph` 中运行。\n", "\n", "`Function` 通过将代码分为以下两个阶段填补了这一空缺:\n", "\n", "1. 第一阶段称为**跟踪**,在这一阶段中,`Function` 会创建新的 `tf.Graph`。Python 代码可以正常运行,但是所有 TensorFlow 运算(例如添加两个张量)都会被*推迟*:它们会被 `tf.Graph` 捕获而不运行。\n", "\n", "2. 在第二阶段中,将运行包含第一阶段中推迟的全部内容的 `tf.Graph`。此阶段比跟踪阶段快得多。\n", "\n", "根据输入,`Function` 在调用时并非总会运行第一阶段。请参阅下方的[跟踪规则](#rules_of_tracing)以更好地了解其决定方式。跳过第一阶段并仅执行第二阶段,可以实现 TensorFlow 的高性能。\n", "\n", "当 `Function` 决定跟踪时,在跟踪阶段完成后会立即运行第二阶段,因此调用 `Function` 会创建并运行 `tf.Graph`。稍后,您将了解如何使用 [`get_concrete_function`](#obtaining_concrete_functions) 来仅运行跟踪阶段。" ] }, { "cell_type": "markdown", "metadata": { "id": "K7scSzLx662f" }, "source": [ "当您将不同类型的参数传递给 `Function` 时,两个阶段都将运行:\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.102822Z", "iopub.status.busy": "2022-12-14T22:33:54.102181Z", "iopub.status.idle": "2022-12-14T22:33:54.155667Z", "shell.execute_reply": "2022-12-14T22:33:54.154939Z" }, "id": "kojmJrgq8U9v" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing with Tensor(\"a:0\", shape=(), dtype=int32)\n", "tf.Tensor(2, shape=(), dtype=int32)\n", "\n", "Tracing with Tensor(\"a:0\", shape=(), dtype=float32)\n", "tf.Tensor(2.2, shape=(), dtype=float32)\n", "\n", "Tracing with Tensor(\"a:0\", shape=(), dtype=string)\n", "tf.Tensor(b'aa', shape=(), dtype=string)\n", "\n" ] } ], "source": [ "@tf.function\n", "def double(a):\n", " print(\"Tracing with\", a)\n", " return a + a\n", "\n", "print(double(tf.constant(1)))\n", "print()\n", "print(double(tf.constant(1.1)))\n", "print()\n", "print(double(tf.constant(\"a\")))\n", "print()\n" ] }, { "cell_type": "markdown", "metadata": { "id": "QPfouGUQrcNb" }, "source": [ "请注意,如果重复使用同一参数类型调用 `Function`,TensorFlow 会跳过跟踪阶段并重用之前跟踪的计算图,因为后面的调用生成的计算图可能相同。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.158987Z", "iopub.status.busy": "2022-12-14T22:33:54.158389Z", "iopub.status.idle": "2022-12-14T22:33:54.162749Z", "shell.execute_reply": "2022-12-14T22:33:54.162093Z" }, "id": "hFccbWFRrsBp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(b'bb', shape=(), dtype=string)\n" ] } ], "source": [ "# This doesn't print 'Tracing with ...'\n", "print(double(tf.constant(\"b\")))" ] }, { "cell_type": "markdown", "metadata": { "id": "fgIO_XEzcB9o" }, "source": [ "您可以使用 `pretty_printed_concrete_signatures()` 查看所有可用跟踪记录:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.165957Z", "iopub.status.busy": "2022-12-14T22:33:54.165489Z", "iopub.status.idle": "2022-12-14T22:33:54.169328Z", "shell.execute_reply": "2022-12-14T22:33:54.168699Z" }, "id": "IiQc4IKAb-NX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "double(a)\n", " Args:\n", " a: int32 Tensor, shape=()\n", " Returns:\n", " int32 Tensor, shape=()\n", "\n", "double(a)\n", " Args:\n", " a: float32 Tensor, shape=()\n", " Returns:\n", " float32 Tensor, shape=()\n", "\n", "double(a)\n", " Args:\n", " a: string Tensor, shape=()\n", " Returns:\n", " string Tensor, shape=()\n" ] } ], "source": [ "print(double.pretty_printed_concrete_signatures())" ] }, { "cell_type": "markdown", "metadata": { "id": "rKQ92VEWI7n8" }, "source": [ "目前,您已经了解 `tf.function` 通过 TensorFlow 的计算图跟踪逻辑创建缓存的动态调度层。对于术语的含义,更具体的解释如下:\n", "\n", "- `tf.Graph` 与语言无关,是 TensorFlow 计算的原始可移植表示。\n", "- `ConcreteFunction` 封装 `tf.Graph`。\n", "- `Function` 管理 `ConcreteFunction` 的缓存,并为输入选择正确的缓存。\n", "- `tf.function` 包装 Python 函数,并返回一个 `Function` 对象。\n", "- **跟踪**会创建 `tf.Graph` 并将其封装在 `ConcreteFunction` 中,也称为**跟踪**。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "129-iRsPS-gY" }, "source": [ "#### 跟踪规则\n", "\n", "被调用时,`Function` 使用每个参数的 `tf.types.experimental.TraceType` 将调用参数与现有的 `ConcreteFunction` 匹配。如果找到匹配的 `ConcreteFunction`,则将调用分派给它。如果未找到匹配项,则跟踪新的 `ConcreteFunction`。\n", "\n", "如果找到多个匹配项,则会选择最具体的签名。匹配是通过[子类型化](https://en.wikipedia.org/wiki/Subtyping)完成的,就像 C++ 或 Java 中的普通函数调用一样。例如,`TensorShape([1, 2])` 是 `TensorShape([None, None])` 的子类型,因此可以将使用 `TensorShape([1, 2])` 对 tf.function 进行的调用分派到使用 `TensorShape([None, None])` 生成的 `ConcreteFunction`。但是,如果具有 `TensorShape([1, None])` 的 `ConcreteFunction` 也存在,那么它将被优先考虑,因为它更具体。\n", "\n", "`TraceType` 由输入参数确定,具体如下所示:\n", "\n", "- 对于 `Tensor`,类型由 `Tensor` 的 `dtype` 和 `shape` 参数化;有秩形状是无秩形状的子类型;固定维度是未知维度的子类型\n", "- 对于 `Variable`,类型类似于 `Tensor`,但还包括变量的唯一资源 ID,这是正确连接控制依赖项所必需的\n", "- 对于 Python 基元值,类型对应于**值**本身。例如,值为 `3` 的 `TraceType` 是 `LiteralTraceType<3>`,而不是 `int`。\n", "- 对于 `list` 和 `tuple` 等 Python 有序容器,类型是通过其元素的类型来参数化的;例如,`[1, 2]` 的类型是 `ListTraceType, LiteralTraceType<2>>`,`[2, 1]` 的类型是 `ListTraceType, LiteralTraceType<1>>`,两者不同。\n", "- 对于 `dict` 等 Python 映射,类型也是从相同的键到值类型而不是实际值的映射。例如,`{1: 2, 3: 4}` 的类型为 `MappingTraceType<>>, >>>`。但是,与有序容器不同的是,`{1: 2, 3: 4}` 和 `{3: 4, 1: 2}` 具有等价的类型。\n", "- 对于实现 `__tf_tracing_type__` 方法的 Python 对象,类型为该方法返回的任何内容\n", "- 对于任何其他 Python 对象,类型是通用的 `TraceType`,它使用对象的 Python 相等性和散列进行匹配。(注:它依赖于对对象的[弱引用](https://docs.python.org/3/library/weakref.html),因此仅在对象处于范围内/未被删除时才有效。)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "GNNN4lgRzpIs" }, "source": [ "注:`TraceType` 基于 `Function` 输入参数,因此仅对全局变量和自由变量进行更改将不会创建新的跟踪记录。有关处理 Python 全局变量和自由变量的建议做法,请参阅[本部分](https://docs.python.org/3/reference/executionmodel.html#binding-of-names)。" ] }, { "cell_type": "markdown", "metadata": { "id": "PEDwbumO32Wh" }, "source": [ "### 控制回溯\n", "\n", "回溯即 `Function` 创建多个跟踪记录的过程,可以确保 TensorFlow 为每组输入生成正确的计算图。但是,跟踪非常消耗资源!如果 `Function` 为每一次调用都回溯新的计算图,您会发现代码的执行速度远不如不使用 `tf.function` 时快。\n", "\n", "要控制跟踪行为,可以采用以下技巧:" ] }, { "cell_type": "markdown", "metadata": { "id": "EUtycWJa34TT" }, "source": [ "#### 将固定的 `input_signature` 传递给 `tf.function`" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.172975Z", "iopub.status.busy": "2022-12-14T22:33:54.172455Z", "iopub.status.idle": "2022-12-14T22:33:54.221083Z", "shell.execute_reply": "2022-12-14T22:33:54.220474Z" }, "id": "_BDMIRmu1RGB" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing with Tensor(\"x:0\", shape=(None,), dtype=int32)\n", "tf.Tensor([4 1], shape=(2,), dtype=int32)\n", "Caught expected exception \n", " :\n", "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_176735/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_176735/1851403433.py\", line 9, in \n", " next_collatz(tf.constant([[1, 2], [3, 4]]))\n", "ValueError: Python inputs incompatible with input_signature:\n", " inputs: (\n", " tf.Tensor(\n", "[[1 2]\n", " [3 4]], shape=(2, 2), dtype=int32))\n", " input_signature: (\n", " TensorSpec(shape=(None,), dtype=tf.int32, name=None)).\n", "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_176735/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_176735/1851403433.py\", line 13, in \n", " next_collatz(tf.constant([1.0, 2.0]))\n", "ValueError: Python inputs incompatible with input_signature:\n", " inputs: (\n", " tf.Tensor([1. 2.], shape=(2,), dtype=float32))\n", " input_signature: (\n", " TensorSpec(shape=(None,), dtype=tf.int32, name=None)).\n" ] } ], "source": [ "@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n", "def next_collatz(x):\n", " print(\"Tracing with\", x)\n", " return tf.where(x % 2 == 0, x // 2, 3 * x + 1)\n", "\n", "print(next_collatz(tf.constant([1, 2])))\n", "# You specified a 1-D tensor in the input signature, so this should fail.\n", "with assert_raises(ValueError):\n", " next_collatz(tf.constant([[1, 2], [3, 4]]))\n", "\n", "# You specified an int32 dtype in the input signature, so this should fail.\n", "with assert_raises(ValueError):\n", " next_collatz(tf.constant([1.0, 2.0]))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ocxX-HVk7P2o" }, "source": [ "#### 使用未知维度以获得灵活性\n", "\n", "由于 TensorFlow 根据其形状匹配张量,因此,对于可变大小输入,使用 `None` 维度作为通配符可以让 `Function` 重复使用跟踪记录。对于每个批次,如果有不同长度的序列或不同大小的图像,则会出现可变大小输入(请参阅 [Transformer](../tutorials/text/transformer.ipynb) 和 [Deep Dream](../tutorials/generative/deepdream.ipynb) 教程了解示例)。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.224428Z", "iopub.status.busy": "2022-12-14T22:33:54.224017Z", "iopub.status.idle": "2022-12-14T22:33:54.353634Z", "shell.execute_reply": "2022-12-14T22:33:54.352997Z" }, "id": "4Viun7dh7PmF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing with Tensor(\"x:0\", shape=(None,), dtype=int32)\n", "tf.Tensor([1 2 3], shape=(3,), dtype=int32)\n", "tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)\n" ] } ], "source": [ "@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n", "def g(x):\n", " print('Tracing with', x)\n", " return x\n", "\n", "# No retrace!\n", "print(g(tf.constant([1, 2, 3])))\n", "print(g(tf.constant([1, 2, 3, 4, 5])))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "AY5oiQN0XIyA" }, "source": [ "#### 传递张量而不是 Python 文字\n", "\n", "通常,Python 参数用于控制超参数和计算图构造,例如 `num_layers=10`、`training=True` 或 `nonlinearity='relu'`。所以,如果 Python 参数改变,则有必要回溯计算图。\n", "\n", "但是,Python 参数有可能并未用于控制计算图构造。在这些情况下,Python 值的改变可能触发非必要的回溯。例如,在此训练循环中,AutoGraph 会动态展开。尽管有多个跟踪,但生成的计算图实际上是相同的,所以没有必要进行回溯。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.357594Z", "iopub.status.busy": "2022-12-14T22:33:54.356990Z", "iopub.status.idle": "2022-12-14T22:33:54.483057Z", "shell.execute_reply": "2022-12-14T22:33:54.482474Z" }, "id": "uydzR5JYUU8H" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Retracing occurs for different Python arguments.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Tracing with num_steps = 10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Executing with num_steps = 10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Tracing with num_steps = 20\n", "Executing with num_steps = 20\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Traces are reused for Tensor arguments.\n", "Tracing with num_steps = Tensor(\"num_steps:0\", shape=(), dtype=int32)\n", "Executing with num_steps = 10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Executing with num_steps = 20\n" ] } ], "source": [ "def train_one_step():\n", " pass\n", "\n", "@tf.function\n", "def train(num_steps):\n", " print(\"Tracing with num_steps = \", num_steps)\n", " tf.print(\"Executing with num_steps = \", num_steps)\n", " for _ in tf.range(num_steps):\n", " train_one_step()\n", "\n", "print(\"Retracing occurs for different Python arguments.\")\n", "train(num_steps=10)\n", "train(num_steps=20)\n", "\n", "print()\n", "print(\"Traces are reused for Tensor arguments.\")\n", "train(num_steps=tf.constant(10))\n", "train(num_steps=tf.constant(20))" ] }, { "cell_type": "markdown", "metadata": { "id": "4pJqkDR_Q2wz" }, "source": [ "如果需要强制执行回溯,可以创建一个新的 `Function`。单独的 `Function` 对象肯定不会共享跟踪记录。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.486848Z", "iopub.status.busy": "2022-12-14T22:33:54.486258Z", "iopub.status.idle": "2022-12-14T22:33:54.515780Z", "shell.execute_reply": "2022-12-14T22:33:54.515154Z" }, "id": "uHp4ousu4DdN" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing!\n", "Executing\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Tracing!\n", "Executing\n" ] } ], "source": [ "def f():\n", " print('Tracing!')\n", " tf.print('Executing')\n", "\n", "tf.function(f)()\n", "tf.function(f)()" ] }, { "cell_type": "markdown", "metadata": { "id": "-tZoWrA6INvc" }, "source": [ "#### 使用跟踪协议\n", "\n", "在可能的情况下,您应当首选将 Python 类型转换为 `tf.experimental.ExtensionType`。此外,`ExtensionType` 的 `TraceType` 是与其关联的 `tf.TypeSpec`。因此,如果需要,您只需重写默认的 `tf.TypeSpec` 即可控制 `ExtensionType` 的 `Tracing Protocol`。请参阅[扩展程序类型](extension_type.ipynb)指南中的*自定义 ExtensionType 的 TypeSpec*部分以了解详情。\n", "\n", "否则,要直接控制 `Function` 何时应针对特定 Python 类型进行重新跟踪,您可以自行为其实现 `Tracing Protocol`。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.519255Z", "iopub.status.busy": "2022-12-14T22:33:54.518673Z", "iopub.status.idle": "2022-12-14T22:33:54.575404Z", "shell.execute_reply": "2022-12-14T22:33:54.574786Z" }, "id": "gZkIh7UaIKc6" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@tf.function\n", "def get_mixed_flavor(fruit_a, fruit_b):\n", " return fruit_a.flavor + fruit_b.flavor\n", "\n", "class Fruit:\n", " flavor = tf.constant([0, 0])\n", "\n", "class Apple(Fruit):\n", " flavor = tf.constant([1, 2])\n", "\n", "class Mango(Fruit):\n", " flavor = tf.constant([3, 4])\n", "\n", "# As described in the above rules, a generic TraceType for `Apple` and `Mango`\n", "# is generated (and a corresponding ConcreteFunction is traced) but it fails to \n", "# match the second function call since the first pair of Apple() and Mango() \n", "# have gone out out of scope by then and deleted.\n", "get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function\n", "get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again\n", "\n", "# However, each subclass of the `Fruit` class has a fixed flavor, and you\n", "# can reuse an existing traced concrete function if it was the same\n", "# subclass. Avoiding such unnecessary tracing of concrete functions\n", "# can have significant performance benefits.\n", "\n", "class FruitTraceType(tf.types.experimental.TraceType):\n", " def __init__(self, fruit_type):\n", " self.fruit_type = fruit_type\n", "\n", " def is_subtype_of(self, other):\n", " return (type(other) is FruitTraceType and\n", " self.fruit_type is other.fruit_type)\n", "\n", " def most_specific_common_supertype(self, others):\n", " return self if all(self == other for other in others) else None\n", "\n", " def __eq__(self, other):\n", " return type(other) is FruitTraceType and self.fruit_type == other.fruit_type\n", " \n", " def __hash__(self):\n", " return hash(self.fruit_type)\n", "\n", "class FruitWithTraceType:\n", "\n", " def __tf_tracing_type__(self, context):\n", " return FruitTraceType(type(self))\n", "\n", "class AppleWithTraceType(FruitWithTraceType):\n", " flavor = tf.constant([1, 2])\n", "\n", "class MangoWithTraceType(FruitWithTraceType):\n", " flavor = tf.constant([3, 4])\n", "\n", "# Now if you try calling it again:\n", "get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function\n", "get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function" ] }, { "cell_type": "markdown", "metadata": { "id": "96IxS2WR37fF" }, "source": [ "### 获取具体函数\n", "\n", "每次跟踪函数时都会创建一个新的具体函数。您可以使用 `get_concrete_function` 直接获取具体函数。\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.578687Z", "iopub.status.busy": "2022-12-14T22:33:54.578216Z", "iopub.status.idle": "2022-12-14T22:33:54.584049Z", "shell.execute_reply": "2022-12-14T22:33:54.583381Z" }, "id": "mHg2CGtPQ3Hz" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Obtaining concrete trace\n", "Executing traced function\n", "tf.Tensor(b'aa', shape=(), dtype=string)\n", "tf.Tensor(b'bb', shape=(), dtype=string)\n" ] } ], "source": [ "print(\"Obtaining concrete trace\")\n", "double_strings = double.get_concrete_function(tf.constant(\"a\"))\n", "print(\"Executing traced function\")\n", "print(double_strings(tf.constant(\"a\")))\n", "print(double_strings(a=tf.constant(\"b\")))\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.587218Z", "iopub.status.busy": "2022-12-14T22:33:54.586600Z", "iopub.status.idle": "2022-12-14T22:33:54.591347Z", "shell.execute_reply": "2022-12-14T22:33:54.590762Z" }, "id": "6IVZ-NVf9vsx" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(b'cc', shape=(), dtype=string)\n" ] } ], "source": [ "# You can also call get_concrete_function on an InputSpec\n", "double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))\n", "print(double_strings_from_inputspec(tf.constant(\"c\")))" ] }, { "cell_type": "markdown", "metadata": { "id": "iR4fVmG34xvF" }, "source": [ "打印 `ConcreteFunction` 会显示其输入参数(及类型)和输出类型的摘要。" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.594764Z", "iopub.status.busy": "2022-12-14T22:33:54.594194Z", "iopub.status.idle": "2022-12-14T22:33:54.597590Z", "shell.execute_reply": "2022-12-14T22:33:54.596986Z" }, "id": "o3-JbkIk41r8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ConcreteFunction double(a)\n", " Args:\n", " a: string Tensor, shape=()\n", " Returns:\n", " string Tensor, shape=()\n" ] } ], "source": [ "print(double_strings)" ] }, { "cell_type": "markdown", "metadata": { "id": "QtqfvljZeuOV" }, "source": [ "您也可以直接检索具体函数的签名。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.600889Z", "iopub.status.busy": "2022-12-14T22:33:54.600339Z", "iopub.status.idle": "2022-12-14T22:33:54.603836Z", "shell.execute_reply": "2022-12-14T22:33:54.603218Z" }, "id": "nzbrqFABe0zG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})\n", "Tensor(\"Identity:0\", shape=(), dtype=string)\n" ] } ], "source": [ "print(double_strings.structured_input_signature)\n", "print(double_strings.structured_outputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "lar5A_5m5IG1" }, "source": [ "对不兼容的类型使用具体跟踪会引发错误" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.607128Z", "iopub.status.busy": "2022-12-14T22:33:54.606462Z", "iopub.status.idle": "2022-12-14T22:33:54.611025Z", "shell.execute_reply": "2022-12-14T22:33:54.610441Z" }, "id": "G5eeTK-T5KYj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_176735/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_176735/3196284684.py\", line 2, in \n", " double_strings(tf.constant(1))\n", "tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_166 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_166]\n" ] } ], "source": [ "with assert_raises(tf.errors.InvalidArgumentError):\n", " double_strings(tf.constant(1))" ] }, { "cell_type": "markdown", "metadata": { "id": "st2L9VNQVtSG" }, "source": [ "您可能会注意到,在具体函数的输入签名中对 Python 参数进行了特别处理。TensorFlow 2.3 之前的版本会将 Python 参数直接从具体函数的签名中删除。从 TensorFlow 2.3 开始,Python 参数会保留在签名中,但是会受到约束,只能获取在跟踪期间设置的值。" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.614267Z", "iopub.status.busy": "2022-12-14T22:33:54.613794Z", "iopub.status.idle": "2022-12-14T22:33:54.637693Z", "shell.execute_reply": "2022-12-14T22:33:54.637099Z" }, "id": "U_QyPSGoaC35" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ConcreteFunction pow(a, b=2)\n", " Args:\n", " a: float32 Tensor, shape=\n", " Returns:\n", " float32 Tensor, shape=\n" ] } ], "source": [ "@tf.function\n", "def pow(a, b):\n", " return a ** b\n", "\n", "square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)\n", "print(square)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.640896Z", "iopub.status.busy": "2022-12-14T22:33:54.640434Z", "iopub.status.idle": "2022-12-14T22:33:54.651642Z", "shell.execute_reply": "2022-12-14T22:33:54.651024Z" }, "id": "E76vIDhQbXIb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py\", line 1487, in _call_impl\n", " return self._call_with_flat_signature(args, kwargs,\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py\", line 1532, in _call_with_flat_signature\n", " raise TypeError(f\"{self._flat_signature_summary()} got unexpected \"\n", "TypeError: pow(a) got unexpected keyword arguments: b.\n", "\n", "During handling of the above exception, another exception occurred:\n", "\n", "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_176735/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_176735/2310937119.py\", line 4, in \n", " square(tf.constant(10.0), b=3)\n", "TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.\n" ] } ], "source": [ "assert square(tf.constant(10.0)) == 100\n", "\n", "with assert_raises(TypeError):\n", " square(tf.constant(10.0), b=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "41gJh_JGIfuA" }, "source": [ "### 获取计算图\n", "\n", "每个具体函数都是 `tf.Graph` 的可调用包装器。虽然一般不需要检索实际 `tf.Graph` 对象,不过,您可以从任何具体函数轻松获得实际对象。" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.654558Z", "iopub.status.busy": "2022-12-14T22:33:54.654355Z", "iopub.status.idle": "2022-12-14T22:33:54.658037Z", "shell.execute_reply": "2022-12-14T22:33:54.657493Z" }, "id": "5UENeGHfaX8g" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[] -> a\n", "['a', 'a'] -> add\n", "['add'] -> Identity\n" ] } ], "source": [ "graph = double_strings.graph\n", "for node in graph.as_graph_def().node:\n", " print(f'{node.input} -> {node.name}')\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aIKkgr6qdtp4" }, "source": [ "### 调试\n", "\n", "通常,在 Eager 模式下调试代码比在 `tf.function` 中简单。在使用 `tf.function` 进行装饰之前,进行装饰之前,您应该先确保代码可在 Eager 模式下无错误执行。为了帮助调试,您可以调用 `tf.config.run_functions_eagerly(True)` 来全局停用和重新启用 `tf.function`。\n", "\n", "追溯仅在 `tf.function` 中出现的问题时,可参考下面的几点提示:\n", "\n", "- 普通旧 Python `print` 调用仅在跟踪期间执行,可以帮助您在(重新)跟踪函数时进行追溯。\n", "- `tf.print` 调用每次都会执行,可用于追溯执行过程中产生的中间值。\n", "- 利用 `tf.debugging.enable_check_numerics` 很容易追溯到 NaN 和 Inf 在何处创建。\n", "- `pdb`([Python 调试器](https://docs.python.org/3/library/pdb.html))可以帮助您理解跟踪的详细过程。(提醒:使用 `pdb` 调试时,AutoGraph 会自动转换 Python 源代码。)" ] }, { "cell_type": "markdown", "metadata": { "id": "5f05Vr_YBUCz" }, "source": [ "## AutoGraph 转换\n", "\n", "AutoGraph 是一个库,在 `tf.function` 中默认处于启用状态。它可以将 Python Eager 代码的子集转换为与计算图兼容的 TensorFlow 运算。这包括 `if`、`for`、`while` 等控制流。\n", "\n", "`tf.cond` 和 `tf.while_loop` 等 TensorFlow 运算仍然可以运行,但是使用 Python 编写时,控制流通常更易于编写,代码也更易于理解。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.661694Z", "iopub.status.busy": "2022-12-14T22:33:54.661130Z", "iopub.status.idle": "2022-12-14T22:33:54.767768Z", "shell.execute_reply": "2022-12-14T22:33:54.767099Z" }, "id": "yCQTtTPTW3WF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.682211161 0.396621943 0.451262951 0.643357158 0.87304759]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.592955 0.37705484 0.422936589 0.567181051 0.702919185]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.532017589 0.360147029 0.399401426 0.513286 0.606217444]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.486921817 0.3453435 0.379436702 0.472501546 0.541458964]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.451769888 0.332239449 0.362218171 0.4402183 0.49409157]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.423352748 0.320531547 0.347166359 0.413825333 0.45745784]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.399751157 0.309987456 0.333860159 0.391715884 0.428010017]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.379735976 0.300425678 0.321985 0.372838497 0.40365687]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.362478137 0.291702092 0.311300635 0.356472 0.383073539]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.347394943 0.283700645 0.301619858 0.342102677 0.365373671]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.334063202 0.276326627 0.292794317 0.329353303 0.349938482]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.322167 0.269501895 0.284704626 0.31793955 0.336320966]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.311465025 0.263161272 0.277253687 0.307642668 0.324188948]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.301769286 0.257249981 0.270361394 0.298291 0.313289642]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.292930901 0.251721531 0.263961077 0.289747804 0.303426832]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.284830183 0.24653624 0.257996708 0.281902701 0.294445485]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.277369589 0.241659909 0.252420813 0.274665147 0.286221236]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.270468801 0.237062961 0.247192904 0.26796037 0.278653115]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.264061 0.23271966 0.242278129 0.261725903 0.271658033]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.25809 0.228607446 0.237646371 0.255909115 0.265166938]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.252508163 0.224706501 0.23327139 0.250465214 0.259121925]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.24727492 0.22099933 0.229130313 0.245355889 0.253474057]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.242355347 0.217470333 0.225202918 0.240548164 0.248181522]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.237719223 0.214105651 0.221471429 0.236013427 0.243208483]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.233340293 0.210892901 0.21792005 0.231726721 0.23852399]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.229195595 0.207821 0.214534715 0.227666199 0.234101087]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.225264892 0.20487988 0.211302832 0.22381258 0.229916304]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.221530348 0.202060521 0.208213195 0.220148876 0.225948915]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.217976168 0.199354753 0.205255613 0.216659933 0.222180739]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.214588255 0.196755111 0.20242089 0.213332266 0.218595564]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.211354017 0.19425483 0.199700788 0.210153803 0.215179041]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.208262146 0.191847727 0.197087735 0.207113713 0.211918324]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.205302477 0.189528167 0.194574893 0.204202175 0.20880191]\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# A simple loop\n", "\n", "@tf.function\n", "def f(x):\n", " while tf.reduce_sum(x) > 1:\n", " tf.print(x)\n", " x = tf.tanh(x)\n", " return x\n", "\n", "f(tf.random.uniform([5]))" ] }, { "cell_type": "markdown", "metadata": { "id": "KxwJ8znPI0Cg" }, "source": [ "如果您有兴趣,可以检查 Autograph 生成的代码。" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.770977Z", "iopub.status.busy": "2022-12-14T22:33:54.770512Z", "iopub.status.idle": "2022-12-14T22:33:54.775213Z", "shell.execute_reply": "2022-12-14T22:33:54.774659Z" }, "id": "jlQD1ffRXJhl" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def tf__f(x):\n", " with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:\n", " do_return = False\n", " retval_ = ag__.UndefinedReturnValue()\n", "\n", " def get_state():\n", " return (x,)\n", "\n", " def set_state(vars_):\n", " nonlocal x\n", " (x,) = vars_\n", "\n", " def loop_body():\n", " nonlocal x\n", " ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)\n", " x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)\n", "\n", " def loop_test():\n", " return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1\n", " ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})\n", " try:\n", " do_return = True\n", " retval_ = ag__.ld(x)\n", " except:\n", " do_return = False\n", " raise\n", " return fscope.ret(retval_, do_return)\n", "\n" ] } ], "source": [ "print(tf.autograph.to_code(f.python_function))" ] }, { "cell_type": "markdown", "metadata": { "id": "xgKmkrNTZSyz" }, "source": [ "### 条件语句\n", "\n", "AutoGraph 会将某些 `if ` 语句转换为等效的 `tf.cond` 调用。如果 `` 是张量,则会执行这种替换,否则会将 `if` 语句作为 Python 条件语句执行。\n", "\n", "Python 条件语句在跟踪时执行,因此会将该条件语句的一个分支添加到计算图。如果不使用 AutoGraph,当存在依赖于数据的控制流时,此跟踪计算图将无法选择替代分支。\n", "\n", "`tf.cond` 跟踪并将条件的两个分支添加到计算图,在执行时动态选择分支。跟踪可能产生意外的副作用;请参阅 [AutoGraph 跟踪作用](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#effects-of-the-tracing-process)以了解详情。" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.778451Z", "iopub.status.busy": "2022-12-14T22:33:54.777908Z", "iopub.status.idle": "2022-12-14T22:33:54.977784Z", "shell.execute_reply": "2022-12-14T22:33:54.976938Z" }, "id": "BOQl8PMq2Sf3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing for loop\n", "Tracing fizzbuzz branch\n", "Tracing fizz branch\n", "Tracing buzz branch\n", "Tracing default branch\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "buzz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "buzz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "7\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "buzz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "11\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "13\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "14\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizzbuzz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "16\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "17\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "19\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "buzz\n" ] } ], "source": [ "@tf.function\n", "def fizzbuzz(n):\n", " for i in tf.range(1, n + 1):\n", " print('Tracing for loop')\n", " if i % 15 == 0:\n", " print('Tracing fizzbuzz branch')\n", " tf.print('fizzbuzz')\n", " elif i % 3 == 0:\n", " print('Tracing fizz branch')\n", " tf.print('fizz')\n", " elif i % 5 == 0:\n", " print('Tracing buzz branch')\n", " tf.print('buzz')\n", " else:\n", " print('Tracing default branch')\n", " tf.print(i)\n", "\n", "fizzbuzz(tf.constant(5))\n", "fizzbuzz(tf.constant(20))" ] }, { "cell_type": "markdown", "metadata": { "id": "4rBO5AQ15HVC" }, "source": [ "有关 AutoGraph 转换的 if 语句的其他限制,请参阅[参考文档](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#if-statements)。" ] }, { "cell_type": "markdown", "metadata": { "id": "yho4J0a0ZkQS" }, "source": [ "### 循环\n", "\n", "AutoGraph 会将某些 `for` 和 `while` 语句转换为等效的 TensorFlow 循环运算,例如 `tf.while_loop`。如果不转换,则会将 `for` 或 `while` 循环作为 Python 循环执行。\n", "\n", "以下情形会执行这种替换:\n", "\n", "- `for x in y`:如果 `y` 是一个张量,则转换为 `tf.while_loop`。在特殊情况下,如果 `y` 是 `tf.data.Dataset`,则会生成 `tf.data.Dataset` 运算的组合。\n", "- `while `:如果 `` 是张量,则转换为 `tf.while_loop`。\n", "\n", "Python 循环在跟踪时执行,因而循环每迭代一次,都会将额外的运算添加到 `tf.Graph`。\n", "\n", "TensorFlow 循环会跟踪循环体,并在执行时动态选择迭代的运行次数。循环体仅在生成的 `tf.Graph` 中出现一次。\n", "\n", "有关 AutoGraph 转换的 `for` 和 `while` 语句的其他限制,请参阅[参考文档](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#while-statements)。" ] }, { "cell_type": "markdown", "metadata": { "id": "sp4rbIdfbM6s" }, "source": [ "#### 在 Python 数据上循环\n", "\n", "一个常见陷阱是在 `tf.function` 中的 Python/Numpy 数据上循环。此循环在跟踪过程中执行,因而循环每迭代一次,都会将模型的一个副本添加到 `tf.Graph`。\n", "\n", "如果要在 `tf.function` 中包装整个训练循环,最安全的方法是将数据包装为 `tf.data.Dataset`,以便 AutoGraph 动态展开训练循环。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:54.981886Z", "iopub.status.busy": "2022-12-14T22:33:54.981234Z", "iopub.status.idle": "2022-12-14T22:33:55.116667Z", "shell.execute_reply": "2022-12-14T22:33:55.115867Z" }, "id": "WGZ19LspbZ27" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph\n", "train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "train(, dtype=tf.int32, name=None), TensorSpec(shape=, dtype=tf.int32, name=None))>) contains 6 nodes in its graph\n", "train(, dtype=tf.int32, name=None), TensorSpec(shape=, dtype=tf.int32, name=None))>) contains 6 nodes in its graph\n" ] } ], "source": [ "def measure_graph_size(f, *args):\n", " g = f.get_concrete_function(*args).graph\n", " print(\"{}({}) contains {} nodes in its graph\".format(\n", " f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))\n", "\n", "@tf.function\n", "def train(dataset):\n", " loss = tf.constant(0)\n", " for x, y in dataset:\n", " loss += tf.abs(y - x) # Some dummy computation.\n", " return loss\n", "\n", "small_data = [(1, 1)] * 3\n", "big_data = [(1, 1)] * 10\n", "measure_graph_size(train, small_data)\n", "measure_graph_size(train, big_data)\n", "\n", "measure_graph_size(train, tf.data.Dataset.from_generator(\n", " lambda: small_data, (tf.int32, tf.int32)))\n", "measure_graph_size(train, tf.data.Dataset.from_generator(\n", " lambda: big_data, (tf.int32, tf.int32)))" ] }, { "cell_type": "markdown", "metadata": { "id": "JeD2U-yrbfVb" }, "source": [ "在数据集中封装 Python/Numpy 数据时,要注意 `tf.data.Dataset.from_generator` 与 ` tf.data.Dataset.from_tensors`。前者将数据保留在 Python 中,并通过 `tf.py_function` 获取,这可能会影响性能;后者将数据的副本捆绑成计算图中的一个大 `tf.constant()` 节点,这可能会消耗较多内存。\n", "\n", "通过 `TFRecordDataset`、`CsvDataset` 等从文件中读取数据是最高效的数据使用方式,因为这样 TensorFlow 就可以自行管理数据的异步加载和预提取,不必利用 Python。要了解详细信息,请参阅 [`tf.data`:构建 TensorFlow 输入流水线](../../guide/data)指南。" ] }, { "cell_type": "markdown", "metadata": { "id": "hyksHW9TCukR" }, "source": [ "#### 累加循环值\n", "\n", "一种常见模式是不断累加循环的中间值。通常,这可以通过将元素追加到 Python 列表或将条目添加到 Python 字典来实现。但是,由于存在 Python 副作用,在动态展开循环中,这些方法无法达到预期效果。要从动态展开循环累加结果,可以使用 `tf.TensorArray` 来实现。" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:55.120229Z", "iopub.status.busy": "2022-12-14T22:33:55.119943Z", "iopub.status.idle": "2022-12-14T22:33:55.261054Z", "shell.execute_reply": "2022-12-14T22:33:55.260194Z" }, "id": "HJ3Vb3dXfefN" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_size = 2\n", "seq_len = 3\n", "feature_size = 4\n", "\n", "def rnn_step(inp, state):\n", " return inp + state\n", "\n", "@tf.function\n", "def dynamic_rnn(rnn_step, input_data, initial_state):\n", " # [batch, time, features] -> [time, batch, features]\n", " input_data = tf.transpose(input_data, [1, 0, 2])\n", " max_seq_len = input_data.shape[0]\n", "\n", " states = tf.TensorArray(tf.float32, size=max_seq_len)\n", " state = initial_state\n", " for i in tf.range(max_seq_len):\n", " state = rnn_step(input_data[i], state)\n", " states = states.write(i, state)\n", " return tf.transpose(states.stack(), [1, 0, 2])\n", "\n", "dynamic_rnn(rnn_step,\n", " tf.random.uniform([batch_size, seq_len, feature_size]),\n", " tf.zeros([batch_size, feature_size]))" ] }, { "cell_type": "markdown", "metadata": { "id": "i2MVoIVaNApG" }, "source": [ "## 限制\n", "\n", "TensorFlow `Function` 有意设计了一些限制,在将 Python 函数转换为 `Function` 时需加以注意。" ] }, { "cell_type": "markdown", "metadata": { "id": "EJqHGFSVLIKl" }, "source": [ "### 执行 Python 副作用\n", "\n", "副作用(如打印、附加到列表、改变全局变量)在 `Function` 内部可能会出现异常行为,有时会执行两次或完全无法执行。它们只会在您第一次使用一组输入调用 `Function` 时发生。之后,将重新执行跟踪的 `tf.Graph`,而不执行 Python 代码。\n", "\n", "一般经验法则是避免在逻辑中依赖 Python 副作用,而仅使用它们来调试跟踪记录。否则,TensorFlow API(例如 `tf.data`、`tf.print`、`tf.summary`、`tf.Variable.assign` 和 `tf.TensorArray`)是确保在每次调用时 TensorFlow 运行时都能执行您的代码的最佳方式。" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:55.264407Z", "iopub.status.busy": "2022-12-14T22:33:55.263946Z", "iopub.status.idle": "2022-12-14T22:33:55.298465Z", "shell.execute_reply": "2022-12-14T22:33:55.297686Z" }, "id": "w2sACuZ9TTRk" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Traced with 1\n", "Executed with 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Executed with 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Traced with 2\n", "Executed with 2\n" ] } ], "source": [ "@tf.function\n", "def f(x):\n", " print(\"Traced with\", x)\n", " tf.print(\"Executed with\", x)\n", "\n", "f(1)\n", "f(1)\n", "f(2)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "e1I0dPiqTV8H" }, "source": [ "如果希望在每次调用 `Function` 时都执行 Python 代码,`tf.py_function` 可以作为退出点。`tf.py_function` 的缺点是不可移植,性能不高,无法使用 SavedModel 保存并且在分布式(多 GPU、TPU)设置中效果不佳。另外,由于 `tf.py_function` 必须连接到计算图中,它会将所有输入/输出转换为张量。" ] }, { "cell_type": "markdown", "metadata": { "id": "bOW1v9WVKGgH" }, "source": [ "#### 更改 Python 全局变量和自由变量\n", "\n", "更改 Python 全局变量和[自由变量](https://docs.python.org/3/reference/executionmodel.html#binding-of-names)视为 Python 副作用,因此仅在跟踪期间发生。\n" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:55.301987Z", "iopub.status.busy": "2022-12-14T22:33:55.301675Z", "iopub.status.idle": "2022-12-14T22:33:55.325428Z", "shell.execute_reply": "2022-12-14T22:33:55.324655Z" }, "id": "7aJD--9qTWmg" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python side effect\n" ] } ], "source": [ "external_list = []\n", "\n", "@tf.function\n", "def side_effect(x):\n", " print('Python side effect')\n", " external_list.append(x)\n", "\n", "side_effect(1)\n", "side_effect(1)\n", "side_effect(1)\n", "# The list append only happened once!\n", "assert len(external_list) == 1" ] }, { "cell_type": "markdown", "metadata": { "id": "5eZTFRv_k_nR" }, "source": [ "有时很难注意到意外行为。在下面的示例中,`counter` 旨在保护变量的增量。然而,由于它是一个 Python 整数而不是 TensorFlow 对象,它的值在第一次跟踪期间被捕获。使用 `tf.function` 时,`assign_add` 将被无条件记录在底层计算图中。因此,每次调用 `tf.function` 时 `v` 都会增加 1。当使用 Python 副作用(示例中的 `counter`)确定要运行的运算(示例中的 `assign_add`)时,此问题在尝试使用 `tf.function` 装饰器将其计算图模式 Tensorflow 代码迁移到 Tensorflow 2 的用户中十分常见。通常,用户只有在看到可疑的数值结果或明显低于预期的性能(例如,如果受保护运算的开销非常大)后才会意识到这一点。" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:55.328676Z", "iopub.status.busy": "2022-12-14T22:33:55.328084Z", "iopub.status.idle": "2022-12-14T22:33:55.374956Z", "shell.execute_reply": "2022-12-14T22:33:55.374125Z" }, "id": "5r6p7-9jk_3L" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n", "2\n", "3\n" ] } ], "source": [ "class Model(tf.Module):\n", " def __init__(self):\n", " self.v = tf.Variable(0)\n", " self.counter = 0\n", "\n", " @tf.function\n", " def __call__(self):\n", " if self.counter == 0:\n", " # A python side-effect\n", " self.counter += 1\n", " self.v.assign_add(1)\n", "\n", " return self.v\n", "\n", "m = Model()\n", "for n in range(3):\n", " print(m().numpy()) # prints 1, 2, 3" ] }, { "cell_type": "markdown", "metadata": { "id": "tXCTcHoVcxhX" }, "source": [ "实现预期行为的一种解决方法是使用 [`tf.init_scope`](https://tensorflow.google.cn/api_docs/python/tf/init_scope) 将运算提升到函数计算图以外。这样可以确保变量增量在跟踪期间只执行一次。应当注意的是,`init_scope` 还有其他副作用,包括清除控制流和梯度带。有时 `init_scope` 的使用会变得过于复杂而无法实际管理。" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:55.378434Z", "iopub.status.busy": "2022-12-14T22:33:55.377841Z", "iopub.status.idle": "2022-12-14T22:33:55.425119Z", "shell.execute_reply": "2022-12-14T22:33:55.424358Z" }, "id": "An4MrIbrcvi8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n", "1\n", "1\n" ] } ], "source": [ "class Model(tf.Module):\n", " def __init__(self):\n", " self.v = tf.Variable(0)\n", " self.counter = 0\n", "\n", " @tf.function\n", " def __call__(self):\n", " if self.counter == 0:\n", " # Lifts ops out of function-building graphs\n", " with tf.init_scope():\n", " self.counter += 1\n", " self.v.assign_add(1)\n", "\n", " return self.v\n", "\n", "m = Model()\n", "for n in range(3):\n", " print(m().numpy()) # prints 1, 1, 1" ] }, { "cell_type": "markdown", "metadata": { "id": "pbFG5CX4LwQA" }, "source": [ "总之,根据经验,您应避免改变整数或容器(如位于 `Function` 外部的列表)等 Python 对象,而应使用参数和 TF 对象。例如,[在循环中累加值](#accumulating_values_in_a_loop)部分中提供了一个如何实现类列表运算的示例。\n", "\n", "在某些情况下,如果为 [`tf.Variable`](https://tensorflow.google.cn/guide/variable),则您可以捕获和处理状态。这是通过重复调用相同的 `ConcreteFunction` 来更新 Keras 模型权重的方式。" ] }, { "cell_type": "markdown", "metadata": { "id": "X_oNNGrAqPJ1" }, "source": [ "#### 使用 Python 迭代器和生成器" ] }, { "cell_type": "markdown", "metadata": { "id": "msTmv-oyUNaf" }, "source": [ "很多 Python 功能(如生成器和迭代器)依赖 Python 运行时来跟踪状态。通常,虽然这些构造在 Eager 模式下可以正常工作,但它们是 Python 副作用的示例,因此仅在跟踪期间发生。" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:55.428858Z", "iopub.status.busy": "2022-12-14T22:33:55.428245Z", "iopub.status.idle": "2022-12-14T22:33:55.456458Z", "shell.execute_reply": "2022-12-14T22:33:55.455847Z" }, "id": "FNPD4unZUedH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Value: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Value: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Value: 1\n" ] } ], "source": [ "@tf.function\n", "def buggy_consume_next(iterator):\n", " tf.print(\"Value:\", next(iterator))\n", "\n", "iterator = iter([1, 2, 3])\n", "buggy_consume_next(iterator)\n", "# This reuses the first value from the iterator, rather than consuming the next value.\n", "buggy_consume_next(iterator)\n", "buggy_consume_next(iterator)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "wcS3TAgCjTWR" }, "source": [ "就像 TensorFlow 具有用于列表构造的专用 `tf.TensorArray` 一样,它也具有用于迭代构造的专用 `tf.data.Iterator`。有关概述,请参阅 [AutoGraph 转换](#autograph_transformations)部分。此外,[`tf.data`](https://tensorflow.google.cn/guide/data) API 也可帮助实现生成器模式:\n" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:55.459883Z", "iopub.status.busy": "2022-12-14T22:33:55.459339Z", "iopub.status.idle": "2022-12-14T22:33:55.498504Z", "shell.execute_reply": "2022-12-14T22:33:55.497863Z" }, "id": "8D_iKetXW6VE" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Value: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Value: 2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Value: 3\n" ] } ], "source": [ "@tf.function\n", "def good_consume_next(iterator):\n", " # This is ok, iterator is a tf.data.Iterator\n", " tf.print(\"Value:\", next(iterator))\n", "\n", "ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])\n", "iterator = iter(ds)\n", "good_consume_next(iterator)\n", "good_consume_next(iterator)\n", "good_consume_next(iterator)" ] }, { "cell_type": "markdown", "metadata": { "id": "i8YAMYb6KEh4" }, "source": [ "### tf.function 的所有输出都必须是返回值\n", "\n", "除了 `tf.Variable` 外,一个 tf.function 必须返回其所有输出。尝试直接从函数访问任何张量而不遍历返回值会导致“泄漏”。\n", "\n", "例如,下面的函数通过 Python 全局变量 `x`“泄漏”张量 `a`:" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:55.502054Z", "iopub.status.busy": "2022-12-14T22:33:55.501792Z", "iopub.status.idle": "2022-12-14T22:33:55.533576Z", "shell.execute_reply": "2022-12-14T22:33:55.533005Z" }, "id": "zrdp4rjxg6jo" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n", "'Tensor' object has no attribute 'numpy'\n" ] } ], "source": [ "x = None\n", "\n", "@tf.function\n", "def leaky_function(a):\n", " global x\n", " x = a + 1 # Bad - leaks local tensor\n", " return a + 2\n", "\n", "correct_a = leaky_function(tf.constant(1))\n", "\n", "print(correct_a.numpy()) # Good - value obtained from function's returns\n", "try:\n", " x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n", "except AttributeError as expected:\n", " print(expected)" ] }, { "cell_type": "markdown", "metadata": { "id": "-d4_J_DC5rxX" }, "source": [ "即使同时返回泄漏的值时也是如此:" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:55.536863Z", "iopub.status.busy": "2022-12-14T22:33:55.536299Z", "iopub.status.idle": "2022-12-14T22:33:55.815261Z", "shell.execute_reply": "2022-12-14T22:33:55.814571Z" }, "id": "PrcpPB8C5s9T" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2\n", "'Tensor' object has no attribute 'numpy'\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_176735/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_176735/566849597.py\", line 21, in \n", " captures_leaked_tensor(tf.constant(2))\n", "TypeError: is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.\n", "Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.\n", "\n", " was defined here:\n", " File \"/usr/lib/python3.9/runpy.py\", line 197, in _run_module_as_main\n", " return _run_code(code, main_globals, None,\n", " File \"/usr/lib/python3.9/runpy.py\", line 87, in _run_code\n", " exec(code, run_globals)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py\", line 17, in \n", " app.launch_new_instance()\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py\", line 992, in launch_instance\n", " app.start()\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py\", line 711, in start\n", " self.io_loop.start()\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py\", line 215, in start\n", " self.asyncio_loop.run_forever()\n", " File \"/usr/lib/python3.9/asyncio/base_events.py\", line 601, in run_forever\n", " self._run_once()\n", " File \"/usr/lib/python3.9/asyncio/base_events.py\", line 1905, in _run_once\n", " handle._run()\n", " File \"/usr/lib/python3.9/asyncio/events.py\", line 80, in _run\n", " self._context.run(self._callback, *self._args)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py\", line 510, in dispatch_queue\n", " await self.process_one()\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py\", line 499, in process_one\n", " await dispatch(*args)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py\", line 406, in dispatch_shell\n", " await result\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py\", line 729, in execute_request\n", " reply_content = await reply_content\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py\", line 411, in do_execute\n", " res = shell.run_cell(\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py\", line 531, in run_cell\n", " return super().run_cell(*args, **kwargs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 2940, in run_cell\n", " result = self._run_cell(\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 2995, in _run_cell\n", " return runner(coro)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n", " coro.send(None)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3194, in run_cell_async\n", " has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3373, in run_ast_nodes\n", " if await self.run_code(code, result, async_=asy):\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3433, in run_code\n", " exec(code_obj, self.user_global_ns, self.user_ns)\n", " File \"/tmpfs/tmp/ipykernel_176735/566849597.py\", line 7, in \n", " correct_a = leaky_function(tf.constant(1))\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n", " return fn(*args, **kwargs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 880, in __call__\n", " result = self._call(*args, **kwds)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 928, in _call\n", " self._initialize(args, kwds, add_initializers_to=initializers)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 749, in _initialize\n", " self._variable_creation_fn # pylint: disable=protected-access\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py\", line 162, in _get_concrete_function_internal_garbage_collected\n", " concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py\", line 157, in _maybe_define_concrete_function\n", " return self._maybe_define_function(args, kwargs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py\", line 360, in _maybe_define_function\n", " concrete_function = self._create_concrete_function(args, kwargs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py\", line 284, in _create_concrete_function\n", " func_graph_module.func_graph_from_py_func(\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py\", line 1283, in func_graph_from_py_func\n", " func_outputs = python_func(*func_args, **func_kwargs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 645, in wrapped_fn\n", " out = weak_wrapped_fn().__wrapped__(*args, **kwds)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py\", line 1258, in autograph_handler\n", " return autograph.converted_call(\n", " File \"/tmpfs/tmp/ipykernel_176735/566849597.py\", line 4, in leaky_function\n", " x = a + 1 # Bad - leaks local tensor\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n", " return fn(*args, **kwargs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py\", line 1407, in binary_op_wrapper\n", " return func(x, y, name=name)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n", " return fn(*args, **kwargs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py\", line 1176, in op_dispatch_handler\n", " return dispatch_target(*args, **kwargs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py\", line 1757, in _add_dispatch\n", " return gen_math_ops.add_v2(x, y, name=name)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py\", line 475, in add_v2\n", " _, _, _op, _outputs = _op_def_library._apply_op_helper(\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py\", line 795, in _apply_op_helper\n", " op = g._create_op_internal(op_type_name, inputs, dtypes=None,\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py\", line 749, in _create_op_internal\n", " return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py\", line 3798, in _create_op_internal\n", " ret = Operation(\n", "\n", "The tensor cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=139917558711056), which is out of scope.\n" ] } ], "source": [ "@tf.function\n", "def leaky_function(a):\n", " global x\n", " x = a + 1 # Bad - leaks local tensor\n", " return x # Good - uses local tensor\n", "\n", "correct_a = leaky_function(tf.constant(1))\n", "\n", "print(correct_a.numpy()) # Good - value obtained from function's returns\n", "try:\n", " x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n", "except AttributeError as expected:\n", " print(expected)\n", "\n", "@tf.function\n", "def captures_leaked_tensor(b):\n", " b += x # Bad - `x` is leaked from `leaky_function`\n", " return b\n", "\n", "with assert_raises(TypeError):\n", " captures_leaked_tensor(tf.constant(2))" ] }, { "cell_type": "markdown", "metadata": { "id": "Sm2ghjyy50D4" }, "source": [ "通常,当您使用 Python 语句或数据结构时,会发生此类泄漏。除了泄漏不可访问的张量之外,此类语句也可能是错误的,因为它们被视为 Python 副作用,而且不能保证在每次函数调用时都执行。\n", "\n", "泄漏局部张量的常见方法还包括改变外部 Python 集合或对象:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:55.818749Z", "iopub.status.busy": "2022-12-14T22:33:55.818273Z", "iopub.status.idle": "2022-12-14T22:33:55.822545Z", "shell.execute_reply": "2022-12-14T22:33:55.821935Z" }, "id": "D7bLe8y652wU" }, "outputs": [], "source": [ "class MyClass:\n", "\n", " def __init__(self):\n", " self.field = None\n", "\n", "external_list = []\n", "external_object = MyClass()\n", "\n", "def leaky_function():\n", " a = tf.constant(1)\n", " external_list.append(a) # Bad - leaks tensor\n", " external_object.field = a # Bad - leaks tensor" ] }, { "cell_type": "markdown", "metadata": { "id": "g-XVQcD-wf5K" }, "source": [ "### 不支持递归 tf.functions\n", "\n", "不支持递归 `Function`,它们可能导致无限循环。例如:" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:55.825959Z", "iopub.status.busy": "2022-12-14T22:33:55.825398Z", "iopub.status.idle": "2022-12-14T22:33:56.646068Z", "shell.execute_reply": "2022-12-14T22:33:56.645377Z" }, "id": "QSN-T1m5EFcR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_176735/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 9, in \n", " recursive_fn(tf.constant(5)) # Bad - maximum recursion error.\n", "tensorflow.python.autograph.impl.api.StagingError: in user code:\n", "\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_176735/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/usr/lib/python3.9/abc.py\", line 119, in __instancecheck__\n", " return _abc_instancecheck(cls, instance)\n", " File \"/usr/lib/python3.9/abc.py\", line 123, in __subclasscheck__\n", " return _abc_subclasscheck(cls, subclass)\n", "\n", " RecursionError: maximum recursion depth exceeded while calling a Python object\n", "\n" ] } ], "source": [ "@tf.function\n", "def recursive_fn(n):\n", " if n > 0:\n", " return recursive_fn(n - 1)\n", " else:\n", " return 1\n", "\n", "with assert_raises(Exception):\n", " recursive_fn(tf.constant(5)) # Bad - maximum recursion error." ] }, { "cell_type": "markdown", "metadata": { "id": "LyRyooKGUxNV" }, "source": [ "即使递归 `Function` 看似有效,Python 函数也会被多次跟踪,并且可能会对性能产生影响。例如:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:56.650889Z", "iopub.status.busy": "2022-12-14T22:33:56.650231Z", "iopub.status.idle": "2022-12-14T22:33:56.709335Z", "shell.execute_reply": "2022-12-14T22:33:56.708708Z" }, "id": "7FlmTqfMUwmT" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tracing\n", "tracing\n", "tracing\n", "tracing\n", "tracing\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@tf.function\n", "def recursive_fn(n):\n", " if n > 0:\n", " print('tracing')\n", " return recursive_fn(n - 1)\n", " else:\n", " return 1\n", "\n", "recursive_fn(5) # Warning - multiple tracings" ] }, { "cell_type": "markdown", "metadata": { "id": "-D6nh3QirXAd" }, "source": [ "## 已知问题\n", "\n", "如果您的 `Function` 评估不正确,则这些计划于将来得到修复的已知问题可能可以解释该问题。" ] }, { "cell_type": "markdown", "metadata": { "id": "ZoPg5w1Pjqna" }, "source": [ "### 取决于 Python 全局变量和自由变量\n", "\n", "当使用 Python 参数的新值进行调用时,`Function` 会创建新的 `ConcreteFunction`。但是,对于该 `Function` 的 Python 闭包、全局变量或非局部变量,则不会创建。如果它们的值在调用 `Function` 之间发生变化,则 `Function` 仍将使用其在跟踪时所具有的值。这与常规 Python 函数的工作方式不同。\n", "\n", "因此,您应采用使用参数的函数式编程风格而非闭合外部名称。" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:56.712893Z", "iopub.status.busy": "2022-12-14T22:33:56.712419Z", "iopub.status.idle": "2022-12-14T22:33:56.759741Z", "shell.execute_reply": "2022-12-14T22:33:56.759041Z" }, "id": "oeJMdXd3M0cM" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Buggy: tf.Tensor(2, shape=(), dtype=int32)\n", "Correct: tf.Tensor(2, shape=(), dtype=int32)\n" ] } ], "source": [ "@tf.function\n", "def buggy_add():\n", " return 1 + foo\n", "\n", "@tf.function\n", "def recommended_add(foo):\n", " return 1 + foo\n", "\n", "foo = 1\n", "print(\"Buggy:\", buggy_add())\n", "print(\"Correct:\", recommended_add(foo))" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:56.762975Z", "iopub.status.busy": "2022-12-14T22:33:56.762512Z", "iopub.status.idle": "2022-12-14T22:33:56.771889Z", "shell.execute_reply": "2022-12-14T22:33:56.771277Z" }, "id": "L3q7sUJWZOSU" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Updating the value of `foo` to 100!\n", "Buggy: tf.Tensor(2, shape=(), dtype=int32)\n", "Correct: tf.Tensor(101, shape=(), dtype=int32)\n" ] } ], "source": [ "print(\"Updating the value of `foo` to 100!\")\n", "foo = 100\n", "print(\"Buggy:\", buggy_add()) # Did not change!\n", "print(\"Correct:\", recommended_add(foo))" ] }, { "cell_type": "markdown", "metadata": { "id": "ZoPg5w1Pjqnb" }, "source": [ "更新全局值的另一种方法是使其成为 `tf.Variable` 并改用 `Variable.assign` 方法。\n" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:56.775080Z", "iopub.status.busy": "2022-12-14T22:33:56.774520Z", "iopub.status.idle": "2022-12-14T22:33:56.803620Z", "shell.execute_reply": "2022-12-14T22:33:56.802964Z" }, "id": "oeJMdXd3M0cc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable: tf.Tensor(2, shape=(), dtype=int32)\n" ] } ], "source": [ "@tf.function\n", "def variable_add():\n", " return 1 + foo\n", "\n", "foo = tf.Variable(1)\n", "print(\"Variable:\", variable_add())\n" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:56.806561Z", "iopub.status.busy": "2022-12-14T22:33:56.806337Z", "iopub.status.idle": "2022-12-14T22:33:56.811617Z", "shell.execute_reply": "2022-12-14T22:33:56.811015Z" }, "id": "L3q7sUJWZOSd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Updating the value of `foo` to 100!\n", "Variable: tf.Tensor(101, shape=(), dtype=int32)\n" ] } ], "source": [ "print(\"Updating the value of `foo` to 100!\")\n", "foo.assign(100)\n", "print(\"Variable:\", variable_add())" ] }, { "cell_type": "markdown", "metadata": { "id": "hvwe9gTIWfx6" }, "source": [ "#### 取决于 Python 对象" ] }, { "cell_type": "markdown", "metadata": { "id": "BJkZS-SwPvOQ" }, "source": [ "将 Python 对象作为参数传递给 `tf.function` 的建议存在许多已知问题,预计会在以后得到解决。通常,如果您使用 Python 基元或兼容 `tf.nest` 的结构作为参数,或将对象的*不同*实例传递给 `Function`,则可以依赖稳定的跟踪。但是,如果您传递**同一对象并仅更改其特性**时,`Function` 将*不会*创建新的跟踪记录。" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:56.815130Z", "iopub.status.busy": "2022-12-14T22:33:56.814534Z", "iopub.status.idle": "2022-12-14T22:33:56.847798Z", "shell.execute_reply": "2022-12-14T22:33:56.846985Z" }, "id": "ux8KJESVWDxX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(20.0, shape=(), dtype=float32)\n" ] } ], "source": [ "class SimpleModel(tf.Module):\n", " def __init__(self):\n", " # These values are *not* tf.Variables.\n", " self.bias = 0.\n", " self.weight = 2.\n", "\n", "@tf.function\n", "def evaluate(model, x):\n", " return model.weight * x + model.bias\n", "\n", "simple_model = SimpleModel()\n", "x = tf.constant(10.)\n", "print(evaluate(simple_model, x))" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:56.851366Z", "iopub.status.busy": "2022-12-14T22:33:56.850747Z", "iopub.status.idle": "2022-12-14T22:33:56.855158Z", "shell.execute_reply": "2022-12-14T22:33:56.854588Z" }, "id": "mUxRF4ghZZvX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Adding bias!\n", "tf.Tensor(20.0, shape=(), dtype=float32)\n" ] } ], "source": [ "print(\"Adding bias!\")\n", "simple_model.bias += 5.0\n", "print(evaluate(simple_model, x)) # Didn't change :(" ] }, { "cell_type": "markdown", "metadata": { "id": "Ytcgg2qFWaBF" }, "source": [ "如果使用相同的 `Function` 评估模型的更新实例,那么更新后的模型与原始模型将具有[相同的缓存键](#rules_of_tracing),所以这种做法并不合理。\n", "\n", "因此,建议您编写 `Function` 以避免依赖于可变对象特性,或者创建新对象。\n", "\n", "如果这不可行,则一种解决方法是,每次修改对象时都创建新的 `Function` 以强制回溯:" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:56.858164Z", "iopub.status.busy": "2022-12-14T22:33:56.857940Z", "iopub.status.idle": "2022-12-14T22:33:56.890062Z", "shell.execute_reply": "2022-12-14T22:33:56.889389Z" }, "id": "pFvWmWAAQjrv" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(20.0, shape=(), dtype=float32)\n" ] } ], "source": [ "def evaluate(model, x):\n", " return model.weight * x + model.bias\n", "\n", "new_model = SimpleModel()\n", "evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)\n", "# Don't pass in `new_model`, `Function` already captured its state during tracing.\n", "print(evaluate_no_bias(x)) " ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:56.893315Z", "iopub.status.busy": "2022-12-14T22:33:56.892824Z", "iopub.status.idle": "2022-12-14T22:33:56.908076Z", "shell.execute_reply": "2022-12-14T22:33:56.907402Z" }, "id": "bdU2-jF4ZH0B" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Adding bias!\n", "tf.Tensor(25.0, shape=(), dtype=float32)\n" ] } ], "source": [ "print(\"Adding bias!\")\n", "new_model.bias += 5.0\n", "# Create new Function and ConcreteFunction since you modified new_model.\n", "evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)\n", "print(evaluate_with_bias(x)) # Don't pass in `new_model`." ] }, { "cell_type": "markdown", "metadata": { "id": "uFgEZClsZrEi" }, "source": [ "[回溯可能十分耗费资源](https://tensorflow.google.cn/guide/intro_to_graphs#tracing_and_performance),您可以使用 `tf.Variable` 作为对象特性,可以对其进行改变(但非更改,请注意!) 以在无需回溯的情况下实现相似效果。\n" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:56.911373Z", "iopub.status.busy": "2022-12-14T22:33:56.910806Z", "iopub.status.idle": "2022-12-14T22:33:56.948046Z", "shell.execute_reply": "2022-12-14T22:33:56.947370Z" }, "id": "daAP_lucwS6w" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(20.0, shape=(), dtype=float32)\n" ] } ], "source": [ "class BetterModel:\n", "\n", " def __init__(self):\n", " self.bias = tf.Variable(0.)\n", " self.weight = tf.Variable(2.)\n", "\n", "@tf.function\n", "def evaluate(model, x):\n", " return model.weight * x + model.bias\n", "\n", "better_model = BetterModel()\n", "print(evaluate(better_model, x))\n" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:56.951182Z", "iopub.status.busy": "2022-12-14T22:33:56.950614Z", "iopub.status.idle": "2022-12-14T22:33:56.956431Z", "shell.execute_reply": "2022-12-14T22:33:56.955850Z" }, "id": "ktqwMJBqwTFj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Adding bias!\n", "tf.Tensor(25.0, shape=(), dtype=float32)\n" ] } ], "source": [ "print(\"Adding bias!\")\n", "better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5\n", "print(evaluate(better_model, x)) # This works!" ] }, { "cell_type": "markdown", "metadata": { "id": "lPr_6mK_AQWL" }, "source": [ "### 创建 tf.Variables\n", "\n", "`Function` 仅支持在第一次调用时创建一次,并且在后续函数调用中重复使用的单例 `tf.Variable`。下面的代码段会在每个函数调用中创建一个新的 `tf.Variable`,这会导致 `ValueError` 异常。\n", "\n", "示例:" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:56.959466Z", "iopub.status.busy": "2022-12-14T22:33:56.959187Z", "iopub.status.idle": "2022-12-14T22:33:57.004957Z", "shell.execute_reply": "2022-12-14T22:33:57.004319Z" }, "id": "Tx0Vvnb_9OB-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_176735/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_176735/3018268426.py\", line 7, in \n", " f(1.0)\n", "ValueError: in user code:\n", "\n", " File \"/tmpfs/tmp/ipykernel_176735/3018268426.py\", line 3, in f *\n", " v = tf.Variable(1.0)\n", "\n", " ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.\n", "\n" ] } ], "source": [ "@tf.function\n", "def f(x):\n", " v = tf.Variable(1.0)\n", " return v\n", "\n", "with assert_raises(ValueError):\n", " f(1.0)" ] }, { "cell_type": "markdown", "metadata": { "id": "KYm6-5GCILXQ" }, "source": [ "用于解决这种限制的常见模式是从 Python None 值开始,随后,在值为 None 时,有条件地创建 `tf.Variable`:" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:57.008463Z", "iopub.status.busy": "2022-12-14T22:33:57.007868Z", "iopub.status.idle": "2022-12-14T22:33:57.073908Z", "shell.execute_reply": "2022-12-14T22:33:57.073272Z" }, "id": "HQrG5_kOiKl_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(1, shape=(), dtype=int32)\n", "tf.Tensor(2, shape=(), dtype=int32)\n" ] } ], "source": [ "class Count(tf.Module):\n", " def __init__(self):\n", " self.count = None\n", "\n", " @tf.function\n", " def __call__(self):\n", " if self.count is None:\n", " self.count = tf.Variable(0)\n", " return self.count.assign_add(1)\n", "\n", "c = Count()\n", "print(c())\n", "print(c())" ] }, { "cell_type": "markdown", "metadata": { "id": "7uD6qI7aJwbR" }, "source": [ "#### 与多个 Keras 优化器一起使用\n", "\n", "将多个 Keras 优化器与 `tf.function` 一起使用时,您可能会遇到 `ValueError: tf.function only supports singleton tf.Variables created on the first call.`。发生此错误的原因是优化器在首次应用梯度时会在内部创建 `tf.Variables`。" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:57.077415Z", "iopub.status.busy": "2022-12-14T22:33:57.076905Z", "iopub.status.idle": "2022-12-14T22:33:57.382371Z", "shell.execute_reply": "2022-12-14T22:33:57.381647Z" }, "id": "yWQ3-r99Jvze" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Calling `train_step` with different optimizer...\n", "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_176735/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_176735/3167358578.py\", line 18, in \n", " train_step(w, x, y, opt2)\n", "ValueError: in user code:\n", "\n", " File \"/tmpfs/tmp/ipykernel_176735/3167358578.py\", line 9, in train_step *\n", " optimizer.apply_gradients(zip(gradients, [w]))\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py\", line 1140, in apply_gradients **\n", " return super().apply_gradients(grads_and_vars, name=name)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py\", line 621, in apply_gradients\n", " self.build(trainable_variables)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py\", line 139, in build\n", " self.add_variable_from_reference(\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py\", line 1072, in add_variable_from_reference\n", " return super().add_variable_from_reference(\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py\", line 496, in add_variable_from_reference\n", " variable = tf.Variable(\n", "\n", " ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.\n", "\n" ] } ], "source": [ "opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n", "opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n", " \n", "@tf.function\n", "def train_step(w, x, y, optimizer):\n", " with tf.GradientTape() as tape:\n", " L = tf.reduce_sum(tf.square(w*x - y))\n", " gradients = tape.gradient(L, [w])\n", " optimizer.apply_gradients(zip(gradients, [w]))\n", "\n", "w = tf.Variable(2.)\n", "x = tf.constant([-1.])\n", "y = tf.constant([2.])\n", "\n", "train_step(w, x, y, opt1)\n", "print(\"Calling `train_step` with different optimizer...\")\n", "with assert_raises(ValueError):\n", " train_step(w, x, y, opt2)" ] }, { "cell_type": "markdown", "metadata": { "id": "7Q8BRPCThTjB" }, "source": [ "如果您需要在训练期间更改优化器,一种解决方法是为每个优化器创建一个新的 `Function`,直接调用 [`ConcreteFunction`](#obtaining_concrete_functions)。" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:33:57.386525Z", "iopub.status.busy": "2022-12-14T22:33:57.385888Z", "iopub.status.idle": "2022-12-14T22:33:57.768334Z", "shell.execute_reply": "2022-12-14T22:33:57.767541Z" }, "id": "YV5F2Gy9hSI3" }, "outputs": [], "source": [ "opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n", "opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n", "\n", "# Not a tf.function.\n", "def train_step(w, x, y, optimizer):\n", " with tf.GradientTape() as tape:\n", " L = tf.reduce_sum(tf.square(w*x - y))\n", " gradients = tape.gradient(L, [w])\n", " optimizer.apply_gradients(zip(gradients, [w]))\n", "\n", "w = tf.Variable(2.)\n", "x = tf.constant([-1.])\n", "y = tf.constant([2.])\n", "\n", "# Make a new Function and ConcreteFunction for each optimizer.\n", "train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)\n", "train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)\n", "for i in range(10):\n", " if i % 2 == 0:\n", " train_step_1(w, x, y) # `opt1` is not used as a parameter. \n", " else:\n", " train_step_2(w, x, y) # `opt2` is not used as a parameter." ] }, { "cell_type": "markdown", "metadata": { "id": "Xjnz5CcuqQac" }, "source": [ "#### 与多个 Keras 模型一起使用\n", "\n", "将不同的模型实例传递给同一 `Function` 时,您也可能会遇到 `ValueError: tf.function only supports singleton tf.Variables created on the first call.`。\n", "\n", "发生此错误的原因是 Keras 模型([未定义其输入形状](https://tensorflow.google.cn/guide/keras/custom_layers_and_models#best_practice_deferring_weight_creation_until_the_shape_of_the_inputs_is_known))和 Keras 层会在首次调用时创建 `tf.Variables`。您可能正在尝试在已调用的 `Function` 中初始化这些变量。为避免此错误,请在训练模型之前尝试调用 `model.build(input_shape)` 以初始化所有权重。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "IKyrEY5GVX3M" }, "source": [ "## 延伸阅读\n", "\n", "要了解如何导出和加载 `Function`,请参阅 [SavedModel 指南](https://render.githubusercontent.com/guide/saved_model)。要详细了解跟踪后执行的计算图优化,请参阅 [Grappler 指南](https://render.githubusercontent.com/guide/graph_optimization)。要了解如何优化数据流水线和剖析模型性能,请参阅 [Profiler 指南](https://render.githubusercontent.com/guide/profiler.md)。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "function.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }