Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_diag.py: 42%
88 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a diagonal matrix."""
17from tensorflow.python.framework import ops
18from tensorflow.python.framework import tensor_conversion
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import check_ops
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops.linalg import linalg_impl as linalg
23from tensorflow.python.ops.linalg import linear_operator
24from tensorflow.python.ops.linalg import linear_operator_util
25from tensorflow.python.util.tf_export import tf_export
27__all__ = ["LinearOperatorDiag",]
30@tf_export("linalg.LinearOperatorDiag")
31@linear_operator.make_composite_tensor
32class LinearOperatorDiag(linear_operator.LinearOperator):
33 """`LinearOperator` acting like a [batch] square diagonal matrix.
35 This operator acts like a [batch] diagonal matrix `A` with shape
36 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
37 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
38 an `N x N` matrix. This matrix `A` is not materialized, but for
39 purposes of broadcasting this shape will be relevant.
41 `LinearOperatorDiag` is initialized with a (batch) vector.
43 ```python
44 # Create a 2 x 2 diagonal linear operator.
45 diag = [1., -1.]
46 operator = LinearOperatorDiag(diag)
48 operator.to_dense()
49 ==> [[1., 0.]
50 [0., -1.]]
52 operator.shape
53 ==> [2, 2]
55 operator.log_abs_determinant()
56 ==> scalar Tensor
58 x = ... Shape [2, 4] Tensor
59 operator.matmul(x)
60 ==> Shape [2, 4] Tensor
62 # Create a [2, 3] batch of 4 x 4 linear operators.
63 diag = tf.random.normal(shape=[2, 3, 4])
64 operator = LinearOperatorDiag(diag)
66 # Create a shape [2, 1, 4, 2] vector. Note that this shape is compatible
67 # since the batch dimensions, [2, 1], are broadcast to
68 # operator.batch_shape = [2, 3].
69 y = tf.random.normal(shape=[2, 1, 4, 2])
70 x = operator.solve(y)
71 ==> operator.matmul(x) = y
72 ```
74 #### Shape compatibility
76 This operator acts on [batch] matrix with compatible shape.
77 `x` is a batch matrix with compatible shape for `matmul` and `solve` if
79 ```
80 operator.shape = [B1,...,Bb] + [N, N], with b >= 0
81 x.shape = [C1,...,Cc] + [N, R],
82 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
83 ```
85 #### Performance
87 Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`,
88 and `x.shape = [N, R]`. Then
90 * `operator.matmul(x)` involves `N * R` multiplications.
91 * `operator.solve(x)` involves `N` divisions and `N * R` multiplications.
92 * `operator.determinant()` involves a size `N` `reduce_prod`.
94 If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
95 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
97 #### Matrix property hints
99 This `LinearOperator` is initialized with boolean flags of the form `is_X`,
100 for `X = non_singular, self_adjoint, positive_definite, square`.
101 These have the following meaning:
103 * If `is_X == True`, callers should expect the operator to have the
104 property `X`. This is a promise that should be fulfilled, but is *not* a
105 runtime assert. For example, finite floating point precision may result
106 in these promises being violated.
107 * If `is_X == False`, callers should expect the operator to not have `X`.
108 * If `is_X == None` (the default), callers should have no expectation either
109 way.
110 """
112 def __init__(self,
113 diag,
114 is_non_singular=None,
115 is_self_adjoint=None,
116 is_positive_definite=None,
117 is_square=None,
118 name="LinearOperatorDiag"):
119 r"""Initialize a `LinearOperatorDiag`.
121 Args:
122 diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
123 The diagonal of the operator. Allowed dtypes: `float16`, `float32`,
124 `float64`, `complex64`, `complex128`.
125 is_non_singular: Expect that this operator is non-singular.
126 is_self_adjoint: Expect that this operator is equal to its hermitian
127 transpose. If `diag.dtype` is real, this is auto-set to `True`.
128 is_positive_definite: Expect that this operator is positive definite,
129 meaning the quadratic form `x^H A x` has positive real part for all
130 nonzero `x`. Note that we do not require the operator to be
131 self-adjoint to be positive-definite. See:
132 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
133 is_square: Expect that this operator acts like square [batch] matrices.
134 name: A name for this `LinearOperator`.
136 Raises:
137 TypeError: If `diag.dtype` is not an allowed type.
138 ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
139 """
140 parameters = dict(
141 diag=diag,
142 is_non_singular=is_non_singular,
143 is_self_adjoint=is_self_adjoint,
144 is_positive_definite=is_positive_definite,
145 is_square=is_square,
146 name=name
147 )
149 with ops.name_scope(name, values=[diag]):
150 self._diag = linear_operator_util.convert_nonref_to_tensor(
151 diag, name="diag")
152 self._check_diag(self._diag)
154 # Check and auto-set hints.
155 if not self._diag.dtype.is_complex:
156 if is_self_adjoint is False:
157 raise ValueError("A real diagonal operator is always self adjoint.")
158 else:
159 is_self_adjoint = True
161 if is_square is False:
162 raise ValueError("Only square diagonal operators currently supported.")
163 is_square = True
165 super(LinearOperatorDiag, self).__init__(
166 dtype=self._diag.dtype,
167 is_non_singular=is_non_singular,
168 is_self_adjoint=is_self_adjoint,
169 is_positive_definite=is_positive_definite,
170 is_square=is_square,
171 parameters=parameters,
172 name=name)
174 def _check_diag(self, diag):
175 """Static check of diag."""
176 if diag.shape.ndims is not None and diag.shape.ndims < 1:
177 raise ValueError("Argument diag must have at least 1 dimension. "
178 "Found: %s" % diag)
180 def _shape(self):
181 # If d_shape = [5, 3], we return [5, 3, 3].
182 d_shape = self._diag.shape
183 return d_shape.concatenate(d_shape[-1:])
185 def _shape_tensor(self):
186 d_shape = array_ops.shape(self._diag)
187 k = d_shape[-1]
188 return array_ops.concat((d_shape, [k]), 0)
190 @property
191 def diag(self):
192 return self._diag
194 def _assert_non_singular(self):
195 return linear_operator_util.assert_no_entries_with_modulus_zero(
196 self._diag,
197 message="Singular operator: Diagonal contained zero values.")
199 def _assert_positive_definite(self):
200 if self.dtype.is_complex:
201 message = (
202 "Diagonal operator had diagonal entries with non-positive real part, "
203 "thus was not positive definite.")
204 else:
205 message = (
206 "Real diagonal operator had non-positive diagonal entries, "
207 "thus was not positive definite.")
209 return check_ops.assert_positive(
210 math_ops.real(self._diag),
211 message=message)
213 def _assert_self_adjoint(self):
214 return linear_operator_util.assert_zero_imag_part(
215 self._diag,
216 message=(
217 "This diagonal operator contained non-zero imaginary values. "
218 " Thus it was not self-adjoint."))
220 def _matmul(self, x, adjoint=False, adjoint_arg=False):
221 diag_term = math_ops.conj(self._diag) if adjoint else self._diag
222 x = linalg.adjoint(x) if adjoint_arg else x
223 diag_mat = array_ops.expand_dims(diag_term, -1)
224 return diag_mat * x
226 def _matvec(self, x, adjoint=False):
227 diag_term = math_ops.conj(self._diag) if adjoint else self._diag
228 return diag_term * x
230 def _determinant(self):
231 return math_ops.reduce_prod(self._diag, axis=[-1])
233 def _log_abs_determinant(self):
234 log_det = math_ops.reduce_sum(
235 math_ops.log(math_ops.abs(self._diag)), axis=[-1])
236 if self.dtype.is_complex:
237 log_det = math_ops.cast(log_det, dtype=self.dtype)
238 return log_det
240 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
241 diag_term = math_ops.conj(self._diag) if adjoint else self._diag
242 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
243 inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1)
244 return rhs * inv_diag_mat
246 def _to_dense(self):
247 return array_ops.matrix_diag(self._diag)
249 def _diag_part(self):
250 return self.diag
252 def _add_to_tensor(self, x):
253 x_diag = array_ops.matrix_diag_part(x)
254 new_diag = self._diag + x_diag
255 return array_ops.matrix_set_diag(x, new_diag)
257 def _eigvals(self):
258 return tensor_conversion.convert_to_tensor_v2_with_dispatch(self.diag)
260 def _cond(self):
261 abs_diag = math_ops.abs(self.diag)
262 return (math_ops.reduce_max(abs_diag, axis=-1) /
263 math_ops.reduce_min(abs_diag, axis=-1))
265 @property
266 def _composite_tensor_fields(self):
267 return ("diag",)
269 @property
270 def _experimental_parameter_ndims_to_matrix_ndims(self):
271 return {"diag": 1}