Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/auto_control_deps.py: 19%
230 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AutomaticControlDependencies and related functionality."""
17import collections
18import enum
20from tensorflow.core.framework import attr_value_pb2
21from tensorflow.python.eager import context
22from tensorflow.python.framework import auto_control_deps_utils as utils
23from tensorflow.python.framework import dtypes as dtypes_module
24from tensorflow.python.framework import indexed_slices
25from tensorflow.python.framework import op_def_registry
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import registry
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import control_flow_util
32from tensorflow.python.ops import tensor_array_ops
33from tensorflow.python.util import nest
34from tensorflow.python.util import object_identity
35from tensorflow.python.util import tf_decorator
37# LINT.IfChange
38# Op types that should not run in program order, e.g. because they need to run
39# asynchronously to avoid deadlock.
41ASYNC_STATEFUL_OPS = frozenset((
42 "CollectiveGather",
43 "CollectiveReduce",
44 "CollectiveBcastSend",
45 "CollectiveBcastSendV2",
46 "CollectiveBcastRecv",
47 "CollectiveBcastRecvV2",
48 "NcclAllReduce",
49 # We do not add "Send" here since we want it to be added as a control output
50 # in order to avoid being pruned.
51 "Recv",
52 "CollectiveInitializeCommunicator",
53 "CollectiveAssignGroupV2",
54))
56LEGACY_RANDOM_OPS = frozenset((
57 # These may be used in variable initializers -- thus their execution should
58 # not be dependent on other stateful operations. This is because although
59 # according to program order, tf.Variables may be created in sequence,
60 # their initialization happens outside of the program order (specifically,
61 # in graph mode their initialization happens by calling a grouped
62 # initializer operation or in eager mode, where initialization is lifted
63 # out of the tf.function and executed the first time the function is
64 # executed).
65 #
66 # Unless there is a specific dependency between the initializers
67 # themselves (e.g. one initializer depends on a Variable whose value depends
68 # on another initializer), the initialization can happen in any order so
69 # long as it's before the associated Variable read operations.
70 #
71 # Note that in general the randomness of legacy random operations is only
72 # guaranteed by providing a graph-level and op-level seed (and ordering of
73 # the same op across multiple iterations of a while_loop is specifically not
74 # guaranteed; see the discussion below).
75 #
76 # There is a possible race condition inside while_loop where the same
77 # random OpKernel instantiation is reused across multiple steps
78 # of the loop. Since legacy Random OpKernels have an internal rng state,
79 # automatic dependency tracking across loop steps would likely
80 # fix this race; and for that case this denylist is problematic.
81 # However, since automatic dependency tracking inside while loops is not
82 # currently supported, and there are no other examples of OpKernel reuse
83 # (each OpKernel is associated with a unique op in graph mode),
84 # this denylist has no effect on the aforementioned behavior.
85 #
86 # TODO(ebrevdo,skyewm): Modify the check against this denylist to
87 # only occur when the op is inside a "variable initialization scope"; and
88 # add proper autodeps inside while_loops that respects this updated check.
89 "RandomUniform",
90 "RandomUniformInt",
91 "RandomStandardNormal",
92 "ParameterizedTruncatedNormal",
93 "TruncatedNormal",
94 "RandomShuffle",
95 "Multinomial",
96 "RandomGamma",
97 "RandomGammaGrad",
98 "RandomPoisson",
99 "RandomPoissonV2",
100))
102MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS = frozenset((
103 "InfeedEnqueue",
104 "InfeedEnqueueTuple",
105 "EnqueueTPUEmbeddingSparseBatch",
106 "EnqueueTPUEmbeddingIntegerBatch",
107 "EnqueueTPUEmbeddingSparseTensorBatch",
108 "EnqueueTPUEmbeddingRaggedTensorBatch",
109 "EnqueueTPUEmbeddingArbitraryTensorBatch",
110 "DynamicEnqueueTPUEmbeddingArbitraryTensorBatch",
111))
113# These ops are order-insensitive ans should in theory run, but at the moment
114# they either always have the necessary data dependencies, or have workarounds
115# in existing code that would break when adding new control deps. This
116# inconsistency should be eventually fixed, but it would be more effective to
117# retire the list instead.
118SKIPPED_ORDER_INSENSITIVE_STATEFUL_OPS = frozenset((
119 "CudnnRNN",
120 "CudnnRNNBackprop",
121 "CudnnRNNV2",
122 "CudnnRNNV3",
123 "CudnnRNNBackpropV2",
124 "CudnnRNNBackpropV3",
125 "RestoreV2",
126 "SaveV2",
127))
128# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc)
130# Op types that are marked as stateless, but should be allowlisted to add auto
131# control dependencies.
132_ALLOWLIST_STATELESS_OPS = [
133 # As TPU collective ops are blocking, if there are more than one collective
134 # op in the function, we need to make sure different collectives ops are
135 # scheduled in certain orders. Otherwise if at the same time all the
136 # replicas are launching different collective ops/programs, it may cause
137 # deadlock.
138 "AllToAll",
139 "CrossReplicaSum",
140 "CollectivePermute",
141]
144def op_is_stateful(op):
145 # pylint: disable=protected-access
146 ret = ((op._is_stateful and
147 ((op.type not in ASYNC_STATEFUL_OPS) and
148 (op.type not in LEGACY_RANDOM_OPS) and
149 (op.type not in SKIPPED_ORDER_INSENSITIVE_STATEFUL_OPS))) or
150 (op.type in _ALLOWLIST_STATELESS_OPS))
151 return ret
154class ResourceType(enum.Enum):
155 READ_ONLY = "read-only"
156 READ_WRITE = "read-write"
159def collective_manager_ids_from_op(op):
160 """Returns CollectiveManager ID from the op if one exists, else None.
162 CollectiveManager adds collective and no_op operations tagged with an ID,
163 unique to the manager object. This function extracts that ID, or None, if the
164 node was not generated by a CollectiveManager.
166 Args:
167 op: `Operation` to get the collective manager ID from.
169 Returns:
170 List of CollectiveManager IDs used by the op.
171 """
172 if op.type == "CollectiveReduce":
173 try:
174 return [op.get_attr("_collective_manager_id")]
175 except ValueError:
176 pass
177 elif op.type == "StatefulPartitionedCall":
178 try:
179 return op.get_attr(utils.COLLECTIVE_MANAGER_IDS)
180 except ValueError:
181 pass
182 return []
185class AutomaticControlDependencies(object):
186 """Context manager to automatically add control dependencies.
188 Code under this context manager will act as if a sensible set of control
189 dependencies were present. More specifically:
190 1. All stateful ops in the scope will execute (with the exception of ops in
191 ASYNC_STATEFUL_OPS and LEGACY_RANDOM_OPS)
192 2. Stateful ops which modify the same resource will execute in program order
194 Note: creating variables in an automatic control dependencies context is not
195 supported (the value of the variables will never change as they will keep
196 getting reinitialized).
198 NOT THREAD SAFE
199 """
201 def __init__(self):
202 self._returned_tensors = object_identity.ObjectIdentitySet()
203 self.ops_which_must_run = set()
204 self._independent_ops = []
206 def mark_as_return(self, tensor):
207 """Acts like identity but marks the `Tensor` as a return value.
209 This will possibly return a copy of the `Tensor`. Usage:
211 ```
212 with AutomaticControlDependencies() as a:
213 ...
214 t = a.mark_as_return(t)
215 _ = ...(t...) # i.e. it's safe to use t here
216 ```
218 Args:
219 tensor: the `Tensor` to be marked
221 Returns:
222 a copy of the `Tensor`.
223 """
224 if isinstance(tensor, indexed_slices.IndexedSlices):
225 values = array_ops.identity(tensor.values)
226 indices = array_ops.identity(tensor.indices)
227 self._returned_tensors.add(indices)
228 self._returned_tensors.add(values)
229 return indexed_slices.IndexedSlices(
230 values, indices, dense_shape=tensor.dense_shape)
231 elif isinstance(tensor, sparse_tensor.SparseTensor):
232 values = array_ops.identity(tensor.values)
233 indices = array_ops.identity(tensor.indices)
234 self._returned_tensors.add(indices)
235 self._returned_tensors.add(values)
236 return sparse_tensor.SparseTensor(
237 indices, values, dense_shape=tensor.dense_shape)
238 elif isinstance(tensor, tensor_array_ops.TensorArray):
239 flow = array_ops.identity(tensor.flow)
240 self._returned_tensors.add(flow)
241 return tensor_array_ops.build_ta_with_new_flow(tensor, flow)
242 # We want to make the return values depend on the stateful operations, but
243 # we don't want to introduce a cycle, so we make the return value the result
244 # of a new identity operation that the stateful operations definitely don't
245 # depend on.
246 tensor = array_ops.identity(tensor)
247 self._returned_tensors.add(tensor)
248 return tensor
250 def run_independently(self, op):
251 """Marks the given op as independent.
253 Overrides any other rule for the op.
255 Independent ops are guaranteed to execute before the return values, but
256 are allowed to run in parallel with everything else. Use in programs which
257 can guarantee that an op has side effects that don't affect any other op.
259 Args:
260 op: An operation
261 """
262 self._independent_ops.append(op)
263 op._set_attr("_independent_side_effects", attr_value_pb2.AttrValue(b=True)) # pylint: disable=protected-access
265 def __enter__(self):
266 if context.executing_eagerly():
267 return self
268 # This code assumes no other thread is adding ops to the graph while
269 # we're adding ops to the graph.
270 # TODO(apassos): Fix this by locking the graph or using a temporary
271 # graph (but that would mess up devices and collections at least,
272 # probably other things as well).
273 g = ops.get_default_graph()
274 self._graph = g
275 g._add_control_dependencies = True # pylint: disable=protected-access
276 g.experimental_acd_manager = self
277 self._n_operations = g.num_operations()
278 return self
280 def _process_switch(self, switch_op, ops_which_must_run,
281 last_write_to_resource, merge_for_resource):
282 """Processes a switch node for a resource input.
284 When tensorflow creates a cond, it creates a control flow context for each
285 branch of the cond. Each external tensor accessed by that branch is routed
286 through a switch op, which gets created in the graph _after_ the op which
287 uses that tensor get created.
289 If the resource comes from another switch op we process that one first.
291 _process_switch creates a corresponding merge node for the switch node. This
292 merge node is added to the outer control flow context of the switch
293 node. We also ensure that:
295 1. The switch node executes after the previous op which used the resource
296 tensor
298 2. Any op which uses a resource output of the switch node executes before
299 the merge for the switch node.
301 3. The next op which uses the input resource to the switch node (which
302 might be another switch node for the other branch of the conditional)
303 will execute after the merge node is done.
305 4. The merge node is marked as must_run so it will run even if no
306 subsequent operation uses the resource.
308 Args:
309 switch_op: the switch op to be processed
310 ops_which_must_run: the set of ops which must run
311 last_write_to_resource: map from resource tensor to last op updating
312 it
313 merge_for_resource: map from resource tensor to merge which must follow
314 all usages of it.
315 """
316 # pylint: disable=protected-access
317 inp = switch_op.inputs[0]
318 input_id = ops.tensor_id(inp)
319 if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
320 self._process_switch(inp.op, ops_which_must_run, last_write_to_resource,
321 merge_for_resource)
322 output = switch_op.outputs[0]
323 output_id = ops.tensor_id(output)
324 if output_id in merge_for_resource:
325 return
326 new_merge = control_flow_ops.merge(
327 switch_op.outputs, name="artificial_merge")
328 new_merge[0].op._control_flow_context = (
329 switch_op._control_flow_context.outer_context)
330 # Ensures the merge always runs
331 ops_which_must_run.add(new_merge[0].op)
332 if input_id in last_write_to_resource:
333 # Ensures the switch executes after the previous op using the resource.
334 switch_op._add_control_input(last_write_to_resource[input_id])
335 # Ensure the next op outside the cond happens after the merge.
336 last_write_to_resource[input_id] = new_merge[0].op
337 if input_id in merge_for_resource:
338 merge_for_resource[input_id]._add_control_input(new_merge[0].op)
339 for o in switch_op.outputs:
340 # Ensures the merge will execute after all ops inside the cond
341 merge_for_resource[ops.tensor_id(o)] = new_merge[0].op
343 def __exit__(self, unused_type, unused_value, unused_traceback):
344 # pylint: disable=protected-access
345 if context.executing_eagerly():
346 return
348 if self._graph is not ops.get_default_graph():
349 raise RuntimeError(
350 "Within the automatic control dependency context, the default graph"
351 f" cannot change. Upon entry it was {self._graph}, but on exit it"
352 f" changed to {ops.get_default_graph()}")
354 outer_graph = getattr(self._graph, "outer_graph", None)
355 if outer_graph is not None:
356 self._graph._add_control_dependencies = outer_graph._add_control_dependencies
357 else:
358 self._graph._add_control_dependencies = False
359 self._graph.experimental_acd_manager = None
361 # map from resource tensor to the last op which wrote to it
362 last_write_to_resource = {}
363 # map from resource tensor to the list of reads from it since the last
364 # write or since the beginning of the function.
365 reads_since_last_write_to_resource = collections.defaultdict(list)
366 # CollectiveManager manager_ids within a particular function call should not
367 # be needed outside of that function call. So we keep them separate (though
368 # the general idea of the maps is the same, in the future, we'll need to
369 # correctly thread the control output outside).
370 # Map from collective manager scope to the last op which used it
371 collective_manager_scopes_opened = {}
372 collective_manager_scopes_used = {}
373 # set of conditional and loop exits
374 ops_which_must_run = set()
375 # merge which must depend on ops which use this resource
376 merge_for_resource = {}
378 new_operations = self._graph.get_operations()[self._n_operations:]
380 # Ensures that uses of resource tensors get serialized properly and all
381 # execute. This is done by keeping a map from resource tensor to the last op
382 # in graph-construction order which used it (last_write_to_resource).
383 #
384 # Conditionals are written in TensorFlow such that every external tensor
385 # accessed in the conditional goes through a switch op and every return
386 # tensor (it's guaranteed that there will be at least one) goes through a
387 # merge op.
388 #
389 # To handle conditionals, switches are handled in a special way (see
390 # comments for _process_switch). Merge nodes created by TF's conditional
391 # logic (as opposed to by _process_switch) are forced to run and also get a
392 # control dependency added to them to ensure all stateful ops inside their
393 # control flow context run.
394 #
395 # We also ensure that if an op is using a resource output by a switch node
396 # (that is, a resource tensor for which there's a value in
397 # merge_for_resource) this op will run before the merge for that resource.
398 #
399 # We try to add control inputs to nodes respecting their control flow
400 # contexts to avoid dead nodes propagating everywhere and leading to
401 # "retval[0] doesn't have value" errors. If a node gets a control dependency
402 # on a dead node (i.e. a note from an untaken control flow branch) that node
403 # will be marked as dead unless it's a merge node.
404 #
405 # TODO(apassos): serialize non-resource-taking stateful ops as well, and
406 # test that it works. Support while loops. Support init_scope escaping from
407 # this.
408 for op in new_operations:
409 # TODO(apassos) make this code safely support while loops.
410 if control_flow_util.IsInWhileLoop(op):
411 continue
412 control_inputs = set()
414 if op.type in MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS:
415 # This will add it to self._independent_ops, but also mark it with an
416 # attribute.
417 self.run_independently(op)
419 if op in self._independent_ops:
420 ops_which_must_run.add(op)
421 continue
423 # Ensure stateful ops run.
424 # Read-only ops are added to control outputs if the read value is
425 # consumed. This covers the case when the read value is returned from
426 # the function since that goes through a tf.identity in mark_as_return.
427 if ((op_def_registry.get(op.type) is None) or
428 (op_is_stateful(op) and
429 (op.type not in utils.RESOURCE_READ_OPS or
430 any(output.consumers() for output in op.outputs)))):
431 ops_which_must_run.add(op)
433 # Make a note of all opened manager_ids.
434 if op.type == "NoOp":
435 try:
436 collective_manager_scopes_opened[op.get_attr(
437 "_collective_manager_id")] = op
438 except ValueError:
439 pass
440 # Ignore switches (they're handled separately)
441 if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
442 continue
443 # Make merges trigger all other computation which must run
444 # TODO(mdan): Don't do this. Write a transform to chains instead.
445 # See core/common_runtime/control_flow_deps_to_chains.cc.
446 if op.type == "Merge":
447 for o in ops_which_must_run:
448 op._add_control_input(o)
449 for inp in o.inputs:
450 input_id = ops.tensor_id(inp)
451 if input_id in last_write_to_resource:
452 last_write_to_resource[input_id] = op
453 ops_which_must_run = set([op])
454 continue
456 resource_inputs = set()
457 # Check for any resource inputs. If we find any, we update control_inputs
458 # and last_write_to_resource.
459 for inp, resource_type in _get_resource_inputs(op):
460 is_read = resource_type == ResourceType.READ_ONLY
461 input_id = ops.tensor_id(inp)
463 # If the op receives the same resource tensor twice as an input, we skip
464 # to avoid the op getting a control dependency on itself.
465 if input_id in resource_inputs:
466 continue
468 resource_inputs.add(input_id)
469 # Deal with switches, finally.
470 if inp.op.type == "Switch":
471 self._process_switch(inp.op, ops_which_must_run,
472 last_write_to_resource, merge_for_resource)
473 is_building_function = op.graph.building_function
474 # Ensure uses of resources are serialized
475 if input_id in last_write_to_resource:
476 if is_building_function or (
477 last_write_to_resource[input_id]._control_flow_context
478 is op._control_flow_context):
479 control_inputs.add(last_write_to_resource[input_id])
480 # Ensure merges happen after the closing of a cond block
481 if input_id in merge_for_resource:
482 merge_for_resource[input_id]._add_control_input(op)
483 if is_read:
484 reads_since_last_write_to_resource[input_id].append(op)
485 else:
486 control_inputs.update(reads_since_last_write_to_resource[input_id])
487 reads_since_last_write_to_resource[input_id] = []
488 last_write_to_resource[input_id] = op
490 if (op_is_stateful(op) and not resource_inputs
491 and op._control_flow_context is None):
492 if None in last_write_to_resource:
493 op._add_control_input(last_write_to_resource[None])
494 last_write_to_resource[None] = op
496 # Ensure ordering of collective ops
497 manager_ids = collective_manager_ids_from_op(op)
498 for manager_id in manager_ids:
499 if manager_id in collective_manager_scopes_opened:
500 # Chain this function call if the scope was opened.
501 op._add_control_input(collective_manager_scopes_opened[manager_id])
502 collective_manager_scopes_opened[manager_id] = op
503 else:
504 # If this op is in a scope not created here, create a chain starting
505 # at this op.
506 if manager_id in collective_manager_scopes_used:
507 op._add_control_input(collective_manager_scopes_used[manager_id])
508 collective_manager_scopes_used[manager_id] = op
510 if control_inputs and not is_building_function:
511 control_inputs = [
512 c for c in control_inputs
513 if c._control_flow_context is op._control_flow_context
514 ]
516 op._add_control_inputs(control_inputs)
518 # Ensure all ops which must run do run
519 self.ops_which_must_run.update(ops_which_must_run)
521 control_output_op = None
522 for idx, r in enumerate(
523 nest.flatten(list(self._returned_tensors), expand_composites=True)):
524 if self.ops_which_must_run:
525 updated_ops_which_must_run = []
526 if r.graph.building_function:
527 # There may be many stateful ops in the graph. Adding them as
528 # control inputs to each function output could create excessive
529 # control edges in the graph. Thus we create an intermediate No-op to
530 # chain the control dependencies between stateful ops and function
531 # outputs.
532 if idx == 0:
533 control_output_op = control_flow_ops.no_op()
534 control_output_op._add_control_inputs(self.ops_which_must_run)
535 updated_ops_which_must_run = [control_output_op]
536 else:
537 updated_ops_which_must_run = [
538 o for o in self.ops_which_must_run
539 if o._control_flow_context is r.op._control_flow_context
540 ]
541 r.op._add_control_inputs(updated_ops_which_must_run)
543 self.collective_manager_ids_used = collective_manager_scopes_used
546_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers")
549def register_acd_resource_resolver(f):
550 """Register a function for resolving resources touched by an op.
552 `f` is called for every Operation added in the ACD context with the op's
553 original resource reads and writes. `f` is expected to update the sets of
554 resource reads and writes in-place and return True if it updated either of the
555 sets, False otherwise.
557 Example:
558 @register_acd_resource_resolver
559 def identity_resolver(op, resource_reads, resource_writes):
560 # op: The `Operation` being processed by ACD currently.
561 # resource_reads: An `ObjectIdentitySet` of read-only resources.
562 # resource_writes: An `ObjectIdentitySet` of read-write resources.
563 def update(resource_inputs):
564 to_remove = []
565 to_add = []
566 for resource in resource_inputs:
567 if resource.op.type == "Identity":
568 to_remove.append(resource)
569 to_add.extend(resource.op.inputs)
570 for t in to_remove:
571 resource_inputs.discard(t)
572 resource_inputs.update(to_add)
573 return to_add or to_remove
574 return update(resource_reads) or update(resource_writes)
576 Args:
577 f: Python function with signature
578 (Operation, ObjectIdentitySet, ObjectIdentitySet) -> bool
580 Returns:
581 The function `f` after adding it to the registry.
582 """
583 _acd_resource_resolvers_registry.register(f)
584 return f
587@register_acd_resource_resolver
588def _identity_resolver(op, resource_reads, resource_writes):
589 """Replaces Identity output with its input in resource_inputs."""
590 del op
591 def update(resource_inputs):
592 to_remove = []
593 to_add = []
594 for resource in resource_inputs:
595 if resource.op.type == "Identity":
596 to_remove.append(resource)
597 to_add.extend(resource.op.inputs)
598 for t in to_remove:
599 resource_inputs.discard(t)
600 resource_inputs.update(to_add)
601 return to_add or to_remove
603 return update(resource_reads) or update(resource_writes)
606def _get_resource_inputs(op):
607 """Returns an iterable of resources touched by this `op`."""
608 reads, writes = utils.get_read_write_resource_inputs(op)
609 saturated = False
610 while not saturated:
611 saturated = True
612 for key in _acd_resource_resolvers_registry.list():
613 # Resolvers should return true if they are updating the list of
614 # resource_inputs.
615 # TODO(srbs): An alternate would be to just compare the old and new set
616 # but that may not be as fast.
617 updated = _acd_resource_resolvers_registry.lookup(key)(op, reads, writes)
618 if updated:
619 # Conservatively remove any resources from `reads` that are also writes.
620 reads = reads.difference(writes)
621 saturated = saturated and not updated
623 # Note: A resource handle that is not written to is treated as read-only. We
624 # don't have a special way of denoting an unused resource.
625 for t in reads:
626 yield (t, ResourceType.READ_ONLY)
627 for t in writes:
628 yield (t, ResourceType.READ_WRITE)
631def automatic_control_dependencies(f):
632 """Wraps f to automatically insert control dependencies.
634 The inserted dependencies ensure that:
635 1. All stateful ops in f run when the result of f runs
636 2. Updates to the same resources happen in order.
638 Args:
639 f: the function to be wrapped.
641 Returns:
642 The wrapped function.
643 """
645 def wrapper(*args, **kwargs):
646 with AutomaticControlDependencies() as a:
647 result = f(*args, **kwargs)
648 result_flat = [a.mark_as_return(t) for t in nest.flatten(result)]
649 return nest.pack_sequence_as(result, result_flat)
651 return tf_decorator.make_decorator(f, wrapper)