Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py: 14%

495 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15"""cond_v2 and gradient. 

16 

17This is a version of cond that emits a single If op, as well as the gradient 

18function for If ops produced by cond_v2. This will eventually replace the 

19current tf.cond implementation once it reaches feature and performance parity. 

20""" 

21 

22import collections 

23 

24from tensorflow.core.framework import types_pb2 

25from tensorflow.python.eager import backprop_util 

26from tensorflow.python.framework import auto_control_deps 

27from tensorflow.python.framework import auto_control_deps_utils as acd 

28from tensorflow.python.framework import constant_op 

29from tensorflow.python.framework import dtypes 

30from tensorflow.python.framework import errors_impl 

31from tensorflow.python.framework import func_graph as func_graph_module 

32from tensorflow.python.framework import indexed_slices 

33from tensorflow.python.framework import ops 

34from tensorflow.python.framework import tensor_shape 

35from tensorflow.python.framework import tensor_util 

36from tensorflow.python.framework import type_spec 

37from tensorflow.python.ops import array_ops 

38from tensorflow.python.ops import control_flow_util 

39from tensorflow.python.ops import control_flow_util_v2 as util 

40from tensorflow.python.ops import default_gradient 

41from tensorflow.python.ops import gen_functional_ops 

42from tensorflow.python.ops import gen_optional_ops 

43from tensorflow.python.ops import gradients_util 

44from tensorflow.python.ops import handle_data_util 

45from tensorflow.python.ops import math_ops 

46from tensorflow.python.util import nest 

47 

48 

49# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify 

50# that they aren't part of the official public API. These protected members 

51# often need to be used by implementation code however. Rather than litter the 

52# code with pylint comments, we ignore protected access violations for 

53# readability. 

54# pylint: disable=protected-access 

55 

56_COND = 1 

57_CASE = 2 

58 

59 

60def cond_v2(pred, true_fn, false_fn, name="cond"): 

61 """Like tf.cond, except emits a single If op.""" 

62 if isinstance(pred, bool): 

63 raise TypeError("pred must not be a Python bool", pred) 

64 

65 if not name: 

66 name = "cond" 

67 

68 with ops.name_scope(name) as scope: 

69 true_name = util.unique_fn_name(scope, "true") 

70 false_name = util.unique_fn_name(scope, "false") 

71 

72 # Automatic control dependencies are added in defuns, but not in v1 

73 # graphs. Propagate that behavior here. 

74 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 

75 pred = ops.convert_to_tensor(pred) 

76 if (tensor_util.is_tf_type(pred) and 

77 (pred.shape.dims is None or pred.shape.dims)): 

78 pred = array_ops.squeeze_v2(pred) 

79 

80 true_graph = func_graph_module.func_graph_from_py_func( 

81 true_name, 

82 true_fn, [], {}, 

83 func_graph=util.CondBranchFuncGraph( 

84 true_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 

85 add_control_dependencies=add_control_dependencies, 

86 op_return_value=pred) 

87 false_graph = func_graph_module.func_graph_from_py_func( 

88 false_name, 

89 false_fn, [], {}, 

90 func_graph=util.CondBranchFuncGraph( 

91 false_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 

92 add_control_dependencies=add_control_dependencies, 

93 op_return_value=pred) 

94 

95 verify_captures(_COND, [true_graph, false_graph]) 

96 return _build_cond( 

97 pred, 

98 true_graph, 

99 false_graph, 

100 true_graph.external_captures, 

101 false_graph.external_captures, 

102 building_gradient=False, 

103 name=scope) 

104 

105 

106@ops.RegisterGradient("StatelessIf") 

107@ops.RegisterGradient("If") 

108def _IfGrad(op, *grads): # pylint: disable=invalid-name 

109 """The gradient of an If op produced by cond_v2.""" 

110 # Get the if operator (this logic handles the case where op is a MockOp) 

111 if_op = op.outputs[0].op 

112 true_graph, false_graph = get_func_graphs(if_op) 

113 # Note: op.graph != ops.get_default_graph() when we are computing the gradient 

114 # of a nested cond. 

115 assert true_graph.outer_graph == if_op.graph 

116 assert false_graph.outer_graph == if_op.graph 

117 

118 # Create grad functions that compute the gradient of the true/false forward 

119 # graphs. These functions will capture tensors from the forward pass 

120 # functions. 

121 true_grad_graph = _create_grad_func( 

122 true_graph, grads, util.unique_grad_fn_name(true_graph.name)) 

123 false_grad_graph = _create_grad_func( 

124 false_graph, grads, util.unique_grad_fn_name(false_graph.name)) 

125 

126 # Replaces output None grads with zeros if at least one branch has non-None 

127 # grad at that index. 

128 _create_zeros_for_none_grads([true_graph, false_graph], 

129 [true_grad_graph, false_grad_graph]) 

130 

131 if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite): 

132 # Modify 'op' to output the intermediates needed by the grad functions. Note 

133 # that all needed intermediates are wrapped in optionals. Each optional 

134 # intermediate output will have a value iff its corresponding branch is 

135 # taken. 

136 # NOTE(skyewm): if there are any active sessions, this modification to `op` 

137 # may make them unrunnable! 

138 

139 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 

140 # XLA does not yet support optionals, so output intermediates directly and 

141 # make them match via FakeParams, which can be converted to zeros in XLA. 

142 # TODO(skyewm,jpienaar): can XLA support optionals? 

143 true_intermediates = true_grad_graph.xla_intermediates 

144 false_intermediates = false_grad_graph.xla_intermediates 

145 extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( 

146 [true_graph, false_graph], [true_intermediates, false_intermediates]) 

147 else: 

148 true_intermediates = true_grad_graph.wrapped_intermediates 

149 false_intermediates = false_grad_graph.wrapped_intermediates 

150 # Make outputs match by adding none optionals. 

151 extra_true_outputs, extra_false_outputs = _make_intermediates_match( 

152 [true_graph, false_graph], [true_intermediates, false_intermediates]) 

153 

154 true_graph.outputs.extend(extra_true_outputs) 

155 false_graph.outputs.extend(extra_false_outputs) 

156 # TODO(skyewm): indicate it's an internal bug if this fails. 

157 _check_same_outputs(_COND, [true_graph, false_graph]) 

158 

159 true_graph.name += "_rewritten" 

160 false_graph.name += "_rewritten" 

161 

162 if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph)) 

163 if_op._set_func_attr("else_branch", 

164 util.create_new_tf_function(false_graph)) 

165 if_op._set_type_list_attr("Tout", true_graph.output_types) 

166 if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes) 

167 if_op._add_outputs( 

168 [t.dtype for t in extra_true_outputs], 

169 [t.shape for t in extra_true_outputs]) 

170 

171 # Resolve references to forward graph tensors in grad graphs and ensure 

172 # they are in-scope, i.e., belong to one of outer graphs of the grad graph. 

173 true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) 

174 false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) 

175 

176 # This modifies true_grad_graph and false_grad_graph. 

177 _make_output_composite_tensors_match(_COND, 

178 [true_grad_graph, false_grad_graph]) 

179 

180 outputs = _build_cond( 

181 if_op.inputs[0], 

182 true_grad_graph, 

183 false_grad_graph, 

184 true_grad_inputs, 

185 false_grad_inputs, 

186 building_gradient=True, 

187 ) 

188 

189 # The predicate has no gradient. 

190 return [None] + outputs 

191 

192 

193def _build_cond(pred, 

194 true_graph, 

195 false_graph, 

196 true_inputs, 

197 false_inputs, 

198 building_gradient, 

199 name=None): 

200 """Creates an If op from the specified predicate, branch functions and inputs. 

201 

202 Note that this modifies true_graph and false_graph to make the inputs match, 

203 and to output all intermediates values so they're available for the gradient 

204 computation. 

205 

206 true_graph and false_graph need not have the same input types, but they must 

207 have the same output types. 

208 

209 Args: 

210 pred: boolean Tensor 

211 true_graph: FuncGraph 

212 false_graph: FuncGraph 

213 true_inputs: a list of Tensors to be passed to true_graph as input. 

214 false_inputs: a list of Tensors to be passed to false_graph as input. 

215 building_gradient: Whether this is a gradient If op. 

216 name: the name for the If op. 

217 

218 Returns: 

219 A list of Tensors which are the outputs of the If op. Does not include added 

220 intermediate outputs. 

221 """ 

222 _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph]) 

223 _check_same_outputs(_COND, [true_graph, false_graph]) 

224 

225 # Add inputs to true_graph and false_graph to make them match. Note that 

226 # this modifies true_graph and false_graph. 

227 cond_inputs = _make_inputs_match([true_graph, false_graph], 

228 [true_inputs, false_inputs]) 

229 # We do not output intermediates of the gradient If op since this is just 

230 # for backwards compatibility with existing code. 

231 if not building_gradient and util.output_all_intermediates(): 

232 # Add all intermediate tensors as function outputs so they're available for 

233 # the gradient computation. Since the outputs of the two functions must 

234 # match, we wrap all the intermediates in optionals. Each intermediate 

235 # output will have a value iff its corresponding branch is taken. 

236 

237 true_intermediates = _get_intermediates(true_graph) 

238 false_intermediates = _get_intermediates(false_graph) 

239 

240 # Wrap intermediates in optionals. 

241 wrapped_true_intermediates = _wrap_intermediates(true_graph, 

242 true_intermediates) 

243 wrapped_false_intermediates = _wrap_intermediates(false_graph, 

244 false_intermediates) 

245 

246 # Make outputs match by adding none optionals. 

247 extra_true_outputs, extra_false_outputs = _make_intermediates_match( # pylint: disable=unbalanced-tuple-unpacking 

248 [true_graph, false_graph], 

249 [wrapped_true_intermediates, wrapped_false_intermediates]) 

250 

251 true_graph.outputs.extend(extra_true_outputs) 

252 false_graph.outputs.extend(extra_false_outputs) 

253 _check_same_outputs(_COND, [true_graph, false_graph]) 

254 

255 # Create the If op. 

256 with ops.control_dependencies( 

257 list(true_graph.function_captures.control) + list( 

258 false_graph.function_captures.control)): 

259 true_stateful_ops = [ 

260 op for op in true_graph.get_operations() if op._is_stateful 

261 ] 

262 false_stateful_ops = [ 

263 op for op in false_graph.get_operations() if op._is_stateful 

264 ] 

265 if (true_stateful_ops or false_stateful_ops): 

266 op_fn = gen_functional_ops._if 

267 else: 

268 op_fn = gen_functional_ops.stateless_if 

269 

270 def _make_op(inputs): 

271 if_op, tensors = util.get_op_and_outputs(op_fn( 

272 pred, 

273 inputs, [t.dtype for t in true_graph.outputs], 

274 util.create_new_tf_function(true_graph), 

275 util.create_new_tf_function(false_graph), 

276 output_shapes=_get_output_shapes(true_graph.outputs, 

277 false_graph.outputs), 

278 name=name)) 

279 _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs) 

280 # `if_op` is None if this is a `StatelessIf` op with no outputs. 

281 if if_op is not None: 

282 # The true and false graphs have already been created, and we need that 

283 # to happen before we know which tensors will be captured and so whether 

284 # to wrap the cond in a tf.function. Post-hoc mutation of the branch 

285 # `outer_graph` properties seems like the only option if we want to 

286 # conditionally wrap in a function. 

287 true_graph.outer_graph = ops.get_default_graph() 

288 false_graph.outer_graph = ops.get_default_graph() 

289 if_op._true_graph = true_graph 

290 if_op._false_graph = false_graph 

291 util.maybe_set_lowering_attr(if_op) 

292 util.maybe_propagate_compile_time_consts_in_xla(if_op) 

293 _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph]) 

294 # Prevent fetching since the variant outputs can't be fetched directly. 

295 if_op.graph.prevent_fetching(if_op) 

296 return tensors 

297 tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs) 

298 

299 # Return identities for each output of the If op, rather than the output of 

300 # the If op directly. This makes pruning work if the output of cond() is 

301 # fetched: the lowering pass converts the If outputs into IdentityN outputs, 

302 # which if fetched will cause all ops in the taken branch to be run (since 

303 # it takes all merge ops as input). After lowering, each output identity op 

304 # will end up with only the appropriate merge op as input. 

305 # TODO(b/79984175): this doesn't have to be a tuple once we covert to the 

306 # correct output structure 

307 tensors = [array_ops.identity(t) for t in tensors] 

308 

309 structured_output_specs = _get_compatible_structured_output_specs(true_graph, 

310 false_graph) 

311 return _pack_sequence_as(structured_output_specs, tensors) 

312 

313 

314def get_func_graphs(op): 

315 """Returns `FuncGraph`s for the input op branches. 

316 

317 Args: 

318 op: The If or Case Operation. 

319 

320 Returns: 

321 A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches 

322 for Case). 

323 """ 

324 

325 def _get_func_graph_for_branch(name_attr_list, cached_attr_name=None): 

326 """Generates and returns a FuncGraph for the given branch.""" 

327 func_graph = None 

328 if cached_attr_name is not None: 

329 func_graph = getattr(op, cached_attr_name, None) 

330 inputs = op.inputs[1:] # First input is pred. 

331 if func_graph is None: 

332 input_shapes = [t.shape for t in inputs] 

333 func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name) 

334 for external_t, internal_t in zip(inputs, func_graph.inputs): 

335 handle_data_util.copy_handle_data(external_t, internal_t) 

336 func_graph.function_captures.reset_captures(inputs, func_graph.inputs) 

337 # Link the op so that the gradient code can use it. 

338 func_graph._forward_cond = op 

339 return func_graph 

340 

341 if op.type in ["If", "StatelessIf"]: 

342 return (_get_func_graph_for_branch( 

343 op.get_attr("then_branch"), "_true_graph"), 

344 _get_func_graph_for_branch( 

345 op.get_attr("else_branch"), "_false_graph")) 

346 elif op.type in ["Case", "StatelessCase"]: 

347 return [_get_func_graph_for_branch(branch_fn, "_branch_graph_{}".format(i)) 

348 for i, branch_fn in enumerate(op.get_attr("branches"))] 

349 else: 

350 raise ValueError("Unsupported op type: {}".format(op.type)) 

351 

352 

353def _get_compatible_structured_output_specs(true_graph, false_graph): 

354 """Returns the most specific compatible specs of graph structured outputs.""" 

355 return nest.map_structure(_get_compatible_spec, 

356 true_graph.structured_outputs, 

357 false_graph.structured_outputs) 

358 

359 

360def _get_compatible_spec(value_or_spec1, value_or_spec2): 

361 """Returns the most specific compatible spec. 

362 

363 Args: 

364 value_or_spec1: A TypeSpecs or a value that has a defined TypeSpec. 

365 value_or_spec2: A TypeSpecs or a value that has a defined TypeSpec. 

366 

367 Returns: 

368 The most specific compatible TypeSpecs of the input. 

369 

370 Raises: 

371 ValueError: If value_or_spec1 is not compatible with value_or_spec2. 

372 """ 

373 spec1 = _get_spec_for(value_or_spec1) 

374 spec2 = _get_spec_for(value_or_spec2) 

375 

376 # pylint: disable=protected-access 

377 common = spec1._without_tensor_names().most_specific_common_supertype( 

378 [spec2._without_tensor_names()]) 

379 if common is None: 

380 raise TypeError(f"No common supertype of {spec1} and {spec2}.") 

381 return common 

382 

383 

384def _get_spec_for(value_or_spec): 

385 """Returns TypeSpec of a value or itself if it is a TypeSpec already.""" 

386 if isinstance(value_or_spec, type_spec.TypeSpec): 

387 return value_or_spec 

388 return type_spec.type_spec_from_value(value_or_spec) 

389 

390 

391def _grad_fn(func_graph, grads): 

392 """The gradient function for each conditional branch. 

393 

394 This function builds the gradient graph of the corresponding forward-pass 

395 conditional branch in `func_graph`. This is done by differentiating 

396 func_graph's outputs w.r.t. its inputs. 

397 

398 Args: 

399 func_graph: FuncGraph. The corresponding forward-pass function. 

400 grads: The list of input gradient Tensors. 

401 

402 Returns: 

403 The output gradient Tensors. 

404 """ 

405 # Filter out untrainable function outputs. 

406 # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes 

407 # cause _GradientsHelper to raise an exception (e.g. the implementation 

408 # doesn't expect 'ys' to contain boolean tensors). 

409 assert len(func_graph.outputs) == len(grads) 

410 ys = [] 

411 grad_ys = [] 

412 for y, grad_y in zip(func_graph.outputs, grads): 

413 if not backprop_util.IsTrainable(y): 

414 continue 

415 ys.append(y) 

416 grad_ys.append(grad_y) 

417 

418 # Build the gradient graph. Note that this builds the gradient computation of 

419 # func_graph in the current graph, which requires capturing tensors from 

420 # func_graph. The captured func_graph tensors are resolved to external tensors 

421 # in _resolve_grad_inputs. 

422 result = gradients_util._GradientsHelper( 

423 ys, func_graph.inputs, grad_ys=grad_ys, 

424 src_graph=func_graph) 

425 

426 return result 

427 

428 

429def _create_grad_func(func_graph, grads, name): 

430 """Returns the FuncGraph representation of _grad_fn.""" 

431 return func_graph_module.func_graph_from_py_func( 

432 name, 

433 lambda: _grad_fn(func_graph, grads), [], {}, 

434 func_graph=_CondGradFuncGraph(name, func_graph)) 

435 

436 

437def _resolve_grad_inputs(cond_graph, grad_graph): 

438 """Returns the tensors to pass as inputs to `grad_graph`. 

439 

440 The `grad_graph` may have external references to 

441 1. Its outer graph containing the input gradients. These references are kept 

442 as is. 

443 2. Tensors in the forward pass graph. These tensors may not be "live" 

444 when the gradient is being computed. We replace such references by their 

445 corresponding tensor in `cond_graph.outer_graph`. In the case of nested 

446 control flow or functions, the gradient logic handling 

447 `grad_graph.outer_graph` will make sure the tensor from 

448 `cond_graph.outer_graph` is also correctly captured. 

449 

450 Args: 

451 cond_graph: FuncGraph. The forward-pass function. 

452 grad_graph: FuncGraph. The gradients function. 

453 

454 Returns: 

455 A list of inputs tensors to be passed to grad_graph. 

456 """ 

457 new_inputs = [] 

458 

459 for t in grad_graph.external_captures: 

460 # `t` must either be in `grad_graph.outer_graph` or in the forward 

461 # `cond_graph`. 

462 if t.graph != grad_graph.outer_graph: 

463 assert t.graph == cond_graph 

464 # `internal_captures` are not treated as intermediates and hence not added 

465 # to If op outputs. So we get the outer tensor corresponding to those 

466 # from the list of `external_captures`. 

467 for i, output in enumerate(t.graph.outputs): 

468 if output is t: 

469 t = t.graph._forward_cond.outputs[i] 

470 break 

471 else: 

472 for i, output in enumerate(t.graph.internal_captures): 

473 if output is t: 

474 t = t.graph.external_captures[i] 

475 break 

476 else: 

477 raise ValueError("Could not find external tensor capture {tensor} in " 

478 "captures or outputs".format(tensor=t)) 

479 

480 # Note: We rely on the capturing logic of the gradient If op graph to 

481 # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2 

482 # and while_v2 handle this while building their gradient functions. 

483 assert t.graph == cond_graph.outer_graph 

484 new_inputs.append(t) 

485 

486 return new_inputs 

487 

488 

489def _get_intermediates(func_graph): 

490 """Returns intermediate tensors of `func_graph` for gradient computation.""" 

491 intermediates = [] 

492 for op in func_graph.get_operations(): 

493 for t in op.outputs: 

494 if t in func_graph.inputs: continue 

495 if t in func_graph.outputs: continue 

496 if t.dtype is dtypes.resource: 

497 continue 

498 # Accumulating mutexes can cause deadlock. 

499 if op.type == "MutexLock": 

500 continue 

501 intermediates.append(t) 

502 return intermediates 

503 

504 

505def _make_intermediates_match(branch_graphs, branch_optionals): 

506 """Returns new optionals lists that have matching signatures. 

507 

508 This is done by mirroring each list in the other using none optionals. 

509 There is no merging of like optionals. 

510 

511 Args: 

512 branch_graphs: `list` of `FuncGraph`. 

513 branch_optionals: `list` of `list`s of optional `Tensor`s from other 

514 branch_graphs 

515 

516 Returns: 

517 A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the 

518 same number of `Tensor`s, all of which will be optionals of the same 

519 shape/type. 

520 """ 

521 new_branch_optionals = [] 

522 # Since the intermediates are optionals with dtype variant, we only need 

523 # enough room for the longest list of intermediates. 

524 intermediates_size = max(len(o) for o in branch_optionals) 

525 for i, branch_graph in enumerate(branch_graphs): 

526 other_optionals = _create_none_optionals( 

527 branch_graph, intermediates_size - len(branch_optionals[i])) 

528 new_branch_optionals.append(branch_optionals[i] + other_optionals) 

529 return new_branch_optionals 

530 

531 

532def _make_intermediates_match_xla(branch_graphs, branch_intermediates): 

533 """Like _make_intermediates_match but for the XLA case.""" 

534 new_branch_intermediates = [] 

535 for i, branch_graph in enumerate(branch_graphs): 

536 other_fakeparams = _create_fakeparams( 

537 branch_graph, 

538 sum((bi for bi in branch_intermediates 

539 if bi is not branch_intermediates[i]), [])) 

540 num_preceding = sum(len(bi) for bi in branch_intermediates[:i]) 

541 new_branch_intermediates.append(other_fakeparams[:num_preceding] + 

542 branch_intermediates[i] + 

543 other_fakeparams[num_preceding:]) 

544 return new_branch_intermediates 

545 

546 

547def _make_inputs_match(branch_graphs, branch_inputs): 

548 """Modifies branch_graphs so they have the same input signature. 

549 

550 This method reorders and/or adds parameters to each graph in branch_graphs so 

551 they have the same input signature, and updates the 'inputs' and 'captured' 

552 fields of each graph accordingly. It uses the input tensors from the outer 

553 graph to avoid duplicating shared arguments. 

554 

555 Args: 

556 branch_graphs: a `list` of `FuncGraph` 

557 branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The 

558 inputs for the corresponding graph in `branch_graphs`. 

559 

560 Returns: 

561 A new list of Tensors from the outer graph that are the new inputs for each 

562 branch_graph. This is a deduped version of `sum(branch_inputs)`. 

563 """ 

564 assert len(branch_graphs) == len(branch_inputs) 

565 added_inputs = set() 

566 new_inputs = [] 

567 for branch_in in branch_inputs: 

568 for tensor in branch_in: 

569 tensor_id = ops.tensor_id(tensor) 

570 if tensor_id not in added_inputs: 

571 added_inputs.add(tensor_id) 

572 new_inputs.append(tensor) 

573 

574 for branch_graph, branch_in in zip(branch_graphs, branch_inputs): 

575 input_ids = [ops.tensor_id(t) for t in branch_in] 

576 branch_input_to_param = dict(zip(input_ids, branch_graph.inputs)) 

577 input_list = [] 

578 for in_t in new_inputs: 

579 param = branch_input_to_param.get(ops.tensor_id(in_t)) 

580 if param is None: 

581 param = _create_dummy_input(branch_graph, in_t) 

582 input_list.append(param) 

583 

584 branch_graph.inputs = input_list 

585 

586 # Rewrite the FuncGraphs' state to reflect the new inputs. 

587 branch_graph.function_captures.reset_captures( 

588 new_inputs, branch_graph.inputs) 

589 

590 return new_inputs 

591 

592 

593def _create_zeros_for_none_grads(forward_graphs, grad_graphs): 

594 """Creates zeros for None out grads if at least one branch has non-None grad. 

595 

596 Args: 

597 forward_graphs: List of forward FuncGraphs. 

598 grad_graphs: List of grad FuncGraphs. 

599 """ 

600 assert len(forward_graphs) == len(grad_graphs) 

601 branch_outputs = [g.structured_outputs for g in grad_graphs] 

602 num_outputs_per_branch = [len(outs) for outs in branch_outputs] 

603 assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch 

604 for output_idx, branch_outs in enumerate(zip(*branch_outputs)): 

605 if (any(t is None for t in branch_outs) and 

606 any(t is not None for t in branch_outs)): 

607 for branch_index, t in enumerate(branch_outs): 

608 if t is None: 

609 with grad_graphs[branch_index].as_default(): 

610 zeros = default_gradient.zeros_like( 

611 forward_graphs[branch_index].inputs[output_idx]) 

612 grad_graphs[branch_index].structured_outputs[output_idx] = zeros 

613 

614 for grad_graph in grad_graphs: 

615 grad_graph.outputs = [ 

616 t for t in func_graph_module.flatten(grad_graph.structured_outputs) 

617 if t is not None 

618 ] 

619 

620 

621def _make_output_composite_tensors_match(op_type, branch_graphs): 

622 """Modifies each branch_graph's outputs to have the same output signature. 

623 

624 Currently the only transformation implemented is turning a Tensor into an 

625 equivalent IndexedSlices if the other branch returns an IndexedSlices. 

626 Updates branch_graph.{outputs,structured_outputs} for each branch_graph in 

627 branch_graphs. 

628 

629 Args: 

630 op_type: _COND or _CASE 

631 branch_graphs: `list` of `FuncGraph` 

632 

633 Raises: 

634 TypeError: if a set of outputs cannot be rewritten. 

635 """ 

636 # Note: since this is only used for gradient graphs, we do not expect the 

637 # outputs to be structured (e.g. nested lists), and thus do not need to use 

638 # nest.flatten, etc. 

639 assert branch_graphs 

640 branch_outputs = [g.structured_outputs for g in branch_graphs] 

641 outputs_per_branch = list(len(outs) for outs in branch_outputs) 

642 assert len(set(outputs_per_branch)) == 1, outputs_per_branch 

643 

644 for output_idx, branch_outs in enumerate(zip(*branch_outputs)): 

645 if len(set(type(out) for out in branch_outs)) == 1: 

646 continue 

647 if not any( 

648 isinstance(out, indexed_slices.IndexedSlices) for out in branch_outs): 

649 continue 

650 for branch_idx, branch_out in enumerate(branch_outs): 

651 if isinstance(branch_out, indexed_slices.IndexedSlices): 

652 continue 

653 elif isinstance(branch_out, ops.Tensor): 

654 with branch_graphs[branch_idx].as_default(): 

655 branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices( 

656 branch_out) 

657 else: 

658 raise TypeError( 

659 "Cannot reconcile {op_name} {output_idx}-th outputs:\n" 

660 " outputs from all branches: {outputs}".format( 

661 op_name="tf.cond" if op_type == _COND else "tf.switch_case", 

662 output_idx=output_idx, 

663 outputs=branch_outs)) 

664 

665 for branch_graph, branch_outs in zip(branch_graphs, branch_outputs): 

666 branch_graph.structured_outputs = branch_outs 

667 branch_graph.outputs = [ 

668 t for t in func_graph_module.flatten(branch_outs) if t is not None 

669 ] 

670 

671 

672def _make_indexed_slices_indices_types_match(op_type, branch_graphs): 

673 """Match dtype of IndexedSlices.indices in outputs of branch_graphs.""" 

674 assert branch_graphs 

675 # Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`. 

676 indexed_slice_indices = [] 

677 current_index = 0 

678 # Note that this still contains Nones. We leave those in so that error 

679 # messages contain the correct indices. We handle the Nones later when 

680 # updating `current_index`. 

681 branch_outputs_flat_with_composites = [ 

682 nest.flatten(branch_graph.structured_outputs, expand_composites=False) 

683 for branch_graph in branch_graphs 

684 ] 

685 outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites] 

686 assert len(set(outs_per_branch)) == 1, outs_per_branch 

687 # Store indices of IndexedSlices.indices in `indexed_slice_indices`. 

688 for output_idx, branch_outs in enumerate( 

689 zip(*branch_outputs_flat_with_composites)): 

690 if len( 

691 set( 

692 isinstance(out, indexed_slices.IndexedSlices) 

693 for out in branch_outs)) != 1: 

694 raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n" 

695 " branches returned: {outputs}".format( 

696 op_name="cond" if op_type == _COND else "switch_case", 

697 output_idx=output_idx, 

698 outputs=branch_outs)) 

699 if isinstance(branch_outs[0], indexed_slices.IndexedSlices): 

700 # indices is the second component of the composite tensor. 

701 indexed_slice_indices.append(current_index + 1) 

702 if nest.is_nested_or_composite(branch_outs[0]): 

703 current_index += len(nest.flatten(branch_outs[0], expand_composites=True)) 

704 elif branch_outs[0] is not None: 

705 # `FuncGraph.outputs` does not contain Nones so no need to update the 

706 # counter in that case. 

707 current_index += 1 

708 

709 if not indexed_slice_indices: 

710 return 

711 

712 # `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus 

713 # the Nones. 

714 if current_index != len(branch_graphs[0].outputs): 

715 raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n" 

716 "Expected: %i\n" 

717 "Actual: %i" % 

718 (current_index, len(branch_graphs[0].outputs))) 

719 

720 # Cast indices with mismatching types to int64. 

721 for index in indexed_slice_indices: 

722 if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64) 

723 for bg in branch_graphs): 

724 raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " 

725 "Found: %s" % 

726 str([bg.outputs[index].dtype for bg in branch_graphs])) 

727 if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1: 

728 for branch_graph in branch_graphs: 

729 if branch_graph.outputs[index].dtype == dtypes.int32: 

730 with branch_graph.as_default(): 

731 branch_graph.outputs[index] = math_ops.cast( 

732 branch_graph.outputs[index], dtypes.int64) 

733 

734 for branch_graph in branch_graphs: 

735 branch_graph.structured_outputs = _pack_sequence_as( 

736 branch_graph.structured_outputs, branch_graph.outputs) 

737 

738 

739def _pack_sequence_as(structured_outputs, op_outputs): 

740 """Packs the outputs of the gradient If/Case op. 

741 

742 The branch functions may contain None's in the list of `structured_outputs`. 

743 `op_outputs` has those outputs missing. So we need to add those Nones to the 

744 list of `op_outputs` and then pack it in the same structure as 

745 `structured_outputs`. 

746 

747 Args: 

748 structured_outputs: structured_outputs from one of the branch functions. 

749 op_outputs: List of output tensors of the op. 

750 

751 Returns: 

752 `op_outputs` packed like `structured_outputs`. 

753 """ 

754 outputs_with_nones = [] 

755 counter = 0 

756 for output in nest.flatten(structured_outputs, expand_composites=True): 

757 if output is None: 

758 outputs_with_nones.append(None) 

759 else: 

760 outputs_with_nones.append(op_outputs[counter]) 

761 counter += 1 

762 return func_graph_module.pack_sequence_as(structured_outputs, 

763 outputs_with_nones) 

764 

765 

766def _wrap_intermediates(func_graph, intermediates): 

767 with func_graph.as_default(): 

768 return [gen_optional_ops.optional_from_value([t]) for t in intermediates] 

769 

770 

771def _create_dummy_input(func_graph, template_tensor): 

772 """Creates tensors in func_graph to represent template_tensors. 

773 

774 Args: 

775 func_graph: FuncGraph. 

776 template_tensor: a tensor in the outer graph. 

777 

778 Returns: 

779 A tensor in func_graph. 

780 """ 

781 with func_graph.as_default(): 

782 return array_ops.placeholder( 

783 template_tensor.dtype, shape=template_tensor.shape) 

784 

785 

786def _create_none_optionals(func_graph, n): 

787 """Creates `n` `None` optionals in func_graph. 

788 

789 Args: 

790 func_graph: FuncGraph. 

791 n: `int` the number of `None` optionals to make. 

792 

793 Returns: 

794 A list of tensors in func_graph. 

795 """ 

796 with func_graph.as_default(): 

797 return [gen_optional_ops.optional_none() for _ in range(n)] 

798 

799 

800# TODO(b/265317139): remove this function and move this dynamic dimension 

801# handling logic to XLA once XLA shape is ready for dynamic dimensions. 

802def _convert_dynamic_dimension_to_zero(shape): 

803 """Converts dynamic dimensions in `shape` to zero. 

804 

805 The fake params created to match the intermediates captured in other branches 

806 could have dynamic dimensions. But the XLA shape is not able to handle 

807 dynamic dimensions in TF TensorShape. Setting the dynamic dimensions to 

808 size zero will help avoid failing safety checks in bridge. When XLA 

809 DynamicConditional op reconciles branch differences, XLA will replace the 

810 dimension size 0 with a bounded dimension determined from the shape of 

811 real argument in the other branch. 

812 

813 Note: Rank unknown shapes are returned as they are. 

814 

815 Args: 

816 shape: The TensorShape of fake param. 

817 

818 Returns: 

819 The new TensorShape with dynamic dimensions set to zero. 

820 """ 

821 if shape.rank is None: 

822 return shape 

823 

824 return tensor_shape.TensorShape( 

825 [0 if d is None else d for d in shape.as_list()] 

826 ) 

827 

828 

829def _create_fakeparams(func_graph, template_tensors): 

830 """Creates FakeParams for the XLA case.""" 

831 with func_graph.as_default(): 

832 return [ 

833 gen_functional_ops.fake_param( 

834 dtype=t.dtype, shape=_convert_dynamic_dimension_to_zero(t.shape)) 

835 for t in template_tensors] 

836 

837 

838def _check_same_outputs(op_type, graphs): 

839 """Raises an error if `graphs` have different outputs.""" 

840 

841 def error(branch_idx, error_detail): 

842 raise TypeError( 

843 "{b0_name} and {bn_name} arguments to {op_name} must have the same " 

844 "number, type, and overall structure of return values.\n" 

845 "\n" 

846 "{b0_name} output: {b0_out}\n" 

847 "{bn_name} output: {bn_out}\n" 

848 "\n" 

849 "Error details:\n" 

850 "{detail}".format( 

851 b0_name="true_fn" if op_type == _COND else "branches[0]", 

852 bn_name=("false_fn" if op_type == _COND else 

853 "branches[{}]".format(branch_idx)), 

854 op_name="tf.cond" if op_type == _COND else "tf.switch_case", 

855 b0_out=graphs[0].structured_outputs, 

856 bn_out=graphs[branch_idx].structured_outputs, 

857 detail=error_detail)) 

858 

859 for b in range(1, len(graphs)): 

860 try: 

861 nest.assert_same_structure( 

862 graphs[0].structured_outputs, 

863 graphs[b].structured_outputs, 

864 expand_composites=True) 

865 except (ValueError, TypeError) as e: 

866 error(b, str(e)) 

867 

868 op_type_str = "cond" if op_type == _COND else "case" 

869 if len(graphs[0].outputs) != len(graphs[b].outputs): 

870 raise ValueError("Lengths of branch outputs of {op_type} must match.\n" 

871 "len(graphs[0].outputs): {len_0}\n" 

872 "len(graphs[{b}].outputs): {len_b}\n".format( 

873 op_type=op_type_str, 

874 len_0=len(graphs[0].outputs), 

875 b=b, 

876 len_b=len(graphs[b].outputs))) 

877 for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs): 

878 if b0_out.dtype != bn_out.dtype: 

879 error(b, "%s and %s have different types" % (b0_out, bn_out)) 

880 

881 

882def _get_output_shapes(*branch_graph_outputs): 

883 output_shapes = [] 

884 for out_by_branch in zip(*branch_graph_outputs): 

885 shape = out_by_branch[0].shape 

886 for other_out in out_by_branch[1:]: 

887 shape = shape.most_specific_compatible_shape(other_out.shape) 

888 output_shapes.append(shape) 

889 return output_shapes 

890 

891 

892def _copy_handle_data(external_tensors, *branch_graph_outputs): 

893 """Combines shapes in handle data and sets metadata on `external_tensors`.""" 

894 for tensors in zip(external_tensors, *branch_graph_outputs): 

895 external = tensors[0] 

896 internal = tensors[1:] 

897 internal_handle_data = [] 

898 for tensor in internal: 

899 handle_data = handle_data_util.get_resource_handle_data(tensor) 

900 # NOTE: Assumes handle data has only one ShapeAndType entry. It's 

901 # unclear how to combine different lengths across branches. 

902 if not handle_data.is_set or len(handle_data.shape_and_type) != 1: 

903 break 

904 internal_handle_data.append(handle_data) 

905 else: # There is handle data, so we need to combine it. 

906 combined_shape = tensor_shape.TensorShape(None) 

907 combined_dtype = None 

908 for handle_data in internal_handle_data: 

909 handle_shape = tensor_shape.TensorShape( 

910 handle_data.shape_and_type[0].shape) 

911 combined_shape = combined_shape.most_specific_compatible_shape( 

912 handle_shape) 

913 if combined_dtype is None: 

914 combined_dtype = handle_data.shape_and_type[0].dtype 

915 elif handle_data.shape_and_type[0].dtype != combined_dtype: 

916 # Variants from different branches have different dtypes. The 

917 # combined variant has no static dtype. 

918 combined_dtype = types_pb2.DT_INVALID 

919 combined_handle_data = internal_handle_data[0] 

920 combined_handle_data.shape_and_type[0].shape.CopyFrom( 

921 combined_shape.as_proto()) 

922 combined_handle_data.shape_and_type[0].dtype = combined_dtype 

923 handle_data_util.set_handle_data(external, combined_handle_data) 

924 

925 

926def verify_captures(op_type, branch_graphs): 

927 """Verify that a branch's tensor is not accessed in another branch fn.""" 

928 # Note: It is technically not possible for lower-branch_index branches to 

929 # capture tensors from higher-branch_index branches, because of the order of 

930 # branch graph construction, but we check all for completeness and to 

931 # guard against potential future changes. 

932 other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)} 

933 for i, branch_graph in enumerate(branch_graphs): 

934 for t in branch_graph.external_captures: 

935 if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs: 

936 branch_names = ["true_fn", "false_fn"] if op_type == _COND else [ 

937 "branch {}".format(bi) for bi in range(len(branch_graphs))] 

938 raise ValueError( 

939 "Tensor {tname} in {b0name} is accessed from {b1name}.".format( 

940 tname=t.name, 

941 b0name=branch_names[other_branch_graphs[t.graph]], 

942 b1name=branch_names[i])) 

943 

944 

945class _CondGradFuncGraph(util.CondBranchFuncGraph): 

946 """FuncGraph for the gradient function of the branch of an If op. 

947 

948 Handles wrapping and unwrapping intermediate values that are captured by the 

949 gradient computation in optionals. 

950 

951 Attributes: 

952 op_needs_rewrite: True if any intermediates were captured, meaning the 

953 forward If op needs to be written to output the wrapped intermediates. 

954 """ 

955 

956 def __init__(self, name, forward_graph): 

957 super(_CondGradFuncGraph, self).__init__( 

958 name, collections=ops.get_default_graph()._collections) # pylint: disable=protected-access 

959 self.op_needs_rewrite = False 

960 self._forward_graph = forward_graph 

961 # Maps from forward intermediate tensor -> the unwrapped captured 

962 # intermediate. 

963 self._indirect_captures = {} 

964 # Maps unwrapped intermediate -> optional-wrapped intermediate in the 

965 # forward graph. 

966 self._wrapped_intermediates = collections.OrderedDict() 

967 # Raw intermediates captured from the forward graph. Populated iff we're in 

968 # an XLA context. 

969 self._xla_intermediates = [] 

970 # Maps forward intermediate constant valued tensor's id to the constant 

971 # created in this graph for that tensor. 

972 self._captured_constants = {} 

973 

974 @property 

975 def wrapped_intermediates(self): 

976 """The optional-wrapped intermediates captured from the forward graph.""" 

977 return list(self._wrapped_intermediates.values()) 

978 

979 @property 

980 def xla_intermediates(self): 

981 """Raw intermediates captured from the forward graph if XLA is enabled.""" 

982 return self._xla_intermediates 

983 

984 def _capture_helper(self, tensor, name): 

985 if (tensor.graph is not self._forward_graph or 

986 any(tensor is t for t in self._forward_graph.inputs) or 

987 any(tensor is t for t in self._forward_graph.outputs)): 

988 return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) 

989 

990 tensor_id = ops.tensor_id(tensor) 

991 

992 # If `tensor` is a graph-building time constant, we create a constant with 

993 # the same value in the backward graph instead of capturing it. 

994 if tensor_id in self._captured_constants: 

995 return self._captured_constants[tensor_id] 

996 elif constant_op.is_constant(tensor): 

997 self._captured_constants[tensor_id] = constant_op.constant( 

998 tensor_util.constant_value(tensor), dtype=tensor.dtype) 

999 return self._captured_constants[tensor_id] 

1000 

1001 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 

1002 # XLA does not yet support optionals, so capture intermediates directly. 

1003 # TODO(skyewm,jpienaar): can XLA support optionals? 

1004 if all(tensor is not capture for capture in self.external_captures): 

1005 self.xla_intermediates.append(tensor) 

1006 self.op_needs_rewrite = True 

1007 return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) 

1008 

1009 captured_tensor = self._indirect_captures.get(tensor_id) 

1010 if captured_tensor is not None: 

1011 return captured_tensor 

1012 

1013 # 'tensor' is an uncaptured intermediate in the forward graph. 

1014 # If it is not a resource, we wrap it in an optional in the forward graph 

1015 # and capture the optional normally. We then unwrap the captured optional 

1016 # value in the gradient graph to get the raw intermediate value. 

1017 # If it is a resource, we trace the resource up to the input in the forward 

1018 # graph and capture that. 

1019 

1020 if tensor.dtype == dtypes.resource: 

1021 # Index of the forward graph input corresponding to the resource tensor. 

1022 index = util.resource_input_index( 

1023 tensor.name, [t.name for t in self._forward_graph.inputs], 

1024 {op.name: op.node_def for op in self._forward_graph.get_operations()}, 

1025 self._forward_graph._functions) 

1026 # This gets mapped to the corresponding If op input in 

1027 # `_resolve_grad_inputs`. 

1028 captured_tensor = super(_CondGradFuncGraph, self)._capture_helper( 

1029 self._forward_graph.inputs[index], name) 

1030 else: 

1031 if tensor_id not in self._wrapped_intermediates: 

1032 # If the gradient has already been computed for this If op, 'tensor' may 

1033 # already be wrapped. 

1034 for consumer in tensor.consumers(): 

1035 if (consumer.type == "OptionalFromValue" and 

1036 any(consumer.outputs[0] is output 

1037 for output in self._forward_graph.outputs)): 

1038 optional = consumer.outputs[0] 

1039 break 

1040 else: 

1041 # 'tensor' hasn't been wrapped, do it now. 

1042 with self._forward_graph.as_default(): 

1043 optional = gen_optional_ops.optional_from_value([tensor]) 

1044 self.op_needs_rewrite = True 

1045 self._wrapped_intermediates[tensor_id] = optional 

1046 

1047 optional = self._wrapped_intermediates[tensor_id] 

1048 captured_optional = super(_CondGradFuncGraph, 

1049 self)._capture_helper(optional, name) 

1050 captured_tensor = gen_optional_ops.optional_get_value( 

1051 captured_optional, [tensor.dtype], [tensor.shape] 

1052 )[0] 

1053 

1054 self._indirect_captures[tensor_id] = captured_tensor 

1055 return captured_tensor 

1056 

1057 

1058def indexed_case(branch_index, 

1059 branch_fns, 

1060 name="indexed_case", 

1061 lower_using_switch_merge=None): 

1062 """Like conv_v2, except emits a Case op instead of an If.""" 

1063 if isinstance(branch_index, int): 

1064 raise TypeError("branch_index must not be a Python int", branch_index) 

1065 

1066 with ops.name_scope(name) as scope: 

1067 branch_names = [ 

1068 util.unique_fn_name(scope, "branch{}".format(b)) 

1069 for b in range(len(branch_fns)) 

1070 ] 

1071 

1072 # Automatic control dependencies are added in defuns, but not in v1 

1073 # graphs. Propagate that behavior here. 

1074 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 

1075 branch_index = ops.convert_to_tensor(branch_index, name="branch_index") 

1076 

1077 branch_graphs = [] 

1078 for branch_name, branch_fn in zip(branch_names, branch_fns): 

1079 branch_graphs.append( 

1080 func_graph_module.func_graph_from_py_func( 

1081 branch_name, 

1082 branch_fn, 

1083 [], 

1084 {}, 

1085 func_graph=util.CondBranchFuncGraph( 

1086 branch_name, 

1087 collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 

1088 add_control_dependencies=add_control_dependencies, 

1089 op_return_value=branch_index)) 

1090 

1091 verify_captures(_CASE, branch_graphs) 

1092 return _build_case( 

1093 branch_index, 

1094 branch_graphs, [g.external_captures for g in branch_graphs], 

1095 name=scope, 

1096 lower_using_switch_merge=lower_using_switch_merge) 

1097 

1098 

1099@ops.RegisterGradient("Case") 

1100@ops.RegisterGradient("StatelessCase") 

1101def _CaseGrad(op, *grads): # pylint: disable=invalid-name 

1102 """The gradient of a Case op produced by tf.switch_case.""" 

1103 # Get the Case operator (this logic handles the case where op is a MockOp) 

1104 case_op = op.outputs[0].op 

1105 branch_graphs = get_func_graphs(case_op) 

1106 assert branch_graphs 

1107 # Note: op.graph != ops.get_default_graph() when we are computing the gradient 

1108 # of a nested cond. 

1109 for branch_graph in branch_graphs: 

1110 assert branch_graph.outer_graph == case_op.graph 

1111 

1112 # Create grad functions that compute the gradient of the branch forward 

1113 # graphs. These functions will capture tensors from the forward pass 

1114 # functions. 

1115 branch_grad_graphs = [] 

1116 for branch_graph in branch_graphs: 

1117 branch_grad_graphs.append( 

1118 _create_grad_func(branch_graph, grads, 

1119 util.unique_grad_fn_name(branch_graph.name))) 

1120 # Replaces output None grads with zeros if at least one branch has non-None 

1121 # grad at that index. 

1122 _create_zeros_for_none_grads(branch_graphs, branch_grad_graphs) 

1123 

1124 if any(g.op_needs_rewrite for g in branch_grad_graphs): 

1125 # Modify 'op' to output the intermediates needed by the grad functions. Note 

1126 # that all needed intermediates are wrapped in optionals. Each optional 

1127 # intermediate output will have a value iff its corresponding branch is 

1128 # taken. 

1129 # NOTE(bjp): if there are any active sessions, this modification to `op` 

1130 # may make them unrunnable! 

1131 

1132 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 

1133 # XLA does not yet support optionals, so output intermediates directly and 

1134 # make them match via FakeParams, which can be converted to zeros in XLA. 

1135 # TODO(bjp,jpienaar): can XLA support optionals? 

1136 branches_intermediates = [ 

1137 branch_grad_graph.xla_intermediates 

1138 for branch_grad_graph in branch_grad_graphs 

1139 ] 

1140 extra_branch_outputs = _make_intermediates_match_xla( 

1141 branch_graphs, branches_intermediates) 

1142 else: 

1143 branch_intermediates = [ 

1144 g.wrapped_intermediates for g in branch_grad_graphs 

1145 ] 

1146 # Make outputs match by adding none optionals. 

1147 extra_branch_outputs = _make_intermediates_match(branch_graphs, 

1148 branch_intermediates) 

1149 

1150 for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs): 

1151 branch_graph.outputs.extend(extra_outputs) 

1152 # TODO(bjp): indicate it's an internal bug if this fails. 

1153 _check_same_outputs(_CASE, branch_graphs) 

1154 

1155 for branch_graph in branch_graphs: 

1156 branch_graph.name += "_rewritten" 

1157 

1158 case_op._set_func_list_attr("branches", [ 

1159 util.create_new_tf_function(branch_graph) 

1160 for branch_graph in branch_graphs 

1161 ]) 

1162 case_op._set_type_list_attr("Tout", branch_graphs[0].output_types) 

1163 case_op._set_shape_list_attr("output_shapes", 

1164 branch_graphs[0].output_shapes) 

1165 case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]], 

1166 [t.shape for t in extra_branch_outputs[0]]) 

1167 

1168 # Resolve references to forward graph tensors in grad graphs and ensure 

1169 # they are in-scope, i.e., belong to one of outer graphs of the grad graph. 

1170 branches_grad_inputs = [ 

1171 _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph, 

1172 branch_grad_graph in zip(branch_graphs, branch_grad_graphs) 

1173 ] 

1174 

1175 # This modifies the graphs in branch_grad_graphs. 

1176 _make_output_composite_tensors_match(_CASE, branch_grad_graphs) 

1177 

1178 try: 

1179 lowering = case_op._get_attr_bool("_lower_using_switch_merge") 

1180 except errors_impl.NotFoundError: 

1181 lowering = None 

1182 

1183 outputs = _build_case( 

1184 case_op.inputs[0], 

1185 branch_grad_graphs, 

1186 branches_grad_inputs, 

1187 name="gradient", 

1188 lower_using_switch_merge=lowering) 

1189 

1190 # The predicate has no gradient. 

1191 return [None] + outputs 

1192 

1193 

1194def _build_case(branch_index, 

1195 branch_graphs, 

1196 branch_inputs, 

1197 name=None, 

1198 lower_using_switch_merge=None): 

1199 """Creates an `Case` op from `branch_index`, branch graphs and inputs. 

1200 

1201 Note that this modifies `branch_graphs` to make the inputs match, and to 

1202 output all intermediates values so they're available for the gradient 

1203 computation. 

1204 

1205 `branch_graphs` need not have the same input types, but they must 

1206 have the same output types. 

1207 

1208 Args: 

1209 branch_index: integer Tensor 

1210 branch_graphs: List of FuncGraph 

1211 branch_inputs: List of lists of Tensors to be passed to corresponding 

1212 branch_graph as input. 

1213 name: the name for the Case op. 

1214 lower_using_switch_merge: Lower this op using switch merge ops (optional). 

1215 

1216 Returns: 

1217 A list of Tensors which are the outputs of the Case op. Does not include 

1218 added intermediate outputs. 

1219 """ 

1220 _make_indexed_slices_indices_types_match(_CASE, branch_graphs) 

1221 _check_same_outputs(_CASE, branch_graphs) 

1222 

1223 # Add inputs to branch_graphs to make them match. Note that this modifies the 

1224 # graphs in `branch_graphs`. 

1225 case_inputs = _make_inputs_match(branch_graphs, branch_inputs) 

1226 

1227 stateful_ops = [] 

1228 for bg in branch_graphs: 

1229 stateful_ops.extend([ 

1230 op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op) 

1231 ]) 

1232 

1233 if stateful_ops: 

1234 op_fn = gen_functional_ops.case 

1235 else: 

1236 op_fn = gen_functional_ops.stateless_case 

1237 

1238 # Create the Case op. 

1239 with ops.control_dependencies( 

1240 sum((list(bg.function_captures.control) for bg in branch_graphs), [])): 

1241 

1242 def _make_op(inputs): 

1243 case_op, tensors = util.get_op_and_outputs(op_fn( 

1244 branch_index, 

1245 inputs, [t.dtype for t in branch_graphs[0].outputs], 

1246 [util.create_new_tf_function(g) for g in branch_graphs], 

1247 output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), 

1248 name=name)) 

1249 _copy_handle_data(tensors, *[g.outputs for g in branch_graphs]) 

1250 if case_op is not None: 

1251 util.maybe_set_lowering_attr(case_op, lower_using_switch_merge) 

1252 util.maybe_propagate_compile_time_consts_in_xla(case_op) 

1253 _set_read_only_resource_inputs_attr(case_op, branch_graphs) 

1254 # Prevent fetching since the variant outputs can't be fetched directly. 

1255 case_op.graph.prevent_fetching(case_op) 

1256 

1257 # Store the branch graphs so they can be reused during the gradient 

1258 # pass. 

1259 for i, bg in enumerate(branch_graphs): 

1260 bg.outer_graph = ops.get_default_graph() 

1261 setattr(case_op, "_branch_graph_{}".format(i), bg) 

1262 

1263 return tensors 

1264 tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs) 

1265 

1266 # Return identities for each output of the Case op, rather than the output of 

1267 # the Case op directly. This makes pruning work if the output of switch_case() 

1268 # is fetched: the lowering pass converts the Case outputs into IdentityN 

1269 # outputs, which if fetched will cause all ops in the taken branch to be run 

1270 # (since it takes all merge ops as input). After lowering, each output 

1271 # identity op will end up with only the appropriate merge op as input. 

1272 # TODO(b/79984175): this doesn't have to be a tuple once we covert to the 

1273 # correct output structure 

1274 tensors = [array_ops.identity(t) for t in tensors] 

1275 

1276 return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors) 

1277 

1278 

1279def _set_read_only_resource_inputs_attr(op, branch_graphs): 

1280 """Sets the list of resource inputs which are read-only. 

1281 

1282 This is used by AutomaticControlDependencies. 

1283 

1284 Args: 

1285 op: If or Case Operation. 

1286 branch_graphs: List of branch FuncGraphs. 

1287 """ 

1288 # The first entry in `op.inputs` is the predicate which is not passed to 

1289 # branch graphs so len(branch_graph[i].inputs) == len(op.inputs) - 1. 

1290 read_only_indices = set(range(len(op.inputs) - 1)) 

1291 for branch_graph in branch_graphs: 

1292 assert len(branch_graph.inputs) == len(op.inputs) - 1, "should never happen" 

1293 if not read_only_indices: 

1294 break 

1295 branch_read_only_indices = acd.get_read_only_resource_input_indices_graph( 

1296 branch_graph) 

1297 read_only_indices = read_only_indices.intersection(branch_read_only_indices) 

1298 # Convert indices in `branch_graphs[i].inputs` to `op.inputs`. 

1299 read_only_indices = [i + 1 for i in read_only_indices] 

1300 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, 

1301 sorted(read_only_indices))