Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg_grad.py: 10%
450 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in linalg_ops.py.
17Useful reference for derivative formulas is (Mike Giles, 2008).
19Ionescu et al. (2015) provide a detailed derivation of formulas for
20backpropagating through spectral layers (SVD and Eig).
22References:
23 An extended collection of matrix derivative results for
24 forward and reverse mode automatic differentiation:
25 [Mike Giles, 2008]
26 (https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124)
27 ([pdf](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf))
28 Matrix Backpropagation for Deep Networks with Structured Layers
29 [Ionescu et al., 2015]
30 (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.html)
31 ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.pdf))
32 Training Deep Networks with Structured Layers by Matrix Backpropagation:
33 [Ionescu et al., 2015](https://arxiv.org/abs/1509.07838)
34 ([pdf](https://arxiv.org/pdf/1509.07838.pdf))
35"""
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import array_ops_stack
40from tensorflow.python.ops import cond
41from tensorflow.python.ops import gen_linalg_ops
42from tensorflow.python.ops import linalg_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops.linalg import linalg_impl as _linalg
47@ops.RegisterGradient("MatrixInverse")
48def _MatrixInverseGrad(op, grad):
49 """Gradient for MatrixInverse."""
50 ainv = op.outputs[0]
51 op_adjoint = op.get_attr("adjoint")
52 return -math_ops.matmul( # pylint: disable=invalid-unary-operand-type
53 ainv,
54 math_ops.matmul(grad, ainv, adjoint_a=op_adjoint,
55 adjoint_b=not op_adjoint),
56 adjoint_a=not op_adjoint)
59@ops.RegisterGradient("Einsum")
60def _EinsumGrad(op, grad):
61 """Gradient for Einsum."""
62 ellipsis = "..."
64 def _GetAxisFromLabel(subscripts, label):
65 """Returns the axis (possibly negative) corresponding to a label.
67 Returns the axis index of the axis label if it is before an ellipsis (or if
68 the ellipsis is not present), and the negative index if it occurs after the
69 ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`.
71 For multiple occurrences, returns the leftmost one. If not found, returns
72 None.
74 Args:
75 subscripts: A string denoting the einsum subscript (e.g. `ab...cd`)
76 label: The single character axis label.
77 """
78 splits = subscripts.split(ellipsis)
79 index = splits[0].find(label)
80 if index != -1:
81 return index
82 if len(splits) < 2:
83 return None
84 index = splits[1].find(label)
85 if index != -1:
86 return index - len(splits[1])
87 return None
89 def _GetBcastSubshape(subscripts):
90 """Returns a tuple denoting the slice mapping to ellipsis.
92 For a given subscript, returns a tuple (start, end) denoting the start
93 axis index and the (negative) end axis index respectively. For any input
94 Tensor `x` described by the subscript, `x[start:end]` would be the slice
95 represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`.
97 If ellipsis is not present in `subscripts`, returns `(0, 0)`.
99 Args:
100 subscripts: A string denoting the einsum subscript.
101 """
102 start = subscripts.find(ellipsis)
103 if start == -1:
104 return 0, 0
105 remaining = len(subscripts) - (start + len(ellipsis))
106 end = -remaining if remaining > 0 else None
107 return start, end
109 def _GetReducedSubscripts(reduced_label_set, input_shape, subscripts):
110 """Returns reduced subscripts and their corresponding dimensions and axes.
112 Given a set of axis labels, returns their concatenated subscript, their
113 corresponding dimensions from input_shape, and their corresponding axes.
114 Note that the concatenated subscript `reduced_subs` may have axis labels
115 from `reduced_label_set` in any order. For example, for the reduced label
116 set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns
117 subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`.
119 Args:
120 reduced_label_set: Set of axis labels which appear in `subscripts`.
121 input_shape: A `Tensor` representing the shape of the einsum operand
122 corresponding to `subscripts`.
123 subscripts: A string denoting the einsum subscript.
125 Returns:
126 reduced_subs: Subscripts formed by a concatenation of labels in
127 `reduced_label_set`.
128 reduced_dims: Dimensions from `input_shape` corresponding to each label
129 in `reduced_subs`.
130 reduced_axes: Axes described by `subscripts` corresponding to each label
131 in `reduced_subs`. If there are multiple occurrences in `subscripts`,
132 we consider only the leftmost one.
134 """
135 # Concatenate the sequence of reduced axis labels.
136 reduced_subs = "".join(list(reduced_label_set))
137 # Get the axis (may be positive, negative or zero) for each of the reduced
138 # labels. If the same label appears multiple times, get the left-most axis.
139 reduced_axes = [_GetAxisFromLabel(subscripts, s) for s in reduced_subs]
140 # Get the corresponding dimensions for each reduced axis.
141 reduced_dims = array_ops_stack.stack(
142 [input_shape[ax] for ax in reduced_axes])
143 return reduced_subs, reduced_dims, reduced_axes
145 def _GetGradReduced(output_grad, output_subs, input_subs, input_shape,
146 reduced_label_set):
147 """Returns the gradient wrt input for a unary einsum with reductions.
149 Args:
150 output_grad: The gradient wrt the output of a unary einsum operation.
151 output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`).
152 input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`).
153 input_shape: A `Tensor` representing the shape of the input operand.
154 reduced_label_set: The set of axis labels appearing in `input_subs` but
155 not in `output_subs`.
156 """
157 # Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and
158 # 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced
159 # subscripts "bd", corresponding dimensions [5,4] and axes [2,5].
160 reduced_subs, reduced_dims, reduced_axes = _GetReducedSubscripts(
161 reduced_label_set, input_shape, input_subs)
162 # Whether either the input or the output subscripts have a repeated label.
163 # This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca".
164 has_repeated_labels = (
165 len(set(input_subs)) + len(set(output_subs)) <
166 len(input_subs) + len(output_subs))
167 # Compute the input subscripts without the reduced axis labels, e.g. "aac"
168 # for the equation "aabbcd->ca".
169 input_subs_without_reduced_labels = "".join(
170 [s for s in input_subs if s not in reduced_label_set])
172 # The gradient wrt the input for the equation "abc->ac" (or, equivalently
173 # reduce_sum(..., axis=1)) is just the gradient of the output tiled N times
174 # along axis 1, where label 'b' represents a dimension of size N.
175 #
176 # If we're not dealing with repeated labels, and the non-reduced labels
177 # doesn't need to be transposed, then just tiling is enough and there is no
178 # need to call another einsum. For example, tiling is sufficient for
179 # "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or
180 # "abc->ca" (transpose), we'd need another einsum operation after tiling.
181 if (not has_repeated_labels and
182 input_subs_without_reduced_labels == output_subs):
183 # Obtain the shape of the output, as if keepdims=True on reduce sum. E.g.
184 # for the equation "abcd->ac" with input shape [2,5,3,4], we get the
185 # reduced shape [2,1,3,1].
186 reduced_shape = math_ops.reduced_shape(
187 input_shape, ops.convert_to_tensor(reduced_axes))
188 # Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to
189 # the shape [2,5,3,4] results in the gradient wrt "abcd".
190 return array_ops.broadcast_to(
191 array_ops.reshape(output_grad, reduced_shape), input_shape)
193 # If we *do* have traces or transpose operations, then prepend the extra
194 # reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd
195 # first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca".
196 #
197 # Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2].
198 # This is the shape of the intermediate "bdca".
199 grad_shape_with_reduced_labels = array_ops.concat(
200 [reduced_dims, array_ops.shape(output_grad)], axis=0)
201 # Obtain the output shape of the reduction-only equation "bdca->ca" as if
202 # keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels, we
203 # just have to prepend that many 1s to the output shape.
204 reduced_shape = (
205 array_ops.concat([
206 array_ops.ones(len(reduced_label_set), dtype=dtypes.int32),
207 array_ops.shape(output_grad)
208 ],
209 axis=0))
210 # Compute the VJP for the intermediate (viz. "bdca->ca") for which
211 # broadcasting is sufficient.
212 broadcasted_grad = array_ops.broadcast_to(
213 array_ops.reshape(output_grad, reduced_shape),
214 grad_shape_with_reduced_labels)
215 # Compute the VJP for the final step (viz. "aabbcd->bdca"). We can use
216 # einsum with the input and output subscripts reversed (viz. "bdca->aabbcd")
217 # since the output axis labels now appear in the input subscripts.
218 return gen_linalg_ops.einsum([broadcasted_grad],
219 "{}->{}".format(reduced_subs + output_subs,
220 input_subs))
222 def _GetGradWrt(output_grad, other_operand, input_shape, input_subs,
223 other_subs, output_subs):
224 """Returns the gradient wrt an input operand for a binary einsum.
226 This function does not handle (un)broadcasting. This must be done separately
227 on the returned gradient.
229 Args:
230 output_grad: The gradient wrt the output of a binary einsum operation.
231 other_operand: The complementary `Tensor` operand i.e. which is not the
232 input operand.
233 input_shape: A `Tensor` representing the shape of input operand.
234 input_subs: The subscripts of the input operand.
235 other_subs: The subscripts of the complementary operand.
236 output_subs: The output subscripts.
237 """
238 # Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y),
239 # where the equation involves only Tensor contractions, generalized traces
240 # and transposes, the input gradients are given by the vector-jacobian
241 # products (VJPs):
242 #
243 # grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z)
244 # grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z}
245 #
246 # where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs
247 # x and y and grad_wrt_z is the given gradient with respect to output z.
248 #
249 # Proof: For unary einsum equations involving only transpose ("ij->ji") and
250 # traces ("ii->i"), the linear mapping's Jacobian at input x is given
251 # by the function itself. We can verify that the linear map given by the
252 # VJP are einsums with the equations "ji->ij" and "i->ii" respectively,
253 # where the latter represents 'un-tracing', or filling the diagonal with
254 # the input axis and non-diagonal entries are zeros.
255 # Furthermore, recall that matrix multiplication, which is
256 # represented by the equation "ab,bc->ac", has its VJPs given by the
257 # einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example
258 # https://math.stackexchange.com/a/2755680). Combined with transposes and
259 # traces we can rewrite Tensor contractions as regular matrix
260 # multiplication. Since each of these operations have their VJPs described
261 # by einsums of the required pattern, the result follows.
262 #
263 # Accordingly, einsum operations except for those with reductions, e.g.
264 # "abc,cd->ad" have their VJPs defined by:
265 # "{output_subs},{other_subs}->{input_subs}".
266 #
267 # But if there is a reduction, this would lead to the equation "ad,cd->abc"
268 # which is invalid because the reduced axis label 'b' is present in the
269 # output but not in any of the inputs. Therefore, we compute the VJP in two
270 # steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of
271 # "abc->ac" or, equivalently, reduce_sum(..., axis=1).
272 #
273 # Compute the set of input axis labels which doesn't appear in either the
274 # output subscripts or the other operand's subscript. E.g. the set {'b'} for
275 # the equation "abc,cd->ad".
276 reduced_label_set = set(input_subs).difference(
277 set(output_subs + other_subs + "."))
278 # Obtain the input subscripts with the reduced axis labels removed. E.g.
279 # "ac" in the above example.
280 left_subs = "".join(s for s in input_subs if s not in reduced_label_set)
282 # Compute the gradient wrt the input, without accounting for the operation
283 # "abc->ac". So, now we have the VJP of the operation "ac,cd->ad".
284 grad_reduced = gen_linalg_ops.einsum([output_grad, other_operand],
285 "{},{}->{}".format(
286 output_subs, other_subs,
287 left_subs))
288 # If the reduced_label_set is empty, then we already have the gradient
289 # wrt the input.
290 if not reduced_label_set:
291 return grad_reduced
292 # Otherwise, we currently have the gradient wrt the output of the reduction
293 # operation "abc->ac". Invoke the subroutine for the gradient for unary
294 # einsum with reductions.
295 return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape,
296 reduced_label_set)
298 equation = op.get_attr("equation")
299 if isinstance(equation, bytes):
300 equation = equation.decode()
301 input_subs, output_subs = equation.split("->")
303 if len(op.inputs) == 1:
304 # For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt the
305 # input (VJP) is given by the reversed equation:
306 # grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z)
307 # (See the justification in _GetGradWrt). This is valid unless there are
308 # reduced axis labels; i.e. axis labels appearing in the input but not in
309 # the output subscripts.
310 input_shape = array_ops.shape(op.inputs[0])
311 # Find the axis labels which appear only in the input.
312 reduced_label_set = set(input_subs).difference(set(output_subs + ellipsis))
313 if not reduced_label_set:
314 # Return the einsum given by the reversed equation, since we don't have
315 # reduced axes.
316 return gen_linalg_ops.einsum([grad],
317 "{}->{}".format(output_subs, input_subs))
318 # We do have reduced axes, so we invoke the subroutine for reduced unary
319 # einsums.
320 return _GetGradReduced(grad, output_subs, input_subs, input_shape,
321 reduced_label_set)
323 x_subs, y_subs = input_subs.split(",")
324 # Add ellipsis for broadcasted dimensions if any operand does not have it.
325 # This is because the equation "...ij,jk->ik" may be valid if the 0th input's
326 # batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
327 # because only the output subscripts contain ellipsis.
328 if ellipsis in output_subs:
329 if ellipsis not in x_subs:
330 x_subs += ellipsis
331 if ellipsis not in y_subs:
332 y_subs += ellipsis
334 # Obtain the gradients wrt the inputs x and y, without taking into account
335 # the unbroadcasting.
336 x, y = op.inputs[0], op.inputs[1]
337 if grad.dtype.is_complex:
338 x = math_ops.conj(x)
339 y = math_ops.conj(y)
341 x_shape = array_ops.shape(x)
342 y_shape = array_ops.shape(y)
343 grad_x = _GetGradWrt(grad, y, x_shape, x_subs, y_subs, output_subs)
344 grad_y = _GetGradWrt(grad, x, y_shape, y_subs, x_subs, output_subs)
346 if ellipsis not in output_subs:
347 # If no ellipsis in the output; then no need to unbroadcast.
348 return grad_x, grad_y
350 # Below we handle the case that broadcasting between x and y was necessary,
351 # with x and y having possibly different batch shapes.
353 # Obtain the range of axes which map to ellipsis. E.g. for subscripts 'ab...c'
354 # and shape of rank 10; the range [3:-1] denotes the broadcasted axes.
355 bx_start, bx_end = _GetBcastSubshape(x_subs)
356 by_start, by_end = _GetBcastSubshape(y_subs)
357 # If the static batch shapes are equal, we don't need to unbroadcast.
358 x_shape_static = x.get_shape()
359 y_shape_static = y.get_shape()
360 if (x_shape_static.is_fully_defined() and
361 y_shape_static.is_fully_defined() and
362 x_shape_static[bx_start:bx_end] == y_shape_static[by_start:by_end]):
363 return grad_x, grad_y
365 # Sum the gradient across the broadcasted axes.
366 rx, ry = array_ops.broadcast_gradient_args(x_shape[bx_start:bx_end],
367 y_shape[by_start:by_end])
368 grad_x = array_ops.reshape(
369 math_ops.reduce_sum(grad_x, bx_start + rx), x_shape)
370 grad_y = array_ops.reshape(
371 math_ops.reduce_sum(grad_y, by_start + ry), y_shape)
372 return grad_x, grad_y
375@ops.RegisterGradient("MatrixDeterminant")
376def _MatrixDeterminantGrad(op, grad):
377 """Gradient for MatrixDeterminant."""
378 a = op.inputs[0]
379 c = op.outputs[0]
380 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
381 multipliers = array_ops.reshape(grad * c,
382 array_ops.concat([array_ops.shape(c), [1, 1]],
383 0))
384 return multipliers * a_adj_inv
387@ops.RegisterGradient("MatrixSquareRoot")
388def _MatrixSquareRootGrad(op, grad):
389 """Gradient for MatrixSquareRoot."""
391 # Let A be an m x m square matrix (or batch of matrices)
392 # Let R = sqrtm(A)
393 # By definition, A = RR
394 # Take the differential: dA = d(RR) = RdR + dRR
395 # Solve the resulting Sylvester equation for dR
397 # Used to find Kronecker products within the Sylvester equation
398 def _KroneckerProduct(b1, b2):
399 """Computes the Kronecker product of two batches of square matrices."""
400 b1_shape = array_ops.shape(b1)
401 b2_shape = array_ops.shape(b2)
402 b1_order = b1_shape[-1]
403 b2_order = b2_shape[-1]
405 shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)]
406 shape_slice = array_ops.slice(b1_shape, [0],
407 shape_slice_size) # Same for both batches
408 b1_reshape_shape = array_ops.concat(
409 [shape_slice, [b1_order], [1], [b1_order], [1]], 0)
410 b2_reshape_shape = array_ops.concat(
411 [shape_slice, [1], [b2_order], [1], [b2_order]], 0)
413 b1_reshape = array_ops.reshape(b1, b1_reshape_shape)
414 b2_reshape = array_ops.reshape(b2, b2_reshape_shape)
416 order_prod = b1_order * b2_order
417 kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0)
418 return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape)
420 sqrtm = op.outputs[0] # R
421 shape = array_ops.shape(sqrtm)
422 order = shape[-1] # m
423 matrix_count = math_ops.reduce_prod(shape[0:-2])
425 # Get batch of m x m identity matrices
426 eye = linalg_ops.eye(order, dtype=sqrtm.dtype) # m x m identity matrix
427 eye_flat = array_ops.reshape(eye, [-1])
428 eye_tiled = array_ops.tile(eye_flat, [matrix_count])
429 eye_batch = array_ops.reshape(eye_tiled, shape)
431 # The transpose of R is taken in the k1 term instead of k2 in
432 # order to prevent redundant transposition of R (i.e. (R')' = R)
433 sqrtm_transpose = array_ops.matrix_transpose(sqrtm)
434 k1 = _KroneckerProduct(eye_batch, sqrtm_transpose)
435 k2 = _KroneckerProduct(sqrtm, eye_batch)
436 ksum = math_ops.add(k1, k2)
438 # Vectorize dA
439 shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)]
440 shape_slice = array_ops.slice(shape, [0], shape_slice_size)
441 shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0)
442 vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da)
444 # Solve for vec(dR)
445 vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da)
447 # Solve for dR by inverse vectorizing vec(dR)
448 dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape)
449 return array_ops.matrix_transpose(dsqrtm_transpose)
452@ops.RegisterGradient("LogMatrixDeterminant")
453def _LogMatrixDeterminantGrad(op, _, grad_b):
454 """Gradient for LogMatrixDeterminant."""
455 a = op.inputs[0]
456 c = op.outputs[1]
457 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
458 multipliers = array_ops.reshape(
459 grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0))
460 return multipliers * a_adj_inv
463@ops.RegisterGradient("Cholesky")
464def _CholeskyGrad(op, grad):
465 """Gradient for Cholesky."""
467 # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
468 l = op.outputs[0]
469 num_rows = array_ops.shape(l)[-1]
470 batch_shape = array_ops.shape(l)[:-2]
471 l_inverse = linalg_ops.matrix_triangular_solve(l,
472 linalg_ops.eye(
473 num_rows,
474 batch_shape=batch_shape,
475 dtype=l.dtype))
477 middle = math_ops.matmul(l, grad, adjoint_a=True)
478 middle = array_ops.matrix_set_diag(middle,
479 0.5 * array_ops.matrix_diag_part(middle))
480 middle = array_ops.matrix_band_part(middle, -1, 0)
482 grad_a = math_ops.matmul(
483 math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
485 grad_a += _linalg.adjoint(grad_a)
486 return grad_a * 0.5
489@ops.RegisterGradient("Qr")
490def _QrGrad(op, dq, dr):
491 """Gradient for Qr."""
493 # The methodology is explained in detail in https://arxiv.org/abs/2009.10071
494 # QR and LQ Decomposition Matrix Backpropagation Algorithms for
495 # Square, Wide, and Deep, Real and Complex, Matrices and Their Software
496 # Implementation
497 q, r = op.outputs
498 if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
499 r.shape.as_list()[-1] is None):
500 raise NotImplementedError("QrGrad not implemented with dynamic shapes. "
501 f"Received r.shape: {r.shape}")
502 if (r.shape.dims[-2].value > r.shape.dims[-1].value and
503 q.shape.dims[-2].value == q.shape.dims[-1].value):
504 raise NotImplementedError("QrGrad not implemented when nrows > ncols "
505 "and full_matrices is true. Received r.shape="
506 f"{r.shape} with nrows={r.shape.dims[-2]}"
507 f"and ncols={r.shape.dims[-1]}.")
509 def _TriangularSolve(x, r):
510 """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
511 return _linalg.adjoint(
512 linalg_ops.matrix_triangular_solve(
513 r, _linalg.adjoint(x), lower=False, adjoint=False))
515 def _QrGradSquareAndDeepMatrices(q, r, dq, dr):
516 """Gradient for matrix orders num_rows >= num_cols
517 and full_matrices is false.
518 """
519 qdq = math_ops.matmul(q, dq, adjoint_a=True)
520 qdq_ = qdq - _linalg.adjoint(qdq)
521 rdr = math_ops.matmul(r, dr, adjoint_b=True)
522 rdr_ = rdr - _linalg.adjoint(rdr)
523 tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)
525 grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
526 grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
527 ret = grad_a + grad_b
529 if q.dtype.is_complex:
530 # need to add a correction to the gradient formula for complex case
531 m = rdr - _linalg.adjoint(qdq)
532 eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m))
533 correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype)
534 ret = ret + _TriangularSolve(
535 math_ops.matmul(q, _linalg.adjoint(correction)), r)
537 return ret
539 num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1]
541 if num_rows >= num_cols:
542 return _QrGradSquareAndDeepMatrices(q, r, dq, dr)
544 # Partition a = [x, y], r = [u, v] and reduce to the square case
545 a = op.inputs[0]
546 y = a[..., :, num_rows:]
547 u = r[..., :, :num_rows]
548 dv = dr[..., :, num_rows:]
549 du = dr[..., :, :num_rows]
550 dy = math_ops.matmul(q, dv)
551 dx = _QrGradSquareAndDeepMatrices(q, u,
552 dq + math_ops.matmul(y, dv, adjoint_b=True),
553 du)
554 return array_ops.concat([dx, dy], axis=-1)
557@ops.RegisterGradient("MatrixSolve")
558def _MatrixSolveGrad(op, grad):
559 """Gradient for MatrixSolve."""
560 a = op.inputs[0]
561 adjoint_a = op.get_attr("adjoint")
562 c = op.outputs[0]
563 grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a)
564 if adjoint_a:
565 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type
566 else:
567 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type
568 return (grad_a, grad_b)
571@ops.RegisterGradient("MatrixSolveLs")
572def _MatrixSolveLsGrad(op, grad):
573 """Gradients for MatrixSolveLs."""
575 # TODO(rmlarsen): The implementation could be more efficient:
576 # a) Output the Cholesky factorization from forward op instead of
577 # recomputing it here.
578 # b) Implement a symmetric rank-k update op instead of computing
579 # x*z + transpose(x*z). This pattern occurs other places in TensorFlow.
581 def _Overdetermined(op, grad):
582 """Gradients for the overdetermined case of MatrixSolveLs.
584 This is the backprop for the solution to the normal equations of the first
585 kind:
586 X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B
587 which solve the least squares problem
588 min ||A * X - B||_F^2 + lambda ||X||_F^2.
589 """
590 a = op.inputs[0]
591 b = op.inputs[1]
592 x = op.outputs[0]
593 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
594 # pylint: disable=protected-access
595 chol = linalg_ops._RegularizedGramianCholesky(
596 a, l2_regularizer=l2_regularizer, first_kind=True)
597 # pylint: enable=protected-access
598 # Temporary z = (A^T * A + lambda * I)^{-1} * grad.
599 z = linalg_ops.cholesky_solve(chol, grad)
600 xzt = math_ops.matmul(x, z, adjoint_b=True)
601 zx_sym = xzt + array_ops.matrix_transpose(xzt)
602 grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True) # pylint: disable=invalid-unary-operand-type
603 grad_b = math_ops.matmul(a, z)
604 return (grad_a, grad_b, None)
606 def _Underdetermined(op, grad):
607 """Gradients for the underdetermined case of MatrixSolveLs.
609 This is the backprop for the solution to the normal equations of the second
610 kind:
611 X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B
612 that (for lambda=0) solve the least squares problem
613 min ||X||_F subject to A*X = B.
614 """
615 a = op.inputs[0]
616 b = op.inputs[1]
617 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
618 # pylint: disable=protected-access
619 chol = linalg_ops._RegularizedGramianCholesky(
620 a, l2_regularizer=l2_regularizer, first_kind=False)
621 # pylint: enable=protected-access
622 grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad))
623 # Temporary tmp = (A * A^T + lambda * I)^{-1} * B.
624 tmp = linalg_ops.cholesky_solve(chol, b)
625 a1 = math_ops.matmul(tmp, a, adjoint_a=True)
626 a1 = -math_ops.matmul(grad_b, a1) # pylint: disable=invalid-unary-operand-type
627 a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True)
628 a2 = math_ops.matmul(tmp, a2, adjoint_b=True)
629 grad_a = a1 + a2
630 return (grad_a, grad_b, None)
632 fast = op.get_attr("fast")
633 if fast is False:
634 raise ValueError("Gradient not defined for fast=False")
635 matrix_shape = op.inputs[0].get_shape()[-2:]
636 if matrix_shape.is_fully_defined():
637 if matrix_shape[-2] >= matrix_shape[-1]:
638 return _Overdetermined(op, grad)
639 else:
640 return _Underdetermined(op, grad)
641 else:
642 # We have to defer determining the shape to runtime and use
643 # conditional execution of the appropriate graph.
644 matrix_shape = array_ops.shape(op.inputs[0])[-2:]
645 return cond.cond(matrix_shape[-2] >= matrix_shape[-1],
646 lambda: _Overdetermined(op, grad),
647 lambda: _Underdetermined(op, grad))
650@ops.RegisterGradient("BandedTriangularSolve")
651def _BandedTriangularSolveGrad(op, grad):
652 """Gradient for BandedTriangularSolve."""
653 a = op.inputs[0]
654 b = op.inputs[1]
655 num_bands = array_ops.shape(a)[-2]
656 adjoint_a = op.get_attr("adjoint")
657 lower_a = op.get_attr("lower")
658 c = op.outputs[0]
659 grad_b = linalg_ops.banded_triangular_solve(
660 a, grad, lower=lower_a, adjoint=not adjoint_a)
661 if adjoint_a:
662 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type
663 else:
664 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type
665 if lower_a:
666 grad_a = array_ops.matrix_diag_part(
667 grad_a, k=(-(num_bands - 1), 0), align="LEFT_RIGHT")
668 else:
669 grad_a = array_ops.matrix_diag_part(
670 grad_a, k=(0, num_bands - 1), align="LEFT_RIGHT")
671 # If the static batch shapes are equal, we don't need to unbroadcast.
672 if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
673 a.shape[:-2] == b.shape[:-2]):
674 return grad_a, grad_b
675 a_shape = array_ops.shape(a)
676 b_shape = array_ops.shape(b)
677 ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
678 grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
679 grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
680 return grad_a, grad_b
683@ops.RegisterGradient("MatrixTriangularSolve")
684def _MatrixTriangularSolveGrad(op, grad):
685 """Gradient for MatrixTriangularSolve."""
686 a = op.inputs[0]
687 b = op.inputs[1]
688 adjoint_a = op.get_attr("adjoint")
689 lower_a = op.get_attr("lower")
690 c = op.outputs[0]
691 grad_b = linalg_ops.matrix_triangular_solve(
692 a, grad, lower=lower_a, adjoint=not adjoint_a)
693 if adjoint_a:
694 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type
695 else:
696 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type
697 if lower_a:
698 grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
699 else:
700 grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
701 # If the static batch shapes are equal, we don't need to unbroadcast.
702 if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
703 a.shape[:-2] == b.shape[:-2]):
704 return grad_a, grad_b
705 a_shape = array_ops.shape(a)
706 b_shape = array_ops.shape(b)
707 ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
708 grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
709 grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
710 return grad_a, grad_b
713# To avoid nan in cases with degenerate eigenvalues or
714# degenerate/zero singular values in calculations of
715# f and s_inv_mat, we introduce a Lorentz broadening.
716def _SafeReciprocal(x, epsilon=1E-20):
717 return x * math_ops.reciprocal(x * x + epsilon)
720@ops.RegisterGradient("Eig")
721def _EigGrad(op, grad_e, grad_v):
722 """Gradient for Eig.
724 Based on eq. 4.77 from paper by
725 Christoph Boeddeker et al.
726 https://arxiv.org/abs/1701.00392
727 See also
728 "Computation of eigenvalue and eigenvector derivatives
729 for a general complex-valued eigensystem" by Nico van der Aa.
730 As for now only distinct eigenvalue case is considered.
731 """
732 e = op.outputs[0]
733 compute_v = op.get_attr("compute_v")
734 # a = op.inputs[0], which satisfies
735 # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
736 with ops.control_dependencies([grad_e, grad_v]):
737 if compute_v:
738 v = op.outputs[1]
739 vt = _linalg.adjoint(v)
740 # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
741 # Notice that because of the term involving f, the gradient becomes
742 # infinite (or NaN in practice) when eigenvalues are not unique.
743 # Mathematically this should not be surprising, since for (k-fold)
744 # degenerate eigenvalues, the corresponding eigenvectors are only defined
745 # up to arbitrary rotation in a (k-dimensional) subspace.
746 f = array_ops.matrix_set_diag(
747 _SafeReciprocal(
748 array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
749 array_ops.zeros_like(e))
750 f = math_ops.conj(f)
751 vgv = math_ops.matmul(vt, grad_v)
752 mid = array_ops.matrix_diag(grad_e)
753 diag_grad_part = array_ops.matrix_diag(
754 array_ops.matrix_diag_part(
755 math_ops.cast(math_ops.real(vgv), vgv.dtype)))
756 mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part))
757 # vt is formally invertible as long as the original matrix is
758 # diagonalizable. However, in practice, vt may
759 # be ill-conditioned when matrix original matrix is close to
760 # non-diagonalizable one
761 grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt))
762 else:
763 _, v = linalg_ops.eig(op.inputs[0])
764 vt = _linalg.adjoint(v)
765 # vt is formally invertible as long as the original matrix is
766 # diagonalizable. However, in practice, vt may
767 # be ill-conditioned when matrix original matrix is close to
768 # non-diagonalizable one
769 grad_a = linalg_ops.matrix_solve(
770 vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt))
771 return math_ops.cast(grad_a, op.inputs[0].dtype)
774@ops.RegisterGradient("SelfAdjointEigV2")
775def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
776 """Gradient for SelfAdjointEigV2."""
777 e = op.outputs[0]
778 compute_v = op.get_attr("compute_v")
779 # a = op.inputs[0], which satisfies
780 # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
781 with ops.control_dependencies([grad_e, grad_v]):
782 if compute_v:
783 v = op.outputs[1]
784 # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
785 # Notice that because of the term involving f, the gradient becomes
786 # infinite (or NaN in practice) when eigenvalues are not unique.
787 # Mathematically this should not be surprising, since for (k-fold)
788 # degenerate eigenvalues, the corresponding eigenvectors are only defined
789 # up to arbitrary rotation in a (k-dimensional) subspace.
790 f = array_ops.matrix_set_diag(
791 _SafeReciprocal(
792 array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
793 array_ops.zeros_like(e))
794 grad_a = math_ops.matmul(
795 v,
796 math_ops.matmul(
797 array_ops.matrix_diag(grad_e) +
798 f * math_ops.matmul(v, grad_v, adjoint_a=True),
799 v,
800 adjoint_b=True))
801 else:
802 _, v = linalg_ops.self_adjoint_eig(op.inputs[0])
803 grad_a = math_ops.matmul(v,
804 math_ops.matmul(
805 array_ops.matrix_diag(grad_e),
806 v,
807 adjoint_b=True))
808 # The forward op only depends on the lower triangular part of a, so here we
809 # symmetrize and take the lower triangle
810 grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0)
811 grad_a = array_ops.matrix_set_diag(grad_a,
812 0.5 * array_ops.matrix_diag_part(grad_a))
813 return grad_a
816@ops.RegisterGradient("Svd")
817def _SvdGrad(op, grad_s, grad_u, grad_v):
818 """Gradient for the singular value decomposition."""
820 # The derivation for the compute_uv=False case, and most of
821 # the derivation for the full_matrices=True case, are in
822 # Giles' paper (see reference at top of file). A derivation for
823 # the full_matrices=False case is available at
824 # https://j-towns.github.io/papers/svd-derivative.pdf
825 # The derivation for complex valued SVD can be found in
826 # https://re-ra.xyz/misc/complexsvd.pdf or
827 # https://giggleliu.github.io/2019/04/02/einsumbp.html
828 a = op.inputs[0]
829 a_shape = a.get_shape().with_rank_at_least(2)
830 grad_s = math_ops.cast(grad_s, a.dtype)
831 grad_s_mat = array_ops.matrix_diag(grad_s)
833 if not op.get_attr("compute_uv"):
834 s, u, v = linalg_ops.svd(a, compute_uv=True)
835 grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
836 grad_a.set_shape(a_shape)
837 return grad_a
839 full_matrices = op.get_attr("full_matrices")
841 grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
842 grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
843 m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
844 n = a_shape.dims[-1].merge_with(grad_v_shape[-2])
845 batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
846 grad_v_shape[:-2])
847 a_shape = batch_shape.concatenate([m, n])
849 m = a_shape.dims[-2].value
850 n = a_shape.dims[-1].value
851 # TODO(rmlarsen): Make this work with placeholders.
852 if m is None or n is None:
853 raise NotImplementedError(
854 "SVD gradient has not been implemented for input with unknown "
855 "inner matrix shape.")
857 s = op.outputs[0]
858 u = op.outputs[1]
859 v = op.outputs[2]
860 s = math_ops.cast(s, a.dtype)
862 use_adjoint = False
863 if m > n:
864 # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
865 # Hermitian transpose of the gradient at the end.
866 use_adjoint = True
867 m, n = n, m
868 u, v = v, u
869 grad_u, grad_v = grad_v, grad_u
871 with ops.control_dependencies([grad_s, grad_u, grad_v]):
872 if full_matrices and abs(m - n) > 1:
873 raise NotImplementedError(
874 "svd gradient is not implemented for abs(m - n) > 1 "
875 f"when full_matrices is True. Received: m={m} and n={n} from "
876 f"op input={a} with shape={a_shape}.")
877 s_mat = array_ops.matrix_diag(s)
878 s2 = math_ops.square(s)
880 # NOTICE: Because of the term involving f, the gradient becomes
881 # infinite (or NaN in practice) when singular values are not unique.
882 # Mathematically this should not be surprising, since for (k-fold)
883 # degenerate singular values, the corresponding singular vectors are
884 # only defined up a (k-dimensional) subspace. In practice, this can
885 # lead to numerical instability when singular values are close but not
886 # exactly equal.
888 s_shape = array_ops.shape(s)
889 f = array_ops.matrix_set_diag(
890 _SafeReciprocal(
891 array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
892 array_ops.zeros_like(s))
893 s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s))
895 v1 = v[..., :, :m]
896 grad_v1 = grad_v[..., :, :m]
898 u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
899 v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True)
901 f_u = f * u_gu
902 f_v = f * v_gv
904 term1_nouv = (
905 grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
906 math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))
908 term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True))
910 if m == n:
911 grad_a_before_transpose = term1
912 else:
913 gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True)
914 gv1t_v1 = math_ops.matmul(gv1t, v1)
915 term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
917 if full_matrices:
918 v2 = v[..., :, m:n]
919 grad_v2 = grad_v[..., :, m:n]
921 v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True)
922 term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True)
924 u_s_inv = math_ops.matmul(u, s_inv_mat)
925 term2 = math_ops.matmul(u_s_inv, term2_nous)
927 grad_a_before_transpose = term1 + term2
929 if a.dtype.is_complex:
930 eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype)
931 l = eye * v_gv
932 term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l)
933 term3 = 1 / 2. * math_ops.matmul(
934 u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
936 grad_a_before_transpose += term3
938 if use_adjoint:
939 grad_a = array_ops.matrix_transpose(
940 grad_a_before_transpose, conjugate=True)
941 else:
942 grad_a = grad_a_before_transpose
944 grad_a.set_shape(a_shape)
945 return grad_a
948def _LeftShift(x):
949 """Shifts next-to-last dimension to the left, adding zero on the right."""
950 rank = array_ops.rank(x)
951 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
952 pad = array_ops.concat([zeros, array_ops.constant([[0, 1], [0, 0]])], axis=0)
953 return array_ops.pad(x[..., 1:, :], pad)
956def _RightShift(x):
957 """Shifts next-to-last dimension to the right, adding zero on the left."""
958 rank = array_ops.rank(x)
959 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
960 pad = array_ops.concat([zeros, array_ops.constant([[1, 0], [0, 0]])], axis=0)
961 return array_ops.pad(x[..., :-1, :], pad)
964@ops.RegisterGradient("TridiagonalMatMul")
965def _TridiagonalMatMulGrad(op, grad):
966 """Gradient for TridiagonalMatMul."""
967 superdiag_conj = array_ops.matrix_transpose(op.inputs[0], conjugate=True)
968 maindiag_conj = array_ops.matrix_transpose(op.inputs[1], conjugate=True)
969 subdiag_conj = array_ops.matrix_transpose(op.inputs[2], conjugate=True)
970 rhs_conj = math_ops.conj(op.inputs[3])
972 superdiag_grad = math_ops.reduce_sum(_LeftShift(rhs_conj) * grad, axis=-1)
973 maindiag_grad = math_ops.reduce_sum(rhs_conj * grad, axis=-1)
974 subdiag_grad = math_ops.reduce_sum(_RightShift(rhs_conj) * grad, axis=-1)
975 rhs_grad = _RightShift(superdiag_conj * grad) + \
976 maindiag_conj * grad + _LeftShift(subdiag_conj * grad)
978 superdiag_grad = array_ops.expand_dims(superdiag_grad, -2)
979 maindiag_grad = array_ops.expand_dims(maindiag_grad, -2)
980 subdiag_grad = array_ops.expand_dims(subdiag_grad, -2)
982 return superdiag_grad, maindiag_grad, subdiag_grad, rhs_grad
985@ops.RegisterGradient("TridiagonalSolve")
986def _TridiagonalSolveGrad(op, grad):
987 """Gradient for TridiagonalSolveGrad."""
988 diags = op.inputs[0]
989 x = op.outputs[0]
990 partial_pivoting = op.get_attr("partial_pivoting")
991 perturb_singular = op.get_attr("perturb_singular")
993 # Transposing the matrix within tridiagonal_solve kernel by interchanging
994 # superdiagonal and subdiagonal wouldn't work on GPU due to mismatch with
995 # paddings required by cusparse*gtsv routines.
996 # So constructing the transposed matrix in Python.
997 diags_transposed = _TransposeTridiagonalMatrix(diags)
999 grad_rhs = linalg_ops.tridiagonal_solve(
1000 diags_transposed,
1001 grad,
1002 partial_pivoting=partial_pivoting,
1003 perturb_singular=perturb_singular)
1004 grad_diags = -_MatmulExtractingThreeDiagonals(grad_rhs, x) # pylint: disable=invalid-unary-operand-type
1005 return grad_diags, grad_rhs
1008def _TransposeTridiagonalMatrix(diags):
1009 """Transposes a tridiagonal matrix.
1011 Args:
1012 diags: the diagonals of the input matrix in the compact form (see
1013 linalg_ops.tridiagonal_solve).
1015 Returns:
1016 Diagonals of the transposed matrix in the compact form.
1017 """
1019 diag = diags[..., 1, :]
1021 if diags.shape.is_fully_defined():
1022 # For fully defined tensor we can concat with a tensor of zeros, which is
1023 # faster than using array_ops.pad().
1024 zeros = array_ops.zeros(list(diags.shape[:-2]) + [1], dtype=diags.dtype)
1025 superdiag = array_ops.concat((diags[..., 2, 1:], zeros), axis=-1)
1026 subdiag = array_ops.concat((zeros, diags[..., 0, :-1]), axis=-1)
1027 else:
1028 rank = array_ops.rank(diags)
1029 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
1030 superdiag_pad = array_ops.concat((zeros, array_ops.constant([[0, 1]])),
1031 axis=0)
1032 superdiag = array_ops.pad(diags[..., 2, 1:], superdiag_pad)
1033 subdiag_pad = array_ops.concat((zeros, array_ops.constant([[1, 0]])),
1034 axis=0)
1035 subdiag = array_ops.pad(diags[..., 0, :-1], subdiag_pad)
1036 return array_ops_stack.stack([superdiag, diag, subdiag], axis=-2)
1039def _MatmulExtractingThreeDiagonals(x, y_tr):
1040 """Multiplies matrices and extracts three diagonals from the product.
1042 With sizes M x K and K x M, this function takes O(MK) time and O(M) space,
1043 while using math_ops.matmul, and then extracting the diagonals would take
1044 O(M^2 K) time and O(M^2) space.
1046 Args:
1047 x: first matrix
1048 y_tr: second matrix transposed
1050 Returns:
1051 Diagonals of the product in compact format (see
1052 linalg_ops.tridiagonal_solve)
1054 """
1055 diag = math_ops.reduce_sum(x * y_tr, axis=-1)
1057 if y_tr.shape.is_fully_defined():
1058 zeros = array_ops.zeros(
1059 list(x.shape[:-2]) + [1, x.shape[-1]], dtype=x.dtype)
1060 superdiag = math_ops.reduce_sum(
1061 x * array_ops.concat((y_tr[..., 1:, :], zeros), axis=-2), axis=-1)
1062 subdiag = math_ops.reduce_sum(
1063 x * array_ops.concat((zeros, y_tr[..., :-1, :]), axis=-2), axis=-1)
1064 else:
1065 rank = array_ops.rank(y_tr)
1066 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
1067 superdiag_pad = array_ops.concat(
1068 (zeros, array_ops.constant([[0, 1], [0, 0]])), axis=0)
1069 superdiag = math_ops.reduce_sum(
1070 x * array_ops.pad(y_tr[..., 1:, :], superdiag_pad), axis=-1)
1071 subdiag_pad = array_ops.concat(
1072 (zeros, array_ops.constant([[1, 0], [0, 0]])), axis=0)
1073 subdiag = math_ops.reduce_sum(
1074 x * array_ops.pad(y_tr[..., :-1, :], subdiag_pad), axis=-1)
1075 return array_ops_stack.stack([superdiag, diag, subdiag], axis=-2)