Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py: 12%

575 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Control flow statements: loops, conditionals, etc. 

16 

17Note: most of these operators accept pairs of get_state/set_state functions, to 

18capture mutations that the corresponding code blocks might make. These 

19mutations only need to be captured when staging the control flow, and they just 

20work when reverting to Python behavior. 

21 

22__Examples__ 

23 

24``` 

25while cond: 

26 self.x += i 

27``` 

28 

29When the functionalized version is executed as a Python loop, it just works: 

30 

31``` 

32def loop_body(): 

33 self.x += i # works as expected for Python loops 

34``` 

35 

36But it won't work for TF loops: 

37 

38``` 

39def loop_body(): 

40 self.x += i # self.x has the wrong value! 

41``` 

42 

43get_state/set_state allow piping the mutations through the loop variables as 

44well, in effect changing the loop body: 

45 

46``` 

47def loop_body(self_x): 

48 self.x = self_x # self.x now has the proper value 

49 self.x += i # the original block 

50 self_x = self.x # write self.x back into the loop vars 

51 return self_x 

52 

53self_x = tf.while_loop(...) 

54self.x = self_x # the result is not properly captured 

55``` 

56""" 

57 

58import functools 

59import sys 

60import traceback 

61 

62import numpy as np 

63 

64from tensorflow.python.autograph.operators import py_builtins 

65from tensorflow.python.autograph.operators import variables 

66from tensorflow.python.autograph.utils import ag_logging 

67from tensorflow.python.autograph.utils import misc 

68from tensorflow.python.autograph.utils import tensors 

69from tensorflow.python.autograph.utils import type_registry 

70from tensorflow.python.framework import dtypes 

71from tensorflow.python.framework import errors_impl 

72from tensorflow.python.framework import func_graph 

73from tensorflow.python.framework import ops 

74from tensorflow.python.framework import tensor_conversion 

75from tensorflow.python.framework import tensor_shape 

76from tensorflow.python.framework import tensor_util 

77from tensorflow.python.ops import array_ops 

78from tensorflow.python.ops import cond as tf_cond 

79from tensorflow.python.ops import control_flow_assert 

80from tensorflow.python.ops import control_flow_util 

81from tensorflow.python.ops import math_ops 

82from tensorflow.python.ops import tensor_array_ops 

83from tensorflow.python.ops import while_loop 

84from tensorflow.python.ops.ragged import ragged_tensor 

85from tensorflow.python.types import distribute 

86from tensorflow.python.util import nest 

87from tensorflow.python.util import variable_utils 

88 

89 

90PYTHON_MAX_ITERATIONS = 100000000 # Fails in about one minute for empty loops. 

91WARN_INEFFICIENT_UNROLL = True 

92INEFFICIENT_UNROLL_MIN_ITERATIONS = 50000 

93INEFFICIENT_UNROLL_MIN_OPS = 1 

94 

95 

96# TODO(mdan): Use the custom operator pattern instead of type dispatch. 

97# An example of this pattern is found in the implementation of distributed 

98# datasets. Before it can be used though, we need to standardize the interface. 

99 

100for_loop_registry = type_registry.TypeRegistry() 

101 

102 

103def _is_none_or_undef(value): 

104 """Tests whether a value is None or undefined. 

105 

106 AutoGraph represents undefined symbols using special objects of type Undefined 

107 or UndefinedReturnValue. 

108 

109 Args: 

110 value: value to test 

111 

112 Returns: 

113 Boolean 

114 """ 

115 return ((value is None) 

116 or isinstance(value, variables.UndefinedReturnValue) 

117 or isinstance(value, variables.Undefined)) 

118 

119 

120def _verify_tf_condition(cond, tag): 

121 """Ensures that the condition can be used in a TF control flow.""" 

122 extra_hint = 'to check for None, use `is not None`' 

123 cond = tensor_conversion.convert_to_tensor_v2(cond) 

124 

125 if cond.dtype != dtypes.bool: 

126 raise ValueError( 

127 'condition of {} expected to be `tf.bool` scalar, got {}' 

128 '; to use as boolean Tensor, use `tf.cast`' 

129 '; {}'.format(tag, cond, extra_hint)) 

130 

131 if cond.shape is None or cond.shape.ndims is None: 

132 # TODO(mdan): Consider a explicit size check, if not too slow. 

133 cond = array_ops.reshape(cond, ()) 

134 

135 elif cond.shape.ndims > 0: 

136 known_dims = [d for d in cond.shape.as_list() if d is not None] 

137 if np.prod(known_dims) > 1: 

138 raise ValueError( 

139 'condition of {} expected to be `tf.bool` scalar, got {}' 

140 '; {}'.format(tag, cond, extra_hint)) 

141 else: 

142 cond = array_ops.reshape(cond, ()) 

143 

144 return cond 

145 

146 

147def verify_loop_init_vars( 

148 init_vars, symbol_names, first_iter_vars=None, extra_message=None 

149): 

150 """Ensures that all values in the state are valid to use in a TF loop. 

151 

152 The init_vars may contain placeholder values derived from first_iter_vars. 

153 

154 Args: 

155 init_vars: initial loop variables (as taken before entering the loop) 

156 symbol_names: corresponding names of the initial loop variables 

157 first_iter_vars: loop variables after one iteration of the loop 

158 extra_message: an extra string to append to the error message, in case of 

159 "undefined variable" errors (see variables.Undefined) 

160 """ 

161 if not symbol_names: 

162 return 

163 if first_iter_vars is None: 

164 first_iter_vars = (None,) * len(symbol_names) 

165 

166 assert len(symbol_names) == len(init_vars) 

167 assert len(symbol_names) == len(first_iter_vars) 

168 for name, val, fi_val in zip(symbol_names, init_vars, first_iter_vars): 

169 if isinstance(val, variables.UndefinedReturnValue): 

170 if fi_val: 

171 raise ValueError( 

172 'the return value from a TensorFlow loop may only be a {}; got {}' 

173 .format(LEGAL_LOOP_TYPES, type(fi_val))) 

174 else: 

175 # TODO(mdan): This can be handled by removing the return value. 

176 raise NotImplementedError( 

177 'a return statement cannot be placed inside this TensorFlow loop;' 

178 ' this may happen if a return statement depends on a' 

179 ' static Python condition such as a hyperparameter') 

180 

181 error_msg = None 

182 if val is None: 

183 error_msg = "'{}' is not allowed to be None before the loop".format(name) 

184 elif isinstance(val, variables.Undefined): 

185 error_msg = "'{}' must be defined before the loop".format(name) 

186 if extra_message: 

187 error_msg += '\n' + extra_message 

188 

189 if error_msg is not None: 

190 raise ValueError(error_msg) 

191 

192 

193def _is_subshape(left, right): 

194 """Returns True if left shape is at least as specific as right shape.""" 

195 # TODO(mdan): This code should be in TensorShape. 

196 # Note: this is not the same as TensorShape.is_compatible_with, which is 

197 # symmetric. 

198 # This code also duplicates _ShapeLessThanOrEqual from control_flow_ops.py. 

199 if right.dims is None: 

200 return True 

201 if left.ndims != right.ndims: 

202 return False 

203 for ldim, rdim in zip(left.dims, right.dims): 

204 if rdim.value is not None and ldim.value != rdim.value: 

205 return False 

206 return True 

207 

208 

209# TODO(mdan): Remove these verifications once TF ops can properly report names. 

210def _verify_single_loop_var( 

211 name, check_shape, init, entry, exit_, shape_invariant): 

212 """Verifies whether the initial, entry and exit values are consistent.""" 

213 assert entry is not None, "no TF op should set '{}' to None?".format(name) 

214 if exit_ is None: 

215 raise ValueError("'{}' is None at the end of the iteration.".format(name)) 

216 

217 if isinstance(init, (bool, int, float, str, np.ndarray)): 

218 init = tensor_conversion.convert_to_tensor_v2(init) 

219 if isinstance(entry, (bool, int, float, str, np.ndarray)): 

220 entry = tensor_conversion.convert_to_tensor_v2(entry) 

221 if isinstance(exit_, (bool, int, float, str, np.ndarray)): 

222 exit_ = tensor_conversion.convert_to_tensor_v2(exit_) 

223 

224 if (not tensor_util.is_tf_type(entry) or 

225 not tensor_util.is_tf_type(exit_)): 

226 return 

227 

228 # TODO(mdan): Properly account for CompositeTensors. 

229 if (not hasattr(entry, 'dtype') or 

230 not hasattr(exit_, 'dtype')): 

231 return 

232 if (not hasattr(entry, 'shape') or 

233 not hasattr(exit_, 'shape')): 

234 return 

235 

236 if entry.dtype != exit_.dtype: 

237 raise TypeError( 

238 "'{}' has dtype {} before the loop, but dtype {} after one" 

239 ' iteration'.format( 

240 name, 

241 entry.dtype.name, 

242 exit_.dtype.name, 

243 )) 

244 if check_shape: 

245 exit_shape = exit_.shape 

246 if shape_invariant is None: 

247 entry_shape = entry.shape 

248 if not _is_subshape(exit_shape, entry_shape): 

249 raise ValueError( 

250 "'{}' has shape {} before the loop, but shape {} after one" 

251 ' iteration. Use tf.autograph.experimental.set_loop_options to set' 

252 ' shape invariants.'.format(name, entry_shape, exit_shape)) 

253 else: 

254 init_shape = init.shape 

255 if not _is_subshape(init_shape, shape_invariant): 

256 raise ValueError( 

257 "'{}' has shape {} before the loop, which does not conform with" 

258 ' the shape invariant {}.'.format(name, init_shape, 

259 shape_invariant)) 

260 if not _is_subshape(exit_shape, shape_invariant): 

261 raise ValueError( 

262 "'{}' has shape {} after one iteration, which does not conform with" 

263 ' the shape invariant {}.'.format(name, exit_shape, shape_invariant) 

264 ) 

265 

266 

267def verify_tf_loop_vars( 

268 init_vars, 

269 iter_entry_vars, 

270 iter_exit_vars, 

271 symbol_names, 

272 opts, 

273 check_shapes=True, 

274): 

275 """Verifies loop variables for consistency.""" 

276 if check_shapes and 'shape_invariants' in opts: 

277 shape_invariants = opts['shape_invariants'] 

278 else: 

279 shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars) 

280 

281 assert len(symbol_names) == len(shape_invariants) 

282 assert len(symbol_names) == len(init_vars) 

283 assert len(symbol_names) == len(iter_entry_vars) 

284 assert len(symbol_names) == len(iter_exit_vars) 

285 

286 for i in range(len(symbol_names)): 

287 name = symbol_names[i] 

288 init = init_vars[i] 

289 entry = iter_entry_vars[i] 

290 exit_ = iter_exit_vars[i] 

291 invariant = shape_invariants[i] 

292 

293 try: 

294 nest.assert_same_structure(init, entry, expand_composites=True) 

295 except (ValueError, TypeError): 

296 # `Variable`s in `init` may be implicitly converted to `Tensor`s. Convert 

297 # `ResourceVariable`s to Tensors so tf.nest.assert_same_structure 

298 # won't break due to type spec mismatches between `ResourceVariable`s and 

299 # `Tensor`s. 

300 try: 

301 init_tensors = variable_utils.convert_variables_to_tensors(init) 

302 nest.assert_same_structure(init_tensors, entry, expand_composites=True) 

303 except (ValueError, TypeError) as e: 

304 raise TypeError("'{}' does not have the same nested structure after one" 

305 ' iteration.\n\n{}'.format(name, e)) from e 

306 

307 try: 

308 nest.assert_same_structure(entry, exit_, expand_composites=True) 

309 except (ValueError, TypeError) as e: 

310 raise TypeError("'{}' does not have the same nested structure after one" 

311 ' iteration.\n\n{}'.format(name, e)) from e 

312 if invariant is not None: 

313 try: 

314 nest.assert_same_structure(init, invariant, expand_composites=False) 

315 except (ValueError, TypeError) as e: 

316 raise TypeError("'{}' does not have the same nested structure as its" 

317 ' corresponding shape invariant.\n\n{}'.format( 

318 name, e)) from e 

319 

320 nest.map_structure( 

321 functools.partial(_verify_single_loop_var, name, check_shapes), init, 

322 entry, exit_, invariant) 

323 

324 

325def verify_single_cond_var(name, body_var, orelse_var): 

326 """Verifies whether body_var and orelse_var are consistent.""" 

327 if body_var is None: 

328 raise ValueError("'{}' is None at the end of the main branch.".format(name)) 

329 if orelse_var is None: 

330 raise ValueError( 

331 "'{}' is None at the end of the else branch.".format(name)) 

332 

333 if isinstance(body_var, (bool, int, float, str, np.ndarray)): 

334 body_var = tensor_conversion.convert_to_tensor_v2(body_var) 

335 

336 if isinstance(orelse_var, (bool, int, float, str, np.ndarray)): 

337 orelse_var = tensor_conversion.convert_to_tensor_v2(orelse_var) 

338 

339 if (not tensor_util.is_tf_type(body_var) or 

340 not tensor_util.is_tf_type(orelse_var)): 

341 return 

342 

343 # TODO(mdan): Properly account for CompositeTensors. 

344 if (not hasattr(body_var, 'dtype') or 

345 not hasattr(orelse_var, 'dtype')): 

346 return 

347 

348 if body_var.dtype != orelse_var.dtype: 

349 raise TypeError( 

350 "'{}' has dtype {} in the main branch, but dtype {} in the else" 

351 ' branch'.format(name, body_var.dtype.name, 

352 orelse_var.dtype.name)) 

353 

354 

355def _verify_tf_cond_branch_vars(vars_, symbol_names, branch_name): 

356 """Verifies variables output by a conditional branch for consistency.""" 

357 for name, var_ in zip(symbol_names, vars_): 

358 if isinstance(var_, variables.Undefined): 

359 raise ValueError( 

360 "'{}' must also be initialized in the {} branch".format( 

361 name, branch_name)) 

362 if isinstance(var_, variables.UndefinedReturnValue): 

363 raise ValueError( 

364 'the {} branch must also have a return statement.'.format( 

365 branch_name)) 

366 

367 

368def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names): 

369 """Verifies variables manipulated by a conditional for consistency.""" 

370 named_vars = zip(symbol_names, body_vars, orelse_vars) 

371 

372 for name, body_var, orelse_var in named_vars: 

373 try: 

374 nest.assert_same_structure(body_var, orelse_var, expand_composites=True) 

375 except (ValueError, TypeError): 

376 # One branch of cond could be a `Tensor`, while the other branch could be 

377 # a `ResourceVariable`. Convert `ResourceVariable`s to `Tensor`s so 

378 # assert_same_structure won't fail. 

379 try: 

380 body_var_tensors = variable_utils.convert_variables_to_tensors(body_var) 

381 orelse_var_tensors = variable_utils.convert_variables_to_tensors( 

382 orelse_var) 

383 nest.assert_same_structure(body_var_tensors, orelse_var_tensors, 

384 expand_composites=True) 

385 except (ValueError, TypeError) as e: 

386 raise TypeError( 

387 "'{}' must have the same nested structure in the main and else" 

388 ' branches:\n\n{}'.format(name, str(e))) from e 

389 nest.map_structure( 

390 functools.partial(verify_single_cond_var, name), body_var, orelse_var) 

391 

392 

393def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): 

394 """Functional form of a for statement. 

395 

396 The loop operates on a state, which includes all symbols that are 

397 variant across loop iterations, excluding the variables local to the loop. 

398 

399 For example, given the loop below that calculates the geometric and 

400 arithmetic means or some numbers: 

401 

402 ``` 

403 geo_mean = 1 

404 arith_mean = 0 

405 for i in range(n): 

406 a = numbers[i] 

407 geo_mean *= a 

408 arith_mean += a 

409 ``` 

410 

411 The state is represented by the variables named geo_mean and arith_mean. The 

412 `extra_test`, `body`, `get_state` and `set_state` functions must bind to the 

413 original `geo_mean` and `arith_mean` symbols, using `nonlocal`. 

414 

415 The inputs and outputs of the callables representing the loop blocks are not 

416 explicit - instead, these functions must use nonlocal/global for side effects. 

417 The inputs and outputs are instead controlled by the set_state/get_state 

418 functions. 

419 

420 Args: 

421 iter_: The entity being iterated over. 

422 extra_test: Callable with boolean return type. An additional loop condition. 

423 body: Callable representing the actual loop body. 

424 get_state: Additional callable which can capture additional state (such as 

425 the values of composite symbols). This is only useful when staging the 

426 loop. 

427 set_state: Additional callable which save values captured by get_state back 

428 into the Python environment. This is only useful when staging the loop. 

429 symbol_names: Tuple containing names of the loop variables returned by 

430 get_state. 

431 opts: Optional dict of extra loop parameters. 

432 """ 

433 

434 try: 

435 for_fn = for_loop_registry.lookup(iter_) 

436 except LookupError: 

437 for_fn = _py_for_stmt 

438 

439 # TODO(bwieder): Refactor isinstance(iter_, ragged_tensor.RaggedTensor) to use 

440 # the registry once python/autograph/utils does not depend on dataset_ops. 

441 if tensor_util.is_tf_type(iter_): 

442 if tensors.is_range_tensor(iter_): 

443 for_fn = _tf_range_for_stmt 

444 elif isinstance(iter_, ragged_tensor.RaggedTensor): 

445 for_fn = _tf_ragged_for_stmt 

446 else: 

447 for_fn = _known_len_tf_for_stmt 

448 elif isinstance(iter_, distribute.Iterator): 

449 for_fn = _tf_iterator_for_stmt 

450 elif isinstance(iter_, distribute.Iterable): 

451 # TODO(b/162250181): Use _tf_iterator_for_stmt(iter(iter_)... 

452 for_fn = _tf_distributed_iterable_for_stmt 

453 

454 for_fn(iter_, extra_test, body, get_state, set_state, symbol_names, opts) 

455 

456 

457def _py_for_stmt( 

458 iter_, extra_test, body, get_state, set_state, symbol_names, opts 

459): 

460 """Overload of for_stmt that executes a Python for loop.""" 

461 del get_state, set_state, symbol_names, opts 

462 

463 if __debug__: 

464 checker = _PythonLoopChecker() 

465 before_iteration = checker.before_iteration 

466 after_iteration = checker.after_iteration 

467 before_iteration() 

468 

469 original_body = body 

470 def protected_body(protected_iter): 

471 original_body(protected_iter) 

472 after_iteration() 

473 before_iteration() 

474 body = protected_body 

475 

476 if extra_test is not None: 

477 def guarded_extra_test(): 

478 extra_test_result = extra_test() 

479 try: 

480 # Note: Using try/except and not tensor_util.is_tf_type to avoid 

481 # performance degradation. 

482 return bool(extra_test_result) 

483 except errors_impl.OperatorNotAllowedInGraphError as e: 

484 ag_logging.log( 

485 1, 

486 'Caught error while evaluating loop stop condition', 

487 exc_info=True) 

488 # TODO(mdan): We can pass the location of extra_test and show it here. 

489 raise NotImplementedError( 

490 'break and return statements which depend on a TF condition are not' 

491 ' supported in Python for loops. Did you intend to make it a TF' 

492 ' loop?\nSee ' 

493 'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 

494 'python/autograph/g3doc/reference/limitations.md' 

495 '#consistency-of-control-flow-types for more info.') from e 

496 

497 if guarded_extra_test(): 

498 for target in iter_: 

499 body(target) 

500 if not guarded_extra_test(): 

501 break 

502 

503 else: 

504 for target in iter_: 

505 body(target) 

506 

507 

508def _add_max_iterations_hint(opts, n): 

509 # TODO(b/159186914): Remove the safeguard, and always set maximum_iterations. 

510 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 

511 opts['maximum_iterations'] = n 

512 

513 

514def _known_len_tf_for_stmt( 

515 iter_, extra_test, body, get_state, set_state, symbol_names, opts): 

516 """Overload of for_stmt that iterates over TF entities that admit a length.""" 

517 n = py_builtins.len_(iter_) 

518 

519 # TODO(b/117628877): Revisit performance once XLA has the necessary support. 

520 # Note: using a TensorArray creates an extra copy, but can calculate 

521 # gradients more efficiently than StridedSlice. 

522 ta = tensor_array_ops.TensorArray(iter_.dtype, size=n) 

523 iter_ = ta.unstack(iter_) 

524 

525 iterate_index = 0 

526 

527 def aug_get_state(): 

528 return (iterate_index,) + get_state() 

529 

530 def aug_set_state(aug_loop_vars): 

531 nonlocal iterate_index 

532 # TODO(b/171479293): Drop the lint override. 

533 iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable 

534 # The iteration index is not "output" by the for loop. If the iteration index 

535 # is used outside the loop, it will appear in the loop vars separately. 

536 set_state(loop_vars) 

537 

538 def aug_body(): 

539 nonlocal iterate_index 

540 body(iter_.read(iterate_index)) 

541 iterate_index += 1 

542 

543 def aug_test(): 

544 main_test = iterate_index < n 

545 if extra_test is not None: 

546 return tf_cond.cond(main_test, extra_test, lambda: False) 

547 return main_test 

548 

549 _add_max_iterations_hint(opts, n) 

550 

551 _tf_while_stmt( 

552 aug_test, 

553 aug_body, 

554 aug_get_state, 

555 aug_set_state, 

556 ('<internal iterate>',) + symbol_names, 

557 opts, 

558 ) 

559 

560 

561def _tf_ragged_for_stmt( 

562 iter_, extra_test, body, get_state, set_state, symbol_names, opts): 

563 """Overload of for_stmt that iterates over TF ragged tensors.""" 

564 init_vars = get_state() 

565 verify_loop_init_vars(init_vars, symbol_names) 

566 

567 # TODO(mdan): Move this into len()? Requires eager support. 

568 if iter_.shape and iter_.shape[0] is not None: 

569 n = iter_.shape[0] 

570 else: 

571 n = iter_.row_lengths()[0] 

572 

573 iterate_index = 0 

574 

575 def aug_get_state(): 

576 return (iterate_index,) + get_state() 

577 

578 def aug_set_state(aug_loop_vars): 

579 nonlocal iterate_index 

580 # TODO(b/171479293): Drop the lint override. 

581 iterate_index, *loop_vars = aug_loop_vars # pylint:disable=unused-variable 

582 # The iteration index is not "output" by the for loop. If the iteration index 

583 # is used outside the loop, it will appear in the loop vars separately. 

584 set_state(loop_vars) 

585 

586 def aug_body(): 

587 nonlocal iterate_index 

588 body(iter_[iterate_index]) 

589 iterate_index += 1 

590 

591 def aug_test(): 

592 main_test = iterate_index < n 

593 if extra_test is not None: 

594 return tf_cond.cond(main_test, extra_test, lambda: False) 

595 return main_test 

596 

597 _add_max_iterations_hint(opts, n) 

598 

599 _tf_while_stmt( 

600 aug_test, 

601 aug_body, 

602 aug_get_state, 

603 aug_set_state, 

604 ('<internal iterate>',) + symbol_names, 

605 opts) 

606 

607 

608def _tf_range_for_stmt( 

609 iter_, extra_test, body, get_state, set_state, symbol_names, opts): 

610 """Overload of for_stmt that iterates over a TF range (and elides it).""" 

611 start, limit, delta = iter_.op.inputs 

612 

613 iterate = start 

614 

615 def _value_or(name, var, default): 

616 if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)): 

617 return default 

618 return var 

619 

620 def aug_get_state(): 

621 state_vars = get_state() 

622 state_vars = tuple( 

623 _value_or(name, var, iterate) 

624 for name, var in zip(symbol_names, state_vars)) 

625 return (iterate,) + state_vars 

626 

627 def aug_set_state(aug_loop_vars): 

628 nonlocal iterate 

629 # TODO(b/171479293): Drop the lint override. 

630 iterate, *loop_vars = aug_loop_vars # pylint:disable=unused-variable 

631 # The iteration index is not "output" by the for loop. If the iterate 

632 # is used outside the loop, it will appear in the loop vars separately. 

633 set_state(loop_vars) 

634 

635 def aug_body(): 

636 nonlocal iterate 

637 body(iterate) 

638 iterate += delta 

639 

640 def aug_test(): 

641 # TODO(b/159713842): Remove once constant folding works. 

642 const_delta = tensor_util.constant_value(delta) 

643 if const_delta is not None: 

644 if const_delta >= 0: 

645 main_test = iterate < limit 

646 else: 

647 main_test = iterate > limit 

648 else: 

649 main_test = math_ops.logical_or( 

650 math_ops.logical_and(delta >= 0, iterate < limit), 

651 math_ops.logical_and(delta < 0, iterate > limit)) 

652 

653 if extra_test is not None: 

654 main_test = tf_cond.cond(main_test, extra_test, lambda: False) 

655 return main_test 

656 

657 _add_max_iterations_hint( 

658 opts, 

659 math_ops.cast(misc.get_range_len(start, limit, delta), dtypes.int32)) 

660 

661 _tf_while_stmt( 

662 aug_test, 

663 aug_body, 

664 aug_get_state, 

665 aug_set_state, 

666 ('<internal iterate>',) + symbol_names, 

667 opts) 

668 

669 

670def _tf_iterator_for_stmt( 

671 iter_, extra_test, body, get_state, set_state, symbol_names, opts): 

672 """Overload of for_stmt that iterates over TF Iterators. See for_loop.""" 

673 symbol_names = ('<internal has_next>',) + symbol_names 

674 has_next = True 

675 

676 def aug_get_state(): 

677 return (has_next,) + get_state() 

678 

679 def aug_set_state(aug_loop_vars): 

680 nonlocal has_next 

681 # TODO(b/171479293): Drop the lint override. 

682 has_next, *loop_vars = aug_loop_vars # pylint:disable=unused-variable 

683 set_state(loop_vars) 

684 

685 init_vars = aug_get_state() 

686 verify_loop_init_vars(init_vars, symbol_names) 

687 

688 def aug_body(): 

689 """Main body passed to _tf_while_stmt.""" 

690 nonlocal has_next 

691 opt_iterate = iter_.get_next_as_optional() 

692 has_next = opt_iterate.has_value() 

693 loop_vars = aug_get_state() # updated by set_state() in _tf_while_loop. 

694 

695 def main_path(): 

696 body(opt_iterate.get_value()) 

697 new_loop_vars = aug_get_state() 

698 # Note: this verification duplicates the one performed in tf_while_stmt, 

699 # but needs to be done earlier to prevent the tf.cond from blowing up 

700 # first. 

701 verify_tf_loop_vars( 

702 init_vars, loop_vars, new_loop_vars, symbol_names, opts) 

703 return new_loop_vars 

704 

705 def noop_path(): 

706 return loop_vars 

707 

708 # TODO(mdan): If tf.while_loop supported Optional, this could be avoided. 

709 # Calling set_state so that get_state() _tf_while_loop sees the conditional 

710 # tensors. 

711 aug_set_state( 

712 tf_cond.cond(has_next, main_path, noop_path)) 

713 

714 def aug_test(): 

715 # This value takes a complicated path to get here: 

716 # prev_iteration_body -> get_state -> tf.while_loop (as loop var) 

717 # -> current_iteration_body -> set_state -> has_next 

718 main_test = has_next 

719 if extra_test is not None: 

720 return tf_cond.cond(main_test, extra_test, lambda: False) 

721 return main_test 

722 

723 _tf_while_stmt( 

724 aug_test, 

725 aug_body, 

726 aug_get_state, 

727 aug_set_state, 

728 symbol_names, 

729 opts) 

730 

731 

732def _tf_distributed_iterable_for_stmt( 

733 iter_, extra_test, body, get_state, set_state, symbol_names, opts): 

734 """Overload of for_stmt that iterates over TF distributed datasets.""" 

735 

736 if extra_test is not None: 

737 raise NotImplementedError( 

738 'break and return statements are not yet supported in ' 

739 'for ... in distributed input loops.') 

740 

741 init_vars = get_state() 

742 verify_loop_init_vars(init_vars, symbol_names) 

743 

744 if 'shape_invariants' in opts: 

745 opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list( 

746 opts['shape_invariants'], init_vars) 

747 

748 def reduce_body(loop_vars, iterate): 

749 set_state(loop_vars) 

750 body(iterate) 

751 new_loop_vars = get_state() 

752 verify_tf_loop_vars( 

753 init_vars, loop_vars, new_loop_vars, symbol_names, opts) 

754 return new_loop_vars 

755 

756 set_state(iter_.reduce(init_vars, reduce_body)) 

757 

758 

759def while_stmt(test, body, get_state, set_state, symbol_names, opts): 

760 """Functional form of a while statement. 

761 

762 The loop operates on a so-called state, which includes all symbols that are 

763 variant across loop iterations. In what follows we refer to state as either 

764 a tuple of entities that represent an actual state, or a list of arguments 

765 of the corresponding types. 

766 

767 The inputs and outputs of the callables representing the loop blocks are not 

768 explicit - instead, these functions must use nonlocal/global for side effects. 

769 The inputs and outputs are instead controlled by the set_state/get_state 

770 functions. 

771 

772 Args: 

773 test: Callable with boolean return type. The loop condition. 

774 body: Callable representing the actual loop body. 

775 get_state: Additional callable which can capture additional state (such as 

776 the values of composite symbols). This is only useful when staging the 

777 loop. 

778 set_state: Additional callable which save values captured by get_state back 

779 into the Python environment. This is only useful when staging the loop. 

780 symbol_names: Tuple containing the names of all loop variables. 

781 opts: Optional dict of extra loop parameters. 

782 

783 Returns: 

784 Tuple containing the final state. 

785 """ 

786 

787 # Evaluate the initial test once in order to do the dispatch. The evaluation 

788 # is isolated to minimize unwanted side effects. 

789 # TODO(mdan): Do a full iteration - some state types might lower to Tensor. 

790 with func_graph.FuncGraph('tmp').as_default(): 

791 init_test = test() 

792 

793 # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine 

794 # with the re-evaluation of `test` that `_tf_while_stmt` will make. 

795 if tensors.is_dense_tensor(init_test): 

796 _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts) 

797 return 

798 

799 # Normal Python: We already consumed one evaluation of `test`; consistently, 

800 # unroll one iteration before dispatching to a normal loop. 

801 # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt? 

802 if not init_test: 

803 return 

804 body() 

805 

806 _py_while_stmt(test, body, get_state, set_state, opts) 

807 

808 

809class _PythonLoopChecker(object): 

810 """Verifies Python loops for TF-specific limits.""" 

811 

812 __slots__ = ( 

813 'iterations', 

814 'check_inefficient_unroll', 

815 'check_op_count_after_iteration', 

816 'ops_before_iteration', 

817 ) 

818 

819 def __init__(self): 

820 self.iterations = 1 

821 self.check_inefficient_unroll = WARN_INEFFICIENT_UNROLL 

822 

823 # Triggered when we decided to test the op counts. 

824 self.check_op_count_after_iteration = False 

825 

826 def _get_ops(self): 

827 return set(ops.get_default_graph().get_operations()) 

828 

829 def _check_unroll_limits(self): 

830 if self.iterations > PYTHON_MAX_ITERATIONS: 

831 raise ValueError('iteration limit exceeded') 

832 

833 def _stop_checking_inefficient_unroll(self): 

834 self.check_inefficient_unroll = False 

835 self.check_op_count_after_iteration = False 

836 self.ops_before_iteration = None 

837 

838 def _verify_inefficient_unroll(self): 

839 """Checks for possibly-inefficient creation of ops in a Python loop.""" 

840 assert self.ops_before_iteration is not None 

841 ops_after_iteration = self._get_ops() 

842 new_ops = tuple( 

843 op for op in ops_after_iteration if op not in self.ops_before_iteration) 

844 

845 if len(new_ops) < INEFFICIENT_UNROLL_MIN_OPS: 

846 return False 

847 

848 ag_logging.warning( 

849 'Large unrolled loop detected. Did you mean to use a TF loop?' 

850 ' The following ops were created after iteration %s: %s' 

851 '\nSee' 

852 ' https://github.com/tensorflow/tensorflow/blob/master/' 

853 'tensorflow/python/autograph/g3doc/reference/common_errors.md' 

854 '#warning-large-unrolled-loop-detected' 

855 '\n' 

856 'Location:' 

857 '\n%s' 

858 '', self.iterations, new_ops, '\n'.join(traceback.format_stack())) 

859 return True 

860 

861 def before_iteration(self): 

862 """Called before each iteration in a Python loop.""" 

863 if (self.check_inefficient_unroll and 

864 self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS): 

865 self.ops_before_iteration = self._get_ops() 

866 self.check_op_count_after_iteration = True 

867 

868 def after_iteration(self): 

869 """Called after each iteration in a Python loop.""" 

870 self.iterations += 1 

871 

872 self._check_unroll_limits() 

873 

874 if self.check_op_count_after_iteration: 

875 did_warn = self._verify_inefficient_unroll() 

876 if did_warn: 

877 self._stop_checking_inefficient_unroll() # Only warn once. 

878 elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3: 

879 # Once deciding to check the op counts, only do it for a few iterations. 

880 self._stop_checking_inefficient_unroll() 

881 

882 

883def _py_while_stmt(test, body, get_state, set_state, opts): 

884 """Overload of while_stmt that executes a Python while loop.""" 

885 del opts, get_state, set_state 

886 

887 if __debug__: 

888 checker = _PythonLoopChecker() 

889 before_iteration = checker.before_iteration 

890 after_iteration = checker.after_iteration 

891 before_iteration() 

892 

893 original_body = body 

894 def protected_body(): 

895 original_body() 

896 after_iteration() 

897 before_iteration() 

898 body = protected_body 

899 

900 def guarded_test(): 

901 test_result = test() 

902 try: 

903 # Note: Using try/except and not tensor_util.is_tf_type to avoid 

904 # performance degradation. 

905 return bool(test_result) 

906 except errors_impl.OperatorNotAllowedInGraphError as e: 

907 ag_logging.log( 

908 1, 

909 'Caught error while evaluating while loop condition', 

910 exc_info=True) 

911 # TODO(mdan): distinguish beteen these two cases. 

912 raise NotImplementedError( 

913 'The condition of while loop started as non-Tensor, then changed to' 

914 ' Tensor. This may happen either because variables changed type, or' 

915 ' when a break or return statement inside the loop depends on a' 

916 ' Tensor condition. In both cases, changing to a TF loop should' 

917 ' remove the error.\nSee ' 

918 'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 

919 'python/autograph/g3doc/reference/limitations.md' 

920 '#consistency-of-control-flow-types for more info.') from e 

921 while guarded_test(): 

922 body() 

923 

924 

925def _shape_invariants_mapping_to_positional_list(mapping, keys): 

926 # The keys are not expected to be hashable. 

927 mapping = {id(k): (k, v) for k, v in mapping} 

928 result = [] 

929 for k in keys: 

930 map_key, map_val = mapping.get(id(k), (None, None)) 

931 result.append( 

932 map_val if map_key is k else nest.map_structure(lambda _: None, k)) 

933 return tuple(result) 

934 

935 

936# Textual description of what a legal TF loop variable is. This description 

937# summarizes types that _placeholder_value below can handle. Keep the two 

938# together and in sync. 

939LEGAL_LOOP_TYPES = 'Tensor, int, float, bool or a list, tuple or dict thereof' 

940 

941 

942def _placeholder_value(like, shape_invariant, original=None): 

943 """Constructs a (dummy) placeholder value for a loop-initialized variable. 

944 

945 Args: 

946 like: Any object. The value created by the first iteration of the loop. If a 

947 Python scalar, the placeholder will be the zero value of that type. If a 

948 Tensor, the placeholder will be a zero tensor of matching shape and dtype. 

949 If a list, dict or tuple, the placeholder will be an identical structure 

950 of placeholders. 

951 shape_invariant: The shape invariant specified by the user (or None, if 

952 nothing was specified) for the respective variable. 

953 original: Any object. The value of the variable prior to entering the loop. 

954 Typically, this is one of the special "Undefined" value, because that's 

955 when a placeholder is needed. 

956 

957 Returns: 

958 Either a zero value of structure, shape and dtype mathing 'like', or 

959 'original', if no such zero value could be created. 

960 """ 

961 if like is None: 

962 return original, None 

963 

964 elif isinstance(like, (variables.Undefined, variables.UndefinedReturnValue)): 

965 return original, None 

966 

967 elif isinstance(like, (int, float, bool)): 

968 return type(like)(0), None 

969 

970 elif tensor_util.is_tf_type(like): 

971 

972 like_shape = shape_invariant if shape_invariant is not None else like.shape 

973 if like_shape is None or like_shape.rank is None: 

974 return array_ops.zeros((), like.dtype), like_shape 

975 

976 # If the shape contains dynamic values, set the corresponding starting 

977 # dimension to either zero or what the shape invariant specified. 

978 placeholder_shape = [] 

979 has_dynamic_dims = False 

980 for s, i in zip(like.shape, like_shape): 

981 if i is None: 

982 like_dim = 0 

983 elif isinstance(i, tensor_shape.Dimension): 

984 if i.value is None: 

985 like_dim = 0 

986 else: 

987 like_dim = i.value 

988 else: 

989 like_dim = i 

990 

991 if s is None: 

992 placeholder_shape.append(like_dim) 

993 has_dynamic_dims = True 

994 elif isinstance(s, tensor_shape.Dimension): 

995 if s.value is None: 

996 placeholder_shape.append(like_dim) 

997 has_dynamic_dims = True 

998 else: 

999 placeholder_shape.append(s.value) 

1000 else: 

1001 placeholder_shape.append(s) 

1002 

1003 if has_dynamic_dims: 

1004 invariant = like_shape 

1005 else: 

1006 invariant = None 

1007 

1008 return array_ops.zeros(placeholder_shape, like.dtype), invariant 

1009 

1010 elif isinstance(like, (list, tuple, dict)): 

1011 if shape_invariant is None: 

1012 zipped = nest.map_structure(lambda v: _placeholder_value(v, None), 

1013 nest.flatten(like)) 

1014 else: 

1015 zipped = nest.map_structure(_placeholder_value, nest.flatten(like), 

1016 nest.flatten(shape_invariant)) 

1017 vals, invars = zip(*zipped) 

1018 return (nest.pack_sequence_as(like, 

1019 vals), nest.pack_sequence_as(like, invars)) 

1020 

1021 # This is to be caught by _try_handling_undefineds, to give more context. 

1022 raise TypeError( 

1023 "Found an unsupported type '{}' while creating placeholder for {}." 

1024 ' Supported types include Tensor, int, float, bool, list, tuple or dict.' 

1025 .format(type(like).__name__, like)) 

1026 

1027 

1028def _try_handling_undefineds(body, get_state, set_state, init_vars, nulls, 

1029 shape_invariants, symbol_names): 

1030 """Makes a best-effort attempt to substitute undefineds with placeholders. 

1031 

1032 Note: this substitution requires two things to happen: 

1033 1. the types of loop variables could be inferred (usually by staging one 

1034 iteration) 

1035 2. these types could be replaced by placeholders (e.g. zero values, for 

1036 tensors). 

1037 

1038 Args: 

1039 body: a function representing the loop body. See while_stmt. 

1040 get_state: state getter for the loop statement. See while_stmt. 

1041 set_state: state getter for the loop statement. See while_stmt. 

1042 init_vars: loop variables before entering the loop. See while_stmt. 

1043 nulls: list of boolean flags indicating whether the corresponding loop var 

1044 is None or undefined. 

1045 shape_invariants: user-specified shape invariant for each loop variable. 

1046 symbol_names: list of loop variable names. See while_stmt. 

1047 

1048 Returns: 

1049 A tuple (success, new_init_vars, extra_shape_invariants, failure_message): 

1050 * success is a boolean flag indicating 

1051 whether types could be successfully inferred (step 1 above) 

1052 * new_init_vars contains the loop vars, with None or undefined values 

1053 replaced by default values, where possible (step 2 above) 

1054 * extra_shape_invariants contains shape invariants that would be needed 

1055 by while_stmt, for instance if the placeholder values had a shape 

1056 different from the corresponding loop outputs 

1057 """ 

1058 state_modified = False 

1059 first_iter_vars = None 

1060 failure_message = None 

1061 

1062 try: 

1063 # Stage an iteration of the loop body in a temporary graph. 

1064 with func_graph.FuncGraph('tmp').as_default(): 

1065 # This call to set_state helps report nicer error messages when symbols 

1066 # are inconsistently used. 

1067 # Another complication is that non_tensor values will be autocast to 

1068 # Tensor by while_loop, and their static value lost. So we need to account 

1069 # that here. 

1070 def autocast_to_tensor(v): 

1071 if isinstance( 

1072 v, (int, float, bool, str, list, tuple, np.ndarray, np.generic)): 

1073 init_val = tensor_conversion.convert_to_tensor_v2(v) 

1074 return array_ops.placeholder(init_val.dtype, init_val.shape) 

1075 return v 

1076 autocast_init_vars = nest.map_structure(autocast_to_tensor, init_vars) 

1077 set_state(autocast_init_vars) 

1078 state_modified = True 

1079 

1080 body() 

1081 first_iter_vars = get_state() 

1082 

1083 # Note: the actual placeholder value doesn't matter, because as the 

1084 # staging proved, it will be replaced by an actual value before being 

1085 # read. 

1086 inits_and_invariants = tuple( 

1087 (_placeholder_value(iv, i, v) if n else (v, None)) 

1088 for v, n, iv, i in zip(init_vars, nulls, first_iter_vars, 

1089 shape_invariants)) 

1090 init_vars, extra_shape_invariants = zip(*inits_and_invariants) 

1091 success = True 

1092 

1093 except (UnboundLocalError, TypeError, ValueError, KeyError): 

1094 ag_logging.log(1, 'Caught error while staging loop body', exc_info=True) 

1095 # Fall back to the old functionality. It will likely result in an input 

1096 # validation failure. 

1097 exc = sys.exc_info() 

1098 failure_message = ( 

1099 'Note: AutoGraph tried to define it automatically, but ran into a' 

1100 ' {}: {}'.format(exc[0].__name__, exc[1])) 

1101 

1102 finally: 

1103 if state_modified: 

1104 set_state(init_vars) 

1105 

1106 # This check runs regardless, in case we captured non-Tensor inputs. 

1107 verify_loop_init_vars( 

1108 init_vars, symbol_names, first_iter_vars, extra_message=failure_message) 

1109 

1110 return success, init_vars, extra_shape_invariants 

1111 

1112 

1113def _runtime_zero_iterations_errmsg(symbol_names, nulls, init_vars): 

1114 """Creates an error message asking for the loop to iterate at least once.""" 

1115 var_names = [] 

1116 for sn, n, v in zip(symbol_names, nulls, init_vars): 

1117 if not n: 

1118 continue 

1119 if isinstance(v, variables.UndefinedReturnValue): 

1120 var_names.append('the function return value') 

1121 else: 

1122 var_names.append(sn) 

1123 var_names = ', '.join(var_names) 

1124 return 'loop must iterate at least once to initialize {}'.format(var_names) 

1125 

1126 

1127def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts): 

1128 """Overload of while_stmt that stages a TF while_stmt.""" 

1129 init_vars = get_state() 

1130 orig_init_vars = init_vars 

1131 

1132 nulls = tuple(_is_none_or_undef(v) for v in init_vars) 

1133 if any(nulls): 

1134 shape_invars_by_init_vals = { 

1135 id(v): i for v, i in opts.get('shape_invariants', ()) 

1136 } 

1137 shape_invariants = tuple( 

1138 shape_invars_by_init_vals.get(id(v), None) for v in orig_init_vars) 

1139 (require_one_iteration, init_vars, 

1140 extra_shape_invariants) = _try_handling_undefineds(body, get_state, 

1141 set_state, init_vars, 

1142 nulls, shape_invariants, 

1143 symbol_names) 

1144 else: 

1145 require_one_iteration = False 

1146 

1147 if require_one_iteration: 

1148 merged_shape_invariants = dict(shape_invars_by_init_vals) 

1149 # This has two roles: 

1150 # 1. Shape invariants are remapped from the old init vars to the new ones. 

1151 # 2. Any new shape invariants created by the init vars are kept, but only 

1152 # if the user didn't already specify some. 

1153 for v, nv, ni in zip(orig_init_vars, init_vars, extra_shape_invariants): 

1154 merged_invariant = merged_shape_invariants.get(id(v), ni) 

1155 if merged_invariant is not None: 

1156 merged_shape_invariants[id(nv)] = merged_invariant 

1157 merged_shape_invariants = tuple((nv, merged_shape_invariants[id(nv)]) 

1158 for nv in init_vars 

1159 if id(nv) in merged_shape_invariants) 

1160 if merged_shape_invariants: 

1161 opts = dict(**opts) 

1162 opts['shape_invariants'] = merged_shape_invariants 

1163 

1164 def aug_test(*loop_vars): 

1165 if require_one_iteration: 

1166 loop_vars = loop_vars[1:] 

1167 

1168 set_state(loop_vars) 

1169 return _verify_tf_condition(test(), 'while loop') 

1170 

1171 def aug_body(*loop_vars): 

1172 if require_one_iteration: 

1173 loop_vars = loop_vars[1:] 

1174 

1175 set_state(loop_vars) 

1176 body() 

1177 new_loop_vars = get_state() 

1178 verify_tf_loop_vars( 

1179 init_vars, loop_vars, new_loop_vars, symbol_names, opts) 

1180 

1181 if require_one_iteration: 

1182 new_loop_vars = (True,) + new_loop_vars 

1183 

1184 return new_loop_vars 

1185 

1186 if 'shape_invariants' in opts: 

1187 opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list( 

1188 opts['shape_invariants'], init_vars) 

1189 

1190 while_loop_opts = dict(opts) 

1191 while_loop_opts.pop('iterate_names', None) 

1192 

1193 # Non-v2 while_loop unpacks the results when there is only one return value. 

1194 # This enforces consistency across versions. 

1195 while_loop_opts['return_same_structure'] = True 

1196 

1197 if require_one_iteration: 

1198 aug_init_vars = (False,) + init_vars 

1199 if 'shape_invariants' in while_loop_opts: 

1200 while_loop_opts['shape_invariants'] = ( 

1201 (None,) + while_loop_opts['shape_invariants']) 

1202 else: 

1203 aug_init_vars = init_vars 

1204 

1205 final_loop_vars = while_loop.while_loop(aug_test, aug_body, aug_init_vars, 

1206 **while_loop_opts) 

1207 

1208 if require_one_iteration: 

1209 with ops.control_dependencies([ 

1210 control_flow_assert.Assert(final_loop_vars[0], [ 

1211 _runtime_zero_iterations_errmsg(symbol_names, nulls, orig_init_vars) 

1212 ]) 

1213 ]): 

1214 final_loop_vars = nest.map_structure( 

1215 lambda v: (array_ops.identity(v) if tensor_util.is_tf_type(v) else v), 

1216 final_loop_vars[1:], 

1217 ) 

1218 

1219 set_state(final_loop_vars) 

1220 

1221 

1222def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts): 

1223 """Functional form of an if statement. 

1224 

1225 The conditional operates on a state, which includes all symbols whose values 

1226 are a function of the branch taken. 

1227 

1228 For example, given the code below that calculates the abs function: 

1229 

1230 ``` 

1231 x = 1 

1232 if x > 0: 

1233 x = -x 

1234 ``` 

1235 

1236 The state is represented by the variable `x`. The `body, `orelse` and 

1237 `set_state` functions must bind to the original `x` symbol, using `nonlocal`. 

1238 

1239 The inputs and outputs of the callables representing the loop blocks are not 

1240 explicit - instead, these functions must use nonlocal/global for side effects. 

1241 The inputs and outputs are instead controlled by the set_state/get_state 

1242 functions. 

1243 

1244 Args: 

1245 cond: Boolean. 

1246 body: Callable representing the main block of the conditional. 

1247 orelse: Callable representing the else block of the conditional. 

1248 get_state: Function that returns a tuple containing the values of all 

1249 composite symbols modified within the conditional. This allows access to 

1250 state that branches may mutate through side effects. This function is not 

1251 needed and should not be called when dispatching to code matching Python's 

1252 default semantics. This is useful for checkpointing to avoid unintended 

1253 side-effects when staging requires evaluating all code-paths. 

1254 set_state: Function to set the values of all composite symbols modified 

1255 within the conditional. This is the complement to get_state, used to 

1256 restore checkpointed values. The single argument a tuple containing values 

1257 for each composite symbol that may be modified in a branch of the 

1258 conditional. The is usually the result of a call to get_state. 

1259 symbol_names: Tuple containing basic loop var names. 

1260 nouts: Number of variables output by the statement. Vars which are not 

1261 outputs will not be passed through staged control flow such as tf.cond. 

1262 This includes variables that are defined before the conditional, but are 

1263 not used after it. 

1264 """ 

1265 # Note: tf.cond doesn't support SparseTensor. 

1266 if tensors.is_dense_tensor(cond): 

1267 _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts) 

1268 else: 

1269 _py_if_stmt(cond, body, orelse) 

1270 

1271 

1272def _tf_if_stmt( 

1273 cond, body, orelse, get_state, set_state, symbol_names, nouts): 

1274 """Overload of if_stmt that stages a TF cond.""" 

1275 cond = _verify_tf_condition(cond, 'if statement') 

1276 

1277 if not nouts: 

1278 prev_get_state, prev_set_state = get_state, set_state 

1279 # Control flow V1 wants at least one output. 

1280 get_state = lambda: (0,) + prev_get_state() 

1281 set_state = lambda v: prev_set_state(v[1:]) 

1282 symbol_names += ('<unused dummy>',) 

1283 nouts = 1 

1284 

1285 init_vars = get_state() 

1286 

1287 # TODO(mdan): Use nonlocal once we no longer need to support py2. 

1288 new_body_vars_ = [None] 

1289 new_orelse_vars_ = [None] 

1290 

1291 def aug_body(): 

1292 set_state(init_vars) 

1293 body() 

1294 new_body_vars = get_state() 

1295 new_body_vars = new_body_vars[:nouts] 

1296 new_body_vars_[0] = new_body_vars 

1297 _verify_tf_cond_branch_vars(new_body_vars, symbol_names, 'main') 

1298 if new_orelse_vars_[0] is not None: 

1299 _verify_tf_cond_vars(new_body_vars, new_orelse_vars_[0], symbol_names) 

1300 return new_body_vars 

1301 

1302 def aug_orelse(): 

1303 set_state(init_vars) 

1304 orelse() 

1305 new_orelse_vars = get_state() 

1306 new_orelse_vars = new_orelse_vars[:nouts] 

1307 new_orelse_vars_[0] = new_orelse_vars 

1308 _verify_tf_cond_branch_vars(new_orelse_vars, symbol_names, 'else') 

1309 if new_body_vars_[0] is not None: 

1310 _verify_tf_cond_vars(new_body_vars_[0], new_orelse_vars, symbol_names) 

1311 return new_orelse_vars 

1312 

1313 final_cond_vars = tf_cond.cond( 

1314 cond, aug_body, aug_orelse, strict=True) 

1315 final_cond_vars = final_cond_vars + init_vars[nouts:] 

1316 

1317 set_state(final_cond_vars) 

1318 

1319 

1320def _py_if_stmt(cond, body, orelse): 

1321 """Overload of if_stmt that executes a Python if statement.""" 

1322 return body() if cond else orelse()