Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/array_grad.py: 28%

614 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Gradients for operators defined in array_ops.py.""" 

16 

17from tensorflow.compiler.tf2xla.ops import gen_xla_ops 

18from tensorflow.python import pywrap_tfe 

19from tensorflow.python.eager import context 

20from tensorflow.python.framework import constant_op 

21from tensorflow.python.framework import dtypes 

22from tensorflow.python.framework import indexed_slices as indexed_slices_lib 

23from tensorflow.python.framework import ops 

24from tensorflow.python.framework import sparse_tensor 

25from tensorflow.python.framework import tensor_shape 

26from tensorflow.python.framework import tensor_util 

27from tensorflow.python.ops import array_ops 

28from tensorflow.python.ops import array_ops_stack 

29from tensorflow.python.ops import cond 

30from tensorflow.python.ops import control_flow_util 

31from tensorflow.python.ops import gen_array_ops 

32from tensorflow.python.ops import gen_math_ops 

33from tensorflow.python.ops import gen_resource_variable_ops 

34from tensorflow.python.ops import math_ops 

35from tensorflow.python.ops import sparse_ops 

36 

37 

38@ops.RegisterGradient("Pack") 

39def _PackGrad(op, grad): 

40 """Gradient for pack op.""" 

41 return array_ops_stack.unstack( 

42 grad, num=op.get_attr("N"), axis=op.get_attr("axis")) 

43 

44 

45@ops.RegisterGradient("Unpack") 

46def _UnpackGrad(op, *grads): 

47 """Gradient for unpack op.""" 

48 return array_ops_stack.stack(grads, axis=op.get_attr("axis")) 

49 

50 

51def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): 

52 """Gradient for concat op. 

53 

54 Args: 

55 op: An operation. 

56 grad: `Tensor` or `IndexedSlices` representing the gradients with respect to 

57 each output of the op. 

58 start_value_index: An integer index of the first value in the op.inputs. 

59 end_value_index: An integer index of the last value in the op.inputs. 

60 dim_index: An integer index of concat_dim or axis parameter in op.inputs. 

61 

62 Returns: 

63 Tensors representing the partial gradients with respect to each input 

64 of the op. 

65 

66 Raises: 

67 ValueError: if concat_dim/axis is not statically known. 

68 """ 

69 

70 def _CreateDenseMaskAndBegin(sizes, concat_dim): 

71 """Create variables for iteratively slicing a dense gradients tensor.""" 

72 # Since shape is 1-D, shape_of_shape = [rank-of-inputs] 

73 shape_of_shape = array_ops.shape(sizes[0]) 

74 # Make a vector of length equal to the input's dimensions, 

75 # with 0's everywhere and 1 in the concat dim position. 

76 # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now) 

77 mask = array_ops.concat([ 

78 array_ops.zeros( 

79 array_ops.expand_dims(concat_dim, 0), dtype=dtypes.int32), [1], 

80 array_ops.zeros(shape_of_shape - concat_dim - 1, dtype=dtypes.int32) 

81 ], 0) 

82 begin = array_ops.zeros(shape_of_shape, dtype=dtypes.int32) 

83 return mask, begin 

84 

85 def _ExtractInputShapes(inputs): 

86 """Extract the shapes of a set of input tensors.""" 

87 if context.executing_eagerly(): 

88 return array_ops.shape_n(inputs) 

89 sizes = [] 

90 fully_known = True 

91 for x in inputs: 

92 input_shape = array_ops.shape(x) 

93 if not isinstance(input_shape, 

94 ops.Tensor) or input_shape.op.type != "Const": 

95 fully_known = False 

96 break 

97 sizes.append(input_shape) 

98 

99 if fully_known: 

100 return sizes 

101 else: 

102 return array_ops.shape_n(inputs) 

103 

104 # Degenerate concatenation, just return grad. 

105 if len(op.inputs) == 2: 

106 return grad + [None] if end_value_index <= dim_index else [None] + grad 

107 

108 concat_dim = op.inputs[dim_index] 

109 input_values = op.inputs[start_value_index:end_value_index] 

110 

111 out_grads = [] 

112 if isinstance(grad, ops.Tensor): 

113 if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor): 

114 # Using mod here for convenience since concat_dim is already verified 

115 # in concat implementation to be within the allowed [-rank, rank) range. 

116 non_neg_concat_dim = ( 

117 concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access 

118 # All inputs are guaranteed to be EagerTensors in eager mode 

119 sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values, 

120 non_neg_concat_dim) 

121 out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) 

122 else: 

123 if constant_op.is_constant(concat_dim): 

124 # If concat_dim is a constant defined in a different context, 

125 # then we duplicate it in the current context to avoid passing it 

126 # through an Enter node. 

127 # This is a small optimization in general, but it is required when 

128 # compiling with XLA, as XLA needs the concat input to be folded into a 

129 # constant. 

130 grad_context = control_flow_util.GetOutputContext(grad.op) 

131 dim_context = control_flow_util.GetOutputContext(concat_dim.op) 

132 if dim_context != grad_context: 

133 value = tensor_util.constant_value(concat_dim) 

134 concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype) 

135 

136 # Using mod here for convenience since concat_dim is already verified 

137 # in concat implementation to be within the allowed [-rank, rank) range. 

138 non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) 

139 

140 # Get the inputs' tensor shapes 

141 sizes = _ExtractInputShapes(input_values) 

142 # The magic number of 16 was found through benchmarking a range of sizes 

143 # on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of 

144 # cases when switching implementations at N=16, but it is possible that 

145 # there will be a small number of performance regressions. 

146 if len(sizes) > 16: 

147 # extract the size of each input along the concat dimension 

148 sizes = array_ops.squeeze( 

149 array_ops.slice( 

150 array_ops_stack.stack(sizes, axis=1), [non_neg_concat_dim, 0], 

151 [1, -1])) 

152 out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) 

153 else: 

154 offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes) 

155 for (begin, size) in zip(offset, sizes): 

156 out_grads.append(array_ops.slice(grad, begin, size)) 

157 elif isinstance(grad, indexed_slices_lib.IndexedSlices): 

158 # Using mod here for convenience since concat_dim is already verified 

159 # in concat implementation to be within the allowed [-rank, rank) range. 

160 non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) 

161 concat_dim_static = tensor_util.constant_value(concat_dim) 

162 if concat_dim_static is None: 

163 raise ValueError("Can only compute IndexedSlices gradient with " 

164 "statically-known concat_dim") 

165 if concat_dim_static < 0: 

166 rank = tensor_util.constant_value(array_ops.rank(input_values[0])) 

167 if rank is None: 

168 raise ValueError("Can only compute IndexedSlices gradient with " 

169 "negative concat_dim when first value rank is " 

170 "statically-known.") 

171 concat_dim_static %= rank 

172 # Get the inputs' tensor shapes 

173 sizes = [array_ops.shape(x) for x in input_values] 

174 if concat_dim_static > 0: 

175 # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices 

176 # gradients with all the indices, but with grad.values sliced accordingly. 

177 # This is like the Tensor case, except shape(grad.values)[0] is not equal 

178 # to shape(sizes[i])[0], since only a subset of the dim-0 values are 

179 # stored. 

180 mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim) 

181 for size in sizes: 

182 new_values = array_ops.slice( 

183 grad.values, begin, 

184 array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0)) 

185 out_grads.append( 

186 indexed_slices_lib.IndexedSlices(new_values, grad.indices, size)) 

187 # Lint complains begin = begin + ... 

188 begin = math_ops.add(begin, size * mask) 

189 else: 

190 # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients 

191 # only for the relevant indices. 

192 start = constant_op.constant(0, dtype=grad.indices.dtype) 

193 for size in sizes: 

194 size_concat_dim = array_ops.gather(size, non_neg_concat_dim) 

195 if size_concat_dim.dtype != grad.indices.dtype: 

196 size_concat_dim = math_ops.cast( 

197 size_concat_dim, dtype=grad.indices.dtype) 

198 end = start + size_concat_dim 

199 # Compute the 1-D Tensor of indices relevant for this input. 

200 indices_to_select = array_ops.squeeze( 

201 array_ops.where( 

202 math_ops.logical_and(grad.indices >= start, 

203 grad.indices < end)), 

204 axis=[1]) 

205 new_indices = array_ops.gather(grad.indices, indices_to_select) - start 

206 new_values = array_ops.gather(grad.values, indices_to_select) 

207 out_grads.append( 

208 indexed_slices_lib.IndexedSlices(new_values, new_indices, size)) 

209 start = end 

210 else: 

211 raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad)) 

212 

213 return (out_grads + [None] if end_value_index <= dim_index else [None] + 

214 out_grads) 

215 

216 

217@ops.RegisterGradient("Concat") 

218def _ConcatGrad(op, grad): 

219 return _ConcatGradHelper( 

220 op, 

221 grad, 

222 start_value_index=1, 

223 end_value_index=len(op.inputs), 

224 dim_index=0) 

225 

226 

227@ops.RegisterGradient("ConcatV2") 

228def _ConcatGradV2(op, grad): 

229 return _ConcatGradHelper( 

230 op, grad, start_value_index=0, end_value_index=-1, dim_index=-1) 

231 

232 

233ops.NotDifferentiable("ConcatOffset") 

234 

235 

236@ops.RegisterGradient("Slice") 

237def _SliceGrad(op, grad): 

238 """Gradient for Slice op.""" 

239 # Create an Nx2 padding where the first column represents how many 

240 # zeros are to be prepended for each dimension, and the second 

241 # column indicates how many zeros are appended. 

242 # 

243 # The number of zeros to append is the shape of the input 

244 # elementwise-subtracted by both the begin vector and sizes vector. 

245 # 

246 # Some more reshaping is needed to assemble this tensor with the 

247 # right dimensions. 

248 input_vec = op.inputs[0] 

249 begin_vec = op.inputs[1] 

250 input_rank = array_ops.rank(input_vec) 

251 index_dtype = begin_vec.dtype 

252 slice_size = array_ops.shape(op.outputs[0], out_type=index_dtype) 

253 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 

254 return gen_xla_ops.xla_dynamic_update_slice(array_ops.zeros_like(input_vec), 

255 grad, begin_vec), None, None 

256 

257 shape = array_ops_stack.stack([input_rank, 1]) 

258 before_pad = array_ops.reshape(begin_vec, shape) 

259 after_pad = array_ops.reshape( 

260 array_ops.shape(input_vec, out_type=index_dtype) - slice_size - begin_vec, 

261 shape) 

262 paddings = array_ops.concat([before_pad, after_pad], 1) 

263 return array_ops.pad(grad, paddings), None, None 

264 

265 

266@ops.RegisterGradient("StridedSlice") 

267def _StridedSliceGrad(op, grad): 

268 """Gradient for StridedSlice op.""" 

269 begin = op.inputs[1] 

270 end = op.inputs[2] 

271 strides = op.inputs[3] 

272 # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the 

273 # same dtype so we build a shape of the same type as other args. 

274 # Note that the choice of `begin` for specifying `out_type` is arbitrary. 

275 # We could choose any of {begin|end|strides}.dtype since they are required to 

276 # be the same. 

277 x = array_ops.shape(op.inputs[0], out_type=begin.dtype) 

278 

279 x_static = tensor_util.constant_value(x) 

280 x = x_static if x_static is not None else x 

281 begin_static = tensor_util.constant_value(begin) 

282 begin = begin_static if begin_static is not None else begin 

283 end_static = tensor_util.constant_value(end) 

284 end = end_static if end_static is not None else end 

285 strides_static = tensor_util.constant_value(strides) 

286 strides = strides_static if strides_static is not None else strides 

287 

288 return array_ops.strided_slice_grad( 

289 x, 

290 begin, 

291 end, 

292 strides, 

293 grad, 

294 begin_mask=op.get_attr("begin_mask"), 

295 end_mask=op.get_attr("end_mask"), 

296 ellipsis_mask=op.get_attr("ellipsis_mask"), 

297 new_axis_mask=op.get_attr("new_axis_mask"), 

298 shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None 

299 

300 

301@ops.RegisterGradient("StridedSliceGrad") 

302def _StridedSliceGradGrad(op, grad): 

303 """Gradient for StridedSliceGrad op.""" 

304 begin = op.inputs[1] 

305 end = op.inputs[2] 

306 strides = op.inputs[3] 

307 

308 return None, None, None, None, array_ops.strided_slice( 

309 grad, 

310 begin, 

311 end, 

312 strides, 

313 begin_mask=op.get_attr("begin_mask"), 

314 end_mask=op.get_attr("end_mask"), 

315 ellipsis_mask=op.get_attr("ellipsis_mask"), 

316 new_axis_mask=op.get_attr("new_axis_mask"), 

317 shrink_axis_mask=op.get_attr("shrink_axis_mask")) 

318 

319 

320@ops.RegisterGradient("TensorStridedSliceUpdate") 

321def _TensorStridedSliceUpdateGrad(op, grad): # pylint:disable=missing-function-docstring 

322 begin = op.inputs[1] 

323 end = op.inputs[2] 

324 strides = op.inputs[3] 

325 begin_mask = op.get_attr("begin_mask") 

326 end_mask = op.get_attr("end_mask") 

327 ellipsis_mask = op.get_attr("ellipsis_mask") 

328 new_axis_mask = op.get_attr("new_axis_mask") 

329 shrink_axis_mask = op.get_attr("shrink_axis_mask") 

330 def Apply(f, *args): 

331 return f(*args, 

332 begin_mask=begin_mask, 

333 end_mask=end_mask, 

334 shrink_axis_mask=shrink_axis_mask, 

335 new_axis_mask=new_axis_mask, 

336 ellipsis_mask=ellipsis_mask) 

337 dy = Apply(array_ops.strided_slice, 

338 grad, begin, end, strides) 

339 dx = Apply(array_ops.tensor_strided_slice_update, 

340 grad, begin, end, strides, array_ops.zeros_like(dy)) 

341 

342 # The value is potentially broadcast to the shape of the strided slice, so we 

343 # may need to adjust dy. 

344 slice_shape = array_ops.shape(dy, out_type=begin.dtype) 

345 value_shape = array_ops.shape(op.inputs[4], out_type=slice_shape.dtype) 

346 

347 _, reduction_axes = gen_array_ops.broadcast_gradient_args( 

348 slice_shape, value_shape) 

349 dy_reshaped = math_ops.reduce_sum(dy, axis=reduction_axes, keepdims=True) 

350 dy = array_ops.reshape(dy_reshaped, value_shape) 

351 

352 return dx, None, None, None, dy 

353 

354 

355@ops.RegisterGradient("Split") 

356def _SplitGrad(op, *grads): 

357 return None, array_ops.concat(list(grads), op.inputs[0]) 

358 

359 

360@ops.RegisterGradient("SplitV") 

361def _SplitVGrad(op, *grads): 

362 returnval = array_ops.concat(list(grads), op.inputs[2]) 

363 returnval = [returnval] + [ 

364 None, 

365 ] * ( 

366 len(op.inputs) - 1) 

367 return returnval 

368 

369 

370ops.NotDifferentiable("Const") 

371 

372 

373@ops.RegisterGradient("Diag") 

374def _DiagGrad(_, grad): 

375 return array_ops.diag_part(grad) 

376 

377 

378@ops.RegisterGradient("DiagPart") 

379def _DiagPartGrad(_, grad): 

380 return array_ops.diag(grad) 

381 

382 

383@ops.RegisterGradient("MatrixDiag") 

384def _MatrixDiagGrad(_, grad): 

385 return array_ops.matrix_diag_part(grad) 

386 

387 

388@ops.RegisterGradient("MatrixDiagV2") 

389def _MatrixDiagV2Grad(op, grad): 

390 return array_ops.matrix_diag_part( 

391 grad, k=op.inputs[1]), None, None, None, None 

392 

393 

394@ops.RegisterGradient("MatrixDiagV3") 

395def _MatrixDiagV3Grad(op, grad): 

396 return array_ops.matrix_diag_part( 

397 grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None 

398 

399 

400@ops.RegisterGradient("MatrixDiagPart") 

401def _MatrixDiagPartGrad(op, grad): 

402 matrix_shape = op.inputs[0].get_shape()[-2:] 

403 if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]: 

404 return array_ops.matrix_diag(grad) 

405 else: 

406 return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad) 

407 

408 

409@ops.RegisterGradient("MatrixDiagPartV2") 

410def _MatrixDiagPartV2Grad(op, grad): 

411 """Gradient for MatrixDiagPartV2.""" 

412 matrix_shape = op.inputs[0].get_shape()[-2:] 

413 if matrix_shape.is_fully_defined(): 

414 return array_ops.matrix_diag( 

415 grad, 

416 k=op.inputs[1], 

417 num_rows=matrix_shape[0], 

418 num_cols=matrix_shape[1]), None, None 

419 else: 

420 return array_ops.matrix_set_diag( 

421 array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None 

422 

423 

424@ops.RegisterGradient("MatrixDiagPartV3") 

425def _MatrixDiagPartV3Grad(op, grad): 

426 """Gradient for MatrixDiagPartV3.""" 

427 matrix_shape = op.inputs[0].get_shape()[-2:] 

428 align = op.get_attr("align") 

429 if matrix_shape.is_fully_defined(): 

430 return array_ops.matrix_diag( 

431 grad, 

432 k=op.inputs[1], 

433 num_rows=matrix_shape[0], 

434 num_cols=matrix_shape[1], 

435 align=align), None, None 

436 else: 

437 return array_ops.matrix_set_diag( 

438 array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1], 

439 align=align), None, None 

440 

441 

442@ops.RegisterGradient("MatrixSetDiag") 

443def _MatrixSetDiagGrad(op, grad): 

444 """Gradient for MatrixSetDiag.""" 

445 input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape()) 

446 diag_shape = op.inputs[1].get_shape() 

447 batch_shape = input_shape[:-2].merge_with(diag_shape[:-1]) 

448 matrix_shape = input_shape[-2:] 

449 if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined(): 

450 diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())] 

451 else: 

452 with ops.colocate_with(grad): 

453 grad_shape = array_ops.shape(grad) 

454 grad_rank = array_ops.rank(grad) 

455 batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2]) 

456 matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2]) 

457 min_dim = math_ops.reduce_min(matrix_shape) 

458 diag_shape = array_ops.concat([batch_shape, [min_dim]], 0) 

459 grad_input = array_ops.matrix_set_diag( 

460 grad, array_ops.zeros(diag_shape, dtype=grad.dtype)) 

461 grad_diag = array_ops.matrix_diag_part(grad) 

462 return (grad_input, grad_diag) 

463 

464 

465@ops.RegisterGradient("MatrixSetDiagV2") 

466def _MatrixSetDiagGradV2(op, grad): 

467 """Gradient for MatrixSetDiagV2.""" 

468 diag_shape = op.inputs[1].get_shape() 

469 if not diag_shape.is_fully_defined(): 

470 # Need to know the values of `d_lower` and `d_upper` to infer diag_shape. 

471 grad_shape = array_ops.shape(grad) 

472 batch_shape = grad_shape[:-2] 

473 matrix_shape = grad_shape[-2:] 

474 diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector. 

475 d_lower = diag_index[0] 

476 d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2. 

477 y_offset = cond.cond( 

478 math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0) 

479 x_offset = cond.cond( 

480 math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0) 

481 

482 max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset, 

483 matrix_shape[1] + x_offset) 

484 # pylint: disable=g-long-lambda 

485 # pyformat: disable 

486 postfix = cond.cond( 

487 math_ops.equal(d_lower, d_upper), 

488 lambda: ops.convert_to_tensor([max_diag_len]), 

489 lambda: ops.convert_to_tensor([d_upper - d_lower + 1, 

490 max_diag_len])) 

491 # pyformat: enable 

492 # pylint: enable=g-long-lambda 

493 diag_shape = array_ops.concat([batch_shape, postfix], 0) 

494 

495 grad_input = array_ops.matrix_set_diag( 

496 grad, array_ops.zeros(diag_shape, dtype=grad.dtype), k=op.inputs[2]) 

497 grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2]) 

498 return (grad_input, grad_diag, None) 

499 

500 

501@ops.RegisterGradient("MatrixSetDiagV3") 

502def _MatrixSetDiagGradV3(op, grad): 

503 """Gradient for MatrixSetDiagV3.""" 

504 diag_shape = op.inputs[1].get_shape() 

505 align = op.get_attr("align") 

506 if not diag_shape.is_fully_defined(): 

507 # Need to know the values of `d_lower` and `d_upper` to infer diag_shape. 

508 grad_shape = array_ops.shape(grad) 

509 batch_shape = grad_shape[:-2] 

510 matrix_shape = grad_shape[-2:] 

511 diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector. 

512 d_lower = diag_index[0] 

513 d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2. 

514 y_offset = cond.cond( 

515 math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0) 

516 x_offset = cond.cond( 

517 math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0) 

518 

519 max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset, 

520 matrix_shape[1] + x_offset) 

521 # pylint: disable=g-long-lambda 

522 # pyformat: disable 

523 postfix = cond.cond( 

524 math_ops.equal(d_lower, d_upper), 

525 lambda: ops.convert_to_tensor([max_diag_len]), 

526 lambda: ops.convert_to_tensor([d_upper - d_lower + 1, 

527 max_diag_len])) 

528 # pyformat: enable 

529 # pylint: enable=g-long-lambda 

530 diag_shape = array_ops.concat([batch_shape, postfix], 0) 

531 

532 grad_input = array_ops.matrix_set_diag( 

533 grad, 

534 array_ops.zeros(diag_shape, dtype=grad.dtype), 

535 k=op.inputs[2], 

536 align=align) 

537 grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align) 

538 return (grad_input, grad_diag, None) 

539 

540 

541@ops.RegisterGradient("MatrixBandPart") 

542def _MatrixBandPartGrad(op, grad): 

543 num_lower = op.inputs[1] 

544 num_upper = op.inputs[2] 

545 return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None) 

546 

547 

548# Edit Distance has no gradient (but can be used to eval seq2seq or CTC). 

549ops.NotDifferentiable("EditDistance") 

550 

551 

552@ops.RegisterGradient("Fill") 

553def _FillGrad(_, grad): 

554 return None, math_ops.reduce_sum(grad) 

555 

556 

557ops.NotDifferentiable("ZerosLike") 

558ops.NotDifferentiable("OnesLike") 

559 

560 

561@ops.RegisterGradient("PreventGradient") 

562def _PreventGradientGrad(op, _): 

563 raise LookupError("Gradient explicitly disabled. Reason: %s" % 

564 op.get_attr("message")) 

565 

566 

567def _IndexedSlicesToTensorNoWarning(indexed_slices): 

568 """Converts an IndexedSlices to a Tensor without sparse->dense warnings.""" 

569 if not isinstance(indexed_slices, indexed_slices_lib.IndexedSlices): 

570 # If it is not IndexedSlices, it's better be a tensor. 

571 return indexed_slices 

572 if indexed_slices.dense_shape is None: 

573 raise ValueError( 

574 "Tensor conversion requested for IndexedSlices without dense_shape: %s" 

575 % str(indexed_slices)) 

576 return math_ops.unsorted_segment_sum(indexed_slices.values, 

577 indexed_slices.indices, 

578 indexed_slices.dense_shape[0]) 

579 

580 

581@ops.RegisterGradient("Gather") 

582def _GatherGrad(op, grad): 

583 """Gradient for Gather op.""" 

584 # params can be large, so colocate the shape calculation with it. 

585 params = op.inputs[0] 

586 with ops.colocate_with(params): 

587 params_shape = array_ops.shape(params) 

588 

589 # Build appropriately shaped IndexedSlices 

590 indices = op.inputs[1] 

591 size = array_ops.expand_dims(array_ops.size(indices), 0) 

592 values_shape = array_ops.concat([size, params_shape[1:]], 0) 

593 values = array_ops.reshape( 

594 _IndexedSlicesToTensorNoWarning(grad), values_shape) 

595 indices = array_ops.reshape(indices, size) 

596 return [indexed_slices_lib.IndexedSlices(values, indices, params_shape), None] 

597 

598 

599def _GetBatchIndices(params_shape, indices, batch_dims): 

600 """Addds the batch offsets to the given indices and returns the results.""" 

601 batch_indices = indices 

602 indices_dtype = indices.dtype.base_dtype 

603 casted_params_shape = math_ops.cast(params_shape, indices_dtype) 

604 accum_dim_value = array_ops.ones((), dtype=indices_dtype) 

605 for dim in range(batch_dims, 0, -1): 

606 dim_value = casted_params_shape[dim - 1] 

607 accum_dim_value *= casted_params_shape[dim] 

608 start = array_ops.zeros((), dtype=indices_dtype) 

609 step = array_ops.ones((), dtype=indices_dtype) 

610 dim_indices = math_ops.range(start, dim_value, step) 

611 dim_indices *= accum_dim_value 

612 dim_shape = array_ops.concat([ 

613 array_ops.tile([1], [dim - 1]), [dim_value], 

614 array_ops.tile([1], [array_ops.rank(indices) - dim]) 

615 ], axis=0) 

616 batch_indices += array_ops.reshape(dim_indices, dim_shape) 

617 

618 return batch_indices 

619 

620 

621def _BatchGatherGrad(params_shape, values, indices, batch_dims, 

622 gather_dim_size): 

623 """Returns the gradient of GatherV2 with batch dimensions.""" 

624 

625 # Axis is the first non-batch dimension. 

626 indices_size = array_ops.expand_dims(array_ops.size(indices), 0) 

627 if batch_dims: 

628 values_shape = array_ops.shape(values) 

629 # Add the batch offsets to indices and flatten the batch dimensions. 

630 outer_shape = values_shape[:batch_dims] 

631 inner_shape = values_shape[batch_dims:][1:] 

632 batch_size = gen_math_ops.prod(outer_shape, [0], False) 

633 flat_values_shape = array_ops.concat([[-1], inner_shape], 0) 

634 gather_dim_size *= batch_size 

635 

636 indices = _GetBatchIndices(params_shape, indices, batch_dims) 

637 values = array_ops.reshape( 

638 _IndexedSlicesToTensorNoWarning(values), flat_values_shape) 

639 

640 indices = array_ops.reshape(indices, indices_size) 

641 params_grad = math_ops.unsorted_segment_sum(values, indices, gather_dim_size) 

642 

643 if batch_dims: 

644 # Put back the batch dimensions. 

645 params_grad = array_ops.reshape( 

646 params_grad, array_ops.concat([outer_shape, flat_values_shape], 0)) 

647 

648 return params_grad 

649 

650 

651@ops.RegisterGradient("GatherV2") 

652def _GatherV2Grad(op, grad): 

653 """Gradient for GatherV2 op.""" 

654 # params can be large, so colocate the shape calculation with it. 

655 # 

656 # params can be very large for sparse model, array_ops.shape raises 

657 # exception on the Windows platform when any dimension is larger than 

658 # int32. params_shape is not used in optimizer apply_sparse gradients, 

659 # so it's fine to convert it back to int32 regardless of truncation. 

660 params = op.inputs[0] 

661 with ops.colocate_with(params): 

662 params_shape = array_ops.shape(params, out_type=ops.dtypes.int64) 

663 params_shape = math_ops.cast(params_shape, dtypes.int32) 

664 

665 indices = op.inputs[1] 

666 indices_size = array_ops.expand_dims(array_ops.size(indices), 0) 

667 axis = op.inputs[2] 

668 axis_static = tensor_util.constant_value(axis) 

669 batch_dims = int(op.get_attr("batch_dims")) 

670 

671 if batch_dims < 0: 

672 if indices.shape.ndims is None: 

673 raise ValueError( 

674 f"Currently, it is unsupported to take the gradient of tf.gather " 

675 f"when batch_dims < 0 and the rank of the indices is unknown. Please " 

676 f"pass a positive batch_dims or use tf.ensure_shape to update the " 

677 f"shape of indices when calling tf.gather. Got " 

678 f"batch_dims={batch_dims} and indices={indices}") 

679 batch_dims += indices.shape.ndims 

680 

681 # For axis 0 gathers, build an appropriately shaped IndexedSlices. 

682 if axis_static == 0: 

683 if context.executing_eagerly(): 

684 with ops.device(indices_size.device): 

685 params_tail_shape = array_ops.identity(params_shape)[1:] 

686 else: 

687 params_tail_shape = params_shape[1:] 

688 values_shape = array_ops.concat([indices_size, params_tail_shape], 0) 

689 values = array_ops.reshape( 

690 _IndexedSlicesToTensorNoWarning(grad), values_shape) 

691 indices = array_ops.reshape(indices, indices_size) 

692 params_grad = indexed_slices_lib.IndexedSlices(values, indices, 

693 params_shape) 

694 else: 

695 # Handle axis by transposing the axis dimension to be the first non-batch 

696 # dimension, compute the gradient and transpose the result back. 

697 outer_shape = params_shape[:axis] 

698 inner_shape = params_shape[axis:][1:] 

699 values_shape = array_ops.concat([outer_shape, [-1], inner_shape], 0) 

700 

701 values_dims = array_ops.size(values_shape) 

702 axis_dims = array_ops.size(outer_shape) 

703 

704 outer_batches_indices = math_ops.range(batch_dims) 

705 batch_axis_indices = math_ops.range(batch_dims, axis_dims) 

706 inner_axes_indices = math_ops.range(axis_dims + 1, values_dims) 

707 

708 values = array_ops.reshape( 

709 _IndexedSlicesToTensorNoWarning(grad), values_shape) 

710 

711 # Move values[axis] up to values[batch_dims] 

712 transpose_dims = array_ops.concat([ 

713 outer_batches_indices, [axis_dims], batch_axis_indices, 

714 inner_axes_indices 

715 ], 0) 

716 values_transpose = array_ops.transpose(values, transpose_dims) 

717 params_shape_transpose = array_ops.gather(params_shape, transpose_dims) 

718 

719 params_grad = _BatchGatherGrad(params_shape_transpose, values_transpose, 

720 indices, batch_dims, params_shape[axis]) 

721 

722 # Inverts the above transpose by moving dimension batch_dims back to its 

723 # original position. 

724 invert_transpose_dims = array_ops.concat([ 

725 outer_batches_indices, batch_axis_indices + 1, [batch_dims], 

726 inner_axes_indices 

727 ], 0) 

728 params_grad = array_ops.transpose(params_grad, invert_transpose_dims) 

729 

730 return [params_grad, None, None] 

731 

732 

733@ops.RegisterGradient("GatherNd") 

734def _GatherNdGrad(op, grad): 

735 ref = op.inputs[0] 

736 indices = op.inputs[1] 

737 ref_shape = array_ops.shape(ref, out_type=indices.dtype) 

738 if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1: 

739 ref_grad = indexed_slices_lib.IndexedSlices( 

740 grad, array_ops.squeeze(indices, axis=-1), ref_shape) 

741 else: 

742 ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) 

743 return [ref_grad, None] 

744 

745 

746@ops.RegisterGradient("ResourceGatherNd") 

747def _ResourceGatherNdGrad(op, grad): # pylint: disable=missing-docstring 

748 ref = op.inputs[0] 

749 indices = op.inputs[1] 

750 ref_shape = gen_resource_variable_ops.variable_shape(ref, indices.dtype) 

751 if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1: 

752 ref_grad = indexed_slices_lib.IndexedSlices( 

753 grad, array_ops.squeeze(indices, axis=-1), ref_shape) 

754 else: 

755 ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) 

756 return [ref_grad, None] 

757 

758 

759@ops.RegisterGradient("CheckNumerics") 

760def _CheckNumericsGrad(op, grad): 

761 """Gradient for check_numerics op.""" 

762 return array_ops.check_numerics( 

763 grad, 

764 "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" % 

765 op.get_attr("message")) 

766 

767 

768@ops.RegisterGradient("CheckNumericsV2") 

769def _CheckNumericsV2Grad(op, grad): 

770 """Gradient for check_numerics op.""" 

771 return array_ops.check_numerics_v2( 

772 grad, 

773 "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" % 

774 op.get_attr("message")) 

775 

776 

777@ops.RegisterGradient("PlaceholderWithDefault") 

778@ops.RegisterGradient("Identity") 

779def _IdGrad(_, grad): 

780 return grad 

781 

782 

783@ops.RegisterGradient("_EagerConst") 

784def _EagerConstGrad(_, grad): 

785 raise AssertionError( 

786 "This op should never interact with gradient APIs. Please file a bug.") 

787 

788 

789@ops.RegisterGradient("RefIdentity") 

790def _RefIdGrad(_, grad): 

791 return grad 

792 

793 

794@ops.RegisterGradient("IdentityN") 

795def _IdNGrad(_, *grad): 

796 return grad 

797 

798 

799ops.NotDifferentiable("StopGradient") 

800 

801 

802@ops.RegisterGradient("Reshape") 

803def _ReshapeGrad(op, grad): 

804 return [ 

805 array_ops.reshape( 

806 _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])), 

807 None 

808 ] 

809 

810 

811ops.NotDifferentiable("InvertPermutation") 

812 

813 

814def _ReshapeToInput(op, grad): 

815 """Reshapes the gradient to the shape of the original input.""" 

816 return array_ops.reshape( 

817 _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])) 

818 

819 

820@ops.RegisterGradient("ExpandDims") 

821def _ExpandDimsGrad(op, grad): 

822 return [_ReshapeToInput(op, grad), None] 

823 

824 

825@ops.RegisterGradient("Squeeze") 

826def _SqueezeGrad(op, grad): 

827 return _ReshapeToInput(op, grad) 

828 

829 

830@ops.RegisterGradient("Transpose") 

831def _TransposeGrad(op, grad): 

832 """Returns unshuffle(grad).""" 

833 p = op.inputs[1] 

834 return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None] 

835 

836 

837@ops.RegisterGradient("ConjugateTranspose") 

838def _ConjugateTransposeGrad(op, grad): 

839 """Returns conj(unshuffle(grad)).""" 

840 p = op.inputs[1] 

841 return [ 

842 array_ops.transpose( 

843 grad, array_ops.invert_permutation(p), conjugate=True), None 

844 ] 

845 

846 

847ops.NotDifferentiable("Shape") 

848 

849ops.NotDifferentiable("ShapeN") 

850 

851ops.NotDifferentiable("Rank") 

852 

853ops.NotDifferentiable("Size") 

854 

855 

856@ops.RegisterGradient("Tile") 

857def _TileGrad(op, grad): 

858 """Sum reduces grad along the tiled dimensions.""" 

859 input_shape = array_ops.shape(op.inputs[0], out_type=op.inputs[1].dtype) 

860 # We interleave multiples and input_shape to get split_shape, 

861 # reshape grad to split_shape, and reduce along all even 

862 # dimensions (the tiled dimensions) to get the result 

863 # with shape input_shape. For example 

864 # input_shape = [20, 30, 40] 

865 # multiples = [2, 3, 4] 

866 # split_shape = [2, 20, 3, 30, 4, 40] 

867 # axes = [0, 2, 4] 

868 split_shape = array_ops.reshape( 

869 array_ops.transpose(array_ops_stack.stack([op.inputs[1], input_shape])), 

870 [-1]) 

871 axes = math_ops.range(0, array_ops.size(split_shape), 2) 

872 # Sum reduces grad along the first dimension for IndexedSlices 

873 if isinstance(grad, indexed_slices_lib.IndexedSlices): 

874 input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype) 

875 grad = math_ops.unsorted_segment_sum( 

876 grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0) 

877 split_shape = array_ops.concat([[1], split_shape[1:]], axis=0) 

878 input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) 

879 # Fix shape inference 

880 if not context.executing_eagerly(): 

881 input_grad.set_shape(op.inputs[0].get_shape()) 

882 return [input_grad, None] 

883 

884 

885ops.NotDifferentiable("BroadcastGradientArgs") 

886 

887 

888def _PadGrad(op, grad): 

889 """Gradient for Pad.""" 

890 # Pad introduces values around the original tensor, so the gradient function 

891 # slices the original shape out of the gradient.""" 

892 x = op.inputs[0] 

893 a = op.inputs[1] # [Rank(x), 2] 

894 # Takes a slice of a. The 1st column. [Rank(x), 1]. 

895 pad_before = array_ops.slice(a, [0, 0], 

896 array_ops_stack.stack([array_ops.rank(x), 1])) 

897 # Make it a 1-D tensor. 

898 begin = array_ops.reshape(pad_before, [-1]) 

899 sizes = array_ops.shape(x, out_type=begin.dtype) 

900 x_grad = array_ops.slice(grad, begin, sizes) 

901 if len(op.inputs) == 3: 

902 return x_grad, None, None 

903 else: 

904 return x_grad, None 

905 

906 

907ops.RegisterGradient("Pad")(_PadGrad) 

908ops.RegisterGradient("PadV2")(_PadGrad) 

909 

910 

911# ReverseSequence is just a permutation. The gradient permutes back. 

912@ops.RegisterGradient("ReverseSequence") 

913def _ReverseSequenceGrad(op, grad): 

914 seq_lengths = op.inputs[1] 

915 return [ 

916 array_ops.reverse_sequence( 

917 grad, 

918 batch_axis=op.get_attr("batch_dim"), 

919 seq_axis=op.get_attr("seq_dim"), 

920 seq_lengths=seq_lengths), None 

921 ] 

922 

923 

924@ops.RegisterGradient("Reverse") 

925def _ReverseGrad(op, grad): 

926 reverse_dims = op.inputs[1] 

927 return gen_array_ops.reverse(grad, reverse_dims), None 

928 

929 

930@ops.RegisterGradient("ReverseV2") 

931def _ReverseV2Grad(op, grad): 

932 axis = op.inputs[1] 

933 return array_ops.reverse_v2(grad, axis), None 

934 

935 

936@ops.RegisterGradient("SpaceToBatch") 

937def _SpaceToBatchGrad(op, grad): 

938 # Its gradient is the opposite op: BatchToSpace. 

939 block_size = op.get_attr("block_size") 

940 return [ 

941 array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None 

942 ] 

943 

944 

945@ops.RegisterGradient("SpaceToBatchND") 

946def _SpaceToBatchNDGrad(op, grad): 

947 # Its gradient is the opposite op: BatchToSpaceND. 

948 return [ 

949 array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None 

950 ] 

951 

952 

953@ops.RegisterGradient("BatchToSpace") 

954def _BatchToSpaceGrad(op, grad): 

955 # Its gradient is the opposite op: SpaceToBatch. 

956 block_size = op.get_attr("block_size") 

957 return [ 

958 array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None 

959 ] 

960 

961 

962@ops.RegisterGradient("BatchToSpaceND") 

963def _BatchToSpaceNDGrad(op, grad): 

964 # Its gradient is the opposite op: SpaceToBatchND. 

965 return [ 

966 array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None 

967 ] 

968 

969 

970@ops.RegisterGradient("SpaceToDepth") 

971def _SpaceToDepthGrad(op, grad): 

972 # Its gradient is the opposite op: DepthToSpace. 

973 block_size = op.get_attr("block_size") 

974 data_format = op.get_attr("data_format") 

975 if data_format == "NCHW_VECT_C": 

976 raise ValueError("Cannot compute SpaceToDepth gradient with NCHW_VECT_C. " 

977 "NCHW_VECT_C requires qint8 data type.") 

978 return array_ops.depth_to_space(grad, block_size, data_format=data_format) 

979 

980 

981@ops.RegisterGradient("DepthToSpace") 

982def _DepthToSpaceGrad(op, grad): 

983 # Its gradient is the opposite op: SpaceToDepth. 

984 block_size = op.get_attr("block_size") 

985 data_format = op.get_attr("data_format") 

986 if data_format == "NCHW_VECT_C": 

987 raise ValueError("Cannot compute DepthToSpace gradient with NCHW_VECT_C. " 

988 "NCHW_VECT_C requires qint8 data type.") 

989 return array_ops.space_to_depth(grad, block_size, data_format=data_format) 

990 

991 

992ops.NotDifferentiable("OneHot") 

993 

994 

995@ops.RegisterGradient("MirrorPad") 

996def _MirrorPadGrad(op, grad): 

997 mode = op.get_attr("mode") 

998 return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None] 

999 

1000 

1001@ops.RegisterGradient("MirrorPadGrad") 

1002def _MirrorPadGradGrad(op, grad): 

1003 mode = op.get_attr("mode") 

1004 return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None] 

1005 

1006 

1007@ops.RegisterGradient("QuantizeAndDequantize") 

1008def _QuantizeAndDequantizeGrad(_, grad): 

1009 return grad 

1010 

1011 

1012@ops.RegisterGradient("QuantizeAndDequantizeV2") 

1013def _QuantizeAndDequantizeV2Grad(_, grad): 

1014 return [grad, None, None] 

1015 

1016 

1017@ops.RegisterGradient("QuantizeAndDequantizeV3") 

1018def _QuantizeAndDequantizeV3Grad(_, grad): 

1019 # Only propagate the gradient for the unquantized input. 

1020 return [grad, None, None, None] 

1021 

1022 

1023@ops.RegisterGradient("ExtractImagePatches") 

1024def _ExtractImagePatchesGrad(op, grad): 

1025 input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64) 

1026 batch_size, rows_in, cols_in, channels = array_ops_stack.unstack(input_bhwc) 

1027 

1028 output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64) 

1029 rows_out, cols_out = array_ops_stack.unstack(output_bhwc[1:3]) 

1030 

1031 _, ksize_r, ksize_c, _ = op.get_attr("ksizes") 

1032 

1033 # Create indices matrix for input tensor. 

1034 # Note that 0 is preserved for padding location, 

1035 # so indices for input start from 1 to 1 + rows_in * cols_in. 

1036 input_indices_num = rows_in * cols_in 

1037 # XLA version of extract_image_patches does not support int64, 

1038 # using float32 instead. 

1039 input_idx = array_ops.reshape( 

1040 math_ops.range(1, input_indices_num + 1, dtype=ops.dtypes.float32), 

1041 (1, rows_in, cols_in, 1), 

1042 ) 

1043 input_idx_patched = gen_array_ops.extract_image_patches( 

1044 input_idx, op.get_attr("ksizes"), op.get_attr("strides"), 

1045 op.get_attr("rates"), op.get_attr("padding")) 

1046 input_idx_patched = math_ops.cast(input_idx_patched, dtypes.int64) 

1047 

1048 grad_expanded = array_ops.transpose( 

1049 array_ops.reshape( 

1050 _IndexedSlicesToTensorNoWarning(grad), 

1051 (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)), 

1052 (1, 2, 3, 4, 0, 5)) 

1053 grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) 

1054 

1055 # Shift all input indices back. Padding locations will have "-1" value 

1056 # which is fortunately ignored by segmented sum. 

1057 segment_ids = array_ops.reshape(input_idx_patched, [-1]) - 1 

1058 grad_out = math_ops.unsorted_segment_sum( 

1059 grad_flat, segment_ids, num_segments=input_indices_num 

1060 ) 

1061 

1062 grad_out = array_ops.reshape( 

1063 grad_out, (rows_in, cols_in, batch_size, channels) 

1064 ) 

1065 grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) 

1066 

1067 return [grad_out] 

1068 

1069 

1070@ops.RegisterGradient("ExtractVolumePatches") 

1071def _ExtractVolumePatchesGrad(op, grad): 

1072 batch_size, planes_in, rows_in, cols_in, channels = [ 

1073 dim.value for dim in op.inputs[0].shape.dims 

1074 ] 

1075 input_bphwc = array_ops.shape(op.inputs[0]) 

1076 batch_size = input_bphwc[0] 

1077 channels = input_bphwc[4] 

1078 

1079 # Create indices matrix for input tensor. 

1080 # Note that 0 is preserved for padding location, 

1081 # so indices for input start from 1 to 1 + rows_in * cols_in. 

1082 input_indices_num = 1 + planes_in * rows_in * cols_in 

1083 input_idx = array_ops.reshape( 

1084 math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), 

1085 (1, planes_in, rows_in, cols_in, 1)) 

1086 input_idx_patched = gen_array_ops.extract_volume_patches( 

1087 input_idx, op.get_attr("ksizes"), op.get_attr("strides"), 

1088 op.get_attr("padding")) 

1089 

1090 # Create indices matrix for output tensor. 

1091 _, planes_out, rows_out, cols_out, _ = [ 

1092 dim.value for dim in op.outputs[0].shape.dims 

1093 ] 

1094 _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes") 

1095 # Indices for output start from 0. 

1096 prc_indices_num = planes_out * rows_out * cols_out 

1097 output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c 

1098 output_idx = array_ops.reshape( 

1099 math_ops.range(output_indices_num, dtype=ops.dtypes.int64), 

1100 (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c)) 

1101 

1102 # Construct mapping table for indices: (input -> output). 

1103 idx_matrix = array_ops.concat([ 

1104 array_ops.expand_dims(input_idx_patched, axis=-1), 

1105 array_ops.expand_dims(output_idx, axis=-1) 

1106 ], 

1107 axis=-1) 

1108 idx_map = array_ops.reshape(idx_matrix, (-1, 2)) 

1109 

1110 sp_shape = (input_indices_num, output_indices_num) 

1111 sp_mat_full = sparse_tensor.SparseTensor( 

1112 idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) 

1113 # Remove all padding locations [0, :]. 

1114 sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0), 

1115 (input_indices_num - 1, output_indices_num)) 

1116 

1117 grad_expanded = array_ops.transpose( 

1118 array_ops.reshape( 

1119 _IndexedSlicesToTensorNoWarning(grad), 

1120 (batch_size, planes_out, rows_out, cols_out, ksize_p, ksize_r, 

1121 ksize_c, channels)), (1, 2, 3, 4, 5, 6, 0, 7)) 

1122 grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) 

1123 

1124 jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) 

1125 

1126 grad_out = array_ops.reshape( 

1127 jac, (planes_in, rows_in, cols_in, batch_size, channels)) 

1128 grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4)) 

1129 

1130 return [grad_out] 

1131 

1132 

1133@ops.RegisterGradient("ScatterNd") 

1134def _ScatterNdGrad(op, grad): 

1135 indices = op.inputs[0] 

1136 updates_grad = array_ops.gather_nd(grad, indices) 

1137 return [None, updates_grad, None] 

1138 

1139 

1140@ops.RegisterGradient("TensorScatterUpdate") 

1141def _TensorScatterUpdateGrad(op, grad): 

1142 indices = op.inputs[1] 

1143 updates_grad = array_ops.gather_nd(grad, indices) 

1144 tensor_grad = array_ops.tensor_scatter_update( 

1145 array_ops.identity(grad), indices, 

1146 array_ops.zeros_like(op.inputs[2], dtype=grad.dtype)) 

1147 return [tensor_grad, None, updates_grad] 

1148 

1149 

1150@ops.RegisterGradient("TensorScatterAdd") 

1151def _TensorScatterAddGrad(op, grad): 

1152 indices = op.inputs[1] 

1153 updates_grad = array_ops.gather_nd(grad, indices) 

1154 tensor_grad = array_ops.identity(grad) 

1155 return [tensor_grad, None, updates_grad] 

1156 

1157 

1158def _TensorScatterMinOrMaxGrad(op, grad): 

1159 """Gradient for TensorScatterMin and TensorScatterMax.""" 

1160 indices = op.inputs[1] 

1161 x = op.inputs[0] 

1162 y = op.inputs[2] 

1163 output = op.outputs[0] 

1164 x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype) 

1165 y_output = array_ops.gather_nd(output, indices) 

1166 y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype) 

1167 ys_indicators = array_ops.scatter_nd( 

1168 indices, y_indicators, array_ops.shape(x, out_type=indices.dtype)) 

1169 indicators = x_indicators + ys_indicators # All elements are >= 1. 

1170 # If there are multiple minimum or maximum elements then the gradient will be 

1171 # divided between them. 

1172 x_grad = grad * x_indicators / indicators 

1173 y_grad = array_ops.gather_nd(grad / indicators, indices) * y_indicators 

1174 return [x_grad, None, y_grad] 

1175 

1176 

1177@ops.RegisterGradient("TensorScatterMax") 

1178def _TensorScatterMaxGrad(op, grad): 

1179 """Gradient for TensorScatterMax op.""" 

1180 return _TensorScatterMinOrMaxGrad(op, grad) 

1181 

1182 

1183@ops.RegisterGradient("TensorScatterMin") 

1184def _TensorScatterMinGrad(op, grad): 

1185 """Gradient for TensorScatterMin op.""" 

1186 return _TensorScatterMinOrMaxGrad(op, grad) 

1187 

1188 

1189@ops.RegisterGradient("TensorScatterSub") 

1190def _TensorScatterSubGrad(op, grad): 

1191 indices = op.inputs[1] 

1192 updates_grad = array_ops.gather_nd(grad, indices) 

1193 tensor_grad = array_ops.identity(grad) 

1194 return [tensor_grad, None, -updates_grad] 

1195 

1196 

1197@ops.RegisterGradient("ScatterNdNonAliasingAdd") 

1198def _ScatterNdNonAliasingAddGrad(op, grad): 

1199 indices = op.inputs[1] 

1200 updates_grad = array_ops.gather_nd(grad, indices) 

1201 return [grad, None, updates_grad] 

1202 

1203 

1204@ops.RegisterGradient("BroadcastTo") 

1205def _BroadcastToGrad(op, grad): 

1206 input_value = op.inputs[0] 

1207 broadcast_shape = op.inputs[1] 

1208 shape_dtype = dtypes.int32 

1209 if isinstance(broadcast_shape, ops.Tensor): 

1210 shape_dtype = broadcast_shape.dtype 

1211 

1212 input_value_shape = array_ops.shape(input_value, out_type=shape_dtype) 

1213 if not isinstance(broadcast_shape, ops.EagerTensor): 

1214 broadcast_shape_static = tensor_shape.TensorShape( 

1215 tensor_util.try_evaluate_constant(broadcast_shape)) 

1216 if broadcast_shape_static.is_fully_defined(): 

1217 broadcast_shape = constant_op.constant( 

1218 broadcast_shape_static.as_list(), dtype=shape_dtype) 

1219 _, reduction_axes = gen_array_ops.broadcast_gradient_args( 

1220 broadcast_shape, input_value_shape) 

1221 updates_grad_reshaped = math_ops.reduce_sum( 

1222 grad, axis=reduction_axes, keepdims=True) 

1223 updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape) 

1224 return [updates_grad, None]