Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py: 17%
1026 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Control Flow Operations.
17See the [autograph](https://www.tensorflow.org/guide/autograph) guide.
18"""
19# pylint: disable=g-bad-name
20import abc
22from tensorflow.core.framework import attr_value_pb2
23from tensorflow.core.protobuf import control_flow_pb2
24from tensorflow.python.eager import context
25from tensorflow.python.framework import composite_tensor
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import indexed_slices
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.framework import type_spec
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import cond as tf_cond
36from tensorflow.python.ops import control_flow_assert
37from tensorflow.python.ops import control_flow_case
38from tensorflow.python.ops import control_flow_util as util
39from tensorflow.python.ops import gen_array_ops
40from tensorflow.python.ops import gen_control_flow_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import tensor_array_ops
43from tensorflow.python.ops import while_loop as while_loop_ops
44# go/tf-wildcard-import
45# pylint: disable=wildcard-import,undefined-variable
46from tensorflow.python.ops.gen_control_flow_ops import *
47# pylint: enable=wildcard-import
48from tensorflow.python.util import compat
49from tensorflow.python.util import dispatch
50from tensorflow.python.util import nest
51from tensorflow.python.util import variable_utils
52from tensorflow.python.util.tf_export import tf_export
54# TODO(b/269483538): needed for references while refactors are in progress
55case = control_flow_case.case
56_case_helper = control_flow_case._case_helper # pylint: disable=protected-access
57case_v2 = control_flow_case.case_v2
58_case_create_default_action = control_flow_case._case_create_default_action # pylint: disable=protected-access
59_case_verify_and_canonicalize_args = control_flow_case._case_verify_and_canonicalize_args # pylint: disable=protected-access
60_assert_at_most_n_true = control_flow_case._assert_at_most_n_true # pylint: disable=protected-access
61Assert = control_flow_assert.Assert
62_summarize_eager = control_flow_assert._summarize_eager # pylint: disable=protected-access
63while_loop = while_loop_ops.while_loop
64while_loop_v2 = while_loop_ops.while_loop_v2
65cond = tf_cond.cond
66cond_for_tf_v2 = tf_cond.cond_for_tf_v2
67_UnpackIfSingleton = tf_cond._UnpackIfSingleton # pylint: disable=protected-access
68_eager_cond_implementation = tf_cond._eager_cond_implementation # pylint: disable=protected-access
69_cast_indexed_slice_indices = tf_cond._cast_indexed_slice_indices # pylint: disable=protected-access
71# We override the 'tuple' for a control flow op, so we keep python's
72# existing 'tuple' for later use in this module.
73_basetuple = tuple
76# pylint: disable=protected-access
79def _Identity(tensor, name=None):
80 """Return a tensor with the same shape and contents as the input tensor.
82 Args:
83 tensor: A Tensor.
84 name: A name for this operation (optional).
86 Returns:
87 A Tensor with the same type and value as the input Tensor.
88 """
89 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
90 # TODO(b/246438937): Remove this when we expand ResourceVariables into
91 # dt_resource tensors.
92 tensor = variable_utils.convert_variables_to_tensors(tensor)
93 if isinstance(tensor, ops.Tensor):
94 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access
95 return gen_array_ops.ref_identity(tensor, name=name)
96 else:
97 return array_ops.identity(tensor, name=name)
98 elif isinstance(tensor, composite_tensor.CompositeTensor):
99 return nest.map_structure(_Identity, tensor, expand_composites=True)
100 else:
101 raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
102 f"Received: {type(tensor)}.")
105def _NextIteration(tensor, name=None):
106 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
107 if isinstance(tensor, ops.Tensor):
108 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access
109 return ref_next_iteration(tensor, name=name)
110 else:
111 return next_iteration(tensor, name=name)
112 elif isinstance(tensor, composite_tensor.CompositeTensor):
113 return nest.map_structure(_NextIteration, tensor, expand_composites=True)
114 else:
115 raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
116 f"Received: {type(tensor)}.")
119def _Enter(tensor,
120 frame_name,
121 is_constant=False,
122 parallel_iterations=10,
123 use_ref=True,
124 use_input_shape=True,
125 name=None):
126 """Creates or finds a child frame, and makes `tensor` available to it.
128 The unique `frame_name` is used by the `Executor` to identify frames. If
129 `is_constant` is true, `tensor` is a constant in the child frame; otherwise
130 it may be changed in the child frame. At most `parallel_iterations`
131 iterations are run in parallel in the child frame.
133 Args:
134 tensor: The tensor to be made available to the child frame.
135 frame_name: The name of the child frame.
136 is_constant: If true, the output is constant within the child frame.
137 parallel_iterations: The number of iterations allowed to run in parallel.
138 use_ref: If true, use ref_enter if tensor is of ref type.
139 use_input_shape: If true, set the result's shape based on tensor's shape.
140 name: A name for this operation (optional).
142 Returns:
143 The same tensor as `tensor`.
145 Raises:
146 ValueError: If any tensor in `tensor` has a less specific shape
147 than its corresponding shape in `shape_invariant`.
148 """
149 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
150 if isinstance(tensor, ops.Tensor):
151 if tensor.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access
152 result = gen_control_flow_ops.ref_enter(
153 tensor, frame_name, is_constant, parallel_iterations, name=name)
154 else:
155 result = gen_control_flow_ops.enter(
156 tensor, frame_name, is_constant, parallel_iterations, name=name)
157 if use_input_shape:
158 result.set_shape(tensor.get_shape())
159 return result
160 elif isinstance(tensor, composite_tensor.CompositeTensor):
162 def enter_component(t):
163 return _Enter(t, frame_name, is_constant, parallel_iterations, use_ref,
164 use_input_shape)
166 return nest.map_structure(enter_component, tensor, expand_composites=True)
167 else:
168 raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
169 f"Received: {type(tensor)}.")
172def exit(tensor, name=None): # pylint: disable=redefined-builtin
173 """Exits the current frame to its parent frame.
175 Exit makes its input `tensor` available to the parent frame.
177 Args:
178 tensor: The tensor to be made available to the parent frame.
179 name: A name for this operation (optional).
181 Returns:
182 The same tensor as `tensor`.
183 """
184 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True)
185 if isinstance(tensor, ops.Tensor):
186 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access
187 return gen_control_flow_ops.ref_exit(tensor, name)
188 else:
189 return gen_control_flow_ops._exit(tensor, name)
190 elif isinstance(tensor, composite_tensor.CompositeTensor):
191 return nest.map_structure(exit, tensor, expand_composites=True)
192 else:
193 raise TypeError("'tensor' must be a Tensor or CompositeTensor. "
194 f"Received: {type(tensor)}.")
197def switch(data, pred, dtype=None, name=None):
198 """Forwards `data` to an output determined by `pred`.
200 If `pred` is false, the `data` input is forwarded to the first output.
201 Otherwise, the data goes to the second output.
203 This op handles `Tensor`s and `IndexedSlices`.
205 Args:
206 data: The tensor to be forwarded to the appropriate output.
207 pred: A scalar that specifies which output port will receive data.
208 dtype: Optional element type for the returned tensor. If missing, the type
209 is inferred from the type of `value`.
210 name: A name for this operation (optional).
212 Returns:
213 `(output_false, output_true)`: If `pred` is true, data will be forwarded
214 to `output_true`, otherwise it goes to `output_false`.
215 """
216 with ops.name_scope(name, "Switch", [data, pred]) as name:
217 data = ops.internal_convert_to_tensor_or_composite(
218 data, dtype=dtype, name="data", as_ref=True)
219 pred = ops.convert_to_tensor(pred, name="pred")
220 if isinstance(data, ops.Tensor):
221 return gen_control_flow_ops.switch(data, pred, name=name)
222 else:
223 if not isinstance(data, composite_tensor.CompositeTensor):
224 raise TypeError(
225 "'data' must be a Tensor or CompositeTensor. "
226 f"Received: {type(data)}.")
227 tensors = nest.flatten(data, expand_composites=True)
228 mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors]
229 mapped_f, mapped_t = zip(*mapped)
230 return (nest.pack_sequence_as(data, mapped_f, expand_composites=True),
231 nest.pack_sequence_as(data, mapped_t, expand_composites=True))
234def _SwitchRefOrTensor(data, pred, name="Switch"):
235 """Forwards `data` to an output determined by `pred`.
237 If `pred` is false, the `data` input is forwarded to the first output.
238 Otherwise, the data goes to the second output.
240 This op handles `Tensor`s and `IndexedSlices`.
242 Args:
243 data: The tensor to be forwarded to the appropriate output.
244 pred: A scalar that specifies which output port will receive data.
245 name: A name for this operation (optional).
247 Returns:
248 `(output_false, output_true)`: If `pred` is true, data will be forwarded to
249 `output_true`, otherwise it goes to `output_false`.
251 Raises:
252 TypeError: if data is not a Tensor or IndexedSlices
253 """
254 data = ops.convert_to_tensor_or_composite(data, name="data")
255 # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
256 # addresses the following scenario.
257 #
258 # Assume you execute Optimizer.apply_gradients() in a branch of a cond().
259 #
260 # 1. The update op is created inside a `with ops.colocate(var):` block
261 #
262 # 2. Some tensor `data` is captured and a switch is created in a
263 # `with ops.colocate_with(data):` block.
264 #
265 # with ops.colocate_with(var):
266 # with ops.colocate_with(data):
267 # op = ...
268 #
269 # var and data may be pinned to different devices, so we want to ops
270 # created within ops.colocate_with(data) to ignore the existing stack.
271 with ops.colocate_with(data, ignore_existing=True):
272 if isinstance(data, ops.Tensor):
273 if data.dtype._is_ref_dtype: # pylint: disable=protected-access
274 return ref_switch(data, pred, name=name)
275 return switch(data, pred, name=name)
278def merge(inputs, name=None):
279 """Returns the value of an available element of `inputs`.
281 This op tests each of the tensors in `inputs` in turn to determine if any of
282 them is available. If it finds an available tensor, it returns it and its
283 index in `inputs`.
285 It is an error if more than one tensor in `inputs` is available. If no tensor
286 in `inputs` is available, the returned tensor and index are not set.
288 This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
289 `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
290 before merging.
292 Args:
293 inputs: The input tensors, at most one of which is available.
294 name: A name for this operation (optional).
296 Returns:
297 A tuple containing the chosen input tensor and its index in `inputs`.
299 Raises:
300 ValueError: If any of the inputs is None, or inputs are IndexedSlices and
301 some but not all have a dense_shape property.
302 """
303 if any(inp is None for inp in inputs):
304 raise ValueError("At least one of the merge inputs is None: %s" % inputs)
305 with ops.name_scope(name, "Merge", inputs) as name:
306 inputs = [
307 ops.internal_convert_to_tensor_or_composite(inp, as_ref=True)
308 for inp in inputs
309 ]
310 if all(isinstance(v, ops.Tensor) for v in inputs):
311 if all(v.dtype._is_ref_dtype for v in inputs): # pylint: disable=protected-access
312 return gen_control_flow_ops.ref_merge(inputs, name)
313 else:
314 return gen_control_flow_ops.merge(inputs, name)
315 else:
316 # If there is a mix of tensors and indexed slices, then convert the
317 # tensors to indexed slices.
318 if all(
319 isinstance(v, (indexed_slices.IndexedSlices, ops.Tensor))
320 for v in inputs):
321 inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
323 for v in inputs:
324 if not isinstance(v, composite_tensor.CompositeTensor):
325 raise TypeError("Type %s not supported" % type(v))
327 for v in inputs[1:]:
328 nest.assert_same_structure(inputs[0], v, expand_composites=True)
330 flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs]
331 merged_results = [
332 gen_control_flow_ops.merge(component)
333 for component in zip(*flat_inputs)
334 ]
335 flat_merged = [tensor for (tensor, _) in merged_results]
336 chosen_index = merged_results[0][1]
337 merged_inputs = nest.pack_sequence_as(
338 inputs[0], flat_merged, expand_composites=True)
339 return (merged_inputs, chosen_index)
342def _convert_tensorarray_to_flow(tensor_or_tensor_array):
343 if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
344 return tensor_or_tensor_array.flow
345 else:
346 return tensor_or_tensor_array
349def _convert_flow_to_tensorarray(tensor_or_tensor_array, tensor_or_flow):
350 if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
351 return tensor_array_ops.build_ta_with_new_flow(tensor_or_tensor_array,
352 tensor_or_flow)
353 else:
354 return tensor_or_flow
357def _convert_to_tensor_or_composite_or_tensorarray(var):
358 if isinstance(var, tensor_array_ops.TensorArray):
359 return var
360 return ops.convert_to_tensor_or_composite(var)
363# TODO(xjun): replace this with is_subtype_of after it is landed.
364def _ShapeLessThanOrEqual(shape1, shape2):
365 if shape2.dims is None:
366 return True
367 if shape1.ndims != shape2.ndims:
368 return False
369 for dim1, dim2 in zip(shape1.dims, shape2.dims):
370 if dim2.value is not None and dim1.value != dim2.value:
371 return False
372 return True
375def _shape_invariant_to_type_spec(var, shape=None):
376 """Converts a shape invariant to a TypeSpec.
378 If `var` is a TensorArray, it will first be converted to its flow.
380 Args:
381 var: The tensor, tensor array or composite tensor whose shape is described
382 by the shape invariant.
383 shape: A `TypeSpec` or `TensorShape`. If `shape` is already a `TypeSpec`,
384 then it is simply returned as-is.
386 Returns:
387 A `TypeSpec` for `var`, consistent with the given shape.
389 Raises:
390 TypeError: If `shape` is a TypeSpec and not compatible with `var`.
391 TypeError: If `shape` is not None, a TypeSpec, or a TensorShape.
392 TypeError: If `shape` is a TensorShape, `var` is a CompositeTensor, and
393 `var` doesn't implement the `_shape_invariant_to_type_spec` method.
394 """
395 var = _convert_tensorarray_to_flow(var)
396 if shape is None:
397 return type_spec.type_spec_from_value(var)
398 elif isinstance(shape, type_spec.TypeSpec):
399 if not shape.is_compatible_with(var):
400 raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
401 return shape
402 elif not isinstance(shape, tensor_shape.TensorShape):
403 raise TypeError(
404 "'shape' must be one of TypeSpec, TensorShape or None. "
405 f"Received: {type(shape)}")
407 if isinstance(var, ops.Tensor):
408 return tensor_spec.TensorSpec(shape, var.dtype)
409 else:
410 try:
411 return var._shape_invariant_to_type_spec(shape) # pylint: disable=protected-access
412 except NotImplementedError as e:
413 raise TypeError(
414 f"To describe or constrain a {type(var).__name__}, use a "
415 f"{type(var._type_spec).__name__} instead of a TensorShape.") from e # pylint: disable=protected-access
418def _EnforceShapeInvariant(merge_var, next_var):
419 """Check if the shapes of the loops variables are invariants.
421 Args:
422 merge_var: The tensor representing the initial values of the loop
423 variables.
424 next_var: The tensor representing the values of the loop variables
425 after one loop iteration.
427 Raises:
428 ValueError: If any tensor in `merge_var` has a more specific shape than
429 its corresponding tensor in `next_var`.
430 """
431 if isinstance(merge_var, ops.Tensor):
432 m_shape = merge_var.get_shape()
433 n_shape = next_var.get_shape()
434 if not _ShapeLessThanOrEqual(n_shape, m_shape):
435 enter = merge_var.op.inputs[0].op
436 assert util.IsLoopEnter(enter)
437 input_t = enter.inputs[0]
438 raise ValueError(
439 "Input tensor '%s' enters the loop with shape %s, but has shape %s "
440 "after one iteration. To allow the shape to vary across iterations, "
441 "use the `shape_invariants` argument of tf.while_loop to specify a "
442 "less-specific shape." % (input_t.name, input_t.shape, n_shape))
443 else:
444 raise TypeError("'merge_var' must be a Tensor. "
445 f"Received: {type(merge_var)}.")
448def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
449 """Add NextIteration and back edge from v to m."""
450 if isinstance(m, ops.Tensor):
451 v = ops.convert_to_tensor(v)
452 v = _NextIteration(v)
453 if enforce_shape_invariant:
454 # Make sure the shapes of loop outputs are correct. We do this before
455 # calling _update_input, which will raise a less-helpful error message if
456 # the types don't match.
457 # TODO(skyewm): call this for other cases below (needs testing)
458 _EnforceShapeInvariant(m, v)
459 m.op._update_input(1, v) # pylint: disable=protected-access
460 elif isinstance(m, composite_tensor.CompositeTensor):
461 # pylint: disable=protected-access
462 def update_component(m_component, v_component):
463 m_component.op._update_input(1, v_component)
465 if isinstance(m, indexed_slices.IndexedSlices):
466 v = math_ops._as_indexed_slices(v, optimize=False)
467 # pylint: enable=protected-access
468 v = _NextIteration(v)
469 return nest.map_structure(update_component, m, v, expand_composites=True)
470 else:
471 raise TypeError("'m' must be a Tensor or CompositeTensor. "
472 f"Received: {type(m)}.")
473 return v
476class ControlFlowContext(metaclass=abc.ABCMeta):
477 """The base class for control flow context.
479 The usage pattern is a sequence of (Enter, Exit) followed by a final
480 ExitResult.
482 We maintain the following state for control flow contexts during graph
483 construction:
484 1. graph has _control_flow_context: the current context used to
485 construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()
486 2. op has _control_flow_context: the context to which the op belongs.
487 Set at the time the op is created. Immutable.
488 3. A ControlFlowContext has _outer_context: the context in which this
489 context is created. Set at the time a context is created. Immutable.
490 4. A ControlFlowContext has _context_stack.
491 Pushed and popped by ctxt.Enter() and ctxt.Exit()
492 """
494 def __init__(self, values_def=None, import_scope=None):
495 self._nested_contexts = []
496 self._outer_context = ops.get_default_graph()._get_control_flow_context()
497 if self._outer_context:
498 self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access
499 self._context_stack = []
500 if values_def:
501 self._init_values_from_proto(values_def, import_scope=import_scope)
502 else:
503 # The names of tensors that have been already seen in this context.
504 self._values = set()
505 # The keys are the names of tensors referenced by but external to this
506 # context. Each value is the Tensor that should be used by this context to
507 # access the key value (e.g. a switch output guarding a cond input value).
508 self._external_values = {}
510 def _init_values_from_proto(self, values_def, import_scope=None):
511 """Initializes values and external_values from `ValuesDef` protocol buffer.
513 Args:
514 values_def: `ValuesDef` protocol buffer.
515 import_scope: Optional `string`. Name scope to add.
516 """
517 assert isinstance(values_def, control_flow_pb2.ValuesDef)
518 self._values = set(
519 ops.prepend_name_scope(value, import_scope)
520 for value in values_def.values)
521 g = ops.get_default_graph()
522 self._external_values = {}
523 for k, v in values_def.external_values.items():
524 k = ops.prepend_name_scope(k, import_scope)
525 self._external_values[k] = g.as_graph_element(
526 ops.prepend_name_scope(v, import_scope))
527 op_names = set([
528 op.split(":")[0]
529 for op in self._values - set(self._external_values.keys())
530 ])
531 for op in op_names:
532 # pylint: disable=protected-access
533 g.as_graph_element(op)._set_control_flow_context(self)
534 # pylint: enable=protected-access
536 @property
537 def name(self):
538 return self._name
540 @property
541 def outer_context(self):
542 """Return the context containing this context."""
543 return self._outer_context
545 @property
546 def grad_state(self):
547 raise NotImplementedError("Abstract method")
549 @property
550 def back_prop(self):
551 raise NotImplementedError("Abstract method")
553 @abc.abstractmethod
554 def to_control_flow_context_def(self, context_def, export_scope=None):
555 """Serializes this into `context_def`.
557 Args:
558 context_def: a `ControlFlowContextDef` protocol buffer.
559 export_scope: Optional `string`. Name scope to remove.
560 """
561 raise NotImplementedError("Abstract method")
563 def _to_values_def(self, export_scope=None):
564 """Converts the values to a `ValuesDef` protocol buffer.
566 Args:
567 export_scope: Optional `string`. Name scope to remove.
569 Returns:
570 A `ValuesDef` protocol buffer.
571 """
572 values_def = control_flow_pb2.ValuesDef()
573 values_def.values.extend(
574 [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
575 for k, v in self._external_values.items():
576 k = ops.strip_name_scope(k, export_scope)
577 values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
578 return values_def
580 def AddName(self, name):
581 self._values.add(name)
583 # pylint: disable=protected-access
584 def Enter(self):
585 """Enter this control flow context."""
586 graph = ops.get_default_graph()
587 self._context_stack.append(graph._get_control_flow_context())
588 graph._set_control_flow_context(self)
590 def Exit(self):
591 """Exit this control flow context."""
592 graph = ops.get_default_graph()
593 last_context = self._context_stack.pop()
594 graph._set_control_flow_context(last_context)
596 def EnterGradientColocation(self, op, gradient_uid):
597 """Start building a gradient colocated with an op."""
598 if self._outer_context:
599 self._outer_context.EnterGradientColocation(op, gradient_uid)
601 def ExitGradientColocation(self, op, gradient_uid):
602 """Start building a gradient colocated with an op."""
603 if self._outer_context:
604 self._outer_context.ExitGradientColocation(op, gradient_uid)
606 def ExitResult(self, result):
607 """Make a list of tensors available in the outer context."""
608 if self._outer_context:
609 def fn(x):
610 self._outer_context.AddName(x.name)
611 return x
612 nest.map_structure(fn, result, expand_composites=True)
614 def GetWhileContext(self):
615 """Return the while context containing this context."""
616 if self._outer_context:
617 return self._outer_context.GetWhileContext()
618 return None
620 def _RemoveExternalControlEdges(self, op):
621 """Remove any external control dependency on this op."""
622 while_ctxt = self.GetWhileContext()
623 # A control input of `op` is internal if it is in the same while
624 # loop context as the enclosing while loop context of self.
625 if while_ctxt is None:
626 internal_control_inputs, external_control_inputs = op.control_inputs, []
627 else:
628 internal_control_inputs, external_control_inputs = [], []
629 for x in op.control_inputs:
630 ctxt = util.GetOutputContext(x)
631 if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
632 internal_control_inputs.append(x)
633 else:
634 external_control_inputs.append(x)
635 if len(internal_control_inputs) != len(op.control_inputs):
636 # TODO(mdan): perhaps there should be a replace_control_inputs()
637 op._remove_all_control_inputs()
638 op._add_control_inputs(internal_control_inputs)
639 return internal_control_inputs, external_control_inputs
641 # pylint: enable=protected-access
643 def AddInnerOp(self, op):
644 """Notifies a scope about an operator added to an inner scope."""
645 if self._outer_context:
646 self._outer_context.AddInnerOp(op)
648 def GetControlPivot(self):
649 """Returns the pivot node for this context, or None."""
650 return None
652 def IsWhileContext(self):
653 return False
655 def IsCondContext(self):
656 return False
658 def IsXLAContext(self):
659 return False
661 def __str__(self):
662 return self.name
665class CondContext(ControlFlowContext):
666 """The context for the conditional construct."""
668 def __init__(self,
669 pred=None,
670 pivot=None,
671 branch=None,
672 name="cond_text",
673 context_def=None,
674 import_scope=None):
675 """Creates a `CondContext`.
677 Args:
678 pred: The `boolean` tensor for the conditional predicate.
679 pivot: The predicate tensor in this branch.
680 branch: 0 or 1 representing this branch.
681 name: Name of the `CondContext` python object.
682 context_def: Optional `ContextDef` protocol buffer to initialize the
683 `CondContext` object from.
684 import_scope: Optional `string`. Name scope to add. Only used when
685 initialing from protocol buffer.
686 """
687 self._name = ops.get_default_graph().unique_name(name)
689 if context_def:
690 self._init_from_proto(context_def, import_scope=import_scope)
691 else:
692 # Initializes the default fields.
693 ControlFlowContext.__init__(self)
694 self._pred = pred # The boolean tensor for the cond predicate
695 self._pivot = pivot # The predicate tensor in this branch
696 self._branch = branch # 0 or 1 representing this branch
698 # Values considered to have been already seen in this context. pred is not
699 # included in this context.
700 self._values.add(pred.name)
701 self._external_values[pred.name] = pred
702 self._values.add(pivot.name)
703 pivot.op._set_control_flow_context(self) # pylint: disable=protected-access
705 def _init_from_proto(self, context_def, import_scope=None):
706 """Creates a new `CondContext` from protocol buffer.
708 Args:
709 context_def: `CondContextDef` protocol buffer.
710 import_scope: Optional `string`. Name scope to add.
711 """
712 assert isinstance(context_def, control_flow_pb2.CondContextDef)
713 # Create from context_def.
714 g = ops.get_default_graph()
715 self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
716 self._pred = g.as_graph_element(
717 ops.prepend_name_scope(context_def.pred_name, import_scope))
718 self._pivot = g.as_graph_element(
719 ops.prepend_name_scope(context_def.pivot_name, import_scope))
720 self._branch = context_def.branch
721 super(CondContext, self).__init__(
722 values_def=context_def.values_def, import_scope=import_scope)
724 @property
725 def pred(self):
726 return self._pred
728 @property
729 def pivot(self):
730 return self._pivot
732 @property
733 def branch(self):
734 return self._branch
736 @property
737 def grad_state(self):
738 if self.GetWhileContext():
739 return self.GetWhileContext().grad_state
740 return None
742 @property
743 def back_prop(self):
744 if self.GetWhileContext():
745 self.GetWhileContext().back_prop
746 return False
748 def GetControlPivot(self):
749 return self._pivot
751 def to_proto(self, export_scope=None):
752 """Converts a `CondContext` to a `CondContextDef` protocol buffer.
754 Args:
755 export_scope: Optional `string`. Name scope to remove.
757 Returns:
758 A `CondContextDef` protocol buffer.
759 """
760 if (export_scope is None or self.name.startswith(export_scope)):
761 context_def = control_flow_pb2.CondContextDef()
762 context_def.context_name = ops.strip_name_scope(self.name, export_scope)
763 context_def.pred_name = ops.strip_name_scope(self._pred.name,
764 export_scope)
765 context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
766 export_scope)
767 context_def.branch = self._branch
768 context_def.values_def.MergeFrom(
769 super(CondContext, self)._to_values_def(export_scope))
770 for nested in self._nested_contexts:
771 nested_def = context_def.nested_contexts.add()
772 nested.to_control_flow_context_def(nested_def)
774 return context_def
775 else:
776 return None
778 @staticmethod
779 def from_proto(context_def, import_scope=None):
780 """Returns a `CondContext` object created from `context_def`."""
781 ret = CondContext(context_def=context_def, import_scope=import_scope)
783 ret.Enter()
784 for nested_def in context_def.nested_contexts:
785 from_control_flow_context_def(nested_def, import_scope=import_scope)
786 ret.Exit()
787 return ret
789 def to_control_flow_context_def(self, context_def, export_scope=None):
790 context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
792 def AddValue(self, val):
793 """Add `val` to the current context and its outer context recursively."""
794 if val.name in self._values:
795 # Use the real value if it comes from outer context. This is needed in
796 # particular for nested conds.
797 result = self._external_values.get(val.name)
798 result = val if result is None else result
799 else:
800 result = val
801 self._values.add(val.name)
802 if self._outer_context:
803 result = self._outer_context.AddValue(val)
804 self._values.add(result.name)
805 self._external_values[result.name] = result
806 with ops.control_dependencies(None):
807 result = _SwitchRefOrTensor(result, self._pred)[self._branch]
808 if self._outer_context:
809 self._outer_context.AddInnerOp(result.op)
811 result.op.graph.prevent_fetching(result.op)
812 # pylint: disable=protected-access
813 result.op._set_control_flow_context(self)
814 # pylint: enable=protected-access
816 # Mark Switch output as seen by this context and any outer contexts,
817 # just like what we do for normal op outputs in _AddOpInternal() below.
818 ctxt = self
819 while ctxt is not None:
820 # pylint: disable=protected-access
821 ctxt._values.add(result.name)
822 ctxt = ctxt._outer_context
823 # pylint: enable=protected-access
825 self._external_values[val.name] = result
826 return result
828 def AddOp(self, op):
829 self._AddOpInternal(op)
831 def _AddOpInternal(self, op):
832 """Add `op` to the current context."""
833 if not op.inputs:
834 # If we're in a while loop, remove any control inputs from outside the
835 # loop.
836 self._RemoveExternalControlEdges(op)
838 if not any(
839 util.OpInContext(input_op, self) for input_op in op.control_inputs):
840 # pylint: disable=protected-access
841 op._add_control_input(self._pivot.op)
842 # pylint: enable=protected-access
843 else:
844 # Make each input to 'op' available in this CondContext. If an input is
845 # already part of this context there's nothing to do, but if it's
846 # external, AddValue() will handle adding the appropriate Switch node and
847 # other bookkeeping.
848 for index in range(len(op.inputs)):
849 x = op.inputs[index]
850 if op.type == "Merge" and x.op.type == "NextIteration":
851 # Edge case: if we're importing a while loop inside this CondContext,
852 # AddValue() will not correctly handle the NextIteration inputs to
853 # Merge node. The problem is that the NextIteration should also be
854 # part of this context, but if we're importing it won't have been
855 # processed and added to the context yet, so AddValue() will try to
856 # add a Switch which results in an invalid graph. Instead, we use the
857 # NextIteration input as-is here, and it will eventually be added to
858 # the context via AddOp().
859 real_x = x
860 else:
861 real_x = self.AddValue(x)
862 if real_x != x:
863 # pylint: disable=protected-access
864 op._update_input(index, real_x)
865 # pylint: enable=protected-access
866 # Remove any external control dependency on this op.
867 self._RemoveExternalControlEdges(op)
868 # pylint: disable=protected-access
869 if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
870 op._add_control_input(self._pivot.op)
871 # pylint: enable=protected-access
873 # Mark op's outputs as seen by this context and any outer contexts.
874 output_names = [x.name for x in op.outputs]
875 ctxt = self
876 while ctxt is not None:
877 # pylint: disable=protected-access
878 ctxt._values.update(output_names)
879 ctxt = ctxt._outer_context
880 # pylint: enable=protected-access
882 if self._outer_context or not util.IsLoopExit(op):
883 op.graph.prevent_fetching(op)
885 if self._outer_context:
886 self._outer_context.AddInnerOp(op)
888 def _ProcessOutputTensor(self, val):
889 """Process an output tensor of a conditional branch."""
890 real_val = val
891 if val.name not in self._values:
892 # Handle the special case of lambda: x
893 self._values.add(val.name)
894 if self._outer_context:
895 real_val = self._outer_context.AddValue(val)
896 self._values.add(real_val.name)
897 self._external_values[real_val.name] = real_val
898 real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
899 self._external_values[val.name] = real_val
900 else:
901 external_val = self._external_values.get(val.name)
902 if external_val is not None:
903 real_val = external_val
904 return real_val
906 def _BuildCondTensor(self, v):
907 if isinstance(v, ops.Operation):
908 # Use pivot as the proxy for this op.
909 return with_dependencies([v], self._pivot)
910 else:
911 v = nest.map_structure(
912 _convert_tensorarray_to_flow, v, expand_composites=True)
913 return self._ProcessOutputTensor(ops.convert_to_tensor(v))
915 def BuildCondBranch(self, fn):
916 """Add the subgraph defined by fn() to the graph."""
917 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
918 original_result = fn()
919 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
920 if len(post_summaries) > len(pre_summaries):
921 new_summaries = post_summaries[len(pre_summaries):]
922 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
923 summary_ref[:] = pre_summaries
924 with ops.control_dependencies(new_summaries):
925 if original_result is None:
926 return no_op(), None
927 elif not isinstance(original_result, ops.Operation):
928 original_result = variable_utils.convert_variables_to_tensors(
929 original_result)
930 original_result = nest.map_structure(
931 array_ops.identity, original_result, expand_composites=True)
932 if original_result is None:
933 return None, None
935 original_result = variable_utils.convert_variables_to_tensors(
936 original_result)
937 result = nest.map_structure(
938 self._BuildCondTensor, original_result, expand_composites=True)
939 if not isinstance(result, (list, _basetuple)):
940 result = [result]
941 return original_result, result
943 def IsCondContext(self):
944 return True
947# pylint: enable=g-doc-args
948# pylint: enable=redefined-outer-name
951def _resource_safe_shape(t):
952 """Returns the shape of t or the variable it points to."""
953 if t.dtype == dtypes.resource:
954 while t.op.inputs:
955 t = t.op.inputs[0]
956 return tensor_shape.TensorShape(t.op.get_attr("shape"))
957 return array_ops.shape_internal(t, optimize=False)
960# TODO(yuanbyu): Consider having a unified notion of context for
961# not only conditionals and loops but also control dependency and
962# subgraphs.
963class WhileContext(ControlFlowContext):
964 """The context for the loop construct."""
966 def __init__(self,
967 maximum_iterations=None,
968 parallel_iterations=10,
969 back_prop=True,
970 swap_memory=False,
971 name="while_context",
972 grad_state=None,
973 context_def=None,
974 import_scope=None):
975 """"Creates a `WhileContext`.
977 Args:
978 maximum_iterations: Optional upper bound on number of loop iterations.
979 parallel_iterations: The number of iterations allowed to run in parallel.
980 back_prop: Whether backprop is enabled for this while loop.
981 swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
982 name: Optional name prefix for the returned tensors.
983 grad_state: The gradient loop state.
984 context_def: Optional `WhileContextDef` protocol buffer to initialize the
985 `Whilecontext` python object from.
986 import_scope: Optional `string`. Name scope to add. Only used when
987 initialing from protocol buffer.
988 """
989 if context_def:
990 self._init_from_proto(context_def, import_scope=import_scope)
991 else:
992 ControlFlowContext.__init__(self)
993 self._init_from_args(maximum_iterations, parallel_iterations, back_prop,
994 swap_memory, name)
995 # The gradient loop state.
996 self._grad_state = grad_state
998 def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop,
999 swap_memory, name):
1000 """Creates a new `WhileContext` from arguments.
1002 Args:
1003 maximum_iterations: Optional upper bound on number of loop iterations.
1004 parallel_iterations: The number of iterations allowed to run in parallel.
1005 back_prop: Whether backprop is enabled for this while loop.
1006 swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
1007 name: Optional name prefix for the returned tensors.
1009 Raises:
1010 ValueError: If `parallel_iterations` has invalid value.
1011 """
1012 if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
1013 raise ValueError("'parallel_iterations' must be a positive integer: "
1014 "%s" % parallel_iterations)
1015 self._name = ops.get_default_graph().unique_name(name)
1016 self._maximum_iterations = maximum_iterations
1017 self._parallel_iterations = parallel_iterations
1018 self._back_prop = back_prop
1019 self._swap_memory = swap_memory
1020 # We use this node to control constants created by the pred lambda.
1021 self._pivot_for_pred = None
1022 # We use this node to control constants created by the body lambda.
1023 self._pivot_for_body = None
1024 # The boolean tensor for loop termination condition. Used in code
1025 # generation for gradient computation
1026 self._pivot = None
1027 # The list of exit tensors for loop variables.
1028 self._loop_exits = []
1029 # The list of enter tensors for loop variables.
1030 self._loop_enters = []
1031 self._graph = ops.get_default_graph()
1033 def _init_from_proto(self, context_def, import_scope=None):
1034 """Creates a new `WhileContext` from protocol buffer.
1036 Args:
1037 context_def: `WhileContextDef` protocol buffer.
1038 import_scope: Optional `string`. Name scope to add.
1039 """
1040 assert isinstance(context_def, control_flow_pb2.WhileContextDef)
1041 # Create from context_def.
1042 g = ops.get_default_graph()
1043 self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
1044 if context_def.maximum_iterations_name:
1045 self._maximum_iterations = g.as_graph_element(
1046 ops.prepend_name_scope(context_def.maximum_iterations_name,
1047 import_scope))
1048 else:
1049 self._maximum_iterations = None
1050 self._parallel_iterations = context_def.parallel_iterations
1051 self._back_prop = context_def.back_prop
1052 self._swap_memory = context_def.swap_memory
1053 self._pivot_for_pred = g.as_graph_element(
1054 ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
1055 # We use this node to control constants created by the body lambda.
1056 self._pivot_for_body = g.as_graph_element(
1057 ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
1058 # The boolean tensor for loop termination condition. Used in code
1059 # generation for gradient computation.
1060 self._pivot = g.as_graph_element(
1061 ops.prepend_name_scope(context_def.pivot_name, import_scope))
1062 # The list of exit tensors for loop variables.
1063 self._loop_exits = [
1064 g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
1065 for exit_name in context_def.loop_exit_names
1066 ]
1067 # The list of enter tensors for loop variables.
1068 self._loop_enters = [
1069 g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
1070 for enter_name in context_def.loop_enter_names
1071 ]
1072 super(WhileContext, self).__init__(
1073 values_def=context_def.values_def, import_scope=import_scope)
1075 # import_scope causes self.name to be different from the original serialized
1076 # context's name. Rewrite "frame_name" attrs with the new name.
1077 if import_scope:
1078 for tensor_name in self._values:
1079 op = g.as_graph_element(tensor_name).op
1080 if util.IsLoopEnter(op):
1081 # pylint: disable=protected-access
1082 op._set_attr("frame_name",
1083 attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
1084 # pylint: enable=protected-access
1085 self._graph = ops.get_default_graph()
1087 @property
1088 def maximum_iterations(self):
1089 """The maximum number of iterations that will be executed."""
1090 return self._maximum_iterations
1092 @property
1093 def parallel_iterations(self):
1094 """The number of iterations allowed to run in parallel."""
1095 return self._parallel_iterations
1097 @property
1098 def back_prop(self):
1099 """True iff backprop is enabled for this while loop."""
1100 return self._back_prop
1102 @property
1103 def swap_memory(self):
1104 """True iff GPU-CPU memory swap is enabled for this while loop."""
1105 return self._swap_memory
1107 @property
1108 def pivot(self):
1109 """The boolean tensor representing the loop termination condition."""
1110 return self._pivot
1112 @property
1113 def loop_enters(self):
1114 """The list of enter tensors for loop variables."""
1115 return self._loop_enters
1117 @property
1118 def loop_exits(self):
1119 """The list of exit tensors for loop variables."""
1120 return self._loop_exits
1122 @property
1123 def grad_state(self):
1124 """The gradient loop state."""
1125 return self._grad_state
1127 def to_proto(self, export_scope=None):
1128 """Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
1130 Args:
1131 export_scope: Optional `string`. Name scope to remove.
1133 Returns:
1134 A `WhileContextDef` protocol buffer.
1135 """
1136 if (export_scope is None or self.name.startswith(export_scope)):
1137 context_def = control_flow_pb2.WhileContextDef()
1138 context_def.context_name = ops.strip_name_scope(self.name, export_scope)
1139 context_def.parallel_iterations = self._parallel_iterations
1140 if self._maximum_iterations is not None:
1141 context_def.maximum_iterations_name = ops.strip_name_scope(
1142 self._maximum_iterations.name, export_scope)
1143 context_def.back_prop = self._back_prop
1144 context_def.swap_memory = self._swap_memory
1145 context_def.pivot_for_pred_name = ops.strip_name_scope(
1146 self._pivot_for_pred.name, export_scope)
1147 context_def.pivot_for_body_name = ops.strip_name_scope(
1148 self._pivot_for_body.name, export_scope)
1149 context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
1150 export_scope)
1151 context_def.loop_exit_names.extend([
1152 ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
1153 ])
1154 context_def.loop_enter_names.extend([
1155 ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
1156 ])
1157 context_def.values_def.MergeFrom(
1158 super(WhileContext, self)._to_values_def(export_scope=export_scope))
1159 for nested in self._nested_contexts:
1160 nested_def = context_def.nested_contexts.add()
1161 nested.to_control_flow_context_def(nested_def)
1163 return context_def
1164 else:
1165 return None
1167 def to_control_flow_context_def(self, context_def, export_scope=None):
1168 context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
1170 @staticmethod
1171 def from_proto(context_def, import_scope=None):
1172 """Returns a `WhileContext` object created from `context_def`.
1174 Args:
1175 context_def: A `WhileContextDef` protocol buffer.
1176 import_scope: Optional `string`. Name scope to add.
1178 Returns:
1179 A `WhileContext` Python object.
1180 """
1181 ret = WhileContext(context_def=context_def, import_scope=import_scope)
1182 ret.Enter()
1183 for nested_def in context_def.nested_contexts:
1184 from_control_flow_context_def(nested_def, import_scope=import_scope)
1185 ret.Exit()
1186 return ret
1188 def GetWhileContext(self):
1189 return self
1191 def GetControlPivot(self):
1192 if self._pivot_for_body is not None:
1193 return self._pivot_for_body
1194 return self._pivot_for_pred
1196 def AddValue(self, val):
1197 """Add `val` to the current context and its outer context recursively."""
1198 result = val
1199 new_value = val.name not in self._values
1200 # Don't treat ops in this context as new values. Usually all known values
1201 # are in self._values, except when we're importing a while loop inside this
1202 # WhileContext. Since there's a cycle in this case, `val` may be part of the
1203 # imported while loop but not yet processed by this context and added to
1204 # self._values in _AddOpInternal. We only want to process external input
1205 # tensors to the while loop here.
1206 new_value &= val.op._control_flow_context is not self # pylint: disable=protected-access
1207 if new_value:
1208 self._values.add(val.name)
1210 # If we are in a grad context and val is from its forward context,
1211 # use GetRealValue(), which adds the logic to save the history of
1212 # val in forward.
1213 grad_ctxt = ops.get_default_graph()._get_control_flow_context()
1214 if grad_ctxt:
1215 grad_ctxt = grad_ctxt.GetWhileContext()
1216 if grad_ctxt.grad_state:
1217 forward_ctxt = util.GetWhileContext(val.op)
1218 if util.IsLoopExit(val.op):
1219 forward_ctxt = forward_ctxt.outer_context
1220 if forward_ctxt:
1221 forward_ctxt = forward_ctxt.GetWhileContext()
1222 if forward_ctxt == grad_ctxt.grad_state.forward_context:
1223 real_val = grad_ctxt.grad_state.GetRealValue(val)
1224 self._external_values[val.name] = real_val
1225 return real_val
1227 if self._outer_context is not None:
1228 result = self._outer_context.AddValue(val)
1229 # Create an Enter to make `result` known to this loop context.
1230 with ops.control_dependencies(None):
1231 enter = _Enter(
1232 result,
1233 self._name,
1234 is_constant=True,
1235 parallel_iterations=self._parallel_iterations)
1236 enter.graph.prevent_feeding(enter)
1237 if self._outer_context:
1238 self._outer_context.AddInnerOp(enter.op)
1239 # Fix the control inputs and control flow context of these enter ops.
1240 self._FixControlInputsAndContext([enter])
1242 # Add `enter` in this context.
1243 self._values.add(enter.name)
1244 self._external_values[val.name] = enter
1245 result = enter
1246 else:
1247 actual_val = self._external_values.get(val.name)
1248 if actual_val is not None:
1249 result = actual_val
1250 return result
1252 def AddOp(self, op):
1253 """Add `op` to the current context."""
1254 # For a reduction op, if op is in a grad context and its input is from
1255 # its forward context, moving op to the forward context means we would
1256 # store the tensor after the reduction as opposed to the tensor before
1257 # reduction, and therefore could significantly reduce memory consumption.
1258 # For now, we do this only for a few ops.
1259 #
1260 # If in XLA context, do not move constant ops to forward pass as pushing to
1261 # and popping from a stack removes the constant property of an op and breaks
1262 # XLA compilation, which requires certain inputs to be constant for certain
1263 # ops.
1264 if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}:
1265 grad_ctxt = ops.get_default_graph()._get_control_flow_context()
1266 if grad_ctxt:
1267 grad_ctxt = grad_ctxt.GetWhileContext()
1268 if grad_ctxt.grad_state:
1269 op_input_forward_ctxt = util.GetWhileContext(op.inputs[0].op)
1270 if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
1271 op_input_ctxt = op.inputs[0].op._get_control_flow_context()
1272 op._set_control_flow_context(op_input_ctxt)
1273 op_input_ctxt._AddOpInternal(op)
1274 return
1275 self._AddOpInternal(op)
1277 def _AddOpInternal(self, op):
1278 """Add `op` to the current context.
1280 We move any external control dependencies of the op to the loop pivot, to
1281 ensure they get executed.
1282 """
1283 # This is needed to prevent frame mismatch errors where there are Const
1284 # nodes inside tf.function in v1 while_loop and inlining is turned on.
1285 if op.type in ["PartitionedCall", "StatefulPartitionedCall"]:
1286 op._add_control_input(self.GetControlPivot().op) # pylint: disable=protected-access
1287 if not op.inputs:
1288 # Remove any external control dependency on this op
1289 control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
1290 # Add a control edge from the control pivot to this op.
1291 if not control_inputs:
1292 # pylint: disable=protected-access
1293 op._add_control_input(self.GetControlPivot().op)
1294 # pylint: enable=protected-access
1295 for x in op.outputs:
1296 self._values.add(x.name)
1297 else:
1298 for index in range(len(op.inputs)):
1299 x = op.inputs[index]
1300 real_x = self.AddValue(x)
1301 if real_x != x:
1302 op._update_input(index, real_x) # pylint: disable=protected-access
1303 # Remove any external control dependency on this op.
1304 _, external_inputs = self._RemoveExternalControlEdges(op)
1305 # Add a control dependency to prevent loop invariants from
1306 # enabling ops that should not be executed.
1307 self._MaybeAddControlDependency(op)
1308 for x in op.outputs:
1309 self._values.add(x.name)
1310 if external_inputs:
1311 # Use an identity to pull control inputs as data inputs. Note that we
1312 # ignore ops which don't have outputs. TODO(apassos): fix that
1313 with ops.control_dependencies(None):
1314 self.Enter()
1315 external_inputs = [
1316 array_ops.identity(x.outputs[0]).op
1317 for x in external_inputs
1318 if x.outputs
1319 ]
1320 self.Exit()
1321 op._add_control_inputs(external_inputs) # pylint: disable=protected-access
1322 if self._outer_context or not util.IsLoopExit(op):
1323 op.graph.prevent_fetching(op)
1324 for x in op.outputs:
1325 op.graph.prevent_feeding(x)
1327 if self._outer_context:
1328 self._outer_context.AddInnerOp(op)
1330 def _MaybeAddControlDependency(self, op):
1331 """Add a control input to the op if it only depends on loop invariants."""
1333 def _IsOpFree(op):
1334 """Determines if `op` needs a control dependency."""
1335 if op.control_inputs:
1336 return False
1337 # pylint: disable=protected-access
1338 if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
1339 return True
1340 # pylint: enable=protected-access
1341 for x in op.inputs:
1342 if not util.IsLoopConstantEnter(x.op):
1343 return False
1344 return True
1346 if _IsOpFree(op):
1347 # pylint: disable=protected-access
1348 op._add_control_input(self.GetControlPivot().op)
1349 # pylint: enable=protected-access
1351 def AddForwardLoopCounter(self, outer_grad_state):
1352 """Adds a loop that counts the number of iterations.
1354 This is added to the forward loop at the time when we start to
1355 create the loop for backprop gradient computation. Called in
1356 the outer context of this forward context.
1358 The pseudocode is:
1359 `n = 0; while (_pivot) { n++; }`
1361 Note that a control dependency is added to `n` to ensure the correct
1362 execution order of stack push ops.
1364 Args:
1365 outer_grad_state: The outer grad state. None if not nested.
1367 Returns:
1368 The number of iterations taken by the forward loop and the loop index.
1369 """
1370 n = constant_op.constant(0, name="f_count")
1371 if outer_grad_state is not None:
1372 # Force the stack pushes of i-th execution of an inner loop to be ordered
1373 # before the pushes of (i+1)-th execution of the same inner loop.
1374 outer_add_op = outer_grad_state.forward_index.op.inputs[0].op
1375 n.op._add_control_input(outer_add_op) # pylint: disable=protected-access
1377 self.Enter()
1378 self.AddName(n.name)
1379 enter_n = _Enter(
1380 n,
1381 self._name,
1382 is_constant=False,
1383 parallel_iterations=self._parallel_iterations,
1384 name="f_count")
1385 self.loop_enters.append(enter_n)
1387 merge_n = merge([enter_n, enter_n])[0]
1388 switch_n = switch(merge_n, self._pivot)
1390 index = math_ops.add(switch_n[1], 1)
1391 next_n = _NextIteration(index)
1392 merge_n.op._update_input(1, next_n)
1394 total_iterations = exit(switch_n[0], name="f_count")
1395 self.loop_exits.append(total_iterations)
1396 self.ExitResult([total_iterations])
1397 self.Exit()
1398 return total_iterations, next_n
1400 def AddBackpropLoopCounter(self, count, outer_grad_state):
1401 """Add the backprop loop that controls the iterations.
1403 This is added to the backprop loop. It is used to control the loop
1404 termination of the backprop loop. Called in the outer context of
1405 this grad context.
1407 The pseudocode is:
1408 `n = count; while (n >= 1) { n--; }`
1410 Note that a control dependency is added to `final_zero` to ensure the
1411 correct execution order of stack pop ops.
1413 Args:
1414 count: The number of iterations for backprop.
1415 outer_grad_state: The outer grad state. None if not nested.
1417 Returns:
1418 The loop index.
1419 """
1420 in_separate_functions = count.graph is not ops.get_default_graph()
1421 if in_separate_functions:
1422 # Brings the count into this graph
1423 count = array_ops.identity(count)
1424 else:
1425 # TODO(apassos) XLA expects this constant to be created outside the loop,
1426 # so doing that for now.
1427 one = constant_op.constant(1, name="b_count")
1429 self.Enter()
1430 self.AddName(count.name)
1431 enter_count = _Enter(
1432 count,
1433 self._name,
1434 is_constant=False,
1435 parallel_iterations=self._parallel_iterations,
1436 name="b_count")
1437 self.loop_enters.append(enter_count)
1439 merge_count = merge([enter_count, enter_count])[0]
1440 self._pivot_for_pred = merge_count
1442 if in_separate_functions:
1443 one = constant_op.constant(1, name="b_count")
1444 pred = math_ops.greater_equal(merge_count, one)
1445 self._pivot = loop_cond(pred, name="b_count")
1446 switch_count = switch(merge_count, self._pivot)
1448 index = math_ops.subtract(switch_count[1], one)
1449 self._pivot_for_body = index
1450 next_count = _NextIteration(index)
1451 merge_count.op._update_input(1, next_count)
1453 final_zero = exit(switch_count[0], name="b_count")
1454 self.loop_exits.append(final_zero)
1455 if outer_grad_state is not None:
1456 # Force the stack pops of i-th execution of an inner loop to be ordered
1457 # before the pops of (i+1)-th execution of the same inner loop.
1458 # pylint: disable=protected-access
1459 outer_grad_state.grad_sync._add_control_input(final_zero.op)
1460 # pylint: enable=protected-access
1462 self.ExitResult([final_zero])
1463 self.Exit()
1464 return next_count
1466 def AddBackpropAccumulator(self, op, grad):
1467 """Add an accumulation loop for every loop invariant.
1469 This is added to the backprop loop. It is used to accumulate partial
1470 gradients within each loop iteration. Called when in the gradient while
1471 context.
1473 The pseudocode is:
1474 ```
1475 acc = 0.0;
1476 while (_pivot) {
1477 acc += grad;
1478 }
1479 ```
1481 Args:
1482 op: The Enter op for a loop invariant.
1483 grad: The partial gradient of an iteration for a loop invariant.
1485 Returns:
1486 The gradient for a loop invariant.
1487 """
1488 self.Exit()
1489 # Create a zeros tensor with the right shape for acc. If we don't
1490 # know the full shape statically, we will have to get the shape
1491 # dynamically from the forward inference. Getting the shape right
1492 # for the zeros is only needed for the base case when the loop exits
1493 # without running any iterations.
1494 shape = grad.get_shape()
1495 if shape.is_fully_defined():
1496 if self.outer_context:
1497 self.outer_context.Enter()
1498 acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
1499 if self.outer_context:
1500 self.outer_context.Exit()
1501 else:
1502 value = op.inputs[0]
1503 if (isinstance(self.outer_context, WhileContext) and
1504 self.outer_context.grad_state is not None):
1505 # We are in a nested while loop.
1506 forward_ctxt = self.grad_state.forward_context
1507 forward_ctxt.outer_context.Enter()
1508 zeros_shape = array_ops.shape_internal(value, optimize=False)
1509 forward_ctxt.outer_context.Exit()
1510 outer_grad_state = self.grad_state.outer_grad_state
1511 history_zeros_shape = outer_grad_state.AddForwardAccumulator(
1512 zeros_shape)
1513 self.outer_context.Enter()
1514 real_shape = outer_grad_state.AddBackpropAccumulatedValue(
1515 history_zeros_shape, zeros_shape)
1516 acc = array_ops.zeros(real_shape, grad.dtype)
1517 self.outer_context.Exit()
1518 else:
1519 if self.outer_context:
1520 self.outer_context.Enter()
1521 zeros_shape = array_ops.shape_internal(value, optimize=False)
1522 acc = array_ops.zeros(zeros_shape, grad.dtype)
1523 if self.outer_context:
1524 self.outer_context.Exit()
1526 self.Enter()
1527 self.AddName(acc.name)
1528 enter_acc = _Enter(
1529 acc,
1530 self._name,
1531 is_constant=False,
1532 parallel_iterations=self._parallel_iterations,
1533 name="b_acc")
1534 self.loop_enters.append(enter_acc)
1536 merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
1537 switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
1539 add_acc = math_ops.add(switch_acc_true, grad)
1540 next_acc = _NextIteration(add_acc)
1541 merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access
1543 result_acc = exit(switch_acc_false, name="b_acc")
1544 self.loop_exits.append(result_acc)
1545 self.ExitResult([result_acc])
1546 return result_acc
1548 def AddBackpropIndexedSlicesAccumulator(self, op, grad):
1549 """This is used for accumulating gradients that are IndexedSlices.
1551 This is essentially the equivalent of AddBackpropAccumulator but optimized
1552 for things like updating embeddings from within a while loop.
1554 Args:
1555 op: The Enter op for a loop invariant.
1556 grad: The partial gradients represented as an IndexedSlices.
1558 Returns:
1559 The accumulated IndexedSlices gradient of the loop invariant.
1560 """
1561 values = grad.values
1562 indices = grad.indices
1563 dense_shape = grad.dense_shape
1565 self.Exit()
1566 if self.outer_context:
1567 self.outer_context.Enter()
1568 if values.get_shape().is_fully_defined():
1569 values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] +
1570 values.get_shape().dims[1:])
1571 if self.outer_context:
1572 self.outer_context.Enter()
1573 values_acc = constant_op.constant(
1574 0, values.dtype, shape=values_shape, name="b_acc")
1575 if self.outer_context:
1576 self.outer_context.Exit()
1577 else:
1578 values_shape = _resource_safe_shape(op.inputs[0])[1:]
1579 values_shape = array_ops.concat([[1], values_shape], 0)
1580 values_acc = array_ops.zeros(values_shape, dtype=values.dtype)
1581 indices_acc = constant_op.constant([0], indices.dtype)
1582 shape_acc = None
1583 if dense_shape is not None:
1584 if dense_shape.get_shape().is_fully_defined():
1585 if self.outer_context:
1586 self.outer_context.Enter()
1587 shape_acc = constant_op.constant(
1588 0, dense_shape.dtype, shape=dense_shape.get_shape())
1589 if self.outer_context:
1590 self.outer_context.Exit()
1591 else:
1592 shape_acc = array_ops.zeros_like(
1593 array_ops.shape_internal(
1594 op.inputs[0], optimize=False, out_type=dense_shape.dtype),
1595 optimize=False)
1597 if self.outer_context:
1598 self.outer_context.Exit()
1600 self.Enter()
1601 self.AddName(values_acc.name)
1602 self.AddName(indices_acc.name)
1603 init_acc = [indices_acc, values_acc]
1604 if shape_acc is not None:
1605 self.AddName(shape_acc.name)
1606 init_acc.append(shape_acc)
1608 # Set use_input_shape=False since the accumulator tensors will grow in
1609 # size. If use_input_shape=True, the _update_input call below will result in
1610 # incompatible shapes.
1611 enter_acc = [
1612 _Enter(
1613 x,
1614 self._name,
1615 is_constant=False,
1616 parallel_iterations=self._parallel_iterations,
1617 use_input_shape=False,
1618 name="b_acc") for x in init_acc
1619 ]
1620 # Manually set appropriate partial shapes.
1621 enter_acc[0].set_shape([None])
1622 if values_acc.shape.dims is not None:
1623 enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:])
1624 self.loop_enters.extend(enter_acc)
1626 merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
1627 switch_acc = [switch(x, self._pivot) for x in merge_acc]
1629 # The actual accumulation.
1630 acc_indexed_slices = [
1631 array_ops.concat([xa[1], xv], 0)
1632 for xa, xv in zip(switch_acc[:2], [indices, values])
1633 ]
1634 if shape_acc is not None:
1635 # For the shape we just keep the maximum
1636 acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
1638 next_acc = [_NextIteration(x) for x in acc_indexed_slices]
1639 for xm, xn in zip(merge_acc, next_acc):
1640 xm.op._update_input(1, xn) # pylint: disable=protected-access
1642 exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]
1643 self.loop_exits.extend(exit_acc)
1645 self.ExitResult(exit_acc)
1646 return indexed_slices.IndexedSlices(
1647 indices=exit_acc[0],
1648 values=exit_acc[1],
1649 dense_shape=exit_acc[2] if shape_acc is not None else None)
1651 def _InitializeValues(self, values):
1652 """Makes the values known to this context."""
1653 self._values = set()
1654 for x in values:
1655 if isinstance(x, ops.Tensor):
1656 self._values.add(x.name)
1657 else:
1658 raise TypeError("'values' must be a list of Tensors. "
1659 f"Received: {type(x)}.")
1661 def _BuildLoop(self, pred, body, flat_orig_loop_vars, flat_loop_vars,
1662 loop_vars_signature):
1663 """Core: Add the loop termination condition and body to the graph."""
1664 flat_shape_invariants = nest.map_structure(
1665 lambda spec: spec.shape,
1666 nest.flatten(loop_vars_signature, expand_composites=True))
1668 # Let the context know the loop variables so the loop variables
1669 # would be added in the outer contexts properly.
1670 self._InitializeValues(flat_loop_vars)
1671 if self._outer_context:
1672 real_vars = [self._outer_context.AddValue(x) for x in flat_loop_vars]
1673 else:
1674 real_vars = flat_loop_vars
1676 enter_vars = []
1677 with ops.control_dependencies(None):
1678 for real_var, shape_invariant in zip(real_vars, flat_shape_invariants):
1679 enter_var = _Enter(
1680 real_var,
1681 self._name,
1682 is_constant=False,
1683 parallel_iterations=self._parallel_iterations,
1684 use_input_shape=False)
1686 if _ShapeLessThanOrEqual(real_var.get_shape(), shape_invariant):
1687 enter_var.set_shape(shape_invariant)
1688 else:
1689 raise ValueError(
1690 f"The shape invariant specified for {real_var.name} is not "
1691 "compatible with the initial shape of the loop variable. It "
1692 f"enters the loop with shape {real_var.get_shape()}, but the "
1693 f"specified shape invariant is {shape_invariant}.")
1695 enter_var.graph.prevent_feeding(enter_var)
1696 if self._outer_context:
1697 self._outer_context.AddInnerOp(enter_var.op)
1698 enter_vars.append(enter_var)
1700 # Finds the closest enclosing non-None control pivot.
1701 outer_context = self._outer_context
1702 control_pivot = None
1703 while outer_context is not None and control_pivot is None:
1704 control_pivot = outer_context.GetControlPivot()
1705 # pylint: disable=protected-access
1706 outer_context = outer_context._outer_context
1707 # pylint: enable=protected-access
1709 if control_pivot is not None:
1710 for var in enter_vars:
1711 if util.IsLoopConstantEnter(var.op.inputs[0].op):
1712 # pylint: disable=protected-access
1713 var.op._add_control_input(control_pivot.op)
1714 # pylint: enable=protected-access
1716 # Fix the control inputs and control flow context of these enter ops.
1717 self._FixControlInputsAndContext(enter_vars)
1718 self._InitializeValues(enter_vars)
1719 self._loop_enters = enter_vars
1721 merge_vars = [merge([x, x])[0] for x in enter_vars]
1722 self._pivot_for_pred = merge_vars[0]
1724 merge_vars_with_tensorarrays = nest.map_structure(
1725 _convert_flow_to_tensorarray, flat_orig_loop_vars, merge_vars)
1726 # Build the graph for pred.
1727 packed_vars = nest.pack_sequence_as(
1728 structure=loop_vars_signature,
1729 flat_sequence=merge_vars_with_tensorarrays,
1730 expand_composites=True)
1731 c = ops.convert_to_tensor(pred(*packed_vars))
1732 self._pivot = loop_cond(c, name="LoopCond")
1733 switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
1735 # Build the graph for body.
1736 vars_for_body = [_Identity(x[1]) for x in switch_vars]
1737 self._pivot_for_body = vars_for_body[0]
1738 # Convert TensorArray flow variables inside the context back into
1739 # their associated TensorArrays for calling the body.
1740 vars_for_body_with_tensorarrays = nest.map_structure(
1741 _convert_flow_to_tensorarray, flat_orig_loop_vars, vars_for_body)
1742 packed_vars_for_body = nest.pack_sequence_as(
1743 structure=loop_vars_signature,
1744 flat_sequence=vars_for_body_with_tensorarrays,
1745 expand_composites=True)
1746 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
1747 body_result = body(*packed_vars_for_body)
1748 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
1749 if not nest.is_nested(body_result):
1750 body_result = [body_result]
1751 if len(post_summaries) > len(pre_summaries):
1752 new_summaries = post_summaries[len(pre_summaries):]
1753 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
1754 summary_ref[:] = pre_summaries
1755 with ops.control_dependencies(new_summaries):
1757 def map_fn(x):
1758 # TODO(apassos) figure out how to trigger with tensor arrays as well
1759 if isinstance(x, tensor_array_ops.TensorArray):
1760 return x
1761 return array_ops.identity(x)
1763 body_result = nest.map_structure(
1764 map_fn, body_result, expand_composites=True)
1766 body_result = variable_utils.convert_variables_to_tensors(body_result)
1767 # Compare the structure types of input and output of body.
1768 # For backwards compatibility, the first layer is forced to a list
1769 # during this comparison, because inputs are typically lists and
1770 # outputs of the body are typically tuples.
1771 nest.assert_same_structure(
1772 list(packed_vars_for_body), list(body_result), expand_composites=True)
1774 # Store body_result to keep track of TensorArrays returned by body
1775 original_body_result = body_result
1776 # Convert TensorArrays returned by body into their flow variables
1777 result = nest.map_structure(
1778 _convert_tensorarray_to_flow,
1779 nest.flatten(body_result, expand_composites=True),
1780 expand_composites=True)
1781 result = ops.convert_n_to_tensor_or_composite(result)
1783 # Add NextIteration and the back edges to complete the loop.
1784 if len(merge_vars) != len(result):
1785 raise ValueError("Number of inputs and outputs of 'body' must match "
1786 f"'loop_vars'. Got {len(merge_vars)} for the number of "
1787 f"inputs/outputs, and {len(result)} for 'loop_vars'.")
1788 next_vars = []
1789 for m, v in zip(merge_vars, result):
1790 next_vars.append(_AddNextAndBackEdge(m, v))
1792 # Add the exit ops.
1793 exit_vars = [exit(x[0]) for x in switch_vars]
1794 self._loop_exits = exit_vars
1796 # Exit the loop.
1797 self.ExitResult(exit_vars)
1799 return original_body_result, exit_vars
1801 def BuildLoop(self, pred, body, loop_vars, shape_invariants,
1802 return_same_structure):
1803 """Add the loop termination condition and body to the graph."""
1805 # Keep flat_orig_loop_vars to identify which are TensorArrays
1806 flat_orig_loop_vars = nest.flatten(loop_vars, expand_composites=True)
1808 loop_vars = nest.map_structure(
1809 _convert_to_tensor_or_composite_or_tensorarray, loop_vars)
1810 # Convert TensorArrays to their flow variables
1811 flat_loop_vars = nest.map_structure(
1812 _convert_tensorarray_to_flow,
1813 nest.flatten(loop_vars, expand_composites=True))
1815 if shape_invariants is not None:
1816 loop_vars_signature = nest.map_structure(
1817 _shape_invariant_to_type_spec, loop_vars, shape_invariants)
1818 else:
1819 loop_vars_signature = nest.map_structure(
1820 _shape_invariant_to_type_spec, loop_vars)
1822 try:
1823 self.Enter()
1824 # _BuildLoop calls _update_input in several places. _mutation_lock()
1825 # ensures a Session.run call cannot occur between creating and mutating
1826 # new ops.
1827 with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
1828 original_body_result, exit_vars = self._BuildLoop(
1829 pred, body, flat_orig_loop_vars, flat_loop_vars,
1830 loop_vars_signature)
1831 finally:
1832 self.Exit()
1834 flat_result = nest.flatten(original_body_result, expand_composites=True)
1835 # Convert TensorArray flow variables outside the context back into
1836 # their associated TensorArrays for returning to caller.
1837 exit_vars_with_tensorarrays = nest.map_structure(
1838 _convert_flow_to_tensorarray, flat_result, exit_vars)
1840 packed_exit_vars = nest.pack_sequence_as(
1841 structure=original_body_result,
1842 flat_sequence=exit_vars_with_tensorarrays,
1843 expand_composites=True)
1845 if return_same_structure:
1846 return packed_exit_vars
1847 else:
1848 return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
1850 def _FixControlInputsAndContext(self, enters):
1851 graph = ops.get_default_graph()
1852 # pylint: disable=protected-access
1853 for e in enters:
1854 if isinstance(e, ops.Tensor):
1855 xs = [e]
1856 else:
1857 raise TypeError("'enters' must be a list of Tensors. "
1858 f"Received: {type(e)}.")
1859 for x in xs:
1860 inp_op = x.op.inputs[0].op
1861 control_inputs = graph._control_dependencies_for_inputs([inp_op])
1862 outer_control_inputs = []
1863 for op in control_inputs:
1864 # We need to keep control inputs that are in any ancestor
1865 # ControlFlowContext, and within outer WhileContext.
1866 keep_as_control_input = True
1867 op_ctxt = util.GetOutputContext(op)
1868 outer_ctxt = self.outer_context
1869 outer_while_context = (None if outer_ctxt is None else
1870 outer_ctxt.GetWhileContext())
1871 while outer_ctxt != op_ctxt:
1872 if outer_ctxt is None or outer_ctxt == outer_while_context:
1873 keep_as_control_input = False
1874 break
1875 outer_ctxt = outer_ctxt.outer_context
1876 if keep_as_control_input:
1877 outer_control_inputs.append(op)
1878 x.op._set_control_flow_context(self)
1879 x.op._add_control_inputs(outer_control_inputs)
1880 graph._record_op_seen_by_control_dependencies(x.op)
1881 # pylint: enable=protected-access
1883 def IsWhileContext(self):
1884 return True
1887# pylint: enable=redefined-outer-name
1890def _AsTensorList(x, p):
1891 """Return x as a list of Tensors or IndexedSlices.
1893 For entries of `x` that are Operations, this returns an Identity of `p`
1894 with a dependency on the operation.
1896 Args:
1897 x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
1898 p: A Tensor to return for entries in `x` that are Operations.
1900 Returns:
1901 A list of Tensors or IndexedSlices.
1902 """
1903 if not isinstance(x, (list, _basetuple)):
1904 x = [x]
1906 l = []
1907 for v in x:
1908 if isinstance(v, ops.Operation):
1909 v = with_dependencies([v], p)
1910 v = ops.convert_to_tensor_or_composite(v)
1911 if isinstance(v, ops.Tensor):
1912 l.append(array_ops.identity(v))
1913 else:
1914 l.append(
1915 indexed_slices.IndexedSlices(
1916 array_ops.identity(v.values), array_ops.identity(v.indices)))
1917 return l
1920def _CheckResults(a, b):
1921 assert len(a) == len(b), (
1922 "Values returned by a() and b() must have the same length.")
1923 for x, y in zip(a, b):
1924 assert x.dtype == y.dtype, (
1925 "Values returned by a() [%s] and b() [%s] must have "
1926 "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
1929def with_dependencies(dependencies, output_tensor, name=None):
1930 """Produces the content of `output_tensor` only after `dependencies`.
1932 In some cases, a user may want the output of an operation to be
1933 consumed externally only after some other dependencies have run
1934 first. This function ensures returns `output_tensor`, but only after all
1935 operations in `dependencies` have run. Note that this means that there is
1936 no guarantee that `output_tensor` will be evaluated after any `dependencies`
1937 have run.
1939 See also `tf.tuple` and `tf.group`.
1941 Args:
1942 dependencies: Iterable of operations to run before this op finishes.
1943 output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
1944 name: (Optional) A name for this operation.
1946 Returns:
1947 Same as `output_tensor`.
1949 Raises:
1950 TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
1951 """
1952 if context.executing_eagerly():
1953 return output_tensor
1954 with ops.name_scope(name, "control_dependency",
1955 list(dependencies) + [output_tensor]) as name:
1956 with ops.colocate_with(output_tensor):
1957 with ops.control_dependencies(dependencies):
1958 output_tensor = ops.convert_to_tensor_or_composite(output_tensor)
1959 if isinstance(output_tensor, indexed_slices.IndexedSlices):
1960 return indexed_slices.IndexedSlices(
1961 _Identity(output_tensor.values, name=name), output_tensor.indices,
1962 output_tensor.dense_shape)
1963 else:
1964 return _Identity(output_tensor, name=name)
1967def _GroupControlDeps(dev, deps, name=None):
1968 with ops.control_dependencies(deps):
1969 if dev is None:
1970 return no_op(name=name)
1971 else:
1972 with ops.device(dev):
1973 return no_op(name=name)
1976# TODO(touts): Accept "inputs" as a list.
1977@tf_export("group")
1978def group(*inputs, **kwargs):
1979 """Create an op that groups multiple operations.
1981 When this op finishes, all ops in `inputs` have finished. This op has no
1982 output.
1984 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
1985 this method, as ops execute in the expected order thanks to automatic control
1986 dependencies.* Only use `tf.group` when working with v1
1987 `tf.Graph` code.
1989 When operating in a v1-style graph context, ops are not executed in the same
1990 order as specified in the code; TensorFlow will attempt to execute ops in
1991 parallel or in an order convenient to the result it is computing. `tf.group`
1992 allows you to request that one or more results finish before execution
1993 continues.
1995 `tf.group` creates a single op (of type `NoOp`), and then adds appropriate
1996 control dependencies. Thus, `c = tf.group(a, b)` will compute the same graph
1997 as this:
1999 with tf.control_dependencies([a, b]):
2000 c = tf.no_op()
2002 See also `tf.tuple` and
2003 `tf.control_dependencies`.
2005 Args:
2006 *inputs: Zero or more tensors to group.
2007 name: A name for this operation (optional).
2009 Returns:
2010 An Operation that executes all its inputs.
2012 Raises:
2013 ValueError: If an unknown keyword argument is provided.
2014 """
2015 if context.executing_eagerly():
2016 return None
2017 name = kwargs.pop("name", None)
2018 if kwargs:
2019 raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
2020 with ops.name_scope(name, "group_deps", inputs) as name:
2021 # Grouping no inputs means do nothing
2022 if not inputs:
2023 return no_op(name=name)
2025 # Sorts *inputs according to their devices.
2026 ops_on_device = {} # device -> operations specified on the device.
2027 for inp in nest.flatten(inputs, expand_composites=True):
2028 if not hasattr(inp, "device"):
2029 raise TypeError("'inputs' should be zero or more (nested) Tensors. "
2030 f"Received '{inp}' with type '{type(inp)}'.")
2031 dev = inp.device
2032 if dev in ops_on_device:
2033 ops_on_device[dev].append(inp)
2034 else:
2035 ops_on_device[dev] = [inp]
2036 if len(ops_on_device) == 1:
2037 # 1-level tree. The root node is the returned NoOp node.
2038 (dev, deps), = ops_on_device.items()
2039 return _GroupControlDeps(dev, deps, name=name)
2041 # 2-level tree. The root node is the returned NoOp node.
2042 # deps contains 1 NoOp node for each device.
2043 deps = []
2045 def device_key(dev):
2046 """A sort key that allows None to be compared to strings."""
2047 return "" if dev is None else dev
2049 for dev in sorted(ops_on_device, key=device_key):
2050 deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
2052 with ops.control_dependencies(deps):
2053 return no_op(name=name)
2056@tf_export("tuple", v1=[])
2057@dispatch.add_dispatch_support
2058def tuple_v2(tensors, control_inputs=None, name=None):
2059 """Groups tensors together.
2061 The returned tensors have the same value as the input tensors, but they
2062 are computed only after all the input tensors have been computed.
2064 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
2065 this method, as ops execute in the expected order thanks to automatic control
2066 dependencies.* Only use `tf.tuple` when working with v1 `tf.Graph` code.
2068 See also `tf.group` and `tf.control_dependencies`.
2070 Example:
2071 >>> with tf.Graph().as_default():
2072 ... with tf.compat.v1.Session() as sess:
2073 ... v = tf.Variable(0.0)
2074 ... a = tf.constant(1.0)
2075 ... sess.run(tf.compat.v1.global_variables_initializer())
2076 ... for i in range(5):
2077 ... update_op = v.assign_add(1.0)
2078 ... b = a + v
2079 ... res_b = sess.run(b)
2080 ... res_v = sess.run(v)
2081 ... print(res_v)
2082 0.0
2083 0.0
2084 0.0
2085 0.0
2086 0.0
2088 >>> with tf.Graph().as_default():
2089 ... with tf.compat.v1.Session() as sess:
2090 ... v = tf.Variable(0.0)
2091 ... a = tf.constant(1.0)
2092 ... sess.run(tf.compat.v1.global_variables_initializer())
2093 ... for i in range(5):
2094 ... update_op = v.assign_add(1.0)
2095 ... calc = [a + v]
2096 ... # `tf.tuple` ensures `update_op` is run before `b`
2097 ... b = tf.tuple(calc, [tf.group(update_op)])
2098 ... res_b = sess.run(b)
2099 ... res_v = sess.run(v)
2100 ... print(res_v)
2101 1.0
2102 2.0
2103 3.0
2104 4.0
2105 5.0
2108 Args:
2109 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
2110 control_inputs: List of additional ops to finish before returning.
2111 name: (optional) A name to use as a `name_scope` for the operation.
2113 Returns:
2114 Same as `tensors`.
2116 Raises:
2117 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
2118 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
2119 objects.
2121 """
2122 return tuple(tensors=tensors, name=name, control_inputs=control_inputs) # pylint: disable=redefined-builtin
2125@tf_export(v1=["tuple"])
2126@dispatch.add_dispatch_support
2127def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin
2128 """Group tensors together.
2130 This creates a tuple of tensors with the same values as the `tensors`
2131 argument, except that the value of each tensor is only returned after the
2132 values of all tensors have been computed.
2134 `control_inputs` contains additional ops that have to finish before this op
2135 finishes, but whose outputs are not returned.
2137 This can be used as a "join" mechanism for parallel computations: all the
2138 argument tensors can be computed in parallel, but the values of any tensor
2139 returned by `tuple` are only available after all the parallel computations
2140 are done.
2142 See also `tf.group` and
2143 `tf.control_dependencies`.
2145 Args:
2146 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
2147 name: (optional) A name to use as a `name_scope` for the operation.
2148 control_inputs: List of additional ops to finish before returning.
2150 Returns:
2151 Same as `tensors`.
2153 Raises:
2154 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
2155 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
2156 objects.
2158 """
2159 if context.executing_eagerly():
2160 return tensors
2161 with ops.name_scope(name, "tuple", tensors) as name:
2162 tensors = [
2163 t if (isinstance(t, ops.Operation) or tensor_util.is_tf_type(t) or
2164 t is None) else ops.convert_to_tensor(t) for t in tensors
2165 ]
2166 gating_ops = [
2167 t if isinstance(t, ops.Operation) else t.op
2168 for t in tensors
2169 if t is not None
2170 ]
2171 if control_inputs:
2172 for c in control_inputs:
2173 if isinstance(c, ops.Tensor):
2174 c = c.op
2175 elif not isinstance(c, ops.Operation):
2176 raise TypeError(
2177 "'control_inputs' must only contain Operation or Tensor. "
2178 f"Received: {type(c)}")
2179 gating_ops.append(c)
2180 # Note that in order to ensure ordering in the pbtxt, we must take care to
2181 # ensure the order here.
2182 gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops.
2183 if not gating_ops:
2184 raise ValueError("'tensors' must have at least one Tensor. "
2185 f"Received: {tensors}.")
2186 gate = group(*gating_ops)
2187 tpl = []
2188 for t in tensors:
2189 if tensor_util.is_tf_type(t):
2190 tpl.append(with_dependencies([gate], t))
2191 elif isinstance(t, ops.Operation):
2192 with ops.control_dependencies([gate]):
2193 tpl.append(group(t))
2194 else:
2195 tpl.append(None)
2196 return tpl
2199class XLAControlFlowContext(ControlFlowContext):
2200 """Base class for XLA and TPU control flow contexts."""
2202 def __init__(self):
2203 super(XLAControlFlowContext, self).__init__()
2204 self._name = "XLAControlFlowContext"
2206 def to_control_flow_context_def(self, context_def, export_scope=None):
2207 # pylint: disable=useless-super-delegation
2208 # NOTE(slebedev): the method is required by `ControlFlowContext`.
2209 super(XLAControlFlowContext,
2210 self).to_control_flow_context_def(context_def, export_scope)
2212 def IsXLAContext(self):
2213 return True
2215 def AddOp(self, _):
2216 pass
2218 def AddValue(self, x):
2219 return x
2221 def RequiresUniqueFunctionRetracing(self):
2222 """Returns whether the tf.function should be retraced if the context changes.
2223 """
2224 return False
2227@tf_export("__internal__.get_enclosing_xla_context", v1=[])
2228def get_enclosing_xla_context():
2229 """Recursively find and return the XLAControlFlowContext."""
2230 graph = ops.get_default_graph()
2231 while graph is not None:
2232 # pylint: disable=protected-access
2233 context_ = graph._get_control_flow_context()
2234 # pylint: enable=protected-access
2235 while context_ is not None:
2236 if isinstance(context_, XLAControlFlowContext):
2237 return context_
2238 context_ = context_.outer_context
2239 # This may be a FuncGraph due to defuns or v2 control flow. We need to
2240 # find the original graph with the XLAControlFlowContext.
2241 graph = getattr(graph, "outer_graph", None)
2242 return None
2245def from_control_flow_context_def(context_def, import_scope=None):
2246 """Deserializes `context_def` into the appropriate ControlFlowContext.
2248 Args:
2249 context_def: ControlFlowContextDef proto
2250 import_scope: Optional `string`. Name scope to add.
2252 Returns:
2253 A ControlFlowContext subclass
2254 """
2255 if context_def.HasField("cond_ctxt"):
2256 return CondContext.from_proto(
2257 context_def.cond_ctxt, import_scope=import_scope)
2258 if context_def.HasField("while_ctxt"):
2259 return WhileContext.from_proto(
2260 context_def.while_ctxt, import_scope=import_scope)
2261 raise NotImplementedError("Unknown ControlFlowContextDef field: %s" %
2262 context_def.WhichOneof("ctxt"))
2265ops.register_proto_function(
2266 ops.GraphKeys.COND_CONTEXT,
2267 proto_type=control_flow_pb2.CondContextDef,
2268 to_proto=CondContext.to_proto,
2269 from_proto=CondContext.from_proto)
2271ops.register_proto_function(
2272 ops.GraphKeys.WHILE_CONTEXT,
2273 proto_type=control_flow_pb2.WhileContextDef,
2274 to_proto=WhileContext.to_proto,
2275 from_proto=WhileContext.from_proto)