Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_diag.py: 42%

88 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""`LinearOperator` acting like a diagonal matrix.""" 

16 

17from tensorflow.python.framework import ops 

18from tensorflow.python.framework import tensor_conversion 

19from tensorflow.python.ops import array_ops 

20from tensorflow.python.ops import check_ops 

21from tensorflow.python.ops import math_ops 

22from tensorflow.python.ops.linalg import linalg_impl as linalg 

23from tensorflow.python.ops.linalg import linear_operator 

24from tensorflow.python.ops.linalg import linear_operator_util 

25from tensorflow.python.util.tf_export import tf_export 

26 

27__all__ = ["LinearOperatorDiag",] 

28 

29 

30@tf_export("linalg.LinearOperatorDiag") 

31@linear_operator.make_composite_tensor 

32class LinearOperatorDiag(linear_operator.LinearOperator): 

33 """`LinearOperator` acting like a [batch] square diagonal matrix. 

34 

35 This operator acts like a [batch] diagonal matrix `A` with shape 

36 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 

37 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 

38 an `N x N` matrix. This matrix `A` is not materialized, but for 

39 purposes of broadcasting this shape will be relevant. 

40 

41 `LinearOperatorDiag` is initialized with a (batch) vector. 

42 

43 ```python 

44 # Create a 2 x 2 diagonal linear operator. 

45 diag = [1., -1.] 

46 operator = LinearOperatorDiag(diag) 

47 

48 operator.to_dense() 

49 ==> [[1., 0.] 

50 [0., -1.]] 

51 

52 operator.shape 

53 ==> [2, 2] 

54 

55 operator.log_abs_determinant() 

56 ==> scalar Tensor 

57 

58 x = ... Shape [2, 4] Tensor 

59 operator.matmul(x) 

60 ==> Shape [2, 4] Tensor 

61 

62 # Create a [2, 3] batch of 4 x 4 linear operators. 

63 diag = tf.random.normal(shape=[2, 3, 4]) 

64 operator = LinearOperatorDiag(diag) 

65 

66 # Create a shape [2, 1, 4, 2] vector. Note that this shape is compatible 

67 # since the batch dimensions, [2, 1], are broadcast to 

68 # operator.batch_shape = [2, 3]. 

69 y = tf.random.normal(shape=[2, 1, 4, 2]) 

70 x = operator.solve(y) 

71 ==> operator.matmul(x) = y 

72 ``` 

73 

74 #### Shape compatibility 

75 

76 This operator acts on [batch] matrix with compatible shape. 

77 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 

78 

79 ``` 

80 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 

81 x.shape = [C1,...,Cc] + [N, R], 

82 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 

83 ``` 

84 

85 #### Performance 

86 

87 Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`, 

88 and `x.shape = [N, R]`. Then 

89 

90 * `operator.matmul(x)` involves `N * R` multiplications. 

91 * `operator.solve(x)` involves `N` divisions and `N * R` multiplications. 

92 * `operator.determinant()` involves a size `N` `reduce_prod`. 

93 

94 If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and 

95 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 

96 

97 #### Matrix property hints 

98 

99 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 

100 for `X = non_singular, self_adjoint, positive_definite, square`. 

101 These have the following meaning: 

102 

103 * If `is_X == True`, callers should expect the operator to have the 

104 property `X`. This is a promise that should be fulfilled, but is *not* a 

105 runtime assert. For example, finite floating point precision may result 

106 in these promises being violated. 

107 * If `is_X == False`, callers should expect the operator to not have `X`. 

108 * If `is_X == None` (the default), callers should have no expectation either 

109 way. 

110 """ 

111 

112 def __init__(self, 

113 diag, 

114 is_non_singular=None, 

115 is_self_adjoint=None, 

116 is_positive_definite=None, 

117 is_square=None, 

118 name="LinearOperatorDiag"): 

119 r"""Initialize a `LinearOperatorDiag`. 

120 

121 Args: 

122 diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. 

123 The diagonal of the operator. Allowed dtypes: `float16`, `float32`, 

124 `float64`, `complex64`, `complex128`. 

125 is_non_singular: Expect that this operator is non-singular. 

126 is_self_adjoint: Expect that this operator is equal to its hermitian 

127 transpose. If `diag.dtype` is real, this is auto-set to `True`. 

128 is_positive_definite: Expect that this operator is positive definite, 

129 meaning the quadratic form `x^H A x` has positive real part for all 

130 nonzero `x`. Note that we do not require the operator to be 

131 self-adjoint to be positive-definite. See: 

132 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 

133 is_square: Expect that this operator acts like square [batch] matrices. 

134 name: A name for this `LinearOperator`. 

135 

136 Raises: 

137 TypeError: If `diag.dtype` is not an allowed type. 

138 ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`. 

139 """ 

140 parameters = dict( 

141 diag=diag, 

142 is_non_singular=is_non_singular, 

143 is_self_adjoint=is_self_adjoint, 

144 is_positive_definite=is_positive_definite, 

145 is_square=is_square, 

146 name=name 

147 ) 

148 

149 with ops.name_scope(name, values=[diag]): 

150 self._diag = linear_operator_util.convert_nonref_to_tensor( 

151 diag, name="diag") 

152 self._check_diag(self._diag) 

153 

154 # Check and auto-set hints. 

155 if not self._diag.dtype.is_complex: 

156 if is_self_adjoint is False: 

157 raise ValueError("A real diagonal operator is always self adjoint.") 

158 else: 

159 is_self_adjoint = True 

160 

161 if is_square is False: 

162 raise ValueError("Only square diagonal operators currently supported.") 

163 is_square = True 

164 

165 super(LinearOperatorDiag, self).__init__( 

166 dtype=self._diag.dtype, 

167 is_non_singular=is_non_singular, 

168 is_self_adjoint=is_self_adjoint, 

169 is_positive_definite=is_positive_definite, 

170 is_square=is_square, 

171 parameters=parameters, 

172 name=name) 

173 

174 def _check_diag(self, diag): 

175 """Static check of diag.""" 

176 if diag.shape.ndims is not None and diag.shape.ndims < 1: 

177 raise ValueError("Argument diag must have at least 1 dimension. " 

178 "Found: %s" % diag) 

179 

180 def _shape(self): 

181 # If d_shape = [5, 3], we return [5, 3, 3]. 

182 d_shape = self._diag.shape 

183 return d_shape.concatenate(d_shape[-1:]) 

184 

185 def _shape_tensor(self): 

186 d_shape = array_ops.shape(self._diag) 

187 k = d_shape[-1] 

188 return array_ops.concat((d_shape, [k]), 0) 

189 

190 @property 

191 def diag(self): 

192 return self._diag 

193 

194 def _assert_non_singular(self): 

195 return linear_operator_util.assert_no_entries_with_modulus_zero( 

196 self._diag, 

197 message="Singular operator: Diagonal contained zero values.") 

198 

199 def _assert_positive_definite(self): 

200 if self.dtype.is_complex: 

201 message = ( 

202 "Diagonal operator had diagonal entries with non-positive real part, " 

203 "thus was not positive definite.") 

204 else: 

205 message = ( 

206 "Real diagonal operator had non-positive diagonal entries, " 

207 "thus was not positive definite.") 

208 

209 return check_ops.assert_positive( 

210 math_ops.real(self._diag), 

211 message=message) 

212 

213 def _assert_self_adjoint(self): 

214 return linear_operator_util.assert_zero_imag_part( 

215 self._diag, 

216 message=( 

217 "This diagonal operator contained non-zero imaginary values. " 

218 " Thus it was not self-adjoint.")) 

219 

220 def _matmul(self, x, adjoint=False, adjoint_arg=False): 

221 diag_term = math_ops.conj(self._diag) if adjoint else self._diag 

222 x = linalg.adjoint(x) if adjoint_arg else x 

223 diag_mat = array_ops.expand_dims(diag_term, -1) 

224 return diag_mat * x 

225 

226 def _matvec(self, x, adjoint=False): 

227 diag_term = math_ops.conj(self._diag) if adjoint else self._diag 

228 return diag_term * x 

229 

230 def _determinant(self): 

231 return math_ops.reduce_prod(self._diag, axis=[-1]) 

232 

233 def _log_abs_determinant(self): 

234 log_det = math_ops.reduce_sum( 

235 math_ops.log(math_ops.abs(self._diag)), axis=[-1]) 

236 if self.dtype.is_complex: 

237 log_det = math_ops.cast(log_det, dtype=self.dtype) 

238 return log_det 

239 

240 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 

241 diag_term = math_ops.conj(self._diag) if adjoint else self._diag 

242 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 

243 inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1) 

244 return rhs * inv_diag_mat 

245 

246 def _to_dense(self): 

247 return array_ops.matrix_diag(self._diag) 

248 

249 def _diag_part(self): 

250 return self.diag 

251 

252 def _add_to_tensor(self, x): 

253 x_diag = array_ops.matrix_diag_part(x) 

254 new_diag = self._diag + x_diag 

255 return array_ops.matrix_set_diag(x, new_diag) 

256 

257 def _eigvals(self): 

258 return tensor_conversion.convert_to_tensor_v2_with_dispatch(self.diag) 

259 

260 def _cond(self): 

261 abs_diag = math_ops.abs(self.diag) 

262 return (math_ops.reduce_max(abs_diag, axis=-1) / 

263 math_ops.reduce_min(abs_diag, axis=-1)) 

264 

265 @property 

266 def _composite_tensor_fields(self): 

267 return ("diag",) 

268 

269 @property 

270 def _experimental_parameter_ndims_to_matrix_ndims(self): 

271 return {"diag": 1}