Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py: 15%
463 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""while_v2 and gradient.
17This is a version of while_loop that emits a single While op, as well as the
18gradient function for While ops produced by while_loop. This will eventually
19replace the current tf.while_loop implementation once it reaches feature and
20performance parity.
21"""
22import collections
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.python.client import pywrap_tf_session as c_api
26from tensorflow.python.eager import backprop_util
27from tensorflow.python.framework import auto_control_deps_utils as acd
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import func_graph as func_graph_module
31from tensorflow.python.framework import indexed_slices
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_spec
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import control_flow_util as util_v1
39from tensorflow.python.ops import control_flow_util_v2 as util
40from tensorflow.python.ops import default_gradient
41from tensorflow.python.ops import gen_functional_ops
42from tensorflow.python.ops import gen_resource_variable_ops
43from tensorflow.python.ops import gradients_util
44from tensorflow.python.ops import handle_data_util
45from tensorflow.python.ops import list_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import tensor_array_ops
48from tensorflow.python.ops import while_v2_indexed_slices_rewriter
49from tensorflow.python.util import compat
50from tensorflow.python.util import nest
51from tensorflow.python.util import object_identity
52from tensorflow.python.util import variable_utils
54# pylint: disable=protected-access
57def while_loop(cond,
58 body,
59 loop_vars,
60 shape_invariants=None,
61 parallel_iterations=10,
62 maximum_iterations=None,
63 name=None,
64 return_same_structure=True,
65 back_prop=True):
66 """Like tf.while_loop, except emits a single While op."""
67 loop_vars = variable_utils.convert_variables_to_tensors(loop_vars)
68 # Keep the original loop_vars around to know which args were TensorArrays.
69 orig_loop_vars = loop_vars
70 flat_orig_loop_vars = nest.flatten(orig_loop_vars, expand_composites=True)
71 # Cache its length since we use it at multiple places below.
72 len_orig_loop_vars = len(orig_loop_vars)
74 # Convert TensorArrays to their flow variables. These get converted back to
75 # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
76 # `wrapped_body` below.
77 loop_vars = _tensor_array_to_flow(loop_vars)
78 loop_vars = nest.map_structure(
79 indexed_slices.internal_convert_to_tensor_or_indexed_slices,
80 loop_vars,
81 expand_composites=True)
83 # `loop_vars_signature` is a structure of TypeSpecs and has the same
84 # structure with the `orig_loop_vars`. If `shape_invariants` is not None, its
85 # shape information comes from `shape_invariants` instead of `orig_loop_vars`.
86 # It is used to pack flattened vars into structured vars.
87 if shape_invariants is not None:
88 loop_vars_signature = nest.map_structure(
89 control_flow_ops._shape_invariant_to_type_spec,
90 loop_vars, shape_invariants)
91 else:
92 loop_vars_signature = nest.map_structure(
93 control_flow_ops._shape_invariant_to_type_spec, loop_vars)
95 flat_shape_invariants = nest.map_structure(
96 lambda spec: spec.shape,
97 nest.flatten(loop_vars_signature, expand_composites=True))
99 if not name:
100 name = "while"
102 with ops.name_scope(name) as scope:
103 with ops.name_scope(None):
104 cond_name = util.unique_fn_name(scope, "cond")
105 body_name = util.unique_fn_name(scope, "body")
106 maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
107 maximum_iterations)
108 loop_counter = constant_op.constant(
109 0,
110 dtype=maximum_iterations_loop_var.dtype
111 if maximum_iterations is not None else None,
112 name="loop_counter")
113 # Add loop counter needed for computing gradients.
114 loop_vars = [loop_counter, maximum_iterations_loop_var] + list(loop_vars)
116 func_graph_signature = (
117 [tensor_spec.TensorSpec.from_tensor(loop_counter),
118 tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] +
119 list(loop_vars_signature))
121 # Automatic control dependencies are added in defuns, but not in v1
122 # graphs. Propagate that behavior here.
123 add_control_dependencies = ops.get_default_graph()._add_control_dependencies
125 def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
126 """Extra `cond` wrapper that can handle the extra counter loop_var."""
127 # Convert the flow variables in `args` to TensorArrays. `args` should
128 # already have the same structure as `orig_loop_vars` but currently there
129 # is no nest.zip so we call `_pack_sequence_as` which flattens `args`,
130 # converts flows in `args` to TensorArrays and packs it into the
131 # structure of `loop_vars_signature`.
132 pred = cond(
133 *_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args))
134 if (tensor_util.is_tf_type(pred) and
135 (pred.shape.dims is None or pred.shape.dims)):
136 pred = array_ops.squeeze_v2(pred)
138 if maximum_iterations is None:
139 return pred
140 else:
141 return math_ops.logical_and(
142 loop_counter < maximum_iterations_arg, pred)
144 # NOTE(skyewm): we set collections to the outer graph's collections for
145 # compatibility with TPUEstimator.
146 cond_graph = func_graph_module.func_graph_from_py_func(
147 cond_name,
148 wrapped_cond,
149 [], # We provide signature instead of args.
150 {},
151 signature=func_graph_signature,
152 func_graph=util.WhileCondFuncGraph(
153 cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
154 add_control_dependencies=add_control_dependencies)
156 def wrapped_body(loop_counter, maximum_iterations_arg, *args):
157 """Loop body augmented with counter update.
159 Args:
160 loop_counter: Loop counter which needs to be incremented in the body.
161 maximum_iterations_arg: Maximum iterations of the loop.
162 *args: List of args
164 Returns:
165 A list of tensors the same length as args.
166 """
167 # The function was created with a signature rather than tensors, so
168 # internal placeholders were created without handle data.
169 _copy_handle_data(nest.flatten(loop_vars[2:], expand_composites=True),
170 nest.flatten(args, expand_composites=True))
171 # Capture the tensors already captured in cond_graph so that they appear
172 # in the same order in body_graph.external_captures.
173 for t in cond_graph.external_captures:
174 ops.get_default_graph().capture(t)
176 # Convert the flow variables in `args` to TensorArrays. `args` should
177 # already have the same structure as `orig_loop_vars` but currently there
178 # is no nest.zip so we call `_pack_sequence_as` which flattens `args`,
179 # converts flows in `args` to TensorArrays and packs it into the
180 # structure of `loop_vars_signature`.
181 outputs = body(
182 *_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args))
183 if not nest.is_nested(outputs):
184 outputs = [outputs]
185 try:
186 # The legacy while_loop considers list and tuple to be the same
187 # structure.
188 nest.assert_same_structure(outputs, orig_loop_vars, check_types=False,
189 expand_composites=True)
190 except ValueError:
191 # Traditionally we consider variables and tensors to be the same
192 # structure.
193 vars1 = variable_utils.convert_variables_to_tensors(outputs)
194 vars2 = variable_utils.convert_variables_to_tensors(orig_loop_vars)
195 nest.assert_same_structure(vars1, vars2, check_types=False,
196 expand_composites=True)
197 outputs = _tensor_array_to_flow(outputs)
199 # TODO(srbs): Update lowering code to create _Enter nodes with
200 # is_constant=True for inputs that are directly passed to outputs.
201 return [loop_counter + 1, maximum_iterations_arg] + list(outputs)
203 body_graph = func_graph_module.func_graph_from_py_func(
204 body_name,
205 wrapped_body,
206 [], # We provide signature instead of args.
207 {},
208 signature=func_graph_signature,
209 func_graph=util.WhileBodyFuncGraph(
210 body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
211 add_control_dependencies=add_control_dependencies)
212 # Add external captures of body to the list of loop vars.
213 # Note that external tensors will be treated as loop invariants, i.e.,
214 # the value of that tensor in each iteration is the same as it was at the
215 # beginning of the loop execution.
216 deferred_external_captures = nest.flatten(
217 [c() for c in body_graph.deferred_external_captures],
218 expand_composites=True)
219 loop_vars = (
220 loop_vars + body_graph.external_captures + deferred_external_captures)
221 # TODO(srbs): Update lowering code to create _Enter nodes with
222 # is_constant=True for inputs that are directly passed to outputs.
223 body_graph.outputs.extend(body_graph.internal_captures)
224 body_graph.outputs.extend(body_graph.deferred_internal_captures)
226 # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
227 # that it expects to receive those as arguments.
228 with cond_graph.as_default():
229 num_cond_captures = len(cond_graph.external_captures)
230 assert (cond_graph.external_captures ==
231 body_graph.external_captures[:num_cond_captures])
232 _duplicate_body_captures_in_cond(
233 cond_graph, body_graph.external_captures[num_cond_captures:] +
234 deferred_external_captures)
236 # Make sure that the shapes of the loop outputs are compatible with the
237 # shape invariants, or the shapes of the loop vars if the invariants are not
238 # specified.
239 num_flattened_outputs = len(nest.flatten(orig_loop_vars,
240 expand_composites=True))
241 # First var is loop counter and second var is maximum_iterations.
242 first_loop_var_index = 2
243 _check_shapes_compat(
244 body_graph.outputs[first_loop_var_index:first_loop_var_index +
245 num_flattened_outputs],
246 flat_shape_invariants,
247 nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
248 len_orig_loop_vars], expand_composites=True))
250 num_original_outputs = len(body_graph.outputs)
251 if back_prop and util.output_all_intermediates():
252 # Export all tensors in the loop body that may be needed for gradient
253 # computation. We do this by accumulating the intermediate values in
254 # TensorLists.
255 intermediate_tensors = _get_intermediates(body_graph)
257 for intermediate_tensor in intermediate_tensors:
258 tensor_list = list_ops.empty_tensor_list(
259 element_dtype=intermediate_tensor.dtype,
260 element_shape=intermediate_tensor.shape,
261 max_num_elements=maximum_iterations)
262 loop_vars.append(tensor_list)
263 with cond_graph.as_default():
264 # Add a placeholder to cond_graph's inputs corresponding to the
265 # tensor_list.
266 cond_graph.capture(tensor_list)
267 with body_graph.as_default():
268 # Push the intermediate tensor to the tensor list. This captures the
269 # `tensor_list` as well.
270 appended_tensor_list = list_ops.tensor_list_push_back(
271 tensor_list, intermediate_tensor)
272 # Add this modified tensor list to the list of outputs.
273 body_graph.outputs.append(appended_tensor_list)
275 flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
276 _check_num_inputs_outputs(cond_graph, body_graph,
277 len(flattened_loop_vars))
278 _check_inputs_outputs_types_match(body_graph, flattened_loop_vars)
280 with ops.control_dependencies(
281 list(cond_graph.function_captures.control) + list(
282 body_graph.function_captures.control)):
283 output_shapes = [t.shape for t in body_graph.outputs]
284 orig_loop_vars_range = slice(first_loop_var_index,
285 first_loop_var_index + num_flattened_outputs)
286 output_shapes[orig_loop_vars_range] = flat_shape_invariants
288 outputs = _build_while_op(
289 flattened_loop_vars,
290 cond_graph,
291 body_graph,
292 output_shapes=output_shapes,
293 parallel_iterations=parallel_iterations,
294 name=scope,
295 num_original_outputs=num_original_outputs)
296 if not ops.get_default_graph().building_function:
297 # In V1 graph mode, return identities for each output of the While op,
298 # rather than the output of the While op directly. This makes pruning work
299 # if the output of while_loop() is fetched: the lowering pass converts the
300 # While outputs into IdentityN outputs, which if fetched will cause all
301 # ops in the body to be run (since it takes all exit ops as input). After
302 # lowering, each output identity op will end up with only the appropriate
303 # exit op as input.
304 outputs = tuple(array_ops.identity(t) for t in outputs)
306 output_loop_vars = outputs[first_loop_var_index:first_loop_var_index +
307 num_flattened_outputs]
308 if not back_prop:
309 output_loop_vars = [array_ops.stop_gradient(t) for t in output_loop_vars]
310 outputs = _pack_sequence_as(
311 loop_vars_signature, flat_orig_loop_vars, output_loop_vars)
313 if return_same_structure:
314 return outputs
316 flattened_outputs = nest.flatten(outputs, expand_composites=True)
317 if len(flattened_outputs) == 1:
318 return flattened_outputs[0]
319 else:
320 return outputs
323@ops.RegisterGradient("StatelessWhile")
324@ops.RegisterGradient("While")
325def _WhileGrad(op, *grads): # pylint: disable=invalid-name
326 """The gradient of a While op produced by while_loop."""
327 # Note that op is not always the same as while_op because the gradient tape,
328 # for eager mode compatibility, forgets information about the proper op. Since
329 # the loop cannot run in eager mode, however, we can safely introspect into
330 # the graph here.
331 while_op = op.outputs[0].op
332 cond_graph = _get_graph(while_op, "cond", "_cond_graph")
333 body_graph = _get_graph(while_op, "body", "_body_graph")
334 orig_num_params = len(body_graph.outputs)
336 maximum_iterations = op.inputs[1]
337 parallel_iterations = op.get_attr("parallel_iterations")
339 try:
340 num_original_outputs = while_op.get_attr("_num_original_outputs")
341 except: # pylint: disable=bare-except
342 num_original_outputs = len(while_op.outputs)
344 num_intermediates = len(while_op.outputs) - num_original_outputs
345 grads = [
346 _preprocess_grad(grad, body_out, while_in, while_out) # pylint: disable=g-complex-comprehension
347 for grad, body_out, while_in, while_out in zip(
348 grads[:num_original_outputs],
349 body_graph.outputs[:num_original_outputs],
350 while_op.inputs[:num_original_outputs],
351 while_op.outputs[:num_original_outputs])
352 ] + [None] * num_intermediates
354 # Skip gradients with respect to the captures whenever possible.
355 if getattr(op, "skip_input_indices", None) is not None:
356 captures_start_index = (
357 len(body_graph.inputs) - len(body_graph.internal_captures))
358 for i in op.skip_input_indices:
359 if i >= captures_start_index:
360 grads[i] = None
362 # We compute the gradient for the sub-graph between trainable ys and xs
363 # with non-None incoming gradients. We later pad the None's to the list of
364 # outputs.
365 ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
366 body_graph.outputs, body_graph.inputs, grads) if grad is not None])
368 body_grad_graph, args = _create_grad_func(
369 ys, xs, non_none_grads, cond_graph, body_graph,
370 util.unique_grad_fn_name(body_graph.name), op, maximum_iterations)
372 if body_grad_graph.while_op_needs_rewrite:
373 # Modify 'op' to output the intermediate accumulators needed by the grad
374 # function.
375 # NOTE(skyewm): if there are any active sessions, this modification to `op`
376 # may make them unrunnable!
378 cond_graph.name += "_rewritten"
379 body_graph.name += "_rewritten"
381 # `body_grad_graph.extra_inputs` here is equivalent to skimming off the new
382 # `body_graph.external_captures` added during `_create_grad_func`.
383 new_inputs = body_grad_graph.extra_inputs
384 new_outputs = body_graph.outputs[orig_num_params:]
386 while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
387 while_op._set_func_attr("body", util.create_new_tf_function(body_graph))
388 if len(body_graph.output_types) != len(while_op.inputs) + len(new_inputs):
389 # Continuing leads to an invalid graph with disconnected inputs.
390 raise AssertionError(
391 "Inputs and outputs constructed for the forward op of a While "
392 "gradient don't match with 'output_types' at "
393 f"{len(body_graph.output_types)},'inputs' at length "
394 f"{len(while_op.inputs)}, and 'new_inputs' at length "
395 f"{len(new_inputs)}. This doesn't make sense, please file a bug.")
396 while_op._set_type_list_attr("T", body_graph.output_types)
397 while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes)
398 while_op._add_while_inputs(new_inputs)
399 while_op._add_outputs([t.dtype for t in new_outputs],
400 [t.shape for t in new_outputs])
401 _copy_handle_data(new_outputs, while_op.outputs[orig_num_params:])
403 # Do not ignore grads wrt extra outputs when computing higher order
404 # derivatives.
405 while_op._set_attr("_num_original_outputs",
406 attr_value_pb2.AttrValue(i=len(while_op.outputs)))
408 captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
409 while_op)
410 loop_vars = args + captured_inputs
412 # This modifies body_grad_graph.
413 loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices(
414 grads, body_grad_graph, loop_vars, while_op.inputs)
416 def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters,
417 *unused_args):
418 return counter < forward_loop_iters
420 grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
421 cond_grad_graph = func_graph_module.func_graph_from_py_func(
422 grad_cond_name, grad_cond, loop_vars, {},
423 func_graph=util.WhileCondFuncGraph(grad_cond_name))
425 _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))
427 outputs = _build_while_op(
428 loop_vars,
429 cond_grad_graph,
430 body_grad_graph,
431 output_shapes=[t.shape for t in body_grad_graph.outputs],
432 parallel_iterations=parallel_iterations,
433 name="%s_grad" % while_op.name,
434 num_original_outputs=len(body_grad_graph.outputs))
436 # See comment in while_loop.
437 outputs = [array_ops.identity(t) for t in outputs]
438 return _get_structured_grad_output(outputs, grads, body_grad_graph)
441def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
442 parallel_iterations, name, num_original_outputs):
443 """Builds the functional StatelessWhile/While op."""
444 cond_stateful_ops = [
445 op for op in cond_graph.get_operations() if op._is_stateful
446 ]
447 body_stateful_ops = [
448 op for op in body_graph.get_operations() if op._is_stateful
449 ]
450 if (cond_stateful_ops or body_stateful_ops):
451 op_fn = gen_functional_ops._while
452 else:
453 op_fn = gen_functional_ops.stateless_while
455 def _make_op(inputs):
456 while_op, tensors = util.get_op_and_outputs(op_fn(
457 inputs,
458 util.create_new_tf_function(cond_graph),
459 util.create_new_tf_function(body_graph),
460 output_shapes=output_shapes,
461 parallel_iterations=parallel_iterations,
462 name=name))
463 _copy_handle_data(body_graph.outputs, tensors)
464 util.maybe_set_lowering_attr(while_op)
465 util.maybe_propagate_compile_time_consts_in_xla(while_op)
466 _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
467 # This is needed so we do not compute derivative wrt these extra outputs.
468 while_op._set_attr("_num_original_outputs",
469 attr_value_pb2.AttrValue(i=num_original_outputs))
470 # The while op may be created inside a tf.function, in which case ops
471 # needs to capture "through" it when taking gradients; outer_graph is used
472 # as a sanity check that capturing only happens from parent to child.
473 cond_graph.outer_graph = ops.get_default_graph()
474 body_graph.outer_graph = ops.get_default_graph()
475 while_op._cond_graph = cond_graph
476 while_op._body_graph = body_graph
477 return tensors
478 return util.run_as_function_for_tape_gradients(_make_op, loop_vars)
481def _get_intermediates(func_graph):
482 """Returns all tensors in `func_graph` that should be accumulated."""
483 # We currently accumulate output tensors of most ops in the function and rely
484 # on the pruning pass to get rid of the unused accumulators at runtime.
485 # However, this can bloat the GraphDef and make debugging harder so we perform
486 # some optimizations.
487 #
488 # Optimization we currently perform:
489 # 1. We do not accumulate tensors which already have an accumulator
490 # in the loop body.
491 # 2. We do not accumulate outputs of Identity nodes. When building the
492 # FuncGraph, we add an Identity node for each output (see
493 # `AutomaticControlDependencies.mark_as_return`). Accumulating outputs
494 # of all these nodes bloats the GraphDef quite a bit so we remove those.
495 # Since the gradient of an Identity node does not rely on its forward op's
496 # input this is safe to do.
497 #
498 # Other possible optimizations:
499 # 1. Only accumulate tensors that will be required by the backward pass.
500 # This will require running the gradient pass and hence would increase the
501 # graph building time for the forward pass.
502 # 2. Do not accumulate Const nodes created inside the loop body.
503 # 3. Do not accumulate loop vars that are returned as-is just like captured
504 # tensors.
505 intermediates = []
506 reverse_captures = dict((v.ref(), k) for k, v in func_graph.captures)
508 for op in func_graph.get_operations():
509 if op.type == "Identity":
510 continue
511 # Accumulating mutexes can cause deadlock.
512 if op.type == "MutexLock":
513 continue
514 for o in op.outputs:
515 if (o is not func_graph.inputs[0] and # Loop counter.
516 o.dtype != dtypes.resource and # Do not accumulate resource tensors.
517 _get_accumulator(o) is None and # Has existing accumulator.
518 o.ref() not in reverse_captures
519 ): # Captured value, hence loop invariant.
520 intermediates.append(o)
521 return intermediates
524def _preprocess_grad(grad, body_graph_output, while_op_input, while_op_output):
525 """Returns the initial gradient to be used for a given output tensor.
527 Args:
528 grad: the original gradient Tensor passed to the gradient function.
529 body_graph_output: the corresponding Tensor in the body graph.
530 while_op_input: the corresponding Tensor input of the While op.
531 while_op_output: the corresponding Tensor output of the While op.
533 Returns:
534 A Tensor or None.
535 """
536 # Set the incoming gradient of non-trainable inputs to None. It is possible
537 # that we receive non-None gradients for non-trainable types in nested while
538 # loops because we accumulate outputs of the inner while as variant tensors
539 # which are trainable and hence receive zeros_like tensors in the gradient
540 # pass. The non-trainable tensors then receive the popped zeros tensor from
541 # this zeros variant. The gradient for the loop vars corresponding to these
542 # tensors is None or zeros (this happens only if the loop var is accumulated
543 # as well) in _grad_fn so we reset these.
544 # TODO(b/118712257): Remove once we can handle None output grads in _grad_fn.
545 if not _is_trainable(body_graph_output):
546 return None
548 # GradientTape initializes resource and variant grads as None instead of
549 # zeros. Set to zeros so _GradientsHelper computes the gradients instead of
550 # returning None.
551 # TODO(b/143286622): The supports_default_grad check is needed
552 # because While op emits non-differentiable resource tensors
553 # as outputs. Remove this check when that is not the case.
554 # Note: We use `while_op_input` instead of `while_op_output` for the call
555 # to `supports_default_grad` because `while_op_output` may be missing
556 # handle_data if the While is in a restored saved model.
557 if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and
558 default_gradient.supports_default_grad(while_op_input) and grad is None):
559 return _zeros_like(while_op_input, while_op_output)
561 # Convert IndexedSlices to dense tensors since it is unlikely that downstream
562 # gradient functions with properly handle indexed slices. This is similar to
563 # what we do in tf.function gradients.
564 if isinstance(grad, indexed_slices.IndexedSlices):
565 return ops.convert_to_tensor(grad)
567 return grad
570# TODO(skyewm): make this return constants if op_output's shape is fully
571# defined (this can be done by checking the "shape" attr of resource vars).
572def _zeros_like(op_input, op_output):
573 """Like array_ops.zeros_like() but also accepts resource var handles."""
574 if op_output.dtype == dtypes.resource:
575 # Note: We use `op_input` instead of `op_output` to get the zeros dtype
576 # because `op_output` may be missing handle_data if the While is in a
577 # restored saved model.
578 return array_ops.zeros(
579 gen_resource_variable_ops.variable_shape(op_output),
580 dtype=default_gradient.get_zeros_dtype(op_input))
581 return array_ops.zeros_like(op_output)
584def _is_trainable(tensor):
585 """Returns whether the given tensor is trainable."""
586 if not backprop_util.IsTrainable(tensor):
587 return False
589 # Special case: untrainable accumulator output. The gradients algorithm
590 # doesn't know about tensor lists of untrainable elements. In theory the
591 # tensor list gradient functions should return None as appropriate, but
592 # because we can't return None from the gradient function we filter out
593 # untrainable accumulator output here to avoid computing the gradient at all.
594 if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
595 assert tensor.dtype == dtypes.variant
596 element_type = tensor.op.get_attr("element_dtype")
597 return backprop_util.IsTrainable(element_type)
599 return True
602def _get_graph(while_op, func_attr_name, attr_graph_name):
603 """Returns `FuncGraph` for the given function attribute.
605 Args:
606 while_op: The While Operation.
607 func_attr_name: string
608 attr_graph_name: cached forward graph name
610 Returns:
611 `FuncGraph`
612 """
613 func_graph = getattr(while_op, attr_graph_name, None)
614 if func_graph is None:
615 # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
616 input_shapes = [
617 tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
618 ]
619 func_name = while_op.get_attr(func_attr_name).name
620 func_graph = util.get_func_graph(while_op, input_shapes, func_name)
621 func_graph._while = while_op
622 return func_graph
625def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
626 maximum_iterations):
627 """Builds and returns the gradient FuncGraph of `func_graph` and its args.
629 The returned grad_func_graph must be called with the returned
630 args + grad_func_graph.captures.
632 Args:
633 ys: A `Tensor` or list of tensors to be differentiated.
634 xs: A `Tensor` or list of tensors to be used for differentiation.
635 grads: The incoming grads for `ys`.
636 cond_graph: FuncGraph for the forward cond function.
637 body_graph: FuncGraph for the forward body function.
638 name: Name of the returned gradient function.
639 while_op: The forward While op.
640 maximum_iterations: Tensor. The maximum number of iterations.
642 Returns:
643 2-tuple of (grad_func_graph, args).
644 """
645 assert len(ys) == len(grads)
647 total_iters = while_op.outputs[0]
648 counter = constant_op.constant(
649 0, dtype=total_iters.dtype, name="grad_counter")
651 # Build frozen sets so that we do not have linear time lookups in
652 # `_is_loop_invariant`. Note: `body_graph.inputs` and `body_graph.outputs`
653 # may get updated during gradient computation because we add accumulators to
654 # the forward op. However, those are not loop invariants so wouldn't affect
655 # the output of `_is_loop_invariant`. Also we would never attempt to capture
656 # those accumulators so `_is_loop_invariant` should never receive those new
657 # tensors as args.
658 body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs)
659 body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs)
661 args = [counter, maximum_iterations, total_iters] + list(grads)
662 # Note: The returned function does not have `args` in the list of
663 # `external_captures`.
664 grad_func_graph = func_graph_module.func_graph_from_py_func(
665 name,
666 lambda *args: _grad_fn(ys, xs, args, body_graph),
667 args, {},
668 func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph,
669 maximum_iterations, while_op,
670 body_graph_inputs, body_graph_outputs))
672 # Update the list of outputs with tensors corresponding to the captured
673 # tensors. We capture 3 types of tensors when building the grad fn:
674 # 1. Accumulators for forward graph intermediates which are not loop
675 # invariants. The outputs corresponding to these are populated in
676 # `internal_capture_to_output` by `_WhileBodyGradFuncGraph`.
677 # 2. Resources, which are output as is.
678 # 3. Forward graph loop invariants, which are output as is.
679 for external_capture, internal_capture in grad_func_graph.captures:
680 if (ops.tensor_id(internal_capture)
681 in grad_func_graph.internal_capture_to_output):
682 new_output = grad_func_graph.internal_capture_to_output[ops.tensor_id(
683 internal_capture)]
684 else:
685 raise ValueError(
686 f"Tensor {str(internal_capture)} which captures "
687 f"{str(external_capture)} is in list of "
688 f"internal_captures but not in internal_capture_to_output.")
689 grad_func_graph.outputs.append(new_output)
690 grad_func_graph.structured_outputs.append(new_output)
692 return grad_func_graph, args
695def _grad_fn(ys, xs, args, func_graph):
696 """Computes the gradient of `func_graph` in the current graph.
698 This function builds the gradient graph of the corresponding forward-pass
699 `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.
701 Args:
702 ys: A `Tensor` or list of tensors to be differentiated.
703 xs: A `Tensor` or list of tensors to be used for differentiation.
704 args: The input arguments.
705 args[0] - Loop counter
706 args[1] - Total number of iterations.
707 args[2] - maximum_iterations.
708 args[3:] - Incoming gradients for `ys`.
709 func_graph: function.FuncGraph. The corresponding forward-pass function.
711 Returns:
712 The output gradient Tensors.
713 """
714 grad_ys = args[3:]
716 # Build the gradient graph. Note that this builds the gradient computation of
717 # func_graph in the current graph, which requires capturing tensors from
718 # func_graph. The captured func_graph tensors are resolved to external tensors
719 # after the forward While op has been rewritten in _resolve_grad_captures.
720 # TODO(srbs): Mark GradientsHelper as public?
721 grad_outs = gradients_util._GradientsHelper(
722 ys, xs, grad_ys=grad_ys, src_graph=func_graph,
723 unconnected_gradients="zero")
725 # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there
726 # is a tf.StopGradient in the loop body.
727 assert all(g is not None for g in grad_outs)
728 counter = args[0]
729 maximum_iterations = args[1]
730 total_iters = args[2]
731 return [counter + 1, maximum_iterations, total_iters] + grad_outs
734def _resolve_grad_captures(body_graph, body_grad_graph, while_op):
735 """Returns the tensors to pass as captured inputs to `body_grad_graph`.
737 `body_grad_graph` may have external references to:
738 1. Its outer graph containing the input gradients. These are left as-is.
739 2. Accumulators captured from the forward-pass graph. These should have been
740 added as `while_op` outputs after the gradient graph was built. We replace
741 these with the corresponding output of `while_op`, i.e. a tensor in
742 `body_graph.outer_graph`. In the case of nested control flow or functions,
743 the gradient logic handling `body_grad_graph.outer_graph` will make sure
744 the tensor from `body_graph.outer_graph` is also correctly captured.
746 Args:
747 body_graph: FuncGraph. The forward-pass body function.
748 body_grad_graph: FuncGraph. The body gradients function.
749 while_op: The forward-pass While Operation calling `body_graph`.
751 Returns:
752 A list of input tensors to be passed as the captured inputs to
753 `body_grad_graph`.
754 """
755 new_capture_inputs = []
756 for t in body_grad_graph.external_captures:
757 # Resolve tensors captured from the forward graph to the outputs of the
758 # forward while_op.
759 if t.graph == body_graph:
760 # Captured accumulator or loop invariant.
761 for i, output in enumerate(t.graph.outputs):
762 if output is t:
763 t = while_op.outputs[i]
764 break
766 # Note: We rely on the capturing logic of the gradient While op graph to
767 # correctly capture the tensors in `body_graph.outer_graph`. Both cond_v2
768 # and while_v2 handle this while building their gradient functions.
769 assert t.graph == body_graph.outer_graph
771 new_capture_inputs.append(t)
772 return new_capture_inputs
775def _get_structured_grad_output(outputs, grads, body_grad_graph):
776 """Returns the values that should be returned from the while grad function.
778 Args:
779 outputs: the raw Tensor outputs of the grad While op.
780 grads: the input gradients to the gradient function.
781 body_grad_graph: _WhileBodyGradFuncGraph.
783 Returns:
784 A list of gradient values. May include Nones.
785 """
786 result = []
787 # outputs[0] is the loop counter.
788 # outputs[1] is maximum_iterations.
789 # outputs[2] is the total number of loop iterations.
790 outputs_idx = 3
791 structured_outputs_idx = 3
792 for g in grads:
793 # Set None as the output gradient for tensors with None input gradient.
794 if g is None:
795 result.append(None)
796 continue
797 output = body_grad_graph.structured_outputs[structured_outputs_idx]
798 structured_outputs_idx += 1
799 if isinstance(output, indexed_slices.IndexedSlices):
800 # TODO(skyewm): is there a more robust way to determine the order of
801 # flattened IndexedSlices components?
802 result.append(indexed_slices.IndexedSlices(
803 values=outputs[outputs_idx],
804 indices=outputs[outputs_idx + 1],
805 dense_shape=outputs[outputs_idx + 2]))
806 outputs_idx += 3
807 else:
808 assert isinstance(output, ops.Tensor)
809 result.append(outputs[outputs_idx])
810 outputs_idx += 1
812 return result
815def _get_accumulator(tensor):
816 r"""Returns TensorList if any containing accumulated values of tensor.
818 We try to find a pattern of the form:
820 input_tl tensor
821 \ /
822 (TensorListPushBack)
823 |
824 output_tl
826 which satisfies the following conditions:
828 1. input_tl must be in tensor.graph.inputs.
829 2. output_tl or Identity(output_tl) must be in tensor.graph.outputs.
830 3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t).
832 output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is
833 returned if such a pattern is found else None is returned.
835 Args:
836 tensor: The Tensor to be accumulated.
838 Returns:
839 A variant tensor in the same graph as `tensor` or None if no accumulator is
840 found.
841 """
842 assert isinstance(tensor.graph, func_graph_module.FuncGraph)
844 def get_func_graph_output(t):
845 """Returns t or Identity(t) whichever exists in graph outputs else None."""
846 for output in tensor.graph.outputs:
847 if output is t:
848 return t
849 # tf.defun adds an Identity for each output, check whether that is the case.
850 identity_op = t.consumers()[0]
851 if (identity_op.type == "Identity" and
852 any(identity_op.outputs[0] is t for t in tensor.graph.outputs)):
853 return identity_op.outputs[0]
854 return None
856 for consumer in tensor.consumers():
857 # Find the consumer that is a TensorListPushBack node whose TensorList input
858 # is in the list of function inputs.
859 if consumer.type != "TensorListPushBack":
860 continue
862 accum_input_idx = -1
863 for accum_input_idx, inp in enumerate(tensor.graph.inputs):
864 if inp is consumer.inputs[0]:
865 break
866 else:
867 continue
869 output = get_func_graph_output(consumer.outputs[0])
870 if output is None:
871 # The TensorList output of `consumer` is not in the list of function
872 # outputs.
873 continue
875 for accum_output_idx, out in enumerate(tensor.graph.outputs):
876 if out is output:
877 if accum_input_idx == accum_output_idx:
878 return output
879 break
881 return None
884OptimizedReductionOpsCacheKey = collections.namedtuple(
885 "OptimizedReductionOpsCacheKey", [
886 "op_type",
887 "inputs",
888 "dtypes",
889 "input_types",
890 "name",
891 "attrs",
892 "op_def",
893 "compute_device",
894 ])
897class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
898 """FuncGraph for the gradient function of the body of a While op.
900 Contains the logic for capturing the tensors from the body of the forward
901 While op which is as follows:
902 1. If the tensor is of resource type (these are not accumulated):
903 a. Ensure that the tensor is a loop invariant, i.e., it exists in both loop
904 inputs and outputs at the same index.
905 b. Lookup the corresponding resource tensor in the forward outer graph and
906 try to capture that.
907 2. If the tensor is not of resource type:
908 a. Create an accumulator for that tensor and output it from the forward
909 pass. Note this also requires adding it as an input to the forward pass.
910 b. Capture the accumulator from the forward pass in this FuncGraph. This
911 will later be resolved to the correct output of the forward While op.
912 c. Pop a value from the captured placeholder and use it as the captured
913 value for the forward pass tensor.
915 This only allows capturing tensors in the forward graph. A ValueError is
916 raised if an attempt is made to capture a tensor not in the forward graph.
917 To manually capture a tensor that is not in the forward graph, call `capture`
918 with `allowlisted=True`.
920 Note: The `captures` dict does not contain the forward tensor since it is not
921 directly captured. It contains the accumulator corresponding to this forward
922 tensor.
924 Attributes:
925 while_op_needs_rewrite: True if any non-resource intermediates were
926 captured, meaning the forward While op needs to be rewritten to output the
927 corresponding accumulators.
928 extra_inputs: list of EmptyTensorList tensors to be used as initial input to
929 the new accumulators in the forward graph. It may also contain external
930 captures of the custom gradient function.
931 internal_capture_to_output: dict from a tensor_id(captured placeholder) to
932 the corresponding tensor that needs to be added to the list of outputs.
933 For instance, when capturing an accumulator TensorList this contains the
934 TensorList obtained after popping a tensor from the list. Other entries
935 in this dict are expected, though not enforced, to be identities.
936 This dict is needed because these output tensors need to be added to
937 FuncGraph.outputs "after" the tensors returned from the gradient function.
938 """
940 def __init__(self, name, forward_cond_graph, forward_body_graph,
941 maximum_iterations, forward_while_op, body_graph_inputs,
942 body_graph_outputs):
943 super(_WhileBodyGradFuncGraph, self).__init__(name)
944 self.extra_inputs = []
945 self.internal_capture_to_output = {}
946 # FuncGraph for the body of the forward While op.
947 self._forward_graph = forward_body_graph
948 # FuncGraph for the cond of the forward While op.
949 self._forward_cond_graph = forward_cond_graph
950 self._maximum_iterations = maximum_iterations
951 self._forward_while_op = forward_while_op
952 # Dict from forward intermediate tensor to its indirectly captured tensor
953 # in this graph. Indirect capturing happens in two ways:
954 # 1. For non-resource tensors we capture their accumulators from the forward
955 # outer graph and pop values from that accumulator inside this graph
956 # using TensorListPopBack.
957 # 2. For resource tensors we directly capture their corresponding tensor
958 # in the forward outer graph.
959 self._indirect_captures = {}
961 @property
962 def while_op_needs_rewrite(self):
963 return self.extra_inputs
965 def _create_op_internal(
966 self,
967 op_type,
968 inputs,
969 dtypes=None, # pylint: disable=redefined-outer-name
970 input_types=None,
971 name=None,
972 attrs=None,
973 op_def=None,
974 compute_device=True):
975 # For a reduction op, if op is in the gradient body graph and its input is
976 # from the forward graph, moving op to the forward graph means we would
977 # store the tensor after the reduction as opposed to the tensor before
978 # reduction, and therefore could significantly reduce memory consumption.
979 # For now, we do this only for a few ops.
980 #
981 # We don't do this if any input tensor has already been accumulated. This
982 # can happen if we output all intermediates in the forward pass.
983 #
984 # If in XLA context, do not move constant ops to forward pass as pushing to
985 # and popping from a TensorList removes the constant property of an op and
986 # breaks XLA compilation, which requires certain inputs to be compile-time
987 # constant for certain ops.
988 #
989 # This optimization is currently also disabled when under a persistent tape,
990 # since it leads to an unbounded number of side outputs. With caching it may
991 # be possible to re-enable it.
992 optimized_reduction_ops = {
993 "Shape", "Size", "Rank", "TensorListElementShape", "TensorListLength"
994 }
995 if (op_type in optimized_reduction_ops and
996 not util.output_all_intermediates() and
997 all(input.graph is self._forward_graph for input in inputs) and
998 all(_get_accumulator(input) is None for input in inputs) and
999 not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and
1000 not util.graph_wrapped_for_higher_order_tape_gradients(
1001 self._forward_graph)):
1002 return self._move_op_to_forward_graph(
1003 op_type,
1004 inputs,
1005 dtypes=dtypes,
1006 input_types=input_types,
1007 name=name,
1008 attrs=attrs,
1009 op_def=op_def,
1010 compute_device=compute_device)
1012 return super(_WhileBodyGradFuncGraph, self)._create_op_internal(
1013 op_type,
1014 inputs,
1015 dtypes=dtypes,
1016 input_types=input_types,
1017 name=name,
1018 attrs=attrs,
1019 op_def=op_def,
1020 compute_device=compute_device)
1022 def _move_op_to_forward_graph(
1023 self,
1024 op_type,
1025 inputs,
1026 dtypes=None, # pylint: disable=redefined-outer-name
1027 input_types=None,
1028 name=None,
1029 attrs=None,
1030 op_def=None,
1031 compute_device=True):
1032 # We have a cache of reduction ops that have already been moved to the
1033 # forward graph, and we will check it first to avoid moving an op twice.
1034 if not hasattr(self._forward_graph, "_optimized_reduction_ops_cache"):
1035 self._forward_graph._optimized_reduction_ops_cache = {}
1036 cache_key = self._get_optimized_reduction_ops_cache_key(
1037 op_type, inputs, dtypes, input_types, name, attrs, op_def,
1038 compute_device)
1039 cached_op = self._forward_graph._optimized_reduction_ops_cache.get(
1040 cache_key)
1041 if cached_op is not None:
1042 # This op has already been moved to the forward graph and we have it in
1043 # the cache.
1044 return cached_op
1046 with self._forward_graph.as_default():
1047 # `name` was built using name_scope stack of gradient graph and may not
1048 # be unique in the forward graph. `Graph.create_op` does not uniquify
1049 # names which are name scopes i.e. end in `/`. To ensure that the op
1050 # created gets a unique name in the forward graph we get rid of the
1051 # trailing slash.
1052 name = ops.name_from_scope_name(name)
1053 result = self._forward_graph._create_op_internal(
1054 op_type,
1055 inputs,
1056 dtypes=dtypes,
1057 input_types=input_types,
1058 name=name,
1059 attrs=attrs,
1060 op_def=op_def,
1061 compute_device=compute_device)
1063 # Store the op we just moved to the forward graph so that it does
1064 # not need to be added there again.
1065 self._forward_graph._optimized_reduction_ops_cache[cache_key] = result
1066 return result
1068 def _get_optimized_reduction_ops_cache_key(
1069 self,
1070 op_type,
1071 inputs,
1072 dtypes=None, # pylint: disable=redefined-outer-name
1073 input_types=None,
1074 name=None,
1075 attrs=None,
1076 op_def=None,
1077 compute_device=True):
1078 # We need all elements of CacheKey to be hashable.
1079 inputs = tuple(map(lambda t: t.ref(), inputs))
1081 if dtypes is not None:
1082 dtypes = tuple(dtypes)
1084 if input_types is not None:
1085 input_types = tuple(input_types)
1087 if attrs is not None:
1088 hashable_attrs = []
1089 for attr_name, attr_value in sorted(attrs.items()):
1090 hashable_attrs.append((attr_name, attr_value.SerializeToString()))
1091 attrs = tuple(hashable_attrs)
1093 if op_def is not None:
1094 op_def = op_def.SerializeToString()
1096 return OptimizedReductionOpsCacheKey(op_type, inputs, dtypes, input_types,
1097 name, attrs, op_def, compute_device)
1099 def _capture_helper(self, tensor, name):
1100 """Implements the capturing described in the class docstring."""
1101 captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
1102 if captured_tensor is not None:
1103 return captured_tensor
1105 if tensor.graph is not self._forward_graph:
1106 already_captured = id(tensor) in self.function_captures.by_val_internal
1107 captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper(
1108 tensor, name)
1109 if not already_captured:
1110 # Adds the captured tensor to the list of outputs so that the input
1111 # and output signatures match.
1112 self.internal_capture_to_output[ops.tensor_id(
1113 captured_tensor)] = captured_tensor
1114 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
1115 return captured_tensor
1117 while tensor.op.type == "Identity":
1118 # We do not accumulate the output of identity nodes so we try to capture
1119 # the input of the Identity node instead.
1120 tensor = tensor.op.inputs[0]
1122 captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
1123 if captured_tensor is not None:
1124 return captured_tensor
1126 # No need to accumulate loop invariants. Capture them directly.
1127 # The captured tensor gets resolved to the corresponding while output in
1128 # `_resolve_grad_captures`.
1129 if _is_loop_invariant(tensor, self._forward_graph.inputs,
1130 self._forward_graph.outputs):
1131 captured_tensor = super(_WhileBodyGradFuncGraph,
1132 self)._capture_helper(tensor, name)
1133 # Add to `internal_capture_to_output` so that this gets added to the list
1134 # of outputs.
1135 self.internal_capture_to_output[ops.tensor_id(
1136 captured_tensor)] = captured_tensor
1137 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
1138 return captured_tensor
1140 # Do not accumulate Const nodes. Instead copy them directly in the backward
1141 # graph.
1142 # TODO(srbs): This just checks for `Const` nodes. Consider checking for
1143 # graph compile time consts in general.
1144 # TODO(srbs): Consider making this a loop input.
1145 if constant_op.is_constant(tensor):
1146 real_value = constant_op.constant(
1147 tensor_util.constant_value(tensor), dtype=tensor.dtype)
1148 self._indirect_captures[ops.tensor_id(tensor)] = real_value
1149 return real_value
1151 # Resource tensors are not accumulated and handled specially.
1152 if tensor.dtype == dtypes.resource:
1153 return self._resource_capture_helper(tensor)
1155 # Create or find an existing accumulator output for `tensor` in the forward
1156 # graph, and fetch from this accumulator in the gradient graph to get the
1157 # raw intermediate value.
1158 accumulator = _get_accumulator(tensor)
1159 if accumulator is None:
1160 # Create the initial empty tensor list.
1161 #
1162 # Note: We clear the control dependencies to avoid a cycle in case a
1163 # control tensor has an input path to an output of the forward While.
1164 #
1165 # E.g.:
1166 # x = tf.while_loop(...)
1167 # y = f(x)
1168 # with tf.control_dependencies([y]):
1169 # tf.gradients(y, x)
1170 #
1171 # Since the EmptyTensorList is fed back into the forward While, not
1172 # removing the control edge would cause a cycle.
1173 with self._forward_graph.outer_graph.as_default():
1174 with util.clear_control_inputs():
1175 tensor_list = list_ops.empty_tensor_list(
1176 element_dtype=tensor.dtype,
1177 element_shape=tensor.shape,
1178 max_num_elements=self._maximum_iterations,
1179 name=_build_accumulator_name(tensor))
1180 self.extra_inputs.append(tensor_list)
1182 # Push the intermediate tensor to the tensor list. This captures
1183 # `tensor_list`.
1184 with self._forward_graph.as_default():
1185 accumulator = list_ops.tensor_list_push_back(tensor_list, tensor)
1186 # Add the modified tensor list to the list of outputs. This output will be
1187 # all the accumulated values.
1188 self._forward_graph.outputs.append(accumulator)
1190 # Capture in the cond graph as well so the forward cond and body inputs
1191 # match.
1192 with self._forward_cond_graph.as_default():
1193 self._forward_cond_graph.capture(tensor_list)
1195 # Capture the accumulator tensor list in the gradient graph directly from
1196 # the forward graph -- we'll later modify this to capture the final list
1197 # output by the forward While op instead.
1198 captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(
1199 accumulator, name)
1201 # Pop the intermediate value from the tensor list in the gradient graph.
1202 new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
1203 captured_accumulator, element_dtype=tensor.dtype)
1205 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
1206 self.internal_capture_to_output[ops.tensor_id(
1207 captured_accumulator)] = new_tensor_list
1208 return captured_tensor
1210 def _resource_capture_helper(self, tensor):
1211 """Returns the captured resource tensor.
1213 Resource-type tensors are not accumulated. If a resource tensor exists in
1214 the loop body it must either be a loop input or an output of a nested While
1215 op inside the loop body which had captured the external resource.
1217 Args:
1218 tensor: the external resource Tensor to be captured.
1220 Returns:
1221 Tensor in this graph.
1222 """
1223 assert tensor.dtype == dtypes.resource
1225 forward_graph_input_names = [t.name for t in self._forward_graph.inputs]
1226 forward_graph_name_to_opdef = {
1227 op.name: op.node_def for op in self._forward_graph.get_operations()}
1228 index = util.resource_input_index(
1229 tensor.name, forward_graph_input_names,
1230 forward_graph_name_to_opdef,
1231 self._forward_graph._functions)
1233 input_placeholder = self._forward_graph.inputs[index]
1234 tensor_in_outer_graph = self._forward_graph._while.inputs[index]
1236 assert input_placeholder.dtype == dtypes.resource
1237 assert tensor_in_outer_graph.dtype == dtypes.resource
1238 # This must be a loop invariant. However, infrastructure
1239 # (e.g. tf.vectorized_map) may insert identity nodes, function calls, conds,
1240 # etc. which take and return the resource tensor unmodified; this means that
1241 # the Python objects may differ.
1242 if index != util.resource_input_index(
1243 self._forward_graph.outputs[index].name, forward_graph_input_names,
1244 forward_graph_name_to_opdef,
1245 self._forward_graph._functions):
1246 raise AssertionError(
1247 f"Resource tensors must be loop invariants {tensor_in_outer_graph}")
1249 self._indirect_captures[ops.tensor_id(tensor)] = self.capture(
1250 tensor_in_outer_graph)
1251 return self._indirect_captures[ops.tensor_id(tensor)]
1254def _check_shapes_compat(flat_output_tensors, flat_shape_invariants,
1255 flat_input_tensors):
1256 for (t, shape, input_t) in zip(flat_output_tensors, flat_shape_invariants,
1257 flat_input_tensors):
1258 if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape):
1259 raise ValueError(
1260 f"Input tensor `{input_t.name}` enters the loop with shape {shape}, "
1261 f"but has shape {t.shape} after one iteration. To allow the shape to "
1262 "vary across iterations, use the `shape_invariants` argument of "
1263 "tf.while_loop to specify a less-specific shape.")
1266def _check_num_inputs_outputs(cond_graph, body_graph, num_flattened_loop_vars):
1267 """Checks the number of inputs/outputs of `cond_graph` and `body_graph`."""
1268 assert len(cond_graph.inputs) == num_flattened_loop_vars, (
1269 "cond_graph takes %d inputs; Expected: %d" % (len(cond_graph.inputs),
1270 num_flattened_loop_vars))
1271 assert len(cond_graph.outputs) == 1, (
1272 "cond_graph has %d outputs; Expected: 1" % len(cond_graph.outputs))
1273 assert len(body_graph.inputs) == num_flattened_loop_vars, (
1274 "body_graph takes %d inputs; Expected: %d" % (len(body_graph.inputs),
1275 num_flattened_loop_vars))
1276 assert len(body_graph.outputs) == num_flattened_loop_vars, (
1277 "body_graph has %d outputs; Expected: %d" % (len(body_graph.outputs),
1278 num_flattened_loop_vars))
1281def _check_inputs_outputs_types_match(body_graph, flattened_loop_vars):
1282 for inp, out, loop_var in zip(body_graph.inputs, body_graph.outputs,
1283 flattened_loop_vars):
1284 if inp.dtype != out.dtype:
1285 raise TypeError(
1286 f"Loop var {loop_var.name} enters the loop with type {inp.dtype} "
1287 f"but has type {out.dtype} after 1 iteration. {loop_var.name} type "
1288 "should remain constant.")
1291def _build_cond_placeholders_name_prefix(cond_graph):
1292 return cond_graph.unique_name(cond_graph.name + "___redundant_placeholder")
1295def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures):
1296 """Creates placeholders for body captures in cond_graph.
1298 This is needed to match signatures of cond and body graphs.
1300 Args:
1301 cond_graph: cond branch graph
1302 body_graph_captures: Tensors which were captured when building the
1303 `body_graph`.
1304 """
1305 types = [t.dtype.as_datatype_enum for t in body_graph_captures]
1306 # TODO(srbs): Providing a unique prefix does not ensure that there is no
1307 # conflict between the placeholder names and existing nodes in the graph.
1308 # However passing a list of strings may not be performant.
1309 # Ideally we should move `Graph.unique_name` to C++ or make
1310 # `Graph._names_in_use` a trie so that we can find a unique prefix.
1311 # TODO(b/143286622): This should not be required once captures are separated
1312 # from regular loop vars.
1313 with cond_graph._c_graph.get() as c_graph:
1314 placeholders = c_api.TF_CreatePlaceholders(
1315 c_graph, types,
1316 compat.as_str(_build_cond_placeholders_name_prefix(cond_graph)))
1317 placeholder_ops = [
1318 ops.Operation._from_c_op(ph.oper, cond_graph) for ph in placeholders
1319 ]
1321 tensors = []
1322 for op in placeholder_ops:
1323 tensors.append(op.outputs[0])
1325 # Update `cond_graph._captures` and `cond_graph.inputs` to contain the
1326 # newly created placeholders.
1327 tuples = zip(body_graph_captures, tensors)
1328 keys = [id(t) for t in body_graph_captures]
1329 for k, v in zip(keys, tuples):
1330 cond_graph._function_captures.add_or_replace(
1331 key=k,
1332 external=v[0],
1333 internal=v[1],
1334 is_by_ref=False)
1335 cond_graph.inputs.extend(tensors)
1338def _copy_handle_data(src_tensors, tgt_tensors):
1339 for src_t, tgt_t in zip(src_tensors, tgt_tensors):
1340 handle_data_util.copy_handle_data(src_t, tgt_t)
1343def _pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, loop_vars):
1344 """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""
1346 def flow_to_tensor_array(flow, ta): # pylint: disable=missing-docstring
1347 return (tensor_array_ops.build_ta_with_new_flow(ta, flow) if isinstance( # pylint: disable=g-long-ternary
1348 ta, tensor_array_ops.TensorArray) else flow)
1350 flattened_loop_vars = [
1351 flow_to_tensor_array(*z)
1352 for z in zip(nest.flatten(loop_vars, expand_composites=True),
1353 flat_orig_loop_vars)
1354 ]
1355 return nest.pack_sequence_as(loop_vars_signature, flattened_loop_vars,
1356 expand_composites=True)
1359def _tensor_array_to_flow(loop_vars):
1361 def f(maybe_ta):
1362 if isinstance(maybe_ta, tensor_array_ops.TensorArray):
1363 return maybe_ta.flow
1364 return maybe_ta
1366 return nest.map_structure(f, loop_vars, expand_composites=True)
1369def _build_maximum_iterations_loop_var(maximum_iterations):
1370 if maximum_iterations is None:
1371 # Default value for max_num_elements to EmptyTensorList meaning that the
1372 # list size is unbounded.
1373 maximum_iterations = -1
1374 # EmptyTensorList expects `max_num_elements` to be of type int32.
1375 return ops.convert_to_tensor(
1376 maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
1379def _build_accumulator_name(tensor):
1380 # Tensor name may be of the form "pow/y:0". Name scope does not allow ":".
1381 return "{}/accumulator".format(tensor.name).replace(":", "_")
1384def _is_loop_invariant(tensor, inputs, outputs):
1385 return (any(tensor is t for t in inputs) and
1386 any(tensor is t for t in outputs))
1389def _set_read_only_resource_inputs_attr(op, branch_graphs):
1390 """Sets the list of resource inputs which are read-only.
1392 This is used by AutomaticControlDependencies.
1394 Args:
1395 op: While Operation.
1396 branch_graphs: List of branch FuncGraphs.
1397 """
1398 read_only_indices = set(range(len(op.inputs)))
1399 for branch_graph in branch_graphs:
1400 if not read_only_indices:
1401 break
1402 branch_read_only_indices = acd.get_read_only_resource_input_indices_graph(
1403 branch_graph)
1404 read_only_indices = read_only_indices.intersection(branch_read_only_indices)
1406 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
1407 sorted(read_only_indices))
1409# pylint: enable=protected-access