Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_state.py: 15%

399 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for managing state of v1 control flow for computing gradients.""" 

16 

17from tensorflow.python.framework import constant_op 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_util 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import control_flow_ops 

23from tensorflow.python.ops import control_flow_util as util 

24from tensorflow.python.ops import control_flow_v2_func_graphs 

25from tensorflow.python.ops import default_gradient 

26from tensorflow.python.ops import gen_data_flow_ops 

27from tensorflow.python.ops import gen_resource_variable_ops 

28from tensorflow.python.ops import resource_variable_ops 

29 

30# pylint: disable=protected-access 

31 

32 

33def _GetMaxSizeFromNestedMaximumIterations(value, while_ctxt): 

34 """Calculate a max_size for use by stack ops inside an XLA while_loop. 

35 

36 Args: 

37 value: The value inside the while_loop forward context. Used for printing 

38 error messages. 

39 while_ctxt: The forward context inside which value resides. This does not 

40 always match the value's immediate context, as `value` may be inside e.g. 

41 a cond context inside the while_loop. 

42 

43 Returns: 

44 A tensor containing the `max_size` to feed to a Stack initializer. 

45 

46 Raises: 

47 ValueError: If `value` is nested inside a `while_loop` that either 

48 lacks a `maximum_iterations` parameter, or the `maximum_iterations` 

49 parameter: 

50 

51 - is inside a `while_loop` that is a parent of the calling context, and 

52 - cannot be evaluated at graph build time to a constant. 

53 """ 

54 value_name = value.name 

55 # curr_ctxt is the context that tf.gradients was called in. 

56 curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access 

57 

58 curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else "" 

59 max_size = constant_op.constant(1) 

60 

61 # Loop through all containing while contexts between value and the 

62 # current context, multiplying together each context's 

63 # max_iterations to get the maximum stack size. 

64 while while_ctxt not in (None, curr_ctxt): 

65 max_iter = while_ctxt.maximum_iterations 

66 if max_iter is None: 

67 raise ValueError( 

68 "Cannot create a gradient accumulator for tensor '%s' inside " 

69 "XLA while_loop because maximum_iterations was not passed to " 

70 "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name)) 

71 

72 # pylint: disable=protected-access 

73 max_iter_ctxt = max_iter.op._get_control_flow_context() 

74 # pylint: enable=protected-access 

75 

76 # If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use. 

77 if util.IsContainingContext(curr_ctxt, max_iter_ctxt): 

78 max_size *= max_iter 

79 else: 

80 # We cannot use max_iter because it's defined in a nested while 

81 # or cond context, so will fail if we try to use it as input to 

82 # any ops in curr_ctxt (e.g. max_size or the final accumulator 

83 # stack). Attempt to get a constant value out to use instead. 

84 const_max_iter = tensor_util.constant_value(max_iter) 

85 if const_max_iter is None: 

86 raise ValueError( 

87 "Cannot create a gradient accumulator for tensor '%s' inside XLA " 

88 "while_loop. maximum_iterations tensor '%s' for while_loop context " 

89 "'%s' must be statically known (e.g. a constant value or known " 

90 "shape dimension), or be defined at or outside the while loop " 

91 "context '%s' (currently defined in '%s')." % 

92 (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name, 

93 max_iter_ctxt.name)) 

94 max_size *= const_max_iter 

95 

96 # Find the next outer WhileContext (or stop if we reach the 

97 # tf.gradient's context). 

98 while_ctxt = util.GetContainingWhileContext( 

99 while_ctxt.outer_context, stop_ctxt=curr_ctxt) 

100 

101 return max_size 

102 

103 

104class _GradLoopState: 

105 """The state used for constructing the gradient graph for a while loop. 

106 

107 We create a _GradLoopState for each while loop in forward and its 

108 corresponding while loop in backprop. This gives us access to both 

109 the forward and the backprop WhileContexts. 

110 

111 During the construction of gradient graph, any time when we detect 

112 a forward value that is needed for backprop, we create a history 

113 accumulator and add it to `history_map`. Any time when we backprop 

114 a loop switch op (in _SwitchGrad), we add the grad merge op in 

115 `switch_map`. 

116 """ 

117 

118 def __init__(self, forward_ctxt, outer_grad_state): 

119 # The grad loop state for the outer while loop. 

120 self._outer_grad_state = None 

121 

122 # The while loop context for forward. 

123 self._forward_context = None 

124 

125 # The loop counter added by AddForwardLoopCounter. It is the value 

126 # of the loop counter for the next iteration. 

127 self._forward_index = None 

128 

129 # A sync op for forward. 

130 self._forward_sync = None 

131 

132 # The while loop context for backprop. 

133 self._grad_context = None 

134 

135 # The loop counter added by AddBackpropLoopCounter. It is the value 

136 # of the loop counter for the current iteration. 

137 self._grad_index = None 

138 

139 # A sync op for backprop. 

140 self._grad_sync = None 

141 

142 # Information needed by backprop. 

143 self._history_map = {} 

144 self._switch_map = {} 

145 self._unused_exits = [] 

146 self._deferred_exits = [] 

147 self._forward_loop_exits = list(forward_ctxt.loop_exits) 

148 self._pending_exits_count = len(forward_ctxt.loop_exits) 

149 

150 self._outer_grad_state = outer_grad_state 

151 if outer_grad_state: 

152 outer_forward_ctxt = outer_grad_state.forward_context 

153 else: 

154 if not hasattr(forward_ctxt, "outer_context"): 

155 raise ValueError("Failed to call gradients on a while loop without" 

156 "properly serializing graph via MetaGraphDef") 

157 outer_forward_ctxt = forward_ctxt.outer_context 

158 

159 # Add the forward loop counter. 

160 with forward_ctxt._graph.as_default(): # pylint: disable=protected-access 

161 if outer_forward_ctxt: 

162 outer_forward_ctxt.Enter() 

163 cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) 

164 if outer_forward_ctxt: 

165 outer_forward_ctxt.Exit() 

166 self._forward_context = forward_ctxt 

167 self._forward_index = forward_index 

168 

169 # Add the backprop WhileContext, and the backprop loop counter. 

170 if outer_grad_state: 

171 # This is a nested loop. Remember the iteration counts for each 

172 # execution of this inner loop. 

173 outer_forward_ctxt.AddName(cnt.name) 

174 history_cnt = outer_grad_state.AddForwardAccumulator(cnt) 

175 

176 outer_grad_ctxt = outer_grad_state.grad_context 

177 outer_grad_ctxt.Enter() 

178 self._grad_context = control_flow_ops.WhileContext( 

179 maximum_iterations=forward_ctxt.maximum_iterations, 

180 parallel_iterations=forward_ctxt.parallel_iterations, 

181 back_prop=forward_ctxt.back_prop, 

182 swap_memory=forward_ctxt.swap_memory, 

183 name=forward_ctxt.name, 

184 grad_state=self) 

185 real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) 

186 self._grad_index = self._grad_context.AddBackpropLoopCounter( 

187 real_cnt, outer_grad_state) 

188 outer_grad_ctxt.Exit() 

189 else: 

190 if outer_forward_ctxt: 

191 outer_forward_ctxt.Enter() 

192 self._grad_context = control_flow_ops.WhileContext( 

193 maximum_iterations=forward_ctxt.maximum_iterations, 

194 parallel_iterations=forward_ctxt.parallel_iterations, 

195 back_prop=forward_ctxt.back_prop, 

196 swap_memory=forward_ctxt.swap_memory, 

197 name=forward_ctxt.name, 

198 grad_state=self) 

199 self._grad_index = self._grad_context.AddBackpropLoopCounter( 

200 cnt, outer_grad_state) 

201 if outer_forward_ctxt: 

202 outer_forward_ctxt.Exit() 

203 

204 @property 

205 def outer_grad_state(self): 

206 """The grad loop state for outer loop.""" 

207 return self._outer_grad_state 

208 

209 @property 

210 def forward_context(self): 

211 """The while loop context for forward.""" 

212 return self._forward_context 

213 

214 @property 

215 def forward_index(self): 

216 """The loop index of forward loop.""" 

217 return self._forward_index 

218 

219 @property 

220 def forward_sync(self): 

221 """A control trigger node for synchronization in the forward loop. 

222 

223 One main use is to keep the push ops of a stack executed in the 

224 iteration order. 

225 """ 

226 if self._forward_sync is None: 

227 with ops.control_dependencies(None): 

228 self._forward_sync = control_flow_ops.control_trigger(name="f_sync") 

229 self._forward_sync._set_control_flow_context(self._forward_context) 

230 self._forward_index.op._add_control_input(self._forward_sync) 

231 return self._forward_sync 

232 

233 @property 

234 def grad_context(self): 

235 """The corresponding WhileContext for gradient.""" 

236 return self._grad_context 

237 

238 @property 

239 def grad_index(self): 

240 """The loop index of backprop loop.""" 

241 return self._grad_index 

242 

243 @property 

244 def grad_sync(self): 

245 """A control trigger node for synchronization in the grad loop. 

246 

247 One main use is to keep the pop ops of a stack executed in the 

248 iteration order. 

249 """ 

250 if self._grad_sync is None: 

251 with ops.control_dependencies(None): 

252 self._grad_sync = control_flow_ops.control_trigger(name="b_sync") 

253 self._grad_sync._set_control_flow_context(self._grad_context) 

254 self._grad_index.op._add_control_input(self._grad_sync) 

255 if self._grad_context.outer_context: 

256 self._grad_context.outer_context.AddInnerOp(self._grad_sync) 

257 return self._grad_sync 

258 

259 @property 

260 def history_map(self): 

261 """The map that records all the tensors needed for backprop.""" 

262 return self._history_map 

263 

264 @property 

265 def switch_map(self): 

266 """The map that records all the Switch ops for the while loop.""" 

267 return self._switch_map 

268 

269 @property 

270 def unused_exits(self): 

271 """The list of "unused" exits.""" 

272 return self._unused_exits 

273 

274 @property 

275 def deferred_exits(self): 

276 """The list of "deferred" exits.""" 

277 return self._deferred_exits 

278 

279 @property 

280 def forward_loop_exits(self): 

281 """The list of exits of the forward loop.""" 

282 return self._forward_loop_exits 

283 

284 @property 

285 def pending_exits_count(self): 

286 """The number of exits we expect to see but haven't.""" 

287 return self._pending_exits_count 

288 

289 @pending_exits_count.setter 

290 def pending_exits_count(self, cnt): 

291 """Set the pending count to cnt.""" 

292 self._pending_exits_count = cnt 

293 

294 def AddForwardAccumulator(self, value, dead_branch=False): 

295 """Add an accumulator for each forward tensor that is needed in backprop. 

296 

297 This is added to the forward loop at the first time when a tensor 

298 in the forward loop is used by backprop gradient computation loop. 

299 We create an accumulator that accumulates the value of tensor at each 

300 iteration. Called in the control flow context where gradients() is called. 

301 

302 The pseudocode is: 

303 ``` 

304 acc = stack(); 

305 while (_pivot) { 

306 acc = stack_push(acc, value); 

307 } 

308 ``` 

309 

310 We make sure that the stack push op in one iteration is executed before 

311 next iteration. This is achieved by adding a control edge from 

312 `forward_index.op.inputs[0].op` to the push op, and another control 

313 edge from the push op to either `forward_index.op` or `forward_sync`. 

314 

315 Args: 

316 value: The source tensor in forward that is to be accumulated. 

317 dead_branch: True iff the tensor is on a dead branch of a cond. 

318 

319 Returns: 

320 The stack that contains the accumulated history of the tensor. 

321 

322 Raises: 

323 TypeError: For internal errors involving the value condition context. 

324 ValueError: If `value` is inside a XLA scope and a valid max size 

325 for the stack can't be found. 

326 """ 

327 # curr_ctxt is the context that tf.gradients was called in. 

328 with self._forward_index.graph.as_default(): 

329 curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access 

330 with ops.control_dependencies(None): 

331 if curr_ctxt: 

332 curr_ctxt.Enter() 

333 with ops.colocate_with(value): 

334 # We only need to pass maximum_iterations to the stack if 

335 # we're inside an XLA context. 

336 if not util.IsInXLAContext(value.op): 

337 max_size = constant_op.constant(-1, dtypes.int32) 

338 else: 

339 max_size = _GetMaxSizeFromNestedMaximumIterations( 

340 value, self.forward_context) 

341 acc = gen_data_flow_ops.stack_v2( 

342 max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") 

343 if curr_ctxt: 

344 curr_ctxt.Exit() 

345 

346 # Make acc available in the forward context. 

347 enter_acc = self.forward_context.AddValue(acc) 

348 

349 # Add the stack_push op in the context of value.op. 

350 swap_enabled = self.forward_context.swap_memory 

351 value_ctxt = util.GetOutputContext(value.op) 

352 if value_ctxt == self.forward_context: 

353 # value is not nested in the forward context. 

354 self.forward_context.Enter() 

355 push = gen_data_flow_ops.stack_push_v2( 

356 enter_acc, value, swap_memory=swap_enabled) 

357 self.forward_context.Exit() 

358 # Protect stack push and order it before forward_index. 

359 self.forward_index.op._add_control_input(push.op) 

360 else: 

361 # value is in a cond context within the forward context. 

362 if not isinstance(value_ctxt, control_flow_ops.CondContext): 

363 raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) 

364 if dead_branch: 

365 # The special case for creating a zero tensor for a dead 

366 # branch of a switch. See _ControlFlowState.ZerosLikeV1WhileLoop(). 

367 value_ctxt.outer_context.Enter() 

368 push = gen_data_flow_ops.stack_push_v2( 

369 enter_acc, value, swap_memory=swap_enabled) 

370 value_ctxt.outer_context.Exit() 

371 push.op._set_control_flow_context(value_ctxt) 

372 else: 

373 value_ctxt.Enter() 

374 push = gen_data_flow_ops.stack_push_v2( 

375 enter_acc, value, swap_memory=swap_enabled) 

376 value_ctxt.Exit() 

377 # Protect stack push and order it before forward_sync. 

378 self.forward_sync._add_control_input(push.op) 

379 # Order stack push after the successor of forward_index 

380 add_op = self.forward_index.op.inputs[0].op 

381 push.op._add_control_input(add_op) 

382 return acc 

383 

384 def AddBackpropAccumulatedValue(self, history_value, value, 

385 dead_branch=False): 

386 """Add the getter for an accumulated value in the grad context. 

387 

388 This is added to the backprop loop. Called in the grad context to 

389 get the value of an accumulated value. The stack pop op must be guarded 

390 by the pred of the controlling cond. 

391 

392 Args: 

393 history_value: The history (a stack) of a value. 

394 value: The value that is pushed onto the stack. 

395 dead_branch: True iff the tensor is on a dead branch of a cond. 

396 

397 Returns: 

398 The current value (the top of the stack). 

399 """ 

400 history_ctxt = history_value.op._get_control_flow_context() 

401 # Find the cond context that controls history_value if any. 

402 cond_ctxt = None 

403 value_ctxt = value.op._get_control_flow_context() 

404 while value_ctxt and value_ctxt != history_ctxt: 

405 if isinstance(value_ctxt, control_flow_ops.CondContext): 

406 cond_ctxt = value_ctxt 

407 break 

408 value_ctxt = value_ctxt.outer_context 

409 with ops.control_dependencies(None): 

410 self.grad_context.Enter() 

411 if cond_ctxt: 

412 # Guard stack pop with a switch if it is controlled by a cond. 

413 grad_state = self 

414 pred = None 

415 while pred is None and grad_state: 

416 pred = grad_state.history_map.get(cond_ctxt.pred.name) 

417 grad_state = grad_state.outer_grad_state 

418 if pred is None: 

419 pred = cond_ctxt.pred 

420 branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch 

421 history_value = control_flow_ops._SwitchRefOrTensor( 

422 history_value, pred)[branch] 

423 pop = gen_data_flow_ops.stack_pop_v2(history_value, 

424 value.dtype.base_dtype) 

425 pop.set_shape(value.get_shape()) 

426 self.grad_context.Exit() 

427 parallel_iterations = self.grad_context.parallel_iterations 

428 if parallel_iterations > 1: 

429 # All pops are ordered after pivot_for_body and before grad_sync. 

430 self.grad_sync._add_control_input(pop.op) 

431 return pop 

432 

433 def GetRealValue(self, value): 

434 """Get the real value of `value`. 

435 

436 If backprop "uses" a value produced by forward inference, an accumulator 

437 is added in the forward loop to accumulate its values. We use the 

438 accumulated value. This method must be called in the grad loop context. 

439 `value` must be in forward and needed for backprop. 

440 

441 Args: 

442 value: A tensor to be captured. 

443 

444 Returns: 

445 The same tensor obtained from the saved history. 

446 """ 

447 assert value.op.type not in ["Variable", "VariableV2"] 

448 real_value = self._history_map.get(value.name) 

449 if real_value is None: 

450 cur_value = value 

451 cur_grad_state = self 

452 while True: 

453 enter_op = util.GetLoopConstantEnter(cur_value) 

454 if enter_op: 

455 # Special case: cur_value comes from a constant Enter node. 

456 cur_value = enter_op.inputs[0] 

457 cur_grad_state = cur_grad_state.outer_grad_state 

458 if cur_grad_state is None: 

459 # We are now outside all nested loops for this gradient(), 

460 # so `value` is a loop invariant and there is no need to 

461 # save the history of value. Just make cur_value to enter 

462 # the right control flow context. 

463 real_value = self._grad_context.AddValue(cur_value) 

464 break 

465 elif constant_op.is_constant(cur_value): 

466 # If the value to be forwarded is a constant, clone the constant in 

467 # the gradient loop rather than using a stack. 

468 # TODO(phawkins): consider hoisting the constant out of the loop 

469 # instead. 

470 real_value = constant_op.constant( 

471 tensor_util.constant_value(cur_value), dtype=cur_value.dtype) 

472 break 

473 else: 

474 # Record the history of this value in forward_ctxt. 

475 self._grad_context.Exit() 

476 history_value = cur_grad_state.AddForwardAccumulator(cur_value) 

477 self._grad_context.Enter() 

478 break 

479 

480 if real_value is None: 

481 # Add the stack pop op in the grad context. 

482 real_value = cur_grad_state.AddBackpropAccumulatedValue( 

483 history_value, cur_value) 

484 if cur_grad_state != self: 

485 real_value = self._grad_context.AddValue(real_value) 

486 self._history_map[value.name] = real_value 

487 return real_value 

488 

489 

490class _ControlFlowState: 

491 """Maintain the mapping from the loops to their grad states.""" 

492 

493 def __init__(self): 

494 self._map = {} # maps forward loop context to _GradLoopState 

495 

496 def GetGradState(self, op, before): 

497 """Return the grad state for this op if it's in a forward loop context.""" 

498 if before and util.IsLoopExit(op): 

499 forward_ctxt = op._get_control_flow_context() # pylint: disable=protected-access 

500 forward_ctxt = forward_ctxt.outer_context 

501 if forward_ctxt: 

502 forward_ctxt = forward_ctxt.GetWhileContext() 

503 else: 

504 forward_ctxt = util.GetWhileContext(op) 

505 if forward_ctxt: 

506 return self._map.get(forward_ctxt) 

507 return None 

508 

509 def ProcessUnusedLoopExits(self, pending_count, to_ops_set): 

510 """Process all the "unused" loop exits. 

511 

512 The "unused" exits of the loops are added to `unused_exits`. An exit is 

513 unused if its pending_count is 0. If there is an exit with real gradient, 

514 all these deferred exits will enter the backprop loop with zero gradient. 

515 Otherwise, they will enter the backprop loop with None. As an example, 

516 people often write: 

517 

518 ```python 

519 v1, _ = tf.while_loop(p, b, [x1, x2]) 

520 result = gradients(v1, x1) 

521 ``` 

522 

523 The exit node for x2 is not included by the betweenness analysis. But we 

524 need to backprop x2 if x2 is involved in computing v1. 

525 

526 Args: 

527 pending_count: The number of backprop inputs for every op. 

528 to_ops_set: The set of ops for ys in gradients(ys, xs) 

529 

530 Returns: 

531 The set of unused loop exits that we know at this point we need 

532 to backprop. 

533 """ 

534 loop_exits = [] 

535 for grad_state in self._map.values(): 

536 for y in grad_state.forward_loop_exits: 

537 if pending_count[y.op] == 0: 

538 grad_state.pending_exits_count -= 1 

539 if y.op not in to_ops_set: 

540 grad_state.unused_exits.append(y) 

541 if grad_state.pending_exits_count == 0: 

542 loop_exits.extend(grad_state.unused_exits) 

543 # Need to include Enters in backprop for higher-order gradients. 

544 for y in grad_state.forward_context.loop_enters: 

545 if pending_count[y.op] == 0: 

546 pending_count[y.op] = 1 

547 return loop_exits 

548 

549 def EnterGradWhileContext(self, op, before): 

550 """Enter the WhileContext for gradient computation.""" 

551 grad_state = self.GetGradState(op, before) 

552 if grad_state: 

553 grad_state.grad_context.Enter() 

554 

555 def ExitGradWhileContext(self, op, before): 

556 """Exit the WhileContext for gradient computation.""" 

557 grad_state = self.GetGradState(op, before) 

558 if grad_state: 

559 grad_state.grad_context.Exit() 

560 

561 def AddWhileContext(self, op, between_op_list, between_ops): 

562 """Add the grad state for the while loop that op belongs to. 

563 

564 Note that op is an Exit, and this method must be called in 

565 the control flow context where gradients() is called. 

566 

567 Note that this method modifies `between_op_list` and `between_ops`. 

568 """ 

569 forward_ctxt = util.GetWhileContext(op) 

570 grad_state = self._map.get(forward_ctxt) 

571 if grad_state is None: 

572 # This is a new while loop so create a grad state for it. 

573 outer_forward_ctxt = forward_ctxt.outer_context 

574 if outer_forward_ctxt: 

575 outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() 

576 outer_grad_state = None 

577 if outer_forward_ctxt: 

578 outer_grad_state = self._map.get(outer_forward_ctxt) 

579 grad_state = _GradLoopState(forward_ctxt, outer_grad_state) 

580 self._map[forward_ctxt] = grad_state 

581 

582 # We need to include all exits of a loop for backprop. 

583 for loop_exit in grad_state.forward_loop_exits: 

584 if loop_exit.op not in between_ops: 

585 between_ops.add(loop_exit.op) 

586 between_op_list.append(loop_exit.op) 

587 

588 def ZerosLikeForExit(self, val): 

589 """Create zeros_like gradient for a loop exit. 

590 

591 If the result of a loop variable is not used but is involved in 

592 computing the result of some needed loop variable, we create a 

593 zero-valued tensor that is fed as gradient for the Exit node of that 

594 loop variable. Note that val.op is an Exit, and this method must be 

595 called in the control flow context where gradients() is called. 

596 

597 Args: 

598 val: The output tensor of an Exit op. 

599 

600 Returns: 

601 A zero tensor of the same shape of val. 

602 """ 

603 val_shape = val.get_shape() 

604 forward_ctxt = val.op._get_control_flow_context() 

605 outer_forward_ctxt = forward_ctxt.outer_context 

606 if outer_forward_ctxt: 

607 outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() 

608 outer_grad_state = None 

609 if outer_forward_ctxt: 

610 outer_grad_state = self._map.get(outer_forward_ctxt) 

611 if outer_grad_state: 

612 # This is a nested loop. 

613 if val_shape.is_fully_defined(): 

614 # If the shape is known statically, just create a zero tensor 

615 # with the right shape in the right context. 

616 outer_grad_state.grad_context.Enter() 

617 result = array_ops.zeros(val_shape.dims, val.dtype) 

618 outer_grad_state.grad_context.Exit() 

619 else: 

620 # Only the shape of value is needed for backprop. 

621 forward_ctxt.outer_context.Enter() 

622 shape = array_ops.shape_internal(val, optimize=False) 

623 forward_ctxt.outer_context.Exit() 

624 # Save the shape to a stack. 

625 history_shape = outer_grad_state.AddForwardAccumulator(shape) 

626 # Get the shape back from the stack. 

627 outer_grad_ctxt = outer_grad_state.grad_context 

628 outer_grad_ctxt.Enter() 

629 real_shape = outer_grad_state.AddBackpropAccumulatedValue( 

630 history_shape, shape) 

631 result = array_ops.zeros(real_shape, val.dtype) 

632 outer_grad_ctxt.Exit() 

633 else: 

634 # This is not a nested loop. 

635 if val_shape.is_fully_defined(): 

636 # If the shape is known statically, just create a zero tensor 

637 # with the right shape. 

638 result = array_ops.zeros(val_shape.dims, val.dtype) 

639 else: 

640 result = array_ops.zeros_like(val, optimize=False) 

641 return result 

642 

643 def ZerosLikeV1WhileLoop(self, op, index): 

644 """Create zeros_like for the specified output of an op. 

645 

646 If op is in a while loop that is part of gradients(), this method 

647 must be called in its grad loop context. 

648 

649 Args: 

650 op: A tensorflow operation. 

651 index: the index for a specific output of the op. 

652 

653 Returns: 

654 A zero tensor of the same shape of op.outputs[index]. 

655 """ 

656 if util.IsLoopSwitch(op): 

657 return None 

658 if op.graph.building_function: 

659 # The optimization here is tricky to apply to functions 

660 return array_ops.zeros_like(op.outputs[index]) 

661 dead_branch = util.IsSwitch(op) 

662 forward_ctxt = util.GetWhileContext(op) 

663 grad_state = self._map.get(forward_ctxt) 

664 if grad_state is None: 

665 # op is not in a while loop that is part of gradients(). 

666 return ZerosLike(op, index) 

667 op_ctxt = op._get_control_flow_context() 

668 val = ops.convert_to_tensor(op.outputs[index], name="tensor") 

669 shape = val.get_shape() 

670 if shape.is_fully_defined(): 

671 # If the shape is known statically, just create a zero tensor with 

672 # the right shape in the grad loop context. 

673 if val.dtype == dtypes.resource: 

674 result = array_ops.zeros( 

675 resource_variable_ops.variable_shape(val), 

676 dtype=default_gradient.get_zeros_dtype(val)) 

677 else: 

678 result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) 

679 if dead_branch: 

680 # op is a cond switch. Guard the zero tensor with a switch. 

681 pred = grad_state.history_map.get(op_ctxt.pred.name) 

682 branch = op_ctxt.branch 

683 result = control_flow_ops._SwitchRefOrTensor(result, pred)[1 - branch] 

684 else: 

685 # Unknown shape so keep a history of the shape at runtime. 

686 if dead_branch: 

687 # Need to add a special switch to guard the value. 

688 pred = op_ctxt.pred 

689 branch = op_ctxt.branch 

690 op_ctxt.outer_context.Enter() 

691 val = control_flow_ops._SwitchRefOrTensor(op.inputs[0], 

692 pred)[1 - branch] 

693 zeros_shape = array_ops.shape_internal(val, optimize=False) 

694 op_ctxt.outer_context.Exit() 

695 val.op._set_control_flow_context(op_ctxt) 

696 zeros_shape.op._set_control_flow_context(op_ctxt) 

697 else: 

698 op_ctxt.Enter() 

699 zeros_shape = array_ops.shape_internal(val, optimize=False) 

700 op_ctxt.Exit() 

701 

702 # Add forward accumulator for shape. 

703 grad_state.grad_context.Exit() 

704 history_zeros_shape = grad_state.AddForwardAccumulator( 

705 zeros_shape, dead_branch=dead_branch) 

706 grad_state.grad_context.Enter() 

707 

708 # Create a zero tensor with the right shape. 

709 shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, 

710 zeros_shape, dead_branch) 

711 result = array_ops.zeros(shape, val.dtype) 

712 return result 

713 

714 def PostProcessing(self): 

715 """Perform postprocessing at the end of gradients(). 

716 

717 We have created the gradient graph at this point. So this function 

718 can be used to perform any postprocessing on the gradient graph. 

719 We currently perform the following postprocessing: 

720 1. Patch the gradient graph if the output of a loop variable 

721 doesn't depend on its input. 

722 """ 

723 for _, grad_state in self._map.items(): 

724 for _, b_merge in grad_state.switch_map.items(): 

725 if b_merge.op.inputs[0] == b_merge.op.inputs[1]: 

726 # The value of this loop variable at iteration i+1 doesn't 

727 # depend on its value at iteration i. So use zeros as the 

728 # gradients for all iterations > 0. 

729 dtype = b_merge.op.inputs[0].dtype 

730 shape = b_merge.op.inputs[0].get_shape() 

731 # pylint: disable=protected-access 

732 if shape.is_fully_defined(): 

733 grad_state.grad_context.Enter() 

734 # Create a zeros and use it for iterations > 0. 

735 grad_val = constant_op.constant(0, dtype=dtype, shape=shape) 

736 next_grad_val = control_flow_ops._NextIteration(grad_val) 

737 grad_state.grad_context.Exit() 

738 else: 

739 # Create a zeros in the outer grad context. 

740 outer_grad_ctxt = grad_state.grad_context.outer_context 

741 if outer_grad_ctxt: 

742 outer_grad_ctxt.Enter() 

743 enter_grad_op = b_merge.op.inputs[0].op 

744 enter_grad = enter_grad_op.inputs[0] 

745 grad_shape = array_ops.shape_internal(enter_grad, optimize=False) 

746 grad_val = array_ops.zeros(grad_shape) 

747 if outer_grad_ctxt: 

748 outer_grad_ctxt.Exit() 

749 # Use the zeros for iterations > 0. 

750 grad_state.grad_context.Enter() 

751 next_grad_val = control_flow_ops._NextIteration(grad_val) 

752 grad_state.grad_context.Exit() 

753 b_merge.op._update_input(1, next_grad_val) 

754 # pylint: enable=protected-access 

755 

756 

757def MaybeCreateControlFlowState(between_op_list, between_ops, 

758 colocate_gradients_with_ops): 

759 """Create the state for all the while loops involved in one gradients(). 

760 

761 We create a _ControlFlowState when there are while loops involved in 

762 gradients(). In gradients(), control flow logic is only invoked when 

763 the _ControlFlowState is not None. 

764 

765 Note that this method modifies `between_op_list` and `between_ops`. 

766 """ 

767 loop_state = None 

768 for op in between_op_list: 

769 if util.IsLoopExit(op): 

770 if loop_state is None: 

771 loop_state = _ControlFlowState() 

772 if colocate_gradients_with_ops: 

773 with ops.colocate_with(op): 

774 loop_state.AddWhileContext(op, between_op_list, between_ops) 

775 else: 

776 loop_state.AddWhileContext(op, between_op_list, between_ops) 

777 return loop_state 

778 

779 

780def _ZerosLikeV1(op, index): 

781 """Branch of ZerosLike for TF1.""" 

782 val = op.outputs[index] 

783 op_ctxt = op._get_control_flow_context() # pylint: disable=protected-access 

784 if op_ctxt: 

785 # We are in a cond context. Use a switch to create zeros only when needed. 

786 pred = op_ctxt.pred 

787 branch = op_ctxt.branch 

788 switch_val = control_flow_ops.switch(op.inputs[0], pred)[1 - branch] 

789 # A op is created along the branch taken as control dependencies are on 

790 # the whole op and not on the tensor output. 

791 pivot = array_ops.identity(switch_val) 

792 if val.dtype == dtypes.resource: 

793 with ops.control_dependencies([pivot]): 

794 return array_ops.zeros( 

795 gen_resource_variable_ops.variable_shape(switch_val), 

796 dtype=default_gradient.get_zeros_dtype(val)) 

797 zeros_shape = array_ops.shape_internal(switch_val, optimize=False) 

798 # Ensure ops created within array_ops.zeros are dominated by switch in 

799 # cond context. 

800 with ops.control_dependencies([pivot]): 

801 return array_ops.zeros(zeros_shape, dtype=val.dtype) 

802 else: 

803 return array_ops.zeros_like(val, optimize=False) 

804 

805 

806def _ZerosLikeV2(op, index): 

807 """Branch of ZerosLike for TF2.""" 

808 val = op.outputs[index] 

809 if val.dtype == dtypes.resource: 

810 return array_ops.zeros( 

811 gen_resource_variable_ops.variable_shape(val), 

812 dtype=default_gradient.get_zeros_dtype(val)) 

813 if (isinstance(val.op.graph, control_flow_v2_func_graphs.WhileBodyFuncGraph) 

814 and val.dtype != dtypes.variant): 

815 # In while_v2 we do not want to add a `ZerosLike` op because that will 

816 # trigger accumulation of `val`. Normally `ZerosLike` is preferred because 

817 # it helps avoid creating extra nodes(possibly Consts) for the shape. 

818 # For variants, we must use ZerosLike. 

819 if val.shape.is_fully_defined(): 

820 return constant_op.constant(0, shape=val.shape.dims, dtype=val.dtype) 

821 else: 

822 # Note: Even though we add `Shape` in the default graph, while_v2 is smart 

823 # enough to place it in the forward graph i.e. `val.graph`. 

824 zeros_shape = array_ops.shape_internal(val, optimize=False) 

825 return array_ops.zeros(zeros_shape, val.dtype) 

826 else: 

827 return array_ops.zeros_like(val, optimize=False) 

828 

829 

830def ZerosLike(op, index): 

831 """Create zeros_like for the specified output of an op.""" 

832 if not util.IsSwitch(op): 

833 return _ZerosLikeV2(op, index) 

834 else: 

835 return _ZerosLikeV1(op, index)