Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/gradients_util.py: 15%
459 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements the graph generation for computation of gradients."""
17import collections
18import contextlib
20from tensorflow.core.framework import attr_value_pb2
21from tensorflow.python import pywrap_tfe
22from tensorflow.python.eager import backprop_util
23from tensorflow.python.eager import context
24from tensorflow.python.framework import composite_tensor
25from tensorflow.python.framework import composite_tensor_gradient
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import indexed_slices
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import control_flow_state
33from tensorflow.python.ops import control_flow_util
34from tensorflow.python.ops import default_gradient
35from tensorflow.python.ops import gen_functional_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import resource_variable_ops
38from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.util import compat
41from tensorflow.python.util import object_identity
42from tensorflow.python.util import variable_utils
43from tensorflow.python.util.compat import collections_abc
44from tensorflow.python.util.tf_export import tf_export
47def _MarkReachedOps(from_ops, reached_ops, func_graphs):
48 """Mark all ops reached from "from_ops".
50 Args:
51 from_ops: list of Operations.
52 reached_ops: set of Operations.
53 func_graphs: list of FuncGraphs. This method will traverse through
54 these functions if they capture from_ops or any reachable ops.
55 """
56 queue = collections.deque()
57 queue.extend(from_ops)
58 while queue:
59 op = queue.popleft()
60 if op not in reached_ops:
61 reached_ops.add(op)
62 for output in op.outputs:
63 if backprop_util.IsTrainable(output):
64 queue.extend(_Consumers(output, func_graphs))
67def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
68 xs_set):
69 """Initialize the pending count for ops between two lists of Operations.
71 'pending_count[op]' indicates the number of backprop inputs
72 to this operation.
74 Args:
75 to_ops: list of Operations.
76 from_ops: list of Operations.
77 colocate_gradients_with_ops: Python bool. See docstring of gradients().
78 func_graphs: list of FuncGraphs. This method will traverse through
79 these functions if they capture from_ops or any reachable ops. This is
80 useful if to_ops occur in a function and from_ops are in an outer function
81 or graph.
82 xs_set: ObjectIdentitySet of Tensors.
84 Returns:
85 A tuple containing: (1) the subset of to_ops reachable from from_ops by a
86 path of zero or more backpropagatable tensors, (2) a mapping from operation
87 to the number of backprop inputs to that op, and (3) a ControlFlowState
88 object which is not None if the ops between from_ops and to_ops contain
89 control flow loops.
90 """
91 # Mark reachable ops from from_ops.
92 reached_ops = set()
93 _MarkReachedOps(from_ops, reached_ops, func_graphs)
94 # X in reached_ops iff X is reachable from from_ops by a path of zero or more
95 # backpropagatable tensors.
97 reachable_to_ops = set(op for op in to_ops if op in reached_ops)
99 # Mark between ops.
100 between_ops = set()
101 between_op_list = []
102 queue = collections.deque()
103 queue.extend(to_ops)
104 while queue:
105 op = queue.popleft()
106 # We are interested in this op.
107 if op in reached_ops:
108 between_ops.add(op)
109 between_op_list.append(op)
110 # Clear the boolean so we won't add the inputs again.
111 reached_ops.remove(op)
112 for inp in _NonEagerInputs(op, xs_set):
113 queue.append(inp.op)
114 # X in between_ops iff X is on a path of zero or more backpropagatable tensors
115 # between from_ops and to_ops
117 # 'loop_state' is None if there are no while loops.
118 loop_state = control_flow_state.MaybeCreateControlFlowState(
119 between_op_list, between_ops, colocate_gradients_with_ops)
121 # Initialize pending count for between ops.
122 pending_count = collections.defaultdict(int)
123 for op in between_op_list:
124 for x in _NonEagerInputs(op, xs_set):
125 if x.op in between_ops:
126 pending_count[x.op] += 1
128 return reachable_to_ops, pending_count, loop_state
131def _AsList(x):
132 return x if isinstance(x, (list, tuple)) else [x]
135def _DefaultGradYs(grad_ys,
136 ys,
137 colocate_gradients_with_ops,
138 gradient_uid="__unsupported__"):
139 """Fill in default values for grad_ys.
141 Args:
142 grad_ys: List of gradients, can contain None.
143 ys: List of tensors.
144 colocate_gradients_with_ops: If True, try colocating gradients with
145 the corresponding op.
146 gradient_uid: A unique identifier within the graph indicating
147 which invocation of gradients is being executed. Used to cluster
148 ops for compilation.
150 Returns:
151 A list of gradients to use, without None.
153 Raises:
154 ValueError: If sizes of gradients and inputs don't match
155 TypeError: If type of any gradient is not valid for its input.
156 """
157 if len(grad_ys) != len(ys):
158 raise ValueError(f"Length mismatch. Passed {len(grad_ys)} grad_ys for "
159 f"{len(ys)} ys")
160 grad_ys = indexed_slices.convert_n_to_tensor_or_indexed_slices(
161 grad_ys, name="grad_y")
162 new_grad_ys = []
163 for i, (y, grad_y) in enumerate(zip(ys, grad_ys)):
164 with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops):
165 if grad_y is None:
166 if y.dtype.is_complex:
167 raise TypeError(
168 f"Gradients of complex tensors ({y}) must set grad_ys (y.dtype = "
169 f"{dtypes.as_dtype(y.dtype).name})")
170 new_grad_ys.append(
171 array_ops.ones(
172 array_ops.shape(y), dtype=y.dtype, name="grad_ys_%d" % i))
173 continue
174 if y.dtype.is_floating or y.dtype.is_integer:
175 if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
176 raise TypeError(
177 f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated "
178 f"for real or integer-valued tensor {y} with type "
179 f"{dtypes.as_dtype(y.dtype).name} must be real or integer")
180 elif y.dtype.is_complex:
181 if not grad_y.dtype.is_complex:
182 raise TypeError(
183 f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated "
184 f"for complex-valued tensor {y} with type "
185 f"{dtypes.as_dtype(y.dtype).name} must be real")
186 elif y.dtype == dtypes.variant:
187 if grad_y.dtype != dtypes.variant:
188 raise TypeError(
189 f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated "
190 f"for variant tensor {y} with type "
191 f"{dtypes.as_dtype(y.dtype).name} must be variant")
192 elif y.dtype == dtypes.resource:
193 # We assume y is the handle of a ResourceVariable. The gradient of a
194 # ResourceVariable should be a numeric value, not another resource.
195 if grad_y.dtype == dtypes.resource:
196 raise TypeError(f"Input gradient {grad_y} for resource tensor {y} "
197 "should not be a resource")
198 else:
199 raise TypeError(
200 f"Tensor {y} with type {dtypes.as_dtype(y.dtype).name} must be "
201 "numeric to obtain a default gradient")
202 # Create a grad_y tensor in the name scope of the gradient.
203 # Required for TensorArrays to identify which gradient call a
204 # grad_y value is coming from.
205 if isinstance(grad_y, indexed_slices.IndexedSlices):
206 new_grad_ys.append(
207 indexed_slices.IndexedSlices(
208 indices=(array_ops.identity(
209 grad_y.indices, name="grad_ys_%d_indices" % i)
210 if isinstance(grad_y.indices, ops.Tensor) else
211 grad_y.indices),
212 values=(array_ops.identity(
213 grad_y.values, name="grad_ys_%d_values" % i) if isinstance(
214 grad_y.values, ops.Tensor) else grad_y.values),
215 dense_shape=(array_ops.identity(
216 grad_y.dense_shape, name="grad_ys_%d_shape" % i)
217 if isinstance(grad_y.dense_shape, ops.Tensor) else
218 grad_y.dense_shape)))
219 else:
220 new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i))
222 return new_grad_ys
225def _VerifyGeneratedGradients(grads, op):
226 """Verify that gradients are valid in number and type.
228 Args:
229 grads: List of generated gradients.
230 op: Operation for which the gradients where generated.
232 Raises:
233 ValueError: if sizes of gradients and inputs don't match.
234 TypeError: if type of any gradient is not valid for its input.
235 """
236 # While ops have inputs added to them during the gradient computation, so we
237 # skip the below check. See while_v2 for details.
238 if op.type == "While" or op.type == "StatelessWhile":
239 return
241 if len(grads) != len(op.inputs):
242 raise ValueError(f"Num gradients {len(grads)} generated for op "
243 f"{op.node_def} do not match num inputs {len(op.inputs)}")
246def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set):
247 """The set of ops that terminate the gradient computation.
249 This computes the frontier of the forward graph *before* which backprop
250 should stop. Operations in the returned set will not be differentiated.
251 This set is defined as the subset of `from_ops` containing ops that have
252 no predecessor in `from_ops`. `pending_count` is the result of
253 `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops`
254 iff pending_count[op] > 0.
256 In addition, none of `stop_gradient_ops` will be differentiated.
258 Args:
259 from_ops: list of Operations.
260 stop_gradient_ops: list of Operations never to backprop through.
261 pending_count: mapping from operation to number of backprop inputs.
262 xs_set: ObjectIdentitySet of Tensors.
264 Returns:
265 The set of operations.
266 """
267 stop_ops = set()
268 for op in from_ops:
269 is_stop_op = True
270 for inp in _NonEagerInputs(op, xs_set):
271 if pending_count[inp.op] > 0:
272 is_stop_op = False
273 break
274 if is_stop_op:
275 stop_ops.add(op)
276 stop_ops.update(op for op in stop_gradient_ops)
277 return stop_ops
280@contextlib.contextmanager
281def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pylint: disable=invalid-name
282 """Context to colocate with `op` if `colocate_gradients_with_ops`."""
283 if colocate_gradients_with_ops:
284 with ops._colocate_with_for_gradient(op, gradient_uid): # pylint: disable=protected-access
285 yield
286 else:
287 yield
290def _IsPartitionedCall(op):
291 return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall"
294def _SymGrad(op, out_grads):
295 """Backprop through a function call node op given its outputs' gradients."""
296 f_in = [x for x in op.inputs] + out_grads
297 f_types = [default_gradient.get_zeros_dtype(x) for x in op.inputs]
298 f = attr_value_pb2.NameAttrList()
299 if _IsPartitionedCall(op):
300 f.name = op.get_attr("f").name
301 else:
302 f.name = op.type
303 for k in op.node_def.attr:
304 f.attr[k].CopyFrom(op.node_def.attr[k])
305 in_grads = gen_functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f)
306 return in_grads
309def _MaybeCompile(scope, op, func, grad_fn):
310 """Compile the calculation in grad_fn if op was marked as compiled."""
311 scope = scope.rstrip("/").replace("/", "_")
312 if func is not None:
313 xla_compile = func.cached_definition.attr["_XlaCompile"].b
314 xla_separate_compiled_gradients = func.cached_definition.attr[
315 "_XlaSeparateCompiledGradients"].b
316 xla_scope = func.cached_definition.attr["_XlaScope"].s.decode()
317 else:
318 try:
319 xla_compile = op.get_attr("_XlaCompile")
320 xla_separate_compiled_gradients = op.get_attr(
321 "_XlaSeparateCompiledGradients")
322 xla_scope = op.get_attr("_XlaScope").decode()
323 except ValueError:
324 xla_compile = False
326 if not xla_compile:
327 return grad_fn() # Exit early
329 # If the gradients are supposed to be compiled separately, we give them a
330 # _XlaScope name that is based on the name_scope of the gradients. Otherwise
331 # they just inherit the existing _XlaScope name, which lets them be merged
332 # together with the non-gradient computation.
333 if xla_separate_compiled_gradients:
334 xla_grad_scope = "%s_grad_%s" % (xla_scope, scope)
335 else:
336 xla_grad_scope = xla_scope
338 attrs = {
339 "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile),
340 "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode())
341 }
342 with ops.get_default_graph()._attr_scope(attrs): # pylint: disable=protected-access
343 return grad_fn()
346def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set):
347 """Raises an error if we backprop through a loop var."""
348 # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
349 # message.
350 target_op = None
351 queue = collections.deque([op])
352 visited = set()
353 while queue:
354 curr_op = queue.popleft()
355 if curr_op in visited: continue
356 visited.add(curr_op)
357 if curr_op in from_ops:
358 target_op = curr_op
359 break
360 queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set))
361 assert target_op
362 raise ValueError(
363 "Cannot compute gradient inside while loop with respect to op "
364 f"'{target_op.name}'. We do not support taking the gradient wrt or "
365 "through the initial value of a loop variable. Gradients can be computed "
366 "through loop invariants or wrt the input parameters to the loop body.")
369def _IsFunction(graph):
370 # isinstance check for FuncGraphs that avoids the explicit dependency
371 # on func_graph.py and function.py
372 return isinstance(graph, ops.Graph) and graph._building_function # pylint: disable=protected-access
375def _Captures(func_graph):
376 assert _IsFunction(func_graph)
377 return func_graph.captures
380def _MaybeCaptured(t):
381 """If t is a captured value placeholder, returns the original captured value.
383 Args:
384 t: Tensor
386 Returns:
387 A tensor, potentially from a different Graph/FuncGraph.
388 """
389 # pylint: disable=protected-access
390 if (not isinstance(t, ops.EagerTensor) and
391 _IsFunction(t.op.graph) and t.op.type == "Placeholder"):
392 for input_t, placeholder_t in _Captures(t.op.graph):
393 if t is placeholder_t:
394 return _MaybeCaptured(input_t)
395 # pylint: enable=protected-access
396 return t
399def _NonEagerInputs(op, xs_set):
400 """Returns the inputs of op, crossing closure boundaries where necessary.
402 Does not return any captured EagerTensors, i.e., the number of tensors
403 returned may be less than the actual number of inputs.
405 Args:
406 op: Operation
407 xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
409 Returns:
410 A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
411 is in a FuncGraph and has captured inputs.
412 """
413 return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)]
416# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
417# _GradientsHelper a class with xs as a member variable.
418def _Inputs(op, xs_set):
419 """Returns the inputs of op, crossing closure boundaries where necessary.
421 Args:
422 op: Operation
423 xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
425 Returns:
426 A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
427 is in a FuncGraph and has captured inputs.
428 """
429 if _IsFunction(op.graph): # pylint: disable=protected-access
430 inputs = []
431 for t in op.inputs:
432 # If we're differentiating w.r.t. `t`, do not attempt to traverse through
433 # it to a captured value. The algorithm needs to "see" `t` in this case,
434 # even if it's a function input for a captured value, whereas usually we'd
435 # like to traverse through these closures as if the captured value was the
436 # direct input to op.
437 if t not in xs_set:
438 t = _MaybeCaptured(t)
439 inputs.append(t)
440 return inputs
441 else:
442 return op.inputs
445def _Consumers(t, func_graphs):
446 """Returns the consumers of t, crossing closure boundaries where necessary.
448 Args:
449 t: Tensor
450 func_graphs: a list of FuncGraphs that may have captured t.
452 Returns:
453 A list of tensors. The tensors will be from the current graph and/or
454 func_graphs.
455 """
456 consumers = t.consumers()
457 for func in func_graphs:
458 for input_t, placeholder in _Captures(func):
459 if input_t is t:
460 consumers.extend(_Consumers(placeholder, func_graphs))
461 return consumers
464def _GradientsHelper(ys,
465 xs,
466 grad_ys=None,
467 name="gradients",
468 colocate_gradients_with_ops=False,
469 gate_gradients=False,
470 aggregation_method=None,
471 stop_gradients=None,
472 unconnected_gradients=UnconnectedGradients.NONE,
473 src_graph=None):
474 """Implementation of gradients()."""
475 if context.executing_eagerly():
476 raise RuntimeError("tf.gradients is not supported when eager execution "
477 "is enabled. Use tf.GradientTape instead.")
478 ys = variable_utils.convert_variables_to_tensors(_AsList(ys))
479 xs = [
480 x.handle if resource_variable_ops.is_resource_variable(x) else x
481 for x in _AsList(xs)
482 ]
483 if grad_ys is not None:
484 grad_ys = _AsList(grad_ys)
486 # Handle CompositeTensors.
487 if (any(isinstance(x, composite_tensor.CompositeTensor) for x in xs) or
488 any(isinstance(y, composite_tensor.CompositeTensor) for y in ys)):
489 flat_xs = composite_tensor_gradient.get_flat_tensors_for_gradients(xs)
490 flat_ys = composite_tensor_gradient.get_flat_tensors_for_gradients(ys)
491 flat_grad_ys = (
492 None if grad_ys is None else
493 composite_tensor_gradient.get_flat_tensors_for_gradients(grad_ys))
494 flat_grads = _GradientsHelper(flat_ys, flat_xs, flat_grad_ys, name,
495 colocate_gradients_with_ops, gate_gradients,
496 aggregation_method, stop_gradients,
497 unconnected_gradients, src_graph)
498 return composite_tensor_gradient.replace_flat_tensors_for_gradients(
499 xs, flat_grads)
501 if src_graph is None:
502 src_graph = ops.get_default_graph()
503 try:
504 unconnected_gradients = UnconnectedGradients(unconnected_gradients)
505 except ValueError:
506 raise ValueError(
507 f"Unknown value for unconnected_gradients: '{unconnected_gradients}'")
509 # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
510 # ancestor graphs. This is necessary for correctly handling captured values.
511 func_graphs = []
512 curr_graph = src_graph
513 while _IsFunction(curr_graph):
514 func_graphs.append(curr_graph)
515 curr_graph = curr_graph.outer_graph
517 stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
518 if grad_ys is None:
519 grad_ys = [None] * len(ys)
521 with ops.name_scope(
522 name, "gradients",
523 list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
524 # Get a uid for this call to gradients that can be used to help
525 # cluster ops for compilation.
526 gradient_uid = ops.get_default_graph().unique_name("uid")
527 ys = indexed_slices.convert_n_to_tensor_or_indexed_slices(ys, name="y")
528 xs = indexed_slices.internal_convert_n_to_tensor_or_indexed_slices(
529 xs, name="x", as_ref=True)
530 xs_set = object_identity.ObjectIdentitySet(xs)
531 grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
532 gradient_uid)
534 # The approach we take here is as follows: Create a list of all ops in the
535 # subgraph between the ys and xs. Visit these ops in reverse order of ids
536 # to ensure that when we visit an op the gradients w.r.t its outputs have
537 # been collected. Then aggregate these gradients if needed, call the op's
538 # gradient function, and add the generated gradients to the gradients for
539 # its input.
541 # Initialize the pending count for ops in the connected subgraph from ys
542 # to the xs.
543 to_ops = [t.op for t in ys]
544 from_ops = [t.op for t in xs]
545 stop_gradient_ops = [t.op for t in stop_gradients]
546 reachable_to_ops, pending_count, loop_state = _PendingCount(
547 to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set)
549 # Iterate over the collected ops.
550 #
551 # grads: op => list of gradients received on each output endpoint of the
552 # op. The gradients for each endpoint are initially collected as a list.
553 # When it is time to call the op's gradient function, for each endpoint we
554 # aggregate the list of received gradients into a Add() Operation if there
555 # is more than one.
556 grads = {}
558 # Add the initial gradients for the ys.
559 for y, grad_y in zip(ys, grad_ys):
560 _SetGrad(grads, y, grad_y)
562 # Initialize queue with to_ops.
563 queue = collections.deque()
564 # Add the ops in 'to_ops' into the queue.
565 to_ops_set = set()
566 for op in to_ops:
567 # 'ready' handles the case where one output gradient relies on
568 # another output's gradient.
569 ready = (pending_count[op] == 0)
570 if ready and op not in to_ops_set and op in reachable_to_ops:
571 to_ops_set.add(op)
572 queue.append(op)
574 if loop_state:
575 loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
576 for y in loop_exits:
577 if backprop_util.IsTrainable(y):
578 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
579 queue.append(y.op)
581 stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set)
582 while queue:
583 # generate gradient subgraph for op.
584 op = queue.popleft()
585 with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
586 if loop_state:
587 loop_state.EnterGradWhileContext(op, before=True)
588 out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
589 aggregation_method)
590 if loop_state:
591 loop_state.ExitGradWhileContext(op, before=True)
593 grad_fn = None
594 func_call = None
595 is_partitioned_call = _IsPartitionedCall(op)
596 # pylint: disable=protected-access
597 is_func_call = (
598 src_graph._is_function(op.type) or is_partitioned_call)
599 # pylint: enable=protected-access
600 has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
601 if has_out_grads and (op not in stop_ops):
602 try:
603 grad_fn = ops.get_gradient_function(op)
604 except LookupError:
605 if is_func_call:
606 if is_partitioned_call:
607 func_name = compat.as_bytes(op.get_attr("f").name)
608 func_call = src_graph._get_function( # pylint: disable=protected-access
609 func_name)
610 # When a graph is imported, the FunctionDefs are not copied over
611 # to each sub-graph so we recursively search the outer graphs
612 # for the FunctionDef.
613 if not func_call and hasattr(src_graph, "outer_graph"):
614 graph = src_graph.outer_graph
615 while graph is not None:
616 func_call = graph._get_function(func_name) # pylint: disable=protected-access
617 if func_call is not None:
618 break
619 if hasattr(graph, "outer_graph"):
620 graph = graph.outer_graph
621 else:
622 break
623 else:
624 func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
625 # Note that __defun is not set if the graph is
626 # imported. If it's set, we prefer to access the original
627 # defun.
628 func_call = getattr(op, "__defun", func_call)
629 grad_fn = func_call.python_grad_func
630 else:
631 raise LookupError(
632 "No gradient defined for operation"
633 f"'{op.name}' (op type: {op.type}). "
634 "In general every operation must have an associated "
635 "`@tf.RegisterGradient` for correct autodiff, which this "
636 "op is lacking. If you want to pretend this "
637 "operation is a constant in your program, you may insert "
638 "`tf.stop_gradient`. This can be useful to silence the "
639 "error in cases where you know gradients are not needed, "
640 "e.g. the forward pass of tf.custom_gradient. "
641 "Please see more details in "
642 "https://www.tensorflow.org/api_docs/python/tf/custom_gradient.") # pylint: disable=line-too-long
643 if loop_state:
644 loop_state.EnterGradWhileContext(op, before=False)
646 # NOTE(skyewm): We don't support computing gradients wrt a loop variable
647 # unless it's within the context of a single iteration (i.e. the
648 # gradient is wrt to the loop parameter in the body function, not wrt or
649 # through the initial value). This means if we're in a while loop
650 # context, we should never see a switch node from this context.
651 # pylint: disable=protected-access
652 if (control_flow_util.IsSwitch(op) and
653 op._control_flow_context is not None and
654 op._control_flow_context.IsWhileContext() and
655 op._control_flow_context ==
656 ops.get_default_graph()._get_control_flow_context()):
657 _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set)
658 # pylint: enable=protected-access
660 if (grad_fn or is_func_call) and has_out_grads:
661 # NOTE: If _AggregatedGrads didn't compute a value for the i'th
662 # output, it means that the cost does not depend on output[i],
663 # therefore dC/doutput[i] is 0.
664 for i, out_grad in enumerate(out_grads):
665 if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
666 (not grad_fn and is_func_call)
667 or backprop_util.IsTrainable(op.outputs[i])):
668 # Only trainable outputs or outputs for a function call that
669 # will use SymbolicGradient get a zero gradient. Gradient
670 # functions should ignore the gradient for other outputs.
671 # TODO(apassos) gradients of resource handles might be an
672 # issue here because of zeros.
673 if loop_state:
674 out_grads[i] = loop_state.ZerosLikeV1WhileLoop(op, i)
675 elif default_gradient.supports_default_grad(op.outputs[i]):
676 # TODO(b/143286622): The supports_default_grad check is needed
677 # because While op emits non-differentiable resource tensors
678 # as outputs. Remove this check when that is not the case.
679 out_grads[i] = control_flow_state.ZerosLike(op, i)
680 with ops.name_scope(op.name + "_grad"):
681 # pylint: disable=protected-access
682 with src_graph._original_op(op):
683 # pylint: enable=protected-access
684 if grad_fn:
685 # If grad_fn was found, do not use SymbolicGradient even for
686 # functions.
687 in_grads = _MaybeCompile(grad_scope, op, func_call,
688 lambda: grad_fn(op, *out_grads))
689 else:
690 # For function call ops, we add a 'SymbolicGradient'
691 # node to the graph to compute gradients.
692 in_grads = _MaybeCompile(grad_scope, op, func_call,
693 lambda: _SymGrad(op, out_grads))
694 in_grads = _AsList(in_grads)
695 _VerifyGeneratedGradients(in_grads, op)
696 if gate_gradients and len([x for x in in_grads
697 if x is not None]) > 1:
698 with ops.device(None):
699 with ops._colocate_with_for_gradient( # pylint: disable=protected-access
700 None,
701 gradient_uid,
702 ignore_existing=True):
703 in_grads = control_flow_ops.tuple(in_grads)
704 _LogOpGradients(op, out_grads, in_grads)
705 else:
706 # If no grad_fn is defined or none of out_grads is available,
707 # just propagate a list of None backwards.
708 in_grads = [None] * len(_Inputs(op, xs_set))
709 # Note: we don't filter out eager inputs here because the inputs need to
710 # line up with in_grads.
711 for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)):
712 if in_grad is not None:
713 if (isinstance(in_grad, ops.Tensor) and
714 t_in.dtype != dtypes.resource):
715 try:
716 in_grad.set_shape(t_in.get_shape())
717 except ValueError:
718 raise ValueError(
719 "Incompatible shapes between op input and calculated "
720 f"input gradient. Forward operation: {op.name}. Input "
721 f"index: {i}. Original input shape: {t_in.shape}. "
722 f"Calculated input gradient shape: {in_grad.shape}")
723 if not isinstance(t_in, ops.EagerTensor):
724 _SetGrad(grads, t_in, in_grad)
725 if loop_state:
726 loop_state.ExitGradWhileContext(op, before=False)
728 # Update pending count for the inputs of op and enqueue ready ops.
729 _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
730 xs_set)
732 if loop_state:
733 loop_state.PostProcessing()
734 return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
737def _HasAnyNotNoneGrads(grads, op):
738 """Return true iff op has real gradient."""
739 out_grads = _GetGrads(grads, op)
740 for out_grad in out_grads:
741 if isinstance(out_grad, (ops.Tensor, indexed_slices.IndexedSlices)):
742 return True
743 if out_grad and isinstance(out_grad, collections_abc.Sequence):
744 if any(g is not None for g in out_grad):
745 return True
746 return False
749def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
750 xs_set):
751 """Update pending count for the inputs of op and enqueue ready ops."""
752 for x in _NonEagerInputs(op, xs_set):
753 pending_count[x.op] -= 1
754 ready = (pending_count[x.op] == 0)
755 if loop_state and not ready:
756 ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op)
757 if ready:
758 if control_flow_util.IsLoopExit(x.op):
759 # if x is an exit without real gradient, defer processing them.
760 grad_state = loop_state.GetGradState(x.op, before=False)
761 grad_state.deferred_exits.append(x)
762 grad_state.pending_exits_count -= 1
763 if grad_state.pending_exits_count == 0:
764 # We now have all the exits so process them.
765 has_not_none_grad = False
766 for y in grad_state.deferred_exits:
767 if _HasAnyNotNoneGrads(grads, y.op):
768 has_not_none_grad = True
769 queue.append(y.op)
770 else:
771 grad_state.unused_exits.append(y)
772 if has_not_none_grad:
773 # For an unused exit, if it has trainable outputs, backprop
774 # a zero gradient. Otherwise, just ignore it.
775 for y in grad_state.unused_exits:
776 if backprop_util.IsTrainable(y):
777 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
778 queue.append(y.op)
779 else:
780 # All exits are "unused" so use None as gradient.
781 for y in grad_state.unused_exits:
782 queue.append(y.op)
783 else:
784 queue.append(x.op)
787def _SetGrad(grads, t, grad):
788 """Sets gradient "grad" in "grads" for tensor "t"."""
789 op = t.op
790 op_grads = grads.get(op)
791 if not op_grads:
792 op_grads = [[] for _ in range(len(op.outputs))]
793 grads[op] = op_grads
794 t_grads = op_grads[t.value_index]
795 if isinstance(t_grads, list):
796 t_grads.append(grad)
797 else:
798 assert control_flow_util.IsLoopSwitch(op)
799 op_grads[t.value_index] = grad
802def _ZerosLike(t):
803 t_dtype = default_gradient.get_zeros_dtype(t)
804 if t.dtype == dtypes.resource:
805 return array_ops.zeros(
806 resource_variable_ops.variable_shape(t), dtype=t_dtype)
807 else:
808 return array_ops.zeros_like(t, dtype=t_dtype)
811def _GetGrad(grads, t, unconnected_gradients):
812 """Gets gradient for tensor "t"."""
813 op = t.op
814 op_grads = grads.get(op)
815 if not op_grads:
816 if unconnected_gradients == UnconnectedGradients.ZERO:
817 return _ZerosLike(t)
818 elif unconnected_gradients == UnconnectedGradients.NONE:
819 return None
820 else:
821 raise ValueError(
822 f"Unknown value for unconnected_gradients: '{unconnected_gradients}'")
824 t_grad = op_grads[t.value_index]
825 # This can happen if some other output of `t.op` has non-None grad.
826 if unconnected_gradients == UnconnectedGradients.ZERO and t_grad is None:
827 return _ZerosLike(t)
829 assert not isinstance(
830 t_grad, list), ("gradients list should have been aggregated by now.")
831 return t_grad
834def _GetGrads(grads, op):
835 """Gets all gradients for op."""
836 if op in grads:
837 return grads[op]
838 else:
839 return [[] for _ in range(len(op.outputs))]
842def _AccumulatorShape(inputs):
843 shape = tensor_shape.unknown_shape()
844 for i in inputs:
845 if isinstance(i, ops.Tensor):
846 shape = shape.merge_with(i.get_shape())
847 return shape
850def _LogOpGradients(op, out_grads, in_grads):
851 """Log the in and out grads of an op."""
852 logging.vlog(1, "Gradient for '" + op.name + "'")
854 def _FilterGrad(x):
855 if x is None:
856 return False
857 if isinstance(x, (list, tuple)):
858 return bool(x)
859 else:
860 return True
862 logging.vlog(1, " in --> %s",
863 ", ".join(x.name for x in out_grads if _FilterGrad(x)))
864 logging.vlog(1, " out --> %s",
865 ", ".join(x.name for x in in_grads if _FilterGrad(x)))
868def _MultiDeviceAddN(tensor_list, gradient_uid):
869 """Adds tensors from potentially multiple devices."""
870 # Basic function structure comes from control_flow_ops.group().
871 # Sort tensors according to their devices.
872 tensors_on_device = collections.defaultdict(lambda: [])
873 for tensor in tensor_list:
874 tensors_on_device[tensor.device].append(tensor)
876 # For each device, add the tensors on that device first.
877 # Then gather the partial sums from multiple devices.
878 # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion.
879 # E.g., aggregate per GPU, then per task, and so on.
880 summands = []
882 def DeviceKey(dev):
883 return "" if dev is None else dev
885 for dev in sorted(tensors_on_device, key=DeviceKey):
886 tensors = tensors_on_device[dev]
887 with ops._colocate_with_for_gradient( # pylint: disable=protected-access
888 tensors[0].op,
889 gradient_uid,
890 ignore_existing=True):
891 summands.append(math_ops.add_n(tensors))
893 return math_ops.add_n(summands)
896@tf_export("AggregationMethod")
897class AggregationMethod:
898 """A class listing aggregation methods used to combine gradients.
900 Computing partial derivatives can require aggregating gradient
901 contributions. This class lists the various methods that can
902 be used to combine gradients in the graph.
904 The following aggregation methods are part of the stable API for
905 aggregating gradients:
907 * `ADD_N`: All of the gradient terms are summed as part of one
908 operation using the "AddN" op (see `tf.add_n`). This
909 method has the property that all gradients must be ready and
910 buffered separately in memory before any aggregation is performed.
911 * `DEFAULT`: The system-chosen default aggregation method.
913 The following aggregation methods are experimental and may not
914 be supported in future releases:
916 * `EXPERIMENTAL_TREE`: Gradient terms are summed in pairs using
917 the "AddN" op. This method of summing gradients may reduce
918 performance, but it can improve memory utilization because the
919 gradients can be released earlier.
920 * `EXPERIMENTAL_ACCUMULATE_N`: Same as `EXPERIMENTAL_TREE`.
922 Example usage when computing gradient:
924 >>> @tf.function
925 ... def example():
926 ... x = tf.constant(1.0)
927 ... y = x * 2.0
928 ... z = y + y + y + y
929 ... return tf.gradients(z, [x, y],
930 ... aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
931 >>> example()
932 [<tf.Tensor: shape=(), dtype=float32, numpy=8.0>,
933 <tf.Tensor: shape=(), dtype=float32, numpy=4.0>]
935 """
936 ADD_N = 0
937 DEFAULT = ADD_N
938 # The following are experimental and may not be supported in future releases.
939 EXPERIMENTAL_TREE = 1
940 EXPERIMENTAL_ACCUMULATE_N = 2 # An alias for EXPERIMENTAL_ADD_N = 1
943def _AggregatedGrads(grads,
944 op,
945 gradient_uid,
946 loop_state,
947 aggregation_method=None):
948 """Get the aggregated gradients for op.
950 Args:
951 grads: The map of memoized gradients.
952 op: The op to get gradients for.
953 gradient_uid: A unique identifier within the graph indicating
954 which invocation of gradients is being executed. Used to cluster
955 ops for compilation.
956 loop_state: An object for maintaining the state of the while loops in the
957 graph. It is of type ControlFlowState. None if the graph
958 contains no while loops.
959 aggregation_method: Specifies the method used to combine gradient terms.
960 Accepted values are constants defined in the class `AggregationMethod`.
962 Returns:
963 A list of gradients, one per each output of `op`. If the gradients
964 for a particular output is a list, this function aggregates it
965 before returning.
967 Raises:
968 TypeError: if the incoming grads are not Tensors or IndexedSlices.
969 ValueError: if the arguments are invalid.
971 """
972 if aggregation_method is None:
973 aggregation_method = AggregationMethod.DEFAULT
974 valid_aggregation_methods = [
975 AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE,
976 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N]
977 if aggregation_method not in valid_aggregation_methods:
978 raise ValueError(
979 f"Invalid `aggregation_method` specified {aggregation_method}. "
980 f"Accepted values are {valid_aggregation_methods}.")
981 out_grads = _GetGrads(grads, op)
982 for i, out_grad in enumerate(out_grads):
983 if loop_state:
984 if isinstance(out_grad, (ops.Tensor, indexed_slices.IndexedSlices)):
985 assert control_flow_util.IsLoopSwitch(op)
986 continue
987 # Grads have to be Tensors or IndexedSlices
988 if (isinstance(out_grad, collections_abc.Sequence) and not all(
989 isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices))
990 for g in out_grad
991 if g is not None)):
992 raise TypeError(f"Invalid gradient {out_grad} [index = {i}]. Gradients "
993 "have to be either all Tensors or all IndexedSlices")
994 # Aggregate multiple gradients, and convert [] to None.
995 if out_grad:
996 if len(out_grad) < 2:
997 used = "nop"
998 out_grads[i] = out_grad[0]
999 elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None):
1000 tensor_shape = _AccumulatorShape(out_grad)
1001 if aggregation_method in [
1002 AggregationMethod.EXPERIMENTAL_TREE,
1003 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
1004 ]:
1005 # Aggregate all gradients by doing pairwise sums: this may
1006 # reduce performance, but it can improve memory because the
1007 # gradients can be released earlier.
1008 #
1009 # TODO(vrv): Consider replacing this with a version of
1010 # tf.AddN() that eagerly frees its inputs as soon as they are
1011 # ready, so the order of this tree does not become a problem.
1012 used = "tree"
1013 with ops.name_scope(op.name + "_gradient_sum"):
1014 running_sum = out_grad[0]
1015 for grad in out_grad[1:]:
1016 running_sum = math_ops.add_n([running_sum, grad])
1017 out_grads[i] = running_sum
1018 else:
1019 used = "add_n"
1020 out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid)
1021 logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
1022 tensor_shape, used)
1023 else:
1024 out_grads[i] = backprop_util.AggregateIndexedSlicesGradients(out_grad) # pylint: disable=protected-access
1025 else: # not out_grad
1026 # out_grads[i] is [], thus its aggregation is simply None.
1027 out_grads[i] = None
1028 return out_grads
1031# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are
1032# unfortunately too slow to use here.
1033POSSIBLE_GRADIENT_TYPES_NONE = 0
1034POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1
1035POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2
1038def PossibleTapeGradientTypes(tensors):
1039 """Determines whether and how `args` may require tape gradients."""
1040 return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors)