Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_zeros.py: 28%
167 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a zero matrix."""
17import numpy as np
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import errors
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import array_ops_stack
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops.linalg import linalg_impl as linalg
30from tensorflow.python.ops.linalg import linear_operator
31from tensorflow.python.ops.linalg import linear_operator_util
32from tensorflow.python.util.tf_export import tf_export
34__all__ = [
35 "LinearOperatorZeros",
36]
39@tf_export("linalg.LinearOperatorZeros")
40@linear_operator.make_composite_tensor
41class LinearOperatorZeros(linear_operator.LinearOperator):
42 """`LinearOperator` acting like a [batch] zero matrix.
44 This operator acts like a [batch] zero matrix `A` with shape
45 `[B1,...,Bb, N, M]` for some `b >= 0`. The first `b` indices index a
46 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
47 an `N x M` matrix. This matrix `A` is not materialized, but for
48 purposes of broadcasting this shape will be relevant.
50 `LinearOperatorZeros` is initialized with `num_rows`, and optionally
51 `num_columns, `batch_shape`, and `dtype` arguments. If `num_columns` is
52 `None`, then this operator will be initialized as a square matrix. If
53 `batch_shape` is `None`, this operator efficiently passes through all
54 arguments. If `batch_shape` is provided, broadcasting may occur, which will
55 require making copies.
57 ```python
58 # Create a 2 x 2 zero matrix.
59 operator = LinearOperatorZero(num_rows=2, dtype=tf.float32)
61 operator.to_dense()
62 ==> [[0., 0.]
63 [0., 0.]]
65 operator.shape
66 ==> [2, 2]
68 operator.determinant()
69 ==> 0.
71 x = ... Shape [2, 4] Tensor
72 operator.matmul(x)
73 ==> Shape [2, 4] Tensor, same as x.
75 # Create a 2-batch of 2x2 zero matrices
76 operator = LinearOperatorZeros(num_rows=2, batch_shape=[2])
77 operator.to_dense()
78 ==> [[[0., 0.]
79 [0., 0.]],
80 [[0., 0.]
81 [0., 0.]]]
83 # Here, even though the operator has a batch shape, the input is the same as
84 # the output, so x can be passed through without a copy. The operator is able
85 # to detect that no broadcast is necessary because both x and the operator
86 # have statically defined shape.
87 x = ... Shape [2, 2, 3]
88 operator.matmul(x)
89 ==> Shape [2, 2, 3] Tensor, same as tf.zeros_like(x)
91 # Here the operator and x have different batch_shape, and are broadcast.
92 # This requires a copy, since the output is different size than the input.
93 x = ... Shape [1, 2, 3]
94 operator.matmul(x)
95 ==> Shape [2, 2, 3] Tensor, equal to tf.zeros_like([x, x])
96 ```
98 ### Shape compatibility
100 This operator acts on [batch] matrix with compatible shape.
101 `x` is a batch matrix with compatible shape for `matmul` and `solve` if
103 ```
104 operator.shape = [B1,...,Bb] + [N, M], with b >= 0
105 x.shape = [C1,...,Cc] + [M, R],
106 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
107 ```
109 #### Matrix property hints
111 This `LinearOperator` is initialized with boolean flags of the form `is_X`,
112 for `X = non_singular, self_adjoint, positive_definite, square`.
113 These have the following meaning:
115 * If `is_X == True`, callers should expect the operator to have the
116 property `X`. This is a promise that should be fulfilled, but is *not* a
117 runtime assert. For example, finite floating point precision may result
118 in these promises being violated.
119 * If `is_X == False`, callers should expect the operator to not have `X`.
120 * If `is_X == None` (the default), callers should have no expectation either
121 way.
122 """
124 def __init__(self,
125 num_rows,
126 num_columns=None,
127 batch_shape=None,
128 dtype=None,
129 is_non_singular=False,
130 is_self_adjoint=True,
131 is_positive_definite=False,
132 is_square=True,
133 assert_proper_shapes=False,
134 name="LinearOperatorZeros"):
135 r"""Initialize a `LinearOperatorZeros`.
137 The `LinearOperatorZeros` is initialized with arguments defining `dtype`
138 and shape.
140 This operator is able to broadcast the leading (batch) dimensions, which
141 sometimes requires copying data. If `batch_shape` is `None`, the operator
142 can take arguments of any batch shape without copying. See examples.
144 Args:
145 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the
146 corresponding zero matrix.
147 num_columns: Scalar non-negative integer `Tensor`. Number of columns in
148 the corresponding zero matrix. If `None`, defaults to the value of
149 `num_rows`.
150 batch_shape: Optional `1-D` integer `Tensor`. The shape of the leading
151 dimensions. If `None`, this operator has no leading dimensions.
152 dtype: Data type of the matrix that this operator represents.
153 is_non_singular: Expect that this operator is non-singular.
154 is_self_adjoint: Expect that this operator is equal to its hermitian
155 transpose.
156 is_positive_definite: Expect that this operator is positive definite,
157 meaning the quadratic form `x^H A x` has positive real part for all
158 nonzero `x`. Note that we do not require the operator to be
159 self-adjoint to be positive-definite. See:
160 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
161 is_square: Expect that this operator acts like square [batch] matrices.
162 assert_proper_shapes: Python `bool`. If `False`, only perform static
163 checks that initialization and method arguments have proper shape.
164 If `True`, and static checks are inconclusive, add asserts to the graph.
165 name: A name for this `LinearOperator`
167 Raises:
168 ValueError: If `num_rows` is determined statically to be non-scalar, or
169 negative.
170 ValueError: If `num_columns` is determined statically to be non-scalar,
171 or negative.
172 ValueError: If `batch_shape` is determined statically to not be 1-D, or
173 negative.
174 ValueError: If any of the following is not `True`:
175 `{is_self_adjoint, is_non_singular, is_positive_definite}`.
176 """
177 parameters = dict(
178 num_rows=num_rows,
179 num_columns=num_columns,
180 batch_shape=batch_shape,
181 dtype=dtype,
182 is_non_singular=is_non_singular,
183 is_self_adjoint=is_self_adjoint,
184 is_positive_definite=is_positive_definite,
185 is_square=is_square,
186 assert_proper_shapes=assert_proper_shapes,
187 name=name
188 )
190 dtype = dtype or dtypes.float32
191 self._assert_proper_shapes = assert_proper_shapes
193 with ops.name_scope(name):
194 dtype = dtypes.as_dtype(dtype)
195 if not is_self_adjoint and is_square:
196 raise ValueError("A zero operator is always self adjoint.")
197 if is_non_singular:
198 raise ValueError("A zero operator is always singular.")
199 if is_positive_definite:
200 raise ValueError("A zero operator is always not positive-definite.")
202 super(LinearOperatorZeros, self).__init__(
203 dtype=dtype,
204 is_non_singular=is_non_singular,
205 is_self_adjoint=is_self_adjoint,
206 is_positive_definite=is_positive_definite,
207 is_square=is_square,
208 parameters=parameters,
209 name=name)
211 linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
212 linear_operator_util.assert_not_ref_type(num_columns, "num_columns")
213 linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape")
215 self._num_rows = linear_operator_util.shape_tensor(
216 num_rows, name="num_rows")
217 self._num_rows_static = tensor_util.constant_value(self._num_rows)
219 if num_columns is None:
220 num_columns = num_rows
222 self._num_columns = linear_operator_util.shape_tensor(
223 num_columns, name="num_columns")
224 self._num_columns_static = tensor_util.constant_value(self._num_columns)
226 self._check_domain_range_possibly_add_asserts()
228 if (self._num_rows_static is not None and
229 self._num_columns_static is not None):
230 if is_square and self._num_rows_static != self._num_columns_static:
231 raise ValueError(
232 "LinearOperatorZeros initialized as is_square=True, but got "
233 "num_rows({}) != num_columns({})".format(
234 self._num_rows_static,
235 self._num_columns_static))
237 if batch_shape is None:
238 self._batch_shape_arg = None
239 else:
240 self._batch_shape_arg = linear_operator_util.shape_tensor(
241 batch_shape, name="batch_shape_arg")
242 self._batch_shape_static = tensor_util.constant_value(
243 self._batch_shape_arg)
244 self._check_batch_shape_possibly_add_asserts()
246 def _shape(self):
247 matrix_shape = tensor_shape.TensorShape((self._num_rows_static,
248 self._num_columns_static))
249 if self._batch_shape_arg is None:
250 return matrix_shape
252 batch_shape = tensor_shape.TensorShape(self._batch_shape_static)
253 return batch_shape.concatenate(matrix_shape)
255 def _shape_tensor(self):
256 matrix_shape = array_ops_stack.stack(
257 (self._num_rows, self._num_columns), axis=0)
258 if self._batch_shape_arg is None:
259 return matrix_shape
261 return array_ops.concat((self._batch_shape_arg, matrix_shape), 0)
263 def _assert_non_singular(self):
264 raise errors.InvalidArgumentError(
265 node_def=None, op=None, message="Zero operators are always "
266 "non-invertible.")
268 def _assert_positive_definite(self):
269 raise errors.InvalidArgumentError(
270 node_def=None, op=None, message="Zero operators are always "
271 "non-positive definite.")
273 def _assert_self_adjoint(self):
274 return control_flow_ops.no_op("assert_self_adjoint")
276 def _possibly_broadcast_batch_shape(self, x):
277 """Return 'x', possibly after broadcasting the leading dimensions."""
278 # If we have no batch shape, our batch shape broadcasts with everything!
279 if self._batch_shape_arg is None:
280 return x
282 # Static attempt:
283 # If we determine that no broadcast is necessary, pass x through
284 # If we need a broadcast, add to an array of zeros.
285 #
286 # special_shape is the shape that, when broadcast with x's shape, will give
287 # the correct broadcast_shape. Note that
288 # We have already verified the second to last dimension of self.shape
289 # matches x's shape in assert_compatible_matrix_dimensions.
290 # Also, the final dimension of 'x' can have any shape.
291 # Therefore, the final two dimensions of special_shape are 1's.
292 special_shape = self.batch_shape.concatenate([1, 1])
293 bshape = array_ops.broadcast_static_shape(x.shape, special_shape)
294 if special_shape.is_fully_defined():
295 # bshape.is_fully_defined iff special_shape.is_fully_defined.
296 if bshape == x.shape:
297 return x
298 # Use the built in broadcasting of addition.
299 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
300 return x + zeros
302 # Dynamic broadcast:
303 # Always add to an array of zeros, rather than using a "cond", since a
304 # cond would require copying data from GPU --> CPU.
305 special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0)
306 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
307 return x + zeros
309 def _matmul(self, x, adjoint=False, adjoint_arg=False):
310 if self._assert_proper_shapes:
311 x = linalg.adjoint(x) if adjoint_arg else x
312 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
313 x = control_flow_ops.with_dependencies([aps], x)
314 if self.is_square:
315 # Note that adjoint has no effect since this matrix is self-adjoint.
316 if adjoint_arg:
317 output_shape = array_ops.concat([
318 array_ops.shape(x)[:-2],
319 [array_ops.shape(x)[-1], array_ops.shape(x)[-2]]], axis=0)
320 else:
321 output_shape = array_ops.shape(x)
323 return self._possibly_broadcast_batch_shape(
324 array_ops.zeros(shape=output_shape, dtype=x.dtype))
326 x_shape = array_ops.shape(x)
327 n = self._num_columns if adjoint else self._num_rows
328 m = x_shape[-2] if adjoint_arg else x_shape[-1]
330 output_shape = array_ops.concat([x_shape[:-2], [n, m]], axis=0)
332 zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype)
333 return self._possibly_broadcast_batch_shape(zeros)
335 def _determinant(self):
336 if self.batch_shape.is_fully_defined():
337 return array_ops.zeros(shape=self.batch_shape, dtype=self.dtype)
338 else:
339 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
341 def _trace(self):
342 # Get Tensor of all zeros of same shape as self.batch_shape.
343 if self.batch_shape.is_fully_defined():
344 return array_ops.zeros(shape=self.batch_shape, dtype=self.dtype)
345 else:
346 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
348 def _diag_part(self):
349 return self._zeros_diag()
351 def add_to_tensor(self, mat, name="add_to_tensor"):
352 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`.
354 Args:
355 mat: `Tensor` with same `dtype` and shape broadcastable to `self`.
356 name: A name to give this `Op`.
358 Returns:
359 A `Tensor` with broadcast shape and same `dtype` as `self`.
360 """
361 return self._possibly_broadcast_batch_shape(mat)
363 def _check_domain_range_possibly_add_asserts(self):
364 """Static check of init arg `num_rows`, possibly add asserts."""
365 # Possibly add asserts.
366 if self._assert_proper_shapes:
367 self._num_rows = control_flow_ops.with_dependencies([
368 check_ops.assert_rank(
369 self._num_rows,
370 0,
371 message="Argument num_rows must be a 0-D Tensor."),
372 check_ops.assert_non_negative(
373 self._num_rows,
374 message="Argument num_rows must be non-negative."),
375 ], self._num_rows)
376 self._num_columns = control_flow_ops.with_dependencies([
377 check_ops.assert_rank(
378 self._num_columns,
379 0,
380 message="Argument num_columns must be a 0-D Tensor."),
381 check_ops.assert_non_negative(
382 self._num_columns,
383 message="Argument num_columns must be non-negative."),
384 ], self._num_columns)
386 # Static checks.
387 if not self._num_rows.dtype.is_integer:
388 raise TypeError("Argument num_rows must be integer type. Found:"
389 " %s" % self._num_rows)
391 if not self._num_columns.dtype.is_integer:
392 raise TypeError("Argument num_columns must be integer type. Found:"
393 " %s" % self._num_columns)
395 num_rows_static = self._num_rows_static
396 num_columns_static = self._num_columns_static
398 if num_rows_static is not None:
399 if num_rows_static.ndim != 0:
400 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:"
401 " %s" % num_rows_static)
403 if num_rows_static < 0:
404 raise ValueError("Argument num_rows must be non-negative. Found:"
405 " %s" % num_rows_static)
406 if num_columns_static is not None:
407 if num_columns_static.ndim != 0:
408 raise ValueError("Argument num_columns must be a 0-D Tensor. Found:"
409 " %s" % num_columns_static)
411 if num_columns_static < 0:
412 raise ValueError("Argument num_columns must be non-negative. Found:"
413 " %s" % num_columns_static)
415 def _check_batch_shape_possibly_add_asserts(self):
416 """Static check of init arg `batch_shape`, possibly add asserts."""
417 if self._batch_shape_arg is None:
418 return
420 # Possibly add asserts
421 if self._assert_proper_shapes:
422 self._batch_shape_arg = control_flow_ops.with_dependencies([
423 check_ops.assert_rank(
424 self._batch_shape_arg,
425 1,
426 message="Argument batch_shape must be a 1-D Tensor."),
427 check_ops.assert_non_negative(
428 self._batch_shape_arg,
429 message="Argument batch_shape must be non-negative."),
430 ], self._batch_shape_arg)
432 # Static checks
433 if not self._batch_shape_arg.dtype.is_integer:
434 raise TypeError("Argument batch_shape must be integer type. Found:"
435 " %s" % self._batch_shape_arg)
437 if self._batch_shape_static is None:
438 return # Cannot do any other static checks.
440 if self._batch_shape_static.ndim != 1:
441 raise ValueError("Argument batch_shape must be a 1-D Tensor. Found:"
442 " %s" % self._batch_shape_static)
444 if np.any(self._batch_shape_static < 0):
445 raise ValueError("Argument batch_shape must be non-negative. Found:"
446 "%s" % self._batch_shape_static)
448 def _min_matrix_dim(self):
449 """Minimum of domain/range dimension, if statically available, else None."""
450 domain_dim = self.domain_dimension.value
451 range_dim = self.range_dimension.value
452 if domain_dim is None or range_dim is None:
453 return None
454 return min(domain_dim, range_dim)
456 def _min_matrix_dim_tensor(self):
457 """Minimum of domain/range dimension, as a tensor."""
458 return math_ops.reduce_min(self.shape_tensor()[-2:])
460 def _zeros_diag(self):
461 """Returns the diagonal of this operator as all zeros."""
462 if self.shape.is_fully_defined():
463 d_shape = self.batch_shape.concatenate([self._min_matrix_dim()])
464 else:
465 d_shape = array_ops.concat(
466 [self.batch_shape_tensor(),
467 [self._min_matrix_dim_tensor()]], axis=0)
469 return array_ops.zeros(shape=d_shape, dtype=self.dtype)
471 def _eigvals(self):
472 return self._zeros_diag()
474 @property
475 def _composite_tensor_prefer_static_fields(self):
476 return ("num_rows", "num_columns", "batch_shape")
478 @property
479 def _composite_tensor_fields(self):
480 return ("num_rows", "num_columns", "batch_shape", "dtype",
481 "assert_proper_shapes")
483 def __getitem__(self, slices):
484 # Slice the batch shape and return a new LinearOperatorIdentity.
485 # Use a proxy shape and slice it. Use this as the new batch shape
486 new_batch_shape = array_ops.shape(
487 array_ops.ones(self._batch_shape_arg)[slices])
488 parameters = dict(self.parameters, batch_shape=new_batch_shape)
489 return LinearOperatorZeros(**parameters)