Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_ops.py: 37%

154 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""CSR Sparse Matrix Operations.""" 

16 

17import abc 

18import collections 

19 

20# pylint: disable=g-direct-tensorflow-import, wildcard-import 

21from tensorflow.python.eager import context 

22from tensorflow.python.framework import cpp_shape_inference_pb2 

23from tensorflow.python.framework import dtypes 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import sparse_tensor 

26from tensorflow.python.framework import tensor_shape 

27from tensorflow.python.ops import array_ops 

28from tensorflow.python.ops import math_ops 

29from tensorflow.python.ops import resource_variable_ops 

30from tensorflow.python.ops.linalg.sparse import gen_sparse_csr_matrix_ops as sm_ops 

31from tensorflow.python.ops.linalg.sparse.gen_sparse_csr_matrix_ops import * 

32 

33 

34__all__ = [ 

35 "SparseMatrix", 

36 "CSRSparseMatrix", 

37 "matmul", 

38 "dense_shape_and_type", 

39] 

40# pylint: disable=invalid-name 

41__all__ += [_x for _x in dir(sm_ops) if not _x.startswith("_")] 

42 

43 

44class DenseShapeAndType( 

45 collections.namedtuple("DenseShapeAndType", ("shape", "dtype"))): 

46 pass 

47 

48 

49def _get_handle_data(tensor): 

50 return resource_variable_ops.get_eager_safe_handle_data(tensor) 

51 

52 

53def _create_handle_data_proto(shape_proto, dtype_enum): 

54 """Create handle data based on shape and dtype protos.""" 

55 variant_shape_and_type_data = \ 

56 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() 

57 variant_shape_and_type_data.is_set = True 

58 # NOTE(ebrevdo): shape_and_type lacks append() in some versions of protobuf. 

59 variant_shape_and_type_data.shape_and_type.extend([ 

60 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 

61 shape=shape_proto, dtype=dtype_enum) 

62 ]) 

63 return variant_shape_and_type_data 

64 

65 

66def _make_handle_data(tensor): 

67 """Create handle data based on tensor shape and dtype.""" 

68 return _create_handle_data_proto(tensor.shape.as_proto(), 

69 tensor.dtype.as_datatype_enum) 

70 

71 

72def get_shape_and_type(matrix): 

73 """Return matrix's shape and type if available.""" 

74 handle_data = getattr(matrix, "_handle_data", None) 

75 if handle_data is None: 

76 return None 

77 if len(handle_data.shape_and_type) != 1: 

78 raise ValueError( 

79 "shape_and_type array in _handle_data must have length one, but saw: %d" 

80 % len(handle_data.shape_and_type)) 

81 return handle_data.shape_and_type[0] 

82 

83 

84def dense_shape_and_type(matrix): 

85 """Get dense shape and dtype of the tf.Tensor containing the matrix. 

86 

87 Args: 

88 matrix: A `tf.Tensor` of type `tf.variant` storing a sparse matrix. 

89 

90 Returns: 

91 An instance of `ShapeAndType` with properties `shape` (a `tf.TensorShape`) 

92 and `dtype` (a `tf.DType`). 

93 

94 Raises: 

95 TypeError: if `matrix` is not a tensor or its dtype is not variant. 

96 ValueError: if `matrix` lacks static handle data containing the dense 

97 shape and dtype. 

98 """ 

99 if not isinstance(matrix, ops.Tensor): 

100 raise TypeError("matrix should be a tensor, but saw: %s" % (matrix,)) 

101 if matrix.dtype != dtypes.variant: 

102 raise TypeError( 

103 "expected matrix to be type tf.variant, but saw: %s" % (matrix.dtype,)) 

104 handle_data = _get_handle_data(matrix) 

105 if not handle_data or not handle_data.is_set: 

106 raise ValueError("matrix has missing handle data: %s" % (matrix,)) 

107 if len(handle_data.shape_and_type) != 1: 

108 raise ValueError("len(matrix.handle_data.shape_and_type) != 1: '%s'" % 

109 (handle_data.shape_and_type,)) 

110 return DenseShapeAndType( 

111 tensor_shape.TensorShape(handle_data.shape_and_type[0].shape), 

112 dtypes.DType(handle_data.shape_and_type[0].dtype)) 

113 

114 

115def matmul_shape_inference(a, b, c, transpose_a, transpose_b, adjoint_a, 

116 adjoint_b): 

117 """Helper function for matmul to set the result matrix's handle data.""" 

118 c_handle = getattr(c, "_handle_data", None) 

119 a_shape_and_type = get_shape_and_type(a) 

120 b_shape_and_type = get_shape_and_type(b) 

121 if (c_handle is None and a_shape_and_type is not None and 

122 b_shape_and_type is not None): 

123 

124 transpose_a = transpose_a or adjoint_a 

125 transpose_b = transpose_b or adjoint_b 

126 

127 a_shape = a_shape_and_type.shape 

128 b_shape = b_shape_and_type.shape 

129 rank = len(a_shape.dim) 

130 

131 # Creates the output shape. 

132 c_rows = a_shape.dim[rank - (1 if transpose_a else 2)].size 

133 c_cols = b_shape.dim[rank - (2 if transpose_b else 1)].size 

134 c_shape = tensor_shape.TensorShape(a_shape) 

135 c_shape = tensor_shape.TensorShape(c_shape[:rank - 2] + [c_rows, c_cols]) 

136 c_handle = _create_handle_data_proto(c_shape.as_proto(), 

137 a_shape_and_type.dtype) 

138 return c_handle 

139 

140 

141def matmul(a, 

142 b, 

143 transpose_a=False, 

144 transpose_b=False, 

145 adjoint_a=False, 

146 adjoint_b=False, 

147 name=None): 

148 """Perform a sparse matrix matmul between `a` and `b`. 

149 

150 Performs a contraction between `a` and `b` along the two innermost dimensions. 

151 If both `a` and `b` are instances of `SparseMatrix`, returns a new instance 

152 of `SparseMatrix` (same type as `a`). If one is not an instance of 

153 `SparseMatrix`, returns a dense `Tensor`: 

154 

155 ``` 

156 c = opA(a) . opB(b) 

157 ``` 

158 where `opA` (resp. `opB`) is the transpose or hermitian transpose depending 

159 on the values of `transpose_a` (resp. `transpose_b`) and `adjoint_a` 

160 (resp. `adjoint_b`). 

161 

162 Args: 

163 a: `Tensor` or `SparseMatrix`, having rank `2` or `3`. 

164 b: `Tensor` or `SparseMatrix`, having rank `2` or `3`. 

165 transpose_a: Python `bool`. 

166 transpose_b: Python `bool`. 

167 adjoint_a: Python `bool`. 

168 adjoint_b: Python `bool`. 

169 name: Optional name to use when creating ops. 

170 

171 Returns: 

172 A `SparseMatrix` if both `a` and `b` are instances of `SparseMatrix`, 

173 otherwise a dense `Tensor`. 

174 """ 

175 if not isinstance(a, SparseMatrix) and not isinstance(b, SparseMatrix): 

176 return math_ops.matmul( 

177 a, 

178 b, 

179 transpose_a=transpose_a, 

180 transpose_b=transpose_b, 

181 adjoint_a=adjoint_a, 

182 adjoint_b=adjoint_b, 

183 name=name) 

184 

185 # pylint: disable=protected-access 

186 a_matrix = a._matrix if isinstance(a, SparseMatrix) else a 

187 b_matrix = b._matrix if isinstance(b, SparseMatrix) else b 

188 with ops.name_scope(name, "SparseMatrixMatMul", [a_matrix, b_matrix]): 

189 if isinstance(a, SparseMatrix) and isinstance(b, SparseMatrix): 

190 if not (isinstance(a, type(b)) or isinstance(b, type(a))): 

191 raise TypeError("SparseMatrix types don't inherit from each other: " 

192 "%s and %s" % (type(a), type(b))) 

193 c = sm_ops.sparse_matrix_sparse_mat_mul( 

194 a_matrix, 

195 b_matrix, 

196 transpose_a=transpose_a, 

197 transpose_b=transpose_b, 

198 adjoint_a=adjoint_a, 

199 adjoint_b=adjoint_b, 

200 type=a.dtype) 

201 

202 # In eager mode, shape inference functions are not called, and the output 

203 # shape is not set. We have to infer the output shape here. 

204 # TODO(penporn): Set this from the C++ kernel instead. 

205 c_handle = matmul_shape_inference(a_matrix, b_matrix, c, transpose_a, 

206 transpose_b, adjoint_a, adjoint_b) 

207 return a._from_matrix(c, handle_data=c_handle) 

208 

209 elif isinstance(a, SparseMatrix): 

210 return sm_ops.sparse_matrix_mat_mul( 

211 a_matrix, 

212 b, 

213 transpose_a=transpose_a, 

214 transpose_b=transpose_b, 

215 adjoint_a=adjoint_a, 

216 adjoint_b=adjoint_b) 

217 else: 

218 # opA(A) . opB(B) = t(nopB(B) . nopA(A)) 

219 if not adjoint_a and not adjoint_b: 

220 return sm_ops.sparse_matrix_mat_mul( 

221 b_matrix, 

222 a, 

223 transpose_a=not transpose_b, 

224 transpose_b=not transpose_a, 

225 transpose_output=True) 

226 elif not transpose_a and not transpose_b: 

227 return sm_ops.sparse_matrix_mat_mul( 

228 b_matrix, 

229 a, 

230 adjoint_a=not adjoint_b, 

231 adjoint_b=not adjoint_a, 

232 transpose_output=True, 

233 conjugate_output=True) 

234 else: 

235 return sm_ops.sparse_matrix_mat_mul( 

236 b_matrix, 

237 math_ops.conj(a), 

238 transpose_output=True, 

239 conjugate_output=adjoint_b) 

240 

241 

242class SparseMatrix(metaclass=abc.ABCMeta): 

243 """Abstract class for sparse matrix types.""" 

244 

245 @abc.abstractmethod 

246 def __init__(self): 

247 self._eager_mode = context.executing_eagerly() 

248 

249 @abc.abstractproperty 

250 def _matrix(self): 

251 pass 

252 

253 @abc.abstractmethod 

254 def _from_matrix(self, matrix, handle_data=None): 

255 pass 

256 

257 @abc.abstractmethod 

258 def to_dense(self): 

259 pass 

260 

261 @abc.abstractmethod 

262 def to_sparse_tensor(self): 

263 pass 

264 

265 @property 

266 def graph(self): 

267 return self._matrix.graph 

268 

269 @property 

270 def shape(self): 

271 return dense_shape_and_type(self._matrix).shape 

272 

273 @property 

274 def dtype(self): 

275 return dense_shape_and_type(self._matrix).dtype 

276 

277 @property 

278 def eager_handle_data(self): 

279 """Return the matrix's handle data iff in eager mode.""" 

280 return _get_handle_data(self._matrix) if self._eager_mode else None 

281 

282 def conj(self): 

283 return self._from_matrix( 

284 math_ops.conj(self._matrix), self.eager_handle_data) 

285 

286 def hermitian_transpose(self): 

287 """Return the hermitian transpose of the matrix.""" 

288 return self._from_matrix( 

289 sm_ops.sparse_matrix_transpose( 

290 self._matrix, conjugate=True, type=self.dtype), 

291 self.eager_handle_data) 

292 

293 def nnz(self): 

294 """Number of stored values, including explicit zeros.""" 

295 return sm_ops.sparse_matrix_nnz(self._matrix) 

296 

297 nonzero = nnz 

298 

299 def sorted_indices(self): 

300 # TODO(ebrevdo): A more efficient implementation? 

301 return self.to_sparse_tensor().indices 

302 

303 def transpose(self): 

304 return self._from_matrix( 

305 sm_ops.sparse_matrix_transpose(self._matrix, type=self.dtype), 

306 self.eager_handle_data) 

307 

308 

309class CSRSparseMatrix(SparseMatrix): 

310 """(Optionally batched) CSR Sparse Matrix.""" 

311 

312 def __init__(self, value, indices=None, name=None): 

313 """Construct a CSRSparseMatrix from a dense matrix or SparseTensor. 

314 

315 Args: 

316 value: A dense `2D` or `3D` Tensor or `SparseTensor`. 

317 indices: The nonzero indices of `value` 

318 (if `value` is not a `SparseTensor`). 

319 name: Optional op name. 

320 

321 Raises: 

322 ValueError: if `value` is a `SparseTensor` and `indices` is not `None`. 

323 """ 

324 del name # Unused. 

325 super(CSRSparseMatrix, self).__init__() 

326 if isinstance(value, sparse_tensor.SparseTensor): 

327 if indices is not None: 

328 raise ValueError("indices must be None if value is a SparseTensor.") 

329 self._dtype = value.dtype 

330 self._csr_matrix = sm_ops.sparse_tensor_to_csr_sparse_matrix( 

331 indices=value.indices, 

332 values=value.values, 

333 dense_shape=value.dense_shape) 

334 else: 

335 value = ops.convert_to_tensor(value) 

336 self._dtype = value.dtype 

337 if indices is not None: 

338 indices = ops.convert_to_tensor(indices, dtype=dtypes.int64) 

339 else: 

340 indices = array_ops.stop_gradient(array_ops.where(value)) 

341 self._csr_matrix = sm_ops.dense_to_csr_sparse_matrix(value, indices) 

342 

343 # Eager mode doesn't call shape inference functions, so we have to set the 

344 # shape and dtype handle data directly. 

345 if self._eager_mode: 

346 # pylint: disable=protected-access 

347 self._csr_matrix._handle_data = _make_handle_data(value) 

348 # pylint: enable=protected-access 

349 

350 @property 

351 def _matrix(self): 

352 return self._csr_matrix 

353 

354 def _from_matrix(self, matrix, handle_data=None): 

355 assert isinstance(matrix, ops.Tensor) and matrix.dtype == dtypes.variant 

356 ret = type(self).__new__(type(self)) 

357 # pylint: disable=protected-access 

358 ret._dtype = self._dtype 

359 if self._eager_mode: 

360 if matrix._handle_data is None: 

361 matrix._handle_data = handle_data 

362 assert matrix._handle_data is not None 

363 ret._csr_matrix = matrix 

364 # pylint: enable=protected-access 

365 return ret 

366 

367 def to_dense(self): 

368 return sm_ops.csr_sparse_matrix_to_dense(self._matrix, type=self.dtype) 

369 

370 def to_sparse_tensor(self): 

371 r = sm_ops.csr_sparse_matrix_to_sparse_tensor(self._matrix, type=self.dtype) 

372 return sparse_tensor.SparseTensor( 

373 indices=r.indices, values=r.values, dense_shape=r.dense_shape)