Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_permutation.py: 41%
87 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a permutation matrix."""
17import numpy as np
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import tensor_conversion
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import sort_ops
27from tensorflow.python.ops.linalg import linalg_impl as linalg
28from tensorflow.python.ops.linalg import linear_operator
29from tensorflow.python.ops.linalg import linear_operator_util
30from tensorflow.python.util.tf_export import tf_export
32__all__ = ["LinearOperatorPermutation",]
35@tf_export("linalg.LinearOperatorPermutation")
36@linear_operator.make_composite_tensor
37class LinearOperatorPermutation(linear_operator.LinearOperator):
38 """`LinearOperator` acting like a [batch] of permutation matrices.
40 This operator acts like a [batch] of permutations with shape
41 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
42 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
43 an `N x N` matrix. This matrix `A` is not materialized, but for
44 purposes of broadcasting this shape will be relevant.
46 `LinearOperatorPermutation` is initialized with a (batch) vector.
48 A permutation, is defined by an integer vector `v` whose values are unique
49 and are in the range `[0, ... n]`. Applying the permutation on an input
50 matrix has the folllowing meaning: the value of `v` at index `i`
51 says to move the `v[i]`-th row of the input matrix to the `i`-th row.
52 Because all values are unique, this will result in a permutation of the
53 rows the input matrix. Note, that the permutation vector `v` has the same
54 semantics as `tf.transpose`.
56 ```python
57 # Create a 3 x 3 permutation matrix that swaps the last two columns.
58 vec = [0, 2, 1]
59 operator = LinearOperatorPermutation(vec)
61 operator.to_dense()
62 ==> [[1., 0., 0.]
63 [0., 0., 1.]
64 [0., 1., 0.]]
66 operator.shape
67 ==> [3, 3]
69 # This will be zero.
70 operator.log_abs_determinant()
71 ==> scalar Tensor
73 x = ... Shape [3, 4] Tensor
74 operator.matmul(x)
75 ==> Shape [3, 4] Tensor
76 ```
78 #### Shape compatibility
80 This operator acts on [batch] matrix with compatible shape.
81 `x` is a batch matrix with compatible shape for `matmul` and `solve` if
83 ```
84 operator.shape = [B1,...,Bb] + [N, N], with b >= 0
85 x.shape = [C1,...,Cc] + [N, R],
86 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
87 ```
89 #### Matrix property hints
91 This `LinearOperator` is initialized with boolean flags of the form `is_X`,
92 for `X = non_singular, self_adjoint, positive_definite, square`.
93 These have the following meaning:
95 * If `is_X == True`, callers should expect the operator to have the
96 property `X`. This is a promise that should be fulfilled, but is *not* a
97 runtime assert. For example, finite floating point precision may result
98 in these promises being violated.
99 * If `is_X == False`, callers should expect the operator to not have `X`.
100 * If `is_X == None` (the default), callers should have no expectation either
101 way.
102 """
104 def __init__(self,
105 perm,
106 dtype=dtypes.float32,
107 is_non_singular=None,
108 is_self_adjoint=None,
109 is_positive_definite=None,
110 is_square=None,
111 name="LinearOperatorPermutation"):
112 r"""Initialize a `LinearOperatorPermutation`.
114 Args:
115 perm: Shape `[B1,...,Bb, N]` Integer `Tensor` with `b >= 0`
116 `N >= 0`. An integer vector that represents the permutation to apply.
117 Note that this argument is same as `tf.transpose`. However, this
118 permutation is applied on the rows, while the permutation in
119 `tf.transpose` is applied on the dimensions of the `Tensor`. `perm`
120 is required to have unique entries from `{0, 1, ... N-1}`.
121 dtype: The `dtype` of arguments to this operator. Default: `float32`.
122 Allowed dtypes: `float16`, `float32`, `float64`, `complex64`,
123 `complex128`.
124 is_non_singular: Expect that this operator is non-singular.
125 is_self_adjoint: Expect that this operator is equal to its hermitian
126 transpose. This is autoset to true
127 is_positive_definite: Expect that this operator is positive definite,
128 meaning the quadratic form `x^H A x` has positive real part for all
129 nonzero `x`. Note that we do not require the operator to be
130 self-adjoint to be positive-definite. See:
131 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
132 This is autoset to false.
133 is_square: Expect that this operator acts like square [batch] matrices.
134 This is autoset to true.
135 name: A name for this `LinearOperator`.
137 Raises:
138 ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is
139 not `False` or `is_square` is not `True`.
140 """
141 parameters = dict(
142 perm=perm,
143 dtype=dtype,
144 is_non_singular=is_non_singular,
145 is_self_adjoint=is_self_adjoint,
146 is_positive_definite=is_positive_definite,
147 is_square=is_square,
148 name=name
149 )
151 with ops.name_scope(name, values=[perm]):
152 self._perm = linear_operator_util.convert_nonref_to_tensor(
153 perm, name="perm")
154 self._check_perm(self._perm)
156 # Check and auto-set hints.
157 if is_non_singular is False: # pylint:disable=g-bool-id-comparison
158 raise ValueError(f"A Permutation operator is always non-singular. "
159 f"Expected argument `is_non_singular` to be True. "
160 f"Received: {is_non_singular}.")
162 if is_square is False: # pylint:disable=g-bool-id-comparison
163 raise ValueError(f"A Permutation operator is always square. "
164 f"Expected argument `is_square` to be True. "
165 f"Received: {is_square}.")
166 is_square = True
168 super(LinearOperatorPermutation, self).__init__(
169 dtype=dtype,
170 is_non_singular=is_non_singular,
171 is_self_adjoint=is_self_adjoint,
172 is_positive_definite=is_positive_definite,
173 is_square=is_square,
174 parameters=parameters,
175 name=name)
177 def _check_perm(self, perm):
178 """Static check of perm."""
179 if (perm.shape.ndims is not None and perm.shape.ndims < 1):
180 raise ValueError(f"Argument `perm` must have at least 1 dimension. "
181 f"Received: {perm}.")
182 if not perm.dtype.is_integer:
183 raise TypeError(f"Argument `perm` must be integer dtype. "
184 f"Received: {perm}.")
185 # Check that the permutation satisfies the uniqueness constraint.
186 static_perm = tensor_util.constant_value(perm)
187 if static_perm is not None:
188 sorted_perm = np.sort(static_perm, axis=-1)
189 if np.any(sorted_perm != np.arange(0, static_perm.shape[-1])):
190 raise ValueError(
191 f"Argument `perm` must be a vector of unique integers from "
192 f"0 to {static_perm.shape[-1] - 1}.")
194 def _shape(self):
195 perm_shape = self._perm.shape
196 return perm_shape.concatenate(perm_shape[-1:])
198 def _shape_tensor(self):
199 perm_shape = array_ops.shape(self._perm)
200 k = perm_shape[-1]
201 return array_ops.concat((perm_shape, [k]), 0)
203 def _assert_non_singular(self):
204 return control_flow_ops.no_op("assert_non_singular")
206 def _domain_dimension_tensor(self, perm=None):
207 perm = perm if perm is not None else self.perm
208 return array_ops.shape(perm)[-1]
210 def _matmul(self, x, adjoint=False, adjoint_arg=False):
211 perm = tensor_conversion.convert_to_tensor_v2_with_dispatch(self.perm)
212 if adjoint and not self.is_self_adjoint:
213 # TODO(srvasude): invert_permutation doesn't work on batches so we use
214 # argsort.
215 perm = sort_ops.argsort(perm, axis=-1)
216 x = linalg.adjoint(x) if adjoint_arg else x
218 # We need to broadcast x and the permutation since tf.gather doesn't
219 # broadcast.
220 broadcast_shape = array_ops.broadcast_dynamic_shape(
221 array_ops.shape(x)[:-1], array_ops.shape(perm))
222 k = array_ops.shape(x)[-1]
223 broadcast_x_shape = array_ops.concat([broadcast_shape, [k]], axis=-1)
224 x = array_ops.broadcast_to(x, broadcast_x_shape)
225 perm = array_ops.broadcast_to(perm, broadcast_shape)
227 m = array_ops.shape(x)[-2]
228 x = array_ops.reshape(x, [-1, m, k])
229 perm = array_ops.reshape(perm, [-1, m])
231 y = array_ops.gather(x, perm, axis=-2, batch_dims=1)
232 return array_ops.reshape(y, broadcast_x_shape)
234 # TODO(srvasude): Permutation parity is equivalent to the determinant.
236 def _log_abs_determinant(self):
237 # Permutation matrices have determinant +/- 1.
238 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
240 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
241 # The inverse of a permutation matrix is the transpose matrix.
242 # Apply a matmul and flip the adjoint bit.
243 return self._matmul(rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
245 def _to_dense(self):
246 perm = tensor_conversion.convert_to_tensor_v2_with_dispatch(self.perm)
247 return math_ops.cast(math_ops.equal(
248 math_ops.range(0, self._domain_dimension_tensor(perm)),
249 perm[..., array_ops.newaxis]), self.dtype)
251 def _diag_part(self):
252 perm = tensor_conversion.convert_to_tensor_v2_with_dispatch(self.perm)
253 return math_ops.cast(math_ops.equal(
254 math_ops.range(0, self._domain_dimension_tensor(perm)),
255 perm), self.dtype)
257 def _cond(self):
258 # Permutation matrices are rotations which have condition number 1.
259 return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype)
261 @property
262 def perm(self):
263 return self._perm
265 @property
266 def _composite_tensor_fields(self):
267 return ("perm", "dtype")
269 @property
270 def _experimental_parameter_ndims_to_matrix_ndims(self):
271 return {"perm": 1}