Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg_grad.py: 10%

450 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Gradients for operators defined in linalg_ops.py. 

16 

17Useful reference for derivative formulas is (Mike Giles, 2008). 

18 

19Ionescu et al. (2015) provide a detailed derivation of formulas for 

20backpropagating through spectral layers (SVD and Eig). 

21 

22References: 

23 An extended collection of matrix derivative results for 

24 forward and reverse mode automatic differentiation: 

25 [Mike Giles, 2008] 

26 (https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124) 

27 ([pdf](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf)) 

28 Matrix Backpropagation for Deep Networks with Structured Layers 

29 [Ionescu et al., 2015] 

30 (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.html) 

31 ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.pdf)) 

32 Training Deep Networks with Structured Layers by Matrix Backpropagation: 

33 [Ionescu et al., 2015](https://arxiv.org/abs/1509.07838) 

34 ([pdf](https://arxiv.org/pdf/1509.07838.pdf)) 

35""" 

36from tensorflow.python.framework import dtypes 

37from tensorflow.python.framework import ops 

38from tensorflow.python.ops import array_ops 

39from tensorflow.python.ops import array_ops_stack 

40from tensorflow.python.ops import cond 

41from tensorflow.python.ops import gen_linalg_ops 

42from tensorflow.python.ops import linalg_ops 

43from tensorflow.python.ops import math_ops 

44from tensorflow.python.ops.linalg import linalg_impl as _linalg 

45 

46 

47@ops.RegisterGradient("MatrixInverse") 

48def _MatrixInverseGrad(op, grad): 

49 """Gradient for MatrixInverse.""" 

50 ainv = op.outputs[0] 

51 op_adjoint = op.get_attr("adjoint") 

52 return -math_ops.matmul( # pylint: disable=invalid-unary-operand-type 

53 ainv, 

54 math_ops.matmul(grad, ainv, adjoint_a=op_adjoint, 

55 adjoint_b=not op_adjoint), 

56 adjoint_a=not op_adjoint) 

57 

58 

59@ops.RegisterGradient("Einsum") 

60def _EinsumGrad(op, grad): 

61 """Gradient for Einsum.""" 

62 ellipsis = "..." 

63 

64 def _GetAxisFromLabel(subscripts, label): 

65 """Returns the axis (possibly negative) corresponding to a label. 

66 

67 Returns the axis index of the axis label if it is before an ellipsis (or if 

68 the ellipsis is not present), and the negative index if it occurs after the 

69 ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`. 

70 

71 For multiple occurrences, returns the leftmost one. If not found, returns 

72 None. 

73 

74 Args: 

75 subscripts: A string denoting the einsum subscript (e.g. `ab...cd`) 

76 label: The single character axis label. 

77 """ 

78 splits = subscripts.split(ellipsis) 

79 index = splits[0].find(label) 

80 if index != -1: 

81 return index 

82 if len(splits) < 2: 

83 return None 

84 index = splits[1].find(label) 

85 if index != -1: 

86 return index - len(splits[1]) 

87 return None 

88 

89 def _GetBcastSubshape(subscripts): 

90 """Returns a tuple denoting the slice mapping to ellipsis. 

91 

92 For a given subscript, returns a tuple (start, end) denoting the start 

93 axis index and the (negative) end axis index respectively. For any input 

94 Tensor `x` described by the subscript, `x[start:end]` would be the slice 

95 represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`. 

96 

97 If ellipsis is not present in `subscripts`, returns `(0, 0)`. 

98 

99 Args: 

100 subscripts: A string denoting the einsum subscript. 

101 """ 

102 start = subscripts.find(ellipsis) 

103 if start == -1: 

104 return 0, 0 

105 remaining = len(subscripts) - (start + len(ellipsis)) 

106 end = -remaining if remaining > 0 else None 

107 return start, end 

108 

109 def _GetReducedSubscripts(reduced_label_set, input_shape, subscripts): 

110 """Returns reduced subscripts and their corresponding dimensions and axes. 

111 

112 Given a set of axis labels, returns their concatenated subscript, their 

113 corresponding dimensions from input_shape, and their corresponding axes. 

114 Note that the concatenated subscript `reduced_subs` may have axis labels 

115 from `reduced_label_set` in any order. For example, for the reduced label 

116 set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns 

117 subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`. 

118 

119 Args: 

120 reduced_label_set: Set of axis labels which appear in `subscripts`. 

121 input_shape: A `Tensor` representing the shape of the einsum operand 

122 corresponding to `subscripts`. 

123 subscripts: A string denoting the einsum subscript. 

124 

125 Returns: 

126 reduced_subs: Subscripts formed by a concatenation of labels in 

127 `reduced_label_set`. 

128 reduced_dims: Dimensions from `input_shape` corresponding to each label 

129 in `reduced_subs`. 

130 reduced_axes: Axes described by `subscripts` corresponding to each label 

131 in `reduced_subs`. If there are multiple occurrences in `subscripts`, 

132 we consider only the leftmost one. 

133 

134 """ 

135 # Concatenate the sequence of reduced axis labels. 

136 reduced_subs = "".join(list(reduced_label_set)) 

137 # Get the axis (may be positive, negative or zero) for each of the reduced 

138 # labels. If the same label appears multiple times, get the left-most axis. 

139 reduced_axes = [_GetAxisFromLabel(subscripts, s) for s in reduced_subs] 

140 # Get the corresponding dimensions for each reduced axis. 

141 reduced_dims = array_ops_stack.stack( 

142 [input_shape[ax] for ax in reduced_axes]) 

143 return reduced_subs, reduced_dims, reduced_axes 

144 

145 def _GetGradReduced(output_grad, output_subs, input_subs, input_shape, 

146 reduced_label_set): 

147 """Returns the gradient wrt input for a unary einsum with reductions. 

148 

149 Args: 

150 output_grad: The gradient wrt the output of a unary einsum operation. 

151 output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`). 

152 input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`). 

153 input_shape: A `Tensor` representing the shape of the input operand. 

154 reduced_label_set: The set of axis labels appearing in `input_subs` but 

155 not in `output_subs`. 

156 """ 

157 # Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and 

158 # 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced 

159 # subscripts "bd", corresponding dimensions [5,4] and axes [2,5]. 

160 reduced_subs, reduced_dims, reduced_axes = _GetReducedSubscripts( 

161 reduced_label_set, input_shape, input_subs) 

162 # Whether either the input or the output subscripts have a repeated label. 

163 # This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca". 

164 has_repeated_labels = ( 

165 len(set(input_subs)) + len(set(output_subs)) < 

166 len(input_subs) + len(output_subs)) 

167 # Compute the input subscripts without the reduced axis labels, e.g. "aac" 

168 # for the equation "aabbcd->ca". 

169 input_subs_without_reduced_labels = "".join( 

170 [s for s in input_subs if s not in reduced_label_set]) 

171 

172 # The gradient wrt the input for the equation "abc->ac" (or, equivalently 

173 # reduce_sum(..., axis=1)) is just the gradient of the output tiled N times 

174 # along axis 1, where label 'b' represents a dimension of size N. 

175 # 

176 # If we're not dealing with repeated labels, and the non-reduced labels 

177 # doesn't need to be transposed, then just tiling is enough and there is no 

178 # need to call another einsum. For example, tiling is sufficient for 

179 # "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or 

180 # "abc->ca" (transpose), we'd need another einsum operation after tiling. 

181 if (not has_repeated_labels and 

182 input_subs_without_reduced_labels == output_subs): 

183 # Obtain the shape of the output, as if keepdims=True on reduce sum. E.g. 

184 # for the equation "abcd->ac" with input shape [2,5,3,4], we get the 

185 # reduced shape [2,1,3,1]. 

186 reduced_shape = math_ops.reduced_shape( 

187 input_shape, ops.convert_to_tensor(reduced_axes)) 

188 # Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to 

189 # the shape [2,5,3,4] results in the gradient wrt "abcd". 

190 return array_ops.broadcast_to( 

191 array_ops.reshape(output_grad, reduced_shape), input_shape) 

192 

193 # If we *do* have traces or transpose operations, then prepend the extra 

194 # reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd 

195 # first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca". 

196 # 

197 # Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2]. 

198 # This is the shape of the intermediate "bdca". 

199 grad_shape_with_reduced_labels = array_ops.concat( 

200 [reduced_dims, array_ops.shape(output_grad)], axis=0) 

201 # Obtain the output shape of the reduction-only equation "bdca->ca" as if 

202 # keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels, we 

203 # just have to prepend that many 1s to the output shape. 

204 reduced_shape = ( 

205 array_ops.concat([ 

206 array_ops.ones(len(reduced_label_set), dtype=dtypes.int32), 

207 array_ops.shape(output_grad) 

208 ], 

209 axis=0)) 

210 # Compute the VJP for the intermediate (viz. "bdca->ca") for which 

211 # broadcasting is sufficient. 

212 broadcasted_grad = array_ops.broadcast_to( 

213 array_ops.reshape(output_grad, reduced_shape), 

214 grad_shape_with_reduced_labels) 

215 # Compute the VJP for the final step (viz. "aabbcd->bdca"). We can use 

216 # einsum with the input and output subscripts reversed (viz. "bdca->aabbcd") 

217 # since the output axis labels now appear in the input subscripts. 

218 return gen_linalg_ops.einsum([broadcasted_grad], 

219 "{}->{}".format(reduced_subs + output_subs, 

220 input_subs)) 

221 

222 def _GetGradWrt(output_grad, other_operand, input_shape, input_subs, 

223 other_subs, output_subs): 

224 """Returns the gradient wrt an input operand for a binary einsum. 

225 

226 This function does not handle (un)broadcasting. This must be done separately 

227 on the returned gradient. 

228 

229 Args: 

230 output_grad: The gradient wrt the output of a binary einsum operation. 

231 other_operand: The complementary `Tensor` operand i.e. which is not the 

232 input operand. 

233 input_shape: A `Tensor` representing the shape of input operand. 

234 input_subs: The subscripts of the input operand. 

235 other_subs: The subscripts of the complementary operand. 

236 output_subs: The output subscripts. 

237 """ 

238 # Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y), 

239 # where the equation involves only Tensor contractions, generalized traces 

240 # and transposes, the input gradients are given by the vector-jacobian 

241 # products (VJPs): 

242 # 

243 # grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z) 

244 # grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z} 

245 # 

246 # where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs 

247 # x and y and grad_wrt_z is the given gradient with respect to output z. 

248 # 

249 # Proof: For unary einsum equations involving only transpose ("ij->ji") and 

250 # traces ("ii->i"), the linear mapping's Jacobian at input x is given 

251 # by the function itself. We can verify that the linear map given by the 

252 # VJP are einsums with the equations "ji->ij" and "i->ii" respectively, 

253 # where the latter represents 'un-tracing', or filling the diagonal with 

254 # the input axis and non-diagonal entries are zeros. 

255 # Furthermore, recall that matrix multiplication, which is 

256 # represented by the equation "ab,bc->ac", has its VJPs given by the 

257 # einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example 

258 # https://math.stackexchange.com/a/2755680). Combined with transposes and 

259 # traces we can rewrite Tensor contractions as regular matrix 

260 # multiplication. Since each of these operations have their VJPs described 

261 # by einsums of the required pattern, the result follows. 

262 # 

263 # Accordingly, einsum operations except for those with reductions, e.g. 

264 # "abc,cd->ad" have their VJPs defined by: 

265 # "{output_subs},{other_subs}->{input_subs}". 

266 # 

267 # But if there is a reduction, this would lead to the equation "ad,cd->abc" 

268 # which is invalid because the reduced axis label 'b' is present in the 

269 # output but not in any of the inputs. Therefore, we compute the VJP in two 

270 # steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of 

271 # "abc->ac" or, equivalently, reduce_sum(..., axis=1). 

272 # 

273 # Compute the set of input axis labels which doesn't appear in either the 

274 # output subscripts or the other operand's subscript. E.g. the set {'b'} for 

275 # the equation "abc,cd->ad". 

276 reduced_label_set = set(input_subs).difference( 

277 set(output_subs + other_subs + ".")) 

278 # Obtain the input subscripts with the reduced axis labels removed. E.g. 

279 # "ac" in the above example. 

280 left_subs = "".join(s for s in input_subs if s not in reduced_label_set) 

281 

282 # Compute the gradient wrt the input, without accounting for the operation 

283 # "abc->ac". So, now we have the VJP of the operation "ac,cd->ad". 

284 grad_reduced = gen_linalg_ops.einsum([output_grad, other_operand], 

285 "{},{}->{}".format( 

286 output_subs, other_subs, 

287 left_subs)) 

288 # If the reduced_label_set is empty, then we already have the gradient 

289 # wrt the input. 

290 if not reduced_label_set: 

291 return grad_reduced 

292 # Otherwise, we currently have the gradient wrt the output of the reduction 

293 # operation "abc->ac". Invoke the subroutine for the gradient for unary 

294 # einsum with reductions. 

295 return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, 

296 reduced_label_set) 

297 

298 equation = op.get_attr("equation") 

299 if isinstance(equation, bytes): 

300 equation = equation.decode() 

301 input_subs, output_subs = equation.split("->") 

302 

303 if len(op.inputs) == 1: 

304 # For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt the 

305 # input (VJP) is given by the reversed equation: 

306 # grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z) 

307 # (See the justification in _GetGradWrt). This is valid unless there are 

308 # reduced axis labels; i.e. axis labels appearing in the input but not in 

309 # the output subscripts. 

310 input_shape = array_ops.shape(op.inputs[0]) 

311 # Find the axis labels which appear only in the input. 

312 reduced_label_set = set(input_subs).difference(set(output_subs + ellipsis)) 

313 if not reduced_label_set: 

314 # Return the einsum given by the reversed equation, since we don't have 

315 # reduced axes. 

316 return gen_linalg_ops.einsum([grad], 

317 "{}->{}".format(output_subs, input_subs)) 

318 # We do have reduced axes, so we invoke the subroutine for reduced unary 

319 # einsums. 

320 return _GetGradReduced(grad, output_subs, input_subs, input_shape, 

321 reduced_label_set) 

322 

323 x_subs, y_subs = input_subs.split(",") 

324 # Add ellipsis for broadcasted dimensions if any operand does not have it. 

325 # This is because the equation "...ij,jk->ik" may be valid if the 0th input's 

326 # batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid 

327 # because only the output subscripts contain ellipsis. 

328 if ellipsis in output_subs: 

329 if ellipsis not in x_subs: 

330 x_subs += ellipsis 

331 if ellipsis not in y_subs: 

332 y_subs += ellipsis 

333 

334 # Obtain the gradients wrt the inputs x and y, without taking into account 

335 # the unbroadcasting. 

336 x, y = op.inputs[0], op.inputs[1] 

337 if grad.dtype.is_complex: 

338 x = math_ops.conj(x) 

339 y = math_ops.conj(y) 

340 

341 x_shape = array_ops.shape(x) 

342 y_shape = array_ops.shape(y) 

343 grad_x = _GetGradWrt(grad, y, x_shape, x_subs, y_subs, output_subs) 

344 grad_y = _GetGradWrt(grad, x, y_shape, y_subs, x_subs, output_subs) 

345 

346 if ellipsis not in output_subs: 

347 # If no ellipsis in the output; then no need to unbroadcast. 

348 return grad_x, grad_y 

349 

350 # Below we handle the case that broadcasting between x and y was necessary, 

351 # with x and y having possibly different batch shapes. 

352 

353 # Obtain the range of axes which map to ellipsis. E.g. for subscripts 'ab...c' 

354 # and shape of rank 10; the range [3:-1] denotes the broadcasted axes. 

355 bx_start, bx_end = _GetBcastSubshape(x_subs) 

356 by_start, by_end = _GetBcastSubshape(y_subs) 

357 # If the static batch shapes are equal, we don't need to unbroadcast. 

358 x_shape_static = x.get_shape() 

359 y_shape_static = y.get_shape() 

360 if (x_shape_static.is_fully_defined() and 

361 y_shape_static.is_fully_defined() and 

362 x_shape_static[bx_start:bx_end] == y_shape_static[by_start:by_end]): 

363 return grad_x, grad_y 

364 

365 # Sum the gradient across the broadcasted axes. 

366 rx, ry = array_ops.broadcast_gradient_args(x_shape[bx_start:bx_end], 

367 y_shape[by_start:by_end]) 

368 grad_x = array_ops.reshape( 

369 math_ops.reduce_sum(grad_x, bx_start + rx), x_shape) 

370 grad_y = array_ops.reshape( 

371 math_ops.reduce_sum(grad_y, by_start + ry), y_shape) 

372 return grad_x, grad_y 

373 

374 

375@ops.RegisterGradient("MatrixDeterminant") 

376def _MatrixDeterminantGrad(op, grad): 

377 """Gradient for MatrixDeterminant.""" 

378 a = op.inputs[0] 

379 c = op.outputs[0] 

380 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) 

381 multipliers = array_ops.reshape(grad * c, 

382 array_ops.concat([array_ops.shape(c), [1, 1]], 

383 0)) 

384 return multipliers * a_adj_inv 

385 

386 

387@ops.RegisterGradient("MatrixSquareRoot") 

388def _MatrixSquareRootGrad(op, grad): 

389 """Gradient for MatrixSquareRoot.""" 

390 

391 # Let A be an m x m square matrix (or batch of matrices) 

392 # Let R = sqrtm(A) 

393 # By definition, A = RR 

394 # Take the differential: dA = d(RR) = RdR + dRR 

395 # Solve the resulting Sylvester equation for dR 

396 

397 # Used to find Kronecker products within the Sylvester equation 

398 def _KroneckerProduct(b1, b2): 

399 """Computes the Kronecker product of two batches of square matrices.""" 

400 b1_shape = array_ops.shape(b1) 

401 b2_shape = array_ops.shape(b2) 

402 b1_order = b1_shape[-1] 

403 b2_order = b2_shape[-1] 

404 

405 shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)] 

406 shape_slice = array_ops.slice(b1_shape, [0], 

407 shape_slice_size) # Same for both batches 

408 b1_reshape_shape = array_ops.concat( 

409 [shape_slice, [b1_order], [1], [b1_order], [1]], 0) 

410 b2_reshape_shape = array_ops.concat( 

411 [shape_slice, [1], [b2_order], [1], [b2_order]], 0) 

412 

413 b1_reshape = array_ops.reshape(b1, b1_reshape_shape) 

414 b2_reshape = array_ops.reshape(b2, b2_reshape_shape) 

415 

416 order_prod = b1_order * b2_order 

417 kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0) 

418 return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape) 

419 

420 sqrtm = op.outputs[0] # R 

421 shape = array_ops.shape(sqrtm) 

422 order = shape[-1] # m 

423 matrix_count = math_ops.reduce_prod(shape[0:-2]) 

424 

425 # Get batch of m x m identity matrices 

426 eye = linalg_ops.eye(order, dtype=sqrtm.dtype) # m x m identity matrix 

427 eye_flat = array_ops.reshape(eye, [-1]) 

428 eye_tiled = array_ops.tile(eye_flat, [matrix_count]) 

429 eye_batch = array_ops.reshape(eye_tiled, shape) 

430 

431 # The transpose of R is taken in the k1 term instead of k2 in 

432 # order to prevent redundant transposition of R (i.e. (R')' = R) 

433 sqrtm_transpose = array_ops.matrix_transpose(sqrtm) 

434 k1 = _KroneckerProduct(eye_batch, sqrtm_transpose) 

435 k2 = _KroneckerProduct(sqrtm, eye_batch) 

436 ksum = math_ops.add(k1, k2) 

437 

438 # Vectorize dA 

439 shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)] 

440 shape_slice = array_ops.slice(shape, [0], shape_slice_size) 

441 shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0) 

442 vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da) 

443 

444 # Solve for vec(dR) 

445 vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da) 

446 

447 # Solve for dR by inverse vectorizing vec(dR) 

448 dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape) 

449 return array_ops.matrix_transpose(dsqrtm_transpose) 

450 

451 

452@ops.RegisterGradient("LogMatrixDeterminant") 

453def _LogMatrixDeterminantGrad(op, _, grad_b): 

454 """Gradient for LogMatrixDeterminant.""" 

455 a = op.inputs[0] 

456 c = op.outputs[1] 

457 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) 

458 multipliers = array_ops.reshape( 

459 grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0)) 

460 return multipliers * a_adj_inv 

461 

462 

463@ops.RegisterGradient("Cholesky") 

464def _CholeskyGrad(op, grad): 

465 """Gradient for Cholesky.""" 

466 

467 # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} 

468 l = op.outputs[0] 

469 num_rows = array_ops.shape(l)[-1] 

470 batch_shape = array_ops.shape(l)[:-2] 

471 l_inverse = linalg_ops.matrix_triangular_solve(l, 

472 linalg_ops.eye( 

473 num_rows, 

474 batch_shape=batch_shape, 

475 dtype=l.dtype)) 

476 

477 middle = math_ops.matmul(l, grad, adjoint_a=True) 

478 middle = array_ops.matrix_set_diag(middle, 

479 0.5 * array_ops.matrix_diag_part(middle)) 

480 middle = array_ops.matrix_band_part(middle, -1, 0) 

481 

482 grad_a = math_ops.matmul( 

483 math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse) 

484 

485 grad_a += _linalg.adjoint(grad_a) 

486 return grad_a * 0.5 

487 

488 

489@ops.RegisterGradient("Qr") 

490def _QrGrad(op, dq, dr): 

491 """Gradient for Qr.""" 

492 

493 # The methodology is explained in detail in https://arxiv.org/abs/2009.10071 

494 # QR and LQ Decomposition Matrix Backpropagation Algorithms for 

495 # Square, Wide, and Deep, Real and Complex, Matrices and Their Software 

496 # Implementation 

497 q, r = op.outputs 

498 if (r.shape.ndims is None or r.shape.as_list()[-2] is None or 

499 r.shape.as_list()[-1] is None): 

500 raise NotImplementedError("QrGrad not implemented with dynamic shapes. " 

501 f"Received r.shape: {r.shape}") 

502 if (r.shape.dims[-2].value > r.shape.dims[-1].value and 

503 q.shape.dims[-2].value == q.shape.dims[-1].value): 

504 raise NotImplementedError("QrGrad not implemented when nrows > ncols " 

505 "and full_matrices is true. Received r.shape=" 

506 f"{r.shape} with nrows={r.shape.dims[-2]}" 

507 f"and ncols={r.shape.dims[-1]}.") 

508 

509 def _TriangularSolve(x, r): 

510 """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" 

511 return _linalg.adjoint( 

512 linalg_ops.matrix_triangular_solve( 

513 r, _linalg.adjoint(x), lower=False, adjoint=False)) 

514 

515 def _QrGradSquareAndDeepMatrices(q, r, dq, dr): 

516 """Gradient for matrix orders num_rows >= num_cols 

517 and full_matrices is false. 

518 """ 

519 qdq = math_ops.matmul(q, dq, adjoint_a=True) 

520 qdq_ = qdq - _linalg.adjoint(qdq) 

521 rdr = math_ops.matmul(r, dr, adjoint_b=True) 

522 rdr_ = rdr - _linalg.adjoint(rdr) 

523 tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) 

524 

525 grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) 

526 grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) 

527 ret = grad_a + grad_b 

528 

529 if q.dtype.is_complex: 

530 # need to add a correction to the gradient formula for complex case 

531 m = rdr - _linalg.adjoint(qdq) 

532 eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m)) 

533 correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype) 

534 ret = ret + _TriangularSolve( 

535 math_ops.matmul(q, _linalg.adjoint(correction)), r) 

536 

537 return ret 

538 

539 num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1] 

540 

541 if num_rows >= num_cols: 

542 return _QrGradSquareAndDeepMatrices(q, r, dq, dr) 

543 

544 # Partition a = [x, y], r = [u, v] and reduce to the square case 

545 a = op.inputs[0] 

546 y = a[..., :, num_rows:] 

547 u = r[..., :, :num_rows] 

548 dv = dr[..., :, num_rows:] 

549 du = dr[..., :, :num_rows] 

550 dy = math_ops.matmul(q, dv) 

551 dx = _QrGradSquareAndDeepMatrices(q, u, 

552 dq + math_ops.matmul(y, dv, adjoint_b=True), 

553 du) 

554 return array_ops.concat([dx, dy], axis=-1) 

555 

556 

557@ops.RegisterGradient("MatrixSolve") 

558def _MatrixSolveGrad(op, grad): 

559 """Gradient for MatrixSolve.""" 

560 a = op.inputs[0] 

561 adjoint_a = op.get_attr("adjoint") 

562 c = op.outputs[0] 

563 grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a) 

564 if adjoint_a: 

565 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 

566 else: 

567 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 

568 return (grad_a, grad_b) 

569 

570 

571@ops.RegisterGradient("MatrixSolveLs") 

572def _MatrixSolveLsGrad(op, grad): 

573 """Gradients for MatrixSolveLs.""" 

574 

575 # TODO(rmlarsen): The implementation could be more efficient: 

576 # a) Output the Cholesky factorization from forward op instead of 

577 # recomputing it here. 

578 # b) Implement a symmetric rank-k update op instead of computing 

579 # x*z + transpose(x*z). This pattern occurs other places in TensorFlow. 

580 

581 def _Overdetermined(op, grad): 

582 """Gradients for the overdetermined case of MatrixSolveLs. 

583 

584 This is the backprop for the solution to the normal equations of the first 

585 kind: 

586 X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B 

587 which solve the least squares problem 

588 min ||A * X - B||_F^2 + lambda ||X||_F^2. 

589 """ 

590 a = op.inputs[0] 

591 b = op.inputs[1] 

592 x = op.outputs[0] 

593 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) 

594 # pylint: disable=protected-access 

595 chol = linalg_ops._RegularizedGramianCholesky( 

596 a, l2_regularizer=l2_regularizer, first_kind=True) 

597 # pylint: enable=protected-access 

598 # Temporary z = (A^T * A + lambda * I)^{-1} * grad. 

599 z = linalg_ops.cholesky_solve(chol, grad) 

600 xzt = math_ops.matmul(x, z, adjoint_b=True) 

601 zx_sym = xzt + array_ops.matrix_transpose(xzt) 

602 grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 

603 grad_b = math_ops.matmul(a, z) 

604 return (grad_a, grad_b, None) 

605 

606 def _Underdetermined(op, grad): 

607 """Gradients for the underdetermined case of MatrixSolveLs. 

608 

609 This is the backprop for the solution to the normal equations of the second 

610 kind: 

611 X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B 

612 that (for lambda=0) solve the least squares problem 

613 min ||X||_F subject to A*X = B. 

614 """ 

615 a = op.inputs[0] 

616 b = op.inputs[1] 

617 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) 

618 # pylint: disable=protected-access 

619 chol = linalg_ops._RegularizedGramianCholesky( 

620 a, l2_regularizer=l2_regularizer, first_kind=False) 

621 # pylint: enable=protected-access 

622 grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad)) 

623 # Temporary tmp = (A * A^T + lambda * I)^{-1} * B. 

624 tmp = linalg_ops.cholesky_solve(chol, b) 

625 a1 = math_ops.matmul(tmp, a, adjoint_a=True) 

626 a1 = -math_ops.matmul(grad_b, a1) # pylint: disable=invalid-unary-operand-type 

627 a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True) 

628 a2 = math_ops.matmul(tmp, a2, adjoint_b=True) 

629 grad_a = a1 + a2 

630 return (grad_a, grad_b, None) 

631 

632 fast = op.get_attr("fast") 

633 if fast is False: 

634 raise ValueError("Gradient not defined for fast=False") 

635 matrix_shape = op.inputs[0].get_shape()[-2:] 

636 if matrix_shape.is_fully_defined(): 

637 if matrix_shape[-2] >= matrix_shape[-1]: 

638 return _Overdetermined(op, grad) 

639 else: 

640 return _Underdetermined(op, grad) 

641 else: 

642 # We have to defer determining the shape to runtime and use 

643 # conditional execution of the appropriate graph. 

644 matrix_shape = array_ops.shape(op.inputs[0])[-2:] 

645 return cond.cond(matrix_shape[-2] >= matrix_shape[-1], 

646 lambda: _Overdetermined(op, grad), 

647 lambda: _Underdetermined(op, grad)) 

648 

649 

650@ops.RegisterGradient("BandedTriangularSolve") 

651def _BandedTriangularSolveGrad(op, grad): 

652 """Gradient for BandedTriangularSolve.""" 

653 a = op.inputs[0] 

654 b = op.inputs[1] 

655 num_bands = array_ops.shape(a)[-2] 

656 adjoint_a = op.get_attr("adjoint") 

657 lower_a = op.get_attr("lower") 

658 c = op.outputs[0] 

659 grad_b = linalg_ops.banded_triangular_solve( 

660 a, grad, lower=lower_a, adjoint=not adjoint_a) 

661 if adjoint_a: 

662 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 

663 else: 

664 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 

665 if lower_a: 

666 grad_a = array_ops.matrix_diag_part( 

667 grad_a, k=(-(num_bands - 1), 0), align="LEFT_RIGHT") 

668 else: 

669 grad_a = array_ops.matrix_diag_part( 

670 grad_a, k=(0, num_bands - 1), align="LEFT_RIGHT") 

671 # If the static batch shapes are equal, we don't need to unbroadcast. 

672 if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and 

673 a.shape[:-2] == b.shape[:-2]): 

674 return grad_a, grad_b 

675 a_shape = array_ops.shape(a) 

676 b_shape = array_ops.shape(b) 

677 ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2]) 

678 grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape) 

679 grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape) 

680 return grad_a, grad_b 

681 

682 

683@ops.RegisterGradient("MatrixTriangularSolve") 

684def _MatrixTriangularSolveGrad(op, grad): 

685 """Gradient for MatrixTriangularSolve.""" 

686 a = op.inputs[0] 

687 b = op.inputs[1] 

688 adjoint_a = op.get_attr("adjoint") 

689 lower_a = op.get_attr("lower") 

690 c = op.outputs[0] 

691 grad_b = linalg_ops.matrix_triangular_solve( 

692 a, grad, lower=lower_a, adjoint=not adjoint_a) 

693 if adjoint_a: 

694 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 

695 else: 

696 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 

697 if lower_a: 

698 grad_a = array_ops.matrix_band_part(grad_a, -1, 0) 

699 else: 

700 grad_a = array_ops.matrix_band_part(grad_a, 0, -1) 

701 # If the static batch shapes are equal, we don't need to unbroadcast. 

702 if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and 

703 a.shape[:-2] == b.shape[:-2]): 

704 return grad_a, grad_b 

705 a_shape = array_ops.shape(a) 

706 b_shape = array_ops.shape(b) 

707 ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2]) 

708 grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape) 

709 grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape) 

710 return grad_a, grad_b 

711 

712 

713# To avoid nan in cases with degenerate eigenvalues or 

714# degenerate/zero singular values in calculations of 

715# f and s_inv_mat, we introduce a Lorentz broadening. 

716def _SafeReciprocal(x, epsilon=1E-20): 

717 return x * math_ops.reciprocal(x * x + epsilon) 

718 

719 

720@ops.RegisterGradient("Eig") 

721def _EigGrad(op, grad_e, grad_v): 

722 """Gradient for Eig. 

723 

724 Based on eq. 4.77 from paper by 

725 Christoph Boeddeker et al. 

726 https://arxiv.org/abs/1701.00392 

727 See also 

728 "Computation of eigenvalue and eigenvector derivatives 

729 for a general complex-valued eigensystem" by Nico van der Aa. 

730 As for now only distinct eigenvalue case is considered. 

731 """ 

732 e = op.outputs[0] 

733 compute_v = op.get_attr("compute_v") 

734 # a = op.inputs[0], which satisfies 

735 # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] 

736 with ops.control_dependencies([grad_e, grad_v]): 

737 if compute_v: 

738 v = op.outputs[1] 

739 vt = _linalg.adjoint(v) 

740 # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). 

741 # Notice that because of the term involving f, the gradient becomes 

742 # infinite (or NaN in practice) when eigenvalues are not unique. 

743 # Mathematically this should not be surprising, since for (k-fold) 

744 # degenerate eigenvalues, the corresponding eigenvectors are only defined 

745 # up to arbitrary rotation in a (k-dimensional) subspace. 

746 f = array_ops.matrix_set_diag( 

747 _SafeReciprocal( 

748 array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), 

749 array_ops.zeros_like(e)) 

750 f = math_ops.conj(f) 

751 vgv = math_ops.matmul(vt, grad_v) 

752 mid = array_ops.matrix_diag(grad_e) 

753 diag_grad_part = array_ops.matrix_diag( 

754 array_ops.matrix_diag_part( 

755 math_ops.cast(math_ops.real(vgv), vgv.dtype))) 

756 mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part)) 

757 # vt is formally invertible as long as the original matrix is 

758 # diagonalizable. However, in practice, vt may 

759 # be ill-conditioned when matrix original matrix is close to 

760 # non-diagonalizable one 

761 grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt)) 

762 else: 

763 _, v = linalg_ops.eig(op.inputs[0]) 

764 vt = _linalg.adjoint(v) 

765 # vt is formally invertible as long as the original matrix is 

766 # diagonalizable. However, in practice, vt may 

767 # be ill-conditioned when matrix original matrix is close to 

768 # non-diagonalizable one 

769 grad_a = linalg_ops.matrix_solve( 

770 vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt)) 

771 return math_ops.cast(grad_a, op.inputs[0].dtype) 

772 

773 

774@ops.RegisterGradient("SelfAdjointEigV2") 

775def _SelfAdjointEigV2Grad(op, grad_e, grad_v): 

776 """Gradient for SelfAdjointEigV2.""" 

777 e = op.outputs[0] 

778 compute_v = op.get_attr("compute_v") 

779 # a = op.inputs[0], which satisfies 

780 # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] 

781 with ops.control_dependencies([grad_e, grad_v]): 

782 if compute_v: 

783 v = op.outputs[1] 

784 # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). 

785 # Notice that because of the term involving f, the gradient becomes 

786 # infinite (or NaN in practice) when eigenvalues are not unique. 

787 # Mathematically this should not be surprising, since for (k-fold) 

788 # degenerate eigenvalues, the corresponding eigenvectors are only defined 

789 # up to arbitrary rotation in a (k-dimensional) subspace. 

790 f = array_ops.matrix_set_diag( 

791 _SafeReciprocal( 

792 array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), 

793 array_ops.zeros_like(e)) 

794 grad_a = math_ops.matmul( 

795 v, 

796 math_ops.matmul( 

797 array_ops.matrix_diag(grad_e) + 

798 f * math_ops.matmul(v, grad_v, adjoint_a=True), 

799 v, 

800 adjoint_b=True)) 

801 else: 

802 _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) 

803 grad_a = math_ops.matmul(v, 

804 math_ops.matmul( 

805 array_ops.matrix_diag(grad_e), 

806 v, 

807 adjoint_b=True)) 

808 # The forward op only depends on the lower triangular part of a, so here we 

809 # symmetrize and take the lower triangle 

810 grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0) 

811 grad_a = array_ops.matrix_set_diag(grad_a, 

812 0.5 * array_ops.matrix_diag_part(grad_a)) 

813 return grad_a 

814 

815 

816@ops.RegisterGradient("Svd") 

817def _SvdGrad(op, grad_s, grad_u, grad_v): 

818 """Gradient for the singular value decomposition.""" 

819 

820 # The derivation for the compute_uv=False case, and most of 

821 # the derivation for the full_matrices=True case, are in 

822 # Giles' paper (see reference at top of file). A derivation for 

823 # the full_matrices=False case is available at 

824 # https://j-towns.github.io/papers/svd-derivative.pdf 

825 # The derivation for complex valued SVD can be found in 

826 # https://re-ra.xyz/misc/complexsvd.pdf or 

827 # https://giggleliu.github.io/2019/04/02/einsumbp.html 

828 a = op.inputs[0] 

829 a_shape = a.get_shape().with_rank_at_least(2) 

830 grad_s = math_ops.cast(grad_s, a.dtype) 

831 grad_s_mat = array_ops.matrix_diag(grad_s) 

832 

833 if not op.get_attr("compute_uv"): 

834 s, u, v = linalg_ops.svd(a, compute_uv=True) 

835 grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) 

836 grad_a.set_shape(a_shape) 

837 return grad_a 

838 

839 full_matrices = op.get_attr("full_matrices") 

840 

841 grad_u_shape = grad_u.get_shape().with_rank_at_least(2) 

842 grad_v_shape = grad_v.get_shape().with_rank_at_least(2) 

843 m = a_shape.dims[-2].merge_with(grad_u_shape[-2]) 

844 n = a_shape.dims[-1].merge_with(grad_v_shape[-2]) 

845 batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( 

846 grad_v_shape[:-2]) 

847 a_shape = batch_shape.concatenate([m, n]) 

848 

849 m = a_shape.dims[-2].value 

850 n = a_shape.dims[-1].value 

851 # TODO(rmlarsen): Make this work with placeholders. 

852 if m is None or n is None: 

853 raise NotImplementedError( 

854 "SVD gradient has not been implemented for input with unknown " 

855 "inner matrix shape.") 

856 

857 s = op.outputs[0] 

858 u = op.outputs[1] 

859 v = op.outputs[2] 

860 s = math_ops.cast(s, a.dtype) 

861 

862 use_adjoint = False 

863 if m > n: 

864 # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the 

865 # Hermitian transpose of the gradient at the end. 

866 use_adjoint = True 

867 m, n = n, m 

868 u, v = v, u 

869 grad_u, grad_v = grad_v, grad_u 

870 

871 with ops.control_dependencies([grad_s, grad_u, grad_v]): 

872 if full_matrices and abs(m - n) > 1: 

873 raise NotImplementedError( 

874 "svd gradient is not implemented for abs(m - n) > 1 " 

875 f"when full_matrices is True. Received: m={m} and n={n} from " 

876 f"op input={a} with shape={a_shape}.") 

877 s_mat = array_ops.matrix_diag(s) 

878 s2 = math_ops.square(s) 

879 

880 # NOTICE: Because of the term involving f, the gradient becomes 

881 # infinite (or NaN in practice) when singular values are not unique. 

882 # Mathematically this should not be surprising, since for (k-fold) 

883 # degenerate singular values, the corresponding singular vectors are 

884 # only defined up a (k-dimensional) subspace. In practice, this can 

885 # lead to numerical instability when singular values are close but not 

886 # exactly equal. 

887 

888 s_shape = array_ops.shape(s) 

889 f = array_ops.matrix_set_diag( 

890 _SafeReciprocal( 

891 array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), 

892 array_ops.zeros_like(s)) 

893 s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s)) 

894 

895 v1 = v[..., :, :m] 

896 grad_v1 = grad_v[..., :, :m] 

897 

898 u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) 

899 v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True) 

900 

901 f_u = f * u_gu 

902 f_v = f * v_gv 

903 

904 term1_nouv = ( 

905 grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + 

906 math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) 

907 

908 term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True)) 

909 

910 if m == n: 

911 grad_a_before_transpose = term1 

912 else: 

913 gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True) 

914 gv1t_v1 = math_ops.matmul(gv1t, v1) 

915 term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) 

916 

917 if full_matrices: 

918 v2 = v[..., :, m:n] 

919 grad_v2 = grad_v[..., :, m:n] 

920 

921 v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True) 

922 term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True) 

923 

924 u_s_inv = math_ops.matmul(u, s_inv_mat) 

925 term2 = math_ops.matmul(u_s_inv, term2_nous) 

926 

927 grad_a_before_transpose = term1 + term2 

928 

929 if a.dtype.is_complex: 

930 eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype) 

931 l = eye * v_gv 

932 term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l) 

933 term3 = 1 / 2. * math_ops.matmul( 

934 u, math_ops.matmul(term3_nouv, v1, adjoint_b=True)) 

935 

936 grad_a_before_transpose += term3 

937 

938 if use_adjoint: 

939 grad_a = array_ops.matrix_transpose( 

940 grad_a_before_transpose, conjugate=True) 

941 else: 

942 grad_a = grad_a_before_transpose 

943 

944 grad_a.set_shape(a_shape) 

945 return grad_a 

946 

947 

948def _LeftShift(x): 

949 """Shifts next-to-last dimension to the left, adding zero on the right.""" 

950 rank = array_ops.rank(x) 

951 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 

952 pad = array_ops.concat([zeros, array_ops.constant([[0, 1], [0, 0]])], axis=0) 

953 return array_ops.pad(x[..., 1:, :], pad) 

954 

955 

956def _RightShift(x): 

957 """Shifts next-to-last dimension to the right, adding zero on the left.""" 

958 rank = array_ops.rank(x) 

959 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 

960 pad = array_ops.concat([zeros, array_ops.constant([[1, 0], [0, 0]])], axis=0) 

961 return array_ops.pad(x[..., :-1, :], pad) 

962 

963 

964@ops.RegisterGradient("TridiagonalMatMul") 

965def _TridiagonalMatMulGrad(op, grad): 

966 """Gradient for TridiagonalMatMul.""" 

967 superdiag_conj = array_ops.matrix_transpose(op.inputs[0], conjugate=True) 

968 maindiag_conj = array_ops.matrix_transpose(op.inputs[1], conjugate=True) 

969 subdiag_conj = array_ops.matrix_transpose(op.inputs[2], conjugate=True) 

970 rhs_conj = math_ops.conj(op.inputs[3]) 

971 

972 superdiag_grad = math_ops.reduce_sum(_LeftShift(rhs_conj) * grad, axis=-1) 

973 maindiag_grad = math_ops.reduce_sum(rhs_conj * grad, axis=-1) 

974 subdiag_grad = math_ops.reduce_sum(_RightShift(rhs_conj) * grad, axis=-1) 

975 rhs_grad = _RightShift(superdiag_conj * grad) + \ 

976 maindiag_conj * grad + _LeftShift(subdiag_conj * grad) 

977 

978 superdiag_grad = array_ops.expand_dims(superdiag_grad, -2) 

979 maindiag_grad = array_ops.expand_dims(maindiag_grad, -2) 

980 subdiag_grad = array_ops.expand_dims(subdiag_grad, -2) 

981 

982 return superdiag_grad, maindiag_grad, subdiag_grad, rhs_grad 

983 

984 

985@ops.RegisterGradient("TridiagonalSolve") 

986def _TridiagonalSolveGrad(op, grad): 

987 """Gradient for TridiagonalSolveGrad.""" 

988 diags = op.inputs[0] 

989 x = op.outputs[0] 

990 partial_pivoting = op.get_attr("partial_pivoting") 

991 perturb_singular = op.get_attr("perturb_singular") 

992 

993 # Transposing the matrix within tridiagonal_solve kernel by interchanging 

994 # superdiagonal and subdiagonal wouldn't work on GPU due to mismatch with 

995 # paddings required by cusparse*gtsv routines. 

996 # So constructing the transposed matrix in Python. 

997 diags_transposed = _TransposeTridiagonalMatrix(diags) 

998 

999 grad_rhs = linalg_ops.tridiagonal_solve( 

1000 diags_transposed, 

1001 grad, 

1002 partial_pivoting=partial_pivoting, 

1003 perturb_singular=perturb_singular) 

1004 grad_diags = -_MatmulExtractingThreeDiagonals(grad_rhs, x) # pylint: disable=invalid-unary-operand-type 

1005 return grad_diags, grad_rhs 

1006 

1007 

1008def _TransposeTridiagonalMatrix(diags): 

1009 """Transposes a tridiagonal matrix. 

1010 

1011 Args: 

1012 diags: the diagonals of the input matrix in the compact form (see 

1013 linalg_ops.tridiagonal_solve). 

1014 

1015 Returns: 

1016 Diagonals of the transposed matrix in the compact form. 

1017 """ 

1018 

1019 diag = diags[..., 1, :] 

1020 

1021 if diags.shape.is_fully_defined(): 

1022 # For fully defined tensor we can concat with a tensor of zeros, which is 

1023 # faster than using array_ops.pad(). 

1024 zeros = array_ops.zeros(list(diags.shape[:-2]) + [1], dtype=diags.dtype) 

1025 superdiag = array_ops.concat((diags[..., 2, 1:], zeros), axis=-1) 

1026 subdiag = array_ops.concat((zeros, diags[..., 0, :-1]), axis=-1) 

1027 else: 

1028 rank = array_ops.rank(diags) 

1029 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 

1030 superdiag_pad = array_ops.concat((zeros, array_ops.constant([[0, 1]])), 

1031 axis=0) 

1032 superdiag = array_ops.pad(diags[..., 2, 1:], superdiag_pad) 

1033 subdiag_pad = array_ops.concat((zeros, array_ops.constant([[1, 0]])), 

1034 axis=0) 

1035 subdiag = array_ops.pad(diags[..., 0, :-1], subdiag_pad) 

1036 return array_ops_stack.stack([superdiag, diag, subdiag], axis=-2) 

1037 

1038 

1039def _MatmulExtractingThreeDiagonals(x, y_tr): 

1040 """Multiplies matrices and extracts three diagonals from the product. 

1041 

1042 With sizes M x K and K x M, this function takes O(MK) time and O(M) space, 

1043 while using math_ops.matmul, and then extracting the diagonals would take 

1044 O(M^2 K) time and O(M^2) space. 

1045 

1046 Args: 

1047 x: first matrix 

1048 y_tr: second matrix transposed 

1049 

1050 Returns: 

1051 Diagonals of the product in compact format (see 

1052 linalg_ops.tridiagonal_solve) 

1053 

1054 """ 

1055 diag = math_ops.reduce_sum(x * y_tr, axis=-1) 

1056 

1057 if y_tr.shape.is_fully_defined(): 

1058 zeros = array_ops.zeros( 

1059 list(x.shape[:-2]) + [1, x.shape[-1]], dtype=x.dtype) 

1060 superdiag = math_ops.reduce_sum( 

1061 x * array_ops.concat((y_tr[..., 1:, :], zeros), axis=-2), axis=-1) 

1062 subdiag = math_ops.reduce_sum( 

1063 x * array_ops.concat((zeros, y_tr[..., :-1, :]), axis=-2), axis=-1) 

1064 else: 

1065 rank = array_ops.rank(y_tr) 

1066 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 

1067 superdiag_pad = array_ops.concat( 

1068 (zeros, array_ops.constant([[0, 1], [0, 0]])), axis=0) 

1069 superdiag = math_ops.reduce_sum( 

1070 x * array_ops.pad(y_tr[..., 1:, :], superdiag_pad), axis=-1) 

1071 subdiag_pad = array_ops.concat( 

1072 (zeros, array_ops.constant([[1, 0], [0, 0]])), axis=0) 

1073 subdiag = math_ops.reduce_sum( 

1074 x * array_ops.pad(y_tr[..., :-1, :], subdiag_pad), axis=-1) 

1075 return array_ops_stack.stack([superdiag, diag, subdiag], axis=-2)