Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_composition.py: 27%
113 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Composes one or more `LinearOperators`."""
17from tensorflow.python.framework import common_shapes
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import array_ops_stack
23from tensorflow.python.ops import check_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops.linalg import linear_operator
26from tensorflow.python.ops.linalg import linear_operator_util
27from tensorflow.python.util.tf_export import tf_export
29__all__ = ["LinearOperatorComposition"]
32@tf_export("linalg.LinearOperatorComposition")
33@linear_operator.make_composite_tensor
34class LinearOperatorComposition(linear_operator.LinearOperator):
35 """Composes one or more `LinearOperators`.
37 This operator composes one or more linear operators `[op1,...,opJ]`,
38 building a new `LinearOperator` with action defined by:
40 ```
41 op_composed(x) := op1(op2(...(opJ(x)...))
42 ```
44 If `opj` acts like [batch] matrix `Aj`, then `op_composed` acts like the
45 [batch] matrix formed with the multiplication `A1 A2...AJ`.
47 If `opj` has shape `batch_shape_j + [M_j, N_j]`, then we must have
48 `N_j = M_{j+1}`, in which case the composed operator has shape equal to
49 `broadcast_batch_shape + [M_1, N_J]`, where `broadcast_batch_shape` is the
50 mutual broadcast of `batch_shape_j`, `j = 1,...,J`, assuming the intermediate
51 batch shapes broadcast. Even if the composed shape is well defined, the
52 composed operator's methods may fail due to lack of broadcasting ability in
53 the defining operators' methods.
55 ```python
56 # Create a 2 x 2 linear operator composed of two 2 x 2 operators.
57 operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]])
58 operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]])
59 operator = LinearOperatorComposition([operator_1, operator_2])
61 operator.to_dense()
62 ==> [[1., 2.]
63 [3., 4.]]
65 operator.shape
66 ==> [2, 2]
68 operator.log_abs_determinant()
69 ==> scalar Tensor
71 x = ... Shape [2, 4] Tensor
72 operator.matmul(x)
73 ==> Shape [2, 4] Tensor
75 # Create a [2, 3] batch of 4 x 5 linear operators.
76 matrix_45 = tf.random.normal(shape=[2, 3, 4, 5])
77 operator_45 = LinearOperatorFullMatrix(matrix)
79 # Create a [2, 3] batch of 5 x 6 linear operators.
80 matrix_56 = tf.random.normal(shape=[2, 3, 5, 6])
81 operator_56 = LinearOperatorFullMatrix(matrix_56)
83 # Compose to create a [2, 3] batch of 4 x 6 operators.
84 operator_46 = LinearOperatorComposition([operator_45, operator_56])
86 # Create a shape [2, 3, 6, 2] vector.
87 x = tf.random.normal(shape=[2, 3, 6, 2])
88 operator.matmul(x)
89 ==> Shape [2, 3, 4, 2] Tensor
90 ```
92 #### Performance
94 The performance of `LinearOperatorComposition` on any operation is equal to
95 the sum of the individual operators' operations.
98 #### Matrix property hints
100 This `LinearOperator` is initialized with boolean flags of the form `is_X`,
101 for `X = non_singular, self_adjoint, positive_definite, square`.
102 These have the following meaning:
104 * If `is_X == True`, callers should expect the operator to have the
105 property `X`. This is a promise that should be fulfilled, but is *not* a
106 runtime assert. For example, finite floating point precision may result
107 in these promises being violated.
108 * If `is_X == False`, callers should expect the operator to not have `X`.
109 * If `is_X == None` (the default), callers should have no expectation either
110 way.
111 """
113 def __init__(self,
114 operators,
115 is_non_singular=None,
116 is_self_adjoint=None,
117 is_positive_definite=None,
118 is_square=None,
119 name=None):
120 r"""Initialize a `LinearOperatorComposition`.
122 `LinearOperatorComposition` is initialized with a list of operators
123 `[op_1,...,op_J]`. For the `matmul` method to be well defined, the
124 composition `op_i.matmul(op_{i+1}(x))` must be defined. Other methods have
125 similar constraints.
127 Args:
128 operators: Iterable of `LinearOperator` objects, each with
129 the same `dtype` and composable shape.
130 is_non_singular: Expect that this operator is non-singular.
131 is_self_adjoint: Expect that this operator is equal to its hermitian
132 transpose.
133 is_positive_definite: Expect that this operator is positive definite,
134 meaning the quadratic form `x^H A x` has positive real part for all
135 nonzero `x`. Note that we do not require the operator to be
136 self-adjoint to be positive-definite. See:
137 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
138 is_square: Expect that this operator acts like square [batch] matrices.
139 name: A name for this `LinearOperator`. Default is the individual
140 operators names joined with `_o_`.
142 Raises:
143 TypeError: If all operators do not have the same `dtype`.
144 ValueError: If `operators` is empty.
145 """
146 parameters = dict(
147 operators=operators,
148 is_non_singular=is_non_singular,
149 is_self_adjoint=is_self_adjoint,
150 is_positive_definite=is_positive_definite,
151 is_square=is_square,
152 name=name)
154 # Validate operators.
155 check_ops.assert_proper_iterable(operators)
156 operators = list(operators)
157 if not operators:
158 raise ValueError(
159 "Expected a non-empty list of operators. Found: %s" % operators)
160 self._operators = operators
162 # Validate dtype.
163 dtype = operators[0].dtype
164 for operator in operators:
165 if operator.dtype != dtype:
166 name_type = (str((o.name, o.dtype)) for o in operators)
167 raise TypeError(
168 "Expected all operators to have the same dtype. Found %s"
169 % " ".join(name_type))
171 # Auto-set and check hints.
172 if all(operator.is_non_singular for operator in operators):
173 if is_non_singular is False: # pylint:disable=g-bool-id-comparison
174 raise ValueError(
175 "The composition of non-singular operators is always non-singular.")
176 is_non_singular = True
178 if _composition_must_be_self_adjoint(operators):
179 if is_self_adjoint is False: # pylint:disable=g-bool-id-comparison
180 raise ValueError(
181 "The composition was determined to be self-adjoint but user "
182 "provided incorrect `False` hint.")
183 is_self_adjoint = True
185 if linear_operator_util.is_aat_form(operators):
186 if is_square is False: # pylint:disable=g-bool-id-comparison
187 raise ValueError(
188 "The composition was determined have the form "
189 "A @ A.H, hence it must be square. The user "
190 "provided an incorrect `False` hint.")
191 is_square = True
193 if linear_operator_util.is_aat_form(operators) and is_non_singular:
194 if is_positive_definite is False: # pylint:disable=g-bool-id-comparison
195 raise ValueError(
196 "The composition was determined to be non-singular and have the "
197 "form A @ A.H, hence it must be positive-definite. The user "
198 "provided an incorrect `False` hint.")
199 is_positive_definite = True
201 # Initialization.
203 if name is None:
204 name = "_o_".join(operator.name for operator in operators)
205 with ops.name_scope(name):
206 super(LinearOperatorComposition, self).__init__(
207 dtype=dtype,
208 is_non_singular=is_non_singular,
209 is_self_adjoint=is_self_adjoint,
210 is_positive_definite=is_positive_definite,
211 is_square=is_square,
212 parameters=parameters,
213 name=name)
215 @property
216 def operators(self):
217 return self._operators
219 def _shape(self):
220 # Get final matrix shape.
221 domain_dimension = self.operators[0].domain_dimension
222 for operator in self.operators[1:]:
223 domain_dimension.assert_is_compatible_with(operator.range_dimension)
224 domain_dimension = operator.domain_dimension
226 matrix_shape = tensor_shape.TensorShape(
227 [self.operators[0].range_dimension,
228 self.operators[-1].domain_dimension])
230 # Get broadcast batch shape.
231 # broadcast_shape checks for compatibility.
232 batch_shape = self.operators[0].batch_shape
233 for operator in self.operators[1:]:
234 batch_shape = common_shapes.broadcast_shape(
235 batch_shape, operator.batch_shape)
237 return batch_shape.concatenate(matrix_shape)
239 def _shape_tensor(self):
240 # Avoid messy broadcasting if possible.
241 if self.shape.is_fully_defined():
242 return ops.convert_to_tensor(
243 self.shape.as_list(), dtype=dtypes.int32, name="shape")
245 # Don't check the matrix dimensions. That would add unnecessary Asserts to
246 # the graph. Things will fail at runtime naturally if shapes are
247 # incompatible.
248 matrix_shape = array_ops_stack.stack([
249 self.operators[0].range_dimension_tensor(),
250 self.operators[-1].domain_dimension_tensor()
251 ])
253 # Dummy Tensor of zeros. Will never be materialized.
254 zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
255 for operator in self.operators[1:]:
256 zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
257 batch_shape = array_ops.shape(zeros)
259 return array_ops.concat((batch_shape, matrix_shape), 0)
261 def _matmul(self, x, adjoint=False, adjoint_arg=False):
262 # If self.operators = [A, B], and not adjoint, then
263 # matmul_order_list = [B, A].
264 # As a result, we return A.matmul(B.matmul(x))
265 if adjoint:
266 matmul_order_list = self.operators
267 else:
268 matmul_order_list = list(reversed(self.operators))
270 result = matmul_order_list[0].matmul(
271 x, adjoint=adjoint, adjoint_arg=adjoint_arg)
272 for operator in matmul_order_list[1:]:
273 result = operator.matmul(result, adjoint=adjoint)
274 return result
276 def _determinant(self):
277 result = self.operators[0].determinant()
278 for operator in self.operators[1:]:
279 result *= operator.determinant()
280 return result
282 def _log_abs_determinant(self):
283 result = self.operators[0].log_abs_determinant()
284 for operator in self.operators[1:]:
285 result += operator.log_abs_determinant()
286 return result
288 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
289 # TODO(langmore) Implement solve using solve_ls if some intermediate
290 # operator maps to a high dimensional space.
291 # In that case, an exact solve may still be possible.
293 # If self.operators = [A, B], and not adjoint, then
294 # solve_order_list = [A, B].
295 # As a result, we return B.solve(A.solve(x))
296 if adjoint:
297 solve_order_list = list(reversed(self.operators))
298 else:
299 solve_order_list = self.operators
301 solution = solve_order_list[0].solve(
302 rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
303 for operator in solve_order_list[1:]:
304 solution = operator.solve(solution, adjoint=adjoint)
305 return solution
307 def _assert_non_singular(self):
308 if all(operator.is_square for operator in self.operators):
309 asserts = [operator.assert_non_singular() for operator in self.operators]
310 return control_flow_ops.group(asserts)
311 return super(LinearOperatorComposition, self)._assert_non_singular()
313 @property
314 def _composite_tensor_fields(self):
315 return ("operators",)
317 @property
318 def _experimental_parameter_ndims_to_matrix_ndims(self):
319 return {"operators": [0] * len(self.operators)}
322def _composition_must_be_self_adjoint(operators):
323 """Runs some checks to see if composition operators must be SA.
325 Args:
326 operators: List of LinearOperators.
328 Returns:
329 True if the composition must be SA. False if it is not SA OR if we did not
330 determine whether the composition is SA.
331 """
332 if len(operators) == 1 and operators[0].is_self_adjoint:
333 return True
335 # Check for forms like A @ A.H or (A1 @ A2) @ (A2.H @ A1.H) or ...
336 if linear_operator_util.is_aat_form(operators):
337 return True
339 # Done checking...could still be SA.
340 # We may not catch some cases. E.g. (A @ I) @ A.H is SA, but is not AAT form.
341 return False