Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_gather_ops.py: 15%
187 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gather operations for RaggedTensors."""
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import indexed_slices
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import gen_ragged_array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops.ragged import ragged_array_ops
25from tensorflow.python.ops.ragged import ragged_math_ops
26from tensorflow.python.ops.ragged import ragged_tensor
27from tensorflow.python.util import dispatch
30#===============================================================================
31# ragged_gather
32#===============================================================================
33@dispatch.dispatch_for_api(array_ops.gather_v2)
34def gather(params: ragged_tensor.RaggedOrDense,
35 indices: ragged_tensor.RaggedOrDense,
36 validate_indices=None,
37 axis=None,
38 batch_dims=0,
39 name=None):
40 """Gathers ragged slices from `params` axis `0` according to `indices`.
42 See `tf.gather` for full documentation. (This version has the same API
43 as `tf.gather`, but supports ragged `params` and `indices`.)
45 Examples:
47 >>> params = tf.constant(['a', 'b', 'c', 'd', 'e'])
48 >>> indices = tf.constant([3, 1, 2, 1, 0])
49 >>> ragged_params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
50 >>> ragged_indices = tf.ragged.constant([[3, 1, 2], [1], [], [0]])
52 >>> tf.gather(params, ragged_indices)
53 <tf.RaggedTensor [[b'd', b'b', b'c'], [b'b'], [], [b'a']]>
55 >>> tf.gather(ragged_params, indices)
56 <tf.RaggedTensor [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']]>
58 >>> tf.gather(ragged_params, ragged_indices)
59 <tf.RaggedTensor [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]]>
61 Args:
62 params: The potentially ragged tensor from which to gather values. Must be
63 at least rank 1.
64 indices: The potentially ragged tensor indicating which values to gather.
65 Must have dtype `int32` or `int64`. Values must be in the range `[0,
66 params.shape[0]]`.
67 validate_indices: Ignored.
68 axis: The axis in `params` to gather `indices` from.
69 batch_dims: The number of batch dimensions.
70 name: A name for the operation (optional).
72 Returns:
73 A `RaggedTensor`, where `output.dtype=params.dtype` and
74 `output.shape=indices.shape + params.shape[1:]` and
75 `output.ragged_rank=indices.shape.ndims + params.ragged_rank`.
77 Raises:
78 ValueError: If indices.shape.ndims is not known statically.
79 """
80 del validate_indices
82 with ops.name_scope(name, 'RaggedGather', [params, indices]):
83 params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
84 params, name='params')
85 indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
86 indices, name='indices')
87 params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
89 if batch_dims != indices.shape.rank:
90 batch_dims = array_ops.get_positive_axis(
91 batch_dims,
92 indices.shape.rank,
93 axis_name='batch_dims',
94 ndims_name='rank(indices)')
95 if params.shape.rank is not None and batch_dims >= params.shape.rank:
96 raise ValueError('batch_dims must be less than rank(params)')
97 if axis is None:
98 axis = batch_dims
99 axis = array_ops.get_positive_axis(
100 axis, params.shape.rank, ndims_name='rank(params)')
101 if axis < batch_dims:
102 raise ValueError('axis must be greater than or equal to batch_dims')
103 if indices.shape.rank is not None:
104 if not 0 <= batch_dims <= indices.shape.rank:
105 raise ValueError(
106 'batch_dims=%s must be between 0 and rank(indices)=%s' %
107 (batch_dims, indices.shape.rank))
109 return _gather(params, indices, axis, batch_dims)
112def _gather(params, indices, axis, batch_dims):
113 """Helper that implements the body for ragged gather().
115 Assumes that `params` and `indices` have been converted to tensors or
116 ragged tensors, and that `axis` and `batch_dims` have been normalized to
117 be positive. (So these conversions & normalizations can be skipped in
118 recursive calls to _gather).
120 Args:
121 params: The tensor from which to gather values.
122 indices: The indices of values to gather.
123 axis: The axis in `params` to gather `indices` from.
124 batch_dims: The number of batch dimensions.
126 Returns:
127 A potentially ragged tensor.
128 """
129 params_is_ragged = ragged_tensor.is_ragged(params)
130 indices_is_ragged = ragged_tensor.is_ragged(indices)
132 if not (params_is_ragged or indices_is_ragged):
133 return array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
135 if batch_dims > 0:
136 return _batch_gather(params, indices, axis, batch_dims)
138 if axis > 0:
139 return _axis_gather(params, indices, axis)
141 if indices_is_ragged:
142 return indices.with_values(_gather(params, indices.values, 0, 0))
144 if indices.shape.ndims is None:
145 raise ValueError('rank(indices) must be known statically')
147 out_ragged_rank = indices.shape.ndims + len(params.nested_row_splits) - 1
148 result = gen_ragged_array_ops.ragged_gather(
149 indices=indices,
150 params_dense_values=params.flat_values,
151 params_nested_splits=params.nested_row_splits,
152 OUTPUT_RAGGED_RANK=out_ragged_rank)
154 result = ragged_tensor.RaggedTensor.from_nested_row_splits(
155 result.output_dense_values, result.output_nested_splits, validate=False)
157 # Inject uniform_row_lengths into the result RaggedTensors for dimensions
158 # corresponding to dense outer dimensions of `indices`.
159 # TODO(edloper): Change this to construct the result using RowPartition
160 # objects instead, so we don't need to modify private variables.
161 if indices.shape.ndims > 1:
162 target = result
163 indices_shape = array_ops.shape(indices, out_type=params.row_splits.dtype)
164 shape_cumprod = math_ops.cumprod(indices_shape)
165 for dim in range(indices.shape.ndims - 1):
166 # pylint: disable=protected-access
167 target._cached_nrows = shape_cumprod[dim]
168 target._uniform_row_length = indices_shape[dim + 1]
169 target = target.values
171 return result
174def _batch_gather(params, indices, axis, batch_dims):
175 """Helper that implements the body for ragged gather() when batch_dims>0.
177 Args:
178 params: The tensor from which to gather values.
179 indices: The indices of values to gather.
180 axis: The axis in `params` to gather `indices` from.
181 batch_dims: The number of batch dimensions.
183 Returns:
184 A potentially ragged tensor.
185 """
186 # Perform static checks that `params` and `indices` have compatible batch
187 # dimensions. Note: we do not perform *runtime* checks that `params` and
188 # `indices` actually have the same row-splits (because we wish to avoid the
189 # runtime cost of those checks). If `params` and `indices` are
190 # incompatible, the resulting `RaggedTensor` may be nonsensical.
191 if not params.shape[:batch_dims].is_compatible_with(
192 indices.shape[:batch_dims]):
193 raise ValueError('batch shape from indices %s does not match params '
194 'shape %s' % (indices.shape[:batch_dims], params.shape))
196 if batch_dims > 1:
197 # Convert params & indices to ragged tensors.
198 if not isinstance(params, ragged_tensor.RaggedTensor):
199 if indices.uniform_row_length is None:
200 raise ValueError(
201 'batch shape from indices does not match params shape: ragged '
202 'indices dimension corresponds to uniform params dimension')
203 params = ragged_tensor.RaggedTensor.from_tensor(
204 params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype)
205 if not isinstance(indices, ragged_tensor.RaggedTensor):
206 if params.uniform_row_length is None:
207 raise ValueError(
208 'batch shape from indices does not match params shape: ragged '
209 'params dimension corresponds to uniform indices dimension')
210 indices = ragged_tensor.RaggedTensor.from_tensor(
211 indices, ragged_rank=1, row_splits_dtype=params.row_splits.dtype)
212 # Flatten the two outer batch dimensions into a single batch dimension,
213 # and recurse.
214 return params.with_values(
215 _gather(params.values, indices.values, axis - 1, batch_dims - 1))
217 if axis > 1:
218 # Convert an axis dimension into a batch dimension, by adding a dimension
219 # to `indices`, and tiling it to match `params`. E.g., if `params`
220 # had shape `[B, P1, P2]`, and `indices` had shape `[B, I1, I2]`, then we
221 # tile `indices` to have shape `[B, P1, I1, I2]`. That way, we can treat
222 # the `P1` dimension as a batch dimension.
223 if not isinstance(indices, ragged_tensor.RaggedTensor):
224 adjusted_indices = params.with_values(
225 array_ops.repeat(indices, params.row_lengths(), 0))
226 else:
227 if not isinstance(params, ragged_tensor.RaggedTensor):
228 params = ragged_tensor.RaggedTensor.from_tensor(
229 params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype)
230 adjusted_indices = _gather(
231 indices,
232 params.with_values(
233 array_ops.repeat(
234 math_ops.range(params.nrows()), params.row_lengths())), 0, 0)
235 return _batch_gather(params, adjusted_indices, axis, batch_dims + 1)
237 if indices.shape.rank is None:
238 raise ValueError('rank(indices) must be known statically')
240 assert batch_dims == 1
241 # If params.shape=[B, P1...PN] and indices.shape=[B, I1...IM], then:
242 #
243 # output[b, i1...im, p2...pn] =
244 # params[b, indices[b, i1...im], p2...pn]
245 #
246 # We construct `output` by flattening `params`, adjusting the `indices` to
247 # point into that flattened list, and recursively calling `gather`.
248 flat_params = _flatten_dims_0_and_1(params)
249 adjustments = _row_starts(params, indices.dtype) # offset for each batch
250 # increase adjustments's rank so it broadcasts w/ the outer dim of indices
251 adjustments = _increase_rank_to(adjustments, indices.shape.ndims)
252 adjusted_indices = indices + adjustments
253 return _gather(flat_params, adjusted_indices, axis - 1, 0)
256def _axis_gather(params, indices, axis):
257 """Helper that implements ragged gather when axis>0 and batch_dims==0.
259 Args:
260 params: The tensor from which to gather values.
261 indices: The indices of values to gather.
262 axis: The axis in `params` to gather `indices` from.
264 Returns:
265 A potentially ragged tensor.
266 """
267 if axis > 1:
268 if not isinstance(params, ragged_tensor.RaggedTensor):
269 params = ragged_tensor.RaggedTensor.from_tensor(
270 params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype)
271 # Recurse, using the flattened params (but do not flatten indices).
272 return params.with_values(_gather(params.values, indices, axis - 1, 0))
274 if indices.shape.rank is None:
275 raise ValueError('rank(indices) must be known statically')
277 # Note: there is no checking of indices. If there is some index
278 # out of bounds, the results may be nonsensical.
280 assert axis == 1
281 # If params.shape=[P1...PN] and indices.shape=[I1...IM], then:
282 #
283 # output[p1, i1...im, p3...pn] =
284 # params[p1, indices[i1...im], p3...pn]
285 #
286 # We construct `output` by flattening `params`, adjusting the `indices` to
287 # have one additional dimension, and to point into that flattened list, and
288 # recursively calling `gather`.
289 flat_params = _flatten_dims_0_and_1(params)
290 adjustments = _row_starts(params, indices.dtype) # offset for each batch
291 adjustments = _increase_rank_to(adjustments, indices.shape.ndims + 1)
292 adjusted_indices = indices + adjustments
293 return _gather(flat_params, adjusted_indices, axis - 1, 0)
296def _flatten_dims_0_and_1(t):
297 """Returns a copy of `t` with the outer two dimensions merged."""
298 if isinstance(t, ragged_tensor.RaggedTensor):
299 return t.values
300 else:
301 t_shape = array_ops.shape(t)
302 return array_ops.reshape(t, array_ops.concat([[-1], t_shape[2:]], axis=0))
305def _row_starts(t, dtype):
306 """Returns the start indices for the rows in `t`."""
307 if isinstance(t, ragged_tensor.RaggedTensor):
308 return math_ops.cast(t.row_starts(), dtype)
309 else:
310 t_shape = array_ops.shape(t, out_type=dtype)
311 return math_ops.range(t_shape[0]) * t_shape[1]
314def _increase_rank_to(t, rank):
315 """Adds *trailing* size-1 dimensions to `t` until it has the given rank."""
316 if isinstance(t, ragged_tensor.RaggedTensor):
317 return t.with_values(_increase_rank_to(t, rank - 1))
318 else:
319 old_dims = array_ops.shape(t)
320 new_dims = array_ops.ones([rank - array_ops.rank(t)], old_dims.dtype)
321 new_shape = array_ops.concat([old_dims, new_dims], axis=0)
322 return array_ops.reshape(t, new_shape)
325@dispatch.dispatch_for_api(array_ops.gather)
326def _ragged_gather_v1(params: ragged_tensor.RaggedOrDense,
327 indices: ragged_tensor.RaggedOrDense,
328 validate_indices=None,
329 name=None,
330 axis=0,
331 batch_dims=0):
332 return gather(params, indices, validate_indices, axis, batch_dims, name)
335#===============================================================================
336# ragged.gather_nd
337#===============================================================================
338@dispatch.dispatch_for_api(array_ops.gather_nd_v2)
339def gather_nd(params: ragged_tensor.RaggedOrDense,
340 indices: ragged_tensor.RaggedOrDense,
341 batch_dims=0,
342 name=None):
343 """Gather slices from `params` using `n`-dimensional indices.
345 This operation is similar to `gather`, but it uses the innermost dimension
346 of `indices` to define a slice into `params`. In particular, if:
348 * `indices` has shape `[A1...AN, I]`
349 * `params` has shape `[B1...BM]`
351 Then:
353 * `result` has shape `[A1...AN, B_{I+1}...BM]`.
354 * `result[a1...aN] = params[indices[a1...aN, :]]`
356 Args:
357 params: A potentially ragged tensor with shape `[A1...AN, I]`.
358 indices: A potentially ragged tensor with shape `[B1...BM]`.
359 batch_dims: Must be zero.
360 name: A name for the operation (optional).
362 Returns:
363 A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`.
365 #### Examples:
367 >>> params = tf.ragged.constant(
368 ... [ [ ['000', '001'], ['010' ] ],
369 ... [ ['100' ], ['110', '111', '112'], ['120'] ],
370 ... [ [ ], ['210' ] ] ])
372 >>> # Gather 2D slices from a 3D tensor
373 >>> tf.gather_nd(params, [[2], [0]])
374 <tf.RaggedTensor [[[], [b'210']], [[b'000', b'001'], [b'010']]]>
376 >>> # Gather 1D slices from a 3D tensor
377 >>> tf.gather_nd(params, [[2, 1], [0, 0]])
378 <tf.RaggedTensor [[b'210'], [b'000', b'001']]>
380 >>> # Gather scalars from a 3D tensor
381 >>> tf.gather_nd(params, [[0, 0, 1], [1, 1, 2]]).numpy()
382 array([b'001', b'112'], dtype=object)
383 """
384 if not isinstance(batch_dims, int) or batch_dims != 0:
385 raise ValueError('batch_dims != 0 is not supported for ragged gather yet.')
386 if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)):
387 return array_ops.gather_nd(params, indices, name)
389 with ops.name_scope(name, 'RaggedGatherNd', [params, indices]):
391 params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
392 params, name='params')
393 indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
394 indices, name='indices')
395 params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
396 indices_shape = indices.shape
397 indices_ndims = indices_shape.ndims
398 if indices_ndims is None:
399 raise ValueError('indices.rank be statically known.')
400 if indices_ndims == 0:
401 raise ValueError('indices.rank must be at least 1.')
402 if (ragged_tensor.is_ragged(indices) and
403 indices_ndims == indices.ragged_rank + 1):
404 raise ValueError('The innermost dimension of indices may not be ragged')
406 # `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions
407 # that each index slices into.
408 index_size = tensor_shape.dimension_value(indices_shape[-1])
409 if index_size is None:
410 raise ValueError('indices.shape[-1] must be statically known.')
412 # If `indices` has more than 2 dimensions, then recurse. If `indices` is
413 # dense, then we convert it to ragged before recursing, and then convert
414 # the result back to `dense` if appropriate.
415 if indices_ndims > 2:
416 indices_is_dense = not ragged_tensor.is_ragged(indices)
417 if indices_is_dense:
418 indices = ragged_tensor.RaggedTensor.from_tensor(
419 indices, ragged_rank=indices_ndims - 2,
420 row_splits_dtype=params.row_splits.dtype)
421 result = indices.with_flat_values(gather_nd(params, indices.flat_values))
422 if (indices_is_dense and ragged_tensor.is_ragged(result) and
423 result.ragged_rank == indices_ndims - 2):
424 result = ragged_tensor.RaggedTensor.to_tensor(result)
425 return result
427 # indices_ndims <= 2, and the innermost dimension of indices may not be
428 # ragged, so `indices` must not be ragged.
429 assert not ragged_tensor.is_ragged(indices)
430 assert ragged_tensor.is_ragged(params)
432 # Handle corner case: An empty index tuple selects the entire `params`
433 # value. So if `index_size` is zero, then tile `params`.
434 if index_size == 0:
435 params_ndims = params.ragged_rank + array_ops.rank(params.flat_values)
436 for dim in range(indices_ndims - 1):
437 params = ragged_array_ops.expand_dims(params, axis=0)
438 multiples = array_ops.concat([
439 array_ops.shape(indices)[:-1],
440 array_ops.ones([params_ndims], dtypes.int32)
441 ],
442 axis=0)
443 return ragged_array_ops.tile(params, multiples)
445 # When index_size=1, we can just flatten the index tuples and use gather.
446 elif index_size == 1:
447 flattened_index_tuples = array_ops.reshape(indices, [-1])
448 return gather(params, flattened_index_tuples)
450 # Otherwise, params is a RaggedTensor, and indices is a 1D or 2D Tensor.
451 # Flatten both the index tuples and the params, such that the flattened
452 # index tuples point to the correct values in the flattened params; and
453 # then use ragged.gather on the flattened index tuples & params.
454 else:
455 indices = math_ops.cast(indices, params.row_splits.dtype)
457 # Flatten the outermost 2 dimensions of the index tuples & params.
458 flattened_index_tuples = array_ops.gather(params.row_splits,
459 indices[..., 0])
460 flattened_index_tuples += indices[..., 1]
461 flattened_params = params.values
463 # Flatten any remaining dimensions.
464 for dim in range(2, index_size):
465 if not ragged_tensor.is_ragged(flattened_params):
466 flattened_index_tuples = array_ops.expand_dims(
467 flattened_index_tuples, axis=1)
468 flattened_index_tuples = array_ops.concat(
469 [flattened_index_tuples, indices[..., dim:]], axis=1)
470 return array_ops.gather_nd(flattened_params, flattened_index_tuples)
472 flattened_index_tuples = array_ops.gather(
473 flattened_params.row_starts(), flattened_index_tuples)
474 flattened_index_tuples += indices[..., dim]
475 flattened_params = flattened_params.values
477 # Gather using the flattened index tuples and params.
478 return gather(flattened_params, flattened_index_tuples)
481@dispatch.dispatch_for_api(array_ops.gather_nd)
482def _ragged_gather_nd_v1(params: ragged_tensor.RaggedOrDense,
483 indices: ragged_tensor.RaggedOrDense,
484 name=None,
485 batch_dims=0):
486 return gather_nd(params, indices, batch_dims, name)
489#===============================================================================
490# Gradient for the RaggedGather kernel
491#===============================================================================
492@ops.RegisterGradient('RaggedGather')
493def _ragged_gather_grad(op, *grads):
494 """Gradient for RaggedGather op."""
495 param_nested_splits = op.inputs[:-2]
496 param_inner_values = op.inputs[-2]
497 indices = op.inputs[-1]
498 grad_inner_values = grads[-1]
500 # For each row in `params`, find the range of values in `params.inner_values`
501 # that is covered by that row. In particular, the values in row `i` are
502 # `param_inner_values[combined_splits[i]:combined_splits[i+1]`.
503 combined_splits = param_nested_splits[0]
504 for row_splits in param_nested_splits[1:]:
505 combined_splits = array_ops.gather(row_splits, combined_splits)
507 # The outer dimensions of `indices` correspond 1:1 with the outer dimensions
508 # of `ragged_grad` that are encoded by `grad_nested_splits`. Thus, the
509 # flattened `indices` correspond 1:1 with `grad_inner_values`.
510 flat_indices = array_ops.reshape(indices, [-1])
512 # Build an IndexedSlices where the values are taken from `flat_grad`.
513 grad_indices = ragged_math_ops.range(
514 array_ops.gather(combined_splits, flat_indices),
515 array_ops.gather(combined_splits[1:], flat_indices)).values
517 param_inner_values_grad = indexed_slices.IndexedSlices(
518 values=grad_inner_values, indices=grad_indices,
519 dense_shape=array_ops.shape(param_inner_values))
520 return [None for _ in param_nested_splits] + [param_inner_values_grad, None]