Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/solve_registrations.py: 57%

58 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Registrations for LinearOperator.solve.""" 

16 

17from tensorflow.python.ops.linalg import linear_operator 

18from tensorflow.python.ops.linalg import linear_operator_algebra 

19from tensorflow.python.ops.linalg import linear_operator_block_diag 

20from tensorflow.python.ops.linalg import linear_operator_circulant 

21from tensorflow.python.ops.linalg import linear_operator_composition 

22from tensorflow.python.ops.linalg import linear_operator_diag 

23from tensorflow.python.ops.linalg import linear_operator_identity 

24from tensorflow.python.ops.linalg import linear_operator_inversion 

25from tensorflow.python.ops.linalg import linear_operator_lower_triangular 

26from tensorflow.python.ops.linalg import registrations_util 

27 

28 

29# By default, use a LinearOperatorComposition to delay the computation. 

30@linear_operator_algebra.RegisterSolve( 

31 linear_operator.LinearOperator, linear_operator.LinearOperator) 

32def _solve_linear_operator(linop_a, linop_b): 

33 """Generic solve of two `LinearOperator`s.""" 

34 is_square = registrations_util.is_square(linop_a, linop_b) 

35 is_non_singular = None 

36 is_self_adjoint = None 

37 is_positive_definite = None 

38 

39 if is_square: 

40 is_non_singular = registrations_util.combined_non_singular_hint( 

41 linop_a, linop_b) 

42 elif is_square is False: # pylint:disable=g-bool-id-comparison 

43 is_non_singular = False 

44 is_self_adjoint = False 

45 is_positive_definite = False 

46 

47 return linear_operator_composition.LinearOperatorComposition( 

48 operators=[ 

49 linear_operator_inversion.LinearOperatorInversion(linop_a), 

50 linop_b 

51 ], 

52 is_non_singular=is_non_singular, 

53 is_self_adjoint=is_self_adjoint, 

54 is_positive_definite=is_positive_definite, 

55 is_square=is_square, 

56 ) 

57 

58 

59@linear_operator_algebra.RegisterSolve( 

60 linear_operator_inversion.LinearOperatorInversion, 

61 linear_operator.LinearOperator) 

62def _solve_inverse_linear_operator(linop_a, linop_b): 

63 """Solve inverse of generic `LinearOperator`s.""" 

64 return linop_a.operator.matmul(linop_b) 

65 

66 

67# Identity 

68@linear_operator_algebra.RegisterSolve( 

69 linear_operator_identity.LinearOperatorIdentity, 

70 linear_operator.LinearOperator) 

71def _solve_linear_operator_identity_left(identity, linop): 

72 del identity 

73 return linop 

74 

75 

76@linear_operator_algebra.RegisterSolve( 

77 linear_operator.LinearOperator, 

78 linear_operator_identity.LinearOperatorIdentity) 

79def _solve_linear_operator_identity_right(linop, identity): 

80 del identity 

81 return linop.inverse() 

82 

83 

84@linear_operator_algebra.RegisterSolve( 

85 linear_operator_identity.LinearOperatorScaledIdentity, 

86 linear_operator_identity.LinearOperatorScaledIdentity) 

87def _solve_linear_operator_scaled_identity(linop_a, linop_b): 

88 """Solve of two ScaledIdentity `LinearOperators`.""" 

89 return linear_operator_identity.LinearOperatorScaledIdentity( 

90 num_rows=linop_a.domain_dimension_tensor(), 

91 multiplier=linop_b.multiplier / linop_a.multiplier, 

92 is_non_singular=registrations_util.combined_non_singular_hint( 

93 linop_a, linop_b), 

94 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

95 linop_a, linop_b), 

96 is_positive_definite=( 

97 registrations_util.combined_commuting_positive_definite_hint( 

98 linop_a, linop_b)), 

99 is_square=True) 

100 

101 

102# Diag. 

103 

104 

105@linear_operator_algebra.RegisterSolve( 

106 linear_operator_diag.LinearOperatorDiag, 

107 linear_operator_diag.LinearOperatorDiag) 

108def _solve_linear_operator_diag(linop_a, linop_b): 

109 return linear_operator_diag.LinearOperatorDiag( 

110 diag=linop_b.diag / linop_a.diag, 

111 is_non_singular=registrations_util.combined_non_singular_hint( 

112 linop_a, linop_b), 

113 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

114 linop_a, linop_b), 

115 is_positive_definite=( 

116 registrations_util.combined_commuting_positive_definite_hint( 

117 linop_a, linop_b)), 

118 is_square=True) 

119 

120 

121@linear_operator_algebra.RegisterSolve( 

122 linear_operator_diag.LinearOperatorDiag, 

123 linear_operator_identity.LinearOperatorScaledIdentity) 

124def _solve_linear_operator_diag_scaled_identity_right( 

125 linop_diag, linop_scaled_identity): 

126 return linear_operator_diag.LinearOperatorDiag( 

127 diag=linop_scaled_identity.multiplier / linop_diag.diag, 

128 is_non_singular=registrations_util.combined_non_singular_hint( 

129 linop_diag, linop_scaled_identity), 

130 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

131 linop_diag, linop_scaled_identity), 

132 is_positive_definite=( 

133 registrations_util.combined_commuting_positive_definite_hint( 

134 linop_diag, linop_scaled_identity)), 

135 is_square=True) 

136 

137 

138@linear_operator_algebra.RegisterSolve( 

139 linear_operator_identity.LinearOperatorScaledIdentity, 

140 linear_operator_diag.LinearOperatorDiag) 

141def _solve_linear_operator_diag_scaled_identity_left( 

142 linop_scaled_identity, linop_diag): 

143 return linear_operator_diag.LinearOperatorDiag( 

144 diag=linop_diag.diag / linop_scaled_identity.multiplier, 

145 is_non_singular=registrations_util.combined_non_singular_hint( 

146 linop_diag, linop_scaled_identity), 

147 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

148 linop_diag, linop_scaled_identity), 

149 is_positive_definite=( 

150 registrations_util.combined_commuting_positive_definite_hint( 

151 linop_diag, linop_scaled_identity)), 

152 is_square=True) 

153 

154 

155@linear_operator_algebra.RegisterSolve( 

156 linear_operator_diag.LinearOperatorDiag, 

157 linear_operator_lower_triangular.LinearOperatorLowerTriangular) 

158def _solve_linear_operator_diag_tril(linop_diag, linop_triangular): 

159 return linear_operator_lower_triangular.LinearOperatorLowerTriangular( 

160 tril=linop_triangular.to_dense() / linop_diag.diag[..., None], 

161 is_non_singular=registrations_util.combined_non_singular_hint( 

162 linop_diag, linop_triangular), 

163 # This is safe to do since the Triangular matrix is only self-adjoint 

164 # when it is a diagonal matrix, and hence commutes. 

165 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

166 linop_diag, linop_triangular), 

167 is_positive_definite=None, 

168 is_square=True) 

169 

170 

171# Circulant. 

172 

173 

174# pylint: disable=protected-access 

175@linear_operator_algebra.RegisterSolve( 

176 linear_operator_circulant._BaseLinearOperatorCirculant, 

177 linear_operator_circulant._BaseLinearOperatorCirculant) 

178def _solve_linear_operator_circulant_circulant(linop_a, linop_b): 

179 if not isinstance(linop_a, linop_b.__class__): 

180 return _solve_linear_operator(linop_a, linop_b) 

181 

182 return linop_a.__class__( 

183 spectrum=linop_b.spectrum / linop_a.spectrum, 

184 is_non_singular=registrations_util.combined_non_singular_hint( 

185 linop_a, linop_b), 

186 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

187 linop_a, linop_b), 

188 is_positive_definite=( 

189 registrations_util.combined_commuting_positive_definite_hint( 

190 linop_a, linop_b)), 

191 is_square=True) 

192# pylint: enable=protected-access 

193 

194 

195# Block Diag 

196 

197 

198@linear_operator_algebra.RegisterSolve( 

199 linear_operator_block_diag.LinearOperatorBlockDiag, 

200 linear_operator_block_diag.LinearOperatorBlockDiag) 

201def _solve_linear_operator_block_diag_block_diag(linop_a, linop_b): 

202 return linear_operator_block_diag.LinearOperatorBlockDiag( 

203 operators=[ 

204 o1.solve(o2) for o1, o2 in zip( 

205 linop_a.operators, linop_b.operators)], 

206 is_non_singular=registrations_util.combined_non_singular_hint( 

207 linop_a, linop_b), 

208 # In general, a solve of self-adjoint positive-definite block diagonal 

209 # matrices is not self-=adjoint. 

210 is_self_adjoint=None, 

211 # In general, a solve of positive-definite block diagonal matrices is 

212 # not positive-definite. 

213 is_positive_definite=None, 

214 is_square=True)