Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_util.py: 17%
190 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Internal utilities for `LinearOperator` classes."""
17import numpy as np
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import tensor_conversion
22from tensorflow.python.module import module
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import linalg_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import variables as variables_module
29from tensorflow.python.util import nest
32################################################################################
33# To make more friendly for TF2.
34################################################################################
37def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, name=None):
38 """Converts the given `value` to a `Tensor` if input is nonreference type.
40 This function converts Python objects of various types to `Tensor` objects
41 except if the input has nonreference semantics. Reference semantics are
42 characterized by `is_ref` and is any object which is a
43 `tf.Variable` or instance of `tf.Module`. This function accepts any input
44 which `tf.convert_to_tensor` would also.
46 Note: This function diverges from default Numpy behavior for `float` and
47 `string` types when `None` is present in a Python list or scalar. Rather
48 than silently converting `None` values, an error will be thrown.
50 Args:
51 value: An object whose type has a registered `Tensor` conversion function.
52 dtype: Optional element type for the returned tensor. If missing, the
53 type is inferred from the type of `value`.
54 dtype_hint: Optional element type for the returned tensor,
55 used when dtype is None. In some cases, a caller may not have a
56 dtype in mind when converting to a tensor, so dtype_hint
57 can be used as a soft preference. If the conversion to
58 `dtype_hint` is not possible, this argument has no effect.
59 name: Optional name to use if a new `Tensor` is created.
61 Returns:
62 tensor: A `Tensor` based on `value`.
64 Raises:
65 TypeError: If no conversion function is registered for `value` to `dtype`.
66 RuntimeError: If a registered conversion function returns an invalid value.
67 ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
70 #### Examples:
72 ```python
74 x = tf.Variable(0.)
75 y = convert_nonref_to_tensor(x)
76 x is y
77 # ==> True
79 x = tf.constant(0.)
80 y = convert_nonref_to_tensor(x)
81 x is y
82 # ==> True
84 x = np.array(0.)
85 y = convert_nonref_to_tensor(x)
86 x is y
87 # ==> False
88 tf.is_tensor(y)
89 # ==> True
91 x = tfp.util.DeferredTensor(13.37, lambda x: x)
92 y = convert_nonref_to_tensor(x)
93 x is y
94 # ==> True
95 tf.is_tensor(y)
96 # ==> False
97 tf.equal(y, 13.37)
98 # ==> True
99 ```
101 """
102 # We explicitly do not use a tf.name_scope to avoid graph clutter.
103 if value is None:
104 return None
105 if is_ref(value):
106 if dtype is None:
107 return value
108 dtype_base = base_dtype(dtype)
109 value_dtype_base = base_dtype(value.dtype)
110 if dtype_base != value_dtype_base:
111 raise TypeError(
112 f"Argument `value` must be of dtype `{dtype_name(dtype_base)}` "
113 f"Received: `{dtype_name(value_dtype_base)}`.")
114 return value
115 return tensor_conversion.convert_to_tensor_v2_with_dispatch(
116 value, dtype=dtype, dtype_hint=dtype_hint, name=name
117 )
120def base_dtype(dtype):
121 """Returns a non-reference `dtype` based on this `dtype`."""
122 dtype = dtypes.as_dtype(dtype)
123 if hasattr(dtype, "base_dtype"):
124 return dtype.base_dtype
125 return dtype
128def dtype_name(dtype):
129 """Returns the string name for this `dtype`."""
130 dtype = dtypes.as_dtype(dtype)
131 if hasattr(dtype, "name"):
132 return dtype.name
133 if hasattr(dtype, "__name__"):
134 return dtype.__name__
135 return str(dtype)
138def check_dtype(arg, dtype):
139 """Check that arg.dtype == self.dtype."""
140 if arg.dtype.base_dtype != dtype:
141 raise TypeError(
142 f"Expected argument to have dtype {dtype}. Found: {arg.dtype} in "
143 f"tensor {arg}.")
146def is_ref(x):
147 """Evaluates if the object has reference semantics.
149 An object is deemed "reference" if it is a `tf.Variable` instance or is
150 derived from a `tf.Module` with `dtype` and `shape` properties.
152 Args:
153 x: Any object.
155 Returns:
156 is_ref: Python `bool` indicating input is has nonreference semantics, i.e.,
157 is a `tf.Variable` or a `tf.Module` with `dtype` and `shape` properties.
158 """
159 return (
160 # Note: we check that tf.Variable is a class because we might be using a
161 # different backend other than TF.
162 isinstance(x, variables_module.Variable) or
163 (isinstance(x, module.Module) and hasattr(x, "dtype") and
164 hasattr(x, "shape")))
167def assert_not_ref_type(x, arg_name):
168 if is_ref(x):
169 raise TypeError(
170 f"Argument {arg_name} cannot be reference type. Found: {type(x)}.")
173################################################################################
174# Asserts.
175################################################################################
178def assert_no_entries_with_modulus_zero(
179 x, message=None, name="assert_no_entries_with_modulus_zero"):
180 """Returns `Op` that asserts Tensor `x` has no entries with modulus zero.
182 Args:
183 x: Numeric `Tensor`, real, integer, or complex.
184 message: A string message to prepend to failure message.
185 name: A name to give this `Op`.
187 Returns:
188 An `Op` that asserts `x` has no entries with modulus zero.
189 """
190 with ops.name_scope(name, values=[x]):
191 x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x, name="x")
192 dtype = x.dtype.base_dtype
193 should_be_nonzero = math_ops.abs(x)
194 zero = tensor_conversion.convert_to_tensor_v2_with_dispatch(
195 0, dtype=dtype.real_dtype
196 )
197 return check_ops.assert_less(zero, should_be_nonzero, message=message)
200def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"):
201 """Returns `Op` that asserts Tensor `x` has no non-zero imaginary parts.
203 Args:
204 x: Numeric `Tensor`, real, integer, or complex.
205 message: A string message to prepend to failure message.
206 name: A name to give this `Op`.
208 Returns:
209 An `Op` that asserts `x` has no entries with modulus zero.
210 """
211 with ops.name_scope(name, values=[x]):
212 x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x, name="x")
213 dtype = x.dtype.base_dtype
215 if dtype.is_floating:
216 return control_flow_ops.no_op()
218 zero = tensor_conversion.convert_to_tensor_v2_with_dispatch(
219 0, dtype=dtype.real_dtype
220 )
221 return check_ops.assert_equal(zero, math_ops.imag(x), message=message)
224def assert_compatible_matrix_dimensions(operator, x):
225 """Assert that an argument to solve/matmul has proper domain dimension.
227 If `operator.shape[-2:] = [M, N]`, and `x.shape[-2:] = [Q, R]`, then
228 `operator.matmul(x)` is defined only if `N = Q`. This `Op` returns an
229 `Assert` that "fires" if this is not the case. Static checks are already
230 done by the base class `LinearOperator`.
232 Args:
233 operator: `LinearOperator`.
234 x: `Tensor`.
236 Returns:
237 `Assert` `Op`.
238 """
239 # Static checks are done in the base class. Only tensor asserts here.
240 assert_same_dd = check_ops.assert_equal(
241 array_ops.shape(x)[-2],
242 operator.domain_dimension_tensor(),
243 # This error message made to look similar to error raised by static check
244 # in the base class.
245 message=("Dimensions are not compatible. "
246 "shape[-2] of argument to be the same as this operator"))
248 return assert_same_dd
251def assert_is_batch_matrix(tensor):
252 """Static assert that `tensor` has rank `2` or higher."""
253 sh = tensor.shape
254 if sh.ndims is not None and sh.ndims < 2:
255 raise ValueError(
256 f"Expected [batch] matrix to have at least two dimensions. Found: "
257 f"{tensor}.")
260def shape_tensor(shape, name=None):
261 """Convert Tensor using default type, unless empty list or tuple."""
262 # Works just like random_ops._ShapeTensor.
263 if isinstance(shape, (tuple, list)) and not shape:
264 dtype = dtypes.int32
265 else:
266 dtype = None
267 return tensor_conversion.convert_to_tensor_v2_with_dispatch(
268 shape, dtype=dtype, name=name
269 )
272################################################################################
273# Broadcasting versions of common linear algebra functions.
274# TODO(b/77519145) Do this more efficiently in some special cases.
275################################################################################
278def broadcast_matrix_batch_dims(batch_matrices, name=None):
279 """Broadcast leading dimensions of zero or more [batch] matrices.
281 Example broadcasting one batch dim of two simple matrices.
283 ```python
284 x = [[1, 2],
285 [3, 4]] # Shape [2, 2], no batch dims
287 y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1]
289 x_bc, y_bc = broadcast_matrix_batch_dims([x, y])
291 x_bc
292 ==> [[[1, 2],
293 [3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1].
295 y_bc
296 ==> same as y
297 ```
299 Example broadcasting many batch dims
301 ```python
302 x = tf.random.normal(shape=(2, 3, 1, 4, 4))
303 y = tf.random.normal(shape=(1, 3, 2, 5, 5))
304 x_bc, y_bc = broadcast_matrix_batch_dims([x, y])
306 x_bc.shape
307 ==> (2, 3, 2, 4, 4)
309 y_bc.shape
310 ==> (2, 3, 2, 5, 5)
311 ```
313 Args:
314 batch_matrices: Iterable of `Tensor`s, each having two or more dimensions.
315 name: A string name to prepend to created ops.
317 Returns:
318 bcast_matrices: List of `Tensor`s, with `bcast_matrices[i]` containing
319 the values from `batch_matrices[i]`, with possibly broadcast batch dims.
321 Raises:
322 ValueError: If any input `Tensor` is statically determined to have less
323 than two dimensions.
324 """
325 with ops.name_scope(
326 name or "broadcast_matrix_batch_dims", values=batch_matrices):
327 check_ops.assert_proper_iterable(batch_matrices)
328 batch_matrices = list(batch_matrices)
330 for i, mat in enumerate(batch_matrices):
331 batch_matrices[i] = tensor_conversion.convert_to_tensor_v2_with_dispatch(
332 mat
333 )
334 assert_is_batch_matrix(batch_matrices[i])
336 if len(batch_matrices) < 2:
337 return batch_matrices
339 # Try static broadcasting.
340 # bcast_batch_shape is the broadcast batch shape of ALL matrices.
341 # E.g. if batch_matrices = [x, y], with
342 # x.shape = [2, j, k] (batch shape = [2])
343 # y.shape = [3, 1, l, m] (batch shape = [3, 1])
344 # ==> bcast_batch_shape = [3, 2]
345 bcast_batch_shape = batch_matrices[0].shape[:-2]
346 for mat in batch_matrices[1:]:
347 bcast_batch_shape = array_ops.broadcast_static_shape(
348 bcast_batch_shape,
349 mat.shape[:-2])
350 if bcast_batch_shape.is_fully_defined():
351 for i, mat in enumerate(batch_matrices):
352 if mat.shape[:-2] != bcast_batch_shape:
353 bcast_shape = array_ops.concat(
354 [bcast_batch_shape.as_list(), array_ops.shape(mat)[-2:]], axis=0)
355 batch_matrices[i] = array_ops.broadcast_to(mat, bcast_shape)
356 return batch_matrices
358 # Since static didn't work, do dynamic, which always copies data.
359 bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
360 for mat in batch_matrices[1:]:
361 bcast_batch_shape = array_ops.broadcast_dynamic_shape(
362 bcast_batch_shape,
363 array_ops.shape(mat)[:-2])
364 for i, mat in enumerate(batch_matrices):
365 batch_matrices[i] = array_ops.broadcast_to(
366 mat,
367 array_ops.concat(
368 [bcast_batch_shape, array_ops.shape(mat)[-2:]], axis=0))
370 return batch_matrices
373def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
374 """Solve systems of linear equations."""
375 with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]):
376 matrix = tensor_conversion.convert_to_tensor_v2_with_dispatch(
377 matrix, name="matrix"
378 )
379 rhs = tensor_conversion.convert_to_tensor_v2_with_dispatch(
380 rhs, name="rhs", dtype=matrix.dtype
381 )
383 # If either matrix/rhs has extra dims, we can reshape to get rid of them.
384 matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
385 matrix, rhs, adjoint_a=adjoint)
387 # This will broadcast by brute force if we still need to.
388 matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
390 solution = linalg_ops.matrix_solve(
391 matrix, rhs, adjoint=adjoint and still_need_to_transpose)
393 return reshape_inv(solution)
396def _reshape_for_efficiency(a,
397 b,
398 transpose_a=False,
399 transpose_b=False,
400 adjoint_a=False,
401 adjoint_b=False):
402 """Maybe reshape a, b, and return an inverse map. For matmul/solve."""
403 def identity(x):
404 return x
406 # At this point, we have not taken transpose/adjoint of a/b.
407 still_need_to_transpose = True
409 if a.shape.ndims is None or b.shape.ndims is None:
410 return a, b, identity, still_need_to_transpose
412 # This could be handled in the future, but seems less common.
413 if a.shape.ndims >= b.shape.ndims:
414 return a, b, identity, still_need_to_transpose
416 # From now on, we might modify b, but will not modify a.
418 # Suppose:
419 # a.shape = C + [m, n], b.shape =
420 # b.shape = S + C + [n, r]
421 b_extra_ndims = b.shape.ndims - a.shape.ndims
423 # b_extra_sh = S, b_main_sh = C + [n, r]
424 b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
425 b_main_sh = array_ops.shape(b)[b_extra_ndims:]
427 # No reason to flip unless the extra dims of b are big enough. Why?
428 # Assume adjoint/transpose = False. Then...
429 # By not flipping, we have to replicate a to shape
430 # b_extra_sh + a.shape,
431 # which could use extra memory. But in all cases, the final output has shape
432 # b_extra_sh + a.shape[:-1] + [b.shape[-1]]
433 # So we only end up creating a larger object if the end dim of b is smaller
434 # than the end dim of a. This often happens, e.g. if b was a vector that was
435 # expanded to a matrix (by appending a singleton).
437 # Since adjoint/transpose may not be False, we must make adjustments here.
438 # The dim of b that holds the multiple equations.
439 a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1]
440 b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1]
441 b_extra_sz_ = (
442 np.prod(b.shape[:b_extra_ndims].as_list())
443 if b.shape[:b_extra_ndims].is_fully_defined() else None)
444 if (a_domain_sz_ is not None and b_eq_sz_ is not None and
445 b_extra_sz_ is not None):
446 if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
447 return a, b, identity, still_need_to_transpose
449 # At this point, we're flipping for sure!
450 # Any transposes/adjoints will happen here explicitly, rather than in calling
451 # code. Why? To avoid having to write separate complex code for each case.
452 if adjoint_a:
453 a = array_ops.matrix_transpose(a, conjugate=True)
454 elif transpose_a:
455 a = array_ops.matrix_transpose(a, conjugate=False)
456 if adjoint_b:
457 b = array_ops.matrix_transpose(b, conjugate=True)
458 elif transpose_a:
459 b = array_ops.matrix_transpose(b, conjugate=False)
460 still_need_to_transpose = False
462 # Recompute shapes, since the transpose/adjoint may have changed them.
463 b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
464 b_main_sh = array_ops.shape(b)[b_extra_ndims:]
466 # Permutation to put the extra dims at the end.
467 perm = (
468 np.concatenate(
469 (np.arange(b_extra_ndims, b.shape.ndims),
470 np.arange(0, b_extra_ndims)), 0))
471 b_extra_on_end = array_ops.transpose(b, perm=perm)
473 # Now squash this end into one long dim.
474 b_squashed_end = array_ops.reshape(
475 b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))
477 def reshape_inv(y):
478 # Expand the extra dims hanging off the end, "b_extra_sh".
479 # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
480 # Could have different batch dims than a and b, because of broadcasting.
481 y_extra_shape = array_ops.concat(
482 (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
483 y_extra_on_end = array_ops.reshape(y, y_extra_shape)
484 inverse_perm = np.argsort(perm)
485 return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
487 return a, b_squashed_end, reshape_inv, still_need_to_transpose
490################################################################################
491# Helpers for hints.
492################################################################################
495def is_adjoint_pair(x, y):
496 """True iff x and y are adjoints of each other (by id, not entries)."""
497 if x is y: # Note that if x is y then all of their hints are the same!
498 if x.is_self_adjoint is False: # pylint:disable=g-bool-id-comparison
499 return False
500 if x.is_self_adjoint:
501 return True
502 # Use the fact that if x = LinearOperatorAdjoint(y), then x.H is y.
503 return x.H is y or y.H is x
506def is_aat_form(operators):
507 """Returns True if operators is of the form A @ A.H, possibly recursively."""
508 operators = list(operators)
509 if not operators:
510 raise ValueError("AAT form is undefined for empty operators")
512 if len(operators) % 2:
513 return False
515 # Check for forms like (A1 @ A2) @ (A2.H @ A1.H)
516 return all(
517 is_adjoint_pair(operators[i], operators[-1 - i])
518 for i in range(len(operators) // 2))
521def use_operator_or_provided_hint_unless_contradicting(
522 operator, hint_attr_name, provided_hint_value, message):
523 """Get combined hint in the case where operator.hint should equal hint.
525 Args:
526 operator: LinearOperator that a meta-operator was initialized with.
527 hint_attr_name: String name for the attribute.
528 provided_hint_value: Bool or None. Value passed by user in initialization.
529 message: Error message to print if hints contradict.
531 Returns:
532 True, False, or None.
534 Raises:
535 ValueError: If hints contradict.
536 """
537 op_hint = getattr(operator, hint_attr_name)
538 # pylint: disable=g-bool-id-comparison
539 if op_hint is False and provided_hint_value:
540 raise ValueError(message)
541 if op_hint and provided_hint_value is False:
542 raise ValueError(message)
543 if op_hint or provided_hint_value:
544 return True
545 if op_hint is False or provided_hint_value is False:
546 return False
547 # pylint: enable=g-bool-id-comparison
548 return None
551################################################################################
552# Utilities for blockwise operators.
553################################################################################
556def arg_is_blockwise(block_dimensions, arg, arg_split_dim):
557 """Detect if input should be interpreted as a list of blocks."""
558 # Tuples and lists of length equal to the number of operators may be
559 # blockwise.
560 if (isinstance(arg, (tuple, list)) and len(arg) == len(block_dimensions)):
561 # If the elements of the iterable are not nested, interpret the input as
562 # blockwise.
563 if not any(nest.is_nested(x) for x in arg):
564 return True
565 else:
566 arg_dims = [
567 tensor_conversion.convert_to_tensor_v2_with_dispatch(x).shape[
568 arg_split_dim
569 ]
570 for x in arg
571 ]
572 self_dims = [dim.value for dim in block_dimensions]
574 # If none of the operator dimensions are known, interpret the input as
575 # blockwise if its matching dimensions are unequal.
576 if all(self_d is None for self_d in self_dims):
578 # A nested tuple/list with a single outermost element is not blockwise
579 if len(arg_dims) == 1:
580 return False
581 elif any(dim != arg_dims[0] for dim in arg_dims):
582 return True
583 else:
584 raise ValueError(
585 "Parsing of the input structure is ambiguous. Please input "
586 "a blockwise iterable of `Tensor`s or a single `Tensor`.")
588 # If input dimensions equal the respective (known) blockwise operator
589 # dimensions, then the input is blockwise.
590 if all(self_d == arg_d or self_d is None
591 for self_d, arg_d in zip(self_dims, arg_dims)):
592 return True
594 # If input dimensions equals are all equal, and are greater than or equal
595 # to the sum of the known operator dimensions, interpret the input as
596 # blockwise.
597 # input is not blockwise.
598 self_dim = sum(self_d for self_d in self_dims if self_d is not None)
599 if all(s == arg_dims[0] for s in arg_dims) and arg_dims[0] >= self_dim:
600 return False
602 # If none of these conditions is met, the input shape is mismatched.
603 raise ValueError("Input dimension does not match operator dimension.")
604 else:
605 return False
608def split_arg_into_blocks(block_dims, block_dims_fn, arg, axis=-1):
609 """Split `x` into blocks matching `operators`'s `domain_dimension`.
611 Specifically, if we have a blockwise lower-triangular matrix, with block
612 sizes along the diagonal `[M_j, M_j] j = 0,1,2..J`, this method splits `arg`
613 on `axis` into `J` tensors, whose shape at `axis` is `M_j`.
615 Args:
616 block_dims: Iterable of `TensorShapes`.
617 block_dims_fn: Callable returning an iterable of `Tensor`s.
618 arg: `Tensor`. `arg` is split into `J` tensors.
619 axis: Python `Integer` representing the axis to split `arg` on.
621 Returns:
622 A list of `Tensor`s.
623 """
624 block_sizes = [dim.value for dim in block_dims]
625 if any(d is None for d in block_sizes):
626 block_sizes = block_dims_fn()
627 return array_ops.split(arg, block_sizes, axis=axis)