{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-08-15T02:32:10.500640Z", "iopub.status.busy": "2024-08-15T02:32:10.500391Z", "iopub.status.idle": "2024-08-15T02:32:10.504295Z", "shell.execute_reply": "2024-08-15T02:32:10.503718Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "qFdPvlXBOdUN" }, "source": [ "# Advanced automatic differentiation" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "8a859404ce7e" }, "source": [ "The [Introduction to gradients and automatic differentiation](autodiff.ipynb) guide includes everything required to calculate gradients in TensorFlow. This guide focuses on deeper, less common features of the `tf.GradientTape` API." ] }, { "cell_type": "markdown", "metadata": { "id": "MUXex9ctTuDB" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:10.507922Z", "iopub.status.busy": "2024-08-15T02:32:10.507705Z", "iopub.status.idle": "2024-08-15T02:32:13.150344Z", "shell.execute_reply": "2024-08-15T02:32:13.149628Z" }, "id": "IqR2PQG4ZaZ0" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-08-15 02:32:10.761137: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-08-15 02:32:10.782161: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-08-15 02:32:10.788607: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "\n", "mpl.rcParams['figure.figsize'] = (8, 6)" ] }, { "cell_type": "markdown", "metadata": { "id": "uGRJJRi8TCkJ" }, "source": [ "## Controlling gradient recording\n", "\n", "In the [automatic differentiation guide](autodiff.ipynb) you saw how to control which variables and tensors are watched by the tape while building the gradient calculation.\n", "\n", "The tape also has methods to manipulate the recording." ] }, { "cell_type": "markdown", "metadata": { "id": "gB_i0VnhQKt2" }, "source": [ "### Stop recording\n", "\n", "If you wish to stop recording gradients, you can use `tf.GradientTape.stop_recording` to temporarily suspend recording.\n", "\n", "This may be useful to reduce overhead if you do not wish to differentiate a complicated operation in the middle of your model. This could include calculating a metric or an intermediate result:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:13.154876Z", "iopub.status.busy": "2024-08-15T02:32:13.154063Z", "iopub.status.idle": "2024-08-15T02:32:15.357653Z", "shell.execute_reply": "2024-08-15T02:32:15.356975Z" }, "id": "mhFSYf7uQWxR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dz/dx: tf.Tensor(4.0, shape=(), dtype=float32)\n", "dz/dy: None\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1723689133.642575 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689133.646496 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689133.650243 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689133.653354 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689133.664545 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689133.668230 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689133.671627 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689133.674592 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689133.677498 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689133.680982 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689133.684370 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689133.687370 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.924735 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.926905 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.928886 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.930883 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.932919 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.934914 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.936798 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.938737 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.940666 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.942634 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.944517 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.946466 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.984712 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.986787 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.988685 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.990637 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.993156 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.995163 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.997026 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689134.998929 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689135.000885 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689135.003378 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689135.005704 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723689135.008045 116670 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" ] } ], "source": [ "x = tf.Variable(2.0)\n", "y = tf.Variable(3.0)\n", "\n", "with tf.GradientTape() as t:\n", " x_sq = x * x\n", " with t.stop_recording():\n", " y_sq = y * y\n", " z = x_sq + y_sq\n", "\n", "grad = t.gradient(z, {'x': x, 'y': y})\n", "\n", "print('dz/dx:', grad['x']) # 2*x => 4\n", "print('dz/dy:', grad['y'])" ] }, { "cell_type": "markdown", "metadata": { "id": "DEHbEZ1h4p8A" }, "source": [ "### Reset/start recording from scratch\n", "\n", "If you wish to start over entirely, use `tf.GradientTape.reset`. Simply exiting the gradient tape block and restarting is usually easier to read, but you can use the `reset` method when exiting the tape block is difficult or impossible." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:15.361000Z", "iopub.status.busy": "2024-08-15T02:32:15.360750Z", "iopub.status.idle": "2024-08-15T02:32:15.368010Z", "shell.execute_reply": "2024-08-15T02:32:15.367418Z" }, "id": "lsMHsmrh4pqM" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dz/dx: tf.Tensor(4.0, shape=(), dtype=float32)\n", "dz/dy: None\n" ] } ], "source": [ "x = tf.Variable(2.0)\n", "y = tf.Variable(3.0)\n", "reset = True\n", "\n", "with tf.GradientTape() as t:\n", " y_sq = y * y\n", " if reset:\n", " # Throw out all the tape recorded so far.\n", " t.reset()\n", " z = x * x + y_sq\n", "\n", "grad = t.gradient(z, {'x': x, 'y': y})\n", "\n", "print('dz/dx:', grad['x']) # 2*x => 4\n", "print('dz/dy:', grad['y'])" ] }, { "cell_type": "markdown", "metadata": { "id": "6zS7cLmS6zMf" }, "source": [ "## Stop gradient flow with precision\n", "\n", "In contrast to the global tape controls above, the `tf.stop_gradient` function is much more precise. It can be used to stop gradients from flowing along a particular path, without needing access to the tape itself:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:15.371465Z", "iopub.status.busy": "2024-08-15T02:32:15.370862Z", "iopub.status.idle": "2024-08-15T02:32:15.772901Z", "shell.execute_reply": "2024-08-15T02:32:15.772211Z" }, "id": "30qnZMe48BkB" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dz/dx: tf.Tensor(4.0, shape=(), dtype=float32)\n", "dz/dy: None\n" ] } ], "source": [ "x = tf.Variable(2.0)\n", "y = tf.Variable(3.0)\n", "\n", "with tf.GradientTape() as t:\n", " y_sq = y**2\n", " z = x**2 + tf.stop_gradient(y_sq)\n", "\n", "grad = t.gradient(z, {'x': x, 'y': y})\n", "\n", "print('dz/dx:', grad['x']) # 2*x => 4\n", "print('dz/dy:', grad['y'])" ] }, { "cell_type": "markdown", "metadata": { "id": "mbb-9lnGVngH" }, "source": [ "## Custom gradients\n", "\n", "In some cases, you may want to control exactly how gradients are calculated rather than using the default. These situations include:\n", "\n", "1. There is no defined gradient for a new op you are writing.\n", "2. The default calculations are numerically unstable.\n", "3. You wish to cache an expensive computation from the forward pass.\n", "4. You want to modify a value (for example, using `tf.clip_by_value` or `tf.math.round`) without modifying the gradient.\n", "\n", "For the first case, to write a new op you can use `tf.RegisterGradient` to set up your own (refer to the API docs for details). (Note that the gradient registry is global, so change it with caution.)\n", "\n", "For the latter three cases, you can use `tf.custom_gradient`." ] }, { "cell_type": "markdown", "metadata": { "id": "oHr31kc_irF_" }, "source": [ "Here is an example that applies `tf.clip_by_norm` to the intermediate gradient:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:15.776671Z", "iopub.status.busy": "2024-08-15T02:32:15.776398Z", "iopub.status.idle": "2024-08-15T02:32:16.184748Z", "shell.execute_reply": "2024-08-15T02:32:16.184058Z" }, "id": "Mjj01w4NYtwd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(2.0, shape=(), dtype=float32)\n" ] } ], "source": [ "# Establish an identity operation, but clip during the gradient pass.\n", "@tf.custom_gradient\n", "def clip_gradients(y):\n", " def backward(dy):\n", " return tf.clip_by_norm(dy, 0.5)\n", " return y, backward\n", "\n", "v = tf.Variable(2.0)\n", "with tf.GradientTape() as t:\n", " output = clip_gradients(v * v)\n", "print(t.gradient(output, v)) # calls \"backward\", which clips 4 to 2" ] }, { "cell_type": "markdown", "metadata": { "id": "n4t7S0scYrD3" }, "source": [ "Refer to the `tf.custom_gradient` decorator API docs for more details." ] }, { "cell_type": "markdown", "metadata": { "id": "v0ODp4Oi--I0" }, "source": [ "### Custom gradients in SavedModel\n", "\n", "Note: This feature is available from TensorFlow 2.6.\n", "\n", "Custom gradients can be saved to SavedModel by using the option `tf.saved_model.SaveOptions(experimental_custom_gradients=True)`.\n", "\n", "To be saved into the SavedModel, the gradient function must be traceable (to learn more, check out the [Better performance with tf.function](function.ipynb) guide)." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:16.188528Z", "iopub.status.busy": "2024-08-15T02:32:16.188273Z", "iopub.status.idle": "2024-08-15T02:32:16.192935Z", "shell.execute_reply": "2024-08-15T02:32:16.192301Z" }, "id": "Q5JBgIBYjN1I" }, "outputs": [], "source": [ "class MyModule(tf.Module):\n", "\n", " @tf.function(input_signature=[tf.TensorSpec(None)])\n", " def call_custom_grad(self, x):\n", " return clip_gradients(x)\n", "\n", "model = MyModule()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:16.196064Z", "iopub.status.busy": "2024-08-15T02:32:16.195809Z", "iopub.status.idle": "2024-08-15T02:32:16.474435Z", "shell.execute_reply": "2024-08-15T02:32:16.473627Z" }, "id": "xZTrgy2q-9pq" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: saved_model/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: saved_model/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(2.0, shape=(), dtype=float32)\n" ] } ], "source": [ "tf.saved_model.save(\n", " model,\n", " 'saved_model',\n", " options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))\n", "\n", "# The loaded gradients will be the same as the above example.\n", "v = tf.Variable(2.0)\n", "loaded = tf.saved_model.load('saved_model')\n", "with tf.GradientTape() as t:\n", " output = loaded.call_custom_grad(v * v)\n", "print(t.gradient(output, v))" ] }, { "cell_type": "markdown", "metadata": { "id": "d-LfRs5FbJCk" }, "source": [ "A note about the above example: If you try replacing the above code with `tf.saved_model.SaveOptions(experimental_custom_gradients=False)`, the gradient will still produce the same result on loading. The reason is that the gradient registry still contains the custom gradient used in the function `call_custom_op`. However, if you restart the runtime after saving without custom gradients, running the loaded model under the `tf.GradientTape` will throw the error: `LookupError: No gradient defined for operation 'IdentityN' (op type: IdentityN)`." ] }, { "cell_type": "markdown", "metadata": { "id": "8aENEt6Veryb" }, "source": [ "## Multiple tapes\n", "\n", "Multiple tapes interact seamlessly.\n", "\n", "For example, here each tape watches a different set of tensors:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:16.478620Z", "iopub.status.busy": "2024-08-15T02:32:16.477887Z", "iopub.status.idle": "2024-08-15T02:32:16.546935Z", "shell.execute_reply": "2024-08-15T02:32:16.546147Z" }, "id": "BJ0HdMvte0VZ" }, "outputs": [], "source": [ "x0 = tf.constant(0.0)\n", "x1 = tf.constant(0.0)\n", "\n", "with tf.GradientTape() as tape0, tf.GradientTape() as tape1:\n", " tape0.watch(x0)\n", " tape1.watch(x1)\n", "\n", " y0 = tf.math.sin(x0)\n", " y1 = tf.nn.sigmoid(x1)\n", "\n", " y = y0 + y1\n", "\n", " ys = tf.reduce_sum(y)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:16.550694Z", "iopub.status.busy": "2024-08-15T02:32:16.550376Z", "iopub.status.idle": "2024-08-15T02:32:16.561049Z", "shell.execute_reply": "2024-08-15T02:32:16.560472Z" }, "id": "6ApAoMNFfNz6" }, "outputs": [ { "data": { "text/plain": [ "1.0" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tape0.gradient(ys, x0).numpy() # cos(x) => 1.0" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:16.564124Z", "iopub.status.busy": "2024-08-15T02:32:16.563895Z", "iopub.status.idle": "2024-08-15T02:32:16.570416Z", "shell.execute_reply": "2024-08-15T02:32:16.569822Z" }, "id": "rF1jrAJsfYW_" }, "outputs": [ { "data": { "text/plain": [ "0.25" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tape1.gradient(ys, x1).numpy() # sigmoid(x1)*(1-sigmoid(x1)) => 0.25" ] }, { "cell_type": "markdown", "metadata": { "id": "DK05KXrAAld3" }, "source": [ "### Higher-order gradients\n", "\n", "Operations inside of the `tf.GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well.\n", "\n", "For example:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:16.573544Z", "iopub.status.busy": "2024-08-15T02:32:16.573317Z", "iopub.status.idle": "2024-08-15T02:32:16.582207Z", "shell.execute_reply": "2024-08-15T02:32:16.581596Z" }, "id": "cPQgthZ7ugRJ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dy_dx: 3.0\n", "d2y_dx2: 6.0\n" ] } ], "source": [ "x = tf.Variable(1.0) # Create a Tensorflow variable initialized to 1.0\n", "\n", "with tf.GradientTape() as t2:\n", " with tf.GradientTape() as t1:\n", " y = x * x * x\n", "\n", " # Compute the gradient inside the outer `t2` context manager\n", " # which means the gradient computation is differentiable as well.\n", " dy_dx = t1.gradient(y, x)\n", "d2y_dx2 = t2.gradient(dy_dx, x)\n", "\n", "print('dy_dx:', dy_dx.numpy()) # 3 * x**2 => 3.0\n", "print('d2y_dx2:', d2y_dx2.numpy()) # 6 * x => 6.0" ] }, { "cell_type": "markdown", "metadata": { "id": "k0HV-Ah4_76i" }, "source": [ "While that does give you the second derivative of a _scalar_ function, this pattern does not generalize to produce a Hessian matrix, since `tf.GradientTape.gradient` only computes the gradient of a scalar. To construct a [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), go to the [Hessian example](#hessian) under the [Jacobian section](#jacobians).\n", "\n", "\"Nested calls to `tf.GradientTape.gradient`\" is a good pattern when you are calculating a scalar from a gradient, and then the resulting scalar acts as a source for a second gradient calculation, as in the following example.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "t7LRlcpVKHv1" }, "source": [ "#### Example: Input gradient regularization\n", "\n", "Many models are susceptible to \"adversarial examples\". This collection of techniques modifies the model's input to confuse the model's output. The simplest implementation—such as the [Adversarial example using the Fast Gradient Signed Method attack](https://www.tensorflow.org/tutorials/generative/adversarial_fgsm)—takes a single step along the gradient of the output with respect to the input; the \"input gradient\".\n", "\n", "One technique to increase robustness to adversarial examples is [input gradient regularization](https://arxiv.org/abs/1905.11468) (Finlay & Oberman, 2019), which attempts to minimize the magnitude of the input gradient. If the input gradient is small, then the change in the output should be small too.\n", "\n", "Below is a naive implementation of input gradient regularization. The implementation is:\n", "\n", "1. Calculate the gradient of the output with respect to the input using an inner tape.\n", "2. Calculate the magnitude of that input gradient.\n", "3. Calculate the gradient of that magnitude with respect to the model." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:16.585351Z", "iopub.status.busy": "2024-08-15T02:32:16.585116Z", "iopub.status.idle": "2024-08-15T02:32:16.597575Z", "shell.execute_reply": "2024-08-15T02:32:16.596989Z" }, "id": "tH3ZFuUfDLrR" }, "outputs": [], "source": [ "x = tf.random.normal([7, 5])\n", "\n", "layer = tf.keras.layers.Dense(10, activation=tf.nn.relu)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:16.600741Z", "iopub.status.busy": "2024-08-15T02:32:16.600515Z", "iopub.status.idle": "2024-08-15T02:32:17.414971Z", "shell.execute_reply": "2024-08-15T02:32:17.414200Z" }, "id": "E6yOFsjEDR9u" }, "outputs": [], "source": [ "with tf.GradientTape() as t2:\n", " # The inner tape only takes the gradient with respect to the input,\n", " # not the variables.\n", " with tf.GradientTape(watch_accessed_variables=False) as t1:\n", " t1.watch(x)\n", " y = layer(x)\n", " out = tf.reduce_sum(layer(x)**2)\n", " # 1. Calculate the input gradient.\n", " g1 = t1.gradient(out, x)\n", " # 2. Calculate the magnitude of the input gradient.\n", " g1_mag = tf.norm(g1)\n", "\n", "# 3. Calculate the gradient of the magnitude with respect to the model.\n", "dg1_mag = t2.gradient(g1_mag, layer.trainable_variables)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:17.419534Z", "iopub.status.busy": "2024-08-15T02:32:17.418960Z", "iopub.status.idle": "2024-08-15T02:32:17.423751Z", "shell.execute_reply": "2024-08-15T02:32:17.423140Z" }, "id": "123QMq6PqK_d" }, "outputs": [ { "data": { "text/plain": [ "[TensorShape([5, 10]), TensorShape([10])]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[var.shape for var in dg1_mag]" ] }, { "cell_type": "markdown", "metadata": { "id": "E4xiYigexMtQ" }, "source": [ "## Jacobians\n" ] }, { "cell_type": "markdown", "metadata": { "id": "4-hVHVIeExkI" }, "source": [ "All the previous examples took the gradients of a scalar target with respect to some source tensor(s).\n", "\n", "The [Jacobian matrix](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) represents the gradients of a vector valued function. Each row contains the gradient of one of the vector's elements.\n", "\n", "The `tf.GradientTape.jacobian` method allows you to efficiently calculate a Jacobian matrix." ] }, { "cell_type": "markdown", "metadata": { "id": "KzNyIM0QBYIH" }, "source": [ "Note that:\n", "\n", "* Like `gradient`: The `sources` argument can be a tensor or a container of tensors.\n", "* Unlike `gradient`: The `target` tensor must be a single tensor." ] }, { "cell_type": "markdown", "metadata": { "id": "O74K3hlxBC8a" }, "source": [ "### Scalar source" ] }, { "cell_type": "markdown", "metadata": { "id": "B08OKn1Orkuc" }, "source": [ "As a first example, here is the Jacobian of a vector-target with respect to a scalar-source." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:17.427622Z", "iopub.status.busy": "2024-08-15T02:32:17.426995Z", "iopub.status.idle": "2024-08-15T02:32:18.258955Z", "shell.execute_reply": "2024-08-15T02:32:18.258207Z" }, "id": "bAFeIE8EuVIq" }, "outputs": [], "source": [ "x = tf.linspace(-10.0, 10.0, 200+1)\n", "delta = tf.Variable(0.0)\n", "\n", "with tf.GradientTape() as tape:\n", " y = tf.nn.sigmoid(x+delta)\n", "\n", "dy_dx = tape.jacobian(y, delta)" ] }, { "cell_type": "markdown", "metadata": { "id": "BgHbUk3zr-WU" }, "source": [ "When you take the Jacobian with respect to a scalar the result has the shape of the **target**, and gives the gradient of the each element with respect to the source:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:18.262973Z", "iopub.status.busy": "2024-08-15T02:32:18.262722Z", "iopub.status.idle": "2024-08-15T02:32:18.266543Z", "shell.execute_reply": "2024-08-15T02:32:18.265859Z" }, "id": "iZ6awnDzr_BA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(201,)\n", "(201,)\n" ] } ], "source": [ "print(y.shape)\n", "print(dy_dx.shape)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:18.269421Z", "iopub.status.busy": "2024-08-15T02:32:18.269170Z", "iopub.status.idle": "2024-08-15T02:32:18.439623Z", "shell.execute_reply": "2024-08-15T02:32:18.439009Z" }, "id": "siNZaklc0_-e" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(x.numpy(), y, label='y')\n", "plt.plot(x.numpy(), dy_dx, label='dy/dx')\n", "plt.legend()\n", "_ = plt.xlabel('x')" ] }, { "cell_type": "markdown", "metadata": { "id": "DsOMSD_1BGkD" }, "source": [ "### Tensor source" ] }, { "cell_type": "markdown", "metadata": { "id": "g3iXKN7KF-st" }, "source": [ "Whether the input is scalar or tensor, `tf.GradientTape.jacobian` efficiently calculates the gradient of each element of the source with respect to each element of the target(s).\n", "\n", "For example, the output of this layer has a shape of `(10, 7)`:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:18.443323Z", "iopub.status.busy": "2024-08-15T02:32:18.442721Z", "iopub.status.idle": "2024-08-15T02:32:18.455734Z", "shell.execute_reply": "2024-08-15T02:32:18.455165Z" }, "id": "39YXItgLxMBk" }, "outputs": [ { "data": { "text/plain": [ "TensorShape([7, 10])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.random.normal([7, 5])\n", "layer = tf.keras.layers.Dense(10, activation=tf.nn.relu)\n", "\n", "with tf.GradientTape(persistent=True) as tape:\n", " y = layer(x)\n", "\n", "y.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "tshNRtfKuVP_" }, "source": [ "And the layer's kernel's shape is `(5, 10)`:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:18.458758Z", "iopub.status.busy": "2024-08-15T02:32:18.458521Z", "iopub.status.idle": "2024-08-15T02:32:18.462766Z", "shell.execute_reply": "2024-08-15T02:32:18.462200Z" }, "id": "CigTWyfPvPuv" }, "outputs": [ { "data": { "text/plain": [ "TensorShape([5, 10])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "layer.kernel.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "mN96JRpnAjpx" }, "source": [ "The shape of the Jacobian of the output with respect to the kernel is those two shapes concatenated together:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:18.465678Z", "iopub.status.busy": "2024-08-15T02:32:18.465464Z", "iopub.status.idle": "2024-08-15T02:32:18.585137Z", "shell.execute_reply": "2024-08-15T02:32:18.584542Z" }, "id": "pRLzTTbvEimH" }, "outputs": [ { "data": { "text/plain": [ "TensorShape([7, 10, 5, 10])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "j = tape.jacobian(y, layer.kernel)\n", "j.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "2Lrv7miMvTll" }, "source": [ "If you sum over the target's dimensions, you're left with the gradient of the sum that would have been calculated by `tf.GradientTape.gradient`:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:18.588391Z", "iopub.status.busy": "2024-08-15T02:32:18.588166Z", "iopub.status.idle": "2024-08-15T02:32:18.596854Z", "shell.execute_reply": "2024-08-15T02:32:18.596171Z" }, "id": "FJjZpYRnDjVa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "g.shape: (5, 10)\n", "delta: 2.3841858e-07\n" ] } ], "source": [ "g = tape.gradient(y, layer.kernel)\n", "print('g.shape:', g.shape)\n", "\n", "j_sum = tf.reduce_sum(j, axis=[0, 1])\n", "delta = tf.reduce_max(abs(g - j_sum)).numpy()\n", "assert delta < 1e-3\n", "print('delta:', delta)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZKajuGlk_krs" }, "source": [ " \n", "\n", "#### Example: Hessian" ] }, { "cell_type": "markdown", "metadata": { "id": "NYcsXeo8TDLi" }, "source": [ "While `tf.GradientTape` doesn't give an explicit method for constructing a [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix) it's possible to build one using the `tf.GradientTape.jacobian` method.\n", "\n", "Note: The Hessian matrix contains `N**2` parameters. For this and other reasons it is not practical for most models. This example is included more as a demonstration of how to use the `tf.GradientTape.jacobian` method, and is not an endorsement of direct Hessian-based optimization. A Hessian-vector product can be [calculated efficiently with nested tapes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/benchmarks/resnet50/hvp_test.py), and is a much more efficient approach to second-order optimization." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:18.600014Z", "iopub.status.busy": "2024-08-15T02:32:18.599794Z", "iopub.status.idle": "2024-08-15T02:32:18.877192Z", "shell.execute_reply": "2024-08-15T02:32:18.876470Z" }, "id": "ELGTaell_j81" }, "outputs": [], "source": [ "x = tf.random.normal([7, 5])\n", "layer1 = tf.keras.layers.Dense(8, activation=tf.nn.relu)\n", "layer2 = tf.keras.layers.Dense(6, activation=tf.nn.relu)\n", "\n", "with tf.GradientTape() as t2:\n", " with tf.GradientTape() as t1:\n", " x = layer1(x)\n", " x = layer2(x)\n", " loss = tf.reduce_mean(x**2)\n", "\n", " g = t1.gradient(loss, layer1.kernel)\n", "\n", "h = t2.jacobian(g, layer1.kernel)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:18.880444Z", "iopub.status.busy": "2024-08-15T02:32:18.880203Z", "iopub.status.idle": "2024-08-15T02:32:18.883872Z", "shell.execute_reply": "2024-08-15T02:32:18.883295Z" }, "id": "FVqQuZj4XGjm" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "layer.kernel.shape: (5, 8)\n", "h.shape: (5, 8, 5, 8)\n" ] } ], "source": [ "print(f'layer.kernel.shape: {layer1.kernel.shape}')\n", "print(f'h.shape: {h.shape}')" ] }, { "cell_type": "markdown", "metadata": { "id": "_M7XElgaiMeP" }, "source": [ "To use this Hessian for a [Newton's method](https://en.wikipedia.org/wiki/Newton%27s_method_in_optimization) step, you would first flatten out its axes into a matrix, and flatten out the gradient into a vector:" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:18.887299Z", "iopub.status.busy": "2024-08-15T02:32:18.886625Z", "iopub.status.idle": "2024-08-15T02:32:18.899284Z", "shell.execute_reply": "2024-08-15T02:32:18.898703Z" }, "id": "6te7N6wVXwXX" }, "outputs": [], "source": [ "n_params = tf.reduce_prod(layer1.kernel.shape)\n", "\n", "g_vec = tf.reshape(g, [n_params, 1])\n", "h_mat = tf.reshape(h, [n_params, n_params])" ] }, { "cell_type": "markdown", "metadata": { "id": "L9rO8b-0mgOH" }, "source": [ "The Hessian matrix should be symmetric:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:18.902307Z", "iopub.status.busy": "2024-08-15T02:32:18.902088Z", "iopub.status.idle": "2024-08-15T02:32:18.905558Z", "shell.execute_reply": "2024-08-15T02:32:18.904989Z" }, "id": "8TCHc7Vrf52S" }, "outputs": [], "source": [ "def imshow_zero_center(image, **kwargs):\n", " lim = tf.reduce_max(abs(image))\n", " plt.imshow(image, vmin=-lim, vmax=lim, cmap='seismic', **kwargs)\n", " plt.colorbar()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:18.908508Z", "iopub.status.busy": "2024-08-15T02:32:18.908283Z", "iopub.status.idle": "2024-08-15T02:32:19.182786Z", "shell.execute_reply": "2024-08-15T02:32:19.182052Z" }, "id": "DExOxd7Ok2H0" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "imshow_zero_center(h_mat)" ] }, { "cell_type": "markdown", "metadata": { "id": "13fBswmtQes4" }, "source": [ "The Newton's method update step is shown below:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:19.186423Z", "iopub.status.busy": "2024-08-15T02:32:19.186142Z", "iopub.status.idle": "2024-08-15T02:32:19.192987Z", "shell.execute_reply": "2024-08-15T02:32:19.192387Z" }, "id": "3DdnbynBdSor" }, "outputs": [], "source": [ "eps = 1e-3\n", "eye_eps = tf.eye(h_mat.shape[0])*eps" ] }, { "cell_type": "markdown", "metadata": { "id": "-zPdtyoWeUeV" }, "source": [ "Note: [Don't actually invert the matrix](https://www.johndcook.com/blog/2010/01/19/dont-invert-that-matrix/)." ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:19.196637Z", "iopub.status.busy": "2024-08-15T02:32:19.195958Z", "iopub.status.idle": "2024-08-15T02:32:19.248580Z", "shell.execute_reply": "2024-08-15T02:32:19.247954Z" }, "id": "k1LYftgmswOO" }, "outputs": [], "source": [ "# X(k+1) = X(k) - (∇²f(X(k)))^-1 @ ∇f(X(k))\n", "# h_mat = ∇²f(X(k))\n", "# g_vec = ∇f(X(k))\n", "update = tf.linalg.solve(h_mat + eye_eps, g_vec)\n", "\n", "# Reshape the update and apply it to the variable.\n", "_ = layer1.kernel.assign_sub(tf.reshape(update, layer1.kernel.shape))" ] }, { "cell_type": "markdown", "metadata": { "id": "pF6qjlHKWxF4" }, "source": [ "While this is relatively simple for a single `tf.Variable`, applying this to a non-trivial model would require careful concatenation and slicing to produce a full Hessian across multiple variables." ] }, { "cell_type": "markdown", "metadata": { "id": "PQWM0uN-GO5t" }, "source": [ "### Batch Jacobian" ] }, { "cell_type": "markdown", "metadata": { "id": "hKtB3rY6EySJ" }, "source": [ "In some cases, you want to take the Jacobian of each of a stack of targets with respect to a stack of sources, where the Jacobians for each target-source pair are independent.\n", "\n", "For example, here the input `x` is shaped `(batch, ins)` and the output `y` is shaped `(batch, outs)`:\n" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:19.252063Z", "iopub.status.busy": "2024-08-15T02:32:19.251837Z", "iopub.status.idle": "2024-08-15T02:32:19.333771Z", "shell.execute_reply": "2024-08-15T02:32:19.333103Z" }, "id": "tQMndhIUHMes" }, "outputs": [ { "data": { "text/plain": [ "TensorShape([7, 6])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.random.normal([7, 5])\n", "\n", "layer1 = tf.keras.layers.Dense(8, activation=tf.nn.elu)\n", "layer2 = tf.keras.layers.Dense(6, activation=tf.nn.elu)\n", "\n", "with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:\n", " tape.watch(x)\n", " y = layer1(x)\n", " y = layer2(y)\n", "\n", "y.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "Ff2spRHEJXBU" }, "source": [ "The full Jacobian of `y` with respect to `x` has a shape of `(batch, ins, batch, outs)`, even if you only want `(batch, ins, outs)`:" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:19.337404Z", "iopub.status.busy": "2024-08-15T02:32:19.337109Z", "iopub.status.idle": "2024-08-15T02:32:19.453376Z", "shell.execute_reply": "2024-08-15T02:32:19.452768Z" }, "id": "1zSl2A5-HhMH" }, "outputs": [ { "data": { "text/plain": [ "TensorShape([7, 6, 7, 5])" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "j = tape.jacobian(y, x)\n", "j.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "UibJijPLJrpQ" }, "source": [ "If the gradients of each item in the stack are independent, then every `(batch, batch)` slice of this tensor is a diagonal matrix:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:19.456497Z", "iopub.status.busy": "2024-08-15T02:32:19.456243Z", "iopub.status.idle": "2024-08-15T02:32:19.678736Z", "shell.execute_reply": "2024-08-15T02:32:19.678068Z" }, "id": "ZFl9uj3ueVSH" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlYAAAIQCAYAAABDpCBuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA7sUlEQVR4nO3deXxU9b3/8fckIROEJIBkYQmERYFIWUwgorIoKSCIYlGWm15D5IdtJRSNXURvWao1WFzQFkG8F7xWEKsVil4BKbJUQZZQKqAiKJawJCECGQgSIHN+f6SMDklOzsCZmWR4PR+P83iYkzPf8zmDD/Lh/f3mOw7DMAwBAADgsoUFuwAAAIBQQWMFAABgExorAAAAm9BYAQAA2ITGCgAAwCY0VgAAADahsQIAALAJjRUAAIBNaKwAAABsQmOFK5Lb7VbXrl31u9/9znNu+vTpcjgcKikpCWJl31m3bp0cDofeeust28b8+uuv5XA49PTTT9s25uVKTk7W7bffXut1K1euVOPGjXX06FHb7j1gwAANGDDA8/WF9+eVV16x7R4Ariw0Vqi3XnzxRTkcDqWnp/v82tdff10FBQXKycnxQ2XfefHFF0P+h/TGjRs1ffp0nThxwq/3GTJkiDp27Ki8vDy/3gcALgeNFeqtRYsWKTk5WVu2bNG+fft8eu2sWbM0ZswYxcbG+qm6SldKYzVjxgy/N1aS9JOf/EQvvfSSTp486Zfx27Ztq2+//Vb/+Z//6ZfxAYQ+GivUS/v379fGjRv17LPPKi4uTosWLbL82n/84x/65z//qVGjRvmxQvjDyJEjVV5erjfffNMv4zscDkVFRSk8PNwv4wMIfTRWqJcWLVqkpk2batiwYbr77rt9aqyWLVumyMhI9evXr9rvl5SUaNSoUYqJidHVV1+tyZMn68yZM17XLFy4ULfeeqvi4+PldDqVkpKiuXPnel2TnJys3bt3a/369XI4HHI4HF7reU6cOKGHHnpIycnJcjqdat26te69994qa7zcbrd+97vfqXXr1oqKitLAgQN9Tuiq89xzz6lt27Zq2LCh+vfvr127dnl9/5NPPtG4cePUvn17RUVFKTExUffdd5+++eYbzzXTp0/XL3/5S0lSu3btPM/59ddfe6557bXX1Lt3b1111VVq2rSp+vXrp/fff79KPR9++KF69+6tqKgotW/fXq+++mqVa+Lj49WtWzf99a9/rfX5CgsLlZ2drdatW8vpdKpFixa68847vWq7WE1rrD7//HONGjVKcXFxatiwoTp16qTHHnvM65pDhw7pvvvuU0JCgpxOp6677jotWLCg1joBhJaIYBcAXIpFixbpRz/6kSIjIzV27FjNnTtXW7duVa9evWp97caNG9W1a1c1aNCg2u+PGjVKycnJysvL08cff6wXXnhBx48f9/pBP3fuXF133XW64447FBERoXfeeUcPPPCA3G63Jk6cKEmaPXu2Jk2apMaNG3t+CCckJEiSTp06pb59++qzzz7Tfffdp+uvv14lJSVavny5Dh48qObNm3vuNXPmTIWFhekXv/iFSktL9fvf/16ZmZnavHnzJb9/r776qk6ePKmJEyfqzJkzev7553Xrrbdq586dnhpXr16tr776StnZ2UpMTNTu3bs1f/587d69Wx9//LEcDod+9KMf6YsvvtDrr7+u5557zlN3XFycJGnGjBmaPn26brzxRv32t79VZGSkNm/erA8++ECDBg3y1LNv3z7dfffdGj9+vLKysrRgwQKNGzdOqampuu6667xqT01N1bJly2p9xpEjR2r37t2aNGmSkpOTVVxcrNWrV+vAgQNKTk62/F598skn6tu3rxo0aKD7779fycnJ+vLLL/XOO+94fvmhqKhIN9xwgxwOh3JychQXF6cVK1Zo/PjxcrlcevDBBy3fD0A9ZwD1zLZt2wxJxurVqw3DMAy32220bt3amDx5sqXXt27d2hg5cmSV89OmTTMkGXfccYfX+QceeMCQZPzzn//0nDt9+nSV1w8ePNho376917nrrrvO6N+/f5Vrp06dakgy3n777Srfc7vdhmEYxtq1aw1JRpcuXYzy8nLP959//nlDkrFz507zB63G/v37DUlGw4YNjYMHD3rOb9682ZBkPPTQQ6bP+PrrrxuSjA0bNnjOzZo1y5Bk7N+/3+vavXv3GmFhYcZdd91lVFRUVPuMhmEYbdu2rTJmcXGx4XQ6jYcffrhKDU8++aQhySgqKqrxOY8fP25IMmbNmlXjNYZhGP379/f687nw/ixcuNBzrl+/fkZ0dLTxr3/9q8ZnGD9+vNGiRQujpKTE65oxY8YYsbGx1b6XAEITU4GodxYtWqSEhATdcsstkirXxYwePVpLlixRRUVFra//5ptv1LRp0xq/fyFxumDSpEmSpPfee89zrmHDhp7/Li0tVUlJifr376+vvvpKpaWltdbwl7/8Rd27d9ddd91V5XsOh8Pr6+zsbEVGRnq+7tu3ryTpq6++qvU+NRkxYoRatWrl+bp3795KT0+v8RnPnDmjkpIS3XDDDZKk7du313qPZcuWye12a+rUqQoL8/6r5uJnTElJ8TyXVJl4derUqdpnvPBnZ7YtRsOGDRUZGal169bp+PHjtdZak6NHj2rDhg2677771KZNm2qfwTAM/eUvf9Hw4cNlGIZKSko8x+DBg1VaWmrp/QIQGmisUK9UVFRoyZIluuWWW7R//37t27dP+/btU3p6uoqKirRmzRpL4xiGUeP3rrnmGq+vO3TooLCwMK+1OR999JEyMjLUqFEjNWnSRHFxcXr00UclyVJj9eWXX6pr166War34B/qFxuJyGoaLn1GSrr32Wq9nPHbsmCZPnqyEhAQ1bNhQcXFxateunSTrzxgWFqaUlJRar734GaXK56zuGS/82V3cnH2f0+nUU089pRUrVighIUH9+vXT73//exUWFtZay/ddaOzM/qyOHj2qEydOaP78+YqLi/M6srOzJUnFxcU+3RdA/cUaK9QrH3zwgY4cOaIlS5ZoyZIlVb6/aNEir7U71bn66qt9akou/gH+5ZdfauDAgercubOeffZZJSUlKTIyUu+9956ee+45ud1uy2NbUdNvqJk1h3YYNWqUNm7cqF/+8pfq0aOHGjduLLfbrSFDhgT1GS/82X1/HVp1HnzwQQ0fPlzLli3TqlWr9Jvf/EZ5eXn64IMP1LNnz8sv+t8uvBc//vGPlZWVVe013bp1s+1+AOo2GivUK4sWLVJ8fLzmzJlT5Xtvv/22li5dqnnz5nlNY12sc+fO2r9/f43f37t3ryeZkSoXVrvdbs+C53feeUfl5eVavny5V9Kydu3aKmPVlKp06NChym/hBdLevXurnPviiy88z3j8+HGtWbNGM2bM0NSpU01fZ/aMbrdbn376qXr06GFL3VLlVhvNmzf3LJA306FDBz388MN6+OGHtXfvXvXo0UPPPPOMXnvtNUv3at++vSSZ/lnFxcUpOjpaFRUVysjIsPYQAEIWU4GoN7799lu9/fbbuv3223X33XdXOXJycnTy5EktX77cdJw+ffpo165dKi8vr/b7Fzdtf/jDHyRJt912m6Tv0pXvpymlpaVauHBhlbEaNWpU7caZI0eO1D//+U8tXbq0yvf8nURJleufDh065Pl6y5Yt2rx5s+kzSpW/6XixRo0aSVKV5xwxYoTCwsL029/+tkrCdTnPmJ+frz59+phec/r06SpbZHTo0EHR0dE1/rlXJy4uTv369dOCBQt04MABr+9deIbw8HCNHDlSf/nLX6ptwOz8CB4AdR+JFeqN5cuX6+TJk7rjjjuq/f4NN9zg2Sx09OjRNY5z55136vHHH9f69eurnTbcv3+/7rjjDg0ZMkSbNm3Sa6+9pv/4j/9Q9+7dJUmDBg1SZGSkhg8frp/85Cc6deqUXn75ZcXHx+vIkSNeY6Wmpmru3Ll64okn1LFjR8XHx+vWW2/VL3/5S7311lu65557dN999yk1NVXHjh3T8uXLNW/ePM+9rFq3bp1uueUWTZs2TdOnT6/1+o4dO+rmm2/Wz372M5WXl2v27Nm6+uqr9atf/UqSFBMT41mXdO7cObVq1Urvv/9+tUlfamqqJOmxxx7TmDFj1KBBAw0fPlwdO3bUY489pscff1x9+/bVj370IzmdTm3dulUtW7a8pI+mKS4u1ieffFLlFwwu9sUXX2jgwIEaNWqUUlJSFBERoaVLl6qoqEhjxozx6Z4vvPCCbr75Zl1//fW6//771a5dO3399df6v//7P+3YsUNS5ZYYa9euVXp6uiZMmKCUlBQdO3ZM27dv19/+9jcdO3bM52cFUE8F6bcRAZ8NHz7ciIqKMsrKymq8Zty4cUaDBg2q/Nr7xbp162aMHz/e69yF7RY+/fRT4+677zaio6ONpk2bGjk5Oca3337rde3y5cuNbt26GVFRUUZycrLx1FNPGQsWLKiy7UBhYaExbNgwIzo62pDk9av933zzjZGTk2O0atXKiIyMNFq3bm1kZWV5ar+w3cKbb77pde/qtgR45513DEnGvHnzTJ/7wmtnzZplPPPMM0ZSUpLhdDqNvn37em0nYRiGcfDgQeOuu+4ymjRpYsTGxhr33HOPcfjwYUOSMW3aNK9rH3/8caNVq1ZGWFhYlfdgwYIFRs+ePQ2n02k0bdrU6N+/v2erDMOo3G5h2LBhVWq9eCsEwzCMuXPnGldddZXhcrlMn7OkpMSYOHGi0blzZ6NRo0ZGbGyskZ6ebvz5z382vUd1761hGMauXbs870VUVJTRqVMn4ze/+Y3XNUVFRcbEiRONpKQko0GDBkZiYqIxcOBAY/78+aa1AggtDsMIwLwDUMf86U9/0sSJE3XgwAE1adIk2OVctl/96ld6/fXXtW/fPjmdzmCX4zc9e/bUgAED9NxzzwW7FACoFmuscEXKzMxUmzZtql0EXx+tXbtWv/nNb0K6qVq5cqX27t2rKVOmBLsUAKgRiRUAAIBNSKwAAABsQmMFAABgExorAAAAm9BYAQAA2CTgG4S63W4dPnxY0dHRph+iCgAAKnf5P3nypFq2bKmwsMDnIWfOnNHZs2f9MnZkZKSioqL8MnawBLyxOnz4sJKSkgJ9WwAA6rWCggK1bt06oPc8c+aM4ho21Ck/jZ+YmKj9+/eHVHMV8MYqOjr63/81MBi3r1NKjy8JdgkAgDrO5XIpqW3b7/38DJyzZ8/qlKSHJNm9S165pOcKC3X27Fkaq8vx3fRfhKQGgb59nRITExPsEgAA9UQwl880lGR36xOqi7xD9bkAAAAC7sqeiwMAALUKk/1JTKgmO6H6XAAAAAFHYgUAAEyRWFlHYwUAAEzRWFkXqs8FAAAQcCRWAADAFImVdaH6XAAAAAFHYgUAAEyRWFkXqs8FAAAQcCRWAADAFImVdaH6XAAAAAFHYgUAAEyRWFkXqs8FAAAQcCRWAADAlEP2JzEOm8erK2isAACAKYfsb4RCtbFiKhAAAMAmJFYAAMBU+L8Pu8cMRSRWAAAANiGxAgAApthuwbpQfS4AAICAI7ECAACmSKysC9XnAgAACDgSKwAAYIrEyjoaKwAAYIrGyrpQfS4AAICAu6TGas6cOUpOTlZUVJTS09O1ZcsWu+sCAAB1RJifjlDk83O98cYbys3N1bRp07R9+3Z1795dgwcPVnFxsT/qAwAAqDd8bqyeffZZTZgwQdnZ2UpJSdG8efN01VVXacGCBf6oDwAABFldSax8mTF7++23lZaWpiZNmqhRo0bq0aOH/vSnP13CXX3j03OdPXtW+fn5ysjI+G6AsDBlZGRo06ZN1b6mvLxcLpfL6wAAAPCFrzNmzZo102OPPaZNmzbpk08+UXZ2trKzs7Vq1Sq/1ulTY1VSUqKKigolJCR4nU9ISFBhYWG1r8nLy1NsbKznSEpKuvRqAQBAwDn8dPjC1xmzAQMG6K677lKXLl3UoUMHTZ48Wd26ddOHH37o45194/e1Y1OmTFFpaannKCgo8PctAQBACLmUGbPvMwxDa9as0Z49e9SvXz9/lurbPlbNmzdXeHi4ioqKvM4XFRUpMTGx2tc4nU45nc5LrxAAAARVmKRwP4wpqcoSoer6BrMZs88//7zGe5SWlqpVq1YqLy9XeHi4XnzxRf3whz+0pf6a+JRYRUZGKjU1VWvWrPGcc7vdWrNmjfr06WN7cQAAIPj8uXg9KSnJa8lQXl6ebXVHR0drx44d2rp1q373u98pNzdX69ats2386vi883pubq6ysrKUlpam3r17a/bs2SorK1N2drY/6gMAACGsoKBAMTExnq+rm+W6lBkzqXK6sGPHjpKkHj166LPPPlNeXp4GDBhgT/HV8LmxGj16tI4ePaqpU6eqsLBQPXr00MqVK6vEcwAAIDT48yNtYmJivBqr6nx/xmzEiBGSvpsxy8nJsXxPt9ut8vLyS6zYmkv6rMCcnByfHgQAAOBy1DZjdu+996pVq1aeqcS8vDylpaWpQ4cOKi8v13vvvac//elPmjt3rl/r5EOYAQCAqbrwIcy1zZgdOHBAYWHfjVpWVqYHHnhABw8eVMOGDdW5c2e99tprGj16tI1PUZXDMAzDr3e4iMvlUmxsrKTBkhoE8tZ1jlHx12CXAACo41wul2KbNlVpaWmtU2Z+uXdsrF6VdJXNY5+WdK8UlOfyJxIrAABgqi4kVvVFqD4XAABAwJFYAQAAUyRW1tFYAQAAUzRW1oXqcwEAAAQciRUAADDl+Pdh95ihiMQKAADAJiRWAADAVPi/D7vHDEUkVgAAADYhsQIAAKYcsj+JYY0VAAAATJFYAQAAU+xjZR2NFQAAMEVjZV2oPhcAAEDAkVgBAABTJFbWhepzAQAABByJFQAAMEViZV2oPhcAAEDAkVgBAABTJFbWhepzAQAABByJFQAAMOWQ/R9BE6ofaUNjBQAATIX/+7B7zFDEVCAAAIBNSKwAAIApFq9bF7TGqvT4EsXExATr9nXC/P8O1f+tfHP//3MHuwQAAGxBYgUAAEw5ZH/CFKqL14lMAAAAbEJiBQAATLHGyrpQfS4AAICAI7ECAACmSKyso7ECAACmaKysC9XnAgAACDgSKwAAYIrEyrpQfS4AAICAI7ECAACmHLJ/Q082CAUAAIApEisAAGAq/N+H3WOGIhIrAAAAm5BYAQAAU/xWoHU0VgAAwBSNlXWh+lwAAAABR2IFAABMOWR/EsN2CwAAADBFYgUAAEyxxsq6UH0uAACAgCOxAgAApkisrAvV5wIAAAg4EisAAGCKxMq6UH0uAACAgCOxAgAAphyyf9+pUN3HisYKAACYCv/3YfeYoYipQAAAAJuQWAEAAFMsXrfO5+fasGGDhg8frpYtW8rhcGjZsmV+KAsAAKD+8bmxKisrU/fu3TVnzhx/1AMAAOqYMD8dvpozZ46Sk5MVFRWl9PR0bdmypcZrX375ZfXt21dNmzZV06ZNlZGRYXq9XXx+rttuu01PPPGE7rrrLn/UAwAAUMUbb7yh3NxcTZs2Tdu3b1f37t01ePBgFRcXV3v9unXrNHbsWK1du1abNm1SUlKSBg0apEOHDvm1Tr9PcZaXl8vlcnkdAACg/qgLidWzzz6rCRMmKDs7WykpKZo3b56uuuoqLViwoNrrFy1apAceeEA9evRQ586d9d///d9yu91as2aNj3f2jd8bq7y8PMXGxnqOpKQkf98SAACEkLNnzyo/P18ZGRmec2FhYcrIyNCmTZssjXH69GmdO3dOzZo181eZlXX5dXRJU6ZMUWlpqecoKCjw9y0BAICNHLI/rbqwQejFs1rl5eVV7l9SUqKKigolJCR4nU9ISFBhYaGlZ/j1r3+tli1bejVn/uD3xsrpdComJsbrAAAA9Yc/pwKTkpK8Zrby8vJsr3/mzJlasmSJli5dqqioKNvH/z72sQIAAEFTUFDgFbo4nc4q1zRv3lzh4eEqKiryOl9UVKTExETT8Z9++mnNnDlTf/vb39StWzd7ijbhc2J16tQp7dixQzt27JAk7d+/Xzt27NCBAwfsrg0AANQB/kysLp7Vqq6xioyMVGpqqtfC8wsL0fv06VNj3b///e/1+OOPa+XKlUpLS7uMd8A6nxOrbdu26ZZbbvF8nZubK0nKysrSK6+8YlthAAAAF+Tm5iorK0tpaWnq3bu3Zs+erbKyMmVnZ0uS7r33XrVq1cozlfjUU09p6tSpWrx4sZKTkz1rsRo3bqzGjRv7rU6fG6sBAwbIMAx/1AIAAOqguvCRNqNHj9bRo0c1depUFRYWqkePHlq5cqVnQfuBAwcUFvbdqHPnztXZs2d19913e40zbdo0TZ8+/TKrrxlrrAAAQL2Qk5OjnJycar+3bt06r6+//vpr/xdUDRorAABgqi4kVvVFqD4XAABAwJFYAQAAUyRW1tFYAQAAUzRW1oXqcwEAAAQciRUAADBFYmVdqD4XAABAwJFYAQAAUyRW1oXqcwEAAAQciRUAADDlkORwOOwdM0Q/Ho/ECgAAwCYkVgAAwFxEhGRzYiXDkM6ft3fMOoDGCgAAmKOxsoypQAAAAJuQWAEAAHP+SqxCEIkVAACATUisAACAORIry0isAAAAbEJiBQAAzIWHS2E2ZzFut73j1REkVgAAADYhsQIAAOYiIkisLKKxAgAA5misLGMqEAAAwCYkVgAAwByJlWUkVgAAADYhsQqi+/9faHbrvnKEVwS7hDrBqAgPdgkAUL3w8MrDThWh+Xc/iRUAAIBNSKwAAIC5iAj7Eyu7PyKnjiCxAgAAsAmJFQAAMEdiZRmNFQAAMEdjZRlTgQAAADYhsQIAAOZIrCwjsQIAALAJiRUAADAXHl6ZWqFWJFYAAAA2of0EAADmIiJIrCwisQIAALAJ7ScAADBHYmUZ7xIAADBHY2UZU4EAAAA2of0EAADm/LHdgmHYO14dQWIFAABgExIrAABgzh9rrEisAAAAYIbECgAAmCOxsozECgAAwCYkVgAAwByJlWU0VgAAwByNlWVMBQIAANiExAoAAJjzxwahbre949URJFYAAAA2IbECAADm/LHGisQKAAAAZnxqrPLy8tSrVy9FR0crPj5eI0aM0J49e/xVGwAAqAsuJFZ2HyHIp8Zq/fr1mjhxoj7++GOtXr1a586d06BBg1RWVuav+gAAACRJc+bMUXJysqKiopSenq4tW7bUeO3u3bs1cuRIJScny+FwaPbs2QGp0ad2ceXKlV5fv/LKK4qPj1d+fr769etna2EAAKCOqANrrN544w3l5uZq3rx5Sk9P1+zZszV48GDt2bNH8fHxVa4/ffq02rdvr3vuuUcPPfSQXVXX6rLWWJWWlkqSmjVrVuM15eXlcrlcXgcAAIAvnn32WU2YMEHZ2dlKSUnRvHnzdNVVV2nBggXVXt+rVy/NmjVLY8aMkdPpDFidl9xYud1uPfjgg7rpppvUtWvXGq/Ly8tTbGys50hKSrrUWwIAgGDw4xqri8OX8vLyKrc/e/as8vPzlZGR4TkXFhamjIwMbdq0KWBvgxWX3FhNnDhRu3bt0pIlS0yvmzJlikpLSz1HQUHBpd4SAAAEw4UNQu08wsMlSUlJSV4BTF5eXpXbl5SUqKKiQgkJCV7nExISVFhYGJC3wKpLmjDNycnRu+++qw0bNqh169am1zqdzoBGcAAAoP4oKChQTEyM5+v63jP41FgZhqFJkyZp6dKlWrdundq1a+evugAAQF3hj8XrFRWSpJiYGK/GqjrNmzdXeHi4ioqKvM4XFRUpMTHR3rouk09TgRMnTtRrr72mxYsXKzo6WoWFhSosLNS3337rr/oAAMAVLjIyUqmpqVqzZo3nnNvt1po1a9SnT58gVlaVT+3n3LlzJUkDBgzwOr9w4UKNGzfOrpoAAEBd4sfEyqrc3FxlZWUpLS1NvXv31uzZs1VWVqbs7GxJ0r333qtWrVp51midPXtWn376qee/Dx06pB07dqhx48bq2LGjvc/yPT5PBQIAAATa6NGjdfToUU2dOlWFhYXq0aOHVq5c6VnQfuDAAYWFfTcRd/jwYfXs2dPz9dNPP62nn35a/fv317p16/xWZ2juJw8AAOxTBxIrqfKX53Jycqr93sXNUnJyclACIT6EGQAAwCYkVgAAwNyFfazsdP68vePVETRWAADAnD+mAu0er45gKhAAAMAmodkuAgAA+5BYWUZiBQAAYJPQbBcBAIB9SKwsI7ECAACwSWi2iwAAwD4kVpaRWAEAANgkNNtFAABgH39sEBoebu94dQSNFQAAMMdUoGVMBQIAANgkNNtFAABgHxIry0isAAAAbBKa7SIAALAPiZVlJFYAAAA2Cc12EQAA2IftFiwjsQIAALAJiRUAADDHGivLQvOpAACAfWisLGMqEAAAwCah2S4CAAD7kFhZRmIFAABgk9BsFwEAgH1IrCwLzadCvWJUhOZeJr5yhJ8Kdgl1glHRONglAMAlo7ECAADm2CDUMtZYAQAA2ITECgAAmGONlWWh+VQAAMA+NFaWMRUIAABgk9BsFwEAgH1IrCwjsQIAALBJaLaLAADAPmy3YBmJFQAAgE1IrAAAgDnWWFlGYgUAAGCT0GwXAQCAfUisLAvNpwIAAPahsbKMqUAAAACbhGa7CAAA7ENiZRmJFQAAgE1Cs10EAAD2YYNQy0isAAAAbEJiBQAAzLHGyjISKwAAAJuEZrsIAADsQ2JlWWg+FQAAsA+NlWVMBQIAANgkNNtFAABgH7ZbsIzECgAAwCYkVgAAwBxrrCwjsQIAALBJaLaLAADAPiRWlvmUWM2dO1fdunVTTEyMYmJi1KdPH61YscJftQEAAHjMmTNHycnJioqKUnp6urZs2WJ6/ZtvvqnOnTsrKipKP/jBD/Tee+/5vUafGqvWrVtr5syZys/P17Zt23Trrbfqzjvv1O7du/1VHwAACLYLiZXdhw/eeOMN5ebmatq0adq+fbu6d++uwYMHq7i4uNrrN27cqLFjx2r8+PH6xz/+oREjRmjEiBHatWuXHe9IjRyGYRiXM0CzZs00a9YsjR8/3tL1LpdLsbGxKj1+XDExMZdzayCkOMJPBbuEOsGoaBzsEoA6xeVyKbZpU5WWlgb856Y/f2b7+lzp6enq1auX/vjHP0qS3G63kpKSNGnSJD3yyCNVrh89erTKysr07rvves7dcMMN6tGjh+bNm2ffg1zkkhevV1RUaMmSJSorK1OfPn1qvK68vFwul8vrAAAAkFSlRygvL69yzdmzZ5Wfn6+MjAzPubCwMGVkZGjTpk3Vjrtp0yav6yVp8ODBNV5vF58bq507d6px48ZyOp366U9/qqVLlyolJaXG6/Py8hQbG+s5kpKSLqtgAAAQWG6F+eWQpKSkJK8+IS8vr8r9S0pKVFFRoYSEBK/zCQkJKiwsrLbmwsJCn663i89L8jt16qQdO3aotLRUb731lrKysrR+/foam6spU6YoNzfX87XL5aK5AgAAkqSCggKvqUCn0xnEai6fz41VZGSkOnbsKElKTU3V1q1b9fzzz+ull16q9nqn01nv3yQAAK5k589XHnaPKcmz04CZ5s2bKzw8XEVFRV7ni4qKlJiYWO1rEhMTfbreLpe9Qajb7a52PhQAAMAOkZGRSk1N1Zo1azzn3G631qxZU+M67z59+nhdL0mrV682XRduB58SqylTpui2225TmzZtdPLkSS1evFjr1q3TqlWr/FUfAAAIMn8mVlbl5uYqKytLaWlp6t27t2bPnq2ysjJlZ2dLku699161atXKs0Zr8uTJ6t+/v5555hkNGzZMS5Ys0bZt2zR//nx7H+QiPjVWxcXFuvfee3XkyBHFxsaqW7duWrVqlX74wx/6qz4AAACNHj1aR48e1dSpU1VYWKgePXpo5cqVngXqBw4cUFjYdxNxN954oxYvXqz/+q//0qOPPqprrrlGy5YtU9euXf1a52XvY+Ur9rECqsc+VpXYxwrwVhf2sSoosP/elb/MFhuU5/Kn0PygHgAAYJu6MBVYX1z24nUAAABUIrECAACmKirsT5gqKuwdr64gsQIAALAJiRUAADDFGivrSKwAAABsQmIFAABMkVhZR2IFAABgExIrAABgisTKOhIrAAAAm5BYAQAAU+xjZR2NFQAAMMVUoHVMBQIAANiExAoAAJgisbKOxAoAAMAmJFYAAMAUiZV1JFYAAAA2IbECAACmSKysI7ECAACwCYkVAAAwxQah1tFYAQAAU0wFWsdUIAAAgE1IrAAAgCkSK+tIrAAAAGxCYgUAAEyRWFlHYgUAAGATEiugjjAqGge7hDrh+jT+vSdJ27e5g10C4MF2C9bxNxgAAIBNSKwAAIAp1lhZR2MFAABM0VhZx1QgAACATUisAACAKRIr60isAAAAbEJiBQAATJFYWUdiBQAAYBMSKwAAYIoNQq0jsQIAALAJiRUAADDFGivraKwAAIApGivrmAoEAACwCYkVAAAwRWJlHYkVAACATUisAACAKbZbsI7ECgAAwCYkVgAAwBRrrKwjsQIAALAJiRUAADBFYmUdjRUAADBFY2UdU4EAAAA2IbECAACmSKysI7ECAACwCYkVAAAwxQah1pFYAQAA2ITECgAAmDp/XgoPt3/MUERiBQAAQsqxY8eUmZmpmJgYNWnSROPHj9epU6dMXzN//nwNGDBAMTExcjgcOnHixCXd+7Iaq5kzZ8rhcOjBBx+8nGEAAEAdduG3Au0+/CUzM1O7d+/W6tWr9e6772rDhg26//77TV9z+vRpDRkyRI8++uhl3fuSpwK3bt2ql156Sd26dbusAgAAQN1Wn6YCP/vsM61cuVJbt25VWlqaJOkPf/iDhg4dqqefflotW7as9nUXQqJ169Zd1v0vKbE6deqUMjMz9fLLL6tp06aXVQAAAIBdNm3apCZNmniaKknKyMhQWFiYNm/e7Pf7X1JjNXHiRA0bNkwZGRm1XlteXi6Xy+V1AACA+uPCdgt2Hhe2W7i4RygvL7+sWgsLCxUfH+91LiIiQs2aNVNhYeFljW2Fz43VkiVLtH37duXl5Vm6Pi8vT7GxsZ4jKSnJ5yIBAEBoSkpK8uoTauovHnnkETkcDtPj888/D3D1Vfm0xqqgoECTJ0/W6tWrFRUVZek1U6ZMUW5urudrl8tFcwUAQD1y/rwUZvM+AhfWWBUUFCgmJsZz3ul0Vnv9ww8/rHHjxpmO2b59eyUmJqq4uPiie53XsWPHlJiYeFk1W+FTY5Wfn6/i4mJdf/31nnMVFRXasGGD/vjHP6q8vFzhF61uczqdNb5JAADgyhYTE+PVWNUkLi5OcXFxtV7Xp08fnThxQvn5+UpNTZUkffDBB3K73UpPT7/semvjU2M1cOBA7dy50+tcdna2OnfurF//+tdVmioAAFD/+TOxsluXLl00ZMgQTZgwQfPmzdO5c+eUk5OjMWPGeH4j8NChQxo4cKBeffVV9e7dW1Ll2qzCwkLt27dPkrRz505FR0erTZs2atasmeX7+9RYRUdHq2vXrl7nGjVqpKuvvrrKeQAAgGBYtGiRcnJyNHDgQIWFhWnkyJF64YUXPN8/d+6c9uzZo9OnT3vOzZs3TzNmzPB83a9fP0nSwoULa52C/D4+0gYAAJiqT4mVJDVr1kyLFy+u8fvJyckyDMPr3PTp0zV9+vTLvvdlN1aXu5EWAACo2+pbYxVMfFYgAACATZgKBAAApi5sEGr3mKGIxAoAAMAmJFYAAMDU+fOSw2H/mKGIxAoAAMAmJFYAAMAUiZV1JFYAAAA2IbECAACmSKyso7ECAACmaKysYyoQAADAJiRWAADAVEWF/YkVG4QCAADAFIkVAAAw5Y/1UKyxAgAAgCkSKwAAYIrEyjoSKwAAAJuQWAEAAFMkVtaRWAEAANiExAoAAJjyx55TobqPFY0VAAAwdf68ZBj2jhmqjRVTgQAAADYhsQIAAKZIrKwjsQIAALAJiRUAADBFYmUdiRUAAIBNSKwAAIApEivraKwA1Cnbt7mDXUKd4AgvCHYJdYJRkRTsEgCf0FgBAABTFRX2J1buEP03FI0VAAAwdf68FGbzquxQbaxYvA4AAGATEisAAGCKxMo6EisAAACbkFgBAABTJFbWkVgBAADYhMQKAACYqqiwP2Gye/uGuoLECgAAwCYkVgAAwNT585LDYe+YoZpY0VgBAABTNFbWMRUIAABgExIrAABgisTKOhIrAAAAm5BYAQAAUyRW1pFYAQAA2ITECgAA1MItw7D7M2hC8zNtSKwAAABsQmIFAABqUfHvw+4xQw+NFQAAqAWNlVVMBQIAANiExAoAANSCxMoqEisAAACbkFgBAIBauGX/9ghstwAAAAATJFYAAKAWrLGyisQKAACElGPHjikzM1MxMTFq0qSJxo8fr1OnTpleP2nSJHXq1EkNGzZUmzZt9POf/1ylpaU+35vECgAA1MIt+xMm/62xyszM1JEjR7R69WqdO3dO2dnZuv/++7V48eJqrz98+LAOHz6sp59+WikpKfrXv/6ln/70pzp8+LDeeustn+7tMAzrny89ffp0zZgxw+tcp06d9Pnnn1u+ocvlUmxsrEqPH1dMTIz1SgHgCuIILwh2CXWCUZEU7BKCzuVyKbZpU5WWlgb85+aFn9nSPknRNo9+UlJH25/rs88+U0pKirZu3aq0tDRJ0sqVKzV06FAdPHhQLVu2tDTOm2++qR//+McqKytTRIT1HMrnqcDrrrtOR44c8Rwffvihr0MAAAD4xaZNm9SkSRNPUyVJGRkZCgsL0+bNmy2Pc6Hh86Wpki5hKjAiIkKJiYm+vgwAANRb/lu87nK5vM46nU45nc5LHrWwsFDx8fFe5yIiItSsWTMVFhZaGqOkpESPP/647r//fp/v73NitXfvXrVs2VLt27dXZmamDhw4YHp9eXm5XC6X1wEAACBJSUlJio2N9Rx5eXnVXvfII4/I4XCYHr4sTaqJy+XSsGHDlJKSounTp/v8ep8Sq/T0dL3yyivq1KmTjhw5ohkzZqhv377atWuXoqOrn3vNy8ursi4LAADUJ/7bILSgoMBrjVVNadXDDz+scePGmY7Yvn17JSYmqri42Ov8+fPndezYsVpn3E6ePKkhQ4YoOjpaS5cuVYMGDSw8hzefFq9f7MSJE2rbtq2effZZjR8/vtprysvLVV5e7vna5XIpKSmJxesAYILF65VYvF5XFq9/Jv8sXu/it8Xr27ZtU2pqqiTp/fff15AhQ0wXr7tcLg0ePFhOp1Pvvfeerrrqqku6/2XtY9WkSRNde+212rdvX43XOJ1OxcTEeB0AAKA+qfDTYb8uXbpoyJAhmjBhgrZs2aKPPvpIOTk5GjNmjKepOnTokDp37qwtW7ZIqmyqBg0apLKyMv3P//yPXC6XCgsLVVhYqIoK3+q8rMbq1KlT+vLLL9WiRYvLGQYAAMA2ixYtUufOnTVw4EANHTpUN998s+bPn+/5/rlz57Rnzx6dPn1akrR9+3Zt3rxZO3fuVMeOHdWiRQvPUVDgW3rs0xqrX/ziFxo+fLjatm2rw4cPa9q0aQoPD9fYsWN9uikAAKhP6tdH2jRr1qzGzUAlKTk5Wd9fCTVgwABdxsooLz41VgcPHtTYsWP1zTffKC4uTjfffLM+/vhjxcXF2VIMAACoi+pXYxVMPjVWS5Ys8VcdAAAA9R6fFQgAAGpRvz4rMJgua/E6AAAAvkNiBQAAauG/DUJDDYkVAACATUisAABALfitQKtIrAAAAGxCYgUAAGpBYmUVjRUAAKgFjZVVTAUCAADYhMQKAADUgsTKKhIrAAAAm5BYAQCAWhiyf0NPw+bx6gYSKwAAAJuQWAEAgFqwxsoqEisAAACbkFgBAIBakFhZRWMFAABqQWNlFVOBAAAANiGxAgAAtSCxsorECgAAwCYkVgAAoBZu2b9BqN3j1Q0kVgAAADYhsQIAALVgjZVVJFYAAAA2IbECAAC1cMv+hIk1VgAAADBBYgUAdZBRkRTsEuqEsAj+/W8YdeE9YI2VVTRWAACgFmy3YFVdaIMBAABCAokVAACoBVOBVpFYAQAA2ITECgAA1ILEyioSKwAAAJuQWAEAgFqQWFlFYgUAAGATEisAAFALPtLGKhorAABQCzYItYqpQAAAAJuQWAEAgFqweN0qEisAAACbkFgBAIBakFhZRWIFAABgExIrAABQCxIrq0isAAAAbEJiBQAAakFiZRWNFQAAqAUbhFrFVCAAAIBNSKwAAEAt+KxAq0isAAAAbEJiBQAAasHidatIrAAAAGxCYgUAAGpBYmUViRUAAIBNaKwAAEAtKvx0+MexY8eUmZmpmJgYNWnSROPHj9epU6dMX/OTn/xEHTp0UMOGDRUXF6c777xTn3/+uc/3prECAAC1qF+NVWZmpnbv3q3Vq1fr3Xff1YYNG3T//febviY1NVULFy7UZ599plWrVskwDA0aNEgVFb7V6TAMw/DlBYcOHdKvf/1rrVixQqdPn1bHjh21cOFCpaWlWXq9y+VSbGysSo8fV0xMjE/FAgCuLGER/PvfMFySYlVaWhrwn5sXfmZL/yUpyubRz0h6wvbn+uyzz5SSkqKtW7d6epOVK1dq6NChOnjwoFq2bGlpnE8++UTdu3fXvn371KFDB8v392nx+vHjx3XTTTfplltu0YoVKxQXF6e9e/eqadOmvgwDAADqFf99pI3L5fI663Q65XQ6L3nUTZs2qUmTJl6BT0ZGhsLCwrR582bdddddtY5RVlamhQsXql27dkpKSvLp/j79U+Cpp55SUlKSFi5cqN69e6tdu3YaNGiQT50cAADABUlJSYqNjfUceXl5lzVeYWGh4uPjvc5FRESoWbNmKiwsNH3tiy++qMaNG6tx48ZasWKFVq9ercjISJ/u71NjtXz5cqWlpemee+5RfHy8evbsqZdfftn0NeXl5XK5XF4HAACoTy58pI2dR2ViVVBQoNLSUs8xZcqUait45JFH5HA4TI9LWWz+fZmZmfrHP/6h9evX69prr9WoUaN05swZn8bwaSrwq6++0ty5c5Wbm6tHH31UW7du1c9//nNFRkYqKyur2tfk5eVpxowZPhUFAACuDDExMZbWWD388MMaN26c6TXt27dXYmKiiouLvc6fP39ex44dU2JiounrL6Rm11xzjW644QY1bdpUS5cu1dixY2ut7wKfGiu32620tDQ9+eSTkqSePXtq165dmjdvXo2N1ZQpU5Sbm+v52uVy+TxfCQAAgin4G4TGxcUpLi6u1uv69OmjEydOKD8/X6mpqZKkDz74QG63W+np6ZbvZxiGDMNQeXm5T3X6NBXYokULpaSkeJ3r0qWLDhw4UONrnE6npxu12pUCAABcii5dumjIkCGaMGGCtmzZoo8++kg5OTkaM2aM5zcCDx06pM6dO2vLli2SKmfk8vLylJ+frwMHDmjjxo2655571LBhQw0dOtSn+/vUWN10003as2eP17kvvvhCbdu29emmAACgPqlf+1gtWrRInTt31sCBAzV06FDdfPPNmj9/vuf7586d0549e3T69GlJUlRUlP7+979r6NCh6tixo0aPHq3o6Ght3LixykL42vg0FfjQQw/pxhtv1JNPPqlRo0Zpy5Ytmj9/vlexAAAg1PhvuwV/aNasmRYvXlzj95OTk/X9bTxbtmyp9957z5Z7+5RY9erVS0uXLtXrr7+url276vHHH9fs2bOVmZlpSzEAAAD1mU+JlSTdfvvtuv322/1RCwAAqJOCv3i9vuCzAgAAAGzic2IFAACuNCRWVpFYAQAA2ITECgAA1ILEyioSKwAAAJuQWAEAgFpc+BBmu8cMPTRWAACgFvVrg9BgYioQAADAJiRWAACgFixet4rECgAAwCYkVgAAoBYkVlaRWAEAANiExAoAANSCxMoqEisAAACbkFgBAIBakFhZRWMFAABqwc7rVjEVCAAAYBMSKwAAUAs+0sYqEisAAACbkFgBAIBaVMj+LCY0F6+TWAEAANiExAoAANSCxMoqEisAAACbkFgBAIBakFhZRWMFAABqwXYLVgW8sTIMQ5LkcrkCfWsAQD1jGKxYkSp/Xl74+Rkc5+vJmMEX8Mbq5MmTkqSktm0DfWsAAOqtkydPKjY2NqD3jIyMVGJiogoL/+aX8RMTExUZGemXsYPFYQS4BXa73Tp8+LCio6PlcDgCeWsPl8ulpKQkFRQUKCYmJig11AW8D5V4HyrxPlTifajE+1CpLrwPhmHo5MmTatmypcLCAp/gnTlzRmfPnvXL2JGRkYqKivLL2MES8MQqLCxMrVu3DvRtqxUTE3NF/4VxAe9DJd6HSrwPlXgfKvE+VAr2+xDopOr7oqKiQq758ScmrwEAAGxCYwUAAGCTK7KxcjqdmjZtmpxOZ7BLCSreh0q8D5V4HyrxPlTifajE+wBfBXzxOgAAQKi6IhMrAAAAf6CxAgAAsAmNFQAAgE1orAAAAGxyxTVWc+bMUXJysqKiopSenq4tW7YEu6SA27Bhg4YPH66WLVvK4XBo2bJlwS4p4PLy8tSrVy9FR0crPj5eI0aM0J49e4JdVsDNnTtX3bp182x+2KdPH61YsSLYZQXdzJkz5XA49OCDDwa7lICaPn26HA6H19G5c+dglxUUhw4d0o9//GNdffXVatiwoX7wgx9o27ZtwS4L9cAV1Vi98cYbys3N1bRp07R9+3Z1795dgwcPVnFxcbBLC6iysjJ1795dc+bMCXYpQbN+/XpNnDhRH3/8sVavXq1z585p0KBBKisrC3ZpAdW6dWvNnDlT+fn52rZtm2699Vbdeeed2r17d7BLC5qtW7fqpZdeUrdu3YJdSlBcd911OnLkiOf48MMPg11SwB0/flw33XSTGjRooBUrVujTTz/VM888o6ZNmwa7NNQDV9R2C+np6erVq5f++Mc/Sqr83MKkpCRNmjRJjzzySJCrCw6Hw6GlS5dqxIgRwS4lqI4ePar4+HitX79e/fr1C3Y5QdWsWTPNmjVL48ePD3YpAXfq1Cldf/31evHFF/XEE0+oR48emj17drDLCpjp06dr2bJl2rFjR7BLCapHHnlEH330kf7+978HuxTUQ1dMYnX27Fnl5+crIyPDcy4sLEwZGRnatGlTECtDXVBaWiqpsqm4UlVUVGjJkiUqKytTnz59gl1OUEycOFHDhg3z+nviSrN37161bNlS7du3V2Zmpg4cOBDskgJu+fLlSktL0z333KP4+Hj17NlTL7/8crDLQj1xxTRWJSUlqqioUEJCgtf5hIQEFRYWBqkq1AVut1sPPvigbrrpJnXt2jXY5QTczp071bhxYzmdTv30pz/V0qVLlZKSEuyyAm7JkiXavn278vLygl1K0KSnp+uVV17RypUrNXfuXO3fv199+/bVyZMng11aQH311VeaO3eurrnmGq1atUo/+9nP9POf/1z/+7//G+zSUA9EBLsAINgmTpyoXbt2XZFrSSSpU6dO2rFjh0pLS/XWW28pKytL69evv6Kaq4KCAk2ePFmrV69WVFRUsMsJmttuu83z3926dVN6erratm2rP//5z1fU1LDb7VZaWpqefPJJSVLPnj21a9cuzZs3T1lZWUGuDnXdFZNYNW/eXOHh4SoqKvI6X1RUpMTExCBVhWDLycnRu+++q7Vr16p169bBLicoIiMj1bFjR6WmpiovL0/du3fX888/H+yyAio/P1/FxcW6/vrrFRERoYiICK1fv14vvPCCIiIiVFFREewSg6JJkya69tprtW/fvmCXElAtWrSo8g+LLl26XJHTovDdFdNYRUZGKjU1VWvWrPGcc7vdWrNmzRW7nuRKZhiGcnJytHTpUn3wwQdq165dsEuqM9xut8rLy4NdRkANHDhQO3fu1I4dOzxHWlqaMjMztWPHDoWHhwe7xKA4deqUvvzyS7Vo0SLYpQTUTTfdVGX7lS+++EJt27YNUkWoT66oqcDc3FxlZWUpLS1NvXv31uzZs1VWVqbs7OxglxZQp06d8voX6P79+7Vjxw41a9ZMbdq0CWJlgTNx4kQtXrxYf/3rXxUdHe1ZZxcbG6uGDRsGubrAmTJlim677Ta1adNGJ0+e1OLFi7Vu3TqtWrUq2KUFVHR0dJX1dY0aNdLVV199Ra27+8UvfqHhw4erbdu2Onz4sKZNm6bw8HCNHTs22KUF1EMPPaQbb7xRTz75pEaNGqUtW7Zo/vz5mj9/frBLQ31gXGH+8Ic/GG3atDEiIyON3r17Gx9//HGwSwq4tWvXGpKqHFlZWcEuLWCqe35JxsKFC4NdWkDdd999Rtu2bY3IyEgjLi7OGDhwoPH+++8Hu6w6oX///sbkyZODXUZAjR492mjRooURGRlptGrVyhg9erSxb9++YJcVFO+8847RtWtXw+l0Gp07dzbmz58f7JJQT1xR+1gBAAD40xWzxgoAAMDfaKwAAABsQmMFAABgExorAAAAm9BYAQAA2ITGCgAAwCY0VgAAADahsQIAALAJjRUAAIBNaKwAAABsQmMFAABgExorAAAAm/x/Hc7NBxFn+e8AAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "imshow_zero_center(j[:, 0, :, 0])\n", "_ = plt.title('A (batch, batch) slice')" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:19.681818Z", "iopub.status.busy": "2024-08-15T02:32:19.681575Z", "iopub.status.idle": "2024-08-15T02:32:19.898017Z", "shell.execute_reply": "2024-08-15T02:32:19.897397Z" }, "id": "g4ZoRJcJNmy5" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def plot_as_patches(j):\n", " # Reorder axes so the diagonals will each form a contiguous patch.\n", " j = tf.transpose(j, [1, 0, 3, 2])\n", " # Pad in between each patch.\n", " lim = tf.reduce_max(abs(j))\n", " j = tf.pad(j, [[0, 0], [1, 1], [0, 0], [1, 1]],\n", " constant_values=-lim)\n", " # Reshape to form a single image.\n", " s = j.shape\n", " j = tf.reshape(j, [s[0]*s[1], s[2]*s[3]])\n", " imshow_zero_center(j, extent=[-0.5, s[2]-0.5, s[0]-0.5, -0.5])\n", "\n", "plot_as_patches(j)\n", "_ = plt.title('All (batch, batch) slices are diagonal')" ] }, { "cell_type": "markdown", "metadata": { "id": "OXpTBKyeK84z" }, "source": [ "To get the desired result, you can sum over the duplicate `batch` dimension, or else select the diagonals using `tf.einsum`:" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:19.901354Z", "iopub.status.busy": "2024-08-15T02:32:19.901103Z", "iopub.status.idle": "2024-08-15T02:32:19.907632Z", "shell.execute_reply": "2024-08-15T02:32:19.906962Z" }, "id": "v65OAjEgLQwl" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(7, 6, 5)\n", "(7, 6, 5)\n" ] } ], "source": [ "j_sum = tf.reduce_sum(j, axis=2)\n", "print(j_sum.shape)\n", "j_select = tf.einsum('bxby->bxy', j)\n", "print(j_select.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "zT_VfR6lcwxD" }, "source": [ "It would be much more efficient to do the calculation without the extra dimension in the first place. The `tf.GradientTape.batch_jacobian` method does exactly that:" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:19.910527Z", "iopub.status.busy": "2024-08-15T02:32:19.910292Z", "iopub.status.idle": "2024-08-15T02:32:20.066210Z", "shell.execute_reply": "2024-08-15T02:32:20.065609Z" }, "id": "YJLIl9WpHqYq" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:5 out of the last 5 calls to .f at 0x7f968c10d700> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:5 out of the last 5 calls to .f at 0x7f968c10d700> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "data": { "text/plain": [ "TensorShape([7, 6, 5])" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jb = tape.batch_jacobian(y, x)\n", "jb.shape" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:20.069156Z", "iopub.status.busy": "2024-08-15T02:32:20.068902Z", "iopub.status.idle": "2024-08-15T02:32:20.074494Z", "shell.execute_reply": "2024-08-15T02:32:20.073908Z" }, "id": "-5t_q5SfHw7T" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.0\n" ] } ], "source": [ "error = tf.reduce_max(abs(jb - j_sum))\n", "assert error < 1e-3\n", "print(error.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "IUeY2ZCiL31I" }, "source": [ "Caution: `tf.GradientTape.batch_jacobian` only verifies that the first dimension of the source and target match. It doesn't check that the gradients are actually independent. It's up to you to make sure you only use `batch_jacobian` where it makes sense. For example, adding a `tf.keras.layers.BatchNormalization` destroys the independence, since it normalizes across the `batch` dimension:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:20.077545Z", "iopub.status.busy": "2024-08-15T02:32:20.077282Z", "iopub.status.idle": "2024-08-15T02:32:20.705360Z", "shell.execute_reply": "2024-08-15T02:32:20.704672Z" }, "id": "tnDugVc-L4fj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:6 out of the last 6 calls to .f at 0x7f967c72d430> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:6 out of the last 6 calls to .f at 0x7f967c72d430> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "j.shape: (7, 6, 7, 5)\n" ] } ], "source": [ "x = tf.random.normal([7, 5])\n", "\n", "layer1 = tf.keras.layers.Dense(8, activation=tf.nn.elu)\n", "bn = tf.keras.layers.BatchNormalization()\n", "layer2 = tf.keras.layers.Dense(6, activation=tf.nn.elu)\n", "\n", "with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:\n", " tape.watch(x)\n", " y = layer1(x)\n", " y = bn(y, training=True)\n", " y = layer2(y)\n", "\n", "j = tape.jacobian(y, x)\n", "print(f'j.shape: {j.shape}')" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:20.708790Z", "iopub.status.busy": "2024-08-15T02:32:20.708535Z", "iopub.status.idle": "2024-08-15T02:32:20.930861Z", "shell.execute_reply": "2024-08-15T02:32:20.930156Z" }, "id": "SNyZ1WhJMVLm" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_as_patches(j)\n", "\n", "_ = plt.title('These slices are not diagonal')\n", "_ = plt.xlabel(\"Don't use `batch_jacobian`\")" ] }, { "cell_type": "markdown", "metadata": { "id": "M_x7ih5sarvG" }, "source": [ "In this case, `batch_jacobian` still runs and returns _something_ with the expected shape, but its contents have an unclear meaning:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T02:32:20.934532Z", "iopub.status.busy": "2024-08-15T02:32:20.934278Z", "iopub.status.idle": "2024-08-15T02:32:21.328241Z", "shell.execute_reply": "2024-08-15T02:32:21.327515Z" }, "id": "k8_mICHoasCi" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "jb.shape: (7, 6, 5)\n" ] } ], "source": [ "jb = tape.batch_jacobian(y, x)\n", "print(f'jb.shape: {jb.shape}')" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "advanced_autodiff.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }