Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/adjoint_registrations.py: 62%

47 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Registrations for LinearOperator.adjoint.""" 

16 

17from tensorflow.python.ops import math_ops 

18from tensorflow.python.ops.linalg import linear_operator 

19from tensorflow.python.ops.linalg import linear_operator_adjoint 

20from tensorflow.python.ops.linalg import linear_operator_algebra 

21from tensorflow.python.ops.linalg import linear_operator_block_diag 

22from tensorflow.python.ops.linalg import linear_operator_circulant 

23from tensorflow.python.ops.linalg import linear_operator_diag 

24from tensorflow.python.ops.linalg import linear_operator_householder 

25from tensorflow.python.ops.linalg import linear_operator_identity 

26from tensorflow.python.ops.linalg import linear_operator_kronecker 

27 

28 

29# By default, return LinearOperatorAdjoint which switched the .matmul 

30# and .solve methods. 

31@linear_operator_algebra.RegisterAdjoint(linear_operator.LinearOperator) 

32def _adjoint_linear_operator(linop): 

33 return linear_operator_adjoint.LinearOperatorAdjoint( 

34 linop, 

35 is_non_singular=linop.is_non_singular, 

36 is_self_adjoint=linop.is_self_adjoint, 

37 is_positive_definite=linop.is_positive_definite, 

38 is_square=linop.is_square) 

39 

40 

41@linear_operator_algebra.RegisterAdjoint( 

42 linear_operator_adjoint.LinearOperatorAdjoint) 

43def _adjoint_adjoint_linear_operator(linop): 

44 return linop.operator 

45 

46 

47@linear_operator_algebra.RegisterAdjoint( 

48 linear_operator_identity.LinearOperatorIdentity) 

49def _adjoint_identity(identity_operator): 

50 return identity_operator 

51 

52 

53@linear_operator_algebra.RegisterAdjoint( 

54 linear_operator_identity.LinearOperatorScaledIdentity) 

55def _adjoint_scaled_identity(identity_operator): 

56 multiplier = identity_operator.multiplier 

57 if multiplier.dtype.is_complex: 

58 multiplier = math_ops.conj(multiplier) 

59 

60 return linear_operator_identity.LinearOperatorScaledIdentity( 

61 num_rows=identity_operator._num_rows, # pylint: disable=protected-access 

62 multiplier=multiplier, 

63 is_non_singular=identity_operator.is_non_singular, 

64 is_self_adjoint=identity_operator.is_self_adjoint, 

65 is_positive_definite=identity_operator.is_positive_definite, 

66 is_square=True) 

67 

68 

69@linear_operator_algebra.RegisterAdjoint( 

70 linear_operator_diag.LinearOperatorDiag) 

71def _adjoint_diag(diag_operator): 

72 diag = diag_operator.diag 

73 if diag.dtype.is_complex: 

74 diag = math_ops.conj(diag) 

75 

76 return linear_operator_diag.LinearOperatorDiag( 

77 diag=diag, 

78 is_non_singular=diag_operator.is_non_singular, 

79 is_self_adjoint=diag_operator.is_self_adjoint, 

80 is_positive_definite=diag_operator.is_positive_definite, 

81 is_square=True) 

82 

83 

84@linear_operator_algebra.RegisterAdjoint( 

85 linear_operator_block_diag.LinearOperatorBlockDiag) 

86def _adjoint_block_diag(block_diag_operator): 

87 # We take the adjoint of each block on the diagonal. 

88 return linear_operator_block_diag.LinearOperatorBlockDiag( 

89 operators=[ 

90 operator.adjoint() for operator in block_diag_operator.operators], 

91 is_non_singular=block_diag_operator.is_non_singular, 

92 is_self_adjoint=block_diag_operator.is_self_adjoint, 

93 is_positive_definite=block_diag_operator.is_positive_definite, 

94 is_square=True) 

95 

96 

97@linear_operator_algebra.RegisterAdjoint( 

98 linear_operator_kronecker.LinearOperatorKronecker) 

99def _adjoint_kronecker(kronecker_operator): 

100 # Adjoint of a Kronecker product is the Kronecker product 

101 # of adjoints. 

102 return linear_operator_kronecker.LinearOperatorKronecker( 

103 operators=[ 

104 operator.adjoint() for operator in kronecker_operator.operators], 

105 is_non_singular=kronecker_operator.is_non_singular, 

106 is_self_adjoint=kronecker_operator.is_self_adjoint, 

107 is_positive_definite=kronecker_operator.is_positive_definite, 

108 is_square=True) 

109 

110 

111@linear_operator_algebra.RegisterAdjoint( 

112 linear_operator_circulant._BaseLinearOperatorCirculant) # pylint: disable=protected-access 

113def _adjoint_circulant(circulant_operator): 

114 spectrum = circulant_operator.spectrum 

115 if spectrum.dtype.is_complex: 

116 spectrum = math_ops.conj(spectrum) 

117 

118 # Conjugating the spectrum is sufficient to get the adjoint. 

119 return circulant_operator.__class__( 

120 spectrum=spectrum, 

121 is_non_singular=circulant_operator.is_non_singular, 

122 is_self_adjoint=circulant_operator.is_self_adjoint, 

123 is_positive_definite=circulant_operator.is_positive_definite, 

124 is_square=True) 

125 

126 

127@linear_operator_algebra.RegisterAdjoint( 

128 linear_operator_householder.LinearOperatorHouseholder) 

129def _adjoint_householder(householder_operator): 

130 return householder_operator