{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ISubpr_SSsiM" }, "source": [ "##### Copyright 2020 The TensorFlow Authors.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-01-11T19:25:23.957806Z", "iopub.status.busy": "2024-01-11T19:25:23.957355Z", "iopub.status.idle": "2024-01-11T19:25:23.961193Z", "shell.execute_reply": "2024-01-11T19:25:23.960634Z" }, "id": "3jTMb1dySr3V" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "6DWfyNThSziV" }, "source": [ "# tf.function によるパフォーマンスの改善\n", "\n", "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示Google Colab で実行GitHub でソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "J122XQYG7W6w" }, "source": [ "TensorFlow 2 の Eager execution はデフォルトで有効になっています。ユーザーインターフェースは直感的で柔軟性に優れていますが(一度限りの演算の実行ははるかに簡単で高速に行われます)、パフォーマンスとデプロイ能力に影響がでることがあります。\n", "\n", "プログラムからグラフを作成するには、`tf.function` を使用できます。変換ツールで Python コードから Python に依存しないデータフローグラフを作成するため、パフォーマンスと移植性に優れたモデルを作成できます。また、`SavedModel` を使用する際に必要となります。\n", "\n", "このチュートリアルでは `tf.function` と AutoGraph の基本的な特徴についてひととおり確認します。\n", "\n", "主に次の内容と推奨事項について説明しています。\n", "\n", "- Eager モードでデバッグしてから、`@tf.function` でデコレートする。\n", "- オブジェクトミューテーションまたはリストの追加といった Python 側の効果に依存しないこと。\n", "- `tf.function` は TensorFlow 演算子と最も相性が良く、NumPy と Python 呼び出しは定数に変換される。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "SjvqpgepHJPd" }, "source": [ "## セットアップ" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:23.964957Z", "iopub.status.busy": "2024-01-11T19:25:23.964364Z", "iopub.status.idle": "2024-01-11T19:25:26.343390Z", "shell.execute_reply": "2024-01-11T19:25:26.342699Z" }, "id": "otIdN1TS8N7S" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-01-11 19:25:24.397188: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-01-11 19:25:24.397236: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-01-11 19:25:24.398868: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tensorflow as tf" ] }, { "cell_type": "markdown", "metadata": { "id": "I0xDjO4SHLUD" }, "source": [ "発生する可能性のあるエラーの種類を示すヘルパー関数を定義します。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:26.347874Z", "iopub.status.busy": "2024-01-11T19:25:26.347258Z", "iopub.status.idle": "2024-01-11T19:25:26.352175Z", "shell.execute_reply": "2024-01-11T19:25:26.351526Z" }, "id": "D25apou9IOXa" }, "outputs": [], "source": [ "import traceback\n", "import contextlib\n", "\n", "# Some helper code to demonstrate the kinds of errors you might encounter.\n", "@contextlib.contextmanager\n", "def assert_raises(error_class):\n", " try:\n", " yield\n", " except error_class as e:\n", " print('Caught expected exception \\n {}:'.format(error_class))\n", " traceback.print_exc(limit=2)\n", " except Exception as e:\n", " raise e\n", " else:\n", " raise Exception('Expected {} to be raised but no error was raised!'.format(\n", " error_class))" ] }, { "cell_type": "markdown", "metadata": { "id": "WPSfepzTHThq" }, "source": [ "## 基礎" ] }, { "cell_type": "markdown", "metadata": { "id": "CNwYTIJ8r56W" }, "source": [ "### 使い方\n", "\n", "定義する `Function`(`@tf.function` デコレーターを適用するなどして)は、コアの TensorFlow 演算とまったく変わりません。Eager での実行や勾配の計算などを行えます。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:26.355553Z", "iopub.status.busy": "2024-01-11T19:25:26.355128Z", "iopub.status.idle": "2024-01-11T19:25:28.631766Z", "shell.execute_reply": "2024-01-11T19:25:28.631089Z" }, "id": "SbtT1-Wm70F2" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@tf.function # The decorator converts `add` into a `Function`.\n", "def add(a, b):\n", " return a + b\n", "\n", "add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:28.635481Z", "iopub.status.busy": "2024-01-11T19:25:28.634856Z", "iopub.status.idle": "2024-01-11T19:25:28.677883Z", "shell.execute_reply": "2024-01-11T19:25:28.677263Z" }, "id": "uP-zUelB8DbX" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v = tf.Variable(1.0)\n", "with tf.GradientTape() as tape:\n", " result = add(v, 1.0)\n", "tape.gradient(result, v)" ] }, { "cell_type": "markdown", "metadata": { "id": "ocWZvqrmHnmX" }, "source": [ "`Function` をほかの `Function` 内で使用できます。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:28.681694Z", "iopub.status.busy": "2024-01-11T19:25:28.681013Z", "iopub.status.idle": "2024-01-11T19:25:28.774747Z", "shell.execute_reply": "2024-01-11T19:25:28.774115Z" }, "id": "l5qRjdbBVdU6" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@tf.function\n", "def dense_layer(x, w, b):\n", " return add(tf.matmul(x, w), b)\n", "\n", "dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))" ] }, { "cell_type": "markdown", "metadata": { "id": "piBhz7gYsHqU" }, "source": [ "`Function` は、特に小さな演算が多数含まれるグラフでは、Eager コードよりも高速に実行されることがありますが、高価な演算がいくつか含まれるグラフ(畳み込みなど)では、速度の差はあまり見られません。\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:28.778483Z", "iopub.status.busy": "2024-01-11T19:25:28.777754Z", "iopub.status.idle": "2024-01-11T19:25:29.551571Z", "shell.execute_reply": "2024-01-11T19:25:29.550854Z" }, "id": "zuXt4wRysI03" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Eager conv: 0.005722227000660496\n", "Function conv: 0.0052488470000753296\n", "Note how there's not much difference in performance for convolutions\n" ] } ], "source": [ "import timeit\n", "conv_layer = tf.keras.layers.Conv2D(100, 3)\n", "\n", "@tf.function\n", "def conv_fn(image):\n", " return conv_layer(image)\n", "\n", "image = tf.zeros([1, 200, 200, 100])\n", "# Warm up\n", "conv_layer(image); conv_fn(image)\n", "print(\"Eager conv:\", timeit.timeit(lambda: conv_layer(image), number=10))\n", "print(\"Function conv:\", timeit.timeit(lambda: conv_fn(image), number=10))\n", "print(\"Note how there's not much difference in performance for convolutions\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "uZ4Do2AV80cO" }, "source": [ "### トレーシング\n", "\n", "このセクションは、`Function` の内部動作や実装の詳細を説明します。*将来的に変更する可能性があります*が、いつなぜトレーシングが発生するのかを理解しておけば、`tf.function` を効果的に使用しやすくなります。" ] }, { "cell_type": "markdown", "metadata": { "id": "nhpUtRqsXoyM" }, "source": [ "#### 「トレーシング」とは?\n", "\n", "`Function` は [TensorFlow Graph](https://www.tensorflow.org/guide/intro_to_graphs#what_are_graphs) でプログラムを実行しますが、`tf.Graph` は、Eager TensorFlow プログラムにユーザーが記述するすべてのものを表現することはできません。たとえば、Python はポリモーフィズムをサポートしていますが、`tf.Graph` では、その入力に特定のデータ型と次元が必要です。またはコマンドラインの引数を読み取る、エラーを発生させる、より複雑な Python オブジェクトを扱うといったサイドタスクを実施しようとしても、どれも `tf.Graph` で実行することはできません。\n", "\n", "`Function` はコードを 2 つの段階に分けることで、このギャップの橋渡しの役割を果たします。\n", "\n", "1. 「**トレーシング**」と呼ばれる第 1 段階において、`Function` は新しい `tf.Graph` を作成します。Python コードは通常通り実行しますが、すべての TensorFlow 演算(2 つのテンソルを加算するなど)は *据え置き*となります。これらは `tf.Graph` にとらわれるため、実行しません。\n", "\n", "2. 第 2 段階では、最初の段階で据え置きとなったすべての演算を含む `tf.Graph` が実行されます。この段階は、トレーシングの段階よりもはるかに高速に行われます。\n", "\n", "`Function` は、その入力によっては必ずしも最初の段階で呼び出されたときに実行するわけではありません。この判定がどのように行われるのかについては、以下の「[トレーシングの規則](#rules_of_tracing)」をご覧ください。最初の段階を省略して 2 番目の段階のみを実行できれば、TensorFlow の高いパフォーマンスが発揮されます。\n", "\n", "`Function` がトレーシングしないと判断した場合、トレーシング段階の直後に 第 2 段階が始まるため、`Function` を呼び出すと、`tf.Graph` の作成と実行が行われます。後の方で、[`get_concrete_function`](#obtaining_concrete_functions) を使ってトレーシング段階のみを実行する方法を説明します。" ] }, { "cell_type": "markdown", "metadata": { "id": "K7scSzLx662f" }, "source": [ "型の異なる引数を `Function` に渡すと、両方の段階が実行されます。\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:29.555697Z", "iopub.status.busy": "2024-01-11T19:25:29.555075Z", "iopub.status.idle": "2024-01-11T19:25:29.626414Z", "shell.execute_reply": "2024-01-11T19:25:29.625517Z" }, "id": "kojmJrgq8U9v" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing with Tensor(\"a:0\", shape=(), dtype=int32)\n", "tf.Tensor(2, shape=(), dtype=int32)\n", "\n", "Tracing with Tensor(\"a:0\", shape=(), dtype=float32)\n", "tf.Tensor(2.2, shape=(), dtype=float32)\n", "\n", "Tracing with Tensor(\"a:0\", shape=(), dtype=string)\n", "tf.Tensor(b'aa', shape=(), dtype=string)\n", "\n" ] } ], "source": [ "@tf.function\n", "def double(a):\n", " print(\"Tracing with\", a)\n", " return a + a\n", "\n", "print(double(tf.constant(1)))\n", "print()\n", "print(double(tf.constant(1.1)))\n", "print()\n", "print(double(tf.constant(\"a\")))\n", "print()\n" ] }, { "cell_type": "markdown", "metadata": { "id": "QPfouGUQrcNb" }, "source": [ "同じ型の引数で `Function` を繰り返し呼び出すと、生成されるグラフはまったく同じになるため、TensorFlow はトレーシング段階を省略して前にトレーシングしたグラフを再利用することに注意してください。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:29.630566Z", "iopub.status.busy": "2024-01-11T19:25:29.629808Z", "iopub.status.idle": "2024-01-11T19:25:29.634821Z", "shell.execute_reply": "2024-01-11T19:25:29.634134Z" }, "id": "hFccbWFRrsBp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(b'bb', shape=(), dtype=string)\n" ] } ], "source": [ "# This doesn't print 'Tracing with ...'\n", "print(double(tf.constant(\"b\")))" ] }, { "cell_type": "markdown", "metadata": { "id": "fgIO_XEzcB9o" }, "source": [ "すべての利用可能なトレースを確認するには、`pretty_printed_concrete_signatures()` を使用できます。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:29.638655Z", "iopub.status.busy": "2024-01-11T19:25:29.637978Z", "iopub.status.idle": "2024-01-11T19:25:29.642454Z", "shell.execute_reply": "2024-01-11T19:25:29.641686Z" }, "id": "IiQc4IKAb-NX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input Parameters:\n", " a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.int32, name=None)\n", "Output Type:\n", " TensorSpec(shape=(), dtype=tf.int32, name=None)\n", "Captures:\n", " None\n", "\n", "Input Parameters:\n", " a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)\n", "Output Type:\n", " TensorSpec(shape=(), dtype=tf.float32, name=None)\n", "Captures:\n", " None\n", "\n", "Input Parameters:\n", " a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)\n", "Output Type:\n", " TensorSpec(shape=(), dtype=tf.string, name=None)\n", "Captures:\n", " None\n" ] } ], "source": [ "print(double.pretty_printed_concrete_signatures())" ] }, { "cell_type": "markdown", "metadata": { "id": "rKQ92VEWI7n8" }, "source": [ "ここまで、`tf.function` が TensorFlow のグラフトレーシングロジックにキャッシュされた動的ディスパッチレイヤーを作成するのを見てきました。用語についてより具体的に説明すると、次のように言えます。\n", "\n", "- `tf.Graph` は、言語に依存しない、生の移植可能な TensorFlow 計算の表現です。\n", "- `ConcreteFunction` は `tf.Graph` をラップします。\n", "- `Function` は `ConcreteFunction` のキャッシュを管理し、入力に適したものを選択します。\n", "- `tf.function` は Python 関数をラップし、`Function` オブジェクトを返します。\n", "- **トレーシング**は `tf.Graph` を作成し、それを `ConcreteFunction`(または**トレース**)をラップします。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "129-iRsPS-gY" }, "source": [ "#### トレーシングの規則\n", "\n", "`Function` が呼び出されると、各引数の `tf.types.experimental.TraceType` を使用して呼び出し引数を既存の `ConcreteFunction` に一致させます。一致する `ConcreteFunction` が見つかった場合、呼び出しはそれにディスパッチされます。一致するものが見つからない場合、新しい `ConcreteFunction` がトレースされます。\n", "\n", "複数の一致が見つかった場合は、最も具体的なシグネチャが選択されます。マッチングは、たとえば C++ や Java での通常の関数呼び出しと同じように、[サブタイプ化](https://en.wikipedia.org/wiki/Subtyping)によって行われます。例えば、`TensorShape([1, 2])` は `TensorShape([None, None])` のサブタイプ化であるため、`TensorShape([1, 2])` を使用した tf.function への呼び出しは、`TensorShape([None, None])` で生成された `ConcreteFunction` にディスパッチできます。しかし、`TensorShape([1, None])` を持つ `ConcreteFunction` も存在する場合は、より具体的であるため優先されます。\n", "\n", "`TraceType` は、次のように入力引数から決定されます。\n", "\n", "- `Tensor` の場合、型は `Tensor` の `dtype` と `shape` によってパラメータ化されます。階数付けされた形状は、階数付けされていない形状のサブタイプです。固定次元は未知次元のサブタイプです\n", "\n", "- `Variable` の場合、型は `Tensor` に似ていますが、変数の一意のリソース ID も含まれています。これは、制御の依存関係を正しく設定するために必要です。\n", "\n", "- Python プリミティブ値の場合、型は**値**自体に対応します。たとえば、値 `3` の `TraceType` は、`int` ではなく `LiteralTraceType<3>` です。\n", "\n", "- `list` や `tuple` などの Python の順序付きコンテナの場合、型はそれらの要素の型によってパラメータ化されます。たとえば、`[1, 2]` の型は `ListTraceType, LiteralTraceType<2>>` であり、`[2, 1]` の型は `ListTraceType, LiteralTraceType<1>>` であり、異なります。\n", "\n", "- `dict` などの Python マッピングの場合、型も同じキーからのマッピングですが、実際の値ではなく値の型へのマッピングです。たとえば、`{1: 2, 3: 4}` の型は `MappingTraceType<>>, >>>` です。ただし、順序付きコンテナとは異なり、`{1: 2, 3: 4}` と `{3: 4, 1: 2}` の型は同等です。\n", "\n", "- `__tf_tracing_type__` メソッドを実装する Python オブジェクトの場合、型はそのメソッドが返すものです\n", "\n", "- その他の Python オブジェクトの場合、型はジェネリックの `TraceType` で、マッチング手順は以下のとおりです。\n", "\n", " - まず、オブジェクトが前のトレースで使用されたオブジェクトと同じであるかをチェックします(Python の `id()` または `is` を使用します)。オブジェクトが変更された場合でも一致することに注意してください。そのため、Python オブジェクトを `tf.function` の引数として使用する場合、*イミュータブル*を使用するのが最適です。\n", " - 次に、オブジェクトが前のトレースで使用されたオブジェクトと同じであるかをチェックします(Python の `==` を使用)。\n", "\n", " この手順では、オブジェクトへの [weakref](https://docs.python.org/3/library/weakref.html) のみが維持されるため、オブジェクトが範囲内または削除されていない場合にのみ機能します。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "GNNN4lgRzpIs" }, "source": [ "注意: `TraceType` は `Function` 入力パラメータに基づいているため、グローバル変数と[自由変数](https://docs.python.org/3/reference/executionmodel.html#binding-of-names)を変更するだけでは、新しいトレースは作成されません。Python のグローバル変数と自由変数を扱う際の推奨される方法については、[こちらのセクション](#depending_on_python_global_and_free_variables)をご覧ください。" ] }, { "cell_type": "markdown", "metadata": { "id": "PEDwbumO32Wh" }, "source": [ "### リトレーシングの制御\n", "\n", "リトレーシングは、`Function` が 2 つ以上のトレースを作成する際に発生します。これは、TensorFlow が一連の入力ごとに正しいグラフを生成する上で役立ちますが、トレーシングは高価な演算です!`Function` が呼び出しごとに新しいグラフをリトレーシングすると、コードの実行は `tf.function` を使用しない場合よりも遅くなってしまいます。\n", "\n", "トレーシングの動作を制御するには、次のテクニックを使用できます。" ] }, { "cell_type": "markdown", "metadata": { "id": "EUtycWJa34TT" }, "source": [ "#### 固定の `input_signature` を `tf.function` に渡す" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:29.646647Z", "iopub.status.busy": "2024-01-11T19:25:29.646108Z", "iopub.status.idle": "2024-01-11T19:25:30.389521Z", "shell.execute_reply": "2024-01-11T19:25:30.388732Z" }, "id": "_BDMIRmu1RGB" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing with Tensor(\"x:0\", shape=(None,), dtype=int32)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([4 1], shape=(2,), dtype=int32)\n", "Caught expected exception \n", " :\n", "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_178944/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_178944/3657259638.py\", line 9, in \n", " next_collatz(tf.constant([[1, 2], [3, 4]]))\n", "TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2, 2), dtype=tf.int32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).\n", "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_178944/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_178944/3657259638.py\", line 13, in \n", " next_collatz(tf.constant([1.0, 2.0]))\n", "TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2,), dtype=tf.float32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).\n" ] } ], "source": [ "@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n", "def next_collatz(x):\n", " print(\"Tracing with\", x)\n", " return tf.where(x % 2 == 0, x // 2, 3 * x + 1)\n", "\n", "print(next_collatz(tf.constant([1, 2])))\n", "# You specified a 1-D tensor in the input signature, so this should fail.\n", "with assert_raises(TypeError):\n", " next_collatz(tf.constant([[1, 2], [3, 4]]))\n", "\n", "# You specified an int32 dtype in the input signature, so this should fail.\n", "with assert_raises(TypeError):\n", " next_collatz(tf.constant([1.0, 2.0]))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ocxX-HVk7P2o" }, "source": [ "#### 柔軟性のために未知の次元を使用する\n", "\n", "TensorFlow は形状に基づいてテンソルを一致させるため、ワイルドカードとして `None` 次元を使用することで、`Function` が可変サイズの入力にトレースを再利用できるようになります。可変サイズの入力は、長さの異なるシーケンスがある場合や、バッチごとに画像のサイズが異なる場合に発生します(例として、[Transformer](../tutorials/text/transformer.ipynb) と [Deep Dream](../tutorials/generative/deepdream.ipynb) チュートリアルをご覧ください)。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.394130Z", "iopub.status.busy": "2024-01-11T19:25:30.393368Z", "iopub.status.idle": "2024-01-11T19:25:30.433362Z", "shell.execute_reply": "2024-01-11T19:25:30.432713Z" }, "id": "4Viun7dh7PmF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing with Tensor(\"x:0\", shape=(None,), dtype=int32)\n", "tf.Tensor([1 2 3], shape=(3,), dtype=int32)\n", "tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)\n" ] } ], "source": [ "@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n", "def g(x):\n", " print('Tracing with', x)\n", " return x\n", "\n", "# No retrace!\n", "print(g(tf.constant([1, 2, 3])))\n", "print(g(tf.constant([1, 2, 3, 4, 5])))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "AY5oiQN0XIyA" }, "source": [ "#### Python リテラルの代わりにテンソルを渡す\n", "\n", "通常、Python 引数は、`num_layers=10` または `training=True` または `nonlinearity='relu'` などのように、ハイパーパラメータとグラフ構造の制御に使用されます。そのため、Python 引数が変わると、当然グラフをリトレースする必要が出てきます。\n", "\n", "しかし、Python 引数がグラフ構造の制御に使用されていない場合もあります。こういった場合、Python の値の変化によってリトレーシングがトリガーされますが、これは不要です。この、AutoGraph が動的にアンロールするトレーニングループを例に見てみましょう。トレースが何度も行われますが、生成されたグラフはまったく同じであるため、リトレーシングは不要と言えます。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.436665Z", "iopub.status.busy": "2024-01-11T19:25:30.436390Z", "iopub.status.idle": "2024-01-11T19:25:30.583540Z", "shell.execute_reply": "2024-01-11T19:25:30.582913Z" }, "id": "uydzR5JYUU8H" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Retracing occurs for different Python arguments.\n", "Tracing with num_steps = 10\n", "Executing with num_steps = 10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Tracing with num_steps = 20\n", "Executing with num_steps = 20\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Traces are reused for Tensor arguments.\n", "Tracing with num_steps = Tensor(\"num_steps:0\", shape=(), dtype=int32)\n", "Executing with num_steps = 10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Executing with num_steps = 20\n" ] } ], "source": [ "def train_one_step():\n", " pass\n", "\n", "@tf.function\n", "def train(num_steps):\n", " print(\"Tracing with num_steps = \", num_steps)\n", " tf.print(\"Executing with num_steps = \", num_steps)\n", " for _ in tf.range(num_steps):\n", " train_one_step()\n", "\n", "print(\"Retracing occurs for different Python arguments.\")\n", "train(num_steps=10)\n", "train(num_steps=20)\n", "\n", "print()\n", "print(\"Traces are reused for Tensor arguments.\")\n", "train(num_steps=tf.constant(10))\n", "train(num_steps=tf.constant(20))" ] }, { "cell_type": "markdown", "metadata": { "id": "4pJqkDR_Q2wz" }, "source": [ "リトレーシングを強制する必要がある場合は、新しい `Function` を作成します。トレースは絶対に、各 `Function` オブジェクト間で共有されることはありません。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.587385Z", "iopub.status.busy": "2024-01-11T19:25:30.586789Z", "iopub.status.idle": "2024-01-11T19:25:30.631677Z", "shell.execute_reply": "2024-01-11T19:25:30.631080Z" }, "id": "uHp4ousu4DdN" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing!\n", "Executing\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Tracing!\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Executing\n" ] } ], "source": [ "def f():\n", " print('Tracing!')\n", " tf.print('Executing')\n", "\n", "tf.function(f)()\n", "tf.function(f)()" ] }, { "cell_type": "markdown", "metadata": { "id": "-tZoWrA6INvc" }, "source": [ "#### トレースプロトコルを使用する\n", "\n", "可能であれば、代わりに Python 型を `tf.experimental.ExtensionType` に変換することをお勧めします。さらに、`ExtensionType` の `TraceType` は、それに関連付けられた `tf.TypeSpec` です。したがって、必要に応じて、デフォルトの `tf.TypeSpec` を単純にオーバーライドして、`ExtensionType` の `Tracing Protocol` を制御できます。詳細については、[拡張型](extension_type.ipynb)ガイドの ExtensionType の TypeSpec のカスタマイズセクションをご覧ください。\n", "\n", "それ以外の場合は、`Function` が特定の Python 型に関していつ再トレースする必要があるかを直接制御するために、`Tracing Protocol` を自分で実装できます。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.635302Z", "iopub.status.busy": "2024-01-11T19:25:30.634810Z", "iopub.status.idle": "2024-01-11T19:25:30.714138Z", "shell.execute_reply": "2024-01-11T19:25:30.713519Z" }, "id": "gZkIh7UaIKc6" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@tf.function\n", "def get_mixed_flavor(fruit_a, fruit_b):\n", " return fruit_a.flavor + fruit_b.flavor\n", "\n", "class Fruit:\n", " flavor = tf.constant([0, 0])\n", "\n", "class Apple(Fruit):\n", " flavor = tf.constant([1, 2])\n", "\n", "class Mango(Fruit):\n", " flavor = tf.constant([3, 4])\n", "\n", "# As described in the above rules, a generic TraceType for `Apple` and `Mango`\n", "# is generated (and a corresponding ConcreteFunction is traced) but it fails to\n", "# match the second function call since the first pair of Apple() and Mango()\n", "# have gone out out of scope by then and deleted.\n", "get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function\n", "get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again\n", "\n", "# However, each subclass of the `Fruit` class has a fixed flavor, and you\n", "# can reuse an existing traced concrete function if it was the same\n", "# subclass. Avoiding such unnecessary tracing of concrete functions\n", "# can have significant performance benefits.\n", "\n", "class FruitTraceType(tf.types.experimental.TraceType):\n", " def __init__(self, fruit):\n", " self.fruit_type = type(fruit)\n", " self.fruit_value = fruit\n", "\n", " def is_subtype_of(self, other):\n", " # True if self subtypes `other` and `other`'s type matches FruitTraceType.\n", " return (type(other) is FruitTraceType and\n", " self.fruit_type is other.fruit_type)\n", "\n", " def most_specific_common_supertype(self, others):\n", " # `self` is the specific common supertype if all input types match it.\n", " return self if all(self == other for other in others) else None\n", "\n", " def placeholder_value(self, placeholder_context=None):\n", " # Use the fruit itself instead of the type for correct tracing.\n", " return self.fruit_value\n", "\n", " def __eq__(self, other):\n", " return type(other) is FruitTraceType and self.fruit_type == other.fruit_type\n", "\n", " def __hash__(self):\n", " return hash(self.fruit_type)\n", "\n", "class FruitWithTraceType:\n", "\n", " def __tf_tracing_type__(self, context):\n", " return FruitTraceType(self)\n", "\n", "class AppleWithTraceType(FruitWithTraceType):\n", " flavor = tf.constant([1, 2])\n", "\n", "class MangoWithTraceType(FruitWithTraceType):\n", " flavor = tf.constant([3, 4])\n", "\n", "# Now if you try calling it again:\n", "get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function\n", "get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function" ] }, { "cell_type": "markdown", "metadata": { "id": "96IxS2WR37fF" }, "source": [ "### 具象関数の取得\n", "\n", "関数がトレースされるたびに新しい具象関数が作成されますが、`get_concrete_function` を使うことで、具象関数を直接取得できます。\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.717619Z", "iopub.status.busy": "2024-01-11T19:25:30.717125Z", "iopub.status.idle": "2024-01-11T19:25:30.722556Z", "shell.execute_reply": "2024-01-11T19:25:30.721966Z" }, "id": "mHg2CGtPQ3Hz" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Obtaining concrete trace\n", "Executing traced function\n", "tf.Tensor(b'aa', shape=(), dtype=string)\n", "tf.Tensor(b'bb', shape=(), dtype=string)\n" ] } ], "source": [ "print(\"Obtaining concrete trace\")\n", "double_strings = double.get_concrete_function(tf.constant(\"a\"))\n", "print(\"Executing traced function\")\n", "print(double_strings(tf.constant(\"a\")))\n", "print(double_strings(a=tf.constant(\"b\")))\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.725861Z", "iopub.status.busy": "2024-01-11T19:25:30.725306Z", "iopub.status.idle": "2024-01-11T19:25:30.729617Z", "shell.execute_reply": "2024-01-11T19:25:30.728971Z" }, "id": "6IVZ-NVf9vsx" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(b'cc', shape=(), dtype=string)\n" ] } ], "source": [ "# You can also call get_concrete_function on an InputSpec\n", "double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))\n", "print(double_strings_from_inputspec(tf.constant(\"c\")))" ] }, { "cell_type": "markdown", "metadata": { "id": "iR4fVmG34xvF" }, "source": [ "`ConcreteFunction` を出力すると、入力引数(型付き)とその出力型の概要が表示されます。" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.732420Z", "iopub.status.busy": "2024-01-11T19:25:30.732189Z", "iopub.status.idle": "2024-01-11T19:25:30.735430Z", "shell.execute_reply": "2024-01-11T19:25:30.734855Z" }, "id": "o3-JbkIk41r8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ConcreteFunction Input Parameters:\n", " a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)\n", "Output Type:\n", " TensorSpec(shape=(), dtype=tf.string, name=None)\n", "Captures:\n", " None\n" ] } ], "source": [ "print(double_strings)" ] }, { "cell_type": "markdown", "metadata": { "id": "QtqfvljZeuOV" }, "source": [ "また、具象関数のシグネチャを直接取得することもできます。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.738393Z", "iopub.status.busy": "2024-01-11T19:25:30.738176Z", "iopub.status.idle": "2024-01-11T19:25:30.741720Z", "shell.execute_reply": "2024-01-11T19:25:30.741107Z" }, "id": "nzbrqFABe0zG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})\n", "Tensor(\"Identity:0\", shape=(), dtype=string)\n" ] } ], "source": [ "print(double_strings.structured_input_signature)\n", "print(double_strings.structured_outputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "lar5A_5m5IG1" }, "source": [ "互換性のない型で具象トレースを使用すると、エラーが発生します。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.744548Z", "iopub.status.busy": "2024-01-11T19:25:30.744294Z", "iopub.status.idle": "2024-01-11T19:25:30.749868Z", "shell.execute_reply": "2024-01-11T19:25:30.749170Z" }, "id": "G5eeTK-T5KYj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py\", line 442, in bind_function_inputs\n", " bound_arguments = function_type.bind_with_defaults(\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py\", line 277, in bind_with_defaults\n", " with_default_args[arg_name] = constraint.cast(\n", "TypeError: Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)\n", "\n", "The above exception was the direct cause of the following exception:\n", "\n", "Traceback (most recent call last):\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1180, in _call_impl\n", " return self._call_with_structured_signature(args, kwargs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1260, in _call_with_structured_signature\n", " function_type_utils.canonicalize_function_inputs(\n", "TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)`. Received args: (,) and kwargs: {} for signature: (a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None).\n", "\n", "During handling of the above exception, another exception occurred:\n", "\n", "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_178944/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_178944/3196284684.py\", line 2, in \n", " double_strings(tf.constant(1))\n", "tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]\n" ] } ], "source": [ "with assert_raises(tf.errors.InvalidArgumentError):\n", " double_strings(tf.constant(1))" ] }, { "cell_type": "markdown", "metadata": { "id": "st2L9VNQVtSG" }, "source": [ "Python 引数は、具象関数の入力シグネチャで特別に扱われていることに気づいたかもしれません。TensorFlow 2.3 より前では、Python 引数は単に具象関数のシグネチャから削除されていましたが、TensorFlow 2.3 からはシグネチャに残されたまま、トレーシング中に値セットを取るように制約されています。" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.752701Z", "iopub.status.busy": "2024-01-11T19:25:30.752474Z", "iopub.status.idle": "2024-01-11T19:25:30.777964Z", "shell.execute_reply": "2024-01-11T19:25:30.777355Z" }, "id": "U_QyPSGoaC35" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ConcreteFunction Input Parameters:\n", " a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=, dtype=tf.float32, name=None)\n", " b (POSITIONAL_OR_KEYWORD): Literal[2]\n", "Output Type:\n", " TensorSpec(shape=, dtype=tf.float32, name=None)\n", "Captures:\n", " None\n" ] } ], "source": [ "@tf.function\n", "def pow(a, b):\n", " return a ** b\n", "\n", "square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)\n", "print(square)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.781262Z", "iopub.status.busy": "2024-01-11T19:25:30.780729Z", "iopub.status.idle": "2024-01-11T19:25:30.881037Z", "shell.execute_reply": "2024-01-11T19:25:30.880291Z" }, "id": "E76vIDhQbXIb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py\", line 442, in bind_function_inputs\n", " bound_arguments = function_type.bind_with_defaults(\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py\", line 277, in bind_with_defaults\n", " with_default_args[arg_name] = constraint.cast(\n", "ValueError: Can not cast 3 to Literal[2]\n", "\n", "The above exception was the direct cause of the following exception:\n", "\n", "Traceback (most recent call last):\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1180, in _call_impl\n", " return self._call_with_structured_signature(args, kwargs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1260, in _call_with_structured_signature\n", " function_type_utils.canonicalize_function_inputs(\n", "TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=, dtype=tf.float32, name=None).\n", "\n", "During handling of the above exception, another exception occurred:\n", "\n", "Traceback (most recent call last):\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1183, in _call_impl\n", " return self._call_with_flat_signature(args, kwargs)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1234, in _call_with_flat_signature\n", " raise TypeError(f\"{self._flat_signature_summary()} got unexpected \"\n", "TypeError: pow(a) got unexpected keyword arguments: b.\n", "\n", "During handling of the above exception, another exception occurred:\n", "\n", "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_178944/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_178944/2310937119.py\", line 4, in \n", " square(tf.constant(10.0), b=3)\n", "TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=, dtype=tf.float32, name=None).\n", "Fallback to flat signature also failed due to: pow(a) got unexpected keyword arguments: b.\n" ] } ], "source": [ "assert square(tf.constant(10.0)) == 100\n", "\n", "with assert_raises(TypeError):\n", " square(tf.constant(10.0), b=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "41gJh_JGIfuA" }, "source": [ "### グラフの取得\n", "\n", "それぞれの具象関数は、`tf.Graph` を囲む呼び出し可能なラッパーです。通常、実際の `tf.Graph` オブジェクトを取得する必要はないにしろ、具象関数から簡単に取得することが可能です。" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.884547Z", "iopub.status.busy": "2024-01-11T19:25:30.884304Z", "iopub.status.idle": "2024-01-11T19:25:30.888431Z", "shell.execute_reply": "2024-01-11T19:25:30.887852Z" }, "id": "5UENeGHfaX8g" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[] -> a\n", "['a', 'a'] -> add\n", "['add'] -> Identity\n" ] } ], "source": [ "graph = double_strings.graph\n", "for node in graph.as_graph_def().node:\n", " print(f'{node.input} -> {node.name}')\n" ] }, { "cell_type": "markdown", "metadata": { "id": "aIKkgr6qdtp4" }, "source": [ "### デバッグ\n", "\n", "一般的に、コードのデバックは、`tf.function` 内で行うよりも、Eager モードで行う方が簡単です。Eager モードでは、`tf.function` でデコレートする前に、コードがエラーなく実行することを確認しておく必要があります。デバッグプロセスを支援する目的で、`tf.config.run_functions_eagerly(True)` を呼び出すと、`tf.function` をグローバルに無効にして、有効にし直すことができます。\n", "\n", "`tf.function` 内でのみ出現する問題を追跡する場合、次のようなヒントがあります。\n", "\n", "- 従来のシンプルな Python `print` 呼び出しは、トレーシング中にのみ実行されるため、関数が(リ)トレーシングされるときに追跡しやすくなります。\n", "- `tf.print` 呼び出しは毎回実行するため、実行中の中間値の追跡に役立ちます。\n", "- `tf.debugging.enable_check_numerics` は、NaN と Inf がいつ作成されるかを簡単に追跡できます。\n", "- `pdb`([Python デバッガ](https://docs.python.org/3/library/pdb.html))は、トレーシング中に何が起きているのかを理解する上で役立ちます。(注意: `pdb` が示すのは、AutoGraph 変換ソースコードです。)" ] }, { "cell_type": "markdown", "metadata": { "id": "5f05Vr_YBUCz" }, "source": [ "## AutoGraph 変換\n", "\n", "AutoGraph は、`tf.function` 内でデフォルトで利用できるようになっているライブラリで、Python の Eager コードのサブセットとグラフ対応の TensorFlow 演算に変換します。これには、`if`、`for`、`while` などの制御フローが含まれます。\n", "\n", "`tf.cond` や `tf.while_loop` などの TensorFlow 演算は機能し続けますが、制御フローは、Python で記述された場合の方が書きやすく理解しやすいことがほとんどです。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.891742Z", "iopub.status.busy": "2024-01-11T19:25:30.891514Z", "iopub.status.idle": "2024-01-11T19:25:30.990489Z", "shell.execute_reply": "2024-01-11T19:25:30.989849Z" }, "id": "yCQTtTPTW3WF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.645081878 0.344026446 0.0305811167 0.918421388 0.0816819668]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.568349838 0.331067264 0.0305715837 0.725149751 0.0815007836]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.514146328 0.319479436 0.03056206 0.620089 0.0813208073]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.473169476 0.309036136 0.0305525456 0.551190078 0.0811420083]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.440756589 0.299559951 0.0305430386 0.501411617 0.0809643865]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.414271355 0.290909827 0.0305335429 0.463226587 0.0807879269]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.39209339 0.282972 0.0305240545 0.43271029 0.0806126148]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.373163491 0.275653511 0.0305145755 0.407583863 0.0804384425]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.356755644 0.268877536 0.030505104 0.386419207 0.0802654]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.342353106 0.262580037 0.0304956455 0.368269116 0.0800934657]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.329576522 0.256707132 0.0304861926 0.352476776 0.0799226314]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.318140209 0.251213 0.0304767471 0.338570237 0.0797528848]\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# A simple loop\n", "\n", "@tf.function\n", "def f(x):\n", " while tf.reduce_sum(x) > 1:\n", " tf.print(x)\n", " x = tf.tanh(x)\n", " return x\n", "\n", "f(tf.random.uniform([5]))" ] }, { "cell_type": "markdown", "metadata": { "id": "KxwJ8znPI0Cg" }, "source": [ "興味があれば、AutoGraph が生成するコードを検査できます。" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:30.993817Z", "iopub.status.busy": "2024-01-11T19:25:30.993581Z", "iopub.status.idle": "2024-01-11T19:25:30.998644Z", "shell.execute_reply": "2024-01-11T19:25:30.998012Z" }, "id": "jlQD1ffRXJhl" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def tf__f(x):\n", " with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:\n", " do_return = False\n", " retval_ = ag__.UndefinedReturnValue()\n", "\n", " def get_state():\n", " return (x,)\n", "\n", " def set_state(vars_):\n", " nonlocal x\n", " (x,) = vars_\n", "\n", " def loop_body():\n", " nonlocal x\n", " ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)\n", " x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)\n", "\n", " def loop_test():\n", " return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1\n", " ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})\n", " try:\n", " do_return = True\n", " retval_ = ag__.ld(x)\n", " except:\n", " do_return = False\n", " raise\n", " return fscope.ret(retval_, do_return)\n", "\n" ] } ], "source": [ "print(tf.autograph.to_code(f.python_function))" ] }, { "cell_type": "markdown", "metadata": { "id": "xgKmkrNTZSyz" }, "source": [ "### 条件文\n", "\n", "AutoGraph は `if ` 文を相当する `tf.cond` 呼び出しに変換します。この置換は、`` がテンソルである場合に行われます。テンソルでない場合は、`if` 文は Python の条件文として実行されます。\n", "\n", "Python 条件文はトレーシング中に実行するため、条件文のブランチが 1 つだけグラフに追加されます。AutoGraph を使用しない場合、データに依存する制御フローが存在すると、トレーシングされたこのグラフは別のブランチを取ることができません。\n", "\n", "`tf.cond` は、条件文の両方のブランチをトレーシングし、実行時に動的に 1 つのブランチを選択してグラフに追加します。トレーシングには意図しない副作用がある場合があります。詳細は、[AutoGraph のトレーシング効果](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#effects-of-the-tracing-process)をご覧ください。" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.001786Z", "iopub.status.busy": "2024-01-11T19:25:31.001503Z", "iopub.status.idle": "2024-01-11T19:25:31.206018Z", "shell.execute_reply": "2024-01-11T19:25:31.205337Z" }, "id": "BOQl8PMq2Sf3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing for loop\n", "Tracing fizzbuzz branch\n", "Tracing fizz branch\n", "Tracing buzz branch\n", "Tracing default branch\n", "1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "buzz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "buzz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "7\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "buzz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "11\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "13\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "14\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizzbuzz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "16\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "17\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fizz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "19\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "buzz\n" ] } ], "source": [ "@tf.function\n", "def fizzbuzz(n):\n", " for i in tf.range(1, n + 1):\n", " print('Tracing for loop')\n", " if i % 15 == 0:\n", " print('Tracing fizzbuzz branch')\n", " tf.print('fizzbuzz')\n", " elif i % 3 == 0:\n", " print('Tracing fizz branch')\n", " tf.print('fizz')\n", " elif i % 5 == 0:\n", " print('Tracing buzz branch')\n", " tf.print('buzz')\n", " else:\n", " print('Tracing default branch')\n", " tf.print(i)\n", "\n", "fizzbuzz(tf.constant(5))\n", "fizzbuzz(tf.constant(20))" ] }, { "cell_type": "markdown", "metadata": { "id": "4rBO5AQ15HVC" }, "source": [ "AutoGraph 変換の if 文におけるその他の制約事項については、[リファレンスドキュメント](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#if-statements)をご覧ください。" ] }, { "cell_type": "markdown", "metadata": { "id": "yho4J0a0ZkQS" }, "source": [ "### ループ\n", "\n", "AutoGraph は、一部の `for` 文と `while` 文を相当する `tf.while_loop` などの TensorFlow のループ演算に変換します。変換されない場合、`for` または `while` ループは Python ループとして実行されます。\n", "\n", "この置き換えは、次の場合に行われます。\n", "\n", "- `for x in y`: `y` がテンソルである場合、`tf.while_loop` に変換されます。`y` が `tf.data.Dataset` である特別なケースでは、`tf.data.Dataset` 演算の組み合わせが生成されます。\n", "- `while `: `` がテンソルである場合、`tf.while_loop` に変換されます。\n", "\n", "Python ループは、トレーシング中に実行され、ループのいてレーションごとに、`tf.Graph` に追加の演算が追加されます。\n", "\n", "TensorFlow ループはループの本体をトレーシングし、実行時に実行する反復回数を動的に選択します。ループ本体は、生成された `tf.Graph` に一度だけ出現します。\n", "\n", "AutoGraph 変換の `for` 文と `while` 文におけるその他の制約事項については、[リファレンスドキュメント](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#while-statements)をご覧ください。" ] }, { "cell_type": "markdown", "metadata": { "id": "sp4rbIdfbM6s" }, "source": [ "#### Python データのループ\n", "\n", "一般的な落とし穴は、`tf.function` 内で Python/NumPy データをループする際にあります。このループは、トレーシングプロセス中に実行し、ループのイテレーションごとにモデルのコピーを tf.Graph に追加してしまいます。\n", "\n", "トレーニングループ全体を `tf.function` にラップしたいのであれば、データを `tf.data.Dataset` としてラップし、AutoGraph にトレーニングループを動的に展開させるようにするのが最も安全な方法です。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.209479Z", "iopub.status.busy": "2024-01-11T19:25:31.209202Z", "iopub.status.idle": "2024-01-11T19:25:31.356892Z", "shell.execute_reply": "2024-01-11T19:25:31.356232Z" }, "id": "WGZ19LspbZ27" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph\n", "train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "train(<_FlatMapDataset element_spec=(TensorSpec(shape=, dtype=tf.int32, name=None), TensorSpec(shape=, dtype=tf.int32, name=None))>) contains 6 nodes in its graph\n", "train(<_FlatMapDataset element_spec=(TensorSpec(shape=, dtype=tf.int32, name=None), TensorSpec(shape=, dtype=tf.int32, name=None))>) contains 6 nodes in its graph\n" ] } ], "source": [ "def measure_graph_size(f, *args):\n", " g = f.get_concrete_function(*args).graph\n", " print(\"{}({}) contains {} nodes in its graph\".format(\n", " f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))\n", "\n", "@tf.function\n", "def train(dataset):\n", " loss = tf.constant(0)\n", " for x, y in dataset:\n", " loss += tf.abs(y - x) # Some dummy computation.\n", " return loss\n", "\n", "small_data = [(1, 1)] * 3\n", "big_data = [(1, 1)] * 10\n", "measure_graph_size(train, small_data)\n", "measure_graph_size(train, big_data)\n", "\n", "measure_graph_size(train, tf.data.Dataset.from_generator(\n", " lambda: small_data, (tf.int32, tf.int32)))\n", "measure_graph_size(train, tf.data.Dataset.from_generator(\n", " lambda: big_data, (tf.int32, tf.int32)))" ] }, { "cell_type": "markdown", "metadata": { "id": "JeD2U-yrbfVb" }, "source": [ "Python/NumPy データをデータセットにラップする際は、`tf.data.Dataset.from_generator` と ` tf.data.Dataset.from_tensor_slices` の違いに注意してください。前者は、データを Python に維持し、`tf.py_function` 経由で取得するため、パフォーマンスに問題がありますが、後者は、データのコピーをグラフ内の大型の `tf.constant()` ノードとしてバンドル化するため、メモリに問題が現れます。\n", "\n", "データを消費するには、`TFRecordDataset` や `CsvDataset` などを介してファイルからデータを読み取るのが最も効果的な方法です。そうすれば、Python を使わずに、TensorFlow 自体でデータの非同期読み込みとプリフェッチを管理できるようになります。詳細は、「[`tf.data`: TensorFlow 入力パイプラインを構築する](../../guide/data)」ガイドをご覧ください。" ] }, { "cell_type": "markdown", "metadata": { "id": "hyksHW9TCukR" }, "source": [ "#### ループでの値の累積\n", "\n", "ループの反復ごとに値を累積していくのは一般的なパターンです。通常は、Python のリストに追加したり、Python ディクショナリにエントリを追加したりして行われますが、これらは Python の副作用であるため、動的に展開されるループでは期待どおりに動作しません。動的に展開されるループの結果を累積する場合は、`tf.TensorArray` を使用してください。" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.360769Z", "iopub.status.busy": "2024-01-11T19:25:31.360118Z", "iopub.status.idle": "2024-01-11T19:25:31.513894Z", "shell.execute_reply": "2024-01-11T19:25:31.513250Z" }, "id": "HJ3Vb3dXfefN" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_size = 2\n", "seq_len = 3\n", "feature_size = 4\n", "\n", "def rnn_step(inp, state):\n", " return inp + state\n", "\n", "@tf.function\n", "def dynamic_rnn(rnn_step, input_data, initial_state):\n", " # [batch, time, features] -> [time, batch, features]\n", " input_data = tf.transpose(input_data, [1, 0, 2])\n", " max_seq_len = input_data.shape[0]\n", "\n", " states = tf.TensorArray(tf.float32, size=max_seq_len)\n", " state = initial_state\n", " for i in tf.range(max_seq_len):\n", " state = rnn_step(input_data[i], state)\n", " states = states.write(i, state)\n", " return tf.transpose(states.stack(), [1, 0, 2])\n", "\n", "dynamic_rnn(rnn_step,\n", " tf.random.uniform([batch_size, seq_len, feature_size]),\n", " tf.zeros([batch_size, feature_size]))" ] }, { "cell_type": "markdown", "metadata": { "id": "i2MVoIVaNApG" }, "source": [ "## 制限事項\n", "\n", "TensorFlow の `Function` には、設計上、いくつかの制限事項があり、Python 関数を `Function` に変換する際には、注意が必要です。" ] }, { "cell_type": "markdown", "metadata": { "id": "EJqHGFSVLIKl" }, "source": [ "### Python の副作用の実行\n", "\n", "`Function` 内での出力、リストへのアペンド、グローバル変数のミューテーションといった副作用は、2 回実行されたり、まったく実行しなかったりといったように、予測のつかない動作をすることがあります。また、入力セットで `Function` を初めて呼び出した場合にのみ実行し、以降では、Python コードを実行せずに、トレーシング済みの `tf.Graph` が再実行されてしまうこともあります。\n", "\n", "基本的に、ロジックでは Python の副作用に依存しないようにし、トレースをデバッグするためだけに使用することをお勧めします。呼び出しごとに TensorFlow ランタイムが確実にコードを実行できるようにするには、`tf.data`、`tf.print`、`tf.summary`、`tf.Variable.assign`、`tf.TensorArray` などの TensorFlow API を使用するのが最善の方法です。" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.517280Z", "iopub.status.busy": "2024-01-11T19:25:31.517047Z", "iopub.status.idle": "2024-01-11T19:25:31.565778Z", "shell.execute_reply": "2024-01-11T19:25:31.565143Z" }, "id": "w2sACuZ9TTRk" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Traced with" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 1\n", "Executed with 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Executed with 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Traced with 2\n", "Executed with 2\n" ] } ], "source": [ "@tf.function\n", "def f(x):\n", " print(\"Traced with\", x)\n", " tf.print(\"Executed with\", x)\n", "\n", "f(1)\n", "f(1)\n", "f(2)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "e1I0dPiqTV8H" }, "source": [ "`Function` の呼び出しごとに Python コードを実行する場合は、`tf.py_function` が脱出口です。`tf.py_function` には移植性がなく、特にパフォーマンスに優れているわけでもなく、SavedModel で保存できなければ、分散型(マルチ GPU、TPU)の環境でうまく動作するわけでもありません。また、`tf.py_function` はグラフに組み込む必要もあるため、すべての入力/出力をテンソルにキャストしてしまいます。" ] }, { "cell_type": "markdown", "metadata": { "id": "bOW1v9WVKGgH" }, "source": [ "#### Python のグローバル変数と自由変数の変更\n", "\n", "Python のグローバル変数と[自由変数](https://docs.python.org/3/reference/executionmodel.html#binding-of-names)の変更は、Python の副作用としてみなされるため、トレーシング中にのみ発生します。\n" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.569033Z", "iopub.status.busy": "2024-01-11T19:25:31.568732Z", "iopub.status.idle": "2024-01-11T19:25:31.599615Z", "shell.execute_reply": "2024-01-11T19:25:31.598978Z" }, "id": "7aJD--9qTWmg" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python side effect\n" ] } ], "source": [ "external_list = []\n", "\n", "@tf.function\n", "def side_effect(x):\n", " print('Python side effect')\n", " external_list.append(x)\n", "\n", "side_effect(1)\n", "side_effect(1)\n", "side_effect(1)\n", "# The list append only happened once!\n", "assert len(external_list) == 1" ] }, { "cell_type": "markdown", "metadata": { "id": "5eZTFRv_k_nR" }, "source": [ "場合によっては気づきにくい予期しない動作が発生することがあります。以下の例では、`counter` は変数のインクリメントを保護することを目的としています。ただし、これは Python 整数であり、TensorFlow オブジェクトではないため、その値は最初のトレース中にキャプチャされます。`tf.function` を使用すると、`assign_add` が下のグラフに無条件に記録されます。したがって、`v` は、`tf.function` が呼び出されるたびに 1 ずつ増加します。この問題は、`tf.function` デコレータを使用して Grpah モードの Tensorflow コードを Tensorflow 2 に移行しようとする場合、Python の副作用 (例では `counter` ) を使用して、実行する演算を決定すると (例では、`assign_add` )によく発生します。通常、ユーザーは、疑わしい数値結果を確認したり、予想よりもパフォーマンスが大幅に低下した場合に、このことに気付きます(たとえば、保護された演算に非常にコストがかかる場合)。" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.602748Z", "iopub.status.busy": "2024-01-11T19:25:31.602526Z", "iopub.status.idle": "2024-01-11T19:25:31.655801Z", "shell.execute_reply": "2024-01-11T19:25:31.655206Z" }, "id": "5r6p7-9jk_3L" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n", "2\n", "3\n" ] } ], "source": [ "class Model(tf.Module):\n", " def __init__(self):\n", " self.v = tf.Variable(0)\n", " self.counter = 0\n", "\n", " @tf.function\n", " def __call__(self):\n", " if self.counter == 0:\n", " # A python side-effect\n", " self.counter += 1\n", " self.v.assign_add(1)\n", "\n", " return self.v\n", "\n", "m = Model()\n", "for n in range(3):\n", " print(m().numpy()) # prints 1, 2, 3" ] }, { "cell_type": "markdown", "metadata": { "id": "tXCTcHoVcxhX" }, "source": [ "このような動作を回避し、期待される動作を実現するためには、[`tf.init_scope`](https://www.tensorflow.org/api_docs/python/tf/init_scope) を使用して演算を関数グラフの外に移動します。これにより、変数のインクリメントがトレース時間中に 1 回だけ実行されるようになります。`init_scope` には、制御フローのクリアや勾配テープなどの他の副作用があることに注意してください。`init_scope` を使用すると非常に複雑になり、現実的に管理できない場合があります。" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.658788Z", "iopub.status.busy": "2024-01-11T19:25:31.658556Z", "iopub.status.idle": "2024-01-11T19:25:31.711953Z", "shell.execute_reply": "2024-01-11T19:25:31.711344Z" }, "id": "An4MrIbrcvi8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n", "1\n", "1\n" ] } ], "source": [ "class Model(tf.Module):\n", " def __init__(self):\n", " self.v = tf.Variable(0)\n", " self.counter = 0\n", "\n", " @tf.function\n", " def __call__(self):\n", " if self.counter == 0:\n", " # Lifts ops out of function-building graphs\n", " with tf.init_scope():\n", " self.counter += 1\n", " self.v.assign_add(1)\n", "\n", " return self.v\n", "\n", "m = Model()\n", "for n in range(3):\n", " print(m().numpy()) # prints 1, 1, 1" ] }, { "cell_type": "markdown", "metadata": { "id": "pbFG5CX4LwQA" }, "source": [ "まとめると、経験則として、`Function` の外側で機能する整数またはリストのようなコンテナなどの Python オブジェクトのミューテーションは避けてください。代わりに、引数と TF オブジェクトを使用しましょう。たとえば、「[ループでの値の累積](#accumulating_values_in_a_loop)」セクションには、リストのような演算を実装する方法の一例が示されています。\n", "\n", "一部のケースでは、[`tf.Variable`](https://www.tensorflow.org/guide/variable) である場合に状態をキャプチャして操作することができます。Keras モデルの重みは、このようにして、同じ `ConcreteFunction` への呼び出しの繰り返しで更新されています。" ] }, { "cell_type": "markdown", "metadata": { "id": "X_oNNGrAqPJ1" }, "source": [ "#### Python イテレータとジェネレータの使用" ] }, { "cell_type": "markdown", "metadata": { "id": "msTmv-oyUNaf" }, "source": [ "ジェネレータやイテレータなどの多くの Python 機能は、Python ランタイムに依存して状態を追跡しています。一般的に、これらのコンストラクトは Eager モードでも期待どおりに動作しますが、Python の副作用の例であるため、トレーシング中にしか発生しません。" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.715127Z", "iopub.status.busy": "2024-01-11T19:25:31.714893Z", "iopub.status.idle": "2024-01-11T19:25:31.750119Z", "shell.execute_reply": "2024-01-11T19:25:31.749437Z" }, "id": "FNPD4unZUedH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Value: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Value: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Value: 1\n" ] } ], "source": [ "@tf.function\n", "def buggy_consume_next(iterator):\n", " tf.print(\"Value:\", next(iterator))\n", "\n", "iterator = iter([1, 2, 3])\n", "buggy_consume_next(iterator)\n", "# This reuses the first value from the iterator, rather than consuming the next value.\n", "buggy_consume_next(iterator)\n", "buggy_consume_next(iterator)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "wcS3TAgCjTWR" }, "source": [ "TensorFlow にリストコントラクト用の特別な `tf.TensorArray` があるように、イテレーション用にも特別な `tf.data.Iterator` があります。概要は、[AutoGraph 変換](#autograph_transformations)をご覧ください。また、[`tf.data`](https://www.tensorflow.org/guide/data) API を使って、ジェネレータのパターンを実装できます。\n" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.753395Z", "iopub.status.busy": "2024-01-11T19:25:31.752942Z", "iopub.status.idle": "2024-01-11T19:25:31.799116Z", "shell.execute_reply": "2024-01-11T19:25:31.798310Z" }, "id": "8D_iKetXW6VE" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Value: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Value: 2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Value: 3\n" ] } ], "source": [ "@tf.function\n", "def good_consume_next(iterator):\n", " # This is ok, iterator is a tf.data.Iterator\n", " tf.print(\"Value:\", next(iterator))\n", "\n", "ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])\n", "iterator = iter(ds)\n", "good_consume_next(iterator)\n", "good_consume_next(iterator)\n", "good_consume_next(iterator)" ] }, { "cell_type": "markdown", "metadata": { "id": "i8YAMYb6KEh4" }, "source": [ "### tf.function のすべての出力は値を返す必要がある\n", "\n", "`tf.Variable`を除いて、tf.function はすべての出力を返す必要があります。戻り値を使用せずに関数からテンソルに直接アクセスしようとすると、「リーク」が発生します。\n", "\n", "たとえば、以下の関数は、Python グローバル `x` を介してテンソル `a` を「リーク」します。" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.802551Z", "iopub.status.busy": "2024-01-11T19:25:31.801955Z", "iopub.status.idle": "2024-01-11T19:25:31.841448Z", "shell.execute_reply": "2024-01-11T19:25:31.840830Z" }, "id": "zrdp4rjxg6jo" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n", "'SymbolicTensor' object has no attribute 'numpy'\n" ] } ], "source": [ "x = None\n", "\n", "@tf.function\n", "def leaky_function(a):\n", " global x\n", " x = a + 1 # Bad - leaks local tensor\n", " return a + 2\n", "\n", "correct_a = leaky_function(tf.constant(1))\n", "\n", "print(correct_a.numpy()) # Good - value obtained from function's returns\n", "try:\n", " x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n", "except AttributeError as expected:\n", " print(expected)" ] }, { "cell_type": "markdown", "metadata": { "id": "-d4_J_DC5rxX" }, "source": [ "リークされた値も返される場合でもリークします。" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.844650Z", "iopub.status.busy": "2024-01-11T19:25:31.844104Z", "iopub.status.idle": "2024-01-11T19:25:31.901552Z", "shell.execute_reply": "2024-01-11T19:25:31.900934Z" }, "id": "PrcpPB8C5s9T" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2\n", "'SymbolicTensor' object has no attribute 'numpy'\n", "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_178944/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_178944/566849597.py\", line 21, in \n", " captures_leaked_tensor(tf.constant(2))\n", "TypeError: is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.\n", "Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.\n", "\n", " was defined here:\n", " File \"/usr/lib/python3.9/runpy.py\", line 197, in _run_module_as_main\n", " File \"/usr/lib/python3.9/runpy.py\", line 87, in _run_code\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py\", line 17, in \n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py\", line 701, in start\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py\", line 205, in start\n", " File \"/usr/lib/python3.9/asyncio/base_events.py\", line 601, in run_forever\n", " File \"/usr/lib/python3.9/asyncio/base_events.py\", line 1905, in _run_once\n", " File \"/usr/lib/python3.9/asyncio/events.py\", line 80, in _run\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3048, in run_cell\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3103, in _run_cell\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3308, in run_cell_async\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3490, in run_ast_nodes\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3550, in run_code\n", " File \"/tmpfs/tmp/ipykernel_178944/566849597.py\", line 7, in \n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 832, in __call__\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 888, in _call\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 695, in _initialize\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py\", line 178, in trace_function\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py\", line 283, in _maybe_define_function\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py\", line 310, in _create_concrete_function\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py\", line 1059, in func_graph_from_py_func\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 598, in wrapped_fn\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py\", line 41, in autograph_handler\n", " File \"/tmpfs/tmp/ipykernel_178944/566849597.py\", line 4, in leaky_function\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py\", line 1478, in binary_op_wrapper\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py\", line 1260, in op_dispatch_handler\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py\", line 1871, in _add_dispatch\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py\", line 490, in add_v2\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py\", line 796, in _apply_op_helper\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py\", line 670, in _create_op_internal\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py\", line 2652, in _create_op_internal\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py\", line 1160, in from_node_def\n", "\n", "The tensor cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=140038351396288), which is out of scope.\n" ] } ], "source": [ "@tf.function\n", "def leaky_function(a):\n", " global x\n", " x = a + 1 # Bad - leaks local tensor\n", " return x # Good - uses local tensor\n", "\n", "correct_a = leaky_function(tf.constant(1))\n", "\n", "print(correct_a.numpy()) # Good - value obtained from function's returns\n", "try:\n", " x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n", "except AttributeError as expected:\n", " print(expected)\n", "\n", "@tf.function\n", "def captures_leaked_tensor(b):\n", " b += x # Bad - `x` is leaked from `leaky_function`\n", " return b\n", "\n", "with assert_raises(TypeError):\n", " captures_leaked_tensor(tf.constant(2))" ] }, { "cell_type": "markdown", "metadata": { "id": "Sm2ghjyy50D4" }, "source": [ "通常、このようなリークは、Python ステートメントまたはデータ構造を使用するときに発生します。アクセスできないテンソルがリークするだけでなく、このようなステートメントは Python の副作用としてカウントされ、すべての関数呼び出しで実行されないことがあるため、間違っている可能性があります。\n", "\n", "また、一般的に外部 Python コレクションまたはオブジェクトの変更によりローカルテンソルがリークすることもあります。" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.904815Z", "iopub.status.busy": "2024-01-11T19:25:31.904582Z", "iopub.status.idle": "2024-01-11T19:25:31.908519Z", "shell.execute_reply": "2024-01-11T19:25:31.907857Z" }, "id": "D7bLe8y652wU" }, "outputs": [], "source": [ "class MyClass:\n", "\n", " def __init__(self):\n", " self.field = None\n", "\n", "external_list = []\n", "external_object = MyClass()\n", "\n", "def leaky_function():\n", " a = tf.constant(1)\n", " external_list.append(a) # Bad - leaks tensor\n", " external_object.field = a # Bad - leaks tensor" ] }, { "cell_type": "markdown", "metadata": { "id": "g-XVQcD-wf5K" }, "source": [ "### 再帰的な tf.function はサポートされていない\n", "\n", "再帰的な `Function` はサポートされていないので、無限ループを引き起こす可能性があります。以下に例を示します。" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:31.912082Z", "iopub.status.busy": "2024-01-11T19:25:31.911556Z", "iopub.status.idle": "2024-01-11T19:25:32.776412Z", "shell.execute_reply": "2024-01-11T19:25:32.775652Z" }, "id": "QSN-T1m5EFcR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_178944/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 9, in \n", " recursive_fn(tf.constant(5)) # Bad - maximum recursion error.\n", "tensorflow.python.autograph.impl.api.StagingError: in user code:\n", "\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/tmpfs/tmp/ipykernel_178944/2233998312.py\", line 4, in recursive_fn *\n", " return recursive_fn(n - 1)\n", " File \"/usr/lib/python3.9/abc.py\", line 119, in __instancecheck__\n", " return _abc_instancecheck(cls, instance)\n", "\n", " RecursionError: maximum recursion depth exceeded while calling a Python object\n", "\n" ] } ], "source": [ "@tf.function\n", "def recursive_fn(n):\n", " if n > 0:\n", " return recursive_fn(n - 1)\n", " else:\n", " return 1\n", "\n", "with assert_raises(Exception):\n", " recursive_fn(tf.constant(5)) # Bad - maximum recursion error." ] }, { "cell_type": "markdown", "metadata": { "id": "LyRyooKGUxNV" }, "source": [ "再帰的な `Function` が正しく動作しているように見えても、Python 関数は複数回トレースされ、パフォーマンスに影響を与える可能性があります。以下に例を示します。" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:32.780716Z", "iopub.status.busy": "2024-01-11T19:25:32.780105Z", "iopub.status.idle": "2024-01-11T19:25:32.849009Z", "shell.execute_reply": "2024-01-11T19:25:32.848346Z" }, "id": "7FlmTqfMUwmT" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tracing\n", "tracing\n", "tracing\n", "tracing\n", "tracing\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@tf.function\n", "def recursive_fn(n):\n", " if n > 0:\n", " print('tracing')\n", " return recursive_fn(n - 1)\n", " else:\n", " return 1\n", "\n", "recursive_fn(5) # Warning - multiple tracings" ] }, { "cell_type": "markdown", "metadata": { "id": "-D6nh3QirXAd" }, "source": [ "## 既知の問題\n", "\n", "`Function` が正しく評価していない場合、以下の既知の問題が該当する可能性があります。これらの問題は、今後修正される予定です。" ] }, { "cell_type": "markdown", "metadata": { "id": "ZoPg5w1Pjqna" }, "source": [ "### Python のグローバル変数と自由変数への依存\n", "\n", "`Function` は、Python 引数の新しい値で呼び出された時に新しい `ConcreteFunction` を作成しますが、Python クロージャ、グローバル変数、またはその `Function` の非ローカル変数に対しては作成しません。`Function` への呼び出しごとに値が変化する場合でも、`Function` はトレーシングされたときの値をそのまま使用してしまいます。これは、通常の Python 関数の動作とは異なります。\n", "\n", "このため、外側の名前を閉じる代わりに引数を使用する関数プログラミングの様式をお勧めします。" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:32.852408Z", "iopub.status.busy": "2024-01-11T19:25:32.852173Z", "iopub.status.idle": "2024-01-11T19:25:32.913546Z", "shell.execute_reply": "2024-01-11T19:25:32.912923Z" }, "id": "oeJMdXd3M0cM" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Buggy: tf.Tensor(2, shape=(), dtype=int32)\n", "Correct: tf.Tensor(2, shape=(), dtype=int32)\n" ] } ], "source": [ "@tf.function\n", "def buggy_add():\n", " return 1 + foo\n", "\n", "@tf.function\n", "def recommended_add(foo):\n", " return 1 + foo\n", "\n", "foo = 1\n", "print(\"Buggy:\", buggy_add())\n", "print(\"Correct:\", recommended_add(foo))" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:32.916578Z", "iopub.status.busy": "2024-01-11T19:25:32.916343Z", "iopub.status.idle": "2024-01-11T19:25:32.933197Z", "shell.execute_reply": "2024-01-11T19:25:32.932596Z" }, "id": "L3q7sUJWZOSU" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Updating the value of `foo` to 100!\n", "Buggy: tf.Tensor(2, shape=(), dtype=int32)\n", "Correct: tf.Tensor(101, shape=(), dtype=int32)\n" ] } ], "source": [ "print(\"Updating the value of `foo` to 100!\")\n", "foo = 100\n", "print(\"Buggy:\", buggy_add()) # Did not change!\n", "print(\"Correct:\", recommended_add(foo))" ] }, { "cell_type": "markdown", "metadata": { "id": "ZoPg5w1Pjqnb" }, "source": [ "グローバル値を更新する別の方法として、それを `tf.Variable` にし、代わりに `Variable.assign` メソッドを使用することができます。\n" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:32.936454Z", "iopub.status.busy": "2024-01-11T19:25:32.935915Z", "iopub.status.idle": "2024-01-11T19:25:32.972772Z", "shell.execute_reply": "2024-01-11T19:25:32.972168Z" }, "id": "oeJMdXd3M0cc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Variable: tf.Tensor(2, shape=(), dtype=int32)\n" ] } ], "source": [ "@tf.function\n", "def variable_add():\n", " return 1 + foo\n", "\n", "foo = tf.Variable(1)\n", "print(\"Variable:\", variable_add())\n" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:32.975555Z", "iopub.status.busy": "2024-01-11T19:25:32.975322Z", "iopub.status.idle": "2024-01-11T19:25:32.980071Z", "shell.execute_reply": "2024-01-11T19:25:32.979384Z" }, "id": "L3q7sUJWZOSd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Updating the value of `foo` to 100!\n", "Variable: tf.Tensor(101, shape=(), dtype=int32)\n" ] } ], "source": [ "print(\"Updating the value of `foo` to 100!\")\n", "foo.assign(100)\n", "print(\"Variable:\", variable_add())" ] }, { "cell_type": "markdown", "metadata": { "id": "hvwe9gTIWfx6" }, "source": [ "### Python オブジェクトへの依存" ] }, { "cell_type": "markdown", "metadata": { "id": "BJkZS-SwPvOQ" }, "source": [ "カスタム Python オブジェクトを引数として `tf.function` に渡すことはサポートされていますが、ある制限が伴います。\n", "\n", "特徴量を最大限にカバーするには、オブジェクトを `tf.function` に渡す前に[拡張型](extension_type.ipynb)に変換することを検討してください。Python プリミティブ型と `tf.nest` 対応構造も使用できます。\n", "\n", "ただし、[トレーシングのルール](#rules_of_tracing)で説明されるように、カスタム `TraceType` がカスタム Python クラスによって提供されない場合、`tf.function` はインスタンスベースの等価性を使用するように強制されてしまいます。そのため、**変更された属性と同じオブジェクト**を渡しても、**新しいトレースは作成されません**。" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:32.983089Z", "iopub.status.busy": "2024-01-11T19:25:32.982872Z", "iopub.status.idle": "2024-01-11T19:25:33.024712Z", "shell.execute_reply": "2024-01-11T19:25:33.024099Z" }, "id": "ux8KJESVWDxX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(20.0, shape=(), dtype=float32)\n" ] } ], "source": [ "class SimpleModel(tf.Module):\n", " def __init__(self):\n", " # These values are *not* tf.Variables.\n", " self.bias = 0.\n", " self.weight = 2.\n", "\n", "@tf.function\n", "def evaluate(model, x):\n", " return model.weight * x + model.bias\n", "\n", "simple_model = SimpleModel()\n", "x = tf.constant(10.)\n", "print(evaluate(simple_model, x))" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:33.027711Z", "iopub.status.busy": "2024-01-11T19:25:33.027462Z", "iopub.status.idle": "2024-01-11T19:25:33.032036Z", "shell.execute_reply": "2024-01-11T19:25:33.031429Z" }, "id": "mUxRF4ghZZvX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Adding bias!\n", "tf.Tensor(20.0, shape=(), dtype=float32)\n" ] } ], "source": [ "print(\"Adding bias!\")\n", "simple_model.bias += 5.0\n", "print(evaluate(simple_model, x)) # Didn't change :(" ] }, { "cell_type": "markdown", "metadata": { "id": "Ytcgg2qFWaBF" }, "source": [ "変更されたモデルインスタンスの評価に同じ `Function` を使用すると、元のモデルと[同じインスタンスに基づく TraceType](#rules_of_tracing)があるために不具合が生じます。\n", "\n", "そのため、ミュータブルオブジェクト属性に依存しない `Function` を書くか、そのような属性を `Function` に伝達するオブジェクトに[トレーシングプロトコル](#use_the_tracing_protocol)を実装することをお勧めします。\n", "\n", "この方法が困難な場合は、回避策として、オブジェクトを変更するたびに新しい `Function` がリトレーシングを行うようにする方法が挙げられます。" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:33.035535Z", "iopub.status.busy": "2024-01-11T19:25:33.034964Z", "iopub.status.idle": "2024-01-11T19:25:33.075259Z", "shell.execute_reply": "2024-01-11T19:25:33.074661Z" }, "id": "pFvWmWAAQjrv" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(20.0, shape=(), dtype=float32)\n" ] } ], "source": [ "def evaluate(model, x):\n", " return model.weight * x + model.bias\n", "\n", "new_model = SimpleModel()\n", "evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)\n", "# Don't pass in `new_model`, `Function` already captured its state during tracing.\n", "print(evaluate_no_bias(x))" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:33.078225Z", "iopub.status.busy": "2024-01-11T19:25:33.078004Z", "iopub.status.idle": "2024-01-11T19:25:33.100277Z", "shell.execute_reply": "2024-01-11T19:25:33.099625Z" }, "id": "bdU2-jF4ZH0B" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Adding bias!\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(25.0, shape=(), dtype=float32)\n" ] } ], "source": [ "print(\"Adding bias!\")\n", "new_model.bias += 5.0\n", "# Create new Function and ConcreteFunction since you modified new_model.\n", "evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)\n", "print(evaluate_with_bias(x)) # Don't pass in `new_model`." ] }, { "cell_type": "markdown", "metadata": { "id": "uFgEZClsZrEi" }, "source": [ "[リトレーシングにはコストがかかる](https://www.tensorflow.org/guide/intro_to_graphs#tracing_and_performance)ため、`tf.Variable` をオブジェクト属性として使用することができます。こうすることで、リトレーシングを行わずに、ミュートして(変更はしません)同様の効果を得ることができます。\n" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:33.103591Z", "iopub.status.busy": "2024-01-11T19:25:33.103006Z", "iopub.status.idle": "2024-01-11T19:25:33.147095Z", "shell.execute_reply": "2024-01-11T19:25:33.146416Z" }, "id": "daAP_lucwS6w" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(20.0, shape=(), dtype=float32)\n" ] } ], "source": [ "class BetterModel:\n", "\n", " def __init__(self):\n", " self.bias = tf.Variable(0.)\n", " self.weight = tf.Variable(2.)\n", "\n", "@tf.function\n", "def evaluate(model, x):\n", " return model.weight * x + model.bias\n", "\n", "better_model = BetterModel()\n", "print(evaluate(better_model, x))\n" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:33.150138Z", "iopub.status.busy": "2024-01-11T19:25:33.149872Z", "iopub.status.idle": "2024-01-11T19:25:33.156696Z", "shell.execute_reply": "2024-01-11T19:25:33.156040Z" }, "id": "ktqwMJBqwTFj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Adding bias!\n", "tf.Tensor(25.0, shape=(), dtype=float32)\n" ] } ], "source": [ "print(\"Adding bias!\")\n", "better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5\n", "print(evaluate(better_model, x)) # This works!" ] }, { "cell_type": "markdown", "metadata": { "id": "lPr_6mK_AQWL" }, "source": [ "### tf.Variables の作成\n", "\n", "`Function` は、最初の呼び出しで 1 回作成され、後続の関数呼び出しで再利用されるシングルトン `tf.Variable` のみをサポートします。以下のコードスニペットは、すべての関数呼び出しで新しい `tf.Variable` を作成します。これにより、`ValueError` 例外が発生します。\n", "\n", "例:" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:33.159746Z", "iopub.status.busy": "2024-01-11T19:25:33.159495Z", "iopub.status.idle": "2024-01-11T19:25:33.214436Z", "shell.execute_reply": "2024-01-11T19:25:33.213785Z" }, "id": "Tx0Vvnb_9OB-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_178944/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_178944/3018268426.py\", line 7, in \n", " f(1.0)\n", "ValueError: in user code:\n", "\n", " File \"/tmpfs/tmp/ipykernel_178944/3018268426.py\", line 3, in f *\n", " v = tf.Variable(1.0)\n", "\n", " ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.\n", "\n" ] } ], "source": [ "@tf.function\n", "def f(x):\n", " v = tf.Variable(1.0)\n", " return v\n", "\n", "with assert_raises(ValueError):\n", " f(1.0)" ] }, { "cell_type": "markdown", "metadata": { "id": "KYm6-5GCILXQ" }, "source": [ "この制限を回避するために使用される一般的なパターンは、Python None 値で開始し、値が None の場合は条件付きで `tf.Variable` を作成することです。" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:33.217660Z", "iopub.status.busy": "2024-01-11T19:25:33.217384Z", "iopub.status.idle": "2024-01-11T19:25:33.299266Z", "shell.execute_reply": "2024-01-11T19:25:33.298584Z" }, "id": "HQrG5_kOiKl_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(1, shape=(), dtype=int32)\n", "tf.Tensor(2, shape=(), dtype=int32)\n" ] } ], "source": [ "class Count(tf.Module):\n", " def __init__(self):\n", " self.count = None\n", "\n", " @tf.function\n", " def __call__(self):\n", " if self.count is None:\n", " self.count = tf.Variable(0)\n", " return self.count.assign_add(1)\n", "\n", "c = Count()\n", "print(c())\n", "print(c())" ] }, { "cell_type": "markdown", "metadata": { "id": "7uD6qI7aJwbR" }, "source": [ "#### 複数の Keras オプティマイザとの使用\n", "\n", "2 つ以上の Keras オプティマイザを `tf.function` で使用しようとすると、「`ValueError: tf.function only supports singleton tf.Variables created on the first call.` 」というエラーが発生することがあります。このエラーは、オプティマイザが初めて勾配を適用する際に、内部的に `tf.Variables` を作成するために発生するものです。" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:33.302485Z", "iopub.status.busy": "2024-01-11T19:25:33.302243Z", "iopub.status.idle": "2024-01-11T19:25:33.600004Z", "shell.execute_reply": "2024-01-11T19:25:33.599218Z" }, "id": "yWQ3-r99Jvze" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Calling `train_step` with different optimizer...\n", "Caught expected exception \n", " :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1705001133.576006 179113 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n", "Traceback (most recent call last):\n", " File \"/tmpfs/tmp/ipykernel_178944/3551158538.py\", line 8, in assert_raises\n", " yield\n", " File \"/tmpfs/tmp/ipykernel_178944/950644149.py\", line 18, in \n", " train_step(w, x, y, opt2)\n", "ValueError: in user code:\n", "\n", " File \"/tmpfs/tmp/ipykernel_178944/950644149.py\", line 9, in train_step *\n", " optimizer.apply_gradients(zip(gradients, [w]))\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py\", line 1223, in apply_gradients **\n", " return super().apply_gradients(grads_and_vars, name=name)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py\", line 638, in apply_gradients\n", " self.build(trainable_variables)\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/adam.py\", line 145, in build\n", " self.add_variable_from_reference(\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py\", line 1125, in add_variable_from_reference\n", " return super().add_variable_from_reference(\n", " File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py\", line 513, in add_variable_from_reference\n", " variable = tf.Variable(\n", "\n", " ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.\n", "\n" ] } ], "source": [ "opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n", "opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n", "\n", "@tf.function\n", "def train_step(w, x, y, optimizer):\n", " with tf.GradientTape() as tape:\n", " L = tf.reduce_sum(tf.square(w*x - y))\n", " gradients = tape.gradient(L, [w])\n", " optimizer.apply_gradients(zip(gradients, [w]))\n", "\n", "w = tf.Variable(2.)\n", "x = tf.constant([-1.])\n", "y = tf.constant([2.])\n", "\n", "train_step(w, x, y, opt1)\n", "print(\"Calling `train_step` with different optimizer...\")\n", "with assert_raises(ValueError):\n", " train_step(w, x, y, opt2)" ] }, { "cell_type": "markdown", "metadata": { "id": "7Q8BRPCThTjB" }, "source": [ "トレーニング中にオプティマイザを変更する必要がある場合は、回避策として、オプティマイザごとに新しい `Function` を作成し、[`ConcreteFunction`](#obtaining_concrete_functions) を直接呼び出すようにすることができます。" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:25:33.604401Z", "iopub.status.busy": "2024-01-11T19:25:33.604142Z", "iopub.status.idle": "2024-01-11T19:25:33.961455Z", "shell.execute_reply": "2024-01-11T19:25:33.960664Z" }, "id": "YV5F2Gy9hSI3" }, "outputs": [], "source": [ "opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n", "opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n", "\n", "# Not a tf.function.\n", "def train_step(w, x, y, optimizer):\n", " with tf.GradientTape() as tape:\n", " L = tf.reduce_sum(tf.square(w*x - y))\n", " gradients = tape.gradient(L, [w])\n", " optimizer.apply_gradients(zip(gradients, [w]))\n", "\n", "w = tf.Variable(2.)\n", "x = tf.constant([-1.])\n", "y = tf.constant([2.])\n", "\n", "# Make a new Function and ConcreteFunction for each optimizer.\n", "train_step_1 = tf.function(train_step)\n", "train_step_2 = tf.function(train_step)\n", "for i in range(10):\n", " if i % 2 == 0:\n", " train_step_1(w, x, y, opt1)\n", " else:\n", " train_step_2(w, x, y, opt2)" ] }, { "cell_type": "markdown", "metadata": { "id": "Xjnz5CcuqQac" }, "source": [ "#### 複数の Keras モデルとの使用\n", "\n", "また、別のモデルインスタンスを同一の `Function` に渡す際に、「`ValueError: tf.function only supports singleton tf.Variables created on the first call.`」というエラーも発生することがあります。\n", "\n", "このエラーは、Keras モデル([入力形状が定義されていない](https://www.tensorflow.org/guide/keras/custom_layers_and_models#best_practice_deferring_weight_creation_until_the_shape_of_the_inputs_is_known))と Keras レイヤーが、初めて呼び出されるときに `tf.Variables` を作成するために発生するものです。これらの変数をすでに呼び出された `Function` 内で初期化しようとしているのでしょう。このエラーを回避するには、モデルをトレーニングする前に、`model.build(input_shape)` を呼び出して、すべての重みを初期化するようにしてください。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "IKyrEY5GVX3M" }, "source": [ "## 参考資料\n", "\n", "`Function` のエクスポートと読み込みの方法については、[SavedModel ガイド](../../guide/saved_model)をご覧ください。トレーシングの後に実行するグラフの最適化については、[Grappler ガイド](../../guide/graph_optimization)をご覧ください。データパイプラインの最適化方法とモデルのプロファイリングについては、[Profiler ガイド](../../guide/profiler.md)をご覧ください。" ] } ], "metadata": { "colab": { "name": "function.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }