{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "WrcIOXsUQh8U" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T20:13:39.813593Z", "iopub.status.busy": "2022-12-14T20:13:39.812947Z", "iopub.status.idle": "2022-12-14T20:13:39.816731Z", "shell.execute_reply": "2022-12-14T20:13:39.816166Z" }, "id": "tXAbWHtqs1Y2" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "HTgMAvQq-PU_" }, "source": [ "# 扩展程序类型\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "jHcw9MtgBo7e" }, "source": [ "## 安装" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:13:39.819980Z", "iopub.status.busy": "2022-12-14T20:13:39.819512Z", "iopub.status.idle": "2022-12-14T20:14:10.360169Z", "shell.execute_reply": "2022-12-14T20:14:10.359417Z" }, "id": "0MsE_F0WBpmc" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 20:14:08.129645: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay\n" ] } ], "source": [ "!pip install -q tf_nightly\n", "import tensorflow as tf\n", "import numpy as np\n", "from typing import Tuple, List, Mapping, Union, Optional\n", "import tempfile" ] }, { "cell_type": "markdown", "metadata": { "id": "1BAk3bji_0wl" }, "source": [ "## 扩展程序类型\n", "\n", "用户定义的类型可以使项目的可读性、模块化、可维护程度更高。但是,大多数 TensorFlow API 对于用户定义的 Python 类型的支持却非常有限。这包括高级 API(如 [Keras](https://tensorflow.google.cn/guide/keras/overview)、[tf.function](https://tensorflow.google.cn/guide/function)、[tf.SavedModel](https://tensorflow.google.cn/guide/saved_model))和低级 API(如 `tf.while_loop` 和 `tf.concat`)。TensorFlow **扩展程序类型**可用于创建能够与 TensorFlow 的 API 无缝协作的用户定义的面向对象类型。要创建扩展程序类型,只需定义一个以 `tf.experimental.ExtensionType` 为基础的 Python 类,并使用[类型注解](https://www.python.org/dev/peps/pep-0484/)来指定每个字段的类型。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:10.364641Z", "iopub.status.busy": "2022-12-14T20:14:10.364232Z", "iopub.status.idle": "2022-12-14T20:14:10.369711Z", "shell.execute_reply": "2022-12-14T20:14:10.369088Z" }, "id": "7o5KY7L5_nxy" }, "outputs": [], "source": [ "class TensorGraph(tf.experimental.ExtensionType):\n", " \"\"\"A collection of labeled nodes connected by weighted edges.\"\"\"\n", " edge_weights: tf.Tensor # shape=[num_nodes, num_nodes]\n", " node_labels: Mapping[str, tf.Tensor] # shape=[num_nodes]; dtype=any\n", "\n", "class MaskedTensor(tf.experimental.ExtensionType):\n", " \"\"\"A tensor paired with a boolean mask, indicating which values are valid.\"\"\"\n", " values: tf.Tensor\n", " mask: tf.Tensor # shape=values.shape; false for missing/invalid values.\n", "\n", "class CSRSparseMatrix(tf.experimental.ExtensionType):\n", " \"\"\"Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix).\"\"\"\n", " values: tf.Tensor # shape=[num_nonzero]; dtype=any\n", " col_index: tf.Tensor # shape=[num_nonzero]; dtype=int64\n", " row_index: tf.Tensor # shape=[num_rows+1]; dtype=int64" ] }, { "cell_type": "markdown", "metadata": { "id": "FiaNXPa7pNK-" }, "source": [ "`tf.experimental.ExtensionType` 基类的工作方式类似于标准 Python 库中的 [`typing.NamedTuple`](https://docs.python.org/3/library/typing.html#typing.NamedTuple) 和 [`@dataclasses.dataclass`](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass)。特别是,它会根据字段类型注解自动添加构造函数和特殊方法(例如 `__repr__` 和 `__eq__`)。" ] }, { "cell_type": "markdown", "metadata": { "id": "JsE7X6_uMyLo" }, "source": [ "通常,扩展程序类型往往属于以下两个类别之一:\n", "\n", "- ***数据结构***,会将一组相关的值组合在一起,并且可以基于这些值提供有用的运算。数据结构可以十分常规(例如上面的 `TensorGraph` 示例),也可以针对特定模型进行高度定制。\n", "\n", "- ***类张量类型***,限定或延伸了“张量”的概念。此类别中的类型具有 `rank`、`shape`,通常还有 `dtype`;并且将它们与张量运算(例如 `tf.stack`、`tf.add` 或 `tf.matmul`)一起使用是合理的。`MaskedTensor` 和 `CSRSparseMatrix` 是类张量类型的示例。" ] }, { "cell_type": "markdown", "metadata": { "id": "uxngcajlMqIY" }, "source": [ "## 支持的 API\n", "\n", "以下 TensorFlow API 支持扩展程序类型:\n", "\n", "- **Keras**:扩展程序类型可以用作 Keras `Models` 和 `Layers` 的输入和输出。\n", "- **tf.data.Dataset**:扩展程序类型可以包含在 `Datasets` 中,并由数据集 `Iterators` 返回。\n", "- **TensorFlow Hub**:扩展程序类型可以用作 `tf.hub` 模块的输入和输出。\n", "- **SavedModel**:扩展程序类型可以用作 `SavedModel` 函数的输入和输出。\n", "- **tf.function**:扩展程序类型可以用作使用 `@tf.function` 装饰器包装的函数的参数和返回值。\n", "- **While 循环**:扩展程序类型可以用作 `tf.while_loop` 中的循环变量,也可以用作 while 循环体的参数和返回值。\n", "- **条件**:可以使用 `tf.cond` 和 `tf.case` 有条件地选择扩展程序类型。\n", "- **`tf.py_function`**:扩展程序类型可以用作 `tf.py_function` 的参数以及针对 `func` 参数的返回值。\n", "- **张量运算**:扩展程序类型可扩展以支持大多数接受张量输入的 TensorFlow 运算(例如,`tf.matmul`、`tf.gather` 和 `tf.reduce_sum`)。如需了解详情,请转到下面的*调度*部分。\n", "- **分布策略**:扩展程序类型可以用作按副本值。\n", "\n", "有关详情,请参阅下面的“支持 ExtensionType 的 TensorFlow API”部分。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "VIpZwuPVpwOX" }, "source": [ "## 要求\n" ] }, { "cell_type": "markdown", "metadata": { "id": "nNk_TQeJGVwV" }, "source": [ "### 字段类型\n", "\n", "必须声明所有字段(实例变量),并且必须为每个字段提供类型注解。支持以下类型注解:\n", "\n", "类型 | 示例\n", "--- | ---\n", "Python 整数 | `i: int`\n", "Python 浮点数 | `f: float`\n", "Python 字符串 | `s: str`\n", "Python 布尔值 | `b: bool`\n", "Python None | `n: None`\n", "[张量形状](https://tensorflow.google.cn/api_docs/python/tf/TensorShape) | `shape: tf.TensorShape`\n", "[张量数据类型](https://tensorflow.google.cn/api_docs/python/tf/dtypes/DType) | `dtype: tf.DType`\n", "[张量](https://tensorflow.google.cn/api_docs/python/tf/Tensor) | `t: tf.Tensor`\n", "[扩展程序类型](https://tensorflow.google.cn/api_docs/python/tf/experimental/ExtensionType) | `mt: MyMaskedTensor`\n", "[不规则张量](https://tensorflow.google.cn/api_docs/python/tf/RaggedTensor) | `rt: tf.RaggedTensor`\n", "[稀疏张量](https://tensorflow.google.cn/api_docs/python/tf/sparse/SparseTensor) | `st: tf.SparseTensor`\n", "[索引切片](https://tensorflow.google.cn/api_docs/python/tf/IndexedSlices) | `s: tf.IndexedSlices`\n", "[可选张量](https://tensorflow.google.cn/api_docs/python/tf/experimental/Optional) | `o: tf.experimental.Optional`\n", "[类型联合](https://docs.python.org/3/library/typing.html#typing.Union) | `int_or_float: typing.Union[int, float]`\n", "[元组](https://docs.python.org/3/library/typing.html#typing.Tuple) | `params: typing.Tuple[int, float, tf.Tensor, int]`\n", "[可变长度元组](https://docs.python.org/3/library/typing.html#typing.Tuple) | `lengths: typing.Tuple[int, ...]`\n", "[映射](https://docs.python.org/3/library/typing.html#typing.Mapping) | `tags: typing.Mapping[str, tf.Tensor]`\n", "[可选值](https://docs.python.org/3/library/typing.html#typing.Optional) | `weight: typing.Optional[tf.Tensor]`" ] }, { "cell_type": "markdown", "metadata": { "id": "iFetYyZsIvf6" }, "source": [ "### 可变性\n", "\n", "扩展程序类型必须是不可变的。这可以确保它们能够被 TensorFlow 的计算图跟踪机制正确跟踪。如果您发现自己想要改变扩展程序类型值,请考虑改为定义用于转换值的方法。例如,与其定义 `set_mask` 方法来改变 `MaskedTensor`,您可以定义用于返回新的 `MaskedTensor` 的 `set_mask` 方法:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:10.373308Z", "iopub.status.busy": "2022-12-14T20:14:10.372903Z", "iopub.status.idle": "2022-12-14T20:14:10.376455Z", "shell.execute_reply": "2022-12-14T20:14:10.375919Z" }, "id": "DThZLYH2IwFh" }, "outputs": [], "source": [ "class MaskedTensor(tf.experimental.ExtensionType):\n", " values: tf.Tensor\n", " mask: tf.Tensor\n", "\n", " def replace_mask(self, new_mask):\n", " self.values.shape.assert_is_compatible_with(new_mask.shape)\n", " return MaskedTensor(self.values, new_mask)" ] }, { "cell_type": "markdown", "metadata": { "id": "x3JyivI_qAtt" }, "source": [ "## `ExtensionType` 添加的功能\n", "\n", "`ExtensionType` 基类提供了以下功能:\n", "\n", "- 构造函数 (`__init__`)。\n", "- 可打印表示方法 (`__repr__`)。\n", "- 相等和不等运算符 (`__eq__`)。\n", "- 验证方法 (`__validate__`)。\n", "- 强制不变性。\n", "- 嵌套 `TypeSpec`。\n", "- 张量 API 调度支持。\n", "\n", "有关自定义此功能的更多信息,请转到下面的“自定义 `ExtensionType`”部分。" ] }, { "cell_type": "markdown", "metadata": { "id": "pfSYs6P26gKq" }, "source": [ "### 构造函数\n", "\n", "`ExtensionType` 添加的构造函数会将每个字段作为命名参数(按照它们在类定义中的排列顺序)。此构造函数将对每个形参进行类型检查,并在必要时对其进行转换。特别是,`Tensor` 字段会使用 `tf.convert_to_tensor` 进行转换;`Tuple` 字段会被转换为 `tuple`;`Mapping` 字段会被转换为不可变字典。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:10.379606Z", "iopub.status.busy": "2022-12-14T20:14:10.379124Z", "iopub.status.idle": "2022-12-14T20:14:13.902346Z", "shell.execute_reply": "2022-12-14T20:14:13.901637Z" }, "id": "DiXwyZ5M5KFW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[1 2 3]\n", " [4 5 6]], shape=(2, 3), dtype=int32)\n" ] } ], "source": [ "class MaskedTensor(tf.experimental.ExtensionType):\n", " values: tf.Tensor\n", " mask: tf.Tensor\n", "\n", "# Constructor takes one parameter for each field.\n", "mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],\n", " mask=[[True, True, False], [True, False, True]])\n", "\n", "# Fields are type-checked and converted to the declared types.\n", "# For example, `mt.values` is converted to a Tensor.\n", "print(mt.values)" ] }, { "cell_type": "markdown", "metadata": { "id": "ezNDe1cYF0Qb" }, "source": [ "如果字段值无法转换为其声明的类型,构造函数将引发 `TypeError`:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.905716Z", "iopub.status.busy": "2022-12-14T20:14:13.905470Z", "iopub.status.idle": "2022-12-14T20:14:13.909559Z", "shell.execute_reply": "2022-12-14T20:14:13.908956Z" }, "id": "6HnrMaabF5VS" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Got expected TypeError: mask: expected a Tensor, got 'NoneType'\n" ] } ], "source": [ "try:\n", " MaskedTensor([1, 2, 3], None)\n", "except TypeError as e:\n", " print(f\"Got expected TypeError: {e}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "FwQUI3X02s20" }, "source": [ "可以通过在类级别设置字段的值来指定字段的默认值:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.912747Z", "iopub.status.busy": "2022-12-14T20:14:13.912316Z", "iopub.status.idle": "2022-12-14T20:14:13.920831Z", "shell.execute_reply": "2022-12-14T20:14:13.920282Z" }, "id": "GbzDT9fz20JA" }, "outputs": [ { "data": { "text/plain": [ "Pencil(color='black', has_erasor=True, length=)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Pencil(tf.experimental.ExtensionType):\n", " color: str = \"black\"\n", " has_erasor: bool = True\n", " length: tf.Tensor = 1.0\n", "\n", "Pencil()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.923705Z", "iopub.status.busy": "2022-12-14T20:14:13.923288Z", "iopub.status.idle": "2022-12-14T20:14:13.927716Z", "shell.execute_reply": "2022-12-14T20:14:13.927122Z" }, "id": "nOW7lS9P4Foc" }, "outputs": [ { "data": { "text/plain": [ "Pencil(color='blue', has_erasor=True, length=)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Pencil(length=0.5, color=\"blue\")" ] }, { "cell_type": "markdown", "metadata": { "id": "S5Eivtg07Aau" }, "source": [ "### 可打印表示\n", "\n", "`ExtensionType` 添加了一个默认的可打印表示方法 (`__repr__`),其中包括类名和每个字段的值:\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.930856Z", "iopub.status.busy": "2022-12-14T20:14:13.930464Z", "iopub.status.idle": "2022-12-14T20:14:13.934632Z", "shell.execute_reply": "2022-12-14T20:14:13.934101Z" }, "id": "5SyiKTe55krG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MaskedTensor(values=, mask=)\n" ] } ], "source": [ "print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))" ] }, { "cell_type": "markdown", "metadata": { "id": "q4l_gnQh6nXR" }, "source": [ "### 相等运算符\n", "\n", "`ExtensionType` 添加了默认相等运算符 (`__eq__` 和 `__ne__`),如果两个值具有相同的类型并且其所有字段都相等,则认为二者相等。如果张量字段具有相同的形状并且对所有元素均符合逐元素相等,则认为张量字段相等。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.937757Z", "iopub.status.busy": "2022-12-14T20:14:13.937294Z", "iopub.status.idle": "2022-12-14T20:14:13.952862Z", "shell.execute_reply": "2022-12-14T20:14:13.952317Z" }, "id": "bHdLg13V52Xm" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a == a: True\n", "a == b: False\n", "a == a.values: False\n" ] } ], "source": [ "a = MaskedTensor([1, 2], [True, False])\n", "b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])\n", "print(f\"a == a: {a==a}\")\n", "print(f\"a == b: {a==b}\")\n", "print(f\"a == a.values: {a==a.values}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "O3HqsO3jZlQq" }, "source": [ "**注**:如果任何字段包含 `Tensor`,则 `__eq__` 可能会返回标量布尔 `Tensor`(而非 Python 布尔值)。" ] }, { "cell_type": "markdown", "metadata": { "id": "hCpBfkKqCuip" }, "source": [ "### 验证方法\n", "\n", "`ExtensionType` 添加了一个 `__validate__` 方法,此方法可重写以对字段执行验证检查。它会在调用构造函数之后,以及在字段经过类型检查并转换为其声明的类型之后运行,因此它可以假定所有字段都具有其声明的类型。\n", "\n", "以下示例会更新 `MaskedTensor` 以验证其字段的 `shape` 和 `dtype`:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.956026Z", "iopub.status.busy": "2022-12-14T20:14:13.955496Z", "iopub.status.idle": "2022-12-14T20:14:13.959221Z", "shell.execute_reply": "2022-12-14T20:14:13.958607Z" }, "id": "dgZOJRINDn00" }, "outputs": [], "source": [ "class MaskedTensor(tf.experimental.ExtensionType):\n", " \"\"\"A tensor paired with a boolean mask, indicating which values are valid.\"\"\"\n", " values: tf.Tensor\n", " mask: tf.Tensor\n", " def __validate__(self):\n", " self.values.shape.assert_is_compatible_with(self.mask.shape)\n", " assert self.mask.dtype.is_bool, 'mask.dtype must be bool'" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.962041Z", "iopub.status.busy": "2022-12-14T20:14:13.961593Z", "iopub.status.idle": "2022-12-14T20:14:13.965249Z", "shell.execute_reply": "2022-12-14T20:14:13.964719Z" }, "id": "ajSgkGUUn9WL" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Got expected AssertionError: mask.dtype must be bool\n" ] } ], "source": [ "try:\n", " MaskedTensor([1, 2, 3], [0, 1, 0]) # Wrong `dtype` for mask.\n", "except AssertionError as e:\n", " print(f\"Got expected AssertionError: {e}\")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.968165Z", "iopub.status.busy": "2022-12-14T20:14:13.967604Z", "iopub.status.idle": "2022-12-14T20:14:13.971653Z", "shell.execute_reply": "2022-12-14T20:14:13.971096Z" }, "id": "Fhb96luJn9K7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Got expected ValueError: Shapes (3,) and (2,) are incompatible\n" ] } ], "source": [ "try:\n", " MaskedTensor([1, 2, 3], [True, False]) # shapes don't match.\n", "except ValueError as e:\n", " print(f\"Got expected ValueError: {e}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "pjIPAF1OCAdO" }, "source": [ "### 强制不变性\n", "\n", "`ExtensionType` 会重写 `__setattr__` 和 `__delattr__` 方法以防止变更,从而确保扩展程序类型值不可变。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.974765Z", "iopub.status.busy": "2022-12-14T20:14:13.974298Z", "iopub.status.idle": "2022-12-14T20:14:13.977839Z", "shell.execute_reply": "2022-12-14T20:14:13.977302Z" }, "id": "NgmJ1C7ilN5C" }, "outputs": [], "source": [ "mt = MaskedTensor([1, 2, 3], [True, False, True])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.980965Z", "iopub.status.busy": "2022-12-14T20:14:13.980505Z", "iopub.status.idle": "2022-12-14T20:14:13.983938Z", "shell.execute_reply": "2022-12-14T20:14:13.983386Z" }, "id": "cMYmJr3RoFKp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.\n" ] } ], "source": [ "try:\n", " mt.mask = [True, True, True]\n", "except AttributeError as e:\n", " print(f\"Got expected AttributeError: {e}\")" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.986835Z", "iopub.status.busy": "2022-12-14T20:14:13.986406Z", "iopub.status.idle": "2022-12-14T20:14:13.989971Z", "shell.execute_reply": "2022-12-14T20:14:13.989349Z" }, "id": "ZWwA-zWdzqlU" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Got expected TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment\n" ] } ], "source": [ "try:\n", " mt.mask[0] = False\n", "except TypeError as e:\n", " print(f\"Got expected TypeError: {e}\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.992987Z", "iopub.status.busy": "2022-12-14T20:14:13.992596Z", "iopub.status.idle": "2022-12-14T20:14:13.995859Z", "shell.execute_reply": "2022-12-14T20:14:13.995295Z" }, "id": "PN_txJVKoFoF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.\n" ] } ], "source": [ "try:\n", " del mt.mask\n", "except AttributeError as e:\n", " print(f\"Got expected AttributeError: {e}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "FBVFtCYn69Ou" }, "source": [ "### 嵌套 TypeSpec\n", "\n", "每个 `ExtensionType` 类都有一个对应的 `TypeSpec` 类,它会自动创建并存储为 `.Spec`。\n", "\n", "此类会从值中捕获所有信息,*除了*任何嵌套张量的值。特别是,值的 `TypeSpec` 是通过将任何嵌套张量、ExtensionType 或 CompositeTensor 替换为其 `TypeSpec` 来创建的。\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:13.998975Z", "iopub.status.busy": "2022-12-14T20:14:13.998408Z", "iopub.status.idle": "2022-12-14T20:14:14.005611Z", "shell.execute_reply": "2022-12-14T20:14:14.005073Z" }, "id": "GRjANkGYKGnV" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TensorSpec(shape=(), dtype=tf.string, name=None)\n", "ImmutableDict({'height': TensorSpec(shape=(), dtype=tf.float32, name=None), 'speed': TensorSpec(shape=(), dtype=tf.float32, name=None)})\n" ] } ], "source": [ "class Player(tf.experimental.ExtensionType):\n", " name: tf.Tensor\n", " attributes: Mapping[str, tf.Tensor]\n", "\n", "anne = Player(\"Anne\", {\"height\": 8.3, \"speed\": 28.1})\n", "anne_spec = tf.type_spec_from_value(anne)\n", "print(anne_spec.name) # Records `dtype` and `shape`, but not the string value.\n", "print(anne_spec.attributes) # Records keys and TensorSpecs for values." ] }, { "cell_type": "markdown", "metadata": { "id": "I2fkgckxO564" }, "source": [ "`TypeSpec` 值可以显式构造,也可以使用 `tf.type_spec_from_value` 从 `ExtensionType` 值构造:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.008490Z", "iopub.status.busy": "2022-12-14T20:14:14.008099Z", "iopub.status.idle": "2022-12-14T20:14:14.011456Z", "shell.execute_reply": "2022-12-14T20:14:14.010889Z" }, "id": "1ehAa7d9OGai" }, "outputs": [], "source": [ "spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})\n", "spec2 = tf.type_spec_from_value(anne)" ] }, { "cell_type": "markdown", "metadata": { "id": "owcFG3cAMCwA" }, "source": [ "TensorFlow 会使用 `TypeSpec` 将值划分为**静态组件**和**动态组件**:\n", "\n", "- **静态组件**(在计算图构建时固定不变)使用 `tf.TypeSpec` 进行编码。\n", "- **动态组件**(每次运行计算图时都会发生变化)被编码为 `tf.Tensor` 的列表。\n", "\n", "例如,每当参数具有以前未见过的 `TypeSpec` 时,`tf.function` 都会回溯它的包装函数:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.014315Z", "iopub.status.busy": "2022-12-14T20:14:14.013923Z", "iopub.status.idle": "2022-12-14T20:14:14.017207Z", "shell.execute_reply": "2022-12-14T20:14:14.016627Z" }, "id": "pg-m5YLRM1Nd" }, "outputs": [], "source": [ "@tf.function\n", "def anonymize_player(player):\n", " print(\"<>\")\n", " return Player(\"\", player.attributes)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.020299Z", "iopub.status.busy": "2022-12-14T20:14:14.019818Z", "iopub.status.idle": "2022-12-14T20:14:14.061960Z", "shell.execute_reply": "2022-12-14T20:14:14.061423Z" }, "id": "0CCGm7cpeIq-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<>\n" ] }, { "data": { "text/plain": [ "Player(name='>, attributes=ImmutableDict({'height': , 'speed': }))" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Function gets traced (first time the function has been called):\n", "anonymize_player(Player(\"Anne\", {\"height\": 8.3, \"speed\": 28.1}))" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.065025Z", "iopub.status.busy": "2022-12-14T20:14:14.064641Z", "iopub.status.idle": "2022-12-14T20:14:14.070462Z", "shell.execute_reply": "2022-12-14T20:14:14.069937Z" }, "id": "WB7bt7s83mFE" }, "outputs": [ { "data": { "text/plain": [ "Player(name='>, attributes=ImmutableDict({'height': , 'speed': }))" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Function does NOT get traced (same TypeSpec: just tensor values changed)\n", "anonymize_player(Player(\"Bart\", {\"height\": 8.1, \"speed\": 25.3}))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.073549Z", "iopub.status.busy": "2022-12-14T20:14:14.073118Z", "iopub.status.idle": "2022-12-14T20:14:14.090081Z", "shell.execute_reply": "2022-12-14T20:14:14.089518Z" }, "id": "dNm7vLpR3nMH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<>\n" ] }, { "data": { "text/plain": [ "Player(name='>, attributes=ImmutableDict({'height': , 'jump': }))" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Function gets traced (new TypeSpec: keys for attributes changed):\n", "anonymize_player(Player(\"Chuck\", {\"height\": 11.0, \"jump\": 5.3}))" ] }, { "cell_type": "markdown", "metadata": { "id": "U5rN1HPq25xC" }, "source": [ "有关详情,请参阅 [tf.function 指南](https://tensorflow.google.cn/guide/function#rules_of_tracing)。" ] }, { "cell_type": "markdown", "metadata": { "id": "gX613uRk0qLz" }, "source": [ "## 自定义 ExtensionType\n", "\n", "除了简单地声明字段及其类型外,扩展程序类型还可以:\n", "\n", "- 重写默认可打印表示 (`__repr__`)。\n", "- 定义方法。\n", "- 定义类方法和静态方法。\n", "- 定义属性。\n", "- 重写默认构造函数 (`__init__`)。\n", "- 重写默认相等运算符 (`__eq__`)。\n", "- 定义运算符(例如 `__add__` 和 `__lt__`)。\n", "- 声明字段的默认值。\n", "- 定义子类。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MK-ePVDj-ROE" }, "source": [ "### 重写默认可打印表示\n", "\n", "您可以为扩展程序类型重写此默认字符串转换运算符。以下示例会更新 `MaskedTensor` 类以在 Eager 模式下打印值时生成更具可读性的字符串表示。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.093337Z", "iopub.status.busy": "2022-12-14T20:14:14.092904Z", "iopub.status.idle": "2022-12-14T20:14:14.099910Z", "shell.execute_reply": "2022-12-14T20:14:14.099284Z" }, "id": "gdPhjYEr8IGO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "class MaskedTensor(tf.experimental.ExtensionType):\n", " \"\"\"A tensor paired with a boolean mask, indicating which values are valid.\"\"\"\n", " values: tf.Tensor\n", " mask: tf.Tensor # shape=values.shape; false for invalid values.\n", "\n", " def __repr__(self):\n", " return masked_tensor_str(self.values, self.mask)\n", "\n", "def masked_tensor_str(values, mask):\n", " if isinstance(values, tf.Tensor):\n", " if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):\n", " return f''\n", " else:\n", " return f'MaskedTensor(values={values}, mask={mask})'\n", " if len(values.shape) == 1:\n", " items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]\n", " else:\n", " items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]\n", " return '[%s]' % ', '.join(items)\n", "\n", "mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],\n", " mask=[[True, True, False], [True, False, True]])\n", "print(mt)" ] }, { "cell_type": "markdown", "metadata": { "id": "_MLQU2_v8VjG" }, "source": [ "### 定义方法\n", "\n", "与任何常规 Python 类一样,扩展程序类型也可以定义方法。例如,`MaskedTensor` 类型可以定义 `with_default` 方法,该方法会返回一个 `self` 的副本,其中掩码值会被替换为给定的 `default` 值。可以选择使用 `@tf.function` 装饰器注解方法。" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.103035Z", "iopub.status.busy": "2022-12-14T20:14:14.102518Z", "iopub.status.idle": "2022-12-14T20:14:14.109737Z", "shell.execute_reply": "2022-12-14T20:14:14.109202Z" }, "id": "7RR-tqee8ZdP" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MaskedTensor(tf.experimental.ExtensionType):\n", " values: tf.Tensor\n", " mask: tf.Tensor\n", "\n", " def with_default(self, default):\n", " return tf.where(self.mask, self.values, default)\n", "\n", "MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)" ] }, { "cell_type": "markdown", "metadata": { "id": "Qwd_gGKp9RP0" }, "source": [ "### 定义类方法和静态方法\n", "\n", "扩展程序类型可以使用 `@classmethod` 和 `@staticmethod` 装饰器定义方法。例如,`MaskedTensor` 类型可以定义能够使用给定值来遮盖任何元素的工厂方法:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.112821Z", "iopub.status.busy": "2022-12-14T20:14:14.112303Z", "iopub.status.idle": "2022-12-14T20:14:14.119894Z", "shell.execute_reply": "2022-12-14T20:14:14.119281Z" }, "id": "BacCEJYU9sBR" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MaskedTensor(tf.experimental.ExtensionType):\n", " values: tf.Tensor\n", " mask: tf.Tensor\n", "\n", " def __repr__(self):\n", " return masked_tensor_str(self.values, self.mask)\n", "\n", " @staticmethod\n", " def from_tensor_and_value_to_mask(values, value_to_mask):\n", " return MaskedTensor(values, values != value_to_mask)\n", "\n", "x = tf.constant([[1, 0, 2], [3, 0, 0]])\n", "MaskedTensor.from_tensor_and_value_to_mask(x, 0)" ] }, { "cell_type": "markdown", "metadata": { "id": "xIPf9PZX9AwL" }, "source": [ "### 定义属性\n", "\n", "与任何常规 Python 类一样,扩展程序类型也可以使用 `@property` 装饰器定义属性。例如,`MaskedTensor` 类型可以定义 `dtype` 属性,它是值的数据类型的简写形式:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.123071Z", "iopub.status.busy": "2022-12-14T20:14:14.122529Z", "iopub.status.idle": "2022-12-14T20:14:14.127931Z", "shell.execute_reply": "2022-12-14T20:14:14.127263Z" }, "id": "16E68wZ-9KXp" }, "outputs": [ { "data": { "text/plain": [ "tf.int32" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MaskedTensor(tf.experimental.ExtensionType):\n", " values: tf.Tensor\n", " mask: tf.Tensor\n", "\n", " @property\n", " def dtype(self):\n", " return self.values.dtype\n", "\n", "MaskedTensor([1, 2, 3], [True, False, True]).dtype" ] }, { "cell_type": "markdown", "metadata": { "id": "Mm5gxoG57nf3" }, "source": [ "### 重写默认构造函数\n", "\n", "您可以重写扩展程序类型的默认构造函数。自定义构造函数必须为每个声明的字段均设置一个值;并且在自定义构造函数返回后,所有字段都将进行类型检查,并将按上述方式转换值。" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.131016Z", "iopub.status.busy": "2022-12-14T20:14:14.130798Z", "iopub.status.idle": "2022-12-14T20:14:14.135555Z", "shell.execute_reply": "2022-12-14T20:14:14.134830Z" }, "id": "-8K3KeB08G1S" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Toy(name='ball', price=)\n" ] } ], "source": [ "class Toy(tf.experimental.ExtensionType):\n", " name: str\n", " price: tf.Tensor\n", " def __init__(self, name, price, discount=0):\n", " self.name = name\n", " self.price = price * (1 - discount)\n", "\n", "print(Toy(\"ball\", 5.0, discount=0.2)) # On sale -- 20% off!" ] }, { "cell_type": "markdown", "metadata": { "id": "qyQxMlwLFQt7" }, "source": [ "或者,您可以考虑保留默认构造函数,但添加一个或多个工厂方法。例如:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.138704Z", "iopub.status.busy": "2022-12-14T20:14:14.138188Z", "iopub.status.idle": "2022-12-14T20:14:14.142567Z", "shell.execute_reply": "2022-12-14T20:14:14.142036Z" }, "id": "jiApK4hzFY89" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Toy(name='ball', price=)\n" ] } ], "source": [ "class Toy(tf.experimental.ExtensionType):\n", " name: str\n", " price: tf.Tensor\n", "\n", " @staticmethod\n", " def new_toy_with_discount(name, price, discount):\n", " return Toy(name, price * (1 - discount))\n", "\n", "print(Toy.new_toy_with_discount(\"ball\", 5.0, discount=0.2))" ] }, { "cell_type": "markdown", "metadata": { "id": "pdVcRBhG-Uee" }, "source": [ "### 重写默认相等运算符 (`__eq__`)\n", "\n", "您可以重写扩展程序类型的默认 `__eq__` 运算符。以下示例会更新 `MaskedTensor` 以在比较相等性时忽略遮盖元素。" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.145978Z", "iopub.status.busy": "2022-12-14T20:14:14.145365Z", "iopub.status.idle": "2022-12-14T20:14:14.156703Z", "shell.execute_reply": "2022-12-14T20:14:14.156116Z" }, "id": "dA7DyjfB-Yz0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(True, shape=(), dtype=bool)\n" ] } ], "source": [ "class MaskedTensor(tf.experimental.ExtensionType):\n", " values: tf.Tensor\n", " mask: tf.Tensor\n", "\n", " def __repr__(self):\n", " return masked_tensor_str(self.values, self.mask)\n", "\n", " def __eq__(self, other):\n", " result = tf.math.equal(self.values, other.values)\n", " result = result | ~(self.mask & other.mask)\n", " return tf.reduce_all(result)\n", "\n", "x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])\n", "y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])\n", "print(x == y)" ] }, { "cell_type": "markdown", "metadata": { "id": "n1mZ1Lkyi14B" }, "source": [ "**注**:您通常不需要重写 `__ne__`,因为其默认实现只需调用 `__eq__` 并对结果求反。" ] }, { "cell_type": "markdown", "metadata": { "id": "A_Jib1SQD1-z" }, "source": [ "### 使用前向引用\n", "\n", "如果字段的类型尚未定义,您可以改用包含类型名称的字符串。在以下示例中,字符串 `\"Node\"` 用于注解 `children` 字段,因为 `Node` 类型尚未(完全)定义。\n" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.160031Z", "iopub.status.busy": "2022-12-14T20:14:14.159536Z", "iopub.status.idle": "2022-12-14T20:14:14.164610Z", "shell.execute_reply": "2022-12-14T20:14:14.164091Z" }, "id": "_Z029QKED0Ao" }, "outputs": [ { "data": { "text/plain": [ "Node(value=, children=(Node(value=, children=()), Node(value=, children=())))" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Node(tf.experimental.ExtensionType):\n", " value: tf.Tensor\n", " children: Tuple[\"Node\", ...] = ()\n", "\n", "Node(3, [Node(5), Node(2)])" ] }, { "cell_type": "markdown", "metadata": { "id": "boaNg1zHgoVn" }, "source": [ "### 定义子类\n", "\n", "扩展程序类型可以使用标准 Python 语法进行子类化。扩展程序类型子类可以添加新字段、方法和属性;并且可以重写构造函数、可打印表示和相等运算符。以下示例定义了一个基本的 `TensorGraph` 类,使用三个 `Tensor` 字段来编码节点之间的一组边。然后,它会定义一个子类,添加一个 `Tensor` 字段来记录每个节点的“特征值”。该子类还会定义一个沿着边传播特征值的方法。" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.167712Z", "iopub.status.busy": "2022-12-14T20:14:14.167176Z", "iopub.status.idle": "2022-12-14T20:14:14.180628Z", "shell.execute_reply": "2022-12-14T20:14:14.180090Z" }, "id": "58r6qRiK-uZh" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original features: tf.Tensor([10. 0. 2. 5. -1. 0.], shape=(6,), dtype=float32)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "After propagating: tf.Tensor([10. 12. 4. 4. -1. 0.], shape=(6,), dtype=float32)\n" ] } ], "source": [ "class TensorGraph(tf.experimental.ExtensionType):\n", " num_nodes: tf.Tensor\n", " edge_src: tf.Tensor # edge_src[e] = index of src node for edge e.\n", " edge_dst: tf.Tensor # edge_dst[e] = index of dst node for edge e.\n", "\n", "class TensorGraphWithNodeFeature(TensorGraph):\n", " node_features: tf.Tensor # node_features[n] = feature value for node n.\n", "\n", " def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':\n", " updates = tf.gather(self.node_features, self.edge_src) * weight\n", " new_node_features = tf.tensor_scatter_nd_add(\n", " self.node_features, tf.expand_dims(self.edge_dst, 1), updates)\n", " return TensorGraphWithNodeFeature(\n", " self.num_nodes, self.edge_src, self.edge_dst, new_node_features)\n", "\n", "g = TensorGraphWithNodeFeature( # Edges: 0->1, 4->3, 2->2, 2->1\n", " num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],\n", " node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])\n", "\n", "print(\"Original features:\", g.node_features)\n", "print(\"After propagating:\", g.propagate_features().node_features)" ] }, { "cell_type": "markdown", "metadata": { "id": "U_oElT5HzqSG" }, "source": [ "### 定义私有字段\n", "\n", "扩展程序类型的字段可以通过在前面加上下划线来标记为私有(遵循标准 Python 惯例)。这不会影响 TensorFlow 处理字段的任何方式;但只为向扩展程序类型的任何用户表明这些字段为私有。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "oMdH7ORqh8Pl" }, "source": [ "### 自定义 ExtensionType 的 `TypeSpec`\n", "\n", "每个 `ExtensionType` 类都有一个对应的 `TypeSpec` 类,后者是自动创建的并被存储为 `.Spec`。有关详情,请参阅上面的“嵌套 TypeSpec”部分。\n", "\n", "要自定义 `TypeSpec`,只需定义您自己的名为 `Spec` 的嵌套类,`ExtensionType` 将使用它作为自动构造的 `TypeSpec` 的基础。您可以通过以下方式自定义 `Spec` 类:\n", "\n", "- 重写默认可打印表示。\n", "- 重写默认构造函数。\n", "- 定义方法、类方法、静态方法和属性。\n", "\n", "以下示例自定义了 `MaskedTensor.Spec` 类以使其更加易于使用:" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.183963Z", "iopub.status.busy": "2022-12-14T20:14:14.183522Z", "iopub.status.idle": "2022-12-14T20:14:14.188633Z", "shell.execute_reply": "2022-12-14T20:14:14.188107Z" }, "id": "Gm4RaqbkLlNG" }, "outputs": [], "source": [ "class MaskedTensor(tf.experimental.ExtensionType):\n", " values: tf.Tensor\n", " mask: tf.Tensor\n", "\n", " shape = property(lambda self: self.values.shape)\n", " dtype = property(lambda self: self.values.dtype)\n", "\n", " def __repr__(self):\n", " return masked_tensor_str(self.values, self.mask)\n", "\n", " def with_values(self, new_values):\n", " return MaskedTensor(new_values, self.mask)\n", "\n", " class Spec:\n", " def __init__(self, shape, dtype=tf.float32):\n", " self.values = tf.TensorSpec(shape, dtype)\n", " self.mask = tf.TensorSpec(shape, tf.bool)\n", "\n", " def __repr__(self):\n", " return f\"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})\"\n", "\n", " shape = property(lambda self: self.values.shape)\n", " dtype = property(lambda self: self.values.dtype)" ] }, { "cell_type": "markdown", "metadata": { "id": "s3zzUXPSNF72" }, "source": [ "**注**:自定义 `Spec` 类不能使用任何未在原始 `ExtensionType` 中声明的实例变量。" ] }, { "cell_type": "markdown", "metadata": { "id": "rip4GCuYPL7o" }, "source": [ "## 张量 API 调度\n", "\n", "扩展程序类型可以是“类张量”,因为它们限定或延伸了 `tf.Tensor` 类型定义的接口。类张量扩展程序类型的示例包括 `RaggedTensor`、`SparseTensor` 和 `MaskedTensor`。当应用于类张量扩展程序类型时,***调度装饰器***可用于重写 TensorFlow 运算的默认行为。TensorFlow 目前定义了三个调度装饰器:\n", "\n", "- `@tf.experimental.dispatch_for_api(tf_api)`\n", "- `@tf.experimental.dispatch_for_unary_elementwise_apis(x_type)`\n", "- `@tf.experimental.dispatch_for_binary_elementwise_apis(x_type, y_type)`" ] }, { "cell_type": "markdown", "metadata": { "id": "5BTQHcY4gHwZ" }, "source": [ "### 单个 API 的调度\n", "\n", "在使用指定签名进行调用时,`tf.experimental.dispatch_for_api` 装饰器会重写指定 TensorFlow 运算的默认行为。例如,您可以使用此装饰器来指定 `tf.stack` 应如何处理 `MaskedTensor` 值:" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.192144Z", "iopub.status.busy": "2022-12-14T20:14:14.191672Z", "iopub.status.idle": "2022-12-14T20:14:14.195390Z", "shell.execute_reply": "2022-12-14T20:14:14.194868Z" }, "id": "B4QgO_fUW2o2" }, "outputs": [], "source": [ "@tf.experimental.dispatch_for_api(tf.stack)\n", "def masked_stack(values: List[MaskedTensor], axis = 0):\n", " return MaskedTensor(tf.stack([v.values for v in values], axis),\n", " tf.stack([v.mask for v in values], axis))" ] }, { "cell_type": "markdown", "metadata": { "id": "FxKcKWNUaLvm" }, "source": [ "每当使用 `MaskedTensor` 值的列表调用`tf.stack` 时,这都会重写它的默认实现(因为 `values` 参数使用 `typing.List[MaskedTensor]` 注解):" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.198377Z", "iopub.status.busy": "2022-12-14T20:14:14.197934Z", "iopub.status.idle": "2022-12-14T20:14:14.205485Z", "shell.execute_reply": "2022-12-14T20:14:14.204973Z" }, "id": "RqpFjaAvaA19" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = MaskedTensor([1, 2, 3], [True, True, False])\n", "y = MaskedTensor([4, 5, 6], [False, True, True])\n", "tf.stack([x, y])" ] }, { "cell_type": "markdown", "metadata": { "id": "loGi8taCa265" }, "source": [ "要允许 `tf.stack` 处理混合的 `MaskedTensor` 和 `Tensor` 值的列表,您可以优化 `values` 形参的类型注解并适当地更新函数体:" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.208703Z", "iopub.status.busy": "2022-12-14T20:14:14.208165Z", "iopub.status.idle": "2022-12-14T20:14:14.220314Z", "shell.execute_reply": "2022-12-14T20:14:14.219742Z" }, "id": "_xySkm0ganAI" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.experimental.unregister_dispatch_for(masked_stack)\n", "\n", "def convert_to_masked_tensor(x):\n", " if isinstance(x, MaskedTensor):\n", " return x\n", " else:\n", " return MaskedTensor(x, tf.ones_like(x, tf.bool))\n", "\n", "@tf.experimental.dispatch_for_api(tf.stack)\n", "def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):\n", " values = [convert_to_masked_tensor(v) for v in values]\n", " return MaskedTensor(tf.stack([v.values for v in values], axis),\n", " tf.stack([v.mask for v in values], axis))\n", "x = MaskedTensor([1, 2, 3], [True, True, False])\n", "y = tf.constant([4, 5, 6])\n", "tf.stack([x, y, x])" ] }, { "cell_type": "markdown", "metadata": { "id": "ITioFCyjQm8V" }, "source": [ "有关可重写 API 的列表,请参阅 `tf.experimental.dispatch_for_api` 的 API 文档。" ] }, { "cell_type": "markdown", "metadata": { "id": "f91SaHSqc-jO" }, "source": [ "### 所有一元逐元素 API 的调度\n", "\n", "只要第一个参数(通常命名为 `x`)的值与类型注解 `x_type` 相匹配,`tf.experimental.dispatch_for_unary_elementwise_apis` 装饰器就会重写***所有***一元逐元素运算(例如 `tf.math.cos`)的默认行为。装饰函数应接受两个参数:\n", "\n", "- `api_func`:接受单个形参并执行逐元素运算的函数(例如 `tf.abs`)。\n", "- `x`:逐元素运算的第一个参数。\n", "\n", "以下示例会更新所有一元逐元素运算以处理 `MaskedTensor` 类型:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.223569Z", "iopub.status.busy": "2022-12-14T20:14:14.223104Z", "iopub.status.idle": "2022-12-14T20:14:14.235529Z", "shell.execute_reply": "2022-12-14T20:14:14.234991Z" }, "id": "cv5fV4xxZI9q" }, "outputs": [], "source": [ " @tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)\n", " def masked_tensor_unary_elementwise_api_handler(api_func, x):\n", " return MaskedTensor(api_func(x.values), x.mask)" ] }, { "cell_type": "markdown", "metadata": { "id": "qiK4n6vaeFwo" }, "source": [ "现在,只要在 `MaskedTensor` 上调用一元逐元素运算,就会使用此函数。" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.238379Z", "iopub.status.busy": "2022-12-14T20:14:14.237946Z", "iopub.status.idle": "2022-12-14T20:14:14.243124Z", "shell.execute_reply": "2022-12-14T20:14:14.242513Z" }, "id": "SkH0xi5gd_41" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ " x = MaskedTensor([1, -2, -3], [True, False, True])\n", " print(tf.abs(x))" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.245940Z", "iopub.status.busy": "2022-12-14T20:14:14.245495Z", "iopub.status.idle": "2022-12-14T20:14:14.250687Z", "shell.execute_reply": "2022-12-14T20:14:14.250115Z" }, "id": "2Ej5fxLBfaXW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "print(tf.ones_like(x, dtype=tf.float32))" ] }, { "cell_type": "markdown", "metadata": { "id": "Z9OgLyfEejqc" }, "source": [ "### 所有二进制逐元素 API 的调度\n", "\n", "同样,`tf.experimental.dispatch_for_binary_elementwise_apis` 可用于更新所有二进制逐元素运算以处理 `MaskedTensor` 类型:\n" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.253722Z", "iopub.status.busy": "2022-12-14T20:14:14.253290Z", "iopub.status.idle": "2022-12-14T20:14:14.362546Z", "shell.execute_reply": "2022-12-14T20:14:14.361874Z" }, "id": "Z8Du-GPofpCW" }, "outputs": [], "source": [ "@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)\n", "def masked_tensor_binary_elementwise_api_handler(api_func, x, y):\n", " return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.365808Z", "iopub.status.busy": "2022-12-14T20:14:14.365569Z", "iopub.status.idle": "2022-12-14T20:14:14.373670Z", "shell.execute_reply": "2022-12-14T20:14:14.373096Z" }, "id": "gghVHDfSfyi2" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = MaskedTensor([1, -2, -3], [True, False, True])\n", "y = MaskedTensor([[4], [5]], [[True], [False]])\n", "tf.math.add(x, y)" ] }, { "cell_type": "markdown", "metadata": { "id": "txTGg9pzG0Ux" }, "source": [ "有关被重写的逐元素 API 的列表,请转到 `tf.experimental.dispatch_for_unary_elementwise_apis` 和 `tf.experimental.dispatch_for_binary_elementwise_apis` 的 API 文档。" ] }, { "cell_type": "markdown", "metadata": { "id": "UseRtohYKiE5" }, "source": [ "## 可批处理 ExtensionType\n", "\n", "如果单个实例可用于表示一批值,则 `ExtensionType` 为*可批处理*。通常,这可以通过向所有嵌套 `Tensor` 添加批量维度来实现。以下 TensorFlow API 要求任何扩展程序类型的输入都可批处理:\n", "\n", "- `tf.data.Dataset`(`batch`、`unbatch`、`from_tensor_slices`)\n", "- `tf.keras`(`fit`、`evaluate`、`predict`)\n", "- `tf.map_fn`" ] }, { "cell_type": "markdown", "metadata": { "id": "hWPauKGj_yRz" }, "source": [ "默认情况下,`BatchableExtensionType` 会通过批处理任何嵌套的 `Tensor`、`CompositeTensor` 和 `ExtensionType` 来创建批处理值。如果这不适合您的类,那么您将需要使用 `tf.experimental.ExtensionTypeBatchEncoder` 来重写此默认行为。例如,通过简单地堆叠各个稀疏张量的 `values`、`indices` 和 `dense_shape` 字段来创建一批 `tf.SparseTensor` 值是不合适的 – 在大多数情况下,您不能堆叠这些张量,因为它们具有不兼容的形状;即便可以,结果也不会是有效的 `SparseTensor`。\n", "\n", "**注**:`BatchableExtensionType` *不会*自动为 `tf.stack`、`tf.concat`、`tf.slice` 等定义调度器。如果您的类需要这些 API 的支持,请使用上述调度装饰器。" ] }, { "cell_type": "markdown", "metadata": { "id": "xkOJ8ke8GH7s" }, "source": [ "### BatchableExtensionType 示例:Network\n", "\n", "例如,请思考用于负载均衡的简单 `Network` 类,用于跟踪每个节点还有多少剩余工作,以及有多少带宽可用于在节点之间移动工作:" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.377155Z", "iopub.status.busy": "2022-12-14T20:14:14.376714Z", "iopub.status.idle": "2022-12-14T20:14:14.381966Z", "shell.execute_reply": "2022-12-14T20:14:14.381380Z" }, "id": "tOeEXwCcfrPd" }, "outputs": [], "source": [ "class Network(tf.experimental.ExtensionType): # This version is not batchable.\n", " work: tf.Tensor # work[n] = work left to do at node n\n", " bandwidth: tf.Tensor # bandwidth[n1, n2] = bandwidth from n1->n2\n", "\n", "net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])\n", "net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])" ] }, { "cell_type": "markdown", "metadata": { "id": "PaOzUev6g3wT" }, "source": [ "要使此类型可批处理,请将基本类型更改为 `BatchableExtensionType`,并调整每个字段的形状来包含可选的批次维度。以下示例还添加了一个 `shape` 字段来跟踪批次形状。`tf.data.Dataset` 或 `tf.map_fn` 不需要此 `shape` 字段,但 `tf.keras` *需要*。" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.384958Z", "iopub.status.busy": "2022-12-14T20:14:14.384522Z", "iopub.status.idle": "2022-12-14T20:14:14.390246Z", "shell.execute_reply": "2022-12-14T20:14:14.389633Z" }, "id": "T03WWBSMg2XC" }, "outputs": [], "source": [ "class Network(tf.experimental.BatchableExtensionType):\n", " shape: tf.TensorShape # batch shape. A single network has shape=[].\n", " work: tf.Tensor # work[*shape, n] = work left to do at node n\n", " bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2\n", "\n", " def __init__(self, work, bandwidth):\n", " self.work = tf.convert_to_tensor(work)\n", " self.bandwidth = tf.convert_to_tensor(bandwidth)\n", " work_batch_shape = self.work.shape[:-1]\n", " bandwidth_batch_shape = self.bandwidth.shape[:-2]\n", " self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)\n", "\n", " def __repr__(self):\n", " return network_repr(self)\n", "\n", "def network_repr(network):\n", " work = network.work\n", " bandwidth = network.bandwidth\n", " if hasattr(work, 'numpy'):\n", " work = ' '.join(str(work.numpy()).split())\n", " if hasattr(bandwidth, 'numpy'):\n", " bandwidth = ' '.join(str(bandwidth.numpy()).split())\n", " return (f\"\")" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.393112Z", "iopub.status.busy": "2022-12-14T20:14:14.392669Z", "iopub.status.idle": "2022-12-14T20:14:14.401547Z", "shell.execute_reply": "2022-12-14T20:14:14.401013Z" }, "id": "NUUJe9HuIPel" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "net1=\n", "net2=\n", "batch=\n" ] } ], "source": [ "net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])\n", "net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])\n", "batch_of_networks = Network(\n", " work=tf.stack([net1.work, net2.work]),\n", " bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))\n", "print(f\"net1={net1}\")\n", "print(f\"net2={net2}\")\n", "print(f\"batch={batch_of_networks}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "r0qWur5JGc3d" }, "source": [ "然后,您可以使用 `tf.data.Dataset` 迭代一批网络:" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.404552Z", "iopub.status.busy": "2022-12-14T20:14:14.404090Z", "iopub.status.idle": "2022-12-14T20:14:14.420788Z", "shell.execute_reply": "2022-12-14T20:14:14.420161Z" }, "id": "BN_kixAUFZtv" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Batch element 0: \n", "Batch element 1: \n" ] } ], "source": [ "dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)\n", "for i, network in enumerate(dataset):\n", " print(f\"Batch element {i}: {network}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "aXENhTzIIjbM" }, "source": [ "您还可以使用 `map_fn` 对每个批处理元素应用函数:" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.423694Z", "iopub.status.busy": "2022-12-14T20:14:14.423455Z", "iopub.status.idle": "2022-12-14T20:14:14.489592Z", "shell.execute_reply": "2022-12-14T20:14:14.489013Z" }, "id": "j1XEsSWj9a3D" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def balance_work_greedy(network):\n", " delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))\n", " delta /= 4\n", " delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)\n", " new_work = network.work + tf.reduce_sum(delta, -1)\n", " return Network(new_work, network.bandwidth)\n", "\n", "tf.map_fn(balance_work_greedy, batch_of_networks)" ] }, { "cell_type": "markdown", "metadata": { "id": "f_HLsTT02Xul" }, "source": [ "## 支持 ExtensionType 的 TensorFlow API" ] }, { "cell_type": "markdown", "metadata": { "id": "NNiQad2U2alT" }, "source": [ "### @tf.function\n", "\n", "[tf.function](https://tensorflow.google.cn/guide/function) 是预计算 Python 函数 TensorFlow 计算图的装饰器,可以大幅改善 TensorFlow 代码的性能。扩展程序类型能够透明地与 `@tf.function` 装饰的函数一起使用。" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.492759Z", "iopub.status.busy": "2022-12-14T20:14:14.492252Z", "iopub.status.idle": "2022-12-14T20:14:14.528424Z", "shell.execute_reply": "2022-12-14T20:14:14.527868Z" }, "id": "jQ_rAvrA6qEb" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Pastry(tf.experimental.ExtensionType):\n", " sweetness: tf.Tensor # 2d embedding that encodes sweetness\n", " chewiness: tf.Tensor # 2d embedding that encodes chewiness\n", "\n", "@tf.function\n", "def combine_pastry_features(x: Pastry):\n", " return (x.sweetness + x.chewiness) / 2\n", "\n", "cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])\n", "combine_pastry_features(cookie)" ] }, { "cell_type": "markdown", "metadata": { "id": "u1P-0Udg71Vx" }, "source": [ "如果您希望为 `tf.function` 明确指定 `input_signature`,则可以使用扩展程序类型的 `TypeSpec` 执行此操作。" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.531557Z", "iopub.status.busy": "2022-12-14T20:14:14.531098Z", "iopub.status.idle": "2022-12-14T20:14:14.569291Z", "shell.execute_reply": "2022-12-14T20:14:14.568655Z" }, "id": "0df90E4x78d7" }, "outputs": [ { "data": { "text/plain": [ "Pastry(sweetness=, chewiness=)" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))\n", "\n", "@tf.function(input_signature=[pastry_spec])\n", "def increase_sweetness(x: Pastry, delta=1.0):\n", " return Pastry(x.sweetness + delta, x.chewiness)\n", "\n", "increase_sweetness(cookie)" ] }, { "cell_type": "markdown", "metadata": { "id": "CdTfc5nD9JpD" }, "source": [ "#### 具体函数\n", "\n", "具体函数封装通过 `tf.function` 构建的各个跟踪计算图。扩展程序类型可以透明地与具体函数一起使用。\n" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.572307Z", "iopub.status.busy": "2022-12-14T20:14:14.571892Z", "iopub.status.idle": "2022-12-14T20:14:14.577055Z", "shell.execute_reply": "2022-12-14T20:14:14.576517Z" }, "id": "FyHBBQWk9xz2" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cf = combine_pastry_features.get_concrete_function(pastry_spec)\n", "cf(cookie)" ] }, { "cell_type": "markdown", "metadata": { "id": "LYas8gtG5IMA" }, "source": [ "### 控制流运算\n", "\n", "TensorFlow 的控制流运算支持扩展程序类型:\n", "\n", "- `tf.cond`\n", "- `tf.case`\n", "- `tf.while_loop`\n", "- `tf.identity`\n" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.580100Z", "iopub.status.busy": "2022-12-14T20:14:14.579602Z", "iopub.status.idle": "2022-12-14T20:14:14.584862Z", "shell.execute_reply": "2022-12-14T20:14:14.584318Z" }, "id": "6G2XE9ZtJu8z" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Example: using tf.cond to select between two MaskedTensors. Note that the\n", "# two MaskedTensors don't need to have the same shape.\n", "a = MaskedTensor([1., 2, 3], [True, False, True])\n", "b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])\n", "condition = tf.constant(True)\n", "print(tf.cond(condition, lambda: a, lambda: b))" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.587690Z", "iopub.status.busy": "2022-12-14T20:14:14.587210Z", "iopub.status.idle": "2022-12-14T20:14:14.593819Z", "shell.execute_reply": "2022-12-14T20:14:14.593271Z" }, "id": "2NwLOw1kKSek" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Example: using tf.while_loop with MaskedTensor.\n", "cond = lambda i, _: i < 10\n", "def body(i, mt):\n", " return i + 1, mt.with_values(mt.values + 3 / 7)\n", "print(tf.while_loop(cond, body, [0, b])[1])" ] }, { "cell_type": "markdown", "metadata": { "id": "zkN7IuWVMRzn" }, "source": [ "### Autograph 控制流\n", "\n", "tf.function 中的控制流语句也支持扩展程序类型(使用 autograph)。在以下示例中,`if` 语句和 `for` 语句会自动转换为支持扩展程序类型的 `tf.cond` 和 `tf.while_loop` 运算。" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.596816Z", "iopub.status.busy": "2022-12-14T20:14:14.596291Z", "iopub.status.idle": "2022-12-14T20:14:14.738167Z", "shell.execute_reply": "2022-12-14T20:14:14.737594Z" }, "id": "4RFySEl8gZ8w" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n" ] } ], "source": [ "@tf.function\n", "def fn(x, b):\n", " if b:\n", " x = MaskedTensor(x, tf.less(x, 0))\n", " else:\n", " x = MaskedTensor(x, tf.greater(x, 0))\n", " for i in tf.range(5 if b else 7):\n", " x = x.with_values(x.values + 1 / 2)\n", " return x\n", "\n", "print(fn(tf.constant([1., -2, 3]), tf.constant(True)))\n", "print(fn(tf.constant([1., -2, 3]), tf.constant(False)))" ] }, { "cell_type": "markdown", "metadata": { "id": "-FjZt2ohfja4" }, "source": [ "### Keras\n", "\n", "[tf.keras](https://tensorflow.google.cn/guide/keras) 是 TensorFlow 用于构建和训练深度学习模型的高级 API。扩展程序类型可以作为输入传递给 Keras 模型,在 Keras 层之间传递,并由 Keras 模型返回。Keras 目前对扩展程序类型具有两项要求:\n", "\n", "- 它们必须可批处理(请转到上面的“可批处理 `ExtensionType`”)。\n", "- 它们必须具有名为 `shape` 的字段或属性。假定`shape[0]` 为批次维度。\n", "\n", "以下两个小节提供了展示如何将扩展程序类型与 Keras 一起使用的示例。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "QH1TXQYiGv8u" }, "source": [ "#### Keras 示例:`Network`\n", "\n", "对于第一个示例,请思考上面“可批处理 ExtensionType”部分定义的 `Network` 类,它可以用于节点之间的负载均衡工作。这里再次给出它的定义:" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.741621Z", "iopub.status.busy": "2022-12-14T20:14:14.741163Z", "iopub.status.idle": "2022-12-14T20:14:14.745602Z", "shell.execute_reply": "2022-12-14T20:14:14.745051Z" }, "id": "zHj1RIS2PK50" }, "outputs": [], "source": [ "class Network(tf.experimental.BatchableExtensionType):\n", " shape: tf.TensorShape # batch shape. A single network has shape=[].\n", " work: tf.Tensor # work[*shape, n] = work left to do at node n\n", " bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2\n", "\n", " def __init__(self, work, bandwidth):\n", " self.work = tf.convert_to_tensor(work)\n", " self.bandwidth = tf.convert_to_tensor(bandwidth)\n", " work_batch_shape = self.work.shape[:-1]\n", " bandwidth_batch_shape = self.bandwidth.shape[:-2]\n", " self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)\n", "\n", " def __repr__(self):\n", " return network_repr(self)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.748578Z", "iopub.status.busy": "2022-12-14T20:14:14.748030Z", "iopub.status.idle": "2022-12-14T20:14:14.753251Z", "shell.execute_reply": "2022-12-14T20:14:14.752715Z" }, "id": "w9LPTEVJD0FD" }, "outputs": [], "source": [ "single_network = Network( # A single network with 4 nodes.\n", " work=[8.0, 5, 12, 2],\n", " bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])\n", "\n", "batch_of_networks = Network( # Batch of 2 networks, each w/ 2 nodes.\n", " work=[[8.0, 5], [3, 2]],\n", " bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])" ] }, { "cell_type": "markdown", "metadata": { "id": "IUfWi3SDD0dj" }, "source": [ "您可以定义用于处理 `Network` 的新 Keras 层。" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.755945Z", "iopub.status.busy": "2022-12-14T20:14:14.755729Z", "iopub.status.idle": "2022-12-14T20:14:14.759169Z", "shell.execute_reply": "2022-12-14T20:14:14.758576Z" }, "id": "2WSYt58r4SF1" }, "outputs": [], "source": [ "class BalanceNetworkLayer(tf.keras.layers.Layer):\n", " \"\"\"Layer that balances work between nodes in a network.\n", "\n", " Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.\n", " \"\"\"\n", " def call(self, inputs):\n", " # This function is defined above in the \"Batchable `ExtensionType`s\" section.\n", " return balance_work_greedy(inputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "VWwFJNb1E03q" }, "source": [ "然后,您可以使用这些层来创建一个简单的模型。要将 `ExtensionType` 馈送给模型,您可以使用 `tf.keras.layer.Input` 层并将 `type_spec` 设置为扩展程序类型的 `TypeSpec`。如果 Keras 模型将用于处理批次,那么 `type_spec` 必须包含批次维度。" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.762295Z", "iopub.status.busy": "2022-12-14T20:14:14.761763Z", "iopub.status.idle": "2022-12-14T20:14:14.797676Z", "shell.execute_reply": "2022-12-14T20:14:14.797130Z" }, "id": "plTyqISRExA4" }, "outputs": [], "source": [ "input_spec = Network.Spec(shape=None,\n", " work=tf.TensorSpec(None, tf.float32),\n", " bandwidth=tf.TensorSpec(None, tf.float32))\n", "model = tf.keras.Sequential([\n", " tf.keras.layers.Input(type_spec=input_spec),\n", " BalanceNetworkLayer(),\n", " ])" ] }, { "cell_type": "markdown", "metadata": { "id": "hyeAbt1WFIiO" }, "source": [ "最后,您可以将模型应用于单个网络和一批网络。" ] }, { "cell_type": "code", "execution_count": 57, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.800870Z", "iopub.status.busy": "2022-12-14T20:14:14.800433Z", "iopub.status.idle": "2022-12-14T20:14:14.806129Z", "shell.execute_reply": "2022-12-14T20:14:14.805612Z" }, "id": "hH1EtA5lFHdN" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(single_network)" ] }, { "cell_type": "code", "execution_count": 58, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.809112Z", "iopub.status.busy": "2022-12-14T20:14:14.808607Z", "iopub.status.idle": "2022-12-14T20:14:14.814238Z", "shell.execute_reply": "2022-12-14T20:14:14.813678Z" }, "id": "V7eM67M7FYYM" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(batch_of_networks)" ] }, { "cell_type": "markdown", "metadata": { "id": "tOxtt9Z1HDCv" }, "source": [ "#### Keras 示例:MaskedTensor\n", "\n", "在此示例中,`MaskedTensor` 进行了扩展以支持 `Keras`。`shape` 定义为从 `values` 字段计算的属性。Keras 要求您将此属性添加到扩展程序类型及其 `TypeSpec`。`MaskedTensor` 还定义了 `SavedModel` 序列化所需的 `__name__` 变量(如下)。" ] }, { "cell_type": "code", "execution_count": 59, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.817215Z", "iopub.status.busy": "2022-12-14T20:14:14.816693Z", "iopub.status.idle": "2022-12-14T20:14:14.822400Z", "shell.execute_reply": "2022-12-14T20:14:14.821872Z" }, "id": "1JBZ_t48Ht7e" }, "outputs": [], "source": [ "class MaskedTensor(tf.experimental.BatchableExtensionType):\n", " # __name__ is required for serialization in SavedModel; see below for details.\n", " __name__ = 'extension_type_colab.MaskedTensor'\n", "\n", " values: tf.Tensor\n", " mask: tf.Tensor\n", "\n", " shape = property(lambda self: self.values.shape)\n", " dtype = property(lambda self: self.values.dtype)\n", "\n", " def with_default(self, default):\n", " return tf.where(self.mask, self.values, default)\n", "\n", " def __repr__(self):\n", " return masked_tensor_str(self.values, self.mask)\n", "\n", " class Spec:\n", " def __init__(self, shape, dtype=tf.float32):\n", " self.values = tf.TensorSpec(shape, dtype)\n", " self.mask = tf.TensorSpec(shape, tf.bool)\n", "\n", " shape = property(lambda self: self.values.shape)\n", " dtype = property(lambda self: self.values.dtype)\n", "\n", " def with_shape(self):\n", " return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),\n", " tf.TensorSpec(shape, self.mask.dtype))" ] }, { "cell_type": "markdown", "metadata": { "id": "oer8BVc8H7_V" }, "source": [ "接下来,调度装饰器会用于重写多个 TensorFlow API 的默认行为。由于这些 API 会由标准 Keras 层(例如 `Dense` 层)使用,对其进行重写,我们就能够将这些层与 `MaskedTensor` 一起使用。出于本示例的目的,我们定义了掩码张量的 `matmul` 以将掩码值视为零(即,不将它们包含在乘积中)。" ] }, { "cell_type": "code", "execution_count": 60, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.825418Z", "iopub.status.busy": "2022-12-14T20:14:14.824899Z", "iopub.status.idle": "2022-12-14T20:14:14.843289Z", "shell.execute_reply": "2022-12-14T20:14:14.842756Z" }, "id": "xy0dhQ_b-ca_" }, "outputs": [], "source": [ "@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)\n", "def unary_elementwise_op_handler(op, x):\n", " return MaskedTensor(op(x.values), x.mask)\n", "\n", "@tf.experimental.dispatch_for_binary_elementwise_apis(\n", " Union[MaskedTensor, tf.Tensor],\n", " Union[MaskedTensor, tf.Tensor])\n", "def binary_elementwise_op_handler(op, x, y):\n", " x = convert_to_masked_tensor(x)\n", " y = convert_to_masked_tensor(y)\n", " return MaskedTensor(op(x.values, y.values), x.mask & y.mask)\n", "\n", "@tf.experimental.dispatch_for_api(tf.matmul)\n", "def masked_matmul(a: MaskedTensor, b,\n", " transpose_a=False, transpose_b=False,\n", " adjoint_a=False, adjoint_b=False,\n", " a_is_sparse=False, b_is_sparse=False,\n", " output_type=None):\n", " if isinstance(a, MaskedTensor):\n", " a = a.with_default(0)\n", " if isinstance(b, MaskedTensor):\n", " b = b.with_default(0)\n", " return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,\n", " adjoint_b, a_is_sparse, b_is_sparse, output_type)" ] }, { "cell_type": "markdown", "metadata": { "id": "osJ_L-fKJusI" }, "source": [ "然后,您可以使用标准 Keras 层构建一个接受 `MaskedTensor` 输入的 Keras 模型:" ] }, { "cell_type": "code", "execution_count": 61, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.845928Z", "iopub.status.busy": "2022-12-14T20:14:14.845713Z", "iopub.status.idle": "2022-12-14T20:14:14.890997Z", "shell.execute_reply": "2022-12-14T20:14:14.890442Z" }, "id": "IS6JCVbk1rd0" }, "outputs": [], "source": [ "input_spec = MaskedTensor.Spec([None, 2], tf.float32)\n", "\n", "masked_tensor_model = tf.keras.Sequential([\n", " tf.keras.layers.Input(type_spec=input_spec),\n", " tf.keras.layers.Dense(16, activation=\"relu\"),\n", " tf.keras.layers.Dense(1)])\n", "masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:14.894304Z", "iopub.status.busy": "2022-12-14T20:14:14.893758Z", "iopub.status.idle": "2022-12-14T20:14:18.136649Z", "shell.execute_reply": "2022-12-14T20:14:18.135898Z" }, "id": "SB1WUSzn1RPj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 0.6819" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 3s 3s/step - loss: 0.6819\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 0.6239" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 5ms/step - loss: 0.6239\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 0.5903" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 5ms/step - loss: 0.5903\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 0.18340722]\n", " [-0.08917451]\n", " [ 1.3972318 ]], shape=(3, 1), dtype=float32)\n" ] } ], "source": [ "a = MaskedTensor([[1., 2], [3, 4], [5, 6]],\n", " [[True, False], [False, True], [True, True]])\n", "masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)\n", "print(masked_tensor_model(a))" ] }, { "cell_type": "markdown", "metadata": { "id": "msmd9XcL2bqb" }, "source": [ "### SavedModel\n", "\n", "[SavedModel](https://tensorflow.google.cn/guide/saved_model) 是序列化 TensorFlow 程序,包括权重和计算。它可以通过 Keras 模型或自定义模型构建。在任何一种情况下,扩展程序类型都可以透明地与 SavedModel 定义的函数和方法一起使用。\n", "\n", "SavedModel 可以保存用于处理扩展程序类型的模型、层和函数,只要扩展程序类型具有 `__name__` 字段即可。此名称用于注册扩展程序类型,以便在加载模型时进行定位。" ] }, { "cell_type": "markdown", "metadata": { "id": "PEtbFrz6-Vku" }, "source": [ "#### 示例:保存 Keras 模型\n", "\n", "可以使用 `SavedModel` 来保存使用扩展程序类型的 Keras 模型。" ] }, { "cell_type": "code", "execution_count": 63, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:18.140494Z", "iopub.status.busy": "2022-12-14T20:14:18.139889Z", "iopub.status.idle": "2022-12-14T20:14:18.616286Z", "shell.execute_reply": "2022-12-14T20:14:18.615627Z" }, "id": "ecxQMnybSzV6" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpkj9uvl1_/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpkj9uvl1_/assets\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "masked_tensor_model_path = tempfile.mkdtemp()\n", "tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)\n", "imported_model = tf.saved_model.load(masked_tensor_model_path)\n", "imported_model(a)" ] }, { "cell_type": "markdown", "metadata": { "id": "Ne2nu3r6-XMr" }, "source": [ "#### 示例:保存自定义模型\n", "\n", "SavedModel 还可用于保存包含用于处理扩展程序类型的函数的自定义 `tf.Module` 子类。" ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:18.620185Z", "iopub.status.busy": "2022-12-14T20:14:18.619462Z", "iopub.status.idle": "2022-12-14T20:14:18.739436Z", "shell.execute_reply": "2022-12-14T20:14:18.738748Z" }, "id": "2V6hV3yOT2vz" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpq2zpskdx/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpq2zpskdx/assets\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class CustomModule(tf.Module):\n", " def __init__(self, variable_value):\n", " super().__init__()\n", " self.v = tf.Variable(variable_value)\n", "\n", " @tf.function\n", " def grow(self, x: MaskedTensor):\n", " \"\"\"Increase values in `x` by multiplying them by `self.v`.\"\"\"\n", " return MaskedTensor(x.values * self.v, x.mask)\n", "\n", "module = CustomModule(100.0)\n", "\n", "module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,\n", " dtype=tf.float32))\n", "custom_module_path = tempfile.mkdtemp()\n", "tf.saved_model.save(module, custom_module_path)\n", "imported_model = tf.saved_model.load(custom_module_path)\n", "imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))" ] }, { "cell_type": "markdown", "metadata": { "id": "o6beljh576ee" }, "source": [ "#### 在 ExtensionType 不可用时加载 SavedModel\n", "\n", "如果您加载使用 `ExtensionType` 的 `SavedModel`,但该 `ExtensionType` 不可用(即尚未导入),您将看到一条警告,并且 TensorFlow 将回退到使用“匿名扩展程序类型”对象。此对象将具有与原始类型相同的字段,但将缺少您为该类型添加的任何后续自定义内容,例如自定义方法或属性。" ] }, { "cell_type": "markdown", "metadata": { "id": "ec9PcUkJ9bFK" }, "source": [ "#### `ExtensionType` 与 TensorFlow Serving 一起使用\n", "\n", "目前,[TensorFlow Serving](https://tensorflow.google.cn/tfx/guide/serving)(以及 SavedModel“签名”字典的其他使用者)要求所有输入和输出都是原始张量。如果您希望将 TensorFlow Serving 与使用扩展程序类型的模型一起使用,可以添加用于组合或分解张量的扩展程序类型值的封装容器方法。 例如:" ] }, { "cell_type": "code", "execution_count": 65, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:18.743110Z", "iopub.status.busy": "2022-12-14T20:14:18.742537Z", "iopub.status.idle": "2022-12-14T20:14:18.945909Z", "shell.execute_reply": "2022-12-14T20:14:18.945301Z" }, "id": "4VnzAwVo9tTc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpbv2z853b/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpbv2z853b/assets\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class CustomModuleWrapper(tf.Module):\n", " def __init__(self, variable_value):\n", " super().__init__()\n", " self.v = tf.Variable(variable_value)\n", "\n", " @tf.function\n", " def var_weighted_mean(self, x: MaskedTensor):\n", " \"\"\"Mean value of unmasked values in x, weighted by self.v.\"\"\"\n", " x = MaskedTensor(x.values * self.v, x.mask)\n", " return (tf.reduce_sum(x.with_default(0)) /\n", " tf.reduce_sum(tf.cast(x.mask, x.dtype)))\n", "\n", " @tf.function()\n", " def var_weighted_mean_wrapper(self, x_values, x_mask):\n", " \"\"\"Raw tensor wrapper for var_weighted_mean.\"\"\"\n", " return self.var_weighted_mean(MaskedTensor(x_values, x_mask))\n", "\n", "module = CustomModuleWrapper([3., 2., 8., 5.])\n", "\n", "module.var_weighted_mean_wrapper.get_concrete_function(\n", " tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))\n", "custom_module_path = tempfile.mkdtemp()\n", "tf.saved_model.save(module, custom_module_path)\n", "imported_model = tf.saved_model.load(custom_module_path)\n", "x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])\n", "imported_model.var_weighted_mean_wrapper(x.values, x.mask)" ] }, { "cell_type": "markdown", "metadata": { "id": "4dwBadWQ5G9_" }, "source": [ "### 数据集\n", "\n", "[tf.data](https://tensorflow.google.cn/guide/data) 是一个 API,可用于通过简单的可重用代码块构建复杂的输入流水线。它的核心数据结构是 `tf.data.Dataset`,表示一系列元素,每个元素包含一个或多个分量。" ] }, { "cell_type": "markdown", "metadata": { "id": "GcIR19FuwRJV" }, "source": [ "#### 使用扩展程序类型构建数据集\n", "\n", "可以使用 `Dataset.from_tensors`、`Dataset.from_tensor_slices` 或 `Dataset.from_generator` 从扩展程序类型值构建数据集:" ] }, { "cell_type": "code", "execution_count": 66, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:18.949690Z", "iopub.status.busy": "2022-12-14T20:14:18.949064Z", "iopub.status.idle": "2022-12-14T20:14:18.962452Z", "shell.execute_reply": "2022-12-14T20:14:18.961882Z" }, "id": "Oe7fRCkzwdub" }, "outputs": [ { "data": { "text/plain": [ "Pastry(sweetness=, chewiness=)" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds = tf.data.Dataset.from_tensors(Pastry(5, 5))\n", "iter(ds).next()" ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:18.965809Z", "iopub.status.busy": "2022-12-14T20:14:18.965294Z", "iopub.status.idle": "2022-12-14T20:14:18.982951Z", "shell.execute_reply": "2022-12-14T20:14:18.982321Z" }, "id": "fk9CD2fZx6yT" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n", "\n" ] } ], "source": [ "mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))\n", "ds = tf.data.Dataset.from_tensor_slices(mt)\n", "for value in ds:\n", " print(value)" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:18.986127Z", "iopub.status.busy": "2022-12-14T20:14:18.985569Z", "iopub.status.idle": "2022-12-14T20:14:19.049862Z", "shell.execute_reply": "2022-12-14T20:14:19.049198Z" }, "id": "DGw8y87awsOJ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n", "\n" ] } ], "source": [ "def value_gen():\n", " for i in range(2, 7):\n", " yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])\n", "\n", "ds = tf.data.Dataset.from_generator(\n", " value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))\n", "for value in ds:\n", " print(value)" ] }, { "cell_type": "markdown", "metadata": { "id": "wfEm4NInyqtj" }, "source": [ "#### 使用扩展程序类型批处理和取消批处理数据集\n", "\n", "可以使用 `Dataset.batch` 和 `Dataset.unbatch` 对具有扩展程序类型的数据集进行批处理和取消批处理。" ] }, { "cell_type": "code", "execution_count": 69, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:19.053247Z", "iopub.status.busy": "2022-12-14T20:14:19.052625Z", "iopub.status.idle": "2022-12-14T20:14:19.095149Z", "shell.execute_reply": "2022-12-14T20:14:19.094462Z" }, "id": "snoOUE1ay1rO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n" ] } ], "source": [ "batched_ds = ds.batch(2)\n", "for value in batched_ds:\n", " print(value)" ] }, { "cell_type": "code", "execution_count": 70, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:14:19.098621Z", "iopub.status.busy": "2022-12-14T20:14:19.098174Z", "iopub.status.idle": "2022-12-14T20:14:19.146675Z", "shell.execute_reply": "2022-12-14T20:14:19.145999Z" }, "id": "f8PTky6EzBVY" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n", "\n" ] } ], "source": [ "unbatched_ds = batched_ds.unbatch()\n", "for value in unbatched_ds:\n", " print(value)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "extension_type.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }