Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_identity.py: 29%
256 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like the identity matrix."""
17import numpy as np
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import tensor_conversion
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import array_ops_stack
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops.linalg import linalg_impl as linalg
30from tensorflow.python.ops.linalg import linear_operator
31from tensorflow.python.ops.linalg import linear_operator_util
32from tensorflow.python.util.tf_export import tf_export
34__all__ = [
35 "LinearOperatorIdentity",
36 "LinearOperatorScaledIdentity",
37]
40class BaseLinearOperatorIdentity(linear_operator.LinearOperator):
41 """Base class for Identity operators."""
43 def _check_num_rows_possibly_add_asserts(self):
44 """Static check of init arg `num_rows`, possibly add asserts."""
45 # Possibly add asserts.
46 if self._assert_proper_shapes:
47 self._num_rows = control_flow_ops.with_dependencies([
48 check_ops.assert_rank(
49 self._num_rows,
50 0,
51 message="Argument num_rows must be a 0-D Tensor."),
52 check_ops.assert_non_negative(
53 self._num_rows,
54 message="Argument num_rows must be non-negative."),
55 ], self._num_rows)
57 # Static checks.
58 if not self._num_rows.dtype.is_integer:
59 raise TypeError("Argument num_rows must be integer type. Found:"
60 " %s" % self._num_rows)
62 num_rows_static = self._num_rows_static
64 if num_rows_static is None:
65 return # Cannot do any other static checks.
67 if num_rows_static.ndim != 0:
68 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:"
69 " %s" % num_rows_static)
71 if num_rows_static < 0:
72 raise ValueError("Argument num_rows must be non-negative. Found:"
73 " %s" % num_rows_static)
75 def _min_matrix_dim(self):
76 """Minimum of domain/range dimension, if statically available, else None."""
77 domain_dim = tensor_shape.dimension_value(self.domain_dimension)
78 range_dim = tensor_shape.dimension_value(self.range_dimension)
79 if domain_dim is None or range_dim is None:
80 return None
81 return min(domain_dim, range_dim)
83 def _min_matrix_dim_tensor(self):
84 """Minimum of domain/range dimension, as a tensor."""
85 return math_ops.reduce_min(self.shape_tensor()[-2:])
87 def _ones_diag(self):
88 """Returns the diagonal of this operator as all ones."""
89 if self.shape.is_fully_defined():
90 d_shape = self.batch_shape.concatenate([self._min_matrix_dim()])
91 else:
92 d_shape = array_ops.concat(
93 [self.batch_shape_tensor(),
94 [self._min_matrix_dim_tensor()]], axis=0)
96 return array_ops.ones(shape=d_shape, dtype=self.dtype)
99@tf_export("linalg.LinearOperatorIdentity")
100@linear_operator.make_composite_tensor
101class LinearOperatorIdentity(BaseLinearOperatorIdentity):
102 """`LinearOperator` acting like a [batch] square identity matrix.
104 This operator acts like a [batch] identity matrix `A` with shape
105 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
106 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
107 an `N x N` matrix. This matrix `A` is not materialized, but for
108 purposes of broadcasting this shape will be relevant.
110 `LinearOperatorIdentity` is initialized with `num_rows`, and optionally
111 `batch_shape`, and `dtype` arguments. If `batch_shape` is `None`, this
112 operator efficiently passes through all arguments. If `batch_shape` is
113 provided, broadcasting may occur, which will require making copies.
115 ```python
116 # Create a 2 x 2 identity matrix.
117 operator = LinearOperatorIdentity(num_rows=2, dtype=tf.float32)
119 operator.to_dense()
120 ==> [[1., 0.]
121 [0., 1.]]
123 operator.shape
124 ==> [2, 2]
126 operator.log_abs_determinant()
127 ==> 0.
129 x = ... Shape [2, 4] Tensor
130 operator.matmul(x)
131 ==> Shape [2, 4] Tensor, same as x.
133 y = tf.random.normal(shape=[3, 2, 4])
134 # Note that y.shape is compatible with operator.shape because operator.shape
135 # is broadcast to [3, 2, 2].
136 # This broadcast does NOT require copying data, since we can infer that y
137 # will be passed through without changing shape. We are always able to infer
138 # this if the operator has no batch_shape.
139 x = operator.solve(y)
140 ==> Shape [3, 2, 4] Tensor, same as y.
142 # Create a 2-batch of 2x2 identity matrices
143 operator = LinearOperatorIdentity(num_rows=2, batch_shape=[2])
144 operator.to_dense()
145 ==> [[[1., 0.]
146 [0., 1.]],
147 [[1., 0.]
148 [0., 1.]]]
150 # Here, even though the operator has a batch shape, the input is the same as
151 # the output, so x can be passed through without a copy. The operator is able
152 # to detect that no broadcast is necessary because both x and the operator
153 # have statically defined shape.
154 x = ... Shape [2, 2, 3]
155 operator.matmul(x)
156 ==> Shape [2, 2, 3] Tensor, same as x
158 # Here the operator and x have different batch_shape, and are broadcast.
159 # This requires a copy, since the output is different size than the input.
160 x = ... Shape [1, 2, 3]
161 operator.matmul(x)
162 ==> Shape [2, 2, 3] Tensor, equal to [x, x]
163 ```
165 ### Shape compatibility
167 This operator acts on [batch] matrix with compatible shape.
168 `x` is a batch matrix with compatible shape for `matmul` and `solve` if
170 ```
171 operator.shape = [B1,...,Bb] + [N, N], with b >= 0
172 x.shape = [C1,...,Cc] + [N, R],
173 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
174 ```
176 ### Performance
178 If `batch_shape` initialization arg is `None`:
180 * `operator.matmul(x)` is `O(1)`
181 * `operator.solve(x)` is `O(1)`
182 * `operator.determinant()` is `O(1)`
184 If `batch_shape` initialization arg is provided, and static checks cannot
185 rule out the need to broadcast:
187 * `operator.matmul(x)` is `O(D1*...*Dd*N*R)`
188 * `operator.solve(x)` is `O(D1*...*Dd*N*R)`
189 * `operator.determinant()` is `O(B1*...*Bb)`
191 #### Matrix property hints
193 This `LinearOperator` is initialized with boolean flags of the form `is_X`,
194 for `X = non_singular, self_adjoint, positive_definite, square`.
195 These have the following meaning:
197 * If `is_X == True`, callers should expect the operator to have the
198 property `X`. This is a promise that should be fulfilled, but is *not* a
199 runtime assert. For example, finite floating point precision may result
200 in these promises being violated.
201 * If `is_X == False`, callers should expect the operator to not have `X`.
202 * If `is_X == None` (the default), callers should have no expectation either
203 way.
204 """
206 def __init__(self,
207 num_rows,
208 batch_shape=None,
209 dtype=None,
210 is_non_singular=True,
211 is_self_adjoint=True,
212 is_positive_definite=True,
213 is_square=True,
214 assert_proper_shapes=False,
215 name="LinearOperatorIdentity"):
216 r"""Initialize a `LinearOperatorIdentity`.
218 The `LinearOperatorIdentity` is initialized with arguments defining `dtype`
219 and shape.
221 This operator is able to broadcast the leading (batch) dimensions, which
222 sometimes requires copying data. If `batch_shape` is `None`, the operator
223 can take arguments of any batch shape without copying. See examples.
225 Args:
226 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the
227 corresponding identity matrix.
228 batch_shape: Optional `1-D` integer `Tensor`. The shape of the leading
229 dimensions. If `None`, this operator has no leading dimensions.
230 dtype: Data type of the matrix that this operator represents.
231 is_non_singular: Expect that this operator is non-singular.
232 is_self_adjoint: Expect that this operator is equal to its hermitian
233 transpose.
234 is_positive_definite: Expect that this operator is positive definite,
235 meaning the quadratic form `x^H A x` has positive real part for all
236 nonzero `x`. Note that we do not require the operator to be
237 self-adjoint to be positive-definite. See:
238 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
239 is_square: Expect that this operator acts like square [batch] matrices.
240 assert_proper_shapes: Python `bool`. If `False`, only perform static
241 checks that initialization and method arguments have proper shape.
242 If `True`, and static checks are inconclusive, add asserts to the graph.
243 name: A name for this `LinearOperator`
245 Raises:
246 ValueError: If `num_rows` is determined statically to be non-scalar, or
247 negative.
248 ValueError: If `batch_shape` is determined statically to not be 1-D, or
249 negative.
250 ValueError: If any of the following is not `True`:
251 `{is_self_adjoint, is_non_singular, is_positive_definite}`.
252 TypeError: If `num_rows` or `batch_shape` is ref-type (e.g. Variable).
253 """
254 parameters = dict(
255 num_rows=num_rows,
256 batch_shape=batch_shape,
257 dtype=dtype,
258 is_non_singular=is_non_singular,
259 is_self_adjoint=is_self_adjoint,
260 is_positive_definite=is_positive_definite,
261 is_square=is_square,
262 assert_proper_shapes=assert_proper_shapes,
263 name=name)
265 dtype = dtype or dtypes.float32
266 self._assert_proper_shapes = assert_proper_shapes
268 with ops.name_scope(name):
269 dtype = dtypes.as_dtype(dtype)
270 if not is_self_adjoint:
271 raise ValueError("An identity operator is always self adjoint.")
272 if not is_non_singular:
273 raise ValueError("An identity operator is always non-singular.")
274 if not is_positive_definite:
275 raise ValueError("An identity operator is always positive-definite.")
276 if not is_square:
277 raise ValueError("An identity operator is always square.")
279 super(LinearOperatorIdentity, self).__init__(
280 dtype=dtype,
281 is_non_singular=is_non_singular,
282 is_self_adjoint=is_self_adjoint,
283 is_positive_definite=is_positive_definite,
284 is_square=is_square,
285 parameters=parameters,
286 name=name)
288 linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
289 linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape")
291 self._num_rows = linear_operator_util.shape_tensor(
292 num_rows, name="num_rows")
293 self._num_rows_static = tensor_util.constant_value(self._num_rows)
294 self._check_num_rows_possibly_add_asserts()
296 if batch_shape is None:
297 self._batch_shape_arg = None
298 else:
299 self._batch_shape_arg = linear_operator_util.shape_tensor(
300 batch_shape, name="batch_shape_arg")
301 self._batch_shape_static = tensor_util.constant_value(
302 self._batch_shape_arg)
303 self._check_batch_shape_possibly_add_asserts()
305 def _shape(self):
306 matrix_shape = tensor_shape.TensorShape((self._num_rows_static,
307 self._num_rows_static))
308 if self._batch_shape_arg is None:
309 return matrix_shape
311 batch_shape = tensor_shape.TensorShape(self._batch_shape_static)
312 return batch_shape.concatenate(matrix_shape)
314 def _shape_tensor(self):
315 matrix_shape = array_ops_stack.stack(
316 (self._num_rows, self._num_rows), axis=0)
317 if self._batch_shape_arg is None:
318 return matrix_shape
320 return array_ops.concat((self._batch_shape_arg, matrix_shape), 0)
322 def _assert_non_singular(self):
323 return control_flow_ops.no_op("assert_non_singular")
325 def _assert_positive_definite(self):
326 return control_flow_ops.no_op("assert_positive_definite")
328 def _assert_self_adjoint(self):
329 return control_flow_ops.no_op("assert_self_adjoint")
331 def _possibly_broadcast_batch_shape(self, x):
332 """Return 'x', possibly after broadcasting the leading dimensions."""
333 # If we have no batch shape, our batch shape broadcasts with everything!
334 if self._batch_shape_arg is None:
335 return x
337 # Static attempt:
338 # If we determine that no broadcast is necessary, pass x through
339 # If we need a broadcast, add to an array of zeros.
340 #
341 # special_shape is the shape that, when broadcast with x's shape, will give
342 # the correct broadcast_shape. Note that
343 # We have already verified the second to last dimension of self.shape
344 # matches x's shape in assert_compatible_matrix_dimensions.
345 # Also, the final dimension of 'x' can have any shape.
346 # Therefore, the final two dimensions of special_shape are 1's.
347 special_shape = self.batch_shape.concatenate([1, 1])
348 bshape = array_ops.broadcast_static_shape(x.shape, special_shape)
349 if special_shape.is_fully_defined():
350 # bshape.is_fully_defined iff special_shape.is_fully_defined.
351 if bshape == x.shape:
352 return x
353 # Use the built in broadcasting of addition.
354 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
355 return x + zeros
357 # Dynamic broadcast:
358 # Always add to an array of zeros, rather than using a "cond", since a
359 # cond would require copying data from GPU --> CPU.
360 special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0)
361 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
362 return x + zeros
364 def _matmul(self, x, adjoint=False, adjoint_arg=False):
365 # Note that adjoint has no effect since this matrix is self-adjoint.
366 x = linalg.adjoint(x) if adjoint_arg else x
367 if self._assert_proper_shapes:
368 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
369 x = control_flow_ops.with_dependencies([aps], x)
370 return self._possibly_broadcast_batch_shape(x)
372 def _determinant(self):
373 return array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype)
375 def _log_abs_determinant(self):
376 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
378 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
379 return self._matmul(rhs, adjoint_arg=adjoint_arg)
381 def _trace(self):
382 # Get Tensor of all ones of same shape as self.batch_shape.
383 if self.batch_shape.is_fully_defined():
384 batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype)
385 else:
386 batch_of_ones = array_ops.ones(
387 shape=self.batch_shape_tensor(), dtype=self.dtype)
389 if self._min_matrix_dim() is not None:
390 return self._min_matrix_dim() * batch_of_ones
391 else:
392 return (math_ops.cast(self._min_matrix_dim_tensor(), self.dtype) *
393 batch_of_ones)
395 def _diag_part(self):
396 return self._ones_diag()
398 def add_to_tensor(self, mat, name="add_to_tensor"):
399 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`.
401 Args:
402 mat: `Tensor` with same `dtype` and shape broadcastable to `self`.
403 name: A name to give this `Op`.
405 Returns:
406 A `Tensor` with broadcast shape and same `dtype` as `self`.
407 """
408 with self._name_scope(name): # pylint: disable=not-callable
409 mat = tensor_conversion.convert_to_tensor_v2_with_dispatch(
410 mat, name="mat"
411 )
412 mat_diag = array_ops.matrix_diag_part(mat)
413 new_diag = 1 + mat_diag
414 return array_ops.matrix_set_diag(mat, new_diag)
416 def _eigvals(self):
417 return self._ones_diag()
419 def _cond(self):
420 return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype)
422 def _check_num_rows_possibly_add_asserts(self):
423 """Static check of init arg `num_rows`, possibly add asserts."""
424 # Possibly add asserts.
425 if self._assert_proper_shapes:
426 self._num_rows = control_flow_ops.with_dependencies([
427 check_ops.assert_rank(
428 self._num_rows,
429 0,
430 message="Argument num_rows must be a 0-D Tensor."),
431 check_ops.assert_non_negative(
432 self._num_rows,
433 message="Argument num_rows must be non-negative."),
434 ], self._num_rows)
436 # Static checks.
437 if not self._num_rows.dtype.is_integer:
438 raise TypeError("Argument num_rows must be integer type. Found:"
439 " %s" % self._num_rows)
441 num_rows_static = self._num_rows_static
443 if num_rows_static is None:
444 return # Cannot do any other static checks.
446 if num_rows_static.ndim != 0:
447 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:"
448 " %s" % num_rows_static)
450 if num_rows_static < 0:
451 raise ValueError("Argument num_rows must be non-negative. Found:"
452 " %s" % num_rows_static)
454 def _check_batch_shape_possibly_add_asserts(self):
455 """Static check of init arg `batch_shape`, possibly add asserts."""
456 if self._batch_shape_arg is None:
457 return
459 # Possibly add asserts
460 if self._assert_proper_shapes:
461 self._batch_shape_arg = control_flow_ops.with_dependencies([
462 check_ops.assert_rank(
463 self._batch_shape_arg,
464 1,
465 message="Argument batch_shape must be a 1-D Tensor."),
466 check_ops.assert_non_negative(
467 self._batch_shape_arg,
468 message="Argument batch_shape must be non-negative."),
469 ], self._batch_shape_arg)
471 # Static checks
472 if not self._batch_shape_arg.dtype.is_integer:
473 raise TypeError("Argument batch_shape must be integer type. Found:"
474 " %s" % self._batch_shape_arg)
476 if self._batch_shape_static is None:
477 return # Cannot do any other static checks.
479 if self._batch_shape_static.ndim != 1:
480 raise ValueError("Argument batch_shape must be a 1-D Tensor. Found:"
481 " %s" % self._batch_shape_static)
483 if np.any(self._batch_shape_static < 0):
484 raise ValueError("Argument batch_shape must be non-negative. Found:"
485 "%s" % self._batch_shape_static)
487 @property
488 def _composite_tensor_prefer_static_fields(self):
489 return ("num_rows", "batch_shape")
491 @property
492 def _composite_tensor_fields(self):
493 return ("num_rows", "batch_shape", "dtype", "assert_proper_shapes")
495 def __getitem__(self, slices):
496 # Slice the batch shape and return a new LinearOperatorIdentity.
497 # Use a proxy shape and slice it. Use this as the new batch shape
498 new_batch_shape = array_ops.shape(
499 array_ops.ones(self._batch_shape_arg)[slices])
500 parameters = dict(self.parameters, batch_shape=new_batch_shape)
501 return LinearOperatorIdentity(**parameters)
504@tf_export("linalg.LinearOperatorScaledIdentity")
505@linear_operator.make_composite_tensor
506class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
507 """`LinearOperator` acting like a scaled [batch] identity matrix `A = c I`.
509 This operator acts like a scaled [batch] identity matrix `A` with shape
510 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
511 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
512 a scaled version of the `N x N` identity matrix.
514 `LinearOperatorIdentity` is initialized with `num_rows`, and a `multiplier`
515 (a `Tensor`) of shape `[B1,...,Bb]`. `N` is set to `num_rows`, and the
516 `multiplier` determines the scale for each batch member.
518 ```python
519 # Create a 2 x 2 scaled identity matrix.
520 operator = LinearOperatorIdentity(num_rows=2, multiplier=3.)
522 operator.to_dense()
523 ==> [[3., 0.]
524 [0., 3.]]
526 operator.shape
527 ==> [2, 2]
529 operator.log_abs_determinant()
530 ==> 2 * Log[3]
532 x = ... Shape [2, 4] Tensor
533 operator.matmul(x)
534 ==> 3 * x
536 y = tf.random.normal(shape=[3, 2, 4])
537 # Note that y.shape is compatible with operator.shape because operator.shape
538 # is broadcast to [3, 2, 2].
539 x = operator.solve(y)
540 ==> 3 * x
542 # Create a 2-batch of 2x2 identity matrices
543 operator = LinearOperatorIdentity(num_rows=2, multiplier=5.)
544 operator.to_dense()
545 ==> [[[5., 0.]
546 [0., 5.]],
547 [[5., 0.]
548 [0., 5.]]]
550 x = ... Shape [2, 2, 3]
551 operator.matmul(x)
552 ==> 5 * x
554 # Here the operator and x have different batch_shape, and are broadcast.
555 x = ... Shape [1, 2, 3]
556 operator.matmul(x)
557 ==> 5 * x
558 ```
560 ### Shape compatibility
562 This operator acts on [batch] matrix with compatible shape.
563 `x` is a batch matrix with compatible shape for `matmul` and `solve` if
565 ```
566 operator.shape = [B1,...,Bb] + [N, N], with b >= 0
567 x.shape = [C1,...,Cc] + [N, R],
568 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
569 ```
571 ### Performance
573 * `operator.matmul(x)` is `O(D1*...*Dd*N*R)`
574 * `operator.solve(x)` is `O(D1*...*Dd*N*R)`
575 * `operator.determinant()` is `O(D1*...*Dd)`
577 #### Matrix property hints
579 This `LinearOperator` is initialized with boolean flags of the form `is_X`,
580 for `X = non_singular, self_adjoint, positive_definite, square`.
581 These have the following meaning
582 * If `is_X == True`, callers should expect the operator to have the
583 property `X`. This is a promise that should be fulfilled, but is *not* a
584 runtime assert. For example, finite floating point precision may result
585 in these promises being violated.
586 * If `is_X == False`, callers should expect the operator to not have `X`.
587 * If `is_X == None` (the default), callers should have no expectation either
588 way.
589 """
591 def __init__(self,
592 num_rows,
593 multiplier,
594 is_non_singular=None,
595 is_self_adjoint=None,
596 is_positive_definite=None,
597 is_square=True,
598 assert_proper_shapes=False,
599 name="LinearOperatorScaledIdentity"):
600 r"""Initialize a `LinearOperatorScaledIdentity`.
602 The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which
603 determines the size of each identity matrix, and a `multiplier`,
604 which defines `dtype`, batch shape, and scale of each matrix.
606 This operator is able to broadcast the leading (batch) dimensions.
608 Args:
609 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the
610 corresponding identity matrix.
611 multiplier: `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar).
612 is_non_singular: Expect that this operator is non-singular.
613 is_self_adjoint: Expect that this operator is equal to its hermitian
614 transpose.
615 is_positive_definite: Expect that this operator is positive definite,
616 meaning the quadratic form `x^H A x` has positive real part for all
617 nonzero `x`. Note that we do not require the operator to be
618 self-adjoint to be positive-definite. See:
619 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
620 is_square: Expect that this operator acts like square [batch] matrices.
621 assert_proper_shapes: Python `bool`. If `False`, only perform static
622 checks that initialization and method arguments have proper shape.
623 If `True`, and static checks are inconclusive, add asserts to the graph.
624 name: A name for this `LinearOperator`
626 Raises:
627 ValueError: If `num_rows` is determined statically to be non-scalar, or
628 negative.
629 """
630 parameters = dict(
631 num_rows=num_rows,
632 multiplier=multiplier,
633 is_non_singular=is_non_singular,
634 is_self_adjoint=is_self_adjoint,
635 is_positive_definite=is_positive_definite,
636 is_square=is_square,
637 assert_proper_shapes=assert_proper_shapes,
638 name=name)
640 self._assert_proper_shapes = assert_proper_shapes
642 with ops.name_scope(name, values=[multiplier, num_rows]):
643 self._multiplier = linear_operator_util.convert_nonref_to_tensor(
644 multiplier, name="multiplier")
646 # Check and auto-set hints.
647 if not self._multiplier.dtype.is_complex:
648 if is_self_adjoint is False: # pylint: disable=g-bool-id-comparison
649 raise ValueError("A real diagonal operator is always self adjoint.")
650 else:
651 is_self_adjoint = True
653 if not is_square:
654 raise ValueError("A ScaledIdentity operator is always square.")
656 linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
658 super(LinearOperatorScaledIdentity, self).__init__(
659 dtype=self._multiplier.dtype.base_dtype,
660 is_non_singular=is_non_singular,
661 is_self_adjoint=is_self_adjoint,
662 is_positive_definite=is_positive_definite,
663 is_square=is_square,
664 parameters=parameters,
665 name=name)
667 self._num_rows = linear_operator_util.shape_tensor(
668 num_rows, name="num_rows")
669 self._num_rows_static = tensor_util.constant_value(self._num_rows)
670 self._check_num_rows_possibly_add_asserts()
671 self._num_rows_cast_to_dtype = math_ops.cast(self._num_rows, self.dtype)
672 self._num_rows_cast_to_real_dtype = math_ops.cast(self._num_rows,
673 self.dtype.real_dtype)
675 def _shape(self):
676 matrix_shape = tensor_shape.TensorShape((self._num_rows_static,
677 self._num_rows_static))
679 batch_shape = self.multiplier.shape
680 return batch_shape.concatenate(matrix_shape)
682 def _shape_tensor(self):
683 matrix_shape = array_ops_stack.stack(
684 (self._num_rows, self._num_rows), axis=0)
686 batch_shape = array_ops.shape(self.multiplier)
687 return array_ops.concat((batch_shape, matrix_shape), 0)
689 def _assert_non_singular(self):
690 return check_ops.assert_positive(
691 math_ops.abs(self.multiplier), message="LinearOperator was singular")
693 def _assert_positive_definite(self):
694 return check_ops.assert_positive(
695 math_ops.real(self.multiplier),
696 message="LinearOperator was not positive definite.")
698 def _assert_self_adjoint(self):
699 imag_multiplier = math_ops.imag(self.multiplier)
700 return check_ops.assert_equal(
701 array_ops.zeros_like(imag_multiplier),
702 imag_multiplier,
703 message="LinearOperator was not self-adjoint")
705 def _make_multiplier_matrix(self, conjugate=False):
706 # Shape [B1,...Bb, 1, 1]
707 multiplier_matrix = array_ops.expand_dims(
708 array_ops.expand_dims(self.multiplier, -1), -1)
709 if conjugate:
710 multiplier_matrix = math_ops.conj(multiplier_matrix)
711 return multiplier_matrix
713 def _matmul(self, x, adjoint=False, adjoint_arg=False):
714 x = linalg.adjoint(x) if adjoint_arg else x
715 if self._assert_proper_shapes:
716 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
717 x = control_flow_ops.with_dependencies([aps], x)
718 return x * self._make_multiplier_matrix(conjugate=adjoint)
720 def _determinant(self):
721 return self.multiplier**self._num_rows_cast_to_dtype
723 def _log_abs_determinant(self):
724 return self._num_rows_cast_to_real_dtype * math_ops.log(
725 math_ops.abs(self.multiplier))
727 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
728 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
729 if self._assert_proper_shapes:
730 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs)
731 rhs = control_flow_ops.with_dependencies([aps], rhs)
732 return rhs / self._make_multiplier_matrix(conjugate=adjoint)
734 def _trace(self):
735 # Get Tensor of all ones of same shape as self.batch_shape.
736 if self.batch_shape.is_fully_defined():
737 batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype)
738 else:
739 batch_of_ones = array_ops.ones(
740 shape=self.batch_shape_tensor(), dtype=self.dtype)
742 if self._min_matrix_dim() is not None:
743 return self.multiplier * self._min_matrix_dim() * batch_of_ones
744 else:
745 return (self.multiplier * math_ops.cast(self._min_matrix_dim_tensor(),
746 self.dtype) * batch_of_ones)
748 def _diag_part(self):
749 return self._ones_diag() * self.multiplier[..., array_ops.newaxis]
751 def add_to_tensor(self, mat, name="add_to_tensor"):
752 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`.
754 Args:
755 mat: `Tensor` with same `dtype` and shape broadcastable to `self`.
756 name: A name to give this `Op`.
758 Returns:
759 A `Tensor` with broadcast shape and same `dtype` as `self`.
760 """
761 with self._name_scope(name): # pylint: disable=not-callable
762 # Shape [B1,...,Bb, 1]
763 multiplier_vector = array_ops.expand_dims(self.multiplier, -1)
765 # Shape [C1,...,Cc, M, M]
766 mat = tensor_conversion.convert_to_tensor_v2_with_dispatch(
767 mat, name="mat"
768 )
770 # Shape [C1,...,Cc, M]
771 mat_diag = array_ops.matrix_diag_part(mat)
773 # multiplier_vector broadcasts here.
774 new_diag = multiplier_vector + mat_diag
776 return array_ops.matrix_set_diag(mat, new_diag)
778 def _eigvals(self):
779 return self._ones_diag() * self.multiplier[..., array_ops.newaxis]
781 def _cond(self):
782 # Condition number for a scalar time identity matrix is one, except when the
783 # scalar is zero.
784 return array_ops.where_v2(
785 math_ops.equal(self._multiplier, 0.),
786 math_ops.cast(np.nan, dtype=self.dtype),
787 math_ops.cast(1., dtype=self.dtype))
789 @property
790 def multiplier(self):
791 """The [batch] scalar `Tensor`, `c` in `cI`."""
792 return self._multiplier
794 @property
795 def _composite_tensor_prefer_static_fields(self):
796 return ("num_rows",)
798 @property
799 def _composite_tensor_fields(self):
800 return ("num_rows", "multiplier", "assert_proper_shapes")
802 @property
803 def _experimental_parameter_ndims_to_matrix_ndims(self):
804 return {"multiplier": 0}