Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/sparse_tensor.py: 46%
198 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Sparse tensors."""
16# pylint: disable=g-bad-name
17import collections
19import numpy as np
21from tensorflow.core.protobuf import struct_pb2
22from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
23from tensorflow.python import tf2
24from tensorflow.python.framework import composite_tensor
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_spec
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.framework import type_spec
32from tensorflow.python.framework import type_spec_registry
33from tensorflow.python.ops import array_ops_stack
34from tensorflow.python.ops import gen_sparse_ops
35from tensorflow.python.saved_model import nested_structure_coder
36from tensorflow.python.types import internal
37from tensorflow.python.util import _pywrap_utils
38from tensorflow.python.util.tf_export import tf_export
40# pylint: disable=protected-access
41_eval_using_default_session = ops._eval_using_default_session
42_override_helper = ops._override_helper
43# pylint: enable=protected-access
46@tf_export("sparse.SparseTensor", "SparseTensor")
47class SparseTensor(internal.NativeObject, composite_tensor.CompositeTensor):
48 """Represents a sparse tensor.
50 TensorFlow represents a sparse tensor as three separate dense tensors:
51 `indices`, `values`, and `dense_shape`. In Python, the three tensors are
52 collected into a `SparseTensor` class for ease of use. If you have separate
53 `indices`, `values`, and `dense_shape` tensors, wrap them in a `SparseTensor`
54 object before passing to the ops below.
56 Concretely, the sparse tensor `SparseTensor(indices, values, dense_shape)`
57 comprises the following components, where `N` and `ndims` are the number
58 of values and number of dimensions in the `SparseTensor`, respectively:
60 * `indices`: A 2-D int64 tensor of shape `[N, ndims]`, which specifies the
61 indices of the elements in the sparse tensor that contain nonzero values
62 (elements are zero-indexed). For example, `indices=[[1,3], [2,4]]` specifies
63 that the elements with indexes of [1,3] and [2,4] have nonzero values.
65 * `values`: A 1-D tensor of any type and shape `[N]`, which supplies the
66 values for each element in `indices`. For example, given `indices=[[1,3],
67 [2,4]]`, the parameter `values=[18, 3.6]` specifies that element [1,3] of
68 the sparse tensor has a value of 18, and element [2,4] of the tensor has a
69 value of 3.6.
71 * `dense_shape`: A 1-D int64 tensor of shape `[ndims]`, which specifies the
72 dense_shape of the sparse tensor. Takes a list indicating the number of
73 elements in each dimension. For example, `dense_shape=[3,6]` specifies a
74 two-dimensional 3x6 tensor, `dense_shape=[2,3,4]` specifies a
75 three-dimensional 2x3x4 tensor, and `dense_shape=[9]` specifies a
76 one-dimensional tensor with 9 elements.
78 The corresponding dense tensor satisfies:
80 ```python
81 dense.shape = dense_shape
82 dense[tuple(indices[i])] = values[i]
83 ```
85 By convention, `indices` should be sorted in row-major order (or equivalently
86 lexicographic order on the tuples `indices[i]`). This is not enforced when
87 `SparseTensor` objects are constructed, but most ops assume correct ordering.
88 If the ordering of sparse tensor `st` is wrong, a fixed version can be
89 obtained by calling `tf.sparse.reorder(st)`.
91 Example: The sparse tensor
93 ```python
94 SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
95 ```
97 represents the dense tensor
99 ```python
100 [[1, 0, 0, 0]
101 [0, 0, 2, 0]
102 [0, 0, 0, 0]]
103 ```
104 """
106 @classmethod
107 def from_value(cls, sparse_tensor_value):
108 if not is_sparse(sparse_tensor_value):
109 raise TypeError(f"Argument sparse_tensor_value={sparse_tensor_value} "
110 "is neither a SparseTensor nor SparseTensorValue.")
111 return SparseTensor(
112 indices=sparse_tensor_value.indices,
113 values=sparse_tensor_value.values,
114 dense_shape=sparse_tensor_value.dense_shape)
116 def __init__(self, indices, values, dense_shape):
117 """Creates a `SparseTensor`.
119 Args:
120 indices: A 2-D int64 tensor of shape `[N, ndims]`.
121 values: A 1-D tensor of any type and shape `[N]`.
122 dense_shape: A 1-D int64 tensor of shape `[ndims]`.
124 Raises:
125 ValueError: When building an eager SparseTensor if `dense_shape` is
126 unknown or contains unknown elements (None or -1).
127 """
128 with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
129 indices = ops.convert_to_tensor(
130 indices, name="indices", dtype=dtypes.int64)
131 # TODO(touts): Consider adding mutable_values() when 'values'
132 # is a VariableOp and updating users of SparseTensor.
133 values = ops.convert_to_tensor(values, name="values")
135 dense_shape = ops.convert_to_tensor(
136 dense_shape, name="dense_shape", dtype=dtypes.int64)
137 dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)
139 self._indices = indices
140 self._values = values
141 self._dense_shape = dense_shape
142 self._dense_shape_default = dense_shape_default
144 indices_shape = indices.shape.with_rank(2)
145 values_shape = values.shape.with_rank(1)
146 dense_shape_shape = dense_shape.shape.with_rank(1)
148 # Assert number of rows in indices match the number of elements in values.
149 indices_shape.dims[0].assert_is_compatible_with(values_shape.dims[0])
150 # Assert number of columns in indices matches the number of elements in
151 # dense_shape.
152 indices_shape.dims[1].assert_is_compatible_with(dense_shape_shape.dims[0])
154 def get_shape(self):
155 """Get the `TensorShape` representing the shape of the dense tensor.
157 Returns:
158 A `TensorShape` object.
159 """
160 return self._dense_shape_default
162 @property
163 def indices(self):
164 """The indices of non-zero values in the represented dense tensor.
166 Returns:
167 A 2-D Tensor of int64 with dense_shape `[N, ndims]`, where `N` is the
168 number of non-zero values in the tensor, and `ndims` is the rank.
169 """
170 return self._indices
172 @property
173 def values(self):
174 """The non-zero values in the represented dense tensor.
176 Returns:
177 A 1-D Tensor of any data type.
178 """
179 return self._values
181 def with_values(self, new_values):
182 """Returns a copy of `self` with `values` replaced by `new_values`.
184 This method produces a new `SparseTensor` that has the same nonzero
185 `indices` and same `dense_shape`, but updated values.
187 Args:
188 new_values: The values of the new `SparseTensor`. Needs to have the same
189 shape as the current `.values` `Tensor`. May have a different type than
190 the current `values`.
192 Returns:
193 A `SparseTensor` with identical indices and shape but updated values.
195 Example usage:
197 >>> st = tf.sparse.from_dense([[1, 0, 2, 0], [3, 0, 0, 4]])
198 >>> tf.sparse.to_dense(st.with_values([10, 20, 30, 40])) # 4 nonzero values
199 <tf.Tensor: shape=(2, 4), dtype=int32, numpy=
200 array([[10, 0, 20, 0],
201 [30, 0, 0, 40]], dtype=int32)>
203 """
204 return SparseTensor(self._indices, new_values, self._dense_shape)
206 @property
207 def op(self):
208 """The `Operation` that produces `values` as an output."""
209 return self._values.op
211 @property
212 def dtype(self):
213 """The `DType` of elements in this tensor."""
214 return self._values.dtype
216 @property
217 def dense_shape(self):
218 """A 1-D Tensor of int64 representing the shape of the dense tensor."""
219 return self._dense_shape
221 @property
222 def shape(self):
223 """Get the `TensorShape` representing the shape of the dense tensor.
225 Returns:
226 A `TensorShape` object.
227 """
228 return self._dense_shape_default
230 def set_shape(self, shape):
231 """Updates the `TensorShape` representing the shape of the dense tensor.
233 With eager execution this operates as a shape assertion.
234 Here the shapes match:
236 >>> st = tf.SparseTensor(
237 ... indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
238 >>> st.set_shape([3, 4])
240 Passing a `None` in the new shape allows any value for that axis:
242 >>> st.set_shape([3, None])
244 An error is raised if an incompatible shape is passed.
246 >>> st.set_shape([1, 4])
247 Traceback (most recent call last):
248 ...
249 ValueError: Tensor's shape (3, 4) is not compatible with supplied
250 shape [1, 4]
252 When executing in a `tf.function`, or building a model using
253 `tf.keras.Input`, `SparseTensor.set_shape` will *merge* the given `shape`
254 with the current shape of this tensor, and set the tensor's shape to the
255 merged value (see `tf.TensorShape.merge_with` for details):
257 >>> st = tf.keras.Input(shape=[None, None, 3], sparse=True)
258 >>> print(st.shape)
259 (None, None, None, 3)
261 Dimensions set to `None` are not updated:
263 >>> st.set_shape([None, 224, 224, None])
264 >>> print(st.shape)
265 (None, 224, 224, 3)
267 The main use case for this is to provide additional shape information
268 that cannot be inferred from the graph alone.
270 Caution: `set_shape` ensures that the applied shape is compatible with
271 the existing shape, but it does not check at runtime. Setting
272 incorrect shapes can result in inconsistencies between the
273 statically-known graph and the runtime value of tensors.
275 Args:
276 shape: A `TensorShape` representing the shape of this tensor, a
277 `TensorShapeProto`, a list, a tuple, or None.
279 Raises:
280 ValueError: If `shape` is not compatible with the current shape of
281 this tensor.
282 """
283 if not isinstance(shape, tensor_shape.TensorShape):
284 shape = tensor_shape.TensorShape(shape)
285 self._dense_shape_default = self._dense_shape_default.merge_with(shape)
287 @property
288 def graph(self):
289 """The `Graph` that contains the index, value, and dense_shape tensors."""
290 return self._indices.graph
292 def __repr__(self):
293 return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % (
294 self._indices, self._values, self._dense_shape)
296 def eval(self, feed_dict=None, session=None):
297 """Evaluates this sparse tensor in a `Session`.
299 Calling this method will execute all preceding operations that
300 produce the inputs needed for the operation that produces this
301 tensor.
303 *N.B.* Before invoking `SparseTensor.eval()`, its graph must have been
304 launched in a session, and either a default session must be
305 available, or `session` must be specified explicitly.
307 Args:
308 feed_dict: A dictionary that maps `Tensor` objects to feed values. See
309 `tf.Session.run` for a description of the valid feed values.
310 session: (Optional.) The `Session` to be used to evaluate this sparse
311 tensor. If none, the default session will be used.
313 Returns:
314 A `SparseTensorValue` object.
315 """
316 indices, values, dense_shape = _eval_using_default_session(
317 [self.indices, self.values, self.dense_shape], feed_dict, self.graph,
318 session)
319 return SparseTensorValue(indices, values, dense_shape)
321 @staticmethod
322 def _override_operator(operator, func):
323 _override_helper(SparseTensor, operator, func)
325 @property
326 def _type_spec(self):
327 return SparseTensorSpec(self.shape, self.dtype)
329 def _shape_invariant_to_type_spec(self, shape):
330 # From the tf.while_loop docs: "If a loop variable is a SparseTensor, the
331 # shape invariant must be TensorShape([r]) where r is the rank of the dense
332 # tensor represented by the sparse tensor. It means the shapes of the three
333 # tensors of the SparseTensor are ([None], [None, r], [r]). NOTE: The shape
334 # invariant here is the shape of the SparseTensor.dense_shape property. It
335 # must be the shape of a vector.
336 if shape.ndims is not None and shape.ndims != 1:
337 raise ValueError(f"Expected a shape with 1 dimension. Obtained: {shape} "
338 f"which has {shape.ndims} dimensions.")
339 rank = tensor_shape.dimension_value(shape[0])
340 return SparseTensorSpec(tensor_shape.unknown_shape(rank), self.dtype)
342 def consumers(self):
343 return self._consumers()
345 def _numpy(self):
346 """Returns a numpy `array` with the values for this `SparseTensor`.
348 Requires that this `SparseTensor` was constructed in eager execution mode.
349 """
350 if not self._is_eager():
351 raise ValueError("SparseTensor.numpy() is only supported in eager mode.")
352 arr = np.zeros(self.dense_shape, dtype=self.dtype.as_numpy_dtype())
353 for i, v in zip(self.indices, self.values):
354 arr[tuple(i)] = v
356 return arr
358 def _is_eager(self):
359 """Returns True if this `SparseTensor` was constructed in eager execution.
361 Requires that each individual component of `SparseTensor`
362 (`indices`, `values` and `dense_shape`) is an instance of `EagerTensor`.
363 """
365 return all(
366 isinstance(t, ops.EagerTensor)
367 for t in (self.indices, self.values, self.dense_shape))
370SparseTensorValue = collections.namedtuple("SparseTensorValue",
371 ["indices", "values", "dense_shape"])
372tf_export(v1=["SparseTensorValue"])(SparseTensorValue)
373_pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue)
376@tf_export("SparseTensorSpec")
377@type_spec_registry.register("tf.SparseTensorSpec")
378class SparseTensorSpec(type_spec.BatchableTypeSpec):
379 """Type specification for a `tf.sparse.SparseTensor`."""
381 __slots__ = ["_shape", "_dtype"]
383 value_type = property(lambda self: SparseTensor)
385 def __init__(self, shape=None, dtype=dtypes.float32):
386 """Constructs a type specification for a `tf.sparse.SparseTensor`.
388 Args:
389 shape: The dense shape of the `SparseTensor`, or `None` to allow any dense
390 shape.
391 dtype: `tf.DType` of values in the `SparseTensor`.
392 """
393 self._shape = tensor_shape.as_shape(shape)
394 self._dtype = dtypes.as_dtype(dtype)
396 def _serialize(self):
397 return (self._shape, self._dtype)
399 @property
400 def dtype(self):
401 """The `tf.dtypes.DType` specified by this type for the SparseTensor."""
402 return self._dtype
404 @property
405 def shape(self):
406 """The `tf.TensorShape` specified by this type for the SparseTensor."""
407 return self._shape
409 @property
410 def _component_specs(self):
411 rank = self._shape.ndims
412 num_values = None
413 return [
414 tensor_spec.TensorSpec([num_values, rank], dtypes.int64),
415 tensor_spec.TensorSpec([num_values], self._dtype),
416 tensor_spec.TensorSpec([rank], dtypes.int64)]
418 def _to_components(self, value):
419 if isinstance(value, SparseTensorValue):
420 value = SparseTensor.from_value(value)
421 return [value.indices, value.values, value.dense_shape]
423 def _from_components(self, tensor_list):
424 if (all(isinstance(t, np.ndarray) for t in tensor_list) and
425 not tf2.enabled()):
426 return SparseTensorValue(*tensor_list)
427 else:
428 result = SparseTensor(*tensor_list)
429 # Augment the static dense shape with the shape carried by the spec.
430 result._dense_shape_default = result._dense_shape_default.merge_with( # pylint: disable=protected-access
431 self._shape)
432 return result
434 # The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops
435 # to (un)box the component tensors in a way that allows for batching &
436 # unbatching.
437 @property
438 def _flat_tensor_specs(self):
439 # NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
440 # but a `SparseTensorSpec` can also represent a batch of boxed
441 # `SparseTensor` objects with shape `(..., 3)` (and batches of batches,
442 # etc.), so the flat shape must be unknown.
443 return [tensor_spec.TensorSpec(None, dtypes.variant)]
445 def _to_tensor_list(self, value):
446 value = SparseTensor.from_value(value)
447 return [gen_sparse_ops.serialize_sparse(
448 value.indices, value.values, value.dense_shape,
449 out_type=dtypes.variant)]
451 def _to_batched_tensor_list(self, value):
452 dense_shape = tensor_util.constant_value_as_shape(value.dense_shape)
453 if self._shape.merge_with(dense_shape).ndims == 0:
454 raise ValueError(
455 "Unbatching a sparse tensor is only supported for rank >= 1. "
456 f"Obtained input: {value}.")
457 return [gen_sparse_ops.serialize_many_sparse(
458 value.indices, value.values, value.dense_shape,
459 out_type=dtypes.variant)]
461 def _from_compatible_tensor_list(self, tensor_list):
462 tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype)
463 indices, values, dense_shape = tensor_list
464 rank = self._shape.ndims
465 indices.set_shape([None, rank])
466 # We restore the dense_shape from the SparseTypeSpec. This is necessary
467 # for shape inference when using placeholder SparseTensors in function
468 # tracing.
469 if self._shape.is_fully_defined():
470 dense_shape = ops.convert_to_tensor(
471 self._shape, dtype=dtypes.int64, name="shape")
472 elif (self._shape.rank is not None and
473 any(dim.value is not None for dim in self._shape.dims)):
474 pieces = array_ops_stack.unstack(dense_shape, num=self._shape.rank)
475 for i, dim in enumerate(self._shape.dims):
476 if dim.value is not None:
477 pieces[i] = constant_op.constant(dim.value, dense_shape.dtype)
478 dense_shape = array_ops_stack.stack(pieces)
479 else:
480 dense_shape.set_shape([rank])
482 return SparseTensor(indices, values, dense_shape)
484 def _batch(self, batch_size):
485 return SparseTensorSpec(
486 tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
487 self._dtype)
489 def _unbatch(self):
490 if self._shape.ndims == 0:
491 raise ValueError("Unbatching a tensor is only supported for rank >= 1")
492 return SparseTensorSpec(self._shape[1:], self._dtype)
494 def _to_legacy_output_types(self):
495 return self._dtype
497 def _to_legacy_output_shapes(self):
498 return self._shape
500 def _to_legacy_output_classes(self):
501 return SparseTensor
503 @classmethod
504 def from_value(cls, value):
505 if isinstance(value, SparseTensor):
506 return cls(value.shape, value.dtype)
507 if isinstance(value, SparseTensorValue):
508 if isinstance(value.values, np.ndarray):
509 return cls(value.dense_shape, value.values.dtype)
510 else:
511 return cls.from_value(SparseTensor.from_value(value))
512 else:
513 raise TypeError("Expected SparseTensor or SparseTensorValue. Received: "
514 f"{value} of type {type(value).__name__}.")
517nested_structure_coder.register_codec(
518 nested_structure_coder.BuiltInTypeSpecCodec(
519 SparseTensorSpec, struct_pb2.TypeSpecProto.SPARSE_TENSOR_SPEC
520 )
521)
524# TODO(b/133606651) Delete the SparseTensor registration when CompositeTensor
525# is updated to define a _type_spec field (since registration will be
526# automatic). Do *not* delete the SparseTensorValue registration.
527type_spec.register_type_spec_from_value_converter(
528 SparseTensor, SparseTensorSpec.from_value)
529type_spec.register_type_spec_from_value_converter(
530 SparseTensorValue, SparseTensorSpec.from_value)
533@tf_export(v1=["convert_to_tensor_or_sparse_tensor"])
534def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
535 """Converts value to a `SparseTensor` or `Tensor`.
537 Args:
538 value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a
539 registered `Tensor` conversion function.
540 dtype: Optional element type for the returned tensor. If missing, the type
541 is inferred from the type of `value`.
542 name: Optional name to use if a new `Tensor` is created.
544 Returns:
545 A `SparseTensor` or `Tensor` based on `value`.
547 Raises:
548 RuntimeError: If result type is incompatible with `dtype`.
549 """
550 if dtype is not None:
551 dtype = dtypes.as_dtype(dtype)
552 if isinstance(value, SparseTensorValue):
553 value = SparseTensor.from_value(value)
554 if isinstance(value, SparseTensor):
555 if dtype and not dtype.is_compatible_with(value.dtype):
556 raise RuntimeError(f"Sparse dtype mismatch. Requested: {dtype.name}, "
557 f" Actual: {value.dtype.name}")
558 return value
559 return ops.convert_to_tensor(value, dtype=dtype, name=name)
562def is_sparse(x):
563 """Check whether `x` is sparse.
565 Check whether an object is a `tf.sparse.SparseTensor` or
566 `tf.compat.v1.SparseTensorValue`.
568 Args:
569 x: A python object to check.
571 Returns:
572 `True` iff `x` is a `tf.sparse.SparseTensor` or
573 `tf.compat.v1.SparseTensorValue`.
574 """
575 return isinstance(x, (SparseTensor, SparseTensorValue))