Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/nn_grad.py: 33%

400 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Gradients for operators defined in nn_ops.py.""" 

16 

17import functools 

18import itertools 

19import operator 

20 

21from tensorflow.python.eager import backprop 

22from tensorflow.python.framework import dtypes 

23from tensorflow.python.framework import ops 

24from tensorflow.python.ops import array_ops 

25from tensorflow.python.ops import array_ops_stack 

26from tensorflow.python.ops import gen_nn_ops 

27from tensorflow.python.ops import math_ops 

28from tensorflow.python.ops import nn_ops 

29 

30 

31@ops.RegisterGradient("Conv2DBackpropInput") 

32def _Conv2DBackpropInputGrad(op, grad): 

33 """The derivatives for deconvolution. 

34 

35 Args: 

36 op: the Deconvolution op. 

37 grad: the tensor representing the gradient w.r.t. the output 

38 

39 Returns: 

40 the gradients w.r.t. the input and the filter 

41 """ 

42 # We call the gen_nn_ops backprop functions instead of nn_ops backprop 

43 # functions for performance reasons in Eager mode. See _Conv2DGrad. 

44 return [ 

45 None, 

46 gen_nn_ops.conv2d_backprop_filter( 

47 grad, 

48 array_ops.shape(op.inputs[1]), 

49 op.inputs[2], 

50 dilations=op.get_attr("dilations"), 

51 strides=op.get_attr("strides"), 

52 padding=op.get_attr("padding"), 

53 explicit_paddings=op.get_attr("explicit_paddings"), 

54 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 

55 data_format=op.get_attr("data_format").decode()), 

56 gen_nn_ops.conv2d( 

57 grad, 

58 op.inputs[1], 

59 dilations=op.get_attr("dilations"), 

60 strides=op.get_attr("strides"), 

61 padding=op.get_attr("padding"), 

62 explicit_paddings=op.get_attr("explicit_paddings"), 

63 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 

64 data_format=op.get_attr("data_format").decode()) 

65 ] 

66 

67 

68@ops.RegisterGradient("Conv2DBackpropFilter") 

69def _Conv2DBackpropFilterGrad(op, grad): 

70 # We call the gen_nn_ops backprop functions instead of nn_ops backprop 

71 # functions for performance reasons in Eager mode. See _Conv2DGrad. 

72 return [ 

73 gen_nn_ops.conv2d_backprop_input( 

74 array_ops.shape(op.inputs[0]), 

75 grad, 

76 op.inputs[2], 

77 dilations=op.get_attr("dilations"), 

78 strides=op.get_attr("strides"), 

79 padding=op.get_attr("padding"), 

80 explicit_paddings=op.get_attr("explicit_paddings"), 

81 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 

82 data_format=op.get_attr("data_format").decode()), None, 

83 gen_nn_ops.conv2d( 

84 op.inputs[0], 

85 grad, 

86 dilations=op.get_attr("dilations"), 

87 strides=op.get_attr("strides"), 

88 padding=op.get_attr("padding"), 

89 explicit_paddings=op.get_attr("explicit_paddings"), 

90 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 

91 data_format=op.get_attr("data_format").decode()) 

92 ] 

93 

94 

95@ops.RegisterGradient("DepthwiseConv2dNativeBackpropInput") 

96def _DepthwiseConv2dNativeBackpropInputGrad(op, grad): 

97 """The derivatives for deconvolution. 

98 

99 Args: 

100 op: the Deconvolution op. 

101 grad: the tensor representing the gradient w.r.t. the output 

102 

103 Returns: 

104 the gradients w.r.t. the input and the filter 

105 """ 

106 return [ 

107 None, 

108 gen_nn_ops.depthwise_conv2d_native_backprop_filter( 

109 grad, 

110 array_ops.shape(op.inputs[1]), 

111 op.inputs[2], 

112 dilations=op.get_attr("dilations"), 

113 strides=op.get_attr("strides"), 

114 padding=op.get_attr("padding"), 

115 explicit_paddings=op.get_attr("explicit_paddings"), 

116 data_format=op.get_attr("data_format")), 

117 gen_nn_ops.depthwise_conv2d_native( 

118 grad, 

119 op.inputs[1], 

120 dilations=op.get_attr("dilations"), 

121 strides=op.get_attr("strides"), 

122 padding=op.get_attr("padding"), 

123 explicit_paddings=op.get_attr("explicit_paddings"), 

124 data_format=op.get_attr("data_format")) 

125 ] 

126 

127 

128@ops.RegisterGradient("DepthwiseConv2dNativeBackpropFilter") 

129def _DepthwiseConv2dNativeBackpropFilterGrad(op, grad): 

130 return [ 

131 gen_nn_ops.depthwise_conv2d_native_backprop_input( 

132 array_ops.shape(op.inputs[0]), 

133 grad, 

134 op.inputs[2], 

135 dilations=op.get_attr("dilations"), 

136 strides=op.get_attr("strides"), 

137 padding=op.get_attr("padding"), 

138 explicit_paddings=op.get_attr("explicit_paddings"), 

139 data_format=op.get_attr("data_format")), None, 

140 gen_nn_ops.depthwise_conv2d_native( 

141 op.inputs[0], 

142 grad, 

143 dilations=op.get_attr("dilations"), 

144 strides=op.get_attr("strides"), 

145 padding=op.get_attr("padding"), 

146 explicit_paddings=op.get_attr("explicit_paddings"), 

147 data_format=op.get_attr("data_format")) 

148 ] 

149 

150 

151@ops.RegisterGradient("Conv3D") 

152def _Conv3DGrad(op, grad): 

153 data_format = op.get_attr("data_format").decode() 

154 shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]]) 

155 return [ 

156 nn_ops.conv3d_backprop_input_v2( 

157 shape_0, 

158 op.inputs[1], 

159 grad, 

160 dilations=op.get_attr("dilations"), 

161 strides=op.get_attr("strides"), 

162 padding=op.get_attr("padding"), 

163 data_format=data_format), 

164 nn_ops.conv3d_backprop_filter_v2( 

165 op.inputs[0], 

166 shape_1, 

167 grad, 

168 dilations=op.get_attr("dilations"), 

169 strides=op.get_attr("strides"), 

170 padding=op.get_attr("padding"), 

171 data_format=data_format) 

172 ] 

173 

174 

175@ops.RegisterGradient("Conv3DBackpropInputV2") 

176def _Conv3DBackpropInputGrad(op, grad): 

177 data_format = op.get_attr("data_format").decode() 

178 return [ 

179 None, 

180 nn_ops.conv3d_backprop_filter_v2( 

181 grad, 

182 array_ops.shape(op.inputs[1]), 

183 op.inputs[2], 

184 dilations=op.get_attr("dilations"), 

185 strides=op.get_attr("strides"), 

186 padding=op.get_attr("padding"), 

187 data_format=data_format), 

188 nn_ops.conv3d( 

189 grad, 

190 op.inputs[1], 

191 dilations=op.get_attr("dilations"), 

192 strides=op.get_attr("strides"), 

193 padding=op.get_attr("padding"), 

194 data_format=data_format) 

195 ] 

196 

197 

198@ops.RegisterGradient("Conv3DBackpropFilterV2") 

199def _Conv3DBackpropFilterGrad(op, grad): 

200 data_format = op.get_attr("data_format").decode() 

201 return [ 

202 nn_ops.conv3d_backprop_input_v2( 

203 array_ops.shape(op.inputs[0]), 

204 grad, 

205 op.inputs[2], 

206 dilations=op.get_attr("dilations"), 

207 strides=op.get_attr("strides"), 

208 padding=op.get_attr("padding"), 

209 data_format=data_format), None, 

210 nn_ops.conv3d( 

211 op.inputs[0], 

212 grad, 

213 dilations=op.get_attr("dilations"), 

214 strides=op.get_attr("strides"), 

215 padding=op.get_attr("padding"), 

216 data_format=data_format) 

217 ] 

218 

219 

220@ops.RegisterGradient("AvgPool3D") 

221def _AvgPool3DGrad(op, grad): 

222 return gen_nn_ops.avg_pool3d_grad( 

223 array_ops.shape(op.inputs[0]), 

224 grad, 

225 ksize=op.get_attr("ksize"), 

226 strides=op.get_attr("strides"), 

227 padding=op.get_attr("padding"), 

228 data_format=op.get_attr("data_format").decode()) 

229 

230 

231@ops.RegisterGradient("AvgPool3DGrad") 

232def _AvgPool3DGradGrad(op, grad): 

233 return (array_ops.stop_gradient(op.inputs[0]), 

234 gen_nn_ops.avg_pool3d( 

235 grad, 

236 op.get_attr("ksize"), 

237 op.get_attr("strides"), 

238 op.get_attr("padding"), 

239 data_format=op.get_attr("data_format").decode())) 

240 

241 

242@ops.RegisterGradient("MaxPool3D") 

243def _MaxPool3DGrad(op, grad): 

244 return gen_nn_ops.max_pool3d_grad( 

245 op.inputs[0], 

246 op.outputs[0], 

247 grad, 

248 ksize=op.get_attr("ksize"), 

249 strides=op.get_attr("strides"), 

250 padding=op.get_attr("padding"), 

251 data_format=op.get_attr("data_format").decode()) 

252 

253 

254@ops.RegisterGradient("MaxPool3DGrad") 

255def _MaxPool3DGradGrad(op, grad): 

256 return (array_ops.zeros_like(op.inputs[0]), 

257 array_ops.zeros_like(op.inputs[1]), 

258 gen_nn_ops.max_pool3d_grad_grad( 

259 op.inputs[0], 

260 op.inputs[1], 

261 grad, 

262 op.get_attr("ksize"), 

263 op.get_attr("strides"), 

264 padding=op.get_attr("padding"), 

265 data_format=op.get_attr("data_format").decode())) 

266 

267 

268@ops.RegisterGradient("MaxPool3DGradGrad") 

269def _MaxPool3DGradGradGrad(op, grad): 

270 return (array_ops.zeros_like(op.inputs[0]), 

271 array_ops.zeros_like(op.inputs[1]), 

272 gen_nn_ops.max_pool3d_grad( 

273 op.inputs[0], 

274 op.inputs[1], 

275 grad, 

276 op.get_attr("ksize"), 

277 op.get_attr("strides"), 

278 padding=op.get_attr("padding"), 

279 data_format=op.get_attr("data_format").decode())) 

280 

281 

282@ops.RegisterGradient("Softmax") 

283def _SoftmaxGrad(op, grad_softmax): 

284 """The derivative of the softmax nonlinearity. 

285 

286 We assume that probs is of shape [batch_size * dim] 

287 The formula for dsoftmax / dx = (diag(softmax) - softmax * softmax'). 

288 This matrix is diagonal minus a rank one matrix, so it is easy to implement 

289 as follows: 

290 

291 grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax 

292 

293 Args: 

294 op: the Softmax op. 

295 grad_softmax: the tensor representing the gradient w.r.t. the softmax 

296 output. 

297 

298 Returns: 

299 gradient w.r.t the input to the softmax 

300 

301 """ 

302 softmax = op.outputs[0] 

303 sum_channels = math_ops.reduce_sum(grad_softmax * softmax, -1, keepdims=True) 

304 return (grad_softmax - sum_channels) * softmax 

305 

306 

307@ops.RegisterGradient("LogSoftmax") 

308def _LogSoftmaxGrad(op, grad): 

309 """The gradient for log_softmax. 

310 

311 log_softmax = input - log(sum(exp(input)) 

312 dlog_softmax/dinput = diag - softmax(input) 

313 

314 Args: 

315 op: The log softmax op. 

316 grad: The tensor representing the gradient w.r.t. the output. 

317 

318 Returns: 

319 The gradients w.r.t. the input. 

320 """ 

321 softmax = math_ops.exp(op.outputs[0]) 

322 return grad - math_ops.reduce_sum(grad, -1, keepdims=True) * softmax 

323 

324 

325@ops.RegisterGradient("BiasAdd") 

326def _BiasAddGrad(op, received_grad): 

327 """Return the gradients for the 2 inputs of bias_op. 

328 

329 The first input of unused_bias_op is the tensor t, and its gradient is 

330 just the gradient the unused_bias_op received. 

331 

332 The second input of unused_bias_op is the bias vector which has one fewer 

333 dimension than "received_grad" (the batch dimension.) Its gradient is the 

334 received gradient Summed on the batch dimension, which is the first dimension. 

335 

336 Args: 

337 op: The BiasOp for which we need to generate gradients. 

338 received_grad: Tensor. The gradients passed to the BiasOp. 

339 

340 Returns: 

341 Two tensors, the first one for the "tensor" input of the BiasOp, 

342 the second one for the "bias" input of the BiasOp. 

343 """ 

344 try: 

345 data_format = op.get_attr("data_format") 

346 except ValueError: 

347 data_format = None 

348 return (received_grad, 

349 gen_nn_ops.bias_add_grad( 

350 out_backprop=received_grad, data_format=data_format)) 

351 

352 

353@ops.RegisterGradient("BiasAddGrad") 

354def _BiasAddGradGrad(op, received_grad): 

355 """Gradient for the BiasAddGrad op. 

356 

357 Args: 

358 op: BiasAddGrad op for which we are calculating gradients. 

359 received_grad: The gradients passed to the BiasAddGrad op. 

360 

361 Returns: 

362 A single gradient Tensor for the input to BiasAddGrad (which 

363 is the gradient of the bias term in BiasAdd) 

364 """ 

365 

366 try: 

367 data_format = op.get_attr("data_format") 

368 except ValueError: 

369 data_format = None 

370 

371 shape = array_ops.shape(op.inputs[0]) 

372 bias_shape = array_ops.shape(received_grad) 

373 

374 if data_format == b"NCHW": 

375 expanded_shape = array_ops.concat([ 

376 array_ops.ones_like(shape[:1]), bias_shape, 

377 array_ops.ones_like(shape[2:]) 

378 ], 0) 

379 tile_mults = array_ops.concat([shape[:1], [1], shape[2:]], 0) 

380 else: 

381 expanded_shape = array_ops.concat( 

382 [array_ops.ones_like(shape[:-1]), bias_shape], 0) 

383 tile_mults = array_ops.concat([shape[:-1], [1]], 0) 

384 

385 expanded_grad = array_ops.reshape(received_grad, expanded_shape) 

386 return array_ops.tile(expanded_grad, tile_mults) 

387 

388 

389@ops.RegisterGradient("BiasAddV1") 

390def _BiasAddGradV1(unused_bias_op, received_grad): 

391 """Return the gradients for the 2 inputs of bias_op. 

392 

393 The first input of unused_bias_op is the tensor t, and its gradient is 

394 just the gradient the unused_bias_op received. 

395 

396 The second input of unused_bias_op is the bias vector which has one fewer 

397 dimension than "received_grad" (the batch dimension.) Its gradient is the 

398 received gradient Summed on the batch dimension, which is the first dimension. 

399 

400 Args: 

401 unused_bias_op: The BiasOp for which we need to generate gradients. 

402 received_grad: Tensor. The gradients passed to the BiasOp. 

403 

404 Returns: 

405 Two tensors, the first one for the "tensor" input of the BiasOp, 

406 the second one for the "bias" input of the BiasOp. 

407 """ 

408 reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1) 

409 return (received_grad, math_ops.reduce_sum(received_grad, 

410 reduction_dim_tensor)) 

411 

412 

413@ops.RegisterGradient("Relu") 

414def _ReluGrad(op, grad): 

415 return gen_nn_ops.relu_grad(grad, op.outputs[0]) 

416 

417 

418@ops.RegisterGradient("EluGrad") 

419def _EluGradGrad(op, grad): 

420 elu_x = op.inputs[1] 

421 return (gen_nn_ops.elu_grad(grad, elu_x), 

422 array_ops.where( 

423 elu_x < 0, grad * op.inputs[0], array_ops.zeros_like(elu_x))) 

424 

425 

426@ops.RegisterGradient("SeluGrad") 

427def _SeluGradGrad(op, grad): 

428 selu_x = op.inputs[1] 

429 return (gen_nn_ops.selu_grad(grad, selu_x), 

430 array_ops.where( 

431 selu_x < 0., grad * op.inputs[0], array_ops.zeros_like(selu_x))) 

432 

433 

434@ops.RegisterGradient("Relu6") 

435def _Relu6Grad(op, grad): 

436 return gen_nn_ops.relu6_grad(grad, op.outputs[0]) 

437 

438 

439@ops.RegisterGradient("Relu6Grad") 

440def _Relu6GradGrad(op, grad): 

441 x = op.inputs[1] 

442 return (gen_nn_ops.relu6_grad(grad, x), array_ops.zeros_like(x)) 

443 

444 

445@ops.RegisterGradient("LeakyRelu") 

446def _LeakyReluGrad(op, grad): 

447 x = op.inputs[0] 

448 alpha = op.get_attr("alpha") 

449 return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha) 

450 

451 

452@ops.RegisterGradient("LeakyReluGrad") 

453def _LeakyReluGradGrad(op, grad): 

454 x = op.inputs[1] 

455 alpha = op.get_attr("alpha") 

456 return (gen_nn_ops.leaky_relu_grad(grad, x, 

457 alpha=alpha), array_ops.zeros_like(x)) 

458 

459 

460@ops.RegisterGradient("Elu") 

461def _EluGrad(op, grad): 

462 return gen_nn_ops.elu_grad(grad, op.outputs[0]) 

463 

464 

465@ops.RegisterGradient("Selu") 

466def _SeluGrad(op, grad): 

467 return gen_nn_ops.selu_grad(grad, op.outputs[0]) 

468 

469 

470@ops.RegisterGradient("Softplus") 

471def _SoftplusGrad(op, grad): 

472 return grad * math_ops.sigmoid(op.inputs[0]) 

473 

474 

475@ops.RegisterGradient("SoftplusGrad") 

476def _SoftplusGradGrad(op, grad): 

477 # Let: 

478 # y = tf.nn.softplus(x) 

479 # dx = gen_nn_ops.softplus_grad(dy, x) = dy / (1 + exp(-x)) 

480 # This op computes (ddy, d2x) from op.inputs == [dy, x] and grad == ddx. 

481 dy, x = op.inputs 

482 with ops.control_dependencies([grad]): 

483 ddy = gen_nn_ops.softplus_grad(grad, x) 

484 d2x = grad * dy / (math_ops.exp(-x) + 2.0 + math_ops.exp(x)) 

485 return (ddy, d2x) 

486 

487 

488@ops.RegisterGradient("Softsign") 

489def _SoftsignGrad(op, grad): 

490 return gen_nn_ops.softsign_grad(grad, op.inputs[0]) 

491 

492 

493@ops.RegisterGradient("ReluGrad") 

494def _ReluGradGrad(op, grad): 

495 x = op.inputs[1] 

496 return (gen_nn_ops.relu_grad(grad, x), array_ops.zeros_like(x)) 

497 

498 

499def _BroadcastMul(vec, mat): 

500 """Multiply after broadcasting vec to match dimensions of mat. 

501 

502 Args: 

503 vec: A 1-D tensor of dimension [D0] 

504 mat: A 2-D tensor of dimension [D0, D1] 

505 

506 Returns: 

507 A tensor of dimension [D0, D1], the result of vec * mat 

508 """ 

509 # Reshape vec to [D0, 1] 

510 vec = array_ops.expand_dims(vec, -1) 

511 return vec * mat 

512 

513 

514@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits") 

515def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): 

516 """Gradient function for SoftmaxCrossEntropyWithLogits.""" 

517 # grad_loss is the backprop for cost, and we multiply it with the gradients 

518 # (which is output[1]) 

519 # grad_grad is the backprop for softmax gradient. 

520 # 

521 # Second derivative is just softmax derivative w.r.t. logits. 

522 softmax_grad = op.outputs[1] 

523 grad = _BroadcastMul(grad_loss, softmax_grad) 

524 

525 logits = op.inputs[0] 

526 if (grad_grad is not None and 

527 not getattr(grad_grad, "_is_zeros_tensor", False)): 

528 softmax = nn_ops.softmax(logits) 

529 

530 grad += ((grad_grad - array_ops.squeeze( 

531 math_ops.matmul( 

532 array_ops.expand_dims(grad_grad, 1), 

533 array_ops.expand_dims(softmax, 2)), 

534 axis=1)) * softmax) 

535 

536 return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits)) # pylint: disable=invalid-unary-operand-type 

537 

538 

539@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits") 

540def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): 

541 """Gradient function for SparseSoftmaxCrossEntropyWithLogits.""" 

542 # grad_loss is the backprop for cost, and we multiply it with the gradients 

543 # (which is output[1]) 

544 # grad_grad is the backprop for softmax gradient. 

545 # There is no gradient for the labels 

546 # 

547 # Second derivative is just softmax derivative w.r.t. logits. 

548 softmax_grad = op.outputs[1] 

549 grad = _BroadcastMul(grad_loss, softmax_grad) 

550 

551 logits = op.inputs[0] 

552 if (grad_grad is not None and 

553 not getattr(grad_grad, "_is_zeros_tensor", False)): 

554 softmax = nn_ops.softmax(logits) 

555 

556 grad += ((grad_grad - array_ops.squeeze( 

557 math_ops.matmul( 

558 array_ops.expand_dims(grad_grad, 1), 

559 array_ops.expand_dims(softmax, 2)), 

560 axis=1)) * softmax) 

561 

562 return grad, None 

563 

564 

565@ops.RegisterGradient("Conv2D") 

566def _Conv2DGrad(op, grad): 

567 """Gradient function for Conv2D.""" 

568 dilations = op.get_attr("dilations") 

569 strides = op.get_attr("strides") 

570 padding = op.get_attr("padding") 

571 explicit_paddings = op.get_attr("explicit_paddings") 

572 use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu") 

573 data_format = op.get_attr("data_format") 

574 shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]]) 

575 

576 # We call the gen_nn_ops backprop functions instead of nn_ops backprop 

577 # functions for performance reasons in Eager mode. gen_nn_ops functions take a 

578 # `explicit_paddings` parameter, but nn_ops functions do not. So if we were 

579 # to use the nn_ops functions, we would have to convert `padding` and 

580 # `explicit_paddings` into a single `padding` parameter, increasing overhead 

581 # in Eager mode. 

582 return [ 

583 gen_nn_ops.conv2d_backprop_input( 

584 shape_0, 

585 op.inputs[1], 

586 grad, 

587 dilations=dilations, 

588 strides=strides, 

589 padding=padding, 

590 explicit_paddings=explicit_paddings, 

591 use_cudnn_on_gpu=use_cudnn_on_gpu, 

592 data_format=data_format), 

593 gen_nn_ops.conv2d_backprop_filter( 

594 op.inputs[0], 

595 shape_1, 

596 grad, 

597 dilations=dilations, 

598 strides=strides, 

599 padding=padding, 

600 explicit_paddings=explicit_paddings, 

601 use_cudnn_on_gpu=use_cudnn_on_gpu, 

602 data_format=data_format) 

603 ] 

604 

605 

606@ops.RegisterGradient("DepthwiseConv2dNative") 

607def _DepthwiseConv2dNativeGrad(op, grad): 

608 return [ 

609 gen_nn_ops.depthwise_conv2d_native_backprop_input( 

610 array_ops.shape(op.inputs[0]), 

611 op.inputs[1], 

612 grad, 

613 dilations=op.get_attr("dilations"), 

614 strides=op.get_attr("strides"), 

615 padding=op.get_attr("padding"), 

616 explicit_paddings=op.get_attr("explicit_paddings"), 

617 data_format=op.get_attr("data_format")), 

618 gen_nn_ops.depthwise_conv2d_native_backprop_filter( 

619 op.inputs[0], 

620 array_ops.shape(op.inputs[1]), 

621 grad, 

622 dilations=op.get_attr("dilations"), 

623 strides=op.get_attr("strides"), 

624 padding=op.get_attr("padding"), 

625 explicit_paddings=op.get_attr("explicit_paddings"), 

626 data_format=op.get_attr("data_format")) 

627 ] 

628 

629 

630@ops.RegisterGradient("Dilation2D") 

631def _Dilation2DGrad(op, grad): 

632 return [ 

633 nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad, 

634 op.get_attr("strides"), 

635 op.get_attr("rates"), 

636 op.get_attr("padding")), 

637 nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad, 

638 op.get_attr("strides"), 

639 op.get_attr("rates"), 

640 op.get_attr("padding")) 

641 ] 

642 

643 

644@ops.RegisterGradient("LRN") 

645def _LRNGrad(op, grad): 

646 depth_radius = op.get_attr("depth_radius") 

647 bias = op.get_attr("bias") 

648 alpha = op.get_attr("alpha") 

649 beta = op.get_attr("beta") 

650 return [ 

651 gen_nn_ops.lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, bias, 

652 alpha, beta) 

653 ] 

654 

655 

656@ops.RegisterGradient("AvgPool") 

657def _AvgPoolGrad(op, grad): 

658 return gen_nn_ops.avg_pool_grad( 

659 array_ops.shape(op.inputs[0]), 

660 grad, 

661 op.get_attr("ksize"), 

662 op.get_attr("strides"), 

663 op.get_attr("padding"), 

664 data_format=op.get_attr("data_format")) 

665 

666 

667@ops.RegisterGradient("AvgPoolGrad") 

668def _AvgPoolGradGrad(op, grad): 

669 return (array_ops.stop_gradient(op.inputs[0]), 

670 gen_nn_ops.avg_pool( 

671 grad, 

672 op.get_attr("ksize"), 

673 op.get_attr("strides"), 

674 op.get_attr("padding"), 

675 data_format=op.get_attr("data_format"))) 

676 

677 

678@ops.RegisterGradient("MaxPool") 

679def _MaxPoolGrad(op, grad): 

680 return gen_nn_ops.max_pool_grad( 

681 op.inputs[0], 

682 op.outputs[0], 

683 grad, 

684 op.get_attr("ksize"), 

685 op.get_attr("strides"), 

686 padding=op.get_attr("padding"), 

687 explicit_paddings=op.get_attr("explicit_paddings"), 

688 data_format=op.get_attr("data_format")) 

689 

690 

691@ops.RegisterGradient("MaxPoolV2") 

692def _MaxPoolGradV2(op, grad): 

693 ksize = op.inputs[1] 

694 strides = op.inputs[2] 

695 return gen_nn_ops.max_pool_grad_v2( 

696 op.inputs[0], 

697 op.outputs[0], 

698 grad, 

699 ksize, 

700 strides, 

701 padding=op.get_attr("padding"), 

702 data_format=op.get_attr("data_format")), None, None 

703 

704 

705@ops.RegisterGradient("MaxPoolWithArgmax") 

706def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad): 

707 del unused_argmax_grad 

708 return gen_nn_ops.max_pool_grad_with_argmax( 

709 op.inputs[0], 

710 grad, 

711 op.outputs[1], 

712 op.get_attr("ksize"), 

713 op.get_attr("strides"), 

714 padding=op.get_attr("padding"), 

715 include_batch_in_index=op.get_attr("include_batch_in_index")) 

716 

717 

718@ops.RegisterGradient("MaxPoolGrad") 

719def _MaxPoolGradGrad(op, grad): 

720 return (array_ops.zeros_like(op.inputs[0]), 

721 array_ops.zeros_like(op.inputs[1]), 

722 gen_nn_ops.max_pool_grad_grad( 

723 op.inputs[0], 

724 op.inputs[1], 

725 grad, 

726 op.get_attr("ksize"), 

727 op.get_attr("strides"), 

728 padding=op.get_attr("padding"), 

729 data_format=op.get_attr("data_format"))) 

730 

731 

732@ops.RegisterGradient("MaxPoolGradV2") 

733def _MaxPoolGradGradV2(op, grad): 

734 ksize = op.inputs[3] 

735 strides = op.inputs[4] 

736 return (array_ops.zeros_like(op.inputs[0]), 

737 array_ops.zeros_like(op.inputs[1]), 

738 gen_nn_ops.max_pool_grad_grad_v2( 

739 op.inputs[0], 

740 op.inputs[1], 

741 grad, 

742 ksize, 

743 strides, 

744 padding=op.get_attr("padding"), 

745 data_format=op.get_attr("data_format")), None, None) 

746 

747 

748@ops.RegisterGradient("MaxPoolGradGrad") 

749def _MaxPoolGradGradGrad(op, grad): 

750 return (array_ops.zeros_like(op.inputs[0]), 

751 array_ops.zeros_like(op.inputs[1]), 

752 gen_nn_ops.max_pool_grad( 

753 op.inputs[0], 

754 op.inputs[1], 

755 grad, 

756 op.get_attr("ksize"), 

757 op.get_attr("strides"), 

758 padding=op.get_attr("padding"), 

759 data_format=op.get_attr("data_format"))) 

760 

761 

762@ops.RegisterGradient("FractionalMaxPool") 

763def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): 

764 """Returns gradient for FractionalMaxPool. 

765 

766 Since FractionalMaxPool has three outputs, there are three gradients passed in 

767 for each of the outputs. Only the first one is useful, the other two gradients 

768 are empty. 

769 

770 Args: 

771 op: The FractionalMaxPoolOp. 

772 grad_0: Gradient with respect to op.outputs[0] 

773 unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty. 

774 unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty. 

775 

776 Returns: 

777 Input backprop for FractionalMaxPool op. 

778 """ 

779 return gen_nn_ops.fractional_max_pool_grad( 

780 op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2], 

781 op.get_attr("overlapping")) 

782 

783 

784@ops.RegisterGradient("FractionalAvgPool") 

785def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): 

786 """Returns gradient for FractionalAvgPool. 

787 

788 Since FractionalAvgPool has three outputs, there are three gradients passed in 

789 for each of the outputs. Only the first one is useful, the other two gradients 

790 are empty. 

791 

792 Args: 

793 op: The FractionalAvgPoolOp. 

794 grad_0: Gradient with respect to op.outputs[0] 

795 unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty. 

796 unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty. 

797 

798 Returns: 

799 Input backprop for FractionalAvgPool op. 

800 """ 

801 return gen_nn_ops.fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0, 

802 op.outputs[1], op.outputs[2], 

803 op.get_attr("overlapping")) 

804 

805 

806@ops.RegisterGradient("BatchNormWithGlobalNormalization") 

807def _BatchNormWithGlobalNormalizationGrad(op, grad): 

808 """Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization. 

809 

810 We do not backprop anything for the mean and var intentionally as they are 

811 not being trained with backprop in the operation. 

812 

813 Args: 

814 op: The BatchNormOp for which we need to generate gradients. 

815 grad: Tensor. The gradients passed to the BatchNormOp. 

816 

817 Returns: 

818 dx: Backprop for input, which is (grad * (g * rsqrt(v + epsilon))) 

819 dm: Backprop for mean, which is 

820 sum_over_rest(grad * g) * (-1 / rsqrt(v + epsilon)) 

821 dv: Backprop for variance, which is 

822 sum_over_rest(grad * g * (x - m)) * (-1/2) * (v + epsilon) ^ (-3/2) 

823 db: Backprop for beta, which is grad reduced in all except the 

824 last dimension. 

825 dg: Backprop for gamma, which is (grad * ((x - m) * rsqrt(v + epsilon))) 

826 """ 

827 dx, dm, dv, db, dg = gen_nn_ops.batch_norm_with_global_normalization_grad( 

828 op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[4], grad, 

829 op.get_attr("variance_epsilon"), op.get_attr("scale_after_normalization")) 

830 return dx, dm, dv, db, dg 

831 

832 

833def _BaseFusedBatchNormGrad(op, version, *grad): 

834 """Return the gradients for the 3 inputs of BatchNorm. 

835 

836 Args: 

837 op: The BatchNormOp for which we need to compute gradients. 

838 version: Integer indicating which version to use of the fused batch 

839 norm gradient. 

840 *grad: An argument list for tensors of gradients wrt the outputs 

841 with grad[0] as grad_y. 

842 

843 Returns: 

844 grad_x: gradient for x, which is scale * rsqrt(variance + epsilon) * 

845 [grad_y - mean(grad_y) - (x - mean(x)) * 

846 mean(grad_y * (x - mean(x))) / (variance + epsilon)] 

847 in training mode; grad_y * scale * rsqrt(pop_variance + epsilon) 

848 in freeze mode. 

849 

850 grad_scale: gradient for scale, which is sum(grad_y * (x - mean(x)) * 

851 rsqrt(variance + epsilon)) in training mode; 

852 sum(grad_y * (x - pop_mean) * rsqrt(pop_variance + epsilon)) 

853 in freeze mode. 

854 

855 grad_offset: gradient for offset, which is sum(grad_y) in training mode; 

856 sum(grad_y) in freeze mode. 

857 """ 

858 x = op.inputs[0] 

859 grad_y = grad[0] 

860 scale = op.inputs[1] 

861 epsilon = op.get_attr("epsilon") 

862 data_format = op.get_attr("data_format") 

863 is_training = op.get_attr("is_training") 

864 if version == 2: 

865 grad_fun = gen_nn_ops.fused_batch_norm_grad_v3 

866 elif version == 1: 

867 grad_fun = gen_nn_ops.fused_batch_norm_grad_v2 

868 else: 

869 grad_fun = gen_nn_ops.fused_batch_norm_grad 

870 if is_training: 

871 args = { 

872 "y_backprop": grad_y, 

873 "x": x, 

874 "scale": scale, 

875 "reserve_space_1": op.outputs[3], 

876 "reserve_space_2": op.outputs[4], 

877 "epsilon": epsilon, 

878 "data_format": data_format, 

879 "is_training": is_training 

880 } 

881 if version == 2: 

882 args["reserve_space_3"] = op.outputs[5] 

883 dx, dscale, doffset, _, _ = grad_fun(**args) 

884 else: 

885 pop_mean = op.inputs[3] 

886 pop_var = op.inputs[4] 

887 if data_format == b"NCHW": 

888 x = array_ops.transpose(x, [0, 2, 3, 1]) 

889 grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1]) 

890 elif data_format == b"NCDHW": 

891 x = array_ops.transpose(x, [0, 2, 3, 4, 1]) 

892 grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1]) 

893 target_data_format = ("NHWC" if data_format in (b"NCHW", 

894 b"NHWC") else "NDHWC") 

895 args = { 

896 "y_backprop": grad_y, 

897 "x": x, 

898 "scale": scale, 

899 "reserve_space_1": pop_mean, 

900 "reserve_space_2": pop_var, 

901 "epsilon": epsilon, 

902 "data_format": target_data_format, 

903 "is_training": is_training 

904 } 

905 if version == 2: 

906 args["reserve_space_3"] = op.outputs[5] 

907 dx, dscale, doffset, _, _ = grad_fun(**args) 

908 if data_format == b"NCHW": 

909 dx = array_ops.transpose(dx, [0, 3, 1, 2]) 

910 elif data_format == b"NCDHW": 

911 dx = array_ops.transpose(dx, [0, 4, 1, 2, 3]) 

912 return dx, dscale, doffset, None, None 

913 

914 

915@ops.RegisterGradient("FusedBatchNorm") 

916def _FusedBatchNormGrad(op, *grad): 

917 return _BaseFusedBatchNormGrad(op, 0, *grad) 

918 

919 

920@ops.RegisterGradient("FusedBatchNormV2") 

921def _FusedBatchNormV2Grad(op, *grad): 

922 return _BaseFusedBatchNormGrad(op, 1, *grad) 

923 

924 

925@ops.RegisterGradient("FusedBatchNormV3") 

926def _FusedBatchNormV3Grad(op, *grad): 

927 return _BaseFusedBatchNormGrad(op, 2, *grad) 

928 

929 

930def _BatchNormGrad(grad_y, 

931 x, 

932 scale, 

933 pop_mean, 

934 pop_var, 

935 epsilon, 

936 data_format, 

937 is_training=True): 

938 """Returns the gradients for the 3 inputs of BatchNorm. 

939 

940 Args: 

941 grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y. 

942 x: A `Tensor` of 4 or 5 dimensions for x. 

943 scale: A `Tensor` of 1 dimension for scaling. 

944 pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when 

945 is_training=False. 

946 pop_var: A `Tensor` of 1 dimension for the population variance. Only used 

947 when is_training=False. 

948 epsilon: A small float number added to the variance of x. 

949 data_format: The data format for input. Either b"NHWC" or b"NCHW". 

950 is_training: A bool value to indicate the operation is for training 

951 (default) or inference. 

952 

953 Returns: 

954 A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient 

955 for x, grad_scale the gradient for scale, and grad_offset the gradient 

956 for offset. 

957 """ 

958 x_dtype = x.dtype.base_dtype 

959 if x_dtype == dtypes.float16 or x_dtype == dtypes.bfloat16: 

960 # float16 math is too imprecise, so we do the batch norm gradient 

961 # computations in float32. 

962 x = math_ops.cast(x, dtypes.float32) 

963 grad_y = math_ops.cast(grad_y, dtypes.float32) 

964 if is_training: 

965 if data_format == b"NHWC": 

966 keepdims = False 

967 reduce_axis = [0, 1, 2] 

968 elif data_format == b"NDHWC": 

969 keepdims = False 

970 reduce_axis = [0, 1, 2, 3] 

971 elif data_format == b"NCHW": 

972 keepdims = True 

973 reduce_axis = [0, 2, 3] 

974 shape = [1, array_ops.size(scale), 1, 1] 

975 scale = array_ops.reshape(scale, shape) 

976 else: 

977 keepdims = True 

978 reduce_axis = [0, 2, 3, 4] 

979 shape = [1, array_ops.size(scale), 1, 1, 1] 

980 scale = array_ops.reshape(scale, shape) 

981 mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims) 

982 mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims) 

983 var_x = math_ops.reduce_mean( 

984 math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)), 

985 reduce_axis, 

986 keepdims=keepdims) 

987 grad_y_offset = grad_y - mean_grad_y 

988 x_offset = x - mean_x 

989 mean = math_ops.reduce_mean( 

990 grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) 

991 grad_x = scale * math_ops.rsqrt(var_x + epsilon) * ( 

992 grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset) 

993 grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum( 

994 grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) 

995 if data_format == b"NCHW" or data_format == b"NCDHW": 

996 grad_scale = array_ops.squeeze(grad_scale) 

997 grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) 

998 return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset 

999 else: 

1000 if data_format == b"NHWC": 

1001 reduce_axis = [0, 1, 2] 

1002 elif data_format == b"NDHWC": 

1003 reduce_axis = [0, 1, 2, 3] 

1004 elif data_format == b"NCHW": 

1005 reduce_axis = [0, 2, 3] 

1006 shape = [1, array_ops.size(pop_mean), 1, 1] 

1007 pop_mean = array_ops.reshape(pop_mean, shape) 

1008 pop_var = array_ops.reshape(pop_var, shape) 

1009 scale = array_ops.reshape(scale, shape) 

1010 else: 

1011 reduce_axis = [0, 2, 3, 4] 

1012 shape = [1, array_ops.size(pop_mean), 1, 1, 1] 

1013 pop_mean = array_ops.reshape(pop_mean, shape) 

1014 pop_var = array_ops.reshape(pop_var, shape) 

1015 scale = array_ops.reshape(scale, shape) 

1016 

1017 grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) 

1018 var_rsqrt = math_ops.rsqrt(pop_var + epsilon) 

1019 grad_scale = math_ops.reduce_sum( 

1020 grad_y * (x - pop_mean) * var_rsqrt, axis=reduce_axis) 

1021 grad_x = grad_y * scale * var_rsqrt 

1022 return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset 

1023 

1024 

1025@ops.RegisterGradient("FusedBatchNormGrad") 

1026def _FusedBatchNormGradGrad(op, *grad): 

1027 """Returns the gradients for the 3 inputs of FusedBatchNormGrad. 

1028 

1029 Args: 

1030 op: The FusedBatchNormGradOp for which we need to compute gradients. 

1031 *grad: An argument list for tensors of gradients wrt the outputs with 

1032 grad[0] as grad_grad_x, grad[1] as grad_grad_scale, grad[2] as 

1033 grad_grad_offset. 

1034 

1035 Returns: 

1036 A tuple (grad_grad_y, grad_x, grad_scale, None, None), where grad_grad_y 

1037 is the gradient for grad_y, grad_x the gradient for x, grad_scale the 

1038 gradient for scale. 

1039 """ 

1040 data_format = op.get_attr("data_format") 

1041 epsilon = op.get_attr("epsilon") 

1042 is_training = op.get_attr("is_training") 

1043 grad_y = op.inputs[0] 

1044 x = op.inputs[1] 

1045 scale = op.inputs[2] 

1046 pop_mean = op.inputs[3] 

1047 pop_var = op.inputs[4] 

1048 grad_grad_x = grad[0] 

1049 grad_grad_scale = grad[1] 

1050 grad_grad_offset = grad[2] 

1051 with backprop.GradientTape() as tape: 

1052 tape.watch(grad_y) 

1053 tape.watch(x) 

1054 tape.watch(scale) 

1055 grad_x, grad_scale, grad_offset = _BatchNormGrad( 

1056 grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training) 

1057 grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset] 

1058 grad_grad_y, grad_x, grad_scale = tape.gradient( 

1059 [grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial) 

1060 return grad_grad_y, grad_x, grad_scale, None, None 

1061 

1062 

1063@ops.RegisterGradient("FusedBatchNormGradV2") 

1064def _FusedBatchNormGradGradV2(op, *grad): 

1065 return _FusedBatchNormGradGrad(op, *grad) 

1066 

1067 

1068@ops.RegisterGradient("FusedBatchNormGradV3") 

1069def _FusedBatchNormGradGradV3(op, *grad): 

1070 grad_grad_y, grad_x, grad_scale, _, _ = _FusedBatchNormGradGrad(op, *grad) 

1071 return grad_grad_y, grad_x, grad_scale, None, None, None 

1072 

1073 

1074@ops.RegisterGradient("L2Loss") 

1075def _L2LossGrad(op, grad): 

1076 """Return the gradients for L2Loss. 

1077 

1078 Args: 

1079 op: The L2LossOp for which we need to generate gradients. 

1080 grad: Tensor containing a single number. 

1081 

1082 Returns: 

1083 The gradient, which is (x * grad). 

1084 """ 

1085 return op.inputs[0] * grad 

1086 

1087 

1088@ops.RegisterGradient("TopK") 

1089@ops.RegisterGradient("TopKV2") 

1090def _TopKGrad(op, grad, _): 

1091 """Return the gradients for TopK. 

1092 

1093 Args: 

1094 op: The TopKOp for which we need to generate gradients. 

1095 grad: Tensor. The gradients passed to the TopKOp. 

1096 

1097 Returns: 

1098 A list of two tensors, the first being the gradient w.r.t to the input and 

1099 TopK, and the second being the gradient w.r.t. to the indices (all zero). 

1100 """ 

1101 in_shape = array_ops.shape(op.inputs[0]) 

1102 ind_shape = array_ops.shape(op.outputs[1]) 

1103 

1104 # int32 is not supported on GPU hence up-casting 

1105 ind_lastdim = array_ops.gather( 

1106 math_ops.cast(ind_shape, dtypes.int64), 

1107 array_ops.size(ind_shape) - 1) 

1108 # Flatten indices to 2D. 

1109 ind_2d = array_ops.reshape( 

1110 op.outputs[1], array_ops_stack.stack([-1, ind_lastdim])) 

1111 

1112 in_lastdim = array_ops.gather( 

1113 math_ops.cast(in_shape, dtypes.int64), 

1114 array_ops.size(in_shape) - 1) 

1115 outerdim = array_ops.shape(ind_2d)[0] 

1116 # Compute linear indices (flattened to 1D). 

1117 ind = array_ops.reshape( 

1118 ind_2d + math_ops.cast( 

1119 array_ops.expand_dims( 

1120 math_ops.range(0, 

1121 math_ops.cast(outerdim, dtypes.int64) * in_lastdim, 

1122 in_lastdim), -1), dtypes.int32), [-1]) 

1123 

1124 # Substitute grad to appropriate locations and fill the rest with zeros, 

1125 # finally reshaping it to the original input shape. 

1126 return [ 

1127 array_ops.reshape( 

1128 array_ops.scatter_nd( 

1129 array_ops.expand_dims(ind, -1), array_ops.reshape(grad, [-1]), 

1130 [math_ops.reduce_prod(in_shape)]), in_shape), 

1131 array_ops.zeros([], dtype=dtypes.int32) 

1132 ] 

1133 

1134 

1135@ops.RegisterGradient("ApproxTopK") 

1136def _ApproxTopKGradient(op, grad, _): 

1137 """Return the gradients for ApproxTopK. 

1138 

1139 Args: 

1140 op: The ApproxTopK for which we need to generate gradients. 

1141 grad: The gradients for backprop. 

1142 

1143 Returns: 

1144 Scattered gradient based on the top-k indices. 

1145 """ 

1146 # The code below is to generate the correct index and value mapping for 

1147 # scatter_nd to work properly. 

1148 # 

1149 # We use static evaluations as much as possible to reduce the runtime cost. 

1150 # That's said, use operation.shape instead of array_ops.shape; 

1151 # and use functools.reduce(operator.mul, ...) instead of math_ops.reduce_prod 

1152 idx_shape = op.outputs[1].shape 

1153 lifted_idx_shape = idx_shape + [1] 

1154 flat_shape_len = functools.reduce(operator.mul, idx_shape) 

1155 rank = idx_shape.rank 

1156 reduction_dim = op.get_attr("reduction_dimension") 

1157 if reduction_dim < 0: 

1158 reduction_dim = rank + reduction_dim 

1159 

1160 def GetLiftedIdx(d): 

1161 if d == reduction_dim: 

1162 return array_ops.reshape(op.outputs[1], lifted_idx_shape) 

1163 iota_len = idx_shape[d] 

1164 iota_shape = list(itertools.repeat(1, rank + 1)) 

1165 iota_shape[d] = iota_len 

1166 iota = array_ops.reshape(math_ops.range(iota_len), iota_shape) 

1167 return array_ops.broadcast_to(iota, lifted_idx_shape) 

1168 

1169 lifted_idx = array_ops.concat( 

1170 list(GetLiftedIdx(d) for d in range(rank)), axis=rank) 

1171 flat_idx = array_ops.reshape(lifted_idx, [flat_shape_len, rank]) 

1172 flat_grad = array_ops.reshape(grad, [flat_shape_len]) 

1173 return array_ops.scatter_nd(flat_idx, flat_grad, op.inputs[0].shape) 

1174 

1175 

1176@ops.RegisterGradient("NthElement") 

1177def _NthElementGrad(op, grad): 

1178 """Return the gradients for NthElement. 

1179 

1180 Args: 

1181 op: The NthElementOp for which we need to generate gradients. 

1182 grad: Tensor. The gradients passed to the NthElementOp 

1183 

1184 Returns: 

1185 A list of two tensors, the first being the gradient w.r.t. the input, 

1186 the second being the gradient w.r.t. the N (None). 

1187 """ 

1188 input = op.inputs[0] # pylint: disable=redefined-builtin 

1189 output = op.outputs[0] 

1190 

1191 # Compute the number of elements which equal to output in each reduction 

1192 # dimension. If there are multiple elements then the gradient will be 

1193 # divided between them. 

1194 indicators = math_ops.cast( 

1195 math_ops.equal(array_ops.expand_dims(output, -1), input), grad.dtype) 

1196 

1197 grad = array_ops.expand_dims(grad, -1) 

1198 num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1) 

1199 

1200 return [math_ops.divide(indicators, num_selected) * grad, None] 

1201 

1202 

1203def _MeanAggregator(inputs, segments): 

1204 """Replaces each segment with its mean along the last axis. 

1205 

1206 Specifically, each value in the `inputs` tensor gets replaced by the mean 

1207 value computed from the values that belong to the same segment. 

1208 

1209 Args: 

1210 inputs: A 2-tensor. Aggregation is done over dimension 1. 

1211 segments: A 2-tensor, same shape as `input`. 

1212 

1213 Returns: 

1214 The result, same shape and type as `inputs`. 

1215 """ 

1216 result = [] 

1217 for inputs_i, segments_i in zip( 

1218 array_ops.split(inputs, inputs.shape[0]), 

1219 array_ops.split(segments, segments.shape[0])): 

1220 # Note that we do not use tf.math.segment_mean, as it has no TPU support. 

1221 means_i = math_ops.unsorted_segment_mean( 

1222 inputs_i, segments_i, num_segments=math_ops.reduce_max(segments_i) + 1) 

1223 result.append( 

1224 array_ops.reshape(array_ops.gather(means_i, segments_i), [-1])) 

1225 return array_ops_stack.stack(result, axis=0) 

1226 

1227 

1228# We have to register the gradients for these ops so that tensorflow will know 

1229# how to differentiate them. 

1230@ops.RegisterGradient("IsotonicRegression") 

1231def _IsotonicRegressionGrad(op, grad_output, grad_segments): 

1232 """Gradient for the isotonic regression function. 

1233 

1234 Args: 

1235 op: The IsotonicRegression tensorflow op. 

1236 grad_output: Tensor of incoming gradients with respect to the output. 

1237 grad_segments: Tensor of incoming gradients with respect to the segments. 

1238 

1239 Returns: 

1240 A tensor, same size as `grad_output` with the gradient with respect to 

1241 the input. 

1242 """ 

1243 del grad_segments # Discrete, non-differentiable. 

1244 segments = op.outputs[1] 

1245 return _MeanAggregator(grad_output, segments)