Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py: 17%

1026 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Control Flow Operations. 

16 

17See the [autograph](https://www.tensorflow.org/guide/autograph) guide. 

18""" 

19# pylint: disable=g-bad-name 

20import abc 

21 

22from tensorflow.core.framework import attr_value_pb2 

23from tensorflow.core.protobuf import control_flow_pb2 

24from tensorflow.python.eager import context 

25from tensorflow.python.framework import composite_tensor 

26from tensorflow.python.framework import constant_op 

27from tensorflow.python.framework import dtypes 

28from tensorflow.python.framework import indexed_slices 

29from tensorflow.python.framework import ops 

30from tensorflow.python.framework import tensor_shape 

31from tensorflow.python.framework import tensor_spec 

32from tensorflow.python.framework import tensor_util 

33from tensorflow.python.framework import type_spec 

34from tensorflow.python.ops import array_ops 

35from tensorflow.python.ops import cond as tf_cond 

36from tensorflow.python.ops import control_flow_assert 

37from tensorflow.python.ops import control_flow_case 

38from tensorflow.python.ops import control_flow_util as util 

39from tensorflow.python.ops import gen_array_ops 

40from tensorflow.python.ops import gen_control_flow_ops 

41from tensorflow.python.ops import math_ops 

42from tensorflow.python.ops import tensor_array_ops 

43from tensorflow.python.ops import while_loop as while_loop_ops 

44# go/tf-wildcard-import 

45# pylint: disable=wildcard-import,undefined-variable 

46from tensorflow.python.ops.gen_control_flow_ops import * 

47# pylint: enable=wildcard-import 

48from tensorflow.python.util import compat 

49from tensorflow.python.util import dispatch 

50from tensorflow.python.util import nest 

51from tensorflow.python.util import variable_utils 

52from tensorflow.python.util.tf_export import tf_export 

53 

54# TODO(b/269483538): needed for references while refactors are in progress 

55case = control_flow_case.case 

56_case_helper = control_flow_case._case_helper # pylint: disable=protected-access 

57case_v2 = control_flow_case.case_v2 

58_case_create_default_action = control_flow_case._case_create_default_action # pylint: disable=protected-access 

59_case_verify_and_canonicalize_args = control_flow_case._case_verify_and_canonicalize_args # pylint: disable=protected-access 

60_assert_at_most_n_true = control_flow_case._assert_at_most_n_true # pylint: disable=protected-access 

61Assert = control_flow_assert.Assert 

62_summarize_eager = control_flow_assert._summarize_eager # pylint: disable=protected-access 

63while_loop = while_loop_ops.while_loop 

64while_loop_v2 = while_loop_ops.while_loop_v2 

65cond = tf_cond.cond 

66cond_for_tf_v2 = tf_cond.cond_for_tf_v2 

67_UnpackIfSingleton = tf_cond._UnpackIfSingleton # pylint: disable=protected-access 

68_eager_cond_implementation = tf_cond._eager_cond_implementation # pylint: disable=protected-access 

69_cast_indexed_slice_indices = tf_cond._cast_indexed_slice_indices # pylint: disable=protected-access 

70 

71# We override the 'tuple' for a control flow op, so we keep python's 

72# existing 'tuple' for later use in this module. 

73_basetuple = tuple 

74 

75 

76# pylint: disable=protected-access 

77 

78 

79def _Identity(tensor, name=None): 

80 """Return a tensor with the same shape and contents as the input tensor. 

81 

82 Args: 

83 tensor: A Tensor. 

84 name: A name for this operation (optional). 

85 

86 Returns: 

87 A Tensor with the same type and value as the input Tensor. 

88 """ 

89 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) 

90 # TODO(b/246438937): Remove this when we expand ResourceVariables into 

91 # dt_resource tensors. 

92 tensor = variable_utils.convert_variables_to_tensors(tensor) 

93 if isinstance(tensor, ops.Tensor): 

94 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access 

95 return gen_array_ops.ref_identity(tensor, name=name) 

96 else: 

97 return array_ops.identity(tensor, name=name) 

98 elif isinstance(tensor, composite_tensor.CompositeTensor): 

99 return nest.map_structure(_Identity, tensor, expand_composites=True) 

100 else: 

101 raise TypeError("'tensor' must be a Tensor or CompositeTensor. " 

102 f"Received: {type(tensor)}.") 

103 

104 

105def _NextIteration(tensor, name=None): 

106 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) 

107 if isinstance(tensor, ops.Tensor): 

108 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access 

109 return ref_next_iteration(tensor, name=name) 

110 else: 

111 return next_iteration(tensor, name=name) 

112 elif isinstance(tensor, composite_tensor.CompositeTensor): 

113 return nest.map_structure(_NextIteration, tensor, expand_composites=True) 

114 else: 

115 raise TypeError("'tensor' must be a Tensor or CompositeTensor. " 

116 f"Received: {type(tensor)}.") 

117 

118 

119def _Enter(tensor, 

120 frame_name, 

121 is_constant=False, 

122 parallel_iterations=10, 

123 use_ref=True, 

124 use_input_shape=True, 

125 name=None): 

126 """Creates or finds a child frame, and makes `tensor` available to it. 

127 

128 The unique `frame_name` is used by the `Executor` to identify frames. If 

129 `is_constant` is true, `tensor` is a constant in the child frame; otherwise 

130 it may be changed in the child frame. At most `parallel_iterations` 

131 iterations are run in parallel in the child frame. 

132 

133 Args: 

134 tensor: The tensor to be made available to the child frame. 

135 frame_name: The name of the child frame. 

136 is_constant: If true, the output is constant within the child frame. 

137 parallel_iterations: The number of iterations allowed to run in parallel. 

138 use_ref: If true, use ref_enter if tensor is of ref type. 

139 use_input_shape: If true, set the result's shape based on tensor's shape. 

140 name: A name for this operation (optional). 

141 

142 Returns: 

143 The same tensor as `tensor`. 

144 

145 Raises: 

146 ValueError: If any tensor in `tensor` has a less specific shape 

147 than its corresponding shape in `shape_invariant`. 

148 """ 

149 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) 

150 if isinstance(tensor, ops.Tensor): 

151 if tensor.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access 

152 result = gen_control_flow_ops.ref_enter( 

153 tensor, frame_name, is_constant, parallel_iterations, name=name) 

154 else: 

155 result = gen_control_flow_ops.enter( 

156 tensor, frame_name, is_constant, parallel_iterations, name=name) 

157 if use_input_shape: 

158 result.set_shape(tensor.get_shape()) 

159 return result 

160 elif isinstance(tensor, composite_tensor.CompositeTensor): 

161 

162 def enter_component(t): 

163 return _Enter(t, frame_name, is_constant, parallel_iterations, use_ref, 

164 use_input_shape) 

165 

166 return nest.map_structure(enter_component, tensor, expand_composites=True) 

167 else: 

168 raise TypeError("'tensor' must be a Tensor or CompositeTensor. " 

169 f"Received: {type(tensor)}.") 

170 

171 

172def exit(tensor, name=None): # pylint: disable=redefined-builtin 

173 """Exits the current frame to its parent frame. 

174 

175 Exit makes its input `tensor` available to the parent frame. 

176 

177 Args: 

178 tensor: The tensor to be made available to the parent frame. 

179 name: A name for this operation (optional). 

180 

181 Returns: 

182 The same tensor as `tensor`. 

183 """ 

184 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) 

185 if isinstance(tensor, ops.Tensor): 

186 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access 

187 return gen_control_flow_ops.ref_exit(tensor, name) 

188 else: 

189 return gen_control_flow_ops._exit(tensor, name) 

190 elif isinstance(tensor, composite_tensor.CompositeTensor): 

191 return nest.map_structure(exit, tensor, expand_composites=True) 

192 else: 

193 raise TypeError("'tensor' must be a Tensor or CompositeTensor. " 

194 f"Received: {type(tensor)}.") 

195 

196 

197def switch(data, pred, dtype=None, name=None): 

198 """Forwards `data` to an output determined by `pred`. 

199 

200 If `pred` is false, the `data` input is forwarded to the first output. 

201 Otherwise, the data goes to the second output. 

202 

203 This op handles `Tensor`s and `IndexedSlices`. 

204 

205 Args: 

206 data: The tensor to be forwarded to the appropriate output. 

207 pred: A scalar that specifies which output port will receive data. 

208 dtype: Optional element type for the returned tensor. If missing, the type 

209 is inferred from the type of `value`. 

210 name: A name for this operation (optional). 

211 

212 Returns: 

213 `(output_false, output_true)`: If `pred` is true, data will be forwarded 

214 to `output_true`, otherwise it goes to `output_false`. 

215 """ 

216 with ops.name_scope(name, "Switch", [data, pred]) as name: 

217 data = ops.internal_convert_to_tensor_or_composite( 

218 data, dtype=dtype, name="data", as_ref=True) 

219 pred = ops.convert_to_tensor(pred, name="pred") 

220 if isinstance(data, ops.Tensor): 

221 return gen_control_flow_ops.switch(data, pred, name=name) 

222 else: 

223 if not isinstance(data, composite_tensor.CompositeTensor): 

224 raise TypeError( 

225 "'data' must be a Tensor or CompositeTensor. " 

226 f"Received: {type(data)}.") 

227 tensors = nest.flatten(data, expand_composites=True) 

228 mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors] 

229 mapped_f, mapped_t = zip(*mapped) 

230 return (nest.pack_sequence_as(data, mapped_f, expand_composites=True), 

231 nest.pack_sequence_as(data, mapped_t, expand_composites=True)) 

232 

233 

234def _SwitchRefOrTensor(data, pred, name="Switch"): 

235 """Forwards `data` to an output determined by `pred`. 

236 

237 If `pred` is false, the `data` input is forwarded to the first output. 

238 Otherwise, the data goes to the second output. 

239 

240 This op handles `Tensor`s and `IndexedSlices`. 

241 

242 Args: 

243 data: The tensor to be forwarded to the appropriate output. 

244 pred: A scalar that specifies which output port will receive data. 

245 name: A name for this operation (optional). 

246 

247 Returns: 

248 `(output_false, output_true)`: If `pred` is true, data will be forwarded to 

249 `output_true`, otherwise it goes to `output_false`. 

250 

251 Raises: 

252 TypeError: if data is not a Tensor or IndexedSlices 

253 """ 

254 data = ops.convert_to_tensor_or_composite(data, name="data") 

255 # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below 

256 # addresses the following scenario. 

257 # 

258 # Assume you execute Optimizer.apply_gradients() in a branch of a cond(). 

259 # 

260 # 1. The update op is created inside a `with ops.colocate(var):` block 

261 # 

262 # 2. Some tensor `data` is captured and a switch is created in a 

263 # `with ops.colocate_with(data):` block. 

264 # 

265 # with ops.colocate_with(var): 

266 # with ops.colocate_with(data): 

267 # op = ... 

268 # 

269 # var and data may be pinned to different devices, so we want to ops 

270 # created within ops.colocate_with(data) to ignore the existing stack. 

271 with ops.colocate_with(data, ignore_existing=True): 

272 if isinstance(data, ops.Tensor): 

273 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 

274 return ref_switch(data, pred, name=name) 

275 return switch(data, pred, name=name) 

276 

277 

278def merge(inputs, name=None): 

279 """Returns the value of an available element of `inputs`. 

280 

281 This op tests each of the tensors in `inputs` in turn to determine if any of 

282 them is available. If it finds an available tensor, it returns it and its 

283 index in `inputs`. 

284 

285 It is an error if more than one tensor in `inputs` is available. If no tensor 

286 in `inputs` is available, the returned tensor and index are not set. 

287 

288 This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of 

289 `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices 

290 before merging. 

291 

292 Args: 

293 inputs: The input tensors, at most one of which is available. 

294 name: A name for this operation (optional). 

295 

296 Returns: 

297 A tuple containing the chosen input tensor and its index in `inputs`. 

298 

299 Raises: 

300 ValueError: If any of the inputs is None, or inputs are IndexedSlices and 

301 some but not all have a dense_shape property. 

302 """ 

303 if any(inp is None for inp in inputs): 

304 raise ValueError("At least one of the merge inputs is None: %s" % inputs) 

305 with ops.name_scope(name, "Merge", inputs) as name: 

306 inputs = [ 

307 ops.internal_convert_to_tensor_or_composite(inp, as_ref=True) 

308 for inp in inputs 

309 ] 

310 if all(isinstance(v, ops.Tensor) for v in inputs): 

311 if all(v.dtype._is_ref_dtype for v in inputs): # pylint: disable=protected-access 

312 return gen_control_flow_ops.ref_merge(inputs, name) 

313 else: 

314 return gen_control_flow_ops.merge(inputs, name) 

315 else: 

316 # If there is a mix of tensors and indexed slices, then convert the 

317 # tensors to indexed slices. 

318 if all( 

319 isinstance(v, (indexed_slices.IndexedSlices, ops.Tensor)) 

320 for v in inputs): 

321 inputs = math_ops._as_indexed_slices_list(inputs, optimize=False) 

322 

323 for v in inputs: 

324 if not isinstance(v, composite_tensor.CompositeTensor): 

325 raise TypeError("Type %s not supported" % type(v)) 

326 

327 for v in inputs[1:]: 

328 nest.assert_same_structure(inputs[0], v, expand_composites=True) 

329 

330 flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs] 

331 merged_results = [ 

332 gen_control_flow_ops.merge(component) 

333 for component in zip(*flat_inputs) 

334 ] 

335 flat_merged = [tensor for (tensor, _) in merged_results] 

336 chosen_index = merged_results[0][1] 

337 merged_inputs = nest.pack_sequence_as( 

338 inputs[0], flat_merged, expand_composites=True) 

339 return (merged_inputs, chosen_index) 

340 

341 

342def _convert_tensorarray_to_flow(tensor_or_tensor_array): 

343 if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray): 

344 return tensor_or_tensor_array.flow 

345 else: 

346 return tensor_or_tensor_array 

347 

348 

349def _convert_flow_to_tensorarray(tensor_or_tensor_array, tensor_or_flow): 

350 if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray): 

351 return tensor_array_ops.build_ta_with_new_flow(tensor_or_tensor_array, 

352 tensor_or_flow) 

353 else: 

354 return tensor_or_flow 

355 

356 

357def _convert_to_tensor_or_composite_or_tensorarray(var): 

358 if isinstance(var, tensor_array_ops.TensorArray): 

359 return var 

360 return ops.convert_to_tensor_or_composite(var) 

361 

362 

363# TODO(xjun): replace this with is_subtype_of after it is landed. 

364def _ShapeLessThanOrEqual(shape1, shape2): 

365 if shape2.dims is None: 

366 return True 

367 if shape1.ndims != shape2.ndims: 

368 return False 

369 for dim1, dim2 in zip(shape1.dims, shape2.dims): 

370 if dim2.value is not None and dim1.value != dim2.value: 

371 return False 

372 return True 

373 

374 

375def _shape_invariant_to_type_spec(var, shape=None): 

376 """Converts a shape invariant to a TypeSpec. 

377 

378 If `var` is a TensorArray, it will first be converted to its flow. 

379 

380 Args: 

381 var: The tensor, tensor array or composite tensor whose shape is described 

382 by the shape invariant. 

383 shape: A `TypeSpec` or `TensorShape`. If `shape` is already a `TypeSpec`, 

384 then it is simply returned as-is. 

385 

386 Returns: 

387 A `TypeSpec` for `var`, consistent with the given shape. 

388 

389 Raises: 

390 TypeError: If `shape` is a TypeSpec and not compatible with `var`. 

391 TypeError: If `shape` is not None, a TypeSpec, or a TensorShape. 

392 TypeError: If `shape` is a TensorShape, `var` is a CompositeTensor, and 

393 `var` doesn't implement the `_shape_invariant_to_type_spec` method. 

394 """ 

395 var = _convert_tensorarray_to_flow(var) 

396 if shape is None: 

397 return type_spec.type_spec_from_value(var) 

398 elif isinstance(shape, type_spec.TypeSpec): 

399 if not shape.is_compatible_with(var): 

400 raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var)) 

401 return shape 

402 elif not isinstance(shape, tensor_shape.TensorShape): 

403 raise TypeError( 

404 "'shape' must be one of TypeSpec, TensorShape or None. " 

405 f"Received: {type(shape)}") 

406 

407 if isinstance(var, ops.Tensor): 

408 return tensor_spec.TensorSpec(shape, var.dtype) 

409 else: 

410 try: 

411 return var._shape_invariant_to_type_spec(shape) # pylint: disable=protected-access 

412 except NotImplementedError as e: 

413 raise TypeError( 

414 f"To describe or constrain a {type(var).__name__}, use a " 

415 f"{type(var._type_spec).__name__} instead of a TensorShape.") from e # pylint: disable=protected-access 

416 

417 

418def _EnforceShapeInvariant(merge_var, next_var): 

419 """Check if the shapes of the loops variables are invariants. 

420 

421 Args: 

422 merge_var: The tensor representing the initial values of the loop 

423 variables. 

424 next_var: The tensor representing the values of the loop variables 

425 after one loop iteration. 

426 

427 Raises: 

428 ValueError: If any tensor in `merge_var` has a more specific shape than 

429 its corresponding tensor in `next_var`. 

430 """ 

431 if isinstance(merge_var, ops.Tensor): 

432 m_shape = merge_var.get_shape() 

433 n_shape = next_var.get_shape() 

434 if not _ShapeLessThanOrEqual(n_shape, m_shape): 

435 enter = merge_var.op.inputs[0].op 

436 assert util.IsLoopEnter(enter) 

437 input_t = enter.inputs[0] 

438 raise ValueError( 

439 "Input tensor '%s' enters the loop with shape %s, but has shape %s " 

440 "after one iteration. To allow the shape to vary across iterations, " 

441 "use the `shape_invariants` argument of tf.while_loop to specify a " 

442 "less-specific shape." % (input_t.name, input_t.shape, n_shape)) 

443 else: 

444 raise TypeError("'merge_var' must be a Tensor. " 

445 f"Received: {type(merge_var)}.") 

446 

447 

448def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True): 

449 """Add NextIteration and back edge from v to m.""" 

450 if isinstance(m, ops.Tensor): 

451 v = ops.convert_to_tensor(v) 

452 v = _NextIteration(v) 

453 if enforce_shape_invariant: 

454 # Make sure the shapes of loop outputs are correct. We do this before 

455 # calling _update_input, which will raise a less-helpful error message if 

456 # the types don't match. 

457 # TODO(skyewm): call this for other cases below (needs testing) 

458 _EnforceShapeInvariant(m, v) 

459 m.op._update_input(1, v) # pylint: disable=protected-access 

460 elif isinstance(m, composite_tensor.CompositeTensor): 

461 # pylint: disable=protected-access 

462 def update_component(m_component, v_component): 

463 m_component.op._update_input(1, v_component) 

464 

465 if isinstance(m, indexed_slices.IndexedSlices): 

466 v = math_ops._as_indexed_slices(v, optimize=False) 

467 # pylint: enable=protected-access 

468 v = _NextIteration(v) 

469 return nest.map_structure(update_component, m, v, expand_composites=True) 

470 else: 

471 raise TypeError("'m' must be a Tensor or CompositeTensor. " 

472 f"Received: {type(m)}.") 

473 return v 

474 

475 

476class ControlFlowContext(metaclass=abc.ABCMeta): 

477 """The base class for control flow context. 

478 

479 The usage pattern is a sequence of (Enter, Exit) followed by a final 

480 ExitResult. 

481 

482 We maintain the following state for control flow contexts during graph 

483 construction: 

484 1. graph has _control_flow_context: the current context used to 

485 construct new nodes. Changed by ctxt.Enter() and ctxt.Exit() 

486 2. op has _control_flow_context: the context to which the op belongs. 

487 Set at the time the op is created. Immutable. 

488 3. A ControlFlowContext has _outer_context: the context in which this 

489 context is created. Set at the time a context is created. Immutable. 

490 4. A ControlFlowContext has _context_stack. 

491 Pushed and popped by ctxt.Enter() and ctxt.Exit() 

492 """ 

493 

494 def __init__(self, values_def=None, import_scope=None): 

495 self._nested_contexts = [] 

496 self._outer_context = ops.get_default_graph()._get_control_flow_context() 

497 if self._outer_context: 

498 self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access 

499 self._context_stack = [] 

500 if values_def: 

501 self._init_values_from_proto(values_def, import_scope=import_scope) 

502 else: 

503 # The names of tensors that have been already seen in this context. 

504 self._values = set() 

505 # The keys are the names of tensors referenced by but external to this 

506 # context. Each value is the Tensor that should be used by this context to 

507 # access the key value (e.g. a switch output guarding a cond input value). 

508 self._external_values = {} 

509 

510 def _init_values_from_proto(self, values_def, import_scope=None): 

511 """Initializes values and external_values from `ValuesDef` protocol buffer. 

512 

513 Args: 

514 values_def: `ValuesDef` protocol buffer. 

515 import_scope: Optional `string`. Name scope to add. 

516 """ 

517 assert isinstance(values_def, control_flow_pb2.ValuesDef) 

518 self._values = set( 

519 ops.prepend_name_scope(value, import_scope) 

520 for value in values_def.values) 

521 g = ops.get_default_graph() 

522 self._external_values = {} 

523 for k, v in values_def.external_values.items(): 

524 k = ops.prepend_name_scope(k, import_scope) 

525 self._external_values[k] = g.as_graph_element( 

526 ops.prepend_name_scope(v, import_scope)) 

527 op_names = set([ 

528 op.split(":")[0] 

529 for op in self._values - set(self._external_values.keys()) 

530 ]) 

531 for op in op_names: 

532 # pylint: disable=protected-access 

533 g.as_graph_element(op)._set_control_flow_context(self) 

534 # pylint: enable=protected-access 

535 

536 @property 

537 def name(self): 

538 return self._name 

539 

540 @property 

541 def outer_context(self): 

542 """Return the context containing this context.""" 

543 return self._outer_context 

544 

545 @property 

546 def grad_state(self): 

547 raise NotImplementedError("Abstract method") 

548 

549 @property 

550 def back_prop(self): 

551 raise NotImplementedError("Abstract method") 

552 

553 @abc.abstractmethod 

554 def to_control_flow_context_def(self, context_def, export_scope=None): 

555 """Serializes this into `context_def`. 

556 

557 Args: 

558 context_def: a `ControlFlowContextDef` protocol buffer. 

559 export_scope: Optional `string`. Name scope to remove. 

560 """ 

561 raise NotImplementedError("Abstract method") 

562 

563 def _to_values_def(self, export_scope=None): 

564 """Converts the values to a `ValuesDef` protocol buffer. 

565 

566 Args: 

567 export_scope: Optional `string`. Name scope to remove. 

568 

569 Returns: 

570 A `ValuesDef` protocol buffer. 

571 """ 

572 values_def = control_flow_pb2.ValuesDef() 

573 values_def.values.extend( 

574 [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)]) 

575 for k, v in self._external_values.items(): 

576 k = ops.strip_name_scope(k, export_scope) 

577 values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope) 

578 return values_def 

579 

580 def AddName(self, name): 

581 self._values.add(name) 

582 

583 # pylint: disable=protected-access 

584 def Enter(self): 

585 """Enter this control flow context.""" 

586 graph = ops.get_default_graph() 

587 self._context_stack.append(graph._get_control_flow_context()) 

588 graph._set_control_flow_context(self) 

589 

590 def Exit(self): 

591 """Exit this control flow context.""" 

592 graph = ops.get_default_graph() 

593 last_context = self._context_stack.pop() 

594 graph._set_control_flow_context(last_context) 

595 

596 def EnterGradientColocation(self, op, gradient_uid): 

597 """Start building a gradient colocated with an op.""" 

598 if self._outer_context: 

599 self._outer_context.EnterGradientColocation(op, gradient_uid) 

600 

601 def ExitGradientColocation(self, op, gradient_uid): 

602 """Start building a gradient colocated with an op.""" 

603 if self._outer_context: 

604 self._outer_context.ExitGradientColocation(op, gradient_uid) 

605 

606 def ExitResult(self, result): 

607 """Make a list of tensors available in the outer context.""" 

608 if self._outer_context: 

609 def fn(x): 

610 self._outer_context.AddName(x.name) 

611 return x 

612 nest.map_structure(fn, result, expand_composites=True) 

613 

614 def GetWhileContext(self): 

615 """Return the while context containing this context.""" 

616 if self._outer_context: 

617 return self._outer_context.GetWhileContext() 

618 return None 

619 

620 def _RemoveExternalControlEdges(self, op): 

621 """Remove any external control dependency on this op.""" 

622 while_ctxt = self.GetWhileContext() 

623 # A control input of `op` is internal if it is in the same while 

624 # loop context as the enclosing while loop context of self. 

625 if while_ctxt is None: 

626 internal_control_inputs, external_control_inputs = op.control_inputs, [] 

627 else: 

628 internal_control_inputs, external_control_inputs = [], [] 

629 for x in op.control_inputs: 

630 ctxt = util.GetOutputContext(x) 

631 if ctxt is not None and ctxt.GetWhileContext() == while_ctxt: 

632 internal_control_inputs.append(x) 

633 else: 

634 external_control_inputs.append(x) 

635 if len(internal_control_inputs) != len(op.control_inputs): 

636 # TODO(mdan): perhaps there should be a replace_control_inputs() 

637 op._remove_all_control_inputs() 

638 op._add_control_inputs(internal_control_inputs) 

639 return internal_control_inputs, external_control_inputs 

640 

641 # pylint: enable=protected-access 

642 

643 def AddInnerOp(self, op): 

644 """Notifies a scope about an operator added to an inner scope.""" 

645 if self._outer_context: 

646 self._outer_context.AddInnerOp(op) 

647 

648 def GetControlPivot(self): 

649 """Returns the pivot node for this context, or None.""" 

650 return None 

651 

652 def IsWhileContext(self): 

653 return False 

654 

655 def IsCondContext(self): 

656 return False 

657 

658 def IsXLAContext(self): 

659 return False 

660 

661 def __str__(self): 

662 return self.name 

663 

664 

665class CondContext(ControlFlowContext): 

666 """The context for the conditional construct.""" 

667 

668 def __init__(self, 

669 pred=None, 

670 pivot=None, 

671 branch=None, 

672 name="cond_text", 

673 context_def=None, 

674 import_scope=None): 

675 """Creates a `CondContext`. 

676 

677 Args: 

678 pred: The `boolean` tensor for the conditional predicate. 

679 pivot: The predicate tensor in this branch. 

680 branch: 0 or 1 representing this branch. 

681 name: Name of the `CondContext` python object. 

682 context_def: Optional `ContextDef` protocol buffer to initialize the 

683 `CondContext` object from. 

684 import_scope: Optional `string`. Name scope to add. Only used when 

685 initialing from protocol buffer. 

686 """ 

687 self._name = ops.get_default_graph().unique_name(name) 

688 

689 if context_def: 

690 self._init_from_proto(context_def, import_scope=import_scope) 

691 else: 

692 # Initializes the default fields. 

693 ControlFlowContext.__init__(self) 

694 self._pred = pred # The boolean tensor for the cond predicate 

695 self._pivot = pivot # The predicate tensor in this branch 

696 self._branch = branch # 0 or 1 representing this branch 

697 

698 # Values considered to have been already seen in this context. pred is not 

699 # included in this context. 

700 self._values.add(pred.name) 

701 self._external_values[pred.name] = pred 

702 self._values.add(pivot.name) 

703 pivot.op._set_control_flow_context(self) # pylint: disable=protected-access 

704 

705 def _init_from_proto(self, context_def, import_scope=None): 

706 """Creates a new `CondContext` from protocol buffer. 

707 

708 Args: 

709 context_def: `CondContextDef` protocol buffer. 

710 import_scope: Optional `string`. Name scope to add. 

711 """ 

712 assert isinstance(context_def, control_flow_pb2.CondContextDef) 

713 # Create from context_def. 

714 g = ops.get_default_graph() 

715 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 

716 self._pred = g.as_graph_element( 

717 ops.prepend_name_scope(context_def.pred_name, import_scope)) 

718 self._pivot = g.as_graph_element( 

719 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 

720 self._branch = context_def.branch 

721 super(CondContext, self).__init__( 

722 values_def=context_def.values_def, import_scope=import_scope) 

723 

724 @property 

725 def pred(self): 

726 return self._pred 

727 

728 @property 

729 def pivot(self): 

730 return self._pivot 

731 

732 @property 

733 def branch(self): 

734 return self._branch 

735 

736 @property 

737 def grad_state(self): 

738 if self.GetWhileContext(): 

739 return self.GetWhileContext().grad_state 

740 return None 

741 

742 @property 

743 def back_prop(self): 

744 if self.GetWhileContext(): 

745 self.GetWhileContext().back_prop 

746 return False 

747 

748 def GetControlPivot(self): 

749 return self._pivot 

750 

751 def to_proto(self, export_scope=None): 

752 """Converts a `CondContext` to a `CondContextDef` protocol buffer. 

753 

754 Args: 

755 export_scope: Optional `string`. Name scope to remove. 

756 

757 Returns: 

758 A `CondContextDef` protocol buffer. 

759 """ 

760 if (export_scope is None or self.name.startswith(export_scope)): 

761 context_def = control_flow_pb2.CondContextDef() 

762 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 

763 context_def.pred_name = ops.strip_name_scope(self._pred.name, 

764 export_scope) 

765 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 

766 export_scope) 

767 context_def.branch = self._branch 

768 context_def.values_def.MergeFrom( 

769 super(CondContext, self)._to_values_def(export_scope)) 

770 for nested in self._nested_contexts: 

771 nested_def = context_def.nested_contexts.add() 

772 nested.to_control_flow_context_def(nested_def) 

773 

774 return context_def 

775 else: 

776 return None 

777 

778 @staticmethod 

779 def from_proto(context_def, import_scope=None): 

780 """Returns a `CondContext` object created from `context_def`.""" 

781 ret = CondContext(context_def=context_def, import_scope=import_scope) 

782 

783 ret.Enter() 

784 for nested_def in context_def.nested_contexts: 

785 from_control_flow_context_def(nested_def, import_scope=import_scope) 

786 ret.Exit() 

787 return ret 

788 

789 def to_control_flow_context_def(self, context_def, export_scope=None): 

790 context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 

791 

792 def AddValue(self, val): 

793 """Add `val` to the current context and its outer context recursively.""" 

794 if val.name in self._values: 

795 # Use the real value if it comes from outer context. This is needed in 

796 # particular for nested conds. 

797 result = self._external_values.get(val.name) 

798 result = val if result is None else result 

799 else: 

800 result = val 

801 self._values.add(val.name) 

802 if self._outer_context: 

803 result = self._outer_context.AddValue(val) 

804 self._values.add(result.name) 

805 self._external_values[result.name] = result 

806 with ops.control_dependencies(None): 

807 result = _SwitchRefOrTensor(result, self._pred)[self._branch] 

808 if self._outer_context: 

809 self._outer_context.AddInnerOp(result.op) 

810 

811 result.op.graph.prevent_fetching(result.op) 

812 # pylint: disable=protected-access 

813 result.op._set_control_flow_context(self) 

814 # pylint: enable=protected-access 

815 

816 # Mark Switch output as seen by this context and any outer contexts, 

817 # just like what we do for normal op outputs in _AddOpInternal() below. 

818 ctxt = self 

819 while ctxt is not None: 

820 # pylint: disable=protected-access 

821 ctxt._values.add(result.name) 

822 ctxt = ctxt._outer_context 

823 # pylint: enable=protected-access 

824 

825 self._external_values[val.name] = result 

826 return result 

827 

828 def AddOp(self, op): 

829 self._AddOpInternal(op) 

830 

831 def _AddOpInternal(self, op): 

832 """Add `op` to the current context.""" 

833 if not op.inputs: 

834 # If we're in a while loop, remove any control inputs from outside the 

835 # loop. 

836 self._RemoveExternalControlEdges(op) 

837 

838 if not any( 

839 util.OpInContext(input_op, self) for input_op in op.control_inputs): 

840 # pylint: disable=protected-access 

841 op._add_control_input(self._pivot.op) 

842 # pylint: enable=protected-access 

843 else: 

844 # Make each input to 'op' available in this CondContext. If an input is 

845 # already part of this context there's nothing to do, but if it's 

846 # external, AddValue() will handle adding the appropriate Switch node and 

847 # other bookkeeping. 

848 for index in range(len(op.inputs)): 

849 x = op.inputs[index] 

850 if op.type == "Merge" and x.op.type == "NextIteration": 

851 # Edge case: if we're importing a while loop inside this CondContext, 

852 # AddValue() will not correctly handle the NextIteration inputs to 

853 # Merge node. The problem is that the NextIteration should also be 

854 # part of this context, but if we're importing it won't have been 

855 # processed and added to the context yet, so AddValue() will try to 

856 # add a Switch which results in an invalid graph. Instead, we use the 

857 # NextIteration input as-is here, and it will eventually be added to 

858 # the context via AddOp(). 

859 real_x = x 

860 else: 

861 real_x = self.AddValue(x) 

862 if real_x != x: 

863 # pylint: disable=protected-access 

864 op._update_input(index, real_x) 

865 # pylint: enable=protected-access 

866 # Remove any external control dependency on this op. 

867 self._RemoveExternalControlEdges(op) 

868 # pylint: disable=protected-access 

869 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 

870 op._add_control_input(self._pivot.op) 

871 # pylint: enable=protected-access 

872 

873 # Mark op's outputs as seen by this context and any outer contexts. 

874 output_names = [x.name for x in op.outputs] 

875 ctxt = self 

876 while ctxt is not None: 

877 # pylint: disable=protected-access 

878 ctxt._values.update(output_names) 

879 ctxt = ctxt._outer_context 

880 # pylint: enable=protected-access 

881 

882 if self._outer_context or not util.IsLoopExit(op): 

883 op.graph.prevent_fetching(op) 

884 

885 if self._outer_context: 

886 self._outer_context.AddInnerOp(op) 

887 

888 def _ProcessOutputTensor(self, val): 

889 """Process an output tensor of a conditional branch.""" 

890 real_val = val 

891 if val.name not in self._values: 

892 # Handle the special case of lambda: x 

893 self._values.add(val.name) 

894 if self._outer_context: 

895 real_val = self._outer_context.AddValue(val) 

896 self._values.add(real_val.name) 

897 self._external_values[real_val.name] = real_val 

898 real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch] 

899 self._external_values[val.name] = real_val 

900 else: 

901 external_val = self._external_values.get(val.name) 

902 if external_val is not None: 

903 real_val = external_val 

904 return real_val 

905 

906 def _BuildCondTensor(self, v): 

907 if isinstance(v, ops.Operation): 

908 # Use pivot as the proxy for this op. 

909 return with_dependencies([v], self._pivot) 

910 else: 

911 v = nest.map_structure( 

912 _convert_tensorarray_to_flow, v, expand_composites=True) 

913 return self._ProcessOutputTensor(ops.convert_to_tensor(v)) 

914 

915 def BuildCondBranch(self, fn): 

916 """Add the subgraph defined by fn() to the graph.""" 

917 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 

918 original_result = fn() 

919 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 

920 if len(post_summaries) > len(pre_summaries): 

921 new_summaries = post_summaries[len(pre_summaries):] 

922 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 

923 summary_ref[:] = pre_summaries 

924 with ops.control_dependencies(new_summaries): 

925 if original_result is None: 

926 return no_op(), None 

927 elif not isinstance(original_result, ops.Operation): 

928 original_result = variable_utils.convert_variables_to_tensors( 

929 original_result) 

930 original_result = nest.map_structure( 

931 array_ops.identity, original_result, expand_composites=True) 

932 if original_result is None: 

933 return None, None 

934 

935 original_result = variable_utils.convert_variables_to_tensors( 

936 original_result) 

937 result = nest.map_structure( 

938 self._BuildCondTensor, original_result, expand_composites=True) 

939 if not isinstance(result, (list, _basetuple)): 

940 result = [result] 

941 return original_result, result 

942 

943 def IsCondContext(self): 

944 return True 

945 

946 

947# pylint: enable=g-doc-args 

948# pylint: enable=redefined-outer-name 

949 

950 

951def _resource_safe_shape(t): 

952 """Returns the shape of t or the variable it points to.""" 

953 if t.dtype == dtypes.resource: 

954 while t.op.inputs: 

955 t = t.op.inputs[0] 

956 return tensor_shape.TensorShape(t.op.get_attr("shape")) 

957 return array_ops.shape_internal(t, optimize=False) 

958 

959 

960# TODO(yuanbyu): Consider having a unified notion of context for 

961# not only conditionals and loops but also control dependency and 

962# subgraphs. 

963class WhileContext(ControlFlowContext): 

964 """The context for the loop construct.""" 

965 

966 def __init__(self, 

967 maximum_iterations=None, 

968 parallel_iterations=10, 

969 back_prop=True, 

970 swap_memory=False, 

971 name="while_context", 

972 grad_state=None, 

973 context_def=None, 

974 import_scope=None): 

975 """"Creates a `WhileContext`. 

976 

977 Args: 

978 maximum_iterations: Optional upper bound on number of loop iterations. 

979 parallel_iterations: The number of iterations allowed to run in parallel. 

980 back_prop: Whether backprop is enabled for this while loop. 

981 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 

982 name: Optional name prefix for the returned tensors. 

983 grad_state: The gradient loop state. 

984 context_def: Optional `WhileContextDef` protocol buffer to initialize the 

985 `Whilecontext` python object from. 

986 import_scope: Optional `string`. Name scope to add. Only used when 

987 initialing from protocol buffer. 

988 """ 

989 if context_def: 

990 self._init_from_proto(context_def, import_scope=import_scope) 

991 else: 

992 ControlFlowContext.__init__(self) 

993 self._init_from_args(maximum_iterations, parallel_iterations, back_prop, 

994 swap_memory, name) 

995 # The gradient loop state. 

996 self._grad_state = grad_state 

997 

998 def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop, 

999 swap_memory, name): 

1000 """Creates a new `WhileContext` from arguments. 

1001 

1002 Args: 

1003 maximum_iterations: Optional upper bound on number of loop iterations. 

1004 parallel_iterations: The number of iterations allowed to run in parallel. 

1005 back_prop: Whether backprop is enabled for this while loop. 

1006 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 

1007 name: Optional name prefix for the returned tensors. 

1008 

1009 Raises: 

1010 ValueError: If `parallel_iterations` has invalid value. 

1011 """ 

1012 if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0): 

1013 raise ValueError("'parallel_iterations' must be a positive integer: " 

1014 "%s" % parallel_iterations) 

1015 self._name = ops.get_default_graph().unique_name(name) 

1016 self._maximum_iterations = maximum_iterations 

1017 self._parallel_iterations = parallel_iterations 

1018 self._back_prop = back_prop 

1019 self._swap_memory = swap_memory 

1020 # We use this node to control constants created by the pred lambda. 

1021 self._pivot_for_pred = None 

1022 # We use this node to control constants created by the body lambda. 

1023 self._pivot_for_body = None 

1024 # The boolean tensor for loop termination condition. Used in code 

1025 # generation for gradient computation 

1026 self._pivot = None 

1027 # The list of exit tensors for loop variables. 

1028 self._loop_exits = [] 

1029 # The list of enter tensors for loop variables. 

1030 self._loop_enters = [] 

1031 self._graph = ops.get_default_graph() 

1032 

1033 def _init_from_proto(self, context_def, import_scope=None): 

1034 """Creates a new `WhileContext` from protocol buffer. 

1035 

1036 Args: 

1037 context_def: `WhileContextDef` protocol buffer. 

1038 import_scope: Optional `string`. Name scope to add. 

1039 """ 

1040 assert isinstance(context_def, control_flow_pb2.WhileContextDef) 

1041 # Create from context_def. 

1042 g = ops.get_default_graph() 

1043 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 

1044 if context_def.maximum_iterations_name: 

1045 self._maximum_iterations = g.as_graph_element( 

1046 ops.prepend_name_scope(context_def.maximum_iterations_name, 

1047 import_scope)) 

1048 else: 

1049 self._maximum_iterations = None 

1050 self._parallel_iterations = context_def.parallel_iterations 

1051 self._back_prop = context_def.back_prop 

1052 self._swap_memory = context_def.swap_memory 

1053 self._pivot_for_pred = g.as_graph_element( 

1054 ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope)) 

1055 # We use this node to control constants created by the body lambda. 

1056 self._pivot_for_body = g.as_graph_element( 

1057 ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope)) 

1058 # The boolean tensor for loop termination condition. Used in code 

1059 # generation for gradient computation. 

1060 self._pivot = g.as_graph_element( 

1061 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 

1062 # The list of exit tensors for loop variables. 

1063 self._loop_exits = [ 

1064 g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) 

1065 for exit_name in context_def.loop_exit_names 

1066 ] 

1067 # The list of enter tensors for loop variables. 

1068 self._loop_enters = [ 

1069 g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) 

1070 for enter_name in context_def.loop_enter_names 

1071 ] 

1072 super(WhileContext, self).__init__( 

1073 values_def=context_def.values_def, import_scope=import_scope) 

1074 

1075 # import_scope causes self.name to be different from the original serialized 

1076 # context's name. Rewrite "frame_name" attrs with the new name. 

1077 if import_scope: 

1078 for tensor_name in self._values: 

1079 op = g.as_graph_element(tensor_name).op 

1080 if util.IsLoopEnter(op): 

1081 # pylint: disable=protected-access 

1082 op._set_attr("frame_name", 

1083 attr_value_pb2.AttrValue(s=compat.as_bytes(self.name))) 

1084 # pylint: enable=protected-access 

1085 self._graph = ops.get_default_graph() 

1086 

1087 @property 

1088 def maximum_iterations(self): 

1089 """The maximum number of iterations that will be executed.""" 

1090 return self._maximum_iterations 

1091 

1092 @property 

1093 def parallel_iterations(self): 

1094 """The number of iterations allowed to run in parallel.""" 

1095 return self._parallel_iterations 

1096 

1097 @property 

1098 def back_prop(self): 

1099 """True iff backprop is enabled for this while loop.""" 

1100 return self._back_prop 

1101 

1102 @property 

1103 def swap_memory(self): 

1104 """True iff GPU-CPU memory swap is enabled for this while loop.""" 

1105 return self._swap_memory 

1106 

1107 @property 

1108 def pivot(self): 

1109 """The boolean tensor representing the loop termination condition.""" 

1110 return self._pivot 

1111 

1112 @property 

1113 def loop_enters(self): 

1114 """The list of enter tensors for loop variables.""" 

1115 return self._loop_enters 

1116 

1117 @property 

1118 def loop_exits(self): 

1119 """The list of exit tensors for loop variables.""" 

1120 return self._loop_exits 

1121 

1122 @property 

1123 def grad_state(self): 

1124 """The gradient loop state.""" 

1125 return self._grad_state 

1126 

1127 def to_proto(self, export_scope=None): 

1128 """Converts a `WhileContext` to a `WhileContextDef` protocol buffer. 

1129 

1130 Args: 

1131 export_scope: Optional `string`. Name scope to remove. 

1132 

1133 Returns: 

1134 A `WhileContextDef` protocol buffer. 

1135 """ 

1136 if (export_scope is None or self.name.startswith(export_scope)): 

1137 context_def = control_flow_pb2.WhileContextDef() 

1138 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 

1139 context_def.parallel_iterations = self._parallel_iterations 

1140 if self._maximum_iterations is not None: 

1141 context_def.maximum_iterations_name = ops.strip_name_scope( 

1142 self._maximum_iterations.name, export_scope) 

1143 context_def.back_prop = self._back_prop 

1144 context_def.swap_memory = self._swap_memory 

1145 context_def.pivot_for_pred_name = ops.strip_name_scope( 

1146 self._pivot_for_pred.name, export_scope) 

1147 context_def.pivot_for_body_name = ops.strip_name_scope( 

1148 self._pivot_for_body.name, export_scope) 

1149 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 

1150 export_scope) 

1151 context_def.loop_exit_names.extend([ 

1152 ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits 

1153 ]) 

1154 context_def.loop_enter_names.extend([ 

1155 ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters 

1156 ]) 

1157 context_def.values_def.MergeFrom( 

1158 super(WhileContext, self)._to_values_def(export_scope=export_scope)) 

1159 for nested in self._nested_contexts: 

1160 nested_def = context_def.nested_contexts.add() 

1161 nested.to_control_flow_context_def(nested_def) 

1162 

1163 return context_def 

1164 else: 

1165 return None 

1166 

1167 def to_control_flow_context_def(self, context_def, export_scope=None): 

1168 context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 

1169 

1170 @staticmethod 

1171 def from_proto(context_def, import_scope=None): 

1172 """Returns a `WhileContext` object created from `context_def`. 

1173 

1174 Args: 

1175 context_def: A `WhileContextDef` protocol buffer. 

1176 import_scope: Optional `string`. Name scope to add. 

1177 

1178 Returns: 

1179 A `WhileContext` Python object. 

1180 """ 

1181 ret = WhileContext(context_def=context_def, import_scope=import_scope) 

1182 ret.Enter() 

1183 for nested_def in context_def.nested_contexts: 

1184 from_control_flow_context_def(nested_def, import_scope=import_scope) 

1185 ret.Exit() 

1186 return ret 

1187 

1188 def GetWhileContext(self): 

1189 return self 

1190 

1191 def GetControlPivot(self): 

1192 if self._pivot_for_body is not None: 

1193 return self._pivot_for_body 

1194 return self._pivot_for_pred 

1195 

1196 def AddValue(self, val): 

1197 """Add `val` to the current context and its outer context recursively.""" 

1198 result = val 

1199 new_value = val.name not in self._values 

1200 # Don't treat ops in this context as new values. Usually all known values 

1201 # are in self._values, except when we're importing a while loop inside this 

1202 # WhileContext. Since there's a cycle in this case, `val` may be part of the 

1203 # imported while loop but not yet processed by this context and added to 

1204 # self._values in _AddOpInternal. We only want to process external input 

1205 # tensors to the while loop here. 

1206 new_value &= val.op._control_flow_context is not self # pylint: disable=protected-access 

1207 if new_value: 

1208 self._values.add(val.name) 

1209 

1210 # If we are in a grad context and val is from its forward context, 

1211 # use GetRealValue(), which adds the logic to save the history of 

1212 # val in forward. 

1213 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 

1214 if grad_ctxt: 

1215 grad_ctxt = grad_ctxt.GetWhileContext() 

1216 if grad_ctxt.grad_state: 

1217 forward_ctxt = util.GetWhileContext(val.op) 

1218 if util.IsLoopExit(val.op): 

1219 forward_ctxt = forward_ctxt.outer_context 

1220 if forward_ctxt: 

1221 forward_ctxt = forward_ctxt.GetWhileContext() 

1222 if forward_ctxt == grad_ctxt.grad_state.forward_context: 

1223 real_val = grad_ctxt.grad_state.GetRealValue(val) 

1224 self._external_values[val.name] = real_val 

1225 return real_val 

1226 

1227 if self._outer_context is not None: 

1228 result = self._outer_context.AddValue(val) 

1229 # Create an Enter to make `result` known to this loop context. 

1230 with ops.control_dependencies(None): 

1231 enter = _Enter( 

1232 result, 

1233 self._name, 

1234 is_constant=True, 

1235 parallel_iterations=self._parallel_iterations) 

1236 enter.graph.prevent_feeding(enter) 

1237 if self._outer_context: 

1238 self._outer_context.AddInnerOp(enter.op) 

1239 # Fix the control inputs and control flow context of these enter ops. 

1240 self._FixControlInputsAndContext([enter]) 

1241 

1242 # Add `enter` in this context. 

1243 self._values.add(enter.name) 

1244 self._external_values[val.name] = enter 

1245 result = enter 

1246 else: 

1247 actual_val = self._external_values.get(val.name) 

1248 if actual_val is not None: 

1249 result = actual_val 

1250 return result 

1251 

1252 def AddOp(self, op): 

1253 """Add `op` to the current context.""" 

1254 # For a reduction op, if op is in a grad context and its input is from 

1255 # its forward context, moving op to the forward context means we would 

1256 # store the tensor after the reduction as opposed to the tensor before 

1257 # reduction, and therefore could significantly reduce memory consumption. 

1258 # For now, we do this only for a few ops. 

1259 # 

1260 # If in XLA context, do not move constant ops to forward pass as pushing to 

1261 # and popping from a stack removes the constant property of an op and breaks 

1262 # XLA compilation, which requires certain inputs to be constant for certain 

1263 # ops. 

1264 if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}: 

1265 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 

1266 if grad_ctxt: 

1267 grad_ctxt = grad_ctxt.GetWhileContext() 

1268 if grad_ctxt.grad_state: 

1269 op_input_forward_ctxt = util.GetWhileContext(op.inputs[0].op) 

1270 if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context: 

1271 op_input_ctxt = op.inputs[0].op._get_control_flow_context() 

1272 op._set_control_flow_context(op_input_ctxt) 

1273 op_input_ctxt._AddOpInternal(op) 

1274 return 

1275 self._AddOpInternal(op) 

1276 

1277 def _AddOpInternal(self, op): 

1278 """Add `op` to the current context. 

1279 

1280 We move any external control dependencies of the op to the loop pivot, to 

1281 ensure they get executed. 

1282 """ 

1283 # This is needed to prevent frame mismatch errors where there are Const 

1284 # nodes inside tf.function in v1 while_loop and inlining is turned on. 

1285 if op.type in ["PartitionedCall", "StatefulPartitionedCall"]: 

1286 op._add_control_input(self.GetControlPivot().op) # pylint: disable=protected-access 

1287 if not op.inputs: 

1288 # Remove any external control dependency on this op 

1289 control_inputs, external_inputs = self._RemoveExternalControlEdges(op) 

1290 # Add a control edge from the control pivot to this op. 

1291 if not control_inputs: 

1292 # pylint: disable=protected-access 

1293 op._add_control_input(self.GetControlPivot().op) 

1294 # pylint: enable=protected-access 

1295 for x in op.outputs: 

1296 self._values.add(x.name) 

1297 else: 

1298 for index in range(len(op.inputs)): 

1299 x = op.inputs[index] 

1300 real_x = self.AddValue(x) 

1301 if real_x != x: 

1302 op._update_input(index, real_x) # pylint: disable=protected-access 

1303 # Remove any external control dependency on this op. 

1304 _, external_inputs = self._RemoveExternalControlEdges(op) 

1305 # Add a control dependency to prevent loop invariants from 

1306 # enabling ops that should not be executed. 

1307 self._MaybeAddControlDependency(op) 

1308 for x in op.outputs: 

1309 self._values.add(x.name) 

1310 if external_inputs: 

1311 # Use an identity to pull control inputs as data inputs. Note that we 

1312 # ignore ops which don't have outputs. TODO(apassos): fix that 

1313 with ops.control_dependencies(None): 

1314 self.Enter() 

1315 external_inputs = [ 

1316 array_ops.identity(x.outputs[0]).op 

1317 for x in external_inputs 

1318 if x.outputs 

1319 ] 

1320 self.Exit() 

1321 op._add_control_inputs(external_inputs) # pylint: disable=protected-access 

1322 if self._outer_context or not util.IsLoopExit(op): 

1323 op.graph.prevent_fetching(op) 

1324 for x in op.outputs: 

1325 op.graph.prevent_feeding(x) 

1326 

1327 if self._outer_context: 

1328 self._outer_context.AddInnerOp(op) 

1329 

1330 def _MaybeAddControlDependency(self, op): 

1331 """Add a control input to the op if it only depends on loop invariants.""" 

1332 

1333 def _IsOpFree(op): 

1334 """Determines if `op` needs a control dependency.""" 

1335 if op.control_inputs: 

1336 return False 

1337 # pylint: disable=protected-access 

1338 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 

1339 return True 

1340 # pylint: enable=protected-access 

1341 for x in op.inputs: 

1342 if not util.IsLoopConstantEnter(x.op): 

1343 return False 

1344 return True 

1345 

1346 if _IsOpFree(op): 

1347 # pylint: disable=protected-access 

1348 op._add_control_input(self.GetControlPivot().op) 

1349 # pylint: enable=protected-access 

1350 

1351 def AddForwardLoopCounter(self, outer_grad_state): 

1352 """Adds a loop that counts the number of iterations. 

1353 

1354 This is added to the forward loop at the time when we start to 

1355 create the loop for backprop gradient computation. Called in 

1356 the outer context of this forward context. 

1357 

1358 The pseudocode is: 

1359 `n = 0; while (_pivot) { n++; }` 

1360 

1361 Note that a control dependency is added to `n` to ensure the correct 

1362 execution order of stack push ops. 

1363 

1364 Args: 

1365 outer_grad_state: The outer grad state. None if not nested. 

1366 

1367 Returns: 

1368 The number of iterations taken by the forward loop and the loop index. 

1369 """ 

1370 n = constant_op.constant(0, name="f_count") 

1371 if outer_grad_state is not None: 

1372 # Force the stack pushes of i-th execution of an inner loop to be ordered 

1373 # before the pushes of (i+1)-th execution of the same inner loop. 

1374 outer_add_op = outer_grad_state.forward_index.op.inputs[0].op 

1375 n.op._add_control_input(outer_add_op) # pylint: disable=protected-access 

1376 

1377 self.Enter() 

1378 self.AddName(n.name) 

1379 enter_n = _Enter( 

1380 n, 

1381 self._name, 

1382 is_constant=False, 

1383 parallel_iterations=self._parallel_iterations, 

1384 name="f_count") 

1385 self.loop_enters.append(enter_n) 

1386 

1387 merge_n = merge([enter_n, enter_n])[0] 

1388 switch_n = switch(merge_n, self._pivot) 

1389 

1390 index = math_ops.add(switch_n[1], 1) 

1391 next_n = _NextIteration(index) 

1392 merge_n.op._update_input(1, next_n) 

1393 

1394 total_iterations = exit(switch_n[0], name="f_count") 

1395 self.loop_exits.append(total_iterations) 

1396 self.ExitResult([total_iterations]) 

1397 self.Exit() 

1398 return total_iterations, next_n 

1399 

1400 def AddBackpropLoopCounter(self, count, outer_grad_state): 

1401 """Add the backprop loop that controls the iterations. 

1402 

1403 This is added to the backprop loop. It is used to control the loop 

1404 termination of the backprop loop. Called in the outer context of 

1405 this grad context. 

1406 

1407 The pseudocode is: 

1408 `n = count; while (n >= 1) { n--; }` 

1409 

1410 Note that a control dependency is added to `final_zero` to ensure the 

1411 correct execution order of stack pop ops. 

1412 

1413 Args: 

1414 count: The number of iterations for backprop. 

1415 outer_grad_state: The outer grad state. None if not nested. 

1416 

1417 Returns: 

1418 The loop index. 

1419 """ 

1420 in_separate_functions = count.graph is not ops.get_default_graph() 

1421 if in_separate_functions: 

1422 # Brings the count into this graph 

1423 count = array_ops.identity(count) 

1424 else: 

1425 # TODO(apassos) XLA expects this constant to be created outside the loop, 

1426 # so doing that for now. 

1427 one = constant_op.constant(1, name="b_count") 

1428 

1429 self.Enter() 

1430 self.AddName(count.name) 

1431 enter_count = _Enter( 

1432 count, 

1433 self._name, 

1434 is_constant=False, 

1435 parallel_iterations=self._parallel_iterations, 

1436 name="b_count") 

1437 self.loop_enters.append(enter_count) 

1438 

1439 merge_count = merge([enter_count, enter_count])[0] 

1440 self._pivot_for_pred = merge_count 

1441 

1442 if in_separate_functions: 

1443 one = constant_op.constant(1, name="b_count") 

1444 pred = math_ops.greater_equal(merge_count, one) 

1445 self._pivot = loop_cond(pred, name="b_count") 

1446 switch_count = switch(merge_count, self._pivot) 

1447 

1448 index = math_ops.subtract(switch_count[1], one) 

1449 self._pivot_for_body = index 

1450 next_count = _NextIteration(index) 

1451 merge_count.op._update_input(1, next_count) 

1452 

1453 final_zero = exit(switch_count[0], name="b_count") 

1454 self.loop_exits.append(final_zero) 

1455 if outer_grad_state is not None: 

1456 # Force the stack pops of i-th execution of an inner loop to be ordered 

1457 # before the pops of (i+1)-th execution of the same inner loop. 

1458 # pylint: disable=protected-access 

1459 outer_grad_state.grad_sync._add_control_input(final_zero.op) 

1460 # pylint: enable=protected-access 

1461 

1462 self.ExitResult([final_zero]) 

1463 self.Exit() 

1464 return next_count 

1465 

1466 def AddBackpropAccumulator(self, op, grad): 

1467 """Add an accumulation loop for every loop invariant. 

1468 

1469 This is added to the backprop loop. It is used to accumulate partial 

1470 gradients within each loop iteration. Called when in the gradient while 

1471 context. 

1472 

1473 The pseudocode is: 

1474 ``` 

1475 acc = 0.0; 

1476 while (_pivot) { 

1477 acc += grad; 

1478 } 

1479 ``` 

1480 

1481 Args: 

1482 op: The Enter op for a loop invariant. 

1483 grad: The partial gradient of an iteration for a loop invariant. 

1484 

1485 Returns: 

1486 The gradient for a loop invariant. 

1487 """ 

1488 self.Exit() 

1489 # Create a zeros tensor with the right shape for acc. If we don't 

1490 # know the full shape statically, we will have to get the shape 

1491 # dynamically from the forward inference. Getting the shape right 

1492 # for the zeros is only needed for the base case when the loop exits 

1493 # without running any iterations. 

1494 shape = grad.get_shape() 

1495 if shape.is_fully_defined(): 

1496 if self.outer_context: 

1497 self.outer_context.Enter() 

1498 acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc") 

1499 if self.outer_context: 

1500 self.outer_context.Exit() 

1501 else: 

1502 value = op.inputs[0] 

1503 if (isinstance(self.outer_context, WhileContext) and 

1504 self.outer_context.grad_state is not None): 

1505 # We are in a nested while loop. 

1506 forward_ctxt = self.grad_state.forward_context 

1507 forward_ctxt.outer_context.Enter() 

1508 zeros_shape = array_ops.shape_internal(value, optimize=False) 

1509 forward_ctxt.outer_context.Exit() 

1510 outer_grad_state = self.grad_state.outer_grad_state 

1511 history_zeros_shape = outer_grad_state.AddForwardAccumulator( 

1512 zeros_shape) 

1513 self.outer_context.Enter() 

1514 real_shape = outer_grad_state.AddBackpropAccumulatedValue( 

1515 history_zeros_shape, zeros_shape) 

1516 acc = array_ops.zeros(real_shape, grad.dtype) 

1517 self.outer_context.Exit() 

1518 else: 

1519 if self.outer_context: 

1520 self.outer_context.Enter() 

1521 zeros_shape = array_ops.shape_internal(value, optimize=False) 

1522 acc = array_ops.zeros(zeros_shape, grad.dtype) 

1523 if self.outer_context: 

1524 self.outer_context.Exit() 

1525 

1526 self.Enter() 

1527 self.AddName(acc.name) 

1528 enter_acc = _Enter( 

1529 acc, 

1530 self._name, 

1531 is_constant=False, 

1532 parallel_iterations=self._parallel_iterations, 

1533 name="b_acc") 

1534 self.loop_enters.append(enter_acc) 

1535 

1536 merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0] 

1537 switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot) 

1538 

1539 add_acc = math_ops.add(switch_acc_true, grad) 

1540 next_acc = _NextIteration(add_acc) 

1541 merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access 

1542 

1543 result_acc = exit(switch_acc_false, name="b_acc") 

1544 self.loop_exits.append(result_acc) 

1545 self.ExitResult([result_acc]) 

1546 return result_acc 

1547 

1548 def AddBackpropIndexedSlicesAccumulator(self, op, grad): 

1549 """This is used for accumulating gradients that are IndexedSlices. 

1550 

1551 This is essentially the equivalent of AddBackpropAccumulator but optimized 

1552 for things like updating embeddings from within a while loop. 

1553 

1554 Args: 

1555 op: The Enter op for a loop invariant. 

1556 grad: The partial gradients represented as an IndexedSlices. 

1557 

1558 Returns: 

1559 The accumulated IndexedSlices gradient of the loop invariant. 

1560 """ 

1561 values = grad.values 

1562 indices = grad.indices 

1563 dense_shape = grad.dense_shape 

1564 

1565 self.Exit() 

1566 if self.outer_context: 

1567 self.outer_context.Enter() 

1568 if values.get_shape().is_fully_defined(): 

1569 values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] + 

1570 values.get_shape().dims[1:]) 

1571 if self.outer_context: 

1572 self.outer_context.Enter() 

1573 values_acc = constant_op.constant( 

1574 0, values.dtype, shape=values_shape, name="b_acc") 

1575 if self.outer_context: 

1576 self.outer_context.Exit() 

1577 else: 

1578 values_shape = _resource_safe_shape(op.inputs[0])[1:] 

1579 values_shape = array_ops.concat([[1], values_shape], 0) 

1580 values_acc = array_ops.zeros(values_shape, dtype=values.dtype) 

1581 indices_acc = constant_op.constant([0], indices.dtype) 

1582 shape_acc = None 

1583 if dense_shape is not None: 

1584 if dense_shape.get_shape().is_fully_defined(): 

1585 if self.outer_context: 

1586 self.outer_context.Enter() 

1587 shape_acc = constant_op.constant( 

1588 0, dense_shape.dtype, shape=dense_shape.get_shape()) 

1589 if self.outer_context: 

1590 self.outer_context.Exit() 

1591 else: 

1592 shape_acc = array_ops.zeros_like( 

1593 array_ops.shape_internal( 

1594 op.inputs[0], optimize=False, out_type=dense_shape.dtype), 

1595 optimize=False) 

1596 

1597 if self.outer_context: 

1598 self.outer_context.Exit() 

1599 

1600 self.Enter() 

1601 self.AddName(values_acc.name) 

1602 self.AddName(indices_acc.name) 

1603 init_acc = [indices_acc, values_acc] 

1604 if shape_acc is not None: 

1605 self.AddName(shape_acc.name) 

1606 init_acc.append(shape_acc) 

1607 

1608 # Set use_input_shape=False since the accumulator tensors will grow in 

1609 # size. If use_input_shape=True, the _update_input call below will result in 

1610 # incompatible shapes. 

1611 enter_acc = [ 

1612 _Enter( 

1613 x, 

1614 self._name, 

1615 is_constant=False, 

1616 parallel_iterations=self._parallel_iterations, 

1617 use_input_shape=False, 

1618 name="b_acc") for x in init_acc 

1619 ] 

1620 # Manually set appropriate partial shapes. 

1621 enter_acc[0].set_shape([None]) 

1622 if values_acc.shape.dims is not None: 

1623 enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:]) 

1624 self.loop_enters.extend(enter_acc) 

1625 

1626 merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc] 

1627 switch_acc = [switch(x, self._pivot) for x in merge_acc] 

1628 

1629 # The actual accumulation. 

1630 acc_indexed_slices = [ 

1631 array_ops.concat([xa[1], xv], 0) 

1632 for xa, xv in zip(switch_acc[:2], [indices, values]) 

1633 ] 

1634 if shape_acc is not None: 

1635 # For the shape we just keep the maximum 

1636 acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1])) 

1637 

1638 next_acc = [_NextIteration(x) for x in acc_indexed_slices] 

1639 for xm, xn in zip(merge_acc, next_acc): 

1640 xm.op._update_input(1, xn) # pylint: disable=protected-access 

1641 

1642 exit_acc = [exit(x[0], name="b_acc") for x in switch_acc] 

1643 self.loop_exits.extend(exit_acc) 

1644 

1645 self.ExitResult(exit_acc) 

1646 return indexed_slices.IndexedSlices( 

1647 indices=exit_acc[0], 

1648 values=exit_acc[1], 

1649 dense_shape=exit_acc[2] if shape_acc is not None else None) 

1650 

1651 def _InitializeValues(self, values): 

1652 """Makes the values known to this context.""" 

1653 self._values = set() 

1654 for x in values: 

1655 if isinstance(x, ops.Tensor): 

1656 self._values.add(x.name) 

1657 else: 

1658 raise TypeError("'values' must be a list of Tensors. " 

1659 f"Received: {type(x)}.") 

1660 

1661 def _BuildLoop(self, pred, body, flat_orig_loop_vars, flat_loop_vars, 

1662 loop_vars_signature): 

1663 """Core: Add the loop termination condition and body to the graph.""" 

1664 flat_shape_invariants = nest.map_structure( 

1665 lambda spec: spec.shape, 

1666 nest.flatten(loop_vars_signature, expand_composites=True)) 

1667 

1668 # Let the context know the loop variables so the loop variables 

1669 # would be added in the outer contexts properly. 

1670 self._InitializeValues(flat_loop_vars) 

1671 if self._outer_context: 

1672 real_vars = [self._outer_context.AddValue(x) for x in flat_loop_vars] 

1673 else: 

1674 real_vars = flat_loop_vars 

1675 

1676 enter_vars = [] 

1677 with ops.control_dependencies(None): 

1678 for real_var, shape_invariant in zip(real_vars, flat_shape_invariants): 

1679 enter_var = _Enter( 

1680 real_var, 

1681 self._name, 

1682 is_constant=False, 

1683 parallel_iterations=self._parallel_iterations, 

1684 use_input_shape=False) 

1685 

1686 if _ShapeLessThanOrEqual(real_var.get_shape(), shape_invariant): 

1687 enter_var.set_shape(shape_invariant) 

1688 else: 

1689 raise ValueError( 

1690 f"The shape invariant specified for {real_var.name} is not " 

1691 "compatible with the initial shape of the loop variable. It " 

1692 f"enters the loop with shape {real_var.get_shape()}, but the " 

1693 f"specified shape invariant is {shape_invariant}.") 

1694 

1695 enter_var.graph.prevent_feeding(enter_var) 

1696 if self._outer_context: 

1697 self._outer_context.AddInnerOp(enter_var.op) 

1698 enter_vars.append(enter_var) 

1699 

1700 # Finds the closest enclosing non-None control pivot. 

1701 outer_context = self._outer_context 

1702 control_pivot = None 

1703 while outer_context is not None and control_pivot is None: 

1704 control_pivot = outer_context.GetControlPivot() 

1705 # pylint: disable=protected-access 

1706 outer_context = outer_context._outer_context 

1707 # pylint: enable=protected-access 

1708 

1709 if control_pivot is not None: 

1710 for var in enter_vars: 

1711 if util.IsLoopConstantEnter(var.op.inputs[0].op): 

1712 # pylint: disable=protected-access 

1713 var.op._add_control_input(control_pivot.op) 

1714 # pylint: enable=protected-access 

1715 

1716 # Fix the control inputs and control flow context of these enter ops. 

1717 self._FixControlInputsAndContext(enter_vars) 

1718 self._InitializeValues(enter_vars) 

1719 self._loop_enters = enter_vars 

1720 

1721 merge_vars = [merge([x, x])[0] for x in enter_vars] 

1722 self._pivot_for_pred = merge_vars[0] 

1723 

1724 merge_vars_with_tensorarrays = nest.map_structure( 

1725 _convert_flow_to_tensorarray, flat_orig_loop_vars, merge_vars) 

1726 # Build the graph for pred. 

1727 packed_vars = nest.pack_sequence_as( 

1728 structure=loop_vars_signature, 

1729 flat_sequence=merge_vars_with_tensorarrays, 

1730 expand_composites=True) 

1731 c = ops.convert_to_tensor(pred(*packed_vars)) 

1732 self._pivot = loop_cond(c, name="LoopCond") 

1733 switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars] 

1734 

1735 # Build the graph for body. 

1736 vars_for_body = [_Identity(x[1]) for x in switch_vars] 

1737 self._pivot_for_body = vars_for_body[0] 

1738 # Convert TensorArray flow variables inside the context back into 

1739 # their associated TensorArrays for calling the body. 

1740 vars_for_body_with_tensorarrays = nest.map_structure( 

1741 _convert_flow_to_tensorarray, flat_orig_loop_vars, vars_for_body) 

1742 packed_vars_for_body = nest.pack_sequence_as( 

1743 structure=loop_vars_signature, 

1744 flat_sequence=vars_for_body_with_tensorarrays, 

1745 expand_composites=True) 

1746 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 

1747 body_result = body(*packed_vars_for_body) 

1748 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 

1749 if not nest.is_nested(body_result): 

1750 body_result = [body_result] 

1751 if len(post_summaries) > len(pre_summaries): 

1752 new_summaries = post_summaries[len(pre_summaries):] 

1753 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 

1754 summary_ref[:] = pre_summaries 

1755 with ops.control_dependencies(new_summaries): 

1756 

1757 def map_fn(x): 

1758 # TODO(apassos) figure out how to trigger with tensor arrays as well 

1759 if isinstance(x, tensor_array_ops.TensorArray): 

1760 return x 

1761 return array_ops.identity(x) 

1762 

1763 body_result = nest.map_structure( 

1764 map_fn, body_result, expand_composites=True) 

1765 

1766 body_result = variable_utils.convert_variables_to_tensors(body_result) 

1767 # Compare the structure types of input and output of body. 

1768 # For backwards compatibility, the first layer is forced to a list 

1769 # during this comparison, because inputs are typically lists and 

1770 # outputs of the body are typically tuples. 

1771 nest.assert_same_structure( 

1772 list(packed_vars_for_body), list(body_result), expand_composites=True) 

1773 

1774 # Store body_result to keep track of TensorArrays returned by body 

1775 original_body_result = body_result 

1776 # Convert TensorArrays returned by body into their flow variables 

1777 result = nest.map_structure( 

1778 _convert_tensorarray_to_flow, 

1779 nest.flatten(body_result, expand_composites=True), 

1780 expand_composites=True) 

1781 result = ops.convert_n_to_tensor_or_composite(result) 

1782 

1783 # Add NextIteration and the back edges to complete the loop. 

1784 if len(merge_vars) != len(result): 

1785 raise ValueError("Number of inputs and outputs of 'body' must match " 

1786 f"'loop_vars'. Got {len(merge_vars)} for the number of " 

1787 f"inputs/outputs, and {len(result)} for 'loop_vars'.") 

1788 next_vars = [] 

1789 for m, v in zip(merge_vars, result): 

1790 next_vars.append(_AddNextAndBackEdge(m, v)) 

1791 

1792 # Add the exit ops. 

1793 exit_vars = [exit(x[0]) for x in switch_vars] 

1794 self._loop_exits = exit_vars 

1795 

1796 # Exit the loop. 

1797 self.ExitResult(exit_vars) 

1798 

1799 return original_body_result, exit_vars 

1800 

1801 def BuildLoop(self, pred, body, loop_vars, shape_invariants, 

1802 return_same_structure): 

1803 """Add the loop termination condition and body to the graph.""" 

1804 

1805 # Keep flat_orig_loop_vars to identify which are TensorArrays 

1806 flat_orig_loop_vars = nest.flatten(loop_vars, expand_composites=True) 

1807 

1808 loop_vars = nest.map_structure( 

1809 _convert_to_tensor_or_composite_or_tensorarray, loop_vars) 

1810 # Convert TensorArrays to their flow variables 

1811 flat_loop_vars = nest.map_structure( 

1812 _convert_tensorarray_to_flow, 

1813 nest.flatten(loop_vars, expand_composites=True)) 

1814 

1815 if shape_invariants is not None: 

1816 loop_vars_signature = nest.map_structure( 

1817 _shape_invariant_to_type_spec, loop_vars, shape_invariants) 

1818 else: 

1819 loop_vars_signature = nest.map_structure( 

1820 _shape_invariant_to_type_spec, loop_vars) 

1821 

1822 try: 

1823 self.Enter() 

1824 # _BuildLoop calls _update_input in several places. _mutation_lock() 

1825 # ensures a Session.run call cannot occur between creating and mutating 

1826 # new ops. 

1827 with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access 

1828 original_body_result, exit_vars = self._BuildLoop( 

1829 pred, body, flat_orig_loop_vars, flat_loop_vars, 

1830 loop_vars_signature) 

1831 finally: 

1832 self.Exit() 

1833 

1834 flat_result = nest.flatten(original_body_result, expand_composites=True) 

1835 # Convert TensorArray flow variables outside the context back into 

1836 # their associated TensorArrays for returning to caller. 

1837 exit_vars_with_tensorarrays = nest.map_structure( 

1838 _convert_flow_to_tensorarray, flat_result, exit_vars) 

1839 

1840 packed_exit_vars = nest.pack_sequence_as( 

1841 structure=original_body_result, 

1842 flat_sequence=exit_vars_with_tensorarrays, 

1843 expand_composites=True) 

1844 

1845 if return_same_structure: 

1846 return packed_exit_vars 

1847 else: 

1848 return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars 

1849 

1850 def _FixControlInputsAndContext(self, enters): 

1851 graph = ops.get_default_graph() 

1852 # pylint: disable=protected-access 

1853 for e in enters: 

1854 if isinstance(e, ops.Tensor): 

1855 xs = [e] 

1856 else: 

1857 raise TypeError("'enters' must be a list of Tensors. " 

1858 f"Received: {type(e)}.") 

1859 for x in xs: 

1860 inp_op = x.op.inputs[0].op 

1861 control_inputs = graph._control_dependencies_for_inputs([inp_op]) 

1862 outer_control_inputs = [] 

1863 for op in control_inputs: 

1864 # We need to keep control inputs that are in any ancestor 

1865 # ControlFlowContext, and within outer WhileContext. 

1866 keep_as_control_input = True 

1867 op_ctxt = util.GetOutputContext(op) 

1868 outer_ctxt = self.outer_context 

1869 outer_while_context = (None if outer_ctxt is None else 

1870 outer_ctxt.GetWhileContext()) 

1871 while outer_ctxt != op_ctxt: 

1872 if outer_ctxt is None or outer_ctxt == outer_while_context: 

1873 keep_as_control_input = False 

1874 break 

1875 outer_ctxt = outer_ctxt.outer_context 

1876 if keep_as_control_input: 

1877 outer_control_inputs.append(op) 

1878 x.op._set_control_flow_context(self) 

1879 x.op._add_control_inputs(outer_control_inputs) 

1880 graph._record_op_seen_by_control_dependencies(x.op) 

1881 # pylint: enable=protected-access 

1882 

1883 def IsWhileContext(self): 

1884 return True 

1885 

1886 

1887# pylint: enable=redefined-outer-name 

1888 

1889 

1890def _AsTensorList(x, p): 

1891 """Return x as a list of Tensors or IndexedSlices. 

1892 

1893 For entries of `x` that are Operations, this returns an Identity of `p` 

1894 with a dependency on the operation. 

1895 

1896 Args: 

1897 x: A Tensor/IndexedSlices/Operation or a list or tuple of them. 

1898 p: A Tensor to return for entries in `x` that are Operations. 

1899 

1900 Returns: 

1901 A list of Tensors or IndexedSlices. 

1902 """ 

1903 if not isinstance(x, (list, _basetuple)): 

1904 x = [x] 

1905 

1906 l = [] 

1907 for v in x: 

1908 if isinstance(v, ops.Operation): 

1909 v = with_dependencies([v], p) 

1910 v = ops.convert_to_tensor_or_composite(v) 

1911 if isinstance(v, ops.Tensor): 

1912 l.append(array_ops.identity(v)) 

1913 else: 

1914 l.append( 

1915 indexed_slices.IndexedSlices( 

1916 array_ops.identity(v.values), array_ops.identity(v.indices))) 

1917 return l 

1918 

1919 

1920def _CheckResults(a, b): 

1921 assert len(a) == len(b), ( 

1922 "Values returned by a() and b() must have the same length.") 

1923 for x, y in zip(a, b): 

1924 assert x.dtype == y.dtype, ( 

1925 "Values returned by a() [%s] and b() [%s] must have " 

1926 "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name)) 

1927 

1928 

1929def with_dependencies(dependencies, output_tensor, name=None): 

1930 """Produces the content of `output_tensor` only after `dependencies`. 

1931 

1932 In some cases, a user may want the output of an operation to be 

1933 consumed externally only after some other dependencies have run 

1934 first. This function ensures returns `output_tensor`, but only after all 

1935 operations in `dependencies` have run. Note that this means that there is 

1936 no guarantee that `output_tensor` will be evaluated after any `dependencies` 

1937 have run. 

1938 

1939 See also `tf.tuple` and `tf.group`. 

1940 

1941 Args: 

1942 dependencies: Iterable of operations to run before this op finishes. 

1943 output_tensor: A `Tensor` or `IndexedSlices` that will be returned. 

1944 name: (Optional) A name for this operation. 

1945 

1946 Returns: 

1947 Same as `output_tensor`. 

1948 

1949 Raises: 

1950 TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. 

1951 """ 

1952 if context.executing_eagerly(): 

1953 return output_tensor 

1954 with ops.name_scope(name, "control_dependency", 

1955 list(dependencies) + [output_tensor]) as name: 

1956 with ops.colocate_with(output_tensor): 

1957 with ops.control_dependencies(dependencies): 

1958 output_tensor = ops.convert_to_tensor_or_composite(output_tensor) 

1959 if isinstance(output_tensor, indexed_slices.IndexedSlices): 

1960 return indexed_slices.IndexedSlices( 

1961 _Identity(output_tensor.values, name=name), output_tensor.indices, 

1962 output_tensor.dense_shape) 

1963 else: 

1964 return _Identity(output_tensor, name=name) 

1965 

1966 

1967def _GroupControlDeps(dev, deps, name=None): 

1968 with ops.control_dependencies(deps): 

1969 if dev is None: 

1970 return no_op(name=name) 

1971 else: 

1972 with ops.device(dev): 

1973 return no_op(name=name) 

1974 

1975 

1976# TODO(touts): Accept "inputs" as a list. 

1977@tf_export("group") 

1978def group(*inputs, **kwargs): 

1979 """Create an op that groups multiple operations. 

1980 

1981 When this op finishes, all ops in `inputs` have finished. This op has no 

1982 output. 

1983 

1984 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require 

1985 this method, as ops execute in the expected order thanks to automatic control 

1986 dependencies.* Only use `tf.group` when working with v1 

1987 `tf.Graph` code. 

1988 

1989 When operating in a v1-style graph context, ops are not executed in the same 

1990 order as specified in the code; TensorFlow will attempt to execute ops in 

1991 parallel or in an order convenient to the result it is computing. `tf.group` 

1992 allows you to request that one or more results finish before execution 

1993 continues. 

1994 

1995 `tf.group` creates a single op (of type `NoOp`), and then adds appropriate 

1996 control dependencies. Thus, `c = tf.group(a, b)` will compute the same graph 

1997 as this: 

1998 

1999 with tf.control_dependencies([a, b]): 

2000 c = tf.no_op() 

2001 

2002 See also `tf.tuple` and 

2003 `tf.control_dependencies`. 

2004 

2005 Args: 

2006 *inputs: Zero or more tensors to group. 

2007 name: A name for this operation (optional). 

2008 

2009 Returns: 

2010 An Operation that executes all its inputs. 

2011 

2012 Raises: 

2013 ValueError: If an unknown keyword argument is provided. 

2014 """ 

2015 if context.executing_eagerly(): 

2016 return None 

2017 name = kwargs.pop("name", None) 

2018 if kwargs: 

2019 raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys())) 

2020 with ops.name_scope(name, "group_deps", inputs) as name: 

2021 # Grouping no inputs means do nothing 

2022 if not inputs: 

2023 return no_op(name=name) 

2024 

2025 # Sorts *inputs according to their devices. 

2026 ops_on_device = {} # device -> operations specified on the device. 

2027 for inp in nest.flatten(inputs, expand_composites=True): 

2028 if not hasattr(inp, "device"): 

2029 raise TypeError("'inputs' should be zero or more (nested) Tensors. " 

2030 f"Received '{inp}' with type '{type(inp)}'.") 

2031 dev = inp.device 

2032 if dev in ops_on_device: 

2033 ops_on_device[dev].append(inp) 

2034 else: 

2035 ops_on_device[dev] = [inp] 

2036 if len(ops_on_device) == 1: 

2037 # 1-level tree. The root node is the returned NoOp node. 

2038 (dev, deps), = ops_on_device.items() 

2039 return _GroupControlDeps(dev, deps, name=name) 

2040 

2041 # 2-level tree. The root node is the returned NoOp node. 

2042 # deps contains 1 NoOp node for each device. 

2043 deps = [] 

2044 

2045 def device_key(dev): 

2046 """A sort key that allows None to be compared to strings.""" 

2047 return "" if dev is None else dev 

2048 

2049 for dev in sorted(ops_on_device, key=device_key): 

2050 deps.append(_GroupControlDeps(dev, ops_on_device[dev])) 

2051 

2052 with ops.control_dependencies(deps): 

2053 return no_op(name=name) 

2054 

2055 

2056@tf_export("tuple", v1=[]) 

2057@dispatch.add_dispatch_support 

2058def tuple_v2(tensors, control_inputs=None, name=None): 

2059 """Groups tensors together. 

2060 

2061 The returned tensors have the same value as the input tensors, but they 

2062 are computed only after all the input tensors have been computed. 

2063 

2064 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require 

2065 this method, as ops execute in the expected order thanks to automatic control 

2066 dependencies.* Only use `tf.tuple` when working with v1 `tf.Graph` code. 

2067 

2068 See also `tf.group` and `tf.control_dependencies`. 

2069 

2070 Example: 

2071 >>> with tf.Graph().as_default(): 

2072 ... with tf.compat.v1.Session() as sess: 

2073 ... v = tf.Variable(0.0) 

2074 ... a = tf.constant(1.0) 

2075 ... sess.run(tf.compat.v1.global_variables_initializer()) 

2076 ... for i in range(5): 

2077 ... update_op = v.assign_add(1.0) 

2078 ... b = a + v 

2079 ... res_b = sess.run(b) 

2080 ... res_v = sess.run(v) 

2081 ... print(res_v) 

2082 0.0 

2083 0.0 

2084 0.0 

2085 0.0 

2086 0.0 

2087 

2088 >>> with tf.Graph().as_default(): 

2089 ... with tf.compat.v1.Session() as sess: 

2090 ... v = tf.Variable(0.0) 

2091 ... a = tf.constant(1.0) 

2092 ... sess.run(tf.compat.v1.global_variables_initializer()) 

2093 ... for i in range(5): 

2094 ... update_op = v.assign_add(1.0) 

2095 ... calc = [a + v] 

2096 ... # `tf.tuple` ensures `update_op` is run before `b` 

2097 ... b = tf.tuple(calc, [tf.group(update_op)]) 

2098 ... res_b = sess.run(b) 

2099 ... res_v = sess.run(v) 

2100 ... print(res_v) 

2101 1.0 

2102 2.0 

2103 3.0 

2104 4.0 

2105 5.0 

2106 

2107 

2108 Args: 

2109 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 

2110 control_inputs: List of additional ops to finish before returning. 

2111 name: (optional) A name to use as a `name_scope` for the operation. 

2112 

2113 Returns: 

2114 Same as `tensors`. 

2115 

2116 Raises: 

2117 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 

2118 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 

2119 objects. 

2120 

2121 """ 

2122 return tuple(tensors=tensors, name=name, control_inputs=control_inputs) # pylint: disable=redefined-builtin 

2123 

2124 

2125@tf_export(v1=["tuple"]) 

2126@dispatch.add_dispatch_support 

2127def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin 

2128 """Group tensors together. 

2129 

2130 This creates a tuple of tensors with the same values as the `tensors` 

2131 argument, except that the value of each tensor is only returned after the 

2132 values of all tensors have been computed. 

2133 

2134 `control_inputs` contains additional ops that have to finish before this op 

2135 finishes, but whose outputs are not returned. 

2136 

2137 This can be used as a "join" mechanism for parallel computations: all the 

2138 argument tensors can be computed in parallel, but the values of any tensor 

2139 returned by `tuple` are only available after all the parallel computations 

2140 are done. 

2141 

2142 See also `tf.group` and 

2143 `tf.control_dependencies`. 

2144 

2145 Args: 

2146 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 

2147 name: (optional) A name to use as a `name_scope` for the operation. 

2148 control_inputs: List of additional ops to finish before returning. 

2149 

2150 Returns: 

2151 Same as `tensors`. 

2152 

2153 Raises: 

2154 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 

2155 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 

2156 objects. 

2157 

2158 """ 

2159 if context.executing_eagerly(): 

2160 return tensors 

2161 with ops.name_scope(name, "tuple", tensors) as name: 

2162 tensors = [ 

2163 t if (isinstance(t, ops.Operation) or tensor_util.is_tf_type(t) or 

2164 t is None) else ops.convert_to_tensor(t) for t in tensors 

2165 ] 

2166 gating_ops = [ 

2167 t if isinstance(t, ops.Operation) else t.op 

2168 for t in tensors 

2169 if t is not None 

2170 ] 

2171 if control_inputs: 

2172 for c in control_inputs: 

2173 if isinstance(c, ops.Tensor): 

2174 c = c.op 

2175 elif not isinstance(c, ops.Operation): 

2176 raise TypeError( 

2177 "'control_inputs' must only contain Operation or Tensor. " 

2178 f"Received: {type(c)}") 

2179 gating_ops.append(c) 

2180 # Note that in order to ensure ordering in the pbtxt, we must take care to 

2181 # ensure the order here. 

2182 gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops. 

2183 if not gating_ops: 

2184 raise ValueError("'tensors' must have at least one Tensor. " 

2185 f"Received: {tensors}.") 

2186 gate = group(*gating_ops) 

2187 tpl = [] 

2188 for t in tensors: 

2189 if tensor_util.is_tf_type(t): 

2190 tpl.append(with_dependencies([gate], t)) 

2191 elif isinstance(t, ops.Operation): 

2192 with ops.control_dependencies([gate]): 

2193 tpl.append(group(t)) 

2194 else: 

2195 tpl.append(None) 

2196 return tpl 

2197 

2198 

2199class XLAControlFlowContext(ControlFlowContext): 

2200 """Base class for XLA and TPU control flow contexts.""" 

2201 

2202 def __init__(self): 

2203 super(XLAControlFlowContext, self).__init__() 

2204 self._name = "XLAControlFlowContext" 

2205 

2206 def to_control_flow_context_def(self, context_def, export_scope=None): 

2207 # pylint: disable=useless-super-delegation 

2208 # NOTE(slebedev): the method is required by `ControlFlowContext`. 

2209 super(XLAControlFlowContext, 

2210 self).to_control_flow_context_def(context_def, export_scope) 

2211 

2212 def IsXLAContext(self): 

2213 return True 

2214 

2215 def AddOp(self, _): 

2216 pass 

2217 

2218 def AddValue(self, x): 

2219 return x 

2220 

2221 def RequiresUniqueFunctionRetracing(self): 

2222 """Returns whether the tf.function should be retraced if the context changes. 

2223 """ 

2224 return False 

2225 

2226 

2227@tf_export("__internal__.get_enclosing_xla_context", v1=[]) 

2228def get_enclosing_xla_context(): 

2229 """Recursively find and return the XLAControlFlowContext.""" 

2230 graph = ops.get_default_graph() 

2231 while graph is not None: 

2232 # pylint: disable=protected-access 

2233 context_ = graph._get_control_flow_context() 

2234 # pylint: enable=protected-access 

2235 while context_ is not None: 

2236 if isinstance(context_, XLAControlFlowContext): 

2237 return context_ 

2238 context_ = context_.outer_context 

2239 # This may be a FuncGraph due to defuns or v2 control flow. We need to 

2240 # find the original graph with the XLAControlFlowContext. 

2241 graph = getattr(graph, "outer_graph", None) 

2242 return None 

2243 

2244 

2245def from_control_flow_context_def(context_def, import_scope=None): 

2246 """Deserializes `context_def` into the appropriate ControlFlowContext. 

2247 

2248 Args: 

2249 context_def: ControlFlowContextDef proto 

2250 import_scope: Optional `string`. Name scope to add. 

2251 

2252 Returns: 

2253 A ControlFlowContext subclass 

2254 """ 

2255 if context_def.HasField("cond_ctxt"): 

2256 return CondContext.from_proto( 

2257 context_def.cond_ctxt, import_scope=import_scope) 

2258 if context_def.HasField("while_ctxt"): 

2259 return WhileContext.from_proto( 

2260 context_def.while_ctxt, import_scope=import_scope) 

2261 raise NotImplementedError("Unknown ControlFlowContextDef field: %s" % 

2262 context_def.WhichOneof("ctxt")) 

2263 

2264 

2265ops.register_proto_function( 

2266 ops.GraphKeys.COND_CONTEXT, 

2267 proto_type=control_flow_pb2.CondContextDef, 

2268 to_proto=CondContext.to_proto, 

2269 from_proto=CondContext.from_proto) 

2270 

2271ops.register_proto_function( 

2272 ops.GraphKeys.WHILE_CONTEXT, 

2273 proto_type=control_flow_pb2.WhileContextDef, 

2274 to_proto=WhileContext.to_proto, 

2275 from_proto=WhileContext.from_proto)