Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py: 12%
575 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Control flow statements: loops, conditionals, etc.
17Note: most of these operators accept pairs of get_state/set_state functions, to
18capture mutations that the corresponding code blocks might make. These
19mutations only need to be captured when staging the control flow, and they just
20work when reverting to Python behavior.
22__Examples__
24```
25while cond:
26 self.x += i
27```
29When the functionalized version is executed as a Python loop, it just works:
31```
32def loop_body():
33 self.x += i # works as expected for Python loops
34```
36But it won't work for TF loops:
38```
39def loop_body():
40 self.x += i # self.x has the wrong value!
41```
43get_state/set_state allow piping the mutations through the loop variables as
44well, in effect changing the loop body:
46```
47def loop_body(self_x):
48 self.x = self_x # self.x now has the proper value
49 self.x += i # the original block
50 self_x = self.x # write self.x back into the loop vars
51 return self_x
53self_x = tf.while_loop(...)
54self.x = self_x # the result is not properly captured
55```
56"""
58import functools
59import sys
60import traceback
62import numpy as np
64from tensorflow.python.autograph.operators import py_builtins
65from tensorflow.python.autograph.operators import variables
66from tensorflow.python.autograph.utils import ag_logging
67from tensorflow.python.autograph.utils import misc
68from tensorflow.python.autograph.utils import tensors
69from tensorflow.python.autograph.utils import type_registry
70from tensorflow.python.framework import dtypes
71from tensorflow.python.framework import errors_impl
72from tensorflow.python.framework import func_graph
73from tensorflow.python.framework import ops
74from tensorflow.python.framework import tensor_conversion
75from tensorflow.python.framework import tensor_shape
76from tensorflow.python.framework import tensor_util
77from tensorflow.python.ops import array_ops
78from tensorflow.python.ops import cond as tf_cond
79from tensorflow.python.ops import control_flow_assert
80from tensorflow.python.ops import control_flow_util
81from tensorflow.python.ops import math_ops
82from tensorflow.python.ops import tensor_array_ops
83from tensorflow.python.ops import while_loop
84from tensorflow.python.ops.ragged import ragged_tensor
85from tensorflow.python.types import distribute
86from tensorflow.python.util import nest
87from tensorflow.python.util import variable_utils
90PYTHON_MAX_ITERATIONS = 100000000 # Fails in about one minute for empty loops.
91WARN_INEFFICIENT_UNROLL = True
92INEFFICIENT_UNROLL_MIN_ITERATIONS = 50000
93INEFFICIENT_UNROLL_MIN_OPS = 1
96# TODO(mdan): Use the custom operator pattern instead of type dispatch.
97# An example of this pattern is found in the implementation of distributed
98# datasets. Before it can be used though, we need to standardize the interface.
100for_loop_registry = type_registry.TypeRegistry()
103def _is_none_or_undef(value):
104 """Tests whether a value is None or undefined.
106 AutoGraph represents undefined symbols using special objects of type Undefined
107 or UndefinedReturnValue.
109 Args:
110 value: value to test
112 Returns:
113 Boolean
114 """
115 return ((value is None)
116 or isinstance(value, variables.UndefinedReturnValue)
117 or isinstance(value, variables.Undefined))
120def _verify_tf_condition(cond, tag):
121 """Ensures that the condition can be used in a TF control flow."""
122 extra_hint = 'to check for None, use `is not None`'
123 cond = tensor_conversion.convert_to_tensor_v2(cond)
125 if cond.dtype != dtypes.bool:
126 raise ValueError(
127 'condition of {} expected to be `tf.bool` scalar, got {}'
128 '; to use as boolean Tensor, use `tf.cast`'
129 '; {}'.format(tag, cond, extra_hint))
131 if cond.shape is None or cond.shape.ndims is None:
132 # TODO(mdan): Consider a explicit size check, if not too slow.
133 cond = array_ops.reshape(cond, ())
135 elif cond.shape.ndims > 0:
136 known_dims = [d for d in cond.shape.as_list() if d is not None]
137 if np.prod(known_dims) > 1:
138 raise ValueError(
139 'condition of {} expected to be `tf.bool` scalar, got {}'
140 '; {}'.format(tag, cond, extra_hint))
141 else:
142 cond = array_ops.reshape(cond, ())
144 return cond
147def verify_loop_init_vars(
148 init_vars, symbol_names, first_iter_vars=None, extra_message=None
149):
150 """Ensures that all values in the state are valid to use in a TF loop.
152 The init_vars may contain placeholder values derived from first_iter_vars.
154 Args:
155 init_vars: initial loop variables (as taken before entering the loop)
156 symbol_names: corresponding names of the initial loop variables
157 first_iter_vars: loop variables after one iteration of the loop
158 extra_message: an extra string to append to the error message, in case of
159 "undefined variable" errors (see variables.Undefined)
160 """
161 if not symbol_names:
162 return
163 if first_iter_vars is None:
164 first_iter_vars = (None,) * len(symbol_names)
166 assert len(symbol_names) == len(init_vars)
167 assert len(symbol_names) == len(first_iter_vars)
168 for name, val, fi_val in zip(symbol_names, init_vars, first_iter_vars):
169 if isinstance(val, variables.UndefinedReturnValue):
170 if fi_val:
171 raise ValueError(
172 'the return value from a TensorFlow loop may only be a {}; got {}'
173 .format(LEGAL_LOOP_TYPES, type(fi_val)))
174 else:
175 # TODO(mdan): This can be handled by removing the return value.
176 raise NotImplementedError(
177 'a return statement cannot be placed inside this TensorFlow loop;'
178 ' this may happen if a return statement depends on a'
179 ' static Python condition such as a hyperparameter')
181 error_msg = None
182 if val is None:
183 error_msg = "'{}' is not allowed to be None before the loop".format(name)
184 elif isinstance(val, variables.Undefined):
185 error_msg = "'{}' must be defined before the loop".format(name)
186 if extra_message:
187 error_msg += '\n' + extra_message
189 if error_msg is not None:
190 raise ValueError(error_msg)
193def _is_subshape(left, right):
194 """Returns True if left shape is at least as specific as right shape."""
195 # TODO(mdan): This code should be in TensorShape.
196 # Note: this is not the same as TensorShape.is_compatible_with, which is
197 # symmetric.
198 # This code also duplicates _ShapeLessThanOrEqual from control_flow_ops.py.
199 if right.dims is None:
200 return True
201 if left.ndims != right.ndims:
202 return False
203 for ldim, rdim in zip(left.dims, right.dims):
204 if rdim.value is not None and ldim.value != rdim.value:
205 return False
206 return True
209# TODO(mdan): Remove these verifications once TF ops can properly report names.
210def _verify_single_loop_var(
211 name, check_shape, init, entry, exit_, shape_invariant):
212 """Verifies whether the initial, entry and exit values are consistent."""
213 assert entry is not None, "no TF op should set '{}' to None?".format(name)
214 if exit_ is None:
215 raise ValueError("'{}' is None at the end of the iteration.".format(name))
217 if isinstance(init, (bool, int, float, str, np.ndarray)):
218 init = tensor_conversion.convert_to_tensor_v2(init)
219 if isinstance(entry, (bool, int, float, str, np.ndarray)):
220 entry = tensor_conversion.convert_to_tensor_v2(entry)
221 if isinstance(exit_, (bool, int, float, str, np.ndarray)):
222 exit_ = tensor_conversion.convert_to_tensor_v2(exit_)
224 if (not tensor_util.is_tf_type(entry) or
225 not tensor_util.is_tf_type(exit_)):
226 return
228 # TODO(mdan): Properly account for CompositeTensors.
229 if (not hasattr(entry, 'dtype') or
230 not hasattr(exit_, 'dtype')):
231 return
232 if (not hasattr(entry, 'shape') or
233 not hasattr(exit_, 'shape')):
234 return
236 if entry.dtype != exit_.dtype:
237 raise TypeError(
238 "'{}' has dtype {} before the loop, but dtype {} after one"
239 ' iteration'.format(
240 name,
241 entry.dtype.name,
242 exit_.dtype.name,
243 ))
244 if check_shape:
245 exit_shape = exit_.shape
246 if shape_invariant is None:
247 entry_shape = entry.shape
248 if not _is_subshape(exit_shape, entry_shape):
249 raise ValueError(
250 "'{}' has shape {} before the loop, but shape {} after one"
251 ' iteration. Use tf.autograph.experimental.set_loop_options to set'
252 ' shape invariants.'.format(name, entry_shape, exit_shape))
253 else:
254 init_shape = init.shape
255 if not _is_subshape(init_shape, shape_invariant):
256 raise ValueError(
257 "'{}' has shape {} before the loop, which does not conform with"
258 ' the shape invariant {}.'.format(name, init_shape,
259 shape_invariant))
260 if not _is_subshape(exit_shape, shape_invariant):
261 raise ValueError(
262 "'{}' has shape {} after one iteration, which does not conform with"
263 ' the shape invariant {}.'.format(name, exit_shape, shape_invariant)
264 )
267def verify_tf_loop_vars(
268 init_vars,
269 iter_entry_vars,
270 iter_exit_vars,
271 symbol_names,
272 opts,
273 check_shapes=True,
274):
275 """Verifies loop variables for consistency."""
276 if check_shapes and 'shape_invariants' in opts:
277 shape_invariants = opts['shape_invariants']
278 else:
279 shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars)
281 assert len(symbol_names) == len(shape_invariants)
282 assert len(symbol_names) == len(init_vars)
283 assert len(symbol_names) == len(iter_entry_vars)
284 assert len(symbol_names) == len(iter_exit_vars)
286 for i in range(len(symbol_names)):
287 name = symbol_names[i]
288 init = init_vars[i]
289 entry = iter_entry_vars[i]
290 exit_ = iter_exit_vars[i]
291 invariant = shape_invariants[i]
293 try:
294 nest.assert_same_structure(init, entry, expand_composites=True)
295 except (ValueError, TypeError):
296 # `Variable`s in `init` may be implicitly converted to `Tensor`s. Convert
297 # `ResourceVariable`s to Tensors so tf.nest.assert_same_structure
298 # won't break due to type spec mismatches between `ResourceVariable`s and
299 # `Tensor`s.
300 try:
301 init_tensors = variable_utils.convert_variables_to_tensors(init)
302 nest.assert_same_structure(init_tensors, entry, expand_composites=True)
303 except (ValueError, TypeError) as e:
304 raise TypeError("'{}' does not have the same nested structure after one"
305 ' iteration.\n\n{}'.format(name, e)) from e
307 try:
308 nest.assert_same_structure(entry, exit_, expand_composites=True)
309 except (ValueError, TypeError) as e:
310 raise TypeError("'{}' does not have the same nested structure after one"
311 ' iteration.\n\n{}'.format(name, e)) from e
312 if invariant is not None:
313 try:
314 nest.assert_same_structure(init, invariant, expand_composites=False)
315 except (ValueError, TypeError) as e:
316 raise TypeError("'{}' does not have the same nested structure as its"
317 ' corresponding shape invariant.\n\n{}'.format(
318 name, e)) from e
320 nest.map_structure(
321 functools.partial(_verify_single_loop_var, name, check_shapes), init,
322 entry, exit_, invariant)
325def verify_single_cond_var(name, body_var, orelse_var):
326 """Verifies whether body_var and orelse_var are consistent."""
327 if body_var is None:
328 raise ValueError("'{}' is None at the end of the main branch.".format(name))
329 if orelse_var is None:
330 raise ValueError(
331 "'{}' is None at the end of the else branch.".format(name))
333 if isinstance(body_var, (bool, int, float, str, np.ndarray)):
334 body_var = tensor_conversion.convert_to_tensor_v2(body_var)
336 if isinstance(orelse_var, (bool, int, float, str, np.ndarray)):
337 orelse_var = tensor_conversion.convert_to_tensor_v2(orelse_var)
339 if (not tensor_util.is_tf_type(body_var) or
340 not tensor_util.is_tf_type(orelse_var)):
341 return
343 # TODO(mdan): Properly account for CompositeTensors.
344 if (not hasattr(body_var, 'dtype') or
345 not hasattr(orelse_var, 'dtype')):
346 return
348 if body_var.dtype != orelse_var.dtype:
349 raise TypeError(
350 "'{}' has dtype {} in the main branch, but dtype {} in the else"
351 ' branch'.format(name, body_var.dtype.name,
352 orelse_var.dtype.name))
355def _verify_tf_cond_branch_vars(vars_, symbol_names, branch_name):
356 """Verifies variables output by a conditional branch for consistency."""
357 for name, var_ in zip(symbol_names, vars_):
358 if isinstance(var_, variables.Undefined):
359 raise ValueError(
360 "'{}' must also be initialized in the {} branch".format(
361 name, branch_name))
362 if isinstance(var_, variables.UndefinedReturnValue):
363 raise ValueError(
364 'the {} branch must also have a return statement.'.format(
365 branch_name))
368def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names):
369 """Verifies variables manipulated by a conditional for consistency."""
370 named_vars = zip(symbol_names, body_vars, orelse_vars)
372 for name, body_var, orelse_var in named_vars:
373 try:
374 nest.assert_same_structure(body_var, orelse_var, expand_composites=True)
375 except (ValueError, TypeError):
376 # One branch of cond could be a `Tensor`, while the other branch could be
377 # a `ResourceVariable`. Convert `ResourceVariable`s to `Tensor`s so
378 # assert_same_structure won't fail.
379 try:
380 body_var_tensors = variable_utils.convert_variables_to_tensors(body_var)
381 orelse_var_tensors = variable_utils.convert_variables_to_tensors(
382 orelse_var)
383 nest.assert_same_structure(body_var_tensors, orelse_var_tensors,
384 expand_composites=True)
385 except (ValueError, TypeError) as e:
386 raise TypeError(
387 "'{}' must have the same nested structure in the main and else"
388 ' branches:\n\n{}'.format(name, str(e))) from e
389 nest.map_structure(
390 functools.partial(verify_single_cond_var, name), body_var, orelse_var)
393def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts):
394 """Functional form of a for statement.
396 The loop operates on a state, which includes all symbols that are
397 variant across loop iterations, excluding the variables local to the loop.
399 For example, given the loop below that calculates the geometric and
400 arithmetic means or some numbers:
402 ```
403 geo_mean = 1
404 arith_mean = 0
405 for i in range(n):
406 a = numbers[i]
407 geo_mean *= a
408 arith_mean += a
409 ```
411 The state is represented by the variables named geo_mean and arith_mean. The
412 `extra_test`, `body`, `get_state` and `set_state` functions must bind to the
413 original `geo_mean` and `arith_mean` symbols, using `nonlocal`.
415 The inputs and outputs of the callables representing the loop blocks are not
416 explicit - instead, these functions must use nonlocal/global for side effects.
417 The inputs and outputs are instead controlled by the set_state/get_state
418 functions.
420 Args:
421 iter_: The entity being iterated over.
422 extra_test: Callable with boolean return type. An additional loop condition.
423 body: Callable representing the actual loop body.
424 get_state: Additional callable which can capture additional state (such as
425 the values of composite symbols). This is only useful when staging the
426 loop.
427 set_state: Additional callable which save values captured by get_state back
428 into the Python environment. This is only useful when staging the loop.
429 symbol_names: Tuple containing names of the loop variables returned by
430 get_state.
431 opts: Optional dict of extra loop parameters.
432 """
434 try:
435 for_fn = for_loop_registry.lookup(iter_)
436 except LookupError:
437 for_fn = _py_for_stmt
439 # TODO(bwieder): Refactor isinstance(iter_, ragged_tensor.RaggedTensor) to use
440 # the registry once python/autograph/utils does not depend on dataset_ops.
441 if tensor_util.is_tf_type(iter_):
442 if tensors.is_range_tensor(iter_):
443 for_fn = _tf_range_for_stmt
444 elif isinstance(iter_, ragged_tensor.RaggedTensor):
445 for_fn = _tf_ragged_for_stmt
446 else:
447 for_fn = _known_len_tf_for_stmt
448 elif isinstance(iter_, distribute.Iterator):
449 for_fn = _tf_iterator_for_stmt
450 elif isinstance(iter_, distribute.Iterable):
451 # TODO(b/162250181): Use _tf_iterator_for_stmt(iter(iter_)...
452 for_fn = _tf_distributed_iterable_for_stmt
454 for_fn(iter_, extra_test, body, get_state, set_state, symbol_names, opts)
457def _py_for_stmt(
458 iter_, extra_test, body, get_state, set_state, symbol_names, opts
459):
460 """Overload of for_stmt that executes a Python for loop."""
461 del get_state, set_state, symbol_names, opts
463 if __debug__:
464 checker = _PythonLoopChecker()
465 before_iteration = checker.before_iteration
466 after_iteration = checker.after_iteration
467 before_iteration()
469 original_body = body
470 def protected_body(protected_iter):
471 original_body(protected_iter)
472 after_iteration()
473 before_iteration()
474 body = protected_body
476 if extra_test is not None:
477 def guarded_extra_test():
478 extra_test_result = extra_test()
479 try:
480 # Note: Using try/except and not tensor_util.is_tf_type to avoid
481 # performance degradation.
482 return bool(extra_test_result)
483 except errors_impl.OperatorNotAllowedInGraphError as e:
484 ag_logging.log(
485 1,
486 'Caught error while evaluating loop stop condition',
487 exc_info=True)
488 # TODO(mdan): We can pass the location of extra_test and show it here.
489 raise NotImplementedError(
490 'break and return statements which depend on a TF condition are not'
491 ' supported in Python for loops. Did you intend to make it a TF'
492 ' loop?\nSee '
493 'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
494 'python/autograph/g3doc/reference/limitations.md'
495 '#consistency-of-control-flow-types for more info.') from e
497 if guarded_extra_test():
498 for target in iter_:
499 body(target)
500 if not guarded_extra_test():
501 break
503 else:
504 for target in iter_:
505 body(target)
508def _add_max_iterations_hint(opts, n):
509 # TODO(b/159186914): Remove the safeguard, and always set maximum_iterations.
510 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
511 opts['maximum_iterations'] = n
514def _known_len_tf_for_stmt(
515 iter_, extra_test, body, get_state, set_state, symbol_names, opts):
516 """Overload of for_stmt that iterates over TF entities that admit a length."""
517 n = py_builtins.len_(iter_)
519 # TODO(b/117628877): Revisit performance once XLA has the necessary support.
520 # Note: using a TensorArray creates an extra copy, but can calculate
521 # gradients more efficiently than StridedSlice.
522 ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
523 iter_ = ta.unstack(iter_)
525 iterate_index = 0
527 def aug_get_state():
528 return (iterate_index,) + get_state()
530 def aug_set_state(aug_loop_vars):
531 nonlocal iterate_index
532 # TODO(b/171479293): Drop the lint override.
533 iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable
534 # The iteration index is not "output" by the for loop. If the iteration index
535 # is used outside the loop, it will appear in the loop vars separately.
536 set_state(loop_vars)
538 def aug_body():
539 nonlocal iterate_index
540 body(iter_.read(iterate_index))
541 iterate_index += 1
543 def aug_test():
544 main_test = iterate_index < n
545 if extra_test is not None:
546 return tf_cond.cond(main_test, extra_test, lambda: False)
547 return main_test
549 _add_max_iterations_hint(opts, n)
551 _tf_while_stmt(
552 aug_test,
553 aug_body,
554 aug_get_state,
555 aug_set_state,
556 ('<internal iterate>',) + symbol_names,
557 opts,
558 )
561def _tf_ragged_for_stmt(
562 iter_, extra_test, body, get_state, set_state, symbol_names, opts):
563 """Overload of for_stmt that iterates over TF ragged tensors."""
564 init_vars = get_state()
565 verify_loop_init_vars(init_vars, symbol_names)
567 # TODO(mdan): Move this into len()? Requires eager support.
568 if iter_.shape and iter_.shape[0] is not None:
569 n = iter_.shape[0]
570 else:
571 n = iter_.row_lengths()[0]
573 iterate_index = 0
575 def aug_get_state():
576 return (iterate_index,) + get_state()
578 def aug_set_state(aug_loop_vars):
579 nonlocal iterate_index
580 # TODO(b/171479293): Drop the lint override.
581 iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable
582 # The iteration index is not "output" by the for loop. If the iteration index
583 # is used outside the loop, it will appear in the loop vars separately.
584 set_state(loop_vars)
586 def aug_body():
587 nonlocal iterate_index
588 body(iter_[iterate_index])
589 iterate_index += 1
591 def aug_test():
592 main_test = iterate_index < n
593 if extra_test is not None:
594 return tf_cond.cond(main_test, extra_test, lambda: False)
595 return main_test
597 _add_max_iterations_hint(opts, n)
599 _tf_while_stmt(
600 aug_test,
601 aug_body,
602 aug_get_state,
603 aug_set_state,
604 ('<internal iterate>',) + symbol_names,
605 opts)
608def _tf_range_for_stmt(
609 iter_, extra_test, body, get_state, set_state, symbol_names, opts):
610 """Overload of for_stmt that iterates over a TF range (and elides it)."""
611 start, limit, delta = iter_.op.inputs
613 iterate = start
615 def _value_or(name, var, default):
616 if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)):
617 return default
618 return var
620 def aug_get_state():
621 state_vars = get_state()
622 state_vars = tuple(
623 _value_or(name, var, iterate)
624 for name, var in zip(symbol_names, state_vars))
625 return (iterate,) + state_vars
627 def aug_set_state(aug_loop_vars):
628 nonlocal iterate
629 # TODO(b/171479293): Drop the lint override.
630 iterate, *loop_vars = aug_loop_vars # pylint:disable=unused-variable
631 # The iteration index is not "output" by the for loop. If the iterate
632 # is used outside the loop, it will appear in the loop vars separately.
633 set_state(loop_vars)
635 def aug_body():
636 nonlocal iterate
637 body(iterate)
638 iterate += delta
640 def aug_test():
641 # TODO(b/159713842): Remove once constant folding works.
642 const_delta = tensor_util.constant_value(delta)
643 if const_delta is not None:
644 if const_delta >= 0:
645 main_test = iterate < limit
646 else:
647 main_test = iterate > limit
648 else:
649 main_test = math_ops.logical_or(
650 math_ops.logical_and(delta >= 0, iterate < limit),
651 math_ops.logical_and(delta < 0, iterate > limit))
653 if extra_test is not None:
654 main_test = tf_cond.cond(main_test, extra_test, lambda: False)
655 return main_test
657 _add_max_iterations_hint(
658 opts,
659 math_ops.cast(misc.get_range_len(start, limit, delta), dtypes.int32))
661 _tf_while_stmt(
662 aug_test,
663 aug_body,
664 aug_get_state,
665 aug_set_state,
666 ('<internal iterate>',) + symbol_names,
667 opts)
670def _tf_iterator_for_stmt(
671 iter_, extra_test, body, get_state, set_state, symbol_names, opts):
672 """Overload of for_stmt that iterates over TF Iterators. See for_loop."""
673 symbol_names = ('<internal has_next>',) + symbol_names
674 has_next = True
676 def aug_get_state():
677 return (has_next,) + get_state()
679 def aug_set_state(aug_loop_vars):
680 nonlocal has_next
681 # TODO(b/171479293): Drop the lint override.
682 has_next, *loop_vars = aug_loop_vars # pylint:disable=unused-variable
683 set_state(loop_vars)
685 init_vars = aug_get_state()
686 verify_loop_init_vars(init_vars, symbol_names)
688 def aug_body():
689 """Main body passed to _tf_while_stmt."""
690 nonlocal has_next
691 opt_iterate = iter_.get_next_as_optional()
692 has_next = opt_iterate.has_value()
693 loop_vars = aug_get_state() # updated by set_state() in _tf_while_loop.
695 def main_path():
696 body(opt_iterate.get_value())
697 new_loop_vars = aug_get_state()
698 # Note: this verification duplicates the one performed in tf_while_stmt,
699 # but needs to be done earlier to prevent the tf.cond from blowing up
700 # first.
701 verify_tf_loop_vars(
702 init_vars, loop_vars, new_loop_vars, symbol_names, opts)
703 return new_loop_vars
705 def noop_path():
706 return loop_vars
708 # TODO(mdan): If tf.while_loop supported Optional, this could be avoided.
709 # Calling set_state so that get_state() _tf_while_loop sees the conditional
710 # tensors.
711 aug_set_state(
712 tf_cond.cond(has_next, main_path, noop_path))
714 def aug_test():
715 # This value takes a complicated path to get here:
716 # prev_iteration_body -> get_state -> tf.while_loop (as loop var)
717 # -> current_iteration_body -> set_state -> has_next
718 main_test = has_next
719 if extra_test is not None:
720 return tf_cond.cond(main_test, extra_test, lambda: False)
721 return main_test
723 _tf_while_stmt(
724 aug_test,
725 aug_body,
726 aug_get_state,
727 aug_set_state,
728 symbol_names,
729 opts)
732def _tf_distributed_iterable_for_stmt(
733 iter_, extra_test, body, get_state, set_state, symbol_names, opts):
734 """Overload of for_stmt that iterates over TF distributed datasets."""
736 if extra_test is not None:
737 raise NotImplementedError(
738 'break and return statements are not yet supported in '
739 'for ... in distributed input loops.')
741 init_vars = get_state()
742 verify_loop_init_vars(init_vars, symbol_names)
744 if 'shape_invariants' in opts:
745 opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
746 opts['shape_invariants'], init_vars)
748 def reduce_body(loop_vars, iterate):
749 set_state(loop_vars)
750 body(iterate)
751 new_loop_vars = get_state()
752 verify_tf_loop_vars(
753 init_vars, loop_vars, new_loop_vars, symbol_names, opts)
754 return new_loop_vars
756 set_state(iter_.reduce(init_vars, reduce_body))
759def while_stmt(test, body, get_state, set_state, symbol_names, opts):
760 """Functional form of a while statement.
762 The loop operates on a so-called state, which includes all symbols that are
763 variant across loop iterations. In what follows we refer to state as either
764 a tuple of entities that represent an actual state, or a list of arguments
765 of the corresponding types.
767 The inputs and outputs of the callables representing the loop blocks are not
768 explicit - instead, these functions must use nonlocal/global for side effects.
769 The inputs and outputs are instead controlled by the set_state/get_state
770 functions.
772 Args:
773 test: Callable with boolean return type. The loop condition.
774 body: Callable representing the actual loop body.
775 get_state: Additional callable which can capture additional state (such as
776 the values of composite symbols). This is only useful when staging the
777 loop.
778 set_state: Additional callable which save values captured by get_state back
779 into the Python environment. This is only useful when staging the loop.
780 symbol_names: Tuple containing the names of all loop variables.
781 opts: Optional dict of extra loop parameters.
783 Returns:
784 Tuple containing the final state.
785 """
787 # Evaluate the initial test once in order to do the dispatch. The evaluation
788 # is isolated to minimize unwanted side effects.
789 # TODO(mdan): Do a full iteration - some state types might lower to Tensor.
790 with func_graph.FuncGraph('tmp').as_default():
791 init_test = test()
793 # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine
794 # with the re-evaluation of `test` that `_tf_while_stmt` will make.
795 if tensors.is_dense_tensor(init_test):
796 _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts)
797 return
799 # Normal Python: We already consumed one evaluation of `test`; consistently,
800 # unroll one iteration before dispatching to a normal loop.
801 # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt?
802 if not init_test:
803 return
804 body()
806 _py_while_stmt(test, body, get_state, set_state, opts)
809class _PythonLoopChecker(object):
810 """Verifies Python loops for TF-specific limits."""
812 __slots__ = (
813 'iterations',
814 'check_inefficient_unroll',
815 'check_op_count_after_iteration',
816 'ops_before_iteration',
817 )
819 def __init__(self):
820 self.iterations = 1
821 self.check_inefficient_unroll = WARN_INEFFICIENT_UNROLL
823 # Triggered when we decided to test the op counts.
824 self.check_op_count_after_iteration = False
826 def _get_ops(self):
827 return set(ops.get_default_graph().get_operations())
829 def _check_unroll_limits(self):
830 if self.iterations > PYTHON_MAX_ITERATIONS:
831 raise ValueError('iteration limit exceeded')
833 def _stop_checking_inefficient_unroll(self):
834 self.check_inefficient_unroll = False
835 self.check_op_count_after_iteration = False
836 self.ops_before_iteration = None
838 def _verify_inefficient_unroll(self):
839 """Checks for possibly-inefficient creation of ops in a Python loop."""
840 assert self.ops_before_iteration is not None
841 ops_after_iteration = self._get_ops()
842 new_ops = tuple(
843 op for op in ops_after_iteration if op not in self.ops_before_iteration)
845 if len(new_ops) < INEFFICIENT_UNROLL_MIN_OPS:
846 return False
848 ag_logging.warning(
849 'Large unrolled loop detected. Did you mean to use a TF loop?'
850 ' The following ops were created after iteration %s: %s'
851 '\nSee'
852 ' https://github.com/tensorflow/tensorflow/blob/master/'
853 'tensorflow/python/autograph/g3doc/reference/common_errors.md'
854 '#warning-large-unrolled-loop-detected'
855 '\n'
856 'Location:'
857 '\n%s'
858 '', self.iterations, new_ops, '\n'.join(traceback.format_stack()))
859 return True
861 def before_iteration(self):
862 """Called before each iteration in a Python loop."""
863 if (self.check_inefficient_unroll and
864 self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS):
865 self.ops_before_iteration = self._get_ops()
866 self.check_op_count_after_iteration = True
868 def after_iteration(self):
869 """Called after each iteration in a Python loop."""
870 self.iterations += 1
872 self._check_unroll_limits()
874 if self.check_op_count_after_iteration:
875 did_warn = self._verify_inefficient_unroll()
876 if did_warn:
877 self._stop_checking_inefficient_unroll() # Only warn once.
878 elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3:
879 # Once deciding to check the op counts, only do it for a few iterations.
880 self._stop_checking_inefficient_unroll()
883def _py_while_stmt(test, body, get_state, set_state, opts):
884 """Overload of while_stmt that executes a Python while loop."""
885 del opts, get_state, set_state
887 if __debug__:
888 checker = _PythonLoopChecker()
889 before_iteration = checker.before_iteration
890 after_iteration = checker.after_iteration
891 before_iteration()
893 original_body = body
894 def protected_body():
895 original_body()
896 after_iteration()
897 before_iteration()
898 body = protected_body
900 def guarded_test():
901 test_result = test()
902 try:
903 # Note: Using try/except and not tensor_util.is_tf_type to avoid
904 # performance degradation.
905 return bool(test_result)
906 except errors_impl.OperatorNotAllowedInGraphError as e:
907 ag_logging.log(
908 1,
909 'Caught error while evaluating while loop condition',
910 exc_info=True)
911 # TODO(mdan): distinguish beteen these two cases.
912 raise NotImplementedError(
913 'The condition of while loop started as non-Tensor, then changed to'
914 ' Tensor. This may happen either because variables changed type, or'
915 ' when a break or return statement inside the loop depends on a'
916 ' Tensor condition. In both cases, changing to a TF loop should'
917 ' remove the error.\nSee '
918 'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
919 'python/autograph/g3doc/reference/limitations.md'
920 '#consistency-of-control-flow-types for more info.') from e
921 while guarded_test():
922 body()
925def _shape_invariants_mapping_to_positional_list(mapping, keys):
926 # The keys are not expected to be hashable.
927 mapping = {id(k): (k, v) for k, v in mapping}
928 result = []
929 for k in keys:
930 map_key, map_val = mapping.get(id(k), (None, None))
931 result.append(
932 map_val if map_key is k else nest.map_structure(lambda _: None, k))
933 return tuple(result)
936# Textual description of what a legal TF loop variable is. This description
937# summarizes types that _placeholder_value below can handle. Keep the two
938# together and in sync.
939LEGAL_LOOP_TYPES = 'Tensor, int, float, bool or a list, tuple or dict thereof'
942def _placeholder_value(like, shape_invariant, original=None):
943 """Constructs a (dummy) placeholder value for a loop-initialized variable.
945 Args:
946 like: Any object. The value created by the first iteration of the loop. If a
947 Python scalar, the placeholder will be the zero value of that type. If a
948 Tensor, the placeholder will be a zero tensor of matching shape and dtype.
949 If a list, dict or tuple, the placeholder will be an identical structure
950 of placeholders.
951 shape_invariant: The shape invariant specified by the user (or None, if
952 nothing was specified) for the respective variable.
953 original: Any object. The value of the variable prior to entering the loop.
954 Typically, this is one of the special "Undefined" value, because that's
955 when a placeholder is needed.
957 Returns:
958 Either a zero value of structure, shape and dtype mathing 'like', or
959 'original', if no such zero value could be created.
960 """
961 if like is None:
962 return original, None
964 elif isinstance(like, (variables.Undefined, variables.UndefinedReturnValue)):
965 return original, None
967 elif isinstance(like, (int, float, bool)):
968 return type(like)(0), None
970 elif tensor_util.is_tf_type(like):
972 like_shape = shape_invariant if shape_invariant is not None else like.shape
973 if like_shape is None or like_shape.rank is None:
974 return array_ops.zeros((), like.dtype), like_shape
976 # If the shape contains dynamic values, set the corresponding starting
977 # dimension to either zero or what the shape invariant specified.
978 placeholder_shape = []
979 has_dynamic_dims = False
980 for s, i in zip(like.shape, like_shape):
981 if i is None:
982 like_dim = 0
983 elif isinstance(i, tensor_shape.Dimension):
984 if i.value is None:
985 like_dim = 0
986 else:
987 like_dim = i.value
988 else:
989 like_dim = i
991 if s is None:
992 placeholder_shape.append(like_dim)
993 has_dynamic_dims = True
994 elif isinstance(s, tensor_shape.Dimension):
995 if s.value is None:
996 placeholder_shape.append(like_dim)
997 has_dynamic_dims = True
998 else:
999 placeholder_shape.append(s.value)
1000 else:
1001 placeholder_shape.append(s)
1003 if has_dynamic_dims:
1004 invariant = like_shape
1005 else:
1006 invariant = None
1008 return array_ops.zeros(placeholder_shape, like.dtype), invariant
1010 elif isinstance(like, (list, tuple, dict)):
1011 if shape_invariant is None:
1012 zipped = nest.map_structure(lambda v: _placeholder_value(v, None),
1013 nest.flatten(like))
1014 else:
1015 zipped = nest.map_structure(_placeholder_value, nest.flatten(like),
1016 nest.flatten(shape_invariant))
1017 vals, invars = zip(*zipped)
1018 return (nest.pack_sequence_as(like,
1019 vals), nest.pack_sequence_as(like, invars))
1021 # This is to be caught by _try_handling_undefineds, to give more context.
1022 raise TypeError(
1023 "Found an unsupported type '{}' while creating placeholder for {}."
1024 ' Supported types include Tensor, int, float, bool, list, tuple or dict.'
1025 .format(type(like).__name__, like))
1028def _try_handling_undefineds(body, get_state, set_state, init_vars, nulls,
1029 shape_invariants, symbol_names):
1030 """Makes a best-effort attempt to substitute undefineds with placeholders.
1032 Note: this substitution requires two things to happen:
1033 1. the types of loop variables could be inferred (usually by staging one
1034 iteration)
1035 2. these types could be replaced by placeholders (e.g. zero values, for
1036 tensors).
1038 Args:
1039 body: a function representing the loop body. See while_stmt.
1040 get_state: state getter for the loop statement. See while_stmt.
1041 set_state: state getter for the loop statement. See while_stmt.
1042 init_vars: loop variables before entering the loop. See while_stmt.
1043 nulls: list of boolean flags indicating whether the corresponding loop var
1044 is None or undefined.
1045 shape_invariants: user-specified shape invariant for each loop variable.
1046 symbol_names: list of loop variable names. See while_stmt.
1048 Returns:
1049 A tuple (success, new_init_vars, extra_shape_invariants, failure_message):
1050 * success is a boolean flag indicating
1051 whether types could be successfully inferred (step 1 above)
1052 * new_init_vars contains the loop vars, with None or undefined values
1053 replaced by default values, where possible (step 2 above)
1054 * extra_shape_invariants contains shape invariants that would be needed
1055 by while_stmt, for instance if the placeholder values had a shape
1056 different from the corresponding loop outputs
1057 """
1058 state_modified = False
1059 first_iter_vars = None
1060 failure_message = None
1062 try:
1063 # Stage an iteration of the loop body in a temporary graph.
1064 with func_graph.FuncGraph('tmp').as_default():
1065 # This call to set_state helps report nicer error messages when symbols
1066 # are inconsistently used.
1067 # Another complication is that non_tensor values will be autocast to
1068 # Tensor by while_loop, and their static value lost. So we need to account
1069 # that here.
1070 def autocast_to_tensor(v):
1071 if isinstance(
1072 v, (int, float, bool, str, list, tuple, np.ndarray, np.generic)):
1073 init_val = tensor_conversion.convert_to_tensor_v2(v)
1074 return array_ops.placeholder(init_val.dtype, init_val.shape)
1075 return v
1076 autocast_init_vars = nest.map_structure(autocast_to_tensor, init_vars)
1077 set_state(autocast_init_vars)
1078 state_modified = True
1080 body()
1081 first_iter_vars = get_state()
1083 # Note: the actual placeholder value doesn't matter, because as the
1084 # staging proved, it will be replaced by an actual value before being
1085 # read.
1086 inits_and_invariants = tuple(
1087 (_placeholder_value(iv, i, v) if n else (v, None))
1088 for v, n, iv, i in zip(init_vars, nulls, first_iter_vars,
1089 shape_invariants))
1090 init_vars, extra_shape_invariants = zip(*inits_and_invariants)
1091 success = True
1093 except (UnboundLocalError, TypeError, ValueError, KeyError):
1094 ag_logging.log(1, 'Caught error while staging loop body', exc_info=True)
1095 # Fall back to the old functionality. It will likely result in an input
1096 # validation failure.
1097 exc = sys.exc_info()
1098 failure_message = (
1099 'Note: AutoGraph tried to define it automatically, but ran into a'
1100 ' {}: {}'.format(exc[0].__name__, exc[1]))
1102 finally:
1103 if state_modified:
1104 set_state(init_vars)
1106 # This check runs regardless, in case we captured non-Tensor inputs.
1107 verify_loop_init_vars(
1108 init_vars, symbol_names, first_iter_vars, extra_message=failure_message)
1110 return success, init_vars, extra_shape_invariants
1113def _runtime_zero_iterations_errmsg(symbol_names, nulls, init_vars):
1114 """Creates an error message asking for the loop to iterate at least once."""
1115 var_names = []
1116 for sn, n, v in zip(symbol_names, nulls, init_vars):
1117 if not n:
1118 continue
1119 if isinstance(v, variables.UndefinedReturnValue):
1120 var_names.append('the function return value')
1121 else:
1122 var_names.append(sn)
1123 var_names = ', '.join(var_names)
1124 return 'loop must iterate at least once to initialize {}'.format(var_names)
1127def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
1128 """Overload of while_stmt that stages a TF while_stmt."""
1129 init_vars = get_state()
1130 orig_init_vars = init_vars
1132 nulls = tuple(_is_none_or_undef(v) for v in init_vars)
1133 if any(nulls):
1134 shape_invars_by_init_vals = {
1135 id(v): i for v, i in opts.get('shape_invariants', ())
1136 }
1137 shape_invariants = tuple(
1138 shape_invars_by_init_vals.get(id(v), None) for v in orig_init_vars)
1139 (require_one_iteration, init_vars,
1140 extra_shape_invariants) = _try_handling_undefineds(body, get_state,
1141 set_state, init_vars,
1142 nulls, shape_invariants,
1143 symbol_names)
1144 else:
1145 require_one_iteration = False
1147 if require_one_iteration:
1148 merged_shape_invariants = dict(shape_invars_by_init_vals)
1149 # This has two roles:
1150 # 1. Shape invariants are remapped from the old init vars to the new ones.
1151 # 2. Any new shape invariants created by the init vars are kept, but only
1152 # if the user didn't already specify some.
1153 for v, nv, ni in zip(orig_init_vars, init_vars, extra_shape_invariants):
1154 merged_invariant = merged_shape_invariants.get(id(v), ni)
1155 if merged_invariant is not None:
1156 merged_shape_invariants[id(nv)] = merged_invariant
1157 merged_shape_invariants = tuple((nv, merged_shape_invariants[id(nv)])
1158 for nv in init_vars
1159 if id(nv) in merged_shape_invariants)
1160 if merged_shape_invariants:
1161 opts = dict(**opts)
1162 opts['shape_invariants'] = merged_shape_invariants
1164 def aug_test(*loop_vars):
1165 if require_one_iteration:
1166 loop_vars = loop_vars[1:]
1168 set_state(loop_vars)
1169 return _verify_tf_condition(test(), 'while loop')
1171 def aug_body(*loop_vars):
1172 if require_one_iteration:
1173 loop_vars = loop_vars[1:]
1175 set_state(loop_vars)
1176 body()
1177 new_loop_vars = get_state()
1178 verify_tf_loop_vars(
1179 init_vars, loop_vars, new_loop_vars, symbol_names, opts)
1181 if require_one_iteration:
1182 new_loop_vars = (True,) + new_loop_vars
1184 return new_loop_vars
1186 if 'shape_invariants' in opts:
1187 opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
1188 opts['shape_invariants'], init_vars)
1190 while_loop_opts = dict(opts)
1191 while_loop_opts.pop('iterate_names', None)
1193 # Non-v2 while_loop unpacks the results when there is only one return value.
1194 # This enforces consistency across versions.
1195 while_loop_opts['return_same_structure'] = True
1197 if require_one_iteration:
1198 aug_init_vars = (False,) + init_vars
1199 if 'shape_invariants' in while_loop_opts:
1200 while_loop_opts['shape_invariants'] = (
1201 (None,) + while_loop_opts['shape_invariants'])
1202 else:
1203 aug_init_vars = init_vars
1205 final_loop_vars = while_loop.while_loop(aug_test, aug_body, aug_init_vars,
1206 **while_loop_opts)
1208 if require_one_iteration:
1209 with ops.control_dependencies([
1210 control_flow_assert.Assert(final_loop_vars[0], [
1211 _runtime_zero_iterations_errmsg(symbol_names, nulls, orig_init_vars)
1212 ])
1213 ]):
1214 final_loop_vars = nest.map_structure(
1215 lambda v: (array_ops.identity(v) if tensor_util.is_tf_type(v) else v),
1216 final_loop_vars[1:],
1217 )
1219 set_state(final_loop_vars)
1222def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts):
1223 """Functional form of an if statement.
1225 The conditional operates on a state, which includes all symbols whose values
1226 are a function of the branch taken.
1228 For example, given the code below that calculates the abs function:
1230 ```
1231 x = 1
1232 if x > 0:
1233 x = -x
1234 ```
1236 The state is represented by the variable `x`. The `body, `orelse` and
1237 `set_state` functions must bind to the original `x` symbol, using `nonlocal`.
1239 The inputs and outputs of the callables representing the loop blocks are not
1240 explicit - instead, these functions must use nonlocal/global for side effects.
1241 The inputs and outputs are instead controlled by the set_state/get_state
1242 functions.
1244 Args:
1245 cond: Boolean.
1246 body: Callable representing the main block of the conditional.
1247 orelse: Callable representing the else block of the conditional.
1248 get_state: Function that returns a tuple containing the values of all
1249 composite symbols modified within the conditional. This allows access to
1250 state that branches may mutate through side effects. This function is not
1251 needed and should not be called when dispatching to code matching Python's
1252 default semantics. This is useful for checkpointing to avoid unintended
1253 side-effects when staging requires evaluating all code-paths.
1254 set_state: Function to set the values of all composite symbols modified
1255 within the conditional. This is the complement to get_state, used to
1256 restore checkpointed values. The single argument a tuple containing values
1257 for each composite symbol that may be modified in a branch of the
1258 conditional. The is usually the result of a call to get_state.
1259 symbol_names: Tuple containing basic loop var names.
1260 nouts: Number of variables output by the statement. Vars which are not
1261 outputs will not be passed through staged control flow such as tf.cond.
1262 This includes variables that are defined before the conditional, but are
1263 not used after it.
1264 """
1265 # Note: tf.cond doesn't support SparseTensor.
1266 if tensors.is_dense_tensor(cond):
1267 _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
1268 else:
1269 _py_if_stmt(cond, body, orelse)
1272def _tf_if_stmt(
1273 cond, body, orelse, get_state, set_state, symbol_names, nouts):
1274 """Overload of if_stmt that stages a TF cond."""
1275 cond = _verify_tf_condition(cond, 'if statement')
1277 if not nouts:
1278 prev_get_state, prev_set_state = get_state, set_state
1279 # Control flow V1 wants at least one output.
1280 get_state = lambda: (0,) + prev_get_state()
1281 set_state = lambda v: prev_set_state(v[1:])
1282 symbol_names += ('<unused dummy>',)
1283 nouts = 1
1285 init_vars = get_state()
1287 # TODO(mdan): Use nonlocal once we no longer need to support py2.
1288 new_body_vars_ = [None]
1289 new_orelse_vars_ = [None]
1291 def aug_body():
1292 set_state(init_vars)
1293 body()
1294 new_body_vars = get_state()
1295 new_body_vars = new_body_vars[:nouts]
1296 new_body_vars_[0] = new_body_vars
1297 _verify_tf_cond_branch_vars(new_body_vars, symbol_names, 'main')
1298 if new_orelse_vars_[0] is not None:
1299 _verify_tf_cond_vars(new_body_vars, new_orelse_vars_[0], symbol_names)
1300 return new_body_vars
1302 def aug_orelse():
1303 set_state(init_vars)
1304 orelse()
1305 new_orelse_vars = get_state()
1306 new_orelse_vars = new_orelse_vars[:nouts]
1307 new_orelse_vars_[0] = new_orelse_vars
1308 _verify_tf_cond_branch_vars(new_orelse_vars, symbol_names, 'else')
1309 if new_body_vars_[0] is not None:
1310 _verify_tf_cond_vars(new_body_vars_[0], new_orelse_vars, symbol_names)
1311 return new_orelse_vars
1313 final_cond_vars = tf_cond.cond(
1314 cond, aug_body, aug_orelse, strict=True)
1315 final_cond_vars = final_cond_vars + init_vars[nouts:]
1317 set_state(final_cond_vars)
1320def _py_if_stmt(cond, body, orelse):
1321 """Overload of if_stmt that executes a Python if statement."""
1322 return body() if cond else orelse()