Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/solve_registrations.py: 57%
58 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Registrations for LinearOperator.solve."""
17from tensorflow.python.ops.linalg import linear_operator
18from tensorflow.python.ops.linalg import linear_operator_algebra
19from tensorflow.python.ops.linalg import linear_operator_block_diag
20from tensorflow.python.ops.linalg import linear_operator_circulant
21from tensorflow.python.ops.linalg import linear_operator_composition
22from tensorflow.python.ops.linalg import linear_operator_diag
23from tensorflow.python.ops.linalg import linear_operator_identity
24from tensorflow.python.ops.linalg import linear_operator_inversion
25from tensorflow.python.ops.linalg import linear_operator_lower_triangular
26from tensorflow.python.ops.linalg import registrations_util
29# By default, use a LinearOperatorComposition to delay the computation.
30@linear_operator_algebra.RegisterSolve(
31 linear_operator.LinearOperator, linear_operator.LinearOperator)
32def _solve_linear_operator(linop_a, linop_b):
33 """Generic solve of two `LinearOperator`s."""
34 is_square = registrations_util.is_square(linop_a, linop_b)
35 is_non_singular = None
36 is_self_adjoint = None
37 is_positive_definite = None
39 if is_square:
40 is_non_singular = registrations_util.combined_non_singular_hint(
41 linop_a, linop_b)
42 elif is_square is False: # pylint:disable=g-bool-id-comparison
43 is_non_singular = False
44 is_self_adjoint = False
45 is_positive_definite = False
47 return linear_operator_composition.LinearOperatorComposition(
48 operators=[
49 linear_operator_inversion.LinearOperatorInversion(linop_a),
50 linop_b
51 ],
52 is_non_singular=is_non_singular,
53 is_self_adjoint=is_self_adjoint,
54 is_positive_definite=is_positive_definite,
55 is_square=is_square,
56 )
59@linear_operator_algebra.RegisterSolve(
60 linear_operator_inversion.LinearOperatorInversion,
61 linear_operator.LinearOperator)
62def _solve_inverse_linear_operator(linop_a, linop_b):
63 """Solve inverse of generic `LinearOperator`s."""
64 return linop_a.operator.matmul(linop_b)
67# Identity
68@linear_operator_algebra.RegisterSolve(
69 linear_operator_identity.LinearOperatorIdentity,
70 linear_operator.LinearOperator)
71def _solve_linear_operator_identity_left(identity, linop):
72 del identity
73 return linop
76@linear_operator_algebra.RegisterSolve(
77 linear_operator.LinearOperator,
78 linear_operator_identity.LinearOperatorIdentity)
79def _solve_linear_operator_identity_right(linop, identity):
80 del identity
81 return linop.inverse()
84@linear_operator_algebra.RegisterSolve(
85 linear_operator_identity.LinearOperatorScaledIdentity,
86 linear_operator_identity.LinearOperatorScaledIdentity)
87def _solve_linear_operator_scaled_identity(linop_a, linop_b):
88 """Solve of two ScaledIdentity `LinearOperators`."""
89 return linear_operator_identity.LinearOperatorScaledIdentity(
90 num_rows=linop_a.domain_dimension_tensor(),
91 multiplier=linop_b.multiplier / linop_a.multiplier,
92 is_non_singular=registrations_util.combined_non_singular_hint(
93 linop_a, linop_b),
94 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
95 linop_a, linop_b),
96 is_positive_definite=(
97 registrations_util.combined_commuting_positive_definite_hint(
98 linop_a, linop_b)),
99 is_square=True)
102# Diag.
105@linear_operator_algebra.RegisterSolve(
106 linear_operator_diag.LinearOperatorDiag,
107 linear_operator_diag.LinearOperatorDiag)
108def _solve_linear_operator_diag(linop_a, linop_b):
109 return linear_operator_diag.LinearOperatorDiag(
110 diag=linop_b.diag / linop_a.diag,
111 is_non_singular=registrations_util.combined_non_singular_hint(
112 linop_a, linop_b),
113 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
114 linop_a, linop_b),
115 is_positive_definite=(
116 registrations_util.combined_commuting_positive_definite_hint(
117 linop_a, linop_b)),
118 is_square=True)
121@linear_operator_algebra.RegisterSolve(
122 linear_operator_diag.LinearOperatorDiag,
123 linear_operator_identity.LinearOperatorScaledIdentity)
124def _solve_linear_operator_diag_scaled_identity_right(
125 linop_diag, linop_scaled_identity):
126 return linear_operator_diag.LinearOperatorDiag(
127 diag=linop_scaled_identity.multiplier / linop_diag.diag,
128 is_non_singular=registrations_util.combined_non_singular_hint(
129 linop_diag, linop_scaled_identity),
130 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
131 linop_diag, linop_scaled_identity),
132 is_positive_definite=(
133 registrations_util.combined_commuting_positive_definite_hint(
134 linop_diag, linop_scaled_identity)),
135 is_square=True)
138@linear_operator_algebra.RegisterSolve(
139 linear_operator_identity.LinearOperatorScaledIdentity,
140 linear_operator_diag.LinearOperatorDiag)
141def _solve_linear_operator_diag_scaled_identity_left(
142 linop_scaled_identity, linop_diag):
143 return linear_operator_diag.LinearOperatorDiag(
144 diag=linop_diag.diag / linop_scaled_identity.multiplier,
145 is_non_singular=registrations_util.combined_non_singular_hint(
146 linop_diag, linop_scaled_identity),
147 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
148 linop_diag, linop_scaled_identity),
149 is_positive_definite=(
150 registrations_util.combined_commuting_positive_definite_hint(
151 linop_diag, linop_scaled_identity)),
152 is_square=True)
155@linear_operator_algebra.RegisterSolve(
156 linear_operator_diag.LinearOperatorDiag,
157 linear_operator_lower_triangular.LinearOperatorLowerTriangular)
158def _solve_linear_operator_diag_tril(linop_diag, linop_triangular):
159 return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
160 tril=linop_triangular.to_dense() / linop_diag.diag[..., None],
161 is_non_singular=registrations_util.combined_non_singular_hint(
162 linop_diag, linop_triangular),
163 # This is safe to do since the Triangular matrix is only self-adjoint
164 # when it is a diagonal matrix, and hence commutes.
165 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
166 linop_diag, linop_triangular),
167 is_positive_definite=None,
168 is_square=True)
171# Circulant.
174# pylint: disable=protected-access
175@linear_operator_algebra.RegisterSolve(
176 linear_operator_circulant._BaseLinearOperatorCirculant,
177 linear_operator_circulant._BaseLinearOperatorCirculant)
178def _solve_linear_operator_circulant_circulant(linop_a, linop_b):
179 if not isinstance(linop_a, linop_b.__class__):
180 return _solve_linear_operator(linop_a, linop_b)
182 return linop_a.__class__(
183 spectrum=linop_b.spectrum / linop_a.spectrum,
184 is_non_singular=registrations_util.combined_non_singular_hint(
185 linop_a, linop_b),
186 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
187 linop_a, linop_b),
188 is_positive_definite=(
189 registrations_util.combined_commuting_positive_definite_hint(
190 linop_a, linop_b)),
191 is_square=True)
192# pylint: enable=protected-access
195# Block Diag
198@linear_operator_algebra.RegisterSolve(
199 linear_operator_block_diag.LinearOperatorBlockDiag,
200 linear_operator_block_diag.LinearOperatorBlockDiag)
201def _solve_linear_operator_block_diag_block_diag(linop_a, linop_b):
202 return linear_operator_block_diag.LinearOperatorBlockDiag(
203 operators=[
204 o1.solve(o2) for o1, o2 in zip(
205 linop_a.operators, linop_b.operators)],
206 is_non_singular=registrations_util.combined_non_singular_hint(
207 linop_a, linop_b),
208 # In general, a solve of self-adjoint positive-definite block diagonal
209 # matrices is not self-=adjoint.
210 is_self_adjoint=None,
211 # In general, a solve of positive-definite block diagonal matrices is
212 # not positive-definite.
213 is_positive_definite=None,
214 is_square=True)