{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-08-15T01:30:06.291838Z", "iopub.status.busy": "2024-08-15T01:30:06.291620Z", "iopub.status.idle": "2024-08-15T01:30:06.295536Z", "shell.execute_reply": "2024-08-15T01:30:06.294994Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "qFdPvlXBOdUN" }, "source": [ "# Introduction to gradients and automatic differentiation" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "r6P32iYYV27b" }, "source": [ "## Automatic Differentiation and Gradients\n", "\n", "[Automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)\n", "is useful for implementing machine learning algorithms such as\n", "[backpropagation](https://en.wikipedia.org/wiki/Backpropagation) for training\n", "neural networks.\n", "\n", "In this guide, you will explore ways to compute gradients with TensorFlow, especially in eager execution." ] }, { "cell_type": "markdown", "metadata": { "id": "MUXex9ctTuDB" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:06.298872Z", "iopub.status.busy": "2024-08-15T01:30:06.298622Z", "iopub.status.idle": "2024-08-15T01:30:08.870485Z", "shell.execute_reply": "2024-08-15T01:30:08.869676Z" }, "id": "IqR2PQG4ZaZ0" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-08-15 01:30:07.003169: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-08-15 01:30:07.023862: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-08-15 01:30:07.029954: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorflow as tf" ] }, { "cell_type": "markdown", "metadata": { "id": "xHxb-dlhMIzW" }, "source": [ "## Computing gradients\n", "\n", "To differentiate automatically, TensorFlow needs to remember what operations happen in what order during the *forward* pass. Then, during the *backward pass*, TensorFlow traverses this list of operations in reverse order to compute gradients." ] }, { "cell_type": "markdown", "metadata": { "id": "1CLWJl0QliB0" }, "source": [ "## Gradient tapes\n", "\n", "TensorFlow provides the `tf.GradientTape` API for automatic differentiation; that is, computing the gradient of a computation with respect to some inputs, usually `tf.Variable`s.\n", "TensorFlow \"records\" relevant operations executed inside the context of a `tf.GradientTape` onto a \"tape\". TensorFlow then uses that tape to compute the gradients of a \"recorded\" computation using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n", "\n", "Here is a simple example:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:08.874835Z", "iopub.status.busy": "2024-08-15T01:30:08.874386Z", "iopub.status.idle": "2024-08-15T01:30:11.490841Z", "shell.execute_reply": "2024-08-15T01:30:11.490090Z" }, "id": "Xq9GgTCP7a4A" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1723685409.408818 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685409.412555 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685409.416343 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685409.420087 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685409.431667 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685409.435229 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685409.438777 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685409.442350 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685409.445712 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685409.449141 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685409.452491 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685409.456034 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.685265 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.687389 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.689411 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.691490 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.693542 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.695541 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.697441 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.699432 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.701351 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.703333 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.705229 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.707222 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.744994 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.747037 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.749507 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.751538 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.753500 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.755501 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.757421 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.759404 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.761363 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.763858 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.766199 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723685410.768560 20970 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" ] } ], "source": [ "x = tf.Variable(3.0)\n", "\n", "with tf.GradientTape() as tape:\n", " y = x**2" ] }, { "cell_type": "markdown", "metadata": { "id": "CR9tFAP_7cra" }, "source": [ "Once you've recorded some operations, use `GradientTape.gradient(target, sources)` to calculate the gradient of some target (often a loss) relative to some source (often the model's variables):" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:11.494831Z", "iopub.status.busy": "2024-08-15T01:30:11.494515Z", "iopub.status.idle": "2024-08-15T01:30:11.507497Z", "shell.execute_reply": "2024-08-15T01:30:11.506851Z" }, "id": "LsvrwF6bHroC" }, "outputs": [ { "data": { "text/plain": [ "6.0" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# dy = 2x * dx\n", "dy_dx = tape.gradient(y, x)\n", "dy_dx.numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "Q2_aqsO25Vx1" }, "source": [ "The above example uses scalars, but `tf.GradientTape` works as easily on any tensor:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:11.510543Z", "iopub.status.busy": "2024-08-15T01:30:11.510294Z", "iopub.status.idle": "2024-08-15T01:30:11.575341Z", "shell.execute_reply": "2024-08-15T01:30:11.574704Z" }, "id": "vacZ3-Ws5VdV" }, "outputs": [], "source": [ "w = tf.Variable(tf.random.normal((3, 2)), name='w')\n", "b = tf.Variable(tf.zeros(2, dtype=tf.float32), name='b')\n", "x = [[1., 2., 3.]]\n", "\n", "with tf.GradientTape(persistent=True) as tape:\n", " y = x @ w + b\n", " loss = tf.reduce_mean(y**2)" ] }, { "cell_type": "markdown", "metadata": { "id": "i4eXOkrQ-9Pb" }, "source": [ "To get the gradient of `loss` with respect to both variables, you can pass both as sources to the `gradient` method. The tape is flexible about how sources are passed and will accept any nested combination of lists or dictionaries and return the gradient structured the same way (see `tf.nest`)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:11.579188Z", "iopub.status.busy": "2024-08-15T01:30:11.578956Z", "iopub.status.idle": "2024-08-15T01:30:11.597447Z", "shell.execute_reply": "2024-08-15T01:30:11.596846Z" }, "id": "luOtK1Da_BR0" }, "outputs": [], "source": [ "[dl_dw, dl_db] = tape.gradient(loss, [w, b])" ] }, { "cell_type": "markdown", "metadata": { "id": "Ei4iVXi6qgM7" }, "source": [ "The gradient with respect to each source has the shape of the source:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:11.600712Z", "iopub.status.busy": "2024-08-15T01:30:11.600470Z", "iopub.status.idle": "2024-08-15T01:30:11.604206Z", "shell.execute_reply": "2024-08-15T01:30:11.603480Z" }, "id": "aYbWRFPZqk4U" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(3, 2)\n", "(3, 2)\n" ] } ], "source": [ "print(w.shape)\n", "print(dl_dw.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "dI_SzxHsvao1" }, "source": [ "Here is the gradient calculation again, this time passing a dictionary of variables:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:11.607301Z", "iopub.status.busy": "2024-08-15T01:30:11.607081Z", "iopub.status.idle": "2024-08-15T01:30:11.614206Z", "shell.execute_reply": "2024-08-15T01:30:11.613564Z" }, "id": "d73cY6NOuaMd" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "my_vars = {\n", " 'w': w,\n", " 'b': b\n", "}\n", "\n", "grad = tape.gradient(loss, my_vars)\n", "grad['b']" ] }, { "cell_type": "markdown", "metadata": { "id": "HZ2LvHifEMgO" }, "source": [ "## Gradients with respect to a model\n", "\n", "It's common to collect `tf.Variables` into a `tf.Module` or one of its subclasses (`layers.Layer`, `keras.Model`) for [checkpointing](checkpoint.ipynb) and [exporting](saved_model.ipynb).\n", "\n", "In most cases, you will want to calculate gradients with respect to a model's trainable variables. Since all subclasses of `tf.Module` aggregate their variables in the `Module.trainable_variables` property, you can calculate these gradients in a few lines of code: " ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:11.617798Z", "iopub.status.busy": "2024-08-15T01:30:11.617178Z", "iopub.status.idle": "2024-08-15T01:30:12.317074Z", "shell.execute_reply": "2024-08-15T01:30:12.316345Z" }, "id": "JvesHtbQESc-" }, "outputs": [], "source": [ "layer = tf.keras.layers.Dense(2, activation='relu')\n", "x = tf.constant([[1., 2., 3.]])\n", "\n", "with tf.GradientTape() as tape:\n", " # Forward pass\n", " y = layer(x)\n", " loss = tf.reduce_mean(y**2)\n", "\n", "# Calculate gradients with respect to every trainable variable\n", "grad = tape.gradient(loss, layer.trainable_variables)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.321035Z", "iopub.status.busy": "2024-08-15T01:30:12.320769Z", "iopub.status.idle": "2024-08-15T01:30:12.324853Z", "shell.execute_reply": "2024-08-15T01:30:12.324246Z" }, "id": "PR_ezr6UFrpI" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "kernel, shape: (3, 2)\n", "bias, shape: (2,)\n" ] } ], "source": [ "for var, g in zip(layer.trainable_variables, grad):\n", " print(f'{var.name}, shape: {g.shape}')" ] }, { "cell_type": "markdown", "metadata": { "id": "f6Gx6LS714zR" }, "source": [ "\n", "\n", "## Controlling what the tape watches" ] }, { "cell_type": "markdown", "metadata": { "id": "N4VlqKFzzGaC" }, "source": [ "The default behavior is to record all operations after accessing a trainable `tf.Variable`. The reasons for this are:\n", "\n", "* The tape needs to know which operations to record in the forward pass to calculate the gradients in the backwards pass.\n", "* The tape holds references to intermediate outputs, so you don't want to record unnecessary operations.\n", "* The most common use case involves calculating the gradient of a loss with respect to all a model's trainable variables.\n", "\n", "For example, the following fails to calculate a gradient because the `tf.Tensor` is not \"watched\" by default, and the `tf.Variable` is not trainable:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.328318Z", "iopub.status.busy": "2024-08-15T01:30:12.328063Z", "iopub.status.idle": "2024-08-15T01:30:12.338607Z", "shell.execute_reply": "2024-08-15T01:30:12.338012Z" }, "id": "Kj9gPckdB37a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(6.0, shape=(), dtype=float32)\n", "None\n", "None\n", "None\n" ] } ], "source": [ "# A trainable variable\n", "x0 = tf.Variable(3.0, name='x0')\n", "# Not trainable\n", "x1 = tf.Variable(3.0, name='x1', trainable=False)\n", "# Not a Variable: A variable + tensor returns a tensor.\n", "x2 = tf.Variable(2.0, name='x2') + 1.0\n", "# Not a variable\n", "x3 = tf.constant(3.0, name='x3')\n", "\n", "with tf.GradientTape() as tape:\n", " y = (x0**2) + (x1**2) + (x2**2)\n", "\n", "grad = tape.gradient(y, [x0, x1, x2, x3])\n", "\n", "for g in grad:\n", " print(g)" ] }, { "cell_type": "markdown", "metadata": { "id": "RkcpQnLgNxgi" }, "source": [ "You can list the variables being watched by the tape using the `GradientTape.watched_variables` method:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.342261Z", "iopub.status.busy": "2024-08-15T01:30:12.341657Z", "iopub.status.idle": "2024-08-15T01:30:12.346467Z", "shell.execute_reply": "2024-08-15T01:30:12.345654Z" }, "id": "hwNwjW1eAkib" }, "outputs": [ { "data": { "text/plain": [ "['x0:0']" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[var.name for var in tape.watched_variables()]" ] }, { "cell_type": "markdown", "metadata": { "id": "NB9I1uFvB4tf" }, "source": [ "`tf.GradientTape` provides hooks that give the user control over what is or is not watched.\n", "\n", "To record gradients with respect to a `tf.Tensor`, you need to call `GradientTape.watch(x)`:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.349661Z", "iopub.status.busy": "2024-08-15T01:30:12.349387Z", "iopub.status.idle": "2024-08-15T01:30:12.354618Z", "shell.execute_reply": "2024-08-15T01:30:12.353987Z" }, "id": "tVN1QqFRDHBK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6.0\n" ] } ], "source": [ "x = tf.constant(3.0)\n", "with tf.GradientTape() as tape:\n", " tape.watch(x)\n", " y = x**2\n", "\n", "# dy = 2x * dx\n", "dy_dx = tape.gradient(y, x)\n", "print(dy_dx.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "qxsiYnf2DN8K" }, "source": [ "Conversely, to disable the default behavior of watching all `tf.Variables`, set `watch_accessed_variables=False` when creating the gradient tape. This calculation uses two variables, but only connects the gradient for one of the variables:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.358163Z", "iopub.status.busy": "2024-08-15T01:30:12.357581Z", "iopub.status.idle": "2024-08-15T01:30:12.451991Z", "shell.execute_reply": "2024-08-15T01:30:12.451324Z" }, "id": "7QPzwWvSEwIp" }, "outputs": [], "source": [ "x0 = tf.Variable(0.0)\n", "x1 = tf.Variable(10.0)\n", "\n", "with tf.GradientTape(watch_accessed_variables=False) as tape:\n", " tape.watch(x1)\n", " y0 = tf.math.sin(x0)\n", " y1 = tf.nn.softplus(x1)\n", " y = y0 + y1\n", " ys = tf.reduce_sum(y)" ] }, { "cell_type": "markdown", "metadata": { "id": "TRduLbE1H2IJ" }, "source": [ "Since `GradientTape.watch` was not called on `x0`, no gradient is computed with respect to it:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.456339Z", "iopub.status.busy": "2024-08-15T01:30:12.455695Z", "iopub.status.idle": "2024-08-15T01:30:12.521151Z", "shell.execute_reply": "2024-08-15T01:30:12.520521Z" }, "id": "e6GM-3evH1Sz" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dy/dx0: None\n", "dy/dx1: 0.9999546\n" ] } ], "source": [ "# dys/dx1 = exp(x1) / (1 + exp(x1)) = sigmoid(x1)\n", "grad = tape.gradient(ys, {'x0': x0, 'x1': x1})\n", "\n", "print('dy/dx0:', grad['x0'])\n", "print('dy/dx1:', grad['x1'].numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "2g1nKB6P-OnA" }, "source": [ "## Intermediate results\n", "\n", "You can also request gradients of the output with respect to intermediate values computed inside the `tf.GradientTape` context." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.524608Z", "iopub.status.busy": "2024-08-15T01:30:12.524366Z", "iopub.status.idle": "2024-08-15T01:30:12.530956Z", "shell.execute_reply": "2024-08-15T01:30:12.530308Z" }, "id": "7XaPRAwUyYms" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "18.0\n" ] } ], "source": [ "x = tf.constant(3.0)\n", "\n", "with tf.GradientTape() as tape:\n", " tape.watch(x)\n", " y = x * x\n", " z = y * y\n", "\n", "# Use the tape to compute the gradient of z with respect to the\n", "# intermediate value y.\n", "# dz_dy = 2 * y and y = x ** 2 = 9\n", "print(tape.gradient(z, y).numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "ISkXuY7YzIcS" }, "source": [ "By default, the resources held by a `GradientTape` are released as soon as the `GradientTape.gradient` method is called. To compute multiple gradients over the same computation, create a gradient tape with `persistent=True`. This allows multiple calls to the `gradient` method as resources are released when the tape object is garbage collected. For example:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.534481Z", "iopub.status.busy": "2024-08-15T01:30:12.533804Z", "iopub.status.idle": "2024-08-15T01:30:12.541218Z", "shell.execute_reply": "2024-08-15T01:30:12.540652Z" }, "id": "zZaCm3-9zVCi" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 4. 108.]\n", "[2. 6.]\n" ] } ], "source": [ "x = tf.constant([1, 3.0])\n", "with tf.GradientTape(persistent=True) as tape:\n", " tape.watch(x)\n", " y = x * x\n", " z = y * y\n", "\n", "print(tape.gradient(z, x).numpy()) # [4.0, 108.0] (4 * x**3 at x = [1.0, 3.0])\n", "print(tape.gradient(y, x).numpy()) # [2.0, 6.0] (2 * x at x = [1.0, 3.0])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.544392Z", "iopub.status.busy": "2024-08-15T01:30:12.543861Z", "iopub.status.idle": "2024-08-15T01:30:12.546772Z", "shell.execute_reply": "2024-08-15T01:30:12.546207Z" }, "id": "j8bv_jQFg6CN" }, "outputs": [], "source": [ "del tape # Drop the reference to the tape" ] }, { "cell_type": "markdown", "metadata": { "id": "O_ZY-9BUB7vX" }, "source": [ "## Notes on performance\n", "\n", "* There is a tiny overhead associated with doing operations inside a gradient tape context. For most eager execution this will not be a noticeable cost, but you should still use tape context around the areas only where it is required.\n", "\n", "* Gradient tapes use memory to store intermediate results, including inputs and outputs, for use during the backwards pass.\n", "\n", " For efficiency, some ops (like `ReLU`) don't need to keep their intermediate results and they are pruned during the forward pass. However, if you use `persistent=True` on your tape, *nothing is discarded* and your peak memory usage will be higher." ] }, { "cell_type": "markdown", "metadata": { "id": "9dLBpZsJebFq" }, "source": [ "## Gradients of non-scalar targets" ] }, { "cell_type": "markdown", "metadata": { "id": "7pldU9F5duP2" }, "source": [ "A gradient is fundamentally an operation on a scalar." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.550634Z", "iopub.status.busy": "2024-08-15T01:30:12.549991Z", "iopub.status.idle": "2024-08-15T01:30:12.619656Z", "shell.execute_reply": "2024-08-15T01:30:12.618932Z" }, "id": "qI0sDV_WeXBb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.0\n", "-0.25\n" ] } ], "source": [ "x = tf.Variable(2.0)\n", "with tf.GradientTape(persistent=True) as tape:\n", " y0 = x**2\n", " y1 = 1 / x\n", "\n", "print(tape.gradient(y0, x).numpy())\n", "print(tape.gradient(y1, x).numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "COEyYp34fxj4" }, "source": [ "Thus, if you ask for the gradient of multiple targets, the result for each source is:\n", "\n", "* The gradient of the sum of the targets, or equivalently\n", "* The sum of the gradients of each target." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.623092Z", "iopub.status.busy": "2024-08-15T01:30:12.622831Z", "iopub.status.idle": "2024-08-15T01:30:12.630160Z", "shell.execute_reply": "2024-08-15T01:30:12.629521Z" }, "id": "o4a6_YOcfWKS" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.75\n" ] } ], "source": [ "x = tf.Variable(2.0)\n", "with tf.GradientTape() as tape:\n", " y0 = x**2\n", " y1 = 1 / x\n", "\n", "print(tape.gradient({'y0': y0, 'y1': y1}, x).numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "uvP-mkBMgbym" }, "source": [ "Similarly, if the target(s) are not scalar the gradient of the sum is calculated:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.633668Z", "iopub.status.busy": "2024-08-15T01:30:12.633412Z", "iopub.status.idle": "2024-08-15T01:30:12.639753Z", "shell.execute_reply": "2024-08-15T01:30:12.639157Z" }, "id": "DArPWqsSh5un" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "7.0\n" ] } ], "source": [ "x = tf.Variable(2.)\n", "\n", "with tf.GradientTape() as tape:\n", " y = x * [3., 4.]\n", "\n", "print(tape.gradient(y, x).numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "flDbx68Zh5Lb" }, "source": [ "This makes it simple to take the gradient of the sum of a collection of losses, or the gradient of the sum of an element-wise loss calculation.\n", "\n", "If you need a separate gradient for each item, refer to [Jacobians](advanced_autodiff.ipynb#jacobians)." ] }, { "cell_type": "markdown", "metadata": { "id": "iwFswok8RAly" }, "source": [ "In some cases you can skip the Jacobian. For an element-wise calculation, the gradient of the sum gives the derivative of each element with respect to its input-element, since each element is independent:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:12.643198Z", "iopub.status.busy": "2024-08-15T01:30:12.642763Z", "iopub.status.idle": "2024-08-15T01:30:13.331633Z", "shell.execute_reply": "2024-08-15T01:30:13.330889Z" }, "id": "JQvk_jnMmTDS" }, "outputs": [], "source": [ "x = tf.linspace(-10.0, 10.0, 200+1)\n", "\n", "with tf.GradientTape() as tape:\n", " tape.watch(x)\n", " y = tf.nn.sigmoid(x)\n", "\n", "dy_dx = tape.gradient(y, x)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:13.335667Z", "iopub.status.busy": "2024-08-15T01:30:13.335420Z", "iopub.status.idle": "2024-08-15T01:30:13.499401Z", "shell.execute_reply": "2024-08-15T01:30:13.498786Z" }, "id": "e_f2QgDPmcPE" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGwCAYAAAB7MGXBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABPSUlEQVR4nO3dd3hUZf7+8fdMekgDQgoQepcqSARULFFERbEtlhXEuooV15/irrLoV7AirrKirmJBV9RVUEFcQAGRIlIUkd4JJBBKejLJzPP7Y5KBQAKZkORkJvfruubKzGnzOTlk5uY5z3mOzRhjEBEREbGI3eoCREREpH5TGBERERFLKYyIiIiIpRRGRERExFIKIyIiImIphRERERGxlMKIiIiIWCrQ6gIqw+VysXfvXiIjI7HZbFaXIyIiIpVgjCE7O5umTZtit1fc/uETYWTv3r0kJSVZXYaIiIhUwe7du2nevHmF830ijERGRgLunYmKirK4GhEREamMrKwskpKSPN/jFfGJMFJ6aiYqKkphRERExMecqouFOrCKiIiIpRRGRERExFIKIyIiImIpn+gzUhkulwuHw2F1GX4nKCiIgIAAq8sQERE/5hdhxOFwsH37dlwul9Wl+KWYmBgSEhI0xouIiNQInw8jxhj27dtHQEAASUlJJx1URbxjjCEvL4/9+/cDkJiYaHFFIiLij3w+jBQXF5OXl0fTpk0JDw+3uhy/ExYWBsD+/fuJi4vTKRsREal2Pt+M4HQ6AQgODra4Ev9VGvKKioosrkRERPyRz4eRUurPUHP0uxURkZrkN2FEREREfJPXYWTRokUMGTKEpk2bYrPZmDFjxinXWbBgAWeeeSYhISG0a9eO9957rwqlioiIiD/yOozk5ubSo0cPJk+eXKnlt2/fzuWXX84FF1zAmjVreOihh7jjjjv47rvvvC5WRERE/I/XV9MMHjyYwYMHV3r5KVOm0Lp1a15++WUAOnfuzOLFi3nllVcYNGiQt28vIiJSpxhjcBn3TwMYUzIdUzL/2GXLTjfHbMMcs0zJBiq9rOHoAsfXYI6Zfnzdx2oSGUJIoDVXTNb4pb1Lly4lJSWlzLRBgwbx0EMPVbhOYWEhhYWFntdZWVk1VZ6IiFjIGIPD6aLA4SK/yOl+ONw/C455Xvq6oMhJkdPgKHZR5HRR7HI/L3a5KCo2FLlcFDkNRSXTHE5DsdO9bJHTUOxy4XS539fpMriM+8vaadzPXS5weea5n7unH33tdJVdxxz/Le+jvri3P2e2aGjJe9d4GElLSyM+Pr7MtPj4eLKyssjPz/eMY3GsCRMmMG7cuCq9nzGG/CJnldY9XWFBAZW+8uSDDz7g4YcfZu/evYSEhHimDx06lMjISD788MOaKlNEpNoZY8jKLyYjt5CDOQ4O5hSSkevgcK6DrPwisguKySo47mfJdIdTo2efjtKvHRtHr360lZluc0/A86P87dRUgZVQJwc9GzNmDKNHj/a8zsrKIikpqVLr5hc56fKUNf1R/nh6EOHBlfuVXn/99TzwwAN89dVXXH/99YB7YLFZs2bxv//9rybLFBHxijGGjBwHqUfy2Xck3/0zs4C9R/LZm1lAWmY+B3McFLtOr4kgKMBGaGAAocEBhAW5H+7ndvfr4ABCAwMIDrQTGGAjKMBe8ij7PNBuJyjQTpC9ZPoxzwMCbATYbATYbdhsYC95bi957nnYy86z2dzrHTvvhOflffkffYrNZjvmeclyHA0N5U2vcF0/G3KhxsNIQkIC6enpZaalp6cTFRVVbqsIQEhISJnWAn8UFhbGTTfdxNSpUz1hZNq0abRo0YLzzz/f2uJEpF5yFLvYeTCXrQdy2Xogp+SRy7YDOWQXFFdqG5GhgcRGhNC4QTCNI4Jp1CCYqNAgosKCiAwNJCq05OcxryNCAwkLCiAoQKNN1Fc1Hkb69evH7Nmzy0ybO3cu/fr1q5H3CwsK4I+nrekYGxbkXcefO++8k7POOovU1FSaNWvGe++9x6233up3iVdE6h5HsYv1+7JYm5rJur2ZrE3NZGNaNkXO8ls3bDaIjwylaUwoiTFhNIsJo2m0+3lidChNIkNo1CDYsg6Q4tu8DiM5OTls2bLF83r79u2sWbOGRo0a0aJFC8aMGUNqaioffPABAH/5y194/fXX+X//7/9x22238f333/Ppp58ya9as6tuLY9hstkqfKrFar1696NGjBx988AGXXHIJ69atq7Hfi4jUb3mOYlbtPMLPOw6xYvshVu8+TEHRiX01IkICadukAW2aRBzzM4KWjcMJ9fI/XCKV5fW39i+//MIFF1zgeV3at2PEiBG899577Nu3j127dnnmt27dmlmzZvHwww/z6quv0rx5c/7973/rst4Sd9xxB5MmTSI1NZWUlJRK940RETkZYwzbM3L5YeMBFmzcz/Jth07oKNowPIhuzWPo1iyKrk2j6dosmuYNw9Q6K7XOZo6/0LgOysrKIjo6mszMTKKiosrMKygoYPv27bRu3ZrQ0FCLKqy6zMxMmjZtSnFxMR988AHDhg2zuqQT+PrvWKS+MMawfl82M39NZc7vaew8mFdmfrOYMM5q1ZC+rRvTt3VD2jaJUPCQGnWy7+9j+cb5DD8WHR3Ntddey6xZsxg6dKjV5YiID9p9KI+Za1KZuWYvm/fneKYHBdjo27oRF3SM44JOcbSJbaDwIXWSwkgdkJqays033+z3VxCJSPVxuQyLNh/gg6U7+WHjfs/AW8EBdi7o1IQrezRjYMcmRIToY17qPv0rtdDhw4dZsGABCxYs4F//+pfV5YiID8gtLObTX3bz4dKdbMvI9Uwf0K4xV/VsxqAzEogOC7KwQhHvKYxYqFevXhw+fJjnn3+ejh07Wl2OiNRh+Q4nHy7bwZSF2ziU6wAgMiSQ6/o055azW9KmSYTFFYpUncKIhXbs2GF1CSJSxxUUOfnPz7v414KtHMh237OrVeNwbj+nNVef2VynYcQv6F+xiEgdZIzhu3XpPP31OvZmFgCQ1CiMBy5sz9W9mhGo0UrFjyiMiIjUMbsP5TH2q3V8v2E/AInRodx/YXuu692c4ECFEPE/CiMiInWEo9jFvxdv45/zN1NQ5CIowMbd57Vl1AXtCAvW6KfivxRGRETqgM3p2Tz4yRr+2JcFwNltGvF/Q7vSLi7S4spEap7CiIiIhYwxTFu2k/+btZ7CYhcNw4N48oouXN2rmQYok3pDJx/rmPPPP5+HHnqoWrdps9mYMWNGtW5TRE5fbmExD3yyhidnrqOw2MV5HZrw3UPncc2ZzRVEpF5Ry4gfGDduHJs3b2batGlWlyIilbTzYC53vP8Lm/fnEGi38fjgTtw2oDV2u0KI1D8KI35g5syZPP7441aXISKVtHzbQf4ybSWH84qIjwph8k1n0qdVI6vLErGMTtNYKDc3l+HDhxMREUFiYiIvv/yyZ97TTz9N165dT1inZ8+ePPnkk57Xu3fvZt26dVx66aUAbN68mfPOO4/Q0FC6dOnC3Llzy6z/wQcfEBERwebNmz3T7r33Xjp16kReXtk7fIpI9Zu5JpU/v7Ocw3lF9Ggezdf3naMgIvWe/7WMGANFFn2pBoWDF+d5H330URYuXMjMmTOJi4vjiSeeYNWqVfTs2ZPbbruNcePGsWLFCs466ywAVq9ezW+//cYXX3zh2cZXX33F+eefT1RUFC6Xi2uuuYb4+HiWL19OZmbmCf1Phg8fzjfffMPNN9/MkiVL+O677/j3v//N0qVLCQ8Pr5Zfg4iUb+pP2xn39R8AXN4tkZf/1IPQIF2yK+J/YaQoD8Y3tea9n9gLwQ0qtWhOTg7vvPMO06ZN46KLLgLg/fffp3nz5gA0b96cQYMGMXXqVE8YmTp1KgMHDqRNmzae7cycOZOrrroKgHnz5rFhwwa+++47mjZ1/w7Gjx/P4MGDy7z3m2++Sffu3XnggQf44osv+Mc//kHv3r1Pb99F5KRenbeZV+ZtAuDW/q146oou6h8iUkKnaSyydetWHA4HycnJnmmNGjUqc8O8O++8k//85z8UFBTgcDj4+OOPue222zzzs7KyWLhwIVdeeSUA69evJykpyRNEAPr163fCezds2JB33nmHN954g7Zt26q/iUgNe2XuJk8QGX1xB8YOURAROZb/tYwEhbtbKKx672o0ZMgQQkJC+PLLLwkODqaoqIjrrrvOM//bb7+lS5cuJCUleb3tRYsWERAQwL59+8jNzSUyUgMridSEV+dt5tX57j5aYwZ34u6BbS2uSKTu8b+WEZvNfarEiocX/UXatm1LUFAQy5cv90w7fPgwmzZt8rwODAxkxIgRTJ06lalTp3LDDTcQFhbmmX/sKRqAzp07s3v3bvbt2+eZtmzZshPee8mSJTz//PN8/fXXREREcN9991W6bhGpvKk/bfe0iCiIiFTM/1pGfERERAS33347jz76KI0bNyYuLo6//e1v2O1l8+Edd9xB586dAfjpp58804uLi/n222/561//6pmWkpJChw4dGDFiBC+++CJZWVn87W9/K7O97OxsbrnlFh544AEGDx5M8+bNOeussxgyZEiZVhcROT0z16R6OquOvriDgojISfhfy4gPefHFFzn33HMZMmQIKSkpnHPOOSd0JG3fvj39+/enU6dOZfqXLFy4kIiICM4880zPNLvdzpdffkl+fj59+/bljjvu4Nlnny2zvQcffJAGDRowfvx4ALp168b48eO5++67SU1NrcG9Fak/lmzN4JFPfwXcnVXvv7CdxRWJ1G02Y4yxuohTycrKIjo6mszMTKKiosrMKygoYPv27bRu3ZrQ0FCLKqw5xhjat2/Pvffey+jRoz3TH3jgAYqLi/nXv/5V4zX4++9YpDptO5DD0Mk/kVVQzBXdE/nnDb3UWVXqrZN9fx9Lp2nqsAMHDvDJJ5+QlpbGyJEjy8zr2rVruVfKiIh1juQ5uO29FWQVFHNmixheur6HgohIJSiM1GFxcXHExsby1ltv0bBhwzLz7rrrLouqEpHyOF2GBz5Zw46DeTSLCePNW/poQDORSlIYqcN84AyaiJR4df5mFm06QGiQnbeH96FJZIjVJYn4DHVgFRE5TT9s2M8/S8YSGX91N7o0rfjcuIicyG/CiFoRao5+tyIVS88qYPSnawD489ktuObM5tYWJOKDfD6MBAS4z8k6HA6LK/FfpXfzDQoKsrgSkbrF5TI88umvHM4roktiFE9e0cXqkkR8ks/3GQkMDCQ8PJwDBw4QFBR0wqBhUnXGGPLy8ti/fz8xMTGe4Ccibu8s3s7iLRmEBtn55429CAnU34hIVfh8GLHZbCQmJrJ9+3Z27txpdTl+KSYmhoSEBKvLEKlTNqZl8+J3GwF46oozaBcXYXFFIr7L58MIQHBwMO3bt9epmhoQFBSkFhGR4xQ7XTz6+a84nC4u6hTHjX29v1mliBzlF2EE3EOha3RQEakNb/24jd/2ZBIVGsj4a7ph8+ImmSJyInWwEBHxwpb9OUya676M96khZxAfpf8EiZwuhRERkUoyxvDkjN9xOF2c37EJ157ZzOqSRPyCwoiISCXNWJPK0m0HCQ2y88xVXXV6RqSaKIyIiFRCZl4Rz85aD8D9F7YnqVG4xRWJ+A+FERGRSnhl3iYychy0i4vgznPbWF2OiF9RGBEROYUt+7P5cJl7HKNxV55BcKA+OkWqk/6iRERO4Zlv1uN0GS7uEs+AdrFWlyPidxRGRERO4oeN+1m46QBBATb+dllnq8sR8UsKIyIiFXC6DONLOq2OHNCaVrENLK5IxD8pjIiIVODL1als3p9DdFgQoy5oZ3U5In5LYUREpBwFRU5embsJgHvPb0t0WJDFFYn4L4UREZFyfLR8F6lH8omPCmFE/1ZWlyPi1xRGRESOk1tYzOQftgDwUEoHQoN052qRmqQwIiJynA+X7eRQroOWjcO5vndzq8sR8XsKIyIix8hzFPP2om0A3HdBOwID9DEpUtP0VyYicoxpy3ZyMNdBi0bhXN1Ld+UVqQ0KIyIiJfIdTt5Sq4hIrdNfmohIiekrdpGR46B5wzCuPlOtIiK1RWFERAQocrp4+8ftANw9sC1BahURqTX6axMRAWb9to/UI/k0bhCsK2hEapnCiIjUe8YYpizcCsCt/VtpXBGRWqYwIiL13sJNB9iQlk14cAC39GtpdTki9Y7CiIjUe+8sdvcVueGsFsSEB1tcjUj9ozAiIvXapvRsftycgd0GIwe0srockXqpSmFk8uTJtGrVitDQUJKTk/n5559PuvykSZPo2LEjYWFhJCUl8fDDD1NQUFClgkVEqtPUn3YAcHGXeJIahVtbjEg95XUYmT59OqNHj2bs2LGsWrWKHj16MGjQIPbv31/u8h9//DGPP/44Y8eOZf369bzzzjtMnz6dJ5544rSLFxE5HYdzHXy5eg8Atw1obXE1IvWX12Fk4sSJ3HnnnYwcOZIuXbowZcoUwsPDeffdd8tdfsmSJQwYMICbbrqJVq1acckll3DjjTeesjVFRKSm/WfFLgqKXJzRNIq+rRtZXY5IveVVGHE4HKxcuZKUlJSjG7DbSUlJYenSpeWu079/f1auXOkJH9u2bWP27NlcdtllFb5PYWEhWVlZZR4iItXJ6TJMW7oTgJEDWmOz2SyuSKT+CvRm4YyMDJxOJ/Hx8WWmx8fHs2HDhnLXuemmm8jIyOCcc87BGENxcTF/+ctfTnqaZsKECYwbN86b0kREvPL9hv3szSygYXgQV3RPtLockXqtxq+mWbBgAePHj+df//oXq1at4osvvmDWrFk888wzFa4zZswYMjMzPY/du3fXdJkiUs9MW+ZuFbm+T5IGOROxmFctI7GxsQQEBJCenl5menp6OgkJCeWu8+STT3LLLbdwxx13ANCtWzdyc3O56667+Nvf/obdfmIeCgkJISQkxJvSREQqbdfBPBZtPgDATX1bWFyNiHjVMhIcHEzv3r2ZP3++Z5rL5WL+/Pn069ev3HXy8vJOCBwBAe7/hRhjvK1XROS0ffTzToyBc9vH0iq2gdXliNR7XrWMAIwePZoRI0bQp08f+vbty6RJk8jNzWXkyJEADB8+nGbNmjFhwgQAhgwZwsSJE+nVqxfJycls2bKFJ598kiFDhnhCiYhIbSksdvLZL+7Lef98toZ+F6kLvA4jw4YN48CBAzz11FOkpaXRs2dP5syZ4+nUumvXrjItIX//+9+x2Wz8/e9/JzU1lSZNmjBkyBCeffbZ6tsLEZFKmvfHfg7lOoiPCuGiTnFWlyMigM34wLmSrKwsoqOjyczMJCoqyupyRMSHDX/3ZxZtOsCoC9ry6KBOVpcj4tcq+/2te9OISL2x53AeP5Z0XP1TnySLqxGRUgojIlJvfL5yD8ZAvzaNadlYHVdF6gqFERGpF1wu4+m4OuwstYqI1CUKIyJSLyzddpDUI/lEhgZyadfyx0USEWsojIhIvfDl6lQArujeVCOuitQxCiMi4vfyHU7m/J4GwDVnNrO4GhE5nsKIiPi9eevTySkspnnDMHq3aGh1OSJyHIUREfF7M0pO0VzVsyl2u83iakTkeAojIuLXDuU6WLjJPbbI0J46RSNSFymMiIhfm/XbXopdhq7NomgfH2l1OSJSDoUREfFrpVfRqFVEpO5SGBERv7XrYB6rdh3BboMhPZpaXY6IVEBhRET81ow17laRAe1iiY8KtbgaEamIwoiI+CVjzDFX0egUjUhdpjAiIn5pbWom2zJyCQ2yM+iMeKvLEZGTUBgREb9U2nH14i4JRIYGWVyNiJyMwoiI+B2XyzDrt30AXKWOqyJ1nsKIiPidVbsOsz+7kMiQQM7tEGt1OSJyCgojIuJ3Zq913xQvpUs8IYG6Q69IXacwIiJ+xRjDnN/dp2gu7ZpgcTUiUhkKIyLiV37dk8nezALCgwMY2KGJ1eWISCUojIiIX/m2pFXkgk5xhAbpFI2IL1AYERG/YYzh25L+Ipd1TbS4GhGpLIUREfEbf+zLYtehPEKD7JzfUadoRHyFwoiI+I3SVpGBHZrQICTQ4mpEpLIURkTELxhjmF3SX+SybjpFI+JLFEZExC9s3p/DtgO5BAfYubBTnNXliIgXFEZExC/MXutuFTm3fazuRSPiYxRGRMQvzPnd3V9ksE7RiPgchRER8XnbDuSwIS2bQLuNizvHW12OiHhJYUREfN7cP9IB6Ne2MdHhOkUj4msURkTE581b7w4jF3dRq4iIL1IYERGfdijXwcqdhwF0FY2Ij1IYERGf9sOG/bgMdE6MonnDcKvLEZEqUBgREZ82f4P7FE1KZ7WKiPgqhRER8VmFxU4WbjwAQIquohHxWQojIuKzlm87RK7DSZPIELo1i7a6HBGpIoUREfFZpVfRpHSOw263WVyNiFSVwoiI+CRjDPPX7wfgok46RSPiyxRGRMQnrd+XTeqRfEKD7AxoF2t1OSJyGhRGRMQnzS85RXNOu1jCggMsrkZETofCiIj4pKP9RXSKRsTXKYyIiM/Zn1XAr3syAbhQ44uI+DyFERHxOfM3uDuu9kiKIS4y1OJqROR0KYyIiM+ZV3KX3ovVKiLiFxRGRMSnFBQ5+WlrBgAX6pJeEb+gMCIiPuXn7YcoKHIRHxVC58RIq8sRkWqgMCIiPmVByb1oBnZogs2mUVdF/IHCiIj4lAWb3J1Xz++o/iIi/kJhRER8xu5DeWw7kEuA3cY57TXqqoi/UBgREZ+xYKO7VaR3y4ZEhQZZXI2IVBeFERHxGaX9Rc7v2MTiSkSkOimMiIhPKChysmTrQQDO76D+IiL+RGFERHzCz9sPkV/k1CW9In5IYUREfILnFE2HOF3SK+JnFEZExCccvaRX/UVE/I3CiIjUeaWX9AbabQzQJb0ifqdKYWTy5Mm0atWK0NBQkpOT+fnnn0+6/JEjRxg1ahSJiYmEhITQoUMHZs+eXaWCRaT+Kb2k90xd0ivilwK9XWH69OmMHj2aKVOmkJyczKRJkxg0aBAbN24kLu7EHu4Oh4OLL76YuLg4Pv/8c5o1a8bOnTuJiYmpjvpFpB7QJb0i/s3rMDJx4kTuvPNORo4cCcCUKVOYNWsW7777Lo8//vgJy7/77rscOnSIJUuWEBTk/h9Nq1atTvoehYWFFBYWel5nZWV5W6aI+Ald0ivi/7w6TeNwOFi5ciUpKSlHN2C3k5KSwtKlS8td56uvvqJfv36MGjWK+Ph4unbtyvjx43E6nRW+z4QJE4iOjvY8kpKSvClTRPzIih26pFfE33kVRjIyMnA6ncTHx5eZHh8fT1paWrnrbNu2jc8//xyn08ns2bN58sknefnll/m///u/Ct9nzJgxZGZmeh67d+/2pkwR8SM/bs4A4Lz2ukuviL/y+jSNt1wuF3Fxcbz11lsEBATQu3dvUlNTefHFFxk7dmy564SEhBASElLTpYmIDygNI+d2UH8REX/lVRiJjY0lICCA9PT0MtPT09NJSEgod53ExESCgoIICAjwTOvcuTNpaWk4HA6Cg4OrULaI1AcHsgtZv8/dZ2xA28YWVyMiNcWr0zTBwcH07t2b+fPne6a5XC7mz59Pv379yl1nwIABbNmyBZfL5Zm2adMmEhMTFURE5KR+2uJuFenaLIrGEWotFfFXXo8zMnr0aN5++23ef/991q9fzz333ENubq7n6prhw4czZswYz/L33HMPhw4d4sEHH2TTpk3MmjWL8ePHM2rUqOrbCxHxS6WnaM5pp1M0Iv7M6z4jw4YN48CBAzz11FOkpaXRs2dP5syZ4+nUumvXLuz2oxknKSmJ7777jocffpju3bvTrFkzHnzwQR577LHq2wsR8TvGGH7c7B5f5FyNuiri12zGGGN1EaeSlZVFdHQ0mZmZREVFWV2OiNSCTenZXPLKIkKD7Kx56hJCgwJOvZKI1CmV/f7WvWlEpE5atMndKtK3dWMFERE/pzAiInXS4i2l44voFI2Iv1MYEZE6p7DYybJt7iHgz1EYEfF7CiMiUues3HmYgiIXTSJD6BivIeBF/J3CiIjUOZ5RV9vFagh4kXpAYURE6pzFpeOL6BSNSL2gMCIidcqhXAe/780E4Jx2CiMi9YHCiIjUKT9tycAY6JQQSVxUqNXliEgtUBgRkTpFo66K1D8KIyJSZxhjjukvovvRiNQXCiMiUmdsy8hlb2YBwYF2+rZqZHU5IlJLFEZEpM74sWQI+LNaNSQsWEPAi9QXCiMiUmeUDgF/TjudohGpTxRGRKROKHK6WLrVPQS8Oq+K1C8KIyJSJ6zedYRch5PGDYLpkljxrcZFxP8ojIhInVB6Se+AdrHY7RoCXqQ+URgRkTrhRw0BL1JvKYyIiOUy84r4bc8RQP1FROojhRERsdySrRm4DLSLiyAxOszqckSklimMiIjlFpWeotGN8UTqJYUREbGUMcbTefW8DgojIvWRwoiIWGrnwTz2HM4nKMBGcuvGVpcjIhZQGBERS/1YMurqmS0a0iAk0OJqRMQKCiMiYqnFJadodBWNSP2lMCIilil2uliypXQIeN2PRqS+UhgREcv8uieT7MJiosOC6Nos2upyRMQiCiMiYpnSq2jOaRdLgIaAF6m3FEZExDKLNQS8iKAwIiIWySooYvXuI4AGOxOp7xRGRMQSy7YexOkytI5tQFKjcKvLERELKYyIiCV+1BDwIlJCYURELLG4ZLAzjS8iIgojIlLrdh/KY3tGLgF2G2e31RDwIvWdwoiI1LrSVpFeSTFEhQZZXI2IWE1hRERqnWd8EZ2iEREURkSkljldhp80BLyIHENhRERq1drUTDLzi4gMDaRHcw0BLyIKIyJSy37c5D5FM6BtLIEB+ggSEYUREallpeOLnNtB/UVExE1hRERqTXZBEat2HQbgPPUXEZESCiMiUmuWbTtEscvQqnG4hoAXEQ+FERGpNaWX9OoqGhE5lsKIiNQaz/1oNL6IiBxDYUREasWxQ8D30xDwInIMhRERqRWlrSIaAl5EjqcwIiK1YvEW9RcRkfIpjIhIjXO6DIs1voiIVEBhRERq3G97jpBVUExUaCDdm2kIeBEpS2FERGpcaX+RAe00BLyInEifCiJS4zS+iIicjMKIiNQo9xDwRwA4V+OLiEg5FEZEpEYt3XoQp8vQOraBhoAXkXIpjIhIjfLcpVetIiJSAYUREalR6i8iIqeiMCIiNWbXwTx2HMwjwG7j7DaNrC5HROoohRERqTE/loy6emaLGCI1BLyIVKBKYWTy5Mm0atWK0NBQkpOT+fnnnyu13ieffILNZmPo0KFVeVsR8TGeUVd1ikZETsLrMDJ9+nRGjx7N2LFjWbVqFT169GDQoEHs37//pOvt2LGDv/71r5x77rlVLlZEfEex08VPW9R5VUROzeswMnHiRO68805GjhxJly5dmDJlCuHh4bz77rsVruN0Orn55psZN24cbdq0Oa2CRcQ3/JaaeXQI+OYxVpcjInWYV2HE4XCwcuVKUlJSjm7AbiclJYWlS5dWuN7TTz9NXFwct99+e6Xep7CwkKysrDIPEfEtCza6+4uc0z6WALvN4mpEpC7zKoxkZGTgdDqJj48vMz0+Pp60tLRy11m8eDHvvPMOb7/9dqXfZ8KECURHR3seSUlJ3pQpInXAwo3uU7fnd4izuBIRqetq9Gqa7OxsbrnlFt5++21iYyt/znjMmDFkZmZ6Hrt3767BKkWkuh3MKeS31EwABnZU51UROblAbxaOjY0lICCA9PT0MtPT09NJSEg4YfmtW7eyY8cOhgwZ4pnmcrncbxwYyMaNG2nbtu0J64WEhBASEuJNaSJShyzafABjoHNiFPFRoVaXIyJ1nFctI8HBwfTu3Zv58+d7prlcLubPn0+/fv1OWL5Tp06sXbuWNWvWeB5XXnklF1xwAWvWrNHpFxE/Vdpf5AK1iohIJXjVMgIwevRoRowYQZ8+fejbty+TJk0iNzeXkSNHAjB8+HCaNWvGhAkTCA0NpWvXrmXWj4mJAThhuoj4B6fLsGiTO4yc31H9RUTk1LwOI8OGDePAgQM89dRTpKWl0bNnT+bMmePp1Lpr1y7sdg3sKlJf/bbnCIfziogMDeTMFjFWlyMiPsBmjDFWF3EqWVlZREdHk5mZSVRUlNXliMhJTJy7iX/O38xl3RL41829rS5HRCxU2e9vNWGISLXSJb0i4i2FERGpNrqkV0SqQmFERKqNLukVkapQGBGRalN6Se/5ahURES8ojIhItShzSW8HhRERqTyFERGpFmUu6W3Z0OpyRMSHKIyISLUoPUVzbvtYggL00SIiladPDBGpFgt0Sa+IVJHCiIictv1ZBfy6x31Jrzqvioi3FEZE5LR9v8HdKtKjeTRxuqRXRLykMCIip23e+nQAUjrHW1yJiPgihREROS35DieLt2QAcJHCiIhUgcKIiJyWn7ZkUFDkollMGJ0TI60uR0R8kMKIiJyW+Rvcp2gu6hyHzWazuBoR8UUKIyJSZS6XYd56d+dV9RcRkapSGBGRKlubmsmB7EIaBAeQ3KaR1eWIiI9SGBGRKiu9imZgxyaEBAZYXI2I+CqFERGpstJTNBd10ikaEak6hRERqZI9h/NYvy8Luw0u6KQh4EWk6hRGRKRKSkdd7d2yIY0aBFtcjYj4MoUREamSuX9o1FURqR4KIyLiteyCIpZtOwhASheFERE5PQojIuK1RZsyKHIaWsc2oG2TCKvLEREfpzAiIl779vd9AFyiVhERqQYKIyLilYIiJz+UdF4d3C3R4mpExB8ojIiIVxZtOkCuw0nT6FB6NI+2uhwR8QMKIyLilW9/TwPg0q6JujGeiFQLhRERqbTCYqdnCPjB3RIsrkZE/IXCiIhU2pItB8kuKCYuMoTeLRpaXY6I+AmFERGptNKraC7tmoDdrlM0IlI9FEZEpFKKnC7+VzLq6qVddYpGRKqPwoiIVMqybQc5kldE4wbB9G3VyOpyRMSPKIyISKWUXkVzyRnxBAboo0NEqo8+UUTklJwuw//WucPI4K4a6ExEqpfCiIic0oodh8jIcRAdFkS/to2tLkdE/IzCiIic0rdr3VfRXNwlniCdohGRaqZPFRE5KafLePqLDNZVNCJSAxRGROSklm49yP7sQmLCgzi3fROryxERP6QwIiIn9eXqVAAu75ZIcKA+MkSk+umTRUQqlO9wMqdk1NWrezWzuBoR8VcKIyJSobnr08l1OGneMIzeLXUvGhGpGQojIlKhmSWnaIb2bIbNVgP3onEWQ85+KMyp/m2LiM8ItLoAEambDuYUsnDTAQCG9mpaPRvNOQAbvoGt82H3z+4ggnHPC2oAse2h7YXQYRAkJUNNBCARqXMURkSkXLPW7qPYZejaLIp2cZGnt7GMLbDkVfh1OjgLy1+mKBf2rXE/Fk+EhO7Q/wHoeg3YA07v/UWkTlMYEZFyzTjmFE2VOfJg0Quw5DVwFbunJfaETpdDm/OhYSsIbwxF+ZCTDntWwJZ5sP4bSPsNvrgDlk2GIa9CYo/T3SURqaMURkTkBDsP5rJq1xHsNriyRxVP0exZCf+9DQ7vcL9ufwmcMxpanH3i6ZeQCPejcVvocQPkHYIV/4Ylr8Pe1fDW+XDOw3DB39RKIuKH1IFVRE4wY/VeAAa0iyUuKtS7lY2BX6bC1EvdQSSqGdzwMdz8GbTsV7l+IOGNYOD/g/tWQNdrwbjgx5dh2jWQm+H9DolInaYwIiJlGGOYuaaKp2hcLvj2MfjmIXA6oNMVcO9S92mZqoiMh+vedT+CGsC2BfDWBe4+KCLiNxRGRKSM3/Zksi0jl9AgO4O8uReNswi+vAt+fhOwwUVjYdg0CI0+/aK6Xgt3zodGbSBzF7w7CPb9evrbFZE6QWFERMr4bOVuAC7pkkBESCW7lTmL4NPhsPYzsAfCtf+Gc0dX76W5cZ3htv+5r7LJy4D3rnD3SxERn6cwIiIe+Q4nM0v6iww7K6lyK7mc8MVdsHE2BIbCDf+BbtfVTIERTeDWb6BFfyjMcvchSV9XM+8lIrVGYUREPL79fR/ZhcUkNQqjX5vGp17BGHf/kHVfgD3IfVqmwyU1W2RotLszbPOzoOAIfDAUDm6t2fcUkRqlMCIiHtNXuE/RXN87Cbu9EqdYFr0Iqz4Am919aqb9xTVcYYmQCHcgSegGufvho+sg92DtvLeIVDuFEREBYEdGLsu3H8Jmg+t6Nz/1Cr99Cj88635++UQ4Y2iN1neCsIbw5y8gpgUc2gaf3ARFBbVbg4hUC4UREQFg+i/uVpGBHZrQNCbs5Avv/hlmjnI/H/Ag9BlZw9VVICIObv4cQqJh9zL46n73qSMR8SkKIyKCo9jFZyVh5IZTdVzN2e++csbpgM5D4KJ/1HyBJ9OkIwz7EGwBsPZT98itIuJTqhRGJk+eTKtWrQgNDSU5OZmff/65wmXffvttzj33XBo2bEjDhg1JSUk56fIiUvvmrEsjI8dBfFQIKZ3jK17QWQyf3wbZ+yC2IwydAvY68H+aNgPhkmfcz+eMgd0rrK1HRLzi9afI9OnTGT16NGPHjmXVqlX06NGDQYMGsX///nKXX7BgATfeeCM//PADS5cuJSkpiUsuuYTU1NTTLl5Eqse0ZTsBuOGsFgQGnORj4ftnYMePEBzhvnImJKKWKqyEs++FLkPBVTLmSc4BqysSkUqyGePdCdbk5GTOOussXn/9dQBcLhdJSUncf//9PP7446dc3+l00rBhQ15//XWGDx9eqffMysoiOjqazMxMoqKivClXRE5hU3o2l7yyiAC7jZ8eu5CE6AruRbP+G5h+s/v59e/BGVfXWo2VVpgNb18IGZug9XlwywzdWE/EQpX9/vaqZcThcLBy5UpSUlKObsBuJyUlhaVLl1ZqG3l5eRQVFdGoUaMKlyksLCQrK6vMQ0RqxkclrSIpneMqDiIHt8KMe9zPzx5VN4MIQEgk/OlD931sti+C7//P6opEpBK8CiMZGRk4nU7i48ueU46PjyctLa1S23jsscdo2rRpmUBzvAkTJhAdHe15JCVVciRIEfFKdkER/13lPmX657Nblr9QcSF8OsI94mmLfnDxuFqssAriOsFVr7mfL54Im/5nbT0ickq12vPsueee45NPPuHLL78kNLTi25KPGTOGzMxMz2P37t21WKVI/fHZL3vIKSymbZMGnNMutvyF5j8N6WshvDFcNxUCgmq3yKroei30vdv9fOa97iuARKTO8iqMxMbGEhAQQHp6epnp6enpJCSc/O6eL730Es899xz/+9//6N69+0mXDQkJISoqqsxDRKqX02V4f+kOAEYOaI2tvJvabVsAS939w7hqMkQl1lp9p+3ipyG+K+QecJ9icrmsrkhEKuBVGAkODqZ3797Mnz/fM83lcjF//nz69etX4XovvPACzzzzDHPmzKFPnz5Vr1ZEqs33G/az82AeUaGBXHNmsxMXyDsEX5b0E+k9EjoOrt0CT1dQqHuI+sBQ2DIPfn7T6opEpAJen6YZPXo0b7/9Nu+//z7r16/nnnvuITc3l5Ej3SMwDh8+nDFjxniWf/7553nyySd59913adWqFWlpaaSlpZGTk1N9eyEiXpv603YAbkxuQXhwYNmZpTfAy94LjdvBoGdrv8DqENcZLinpxDr3KUhba209IlIur8PIsGHDeOmll3jqqafo2bMna9asYc6cOZ5Orbt27WLfvn2e5d944w0cDgfXXXcdiYmJnsdLL71UfXshIl5ZtzeTJVsPEmC3MbxfqxMXWPMx/DET7IHu1oXgBrVeY7U56w7oMNg9Yux/7wBHntUVichxvB5nxAoaZ0Skej3wn9V89etehvRoyms39io789A2mHIuOHLgoqfg3EesKbI65WbAG/0hJx363A5XTLS6IpF6oUbGGRER37f7UB7f/LYXgLvPa1N2prMYvrjbHURaDoABD9V+gTWhQSxcPcX9/Jd3YNN31tYjImUojIjUM2//uA2XgXPbx9K1WXTZmT++DHt+dt8F9+op/jV6adsL3UPGg/uOwxouXqTOUBgRqUcO5hTyacndee8Z2LbszN0rYOHz7ueXvwwxLWq5ulpw0ViI6+K+3Per+90ddUXEcgojIvXIO4u3U1DkoluzaPq1bXx0RmE2fHEnGCd0ux66X29dkTUpKBSueRsCgmHTt7ByqtUViQgKIyL1xuFcB+8v2QHAfRe2KzvI2ZzH4fB2iE6Cy/z8SreEru4WEoA5T0DGZmvrERGFEZH64p3F28l1OOmcGMUlXY65v9QfX8HqaYANrn4TwmKsKrH2nH2v+66+xfnuFiFnkdUVidRrCiMi9cCRPAfvlbSKPHhR+6OtIll74esH3M/PeRhaDbCmwNpmt8PQKRAaA3tXH+0rIyKWUBgRqQfeWbydnMJiOiVEHm0Vcbnc92zJPwyJPeH8MSfdht+JbgZDJrmf//gy7FpmaTki9ZnCiIify8gp5J3F7qHfH0ppj91e0iqy/A33jfACw0ru4RJsXZFWOeNq6HEjGBd8cRcUZFldkUi9pDAi4ude/34LeQ4nPZpHM+iMkrtrp62Fef9wPx/0LMS2t6w+yw1+wX0Z85Gd8O1jVlcjUi8pjIj4sd2H8vho+U4AHru0k7uviCMXPr/Nfa+WDoOhz20WV2mx0Ch3x12bHX79GNZ9aXVFIvWOwoiIH3tl7iaKnIZz28fSv12se+KcMZCxCSIT4arJcOwlvvVVy/7uDrwAXz/k7tgrIrVGYUTET63dk8kXq1MBeHRQR/fEdTNg1ft4LuNt0LjC9eudgY+7O/IWHIEv7waX0+qKROoNhRERP2SMYdzX6wC4plczujePgSO7y17G22agdQXWRYHB7o68QeGwfRH8qDv7itQWhRERPzRr7T5+2XmYsKAAHr20Y8ndeO+Egkxo1gcueMLqEuum2Pbu+/IALBgPOxZbW49IPaEwIuJnCoqcTJi9AYC/DGxLYnSYe1CvXUshONL9v/+AIIurrMN63gQ9bnJf7vvfOyA3w+qKRPyewoiIn3n9+y2kHsknMTqUu85rA5vnwaIX3TOvmAiNWltboC+47EWI7QDZ+0r6j7isrkjErymMiPiRLftzeHPRVgDGDjmDsLxU+OIOwEDvkdD9T9YW6CtCIuD69yAwFLbMgyX/tLoiEb+mMCLiJ4wx/H3GWoqchgs7xTGoYwx8OuLocO+XPmd1ib4l/gz3gGgA85+GnUusrUfEjymMiPiJ/65KZdm2Q4QG2Rl35RnY/vc32LvKfTO4P30AQaFWl+h7zhwO3a4H44RPh0NmqtUVifglhRERP5CeVcDTJZfyPnhRB5L2zIIV/3bPvOYtaNjSwup8mM0GQ16F+K6QewCm/xmKCqyuSsTvKIyI+DhjDE98sZasgmK6N4/mzvY5R8cTOfev0GGQtQX6uuAGcMNHENbQ3dI06xEwxuqqRPyKwoiIj/tydSrzN+wnOMDOxMsSCZx+ExTlQZsLNJ5IdWnYCq6b6r5/zZppR1udRKRaKIyI+LDdh/IYO9N9eubh85vTbv6dkLUHGreD66eCPcDiCv1I2wsgZZz7+ZzHYcdP1tYj4kcURkR8VLHTxYOfrCa7sJjeLWK4+8hESF3p7rB606fu0wpSvfrfD12vA1exu/9IxharKxLxCwojIj7qn99vYdWuI0SGBPJu6++xr/sC7IEwbBo0bmt1ef7JZoMrX4OmvSD/EHx0LeQcsLoqEZ+nMCLigxZuOsBr328G4L2zdhG9/CX3jMsnQutzLaysHggOd7c8NWwFh3fAx38CR67VVYn4NIURER+z53AeD36yGmNgbOe99F41xj2j333Qe4S1xdUXEXFw838hrJH7CpvPb3PfjFBEqkRhRMSHFBQ5ufejVRzJK2JY/B5u3f0kuIqg67Vw8dNWl1e/xLaDm6a7h4zfNAdm65JfkapSGBHxEcYYHv38N37bk0ly2B4m5P8ftuJ8aH8JXP2mrpyxQlJfuPYd9yW/K9+D+eMUSESqQGFExEe8On8zX/+6l3b2NKYFP4fdkQUt+sP170NAkNXl1V+dr4DLX3Y/X/wKfP+MAomIlxRGRHzAF6v2MGneZlrY0pkZ9SJBhYcgsQfc9Im7Q6VYq89tR2+q9+PL8MOzCiQiXlAYEanjvt+QzqOf/0Z72x5mRTxLg4J9ENsB/vwFhEZbXZ6USr776J2RF70IC3SXZJHKUhgRqcNW7DjEvR+torPZypfhzxJZlAFxZ8Cts6BBrNXlyfHOvgcuedb9fOFz8L1aSEQqQ2FEpI5aufMQI6euoGvxH3waOp4IZyY06w23fuO+tFTqpv73wcXPuJ8vegG+flCX/YqcgsKISB20cudhRry7gj5Fv/BR6POEmzxoeQ4MnwnhjawuT05lwANw2UuADVa9D9Nv1sBoIiehMCJSxyzenMEt7yzjT8Vf827wS4SYQmh3Mdz8GYREWl2eVFbfO91D85eOQ/L+EA0dL1IBhRGROmT22n3c/d5PjHVN4amgD7Hjgl5/hhs+1lUzvqjzFTD8K/dNC1NXwjsXw4GNVlclUucojIjUAcYY3lq0lac+/oGpAc8yLHABxmaHQRPgytchMNjqEqWqWiTD7XMhpiUc3g5vXQBrP7e6KpE6RWFExGKOYheP/3ctc779ihnBT9LXvhETEoXtps+g373uO8WKb4ttD3fMh9bnQVEu/Pd2mP0oFDusrkykTlAYEbHQ3iP53PjmYhLXvMJnweNobsvANGqD7Y750D7F6vKkOkU0gVtmwLl/db/++S2YOhiO7La0LJG6QGFExCILNx3grn/+lyfSR/NQ4BcE2Ax0vwHbXQuhSQery5OaYA+Ai56EG6e7B6xL/QXeGACrPtR4JFKv2Yyp+38BWVlZREdHk5mZSVRUlNXliJyWgiInz8/+g9yf3+fJwGlE2vJxBUdiHzIJul1ndXlSWw7vgM9vc3dsBWh7EQx5FWKSLC1LpDpV9vtbYUSkFv2y4xBvfvoVd+dMpo99EwCu5snYr30bGra0uDqpdc5iWDbZPVKrsxCCI+GSp+HMW8GuhmvxfQojInXIkTwHr327msTVr3BrwHcE2lwUBzYg8MIxkHwPBARaXaJYKWMzzBwFu5e7XzftBYPGQ8v+1tYlcpoURkTqAEexi4+WbCH1+ze503xOvO2Ie3rHqwi+bAJEN7O2QKk7XE5Y/ib8MB4c2e5pna+Ei8dBozbW1iZSRQojIhYyxjD391RWfT2Fmws+IcnuHnkzP6IFYVe9oitlpGI5B2DBeFj5HhgX2IPgrNuh//0Q3dzq6kS8ojAiYgGXyzB37W7Wz5vKkMz/0Na+D4D8kFhCLnwMe+8REBhicZXiE9L/gP/9DbZ+735tD4IeN8CAhyC2naWliVSWwohILSosdvLtsl/JXPQmlxXOpoktC4D8wGjs540m5Oy7NJy7VM3WH+DHl2HHjyUTbNDlKki+G1r006B4UqcpjIjUgu0Hsln+/UyiNnzKRa6fCLG5bxWfHRyHPfkuGgy4G0L1b1aqwe6f4ceJsOnbo9NiO0DvW6HHjbqbs9RJCiMiNSS3sJiflv5E7oqP6Jszj2a2g5556VHdib7gAUK7D4WAIOuKFP+V9jv8/Cas/a97aHmAgBDoOBi6XuO+w7Na4aSOUBgRqUZHcgtYufR78td+Q/sjP9LRtsszL8/WgIxWl5N4/l0EtTzLwiqlXinIgt8/h1+mQtpvR6cHNYCOl0KXodD2AgiJtKxEEYURkdPgdBk2bN5M6pr/EbjzR87IXea5LBegmAB2NepPw37DadjzSggKta5Yqd+MgX1r4PcvYN0MyDwalLEHQtLZ0O5C9wivCd01mJrUKoURES8UOIrZsnEtBzYuwbZrGUlZK2lLapllcgkjNXYADboNoelZV2LTOXqpa4yB1FXwx5ew/hs4vL3s/LCGkJTsfrQ4G5qeqSAtNUphRKQCmdnZ7Nn8G4d3/oYzbT2Rh36njWMDMbbcMsu5jI3dIW3JSexH4+6DSeiRostyxbcc2gZb5rsvD96+CBw5ZefbgyCuEyT0gMTu7paT+C7um/iJVAOFEanXHI4i0lO3c2jvFvL2b8e5fxOhRzbTJH87zc0+9x1yj1NIEHtC2pET25PITufTotfFBEY0tqB6kRrgLHL3Ldm1zP3YvRxy0stftkEcNG7nHs+kcXuIbe/+GZOkQC5eqdEwMnnyZF588UXS0tLo0aMHr732Gn379q1w+c8++4wnn3ySHTt20L59e55//nkuu+yySr+fwoiUMsaQm5vN4fQ9ZB3YQ/7hvTgy0zHZ6QTkpBGen0ojxz7iTQZBNmeF28miAWkhrciNaoe9aQ/iO/cnvt2Z2PRBK/WFMXBklzug7Pvt6M/svSdfr0ETiGrmHg02qpn7lgZRzdzTG8RCeGP3Q1eTCTUYRqZPn87w4cOZMmUKycnJTJo0ic8++4yNGzcSFxd3wvJLlizhvPPOY8KECVxxxRV8/PHHPP/886xatYquXbtW685I3eZyOsnPyyE/L5vC3BwKC7Jw5OVQXJBDUUEuzoJsnLmHceUdhoIjBBRmEujIIrg4i7DibMJd2USZHBrYCir1fkUmgAP2JhwJSSC/QQuI60xki24ktutBZGySBosSKU9BFhzc4n5kbC55vhkOboWivMpvJzS6JJjEloSURhAS7b66p8wj6pjnERAUDoGhEBQG9oCa20+pFTUWRpKTkznrrLN4/fXXAXC5XCQlJXH//ffz+OOPn7D8sGHDyM3N5ZtvvvFMO/vss+nZsydTpkyp1p3xVkbaLhwF+e4XxnD0V+HCuNzPDQaMyzPPGNz3iyhdh9LpxjPdGI5u65h14Zj3KPN+BuNyebZVsoEK3qOc57iOPncZjKsY43Ie8zj6GpfT/V6m9PkxP0um4XJ65mNc2Epe25xF4HRgcxVhczqwuxzYXUXYXEUEuIoIcDmwmyICTclrigl2FRJKIaGmgDCb4/QPWol8E8xhe0OyAxuRHxxLUVgsroh4ghq3IiK+DY2btadRQgtsuhuuSPUwBvIOQdYeyEyFrFTI3OP+mbUXcjMg7yDkHzr6GXm67EHuUBIY6u5oGxjmPk1UOi0wxL1MQGDJzyD3FUQBQRW8Pm45m90deGz24x4B7v+sHD+9zLK245a3l52H7Zj/8NiO+VH6/Ph5tvLnlbvcyeYdvxwnmWcru1xEfLWfhqvs97dXn9QOh4OVK1cyZswYzzS73U5KSgpLly4td52lS5cyevToMtMGDRrEjBkzKnyfwsJCCgsLPa+zsrK8KbPSDv77ejoWb6iRbctxjvu7yDfBFNhCKbCF4LCFUmgPo8geiiMoiuLgaFwh0ZjQGOzhMQSENyI4shGhkY0Jj25MTJPmNIiMIUwtGyK1x2aDBo3dj8QeFS/nckL+EXcwycs4GlLyDkJhdjmPrLKvnYXHbKsICovcy0jNu30eJFkzVpJXYSQjIwOn00l8fHyZ6fHx8WzYUP6XelpaWrnLp6WlVfg+EyZMYNy4cd6UViXF9mDyTTDmmG/K0ucGG552i2MSpTlumYqnc8x2yt9+mXWPS7imzHaO24at/G2ULuey2THYcdkCSn7aMcc/P2basa85Zj72AM80bHZMQAgmIAgCgiEgGFtgMLbAEGylz4NCCAgMxh4Ugj0whICgUAKCQggOjyQkLJKQ8AjCGkQSGhZBWEAAYd4dLhHxBfaAo6GFDt6v73JCcSEUF0BR/sl/Oh3ujrmu4pKfReW8Lj5m+nGvTUmrsnGVPJzHPC+Z5zp+mrPsfOMqZxkXeFq6S3fsmJbv0tfHPvdmnud1RfNOtd7x71HCZt0YNHWyDXvMmDFlWlOysrJISkqq9vc544kfT72QiIjUHnuAezh7DWlfr3gVRmJjYwkICCA9vezlYOnp6SQkJJS7TkJCglfLA4SEhBASoqsaRERE6gOv2mSCg4Pp3bs38+fP90xzuVzMnz+ffv36lbtOv379yiwPMHfu3AqXFxERkfrF69M0o0ePZsSIEfTp04e+ffsyadIkcnNzGTlyJADDhw+nWbNmTJgwAYAHH3yQgQMH8vLLL3P55ZfzySef8Msvv/DWW29V756IiIiIT/I6jAwbNowDBw7w1FNPkZaWRs+ePZkzZ46nk+quXbuwH3Mjpv79+/Pxxx/z97//nSeeeIL27dszY8aMSo8xIiIiIv5Nw8GLiIhIjajs97fuJS0iIiKWUhgRERERSymMiIiIiKUURkRERMRSCiMiIiJiKYURERERsZTCiIiIiFhKYUREREQspTAiIiIilvJ6OHgrlA4Sm5WVZXElIiIiUlml39unGuzdJ8JIdnY2AElJSRZXIiIiIt7Kzs4mOjq6wvk+cW8al8vF3r17iYyMxGazVdt2s7KySEpKYvfu3X57zxt/30ftn+/z933U/vk+f9/Hmtw/YwzZ2dk0bdq0zE10j+cTLSN2u53mzZvX2PajoqL88h/Ysfx9H7V/vs/f91H75/v8fR9rav9O1iJSSh1YRURExFIKIyIiImKpeh1GQkJCGDt2LCEhIVaXUmP8fR+1f77P3/dR++f7/H0f68L++UQHVhEREfFf9bplRERERKynMCIiIiKWUhgRERERSymMiIiIiKX8Pow8++yz9O/fn/DwcGJiYspdZteuXVx++eWEh4cTFxfHo48+SnFx8Um3e+jQIW6++WaioqKIiYnh9ttvJycnpwb2oPIWLFiAzWYr97FixYoK1zv//PNPWP4vf/lLLVbunVatWp1Q73PPPXfSdQoKChg1ahSNGzcmIiKCa6+9lvT09FqquPJ27NjB7bffTuvWrQkLC6Nt27aMHTsWh8Nx0vXq+jGcPHkyrVq1IjQ0lOTkZH7++eeTLv/ZZ5/RqVMnQkND6datG7Nnz66lSr0zYcIEzjrrLCIjI4mLi2Po0KFs3LjxpOu89957Jxyr0NDQWqrYe//4xz9OqLdTp04nXcdXjh+U/3lis9kYNWpUucvX9eO3aNEihgwZQtOmTbHZbMyYMaPMfGMMTz31FImJiYSFhZGSksLmzZtPuV1v/4a95fdhxOFwcP3113PPPfeUO9/pdHL55ZfjcDhYsmQJ77//Pu+99x5PPfXUSbd78803s27dOubOncs333zDokWLuOuuu2piFyqtf//+7Nu3r8zjjjvuoHXr1vTp0+ek6955551l1nvhhRdqqeqqefrpp8vUe//99590+Ycffpivv/6azz77jIULF7J3716uueaaWqq28jZs2IDL5eLNN99k3bp1vPLKK0yZMoUnnnjilOvW1WM4ffp0Ro8ezdixY1m1ahU9evRg0KBB7N+/v9zllyxZwo033sjtt9/O6tWrGTp0KEOHDuX333+v5cpPbeHChYwaNYply5Yxd+5cioqKuOSSS8jNzT3pelFRUWWO1c6dO2up4qo544wzytS7ePHiCpf1peMHsGLFijL7NnfuXACuv/76Ctepy8cvNzeXHj16MHny5HLnv/DCC/zzn/9kypQpLF++nAYNGjBo0CAKCgoq3Ka3f8NVYuqJqVOnmujo6BOmz54929jtdpOWluaZ9sYbb5ioqChTWFhY7rb++OMPA5gVK1Z4pn377bfGZrOZ1NTUaq+9qhwOh2nSpIl5+umnT7rcwIEDzYMPPlg7RVWDli1bmldeeaXSyx85csQEBQWZzz77zDNt/fr1BjBLly6tgQqr1wsvvGBat2590mXq8jHs27evGTVqlOe10+k0TZs2NRMmTCh3+T/96U/m8ssvLzMtOTnZ3H333TVaZ3XYv3+/AczChQsrXKaiz6K6auzYsaZHjx6VXt6Xj58xxjz44IOmbdu2xuVylTvfl44fYL788kvPa5fLZRISEsyLL77omXbkyBETEhJi/vOf/1S4HW//hqvC71tGTmXp0qV069aN+Ph4z7RBgwaRlZXFunXrKlwnJiamTGtDSkoKdrud5cuX13jNlfXVV19x8OBBRo4cecplP/roI2JjY+natStjxowhLy+vFiqsuueee47GjRvTq1cvXnzxxZOeVlu5ciVFRUWkpKR4pnXq1IkWLVqwdOnS2ij3tGRmZtKoUaNTLlcXj6HD4WDlypVlfvd2u52UlJQKf/dLly4tszy4/yZ95VgBpzxeOTk5tGzZkqSkJK666qoKP2vqis2bN9O0aVPatGnDzTffzK5duypc1pePn8PhYNq0adx2220nvSmrrx2/Utu3byctLa3M8YmOjiY5ObnC41OVv+Gq8Ikb5dWktLS0MkEE8LxOS0urcJ24uLgy0wIDA2nUqFGF61jhnXfeYdCgQae8yeBNN91Ey5Ytadq0Kb/99huPPfYYGzdu5IsvvqilSr3zwAMPcOaZZ9KoUSOWLFnCmDFj2LdvHxMnTix3+bS0NIKDg0/oMxQfH1+njld5tmzZwmuvvcZLL7100uXq6jHMyMjA6XSW+ze2YcOGctep6G+yrh8rl8vFQw89xIABA+jatWuFy3Xs2JF3332X7t27k5mZyUsvvUT//v1Zt25djd4QtKqSk5N577336NixI/v27WPcuHGce+65/P7770RGRp6wvK8eP4AZM2Zw5MgRbr311gqX8bXjd6zSY+DN8anK33BV+GQYefzxx3n++edPusz69etP2cnKV1Rlf/fs2cN3333Hp59+esrtH9vXpVu3biQmJnLRRRexdetW2rZtW/XCveDNPo4ePdozrXv37gQHB3P33XczYcKEOjtcc1WOYWpqKpdeeinXX389d95550nXrQvHsL4bNWoUv//++0n7UwD069ePfv36eV7379+fzp078+abb/LMM8/UdJleGzx4sOd59+7dSU5OpmXLlnz66afcfvvtFlZW/d555x0GDx5M06ZNK1zG146fr/DJMPLII4+cNLkCtGnTplLbSkhIOKFXcOlVFgkJCRWuc3zHneLiYg4dOlThOqejKvs7depUGjduzJVXXun1+yUnJwPu/5XX1hfZ6RzT5ORkiouL2bFjBx07djxhfkJCAg6HgyNHjpRpHUlPT6+R41Ueb/dv7969XHDBBfTv35+33nrL6/ez4hiWJzY2loCAgBOuXDrZ7z4hIcGr5euC++67z9OR3dv/HQcFBdGrVy+2bNlSQ9VVr5iYGDp06FBhvb54/AB27tzJvHnzvG5N9KXjV3oM0tPTSUxM9ExPT0+nZ8+e5a5Tlb/hKqm23id13Kk6sKanp3umvfnmmyYqKsoUFBSUu63SDqy//PKLZ9p3331XZzqwulwu07p1a/PII49Uaf3FixcbwPz666/VXFnNmDZtmrHb7ebQoUPlzi/twPr55597pm3YsKHOdmDds2ePad++vbnhhhtMcXFxlbZRl45h3759zX333ed57XQ6TbNmzU7agfWKK64oM61fv351sgOky+Uyo0aNMk2bNjWbNm2q0jaKi4tNx44dzcMPP1zN1dWM7Oxs07BhQ/Pqq6+WO9+Xjt+xxo4daxISEkxRUZFX69Xl40cFHVhfeuklz7TMzMxKdWD15m+4SrVW25bqqJ07d5rVq1ebcePGmYiICLN69WqzevVqk52dbYxx/0Pq2rWrueSSS8yaNWvMnDlzTJMmTcyYMWM821i+fLnp2LGj2bNnj2fapZdeanr16mWWL19uFi9ebNq3b29uvPHGWt+/8sybN88AZv369SfM27Nnj+nYsaNZvny5McaYLVu2mKefftr88ssvZvv27WbmzJmmTZs25rzzzqvtsitlyZIl5pVXXjFr1qwxW7duNdOmTTNNmjQxw4cP9yxz/D4aY8xf/vIX06JFC/P999+bX375xfTr18/069fPil04qT179ph27dqZiy66yOzZs8fs27fP8zh2GV86hp988okJCQkx7733nvnjjz/MXXfdZWJiYjxXsN1yyy3m8ccf9yz/008/mcDAQPPSSy+Z9evXm7Fjx5qgoCCzdu1aq3ahQvfcc4+Jjo42CxYsKHOs8vLyPMscv3/jxo0z3333ndm6datZuXKlueGGG0xoaKhZt26dFbtwSo888ohZsGCB2b59u/npp59MSkqKiY2NNfv37zfG+PbxK+V0Ok2LFi3MY489dsI8Xzt+2dnZnu85wEycONGsXr3a7Ny50xhjzHPPPWdiYmLMzJkzzW+//Wauuuoq07p1a5Ofn+/ZxoUXXmhee+01z+tT/Q1XB78PIyNGjDDACY8ffvjBs8yOHTvM4MGDTVhYmImNjTWPPPJImXT8ww8/GMBs377dM+3gwYPmxhtvNBERESYqKsqMHDnSE3CsduONN5r+/fuXO2/79u1l9n/Xrl3mvPPOM40aNTIhISGmXbt25tFHHzWZmZm1WHHlrVy50iQnJ5vo6GgTGhpqOnfubMaPH1+mFev4fTTGmPz8fHPvvfeahg0bmvDwcHP11VeX+YKvK6ZOnVruv9djGzF98Ri+9tprpkWLFiY4ONj07dvXLFu2zDNv4MCBZsSIEWWW//TTT02HDh1McHCwOeOMM8ysWbNqueLKqehYTZ061bPM8fv30EMPeX4X8fHx5rLLLjOrVq2q/eIradiwYSYxMdEEBwebZs2amWHDhpktW7Z45vvy8Sv13XffGcBs3LjxhHm+dvxKv6+Of5Tug8vlMk8++aSJj483ISEh5qKLLjphv1u2bGnGjh1bZtrJ/oarg80YY6rvpI+IiIiId+r9OCMiIiJiLYURERERsZTCiIiIiFhKYUREREQspTAiIiIillIYEREREUspjIiIiIilFEZERETEUgojIiIiYimFEREREbGUwoiIiIhYSmFERGrdgQMHSEhIYPz48Z5pS5YsITg4mPnz51tYmYhYQTfKExFLzJ49m6FDh7JkyRI6duxIz549ueqqq5g4caLVpYlILVMYERHLjBo1innz5tGnTx/Wrl3LihUrCAkJsbosEallCiMiYpn8/Hy6du3K7t27WblyJd26dbO6JBGxgPqMiIhltm7dyt69e3G5XOzYscPqckTEImoZERFLOBwO+vbtS8+ePenYsSOTJk1i7dq1xMXFWV2aiNQyhRERscSjjz7K559/zq+//kpERAQDBw4kOjqab775xurSRKSW6TSNiNS6BQsWMGnSJD788EOioqKw2+18+OGH/Pjjj7zxxhtWlycitUwtIyIiImIptYyIiIiIpRRGRERExFIKIyIiImIphRERERGxlMKIiIiIWEphRERERCylMCIiIiKWUhgRERERSymMiIiIiKUURkRERMRSCiMiIiJiqf8PTCdlTkI8lFsAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(x, y, label='y')\n", "plt.plot(x, dy_dx, label='dy/dx')\n", "plt.legend()\n", "_ = plt.xlabel('x')" ] }, { "cell_type": "markdown", "metadata": { "id": "6kADybtQzYj4" }, "source": [ "## Control flow\n", "\n", "Because a gradient tape records operations as they are executed, Python control flow is naturally handled (for example, `if` and `while` statements).\n", "\n", "Here a different variable is used on each branch of an `if`. The gradient only connects to the variable that was used:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:13.502631Z", "iopub.status.busy": "2024-08-15T01:30:13.502403Z", "iopub.status.idle": "2024-08-15T01:30:13.510472Z", "shell.execute_reply": "2024-08-15T01:30:13.509868Z" }, "id": "ciFLizhrrjy7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(1.0, shape=(), dtype=float32)\n", "None\n" ] } ], "source": [ "x = tf.constant(1.0)\n", "\n", "v0 = tf.Variable(2.0)\n", "v1 = tf.Variable(2.0)\n", "\n", "with tf.GradientTape(persistent=True) as tape:\n", " tape.watch(x)\n", " if x > 0.0:\n", " result = v0\n", " else:\n", " result = v1**2 \n", "\n", "dv0, dv1 = tape.gradient(result, [v0, v1])\n", "\n", "print(dv0)\n", "print(dv1)" ] }, { "cell_type": "markdown", "metadata": { "id": "HKnLaiapsjeP" }, "source": [ "Just remember that the control statements themselves are not differentiable, so they are invisible to gradient-based optimizers.\n", "\n", "Depending on the value of `x` in the above example, the tape either records `result = v0` or `result = v1**2`. The gradient with respect to `x` is always `None`." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:13.513704Z", "iopub.status.busy": "2024-08-15T01:30:13.513457Z", "iopub.status.idle": "2024-08-15T01:30:13.517267Z", "shell.execute_reply": "2024-08-15T01:30:13.516678Z" }, "id": "8k05WmuAwPm7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None\n" ] } ], "source": [ "dx = tape.gradient(result, x)\n", "\n", "print(dx)" ] }, { "cell_type": "markdown", "metadata": { "id": "egypBxISAHhx" }, "source": [ "## Cases where `gradient` returns `None`\n", "\n", "When a target is not connected to a source, `gradient` will return `None`.\n" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:13.520596Z", "iopub.status.busy": "2024-08-15T01:30:13.520029Z", "iopub.status.idle": "2024-08-15T01:30:13.525557Z", "shell.execute_reply": "2024-08-15T01:30:13.524980Z" }, "id": "CU185WDM81Ut" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None\n" ] } ], "source": [ "x = tf.Variable(2.)\n", "y = tf.Variable(3.)\n", "\n", "with tf.GradientTape() as tape:\n", " z = y * y\n", "print(tape.gradient(z, x))" ] }, { "cell_type": "markdown", "metadata": { "id": "sZbKpHfBRJym" }, "source": [ "Here `z` is obviously not connected to `x`, but there are several less-obvious ways that a gradient can be disconnected." ] }, { "cell_type": "markdown", "metadata": { "id": "eHDzDOiQ8xmw" }, "source": [ "### 1. Replaced a variable with a tensor\n", "\n", "In the section on [\"controlling what the tape watches\"](#watches) you saw that the tape will automatically watch a `tf.Variable` but not a `tf.Tensor`.\n", "\n", "One common error is to inadvertently replace a `tf.Variable` with a `tf.Tensor`, instead of using `Variable.assign` to update the `tf.Variable`. Here is an example:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:13.528706Z", "iopub.status.busy": "2024-08-15T01:30:13.528492Z", "iopub.status.idle": "2024-08-15T01:30:13.534284Z", "shell.execute_reply": "2024-08-15T01:30:13.533635Z" }, "id": "QPKY4Tn9zX7_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ResourceVariable : tf.Tensor(1.0, shape=(), dtype=float32)\n", "EagerTensor : None\n" ] } ], "source": [ "x = tf.Variable(2.0)\n", "\n", "for epoch in range(2):\n", " with tf.GradientTape() as tape:\n", " y = x+1\n", "\n", " print(type(x).__name__, \":\", tape.gradient(y, x))\n", " x = x + 1 # This should be `x.assign_add(1)`" ] }, { "cell_type": "markdown", "metadata": { "id": "3gwZKxgA97an" }, "source": [ "### 2. Did calculations outside of TensorFlow\n", "\n", "The tape can't record the gradient path if the calculation exits TensorFlow.\n", "For example:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:13.537149Z", "iopub.status.busy": "2024-08-15T01:30:13.536925Z", "iopub.status.idle": "2024-08-15T01:30:13.544353Z", "shell.execute_reply": "2024-08-15T01:30:13.543704Z" }, "id": "jmoLCDJb_yw1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None\n" ] } ], "source": [ "x = tf.Variable([[1.0, 2.0],\n", " [3.0, 4.0]], dtype=tf.float32)\n", "\n", "with tf.GradientTape() as tape:\n", " x2 = x**2\n", "\n", " # This step is calculated with NumPy\n", " y = np.mean(x2, axis=0)\n", "\n", " # Like most ops, reduce_mean will cast the NumPy array to a constant tensor\n", " # using `tf.convert_to_tensor`.\n", " y = tf.reduce_mean(y, axis=0)\n", "\n", "print(tape.gradient(y, x))" ] }, { "cell_type": "markdown", "metadata": { "id": "p3YVfP3R-tp7" }, "source": [ "### 3. Took gradients through an integer or string\n", "\n", "Integers and strings are not differentiable. If a calculation path uses these data types there will be no gradient.\n", "\n", "Nobody expects strings to be differentiable, but it's easy to accidentally create an `int` constant or variable if you don't specify the `dtype`." ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:13.547270Z", "iopub.status.busy": "2024-08-15T01:30:13.547047Z", "iopub.status.idle": "2024-08-15T01:30:13.552690Z", "shell.execute_reply": "2024-08-15T01:30:13.552120Z" }, "id": "9jlHXHqfASU3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "None\n" ] } ], "source": [ "x = tf.constant(10)\n", "\n", "with tf.GradientTape() as g:\n", " g.watch(x)\n", " y = x * x\n", "\n", "print(g.gradient(y, x))" ] }, { "cell_type": "markdown", "metadata": { "id": "RsdP_mTHX9L1" }, "source": [ "TensorFlow doesn't automatically cast between types, so, in practice, you'll often get a type error instead of a missing gradient." ] }, { "cell_type": "markdown", "metadata": { "id": "WyAZ7C8qCEs6" }, "source": [ "### 4. Took gradients through a stateful object\n", "\n", "State stops gradients. When you read from a stateful object, the tape can only observe the current state, not the history that lead to it.\n", "\n", "A `tf.Tensor` is immutable. You can't change a tensor once it's created. It has a _value_, but no _state_. All the operations discussed so far are also stateless: the output of a `tf.matmul` only depends on its inputs.\n", "\n", "A `tf.Variable` has internal state—its value. When you use the variable, the state is read. It's normal to calculate a gradient with respect to a variable, but the variable's state blocks gradient calculations from going farther back. For example:\n" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:13.556090Z", "iopub.status.busy": "2024-08-15T01:30:13.555864Z", "iopub.status.idle": "2024-08-15T01:30:13.563022Z", "shell.execute_reply": "2024-08-15T01:30:13.562452Z" }, "id": "C1tLeeRFE479" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None\n" ] } ], "source": [ "x0 = tf.Variable(3.0)\n", "x1 = tf.Variable(0.0)\n", "\n", "with tf.GradientTape() as tape:\n", " # Update x1 = x1 + x0.\n", " x1.assign_add(x0)\n", " # The tape starts recording from x1.\n", " y = x1**2 # y = (x1 + x0)**2\n", "\n", "# This doesn't work.\n", "print(tape.gradient(y, x0)) #dy/dx0 = 2*(x1 + x0)" ] }, { "cell_type": "markdown", "metadata": { "id": "xKA92-dqF2r-" }, "source": [ "Similarly, `tf.data.Dataset` iterators and `tf.queue`s are stateful, and will stop all gradients on tensors that pass through them." ] }, { "cell_type": "markdown", "metadata": { "id": "HHvcDGIbOj2I" }, "source": [ "## No gradient registered" ] }, { "cell_type": "markdown", "metadata": { "id": "aoc-A6AxVqry" }, "source": [ "Some `tf.Operation`s are **registered as being non-differentiable** and will return `None`. Others have **no gradient registered**.\n", "\n", "The `tf.raw_ops` page shows which low-level ops have gradients registered.\n", "\n", "If you attempt to take a gradient through a float op that has no gradient registered the tape will throw an error instead of silently returning `None`. This way you know something has gone wrong.\n", "\n", "For example, the `tf.image.adjust_contrast` function wraps `raw_ops.AdjustContrastv2`, which could have a gradient but the gradient is not implemented:\n" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:13.566420Z", "iopub.status.busy": "2024-08-15T01:30:13.566184Z", "iopub.status.idle": "2024-08-15T01:30:13.578588Z", "shell.execute_reply": "2024-08-15T01:30:13.578030Z" }, "id": "HSb20FXc_V0U" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LookupError: gradient registry has no entry for: AdjustContrastv2\n" ] } ], "source": [ "image = tf.Variable([[[0.5, 0.0, 0.0]]])\n", "delta = tf.Variable(0.1)\n", "\n", "with tf.GradientTape() as tape:\n", " new_image = tf.image.adjust_contrast(image, delta)\n", "\n", "try:\n", " print(tape.gradient(new_image, [image, delta]))\n", " assert False # This should not happen.\n", "except LookupError as e:\n", " print(f'{type(e).__name__}: {e}')\n" ] }, { "cell_type": "markdown", "metadata": { "id": "pDoutjzATiEm" }, "source": [ "If you need to differentiate through this op, you'll either need to implement the gradient and register it (using `tf.RegisterGradient`) or re-implement the function using other ops." ] }, { "cell_type": "markdown", "metadata": { "id": "GCTwc_dQXp2W" }, "source": [ "## Zeros instead of None" ] }, { "cell_type": "markdown", "metadata": { "id": "TYDrVogA89eA" }, "source": [ "In some cases it would be convenient to get 0 instead of `None` for unconnected gradients. You can decide what to return when you have unconnected gradients using the `unconnected_gradients` argument:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "execution": { "iopub.execute_input": "2024-08-15T01:30:13.581859Z", "iopub.status.busy": "2024-08-15T01:30:13.581644Z", "iopub.status.idle": "2024-08-15T01:30:13.642074Z", "shell.execute_reply": "2024-08-15T01:30:13.641433Z" }, "id": "U6zxk1sf9Ixx" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([0. 0.], shape=(2,), dtype=float32)\n" ] } ], "source": [ "x = tf.Variable([2., 2.])\n", "y = tf.Variable(3.)\n", "\n", "with tf.GradientTape() as tape:\n", " z = y**2\n", "print(tape.gradient(z, x, unconnected_gradients=tf.UnconnectedGradients.ZERO))" ] } ], "metadata": { "colab": { "collapsed_sections": [ "Tce3stUlHN0L" ], "name": "autodiff.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }