Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/math_grad.py: 24%

1208 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Gradients for operators defined in math_ops.py.""" 

16import numpy as np 

17 

18from tensorflow.python.compat import compat 

19from tensorflow.python.eager import context 

20from tensorflow.python.framework import constant_op 

21from tensorflow.python.framework import dtypes 

22from tensorflow.python.framework import ops 

23from tensorflow.python.framework import tensor_util 

24from tensorflow.python.ops import array_ops 

25from tensorflow.python.ops import gen_array_ops 

26from tensorflow.python.ops import gen_math_ops 

27from tensorflow.python.ops import math_ops 

28from tensorflow.python.ops import special_math_ops 

29 

30 

31def _safe_shape_div(x, y): 

32 """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`.""" 

33 return x // math_ops.maximum(y, 1) 

34 

35 

36@ops.RegisterGradient("ArgMax") 

37def _ArgMaxGrad(op, grad): 

38 del op, grad 

39 return [None, None] 

40 

41 

42@ops.RegisterGradient("ArgMin") 

43def _ArgMinGrad(op, grad): 

44 del op, grad 

45 return [None, None] 

46 

47 

48@ops.RegisterGradient("EuclideanNorm") 

49def _EuclideanNormGrad(op, grad): 

50 """Gradient for EuclideanNorm.""" 

51 

52 output = op.outputs[0] 

53 

54 if not op.get_attr("keep_dims"): 

55 output_shape_kept_dims = math_ops.reduced_shape( 

56 array_ops.shape(op.inputs[0]), op.inputs[1]) 

57 output = array_ops.reshape(output, output_shape_kept_dims) 

58 grad = array_ops.reshape(grad, output_shape_kept_dims) 

59 

60 return math_ops.truediv(op.inputs[0], output / grad), None 

61 

62 

63def SmartBroadcastGradientArgs(x, y, grad): 

64 """Optimized version of `broadcast_gradient_args` that caches results. 

65 

66 This implementation avoids creating `broadcast_gradient_args` ops in the case 

67 that the input shapes are fully defined, and provides hints to the calling 

68 code that can be used to avoid creating reduction and reshaping ops. 

69 

70 Args: 

71 x: The left input tensor to a broadcasting binary op. 

72 y: The right input tensor to a broadcasting binary op. 

73 grad: The incoming gradient tensor for a broadcasting binary op. 

74 

75 Returns: 

76 A pair of tuples, containing: 

77 * A 3-tuple of broadcast information for x, containing: 

78 * The shape of x (as a tuple or Tensor). 

79 * The reduction indices for x (as a tuple or Tensor). 

80 * A boolean, which if True, indicates that x's shape differs from grad's 

81 shape (and so x's gradient must be reduced and/or reshaped). 

82 * A 3-tuple of broadcast information for y, containing the respective 

83 details for y. 

84 """ 

85 # NOTE: It may be productive to apply these optimizations in the eager case 

86 # as well. 

87 if context.executing_eagerly() or not ( 

88 isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) 

89 and isinstance(grad, ops.Tensor)): 

90 sx = array_ops.shape(x) 

91 sy = array_ops.shape(y) 

92 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 

93 return (sx, rx, True), (sy, ry, True) 

94 

95 # pylint: disable=protected-access 

96 x_shape_tuple = x._shape_tuple() 

97 y_shape_tuple = y._shape_tuple() 

98 grad_shape_tuple = grad._shape_tuple() 

99 # pylint: enable=protected-access 

100 

101 if (x_shape_tuple is None or None in x_shape_tuple or 

102 y_shape_tuple is None or None in y_shape_tuple): 

103 sx = array_ops.shape_internal(x, optimize=False) 

104 sy = array_ops.shape_internal(y, optimize=False) 

105 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 

106 return (sx, rx, True), (sy, ry, True) 

107 

108 x_needs_reduction = x_shape_tuple != grad_shape_tuple 

109 y_needs_reduction = y_shape_tuple != grad_shape_tuple 

110 

111 # Get the default graph rather than relying on `x.graph`, `y.graph`, or 

112 # `grad.graph`, because these may be eager tensors. 

113 g = ops.get_default_graph() 

114 

115 try: 

116 rx, ry = g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] # pylint: disable=protected-access 

117 return (x_shape_tuple, rx, x_needs_reduction), ( 

118 y_shape_tuple, ry, y_needs_reduction) 

119 except KeyError: 

120 rx, ry = array_ops.broadcast_gradient_args(x_shape_tuple, y_shape_tuple) 

121 # TODO(mrry): If this becomes a bottleneck, add a multi-output version of 

122 # `TF_TryEvaluateConstant()`. 

123 rx_value = tuple(tensor_util.try_evaluate_constant(rx)) 

124 assert rx_value is not None 

125 ry_value = tuple(tensor_util.try_evaluate_constant(ry)) 

126 assert ry_value is not None 

127 g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] = ( # pylint: disable=protected-access 

128 rx_value, ry_value) 

129 

130 return (x_shape_tuple, rx_value, x_needs_reduction), ( 

131 y_shape_tuple, ry_value, y_needs_reduction) 

132 

133 

134_empty_tuple = () 

135 

136 

137def _IsScalar(x): 

138 return x._shape_tuple() is _empty_tuple # pylint: disable=protected-access 

139 

140 

141@ops.RegisterGradient("Sum") 

142def _SumGrad(op, grad): 

143 """Gradient for Sum.""" 

144 # Fast path for when reducing to a scalar and ndims is known: adds only 

145 # Reshape and Tile ops (and possibly a Shape). 

146 input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access 

147 if input_0_shape is not None: 

148 axes = tensor_util.constant_value(op.inputs[1]) 

149 if axes is not None: 

150 rank = len(input_0_shape) 

151 if np.array_equal(axes, np.arange(rank)): # Reduce all dims. 

152 if context.executing_eagerly(): 

153 ctx = context.context() 

154 new_shape = ctx.ones_rank_cache().get(rank) 

155 if new_shape is None: 

156 new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32) 

157 ctx.ones_rank_cache().put(rank, new_shape) 

158 else: 

159 new_shape = [1] * rank 

160 grad = array_ops.reshape(grad, new_shape) 

161 # If shape is not fully defined (but rank is), we use Shape. 

162 if None not in input_0_shape: 

163 input_shape = constant_op.constant(input_0_shape, dtype=dtypes.int32) 

164 else: 

165 input_shape = array_ops.shape(op.inputs[0]) 

166 return [array_ops.tile(grad, input_shape), None] 

167 elif None not in input_0_shape and not context.executing_eagerly(): 

168 # The shape and reduction indices are statically known, so we use a 

169 # graph-level cache to avoid recomputing `reduced_shape()` for each 

170 # invocation. 

171 graph = ops.get_default_graph() 

172 

173 # Canonicalize `axes` to be a tuple of indices. The incoming 

174 # value may be a scalar or a vector, and may include negative indices. 

175 axes = tuple(axes.reshape(-1)) 

176 

177 try: 

178 output_shape_kept_dims, tile_scaling = graph._reduced_shape_cache[ # pylint: disable=protected-access 

179 (input_0_shape, axes)] 

180 except KeyError: 

181 

182 # Compute and cache `output_shape_kept_dims` and `tile_scaling`. 

183 def EvaluateAsTuple(t): 

184 if tensor_util.is_tf_type(t): 

185 value = tensor_util.try_evaluate_constant(t) 

186 assert value is not None 

187 else: 

188 value = t 

189 return tuple(value) 

190 

191 output_shape_kept_dims = EvaluateAsTuple( 

192 math_ops.reduced_shape(input_0_shape, axes)) 

193 tile_scaling = EvaluateAsTuple( 

194 _safe_shape_div(input_0_shape, output_shape_kept_dims)) 

195 graph._reduced_shape_cache[(input_0_shape, axes)] = ( # pylint:disable=protected-access 

196 output_shape_kept_dims, tile_scaling) 

197 

198 grad = array_ops.reshape(grad, output_shape_kept_dims) 

199 return [array_ops.tile(grad, tile_scaling), None] 

200 

201 input_shape = array_ops.shape(op.inputs[0]) 

202 

203 if not op.get_attr("keep_dims"): 

204 with ops.colocate_with(input_shape): 

205 # TODO(apassos) remove this once device placement for eager ops makes 

206 # more sense. 

207 output_shape_kept_dims = math_ops.reduced_shape(input_shape, 

208 op.inputs[1]) 

209 grad = array_ops.reshape(grad, output_shape_kept_dims) 

210 return [array_ops.broadcast_to(grad, input_shape), None] 

211 

212 

213def _MinOrMaxGrad(op, grad): 

214 """Gradient for Min or Max. Amazingly it's precisely the same code.""" 

215 input_shape = array_ops.shape(op.inputs[0]) 

216 y = op.outputs[0] 

217 if not op.get_attr("keep_dims"): 

218 output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) 

219 y = array_ops.reshape(y, output_shape_kept_dims) 

220 grad = array_ops.reshape(grad, output_shape_kept_dims) 

221 else: 

222 output_shape_kept_dims = array_ops.shape(y) 

223 

224 # Compute the number of selected (maximum or minimum) elements in each 

225 # reduction dimension. If there are multiple minimum or maximum elements 

226 # then the gradient will be divided between them. 

227 indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype) 

228 num_selected = array_ops.reshape( 

229 math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims) 

230 

231 return [math_ops.divide(indicators, num_selected) * grad, None] 

232 

233 

234@ops.RegisterGradient("Max") 

235def _MaxGrad(op, grad): 

236 """Gradient for Max.""" 

237 return _MinOrMaxGrad(op, grad) 

238 

239 

240@ops.RegisterGradient("Min") 

241def _MinGrad(op, grad): 

242 return _MinOrMaxGrad(op, grad) 

243 

244 

245@ops.RegisterGradient("Mean") 

246def _MeanGrad(op, grad): 

247 """Gradient for Mean.""" 

248 sum_grad = _SumGrad(op, grad)[0] 

249 input_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access 

250 output_shape = op.outputs[0]._shape_tuple() # pylint: disable=protected-access 

251 if (input_shape is not None and output_shape is not None and 

252 None not in input_shape and None not in output_shape): 

253 input_size = np.prod(input_shape) 

254 output_size = np.prod(output_shape) 

255 factor = input_size // max(output_size, 1) 

256 factor = constant_op.constant(factor, dtype=sum_grad.dtype) 

257 else: 

258 input_shape = array_ops.shape(op.inputs[0]) 

259 output_shape = array_ops.shape(op.outputs[0]) 

260 factor = _safe_shape_div( 

261 math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape)) 

262 return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None 

263 

264 

265@ops.RegisterGradient("Prod") 

266def _ProdGrad(op, grad): 

267 """Gradient for Prod.""" 

268 # The gradient can be expressed by dividing the product by each entry of the 

269 # input tensor, but this approach can't deal with zeros in the input. 

270 # Here, we avoid this problem by composing the output as a product of two 

271 # cumprod operations. 

272 

273 input_shape = array_ops.shape(op.inputs[0]) 

274 # Reshape reduction indices for the case where the parameter is a scalar 

275 reduction_indices = array_ops.reshape(op.inputs[1], [-1]) 

276 

277 # Expand grad to full input shape 

278 if not op.get_attr("keep_dims"): 

279 output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) 

280 grad = array_ops.reshape(grad, output_shape_kept_dims) 

281 

282 grad = array_ops.broadcast_to(grad, input_shape) 

283 

284 # Pack all reduced dimensions into a single one, so we can perform the 

285 # cumprod ops. If the reduction dims list is empty, it defaults to float32, 

286 # so we need to cast here. We put all the shape-related ops on CPU to avoid 

287 # copying back and forth, and since listdiff is CPU only. 

288 with ops.device("/cpu:0"): 

289 rank = array_ops.rank(op.inputs[0]) 

290 reduction_indices = (reduction_indices + rank) % rank 

291 reduced = math_ops.cast(reduction_indices, dtypes.int32) 

292 idx = math_ops.range(0, rank) 

293 other, _ = gen_array_ops.list_diff(idx, reduced, dtypes.int32) 

294 perm = array_ops.concat([reduced, other], 0) 

295 reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced)) 

296 other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other)) 

297 permuted = array_ops.transpose(op.inputs[0], perm) 

298 permuted_shape = array_ops.shape(permuted) 

299 reshaped = array_ops.reshape(permuted, (reduced_num, other_num)) 

300 

301 # Calculate product, leaving out the current entry 

302 left = math_ops.cumprod(reshaped, axis=0, exclusive=True) 

303 right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True) 

304 # For complex inputs, the gradient is in the conjugate direction. 

305 y = array_ops.reshape( 

306 math_ops.conj(left) * math_ops.conj(right), permuted_shape) 

307 

308 # Invert the transpose and reshape operations. 

309 # Make sure to set the statically known shape information through a reshape. 

310 out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm)) 

311 return array_ops.reshape(out, input_shape), None 

312 

313 

314@ops.RegisterGradient("SegmentSum") 

315def _SegmentSumGrad(op, grad): 

316 """Gradient for SegmentSum.""" 

317 return array_ops.gather(grad, op.inputs[1]), None 

318 

319 

320@ops.RegisterGradient("SegmentMean") 

321def _SegmentMeanGrad(op, grad): 

322 """Gradient for SegmentMean.""" 

323 input_rank = array_ops.rank(op.inputs[0]) 

324 ones_shape = array_ops.concat([ 

325 array_ops.shape(op.inputs[1]), 

326 array_ops.ones( 

327 array_ops.expand_dims(input_rank - 1, 0), dtype=dtypes.int32) 

328 ], 0) 

329 ones = array_ops.ones(ones_shape, dtype=grad.dtype) 

330 scaled_grad = math_ops.divide(grad, math_ops.segment_sum(ones, op.inputs[1])) 

331 return array_ops.gather(scaled_grad, op.inputs[1]), None 

332 

333 

334@ops.RegisterGradient("SparseSegmentSum") 

335def _SparseSegmentSumGrad(op, grad): 

336 """Gradient for SparseSegmentSum.""" 

337 dim0 = array_ops.shape(op.inputs[0])[0] 

338 if compat.forward_compatible(2021, 6, 10): 

339 return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2], 

340 dim0), None, None) 

341 else: 

342 return (math_ops.unsorted_segment_sum( 

343 array_ops.gather(grad, op.inputs[2]), op.inputs[1], dim0), None, None) 

344 

345 

346@ops.RegisterGradient("SparseSegmentSumWithNumSegments") 

347def _SparseSegmentSumWithNumSegmentsGrad(op, grad): 

348 """Gradient for SparseSegmentSumWithNumSegments.""" 

349 dim0 = array_ops.shape(op.inputs[0])[0] 

350 if compat.forward_compatible(2021, 6, 10): 

351 return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2], 

352 dim0), None, None, None) 

353 else: 

354 return (math_ops.unsorted_segment_sum( 

355 array_ops.gather(grad, op.inputs[2]), op.inputs[1], 

356 dim0), None, None, None) 

357 

358 

359@ops.RegisterGradient("SparseSegmentMean") 

360def _SparseSegmentMeanGrad(op, grad): 

361 """Gradient for SparseSegmentMean.""" 

362 dim0 = array_ops.shape(op.inputs[0])[0] 

363 return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2], 

364 dim0), None, None) 

365 

366 

367@ops.RegisterGradient("SparseSegmentMeanWithNumSegments") 

368def _SparseSegmentMeanWithNumSegmentsGrad(op, grad): 

369 """Gradient for SparseSegmentMeanWithNumSegments.""" 

370 dim0 = array_ops.shape(op.inputs[0])[0] 

371 return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2], 

372 dim0), None, None, None) 

373 

374 

375@ops.RegisterGradient("SparseSegmentSqrtN") 

376def _SparseSegmentSqrtNGrad(op, grad): 

377 """Gradient for SparseSegmentSqrtN.""" 

378 dim0 = array_ops.shape(op.inputs[0])[0] 

379 return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2], 

380 dim0), None, None) 

381 

382 

383@ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments") 

384def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad): 

385 """Gradient for SparseSegmentSqrtNWithNumSegments.""" 

386 dim0 = array_ops.shape(op.inputs[0])[0] 

387 return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2], 

388 dim0), None, None, None) 

389 

390 

391def _SegmentMinOrMaxGrad(op, grad): 

392 """ Gradient for SegmentMin and SegmentMax. """ 

393 zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype) 

394 # Get the number of selected (minimum or maximum) elements in each segment. 

395 gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1]) 

396 is_selected = math_ops.equal(op.inputs[0], gathered_outputs) 

397 num_selected = math_ops.segment_sum( 

398 math_ops.cast(is_selected, grad.dtype), op.inputs[1]) 

399 # Compute the gradient for each segment. The gradient for the ith segment is 

400 # divided evenly among the selected elements in that segment. 

401 weighted_grads = math_ops.divide(grad, num_selected) 

402 gathered_grads = array_ops.gather(weighted_grads, op.inputs[1]) 

403 return array_ops.where_v2(is_selected, gathered_grads, zeros), None 

404 

405 

406@ops.RegisterGradient("SegmentMin") 

407def _SegmentMinGrad(op, grad): 

408 """Gradient for SegmentMin.""" 

409 return _SegmentMinOrMaxGrad(op, grad) 

410 

411 

412@ops.RegisterGradient("SegmentMax") 

413def _SegmentMaxGrad(op, grad): 

414 """Gradient for SegmentMax.""" 

415 return _SegmentMinOrMaxGrad(op, grad) 

416 

417 

418@ops.RegisterGradient("SegmentProd") 

419def _SegmentProdGrad(op, grad): 

420 """Gradient for SegmentProd. 

421 

422 The gradient can be expressed for each segment by dividing the segment's 

423 product by each element of the segment input tensor, but this approach can't 

424 deal with zeros in the input. 

425 Unlike reduce_prod we can't use cumsum here as individual segments may have 

426 a different number of elements. Therefore we consider three cases: 

427 1) A segment input contains no zeros and we can safely divide by the input 

428 tensor. 

429 2) A segment contains exactly one zero. Then the gradient of each input of 

430 the segment is zero except for the 0-input, there the gradient is 

431 the product of the remaining segment entries. 

432 3) A segment contains at least two zeros. The gradient is zero for all 

433 segment inputs. 

434 """ 

435 data = op.inputs[0] 

436 segment_ids = op.inputs[1] 

437 is_zero = math_ops.equal(data, 0) 

438 num_zeros = gen_math_ops.segment_sum( 

439 math_ops.cast(is_zero, dtype=dtypes.int32), segment_ids) 

440 # handle case 3 and set the gradient to 0 for segments with more than one 

441 # 0 as input 

442 grad = array_ops.where_v2( 

443 math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad) 

444 # replace all zeros with ones and compute the segment_prod 

445 non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(data), data) 

446 non_zero_prod = gen_math_ops.segment_prod(non_zero_data, segment_ids) 

447 gathered_prod = array_ops.gather(op.outputs[0], segment_ids) 

448 gathered_non_zero_prod = array_ops.gather(non_zero_prod, segment_ids) 

449 prod_divided_by_el = gathered_prod / non_zero_data 

450 # Now fetch the individual results for segments containing 0 and those that 

451 # don't. 

452 partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod, 

453 prod_divided_by_el) 

454 gathered_grad = array_ops.gather(grad, segment_ids) 

455 return gathered_grad * partial_derivative, None 

456 

457 

458def _GatherDropNegatives(params, 

459 ids, 

460 zero_clipped_indices=None, 

461 is_positive=None): 

462 """ Helper function for unsorted segment ops. 

463 

464 Gathers params for 

465 positive segment ids and gathers 0 for inputs with negative segment id. 

466 Also returns the clipped indices and a boolean mask with the same shape 

467 as ids where a positive id is masked as true. With this, the latter two 

468 can be passed as arguments to this function to reuse them. 

469 """ 

470 if zero_clipped_indices is None: 

471 zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids)) 

472 gathered = array_ops.gather(params, zero_clipped_indices) 

473 if is_positive is None: 

474 is_positive = math_ops.greater_equal(ids, 0) 

475 # tf.where(condition, x, y) requires condition to have the same shape as x 

476 # and y. 

477 is_positive_shape = array_ops.shape(is_positive) 

478 broadcastable_shape = array_ops.concat( 

479 [is_positive_shape, 

480 array_ops.ones([array_ops.rank(gathered) 

481 - array_ops.rank(is_positive)], 

482 dtype=is_positive_shape.dtype)], 

483 axis=0) 

484 is_positive = array_ops.reshape(is_positive, broadcastable_shape) 

485 is_positive = ( 

486 is_positive & array_ops.ones_like(gathered, dtype=dtypes.bool)) 

487 # replace gathered params of negative indices with 0 

488 zero_slice = array_ops.zeros_like(gathered) 

489 return (array_ops.where_v2(is_positive, gathered, 

490 zero_slice), zero_clipped_indices, is_positive) 

491 

492 

493def _UnsortedSegmentMinOrMaxGrad(op, grad): 

494 """ Gradient for UnsortedSegmentMin and UnsortedSegmentMax. """ 

495 # Get the number of selected (minimum or maximum) elements in each segment. 

496 gathered_outputs, zero_clipped_indices, is_positive = \ 

497 _GatherDropNegatives(op.outputs[0], op.inputs[1]) 

498 is_selected = math_ops.equal(op.inputs[0], gathered_outputs) 

499 is_selected = math_ops.logical_and(is_selected, is_positive) 

500 num_selected = math_ops.unsorted_segment_sum( 

501 math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2]) 

502 # Compute the gradient for each segment. The gradient for the ith segment is 

503 # divided evenly among the selected elements in that segment. 

504 weighted_grads = math_ops.divide(grad, num_selected) 

505 gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None, 

506 zero_clipped_indices, is_positive) 

507 zeros = array_ops.zeros_like(gathered_grads) 

508 return array_ops.where_v2(is_selected, gathered_grads, zeros), None, None 

509 

510 

511@ops.RegisterGradient("UnsortedSegmentSum") 

512def _UnsortedSegmentSumGrad(op, grad): 

513 """Gradient for UnsortedSegmentSum.""" 

514 return _GatherDropNegatives(grad, op.inputs[1])[0], None, None 

515 

516 

517@ops.RegisterGradient("UnsortedSegmentMax") 

518def _UnsortedSegmentMaxGrad(op, grad): 

519 """ Gradient for UnsortedSegmentMax. """ 

520 return _UnsortedSegmentMinOrMaxGrad(op, grad) 

521 

522 

523@ops.RegisterGradient("UnsortedSegmentMin") 

524def _UnsortedSegmentMinGrad(op, grad): 

525 """ Gradient for UnsortedSegmentMin. """ 

526 return _UnsortedSegmentMinOrMaxGrad(op, grad) 

527 

528 

529@ops.RegisterGradient("UnsortedSegmentProd") 

530def _UnsortedSegmentProdGrad(op, grad): 

531 """ Gradient for UnsortedSegmentProd. 

532 

533 The gradient can be expressed for each segment by dividing the segment's 

534 product by each element of the segment input tensor, but this approach can't 

535 deal with zeros in the input. 

536 Unlike reduce_prod we can't use cumsum here as individual segments may have 

537 a different number of elements. Therefore we consider three cases: 

538 1) A segment input contains no zeros and we can safely divide by the input 

539 tensor. 

540 2) A segment contains exactly one zero. Then the gradient of each input of 

541 the segment is zero except for the 0-input, there the gradient is 

542 the product of the remaining segment entries. 

543 3) A segment contains at least two zeros. The gradient is zero for all 

544 segment inputs. 

545 """ 

546 # Note that unsorted_segment_sum will filter out the negative indices, 

547 # so we don't need to do a logical_and with is_positive here 

548 is_zero = math_ops.equal(op.inputs[0], 0) 

549 num_zeros = gen_math_ops.unsorted_segment_sum( 

550 math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2]) 

551 # handle case 3 and set the gradient to 0 for segments with more than one 

552 # 0 as input 

553 grad = array_ops.where_v2( 

554 math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad) 

555 # replace all zeros with ones and compute the unsorted_segment_prod 

556 non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(op.inputs[0]), 

557 op.inputs[0]) 

558 non_zero_prod = gen_math_ops.unsorted_segment_prod(non_zero_data, 

559 op.inputs[1], op.inputs[2]) 

560 # clip the indices for gather to be positive 

561 zero_clipped_indices = math_ops.maximum(op.inputs[1], 

562 array_ops.zeros_like(op.inputs[1])) 

563 gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices) 

564 gathered_non_zero_prod = array_ops.gather(non_zero_prod, zero_clipped_indices) 

565 prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf. 

566 # Now fetch the individual results for segments containing 0 and those that 

567 # don't. is_zero will also fetch results for entries with negative index 

568 # but the following gather_drop_negatives sets the corresponding entry in 

569 # grad to 0 for these 

570 partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod, 

571 prod_divided_by_el) 

572 gathered_grad = _GatherDropNegatives(grad, op.inputs[1], 

573 zero_clipped_indices)[0] 

574 return gathered_grad * partial_derivative, None, None 

575 

576 

577@ops.RegisterGradient("Abs") 

578def _AbsGrad(op, grad): 

579 x = op.inputs[0] 

580 return grad * math_ops.sign(x) 

581 

582 

583@ops.RegisterGradient("Neg") 

584def _NegGrad(_, grad): 

585 """Returns -grad.""" 

586 return -grad 

587 

588 

589@ops.RegisterGradient("Inv") 

590def _InvGrad(op, grad): 

591 """Returns -grad * (1 / x^2).""" 

592 y = op.outputs[0] # y = 1 / x 

593 return gen_math_ops.reciprocal_grad(y, grad) 

594 

595 

596@ops.RegisterGradient("Reciprocal") 

597def _ReciprocalGrad(op, grad): 

598 """Returns -grad * (1 / x^2).""" 

599 y = op.outputs[0] # y = 1 / x 

600 return gen_math_ops.reciprocal_grad(y, grad) 

601 

602 

603@ops.RegisterGradient("InvGrad") 

604def _InvGradGrad(op, grad): 

605 b = op.inputs[1] 

606 # op.output[0]: y = -b * conj(a)^2 

607 with ops.control_dependencies([grad]): 

608 ca = math_ops.conj(op.inputs[0]) 

609 cg = math_ops.conj(grad) 

610 return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad) 

611 

612 

613@ops.RegisterGradient("ReciprocalGrad") 

614def _ReciprocalGradGrad(op, grad): 

615 b = op.inputs[1] 

616 # op.output[0]: y = -b * conj(a)^2 

617 with ops.control_dependencies([grad]): 

618 ca = math_ops.conj(op.inputs[0]) 

619 cg = math_ops.conj(grad) 

620 return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad) 

621 

622 

623@ops.RegisterGradient("Square") 

624def _SquareGrad(op, grad): 

625 x = op.inputs[0] 

626 # Added control dependencies to prevent 2*x from being computed too early. 

627 with ops.control_dependencies([grad]): 

628 x = math_ops.conj(x) 

629 y = constant_op.constant(2.0, dtype=x.dtype) 

630 return math_ops.multiply(grad, math_ops.multiply(x, y)) 

631 

632 

633@ops.RegisterGradient("Sqrt") 

634def _SqrtGrad(op, grad): 

635 y = op.outputs[0] # y = x^(1/2) 

636 return gen_math_ops.sqrt_grad(y, grad) 

637 

638 

639@ops.RegisterGradient("SqrtGrad") 

640def _SqrtGradGrad(op, grad): 

641 a = op.inputs[0] 

642 y = op.outputs[0] # y = 0.5 * b / conj(a) 

643 with ops.control_dependencies([grad]): 

644 ga = grad / a 

645 return -math_ops.conj(ga) * y, 0.5 * ga # pylint: disable=invalid-unary-operand-type 

646 

647 

648@ops.RegisterGradient("Rsqrt") 

649def _RsqrtGrad(op, grad): 

650 """Returns -0.5 * grad * conj(y)^3.""" 

651 y = op.outputs[0] # y = x^(-1/2) 

652 return gen_math_ops.rsqrt_grad(y, grad) 

653 

654 

655@ops.RegisterGradient("RsqrtGrad") 

656def _RsqrtGradGrad(op, grad): 

657 """Returns backprop gradient for f(a,b) = -0.5 * b * conj(a)^3.""" 

658 a = op.inputs[0] # a = x^{-1/2} 

659 b = op.inputs[1] # backprop gradient for a 

660 with ops.control_dependencies([grad]): 

661 ca = math_ops.conj(a) 

662 cg = math_ops.conj(grad) 

663 grad_a = -1.5 * cg * b * math_ops.square(ca) 

664 grad_b = gen_math_ops.rsqrt_grad(ca, grad) 

665 return grad_a, grad_b 

666 

667 

668@ops.RegisterGradient("Exp") 

669def _ExpGrad(op, grad): 

670 """Returns grad * exp(x).""" 

671 y = op.outputs[0] # y = e^x 

672 with ops.control_dependencies([grad]): 

673 y = math_ops.conj(y) 

674 return grad * y 

675 

676 

677@ops.RegisterGradient("Expm1") 

678def _Expm1Grad(op, grad): 

679 """Returns grad * exp(x).""" 

680 x = op.inputs[0] 

681 with ops.control_dependencies([grad]): 

682 x = math_ops.conj(x) 

683 y = math_ops.exp(x) 

684 return grad * y 

685 

686 

687@ops.RegisterGradient("Log") 

688def _LogGrad(op, grad): 

689 """Returns grad * (1/x).""" 

690 x = op.inputs[0] 

691 with ops.control_dependencies([grad]): 

692 x = math_ops.conj(x) 

693 return grad * math_ops.reciprocal(x) 

694 

695 

696@ops.RegisterGradient("Log1p") 

697def _Log1pGrad(op, grad): 

698 """Returns grad * (1/(1 + x)).""" 

699 x = op.inputs[0] 

700 with ops.control_dependencies([grad]): 

701 x = math_ops.conj(x) 

702 return grad * math_ops.reciprocal(1 + x) 

703 

704 

705@ops.RegisterGradient("Xlogy") 

706def _XLogyGrad(op, grad): 

707 """Returns gradient of xlogy(x, y) with respect to x and y.""" 

708 x = op.inputs[0] 

709 y = op.inputs[1] 

710 sx = array_ops.shape(x) 

711 sy = array_ops.shape(y) 

712 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 

713 with ops.control_dependencies([grad]): 

714 not_zero_x = math_ops.cast( 

715 math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) 

716 partial_x = gen_math_ops.xlogy(not_zero_x, y) 

717 partial_y = gen_math_ops.xdivy(x, y) 

718 return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), 

719 array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) 

720 

721 

722@ops.RegisterGradient("Xlog1py") 

723def _XLog1pyGrad(op, grad): 

724 """Returns gradient of xlog1py(x, y) with respect to x and y.""" 

725 x = op.inputs[0] 

726 y = op.inputs[1] 

727 sx = array_ops.shape(x) 

728 sy = array_ops.shape(y) 

729 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 

730 with ops.control_dependencies([grad]): 

731 not_zero_x = math_ops.cast( 

732 math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) 

733 partial_x = gen_math_ops.xlog1py(not_zero_x, y) 

734 partial_y = gen_math_ops.xdivy(x, y + 1.) 

735 return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), 

736 array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) 

737 

738 

739@ops.RegisterGradient("Xdivy") 

740def _XDivyGrad(op, grad): 

741 """Returns gradient of xdivy(x, y) with respect to x and y.""" 

742 x = op.inputs[0] 

743 y = op.inputs[1] 

744 sx = array_ops.shape(x) 

745 sy = array_ops.shape(y) 

746 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 

747 with ops.control_dependencies([grad]): 

748 not_zero_x = math_ops.cast( 

749 math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) 

750 partial_x = gen_math_ops.xdivy(not_zero_x, y) 

751 partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2) 

752 return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), 

753 array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) 

754 

755 

756@ops.RegisterGradient("Sinh") 

757def _SinhGrad(op, grad): 

758 """Returns grad * cosh(x).""" 

759 x = op.inputs[0] 

760 with ops.control_dependencies([grad]): 

761 x = math_ops.conj(x) 

762 return grad * math_ops.cosh(x) 

763 

764 

765@ops.RegisterGradient("Cosh") 

766def _CoshGrad(op, grad): 

767 """Returns grad * sinh(x).""" 

768 x = op.inputs[0] 

769 with ops.control_dependencies([grad]): 

770 x = math_ops.conj(x) 

771 return grad * math_ops.sinh(x) 

772 

773 

774@ops.RegisterGradient("Tanh") 

775def _TanhGrad(op, grad): 

776 """Returns grad * (1 - tanh(x) * tanh(x)).""" 

777 y = op.outputs[0] # y = tanh(x) 

778 with ops.control_dependencies([grad]): 

779 y = math_ops.conj(y) 

780 return gen_math_ops.tanh_grad(y, grad) 

781 

782 

783@ops.RegisterGradient("Asinh") 

784def _AsinhGrad(op, grad): 

785 """Returns grad * 1/cosh(y).""" 

786 y = op.outputs[0] 

787 with ops.control_dependencies([grad]): 

788 y = math_ops.conj(y) 

789 return grad / math_ops.cosh(y) 

790 

791 

792@ops.RegisterGradient("Acosh") 

793def _AcoshGrad(op, grad): 

794 """Returns grad * 1/sinh(y).""" 

795 y = op.outputs[0] 

796 with ops.control_dependencies([grad]): 

797 y = math_ops.conj(y) 

798 return grad / math_ops.sinh(y) 

799 

800 

801@ops.RegisterGradient("Atanh") 

802def _AtanhGrad(op, grad): 

803 """Returns grad * 1/ (1 - x^2).""" 

804 x = op.inputs[0] 

805 with ops.control_dependencies([grad]): 

806 x = math_ops.conj(x) 

807 x2 = math_ops.square(x) 

808 one = constant_op.constant(1, dtype=grad.dtype) 

809 inv = math_ops.reciprocal(math_ops.subtract(one, x2)) 

810 return grad * inv 

811 

812 

813@ops.RegisterGradient("TanhGrad") 

814def _TanhGradGrad(op, grad): 

815 with ops.control_dependencies([grad]): 

816 a = math_ops.conj(op.inputs[0]) 

817 b = math_ops.conj(op.inputs[1]) 

818 return grad * -2.0 * b * a, gen_math_ops.tanh_grad(a, grad) 

819 

820 

821@ops.RegisterGradient("Erf") 

822def _ErfGrad(op, grad): 

823 """Returns grad * 2/sqrt(pi) * exp(-x**2).""" 

824 x = op.inputs[0] 

825 two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype) 

826 with ops.control_dependencies([grad]): 

827 x = math_ops.conj(x) 

828 return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x)) 

829 

830 

831@ops.RegisterGradient("Erfc") 

832def _ErfcGrad(op, grad): 

833 """Returns -grad * 2/sqrt(pi) * exp(-x**2).""" 

834 x = op.inputs[0] 

835 minus_two_over_root_pi = constant_op.constant( 

836 -2 / np.sqrt(np.pi), dtype=grad.dtype) 

837 with ops.control_dependencies([grad]): 

838 x = math_ops.conj(x) 

839 return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x)) 

840 

841 

842@ops.RegisterGradient("Erfinv") 

843def _ErfinvGrad(op, grad): 

844 """Returns grad * sqrt(pi) / 2 * exp(erfinv(x)**2).""" 

845 root_pi_over_two = constant_op.constant(np.sqrt(np.pi) / 2, dtype=grad.dtype) 

846 with ops.control_dependencies([grad]): 

847 return grad * root_pi_over_two * math_ops.exp( 

848 math_ops.square(op.outputs[0])) 

849 

850 

851@ops.RegisterGradient("Ndtri") 

852def _NdtriGrad(op, grad): 

853 """Returns grad * sqrt(2 * pi) * exp(ndtri(x)**2 / 2).""" 

854 root_two_pi = constant_op.constant(np.sqrt(2 * np.pi), dtype=grad.dtype) 

855 with ops.control_dependencies([grad]): 

856 return grad * root_two_pi * math_ops.exp( 

857 math_ops.square(op.outputs[0]) / 2.) 

858 

859 

860@ops.RegisterGradient("Lgamma") 

861def _LgammaGrad(op, grad): 

862 """Returns grad * digamma(x).""" 

863 x = op.inputs[0] 

864 with ops.control_dependencies([grad]): 

865 x = math_ops.conj(x) 

866 return grad * math_ops.digamma(x) 

867 

868 

869@ops.RegisterGradient("Digamma") 

870def _DigammaGrad(op, grad): 

871 """Compute gradient of the digamma function with respect to its argument.""" 

872 x = op.inputs[0] 

873 with ops.control_dependencies([grad]): 

874 x = math_ops.conj(x) 

875 partial_x = math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x) 

876 return grad * partial_x 

877 

878 

879@ops.RegisterGradient("Dawsn") 

880def _DawsnGrad(op, grad): 

881 """Compute gradient of dawsn(x) with respect to its argument.""" 

882 x = op.inputs[0] 

883 y = op.outputs[0] 

884 with ops.control_dependencies([grad]): 

885 return grad * (1. - 2 * x * y) 

886 

887 

888@ops.RegisterGradient("Expint") 

889def _ExpintGrad(op, grad): 

890 """Compute gradient of expint(x) with respect to its argument.""" 

891 x = op.inputs[0] 

892 with ops.control_dependencies([grad]): 

893 return grad * math_ops.exp(x) / x 

894 

895 

896@ops.RegisterGradient("FresnelCos") 

897def _FresnelCosGrad(op, grad): 

898 """Compute gradient of fresnel_cos(x) with respect to its argument.""" 

899 x = op.inputs[0] 

900 with ops.control_dependencies([grad]): 

901 return grad * math_ops.cos((np.pi / 2.) * math_ops.square(x)) 

902 

903 

904@ops.RegisterGradient("FresnelSin") 

905def _FresnelSinGrad(op, grad): 

906 """Compute gradient of fresnel_sin(x) with respect to its argument.""" 

907 x = op.inputs[0] 

908 with ops.control_dependencies([grad]): 

909 return grad * math_ops.sin((np.pi / 2.) * math_ops.square(x)) 

910 

911 

912@ops.RegisterGradient("Spence") 

913def _SpenceGrad(op, grad): 

914 """Compute gradient of spence(x) with respect to its argument.""" 

915 x = op.inputs[0] 

916 with ops.control_dependencies([grad]): 

917 partial_x = math_ops.log(x) / (1 - x) 

918 partial_x = array_ops.where( 

919 math_ops.equal(x, 1.), -array_ops.ones_like(x), partial_x) # pylint: disable=invalid-unary-operand-type 

920 return grad * partial_x 

921 

922 

923@ops.RegisterGradient("BesselI0") 

924def _BesselI0Grad(op, grad): 

925 """Compute gradient of bessel_i0(x) with respect to its argument.""" 

926 x = op.inputs[0] 

927 with ops.control_dependencies([grad]): 

928 partial_x = special_math_ops.bessel_i1(x) 

929 return grad * partial_x 

930 

931 

932@ops.RegisterGradient("BesselI0e") 

933def _BesselI0eGrad(op, grad): 

934 """Compute gradient of bessel_i0e(x) with respect to its argument.""" 

935 x = op.inputs[0] 

936 y = op.outputs[0] 

937 with ops.control_dependencies([grad]): 

938 partial_x = (special_math_ops.bessel_i1e(x) - math_ops.sign(x) * y) 

939 return grad * partial_x 

940 

941 

942@ops.RegisterGradient("BesselI1") 

943def _BesselI1Grad(op, grad): 

944 """Compute gradient of bessel_i1(x) with respect to its argument.""" 

945 x = op.inputs[0] 

946 y = op.outputs[0] 

947 with ops.control_dependencies([grad]): 

948 # For x = 0, the correct gradient is 1.0. 

949 # However, the main branch gives NaN because of the division by x, so 

950 # we impute the gradient manually. 

951 # An alternative solution is to express the gradient via bessel_i0 and 

952 # bessel_i2, but the latter is not yet implemented in Eigen. 

953 dy_dx = array_ops.where_v2( 

954 math_ops.equal(x, 0.), math_ops.cast(1., x.dtype), 

955 special_math_ops.bessel_i0(x) - math_ops.div(y, x)) 

956 return grad * dy_dx 

957 

958 

959@ops.RegisterGradient("BesselI1e") 

960def _BesselI1eGrad(op, grad): 

961 """Compute gradient of bessel_i1e(x) with respect to its argument.""" 

962 x = op.inputs[0] 

963 y = op.outputs[0] 

964 with ops.control_dependencies([grad]): 

965 # For x = 0, the correct gradient is 0.5. 

966 # However, the main branch gives NaN because of the division by x, so 

967 # we impute the gradient manually. 

968 # An alternative solution is to express the gradient via bessel_i0e and 

969 # bessel_i2e, but the latter is not yet implemented in Eigen. 

970 dy_dx = array_ops.where_v2( 

971 math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype), 

972 special_math_ops.bessel_i0e(x) - y * 

973 (math_ops.sign(x) + math_ops.reciprocal(x))) 

974 return grad * dy_dx 

975 

976 

977@ops.RegisterGradient("BesselK0") 

978def _BesselK0Grad(op, grad): 

979 """Compute gradient of bessel_k0(x) with respect to its argument.""" 

980 x = op.inputs[0] 

981 with ops.control_dependencies([grad]): 

982 partial_x = -special_math_ops.bessel_k1(x) 

983 return grad * partial_x 

984 

985 

986@ops.RegisterGradient("BesselK0e") 

987def _BesselK0eGrad(op, grad): 

988 """Compute gradient of bessel_k0e(x) with respect to its argument.""" 

989 x = op.inputs[0] 

990 y = op.outputs[0] 

991 with ops.control_dependencies([grad]): 

992 partial_x = (y - special_math_ops.bessel_k1e(x)) 

993 return grad * partial_x 

994 

995 

996@ops.RegisterGradient("BesselK1") 

997def _BesselK1Grad(op, grad): 

998 """Compute gradient of bessel_k1(x) with respect to its argument.""" 

999 x = op.inputs[0] 

1000 y = op.outputs[0] 

1001 with ops.control_dependencies([grad]): 

1002 # At 0., this is NaN which is fine since the derivative is undefined 

1003 # at 0. 

1004 partial_x = -special_math_ops.bessel_k0(x) - math_ops.div(y, x) 

1005 return grad * partial_x 

1006 

1007 

1008@ops.RegisterGradient("BesselK1e") 

1009def _BesselK1eGrad(op, grad): 

1010 """Compute gradient of bessel_k1e(x) with respect to its argument.""" 

1011 x = op.inputs[0] 

1012 y = op.outputs[0] 

1013 with ops.control_dependencies([grad]): 

1014 # At 0., this is NaN which is fine since the derivative is undefined 

1015 # at 0. 

1016 partial_x = ( 

1017 y * (1. - math_ops.reciprocal(x)) - special_math_ops.bessel_k0e(x)) 

1018 return grad * partial_x 

1019 

1020 

1021@ops.RegisterGradient("BesselJ0") 

1022def _BesselJ0Grad(op, grad): 

1023 """Compute gradient of bessel_j0(x) with respect to its argument.""" 

1024 x = op.inputs[0] 

1025 with ops.control_dependencies([grad]): 

1026 partial_x = -special_math_ops.bessel_j1(x) 

1027 return grad * partial_x 

1028 

1029 

1030@ops.RegisterGradient("BesselJ1") 

1031def _BesselJ1Grad(op, grad): 

1032 """Compute gradient of bessel_j1(x) with respect to its argument.""" 

1033 x = op.inputs[0] 

1034 y = op.outputs[0] 

1035 with ops.control_dependencies([grad]): 

1036 # For x = 0, the correct gradient is 0.5. 

1037 # However, the main branch gives NaN because of the division by x, so 

1038 # we impute the gradient manually. 

1039 # An alternative solution is to express the gradient via bessel_i0e and 

1040 # bessel_i2e, but the latter is not yet implemented in Eigen. 

1041 dy_dx = array_ops.where_v2( 

1042 math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype), 

1043 special_math_ops.bessel_j0(x) - math_ops.div(y, x)) 

1044 return grad * dy_dx 

1045 

1046 

1047@ops.RegisterGradient("BesselY0") 

1048def _BesselY0Grad(op, grad): 

1049 """Compute gradient of bessel_y0(x) with respect to its argument.""" 

1050 x = op.inputs[0] 

1051 with ops.control_dependencies([grad]): 

1052 partial_x = -special_math_ops.bessel_y1(x) 

1053 return grad * partial_x 

1054 

1055 

1056@ops.RegisterGradient("BesselY1") 

1057def _BesselY1Grad(op, grad): 

1058 """Compute gradient of bessel_y1(x) with respect to its argument.""" 

1059 x = op.inputs[0] 

1060 y = op.outputs[0] 

1061 with ops.control_dependencies([grad]): 

1062 # At 0., this is NaN which is fine since the derivative is undefined 

1063 # at 0. 

1064 partial_x = special_math_ops.bessel_y0(x) - math_ops.div(y, x) 

1065 return grad * partial_x 

1066 

1067 

1068@ops.RegisterGradient("Igamma") 

1069def _IgammaGrad(op, grad): 

1070 """Returns gradient of igamma(a, x) with respect to a and x.""" 

1071 a = op.inputs[0] 

1072 x = op.inputs[1] 

1073 sa = array_ops.shape(a) 

1074 sx = array_ops.shape(x) 

1075 ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) 

1076 

1077 with ops.control_dependencies([grad]): 

1078 partial_a = gen_math_ops.igamma_grad_a(a, x) 

1079 # Perform operations in log space before summing, because Gamma(a) 

1080 # and Gamma'(a) can grow large. 

1081 partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - 

1082 math_ops.lgamma(a)) 

1083 return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa), 

1084 array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) 

1085 

1086 

1087@ops.RegisterGradient("Igammac") 

1088def _IgammacGrad(op, grad): 

1089 """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x.""" 

1090 igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad) 

1091 return (-igamma_grad_a, -igamma_grad_x) 

1092 

1093 

1094@ops.RegisterGradient("Betainc") 

1095def _BetaincGrad(op, grad): 

1096 """Returns gradient of betainc(a, b, x) with respect to x.""" 

1097 # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b 

1098 a, b, x = op.inputs 

1099 

1100 # two cases: x is a scalar and a/b are same-shaped tensors, or vice 

1101 # versa; so its sufficient to check against shape(a). 

1102 sa = array_ops.shape(a) 

1103 sx = array_ops.shape(x) 

1104 _, rx = gen_array_ops.broadcast_gradient_args(sa, sx) 

1105 

1106 # Perform operations in log space before summing, because terms 

1107 # can grow large. 

1108 log_beta = ( 

1109 gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) - 

1110 gen_math_ops.lgamma(a + b)) 

1111 # We use xlog1py and xlogy since the derivatives should tend to 

1112 # zero one of the tails when a is 1. or b is 1. 

1113 partial_x = math_ops.exp(math_ops.xlog1py(b - 1, -x) + 

1114 math_ops.xlogy(a - 1, x) - log_beta) 

1115 

1116 return ( 

1117 None, # da 

1118 None, # db 

1119 array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) 

1120 

1121 

1122@ops.RegisterGradient("Zeta") 

1123def _ZetaGrad(op, grad): 

1124 """Returns gradient of zeta(x, q) with respect to x and q.""" 

1125 # TODO(tillahoffmann): Add derivative with respect to x 

1126 x = op.inputs[0] 

1127 q = op.inputs[1] 

1128 # Broadcast gradients 

1129 sx = array_ops.shape(x) 

1130 sq = array_ops.shape(q) 

1131 unused_rx, rq = gen_array_ops.broadcast_gradient_args(sx, sq) 

1132 # Evaluate gradient 

1133 with ops.control_dependencies([grad]): 

1134 x = math_ops.conj(x) 

1135 q = math_ops.conj(q) 

1136 partial_q = -x * math_ops.zeta(x + 1, q) # pylint: disable=invalid-unary-operand-type 

1137 return (None, 

1138 array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq)) 

1139 

1140 

1141@ops.RegisterGradient("Polygamma") 

1142def _PolygammaGrad(op, grad): 

1143 """Returns gradient of psi(n, x) with respect to n and x.""" 

1144 # TODO(tillahoffmann): Add derivative with respect to n 

1145 n = op.inputs[0] 

1146 x = op.inputs[1] 

1147 # Broadcast gradients 

1148 sn = array_ops.shape(n) 

1149 sx = array_ops.shape(x) 

1150 unused_rn, rx = gen_array_ops.broadcast_gradient_args(sn, sx) 

1151 # Evaluate gradient 

1152 with ops.control_dependencies([grad]): 

1153 n = math_ops.conj(n) 

1154 x = math_ops.conj(x) 

1155 partial_x = math_ops.polygamma(n + 1, x) 

1156 return (None, 

1157 array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) 

1158 

1159 

1160@ops.RegisterGradient("Sigmoid") 

1161def _SigmoidGrad(op, grad): 

1162 """Returns grad * sigmoid(x) * (1 - sigmoid(x)).""" 

1163 y = op.outputs[0] # y = sigmoid(x) 

1164 with ops.control_dependencies([grad]): 

1165 y = math_ops.conj(y) 

1166 return gen_math_ops.sigmoid_grad(y, grad) 

1167 

1168 

1169@ops.RegisterGradient("SigmoidGrad") 

1170def _SigmoidGradGrad(op, grad): 

1171 with ops.control_dependencies([grad]): 

1172 a = math_ops.conj(op.inputs[0]) 

1173 b = math_ops.conj(op.inputs[1]) 

1174 gb = grad * b 

1175 return gb - 2.0 * gb * a, gen_math_ops.sigmoid_grad(a, grad) 

1176 

1177 

1178@ops.RegisterGradient("Sign") 

1179def _SignGrad(op, _): 

1180 """Returns 0.""" 

1181 x = op.inputs[0] 

1182 return array_ops.zeros_like(x) 

1183 

1184 

1185@ops.RegisterGradient("Sin") 

1186def _SinGrad(op, grad): 

1187 """Returns grad * cos(x).""" 

1188 x = op.inputs[0] 

1189 with ops.control_dependencies([grad]): 

1190 x = math_ops.conj(x) 

1191 return grad * math_ops.cos(x) 

1192 

1193 

1194@ops.RegisterGradient("Cos") 

1195def _CosGrad(op, grad): 

1196 """Returns grad * -sin(x).""" 

1197 x = op.inputs[0] 

1198 with ops.control_dependencies([grad]): 

1199 x = math_ops.conj(x) 

1200 return -grad * math_ops.sin(x) 

1201 

1202 

1203@ops.RegisterGradient("Tan") 

1204def _TanGrad(op, grad): 

1205 """Returns grad * 1/sec^2(x).""" 

1206 x = op.inputs[0] 

1207 with ops.control_dependencies([grad]): 

1208 x = math_ops.conj(x) 

1209 secx = math_ops.reciprocal(math_ops.cos(x)) 

1210 secx2 = math_ops.square(secx) 

1211 return secx2 * grad 

1212 

1213 

1214@ops.RegisterGradient("Asin") 

1215def _AsinGrad(op, grad): 

1216 """Returns grad * 1/sqrt(1-x^2).""" 

1217 x = op.inputs[0] 

1218 with ops.control_dependencies([grad]): 

1219 x = math_ops.conj(x) 

1220 x2 = math_ops.square(x) 

1221 one = constant_op.constant(1, dtype=grad.dtype) 

1222 den = math_ops.sqrt(math_ops.subtract(one, x2)) 

1223 inv = math_ops.reciprocal(den) 

1224 return grad * inv 

1225 

1226 

1227@ops.RegisterGradient("Acos") 

1228def _AcosGrad(op, grad): 

1229 """Returns grad * -1/sqrt(1-x^2).""" 

1230 x = op.inputs[0] 

1231 with ops.control_dependencies([grad]): 

1232 x = math_ops.conj(x) 

1233 x2 = math_ops.square(x) 

1234 one = constant_op.constant(1, dtype=grad.dtype) 

1235 den = math_ops.sqrt(math_ops.subtract(one, x2)) 

1236 inv = math_ops.reciprocal(den) 

1237 return -grad * inv 

1238 

1239 

1240@ops.RegisterGradient("Atan") 

1241def _AtanGrad(op, grad): 

1242 """Returns grad * 1/ (1 + x^2).""" 

1243 x = op.inputs[0] 

1244 with ops.control_dependencies([grad]): 

1245 x = math_ops.conj(x) 

1246 x2 = math_ops.square(x) 

1247 one = constant_op.constant(1, dtype=grad.dtype) 

1248 inv = math_ops.reciprocal(math_ops.add(one, x2)) 

1249 return grad * inv 

1250 

1251 

1252@ops.RegisterGradient("Atan2") 

1253def _Atan2Grad(op, grad): 

1254 """Returns grad * x / (x^2 + y^2), grad * -y / (x^2 + y^2).""" 

1255 y = op.inputs[0] 

1256 x = op.inputs[1] 

1257 with ops.control_dependencies([grad]): 

1258 grad_inv = grad / (math_ops.square(x) + math_ops.square(y)) 

1259 return x * grad_inv, -y * grad_inv 

1260 

1261 

1262@ops.RegisterGradient("AddN") 

1263def _AddNGrad(op, grad): 

1264 """Copies the gradient to all inputs.""" 

1265 # Not broadcasting. 

1266 return [grad] * len(op.inputs) 

1267 

1268 

1269def _ShapesFullySpecifiedAndEqual(x, y, grad): 

1270 # pylint: disable=protected-access 

1271 x_shape = x._shape_tuple() 

1272 y_shape = y._shape_tuple() 

1273 grad_shape = grad._shape_tuple() 

1274 # pylint: enable=protected-access 

1275 return (x_shape == y_shape and x_shape == grad_shape and 

1276 x_shape is not None and None not in x_shape) 

1277 

1278 

1279@ops.RegisterGradient("Add") 

1280@ops.RegisterGradient("AddV2") 

1281def _AddGrad(op, grad): 

1282 """Gradient for Add.""" 

1283 y = op.inputs[1] 

1284 skip_input_indices = None 

1285 try: 

1286 skip_input_indices = op.skip_input_indices 

1287 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( 

1288 y): 

1289 return grad, None 

1290 except AttributeError: 

1291 # No gradient skipping, so do the full gradient computation 

1292 pass 

1293 x = op.inputs[0] 

1294 if (isinstance(grad, ops.Tensor) and 

1295 _ShapesFullySpecifiedAndEqual(x, y, grad)): 

1296 return grad, grad 

1297 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( 

1298 SmartBroadcastGradientArgs(x, y, grad)) 

1299 if skip_input_indices is not None and 0 in skip_input_indices: 

1300 gx = None 

1301 elif not must_reduce_x: 

1302 gx = grad 

1303 else: 

1304 gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) 

1305 if skip_input_indices is not None and 1 in skip_input_indices: 

1306 gy = None 

1307 elif not must_reduce_y: 

1308 gy = grad 

1309 else: 

1310 gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy) 

1311 return (gx, gy) 

1312 

1313 

1314@ops.RegisterGradient("Sub") 

1315def _SubGrad(op, grad): 

1316 """Gradient for Sub.""" 

1317 y = op.inputs[1] 

1318 skip_input_indices = None 

1319 try: 

1320 skip_input_indices = op.skip_input_indices 

1321 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( 

1322 y): 

1323 return grad, None 

1324 except AttributeError: 

1325 # No gradient skipping, so do the full gradient computation 

1326 pass 

1327 x = op.inputs[0] 

1328 if (isinstance(grad, ops.Tensor) and 

1329 _ShapesFullySpecifiedAndEqual(x, y, grad)): 

1330 return grad, -grad 

1331 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( 

1332 SmartBroadcastGradientArgs(x, y, grad)) 

1333 if skip_input_indices is not None and 0 in skip_input_indices: 

1334 gx = None 

1335 elif not must_reduce_x: 

1336 gx = grad 

1337 else: 

1338 gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) 

1339 if skip_input_indices is not None and 1 in skip_input_indices: 

1340 gy = None 

1341 elif not must_reduce_y: 

1342 gy = -grad 

1343 else: 

1344 gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy) 

1345 return (gx, gy) 

1346 

1347 

1348@ops.RegisterGradient("Mul") 

1349def _MulGrad(op, grad): 

1350 """The gradient of scalar multiplication.""" 

1351 y = op.inputs[1] 

1352 skip_input_indices = None 

1353 try: 

1354 skip_input_indices = op.skip_input_indices 

1355 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( 

1356 y): 

1357 return gen_math_ops.mul(grad, math_ops.conj(y)), None 

1358 except AttributeError: 

1359 # No gradient skipping, so do the full gradient computation 

1360 pass 

1361 x = op.inputs[0] 

1362 if (isinstance(grad, ops.Tensor) and 

1363 _ShapesFullySpecifiedAndEqual(x, y, grad) and 

1364 grad.dtype in (dtypes.int32, dtypes.float32)): 

1365 return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) 

1366 assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) 

1367 

1368 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( 

1369 SmartBroadcastGradientArgs(x, y, grad)) 

1370 x = math_ops.conj(x) 

1371 y = math_ops.conj(y) 

1372 if skip_input_indices is not None and 0 in skip_input_indices: 

1373 gx = None 

1374 elif not must_reduce_x: 

1375 gx = gen_math_ops.mul(grad, y) 

1376 else: 

1377 gx = array_ops.reshape( 

1378 math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx) 

1379 if skip_input_indices is not None and 1 in skip_input_indices: 

1380 gy = None 

1381 elif not must_reduce_y: 

1382 gy = gen_math_ops.mul(x, grad) 

1383 else: 

1384 gy = array_ops.reshape( 

1385 math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy) 

1386 return (gx, gy) 

1387 

1388 

1389@ops.RegisterGradient("MulNoNan") 

1390def _MulNoNanGrad(op, grad): 

1391 """The gradient of scalar multiplication with NaN-suppression.""" 

1392 x = op.inputs[0] 

1393 y = op.inputs[1] 

1394 if (isinstance(grad, ops.Tensor) and 

1395 _ShapesFullySpecifiedAndEqual(x, y, grad)): 

1396 return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad) 

1397 assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) 

1398 sx = array_ops.shape(x) 

1399 sy = array_ops.shape(y) 

1400 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 

1401 return (array_ops.reshape( 

1402 math_ops.reduce_sum(gen_math_ops.mul_no_nan(grad, y), rx), sx), 

1403 array_ops.reshape( 

1404 math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry), sy)) 

1405 

1406 

1407@ops.RegisterGradient("Div") 

1408def _DivGrad(op, grad): 

1409 """The gradient for the Div operator.""" 

1410 x = op.inputs[0] 

1411 y = op.inputs[1] 

1412 sx = array_ops.shape(x) 

1413 sy = array_ops.shape(y) 

1414 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 

1415 x = math_ops.conj(x) 

1416 y = math_ops.conj(y) 

1417 # pylint: disable=invalid-unary-operand-type 

1418 return ( 

1419 array_ops.reshape(math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx), 

1420 array_ops.reshape( 

1421 math_ops.reduce_sum(grad * math_ops.divide(math_ops.divide(-x, y), y), 

1422 ry), sy)) 

1423 

1424 

1425@ops.RegisterGradient("FloorDiv") 

1426def _FloorDivGrad(_, unused_grad): 

1427 """The gradient for the FloorDiv operator.""" 

1428 return None, None 

1429 

1430 

1431@ops.RegisterGradient("FloorMod") 

1432def _FloorModGrad(op, grad): 

1433 """Returns grad * (1, -floor(x/y)).""" 

1434 x = math_ops.conj(op.inputs[0]) 

1435 y = math_ops.conj(op.inputs[1]) 

1436 

1437 sx = array_ops.shape(x) 

1438 sy = array_ops.shape(y) 

1439 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 

1440 floor_xy = math_ops.floor_div(x, y) 

1441 gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) 

1442 gy = array_ops.reshape( 

1443 math_ops.reduce_sum(grad * math_ops.negative(floor_xy), ry), sy) 

1444 return gx, gy 

1445 

1446 

1447@ops.RegisterGradient("TruncateDiv") 

1448def _TruncateDivGrad(_, unused_grad): 

1449 return None, None 

1450 

1451 

1452@ops.RegisterGradient("RealDiv") 

1453def _RealDivGrad(op, grad): 

1454 """RealDiv op gradient.""" 

1455 x = op.inputs[0] 

1456 y = op.inputs[1] 

1457 sx = array_ops.shape(x) 

1458 sy = array_ops.shape(y) 

1459 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 

1460 x = math_ops.conj(x) 

1461 y = math_ops.conj(y) 

1462 return (array_ops.reshape( 

1463 math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx), 

1464 array_ops.reshape( 

1465 math_ops.reduce_sum( 

1466 grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy)) # pylint: disable=invalid-unary-operand-type 

1467 

1468 

1469@ops.RegisterGradient("DivNoNan") 

1470def _DivNoNanGrad(op, grad): 

1471 """DivNoNan op gradient.""" 

1472 x = op.inputs[0] 

1473 y = op.inputs[1] 

1474 sx = array_ops.shape(x) 

1475 sy = array_ops.shape(y) 

1476 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 

1477 x = math_ops.conj(x) 

1478 y = math_ops.conj(y) 

1479 return ( 

1480 array_ops.reshape( 

1481 math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), 

1482 array_ops.reshape( 

1483 math_ops.reduce_sum( 

1484 grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), # pylint: disable=invalid-unary-operand-type 

1485 ry), 

1486 sy)) 

1487 

1488 

1489@ops.RegisterGradient("Pow") 

1490def _PowGrad(op, grad): 

1491 """Returns grad * (y*x^(y-1), z*log(x)).""" 

1492 x = op.inputs[0] 

1493 y = op.inputs[1] 

1494 skip_input_indices = None 

1495 try: 

1496 skip_input_indices = op.skip_input_indices 

1497 # TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the 

1498 # constant `1` into a single constant op. 

1499 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( 

1500 y): 

1501 x = math_ops.conj(x) 

1502 y = math_ops.conj(y) 

1503 return grad * y * math_ops.pow(x, y - 1), None 

1504 

1505 except AttributeError: 

1506 # No gradient skipping, so do the full gradient computation 

1507 pass 

1508 

1509 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( 

1510 SmartBroadcastGradientArgs(x, y, grad)) 

1511 x = math_ops.conj(x) 

1512 y = math_ops.conj(y) 

1513 

1514 if skip_input_indices is None or 0 not in skip_input_indices: 

1515 gx = grad * y * math_ops.pow(x, y - 1) 

1516 if must_reduce_x: 

1517 gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx) 

1518 else: 

1519 gx = None 

1520 

1521 if skip_input_indices is None or 1 not in skip_input_indices: 

1522 z = math_ops.conj(op.outputs[0]) 

1523 

1524 # Avoid false singularity at x = 0 

1525 if x.dtype.is_complex: 

1526 # real(x) < 0 is fine for the complex case 

1527 mask = math_ops.not_equal(x, 0) 

1528 else: 

1529 # There's no sensible real value to return if x < 0, so return 0 

1530 mask = x > 0 

1531 safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) 

1532 log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) 

1533 gy = grad * z * log_x 

1534 if must_reduce_y: 

1535 gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy) 

1536 else: 

1537 gy = None 

1538 

1539 return gx, gy 

1540 

1541 

1542def _MaximumMinimumGradInputOnly(op, grad, selector_op): 

1543 x = op.inputs[0] 

1544 y = op.inputs[1] 

1545 zeros = array_ops.zeros_like(grad) 

1546 xmask = selector_op(x, y) 

1547 xgrad = array_ops.where_v2(xmask, grad, zeros) 

1548 ygrad = None # Return None for ygrad since the config allows that. 

1549 return (xgrad, ygrad) 

1550 

1551 

1552def _MaximumMinimumGrad(op, grad, selector_op): 

1553 """Factor out the code for the gradient of Maximum or Minimum.""" 

1554 y = op.inputs[1] 

1555 skip_input_indices = None 

1556 try: 

1557 skip_input_indices = op.skip_input_indices 

1558 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( 

1559 y): 

1560 # When we want to get gradients for the first input only, and the second 

1561 # input tensor is a scalar, we can do a much simpler calculation 

1562 return _MaximumMinimumGradInputOnly(op, grad, selector_op) 

1563 except AttributeError: 

1564 # No gradient skipping, so do the full gradient computation 

1565 pass 

1566 x = op.inputs[0] 

1567 sx = array_ops.shape(x) 

1568 sy = array_ops.shape(y) 

1569 zeros = array_ops.zeros_like(grad) 

1570 xmask = selector_op(x, y) 

1571 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 

1572 if skip_input_indices is not None and 0 in skip_input_indices: 

1573 gx = None 

1574 else: 

1575 xgrad = array_ops.where_v2(xmask, grad, zeros) 

1576 gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx) 

1577 

1578 if skip_input_indices is not None and 1 in skip_input_indices: 

1579 gy = None 

1580 else: 

1581 ygrad = array_ops.where_v2(xmask, zeros, grad) 

1582 gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy) 

1583 

1584 return (gx, gy) 

1585 

1586 

1587@ops.RegisterGradient("Maximum") 

1588def _MaximumGrad(op, grad): 

1589 """Returns grad*(x >= y, x < y) with type of grad.""" 

1590 return _MaximumMinimumGrad(op, grad, math_ops.greater_equal) 

1591 

1592 

1593@ops.RegisterGradient("Minimum") 

1594def _MinimumGrad(op, grad): 

1595 """Returns grad*(x <= y, x > y) with type of grad.""" 

1596 return _MaximumMinimumGrad(op, grad, math_ops.less_equal) 

1597 

1598 

1599@ops.RegisterGradient("SquaredDifference") 

1600def _SquaredDifferenceGrad(op, grad): 

1601 """Returns the gradient for (x-y)^2.""" 

1602 x = op.inputs[0] 

1603 y = op.inputs[1] 

1604 skip_input_indices = None 

1605 try: 

1606 skip_input_indices = op.skip_input_indices 

1607 except AttributeError: 

1608 # No gradient skipping, so do the full gradient computation 

1609 pass 

1610 

1611 with ops.control_dependencies([grad]): 

1612 # The parens ensure that if grad is IndexedSlices, it'll get multiplied by 

1613 # Tensor (not a number like 2.0) which causes it to convert to Tensor. 

1614 x_grad = math_ops.scalar_mul(2.0, grad) * (x - y) 

1615 

1616 if (isinstance(grad, ops.Tensor) and 

1617 _ShapesFullySpecifiedAndEqual(x, y, grad)): 

1618 return x_grad, -x_grad 

1619 

1620 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( 

1621 SmartBroadcastGradientArgs(x, y, grad)) 

1622 

1623 if skip_input_indices is not None and 0 in skip_input_indices: 

1624 gx = None 

1625 elif must_reduce_x: 

1626 gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx) 

1627 else: 

1628 gx = x_grad 

1629 

1630 if skip_input_indices is not None and 1 in skip_input_indices: 

1631 gy = None 

1632 elif must_reduce_y: 

1633 gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy) 

1634 else: 

1635 gy = -x_grad 

1636 return (gx, gy) 

1637 

1638 

1639# Logical operations have no gradients. 

1640ops.NotDifferentiable("Less") 

1641ops.NotDifferentiable("LessEqual") 

1642ops.NotDifferentiable("Greater") 

1643ops.NotDifferentiable("GreaterEqual") 

1644ops.NotDifferentiable("Equal") 

1645ops.NotDifferentiable("ApproximateEqual") 

1646ops.NotDifferentiable("NotEqual") 

1647ops.NotDifferentiable("LogicalAnd") 

1648ops.NotDifferentiable("LogicalOr") 

1649ops.NotDifferentiable("LogicalNot") 

1650 

1651 

1652@ops.RegisterGradient("Select") 

1653def _SelectGrad(op, grad): 

1654 c = op.inputs[0] 

1655 x = op.inputs[1] 

1656 zeros = array_ops.zeros_like(x) 

1657 return (None, array_ops.where(c, grad, zeros), array_ops.where( 

1658 c, zeros, grad)) 

1659 

1660 

1661@ops.RegisterGradient("SelectV2") 

1662def _SelectGradV2(op, grad): 

1663 c = op.inputs[0] 

1664 x = op.inputs[1] 

1665 y = op.inputs[2] 

1666 zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype) 

1667 gx = array_ops.where_v2(c, grad, zeros) 

1668 x_shape = array_ops.shape(x) 

1669 output_shape = array_ops.shape(op.outputs[0]) 

1670 # Reduce away broadcasted leading dims. 

1671 reduce_x, _ = gen_array_ops.broadcast_gradient_args(x_shape, output_shape) 

1672 gx = math_ops.reduce_sum(gx, keepdims=True, axis=reduce_x) 

1673 gx = array_ops.reshape(gx, x_shape) 

1674 

1675 gy = array_ops.where_v2(c, zeros, grad) 

1676 y_shape = array_ops.shape(y) 

1677 # Reduce away broadcasted leading dims. 

1678 reduce_y, _ = gen_array_ops.broadcast_gradient_args(y_shape, output_shape) 

1679 gy = math_ops.reduce_sum(gy, keepdims=True, axis=reduce_y) 

1680 gy = array_ops.reshape(gy, y_shape) 

1681 

1682 return (None, gx, gy) 

1683 

1684 

1685def _MatMulGradAgainstFirstOnly(op, grad): 

1686 """Gradient for MatMul, only for the first input.""" 

1687 t_a = op.get_attr("transpose_a") 

1688 t_b = op.get_attr("transpose_b") 

1689 b = math_ops.conj(op.inputs[1]) 

1690 if not t_a and not t_b: 

1691 grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) 

1692 elif not t_a and t_b: 

1693 grad_a = gen_math_ops.mat_mul(grad, b) 

1694 elif t_a and not t_b: 

1695 grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) 

1696 elif t_a and t_b: 

1697 grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) 

1698 return grad_a, None 

1699 

1700 

1701def _MatMulGradAgainstSecondOnly(op, grad): 

1702 """Gradient for MatMul, only for the second input.""" 

1703 t_a = op.get_attr("transpose_a") 

1704 t_b = op.get_attr("transpose_b") 

1705 a = math_ops.conj(op.inputs[0]) 

1706 if not t_a and not t_b: 

1707 grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) 

1708 elif not t_a and t_b: 

1709 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) 

1710 elif t_a and not t_b: 

1711 grad_b = gen_math_ops.mat_mul(a, grad) 

1712 elif t_a and t_b: 

1713 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) 

1714 return None, grad_b 

1715 

1716 

1717@ops.RegisterGradient("MatMul") 

1718def _MatMulGrad(op, grad): 

1719 """Gradient for MatMul.""" 

1720 try: 

1721 skip_input_indices = op.skip_input_indices 

1722 if skip_input_indices is not None: 

1723 if 1 in skip_input_indices: 

1724 return _MatMulGradAgainstFirstOnly(op, grad) 

1725 elif 0 in skip_input_indices: 

1726 return _MatMulGradAgainstSecondOnly(op, grad) 

1727 except AttributeError: 

1728 # No gradient skipping, so do the full gradient computation 

1729 pass 

1730 

1731 t_a = op.get_attr("transpose_a") 

1732 t_b = op.get_attr("transpose_b") 

1733 a = math_ops.conj(op.inputs[0]) 

1734 b = math_ops.conj(op.inputs[1]) 

1735 if not t_a and not t_b: 

1736 grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) 

1737 grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) 

1738 elif not t_a and t_b: 

1739 grad_a = gen_math_ops.mat_mul(grad, b) 

1740 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) 

1741 elif t_a and not t_b: 

1742 grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) 

1743 grad_b = gen_math_ops.mat_mul(a, grad) 

1744 elif t_a and t_b: 

1745 grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) 

1746 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) 

1747 return grad_a, grad_b 

1748 

1749 

1750@ops.RegisterGradient("SparseMatMul") 

1751def _SparseMatMulGrad(op, grad): 

1752 """Gradient for SparseMatMul.""" 

1753 

1754 t_a = op.get_attr("transpose_a") 

1755 t_b = op.get_attr("transpose_b") 

1756 is_sparse = {} 

1757 is_sparse[op.inputs[0].ref()] = op.get_attr("a_is_sparse") 

1758 is_sparse[op.inputs[1].ref()] = op.get_attr("b_is_sparse") 

1759 # Use heuristic to figure out if grad might be sparse 

1760 is_sparse[grad.ref()] = not context.executing_eagerly() and ( 

1761 grad.op.type == "ReluGrad") 

1762 

1763 def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False): 

1764 """Helper function to create SparseMatMul op.""" 

1765 

1766 assert t1.ref() in is_sparse and t2.ref() in is_sparse 

1767 t1_sparse = is_sparse[t1.ref()] 

1768 t2_sparse = is_sparse[t2.ref()] 

1769 if transpose_b: 

1770 t2 = array_ops.transpose(t2) 

1771 transpose_b = False 

1772 prod = math_ops.matmul( 

1773 t1, 

1774 t2, 

1775 transpose_a=transpose_a, 

1776 transpose_b=transpose_b, 

1777 a_is_sparse=t1_sparse, 

1778 b_is_sparse=t2_sparse) 

1779 if prod.dtype != out_dtype: 

1780 prod = math_ops.cast(prod, out_dtype) 

1781 return prod 

1782 

1783 dtype_a = op.inputs[0].dtype 

1784 dtype_b = op.inputs[1].dtype 

1785 if not t_a and not t_b: 

1786 return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True), 

1787 _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True)) 

1788 elif not t_a and t_b: 

1789 return (_SparseMatMul(grad, op.inputs[1], dtype_a), 

1790 _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True)) 

1791 elif t_a and not t_b: 

1792 return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True), 

1793 _SparseMatMul(op.inputs[0], grad, dtype_b)) 

1794 elif t_a and t_b: 

1795 return (_SparseMatMul( 

1796 op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True), 

1797 _SparseMatMul( 

1798 grad, op.inputs[0], dtype_b, transpose_a=True, 

1799 transpose_b=True)) 

1800 

1801 

1802@ops.RegisterGradient("Floor") 

1803def _FloorGrad(_, unused_grad): 

1804 return [None] 

1805 

1806 

1807@ops.RegisterGradient("Ceil") 

1808def _CeilGrad(_, unused_grad): 

1809 return [None] 

1810 

1811 

1812@ops.RegisterGradient("Round") 

1813def _RoundGrad(_, unused_grad): 

1814 return [None] 

1815 

1816 

1817@ops.RegisterGradient("Rint") 

1818def _RintGrad(_, unused_grad): 

1819 # the gradient of Rint is zero 

1820 return [None] 

1821 

1822 

1823@ops.RegisterGradient("BatchMatMul") 

1824def _BatchMatMul(op, grad): 

1825 """Returns the gradient of x and y given the gradient of x * y.""" 

1826 x = op.inputs[0] 

1827 y = op.inputs[1] 

1828 adj_x = op.get_attr("adj_x") 

1829 adj_y = op.get_attr("adj_y") 

1830 

1831 if not adj_x: 

1832 if not adj_y: 

1833 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True) 

1834 grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False) 

1835 else: 

1836 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False) 

1837 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False) 

1838 else: 

1839 if not adj_y: 

1840 grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True) 

1841 grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False) 

1842 else: 

1843 grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True) 

1844 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True) 

1845 

1846 return grad_x, grad_y 

1847 

1848 

1849@ops.RegisterGradient("BatchMatMulV2") 

1850@ops.RegisterGradient("BatchMatMulV3") 

1851def _BatchMatMulV2(op, grad): 

1852 """Returns the gradient of x and y given the gradient of x * y.""" 

1853 x = op.inputs[0] 

1854 y = op.inputs[1] 

1855 adj_x = op.get_attr("adj_x") 

1856 adj_y = op.get_attr("adj_y") 

1857 

1858 if not adj_x: 

1859 if not adj_y: 

1860 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True) 

1861 grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False) 

1862 else: 

1863 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False) 

1864 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False) 

1865 else: 

1866 if not adj_y: 

1867 grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True) 

1868 grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False) 

1869 else: 

1870 grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True) 

1871 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True) 

1872 

1873 # Possibly reduce along the broadcasted batch dimensions, if broadcasting 

1874 # is required. 

1875 shape_x_static = x.get_shape() 

1876 shape_y_static = y.get_shape() 

1877 output_may_have_non_empty_batch_shape = ( 

1878 (shape_x_static.rank is None or shape_x_static.rank > 2) or 

1879 (shape_y_static.rank is None or shape_y_static.rank > 2)) 

1880 batch_shapes_match = ( 

1881 shape_x_static[:-2].is_fully_defined() and 

1882 shape_y_static[:-2].is_fully_defined() and 

1883 shape_x_static[:-2] == shape_y_static[:-2]) 

1884 if (not output_may_have_non_empty_batch_shape) or batch_shapes_match: 

1885 return grad_x, grad_y 

1886 

1887 sx = array_ops.shape(x) 

1888 sy = array_ops.shape(y) 

1889 rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2]) 

1890 grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx) 

1891 grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy) 

1892 return grad_x, grad_y 

1893 

1894 

1895ops.NotDifferentiable("Range") 

1896ops.NotDifferentiable("LinSpace") 

1897 

1898 

1899@ops.RegisterGradient("Complex") 

1900def _ComplexGrad(op, grad): 

1901 """Returns the real and imaginary components of 'grad', respectively.""" 

1902 x = op.inputs[0] 

1903 y = op.inputs[1] 

1904 sx = array_ops.shape(x) 

1905 sy = array_ops.shape(y) 

1906 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 

1907 return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx), 

1908 array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy)) 

1909 

1910 

1911@ops.RegisterGradient("Real") 

1912def _RealGrad(_, grad): 

1913 """Returns 'grad' as the real part and set the imaginary part 0.""" 

1914 zero = constant_op.constant(0, dtype=grad.dtype) 

1915 return math_ops.complex(grad, zero) 

1916 

1917 

1918@ops.RegisterGradient("Imag") 

1919def _ImagGrad(_, grad): 

1920 """Returns 'grad' as the imaginary part and set the real part 0.""" 

1921 zero = constant_op.constant(0, dtype=grad.dtype) 

1922 return math_ops.complex(zero, grad) 

1923 

1924 

1925@ops.RegisterGradient("Angle") 

1926def _AngleGrad(op, grad): 

1927 """Returns -grad / (Im(x) + iRe(x))""" 

1928 x = op.inputs[0] 

1929 with ops.control_dependencies([grad]): 

1930 re = math_ops.real(x) 

1931 im = math_ops.imag(x) 

1932 z = math_ops.reciprocal(math_ops.complex(im, re)) 

1933 zero = constant_op.constant(0, dtype=grad.dtype) 

1934 complex_grad = math_ops.complex(grad, zero) 

1935 return -complex_grad * z 

1936 

1937 

1938@ops.RegisterGradient("Conj") 

1939def _ConjGrad(_, grad): 

1940 """Returns the complex conjugate of grad.""" 

1941 return math_ops.conj(grad) 

1942 

1943 

1944@ops.RegisterGradient("ComplexAbs") 

1945def _ComplexAbsGrad(op, grad): 

1946 """Returns the gradient of ComplexAbs.""" 

1947 return math_ops.div_no_nan( 

1948 math_ops.complex( 

1949 grad, array_ops.zeros_like(grad)) * op.inputs[0], 

1950 math_ops.complex( 

1951 op.outputs[0], array_ops.zeros_like(op.outputs[0]))) 

1952 

1953 

1954@ops.RegisterGradient("Cast") 

1955def _CastGrad(op, grad): 

1956 t = [ 

1957 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16, 

1958 dtypes.complex64, dtypes.complex128 

1959 ] 

1960 src_type = op.inputs[0].dtype.base_dtype 

1961 dst_type = grad.dtype.base_dtype 

1962 if src_type in t and dst_type in t: 

1963 return math_ops.cast(grad, src_type) 

1964 else: 

1965 return None 

1966 

1967 

1968@ops.RegisterGradient("Cross") 

1969def _CrossGrad(op, grad): 

1970 u = op.inputs[0] 

1971 v = op.inputs[1] 

1972 return (math_ops.cross(v, grad), math_ops.cross(grad, u)) 

1973 

1974 

1975@ops.RegisterGradient("Cumsum") 

1976def _CumsumGrad(op, grad): 

1977 axis = op.inputs[1] 

1978 exclusive = op.get_attr("exclusive") 

1979 reverse = op.get_attr("reverse") 

1980 return [ 

1981 math_ops.cumsum(grad, axis, exclusive=exclusive, reverse=not reverse), 

1982 None 

1983 ] 

1984 

1985 

1986@ops.RegisterGradient("Cumprod") 

1987def _CumprodGrad(op, grad): 

1988 x = op.inputs[0] 

1989 axis = op.inputs[1] 

1990 exclusive = op.get_attr("exclusive") 

1991 reverse = op.get_attr("reverse") 

1992 

1993 prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse) 

1994 out = math_ops.cumsum( 

1995 prod * grad, axis, exclusive=exclusive, reverse=not reverse) 

1996 return [math_ops.div_no_nan(out, x), None] 

1997 

1998 

1999@ops.RegisterGradient("CumulativeLogsumexp") 

2000def _CumulativeLogsumexpGrad(op, grad): 

2001 x = op.inputs[0] 

2002 axis = op.inputs[1] 

2003 cumulative_logsumexp = op.outputs[0] 

2004 

2005 exclusive = op.get_attr("exclusive") 

2006 reverse = op.get_attr("reverse") 

2007 

2008 # Split the incoming gradient into positive and negative part 

2009 # in order to take logs. This is required for stable results. 

2010 log_grad_positive = array_ops.where_v2( 

2011 math_ops.greater(grad, 0), 

2012 math_ops.log(grad), 

2013 grad.dtype.min) 

2014 

2015 log_grad_negative = array_ops.where_v2( 

2016 math_ops.less(grad, 0), 

2017 math_ops.log(-grad), 

2018 grad.dtype.min) 

2019 

2020 output_pos = math_ops.exp( 

2021 math_ops.cumulative_logsumexp( 

2022 log_grad_positive - cumulative_logsumexp, 

2023 axis=axis, reverse=not reverse, exclusive=exclusive) + x) 

2024 

2025 output_neg = math_ops.exp( 

2026 math_ops.cumulative_logsumexp( 

2027 log_grad_negative - cumulative_logsumexp, 

2028 axis=axis, reverse=not reverse, exclusive=exclusive) + x) 

2029 

2030 return [output_pos - output_neg, None] 

2031 

2032 

2033@ops.RegisterGradient("NextAfter") 

2034def _NextAfterGrad(op, grad): 

2035 """Returns gradient of nextafter(x1, x2) with respect to x1 and x2.""" 

2036 x1 = op.inputs[0] 

2037 x2 = op.inputs[1] 

2038 s_x1 = array_ops.shape(x1) 

2039 s_x2 = array_ops.shape(x2) 

2040 r_x1, r_x2 = gen_array_ops.broadcast_gradient_args(s_x1, s_x2) 

2041 with ops.control_dependencies([grad]): 

2042 partial_x1 = array_ops.ones(s_x1, dtype=x1.dtype) 

2043 partial_x2 = array_ops.zeros(s_x2, dtype=x2.dtype) 

2044 return (array_ops.reshape( 

2045 math_ops.reduce_sum(partial_x1 * grad, r_x1), s_x1), 

2046 array_ops.reshape( 

2047 math_ops.reduce_sum(partial_x2 * grad, r_x2), s_x2))