Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/matmul_registrations.py: 54%
68 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Registrations for LinearOperator.matmul."""
17from tensorflow.python.ops.linalg import linear_operator
18from tensorflow.python.ops.linalg import linear_operator_algebra
19from tensorflow.python.ops.linalg import linear_operator_block_diag
20from tensorflow.python.ops.linalg import linear_operator_circulant
21from tensorflow.python.ops.linalg import linear_operator_composition
22from tensorflow.python.ops.linalg import linear_operator_diag
23from tensorflow.python.ops.linalg import linear_operator_identity
24from tensorflow.python.ops.linalg import linear_operator_lower_triangular
25from tensorflow.python.ops.linalg import linear_operator_zeros
26from tensorflow.python.ops.linalg import registrations_util
29# By default, use a LinearOperatorComposition to delay the computation.
30@linear_operator_algebra.RegisterMatmul(
31 linear_operator.LinearOperator, linear_operator.LinearOperator)
32def _matmul_linear_operator(linop_a, linop_b):
33 """Generic matmul of two `LinearOperator`s."""
34 is_square = registrations_util.is_square(linop_a, linop_b)
35 is_non_singular = None
36 is_self_adjoint = None
37 is_positive_definite = None
39 if is_square:
40 is_non_singular = registrations_util.combined_non_singular_hint(
41 linop_a, linop_b)
42 elif is_square is False: # pylint:disable=g-bool-id-comparison
43 is_non_singular = False
44 is_self_adjoint = False
45 is_positive_definite = False
47 return linear_operator_composition.LinearOperatorComposition(
48 operators=[linop_a, linop_b],
49 is_non_singular=is_non_singular,
50 is_self_adjoint=is_self_adjoint,
51 is_positive_definite=is_positive_definite,
52 is_square=is_square,
53 )
55# Identity
58@linear_operator_algebra.RegisterMatmul(
59 linear_operator_identity.LinearOperatorIdentity,
60 linear_operator.LinearOperator)
61def _matmul_linear_operator_identity_left(identity, linop):
62 del identity
63 return linop
66@linear_operator_algebra.RegisterMatmul(
67 linear_operator.LinearOperator,
68 linear_operator_identity.LinearOperatorIdentity)
69def _matmul_linear_operator_identity_right(linop, identity):
70 del identity
71 return linop
74@linear_operator_algebra.RegisterMatmul(
75 linear_operator_identity.LinearOperatorScaledIdentity,
76 linear_operator_identity.LinearOperatorScaledIdentity)
77def _matmul_linear_operator_scaled_identity(linop_a, linop_b):
78 """Matmul of two ScaledIdentity `LinearOperators`."""
79 return linear_operator_identity.LinearOperatorScaledIdentity(
80 num_rows=linop_a.domain_dimension_tensor(),
81 multiplier=linop_a.multiplier * linop_b.multiplier,
82 is_non_singular=registrations_util.combined_non_singular_hint(
83 linop_a, linop_b),
84 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
85 linop_a, linop_b),
86 is_positive_definite=(
87 registrations_util.combined_commuting_positive_definite_hint(
88 linop_a, linop_b)),
89 is_square=True)
92# Zeros
95@linear_operator_algebra.RegisterMatmul(
96 linear_operator.LinearOperator,
97 linear_operator_zeros.LinearOperatorZeros)
98def _matmul_linear_operator_zeros_right(linop, zeros):
99 if not zeros.is_square or not linop.is_square:
100 raise ValueError("Matmul with non-square `LinearOperator`s or non-square "
101 "`LinearOperatorZeros` not supported at this time.")
102 return zeros
105@linear_operator_algebra.RegisterMatmul(
106 linear_operator_zeros.LinearOperatorZeros,
107 linear_operator.LinearOperator)
108def _matmul_linear_operator_zeros_left(zeros, linop):
109 if not zeros.is_square or not linop.is_square:
110 raise ValueError("Matmul with non-square `LinearOperator`s or non-square "
111 "`LinearOperatorZeros` not supported at this time.")
112 return zeros
115# Diag.
118@linear_operator_algebra.RegisterMatmul(
119 linear_operator_diag.LinearOperatorDiag,
120 linear_operator_diag.LinearOperatorDiag)
121def _matmul_linear_operator_diag(linop_a, linop_b):
122 return linear_operator_diag.LinearOperatorDiag(
123 diag=linop_a.diag * linop_b.diag,
124 is_non_singular=registrations_util.combined_non_singular_hint(
125 linop_a, linop_b),
126 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
127 linop_a, linop_b),
128 is_positive_definite=(
129 registrations_util.combined_commuting_positive_definite_hint(
130 linop_a, linop_b)),
131 is_square=True)
134@linear_operator_algebra.RegisterMatmul(
135 linear_operator_diag.LinearOperatorDiag,
136 linear_operator_identity.LinearOperatorScaledIdentity)
137def _matmul_linear_operator_diag_scaled_identity_right(
138 linop_diag, linop_scaled_identity):
139 return linear_operator_diag.LinearOperatorDiag(
140 diag=linop_diag.diag * linop_scaled_identity.multiplier,
141 is_non_singular=registrations_util.combined_non_singular_hint(
142 linop_diag, linop_scaled_identity),
143 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
144 linop_diag, linop_scaled_identity),
145 is_positive_definite=(
146 registrations_util.combined_commuting_positive_definite_hint(
147 linop_diag, linop_scaled_identity)),
148 is_square=True)
151@linear_operator_algebra.RegisterMatmul(
152 linear_operator_identity.LinearOperatorScaledIdentity,
153 linear_operator_diag.LinearOperatorDiag)
154def _matmul_linear_operator_diag_scaled_identity_left(
155 linop_scaled_identity, linop_diag):
156 return linear_operator_diag.LinearOperatorDiag(
157 diag=linop_diag.diag * linop_scaled_identity.multiplier,
158 is_non_singular=registrations_util.combined_non_singular_hint(
159 linop_diag, linop_scaled_identity),
160 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
161 linop_diag, linop_scaled_identity),
162 is_positive_definite=(
163 registrations_util.combined_commuting_positive_definite_hint(
164 linop_diag, linop_scaled_identity)),
165 is_square=True)
168@linear_operator_algebra.RegisterMatmul(
169 linear_operator_diag.LinearOperatorDiag,
170 linear_operator_lower_triangular.LinearOperatorLowerTriangular)
171def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular):
172 return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
173 tril=linop_diag.diag[..., None] * linop_triangular.to_dense(),
174 is_non_singular=registrations_util.combined_non_singular_hint(
175 linop_diag, linop_triangular),
176 # This is safe to do since the Triangular matrix is only self-adjoint
177 # when it is a diagonal matrix, and hence commutes.
178 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
179 linop_diag, linop_triangular),
180 is_positive_definite=None,
181 is_square=True)
184@linear_operator_algebra.RegisterMatmul(
185 linear_operator_lower_triangular.LinearOperatorLowerTriangular,
186 linear_operator_diag.LinearOperatorDiag)
187def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag):
188 return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
189 tril=linop_triangular.to_dense() * linop_diag.diag,
190 is_non_singular=registrations_util.combined_non_singular_hint(
191 linop_diag, linop_triangular),
192 # This is safe to do since the Triangular matrix is only self-adjoint
193 # when it is a diagonal matrix, and hence commutes.
194 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
195 linop_diag, linop_triangular),
196 is_positive_definite=None,
197 is_square=True)
199# Circulant.
202# pylint: disable=protected-access
203@linear_operator_algebra.RegisterMatmul(
204 linear_operator_circulant._BaseLinearOperatorCirculant,
205 linear_operator_circulant._BaseLinearOperatorCirculant)
206def _matmul_linear_operator_circulant_circulant(linop_a, linop_b):
207 if not isinstance(linop_a, linop_b.__class__):
208 return _matmul_linear_operator(linop_a, linop_b)
210 return linop_a.__class__(
211 spectrum=linop_a.spectrum * linop_b.spectrum,
212 is_non_singular=registrations_util.combined_non_singular_hint(
213 linop_a, linop_b),
214 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
215 linop_a, linop_b),
216 is_positive_definite=(
217 registrations_util.combined_commuting_positive_definite_hint(
218 linop_a, linop_b)),
219 is_square=True)
220# pylint: enable=protected-access
222# Block Diag
225@linear_operator_algebra.RegisterMatmul(
226 linear_operator_block_diag.LinearOperatorBlockDiag,
227 linear_operator_block_diag.LinearOperatorBlockDiag)
228def _matmul_linear_operator_block_diag_block_diag(linop_a, linop_b):
229 return linear_operator_block_diag.LinearOperatorBlockDiag(
230 operators=[
231 o1.matmul(o2) for o1, o2 in zip(
232 linop_a.operators, linop_b.operators)],
233 is_non_singular=registrations_util.combined_non_singular_hint(
234 linop_a, linop_b),
235 # In general, a product of self-adjoint positive-definite block diagonal
236 # matrices is not self-=adjoint.
237 is_self_adjoint=None,
238 # In general, a product of positive-definite block diagonal matrices is
239 # not positive-definite.
240 is_positive_definite=None,
241 is_square=True)