Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_tridiag.py: 30%
132 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a tridiagonal matrix."""
17from tensorflow.python.framework import ops
18from tensorflow.python.framework import tensor_conversion
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import array_ops_stack
21from tensorflow.python.ops import check_ops
22from tensorflow.python.ops import control_flow_ops
23from tensorflow.python.ops import gen_array_ops
24from tensorflow.python.ops import manip_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops.linalg import linalg_impl as linalg
27from tensorflow.python.ops.linalg import linear_operator
28from tensorflow.python.ops.linalg import linear_operator_util
29from tensorflow.python.util.tf_export import tf_export
31__all__ = ['LinearOperatorTridiag',]
33_COMPACT = 'compact'
34_MATRIX = 'matrix'
35_SEQUENCE = 'sequence'
36_DIAGONAL_FORMATS = frozenset({_COMPACT, _MATRIX, _SEQUENCE})
39@tf_export('linalg.LinearOperatorTridiag')
40@linear_operator.make_composite_tensor
41class LinearOperatorTridiag(linear_operator.LinearOperator):
42 """`LinearOperator` acting like a [batch] square tridiagonal matrix.
44 This operator acts like a [batch] square tridiagonal matrix `A` with shape
45 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
46 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
47 an `N x M` matrix. This matrix `A` is not materialized, but for
48 purposes of broadcasting this shape will be relevant.
50 Example usage:
52 Create a 3 x 3 tridiagonal linear operator.
54 >>> superdiag = [3., 4., 5.]
55 >>> diag = [1., -1., 2.]
56 >>> subdiag = [6., 7., 8]
57 >>> operator = tf.linalg.LinearOperatorTridiag(
58 ... [superdiag, diag, subdiag],
59 ... diagonals_format='sequence')
60 >>> operator.to_dense()
61 <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
62 array([[ 1., 3., 0.],
63 [ 7., -1., 4.],
64 [ 0., 8., 2.]], dtype=float32)>
65 >>> operator.shape
66 TensorShape([3, 3])
68 Scalar Tensor output.
70 >>> operator.log_abs_determinant()
71 <tf.Tensor: shape=(), dtype=float32, numpy=4.3307333>
73 Create a [2, 3] batch of 4 x 4 linear operators.
75 >>> diagonals = tf.random.normal(shape=[2, 3, 3, 4])
76 >>> operator = tf.linalg.LinearOperatorTridiag(
77 ... diagonals,
78 ... diagonals_format='compact')
80 Create a shape [2, 1, 4, 2] vector. Note that this shape is compatible
81 since the batch dimensions, [2, 1], are broadcast to
82 operator.batch_shape = [2, 3].
84 >>> y = tf.random.normal(shape=[2, 1, 4, 2])
85 >>> x = operator.solve(y)
86 >>> x
87 <tf.Tensor: shape=(2, 3, 4, 2), dtype=float32, numpy=...,
88 dtype=float32)>
90 #### Shape compatibility
92 This operator acts on [batch] matrix with compatible shape.
93 `x` is a batch matrix with compatible shape for `matmul` and `solve` if
95 ```
96 operator.shape = [B1,...,Bb] + [N, N], with b >= 0
97 x.shape = [C1,...,Cc] + [N, R],
98 and [C1,...,Cc] broadcasts with [B1,...,Bb].
99 ```
101 #### Performance
103 Suppose `operator` is a `LinearOperatorTridiag` of shape `[N, N]`,
104 and `x.shape = [N, R]`. Then
106 * `operator.matmul(x)` will take O(N * R) time.
107 * `operator.solve(x)` will take O(N * R) time.
109 If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
110 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
112 #### Matrix property hints
114 This `LinearOperator` is initialized with boolean flags of the form `is_X`,
115 for `X = non_singular, self_adjoint, positive_definite, square`.
116 These have the following meaning:
118 * If `is_X == True`, callers should expect the operator to have the
119 property `X`. This is a promise that should be fulfilled, but is *not* a
120 runtime assert. For example, finite floating point precision may result
121 in these promises being violated.
122 * If `is_X == False`, callers should expect the operator to not have `X`.
123 * If `is_X == None` (the default), callers should have no expectation either
124 way.
125 """
127 def __init__(self,
128 diagonals,
129 diagonals_format=_COMPACT,
130 is_non_singular=None,
131 is_self_adjoint=None,
132 is_positive_definite=None,
133 is_square=None,
134 name='LinearOperatorTridiag'):
135 r"""Initialize a `LinearOperatorTridiag`.
137 Args:
138 diagonals: `Tensor` or list of `Tensor`s depending on `diagonals_format`.
140 If `diagonals_format=sequence`, this is a list of three `Tensor`'s each
141 with shape `[B1, ..., Bb, N]`, `b >= 0, N >= 0`, representing the
142 superdiagonal, diagonal and subdiagonal in that order. Note the
143 superdiagonal is padded with an element in the last position, and the
144 subdiagonal is padded with an element in the front.
146 If `diagonals_format=matrix` this is a `[B1, ... Bb, N, N]` shaped
147 `Tensor` representing the full tridiagonal matrix.
149 If `diagonals_format=compact` this is a `[B1, ... Bb, 3, N]` shaped
150 `Tensor` with the second to last dimension indexing the
151 superdiagonal, diagonal and subdiagonal in that order. Note the
152 superdiagonal is padded with an element in the last position, and the
153 subdiagonal is padded with an element in the front.
155 In every case, these `Tensor`s are all floating dtype.
156 diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is
157 `compact`.
158 is_non_singular: Expect that this operator is non-singular.
159 is_self_adjoint: Expect that this operator is equal to its hermitian
160 transpose. If `diag.dtype` is real, this is auto-set to `True`.
161 is_positive_definite: Expect that this operator is positive definite,
162 meaning the quadratic form `x^H A x` has positive real part for all
163 nonzero `x`. Note that we do not require the operator to be
164 self-adjoint to be positive-definite. See:
165 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
166 is_square: Expect that this operator acts like square [batch] matrices.
167 name: A name for this `LinearOperator`.
169 Raises:
170 TypeError: If `diag.dtype` is not an allowed type.
171 ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
172 """
173 parameters = dict(
174 diagonals=diagonals,
175 diagonals_format=diagonals_format,
176 is_non_singular=is_non_singular,
177 is_self_adjoint=is_self_adjoint,
178 is_positive_definite=is_positive_definite,
179 is_square=is_square,
180 name=name
181 )
183 with ops.name_scope(name, values=[diagonals]):
184 if diagonals_format not in _DIAGONAL_FORMATS:
185 raise ValueError(
186 f'Argument `diagonals_format` must be one of compact, matrix, or '
187 f'sequence. Received : {diagonals_format}.')
188 if diagonals_format == _SEQUENCE:
189 self._diagonals = [linear_operator_util.convert_nonref_to_tensor(
190 d, name='diag_{}'.format(i)) for i, d in enumerate(diagonals)]
191 dtype = self._diagonals[0].dtype
192 else:
193 self._diagonals = linear_operator_util.convert_nonref_to_tensor(
194 diagonals, name='diagonals')
195 dtype = self._diagonals.dtype
196 self._diagonals_format = diagonals_format
198 super(LinearOperatorTridiag, self).__init__(
199 dtype=dtype,
200 is_non_singular=is_non_singular,
201 is_self_adjoint=is_self_adjoint,
202 is_positive_definite=is_positive_definite,
203 is_square=is_square,
204 parameters=parameters,
205 name=name)
207 def _shape(self):
208 if self.diagonals_format == _MATRIX:
209 return self.diagonals.shape
210 if self.diagonals_format == _COMPACT:
211 # Remove the second to last dimension that contains the value 3.
212 d_shape = self.diagonals.shape[:-2].concatenate(
213 self.diagonals.shape[-1])
214 else:
215 broadcast_shape = array_ops.broadcast_static_shape(
216 self.diagonals[0].shape[:-1],
217 self.diagonals[1].shape[:-1])
218 broadcast_shape = array_ops.broadcast_static_shape(
219 broadcast_shape,
220 self.diagonals[2].shape[:-1])
221 d_shape = broadcast_shape.concatenate(self.diagonals[1].shape[-1])
222 return d_shape.concatenate(d_shape[-1])
224 def _shape_tensor(self, diagonals=None):
225 diagonals = diagonals if diagonals is not None else self.diagonals
226 if self.diagonals_format == _MATRIX:
227 return array_ops.shape(diagonals)
228 if self.diagonals_format == _COMPACT:
229 d_shape = array_ops.shape(diagonals[..., 0, :])
230 else:
231 broadcast_shape = array_ops.broadcast_dynamic_shape(
232 array_ops.shape(self.diagonals[0])[:-1],
233 array_ops.shape(self.diagonals[1])[:-1])
234 broadcast_shape = array_ops.broadcast_dynamic_shape(
235 broadcast_shape,
236 array_ops.shape(self.diagonals[2])[:-1])
237 d_shape = array_ops.concat(
238 [broadcast_shape, [array_ops.shape(self.diagonals[1])[-1]]], axis=0)
239 return array_ops.concat([d_shape, [d_shape[-1]]], axis=-1)
241 def _assert_self_adjoint(self):
242 # Check the diagonal has non-zero imaginary, and the super and subdiagonals
243 # are conjugate.
245 asserts = []
246 diag_message = (
247 'This tridiagonal operator contained non-zero '
248 'imaginary values on the diagonal.')
249 off_diag_message = (
250 'This tridiagonal operator has non-conjugate '
251 'subdiagonal and superdiagonal.')
253 if self.diagonals_format == _MATRIX:
254 asserts += [check_ops.assert_equal(
255 self.diagonals, linalg.adjoint(self.diagonals),
256 message='Matrix was not equal to its adjoint.')]
257 elif self.diagonals_format == _COMPACT:
258 diagonals = tensor_conversion.convert_to_tensor_v2_with_dispatch(
259 self.diagonals
260 )
261 asserts += [linear_operator_util.assert_zero_imag_part(
262 diagonals[..., 1, :], message=diag_message)]
263 # Roll the subdiagonal so the shifted argument is at the end.
264 subdiag = manip_ops.roll(diagonals[..., 2, :], shift=-1, axis=-1)
265 asserts += [check_ops.assert_equal(
266 math_ops.conj(subdiag[..., :-1]),
267 diagonals[..., 0, :-1],
268 message=off_diag_message)]
269 else:
270 asserts += [linear_operator_util.assert_zero_imag_part(
271 self.diagonals[1], message=diag_message)]
272 subdiag = manip_ops.roll(self.diagonals[2], shift=-1, axis=-1)
273 asserts += [check_ops.assert_equal(
274 math_ops.conj(subdiag[..., :-1]),
275 self.diagonals[0][..., :-1],
276 message=off_diag_message)]
277 return control_flow_ops.group(asserts)
279 def _construct_adjoint_diagonals(self, diagonals):
280 # Constructs adjoint tridiagonal matrix from diagonals.
281 if self.diagonals_format == _SEQUENCE:
282 diagonals = [math_ops.conj(d) for d in reversed(diagonals)]
283 # The subdiag and the superdiag swap places, so we need to shift the
284 # padding argument.
285 diagonals[0] = manip_ops.roll(diagonals[0], shift=-1, axis=-1)
286 diagonals[2] = manip_ops.roll(diagonals[2], shift=1, axis=-1)
287 return diagonals
288 elif self.diagonals_format == _MATRIX:
289 return linalg.adjoint(diagonals)
290 else:
291 diagonals = math_ops.conj(diagonals)
292 superdiag, diag, subdiag = array_ops_stack.unstack(
293 diagonals, num=3, axis=-2)
294 # The subdiag and the superdiag swap places, so we need
295 # to shift all arguments.
296 new_superdiag = manip_ops.roll(subdiag, shift=-1, axis=-1)
297 new_subdiag = manip_ops.roll(superdiag, shift=1, axis=-1)
298 return array_ops_stack.stack([new_superdiag, diag, new_subdiag], axis=-2)
300 def _matmul(self, x, adjoint=False, adjoint_arg=False):
301 diagonals = self.diagonals
302 if adjoint:
303 diagonals = self._construct_adjoint_diagonals(diagonals)
304 x = linalg.adjoint(x) if adjoint_arg else x
305 return linalg.tridiagonal_matmul(
306 diagonals, x,
307 diagonals_format=self.diagonals_format)
309 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
310 diagonals = self.diagonals
311 if adjoint:
312 diagonals = self._construct_adjoint_diagonals(diagonals)
314 # TODO(b/144860784): Remove the broadcasting code below once
315 # tridiagonal_solve broadcasts.
317 rhs_shape = array_ops.shape(rhs)
318 k = self._shape_tensor(diagonals)[-1]
319 broadcast_shape = array_ops.broadcast_dynamic_shape(
320 self._shape_tensor(diagonals)[:-2], rhs_shape[:-2])
321 rhs = array_ops.broadcast_to(
322 rhs, array_ops.concat(
323 [broadcast_shape, rhs_shape[-2:]], axis=-1))
324 if self.diagonals_format == _MATRIX:
325 diagonals = array_ops.broadcast_to(
326 diagonals, array_ops.concat(
327 [broadcast_shape, [k, k]], axis=-1))
328 elif self.diagonals_format == _COMPACT:
329 diagonals = array_ops.broadcast_to(
330 diagonals, array_ops.concat(
331 [broadcast_shape, [3, k]], axis=-1))
332 else:
333 diagonals = [
334 array_ops.broadcast_to(d, array_ops.concat(
335 [broadcast_shape, [k]], axis=-1)) for d in diagonals]
337 y = linalg.tridiagonal_solve(
338 diagonals, rhs,
339 diagonals_format=self.diagonals_format,
340 transpose_rhs=adjoint_arg,
341 conjugate_rhs=adjoint_arg)
342 return y
344 def _diag_part(self):
345 if self.diagonals_format == _MATRIX:
346 return array_ops.matrix_diag_part(self.diagonals)
347 elif self.diagonals_format == _SEQUENCE:
348 diagonal = self.diagonals[1]
349 return array_ops.broadcast_to(
350 diagonal, self.shape_tensor()[:-1])
351 else:
352 return self.diagonals[..., 1, :]
354 def _to_dense(self):
355 if self.diagonals_format == _MATRIX:
356 return self.diagonals
358 if self.diagonals_format == _COMPACT:
359 return gen_array_ops.matrix_diag_v3(
360 self.diagonals,
361 k=(-1, 1),
362 num_rows=-1,
363 num_cols=-1,
364 align='LEFT_RIGHT',
365 padding_value=0.)
367 diagonals = [
368 tensor_conversion.convert_to_tensor_v2_with_dispatch(d)
369 for d in self.diagonals
370 ]
371 diagonals = array_ops_stack.stack(diagonals, axis=-2)
373 return gen_array_ops.matrix_diag_v3(
374 diagonals,
375 k=(-1, 1),
376 num_rows=-1,
377 num_cols=-1,
378 align='LEFT_RIGHT',
379 padding_value=0.)
381 @property
382 def diagonals(self):
383 return self._diagonals
385 @property
386 def diagonals_format(self):
387 return self._diagonals_format
389 @property
390 def _composite_tensor_fields(self):
391 return ('diagonals', 'diagonals_format')
393 @property
394 def _experimental_parameter_ndims_to_matrix_ndims(self):
395 diagonal_event_ndims = 2
396 if self.diagonals_format == _SEQUENCE:
397 # For the diagonal and the super/sub diagonals.
398 diagonal_event_ndims = [1, 1, 1]
399 return {
400 'diagonals': diagonal_event_ndims,
401 }