Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_householder.py: 41%

92 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""`LinearOperator` acting like a Householder transformation.""" 

16 

17from tensorflow.python.framework import errors 

18from tensorflow.python.framework import ops 

19from tensorflow.python.framework import tensor_conversion 

20from tensorflow.python.ops import array_ops 

21from tensorflow.python.ops import control_flow_ops 

22from tensorflow.python.ops import math_ops 

23from tensorflow.python.ops import nn 

24from tensorflow.python.ops.linalg import linalg_impl as linalg 

25from tensorflow.python.ops.linalg import linear_operator 

26from tensorflow.python.ops.linalg import linear_operator_util 

27from tensorflow.python.util.tf_export import tf_export 

28 

29__all__ = ["LinearOperatorHouseholder",] 

30 

31 

32@tf_export("linalg.LinearOperatorHouseholder") 

33@linear_operator.make_composite_tensor 

34class LinearOperatorHouseholder(linear_operator.LinearOperator): 

35 """`LinearOperator` acting like a [batch] of Householder transformations. 

36 

37 This operator acts like a [batch] of householder reflections with shape 

38 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 

39 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 

40 an `N x N` matrix. This matrix `A` is not materialized, but for 

41 purposes of broadcasting this shape will be relevant. 

42 

43 `LinearOperatorHouseholder` is initialized with a (batch) vector. 

44 

45 A Householder reflection, defined via a vector `v`, which reflects points 

46 in `R^n` about the hyperplane orthogonal to `v` and through the origin. 

47 

48 ```python 

49 # Create a 2 x 2 householder transform. 

50 vec = [1 / np.sqrt(2), 1. / np.sqrt(2)] 

51 operator = LinearOperatorHouseholder(vec) 

52 

53 operator.to_dense() 

54 ==> [[0., -1.] 

55 [-1., -0.]] 

56 

57 operator.shape 

58 ==> [2, 2] 

59 

60 operator.log_abs_determinant() 

61 ==> scalar Tensor 

62 

63 x = ... Shape [2, 4] Tensor 

64 operator.matmul(x) 

65 ==> Shape [2, 4] Tensor 

66 ``` 

67 

68 #### Shape compatibility 

69 

70 This operator acts on [batch] matrix with compatible shape. 

71 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 

72 

73 ``` 

74 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 

75 x.shape = [C1,...,Cc] + [N, R], 

76 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 

77 ``` 

78 

79 #### Matrix property hints 

80 

81 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 

82 for `X = non_singular, self_adjoint, positive_definite, square`. 

83 These have the following meaning: 

84 

85 * If `is_X == True`, callers should expect the operator to have the 

86 property `X`. This is a promise that should be fulfilled, but is *not* a 

87 runtime assert. For example, finite floating point precision may result 

88 in these promises being violated. 

89 * If `is_X == False`, callers should expect the operator to not have `X`. 

90 * If `is_X == None` (the default), callers should have no expectation either 

91 way. 

92 """ 

93 

94 def __init__(self, 

95 reflection_axis, 

96 is_non_singular=None, 

97 is_self_adjoint=None, 

98 is_positive_definite=None, 

99 is_square=None, 

100 name="LinearOperatorHouseholder"): 

101 r"""Initialize a `LinearOperatorHouseholder`. 

102 

103 Args: 

104 reflection_axis: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. 

105 The vector defining the hyperplane to reflect about. 

106 Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, 

107 `complex128`. 

108 is_non_singular: Expect that this operator is non-singular. 

109 is_self_adjoint: Expect that this operator is equal to its hermitian 

110 transpose. This is autoset to true 

111 is_positive_definite: Expect that this operator is positive definite, 

112 meaning the quadratic form `x^H A x` has positive real part for all 

113 nonzero `x`. Note that we do not require the operator to be 

114 self-adjoint to be positive-definite. See: 

115 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 

116 This is autoset to false. 

117 is_square: Expect that this operator acts like square [batch] matrices. 

118 This is autoset to true. 

119 name: A name for this `LinearOperator`. 

120 

121 Raises: 

122 ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is 

123 not `False` or `is_square` is not `True`. 

124 """ 

125 parameters = dict( 

126 reflection_axis=reflection_axis, 

127 is_non_singular=is_non_singular, 

128 is_self_adjoint=is_self_adjoint, 

129 is_positive_definite=is_positive_definite, 

130 is_square=is_square, 

131 name=name 

132 ) 

133 

134 with ops.name_scope(name, values=[reflection_axis]): 

135 self._reflection_axis = linear_operator_util.convert_nonref_to_tensor( 

136 reflection_axis, name="reflection_axis") 

137 self._check_reflection_axis(self._reflection_axis) 

138 

139 # Check and auto-set hints. 

140 if is_self_adjoint is False: # pylint:disable=g-bool-id-comparison 

141 raise ValueError("A Householder operator is always self adjoint.") 

142 else: 

143 is_self_adjoint = True 

144 

145 if is_positive_definite is True: # pylint:disable=g-bool-id-comparison 

146 raise ValueError( 

147 "A Householder operator is always non-positive definite.") 

148 else: 

149 is_positive_definite = False 

150 

151 if is_square is False: # pylint:disable=g-bool-id-comparison 

152 raise ValueError("A Householder operator is always square.") 

153 is_square = True 

154 

155 super(LinearOperatorHouseholder, self).__init__( 

156 dtype=self._reflection_axis.dtype, 

157 is_non_singular=is_non_singular, 

158 is_self_adjoint=is_self_adjoint, 

159 is_positive_definite=is_positive_definite, 

160 is_square=is_square, 

161 parameters=parameters, 

162 name=name) 

163 

164 def _check_reflection_axis(self, reflection_axis): 

165 """Static check of reflection_axis.""" 

166 if (reflection_axis.shape.ndims is not None and 

167 reflection_axis.shape.ndims < 1): 

168 raise ValueError( 

169 "Argument reflection_axis must have at least 1 dimension. " 

170 "Found: %s" % reflection_axis) 

171 

172 def _shape(self): 

173 # If d_shape = [5, 3], we return [5, 3, 3]. 

174 d_shape = self._reflection_axis.shape 

175 return d_shape.concatenate(d_shape[-1:]) 

176 

177 def _shape_tensor(self): 

178 d_shape = array_ops.shape(self._reflection_axis) 

179 k = d_shape[-1] 

180 return array_ops.concat((d_shape, [k]), 0) 

181 

182 def _assert_non_singular(self): 

183 return control_flow_ops.no_op("assert_non_singular") 

184 

185 def _assert_positive_definite(self): 

186 raise errors.InvalidArgumentError( 

187 node_def=None, op=None, message="Householder operators are always " 

188 "non-positive definite.") 

189 

190 def _assert_self_adjoint(self): 

191 return control_flow_ops.no_op("assert_self_adjoint") 

192 

193 def _matmul(self, x, adjoint=False, adjoint_arg=False): 

194 # Given a vector `v`, we would like to reflect `x` about the hyperplane 

195 # orthogonal to `v` going through the origin. We first project `x` to `v` 

196 # to get v * dot(v, x) / dot(v, v). After we project, we can reflect the 

197 # projection about the hyperplane by flipping sign to get 

198 # -v * dot(v, x) / dot(v, v). Finally, we can add back the component 

199 # that is orthogonal to v. This is invariant under reflection, since the 

200 # whole hyperplane is invariant. This component is equal to x - v * dot(v, 

201 # x) / dot(v, v), giving the formula x - 2 * v * dot(v, x) / dot(v, v) 

202 # for the reflection. 

203 

204 # Note that because this is a reflection, it lies in O(n) (for real vector 

205 # spaces) or U(n) (for complex vector spaces), and thus is its own adjoint. 

206 reflection_axis = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

207 self.reflection_axis 

208 ) 

209 x = linalg.adjoint(x) if adjoint_arg else x 

210 normalized_axis = nn.l2_normalize(reflection_axis, axis=-1) 

211 mat = normalized_axis[..., array_ops.newaxis] 

212 x_dot_normalized_v = math_ops.matmul(mat, x, adjoint_a=True) 

213 

214 return x - 2 * mat * x_dot_normalized_v 

215 

216 def _trace(self): 

217 # We have (n - 1) +1 eigenvalues and a single -1 eigenvalue. 

218 shape = self.shape_tensor() 

219 return math_ops.cast( 

220 self._domain_dimension_tensor(shape=shape) - 2, 

221 self.dtype) * array_ops.ones( 

222 shape=self._batch_shape_tensor(shape=shape), dtype=self.dtype) 

223 

224 def _determinant(self): 

225 # For householder transformations, the determinant is -1. 

226 return -array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype) # pylint: disable=invalid-unary-operand-type 

227 

228 def _log_abs_determinant(self): 

229 # Orthogonal matrix -> log|Q| = 0. 

230 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) 

231 

232 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 

233 # A householder reflection is a reflection, hence is idempotent. Thus we 

234 # can just apply a matmul. 

235 return self._matmul(rhs, adjoint, adjoint_arg) 

236 

237 def _to_dense(self): 

238 reflection_axis = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

239 self.reflection_axis 

240 ) 

241 normalized_axis = nn.l2_normalize(reflection_axis, axis=-1) 

242 mat = normalized_axis[..., array_ops.newaxis] 

243 matrix = -2 * math_ops.matmul(mat, mat, adjoint_b=True) 

244 return array_ops.matrix_set_diag( 

245 matrix, 1. + array_ops.matrix_diag_part(matrix)) 

246 

247 def _diag_part(self): 

248 reflection_axis = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

249 self.reflection_axis 

250 ) 

251 normalized_axis = nn.l2_normalize(reflection_axis, axis=-1) 

252 return 1. - 2 * normalized_axis * math_ops.conj(normalized_axis) 

253 

254 def _eigvals(self): 

255 # We have (n - 1) +1 eigenvalues and a single -1 eigenvalue. 

256 result_shape = array_ops.shape(self.reflection_axis) 

257 n = result_shape[-1] 

258 ones_shape = array_ops.concat([result_shape[:-1], [n - 1]], axis=-1) 

259 neg_shape = array_ops.concat([result_shape[:-1], [1]], axis=-1) 

260 eigvals = array_ops.ones(shape=ones_shape, dtype=self.dtype) 

261 eigvals = array_ops.concat( 

262 [-array_ops.ones(shape=neg_shape, dtype=self.dtype), eigvals], axis=-1) # pylint: disable=invalid-unary-operand-type 

263 return eigvals 

264 

265 def _cond(self): 

266 # Householder matrices are rotations which have condition number 1. 

267 return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype) 

268 

269 @property 

270 def reflection_axis(self): 

271 return self._reflection_axis 

272 

273 @property 

274 def _composite_tensor_fields(self): 

275 return ("reflection_axis",) 

276 

277 @property 

278 def _experimental_parameter_ndims_to_matrix_ndims(self): 

279 return {"reflection_axis": 1}