Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_ops.py: 37%
154 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""CSR Sparse Matrix Operations."""
17import abc
18import collections
20# pylint: disable=g-direct-tensorflow-import, wildcard-import
21from tensorflow.python.eager import context
22from tensorflow.python.framework import cpp_shape_inference_pb2
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import resource_variable_ops
30from tensorflow.python.ops.linalg.sparse import gen_sparse_csr_matrix_ops as sm_ops
31from tensorflow.python.ops.linalg.sparse.gen_sparse_csr_matrix_ops import *
34__all__ = [
35 "SparseMatrix",
36 "CSRSparseMatrix",
37 "matmul",
38 "dense_shape_and_type",
39]
40# pylint: disable=invalid-name
41__all__ += [_x for _x in dir(sm_ops) if not _x.startswith("_")]
44class DenseShapeAndType(
45 collections.namedtuple("DenseShapeAndType", ("shape", "dtype"))):
46 pass
49def _get_handle_data(tensor):
50 return resource_variable_ops.get_eager_safe_handle_data(tensor)
53def _create_handle_data_proto(shape_proto, dtype_enum):
54 """Create handle data based on shape and dtype protos."""
55 variant_shape_and_type_data = \
56 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
57 variant_shape_and_type_data.is_set = True
58 # NOTE(ebrevdo): shape_and_type lacks append() in some versions of protobuf.
59 variant_shape_and_type_data.shape_and_type.extend([
60 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
61 shape=shape_proto, dtype=dtype_enum)
62 ])
63 return variant_shape_and_type_data
66def _make_handle_data(tensor):
67 """Create handle data based on tensor shape and dtype."""
68 return _create_handle_data_proto(tensor.shape.as_proto(),
69 tensor.dtype.as_datatype_enum)
72def get_shape_and_type(matrix):
73 """Return matrix's shape and type if available."""
74 handle_data = getattr(matrix, "_handle_data", None)
75 if handle_data is None:
76 return None
77 if len(handle_data.shape_and_type) != 1:
78 raise ValueError(
79 "shape_and_type array in _handle_data must have length one, but saw: %d"
80 % len(handle_data.shape_and_type))
81 return handle_data.shape_and_type[0]
84def dense_shape_and_type(matrix):
85 """Get dense shape and dtype of the tf.Tensor containing the matrix.
87 Args:
88 matrix: A `tf.Tensor` of type `tf.variant` storing a sparse matrix.
90 Returns:
91 An instance of `ShapeAndType` with properties `shape` (a `tf.TensorShape`)
92 and `dtype` (a `tf.DType`).
94 Raises:
95 TypeError: if `matrix` is not a tensor or its dtype is not variant.
96 ValueError: if `matrix` lacks static handle data containing the dense
97 shape and dtype.
98 """
99 if not isinstance(matrix, ops.Tensor):
100 raise TypeError("matrix should be a tensor, but saw: %s" % (matrix,))
101 if matrix.dtype != dtypes.variant:
102 raise TypeError(
103 "expected matrix to be type tf.variant, but saw: %s" % (matrix.dtype,))
104 handle_data = _get_handle_data(matrix)
105 if not handle_data or not handle_data.is_set:
106 raise ValueError("matrix has missing handle data: %s" % (matrix,))
107 if len(handle_data.shape_and_type) != 1:
108 raise ValueError("len(matrix.handle_data.shape_and_type) != 1: '%s'" %
109 (handle_data.shape_and_type,))
110 return DenseShapeAndType(
111 tensor_shape.TensorShape(handle_data.shape_and_type[0].shape),
112 dtypes.DType(handle_data.shape_and_type[0].dtype))
115def matmul_shape_inference(a, b, c, transpose_a, transpose_b, adjoint_a,
116 adjoint_b):
117 """Helper function for matmul to set the result matrix's handle data."""
118 c_handle = getattr(c, "_handle_data", None)
119 a_shape_and_type = get_shape_and_type(a)
120 b_shape_and_type = get_shape_and_type(b)
121 if (c_handle is None and a_shape_and_type is not None and
122 b_shape_and_type is not None):
124 transpose_a = transpose_a or adjoint_a
125 transpose_b = transpose_b or adjoint_b
127 a_shape = a_shape_and_type.shape
128 b_shape = b_shape_and_type.shape
129 rank = len(a_shape.dim)
131 # Creates the output shape.
132 c_rows = a_shape.dim[rank - (1 if transpose_a else 2)].size
133 c_cols = b_shape.dim[rank - (2 if transpose_b else 1)].size
134 c_shape = tensor_shape.TensorShape(a_shape)
135 c_shape = tensor_shape.TensorShape(c_shape[:rank - 2] + [c_rows, c_cols])
136 c_handle = _create_handle_data_proto(c_shape.as_proto(),
137 a_shape_and_type.dtype)
138 return c_handle
141def matmul(a,
142 b,
143 transpose_a=False,
144 transpose_b=False,
145 adjoint_a=False,
146 adjoint_b=False,
147 name=None):
148 """Perform a sparse matrix matmul between `a` and `b`.
150 Performs a contraction between `a` and `b` along the two innermost dimensions.
151 If both `a` and `b` are instances of `SparseMatrix`, returns a new instance
152 of `SparseMatrix` (same type as `a`). If one is not an instance of
153 `SparseMatrix`, returns a dense `Tensor`:
155 ```
156 c = opA(a) . opB(b)
157 ```
158 where `opA` (resp. `opB`) is the transpose or hermitian transpose depending
159 on the values of `transpose_a` (resp. `transpose_b`) and `adjoint_a`
160 (resp. `adjoint_b`).
162 Args:
163 a: `Tensor` or `SparseMatrix`, having rank `2` or `3`.
164 b: `Tensor` or `SparseMatrix`, having rank `2` or `3`.
165 transpose_a: Python `bool`.
166 transpose_b: Python `bool`.
167 adjoint_a: Python `bool`.
168 adjoint_b: Python `bool`.
169 name: Optional name to use when creating ops.
171 Returns:
172 A `SparseMatrix` if both `a` and `b` are instances of `SparseMatrix`,
173 otherwise a dense `Tensor`.
174 """
175 if not isinstance(a, SparseMatrix) and not isinstance(b, SparseMatrix):
176 return math_ops.matmul(
177 a,
178 b,
179 transpose_a=transpose_a,
180 transpose_b=transpose_b,
181 adjoint_a=adjoint_a,
182 adjoint_b=adjoint_b,
183 name=name)
185 # pylint: disable=protected-access
186 a_matrix = a._matrix if isinstance(a, SparseMatrix) else a
187 b_matrix = b._matrix if isinstance(b, SparseMatrix) else b
188 with ops.name_scope(name, "SparseMatrixMatMul", [a_matrix, b_matrix]):
189 if isinstance(a, SparseMatrix) and isinstance(b, SparseMatrix):
190 if not (isinstance(a, type(b)) or isinstance(b, type(a))):
191 raise TypeError("SparseMatrix types don't inherit from each other: "
192 "%s and %s" % (type(a), type(b)))
193 c = sm_ops.sparse_matrix_sparse_mat_mul(
194 a_matrix,
195 b_matrix,
196 transpose_a=transpose_a,
197 transpose_b=transpose_b,
198 adjoint_a=adjoint_a,
199 adjoint_b=adjoint_b,
200 type=a.dtype)
202 # In eager mode, shape inference functions are not called, and the output
203 # shape is not set. We have to infer the output shape here.
204 # TODO(penporn): Set this from the C++ kernel instead.
205 c_handle = matmul_shape_inference(a_matrix, b_matrix, c, transpose_a,
206 transpose_b, adjoint_a, adjoint_b)
207 return a._from_matrix(c, handle_data=c_handle)
209 elif isinstance(a, SparseMatrix):
210 return sm_ops.sparse_matrix_mat_mul(
211 a_matrix,
212 b,
213 transpose_a=transpose_a,
214 transpose_b=transpose_b,
215 adjoint_a=adjoint_a,
216 adjoint_b=adjoint_b)
217 else:
218 # opA(A) . opB(B) = t(nopB(B) . nopA(A))
219 if not adjoint_a and not adjoint_b:
220 return sm_ops.sparse_matrix_mat_mul(
221 b_matrix,
222 a,
223 transpose_a=not transpose_b,
224 transpose_b=not transpose_a,
225 transpose_output=True)
226 elif not transpose_a and not transpose_b:
227 return sm_ops.sparse_matrix_mat_mul(
228 b_matrix,
229 a,
230 adjoint_a=not adjoint_b,
231 adjoint_b=not adjoint_a,
232 transpose_output=True,
233 conjugate_output=True)
234 else:
235 return sm_ops.sparse_matrix_mat_mul(
236 b_matrix,
237 math_ops.conj(a),
238 transpose_output=True,
239 conjugate_output=adjoint_b)
242class SparseMatrix(metaclass=abc.ABCMeta):
243 """Abstract class for sparse matrix types."""
245 @abc.abstractmethod
246 def __init__(self):
247 self._eager_mode = context.executing_eagerly()
249 @abc.abstractproperty
250 def _matrix(self):
251 pass
253 @abc.abstractmethod
254 def _from_matrix(self, matrix, handle_data=None):
255 pass
257 @abc.abstractmethod
258 def to_dense(self):
259 pass
261 @abc.abstractmethod
262 def to_sparse_tensor(self):
263 pass
265 @property
266 def graph(self):
267 return self._matrix.graph
269 @property
270 def shape(self):
271 return dense_shape_and_type(self._matrix).shape
273 @property
274 def dtype(self):
275 return dense_shape_and_type(self._matrix).dtype
277 @property
278 def eager_handle_data(self):
279 """Return the matrix's handle data iff in eager mode."""
280 return _get_handle_data(self._matrix) if self._eager_mode else None
282 def conj(self):
283 return self._from_matrix(
284 math_ops.conj(self._matrix), self.eager_handle_data)
286 def hermitian_transpose(self):
287 """Return the hermitian transpose of the matrix."""
288 return self._from_matrix(
289 sm_ops.sparse_matrix_transpose(
290 self._matrix, conjugate=True, type=self.dtype),
291 self.eager_handle_data)
293 def nnz(self):
294 """Number of stored values, including explicit zeros."""
295 return sm_ops.sparse_matrix_nnz(self._matrix)
297 nonzero = nnz
299 def sorted_indices(self):
300 # TODO(ebrevdo): A more efficient implementation?
301 return self.to_sparse_tensor().indices
303 def transpose(self):
304 return self._from_matrix(
305 sm_ops.sparse_matrix_transpose(self._matrix, type=self.dtype),
306 self.eager_handle_data)
309class CSRSparseMatrix(SparseMatrix):
310 """(Optionally batched) CSR Sparse Matrix."""
312 def __init__(self, value, indices=None, name=None):
313 """Construct a CSRSparseMatrix from a dense matrix or SparseTensor.
315 Args:
316 value: A dense `2D` or `3D` Tensor or `SparseTensor`.
317 indices: The nonzero indices of `value`
318 (if `value` is not a `SparseTensor`).
319 name: Optional op name.
321 Raises:
322 ValueError: if `value` is a `SparseTensor` and `indices` is not `None`.
323 """
324 del name # Unused.
325 super(CSRSparseMatrix, self).__init__()
326 if isinstance(value, sparse_tensor.SparseTensor):
327 if indices is not None:
328 raise ValueError("indices must be None if value is a SparseTensor.")
329 self._dtype = value.dtype
330 self._csr_matrix = sm_ops.sparse_tensor_to_csr_sparse_matrix(
331 indices=value.indices,
332 values=value.values,
333 dense_shape=value.dense_shape)
334 else:
335 value = ops.convert_to_tensor(value)
336 self._dtype = value.dtype
337 if indices is not None:
338 indices = ops.convert_to_tensor(indices, dtype=dtypes.int64)
339 else:
340 indices = array_ops.stop_gradient(array_ops.where(value))
341 self._csr_matrix = sm_ops.dense_to_csr_sparse_matrix(value, indices)
343 # Eager mode doesn't call shape inference functions, so we have to set the
344 # shape and dtype handle data directly.
345 if self._eager_mode:
346 # pylint: disable=protected-access
347 self._csr_matrix._handle_data = _make_handle_data(value)
348 # pylint: enable=protected-access
350 @property
351 def _matrix(self):
352 return self._csr_matrix
354 def _from_matrix(self, matrix, handle_data=None):
355 assert isinstance(matrix, ops.Tensor) and matrix.dtype == dtypes.variant
356 ret = type(self).__new__(type(self))
357 # pylint: disable=protected-access
358 ret._dtype = self._dtype
359 if self._eager_mode:
360 if matrix._handle_data is None:
361 matrix._handle_data = handle_data
362 assert matrix._handle_data is not None
363 ret._csr_matrix = matrix
364 # pylint: enable=protected-access
365 return ret
367 def to_dense(self):
368 return sm_ops.csr_sparse_matrix_to_dense(self._matrix, type=self.dtype)
370 def to_sparse_tensor(self):
371 r = sm_ops.csr_sparse_matrix_to_sparse_tensor(self._matrix, type=self.dtype)
372 return sparse_tensor.SparseTensor(
373 indices=r.indices, values=r.values, dense_shape=r.dense_shape)