Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/custom_gradient.py: 18%

215 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Decorator to overrides the gradient for a function.""" 

16 

17from tensorflow.python.eager import backprop 

18from tensorflow.python.eager import context 

19from tensorflow.python.eager import record 

20from tensorflow.python.framework import composite_tensor_gradient 

21from tensorflow.python.framework import dtypes 

22from tensorflow.python.framework import ops 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import gen_array_ops 

25from tensorflow.python.ops import handle_data_util 

26from tensorflow.python.ops import math_ops 

27from tensorflow.python.ops import op_selector 

28from tensorflow.python.ops import resource_variable_ops 

29from tensorflow.python.ops import variable_scope 

30from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 

31from tensorflow.python.platform import tf_logging as logging 

32from tensorflow.python.util import nest 

33from tensorflow.python.util import tf_decorator 

34from tensorflow.python.util import tf_inspect 

35from tensorflow.python.util import variable_utils 

36from tensorflow.python.util.tf_export import tf_export 

37 

38 

39VAR_OP_TYPES = [ 

40 "VariableV2", 

41 "VarHandleOp", 

42] 

43 

44 

45@tf_export("custom_gradient") 

46def custom_gradient(f=None): 

47 """Decorator to define a function with a custom gradient. 

48 

49 This decorator allows fine grained control over the gradients of a sequence 

50 for operations. This may be useful for multiple reasons, including providing 

51 a more efficient or numerically stable gradient for a sequence of operations. 

52 

53 For example, consider the following function that commonly occurs in the 

54 computation of cross entropy and log likelihoods: 

55 

56 ```python 

57 def log1pexp(x): 

58 return tf.math.log(1 + tf.exp(x)) 

59 ``` 

60 

61 Due to numerical instability, the gradient of this function evaluated at x=100 

62 is NaN. For example: 

63 

64 ```python 

65 with tf.GradientTape() as tape: 

66 tape.watch(x) 

67 y=log1pexp(x) 

68 dy_dx = tape.gradient(y, x) # Will be NaN when evaluated. 

69 ``` 

70 

71 The gradient expression can be analytically simplified to provide numerical 

72 stability: 

73 

74 ```python 

75 @tf.custom_gradient 

76 def log1pexp(x): 

77 e = tf.exp(x) 

78 def grad(upstream): 

79 return upstream * (1 - 1 / (1 + e)) 

80 return tf.math.log(1 + e), grad 

81 ``` 

82 

83 With this definition, the gradient `dy_dx` at `x = 100` will be correctly 

84 evaluated as 1.0. 

85 

86 The variable `upstream` is defined as the upstream gradient. i.e. the gradient 

87 from all the layers or functions originating from this layer. The above 

88 example has no upstream functions, therefore `upstream = dy/dy = 1.0`. 

89 

90 Assume that `x_i` is `log1pexp` in the forward pass `x_1 = x_1(x_0)`, 

91 `x_2 = x_2(x_1)`, ..., `x_i = x_i(x_i-1)`, ..., `x_n = x_n(x_n-1)`. By 

92 chain rule we know that `dx_n/dx_0 = dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * 

93 dx_i/dx_i-1 * ... * dx_1/dx_0`. 

94 

95 In this case the gradient of our current function defined as 

96 `dx_i/dx_i-1 = (1 - 1 / (1 + e))`. The upstream gradient `upstream` would be 

97 `dx_n/dx_n-1 * dx_n-1/dx_n-2 * ... * dx_i+1/dx_i`. The upstream gradient 

98 multiplied by the current gradient is then passed downstream. 

99 

100 In case the function takes multiple variables as input, the `grad` 

101 function must also return the same number of variables. 

102 We take the function `z = x * y` as an example. 

103 

104 >>> @tf.custom_gradient 

105 ... def bar(x, y): 

106 ... def grad(upstream): 

107 ... dz_dx = y 

108 ... dz_dy = x 

109 ... return upstream * dz_dx, upstream * dz_dy 

110 ... z = x * y 

111 ... return z, grad 

112 >>> x = tf.constant(2.0, dtype=tf.float32) 

113 >>> y = tf.constant(3.0, dtype=tf.float32) 

114 >>> with tf.GradientTape(persistent=True) as tape: 

115 ... tape.watch(x) 

116 ... tape.watch(y) 

117 ... z = bar(x, y) 

118 >>> z 

119 <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 

120 >>> tape.gradient(z, x) 

121 <tf.Tensor: shape=(), dtype=float32, numpy=3.0> 

122 >>> tape.gradient(z, y) 

123 <tf.Tensor: shape=(), dtype=float32, numpy=2.0> 

124 

125 Nesting custom gradients can lead to unintuitive results. The default 

126 behavior does not correspond to n-th order derivatives. For example 

127 

128 ```python 

129 @tf.custom_gradient 

130 def op(x): 

131 y = op1(x) 

132 @tf.custom_gradient 

133 def grad_fn(dy): 

134 gdy = op2(x, y, dy) 

135 def grad_grad_fn(ddy): # Not the 2nd order gradient of op w.r.t. x. 

136 return op3(x, y, dy, ddy) 

137 return gdy, grad_grad_fn 

138 return y, grad_fn 

139 ``` 

140 

141 The function `grad_grad_fn` will be calculating the first order gradient 

142 of `grad_fn` with respect to `dy`, which is used to generate forward-mode 

143 gradient graphs from backward-mode gradient graphs, but is not the same as 

144 the second order gradient of `op` with respect to `x`. 

145 

146 Instead, wrap nested `@tf.custom_gradients` in another function: 

147 

148 ```python 

149 @tf.custom_gradient 

150 def op_with_fused_backprop(x): 

151 y, x_grad = fused_op(x) 

152 def first_order_gradient(dy): 

153 @tf.custom_gradient 

154 def first_order_custom(unused_x): 

155 def second_order_and_transpose(ddy): 

156 return second_order_for_x(...), gradient_wrt_dy(...) 

157 return x_grad, second_order_and_transpose 

158 return dy * first_order_custom(x) 

159 return y, first_order_gradient 

160 ``` 

161 

162 Additional arguments to the inner `@tf.custom_gradient`-decorated function 

163 control the expected return values of the innermost function. 

164 

165 The examples above illustrate how to specify custom gradients for functions 

166 which do not read from variables. The following example uses variables, which 

167 require special handling because they are effectively inputs of the forward 

168 function. 

169 

170 >>> weights = tf.Variable(tf.ones([2])) # Trainable variable weights 

171 >>> @tf.custom_gradient 

172 ... def linear_poly(x): 

173 ... # Creating polynomial 

174 ... poly = weights[1] * x + weights[0] 

175 ... 

176 ... def grad_fn(dpoly, variables): 

177 ... # dy/dx = weights[1] and we need to left multiply dpoly 

178 ... grad_xs = dpoly * weights[1] # Scalar gradient 

179 ... 

180 ... grad_vars = [] # To store gradients of passed variables 

181 ... assert variables is not None 

182 ... assert len(variables) == 1 

183 ... assert variables[0] is weights 

184 ... # Manually computing dy/dweights 

185 ... dy_dw = dpoly * tf.stack([x ** 1, x ** 0]) 

186 ... grad_vars.append( 

187 ... tf.reduce_sum(tf.reshape(dy_dw, [2, -1]), axis=1) 

188 ... ) 

189 ... return grad_xs, grad_vars 

190 ... return poly, grad_fn 

191 >>> x = tf.constant([1., 2., 3.]) 

192 >>> with tf.GradientTape(persistent=True) as tape: 

193 ... tape.watch(x) 

194 ... poly = linear_poly(x) 

195 >>> poly # poly = x + 1 

196 <tf.Tensor: shape=(3,), 

197 dtype=float32, 

198 numpy=array([2., 3., 4.], dtype=float32)> 

199 >>> tape.gradient(poly, x) # conventional scalar gradient dy/dx 

200 <tf.Tensor: shape=(3,), 

201 dtype=float32, 

202 numpy=array([1., 1., 1.], dtype=float32)> 

203 >>> tape.gradient(poly, weights) 

204 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 3.], dtype=float32)> 

205 

206 Above example illustrates usage of trainable variable `weights`. 

207 In the example, the inner `grad_fn` accepts an extra `variables` input 

208 parameter and also returns an extra `grad_vars` output. That extra argument 

209 is passed if the forward function reads any variables. You need to 

210 compute the gradient w.r.t. each of those `variables` and output it as a list 

211 of `grad_vars`. Note here that default value of `variables` is set to `None` 

212 when no variables are used in the forward function. 

213 

214 It should be noted `tf.GradientTape` is still watching the forward pass of a 

215 `tf.custom_gradient`, and will use the ops it watches. As a consequence, 

216 calling `tf.function` while the tape is still watching leads 

217 to a gradient graph being built. If an op is used in `tf.function` without 

218 registered gradient, a `LookupError` will be raised. 

219 

220 Users can insert `tf.stop_gradient` to customize this behavior. This 

221 is demonstrated in the example below. `tf.random.shuffle` does not have a 

222 registered gradient. As a result `tf.stop_gradient` is used to avoid the 

223 `LookupError`. 

224 

225 ```python 

226 x = tf.constant([0.3, 0.5], dtype=tf.float32) 

227 

228 @tf.custom_gradient 

229 def test_func_with_stop_grad(x): 

230 @tf.function 

231 def _inner_func(): 

232 # Avoid exception during the forward pass 

233 return tf.stop_gradient(tf.random.shuffle(x)) 

234 # return tf.random.shuffle(x) # This will raise 

235 

236 res = _inner_func() 

237 def grad(upstream): 

238 return upstream # Arbitrarily defined custom gradient 

239 return res, grad 

240 

241 with tf.GradientTape() as g: 

242 g.watch(x) 

243 res = test_func_with_stop_grad(x) 

244 

245 g.gradient(res, x) 

246 ``` 

247 

248 See also `tf.RegisterGradient` which registers a gradient function for a 

249 primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows 

250 for fine grained control over the gradient computation of a sequence of 

251 operations. 

252 

253 Note that if the decorated function uses `Variable`s, the enclosing variable 

254 scope must be using 

255 [ResourceVariables](https://www.tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables). 

256 

257 Args: 

258 f: function `f(*x)` that returns a tuple `(y, grad_fn)` where: 

259 - `x` is a sequence of (nested structures of) `Tensor` inputs to the 

260 function. 

261 - `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow 

262 operations in `f` to `x`. 

263 - `grad_fn` is a function with the signature `g(*grad_ys)` which returns 

264 a list of `Tensor`s the same size as (flattened) `x` - the derivatives 

265 of `Tensor`s in `y` with respect to the `Tensor`s in `x`. `grad_ys` is 

266 a sequence of `Tensor`s the same size as (flattened) `y` holding the 

267 initial value gradients for each `Tensor` in `y`. 

268 

269 In a pure mathematical sense, a vector-argument vector-valued function 

270 `f`'s derivatives should be its Jacobian matrix `J`. Here we are 

271 expressing the Jacobian `J` as a function `grad_fn` which defines how 

272 `J` will transform a vector `grad_ys` when left-multiplied with it 

273 (`grad_ys * J`, the vector-Jacobian product, or VJP). This functional 

274 representation of a matrix is convenient to use for chain-rule 

275 calculation (in e.g. the back-propagation algorithm). 

276 

277 If `f` uses `Variable`s (that are not part of the 

278 inputs), i.e. through `get_variable`, then `grad_fn` should have 

279 signature `g(*grad_ys, variables=None)`, where `variables` is a list of 

280 the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where 

281 `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>` 

282 with the derivatives of `Tensor`s in `y` with respect to the variables 

283 (that is, grad_vars has one Tensor per variable in variables). 

284 

285 Returns: 

286 A function `h(x)` which returns the same value as `f(x)[0]` and whose 

287 gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`. 

288 """ 

289 

290 if f is None: 

291 return lambda f: custom_gradient(f=f) 

292 

293 @Bind.decorator 

294 def decorated(wrapped, args, kwargs): 

295 """Decorated function with custom gradient.""" 

296 if context.executing_eagerly(): 

297 return _eager_mode_decorator(wrapped, args, kwargs) 

298 else: 

299 return _graph_mode_decorator(wrapped, args, kwargs) 

300 

301 return tf_decorator.make_decorator(f, decorated(f)) # pylint: disable=no-value-for-parameter 

302 

303 

304class Bind: 

305 """When called evaluates `d(f, args, kwargs)` but supports binding `f`. 

306 

307 >>> @Bind.decorator 

308 ... def my_decorator(f, args, kwargs): 

309 ... print("my_decorator called with", args, kwargs) 

310 ... return f(*args, **kwargs) 

311 

312 >>> class Foo: 

313 ... @my_decorator 

314 ... def bar(self, a, b, c): 

315 ... return a * b * c 

316 

317 >>> Foo.bar(None, 1, 2, c=3) 

318 my_decorator called with (None, 1, 2) {'c': 3} 

319 6 

320 

321 >>> foo = Foo() 

322 >>> foo.bar(1, 2, c=3) 

323 my_decorator called with (1, 2) {'c': 3} 

324 6 

325 """ 

326 

327 @classmethod 

328 def decorator(cls, d): 

329 return lambda f: Bind(f, d) 

330 

331 def __init__(self, f, d): 

332 self._f = f 

333 self._d = d 

334 

335 def __get__(self, instance, owner): 

336 if instance is not None: 

337 f = self._f.__get__(instance, owner) 

338 return tf_decorator.make_decorator(f, Bind(f, self._d)) 

339 else: 

340 return self 

341 

342 def __call__(self, *a, **k): 

343 return self._d(self._f, a, k) 

344 

345 

346def get_variable_by_name(var_name): 

347 """Given a variable name, retrieves a handle on the tensorflow Variable.""" 

348 global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 

349 

350 def _filter_fn(item): 

351 try: 

352 return var_name == item.op.name 

353 except AttributeError: 

354 # Collection items without operation are ignored. 

355 return False 

356 

357 candidate_vars = list(filter(_filter_fn, global_vars)) 

358 

359 if len(candidate_vars) >= 1: 

360 # Filter out non-trainable variables. 

361 candidate_vars = [v for v in candidate_vars if v.trainable] 

362 else: 

363 raise ValueError("Unsuccessful at finding variable {}.".format(var_name)) 

364 

365 if len(candidate_vars) == 1: 

366 return candidate_vars[0] 

367 elif len(candidate_vars) > 1: 

368 raise ValueError( 

369 "Unsuccessful at finding trainable variable {}. " 

370 "Number of candidates: {}. " 

371 "Candidates: {}".format(var_name, len(candidate_vars), candidate_vars)) 

372 else: 

373 # The variable is not trainable. 

374 return None 

375 

376 

377def _get_dependent_variables(input_ops, output_ops): 

378 """Finds variables involved in the subgraph between input_ops and output_ops. 

379 

380 Args: 

381 input_ops: Flattened list of input ops 

382 output_ops: Flattened list of output ops 

383 

384 Returns: 

385 A list of variables 

386 """ 

387 

388 # avoids the edge-case when input_ops == output_ops. 

389 output_ops = nest.map_structure(gen_array_ops.identity, output_ops) 

390 inbetween_ops = op_selector.get_backward_walk_ops( 

391 seed_ops=output_ops, 

392 stop_at_ts=input_ops, 

393 inclusive=False, 

394 only_differentiable=True) 

395 var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES) 

396 var_names = (op.name for op in var_ops) 

397 tf_vars = (get_variable_by_name(var_name) for var_name in var_names) 

398 tf_vars = [v for v in tf_vars if v is not None] 

399 return tf_vars 

400 

401 

402def generate_name(): 

403 return "CustomGradient-%s" % ops.uid() 

404 

405 

406def _graph_mode_decorator(f, args, kwargs): 

407 """Implement custom gradient decorator for graph mode.""" 

408 # TODO(rsepassi): Add support for kwargs 

409 if kwargs: 

410 raise ValueError( 

411 "The custom_gradient decorator currently supports keywords " 

412 "arguments only when eager execution is enabled.") 

413 name = generate_name() 

414 args = variable_utils.convert_variables_to_tensors(args) 

415 args = nest.map_structure(ops.convert_to_tensor, args, expand_composites=True) 

416 

417 # Checking global and local variables attempts to ensure that no non-resource 

418 # Variables are added to the graph. 

419 current_var_scope = variable_scope.get_variable_scope() 

420 before_vars = set([ 

421 v.ref() for v in current_var_scope.global_variables() + 

422 current_var_scope.local_variables() 

423 ]) 

424 with record.VariableWatcher() as variable_watcher: 

425 result, grad_fn = f(*args) 

426 

427 flat_args = composite_tensor_gradient.get_flat_tensors_for_gradients( 

428 nest.flatten(args)) 

429 flat_result = composite_tensor_gradient.get_flat_tensors_for_gradients( 

430 nest.flatten(result)) 

431 flat_result_len = len(flat_result) 

432 

433 after_vars = set([ 

434 v.ref() for v in current_var_scope.global_variables() + 

435 current_var_scope.local_variables() 

436 ]) 

437 new_vars = after_vars - before_vars 

438 new_vars_list = [v.deref() for v in new_vars] 

439 for v in new_vars_list: 

440 if not resource_variable_ops.is_resource_variable(v): 

441 raise TypeError( 

442 "All variables used by a function wrapped with @custom_gradient must " 

443 "be `ResourceVariable`s. Ensure that no `variable_scope` is created " 

444 "with `use_resource=False`.") 

445 

446 # The variables that grad_fn needs to return gradients for are the set of 

447 # variables used that are *not* part of the inputs. 

448 variables_in_tape = frozenset([ 

449 v.ref() for v in variable_watcher.watched_variables() 

450 ]) 

451 

452 graphs = {getattr(o, "graph", None) for o in flat_result} 

453 # Not all results may be tensors. However, we want to ensure all tensor 

454 # outputs are from the same graph and get a list of captured inputs for 

455 # variable search 

456 graphs.discard(None) # Discard non-graph outputs 

457 if graphs: 

458 if len(graphs) > 1: 

459 raise ValueError( 

460 "All custom_gradient outputs should be from the same graph") 

461 output_graph = graphs.pop() 

462 filtered_input_tensors = [] 

463 for i in flat_args: 

464 if i.graph == output_graph: 

465 filtered_input_tensors.append(i) 

466 else: 

467 filtered_input_tensors = flat_args 

468 

469 variables_in_subgraph = frozenset([ 

470 v.ref() for v in _get_dependent_variables( 

471 input_ops=filtered_input_tensors, output_ops=flat_result) 

472 ]) 

473 variables = sorted( 

474 [v.deref() for v in variables_in_subgraph.union(variables_in_tape)], 

475 key=lambda v: v.name) 

476 

477 grad_argspec = tf_inspect.getfullargspec(grad_fn) 

478 variables_in_signature = ("variables" in grad_argspec.args or 

479 "variables" in grad_argspec.kwonlyargs or 

480 grad_argspec.varkw) 

481 if variables and not variables_in_signature: 

482 raise TypeError( 

483 "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " 

484 "since function uses variables: {}".format(variables)) 

485 if variables_in_signature and not variables: 

486 # User seems to intend to use variables but none were captured. 

487 logging.vlog( 

488 1, "@custom_gradient grad_fn has 'variables' in signature, " 

489 "but no ResourceVariables were used on the forward pass.") 

490 

491 all_tensors = flat_result + flat_args + variables 

492 

493 def tape_grad_fn(*result_grad_components): 

494 """Custom grad fn wrapper.""" 

495 result_grads = composite_tensor_gradient.replace_flat_tensors_for_gradients( 

496 nest.flatten(result), result_grad_components[:flat_result_len]) 

497 if not isinstance(result_grads, (list, tuple)): 

498 result_grads = [result_grads] 

499 

500 if variables: 

501 input_grads, variable_grads = grad_fn(*result_grads, variables=variables) 

502 if len(variable_grads) != len(variables): 

503 raise ValueError("Must return gradient for each variable from " 

504 "@custom_gradient grad_fn.") 

505 else: 

506 input_grads = grad_fn(*result_grads) 

507 variable_grads = [] 

508 

509 # Need to return one value per input to the IdentityN, so pad the 

510 # gradients of the inputs of the custom_gradient function with the 

511 # gradients of the outputs as well. 

512 input_grads = composite_tensor_gradient.get_flat_tensors_for_gradients( 

513 nest.flatten(input_grads)) 

514 return ([None] * flat_result_len) + input_grads + variable_grads 

515 

516 @ops.RegisterGradient(name) 

517 def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable 

518 """Custom grad fn wrapper.""" 

519 return tape_grad_fn(*result_grads) 

520 

521 original_tensors = all_tensors 

522 with ops.get_default_graph().gradient_override_map({"IdentityN": name}): 

523 all_tensors = array_ops.identity_n(all_tensors) 

524 

525 original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] 

526 

527 # Propagate handle data for happier shape inference for resource variables. 

528 for i, t in enumerate(original_tensors): 

529 if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): 

530 all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access 

531 record.record_operation( 

532 f.__name__, all_tensors, original_tensors, tape_grad_fn) 

533 for ot, t in zip(original_tensors, all_tensors): 

534 handle_data_util.copy_handle_data(ot, t) 

535 flat_result = composite_tensor_gradient.replace_flat_tensors_for_gradients( 

536 nest.flatten(result), all_tensors[:flat_result_len]) 

537 return nest.pack_sequence_as(result, flat_result) 

538 

539 

540def _eager_mode_decorator(f, args, kwargs): 

541 """Implement custom gradient decorator for eager mode.""" 

542 with record.VariableWatcher() as variable_watcher: 

543 result, grad_fn = f(*args, **kwargs) 

544 flat_args = composite_tensor_gradient.get_flat_tensors_for_gradients( 

545 nest.flatten(args)) 

546 flat_kwargs = composite_tensor_gradient.get_flat_tensors_for_gradients( 

547 nest.flatten(kwargs)) 

548 all_inputs = flat_args + flat_kwargs 

549 # The variables that grad_fn needs to return gradients for are the set of 

550 # variables used that are *not* part of the inputs. 

551 variables = [ 

552 v.deref() # pylint: disable=g-complex-comprehension 

553 for v in set(v.ref() for v in variable_watcher.watched_variables()) 

554 if all(v.deref() is not i for i in all_inputs) 

555 ] 

556 grad_argspec = tf_inspect.getfullargspec(grad_fn) 

557 if (variables and ("variables" not in grad_argspec.args) and 

558 ("variables" not in grad_argspec.kwonlyargs) and 

559 not grad_argspec.varkw): 

560 raise TypeError( 

561 "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " 

562 "since function uses variables: {}".format(variables)) 

563 flat_result = composite_tensor_gradient.get_flat_tensors_for_gradients( 

564 nest.flatten(result)) 

565 # TODO(apassos) consider removing the identity below. 

566 flat_result = [gen_array_ops.identity(x) for x in flat_result] 

567 

568 input_tensors = [ 

569 ops.convert_to_tensor(x) for x in flat_args + list(variables)] 

570 

571 recorded_inputs = input_tensors 

572 arg_count = len(flat_args) 

573 

574 def actual_grad_fn(*result_grad_components): 

575 """Custom grad fn wrapper.""" 

576 result_grads = composite_tensor_gradient.replace_flat_tensors_for_gradients( 

577 nest.flatten(result), result_grad_components) 

578 if not isinstance(result_grads, (list, tuple)): 

579 result_grads = [result_grads] 

580 

581 if variables: 

582 input_grads, variable_grads = grad_fn(*result_grads, variables=variables) 

583 if len(variable_grads) != len(variables): 

584 raise ValueError("Must return gradient for each variable from " 

585 "@custom_gradient grad_fn.") 

586 else: 

587 input_grads = grad_fn(*result_grads) 

588 variable_grads = [] 

589 flat_grads = composite_tensor_gradient.get_flat_tensors_for_gradients( 

590 nest.flatten(input_grads)) 

591 if len(flat_grads) != arg_count: 

592 raise ValueError( 

593 f"custom_gradient function expected to return {arg_count} " 

594 f"gradients, but returned {len(flat_grads)} instead.") 

595 return flat_grads + variable_grads 

596 

597 record.record_operation(f.__name__, flat_result, recorded_inputs, 

598 actual_grad_fn) 

599 flat_result = composite_tensor_gradient.replace_flat_tensors_for_gradients( 

600 nest.flatten(result), flat_result) 

601 return nest.pack_sequence_as(result, flat_result) 

602 

603 

604@tf_export("recompute_grad") 

605def recompute_grad(f): 

606 """Defines a function as a recompute-checkpoint for the tape auto-diff. 

607 

608 Tape checkpointing is a technique to reduce the memory consumption of the 

609 auto-diff tape: 

610 

611 - Without tape checkpointing operations and intermediate values are 

612 recorded to the tape for use in the backward pass. 

613 

614 - With tape checkpointing, only the function call and its inputs are 

615 recorded. During back-propagation the `recompute_grad` custom gradient 

616 (`tf.custom_gradient`) recomputes the function under a localized Tape object. 

617 This recomputation of the function during backpropagation performs redundant 

618 calculation, but reduces the overall memory usage of the Tape. 

619 

620 >>> y = tf.Variable(1.0) 

621 

622 >>> def my_function(x): 

623 ... tf.print('running') 

624 ... z = x*y 

625 ... return z 

626 

627 >>> my_function_recompute = tf.recompute_grad(my_function) 

628 

629 >>> with tf.GradientTape() as tape: 

630 ... r = tf.constant(1.0) 

631 ... for i in range(4): 

632 ... r = my_function_recompute(r) 

633 running 

634 running 

635 running 

636 running 

637 

638 >>> grad = tape.gradient(r, [y]) 

639 running 

640 running 

641 running 

642 running 

643 

644 Without `recompute_grad`, the tape contains all intermitate steps, and no 

645 recomputation is performed. 

646 

647 >>> with tf.GradientTape() as tape: 

648 ... r = tf.constant(1.0) 

649 ... for i in range(4): 

650 ... r = my_function(r) 

651 running 

652 running 

653 running 

654 running 

655 

656 >>> grad = tape.gradient(r, [y]) 

657 

658 

659 If `f` was a `tf.keras` `Model` or `Layer` object, methods and attributes 

660 such as `f.variables` are not available on the returned function `g`. 

661 Either keep a reference of `f` , or use `g.__wrapped__` for accessing 

662 these variables and methods. 

663 

664 

665 >>> def print_running_and_return(x): 

666 ... tf.print("running") 

667 ... return x 

668 

669 >>> model = tf.keras.Sequential([ 

670 ... tf.keras.layers.Lambda(print_running_and_return), 

671 ... tf.keras.layers.Dense(2) 

672 ... ]) 

673 

674 >>> model_recompute = tf.recompute_grad(model) 

675 

676 >>> with tf.GradientTape(persistent=True) as tape: 

677 ... r = tf.constant([[1,2]]) 

678 ... for i in range(4): 

679 ... r = model_recompute(r) 

680 running 

681 running 

682 running 

683 running 

684 

685 >>> grad = tape.gradient(r, model.variables) 

686 running 

687 running 

688 running 

689 running 

690 

691 Alternatively, use the `__wrapped__` attribute to access the original 

692 model object. 

693 

694 >>> grad = tape.gradient(r, model_recompute.__wrapped__.variables) 

695 running 

696 running 

697 running 

698 running 

699 

700 

701 Args: 

702 f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs. 

703 

704 Returns: 

705 A function `g` wrapping `f` that defines a custom gradient, which recomputes 

706 `f` on the backwards pass of a gradient call. 

707 """ 

708 # TODO(cdfreeman) Add is_recomputing functionality from graph mode version 

709 

710 @custom_gradient 

711 def inner(*args, **kwargs): 

712 """Inner function closure for calculating gradients.""" 

713 current_var_scope = variable_scope.get_variable_scope() 

714 with record.stop_recording(): 

715 result = f(*args, **kwargs) 

716 

717 def grad_wrapper(*wrapper_args, variables=None): 

718 """Wrapper function to accomodate lack of kwargs in graph mode custom_gradient.""" 

719 

720 @custom_gradient 

721 def inner_recompute_grad(*dresult): 

722 """Nested custom gradient function for computing grads in reverse and forward mode autodiff.""" 

723 # Gradient calculation for reverse mode autodiff. 

724 with backprop.GradientTape() as t: 

725 id_args = nest.map_structure(gen_array_ops.identity, args) 

726 # Tuple `dresult` should contain at least one tensor. 

727 assert len(dresult) >= 1 

728 

729 if not context.executing_eagerly(): 

730 # XLA doesn't respect `tf.control_dependencies`. The code block 

731 # below manually adds a data dependency to `dresult` to ensure 

732 # recomputation of `f(*args, **kwargs)` happens after `dresult`. 

733 

734 # This works even if `dresult[0]` is a size 0 tensor as reduce_max 

735 # of a size 0 tensor returns -inf. Use reshape here to avoid reading 

736 # the entire `dresult[0]`. 

737 elem = math_ops.reduce_max(array_ops.reshape(dresult[0], [-1])[:1]) 

738 # Cast elem to bool in case elem is NaN. 

739 elem_bool = math_ops.cast(elem, dtypes.bool) 

740 dresult_dep = array_ops.where_v2( 

741 elem_bool == elem_bool, 0., float("nan")) # pylint: disable=comparison-with-itself 

742 id_args = nest.map_structure( 

743 lambda x: x + math_ops.cast(dresult_dep, x.dtype), id_args) 

744 

745 t.watch(id_args) 

746 if variables is not None: 

747 t.watch(variables) 

748 with variable_scope.variable_scope(current_var_scope): 

749 recomputed_result = f(*id_args, **kwargs) 

750 kw_vars = [] 

751 if variables is not None: 

752 kw_vars = list(variables) 

753 grads = t.gradient( 

754 recomputed_result, 

755 list(id_args) + kw_vars, 

756 output_gradients=dresult, 

757 unconnected_gradients=UnconnectedGradients.ZERO) 

758 

759 def transpose(*t_args, **t_kwargs): 

760 """Gradient function calculation for forward mode autodiff.""" 

761 # Just throw an error since gradients / activations are not stored on 

762 # tape for recompute. 

763 raise NotImplementedError( 

764 "recompute_grad tried to transpose grad of {}. " 

765 "Consider not using recompute_grad in forward mode" 

766 "autodiff".format(f.__name__)) 

767 

768 return (grads[:len(id_args)], grads[len(id_args):]), transpose 

769 

770 return inner_recompute_grad(*wrapper_args) 

771 

772 return result, grad_wrapper 

773 

774 return tf_decorator.make_decorator(f, inner) 

775 

776 

777@tf_export("grad_pass_through") 

778def grad_pass_through(f): 

779 """Creates a grad-pass-through op with the forward behavior provided in f. 

780 

781 Use this function to wrap any op, maintaining its behavior in the forward 

782 pass, but replacing the original op in the backward graph with an identity. 

783 For example: 

784 

785 ```python 

786 x = tf.Variable(1.0, name="x") 

787 z = tf.Variable(3.0, name="z") 

788 

789 with tf.GradientTape() as tape: 

790 # y will evaluate to 9.0 

791 y = tf.grad_pass_through(x.assign)(z**2) 

792 # grads will evaluate to 6.0 

793 grads = tape.gradient(y, z) 

794 ``` 

795 

796 Another example is a 'differentiable' moving average approximation, where 

797 gradients are allowed to flow into the last value fed to the moving average, 

798 but the moving average is still used for the forward pass: 

799 

800 ```python 

801 x = ... # Some scalar value 

802 # A moving average object, we don't need to know how this is implemented 

803 moving_average = MovingAverage() 

804 with backprop.GradientTape() as tape: 

805 # mavg_x will evaluate to the current running average value 

806 mavg_x = tf.grad_pass_through(moving_average)(x) 

807 grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0 

808 ``` 

809 

810 Args: 

811 f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor` 

812 outputs. 

813 

814 Returns: 

815 A function `h(x)` which returns the same values as `f(x)` and whose 

816 gradients are the same as those of an identity function. 

817 """ 

818 @custom_gradient 

819 def _grad_pass_through_op(*args, **kwargs): 

820 def grad(*args, **kwargs): 

821 variables = kwargs.get("variables") 

822 if variables is not None: 

823 # Variables involved in the wrapped op will not receive gradients. 

824 return args, [None] * len(variables) 

825 return args 

826 return f(*args, **kwargs), grad 

827 return tf_decorator.make_decorator(f, _grad_pass_through_op)