Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/gradients_util.py: 15%

459 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implements the graph generation for computation of gradients.""" 

16 

17import collections 

18import contextlib 

19 

20from tensorflow.core.framework import attr_value_pb2 

21from tensorflow.python import pywrap_tfe 

22from tensorflow.python.eager import backprop_util 

23from tensorflow.python.eager import context 

24from tensorflow.python.framework import composite_tensor 

25from tensorflow.python.framework import composite_tensor_gradient 

26from tensorflow.python.framework import dtypes 

27from tensorflow.python.framework import indexed_slices 

28from tensorflow.python.framework import ops 

29from tensorflow.python.framework import tensor_shape 

30from tensorflow.python.ops import array_ops 

31from tensorflow.python.ops import control_flow_ops 

32from tensorflow.python.ops import control_flow_state 

33from tensorflow.python.ops import control_flow_util 

34from tensorflow.python.ops import default_gradient 

35from tensorflow.python.ops import gen_functional_ops 

36from tensorflow.python.ops import math_ops 

37from tensorflow.python.ops import resource_variable_ops 

38from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 

39from tensorflow.python.platform import tf_logging as logging 

40from tensorflow.python.util import compat 

41from tensorflow.python.util import object_identity 

42from tensorflow.python.util import variable_utils 

43from tensorflow.python.util.compat import collections_abc 

44from tensorflow.python.util.tf_export import tf_export 

45 

46 

47def _MarkReachedOps(from_ops, reached_ops, func_graphs): 

48 """Mark all ops reached from "from_ops". 

49 

50 Args: 

51 from_ops: list of Operations. 

52 reached_ops: set of Operations. 

53 func_graphs: list of FuncGraphs. This method will traverse through 

54 these functions if they capture from_ops or any reachable ops. 

55 """ 

56 queue = collections.deque() 

57 queue.extend(from_ops) 

58 while queue: 

59 op = queue.popleft() 

60 if op not in reached_ops: 

61 reached_ops.add(op) 

62 for output in op.outputs: 

63 if backprop_util.IsTrainable(output): 

64 queue.extend(_Consumers(output, func_graphs)) 

65 

66 

67def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, 

68 xs_set): 

69 """Initialize the pending count for ops between two lists of Operations. 

70 

71 'pending_count[op]' indicates the number of backprop inputs 

72 to this operation. 

73 

74 Args: 

75 to_ops: list of Operations. 

76 from_ops: list of Operations. 

77 colocate_gradients_with_ops: Python bool. See docstring of gradients(). 

78 func_graphs: list of FuncGraphs. This method will traverse through 

79 these functions if they capture from_ops or any reachable ops. This is 

80 useful if to_ops occur in a function and from_ops are in an outer function 

81 or graph. 

82 xs_set: ObjectIdentitySet of Tensors. 

83 

84 Returns: 

85 A tuple containing: (1) the subset of to_ops reachable from from_ops by a 

86 path of zero or more backpropagatable tensors, (2) a mapping from operation 

87 to the number of backprop inputs to that op, and (3) a ControlFlowState 

88 object which is not None if the ops between from_ops and to_ops contain 

89 control flow loops. 

90 """ 

91 # Mark reachable ops from from_ops. 

92 reached_ops = set() 

93 _MarkReachedOps(from_ops, reached_ops, func_graphs) 

94 # X in reached_ops iff X is reachable from from_ops by a path of zero or more 

95 # backpropagatable tensors. 

96 

97 reachable_to_ops = set(op for op in to_ops if op in reached_ops) 

98 

99 # Mark between ops. 

100 between_ops = set() 

101 between_op_list = [] 

102 queue = collections.deque() 

103 queue.extend(to_ops) 

104 while queue: 

105 op = queue.popleft() 

106 # We are interested in this op. 

107 if op in reached_ops: 

108 between_ops.add(op) 

109 between_op_list.append(op) 

110 # Clear the boolean so we won't add the inputs again. 

111 reached_ops.remove(op) 

112 for inp in _NonEagerInputs(op, xs_set): 

113 queue.append(inp.op) 

114 # X in between_ops iff X is on a path of zero or more backpropagatable tensors 

115 # between from_ops and to_ops 

116 

117 # 'loop_state' is None if there are no while loops. 

118 loop_state = control_flow_state.MaybeCreateControlFlowState( 

119 between_op_list, between_ops, colocate_gradients_with_ops) 

120 

121 # Initialize pending count for between ops. 

122 pending_count = collections.defaultdict(int) 

123 for op in between_op_list: 

124 for x in _NonEagerInputs(op, xs_set): 

125 if x.op in between_ops: 

126 pending_count[x.op] += 1 

127 

128 return reachable_to_ops, pending_count, loop_state 

129 

130 

131def _AsList(x): 

132 return x if isinstance(x, (list, tuple)) else [x] 

133 

134 

135def _DefaultGradYs(grad_ys, 

136 ys, 

137 colocate_gradients_with_ops, 

138 gradient_uid="__unsupported__"): 

139 """Fill in default values for grad_ys. 

140 

141 Args: 

142 grad_ys: List of gradients, can contain None. 

143 ys: List of tensors. 

144 colocate_gradients_with_ops: If True, try colocating gradients with 

145 the corresponding op. 

146 gradient_uid: A unique identifier within the graph indicating 

147 which invocation of gradients is being executed. Used to cluster 

148 ops for compilation. 

149 

150 Returns: 

151 A list of gradients to use, without None. 

152 

153 Raises: 

154 ValueError: If sizes of gradients and inputs don't match 

155 TypeError: If type of any gradient is not valid for its input. 

156 """ 

157 if len(grad_ys) != len(ys): 

158 raise ValueError(f"Length mismatch. Passed {len(grad_ys)} grad_ys for " 

159 f"{len(ys)} ys") 

160 grad_ys = indexed_slices.convert_n_to_tensor_or_indexed_slices( 

161 grad_ys, name="grad_y") 

162 new_grad_ys = [] 

163 for i, (y, grad_y) in enumerate(zip(ys, grad_ys)): 

164 with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops): 

165 if grad_y is None: 

166 if y.dtype.is_complex: 

167 raise TypeError( 

168 f"Gradients of complex tensors ({y}) must set grad_ys (y.dtype = " 

169 f"{dtypes.as_dtype(y.dtype).name})") 

170 new_grad_ys.append( 

171 array_ops.ones( 

172 array_ops.shape(y), dtype=y.dtype, name="grad_ys_%d" % i)) 

173 continue 

174 if y.dtype.is_floating or y.dtype.is_integer: 

175 if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer: 

176 raise TypeError( 

177 f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated " 

178 f"for real or integer-valued tensor {y} with type " 

179 f"{dtypes.as_dtype(y.dtype).name} must be real or integer") 

180 elif y.dtype.is_complex: 

181 if not grad_y.dtype.is_complex: 

182 raise TypeError( 

183 f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated " 

184 f"for complex-valued tensor {y} with type " 

185 f"{dtypes.as_dtype(y.dtype).name} must be real") 

186 elif y.dtype == dtypes.variant: 

187 if grad_y.dtype != dtypes.variant: 

188 raise TypeError( 

189 f"Gradient type {dtypes.as_dtype(grad_y.dtype).name} generated " 

190 f"for variant tensor {y} with type " 

191 f"{dtypes.as_dtype(y.dtype).name} must be variant") 

192 elif y.dtype == dtypes.resource: 

193 # We assume y is the handle of a ResourceVariable. The gradient of a 

194 # ResourceVariable should be a numeric value, not another resource. 

195 if grad_y.dtype == dtypes.resource: 

196 raise TypeError(f"Input gradient {grad_y} for resource tensor {y} " 

197 "should not be a resource") 

198 else: 

199 raise TypeError( 

200 f"Tensor {y} with type {dtypes.as_dtype(y.dtype).name} must be " 

201 "numeric to obtain a default gradient") 

202 # Create a grad_y tensor in the name scope of the gradient. 

203 # Required for TensorArrays to identify which gradient call a 

204 # grad_y value is coming from. 

205 if isinstance(grad_y, indexed_slices.IndexedSlices): 

206 new_grad_ys.append( 

207 indexed_slices.IndexedSlices( 

208 indices=(array_ops.identity( 

209 grad_y.indices, name="grad_ys_%d_indices" % i) 

210 if isinstance(grad_y.indices, ops.Tensor) else 

211 grad_y.indices), 

212 values=(array_ops.identity( 

213 grad_y.values, name="grad_ys_%d_values" % i) if isinstance( 

214 grad_y.values, ops.Tensor) else grad_y.values), 

215 dense_shape=(array_ops.identity( 

216 grad_y.dense_shape, name="grad_ys_%d_shape" % i) 

217 if isinstance(grad_y.dense_shape, ops.Tensor) else 

218 grad_y.dense_shape))) 

219 else: 

220 new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i)) 

221 

222 return new_grad_ys 

223 

224 

225def _VerifyGeneratedGradients(grads, op): 

226 """Verify that gradients are valid in number and type. 

227 

228 Args: 

229 grads: List of generated gradients. 

230 op: Operation for which the gradients where generated. 

231 

232 Raises: 

233 ValueError: if sizes of gradients and inputs don't match. 

234 TypeError: if type of any gradient is not valid for its input. 

235 """ 

236 # While ops have inputs added to them during the gradient computation, so we 

237 # skip the below check. See while_v2 for details. 

238 if op.type == "While" or op.type == "StatelessWhile": 

239 return 

240 

241 if len(grads) != len(op.inputs): 

242 raise ValueError(f"Num gradients {len(grads)} generated for op " 

243 f"{op.node_def} do not match num inputs {len(op.inputs)}") 

244 

245 

246def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set): 

247 """The set of ops that terminate the gradient computation. 

248 

249 This computes the frontier of the forward graph *before* which backprop 

250 should stop. Operations in the returned set will not be differentiated. 

251 This set is defined as the subset of `from_ops` containing ops that have 

252 no predecessor in `from_ops`. `pending_count` is the result of 

253 `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops` 

254 iff pending_count[op] > 0. 

255 

256 In addition, none of `stop_gradient_ops` will be differentiated. 

257 

258 Args: 

259 from_ops: list of Operations. 

260 stop_gradient_ops: list of Operations never to backprop through. 

261 pending_count: mapping from operation to number of backprop inputs. 

262 xs_set: ObjectIdentitySet of Tensors. 

263 

264 Returns: 

265 The set of operations. 

266 """ 

267 stop_ops = set() 

268 for op in from_ops: 

269 is_stop_op = True 

270 for inp in _NonEagerInputs(op, xs_set): 

271 if pending_count[inp.op] > 0: 

272 is_stop_op = False 

273 break 

274 if is_stop_op: 

275 stop_ops.add(op) 

276 stop_ops.update(op for op in stop_gradient_ops) 

277 return stop_ops 

278 

279 

280@contextlib.contextmanager 

281def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pylint: disable=invalid-name 

282 """Context to colocate with `op` if `colocate_gradients_with_ops`.""" 

283 if colocate_gradients_with_ops: 

284 with ops._colocate_with_for_gradient(op, gradient_uid): # pylint: disable=protected-access 

285 yield 

286 else: 

287 yield 

288 

289 

290def _IsPartitionedCall(op): 

291 return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall" 

292 

293 

294def _SymGrad(op, out_grads): 

295 """Backprop through a function call node op given its outputs' gradients.""" 

296 f_in = [x for x in op.inputs] + out_grads 

297 f_types = [default_gradient.get_zeros_dtype(x) for x in op.inputs] 

298 f = attr_value_pb2.NameAttrList() 

299 if _IsPartitionedCall(op): 

300 f.name = op.get_attr("f").name 

301 else: 

302 f.name = op.type 

303 for k in op.node_def.attr: 

304 f.attr[k].CopyFrom(op.node_def.attr[k]) 

305 in_grads = gen_functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f) 

306 return in_grads 

307 

308 

309def _MaybeCompile(scope, op, func, grad_fn): 

310 """Compile the calculation in grad_fn if op was marked as compiled.""" 

311 scope = scope.rstrip("/").replace("/", "_") 

312 if func is not None: 

313 xla_compile = func.cached_definition.attr["_XlaCompile"].b 

314 xla_separate_compiled_gradients = func.cached_definition.attr[ 

315 "_XlaSeparateCompiledGradients"].b 

316 xla_scope = func.cached_definition.attr["_XlaScope"].s.decode() 

317 else: 

318 try: 

319 xla_compile = op.get_attr("_XlaCompile") 

320 xla_separate_compiled_gradients = op.get_attr( 

321 "_XlaSeparateCompiledGradients") 

322 xla_scope = op.get_attr("_XlaScope").decode() 

323 except ValueError: 

324 xla_compile = False 

325 

326 if not xla_compile: 

327 return grad_fn() # Exit early 

328 

329 # If the gradients are supposed to be compiled separately, we give them a 

330 # _XlaScope name that is based on the name_scope of the gradients. Otherwise 

331 # they just inherit the existing _XlaScope name, which lets them be merged 

332 # together with the non-gradient computation. 

333 if xla_separate_compiled_gradients: 

334 xla_grad_scope = "%s_grad_%s" % (xla_scope, scope) 

335 else: 

336 xla_grad_scope = xla_scope 

337 

338 attrs = { 

339 "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile), 

340 "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode()) 

341 } 

342 with ops.get_default_graph()._attr_scope(attrs): # pylint: disable=protected-access 

343 return grad_fn() 

344 

345 

346def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set): 

347 """Raises an error if we backprop through a loop var.""" 

348 # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error 

349 # message. 

350 target_op = None 

351 queue = collections.deque([op]) 

352 visited = set() 

353 while queue: 

354 curr_op = queue.popleft() 

355 if curr_op in visited: continue 

356 visited.add(curr_op) 

357 if curr_op in from_ops: 

358 target_op = curr_op 

359 break 

360 queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set)) 

361 assert target_op 

362 raise ValueError( 

363 "Cannot compute gradient inside while loop with respect to op " 

364 f"'{target_op.name}'. We do not support taking the gradient wrt or " 

365 "through the initial value of a loop variable. Gradients can be computed " 

366 "through loop invariants or wrt the input parameters to the loop body.") 

367 

368 

369def _IsFunction(graph): 

370 # isinstance check for FuncGraphs that avoids the explicit dependency 

371 # on func_graph.py and function.py 

372 return isinstance(graph, ops.Graph) and graph._building_function # pylint: disable=protected-access 

373 

374 

375def _Captures(func_graph): 

376 assert _IsFunction(func_graph) 

377 return func_graph.captures 

378 

379 

380def _MaybeCaptured(t): 

381 """If t is a captured value placeholder, returns the original captured value. 

382 

383 Args: 

384 t: Tensor 

385 

386 Returns: 

387 A tensor, potentially from a different Graph/FuncGraph. 

388 """ 

389 # pylint: disable=protected-access 

390 if (not isinstance(t, ops.EagerTensor) and 

391 _IsFunction(t.op.graph) and t.op.type == "Placeholder"): 

392 for input_t, placeholder_t in _Captures(t.op.graph): 

393 if t is placeholder_t: 

394 return _MaybeCaptured(input_t) 

395 # pylint: enable=protected-access 

396 return t 

397 

398 

399def _NonEagerInputs(op, xs_set): 

400 """Returns the inputs of op, crossing closure boundaries where necessary. 

401 

402 Does not return any captured EagerTensors, i.e., the number of tensors 

403 returned may be less than the actual number of inputs. 

404 

405 Args: 

406 op: Operation 

407 xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t. 

408 

409 Returns: 

410 A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op 

411 is in a FuncGraph and has captured inputs. 

412 """ 

413 return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)] 

414 

415 

416# TODO(skyewm): plumbing xs through everywhere is ugly, consider making 

417# _GradientsHelper a class with xs as a member variable. 

418def _Inputs(op, xs_set): 

419 """Returns the inputs of op, crossing closure boundaries where necessary. 

420 

421 Args: 

422 op: Operation 

423 xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t. 

424 

425 Returns: 

426 A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op 

427 is in a FuncGraph and has captured inputs. 

428 """ 

429 if _IsFunction(op.graph): # pylint: disable=protected-access 

430 inputs = [] 

431 for t in op.inputs: 

432 # If we're differentiating w.r.t. `t`, do not attempt to traverse through 

433 # it to a captured value. The algorithm needs to "see" `t` in this case, 

434 # even if it's a function input for a captured value, whereas usually we'd 

435 # like to traverse through these closures as if the captured value was the 

436 # direct input to op. 

437 if t not in xs_set: 

438 t = _MaybeCaptured(t) 

439 inputs.append(t) 

440 return inputs 

441 else: 

442 return op.inputs 

443 

444 

445def _Consumers(t, func_graphs): 

446 """Returns the consumers of t, crossing closure boundaries where necessary. 

447 

448 Args: 

449 t: Tensor 

450 func_graphs: a list of FuncGraphs that may have captured t. 

451 

452 Returns: 

453 A list of tensors. The tensors will be from the current graph and/or 

454 func_graphs. 

455 """ 

456 consumers = t.consumers() 

457 for func in func_graphs: 

458 for input_t, placeholder in _Captures(func): 

459 if input_t is t: 

460 consumers.extend(_Consumers(placeholder, func_graphs)) 

461 return consumers 

462 

463 

464def _GradientsHelper(ys, 

465 xs, 

466 grad_ys=None, 

467 name="gradients", 

468 colocate_gradients_with_ops=False, 

469 gate_gradients=False, 

470 aggregation_method=None, 

471 stop_gradients=None, 

472 unconnected_gradients=UnconnectedGradients.NONE, 

473 src_graph=None): 

474 """Implementation of gradients().""" 

475 if context.executing_eagerly(): 

476 raise RuntimeError("tf.gradients is not supported when eager execution " 

477 "is enabled. Use tf.GradientTape instead.") 

478 ys = variable_utils.convert_variables_to_tensors(_AsList(ys)) 

479 xs = [ 

480 x.handle if resource_variable_ops.is_resource_variable(x) else x 

481 for x in _AsList(xs) 

482 ] 

483 if grad_ys is not None: 

484 grad_ys = _AsList(grad_ys) 

485 

486 # Handle CompositeTensors. 

487 if (any(isinstance(x, composite_tensor.CompositeTensor) for x in xs) or 

488 any(isinstance(y, composite_tensor.CompositeTensor) for y in ys)): 

489 flat_xs = composite_tensor_gradient.get_flat_tensors_for_gradients(xs) 

490 flat_ys = composite_tensor_gradient.get_flat_tensors_for_gradients(ys) 

491 flat_grad_ys = ( 

492 None if grad_ys is None else 

493 composite_tensor_gradient.get_flat_tensors_for_gradients(grad_ys)) 

494 flat_grads = _GradientsHelper(flat_ys, flat_xs, flat_grad_ys, name, 

495 colocate_gradients_with_ops, gate_gradients, 

496 aggregation_method, stop_gradients, 

497 unconnected_gradients, src_graph) 

498 return composite_tensor_gradient.replace_flat_tensors_for_gradients( 

499 xs, flat_grads) 

500 

501 if src_graph is None: 

502 src_graph = ops.get_default_graph() 

503 try: 

504 unconnected_gradients = UnconnectedGradients(unconnected_gradients) 

505 except ValueError: 

506 raise ValueError( 

507 f"Unknown value for unconnected_gradients: '{unconnected_gradients}'") 

508 

509 # If src_graph is a _FuncGraph (i.e. a function body), gather it and all 

510 # ancestor graphs. This is necessary for correctly handling captured values. 

511 func_graphs = [] 

512 curr_graph = src_graph 

513 while _IsFunction(curr_graph): 

514 func_graphs.append(curr_graph) 

515 curr_graph = curr_graph.outer_graph 

516 

517 stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients) 

518 if grad_ys is None: 

519 grad_ys = [None] * len(ys) 

520 

521 with ops.name_scope( 

522 name, "gradients", 

523 list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope: 

524 # Get a uid for this call to gradients that can be used to help 

525 # cluster ops for compilation. 

526 gradient_uid = ops.get_default_graph().unique_name("uid") 

527 ys = indexed_slices.convert_n_to_tensor_or_indexed_slices(ys, name="y") 

528 xs = indexed_slices.internal_convert_n_to_tensor_or_indexed_slices( 

529 xs, name="x", as_ref=True) 

530 xs_set = object_identity.ObjectIdentitySet(xs) 

531 grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, 

532 gradient_uid) 

533 

534 # The approach we take here is as follows: Create a list of all ops in the 

535 # subgraph between the ys and xs. Visit these ops in reverse order of ids 

536 # to ensure that when we visit an op the gradients w.r.t its outputs have 

537 # been collected. Then aggregate these gradients if needed, call the op's 

538 # gradient function, and add the generated gradients to the gradients for 

539 # its input. 

540 

541 # Initialize the pending count for ops in the connected subgraph from ys 

542 # to the xs. 

543 to_ops = [t.op for t in ys] 

544 from_ops = [t.op for t in xs] 

545 stop_gradient_ops = [t.op for t in stop_gradients] 

546 reachable_to_ops, pending_count, loop_state = _PendingCount( 

547 to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set) 

548 

549 # Iterate over the collected ops. 

550 # 

551 # grads: op => list of gradients received on each output endpoint of the 

552 # op. The gradients for each endpoint are initially collected as a list. 

553 # When it is time to call the op's gradient function, for each endpoint we 

554 # aggregate the list of received gradients into a Add() Operation if there 

555 # is more than one. 

556 grads = {} 

557 

558 # Add the initial gradients for the ys. 

559 for y, grad_y in zip(ys, grad_ys): 

560 _SetGrad(grads, y, grad_y) 

561 

562 # Initialize queue with to_ops. 

563 queue = collections.deque() 

564 # Add the ops in 'to_ops' into the queue. 

565 to_ops_set = set() 

566 for op in to_ops: 

567 # 'ready' handles the case where one output gradient relies on 

568 # another output's gradient. 

569 ready = (pending_count[op] == 0) 

570 if ready and op not in to_ops_set and op in reachable_to_ops: 

571 to_ops_set.add(op) 

572 queue.append(op) 

573 

574 if loop_state: 

575 loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set) 

576 for y in loop_exits: 

577 if backprop_util.IsTrainable(y): 

578 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) 

579 queue.append(y.op) 

580 

581 stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set) 

582 while queue: 

583 # generate gradient subgraph for op. 

584 op = queue.popleft() 

585 with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): 

586 if loop_state: 

587 loop_state.EnterGradWhileContext(op, before=True) 

588 out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, 

589 aggregation_method) 

590 if loop_state: 

591 loop_state.ExitGradWhileContext(op, before=True) 

592 

593 grad_fn = None 

594 func_call = None 

595 is_partitioned_call = _IsPartitionedCall(op) 

596 # pylint: disable=protected-access 

597 is_func_call = ( 

598 src_graph._is_function(op.type) or is_partitioned_call) 

599 # pylint: enable=protected-access 

600 has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) 

601 if has_out_grads and (op not in stop_ops): 

602 try: 

603 grad_fn = ops.get_gradient_function(op) 

604 except LookupError: 

605 if is_func_call: 

606 if is_partitioned_call: 

607 func_name = compat.as_bytes(op.get_attr("f").name) 

608 func_call = src_graph._get_function( # pylint: disable=protected-access 

609 func_name) 

610 # When a graph is imported, the FunctionDefs are not copied over 

611 # to each sub-graph so we recursively search the outer graphs 

612 # for the FunctionDef. 

613 if not func_call and hasattr(src_graph, "outer_graph"): 

614 graph = src_graph.outer_graph 

615 while graph is not None: 

616 func_call = graph._get_function(func_name) # pylint: disable=protected-access 

617 if func_call is not None: 

618 break 

619 if hasattr(graph, "outer_graph"): 

620 graph = graph.outer_graph 

621 else: 

622 break 

623 else: 

624 func_call = src_graph._get_function(op.type) # pylint: disable=protected-access 

625 # Note that __defun is not set if the graph is 

626 # imported. If it's set, we prefer to access the original 

627 # defun. 

628 func_call = getattr(op, "__defun", func_call) 

629 grad_fn = func_call.python_grad_func 

630 else: 

631 raise LookupError( 

632 "No gradient defined for operation" 

633 f"'{op.name}' (op type: {op.type}). " 

634 "In general every operation must have an associated " 

635 "`@tf.RegisterGradient` for correct autodiff, which this " 

636 "op is lacking. If you want to pretend this " 

637 "operation is a constant in your program, you may insert " 

638 "`tf.stop_gradient`. This can be useful to silence the " 

639 "error in cases where you know gradients are not needed, " 

640 "e.g. the forward pass of tf.custom_gradient. " 

641 "Please see more details in " 

642 "https://www.tensorflow.org/api_docs/python/tf/custom_gradient.") # pylint: disable=line-too-long 

643 if loop_state: 

644 loop_state.EnterGradWhileContext(op, before=False) 

645 

646 # NOTE(skyewm): We don't support computing gradients wrt a loop variable 

647 # unless it's within the context of a single iteration (i.e. the 

648 # gradient is wrt to the loop parameter in the body function, not wrt or 

649 # through the initial value). This means if we're in a while loop 

650 # context, we should never see a switch node from this context. 

651 # pylint: disable=protected-access 

652 if (control_flow_util.IsSwitch(op) and 

653 op._control_flow_context is not None and 

654 op._control_flow_context.IsWhileContext() and 

655 op._control_flow_context == 

656 ops.get_default_graph()._get_control_flow_context()): 

657 _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set) 

658 # pylint: enable=protected-access 

659 

660 if (grad_fn or is_func_call) and has_out_grads: 

661 # NOTE: If _AggregatedGrads didn't compute a value for the i'th 

662 # output, it means that the cost does not depend on output[i], 

663 # therefore dC/doutput[i] is 0. 

664 for i, out_grad in enumerate(out_grads): 

665 if (not isinstance(out_grad, ops.Tensor) and not out_grad) and ( 

666 (not grad_fn and is_func_call) 

667 or backprop_util.IsTrainable(op.outputs[i])): 

668 # Only trainable outputs or outputs for a function call that 

669 # will use SymbolicGradient get a zero gradient. Gradient 

670 # functions should ignore the gradient for other outputs. 

671 # TODO(apassos) gradients of resource handles might be an 

672 # issue here because of zeros. 

673 if loop_state: 

674 out_grads[i] = loop_state.ZerosLikeV1WhileLoop(op, i) 

675 elif default_gradient.supports_default_grad(op.outputs[i]): 

676 # TODO(b/143286622): The supports_default_grad check is needed 

677 # because While op emits non-differentiable resource tensors 

678 # as outputs. Remove this check when that is not the case. 

679 out_grads[i] = control_flow_state.ZerosLike(op, i) 

680 with ops.name_scope(op.name + "_grad"): 

681 # pylint: disable=protected-access 

682 with src_graph._original_op(op): 

683 # pylint: enable=protected-access 

684 if grad_fn: 

685 # If grad_fn was found, do not use SymbolicGradient even for 

686 # functions. 

687 in_grads = _MaybeCompile(grad_scope, op, func_call, 

688 lambda: grad_fn(op, *out_grads)) 

689 else: 

690 # For function call ops, we add a 'SymbolicGradient' 

691 # node to the graph to compute gradients. 

692 in_grads = _MaybeCompile(grad_scope, op, func_call, 

693 lambda: _SymGrad(op, out_grads)) 

694 in_grads = _AsList(in_grads) 

695 _VerifyGeneratedGradients(in_grads, op) 

696 if gate_gradients and len([x for x in in_grads 

697 if x is not None]) > 1: 

698 with ops.device(None): 

699 with ops._colocate_with_for_gradient( # pylint: disable=protected-access 

700 None, 

701 gradient_uid, 

702 ignore_existing=True): 

703 in_grads = control_flow_ops.tuple(in_grads) 

704 _LogOpGradients(op, out_grads, in_grads) 

705 else: 

706 # If no grad_fn is defined or none of out_grads is available, 

707 # just propagate a list of None backwards. 

708 in_grads = [None] * len(_Inputs(op, xs_set)) 

709 # Note: we don't filter out eager inputs here because the inputs need to 

710 # line up with in_grads. 

711 for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)): 

712 if in_grad is not None: 

713 if (isinstance(in_grad, ops.Tensor) and 

714 t_in.dtype != dtypes.resource): 

715 try: 

716 in_grad.set_shape(t_in.get_shape()) 

717 except ValueError: 

718 raise ValueError( 

719 "Incompatible shapes between op input and calculated " 

720 f"input gradient. Forward operation: {op.name}. Input " 

721 f"index: {i}. Original input shape: {t_in.shape}. " 

722 f"Calculated input gradient shape: {in_grad.shape}") 

723 if not isinstance(t_in, ops.EagerTensor): 

724 _SetGrad(grads, t_in, in_grad) 

725 if loop_state: 

726 loop_state.ExitGradWhileContext(op, before=False) 

727 

728 # Update pending count for the inputs of op and enqueue ready ops. 

729 _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, 

730 xs_set) 

731 

732 if loop_state: 

733 loop_state.PostProcessing() 

734 return [_GetGrad(grads, x, unconnected_gradients) for x in xs] 

735 

736 

737def _HasAnyNotNoneGrads(grads, op): 

738 """Return true iff op has real gradient.""" 

739 out_grads = _GetGrads(grads, op) 

740 for out_grad in out_grads: 

741 if isinstance(out_grad, (ops.Tensor, indexed_slices.IndexedSlices)): 

742 return True 

743 if out_grad and isinstance(out_grad, collections_abc.Sequence): 

744 if any(g is not None for g in out_grad): 

745 return True 

746 return False 

747 

748 

749def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, 

750 xs_set): 

751 """Update pending count for the inputs of op and enqueue ready ops.""" 

752 for x in _NonEagerInputs(op, xs_set): 

753 pending_count[x.op] -= 1 

754 ready = (pending_count[x.op] == 0) 

755 if loop_state and not ready: 

756 ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op) 

757 if ready: 

758 if control_flow_util.IsLoopExit(x.op): 

759 # if x is an exit without real gradient, defer processing them. 

760 grad_state = loop_state.GetGradState(x.op, before=False) 

761 grad_state.deferred_exits.append(x) 

762 grad_state.pending_exits_count -= 1 

763 if grad_state.pending_exits_count == 0: 

764 # We now have all the exits so process them. 

765 has_not_none_grad = False 

766 for y in grad_state.deferred_exits: 

767 if _HasAnyNotNoneGrads(grads, y.op): 

768 has_not_none_grad = True 

769 queue.append(y.op) 

770 else: 

771 grad_state.unused_exits.append(y) 

772 if has_not_none_grad: 

773 # For an unused exit, if it has trainable outputs, backprop 

774 # a zero gradient. Otherwise, just ignore it. 

775 for y in grad_state.unused_exits: 

776 if backprop_util.IsTrainable(y): 

777 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) 

778 queue.append(y.op) 

779 else: 

780 # All exits are "unused" so use None as gradient. 

781 for y in grad_state.unused_exits: 

782 queue.append(y.op) 

783 else: 

784 queue.append(x.op) 

785 

786 

787def _SetGrad(grads, t, grad): 

788 """Sets gradient "grad" in "grads" for tensor "t".""" 

789 op = t.op 

790 op_grads = grads.get(op) 

791 if not op_grads: 

792 op_grads = [[] for _ in range(len(op.outputs))] 

793 grads[op] = op_grads 

794 t_grads = op_grads[t.value_index] 

795 if isinstance(t_grads, list): 

796 t_grads.append(grad) 

797 else: 

798 assert control_flow_util.IsLoopSwitch(op) 

799 op_grads[t.value_index] = grad 

800 

801 

802def _ZerosLike(t): 

803 t_dtype = default_gradient.get_zeros_dtype(t) 

804 if t.dtype == dtypes.resource: 

805 return array_ops.zeros( 

806 resource_variable_ops.variable_shape(t), dtype=t_dtype) 

807 else: 

808 return array_ops.zeros_like(t, dtype=t_dtype) 

809 

810 

811def _GetGrad(grads, t, unconnected_gradients): 

812 """Gets gradient for tensor "t".""" 

813 op = t.op 

814 op_grads = grads.get(op) 

815 if not op_grads: 

816 if unconnected_gradients == UnconnectedGradients.ZERO: 

817 return _ZerosLike(t) 

818 elif unconnected_gradients == UnconnectedGradients.NONE: 

819 return None 

820 else: 

821 raise ValueError( 

822 f"Unknown value for unconnected_gradients: '{unconnected_gradients}'") 

823 

824 t_grad = op_grads[t.value_index] 

825 # This can happen if some other output of `t.op` has non-None grad. 

826 if unconnected_gradients == UnconnectedGradients.ZERO and t_grad is None: 

827 return _ZerosLike(t) 

828 

829 assert not isinstance( 

830 t_grad, list), ("gradients list should have been aggregated by now.") 

831 return t_grad 

832 

833 

834def _GetGrads(grads, op): 

835 """Gets all gradients for op.""" 

836 if op in grads: 

837 return grads[op] 

838 else: 

839 return [[] for _ in range(len(op.outputs))] 

840 

841 

842def _AccumulatorShape(inputs): 

843 shape = tensor_shape.unknown_shape() 

844 for i in inputs: 

845 if isinstance(i, ops.Tensor): 

846 shape = shape.merge_with(i.get_shape()) 

847 return shape 

848 

849 

850def _LogOpGradients(op, out_grads, in_grads): 

851 """Log the in and out grads of an op.""" 

852 logging.vlog(1, "Gradient for '" + op.name + "'") 

853 

854 def _FilterGrad(x): 

855 if x is None: 

856 return False 

857 if isinstance(x, (list, tuple)): 

858 return bool(x) 

859 else: 

860 return True 

861 

862 logging.vlog(1, " in --> %s", 

863 ", ".join(x.name for x in out_grads if _FilterGrad(x))) 

864 logging.vlog(1, " out --> %s", 

865 ", ".join(x.name for x in in_grads if _FilterGrad(x))) 

866 

867 

868def _MultiDeviceAddN(tensor_list, gradient_uid): 

869 """Adds tensors from potentially multiple devices.""" 

870 # Basic function structure comes from control_flow_ops.group(). 

871 # Sort tensors according to their devices. 

872 tensors_on_device = collections.defaultdict(lambda: []) 

873 for tensor in tensor_list: 

874 tensors_on_device[tensor.device].append(tensor) 

875 

876 # For each device, add the tensors on that device first. 

877 # Then gather the partial sums from multiple devices. 

878 # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion. 

879 # E.g., aggregate per GPU, then per task, and so on. 

880 summands = [] 

881 

882 def DeviceKey(dev): 

883 return "" if dev is None else dev 

884 

885 for dev in sorted(tensors_on_device, key=DeviceKey): 

886 tensors = tensors_on_device[dev] 

887 with ops._colocate_with_for_gradient( # pylint: disable=protected-access 

888 tensors[0].op, 

889 gradient_uid, 

890 ignore_existing=True): 

891 summands.append(math_ops.add_n(tensors)) 

892 

893 return math_ops.add_n(summands) 

894 

895 

896@tf_export("AggregationMethod") 

897class AggregationMethod: 

898 """A class listing aggregation methods used to combine gradients. 

899 

900 Computing partial derivatives can require aggregating gradient 

901 contributions. This class lists the various methods that can 

902 be used to combine gradients in the graph. 

903 

904 The following aggregation methods are part of the stable API for 

905 aggregating gradients: 

906 

907 * `ADD_N`: All of the gradient terms are summed as part of one 

908 operation using the "AddN" op (see `tf.add_n`). This 

909 method has the property that all gradients must be ready and 

910 buffered separately in memory before any aggregation is performed. 

911 * `DEFAULT`: The system-chosen default aggregation method. 

912 

913 The following aggregation methods are experimental and may not 

914 be supported in future releases: 

915 

916 * `EXPERIMENTAL_TREE`: Gradient terms are summed in pairs using 

917 the "AddN" op. This method of summing gradients may reduce 

918 performance, but it can improve memory utilization because the 

919 gradients can be released earlier. 

920 * `EXPERIMENTAL_ACCUMULATE_N`: Same as `EXPERIMENTAL_TREE`. 

921 

922 Example usage when computing gradient: 

923 

924 >>> @tf.function 

925 ... def example(): 

926 ... x = tf.constant(1.0) 

927 ... y = x * 2.0 

928 ... z = y + y + y + y 

929 ... return tf.gradients(z, [x, y], 

930 ... aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) 

931 >>> example() 

932 [<tf.Tensor: shape=(), dtype=float32, numpy=8.0>, 

933 <tf.Tensor: shape=(), dtype=float32, numpy=4.0>] 

934 

935 """ 

936 ADD_N = 0 

937 DEFAULT = ADD_N 

938 # The following are experimental and may not be supported in future releases. 

939 EXPERIMENTAL_TREE = 1 

940 EXPERIMENTAL_ACCUMULATE_N = 2 # An alias for EXPERIMENTAL_ADD_N = 1 

941 

942 

943def _AggregatedGrads(grads, 

944 op, 

945 gradient_uid, 

946 loop_state, 

947 aggregation_method=None): 

948 """Get the aggregated gradients for op. 

949 

950 Args: 

951 grads: The map of memoized gradients. 

952 op: The op to get gradients for. 

953 gradient_uid: A unique identifier within the graph indicating 

954 which invocation of gradients is being executed. Used to cluster 

955 ops for compilation. 

956 loop_state: An object for maintaining the state of the while loops in the 

957 graph. It is of type ControlFlowState. None if the graph 

958 contains no while loops. 

959 aggregation_method: Specifies the method used to combine gradient terms. 

960 Accepted values are constants defined in the class `AggregationMethod`. 

961 

962 Returns: 

963 A list of gradients, one per each output of `op`. If the gradients 

964 for a particular output is a list, this function aggregates it 

965 before returning. 

966 

967 Raises: 

968 TypeError: if the incoming grads are not Tensors or IndexedSlices. 

969 ValueError: if the arguments are invalid. 

970 

971 """ 

972 if aggregation_method is None: 

973 aggregation_method = AggregationMethod.DEFAULT 

974 valid_aggregation_methods = [ 

975 AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE, 

976 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N] 

977 if aggregation_method not in valid_aggregation_methods: 

978 raise ValueError( 

979 f"Invalid `aggregation_method` specified {aggregation_method}. " 

980 f"Accepted values are {valid_aggregation_methods}.") 

981 out_grads = _GetGrads(grads, op) 

982 for i, out_grad in enumerate(out_grads): 

983 if loop_state: 

984 if isinstance(out_grad, (ops.Tensor, indexed_slices.IndexedSlices)): 

985 assert control_flow_util.IsLoopSwitch(op) 

986 continue 

987 # Grads have to be Tensors or IndexedSlices 

988 if (isinstance(out_grad, collections_abc.Sequence) and not all( 

989 isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)) 

990 for g in out_grad 

991 if g is not None)): 

992 raise TypeError(f"Invalid gradient {out_grad} [index = {i}]. Gradients " 

993 "have to be either all Tensors or all IndexedSlices") 

994 # Aggregate multiple gradients, and convert [] to None. 

995 if out_grad: 

996 if len(out_grad) < 2: 

997 used = "nop" 

998 out_grads[i] = out_grad[0] 

999 elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None): 

1000 tensor_shape = _AccumulatorShape(out_grad) 

1001 if aggregation_method in [ 

1002 AggregationMethod.EXPERIMENTAL_TREE, 

1003 AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 

1004 ]: 

1005 # Aggregate all gradients by doing pairwise sums: this may 

1006 # reduce performance, but it can improve memory because the 

1007 # gradients can be released earlier. 

1008 # 

1009 # TODO(vrv): Consider replacing this with a version of 

1010 # tf.AddN() that eagerly frees its inputs as soon as they are 

1011 # ready, so the order of this tree does not become a problem. 

1012 used = "tree" 

1013 with ops.name_scope(op.name + "_gradient_sum"): 

1014 running_sum = out_grad[0] 

1015 for grad in out_grad[1:]: 

1016 running_sum = math_ops.add_n([running_sum, grad]) 

1017 out_grads[i] = running_sum 

1018 else: 

1019 used = "add_n" 

1020 out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid) 

1021 logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad), 

1022 tensor_shape, used) 

1023 else: 

1024 out_grads[i] = backprop_util.AggregateIndexedSlicesGradients(out_grad) # pylint: disable=protected-access 

1025 else: # not out_grad 

1026 # out_grads[i] is [], thus its aggregation is simply None. 

1027 out_grads[i] = None 

1028 return out_grads 

1029 

1030 

1031# Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are 

1032# unfortunately too slow to use here. 

1033POSSIBLE_GRADIENT_TYPES_NONE = 0 

1034POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1 

1035POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2 

1036 

1037 

1038def PossibleTapeGradientTypes(tensors): 

1039 """Determines whether and how `args` may require tape gradients.""" 

1040 return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors)