Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_toeplitz.py: 40%
87 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a Toeplitz matrix."""
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import tensor_conversion
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import check_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.ops.linalg import linalg_impl as linalg
24from tensorflow.python.ops.linalg import linear_operator
25from tensorflow.python.ops.linalg import linear_operator_circulant
26from tensorflow.python.ops.linalg import linear_operator_util
27from tensorflow.python.ops.signal import fft_ops
28from tensorflow.python.util.tf_export import tf_export
30__all__ = ["LinearOperatorToeplitz",]
33@tf_export("linalg.LinearOperatorToeplitz")
34@linear_operator.make_composite_tensor
35class LinearOperatorToeplitz(linear_operator.LinearOperator):
36 """`LinearOperator` acting like a [batch] of toeplitz matrices.
38 This operator acts like a [batch] Toeplitz matrix `A` with shape
39 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
40 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
41 an `N x N` matrix. This matrix `A` is not materialized, but for
42 purposes of broadcasting this shape will be relevant.
44 #### Description in terms of toeplitz matrices
46 Toeplitz means that `A` has constant diagonals. Hence, `A` can be generated
47 with two vectors. One represents the first column of the matrix, and the
48 other represents the first row.
50 Below is a 4 x 4 example:
52 ```
53 A = |a b c d|
54 |e a b c|
55 |f e a b|
56 |g f e a|
57 ```
59 #### Example of a Toeplitz operator.
61 ```python
62 # Create a 3 x 3 Toeplitz operator.
63 col = [1., 2., 3.]
64 row = [1., 4., -9.]
65 operator = LinearOperatorToeplitz(col, row)
67 operator.to_dense()
68 ==> [[1., 4., -9.],
69 [2., 1., 4.],
70 [3., 2., 1.]]
72 operator.shape
73 ==> [3, 3]
75 operator.log_abs_determinant()
76 ==> scalar Tensor
78 x = ... Shape [3, 4] Tensor
79 operator.matmul(x)
80 ==> Shape [3, 4] Tensor
81 ```
83 #### Shape compatibility
85 This operator acts on [batch] matrix with compatible shape.
86 `x` is a batch matrix with compatible shape for `matmul` and `solve` if
88 ```
89 operator.shape = [B1,...,Bb] + [N, N], with b >= 0
90 x.shape = [C1,...,Cc] + [N, R],
91 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
92 ```
94 #### Matrix property hints
96 This `LinearOperator` is initialized with boolean flags of the form `is_X`,
97 for `X = non_singular, self_adjoint, positive_definite, square`.
98 These have the following meaning:
100 * If `is_X == True`, callers should expect the operator to have the
101 property `X`. This is a promise that should be fulfilled, but is *not* a
102 runtime assert. For example, finite floating point precision may result
103 in these promises being violated.
104 * If `is_X == False`, callers should expect the operator to not have `X`.
105 * If `is_X == None` (the default), callers should have no expectation either
106 way.
107 """
109 def __init__(self,
110 col,
111 row,
112 is_non_singular=None,
113 is_self_adjoint=None,
114 is_positive_definite=None,
115 is_square=None,
116 name="LinearOperatorToeplitz"):
117 r"""Initialize a `LinearOperatorToeplitz`.
119 Args:
120 col: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
121 The first column of the operator. Allowed dtypes: `float16`, `float32`,
122 `float64`, `complex64`, `complex128`. Note that the first entry of
123 `col` is assumed to be the same as the first entry of `row`.
124 row: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
125 The first row of the operator. Allowed dtypes: `float16`, `float32`,
126 `float64`, `complex64`, `complex128`. Note that the first entry of
127 `row` is assumed to be the same as the first entry of `col`.
128 is_non_singular: Expect that this operator is non-singular.
129 is_self_adjoint: Expect that this operator is equal to its hermitian
130 transpose. If `diag.dtype` is real, this is auto-set to `True`.
131 is_positive_definite: Expect that this operator is positive definite,
132 meaning the quadratic form `x^H A x` has positive real part for all
133 nonzero `x`. Note that we do not require the operator to be
134 self-adjoint to be positive-definite. See:
135 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
136 is_square: Expect that this operator acts like square [batch] matrices.
137 name: A name for this `LinearOperator`.
138 """
139 parameters = dict(
140 col=col,
141 row=row,
142 is_non_singular=is_non_singular,
143 is_self_adjoint=is_self_adjoint,
144 is_positive_definite=is_positive_definite,
145 is_square=is_square,
146 name=name
147 )
149 with ops.name_scope(name, values=[row, col]):
150 self._row = linear_operator_util.convert_nonref_to_tensor(row, name="row")
151 self._col = linear_operator_util.convert_nonref_to_tensor(col, name="col")
152 self._check_row_col(self._row, self._col)
154 if is_square is False: # pylint:disable=g-bool-id-comparison
155 raise ValueError("Only square Toeplitz operators currently supported.")
156 is_square = True
158 super(LinearOperatorToeplitz, self).__init__(
159 dtype=self._row.dtype,
160 is_non_singular=is_non_singular,
161 is_self_adjoint=is_self_adjoint,
162 is_positive_definite=is_positive_definite,
163 is_square=is_square,
164 parameters=parameters,
165 name=name)
167 def _check_row_col(self, row, col):
168 """Static check of row and column."""
169 for name, tensor in [["row", row], ["col", col]]:
170 if tensor.shape.ndims is not None and tensor.shape.ndims < 1:
171 raise ValueError("Argument {} must have at least 1 dimension. "
172 "Found: {}".format(name, tensor))
174 if row.shape[-1] is not None and col.shape[-1] is not None:
175 if row.shape[-1] != col.shape[-1]:
176 raise ValueError(
177 "Expected square matrix, got row and col with mismatched "
178 "dimensions.")
180 def _shape(self):
181 # If d_shape = [5, 3], we return [5, 3, 3].
182 v_shape = array_ops.broadcast_static_shape(
183 self.row.shape, self.col.shape)
184 return v_shape.concatenate(v_shape[-1:])
186 def _shape_tensor(self, row=None, col=None):
187 row = self.row if row is None else row
188 col = self.col if col is None else col
189 v_shape = array_ops.broadcast_dynamic_shape(
190 array_ops.shape(row),
191 array_ops.shape(col))
192 k = v_shape[-1]
193 return array_ops.concat((v_shape, [k]), 0)
195 def _assert_self_adjoint(self):
196 return check_ops.assert_equal(
197 self.row,
198 self.col,
199 message=("row and col are not the same, and "
200 "so this operator is not self-adjoint."))
202 # TODO(srvasude): Add efficient solver and determinant calculations to this
203 # class (based on Levinson recursion.)
205 def _matmul(self, x, adjoint=False, adjoint_arg=False):
206 # Given a Toeplitz matrix, we can embed it in a Circulant matrix to perform
207 # efficient matrix multiplications. Given a Toeplitz matrix with first row
208 # [t_0, t_1, ... t_{n-1}] and first column [t0, t_{-1}, ..., t_{-(n-1)},
209 # let C by the circulant matrix with first column [t0, t_{-1}, ...,
210 # t_{-(n-1)}, 0, t_{n-1}, ..., t_1]. Also adjoin to our input vector `x`
211 # `n` zeros, to make it a vector of length `2n` (call it y). It can be shown
212 # that if we take the first n entries of `Cy`, this is equal to the Toeplitz
213 # multiplication. See:
214 # http://math.mit.edu/icg/resources/teaching/18.085-spring2015/toeplitz.pdf
215 # for more details.
216 x = linalg.adjoint(x) if adjoint_arg else x
217 expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2)
218 col = tensor_conversion.convert_to_tensor_v2_with_dispatch(self.col)
219 row = tensor_conversion.convert_to_tensor_v2_with_dispatch(self.row)
220 circulant_col = array_ops.concat(
221 [col,
222 array_ops.zeros_like(col[..., 0:1]),
223 array_ops.reverse(row[..., 1:], axis=[-1])], axis=-1)
224 circulant = linear_operator_circulant.LinearOperatorCirculant(
225 fft_ops.fft(_to_complex(circulant_col)),
226 input_output_dtype=row.dtype)
227 result = circulant.matmul(expanded_x, adjoint=adjoint, adjoint_arg=False)
229 shape = self._shape_tensor(row=row, col=col)
230 return math_ops.cast(
231 result[..., :self._domain_dimension_tensor(shape=shape), :],
232 self.dtype)
234 def _trace(self):
235 return math_ops.cast(
236 self.domain_dimension_tensor(),
237 dtype=self.dtype) * self.col[..., 0]
239 def _diag_part(self):
240 diag_entry = self.col[..., 0:1]
241 return diag_entry * array_ops.ones(
242 [self.domain_dimension_tensor()], self.dtype)
244 def _to_dense(self):
245 row = tensor_conversion.convert_to_tensor_v2_with_dispatch(self.row)
246 col = tensor_conversion.convert_to_tensor_v2_with_dispatch(self.col)
247 total_shape = array_ops.broadcast_dynamic_shape(
248 array_ops.shape(row), array_ops.shape(col))
249 n = array_ops.shape(row)[-1]
250 row = array_ops.broadcast_to(row, total_shape)
251 col = array_ops.broadcast_to(col, total_shape)
252 # We concatenate the column in reverse order to the row.
253 # This gives us 2*n + 1 elements.
254 elements = array_ops.concat(
255 [array_ops.reverse(col, axis=[-1]), row[..., 1:]], axis=-1)
256 # Given the above vector, the i-th row of the Toeplitz matrix
257 # is the last n elements of the above vector shifted i right
258 # (hence the first row is just the row vector provided, and
259 # the first element of each row will belong to the column vector).
260 # We construct these set of indices below.
261 indices = math_ops.mod(
262 # How much to shift right. This corresponds to `i`.
263 math_ops.range(0, n) +
264 # Specifies the last `n` indices.
265 math_ops.range(n - 1, -1, -1)[..., array_ops.newaxis],
266 # Mod out by the total number of elements to ensure the index is
267 # non-negative (for tf.gather) and < 2 * n - 1.
268 2 * n - 1)
269 return array_ops.gather(elements, indices, axis=-1)
271 @property
272 def col(self):
273 return self._col
275 @property
276 def row(self):
277 return self._row
279 @property
280 def _composite_tensor_fields(self):
281 return ("col", "row")
283 @property
284 def _experimental_parameter_ndims_to_matrix_ndims(self):
285 return {"col": 1, "row": 1}
288def _to_complex(x):
289 dtype = dtypes.complex64
290 if x.dtype in [dtypes.float64, dtypes.complex128]:
291 dtype = dtypes.complex128
292 return math_ops.cast(x, dtype)