Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_kronecker.py: 20%
196 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Construct the Kronecker product of one or more `LinearOperators`."""
17from tensorflow.python.framework import common_shapes
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import errors
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import tensor_shape
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops.linalg import linalg_impl as linalg
28from tensorflow.python.ops.linalg import linear_operator
29from tensorflow.python.util.tf_export import tf_export
31__all__ = ["LinearOperatorKronecker"]
34def _prefer_static_shape(x):
35 if x.shape.is_fully_defined():
36 return x.shape
37 return array_ops.shape(x)
40def _prefer_static_concat_shape(first_shape, second_shape_int_list):
41 """Concatenate a shape with a list of integers as statically as possible.
43 Args:
44 first_shape: `TensorShape` or `Tensor` instance. If a `TensorShape`,
45 `first_shape.is_fully_defined()` must return `True`.
46 second_shape_int_list: `list` of scalar integer `Tensor`s.
48 Returns:
49 `Tensor` representing concatenating `first_shape` and
50 `second_shape_int_list` as statically as possible.
51 """
52 second_shape_int_list_static = [
53 tensor_util.constant_value(s) for s in second_shape_int_list]
54 if (isinstance(first_shape, tensor_shape.TensorShape) and
55 all(s is not None for s in second_shape_int_list_static)):
56 return first_shape.concatenate(second_shape_int_list_static)
57 return array_ops.concat([first_shape, second_shape_int_list], axis=0)
60@tf_export("linalg.LinearOperatorKronecker")
61@linear_operator.make_composite_tensor
62class LinearOperatorKronecker(linear_operator.LinearOperator):
63 """Kronecker product between two `LinearOperators`.
65 This operator composes one or more linear operators `[op1,...,opJ]`,
66 building a new `LinearOperator` representing the Kronecker product:
67 `op1 x op2 x .. opJ` (we omit parentheses as the Kronecker product is
68 associative).
70 If `opj` has shape `batch_shape_j + [M_j, N_j]`, then the composed operator
71 will have shape equal to `broadcast_batch_shape + [prod M_j, prod N_j]`,
72 where the product is over all operators.
74 ```python
75 # Create a 4 x 4 linear operator composed of two 2 x 2 operators.
76 operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]])
77 operator_2 = LinearOperatorFullMatrix([[1., 0.], [2., 1.]])
78 operator = LinearOperatorKronecker([operator_1, operator_2])
80 operator.to_dense()
81 ==> [[1., 0., 2., 0.],
82 [2., 1., 4., 2.],
83 [3., 0., 4., 0.],
84 [6., 3., 8., 4.]]
86 operator.shape
87 ==> [4, 4]
89 operator.log_abs_determinant()
90 ==> scalar Tensor
92 x = ... Shape [4, 2] Tensor
93 operator.matmul(x)
94 ==> Shape [4, 2] Tensor
96 # Create a [2, 3] batch of 4 x 5 linear operators.
97 matrix_45 = tf.random.normal(shape=[2, 3, 4, 5])
98 operator_45 = LinearOperatorFullMatrix(matrix)
100 # Create a [2, 3] batch of 5 x 6 linear operators.
101 matrix_56 = tf.random.normal(shape=[2, 3, 5, 6])
102 operator_56 = LinearOperatorFullMatrix(matrix_56)
104 # Compose to create a [2, 3] batch of 20 x 30 operators.
105 operator_large = LinearOperatorKronecker([operator_45, operator_56])
107 # Create a shape [2, 3, 20, 2] vector.
108 x = tf.random.normal(shape=[2, 3, 6, 2])
109 operator_large.matmul(x)
110 ==> Shape [2, 3, 30, 2] Tensor
111 ```
113 #### Performance
115 The performance of `LinearOperatorKronecker` on any operation is equal to
116 the sum of the individual operators' operations.
118 #### Matrix property hints
120 This `LinearOperator` is initialized with boolean flags of the form `is_X`,
121 for `X = non_singular, self_adjoint, positive_definite, square`.
122 These have the following meaning:
124 * If `is_X == True`, callers should expect the operator to have the
125 property `X`. This is a promise that should be fulfilled, but is *not* a
126 runtime assert. For example, finite floating point precision may result
127 in these promises being violated.
128 * If `is_X == False`, callers should expect the operator to not have `X`.
129 * If `is_X == None` (the default), callers should have no expectation either
130 way.
131 """
133 def __init__(self,
134 operators,
135 is_non_singular=None,
136 is_self_adjoint=None,
137 is_positive_definite=None,
138 is_square=None,
139 name=None):
140 r"""Initialize a `LinearOperatorKronecker`.
142 `LinearOperatorKronecker` is initialized with a list of operators
143 `[op_1,...,op_J]`.
145 Args:
146 operators: Iterable of `LinearOperator` objects, each with
147 the same `dtype` and composable shape, representing the Kronecker
148 factors.
149 is_non_singular: Expect that this operator is non-singular.
150 is_self_adjoint: Expect that this operator is equal to its hermitian
151 transpose.
152 is_positive_definite: Expect that this operator is positive definite,
153 meaning the quadratic form `x^H A x` has positive real part for all
154 nonzero `x`. Note that we do not require the operator to be
155 self-adjoint to be positive-definite. See:
156 https://en.wikipedia.org/wiki/Positive-definite_matrix\
157 #Extension_for_non_symmetric_matrices
158 is_square: Expect that this operator acts like square [batch] matrices.
159 name: A name for this `LinearOperator`. Default is the individual
160 operators names joined with `_x_`.
162 Raises:
163 TypeError: If all operators do not have the same `dtype`.
164 ValueError: If `operators` is empty.
165 """
166 parameters = dict(
167 operators=operators,
168 is_non_singular=is_non_singular,
169 is_self_adjoint=is_self_adjoint,
170 is_positive_definite=is_positive_definite,
171 is_square=is_square,
172 name=name
173 )
175 # Validate operators.
176 check_ops.assert_proper_iterable(operators)
177 operators = list(operators)
178 if not operators:
179 raise ValueError(f"Argument `operators` must be a list of >=1 operators. "
180 f"Received: {operators}.")
181 self._operators = operators
183 # Validate dtype.
184 dtype = operators[0].dtype
185 for operator in operators:
186 if operator.dtype != dtype:
187 name_type = (str((o.name, o.dtype)) for o in operators)
188 raise TypeError(
189 f"Expected every operation in argument `operators` to have the "
190 f"same dtype. Received {list(name_type)}.")
192 # Auto-set and check hints.
193 # A Kronecker product is invertible, if and only if all factors are
194 # invertible.
195 if all(operator.is_non_singular for operator in operators):
196 if is_non_singular is False:
197 raise ValueError(
198 f"The Kronecker product of non-singular operators is always "
199 f"non-singular. Expected argument `is_non_singular` to be True. "
200 f"Received: {is_non_singular}.")
201 is_non_singular = True
203 if all(operator.is_self_adjoint for operator in operators):
204 if is_self_adjoint is False:
205 raise ValueError(
206 f"The Kronecker product of self-adjoint operators is always "
207 f"self-adjoint. Expected argument `is_self_adjoint` to be True. "
208 f"Received: {is_self_adjoint}.")
209 is_self_adjoint = True
211 # The eigenvalues of a Kronecker product are equal to the products of eigen
212 # values of the corresponding factors.
213 if all(operator.is_positive_definite for operator in operators):
214 if is_positive_definite is False:
215 raise ValueError(
216 f"The Kronecker product of positive-definite operators is always "
217 f"positive-definite. Expected argument `is_positive_definite` to "
218 f"be True. Received: {is_positive_definite}.")
219 is_positive_definite = True
221 if name is None:
222 name = operators[0].name
223 for operator in operators[1:]:
224 name += "_x_" + operator.name
225 with ops.name_scope(name):
226 super(LinearOperatorKronecker, self).__init__(
227 dtype=dtype,
228 is_non_singular=is_non_singular,
229 is_self_adjoint=is_self_adjoint,
230 is_positive_definite=is_positive_definite,
231 is_square=is_square,
232 parameters=parameters,
233 name=name)
235 @property
236 def operators(self):
237 return self._operators
239 def _shape(self):
240 # Get final matrix shape.
241 domain_dimension = self.operators[0].domain_dimension
242 for operator in self.operators[1:]:
243 domain_dimension = domain_dimension * operator.domain_dimension
245 range_dimension = self.operators[0].range_dimension
246 for operator in self.operators[1:]:
247 range_dimension = range_dimension * operator.range_dimension
249 matrix_shape = tensor_shape.TensorShape([
250 range_dimension, domain_dimension])
252 # Get broadcast batch shape.
253 # broadcast_shape checks for compatibility.
254 batch_shape = self.operators[0].batch_shape
255 for operator in self.operators[1:]:
256 batch_shape = common_shapes.broadcast_shape(
257 batch_shape, operator.batch_shape)
259 return batch_shape.concatenate(matrix_shape)
261 def _shape_tensor(self):
262 domain_dimension = self.operators[0].domain_dimension_tensor()
263 for operator in self.operators[1:]:
264 domain_dimension = domain_dimension * operator.domain_dimension_tensor()
266 range_dimension = self.operators[0].range_dimension_tensor()
267 for operator in self.operators[1:]:
268 range_dimension = range_dimension * operator.range_dimension_tensor()
270 matrix_shape = [range_dimension, domain_dimension]
272 # Get broadcast batch shape.
273 # broadcast_shape checks for compatibility.
274 batch_shape = self.operators[0].batch_shape_tensor()
275 for operator in self.operators[1:]:
276 batch_shape = array_ops.broadcast_dynamic_shape(
277 batch_shape, operator.batch_shape_tensor())
279 return array_ops.concat((batch_shape, matrix_shape), 0)
281 def _solve_matmul_internal(
282 self,
283 x,
284 solve_matmul_fn,
285 adjoint=False,
286 adjoint_arg=False):
287 # We heavily rely on Roth's column Lemma [1]:
288 # (A x B) * vec X = vec BXA^T
289 # where vec stacks all the columns of the matrix under each other.
290 # In our case, we use a variant of the lemma that is row-major
291 # friendly: (A x B) * vec' X = vec' AXB^T
292 # Where vec' reshapes a matrix into a vector. We can repeatedly apply this
293 # for a collection of kronecker products.
294 # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can
295 # use the above to compute multiplications, solves with any composition of
296 # transposes.
297 output = x
299 if adjoint_arg:
300 if self.dtype.is_complex:
301 output = math_ops.conj(output)
302 else:
303 output = linalg.transpose(output)
305 for o in reversed(self.operators):
306 # Statically compute the reshape.
307 if adjoint:
308 operator_dimension = o.range_dimension_tensor()
309 else:
310 operator_dimension = o.domain_dimension_tensor()
311 output_shape = _prefer_static_shape(output)
313 if tensor_util.constant_value(operator_dimension) is not None:
314 operator_dimension = tensor_util.constant_value(operator_dimension)
315 if output.shape[-2] is not None and output.shape[-1] is not None:
316 dim = int(output.shape[-2] * output_shape[-1] // operator_dimension)
317 else:
318 dim = math_ops.cast(
319 output_shape[-2] * output_shape[-1] // operator_dimension,
320 dtype=dtypes.int32)
322 output_shape = _prefer_static_concat_shape(
323 output_shape[:-2], [dim, operator_dimension])
324 output = array_ops.reshape(output, shape=output_shape)
326 # Conjugate because we are trying to compute A @ B^T, but
327 # `LinearOperator` only supports `adjoint_arg`.
328 if self.dtype.is_complex:
329 output = math_ops.conj(output)
331 output = solve_matmul_fn(
332 o, output, adjoint=adjoint, adjoint_arg=True)
334 if adjoint_arg:
335 col_dim = _prefer_static_shape(x)[-2]
336 else:
337 col_dim = _prefer_static_shape(x)[-1]
339 if adjoint:
340 row_dim = self.domain_dimension_tensor()
341 else:
342 row_dim = self.range_dimension_tensor()
344 matrix_shape = [row_dim, col_dim]
346 output = array_ops.reshape(
347 output,
348 _prefer_static_concat_shape(
349 _prefer_static_shape(output)[:-2], matrix_shape))
351 if x.shape.is_fully_defined():
352 if adjoint_arg:
353 column_dim = x.shape[-2]
354 else:
355 column_dim = x.shape[-1]
356 broadcast_batch_shape = common_shapes.broadcast_shape(
357 x.shape[:-2], self.batch_shape)
358 if adjoint:
359 matrix_dimensions = [self.domain_dimension, column_dim]
360 else:
361 matrix_dimensions = [self.range_dimension, column_dim]
363 output.set_shape(broadcast_batch_shape.concatenate(
364 matrix_dimensions))
366 return output
368 def _matmul(self, x, adjoint=False, adjoint_arg=False):
369 def matmul_fn(o, x, adjoint, adjoint_arg):
370 return o.matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
371 return self._solve_matmul_internal(
372 x=x,
373 solve_matmul_fn=matmul_fn,
374 adjoint=adjoint,
375 adjoint_arg=adjoint_arg)
377 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
378 def solve_fn(o, rhs, adjoint, adjoint_arg):
379 return o.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
380 return self._solve_matmul_internal(
381 x=rhs,
382 solve_matmul_fn=solve_fn,
383 adjoint=adjoint,
384 adjoint_arg=adjoint_arg)
386 def _determinant(self):
387 # Note that we have |X1 x X2| = |X1| ** n * |X2| ** m, where X1 is an m x m
388 # matrix, and X2 is an n x n matrix. We can iteratively apply this property
389 # to get the determinant of |X1 x X2 x X3 ...|. If T is the product of the
390 # domain dimension of all operators, then we have:
391 # |X1 x X2 x X3 ...| =
392 # |X1| ** (T / m) * |X2 x X3 ... | ** m =
393 # |X1| ** (T / m) * |X2| ** (m * (T / m) / n) * ... =
394 # |X1| ** (T / m) * |X2| ** (T / n) * | X3 x X4... | ** (m * n)
395 # And by doing induction we have product(|X_i| ** (T / dim(X_i))).
396 total = self.domain_dimension_tensor()
397 determinant = 1.
398 for operator in self.operators:
399 determinant = determinant * operator.determinant() ** math_ops.cast(
400 total / operator.domain_dimension_tensor(),
401 dtype=operator.dtype)
402 return determinant
404 def _log_abs_determinant(self):
405 # This will be sum((total / dim(x_i)) * log |X_i|)
406 total = self.domain_dimension_tensor()
407 log_abs_det = 0.
408 for operator in self.operators:
409 log_abs_det += operator.log_abs_determinant() * math_ops.cast(
410 total / operator.domain_dimension_tensor(),
411 dtype=operator.dtype)
412 return log_abs_det
414 def _trace(self):
415 # tr(A x B) = tr(A) * tr(B)
416 trace = 1.
417 for operator in self.operators:
418 trace = trace * operator.trace()
419 return trace
421 def _diag_part(self):
422 diag_part = self.operators[0].diag_part()
423 for operator in self.operators[1:]:
424 diag_part = diag_part[..., :, array_ops.newaxis]
425 op_diag_part = operator.diag_part()[..., array_ops.newaxis, :]
426 diag_part = diag_part * op_diag_part
427 diag_part = array_ops.reshape(
428 diag_part,
429 shape=array_ops.concat(
430 [array_ops.shape(diag_part)[:-2], [-1]], axis=0))
431 if self.range_dimension > self.domain_dimension:
432 diag_dimension = self.domain_dimension
433 else:
434 diag_dimension = self.range_dimension
435 diag_part.set_shape(
436 self.batch_shape.concatenate(diag_dimension))
437 return diag_part
439 def _to_dense(self):
440 product = self.operators[0].to_dense()
441 for operator in self.operators[1:]:
442 # Product has shape [B, R1, 1, C1, 1].
443 product = product[
444 ..., :, array_ops.newaxis, :, array_ops.newaxis]
445 # Operator has shape [B, 1, R2, 1, C2].
446 op_to_mul = operator.to_dense()[
447 ..., array_ops.newaxis, :, array_ops.newaxis, :]
448 # This is now [B, R1, R2, C1, C2].
449 product = product * op_to_mul
450 # Now merge together dimensions to get [B, R1 * R2, C1 * C2].
451 product_shape = _prefer_static_shape(product)
452 shape = _prefer_static_concat_shape(
453 product_shape[:-4],
454 [product_shape[-4] * product_shape[-3],
455 product_shape[-2] * product_shape[-1]])
457 product = array_ops.reshape(product, shape=shape)
458 product.set_shape(self.shape)
459 return product
461 def _eigvals(self):
462 # This will be the kronecker product of all the eigenvalues.
463 # Note: It doesn't matter which kronecker product it is, since every
464 # kronecker product of the same matrices are similar.
465 eigvals = [operator.eigvals() for operator in self.operators]
466 # Now compute the kronecker product
467 product = eigvals[0]
468 for eigval in eigvals[1:]:
469 # Product has shape [B, R1, 1].
470 product = product[..., array_ops.newaxis]
471 # Eigval has shape [B, 1, R2]. Produces shape [B, R1, R2].
472 product = product * eigval[..., array_ops.newaxis, :]
473 # Reshape to [B, R1 * R2]
474 product = array_ops.reshape(
475 product,
476 shape=array_ops.concat([array_ops.shape(product)[:-2], [-1]], axis=0))
477 product.set_shape(self.shape[:-1])
478 return product
480 def _assert_non_singular(self):
481 if all(operator.is_square for operator in self.operators):
482 asserts = [operator.assert_non_singular() for operator in self.operators]
483 return control_flow_ops.group(asserts)
484 else:
485 raise errors.InvalidArgumentError(
486 node_def=None,
487 op=None,
488 message="All Kronecker factors must be square for the product to be "
489 "invertible. Expected hint `is_square` to be True for every operator "
490 "in argument `operators`.")
492 def _assert_self_adjoint(self):
493 if all(operator.is_square for operator in self.operators):
494 asserts = [operator.assert_self_adjoint() for operator in self.operators]
495 return control_flow_ops.group(asserts)
496 else:
497 raise errors.InvalidArgumentError(
498 node_def=None,
499 op=None,
500 message="All Kronecker factors must be square for the product to be "
501 "invertible. Expected hint `is_square` to be True for every operator "
502 "in argument `operators`.")
504 @property
505 def _composite_tensor_fields(self):
506 return ("operators",)
508 @property
509 def _experimental_parameter_ndims_to_matrix_ndims(self):
510 return {"operators": [0] * len(self.operators)}