Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/cholesky_registrations.py: 62%
47 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Registrations for LinearOperator.cholesky."""
17from tensorflow.python.ops import array_ops
18from tensorflow.python.ops import linalg_ops
19from tensorflow.python.ops import math_ops
20from tensorflow.python.ops.linalg import linear_operator
21from tensorflow.python.ops.linalg import linear_operator_algebra
22from tensorflow.python.ops.linalg import linear_operator_block_diag
23from tensorflow.python.ops.linalg import linear_operator_composition
24from tensorflow.python.ops.linalg import linear_operator_diag
25from tensorflow.python.ops.linalg import linear_operator_identity
26from tensorflow.python.ops.linalg import linear_operator_kronecker
27from tensorflow.python.ops.linalg import linear_operator_lower_triangular
28from tensorflow.python.ops.linalg import linear_operator_util
30LinearOperatorLowerTriangular = (
31 linear_operator_lower_triangular.LinearOperatorLowerTriangular)
34# By default, compute the Cholesky of the dense matrix, and return a
35# LowerTriangular operator. Methods below specialize this registration.
36@linear_operator_algebra.RegisterCholesky(linear_operator.LinearOperator)
37def _cholesky_linear_operator(linop):
38 return LinearOperatorLowerTriangular(
39 linalg_ops.cholesky(linop.to_dense()),
40 is_non_singular=True,
41 is_self_adjoint=False,
42 is_square=True)
45def _is_llt_product(linop):
46 """Determines if linop = L @ L.H for L = LinearOperatorLowerTriangular."""
47 if len(linop.operators) != 2:
48 return False
49 if not linear_operator_util.is_aat_form(linop.operators):
50 return False
51 return isinstance(linop.operators[0], LinearOperatorLowerTriangular)
54@linear_operator_algebra.RegisterCholesky(
55 linear_operator_composition.LinearOperatorComposition)
56def _cholesky_linear_operator_composition(linop):
57 """Computes Cholesky(LinearOperatorComposition)."""
58 # L @ L.H will be handled with special code below. Why is L @ L.H the most
59 # important special case?
60 # Note that Diag @ Diag.H and Diag @ TriL and TriL @ Diag are already
61 # compressed to Diag or TriL by diag matmul
62 # registration. Similarly for Identity and ScaledIdentity.
63 # So these would not appear in a LinearOperatorComposition unless explicitly
64 # constructed as such. So the most important thing to check is L @ L.H.
65 if not _is_llt_product(linop):
66 return LinearOperatorLowerTriangular(
67 linalg_ops.cholesky(linop.to_dense()),
68 is_non_singular=True,
69 is_self_adjoint=False,
70 is_square=True)
72 left_op = linop.operators[0]
74 # left_op.is_positive_definite ==> op already has positive diag. So return it.
75 if left_op.is_positive_definite:
76 return left_op
78 # Recall that the base class has already verified linop.is_positive_definite,
79 # else linop.cholesky() would have raised.
80 # So in particular, we know the diagonal has nonzero entries.
81 # In the generic case, we make op have positive diag by dividing each row
82 # by the sign of the diag. This is equivalent to setting A = L @ D where D is
83 # diag(sign(1 / L.diag_part())). Then A is lower triangular with positive diag
84 # and A @ A^H = L @ D @ D^H @ L^H = L @ L^H = linop.
85 # This also works for complex L, since sign(x + iy) = exp(i * angle(x + iy)).
86 diag_sign = array_ops.expand_dims(math_ops.sign(left_op.diag_part()), axis=-2)
87 return LinearOperatorLowerTriangular(
88 tril=left_op.tril / diag_sign,
89 is_non_singular=left_op.is_non_singular,
90 # L.is_self_adjoint ==> L is diagonal ==> L @ D is diagonal ==> SA
91 # L.is_self_adjoint is False ==> L not diagonal ==> L @ D not diag ...
92 is_self_adjoint=left_op.is_self_adjoint,
93 # L.is_positive_definite ==> L has positive diag ==> L = L @ D
94 # ==> (L @ D).is_positive_definite.
95 # L.is_positive_definite is False could result in L @ D being PD or not..
96 # Consider L = [[1, 0], [-2, 1]] and quadratic form with x = [1, 1].
97 # Note we will already return left_op if left_op.is_positive_definite
98 # above, but to be explicit write this below.
99 is_positive_definite=True if left_op.is_positive_definite else None,
100 is_square=True,
101 )
104@linear_operator_algebra.RegisterCholesky(
105 linear_operator_diag.LinearOperatorDiag)
106def _cholesky_diag(diag_operator):
107 return linear_operator_diag.LinearOperatorDiag(
108 math_ops.sqrt(diag_operator.diag),
109 is_non_singular=True,
110 is_self_adjoint=True,
111 is_positive_definite=True,
112 is_square=True)
115@linear_operator_algebra.RegisterCholesky(
116 linear_operator_identity.LinearOperatorIdentity)
117def _cholesky_identity(identity_operator):
118 return linear_operator_identity.LinearOperatorIdentity(
119 num_rows=identity_operator._num_rows, # pylint: disable=protected-access
120 batch_shape=identity_operator.batch_shape,
121 dtype=identity_operator.dtype,
122 is_non_singular=True,
123 is_self_adjoint=True,
124 is_positive_definite=True,
125 is_square=True)
128@linear_operator_algebra.RegisterCholesky(
129 linear_operator_identity.LinearOperatorScaledIdentity)
130def _cholesky_scaled_identity(identity_operator):
131 return linear_operator_identity.LinearOperatorScaledIdentity(
132 num_rows=identity_operator._num_rows, # pylint: disable=protected-access
133 multiplier=math_ops.sqrt(identity_operator.multiplier),
134 is_non_singular=True,
135 is_self_adjoint=True,
136 is_positive_definite=True,
137 is_square=True)
140@linear_operator_algebra.RegisterCholesky(
141 linear_operator_block_diag.LinearOperatorBlockDiag)
142def _cholesky_block_diag(block_diag_operator):
143 # We take the cholesky of each block on the diagonal.
144 return linear_operator_block_diag.LinearOperatorBlockDiag(
145 operators=[
146 operator.cholesky() for operator in block_diag_operator.operators],
147 is_non_singular=True,
148 is_self_adjoint=None, # Let the operators passed in decide.
149 is_square=True)
152@linear_operator_algebra.RegisterCholesky(
153 linear_operator_kronecker.LinearOperatorKronecker)
154def _cholesky_kronecker(kronecker_operator):
155 # Cholesky decomposition of a Kronecker product is the Kronecker product
156 # of cholesky decompositions.
157 return linear_operator_kronecker.LinearOperatorKronecker(
158 operators=[
159 operator.cholesky() for operator in kronecker_operator.operators],
160 is_non_singular=True,
161 is_self_adjoint=None, # Let the operators passed in decide.
162 is_square=True)