{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "N7ITxKLUkX0v" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-01-11T19:23:34.221855Z", "iopub.status.busy": "2024-01-11T19:23:34.221403Z", "iopub.status.idle": "2024-01-11T19:23:34.225131Z", "shell.execute_reply": "2024-01-11T19:23:34.224568Z" }, "id": "yOYx6tzSnWQ3" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "6xgB0Oz5eGSQ" }, "source": [ "# グラフと `tf.function` の基礎" ] }, { "cell_type": "markdown", "metadata": { "id": "w4zzZVZtQb1w" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示Google Colab で実行GitHub でソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "RBKqnXI9GOax" }, "source": [ "## 概要\n", "\n", "このガイドは、TensorFlow の仕組みを説明するために、TensorFlow と Keras 基礎を説明します。今すぐ Keras に取り組みたい方は、[Keras のガイド一覧](https://www.tensorflow.org/guide/keras/)をご覧ください。\n", "\n", "このガイドでは、TensorFlow でグラフ取得のための単純なコード変更、格納と表現、およびモデルの高速化とエクスポートを行う方法を説明します。\n", "\n", "注意: TensorFlow 1.x のみの知識をお持ちの場合は、このガイドでは、非常に異なるグラフビューが紹介されています。\n", "\n", "**ここでは、`tf.function` を使って Eager execution から Graph execution に切り替える方法を概説しています。**より詳しい `tf.function` の仕様については、`tf.function` によるパフォーマンスの改善ガイドをご覧ください。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "v0DdlfacAdTZ" }, "source": [ "### グラフとは?\n", "\n", "前の 3 つのガイドでは、TensorFlow を **Eager** で実行する方法を紹介しました。つまり、TensorFlow 演算は、Python によって演算ごとに実行され、Python に結果を戻しました。\n", "\n", "Eager execution には特有のメリットがいくつかありますが、Graph execution では Python 外への移植が可能になり、より優れたパフォーマンスを得られる傾向にあります。**Graph execution** では、テンソルの計算は *TensorFlow グラフ*(`tf.Graph` または単に「graph」とも呼ばれます)として実行されます。\n", "\n", "**グラフとは、計算のユニットを表す一連の `tf.Operation` オブジェクトと、演算間を流れるデータのユニットを表す `tf.Tensor` オブジェクトを含むデータ構造です。** `tf.Graph` コンテキストで定義されます。これらのグラフはデータ構造であるため、元の Python コードがなくても、保存、実行、および復元することができます。\n", "\n", "次は、TensorBoard で視覚化された二層ニューラルネットワークを表現する TensorFlow グラフです。" ] }, { "cell_type": "markdown", "metadata": { "id": "FvQ5aBuRGT1o" }, "source": [ "\"A " ] }, { "cell_type": "markdown", "metadata": { "id": "DHpY3avXGITP" }, "source": [ "### グラフのメリット\n", "\n", "グラフを使用すると、柔軟性が大幅に向上し、モバイルアプリケーション、組み込みデバイス、バックエンドサーバーといった Python インタプリタのない環境でも TensorFlow グラフを使用できます。TensorFlow は、Python からエクスポートする場合に、[SavedModel](./saved_model.ipynb) の形式としてグラフを使用します。\n", "\n", "また、グラフは最適化を簡単に行えるため、コンパイラは次のような変換を行えます。\n", "\n", "- 計算に定数ノードを畳み込むで、テンソルの値を統計的に推論します*(「定数畳み込み」)*。\n", "- 独立した計算のサブパートを分離し、スレッドまたはデバイスに分割します。\n", "- 共通部分式を取り除き、算術演算を単純化します。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "o1x1EOD9GjnB" }, "source": [ "これやほかの高速化を実行する [Grappler](./graph_optimization.ipynb) という総合的な最適化システムがあります。\n", "\n", "まとめると、グラフは非常に便利なもので、**複数のデバイス**で、TensorFlow の**高速化**、**並列化**、および効率化を期待することができます。\n", "\n", "ただし、便宜上、Python で機械学習モデル(またはその他の計算)を定義した後、必要となったときに自動的にグラフを作成することをお勧めします。" ] }, { "cell_type": "markdown", "metadata": { "id": "k-6Qi0thw2i9" }, "source": [ "## セットアップ" ] }, { "cell_type": "markdown", "metadata": { "id": "0d1689fa928f" }, "source": [ "いくつかの必要なライブラリをインポートします。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:34.229853Z", "iopub.status.busy": "2024-01-11T19:23:34.229199Z", "iopub.status.idle": "2024-01-11T19:23:36.598556Z", "shell.execute_reply": "2024-01-11T19:23:36.597755Z" }, "id": "goZwOXp_xyQj" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-01-11 19:23:34.661930: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-01-11 19:23:34.661977: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-01-11 19:23:34.663554: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tensorflow as tf\n", "import timeit\n", "from datetime import datetime" ] }, { "cell_type": "markdown", "metadata": { "id": "pSZebVuWxDXu" }, "source": [ "## グラフを利用する\n", "\n", "TensorFlow では、`tf.function` を直接呼出しまたはデコレータとして使用し、グラフを作成して実行します。`tf.function` は通常の関数を入力として取り、`Function` を返します。`Function` は、Python 関数から TensorFlow グラフを構築する Python コーラブルです。`Function` は 相当する Python 関数と同様に使用します。\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:36.602682Z", "iopub.status.busy": "2024-01-11T19:23:36.602311Z", "iopub.status.idle": "2024-01-11T19:23:38.920078Z", "shell.execute_reply": "2024-01-11T19:23:38.919297Z" }, "id": "HKbLeJ1y0Umi" }, "outputs": [], "source": [ "# Define a Python function.\n", "def a_regular_function(x, y, b):\n", " x = tf.matmul(x, y)\n", " x = x + b\n", " return x\n", "\n", "# `a_function_that_uses_a_graph` is a TensorFlow `Function`.\n", "a_function_that_uses_a_graph = tf.function(a_regular_function)\n", "\n", "# Make some tensors.\n", "x1 = tf.constant([[1.0, 2.0]])\n", "y1 = tf.constant([[2.0], [3.0]])\n", "b1 = tf.constant(4.0)\n", "\n", "orig_value = a_regular_function(x1, y1, b1).numpy()\n", "# Call a `Function` like a Python function.\n", "tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()\n", "assert(orig_value == tf_function_value)" ] }, { "cell_type": "markdown", "metadata": { "id": "PNvuAYpdrTOf" }, "source": [ "一方、`Function` は TensorFlow 演算を使って記述する通常の関数のように見えます。ただし、[その根底では](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/def_function.py)*非常に異なります*。`Function` は **1 つの API の背後で複数の `tf.Graph` をカプセル化しています**(詳細については、*多態性*セクションをご覧ください)。`Function` が速度やデプロイ可能性といった Graph execution のメリットを提供できるのはこのためです。(上記の*グラフのメリット*をご覧ください)。" ] }, { "cell_type": "markdown", "metadata": { "id": "MT7U8ozok0gV" }, "source": [ "`tf.function` は関数と*それが呼び出すその他すべての関数に次のように適用します*。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:38.924500Z", "iopub.status.busy": "2024-01-11T19:23:38.923861Z", "iopub.status.idle": "2024-01-11T19:23:38.996604Z", "shell.execute_reply": "2024-01-11T19:23:38.995969Z" }, "id": "rpz08iLplm9F" }, "outputs": [ { "data": { "text/plain": [ "array([[12.]], dtype=float32)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def inner_function(x, y, b):\n", " x = tf.matmul(x, y)\n", " x = x + b\n", " return x\n", "\n", "# Use the decorator to make `outer_function` a `Function`.\n", "@tf.function\n", "def outer_function(x):\n", " y = tf.constant([[2.0], [3.0]])\n", " b = tf.constant(4.0)\n", "\n", " return inner_function(x, y, b)\n", "\n", "# Note that the callable will create a graph that\n", "# includes `inner_function` as well as `outer_function`.\n", "outer_function(tf.constant([[1.0, 2.0]])).numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "P88fOr88qgCj" }, "source": [ "TensorFlow 1.x を使用したことがある場合は、`Placeholder` または `tf.Sesssion` をまったく定義する必要がないことに気づくでしょう。" ] }, { "cell_type": "markdown", "metadata": { "id": "wfeKf0Nr1OEK" }, "source": [ "### Python 関数をグラフに変換する\n", "\n", "TensorFlow で記述するすべての関数には、組み込みの TF 演算と、`if-then` 句、ループ、`break`、`return`、`continue` などの Python ロジックが含まれます。TensorFlow 演算は `tf.Graph` で簡単にキャプチャされますが、Python 固有のロジックがグラフの一部となるには、さらにステップが必要となります。`tf.function` は、Python コードをグラフが生成するコードに変換するために、AutoGraph(`tf.autograph`)というライブラリを使用しています。\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.000144Z", "iopub.status.busy": "2024-01-11T19:23:38.999624Z", "iopub.status.idle": "2024-01-11T19:23:39.065947Z", "shell.execute_reply": "2024-01-11T19:23:39.065191Z" }, "id": "PFObpff1BMEb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "First branch, with graph: 1\n", "Second branch, with graph: 0\n" ] } ], "source": [ "def simple_relu(x):\n", " if tf.greater(x, 0):\n", " return x\n", " else:\n", " return 0\n", "\n", "# `tf_simple_relu` is a TensorFlow `Function` that wraps `simple_relu`.\n", "tf_simple_relu = tf.function(simple_relu)\n", "\n", "print(\"First branch, with graph:\", tf_simple_relu(tf.constant(1)).numpy())\n", "print(\"Second branch, with graph:\", tf_simple_relu(tf.constant(-1)).numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "hO4DBUNZBMwQ" }, "source": [ "直接グラフを閲覧する必要があることはほぼありませんが、正確な結果を確認するために出力を検査することは可能です。簡単に読み取れるものではありませんので、精査する必要はありません!" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.069549Z", "iopub.status.busy": "2024-01-11T19:23:39.068979Z", "iopub.status.idle": "2024-01-11T19:23:39.073843Z", "shell.execute_reply": "2024-01-11T19:23:39.073191Z" }, "id": "lAKaat3w0gnn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def tf__simple_relu(x):\n", " with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:\n", " do_return = False\n", " retval_ = ag__.UndefinedReturnValue()\n", "\n", " def get_state():\n", " return (do_return, retval_)\n", "\n", " def set_state(vars_):\n", " nonlocal retval_, do_return\n", " (do_return, retval_) = vars_\n", "\n", " def if_body():\n", " nonlocal retval_, do_return\n", " try:\n", " do_return = True\n", " retval_ = ag__.ld(x)\n", " except:\n", " do_return = False\n", " raise\n", "\n", " def else_body():\n", " nonlocal retval_, do_return\n", " try:\n", " do_return = True\n", " retval_ = 0\n", " except:\n", " do_return = False\n", " raise\n", " ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)\n", " return fscope.ret(retval_, do_return)\n", "\n" ] } ], "source": [ "# This is the graph-generating output of AutoGraph.\n", "print(tf.autograph.to_code(simple_relu))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.077009Z", "iopub.status.busy": "2024-01-11T19:23:39.076445Z", "iopub.status.idle": "2024-01-11T19:23:39.080856Z", "shell.execute_reply": "2024-01-11T19:23:39.080159Z" }, "id": "8x6RAqza1UWf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "node {\n", " name: \"x\"\n", " op: \"Placeholder\"\n", " attr {\n", " key: \"_user_specified_name\"\n", " value {\n", " s: \"x\"\n", " }\n", " }\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", " attr {\n", " key: \"shape\"\n", " value {\n", " shape {\n", " }\n", " }\n", " }\n", "}\n", "node {\n", " name: \"Greater/y\"\n", " op: \"Const\"\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_INT32\n", " tensor_shape {\n", " }\n", " int_val: 0\n", " }\n", " }\n", " }\n", "}\n", "node {\n", " name: \"Greater\"\n", " op: \"Greater\"\n", " input: \"x\"\n", " input: \"Greater/y\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", "}\n", "node {\n", " name: \"cond\"\n", " op: \"StatelessIf\"\n", " input: \"Greater\"\n", " input: \"x\"\n", " attr {\n", " key: \"Tcond\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " attr {\n", " key: \"Tin\"\n", " value {\n", " list {\n", " type: DT_INT32\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"Tout\"\n", " value {\n", " list {\n", " type: DT_BOOL\n", " type: DT_INT32\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"_lower_using_switch_merge\"\n", " value {\n", " b: true\n", " }\n", " }\n", " attr {\n", " key: \"_read_only_resource_inputs\"\n", " value {\n", " list {\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"else_branch\"\n", " value {\n", " func {\n", " name: \"cond_false_31\"\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"output_shapes\"\n", " value {\n", " list {\n", " shape {\n", " }\n", " shape {\n", " }\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"then_branch\"\n", " value {\n", " func {\n", " name: \"cond_true_30\"\n", " }\n", " }\n", " }\n", "}\n", "node {\n", " name: \"cond/Identity\"\n", " op: \"Identity\"\n", " input: \"cond\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", "}\n", "node {\n", " name: \"cond/Identity_1\"\n", " op: \"Identity\"\n", " input: \"cond:1\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", "}\n", "node {\n", " name: \"Identity\"\n", " op: \"Identity\"\n", " input: \"cond/Identity_1\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", "}\n", "library {\n", " function {\n", " signature {\n", " name: \"cond_false_31\"\n", " input_arg {\n", " name: \"cond_placeholder\"\n", " type: DT_INT32\n", " }\n", " output_arg {\n", " name: \"cond_identity\"\n", " type: DT_BOOL\n", " }\n", " output_arg {\n", " name: \"cond_identity_1\"\n", " type: DT_INT32\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Const\"\n", " op: \"Const\"\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_BOOL\n", " tensor_shape {\n", " }\n", " bool_val: true\n", " }\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Const_1\"\n", " op: \"Const\"\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_BOOL\n", " tensor_shape {\n", " }\n", " bool_val: true\n", " }\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Const_2\"\n", " op: \"Const\"\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_INT32\n", " tensor_shape {\n", " }\n", " int_val: 0\n", " }\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Const_3\"\n", " op: \"Const\"\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_BOOL\n", " tensor_shape {\n", " }\n", " bool_val: true\n", " }\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Identity\"\n", " op: \"Identity\"\n", " input: \"cond/Const_3:output:0\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Const_4\"\n", " op: \"Const\"\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_INT32\n", " tensor_shape {\n", " }\n", " int_val: 0\n", " }\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Identity_1\"\n", " op: \"Identity\"\n", " input: \"cond/Const_4:output:0\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", " }\n", " ret {\n", " key: \"cond_identity\"\n", " value: \"cond/Identity:output:0\"\n", " }\n", " ret {\n", " key: \"cond_identity_1\"\n", " value: \"cond/Identity_1:output:0\"\n", " }\n", " attr {\n", " key: \"_construction_context\"\n", " value {\n", " s: \"kEagerRuntime\"\n", " }\n", " }\n", " arg_attr {\n", " key: 0\n", " value {\n", " attr {\n", " key: \"_output_shapes\"\n", " value {\n", " list {\n", " shape {\n", " }\n", " }\n", " }\n", " }\n", " }\n", " }\n", " }\n", " function {\n", " signature {\n", " name: \"cond_true_30\"\n", " input_arg {\n", " name: \"cond_identity_1_x\"\n", " type: DT_INT32\n", " }\n", " output_arg {\n", " name: \"cond_identity\"\n", " type: DT_BOOL\n", " }\n", " output_arg {\n", " name: \"cond_identity_1\"\n", " type: DT_INT32\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Const\"\n", " op: \"Const\"\n", " attr {\n", " key: \"dtype\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " attr {\n", " key: \"value\"\n", " value {\n", " tensor {\n", " dtype: DT_BOOL\n", " tensor_shape {\n", " }\n", " bool_val: true\n", " }\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Identity\"\n", " op: \"Identity\"\n", " input: \"cond/Const:output:0\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_BOOL\n", " }\n", " }\n", " }\n", " node_def {\n", " name: \"cond/Identity_1\"\n", " op: \"Identity\"\n", " input: \"cond_identity_1_x\"\n", " attr {\n", " key: \"T\"\n", " value {\n", " type: DT_INT32\n", " }\n", " }\n", " }\n", " ret {\n", " key: \"cond_identity\"\n", " value: \"cond/Identity:output:0\"\n", " }\n", " ret {\n", " key: \"cond_identity_1\"\n", " value: \"cond/Identity_1:output:0\"\n", " }\n", " attr {\n", " key: \"_construction_context\"\n", " value {\n", " s: \"kEagerRuntime\"\n", " }\n", " }\n", " arg_attr {\n", " key: 0\n", " value {\n", " attr {\n", " key: \"_output_shapes\"\n", " value {\n", " list {\n", " shape {\n", " }\n", " }\n", " }\n", " }\n", " attr {\n", " key: \"_user_specified_name\"\n", " value {\n", " s: \"x\"\n", " }\n", " }\n", " }\n", " }\n", " }\n", "}\n", "versions {\n", " producer: 1645\n", " min_consumer: 12\n", "}\n", "\n" ] } ], "source": [ "# This is the graph itself.\n", "print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())" ] }, { "cell_type": "markdown", "metadata": { "id": "GZ4Ieg6tBE6l" }, "source": [ "ほとんどの場合、`tf.function` の動作に特別な考慮はいりませんが、いくつかの注意事項があり、これについては `tf.function` ガイドのほか、[詳細な Autograph リファレンス](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md)が役立ちます。" ] }, { "cell_type": "markdown", "metadata": { "id": "sIpc_jfjEZEg" }, "source": [ "### ポリモーフィズム: 1 つの `Function` で複数のグラフを得る\n", "\n", "`tf.Graph` は特定の型の入力(特定の [`dtype`](https://www.tensorflow.org/api_docs/python/tf/dtypes/DType) のテンソルまたは同じ [`id()` のオブジェクト](https://docs.python.org/3/library/functions.html#id%5D)など)に特化しています。\n", "\n", "既存のグラフ(新しい `dtypes` や互換性のない形状の引数など)では処理できない一連の引数を指定して `Function` を呼び出すたびに、`Function` はそれらの新しい引数に特化した新しい `tf.Graph` を作成します。`tf.Graph` の入力の型指定は、その**入力シグネチャ**、または単に**シグネチャ**として知られています。新しい `tf.Graph` がいつ生成されるか、およびそれをどのように制御できるかに関する詳細については、[`tf.function` ガイドによるパフォーマンスの改善](./function.ipynb)の*トレーシングのルール*セクションに移動してください。\n", "\n", "`Function` はそのシグネチャに対応する `tf.Graph` を `ConcreteFunction` に格納します。`ConcreteFunction` は `tf.Graph` を囲むラッパーです。\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.084333Z", "iopub.status.busy": "2024-01-11T19:23:39.083746Z", "iopub.status.idle": "2024-01-11T19:23:39.491090Z", "shell.execute_reply": "2024-01-11T19:23:39.490319Z" }, "id": "LOASwhbvIv_T" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(5.5, shape=(), dtype=float32)\n", "tf.Tensor([1. 0.], shape=(2,), dtype=float32)\n", "tf.Tensor([3. 0.], shape=(2,), dtype=float32)\n" ] } ], "source": [ "@tf.function\n", "def my_relu(x):\n", " return tf.maximum(0., x)\n", "\n", "# `my_relu` creates new graphs as it observes more signatures.\n", "print(my_relu(tf.constant(5.5)))\n", "print(my_relu([1, -1]))\n", "print(my_relu(tf.constant([3., -3.])))" ] }, { "cell_type": "markdown", "metadata": { "id": "1qRtw7R4KL9X" }, "source": [ "`Function` がそのシグネチャですでに呼び出されている場合、`Function` は新しい `tf.Graph` を作成しません。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.494680Z", "iopub.status.busy": "2024-01-11T19:23:39.494420Z", "iopub.status.idle": "2024-01-11T19:23:39.500467Z", "shell.execute_reply": "2024-01-11T19:23:39.499794Z" }, "id": "TjjbnL5OKNDP" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(0.0, shape=(), dtype=float32)\n", "tf.Tensor([0. 1.], shape=(2,), dtype=float32)\n" ] } ], "source": [ "# These two calls do *not* create new graphs.\n", "print(my_relu(tf.constant(-2.5))) # Signature matches `tf.constant(5.5)`.\n", "print(my_relu(tf.constant([-1., 1.]))) # Signature matches `tf.constant([3., -3.])`." ] }, { "cell_type": "markdown", "metadata": { "id": "UohRmexhIpvQ" }, "source": [ "複数のグラフでサポートされているため、`Function` は**ポリモーフィック**です。そのため、単一の `tf.Graph` が表現できる以上の入力型をサポートし、パフォーマンスが改善されるように `tf.Graph` ごとに最適化することができます。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.503785Z", "iopub.status.busy": "2024-01-11T19:23:39.503251Z", "iopub.status.idle": "2024-01-11T19:23:39.506803Z", "shell.execute_reply": "2024-01-11T19:23:39.506201Z" }, "id": "dxzqebDYFmLy" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input Parameters:\n", " x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)\n", "Output Type:\n", " TensorSpec(shape=(), dtype=tf.float32, name=None)\n", "Captures:\n", " None\n", "\n", "Input Parameters:\n", " x (POSITIONAL_OR_KEYWORD): List[Literal[1], Literal[-1]]\n", "Output Type:\n", " TensorSpec(shape=(2,), dtype=tf.float32, name=None)\n", "Captures:\n", " None\n", "\n", "Input Parameters:\n", " x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(2,), dtype=tf.float32, name=None)\n", "Output Type:\n", " TensorSpec(shape=(2,), dtype=tf.float32, name=None)\n", "Captures:\n", " None\n" ] } ], "source": [ "# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.\n", "# The `ConcreteFunction` also knows the return type and shape!\n", "print(my_relu.pretty_printed_concrete_signatures())" ] }, { "cell_type": "markdown", "metadata": { "id": "V11zkxU22XeD" }, "source": [ "## `tf.function` を使用する\n", "\n", "ここまでで、`tf.function` をデコレータまたはラッパーとして使用するだけで、Python 関数をグラフに変換できることを学習しました。しかし実際には、`tf.function` を正しく動作させるにはコツがいります!以下のセクションでは、`tf.function` を使って期待通りにコードを動作させるようにする方法を説明します。" ] }, { "cell_type": "markdown", "metadata": { "id": "yp_n0B5-P0RU" }, "source": [ "### Graph execution と Eager execution\n", "\n", "`Function` 内のコードは、Eager と Graph の両方で実行できますが、デフォルトでは、`Function` は Graph としてコードを実行するようになっています。\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.510126Z", "iopub.status.busy": "2024-01-11T19:23:39.509740Z", "iopub.status.idle": "2024-01-11T19:23:39.513159Z", "shell.execute_reply": "2024-01-11T19:23:39.512579Z" }, "id": "_R0BOvBFxqVZ" }, "outputs": [], "source": [ "@tf.function\n", "def get_MSE(y_true, y_pred):\n", " sq_diff = tf.pow(y_true - y_pred, 2)\n", " return tf.reduce_mean(sq_diff)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.516332Z", "iopub.status.busy": "2024-01-11T19:23:39.515925Z", "iopub.status.idle": "2024-01-11T19:23:39.523146Z", "shell.execute_reply": "2024-01-11T19:23:39.522558Z" }, "id": "zikMVPGhmDET" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([9 6 8 1 5], shape=(5,), dtype=int32)\n", "tf.Tensor([2 4 4 2 3], shape=(5,), dtype=int32)\n" ] } ], "source": [ "y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)\n", "y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)\n", "print(y_true)\n", "print(y_pred)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.526323Z", "iopub.status.busy": "2024-01-11T19:23:39.525856Z", "iopub.status.idle": "2024-01-11T19:23:39.574128Z", "shell.execute_reply": "2024-01-11T19:23:39.573545Z" }, "id": "07r08Dh158ft" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_MSE(y_true, y_pred)" ] }, { "cell_type": "markdown", "metadata": { "id": "cyZNCRcQorGO" }, "source": [ "`Function` のグラフがそれに相当する Python 関数と同じように計算していることを確認するには、`tf.config.run_functions_eagerly(True)` を使って Eager で実行することができます。これは、通常どおりコードを実行するのではなく、グラフを作成して実行する `Function` の能力をオフにするスイッチです。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.577396Z", "iopub.status.busy": "2024-01-11T19:23:39.576871Z", "iopub.status.idle": "2024-01-11T19:23:39.579918Z", "shell.execute_reply": "2024-01-11T19:23:39.579294Z" }, "id": "lKoF6NjPoI8w" }, "outputs": [], "source": [ "tf.config.run_functions_eagerly(True)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.582840Z", "iopub.status.busy": "2024-01-11T19:23:39.582472Z", "iopub.status.idle": "2024-01-11T19:23:39.589153Z", "shell.execute_reply": "2024-01-11T19:23:39.588505Z" }, "id": "9ZLqTyn0oKeM" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_MSE(y_true, y_pred)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.592231Z", "iopub.status.busy": "2024-01-11T19:23:39.591721Z", "iopub.status.idle": "2024-01-11T19:23:39.594690Z", "shell.execute_reply": "2024-01-11T19:23:39.594057Z" }, "id": "cV7daQW9odn-" }, "outputs": [], "source": [ "# Don't forget to set it back when you are done.\n", "tf.config.run_functions_eagerly(False)" ] }, { "cell_type": "markdown", "metadata": { "id": "DKT3YBsqy0x4" }, "source": [ "ただし、Eager execution と Graph execution では `Function` の動作が異なることがあります。Python の [`print`](https://docs.python.org/3/library/functions.html#print) 関数がその例です。関数に `print` ステートメントを挿入して、それを繰り返し呼び出すとどうなるかを見てみましょう。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.598073Z", "iopub.status.busy": "2024-01-11T19:23:39.597591Z", "iopub.status.idle": "2024-01-11T19:23:39.601383Z", "shell.execute_reply": "2024-01-11T19:23:39.600794Z" }, "id": "BEJeVeBEoGjV" }, "outputs": [], "source": [ "@tf.function\n", "def get_MSE(y_true, y_pred):\n", " print(\"Calculating MSE!\")\n", " sq_diff = tf.pow(y_true - y_pred, 2)\n", " return tf.reduce_mean(sq_diff)" ] }, { "cell_type": "markdown", "metadata": { "id": "3sWTGwX3BzP1" }, "source": [ "何が出力されるか観察しましょう。" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.604663Z", "iopub.status.busy": "2024-01-11T19:23:39.604216Z", "iopub.status.idle": "2024-01-11T19:23:39.652380Z", "shell.execute_reply": "2024-01-11T19:23:39.651786Z" }, "id": "3rJIeBg72T9n" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Calculating MSE!\n" ] } ], "source": [ "error = get_MSE(y_true, y_pred)\n", "error = get_MSE(y_true, y_pred)\n", "error = get_MSE(y_true, y_pred)" ] }, { "cell_type": "markdown", "metadata": { "id": "WLMXk1uxKQ44" }, "source": [ "この出力に驚きましたか?**`get_MSE` は *3 回*呼び出されたにもかかわらず、出力されたのは 1 回だけでした。**\n", "\n", "説明すると、`print` ステートメントは `Function` が「トレーシング」というプロセスでグラフを作成するために元のコードを実行したときに実行されます([`tf.function` ガイド](./function.ipynb)の*トレーシング*セクションをご覧ください)。トレーシングは、TensorFlow 演算をグラフにキャプチャしますが、グラフには `print` はキャプチャされません。以降、そのグラフは **Python コードを再実行せずに**、3 つのすべての呼び出しに対して実行されます。\n", "\n", "サニティーチェックとして、Graph execution をオフにして比較してみましょう。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.655330Z", "iopub.status.busy": "2024-01-11T19:23:39.655109Z", "iopub.status.idle": "2024-01-11T19:23:39.658146Z", "shell.execute_reply": "2024-01-11T19:23:39.657500Z" }, "id": "oFSxRtcptYpe" }, "outputs": [], "source": [ "# Now, globally set everything to run eagerly to force eager execution.\n", "tf.config.run_functions_eagerly(True)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.661050Z", "iopub.status.busy": "2024-01-11T19:23:39.660812Z", "iopub.status.idle": "2024-01-11T19:23:39.665541Z", "shell.execute_reply": "2024-01-11T19:23:39.664912Z" }, "id": "qYxrAtvzNgHR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Calculating MSE!\n", "Calculating MSE!\n", "Calculating MSE!\n" ] } ], "source": [ "# Observe what is printed below.\n", "error = get_MSE(y_true, y_pred)\n", "error = get_MSE(y_true, y_pred)\n", "error = get_MSE(y_true, y_pred)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.668201Z", "iopub.status.busy": "2024-01-11T19:23:39.667972Z", "iopub.status.idle": "2024-01-11T19:23:39.671034Z", "shell.execute_reply": "2024-01-11T19:23:39.670450Z" }, "id": "_Df6ynXcAaup" }, "outputs": [], "source": [ "tf.config.run_functions_eagerly(False)" ] }, { "cell_type": "markdown", "metadata": { "id": "PUR7qC_bquCn" }, "source": [ "`print` は *Python の副作用*です。違いは他にもあり、関数を `Function` に変換する場合には注意が必要です。詳細については、[`tf.function` でパフォーマンスを向上](./function.ipynb)ガイドの*制限*セクションをご覧ください。" ] }, { "cell_type": "markdown", "metadata": { "id": "oTZJfV_tccVp" }, "source": [ "注意: Eager execution と Graph execution の両方で値を出力する場合は、代わりに `tf.print` を使用してください。" ] }, { "cell_type": "markdown", "metadata": { "id": "rMT_Xf5yKn9o" }, "source": [ "### Non-strict execution\n", "\n", "\n", "\n", "Graph execution は、観測可能な効果を生成するために必要な演算のみを実行するもので、次が含まれています。\n", "\n", "- 関数の戻り値\n", "- 以下のような、文書化された既知の副作用\n", " - 入力/出力演算。`tf.print` など。\n", " - デバッグ演算。`tf.debugging` のアサート関数など。\n", " - `tf.Variable` のミューテーション\n", "\n", "この動作は、「Non-strict execution」としてよく知られており、Eager execution とは異なり、必要であるかに関係なく、すべてのプログラム演算をステップします。\n", "\n", "具体的には、ランタイムエラーチェックは観測可能な効果として考慮されません。演算が不要であるがためにスキップされると、その演算はランタイムエラーをスローできません。\n", "\n", "次の例では、Graph execution 中に「不要な」演算 `tf.gather` がスキップされるため、Eager execution とは異なり、ランタイムエラーの `InvalidArgumentError` は発生しません。グラフの実行中にはエラーが発生することをあまり信頼しないようにしましょう。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.674194Z", "iopub.status.busy": "2024-01-11T19:23:39.673951Z", "iopub.status.idle": "2024-01-11T19:23:39.690189Z", "shell.execute_reply": "2024-01-11T19:23:39.689455Z" }, "id": "OdN0nKlUwj7M" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([0.], shape=(1,), dtype=float32)\n" ] } ], "source": [ "def unused_return_eager(x):\n", " # Get index 1 will fail when `len(x) == 1`\n", " tf.gather(x, [1]) # unused \n", " return x\n", "\n", "try:\n", " print(unused_return_eager(tf.constant([0.0])))\n", "except tf.errors.InvalidArgumentError as e:\n", " # All operations are run during eager execution so an error is raised.\n", " print(f'{type(e).__name__}: {e}')" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.693316Z", "iopub.status.busy": "2024-01-11T19:23:39.692775Z", "iopub.status.idle": "2024-01-11T19:23:39.732021Z", "shell.execute_reply": "2024-01-11T19:23:39.731425Z" }, "id": "d80Fob4MwhTs" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([0.], shape=(1,), dtype=float32)\n" ] } ], "source": [ "@tf.function\n", "def unused_return_graph(x):\n", " tf.gather(x, [1]) # unused\n", " return x\n", "\n", "# Only needed operations are run during graph execution. The error is not raised.\n", "print(unused_return_graph(tf.constant([0.0])))" ] }, { "cell_type": "markdown", "metadata": { "id": "def6MupG9R0O" }, "source": [ "### `tf.function` のベストプラクティス\n", "\n", "It may take some time to get used to the behavior of `Function`. To get started quickly, first-time users should play around with decorating toy functions with `@tf.function` to get experience with going from eager to graph execution.\n", "\n", "*`tf.function` の設計は、*グラフ互換の TensorFlow プログラムを作成するための最良の策かもしれません。いくつかのヒントを次に示します。\n", "\n", "- 早い段階で Eager execution と Graph execution を切り替えながら、2 つのモードで異なる結果を得るかどうか、またはそのタイミングを知るために `tf.config.run_functions_eagerly` を頻繁に使用しましょう。\n", "- Python 関数の外で `tf.Variable` を作成し、Python 関数内で変更するようにします。これは、`tf.keras.layers`、`tf.keras.Model`、`tf.keras.optimizers` などの `tf.Variable` を使用するオブジェクトでも同じです。\n", "- `tf.Variable` と Keras オブジェクトを除いて、外部の Python 変数に依存する関数を書くことは避けてください。[`tf.function` ガイド](./function.ipynb)の *Python のグローバル変数と自由変数に依存する*で詳細を確認してください。\n", "- テンソルと他の TensorFlow 型を入力として取る関数を記述するようにしましょう。他の型のオブジェクトを渡すことは可能ですが、十分な注意が必要です![`tf.function` ガイド](./function.ipynb)の *Python オブジェクトに依存する*で詳細を確認してください。\n", "- パフォーマンスを最大限に得るには、`tf.function` の下にできるだけ多くの計算を含めるようにしましょう。たとえば、トレーニングステップ全体またはトレーニングループ全体をデコレートすることができます。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ViM3oBJVJrDx" }, "source": [ "## 高速化の確認" ] }, { "cell_type": "markdown", "metadata": { "id": "A6NHDp7vAKcJ" }, "source": [ "コードのパフォーマンスは通常、`tf.function` によって改善されますが、改善率は実行する計算によって異なります。 小さな計算であれば、グラフ呼び出しのオーバーヘッドに制約を受ける可能性があります。パフォーマンスの変化は、次のようにして確認することができます。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.735496Z", "iopub.status.busy": "2024-01-11T19:23:39.734900Z", "iopub.status.idle": "2024-01-11T19:23:39.739407Z", "shell.execute_reply": "2024-01-11T19:23:39.738734Z" }, "id": "jr7p1BBjauPK" }, "outputs": [], "source": [ "x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)\n", "\n", "def power(x, y):\n", " result = tf.eye(10, dtype=tf.dtypes.int32)\n", " for _ in range(y):\n", " result = tf.matmul(x, result)\n", " return result" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:39.742490Z", "iopub.status.busy": "2024-01-11T19:23:39.742259Z", "iopub.status.idle": "2024-01-11T19:23:43.748576Z", "shell.execute_reply": "2024-01-11T19:23:43.747795Z" }, "id": "ms2yJyAnUYxK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Eager execution: 4.002248740000141 seconds\n" ] } ], "source": [ "print(\"Eager execution:\", timeit.timeit(lambda: power(x, 100), number=1000), \"seconds\")" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:43.751809Z", "iopub.status.busy": "2024-01-11T19:23:43.751566Z", "iopub.status.idle": "2024-01-11T19:23:44.578173Z", "shell.execute_reply": "2024-01-11T19:23:44.577396Z" }, "id": "gUB2mTyRYRAe" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Graph execution: 0.8218019930000082 seconds\n" ] } ], "source": [ "power_as_graph = tf.function(power)\n", "print(\"Graph execution:\", timeit.timeit(lambda: power_as_graph(x, 100), number=1000), \"seconds\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Q1Pfo5YwwILi" }, "source": [ "`tf.function` は一般的にトレーニングループを高速化するために使用されます。詳細については、Keras ガイドの[トレーニングループを新規作成する](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch)の{nbsp}`tf.function` を使用してトレーニングステップを高速化するセクションで詳細を確認してください。\n", "\n", "注意: パフォーマンスをさらに大きく改善させるには、`tf.function(jit_compile=True)` を使用することもできます。特に、コードで大量の TensorFlow 制御フローが使用されており、小さなテンソルが多数使用されている場合に最適です。詳細については、[XLA の概要](https://www.tensorflow.org/xla)の `tf.function(jit_compile=True)` を使用した明示的なコンパイルセクションをご覧ください。" ] }, { "cell_type": "markdown", "metadata": { "id": "sm0bNFp8PX53" }, "source": [ "### パフォーマンスとトレードオフ\n", "\n", "グラフを使ってコードを高速化することは可能ですが、グラフを作成するプロセスにはオーバーヘッドが伴います。一部の関数の場合、グラフの作成にはグラフを実行するよりも長い時間が掛かることがあります。**このオーバーヘッドは、以降の実行においてパフォーマンスが向上するのであれば挽回することができますが、大規模なモデルトレーニングの最初の数ステップではトレーシングにより速度が減少する可能性があることに注意してください。**\n", "\n", "モデルの規模に関係なく、頻繁にトレースするのは避けたほうがよいでしょう。[`tf.function` ガイド](./function.ipynb)では、トレーシングを回避できるよう、*リトレーシングの制御*セクションで入力仕様を設定してテンソル引数を使用する方法を説明しています。パフォーマンスが異常に低下している場合は、リトレーシングをうっかり行っていないかどうかを確認することをお勧めします。" ] }, { "cell_type": "markdown", "metadata": { "id": "F4InDaTjwmBA" }, "source": [ "## `Function` がトレーシングしているタイミングを確認するには\n", "\n", "`Function` がトレーシングしているタイミングを確認するには、コードに `print` ステートメントを追加すれば、`Function` がトレーシングを行うたびに `print` ステートメントが実行されるようになります。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:44.582286Z", "iopub.status.busy": "2024-01-11T19:23:44.581685Z", "iopub.status.idle": "2024-01-11T19:23:44.627283Z", "shell.execute_reply": "2024-01-11T19:23:44.626642Z" }, "id": "hXtwlbpofLgW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing!\n", "tf.Tensor(6, shape=(), dtype=int32)\n", "tf.Tensor(11, shape=(), dtype=int32)\n" ] } ], "source": [ "@tf.function\n", "def a_function_with_python_side_effect(x):\n", " print(\"Tracing!\") # An eager-only side effect.\n", " return x * x + tf.constant(2)\n", "\n", "# This is traced the first time.\n", "print(a_function_with_python_side_effect(tf.constant(2)))\n", "# The second time through, you won't see the side effect.\n", "print(a_function_with_python_side_effect(tf.constant(3)))" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:23:44.630244Z", "iopub.status.busy": "2024-01-11T19:23:44.629964Z", "iopub.status.idle": "2024-01-11T19:23:44.662690Z", "shell.execute_reply": "2024-01-11T19:23:44.662080Z" }, "id": "inzSg8yzfNjl" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tracing!\n", "tf.Tensor(6, shape=(), dtype=int32)\n", "Tracing!\n", "tf.Tensor(11, shape=(), dtype=int32)\n" ] } ], "source": [ "# This retraces each time the Python argument changes,\n", "# as a Python argument could be an epoch count or other\n", "# hyperparameter.\n", "print(a_function_with_python_side_effect(2))\n", "print(a_function_with_python_side_effect(3))" ] }, { "cell_type": "markdown", "metadata": { "id": "rtN8NW6AfKye" }, "source": [ "新しい Python 引数は、必ず新しいグラフの作成をトリガーするため、追加のトレーシングが行われます。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "D1kbr5ocpS6R" }, "source": [ "## 次のステップ\n", "\n", "`tf.function` についてさらに詳しくは、API リファレンスページをご覧ください。また、`tf.function` によるパフォーマンスの改善ガイドもお試しください。" ] } ], "metadata": { "colab": { "name": "intro_to_graphs.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }