Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_addition.py: 33%
151 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Add one or more `LinearOperators` efficiently."""
17import abc
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import check_ops
23from tensorflow.python.ops.linalg import linear_operator
24from tensorflow.python.ops.linalg import linear_operator_diag
25from tensorflow.python.ops.linalg import linear_operator_full_matrix
26from tensorflow.python.ops.linalg import linear_operator_identity
27from tensorflow.python.ops.linalg import linear_operator_lower_triangular
29__all__ = []
32def add_operators(operators,
33 operator_name=None,
34 addition_tiers=None,
35 name=None):
36 """Efficiently add one or more linear operators.
38 Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
39 operators `[B1, B2,...]` such that
41 ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
43 The operators `Bk` result by adding some of the `Ak`, as allowed by
44 `addition_tiers`.
46 Example of efficient adding of diagonal operators.
48 ```python
49 A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
50 A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
52 # Use two tiers, the first contains an Adder that returns Diag. Since both
53 # A1 and A2 are Diag, they can use this Adder. The second tier will not be
54 # used.
55 addition_tiers = [
56 [_AddAndReturnDiag()],
57 [_AddAndReturnMatrix()]]
58 B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
60 len(B_list)
61 ==> 1
63 B_list[0].__class__.__name__
64 ==> 'LinearOperatorDiag'
66 B_list[0].to_dense()
67 ==> [[3., 0.],
68 [0., 3.]]
70 B_list[0].name
71 ==> 'Add/A1__A2/'
72 ```
74 Args:
75 operators: Iterable of `LinearOperator` objects with same `dtype`, domain
76 and range dimensions, and broadcastable batch shapes.
77 operator_name: String name for returned `LinearOperator`. Defaults to
78 concatenation of "Add/A__B/" that indicates the order of addition steps.
79 addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
80 is a list of `Adder` objects. This function attempts to do all additions
81 in tier `i` before trying tier `i + 1`.
82 name: A name for this `Op`. Defaults to `add_operators`.
84 Returns:
85 Subclass of `LinearOperator`. Class and order of addition may change as new
86 (and better) addition strategies emerge.
88 Raises:
89 ValueError: If `operators` argument is empty.
90 ValueError: If shapes are incompatible.
91 """
92 # Default setting
93 if addition_tiers is None:
94 addition_tiers = _DEFAULT_ADDITION_TIERS
96 # Argument checking.
97 check_ops.assert_proper_iterable(operators)
98 operators = list(reversed(operators))
99 if len(operators) < 1:
100 raise ValueError(
101 f"Argument `operators` must contain at least one operator. "
102 f"Received: {operators}.")
103 if not all(
104 isinstance(op, linear_operator.LinearOperator) for op in operators):
105 raise TypeError(
106 f"Argument `operators` must contain only LinearOperator instances. "
107 f"Received: {operators}.")
108 _static_check_for_same_dimensions(operators)
109 _static_check_for_broadcastable_batch_shape(operators)
111 with ops.name_scope(name or "add_operators"):
113 # Additions done in one of the tiers. Try tier 0, 1,...
114 ops_to_try_at_next_tier = list(operators)
115 for tier in addition_tiers:
116 ops_to_try_at_this_tier = ops_to_try_at_next_tier
117 ops_to_try_at_next_tier = []
118 while ops_to_try_at_this_tier:
119 op1 = ops_to_try_at_this_tier.pop()
120 op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
121 if op2 is not None:
122 # Will try to add the result of this again at this same tier.
123 new_operator = adder.add(op1, op2, operator_name)
124 ops_to_try_at_this_tier.append(new_operator)
125 else:
126 ops_to_try_at_next_tier.append(op1)
128 return ops_to_try_at_next_tier
131def _pop_a_match_at_tier(op1, operator_list, tier):
132 # Search from the back of list to the front in order to create nice default
133 # order of operations.
134 for i in range(1, len(operator_list) + 1):
135 op2 = operator_list[-i]
136 for adder in tier:
137 if adder.can_add(op1, op2):
138 return operator_list.pop(-i), adder
139 return None, None
142def _infer_hints_allowing_override(op1, op2, hints):
143 """Infer hints from op1 and op2. hints argument is an override.
145 Args:
146 op1: LinearOperator
147 op2: LinearOperator
148 hints: _Hints object holding "is_X" boolean hints to use for returned
149 operator.
150 If some hint is None, try to set using op1 and op2. If the
151 hint is provided, ignore op1 and op2 hints. This allows an override
152 of previous hints, but does not allow forbidden hints (e.g. you still
153 cannot say a real diagonal operator is not self-adjoint.
155 Returns:
156 _Hints object.
157 """
158 hints = hints or _Hints()
159 # If A, B are self-adjoint, then so is A + B.
160 if hints.is_self_adjoint is None:
161 is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
162 else:
163 is_self_adjoint = hints.is_self_adjoint
165 # If A, B are positive definite, then so is A + B.
166 if hints.is_positive_definite is None:
167 is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
168 else:
169 is_positive_definite = hints.is_positive_definite
171 # A positive definite operator is always non-singular.
172 if is_positive_definite and hints.is_positive_definite is None:
173 is_non_singular = True
174 else:
175 is_non_singular = hints.is_non_singular
177 return _Hints(
178 is_non_singular=is_non_singular,
179 is_self_adjoint=is_self_adjoint,
180 is_positive_definite=is_positive_definite)
183def _static_check_for_same_dimensions(operators):
184 """ValueError if operators determined to have different dimensions."""
185 if len(operators) < 2:
186 return
188 domain_dimensions = [
189 (op.name, tensor_shape.dimension_value(op.domain_dimension))
190 for op in operators
191 if tensor_shape.dimension_value(op.domain_dimension) is not None]
192 if len(set(value for name, value in domain_dimensions)) > 1:
193 raise ValueError(f"All `operators` must have the same `domain_dimension`. "
194 f"Received: {domain_dimensions}.")
196 range_dimensions = [
197 (op.name, tensor_shape.dimension_value(op.range_dimension))
198 for op in operators
199 if tensor_shape.dimension_value(op.range_dimension) is not None]
200 if len(set(value for name, value in range_dimensions)) > 1:
201 raise ValueError(f"All operators must have the same `range_dimension`. "
202 f"Received: {range_dimensions}.")
205def _static_check_for_broadcastable_batch_shape(operators):
206 """ValueError if operators determined to have non-broadcastable shapes."""
207 if len(operators) < 2:
208 return
210 # This will fail if they cannot be broadcast together.
211 batch_shape = operators[0].batch_shape
212 for op in operators[1:]:
213 batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
216class _Hints:
217 """Holds 'is_X' flags that every LinearOperator is initialized with."""
219 def __init__(self,
220 is_non_singular=None,
221 is_positive_definite=None,
222 is_self_adjoint=None):
223 self.is_non_singular = is_non_singular
224 self.is_positive_definite = is_positive_definite
225 self.is_self_adjoint = is_self_adjoint
228################################################################################
229# Classes to add two linear operators.
230################################################################################
233class _Adder(metaclass=abc.ABCMeta):
234 """Abstract base class to add two operators.
236 Each `Adder` acts independently, adding everything it can, paying no attention
237 as to whether another `Adder` could have done the addition more efficiently.
238 """
240 @property
241 def name(self):
242 return self.__class__.__name__
244 @abc.abstractmethod
245 def can_add(self, op1, op2):
246 """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`."""
247 pass
249 @abc.abstractmethod
250 def _add(self, op1, op2, operator_name, hints):
251 # Derived classes can assume op1 and op2 have been validated, e.g. they have
252 # the same dtype, and their domain/range dimensions match.
253 pass
255 def add(self, op1, op2, operator_name, hints=None):
256 """Return new `LinearOperator` acting like `op1 + op2`.
258 Args:
259 op1: `LinearOperator`
260 op2: `LinearOperator`, with `shape` and `dtype` such that adding to
261 `op1` is allowed.
262 operator_name: `String` name to give to returned `LinearOperator`
263 hints: `_Hints` object. Returned `LinearOperator` will be created with
264 these hints.
266 Returns:
267 `LinearOperator`
268 """
269 updated_hints = _infer_hints_allowing_override(op1, op2, hints)
271 if operator_name is None:
272 operator_name = "Add/" + op1.name + "__" + op2.name + "/"
274 scope_name = self.name
275 if scope_name.startswith("_"):
276 scope_name = scope_name[1:]
277 with ops.name_scope(scope_name):
278 return self._add(op1, op2, operator_name, updated_hints)
281class _AddAndReturnScaledIdentity(_Adder):
282 """Handles additions resulting in an Identity family member.
284 The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
285 is closed under addition. This `Adder` respects that, and returns an Identity
286 """
288 def can_add(self, op1, op2):
289 types = {_type(op1), _type(op2)}
290 return not types.difference(_IDENTITY_FAMILY)
292 def _add(self, op1, op2, operator_name, hints):
293 # Will build a LinearOperatorScaledIdentity.
295 if _type(op1) == _SCALED_IDENTITY:
296 multiplier_1 = op1.multiplier
297 else:
298 multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
300 if _type(op2) == _SCALED_IDENTITY:
301 multiplier_2 = op2.multiplier
302 else:
303 multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
305 return linear_operator_identity.LinearOperatorScaledIdentity(
306 num_rows=op1.range_dimension_tensor(),
307 multiplier=multiplier_1 + multiplier_2,
308 is_non_singular=hints.is_non_singular,
309 is_self_adjoint=hints.is_self_adjoint,
310 is_positive_definite=hints.is_positive_definite,
311 name=operator_name)
314class _AddAndReturnDiag(_Adder):
315 """Handles additions resulting in a Diag operator."""
317 def can_add(self, op1, op2):
318 types = {_type(op1), _type(op2)}
319 return not types.difference(_DIAG_LIKE)
321 def _add(self, op1, op2, operator_name, hints):
322 return linear_operator_diag.LinearOperatorDiag(
323 diag=op1.diag_part() + op2.diag_part(),
324 is_non_singular=hints.is_non_singular,
325 is_self_adjoint=hints.is_self_adjoint,
326 is_positive_definite=hints.is_positive_definite,
327 name=operator_name)
330class _AddAndReturnTriL(_Adder):
331 """Handles additions resulting in a TriL operator."""
333 def can_add(self, op1, op2):
334 types = {_type(op1), _type(op2)}
335 return not types.difference(_DIAG_LIKE.union({_TRIL}))
337 def _add(self, op1, op2, operator_name, hints):
338 if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
339 op_add_to_tensor, op_other = op1, op2
340 else:
341 op_add_to_tensor, op_other = op2, op1
343 return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
344 tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
345 is_non_singular=hints.is_non_singular,
346 is_self_adjoint=hints.is_self_adjoint,
347 is_positive_definite=hints.is_positive_definite,
348 name=operator_name)
351class _AddAndReturnMatrix(_Adder):
352 """"Handles additions resulting in a `LinearOperatorFullMatrix`."""
354 def can_add(self, op1, op2): # pylint: disable=unused-argument
355 return isinstance(op1, linear_operator.LinearOperator) and isinstance(
356 op2, linear_operator.LinearOperator)
358 def _add(self, op1, op2, operator_name, hints):
359 if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
360 op_add_to_tensor, op_other = op1, op2
361 else:
362 op_add_to_tensor, op_other = op2, op1
363 return linear_operator_full_matrix.LinearOperatorFullMatrix(
364 matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
365 is_non_singular=hints.is_non_singular,
366 is_self_adjoint=hints.is_self_adjoint,
367 is_positive_definite=hints.is_positive_definite,
368 name=operator_name)
371################################################################################
372# Constants designating types of LinearOperators
373################################################################################
375# Type name constants for LinearOperator classes.
376_IDENTITY = "identity"
377_SCALED_IDENTITY = "scaled_identity"
378_DIAG = "diag"
379_TRIL = "tril"
380_MATRIX = "matrix"
382# Groups of operators.
383_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
384_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
385# operators with an efficient .add_to_tensor() method.
386_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
388# Supported LinearOperator classes.
389SUPPORTED_OPERATORS = [
390 linear_operator_diag.LinearOperatorDiag,
391 linear_operator_lower_triangular.LinearOperatorLowerTriangular,
392 linear_operator_full_matrix.LinearOperatorFullMatrix,
393 linear_operator_identity.LinearOperatorIdentity,
394 linear_operator_identity.LinearOperatorScaledIdentity
395]
398def _type(operator):
399 """Returns the type name constant (e.g. _TRIL) for operator."""
400 if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
401 return _DIAG
402 if isinstance(operator,
403 linear_operator_lower_triangular.LinearOperatorLowerTriangular):
404 return _TRIL
405 if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
406 return _MATRIX
407 if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
408 return _IDENTITY
409 if isinstance(operator,
410 linear_operator_identity.LinearOperatorScaledIdentity):
411 return _SCALED_IDENTITY
412 raise TypeError(f"Expected operator to be one of [LinearOperatorDiag, "
413 f"LinearOperatorLowerTriangular, LinearOperatorFullMatrix, "
414 f"LinearOperatorIdentity, LinearOperatorScaledIdentity]. "
415 f"Received: {operator}")
418################################################################################
419# Addition tiers:
420# We attempt to use Adders in tier K before K+1.
421#
422# Organize tiers to
423# (i) reduce O(..) complexity of forming final operator, and
424# (ii) produce the "most efficient" final operator.
425# Dev notes:
426# * Results of addition at tier K will be added at tier K or higher.
427# * Tiers may change, and we warn the user that it may change.
428################################################################################
430# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
431# dense matrix. So it is sometimes very inefficient.
432_DEFAULT_ADDITION_TIERS = [
433 [_AddAndReturnScaledIdentity()],
434 [_AddAndReturnDiag()],
435 [_AddAndReturnTriL()],
436 [_AddAndReturnMatrix()],
437]