Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_householder.py: 41%
92 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a Householder transformation."""
17from tensorflow.python.framework import errors
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import tensor_conversion
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import control_flow_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.ops import nn
24from tensorflow.python.ops.linalg import linalg_impl as linalg
25from tensorflow.python.ops.linalg import linear_operator
26from tensorflow.python.ops.linalg import linear_operator_util
27from tensorflow.python.util.tf_export import tf_export
29__all__ = ["LinearOperatorHouseholder",]
32@tf_export("linalg.LinearOperatorHouseholder")
33@linear_operator.make_composite_tensor
34class LinearOperatorHouseholder(linear_operator.LinearOperator):
35 """`LinearOperator` acting like a [batch] of Householder transformations.
37 This operator acts like a [batch] of householder reflections with shape
38 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
39 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
40 an `N x N` matrix. This matrix `A` is not materialized, but for
41 purposes of broadcasting this shape will be relevant.
43 `LinearOperatorHouseholder` is initialized with a (batch) vector.
45 A Householder reflection, defined via a vector `v`, which reflects points
46 in `R^n` about the hyperplane orthogonal to `v` and through the origin.
48 ```python
49 # Create a 2 x 2 householder transform.
50 vec = [1 / np.sqrt(2), 1. / np.sqrt(2)]
51 operator = LinearOperatorHouseholder(vec)
53 operator.to_dense()
54 ==> [[0., -1.]
55 [-1., -0.]]
57 operator.shape
58 ==> [2, 2]
60 operator.log_abs_determinant()
61 ==> scalar Tensor
63 x = ... Shape [2, 4] Tensor
64 operator.matmul(x)
65 ==> Shape [2, 4] Tensor
66 ```
68 #### Shape compatibility
70 This operator acts on [batch] matrix with compatible shape.
71 `x` is a batch matrix with compatible shape for `matmul` and `solve` if
73 ```
74 operator.shape = [B1,...,Bb] + [N, N], with b >= 0
75 x.shape = [C1,...,Cc] + [N, R],
76 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
77 ```
79 #### Matrix property hints
81 This `LinearOperator` is initialized with boolean flags of the form `is_X`,
82 for `X = non_singular, self_adjoint, positive_definite, square`.
83 These have the following meaning:
85 * If `is_X == True`, callers should expect the operator to have the
86 property `X`. This is a promise that should be fulfilled, but is *not* a
87 runtime assert. For example, finite floating point precision may result
88 in these promises being violated.
89 * If `is_X == False`, callers should expect the operator to not have `X`.
90 * If `is_X == None` (the default), callers should have no expectation either
91 way.
92 """
94 def __init__(self,
95 reflection_axis,
96 is_non_singular=None,
97 is_self_adjoint=None,
98 is_positive_definite=None,
99 is_square=None,
100 name="LinearOperatorHouseholder"):
101 r"""Initialize a `LinearOperatorHouseholder`.
103 Args:
104 reflection_axis: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
105 The vector defining the hyperplane to reflect about.
106 Allowed dtypes: `float16`, `float32`, `float64`, `complex64`,
107 `complex128`.
108 is_non_singular: Expect that this operator is non-singular.
109 is_self_adjoint: Expect that this operator is equal to its hermitian
110 transpose. This is autoset to true
111 is_positive_definite: Expect that this operator is positive definite,
112 meaning the quadratic form `x^H A x` has positive real part for all
113 nonzero `x`. Note that we do not require the operator to be
114 self-adjoint to be positive-definite. See:
115 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
116 This is autoset to false.
117 is_square: Expect that this operator acts like square [batch] matrices.
118 This is autoset to true.
119 name: A name for this `LinearOperator`.
121 Raises:
122 ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is
123 not `False` or `is_square` is not `True`.
124 """
125 parameters = dict(
126 reflection_axis=reflection_axis,
127 is_non_singular=is_non_singular,
128 is_self_adjoint=is_self_adjoint,
129 is_positive_definite=is_positive_definite,
130 is_square=is_square,
131 name=name
132 )
134 with ops.name_scope(name, values=[reflection_axis]):
135 self._reflection_axis = linear_operator_util.convert_nonref_to_tensor(
136 reflection_axis, name="reflection_axis")
137 self._check_reflection_axis(self._reflection_axis)
139 # Check and auto-set hints.
140 if is_self_adjoint is False: # pylint:disable=g-bool-id-comparison
141 raise ValueError("A Householder operator is always self adjoint.")
142 else:
143 is_self_adjoint = True
145 if is_positive_definite is True: # pylint:disable=g-bool-id-comparison
146 raise ValueError(
147 "A Householder operator is always non-positive definite.")
148 else:
149 is_positive_definite = False
151 if is_square is False: # pylint:disable=g-bool-id-comparison
152 raise ValueError("A Householder operator is always square.")
153 is_square = True
155 super(LinearOperatorHouseholder, self).__init__(
156 dtype=self._reflection_axis.dtype,
157 is_non_singular=is_non_singular,
158 is_self_adjoint=is_self_adjoint,
159 is_positive_definite=is_positive_definite,
160 is_square=is_square,
161 parameters=parameters,
162 name=name)
164 def _check_reflection_axis(self, reflection_axis):
165 """Static check of reflection_axis."""
166 if (reflection_axis.shape.ndims is not None and
167 reflection_axis.shape.ndims < 1):
168 raise ValueError(
169 "Argument reflection_axis must have at least 1 dimension. "
170 "Found: %s" % reflection_axis)
172 def _shape(self):
173 # If d_shape = [5, 3], we return [5, 3, 3].
174 d_shape = self._reflection_axis.shape
175 return d_shape.concatenate(d_shape[-1:])
177 def _shape_tensor(self):
178 d_shape = array_ops.shape(self._reflection_axis)
179 k = d_shape[-1]
180 return array_ops.concat((d_shape, [k]), 0)
182 def _assert_non_singular(self):
183 return control_flow_ops.no_op("assert_non_singular")
185 def _assert_positive_definite(self):
186 raise errors.InvalidArgumentError(
187 node_def=None, op=None, message="Householder operators are always "
188 "non-positive definite.")
190 def _assert_self_adjoint(self):
191 return control_flow_ops.no_op("assert_self_adjoint")
193 def _matmul(self, x, adjoint=False, adjoint_arg=False):
194 # Given a vector `v`, we would like to reflect `x` about the hyperplane
195 # orthogonal to `v` going through the origin. We first project `x` to `v`
196 # to get v * dot(v, x) / dot(v, v). After we project, we can reflect the
197 # projection about the hyperplane by flipping sign to get
198 # -v * dot(v, x) / dot(v, v). Finally, we can add back the component
199 # that is orthogonal to v. This is invariant under reflection, since the
200 # whole hyperplane is invariant. This component is equal to x - v * dot(v,
201 # x) / dot(v, v), giving the formula x - 2 * v * dot(v, x) / dot(v, v)
202 # for the reflection.
204 # Note that because this is a reflection, it lies in O(n) (for real vector
205 # spaces) or U(n) (for complex vector spaces), and thus is its own adjoint.
206 reflection_axis = tensor_conversion.convert_to_tensor_v2_with_dispatch(
207 self.reflection_axis
208 )
209 x = linalg.adjoint(x) if adjoint_arg else x
210 normalized_axis = nn.l2_normalize(reflection_axis, axis=-1)
211 mat = normalized_axis[..., array_ops.newaxis]
212 x_dot_normalized_v = math_ops.matmul(mat, x, adjoint_a=True)
214 return x - 2 * mat * x_dot_normalized_v
216 def _trace(self):
217 # We have (n - 1) +1 eigenvalues and a single -1 eigenvalue.
218 shape = self.shape_tensor()
219 return math_ops.cast(
220 self._domain_dimension_tensor(shape=shape) - 2,
221 self.dtype) * array_ops.ones(
222 shape=self._batch_shape_tensor(shape=shape), dtype=self.dtype)
224 def _determinant(self):
225 # For householder transformations, the determinant is -1.
226 return -array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype) # pylint: disable=invalid-unary-operand-type
228 def _log_abs_determinant(self):
229 # Orthogonal matrix -> log|Q| = 0.
230 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
232 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
233 # A householder reflection is a reflection, hence is idempotent. Thus we
234 # can just apply a matmul.
235 return self._matmul(rhs, adjoint, adjoint_arg)
237 def _to_dense(self):
238 reflection_axis = tensor_conversion.convert_to_tensor_v2_with_dispatch(
239 self.reflection_axis
240 )
241 normalized_axis = nn.l2_normalize(reflection_axis, axis=-1)
242 mat = normalized_axis[..., array_ops.newaxis]
243 matrix = -2 * math_ops.matmul(mat, mat, adjoint_b=True)
244 return array_ops.matrix_set_diag(
245 matrix, 1. + array_ops.matrix_diag_part(matrix))
247 def _diag_part(self):
248 reflection_axis = tensor_conversion.convert_to_tensor_v2_with_dispatch(
249 self.reflection_axis
250 )
251 normalized_axis = nn.l2_normalize(reflection_axis, axis=-1)
252 return 1. - 2 * normalized_axis * math_ops.conj(normalized_axis)
254 def _eigvals(self):
255 # We have (n - 1) +1 eigenvalues and a single -1 eigenvalue.
256 result_shape = array_ops.shape(self.reflection_axis)
257 n = result_shape[-1]
258 ones_shape = array_ops.concat([result_shape[:-1], [n - 1]], axis=-1)
259 neg_shape = array_ops.concat([result_shape[:-1], [1]], axis=-1)
260 eigvals = array_ops.ones(shape=ones_shape, dtype=self.dtype)
261 eigvals = array_ops.concat(
262 [-array_ops.ones(shape=neg_shape, dtype=self.dtype), eigvals], axis=-1) # pylint: disable=invalid-unary-operand-type
263 return eigvals
265 def _cond(self):
266 # Householder matrices are rotations which have condition number 1.
267 return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype)
269 @property
270 def reflection_axis(self):
271 return self._reflection_axis
273 @property
274 def _composite_tensor_fields(self):
275 return ("reflection_axis",)
277 @property
278 def _experimental_parameter_ndims_to_matrix_ndims(self):
279 return {"reflection_axis": 1}