Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_kronecker.py: 20%

196 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Construct the Kronecker product of one or more `LinearOperators`.""" 

16 

17from tensorflow.python.framework import common_shapes 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import errors 

20from tensorflow.python.framework import ops 

21from tensorflow.python.framework import tensor_shape 

22from tensorflow.python.framework import tensor_util 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import check_ops 

25from tensorflow.python.ops import control_flow_ops 

26from tensorflow.python.ops import math_ops 

27from tensorflow.python.ops.linalg import linalg_impl as linalg 

28from tensorflow.python.ops.linalg import linear_operator 

29from tensorflow.python.util.tf_export import tf_export 

30 

31__all__ = ["LinearOperatorKronecker"] 

32 

33 

34def _prefer_static_shape(x): 

35 if x.shape.is_fully_defined(): 

36 return x.shape 

37 return array_ops.shape(x) 

38 

39 

40def _prefer_static_concat_shape(first_shape, second_shape_int_list): 

41 """Concatenate a shape with a list of integers as statically as possible. 

42 

43 Args: 

44 first_shape: `TensorShape` or `Tensor` instance. If a `TensorShape`, 

45 `first_shape.is_fully_defined()` must return `True`. 

46 second_shape_int_list: `list` of scalar integer `Tensor`s. 

47 

48 Returns: 

49 `Tensor` representing concatenating `first_shape` and 

50 `second_shape_int_list` as statically as possible. 

51 """ 

52 second_shape_int_list_static = [ 

53 tensor_util.constant_value(s) for s in second_shape_int_list] 

54 if (isinstance(first_shape, tensor_shape.TensorShape) and 

55 all(s is not None for s in second_shape_int_list_static)): 

56 return first_shape.concatenate(second_shape_int_list_static) 

57 return array_ops.concat([first_shape, second_shape_int_list], axis=0) 

58 

59 

60@tf_export("linalg.LinearOperatorKronecker") 

61@linear_operator.make_composite_tensor 

62class LinearOperatorKronecker(linear_operator.LinearOperator): 

63 """Kronecker product between two `LinearOperators`. 

64 

65 This operator composes one or more linear operators `[op1,...,opJ]`, 

66 building a new `LinearOperator` representing the Kronecker product: 

67 `op1 x op2 x .. opJ` (we omit parentheses as the Kronecker product is 

68 associative). 

69 

70 If `opj` has shape `batch_shape_j + [M_j, N_j]`, then the composed operator 

71 will have shape equal to `broadcast_batch_shape + [prod M_j, prod N_j]`, 

72 where the product is over all operators. 

73 

74 ```python 

75 # Create a 4 x 4 linear operator composed of two 2 x 2 operators. 

76 operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]]) 

77 operator_2 = LinearOperatorFullMatrix([[1., 0.], [2., 1.]]) 

78 operator = LinearOperatorKronecker([operator_1, operator_2]) 

79 

80 operator.to_dense() 

81 ==> [[1., 0., 2., 0.], 

82 [2., 1., 4., 2.], 

83 [3., 0., 4., 0.], 

84 [6., 3., 8., 4.]] 

85 

86 operator.shape 

87 ==> [4, 4] 

88 

89 operator.log_abs_determinant() 

90 ==> scalar Tensor 

91 

92 x = ... Shape [4, 2] Tensor 

93 operator.matmul(x) 

94 ==> Shape [4, 2] Tensor 

95 

96 # Create a [2, 3] batch of 4 x 5 linear operators. 

97 matrix_45 = tf.random.normal(shape=[2, 3, 4, 5]) 

98 operator_45 = LinearOperatorFullMatrix(matrix) 

99 

100 # Create a [2, 3] batch of 5 x 6 linear operators. 

101 matrix_56 = tf.random.normal(shape=[2, 3, 5, 6]) 

102 operator_56 = LinearOperatorFullMatrix(matrix_56) 

103 

104 # Compose to create a [2, 3] batch of 20 x 30 operators. 

105 operator_large = LinearOperatorKronecker([operator_45, operator_56]) 

106 

107 # Create a shape [2, 3, 20, 2] vector. 

108 x = tf.random.normal(shape=[2, 3, 6, 2]) 

109 operator_large.matmul(x) 

110 ==> Shape [2, 3, 30, 2] Tensor 

111 ``` 

112 

113 #### Performance 

114 

115 The performance of `LinearOperatorKronecker` on any operation is equal to 

116 the sum of the individual operators' operations. 

117 

118 #### Matrix property hints 

119 

120 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 

121 for `X = non_singular, self_adjoint, positive_definite, square`. 

122 These have the following meaning: 

123 

124 * If `is_X == True`, callers should expect the operator to have the 

125 property `X`. This is a promise that should be fulfilled, but is *not* a 

126 runtime assert. For example, finite floating point precision may result 

127 in these promises being violated. 

128 * If `is_X == False`, callers should expect the operator to not have `X`. 

129 * If `is_X == None` (the default), callers should have no expectation either 

130 way. 

131 """ 

132 

133 def __init__(self, 

134 operators, 

135 is_non_singular=None, 

136 is_self_adjoint=None, 

137 is_positive_definite=None, 

138 is_square=None, 

139 name=None): 

140 r"""Initialize a `LinearOperatorKronecker`. 

141 

142 `LinearOperatorKronecker` is initialized with a list of operators 

143 `[op_1,...,op_J]`. 

144 

145 Args: 

146 operators: Iterable of `LinearOperator` objects, each with 

147 the same `dtype` and composable shape, representing the Kronecker 

148 factors. 

149 is_non_singular: Expect that this operator is non-singular. 

150 is_self_adjoint: Expect that this operator is equal to its hermitian 

151 transpose. 

152 is_positive_definite: Expect that this operator is positive definite, 

153 meaning the quadratic form `x^H A x` has positive real part for all 

154 nonzero `x`. Note that we do not require the operator to be 

155 self-adjoint to be positive-definite. See: 

156 https://en.wikipedia.org/wiki/Positive-definite_matrix\ 

157 #Extension_for_non_symmetric_matrices 

158 is_square: Expect that this operator acts like square [batch] matrices. 

159 name: A name for this `LinearOperator`. Default is the individual 

160 operators names joined with `_x_`. 

161 

162 Raises: 

163 TypeError: If all operators do not have the same `dtype`. 

164 ValueError: If `operators` is empty. 

165 """ 

166 parameters = dict( 

167 operators=operators, 

168 is_non_singular=is_non_singular, 

169 is_self_adjoint=is_self_adjoint, 

170 is_positive_definite=is_positive_definite, 

171 is_square=is_square, 

172 name=name 

173 ) 

174 

175 # Validate operators. 

176 check_ops.assert_proper_iterable(operators) 

177 operators = list(operators) 

178 if not operators: 

179 raise ValueError(f"Argument `operators` must be a list of >=1 operators. " 

180 f"Received: {operators}.") 

181 self._operators = operators 

182 

183 # Validate dtype. 

184 dtype = operators[0].dtype 

185 for operator in operators: 

186 if operator.dtype != dtype: 

187 name_type = (str((o.name, o.dtype)) for o in operators) 

188 raise TypeError( 

189 f"Expected every operation in argument `operators` to have the " 

190 f"same dtype. Received {list(name_type)}.") 

191 

192 # Auto-set and check hints. 

193 # A Kronecker product is invertible, if and only if all factors are 

194 # invertible. 

195 if all(operator.is_non_singular for operator in operators): 

196 if is_non_singular is False: 

197 raise ValueError( 

198 f"The Kronecker product of non-singular operators is always " 

199 f"non-singular. Expected argument `is_non_singular` to be True. " 

200 f"Received: {is_non_singular}.") 

201 is_non_singular = True 

202 

203 if all(operator.is_self_adjoint for operator in operators): 

204 if is_self_adjoint is False: 

205 raise ValueError( 

206 f"The Kronecker product of self-adjoint operators is always " 

207 f"self-adjoint. Expected argument `is_self_adjoint` to be True. " 

208 f"Received: {is_self_adjoint}.") 

209 is_self_adjoint = True 

210 

211 # The eigenvalues of a Kronecker product are equal to the products of eigen 

212 # values of the corresponding factors. 

213 if all(operator.is_positive_definite for operator in operators): 

214 if is_positive_definite is False: 

215 raise ValueError( 

216 f"The Kronecker product of positive-definite operators is always " 

217 f"positive-definite. Expected argument `is_positive_definite` to " 

218 f"be True. Received: {is_positive_definite}.") 

219 is_positive_definite = True 

220 

221 if name is None: 

222 name = operators[0].name 

223 for operator in operators[1:]: 

224 name += "_x_" + operator.name 

225 with ops.name_scope(name): 

226 super(LinearOperatorKronecker, self).__init__( 

227 dtype=dtype, 

228 is_non_singular=is_non_singular, 

229 is_self_adjoint=is_self_adjoint, 

230 is_positive_definite=is_positive_definite, 

231 is_square=is_square, 

232 parameters=parameters, 

233 name=name) 

234 

235 @property 

236 def operators(self): 

237 return self._operators 

238 

239 def _shape(self): 

240 # Get final matrix shape. 

241 domain_dimension = self.operators[0].domain_dimension 

242 for operator in self.operators[1:]: 

243 domain_dimension = domain_dimension * operator.domain_dimension 

244 

245 range_dimension = self.operators[0].range_dimension 

246 for operator in self.operators[1:]: 

247 range_dimension = range_dimension * operator.range_dimension 

248 

249 matrix_shape = tensor_shape.TensorShape([ 

250 range_dimension, domain_dimension]) 

251 

252 # Get broadcast batch shape. 

253 # broadcast_shape checks for compatibility. 

254 batch_shape = self.operators[0].batch_shape 

255 for operator in self.operators[1:]: 

256 batch_shape = common_shapes.broadcast_shape( 

257 batch_shape, operator.batch_shape) 

258 

259 return batch_shape.concatenate(matrix_shape) 

260 

261 def _shape_tensor(self): 

262 domain_dimension = self.operators[0].domain_dimension_tensor() 

263 for operator in self.operators[1:]: 

264 domain_dimension = domain_dimension * operator.domain_dimension_tensor() 

265 

266 range_dimension = self.operators[0].range_dimension_tensor() 

267 for operator in self.operators[1:]: 

268 range_dimension = range_dimension * operator.range_dimension_tensor() 

269 

270 matrix_shape = [range_dimension, domain_dimension] 

271 

272 # Get broadcast batch shape. 

273 # broadcast_shape checks for compatibility. 

274 batch_shape = self.operators[0].batch_shape_tensor() 

275 for operator in self.operators[1:]: 

276 batch_shape = array_ops.broadcast_dynamic_shape( 

277 batch_shape, operator.batch_shape_tensor()) 

278 

279 return array_ops.concat((batch_shape, matrix_shape), 0) 

280 

281 def _solve_matmul_internal( 

282 self, 

283 x, 

284 solve_matmul_fn, 

285 adjoint=False, 

286 adjoint_arg=False): 

287 # We heavily rely on Roth's column Lemma [1]: 

288 # (A x B) * vec X = vec BXA^T 

289 # where vec stacks all the columns of the matrix under each other. 

290 # In our case, we use a variant of the lemma that is row-major 

291 # friendly: (A x B) * vec' X = vec' AXB^T 

292 # Where vec' reshapes a matrix into a vector. We can repeatedly apply this 

293 # for a collection of kronecker products. 

294 # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can 

295 # use the above to compute multiplications, solves with any composition of 

296 # transposes. 

297 output = x 

298 

299 if adjoint_arg: 

300 if self.dtype.is_complex: 

301 output = math_ops.conj(output) 

302 else: 

303 output = linalg.transpose(output) 

304 

305 for o in reversed(self.operators): 

306 # Statically compute the reshape. 

307 if adjoint: 

308 operator_dimension = o.range_dimension_tensor() 

309 else: 

310 operator_dimension = o.domain_dimension_tensor() 

311 output_shape = _prefer_static_shape(output) 

312 

313 if tensor_util.constant_value(operator_dimension) is not None: 

314 operator_dimension = tensor_util.constant_value(operator_dimension) 

315 if output.shape[-2] is not None and output.shape[-1] is not None: 

316 dim = int(output.shape[-2] * output_shape[-1] // operator_dimension) 

317 else: 

318 dim = math_ops.cast( 

319 output_shape[-2] * output_shape[-1] // operator_dimension, 

320 dtype=dtypes.int32) 

321 

322 output_shape = _prefer_static_concat_shape( 

323 output_shape[:-2], [dim, operator_dimension]) 

324 output = array_ops.reshape(output, shape=output_shape) 

325 

326 # Conjugate because we are trying to compute A @ B^T, but 

327 # `LinearOperator` only supports `adjoint_arg`. 

328 if self.dtype.is_complex: 

329 output = math_ops.conj(output) 

330 

331 output = solve_matmul_fn( 

332 o, output, adjoint=adjoint, adjoint_arg=True) 

333 

334 if adjoint_arg: 

335 col_dim = _prefer_static_shape(x)[-2] 

336 else: 

337 col_dim = _prefer_static_shape(x)[-1] 

338 

339 if adjoint: 

340 row_dim = self.domain_dimension_tensor() 

341 else: 

342 row_dim = self.range_dimension_tensor() 

343 

344 matrix_shape = [row_dim, col_dim] 

345 

346 output = array_ops.reshape( 

347 output, 

348 _prefer_static_concat_shape( 

349 _prefer_static_shape(output)[:-2], matrix_shape)) 

350 

351 if x.shape.is_fully_defined(): 

352 if adjoint_arg: 

353 column_dim = x.shape[-2] 

354 else: 

355 column_dim = x.shape[-1] 

356 broadcast_batch_shape = common_shapes.broadcast_shape( 

357 x.shape[:-2], self.batch_shape) 

358 if adjoint: 

359 matrix_dimensions = [self.domain_dimension, column_dim] 

360 else: 

361 matrix_dimensions = [self.range_dimension, column_dim] 

362 

363 output.set_shape(broadcast_batch_shape.concatenate( 

364 matrix_dimensions)) 

365 

366 return output 

367 

368 def _matmul(self, x, adjoint=False, adjoint_arg=False): 

369 def matmul_fn(o, x, adjoint, adjoint_arg): 

370 return o.matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 

371 return self._solve_matmul_internal( 

372 x=x, 

373 solve_matmul_fn=matmul_fn, 

374 adjoint=adjoint, 

375 adjoint_arg=adjoint_arg) 

376 

377 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 

378 def solve_fn(o, rhs, adjoint, adjoint_arg): 

379 return o.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 

380 return self._solve_matmul_internal( 

381 x=rhs, 

382 solve_matmul_fn=solve_fn, 

383 adjoint=adjoint, 

384 adjoint_arg=adjoint_arg) 

385 

386 def _determinant(self): 

387 # Note that we have |X1 x X2| = |X1| ** n * |X2| ** m, where X1 is an m x m 

388 # matrix, and X2 is an n x n matrix. We can iteratively apply this property 

389 # to get the determinant of |X1 x X2 x X3 ...|. If T is the product of the 

390 # domain dimension of all operators, then we have: 

391 # |X1 x X2 x X3 ...| = 

392 # |X1| ** (T / m) * |X2 x X3 ... | ** m = 

393 # |X1| ** (T / m) * |X2| ** (m * (T / m) / n) * ... = 

394 # |X1| ** (T / m) * |X2| ** (T / n) * | X3 x X4... | ** (m * n) 

395 # And by doing induction we have product(|X_i| ** (T / dim(X_i))). 

396 total = self.domain_dimension_tensor() 

397 determinant = 1. 

398 for operator in self.operators: 

399 determinant = determinant * operator.determinant() ** math_ops.cast( 

400 total / operator.domain_dimension_tensor(), 

401 dtype=operator.dtype) 

402 return determinant 

403 

404 def _log_abs_determinant(self): 

405 # This will be sum((total / dim(x_i)) * log |X_i|) 

406 total = self.domain_dimension_tensor() 

407 log_abs_det = 0. 

408 for operator in self.operators: 

409 log_abs_det += operator.log_abs_determinant() * math_ops.cast( 

410 total / operator.domain_dimension_tensor(), 

411 dtype=operator.dtype) 

412 return log_abs_det 

413 

414 def _trace(self): 

415 # tr(A x B) = tr(A) * tr(B) 

416 trace = 1. 

417 for operator in self.operators: 

418 trace = trace * operator.trace() 

419 return trace 

420 

421 def _diag_part(self): 

422 diag_part = self.operators[0].diag_part() 

423 for operator in self.operators[1:]: 

424 diag_part = diag_part[..., :, array_ops.newaxis] 

425 op_diag_part = operator.diag_part()[..., array_ops.newaxis, :] 

426 diag_part = diag_part * op_diag_part 

427 diag_part = array_ops.reshape( 

428 diag_part, 

429 shape=array_ops.concat( 

430 [array_ops.shape(diag_part)[:-2], [-1]], axis=0)) 

431 if self.range_dimension > self.domain_dimension: 

432 diag_dimension = self.domain_dimension 

433 else: 

434 diag_dimension = self.range_dimension 

435 diag_part.set_shape( 

436 self.batch_shape.concatenate(diag_dimension)) 

437 return diag_part 

438 

439 def _to_dense(self): 

440 product = self.operators[0].to_dense() 

441 for operator in self.operators[1:]: 

442 # Product has shape [B, R1, 1, C1, 1]. 

443 product = product[ 

444 ..., :, array_ops.newaxis, :, array_ops.newaxis] 

445 # Operator has shape [B, 1, R2, 1, C2]. 

446 op_to_mul = operator.to_dense()[ 

447 ..., array_ops.newaxis, :, array_ops.newaxis, :] 

448 # This is now [B, R1, R2, C1, C2]. 

449 product = product * op_to_mul 

450 # Now merge together dimensions to get [B, R1 * R2, C1 * C2]. 

451 product_shape = _prefer_static_shape(product) 

452 shape = _prefer_static_concat_shape( 

453 product_shape[:-4], 

454 [product_shape[-4] * product_shape[-3], 

455 product_shape[-2] * product_shape[-1]]) 

456 

457 product = array_ops.reshape(product, shape=shape) 

458 product.set_shape(self.shape) 

459 return product 

460 

461 def _eigvals(self): 

462 # This will be the kronecker product of all the eigenvalues. 

463 # Note: It doesn't matter which kronecker product it is, since every 

464 # kronecker product of the same matrices are similar. 

465 eigvals = [operator.eigvals() for operator in self.operators] 

466 # Now compute the kronecker product 

467 product = eigvals[0] 

468 for eigval in eigvals[1:]: 

469 # Product has shape [B, R1, 1]. 

470 product = product[..., array_ops.newaxis] 

471 # Eigval has shape [B, 1, R2]. Produces shape [B, R1, R2]. 

472 product = product * eigval[..., array_ops.newaxis, :] 

473 # Reshape to [B, R1 * R2] 

474 product = array_ops.reshape( 

475 product, 

476 shape=array_ops.concat([array_ops.shape(product)[:-2], [-1]], axis=0)) 

477 product.set_shape(self.shape[:-1]) 

478 return product 

479 

480 def _assert_non_singular(self): 

481 if all(operator.is_square for operator in self.operators): 

482 asserts = [operator.assert_non_singular() for operator in self.operators] 

483 return control_flow_ops.group(asserts) 

484 else: 

485 raise errors.InvalidArgumentError( 

486 node_def=None, 

487 op=None, 

488 message="All Kronecker factors must be square for the product to be " 

489 "invertible. Expected hint `is_square` to be True for every operator " 

490 "in argument `operators`.") 

491 

492 def _assert_self_adjoint(self): 

493 if all(operator.is_square for operator in self.operators): 

494 asserts = [operator.assert_self_adjoint() for operator in self.operators] 

495 return control_flow_ops.group(asserts) 

496 else: 

497 raise errors.InvalidArgumentError( 

498 node_def=None, 

499 op=None, 

500 message="All Kronecker factors must be square for the product to be " 

501 "invertible. Expected hint `is_square` to be True for every operator " 

502 "in argument `operators`.") 

503 

504 @property 

505 def _composite_tensor_fields(self): 

506 return ("operators",) 

507 

508 @property 

509 def _experimental_parameter_ndims_to_matrix_ndims(self): 

510 return {"operators": [0] * len(self.operators)}