Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/cholesky_registrations.py: 62%

47 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Registrations for LinearOperator.cholesky.""" 

16 

17from tensorflow.python.ops import array_ops 

18from tensorflow.python.ops import linalg_ops 

19from tensorflow.python.ops import math_ops 

20from tensorflow.python.ops.linalg import linear_operator 

21from tensorflow.python.ops.linalg import linear_operator_algebra 

22from tensorflow.python.ops.linalg import linear_operator_block_diag 

23from tensorflow.python.ops.linalg import linear_operator_composition 

24from tensorflow.python.ops.linalg import linear_operator_diag 

25from tensorflow.python.ops.linalg import linear_operator_identity 

26from tensorflow.python.ops.linalg import linear_operator_kronecker 

27from tensorflow.python.ops.linalg import linear_operator_lower_triangular 

28from tensorflow.python.ops.linalg import linear_operator_util 

29 

30LinearOperatorLowerTriangular = ( 

31 linear_operator_lower_triangular.LinearOperatorLowerTriangular) 

32 

33 

34# By default, compute the Cholesky of the dense matrix, and return a 

35# LowerTriangular operator. Methods below specialize this registration. 

36@linear_operator_algebra.RegisterCholesky(linear_operator.LinearOperator) 

37def _cholesky_linear_operator(linop): 

38 return LinearOperatorLowerTriangular( 

39 linalg_ops.cholesky(linop.to_dense()), 

40 is_non_singular=True, 

41 is_self_adjoint=False, 

42 is_square=True) 

43 

44 

45def _is_llt_product(linop): 

46 """Determines if linop = L @ L.H for L = LinearOperatorLowerTriangular.""" 

47 if len(linop.operators) != 2: 

48 return False 

49 if not linear_operator_util.is_aat_form(linop.operators): 

50 return False 

51 return isinstance(linop.operators[0], LinearOperatorLowerTriangular) 

52 

53 

54@linear_operator_algebra.RegisterCholesky( 

55 linear_operator_composition.LinearOperatorComposition) 

56def _cholesky_linear_operator_composition(linop): 

57 """Computes Cholesky(LinearOperatorComposition).""" 

58 # L @ L.H will be handled with special code below. Why is L @ L.H the most 

59 # important special case? 

60 # Note that Diag @ Diag.H and Diag @ TriL and TriL @ Diag are already 

61 # compressed to Diag or TriL by diag matmul 

62 # registration. Similarly for Identity and ScaledIdentity. 

63 # So these would not appear in a LinearOperatorComposition unless explicitly 

64 # constructed as such. So the most important thing to check is L @ L.H. 

65 if not _is_llt_product(linop): 

66 return LinearOperatorLowerTriangular( 

67 linalg_ops.cholesky(linop.to_dense()), 

68 is_non_singular=True, 

69 is_self_adjoint=False, 

70 is_square=True) 

71 

72 left_op = linop.operators[0] 

73 

74 # left_op.is_positive_definite ==> op already has positive diag. So return it. 

75 if left_op.is_positive_definite: 

76 return left_op 

77 

78 # Recall that the base class has already verified linop.is_positive_definite, 

79 # else linop.cholesky() would have raised. 

80 # So in particular, we know the diagonal has nonzero entries. 

81 # In the generic case, we make op have positive diag by dividing each row 

82 # by the sign of the diag. This is equivalent to setting A = L @ D where D is 

83 # diag(sign(1 / L.diag_part())). Then A is lower triangular with positive diag 

84 # and A @ A^H = L @ D @ D^H @ L^H = L @ L^H = linop. 

85 # This also works for complex L, since sign(x + iy) = exp(i * angle(x + iy)). 

86 diag_sign = array_ops.expand_dims(math_ops.sign(left_op.diag_part()), axis=-2) 

87 return LinearOperatorLowerTriangular( 

88 tril=left_op.tril / diag_sign, 

89 is_non_singular=left_op.is_non_singular, 

90 # L.is_self_adjoint ==> L is diagonal ==> L @ D is diagonal ==> SA 

91 # L.is_self_adjoint is False ==> L not diagonal ==> L @ D not diag ... 

92 is_self_adjoint=left_op.is_self_adjoint, 

93 # L.is_positive_definite ==> L has positive diag ==> L = L @ D 

94 # ==> (L @ D).is_positive_definite. 

95 # L.is_positive_definite is False could result in L @ D being PD or not.. 

96 # Consider L = [[1, 0], [-2, 1]] and quadratic form with x = [1, 1]. 

97 # Note we will already return left_op if left_op.is_positive_definite 

98 # above, but to be explicit write this below. 

99 is_positive_definite=True if left_op.is_positive_definite else None, 

100 is_square=True, 

101 ) 

102 

103 

104@linear_operator_algebra.RegisterCholesky( 

105 linear_operator_diag.LinearOperatorDiag) 

106def _cholesky_diag(diag_operator): 

107 return linear_operator_diag.LinearOperatorDiag( 

108 math_ops.sqrt(diag_operator.diag), 

109 is_non_singular=True, 

110 is_self_adjoint=True, 

111 is_positive_definite=True, 

112 is_square=True) 

113 

114 

115@linear_operator_algebra.RegisterCholesky( 

116 linear_operator_identity.LinearOperatorIdentity) 

117def _cholesky_identity(identity_operator): 

118 return linear_operator_identity.LinearOperatorIdentity( 

119 num_rows=identity_operator._num_rows, # pylint: disable=protected-access 

120 batch_shape=identity_operator.batch_shape, 

121 dtype=identity_operator.dtype, 

122 is_non_singular=True, 

123 is_self_adjoint=True, 

124 is_positive_definite=True, 

125 is_square=True) 

126 

127 

128@linear_operator_algebra.RegisterCholesky( 

129 linear_operator_identity.LinearOperatorScaledIdentity) 

130def _cholesky_scaled_identity(identity_operator): 

131 return linear_operator_identity.LinearOperatorScaledIdentity( 

132 num_rows=identity_operator._num_rows, # pylint: disable=protected-access 

133 multiplier=math_ops.sqrt(identity_operator.multiplier), 

134 is_non_singular=True, 

135 is_self_adjoint=True, 

136 is_positive_definite=True, 

137 is_square=True) 

138 

139 

140@linear_operator_algebra.RegisterCholesky( 

141 linear_operator_block_diag.LinearOperatorBlockDiag) 

142def _cholesky_block_diag(block_diag_operator): 

143 # We take the cholesky of each block on the diagonal. 

144 return linear_operator_block_diag.LinearOperatorBlockDiag( 

145 operators=[ 

146 operator.cholesky() for operator in block_diag_operator.operators], 

147 is_non_singular=True, 

148 is_self_adjoint=None, # Let the operators passed in decide. 

149 is_square=True) 

150 

151 

152@linear_operator_algebra.RegisterCholesky( 

153 linear_operator_kronecker.LinearOperatorKronecker) 

154def _cholesky_kronecker(kronecker_operator): 

155 # Cholesky decomposition of a Kronecker product is the Kronecker product 

156 # of cholesky decompositions. 

157 return linear_operator_kronecker.LinearOperatorKronecker( 

158 operators=[ 

159 operator.cholesky() for operator in kronecker_operator.operators], 

160 is_non_singular=True, 

161 is_self_adjoint=None, # Let the operators passed in decide. 

162 is_square=True)