Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_util.py: 38%

32 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Private convenience functions for RaggedTensors. 

16 

17None of these methods are exposed in the main "ragged" package. 

18""" 

19 

20from tensorflow.python.ops import array_ops 

21from tensorflow.python.ops import check_ops 

22from tensorflow.python.ops import control_flow_ops 

23from tensorflow.python.ops import gen_ragged_math_ops 

24from tensorflow.python.ops import math_ops 

25 

26 

27def assert_splits_match(nested_splits_lists): 

28 """Checks that the given splits lists are identical. 

29 

30 Performs static tests to ensure that the given splits lists are identical, 

31 and returns a list of control dependency op tensors that check that they are 

32 fully identical. 

33 

34 Args: 

35 nested_splits_lists: A list of nested_splits_lists, where each split_list is 

36 a list of `splits` tensors from a `RaggedTensor`, ordered from outermost 

37 ragged dimension to innermost ragged dimension. 

38 

39 Returns: 

40 A list of control dependency op tensors. 

41 Raises: 

42 ValueError: If the splits are not identical. 

43 """ 

44 error_msg = "Inputs must have identical ragged splits" 

45 for splits_list in nested_splits_lists: 

46 if len(splits_list) != len(nested_splits_lists[0]): 

47 raise ValueError(error_msg) 

48 return [ 

49 check_ops.assert_equal(s1, s2, message=error_msg) 

50 for splits_list in nested_splits_lists[1:] 

51 for (s1, s2) in zip(nested_splits_lists[0], splits_list) 

52 ] 

53 

54 

55# Note: imported here to avoid circular dependency of array_ops. 

56get_positive_axis = array_ops.get_positive_axis 

57convert_to_int_tensor = array_ops.convert_to_int_tensor 

58repeat = array_ops.repeat_with_axis 

59 

60 

61def lengths_to_splits(lengths): 

62 """Returns splits corresponding to the given lengths.""" 

63 return array_ops.concat([[0], math_ops.cumsum(lengths)], axis=-1) 

64 

65 

66def repeat_ranges(params, splits, repeats): 

67 """Repeats each range of `params` (as specified by `splits`) `repeats` times. 

68 

69 Let the `i`th range of `params` be defined as 

70 `params[splits[i]:splits[i + 1]]`. Then this function returns a tensor 

71 containing range 0 repeated `repeats[0]` times, followed by range 1 repeated 

72 `repeats[1]`, ..., followed by the last range repeated `repeats[-1]` times. 

73 

74 Args: 

75 params: The `Tensor` whose values should be repeated. 

76 splits: A splits tensor indicating the ranges of `params` that should be 

77 repeated. Elements should be non-negative integers. 

78 repeats: The number of times each range should be repeated. Supports 

79 broadcasting from a scalar value. Elements should be non-negative 

80 integers. 

81 

82 Returns: 

83 A `Tensor` with the same rank and type as `params`. 

84 

85 #### Example: 

86 

87 >>> print(repeat_ranges( 

88 ... params=tf.constant(['a', 'b', 'c']), 

89 ... splits=tf.constant([0, 2, 3]), 

90 ... repeats=tf.constant(3))) 

91 tf.Tensor([b'a' b'b' b'a' b'b' b'a' b'b' b'c' b'c' b'c'], 

92 shape=(9,), dtype=string) 

93 """ 

94 # Check if the input is valid 

95 splits_checks = [ 

96 check_ops.assert_non_negative( 

97 splits, message="Input argument 'splits' must be non-negative" 

98 ), 

99 check_ops.assert_integer( 

100 splits, 

101 message=( 

102 "Input argument 'splits' must be integer, but got" 

103 f" {splits.dtype} instead" 

104 ), 

105 ), 

106 ] 

107 repeats_checks = [ 

108 check_ops.assert_non_negative( 

109 repeats, message="Input argument 'repeats' must be non-negative" 

110 ), 

111 check_ops.assert_integer( 

112 repeats, 

113 message=( 

114 "Input argument 'repeats' must be integer, but got" 

115 f" {repeats.dtype} instead" 

116 ), 

117 ), 

118 ] 

119 splits = control_flow_ops.with_dependencies(splits_checks, splits) 

120 repeats = control_flow_ops.with_dependencies(repeats_checks, repeats) 

121 

122 # Divide `splits` into starts and limits, and repeat them `repeats` times. 

123 if repeats.shape.ndims != 0: 

124 repeated_starts = repeat(splits[:-1], repeats, axis=0) 

125 repeated_limits = repeat(splits[1:], repeats, axis=0) 

126 else: 

127 # Optimization: we can just call repeat once, and then slice the result. 

128 repeated_splits = repeat(splits, repeats, axis=0) 

129 n_splits = array_ops.shape(repeated_splits, out_type=repeats.dtype)[0] 

130 repeated_starts = repeated_splits[:n_splits - repeats] 

131 repeated_limits = repeated_splits[repeats:] 

132 

133 # Get indices for each range from starts to limits, and use those to gather 

134 # the values in the desired repetition pattern. 

135 one = array_ops.ones((), repeated_starts.dtype) 

136 offsets = gen_ragged_math_ops.ragged_range( 

137 repeated_starts, repeated_limits, one) 

138 return array_ops.gather(params, offsets.rt_dense_values)