Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_getitem.py: 16%
164 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python-style indexing and slicing for RaggedTensors."""
17from tensorflow.python.eager import context
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.framework import tensor_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import check_ops
24from tensorflow.python.ops import cond
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops.ragged import ragged_gather_ops
27from tensorflow.python.ops.ragged import ragged_math_ops
28from tensorflow.python.ops.ragged import ragged_tensor
29from tensorflow.python.util import dispatch
30from tensorflow.python.util.tf_export import tf_export
33@tf_export("__operators__.ragged_getitem", v1=[])
34@dispatch.add_dispatch_support
35def ragged_tensor_getitem(rt_input, key):
36 """Returns the specified piece of this RaggedTensor.
38 Supports multidimensional indexing and slicing, with one restriction:
39 indexing into a ragged inner dimension is not allowed. This case is
40 problematic because the indicated value may exist in some rows but not
41 others. In such cases, it's not obvious whether we should (1) report an
42 IndexError; (2) use a default value; or (3) skip that value and return a
43 tensor with fewer rows than we started with. Following the guiding
44 principles of Python ("In the face of ambiguity, refuse the temptation to
45 guess"), we simply disallow this operation.
47 Args:
48 rt_input: The RaggedTensor to slice.
49 key: Indicates which piece of the RaggedTensor to return, using standard
50 Python semantics (e.g., negative values index from the end). `key`
51 may have any of the following types:
53 * `int` constant
54 * Scalar integer `Tensor`
55 * `slice` containing integer constants and/or scalar integer
56 `Tensor`s
57 * `Ellipsis`
58 * `tf.newaxis`
59 * `tuple` containing any of the above (for multidimensional indexing)
61 Returns:
62 A `Tensor` or `RaggedTensor` object. Values that include at least one
63 ragged dimension are returned as `RaggedTensor`. Values that include no
64 ragged dimensions are returned as `Tensor`. See above for examples of
65 expressions that return `Tensor`s vs `RaggedTensor`s.
67 Raises:
68 ValueError: If `key` is out of bounds.
69 ValueError: If `key` is not supported.
70 TypeError: If the indices in `key` have an unsupported type.
72 Examples:
74 >>> # A 2-D ragged tensor with 1 ragged dimension.
75 >>> rt = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e'], ['f'], ['g']])
76 >>> rt[0].numpy() # First row (1-D `Tensor`)
77 array([b'a', b'b', b'c'], dtype=object)
78 >>> rt[:3].to_list() # First three rows (2-D RaggedTensor)
79 [[b'a', b'b', b'c'], [b'd', b'e'], [b'f']]
80 >>> rt[3, 0].numpy() # 1st element of 4th row (scalar)
81 b'g'
83 >>> # A 3-D ragged tensor with 2 ragged dimensions.
84 >>> rt = tf.ragged.constant([[[1, 2, 3], [4]],
85 ... [[5], [], [6]],
86 ... [[7]],
87 ... [[8, 9], [10]]])
88 >>> rt[1].to_list() # Second row (2-D RaggedTensor)
89 [[5], [], [6]]
90 >>> rt[3, 0].numpy() # First element of fourth row (1-D Tensor)
91 array([8, 9], dtype=int32)
92 >>> rt[:, 1:3].to_list() # Items 1-3 of each row (3-D RaggedTensor)
93 [[[4]], [[], [6]], [], [[10]]]
94 >>> rt[:, -1:].to_list() # Last item of each row (3-D RaggedTensor)
95 [[[4]], [[6]], [[7]], [[10]]]
96 """
97 if not isinstance(rt_input, ragged_tensor.RaggedTensor):
98 raise TypeError("Ragged __getitem__ expects a ragged_tensor.")
99 scope_tensors = [rt_input] + list(_tensors_in_key_list(key))
100 if isinstance(key, (list, tuple)):
101 key = list(key)
102 else:
103 key = [key]
104 with ops.name_scope(None, "RaggedGetItem", scope_tensors):
105 return _ragged_getitem(rt_input, key)
108def _ragged_getitem(rt_input, key_list):
109 """Helper for indexing and slicing ragged tensors with __getitem__().
111 Extracts the specified piece of the `rt_input`. See
112 `RaggedTensor.__getitem__` for examples and restrictions.
114 Args:
115 rt_input: The `RaggedTensor` from which a piece should be returned.
116 key_list: The list of keys specifying which piece to return. Each key
117 corresponds with a separate dimension.
119 Returns:
120 The indicated piece of rt_input.
122 Raises:
123 ValueError: If `key_list` is not supported.
124 TypeError: If any keys in `key_list` have an unsupported type.
125 """
126 if not key_list:
127 return rt_input
128 row_key = key_list[0]
129 inner_keys = key_list[1:]
131 if row_key is Ellipsis:
132 expanded_key_list = _expand_ellipsis(key_list, rt_input.shape.ndims)
133 return _ragged_getitem(rt_input, expanded_key_list)
135 # Adding a new axis: Get rt_input[inner_keys], and wrap it in a RaggedTensor
136 # that puts all values in a single row.
137 if row_key is array_ops.newaxis:
138 inner_rt = _ragged_getitem(rt_input, inner_keys)
139 nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0)
140 if nsplits.value is not None:
141 nsplits = nsplits.value
142 else:
143 nsplits = array_ops.shape(inner_rt.row_splits,
144 out_type=inner_rt.row_splits.dtype)[0]
145 return ragged_tensor.RaggedTensor.from_uniform_row_length(
146 inner_rt, nsplits - 1, nrows=1, validate=False)
148 # Slicing a range of rows: first slice the outer dimension, and then
149 # call `_ragged_getitem_inner_dimensions` to handle the inner keys.
150 if isinstance(row_key, slice):
151 sliced_rt_input = _slice_ragged_row_dimension(rt_input, row_key)
152 if rt_input.uniform_row_length is not None:
153 # If the inner dimension has uniform_row_length, then preserve it (by
154 # re-wrapping the values in a new RaggedTensor). Note that the row
155 # length won't have changed, since we're slicing a range of rows (and not
156 # slicing the rows themselves).
157 sliced_rt_input = ragged_tensor.RaggedTensor.from_uniform_row_length(
158 sliced_rt_input.values, rt_input.uniform_row_length,
159 nrows=sliced_rt_input.nrows())
160 return _ragged_getitem_inner_dimensions(sliced_rt_input, inner_keys)
162 # Indexing a single row: slice values to get the indicated row, and then
163 # use a recursive call to __getitem__ to handle the inner keys.
164 else:
165 starts = rt_input.row_splits[:-1]
166 limits = rt_input.row_splits[1:]
167 if context.executing_eagerly():
168 # In python, __getitem__ should throw IndexError for out of bound
169 # indices. This will allow iteration run correctly as python will
170 # translate IndexError into StopIteration for next()/__next__().
171 # Below is an example:
172 # import tensorflow as tf
173 # r = tf.ragged.constant([[1., 2.], [3., 4., 5.], [6.]])
174 # for elem in r:
175 # print(elem)
176 # In non eager mode, the exception is thrown when session runs
177 # so we don't know if out of bound happens before.
178 # In eager mode, however, it is possible to find out when to
179 # throw out of bound IndexError.
180 # In the following row_key >= len(starts) is checked. In case of
181 # TypeError which happens when row_key is not an integer, the exception
182 # will simply be ignored as it will be processed later anyway.
183 try:
184 if int(row_key) >= len(starts):
185 raise IndexError("Row key {} out of bounds".format(row_key))
186 except (TypeError, ValueError):
187 pass
188 row = rt_input.values[starts[row_key]:limits[row_key]]
189 return row.__getitem__(inner_keys)
192def _slice_ragged_row_dimension(rt_input, row_key):
193 """Slice the outer dimension of `rt_input` according to the given `slice`.
195 Args:
196 rt_input: The `RaggedTensor` to slice.
197 row_key: The `slice` object that should be used to slice `rt_input`.
199 Returns:
200 A `RaggedTensor` containing the indicated slice of `rt_input`.
201 """
202 if row_key.start is None and row_key.stop is None and row_key.step is None:
203 return rt_input
205 # Use row_key to slice the starts & limits.
206 new_starts = rt_input.row_splits[:-1][row_key]
207 new_limits = rt_input.row_splits[1:][row_key]
208 zero_pad = array_ops.zeros([1], rt_input.row_splits.dtype)
210 # If there's no slice step, then we can just select a single continuous
211 # span of `ragged.values(rt_input)`.
212 if row_key.step is None or row_key.step == 1:
213 # Construct the new splits. If new_starts and new_limits are empty,
214 # then this reduces to [0]. Otherwise, this reduces to:
215 # concat([[new_starts[0]], new_limits])
216 new_splits = array_ops.concat(
217 [zero_pad[array_ops.size(new_starts):], new_starts[:1], new_limits],
218 axis=0)
219 values_start = new_splits[0]
220 values_limit = new_splits[-1]
221 return ragged_tensor.RaggedTensor.from_row_splits(
222 rt_input.values[values_start:values_limit], new_splits - values_start,
223 validate=False)
225 # If there is a slice step (aka a strided slice), then use ragged_gather to
226 # collect the necessary elements of `ragged.values(rt_input)`.
227 else:
228 return _build_ragged_tensor_from_value_ranges(new_starts, new_limits, 1,
229 rt_input.values)
232def _ragged_getitem_inner_dimensions(rt_input, key_list):
233 """Retrieve inner dimensions, keeping outermost dimension unchanged.
235 Args:
236 rt_input: The `RaggedTensor` or `Tensor` from which a piece should be
237 extracted.
238 key_list: The __getitem__ keys for slicing the inner dimensions.
240 Returns:
241 A `RaggedTensor`.
243 Raises:
244 ValueError: If key_list is not supported.
245 """
246 if not key_list:
247 return rt_input
249 if not isinstance(rt_input, ragged_tensor.RaggedTensor):
250 return rt_input.__getitem__([slice(None, None, None)] + key_list)
252 column_key = key_list[0]
253 if column_key is Ellipsis:
254 expanded_key_list = _expand_ellipsis(key_list, rt_input.values.shape.ndims)
255 return _ragged_getitem_inner_dimensions(rt_input, expanded_key_list)
257 # Adding a new axis to a ragged inner dimension: recursively get the inner
258 # dimensions of rt_input with key_list[1:], and then wrap the result in a
259 # RaggedTensor that puts each value in its own row.
260 if column_key is array_ops.newaxis:
261 inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:])
262 nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0)
263 if nsplits.value is not None:
264 nsplits = nsplits.value
265 else:
266 nsplits = array_ops.shape(inner_rt.row_splits,
267 out_type=inner_rt.row_splits.dtype)[0]
268 return ragged_tensor.RaggedTensor.from_uniform_row_length(
269 inner_rt, 1, nrows=nsplits - 1, validate=False)
271 # Slicing a range of columns in a ragged inner dimension. We use a
272 # recursive call to process the values, and then assemble a RaggedTensor
273 # with those values.
274 if isinstance(column_key, slice):
275 if (column_key.start is None and column_key.stop is None and
276 column_key.step is None):
277 # Trivial slice: recursively process all values, & splits is unchanged.
278 return rt_input.with_values(
279 _ragged_getitem_inner_dimensions(rt_input.values, key_list[1:]))
280 else:
281 if not (isinstance(column_key.start, (ops.Tensor, int, type(None))) and
282 isinstance(column_key.stop, (ops.Tensor, int, type(None)))):
283 raise TypeError("slice offsets must be integers or None")
285 # Nontrivial slice: use ragged_gather to extract the indicated slice as
286 # a new RaggedTensor (inner_rt), and then recursively process its values.
287 starts = rt_input.row_splits[:-1]
288 limits = rt_input.row_splits[1:]
289 step = 1 if column_key.step is None else column_key.step
290 lower_bound = _if_ge_zero(step, lambda: starts, lambda: starts - 1)
291 upper_bound = _if_ge_zero(step, lambda: limits, lambda: limits - 1)
292 # inner_rt_starts[i] = index to start gathering for row i.
293 if column_key.start is None:
294 inner_rt_starts = _if_ge_zero(step, lambda: starts, lambda: limits - 1)
295 else:
296 start_offset = math_ops.cast(column_key.start, starts.dtype)
297 inner_rt_starts = _if_ge_zero(
298 column_key.start,
299 lambda: math_ops.minimum(starts + start_offset, upper_bound),
300 lambda: math_ops.maximum(limits + start_offset, lower_bound))
301 # inner_rt_limits[i] = index to stop gathering for row i.
302 if column_key.stop is None:
303 inner_rt_limits = _if_ge_zero(step, lambda: limits, lambda: starts - 1)
304 else:
305 stop_offset = math_ops.cast(column_key.stop, starts.dtype)
306 inner_rt_limits = _if_ge_zero(
307 column_key.stop,
308 lambda: math_ops.minimum(starts + stop_offset, upper_bound),
309 lambda: math_ops.maximum(limits + stop_offset, lower_bound))
310 inner_rt = _build_ragged_tensor_from_value_ranges(
311 inner_rt_starts, inner_rt_limits, column_key.step, rt_input.values)
312 # If the row dimension is uniform, then calculate the new
313 # uniform_row_length, and rebuild inner_rt using that uniform_row_lengths.
314 if rt_input.uniform_row_length is not None:
315 new_row_length = _slice_length(rt_input.uniform_row_length, column_key)
316 inner_rt = ragged_tensor.RaggedTensor.from_uniform_row_length(
317 inner_rt.values, new_row_length, rt_input.nrows())
318 return inner_rt.with_values(
319 _ragged_getitem_inner_dimensions(inner_rt.values, key_list[1:]))
321 # Indexing a single column in a ragged inner dimension: raise an Exception.
322 # See RaggedTensor.__getitem__.__doc__ for an explanation of why indexing
323 # into a ragged inner dimension is problematic.
324 if rt_input.uniform_row_length is None:
325 raise ValueError("Cannot index into an inner ragged dimension.")
327 # Indexing a single column in a uniform inner dimension: check that the
328 # given index is in-bounds, and then use a strided slice over rt_input.values
329 # to take the indicated element from each row.
330 row_length = rt_input.uniform_row_length
331 column_key = math_ops.cast(column_key, row_length.dtype)
332 oob_err_msg = "Index out of bounds when indexing into a ragged tensor"
333 oob_checks = [
334 check_ops.assert_greater_equal(
335 column_key, -row_length, message=oob_err_msg),
336 check_ops.assert_less(column_key, row_length, message=oob_err_msg),
337 ]
338 with ops.control_dependencies(oob_checks):
339 offset = _if_ge_zero(column_key, lambda: column_key,
340 lambda: row_length + column_key)
341 sliced_rt = rt_input.values[offset::row_length]
342 return _ragged_getitem_inner_dimensions(sliced_rt, key_list[1:])
345def _slice_length(value_length, slice_key):
346 """Computes the number of elements in a slice of a value with a given length.
348 Returns the equivalent of: `len(range(value_length)[slice_key])`
350 Args:
351 value_length: Scalar int `Tensor`: the length of the value being sliced.
352 slice_key: A `slice` object used to slice elements from the value.
354 Returns:
355 The number of elements in the sliced value.
356 """
357 # Note: we could compute the slice length without creating a zeros tensor
358 # with some variant of (stop-start)//step, but doing so would require more
359 # ops (for checking bounds, handling negative indices, negative step sizes,
360 # etc); and we expect this to be an uncommon operation, so we use this
361 # simpler implementation.
362 zeros = array_ops.zeros(value_length, dtype=dtypes.bool)
363 return array_ops.size(zeros[slice_key], out_type=value_length.dtype)
366def _expand_ellipsis(key_list, num_remaining_dims):
367 """Expands the ellipsis at the start of `key_list`.
369 Assumes that the first element of `key_list` is Ellipsis. This will either
370 remove the Ellipsis (if it corresponds to zero indices) or prepend a new
371 `slice(None, None, None)` (if it corresponds to more than zero indices).
373 Args:
374 key_list: The arguments to `__getitem__()`.
375 num_remaining_dims: The number of dimensions remaining.
377 Returns:
378 A copy of `key_list` with he ellipsis expanded.
379 Raises:
380 ValueError: If ragged_rank.shape.ndims is None
381 IndexError: If there are too many elements in `key_list`.
382 """
383 if num_remaining_dims is None:
384 raise ValueError("Ellipsis not supported for unknown shape RaggedTensors")
385 num_indices = sum(1 for idx in key_list if idx is not array_ops.newaxis)
386 if num_indices > num_remaining_dims + 1:
387 raise IndexError("Too many indices for RaggedTensor")
388 elif num_indices == num_remaining_dims + 1:
389 return key_list[1:]
390 else:
391 return [slice(None, None, None)] + key_list
394def _tensors_in_key_list(key_list):
395 """Generates all Tensors in the given slice spec."""
396 if isinstance(key_list, ops.Tensor):
397 yield key_list
398 if isinstance(key_list, (list, tuple)):
399 for v in key_list:
400 for tensor in _tensors_in_key_list(v):
401 yield tensor
402 if isinstance(key_list, slice):
403 for tensor in _tensors_in_key_list(key_list.start):
404 yield tensor
405 for tensor in _tensors_in_key_list(key_list.stop):
406 yield tensor
407 for tensor in _tensors_in_key_list(key_list.step):
408 yield tensor
411def _build_ragged_tensor_from_value_ranges(starts, limits, step, values):
412 """Returns a `RaggedTensor` containing the specified sequences of values.
414 Returns a RaggedTensor `output` where:
416 ```python
417 output.shape[0] = starts.shape[0]
418 output[i] = values[starts[i]:limits[i]:step]
419 ```
421 Requires that `starts.shape == limits.shape` and
422 `0 <= starts[i] <= limits[i] <= values.shape[0]`.
424 Args:
425 starts: 1D integer Tensor specifying the start indices for the sequences of
426 values to include.
427 limits: 1D integer Tensor specifying the limit indices for the sequences of
428 values to include.
429 step: Integer value specifying the step size for strided slices.
430 values: The set of values to select from.
432 Returns:
433 A `RaggedTensor`.
435 Raises:
436 ValueError: Until the prerequisite ops are checked in.
437 """
438 # Use `ragged_range` to get the index of each value we should include.
439 if step is None:
440 step = 1
441 step = ops.convert_to_tensor(step, name="step")
442 if step.dtype.is_integer:
443 step = math_ops.cast(step, starts.dtype)
444 else:
445 raise TypeError("slice strides must be integers or None")
446 value_indices = ragged_math_ops.range(starts, limits, step,
447 row_splits_dtype=starts.dtype)
449 # Use `ragged_gather` or `array_ops.gather` to collect the values.
450 if isinstance(values, ragged_tensor.RaggedTensor):
451 gathered_values = ragged_gather_ops.gather(
452 params=values, indices=value_indices.values)
453 else:
454 gathered_values = array_ops.gather(
455 params=values, indices=value_indices.values)
457 # Assemble the RaggedTensor from splits & values.
458 return value_indices.with_values(gathered_values)
461def _if_ge_zero(value, true_fn, false_fn):
462 """Returns `true_fn() if value >= 0 else false_fn()`."""
463 # If `value` is statically known, then don't use a control flow op.
464 if isinstance(value, ops.Tensor):
465 const_value = tensor_util.constant_value(value)
466 if const_value is None:
467 return cond.cond(value >= 0, true_fn, false_fn)
468 else:
469 value = const_value
470 if value >= 0:
471 return true_fn()
472 else:
473 return false_fn()