Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator.py: 34%
493 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for linear operators."""
17import abc
18import contextlib
20import numpy as np
22from tensorflow.python.framework import composite_tensor
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_conversion
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.framework import tensor_util
29from tensorflow.python.framework import type_spec
30from tensorflow.python.framework import type_spec_registry
31from tensorflow.python.module import module
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import check_ops
34from tensorflow.python.ops import linalg_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.ops import variables
38from tensorflow.python.ops.linalg import linalg_impl as linalg
39from tensorflow.python.ops.linalg import linear_operator_algebra
40from tensorflow.python.ops.linalg import linear_operator_util
41from tensorflow.python.ops.linalg import slicing
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.trackable import data_structures
44from tensorflow.python.util import deprecation
45from tensorflow.python.util import dispatch
46from tensorflow.python.util import nest
47from tensorflow.python.util import variable_utils
48from tensorflow.python.util.tf_export import tf_export
50__all__ = ["LinearOperator"]
53# TODO(langmore) Use matrix_solve_ls for singular or non-square matrices.
54@tf_export("linalg.LinearOperator")
55class LinearOperator(
56 module.Module, composite_tensor.CompositeTensor, metaclass=abc.ABCMeta):
57 """Base class defining a [batch of] linear operator[s].
59 Subclasses of `LinearOperator` provide access to common methods on a
60 (batch) matrix, without the need to materialize the matrix. This allows:
62 * Matrix free computations
63 * Operators that take advantage of special structure, while providing a
64 consistent API to users.
66 #### Subclassing
68 To enable a public method, subclasses should implement the leading-underscore
69 version of the method. The argument signature should be identical except for
70 the omission of `name="..."`. For example, to enable
71 `matmul(x, adjoint=False, name="matmul")` a subclass should implement
72 `_matmul(x, adjoint=False)`.
74 #### Performance contract
76 Subclasses should only implement the assert methods
77 (e.g. `assert_non_singular`) if they can be done in less than `O(N^3)`
78 time.
80 Class docstrings should contain an explanation of computational complexity.
81 Since this is a high-performance library, attention should be paid to detail,
82 and explanations can include constants as well as Big-O notation.
84 #### Shape compatibility
86 `LinearOperator` subclasses should operate on a [batch] matrix with
87 compatible shape. Class docstrings should define what is meant by compatible
88 shape. Some subclasses may not support batching.
90 Examples:
92 `x` is a batch matrix with compatible shape for `matmul` if
94 ```
95 operator.shape = [B1,...,Bb] + [M, N], b >= 0,
96 x.shape = [B1,...,Bb] + [N, R]
97 ```
99 `rhs` is a batch matrix with compatible shape for `solve` if
101 ```
102 operator.shape = [B1,...,Bb] + [M, N], b >= 0,
103 rhs.shape = [B1,...,Bb] + [M, R]
104 ```
106 #### Example docstring for subclasses.
108 This operator acts like a (batch) matrix `A` with shape
109 `[B1,...,Bb, M, N]` for some `b >= 0`. The first `b` indices index a
110 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
111 an `m x n` matrix. Again, this matrix `A` may not be materialized, but for
112 purposes of identifying and working with compatible arguments the shape is
113 relevant.
115 Examples:
117 ```python
118 some_tensor = ... shape = ????
119 operator = MyLinOp(some_tensor)
121 operator.shape()
122 ==> [2, 4, 4]
124 operator.log_abs_determinant()
125 ==> Shape [2] Tensor
127 x = ... Shape [2, 4, 5] Tensor
129 operator.matmul(x)
130 ==> Shape [2, 4, 5] Tensor
131 ```
133 #### Shape compatibility
135 This operator acts on batch matrices with compatible shape.
136 FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE
138 #### Performance
140 FILL THIS IN
142 #### Matrix property hints
144 This `LinearOperator` is initialized with boolean flags of the form `is_X`,
145 for `X = non_singular, self_adjoint, positive_definite, square`.
146 These have the following meaning:
148 * If `is_X == True`, callers should expect the operator to have the
149 property `X`. This is a promise that should be fulfilled, but is *not* a
150 runtime assert. For example, finite floating point precision may result
151 in these promises being violated.
152 * If `is_X == False`, callers should expect the operator to not have `X`.
153 * If `is_X == None` (the default), callers should have no expectation either
154 way.
156 #### Initialization parameters
158 All subclasses of `LinearOperator` are expected to pass a `parameters`
159 argument to `super().__init__()`. This should be a `dict` containing
160 the unadulterated arguments passed to the subclass `__init__`. For example,
161 `MyLinearOperator` with an initializer should look like:
163 ```python
164 def __init__(self, operator, is_square=False, name=None):
165 parameters = dict(
166 operator=operator,
167 is_square=is_square,
168 name=name
169 )
170 ...
171 super().__init__(..., parameters=parameters)
172 ```
174 Users can then access `my_linear_operator.parameters` to see all arguments
175 passed to its initializer.
176 """
178 # TODO(b/143910018) Remove graph_parents in V3.
179 @deprecation.deprecated_args(None, "Do not pass `graph_parents`. They will "
180 " no longer be used.", "graph_parents")
181 def __init__(self,
182 dtype,
183 graph_parents=None,
184 is_non_singular=None,
185 is_self_adjoint=None,
186 is_positive_definite=None,
187 is_square=None,
188 name=None,
189 parameters=None):
190 """Initialize the `LinearOperator`.
192 **This is a private method for subclass use.**
193 **Subclasses should copy-paste this `__init__` documentation.**
195 Args:
196 dtype: The type of the this `LinearOperator`. Arguments to `matmul` and
197 `solve` will have to be this type.
198 graph_parents: (Deprecated) Python list of graph prerequisites of this
199 `LinearOperator` Typically tensors that are passed during initialization
200 is_non_singular: Expect that this operator is non-singular.
201 is_self_adjoint: Expect that this operator is equal to its hermitian
202 transpose. If `dtype` is real, this is equivalent to being symmetric.
203 is_positive_definite: Expect that this operator is positive definite,
204 meaning the quadratic form `x^H A x` has positive real part for all
205 nonzero `x`. Note that we do not require the operator to be
206 self-adjoint to be positive-definite. See:
207 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
208 is_square: Expect that this operator acts like square [batch] matrices.
209 name: A name for this `LinearOperator`.
210 parameters: Python `dict` of parameters used to instantiate this
211 `LinearOperator`.
213 Raises:
214 ValueError: If any member of graph_parents is `None` or not a `Tensor`.
215 ValueError: If hints are set incorrectly.
216 """
217 # Check and auto-set flags.
218 if is_positive_definite:
219 if is_non_singular is False:
220 raise ValueError("A positive definite matrix is always non-singular.")
221 is_non_singular = True
223 if is_non_singular:
224 if is_square is False:
225 raise ValueError("A non-singular matrix is always square.")
226 is_square = True
228 if is_self_adjoint:
229 if is_square is False:
230 raise ValueError("A self-adjoint matrix is always square.")
231 is_square = True
233 self._is_square_set_or_implied_by_hints = is_square
235 if graph_parents is not None:
236 self._set_graph_parents(graph_parents)
237 else:
238 self._graph_parents = []
239 self._dtype = dtypes.as_dtype(dtype).base_dtype if dtype else dtype
240 self._is_non_singular = is_non_singular
241 self._is_self_adjoint = is_self_adjoint
242 self._is_positive_definite = is_positive_definite
243 self._parameters = self._no_dependency(parameters)
244 self._parameters_sanitized = False
245 self._name = name or type(self).__name__
247 @contextlib.contextmanager
248 def _name_scope(self, name=None): # pylint: disable=method-hidden
249 """Helper function to standardize op scope."""
250 full_name = self.name
251 if name is not None:
252 full_name += "/" + name
253 with ops.name_scope(full_name) as scope:
254 yield scope
256 @property
257 def parameters(self):
258 """Dictionary of parameters used to instantiate this `LinearOperator`."""
259 return dict(self._parameters)
261 @property
262 def dtype(self):
263 """The `DType` of `Tensor`s handled by this `LinearOperator`."""
264 return self._dtype
266 @property
267 def name(self):
268 """Name prepended to all ops created by this `LinearOperator`."""
269 return self._name
271 @property
272 @deprecation.deprecated(None, "Do not call `graph_parents`.")
273 def graph_parents(self):
274 """List of graph dependencies of this `LinearOperator`."""
275 return self._graph_parents
277 @property
278 def is_non_singular(self):
279 return self._is_non_singular
281 @property
282 def is_self_adjoint(self):
283 return self._is_self_adjoint
285 @property
286 def is_positive_definite(self):
287 return self._is_positive_definite
289 @property
290 def is_square(self):
291 """Return `True/False` depending on if this operator is square."""
292 # Static checks done after __init__. Why? Because domain/range dimension
293 # sometimes requires lots of work done in the derived class after init.
294 auto_square_check = self.domain_dimension == self.range_dimension
295 if self._is_square_set_or_implied_by_hints is False and auto_square_check:
296 raise ValueError(
297 "User set is_square hint to False, but the operator was square.")
298 if self._is_square_set_or_implied_by_hints is None:
299 return auto_square_check
301 return self._is_square_set_or_implied_by_hints
303 @abc.abstractmethod
304 def _shape(self):
305 # Write this in derived class to enable all static shape methods.
306 raise NotImplementedError("_shape is not implemented.")
308 @property
309 def shape(self):
310 """`TensorShape` of this `LinearOperator`.
312 If this operator acts like the batch matrix `A` with
313 `A.shape = [B1,...,Bb, M, N]`, then this returns
314 `TensorShape([B1,...,Bb, M, N])`, equivalent to `A.shape`.
316 Returns:
317 `TensorShape`, statically determined, may be undefined.
318 """
319 return self._shape()
321 def _shape_tensor(self):
322 # This is not an abstractmethod, since we want derived classes to be able to
323 # override this with optional kwargs, which can reduce the number of
324 # `convert_to_tensor` calls. See derived classes for examples.
325 raise NotImplementedError("_shape_tensor is not implemented.")
327 def shape_tensor(self, name="shape_tensor"):
328 """Shape of this `LinearOperator`, determined at runtime.
330 If this operator acts like the batch matrix `A` with
331 `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding
332 `[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`.
334 Args:
335 name: A name for this `Op`.
337 Returns:
338 `int32` `Tensor`
339 """
340 with self._name_scope(name): # pylint: disable=not-callable
341 # Prefer to use statically defined shape if available.
342 if self.shape.is_fully_defined():
343 return linear_operator_util.shape_tensor(self.shape.as_list())
344 else:
345 return self._shape_tensor()
347 @property
348 def batch_shape(self):
349 """`TensorShape` of batch dimensions of this `LinearOperator`.
351 If this operator acts like the batch matrix `A` with
352 `A.shape = [B1,...,Bb, M, N]`, then this returns
353 `TensorShape([B1,...,Bb])`, equivalent to `A.shape[:-2]`
355 Returns:
356 `TensorShape`, statically determined, may be undefined.
357 """
358 # Derived classes get this "for free" once .shape is implemented.
359 return self.shape[:-2]
361 def batch_shape_tensor(self, name="batch_shape_tensor"):
362 """Shape of batch dimensions of this operator, determined at runtime.
364 If this operator acts like the batch matrix `A` with
365 `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding
366 `[B1,...,Bb]`.
368 Args:
369 name: A name for this `Op`.
371 Returns:
372 `int32` `Tensor`
373 """
374 # Derived classes get this "for free" once .shape() is implemented.
375 with self._name_scope(name): # pylint: disable=not-callable
376 return self._batch_shape_tensor()
378 def _batch_shape_tensor(self, shape=None):
379 # `shape` may be passed in if this can be pre-computed in a
380 # more efficient manner, e.g. without excessive Tensor conversions.
381 if self.batch_shape.is_fully_defined():
382 return linear_operator_util.shape_tensor(
383 self.batch_shape.as_list(), name="batch_shape")
384 else:
385 shape = self.shape_tensor() if shape is None else shape
386 return shape[:-2]
388 @property
389 def tensor_rank(self, name="tensor_rank"):
390 """Rank (in the sense of tensors) of matrix corresponding to this operator.
392 If this operator acts like the batch matrix `A` with
393 `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`.
395 Args:
396 name: A name for this `Op`.
398 Returns:
399 Python integer, or None if the tensor rank is undefined.
400 """
401 # Derived classes get this "for free" once .shape() is implemented.
402 with self._name_scope(name): # pylint: disable=not-callable
403 return self.shape.ndims
405 def tensor_rank_tensor(self, name="tensor_rank_tensor"):
406 """Rank (in the sense of tensors) of matrix corresponding to this operator.
408 If this operator acts like the batch matrix `A` with
409 `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`.
411 Args:
412 name: A name for this `Op`.
414 Returns:
415 `int32` `Tensor`, determined at runtime.
416 """
417 # Derived classes get this "for free" once .shape() is implemented.
418 with self._name_scope(name): # pylint: disable=not-callable
419 return self._tensor_rank_tensor()
421 def _tensor_rank_tensor(self, shape=None):
422 # `shape` may be passed in if this can be pre-computed in a
423 # more efficient manner, e.g. without excessive Tensor conversions.
424 if self.tensor_rank is not None:
425 return tensor_conversion.convert_to_tensor_v2_with_dispatch(
426 self.tensor_rank
427 )
428 else:
429 shape = self.shape_tensor() if shape is None else shape
430 return array_ops.size(shape)
432 @property
433 def domain_dimension(self):
434 """Dimension (in the sense of vector spaces) of the domain of this operator.
436 If this operator acts like the batch matrix `A` with
437 `A.shape = [B1,...,Bb, M, N]`, then this returns `N`.
439 Returns:
440 `Dimension` object.
441 """
442 # Derived classes get this "for free" once .shape is implemented.
443 if self.shape.rank is None:
444 return tensor_shape.Dimension(None)
445 else:
446 return self.shape.dims[-1]
448 def domain_dimension_tensor(self, name="domain_dimension_tensor"):
449 """Dimension (in the sense of vector spaces) of the domain of this operator.
451 Determined at runtime.
453 If this operator acts like the batch matrix `A` with
454 `A.shape = [B1,...,Bb, M, N]`, then this returns `N`.
456 Args:
457 name: A name for this `Op`.
459 Returns:
460 `int32` `Tensor`
461 """
462 # Derived classes get this "for free" once .shape() is implemented.
463 with self._name_scope(name): # pylint: disable=not-callable
464 return self._domain_dimension_tensor()
466 def _domain_dimension_tensor(self, shape=None):
467 # `shape` may be passed in if this can be pre-computed in a
468 # more efficient manner, e.g. without excessive Tensor conversions.
469 dim_value = tensor_shape.dimension_value(self.domain_dimension)
470 if dim_value is not None:
471 return tensor_conversion.convert_to_tensor_v2_with_dispatch(dim_value)
472 else:
473 shape = self.shape_tensor() if shape is None else shape
474 return shape[-1]
476 @property
477 def range_dimension(self):
478 """Dimension (in the sense of vector spaces) of the range of this operator.
480 If this operator acts like the batch matrix `A` with
481 `A.shape = [B1,...,Bb, M, N]`, then this returns `M`.
483 Returns:
484 `Dimension` object.
485 """
486 # Derived classes get this "for free" once .shape is implemented.
487 if self.shape.dims:
488 return self.shape.dims[-2]
489 else:
490 return tensor_shape.Dimension(None)
492 def range_dimension_tensor(self, name="range_dimension_tensor"):
493 """Dimension (in the sense of vector spaces) of the range of this operator.
495 Determined at runtime.
497 If this operator acts like the batch matrix `A` with
498 `A.shape = [B1,...,Bb, M, N]`, then this returns `M`.
500 Args:
501 name: A name for this `Op`.
503 Returns:
504 `int32` `Tensor`
505 """
506 # Derived classes get this "for free" once .shape() is implemented.
507 with self._name_scope(name): # pylint: disable=not-callable
508 return self._range_dimension_tensor()
510 def _range_dimension_tensor(self, shape=None):
511 # `shape` may be passed in if this can be pre-computed in a
512 # more efficient manner, e.g. without excessive Tensor conversions.
513 dim_value = tensor_shape.dimension_value(self.range_dimension)
514 if dim_value is not None:
515 return tensor_conversion.convert_to_tensor_v2_with_dispatch(dim_value)
516 else:
517 shape = self.shape_tensor() if shape is None else shape
518 return shape[-2]
520 def _assert_non_singular(self):
521 """Private default implementation of _assert_non_singular."""
522 logging.warn(
523 "Using (possibly slow) default implementation of assert_non_singular."
524 " Requires conversion to a dense matrix and O(N^3) operations.")
525 if self._can_use_cholesky():
526 return self.assert_positive_definite()
527 else:
528 singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False)
529 # TODO(langmore) Add .eig and .cond as methods.
530 cond = (math_ops.reduce_max(singular_values, axis=-1) /
531 math_ops.reduce_min(singular_values, axis=-1))
532 return check_ops.assert_less(
533 cond,
534 self._max_condition_number_to_be_non_singular(),
535 message="Singular matrix up to precision epsilon.")
537 def _max_condition_number_to_be_non_singular(self):
538 """Return the maximum condition number that we consider nonsingular."""
539 with ops.name_scope("max_nonsingular_condition_number"):
540 dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps
541 eps = math_ops.cast(
542 math_ops.reduce_max([
543 100.,
544 math_ops.cast(self.range_dimension_tensor(), self.dtype),
545 math_ops.cast(self.domain_dimension_tensor(), self.dtype)
546 ]), self.dtype) * dtype_eps
547 return 1. / eps
549 def assert_non_singular(self, name="assert_non_singular"):
550 """Returns an `Op` that asserts this operator is non singular.
552 This operator is considered non-singular if
554 ```
555 ConditionNumber < max{100, range_dimension, domain_dimension} * eps,
556 eps := np.finfo(self.dtype.as_numpy_dtype).eps
557 ```
559 Args:
560 name: A string name to prepend to created ops.
562 Returns:
563 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
564 the operator is singular.
565 """
566 with self._name_scope(name): # pylint: disable=not-callable
567 return self._assert_non_singular()
569 def _assert_positive_definite(self):
570 """Default implementation of _assert_positive_definite."""
571 logging.warn(
572 "Using (possibly slow) default implementation of "
573 "assert_positive_definite."
574 " Requires conversion to a dense matrix and O(N^3) operations.")
575 # If the operator is self-adjoint, then checking that
576 # Cholesky decomposition succeeds + results in positive diag is necessary
577 # and sufficient.
578 if self.is_self_adjoint:
579 return check_ops.assert_positive(
580 array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())),
581 message="Matrix was not positive definite.")
582 # We have no generic check for positive definite.
583 raise NotImplementedError("assert_positive_definite is not implemented.")
585 def assert_positive_definite(self, name="assert_positive_definite"):
586 """Returns an `Op` that asserts this operator is positive definite.
588 Here, positive definite means that the quadratic form `x^H A x` has positive
589 real part for all nonzero `x`. Note that we do not require the operator to
590 be self-adjoint to be positive definite.
592 Args:
593 name: A name to give this `Op`.
595 Returns:
596 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
597 the operator is not positive definite.
598 """
599 with self._name_scope(name): # pylint: disable=not-callable
600 return self._assert_positive_definite()
602 def _assert_self_adjoint(self):
603 dense = self.to_dense()
604 logging.warn(
605 "Using (possibly slow) default implementation of assert_self_adjoint."
606 " Requires conversion to a dense matrix.")
607 return check_ops.assert_equal(
608 dense,
609 linalg.adjoint(dense),
610 message="Matrix was not equal to its adjoint.")
612 def assert_self_adjoint(self, name="assert_self_adjoint"):
613 """Returns an `Op` that asserts this operator is self-adjoint.
615 Here we check that this operator is *exactly* equal to its hermitian
616 transpose.
618 Args:
619 name: A string name to prepend to created ops.
621 Returns:
622 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
623 the operator is not self-adjoint.
624 """
625 with self._name_scope(name): # pylint: disable=not-callable
626 return self._assert_self_adjoint()
628 def _check_input_dtype(self, arg):
629 """Check that arg.dtype == self.dtype."""
630 if arg.dtype.base_dtype != self.dtype:
631 raise TypeError(
632 "Expected argument to have dtype %s. Found: %s in tensor %s" %
633 (self.dtype, arg.dtype, arg))
635 @abc.abstractmethod
636 def _matmul(self, x, adjoint=False, adjoint_arg=False):
637 raise NotImplementedError("_matmul is not implemented.")
639 def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
640 """Transform [batch] matrix `x` with left multiplication: `x --> Ax`.
642 ```python
643 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N]
644 operator = LinearOperator(...)
645 operator.shape = [..., M, N]
647 X = ... # shape [..., N, R], batch matrix, R > 0.
649 Y = operator.matmul(X)
650 Y.shape
651 ==> [..., M, R]
653 Y[..., :, r] = sum_j A[..., :, j] X[j, r]
654 ```
656 Args:
657 x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as
658 `self`. See class docstring for definition of compatibility.
659 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`.
660 adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is
661 the hermitian transpose (transposition and complex conjugation).
662 name: A name for this `Op`.
664 Returns:
665 A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
666 as `self`.
667 """
668 if isinstance(x, LinearOperator):
669 left_operator = self.adjoint() if adjoint else self
670 right_operator = x.adjoint() if adjoint_arg else x
672 if (right_operator.range_dimension is not None and
673 left_operator.domain_dimension is not None and
674 right_operator.range_dimension != left_operator.domain_dimension):
675 raise ValueError(
676 "Operators are incompatible. Expected `x` to have dimension"
677 " {} but got {}.".format(
678 left_operator.domain_dimension, right_operator.range_dimension))
679 with self._name_scope(name): # pylint: disable=not-callable
680 return linear_operator_algebra.matmul(left_operator, right_operator)
682 with self._name_scope(name): # pylint: disable=not-callable
683 x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x, name="x")
684 self._check_input_dtype(x)
686 self_dim = -2 if adjoint else -1
687 arg_dim = -1 if adjoint_arg else -2
688 tensor_shape.dimension_at_index(
689 self.shape, self_dim).assert_is_compatible_with(
690 x.shape[arg_dim])
692 return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
694 def __matmul__(self, other):
695 return self.matmul(other)
697 def _matvec(self, x, adjoint=False):
698 x_mat = array_ops.expand_dims(x, axis=-1)
699 y_mat = self.matmul(x_mat, adjoint=adjoint)
700 return array_ops.squeeze(y_mat, axis=-1)
702 def matvec(self, x, adjoint=False, name="matvec"):
703 """Transform [batch] vector `x` with left multiplication: `x --> Ax`.
705 ```python
706 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N]
707 operator = LinearOperator(...)
709 X = ... # shape [..., N], batch vector
711 Y = operator.matvec(X)
712 Y.shape
713 ==> [..., M]
715 Y[..., :] = sum_j A[..., :, j] X[..., j]
716 ```
718 Args:
719 x: `Tensor` with compatible shape and same `dtype` as `self`.
720 `x` is treated as a [batch] vector meaning for every set of leading
721 dimensions, the last dimension defines a vector.
722 See class docstring for definition of compatibility.
723 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`.
724 name: A name for this `Op`.
726 Returns:
727 A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
728 """
729 with self._name_scope(name): # pylint: disable=not-callable
730 x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x, name="x")
731 self._check_input_dtype(x)
732 self_dim = -2 if adjoint else -1
733 tensor_shape.dimension_at_index(
734 self.shape, self_dim).assert_is_compatible_with(x.shape[-1])
735 return self._matvec(x, adjoint=adjoint)
737 def _determinant(self):
738 logging.warn(
739 "Using (possibly slow) default implementation of determinant."
740 " Requires conversion to a dense matrix and O(N^3) operations.")
741 if self._can_use_cholesky():
742 return math_ops.exp(self.log_abs_determinant())
743 return linalg_ops.matrix_determinant(self.to_dense())
745 def determinant(self, name="det"):
746 """Determinant for every batch member.
748 Args:
749 name: A name for this `Op`.
751 Returns:
752 `Tensor` with shape `self.batch_shape` and same `dtype` as `self`.
754 Raises:
755 NotImplementedError: If `self.is_square` is `False`.
756 """
757 if self.is_square is False:
758 raise NotImplementedError(
759 "Determinant not implemented for an operator that is expected to "
760 "not be square.")
761 with self._name_scope(name): # pylint: disable=not-callable
762 return self._determinant()
764 def _log_abs_determinant(self):
765 logging.warn(
766 "Using (possibly slow) default implementation of determinant."
767 " Requires conversion to a dense matrix and O(N^3) operations.")
768 if self._can_use_cholesky():
769 diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense()))
770 return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1])
771 _, log_abs_det = linalg.slogdet(self.to_dense())
772 return log_abs_det
774 def log_abs_determinant(self, name="log_abs_det"):
775 """Log absolute value of determinant for every batch member.
777 Args:
778 name: A name for this `Op`.
780 Returns:
781 `Tensor` with shape `self.batch_shape` and same `dtype` as `self`.
783 Raises:
784 NotImplementedError: If `self.is_square` is `False`.
785 """
786 if self.is_square is False:
787 raise NotImplementedError(
788 "Determinant not implemented for an operator that is expected to "
789 "not be square.")
790 with self._name_scope(name): # pylint: disable=not-callable
791 return self._log_abs_determinant()
793 def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False):
794 """Solve by conversion to a dense matrix."""
795 if self.is_square is False: # pylint: disable=g-bool-id-comparison
796 raise NotImplementedError(
797 "Solve is not yet implemented for non-square operators.")
798 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
799 if self._can_use_cholesky():
800 return linalg_ops.cholesky_solve(
801 linalg_ops.cholesky(self.to_dense()), rhs)
802 return linear_operator_util.matrix_solve_with_broadcast(
803 self.to_dense(), rhs, adjoint=adjoint)
805 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
806 """Default implementation of _solve."""
807 logging.warn(
808 "Using (possibly slow) default implementation of solve."
809 " Requires conversion to a dense matrix and O(N^3) operations.")
810 return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
812 def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
813 """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
815 The returned `Tensor` will be close to an exact solution if `A` is well
816 conditioned. Otherwise closeness will vary. See class docstring for details.
818 Examples:
820 ```python
821 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N]
822 operator = LinearOperator(...)
823 operator.shape = [..., M, N]
825 # Solve R > 0 linear systems for every member of the batch.
826 RHS = ... # shape [..., M, R]
828 X = operator.solve(RHS)
829 # X[..., :, r] is the solution to the r'th linear system
830 # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]
832 operator.matmul(X)
833 ==> RHS
834 ```
836 Args:
837 rhs: `Tensor` with same `dtype` as this operator and compatible shape.
838 `rhs` is treated like a [batch] matrix meaning for every set of leading
839 dimensions, the last two dimensions defines a matrix.
840 See class docstring for definition of compatibility.
841 adjoint: Python `bool`. If `True`, solve the system involving the adjoint
842 of this `LinearOperator`: `A^H X = rhs`.
843 adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H`
844 is the hermitian transpose (transposition and complex conjugation).
845 name: A name scope to use for ops added by this method.
847 Returns:
848 `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.
850 Raises:
851 NotImplementedError: If `self.is_non_singular` or `is_square` is False.
852 """
853 if self.is_non_singular is False:
854 raise NotImplementedError(
855 "Exact solve not implemented for an operator that is expected to "
856 "be singular.")
857 if self.is_square is False:
858 raise NotImplementedError(
859 "Exact solve not implemented for an operator that is expected to "
860 "not be square.")
861 if isinstance(rhs, LinearOperator):
862 left_operator = self.adjoint() if adjoint else self
863 right_operator = rhs.adjoint() if adjoint_arg else rhs
865 if (right_operator.range_dimension is not None and
866 left_operator.domain_dimension is not None and
867 right_operator.range_dimension != left_operator.domain_dimension):
868 raise ValueError(
869 "Operators are incompatible. Expected `rhs` to have dimension"
870 " {} but got {}.".format(
871 left_operator.domain_dimension, right_operator.range_dimension))
872 with self._name_scope(name): # pylint: disable=not-callable
873 return linear_operator_algebra.solve(left_operator, right_operator)
875 with self._name_scope(name): # pylint: disable=not-callable
876 rhs = tensor_conversion.convert_to_tensor_v2_with_dispatch(
877 rhs, name="rhs"
878 )
879 self._check_input_dtype(rhs)
881 self_dim = -1 if adjoint else -2
882 arg_dim = -1 if adjoint_arg else -2
883 tensor_shape.dimension_at_index(
884 self.shape, self_dim).assert_is_compatible_with(
885 rhs.shape[arg_dim])
887 return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
889 def _solvevec(self, rhs, adjoint=False):
890 """Default implementation of _solvevec."""
891 rhs_mat = array_ops.expand_dims(rhs, axis=-1)
892 solution_mat = self.solve(rhs_mat, adjoint=adjoint)
893 return array_ops.squeeze(solution_mat, axis=-1)
895 def solvevec(self, rhs, adjoint=False, name="solve"):
896 """Solve single equation with best effort: `A X = rhs`.
898 The returned `Tensor` will be close to an exact solution if `A` is well
899 conditioned. Otherwise closeness will vary. See class docstring for details.
901 Examples:
903 ```python
904 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N]
905 operator = LinearOperator(...)
906 operator.shape = [..., M, N]
908 # Solve one linear system for every member of the batch.
909 RHS = ... # shape [..., M]
911 X = operator.solvevec(RHS)
912 # X is the solution to the linear system
913 # sum_j A[..., :, j] X[..., j] = RHS[..., :]
915 operator.matvec(X)
916 ==> RHS
917 ```
919 Args:
920 rhs: `Tensor` with same `dtype` as this operator.
921 `rhs` is treated like a [batch] vector meaning for every set of leading
922 dimensions, the last dimension defines a vector. See class docstring
923 for definition of compatibility regarding batch dimensions.
924 adjoint: Python `bool`. If `True`, solve the system involving the adjoint
925 of this `LinearOperator`: `A^H X = rhs`.
926 name: A name scope to use for ops added by this method.
928 Returns:
929 `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.
931 Raises:
932 NotImplementedError: If `self.is_non_singular` or `is_square` is False.
933 """
934 with self._name_scope(name): # pylint: disable=not-callable
935 rhs = tensor_conversion.convert_to_tensor_v2_with_dispatch(
936 rhs, name="rhs"
937 )
938 self._check_input_dtype(rhs)
939 self_dim = -1 if adjoint else -2
940 tensor_shape.dimension_at_index(
941 self.shape, self_dim).assert_is_compatible_with(rhs.shape[-1])
943 return self._solvevec(rhs, adjoint=adjoint)
945 def adjoint(self, name="adjoint"):
946 """Returns the adjoint of the current `LinearOperator`.
948 Given `A` representing this `LinearOperator`, return `A*`.
949 Note that calling `self.adjoint()` and `self.H` are equivalent.
951 Args:
952 name: A name for this `Op`.
954 Returns:
955 `LinearOperator` which represents the adjoint of this `LinearOperator`.
956 """
957 if self.is_self_adjoint is True: # pylint: disable=g-bool-id-comparison
958 return self
959 with self._name_scope(name): # pylint: disable=not-callable
960 return linear_operator_algebra.adjoint(self)
962 # self.H is equivalent to self.adjoint().
963 H = property(adjoint, None)
965 def inverse(self, name="inverse"):
966 """Returns the Inverse of this `LinearOperator`.
968 Given `A` representing this `LinearOperator`, return a `LinearOperator`
969 representing `A^-1`.
971 Args:
972 name: A name scope to use for ops added by this method.
974 Returns:
975 `LinearOperator` representing inverse of this matrix.
977 Raises:
978 ValueError: When the `LinearOperator` is not hinted to be `non_singular`.
979 """
980 if self.is_square is False: # pylint: disable=g-bool-id-comparison
981 raise ValueError("Cannot take the Inverse: This operator represents "
982 "a non square matrix.")
983 if self.is_non_singular is False: # pylint: disable=g-bool-id-comparison
984 raise ValueError("Cannot take the Inverse: This operator represents "
985 "a singular matrix.")
987 with self._name_scope(name): # pylint: disable=not-callable
988 return linear_operator_algebra.inverse(self)
990 def cholesky(self, name="cholesky"):
991 """Returns a Cholesky factor as a `LinearOperator`.
993 Given `A` representing this `LinearOperator`, if `A` is positive definite
994 self-adjoint, return `L`, where `A = L L^T`, i.e. the cholesky
995 decomposition.
997 Args:
998 name: A name for this `Op`.
1000 Returns:
1001 `LinearOperator` which represents the lower triangular matrix
1002 in the Cholesky decomposition.
1004 Raises:
1005 ValueError: When the `LinearOperator` is not hinted to be positive
1006 definite and self adjoint.
1007 """
1009 if not self._can_use_cholesky():
1010 raise ValueError("Cannot take the Cholesky decomposition: "
1011 "Not a positive definite self adjoint matrix.")
1012 with self._name_scope(name): # pylint: disable=not-callable
1013 return linear_operator_algebra.cholesky(self)
1015 def _to_dense(self):
1016 """Generic and often inefficient implementation. Override often."""
1017 if self.batch_shape.is_fully_defined():
1018 batch_shape = self.batch_shape
1019 else:
1020 batch_shape = self.batch_shape_tensor()
1022 dim_value = tensor_shape.dimension_value(self.domain_dimension)
1023 if dim_value is not None:
1024 n = dim_value
1025 else:
1026 n = self.domain_dimension_tensor()
1028 eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype)
1029 return self.matmul(eye)
1031 def to_dense(self, name="to_dense"):
1032 """Return a dense (batch) matrix representing this operator."""
1033 with self._name_scope(name): # pylint: disable=not-callable
1034 return self._to_dense()
1036 def _diag_part(self):
1037 """Generic and often inefficient implementation. Override often."""
1038 return array_ops.matrix_diag_part(self.to_dense())
1040 def diag_part(self, name="diag_part"):
1041 """Efficiently get the [batch] diagonal part of this operator.
1043 If this operator has shape `[B1,...,Bb, M, N]`, this returns a
1044 `Tensor` `diagonal`, of shape `[B1,...,Bb, min(M, N)]`, where
1045 `diagonal[b1,...,bb, i] = self.to_dense()[b1,...,bb, i, i]`.
1047 ```
1048 my_operator = LinearOperatorDiag([1., 2.])
1050 # Efficiently get the diagonal
1051 my_operator.diag_part()
1052 ==> [1., 2.]
1054 # Equivalent, but inefficient method
1055 tf.linalg.diag_part(my_operator.to_dense())
1056 ==> [1., 2.]
1057 ```
1059 Args:
1060 name: A name for this `Op`.
1062 Returns:
1063 diag_part: A `Tensor` of same `dtype` as self.
1064 """
1065 with self._name_scope(name): # pylint: disable=not-callable
1066 return self._diag_part()
1068 def _trace(self):
1069 return math_ops.reduce_sum(self.diag_part(), axis=-1)
1071 def trace(self, name="trace"):
1072 """Trace of the linear operator, equal to sum of `self.diag_part()`.
1074 If the operator is square, this is also the sum of the eigenvalues.
1076 Args:
1077 name: A name for this `Op`.
1079 Returns:
1080 Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`.
1081 """
1082 with self._name_scope(name): # pylint: disable=not-callable
1083 return self._trace()
1085 def _add_to_tensor(self, x):
1086 # Override if a more efficient implementation is available.
1087 return self.to_dense() + x
1089 def add_to_tensor(self, x, name="add_to_tensor"):
1090 """Add matrix represented by this operator to `x`. Equivalent to `A + x`.
1092 Args:
1093 x: `Tensor` with same `dtype` and shape broadcastable to `self.shape`.
1094 name: A name to give this `Op`.
1096 Returns:
1097 A `Tensor` with broadcast shape and same `dtype` as `self`.
1098 """
1099 with self._name_scope(name): # pylint: disable=not-callable
1100 x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x, name="x")
1101 self._check_input_dtype(x)
1102 return self._add_to_tensor(x)
1104 def _eigvals(self):
1105 return linalg_ops.self_adjoint_eigvals(self.to_dense())
1107 def eigvals(self, name="eigvals"):
1108 """Returns the eigenvalues of this linear operator.
1110 If the operator is marked as self-adjoint (via `is_self_adjoint`)
1111 this computation can be more efficient.
1113 Note: This currently only supports self-adjoint operators.
1115 Args:
1116 name: A name for this `Op`.
1118 Returns:
1119 Shape `[B1,...,Bb, N]` `Tensor` of same `dtype` as `self`.
1120 """
1121 if not self.is_self_adjoint:
1122 raise NotImplementedError("Only self-adjoint matrices are supported.")
1123 with self._name_scope(name): # pylint: disable=not-callable
1124 return self._eigvals()
1126 def _cond(self):
1127 if not self.is_self_adjoint:
1128 # In general the condition number is the ratio of the
1129 # absolute value of the largest and smallest singular values.
1130 vals = linalg_ops.svd(self.to_dense(), compute_uv=False)
1131 else:
1132 # For self-adjoint matrices, and in general normal matrices,
1133 # we can use eigenvalues.
1134 vals = math_ops.abs(self._eigvals())
1136 return (math_ops.reduce_max(vals, axis=-1) /
1137 math_ops.reduce_min(vals, axis=-1))
1139 def cond(self, name="cond"):
1140 """Returns the condition number of this linear operator.
1142 Args:
1143 name: A name for this `Op`.
1145 Returns:
1146 Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`.
1147 """
1148 with self._name_scope(name): # pylint: disable=not-callable
1149 return self._cond()
1151 def _can_use_cholesky(self):
1152 return self.is_self_adjoint and self.is_positive_definite
1154 def _set_graph_parents(self, graph_parents):
1155 """Set self._graph_parents. Called during derived class init.
1157 This method allows derived classes to set graph_parents, without triggering
1158 a deprecation warning (which is invoked if `graph_parents` is passed during
1159 `__init__`.
1161 Args:
1162 graph_parents: Iterable over Tensors.
1163 """
1164 # TODO(b/143910018) Remove this function in V3.
1165 graph_parents = [] if graph_parents is None else graph_parents
1166 for i, t in enumerate(graph_parents):
1167 if t is None or not (linear_operator_util.is_ref(t) or
1168 tensor_util.is_tf_type(t)):
1169 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
1170 self._graph_parents = graph_parents
1172 @property
1173 def _composite_tensor_fields(self):
1174 """A tuple of parameter names to rebuild the `LinearOperator`.
1176 The tuple contains the names of kwargs to the `LinearOperator`'s constructor
1177 that the `TypeSpec` needs to rebuild the `LinearOperator` instance.
1179 "is_non_singular", "is_self_adjoint", "is_positive_definite", and
1180 "is_square" are common to all `LinearOperator` subclasses and may be
1181 omitted.
1182 """
1183 return ()
1185 @property
1186 def _composite_tensor_prefer_static_fields(self):
1187 """A tuple of names referring to parameters that may be treated statically.
1189 This is a subset of `_composite_tensor_fields`, and contains the names of
1190 of `Tensor`-like args to the `LinearOperator`s constructor that may be
1191 stored as static values, if they are statically known. These are typically
1192 shapes or axis values.
1193 """
1194 return ()
1196 @property
1197 def _type_spec(self):
1198 # This property will be overwritten by the `@make_composite_tensor`
1199 # decorator. However, we need it so that a valid subclass of the `ABCMeta`
1200 # class `CompositeTensor` can be constructed and passed to the
1201 # `@make_composite_tensor` decorator.
1202 pass
1204 def _convert_variables_to_tensors(self):
1205 """Recursively converts ResourceVariables in the LinearOperator to Tensors.
1207 The usage of `self._type_spec._from_components` violates the contract of
1208 `CompositeTensor`, since it is called on a different nested structure
1209 (one containing only `Tensor`s) than `self.type_spec` specifies (one that
1210 may contain `ResourceVariable`s). Since `LinearOperator`'s
1211 `_from_components` method just passes the contents of the nested structure
1212 to `__init__` to rebuild the operator, and any `LinearOperator` that may be
1213 instantiated with `ResourceVariables` may also be instantiated with
1214 `Tensor`s, this usage is valid.
1216 Returns:
1217 tensor_operator: `self` with all internal Variables converted to Tensors.
1218 """
1219 # pylint: disable=protected-access
1220 components = self._type_spec._to_components(self)
1221 tensor_components = variable_utils.convert_variables_to_tensors(
1222 components)
1223 return self._type_spec._from_components(tensor_components)
1224 # pylint: enable=protected-access
1226 def __getitem__(self, slices):
1227 return slicing.batch_slice(self, params_overrides={}, slices=slices)
1229 @property
1230 def _experimental_parameter_ndims_to_matrix_ndims(self):
1231 """A dict of names to number of dimensions contributing to an operator.
1233 This is a dictionary of parameter names to `int`s specifying the
1234 number of right-most dimensions contributing to the **matrix** shape of the
1235 densified operator.
1236 If the parameter is a `Tensor`, this is mapped to an `int`.
1237 If the parameter is a `LinearOperator` (called `A`), this specifies the
1238 number of batch dimensions of `A` contributing to this `LinearOperator`s
1239 matrix shape.
1240 If the parameter is a structure, this is a structure of the same type of
1241 `int`s.
1242 """
1243 return ()
1246class _LinearOperatorSpec(type_spec.BatchableTypeSpec):
1247 """A tf.TypeSpec for `LinearOperator` objects."""
1249 __slots__ = ("_param_specs", "_non_tensor_params", "_prefer_static_fields")
1251 def __init__(self, param_specs, non_tensor_params, prefer_static_fields):
1252 """Initializes a new `_LinearOperatorSpec`.
1254 Args:
1255 param_specs: Python `dict` of `tf.TypeSpec` instances that describe
1256 kwargs to the `LinearOperator`'s constructor that are `Tensor`-like or
1257 `CompositeTensor` subclasses.
1258 non_tensor_params: Python `dict` containing non-`Tensor` and non-
1259 `CompositeTensor` kwargs to the `LinearOperator`'s constructor.
1260 prefer_static_fields: Python `tuple` of strings corresponding to the names
1261 of `Tensor`-like args to the `LinearOperator`s constructor that may be
1262 stored as static values, if known. These are typically shapes, indices,
1263 or axis values.
1264 """
1265 self._param_specs = param_specs
1266 self._non_tensor_params = non_tensor_params
1267 self._prefer_static_fields = prefer_static_fields
1269 @classmethod
1270 def from_operator(cls, operator):
1271 """Builds a `_LinearOperatorSpec` from a `LinearOperator` instance.
1273 Args:
1274 operator: An instance of `LinearOperator`.
1276 Returns:
1277 linear_operator_spec: An instance of `_LinearOperatorSpec` to be used as
1278 the `TypeSpec` of `operator`.
1279 """
1280 validation_fields = ("is_non_singular", "is_self_adjoint",
1281 "is_positive_definite", "is_square")
1282 kwargs = _extract_attrs(
1283 operator,
1284 keys=set(operator._composite_tensor_fields + validation_fields)) # pylint: disable=protected-access
1286 non_tensor_params = {}
1287 param_specs = {}
1288 for k, v in list(kwargs.items()):
1289 type_spec_or_v = _extract_type_spec_recursively(v)
1290 is_tensor = [isinstance(x, type_spec.TypeSpec)
1291 for x in nest.flatten(type_spec_or_v)]
1292 if all(is_tensor):
1293 param_specs[k] = type_spec_or_v
1294 elif not any(is_tensor):
1295 non_tensor_params[k] = v
1296 else:
1297 raise NotImplementedError(f"Field {k} contains a mix of `Tensor` and "
1298 f" non-`Tensor` values.")
1300 return cls(
1301 param_specs=param_specs,
1302 non_tensor_params=non_tensor_params,
1303 prefer_static_fields=operator._composite_tensor_prefer_static_fields) # pylint: disable=protected-access
1305 def _to_components(self, obj):
1306 return _extract_attrs(obj, keys=list(self._param_specs))
1308 def _from_components(self, components):
1309 kwargs = dict(self._non_tensor_params, **components)
1310 return self.value_type(**kwargs)
1312 @property
1313 def _component_specs(self):
1314 return self._param_specs
1316 def _serialize(self):
1317 return (self._param_specs,
1318 self._non_tensor_params,
1319 self._prefer_static_fields)
1321 def _copy(self, **overrides):
1322 kwargs = {
1323 "param_specs": self._param_specs,
1324 "non_tensor_params": self._non_tensor_params,
1325 "prefer_static_fields": self._prefer_static_fields
1326 }
1327 kwargs.update(overrides)
1328 return type(self)(**kwargs)
1330 def _batch(self, batch_size):
1331 """Returns a TypeSpec representing a batch of objects with this TypeSpec."""
1332 return self._copy(
1333 param_specs=nest.map_structure(
1334 lambda spec: spec._batch(batch_size), # pylint: disable=protected-access
1335 self._param_specs))
1337 def _unbatch(self, batch_size):
1338 """Returns a TypeSpec representing a single element of this TypeSpec."""
1339 return self._copy(
1340 param_specs=nest.map_structure(
1341 lambda spec: spec._unbatch(), # pylint: disable=protected-access
1342 self._param_specs))
1345def make_composite_tensor(cls, module_name="tf.linalg"):
1346 """Class decorator to convert `LinearOperator`s to `CompositeTensor`."""
1348 spec_name = "{}Spec".format(cls.__name__)
1349 spec_type = type(spec_name, (_LinearOperatorSpec,), {"value_type": cls})
1350 type_spec_registry.register("{}.{}".format(module_name, spec_name))(spec_type)
1351 cls._type_spec = property(spec_type.from_operator) # pylint: disable=protected-access
1352 return cls
1355def _extract_attrs(op, keys):
1356 """Extract constructor kwargs to reconstruct `op`.
1358 Args:
1359 op: A `LinearOperator` instance.
1360 keys: A Python `tuple` of strings indicating the names of the constructor
1361 kwargs to extract from `op`.
1363 Returns:
1364 kwargs: A Python `dict` of kwargs to `op`'s constructor, keyed by `keys`.
1365 """
1367 kwargs = {}
1368 not_found = object()
1369 for k in keys:
1370 srcs = [
1371 getattr(op, k, not_found), getattr(op, "_" + k, not_found),
1372 getattr(op, "parameters", {}).get(k, not_found),
1373 ]
1374 if any(v is not not_found for v in srcs):
1375 kwargs[k] = [v for v in srcs if v is not not_found][0]
1376 else:
1377 raise ValueError(
1378 f"Could not determine an appropriate value for field `{k}` in object "
1379 f" `{op}`. Looked for \n"
1380 f" 1. an attr called `{k}`,\n"
1381 f" 2. an attr called `_{k}`,\n"
1382 f" 3. an entry in `op.parameters` with key '{k}'.")
1383 if k in op._composite_tensor_prefer_static_fields and kwargs[k] is not None: # pylint: disable=protected-access
1384 if tensor_util.is_tensor(kwargs[k]):
1385 static_val = tensor_util.constant_value(kwargs[k])
1386 if static_val is not None:
1387 kwargs[k] = static_val
1388 if isinstance(kwargs[k], (np.ndarray, np.generic)):
1389 kwargs[k] = kwargs[k].tolist()
1390 return kwargs
1393def _extract_type_spec_recursively(value):
1394 """Return (collection of) `TypeSpec`(s) for `value` if it includes `Tensor`s.
1396 If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If
1397 `value` is a collection containing `Tensor` values, recursively supplant them
1398 with their respective `TypeSpec`s in a collection of parallel stucture.
1400 If `value` is none of the above, return it unchanged.
1402 Args:
1403 value: a Python `object` to (possibly) turn into a (collection of)
1404 `tf.TypeSpec`(s).
1406 Returns:
1407 spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value`
1408 or `value`, if no `Tensor`s are found.
1409 """
1410 if isinstance(value, composite_tensor.CompositeTensor):
1411 return value._type_spec # pylint: disable=protected-access
1412 if isinstance(value, variables.Variable):
1413 return resource_variable_ops.VariableSpec(
1414 value.shape, dtype=value.dtype, trainable=value.trainable)
1415 if tensor_util.is_tensor(value):
1416 return tensor_spec.TensorSpec(value.shape, value.dtype)
1417 # Unwrap trackable data structures to comply with `Type_Spec._serialize`
1418 # requirements. `ListWrapper`s are converted to `list`s, and for other
1419 # trackable data structures, the `__wrapped__` attribute is used.
1420 if isinstance(value, list):
1421 return list(_extract_type_spec_recursively(v) for v in value)
1422 if isinstance(value, data_structures.TrackableDataStructure):
1423 return _extract_type_spec_recursively(value.__wrapped__)
1424 if isinstance(value, tuple):
1425 return type(value)(_extract_type_spec_recursively(x) for x in value)
1426 if isinstance(value, dict):
1427 return type(value)((k, _extract_type_spec_recursively(v))
1428 for k, v in value.items())
1429 return value
1432# Overrides for tf.linalg functions. This allows a LinearOperator to be used in
1433# place of a Tensor.
1434# For instance tf.trace(linop) and linop.trace() both work.
1437@dispatch.dispatch_for_types(linalg.adjoint, LinearOperator)
1438def _adjoint(matrix, name=None):
1439 return matrix.adjoint(name)
1442@dispatch.dispatch_for_types(linalg.cholesky, LinearOperator)
1443def _cholesky(input, name=None): # pylint:disable=redefined-builtin
1444 return input.cholesky(name)
1447# The signature has to match with the one in python/op/array_ops.py,
1448# so we have k, padding_value, and align even though we don't use them here.
1449# pylint:disable=unused-argument
1450@dispatch.dispatch_for_types(linalg.diag_part, LinearOperator)
1451def _diag_part(
1452 input, # pylint:disable=redefined-builtin
1453 name="diag_part",
1454 k=0,
1455 padding_value=0,
1456 align="RIGHT_LEFT"):
1457 return input.diag_part(name)
1458# pylint:enable=unused-argument
1461@dispatch.dispatch_for_types(linalg.det, LinearOperator)
1462def _det(input, name=None): # pylint:disable=redefined-builtin
1463 return input.determinant(name)
1466@dispatch.dispatch_for_types(linalg.inv, LinearOperator)
1467def _inverse(input, adjoint=False, name=None): # pylint:disable=redefined-builtin
1468 inv = input.inverse(name)
1469 if adjoint:
1470 inv = inv.adjoint()
1471 return inv
1474@dispatch.dispatch_for_types(linalg.logdet, LinearOperator)
1475def _logdet(matrix, name=None):
1476 if matrix.is_positive_definite and matrix.is_self_adjoint:
1477 return matrix.log_abs_determinant(name)
1478 raise ValueError("Expected matrix to be self-adjoint positive definite.")
1481@dispatch.dispatch_for_types(math_ops.matmul, LinearOperator)
1482def _matmul( # pylint:disable=missing-docstring
1483 a,
1484 b,
1485 transpose_a=False,
1486 transpose_b=False,
1487 adjoint_a=False,
1488 adjoint_b=False,
1489 a_is_sparse=False,
1490 b_is_sparse=False,
1491 output_type=None, # pylint: disable=unused-argument
1492 name=None):
1493 if transpose_a or transpose_b:
1494 raise ValueError("Transposing not supported at this time.")
1495 if a_is_sparse or b_is_sparse:
1496 raise ValueError("Sparse methods not supported at this time.")
1497 if not isinstance(a, LinearOperator):
1498 # We use the identity (B^HA^H)^H = AB
1499 adjoint_matmul = b.matmul(
1500 a,
1501 adjoint=(not adjoint_b),
1502 adjoint_arg=(not adjoint_a),
1503 name=name)
1504 return linalg.adjoint(adjoint_matmul)
1505 return a.matmul(
1506 b, adjoint=adjoint_a, adjoint_arg=adjoint_b, name=name)
1509@dispatch.dispatch_for_types(linalg.solve, LinearOperator)
1510def _solve(
1511 matrix,
1512 rhs,
1513 adjoint=False,
1514 name=None):
1515 if not isinstance(matrix, LinearOperator):
1516 raise ValueError("Passing in `matrix` as a Tensor and `rhs` as a "
1517 "LinearOperator is not supported.")
1518 return matrix.solve(rhs, adjoint=adjoint, name=name)
1521@dispatch.dispatch_for_types(linalg.trace, LinearOperator)
1522def _trace(x, name=None):
1523 return x.trace(name)