Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/matmul_registrations.py: 54%

68 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Registrations for LinearOperator.matmul.""" 

16 

17from tensorflow.python.ops.linalg import linear_operator 

18from tensorflow.python.ops.linalg import linear_operator_algebra 

19from tensorflow.python.ops.linalg import linear_operator_block_diag 

20from tensorflow.python.ops.linalg import linear_operator_circulant 

21from tensorflow.python.ops.linalg import linear_operator_composition 

22from tensorflow.python.ops.linalg import linear_operator_diag 

23from tensorflow.python.ops.linalg import linear_operator_identity 

24from tensorflow.python.ops.linalg import linear_operator_lower_triangular 

25from tensorflow.python.ops.linalg import linear_operator_zeros 

26from tensorflow.python.ops.linalg import registrations_util 

27 

28 

29# By default, use a LinearOperatorComposition to delay the computation. 

30@linear_operator_algebra.RegisterMatmul( 

31 linear_operator.LinearOperator, linear_operator.LinearOperator) 

32def _matmul_linear_operator(linop_a, linop_b): 

33 """Generic matmul of two `LinearOperator`s.""" 

34 is_square = registrations_util.is_square(linop_a, linop_b) 

35 is_non_singular = None 

36 is_self_adjoint = None 

37 is_positive_definite = None 

38 

39 if is_square: 

40 is_non_singular = registrations_util.combined_non_singular_hint( 

41 linop_a, linop_b) 

42 elif is_square is False: # pylint:disable=g-bool-id-comparison 

43 is_non_singular = False 

44 is_self_adjoint = False 

45 is_positive_definite = False 

46 

47 return linear_operator_composition.LinearOperatorComposition( 

48 operators=[linop_a, linop_b], 

49 is_non_singular=is_non_singular, 

50 is_self_adjoint=is_self_adjoint, 

51 is_positive_definite=is_positive_definite, 

52 is_square=is_square, 

53 ) 

54 

55# Identity 

56 

57 

58@linear_operator_algebra.RegisterMatmul( 

59 linear_operator_identity.LinearOperatorIdentity, 

60 linear_operator.LinearOperator) 

61def _matmul_linear_operator_identity_left(identity, linop): 

62 del identity 

63 return linop 

64 

65 

66@linear_operator_algebra.RegisterMatmul( 

67 linear_operator.LinearOperator, 

68 linear_operator_identity.LinearOperatorIdentity) 

69def _matmul_linear_operator_identity_right(linop, identity): 

70 del identity 

71 return linop 

72 

73 

74@linear_operator_algebra.RegisterMatmul( 

75 linear_operator_identity.LinearOperatorScaledIdentity, 

76 linear_operator_identity.LinearOperatorScaledIdentity) 

77def _matmul_linear_operator_scaled_identity(linop_a, linop_b): 

78 """Matmul of two ScaledIdentity `LinearOperators`.""" 

79 return linear_operator_identity.LinearOperatorScaledIdentity( 

80 num_rows=linop_a.domain_dimension_tensor(), 

81 multiplier=linop_a.multiplier * linop_b.multiplier, 

82 is_non_singular=registrations_util.combined_non_singular_hint( 

83 linop_a, linop_b), 

84 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

85 linop_a, linop_b), 

86 is_positive_definite=( 

87 registrations_util.combined_commuting_positive_definite_hint( 

88 linop_a, linop_b)), 

89 is_square=True) 

90 

91 

92# Zeros 

93 

94 

95@linear_operator_algebra.RegisterMatmul( 

96 linear_operator.LinearOperator, 

97 linear_operator_zeros.LinearOperatorZeros) 

98def _matmul_linear_operator_zeros_right(linop, zeros): 

99 if not zeros.is_square or not linop.is_square: 

100 raise ValueError("Matmul with non-square `LinearOperator`s or non-square " 

101 "`LinearOperatorZeros` not supported at this time.") 

102 return zeros 

103 

104 

105@linear_operator_algebra.RegisterMatmul( 

106 linear_operator_zeros.LinearOperatorZeros, 

107 linear_operator.LinearOperator) 

108def _matmul_linear_operator_zeros_left(zeros, linop): 

109 if not zeros.is_square or not linop.is_square: 

110 raise ValueError("Matmul with non-square `LinearOperator`s or non-square " 

111 "`LinearOperatorZeros` not supported at this time.") 

112 return zeros 

113 

114 

115# Diag. 

116 

117 

118@linear_operator_algebra.RegisterMatmul( 

119 linear_operator_diag.LinearOperatorDiag, 

120 linear_operator_diag.LinearOperatorDiag) 

121def _matmul_linear_operator_diag(linop_a, linop_b): 

122 return linear_operator_diag.LinearOperatorDiag( 

123 diag=linop_a.diag * linop_b.diag, 

124 is_non_singular=registrations_util.combined_non_singular_hint( 

125 linop_a, linop_b), 

126 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

127 linop_a, linop_b), 

128 is_positive_definite=( 

129 registrations_util.combined_commuting_positive_definite_hint( 

130 linop_a, linop_b)), 

131 is_square=True) 

132 

133 

134@linear_operator_algebra.RegisterMatmul( 

135 linear_operator_diag.LinearOperatorDiag, 

136 linear_operator_identity.LinearOperatorScaledIdentity) 

137def _matmul_linear_operator_diag_scaled_identity_right( 

138 linop_diag, linop_scaled_identity): 

139 return linear_operator_diag.LinearOperatorDiag( 

140 diag=linop_diag.diag * linop_scaled_identity.multiplier, 

141 is_non_singular=registrations_util.combined_non_singular_hint( 

142 linop_diag, linop_scaled_identity), 

143 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

144 linop_diag, linop_scaled_identity), 

145 is_positive_definite=( 

146 registrations_util.combined_commuting_positive_definite_hint( 

147 linop_diag, linop_scaled_identity)), 

148 is_square=True) 

149 

150 

151@linear_operator_algebra.RegisterMatmul( 

152 linear_operator_identity.LinearOperatorScaledIdentity, 

153 linear_operator_diag.LinearOperatorDiag) 

154def _matmul_linear_operator_diag_scaled_identity_left( 

155 linop_scaled_identity, linop_diag): 

156 return linear_operator_diag.LinearOperatorDiag( 

157 diag=linop_diag.diag * linop_scaled_identity.multiplier, 

158 is_non_singular=registrations_util.combined_non_singular_hint( 

159 linop_diag, linop_scaled_identity), 

160 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

161 linop_diag, linop_scaled_identity), 

162 is_positive_definite=( 

163 registrations_util.combined_commuting_positive_definite_hint( 

164 linop_diag, linop_scaled_identity)), 

165 is_square=True) 

166 

167 

168@linear_operator_algebra.RegisterMatmul( 

169 linear_operator_diag.LinearOperatorDiag, 

170 linear_operator_lower_triangular.LinearOperatorLowerTriangular) 

171def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular): 

172 return linear_operator_lower_triangular.LinearOperatorLowerTriangular( 

173 tril=linop_diag.diag[..., None] * linop_triangular.to_dense(), 

174 is_non_singular=registrations_util.combined_non_singular_hint( 

175 linop_diag, linop_triangular), 

176 # This is safe to do since the Triangular matrix is only self-adjoint 

177 # when it is a diagonal matrix, and hence commutes. 

178 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

179 linop_diag, linop_triangular), 

180 is_positive_definite=None, 

181 is_square=True) 

182 

183 

184@linear_operator_algebra.RegisterMatmul( 

185 linear_operator_lower_triangular.LinearOperatorLowerTriangular, 

186 linear_operator_diag.LinearOperatorDiag) 

187def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag): 

188 return linear_operator_lower_triangular.LinearOperatorLowerTriangular( 

189 tril=linop_triangular.to_dense() * linop_diag.diag, 

190 is_non_singular=registrations_util.combined_non_singular_hint( 

191 linop_diag, linop_triangular), 

192 # This is safe to do since the Triangular matrix is only self-adjoint 

193 # when it is a diagonal matrix, and hence commutes. 

194 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

195 linop_diag, linop_triangular), 

196 is_positive_definite=None, 

197 is_square=True) 

198 

199# Circulant. 

200 

201 

202# pylint: disable=protected-access 

203@linear_operator_algebra.RegisterMatmul( 

204 linear_operator_circulant._BaseLinearOperatorCirculant, 

205 linear_operator_circulant._BaseLinearOperatorCirculant) 

206def _matmul_linear_operator_circulant_circulant(linop_a, linop_b): 

207 if not isinstance(linop_a, linop_b.__class__): 

208 return _matmul_linear_operator(linop_a, linop_b) 

209 

210 return linop_a.__class__( 

211 spectrum=linop_a.spectrum * linop_b.spectrum, 

212 is_non_singular=registrations_util.combined_non_singular_hint( 

213 linop_a, linop_b), 

214 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 

215 linop_a, linop_b), 

216 is_positive_definite=( 

217 registrations_util.combined_commuting_positive_definite_hint( 

218 linop_a, linop_b)), 

219 is_square=True) 

220# pylint: enable=protected-access 

221 

222# Block Diag 

223 

224 

225@linear_operator_algebra.RegisterMatmul( 

226 linear_operator_block_diag.LinearOperatorBlockDiag, 

227 linear_operator_block_diag.LinearOperatorBlockDiag) 

228def _matmul_linear_operator_block_diag_block_diag(linop_a, linop_b): 

229 return linear_operator_block_diag.LinearOperatorBlockDiag( 

230 operators=[ 

231 o1.matmul(o2) for o1, o2 in zip( 

232 linop_a.operators, linop_b.operators)], 

233 is_non_singular=registrations_util.combined_non_singular_hint( 

234 linop_a, linop_b), 

235 # In general, a product of self-adjoint positive-definite block diagonal 

236 # matrices is not self-=adjoint. 

237 is_self_adjoint=None, 

238 # In general, a product of positive-definite block diagonal matrices is 

239 # not positive-definite. 

240 is_positive_definite=None, 

241 is_square=True)