Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_tensor_value.py: 47%
53 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Value for RaggedTensor."""
17import numpy as np
19from tensorflow.python.ops.ragged.row_partition import RowPartition
20from tensorflow.python.util import dispatch
21from tensorflow.python.util.tf_export import tf_export
24@tf_export(v1=["ragged.RaggedTensorValue"])
25@dispatch.register_dispatchable_type
26class RaggedTensorValue:
27 """Represents the value of a `RaggedTensor`.
29 Warning: `RaggedTensorValue` should only be used in graph mode; in
30 eager mode, the `tf.RaggedTensor` class contains its value directly.
32 See `tf.RaggedTensor` for a description of ragged tensors.
33 """
35 def __init__(self, values, row_splits):
36 """Creates a `RaggedTensorValue`.
38 Args:
39 values: A numpy array of any type and shape; or a RaggedTensorValue.
40 row_splits: A 1-D int32 or int64 numpy array.
41 """
42 if not (isinstance(row_splits, (np.ndarray, np.generic)) and
43 row_splits.dtype in (np.int64, np.int32) and row_splits.ndim == 1):
44 raise TypeError("row_splits must be a 1D int32 or int64 numpy array")
45 if not isinstance(values, (np.ndarray, np.generic, RaggedTensorValue)):
46 raise TypeError("values must be a numpy array or a RaggedTensorValue")
47 if (isinstance(values, RaggedTensorValue) and
48 row_splits.dtype != values.row_splits.dtype):
49 raise ValueError("row_splits and values.row_splits must have "
50 "the same dtype")
51 self._values = values
52 self._row_splits = row_splits
54 row_splits = property(
55 lambda self: self._row_splits,
56 doc="""The split indices for the ragged tensor value.""")
57 values = property(
58 lambda self: self._values,
59 doc="""The concatenated values for all rows in this tensor.""")
60 dtype = property(
61 lambda self: self._values.dtype,
62 doc="""The numpy dtype of values in this tensor.""")
64 @property
65 def flat_values(self):
66 """The innermost `values` array for this ragged tensor value."""
67 rt_values = self.values
68 while isinstance(rt_values, RaggedTensorValue):
69 rt_values = rt_values.values
70 return rt_values
72 @property
73 def nested_row_splits(self):
74 """The row_splits for all ragged dimensions in this ragged tensor value."""
75 rt_nested_splits = [self.row_splits]
76 rt_values = self.values
77 while isinstance(rt_values, RaggedTensorValue):
78 rt_nested_splits.append(rt_values.row_splits)
79 rt_values = rt_values.values
80 return tuple(rt_nested_splits)
82 @property
83 def ragged_rank(self):
84 """The number of ragged dimensions in this ragged tensor value."""
85 values_is_ragged = isinstance(self._values, RaggedTensorValue)
86 return self._values.ragged_rank + 1 if values_is_ragged else 1
88 @property
89 def shape(self):
90 """A tuple indicating the shape of this RaggedTensorValue."""
91 return (self._row_splits.shape[0] - 1,) + (None,) + self._values.shape[1:]
93 @property
94 def _nested_row_partitions(self):
95 """The row_partitions representing this shape."""
96 return [RowPartition.from_row_splits(rs) for rs in self.nested_row_splits]
98 def __str__(self):
99 return "<tf.RaggedTensorValue %s>" % self.to_list()
101 def __repr__(self):
102 return "tf.RaggedTensorValue(values=%r, row_splits=%r)" % (self._values,
103 self._row_splits)
105 def to_list(self):
106 """Returns this ragged tensor value as a nested Python list."""
107 if isinstance(self._values, RaggedTensorValue):
108 values_as_list = self._values.to_list()
109 else:
110 values_as_list = self._values.tolist()
111 return [
112 values_as_list[self._row_splits[i]:self._row_splits[i + 1]]
113 for i in range(len(self._row_splits) - 1)
114 ]