Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linalg_impl.py: 16%

574 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Operations for linear algebra.""" 

16 

17import numpy as np 

18 

19from tensorflow.python.framework import constant_op 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.framework import tensor_shape 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import array_ops_stack 

25from tensorflow.python.ops import check_ops 

26from tensorflow.python.ops import cond as tf_cond 

27from tensorflow.python.ops import control_flow_ops 

28from tensorflow.python.ops import gen_linalg_ops 

29from tensorflow.python.ops import linalg_ops 

30from tensorflow.python.ops import map_fn 

31from tensorflow.python.ops import math_ops 

32from tensorflow.python.ops import special_math_ops 

33from tensorflow.python.ops import stateless_random_ops 

34from tensorflow.python.ops import while_loop 

35from tensorflow.python.util import dispatch 

36from tensorflow.python.util.tf_export import tf_export 

37 

38# Linear algebra ops. 

39band_part = array_ops.matrix_band_part 

40cholesky = linalg_ops.cholesky 

41cholesky_solve = linalg_ops.cholesky_solve 

42det = linalg_ops.matrix_determinant 

43slogdet = gen_linalg_ops.log_matrix_determinant 

44tf_export('linalg.slogdet')(dispatch.add_dispatch_support(slogdet)) 

45diag = array_ops.matrix_diag 

46diag_part = array_ops.matrix_diag_part 

47eigh = linalg_ops.self_adjoint_eig 

48eigvalsh = linalg_ops.self_adjoint_eigvals 

49einsum = special_math_ops.einsum 

50eye = linalg_ops.eye 

51inv = linalg_ops.matrix_inverse 

52logm = gen_linalg_ops.matrix_logarithm 

53lu = gen_linalg_ops.lu 

54tf_export('linalg.logm')(dispatch.add_dispatch_support(logm)) 

55lstsq = linalg_ops.matrix_solve_ls 

56norm = linalg_ops.norm 

57qr = linalg_ops.qr 

58set_diag = array_ops.matrix_set_diag 

59solve = linalg_ops.matrix_solve 

60sqrtm = linalg_ops.matrix_square_root 

61svd = linalg_ops.svd 

62tensordot = math_ops.tensordot 

63trace = math_ops.trace 

64transpose = array_ops.matrix_transpose 

65triangular_solve = linalg_ops.matrix_triangular_solve 

66 

67 

68@tf_export('linalg.logdet') 

69@dispatch.add_dispatch_support 

70def logdet(matrix, name=None): 

71 """Computes log of the determinant of a hermitian positive definite matrix. 

72 

73 ```python 

74 # Compute the determinant of a matrix while reducing the chance of over- or 

75 underflow: 

76 A = ... # shape 10 x 10 

77 det = tf.exp(tf.linalg.logdet(A)) # scalar 

78 ``` 

79 

80 Args: 

81 matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, 

82 or `complex128` with shape `[..., M, M]`. 

83 name: A name to give this `Op`. Defaults to `logdet`. 

84 

85 Returns: 

86 The natural log of the determinant of `matrix`. 

87 

88 @compatibility(numpy) 

89 Equivalent to numpy.linalg.slogdet, although no sign is returned since only 

90 hermitian positive definite matrices are supported. 

91 @end_compatibility 

92 """ 

93 # This uses the property that the log det(A) = 2*sum(log(real(diag(C)))) 

94 # where C is the cholesky decomposition of A. 

95 with ops.name_scope(name, 'logdet', [matrix]): 

96 chol = gen_linalg_ops.cholesky(matrix) 

97 return 2.0 * math_ops.reduce_sum( 

98 math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))), 

99 axis=[-1]) 

100 

101 

102@tf_export('linalg.adjoint') 

103@dispatch.add_dispatch_support 

104def adjoint(matrix, name=None): 

105 """Transposes the last two dimensions of and conjugates tensor `matrix`. 

106 

107 For example: 

108 

109 ```python 

110 x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j], 

111 [4 + 4j, 5 + 5j, 6 + 6j]]) 

112 tf.linalg.adjoint(x) # [[1 - 1j, 4 - 4j], 

113 # [2 - 2j, 5 - 5j], 

114 # [3 - 3j, 6 - 6j]] 

115 ``` 

116 

117 Args: 

118 matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, 

119 or `complex128` with shape `[..., M, M]`. 

120 name: A name to give this `Op` (optional). 

121 

122 Returns: 

123 The adjoint (a.k.a. Hermitian transpose a.k.a. conjugate transpose) of 

124 matrix. 

125 """ 

126 with ops.name_scope(name, 'adjoint', [matrix]): 

127 matrix = ops.convert_to_tensor(matrix, name='matrix') 

128 return array_ops.matrix_transpose(matrix, conjugate=True) 

129 

130 

131# This section is ported nearly verbatim from Eigen's implementation: 

132# https://eigen.tuxfamily.org/dox/unsupported/MatrixExponential_8h_source.html 

133def _matrix_exp_pade3(matrix): 

134 """3rd-order Pade approximant for matrix exponential.""" 

135 b = [120.0, 60.0, 12.0] 

136 b = [constant_op.constant(x, matrix.dtype) for x in b] 

137 ident = linalg_ops.eye( 

138 array_ops.shape(matrix)[-2], 

139 batch_shape=array_ops.shape(matrix)[:-2], 

140 dtype=matrix.dtype) 

141 matrix_2 = math_ops.matmul(matrix, matrix) 

142 tmp = matrix_2 + b[1] * ident 

143 matrix_u = math_ops.matmul(matrix, tmp) 

144 matrix_v = b[2] * matrix_2 + b[0] * ident 

145 return matrix_u, matrix_v 

146 

147 

148def _matrix_exp_pade5(matrix): 

149 """5th-order Pade approximant for matrix exponential.""" 

150 b = [30240.0, 15120.0, 3360.0, 420.0, 30.0] 

151 b = [constant_op.constant(x, matrix.dtype) for x in b] 

152 ident = linalg_ops.eye( 

153 array_ops.shape(matrix)[-2], 

154 batch_shape=array_ops.shape(matrix)[:-2], 

155 dtype=matrix.dtype) 

156 matrix_2 = math_ops.matmul(matrix, matrix) 

157 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 

158 tmp = matrix_4 + b[3] * matrix_2 + b[1] * ident 

159 matrix_u = math_ops.matmul(matrix, tmp) 

160 matrix_v = b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident 

161 return matrix_u, matrix_v 

162 

163 

164def _matrix_exp_pade7(matrix): 

165 """7th-order Pade approximant for matrix exponential.""" 

166 b = [17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0] 

167 b = [constant_op.constant(x, matrix.dtype) for x in b] 

168 ident = linalg_ops.eye( 

169 array_ops.shape(matrix)[-2], 

170 batch_shape=array_ops.shape(matrix)[:-2], 

171 dtype=matrix.dtype) 

172 matrix_2 = math_ops.matmul(matrix, matrix) 

173 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 

174 matrix_6 = math_ops.matmul(matrix_4, matrix_2) 

175 tmp = matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident 

176 matrix_u = math_ops.matmul(matrix, tmp) 

177 matrix_v = b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident 

178 return matrix_u, matrix_v 

179 

180 

181def _matrix_exp_pade9(matrix): 

182 """9th-order Pade approximant for matrix exponential.""" 

183 b = [ 

184 17643225600.0, 8821612800.0, 2075673600.0, 302702400.0, 30270240.0, 

185 2162160.0, 110880.0, 3960.0, 90.0 

186 ] 

187 b = [constant_op.constant(x, matrix.dtype) for x in b] 

188 ident = linalg_ops.eye( 

189 array_ops.shape(matrix)[-2], 

190 batch_shape=array_ops.shape(matrix)[:-2], 

191 dtype=matrix.dtype) 

192 matrix_2 = math_ops.matmul(matrix, matrix) 

193 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 

194 matrix_6 = math_ops.matmul(matrix_4, matrix_2) 

195 matrix_8 = math_ops.matmul(matrix_6, matrix_2) 

196 tmp = ( 

197 matrix_8 + b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + 

198 b[1] * ident) 

199 matrix_u = math_ops.matmul(matrix, tmp) 

200 matrix_v = ( 

201 b[8] * matrix_8 + b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + 

202 b[0] * ident) 

203 return matrix_u, matrix_v 

204 

205 

206def _matrix_exp_pade13(matrix): 

207 """13th-order Pade approximant for matrix exponential.""" 

208 b = [ 

209 64764752532480000.0, 32382376266240000.0, 7771770303897600.0, 

210 1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0, 

211 33522128640.0, 1323241920.0, 40840800.0, 960960.0, 16380.0, 182.0 

212 ] 

213 b = [constant_op.constant(x, matrix.dtype) for x in b] 

214 ident = linalg_ops.eye( 

215 array_ops.shape(matrix)[-2], 

216 batch_shape=array_ops.shape(matrix)[:-2], 

217 dtype=matrix.dtype) 

218 matrix_2 = math_ops.matmul(matrix, matrix) 

219 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 

220 matrix_6 = math_ops.matmul(matrix_4, matrix_2) 

221 tmp_u = ( 

222 math_ops.matmul(matrix_6, matrix_6 + b[11] * matrix_4 + b[9] * matrix_2) + 

223 b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident) 

224 matrix_u = math_ops.matmul(matrix, tmp_u) 

225 tmp_v = b[12] * matrix_6 + b[10] * matrix_4 + b[8] * matrix_2 

226 matrix_v = ( 

227 math_ops.matmul(matrix_6, tmp_v) + b[6] * matrix_6 + b[4] * matrix_4 + 

228 b[2] * matrix_2 + b[0] * ident) 

229 return matrix_u, matrix_v 

230 

231 

232@tf_export('linalg.expm') 

233@dispatch.add_dispatch_support 

234def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin 

235 r"""Computes the matrix exponential of one or more square matrices. 

236 

237 $$exp(A) = \sum_{n=0}^\infty A^n/n!$$ 

238 

239 The exponential is computed using a combination of the scaling and squaring 

240 method and the Pade approximation. Details can be found in: 

241 Nicholas J. Higham, "The scaling and squaring method for the matrix 

242 exponential revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005. 

243 

244 The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions 

245 form square matrices. The output is a tensor of the same shape as the input 

246 containing the exponential for all input submatrices `[..., :, :]`. 

247 

248 Args: 

249 input: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, or 

250 `complex128` with shape `[..., M, M]`. 

251 name: A name to give this `Op` (optional). 

252 

253 Returns: 

254 the matrix exponential of the input. 

255 

256 Raises: 

257 ValueError: An unsupported type is provided as input. 

258 

259 @compatibility(scipy) 

260 Equivalent to scipy.linalg.expm 

261 @end_compatibility 

262 """ 

263 with ops.name_scope(name, 'matrix_exponential', [input]): 

264 matrix = ops.convert_to_tensor(input, name='input') 

265 if matrix.shape[-2:] == [0, 0]: 

266 return matrix 

267 batch_shape = matrix.shape[:-2] 

268 if not batch_shape.is_fully_defined(): 

269 batch_shape = array_ops.shape(matrix)[:-2] 

270 

271 # reshaping the batch makes the where statements work better 

272 matrix = array_ops.reshape( 

273 matrix, array_ops.concat(([-1], array_ops.shape(matrix)[-2:]), axis=0)) 

274 l1_norm = math_ops.reduce_max( 

275 math_ops.reduce_sum( 

276 math_ops.abs(matrix), 

277 axis=array_ops.size(array_ops.shape(matrix)) - 2), 

278 axis=-1)[..., array_ops.newaxis, array_ops.newaxis] 

279 

280 const = lambda x: constant_op.constant(x, l1_norm.dtype) 

281 

282 def _nest_where(vals, cases): 

283 assert len(vals) == len(cases) - 1 

284 if len(vals) == 1: 

285 return array_ops.where_v2( 

286 math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1]) 

287 else: 

288 return array_ops.where_v2( 

289 math_ops.less(l1_norm, const(vals[0])), cases[0], 

290 _nest_where(vals[1:], cases[1:])) 

291 

292 if matrix.dtype in [dtypes.float16, dtypes.float32, dtypes.complex64]: 

293 maxnorm = const(3.925724783138660) 

294 squarings = math_ops.maximum( 

295 math_ops.floor( 

296 math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) 

297 u3, v3 = _matrix_exp_pade3(matrix) 

298 u5, v5 = _matrix_exp_pade5(matrix) 

299 u7, v7 = _matrix_exp_pade7( 

300 matrix / 

301 math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype)) 

302 conds = (4.258730016922831e-001, 1.880152677804762e+000) 

303 u = _nest_where(conds, (u3, u5, u7)) 

304 v = _nest_where(conds, (v3, v5, v7)) 

305 elif matrix.dtype in [dtypes.float64, dtypes.complex128]: 

306 maxnorm = const(5.371920351148152) 

307 squarings = math_ops.maximum( 

308 math_ops.floor( 

309 math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) 

310 u3, v3 = _matrix_exp_pade3(matrix) 

311 u5, v5 = _matrix_exp_pade5(matrix) 

312 u7, v7 = _matrix_exp_pade7(matrix) 

313 u9, v9 = _matrix_exp_pade9(matrix) 

314 u13, v13 = _matrix_exp_pade13( 

315 matrix / 

316 math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype)) 

317 conds = (1.495585217958292e-002, 2.539398330063230e-001, 

318 9.504178996162932e-001, 2.097847961257068e+000) 

319 u = _nest_where(conds, (u3, u5, u7, u9, u13)) 

320 v = _nest_where(conds, (v3, v5, v7, v9, v13)) 

321 else: 

322 raise ValueError('tf.linalg.expm does not support matrices of type %s' % 

323 matrix.dtype) 

324 

325 is_finite = math_ops.is_finite(math_ops.reduce_max(l1_norm)) 

326 nan = constant_op.constant(np.nan, matrix.dtype) 

327 result = tf_cond.cond( 

328 is_finite, lambda: linalg_ops.matrix_solve(-u + v, u + v), 

329 lambda: array_ops.fill(array_ops.shape(matrix), nan)) 

330 max_squarings = math_ops.reduce_max(squarings) 

331 i = const(0.0) 

332 

333 def c(i, _): 

334 return tf_cond.cond(is_finite, 

335 lambda: math_ops.less(i, max_squarings), 

336 lambda: constant_op.constant(False)) 

337 

338 def b(i, r): 

339 return i + 1, array_ops.where_v2( 

340 math_ops.less(i, squarings), math_ops.matmul(r, r), r) 

341 

342 _, result = while_loop.while_loop(c, b, [i, result]) 

343 if not matrix.shape.is_fully_defined(): 

344 return array_ops.reshape( 

345 result, 

346 array_ops.concat((batch_shape, array_ops.shape(result)[-2:]), axis=0)) 

347 return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:])) 

348 

349 

350@tf_export('linalg.banded_triangular_solve', v1=[]) 

351def banded_triangular_solve( 

352 bands, 

353 rhs, 

354 lower=True, 

355 adjoint=False, # pylint: disable=redefined-outer-name 

356 name=None): 

357 r"""Solve triangular systems of equations with a banded solver. 

358 

359 `bands` is a tensor of shape `[..., K, M]`, where `K` represents the number 

360 of bands stored. This corresponds to a batch of `M` by `M` matrices, whose 

361 `K` subdiagonals (when `lower` is `True`) are stored. 

362 

363 This operator broadcasts the batch dimensions of `bands` and the batch 

364 dimensions of `rhs`. 

365 

366 

367 Examples: 

368 

369 Storing 2 bands of a 3x3 matrix. 

370 Note that first element in the second row is ignored due to 

371 the 'LEFT_RIGHT' padding. 

372 

373 >>> x = [[2., 3., 4.], [1., 2., 3.]] 

374 >>> x2 = [[2., 3., 4.], [10000., 2., 3.]] 

375 >>> y = tf.zeros([3, 3]) 

376 >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(-1, 0)) 

377 >>> z 

378 <tf.Tensor: shape=(3, 3), dtype=float32, numpy= 

379 array([[2., 0., 0.], 

380 [2., 3., 0.], 

381 [0., 3., 4.]], dtype=float32)> 

382 >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([3, 1])) 

383 >>> soln 

384 <tf.Tensor: shape=(3, 1), dtype=float32, numpy= 

385 array([[0.5 ], 

386 [0. ], 

387 [0.25]], dtype=float32)> 

388 >>> are_equal = soln == tf.linalg.banded_triangular_solve(x2, tf.ones([3, 1])) 

389 >>> tf.reduce_all(are_equal).numpy() 

390 True 

391 >>> are_equal = soln == tf.linalg.triangular_solve(z, tf.ones([3, 1])) 

392 >>> tf.reduce_all(are_equal).numpy() 

393 True 

394 

395 Storing 2 superdiagonals of a 4x4 matrix. Because of the 'LEFT_RIGHT' padding 

396 the last element of the first row is ignored. 

397 

398 >>> x = [[2., 3., 4., 5.], [-1., -2., -3., -4.]] 

399 >>> y = tf.zeros([4, 4]) 

400 >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(0, 1)) 

401 >>> z 

402 <tf.Tensor: shape=(4, 4), dtype=float32, numpy= 

403 array([[-1., 2., 0., 0.], 

404 [ 0., -2., 3., 0.], 

405 [ 0., 0., -3., 4.], 

406 [ 0., 0., -0., -4.]], dtype=float32)> 

407 >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([4, 1]), lower=False) 

408 >>> soln 

409 <tf.Tensor: shape=(4, 1), dtype=float32, numpy= 

410 array([[-4. ], 

411 [-1.5 ], 

412 [-0.6666667], 

413 [-0.25 ]], dtype=float32)> 

414 >>> are_equal = (soln == tf.linalg.triangular_solve( 

415 ... z, tf.ones([4, 1]), lower=False)) 

416 >>> tf.reduce_all(are_equal).numpy() 

417 True 

418 

419 

420 Args: 

421 bands: A `Tensor` describing the bands of the left hand side, with shape 

422 `[..., K, M]`. The `K` rows correspond to the diagonal to the `K - 1`-th 

423 diagonal (the diagonal is the top row) when `lower` is `True` and 

424 otherwise the `K - 1`-th superdiagonal to the diagonal (the diagonal is 

425 the bottom row) when `lower` is `False`. The bands are stored with 

426 'LEFT_RIGHT' alignment, where the superdiagonals are padded on the right 

427 and subdiagonals are padded on the left. This is the alignment cuSPARSE 

428 uses. See `tf.linalg.set_diag` for more details. 

429 rhs: A `Tensor` of shape [..., M] or [..., M, N] and with the same dtype as 

430 `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known 

431 statically, `rhs` will be treated as a matrix rather than a vector. 

432 lower: An optional `bool`. Defaults to `True`. Boolean indicating whether 

433 `bands` represents a lower or upper triangular matrix. 

434 adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether 

435 to solve with the matrix's block-wise adjoint. 

436 name: A name to give this `Op` (optional). 

437 

438 Returns: 

439 A `Tensor` of shape [..., M] or [..., M, N] containing the solutions. 

440 """ 

441 with ops.name_scope(name, 'banded_triangular_solve', [bands, rhs]): 

442 return gen_linalg_ops.banded_triangular_solve( 

443 bands, rhs, lower=lower, adjoint=adjoint) 

444 

445 

446@tf_export('linalg.tridiagonal_solve') 

447@dispatch.add_dispatch_support 

448def tridiagonal_solve(diagonals, 

449 rhs, 

450 diagonals_format='compact', 

451 transpose_rhs=False, 

452 conjugate_rhs=False, 

453 name=None, 

454 partial_pivoting=True, 

455 perturb_singular=False): 

456 r"""Solves tridiagonal systems of equations. 

457 

458 The input can be supplied in various formats: `matrix`, `sequence` and 

459 `compact`, specified by the `diagonals_format` arg. 

460 

461 In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with 

462 two inner-most dimensions representing the square tridiagonal matrices. 

463 Elements outside of the three diagonals will be ignored. 

464 

465 In `sequence` format, `diagonals` are supplied as a tuple or list of three 

466 tensors of shapes `[..., N]`, `[..., M]`, `[..., N]` representing 

467 superdiagonals, diagonals, and subdiagonals, respectively. `N` can be either 

468 `M-1` or `M`; in the latter case, the last element of superdiagonal and the 

469 first element of subdiagonal will be ignored. 

470 

471 In `compact` format the three diagonals are brought together into one tensor 

472 of shape `[..., 3, M]`, with last two dimensions containing superdiagonals, 

473 diagonals, and subdiagonals, in order. Similarly to `sequence` format, 

474 elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored. 

475 

476 The `compact` format is recommended as the one with best performance. In case 

477 you need to cast a tensor into a compact format manually, use `tf.gather_nd`. 

478 An example for a tensor of shape [m, m]: 

479 

480 ```python 

481 rhs = tf.constant([...]) 

482 matrix = tf.constant([[...]]) 

483 m = matrix.shape[0] 

484 dummy_idx = [0, 0] # An arbitrary element to use as a dummy 

485 indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx], # Superdiagonal 

486 [[i, i] for i in range(m)], # Diagonal 

487 [dummy_idx] + [[i + 1, i] for i in range(m - 1)]] # Subdiagonal 

488 diagonals=tf.gather_nd(matrix, indices) 

489 x = tf.linalg.tridiagonal_solve(diagonals, rhs) 

490 ``` 

491 

492 Regardless of the `diagonals_format`, `rhs` is a tensor of shape `[..., M]` or 

493 `[..., M, K]`. The latter allows to simultaneously solve K systems with the 

494 same left-hand sides and K different right-hand sides. If `transpose_rhs` 

495 is set to `True` the expected shape is `[..., M]` or `[..., K, M]`. 

496 

497 The batch dimensions, denoted as `...`, must be the same in `diagonals` and 

498 `rhs`. 

499 

500 The output is a tensor of the same shape as `rhs`: either `[..., M]` or 

501 `[..., M, K]`. 

502 

503 The op isn't guaranteed to raise an error if the input matrix is not 

504 invertible. `tf.debugging.check_numerics` can be applied to the output to 

505 detect invertibility problems. 

506 

507 **Note**: with large batch sizes, the computation on the GPU may be slow, if 

508 either `partial_pivoting=True` or there are multiple right-hand sides 

509 (`K > 1`). If this issue arises, consider if it's possible to disable pivoting 

510 and have `K = 1`, or, alternatively, consider using CPU. 

511 

512 On CPU, solution is computed via Gaussian elimination with or without partial 

513 pivoting, depending on `partial_pivoting` parameter. On GPU, Nvidia's cuSPARSE 

514 library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv 

515 

516 Args: 

517 diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The 

518 shape depends of `diagonals_format`, see description above. Must be 

519 `float32`, `float64`, `complex64`, or `complex128`. 

520 rhs: A `Tensor` of shape [..., M] or [..., M, K] and with the same dtype as 

521 `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known 

522 statically, `rhs` will be treated as a matrix rather than a vector. 

523 diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is 

524 `compact`. 

525 transpose_rhs: If `True`, `rhs` is transposed before solving (has no effect 

526 if the shape of rhs is [..., M]). 

527 conjugate_rhs: If `True`, `rhs` is conjugated before solving. 

528 name: A name to give this `Op` (optional). 

529 partial_pivoting: whether to perform partial pivoting. `True` by default. 

530 Partial pivoting makes the procedure more stable, but slower. Partial 

531 pivoting is unnecessary in some cases, including diagonally dominant and 

532 symmetric positive definite matrices (see e.g. theorem 9.12 in [1]). 

533 perturb_singular: whether to perturb singular matrices to return a finite 

534 result. `False` by default. If true, solutions to systems involving 

535 a singular matrix will be computed by perturbing near-zero pivots in 

536 the partially pivoted LU decomposition. Specifically, tiny pivots are 

537 perturbed by an amount of order `eps * max_{ij} |U(i,j)|` to avoid 

538 overflow. Here `U` is the upper triangular part of the LU decomposition, 

539 and `eps` is the machine precision. This is useful for solving 

540 numerically singular systems when computing eigenvectors by inverse 

541 iteration. 

542 If `partial_pivoting` is `False`, `perturb_singular` must be `False` as 

543 well. 

544 

545 Returns: 

546 A `Tensor` of shape [..., M] or [..., M, K] containing the solutions. 

547 If the input matrix is singular, the result is undefined. 

548 

549 Raises: 

550 ValueError: Is raised if any of the following conditions hold: 

551 1. An unsupported type is provided as input, 

552 2. the input tensors have incorrect shapes, 

553 3. `perturb_singular` is `True` but `partial_pivoting` is not. 

554 UnimplementedError: Whenever `partial_pivoting` is true and the backend is 

555 XLA, or whenever `perturb_singular` is true and the backend is 

556 XLA or GPU. 

557 

558 [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms: 

559 Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7. 

560 

561 """ 

562 if perturb_singular and not partial_pivoting: 

563 raise ValueError('partial_pivoting must be True if perturb_singular is.') 

564 

565 if diagonals_format == 'compact': 

566 return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 

567 conjugate_rhs, partial_pivoting, 

568 perturb_singular, name) 

569 

570 if diagonals_format == 'sequence': 

571 if not isinstance(diagonals, (tuple, list)) or len(diagonals) != 3: 

572 raise ValueError('Expected diagonals to be a sequence of length 3.') 

573 

574 superdiag, maindiag, subdiag = diagonals 

575 if (not subdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1]) or 

576 not superdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1])): 

577 raise ValueError( 

578 'Tensors representing the three diagonals must have the same shape,' 

579 'except for the last dimension, got {}, {}, {}'.format( 

580 subdiag.shape, maindiag.shape, superdiag.shape)) 

581 

582 m = tensor_shape.dimension_value(maindiag.shape[-1]) 

583 

584 def pad_if_necessary(t, name, last_dim_padding): 

585 n = tensor_shape.dimension_value(t.shape[-1]) 

586 if not n or n == m: 

587 return t 

588 if n == m - 1: 

589 paddings = ([[0, 0] for _ in range(len(t.shape) - 1)] + 

590 [last_dim_padding]) 

591 return array_ops.pad(t, paddings) 

592 raise ValueError('Expected {} to be have length {} or {}, got {}.'.format( 

593 name, m, m - 1, n)) 

594 

595 subdiag = pad_if_necessary(subdiag, 'subdiagonal', [1, 0]) 

596 superdiag = pad_if_necessary(superdiag, 'superdiagonal', [0, 1]) 

597 

598 diagonals = array_ops_stack.stack((superdiag, maindiag, subdiag), axis=-2) 

599 return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 

600 conjugate_rhs, partial_pivoting, 

601 perturb_singular, name) 

602 

603 if diagonals_format == 'matrix': 

604 m1 = tensor_shape.dimension_value(diagonals.shape[-1]) 

605 m2 = tensor_shape.dimension_value(diagonals.shape[-2]) 

606 if m1 and m2 and m1 != m2: 

607 raise ValueError( 

608 'Expected last two dimensions of diagonals to be same, got {} and {}' 

609 .format(m1, m2)) 

610 m = m1 or m2 

611 diagonals = array_ops.matrix_diag_part( 

612 diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT') 

613 return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 

614 conjugate_rhs, partial_pivoting, 

615 perturb_singular, name) 

616 

617 raise ValueError('Unrecognized diagonals_format: {}'.format(diagonals_format)) 

618 

619 

620def _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 

621 conjugate_rhs, partial_pivoting, 

622 perturb_singular, name): 

623 """Helper function used after the input has been cast to compact form.""" 

624 diags_rank, rhs_rank = diagonals.shape.rank, rhs.shape.rank 

625 

626 # If we know the rank of the diagonal tensor, do some static checking. 

627 if diags_rank: 

628 if diags_rank < 2: 

629 raise ValueError( 

630 'Expected diagonals to have rank at least 2, got {}'.format( 

631 diags_rank)) 

632 if rhs_rank and rhs_rank != diags_rank and rhs_rank != diags_rank - 1: 

633 raise ValueError('Expected the rank of rhs to be {} or {}, got {}'.format( 

634 diags_rank - 1, diags_rank, rhs_rank)) 

635 if (rhs_rank and not diagonals.shape[:-2].is_compatible_with( 

636 rhs.shape[:diags_rank - 2])): 

637 raise ValueError('Batch shapes {} and {} are incompatible'.format( 

638 diagonals.shape[:-2], rhs.shape[:diags_rank - 2])) 

639 

640 if diagonals.shape[-2] and diagonals.shape[-2] != 3: 

641 raise ValueError('Expected 3 diagonals got {}'.format(diagonals.shape[-2])) 

642 

643 def check_num_lhs_matches_num_rhs(): 

644 if (diagonals.shape[-1] and rhs.shape[-2] and 

645 diagonals.shape[-1] != rhs.shape[-2]): 

646 raise ValueError('Expected number of left-hand sided and right-hand ' 

647 'sides to be equal, got {} and {}'.format( 

648 diagonals.shape[-1], rhs.shape[-2])) 

649 

650 if rhs_rank and diags_rank and rhs_rank == diags_rank - 1: 

651 # Rhs provided as a vector, ignoring transpose_rhs 

652 if conjugate_rhs: 

653 rhs = math_ops.conj(rhs) 

654 rhs = array_ops.expand_dims(rhs, -1) 

655 check_num_lhs_matches_num_rhs() 

656 return array_ops.squeeze( 

657 linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting, 

658 perturb_singular, name), -1) 

659 

660 if transpose_rhs: 

661 rhs = array_ops.matrix_transpose(rhs, conjugate=conjugate_rhs) 

662 elif conjugate_rhs: 

663 rhs = math_ops.conj(rhs) 

664 

665 check_num_lhs_matches_num_rhs() 

666 return linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting, 

667 perturb_singular, name) 

668 

669 

670@tf_export('linalg.tridiagonal_matmul') 

671@dispatch.add_dispatch_support 

672def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None): 

673 r"""Multiplies tridiagonal matrix by matrix. 

674 

675 `diagonals` is representation of 3-diagonal NxN matrix, which depends on 

676 `diagonals_format`. 

677 

678 In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with 

679 two inner-most dimensions representing the square tridiagonal matrices. 

680 Elements outside of the three diagonals will be ignored. 

681 

682 If `sequence` format, `diagonals` is list or tuple of three tensors: 

683 `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element 

684 of `superdiag` first element of `subdiag` are ignored. 

685 

686 In `compact` format the three diagonals are brought together into one tensor 

687 of shape `[..., 3, M]`, with last two dimensions containing superdiagonals, 

688 diagonals, and subdiagonals, in order. Similarly to `sequence` format, 

689 elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored. 

690 

691 The `sequence` format is recommended as the one with the best performance. 

692 

693 `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`. 

694 

695 Example: 

696 

697 ```python 

698 superdiag = tf.constant([-1, -1, 0], dtype=tf.float64) 

699 maindiag = tf.constant([2, 2, 2], dtype=tf.float64) 

700 subdiag = tf.constant([0, -1, -1], dtype=tf.float64) 

701 diagonals = [superdiag, maindiag, subdiag] 

702 rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64) 

703 x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence') 

704 ``` 

705 

706 Args: 

707 diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The 

708 shape depends of `diagonals_format`, see description above. Must be 

709 `float32`, `float64`, `complex64`, or `complex128`. 

710 rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`. 

711 diagonals_format: one of `sequence`, or `compact`. Default is `compact`. 

712 name: A name to give this `Op` (optional). 

713 

714 Returns: 

715 A `Tensor` of shape [..., M, N] containing the result of multiplication. 

716 

717 Raises: 

718 ValueError: An unsupported type is provided as input, or when the input 

719 tensors have incorrect shapes. 

720 """ 

721 if diagonals_format == 'compact': 

722 superdiag = diagonals[..., 0, :] 

723 maindiag = diagonals[..., 1, :] 

724 subdiag = diagonals[..., 2, :] 

725 elif diagonals_format == 'sequence': 

726 superdiag, maindiag, subdiag = diagonals 

727 elif diagonals_format == 'matrix': 

728 m1 = tensor_shape.dimension_value(diagonals.shape[-1]) 

729 m2 = tensor_shape.dimension_value(diagonals.shape[-2]) 

730 if m1 and m2 and m1 != m2: 

731 raise ValueError( 

732 'Expected last two dimensions of diagonals to be same, got {} and {}' 

733 .format(m1, m2)) 

734 diags = array_ops.matrix_diag_part( 

735 diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT') 

736 superdiag = diags[..., 0, :] 

737 maindiag = diags[..., 1, :] 

738 subdiag = diags[..., 2, :] 

739 else: 

740 raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format) 

741 

742 # C++ backend requires matrices. 

743 # Converting 1-dimensional vectors to matrices with 1 row. 

744 superdiag = array_ops.expand_dims(superdiag, -2) 

745 maindiag = array_ops.expand_dims(maindiag, -2) 

746 subdiag = array_ops.expand_dims(subdiag, -2) 

747 

748 return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs, name) 

749 

750 

751def _maybe_validate_matrix(a, validate_args): 

752 """Checks that input is a `float` matrix.""" 

753 assertions = [] 

754 if not a.dtype.is_floating: 

755 raise TypeError('Input `a` must have `float`-like `dtype` ' 

756 '(saw {}).'.format(a.dtype.name)) 

757 if a.shape is not None and a.shape.rank is not None: 

758 if a.shape.rank < 2: 

759 raise ValueError('Input `a` must have at least 2 dimensions ' 

760 '(saw: {}).'.format(a.shape.rank)) 

761 elif validate_args: 

762 assertions.append( 

763 check_ops.assert_rank_at_least( 

764 a, rank=2, message='Input `a` must have at least 2 dimensions.')) 

765 return assertions 

766 

767 

768@tf_export('linalg.matrix_rank') 

769@dispatch.add_dispatch_support 

770def matrix_rank(a, tol=None, validate_args=False, name=None): 

771 """Compute the matrix rank of one or more matrices. 

772 

773 Args: 

774 a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be 

775 pseudo-inverted. 

776 tol: Threshold below which the singular value is counted as 'zero'. 

777 Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`). 

778 validate_args: When `True`, additional assertions might be embedded in the 

779 graph. 

780 Default value: `False` (i.e., no graph assertions are added). 

781 name: Python `str` prefixed to ops created by this function. 

782 Default value: 'matrix_rank'. 

783 

784 Returns: 

785 matrix_rank: (Batch of) `int32` scalars representing the number of non-zero 

786 singular values. 

787 """ 

788 with ops.name_scope(name or 'matrix_rank'): 

789 a = ops.convert_to_tensor(a, dtype_hint=dtypes.float32, name='a') 

790 assertions = _maybe_validate_matrix(a, validate_args) 

791 if assertions: 

792 with ops.control_dependencies(assertions): 

793 a = array_ops.identity(a) 

794 s = svd(a, compute_uv=False) 

795 if tol is None: 

796 if (a.shape[-2:]).is_fully_defined(): 

797 m = np.max(a.shape[-2:].as_list()) 

798 else: 

799 m = math_ops.reduce_max(array_ops.shape(a)[-2:]) 

800 eps = np.finfo(a.dtype.as_numpy_dtype).eps 

801 tol = ( 

802 eps * math_ops.cast(m, a.dtype) * 

803 math_ops.reduce_max(s, axis=-1, keepdims=True)) 

804 return math_ops.reduce_sum(math_ops.cast(s > tol, dtypes.int32), axis=-1) 

805 

806 

807@tf_export('linalg.pinv') 

808@dispatch.add_dispatch_support 

809def pinv(a, rcond=None, validate_args=False, name=None): 

810 """Compute the Moore-Penrose pseudo-inverse of one or more matrices. 

811 

812 Calculate the [generalized inverse of a matrix]( 

813 https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its 

814 singular-value decomposition (SVD) and including all large singular values. 

815 

816 The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves' 

817 [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then 

818 `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if 

819 `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then 

820 `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1] 

821 

822 This function is analogous to [`numpy.linalg.pinv`]( 

823 https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html). 

824 It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the 

825 default `rcond` is `1e-15`. Here the default is 

826 `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`. 

827 

828 Args: 

829 a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be 

830 pseudo-inverted. 

831 rcond: `Tensor` of small singular value cutoffs. Singular values smaller 

832 (in modulus) than `rcond` * largest_singular_value (again, in modulus) are 

833 set to zero. Must broadcast against `tf.shape(a)[:-2]`. 

834 Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`. 

835 validate_args: When `True`, additional assertions might be embedded in the 

836 graph. 

837 Default value: `False` (i.e., no graph assertions are added). 

838 name: Python `str` prefixed to ops created by this function. 

839 Default value: 'pinv'. 

840 

841 Returns: 

842 a_pinv: (Batch of) pseudo-inverse of input `a`. Has same shape as `a` except 

843 rightmost two dimensions are transposed. 

844 

845 Raises: 

846 TypeError: if input `a` does not have `float`-like `dtype`. 

847 ValueError: if input `a` has fewer than 2 dimensions. 

848 

849 #### Examples 

850 

851 ```python 

852 import tensorflow as tf 

853 import tensorflow_probability as tfp 

854 

855 a = tf.constant([[1., 0.4, 0.5], 

856 [0.4, 0.2, 0.25], 

857 [0.5, 0.25, 0.35]]) 

858 tf.matmul(tf.linalg.pinv(a), a) 

859 # ==> array([[1., 0., 0.], 

860 [0., 1., 0.], 

861 [0., 0., 1.]], dtype=float32) 

862 

863 a = tf.constant([[1., 0.4, 0.5, 1.], 

864 [0.4, 0.2, 0.25, 2.], 

865 [0.5, 0.25, 0.35, 3.]]) 

866 tf.matmul(tf.linalg.pinv(a), a) 

867 # ==> array([[ 0.76, 0.37, 0.21, -0.02], 

868 [ 0.37, 0.43, -0.33, 0.02], 

869 [ 0.21, -0.33, 0.81, 0.01], 

870 [-0.02, 0.02, 0.01, 1. ]], dtype=float32) 

871 ``` 

872 

873 #### References 

874 

875 [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press, 

876 Inc., 1980, pp. 139-142. 

877 """ 

878 with ops.name_scope(name or 'pinv'): 

879 a = ops.convert_to_tensor(a, name='a') 

880 

881 assertions = _maybe_validate_matrix(a, validate_args) 

882 if assertions: 

883 with ops.control_dependencies(assertions): 

884 a = array_ops.identity(a) 

885 

886 dtype = a.dtype.as_numpy_dtype 

887 

888 if rcond is None: 

889 

890 def get_dim_size(dim): 

891 dim_val = tensor_shape.dimension_value(a.shape[dim]) 

892 if dim_val is not None: 

893 return dim_val 

894 return array_ops.shape(a)[dim] 

895 

896 num_rows = get_dim_size(-2) 

897 num_cols = get_dim_size(-1) 

898 if isinstance(num_rows, int) and isinstance(num_cols, int): 

899 max_rows_cols = float(max(num_rows, num_cols)) 

900 else: 

901 max_rows_cols = math_ops.cast( 

902 math_ops.maximum(num_rows, num_cols), dtype) 

903 rcond = 10. * max_rows_cols * np.finfo(dtype).eps 

904 

905 rcond = ops.convert_to_tensor(rcond, dtype=dtype, name='rcond') 

906 

907 # Calculate pseudo inverse via SVD. 

908 # Note: if a is Hermitian then u == v. (We might observe additional 

909 # performance by explicitly setting `v = u` in such cases.) 

910 [ 

911 singular_values, # Sigma 

912 left_singular_vectors, # U 

913 right_singular_vectors, # V 

914 ] = svd( 

915 a, full_matrices=False, compute_uv=True) 

916 

917 # Saturate small singular values to inf. This has the effect of make 

918 # `1. / s = 0.` while not resulting in `NaN` gradients. 

919 cutoff = rcond * math_ops.reduce_max(singular_values, axis=-1) 

920 singular_values = array_ops.where_v2( 

921 singular_values > array_ops.expand_dims_v2(cutoff, -1), singular_values, 

922 np.array(np.inf, dtype)) 

923 

924 # By the definition of the SVD, `a == u @ s @ v^H`, and the pseudo-inverse 

925 # is defined as `pinv(a) == v @ inv(s) @ u^H`. 

926 a_pinv = math_ops.matmul( 

927 right_singular_vectors / array_ops.expand_dims_v2(singular_values, -2), 

928 left_singular_vectors, 

929 adjoint_b=True) 

930 

931 if a.shape is not None and a.shape.rank is not None: 

932 a_pinv.set_shape(a.shape[:-2].concatenate([a.shape[-1], a.shape[-2]])) 

933 

934 return a_pinv 

935 

936 

937@tf_export('linalg.lu_solve') 

938@dispatch.add_dispatch_support 

939def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): 

940 """Solves systems of linear eqns `A X = RHS`, given LU factorizations. 

941 

942 Note: this function does not verify the implied matrix is actually invertible 

943 nor is this condition checked even when `validate_args=True`. 

944 

945 Args: 

946 lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, 

947 matmul(L, U)) = X` then `lower_upper = L + U - eye`. 

948 perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = 

949 X` then `perm = argmax(P)`. 

950 rhs: Matrix-shaped float `Tensor` representing targets for which to solve; 

951 `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[..., 

952 tf.newaxis])[..., 0]`. 

953 validate_args: Python `bool` indicating whether arguments should be checked 

954 for correctness. Note: this function does not verify the implied matrix is 

955 actually invertible, even when `validate_args=True`. 

956 Default value: `False` (i.e., don't validate arguments). 

957 name: Python `str` name given to ops managed by this object. 

958 Default value: `None` (i.e., 'lu_solve'). 

959 

960 Returns: 

961 x: The `X` in `A @ X = RHS`. 

962 

963 #### Examples 

964 

965 ```python 

966 import numpy as np 

967 import tensorflow as tf 

968 import tensorflow_probability as tfp 

969 

970 x = [[[1., 2], 

971 [3, 4]], 

972 [[7, 8], 

973 [3, 4]]] 

974 inv_x = tf.linalg.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2)) 

975 tf.assert_near(tf.matrix_inverse(x), inv_x) 

976 # ==> True 

977 ``` 

978 

979 """ 

980 

981 with ops.name_scope(name or 'lu_solve'): 

982 lower_upper = ops.convert_to_tensor( 

983 lower_upper, dtype_hint=dtypes.float32, name='lower_upper') 

984 perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') 

985 rhs = ops.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs') 

986 

987 assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args) 

988 if assertions: 

989 with ops.control_dependencies(assertions): 

990 lower_upper = array_ops.identity(lower_upper) 

991 perm = array_ops.identity(perm) 

992 rhs = array_ops.identity(rhs) 

993 

994 if (rhs.shape.rank == 2 and perm.shape.rank == 1): 

995 # Both rhs and perm have scalar batch_shape. 

996 permuted_rhs = array_ops.gather(rhs, perm, axis=-2) 

997 else: 

998 # Either rhs or perm have non-scalar batch_shape or we can't determine 

999 # this information statically. 

1000 rhs_shape = array_ops.shape(rhs) 

1001 broadcast_batch_shape = array_ops.broadcast_dynamic_shape( 

1002 rhs_shape[:-2], 

1003 array_ops.shape(perm)[:-1]) 

1004 d, m = rhs_shape[-2], rhs_shape[-1] 

1005 rhs_broadcast_shape = array_ops.concat([broadcast_batch_shape, [d, m]], 

1006 axis=0) 

1007 

1008 # Tile out rhs. 

1009 broadcast_rhs = array_ops.broadcast_to(rhs, rhs_broadcast_shape) 

1010 broadcast_rhs = array_ops.reshape(broadcast_rhs, [-1, d, m]) 

1011 

1012 # Tile out perm and add batch indices. 

1013 broadcast_perm = array_ops.broadcast_to(perm, rhs_broadcast_shape[:-1]) 

1014 broadcast_perm = array_ops.reshape(broadcast_perm, [-1, d]) 

1015 broadcast_batch_size = math_ops.reduce_prod(broadcast_batch_shape) 

1016 broadcast_batch_indices = array_ops.broadcast_to( 

1017 math_ops.range(broadcast_batch_size)[:, array_ops.newaxis], 

1018 [broadcast_batch_size, d]) 

1019 broadcast_perm = array_ops_stack.stack( 

1020 [broadcast_batch_indices, broadcast_perm], axis=-1) 

1021 

1022 permuted_rhs = array_ops.gather_nd(broadcast_rhs, broadcast_perm) 

1023 permuted_rhs = array_ops.reshape(permuted_rhs, rhs_broadcast_shape) 

1024 

1025 lower = set_diag( 

1026 band_part(lower_upper, num_lower=-1, num_upper=0), 

1027 array_ops.ones( 

1028 array_ops.shape(lower_upper)[:-1], dtype=lower_upper.dtype)) 

1029 return triangular_solve( 

1030 lower_upper, # Only upper is accessed. 

1031 triangular_solve(lower, permuted_rhs), 

1032 lower=False) 

1033 

1034 

1035@tf_export('linalg.lu_matrix_inverse') 

1036@dispatch.add_dispatch_support 

1037def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None): 

1038 """Computes the inverse given the LU decomposition(s) of one or more matrices. 

1039 

1040 This op is conceptually identical to, 

1041 

1042 ```python 

1043 inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X)) 

1044 tf.assert_near(tf.matrix_inverse(X), inv_X) 

1045 # ==> True 

1046 ``` 

1047 

1048 Note: this function does not verify the implied matrix is actually invertible 

1049 nor is this condition checked even when `validate_args=True`. 

1050 

1051 Args: 

1052 lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, 

1053 matmul(L, U)) = X` then `lower_upper = L + U - eye`. 

1054 perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = 

1055 X` then `perm = argmax(P)`. 

1056 validate_args: Python `bool` indicating whether arguments should be checked 

1057 for correctness. Note: this function does not verify the implied matrix is 

1058 actually invertible, even when `validate_args=True`. 

1059 Default value: `False` (i.e., don't validate arguments). 

1060 name: Python `str` name given to ops managed by this object. 

1061 Default value: `None` (i.e., 'lu_matrix_inverse'). 

1062 

1063 Returns: 

1064 inv_x: The matrix_inv, i.e., 

1065 `tf.matrix_inverse(tf.linalg.lu_reconstruct(lu, perm))`. 

1066 

1067 #### Examples 

1068 

1069 ```python 

1070 import numpy as np 

1071 import tensorflow as tf 

1072 import tensorflow_probability as tfp 

1073 

1074 x = [[[3., 4], [1, 2]], 

1075 [[7., 8], [3, 4]]] 

1076 inv_x = tf.linalg.lu_matrix_inverse(*tf.linalg.lu(x)) 

1077 tf.assert_near(tf.matrix_inverse(x), inv_x) 

1078 # ==> True 

1079 ``` 

1080 

1081 """ 

1082 

1083 with ops.name_scope(name or 'lu_matrix_inverse'): 

1084 lower_upper = ops.convert_to_tensor( 

1085 lower_upper, dtype_hint=dtypes.float32, name='lower_upper') 

1086 perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') 

1087 assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) 

1088 if assertions: 

1089 with ops.control_dependencies(assertions): 

1090 lower_upper = array_ops.identity(lower_upper) 

1091 perm = array_ops.identity(perm) 

1092 shape = array_ops.shape(lower_upper) 

1093 return lu_solve( 

1094 lower_upper, 

1095 perm, 

1096 rhs=eye(shape[-1], batch_shape=shape[:-2], dtype=lower_upper.dtype), 

1097 validate_args=False) 

1098 

1099 

1100@tf_export('linalg.lu_reconstruct') 

1101@dispatch.add_dispatch_support 

1102def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): 

1103 """The reconstruct one or more matrices from their LU decomposition(s). 

1104 

1105 Args: 

1106 lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, 

1107 matmul(L, U)) = X` then `lower_upper = L + U - eye`. 

1108 perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = 

1109 X` then `perm = argmax(P)`. 

1110 validate_args: Python `bool` indicating whether arguments should be checked 

1111 for correctness. 

1112 Default value: `False` (i.e., don't validate arguments). 

1113 name: Python `str` name given to ops managed by this object. 

1114 Default value: `None` (i.e., 'lu_reconstruct'). 

1115 

1116 Returns: 

1117 x: The original input to `tf.linalg.lu`, i.e., `x` as in, 

1118 `lu_reconstruct(*tf.linalg.lu(x))`. 

1119 

1120 #### Examples 

1121 

1122 ```python 

1123 import numpy as np 

1124 import tensorflow as tf 

1125 import tensorflow_probability as tfp 

1126 

1127 x = [[[3., 4], [1, 2]], 

1128 [[7., 8], [3, 4]]] 

1129 x_reconstructed = tf.linalg.lu_reconstruct(*tf.linalg.lu(x)) 

1130 tf.assert_near(x, x_reconstructed) 

1131 # ==> True 

1132 ``` 

1133 

1134 """ 

1135 with ops.name_scope(name or 'lu_reconstruct'): 

1136 lower_upper = ops.convert_to_tensor( 

1137 lower_upper, dtype_hint=dtypes.float32, name='lower_upper') 

1138 perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') 

1139 

1140 assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) 

1141 if assertions: 

1142 with ops.control_dependencies(assertions): 

1143 lower_upper = array_ops.identity(lower_upper) 

1144 perm = array_ops.identity(perm) 

1145 

1146 shape = array_ops.shape(lower_upper) 

1147 

1148 lower = set_diag( 

1149 band_part(lower_upper, num_lower=-1, num_upper=0), 

1150 array_ops.ones(shape[:-1], dtype=lower_upper.dtype)) 

1151 upper = band_part(lower_upper, num_lower=0, num_upper=-1) 

1152 x = math_ops.matmul(lower, upper) 

1153 

1154 if (lower_upper.shape is None or lower_upper.shape.rank is None or 

1155 lower_upper.shape.rank != 2): 

1156 # We either don't know the batch rank or there are >0 batch dims. 

1157 batch_size = math_ops.reduce_prod(shape[:-2]) 

1158 d = shape[-1] 

1159 x = array_ops.reshape(x, [batch_size, d, d]) 

1160 perm = array_ops.reshape(perm, [batch_size, d]) 

1161 perm = map_fn.map_fn(array_ops.invert_permutation, perm) 

1162 batch_indices = array_ops.broadcast_to( 

1163 math_ops.range(batch_size)[:, array_ops.newaxis], [batch_size, d]) 

1164 x = array_ops.gather_nd( 

1165 x, array_ops_stack.stack([batch_indices, perm], axis=-1)) 

1166 x = array_ops.reshape(x, shape) 

1167 else: 

1168 x = array_ops.gather(x, array_ops.invert_permutation(perm)) 

1169 

1170 x.set_shape(lower_upper.shape) 

1171 return x 

1172 

1173 

1174def lu_reconstruct_assertions(lower_upper, perm, validate_args): 

1175 """Returns list of assertions related to `lu_reconstruct` assumptions.""" 

1176 assertions = [] 

1177 

1178 message = 'Input `lower_upper` must have at least 2 dimensions.' 

1179 if lower_upper.shape.rank is not None and lower_upper.shape.rank < 2: 

1180 raise ValueError(message) 

1181 elif validate_args: 

1182 assertions.append( 

1183 check_ops.assert_rank_at_least_v2(lower_upper, rank=2, message=message)) 

1184 

1185 message = '`rank(lower_upper)` must equal `rank(perm) + 1`' 

1186 if lower_upper.shape.rank is not None and perm.shape.rank is not None: 

1187 if lower_upper.shape.rank != perm.shape.rank + 1: 

1188 raise ValueError(message) 

1189 elif validate_args: 

1190 assertions.append( 

1191 check_ops.assert_rank( 

1192 lower_upper, rank=array_ops.rank(perm) + 1, message=message)) 

1193 

1194 message = '`lower_upper` must be square.' 

1195 if lower_upper.shape[:-2].is_fully_defined(): 

1196 if lower_upper.shape[-2] != lower_upper.shape[-1]: 

1197 raise ValueError(message) 

1198 elif validate_args: 

1199 m, n = array_ops.split( 

1200 array_ops.shape(lower_upper)[-2:], num_or_size_splits=2) 

1201 assertions.append(check_ops.assert_equal(m, n, message=message)) 

1202 

1203 return assertions 

1204 

1205 

1206def _lu_solve_assertions(lower_upper, perm, rhs, validate_args): 

1207 """Returns list of assertions related to `lu_solve` assumptions.""" 

1208 assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) 

1209 

1210 message = 'Input `rhs` must have at least 2 dimensions.' 

1211 if rhs.shape.ndims is not None: 

1212 if rhs.shape.ndims < 2: 

1213 raise ValueError(message) 

1214 elif validate_args: 

1215 assertions.append( 

1216 check_ops.assert_rank_at_least(rhs, rank=2, message=message)) 

1217 

1218 message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.' 

1219 if (lower_upper.shape[-1] is not None and rhs.shape[-2] is not None): 

1220 if lower_upper.shape[-1] != rhs.shape[-2]: 

1221 raise ValueError(message) 

1222 elif validate_args: 

1223 assertions.append( 

1224 check_ops.assert_equal( 

1225 array_ops.shape(lower_upper)[-1], 

1226 array_ops.shape(rhs)[-2], 

1227 message=message)) 

1228 

1229 return assertions 

1230 

1231 

1232@tf_export('linalg.eigh_tridiagonal') 

1233@dispatch.add_dispatch_support 

1234def eigh_tridiagonal(alpha, 

1235 beta, 

1236 eigvals_only=True, 

1237 select='a', 

1238 select_range=None, 

1239 tol=None, 

1240 name=None): 

1241 """Computes the eigenvalues of a Hermitian tridiagonal matrix. 

1242 

1243 Args: 

1244 alpha: A real or complex tensor of shape (n), the diagonal elements of the 

1245 matrix. NOTE: If alpha is complex, the imaginary part is ignored (assumed 

1246 zero) to satisfy the requirement that the matrix be Hermitian. 

1247 beta: A real or complex tensor of shape (n-1), containing the elements of 

1248 the first super-diagonal of the matrix. If beta is complex, the first 

1249 sub-diagonal of the matrix is assumed to be the conjugate of beta to 

1250 satisfy the requirement that the matrix be Hermitian 

1251 eigvals_only: If False, both eigenvalues and corresponding eigenvectors are 

1252 computed. If True, only eigenvalues are computed. Default is True. 

1253 select: Optional string with values in {‘a’, ‘v’, ‘i’} (default is 'a') that 

1254 determines which eigenvalues to calculate: 

1255 'a': all eigenvalues. 

1256 ‘v’: eigenvalues in the interval (min, max] given by `select_range`. 

1257 'i’: eigenvalues with indices min <= i <= max. 

1258 select_range: Size 2 tuple or list or tensor specifying the range of 

1259 eigenvalues to compute together with select. If select is 'a', 

1260 select_range is ignored. 

1261 tol: Optional scalar. The absolute tolerance to which each eigenvalue is 

1262 required. An eigenvalue (or cluster) is considered to have converged if it 

1263 lies in an interval of this width. If tol is None (default), the value 

1264 eps*|T|_2 is used where eps is the machine precision, and |T|_2 is the 

1265 2-norm of the matrix T. 

1266 name: Optional name of the op. 

1267 

1268 Returns: 

1269 eig_vals: The eigenvalues of the matrix in non-decreasing order. 

1270 eig_vectors: If `eigvals_only` is False the eigenvectors are returned in 

1271 the second output argument. 

1272 

1273 Raises: 

1274 ValueError: If input values are invalid. 

1275 NotImplemented: Computing eigenvectors for `eigvals_only` = False is 

1276 not implemented yet. 

1277 

1278 This op implements a subset of the functionality of 

1279 scipy.linalg.eigh_tridiagonal. 

1280 

1281 Note: The result is undefined if the input contains +/-inf or NaN, or if 

1282 any value in beta has a magnitude greater than 

1283 `numpy.sqrt(numpy.finfo(beta.dtype.as_numpy_dtype).max)`. 

1284 

1285 

1286 TODO(b/187527398): 

1287 Add support for outer batch dimensions. 

1288 

1289 #### Examples 

1290 

1291 ```python 

1292 import numpy 

1293 eigvals = tf.linalg.eigh_tridiagonal([0.0, 0.0, 0.0], [1.0, 1.0]) 

1294 eigvals_expected = [-numpy.sqrt(2.0), 0.0, numpy.sqrt(2.0)] 

1295 tf.assert_near(eigvals_expected, eigvals) 

1296 # ==> True 

1297 ``` 

1298 

1299 """ 

1300 with ops.name_scope(name or 'eigh_tridiagonal'): 

1301 

1302 def _compute_eigenvalues(alpha, beta): 

1303 """Computes all eigenvalues of a Hermitian tridiagonal matrix.""" 

1304 

1305 def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x): 

1306 """Implements the Sturm sequence recurrence.""" 

1307 with ops.name_scope('sturm'): 

1308 n = alpha.shape[0] 

1309 zeros = array_ops.zeros(array_ops.shape(x), dtype=dtypes.int32) 

1310 ones = array_ops.ones(array_ops.shape(x), dtype=dtypes.int32) 

1311 

1312 # The first step in the Sturm sequence recurrence 

1313 # requires special care if x is equal to alpha[0]. 

1314 def sturm_step0(): 

1315 q = alpha[0] - x 

1316 count = array_ops.where(q < 0, ones, zeros) 

1317 q = array_ops.where( 

1318 math_ops.equal(alpha[0], x), alpha0_perturbation, q) 

1319 return q, count 

1320 

1321 # Subsequent steps all take this form: 

1322 def sturm_step(i, q, count): 

1323 q = alpha[i] - beta_sq[i - 1] / q - x 

1324 count = array_ops.where(q <= pivmin, count + 1, count) 

1325 q = array_ops.where(q <= pivmin, math_ops.minimum(q, -pivmin), q) 

1326 return q, count 

1327 

1328 # The first step initializes q and count. 

1329 q, count = sturm_step0() 

1330 

1331 # Peel off ((n-1) % blocksize) steps from the main loop, so we can run 

1332 # the bulk of the iterations unrolled by a factor of blocksize. 

1333 blocksize = 16 

1334 i = 1 

1335 peel = (n - 1) % blocksize 

1336 unroll_cnt = peel 

1337 

1338 def unrolled_steps(start, q, count): 

1339 for j in range(unroll_cnt): 

1340 q, count = sturm_step(start + j, q, count) 

1341 return start + unroll_cnt, q, count 

1342 

1343 i, q, count = unrolled_steps(i, q, count) 

1344 

1345 # Run the remaining steps of the Sturm sequence using a partially 

1346 # unrolled while loop. 

1347 unroll_cnt = blocksize 

1348 cond = lambda i, q, count: math_ops.less(i, n) 

1349 _, _, count = while_loop.while_loop( 

1350 cond, unrolled_steps, [i, q, count], back_prop=False) 

1351 return count 

1352 

1353 with ops.name_scope('compute_eigenvalues'): 

1354 if alpha.dtype.is_complex: 

1355 alpha = math_ops.real(alpha) 

1356 beta_sq = math_ops.real(math_ops.conj(beta) * beta) 

1357 beta_abs = math_ops.sqrt(beta_sq) 

1358 else: 

1359 beta_sq = math_ops.square(beta) 

1360 beta_abs = math_ops.abs(beta) 

1361 

1362 # Estimate the largest and smallest eigenvalues of T using the 

1363 # Gershgorin circle theorem. 

1364 finfo = np.finfo(alpha.dtype.as_numpy_dtype) 

1365 off_diag_abs_row_sum = array_ops.concat( 

1366 [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0) 

1367 lambda_est_max = math_ops.minimum( 

1368 finfo.max, math_ops.reduce_max(alpha + off_diag_abs_row_sum)) 

1369 lambda_est_min = math_ops.maximum( 

1370 finfo.min, math_ops.reduce_min(alpha - off_diag_abs_row_sum)) 

1371 # Upper bound on 2-norm of T. 

1372 t_norm = math_ops.maximum( 

1373 math_ops.abs(lambda_est_min), math_ops.abs(lambda_est_max)) 

1374 

1375 # Compute the smallest allowed pivot in the Sturm sequence to avoid 

1376 # overflow. 

1377 one = np.ones([], dtype=alpha.dtype.as_numpy_dtype) 

1378 safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny) 

1379 pivmin = safemin * math_ops.maximum(one, math_ops.reduce_max(beta_sq)) 

1380 alpha0_perturbation = math_ops.square(finfo.eps * beta_abs[0]) 

1381 abs_tol = finfo.eps * t_norm 

1382 if tol: 

1383 abs_tol = math_ops.maximum(tol, abs_tol) 

1384 # In the worst case, when the absolute tolerance is eps*lambda_est_max 

1385 # and lambda_est_max = -lambda_est_min, we have to take as many 

1386 # bisection steps as there are bits in the mantissa plus 1. 

1387 max_it = finfo.nmant + 1 

1388 

1389 # Determine the indices of the desired eigenvalues, based on select 

1390 # and select_range. 

1391 asserts = None 

1392 if select == 'a': 

1393 target_counts = math_ops.range(n) 

1394 elif select == 'i': 

1395 asserts = check_ops.assert_less_equal( 

1396 select_range[0], 

1397 select_range[1], 

1398 message='Got empty index range in select_range.') 

1399 target_counts = math_ops.range(select_range[0], select_range[1] + 1) 

1400 elif select == 'v': 

1401 asserts = check_ops.assert_less( 

1402 select_range[0], 

1403 select_range[1], 

1404 message='Got empty interval in select_range.') 

1405 else: 

1406 raise ValueError("'select must have a value in {'a', 'i', 'v'}.") 

1407 

1408 if asserts: 

1409 with ops.control_dependencies([asserts]): 

1410 alpha = array_ops.identity(alpha) 

1411 

1412 # Run binary search for all desired eigenvalues in parallel, starting 

1413 # from an interval slightly wider than the estimated 

1414 # [lambda_est_min, lambda_est_max]. 

1415 fudge = 2.1 # We widen starting interval the Gershgorin interval a bit. 

1416 norm_slack = math_ops.cast(n, alpha.dtype) * fudge * finfo.eps * t_norm 

1417 if select in {'a', 'i'}: 

1418 lower = lambda_est_min - norm_slack - 2 * fudge * pivmin 

1419 upper = lambda_est_max + norm_slack + fudge * pivmin 

1420 else: 

1421 # Count the number of eigenvalues in the given range. 

1422 lower = select_range[0] - norm_slack - 2 * fudge * pivmin 

1423 upper = select_range[1] + norm_slack + fudge * pivmin 

1424 first = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, lower) 

1425 last = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, upper) 

1426 target_counts = math_ops.range(first, last) 

1427 

1428 # Pre-broadcast the scalars used in the Sturm sequence for improved 

1429 # performance. 

1430 upper = math_ops.minimum(upper, finfo.max) 

1431 lower = math_ops.maximum(lower, finfo.min) 

1432 target_shape = array_ops.shape(target_counts) 

1433 lower = array_ops.broadcast_to(lower, shape=target_shape) 

1434 upper = array_ops.broadcast_to(upper, shape=target_shape) 

1435 pivmin = array_ops.broadcast_to(pivmin, target_shape) 

1436 alpha0_perturbation = array_ops.broadcast_to(alpha0_perturbation, 

1437 target_shape) 

1438 

1439 # We compute the midpoint as 0.5*lower + 0.5*upper to avoid overflow in 

1440 # (lower + upper) or (upper - lower) when the matrix has eigenvalues 

1441 # with magnitude greater than finfo.max / 2. 

1442 def midpoint(lower, upper): 

1443 return (0.5 * lower) + (0.5 * upper) 

1444 

1445 def continue_binary_search(i, lower, upper): 

1446 return math_ops.logical_and( 

1447 math_ops.less(i, max_it), 

1448 math_ops.less(abs_tol, math_ops.reduce_max(upper - lower))) 

1449 

1450 def binary_search_step(i, lower, upper): 

1451 mid = midpoint(lower, upper) 

1452 counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid) 

1453 lower = array_ops.where(counts <= target_counts, mid, lower) 

1454 upper = array_ops.where(counts > target_counts, mid, upper) 

1455 return i + 1, lower, upper 

1456 

1457 # Start parallel binary searches. 

1458 _, lower, upper = while_loop.while_loop(continue_binary_search, 

1459 binary_search_step, 

1460 [0, lower, upper]) 

1461 return midpoint(lower, upper) 

1462 

1463 def _compute_eigenvectors(alpha, beta, eigvals): 

1464 """Implements inverse iteration to compute eigenvectors.""" 

1465 with ops.name_scope('compute_eigenvectors'): 

1466 k = array_ops.size(eigvals) 

1467 n = array_ops.size(alpha) 

1468 alpha = math_ops.cast(alpha, dtype=beta.dtype) 

1469 

1470 # Eigenvectors corresponding to cluster of close eigenvalues are 

1471 # not unique and need to be explicitly orthogonalized. Here we 

1472 # identify such clusters. Note: This function assumes that 

1473 # eigenvalues are sorted in non-decreasing order. 

1474 gap = eigvals[1:] - eigvals[:-1] 

1475 eps = np.finfo(eigvals.dtype.as_numpy_dtype).eps 

1476 t_norm = math_ops.maximum( 

1477 math_ops.abs(eigvals[0]), math_ops.abs(eigvals[-1])) 

1478 gaptol = np.sqrt(eps) * t_norm 

1479 # Find the beginning and end of runs of eigenvectors corresponding 

1480 # to eigenvalues closer than "gaptol", which will need to be 

1481 # orthogonalized against each other. 

1482 close = math_ops.less(gap, gaptol) 

1483 left_neighbor_close = array_ops.concat([[False], close], axis=0) 

1484 right_neighbor_close = array_ops.concat([close, [False]], axis=0) 

1485 ortho_interval_start = math_ops.logical_and( 

1486 math_ops.logical_not(left_neighbor_close), right_neighbor_close) 

1487 ortho_interval_start = array_ops.squeeze( 

1488 array_ops.where_v2(ortho_interval_start), axis=-1) 

1489 ortho_interval_end = math_ops.logical_and( 

1490 left_neighbor_close, math_ops.logical_not(right_neighbor_close)) 

1491 ortho_interval_end = array_ops.squeeze( 

1492 array_ops.where_v2(ortho_interval_end), axis=-1) + 1 

1493 num_clusters = array_ops.size(ortho_interval_end) 

1494 

1495 # We perform inverse iteration for all eigenvectors in parallel, 

1496 # starting from a random set of vectors, until all have converged. 

1497 v0 = math_ops.cast( 

1498 stateless_random_ops.stateless_random_normal( 

1499 shape=(k, n), seed=[7, 42]), 

1500 dtype=beta.dtype) 

1501 nrm_v = norm(v0, axis=1) 

1502 v0 = v0 / nrm_v[:, array_ops.newaxis] 

1503 zero_nrm = constant_op.constant(0, shape=nrm_v.shape, dtype=nrm_v.dtype) 

1504 

1505 # Replicate alpha-eigvals(ik) and beta across the k eigenvectors so we 

1506 # can solve the k systems 

1507 # [T - eigvals(i)*eye(n)] x_i = r_i 

1508 # simultaneously using the batching mechanism. 

1509 eigvals_cast = math_ops.cast(eigvals, dtype=beta.dtype) 

1510 alpha_shifted = ( 

1511 alpha[array_ops.newaxis, :] - eigvals_cast[:, array_ops.newaxis]) 

1512 beta = array_ops.tile(beta[array_ops.newaxis, :], [k, 1]) 

1513 diags = [beta, alpha_shifted, math_ops.conj(beta)] 

1514 

1515 def orthogonalize_close_eigenvectors(eigenvectors): 

1516 # Eigenvectors corresponding to a cluster of close eigenvalues are not 

1517 # uniquely defined, but the subspace they span is. To avoid numerical 

1518 # instability, we explicitly mutually orthogonalize such eigenvectors 

1519 # after each step of inverse iteration. It is customary to use 

1520 # modified Gram-Schmidt for this, but this is not very efficient 

1521 # on some platforms, so here we defer to the QR decomposition in 

1522 # TensorFlow. 

1523 def orthogonalize_cluster(cluster_idx, eigenvectors): 

1524 start = ortho_interval_start[cluster_idx] 

1525 end = ortho_interval_end[cluster_idx] 

1526 update_indices = array_ops.expand_dims( 

1527 math_ops.range(start, end), -1) 

1528 vectors_in_cluster = eigenvectors[start:end, :] 

1529 # We use the builtin QR factorization to orthonormalize the 

1530 # vectors in the cluster. 

1531 q, _ = qr(transpose(vectors_in_cluster)) 

1532 vectors_to_update = transpose(q) 

1533 eigenvectors = array_ops.tensor_scatter_nd_update( 

1534 eigenvectors, update_indices, vectors_to_update) 

1535 return cluster_idx + 1, eigenvectors 

1536 

1537 _, eigenvectors = while_loop.while_loop( 

1538 lambda i, ev: math_ops.less(i, num_clusters), 

1539 orthogonalize_cluster, [0, eigenvectors]) 

1540 return eigenvectors 

1541 

1542 def continue_iteration(i, _, nrm_v, nrm_v_old): 

1543 max_it = 5 # Taken from LAPACK xSTEIN. 

1544 min_norm_growth = 0.1 

1545 norm_growth_factor = constant_op.constant( 

1546 1 + min_norm_growth, dtype=nrm_v.dtype) 

1547 # We stop the inverse iteration when we reach the maximum number of 

1548 # iterations or the norm growths is less than 10%. 

1549 return math_ops.logical_and( 

1550 math_ops.less(i, max_it), 

1551 math_ops.reduce_any( 

1552 math_ops.greater_equal( 

1553 math_ops.real(nrm_v), 

1554 math_ops.real(norm_growth_factor * nrm_v_old)))) 

1555 

1556 def inverse_iteration_step(i, v, nrm_v, nrm_v_old): 

1557 v = tridiagonal_solve( 

1558 diags, 

1559 v, 

1560 diagonals_format='sequence', 

1561 partial_pivoting=True, 

1562 perturb_singular=True) 

1563 nrm_v_old = nrm_v 

1564 nrm_v = norm(v, axis=1) 

1565 v = v / nrm_v[:, array_ops.newaxis] 

1566 v = orthogonalize_close_eigenvectors(v) 

1567 return i + 1, v, nrm_v, nrm_v_old 

1568 

1569 _, v, nrm_v, _ = while_loop.while_loop(continue_iteration, 

1570 inverse_iteration_step, 

1571 [0, v0, nrm_v, zero_nrm]) 

1572 return transpose(v) 

1573 

1574 alpha = ops.convert_to_tensor(alpha, name='alpha') 

1575 n = alpha.shape[0] 

1576 if n <= 1: 

1577 return math_ops.real(alpha) 

1578 beta = ops.convert_to_tensor(beta, name='beta') 

1579 

1580 if alpha.dtype != beta.dtype: 

1581 raise ValueError("'alpha' and 'beta' must have the same type.") 

1582 

1583 eigvals = _compute_eigenvalues(alpha, beta) 

1584 if eigvals_only: 

1585 return eigvals 

1586 

1587 eigvectors = _compute_eigenvectors(alpha, beta, eigvals) 

1588 return eigvals, eigvectors