Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_composition.py: 27%

113 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Composes one or more `LinearOperators`.""" 

16 

17from tensorflow.python.framework import common_shapes 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_shape 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import array_ops_stack 

23from tensorflow.python.ops import check_ops 

24from tensorflow.python.ops import control_flow_ops 

25from tensorflow.python.ops.linalg import linear_operator 

26from tensorflow.python.ops.linalg import linear_operator_util 

27from tensorflow.python.util.tf_export import tf_export 

28 

29__all__ = ["LinearOperatorComposition"] 

30 

31 

32@tf_export("linalg.LinearOperatorComposition") 

33@linear_operator.make_composite_tensor 

34class LinearOperatorComposition(linear_operator.LinearOperator): 

35 """Composes one or more `LinearOperators`. 

36 

37 This operator composes one or more linear operators `[op1,...,opJ]`, 

38 building a new `LinearOperator` with action defined by: 

39 

40 ``` 

41 op_composed(x) := op1(op2(...(opJ(x)...)) 

42 ``` 

43 

44 If `opj` acts like [batch] matrix `Aj`, then `op_composed` acts like the 

45 [batch] matrix formed with the multiplication `A1 A2...AJ`. 

46 

47 If `opj` has shape `batch_shape_j + [M_j, N_j]`, then we must have 

48 `N_j = M_{j+1}`, in which case the composed operator has shape equal to 

49 `broadcast_batch_shape + [M_1, N_J]`, where `broadcast_batch_shape` is the 

50 mutual broadcast of `batch_shape_j`, `j = 1,...,J`, assuming the intermediate 

51 batch shapes broadcast. Even if the composed shape is well defined, the 

52 composed operator's methods may fail due to lack of broadcasting ability in 

53 the defining operators' methods. 

54 

55 ```python 

56 # Create a 2 x 2 linear operator composed of two 2 x 2 operators. 

57 operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]]) 

58 operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]]) 

59 operator = LinearOperatorComposition([operator_1, operator_2]) 

60 

61 operator.to_dense() 

62 ==> [[1., 2.] 

63 [3., 4.]] 

64 

65 operator.shape 

66 ==> [2, 2] 

67 

68 operator.log_abs_determinant() 

69 ==> scalar Tensor 

70 

71 x = ... Shape [2, 4] Tensor 

72 operator.matmul(x) 

73 ==> Shape [2, 4] Tensor 

74 

75 # Create a [2, 3] batch of 4 x 5 linear operators. 

76 matrix_45 = tf.random.normal(shape=[2, 3, 4, 5]) 

77 operator_45 = LinearOperatorFullMatrix(matrix) 

78 

79 # Create a [2, 3] batch of 5 x 6 linear operators. 

80 matrix_56 = tf.random.normal(shape=[2, 3, 5, 6]) 

81 operator_56 = LinearOperatorFullMatrix(matrix_56) 

82 

83 # Compose to create a [2, 3] batch of 4 x 6 operators. 

84 operator_46 = LinearOperatorComposition([operator_45, operator_56]) 

85 

86 # Create a shape [2, 3, 6, 2] vector. 

87 x = tf.random.normal(shape=[2, 3, 6, 2]) 

88 operator.matmul(x) 

89 ==> Shape [2, 3, 4, 2] Tensor 

90 ``` 

91 

92 #### Performance 

93 

94 The performance of `LinearOperatorComposition` on any operation is equal to 

95 the sum of the individual operators' operations. 

96 

97 

98 #### Matrix property hints 

99 

100 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 

101 for `X = non_singular, self_adjoint, positive_definite, square`. 

102 These have the following meaning: 

103 

104 * If `is_X == True`, callers should expect the operator to have the 

105 property `X`. This is a promise that should be fulfilled, but is *not* a 

106 runtime assert. For example, finite floating point precision may result 

107 in these promises being violated. 

108 * If `is_X == False`, callers should expect the operator to not have `X`. 

109 * If `is_X == None` (the default), callers should have no expectation either 

110 way. 

111 """ 

112 

113 def __init__(self, 

114 operators, 

115 is_non_singular=None, 

116 is_self_adjoint=None, 

117 is_positive_definite=None, 

118 is_square=None, 

119 name=None): 

120 r"""Initialize a `LinearOperatorComposition`. 

121 

122 `LinearOperatorComposition` is initialized with a list of operators 

123 `[op_1,...,op_J]`. For the `matmul` method to be well defined, the 

124 composition `op_i.matmul(op_{i+1}(x))` must be defined. Other methods have 

125 similar constraints. 

126 

127 Args: 

128 operators: Iterable of `LinearOperator` objects, each with 

129 the same `dtype` and composable shape. 

130 is_non_singular: Expect that this operator is non-singular. 

131 is_self_adjoint: Expect that this operator is equal to its hermitian 

132 transpose. 

133 is_positive_definite: Expect that this operator is positive definite, 

134 meaning the quadratic form `x^H A x` has positive real part for all 

135 nonzero `x`. Note that we do not require the operator to be 

136 self-adjoint to be positive-definite. See: 

137 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 

138 is_square: Expect that this operator acts like square [batch] matrices. 

139 name: A name for this `LinearOperator`. Default is the individual 

140 operators names joined with `_o_`. 

141 

142 Raises: 

143 TypeError: If all operators do not have the same `dtype`. 

144 ValueError: If `operators` is empty. 

145 """ 

146 parameters = dict( 

147 operators=operators, 

148 is_non_singular=is_non_singular, 

149 is_self_adjoint=is_self_adjoint, 

150 is_positive_definite=is_positive_definite, 

151 is_square=is_square, 

152 name=name) 

153 

154 # Validate operators. 

155 check_ops.assert_proper_iterable(operators) 

156 operators = list(operators) 

157 if not operators: 

158 raise ValueError( 

159 "Expected a non-empty list of operators. Found: %s" % operators) 

160 self._operators = operators 

161 

162 # Validate dtype. 

163 dtype = operators[0].dtype 

164 for operator in operators: 

165 if operator.dtype != dtype: 

166 name_type = (str((o.name, o.dtype)) for o in operators) 

167 raise TypeError( 

168 "Expected all operators to have the same dtype. Found %s" 

169 % " ".join(name_type)) 

170 

171 # Auto-set and check hints. 

172 if all(operator.is_non_singular for operator in operators): 

173 if is_non_singular is False: # pylint:disable=g-bool-id-comparison 

174 raise ValueError( 

175 "The composition of non-singular operators is always non-singular.") 

176 is_non_singular = True 

177 

178 if _composition_must_be_self_adjoint(operators): 

179 if is_self_adjoint is False: # pylint:disable=g-bool-id-comparison 

180 raise ValueError( 

181 "The composition was determined to be self-adjoint but user " 

182 "provided incorrect `False` hint.") 

183 is_self_adjoint = True 

184 

185 if linear_operator_util.is_aat_form(operators): 

186 if is_square is False: # pylint:disable=g-bool-id-comparison 

187 raise ValueError( 

188 "The composition was determined have the form " 

189 "A @ A.H, hence it must be square. The user " 

190 "provided an incorrect `False` hint.") 

191 is_square = True 

192 

193 if linear_operator_util.is_aat_form(operators) and is_non_singular: 

194 if is_positive_definite is False: # pylint:disable=g-bool-id-comparison 

195 raise ValueError( 

196 "The composition was determined to be non-singular and have the " 

197 "form A @ A.H, hence it must be positive-definite. The user " 

198 "provided an incorrect `False` hint.") 

199 is_positive_definite = True 

200 

201 # Initialization. 

202 

203 if name is None: 

204 name = "_o_".join(operator.name for operator in operators) 

205 with ops.name_scope(name): 

206 super(LinearOperatorComposition, self).__init__( 

207 dtype=dtype, 

208 is_non_singular=is_non_singular, 

209 is_self_adjoint=is_self_adjoint, 

210 is_positive_definite=is_positive_definite, 

211 is_square=is_square, 

212 parameters=parameters, 

213 name=name) 

214 

215 @property 

216 def operators(self): 

217 return self._operators 

218 

219 def _shape(self): 

220 # Get final matrix shape. 

221 domain_dimension = self.operators[0].domain_dimension 

222 for operator in self.operators[1:]: 

223 domain_dimension.assert_is_compatible_with(operator.range_dimension) 

224 domain_dimension = operator.domain_dimension 

225 

226 matrix_shape = tensor_shape.TensorShape( 

227 [self.operators[0].range_dimension, 

228 self.operators[-1].domain_dimension]) 

229 

230 # Get broadcast batch shape. 

231 # broadcast_shape checks for compatibility. 

232 batch_shape = self.operators[0].batch_shape 

233 for operator in self.operators[1:]: 

234 batch_shape = common_shapes.broadcast_shape( 

235 batch_shape, operator.batch_shape) 

236 

237 return batch_shape.concatenate(matrix_shape) 

238 

239 def _shape_tensor(self): 

240 # Avoid messy broadcasting if possible. 

241 if self.shape.is_fully_defined(): 

242 return ops.convert_to_tensor( 

243 self.shape.as_list(), dtype=dtypes.int32, name="shape") 

244 

245 # Don't check the matrix dimensions. That would add unnecessary Asserts to 

246 # the graph. Things will fail at runtime naturally if shapes are 

247 # incompatible. 

248 matrix_shape = array_ops_stack.stack([ 

249 self.operators[0].range_dimension_tensor(), 

250 self.operators[-1].domain_dimension_tensor() 

251 ]) 

252 

253 # Dummy Tensor of zeros. Will never be materialized. 

254 zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor()) 

255 for operator in self.operators[1:]: 

256 zeros += array_ops.zeros(shape=operator.batch_shape_tensor()) 

257 batch_shape = array_ops.shape(zeros) 

258 

259 return array_ops.concat((batch_shape, matrix_shape), 0) 

260 

261 def _matmul(self, x, adjoint=False, adjoint_arg=False): 

262 # If self.operators = [A, B], and not adjoint, then 

263 # matmul_order_list = [B, A]. 

264 # As a result, we return A.matmul(B.matmul(x)) 

265 if adjoint: 

266 matmul_order_list = self.operators 

267 else: 

268 matmul_order_list = list(reversed(self.operators)) 

269 

270 result = matmul_order_list[0].matmul( 

271 x, adjoint=adjoint, adjoint_arg=adjoint_arg) 

272 for operator in matmul_order_list[1:]: 

273 result = operator.matmul(result, adjoint=adjoint) 

274 return result 

275 

276 def _determinant(self): 

277 result = self.operators[0].determinant() 

278 for operator in self.operators[1:]: 

279 result *= operator.determinant() 

280 return result 

281 

282 def _log_abs_determinant(self): 

283 result = self.operators[0].log_abs_determinant() 

284 for operator in self.operators[1:]: 

285 result += operator.log_abs_determinant() 

286 return result 

287 

288 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 

289 # TODO(langmore) Implement solve using solve_ls if some intermediate 

290 # operator maps to a high dimensional space. 

291 # In that case, an exact solve may still be possible. 

292 

293 # If self.operators = [A, B], and not adjoint, then 

294 # solve_order_list = [A, B]. 

295 # As a result, we return B.solve(A.solve(x)) 

296 if adjoint: 

297 solve_order_list = list(reversed(self.operators)) 

298 else: 

299 solve_order_list = self.operators 

300 

301 solution = solve_order_list[0].solve( 

302 rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 

303 for operator in solve_order_list[1:]: 

304 solution = operator.solve(solution, adjoint=adjoint) 

305 return solution 

306 

307 def _assert_non_singular(self): 

308 if all(operator.is_square for operator in self.operators): 

309 asserts = [operator.assert_non_singular() for operator in self.operators] 

310 return control_flow_ops.group(asserts) 

311 return super(LinearOperatorComposition, self)._assert_non_singular() 

312 

313 @property 

314 def _composite_tensor_fields(self): 

315 return ("operators",) 

316 

317 @property 

318 def _experimental_parameter_ndims_to_matrix_ndims(self): 

319 return {"operators": [0] * len(self.operators)} 

320 

321 

322def _composition_must_be_self_adjoint(operators): 

323 """Runs some checks to see if composition operators must be SA. 

324 

325 Args: 

326 operators: List of LinearOperators. 

327 

328 Returns: 

329 True if the composition must be SA. False if it is not SA OR if we did not 

330 determine whether the composition is SA. 

331 """ 

332 if len(operators) == 1 and operators[0].is_self_adjoint: 

333 return True 

334 

335 # Check for forms like A @ A.H or (A1 @ A2) @ (A2.H @ A1.H) or ... 

336 if linear_operator_util.is_aat_form(operators): 

337 return True 

338 

339 # Done checking...could still be SA. 

340 # We may not catch some cases. E.g. (A @ I) @ A.H is SA, but is not AAT form. 

341 return False