Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_concat_ops.py: 23%

111 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Concat and stack operations for RaggedTensors.""" 

16 

17import typing 

18 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_shape 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import array_ops_stack 

23from tensorflow.python.ops import check_ops 

24from tensorflow.python.ops import math_ops 

25from tensorflow.python.ops.ragged import ragged_gather_ops 

26from tensorflow.python.ops.ragged import ragged_tensor 

27from tensorflow.python.ops.ragged import ragged_util 

28from tensorflow.python.util import dispatch 

29from tensorflow.python.util.tf_export import tf_export 

30 

31 

32@dispatch.dispatch_for_api(array_ops.concat) 

33def concat(values: typing.List[ragged_tensor.RaggedOrDense], axis, name=None): 

34 """Concatenates potentially ragged tensors along one dimension. 

35 

36 Given a list of tensors with the same rank `K` (`K >= axis`), returns a 

37 rank-`K` `RaggedTensor` `result` such that `result[i0...iaxis]` is the 

38 concatenation of `[rt[i0...iaxis] for rt in values]`. 

39 

40 Args: 

41 values: A list of potentially ragged tensors. May not be empty. All 

42 `values` must have the same rank and the same dtype; but unlike 

43 `tf.concat`, they can have arbitrary shapes. 

44 axis: A python integer, indicating the dimension along which to concatenate. 

45 (Note: Unlike `tf.concat`, the `axis` parameter must be statically known.) 

46 Negative values are supported only if the rank of at least one 

47 `values` value is statically known. 

48 name: A name prefix for the returned tensor (optional). 

49 

50 Returns: 

51 A `RaggedTensor` with rank `K`. 

52 `result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`. 

53 

54 Raises: 

55 ValueError: If `values` is empty, if `axis` is out of bounds or if 

56 the input tensors have different ranks. 

57 

58 #### Example: 

59 

60 >>> t1 = tf.ragged.constant([[1, 2], [3, 4, 5]]) 

61 >>> t2 = tf.ragged.constant([[6], [7, 8, 9]]) 

62 >>> tf.concat([t1, t2], axis=0) 

63 <tf.RaggedTensor [[1, 2], [3, 4, 5], [6], [7, 8, 9]]> 

64 >>> tf.concat([t1, t2], axis=1) 

65 <tf.RaggedTensor [[1, 2, 6], [3, 4, 5, 7, 8, 9]]> 

66 """ 

67 if not isinstance(values, (list, tuple)): 

68 values = [values] 

69 with ops.name_scope(name, 'RaggedConcat', values): 

70 return _ragged_stack_concat_helper(values, axis, stack_values=False) 

71 

72 

73@tf_export('ragged.stack') 

74@dispatch.add_dispatch_support 

75@dispatch.dispatch_for_api(array_ops_stack.stack) 

76def stack(values: typing.List[ragged_tensor.RaggedOrDense], 

77 axis=0, 

78 name=None): 

79 """Stacks a list of rank-`R` tensors into one rank-`(R+1)` `RaggedTensor`. 

80 

81 Given a list of tensors or ragged tensors with the same rank `R` 

82 (`R >= axis`), returns a rank-`R+1` `RaggedTensor` `result` such that 

83 `result[i0...iaxis]` is `[value[i0...iaxis] for value in values]`. 

84 

85 #### Examples: 

86 

87 >>> # Stacking two ragged tensors. 

88 >>> t1 = tf.ragged.constant([[1, 2], [3, 4, 5]]) 

89 >>> t2 = tf.ragged.constant([[6], [7, 8, 9]]) 

90 >>> tf.ragged.stack([t1, t2], axis=0) 

91 <tf.RaggedTensor [[[1, 2], [3, 4, 5]], [[6], [7, 8, 9]]]> 

92 >>> tf.ragged.stack([t1, t2], axis=1) 

93 <tf.RaggedTensor [[[1, 2], [6]], [[3, 4, 5], [7, 8, 9]]]> 

94 

95 >>> # Stacking two dense tensors with different sizes. 

96 >>> t3 = tf.constant([[1, 2, 3], [4, 5, 6]]) 

97 >>> t4 = tf.constant([[5], [6], [7]]) 

98 >>> tf.ragged.stack([t3, t4], axis=0) 

99 <tf.RaggedTensor [[[1, 2, 3], [4, 5, 6]], [[5], [6], [7]]]> 

100 

101 Args: 

102 values: A list of `tf.Tensor` or `tf.RaggedTensor`. May not be empty. All 

103 `values` must have the same rank and the same dtype; but unlike 

104 `tf.stack`, they can have arbitrary dimension sizes. 

105 axis: A python integer, indicating the dimension along which to stack. 

106 (Note: Unlike `tf.stack`, the `axis` parameter must be statically known.) 

107 Negative values are supported only if the rank of at least one 

108 `values` value is statically known. 

109 name: A name prefix for the returned tensor (optional). 

110 

111 Returns: 

112 A `RaggedTensor` with rank `R+1` (if `R>0`). 

113 If `R==0`, then the result will be returned as a 1D `Tensor`, since 

114 `RaggedTensor` can only be used when `rank>1`. 

115 `result.ragged_rank=1+max(axis, max(rt.ragged_rank for rt in values]))`. 

116 

117 Raises: 

118 ValueError: If `values` is empty, if `axis` is out of bounds or if 

119 the input tensors have different ranks. 

120 """ 

121 if not isinstance(values, (list, tuple)): 

122 values = [values] 

123 with ops.name_scope(name, 'RaggedConcat', values): 

124 return _ragged_stack_concat_helper(values, axis, stack_values=True) 

125 

126 

127def _ragged_stack_concat_helper(rt_inputs, axis, stack_values): 

128 """Helper function to concatenate or stack ragged tensors. 

129 

130 Args: 

131 rt_inputs: A list of RaggedTensors or Tensors to combine. 

132 axis: The axis along which to concatenate or stack. 

133 stack_values: A boolean -- if true, then stack values; otherwise, 

134 concatenate them. 

135 

136 Returns: 

137 A RaggedTensor. 

138 Raises: 

139 ValueError: If rt_inputs is empty, or if axis is out of range. 

140 """ 

141 # Validate parameters. 

142 if not rt_inputs: 

143 raise ValueError('rt_inputs may not be empty.') 

144 

145 # Convert input tensors. 

146 rt_inputs = [ 

147 ragged_tensor.convert_to_tensor_or_ragged_tensor( 

148 rt_input, name='rt_input') for rt_input in rt_inputs 

149 ] 

150 row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes( 

151 *rt_inputs, return_dtype=True) 

152 rt_inputs = list(rt_inputs) 

153 

154 # Special case: if there's only one input, then return it as-is. 

155 if len(rt_inputs) == 1 and not stack_values: 

156 return rt_inputs[0] 

157 

158 # Check the rank (number of dimensions) of the input tensors. 

159 ndims = None 

160 for rt in rt_inputs: 

161 if ndims is None: 

162 ndims = rt.shape.ndims 

163 else: 

164 rt.shape.assert_has_rank(ndims) 

165 

166 out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1 

167 axis = array_ops.get_positive_axis(axis, out_ndims) 

168 

169 if stack_values and ndims == 1 and axis == 0: 

170 return ragged_tensor.RaggedTensor.from_row_lengths( 

171 values=array_ops.concat(rt_inputs, axis=0), 

172 row_lengths=array_ops.concat([array_ops.shape(r) for r in rt_inputs], 

173 axis=0)) 

174 

175 # If all the inputs are Tensors, and we're combining the final dimension, 

176 # then we can delegate to the tf.stack/tf.concat operation, and return a 

177 # Tensor. 

178 if all(not ragged_tensor.is_ragged(rt) for rt in rt_inputs): 

179 if ndims is not None and (axis == out_ndims - 1 or axis == ndims - 1): 

180 if stack_values: 

181 return array_ops_stack.stack(rt_inputs, axis) 

182 else: 

183 return array_ops.concat(rt_inputs, axis) 

184 

185 # Convert any Tensor inputs to RaggedTensors. This makes it 

186 # possible to concatenate Tensors and RaggedTensors together. 

187 for i in range(len(rt_inputs)): 

188 if not ragged_tensor.is_ragged(rt_inputs[i]): 

189 rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor( 

190 rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype) 

191 

192 # Convert the input tensors to all have the same ragged_rank. 

193 ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1) 

194 rt_inputs = [_increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype) 

195 for rt in rt_inputs] 

196 

197 if axis == 0: 

198 return _ragged_stack_concat_axis_0(rt_inputs, stack_values) 

199 elif axis == 1: 

200 return _ragged_stack_concat_axis_1(rt_inputs, stack_values) 

201 else: # axis > 1: recurse. 

202 values = [rt.values for rt in rt_inputs] 

203 splits = [[rt_input.row_splits] for rt_input in rt_inputs] 

204 with ops.control_dependencies(ragged_util.assert_splits_match(splits)): 

205 return ragged_tensor.RaggedTensor.from_row_splits( 

206 _ragged_stack_concat_helper(values, axis - 1, stack_values), 

207 splits[0][0], validate=False) 

208 

209 

210def _ragged_stack_concat_axis_0(rt_inputs, stack_values): 

211 """Helper function to concatenate or stack ragged tensors along axis 0. 

212 

213 Args: 

214 rt_inputs: A list of RaggedTensors, all with the same rank and ragged_rank. 

215 stack_values: Boolean. If true, then stack values; otherwise, concatenate 

216 them. 

217 

218 Returns: 

219 A RaggedTensor. 

220 """ 

221 # Concatenate the inner values together. 

222 flat_values = [rt.flat_values for rt in rt_inputs] 

223 concatenated_flat_values = array_ops.concat(flat_values, axis=0) 

224 

225 # Concatenate the splits together for each ragged dimension (adjusting 

226 # split offsets as necessary). 

227 nested_splits = [rt.nested_row_splits for rt in rt_inputs] 

228 ragged_rank = rt_inputs[0].ragged_rank 

229 concatenated_nested_splits = [ 

230 _concat_ragged_splits([ns[dim] 

231 for ns in nested_splits]) 

232 for dim in range(ragged_rank) 

233 ] 

234 

235 # If we are performing a stack operation, then add another splits. 

236 if stack_values: 

237 stack_lengths = array_ops_stack.stack([rt.nrows() for rt in rt_inputs]) 

238 stack_splits = ragged_util.lengths_to_splits(stack_lengths) 

239 concatenated_nested_splits.insert(0, stack_splits) 

240 

241 return ragged_tensor.RaggedTensor.from_nested_row_splits( 

242 concatenated_flat_values, concatenated_nested_splits, validate=False) 

243 

244 

245def _ragged_stack_concat_axis_1(rt_inputs, stack_values): 

246 """Helper function to concatenate or stack ragged tensors along axis 1. 

247 

248 Args: 

249 rt_inputs: A list of RaggedTensors, all with the same rank and ragged_rank. 

250 stack_values: Boolean. If true, then stack values; otherwise, concatenate 

251 them. 

252 

253 Returns: 

254 A RaggedTensor. 

255 """ 

256 num_inputs = len(rt_inputs) 

257 

258 nrows_checks = [] 

259 rt_nrows = rt_inputs[0].nrows() 

260 for index, rt in enumerate(rt_inputs[1:]): 

261 nrows_checks.append( 

262 check_ops.assert_equal( 

263 rt_nrows, 

264 rt.nrows(), 

265 message=( 

266 f'Input tensors at index 0 (=x) and {index+1} (=y) have' 

267 ' incompatible shapes.' 

268 ), 

269 ) 

270 ) 

271 

272 with ops.control_dependencies(nrows_checks): 

273 # Concatenate the inputs together to put them in a single ragged tensor. 

274 concatenated_rt = _ragged_stack_concat_axis_0(rt_inputs, stack_values=False) 

275 

276 # Use ragged.gather to permute the rows of concatenated_rt. In particular, 

277 # permuted_rt = [rt_inputs[0][0], ..., rt_inputs[N][0], 

278 # rt_inputs[0][1], ..., rt_inputs[N][1], 

279 # ..., 

280 # rt_inputs[0][M], ..., rt_input[N][M]] 

281 # where `N=num_inputs-1` and `M=rt_nrows-1`. 

282 row_indices = math_ops.range(rt_nrows * num_inputs) 

283 row_index_matrix = array_ops.reshape(row_indices, [num_inputs, -1]) 

284 transposed_row_index_matrix = array_ops.transpose(row_index_matrix) 

285 row_permutation = array_ops.reshape(transposed_row_index_matrix, [-1]) 

286 permuted_rt = ragged_gather_ops.gather(concatenated_rt, row_permutation) 

287 

288 if stack_values: 

289 # Add a new splits tensor to group together the values. 

290 stack_splits = math_ops.range(0, rt_nrows * num_inputs + 1, num_inputs) 

291 _copy_row_shape(rt_inputs, stack_splits) 

292 return ragged_tensor.RaggedTensor.from_row_splits( 

293 permuted_rt, stack_splits, validate=False) 

294 else: 

295 # Merge together adjacent rows by dropping the row-split indices that 

296 # separate them. 

297 concat_splits = permuted_rt.row_splits[::num_inputs] 

298 _copy_row_shape(rt_inputs, concat_splits) 

299 return ragged_tensor.RaggedTensor.from_row_splits( 

300 permuted_rt.values, concat_splits, validate=False) 

301 

302 

303def _copy_row_shape(rt_inputs, splits): 

304 """Sets splits.shape to [rt[shape[0]+1] for each rt in rt_inputs.""" 

305 for rt in rt_inputs: 

306 if rt.shape[0] is not None: 

307 splits.set_shape(tensor_shape.TensorShape(rt.shape[0] + 1)) 

308 

309 

310def _increase_ragged_rank_to(rt_input, ragged_rank, row_splits_dtype): 

311 """Adds ragged dimensions to `rt_input` so it has the desired ragged rank.""" 

312 if ragged_rank > 0: 

313 if not ragged_tensor.is_ragged(rt_input): 

314 rt_input = ragged_tensor.RaggedTensor.from_tensor( 

315 rt_input, row_splits_dtype=row_splits_dtype) 

316 if rt_input.ragged_rank < ragged_rank: 

317 rt_input = rt_input.with_values( 

318 _increase_ragged_rank_to(rt_input.values, ragged_rank - 1, 

319 row_splits_dtype)) 

320 return rt_input 

321 

322 

323def _concat_ragged_splits(splits_list): 

324 """Concatenates a list of RaggedTensor splits to form a single splits.""" 

325 pieces = [splits_list[0]] 

326 splits_offset = splits_list[0][-1] 

327 for splits in splits_list[1:]: 

328 pieces.append(splits[1:] + splits_offset) 

329 splits_offset += splits[-1] 

330 return array_ops.concat(pieces, axis=0)