Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_permutation.py: 41%

87 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""`LinearOperator` acting like a permutation matrix.""" 

16 

17import numpy as np 

18 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.framework import ops 

21from tensorflow.python.framework import tensor_conversion 

22from tensorflow.python.framework import tensor_util 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import control_flow_ops 

25from tensorflow.python.ops import math_ops 

26from tensorflow.python.ops import sort_ops 

27from tensorflow.python.ops.linalg import linalg_impl as linalg 

28from tensorflow.python.ops.linalg import linear_operator 

29from tensorflow.python.ops.linalg import linear_operator_util 

30from tensorflow.python.util.tf_export import tf_export 

31 

32__all__ = ["LinearOperatorPermutation",] 

33 

34 

35@tf_export("linalg.LinearOperatorPermutation") 

36@linear_operator.make_composite_tensor 

37class LinearOperatorPermutation(linear_operator.LinearOperator): 

38 """`LinearOperator` acting like a [batch] of permutation matrices. 

39 

40 This operator acts like a [batch] of permutations with shape 

41 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 

42 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 

43 an `N x N` matrix. This matrix `A` is not materialized, but for 

44 purposes of broadcasting this shape will be relevant. 

45 

46 `LinearOperatorPermutation` is initialized with a (batch) vector. 

47 

48 A permutation, is defined by an integer vector `v` whose values are unique 

49 and are in the range `[0, ... n]`. Applying the permutation on an input 

50 matrix has the folllowing meaning: the value of `v` at index `i` 

51 says to move the `v[i]`-th row of the input matrix to the `i`-th row. 

52 Because all values are unique, this will result in a permutation of the 

53 rows the input matrix. Note, that the permutation vector `v` has the same 

54 semantics as `tf.transpose`. 

55 

56 ```python 

57 # Create a 3 x 3 permutation matrix that swaps the last two columns. 

58 vec = [0, 2, 1] 

59 operator = LinearOperatorPermutation(vec) 

60 

61 operator.to_dense() 

62 ==> [[1., 0., 0.] 

63 [0., 0., 1.] 

64 [0., 1., 0.]] 

65 

66 operator.shape 

67 ==> [3, 3] 

68 

69 # This will be zero. 

70 operator.log_abs_determinant() 

71 ==> scalar Tensor 

72 

73 x = ... Shape [3, 4] Tensor 

74 operator.matmul(x) 

75 ==> Shape [3, 4] Tensor 

76 ``` 

77 

78 #### Shape compatibility 

79 

80 This operator acts on [batch] matrix with compatible shape. 

81 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 

82 

83 ``` 

84 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 

85 x.shape = [C1,...,Cc] + [N, R], 

86 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 

87 ``` 

88 

89 #### Matrix property hints 

90 

91 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 

92 for `X = non_singular, self_adjoint, positive_definite, square`. 

93 These have the following meaning: 

94 

95 * If `is_X == True`, callers should expect the operator to have the 

96 property `X`. This is a promise that should be fulfilled, but is *not* a 

97 runtime assert. For example, finite floating point precision may result 

98 in these promises being violated. 

99 * If `is_X == False`, callers should expect the operator to not have `X`. 

100 * If `is_X == None` (the default), callers should have no expectation either 

101 way. 

102 """ 

103 

104 def __init__(self, 

105 perm, 

106 dtype=dtypes.float32, 

107 is_non_singular=None, 

108 is_self_adjoint=None, 

109 is_positive_definite=None, 

110 is_square=None, 

111 name="LinearOperatorPermutation"): 

112 r"""Initialize a `LinearOperatorPermutation`. 

113 

114 Args: 

115 perm: Shape `[B1,...,Bb, N]` Integer `Tensor` with `b >= 0` 

116 `N >= 0`. An integer vector that represents the permutation to apply. 

117 Note that this argument is same as `tf.transpose`. However, this 

118 permutation is applied on the rows, while the permutation in 

119 `tf.transpose` is applied on the dimensions of the `Tensor`. `perm` 

120 is required to have unique entries from `{0, 1, ... N-1}`. 

121 dtype: The `dtype` of arguments to this operator. Default: `float32`. 

122 Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, 

123 `complex128`. 

124 is_non_singular: Expect that this operator is non-singular. 

125 is_self_adjoint: Expect that this operator is equal to its hermitian 

126 transpose. This is autoset to true 

127 is_positive_definite: Expect that this operator is positive definite, 

128 meaning the quadratic form `x^H A x` has positive real part for all 

129 nonzero `x`. Note that we do not require the operator to be 

130 self-adjoint to be positive-definite. See: 

131 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 

132 This is autoset to false. 

133 is_square: Expect that this operator acts like square [batch] matrices. 

134 This is autoset to true. 

135 name: A name for this `LinearOperator`. 

136 

137 Raises: 

138 ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is 

139 not `False` or `is_square` is not `True`. 

140 """ 

141 parameters = dict( 

142 perm=perm, 

143 dtype=dtype, 

144 is_non_singular=is_non_singular, 

145 is_self_adjoint=is_self_adjoint, 

146 is_positive_definite=is_positive_definite, 

147 is_square=is_square, 

148 name=name 

149 ) 

150 

151 with ops.name_scope(name, values=[perm]): 

152 self._perm = linear_operator_util.convert_nonref_to_tensor( 

153 perm, name="perm") 

154 self._check_perm(self._perm) 

155 

156 # Check and auto-set hints. 

157 if is_non_singular is False: # pylint:disable=g-bool-id-comparison 

158 raise ValueError(f"A Permutation operator is always non-singular. " 

159 f"Expected argument `is_non_singular` to be True. " 

160 f"Received: {is_non_singular}.") 

161 

162 if is_square is False: # pylint:disable=g-bool-id-comparison 

163 raise ValueError(f"A Permutation operator is always square. " 

164 f"Expected argument `is_square` to be True. " 

165 f"Received: {is_square}.") 

166 is_square = True 

167 

168 super(LinearOperatorPermutation, self).__init__( 

169 dtype=dtype, 

170 is_non_singular=is_non_singular, 

171 is_self_adjoint=is_self_adjoint, 

172 is_positive_definite=is_positive_definite, 

173 is_square=is_square, 

174 parameters=parameters, 

175 name=name) 

176 

177 def _check_perm(self, perm): 

178 """Static check of perm.""" 

179 if (perm.shape.ndims is not None and perm.shape.ndims < 1): 

180 raise ValueError(f"Argument `perm` must have at least 1 dimension. " 

181 f"Received: {perm}.") 

182 if not perm.dtype.is_integer: 

183 raise TypeError(f"Argument `perm` must be integer dtype. " 

184 f"Received: {perm}.") 

185 # Check that the permutation satisfies the uniqueness constraint. 

186 static_perm = tensor_util.constant_value(perm) 

187 if static_perm is not None: 

188 sorted_perm = np.sort(static_perm, axis=-1) 

189 if np.any(sorted_perm != np.arange(0, static_perm.shape[-1])): 

190 raise ValueError( 

191 f"Argument `perm` must be a vector of unique integers from " 

192 f"0 to {static_perm.shape[-1] - 1}.") 

193 

194 def _shape(self): 

195 perm_shape = self._perm.shape 

196 return perm_shape.concatenate(perm_shape[-1:]) 

197 

198 def _shape_tensor(self): 

199 perm_shape = array_ops.shape(self._perm) 

200 k = perm_shape[-1] 

201 return array_ops.concat((perm_shape, [k]), 0) 

202 

203 def _assert_non_singular(self): 

204 return control_flow_ops.no_op("assert_non_singular") 

205 

206 def _domain_dimension_tensor(self, perm=None): 

207 perm = perm if perm is not None else self.perm 

208 return array_ops.shape(perm)[-1] 

209 

210 def _matmul(self, x, adjoint=False, adjoint_arg=False): 

211 perm = tensor_conversion.convert_to_tensor_v2_with_dispatch(self.perm) 

212 if adjoint and not self.is_self_adjoint: 

213 # TODO(srvasude): invert_permutation doesn't work on batches so we use 

214 # argsort. 

215 perm = sort_ops.argsort(perm, axis=-1) 

216 x = linalg.adjoint(x) if adjoint_arg else x 

217 

218 # We need to broadcast x and the permutation since tf.gather doesn't 

219 # broadcast. 

220 broadcast_shape = array_ops.broadcast_dynamic_shape( 

221 array_ops.shape(x)[:-1], array_ops.shape(perm)) 

222 k = array_ops.shape(x)[-1] 

223 broadcast_x_shape = array_ops.concat([broadcast_shape, [k]], axis=-1) 

224 x = array_ops.broadcast_to(x, broadcast_x_shape) 

225 perm = array_ops.broadcast_to(perm, broadcast_shape) 

226 

227 m = array_ops.shape(x)[-2] 

228 x = array_ops.reshape(x, [-1, m, k]) 

229 perm = array_ops.reshape(perm, [-1, m]) 

230 

231 y = array_ops.gather(x, perm, axis=-2, batch_dims=1) 

232 return array_ops.reshape(y, broadcast_x_shape) 

233 

234 # TODO(srvasude): Permutation parity is equivalent to the determinant. 

235 

236 def _log_abs_determinant(self): 

237 # Permutation matrices have determinant +/- 1. 

238 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) 

239 

240 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 

241 # The inverse of a permutation matrix is the transpose matrix. 

242 # Apply a matmul and flip the adjoint bit. 

243 return self._matmul(rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg) 

244 

245 def _to_dense(self): 

246 perm = tensor_conversion.convert_to_tensor_v2_with_dispatch(self.perm) 

247 return math_ops.cast(math_ops.equal( 

248 math_ops.range(0, self._domain_dimension_tensor(perm)), 

249 perm[..., array_ops.newaxis]), self.dtype) 

250 

251 def _diag_part(self): 

252 perm = tensor_conversion.convert_to_tensor_v2_with_dispatch(self.perm) 

253 return math_ops.cast(math_ops.equal( 

254 math_ops.range(0, self._domain_dimension_tensor(perm)), 

255 perm), self.dtype) 

256 

257 def _cond(self): 

258 # Permutation matrices are rotations which have condition number 1. 

259 return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype) 

260 

261 @property 

262 def perm(self): 

263 return self._perm 

264 

265 @property 

266 def _composite_tensor_fields(self): 

267 return ("perm", "dtype") 

268 

269 @property 

270 def _experimental_parameter_ndims_to_matrix_ndims(self): 

271 return {"perm": 1}