Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py: 14%
495 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""cond_v2 and gradient.
17This is a version of cond that emits a single If op, as well as the gradient
18function for If ops produced by cond_v2. This will eventually replace the
19current tf.cond implementation once it reaches feature and performance parity.
20"""
22import collections
24from tensorflow.core.framework import types_pb2
25from tensorflow.python.eager import backprop_util
26from tensorflow.python.framework import auto_control_deps
27from tensorflow.python.framework import auto_control_deps_utils as acd
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors_impl
31from tensorflow.python.framework import func_graph as func_graph_module
32from tensorflow.python.framework import indexed_slices
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.framework import type_spec
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_util
39from tensorflow.python.ops import control_flow_util_v2 as util
40from tensorflow.python.ops import default_gradient
41from tensorflow.python.ops import gen_functional_ops
42from tensorflow.python.ops import gen_optional_ops
43from tensorflow.python.ops import gradients_util
44from tensorflow.python.ops import handle_data_util
45from tensorflow.python.ops import math_ops
46from tensorflow.python.util import nest
49# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify
50# that they aren't part of the official public API. These protected members
51# often need to be used by implementation code however. Rather than litter the
52# code with pylint comments, we ignore protected access violations for
53# readability.
54# pylint: disable=protected-access
56_COND = 1
57_CASE = 2
60def cond_v2(pred, true_fn, false_fn, name="cond"):
61 """Like tf.cond, except emits a single If op."""
62 if isinstance(pred, bool):
63 raise TypeError("pred must not be a Python bool", pred)
65 if not name:
66 name = "cond"
68 with ops.name_scope(name) as scope:
69 true_name = util.unique_fn_name(scope, "true")
70 false_name = util.unique_fn_name(scope, "false")
72 # Automatic control dependencies are added in defuns, but not in v1
73 # graphs. Propagate that behavior here.
74 add_control_dependencies = ops.get_default_graph()._add_control_dependencies
75 pred = ops.convert_to_tensor(pred)
76 if (tensor_util.is_tf_type(pred) and
77 (pred.shape.dims is None or pred.shape.dims)):
78 pred = array_ops.squeeze_v2(pred)
80 true_graph = func_graph_module.func_graph_from_py_func(
81 true_name,
82 true_fn, [], {},
83 func_graph=util.CondBranchFuncGraph(
84 true_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
85 add_control_dependencies=add_control_dependencies,
86 op_return_value=pred)
87 false_graph = func_graph_module.func_graph_from_py_func(
88 false_name,
89 false_fn, [], {},
90 func_graph=util.CondBranchFuncGraph(
91 false_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
92 add_control_dependencies=add_control_dependencies,
93 op_return_value=pred)
95 verify_captures(_COND, [true_graph, false_graph])
96 return _build_cond(
97 pred,
98 true_graph,
99 false_graph,
100 true_graph.external_captures,
101 false_graph.external_captures,
102 building_gradient=False,
103 name=scope)
106@ops.RegisterGradient("StatelessIf")
107@ops.RegisterGradient("If")
108def _IfGrad(op, *grads): # pylint: disable=invalid-name
109 """The gradient of an If op produced by cond_v2."""
110 # Get the if operator (this logic handles the case where op is a MockOp)
111 if_op = op.outputs[0].op
112 true_graph, false_graph = get_func_graphs(if_op)
113 # Note: op.graph != ops.get_default_graph() when we are computing the gradient
114 # of a nested cond.
115 assert true_graph.outer_graph == if_op.graph
116 assert false_graph.outer_graph == if_op.graph
118 # Create grad functions that compute the gradient of the true/false forward
119 # graphs. These functions will capture tensors from the forward pass
120 # functions.
121 true_grad_graph = _create_grad_func(
122 true_graph, grads, util.unique_grad_fn_name(true_graph.name))
123 false_grad_graph = _create_grad_func(
124 false_graph, grads, util.unique_grad_fn_name(false_graph.name))
126 # Replaces output None grads with zeros if at least one branch has non-None
127 # grad at that index.
128 _create_zeros_for_none_grads([true_graph, false_graph],
129 [true_grad_graph, false_grad_graph])
131 if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite):
132 # Modify 'op' to output the intermediates needed by the grad functions. Note
133 # that all needed intermediates are wrapped in optionals. Each optional
134 # intermediate output will have a value iff its corresponding branch is
135 # taken.
136 # NOTE(skyewm): if there are any active sessions, this modification to `op`
137 # may make them unrunnable!
139 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
140 # XLA does not yet support optionals, so output intermediates directly and
141 # make them match via FakeParams, which can be converted to zeros in XLA.
142 # TODO(skyewm,jpienaar): can XLA support optionals?
143 true_intermediates = true_grad_graph.xla_intermediates
144 false_intermediates = false_grad_graph.xla_intermediates
145 extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
146 [true_graph, false_graph], [true_intermediates, false_intermediates])
147 else:
148 true_intermediates = true_grad_graph.wrapped_intermediates
149 false_intermediates = false_grad_graph.wrapped_intermediates
150 # Make outputs match by adding none optionals.
151 extra_true_outputs, extra_false_outputs = _make_intermediates_match(
152 [true_graph, false_graph], [true_intermediates, false_intermediates])
154 true_graph.outputs.extend(extra_true_outputs)
155 false_graph.outputs.extend(extra_false_outputs)
156 # TODO(skyewm): indicate it's an internal bug if this fails.
157 _check_same_outputs(_COND, [true_graph, false_graph])
159 true_graph.name += "_rewritten"
160 false_graph.name += "_rewritten"
162 if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph))
163 if_op._set_func_attr("else_branch",
164 util.create_new_tf_function(false_graph))
165 if_op._set_type_list_attr("Tout", true_graph.output_types)
166 if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes)
167 if_op._add_outputs(
168 [t.dtype for t in extra_true_outputs],
169 [t.shape for t in extra_true_outputs])
171 # Resolve references to forward graph tensors in grad graphs and ensure
172 # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
173 true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
174 false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
176 # This modifies true_grad_graph and false_grad_graph.
177 _make_output_composite_tensors_match(_COND,
178 [true_grad_graph, false_grad_graph])
180 outputs = _build_cond(
181 if_op.inputs[0],
182 true_grad_graph,
183 false_grad_graph,
184 true_grad_inputs,
185 false_grad_inputs,
186 building_gradient=True,
187 )
189 # The predicate has no gradient.
190 return [None] + outputs
193def _build_cond(pred,
194 true_graph,
195 false_graph,
196 true_inputs,
197 false_inputs,
198 building_gradient,
199 name=None):
200 """Creates an If op from the specified predicate, branch functions and inputs.
202 Note that this modifies true_graph and false_graph to make the inputs match,
203 and to output all intermediates values so they're available for the gradient
204 computation.
206 true_graph and false_graph need not have the same input types, but they must
207 have the same output types.
209 Args:
210 pred: boolean Tensor
211 true_graph: FuncGraph
212 false_graph: FuncGraph
213 true_inputs: a list of Tensors to be passed to true_graph as input.
214 false_inputs: a list of Tensors to be passed to false_graph as input.
215 building_gradient: Whether this is a gradient If op.
216 name: the name for the If op.
218 Returns:
219 A list of Tensors which are the outputs of the If op. Does not include added
220 intermediate outputs.
221 """
222 _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
223 _check_same_outputs(_COND, [true_graph, false_graph])
225 # Add inputs to true_graph and false_graph to make them match. Note that
226 # this modifies true_graph and false_graph.
227 cond_inputs = _make_inputs_match([true_graph, false_graph],
228 [true_inputs, false_inputs])
229 # We do not output intermediates of the gradient If op since this is just
230 # for backwards compatibility with existing code.
231 if not building_gradient and util.output_all_intermediates():
232 # Add all intermediate tensors as function outputs so they're available for
233 # the gradient computation. Since the outputs of the two functions must
234 # match, we wrap all the intermediates in optionals. Each intermediate
235 # output will have a value iff its corresponding branch is taken.
237 true_intermediates = _get_intermediates(true_graph)
238 false_intermediates = _get_intermediates(false_graph)
240 # Wrap intermediates in optionals.
241 wrapped_true_intermediates = _wrap_intermediates(true_graph,
242 true_intermediates)
243 wrapped_false_intermediates = _wrap_intermediates(false_graph,
244 false_intermediates)
246 # Make outputs match by adding none optionals.
247 extra_true_outputs, extra_false_outputs = _make_intermediates_match( # pylint: disable=unbalanced-tuple-unpacking
248 [true_graph, false_graph],
249 [wrapped_true_intermediates, wrapped_false_intermediates])
251 true_graph.outputs.extend(extra_true_outputs)
252 false_graph.outputs.extend(extra_false_outputs)
253 _check_same_outputs(_COND, [true_graph, false_graph])
255 # Create the If op.
256 with ops.control_dependencies(
257 list(true_graph.function_captures.control) + list(
258 false_graph.function_captures.control)):
259 true_stateful_ops = [
260 op for op in true_graph.get_operations() if op._is_stateful
261 ]
262 false_stateful_ops = [
263 op for op in false_graph.get_operations() if op._is_stateful
264 ]
265 if (true_stateful_ops or false_stateful_ops):
266 op_fn = gen_functional_ops._if
267 else:
268 op_fn = gen_functional_ops.stateless_if
270 def _make_op(inputs):
271 if_op, tensors = util.get_op_and_outputs(op_fn(
272 pred,
273 inputs, [t.dtype for t in true_graph.outputs],
274 util.create_new_tf_function(true_graph),
275 util.create_new_tf_function(false_graph),
276 output_shapes=_get_output_shapes(true_graph.outputs,
277 false_graph.outputs),
278 name=name))
279 _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
280 # `if_op` is None if this is a `StatelessIf` op with no outputs.
281 if if_op is not None:
282 # The true and false graphs have already been created, and we need that
283 # to happen before we know which tensors will be captured and so whether
284 # to wrap the cond in a tf.function. Post-hoc mutation of the branch
285 # `outer_graph` properties seems like the only option if we want to
286 # conditionally wrap in a function.
287 true_graph.outer_graph = ops.get_default_graph()
288 false_graph.outer_graph = ops.get_default_graph()
289 if_op._true_graph = true_graph
290 if_op._false_graph = false_graph
291 util.maybe_set_lowering_attr(if_op)
292 util.maybe_propagate_compile_time_consts_in_xla(if_op)
293 _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
294 # Prevent fetching since the variant outputs can't be fetched directly.
295 if_op.graph.prevent_fetching(if_op)
296 return tensors
297 tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs)
299 # Return identities for each output of the If op, rather than the output of
300 # the If op directly. This makes pruning work if the output of cond() is
301 # fetched: the lowering pass converts the If outputs into IdentityN outputs,
302 # which if fetched will cause all ops in the taken branch to be run (since
303 # it takes all merge ops as input). After lowering, each output identity op
304 # will end up with only the appropriate merge op as input.
305 # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
306 # correct output structure
307 tensors = [array_ops.identity(t) for t in tensors]
309 structured_output_specs = _get_compatible_structured_output_specs(true_graph,
310 false_graph)
311 return _pack_sequence_as(structured_output_specs, tensors)
314def get_func_graphs(op):
315 """Returns `FuncGraph`s for the input op branches.
317 Args:
318 op: The If or Case Operation.
320 Returns:
321 A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches
322 for Case).
323 """
325 def _get_func_graph_for_branch(name_attr_list, cached_attr_name=None):
326 """Generates and returns a FuncGraph for the given branch."""
327 func_graph = None
328 if cached_attr_name is not None:
329 func_graph = getattr(op, cached_attr_name, None)
330 inputs = op.inputs[1:] # First input is pred.
331 if func_graph is None:
332 input_shapes = [t.shape for t in inputs]
333 func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name)
334 for external_t, internal_t in zip(inputs, func_graph.inputs):
335 handle_data_util.copy_handle_data(external_t, internal_t)
336 func_graph.function_captures.reset_captures(inputs, func_graph.inputs)
337 # Link the op so that the gradient code can use it.
338 func_graph._forward_cond = op
339 return func_graph
341 if op.type in ["If", "StatelessIf"]:
342 return (_get_func_graph_for_branch(
343 op.get_attr("then_branch"), "_true_graph"),
344 _get_func_graph_for_branch(
345 op.get_attr("else_branch"), "_false_graph"))
346 elif op.type in ["Case", "StatelessCase"]:
347 return [_get_func_graph_for_branch(branch_fn, "_branch_graph_{}".format(i))
348 for i, branch_fn in enumerate(op.get_attr("branches"))]
349 else:
350 raise ValueError("Unsupported op type: {}".format(op.type))
353def _get_compatible_structured_output_specs(true_graph, false_graph):
354 """Returns the most specific compatible specs of graph structured outputs."""
355 return nest.map_structure(_get_compatible_spec,
356 true_graph.structured_outputs,
357 false_graph.structured_outputs)
360def _get_compatible_spec(value_or_spec1, value_or_spec2):
361 """Returns the most specific compatible spec.
363 Args:
364 value_or_spec1: A TypeSpecs or a value that has a defined TypeSpec.
365 value_or_spec2: A TypeSpecs or a value that has a defined TypeSpec.
367 Returns:
368 The most specific compatible TypeSpecs of the input.
370 Raises:
371 ValueError: If value_or_spec1 is not compatible with value_or_spec2.
372 """
373 spec1 = _get_spec_for(value_or_spec1)
374 spec2 = _get_spec_for(value_or_spec2)
376 # pylint: disable=protected-access
377 common = spec1._without_tensor_names().most_specific_common_supertype(
378 [spec2._without_tensor_names()])
379 if common is None:
380 raise TypeError(f"No common supertype of {spec1} and {spec2}.")
381 return common
384def _get_spec_for(value_or_spec):
385 """Returns TypeSpec of a value or itself if it is a TypeSpec already."""
386 if isinstance(value_or_spec, type_spec.TypeSpec):
387 return value_or_spec
388 return type_spec.type_spec_from_value(value_or_spec)
391def _grad_fn(func_graph, grads):
392 """The gradient function for each conditional branch.
394 This function builds the gradient graph of the corresponding forward-pass
395 conditional branch in `func_graph`. This is done by differentiating
396 func_graph's outputs w.r.t. its inputs.
398 Args:
399 func_graph: FuncGraph. The corresponding forward-pass function.
400 grads: The list of input gradient Tensors.
402 Returns:
403 The output gradient Tensors.
404 """
405 # Filter out untrainable function outputs.
406 # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
407 # cause _GradientsHelper to raise an exception (e.g. the implementation
408 # doesn't expect 'ys' to contain boolean tensors).
409 assert len(func_graph.outputs) == len(grads)
410 ys = []
411 grad_ys = []
412 for y, grad_y in zip(func_graph.outputs, grads):
413 if not backprop_util.IsTrainable(y):
414 continue
415 ys.append(y)
416 grad_ys.append(grad_y)
418 # Build the gradient graph. Note that this builds the gradient computation of
419 # func_graph in the current graph, which requires capturing tensors from
420 # func_graph. The captured func_graph tensors are resolved to external tensors
421 # in _resolve_grad_inputs.
422 result = gradients_util._GradientsHelper(
423 ys, func_graph.inputs, grad_ys=grad_ys,
424 src_graph=func_graph)
426 return result
429def _create_grad_func(func_graph, grads, name):
430 """Returns the FuncGraph representation of _grad_fn."""
431 return func_graph_module.func_graph_from_py_func(
432 name,
433 lambda: _grad_fn(func_graph, grads), [], {},
434 func_graph=_CondGradFuncGraph(name, func_graph))
437def _resolve_grad_inputs(cond_graph, grad_graph):
438 """Returns the tensors to pass as inputs to `grad_graph`.
440 The `grad_graph` may have external references to
441 1. Its outer graph containing the input gradients. These references are kept
442 as is.
443 2. Tensors in the forward pass graph. These tensors may not be "live"
444 when the gradient is being computed. We replace such references by their
445 corresponding tensor in `cond_graph.outer_graph`. In the case of nested
446 control flow or functions, the gradient logic handling
447 `grad_graph.outer_graph` will make sure the tensor from
448 `cond_graph.outer_graph` is also correctly captured.
450 Args:
451 cond_graph: FuncGraph. The forward-pass function.
452 grad_graph: FuncGraph. The gradients function.
454 Returns:
455 A list of inputs tensors to be passed to grad_graph.
456 """
457 new_inputs = []
459 for t in grad_graph.external_captures:
460 # `t` must either be in `grad_graph.outer_graph` or in the forward
461 # `cond_graph`.
462 if t.graph != grad_graph.outer_graph:
463 assert t.graph == cond_graph
464 # `internal_captures` are not treated as intermediates and hence not added
465 # to If op outputs. So we get the outer tensor corresponding to those
466 # from the list of `external_captures`.
467 for i, output in enumerate(t.graph.outputs):
468 if output is t:
469 t = t.graph._forward_cond.outputs[i]
470 break
471 else:
472 for i, output in enumerate(t.graph.internal_captures):
473 if output is t:
474 t = t.graph.external_captures[i]
475 break
476 else:
477 raise ValueError("Could not find external tensor capture {tensor} in "
478 "captures or outputs".format(tensor=t))
480 # Note: We rely on the capturing logic of the gradient If op graph to
481 # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2
482 # and while_v2 handle this while building their gradient functions.
483 assert t.graph == cond_graph.outer_graph
484 new_inputs.append(t)
486 return new_inputs
489def _get_intermediates(func_graph):
490 """Returns intermediate tensors of `func_graph` for gradient computation."""
491 intermediates = []
492 for op in func_graph.get_operations():
493 for t in op.outputs:
494 if t in func_graph.inputs: continue
495 if t in func_graph.outputs: continue
496 if t.dtype is dtypes.resource:
497 continue
498 # Accumulating mutexes can cause deadlock.
499 if op.type == "MutexLock":
500 continue
501 intermediates.append(t)
502 return intermediates
505def _make_intermediates_match(branch_graphs, branch_optionals):
506 """Returns new optionals lists that have matching signatures.
508 This is done by mirroring each list in the other using none optionals.
509 There is no merging of like optionals.
511 Args:
512 branch_graphs: `list` of `FuncGraph`.
513 branch_optionals: `list` of `list`s of optional `Tensor`s from other
514 branch_graphs
516 Returns:
517 A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the
518 same number of `Tensor`s, all of which will be optionals of the same
519 shape/type.
520 """
521 new_branch_optionals = []
522 # Since the intermediates are optionals with dtype variant, we only need
523 # enough room for the longest list of intermediates.
524 intermediates_size = max(len(o) for o in branch_optionals)
525 for i, branch_graph in enumerate(branch_graphs):
526 other_optionals = _create_none_optionals(
527 branch_graph, intermediates_size - len(branch_optionals[i]))
528 new_branch_optionals.append(branch_optionals[i] + other_optionals)
529 return new_branch_optionals
532def _make_intermediates_match_xla(branch_graphs, branch_intermediates):
533 """Like _make_intermediates_match but for the XLA case."""
534 new_branch_intermediates = []
535 for i, branch_graph in enumerate(branch_graphs):
536 other_fakeparams = _create_fakeparams(
537 branch_graph,
538 sum((bi for bi in branch_intermediates
539 if bi is not branch_intermediates[i]), []))
540 num_preceding = sum(len(bi) for bi in branch_intermediates[:i])
541 new_branch_intermediates.append(other_fakeparams[:num_preceding] +
542 branch_intermediates[i] +
543 other_fakeparams[num_preceding:])
544 return new_branch_intermediates
547def _make_inputs_match(branch_graphs, branch_inputs):
548 """Modifies branch_graphs so they have the same input signature.
550 This method reorders and/or adds parameters to each graph in branch_graphs so
551 they have the same input signature, and updates the 'inputs' and 'captured'
552 fields of each graph accordingly. It uses the input tensors from the outer
553 graph to avoid duplicating shared arguments.
555 Args:
556 branch_graphs: a `list` of `FuncGraph`
557 branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The
558 inputs for the corresponding graph in `branch_graphs`.
560 Returns:
561 A new list of Tensors from the outer graph that are the new inputs for each
562 branch_graph. This is a deduped version of `sum(branch_inputs)`.
563 """
564 assert len(branch_graphs) == len(branch_inputs)
565 added_inputs = set()
566 new_inputs = []
567 for branch_in in branch_inputs:
568 for tensor in branch_in:
569 tensor_id = ops.tensor_id(tensor)
570 if tensor_id not in added_inputs:
571 added_inputs.add(tensor_id)
572 new_inputs.append(tensor)
574 for branch_graph, branch_in in zip(branch_graphs, branch_inputs):
575 input_ids = [ops.tensor_id(t) for t in branch_in]
576 branch_input_to_param = dict(zip(input_ids, branch_graph.inputs))
577 input_list = []
578 for in_t in new_inputs:
579 param = branch_input_to_param.get(ops.tensor_id(in_t))
580 if param is None:
581 param = _create_dummy_input(branch_graph, in_t)
582 input_list.append(param)
584 branch_graph.inputs = input_list
586 # Rewrite the FuncGraphs' state to reflect the new inputs.
587 branch_graph.function_captures.reset_captures(
588 new_inputs, branch_graph.inputs)
590 return new_inputs
593def _create_zeros_for_none_grads(forward_graphs, grad_graphs):
594 """Creates zeros for None out grads if at least one branch has non-None grad.
596 Args:
597 forward_graphs: List of forward FuncGraphs.
598 grad_graphs: List of grad FuncGraphs.
599 """
600 assert len(forward_graphs) == len(grad_graphs)
601 branch_outputs = [g.structured_outputs for g in grad_graphs]
602 num_outputs_per_branch = [len(outs) for outs in branch_outputs]
603 assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch
604 for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
605 if (any(t is None for t in branch_outs) and
606 any(t is not None for t in branch_outs)):
607 for branch_index, t in enumerate(branch_outs):
608 if t is None:
609 with grad_graphs[branch_index].as_default():
610 zeros = default_gradient.zeros_like(
611 forward_graphs[branch_index].inputs[output_idx])
612 grad_graphs[branch_index].structured_outputs[output_idx] = zeros
614 for grad_graph in grad_graphs:
615 grad_graph.outputs = [
616 t for t in func_graph_module.flatten(grad_graph.structured_outputs)
617 if t is not None
618 ]
621def _make_output_composite_tensors_match(op_type, branch_graphs):
622 """Modifies each branch_graph's outputs to have the same output signature.
624 Currently the only transformation implemented is turning a Tensor into an
625 equivalent IndexedSlices if the other branch returns an IndexedSlices.
626 Updates branch_graph.{outputs,structured_outputs} for each branch_graph in
627 branch_graphs.
629 Args:
630 op_type: _COND or _CASE
631 branch_graphs: `list` of `FuncGraph`
633 Raises:
634 TypeError: if a set of outputs cannot be rewritten.
635 """
636 # Note: since this is only used for gradient graphs, we do not expect the
637 # outputs to be structured (e.g. nested lists), and thus do not need to use
638 # nest.flatten, etc.
639 assert branch_graphs
640 branch_outputs = [g.structured_outputs for g in branch_graphs]
641 outputs_per_branch = list(len(outs) for outs in branch_outputs)
642 assert len(set(outputs_per_branch)) == 1, outputs_per_branch
644 for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
645 if len(set(type(out) for out in branch_outs)) == 1:
646 continue
647 if not any(
648 isinstance(out, indexed_slices.IndexedSlices) for out in branch_outs):
649 continue
650 for branch_idx, branch_out in enumerate(branch_outs):
651 if isinstance(branch_out, indexed_slices.IndexedSlices):
652 continue
653 elif isinstance(branch_out, ops.Tensor):
654 with branch_graphs[branch_idx].as_default():
655 branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices(
656 branch_out)
657 else:
658 raise TypeError(
659 "Cannot reconcile {op_name} {output_idx}-th outputs:\n"
660 " outputs from all branches: {outputs}".format(
661 op_name="tf.cond" if op_type == _COND else "tf.switch_case",
662 output_idx=output_idx,
663 outputs=branch_outs))
665 for branch_graph, branch_outs in zip(branch_graphs, branch_outputs):
666 branch_graph.structured_outputs = branch_outs
667 branch_graph.outputs = [
668 t for t in func_graph_module.flatten(branch_outs) if t is not None
669 ]
672def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
673 """Match dtype of IndexedSlices.indices in outputs of branch_graphs."""
674 assert branch_graphs
675 # Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`.
676 indexed_slice_indices = []
677 current_index = 0
678 # Note that this still contains Nones. We leave those in so that error
679 # messages contain the correct indices. We handle the Nones later when
680 # updating `current_index`.
681 branch_outputs_flat_with_composites = [
682 nest.flatten(branch_graph.structured_outputs, expand_composites=False)
683 for branch_graph in branch_graphs
684 ]
685 outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites]
686 assert len(set(outs_per_branch)) == 1, outs_per_branch
687 # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
688 for output_idx, branch_outs in enumerate(
689 zip(*branch_outputs_flat_with_composites)):
690 if len(
691 set(
692 isinstance(out, indexed_slices.IndexedSlices)
693 for out in branch_outs)) != 1:
694 raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n"
695 " branches returned: {outputs}".format(
696 op_name="cond" if op_type == _COND else "switch_case",
697 output_idx=output_idx,
698 outputs=branch_outs))
699 if isinstance(branch_outs[0], indexed_slices.IndexedSlices):
700 # indices is the second component of the composite tensor.
701 indexed_slice_indices.append(current_index + 1)
702 if nest.is_nested_or_composite(branch_outs[0]):
703 current_index += len(nest.flatten(branch_outs[0], expand_composites=True))
704 elif branch_outs[0] is not None:
705 # `FuncGraph.outputs` does not contain Nones so no need to update the
706 # counter in that case.
707 current_index += 1
709 if not indexed_slice_indices:
710 return
712 # `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus
713 # the Nones.
714 if current_index != len(branch_graphs[0].outputs):
715 raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n"
716 "Expected: %i\n"
717 "Actual: %i" %
718 (current_index, len(branch_graphs[0].outputs)))
720 # Cast indices with mismatching types to int64.
721 for index in indexed_slice_indices:
722 if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64)
723 for bg in branch_graphs):
724 raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
725 "Found: %s" %
726 str([bg.outputs[index].dtype for bg in branch_graphs]))
727 if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1:
728 for branch_graph in branch_graphs:
729 if branch_graph.outputs[index].dtype == dtypes.int32:
730 with branch_graph.as_default():
731 branch_graph.outputs[index] = math_ops.cast(
732 branch_graph.outputs[index], dtypes.int64)
734 for branch_graph in branch_graphs:
735 branch_graph.structured_outputs = _pack_sequence_as(
736 branch_graph.structured_outputs, branch_graph.outputs)
739def _pack_sequence_as(structured_outputs, op_outputs):
740 """Packs the outputs of the gradient If/Case op.
742 The branch functions may contain None's in the list of `structured_outputs`.
743 `op_outputs` has those outputs missing. So we need to add those Nones to the
744 list of `op_outputs` and then pack it in the same structure as
745 `structured_outputs`.
747 Args:
748 structured_outputs: structured_outputs from one of the branch functions.
749 op_outputs: List of output tensors of the op.
751 Returns:
752 `op_outputs` packed like `structured_outputs`.
753 """
754 outputs_with_nones = []
755 counter = 0
756 for output in nest.flatten(structured_outputs, expand_composites=True):
757 if output is None:
758 outputs_with_nones.append(None)
759 else:
760 outputs_with_nones.append(op_outputs[counter])
761 counter += 1
762 return func_graph_module.pack_sequence_as(structured_outputs,
763 outputs_with_nones)
766def _wrap_intermediates(func_graph, intermediates):
767 with func_graph.as_default():
768 return [gen_optional_ops.optional_from_value([t]) for t in intermediates]
771def _create_dummy_input(func_graph, template_tensor):
772 """Creates tensors in func_graph to represent template_tensors.
774 Args:
775 func_graph: FuncGraph.
776 template_tensor: a tensor in the outer graph.
778 Returns:
779 A tensor in func_graph.
780 """
781 with func_graph.as_default():
782 return array_ops.placeholder(
783 template_tensor.dtype, shape=template_tensor.shape)
786def _create_none_optionals(func_graph, n):
787 """Creates `n` `None` optionals in func_graph.
789 Args:
790 func_graph: FuncGraph.
791 n: `int` the number of `None` optionals to make.
793 Returns:
794 A list of tensors in func_graph.
795 """
796 with func_graph.as_default():
797 return [gen_optional_ops.optional_none() for _ in range(n)]
800# TODO(b/265317139): remove this function and move this dynamic dimension
801# handling logic to XLA once XLA shape is ready for dynamic dimensions.
802def _convert_dynamic_dimension_to_zero(shape):
803 """Converts dynamic dimensions in `shape` to zero.
805 The fake params created to match the intermediates captured in other branches
806 could have dynamic dimensions. But the XLA shape is not able to handle
807 dynamic dimensions in TF TensorShape. Setting the dynamic dimensions to
808 size zero will help avoid failing safety checks in bridge. When XLA
809 DynamicConditional op reconciles branch differences, XLA will replace the
810 dimension size 0 with a bounded dimension determined from the shape of
811 real argument in the other branch.
813 Note: Rank unknown shapes are returned as they are.
815 Args:
816 shape: The TensorShape of fake param.
818 Returns:
819 The new TensorShape with dynamic dimensions set to zero.
820 """
821 if shape.rank is None:
822 return shape
824 return tensor_shape.TensorShape(
825 [0 if d is None else d for d in shape.as_list()]
826 )
829def _create_fakeparams(func_graph, template_tensors):
830 """Creates FakeParams for the XLA case."""
831 with func_graph.as_default():
832 return [
833 gen_functional_ops.fake_param(
834 dtype=t.dtype, shape=_convert_dynamic_dimension_to_zero(t.shape))
835 for t in template_tensors]
838def _check_same_outputs(op_type, graphs):
839 """Raises an error if `graphs` have different outputs."""
841 def error(branch_idx, error_detail):
842 raise TypeError(
843 "{b0_name} and {bn_name} arguments to {op_name} must have the same "
844 "number, type, and overall structure of return values.\n"
845 "\n"
846 "{b0_name} output: {b0_out}\n"
847 "{bn_name} output: {bn_out}\n"
848 "\n"
849 "Error details:\n"
850 "{detail}".format(
851 b0_name="true_fn" if op_type == _COND else "branches[0]",
852 bn_name=("false_fn" if op_type == _COND else
853 "branches[{}]".format(branch_idx)),
854 op_name="tf.cond" if op_type == _COND else "tf.switch_case",
855 b0_out=graphs[0].structured_outputs,
856 bn_out=graphs[branch_idx].structured_outputs,
857 detail=error_detail))
859 for b in range(1, len(graphs)):
860 try:
861 nest.assert_same_structure(
862 graphs[0].structured_outputs,
863 graphs[b].structured_outputs,
864 expand_composites=True)
865 except (ValueError, TypeError) as e:
866 error(b, str(e))
868 op_type_str = "cond" if op_type == _COND else "case"
869 if len(graphs[0].outputs) != len(graphs[b].outputs):
870 raise ValueError("Lengths of branch outputs of {op_type} must match.\n"
871 "len(graphs[0].outputs): {len_0}\n"
872 "len(graphs[{b}].outputs): {len_b}\n".format(
873 op_type=op_type_str,
874 len_0=len(graphs[0].outputs),
875 b=b,
876 len_b=len(graphs[b].outputs)))
877 for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs):
878 if b0_out.dtype != bn_out.dtype:
879 error(b, "%s and %s have different types" % (b0_out, bn_out))
882def _get_output_shapes(*branch_graph_outputs):
883 output_shapes = []
884 for out_by_branch in zip(*branch_graph_outputs):
885 shape = out_by_branch[0].shape
886 for other_out in out_by_branch[1:]:
887 shape = shape.most_specific_compatible_shape(other_out.shape)
888 output_shapes.append(shape)
889 return output_shapes
892def _copy_handle_data(external_tensors, *branch_graph_outputs):
893 """Combines shapes in handle data and sets metadata on `external_tensors`."""
894 for tensors in zip(external_tensors, *branch_graph_outputs):
895 external = tensors[0]
896 internal = tensors[1:]
897 internal_handle_data = []
898 for tensor in internal:
899 handle_data = handle_data_util.get_resource_handle_data(tensor)
900 # NOTE: Assumes handle data has only one ShapeAndType entry. It's
901 # unclear how to combine different lengths across branches.
902 if not handle_data.is_set or len(handle_data.shape_and_type) != 1:
903 break
904 internal_handle_data.append(handle_data)
905 else: # There is handle data, so we need to combine it.
906 combined_shape = tensor_shape.TensorShape(None)
907 combined_dtype = None
908 for handle_data in internal_handle_data:
909 handle_shape = tensor_shape.TensorShape(
910 handle_data.shape_and_type[0].shape)
911 combined_shape = combined_shape.most_specific_compatible_shape(
912 handle_shape)
913 if combined_dtype is None:
914 combined_dtype = handle_data.shape_and_type[0].dtype
915 elif handle_data.shape_and_type[0].dtype != combined_dtype:
916 # Variants from different branches have different dtypes. The
917 # combined variant has no static dtype.
918 combined_dtype = types_pb2.DT_INVALID
919 combined_handle_data = internal_handle_data[0]
920 combined_handle_data.shape_and_type[0].shape.CopyFrom(
921 combined_shape.as_proto())
922 combined_handle_data.shape_and_type[0].dtype = combined_dtype
923 handle_data_util.set_handle_data(external, combined_handle_data)
926def verify_captures(op_type, branch_graphs):
927 """Verify that a branch's tensor is not accessed in another branch fn."""
928 # Note: It is technically not possible for lower-branch_index branches to
929 # capture tensors from higher-branch_index branches, because of the order of
930 # branch graph construction, but we check all for completeness and to
931 # guard against potential future changes.
932 other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)}
933 for i, branch_graph in enumerate(branch_graphs):
934 for t in branch_graph.external_captures:
935 if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs:
936 branch_names = ["true_fn", "false_fn"] if op_type == _COND else [
937 "branch {}".format(bi) for bi in range(len(branch_graphs))]
938 raise ValueError(
939 "Tensor {tname} in {b0name} is accessed from {b1name}.".format(
940 tname=t.name,
941 b0name=branch_names[other_branch_graphs[t.graph]],
942 b1name=branch_names[i]))
945class _CondGradFuncGraph(util.CondBranchFuncGraph):
946 """FuncGraph for the gradient function of the branch of an If op.
948 Handles wrapping and unwrapping intermediate values that are captured by the
949 gradient computation in optionals.
951 Attributes:
952 op_needs_rewrite: True if any intermediates were captured, meaning the
953 forward If op needs to be written to output the wrapped intermediates.
954 """
956 def __init__(self, name, forward_graph):
957 super(_CondGradFuncGraph, self).__init__(
958 name, collections=ops.get_default_graph()._collections) # pylint: disable=protected-access
959 self.op_needs_rewrite = False
960 self._forward_graph = forward_graph
961 # Maps from forward intermediate tensor -> the unwrapped captured
962 # intermediate.
963 self._indirect_captures = {}
964 # Maps unwrapped intermediate -> optional-wrapped intermediate in the
965 # forward graph.
966 self._wrapped_intermediates = collections.OrderedDict()
967 # Raw intermediates captured from the forward graph. Populated iff we're in
968 # an XLA context.
969 self._xla_intermediates = []
970 # Maps forward intermediate constant valued tensor's id to the constant
971 # created in this graph for that tensor.
972 self._captured_constants = {}
974 @property
975 def wrapped_intermediates(self):
976 """The optional-wrapped intermediates captured from the forward graph."""
977 return list(self._wrapped_intermediates.values())
979 @property
980 def xla_intermediates(self):
981 """Raw intermediates captured from the forward graph if XLA is enabled."""
982 return self._xla_intermediates
984 def _capture_helper(self, tensor, name):
985 if (tensor.graph is not self._forward_graph or
986 any(tensor is t for t in self._forward_graph.inputs) or
987 any(tensor is t for t in self._forward_graph.outputs)):
988 return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
990 tensor_id = ops.tensor_id(tensor)
992 # If `tensor` is a graph-building time constant, we create a constant with
993 # the same value in the backward graph instead of capturing it.
994 if tensor_id in self._captured_constants:
995 return self._captured_constants[tensor_id]
996 elif constant_op.is_constant(tensor):
997 self._captured_constants[tensor_id] = constant_op.constant(
998 tensor_util.constant_value(tensor), dtype=tensor.dtype)
999 return self._captured_constants[tensor_id]
1001 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
1002 # XLA does not yet support optionals, so capture intermediates directly.
1003 # TODO(skyewm,jpienaar): can XLA support optionals?
1004 if all(tensor is not capture for capture in self.external_captures):
1005 self.xla_intermediates.append(tensor)
1006 self.op_needs_rewrite = True
1007 return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
1009 captured_tensor = self._indirect_captures.get(tensor_id)
1010 if captured_tensor is not None:
1011 return captured_tensor
1013 # 'tensor' is an uncaptured intermediate in the forward graph.
1014 # If it is not a resource, we wrap it in an optional in the forward graph
1015 # and capture the optional normally. We then unwrap the captured optional
1016 # value in the gradient graph to get the raw intermediate value.
1017 # If it is a resource, we trace the resource up to the input in the forward
1018 # graph and capture that.
1020 if tensor.dtype == dtypes.resource:
1021 # Index of the forward graph input corresponding to the resource tensor.
1022 index = util.resource_input_index(
1023 tensor.name, [t.name for t in self._forward_graph.inputs],
1024 {op.name: op.node_def for op in self._forward_graph.get_operations()},
1025 self._forward_graph._functions)
1026 # This gets mapped to the corresponding If op input in
1027 # `_resolve_grad_inputs`.
1028 captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
1029 self._forward_graph.inputs[index], name)
1030 else:
1031 if tensor_id not in self._wrapped_intermediates:
1032 # If the gradient has already been computed for this If op, 'tensor' may
1033 # already be wrapped.
1034 for consumer in tensor.consumers():
1035 if (consumer.type == "OptionalFromValue" and
1036 any(consumer.outputs[0] is output
1037 for output in self._forward_graph.outputs)):
1038 optional = consumer.outputs[0]
1039 break
1040 else:
1041 # 'tensor' hasn't been wrapped, do it now.
1042 with self._forward_graph.as_default():
1043 optional = gen_optional_ops.optional_from_value([tensor])
1044 self.op_needs_rewrite = True
1045 self._wrapped_intermediates[tensor_id] = optional
1047 optional = self._wrapped_intermediates[tensor_id]
1048 captured_optional = super(_CondGradFuncGraph,
1049 self)._capture_helper(optional, name)
1050 captured_tensor = gen_optional_ops.optional_get_value(
1051 captured_optional, [tensor.dtype], [tensor.shape]
1052 )[0]
1054 self._indirect_captures[tensor_id] = captured_tensor
1055 return captured_tensor
1058def indexed_case(branch_index,
1059 branch_fns,
1060 name="indexed_case",
1061 lower_using_switch_merge=None):
1062 """Like conv_v2, except emits a Case op instead of an If."""
1063 if isinstance(branch_index, int):
1064 raise TypeError("branch_index must not be a Python int", branch_index)
1066 with ops.name_scope(name) as scope:
1067 branch_names = [
1068 util.unique_fn_name(scope, "branch{}".format(b))
1069 for b in range(len(branch_fns))
1070 ]
1072 # Automatic control dependencies are added in defuns, but not in v1
1073 # graphs. Propagate that behavior here.
1074 add_control_dependencies = ops.get_default_graph()._add_control_dependencies
1075 branch_index = ops.convert_to_tensor(branch_index, name="branch_index")
1077 branch_graphs = []
1078 for branch_name, branch_fn in zip(branch_names, branch_fns):
1079 branch_graphs.append(
1080 func_graph_module.func_graph_from_py_func(
1081 branch_name,
1082 branch_fn,
1083 [],
1084 {},
1085 func_graph=util.CondBranchFuncGraph(
1086 branch_name,
1087 collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
1088 add_control_dependencies=add_control_dependencies,
1089 op_return_value=branch_index))
1091 verify_captures(_CASE, branch_graphs)
1092 return _build_case(
1093 branch_index,
1094 branch_graphs, [g.external_captures for g in branch_graphs],
1095 name=scope,
1096 lower_using_switch_merge=lower_using_switch_merge)
1099@ops.RegisterGradient("Case")
1100@ops.RegisterGradient("StatelessCase")
1101def _CaseGrad(op, *grads): # pylint: disable=invalid-name
1102 """The gradient of a Case op produced by tf.switch_case."""
1103 # Get the Case operator (this logic handles the case where op is a MockOp)
1104 case_op = op.outputs[0].op
1105 branch_graphs = get_func_graphs(case_op)
1106 assert branch_graphs
1107 # Note: op.graph != ops.get_default_graph() when we are computing the gradient
1108 # of a nested cond.
1109 for branch_graph in branch_graphs:
1110 assert branch_graph.outer_graph == case_op.graph
1112 # Create grad functions that compute the gradient of the branch forward
1113 # graphs. These functions will capture tensors from the forward pass
1114 # functions.
1115 branch_grad_graphs = []
1116 for branch_graph in branch_graphs:
1117 branch_grad_graphs.append(
1118 _create_grad_func(branch_graph, grads,
1119 util.unique_grad_fn_name(branch_graph.name)))
1120 # Replaces output None grads with zeros if at least one branch has non-None
1121 # grad at that index.
1122 _create_zeros_for_none_grads(branch_graphs, branch_grad_graphs)
1124 if any(g.op_needs_rewrite for g in branch_grad_graphs):
1125 # Modify 'op' to output the intermediates needed by the grad functions. Note
1126 # that all needed intermediates are wrapped in optionals. Each optional
1127 # intermediate output will have a value iff its corresponding branch is
1128 # taken.
1129 # NOTE(bjp): if there are any active sessions, this modification to `op`
1130 # may make them unrunnable!
1132 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
1133 # XLA does not yet support optionals, so output intermediates directly and
1134 # make them match via FakeParams, which can be converted to zeros in XLA.
1135 # TODO(bjp,jpienaar): can XLA support optionals?
1136 branches_intermediates = [
1137 branch_grad_graph.xla_intermediates
1138 for branch_grad_graph in branch_grad_graphs
1139 ]
1140 extra_branch_outputs = _make_intermediates_match_xla(
1141 branch_graphs, branches_intermediates)
1142 else:
1143 branch_intermediates = [
1144 g.wrapped_intermediates for g in branch_grad_graphs
1145 ]
1146 # Make outputs match by adding none optionals.
1147 extra_branch_outputs = _make_intermediates_match(branch_graphs,
1148 branch_intermediates)
1150 for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs):
1151 branch_graph.outputs.extend(extra_outputs)
1152 # TODO(bjp): indicate it's an internal bug if this fails.
1153 _check_same_outputs(_CASE, branch_graphs)
1155 for branch_graph in branch_graphs:
1156 branch_graph.name += "_rewritten"
1158 case_op._set_func_list_attr("branches", [
1159 util.create_new_tf_function(branch_graph)
1160 for branch_graph in branch_graphs
1161 ])
1162 case_op._set_type_list_attr("Tout", branch_graphs[0].output_types)
1163 case_op._set_shape_list_attr("output_shapes",
1164 branch_graphs[0].output_shapes)
1165 case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]],
1166 [t.shape for t in extra_branch_outputs[0]])
1168 # Resolve references to forward graph tensors in grad graphs and ensure
1169 # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
1170 branches_grad_inputs = [
1171 _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph,
1172 branch_grad_graph in zip(branch_graphs, branch_grad_graphs)
1173 ]
1175 # This modifies the graphs in branch_grad_graphs.
1176 _make_output_composite_tensors_match(_CASE, branch_grad_graphs)
1178 try:
1179 lowering = case_op._get_attr_bool("_lower_using_switch_merge")
1180 except errors_impl.NotFoundError:
1181 lowering = None
1183 outputs = _build_case(
1184 case_op.inputs[0],
1185 branch_grad_graphs,
1186 branches_grad_inputs,
1187 name="gradient",
1188 lower_using_switch_merge=lowering)
1190 # The predicate has no gradient.
1191 return [None] + outputs
1194def _build_case(branch_index,
1195 branch_graphs,
1196 branch_inputs,
1197 name=None,
1198 lower_using_switch_merge=None):
1199 """Creates an `Case` op from `branch_index`, branch graphs and inputs.
1201 Note that this modifies `branch_graphs` to make the inputs match, and to
1202 output all intermediates values so they're available for the gradient
1203 computation.
1205 `branch_graphs` need not have the same input types, but they must
1206 have the same output types.
1208 Args:
1209 branch_index: integer Tensor
1210 branch_graphs: List of FuncGraph
1211 branch_inputs: List of lists of Tensors to be passed to corresponding
1212 branch_graph as input.
1213 name: the name for the Case op.
1214 lower_using_switch_merge: Lower this op using switch merge ops (optional).
1216 Returns:
1217 A list of Tensors which are the outputs of the Case op. Does not include
1218 added intermediate outputs.
1219 """
1220 _make_indexed_slices_indices_types_match(_CASE, branch_graphs)
1221 _check_same_outputs(_CASE, branch_graphs)
1223 # Add inputs to branch_graphs to make them match. Note that this modifies the
1224 # graphs in `branch_graphs`.
1225 case_inputs = _make_inputs_match(branch_graphs, branch_inputs)
1227 stateful_ops = []
1228 for bg in branch_graphs:
1229 stateful_ops.extend([
1230 op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op)
1231 ])
1233 if stateful_ops:
1234 op_fn = gen_functional_ops.case
1235 else:
1236 op_fn = gen_functional_ops.stateless_case
1238 # Create the Case op.
1239 with ops.control_dependencies(
1240 sum((list(bg.function_captures.control) for bg in branch_graphs), [])):
1242 def _make_op(inputs):
1243 case_op, tensors = util.get_op_and_outputs(op_fn(
1244 branch_index,
1245 inputs, [t.dtype for t in branch_graphs[0].outputs],
1246 [util.create_new_tf_function(g) for g in branch_graphs],
1247 output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
1248 name=name))
1249 _copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
1250 if case_op is not None:
1251 util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
1252 util.maybe_propagate_compile_time_consts_in_xla(case_op)
1253 _set_read_only_resource_inputs_attr(case_op, branch_graphs)
1254 # Prevent fetching since the variant outputs can't be fetched directly.
1255 case_op.graph.prevent_fetching(case_op)
1257 # Store the branch graphs so they can be reused during the gradient
1258 # pass.
1259 for i, bg in enumerate(branch_graphs):
1260 bg.outer_graph = ops.get_default_graph()
1261 setattr(case_op, "_branch_graph_{}".format(i), bg)
1263 return tensors
1264 tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs)
1266 # Return identities for each output of the Case op, rather than the output of
1267 # the Case op directly. This makes pruning work if the output of switch_case()
1268 # is fetched: the lowering pass converts the Case outputs into IdentityN
1269 # outputs, which if fetched will cause all ops in the taken branch to be run
1270 # (since it takes all merge ops as input). After lowering, each output
1271 # identity op will end up with only the appropriate merge op as input.
1272 # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
1273 # correct output structure
1274 tensors = [array_ops.identity(t) for t in tensors]
1276 return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)
1279def _set_read_only_resource_inputs_attr(op, branch_graphs):
1280 """Sets the list of resource inputs which are read-only.
1282 This is used by AutomaticControlDependencies.
1284 Args:
1285 op: If or Case Operation.
1286 branch_graphs: List of branch FuncGraphs.
1287 """
1288 # The first entry in `op.inputs` is the predicate which is not passed to
1289 # branch graphs so len(branch_graph[i].inputs) == len(op.inputs) - 1.
1290 read_only_indices = set(range(len(op.inputs) - 1))
1291 for branch_graph in branch_graphs:
1292 assert len(branch_graph.inputs) == len(op.inputs) - 1, "should never happen"
1293 if not read_only_indices:
1294 break
1295 branch_read_only_indices = acd.get_read_only_resource_input_indices_graph(
1296 branch_graph)
1297 read_only_indices = read_only_indices.intersection(branch_read_only_indices)
1298 # Convert indices in `branch_graphs[i].inputs` to `op.inputs`.
1299 read_only_indices = [i + 1 for i in read_only_indices]
1300 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
1301 sorted(read_only_indices))