Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_functional_ops.py: 21%

57 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Support for ragged tensors.""" 

16 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import tensor_shape 

19from tensorflow.python.ops.ragged import ragged_config 

20from tensorflow.python.ops.ragged import ragged_tensor 

21from tensorflow.python.util import dispatch 

22from tensorflow.python.util.tf_export import tf_export 

23 

24 

25@tf_export("ragged.map_flat_values") 

26@dispatch.add_dispatch_support 

27def map_flat_values(op, *args, **kwargs): 

28 """Applies `op` to the `flat_values` of one or more RaggedTensors. 

29 

30 Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values` 

31 tensor (which collapses all ragged dimensions), and then calls `op`. Returns 

32 a `RaggedTensor` that is constructed from the input `RaggedTensor`s' 

33 `nested_row_splits` and the value returned by the `op`. 

34 

35 If the input arguments contain multiple `RaggedTensor`s, then they must have 

36 identical `nested_row_splits`. 

37 

38 This operation is generally used to apply elementwise operations to each value 

39 in a `RaggedTensor`. 

40 

41 Warning: `tf.ragged.map_flat_values` does *not* apply `op` to each row of a 

42 ragged tensor. This difference is important for non-elementwise operations, 

43 such as `tf.reduce_sum`. If you wish to apply a non-elementwise operation to 

44 each row of a ragged tensor, use `tf.map_fn` instead. (You may need to 

45 specify an `output_signature` when using `tf.map_fn` with ragged tensors.) 

46 

47 Examples: 

48 

49 >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) 

50 >>> tf.ragged.map_flat_values(tf.ones_like, rt) 

51 <tf.RaggedTensor [[1, 1, 1], [], [1, 1], [1]]> 

52 >>> tf.ragged.map_flat_values(tf.multiply, rt, rt) 

53 <tf.RaggedTensor [[1, 4, 9], [], [16, 25], [36]]> 

54 >>> tf.ragged.map_flat_values(tf.add, rt, 5) 

55 <tf.RaggedTensor [[6, 7, 8], [], [9, 10], [11]]> 

56 

57 Example with a non-elementwise operation (note that `map_flat_values` and 

58 `map_fn` return different results): 

59 

60 >>> rt = tf.ragged.constant([[1.0, 3.0], [], [3.0, 6.0, 3.0]]) 

61 >>> def normalized(x): 

62 ... return x / tf.reduce_sum(x) 

63 >>> tf.ragged.map_flat_values(normalized, rt) 

64 <tf.RaggedTensor [[0.0625, 0.1875], [], [0.1875, 0.375, 0.1875]]> 

65 >>> tf.map_fn(normalized, rt) 

66 <tf.RaggedTensor [[0.25, 0.75], [], [0.25, 0.5, 0.25]]> 

67 

68 Args: 

69 op: The operation that should be applied to the RaggedTensor `flat_values`. 

70 `op` is typically an element-wise operation (such as math_ops.add), but 

71 any operation that preserves the size of the outermost dimension can be 

72 used. I.e., `shape[0]` of the value returned by `op` must match 

73 `shape[0]` of the `RaggedTensor`s' `flat_values` tensors. 

74 *args: Arguments for `op`. 

75 **kwargs: Keyword arguments for `op`. 

76 

77 Returns: 

78 A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all 

79 input `RaggedTensor`s. 

80 Raises: 

81 ValueError: If args contains no `RaggedTensors`, or if the `nested_splits` 

82 of the input `RaggedTensor`s are not identical. 

83 """ 

84 # Replace RaggedTensors with their values; and collect the partitions tensors 

85 # from each RaggedTensor. 

86 partition_lists = [] 

87 flat_values_nrows = [] 

88 inner_args = _replace_ragged_with_flat_values(args, partition_lists, 

89 flat_values_nrows) 

90 inner_kwargs = _replace_ragged_with_flat_values(kwargs, partition_lists, 

91 flat_values_nrows) 

92 if not partition_lists: 

93 return op(*args, **kwargs) 

94 

95 # If we can statically determine that the inputs are incompatible, then raise 

96 # an error. (We can't guarantee full compatibility statically, so we need to 

97 # perform some runtime checks too; but this allows us to fail sooner in some 

98 # cases.) 

99 if flat_values_nrows: 

100 flat_values_nrows = set(flat_values_nrows) 

101 if len(flat_values_nrows) != 1: 

102 raise ValueError("Input RaggedTensors' flat_values must all have the " 

103 "same outer-dimension size. Got sizes: %s" % 

104 flat_values_nrows) 

105 flat_values_nrows = flat_values_nrows.pop() # Get the single element 

106 else: 

107 flat_values_nrows = None 

108 

109 partition_dtypes = set(p[0].dtype for p in partition_lists) 

110 if len(partition_dtypes) > 1: 

111 if not ragged_config.auto_cast_partition_dtype(): 

112 raise ValueError("Input RaggedTensors have mismatched row partition " 

113 "dtypes; use RaggedTensor.with_row_splits_dtype() to " 

114 "convert them to compatible dtypes.") 

115 

116 partition_lists = [ 

117 [p.with_dtype(dtypes.int64) 

118 for p in partition_list] # pylint: disable=g-complex-comprehension 

119 for partition_list in partition_lists 

120 ] 

121 

122 # Delegate to `op` 

123 op_output = op(*inner_args, **inner_kwargs) 

124 # Check that the result has the expected shape (if known). 

125 if flat_values_nrows is not None: 

126 if not op_output.shape[:1].is_compatible_with([flat_values_nrows]): 

127 raise ValueError( 

128 "tf.ragged.map_flat_values requires that the output of `op` have " 

129 "the same outer-dimension size as flat_values of any ragged " 

130 "inputs. (output shape: %s; expected outer dimension size: %s)" % 

131 (op_output.shape, flat_values_nrows)) 

132 # Compose the result from the transformed values and the partitions. 

133 return ragged_tensor.RaggedTensor._from_nested_row_partitions( # pylint: disable=protected-access 

134 op_output, 

135 _merge_partition_lists(partition_lists), 

136 validate=False) 

137 

138 

139def _replace_ragged_with_flat_values(value, partition_lists, flat_values_nrows): 

140 """Replace RaggedTensors with their flat_values, and record their partitions. 

141 

142 Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their 

143 `flat_values` tensor. Looks inside lists, tuples, and dicts. 

144 

145 Appends each `RaggedTensor`'s `RowPartition`s to `partition_lists`. 

146 

147 Args: 

148 value: The value that should be transformed by replacing `RaggedTensors`. 

149 partition_lists: An output parameter used to record the row partitions 

150 for any `RaggedTensors` that were replaced. 

151 flat_values_nrows: An output parameter used to record the outer dimension 

152 size for each replacement `flat_values` (when known). Contains a list of 

153 int. 

154 

155 Returns: 

156 A copy of `value` with nested `RaggedTensors` replaced by their `values`. 

157 """ 

158 # Base case 

159 if ragged_tensor.is_ragged(value): 

160 value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value) 

161 partition_lists.append(value._nested_row_partitions) # pylint: disable=protected-access 

162 nrows = tensor_shape.dimension_at_index(value.flat_values.shape, 0).value 

163 if nrows is not None: 

164 flat_values_nrows.append(nrows) 

165 return value.flat_values 

166 

167 # Recursion cases 

168 def recurse(v): 

169 return _replace_ragged_with_flat_values(v, partition_lists, 

170 flat_values_nrows) 

171 

172 if isinstance(value, list): 

173 return [recurse(v) for v in value] 

174 elif isinstance(value, tuple): 

175 return tuple(recurse(v) for v in value) 

176 elif isinstance(value, dict): 

177 return dict((k, recurse(v)) for (k, v) in value.items()) 

178 else: 

179 return value 

180 

181 

182def _merge_partition_lists(partition_lists): 

183 """Merges the given list of lists of RowPartitions. 

184 

185 Args: 

186 partition_lists: A list of lists of RowPartition. 

187 

188 Returns: 

189 A list of RowPartitions, where `result[i]` is formed by merging 

190 `partition_lists[j][i]` for all `j`, using 

191 `RowPartition._merge_precomputed_encodings`. 

192 """ 

193 dst = list(partition_lists[0]) 

194 for src in partition_lists[1:]: 

195 if len(src) != len(dst): 

196 raise ValueError("All ragged inputs must have the same ragged_rank.") 

197 for i in range(len(dst)): 

198 # pylint: disable=protected-access 

199 dst[i] = dst[i]._merge_precomputed_encodings(src[i]) 

200 return dst