Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_tensor_value.py: 47%

53 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Value for RaggedTensor.""" 

16 

17import numpy as np 

18 

19from tensorflow.python.ops.ragged.row_partition import RowPartition 

20from tensorflow.python.util import dispatch 

21from tensorflow.python.util.tf_export import tf_export 

22 

23 

24@tf_export(v1=["ragged.RaggedTensorValue"]) 

25@dispatch.register_dispatchable_type 

26class RaggedTensorValue: 

27 """Represents the value of a `RaggedTensor`. 

28 

29 Warning: `RaggedTensorValue` should only be used in graph mode; in 

30 eager mode, the `tf.RaggedTensor` class contains its value directly. 

31 

32 See `tf.RaggedTensor` for a description of ragged tensors. 

33 """ 

34 

35 def __init__(self, values, row_splits): 

36 """Creates a `RaggedTensorValue`. 

37 

38 Args: 

39 values: A numpy array of any type and shape; or a RaggedTensorValue. 

40 row_splits: A 1-D int32 or int64 numpy array. 

41 """ 

42 if not (isinstance(row_splits, (np.ndarray, np.generic)) and 

43 row_splits.dtype in (np.int64, np.int32) and row_splits.ndim == 1): 

44 raise TypeError("row_splits must be a 1D int32 or int64 numpy array") 

45 if not isinstance(values, (np.ndarray, np.generic, RaggedTensorValue)): 

46 raise TypeError("values must be a numpy array or a RaggedTensorValue") 

47 if (isinstance(values, RaggedTensorValue) and 

48 row_splits.dtype != values.row_splits.dtype): 

49 raise ValueError("row_splits and values.row_splits must have " 

50 "the same dtype") 

51 self._values = values 

52 self._row_splits = row_splits 

53 

54 row_splits = property( 

55 lambda self: self._row_splits, 

56 doc="""The split indices for the ragged tensor value.""") 

57 values = property( 

58 lambda self: self._values, 

59 doc="""The concatenated values for all rows in this tensor.""") 

60 dtype = property( 

61 lambda self: self._values.dtype, 

62 doc="""The numpy dtype of values in this tensor.""") 

63 

64 @property 

65 def flat_values(self): 

66 """The innermost `values` array for this ragged tensor value.""" 

67 rt_values = self.values 

68 while isinstance(rt_values, RaggedTensorValue): 

69 rt_values = rt_values.values 

70 return rt_values 

71 

72 @property 

73 def nested_row_splits(self): 

74 """The row_splits for all ragged dimensions in this ragged tensor value.""" 

75 rt_nested_splits = [self.row_splits] 

76 rt_values = self.values 

77 while isinstance(rt_values, RaggedTensorValue): 

78 rt_nested_splits.append(rt_values.row_splits) 

79 rt_values = rt_values.values 

80 return tuple(rt_nested_splits) 

81 

82 @property 

83 def ragged_rank(self): 

84 """The number of ragged dimensions in this ragged tensor value.""" 

85 values_is_ragged = isinstance(self._values, RaggedTensorValue) 

86 return self._values.ragged_rank + 1 if values_is_ragged else 1 

87 

88 @property 

89 def shape(self): 

90 """A tuple indicating the shape of this RaggedTensorValue.""" 

91 return (self._row_splits.shape[0] - 1,) + (None,) + self._values.shape[1:] 

92 

93 @property 

94 def _nested_row_partitions(self): 

95 """The row_partitions representing this shape.""" 

96 return [RowPartition.from_row_splits(rs) for rs in self.nested_row_splits] 

97 

98 def __str__(self): 

99 return "<tf.RaggedTensorValue %s>" % self.to_list() 

100 

101 def __repr__(self): 

102 return "tf.RaggedTensorValue(values=%r, row_splits=%r)" % (self._values, 

103 self._row_splits) 

104 

105 def to_list(self): 

106 """Returns this ragged tensor value as a nested Python list.""" 

107 if isinstance(self._values, RaggedTensorValue): 

108 values_as_list = self._values.to_list() 

109 else: 

110 values_as_list = self._values.tolist() 

111 return [ 

112 values_as_list[self._row_splits[i]:self._row_splits[i + 1]] 

113 for i in range(len(self._row_splits) - 1) 

114 ]