Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py: 15%

463 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15"""while_v2 and gradient. 

16 

17This is a version of while_loop that emits a single While op, as well as the 

18gradient function for While ops produced by while_loop. This will eventually 

19replace the current tf.while_loop implementation once it reaches feature and 

20performance parity. 

21""" 

22import collections 

23 

24from tensorflow.core.framework import attr_value_pb2 

25from tensorflow.python.client import pywrap_tf_session as c_api 

26from tensorflow.python.eager import backprop_util 

27from tensorflow.python.framework import auto_control_deps_utils as acd 

28from tensorflow.python.framework import constant_op 

29from tensorflow.python.framework import dtypes 

30from tensorflow.python.framework import func_graph as func_graph_module 

31from tensorflow.python.framework import indexed_slices 

32from tensorflow.python.framework import ops 

33from tensorflow.python.framework import tensor_shape 

34from tensorflow.python.framework import tensor_spec 

35from tensorflow.python.framework import tensor_util 

36from tensorflow.python.ops import array_ops 

37from tensorflow.python.ops import control_flow_ops 

38from tensorflow.python.ops import control_flow_util as util_v1 

39from tensorflow.python.ops import control_flow_util_v2 as util 

40from tensorflow.python.ops import default_gradient 

41from tensorflow.python.ops import gen_functional_ops 

42from tensorflow.python.ops import gen_resource_variable_ops 

43from tensorflow.python.ops import gradients_util 

44from tensorflow.python.ops import handle_data_util 

45from tensorflow.python.ops import list_ops 

46from tensorflow.python.ops import math_ops 

47from tensorflow.python.ops import tensor_array_ops 

48from tensorflow.python.ops import while_v2_indexed_slices_rewriter 

49from tensorflow.python.util import compat 

50from tensorflow.python.util import nest 

51from tensorflow.python.util import object_identity 

52from tensorflow.python.util import variable_utils 

53 

54# pylint: disable=protected-access 

55 

56 

57def while_loop(cond, 

58 body, 

59 loop_vars, 

60 shape_invariants=None, 

61 parallel_iterations=10, 

62 maximum_iterations=None, 

63 name=None, 

64 return_same_structure=True, 

65 back_prop=True): 

66 """Like tf.while_loop, except emits a single While op.""" 

67 loop_vars = variable_utils.convert_variables_to_tensors(loop_vars) 

68 # Keep the original loop_vars around to know which args were TensorArrays. 

69 orig_loop_vars = loop_vars 

70 flat_orig_loop_vars = nest.flatten(orig_loop_vars, expand_composites=True) 

71 # Cache its length since we use it at multiple places below. 

72 len_orig_loop_vars = len(orig_loop_vars) 

73 

74 # Convert TensorArrays to their flow variables. These get converted back to 

75 # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and 

76 # `wrapped_body` below. 

77 loop_vars = _tensor_array_to_flow(loop_vars) 

78 loop_vars = nest.map_structure( 

79 indexed_slices.internal_convert_to_tensor_or_indexed_slices, 

80 loop_vars, 

81 expand_composites=True) 

82 

83 # `loop_vars_signature` is a structure of TypeSpecs and has the same 

84 # structure with the `orig_loop_vars`. If `shape_invariants` is not None, its 

85 # shape information comes from `shape_invariants` instead of `orig_loop_vars`. 

86 # It is used to pack flattened vars into structured vars. 

87 if shape_invariants is not None: 

88 loop_vars_signature = nest.map_structure( 

89 control_flow_ops._shape_invariant_to_type_spec, 

90 loop_vars, shape_invariants) 

91 else: 

92 loop_vars_signature = nest.map_structure( 

93 control_flow_ops._shape_invariant_to_type_spec, loop_vars) 

94 

95 flat_shape_invariants = nest.map_structure( 

96 lambda spec: spec.shape, 

97 nest.flatten(loop_vars_signature, expand_composites=True)) 

98 

99 if not name: 

100 name = "while" 

101 

102 with ops.name_scope(name) as scope: 

103 with ops.name_scope(None): 

104 cond_name = util.unique_fn_name(scope, "cond") 

105 body_name = util.unique_fn_name(scope, "body") 

106 maximum_iterations_loop_var = _build_maximum_iterations_loop_var( 

107 maximum_iterations) 

108 loop_counter = constant_op.constant( 

109 0, 

110 dtype=maximum_iterations_loop_var.dtype 

111 if maximum_iterations is not None else None, 

112 name="loop_counter") 

113 # Add loop counter needed for computing gradients. 

114 loop_vars = [loop_counter, maximum_iterations_loop_var] + list(loop_vars) 

115 

116 func_graph_signature = ( 

117 [tensor_spec.TensorSpec.from_tensor(loop_counter), 

118 tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] + 

119 list(loop_vars_signature)) 

120 

121 # Automatic control dependencies are added in defuns, but not in v1 

122 # graphs. Propagate that behavior here. 

123 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 

124 

125 def wrapped_cond(loop_counter, maximum_iterations_arg, *args): 

126 """Extra `cond` wrapper that can handle the extra counter loop_var.""" 

127 # Convert the flow variables in `args` to TensorArrays. `args` should 

128 # already have the same structure as `orig_loop_vars` but currently there 

129 # is no nest.zip so we call `_pack_sequence_as` which flattens `args`, 

130 # converts flows in `args` to TensorArrays and packs it into the 

131 # structure of `loop_vars_signature`. 

132 pred = cond( 

133 *_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)) 

134 if (tensor_util.is_tf_type(pred) and 

135 (pred.shape.dims is None or pred.shape.dims)): 

136 pred = array_ops.squeeze_v2(pred) 

137 

138 if maximum_iterations is None: 

139 return pred 

140 else: 

141 return math_ops.logical_and( 

142 loop_counter < maximum_iterations_arg, pred) 

143 

144 # NOTE(skyewm): we set collections to the outer graph's collections for 

145 # compatibility with TPUEstimator. 

146 cond_graph = func_graph_module.func_graph_from_py_func( 

147 cond_name, 

148 wrapped_cond, 

149 [], # We provide signature instead of args. 

150 {}, 

151 signature=func_graph_signature, 

152 func_graph=util.WhileCondFuncGraph( 

153 cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 

154 add_control_dependencies=add_control_dependencies) 

155 

156 def wrapped_body(loop_counter, maximum_iterations_arg, *args): 

157 """Loop body augmented with counter update. 

158 

159 Args: 

160 loop_counter: Loop counter which needs to be incremented in the body. 

161 maximum_iterations_arg: Maximum iterations of the loop. 

162 *args: List of args 

163 

164 Returns: 

165 A list of tensors the same length as args. 

166 """ 

167 # The function was created with a signature rather than tensors, so 

168 # internal placeholders were created without handle data. 

169 _copy_handle_data(nest.flatten(loop_vars[2:], expand_composites=True), 

170 nest.flatten(args, expand_composites=True)) 

171 # Capture the tensors already captured in cond_graph so that they appear 

172 # in the same order in body_graph.external_captures. 

173 for t in cond_graph.external_captures: 

174 ops.get_default_graph().capture(t) 

175 

176 # Convert the flow variables in `args` to TensorArrays. `args` should 

177 # already have the same structure as `orig_loop_vars` but currently there 

178 # is no nest.zip so we call `_pack_sequence_as` which flattens `args`, 

179 # converts flows in `args` to TensorArrays and packs it into the 

180 # structure of `loop_vars_signature`. 

181 outputs = body( 

182 *_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)) 

183 if not nest.is_nested(outputs): 

184 outputs = [outputs] 

185 try: 

186 # The legacy while_loop considers list and tuple to be the same 

187 # structure. 

188 nest.assert_same_structure(outputs, orig_loop_vars, check_types=False, 

189 expand_composites=True) 

190 except ValueError: 

191 # Traditionally we consider variables and tensors to be the same 

192 # structure. 

193 vars1 = variable_utils.convert_variables_to_tensors(outputs) 

194 vars2 = variable_utils.convert_variables_to_tensors(orig_loop_vars) 

195 nest.assert_same_structure(vars1, vars2, check_types=False, 

196 expand_composites=True) 

197 outputs = _tensor_array_to_flow(outputs) 

198 

199 # TODO(srbs): Update lowering code to create _Enter nodes with 

200 # is_constant=True for inputs that are directly passed to outputs. 

201 return [loop_counter + 1, maximum_iterations_arg] + list(outputs) 

202 

203 body_graph = func_graph_module.func_graph_from_py_func( 

204 body_name, 

205 wrapped_body, 

206 [], # We provide signature instead of args. 

207 {}, 

208 signature=func_graph_signature, 

209 func_graph=util.WhileBodyFuncGraph( 

210 body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 

211 add_control_dependencies=add_control_dependencies) 

212 # Add external captures of body to the list of loop vars. 

213 # Note that external tensors will be treated as loop invariants, i.e., 

214 # the value of that tensor in each iteration is the same as it was at the 

215 # beginning of the loop execution. 

216 deferred_external_captures = nest.flatten( 

217 [c() for c in body_graph.deferred_external_captures], 

218 expand_composites=True) 

219 loop_vars = ( 

220 loop_vars + body_graph.external_captures + deferred_external_captures) 

221 # TODO(srbs): Update lowering code to create _Enter nodes with 

222 # is_constant=True for inputs that are directly passed to outputs. 

223 body_graph.outputs.extend(body_graph.internal_captures) 

224 body_graph.outputs.extend(body_graph.deferred_internal_captures) 

225 

226 # Capture the extra `external_captures` of `body_graph` in `cond_graph` so 

227 # that it expects to receive those as arguments. 

228 with cond_graph.as_default(): 

229 num_cond_captures = len(cond_graph.external_captures) 

230 assert (cond_graph.external_captures == 

231 body_graph.external_captures[:num_cond_captures]) 

232 _duplicate_body_captures_in_cond( 

233 cond_graph, body_graph.external_captures[num_cond_captures:] + 

234 deferred_external_captures) 

235 

236 # Make sure that the shapes of the loop outputs are compatible with the 

237 # shape invariants, or the shapes of the loop vars if the invariants are not 

238 # specified. 

239 num_flattened_outputs = len(nest.flatten(orig_loop_vars, 

240 expand_composites=True)) 

241 # First var is loop counter and second var is maximum_iterations. 

242 first_loop_var_index = 2 

243 _check_shapes_compat( 

244 body_graph.outputs[first_loop_var_index:first_loop_var_index + 

245 num_flattened_outputs], 

246 flat_shape_invariants, 

247 nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + 

248 len_orig_loop_vars], expand_composites=True)) 

249 

250 num_original_outputs = len(body_graph.outputs) 

251 if back_prop and util.output_all_intermediates(): 

252 # Export all tensors in the loop body that may be needed for gradient 

253 # computation. We do this by accumulating the intermediate values in 

254 # TensorLists. 

255 intermediate_tensors = _get_intermediates(body_graph) 

256 

257 for intermediate_tensor in intermediate_tensors: 

258 tensor_list = list_ops.empty_tensor_list( 

259 element_dtype=intermediate_tensor.dtype, 

260 element_shape=intermediate_tensor.shape, 

261 max_num_elements=maximum_iterations) 

262 loop_vars.append(tensor_list) 

263 with cond_graph.as_default(): 

264 # Add a placeholder to cond_graph's inputs corresponding to the 

265 # tensor_list. 

266 cond_graph.capture(tensor_list) 

267 with body_graph.as_default(): 

268 # Push the intermediate tensor to the tensor list. This captures the 

269 # `tensor_list` as well. 

270 appended_tensor_list = list_ops.tensor_list_push_back( 

271 tensor_list, intermediate_tensor) 

272 # Add this modified tensor list to the list of outputs. 

273 body_graph.outputs.append(appended_tensor_list) 

274 

275 flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True) 

276 _check_num_inputs_outputs(cond_graph, body_graph, 

277 len(flattened_loop_vars)) 

278 _check_inputs_outputs_types_match(body_graph, flattened_loop_vars) 

279 

280 with ops.control_dependencies( 

281 list(cond_graph.function_captures.control) + list( 

282 body_graph.function_captures.control)): 

283 output_shapes = [t.shape for t in body_graph.outputs] 

284 orig_loop_vars_range = slice(first_loop_var_index, 

285 first_loop_var_index + num_flattened_outputs) 

286 output_shapes[orig_loop_vars_range] = flat_shape_invariants 

287 

288 outputs = _build_while_op( 

289 flattened_loop_vars, 

290 cond_graph, 

291 body_graph, 

292 output_shapes=output_shapes, 

293 parallel_iterations=parallel_iterations, 

294 name=scope, 

295 num_original_outputs=num_original_outputs) 

296 if not ops.get_default_graph().building_function: 

297 # In V1 graph mode, return identities for each output of the While op, 

298 # rather than the output of the While op directly. This makes pruning work 

299 # if the output of while_loop() is fetched: the lowering pass converts the 

300 # While outputs into IdentityN outputs, which if fetched will cause all 

301 # ops in the body to be run (since it takes all exit ops as input). After 

302 # lowering, each output identity op will end up with only the appropriate 

303 # exit op as input. 

304 outputs = tuple(array_ops.identity(t) for t in outputs) 

305 

306 output_loop_vars = outputs[first_loop_var_index:first_loop_var_index + 

307 num_flattened_outputs] 

308 if not back_prop: 

309 output_loop_vars = [array_ops.stop_gradient(t) for t in output_loop_vars] 

310 outputs = _pack_sequence_as( 

311 loop_vars_signature, flat_orig_loop_vars, output_loop_vars) 

312 

313 if return_same_structure: 

314 return outputs 

315 

316 flattened_outputs = nest.flatten(outputs, expand_composites=True) 

317 if len(flattened_outputs) == 1: 

318 return flattened_outputs[0] 

319 else: 

320 return outputs 

321 

322 

323@ops.RegisterGradient("StatelessWhile") 

324@ops.RegisterGradient("While") 

325def _WhileGrad(op, *grads): # pylint: disable=invalid-name 

326 """The gradient of a While op produced by while_loop.""" 

327 # Note that op is not always the same as while_op because the gradient tape, 

328 # for eager mode compatibility, forgets information about the proper op. Since 

329 # the loop cannot run in eager mode, however, we can safely introspect into 

330 # the graph here. 

331 while_op = op.outputs[0].op 

332 cond_graph = _get_graph(while_op, "cond", "_cond_graph") 

333 body_graph = _get_graph(while_op, "body", "_body_graph") 

334 orig_num_params = len(body_graph.outputs) 

335 

336 maximum_iterations = op.inputs[1] 

337 parallel_iterations = op.get_attr("parallel_iterations") 

338 

339 try: 

340 num_original_outputs = while_op.get_attr("_num_original_outputs") 

341 except: # pylint: disable=bare-except 

342 num_original_outputs = len(while_op.outputs) 

343 

344 num_intermediates = len(while_op.outputs) - num_original_outputs 

345 grads = [ 

346 _preprocess_grad(grad, body_out, while_in, while_out) # pylint: disable=g-complex-comprehension 

347 for grad, body_out, while_in, while_out in zip( 

348 grads[:num_original_outputs], 

349 body_graph.outputs[:num_original_outputs], 

350 while_op.inputs[:num_original_outputs], 

351 while_op.outputs[:num_original_outputs]) 

352 ] + [None] * num_intermediates 

353 

354 # Skip gradients with respect to the captures whenever possible. 

355 if getattr(op, "skip_input_indices", None) is not None: 

356 captures_start_index = ( 

357 len(body_graph.inputs) - len(body_graph.internal_captures)) 

358 for i in op.skip_input_indices: 

359 if i >= captures_start_index: 

360 grads[i] = None 

361 

362 # We compute the gradient for the sub-graph between trainable ys and xs 

363 # with non-None incoming gradients. We later pad the None's to the list of 

364 # outputs. 

365 ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( 

366 body_graph.outputs, body_graph.inputs, grads) if grad is not None]) 

367 

368 body_grad_graph, args = _create_grad_func( 

369 ys, xs, non_none_grads, cond_graph, body_graph, 

370 util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) 

371 

372 if body_grad_graph.while_op_needs_rewrite: 

373 # Modify 'op' to output the intermediate accumulators needed by the grad 

374 # function. 

375 # NOTE(skyewm): if there are any active sessions, this modification to `op` 

376 # may make them unrunnable! 

377 

378 cond_graph.name += "_rewritten" 

379 body_graph.name += "_rewritten" 

380 

381 # `body_grad_graph.extra_inputs` here is equivalent to skimming off the new 

382 # `body_graph.external_captures` added during `_create_grad_func`. 

383 new_inputs = body_grad_graph.extra_inputs 

384 new_outputs = body_graph.outputs[orig_num_params:] 

385 

386 while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) 

387 while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) 

388 if len(body_graph.output_types) != len(while_op.inputs) + len(new_inputs): 

389 # Continuing leads to an invalid graph with disconnected inputs. 

390 raise AssertionError( 

391 "Inputs and outputs constructed for the forward op of a While " 

392 "gradient don't match with 'output_types' at " 

393 f"{len(body_graph.output_types)},'inputs' at length " 

394 f"{len(while_op.inputs)}, and 'new_inputs' at length " 

395 f"{len(new_inputs)}. This doesn't make sense, please file a bug.") 

396 while_op._set_type_list_attr("T", body_graph.output_types) 

397 while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) 

398 while_op._add_while_inputs(new_inputs) 

399 while_op._add_outputs([t.dtype for t in new_outputs], 

400 [t.shape for t in new_outputs]) 

401 _copy_handle_data(new_outputs, while_op.outputs[orig_num_params:]) 

402 

403 # Do not ignore grads wrt extra outputs when computing higher order 

404 # derivatives. 

405 while_op._set_attr("_num_original_outputs", 

406 attr_value_pb2.AttrValue(i=len(while_op.outputs))) 

407 

408 captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, 

409 while_op) 

410 loop_vars = args + captured_inputs 

411 

412 # This modifies body_grad_graph. 

413 loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices( 

414 grads, body_grad_graph, loop_vars, while_op.inputs) 

415 

416 def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters, 

417 *unused_args): 

418 return counter < forward_loop_iters 

419 

420 grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) 

421 cond_grad_graph = func_graph_module.func_graph_from_py_func( 

422 grad_cond_name, grad_cond, loop_vars, {}, 

423 func_graph=util.WhileCondFuncGraph(grad_cond_name)) 

424 

425 _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) 

426 

427 outputs = _build_while_op( 

428 loop_vars, 

429 cond_grad_graph, 

430 body_grad_graph, 

431 output_shapes=[t.shape for t in body_grad_graph.outputs], 

432 parallel_iterations=parallel_iterations, 

433 name="%s_grad" % while_op.name, 

434 num_original_outputs=len(body_grad_graph.outputs)) 

435 

436 # See comment in while_loop. 

437 outputs = [array_ops.identity(t) for t in outputs] 

438 return _get_structured_grad_output(outputs, grads, body_grad_graph) 

439 

440 

441def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes, 

442 parallel_iterations, name, num_original_outputs): 

443 """Builds the functional StatelessWhile/While op.""" 

444 cond_stateful_ops = [ 

445 op for op in cond_graph.get_operations() if op._is_stateful 

446 ] 

447 body_stateful_ops = [ 

448 op for op in body_graph.get_operations() if op._is_stateful 

449 ] 

450 if (cond_stateful_ops or body_stateful_ops): 

451 op_fn = gen_functional_ops._while 

452 else: 

453 op_fn = gen_functional_ops.stateless_while 

454 

455 def _make_op(inputs): 

456 while_op, tensors = util.get_op_and_outputs(op_fn( 

457 inputs, 

458 util.create_new_tf_function(cond_graph), 

459 util.create_new_tf_function(body_graph), 

460 output_shapes=output_shapes, 

461 parallel_iterations=parallel_iterations, 

462 name=name)) 

463 _copy_handle_data(body_graph.outputs, tensors) 

464 util.maybe_set_lowering_attr(while_op) 

465 util.maybe_propagate_compile_time_consts_in_xla(while_op) 

466 _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph]) 

467 # This is needed so we do not compute derivative wrt these extra outputs. 

468 while_op._set_attr("_num_original_outputs", 

469 attr_value_pb2.AttrValue(i=num_original_outputs)) 

470 # The while op may be created inside a tf.function, in which case ops 

471 # needs to capture "through" it when taking gradients; outer_graph is used 

472 # as a sanity check that capturing only happens from parent to child. 

473 cond_graph.outer_graph = ops.get_default_graph() 

474 body_graph.outer_graph = ops.get_default_graph() 

475 while_op._cond_graph = cond_graph 

476 while_op._body_graph = body_graph 

477 return tensors 

478 return util.run_as_function_for_tape_gradients(_make_op, loop_vars) 

479 

480 

481def _get_intermediates(func_graph): 

482 """Returns all tensors in `func_graph` that should be accumulated.""" 

483 # We currently accumulate output tensors of most ops in the function and rely 

484 # on the pruning pass to get rid of the unused accumulators at runtime. 

485 # However, this can bloat the GraphDef and make debugging harder so we perform 

486 # some optimizations. 

487 # 

488 # Optimization we currently perform: 

489 # 1. We do not accumulate tensors which already have an accumulator 

490 # in the loop body. 

491 # 2. We do not accumulate outputs of Identity nodes. When building the 

492 # FuncGraph, we add an Identity node for each output (see 

493 # `AutomaticControlDependencies.mark_as_return`). Accumulating outputs 

494 # of all these nodes bloats the GraphDef quite a bit so we remove those. 

495 # Since the gradient of an Identity node does not rely on its forward op's 

496 # input this is safe to do. 

497 # 

498 # Other possible optimizations: 

499 # 1. Only accumulate tensors that will be required by the backward pass. 

500 # This will require running the gradient pass and hence would increase the 

501 # graph building time for the forward pass. 

502 # 2. Do not accumulate Const nodes created inside the loop body. 

503 # 3. Do not accumulate loop vars that are returned as-is just like captured 

504 # tensors. 

505 intermediates = [] 

506 reverse_captures = dict((v.ref(), k) for k, v in func_graph.captures) 

507 

508 for op in func_graph.get_operations(): 

509 if op.type == "Identity": 

510 continue 

511 # Accumulating mutexes can cause deadlock. 

512 if op.type == "MutexLock": 

513 continue 

514 for o in op.outputs: 

515 if (o is not func_graph.inputs[0] and # Loop counter. 

516 o.dtype != dtypes.resource and # Do not accumulate resource tensors. 

517 _get_accumulator(o) is None and # Has existing accumulator. 

518 o.ref() not in reverse_captures 

519 ): # Captured value, hence loop invariant. 

520 intermediates.append(o) 

521 return intermediates 

522 

523 

524def _preprocess_grad(grad, body_graph_output, while_op_input, while_op_output): 

525 """Returns the initial gradient to be used for a given output tensor. 

526 

527 Args: 

528 grad: the original gradient Tensor passed to the gradient function. 

529 body_graph_output: the corresponding Tensor in the body graph. 

530 while_op_input: the corresponding Tensor input of the While op. 

531 while_op_output: the corresponding Tensor output of the While op. 

532 

533 Returns: 

534 A Tensor or None. 

535 """ 

536 # Set the incoming gradient of non-trainable inputs to None. It is possible 

537 # that we receive non-None gradients for non-trainable types in nested while 

538 # loops because we accumulate outputs of the inner while as variant tensors 

539 # which are trainable and hence receive zeros_like tensors in the gradient 

540 # pass. The non-trainable tensors then receive the popped zeros tensor from 

541 # this zeros variant. The gradient for the loop vars corresponding to these 

542 # tensors is None or zeros (this happens only if the loop var is accumulated 

543 # as well) in _grad_fn so we reset these. 

544 # TODO(b/118712257): Remove once we can handle None output grads in _grad_fn. 

545 if not _is_trainable(body_graph_output): 

546 return None 

547 

548 # GradientTape initializes resource and variant grads as None instead of 

549 # zeros. Set to zeros so _GradientsHelper computes the gradients instead of 

550 # returning None. 

551 # TODO(b/143286622): The supports_default_grad check is needed 

552 # because While op emits non-differentiable resource tensors 

553 # as outputs. Remove this check when that is not the case. 

554 # Note: We use `while_op_input` instead of `while_op_output` for the call 

555 # to `supports_default_grad` because `while_op_output` may be missing 

556 # handle_data if the While is in a restored saved model. 

557 if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and 

558 default_gradient.supports_default_grad(while_op_input) and grad is None): 

559 return _zeros_like(while_op_input, while_op_output) 

560 

561 # Convert IndexedSlices to dense tensors since it is unlikely that downstream 

562 # gradient functions with properly handle indexed slices. This is similar to 

563 # what we do in tf.function gradients. 

564 if isinstance(grad, indexed_slices.IndexedSlices): 

565 return ops.convert_to_tensor(grad) 

566 

567 return grad 

568 

569 

570# TODO(skyewm): make this return constants if op_output's shape is fully 

571# defined (this can be done by checking the "shape" attr of resource vars). 

572def _zeros_like(op_input, op_output): 

573 """Like array_ops.zeros_like() but also accepts resource var handles.""" 

574 if op_output.dtype == dtypes.resource: 

575 # Note: We use `op_input` instead of `op_output` to get the zeros dtype 

576 # because `op_output` may be missing handle_data if the While is in a 

577 # restored saved model. 

578 return array_ops.zeros( 

579 gen_resource_variable_ops.variable_shape(op_output), 

580 dtype=default_gradient.get_zeros_dtype(op_input)) 

581 return array_ops.zeros_like(op_output) 

582 

583 

584def _is_trainable(tensor): 

585 """Returns whether the given tensor is trainable.""" 

586 if not backprop_util.IsTrainable(tensor): 

587 return False 

588 

589 # Special case: untrainable accumulator output. The gradients algorithm 

590 # doesn't know about tensor lists of untrainable elements. In theory the 

591 # tensor list gradient functions should return None as appropriate, but 

592 # because we can't return None from the gradient function we filter out 

593 # untrainable accumulator output here to avoid computing the gradient at all. 

594 if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0: 

595 assert tensor.dtype == dtypes.variant 

596 element_type = tensor.op.get_attr("element_dtype") 

597 return backprop_util.IsTrainable(element_type) 

598 

599 return True 

600 

601 

602def _get_graph(while_op, func_attr_name, attr_graph_name): 

603 """Returns `FuncGraph` for the given function attribute. 

604 

605 Args: 

606 while_op: The While Operation. 

607 func_attr_name: string 

608 attr_graph_name: cached forward graph name 

609 

610 Returns: 

611 `FuncGraph` 

612 """ 

613 func_graph = getattr(while_op, attr_graph_name, None) 

614 if func_graph is None: 

615 # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes. 

616 input_shapes = [ 

617 tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes") 

618 ] 

619 func_name = while_op.get_attr(func_attr_name).name 

620 func_graph = util.get_func_graph(while_op, input_shapes, func_name) 

621 func_graph._while = while_op 

622 return func_graph 

623 

624 

625def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op, 

626 maximum_iterations): 

627 """Builds and returns the gradient FuncGraph of `func_graph` and its args. 

628 

629 The returned grad_func_graph must be called with the returned 

630 args + grad_func_graph.captures. 

631 

632 Args: 

633 ys: A `Tensor` or list of tensors to be differentiated. 

634 xs: A `Tensor` or list of tensors to be used for differentiation. 

635 grads: The incoming grads for `ys`. 

636 cond_graph: FuncGraph for the forward cond function. 

637 body_graph: FuncGraph for the forward body function. 

638 name: Name of the returned gradient function. 

639 while_op: The forward While op. 

640 maximum_iterations: Tensor. The maximum number of iterations. 

641 

642 Returns: 

643 2-tuple of (grad_func_graph, args). 

644 """ 

645 assert len(ys) == len(grads) 

646 

647 total_iters = while_op.outputs[0] 

648 counter = constant_op.constant( 

649 0, dtype=total_iters.dtype, name="grad_counter") 

650 

651 # Build frozen sets so that we do not have linear time lookups in 

652 # `_is_loop_invariant`. Note: `body_graph.inputs` and `body_graph.outputs` 

653 # may get updated during gradient computation because we add accumulators to 

654 # the forward op. However, those are not loop invariants so wouldn't affect 

655 # the output of `_is_loop_invariant`. Also we would never attempt to capture 

656 # those accumulators so `_is_loop_invariant` should never receive those new 

657 # tensors as args. 

658 body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs) 

659 body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs) 

660 

661 args = [counter, maximum_iterations, total_iters] + list(grads) 

662 # Note: The returned function does not have `args` in the list of 

663 # `external_captures`. 

664 grad_func_graph = func_graph_module.func_graph_from_py_func( 

665 name, 

666 lambda *args: _grad_fn(ys, xs, args, body_graph), 

667 args, {}, 

668 func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph, 

669 maximum_iterations, while_op, 

670 body_graph_inputs, body_graph_outputs)) 

671 

672 # Update the list of outputs with tensors corresponding to the captured 

673 # tensors. We capture 3 types of tensors when building the grad fn: 

674 # 1. Accumulators for forward graph intermediates which are not loop 

675 # invariants. The outputs corresponding to these are populated in 

676 # `internal_capture_to_output` by `_WhileBodyGradFuncGraph`. 

677 # 2. Resources, which are output as is. 

678 # 3. Forward graph loop invariants, which are output as is. 

679 for external_capture, internal_capture in grad_func_graph.captures: 

680 if (ops.tensor_id(internal_capture) 

681 in grad_func_graph.internal_capture_to_output): 

682 new_output = grad_func_graph.internal_capture_to_output[ops.tensor_id( 

683 internal_capture)] 

684 else: 

685 raise ValueError( 

686 f"Tensor {str(internal_capture)} which captures " 

687 f"{str(external_capture)} is in list of " 

688 f"internal_captures but not in internal_capture_to_output.") 

689 grad_func_graph.outputs.append(new_output) 

690 grad_func_graph.structured_outputs.append(new_output) 

691 

692 return grad_func_graph, args 

693 

694 

695def _grad_fn(ys, xs, args, func_graph): 

696 """Computes the gradient of `func_graph` in the current graph. 

697 

698 This function builds the gradient graph of the corresponding forward-pass 

699 `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs. 

700 

701 Args: 

702 ys: A `Tensor` or list of tensors to be differentiated. 

703 xs: A `Tensor` or list of tensors to be used for differentiation. 

704 args: The input arguments. 

705 args[0] - Loop counter 

706 args[1] - Total number of iterations. 

707 args[2] - maximum_iterations. 

708 args[3:] - Incoming gradients for `ys`. 

709 func_graph: function.FuncGraph. The corresponding forward-pass function. 

710 

711 Returns: 

712 The output gradient Tensors. 

713 """ 

714 grad_ys = args[3:] 

715 

716 # Build the gradient graph. Note that this builds the gradient computation of 

717 # func_graph in the current graph, which requires capturing tensors from 

718 # func_graph. The captured func_graph tensors are resolved to external tensors 

719 # after the forward While op has been rewritten in _resolve_grad_captures. 

720 # TODO(srbs): Mark GradientsHelper as public? 

721 grad_outs = gradients_util._GradientsHelper( 

722 ys, xs, grad_ys=grad_ys, src_graph=func_graph, 

723 unconnected_gradients="zero") 

724 

725 # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there 

726 # is a tf.StopGradient in the loop body. 

727 assert all(g is not None for g in grad_outs) 

728 counter = args[0] 

729 maximum_iterations = args[1] 

730 total_iters = args[2] 

731 return [counter + 1, maximum_iterations, total_iters] + grad_outs 

732 

733 

734def _resolve_grad_captures(body_graph, body_grad_graph, while_op): 

735 """Returns the tensors to pass as captured inputs to `body_grad_graph`. 

736 

737 `body_grad_graph` may have external references to: 

738 1. Its outer graph containing the input gradients. These are left as-is. 

739 2. Accumulators captured from the forward-pass graph. These should have been 

740 added as `while_op` outputs after the gradient graph was built. We replace 

741 these with the corresponding output of `while_op`, i.e. a tensor in 

742 `body_graph.outer_graph`. In the case of nested control flow or functions, 

743 the gradient logic handling `body_grad_graph.outer_graph` will make sure 

744 the tensor from `body_graph.outer_graph` is also correctly captured. 

745 

746 Args: 

747 body_graph: FuncGraph. The forward-pass body function. 

748 body_grad_graph: FuncGraph. The body gradients function. 

749 while_op: The forward-pass While Operation calling `body_graph`. 

750 

751 Returns: 

752 A list of input tensors to be passed as the captured inputs to 

753 `body_grad_graph`. 

754 """ 

755 new_capture_inputs = [] 

756 for t in body_grad_graph.external_captures: 

757 # Resolve tensors captured from the forward graph to the outputs of the 

758 # forward while_op. 

759 if t.graph == body_graph: 

760 # Captured accumulator or loop invariant. 

761 for i, output in enumerate(t.graph.outputs): 

762 if output is t: 

763 t = while_op.outputs[i] 

764 break 

765 

766 # Note: We rely on the capturing logic of the gradient While op graph to 

767 # correctly capture the tensors in `body_graph.outer_graph`. Both cond_v2 

768 # and while_v2 handle this while building their gradient functions. 

769 assert t.graph == body_graph.outer_graph 

770 

771 new_capture_inputs.append(t) 

772 return new_capture_inputs 

773 

774 

775def _get_structured_grad_output(outputs, grads, body_grad_graph): 

776 """Returns the values that should be returned from the while grad function. 

777 

778 Args: 

779 outputs: the raw Tensor outputs of the grad While op. 

780 grads: the input gradients to the gradient function. 

781 body_grad_graph: _WhileBodyGradFuncGraph. 

782 

783 Returns: 

784 A list of gradient values. May include Nones. 

785 """ 

786 result = [] 

787 # outputs[0] is the loop counter. 

788 # outputs[1] is maximum_iterations. 

789 # outputs[2] is the total number of loop iterations. 

790 outputs_idx = 3 

791 structured_outputs_idx = 3 

792 for g in grads: 

793 # Set None as the output gradient for tensors with None input gradient. 

794 if g is None: 

795 result.append(None) 

796 continue 

797 output = body_grad_graph.structured_outputs[structured_outputs_idx] 

798 structured_outputs_idx += 1 

799 if isinstance(output, indexed_slices.IndexedSlices): 

800 # TODO(skyewm): is there a more robust way to determine the order of 

801 # flattened IndexedSlices components? 

802 result.append(indexed_slices.IndexedSlices( 

803 values=outputs[outputs_idx], 

804 indices=outputs[outputs_idx + 1], 

805 dense_shape=outputs[outputs_idx + 2])) 

806 outputs_idx += 3 

807 else: 

808 assert isinstance(output, ops.Tensor) 

809 result.append(outputs[outputs_idx]) 

810 outputs_idx += 1 

811 

812 return result 

813 

814 

815def _get_accumulator(tensor): 

816 r"""Returns TensorList if any containing accumulated values of tensor. 

817 

818 We try to find a pattern of the form: 

819 

820 input_tl tensor 

821 \ / 

822 (TensorListPushBack) 

823 | 

824 output_tl 

825 

826 which satisfies the following conditions: 

827 

828 1. input_tl must be in tensor.graph.inputs. 

829 2. output_tl or Identity(output_tl) must be in tensor.graph.outputs. 

830 3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t). 

831 

832 output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is 

833 returned if such a pattern is found else None is returned. 

834 

835 Args: 

836 tensor: The Tensor to be accumulated. 

837 

838 Returns: 

839 A variant tensor in the same graph as `tensor` or None if no accumulator is 

840 found. 

841 """ 

842 assert isinstance(tensor.graph, func_graph_module.FuncGraph) 

843 

844 def get_func_graph_output(t): 

845 """Returns t or Identity(t) whichever exists in graph outputs else None.""" 

846 for output in tensor.graph.outputs: 

847 if output is t: 

848 return t 

849 # tf.defun adds an Identity for each output, check whether that is the case. 

850 identity_op = t.consumers()[0] 

851 if (identity_op.type == "Identity" and 

852 any(identity_op.outputs[0] is t for t in tensor.graph.outputs)): 

853 return identity_op.outputs[0] 

854 return None 

855 

856 for consumer in tensor.consumers(): 

857 # Find the consumer that is a TensorListPushBack node whose TensorList input 

858 # is in the list of function inputs. 

859 if consumer.type != "TensorListPushBack": 

860 continue 

861 

862 accum_input_idx = -1 

863 for accum_input_idx, inp in enumerate(tensor.graph.inputs): 

864 if inp is consumer.inputs[0]: 

865 break 

866 else: 

867 continue 

868 

869 output = get_func_graph_output(consumer.outputs[0]) 

870 if output is None: 

871 # The TensorList output of `consumer` is not in the list of function 

872 # outputs. 

873 continue 

874 

875 for accum_output_idx, out in enumerate(tensor.graph.outputs): 

876 if out is output: 

877 if accum_input_idx == accum_output_idx: 

878 return output 

879 break 

880 

881 return None 

882 

883 

884OptimizedReductionOpsCacheKey = collections.namedtuple( 

885 "OptimizedReductionOpsCacheKey", [ 

886 "op_type", 

887 "inputs", 

888 "dtypes", 

889 "input_types", 

890 "name", 

891 "attrs", 

892 "op_def", 

893 "compute_device", 

894 ]) 

895 

896 

897class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): 

898 """FuncGraph for the gradient function of the body of a While op. 

899 

900 Contains the logic for capturing the tensors from the body of the forward 

901 While op which is as follows: 

902 1. If the tensor is of resource type (these are not accumulated): 

903 a. Ensure that the tensor is a loop invariant, i.e., it exists in both loop 

904 inputs and outputs at the same index. 

905 b. Lookup the corresponding resource tensor in the forward outer graph and 

906 try to capture that. 

907 2. If the tensor is not of resource type: 

908 a. Create an accumulator for that tensor and output it from the forward 

909 pass. Note this also requires adding it as an input to the forward pass. 

910 b. Capture the accumulator from the forward pass in this FuncGraph. This 

911 will later be resolved to the correct output of the forward While op. 

912 c. Pop a value from the captured placeholder and use it as the captured 

913 value for the forward pass tensor. 

914 

915 This only allows capturing tensors in the forward graph. A ValueError is 

916 raised if an attempt is made to capture a tensor not in the forward graph. 

917 To manually capture a tensor that is not in the forward graph, call `capture` 

918 with `allowlisted=True`. 

919 

920 Note: The `captures` dict does not contain the forward tensor since it is not 

921 directly captured. It contains the accumulator corresponding to this forward 

922 tensor. 

923 

924 Attributes: 

925 while_op_needs_rewrite: True if any non-resource intermediates were 

926 captured, meaning the forward While op needs to be rewritten to output the 

927 corresponding accumulators. 

928 extra_inputs: list of EmptyTensorList tensors to be used as initial input to 

929 the new accumulators in the forward graph. It may also contain external 

930 captures of the custom gradient function. 

931 internal_capture_to_output: dict from a tensor_id(captured placeholder) to 

932 the corresponding tensor that needs to be added to the list of outputs. 

933 For instance, when capturing an accumulator TensorList this contains the 

934 TensorList obtained after popping a tensor from the list. Other entries 

935 in this dict are expected, though not enforced, to be identities. 

936 This dict is needed because these output tensors need to be added to 

937 FuncGraph.outputs "after" the tensors returned from the gradient function. 

938 """ 

939 

940 def __init__(self, name, forward_cond_graph, forward_body_graph, 

941 maximum_iterations, forward_while_op, body_graph_inputs, 

942 body_graph_outputs): 

943 super(_WhileBodyGradFuncGraph, self).__init__(name) 

944 self.extra_inputs = [] 

945 self.internal_capture_to_output = {} 

946 # FuncGraph for the body of the forward While op. 

947 self._forward_graph = forward_body_graph 

948 # FuncGraph for the cond of the forward While op. 

949 self._forward_cond_graph = forward_cond_graph 

950 self._maximum_iterations = maximum_iterations 

951 self._forward_while_op = forward_while_op 

952 # Dict from forward intermediate tensor to its indirectly captured tensor 

953 # in this graph. Indirect capturing happens in two ways: 

954 # 1. For non-resource tensors we capture their accumulators from the forward 

955 # outer graph and pop values from that accumulator inside this graph 

956 # using TensorListPopBack. 

957 # 2. For resource tensors we directly capture their corresponding tensor 

958 # in the forward outer graph. 

959 self._indirect_captures = {} 

960 

961 @property 

962 def while_op_needs_rewrite(self): 

963 return self.extra_inputs 

964 

965 def _create_op_internal( 

966 self, 

967 op_type, 

968 inputs, 

969 dtypes=None, # pylint: disable=redefined-outer-name 

970 input_types=None, 

971 name=None, 

972 attrs=None, 

973 op_def=None, 

974 compute_device=True): 

975 # For a reduction op, if op is in the gradient body graph and its input is 

976 # from the forward graph, moving op to the forward graph means we would 

977 # store the tensor after the reduction as opposed to the tensor before 

978 # reduction, and therefore could significantly reduce memory consumption. 

979 # For now, we do this only for a few ops. 

980 # 

981 # We don't do this if any input tensor has already been accumulated. This 

982 # can happen if we output all intermediates in the forward pass. 

983 # 

984 # If in XLA context, do not move constant ops to forward pass as pushing to 

985 # and popping from a TensorList removes the constant property of an op and 

986 # breaks XLA compilation, which requires certain inputs to be compile-time 

987 # constant for certain ops. 

988 # 

989 # This optimization is currently also disabled when under a persistent tape, 

990 # since it leads to an unbounded number of side outputs. With caching it may 

991 # be possible to re-enable it. 

992 optimized_reduction_ops = { 

993 "Shape", "Size", "Rank", "TensorListElementShape", "TensorListLength" 

994 } 

995 if (op_type in optimized_reduction_ops and 

996 not util.output_all_intermediates() and 

997 all(input.graph is self._forward_graph for input in inputs) and 

998 all(_get_accumulator(input) is None for input in inputs) and 

999 not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and 

1000 not util.graph_wrapped_for_higher_order_tape_gradients( 

1001 self._forward_graph)): 

1002 return self._move_op_to_forward_graph( 

1003 op_type, 

1004 inputs, 

1005 dtypes=dtypes, 

1006 input_types=input_types, 

1007 name=name, 

1008 attrs=attrs, 

1009 op_def=op_def, 

1010 compute_device=compute_device) 

1011 

1012 return super(_WhileBodyGradFuncGraph, self)._create_op_internal( 

1013 op_type, 

1014 inputs, 

1015 dtypes=dtypes, 

1016 input_types=input_types, 

1017 name=name, 

1018 attrs=attrs, 

1019 op_def=op_def, 

1020 compute_device=compute_device) 

1021 

1022 def _move_op_to_forward_graph( 

1023 self, 

1024 op_type, 

1025 inputs, 

1026 dtypes=None, # pylint: disable=redefined-outer-name 

1027 input_types=None, 

1028 name=None, 

1029 attrs=None, 

1030 op_def=None, 

1031 compute_device=True): 

1032 # We have a cache of reduction ops that have already been moved to the 

1033 # forward graph, and we will check it first to avoid moving an op twice. 

1034 if not hasattr(self._forward_graph, "_optimized_reduction_ops_cache"): 

1035 self._forward_graph._optimized_reduction_ops_cache = {} 

1036 cache_key = self._get_optimized_reduction_ops_cache_key( 

1037 op_type, inputs, dtypes, input_types, name, attrs, op_def, 

1038 compute_device) 

1039 cached_op = self._forward_graph._optimized_reduction_ops_cache.get( 

1040 cache_key) 

1041 if cached_op is not None: 

1042 # This op has already been moved to the forward graph and we have it in 

1043 # the cache. 

1044 return cached_op 

1045 

1046 with self._forward_graph.as_default(): 

1047 # `name` was built using name_scope stack of gradient graph and may not 

1048 # be unique in the forward graph. `Graph.create_op` does not uniquify 

1049 # names which are name scopes i.e. end in `/`. To ensure that the op 

1050 # created gets a unique name in the forward graph we get rid of the 

1051 # trailing slash. 

1052 name = ops.name_from_scope_name(name) 

1053 result = self._forward_graph._create_op_internal( 

1054 op_type, 

1055 inputs, 

1056 dtypes=dtypes, 

1057 input_types=input_types, 

1058 name=name, 

1059 attrs=attrs, 

1060 op_def=op_def, 

1061 compute_device=compute_device) 

1062 

1063 # Store the op we just moved to the forward graph so that it does 

1064 # not need to be added there again. 

1065 self._forward_graph._optimized_reduction_ops_cache[cache_key] = result 

1066 return result 

1067 

1068 def _get_optimized_reduction_ops_cache_key( 

1069 self, 

1070 op_type, 

1071 inputs, 

1072 dtypes=None, # pylint: disable=redefined-outer-name 

1073 input_types=None, 

1074 name=None, 

1075 attrs=None, 

1076 op_def=None, 

1077 compute_device=True): 

1078 # We need all elements of CacheKey to be hashable. 

1079 inputs = tuple(map(lambda t: t.ref(), inputs)) 

1080 

1081 if dtypes is not None: 

1082 dtypes = tuple(dtypes) 

1083 

1084 if input_types is not None: 

1085 input_types = tuple(input_types) 

1086 

1087 if attrs is not None: 

1088 hashable_attrs = [] 

1089 for attr_name, attr_value in sorted(attrs.items()): 

1090 hashable_attrs.append((attr_name, attr_value.SerializeToString())) 

1091 attrs = tuple(hashable_attrs) 

1092 

1093 if op_def is not None: 

1094 op_def = op_def.SerializeToString() 

1095 

1096 return OptimizedReductionOpsCacheKey(op_type, inputs, dtypes, input_types, 

1097 name, attrs, op_def, compute_device) 

1098 

1099 def _capture_helper(self, tensor, name): 

1100 """Implements the capturing described in the class docstring.""" 

1101 captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor)) 

1102 if captured_tensor is not None: 

1103 return captured_tensor 

1104 

1105 if tensor.graph is not self._forward_graph: 

1106 already_captured = id(tensor) in self.function_captures.by_val_internal 

1107 captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper( 

1108 tensor, name) 

1109 if not already_captured: 

1110 # Adds the captured tensor to the list of outputs so that the input 

1111 # and output signatures match. 

1112 self.internal_capture_to_output[ops.tensor_id( 

1113 captured_tensor)] = captured_tensor 

1114 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor 

1115 return captured_tensor 

1116 

1117 while tensor.op.type == "Identity": 

1118 # We do not accumulate the output of identity nodes so we try to capture 

1119 # the input of the Identity node instead. 

1120 tensor = tensor.op.inputs[0] 

1121 

1122 captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor)) 

1123 if captured_tensor is not None: 

1124 return captured_tensor 

1125 

1126 # No need to accumulate loop invariants. Capture them directly. 

1127 # The captured tensor gets resolved to the corresponding while output in 

1128 # `_resolve_grad_captures`. 

1129 if _is_loop_invariant(tensor, self._forward_graph.inputs, 

1130 self._forward_graph.outputs): 

1131 captured_tensor = super(_WhileBodyGradFuncGraph, 

1132 self)._capture_helper(tensor, name) 

1133 # Add to `internal_capture_to_output` so that this gets added to the list 

1134 # of outputs. 

1135 self.internal_capture_to_output[ops.tensor_id( 

1136 captured_tensor)] = captured_tensor 

1137 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor 

1138 return captured_tensor 

1139 

1140 # Do not accumulate Const nodes. Instead copy them directly in the backward 

1141 # graph. 

1142 # TODO(srbs): This just checks for `Const` nodes. Consider checking for 

1143 # graph compile time consts in general. 

1144 # TODO(srbs): Consider making this a loop input. 

1145 if constant_op.is_constant(tensor): 

1146 real_value = constant_op.constant( 

1147 tensor_util.constant_value(tensor), dtype=tensor.dtype) 

1148 self._indirect_captures[ops.tensor_id(tensor)] = real_value 

1149 return real_value 

1150 

1151 # Resource tensors are not accumulated and handled specially. 

1152 if tensor.dtype == dtypes.resource: 

1153 return self._resource_capture_helper(tensor) 

1154 

1155 # Create or find an existing accumulator output for `tensor` in the forward 

1156 # graph, and fetch from this accumulator in the gradient graph to get the 

1157 # raw intermediate value. 

1158 accumulator = _get_accumulator(tensor) 

1159 if accumulator is None: 

1160 # Create the initial empty tensor list. 

1161 # 

1162 # Note: We clear the control dependencies to avoid a cycle in case a 

1163 # control tensor has an input path to an output of the forward While. 

1164 # 

1165 # E.g.: 

1166 # x = tf.while_loop(...) 

1167 # y = f(x) 

1168 # with tf.control_dependencies([y]): 

1169 # tf.gradients(y, x) 

1170 # 

1171 # Since the EmptyTensorList is fed back into the forward While, not 

1172 # removing the control edge would cause a cycle. 

1173 with self._forward_graph.outer_graph.as_default(): 

1174 with util.clear_control_inputs(): 

1175 tensor_list = list_ops.empty_tensor_list( 

1176 element_dtype=tensor.dtype, 

1177 element_shape=tensor.shape, 

1178 max_num_elements=self._maximum_iterations, 

1179 name=_build_accumulator_name(tensor)) 

1180 self.extra_inputs.append(tensor_list) 

1181 

1182 # Push the intermediate tensor to the tensor list. This captures 

1183 # `tensor_list`. 

1184 with self._forward_graph.as_default(): 

1185 accumulator = list_ops.tensor_list_push_back(tensor_list, tensor) 

1186 # Add the modified tensor list to the list of outputs. This output will be 

1187 # all the accumulated values. 

1188 self._forward_graph.outputs.append(accumulator) 

1189 

1190 # Capture in the cond graph as well so the forward cond and body inputs 

1191 # match. 

1192 with self._forward_cond_graph.as_default(): 

1193 self._forward_cond_graph.capture(tensor_list) 

1194 

1195 # Capture the accumulator tensor list in the gradient graph directly from 

1196 # the forward graph -- we'll later modify this to capture the final list 

1197 # output by the forward While op instead. 

1198 captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper( 

1199 accumulator, name) 

1200 

1201 # Pop the intermediate value from the tensor list in the gradient graph. 

1202 new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( 

1203 captured_accumulator, element_dtype=tensor.dtype) 

1204 

1205 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor 

1206 self.internal_capture_to_output[ops.tensor_id( 

1207 captured_accumulator)] = new_tensor_list 

1208 return captured_tensor 

1209 

1210 def _resource_capture_helper(self, tensor): 

1211 """Returns the captured resource tensor. 

1212 

1213 Resource-type tensors are not accumulated. If a resource tensor exists in 

1214 the loop body it must either be a loop input or an output of a nested While 

1215 op inside the loop body which had captured the external resource. 

1216 

1217 Args: 

1218 tensor: the external resource Tensor to be captured. 

1219 

1220 Returns: 

1221 Tensor in this graph. 

1222 """ 

1223 assert tensor.dtype == dtypes.resource 

1224 

1225 forward_graph_input_names = [t.name for t in self._forward_graph.inputs] 

1226 forward_graph_name_to_opdef = { 

1227 op.name: op.node_def for op in self._forward_graph.get_operations()} 

1228 index = util.resource_input_index( 

1229 tensor.name, forward_graph_input_names, 

1230 forward_graph_name_to_opdef, 

1231 self._forward_graph._functions) 

1232 

1233 input_placeholder = self._forward_graph.inputs[index] 

1234 tensor_in_outer_graph = self._forward_graph._while.inputs[index] 

1235 

1236 assert input_placeholder.dtype == dtypes.resource 

1237 assert tensor_in_outer_graph.dtype == dtypes.resource 

1238 # This must be a loop invariant. However, infrastructure 

1239 # (e.g. tf.vectorized_map) may insert identity nodes, function calls, conds, 

1240 # etc. which take and return the resource tensor unmodified; this means that 

1241 # the Python objects may differ. 

1242 if index != util.resource_input_index( 

1243 self._forward_graph.outputs[index].name, forward_graph_input_names, 

1244 forward_graph_name_to_opdef, 

1245 self._forward_graph._functions): 

1246 raise AssertionError( 

1247 f"Resource tensors must be loop invariants {tensor_in_outer_graph}") 

1248 

1249 self._indirect_captures[ops.tensor_id(tensor)] = self.capture( 

1250 tensor_in_outer_graph) 

1251 return self._indirect_captures[ops.tensor_id(tensor)] 

1252 

1253 

1254def _check_shapes_compat(flat_output_tensors, flat_shape_invariants, 

1255 flat_input_tensors): 

1256 for (t, shape, input_t) in zip(flat_output_tensors, flat_shape_invariants, 

1257 flat_input_tensors): 

1258 if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape): 

1259 raise ValueError( 

1260 f"Input tensor `{input_t.name}` enters the loop with shape {shape}, " 

1261 f"but has shape {t.shape} after one iteration. To allow the shape to " 

1262 "vary across iterations, use the `shape_invariants` argument of " 

1263 "tf.while_loop to specify a less-specific shape.") 

1264 

1265 

1266def _check_num_inputs_outputs(cond_graph, body_graph, num_flattened_loop_vars): 

1267 """Checks the number of inputs/outputs of `cond_graph` and `body_graph`.""" 

1268 assert len(cond_graph.inputs) == num_flattened_loop_vars, ( 

1269 "cond_graph takes %d inputs; Expected: %d" % (len(cond_graph.inputs), 

1270 num_flattened_loop_vars)) 

1271 assert len(cond_graph.outputs) == 1, ( 

1272 "cond_graph has %d outputs; Expected: 1" % len(cond_graph.outputs)) 

1273 assert len(body_graph.inputs) == num_flattened_loop_vars, ( 

1274 "body_graph takes %d inputs; Expected: %d" % (len(body_graph.inputs), 

1275 num_flattened_loop_vars)) 

1276 assert len(body_graph.outputs) == num_flattened_loop_vars, ( 

1277 "body_graph has %d outputs; Expected: %d" % (len(body_graph.outputs), 

1278 num_flattened_loop_vars)) 

1279 

1280 

1281def _check_inputs_outputs_types_match(body_graph, flattened_loop_vars): 

1282 for inp, out, loop_var in zip(body_graph.inputs, body_graph.outputs, 

1283 flattened_loop_vars): 

1284 if inp.dtype != out.dtype: 

1285 raise TypeError( 

1286 f"Loop var {loop_var.name} enters the loop with type {inp.dtype} " 

1287 f"but has type {out.dtype} after 1 iteration. {loop_var.name} type " 

1288 "should remain constant.") 

1289 

1290 

1291def _build_cond_placeholders_name_prefix(cond_graph): 

1292 return cond_graph.unique_name(cond_graph.name + "___redundant_placeholder") 

1293 

1294 

1295def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures): 

1296 """Creates placeholders for body captures in cond_graph. 

1297 

1298 This is needed to match signatures of cond and body graphs. 

1299 

1300 Args: 

1301 cond_graph: cond branch graph 

1302 body_graph_captures: Tensors which were captured when building the 

1303 `body_graph`. 

1304 """ 

1305 types = [t.dtype.as_datatype_enum for t in body_graph_captures] 

1306 # TODO(srbs): Providing a unique prefix does not ensure that there is no 

1307 # conflict between the placeholder names and existing nodes in the graph. 

1308 # However passing a list of strings may not be performant. 

1309 # Ideally we should move `Graph.unique_name` to C++ or make 

1310 # `Graph._names_in_use` a trie so that we can find a unique prefix. 

1311 # TODO(b/143286622): This should not be required once captures are separated 

1312 # from regular loop vars. 

1313 with cond_graph._c_graph.get() as c_graph: 

1314 placeholders = c_api.TF_CreatePlaceholders( 

1315 c_graph, types, 

1316 compat.as_str(_build_cond_placeholders_name_prefix(cond_graph))) 

1317 placeholder_ops = [ 

1318 ops.Operation._from_c_op(ph.oper, cond_graph) for ph in placeholders 

1319 ] 

1320 

1321 tensors = [] 

1322 for op in placeholder_ops: 

1323 tensors.append(op.outputs[0]) 

1324 

1325 # Update `cond_graph._captures` and `cond_graph.inputs` to contain the 

1326 # newly created placeholders. 

1327 tuples = zip(body_graph_captures, tensors) 

1328 keys = [id(t) for t in body_graph_captures] 

1329 for k, v in zip(keys, tuples): 

1330 cond_graph._function_captures.add_or_replace( 

1331 key=k, 

1332 external=v[0], 

1333 internal=v[1], 

1334 is_by_ref=False) 

1335 cond_graph.inputs.extend(tensors) 

1336 

1337 

1338def _copy_handle_data(src_tensors, tgt_tensors): 

1339 for src_t, tgt_t in zip(src_tensors, tgt_tensors): 

1340 handle_data_util.copy_handle_data(src_t, tgt_t) 

1341 

1342 

1343def _pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, loop_vars): 

1344 """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays.""" 

1345 

1346 def flow_to_tensor_array(flow, ta): # pylint: disable=missing-docstring 

1347 return (tensor_array_ops.build_ta_with_new_flow(ta, flow) if isinstance( # pylint: disable=g-long-ternary 

1348 ta, tensor_array_ops.TensorArray) else flow) 

1349 

1350 flattened_loop_vars = [ 

1351 flow_to_tensor_array(*z) 

1352 for z in zip(nest.flatten(loop_vars, expand_composites=True), 

1353 flat_orig_loop_vars) 

1354 ] 

1355 return nest.pack_sequence_as(loop_vars_signature, flattened_loop_vars, 

1356 expand_composites=True) 

1357 

1358 

1359def _tensor_array_to_flow(loop_vars): 

1360 

1361 def f(maybe_ta): 

1362 if isinstance(maybe_ta, tensor_array_ops.TensorArray): 

1363 return maybe_ta.flow 

1364 return maybe_ta 

1365 

1366 return nest.map_structure(f, loop_vars, expand_composites=True) 

1367 

1368 

1369def _build_maximum_iterations_loop_var(maximum_iterations): 

1370 if maximum_iterations is None: 

1371 # Default value for max_num_elements to EmptyTensorList meaning that the 

1372 # list size is unbounded. 

1373 maximum_iterations = -1 

1374 # EmptyTensorList expects `max_num_elements` to be of type int32. 

1375 return ops.convert_to_tensor( 

1376 maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") 

1377 

1378 

1379def _build_accumulator_name(tensor): 

1380 # Tensor name may be of the form "pow/y:0". Name scope does not allow ":". 

1381 return "{}/accumulator".format(tensor.name).replace(":", "_") 

1382 

1383 

1384def _is_loop_invariant(tensor, inputs, outputs): 

1385 return (any(tensor is t for t in inputs) and 

1386 any(tensor is t for t in outputs)) 

1387 

1388 

1389def _set_read_only_resource_inputs_attr(op, branch_graphs): 

1390 """Sets the list of resource inputs which are read-only. 

1391 

1392 This is used by AutomaticControlDependencies. 

1393 

1394 Args: 

1395 op: While Operation. 

1396 branch_graphs: List of branch FuncGraphs. 

1397 """ 

1398 read_only_indices = set(range(len(op.inputs))) 

1399 for branch_graph in branch_graphs: 

1400 if not read_only_indices: 

1401 break 

1402 branch_read_only_indices = acd.get_read_only_resource_input_indices_graph( 

1403 branch_graph) 

1404 read_only_indices = read_only_indices.intersection(branch_read_only_indices) 

1405 

1406 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, 

1407 sorted(read_only_indices)) 

1408 

1409# pylint: enable=protected-access