Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/inverse_registrations.py: 52%

65 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Registrations for LinearOperator.inverse.""" 

16 

17from tensorflow.python.ops import math_ops 

18from tensorflow.python.ops.linalg import linear_operator 

19from tensorflow.python.ops.linalg import linear_operator_addition 

20from tensorflow.python.ops.linalg import linear_operator_algebra 

21from tensorflow.python.ops.linalg import linear_operator_block_diag 

22from tensorflow.python.ops.linalg import linear_operator_block_lower_triangular 

23from tensorflow.python.ops.linalg import linear_operator_circulant 

24from tensorflow.python.ops.linalg import linear_operator_diag 

25from tensorflow.python.ops.linalg import linear_operator_full_matrix 

26from tensorflow.python.ops.linalg import linear_operator_householder 

27from tensorflow.python.ops.linalg import linear_operator_identity 

28from tensorflow.python.ops.linalg import linear_operator_inversion 

29from tensorflow.python.ops.linalg import linear_operator_kronecker 

30 

31 

32# By default, return LinearOperatorInversion which switched the .matmul 

33# and .solve methods. 

34@linear_operator_algebra.RegisterInverse(linear_operator.LinearOperator) 

35def _inverse_linear_operator(linop): 

36 return linear_operator_inversion.LinearOperatorInversion( 

37 linop, 

38 is_non_singular=linop.is_non_singular, 

39 is_self_adjoint=linop.is_self_adjoint, 

40 is_positive_definite=linop.is_positive_definite, 

41 is_square=linop.is_square) 

42 

43 

44@linear_operator_algebra.RegisterInverse( 

45 linear_operator_inversion.LinearOperatorInversion) 

46def _inverse_inverse_linear_operator(linop_inversion): 

47 return linop_inversion.operator 

48 

49 

50@linear_operator_algebra.RegisterInverse( 

51 linear_operator_diag.LinearOperatorDiag) 

52def _inverse_diag(diag_operator): 

53 return linear_operator_diag.LinearOperatorDiag( 

54 1. / diag_operator.diag, 

55 is_non_singular=diag_operator.is_non_singular, 

56 is_self_adjoint=diag_operator.is_self_adjoint, 

57 is_positive_definite=diag_operator.is_positive_definite, 

58 is_square=True) 

59 

60 

61@linear_operator_algebra.RegisterInverse( 

62 linear_operator_identity.LinearOperatorIdentity) 

63def _inverse_identity(identity_operator): 

64 return identity_operator 

65 

66 

67@linear_operator_algebra.RegisterInverse( 

68 linear_operator_identity.LinearOperatorScaledIdentity) 

69def _inverse_scaled_identity(identity_operator): 

70 return linear_operator_identity.LinearOperatorScaledIdentity( 

71 num_rows=identity_operator._num_rows, # pylint: disable=protected-access 

72 multiplier=1. / identity_operator.multiplier, 

73 is_non_singular=identity_operator.is_non_singular, 

74 is_self_adjoint=True, 

75 is_positive_definite=identity_operator.is_positive_definite, 

76 is_square=True) 

77 

78 

79@linear_operator_algebra.RegisterInverse( 

80 linear_operator_block_diag.LinearOperatorBlockDiag) 

81def _inverse_block_diag(block_diag_operator): 

82 # We take the inverse of each block on the diagonal. 

83 return linear_operator_block_diag.LinearOperatorBlockDiag( 

84 operators=[ 

85 operator.inverse() for operator in block_diag_operator.operators], 

86 is_non_singular=block_diag_operator.is_non_singular, 

87 is_self_adjoint=block_diag_operator.is_self_adjoint, 

88 is_positive_definite=block_diag_operator.is_positive_definite, 

89 is_square=True) 

90 

91 

92@linear_operator_algebra.RegisterInverse( 

93 linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular) 

94def _inverse_block_lower_triangular(block_lower_triangular_operator): 

95 """Inverse of LinearOperatorBlockLowerTriangular. 

96 

97 We recursively apply the identity: 

98 

99 ```none 

100 |A 0|' = | A' 0| 

101 |B C| |-C'BA' C'| 

102 ``` 

103 

104 where `A` is n-by-n, `B` is m-by-n, `C` is m-by-m, and `'` denotes inverse. 

105 

106 This identity can be verified through multiplication: 

107 

108 ```none 

109 |A 0|| A' 0| 

110 |B C||-C'BA' C'| 

111 

112 = | AA' 0| 

113 |BA'-CC'BA' CC'| 

114 

115 = |I 0| 

116 |0 I| 

117 ``` 

118 

119 Args: 

120 block_lower_triangular_operator: Instance of 

121 `LinearOperatorBlockLowerTriangular`. 

122 

123 Returns: 

124 block_lower_triangular_operator_inverse: Instance of 

125 `LinearOperatorBlockLowerTriangular`, the inverse of 

126 `block_lower_triangular_operator`. 

127 """ 

128 if len(block_lower_triangular_operator.operators) == 1: 

129 return (linear_operator_block_lower_triangular. 

130 LinearOperatorBlockLowerTriangular( 

131 [[block_lower_triangular_operator.operators[0][0].inverse()]], 

132 is_non_singular=block_lower_triangular_operator.is_non_singular, 

133 is_self_adjoint=block_lower_triangular_operator.is_self_adjoint, 

134 is_positive_definite=(block_lower_triangular_operator. 

135 is_positive_definite), 

136 is_square=True)) 

137 

138 blockwise_dim = len(block_lower_triangular_operator.operators) 

139 

140 # Calculate the inverse of the `LinearOperatorBlockLowerTriangular` 

141 # representing all but the last row of `block_lower_triangular_operator` with 

142 # a recursive call (the matrix `A'` in the docstring definition). 

143 upper_left_inverse = ( 

144 linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular( 

145 block_lower_triangular_operator.operators[:-1]).inverse()) 

146 

147 bottom_row = block_lower_triangular_operator.operators[-1] 

148 bottom_right_inverse = bottom_row[-1].inverse() 

149 

150 # Find the bottom row of the inverse (equal to `[-C'BA', C']` in the docstring 

151 # definition, where `C` is the bottom-right operator of 

152 # `block_lower_triangular_operator` and `B` is the set of operators in the 

153 # bottom row excluding `C`). To find `-C'BA'`, we first iterate over the 

154 # column partitions of `A'`. 

155 inverse_bottom_row = [] 

156 for i in range(blockwise_dim - 1): 

157 # Find the `i`-th block of `BA'`. 

158 blocks = [] 

159 for j in range(i, blockwise_dim - 1): 

160 result = bottom_row[j].matmul(upper_left_inverse.operators[j][i]) 

161 if not any(isinstance(result, op_type) 

162 for op_type in linear_operator_addition.SUPPORTED_OPERATORS): 

163 result = linear_operator_full_matrix.LinearOperatorFullMatrix( 

164 result.to_dense()) 

165 blocks.append(result) 

166 

167 summed_blocks = linear_operator_addition.add_operators(blocks) 

168 assert len(summed_blocks) == 1 

169 block = summed_blocks[0] 

170 

171 # Find the `i`-th block of `-C'BA'`. 

172 block = bottom_right_inverse.matmul(block) 

173 block = linear_operator_identity.LinearOperatorScaledIdentity( 

174 num_rows=bottom_right_inverse.domain_dimension_tensor(), 

175 multiplier=math_ops.cast(-1, dtype=block.dtype)).matmul(block) 

176 inverse_bottom_row.append(block) 

177 

178 # `C'` is the last block of the inverted linear operator. 

179 inverse_bottom_row.append(bottom_right_inverse) 

180 

181 return ( 

182 linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular( 

183 upper_left_inverse.operators + [inverse_bottom_row], 

184 is_non_singular=block_lower_triangular_operator.is_non_singular, 

185 is_self_adjoint=block_lower_triangular_operator.is_self_adjoint, 

186 is_positive_definite=(block_lower_triangular_operator. 

187 is_positive_definite), 

188 is_square=True)) 

189 

190 

191@linear_operator_algebra.RegisterInverse( 

192 linear_operator_kronecker.LinearOperatorKronecker) 

193def _inverse_kronecker(kronecker_operator): 

194 # Inverse decomposition of a Kronecker product is the Kronecker product 

195 # of inverse decompositions. 

196 return linear_operator_kronecker.LinearOperatorKronecker( 

197 operators=[ 

198 operator.inverse() for operator in kronecker_operator.operators], 

199 is_non_singular=kronecker_operator.is_non_singular, 

200 is_self_adjoint=kronecker_operator.is_self_adjoint, 

201 is_positive_definite=kronecker_operator.is_positive_definite, 

202 is_square=True) 

203 

204 

205@linear_operator_algebra.RegisterInverse( 

206 linear_operator_circulant._BaseLinearOperatorCirculant) # pylint: disable=protected-access 

207def _inverse_circulant(circulant_operator): 

208 # Inverting the spectrum is sufficient to get the inverse. 

209 return circulant_operator.__class__( 

210 spectrum=1. / circulant_operator.spectrum, 

211 is_non_singular=circulant_operator.is_non_singular, 

212 is_self_adjoint=circulant_operator.is_self_adjoint, 

213 is_positive_definite=circulant_operator.is_positive_definite, 

214 is_square=True, 

215 input_output_dtype=circulant_operator.dtype) 

216 

217 

218@linear_operator_algebra.RegisterInverse( 

219 linear_operator_householder.LinearOperatorHouseholder) 

220def _inverse_householder(householder_operator): 

221 return householder_operator