Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/math_grad.py: 24%
1208 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in math_ops.py."""
16import numpy as np
18from tensorflow.python.compat import compat
19from tensorflow.python.eager import context
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gen_array_ops
26from tensorflow.python.ops import gen_math_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import special_math_ops
31def _safe_shape_div(x, y):
32 """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`."""
33 return x // math_ops.maximum(y, 1)
36@ops.RegisterGradient("ArgMax")
37def _ArgMaxGrad(op, grad):
38 del op, grad
39 return [None, None]
42@ops.RegisterGradient("ArgMin")
43def _ArgMinGrad(op, grad):
44 del op, grad
45 return [None, None]
48@ops.RegisterGradient("EuclideanNorm")
49def _EuclideanNormGrad(op, grad):
50 """Gradient for EuclideanNorm."""
52 output = op.outputs[0]
54 if not op.get_attr("keep_dims"):
55 output_shape_kept_dims = math_ops.reduced_shape(
56 array_ops.shape(op.inputs[0]), op.inputs[1])
57 output = array_ops.reshape(output, output_shape_kept_dims)
58 grad = array_ops.reshape(grad, output_shape_kept_dims)
60 return math_ops.truediv(op.inputs[0], output / grad), None
63def SmartBroadcastGradientArgs(x, y, grad):
64 """Optimized version of `broadcast_gradient_args` that caches results.
66 This implementation avoids creating `broadcast_gradient_args` ops in the case
67 that the input shapes are fully defined, and provides hints to the calling
68 code that can be used to avoid creating reduction and reshaping ops.
70 Args:
71 x: The left input tensor to a broadcasting binary op.
72 y: The right input tensor to a broadcasting binary op.
73 grad: The incoming gradient tensor for a broadcasting binary op.
75 Returns:
76 A pair of tuples, containing:
77 * A 3-tuple of broadcast information for x, containing:
78 * The shape of x (as a tuple or Tensor).
79 * The reduction indices for x (as a tuple or Tensor).
80 * A boolean, which if True, indicates that x's shape differs from grad's
81 shape (and so x's gradient must be reduced and/or reshaped).
82 * A 3-tuple of broadcast information for y, containing the respective
83 details for y.
84 """
85 # NOTE: It may be productive to apply these optimizations in the eager case
86 # as well.
87 if context.executing_eagerly() or not (
88 isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
89 and isinstance(grad, ops.Tensor)):
90 sx = array_ops.shape(x)
91 sy = array_ops.shape(y)
92 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
93 return (sx, rx, True), (sy, ry, True)
95 # pylint: disable=protected-access
96 x_shape_tuple = x._shape_tuple()
97 y_shape_tuple = y._shape_tuple()
98 grad_shape_tuple = grad._shape_tuple()
99 # pylint: enable=protected-access
101 if (x_shape_tuple is None or None in x_shape_tuple or
102 y_shape_tuple is None or None in y_shape_tuple):
103 sx = array_ops.shape_internal(x, optimize=False)
104 sy = array_ops.shape_internal(y, optimize=False)
105 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
106 return (sx, rx, True), (sy, ry, True)
108 x_needs_reduction = x_shape_tuple != grad_shape_tuple
109 y_needs_reduction = y_shape_tuple != grad_shape_tuple
111 # Get the default graph rather than relying on `x.graph`, `y.graph`, or
112 # `grad.graph`, because these may be eager tensors.
113 g = ops.get_default_graph()
115 try:
116 rx, ry = g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] # pylint: disable=protected-access
117 return (x_shape_tuple, rx, x_needs_reduction), (
118 y_shape_tuple, ry, y_needs_reduction)
119 except KeyError:
120 rx, ry = array_ops.broadcast_gradient_args(x_shape_tuple, y_shape_tuple)
121 # TODO(mrry): If this becomes a bottleneck, add a multi-output version of
122 # `TF_TryEvaluateConstant()`.
123 rx_value = tuple(tensor_util.try_evaluate_constant(rx))
124 assert rx_value is not None
125 ry_value = tuple(tensor_util.try_evaluate_constant(ry))
126 assert ry_value is not None
127 g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] = ( # pylint: disable=protected-access
128 rx_value, ry_value)
130 return (x_shape_tuple, rx_value, x_needs_reduction), (
131 y_shape_tuple, ry_value, y_needs_reduction)
134_empty_tuple = ()
137def _IsScalar(x):
138 return x._shape_tuple() is _empty_tuple # pylint: disable=protected-access
141@ops.RegisterGradient("Sum")
142def _SumGrad(op, grad):
143 """Gradient for Sum."""
144 # Fast path for when reducing to a scalar and ndims is known: adds only
145 # Reshape and Tile ops (and possibly a Shape).
146 input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access
147 if input_0_shape is not None:
148 axes = tensor_util.constant_value(op.inputs[1])
149 if axes is not None:
150 rank = len(input_0_shape)
151 if np.array_equal(axes, np.arange(rank)): # Reduce all dims.
152 if context.executing_eagerly():
153 ctx = context.context()
154 new_shape = ctx.ones_rank_cache().get(rank)
155 if new_shape is None:
156 new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32)
157 ctx.ones_rank_cache().put(rank, new_shape)
158 else:
159 new_shape = [1] * rank
160 grad = array_ops.reshape(grad, new_shape)
161 # If shape is not fully defined (but rank is), we use Shape.
162 if None not in input_0_shape:
163 input_shape = constant_op.constant(input_0_shape, dtype=dtypes.int32)
164 else:
165 input_shape = array_ops.shape(op.inputs[0])
166 return [array_ops.tile(grad, input_shape), None]
167 elif None not in input_0_shape and not context.executing_eagerly():
168 # The shape and reduction indices are statically known, so we use a
169 # graph-level cache to avoid recomputing `reduced_shape()` for each
170 # invocation.
171 graph = ops.get_default_graph()
173 # Canonicalize `axes` to be a tuple of indices. The incoming
174 # value may be a scalar or a vector, and may include negative indices.
175 axes = tuple(axes.reshape(-1))
177 try:
178 output_shape_kept_dims, tile_scaling = graph._reduced_shape_cache[ # pylint: disable=protected-access
179 (input_0_shape, axes)]
180 except KeyError:
182 # Compute and cache `output_shape_kept_dims` and `tile_scaling`.
183 def EvaluateAsTuple(t):
184 if tensor_util.is_tf_type(t):
185 value = tensor_util.try_evaluate_constant(t)
186 assert value is not None
187 else:
188 value = t
189 return tuple(value)
191 output_shape_kept_dims = EvaluateAsTuple(
192 math_ops.reduced_shape(input_0_shape, axes))
193 tile_scaling = EvaluateAsTuple(
194 _safe_shape_div(input_0_shape, output_shape_kept_dims))
195 graph._reduced_shape_cache[(input_0_shape, axes)] = ( # pylint:disable=protected-access
196 output_shape_kept_dims, tile_scaling)
198 grad = array_ops.reshape(grad, output_shape_kept_dims)
199 return [array_ops.tile(grad, tile_scaling), None]
201 input_shape = array_ops.shape(op.inputs[0])
203 if not op.get_attr("keep_dims"):
204 with ops.colocate_with(input_shape):
205 # TODO(apassos) remove this once device placement for eager ops makes
206 # more sense.
207 output_shape_kept_dims = math_ops.reduced_shape(input_shape,
208 op.inputs[1])
209 grad = array_ops.reshape(grad, output_shape_kept_dims)
210 return [array_ops.broadcast_to(grad, input_shape), None]
213def _MinOrMaxGrad(op, grad):
214 """Gradient for Min or Max. Amazingly it's precisely the same code."""
215 input_shape = array_ops.shape(op.inputs[0])
216 y = op.outputs[0]
217 if not op.get_attr("keep_dims"):
218 output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
219 y = array_ops.reshape(y, output_shape_kept_dims)
220 grad = array_ops.reshape(grad, output_shape_kept_dims)
221 else:
222 output_shape_kept_dims = array_ops.shape(y)
224 # Compute the number of selected (maximum or minimum) elements in each
225 # reduction dimension. If there are multiple minimum or maximum elements
226 # then the gradient will be divided between them.
227 indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype)
228 num_selected = array_ops.reshape(
229 math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims)
231 return [math_ops.divide(indicators, num_selected) * grad, None]
234@ops.RegisterGradient("Max")
235def _MaxGrad(op, grad):
236 """Gradient for Max."""
237 return _MinOrMaxGrad(op, grad)
240@ops.RegisterGradient("Min")
241def _MinGrad(op, grad):
242 return _MinOrMaxGrad(op, grad)
245@ops.RegisterGradient("Mean")
246def _MeanGrad(op, grad):
247 """Gradient for Mean."""
248 sum_grad = _SumGrad(op, grad)[0]
249 input_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access
250 output_shape = op.outputs[0]._shape_tuple() # pylint: disable=protected-access
251 if (input_shape is not None and output_shape is not None and
252 None not in input_shape and None not in output_shape):
253 input_size = np.prod(input_shape)
254 output_size = np.prod(output_shape)
255 factor = input_size // max(output_size, 1)
256 factor = constant_op.constant(factor, dtype=sum_grad.dtype)
257 else:
258 input_shape = array_ops.shape(op.inputs[0])
259 output_shape = array_ops.shape(op.outputs[0])
260 factor = _safe_shape_div(
261 math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape))
262 return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None
265@ops.RegisterGradient("Prod")
266def _ProdGrad(op, grad):
267 """Gradient for Prod."""
268 # The gradient can be expressed by dividing the product by each entry of the
269 # input tensor, but this approach can't deal with zeros in the input.
270 # Here, we avoid this problem by composing the output as a product of two
271 # cumprod operations.
273 input_shape = array_ops.shape(op.inputs[0])
274 # Reshape reduction indices for the case where the parameter is a scalar
275 reduction_indices = array_ops.reshape(op.inputs[1], [-1])
277 # Expand grad to full input shape
278 if not op.get_attr("keep_dims"):
279 output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
280 grad = array_ops.reshape(grad, output_shape_kept_dims)
282 grad = array_ops.broadcast_to(grad, input_shape)
284 # Pack all reduced dimensions into a single one, so we can perform the
285 # cumprod ops. If the reduction dims list is empty, it defaults to float32,
286 # so we need to cast here. We put all the shape-related ops on CPU to avoid
287 # copying back and forth, and since listdiff is CPU only.
288 with ops.device("/cpu:0"):
289 rank = array_ops.rank(op.inputs[0])
290 reduction_indices = (reduction_indices + rank) % rank
291 reduced = math_ops.cast(reduction_indices, dtypes.int32)
292 idx = math_ops.range(0, rank)
293 other, _ = gen_array_ops.list_diff(idx, reduced, dtypes.int32)
294 perm = array_ops.concat([reduced, other], 0)
295 reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
296 other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
297 permuted = array_ops.transpose(op.inputs[0], perm)
298 permuted_shape = array_ops.shape(permuted)
299 reshaped = array_ops.reshape(permuted, (reduced_num, other_num))
301 # Calculate product, leaving out the current entry
302 left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
303 right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
304 # For complex inputs, the gradient is in the conjugate direction.
305 y = array_ops.reshape(
306 math_ops.conj(left) * math_ops.conj(right), permuted_shape)
308 # Invert the transpose and reshape operations.
309 # Make sure to set the statically known shape information through a reshape.
310 out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
311 return array_ops.reshape(out, input_shape), None
314@ops.RegisterGradient("SegmentSum")
315def _SegmentSumGrad(op, grad):
316 """Gradient for SegmentSum."""
317 return array_ops.gather(grad, op.inputs[1]), None
320@ops.RegisterGradient("SegmentMean")
321def _SegmentMeanGrad(op, grad):
322 """Gradient for SegmentMean."""
323 input_rank = array_ops.rank(op.inputs[0])
324 ones_shape = array_ops.concat([
325 array_ops.shape(op.inputs[1]),
326 array_ops.ones(
327 array_ops.expand_dims(input_rank - 1, 0), dtype=dtypes.int32)
328 ], 0)
329 ones = array_ops.ones(ones_shape, dtype=grad.dtype)
330 scaled_grad = math_ops.divide(grad, math_ops.segment_sum(ones, op.inputs[1]))
331 return array_ops.gather(scaled_grad, op.inputs[1]), None
334@ops.RegisterGradient("SparseSegmentSum")
335def _SparseSegmentSumGrad(op, grad):
336 """Gradient for SparseSegmentSum."""
337 dim0 = array_ops.shape(op.inputs[0])[0]
338 if compat.forward_compatible(2021, 6, 10):
339 return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
340 dim0), None, None)
341 else:
342 return (math_ops.unsorted_segment_sum(
343 array_ops.gather(grad, op.inputs[2]), op.inputs[1], dim0), None, None)
346@ops.RegisterGradient("SparseSegmentSumWithNumSegments")
347def _SparseSegmentSumWithNumSegmentsGrad(op, grad):
348 """Gradient for SparseSegmentSumWithNumSegments."""
349 dim0 = array_ops.shape(op.inputs[0])[0]
350 if compat.forward_compatible(2021, 6, 10):
351 return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
352 dim0), None, None, None)
353 else:
354 return (math_ops.unsorted_segment_sum(
355 array_ops.gather(grad, op.inputs[2]), op.inputs[1],
356 dim0), None, None, None)
359@ops.RegisterGradient("SparseSegmentMean")
360def _SparseSegmentMeanGrad(op, grad):
361 """Gradient for SparseSegmentMean."""
362 dim0 = array_ops.shape(op.inputs[0])[0]
363 return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
364 dim0), None, None)
367@ops.RegisterGradient("SparseSegmentMeanWithNumSegments")
368def _SparseSegmentMeanWithNumSegmentsGrad(op, grad):
369 """Gradient for SparseSegmentMeanWithNumSegments."""
370 dim0 = array_ops.shape(op.inputs[0])[0]
371 return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
372 dim0), None, None, None)
375@ops.RegisterGradient("SparseSegmentSqrtN")
376def _SparseSegmentSqrtNGrad(op, grad):
377 """Gradient for SparseSegmentSqrtN."""
378 dim0 = array_ops.shape(op.inputs[0])[0]
379 return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
380 dim0), None, None)
383@ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments")
384def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad):
385 """Gradient for SparseSegmentSqrtNWithNumSegments."""
386 dim0 = array_ops.shape(op.inputs[0])[0]
387 return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
388 dim0), None, None, None)
391def _SegmentMinOrMaxGrad(op, grad):
392 """ Gradient for SegmentMin and SegmentMax. """
393 zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype)
394 # Get the number of selected (minimum or maximum) elements in each segment.
395 gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
396 is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
397 num_selected = math_ops.segment_sum(
398 math_ops.cast(is_selected, grad.dtype), op.inputs[1])
399 # Compute the gradient for each segment. The gradient for the ith segment is
400 # divided evenly among the selected elements in that segment.
401 weighted_grads = math_ops.divide(grad, num_selected)
402 gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])
403 return array_ops.where_v2(is_selected, gathered_grads, zeros), None
406@ops.RegisterGradient("SegmentMin")
407def _SegmentMinGrad(op, grad):
408 """Gradient for SegmentMin."""
409 return _SegmentMinOrMaxGrad(op, grad)
412@ops.RegisterGradient("SegmentMax")
413def _SegmentMaxGrad(op, grad):
414 """Gradient for SegmentMax."""
415 return _SegmentMinOrMaxGrad(op, grad)
418@ops.RegisterGradient("SegmentProd")
419def _SegmentProdGrad(op, grad):
420 """Gradient for SegmentProd.
422 The gradient can be expressed for each segment by dividing the segment's
423 product by each element of the segment input tensor, but this approach can't
424 deal with zeros in the input.
425 Unlike reduce_prod we can't use cumsum here as individual segments may have
426 a different number of elements. Therefore we consider three cases:
427 1) A segment input contains no zeros and we can safely divide by the input
428 tensor.
429 2) A segment contains exactly one zero. Then the gradient of each input of
430 the segment is zero except for the 0-input, there the gradient is
431 the product of the remaining segment entries.
432 3) A segment contains at least two zeros. The gradient is zero for all
433 segment inputs.
434 """
435 data = op.inputs[0]
436 segment_ids = op.inputs[1]
437 is_zero = math_ops.equal(data, 0)
438 num_zeros = gen_math_ops.segment_sum(
439 math_ops.cast(is_zero, dtype=dtypes.int32), segment_ids)
440 # handle case 3 and set the gradient to 0 for segments with more than one
441 # 0 as input
442 grad = array_ops.where_v2(
443 math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
444 # replace all zeros with ones and compute the segment_prod
445 non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(data), data)
446 non_zero_prod = gen_math_ops.segment_prod(non_zero_data, segment_ids)
447 gathered_prod = array_ops.gather(op.outputs[0], segment_ids)
448 gathered_non_zero_prod = array_ops.gather(non_zero_prod, segment_ids)
449 prod_divided_by_el = gathered_prod / non_zero_data
450 # Now fetch the individual results for segments containing 0 and those that
451 # don't.
452 partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod,
453 prod_divided_by_el)
454 gathered_grad = array_ops.gather(grad, segment_ids)
455 return gathered_grad * partial_derivative, None
458def _GatherDropNegatives(params,
459 ids,
460 zero_clipped_indices=None,
461 is_positive=None):
462 """ Helper function for unsorted segment ops.
464 Gathers params for
465 positive segment ids and gathers 0 for inputs with negative segment id.
466 Also returns the clipped indices and a boolean mask with the same shape
467 as ids where a positive id is masked as true. With this, the latter two
468 can be passed as arguments to this function to reuse them.
469 """
470 if zero_clipped_indices is None:
471 zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids))
472 gathered = array_ops.gather(params, zero_clipped_indices)
473 if is_positive is None:
474 is_positive = math_ops.greater_equal(ids, 0)
475 # tf.where(condition, x, y) requires condition to have the same shape as x
476 # and y.
477 is_positive_shape = array_ops.shape(is_positive)
478 broadcastable_shape = array_ops.concat(
479 [is_positive_shape,
480 array_ops.ones([array_ops.rank(gathered)
481 - array_ops.rank(is_positive)],
482 dtype=is_positive_shape.dtype)],
483 axis=0)
484 is_positive = array_ops.reshape(is_positive, broadcastable_shape)
485 is_positive = (
486 is_positive & array_ops.ones_like(gathered, dtype=dtypes.bool))
487 # replace gathered params of negative indices with 0
488 zero_slice = array_ops.zeros_like(gathered)
489 return (array_ops.where_v2(is_positive, gathered,
490 zero_slice), zero_clipped_indices, is_positive)
493def _UnsortedSegmentMinOrMaxGrad(op, grad):
494 """ Gradient for UnsortedSegmentMin and UnsortedSegmentMax. """
495 # Get the number of selected (minimum or maximum) elements in each segment.
496 gathered_outputs, zero_clipped_indices, is_positive = \
497 _GatherDropNegatives(op.outputs[0], op.inputs[1])
498 is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
499 is_selected = math_ops.logical_and(is_selected, is_positive)
500 num_selected = math_ops.unsorted_segment_sum(
501 math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2])
502 # Compute the gradient for each segment. The gradient for the ith segment is
503 # divided evenly among the selected elements in that segment.
504 weighted_grads = math_ops.divide(grad, num_selected)
505 gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
506 zero_clipped_indices, is_positive)
507 zeros = array_ops.zeros_like(gathered_grads)
508 return array_ops.where_v2(is_selected, gathered_grads, zeros), None, None
511@ops.RegisterGradient("UnsortedSegmentSum")
512def _UnsortedSegmentSumGrad(op, grad):
513 """Gradient for UnsortedSegmentSum."""
514 return _GatherDropNegatives(grad, op.inputs[1])[0], None, None
517@ops.RegisterGradient("UnsortedSegmentMax")
518def _UnsortedSegmentMaxGrad(op, grad):
519 """ Gradient for UnsortedSegmentMax. """
520 return _UnsortedSegmentMinOrMaxGrad(op, grad)
523@ops.RegisterGradient("UnsortedSegmentMin")
524def _UnsortedSegmentMinGrad(op, grad):
525 """ Gradient for UnsortedSegmentMin. """
526 return _UnsortedSegmentMinOrMaxGrad(op, grad)
529@ops.RegisterGradient("UnsortedSegmentProd")
530def _UnsortedSegmentProdGrad(op, grad):
531 """ Gradient for UnsortedSegmentProd.
533 The gradient can be expressed for each segment by dividing the segment's
534 product by each element of the segment input tensor, but this approach can't
535 deal with zeros in the input.
536 Unlike reduce_prod we can't use cumsum here as individual segments may have
537 a different number of elements. Therefore we consider three cases:
538 1) A segment input contains no zeros and we can safely divide by the input
539 tensor.
540 2) A segment contains exactly one zero. Then the gradient of each input of
541 the segment is zero except for the 0-input, there the gradient is
542 the product of the remaining segment entries.
543 3) A segment contains at least two zeros. The gradient is zero for all
544 segment inputs.
545 """
546 # Note that unsorted_segment_sum will filter out the negative indices,
547 # so we don't need to do a logical_and with is_positive here
548 is_zero = math_ops.equal(op.inputs[0], 0)
549 num_zeros = gen_math_ops.unsorted_segment_sum(
550 math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2])
551 # handle case 3 and set the gradient to 0 for segments with more than one
552 # 0 as input
553 grad = array_ops.where_v2(
554 math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
555 # replace all zeros with ones and compute the unsorted_segment_prod
556 non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(op.inputs[0]),
557 op.inputs[0])
558 non_zero_prod = gen_math_ops.unsorted_segment_prod(non_zero_data,
559 op.inputs[1], op.inputs[2])
560 # clip the indices for gather to be positive
561 zero_clipped_indices = math_ops.maximum(op.inputs[1],
562 array_ops.zeros_like(op.inputs[1]))
563 gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices)
564 gathered_non_zero_prod = array_ops.gather(non_zero_prod, zero_clipped_indices)
565 prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf.
566 # Now fetch the individual results for segments containing 0 and those that
567 # don't. is_zero will also fetch results for entries with negative index
568 # but the following gather_drop_negatives sets the corresponding entry in
569 # grad to 0 for these
570 partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod,
571 prod_divided_by_el)
572 gathered_grad = _GatherDropNegatives(grad, op.inputs[1],
573 zero_clipped_indices)[0]
574 return gathered_grad * partial_derivative, None, None
577@ops.RegisterGradient("Abs")
578def _AbsGrad(op, grad):
579 x = op.inputs[0]
580 return grad * math_ops.sign(x)
583@ops.RegisterGradient("Neg")
584def _NegGrad(_, grad):
585 """Returns -grad."""
586 return -grad
589@ops.RegisterGradient("Inv")
590def _InvGrad(op, grad):
591 """Returns -grad * (1 / x^2)."""
592 y = op.outputs[0] # y = 1 / x
593 return gen_math_ops.reciprocal_grad(y, grad)
596@ops.RegisterGradient("Reciprocal")
597def _ReciprocalGrad(op, grad):
598 """Returns -grad * (1 / x^2)."""
599 y = op.outputs[0] # y = 1 / x
600 return gen_math_ops.reciprocal_grad(y, grad)
603@ops.RegisterGradient("InvGrad")
604def _InvGradGrad(op, grad):
605 b = op.inputs[1]
606 # op.output[0]: y = -b * conj(a)^2
607 with ops.control_dependencies([grad]):
608 ca = math_ops.conj(op.inputs[0])
609 cg = math_ops.conj(grad)
610 return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad)
613@ops.RegisterGradient("ReciprocalGrad")
614def _ReciprocalGradGrad(op, grad):
615 b = op.inputs[1]
616 # op.output[0]: y = -b * conj(a)^2
617 with ops.control_dependencies([grad]):
618 ca = math_ops.conj(op.inputs[0])
619 cg = math_ops.conj(grad)
620 return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad)
623@ops.RegisterGradient("Square")
624def _SquareGrad(op, grad):
625 x = op.inputs[0]
626 # Added control dependencies to prevent 2*x from being computed too early.
627 with ops.control_dependencies([grad]):
628 x = math_ops.conj(x)
629 y = constant_op.constant(2.0, dtype=x.dtype)
630 return math_ops.multiply(grad, math_ops.multiply(x, y))
633@ops.RegisterGradient("Sqrt")
634def _SqrtGrad(op, grad):
635 y = op.outputs[0] # y = x^(1/2)
636 return gen_math_ops.sqrt_grad(y, grad)
639@ops.RegisterGradient("SqrtGrad")
640def _SqrtGradGrad(op, grad):
641 a = op.inputs[0]
642 y = op.outputs[0] # y = 0.5 * b / conj(a)
643 with ops.control_dependencies([grad]):
644 ga = grad / a
645 return -math_ops.conj(ga) * y, 0.5 * ga # pylint: disable=invalid-unary-operand-type
648@ops.RegisterGradient("Rsqrt")
649def _RsqrtGrad(op, grad):
650 """Returns -0.5 * grad * conj(y)^3."""
651 y = op.outputs[0] # y = x^(-1/2)
652 return gen_math_ops.rsqrt_grad(y, grad)
655@ops.RegisterGradient("RsqrtGrad")
656def _RsqrtGradGrad(op, grad):
657 """Returns backprop gradient for f(a,b) = -0.5 * b * conj(a)^3."""
658 a = op.inputs[0] # a = x^{-1/2}
659 b = op.inputs[1] # backprop gradient for a
660 with ops.control_dependencies([grad]):
661 ca = math_ops.conj(a)
662 cg = math_ops.conj(grad)
663 grad_a = -1.5 * cg * b * math_ops.square(ca)
664 grad_b = gen_math_ops.rsqrt_grad(ca, grad)
665 return grad_a, grad_b
668@ops.RegisterGradient("Exp")
669def _ExpGrad(op, grad):
670 """Returns grad * exp(x)."""
671 y = op.outputs[0] # y = e^x
672 with ops.control_dependencies([grad]):
673 y = math_ops.conj(y)
674 return grad * y
677@ops.RegisterGradient("Expm1")
678def _Expm1Grad(op, grad):
679 """Returns grad * exp(x)."""
680 x = op.inputs[0]
681 with ops.control_dependencies([grad]):
682 x = math_ops.conj(x)
683 y = math_ops.exp(x)
684 return grad * y
687@ops.RegisterGradient("Log")
688def _LogGrad(op, grad):
689 """Returns grad * (1/x)."""
690 x = op.inputs[0]
691 with ops.control_dependencies([grad]):
692 x = math_ops.conj(x)
693 return grad * math_ops.reciprocal(x)
696@ops.RegisterGradient("Log1p")
697def _Log1pGrad(op, grad):
698 """Returns grad * (1/(1 + x))."""
699 x = op.inputs[0]
700 with ops.control_dependencies([grad]):
701 x = math_ops.conj(x)
702 return grad * math_ops.reciprocal(1 + x)
705@ops.RegisterGradient("Xlogy")
706def _XLogyGrad(op, grad):
707 """Returns gradient of xlogy(x, y) with respect to x and y."""
708 x = op.inputs[0]
709 y = op.inputs[1]
710 sx = array_ops.shape(x)
711 sy = array_ops.shape(y)
712 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
713 with ops.control_dependencies([grad]):
714 not_zero_x = math_ops.cast(
715 math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
716 partial_x = gen_math_ops.xlogy(not_zero_x, y)
717 partial_y = gen_math_ops.xdivy(x, y)
718 return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
719 array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
722@ops.RegisterGradient("Xlog1py")
723def _XLog1pyGrad(op, grad):
724 """Returns gradient of xlog1py(x, y) with respect to x and y."""
725 x = op.inputs[0]
726 y = op.inputs[1]
727 sx = array_ops.shape(x)
728 sy = array_ops.shape(y)
729 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
730 with ops.control_dependencies([grad]):
731 not_zero_x = math_ops.cast(
732 math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
733 partial_x = gen_math_ops.xlog1py(not_zero_x, y)
734 partial_y = gen_math_ops.xdivy(x, y + 1.)
735 return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
736 array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
739@ops.RegisterGradient("Xdivy")
740def _XDivyGrad(op, grad):
741 """Returns gradient of xdivy(x, y) with respect to x and y."""
742 x = op.inputs[0]
743 y = op.inputs[1]
744 sx = array_ops.shape(x)
745 sy = array_ops.shape(y)
746 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
747 with ops.control_dependencies([grad]):
748 not_zero_x = math_ops.cast(
749 math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
750 partial_x = gen_math_ops.xdivy(not_zero_x, y)
751 partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2)
752 return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
753 array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
756@ops.RegisterGradient("Sinh")
757def _SinhGrad(op, grad):
758 """Returns grad * cosh(x)."""
759 x = op.inputs[0]
760 with ops.control_dependencies([grad]):
761 x = math_ops.conj(x)
762 return grad * math_ops.cosh(x)
765@ops.RegisterGradient("Cosh")
766def _CoshGrad(op, grad):
767 """Returns grad * sinh(x)."""
768 x = op.inputs[0]
769 with ops.control_dependencies([grad]):
770 x = math_ops.conj(x)
771 return grad * math_ops.sinh(x)
774@ops.RegisterGradient("Tanh")
775def _TanhGrad(op, grad):
776 """Returns grad * (1 - tanh(x) * tanh(x))."""
777 y = op.outputs[0] # y = tanh(x)
778 with ops.control_dependencies([grad]):
779 y = math_ops.conj(y)
780 return gen_math_ops.tanh_grad(y, grad)
783@ops.RegisterGradient("Asinh")
784def _AsinhGrad(op, grad):
785 """Returns grad * 1/cosh(y)."""
786 y = op.outputs[0]
787 with ops.control_dependencies([grad]):
788 y = math_ops.conj(y)
789 return grad / math_ops.cosh(y)
792@ops.RegisterGradient("Acosh")
793def _AcoshGrad(op, grad):
794 """Returns grad * 1/sinh(y)."""
795 y = op.outputs[0]
796 with ops.control_dependencies([grad]):
797 y = math_ops.conj(y)
798 return grad / math_ops.sinh(y)
801@ops.RegisterGradient("Atanh")
802def _AtanhGrad(op, grad):
803 """Returns grad * 1/ (1 - x^2)."""
804 x = op.inputs[0]
805 with ops.control_dependencies([grad]):
806 x = math_ops.conj(x)
807 x2 = math_ops.square(x)
808 one = constant_op.constant(1, dtype=grad.dtype)
809 inv = math_ops.reciprocal(math_ops.subtract(one, x2))
810 return grad * inv
813@ops.RegisterGradient("TanhGrad")
814def _TanhGradGrad(op, grad):
815 with ops.control_dependencies([grad]):
816 a = math_ops.conj(op.inputs[0])
817 b = math_ops.conj(op.inputs[1])
818 return grad * -2.0 * b * a, gen_math_ops.tanh_grad(a, grad)
821@ops.RegisterGradient("Erf")
822def _ErfGrad(op, grad):
823 """Returns grad * 2/sqrt(pi) * exp(-x**2)."""
824 x = op.inputs[0]
825 two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype)
826 with ops.control_dependencies([grad]):
827 x = math_ops.conj(x)
828 return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x))
831@ops.RegisterGradient("Erfc")
832def _ErfcGrad(op, grad):
833 """Returns -grad * 2/sqrt(pi) * exp(-x**2)."""
834 x = op.inputs[0]
835 minus_two_over_root_pi = constant_op.constant(
836 -2 / np.sqrt(np.pi), dtype=grad.dtype)
837 with ops.control_dependencies([grad]):
838 x = math_ops.conj(x)
839 return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x))
842@ops.RegisterGradient("Erfinv")
843def _ErfinvGrad(op, grad):
844 """Returns grad * sqrt(pi) / 2 * exp(erfinv(x)**2)."""
845 root_pi_over_two = constant_op.constant(np.sqrt(np.pi) / 2, dtype=grad.dtype)
846 with ops.control_dependencies([grad]):
847 return grad * root_pi_over_two * math_ops.exp(
848 math_ops.square(op.outputs[0]))
851@ops.RegisterGradient("Ndtri")
852def _NdtriGrad(op, grad):
853 """Returns grad * sqrt(2 * pi) * exp(ndtri(x)**2 / 2)."""
854 root_two_pi = constant_op.constant(np.sqrt(2 * np.pi), dtype=grad.dtype)
855 with ops.control_dependencies([grad]):
856 return grad * root_two_pi * math_ops.exp(
857 math_ops.square(op.outputs[0]) / 2.)
860@ops.RegisterGradient("Lgamma")
861def _LgammaGrad(op, grad):
862 """Returns grad * digamma(x)."""
863 x = op.inputs[0]
864 with ops.control_dependencies([grad]):
865 x = math_ops.conj(x)
866 return grad * math_ops.digamma(x)
869@ops.RegisterGradient("Digamma")
870def _DigammaGrad(op, grad):
871 """Compute gradient of the digamma function with respect to its argument."""
872 x = op.inputs[0]
873 with ops.control_dependencies([grad]):
874 x = math_ops.conj(x)
875 partial_x = math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
876 return grad * partial_x
879@ops.RegisterGradient("Dawsn")
880def _DawsnGrad(op, grad):
881 """Compute gradient of dawsn(x) with respect to its argument."""
882 x = op.inputs[0]
883 y = op.outputs[0]
884 with ops.control_dependencies([grad]):
885 return grad * (1. - 2 * x * y)
888@ops.RegisterGradient("Expint")
889def _ExpintGrad(op, grad):
890 """Compute gradient of expint(x) with respect to its argument."""
891 x = op.inputs[0]
892 with ops.control_dependencies([grad]):
893 return grad * math_ops.exp(x) / x
896@ops.RegisterGradient("FresnelCos")
897def _FresnelCosGrad(op, grad):
898 """Compute gradient of fresnel_cos(x) with respect to its argument."""
899 x = op.inputs[0]
900 with ops.control_dependencies([grad]):
901 return grad * math_ops.cos((np.pi / 2.) * math_ops.square(x))
904@ops.RegisterGradient("FresnelSin")
905def _FresnelSinGrad(op, grad):
906 """Compute gradient of fresnel_sin(x) with respect to its argument."""
907 x = op.inputs[0]
908 with ops.control_dependencies([grad]):
909 return grad * math_ops.sin((np.pi / 2.) * math_ops.square(x))
912@ops.RegisterGradient("Spence")
913def _SpenceGrad(op, grad):
914 """Compute gradient of spence(x) with respect to its argument."""
915 x = op.inputs[0]
916 with ops.control_dependencies([grad]):
917 partial_x = math_ops.log(x) / (1 - x)
918 partial_x = array_ops.where(
919 math_ops.equal(x, 1.), -array_ops.ones_like(x), partial_x) # pylint: disable=invalid-unary-operand-type
920 return grad * partial_x
923@ops.RegisterGradient("BesselI0")
924def _BesselI0Grad(op, grad):
925 """Compute gradient of bessel_i0(x) with respect to its argument."""
926 x = op.inputs[0]
927 with ops.control_dependencies([grad]):
928 partial_x = special_math_ops.bessel_i1(x)
929 return grad * partial_x
932@ops.RegisterGradient("BesselI0e")
933def _BesselI0eGrad(op, grad):
934 """Compute gradient of bessel_i0e(x) with respect to its argument."""
935 x = op.inputs[0]
936 y = op.outputs[0]
937 with ops.control_dependencies([grad]):
938 partial_x = (special_math_ops.bessel_i1e(x) - math_ops.sign(x) * y)
939 return grad * partial_x
942@ops.RegisterGradient("BesselI1")
943def _BesselI1Grad(op, grad):
944 """Compute gradient of bessel_i1(x) with respect to its argument."""
945 x = op.inputs[0]
946 y = op.outputs[0]
947 with ops.control_dependencies([grad]):
948 # For x = 0, the correct gradient is 1.0.
949 # However, the main branch gives NaN because of the division by x, so
950 # we impute the gradient manually.
951 # An alternative solution is to express the gradient via bessel_i0 and
952 # bessel_i2, but the latter is not yet implemented in Eigen.
953 dy_dx = array_ops.where_v2(
954 math_ops.equal(x, 0.), math_ops.cast(1., x.dtype),
955 special_math_ops.bessel_i0(x) - math_ops.div(y, x))
956 return grad * dy_dx
959@ops.RegisterGradient("BesselI1e")
960def _BesselI1eGrad(op, grad):
961 """Compute gradient of bessel_i1e(x) with respect to its argument."""
962 x = op.inputs[0]
963 y = op.outputs[0]
964 with ops.control_dependencies([grad]):
965 # For x = 0, the correct gradient is 0.5.
966 # However, the main branch gives NaN because of the division by x, so
967 # we impute the gradient manually.
968 # An alternative solution is to express the gradient via bessel_i0e and
969 # bessel_i2e, but the latter is not yet implemented in Eigen.
970 dy_dx = array_ops.where_v2(
971 math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype),
972 special_math_ops.bessel_i0e(x) - y *
973 (math_ops.sign(x) + math_ops.reciprocal(x)))
974 return grad * dy_dx
977@ops.RegisterGradient("BesselK0")
978def _BesselK0Grad(op, grad):
979 """Compute gradient of bessel_k0(x) with respect to its argument."""
980 x = op.inputs[0]
981 with ops.control_dependencies([grad]):
982 partial_x = -special_math_ops.bessel_k1(x)
983 return grad * partial_x
986@ops.RegisterGradient("BesselK0e")
987def _BesselK0eGrad(op, grad):
988 """Compute gradient of bessel_k0e(x) with respect to its argument."""
989 x = op.inputs[0]
990 y = op.outputs[0]
991 with ops.control_dependencies([grad]):
992 partial_x = (y - special_math_ops.bessel_k1e(x))
993 return grad * partial_x
996@ops.RegisterGradient("BesselK1")
997def _BesselK1Grad(op, grad):
998 """Compute gradient of bessel_k1(x) with respect to its argument."""
999 x = op.inputs[0]
1000 y = op.outputs[0]
1001 with ops.control_dependencies([grad]):
1002 # At 0., this is NaN which is fine since the derivative is undefined
1003 # at 0.
1004 partial_x = -special_math_ops.bessel_k0(x) - math_ops.div(y, x)
1005 return grad * partial_x
1008@ops.RegisterGradient("BesselK1e")
1009def _BesselK1eGrad(op, grad):
1010 """Compute gradient of bessel_k1e(x) with respect to its argument."""
1011 x = op.inputs[0]
1012 y = op.outputs[0]
1013 with ops.control_dependencies([grad]):
1014 # At 0., this is NaN which is fine since the derivative is undefined
1015 # at 0.
1016 partial_x = (
1017 y * (1. - math_ops.reciprocal(x)) - special_math_ops.bessel_k0e(x))
1018 return grad * partial_x
1021@ops.RegisterGradient("BesselJ0")
1022def _BesselJ0Grad(op, grad):
1023 """Compute gradient of bessel_j0(x) with respect to its argument."""
1024 x = op.inputs[0]
1025 with ops.control_dependencies([grad]):
1026 partial_x = -special_math_ops.bessel_j1(x)
1027 return grad * partial_x
1030@ops.RegisterGradient("BesselJ1")
1031def _BesselJ1Grad(op, grad):
1032 """Compute gradient of bessel_j1(x) with respect to its argument."""
1033 x = op.inputs[0]
1034 y = op.outputs[0]
1035 with ops.control_dependencies([grad]):
1036 # For x = 0, the correct gradient is 0.5.
1037 # However, the main branch gives NaN because of the division by x, so
1038 # we impute the gradient manually.
1039 # An alternative solution is to express the gradient via bessel_i0e and
1040 # bessel_i2e, but the latter is not yet implemented in Eigen.
1041 dy_dx = array_ops.where_v2(
1042 math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype),
1043 special_math_ops.bessel_j0(x) - math_ops.div(y, x))
1044 return grad * dy_dx
1047@ops.RegisterGradient("BesselY0")
1048def _BesselY0Grad(op, grad):
1049 """Compute gradient of bessel_y0(x) with respect to its argument."""
1050 x = op.inputs[0]
1051 with ops.control_dependencies([grad]):
1052 partial_x = -special_math_ops.bessel_y1(x)
1053 return grad * partial_x
1056@ops.RegisterGradient("BesselY1")
1057def _BesselY1Grad(op, grad):
1058 """Compute gradient of bessel_y1(x) with respect to its argument."""
1059 x = op.inputs[0]
1060 y = op.outputs[0]
1061 with ops.control_dependencies([grad]):
1062 # At 0., this is NaN which is fine since the derivative is undefined
1063 # at 0.
1064 partial_x = special_math_ops.bessel_y0(x) - math_ops.div(y, x)
1065 return grad * partial_x
1068@ops.RegisterGradient("Igamma")
1069def _IgammaGrad(op, grad):
1070 """Returns gradient of igamma(a, x) with respect to a and x."""
1071 a = op.inputs[0]
1072 x = op.inputs[1]
1073 sa = array_ops.shape(a)
1074 sx = array_ops.shape(x)
1075 ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx)
1077 with ops.control_dependencies([grad]):
1078 partial_a = gen_math_ops.igamma_grad_a(a, x)
1079 # Perform operations in log space before summing, because Gamma(a)
1080 # and Gamma'(a) can grow large.
1081 partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) -
1082 math_ops.lgamma(a))
1083 return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa),
1084 array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
1087@ops.RegisterGradient("Igammac")
1088def _IgammacGrad(op, grad):
1089 """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x."""
1090 igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad)
1091 return (-igamma_grad_a, -igamma_grad_x)
1094@ops.RegisterGradient("Betainc")
1095def _BetaincGrad(op, grad):
1096 """Returns gradient of betainc(a, b, x) with respect to x."""
1097 # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b
1098 a, b, x = op.inputs
1100 # two cases: x is a scalar and a/b are same-shaped tensors, or vice
1101 # versa; so its sufficient to check against shape(a).
1102 sa = array_ops.shape(a)
1103 sx = array_ops.shape(x)
1104 _, rx = gen_array_ops.broadcast_gradient_args(sa, sx)
1106 # Perform operations in log space before summing, because terms
1107 # can grow large.
1108 log_beta = (
1109 gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) -
1110 gen_math_ops.lgamma(a + b))
1111 # We use xlog1py and xlogy since the derivatives should tend to
1112 # zero one of the tails when a is 1. or b is 1.
1113 partial_x = math_ops.exp(math_ops.xlog1py(b - 1, -x) +
1114 math_ops.xlogy(a - 1, x) - log_beta)
1116 return (
1117 None, # da
1118 None, # db
1119 array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
1122@ops.RegisterGradient("Zeta")
1123def _ZetaGrad(op, grad):
1124 """Returns gradient of zeta(x, q) with respect to x and q."""
1125 # TODO(tillahoffmann): Add derivative with respect to x
1126 x = op.inputs[0]
1127 q = op.inputs[1]
1128 # Broadcast gradients
1129 sx = array_ops.shape(x)
1130 sq = array_ops.shape(q)
1131 unused_rx, rq = gen_array_ops.broadcast_gradient_args(sx, sq)
1132 # Evaluate gradient
1133 with ops.control_dependencies([grad]):
1134 x = math_ops.conj(x)
1135 q = math_ops.conj(q)
1136 partial_q = -x * math_ops.zeta(x + 1, q) # pylint: disable=invalid-unary-operand-type
1137 return (None,
1138 array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq))
1141@ops.RegisterGradient("Polygamma")
1142def _PolygammaGrad(op, grad):
1143 """Returns gradient of psi(n, x) with respect to n and x."""
1144 # TODO(tillahoffmann): Add derivative with respect to n
1145 n = op.inputs[0]
1146 x = op.inputs[1]
1147 # Broadcast gradients
1148 sn = array_ops.shape(n)
1149 sx = array_ops.shape(x)
1150 unused_rn, rx = gen_array_ops.broadcast_gradient_args(sn, sx)
1151 # Evaluate gradient
1152 with ops.control_dependencies([grad]):
1153 n = math_ops.conj(n)
1154 x = math_ops.conj(x)
1155 partial_x = math_ops.polygamma(n + 1, x)
1156 return (None,
1157 array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
1160@ops.RegisterGradient("Sigmoid")
1161def _SigmoidGrad(op, grad):
1162 """Returns grad * sigmoid(x) * (1 - sigmoid(x))."""
1163 y = op.outputs[0] # y = sigmoid(x)
1164 with ops.control_dependencies([grad]):
1165 y = math_ops.conj(y)
1166 return gen_math_ops.sigmoid_grad(y, grad)
1169@ops.RegisterGradient("SigmoidGrad")
1170def _SigmoidGradGrad(op, grad):
1171 with ops.control_dependencies([grad]):
1172 a = math_ops.conj(op.inputs[0])
1173 b = math_ops.conj(op.inputs[1])
1174 gb = grad * b
1175 return gb - 2.0 * gb * a, gen_math_ops.sigmoid_grad(a, grad)
1178@ops.RegisterGradient("Sign")
1179def _SignGrad(op, _):
1180 """Returns 0."""
1181 x = op.inputs[0]
1182 return array_ops.zeros_like(x)
1185@ops.RegisterGradient("Sin")
1186def _SinGrad(op, grad):
1187 """Returns grad * cos(x)."""
1188 x = op.inputs[0]
1189 with ops.control_dependencies([grad]):
1190 x = math_ops.conj(x)
1191 return grad * math_ops.cos(x)
1194@ops.RegisterGradient("Cos")
1195def _CosGrad(op, grad):
1196 """Returns grad * -sin(x)."""
1197 x = op.inputs[0]
1198 with ops.control_dependencies([grad]):
1199 x = math_ops.conj(x)
1200 return -grad * math_ops.sin(x)
1203@ops.RegisterGradient("Tan")
1204def _TanGrad(op, grad):
1205 """Returns grad * 1/sec^2(x)."""
1206 x = op.inputs[0]
1207 with ops.control_dependencies([grad]):
1208 x = math_ops.conj(x)
1209 secx = math_ops.reciprocal(math_ops.cos(x))
1210 secx2 = math_ops.square(secx)
1211 return secx2 * grad
1214@ops.RegisterGradient("Asin")
1215def _AsinGrad(op, grad):
1216 """Returns grad * 1/sqrt(1-x^2)."""
1217 x = op.inputs[0]
1218 with ops.control_dependencies([grad]):
1219 x = math_ops.conj(x)
1220 x2 = math_ops.square(x)
1221 one = constant_op.constant(1, dtype=grad.dtype)
1222 den = math_ops.sqrt(math_ops.subtract(one, x2))
1223 inv = math_ops.reciprocal(den)
1224 return grad * inv
1227@ops.RegisterGradient("Acos")
1228def _AcosGrad(op, grad):
1229 """Returns grad * -1/sqrt(1-x^2)."""
1230 x = op.inputs[0]
1231 with ops.control_dependencies([grad]):
1232 x = math_ops.conj(x)
1233 x2 = math_ops.square(x)
1234 one = constant_op.constant(1, dtype=grad.dtype)
1235 den = math_ops.sqrt(math_ops.subtract(one, x2))
1236 inv = math_ops.reciprocal(den)
1237 return -grad * inv
1240@ops.RegisterGradient("Atan")
1241def _AtanGrad(op, grad):
1242 """Returns grad * 1/ (1 + x^2)."""
1243 x = op.inputs[0]
1244 with ops.control_dependencies([grad]):
1245 x = math_ops.conj(x)
1246 x2 = math_ops.square(x)
1247 one = constant_op.constant(1, dtype=grad.dtype)
1248 inv = math_ops.reciprocal(math_ops.add(one, x2))
1249 return grad * inv
1252@ops.RegisterGradient("Atan2")
1253def _Atan2Grad(op, grad):
1254 """Returns grad * x / (x^2 + y^2), grad * -y / (x^2 + y^2)."""
1255 y = op.inputs[0]
1256 x = op.inputs[1]
1257 with ops.control_dependencies([grad]):
1258 grad_inv = grad / (math_ops.square(x) + math_ops.square(y))
1259 return x * grad_inv, -y * grad_inv
1262@ops.RegisterGradient("AddN")
1263def _AddNGrad(op, grad):
1264 """Copies the gradient to all inputs."""
1265 # Not broadcasting.
1266 return [grad] * len(op.inputs)
1269def _ShapesFullySpecifiedAndEqual(x, y, grad):
1270 # pylint: disable=protected-access
1271 x_shape = x._shape_tuple()
1272 y_shape = y._shape_tuple()
1273 grad_shape = grad._shape_tuple()
1274 # pylint: enable=protected-access
1275 return (x_shape == y_shape and x_shape == grad_shape and
1276 x_shape is not None and None not in x_shape)
1279@ops.RegisterGradient("Add")
1280@ops.RegisterGradient("AddV2")
1281def _AddGrad(op, grad):
1282 """Gradient for Add."""
1283 y = op.inputs[1]
1284 skip_input_indices = None
1285 try:
1286 skip_input_indices = op.skip_input_indices
1287 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
1288 y):
1289 return grad, None
1290 except AttributeError:
1291 # No gradient skipping, so do the full gradient computation
1292 pass
1293 x = op.inputs[0]
1294 if (isinstance(grad, ops.Tensor) and
1295 _ShapesFullySpecifiedAndEqual(x, y, grad)):
1296 return grad, grad
1297 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
1298 SmartBroadcastGradientArgs(x, y, grad))
1299 if skip_input_indices is not None and 0 in skip_input_indices:
1300 gx = None
1301 elif not must_reduce_x:
1302 gx = grad
1303 else:
1304 gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
1305 if skip_input_indices is not None and 1 in skip_input_indices:
1306 gy = None
1307 elif not must_reduce_y:
1308 gy = grad
1309 else:
1310 gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy)
1311 return (gx, gy)
1314@ops.RegisterGradient("Sub")
1315def _SubGrad(op, grad):
1316 """Gradient for Sub."""
1317 y = op.inputs[1]
1318 skip_input_indices = None
1319 try:
1320 skip_input_indices = op.skip_input_indices
1321 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
1322 y):
1323 return grad, None
1324 except AttributeError:
1325 # No gradient skipping, so do the full gradient computation
1326 pass
1327 x = op.inputs[0]
1328 if (isinstance(grad, ops.Tensor) and
1329 _ShapesFullySpecifiedAndEqual(x, y, grad)):
1330 return grad, -grad
1331 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
1332 SmartBroadcastGradientArgs(x, y, grad))
1333 if skip_input_indices is not None and 0 in skip_input_indices:
1334 gx = None
1335 elif not must_reduce_x:
1336 gx = grad
1337 else:
1338 gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
1339 if skip_input_indices is not None and 1 in skip_input_indices:
1340 gy = None
1341 elif not must_reduce_y:
1342 gy = -grad
1343 else:
1344 gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy)
1345 return (gx, gy)
1348@ops.RegisterGradient("Mul")
1349def _MulGrad(op, grad):
1350 """The gradient of scalar multiplication."""
1351 y = op.inputs[1]
1352 skip_input_indices = None
1353 try:
1354 skip_input_indices = op.skip_input_indices
1355 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
1356 y):
1357 return gen_math_ops.mul(grad, math_ops.conj(y)), None
1358 except AttributeError:
1359 # No gradient skipping, so do the full gradient computation
1360 pass
1361 x = op.inputs[0]
1362 if (isinstance(grad, ops.Tensor) and
1363 _ShapesFullySpecifiedAndEqual(x, y, grad) and
1364 grad.dtype in (dtypes.int32, dtypes.float32)):
1365 return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)
1366 assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
1368 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
1369 SmartBroadcastGradientArgs(x, y, grad))
1370 x = math_ops.conj(x)
1371 y = math_ops.conj(y)
1372 if skip_input_indices is not None and 0 in skip_input_indices:
1373 gx = None
1374 elif not must_reduce_x:
1375 gx = gen_math_ops.mul(grad, y)
1376 else:
1377 gx = array_ops.reshape(
1378 math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx)
1379 if skip_input_indices is not None and 1 in skip_input_indices:
1380 gy = None
1381 elif not must_reduce_y:
1382 gy = gen_math_ops.mul(x, grad)
1383 else:
1384 gy = array_ops.reshape(
1385 math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy)
1386 return (gx, gy)
1389@ops.RegisterGradient("MulNoNan")
1390def _MulNoNanGrad(op, grad):
1391 """The gradient of scalar multiplication with NaN-suppression."""
1392 x = op.inputs[0]
1393 y = op.inputs[1]
1394 if (isinstance(grad, ops.Tensor) and
1395 _ShapesFullySpecifiedAndEqual(x, y, grad)):
1396 return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad)
1397 assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
1398 sx = array_ops.shape(x)
1399 sy = array_ops.shape(y)
1400 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1401 return (array_ops.reshape(
1402 math_ops.reduce_sum(gen_math_ops.mul_no_nan(grad, y), rx), sx),
1403 array_ops.reshape(
1404 math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry), sy))
1407@ops.RegisterGradient("Div")
1408def _DivGrad(op, grad):
1409 """The gradient for the Div operator."""
1410 x = op.inputs[0]
1411 y = op.inputs[1]
1412 sx = array_ops.shape(x)
1413 sy = array_ops.shape(y)
1414 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1415 x = math_ops.conj(x)
1416 y = math_ops.conj(y)
1417 # pylint: disable=invalid-unary-operand-type
1418 return (
1419 array_ops.reshape(math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx),
1420 array_ops.reshape(
1421 math_ops.reduce_sum(grad * math_ops.divide(math_ops.divide(-x, y), y),
1422 ry), sy))
1425@ops.RegisterGradient("FloorDiv")
1426def _FloorDivGrad(_, unused_grad):
1427 """The gradient for the FloorDiv operator."""
1428 return None, None
1431@ops.RegisterGradient("FloorMod")
1432def _FloorModGrad(op, grad):
1433 """Returns grad * (1, -floor(x/y))."""
1434 x = math_ops.conj(op.inputs[0])
1435 y = math_ops.conj(op.inputs[1])
1437 sx = array_ops.shape(x)
1438 sy = array_ops.shape(y)
1439 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1440 floor_xy = math_ops.floor_div(x, y)
1441 gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
1442 gy = array_ops.reshape(
1443 math_ops.reduce_sum(grad * math_ops.negative(floor_xy), ry), sy)
1444 return gx, gy
1447@ops.RegisterGradient("TruncateDiv")
1448def _TruncateDivGrad(_, unused_grad):
1449 return None, None
1452@ops.RegisterGradient("RealDiv")
1453def _RealDivGrad(op, grad):
1454 """RealDiv op gradient."""
1455 x = op.inputs[0]
1456 y = op.inputs[1]
1457 sx = array_ops.shape(x)
1458 sy = array_ops.shape(y)
1459 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1460 x = math_ops.conj(x)
1461 y = math_ops.conj(y)
1462 return (array_ops.reshape(
1463 math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx),
1464 array_ops.reshape(
1465 math_ops.reduce_sum(
1466 grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy)) # pylint: disable=invalid-unary-operand-type
1469@ops.RegisterGradient("DivNoNan")
1470def _DivNoNanGrad(op, grad):
1471 """DivNoNan op gradient."""
1472 x = op.inputs[0]
1473 y = op.inputs[1]
1474 sx = array_ops.shape(x)
1475 sy = array_ops.shape(y)
1476 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1477 x = math_ops.conj(x)
1478 y = math_ops.conj(y)
1479 return (
1480 array_ops.reshape(
1481 math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
1482 array_ops.reshape(
1483 math_ops.reduce_sum(
1484 grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), # pylint: disable=invalid-unary-operand-type
1485 ry),
1486 sy))
1489@ops.RegisterGradient("Pow")
1490def _PowGrad(op, grad):
1491 """Returns grad * (y*x^(y-1), z*log(x))."""
1492 x = op.inputs[0]
1493 y = op.inputs[1]
1494 skip_input_indices = None
1495 try:
1496 skip_input_indices = op.skip_input_indices
1497 # TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the
1498 # constant `1` into a single constant op.
1499 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
1500 y):
1501 x = math_ops.conj(x)
1502 y = math_ops.conj(y)
1503 return grad * y * math_ops.pow(x, y - 1), None
1505 except AttributeError:
1506 # No gradient skipping, so do the full gradient computation
1507 pass
1509 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
1510 SmartBroadcastGradientArgs(x, y, grad))
1511 x = math_ops.conj(x)
1512 y = math_ops.conj(y)
1514 if skip_input_indices is None or 0 not in skip_input_indices:
1515 gx = grad * y * math_ops.pow(x, y - 1)
1516 if must_reduce_x:
1517 gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx)
1518 else:
1519 gx = None
1521 if skip_input_indices is None or 1 not in skip_input_indices:
1522 z = math_ops.conj(op.outputs[0])
1524 # Avoid false singularity at x = 0
1525 if x.dtype.is_complex:
1526 # real(x) < 0 is fine for the complex case
1527 mask = math_ops.not_equal(x, 0)
1528 else:
1529 # There's no sensible real value to return if x < 0, so return 0
1530 mask = x > 0
1531 safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
1532 log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
1533 gy = grad * z * log_x
1534 if must_reduce_y:
1535 gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy)
1536 else:
1537 gy = None
1539 return gx, gy
1542def _MaximumMinimumGradInputOnly(op, grad, selector_op):
1543 x = op.inputs[0]
1544 y = op.inputs[1]
1545 zeros = array_ops.zeros_like(grad)
1546 xmask = selector_op(x, y)
1547 xgrad = array_ops.where_v2(xmask, grad, zeros)
1548 ygrad = None # Return None for ygrad since the config allows that.
1549 return (xgrad, ygrad)
1552def _MaximumMinimumGrad(op, grad, selector_op):
1553 """Factor out the code for the gradient of Maximum or Minimum."""
1554 y = op.inputs[1]
1555 skip_input_indices = None
1556 try:
1557 skip_input_indices = op.skip_input_indices
1558 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
1559 y):
1560 # When we want to get gradients for the first input only, and the second
1561 # input tensor is a scalar, we can do a much simpler calculation
1562 return _MaximumMinimumGradInputOnly(op, grad, selector_op)
1563 except AttributeError:
1564 # No gradient skipping, so do the full gradient computation
1565 pass
1566 x = op.inputs[0]
1567 sx = array_ops.shape(x)
1568 sy = array_ops.shape(y)
1569 zeros = array_ops.zeros_like(grad)
1570 xmask = selector_op(x, y)
1571 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1572 if skip_input_indices is not None and 0 in skip_input_indices:
1573 gx = None
1574 else:
1575 xgrad = array_ops.where_v2(xmask, grad, zeros)
1576 gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
1578 if skip_input_indices is not None and 1 in skip_input_indices:
1579 gy = None
1580 else:
1581 ygrad = array_ops.where_v2(xmask, zeros, grad)
1582 gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
1584 return (gx, gy)
1587@ops.RegisterGradient("Maximum")
1588def _MaximumGrad(op, grad):
1589 """Returns grad*(x >= y, x < y) with type of grad."""
1590 return _MaximumMinimumGrad(op, grad, math_ops.greater_equal)
1593@ops.RegisterGradient("Minimum")
1594def _MinimumGrad(op, grad):
1595 """Returns grad*(x <= y, x > y) with type of grad."""
1596 return _MaximumMinimumGrad(op, grad, math_ops.less_equal)
1599@ops.RegisterGradient("SquaredDifference")
1600def _SquaredDifferenceGrad(op, grad):
1601 """Returns the gradient for (x-y)^2."""
1602 x = op.inputs[0]
1603 y = op.inputs[1]
1604 skip_input_indices = None
1605 try:
1606 skip_input_indices = op.skip_input_indices
1607 except AttributeError:
1608 # No gradient skipping, so do the full gradient computation
1609 pass
1611 with ops.control_dependencies([grad]):
1612 # The parens ensure that if grad is IndexedSlices, it'll get multiplied by
1613 # Tensor (not a number like 2.0) which causes it to convert to Tensor.
1614 x_grad = math_ops.scalar_mul(2.0, grad) * (x - y)
1616 if (isinstance(grad, ops.Tensor) and
1617 _ShapesFullySpecifiedAndEqual(x, y, grad)):
1618 return x_grad, -x_grad
1620 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
1621 SmartBroadcastGradientArgs(x, y, grad))
1623 if skip_input_indices is not None and 0 in skip_input_indices:
1624 gx = None
1625 elif must_reduce_x:
1626 gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx)
1627 else:
1628 gx = x_grad
1630 if skip_input_indices is not None and 1 in skip_input_indices:
1631 gy = None
1632 elif must_reduce_y:
1633 gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)
1634 else:
1635 gy = -x_grad
1636 return (gx, gy)
1639# Logical operations have no gradients.
1640ops.NotDifferentiable("Less")
1641ops.NotDifferentiable("LessEqual")
1642ops.NotDifferentiable("Greater")
1643ops.NotDifferentiable("GreaterEqual")
1644ops.NotDifferentiable("Equal")
1645ops.NotDifferentiable("ApproximateEqual")
1646ops.NotDifferentiable("NotEqual")
1647ops.NotDifferentiable("LogicalAnd")
1648ops.NotDifferentiable("LogicalOr")
1649ops.NotDifferentiable("LogicalNot")
1652@ops.RegisterGradient("Select")
1653def _SelectGrad(op, grad):
1654 c = op.inputs[0]
1655 x = op.inputs[1]
1656 zeros = array_ops.zeros_like(x)
1657 return (None, array_ops.where(c, grad, zeros), array_ops.where(
1658 c, zeros, grad))
1661@ops.RegisterGradient("SelectV2")
1662def _SelectGradV2(op, grad):
1663 c = op.inputs[0]
1664 x = op.inputs[1]
1665 y = op.inputs[2]
1666 zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype)
1667 gx = array_ops.where_v2(c, grad, zeros)
1668 x_shape = array_ops.shape(x)
1669 output_shape = array_ops.shape(op.outputs[0])
1670 # Reduce away broadcasted leading dims.
1671 reduce_x, _ = gen_array_ops.broadcast_gradient_args(x_shape, output_shape)
1672 gx = math_ops.reduce_sum(gx, keepdims=True, axis=reduce_x)
1673 gx = array_ops.reshape(gx, x_shape)
1675 gy = array_ops.where_v2(c, zeros, grad)
1676 y_shape = array_ops.shape(y)
1677 # Reduce away broadcasted leading dims.
1678 reduce_y, _ = gen_array_ops.broadcast_gradient_args(y_shape, output_shape)
1679 gy = math_ops.reduce_sum(gy, keepdims=True, axis=reduce_y)
1680 gy = array_ops.reshape(gy, y_shape)
1682 return (None, gx, gy)
1685def _MatMulGradAgainstFirstOnly(op, grad):
1686 """Gradient for MatMul, only for the first input."""
1687 t_a = op.get_attr("transpose_a")
1688 t_b = op.get_attr("transpose_b")
1689 b = math_ops.conj(op.inputs[1])
1690 if not t_a and not t_b:
1691 grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True)
1692 elif not t_a and t_b:
1693 grad_a = gen_math_ops.mat_mul(grad, b)
1694 elif t_a and not t_b:
1695 grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True)
1696 elif t_a and t_b:
1697 grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True)
1698 return grad_a, None
1701def _MatMulGradAgainstSecondOnly(op, grad):
1702 """Gradient for MatMul, only for the second input."""
1703 t_a = op.get_attr("transpose_a")
1704 t_b = op.get_attr("transpose_b")
1705 a = math_ops.conj(op.inputs[0])
1706 if not t_a and not t_b:
1707 grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True)
1708 elif not t_a and t_b:
1709 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True)
1710 elif t_a and not t_b:
1711 grad_b = gen_math_ops.mat_mul(a, grad)
1712 elif t_a and t_b:
1713 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True)
1714 return None, grad_b
1717@ops.RegisterGradient("MatMul")
1718def _MatMulGrad(op, grad):
1719 """Gradient for MatMul."""
1720 try:
1721 skip_input_indices = op.skip_input_indices
1722 if skip_input_indices is not None:
1723 if 1 in skip_input_indices:
1724 return _MatMulGradAgainstFirstOnly(op, grad)
1725 elif 0 in skip_input_indices:
1726 return _MatMulGradAgainstSecondOnly(op, grad)
1727 except AttributeError:
1728 # No gradient skipping, so do the full gradient computation
1729 pass
1731 t_a = op.get_attr("transpose_a")
1732 t_b = op.get_attr("transpose_b")
1733 a = math_ops.conj(op.inputs[0])
1734 b = math_ops.conj(op.inputs[1])
1735 if not t_a and not t_b:
1736 grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True)
1737 grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True)
1738 elif not t_a and t_b:
1739 grad_a = gen_math_ops.mat_mul(grad, b)
1740 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True)
1741 elif t_a and not t_b:
1742 grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True)
1743 grad_b = gen_math_ops.mat_mul(a, grad)
1744 elif t_a and t_b:
1745 grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True)
1746 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True)
1747 return grad_a, grad_b
1750@ops.RegisterGradient("SparseMatMul")
1751def _SparseMatMulGrad(op, grad):
1752 """Gradient for SparseMatMul."""
1754 t_a = op.get_attr("transpose_a")
1755 t_b = op.get_attr("transpose_b")
1756 is_sparse = {}
1757 is_sparse[op.inputs[0].ref()] = op.get_attr("a_is_sparse")
1758 is_sparse[op.inputs[1].ref()] = op.get_attr("b_is_sparse")
1759 # Use heuristic to figure out if grad might be sparse
1760 is_sparse[grad.ref()] = not context.executing_eagerly() and (
1761 grad.op.type == "ReluGrad")
1763 def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):
1764 """Helper function to create SparseMatMul op."""
1766 assert t1.ref() in is_sparse and t2.ref() in is_sparse
1767 t1_sparse = is_sparse[t1.ref()]
1768 t2_sparse = is_sparse[t2.ref()]
1769 if transpose_b:
1770 t2 = array_ops.transpose(t2)
1771 transpose_b = False
1772 prod = math_ops.matmul(
1773 t1,
1774 t2,
1775 transpose_a=transpose_a,
1776 transpose_b=transpose_b,
1777 a_is_sparse=t1_sparse,
1778 b_is_sparse=t2_sparse)
1779 if prod.dtype != out_dtype:
1780 prod = math_ops.cast(prod, out_dtype)
1781 return prod
1783 dtype_a = op.inputs[0].dtype
1784 dtype_b = op.inputs[1].dtype
1785 if not t_a and not t_b:
1786 return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True),
1787 _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True))
1788 elif not t_a and t_b:
1789 return (_SparseMatMul(grad, op.inputs[1], dtype_a),
1790 _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True))
1791 elif t_a and not t_b:
1792 return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True),
1793 _SparseMatMul(op.inputs[0], grad, dtype_b))
1794 elif t_a and t_b:
1795 return (_SparseMatMul(
1796 op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True),
1797 _SparseMatMul(
1798 grad, op.inputs[0], dtype_b, transpose_a=True,
1799 transpose_b=True))
1802@ops.RegisterGradient("Floor")
1803def _FloorGrad(_, unused_grad):
1804 return [None]
1807@ops.RegisterGradient("Ceil")
1808def _CeilGrad(_, unused_grad):
1809 return [None]
1812@ops.RegisterGradient("Round")
1813def _RoundGrad(_, unused_grad):
1814 return [None]
1817@ops.RegisterGradient("Rint")
1818def _RintGrad(_, unused_grad):
1819 # the gradient of Rint is zero
1820 return [None]
1823@ops.RegisterGradient("BatchMatMul")
1824def _BatchMatMul(op, grad):
1825 """Returns the gradient of x and y given the gradient of x * y."""
1826 x = op.inputs[0]
1827 y = op.inputs[1]
1828 adj_x = op.get_attr("adj_x")
1829 adj_y = op.get_attr("adj_y")
1831 if not adj_x:
1832 if not adj_y:
1833 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True)
1834 grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False)
1835 else:
1836 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False)
1837 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False)
1838 else:
1839 if not adj_y:
1840 grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True)
1841 grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False)
1842 else:
1843 grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
1844 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
1846 return grad_x, grad_y
1849@ops.RegisterGradient("BatchMatMulV2")
1850@ops.RegisterGradient("BatchMatMulV3")
1851def _BatchMatMulV2(op, grad):
1852 """Returns the gradient of x and y given the gradient of x * y."""
1853 x = op.inputs[0]
1854 y = op.inputs[1]
1855 adj_x = op.get_attr("adj_x")
1856 adj_y = op.get_attr("adj_y")
1858 if not adj_x:
1859 if not adj_y:
1860 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True)
1861 grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False)
1862 else:
1863 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False)
1864 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False)
1865 else:
1866 if not adj_y:
1867 grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True)
1868 grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False)
1869 else:
1870 grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
1871 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
1873 # Possibly reduce along the broadcasted batch dimensions, if broadcasting
1874 # is required.
1875 shape_x_static = x.get_shape()
1876 shape_y_static = y.get_shape()
1877 output_may_have_non_empty_batch_shape = (
1878 (shape_x_static.rank is None or shape_x_static.rank > 2) or
1879 (shape_y_static.rank is None or shape_y_static.rank > 2))
1880 batch_shapes_match = (
1881 shape_x_static[:-2].is_fully_defined() and
1882 shape_y_static[:-2].is_fully_defined() and
1883 shape_x_static[:-2] == shape_y_static[:-2])
1884 if (not output_may_have_non_empty_batch_shape) or batch_shapes_match:
1885 return grad_x, grad_y
1887 sx = array_ops.shape(x)
1888 sy = array_ops.shape(y)
1889 rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2])
1890 grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx)
1891 grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy)
1892 return grad_x, grad_y
1895ops.NotDifferentiable("Range")
1896ops.NotDifferentiable("LinSpace")
1899@ops.RegisterGradient("Complex")
1900def _ComplexGrad(op, grad):
1901 """Returns the real and imaginary components of 'grad', respectively."""
1902 x = op.inputs[0]
1903 y = op.inputs[1]
1904 sx = array_ops.shape(x)
1905 sy = array_ops.shape(y)
1906 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1907 return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx),
1908 array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy))
1911@ops.RegisterGradient("Real")
1912def _RealGrad(_, grad):
1913 """Returns 'grad' as the real part and set the imaginary part 0."""
1914 zero = constant_op.constant(0, dtype=grad.dtype)
1915 return math_ops.complex(grad, zero)
1918@ops.RegisterGradient("Imag")
1919def _ImagGrad(_, grad):
1920 """Returns 'grad' as the imaginary part and set the real part 0."""
1921 zero = constant_op.constant(0, dtype=grad.dtype)
1922 return math_ops.complex(zero, grad)
1925@ops.RegisterGradient("Angle")
1926def _AngleGrad(op, grad):
1927 """Returns -grad / (Im(x) + iRe(x))"""
1928 x = op.inputs[0]
1929 with ops.control_dependencies([grad]):
1930 re = math_ops.real(x)
1931 im = math_ops.imag(x)
1932 z = math_ops.reciprocal(math_ops.complex(im, re))
1933 zero = constant_op.constant(0, dtype=grad.dtype)
1934 complex_grad = math_ops.complex(grad, zero)
1935 return -complex_grad * z
1938@ops.RegisterGradient("Conj")
1939def _ConjGrad(_, grad):
1940 """Returns the complex conjugate of grad."""
1941 return math_ops.conj(grad)
1944@ops.RegisterGradient("ComplexAbs")
1945def _ComplexAbsGrad(op, grad):
1946 """Returns the gradient of ComplexAbs."""
1947 return math_ops.div_no_nan(
1948 math_ops.complex(
1949 grad, array_ops.zeros_like(grad)) * op.inputs[0],
1950 math_ops.complex(
1951 op.outputs[0], array_ops.zeros_like(op.outputs[0])))
1954@ops.RegisterGradient("Cast")
1955def _CastGrad(op, grad):
1956 t = [
1957 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16,
1958 dtypes.complex64, dtypes.complex128
1959 ]
1960 src_type = op.inputs[0].dtype.base_dtype
1961 dst_type = grad.dtype.base_dtype
1962 if src_type in t and dst_type in t:
1963 return math_ops.cast(grad, src_type)
1964 else:
1965 return None
1968@ops.RegisterGradient("Cross")
1969def _CrossGrad(op, grad):
1970 u = op.inputs[0]
1971 v = op.inputs[1]
1972 return (math_ops.cross(v, grad), math_ops.cross(grad, u))
1975@ops.RegisterGradient("Cumsum")
1976def _CumsumGrad(op, grad):
1977 axis = op.inputs[1]
1978 exclusive = op.get_attr("exclusive")
1979 reverse = op.get_attr("reverse")
1980 return [
1981 math_ops.cumsum(grad, axis, exclusive=exclusive, reverse=not reverse),
1982 None
1983 ]
1986@ops.RegisterGradient("Cumprod")
1987def _CumprodGrad(op, grad):
1988 x = op.inputs[0]
1989 axis = op.inputs[1]
1990 exclusive = op.get_attr("exclusive")
1991 reverse = op.get_attr("reverse")
1993 prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse)
1994 out = math_ops.cumsum(
1995 prod * grad, axis, exclusive=exclusive, reverse=not reverse)
1996 return [math_ops.div_no_nan(out, x), None]
1999@ops.RegisterGradient("CumulativeLogsumexp")
2000def _CumulativeLogsumexpGrad(op, grad):
2001 x = op.inputs[0]
2002 axis = op.inputs[1]
2003 cumulative_logsumexp = op.outputs[0]
2005 exclusive = op.get_attr("exclusive")
2006 reverse = op.get_attr("reverse")
2008 # Split the incoming gradient into positive and negative part
2009 # in order to take logs. This is required for stable results.
2010 log_grad_positive = array_ops.where_v2(
2011 math_ops.greater(grad, 0),
2012 math_ops.log(grad),
2013 grad.dtype.min)
2015 log_grad_negative = array_ops.where_v2(
2016 math_ops.less(grad, 0),
2017 math_ops.log(-grad),
2018 grad.dtype.min)
2020 output_pos = math_ops.exp(
2021 math_ops.cumulative_logsumexp(
2022 log_grad_positive - cumulative_logsumexp,
2023 axis=axis, reverse=not reverse, exclusive=exclusive) + x)
2025 output_neg = math_ops.exp(
2026 math_ops.cumulative_logsumexp(
2027 log_grad_negative - cumulative_logsumexp,
2028 axis=axis, reverse=not reverse, exclusive=exclusive) + x)
2030 return [output_pos - output_neg, None]
2033@ops.RegisterGradient("NextAfter")
2034def _NextAfterGrad(op, grad):
2035 """Returns gradient of nextafter(x1, x2) with respect to x1 and x2."""
2036 x1 = op.inputs[0]
2037 x2 = op.inputs[1]
2038 s_x1 = array_ops.shape(x1)
2039 s_x2 = array_ops.shape(x2)
2040 r_x1, r_x2 = gen_array_ops.broadcast_gradient_args(s_x1, s_x2)
2041 with ops.control_dependencies([grad]):
2042 partial_x1 = array_ops.ones(s_x1, dtype=x1.dtype)
2043 partial_x2 = array_ops.zeros(s_x2, dtype=x2.dtype)
2044 return (array_ops.reshape(
2045 math_ops.reduce_sum(partial_x1 * grad, r_x1), s_x1),
2046 array_ops.reshape(
2047 math_ops.reduce_sum(partial_x2 * grad, r_x2), s_x2))