Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_identity.py: 29%

256 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""`LinearOperator` acting like the identity matrix.""" 

16 

17import numpy as np 

18 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.framework import ops 

21from tensorflow.python.framework import tensor_conversion 

22from tensorflow.python.framework import tensor_shape 

23from tensorflow.python.framework import tensor_util 

24from tensorflow.python.ops import array_ops 

25from tensorflow.python.ops import array_ops_stack 

26from tensorflow.python.ops import check_ops 

27from tensorflow.python.ops import control_flow_ops 

28from tensorflow.python.ops import math_ops 

29from tensorflow.python.ops.linalg import linalg_impl as linalg 

30from tensorflow.python.ops.linalg import linear_operator 

31from tensorflow.python.ops.linalg import linear_operator_util 

32from tensorflow.python.util.tf_export import tf_export 

33 

34__all__ = [ 

35 "LinearOperatorIdentity", 

36 "LinearOperatorScaledIdentity", 

37] 

38 

39 

40class BaseLinearOperatorIdentity(linear_operator.LinearOperator): 

41 """Base class for Identity operators.""" 

42 

43 def _check_num_rows_possibly_add_asserts(self): 

44 """Static check of init arg `num_rows`, possibly add asserts.""" 

45 # Possibly add asserts. 

46 if self._assert_proper_shapes: 

47 self._num_rows = control_flow_ops.with_dependencies([ 

48 check_ops.assert_rank( 

49 self._num_rows, 

50 0, 

51 message="Argument num_rows must be a 0-D Tensor."), 

52 check_ops.assert_non_negative( 

53 self._num_rows, 

54 message="Argument num_rows must be non-negative."), 

55 ], self._num_rows) 

56 

57 # Static checks. 

58 if not self._num_rows.dtype.is_integer: 

59 raise TypeError("Argument num_rows must be integer type. Found:" 

60 " %s" % self._num_rows) 

61 

62 num_rows_static = self._num_rows_static 

63 

64 if num_rows_static is None: 

65 return # Cannot do any other static checks. 

66 

67 if num_rows_static.ndim != 0: 

68 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" 

69 " %s" % num_rows_static) 

70 

71 if num_rows_static < 0: 

72 raise ValueError("Argument num_rows must be non-negative. Found:" 

73 " %s" % num_rows_static) 

74 

75 def _min_matrix_dim(self): 

76 """Minimum of domain/range dimension, if statically available, else None.""" 

77 domain_dim = tensor_shape.dimension_value(self.domain_dimension) 

78 range_dim = tensor_shape.dimension_value(self.range_dimension) 

79 if domain_dim is None or range_dim is None: 

80 return None 

81 return min(domain_dim, range_dim) 

82 

83 def _min_matrix_dim_tensor(self): 

84 """Minimum of domain/range dimension, as a tensor.""" 

85 return math_ops.reduce_min(self.shape_tensor()[-2:]) 

86 

87 def _ones_diag(self): 

88 """Returns the diagonal of this operator as all ones.""" 

89 if self.shape.is_fully_defined(): 

90 d_shape = self.batch_shape.concatenate([self._min_matrix_dim()]) 

91 else: 

92 d_shape = array_ops.concat( 

93 [self.batch_shape_tensor(), 

94 [self._min_matrix_dim_tensor()]], axis=0) 

95 

96 return array_ops.ones(shape=d_shape, dtype=self.dtype) 

97 

98 

99@tf_export("linalg.LinearOperatorIdentity") 

100@linear_operator.make_composite_tensor 

101class LinearOperatorIdentity(BaseLinearOperatorIdentity): 

102 """`LinearOperator` acting like a [batch] square identity matrix. 

103 

104 This operator acts like a [batch] identity matrix `A` with shape 

105 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 

106 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 

107 an `N x N` matrix. This matrix `A` is not materialized, but for 

108 purposes of broadcasting this shape will be relevant. 

109 

110 `LinearOperatorIdentity` is initialized with `num_rows`, and optionally 

111 `batch_shape`, and `dtype` arguments. If `batch_shape` is `None`, this 

112 operator efficiently passes through all arguments. If `batch_shape` is 

113 provided, broadcasting may occur, which will require making copies. 

114 

115 ```python 

116 # Create a 2 x 2 identity matrix. 

117 operator = LinearOperatorIdentity(num_rows=2, dtype=tf.float32) 

118 

119 operator.to_dense() 

120 ==> [[1., 0.] 

121 [0., 1.]] 

122 

123 operator.shape 

124 ==> [2, 2] 

125 

126 operator.log_abs_determinant() 

127 ==> 0. 

128 

129 x = ... Shape [2, 4] Tensor 

130 operator.matmul(x) 

131 ==> Shape [2, 4] Tensor, same as x. 

132 

133 y = tf.random.normal(shape=[3, 2, 4]) 

134 # Note that y.shape is compatible with operator.shape because operator.shape 

135 # is broadcast to [3, 2, 2]. 

136 # This broadcast does NOT require copying data, since we can infer that y 

137 # will be passed through without changing shape. We are always able to infer 

138 # this if the operator has no batch_shape. 

139 x = operator.solve(y) 

140 ==> Shape [3, 2, 4] Tensor, same as y. 

141 

142 # Create a 2-batch of 2x2 identity matrices 

143 operator = LinearOperatorIdentity(num_rows=2, batch_shape=[2]) 

144 operator.to_dense() 

145 ==> [[[1., 0.] 

146 [0., 1.]], 

147 [[1., 0.] 

148 [0., 1.]]] 

149 

150 # Here, even though the operator has a batch shape, the input is the same as 

151 # the output, so x can be passed through without a copy. The operator is able 

152 # to detect that no broadcast is necessary because both x and the operator 

153 # have statically defined shape. 

154 x = ... Shape [2, 2, 3] 

155 operator.matmul(x) 

156 ==> Shape [2, 2, 3] Tensor, same as x 

157 

158 # Here the operator and x have different batch_shape, and are broadcast. 

159 # This requires a copy, since the output is different size than the input. 

160 x = ... Shape [1, 2, 3] 

161 operator.matmul(x) 

162 ==> Shape [2, 2, 3] Tensor, equal to [x, x] 

163 ``` 

164 

165 ### Shape compatibility 

166 

167 This operator acts on [batch] matrix with compatible shape. 

168 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 

169 

170 ``` 

171 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 

172 x.shape = [C1,...,Cc] + [N, R], 

173 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 

174 ``` 

175 

176 ### Performance 

177 

178 If `batch_shape` initialization arg is `None`: 

179 

180 * `operator.matmul(x)` is `O(1)` 

181 * `operator.solve(x)` is `O(1)` 

182 * `operator.determinant()` is `O(1)` 

183 

184 If `batch_shape` initialization arg is provided, and static checks cannot 

185 rule out the need to broadcast: 

186 

187 * `operator.matmul(x)` is `O(D1*...*Dd*N*R)` 

188 * `operator.solve(x)` is `O(D1*...*Dd*N*R)` 

189 * `operator.determinant()` is `O(B1*...*Bb)` 

190 

191 #### Matrix property hints 

192 

193 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 

194 for `X = non_singular, self_adjoint, positive_definite, square`. 

195 These have the following meaning: 

196 

197 * If `is_X == True`, callers should expect the operator to have the 

198 property `X`. This is a promise that should be fulfilled, but is *not* a 

199 runtime assert. For example, finite floating point precision may result 

200 in these promises being violated. 

201 * If `is_X == False`, callers should expect the operator to not have `X`. 

202 * If `is_X == None` (the default), callers should have no expectation either 

203 way. 

204 """ 

205 

206 def __init__(self, 

207 num_rows, 

208 batch_shape=None, 

209 dtype=None, 

210 is_non_singular=True, 

211 is_self_adjoint=True, 

212 is_positive_definite=True, 

213 is_square=True, 

214 assert_proper_shapes=False, 

215 name="LinearOperatorIdentity"): 

216 r"""Initialize a `LinearOperatorIdentity`. 

217 

218 The `LinearOperatorIdentity` is initialized with arguments defining `dtype` 

219 and shape. 

220 

221 This operator is able to broadcast the leading (batch) dimensions, which 

222 sometimes requires copying data. If `batch_shape` is `None`, the operator 

223 can take arguments of any batch shape without copying. See examples. 

224 

225 Args: 

226 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the 

227 corresponding identity matrix. 

228 batch_shape: Optional `1-D` integer `Tensor`. The shape of the leading 

229 dimensions. If `None`, this operator has no leading dimensions. 

230 dtype: Data type of the matrix that this operator represents. 

231 is_non_singular: Expect that this operator is non-singular. 

232 is_self_adjoint: Expect that this operator is equal to its hermitian 

233 transpose. 

234 is_positive_definite: Expect that this operator is positive definite, 

235 meaning the quadratic form `x^H A x` has positive real part for all 

236 nonzero `x`. Note that we do not require the operator to be 

237 self-adjoint to be positive-definite. See: 

238 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 

239 is_square: Expect that this operator acts like square [batch] matrices. 

240 assert_proper_shapes: Python `bool`. If `False`, only perform static 

241 checks that initialization and method arguments have proper shape. 

242 If `True`, and static checks are inconclusive, add asserts to the graph. 

243 name: A name for this `LinearOperator` 

244 

245 Raises: 

246 ValueError: If `num_rows` is determined statically to be non-scalar, or 

247 negative. 

248 ValueError: If `batch_shape` is determined statically to not be 1-D, or 

249 negative. 

250 ValueError: If any of the following is not `True`: 

251 `{is_self_adjoint, is_non_singular, is_positive_definite}`. 

252 TypeError: If `num_rows` or `batch_shape` is ref-type (e.g. Variable). 

253 """ 

254 parameters = dict( 

255 num_rows=num_rows, 

256 batch_shape=batch_shape, 

257 dtype=dtype, 

258 is_non_singular=is_non_singular, 

259 is_self_adjoint=is_self_adjoint, 

260 is_positive_definite=is_positive_definite, 

261 is_square=is_square, 

262 assert_proper_shapes=assert_proper_shapes, 

263 name=name) 

264 

265 dtype = dtype or dtypes.float32 

266 self._assert_proper_shapes = assert_proper_shapes 

267 

268 with ops.name_scope(name): 

269 dtype = dtypes.as_dtype(dtype) 

270 if not is_self_adjoint: 

271 raise ValueError("An identity operator is always self adjoint.") 

272 if not is_non_singular: 

273 raise ValueError("An identity operator is always non-singular.") 

274 if not is_positive_definite: 

275 raise ValueError("An identity operator is always positive-definite.") 

276 if not is_square: 

277 raise ValueError("An identity operator is always square.") 

278 

279 super(LinearOperatorIdentity, self).__init__( 

280 dtype=dtype, 

281 is_non_singular=is_non_singular, 

282 is_self_adjoint=is_self_adjoint, 

283 is_positive_definite=is_positive_definite, 

284 is_square=is_square, 

285 parameters=parameters, 

286 name=name) 

287 

288 linear_operator_util.assert_not_ref_type(num_rows, "num_rows") 

289 linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape") 

290 

291 self._num_rows = linear_operator_util.shape_tensor( 

292 num_rows, name="num_rows") 

293 self._num_rows_static = tensor_util.constant_value(self._num_rows) 

294 self._check_num_rows_possibly_add_asserts() 

295 

296 if batch_shape is None: 

297 self._batch_shape_arg = None 

298 else: 

299 self._batch_shape_arg = linear_operator_util.shape_tensor( 

300 batch_shape, name="batch_shape_arg") 

301 self._batch_shape_static = tensor_util.constant_value( 

302 self._batch_shape_arg) 

303 self._check_batch_shape_possibly_add_asserts() 

304 

305 def _shape(self): 

306 matrix_shape = tensor_shape.TensorShape((self._num_rows_static, 

307 self._num_rows_static)) 

308 if self._batch_shape_arg is None: 

309 return matrix_shape 

310 

311 batch_shape = tensor_shape.TensorShape(self._batch_shape_static) 

312 return batch_shape.concatenate(matrix_shape) 

313 

314 def _shape_tensor(self): 

315 matrix_shape = array_ops_stack.stack( 

316 (self._num_rows, self._num_rows), axis=0) 

317 if self._batch_shape_arg is None: 

318 return matrix_shape 

319 

320 return array_ops.concat((self._batch_shape_arg, matrix_shape), 0) 

321 

322 def _assert_non_singular(self): 

323 return control_flow_ops.no_op("assert_non_singular") 

324 

325 def _assert_positive_definite(self): 

326 return control_flow_ops.no_op("assert_positive_definite") 

327 

328 def _assert_self_adjoint(self): 

329 return control_flow_ops.no_op("assert_self_adjoint") 

330 

331 def _possibly_broadcast_batch_shape(self, x): 

332 """Return 'x', possibly after broadcasting the leading dimensions.""" 

333 # If we have no batch shape, our batch shape broadcasts with everything! 

334 if self._batch_shape_arg is None: 

335 return x 

336 

337 # Static attempt: 

338 # If we determine that no broadcast is necessary, pass x through 

339 # If we need a broadcast, add to an array of zeros. 

340 # 

341 # special_shape is the shape that, when broadcast with x's shape, will give 

342 # the correct broadcast_shape. Note that 

343 # We have already verified the second to last dimension of self.shape 

344 # matches x's shape in assert_compatible_matrix_dimensions. 

345 # Also, the final dimension of 'x' can have any shape. 

346 # Therefore, the final two dimensions of special_shape are 1's. 

347 special_shape = self.batch_shape.concatenate([1, 1]) 

348 bshape = array_ops.broadcast_static_shape(x.shape, special_shape) 

349 if special_shape.is_fully_defined(): 

350 # bshape.is_fully_defined iff special_shape.is_fully_defined. 

351 if bshape == x.shape: 

352 return x 

353 # Use the built in broadcasting of addition. 

354 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) 

355 return x + zeros 

356 

357 # Dynamic broadcast: 

358 # Always add to an array of zeros, rather than using a "cond", since a 

359 # cond would require copying data from GPU --> CPU. 

360 special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0) 

361 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) 

362 return x + zeros 

363 

364 def _matmul(self, x, adjoint=False, adjoint_arg=False): 

365 # Note that adjoint has no effect since this matrix is self-adjoint. 

366 x = linalg.adjoint(x) if adjoint_arg else x 

367 if self._assert_proper_shapes: 

368 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) 

369 x = control_flow_ops.with_dependencies([aps], x) 

370 return self._possibly_broadcast_batch_shape(x) 

371 

372 def _determinant(self): 

373 return array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype) 

374 

375 def _log_abs_determinant(self): 

376 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) 

377 

378 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 

379 return self._matmul(rhs, adjoint_arg=adjoint_arg) 

380 

381 def _trace(self): 

382 # Get Tensor of all ones of same shape as self.batch_shape. 

383 if self.batch_shape.is_fully_defined(): 

384 batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype) 

385 else: 

386 batch_of_ones = array_ops.ones( 

387 shape=self.batch_shape_tensor(), dtype=self.dtype) 

388 

389 if self._min_matrix_dim() is not None: 

390 return self._min_matrix_dim() * batch_of_ones 

391 else: 

392 return (math_ops.cast(self._min_matrix_dim_tensor(), self.dtype) * 

393 batch_of_ones) 

394 

395 def _diag_part(self): 

396 return self._ones_diag() 

397 

398 def add_to_tensor(self, mat, name="add_to_tensor"): 

399 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. 

400 

401 Args: 

402 mat: `Tensor` with same `dtype` and shape broadcastable to `self`. 

403 name: A name to give this `Op`. 

404 

405 Returns: 

406 A `Tensor` with broadcast shape and same `dtype` as `self`. 

407 """ 

408 with self._name_scope(name): # pylint: disable=not-callable 

409 mat = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

410 mat, name="mat" 

411 ) 

412 mat_diag = array_ops.matrix_diag_part(mat) 

413 new_diag = 1 + mat_diag 

414 return array_ops.matrix_set_diag(mat, new_diag) 

415 

416 def _eigvals(self): 

417 return self._ones_diag() 

418 

419 def _cond(self): 

420 return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype) 

421 

422 def _check_num_rows_possibly_add_asserts(self): 

423 """Static check of init arg `num_rows`, possibly add asserts.""" 

424 # Possibly add asserts. 

425 if self._assert_proper_shapes: 

426 self._num_rows = control_flow_ops.with_dependencies([ 

427 check_ops.assert_rank( 

428 self._num_rows, 

429 0, 

430 message="Argument num_rows must be a 0-D Tensor."), 

431 check_ops.assert_non_negative( 

432 self._num_rows, 

433 message="Argument num_rows must be non-negative."), 

434 ], self._num_rows) 

435 

436 # Static checks. 

437 if not self._num_rows.dtype.is_integer: 

438 raise TypeError("Argument num_rows must be integer type. Found:" 

439 " %s" % self._num_rows) 

440 

441 num_rows_static = self._num_rows_static 

442 

443 if num_rows_static is None: 

444 return # Cannot do any other static checks. 

445 

446 if num_rows_static.ndim != 0: 

447 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" 

448 " %s" % num_rows_static) 

449 

450 if num_rows_static < 0: 

451 raise ValueError("Argument num_rows must be non-negative. Found:" 

452 " %s" % num_rows_static) 

453 

454 def _check_batch_shape_possibly_add_asserts(self): 

455 """Static check of init arg `batch_shape`, possibly add asserts.""" 

456 if self._batch_shape_arg is None: 

457 return 

458 

459 # Possibly add asserts 

460 if self._assert_proper_shapes: 

461 self._batch_shape_arg = control_flow_ops.with_dependencies([ 

462 check_ops.assert_rank( 

463 self._batch_shape_arg, 

464 1, 

465 message="Argument batch_shape must be a 1-D Tensor."), 

466 check_ops.assert_non_negative( 

467 self._batch_shape_arg, 

468 message="Argument batch_shape must be non-negative."), 

469 ], self._batch_shape_arg) 

470 

471 # Static checks 

472 if not self._batch_shape_arg.dtype.is_integer: 

473 raise TypeError("Argument batch_shape must be integer type. Found:" 

474 " %s" % self._batch_shape_arg) 

475 

476 if self._batch_shape_static is None: 

477 return # Cannot do any other static checks. 

478 

479 if self._batch_shape_static.ndim != 1: 

480 raise ValueError("Argument batch_shape must be a 1-D Tensor. Found:" 

481 " %s" % self._batch_shape_static) 

482 

483 if np.any(self._batch_shape_static < 0): 

484 raise ValueError("Argument batch_shape must be non-negative. Found:" 

485 "%s" % self._batch_shape_static) 

486 

487 @property 

488 def _composite_tensor_prefer_static_fields(self): 

489 return ("num_rows", "batch_shape") 

490 

491 @property 

492 def _composite_tensor_fields(self): 

493 return ("num_rows", "batch_shape", "dtype", "assert_proper_shapes") 

494 

495 def __getitem__(self, slices): 

496 # Slice the batch shape and return a new LinearOperatorIdentity. 

497 # Use a proxy shape and slice it. Use this as the new batch shape 

498 new_batch_shape = array_ops.shape( 

499 array_ops.ones(self._batch_shape_arg)[slices]) 

500 parameters = dict(self.parameters, batch_shape=new_batch_shape) 

501 return LinearOperatorIdentity(**parameters) 

502 

503 

504@tf_export("linalg.LinearOperatorScaledIdentity") 

505@linear_operator.make_composite_tensor 

506class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): 

507 """`LinearOperator` acting like a scaled [batch] identity matrix `A = c I`. 

508 

509 This operator acts like a scaled [batch] identity matrix `A` with shape 

510 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 

511 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 

512 a scaled version of the `N x N` identity matrix. 

513 

514 `LinearOperatorIdentity` is initialized with `num_rows`, and a `multiplier` 

515 (a `Tensor`) of shape `[B1,...,Bb]`. `N` is set to `num_rows`, and the 

516 `multiplier` determines the scale for each batch member. 

517 

518 ```python 

519 # Create a 2 x 2 scaled identity matrix. 

520 operator = LinearOperatorIdentity(num_rows=2, multiplier=3.) 

521 

522 operator.to_dense() 

523 ==> [[3., 0.] 

524 [0., 3.]] 

525 

526 operator.shape 

527 ==> [2, 2] 

528 

529 operator.log_abs_determinant() 

530 ==> 2 * Log[3] 

531 

532 x = ... Shape [2, 4] Tensor 

533 operator.matmul(x) 

534 ==> 3 * x 

535 

536 y = tf.random.normal(shape=[3, 2, 4]) 

537 # Note that y.shape is compatible with operator.shape because operator.shape 

538 # is broadcast to [3, 2, 2]. 

539 x = operator.solve(y) 

540 ==> 3 * x 

541 

542 # Create a 2-batch of 2x2 identity matrices 

543 operator = LinearOperatorIdentity(num_rows=2, multiplier=5.) 

544 operator.to_dense() 

545 ==> [[[5., 0.] 

546 [0., 5.]], 

547 [[5., 0.] 

548 [0., 5.]]] 

549 

550 x = ... Shape [2, 2, 3] 

551 operator.matmul(x) 

552 ==> 5 * x 

553 

554 # Here the operator and x have different batch_shape, and are broadcast. 

555 x = ... Shape [1, 2, 3] 

556 operator.matmul(x) 

557 ==> 5 * x 

558 ``` 

559 

560 ### Shape compatibility 

561 

562 This operator acts on [batch] matrix with compatible shape. 

563 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 

564 

565 ``` 

566 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 

567 x.shape = [C1,...,Cc] + [N, R], 

568 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 

569 ``` 

570 

571 ### Performance 

572 

573 * `operator.matmul(x)` is `O(D1*...*Dd*N*R)` 

574 * `operator.solve(x)` is `O(D1*...*Dd*N*R)` 

575 * `operator.determinant()` is `O(D1*...*Dd)` 

576 

577 #### Matrix property hints 

578 

579 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 

580 for `X = non_singular, self_adjoint, positive_definite, square`. 

581 These have the following meaning 

582 * If `is_X == True`, callers should expect the operator to have the 

583 property `X`. This is a promise that should be fulfilled, but is *not* a 

584 runtime assert. For example, finite floating point precision may result 

585 in these promises being violated. 

586 * If `is_X == False`, callers should expect the operator to not have `X`. 

587 * If `is_X == None` (the default), callers should have no expectation either 

588 way. 

589 """ 

590 

591 def __init__(self, 

592 num_rows, 

593 multiplier, 

594 is_non_singular=None, 

595 is_self_adjoint=None, 

596 is_positive_definite=None, 

597 is_square=True, 

598 assert_proper_shapes=False, 

599 name="LinearOperatorScaledIdentity"): 

600 r"""Initialize a `LinearOperatorScaledIdentity`. 

601 

602 The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which 

603 determines the size of each identity matrix, and a `multiplier`, 

604 which defines `dtype`, batch shape, and scale of each matrix. 

605 

606 This operator is able to broadcast the leading (batch) dimensions. 

607 

608 Args: 

609 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the 

610 corresponding identity matrix. 

611 multiplier: `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar). 

612 is_non_singular: Expect that this operator is non-singular. 

613 is_self_adjoint: Expect that this operator is equal to its hermitian 

614 transpose. 

615 is_positive_definite: Expect that this operator is positive definite, 

616 meaning the quadratic form `x^H A x` has positive real part for all 

617 nonzero `x`. Note that we do not require the operator to be 

618 self-adjoint to be positive-definite. See: 

619 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 

620 is_square: Expect that this operator acts like square [batch] matrices. 

621 assert_proper_shapes: Python `bool`. If `False`, only perform static 

622 checks that initialization and method arguments have proper shape. 

623 If `True`, and static checks are inconclusive, add asserts to the graph. 

624 name: A name for this `LinearOperator` 

625 

626 Raises: 

627 ValueError: If `num_rows` is determined statically to be non-scalar, or 

628 negative. 

629 """ 

630 parameters = dict( 

631 num_rows=num_rows, 

632 multiplier=multiplier, 

633 is_non_singular=is_non_singular, 

634 is_self_adjoint=is_self_adjoint, 

635 is_positive_definite=is_positive_definite, 

636 is_square=is_square, 

637 assert_proper_shapes=assert_proper_shapes, 

638 name=name) 

639 

640 self._assert_proper_shapes = assert_proper_shapes 

641 

642 with ops.name_scope(name, values=[multiplier, num_rows]): 

643 self._multiplier = linear_operator_util.convert_nonref_to_tensor( 

644 multiplier, name="multiplier") 

645 

646 # Check and auto-set hints. 

647 if not self._multiplier.dtype.is_complex: 

648 if is_self_adjoint is False: # pylint: disable=g-bool-id-comparison 

649 raise ValueError("A real diagonal operator is always self adjoint.") 

650 else: 

651 is_self_adjoint = True 

652 

653 if not is_square: 

654 raise ValueError("A ScaledIdentity operator is always square.") 

655 

656 linear_operator_util.assert_not_ref_type(num_rows, "num_rows") 

657 

658 super(LinearOperatorScaledIdentity, self).__init__( 

659 dtype=self._multiplier.dtype.base_dtype, 

660 is_non_singular=is_non_singular, 

661 is_self_adjoint=is_self_adjoint, 

662 is_positive_definite=is_positive_definite, 

663 is_square=is_square, 

664 parameters=parameters, 

665 name=name) 

666 

667 self._num_rows = linear_operator_util.shape_tensor( 

668 num_rows, name="num_rows") 

669 self._num_rows_static = tensor_util.constant_value(self._num_rows) 

670 self._check_num_rows_possibly_add_asserts() 

671 self._num_rows_cast_to_dtype = math_ops.cast(self._num_rows, self.dtype) 

672 self._num_rows_cast_to_real_dtype = math_ops.cast(self._num_rows, 

673 self.dtype.real_dtype) 

674 

675 def _shape(self): 

676 matrix_shape = tensor_shape.TensorShape((self._num_rows_static, 

677 self._num_rows_static)) 

678 

679 batch_shape = self.multiplier.shape 

680 return batch_shape.concatenate(matrix_shape) 

681 

682 def _shape_tensor(self): 

683 matrix_shape = array_ops_stack.stack( 

684 (self._num_rows, self._num_rows), axis=0) 

685 

686 batch_shape = array_ops.shape(self.multiplier) 

687 return array_ops.concat((batch_shape, matrix_shape), 0) 

688 

689 def _assert_non_singular(self): 

690 return check_ops.assert_positive( 

691 math_ops.abs(self.multiplier), message="LinearOperator was singular") 

692 

693 def _assert_positive_definite(self): 

694 return check_ops.assert_positive( 

695 math_ops.real(self.multiplier), 

696 message="LinearOperator was not positive definite.") 

697 

698 def _assert_self_adjoint(self): 

699 imag_multiplier = math_ops.imag(self.multiplier) 

700 return check_ops.assert_equal( 

701 array_ops.zeros_like(imag_multiplier), 

702 imag_multiplier, 

703 message="LinearOperator was not self-adjoint") 

704 

705 def _make_multiplier_matrix(self, conjugate=False): 

706 # Shape [B1,...Bb, 1, 1] 

707 multiplier_matrix = array_ops.expand_dims( 

708 array_ops.expand_dims(self.multiplier, -1), -1) 

709 if conjugate: 

710 multiplier_matrix = math_ops.conj(multiplier_matrix) 

711 return multiplier_matrix 

712 

713 def _matmul(self, x, adjoint=False, adjoint_arg=False): 

714 x = linalg.adjoint(x) if adjoint_arg else x 

715 if self._assert_proper_shapes: 

716 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) 

717 x = control_flow_ops.with_dependencies([aps], x) 

718 return x * self._make_multiplier_matrix(conjugate=adjoint) 

719 

720 def _determinant(self): 

721 return self.multiplier**self._num_rows_cast_to_dtype 

722 

723 def _log_abs_determinant(self): 

724 return self._num_rows_cast_to_real_dtype * math_ops.log( 

725 math_ops.abs(self.multiplier)) 

726 

727 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 

728 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 

729 if self._assert_proper_shapes: 

730 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs) 

731 rhs = control_flow_ops.with_dependencies([aps], rhs) 

732 return rhs / self._make_multiplier_matrix(conjugate=adjoint) 

733 

734 def _trace(self): 

735 # Get Tensor of all ones of same shape as self.batch_shape. 

736 if self.batch_shape.is_fully_defined(): 

737 batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype) 

738 else: 

739 batch_of_ones = array_ops.ones( 

740 shape=self.batch_shape_tensor(), dtype=self.dtype) 

741 

742 if self._min_matrix_dim() is not None: 

743 return self.multiplier * self._min_matrix_dim() * batch_of_ones 

744 else: 

745 return (self.multiplier * math_ops.cast(self._min_matrix_dim_tensor(), 

746 self.dtype) * batch_of_ones) 

747 

748 def _diag_part(self): 

749 return self._ones_diag() * self.multiplier[..., array_ops.newaxis] 

750 

751 def add_to_tensor(self, mat, name="add_to_tensor"): 

752 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. 

753 

754 Args: 

755 mat: `Tensor` with same `dtype` and shape broadcastable to `self`. 

756 name: A name to give this `Op`. 

757 

758 Returns: 

759 A `Tensor` with broadcast shape and same `dtype` as `self`. 

760 """ 

761 with self._name_scope(name): # pylint: disable=not-callable 

762 # Shape [B1,...,Bb, 1] 

763 multiplier_vector = array_ops.expand_dims(self.multiplier, -1) 

764 

765 # Shape [C1,...,Cc, M, M] 

766 mat = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

767 mat, name="mat" 

768 ) 

769 

770 # Shape [C1,...,Cc, M] 

771 mat_diag = array_ops.matrix_diag_part(mat) 

772 

773 # multiplier_vector broadcasts here. 

774 new_diag = multiplier_vector + mat_diag 

775 

776 return array_ops.matrix_set_diag(mat, new_diag) 

777 

778 def _eigvals(self): 

779 return self._ones_diag() * self.multiplier[..., array_ops.newaxis] 

780 

781 def _cond(self): 

782 # Condition number for a scalar time identity matrix is one, except when the 

783 # scalar is zero. 

784 return array_ops.where_v2( 

785 math_ops.equal(self._multiplier, 0.), 

786 math_ops.cast(np.nan, dtype=self.dtype), 

787 math_ops.cast(1., dtype=self.dtype)) 

788 

789 @property 

790 def multiplier(self): 

791 """The [batch] scalar `Tensor`, `c` in `cI`.""" 

792 return self._multiplier 

793 

794 @property 

795 def _composite_tensor_prefer_static_fields(self): 

796 return ("num_rows",) 

797 

798 @property 

799 def _composite_tensor_fields(self): 

800 return ("num_rows", "multiplier", "assert_proper_shapes") 

801 

802 @property 

803 def _experimental_parameter_ndims_to_matrix_ndims(self): 

804 return {"multiplier": 0}