Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/inverse_registrations.py: 52%
65 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Registrations for LinearOperator.inverse."""
17from tensorflow.python.ops import math_ops
18from tensorflow.python.ops.linalg import linear_operator
19from tensorflow.python.ops.linalg import linear_operator_addition
20from tensorflow.python.ops.linalg import linear_operator_algebra
21from tensorflow.python.ops.linalg import linear_operator_block_diag
22from tensorflow.python.ops.linalg import linear_operator_block_lower_triangular
23from tensorflow.python.ops.linalg import linear_operator_circulant
24from tensorflow.python.ops.linalg import linear_operator_diag
25from tensorflow.python.ops.linalg import linear_operator_full_matrix
26from tensorflow.python.ops.linalg import linear_operator_householder
27from tensorflow.python.ops.linalg import linear_operator_identity
28from tensorflow.python.ops.linalg import linear_operator_inversion
29from tensorflow.python.ops.linalg import linear_operator_kronecker
32# By default, return LinearOperatorInversion which switched the .matmul
33# and .solve methods.
34@linear_operator_algebra.RegisterInverse(linear_operator.LinearOperator)
35def _inverse_linear_operator(linop):
36 return linear_operator_inversion.LinearOperatorInversion(
37 linop,
38 is_non_singular=linop.is_non_singular,
39 is_self_adjoint=linop.is_self_adjoint,
40 is_positive_definite=linop.is_positive_definite,
41 is_square=linop.is_square)
44@linear_operator_algebra.RegisterInverse(
45 linear_operator_inversion.LinearOperatorInversion)
46def _inverse_inverse_linear_operator(linop_inversion):
47 return linop_inversion.operator
50@linear_operator_algebra.RegisterInverse(
51 linear_operator_diag.LinearOperatorDiag)
52def _inverse_diag(diag_operator):
53 return linear_operator_diag.LinearOperatorDiag(
54 1. / diag_operator.diag,
55 is_non_singular=diag_operator.is_non_singular,
56 is_self_adjoint=diag_operator.is_self_adjoint,
57 is_positive_definite=diag_operator.is_positive_definite,
58 is_square=True)
61@linear_operator_algebra.RegisterInverse(
62 linear_operator_identity.LinearOperatorIdentity)
63def _inverse_identity(identity_operator):
64 return identity_operator
67@linear_operator_algebra.RegisterInverse(
68 linear_operator_identity.LinearOperatorScaledIdentity)
69def _inverse_scaled_identity(identity_operator):
70 return linear_operator_identity.LinearOperatorScaledIdentity(
71 num_rows=identity_operator._num_rows, # pylint: disable=protected-access
72 multiplier=1. / identity_operator.multiplier,
73 is_non_singular=identity_operator.is_non_singular,
74 is_self_adjoint=True,
75 is_positive_definite=identity_operator.is_positive_definite,
76 is_square=True)
79@linear_operator_algebra.RegisterInverse(
80 linear_operator_block_diag.LinearOperatorBlockDiag)
81def _inverse_block_diag(block_diag_operator):
82 # We take the inverse of each block on the diagonal.
83 return linear_operator_block_diag.LinearOperatorBlockDiag(
84 operators=[
85 operator.inverse() for operator in block_diag_operator.operators],
86 is_non_singular=block_diag_operator.is_non_singular,
87 is_self_adjoint=block_diag_operator.is_self_adjoint,
88 is_positive_definite=block_diag_operator.is_positive_definite,
89 is_square=True)
92@linear_operator_algebra.RegisterInverse(
93 linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular)
94def _inverse_block_lower_triangular(block_lower_triangular_operator):
95 """Inverse of LinearOperatorBlockLowerTriangular.
97 We recursively apply the identity:
99 ```none
100 |A 0|' = | A' 0|
101 |B C| |-C'BA' C'|
102 ```
104 where `A` is n-by-n, `B` is m-by-n, `C` is m-by-m, and `'` denotes inverse.
106 This identity can be verified through multiplication:
108 ```none
109 |A 0|| A' 0|
110 |B C||-C'BA' C'|
112 = | AA' 0|
113 |BA'-CC'BA' CC'|
115 = |I 0|
116 |0 I|
117 ```
119 Args:
120 block_lower_triangular_operator: Instance of
121 `LinearOperatorBlockLowerTriangular`.
123 Returns:
124 block_lower_triangular_operator_inverse: Instance of
125 `LinearOperatorBlockLowerTriangular`, the inverse of
126 `block_lower_triangular_operator`.
127 """
128 if len(block_lower_triangular_operator.operators) == 1:
129 return (linear_operator_block_lower_triangular.
130 LinearOperatorBlockLowerTriangular(
131 [[block_lower_triangular_operator.operators[0][0].inverse()]],
132 is_non_singular=block_lower_triangular_operator.is_non_singular,
133 is_self_adjoint=block_lower_triangular_operator.is_self_adjoint,
134 is_positive_definite=(block_lower_triangular_operator.
135 is_positive_definite),
136 is_square=True))
138 blockwise_dim = len(block_lower_triangular_operator.operators)
140 # Calculate the inverse of the `LinearOperatorBlockLowerTriangular`
141 # representing all but the last row of `block_lower_triangular_operator` with
142 # a recursive call (the matrix `A'` in the docstring definition).
143 upper_left_inverse = (
144 linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular(
145 block_lower_triangular_operator.operators[:-1]).inverse())
147 bottom_row = block_lower_triangular_operator.operators[-1]
148 bottom_right_inverse = bottom_row[-1].inverse()
150 # Find the bottom row of the inverse (equal to `[-C'BA', C']` in the docstring
151 # definition, where `C` is the bottom-right operator of
152 # `block_lower_triangular_operator` and `B` is the set of operators in the
153 # bottom row excluding `C`). To find `-C'BA'`, we first iterate over the
154 # column partitions of `A'`.
155 inverse_bottom_row = []
156 for i in range(blockwise_dim - 1):
157 # Find the `i`-th block of `BA'`.
158 blocks = []
159 for j in range(i, blockwise_dim - 1):
160 result = bottom_row[j].matmul(upper_left_inverse.operators[j][i])
161 if not any(isinstance(result, op_type)
162 for op_type in linear_operator_addition.SUPPORTED_OPERATORS):
163 result = linear_operator_full_matrix.LinearOperatorFullMatrix(
164 result.to_dense())
165 blocks.append(result)
167 summed_blocks = linear_operator_addition.add_operators(blocks)
168 assert len(summed_blocks) == 1
169 block = summed_blocks[0]
171 # Find the `i`-th block of `-C'BA'`.
172 block = bottom_right_inverse.matmul(block)
173 block = linear_operator_identity.LinearOperatorScaledIdentity(
174 num_rows=bottom_right_inverse.domain_dimension_tensor(),
175 multiplier=math_ops.cast(-1, dtype=block.dtype)).matmul(block)
176 inverse_bottom_row.append(block)
178 # `C'` is the last block of the inverted linear operator.
179 inverse_bottom_row.append(bottom_right_inverse)
181 return (
182 linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular(
183 upper_left_inverse.operators + [inverse_bottom_row],
184 is_non_singular=block_lower_triangular_operator.is_non_singular,
185 is_self_adjoint=block_lower_triangular_operator.is_self_adjoint,
186 is_positive_definite=(block_lower_triangular_operator.
187 is_positive_definite),
188 is_square=True))
191@linear_operator_algebra.RegisterInverse(
192 linear_operator_kronecker.LinearOperatorKronecker)
193def _inverse_kronecker(kronecker_operator):
194 # Inverse decomposition of a Kronecker product is the Kronecker product
195 # of inverse decompositions.
196 return linear_operator_kronecker.LinearOperatorKronecker(
197 operators=[
198 operator.inverse() for operator in kronecker_operator.operators],
199 is_non_singular=kronecker_operator.is_non_singular,
200 is_self_adjoint=kronecker_operator.is_self_adjoint,
201 is_positive_definite=kronecker_operator.is_positive_definite,
202 is_square=True)
205@linear_operator_algebra.RegisterInverse(
206 linear_operator_circulant._BaseLinearOperatorCirculant) # pylint: disable=protected-access
207def _inverse_circulant(circulant_operator):
208 # Inverting the spectrum is sufficient to get the inverse.
209 return circulant_operator.__class__(
210 spectrum=1. / circulant_operator.spectrum,
211 is_non_singular=circulant_operator.is_non_singular,
212 is_self_adjoint=circulant_operator.is_self_adjoint,
213 is_positive_definite=circulant_operator.is_positive_definite,
214 is_square=True,
215 input_output_dtype=circulant_operator.dtype)
218@linear_operator_algebra.RegisterInverse(
219 linear_operator_householder.LinearOperatorHouseholder)
220def _inverse_householder(householder_operator):
221 return householder_operator