Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_tridiag.py: 30%

132 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""`LinearOperator` acting like a tridiagonal matrix.""" 

16 

17from tensorflow.python.framework import ops 

18from tensorflow.python.framework import tensor_conversion 

19from tensorflow.python.ops import array_ops 

20from tensorflow.python.ops import array_ops_stack 

21from tensorflow.python.ops import check_ops 

22from tensorflow.python.ops import control_flow_ops 

23from tensorflow.python.ops import gen_array_ops 

24from tensorflow.python.ops import manip_ops 

25from tensorflow.python.ops import math_ops 

26from tensorflow.python.ops.linalg import linalg_impl as linalg 

27from tensorflow.python.ops.linalg import linear_operator 

28from tensorflow.python.ops.linalg import linear_operator_util 

29from tensorflow.python.util.tf_export import tf_export 

30 

31__all__ = ['LinearOperatorTridiag',] 

32 

33_COMPACT = 'compact' 

34_MATRIX = 'matrix' 

35_SEQUENCE = 'sequence' 

36_DIAGONAL_FORMATS = frozenset({_COMPACT, _MATRIX, _SEQUENCE}) 

37 

38 

39@tf_export('linalg.LinearOperatorTridiag') 

40@linear_operator.make_composite_tensor 

41class LinearOperatorTridiag(linear_operator.LinearOperator): 

42 """`LinearOperator` acting like a [batch] square tridiagonal matrix. 

43 

44 This operator acts like a [batch] square tridiagonal matrix `A` with shape 

45 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 

46 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 

47 an `N x M` matrix. This matrix `A` is not materialized, but for 

48 purposes of broadcasting this shape will be relevant. 

49 

50 Example usage: 

51 

52 Create a 3 x 3 tridiagonal linear operator. 

53 

54 >>> superdiag = [3., 4., 5.] 

55 >>> diag = [1., -1., 2.] 

56 >>> subdiag = [6., 7., 8] 

57 >>> operator = tf.linalg.LinearOperatorTridiag( 

58 ... [superdiag, diag, subdiag], 

59 ... diagonals_format='sequence') 

60 >>> operator.to_dense() 

61 <tf.Tensor: shape=(3, 3), dtype=float32, numpy= 

62 array([[ 1., 3., 0.], 

63 [ 7., -1., 4.], 

64 [ 0., 8., 2.]], dtype=float32)> 

65 >>> operator.shape 

66 TensorShape([3, 3]) 

67 

68 Scalar Tensor output. 

69 

70 >>> operator.log_abs_determinant() 

71 <tf.Tensor: shape=(), dtype=float32, numpy=4.3307333> 

72 

73 Create a [2, 3] batch of 4 x 4 linear operators. 

74 

75 >>> diagonals = tf.random.normal(shape=[2, 3, 3, 4]) 

76 >>> operator = tf.linalg.LinearOperatorTridiag( 

77 ... diagonals, 

78 ... diagonals_format='compact') 

79 

80 Create a shape [2, 1, 4, 2] vector. Note that this shape is compatible 

81 since the batch dimensions, [2, 1], are broadcast to 

82 operator.batch_shape = [2, 3]. 

83 

84 >>> y = tf.random.normal(shape=[2, 1, 4, 2]) 

85 >>> x = operator.solve(y) 

86 >>> x 

87 <tf.Tensor: shape=(2, 3, 4, 2), dtype=float32, numpy=..., 

88 dtype=float32)> 

89 

90 #### Shape compatibility 

91 

92 This operator acts on [batch] matrix with compatible shape. 

93 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 

94 

95 ``` 

96 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 

97 x.shape = [C1,...,Cc] + [N, R], 

98 and [C1,...,Cc] broadcasts with [B1,...,Bb]. 

99 ``` 

100 

101 #### Performance 

102 

103 Suppose `operator` is a `LinearOperatorTridiag` of shape `[N, N]`, 

104 and `x.shape = [N, R]`. Then 

105 

106 * `operator.matmul(x)` will take O(N * R) time. 

107 * `operator.solve(x)` will take O(N * R) time. 

108 

109 If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and 

110 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 

111 

112 #### Matrix property hints 

113 

114 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 

115 for `X = non_singular, self_adjoint, positive_definite, square`. 

116 These have the following meaning: 

117 

118 * If `is_X == True`, callers should expect the operator to have the 

119 property `X`. This is a promise that should be fulfilled, but is *not* a 

120 runtime assert. For example, finite floating point precision may result 

121 in these promises being violated. 

122 * If `is_X == False`, callers should expect the operator to not have `X`. 

123 * If `is_X == None` (the default), callers should have no expectation either 

124 way. 

125 """ 

126 

127 def __init__(self, 

128 diagonals, 

129 diagonals_format=_COMPACT, 

130 is_non_singular=None, 

131 is_self_adjoint=None, 

132 is_positive_definite=None, 

133 is_square=None, 

134 name='LinearOperatorTridiag'): 

135 r"""Initialize a `LinearOperatorTridiag`. 

136 

137 Args: 

138 diagonals: `Tensor` or list of `Tensor`s depending on `diagonals_format`. 

139 

140 If `diagonals_format=sequence`, this is a list of three `Tensor`'s each 

141 with shape `[B1, ..., Bb, N]`, `b >= 0, N >= 0`, representing the 

142 superdiagonal, diagonal and subdiagonal in that order. Note the 

143 superdiagonal is padded with an element in the last position, and the 

144 subdiagonal is padded with an element in the front. 

145 

146 If `diagonals_format=matrix` this is a `[B1, ... Bb, N, N]` shaped 

147 `Tensor` representing the full tridiagonal matrix. 

148 

149 If `diagonals_format=compact` this is a `[B1, ... Bb, 3, N]` shaped 

150 `Tensor` with the second to last dimension indexing the 

151 superdiagonal, diagonal and subdiagonal in that order. Note the 

152 superdiagonal is padded with an element in the last position, and the 

153 subdiagonal is padded with an element in the front. 

154 

155 In every case, these `Tensor`s are all floating dtype. 

156 diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is 

157 `compact`. 

158 is_non_singular: Expect that this operator is non-singular. 

159 is_self_adjoint: Expect that this operator is equal to its hermitian 

160 transpose. If `diag.dtype` is real, this is auto-set to `True`. 

161 is_positive_definite: Expect that this operator is positive definite, 

162 meaning the quadratic form `x^H A x` has positive real part for all 

163 nonzero `x`. Note that we do not require the operator to be 

164 self-adjoint to be positive-definite. See: 

165 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 

166 is_square: Expect that this operator acts like square [batch] matrices. 

167 name: A name for this `LinearOperator`. 

168 

169 Raises: 

170 TypeError: If `diag.dtype` is not an allowed type. 

171 ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`. 

172 """ 

173 parameters = dict( 

174 diagonals=diagonals, 

175 diagonals_format=diagonals_format, 

176 is_non_singular=is_non_singular, 

177 is_self_adjoint=is_self_adjoint, 

178 is_positive_definite=is_positive_definite, 

179 is_square=is_square, 

180 name=name 

181 ) 

182 

183 with ops.name_scope(name, values=[diagonals]): 

184 if diagonals_format not in _DIAGONAL_FORMATS: 

185 raise ValueError( 

186 f'Argument `diagonals_format` must be one of compact, matrix, or ' 

187 f'sequence. Received : {diagonals_format}.') 

188 if diagonals_format == _SEQUENCE: 

189 self._diagonals = [linear_operator_util.convert_nonref_to_tensor( 

190 d, name='diag_{}'.format(i)) for i, d in enumerate(diagonals)] 

191 dtype = self._diagonals[0].dtype 

192 else: 

193 self._diagonals = linear_operator_util.convert_nonref_to_tensor( 

194 diagonals, name='diagonals') 

195 dtype = self._diagonals.dtype 

196 self._diagonals_format = diagonals_format 

197 

198 super(LinearOperatorTridiag, self).__init__( 

199 dtype=dtype, 

200 is_non_singular=is_non_singular, 

201 is_self_adjoint=is_self_adjoint, 

202 is_positive_definite=is_positive_definite, 

203 is_square=is_square, 

204 parameters=parameters, 

205 name=name) 

206 

207 def _shape(self): 

208 if self.diagonals_format == _MATRIX: 

209 return self.diagonals.shape 

210 if self.diagonals_format == _COMPACT: 

211 # Remove the second to last dimension that contains the value 3. 

212 d_shape = self.diagonals.shape[:-2].concatenate( 

213 self.diagonals.shape[-1]) 

214 else: 

215 broadcast_shape = array_ops.broadcast_static_shape( 

216 self.diagonals[0].shape[:-1], 

217 self.diagonals[1].shape[:-1]) 

218 broadcast_shape = array_ops.broadcast_static_shape( 

219 broadcast_shape, 

220 self.diagonals[2].shape[:-1]) 

221 d_shape = broadcast_shape.concatenate(self.diagonals[1].shape[-1]) 

222 return d_shape.concatenate(d_shape[-1]) 

223 

224 def _shape_tensor(self, diagonals=None): 

225 diagonals = diagonals if diagonals is not None else self.diagonals 

226 if self.diagonals_format == _MATRIX: 

227 return array_ops.shape(diagonals) 

228 if self.diagonals_format == _COMPACT: 

229 d_shape = array_ops.shape(diagonals[..., 0, :]) 

230 else: 

231 broadcast_shape = array_ops.broadcast_dynamic_shape( 

232 array_ops.shape(self.diagonals[0])[:-1], 

233 array_ops.shape(self.diagonals[1])[:-1]) 

234 broadcast_shape = array_ops.broadcast_dynamic_shape( 

235 broadcast_shape, 

236 array_ops.shape(self.diagonals[2])[:-1]) 

237 d_shape = array_ops.concat( 

238 [broadcast_shape, [array_ops.shape(self.diagonals[1])[-1]]], axis=0) 

239 return array_ops.concat([d_shape, [d_shape[-1]]], axis=-1) 

240 

241 def _assert_self_adjoint(self): 

242 # Check the diagonal has non-zero imaginary, and the super and subdiagonals 

243 # are conjugate. 

244 

245 asserts = [] 

246 diag_message = ( 

247 'This tridiagonal operator contained non-zero ' 

248 'imaginary values on the diagonal.') 

249 off_diag_message = ( 

250 'This tridiagonal operator has non-conjugate ' 

251 'subdiagonal and superdiagonal.') 

252 

253 if self.diagonals_format == _MATRIX: 

254 asserts += [check_ops.assert_equal( 

255 self.diagonals, linalg.adjoint(self.diagonals), 

256 message='Matrix was not equal to its adjoint.')] 

257 elif self.diagonals_format == _COMPACT: 

258 diagonals = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

259 self.diagonals 

260 ) 

261 asserts += [linear_operator_util.assert_zero_imag_part( 

262 diagonals[..., 1, :], message=diag_message)] 

263 # Roll the subdiagonal so the shifted argument is at the end. 

264 subdiag = manip_ops.roll(diagonals[..., 2, :], shift=-1, axis=-1) 

265 asserts += [check_ops.assert_equal( 

266 math_ops.conj(subdiag[..., :-1]), 

267 diagonals[..., 0, :-1], 

268 message=off_diag_message)] 

269 else: 

270 asserts += [linear_operator_util.assert_zero_imag_part( 

271 self.diagonals[1], message=diag_message)] 

272 subdiag = manip_ops.roll(self.diagonals[2], shift=-1, axis=-1) 

273 asserts += [check_ops.assert_equal( 

274 math_ops.conj(subdiag[..., :-1]), 

275 self.diagonals[0][..., :-1], 

276 message=off_diag_message)] 

277 return control_flow_ops.group(asserts) 

278 

279 def _construct_adjoint_diagonals(self, diagonals): 

280 # Constructs adjoint tridiagonal matrix from diagonals. 

281 if self.diagonals_format == _SEQUENCE: 

282 diagonals = [math_ops.conj(d) for d in reversed(diagonals)] 

283 # The subdiag and the superdiag swap places, so we need to shift the 

284 # padding argument. 

285 diagonals[0] = manip_ops.roll(diagonals[0], shift=-1, axis=-1) 

286 diagonals[2] = manip_ops.roll(diagonals[2], shift=1, axis=-1) 

287 return diagonals 

288 elif self.diagonals_format == _MATRIX: 

289 return linalg.adjoint(diagonals) 

290 else: 

291 diagonals = math_ops.conj(diagonals) 

292 superdiag, diag, subdiag = array_ops_stack.unstack( 

293 diagonals, num=3, axis=-2) 

294 # The subdiag and the superdiag swap places, so we need 

295 # to shift all arguments. 

296 new_superdiag = manip_ops.roll(subdiag, shift=-1, axis=-1) 

297 new_subdiag = manip_ops.roll(superdiag, shift=1, axis=-1) 

298 return array_ops_stack.stack([new_superdiag, diag, new_subdiag], axis=-2) 

299 

300 def _matmul(self, x, adjoint=False, adjoint_arg=False): 

301 diagonals = self.diagonals 

302 if adjoint: 

303 diagonals = self._construct_adjoint_diagonals(diagonals) 

304 x = linalg.adjoint(x) if adjoint_arg else x 

305 return linalg.tridiagonal_matmul( 

306 diagonals, x, 

307 diagonals_format=self.diagonals_format) 

308 

309 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 

310 diagonals = self.diagonals 

311 if adjoint: 

312 diagonals = self._construct_adjoint_diagonals(diagonals) 

313 

314 # TODO(b/144860784): Remove the broadcasting code below once 

315 # tridiagonal_solve broadcasts. 

316 

317 rhs_shape = array_ops.shape(rhs) 

318 k = self._shape_tensor(diagonals)[-1] 

319 broadcast_shape = array_ops.broadcast_dynamic_shape( 

320 self._shape_tensor(diagonals)[:-2], rhs_shape[:-2]) 

321 rhs = array_ops.broadcast_to( 

322 rhs, array_ops.concat( 

323 [broadcast_shape, rhs_shape[-2:]], axis=-1)) 

324 if self.diagonals_format == _MATRIX: 

325 diagonals = array_ops.broadcast_to( 

326 diagonals, array_ops.concat( 

327 [broadcast_shape, [k, k]], axis=-1)) 

328 elif self.diagonals_format == _COMPACT: 

329 diagonals = array_ops.broadcast_to( 

330 diagonals, array_ops.concat( 

331 [broadcast_shape, [3, k]], axis=-1)) 

332 else: 

333 diagonals = [ 

334 array_ops.broadcast_to(d, array_ops.concat( 

335 [broadcast_shape, [k]], axis=-1)) for d in diagonals] 

336 

337 y = linalg.tridiagonal_solve( 

338 diagonals, rhs, 

339 diagonals_format=self.diagonals_format, 

340 transpose_rhs=adjoint_arg, 

341 conjugate_rhs=adjoint_arg) 

342 return y 

343 

344 def _diag_part(self): 

345 if self.diagonals_format == _MATRIX: 

346 return array_ops.matrix_diag_part(self.diagonals) 

347 elif self.diagonals_format == _SEQUENCE: 

348 diagonal = self.diagonals[1] 

349 return array_ops.broadcast_to( 

350 diagonal, self.shape_tensor()[:-1]) 

351 else: 

352 return self.diagonals[..., 1, :] 

353 

354 def _to_dense(self): 

355 if self.diagonals_format == _MATRIX: 

356 return self.diagonals 

357 

358 if self.diagonals_format == _COMPACT: 

359 return gen_array_ops.matrix_diag_v3( 

360 self.diagonals, 

361 k=(-1, 1), 

362 num_rows=-1, 

363 num_cols=-1, 

364 align='LEFT_RIGHT', 

365 padding_value=0.) 

366 

367 diagonals = [ 

368 tensor_conversion.convert_to_tensor_v2_with_dispatch(d) 

369 for d in self.diagonals 

370 ] 

371 diagonals = array_ops_stack.stack(diagonals, axis=-2) 

372 

373 return gen_array_ops.matrix_diag_v3( 

374 diagonals, 

375 k=(-1, 1), 

376 num_rows=-1, 

377 num_cols=-1, 

378 align='LEFT_RIGHT', 

379 padding_value=0.) 

380 

381 @property 

382 def diagonals(self): 

383 return self._diagonals 

384 

385 @property 

386 def diagonals_format(self): 

387 return self._diagonals_format 

388 

389 @property 

390 def _composite_tensor_fields(self): 

391 return ('diagonals', 'diagonals_format') 

392 

393 @property 

394 def _experimental_parameter_ndims_to_matrix_ndims(self): 

395 diagonal_event_ndims = 2 

396 if self.diagonals_format == _SEQUENCE: 

397 # For the diagonal and the super/sub diagonals. 

398 diagonal_event_ndims = [1, 1, 1] 

399 return { 

400 'diagonals': diagonal_event_ndims, 

401 }