Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/segment_id_ops.py: 31%

51 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Ops for converting between row_splits and segment_ids.""" 

16 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import ops 

19from tensorflow.python.framework import tensor_shape 

20from tensorflow.python.framework import tensor_util 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import math_ops 

23from tensorflow.python.ops.ragged import ragged_util 

24from tensorflow.python.util import dispatch 

25from tensorflow.python.util.tf_export import tf_export 

26 

27 

28# For background on "segments" and "segment ids", see: 

29# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation 

30@tf_export("ragged.row_splits_to_segment_ids") 

31@dispatch.add_dispatch_support 

32def row_splits_to_segment_ids(splits, name=None, out_type=None): 

33 """Generates the segmentation corresponding to a RaggedTensor `row_splits`. 

34 

35 Returns an integer vector `segment_ids`, where `segment_ids[i] == j` if 

36 `splits[j] <= i < splits[j+1]`. Example: 

37 

38 >>> print(tf.ragged.row_splits_to_segment_ids([0, 3, 3, 5, 6, 9])) 

39 tf.Tensor([0 0 0 2 2 3 4 4 4], shape=(9,), dtype=int64) 

40 

41 Args: 

42 splits: A sorted 1-D integer Tensor. `splits[0]` must be zero. 

43 name: A name prefix for the returned tensor (optional). 

44 out_type: The dtype for the return value. Defaults to `splits.dtype`, 

45 or `tf.int64` if `splits` does not have a dtype. 

46 

47 Returns: 

48 A sorted 1-D integer Tensor, with `shape=[splits[-1]]` 

49 

50 Raises: 

51 ValueError: If `splits` is invalid. 

52 """ 

53 with ops.name_scope(name, "RaggedSplitsToSegmentIds", [splits]) as name: 

54 splits = ops.convert_to_tensor( 

55 splits, name="splits", 

56 preferred_dtype=dtypes.int64) 

57 if splits.dtype not in (dtypes.int32, dtypes.int64): 

58 raise ValueError("splits must have dtype int32 or int64") 

59 splits.shape.assert_has_rank(1) 

60 if tensor_shape.dimension_value(splits.shape[0]) == 0: 

61 raise ValueError("Invalid row_splits: []") 

62 if out_type is None: 

63 out_type = splits.dtype 

64 else: 

65 out_type = dtypes.as_dtype(out_type) 

66 row_lengths = splits[1:] - splits[:-1] 

67 nrows = array_ops.shape(splits, out_type=out_type)[-1] - 1 

68 indices = math_ops.range(nrows) 

69 return ragged_util.repeat(indices, repeats=row_lengths, axis=0) 

70 

71 

72# For background on "segments" and "segment ids", see: 

73# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation 

74@tf_export("ragged.segment_ids_to_row_splits") 

75@dispatch.add_dispatch_support 

76def segment_ids_to_row_splits(segment_ids, num_segments=None, 

77 out_type=None, name=None): 

78 """Generates the RaggedTensor `row_splits` corresponding to a segmentation. 

79 

80 Returns an integer vector `splits`, where `splits[0] = 0` and 

81 `splits[i] = splits[i-1] + count(segment_ids==i)`. Example: 

82 

83 >>> print(tf.ragged.segment_ids_to_row_splits([0, 0, 0, 2, 2, 3, 4, 4, 4])) 

84 tf.Tensor([0 3 3 5 6 9], shape=(6,), dtype=int64) 

85 

86 Args: 

87 segment_ids: A 1-D integer Tensor. 

88 num_segments: A scalar integer indicating the number of segments. Defaults 

89 to `max(segment_ids) + 1` (or zero if `segment_ids` is empty). 

90 out_type: The dtype for the return value. Defaults to `segment_ids.dtype`, 

91 or `tf.int64` if `segment_ids` does not have a dtype. 

92 name: A name prefix for the returned tensor (optional). 

93 

94 Returns: 

95 A sorted 1-D integer Tensor, with `shape=[num_segments + 1]`. 

96 """ 

97 # Local import bincount_ops to avoid import-cycle. 

98 from tensorflow.python.ops import bincount_ops # pylint: disable=g-import-not-at-top 

99 if out_type is None: 

100 if isinstance(segment_ids, ops.Tensor): 

101 out_type = segment_ids.dtype 

102 elif isinstance(num_segments, ops.Tensor): 

103 out_type = num_segments.dtype 

104 else: 

105 out_type = dtypes.int64 

106 else: 

107 out_type = dtypes.as_dtype(out_type) 

108 with ops.name_scope(name, "SegmentIdsToRaggedSplits", [segment_ids]) as name: 

109 # Note: we cast int64 tensors to int32, since bincount currently only 

110 # supports int32 inputs. 

111 segment_ids = ragged_util.convert_to_int_tensor(segment_ids, "segment_ids", 

112 dtype=dtypes.int32) 

113 segment_ids.shape.assert_has_rank(1) 

114 if num_segments is not None: 

115 num_segments = ragged_util.convert_to_int_tensor(num_segments, 

116 "num_segments", 

117 dtype=dtypes.int32) 

118 num_segments.shape.assert_has_rank(0) 

119 

120 row_lengths = bincount_ops.bincount( 

121 segment_ids, 

122 minlength=num_segments, 

123 maxlength=num_segments, 

124 dtype=out_type) 

125 splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0) 

126 

127 # Update shape information, if possible. 

128 if num_segments is not None: 

129 const_num_segments = tensor_util.constant_value(num_segments) 

130 if const_num_segments is not None: 

131 splits.set_shape(tensor_shape.TensorShape([const_num_segments + 1])) 

132 

133 return splits