Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/custom_gradient.py: 18%
215 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Decorator to overrides the gradient for a function."""
17from tensorflow.python.eager import backprop
18from tensorflow.python.eager import context
19from tensorflow.python.eager import record
20from tensorflow.python.framework import composite_tensor_gradient
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import gen_array_ops
25from tensorflow.python.ops import handle_data_util
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import op_selector
28from tensorflow.python.ops import resource_variable_ops
29from tensorflow.python.ops import variable_scope
30from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.util import nest
33from tensorflow.python.util import tf_decorator
34from tensorflow.python.util import tf_inspect
35from tensorflow.python.util import variable_utils
36from tensorflow.python.util.tf_export import tf_export
39VAR_OP_TYPES = [
40 "VariableV2",
41 "VarHandleOp",
42]
45@tf_export("custom_gradient")
46def custom_gradient(f=None):
47 """Decorator to define a function with a custom gradient.
49 This decorator allows fine grained control over the gradients of a sequence
50 for operations. This may be useful for multiple reasons, including providing
51 a more efficient or numerically stable gradient for a sequence of operations.
53 For example, consider the following function that commonly occurs in the
54 computation of cross entropy and log likelihoods:
56 ```python
57 def log1pexp(x):
58 return tf.math.log(1 + tf.exp(x))
59 ```
61 Due to numerical instability, the gradient of this function evaluated at x=100
62 is NaN. For example:
64 ```python
65 with tf.GradientTape() as tape:
66 tape.watch(x)
67 y=log1pexp(x)
68 dy_dx = tape.gradient(y, x) # Will be NaN when evaluated.
69 ```
71 The gradient expression can be analytically simplified to provide numerical
72 stability:
74 ```python
75 @tf.custom_gradient
76 def log1pexp(x):
77 e = tf.exp(x)
78 def grad(upstream):
79 return upstream * (1 - 1 / (1 + e))
80 return tf.math.log(1 + e), grad
81 ```
83 With this definition, the gradient `dy_dx` at `x = 100` will be correctly
84 evaluated as 1.0.
86 The variable `upstream` is defined as the upstream gradient. i.e. the gradient
87 from all the layers or functions originating from this layer. The above
88 example has no upstream functions, therefore `upstream = dy/dy = 1.0`.
90 Assume that `x_i` is `log1pexp` in the forward pass `x_1 = x_1(x_0)`,
91 `x_2 = x_2(x_1)`, ..., `x_i = x_i(x_i-1)`, ..., `x_n = x_n(x_n-1)`. By
92 chain rule we know that `dx_n/dx_0 = dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... *
93 dx_i/dx_i-1 * ... * dx_1/dx_0`.
95 In this case the gradient of our current function defined as
96 `dx_i/dx_i-1 = (1 - 1 / (1 + e))`. The upstream gradient `upstream` would be
97 `dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * dx_i+1/dx_i`. The upstream gradient
98 multiplied by the current gradient is then passed downstream.
100 In case the function takes multiple variables as input, the `grad`
101 function must also return the same number of variables.
102 We take the function `z = x * y` as an example.
104 >>> @tf.custom_gradient
105 ... def bar(x, y):
106 ... def grad(upstream):
107 ... dz_dx = y
108 ... dz_dy = x
109 ... return upstream * dz_dx, upstream * dz_dy
110 ... z = x * y
111 ... return z, grad
112 >>> x = tf.constant(2.0, dtype=tf.float32)
113 >>> y = tf.constant(3.0, dtype=tf.float32)
114 >>> with tf.GradientTape(persistent=True) as tape:
115 ... tape.watch(x)
116 ... tape.watch(y)
117 ... z = bar(x, y)
118 >>> z
119 <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
120 >>> tape.gradient(z, x)
121 <tf.Tensor: shape=(), dtype=float32, numpy=3.0>
122 >>> tape.gradient(z, y)
123 <tf.Tensor: shape=(), dtype=float32, numpy=2.0>
125 Nesting custom gradients can lead to unintuitive results. The default
126 behavior does not correspond to n-th order derivatives. For example
128 ```python
129 @tf.custom_gradient
130 def op(x):
131 y = op1(x)
132 @tf.custom_gradient
133 def grad_fn(dy):
134 gdy = op2(x, y, dy)
135 def grad_grad_fn(ddy): # Not the 2nd order gradient of op w.r.t. x.
136 return op3(x, y, dy, ddy)
137 return gdy, grad_grad_fn
138 return y, grad_fn
139 ```
141 The function `grad_grad_fn` will be calculating the first order gradient
142 of `grad_fn` with respect to `dy`, which is used to generate forward-mode
143 gradient graphs from backward-mode gradient graphs, but is not the same as
144 the second order gradient of `op` with respect to `x`.
146 Instead, wrap nested `@tf.custom_gradients` in another function:
148 ```python
149 @tf.custom_gradient
150 def op_with_fused_backprop(x):
151 y, x_grad = fused_op(x)
152 def first_order_gradient(dy):
153 @tf.custom_gradient
154 def first_order_custom(unused_x):
155 def second_order_and_transpose(ddy):
156 return second_order_for_x(...), gradient_wrt_dy(...)
157 return x_grad, second_order_and_transpose
158 return dy * first_order_custom(x)
159 return y, first_order_gradient
160 ```
162 Additional arguments to the inner `@tf.custom_gradient`-decorated function
163 control the expected return values of the innermost function.
165 The examples above illustrate how to specify custom gradients for functions
166 which do not read from variables. The following example uses variables, which
167 require special handling because they are effectively inputs of the forward
168 function.
170 >>> weights = tf.Variable(tf.ones([2])) # Trainable variable weights
171 >>> @tf.custom_gradient
172 ... def linear_poly(x):
173 ... # Creating polynomial
174 ... poly = weights[1] * x + weights[0]
175 ...
176 ... def grad_fn(dpoly, variables):
177 ... # dy/dx = weights[1] and we need to left multiply dpoly
178 ... grad_xs = dpoly * weights[1] # Scalar gradient
179 ...
180 ... grad_vars = [] # To store gradients of passed variables
181 ... assert variables is not None
182 ... assert len(variables) == 1
183 ... assert variables[0] is weights
184 ... # Manually computing dy/dweights
185 ... dy_dw = dpoly * tf.stack([x ** 1, x ** 0])
186 ... grad_vars.append(
187 ... tf.reduce_sum(tf.reshape(dy_dw, [2, -1]), axis=1)
188 ... )
189 ... return grad_xs, grad_vars
190 ... return poly, grad_fn
191 >>> x = tf.constant([1., 2., 3.])
192 >>> with tf.GradientTape(persistent=True) as tape:
193 ... tape.watch(x)
194 ... poly = linear_poly(x)
195 >>> poly # poly = x + 1
196 <tf.Tensor: shape=(3,),
197 dtype=float32,
198 numpy=array([2., 3., 4.], dtype=float32)>
199 >>> tape.gradient(poly, x) # conventional scalar gradient dy/dx
200 <tf.Tensor: shape=(3,),
201 dtype=float32,
202 numpy=array([1., 1., 1.], dtype=float32)>
203 >>> tape.gradient(poly, weights)
204 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 3.], dtype=float32)>
206 Above example illustrates usage of trainable variable `weights`.
207 In the example, the inner `grad_fn` accepts an extra `variables` input
208 parameter and also returns an extra `grad_vars` output. That extra argument
209 is passed if the forward function reads any variables. You need to
210 compute the gradient w.r.t. each of those `variables` and output it as a list
211 of `grad_vars`. Note here that default value of `variables` is set to `None`
212 when no variables are used in the forward function.
214 It should be noted `tf.GradientTape` is still watching the forward pass of a
215 `tf.custom_gradient`, and will use the ops it watches. As a consequence,
216 calling `tf.function` while the tape is still watching leads
217 to a gradient graph being built. If an op is used in `tf.function` without
218 registered gradient, a `LookupError` will be raised.
220 Users can insert `tf.stop_gradient` to customize this behavior. This
221 is demonstrated in the example below. `tf.random.shuffle` does not have a
222 registered gradient. As a result `tf.stop_gradient` is used to avoid the
223 `LookupError`.
225 ```python
226 x = tf.constant([0.3, 0.5], dtype=tf.float32)
228 @tf.custom_gradient
229 def test_func_with_stop_grad(x):
230 @tf.function
231 def _inner_func():
232 # Avoid exception during the forward pass
233 return tf.stop_gradient(tf.random.shuffle(x))
234 # return tf.random.shuffle(x) # This will raise
236 res = _inner_func()
237 def grad(upstream):
238 return upstream # Arbitrarily defined custom gradient
239 return res, grad
241 with tf.GradientTape() as g:
242 g.watch(x)
243 res = test_func_with_stop_grad(x)
245 g.gradient(res, x)
246 ```
248 See also `tf.RegisterGradient` which registers a gradient function for a
249 primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
250 for fine grained control over the gradient computation of a sequence of
251 operations.
253 Note that if the decorated function uses `Variable`s, the enclosing variable
254 scope must be using
255 [ResourceVariables](https://www.tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables).
257 Args:
258 f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
259 - `x` is a sequence of (nested structures of) `Tensor` inputs to the
260 function.
261 - `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow
262 operations in `f` to `x`.
263 - `grad_fn` is a function with the signature `g(*grad_ys)` which returns
264 a list of `Tensor`s the same size as (flattened) `x` - the derivatives
265 of `Tensor`s in `y` with respect to the `Tensor`s in `x`. `grad_ys` is
266 a sequence of `Tensor`s the same size as (flattened) `y` holding the
267 initial value gradients for each `Tensor` in `y`.
269 In a pure mathematical sense, a vector-argument vector-valued function
270 `f`'s derivatives should be its Jacobian matrix `J`. Here we are
271 expressing the Jacobian `J` as a function `grad_fn` which defines how
272 `J` will transform a vector `grad_ys` when left-multiplied with it
273 (`grad_ys * J`, the vector-Jacobian product, or VJP). This functional
274 representation of a matrix is convenient to use for chain-rule
275 calculation (in e.g. the back-propagation algorithm).
277 If `f` uses `Variable`s (that are not part of the
278 inputs), i.e. through `get_variable`, then `grad_fn` should have
279 signature `g(*grad_ys, variables=None)`, where `variables` is a list of
280 the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where
281 `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
282 with the derivatives of `Tensor`s in `y` with respect to the variables
283 (that is, grad_vars has one Tensor per variable in variables).
285 Returns:
286 A function `h(x)` which returns the same value as `f(x)[0]` and whose
287 gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`.
288 """
290 if f is None:
291 return lambda f: custom_gradient(f=f)
293 @Bind.decorator
294 def decorated(wrapped, args, kwargs):
295 """Decorated function with custom gradient."""
296 if context.executing_eagerly():
297 return _eager_mode_decorator(wrapped, args, kwargs)
298 else:
299 return _graph_mode_decorator(wrapped, args, kwargs)
301 return tf_decorator.make_decorator(f, decorated(f)) # pylint: disable=no-value-for-parameter
304class Bind:
305 """When called evaluates `d(f, args, kwargs)` but supports binding `f`.
307 >>> @Bind.decorator
308 ... def my_decorator(f, args, kwargs):
309 ... print("my_decorator called with", args, kwargs)
310 ... return f(*args, **kwargs)
312 >>> class Foo:
313 ... @my_decorator
314 ... def bar(self, a, b, c):
315 ... return a * b * c
317 >>> Foo.bar(None, 1, 2, c=3)
318 my_decorator called with (None, 1, 2) {'c': 3}
319 6
321 >>> foo = Foo()
322 >>> foo.bar(1, 2, c=3)
323 my_decorator called with (1, 2) {'c': 3}
324 6
325 """
327 @classmethod
328 def decorator(cls, d):
329 return lambda f: Bind(f, d)
331 def __init__(self, f, d):
332 self._f = f
333 self._d = d
335 def __get__(self, instance, owner):
336 if instance is not None:
337 f = self._f.__get__(instance, owner)
338 return tf_decorator.make_decorator(f, Bind(f, self._d))
339 else:
340 return self
342 def __call__(self, *a, **k):
343 return self._d(self._f, a, k)
346def get_variable_by_name(var_name):
347 """Given a variable name, retrieves a handle on the tensorflow Variable."""
348 global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
350 def _filter_fn(item):
351 try:
352 return var_name == item.op.name
353 except AttributeError:
354 # Collection items without operation are ignored.
355 return False
357 candidate_vars = list(filter(_filter_fn, global_vars))
359 if len(candidate_vars) >= 1:
360 # Filter out non-trainable variables.
361 candidate_vars = [v for v in candidate_vars if v.trainable]
362 else:
363 raise ValueError("Unsuccessful at finding variable {}.".format(var_name))
365 if len(candidate_vars) == 1:
366 return candidate_vars[0]
367 elif len(candidate_vars) > 1:
368 raise ValueError(
369 "Unsuccessful at finding trainable variable {}. "
370 "Number of candidates: {}. "
371 "Candidates: {}".format(var_name, len(candidate_vars), candidate_vars))
372 else:
373 # The variable is not trainable.
374 return None
377def _get_dependent_variables(input_ops, output_ops):
378 """Finds variables involved in the subgraph between input_ops and output_ops.
380 Args:
381 input_ops: Flattened list of input ops
382 output_ops: Flattened list of output ops
384 Returns:
385 A list of variables
386 """
388 # avoids the edge-case when input_ops == output_ops.
389 output_ops = nest.map_structure(gen_array_ops.identity, output_ops)
390 inbetween_ops = op_selector.get_backward_walk_ops(
391 seed_ops=output_ops,
392 stop_at_ts=input_ops,
393 inclusive=False,
394 only_differentiable=True)
395 var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES)
396 var_names = (op.name for op in var_ops)
397 tf_vars = (get_variable_by_name(var_name) for var_name in var_names)
398 tf_vars = [v for v in tf_vars if v is not None]
399 return tf_vars
402def generate_name():
403 return "CustomGradient-%s" % ops.uid()
406def _graph_mode_decorator(f, args, kwargs):
407 """Implement custom gradient decorator for graph mode."""
408 # TODO(rsepassi): Add support for kwargs
409 if kwargs:
410 raise ValueError(
411 "The custom_gradient decorator currently supports keywords "
412 "arguments only when eager execution is enabled.")
413 name = generate_name()
414 args = variable_utils.convert_variables_to_tensors(args)
415 args = nest.map_structure(ops.convert_to_tensor, args, expand_composites=True)
417 # Checking global and local variables attempts to ensure that no non-resource
418 # Variables are added to the graph.
419 current_var_scope = variable_scope.get_variable_scope()
420 before_vars = set([
421 v.ref() for v in current_var_scope.global_variables() +
422 current_var_scope.local_variables()
423 ])
424 with record.VariableWatcher() as variable_watcher:
425 result, grad_fn = f(*args)
427 flat_args = composite_tensor_gradient.get_flat_tensors_for_gradients(
428 nest.flatten(args))
429 flat_result = composite_tensor_gradient.get_flat_tensors_for_gradients(
430 nest.flatten(result))
431 flat_result_len = len(flat_result)
433 after_vars = set([
434 v.ref() for v in current_var_scope.global_variables() +
435 current_var_scope.local_variables()
436 ])
437 new_vars = after_vars - before_vars
438 new_vars_list = [v.deref() for v in new_vars]
439 for v in new_vars_list:
440 if not resource_variable_ops.is_resource_variable(v):
441 raise TypeError(
442 "All variables used by a function wrapped with @custom_gradient must "
443 "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
444 "with `use_resource=False`.")
446 # The variables that grad_fn needs to return gradients for are the set of
447 # variables used that are *not* part of the inputs.
448 variables_in_tape = frozenset([
449 v.ref() for v in variable_watcher.watched_variables()
450 ])
452 graphs = {getattr(o, "graph", None) for o in flat_result}
453 # Not all results may be tensors. However, we want to ensure all tensor
454 # outputs are from the same graph and get a list of captured inputs for
455 # variable search
456 graphs.discard(None) # Discard non-graph outputs
457 if graphs:
458 if len(graphs) > 1:
459 raise ValueError(
460 "All custom_gradient outputs should be from the same graph")
461 output_graph = graphs.pop()
462 filtered_input_tensors = []
463 for i in flat_args:
464 if i.graph == output_graph:
465 filtered_input_tensors.append(i)
466 else:
467 filtered_input_tensors = flat_args
469 variables_in_subgraph = frozenset([
470 v.ref() for v in _get_dependent_variables(
471 input_ops=filtered_input_tensors, output_ops=flat_result)
472 ])
473 variables = sorted(
474 [v.deref() for v in variables_in_subgraph.union(variables_in_tape)],
475 key=lambda v: v.name)
477 grad_argspec = tf_inspect.getfullargspec(grad_fn)
478 variables_in_signature = ("variables" in grad_argspec.args or
479 "variables" in grad_argspec.kwonlyargs or
480 grad_argspec.varkw)
481 if variables and not variables_in_signature:
482 raise TypeError(
483 "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
484 "since function uses variables: {}".format(variables))
485 if variables_in_signature and not variables:
486 # User seems to intend to use variables but none were captured.
487 logging.vlog(
488 1, "@custom_gradient grad_fn has 'variables' in signature, "
489 "but no ResourceVariables were used on the forward pass.")
491 all_tensors = flat_result + flat_args + variables
493 def tape_grad_fn(*result_grad_components):
494 """Custom grad fn wrapper."""
495 result_grads = composite_tensor_gradient.replace_flat_tensors_for_gradients(
496 nest.flatten(result), result_grad_components[:flat_result_len])
497 if not isinstance(result_grads, (list, tuple)):
498 result_grads = [result_grads]
500 if variables:
501 input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
502 if len(variable_grads) != len(variables):
503 raise ValueError("Must return gradient for each variable from "
504 "@custom_gradient grad_fn.")
505 else:
506 input_grads = grad_fn(*result_grads)
507 variable_grads = []
509 # Need to return one value per input to the IdentityN, so pad the
510 # gradients of the inputs of the custom_gradient function with the
511 # gradients of the outputs as well.
512 input_grads = composite_tensor_gradient.get_flat_tensors_for_gradients(
513 nest.flatten(input_grads))
514 return ([None] * flat_result_len) + input_grads + variable_grads
516 @ops.RegisterGradient(name)
517 def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable
518 """Custom grad fn wrapper."""
519 return tape_grad_fn(*result_grads)
521 original_tensors = all_tensors
522 with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
523 all_tensors = array_ops.identity_n(all_tensors)
525 original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]
527 # Propagate handle data for happier shape inference for resource variables.
528 for i, t in enumerate(original_tensors):
529 if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
530 all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access
531 record.record_operation(
532 f.__name__, all_tensors, original_tensors, tape_grad_fn)
533 for ot, t in zip(original_tensors, all_tensors):
534 handle_data_util.copy_handle_data(ot, t)
535 flat_result = composite_tensor_gradient.replace_flat_tensors_for_gradients(
536 nest.flatten(result), all_tensors[:flat_result_len])
537 return nest.pack_sequence_as(result, flat_result)
540def _eager_mode_decorator(f, args, kwargs):
541 """Implement custom gradient decorator for eager mode."""
542 with record.VariableWatcher() as variable_watcher:
543 result, grad_fn = f(*args, **kwargs)
544 flat_args = composite_tensor_gradient.get_flat_tensors_for_gradients(
545 nest.flatten(args))
546 flat_kwargs = composite_tensor_gradient.get_flat_tensors_for_gradients(
547 nest.flatten(kwargs))
548 all_inputs = flat_args + flat_kwargs
549 # The variables that grad_fn needs to return gradients for are the set of
550 # variables used that are *not* part of the inputs.
551 variables = [
552 v.deref() # pylint: disable=g-complex-comprehension
553 for v in set(v.ref() for v in variable_watcher.watched_variables())
554 if all(v.deref() is not i for i in all_inputs)
555 ]
556 grad_argspec = tf_inspect.getfullargspec(grad_fn)
557 if (variables and ("variables" not in grad_argspec.args) and
558 ("variables" not in grad_argspec.kwonlyargs) and
559 not grad_argspec.varkw):
560 raise TypeError(
561 "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
562 "since function uses variables: {}".format(variables))
563 flat_result = composite_tensor_gradient.get_flat_tensors_for_gradients(
564 nest.flatten(result))
565 # TODO(apassos) consider removing the identity below.
566 flat_result = [gen_array_ops.identity(x) for x in flat_result]
568 input_tensors = [
569 ops.convert_to_tensor(x) for x in flat_args + list(variables)]
571 recorded_inputs = input_tensors
572 arg_count = len(flat_args)
574 def actual_grad_fn(*result_grad_components):
575 """Custom grad fn wrapper."""
576 result_grads = composite_tensor_gradient.replace_flat_tensors_for_gradients(
577 nest.flatten(result), result_grad_components)
578 if not isinstance(result_grads, (list, tuple)):
579 result_grads = [result_grads]
581 if variables:
582 input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
583 if len(variable_grads) != len(variables):
584 raise ValueError("Must return gradient for each variable from "
585 "@custom_gradient grad_fn.")
586 else:
587 input_grads = grad_fn(*result_grads)
588 variable_grads = []
589 flat_grads = composite_tensor_gradient.get_flat_tensors_for_gradients(
590 nest.flatten(input_grads))
591 if len(flat_grads) != arg_count:
592 raise ValueError(
593 f"custom_gradient function expected to return {arg_count} "
594 f"gradients, but returned {len(flat_grads)} instead.")
595 return flat_grads + variable_grads
597 record.record_operation(f.__name__, flat_result, recorded_inputs,
598 actual_grad_fn)
599 flat_result = composite_tensor_gradient.replace_flat_tensors_for_gradients(
600 nest.flatten(result), flat_result)
601 return nest.pack_sequence_as(result, flat_result)
604@tf_export("recompute_grad")
605def recompute_grad(f):
606 """Defines a function as a recompute-checkpoint for the tape auto-diff.
608 Tape checkpointing is a technique to reduce the memory consumption of the
609 auto-diff tape:
611 - Without tape checkpointing operations and intermediate values are
612 recorded to the tape for use in the backward pass.
614 - With tape checkpointing, only the function call and its inputs are
615 recorded. During back-propagation the `recompute_grad` custom gradient
616 (`tf.custom_gradient`) recomputes the function under a localized Tape object.
617 This recomputation of the function during backpropagation performs redundant
618 calculation, but reduces the overall memory usage of the Tape.
620 >>> y = tf.Variable(1.0)
622 >>> def my_function(x):
623 ... tf.print('running')
624 ... z = x*y
625 ... return z
627 >>> my_function_recompute = tf.recompute_grad(my_function)
629 >>> with tf.GradientTape() as tape:
630 ... r = tf.constant(1.0)
631 ... for i in range(4):
632 ... r = my_function_recompute(r)
633 running
634 running
635 running
636 running
638 >>> grad = tape.gradient(r, [y])
639 running
640 running
641 running
642 running
644 Without `recompute_grad`, the tape contains all intermitate steps, and no
645 recomputation is performed.
647 >>> with tf.GradientTape() as tape:
648 ... r = tf.constant(1.0)
649 ... for i in range(4):
650 ... r = my_function(r)
651 running
652 running
653 running
654 running
656 >>> grad = tape.gradient(r, [y])
659 If `f` was a `tf.keras` `Model` or `Layer` object, methods and attributes
660 such as `f.variables` are not available on the returned function `g`.
661 Either keep a reference of `f` , or use `g.__wrapped__` for accessing
662 these variables and methods.
665 >>> def print_running_and_return(x):
666 ... tf.print("running")
667 ... return x
669 >>> model = tf.keras.Sequential([
670 ... tf.keras.layers.Lambda(print_running_and_return),
671 ... tf.keras.layers.Dense(2)
672 ... ])
674 >>> model_recompute = tf.recompute_grad(model)
676 >>> with tf.GradientTape(persistent=True) as tape:
677 ... r = tf.constant([[1,2]])
678 ... for i in range(4):
679 ... r = model_recompute(r)
680 running
681 running
682 running
683 running
685 >>> grad = tape.gradient(r, model.variables)
686 running
687 running
688 running
689 running
691 Alternatively, use the `__wrapped__` attribute to access the original
692 model object.
694 >>> grad = tape.gradient(r, model_recompute.__wrapped__.variables)
695 running
696 running
697 running
698 running
701 Args:
702 f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs.
704 Returns:
705 A function `g` wrapping `f` that defines a custom gradient, which recomputes
706 `f` on the backwards pass of a gradient call.
707 """
708 # TODO(cdfreeman) Add is_recomputing functionality from graph mode version
710 @custom_gradient
711 def inner(*args, **kwargs):
712 """Inner function closure for calculating gradients."""
713 current_var_scope = variable_scope.get_variable_scope()
714 with record.stop_recording():
715 result = f(*args, **kwargs)
717 def grad_wrapper(*wrapper_args, variables=None):
718 """Wrapper function to accomodate lack of kwargs in graph mode custom_gradient."""
720 @custom_gradient
721 def inner_recompute_grad(*dresult):
722 """Nested custom gradient function for computing grads in reverse and forward mode autodiff."""
723 # Gradient calculation for reverse mode autodiff.
724 with backprop.GradientTape() as t:
725 id_args = nest.map_structure(gen_array_ops.identity, args)
726 # Tuple `dresult` should contain at least one tensor.
727 assert len(dresult) >= 1
729 if not context.executing_eagerly():
730 # XLA doesn't respect `tf.control_dependencies`. The code block
731 # below manually adds a data dependency to `dresult` to ensure
732 # recomputation of `f(*args, **kwargs)` happens after `dresult`.
734 # This works even if `dresult[0]` is a size 0 tensor as reduce_max
735 # of a size 0 tensor returns -inf. Use reshape here to avoid reading
736 # the entire `dresult[0]`.
737 elem = math_ops.reduce_max(array_ops.reshape(dresult[0], [-1])[:1])
738 # Cast elem to bool in case elem is NaN.
739 elem_bool = math_ops.cast(elem, dtypes.bool)
740 dresult_dep = array_ops.where_v2(
741 elem_bool == elem_bool, 0., float("nan")) # pylint: disable=comparison-with-itself
742 id_args = nest.map_structure(
743 lambda x: x + math_ops.cast(dresult_dep, x.dtype), id_args)
745 t.watch(id_args)
746 if variables is not None:
747 t.watch(variables)
748 with variable_scope.variable_scope(current_var_scope):
749 recomputed_result = f(*id_args, **kwargs)
750 kw_vars = []
751 if variables is not None:
752 kw_vars = list(variables)
753 grads = t.gradient(
754 recomputed_result,
755 list(id_args) + kw_vars,
756 output_gradients=dresult,
757 unconnected_gradients=UnconnectedGradients.ZERO)
759 def transpose(*t_args, **t_kwargs):
760 """Gradient function calculation for forward mode autodiff."""
761 # Just throw an error since gradients / activations are not stored on
762 # tape for recompute.
763 raise NotImplementedError(
764 "recompute_grad tried to transpose grad of {}. "
765 "Consider not using recompute_grad in forward mode"
766 "autodiff".format(f.__name__))
768 return (grads[:len(id_args)], grads[len(id_args):]), transpose
770 return inner_recompute_grad(*wrapper_args)
772 return result, grad_wrapper
774 return tf_decorator.make_decorator(f, inner)
777@tf_export("grad_pass_through")
778def grad_pass_through(f):
779 """Creates a grad-pass-through op with the forward behavior provided in f.
781 Use this function to wrap any op, maintaining its behavior in the forward
782 pass, but replacing the original op in the backward graph with an identity.
783 For example:
785 ```python
786 x = tf.Variable(1.0, name="x")
787 z = tf.Variable(3.0, name="z")
789 with tf.GradientTape() as tape:
790 # y will evaluate to 9.0
791 y = tf.grad_pass_through(x.assign)(z**2)
792 # grads will evaluate to 6.0
793 grads = tape.gradient(y, z)
794 ```
796 Another example is a 'differentiable' moving average approximation, where
797 gradients are allowed to flow into the last value fed to the moving average,
798 but the moving average is still used for the forward pass:
800 ```python
801 x = ... # Some scalar value
802 # A moving average object, we don't need to know how this is implemented
803 moving_average = MovingAverage()
804 with backprop.GradientTape() as tape:
805 # mavg_x will evaluate to the current running average value
806 mavg_x = tf.grad_pass_through(moving_average)(x)
807 grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0
808 ```
810 Args:
811 f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor`
812 outputs.
814 Returns:
815 A function `h(x)` which returns the same values as `f(x)` and whose
816 gradients are the same as those of an identity function.
817 """
818 @custom_gradient
819 def _grad_pass_through_op(*args, **kwargs):
820 def grad(*args, **kwargs):
821 variables = kwargs.get("variables")
822 if variables is not None:
823 # Variables involved in the wrapped op will not receive gradients.
824 return args, [None] * len(variables)
825 return args
826 return f(*args, **kwargs), grad
827 return tf_decorator.make_decorator(f, _grad_pass_through_op)