Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/array_grad.py: 28%
614 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in array_ops.py."""
17from tensorflow.compiler.tf2xla.ops import gen_xla_ops
18from tensorflow.python import pywrap_tfe
19from tensorflow.python.eager import context
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import indexed_slices as indexed_slices_lib
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import array_ops_stack
29from tensorflow.python.ops import cond
30from tensorflow.python.ops import control_flow_util
31from tensorflow.python.ops import gen_array_ops
32from tensorflow.python.ops import gen_math_ops
33from tensorflow.python.ops import gen_resource_variable_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import sparse_ops
38@ops.RegisterGradient("Pack")
39def _PackGrad(op, grad):
40 """Gradient for pack op."""
41 return array_ops_stack.unstack(
42 grad, num=op.get_attr("N"), axis=op.get_attr("axis"))
45@ops.RegisterGradient("Unpack")
46def _UnpackGrad(op, *grads):
47 """Gradient for unpack op."""
48 return array_ops_stack.stack(grads, axis=op.get_attr("axis"))
51def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
52 """Gradient for concat op.
54 Args:
55 op: An operation.
56 grad: `Tensor` or `IndexedSlices` representing the gradients with respect to
57 each output of the op.
58 start_value_index: An integer index of the first value in the op.inputs.
59 end_value_index: An integer index of the last value in the op.inputs.
60 dim_index: An integer index of concat_dim or axis parameter in op.inputs.
62 Returns:
63 Tensors representing the partial gradients with respect to each input
64 of the op.
66 Raises:
67 ValueError: if concat_dim/axis is not statically known.
68 """
70 def _CreateDenseMaskAndBegin(sizes, concat_dim):
71 """Create variables for iteratively slicing a dense gradients tensor."""
72 # Since shape is 1-D, shape_of_shape = [rank-of-inputs]
73 shape_of_shape = array_ops.shape(sizes[0])
74 # Make a vector of length equal to the input's dimensions,
75 # with 0's everywhere and 1 in the concat dim position.
76 # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
77 mask = array_ops.concat([
78 array_ops.zeros(
79 array_ops.expand_dims(concat_dim, 0), dtype=dtypes.int32), [1],
80 array_ops.zeros(shape_of_shape - concat_dim - 1, dtype=dtypes.int32)
81 ], 0)
82 begin = array_ops.zeros(shape_of_shape, dtype=dtypes.int32)
83 return mask, begin
85 def _ExtractInputShapes(inputs):
86 """Extract the shapes of a set of input tensors."""
87 if context.executing_eagerly():
88 return array_ops.shape_n(inputs)
89 sizes = []
90 fully_known = True
91 for x in inputs:
92 input_shape = array_ops.shape(x)
93 if not isinstance(input_shape,
94 ops.Tensor) or input_shape.op.type != "Const":
95 fully_known = False
96 break
97 sizes.append(input_shape)
99 if fully_known:
100 return sizes
101 else:
102 return array_ops.shape_n(inputs)
104 # Degenerate concatenation, just return grad.
105 if len(op.inputs) == 2:
106 return grad + [None] if end_value_index <= dim_index else [None] + grad
108 concat_dim = op.inputs[dim_index]
109 input_values = op.inputs[start_value_index:end_value_index]
111 out_grads = []
112 if isinstance(grad, ops.Tensor):
113 if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor):
114 # Using mod here for convenience since concat_dim is already verified
115 # in concat implementation to be within the allowed [-rank, rank) range.
116 non_neg_concat_dim = (
117 concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access
118 # All inputs are guaranteed to be EagerTensors in eager mode
119 sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values,
120 non_neg_concat_dim)
121 out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
122 else:
123 if constant_op.is_constant(concat_dim):
124 # If concat_dim is a constant defined in a different context,
125 # then we duplicate it in the current context to avoid passing it
126 # through an Enter node.
127 # This is a small optimization in general, but it is required when
128 # compiling with XLA, as XLA needs the concat input to be folded into a
129 # constant.
130 grad_context = control_flow_util.GetOutputContext(grad.op)
131 dim_context = control_flow_util.GetOutputContext(concat_dim.op)
132 if dim_context != grad_context:
133 value = tensor_util.constant_value(concat_dim)
134 concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype)
136 # Using mod here for convenience since concat_dim is already verified
137 # in concat implementation to be within the allowed [-rank, rank) range.
138 non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
140 # Get the inputs' tensor shapes
141 sizes = _ExtractInputShapes(input_values)
142 # The magic number of 16 was found through benchmarking a range of sizes
143 # on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of
144 # cases when switching implementations at N=16, but it is possible that
145 # there will be a small number of performance regressions.
146 if len(sizes) > 16:
147 # extract the size of each input along the concat dimension
148 sizes = array_ops.squeeze(
149 array_ops.slice(
150 array_ops_stack.stack(sizes, axis=1), [non_neg_concat_dim, 0],
151 [1, -1]))
152 out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
153 else:
154 offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes)
155 for (begin, size) in zip(offset, sizes):
156 out_grads.append(array_ops.slice(grad, begin, size))
157 elif isinstance(grad, indexed_slices_lib.IndexedSlices):
158 # Using mod here for convenience since concat_dim is already verified
159 # in concat implementation to be within the allowed [-rank, rank) range.
160 non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
161 concat_dim_static = tensor_util.constant_value(concat_dim)
162 if concat_dim_static is None:
163 raise ValueError("Can only compute IndexedSlices gradient with "
164 "statically-known concat_dim")
165 if concat_dim_static < 0:
166 rank = tensor_util.constant_value(array_ops.rank(input_values[0]))
167 if rank is None:
168 raise ValueError("Can only compute IndexedSlices gradient with "
169 "negative concat_dim when first value rank is "
170 "statically-known.")
171 concat_dim_static %= rank
172 # Get the inputs' tensor shapes
173 sizes = [array_ops.shape(x) for x in input_values]
174 if concat_dim_static > 0:
175 # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices
176 # gradients with all the indices, but with grad.values sliced accordingly.
177 # This is like the Tensor case, except shape(grad.values)[0] is not equal
178 # to shape(sizes[i])[0], since only a subset of the dim-0 values are
179 # stored.
180 mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim)
181 for size in sizes:
182 new_values = array_ops.slice(
183 grad.values, begin,
184 array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0))
185 out_grads.append(
186 indexed_slices_lib.IndexedSlices(new_values, grad.indices, size))
187 # Lint complains begin = begin + ...
188 begin = math_ops.add(begin, size * mask)
189 else:
190 # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients
191 # only for the relevant indices.
192 start = constant_op.constant(0, dtype=grad.indices.dtype)
193 for size in sizes:
194 size_concat_dim = array_ops.gather(size, non_neg_concat_dim)
195 if size_concat_dim.dtype != grad.indices.dtype:
196 size_concat_dim = math_ops.cast(
197 size_concat_dim, dtype=grad.indices.dtype)
198 end = start + size_concat_dim
199 # Compute the 1-D Tensor of indices relevant for this input.
200 indices_to_select = array_ops.squeeze(
201 array_ops.where(
202 math_ops.logical_and(grad.indices >= start,
203 grad.indices < end)),
204 axis=[1])
205 new_indices = array_ops.gather(grad.indices, indices_to_select) - start
206 new_values = array_ops.gather(grad.values, indices_to_select)
207 out_grads.append(
208 indexed_slices_lib.IndexedSlices(new_values, new_indices, size))
209 start = end
210 else:
211 raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad))
213 return (out_grads + [None] if end_value_index <= dim_index else [None] +
214 out_grads)
217@ops.RegisterGradient("Concat")
218def _ConcatGrad(op, grad):
219 return _ConcatGradHelper(
220 op,
221 grad,
222 start_value_index=1,
223 end_value_index=len(op.inputs),
224 dim_index=0)
227@ops.RegisterGradient("ConcatV2")
228def _ConcatGradV2(op, grad):
229 return _ConcatGradHelper(
230 op, grad, start_value_index=0, end_value_index=-1, dim_index=-1)
233ops.NotDifferentiable("ConcatOffset")
236@ops.RegisterGradient("Slice")
237def _SliceGrad(op, grad):
238 """Gradient for Slice op."""
239 # Create an Nx2 padding where the first column represents how many
240 # zeros are to be prepended for each dimension, and the second
241 # column indicates how many zeros are appended.
242 #
243 # The number of zeros to append is the shape of the input
244 # elementwise-subtracted by both the begin vector and sizes vector.
245 #
246 # Some more reshaping is needed to assemble this tensor with the
247 # right dimensions.
248 input_vec = op.inputs[0]
249 begin_vec = op.inputs[1]
250 input_rank = array_ops.rank(input_vec)
251 index_dtype = begin_vec.dtype
252 slice_size = array_ops.shape(op.outputs[0], out_type=index_dtype)
253 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
254 return gen_xla_ops.xla_dynamic_update_slice(array_ops.zeros_like(input_vec),
255 grad, begin_vec), None, None
257 shape = array_ops_stack.stack([input_rank, 1])
258 before_pad = array_ops.reshape(begin_vec, shape)
259 after_pad = array_ops.reshape(
260 array_ops.shape(input_vec, out_type=index_dtype) - slice_size - begin_vec,
261 shape)
262 paddings = array_ops.concat([before_pad, after_pad], 1)
263 return array_ops.pad(grad, paddings), None, None
266@ops.RegisterGradient("StridedSlice")
267def _StridedSliceGrad(op, grad):
268 """Gradient for StridedSlice op."""
269 begin = op.inputs[1]
270 end = op.inputs[2]
271 strides = op.inputs[3]
272 # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the
273 # same dtype so we build a shape of the same type as other args.
274 # Note that the choice of `begin` for specifying `out_type` is arbitrary.
275 # We could choose any of {begin|end|strides}.dtype since they are required to
276 # be the same.
277 x = array_ops.shape(op.inputs[0], out_type=begin.dtype)
279 x_static = tensor_util.constant_value(x)
280 x = x_static if x_static is not None else x
281 begin_static = tensor_util.constant_value(begin)
282 begin = begin_static if begin_static is not None else begin
283 end_static = tensor_util.constant_value(end)
284 end = end_static if end_static is not None else end
285 strides_static = tensor_util.constant_value(strides)
286 strides = strides_static if strides_static is not None else strides
288 return array_ops.strided_slice_grad(
289 x,
290 begin,
291 end,
292 strides,
293 grad,
294 begin_mask=op.get_attr("begin_mask"),
295 end_mask=op.get_attr("end_mask"),
296 ellipsis_mask=op.get_attr("ellipsis_mask"),
297 new_axis_mask=op.get_attr("new_axis_mask"),
298 shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None
301@ops.RegisterGradient("StridedSliceGrad")
302def _StridedSliceGradGrad(op, grad):
303 """Gradient for StridedSliceGrad op."""
304 begin = op.inputs[1]
305 end = op.inputs[2]
306 strides = op.inputs[3]
308 return None, None, None, None, array_ops.strided_slice(
309 grad,
310 begin,
311 end,
312 strides,
313 begin_mask=op.get_attr("begin_mask"),
314 end_mask=op.get_attr("end_mask"),
315 ellipsis_mask=op.get_attr("ellipsis_mask"),
316 new_axis_mask=op.get_attr("new_axis_mask"),
317 shrink_axis_mask=op.get_attr("shrink_axis_mask"))
320@ops.RegisterGradient("TensorStridedSliceUpdate")
321def _TensorStridedSliceUpdateGrad(op, grad): # pylint:disable=missing-function-docstring
322 begin = op.inputs[1]
323 end = op.inputs[2]
324 strides = op.inputs[3]
325 begin_mask = op.get_attr("begin_mask")
326 end_mask = op.get_attr("end_mask")
327 ellipsis_mask = op.get_attr("ellipsis_mask")
328 new_axis_mask = op.get_attr("new_axis_mask")
329 shrink_axis_mask = op.get_attr("shrink_axis_mask")
330 def Apply(f, *args):
331 return f(*args,
332 begin_mask=begin_mask,
333 end_mask=end_mask,
334 shrink_axis_mask=shrink_axis_mask,
335 new_axis_mask=new_axis_mask,
336 ellipsis_mask=ellipsis_mask)
337 dy = Apply(array_ops.strided_slice,
338 grad, begin, end, strides)
339 dx = Apply(array_ops.tensor_strided_slice_update,
340 grad, begin, end, strides, array_ops.zeros_like(dy))
342 # The value is potentially broadcast to the shape of the strided slice, so we
343 # may need to adjust dy.
344 slice_shape = array_ops.shape(dy, out_type=begin.dtype)
345 value_shape = array_ops.shape(op.inputs[4], out_type=slice_shape.dtype)
347 _, reduction_axes = gen_array_ops.broadcast_gradient_args(
348 slice_shape, value_shape)
349 dy_reshaped = math_ops.reduce_sum(dy, axis=reduction_axes, keepdims=True)
350 dy = array_ops.reshape(dy_reshaped, value_shape)
352 return dx, None, None, None, dy
355@ops.RegisterGradient("Split")
356def _SplitGrad(op, *grads):
357 return None, array_ops.concat(list(grads), op.inputs[0])
360@ops.RegisterGradient("SplitV")
361def _SplitVGrad(op, *grads):
362 returnval = array_ops.concat(list(grads), op.inputs[2])
363 returnval = [returnval] + [
364 None,
365 ] * (
366 len(op.inputs) - 1)
367 return returnval
370ops.NotDifferentiable("Const")
373@ops.RegisterGradient("Diag")
374def _DiagGrad(_, grad):
375 return array_ops.diag_part(grad)
378@ops.RegisterGradient("DiagPart")
379def _DiagPartGrad(_, grad):
380 return array_ops.diag(grad)
383@ops.RegisterGradient("MatrixDiag")
384def _MatrixDiagGrad(_, grad):
385 return array_ops.matrix_diag_part(grad)
388@ops.RegisterGradient("MatrixDiagV2")
389def _MatrixDiagV2Grad(op, grad):
390 return array_ops.matrix_diag_part(
391 grad, k=op.inputs[1]), None, None, None, None
394@ops.RegisterGradient("MatrixDiagV3")
395def _MatrixDiagV3Grad(op, grad):
396 return array_ops.matrix_diag_part(
397 grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None
400@ops.RegisterGradient("MatrixDiagPart")
401def _MatrixDiagPartGrad(op, grad):
402 matrix_shape = op.inputs[0].get_shape()[-2:]
403 if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]:
404 return array_ops.matrix_diag(grad)
405 else:
406 return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad)
409@ops.RegisterGradient("MatrixDiagPartV2")
410def _MatrixDiagPartV2Grad(op, grad):
411 """Gradient for MatrixDiagPartV2."""
412 matrix_shape = op.inputs[0].get_shape()[-2:]
413 if matrix_shape.is_fully_defined():
414 return array_ops.matrix_diag(
415 grad,
416 k=op.inputs[1],
417 num_rows=matrix_shape[0],
418 num_cols=matrix_shape[1]), None, None
419 else:
420 return array_ops.matrix_set_diag(
421 array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None
424@ops.RegisterGradient("MatrixDiagPartV3")
425def _MatrixDiagPartV3Grad(op, grad):
426 """Gradient for MatrixDiagPartV3."""
427 matrix_shape = op.inputs[0].get_shape()[-2:]
428 align = op.get_attr("align")
429 if matrix_shape.is_fully_defined():
430 return array_ops.matrix_diag(
431 grad,
432 k=op.inputs[1],
433 num_rows=matrix_shape[0],
434 num_cols=matrix_shape[1],
435 align=align), None, None
436 else:
437 return array_ops.matrix_set_diag(
438 array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1],
439 align=align), None, None
442@ops.RegisterGradient("MatrixSetDiag")
443def _MatrixSetDiagGrad(op, grad):
444 """Gradient for MatrixSetDiag."""
445 input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
446 diag_shape = op.inputs[1].get_shape()
447 batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
448 matrix_shape = input_shape[-2:]
449 if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
450 diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
451 else:
452 with ops.colocate_with(grad):
453 grad_shape = array_ops.shape(grad)
454 grad_rank = array_ops.rank(grad)
455 batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
456 matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
457 min_dim = math_ops.reduce_min(matrix_shape)
458 diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
459 grad_input = array_ops.matrix_set_diag(
460 grad, array_ops.zeros(diag_shape, dtype=grad.dtype))
461 grad_diag = array_ops.matrix_diag_part(grad)
462 return (grad_input, grad_diag)
465@ops.RegisterGradient("MatrixSetDiagV2")
466def _MatrixSetDiagGradV2(op, grad):
467 """Gradient for MatrixSetDiagV2."""
468 diag_shape = op.inputs[1].get_shape()
469 if not diag_shape.is_fully_defined():
470 # Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
471 grad_shape = array_ops.shape(grad)
472 batch_shape = grad_shape[:-2]
473 matrix_shape = grad_shape[-2:]
474 diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector.
475 d_lower = diag_index[0]
476 d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2.
477 y_offset = cond.cond(
478 math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0)
479 x_offset = cond.cond(
480 math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0)
482 max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset,
483 matrix_shape[1] + x_offset)
484 # pylint: disable=g-long-lambda
485 # pyformat: disable
486 postfix = cond.cond(
487 math_ops.equal(d_lower, d_upper),
488 lambda: ops.convert_to_tensor([max_diag_len]),
489 lambda: ops.convert_to_tensor([d_upper - d_lower + 1,
490 max_diag_len]))
491 # pyformat: enable
492 # pylint: enable=g-long-lambda
493 diag_shape = array_ops.concat([batch_shape, postfix], 0)
495 grad_input = array_ops.matrix_set_diag(
496 grad, array_ops.zeros(diag_shape, dtype=grad.dtype), k=op.inputs[2])
497 grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2])
498 return (grad_input, grad_diag, None)
501@ops.RegisterGradient("MatrixSetDiagV3")
502def _MatrixSetDiagGradV3(op, grad):
503 """Gradient for MatrixSetDiagV3."""
504 diag_shape = op.inputs[1].get_shape()
505 align = op.get_attr("align")
506 if not diag_shape.is_fully_defined():
507 # Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
508 grad_shape = array_ops.shape(grad)
509 batch_shape = grad_shape[:-2]
510 matrix_shape = grad_shape[-2:]
511 diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector.
512 d_lower = diag_index[0]
513 d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2.
514 y_offset = cond.cond(
515 math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0)
516 x_offset = cond.cond(
517 math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0)
519 max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset,
520 matrix_shape[1] + x_offset)
521 # pylint: disable=g-long-lambda
522 # pyformat: disable
523 postfix = cond.cond(
524 math_ops.equal(d_lower, d_upper),
525 lambda: ops.convert_to_tensor([max_diag_len]),
526 lambda: ops.convert_to_tensor([d_upper - d_lower + 1,
527 max_diag_len]))
528 # pyformat: enable
529 # pylint: enable=g-long-lambda
530 diag_shape = array_ops.concat([batch_shape, postfix], 0)
532 grad_input = array_ops.matrix_set_diag(
533 grad,
534 array_ops.zeros(diag_shape, dtype=grad.dtype),
535 k=op.inputs[2],
536 align=align)
537 grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align)
538 return (grad_input, grad_diag, None)
541@ops.RegisterGradient("MatrixBandPart")
542def _MatrixBandPartGrad(op, grad):
543 num_lower = op.inputs[1]
544 num_upper = op.inputs[2]
545 return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None)
548# Edit Distance has no gradient (but can be used to eval seq2seq or CTC).
549ops.NotDifferentiable("EditDistance")
552@ops.RegisterGradient("Fill")
553def _FillGrad(_, grad):
554 return None, math_ops.reduce_sum(grad)
557ops.NotDifferentiable("ZerosLike")
558ops.NotDifferentiable("OnesLike")
561@ops.RegisterGradient("PreventGradient")
562def _PreventGradientGrad(op, _):
563 raise LookupError("Gradient explicitly disabled. Reason: %s" %
564 op.get_attr("message"))
567def _IndexedSlicesToTensorNoWarning(indexed_slices):
568 """Converts an IndexedSlices to a Tensor without sparse->dense warnings."""
569 if not isinstance(indexed_slices, indexed_slices_lib.IndexedSlices):
570 # If it is not IndexedSlices, it's better be a tensor.
571 return indexed_slices
572 if indexed_slices.dense_shape is None:
573 raise ValueError(
574 "Tensor conversion requested for IndexedSlices without dense_shape: %s"
575 % str(indexed_slices))
576 return math_ops.unsorted_segment_sum(indexed_slices.values,
577 indexed_slices.indices,
578 indexed_slices.dense_shape[0])
581@ops.RegisterGradient("Gather")
582def _GatherGrad(op, grad):
583 """Gradient for Gather op."""
584 # params can be large, so colocate the shape calculation with it.
585 params = op.inputs[0]
586 with ops.colocate_with(params):
587 params_shape = array_ops.shape(params)
589 # Build appropriately shaped IndexedSlices
590 indices = op.inputs[1]
591 size = array_ops.expand_dims(array_ops.size(indices), 0)
592 values_shape = array_ops.concat([size, params_shape[1:]], 0)
593 values = array_ops.reshape(
594 _IndexedSlicesToTensorNoWarning(grad), values_shape)
595 indices = array_ops.reshape(indices, size)
596 return [indexed_slices_lib.IndexedSlices(values, indices, params_shape), None]
599def _GetBatchIndices(params_shape, indices, batch_dims):
600 """Addds the batch offsets to the given indices and returns the results."""
601 batch_indices = indices
602 indices_dtype = indices.dtype.base_dtype
603 casted_params_shape = math_ops.cast(params_shape, indices_dtype)
604 accum_dim_value = array_ops.ones((), dtype=indices_dtype)
605 for dim in range(batch_dims, 0, -1):
606 dim_value = casted_params_shape[dim - 1]
607 accum_dim_value *= casted_params_shape[dim]
608 start = array_ops.zeros((), dtype=indices_dtype)
609 step = array_ops.ones((), dtype=indices_dtype)
610 dim_indices = math_ops.range(start, dim_value, step)
611 dim_indices *= accum_dim_value
612 dim_shape = array_ops.concat([
613 array_ops.tile([1], [dim - 1]), [dim_value],
614 array_ops.tile([1], [array_ops.rank(indices) - dim])
615 ], axis=0)
616 batch_indices += array_ops.reshape(dim_indices, dim_shape)
618 return batch_indices
621def _BatchGatherGrad(params_shape, values, indices, batch_dims,
622 gather_dim_size):
623 """Returns the gradient of GatherV2 with batch dimensions."""
625 # Axis is the first non-batch dimension.
626 indices_size = array_ops.expand_dims(array_ops.size(indices), 0)
627 if batch_dims:
628 values_shape = array_ops.shape(values)
629 # Add the batch offsets to indices and flatten the batch dimensions.
630 outer_shape = values_shape[:batch_dims]
631 inner_shape = values_shape[batch_dims:][1:]
632 batch_size = gen_math_ops.prod(outer_shape, [0], False)
633 flat_values_shape = array_ops.concat([[-1], inner_shape], 0)
634 gather_dim_size *= batch_size
636 indices = _GetBatchIndices(params_shape, indices, batch_dims)
637 values = array_ops.reshape(
638 _IndexedSlicesToTensorNoWarning(values), flat_values_shape)
640 indices = array_ops.reshape(indices, indices_size)
641 params_grad = math_ops.unsorted_segment_sum(values, indices, gather_dim_size)
643 if batch_dims:
644 # Put back the batch dimensions.
645 params_grad = array_ops.reshape(
646 params_grad, array_ops.concat([outer_shape, flat_values_shape], 0))
648 return params_grad
651@ops.RegisterGradient("GatherV2")
652def _GatherV2Grad(op, grad):
653 """Gradient for GatherV2 op."""
654 # params can be large, so colocate the shape calculation with it.
655 #
656 # params can be very large for sparse model, array_ops.shape raises
657 # exception on the Windows platform when any dimension is larger than
658 # int32. params_shape is not used in optimizer apply_sparse gradients,
659 # so it's fine to convert it back to int32 regardless of truncation.
660 params = op.inputs[0]
661 with ops.colocate_with(params):
662 params_shape = array_ops.shape(params, out_type=ops.dtypes.int64)
663 params_shape = math_ops.cast(params_shape, dtypes.int32)
665 indices = op.inputs[1]
666 indices_size = array_ops.expand_dims(array_ops.size(indices), 0)
667 axis = op.inputs[2]
668 axis_static = tensor_util.constant_value(axis)
669 batch_dims = int(op.get_attr("batch_dims"))
671 if batch_dims < 0:
672 if indices.shape.ndims is None:
673 raise ValueError(
674 f"Currently, it is unsupported to take the gradient of tf.gather "
675 f"when batch_dims < 0 and the rank of the indices is unknown. Please "
676 f"pass a positive batch_dims or use tf.ensure_shape to update the "
677 f"shape of indices when calling tf.gather. Got "
678 f"batch_dims={batch_dims} and indices={indices}")
679 batch_dims += indices.shape.ndims
681 # For axis 0 gathers, build an appropriately shaped IndexedSlices.
682 if axis_static == 0:
683 if context.executing_eagerly():
684 with ops.device(indices_size.device):
685 params_tail_shape = array_ops.identity(params_shape)[1:]
686 else:
687 params_tail_shape = params_shape[1:]
688 values_shape = array_ops.concat([indices_size, params_tail_shape], 0)
689 values = array_ops.reshape(
690 _IndexedSlicesToTensorNoWarning(grad), values_shape)
691 indices = array_ops.reshape(indices, indices_size)
692 params_grad = indexed_slices_lib.IndexedSlices(values, indices,
693 params_shape)
694 else:
695 # Handle axis by transposing the axis dimension to be the first non-batch
696 # dimension, compute the gradient and transpose the result back.
697 outer_shape = params_shape[:axis]
698 inner_shape = params_shape[axis:][1:]
699 values_shape = array_ops.concat([outer_shape, [-1], inner_shape], 0)
701 values_dims = array_ops.size(values_shape)
702 axis_dims = array_ops.size(outer_shape)
704 outer_batches_indices = math_ops.range(batch_dims)
705 batch_axis_indices = math_ops.range(batch_dims, axis_dims)
706 inner_axes_indices = math_ops.range(axis_dims + 1, values_dims)
708 values = array_ops.reshape(
709 _IndexedSlicesToTensorNoWarning(grad), values_shape)
711 # Move values[axis] up to values[batch_dims]
712 transpose_dims = array_ops.concat([
713 outer_batches_indices, [axis_dims], batch_axis_indices,
714 inner_axes_indices
715 ], 0)
716 values_transpose = array_ops.transpose(values, transpose_dims)
717 params_shape_transpose = array_ops.gather(params_shape, transpose_dims)
719 params_grad = _BatchGatherGrad(params_shape_transpose, values_transpose,
720 indices, batch_dims, params_shape[axis])
722 # Inverts the above transpose by moving dimension batch_dims back to its
723 # original position.
724 invert_transpose_dims = array_ops.concat([
725 outer_batches_indices, batch_axis_indices + 1, [batch_dims],
726 inner_axes_indices
727 ], 0)
728 params_grad = array_ops.transpose(params_grad, invert_transpose_dims)
730 return [params_grad, None, None]
733@ops.RegisterGradient("GatherNd")
734def _GatherNdGrad(op, grad):
735 ref = op.inputs[0]
736 indices = op.inputs[1]
737 ref_shape = array_ops.shape(ref, out_type=indices.dtype)
738 if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
739 ref_grad = indexed_slices_lib.IndexedSlices(
740 grad, array_ops.squeeze(indices, axis=-1), ref_shape)
741 else:
742 ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
743 return [ref_grad, None]
746@ops.RegisterGradient("ResourceGatherNd")
747def _ResourceGatherNdGrad(op, grad): # pylint: disable=missing-docstring
748 ref = op.inputs[0]
749 indices = op.inputs[1]
750 ref_shape = gen_resource_variable_ops.variable_shape(ref, indices.dtype)
751 if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
752 ref_grad = indexed_slices_lib.IndexedSlices(
753 grad, array_ops.squeeze(indices, axis=-1), ref_shape)
754 else:
755 ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
756 return [ref_grad, None]
759@ops.RegisterGradient("CheckNumerics")
760def _CheckNumericsGrad(op, grad):
761 """Gradient for check_numerics op."""
762 return array_ops.check_numerics(
763 grad,
764 "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
765 op.get_attr("message"))
768@ops.RegisterGradient("CheckNumericsV2")
769def _CheckNumericsV2Grad(op, grad):
770 """Gradient for check_numerics op."""
771 return array_ops.check_numerics_v2(
772 grad,
773 "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
774 op.get_attr("message"))
777@ops.RegisterGradient("PlaceholderWithDefault")
778@ops.RegisterGradient("Identity")
779def _IdGrad(_, grad):
780 return grad
783@ops.RegisterGradient("_EagerConst")
784def _EagerConstGrad(_, grad):
785 raise AssertionError(
786 "This op should never interact with gradient APIs. Please file a bug.")
789@ops.RegisterGradient("RefIdentity")
790def _RefIdGrad(_, grad):
791 return grad
794@ops.RegisterGradient("IdentityN")
795def _IdNGrad(_, *grad):
796 return grad
799ops.NotDifferentiable("StopGradient")
802@ops.RegisterGradient("Reshape")
803def _ReshapeGrad(op, grad):
804 return [
805 array_ops.reshape(
806 _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])),
807 None
808 ]
811ops.NotDifferentiable("InvertPermutation")
814def _ReshapeToInput(op, grad):
815 """Reshapes the gradient to the shape of the original input."""
816 return array_ops.reshape(
817 _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0]))
820@ops.RegisterGradient("ExpandDims")
821def _ExpandDimsGrad(op, grad):
822 return [_ReshapeToInput(op, grad), None]
825@ops.RegisterGradient("Squeeze")
826def _SqueezeGrad(op, grad):
827 return _ReshapeToInput(op, grad)
830@ops.RegisterGradient("Transpose")
831def _TransposeGrad(op, grad):
832 """Returns unshuffle(grad)."""
833 p = op.inputs[1]
834 return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]
837@ops.RegisterGradient("ConjugateTranspose")
838def _ConjugateTransposeGrad(op, grad):
839 """Returns conj(unshuffle(grad))."""
840 p = op.inputs[1]
841 return [
842 array_ops.transpose(
843 grad, array_ops.invert_permutation(p), conjugate=True), None
844 ]
847ops.NotDifferentiable("Shape")
849ops.NotDifferentiable("ShapeN")
851ops.NotDifferentiable("Rank")
853ops.NotDifferentiable("Size")
856@ops.RegisterGradient("Tile")
857def _TileGrad(op, grad):
858 """Sum reduces grad along the tiled dimensions."""
859 input_shape = array_ops.shape(op.inputs[0], out_type=op.inputs[1].dtype)
860 # We interleave multiples and input_shape to get split_shape,
861 # reshape grad to split_shape, and reduce along all even
862 # dimensions (the tiled dimensions) to get the result
863 # with shape input_shape. For example
864 # input_shape = [20, 30, 40]
865 # multiples = [2, 3, 4]
866 # split_shape = [2, 20, 3, 30, 4, 40]
867 # axes = [0, 2, 4]
868 split_shape = array_ops.reshape(
869 array_ops.transpose(array_ops_stack.stack([op.inputs[1], input_shape])),
870 [-1])
871 axes = math_ops.range(0, array_ops.size(split_shape), 2)
872 # Sum reduces grad along the first dimension for IndexedSlices
873 if isinstance(grad, indexed_slices_lib.IndexedSlices):
874 input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
875 grad = math_ops.unsorted_segment_sum(
876 grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0)
877 split_shape = array_ops.concat([[1], split_shape[1:]], axis=0)
878 input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
879 # Fix shape inference
880 if not context.executing_eagerly():
881 input_grad.set_shape(op.inputs[0].get_shape())
882 return [input_grad, None]
885ops.NotDifferentiable("BroadcastGradientArgs")
888def _PadGrad(op, grad):
889 """Gradient for Pad."""
890 # Pad introduces values around the original tensor, so the gradient function
891 # slices the original shape out of the gradient."""
892 x = op.inputs[0]
893 a = op.inputs[1] # [Rank(x), 2]
894 # Takes a slice of a. The 1st column. [Rank(x), 1].
895 pad_before = array_ops.slice(a, [0, 0],
896 array_ops_stack.stack([array_ops.rank(x), 1]))
897 # Make it a 1-D tensor.
898 begin = array_ops.reshape(pad_before, [-1])
899 sizes = array_ops.shape(x, out_type=begin.dtype)
900 x_grad = array_ops.slice(grad, begin, sizes)
901 if len(op.inputs) == 3:
902 return x_grad, None, None
903 else:
904 return x_grad, None
907ops.RegisterGradient("Pad")(_PadGrad)
908ops.RegisterGradient("PadV2")(_PadGrad)
911# ReverseSequence is just a permutation. The gradient permutes back.
912@ops.RegisterGradient("ReverseSequence")
913def _ReverseSequenceGrad(op, grad):
914 seq_lengths = op.inputs[1]
915 return [
916 array_ops.reverse_sequence(
917 grad,
918 batch_axis=op.get_attr("batch_dim"),
919 seq_axis=op.get_attr("seq_dim"),
920 seq_lengths=seq_lengths), None
921 ]
924@ops.RegisterGradient("Reverse")
925def _ReverseGrad(op, grad):
926 reverse_dims = op.inputs[1]
927 return gen_array_ops.reverse(grad, reverse_dims), None
930@ops.RegisterGradient("ReverseV2")
931def _ReverseV2Grad(op, grad):
932 axis = op.inputs[1]
933 return array_ops.reverse_v2(grad, axis), None
936@ops.RegisterGradient("SpaceToBatch")
937def _SpaceToBatchGrad(op, grad):
938 # Its gradient is the opposite op: BatchToSpace.
939 block_size = op.get_attr("block_size")
940 return [
941 array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None
942 ]
945@ops.RegisterGradient("SpaceToBatchND")
946def _SpaceToBatchNDGrad(op, grad):
947 # Its gradient is the opposite op: BatchToSpaceND.
948 return [
949 array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None
950 ]
953@ops.RegisterGradient("BatchToSpace")
954def _BatchToSpaceGrad(op, grad):
955 # Its gradient is the opposite op: SpaceToBatch.
956 block_size = op.get_attr("block_size")
957 return [
958 array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None
959 ]
962@ops.RegisterGradient("BatchToSpaceND")
963def _BatchToSpaceNDGrad(op, grad):
964 # Its gradient is the opposite op: SpaceToBatchND.
965 return [
966 array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None
967 ]
970@ops.RegisterGradient("SpaceToDepth")
971def _SpaceToDepthGrad(op, grad):
972 # Its gradient is the opposite op: DepthToSpace.
973 block_size = op.get_attr("block_size")
974 data_format = op.get_attr("data_format")
975 if data_format == "NCHW_VECT_C":
976 raise ValueError("Cannot compute SpaceToDepth gradient with NCHW_VECT_C. "
977 "NCHW_VECT_C requires qint8 data type.")
978 return array_ops.depth_to_space(grad, block_size, data_format=data_format)
981@ops.RegisterGradient("DepthToSpace")
982def _DepthToSpaceGrad(op, grad):
983 # Its gradient is the opposite op: SpaceToDepth.
984 block_size = op.get_attr("block_size")
985 data_format = op.get_attr("data_format")
986 if data_format == "NCHW_VECT_C":
987 raise ValueError("Cannot compute DepthToSpace gradient with NCHW_VECT_C. "
988 "NCHW_VECT_C requires qint8 data type.")
989 return array_ops.space_to_depth(grad, block_size, data_format=data_format)
992ops.NotDifferentiable("OneHot")
995@ops.RegisterGradient("MirrorPad")
996def _MirrorPadGrad(op, grad):
997 mode = op.get_attr("mode")
998 return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None]
1001@ops.RegisterGradient("MirrorPadGrad")
1002def _MirrorPadGradGrad(op, grad):
1003 mode = op.get_attr("mode")
1004 return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None]
1007@ops.RegisterGradient("QuantizeAndDequantize")
1008def _QuantizeAndDequantizeGrad(_, grad):
1009 return grad
1012@ops.RegisterGradient("QuantizeAndDequantizeV2")
1013def _QuantizeAndDequantizeV2Grad(_, grad):
1014 return [grad, None, None]
1017@ops.RegisterGradient("QuantizeAndDequantizeV3")
1018def _QuantizeAndDequantizeV3Grad(_, grad):
1019 # Only propagate the gradient for the unquantized input.
1020 return [grad, None, None, None]
1023@ops.RegisterGradient("ExtractImagePatches")
1024def _ExtractImagePatchesGrad(op, grad):
1025 input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64)
1026 batch_size, rows_in, cols_in, channels = array_ops_stack.unstack(input_bhwc)
1028 output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64)
1029 rows_out, cols_out = array_ops_stack.unstack(output_bhwc[1:3])
1031 _, ksize_r, ksize_c, _ = op.get_attr("ksizes")
1033 # Create indices matrix for input tensor.
1034 # Note that 0 is preserved for padding location,
1035 # so indices for input start from 1 to 1 + rows_in * cols_in.
1036 input_indices_num = rows_in * cols_in
1037 # XLA version of extract_image_patches does not support int64,
1038 # using float32 instead.
1039 input_idx = array_ops.reshape(
1040 math_ops.range(1, input_indices_num + 1, dtype=ops.dtypes.float32),
1041 (1, rows_in, cols_in, 1),
1042 )
1043 input_idx_patched = gen_array_ops.extract_image_patches(
1044 input_idx, op.get_attr("ksizes"), op.get_attr("strides"),
1045 op.get_attr("rates"), op.get_attr("padding"))
1046 input_idx_patched = math_ops.cast(input_idx_patched, dtypes.int64)
1048 grad_expanded = array_ops.transpose(
1049 array_ops.reshape(
1050 _IndexedSlicesToTensorNoWarning(grad),
1051 (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
1052 (1, 2, 3, 4, 0, 5))
1053 grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
1055 # Shift all input indices back. Padding locations will have "-1" value
1056 # which is fortunately ignored by segmented sum.
1057 segment_ids = array_ops.reshape(input_idx_patched, [-1]) - 1
1058 grad_out = math_ops.unsorted_segment_sum(
1059 grad_flat, segment_ids, num_segments=input_indices_num
1060 )
1062 grad_out = array_ops.reshape(
1063 grad_out, (rows_in, cols_in, batch_size, channels)
1064 )
1065 grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))
1067 return [grad_out]
1070@ops.RegisterGradient("ExtractVolumePatches")
1071def _ExtractVolumePatchesGrad(op, grad):
1072 batch_size, planes_in, rows_in, cols_in, channels = [
1073 dim.value for dim in op.inputs[0].shape.dims
1074 ]
1075 input_bphwc = array_ops.shape(op.inputs[0])
1076 batch_size = input_bphwc[0]
1077 channels = input_bphwc[4]
1079 # Create indices matrix for input tensor.
1080 # Note that 0 is preserved for padding location,
1081 # so indices for input start from 1 to 1 + rows_in * cols_in.
1082 input_indices_num = 1 + planes_in * rows_in * cols_in
1083 input_idx = array_ops.reshape(
1084 math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64),
1085 (1, planes_in, rows_in, cols_in, 1))
1086 input_idx_patched = gen_array_ops.extract_volume_patches(
1087 input_idx, op.get_attr("ksizes"), op.get_attr("strides"),
1088 op.get_attr("padding"))
1090 # Create indices matrix for output tensor.
1091 _, planes_out, rows_out, cols_out, _ = [
1092 dim.value for dim in op.outputs[0].shape.dims
1093 ]
1094 _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes")
1095 # Indices for output start from 0.
1096 prc_indices_num = planes_out * rows_out * cols_out
1097 output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c
1098 output_idx = array_ops.reshape(
1099 math_ops.range(output_indices_num, dtype=ops.dtypes.int64),
1100 (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c))
1102 # Construct mapping table for indices: (input -> output).
1103 idx_matrix = array_ops.concat([
1104 array_ops.expand_dims(input_idx_patched, axis=-1),
1105 array_ops.expand_dims(output_idx, axis=-1)
1106 ],
1107 axis=-1)
1108 idx_map = array_ops.reshape(idx_matrix, (-1, 2))
1110 sp_shape = (input_indices_num, output_indices_num)
1111 sp_mat_full = sparse_tensor.SparseTensor(
1112 idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape)
1113 # Remove all padding locations [0, :].
1114 sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
1115 (input_indices_num - 1, output_indices_num))
1117 grad_expanded = array_ops.transpose(
1118 array_ops.reshape(
1119 _IndexedSlicesToTensorNoWarning(grad),
1120 (batch_size, planes_out, rows_out, cols_out, ksize_p, ksize_r,
1121 ksize_c, channels)), (1, 2, 3, 4, 5, 6, 0, 7))
1122 grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
1124 jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
1126 grad_out = array_ops.reshape(
1127 jac, (planes_in, rows_in, cols_in, batch_size, channels))
1128 grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4))
1130 return [grad_out]
1133@ops.RegisterGradient("ScatterNd")
1134def _ScatterNdGrad(op, grad):
1135 indices = op.inputs[0]
1136 updates_grad = array_ops.gather_nd(grad, indices)
1137 return [None, updates_grad, None]
1140@ops.RegisterGradient("TensorScatterUpdate")
1141def _TensorScatterUpdateGrad(op, grad):
1142 indices = op.inputs[1]
1143 updates_grad = array_ops.gather_nd(grad, indices)
1144 tensor_grad = array_ops.tensor_scatter_update(
1145 array_ops.identity(grad), indices,
1146 array_ops.zeros_like(op.inputs[2], dtype=grad.dtype))
1147 return [tensor_grad, None, updates_grad]
1150@ops.RegisterGradient("TensorScatterAdd")
1151def _TensorScatterAddGrad(op, grad):
1152 indices = op.inputs[1]
1153 updates_grad = array_ops.gather_nd(grad, indices)
1154 tensor_grad = array_ops.identity(grad)
1155 return [tensor_grad, None, updates_grad]
1158def _TensorScatterMinOrMaxGrad(op, grad):
1159 """Gradient for TensorScatterMin and TensorScatterMax."""
1160 indices = op.inputs[1]
1161 x = op.inputs[0]
1162 y = op.inputs[2]
1163 output = op.outputs[0]
1164 x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype)
1165 y_output = array_ops.gather_nd(output, indices)
1166 y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype)
1167 ys_indicators = array_ops.scatter_nd(
1168 indices, y_indicators, array_ops.shape(x, out_type=indices.dtype))
1169 indicators = x_indicators + ys_indicators # All elements are >= 1.
1170 # If there are multiple minimum or maximum elements then the gradient will be
1171 # divided between them.
1172 x_grad = grad * x_indicators / indicators
1173 y_grad = array_ops.gather_nd(grad / indicators, indices) * y_indicators
1174 return [x_grad, None, y_grad]
1177@ops.RegisterGradient("TensorScatterMax")
1178def _TensorScatterMaxGrad(op, grad):
1179 """Gradient for TensorScatterMax op."""
1180 return _TensorScatterMinOrMaxGrad(op, grad)
1183@ops.RegisterGradient("TensorScatterMin")
1184def _TensorScatterMinGrad(op, grad):
1185 """Gradient for TensorScatterMin op."""
1186 return _TensorScatterMinOrMaxGrad(op, grad)
1189@ops.RegisterGradient("TensorScatterSub")
1190def _TensorScatterSubGrad(op, grad):
1191 indices = op.inputs[1]
1192 updates_grad = array_ops.gather_nd(grad, indices)
1193 tensor_grad = array_ops.identity(grad)
1194 return [tensor_grad, None, -updates_grad]
1197@ops.RegisterGradient("ScatterNdNonAliasingAdd")
1198def _ScatterNdNonAliasingAddGrad(op, grad):
1199 indices = op.inputs[1]
1200 updates_grad = array_ops.gather_nd(grad, indices)
1201 return [grad, None, updates_grad]
1204@ops.RegisterGradient("BroadcastTo")
1205def _BroadcastToGrad(op, grad):
1206 input_value = op.inputs[0]
1207 broadcast_shape = op.inputs[1]
1208 shape_dtype = dtypes.int32
1209 if isinstance(broadcast_shape, ops.Tensor):
1210 shape_dtype = broadcast_shape.dtype
1212 input_value_shape = array_ops.shape(input_value, out_type=shape_dtype)
1213 if not isinstance(broadcast_shape, ops.EagerTensor):
1214 broadcast_shape_static = tensor_shape.TensorShape(
1215 tensor_util.try_evaluate_constant(broadcast_shape))
1216 if broadcast_shape_static.is_fully_defined():
1217 broadcast_shape = constant_op.constant(
1218 broadcast_shape_static.as_list(), dtype=shape_dtype)
1219 _, reduction_axes = gen_array_ops.broadcast_gradient_args(
1220 broadcast_shape, input_value_shape)
1221 updates_grad_reshaped = math_ops.reduce_sum(
1222 grad, axis=reduction_axes, keepdims=True)
1223 updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape)
1224 return [updates_grad, None]