Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_concat_ops.py: 23%
111 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Concat and stack operations for RaggedTensors."""
17import typing
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import array_ops_stack
23from tensorflow.python.ops import check_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops.ragged import ragged_gather_ops
26from tensorflow.python.ops.ragged import ragged_tensor
27from tensorflow.python.ops.ragged import ragged_util
28from tensorflow.python.util import dispatch
29from tensorflow.python.util.tf_export import tf_export
32@dispatch.dispatch_for_api(array_ops.concat)
33def concat(values: typing.List[ragged_tensor.RaggedOrDense], axis, name=None):
34 """Concatenates potentially ragged tensors along one dimension.
36 Given a list of tensors with the same rank `K` (`K >= axis`), returns a
37 rank-`K` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
38 concatenation of `[rt[i0...iaxis] for rt in values]`.
40 Args:
41 values: A list of potentially ragged tensors. May not be empty. All
42 `values` must have the same rank and the same dtype; but unlike
43 `tf.concat`, they can have arbitrary shapes.
44 axis: A python integer, indicating the dimension along which to concatenate.
45 (Note: Unlike `tf.concat`, the `axis` parameter must be statically known.)
46 Negative values are supported only if the rank of at least one
47 `values` value is statically known.
48 name: A name prefix for the returned tensor (optional).
50 Returns:
51 A `RaggedTensor` with rank `K`.
52 `result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`.
54 Raises:
55 ValueError: If `values` is empty, if `axis` is out of bounds or if
56 the input tensors have different ranks.
58 #### Example:
60 >>> t1 = tf.ragged.constant([[1, 2], [3, 4, 5]])
61 >>> t2 = tf.ragged.constant([[6], [7, 8, 9]])
62 >>> tf.concat([t1, t2], axis=0)
63 <tf.RaggedTensor [[1, 2], [3, 4, 5], [6], [7, 8, 9]]>
64 >>> tf.concat([t1, t2], axis=1)
65 <tf.RaggedTensor [[1, 2, 6], [3, 4, 5, 7, 8, 9]]>
66 """
67 if not isinstance(values, (list, tuple)):
68 values = [values]
69 with ops.name_scope(name, 'RaggedConcat', values):
70 return _ragged_stack_concat_helper(values, axis, stack_values=False)
73@tf_export('ragged.stack')
74@dispatch.add_dispatch_support
75@dispatch.dispatch_for_api(array_ops_stack.stack)
76def stack(values: typing.List[ragged_tensor.RaggedOrDense],
77 axis=0,
78 name=None):
79 """Stacks a list of rank-`R` tensors into one rank-`(R+1)` `RaggedTensor`.
81 Given a list of tensors or ragged tensors with the same rank `R`
82 (`R >= axis`), returns a rank-`R+1` `RaggedTensor` `result` such that
83 `result[i0...iaxis]` is `[value[i0...iaxis] for value in values]`.
85 #### Examples:
87 >>> # Stacking two ragged tensors.
88 >>> t1 = tf.ragged.constant([[1, 2], [3, 4, 5]])
89 >>> t2 = tf.ragged.constant([[6], [7, 8, 9]])
90 >>> tf.ragged.stack([t1, t2], axis=0)
91 <tf.RaggedTensor [[[1, 2], [3, 4, 5]], [[6], [7, 8, 9]]]>
92 >>> tf.ragged.stack([t1, t2], axis=1)
93 <tf.RaggedTensor [[[1, 2], [6]], [[3, 4, 5], [7, 8, 9]]]>
95 >>> # Stacking two dense tensors with different sizes.
96 >>> t3 = tf.constant([[1, 2, 3], [4, 5, 6]])
97 >>> t4 = tf.constant([[5], [6], [7]])
98 >>> tf.ragged.stack([t3, t4], axis=0)
99 <tf.RaggedTensor [[[1, 2, 3], [4, 5, 6]], [[5], [6], [7]]]>
101 Args:
102 values: A list of `tf.Tensor` or `tf.RaggedTensor`. May not be empty. All
103 `values` must have the same rank and the same dtype; but unlike
104 `tf.stack`, they can have arbitrary dimension sizes.
105 axis: A python integer, indicating the dimension along which to stack.
106 (Note: Unlike `tf.stack`, the `axis` parameter must be statically known.)
107 Negative values are supported only if the rank of at least one
108 `values` value is statically known.
109 name: A name prefix for the returned tensor (optional).
111 Returns:
112 A `RaggedTensor` with rank `R+1` (if `R>0`).
113 If `R==0`, then the result will be returned as a 1D `Tensor`, since
114 `RaggedTensor` can only be used when `rank>1`.
115 `result.ragged_rank=1+max(axis, max(rt.ragged_rank for rt in values]))`.
117 Raises:
118 ValueError: If `values` is empty, if `axis` is out of bounds or if
119 the input tensors have different ranks.
120 """
121 if not isinstance(values, (list, tuple)):
122 values = [values]
123 with ops.name_scope(name, 'RaggedConcat', values):
124 return _ragged_stack_concat_helper(values, axis, stack_values=True)
127def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
128 """Helper function to concatenate or stack ragged tensors.
130 Args:
131 rt_inputs: A list of RaggedTensors or Tensors to combine.
132 axis: The axis along which to concatenate or stack.
133 stack_values: A boolean -- if true, then stack values; otherwise,
134 concatenate them.
136 Returns:
137 A RaggedTensor.
138 Raises:
139 ValueError: If rt_inputs is empty, or if axis is out of range.
140 """
141 # Validate parameters.
142 if not rt_inputs:
143 raise ValueError('rt_inputs may not be empty.')
145 # Convert input tensors.
146 rt_inputs = [
147 ragged_tensor.convert_to_tensor_or_ragged_tensor(
148 rt_input, name='rt_input') for rt_input in rt_inputs
149 ]
150 row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes(
151 *rt_inputs, return_dtype=True)
152 rt_inputs = list(rt_inputs)
154 # Special case: if there's only one input, then return it as-is.
155 if len(rt_inputs) == 1 and not stack_values:
156 return rt_inputs[0]
158 # Check the rank (number of dimensions) of the input tensors.
159 ndims = None
160 for rt in rt_inputs:
161 if ndims is None:
162 ndims = rt.shape.ndims
163 else:
164 rt.shape.assert_has_rank(ndims)
166 out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1
167 axis = array_ops.get_positive_axis(axis, out_ndims)
169 if stack_values and ndims == 1 and axis == 0:
170 return ragged_tensor.RaggedTensor.from_row_lengths(
171 values=array_ops.concat(rt_inputs, axis=0),
172 row_lengths=array_ops.concat([array_ops.shape(r) for r in rt_inputs],
173 axis=0))
175 # If all the inputs are Tensors, and we're combining the final dimension,
176 # then we can delegate to the tf.stack/tf.concat operation, and return a
177 # Tensor.
178 if all(not ragged_tensor.is_ragged(rt) for rt in rt_inputs):
179 if ndims is not None and (axis == out_ndims - 1 or axis == ndims - 1):
180 if stack_values:
181 return array_ops_stack.stack(rt_inputs, axis)
182 else:
183 return array_ops.concat(rt_inputs, axis)
185 # Convert any Tensor inputs to RaggedTensors. This makes it
186 # possible to concatenate Tensors and RaggedTensors together.
187 for i in range(len(rt_inputs)):
188 if not ragged_tensor.is_ragged(rt_inputs[i]):
189 rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor(
190 rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype)
192 # Convert the input tensors to all have the same ragged_rank.
193 ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1)
194 rt_inputs = [_increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype)
195 for rt in rt_inputs]
197 if axis == 0:
198 return _ragged_stack_concat_axis_0(rt_inputs, stack_values)
199 elif axis == 1:
200 return _ragged_stack_concat_axis_1(rt_inputs, stack_values)
201 else: # axis > 1: recurse.
202 values = [rt.values for rt in rt_inputs]
203 splits = [[rt_input.row_splits] for rt_input in rt_inputs]
204 with ops.control_dependencies(ragged_util.assert_splits_match(splits)):
205 return ragged_tensor.RaggedTensor.from_row_splits(
206 _ragged_stack_concat_helper(values, axis - 1, stack_values),
207 splits[0][0], validate=False)
210def _ragged_stack_concat_axis_0(rt_inputs, stack_values):
211 """Helper function to concatenate or stack ragged tensors along axis 0.
213 Args:
214 rt_inputs: A list of RaggedTensors, all with the same rank and ragged_rank.
215 stack_values: Boolean. If true, then stack values; otherwise, concatenate
216 them.
218 Returns:
219 A RaggedTensor.
220 """
221 # Concatenate the inner values together.
222 flat_values = [rt.flat_values for rt in rt_inputs]
223 concatenated_flat_values = array_ops.concat(flat_values, axis=0)
225 # Concatenate the splits together for each ragged dimension (adjusting
226 # split offsets as necessary).
227 nested_splits = [rt.nested_row_splits for rt in rt_inputs]
228 ragged_rank = rt_inputs[0].ragged_rank
229 concatenated_nested_splits = [
230 _concat_ragged_splits([ns[dim]
231 for ns in nested_splits])
232 for dim in range(ragged_rank)
233 ]
235 # If we are performing a stack operation, then add another splits.
236 if stack_values:
237 stack_lengths = array_ops_stack.stack([rt.nrows() for rt in rt_inputs])
238 stack_splits = ragged_util.lengths_to_splits(stack_lengths)
239 concatenated_nested_splits.insert(0, stack_splits)
241 return ragged_tensor.RaggedTensor.from_nested_row_splits(
242 concatenated_flat_values, concatenated_nested_splits, validate=False)
245def _ragged_stack_concat_axis_1(rt_inputs, stack_values):
246 """Helper function to concatenate or stack ragged tensors along axis 1.
248 Args:
249 rt_inputs: A list of RaggedTensors, all with the same rank and ragged_rank.
250 stack_values: Boolean. If true, then stack values; otherwise, concatenate
251 them.
253 Returns:
254 A RaggedTensor.
255 """
256 num_inputs = len(rt_inputs)
258 nrows_checks = []
259 rt_nrows = rt_inputs[0].nrows()
260 for index, rt in enumerate(rt_inputs[1:]):
261 nrows_checks.append(
262 check_ops.assert_equal(
263 rt_nrows,
264 rt.nrows(),
265 message=(
266 f'Input tensors at index 0 (=x) and {index+1} (=y) have'
267 ' incompatible shapes.'
268 ),
269 )
270 )
272 with ops.control_dependencies(nrows_checks):
273 # Concatenate the inputs together to put them in a single ragged tensor.
274 concatenated_rt = _ragged_stack_concat_axis_0(rt_inputs, stack_values=False)
276 # Use ragged.gather to permute the rows of concatenated_rt. In particular,
277 # permuted_rt = [rt_inputs[0][0], ..., rt_inputs[N][0],
278 # rt_inputs[0][1], ..., rt_inputs[N][1],
279 # ...,
280 # rt_inputs[0][M], ..., rt_input[N][M]]
281 # where `N=num_inputs-1` and `M=rt_nrows-1`.
282 row_indices = math_ops.range(rt_nrows * num_inputs)
283 row_index_matrix = array_ops.reshape(row_indices, [num_inputs, -1])
284 transposed_row_index_matrix = array_ops.transpose(row_index_matrix)
285 row_permutation = array_ops.reshape(transposed_row_index_matrix, [-1])
286 permuted_rt = ragged_gather_ops.gather(concatenated_rt, row_permutation)
288 if stack_values:
289 # Add a new splits tensor to group together the values.
290 stack_splits = math_ops.range(0, rt_nrows * num_inputs + 1, num_inputs)
291 _copy_row_shape(rt_inputs, stack_splits)
292 return ragged_tensor.RaggedTensor.from_row_splits(
293 permuted_rt, stack_splits, validate=False)
294 else:
295 # Merge together adjacent rows by dropping the row-split indices that
296 # separate them.
297 concat_splits = permuted_rt.row_splits[::num_inputs]
298 _copy_row_shape(rt_inputs, concat_splits)
299 return ragged_tensor.RaggedTensor.from_row_splits(
300 permuted_rt.values, concat_splits, validate=False)
303def _copy_row_shape(rt_inputs, splits):
304 """Sets splits.shape to [rt[shape[0]+1] for each rt in rt_inputs."""
305 for rt in rt_inputs:
306 if rt.shape[0] is not None:
307 splits.set_shape(tensor_shape.TensorShape(rt.shape[0] + 1))
310def _increase_ragged_rank_to(rt_input, ragged_rank, row_splits_dtype):
311 """Adds ragged dimensions to `rt_input` so it has the desired ragged rank."""
312 if ragged_rank > 0:
313 if not ragged_tensor.is_ragged(rt_input):
314 rt_input = ragged_tensor.RaggedTensor.from_tensor(
315 rt_input, row_splits_dtype=row_splits_dtype)
316 if rt_input.ragged_rank < ragged_rank:
317 rt_input = rt_input.with_values(
318 _increase_ragged_rank_to(rt_input.values, ragged_rank - 1,
319 row_splits_dtype))
320 return rt_input
323def _concat_ragged_splits(splits_list):
324 """Concatenates a list of RaggedTensor splits to form a single splits."""
325 pieces = [splits_list[0]]
326 splits_offset = splits_list[0][-1]
327 for splits in splits_list[1:]:
328 pieces.append(splits[1:] + splits_offset)
329 splits_offset += splits[-1]
330 return array_ops.concat(pieces, axis=0)