Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator.py: 34%

493 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base class for linear operators.""" 

16 

17import abc 

18import contextlib 

19 

20import numpy as np 

21 

22from tensorflow.python.framework import composite_tensor 

23from tensorflow.python.framework import dtypes 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import tensor_conversion 

26from tensorflow.python.framework import tensor_shape 

27from tensorflow.python.framework import tensor_spec 

28from tensorflow.python.framework import tensor_util 

29from tensorflow.python.framework import type_spec 

30from tensorflow.python.framework import type_spec_registry 

31from tensorflow.python.module import module 

32from tensorflow.python.ops import array_ops 

33from tensorflow.python.ops import check_ops 

34from tensorflow.python.ops import linalg_ops 

35from tensorflow.python.ops import math_ops 

36from tensorflow.python.ops import resource_variable_ops 

37from tensorflow.python.ops import variables 

38from tensorflow.python.ops.linalg import linalg_impl as linalg 

39from tensorflow.python.ops.linalg import linear_operator_algebra 

40from tensorflow.python.ops.linalg import linear_operator_util 

41from tensorflow.python.ops.linalg import slicing 

42from tensorflow.python.platform import tf_logging as logging 

43from tensorflow.python.trackable import data_structures 

44from tensorflow.python.util import deprecation 

45from tensorflow.python.util import dispatch 

46from tensorflow.python.util import nest 

47from tensorflow.python.util import variable_utils 

48from tensorflow.python.util.tf_export import tf_export 

49 

50__all__ = ["LinearOperator"] 

51 

52 

53# TODO(langmore) Use matrix_solve_ls for singular or non-square matrices. 

54@tf_export("linalg.LinearOperator") 

55class LinearOperator( 

56 module.Module, composite_tensor.CompositeTensor, metaclass=abc.ABCMeta): 

57 """Base class defining a [batch of] linear operator[s]. 

58 

59 Subclasses of `LinearOperator` provide access to common methods on a 

60 (batch) matrix, without the need to materialize the matrix. This allows: 

61 

62 * Matrix free computations 

63 * Operators that take advantage of special structure, while providing a 

64 consistent API to users. 

65 

66 #### Subclassing 

67 

68 To enable a public method, subclasses should implement the leading-underscore 

69 version of the method. The argument signature should be identical except for 

70 the omission of `name="..."`. For example, to enable 

71 `matmul(x, adjoint=False, name="matmul")` a subclass should implement 

72 `_matmul(x, adjoint=False)`. 

73 

74 #### Performance contract 

75 

76 Subclasses should only implement the assert methods 

77 (e.g. `assert_non_singular`) if they can be done in less than `O(N^3)` 

78 time. 

79 

80 Class docstrings should contain an explanation of computational complexity. 

81 Since this is a high-performance library, attention should be paid to detail, 

82 and explanations can include constants as well as Big-O notation. 

83 

84 #### Shape compatibility 

85 

86 `LinearOperator` subclasses should operate on a [batch] matrix with 

87 compatible shape. Class docstrings should define what is meant by compatible 

88 shape. Some subclasses may not support batching. 

89 

90 Examples: 

91 

92 `x` is a batch matrix with compatible shape for `matmul` if 

93 

94 ``` 

95 operator.shape = [B1,...,Bb] + [M, N], b >= 0, 

96 x.shape = [B1,...,Bb] + [N, R] 

97 ``` 

98 

99 `rhs` is a batch matrix with compatible shape for `solve` if 

100 

101 ``` 

102 operator.shape = [B1,...,Bb] + [M, N], b >= 0, 

103 rhs.shape = [B1,...,Bb] + [M, R] 

104 ``` 

105 

106 #### Example docstring for subclasses. 

107 

108 This operator acts like a (batch) matrix `A` with shape 

109 `[B1,...,Bb, M, N]` for some `b >= 0`. The first `b` indices index a 

110 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 

111 an `m x n` matrix. Again, this matrix `A` may not be materialized, but for 

112 purposes of identifying and working with compatible arguments the shape is 

113 relevant. 

114 

115 Examples: 

116 

117 ```python 

118 some_tensor = ... shape = ???? 

119 operator = MyLinOp(some_tensor) 

120 

121 operator.shape() 

122 ==> [2, 4, 4] 

123 

124 operator.log_abs_determinant() 

125 ==> Shape [2] Tensor 

126 

127 x = ... Shape [2, 4, 5] Tensor 

128 

129 operator.matmul(x) 

130 ==> Shape [2, 4, 5] Tensor 

131 ``` 

132 

133 #### Shape compatibility 

134 

135 This operator acts on batch matrices with compatible shape. 

136 FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE 

137 

138 #### Performance 

139 

140 FILL THIS IN 

141 

142 #### Matrix property hints 

143 

144 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 

145 for `X = non_singular, self_adjoint, positive_definite, square`. 

146 These have the following meaning: 

147 

148 * If `is_X == True`, callers should expect the operator to have the 

149 property `X`. This is a promise that should be fulfilled, but is *not* a 

150 runtime assert. For example, finite floating point precision may result 

151 in these promises being violated. 

152 * If `is_X == False`, callers should expect the operator to not have `X`. 

153 * If `is_X == None` (the default), callers should have no expectation either 

154 way. 

155 

156 #### Initialization parameters 

157 

158 All subclasses of `LinearOperator` are expected to pass a `parameters` 

159 argument to `super().__init__()`. This should be a `dict` containing 

160 the unadulterated arguments passed to the subclass `__init__`. For example, 

161 `MyLinearOperator` with an initializer should look like: 

162 

163 ```python 

164 def __init__(self, operator, is_square=False, name=None): 

165 parameters = dict( 

166 operator=operator, 

167 is_square=is_square, 

168 name=name 

169 ) 

170 ... 

171 super().__init__(..., parameters=parameters) 

172 ``` 

173 

174 Users can then access `my_linear_operator.parameters` to see all arguments 

175 passed to its initializer. 

176 """ 

177 

178 # TODO(b/143910018) Remove graph_parents in V3. 

179 @deprecation.deprecated_args(None, "Do not pass `graph_parents`. They will " 

180 " no longer be used.", "graph_parents") 

181 def __init__(self, 

182 dtype, 

183 graph_parents=None, 

184 is_non_singular=None, 

185 is_self_adjoint=None, 

186 is_positive_definite=None, 

187 is_square=None, 

188 name=None, 

189 parameters=None): 

190 """Initialize the `LinearOperator`. 

191 

192 **This is a private method for subclass use.** 

193 **Subclasses should copy-paste this `__init__` documentation.** 

194 

195 Args: 

196 dtype: The type of the this `LinearOperator`. Arguments to `matmul` and 

197 `solve` will have to be this type. 

198 graph_parents: (Deprecated) Python list of graph prerequisites of this 

199 `LinearOperator` Typically tensors that are passed during initialization 

200 is_non_singular: Expect that this operator is non-singular. 

201 is_self_adjoint: Expect that this operator is equal to its hermitian 

202 transpose. If `dtype` is real, this is equivalent to being symmetric. 

203 is_positive_definite: Expect that this operator is positive definite, 

204 meaning the quadratic form `x^H A x` has positive real part for all 

205 nonzero `x`. Note that we do not require the operator to be 

206 self-adjoint to be positive-definite. See: 

207 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 

208 is_square: Expect that this operator acts like square [batch] matrices. 

209 name: A name for this `LinearOperator`. 

210 parameters: Python `dict` of parameters used to instantiate this 

211 `LinearOperator`. 

212 

213 Raises: 

214 ValueError: If any member of graph_parents is `None` or not a `Tensor`. 

215 ValueError: If hints are set incorrectly. 

216 """ 

217 # Check and auto-set flags. 

218 if is_positive_definite: 

219 if is_non_singular is False: 

220 raise ValueError("A positive definite matrix is always non-singular.") 

221 is_non_singular = True 

222 

223 if is_non_singular: 

224 if is_square is False: 

225 raise ValueError("A non-singular matrix is always square.") 

226 is_square = True 

227 

228 if is_self_adjoint: 

229 if is_square is False: 

230 raise ValueError("A self-adjoint matrix is always square.") 

231 is_square = True 

232 

233 self._is_square_set_or_implied_by_hints = is_square 

234 

235 if graph_parents is not None: 

236 self._set_graph_parents(graph_parents) 

237 else: 

238 self._graph_parents = [] 

239 self._dtype = dtypes.as_dtype(dtype).base_dtype if dtype else dtype 

240 self._is_non_singular = is_non_singular 

241 self._is_self_adjoint = is_self_adjoint 

242 self._is_positive_definite = is_positive_definite 

243 self._parameters = self._no_dependency(parameters) 

244 self._parameters_sanitized = False 

245 self._name = name or type(self).__name__ 

246 

247 @contextlib.contextmanager 

248 def _name_scope(self, name=None): # pylint: disable=method-hidden 

249 """Helper function to standardize op scope.""" 

250 full_name = self.name 

251 if name is not None: 

252 full_name += "/" + name 

253 with ops.name_scope(full_name) as scope: 

254 yield scope 

255 

256 @property 

257 def parameters(self): 

258 """Dictionary of parameters used to instantiate this `LinearOperator`.""" 

259 return dict(self._parameters) 

260 

261 @property 

262 def dtype(self): 

263 """The `DType` of `Tensor`s handled by this `LinearOperator`.""" 

264 return self._dtype 

265 

266 @property 

267 def name(self): 

268 """Name prepended to all ops created by this `LinearOperator`.""" 

269 return self._name 

270 

271 @property 

272 @deprecation.deprecated(None, "Do not call `graph_parents`.") 

273 def graph_parents(self): 

274 """List of graph dependencies of this `LinearOperator`.""" 

275 return self._graph_parents 

276 

277 @property 

278 def is_non_singular(self): 

279 return self._is_non_singular 

280 

281 @property 

282 def is_self_adjoint(self): 

283 return self._is_self_adjoint 

284 

285 @property 

286 def is_positive_definite(self): 

287 return self._is_positive_definite 

288 

289 @property 

290 def is_square(self): 

291 """Return `True/False` depending on if this operator is square.""" 

292 # Static checks done after __init__. Why? Because domain/range dimension 

293 # sometimes requires lots of work done in the derived class after init. 

294 auto_square_check = self.domain_dimension == self.range_dimension 

295 if self._is_square_set_or_implied_by_hints is False and auto_square_check: 

296 raise ValueError( 

297 "User set is_square hint to False, but the operator was square.") 

298 if self._is_square_set_or_implied_by_hints is None: 

299 return auto_square_check 

300 

301 return self._is_square_set_or_implied_by_hints 

302 

303 @abc.abstractmethod 

304 def _shape(self): 

305 # Write this in derived class to enable all static shape methods. 

306 raise NotImplementedError("_shape is not implemented.") 

307 

308 @property 

309 def shape(self): 

310 """`TensorShape` of this `LinearOperator`. 

311 

312 If this operator acts like the batch matrix `A` with 

313 `A.shape = [B1,...,Bb, M, N]`, then this returns 

314 `TensorShape([B1,...,Bb, M, N])`, equivalent to `A.shape`. 

315 

316 Returns: 

317 `TensorShape`, statically determined, may be undefined. 

318 """ 

319 return self._shape() 

320 

321 def _shape_tensor(self): 

322 # This is not an abstractmethod, since we want derived classes to be able to 

323 # override this with optional kwargs, which can reduce the number of 

324 # `convert_to_tensor` calls. See derived classes for examples. 

325 raise NotImplementedError("_shape_tensor is not implemented.") 

326 

327 def shape_tensor(self, name="shape_tensor"): 

328 """Shape of this `LinearOperator`, determined at runtime. 

329 

330 If this operator acts like the batch matrix `A` with 

331 `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding 

332 `[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`. 

333 

334 Args: 

335 name: A name for this `Op`. 

336 

337 Returns: 

338 `int32` `Tensor` 

339 """ 

340 with self._name_scope(name): # pylint: disable=not-callable 

341 # Prefer to use statically defined shape if available. 

342 if self.shape.is_fully_defined(): 

343 return linear_operator_util.shape_tensor(self.shape.as_list()) 

344 else: 

345 return self._shape_tensor() 

346 

347 @property 

348 def batch_shape(self): 

349 """`TensorShape` of batch dimensions of this `LinearOperator`. 

350 

351 If this operator acts like the batch matrix `A` with 

352 `A.shape = [B1,...,Bb, M, N]`, then this returns 

353 `TensorShape([B1,...,Bb])`, equivalent to `A.shape[:-2]` 

354 

355 Returns: 

356 `TensorShape`, statically determined, may be undefined. 

357 """ 

358 # Derived classes get this "for free" once .shape is implemented. 

359 return self.shape[:-2] 

360 

361 def batch_shape_tensor(self, name="batch_shape_tensor"): 

362 """Shape of batch dimensions of this operator, determined at runtime. 

363 

364 If this operator acts like the batch matrix `A` with 

365 `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding 

366 `[B1,...,Bb]`. 

367 

368 Args: 

369 name: A name for this `Op`. 

370 

371 Returns: 

372 `int32` `Tensor` 

373 """ 

374 # Derived classes get this "for free" once .shape() is implemented. 

375 with self._name_scope(name): # pylint: disable=not-callable 

376 return self._batch_shape_tensor() 

377 

378 def _batch_shape_tensor(self, shape=None): 

379 # `shape` may be passed in if this can be pre-computed in a 

380 # more efficient manner, e.g. without excessive Tensor conversions. 

381 if self.batch_shape.is_fully_defined(): 

382 return linear_operator_util.shape_tensor( 

383 self.batch_shape.as_list(), name="batch_shape") 

384 else: 

385 shape = self.shape_tensor() if shape is None else shape 

386 return shape[:-2] 

387 

388 @property 

389 def tensor_rank(self, name="tensor_rank"): 

390 """Rank (in the sense of tensors) of matrix corresponding to this operator. 

391 

392 If this operator acts like the batch matrix `A` with 

393 `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. 

394 

395 Args: 

396 name: A name for this `Op`. 

397 

398 Returns: 

399 Python integer, or None if the tensor rank is undefined. 

400 """ 

401 # Derived classes get this "for free" once .shape() is implemented. 

402 with self._name_scope(name): # pylint: disable=not-callable 

403 return self.shape.ndims 

404 

405 def tensor_rank_tensor(self, name="tensor_rank_tensor"): 

406 """Rank (in the sense of tensors) of matrix corresponding to this operator. 

407 

408 If this operator acts like the batch matrix `A` with 

409 `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. 

410 

411 Args: 

412 name: A name for this `Op`. 

413 

414 Returns: 

415 `int32` `Tensor`, determined at runtime. 

416 """ 

417 # Derived classes get this "for free" once .shape() is implemented. 

418 with self._name_scope(name): # pylint: disable=not-callable 

419 return self._tensor_rank_tensor() 

420 

421 def _tensor_rank_tensor(self, shape=None): 

422 # `shape` may be passed in if this can be pre-computed in a 

423 # more efficient manner, e.g. without excessive Tensor conversions. 

424 if self.tensor_rank is not None: 

425 return tensor_conversion.convert_to_tensor_v2_with_dispatch( 

426 self.tensor_rank 

427 ) 

428 else: 

429 shape = self.shape_tensor() if shape is None else shape 

430 return array_ops.size(shape) 

431 

432 @property 

433 def domain_dimension(self): 

434 """Dimension (in the sense of vector spaces) of the domain of this operator. 

435 

436 If this operator acts like the batch matrix `A` with 

437 `A.shape = [B1,...,Bb, M, N]`, then this returns `N`. 

438 

439 Returns: 

440 `Dimension` object. 

441 """ 

442 # Derived classes get this "for free" once .shape is implemented. 

443 if self.shape.rank is None: 

444 return tensor_shape.Dimension(None) 

445 else: 

446 return self.shape.dims[-1] 

447 

448 def domain_dimension_tensor(self, name="domain_dimension_tensor"): 

449 """Dimension (in the sense of vector spaces) of the domain of this operator. 

450 

451 Determined at runtime. 

452 

453 If this operator acts like the batch matrix `A` with 

454 `A.shape = [B1,...,Bb, M, N]`, then this returns `N`. 

455 

456 Args: 

457 name: A name for this `Op`. 

458 

459 Returns: 

460 `int32` `Tensor` 

461 """ 

462 # Derived classes get this "for free" once .shape() is implemented. 

463 with self._name_scope(name): # pylint: disable=not-callable 

464 return self._domain_dimension_tensor() 

465 

466 def _domain_dimension_tensor(self, shape=None): 

467 # `shape` may be passed in if this can be pre-computed in a 

468 # more efficient manner, e.g. without excessive Tensor conversions. 

469 dim_value = tensor_shape.dimension_value(self.domain_dimension) 

470 if dim_value is not None: 

471 return tensor_conversion.convert_to_tensor_v2_with_dispatch(dim_value) 

472 else: 

473 shape = self.shape_tensor() if shape is None else shape 

474 return shape[-1] 

475 

476 @property 

477 def range_dimension(self): 

478 """Dimension (in the sense of vector spaces) of the range of this operator. 

479 

480 If this operator acts like the batch matrix `A` with 

481 `A.shape = [B1,...,Bb, M, N]`, then this returns `M`. 

482 

483 Returns: 

484 `Dimension` object. 

485 """ 

486 # Derived classes get this "for free" once .shape is implemented. 

487 if self.shape.dims: 

488 return self.shape.dims[-2] 

489 else: 

490 return tensor_shape.Dimension(None) 

491 

492 def range_dimension_tensor(self, name="range_dimension_tensor"): 

493 """Dimension (in the sense of vector spaces) of the range of this operator. 

494 

495 Determined at runtime. 

496 

497 If this operator acts like the batch matrix `A` with 

498 `A.shape = [B1,...,Bb, M, N]`, then this returns `M`. 

499 

500 Args: 

501 name: A name for this `Op`. 

502 

503 Returns: 

504 `int32` `Tensor` 

505 """ 

506 # Derived classes get this "for free" once .shape() is implemented. 

507 with self._name_scope(name): # pylint: disable=not-callable 

508 return self._range_dimension_tensor() 

509 

510 def _range_dimension_tensor(self, shape=None): 

511 # `shape` may be passed in if this can be pre-computed in a 

512 # more efficient manner, e.g. without excessive Tensor conversions. 

513 dim_value = tensor_shape.dimension_value(self.range_dimension) 

514 if dim_value is not None: 

515 return tensor_conversion.convert_to_tensor_v2_with_dispatch(dim_value) 

516 else: 

517 shape = self.shape_tensor() if shape is None else shape 

518 return shape[-2] 

519 

520 def _assert_non_singular(self): 

521 """Private default implementation of _assert_non_singular.""" 

522 logging.warn( 

523 "Using (possibly slow) default implementation of assert_non_singular." 

524 " Requires conversion to a dense matrix and O(N^3) operations.") 

525 if self._can_use_cholesky(): 

526 return self.assert_positive_definite() 

527 else: 

528 singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False) 

529 # TODO(langmore) Add .eig and .cond as methods. 

530 cond = (math_ops.reduce_max(singular_values, axis=-1) / 

531 math_ops.reduce_min(singular_values, axis=-1)) 

532 return check_ops.assert_less( 

533 cond, 

534 self._max_condition_number_to_be_non_singular(), 

535 message="Singular matrix up to precision epsilon.") 

536 

537 def _max_condition_number_to_be_non_singular(self): 

538 """Return the maximum condition number that we consider nonsingular.""" 

539 with ops.name_scope("max_nonsingular_condition_number"): 

540 dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps 

541 eps = math_ops.cast( 

542 math_ops.reduce_max([ 

543 100., 

544 math_ops.cast(self.range_dimension_tensor(), self.dtype), 

545 math_ops.cast(self.domain_dimension_tensor(), self.dtype) 

546 ]), self.dtype) * dtype_eps 

547 return 1. / eps 

548 

549 def assert_non_singular(self, name="assert_non_singular"): 

550 """Returns an `Op` that asserts this operator is non singular. 

551 

552 This operator is considered non-singular if 

553 

554 ``` 

555 ConditionNumber < max{100, range_dimension, domain_dimension} * eps, 

556 eps := np.finfo(self.dtype.as_numpy_dtype).eps 

557 ``` 

558 

559 Args: 

560 name: A string name to prepend to created ops. 

561 

562 Returns: 

563 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 

564 the operator is singular. 

565 """ 

566 with self._name_scope(name): # pylint: disable=not-callable 

567 return self._assert_non_singular() 

568 

569 def _assert_positive_definite(self): 

570 """Default implementation of _assert_positive_definite.""" 

571 logging.warn( 

572 "Using (possibly slow) default implementation of " 

573 "assert_positive_definite." 

574 " Requires conversion to a dense matrix and O(N^3) operations.") 

575 # If the operator is self-adjoint, then checking that 

576 # Cholesky decomposition succeeds + results in positive diag is necessary 

577 # and sufficient. 

578 if self.is_self_adjoint: 

579 return check_ops.assert_positive( 

580 array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())), 

581 message="Matrix was not positive definite.") 

582 # We have no generic check for positive definite. 

583 raise NotImplementedError("assert_positive_definite is not implemented.") 

584 

585 def assert_positive_definite(self, name="assert_positive_definite"): 

586 """Returns an `Op` that asserts this operator is positive definite. 

587 

588 Here, positive definite means that the quadratic form `x^H A x` has positive 

589 real part for all nonzero `x`. Note that we do not require the operator to 

590 be self-adjoint to be positive definite. 

591 

592 Args: 

593 name: A name to give this `Op`. 

594 

595 Returns: 

596 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 

597 the operator is not positive definite. 

598 """ 

599 with self._name_scope(name): # pylint: disable=not-callable 

600 return self._assert_positive_definite() 

601 

602 def _assert_self_adjoint(self): 

603 dense = self.to_dense() 

604 logging.warn( 

605 "Using (possibly slow) default implementation of assert_self_adjoint." 

606 " Requires conversion to a dense matrix.") 

607 return check_ops.assert_equal( 

608 dense, 

609 linalg.adjoint(dense), 

610 message="Matrix was not equal to its adjoint.") 

611 

612 def assert_self_adjoint(self, name="assert_self_adjoint"): 

613 """Returns an `Op` that asserts this operator is self-adjoint. 

614 

615 Here we check that this operator is *exactly* equal to its hermitian 

616 transpose. 

617 

618 Args: 

619 name: A string name to prepend to created ops. 

620 

621 Returns: 

622 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 

623 the operator is not self-adjoint. 

624 """ 

625 with self._name_scope(name): # pylint: disable=not-callable 

626 return self._assert_self_adjoint() 

627 

628 def _check_input_dtype(self, arg): 

629 """Check that arg.dtype == self.dtype.""" 

630 if arg.dtype.base_dtype != self.dtype: 

631 raise TypeError( 

632 "Expected argument to have dtype %s. Found: %s in tensor %s" % 

633 (self.dtype, arg.dtype, arg)) 

634 

635 @abc.abstractmethod 

636 def _matmul(self, x, adjoint=False, adjoint_arg=False): 

637 raise NotImplementedError("_matmul is not implemented.") 

638 

639 def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): 

640 """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. 

641 

642 ```python 

643 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 

644 operator = LinearOperator(...) 

645 operator.shape = [..., M, N] 

646 

647 X = ... # shape [..., N, R], batch matrix, R > 0. 

648 

649 Y = operator.matmul(X) 

650 Y.shape 

651 ==> [..., M, R] 

652 

653 Y[..., :, r] = sum_j A[..., :, j] X[j, r] 

654 ``` 

655 

656 Args: 

657 x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as 

658 `self`. See class docstring for definition of compatibility. 

659 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 

660 adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is 

661 the hermitian transpose (transposition and complex conjugation). 

662 name: A name for this `Op`. 

663 

664 Returns: 

665 A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype` 

666 as `self`. 

667 """ 

668 if isinstance(x, LinearOperator): 

669 left_operator = self.adjoint() if adjoint else self 

670 right_operator = x.adjoint() if adjoint_arg else x 

671 

672 if (right_operator.range_dimension is not None and 

673 left_operator.domain_dimension is not None and 

674 right_operator.range_dimension != left_operator.domain_dimension): 

675 raise ValueError( 

676 "Operators are incompatible. Expected `x` to have dimension" 

677 " {} but got {}.".format( 

678 left_operator.domain_dimension, right_operator.range_dimension)) 

679 with self._name_scope(name): # pylint: disable=not-callable 

680 return linear_operator_algebra.matmul(left_operator, right_operator) 

681 

682 with self._name_scope(name): # pylint: disable=not-callable 

683 x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x, name="x") 

684 self._check_input_dtype(x) 

685 

686 self_dim = -2 if adjoint else -1 

687 arg_dim = -1 if adjoint_arg else -2 

688 tensor_shape.dimension_at_index( 

689 self.shape, self_dim).assert_is_compatible_with( 

690 x.shape[arg_dim]) 

691 

692 return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 

693 

694 def __matmul__(self, other): 

695 return self.matmul(other) 

696 

697 def _matvec(self, x, adjoint=False): 

698 x_mat = array_ops.expand_dims(x, axis=-1) 

699 y_mat = self.matmul(x_mat, adjoint=adjoint) 

700 return array_ops.squeeze(y_mat, axis=-1) 

701 

702 def matvec(self, x, adjoint=False, name="matvec"): 

703 """Transform [batch] vector `x` with left multiplication: `x --> Ax`. 

704 

705 ```python 

706 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 

707 operator = LinearOperator(...) 

708 

709 X = ... # shape [..., N], batch vector 

710 

711 Y = operator.matvec(X) 

712 Y.shape 

713 ==> [..., M] 

714 

715 Y[..., :] = sum_j A[..., :, j] X[..., j] 

716 ``` 

717 

718 Args: 

719 x: `Tensor` with compatible shape and same `dtype` as `self`. 

720 `x` is treated as a [batch] vector meaning for every set of leading 

721 dimensions, the last dimension defines a vector. 

722 See class docstring for definition of compatibility. 

723 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 

724 name: A name for this `Op`. 

725 

726 Returns: 

727 A `Tensor` with shape `[..., M]` and same `dtype` as `self`. 

728 """ 

729 with self._name_scope(name): # pylint: disable=not-callable 

730 x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x, name="x") 

731 self._check_input_dtype(x) 

732 self_dim = -2 if adjoint else -1 

733 tensor_shape.dimension_at_index( 

734 self.shape, self_dim).assert_is_compatible_with(x.shape[-1]) 

735 return self._matvec(x, adjoint=adjoint) 

736 

737 def _determinant(self): 

738 logging.warn( 

739 "Using (possibly slow) default implementation of determinant." 

740 " Requires conversion to a dense matrix and O(N^3) operations.") 

741 if self._can_use_cholesky(): 

742 return math_ops.exp(self.log_abs_determinant()) 

743 return linalg_ops.matrix_determinant(self.to_dense()) 

744 

745 def determinant(self, name="det"): 

746 """Determinant for every batch member. 

747 

748 Args: 

749 name: A name for this `Op`. 

750 

751 Returns: 

752 `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. 

753 

754 Raises: 

755 NotImplementedError: If `self.is_square` is `False`. 

756 """ 

757 if self.is_square is False: 

758 raise NotImplementedError( 

759 "Determinant not implemented for an operator that is expected to " 

760 "not be square.") 

761 with self._name_scope(name): # pylint: disable=not-callable 

762 return self._determinant() 

763 

764 def _log_abs_determinant(self): 

765 logging.warn( 

766 "Using (possibly slow) default implementation of determinant." 

767 " Requires conversion to a dense matrix and O(N^3) operations.") 

768 if self._can_use_cholesky(): 

769 diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())) 

770 return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1]) 

771 _, log_abs_det = linalg.slogdet(self.to_dense()) 

772 return log_abs_det 

773 

774 def log_abs_determinant(self, name="log_abs_det"): 

775 """Log absolute value of determinant for every batch member. 

776 

777 Args: 

778 name: A name for this `Op`. 

779 

780 Returns: 

781 `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. 

782 

783 Raises: 

784 NotImplementedError: If `self.is_square` is `False`. 

785 """ 

786 if self.is_square is False: 

787 raise NotImplementedError( 

788 "Determinant not implemented for an operator that is expected to " 

789 "not be square.") 

790 with self._name_scope(name): # pylint: disable=not-callable 

791 return self._log_abs_determinant() 

792 

793 def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False): 

794 """Solve by conversion to a dense matrix.""" 

795 if self.is_square is False: # pylint: disable=g-bool-id-comparison 

796 raise NotImplementedError( 

797 "Solve is not yet implemented for non-square operators.") 

798 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 

799 if self._can_use_cholesky(): 

800 return linalg_ops.cholesky_solve( 

801 linalg_ops.cholesky(self.to_dense()), rhs) 

802 return linear_operator_util.matrix_solve_with_broadcast( 

803 self.to_dense(), rhs, adjoint=adjoint) 

804 

805 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 

806 """Default implementation of _solve.""" 

807 logging.warn( 

808 "Using (possibly slow) default implementation of solve." 

809 " Requires conversion to a dense matrix and O(N^3) operations.") 

810 return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 

811 

812 def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): 

813 """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. 

814 

815 The returned `Tensor` will be close to an exact solution if `A` is well 

816 conditioned. Otherwise closeness will vary. See class docstring for details. 

817 

818 Examples: 

819 

820 ```python 

821 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 

822 operator = LinearOperator(...) 

823 operator.shape = [..., M, N] 

824 

825 # Solve R > 0 linear systems for every member of the batch. 

826 RHS = ... # shape [..., M, R] 

827 

828 X = operator.solve(RHS) 

829 # X[..., :, r] is the solution to the r'th linear system 

830 # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] 

831 

832 operator.matmul(X) 

833 ==> RHS 

834 ``` 

835 

836 Args: 

837 rhs: `Tensor` with same `dtype` as this operator and compatible shape. 

838 `rhs` is treated like a [batch] matrix meaning for every set of leading 

839 dimensions, the last two dimensions defines a matrix. 

840 See class docstring for definition of compatibility. 

841 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 

842 of this `LinearOperator`: `A^H X = rhs`. 

843 adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` 

844 is the hermitian transpose (transposition and complex conjugation). 

845 name: A name scope to use for ops added by this method. 

846 

847 Returns: 

848 `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. 

849 

850 Raises: 

851 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 

852 """ 

853 if self.is_non_singular is False: 

854 raise NotImplementedError( 

855 "Exact solve not implemented for an operator that is expected to " 

856 "be singular.") 

857 if self.is_square is False: 

858 raise NotImplementedError( 

859 "Exact solve not implemented for an operator that is expected to " 

860 "not be square.") 

861 if isinstance(rhs, LinearOperator): 

862 left_operator = self.adjoint() if adjoint else self 

863 right_operator = rhs.adjoint() if adjoint_arg else rhs 

864 

865 if (right_operator.range_dimension is not None and 

866 left_operator.domain_dimension is not None and 

867 right_operator.range_dimension != left_operator.domain_dimension): 

868 raise ValueError( 

869 "Operators are incompatible. Expected `rhs` to have dimension" 

870 " {} but got {}.".format( 

871 left_operator.domain_dimension, right_operator.range_dimension)) 

872 with self._name_scope(name): # pylint: disable=not-callable 

873 return linear_operator_algebra.solve(left_operator, right_operator) 

874 

875 with self._name_scope(name): # pylint: disable=not-callable 

876 rhs = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

877 rhs, name="rhs" 

878 ) 

879 self._check_input_dtype(rhs) 

880 

881 self_dim = -1 if adjoint else -2 

882 arg_dim = -1 if adjoint_arg else -2 

883 tensor_shape.dimension_at_index( 

884 self.shape, self_dim).assert_is_compatible_with( 

885 rhs.shape[arg_dim]) 

886 

887 return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 

888 

889 def _solvevec(self, rhs, adjoint=False): 

890 """Default implementation of _solvevec.""" 

891 rhs_mat = array_ops.expand_dims(rhs, axis=-1) 

892 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 

893 return array_ops.squeeze(solution_mat, axis=-1) 

894 

895 def solvevec(self, rhs, adjoint=False, name="solve"): 

896 """Solve single equation with best effort: `A X = rhs`. 

897 

898 The returned `Tensor` will be close to an exact solution if `A` is well 

899 conditioned. Otherwise closeness will vary. See class docstring for details. 

900 

901 Examples: 

902 

903 ```python 

904 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 

905 operator = LinearOperator(...) 

906 operator.shape = [..., M, N] 

907 

908 # Solve one linear system for every member of the batch. 

909 RHS = ... # shape [..., M] 

910 

911 X = operator.solvevec(RHS) 

912 # X is the solution to the linear system 

913 # sum_j A[..., :, j] X[..., j] = RHS[..., :] 

914 

915 operator.matvec(X) 

916 ==> RHS 

917 ``` 

918 

919 Args: 

920 rhs: `Tensor` with same `dtype` as this operator. 

921 `rhs` is treated like a [batch] vector meaning for every set of leading 

922 dimensions, the last dimension defines a vector. See class docstring 

923 for definition of compatibility regarding batch dimensions. 

924 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 

925 of this `LinearOperator`: `A^H X = rhs`. 

926 name: A name scope to use for ops added by this method. 

927 

928 Returns: 

929 `Tensor` with shape `[...,N]` and same `dtype` as `rhs`. 

930 

931 Raises: 

932 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 

933 """ 

934 with self._name_scope(name): # pylint: disable=not-callable 

935 rhs = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

936 rhs, name="rhs" 

937 ) 

938 self._check_input_dtype(rhs) 

939 self_dim = -1 if adjoint else -2 

940 tensor_shape.dimension_at_index( 

941 self.shape, self_dim).assert_is_compatible_with(rhs.shape[-1]) 

942 

943 return self._solvevec(rhs, adjoint=adjoint) 

944 

945 def adjoint(self, name="adjoint"): 

946 """Returns the adjoint of the current `LinearOperator`. 

947 

948 Given `A` representing this `LinearOperator`, return `A*`. 

949 Note that calling `self.adjoint()` and `self.H` are equivalent. 

950 

951 Args: 

952 name: A name for this `Op`. 

953 

954 Returns: 

955 `LinearOperator` which represents the adjoint of this `LinearOperator`. 

956 """ 

957 if self.is_self_adjoint is True: # pylint: disable=g-bool-id-comparison 

958 return self 

959 with self._name_scope(name): # pylint: disable=not-callable 

960 return linear_operator_algebra.adjoint(self) 

961 

962 # self.H is equivalent to self.adjoint(). 

963 H = property(adjoint, None) 

964 

965 def inverse(self, name="inverse"): 

966 """Returns the Inverse of this `LinearOperator`. 

967 

968 Given `A` representing this `LinearOperator`, return a `LinearOperator` 

969 representing `A^-1`. 

970 

971 Args: 

972 name: A name scope to use for ops added by this method. 

973 

974 Returns: 

975 `LinearOperator` representing inverse of this matrix. 

976 

977 Raises: 

978 ValueError: When the `LinearOperator` is not hinted to be `non_singular`. 

979 """ 

980 if self.is_square is False: # pylint: disable=g-bool-id-comparison 

981 raise ValueError("Cannot take the Inverse: This operator represents " 

982 "a non square matrix.") 

983 if self.is_non_singular is False: # pylint: disable=g-bool-id-comparison 

984 raise ValueError("Cannot take the Inverse: This operator represents " 

985 "a singular matrix.") 

986 

987 with self._name_scope(name): # pylint: disable=not-callable 

988 return linear_operator_algebra.inverse(self) 

989 

990 def cholesky(self, name="cholesky"): 

991 """Returns a Cholesky factor as a `LinearOperator`. 

992 

993 Given `A` representing this `LinearOperator`, if `A` is positive definite 

994 self-adjoint, return `L`, where `A = L L^T`, i.e. the cholesky 

995 decomposition. 

996 

997 Args: 

998 name: A name for this `Op`. 

999 

1000 Returns: 

1001 `LinearOperator` which represents the lower triangular matrix 

1002 in the Cholesky decomposition. 

1003 

1004 Raises: 

1005 ValueError: When the `LinearOperator` is not hinted to be positive 

1006 definite and self adjoint. 

1007 """ 

1008 

1009 if not self._can_use_cholesky(): 

1010 raise ValueError("Cannot take the Cholesky decomposition: " 

1011 "Not a positive definite self adjoint matrix.") 

1012 with self._name_scope(name): # pylint: disable=not-callable 

1013 return linear_operator_algebra.cholesky(self) 

1014 

1015 def _to_dense(self): 

1016 """Generic and often inefficient implementation. Override often.""" 

1017 if self.batch_shape.is_fully_defined(): 

1018 batch_shape = self.batch_shape 

1019 else: 

1020 batch_shape = self.batch_shape_tensor() 

1021 

1022 dim_value = tensor_shape.dimension_value(self.domain_dimension) 

1023 if dim_value is not None: 

1024 n = dim_value 

1025 else: 

1026 n = self.domain_dimension_tensor() 

1027 

1028 eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype) 

1029 return self.matmul(eye) 

1030 

1031 def to_dense(self, name="to_dense"): 

1032 """Return a dense (batch) matrix representing this operator.""" 

1033 with self._name_scope(name): # pylint: disable=not-callable 

1034 return self._to_dense() 

1035 

1036 def _diag_part(self): 

1037 """Generic and often inefficient implementation. Override often.""" 

1038 return array_ops.matrix_diag_part(self.to_dense()) 

1039 

1040 def diag_part(self, name="diag_part"): 

1041 """Efficiently get the [batch] diagonal part of this operator. 

1042 

1043 If this operator has shape `[B1,...,Bb, M, N]`, this returns a 

1044 `Tensor` `diagonal`, of shape `[B1,...,Bb, min(M, N)]`, where 

1045 `diagonal[b1,...,bb, i] = self.to_dense()[b1,...,bb, i, i]`. 

1046 

1047 ``` 

1048 my_operator = LinearOperatorDiag([1., 2.]) 

1049 

1050 # Efficiently get the diagonal 

1051 my_operator.diag_part() 

1052 ==> [1., 2.] 

1053 

1054 # Equivalent, but inefficient method 

1055 tf.linalg.diag_part(my_operator.to_dense()) 

1056 ==> [1., 2.] 

1057 ``` 

1058 

1059 Args: 

1060 name: A name for this `Op`. 

1061 

1062 Returns: 

1063 diag_part: A `Tensor` of same `dtype` as self. 

1064 """ 

1065 with self._name_scope(name): # pylint: disable=not-callable 

1066 return self._diag_part() 

1067 

1068 def _trace(self): 

1069 return math_ops.reduce_sum(self.diag_part(), axis=-1) 

1070 

1071 def trace(self, name="trace"): 

1072 """Trace of the linear operator, equal to sum of `self.diag_part()`. 

1073 

1074 If the operator is square, this is also the sum of the eigenvalues. 

1075 

1076 Args: 

1077 name: A name for this `Op`. 

1078 

1079 Returns: 

1080 Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`. 

1081 """ 

1082 with self._name_scope(name): # pylint: disable=not-callable 

1083 return self._trace() 

1084 

1085 def _add_to_tensor(self, x): 

1086 # Override if a more efficient implementation is available. 

1087 return self.to_dense() + x 

1088 

1089 def add_to_tensor(self, x, name="add_to_tensor"): 

1090 """Add matrix represented by this operator to `x`. Equivalent to `A + x`. 

1091 

1092 Args: 

1093 x: `Tensor` with same `dtype` and shape broadcastable to `self.shape`. 

1094 name: A name to give this `Op`. 

1095 

1096 Returns: 

1097 A `Tensor` with broadcast shape and same `dtype` as `self`. 

1098 """ 

1099 with self._name_scope(name): # pylint: disable=not-callable 

1100 x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x, name="x") 

1101 self._check_input_dtype(x) 

1102 return self._add_to_tensor(x) 

1103 

1104 def _eigvals(self): 

1105 return linalg_ops.self_adjoint_eigvals(self.to_dense()) 

1106 

1107 def eigvals(self, name="eigvals"): 

1108 """Returns the eigenvalues of this linear operator. 

1109 

1110 If the operator is marked as self-adjoint (via `is_self_adjoint`) 

1111 this computation can be more efficient. 

1112 

1113 Note: This currently only supports self-adjoint operators. 

1114 

1115 Args: 

1116 name: A name for this `Op`. 

1117 

1118 Returns: 

1119 Shape `[B1,...,Bb, N]` `Tensor` of same `dtype` as `self`. 

1120 """ 

1121 if not self.is_self_adjoint: 

1122 raise NotImplementedError("Only self-adjoint matrices are supported.") 

1123 with self._name_scope(name): # pylint: disable=not-callable 

1124 return self._eigvals() 

1125 

1126 def _cond(self): 

1127 if not self.is_self_adjoint: 

1128 # In general the condition number is the ratio of the 

1129 # absolute value of the largest and smallest singular values. 

1130 vals = linalg_ops.svd(self.to_dense(), compute_uv=False) 

1131 else: 

1132 # For self-adjoint matrices, and in general normal matrices, 

1133 # we can use eigenvalues. 

1134 vals = math_ops.abs(self._eigvals()) 

1135 

1136 return (math_ops.reduce_max(vals, axis=-1) / 

1137 math_ops.reduce_min(vals, axis=-1)) 

1138 

1139 def cond(self, name="cond"): 

1140 """Returns the condition number of this linear operator. 

1141 

1142 Args: 

1143 name: A name for this `Op`. 

1144 

1145 Returns: 

1146 Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`. 

1147 """ 

1148 with self._name_scope(name): # pylint: disable=not-callable 

1149 return self._cond() 

1150 

1151 def _can_use_cholesky(self): 

1152 return self.is_self_adjoint and self.is_positive_definite 

1153 

1154 def _set_graph_parents(self, graph_parents): 

1155 """Set self._graph_parents. Called during derived class init. 

1156 

1157 This method allows derived classes to set graph_parents, without triggering 

1158 a deprecation warning (which is invoked if `graph_parents` is passed during 

1159 `__init__`. 

1160 

1161 Args: 

1162 graph_parents: Iterable over Tensors. 

1163 """ 

1164 # TODO(b/143910018) Remove this function in V3. 

1165 graph_parents = [] if graph_parents is None else graph_parents 

1166 for i, t in enumerate(graph_parents): 

1167 if t is None or not (linear_operator_util.is_ref(t) or 

1168 tensor_util.is_tf_type(t)): 

1169 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) 

1170 self._graph_parents = graph_parents 

1171 

1172 @property 

1173 def _composite_tensor_fields(self): 

1174 """A tuple of parameter names to rebuild the `LinearOperator`. 

1175 

1176 The tuple contains the names of kwargs to the `LinearOperator`'s constructor 

1177 that the `TypeSpec` needs to rebuild the `LinearOperator` instance. 

1178 

1179 "is_non_singular", "is_self_adjoint", "is_positive_definite", and 

1180 "is_square" are common to all `LinearOperator` subclasses and may be 

1181 omitted. 

1182 """ 

1183 return () 

1184 

1185 @property 

1186 def _composite_tensor_prefer_static_fields(self): 

1187 """A tuple of names referring to parameters that may be treated statically. 

1188 

1189 This is a subset of `_composite_tensor_fields`, and contains the names of 

1190 of `Tensor`-like args to the `LinearOperator`s constructor that may be 

1191 stored as static values, if they are statically known. These are typically 

1192 shapes or axis values. 

1193 """ 

1194 return () 

1195 

1196 @property 

1197 def _type_spec(self): 

1198 # This property will be overwritten by the `@make_composite_tensor` 

1199 # decorator. However, we need it so that a valid subclass of the `ABCMeta` 

1200 # class `CompositeTensor` can be constructed and passed to the 

1201 # `@make_composite_tensor` decorator. 

1202 pass 

1203 

1204 def _convert_variables_to_tensors(self): 

1205 """Recursively converts ResourceVariables in the LinearOperator to Tensors. 

1206 

1207 The usage of `self._type_spec._from_components` violates the contract of 

1208 `CompositeTensor`, since it is called on a different nested structure 

1209 (one containing only `Tensor`s) than `self.type_spec` specifies (one that 

1210 may contain `ResourceVariable`s). Since `LinearOperator`'s 

1211 `_from_components` method just passes the contents of the nested structure 

1212 to `__init__` to rebuild the operator, and any `LinearOperator` that may be 

1213 instantiated with `ResourceVariables` may also be instantiated with 

1214 `Tensor`s, this usage is valid. 

1215 

1216 Returns: 

1217 tensor_operator: `self` with all internal Variables converted to Tensors. 

1218 """ 

1219 # pylint: disable=protected-access 

1220 components = self._type_spec._to_components(self) 

1221 tensor_components = variable_utils.convert_variables_to_tensors( 

1222 components) 

1223 return self._type_spec._from_components(tensor_components) 

1224 # pylint: enable=protected-access 

1225 

1226 def __getitem__(self, slices): 

1227 return slicing.batch_slice(self, params_overrides={}, slices=slices) 

1228 

1229 @property 

1230 def _experimental_parameter_ndims_to_matrix_ndims(self): 

1231 """A dict of names to number of dimensions contributing to an operator. 

1232 

1233 This is a dictionary of parameter names to `int`s specifying the 

1234 number of right-most dimensions contributing to the **matrix** shape of the 

1235 densified operator. 

1236 If the parameter is a `Tensor`, this is mapped to an `int`. 

1237 If the parameter is a `LinearOperator` (called `A`), this specifies the 

1238 number of batch dimensions of `A` contributing to this `LinearOperator`s 

1239 matrix shape. 

1240 If the parameter is a structure, this is a structure of the same type of 

1241 `int`s. 

1242 """ 

1243 return () 

1244 

1245 

1246class _LinearOperatorSpec(type_spec.BatchableTypeSpec): 

1247 """A tf.TypeSpec for `LinearOperator` objects.""" 

1248 

1249 __slots__ = ("_param_specs", "_non_tensor_params", "_prefer_static_fields") 

1250 

1251 def __init__(self, param_specs, non_tensor_params, prefer_static_fields): 

1252 """Initializes a new `_LinearOperatorSpec`. 

1253 

1254 Args: 

1255 param_specs: Python `dict` of `tf.TypeSpec` instances that describe 

1256 kwargs to the `LinearOperator`'s constructor that are `Tensor`-like or 

1257 `CompositeTensor` subclasses. 

1258 non_tensor_params: Python `dict` containing non-`Tensor` and non- 

1259 `CompositeTensor` kwargs to the `LinearOperator`'s constructor. 

1260 prefer_static_fields: Python `tuple` of strings corresponding to the names 

1261 of `Tensor`-like args to the `LinearOperator`s constructor that may be 

1262 stored as static values, if known. These are typically shapes, indices, 

1263 or axis values. 

1264 """ 

1265 self._param_specs = param_specs 

1266 self._non_tensor_params = non_tensor_params 

1267 self._prefer_static_fields = prefer_static_fields 

1268 

1269 @classmethod 

1270 def from_operator(cls, operator): 

1271 """Builds a `_LinearOperatorSpec` from a `LinearOperator` instance. 

1272 

1273 Args: 

1274 operator: An instance of `LinearOperator`. 

1275 

1276 Returns: 

1277 linear_operator_spec: An instance of `_LinearOperatorSpec` to be used as 

1278 the `TypeSpec` of `operator`. 

1279 """ 

1280 validation_fields = ("is_non_singular", "is_self_adjoint", 

1281 "is_positive_definite", "is_square") 

1282 kwargs = _extract_attrs( 

1283 operator, 

1284 keys=set(operator._composite_tensor_fields + validation_fields)) # pylint: disable=protected-access 

1285 

1286 non_tensor_params = {} 

1287 param_specs = {} 

1288 for k, v in list(kwargs.items()): 

1289 type_spec_or_v = _extract_type_spec_recursively(v) 

1290 is_tensor = [isinstance(x, type_spec.TypeSpec) 

1291 for x in nest.flatten(type_spec_or_v)] 

1292 if all(is_tensor): 

1293 param_specs[k] = type_spec_or_v 

1294 elif not any(is_tensor): 

1295 non_tensor_params[k] = v 

1296 else: 

1297 raise NotImplementedError(f"Field {k} contains a mix of `Tensor` and " 

1298 f" non-`Tensor` values.") 

1299 

1300 return cls( 

1301 param_specs=param_specs, 

1302 non_tensor_params=non_tensor_params, 

1303 prefer_static_fields=operator._composite_tensor_prefer_static_fields) # pylint: disable=protected-access 

1304 

1305 def _to_components(self, obj): 

1306 return _extract_attrs(obj, keys=list(self._param_specs)) 

1307 

1308 def _from_components(self, components): 

1309 kwargs = dict(self._non_tensor_params, **components) 

1310 return self.value_type(**kwargs) 

1311 

1312 @property 

1313 def _component_specs(self): 

1314 return self._param_specs 

1315 

1316 def _serialize(self): 

1317 return (self._param_specs, 

1318 self._non_tensor_params, 

1319 self._prefer_static_fields) 

1320 

1321 def _copy(self, **overrides): 

1322 kwargs = { 

1323 "param_specs": self._param_specs, 

1324 "non_tensor_params": self._non_tensor_params, 

1325 "prefer_static_fields": self._prefer_static_fields 

1326 } 

1327 kwargs.update(overrides) 

1328 return type(self)(**kwargs) 

1329 

1330 def _batch(self, batch_size): 

1331 """Returns a TypeSpec representing a batch of objects with this TypeSpec.""" 

1332 return self._copy( 

1333 param_specs=nest.map_structure( 

1334 lambda spec: spec._batch(batch_size), # pylint: disable=protected-access 

1335 self._param_specs)) 

1336 

1337 def _unbatch(self, batch_size): 

1338 """Returns a TypeSpec representing a single element of this TypeSpec.""" 

1339 return self._copy( 

1340 param_specs=nest.map_structure( 

1341 lambda spec: spec._unbatch(), # pylint: disable=protected-access 

1342 self._param_specs)) 

1343 

1344 

1345def make_composite_tensor(cls, module_name="tf.linalg"): 

1346 """Class decorator to convert `LinearOperator`s to `CompositeTensor`.""" 

1347 

1348 spec_name = "{}Spec".format(cls.__name__) 

1349 spec_type = type(spec_name, (_LinearOperatorSpec,), {"value_type": cls}) 

1350 type_spec_registry.register("{}.{}".format(module_name, spec_name))(spec_type) 

1351 cls._type_spec = property(spec_type.from_operator) # pylint: disable=protected-access 

1352 return cls 

1353 

1354 

1355def _extract_attrs(op, keys): 

1356 """Extract constructor kwargs to reconstruct `op`. 

1357 

1358 Args: 

1359 op: A `LinearOperator` instance. 

1360 keys: A Python `tuple` of strings indicating the names of the constructor 

1361 kwargs to extract from `op`. 

1362 

1363 Returns: 

1364 kwargs: A Python `dict` of kwargs to `op`'s constructor, keyed by `keys`. 

1365 """ 

1366 

1367 kwargs = {} 

1368 not_found = object() 

1369 for k in keys: 

1370 srcs = [ 

1371 getattr(op, k, not_found), getattr(op, "_" + k, not_found), 

1372 getattr(op, "parameters", {}).get(k, not_found), 

1373 ] 

1374 if any(v is not not_found for v in srcs): 

1375 kwargs[k] = [v for v in srcs if v is not not_found][0] 

1376 else: 

1377 raise ValueError( 

1378 f"Could not determine an appropriate value for field `{k}` in object " 

1379 f" `{op}`. Looked for \n" 

1380 f" 1. an attr called `{k}`,\n" 

1381 f" 2. an attr called `_{k}`,\n" 

1382 f" 3. an entry in `op.parameters` with key '{k}'.") 

1383 if k in op._composite_tensor_prefer_static_fields and kwargs[k] is not None: # pylint: disable=protected-access 

1384 if tensor_util.is_tensor(kwargs[k]): 

1385 static_val = tensor_util.constant_value(kwargs[k]) 

1386 if static_val is not None: 

1387 kwargs[k] = static_val 

1388 if isinstance(kwargs[k], (np.ndarray, np.generic)): 

1389 kwargs[k] = kwargs[k].tolist() 

1390 return kwargs 

1391 

1392 

1393def _extract_type_spec_recursively(value): 

1394 """Return (collection of) `TypeSpec`(s) for `value` if it includes `Tensor`s. 

1395 

1396 If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If 

1397 `value` is a collection containing `Tensor` values, recursively supplant them 

1398 with their respective `TypeSpec`s in a collection of parallel stucture. 

1399 

1400 If `value` is none of the above, return it unchanged. 

1401 

1402 Args: 

1403 value: a Python `object` to (possibly) turn into a (collection of) 

1404 `tf.TypeSpec`(s). 

1405 

1406 Returns: 

1407 spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value` 

1408 or `value`, if no `Tensor`s are found. 

1409 """ 

1410 if isinstance(value, composite_tensor.CompositeTensor): 

1411 return value._type_spec # pylint: disable=protected-access 

1412 if isinstance(value, variables.Variable): 

1413 return resource_variable_ops.VariableSpec( 

1414 value.shape, dtype=value.dtype, trainable=value.trainable) 

1415 if tensor_util.is_tensor(value): 

1416 return tensor_spec.TensorSpec(value.shape, value.dtype) 

1417 # Unwrap trackable data structures to comply with `Type_Spec._serialize` 

1418 # requirements. `ListWrapper`s are converted to `list`s, and for other 

1419 # trackable data structures, the `__wrapped__` attribute is used. 

1420 if isinstance(value, list): 

1421 return list(_extract_type_spec_recursively(v) for v in value) 

1422 if isinstance(value, data_structures.TrackableDataStructure): 

1423 return _extract_type_spec_recursively(value.__wrapped__) 

1424 if isinstance(value, tuple): 

1425 return type(value)(_extract_type_spec_recursively(x) for x in value) 

1426 if isinstance(value, dict): 

1427 return type(value)((k, _extract_type_spec_recursively(v)) 

1428 for k, v in value.items()) 

1429 return value 

1430 

1431 

1432# Overrides for tf.linalg functions. This allows a LinearOperator to be used in 

1433# place of a Tensor. 

1434# For instance tf.trace(linop) and linop.trace() both work. 

1435 

1436 

1437@dispatch.dispatch_for_types(linalg.adjoint, LinearOperator) 

1438def _adjoint(matrix, name=None): 

1439 return matrix.adjoint(name) 

1440 

1441 

1442@dispatch.dispatch_for_types(linalg.cholesky, LinearOperator) 

1443def _cholesky(input, name=None): # pylint:disable=redefined-builtin 

1444 return input.cholesky(name) 

1445 

1446 

1447# The signature has to match with the one in python/op/array_ops.py, 

1448# so we have k, padding_value, and align even though we don't use them here. 

1449# pylint:disable=unused-argument 

1450@dispatch.dispatch_for_types(linalg.diag_part, LinearOperator) 

1451def _diag_part( 

1452 input, # pylint:disable=redefined-builtin 

1453 name="diag_part", 

1454 k=0, 

1455 padding_value=0, 

1456 align="RIGHT_LEFT"): 

1457 return input.diag_part(name) 

1458# pylint:enable=unused-argument 

1459 

1460 

1461@dispatch.dispatch_for_types(linalg.det, LinearOperator) 

1462def _det(input, name=None): # pylint:disable=redefined-builtin 

1463 return input.determinant(name) 

1464 

1465 

1466@dispatch.dispatch_for_types(linalg.inv, LinearOperator) 

1467def _inverse(input, adjoint=False, name=None): # pylint:disable=redefined-builtin 

1468 inv = input.inverse(name) 

1469 if adjoint: 

1470 inv = inv.adjoint() 

1471 return inv 

1472 

1473 

1474@dispatch.dispatch_for_types(linalg.logdet, LinearOperator) 

1475def _logdet(matrix, name=None): 

1476 if matrix.is_positive_definite and matrix.is_self_adjoint: 

1477 return matrix.log_abs_determinant(name) 

1478 raise ValueError("Expected matrix to be self-adjoint positive definite.") 

1479 

1480 

1481@dispatch.dispatch_for_types(math_ops.matmul, LinearOperator) 

1482def _matmul( # pylint:disable=missing-docstring 

1483 a, 

1484 b, 

1485 transpose_a=False, 

1486 transpose_b=False, 

1487 adjoint_a=False, 

1488 adjoint_b=False, 

1489 a_is_sparse=False, 

1490 b_is_sparse=False, 

1491 output_type=None, # pylint: disable=unused-argument 

1492 name=None): 

1493 if transpose_a or transpose_b: 

1494 raise ValueError("Transposing not supported at this time.") 

1495 if a_is_sparse or b_is_sparse: 

1496 raise ValueError("Sparse methods not supported at this time.") 

1497 if not isinstance(a, LinearOperator): 

1498 # We use the identity (B^HA^H)^H = AB 

1499 adjoint_matmul = b.matmul( 

1500 a, 

1501 adjoint=(not adjoint_b), 

1502 adjoint_arg=(not adjoint_a), 

1503 name=name) 

1504 return linalg.adjoint(adjoint_matmul) 

1505 return a.matmul( 

1506 b, adjoint=adjoint_a, adjoint_arg=adjoint_b, name=name) 

1507 

1508 

1509@dispatch.dispatch_for_types(linalg.solve, LinearOperator) 

1510def _solve( 

1511 matrix, 

1512 rhs, 

1513 adjoint=False, 

1514 name=None): 

1515 if not isinstance(matrix, LinearOperator): 

1516 raise ValueError("Passing in `matrix` as a Tensor and `rhs` as a " 

1517 "LinearOperator is not supported.") 

1518 return matrix.solve(rhs, adjoint=adjoint, name=name) 

1519 

1520 

1521@dispatch.dispatch_for_types(linalg.trace, LinearOperator) 

1522def _trace(x, name=None): 

1523 return x.trace(name)