Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_tensor.py: 26%
917 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes for storing ragged tensors and their values."""
17import functools
18import operator
20import typing
21import numpy as np
23from tensorflow.core.protobuf import struct_pb2
24from tensorflow.python import tf2
25from tensorflow.python.client import session
26from tensorflow.python.framework import composite_tensor
27from tensorflow.python.framework import composite_tensor_gradient
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.framework import tensor_conversion
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_spec
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.framework import type_spec
37from tensorflow.python.framework import type_spec_registry
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import array_ops_stack
40from tensorflow.python.ops import check_ops
41from tensorflow.python.ops import cond
42from tensorflow.python.ops import control_flow_assert
43from tensorflow.python.ops import gen_ragged_conversion_ops
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops.ragged import ragged_config
46from tensorflow.python.ops.ragged import ragged_tensor_value
47from tensorflow.python.ops.ragged import ragged_util
48from tensorflow.python.ops.ragged.row_partition import RowPartition
49from tensorflow.python.saved_model import nested_structure_coder
50from tensorflow.python.types import core as core_types
51from tensorflow.python.types import internal as internal_types
52from tensorflow.python.util import dispatch
53from tensorflow.python.util.tf_export import tf_export
54from tensorflow.tools.docs import doc_controls
56# pylint: disable=protected-access
57_convert_row_partition = RowPartition._convert_row_partition
58# pylint: enable=protected-access
60# ===============================================================================
61# RaggedTensor
62# ===============================================================================
65@tf_export("RaggedTensor")
66class RaggedTensor(composite_tensor.CompositeTensor,
67 internal_types.NativeObject):
68 """Represents a ragged tensor.
70 A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are
71 dimensions whose slices may have different lengths. For example, the inner
72 (column) dimension of `rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged,
73 since the column slices (`rt[0, :]`, ..., `rt[4, :]`) have different lengths.
74 Dimensions whose slices all have the same length are called *uniform
75 dimensions*. The outermost dimension of a `RaggedTensor` is always uniform,
76 since it consists of a single slice (and so there is no possibility for
77 differing slice lengths).
79 The total number of dimensions in a `RaggedTensor` is called its *rank*,
80 and the number of ragged dimensions in a `RaggedTensor` is called its
81 *ragged-rank*. A `RaggedTensor`'s ragged-rank is fixed at graph creation
82 time: it can't depend on the runtime values of `Tensor`s, and can't vary
83 dynamically for different session runs.
85 Note that the `__init__` constructor is private. Please use one of the
86 following methods to construct a `RaggedTensor`:
88 * `tf.RaggedTensor.from_row_lengths`
89 * `tf.RaggedTensor.from_value_rowids`
90 * `tf.RaggedTensor.from_row_splits`
91 * `tf.RaggedTensor.from_row_starts`
92 * `tf.RaggedTensor.from_row_limits`
93 * `tf.RaggedTensor.from_nested_row_splits`
94 * `tf.RaggedTensor.from_nested_row_lengths`
95 * `tf.RaggedTensor.from_nested_value_rowids`
97 ### Potentially Ragged Tensors
99 Many ops support both `Tensor`s and `RaggedTensor`s
100 (see [tf.ragged](https://www.tensorflow.org/api_docs/python/tf/ragged) for a
101 full listing). The term "potentially ragged tensor" may be used to refer to a
102 tensor that might be either a `Tensor` or a `RaggedTensor`. The ragged-rank
103 of a `Tensor` is zero.
105 ### Documenting RaggedTensor Shapes
107 When documenting the shape of a RaggedTensor, ragged dimensions can be
108 indicated by enclosing them in parentheses. For example, the shape of
109 a 3-D `RaggedTensor` that stores the fixed-size word embedding for each
110 word in a sentence, for each sentence in a batch, could be written as
111 `[num_sentences, (num_words), embedding_size]`. The parentheses around
112 `(num_words)` indicate that dimension is ragged, and that the length
113 of each element list in that dimension may vary for each item.
115 ### Component Tensors
117 Internally, a `RaggedTensor` consists of a concatenated list of values that
118 are partitioned into variable-length rows. In particular, each `RaggedTensor`
119 consists of:
121 * A `values` tensor, which concatenates the variable-length rows into a
122 flattened list. For example, the `values` tensor for
123 `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is `[3, 1, 4, 1, 5, 9, 2, 6]`.
125 * A `row_splits` vector, which indicates how those flattened values are
126 divided into rows. In particular, the values for row `rt[i]` are stored
127 in the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
129 Example:
131 >>> print(tf.RaggedTensor.from_row_splits(
132 ... values=[3, 1, 4, 1, 5, 9, 2, 6],
133 ... row_splits=[0, 4, 4, 7, 8, 8]))
134 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
136 ### Alternative Row-Partitioning Schemes
138 In addition to `row_splits`, ragged tensors provide support for five other
139 row-partitioning schemes:
141 * `row_lengths`: a vector with shape `[nrows]`, which specifies the length
142 of each row.
144 * `value_rowids` and `nrows`: `value_rowids` is a vector with shape
145 `[nvals]`, corresponding one-to-one with `values`, which specifies
146 each value's row index. In particular, the row `rt[row]` consists of the
147 values `rt.values[j]` where `value_rowids[j]==row`. `nrows` is an
148 integer scalar that specifies the number of rows in the
149 `RaggedTensor`. (`nrows` is used to indicate trailing empty rows.)
151 * `row_starts`: a vector with shape `[nrows]`, which specifies the start
152 offset of each row. Equivalent to `row_splits[:-1]`.
154 * `row_limits`: a vector with shape `[nrows]`, which specifies the stop
155 offset of each row. Equivalent to `row_splits[1:]`.
157 * `uniform_row_length`: A scalar tensor, specifying the length of every
158 row. This row-partitioning scheme may only be used if all rows have
159 the same length.
161 Example: The following ragged tensors are equivalent, and all represent the
162 nested list `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]`.
164 >>> values = [3, 1, 4, 1, 5, 9, 2, 6]
165 >>> RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8])
166 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
167 >>> RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0])
168 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
169 >>> RaggedTensor.from_value_rowids(
170 ... values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
171 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
172 >>> RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8])
173 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
174 >>> RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8])
175 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
176 >>> RaggedTensor.from_uniform_row_length(values, uniform_row_length=2)
177 <tf.RaggedTensor [[3, 1], [4, 1], [5, 9], [2, 6]]>
179 ### Multiple Ragged Dimensions
181 `RaggedTensor`s with multiple ragged dimensions can be defined by using
182 a nested `RaggedTensor` for the `values` tensor. Each nested `RaggedTensor`
183 adds a single ragged dimension.
185 >>> inner_rt = RaggedTensor.from_row_splits( # =rt1 from above
186 ... values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
187 >>> outer_rt = RaggedTensor.from_row_splits(
188 ... values=inner_rt, row_splits=[0, 3, 3, 5])
189 >>> print(outer_rt.to_list())
190 [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]
191 >>> print(outer_rt.ragged_rank)
192 2
194 The factory function `RaggedTensor.from_nested_row_splits` may be used to
195 construct a `RaggedTensor` with multiple ragged dimensions directly, by
196 providing a list of `row_splits` tensors:
198 >>> RaggedTensor.from_nested_row_splits(
199 ... flat_values=[3, 1, 4, 1, 5, 9, 2, 6],
200 ... nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8])).to_list()
201 [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]
203 ### Uniform Inner Dimensions
205 `RaggedTensor`s with uniform inner dimensions can be defined
206 by using a multidimensional `Tensor` for `values`.
208 >>> rt = RaggedTensor.from_row_splits(values=tf.ones([5, 3], tf.int32),
209 ... row_splits=[0, 2, 5])
210 >>> print(rt.to_list())
211 [[[1, 1, 1], [1, 1, 1]],
212 [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]
213 >>> print(rt.shape)
214 (2, None, 3)
216 ### Uniform Outer Dimensions
218 `RaggedTensor`s with uniform outer dimensions can be defined by using
219 one or more `RaggedTensor` with a `uniform_row_length` row-partitioning
220 tensor. For example, a `RaggedTensor` with shape `[2, 2, None]` can be
221 constructed with this method from a `RaggedTensor` values with shape
222 `[4, None]`:
224 >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
225 >>> print(values.shape)
226 (4, None)
227 >>> rt6 = tf.RaggedTensor.from_uniform_row_length(values, 2)
228 >>> print(rt6)
229 <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
230 >>> print(rt6.shape)
231 (2, 2, None)
233 Note that `rt6` only contains one ragged dimension (the innermost
234 dimension). In contrast, if `from_row_splits` is used to construct a similar
235 `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions:
237 >>> rt7 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4])
238 >>> print(rt7.shape)
239 (2, None, None)
241 Uniform and ragged outer dimensions may be interleaved, meaning that a
242 tensor with any combination of ragged and uniform dimensions may be created.
243 For example, a RaggedTensor `t4` with shape `[3, None, 4, 8, None, 2]` could
244 be constructed as follows:
246 ```python
247 t0 = tf.zeros([1000, 2]) # Shape: [1000, 2]
248 t1 = RaggedTensor.from_row_lengths(t0, [...]) # [160, None, 2]
249 t2 = RaggedTensor.from_uniform_row_length(t1, 8) # [20, 8, None, 2]
250 t3 = RaggedTensor.from_uniform_row_length(t2, 4) # [5, 4, 8, None, 2]
251 t4 = RaggedTensor.from_row_lengths(t3, [...]) # [3, None, 4, 8, None, 2]
252 ```
254 """
256 #=============================================================================
257 # Constructor (private)
258 #=============================================================================
259 @doc_controls.do_not_generate_docs
260 def __init__(self, values, row_partition, internal=False):
261 """Creates a `RaggedTensor` with a specified partitioning for `values`.
263 This constructor is private -- please use one of the following ops to
264 build `RaggedTensor`s:
266 * `tf.RaggedTensor.from_row_lengths`
267 * `tf.RaggedTensor.from_value_rowids`
268 * `tf.RaggedTensor.from_row_splits`
269 * `tf.RaggedTensor.from_row_starts`
270 * `tf.RaggedTensor.from_row_limits`
271 * `tf.RaggedTensor.from_nested_row_splits`
272 * `tf.RaggedTensor.from_nested_row_lengths`
273 * `tf.RaggedTensor.from_nested_value_rowids`
275 Args:
276 values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`.
277 row_partition: A `RowPartition` object, representing the arrangement of
278 the lists at the top level.
279 internal: True if the constructor is being called by one of the factory
280 methods. If false, an exception will be raised.
282 Raises:
283 ValueError: If internal = False. Note that this method is intended only
284 for internal use.
285 TypeError: If values is not a `RaggedTensor` or `Tensor`, or
286 row_partition is not a `RowPartition`.
287 """
289 if not internal:
290 raise ValueError("RaggedTensor constructor is private; please use one "
291 "of the factory methods instead (e.g., "
292 "RaggedTensor.from_row_lengths())")
293 _assert_is_supported_ragged_values_type(values)
294 if not isinstance(row_partition, RowPartition):
295 raise TypeError(f"Argument `row_partition` must be a RowPartition. "
296 f"Received {row_partition}.")
298 # Validate shapes.
299 values.shape.with_rank_at_least(1)
300 if isinstance(values, RaggedTensor):
301 # pylint: disable=protected-access
302 assert row_partition.dtype == values._row_partition.dtype
304 self._values = values
305 self._row_partition = row_partition
307 #=============================================================================
308 # Factory Methods
309 #=============================================================================
311 @classmethod
312 def _from_row_partition(cls, values, row_partition, validate=True):
313 """Creates a `RaggedTensor` with a row partition.
315 This is used as a way for RaggedTensors to share row partitions.
317 The outer dimension of values must be equal to `partition.nvals()`.
319 Args:
320 values: A potentially ragged tensor.
321 row_partition: a `RowPartition`: can be shared between tensors.
322 validate: If true, then use assertions to check that the arguments form a
323 valid `RaggedTensor`.
325 Returns:
326 A `RaggedTensor`. `result.rank = values.rank + 1`.
327 `result.ragged_rank = values.ragged_rank + 1`.
329 Raises:
330 ValueError: If partition.nvals() != _nrows(values)
331 """
332 if not isinstance(row_partition, RowPartition):
333 raise TypeError(f"Argument `row_partition` must be a RowPartition. "
334 f"Received {row_partition}.")
335 if not isinstance(validate, bool):
336 raise TypeError(f"Argument `validate` must have type bool. "
337 f"Received {validate}.")
338 values, row_partition = cls._convert_values_and_partition(
339 values, row_partition, "partition")
340 if row_partition._has_precomputed_value_rowids(): # pylint: disable=protected-access
341 value_rowids_shape = row_partition.value_rowids().shape
342 values.shape[:1].assert_is_compatible_with(value_rowids_shape)
343 if validate:
344 msg = "Arguments to _from_row_partition do not form a valid RaggedTensor"
345 nvals = _nrows(values, row_partition.dtype)
346 checks = [
347 check_ops.assert_equal(
348 math_ops.cast(row_partition.nvals(), row_partition.dtype),
349 nvals,
350 message=msg),
351 ]
352 if not isinstance(values, RaggedTensor):
353 checks.append(check_ops.assert_rank_at_least(values, 1))
354 row_partition = row_partition._with_dependencies(checks) # pylint: disable=protected-access
355 return cls(values=values, internal=True, row_partition=row_partition)
357 @classmethod
358 @dispatch.add_dispatch_support
359 def from_value_rowids(cls,
360 values,
361 value_rowids,
362 nrows=None,
363 name=None,
364 validate=True):
365 """Creates a `RaggedTensor` with rows partitioned by `value_rowids`.
367 The returned `RaggedTensor` corresponds with the python list defined by:
369 ```python
370 result = [[values[i] for i in range(len(values)) if value_rowids[i] == row]
371 for row in range(nrows)]
372 ```
374 Args:
375 values: A potentially ragged tensor with shape `[nvals, ...]`.
376 value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds
377 one-to-one with `values`, and specifies each value's row index. Must be
378 nonnegative, and must be sorted in ascending order.
379 nrows: An integer scalar specifying the number of rows. This should be
380 specified if the `RaggedTensor` may containing empty training rows. Must
381 be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty).
382 Defaults to `value_rowids[-1] + 1` (or zero if `value_rowids` is empty).
383 name: A name prefix for the RaggedTensor (optional).
384 validate: If true, then use assertions to check that the arguments form
385 a valid `RaggedTensor`. Note: these assertions incur a runtime cost,
386 since they must be checked for each tensor value.
388 Returns:
389 A `RaggedTensor`. `result.rank = values.rank + 1`.
390 `result.ragged_rank = values.ragged_rank + 1`.
392 Raises:
393 ValueError: If `nrows` is incompatible with `value_rowids`.
395 #### Example:
397 >>> print(tf.RaggedTensor.from_value_rowids(
398 ... values=[3, 1, 4, 1, 5, 9, 2, 6],
399 ... value_rowids=[0, 0, 0, 0, 2, 2, 2, 3],
400 ... nrows=5))
401 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
403 """
404 if not isinstance(validate, bool):
405 raise TypeError(f"Argument `validate` must have type bool. "
406 f"Received {validate}.")
408 with ops.name_scope(name, "RaggedFromValueRowIds",
409 [values, value_rowids, nrows]):
410 row_partition = RowPartition.from_value_rowids(
411 value_rowids=value_rowids,
412 nrows=nrows,
413 validate=validate,
414 dtype_hint=_get_optional_partition_dtype(values))
415 return cls._from_row_partition(values, row_partition, validate=validate)
417 @classmethod
418 @dispatch.add_dispatch_support
419 def from_row_splits(cls, values, row_splits, name=None, validate=True):
420 """Creates a `RaggedTensor` with rows partitioned by `row_splits`.
422 The returned `RaggedTensor` corresponds with the python list defined by:
424 ```python
425 result = [values[row_splits[i]:row_splits[i + 1]]
426 for i in range(len(row_splits) - 1)]
427 ```
429 Args:
430 values: A potentially ragged tensor with shape `[nvals, ...]`.
431 row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be
432 empty, and must be sorted in ascending order. `row_splits[0]` must be
433 zero and `row_splits[-1]` must be `nvals`.
434 name: A name prefix for the RaggedTensor (optional).
435 validate: If true, then use assertions to check that the arguments form
436 a valid `RaggedTensor`. Note: these assertions incur a runtime cost,
437 since they must be checked for each tensor value.
439 Returns:
440 A `RaggedTensor`. `result.rank = values.rank + 1`.
441 `result.ragged_rank = values.ragged_rank + 1`.
443 Raises:
444 ValueError: If `row_splits` is an empty list.
446 #### Example:
448 >>> print(tf.RaggedTensor.from_row_splits(
449 ... values=[3, 1, 4, 1, 5, 9, 2, 6],
450 ... row_splits=[0, 4, 4, 7, 8, 8]))
451 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
453 """
454 if not isinstance(validate, bool):
455 raise TypeError(f"Argument `validate` must have type bool. "
456 f"Received {validate}.")
458 with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
459 row_partition = RowPartition.from_row_splits(
460 row_splits=row_splits,
461 validate=validate,
462 dtype_hint=_get_optional_partition_dtype(values))
463 return cls._from_row_partition(values, row_partition, validate=validate)
465 @classmethod
466 @dispatch.add_dispatch_support
467 def from_row_lengths(cls, values, row_lengths, name=None, validate=True):
468 """Creates a `RaggedTensor` with rows partitioned by `row_lengths`.
470 The returned `RaggedTensor` corresponds with the python list defined by:
472 ```python
473 result = [[values.pop(0) for i in range(length)]
474 for length in row_lengths]
475 ```
477 Args:
478 values: A potentially ragged tensor with shape `[nvals, ...]`.
479 row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be
480 nonnegative. `sum(row_lengths)` must be `nvals`.
481 name: A name prefix for the RaggedTensor (optional).
482 validate: If true, then use assertions to check that the arguments form
483 a valid `RaggedTensor`. Note: these assertions incur a runtime cost,
484 since they must be checked for each tensor value.
486 Returns:
487 A `RaggedTensor`. `result.rank = values.rank + 1`.
488 `result.ragged_rank = values.ragged_rank + 1`.
490 #### Example:
492 >>> print(tf.RaggedTensor.from_row_lengths(
493 ... values=[3, 1, 4, 1, 5, 9, 2, 6],
494 ... row_lengths=[4, 0, 3, 1, 0]))
495 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
497 """
498 if not isinstance(validate, bool):
499 raise TypeError(f"Argument `validate` must have type bool. "
500 f"Received {validate}.")
502 with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]):
503 row_partition = RowPartition.from_row_lengths(
504 row_lengths=row_lengths,
505 validate=validate,
506 dtype_hint=_get_optional_partition_dtype(values))
507 return cls._from_row_partition(values, row_partition, validate=validate)
509 @classmethod
510 @dispatch.add_dispatch_support
511 def from_row_starts(cls, values, row_starts, name=None, validate=True):
512 """Creates a `RaggedTensor` with rows partitioned by `row_starts`.
514 Equivalent to: `from_row_splits(values, concat([row_starts, nvals]))`.
516 Args:
517 values: A potentially ragged tensor with shape `[nvals, ...]`.
518 row_starts: A 1-D integer tensor with shape `[nrows]`. Must be
519 nonnegative and sorted in ascending order. If `nrows>0`, then
520 `row_starts[0]` must be zero.
521 name: A name prefix for the RaggedTensor (optional).
522 validate: If true, then use assertions to check that the arguments form
523 a valid `RaggedTensor`. Note: these assertions incur a runtime cost,
524 since they must be checked for each tensor value.
526 Returns:
527 A `RaggedTensor`. `result.rank = values.rank + 1`.
528 `result.ragged_rank = values.ragged_rank + 1`.
530 #### Example:
532 >>> print(tf.RaggedTensor.from_row_starts(
533 ... values=[3, 1, 4, 1, 5, 9, 2, 6],
534 ... row_starts=[0, 4, 4, 7, 8]))
535 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
537 """
538 if not isinstance(validate, bool):
539 raise TypeError(f"Argument `validate` must have type bool. "
540 f"Received {validate}.")
541 with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]):
542 values = _convert_to_ragged_tensor_values(values)
543 row_partition = RowPartition.from_row_starts(
544 row_starts=row_starts,
545 nvals=_nrows(values),
546 validate=validate,
547 dtype_hint=_get_optional_partition_dtype(values))
548 return cls._from_row_partition(values, row_partition, validate=validate)
550 @classmethod
551 @dispatch.add_dispatch_support
552 def from_row_limits(cls, values, row_limits, name=None, validate=True):
553 """Creates a `RaggedTensor` with rows partitioned by `row_limits`.
555 Equivalent to: `from_row_splits(values, concat([0, row_limits]))`.
557 Args:
558 values: A potentially ragged tensor with shape `[nvals, ...]`.
559 row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in
560 ascending order. If `nrows>0`, then `row_limits[-1]` must be `nvals`.
561 name: A name prefix for the RaggedTensor (optional).
562 validate: If true, then use assertions to check that the arguments form
563 a valid `RaggedTensor`. Note: these assertions incur a runtime cost,
564 since they must be checked for each tensor value.
566 Returns:
567 A `RaggedTensor`. `result.rank = values.rank + 1`.
568 `result.ragged_rank = values.ragged_rank + 1`.
570 #### Example:
572 >>> print(tf.RaggedTensor.from_row_limits(
573 ... values=[3, 1, 4, 1, 5, 9, 2, 6],
574 ... row_limits=[4, 4, 7, 8, 8]))
575 <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
577 """
578 if not isinstance(validate, bool):
579 raise TypeError(f"Argument `validate` must have type bool. "
580 f"Received {validate}.")
581 with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]):
582 values = _convert_to_ragged_tensor_values(values)
583 row_partition = RowPartition.from_row_limits(
584 row_limits=row_limits,
585 validate=validate,
586 dtype_hint=_get_optional_partition_dtype(values))
587 return cls._from_row_partition(values, row_partition, validate=validate)
589 @classmethod
590 @dispatch.add_dispatch_support
591 def from_uniform_row_length(cls,
592 values,
593 uniform_row_length,
594 nrows=None,
595 validate=True,
596 name=None):
597 """Creates a `RaggedTensor` with rows partitioned by `uniform_row_length`.
599 This method can be used to create `RaggedTensor`s with multiple uniform
600 outer dimensions. For example, a `RaggedTensor` with shape `[2, 2, None]`
601 can be constructed with this method from a `RaggedTensor` values with shape
602 `[4, None]`:
604 >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
605 >>> print(values.shape)
606 (4, None)
607 >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2)
608 >>> print(rt1)
609 <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
610 >>> print(rt1.shape)
611 (2, 2, None)
613 Note that `rt1` only contains one ragged dimension (the innermost
614 dimension). In contrast, if `from_row_splits` is used to construct a similar
615 `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions:
617 >>> rt2 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4])
618 >>> print(rt2.shape)
619 (2, None, None)
621 Args:
622 values: A potentially ragged tensor with shape `[nvals, ...]`.
623 uniform_row_length: A scalar integer tensor. Must be nonnegative. The
624 size of the outer axis of `values` must be evenly divisible by
625 `uniform_row_length`.
626 nrows: The number of rows in the constructed RaggedTensor. If not
627 specified, then it defaults to `nvals/uniform_row_length` (or `0` if
628 `uniform_row_length==0`). `nrows` only needs to be specified if
629 `uniform_row_length` might be zero. `uniform_row_length*nrows` must be
630 `nvals`.
631 validate: If true, then use assertions to check that the arguments form
632 a valid `RaggedTensor`. Note: these assertions incur a runtime cost,
633 since they must be checked for each tensor value.
634 name: A name prefix for the RaggedTensor (optional).
636 Returns:
637 A `RaggedTensor` that corresponds with the python list defined by:
639 ```python
640 result = [[values.pop(0) for i in range(uniform_row_length)]
641 for _ in range(nrows)]
642 ```
644 `result.rank = values.rank + 1`.
645 `result.ragged_rank = values.ragged_rank + 1`.
646 """
647 if not isinstance(validate, bool):
648 raise TypeError(f"Argument `validate` must have type bool. "
649 f"Received {validate}.")
650 with ops.name_scope(name, "RaggedFromUniformRowLength",
651 [values, uniform_row_length, nrows]):
652 values = _convert_to_ragged_tensor_values(values)
653 uniform_row_length = _convert_row_partition(
654 uniform_row_length, "UniformRowLength",
655 _get_optional_partition_dtype(values))
656 nvals = _nvals_uniform_row_length(values, uniform_row_length)
657 row_partition = RowPartition.from_uniform_row_length(
658 uniform_row_length=uniform_row_length,
659 nvals=nvals,
660 nrows=nrows,
661 validate=validate,
662 dtype_hint=_get_optional_partition_dtype(values))
663 return cls._from_row_partition(values, row_partition, validate=validate)
665 @classmethod
666 @dispatch.add_dispatch_support
667 def from_nested_value_rowids(cls,
668 flat_values,
669 nested_value_rowids,
670 nested_nrows=None,
671 name=None,
672 validate=True):
673 """Creates a `RaggedTensor` from a nested list of `value_rowids` tensors.
675 Equivalent to:
677 ```python
678 result = flat_values
679 for (rowids, nrows) in reversed(zip(nested_value_rowids, nested_nrows)):
680 result = from_value_rowids(result, rowids, nrows)
681 ```
683 Args:
684 flat_values: A potentially ragged tensor.
685 nested_value_rowids: A list of 1-D integer tensors. The `i`th tensor is
686 used as the `value_rowids` for the `i`th ragged dimension.
687 nested_nrows: A list of integer scalars. The `i`th scalar is used as the
688 `nrows` for the `i`th ragged dimension.
689 name: A name prefix for the RaggedTensor (optional).
690 validate: If true, then use assertions to check that the arguments form
691 a valid `RaggedTensor`. Note: these assertions incur a runtime cost,
692 since they must be checked for each tensor value.
694 Returns:
695 A `RaggedTensor` (or `flat_values` if `nested_value_rowids` is empty).
697 Raises:
698 ValueError: If `len(nested_values_rowids) != len(nested_nrows)`.
699 """
700 if not isinstance(validate, bool):
701 raise TypeError(f"Argument `validate` must have type bool. "
702 f"Received {validate}.")
703 if isinstance(nested_value_rowids, ops.Tensor):
704 raise TypeError(f"Argument `nested_value_rowids` must be a list of "
705 f"Tensors. Received {nested_value_rowids}.")
706 if nested_nrows is None:
707 nested_nrows = [None] * len(nested_value_rowids)
708 else:
709 if isinstance(nested_nrows, ops.Tensor):
710 raise TypeError(f"Argument `nested_nrows` must be a list of "
711 f"Tensors. Received {nested_nrows}.")
712 if len(nested_nrows) != len(nested_value_rowids):
713 raise ValueError(
714 f"Argument `nested_nrows` must have the same length as "
715 f"argument `nested_value_rowids`. len(nested_nrows) = "
716 f"{len(nested_nrows)} vs. len(nested_values_rowids) = "
717 f"{len(nested_value_rowids)}.")
719 with ops.name_scope(name, "RaggedFromNestedValueRowIds", [flat_values] +
720 list(nested_value_rowids) + list(nested_nrows)):
721 result = flat_values
722 for value_rowids, nrows in reversed(
723 list(zip(nested_value_rowids, nested_nrows))):
724 result = cls.from_value_rowids(
725 result, value_rowids, nrows, validate=validate)
726 return result
728 @classmethod
729 @dispatch.add_dispatch_support
730 def from_nested_row_splits(cls,
731 flat_values,
732 nested_row_splits,
733 name=None,
734 validate=True):
735 """Creates a `RaggedTensor` from a nested list of `row_splits` tensors.
737 Equivalent to:
739 ```python
740 result = flat_values
741 for row_splits in reversed(nested_row_splits):
742 result = from_row_splits(result, row_splits)
743 ```
745 Args:
746 flat_values: A potentially ragged tensor.
747 nested_row_splits: A list of 1-D integer tensors. The `i`th tensor is
748 used as the `row_splits` for the `i`th ragged dimension.
749 name: A name prefix for the RaggedTensor (optional).
750 validate: If true, then use assertions to check that the arguments form
751 a valid `RaggedTensor`. Note: these assertions incur a runtime cost,
752 since they must be checked for each tensor value.
754 Returns:
755 A `RaggedTensor` (or `flat_values` if `nested_row_splits` is empty).
756 """
757 if not isinstance(validate, bool):
758 raise TypeError(f"Argument `validate` must have type bool. "
759 f"Received {validate}.")
760 if isinstance(nested_row_splits, ops.Tensor):
761 raise TypeError(f"Argument `nested_row_splits` must be a list of "
762 f"Tensors. Received {nested_row_splits}.")
763 with ops.name_scope(name, "RaggedFromNestedRowSplits",
764 [flat_values] + list(nested_row_splits)):
765 result = flat_values
766 for splits in reversed(nested_row_splits):
767 result = cls.from_row_splits(result, splits, validate=validate)
768 return result
770 @classmethod
771 @dispatch.add_dispatch_support
772 def from_nested_row_lengths(cls,
773 flat_values,
774 nested_row_lengths,
775 name=None,
776 validate=True):
777 """Creates a `RaggedTensor` from a nested list of `row_lengths` tensors.
779 Equivalent to:
781 ```python
782 result = flat_values
783 for row_lengths in reversed(nested_row_lengths):
784 result = from_row_lengths(result, row_lengths)
785 ```
787 Args:
788 flat_values: A potentially ragged tensor.
789 nested_row_lengths: A list of 1-D integer tensors. The `i`th tensor is
790 used as the `row_lengths` for the `i`th ragged dimension.
791 name: A name prefix for the RaggedTensor (optional).
792 validate: If true, then use assertions to check that the arguments form
793 a valid `RaggedTensor`. Note: these assertions incur a runtime cost,
794 since they must be checked for each tensor value.
796 Returns:
797 A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty).
798 """
799 if not isinstance(validate, bool):
800 raise TypeError(f"Argument `validate` must have type bool. "
801 f"Received {validate}.")
802 if isinstance(nested_row_lengths, ops.Tensor):
803 raise TypeError(f"Argument `nested_row_lengths` must be a list of "
804 f"Tensors. Received {nested_row_lengths}.")
805 with ops.name_scope(name, "RaggedFromNestedRowlengths",
806 [flat_values] + list(nested_row_lengths)):
807 result = flat_values
808 for lengths in reversed(nested_row_lengths):
809 result = cls.from_row_lengths(result, lengths, validate=validate)
810 return result
812 @classmethod
813 def _from_nested_row_partitions(cls,
814 flat_values,
815 nested_row_partitions,
816 name=None,
817 validate=True):
818 """Creates a `RaggedTensor` from a nested list of row partitions.
820 Equivalent to:
822 ```python
823 result = flat_values
824 for row_partition in reversed(nested_row_partitions):
825 result = _from_row_partition(result, row_partition)
826 ```
828 Args:
829 flat_values: A potentially ragged tensor.
830 nested_row_partitions: A list of row partitions. The `i`th element is
831 used as the row partition for the `i`th ragged dimension.
832 name: A name prefix for the RaggedTensor (optional).
833 validate: If true, then use assertions to check that the arguments form
834 a valid `RaggedTensor`. Note: these assertions incur a runtime cost,
835 since they must be checked for each tensor value.
837 Returns:
838 A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty).
839 """
840 if not isinstance(validate, bool):
841 raise TypeError(f"Argument `validate` must have type bool. "
842 f"Received {validate}.")
843 if isinstance(nested_row_partitions, RowPartition):
844 raise TypeError(f"Argument `nested_row_partitions` must be a list of "
845 f"RowPartitions. Received {nested_row_partitions}.")
846 if isinstance(nested_row_partitions, ops.Tensor):
847 raise TypeError(f"Argument `nested_row_partitions` must be a list of "
848 f"RowPartitions. Received {nested_row_partitions}.")
849 with ops.name_scope(name, "RaggedFromNestedRowPartitions",
850 [flat_values] + list(nested_row_partitions)):
851 result = flat_values
852 for partition in reversed(nested_row_partitions):
853 result = cls._from_row_partition(result, partition, validate=validate)
854 return result
856 @classmethod
857 def _convert_values_and_partition(cls, values, row_partition, name):
858 """Converts `values` and `partition` to Tensors.
860 If `values` is a `RaggedTensor`, then converts `values` and `partition`
861 to have compatible row-partitioning dtypes. In particular, if any of the
862 row partitioning tensors are `int64`, then all of the other row
863 partitioning tensors wil be cast to `int64` (if auto_cast_partition_dtype()
864 is true) or an error will be raised (if auto_cast_partition_dtype() is
865 false).
867 Args:
868 values: The `values` for the `RaggedTensor` being constructed.
869 row_partition: A RowPartition object for the `RaggedTensor` being
870 constructed.
871 name: The name of the RowPartition object.
873 Returns:
874 A tuple (values, partition).
875 """
876 if not isinstance(row_partition, RowPartition):
877 raise TypeError(f"Argument `row_partition` must be a RowPartition. "
878 f"Received {row_partition}.")
879 if isinstance(values, RaggedTensor):
880 # pylint: disable=protected-access
881 if values._row_partition.dtype != row_partition.dtype:
882 if not ragged_config.auto_cast_partition_dtype():
883 # pylint: disable=protected-access
884 # TODO(edloper): get rid of the `name` parameter.
885 raise ValueError(
886 f"Argument `row_partition` of RaggedTensor with name: {name} "
887 f"must have same dtype as Argument `values`. "
888 f"({row_partition.dtype} vs. {values._row_partition.dtype}).")
889 values = values.with_row_splits_dtype(row_partition.dtype)
890 else:
891 values = _convert_to_ragged_tensor_values(values)
893 return (values, row_partition)
895 #=============================================================================
896 # Accessors
897 #=============================================================================
899 @property
900 def dtype(self):
901 """The `DType` of values in this tensor."""
902 return self._values.dtype
904 @property
905 def shape(self):
906 """The statically known shape of this ragged tensor.
908 Returns:
909 A `TensorShape` containing the statically known shape of this ragged
910 tensor. Ragged dimensions have a size of `None`.
912 Examples:
914 >>> tf.ragged.constant([[0], [1, 2]]).shape
915 TensorShape([2, None])
917 >>> tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).shape
918 TensorShape([2, None, 2])
920 """
921 nrows = self._row_partition.static_nrows
922 ncols = self._row_partition.static_uniform_row_length
923 value_shape = self._values.shape[1:]
924 return tensor_shape.TensorShape([nrows, ncols]).concatenate(value_shape)
926 def get_shape(self):
927 """The statically known shape of this ragged tensor.
929 Returns:
930 A `TensorShape` containing the statically known shape of this ragged
931 tensor. Ragged dimensions have a size of `None`.
933 Alias for `shape` property.
935 Examples:
937 >>> tf.ragged.constant([[0], [1, 2]]).get_shape()
938 TensorShape([2, None])
940 >>> tf.ragged.constant(
941 ... [[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).get_shape()
942 TensorShape([2, None, 2])
944 """
945 return self.shape
947 @property
948 def ragged_rank(self):
949 """The number of times the RaggedTensor's flat_values is partitioned.
951 Examples:
953 >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
954 >>> values.ragged_rank
955 1
957 >>> rt = tf.RaggedTensor.from_uniform_row_length(values, 2)
958 >>> rt.ragged_rank
959 2
961 Returns:
962 A Python `int` indicating the number of times the underlying `flat_values`
963 Tensor has been partitioned to add a new dimension.
964 I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`.
965 """
966 values_is_ragged = isinstance(self._values, RaggedTensor)
967 return self._values.ragged_rank + 1 if values_is_ragged else 1
969 @property
970 def values(self):
971 """The concatenated rows for this ragged tensor.
973 `rt.values` is a potentially ragged tensor formed by flattening the two
974 outermost dimensions of `rt` into a single dimension.
976 `rt.values.shape = [nvals] + rt.shape[2:]` (where `nvals` is the
977 number of items in the outer two dimensions of `rt`).
979 `rt.ragged_rank = self.ragged_rank - 1`
981 Returns:
982 A potentially ragged tensor.
984 #### Example:
986 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
987 >>> print(rt.values)
988 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
990 """
991 return self._values
993 @property
994 def _nested_row_partitions(self):
995 """Returns the row partitions for this `RaggedTensor`."""
996 partitions = [self._row_partition]
997 rt_values = self.values
998 while isinstance(rt_values, RaggedTensor):
999 # pylint: disable=protected-access
1000 partitions.append(rt_values._row_partition)
1001 rt_values = rt_values.values
1002 return tuple(partitions)
1004 @property
1005 def row_splits(self):
1006 """The row-split indices for this ragged tensor's `values`.
1008 `rt.row_splits` specifies where the values for each row begin and end in
1009 `rt.values`. In particular, the values for row `rt[i]` are stored in
1010 the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
1012 Returns:
1013 A 1-D integer `Tensor` with shape `[self.nrows+1]`.
1014 The returned tensor is non-empty, and is sorted in ascending order.
1015 `self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to
1016 `self.values.shape[0]`.
1018 #### Example:
1020 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1021 >>> print(rt.row_splits) # indices of row splits in rt.values
1022 tf.Tensor([0 4 4 7 8 8], shape=(6,), dtype=int64)
1024 """
1025 return self._row_partition.row_splits()
1027 @property
1028 def uniform_row_length(self):
1029 """The length of each row in this ragged tensor, or None if rows are ragged.
1031 >>> rt1 = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
1032 >>> print(rt1.uniform_row_length) # rows are ragged.
1033 None
1035 >>> rt2 = tf.RaggedTensor.from_uniform_row_length(
1036 ... values=rt1, uniform_row_length=2)
1037 >>> print(rt2)
1038 <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
1039 >>> print(rt2.uniform_row_length) # rows are not ragged (all have size 2).
1040 tf.Tensor(2, shape=(), dtype=int64)
1042 A RaggedTensor's rows are only considered to be uniform (i.e. non-ragged)
1043 if it can be determined statically (at graph construction time) that the
1044 rows all have the same length.
1046 Returns:
1047 A scalar integer `Tensor`, specifying the length of every row in this
1048 ragged tensor (for ragged tensors whose rows are uniform); or `None`
1049 (for ragged tensors whose rows are ragged).
1050 """
1051 return self._row_partition.uniform_row_length()
1053 @property
1054 def flat_values(self):
1055 """The innermost `values` tensor for this ragged tensor.
1057 Concretely, if `rt.values` is a `Tensor`, then `rt.flat_values` is
1058 `rt.values`; otherwise, `rt.flat_values` is `rt.values.flat_values`.
1060 Conceptually, `flat_values` is the tensor formed by flattening the
1061 outermost dimension and all of the ragged dimensions into a single
1062 dimension.
1064 `rt.flat_values.shape = [nvals] + rt.shape[rt.ragged_rank + 1:]`
1065 (where `nvals` is the number of items in the flattened dimensions).
1067 Returns:
1068 A `Tensor`.
1070 #### Example:
1072 >>> rt = tf.ragged.constant([[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
1073 >>> print(rt.flat_values)
1074 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1076 """
1077 rt_values = self.values
1078 while isinstance(rt_values, RaggedTensor):
1079 rt_values = rt_values.values
1080 return rt_values
1082 @property
1083 def nested_row_splits(self):
1084 """A tuple containing the row_splits for all ragged dimensions.
1086 `rt.nested_row_splits` is a tuple containing the `row_splits` tensors for
1087 all ragged dimensions in `rt`, ordered from outermost to innermost. In
1088 particular, `rt.nested_row_splits = (rt.row_splits,) + value_splits` where:
1090 * `value_splits = ()` if `rt.values` is a `Tensor`.
1091 * `value_splits = rt.values.nested_row_splits` otherwise.
1093 Returns:
1094 A `tuple` of 1-D integer `Tensor`s.
1096 #### Example:
1098 >>> rt = tf.ragged.constant(
1099 ... [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]])
1100 >>> for i, splits in enumerate(rt.nested_row_splits):
1101 ... print('Splits for dimension %d: %s' % (i+1, splits.numpy()))
1102 Splits for dimension 1: [0 3]
1103 Splits for dimension 2: [0 3 3 5]
1104 Splits for dimension 3: [0 4 4 7 8 8]
1106 """
1107 rt_nested_splits = [self.row_splits]
1108 rt_values = self.values
1109 while isinstance(rt_values, RaggedTensor):
1110 rt_nested_splits.append(rt_values.row_splits)
1111 rt_values = rt_values.values
1112 return tuple(rt_nested_splits)
1114 def value_rowids(self, name=None):
1115 """Returns the row indices for the `values` in this ragged tensor.
1117 `rt.value_rowids()` corresponds one-to-one with the outermost dimension of
1118 `rt.values`, and specifies the row containing each value. In particular,
1119 the row `rt[row]` consists of the values `rt.values[j]` where
1120 `rt.value_rowids()[j] == row`.
1122 Args:
1123 name: A name prefix for the returned tensor (optional).
1125 Returns:
1126 A 1-D integer `Tensor` with shape `self.values.shape[:1]`.
1127 The returned tensor is nonnegative, and is sorted in ascending order.
1129 #### Example:
1131 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1132 >>> print(rt.values)
1133 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1134 >>> print(rt.value_rowids()) # corresponds 1:1 with rt.values
1135 tf.Tensor([0 0 0 0 2 2 2 3], shape=(8,), dtype=int64)
1137 """
1138 with ops.name_scope(name, "RaggedValueRowIds", [self]):
1139 return self._row_partition.value_rowids()
1141 def nested_value_rowids(self, name=None):
1142 """Returns a tuple containing the value_rowids for all ragged dimensions.
1144 `rt.nested_value_rowids` is a tuple containing the `value_rowids` tensors
1145 for
1146 all ragged dimensions in `rt`, ordered from outermost to innermost. In
1147 particular, `rt.nested_value_rowids = (rt.value_rowids(),) + value_ids`
1148 where:
1150 * `value_ids = ()` if `rt.values` is a `Tensor`.
1151 * `value_ids = rt.values.nested_value_rowids` otherwise.
1153 Args:
1154 name: A name prefix for the returned tensors (optional).
1156 Returns:
1157 A `tuple` of 1-D integer `Tensor`s.
1159 #### Example:
1161 >>> rt = tf.ragged.constant(
1162 ... [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]])
1163 >>> for i, ids in enumerate(rt.nested_value_rowids()):
1164 ... print('row ids for dimension %d: %s' % (i+1, ids.numpy()))
1165 row ids for dimension 1: [0 0 0]
1166 row ids for dimension 2: [0 0 0 2 2]
1167 row ids for dimension 3: [0 0 0 0 2 2 2 3]
1169 """
1170 with ops.name_scope(name, "RaggedNestedValueRowIds", [self]):
1171 rt_nested_ids = [self.value_rowids()]
1172 rt_values = self.values
1173 while isinstance(rt_values, RaggedTensor):
1174 rt_nested_ids.append(rt_values.value_rowids())
1175 rt_values = rt_values.values
1176 return tuple(rt_nested_ids)
1178 def nrows(self, out_type=None, name=None):
1179 """Returns the number of rows in this ragged tensor.
1181 I.e., the size of the outermost dimension of the tensor.
1183 Args:
1184 out_type: `dtype` for the returned tensor. Defaults to
1185 `self.row_splits.dtype`.
1186 name: A name prefix for the returned tensor (optional).
1188 Returns:
1189 A scalar `Tensor` with dtype `out_type`.
1191 #### Example:
1193 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1194 >>> print(rt.nrows()) # rt has 5 rows.
1195 tf.Tensor(5, shape=(), dtype=int64)
1197 """
1198 with ops.name_scope(name, "RaggedNRows", [self]):
1199 if out_type is None:
1200 return self._row_partition.nrows()
1201 else:
1202 return math_ops.cast(self._row_partition.nrows(), dtype=out_type)
1204 def row_starts(self, name=None):
1205 """Returns the start indices for rows in this ragged tensor.
1207 These indices specify where the values for each row begin in
1208 `self.values`. `rt.row_starts()` is equal to `rt.row_splits[:-1]`.
1210 Args:
1211 name: A name prefix for the returned tensor (optional).
1213 Returns:
1214 A 1-D integer Tensor with shape `[nrows]`.
1215 The returned tensor is nonnegative, and is sorted in ascending order.
1217 #### Example:
1219 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1220 >>> print(rt.values)
1221 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1222 >>> print(rt.row_starts()) # indices of row starts in rt.values
1223 tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64)
1225 """
1226 with ops.name_scope(name, "RaggedRowStarts", [self]):
1227 return self._row_partition.row_starts()
1229 def row_limits(self, name=None):
1230 """Returns the limit indices for rows in this ragged tensor.
1232 These indices specify where the values for each row end in
1233 `self.values`. `rt.row_limits(self)` is equal to `rt.row_splits[:-1]`.
1235 Args:
1236 name: A name prefix for the returned tensor (optional).
1238 Returns:
1239 A 1-D integer Tensor with shape `[nrows]`.
1240 The returned tensor is nonnegative, and is sorted in ascending order.
1242 #### Example:
1244 >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1245 >>> print(rt.values)
1246 tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1247 >>> print(rt.row_limits()) # indices of row limits in rt.values
1248 tf.Tensor([4 4 7 8 8], shape=(5,), dtype=int64)
1250 """
1251 with ops.name_scope(name, "RaggedRowLimits", [self]):
1252 return self._row_partition.row_limits()
1254 def row_lengths(self, axis=1, name=None):
1255 """Returns the lengths of the rows in this ragged tensor.
1257 `rt.row_lengths()[i]` indicates the number of values in the
1258 `i`th row of `rt`.
1260 Args:
1261 axis: An integer constant indicating the axis whose row lengths should be
1262 returned.
1263 name: A name prefix for the returned tensor (optional).
1265 Returns:
1266 A potentially ragged integer Tensor with shape `self.shape[:axis]`.
1268 Raises:
1269 ValueError: If `axis` is out of bounds.
1271 #### Example:
1273 >>> rt = tf.ragged.constant(
1274 ... [[[3, 1, 4], [1]], [], [[5, 9], [2]], [[6]], []])
1275 >>> print(rt.row_lengths()) # lengths of rows in rt
1276 tf.Tensor([2 0 2 1 0], shape=(5,), dtype=int64)
1277 >>> print(rt.row_lengths(axis=2)) # lengths of axis=2 rows.
1278 <tf.RaggedTensor [[3, 1], [], [2, 1], [1], []]>
1280 """
1281 if axis == 0:
1282 return self._row_partition.nrows()
1284 if axis == 1:
1285 return self._row_partition.row_lengths()
1287 with ops.name_scope(name, "RaggedRowLengths", [self]):
1288 axis = array_ops.get_positive_axis(
1289 axis, self.shape.rank, ndims_name="rank(self)")
1290 if axis == 0:
1291 return self.nrows()
1292 elif axis == 1:
1293 splits = self.row_splits
1294 return splits[1:] - splits[:-1]
1295 elif isinstance(self.values, RaggedTensor):
1296 return self.with_values(self.values.row_lengths(axis - 1))
1297 else:
1298 shape = array_ops.shape(self.values, out_type=self._row_partition.dtype)
1299 return self.with_values(
1300 array_ops.ones(shape[:axis - 1], self._row_partition.dtype) *
1301 shape[axis - 1])
1303 def nested_row_lengths(self, name=None):
1304 """Returns a tuple containing the row_lengths for all ragged dimensions.
1306 `rt.nested_row_lengths()` is a tuple containing the `row_lengths` tensors
1307 for all ragged dimensions in `rt`, ordered from outermost to innermost.
1309 Args:
1310 name: A name prefix for the returned tensors (optional).
1312 Returns:
1313 A `tuple` of 1-D integer `Tensors`. The length of the tuple is equal to
1314 `self.ragged_rank`.
1315 """
1316 with ops.name_scope(name, "RaggedNestedRowLengths", [self]):
1317 rt_nested_row_lengths = []
1318 rt = self
1319 while isinstance(rt, RaggedTensor):
1320 rt_nested_row_lengths.append(rt.row_lengths())
1321 rt = rt.values
1322 return tuple(rt_nested_row_lengths)
1324 def bounding_shape(self, axis=None, name=None, out_type=None):
1325 """Returns the tight bounding box shape for this `RaggedTensor`.
1327 Args:
1328 axis: An integer scalar or vector indicating which axes to return the
1329 bounding box for. If not specified, then the full bounding box is
1330 returned.
1331 name: A name prefix for the returned tensor (optional).
1332 out_type: `dtype` for the returned tensor. Defaults to
1333 `self.row_splits.dtype`.
1335 Returns:
1336 An integer `Tensor` (`dtype=self.row_splits.dtype`). If `axis` is not
1337 specified, then `output` is a vector with
1338 `output.shape=[self.shape.ndims]`. If `axis` is a scalar, then the
1339 `output` is a scalar. If `axis` is a vector, then `output` is a vector,
1340 where `output[i]` is the bounding size for dimension `axis[i]`.
1342 #### Example:
1344 >>> rt = tf.ragged.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9], [10]])
1345 >>> rt.bounding_shape().numpy()
1346 array([5, 4])
1348 """
1349 if out_type is None:
1350 out_type = self._row_partition.dtype
1351 else:
1352 out_type = dtypes.as_dtype(out_type)
1353 with ops.name_scope(name, "RaggedBoundingBox", [self, axis]):
1354 nested_splits = self.nested_row_splits
1355 rt_flat_values = self.flat_values
1357 # Optimized special cases for when axis=0 or axis=1:
1358 if isinstance(axis, int):
1359 if axis == 0:
1360 return array_ops.shape(nested_splits[0], out_type=out_type)[0] - 1
1361 elif axis == 1:
1362 result = math_ops.maximum(math_ops.reduce_max(self.row_lengths()), 0)
1363 if out_type != self._row_partition.dtype:
1364 result = math_ops.cast(result, out_type)
1365 return result
1367 splits_shape = array_ops.shape(self.row_splits, out_type=out_type)
1368 flat_values_shape = array_ops.shape(rt_flat_values, out_type=out_type)
1370 ragged_dimensions = [splits_shape[0] - 1] + [
1371 math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0)
1372 for splits in nested_splits
1373 ]
1374 inner_dimensions = flat_values_shape[1:]
1376 if out_type != self._row_partition.dtype:
1377 ragged_dimensions = [
1378 math_ops.cast(d, out_type) for d in ragged_dimensions
1379 ]
1380 bbox = array_ops.concat(
1381 [array_ops_stack.stack(ragged_dimensions), inner_dimensions], axis=0)
1382 return bbox if axis is None else array_ops.gather(bbox, axis)
1384 #=============================================================================
1385 # Transformation
1386 #=============================================================================
1388 def with_values(self, new_values):
1389 """Returns a copy of `self` with `values` replaced by `new_value`.
1391 Preserves cached row-partitioning tensors such as `self.cached_nrows` and
1392 `self.cached_value_rowids` if they have values.
1394 Args:
1395 new_values: Potentially ragged tensor to use as the `values` for the
1396 returned `RaggedTensor`. Must have `rank > 0`, and must have the same
1397 number of rows as `self.values`.
1399 Returns:
1400 A `RaggedTensor`. `result.rank = 1 + new_values.rank`.
1401 `result.ragged_rank = 1 + new_values.ragged_rank`
1402 """
1403 new_values = _convert_to_ragged_tensor_values(new_values)
1404 new_values.shape.with_rank_at_least(1)
1405 self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1])
1406 if (isinstance(new_values, RaggedTensor) and
1407 self._row_partition.dtype != new_values.row_splits.dtype):
1408 if not ragged_config.auto_cast_partition_dtype():
1409 raise ValueError("self and new_values have mismatched row_splits "
1410 "dtypes; use RaggedTensor.with_row_splits_dtype() to "
1411 "convert them to compatible dtypes.")
1412 new_values = new_values.with_row_splits_dtype(dtypes.int64)
1413 return self.with_row_splits_dtype(dtypes.int64).with_values(new_values)
1414 return RaggedTensor(
1415 values=new_values, row_partition=self._row_partition, internal=True)
1417 def with_flat_values(self, new_values):
1418 """Returns a copy of `self` with `flat_values` replaced by `new_value`.
1420 Preserves cached row-partitioning tensors such as `self.cached_nrows` and
1421 `self.cached_value_rowids` if they have values.
1423 Args:
1424 new_values: Potentially ragged tensor that should replace
1425 `self.flat_values`. Must have `rank > 0`, and must have the same number
1426 of rows as `self.flat_values`.
1428 Returns:
1429 A `RaggedTensor`.
1430 `result.rank = self.ragged_rank + new_values.rank`.
1431 `result.ragged_rank = self.ragged_rank + new_values.ragged_rank`.
1432 """
1433 if isinstance(self._values, RaggedTensor):
1434 return self.with_values(self.values.with_flat_values(new_values))
1435 else:
1436 new_values = _convert_to_ragged_tensor_values(new_values)
1437 return self.with_values(new_values)
1439 def with_row_splits_dtype(self, dtype):
1440 """Returns a copy of this RaggedTensor with the given `row_splits` dtype.
1442 For RaggedTensors with multiple ragged dimensions, the `row_splits` for all
1443 nested `RaggedTensor` objects are cast to the given dtype.
1445 Args:
1446 dtype: The dtype for `row_splits`. One of `tf.int32` or `tf.int64`.
1448 Returns:
1449 A copy of this RaggedTensor, with the `row_splits` cast to the given
1450 type.
1451 """
1452 dtype = dtypes.as_dtype(dtype)
1453 if dtype not in (dtypes.int32, dtypes.int64):
1454 raise ValueError(f"Argument `row_splits` dtype must be int32 or int64. "
1455 f"Received {dtype}.")
1456 if self._row_partition.dtype == dtype:
1457 return self
1458 current_values = self._values
1459 if isinstance(current_values, RaggedTensor):
1460 return RaggedTensor(
1461 values=current_values.with_row_splits_dtype(dtype),
1462 row_partition=self._row_partition.with_dtype(dtype),
1463 internal=True)
1464 else:
1465 return RaggedTensor(
1466 values=current_values,
1467 row_partition=self._row_partition.with_dtype(dtype),
1468 internal=True)
1470 def merge_dims(self, outer_axis, inner_axis):
1471 """Merges outer_axis...inner_axis into a single dimension.
1473 Returns a copy of this RaggedTensor with the specified range of dimensions
1474 flattened into a single dimension, with elements in row-major order.
1476 #### Examples:
1478 >>> rt = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]])
1479 >>> print(rt.merge_dims(0, 1))
1480 <tf.RaggedTensor [[1, 2], [3], [4, 5, 6]]>
1481 >>> print(rt.merge_dims(1, 2))
1482 <tf.RaggedTensor [[1, 2, 3], [4, 5, 6]]>
1483 >>> print(rt.merge_dims(0, 2))
1484 tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32)
1486 To mimic the behavior of `np.flatten` (which flattens all dimensions), use
1487 `rt.merge_dims(0, -1). To mimic the behavior of `tf.layers.Flatten` (which
1488 flattens all dimensions except the outermost batch dimension), use
1489 `rt.merge_dims(1, -1)`.
1491 Args:
1492 outer_axis: `int`: The first dimension in the range of dimensions to
1493 merge. May be negative if `self.shape.rank` is statically known.
1494 inner_axis: `int`: The last dimension in the range of dimensions to merge.
1495 May be negative if `self.shape.rank` is statically known.
1497 Returns:
1498 A copy of this tensor, with the specified dimensions merged into a
1499 single dimension. The shape of the returned tensor will be
1500 `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
1501 is the total number of slices in the merged dimensions.
1502 """
1503 outer_axis = array_ops.get_positive_axis(
1504 outer_axis,
1505 self.shape.rank,
1506 axis_name="outer_axis",
1507 ndims_name="rank(self)")
1508 inner_axis = array_ops.get_positive_axis(
1509 inner_axis,
1510 self.shape.rank,
1511 axis_name="inner_axis",
1512 ndims_name="rank(self)")
1513 if not outer_axis <= inner_axis:
1514 raise ValueError(f"Expected outer_axis ({outer_axis}) to be less than or "
1515 f"equal to inner_axis ({inner_axis}).")
1516 return merge_dims(self, outer_axis, inner_axis)
1518 def _set_shape(self, shape):
1519 """Updates the static shape of `self` to be `shape`.
1521 * If a dimension of `shape` has known rank, and is encoded via
1522 partitioning, then this will update the corresponding partition to
1523 define `_uniform_row_length` and `nrows`.
1524 * If a dimension of `shape` has a known rank, and is encoded as one
1525 of the `flat_values` dimensions, then `flat_values.set_shape()` will
1526 be used to update its shape.
1528 Warning: Using this method to assert an incorrect shape for a RaggedTensor
1529 (i.e., one that's not consistent with its actual shape) can cause
1530 segmentation faults and very difficult-to-diagnose behavior. Only use this
1531 method if you are certain that the shape is correct.
1533 Args:
1534 shape: `tf.TensorShape` specifying the shape for this `RaggedTensor`.
1535 """
1536 # TODO(edloper): Refactor this to not directly access private members
1537 # of RowPartition.
1538 # pylint: disable=protected-access
1540 shape = tensor_shape.as_shape(shape)
1541 if shape.rank is None:
1542 return # Nothing to do.
1544 shape = shape.as_list()
1546 # Outermost dimension
1547 if shape[0] is not None:
1548 self._row_partition._row_splits.set_shape(shape[0] + 1)
1550 # Partitioned dimensions
1551 dtype = self._row_partition.dtype
1552 for i, partition in enumerate(self._nested_row_partitions):
1553 size = shape[i + 1]
1554 if size is not None:
1555 if partition._uniform_row_length is not None:
1556 old_row_length = tensor_util.constant_value(
1557 partition._uniform_row_length)
1558 if old_row_length is not None:
1559 if size == old_row_length:
1560 continue # already have shape info for this axis.
1561 else:
1562 raise ValueError(f"Inconsistent size for axis {i + 1}: "
1563 f"{old_row_length} vs. {size}.")
1564 partition._uniform_row_length = ops.convert_to_tensor(size, dtype)
1565 if partition._nrows is None:
1566 partition._nrows = array_ops.size(
1567 partition._row_splits, out_type=dtype) - 1
1569 # self.flat_values could be a CompositeTensor and doesn't have set_shape.
1570 if hasattr(self.flat_values, "set_shape"):
1571 # Inner dimensions
1572 flat_shape = tensor_shape.as_shape([None] + shape[self.ragged_rank + 1:])
1573 self.flat_values.set_shape(flat_shape)
1575 #=============================================================================
1576 # Tensor Type Conversions
1577 #=============================================================================
1579 @classmethod
1580 @dispatch.add_dispatch_support
1581 def from_tensor(cls,
1582 tensor,
1583 lengths=None,
1584 padding=None,
1585 ragged_rank=1,
1586 name=None,
1587 row_splits_dtype=dtypes.int64):
1588 """Converts a `tf.Tensor` into a `RaggedTensor`.
1590 The set of absent/default values may be specified using a vector of lengths
1591 or a padding value (but not both). If `lengths` is specified, then the
1592 output tensor will satisfy `output[row] = tensor[row][:lengths[row]]`. If
1593 'lengths' is a list of lists or tuple of lists, those lists will be used
1594 as nested row lengths. If `padding` is specified, then any row *suffix*
1595 consisting entirely of `padding` will be excluded from the returned
1596 `RaggedTensor`. If neither `lengths` nor `padding` is specified, then the
1597 returned `RaggedTensor` will have no absent/default values.
1599 Examples:
1601 >>> dt = tf.constant([[5, 7, 0], [0, 3, 0], [6, 0, 0]])
1602 >>> tf.RaggedTensor.from_tensor(dt)
1603 <tf.RaggedTensor [[5, 7, 0], [0, 3, 0], [6, 0, 0]]>
1604 >>> tf.RaggedTensor.from_tensor(dt, lengths=[1, 0, 3])
1605 <tf.RaggedTensor [[5], [], [6, 0, 0]]>
1607 >>> tf.RaggedTensor.from_tensor(dt, padding=0)
1608 <tf.RaggedTensor [[5, 7], [0, 3], [6]]>
1610 >>> dt = tf.constant([[[5, 0], [7, 0], [0, 0]],
1611 ... [[0, 0], [3, 0], [0, 0]],
1612 ... [[6, 0], [0, 0], [0, 0]]])
1613 >>> tf.RaggedTensor.from_tensor(dt, lengths=([2, 0, 3], [1, 1, 2, 0, 1]))
1614 <tf.RaggedTensor [[[5], [7]], [], [[6, 0], [], [0]]]>
1616 Args:
1617 tensor: The `Tensor` to convert. Must have rank `ragged_rank + 1` or
1618 higher.
1619 lengths: An optional set of row lengths, specified using a 1-D integer
1620 `Tensor` whose length is equal to `tensor.shape[0]` (the number of rows
1621 in `tensor`). If specified, then `output[row]` will contain
1622 `tensor[row][:lengths[row]]`. Negative lengths are treated as zero. You
1623 may optionally pass a list or tuple of lengths to this argument, which
1624 will be used as nested row lengths to construct a ragged tensor with
1625 multiple ragged dimensions.
1626 padding: An optional padding value. If specified, then any row suffix
1627 consisting entirely of `padding` will be excluded from the returned
1628 RaggedTensor. `padding` is a `Tensor` with the same dtype as `tensor`
1629 and with `shape=tensor.shape[ragged_rank + 1:]`.
1630 ragged_rank: Integer specifying the ragged rank for the returned
1631 `RaggedTensor`. Must be greater than zero.
1632 name: A name prefix for the returned tensors (optional).
1633 row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
1634 tensor. One of `tf.int32` or `tf.int64`.
1636 Returns:
1637 A `RaggedTensor` with the specified `ragged_rank`. The shape of the
1638 returned ragged tensor is compatible with the shape of `tensor`.
1640 Raises:
1641 ValueError: If both `lengths` and `padding` are specified.
1642 ValueError: If the rank of `tensor` is 0 or 1.
1643 """
1644 row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
1645 if lengths is not None and padding is not None:
1646 raise ValueError("Specify argument `lengths` or `padding`, but not both.")
1647 if not isinstance(ragged_rank, int):
1648 raise TypeError(f"Argument `ragged_rank` must be an int. "
1649 f"Received {ragged_rank}.")
1650 if ragged_rank <= 0:
1651 raise ValueError(f"Argument `ragged_rank` must be greater than 0. "
1652 f"Received {ragged_rank}.")
1654 with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]):
1655 tensor = ops.convert_to_tensor(tensor, name="tensor")
1656 if tensor.shape.rank is not None and tensor.shape.rank < 2:
1657 raise ValueError(f"The rank of a RaggedTensor must be greater than 1, "
1658 f"i.e., a list of scalars won't have ragged "
1659 f"dimensions. Received argument `tensor` with rank "
1660 f"{tensor.shape.rank}.")
1661 tensor.shape.with_rank_at_least(ragged_rank + 1)
1662 input_shape = array_ops.shape(tensor, out_type=row_splits_dtype)
1663 ncols = input_shape[1]
1665 # Handle nested row lengths.
1666 if (lengths is not None and isinstance(lengths, (list, tuple)) and
1667 len(lengths) and not isinstance(lengths[0], (int, float))):
1668 if ragged_rank not in (1, len(lengths)):
1669 # Note: we accept `ragged_rank=1` here because it's the default value;
1670 # i.e., if the user passes in a tuple of lengths, but doesn't specify
1671 # ragged_rank, then we should use that tuple to determine ragged_rank.
1672 # We only want to complain if they pass in an explicit ragged_rank
1673 # that doesn't match len(lengths).
1674 raise ValueError(f"If Argument `lengths` is a tuple of row_lengths, "
1675 f"argument `ragged_rank` must be "
1676 f"len(lengths): {len(lengths)}. Received "
1677 f"ragged_rank: {ragged_rank}.")
1678 # Rather than reconstructing the tensor mask directly, we can
1679 # recreate it as a boolean RaggedTensor, then densify that and use
1680 # that as the mask to clear out the unused data in the passed tensor.
1681 tensor.shape.with_rank_at_least(len(lengths) + 1)
1682 num_tokens = math_ops.reduce_sum(lengths[-1])
1683 ones_mask = array_ops.ones([num_tokens], dtype=dtypes.bool)
1684 ragged_mask = cls.from_nested_row_lengths(
1685 ones_mask, lengths, validate=False)
1686 dense_ragged_mask = ragged_mask.to_tensor(default_value=False)
1687 masked_data = array_ops.boolean_mask(tensor, dense_ragged_mask)
1688 return cls.from_nested_row_lengths(masked_data, lengths, validate=False)
1690 # Handle ragged_rank>1 via recursion:
1691 # If the output should have multiple ragged dimensions, then first
1692 # flatten the tensor to eliminate all but the last ragged dimension,
1693 # and recursively convert that flattened tensor. Then add on the splits
1694 # for the dimensions that we flattened out.
1695 if ragged_rank > 1:
1696 if tensor.shape.is_fully_defined():
1697 input_shape = tensor.shape.as_list()
1698 # The total number of elements in each dimension. E.g., if
1699 # input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total.
1700 dim_size = np.cumprod(input_shape)
1701 new_shape = [dim_size[ragged_rank - 1]] + input_shape[ragged_rank:]
1702 else:
1703 dim_size = math_ops.cumprod(input_shape)
1704 new_shape = array_ops.concat(
1705 [[dim_size[ragged_rank - 1]], input_shape[ragged_rank:]], axis=0)
1706 flattened = array_ops.reshape(tensor, new_shape)
1707 result = cls.from_tensor(
1708 flattened, lengths, padding, row_splits_dtype=row_splits_dtype)
1710 for axis in range(ragged_rank - 1, 0, -1):
1711 dim_len = tensor_shape.dimension_at_index(tensor.shape, axis).value
1712 if dim_len is None:
1713 dim_len = input_shape[axis]
1714 else:
1715 dim_len = constant_op.constant(dim_len, row_splits_dtype)
1716 result = RaggedTensor.from_uniform_row_length(
1717 values=result,
1718 uniform_row_length=dim_len,
1719 nrows=dim_size[axis - 1],
1720 validate=False)
1721 return result
1723 # If padding was specified, then use it to find row lengths.
1724 if padding is not None:
1725 padding = ops.convert_to_tensor(
1726 padding, name="padding", dtype=tensor.dtype)
1727 padding.shape.assert_is_compatible_with(tensor.shape[2:])
1729 # Find places where the padding is equal to the tensor. (This will
1730 # broadcast `padding` across the outermost 2 dimensions of `tensor`,
1731 # so `has_default_value.shape = tensor.shape`.)
1732 has_default_value = math_ops.equal(padding, tensor)
1734 # If the padding isn't a scalar, then require that all values in the
1735 # padding match each item in the tensor. After this block of code,
1736 # `has_default.shape = tensor.shape[:2]`. (Unfortunately, we can't just
1737 # use reduce_all for both cases, becaue when you pass an empty `axis`
1738 # list to reduce_all, it reduces all axes; but we want it to reduce no
1739 # axes -- i.e., to be a no-op.)
1740 tensor_rank = array_ops.rank(tensor)
1741 reduce_axis = math_ops.range(2, tensor_rank)
1742 has_default = cond.cond(
1743 tensor_rank > 2,
1744 lambda: math_ops.reduce_all(has_default_value, axis=reduce_axis),
1745 lambda: has_default_value)
1746 has_default.set_shape(tensor_shape.TensorShape([None, None]))
1747 has_default.set_shape(tensor.shape[:2])
1749 # Use has_default to find the length of each row: for each
1750 # non-default item in a row, calculate the length that the row needs to
1751 # have to include that item; and then take the max of those values
1752 # (across each row).
1753 has_nondefault = math_ops.logical_not(has_default)
1754 has_nondefault = math_ops.cast(has_nondefault, row_splits_dtype)
1755 length_for_nondefault_value = (
1756 has_nondefault *
1757 array_ops.expand_dims(math_ops.range(1, ncols + 1), 0))
1758 lengths = math_ops.reduce_max(length_for_nondefault_value, axis=1)
1760 if lengths is not None:
1761 # If we have lengths (either directly supplied, or computed from
1762 # paddings), then use those to construct splits; and then use masking
1763 # to get the corresponding values.
1764 lengths = ragged_util.convert_to_int_tensor(lengths, "lengths",
1765 row_splits_dtype)
1766 lengths.shape.assert_has_rank(1)
1767 lengths = math_ops.minimum(lengths, ncols)
1768 lengths = math_ops.maximum(lengths, 0)
1769 limits = math_ops.cumsum(lengths)
1770 splits = array_ops.concat(
1771 [array_ops.zeros([1], row_splits_dtype), limits], axis=0)
1772 mask = array_ops.sequence_mask(lengths, maxlen=ncols)
1773 values = array_ops.boolean_mask(tensor, mask)
1774 return cls.from_row_splits(values, splits, validate=False)
1776 # If neither padding nor lengths were specified, then create a splits
1777 # vector that contains no default values, and reshape the input tensor
1778 # to form the values for the RaggedTensor.
1779 values_shape = array_ops.concat(
1780 [[input_shape[0] * input_shape[1]], input_shape[2:]], axis=0)
1781 values = array_ops.reshape(tensor, values_shape)
1782 const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value
1783 const_ncols = tensor_shape.dimension_at_index(tensor.shape, 1).value
1784 if const_nrows is not None:
1785 nrows = constant_op.constant(const_nrows, row_splits_dtype)
1786 else:
1787 nrows = input_shape[0]
1788 if const_ncols is not None:
1789 ncols = constant_op.constant(const_ncols, row_splits_dtype)
1790 else:
1791 ncols = input_shape[1]
1792 return RaggedTensor.from_uniform_row_length(
1793 values=values, uniform_row_length=ncols, nrows=nrows, validate=False)
1795 def to_tensor(self, default_value=None, name=None, shape=None):
1796 """Converts this `RaggedTensor` into a `tf.Tensor`.
1798 If `shape` is specified, then the result is padded and/or truncated to
1799 the specified shape.
1801 Examples:
1803 >>> rt = tf.ragged.constant([[9, 8, 7], [], [6, 5], [4]])
1804 >>> print(rt.to_tensor())
1805 tf.Tensor(
1806 [[9 8 7] [0 0 0] [6 5 0] [4 0 0]], shape=(4, 3), dtype=int32)
1807 >>> print(rt.to_tensor(shape=[5, 2]))
1808 tf.Tensor(
1809 [[9 8] [0 0] [6 5] [4 0] [0 0]], shape=(5, 2), dtype=int32)
1811 Args:
1812 default_value: Value to set for indices not specified in `self`. Defaults
1813 to zero. `default_value` must be broadcastable to
1814 `self.shape[self.ragged_rank + 1:]`.
1815 name: A name prefix for the returned tensors (optional).
1816 shape: The shape of the resulting dense tensor. In particular,
1817 `result.shape[i]` is `shape[i]` (if `shape[i]` is not None), or
1818 `self.bounding_shape(i)` (otherwise).`shape.rank` must be `None` or
1819 equal to `self.rank`.
1821 Returns:
1822 A `Tensor` with shape `ragged.bounding_shape(self)` and the
1823 values specified by the non-empty values in `self`. Empty values are
1824 assigned `default_value`.
1825 """
1826 with ops.name_scope(name, "RaggedToTensor", [self, default_value, shape]):
1827 if default_value is not None:
1828 default_value = ops.convert_to_tensor(
1829 default_value, name="default_value", dtype=self.dtype)
1830 type_tensor_pairs = _get_row_partition_type_tensor_pairs(self)
1831 row_partition_types = [x[0] for x in type_tensor_pairs]
1832 row_partition_tensors = [x[1] for x in type_tensor_pairs]
1833 if default_value is None:
1834 default_value = array_ops.zeros((), self.dtype)
1836 if (isinstance(shape, (list, tuple)) and
1837 any(isinstance(v, ops.Tensor) for v in shape) and
1838 all(isinstance(v, (int, ops.Tensor)) for v in shape)):
1839 shape = array_ops_stack.stack(shape)
1841 shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype)
1842 tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor(
1843 shape=shape_tensor,
1844 values=self.flat_values,
1845 default_value=default_value,
1846 row_partition_types=row_partition_types,
1847 row_partition_tensors=row_partition_tensors)
1849 ragged_shape = self.shape
1851 if ragged_shape.rank is not None and not isinstance(shape, ops.Tensor):
1852 # Merged self.shape and shape, favoring the second one as it takes
1853 # into account potential padding added to the output.
1854 shape = tensor_shape.as_shape(shape)
1855 if shape.rank is None:
1856 output_shape = ragged_shape
1857 else:
1858 # At this point we can assume that hshape.rank == ragged_shape.rank
1859 # because otherwise it would have failed earlier.
1860 output_shape = [
1861 s1 if s1 is not None else s2
1862 for (s1, s2) in zip(shape.as_list(), ragged_shape.as_list())
1863 ]
1864 tensor.set_shape(output_shape)
1866 return tensor
1868 @classmethod
1869 @dispatch.add_dispatch_support
1870 def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64):
1871 """Converts a 2D `tf.sparse.SparseTensor` to a `RaggedTensor`.
1873 Each row of the `output` `RaggedTensor` will contain the explicit values
1874 from the same row in `st_input`. `st_input` must be ragged-right. If not
1875 it is not ragged-right, then an error will be generated.
1877 Example:
1879 >>> indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0]]
1880 >>> st = tf.sparse.SparseTensor(indices=indices,
1881 ... values=[1, 2, 3, 4, 5],
1882 ... dense_shape=[4, 3])
1883 >>> tf.RaggedTensor.from_sparse(st).to_list()
1884 [[1, 2, 3], [4], [], [5]]
1886 Currently, only two-dimensional `SparseTensors` are supported.
1888 Args:
1889 st_input: The sparse tensor to convert. Must have rank 2.
1890 name: A name prefix for the returned tensors (optional).
1891 row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
1892 tensor. One of `tf.int32` or `tf.int64`.
1894 Returns:
1895 A `RaggedTensor` with the same values as `st_input`.
1896 `output.ragged_rank = rank(st_input) - 1`.
1897 `output.shape = [st_input.dense_shape[0], None]`.
1898 Raises:
1899 ValueError: If the number of dimensions in `st_input` is not known
1900 statically, or is not two.
1901 """
1902 row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
1903 if not sparse_tensor.is_sparse(st_input):
1904 raise TypeError(f"Argument `st_input` must be of type SparseTensor, but "
1905 f"is of type {type(st_input).__name__}.")
1906 with ops.name_scope(name, "RaggedFromSparse", [st_input]):
1907 st_input = sparse_tensor.convert_to_tensor_or_sparse_tensor(
1908 st_input, name="st_input")
1910 if st_input.dense_shape.shape.ndims is None:
1911 static_rank_from_dense_shape = None
1912 else:
1913 static_rank_from_dense_shape = st_input.dense_shape.shape.dims[0].value
1915 if st_input.indices.shape.ndims is None:
1916 static_rank_from_indices = None
1917 else:
1918 static_rank_from_indices = st_input.indices.shape.dims[1].value
1920 if static_rank_from_dense_shape != 2 and static_rank_from_indices != 2:
1921 raise ValueError("rank(st_input) must be 2.")
1923 with ops.control_dependencies(
1924 _assert_sparse_indices_are_ragged_right(st_input.indices)):
1925 # Treat sparse row indices as segment ids to generate a splits tensor
1926 # thta we can pair with the sparse tensor values. (Ignore sparse column
1927 # indices.)
1928 segment_ids = math_ops.cast(st_input.indices[:, 0], row_splits_dtype)
1929 num_segments = math_ops.cast(st_input.dense_shape[0], row_splits_dtype)
1930 return cls.from_value_rowids(
1931 st_input.values, segment_ids, num_segments, validate=False)
1933 def to_sparse(self, name=None):
1934 """Converts this `RaggedTensor` into a `tf.sparse.SparseTensor`.
1936 Example:
1938 >>> rt = tf.ragged.constant([[1, 2, 3], [4], [], [5, 6]])
1939 >>> print(rt.to_sparse())
1940 SparseTensor(indices=tf.Tensor(
1941 [[0 0] [0 1] [0 2] [1 0] [3 0] [3 1]],
1942 shape=(6, 2), dtype=int64),
1943 values=tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32),
1944 dense_shape=tf.Tensor([4 3], shape=(2,), dtype=int64))
1946 Args:
1947 name: A name prefix for the returned tensors (optional).
1949 Returns:
1950 A SparseTensor with the same values as `self`.
1951 """
1952 with ops.name_scope(name, "RaggedToSparse", [self]):
1953 result = gen_ragged_conversion_ops.ragged_tensor_to_sparse(
1954 self.nested_row_splits, self.flat_values, name=name)
1955 return sparse_tensor.SparseTensor(result.sparse_indices,
1956 result.sparse_values,
1957 result.sparse_dense_shape)
1959 @classmethod
1960 def _from_variant(cls,
1961 variant,
1962 dtype,
1963 output_ragged_rank,
1964 input_ragged_rank=None,
1965 row_splits_dtype=dtypes.int64,
1966 name=None):
1967 """Converts a `variant` Tensor into a `RaggedTensor`.
1969 The input `variant` could be a scalar, meaning it encodes a single
1970 `RaggedTensor` with ragged_rank `output_ragged_rank`. Alternatively it could
1971 have an arbitrary rank, in which case each element is decoded into a
1972 `RaggedTensor` with ragged_rank `input_ragged_rank` and these are then
1973 stacked according to the input shape to output a single `RaggedTensor`
1974 with ragged_rank `output_ragged_rank`. If `input_ragged_rank` is not
1975 provided, it is inferred dynamically as `output_ragged_rank` -
1976 `rank(variant)`. If `input_ragged_rank` is provided, the following must be
1977 true: `output_ragged_rank` = `input_ragged_rank` + `rank(variant)`.
1979 Example:
1981 >>> rt = tf.ragged.constant([[0], [1, 2]])
1982 >>> et = rt._to_variant()
1983 >>> stacked_et = tf.stack([et, et])
1984 >>> tf.RaggedTensor._from_variant( # scalar input.
1985 ... et, dtype=tf.int32, output_ragged_rank=1).to_list()
1986 [[0], [1, 2]]
1987 >>> tf.RaggedTensor._from_variant( # batched input.
1988 ... stacked_et, dtype=tf.int32, output_ragged_rank=2).to_list()
1989 [[[0], [1, 2]], [[0], [1, 2]]]
1991 Args:
1992 variant: A `variant` Tensor representing an encoded (possibly
1993 nested-batched) `RaggedTensor`.
1994 dtype: The dtype of the encoded `RaggedTensor`.
1995 output_ragged_rank: The expected ragged rank of the output `RaggedTensor`.
1996 input_ragged_rank: The ragged rank of each encoded `RaggedTensor`. This is
1997 optional and inferred dynamically if not provided.
1998 row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One
1999 of `tf.int32` or `tf.int64`.
2000 name: A name prefix for the returned tensors (optional).
2002 Returns:
2003 A `RaggedTensor` of dtype `dtype` and ragged rank `output_ragged_rank`.
2005 Raises:
2006 ValueError: If the input rank is known, `input_ragged_rank` is provided
2007 and `output_ragged_rank` = `input_ragged_rank` + `rank(variant)` does
2008 not hold.
2009 """
2010 variant = ops.convert_to_tensor(
2011 variant, name="variant", dtype=dtypes.variant)
2012 if (variant.shape.ndims is not None and input_ragged_rank is not None and
2013 output_ragged_rank != input_ragged_rank + variant.shape.ndims):
2014 raise ValueError(
2015 f"Argument `output_ragged_rank` ({output_ragged_rank}) must be equal "
2016 f"to `input_ragged_rank` + `variant.shape.ndims` "
2017 f"({input_ragged_rank} + {variant.shape.ndims}).")
2018 input_ragged_rank = -1 if input_ragged_rank is None else input_ragged_rank
2019 with ops.name_scope(
2020 name, "RaggedFromVariant",
2021 [variant, dtype, input_ragged_rank, output_ragged_rank]):
2022 result = gen_ragged_conversion_ops.ragged_tensor_from_variant(
2023 variant, input_ragged_rank, max(output_ragged_rank, 0), dtype,
2024 row_splits_dtype, name)
2025 return cls.from_nested_row_splits(
2026 result.output_dense_values,
2027 result.output_nested_splits,
2028 validate=False)
2030 def _to_variant(self, batched_input=False, name=None):
2031 """Converts this `RaggedTensor` into a `variant` Tensor.
2033 If `batched_input` is `True`, then the `RaggedTensor` is unbatched along the
2034 zero-th dimension, each component `RaggedTensor` is encoded into a scalar
2035 `variant` Tensor, and these are stacked to return a 1-D `variant` Tensor.
2036 If `batched_input` is `False`, then the `RaggedTensor` is encoded as is and
2037 a scalar `variant` Tensor is returned.
2039 Example:
2040 >>> rt = tf.ragged.constant([[[0]], [[1]], [[2]]])
2041 >>> rt._to_variant().shape.as_list()
2042 []
2043 >>> rt._to_variant(batched_input=True).shape.as_list()
2044 [3]
2046 Args:
2047 batched_input: If `True`, the `RaggedTensor` is unbatched and converted to
2048 a `variant` vector. Set to `False` by default.
2049 name: A name prefix for the returned tensors (optional).
2051 Returns:
2052 A `variant` Tensor that encodes this `RaggedTensor`.
2053 """
2054 with ops.name_scope(name, "RaggedToVariant", [self, batched_input]):
2055 return gen_ragged_conversion_ops.ragged_tensor_to_variant(
2056 self.nested_row_splits, self.flat_values, batched_input, name)
2058 #=============================================================================
2059 # String Encoding
2060 #=============================================================================
2061 def __repr__(self):
2062 if self._is_eager():
2063 # The np.array2string in _formatter provides a separator argument, but
2064 # doesn't handle recursive calls correctly. The np.printoptions handles
2065 # recursive calls correctly, but doesn't provide a separator argument.
2066 # Combines them together to print elements separated by comma, while
2067 # avoiding the redundant array prefixes and dtypes. For example,
2068 # the value of tf.ragged.constant([[1, 2], [3, 4]]) will look like
2069 #
2070 # [[1, 2],
2071 # [3, 4]]
2072 with np.printoptions(formatter={"all": _formatter}):
2073 value_text = _formatter(self.numpy())
2074 return f"<tf.RaggedTensor {value_text}>"
2075 else:
2076 return "tf.RaggedTensor(values=%s, row_splits=%s)" % (self.values,
2077 self.row_splits)
2079 #=============================================================================
2080 # Eager Execution Mode
2081 #=============================================================================
2083 def numpy(self):
2084 """Returns a numpy `array` with the values for this `RaggedTensor`.
2086 Requires that this `RaggedTensor` was constructed in eager execution mode.
2088 Ragged dimensions are encoded using numpy `arrays` with `dtype=object` and
2089 `rank=1`, where each element is a single row.
2091 #### Examples
2093 In the following example, the value returned by `RaggedTensor.numpy()`
2094 contains three numpy `array` objects: one for each row (with `rank=1` and
2095 `dtype=int64`), and one to combine them (with `rank=1` and `dtype=object`):
2097 >>> tf.ragged.constant([[1, 2, 3], [4, 5]], dtype=tf.int64).numpy()
2098 array([array([1, 2, 3]), array([4, 5])], dtype=object)
2100 Uniform dimensions are encoded using multidimensional numpy `array`s. In
2101 the following example, the value returned by `RaggedTensor.numpy()` contains
2102 a single numpy `array` object, with `rank=2` and `dtype=int64`:
2104 >>> tf.ragged.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.int64).numpy()
2105 array([[1, 2, 3], [4, 5, 6]])
2107 Returns:
2108 A numpy `array`.
2109 """
2110 if not self._is_eager():
2111 raise ValueError("RaggedTensor.numpy() is only supported in eager mode.")
2112 values = self.values.numpy()
2113 splits = self.row_splits.numpy()
2114 rows = [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)]
2115 if not rows:
2116 return np.zeros((0, 0) + values.shape[1:], dtype=values.dtype)
2117 # Note: if `rows` have ragged lengths, then they will be stored in a
2118 # np.ndarray with dtype=object and rank=1. If they have uniform lengths,
2119 # they will be combined into a single np.ndarray with dtype=row.dtype and
2120 # rank=row.rank+1.
2121 #
2122 # Manually set dtype as numpy now complains when given ragged rows.
2123 has_variable_length_rows = any(len(row) != len(rows[0]) for row in rows)
2124 dtype = np.object_ if has_variable_length_rows else None
2125 return np.array(rows, dtype=dtype)
2127 def to_list(self):
2128 """Returns a nested Python `list` with the values for this `RaggedTensor`.
2130 Requires that `rt` was constructed in eager execution mode.
2132 Returns:
2133 A nested Python `list`.
2134 """
2135 if not isinstance(self.row_splits, ops.EagerTensor):
2136 raise ValueError("to_list can only be used in eager mode.")
2137 row_splits = self.row_splits.numpy().tolist()
2138 values = self.values
2140 if isinstance(values, RaggedTensor):
2141 return [
2142 values[row_splits[i]:row_splits[i + 1]].to_list()
2143 for i in range(len(row_splits) - 1)
2144 ]
2145 else:
2146 # Convert values to a Python list.
2147 if hasattr(values, "numpy"):
2148 values_as_list = values.numpy().tolist()
2149 elif hasattr(values, "to_list"):
2150 values_as_list = values.to_list()
2151 else:
2152 raise ValueError("values must be convertible to a list")
2154 return [
2155 values_as_list[row_splits[i]:row_splits[i + 1]]
2156 for i in range(len(row_splits) - 1)
2157 ]
2159 def _eager_value(self):
2160 """Returns a RaggedTensorValue for self. Requires self._is_eager()=true."""
2161 value = self.flat_values.numpy()
2162 for row_splits in reversed(self.nested_row_splits):
2163 value = ragged_tensor_value.RaggedTensorValue(value, row_splits.numpy())
2164 return value
2166 def _is_eager(self):
2167 """Returns True if values & row_splits Tensors are all `EagerTensor`s."""
2168 rt = self
2169 while isinstance(rt, RaggedTensor):
2170 if not isinstance(rt.row_splits, ops.EagerTensor):
2171 return False
2172 rt = rt.values
2173 return isinstance(rt, ops.EagerTensor)
2175 #=============================================================================
2176 # Operators
2177 #=============================================================================
2178 # To avoid circular dependencies, we define stub methods for operators here,
2179 # and then override them when the ragged_operators module is imported.
2181 def _overloaded_operator(name): # pylint: disable=no-self-argument
2183 def stub(*args, **kwargs):
2184 del args, kwargs
2185 raise ValueError(
2186 f"You must import 'tensorflow.python.ops.ragged.ragged_ops' "
2187 f"before using RaggedTensor.{name}.")
2189 return stub
2191 __getitem__ = _overloaded_operator("__getitem__")
2192 __ge__ = _overloaded_operator("__ge__")
2193 __gt__ = _overloaded_operator("__gt__")
2194 __le__ = _overloaded_operator("__le__")
2195 __lt__ = _overloaded_operator("__lt__")
2196 __and__ = _overloaded_operator("__and__")
2197 __rand__ = _overloaded_operator("__rand__")
2198 __invert__ = _overloaded_operator("__invert__")
2199 __ror__ = _overloaded_operator("__ror__")
2200 __or__ = _overloaded_operator("__or__")
2201 __xor__ = _overloaded_operator("__xor__")
2202 __rxor__ = _overloaded_operator("__rxor__")
2203 __abs__ = _overloaded_operator("__abs__")
2204 __add__ = _overloaded_operator("__add__")
2205 __radd__ = _overloaded_operator("__radd__")
2206 __div__ = _overloaded_operator("__div__")
2207 __rdiv__ = _overloaded_operator("__rdiv__")
2208 __floordiv__ = _overloaded_operator("__floordiv__")
2209 __rfloordiv__ = _overloaded_operator("__rfloordiv__")
2210 __mod__ = _overloaded_operator("__mod__")
2211 __rmod__ = _overloaded_operator("__rmod__")
2212 __mul__ = _overloaded_operator("__mul__")
2213 __rmul__ = _overloaded_operator("__rmul__")
2214 __neg__ = _overloaded_operator("__neg__")
2215 __pow__ = _overloaded_operator("__pow__")
2216 __rpow__ = _overloaded_operator("__rpow__")
2217 __sub__ = _overloaded_operator("__sub__")
2218 __rsub__ = _overloaded_operator("__rsub__")
2219 __truediv__ = _overloaded_operator("__truediv__")
2220 __rtruediv__ = _overloaded_operator("__rtruediv__")
2221 del _overloaded_operator
2223 #=============================================================================
2224 # Name Scope
2225 #=============================================================================
2227 # This private function is used by ops.name_scope to ensure that all of the
2228 # input tensors for the scope belong to the same graph. Defining this means
2229 # that you may include `RaggedTensor` objects in the name_scope `values`
2230 # list.
2231 def _as_graph_element(self):
2232 """Convert `self` to a graph element."""
2233 values = self.values
2234 while isinstance(values, RaggedTensor):
2235 values = values.values
2236 return values
2238 #=============================================================================
2239 # Composite Tensor
2240 #=============================================================================
2242 @property
2243 def _type_spec(self):
2244 return RaggedTensorSpec.from_value(self)
2246 def _shape_invariant_to_type_spec(self, shape):
2247 return RaggedTensorSpec(shape, self.dtype, self.ragged_rank,
2248 self.row_splits.dtype)
2250 def consumers(self):
2251 return self._consumers()
2253 __composite_gradient__ = (
2254 composite_tensor_gradient.WithValuesCompositeTensorGradient())
2257def is_ragged(value):
2258 """Returns true if `value` is a ragged tensor or ragged tensor value."""
2259 return isinstance(value,
2260 (RaggedTensor, ragged_tensor_value.RaggedTensorValue))
2263def match_row_splits_dtypes(*tensors, **kwargs):
2264 """Return a copy of `tensors` with row_splits all having the same dtype.
2266 Args:
2267 *tensors: A list of Tensors or RaggedTensors.
2268 **kwargs: If 'return_dtype=True', then return a tuple (dtype, tensors),
2269 where `dtype` is the data type used by row-splits, and `tensors` is the
2270 converted list of `Tensors` and `RaggedTensors`.
2272 Returns:
2273 The converted list of `Tensors` and `RaggedTensors`.
2274 """
2275 return_dtype = kwargs.pop("return_dtype", False)
2276 if kwargs:
2277 raise ValueError(f"Unexpected keyword args {kwargs}.")
2279 has_int32 = False
2280 has_int64 = False
2281 for tensor in tensors:
2282 if isinstance(tensor, RaggedTensor):
2283 if tensor.row_splits.dtype == dtypes.int32:
2284 has_int32 = True
2285 else:
2286 has_int64 = True
2288 if has_int32 and has_int64:
2289 if not ragged_config.auto_cast_partition_dtype():
2290 raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; "
2291 "use RaggedTensor.with_row_splits_dtype() to convert "
2292 "them to compatible dtypes.")
2293 dtype = dtypes.int64
2294 tensors = tuple(
2295 t.with_row_splits_dtype(dtypes.int64) if isinstance(t, RaggedTensor
2296 ) else t
2297 for t in tensors)
2299 elif has_int32:
2300 dtype = dtypes.int32
2301 else:
2302 dtype = dtypes.int64
2304 if return_dtype:
2305 return (dtype, tensors)
2306 else:
2307 return tensors
2310# ===============================================================================
2311# RaggedTensorSpec
2312# ===============================================================================
2313@tf_export("RaggedTensorSpec")
2314@type_spec_registry.register("tf.RaggedTensorSpec")
2315class RaggedTensorSpec(type_spec.BatchableTypeSpec):
2316 """Type specification for a `tf.RaggedTensor`."""
2318 __slots__ = [
2319 "_shape", "_dtype", "_ragged_rank", "_row_splits_dtype",
2320 "_flat_values_spec"
2321 ]
2323 @property
2324 def dtype(self):
2325 """The `tf.dtypes.DType` specified by this type for the RaggedTensor.
2327 Examples:
2329 >>> rt = tf.ragged.constant([["a"], ["b", "c"]], dtype=tf.string)
2330 >>> tf.type_spec_from_value(rt).dtype
2331 tf.string
2333 Returns:
2334 A `tf.dtypes.DType` of the values in the RaggedTensor.
2335 """
2336 return self._dtype
2338 @property
2339 def shape(self):
2340 """The statically known shape of the RaggedTensor.
2342 Examples:
2344 >>> rt = tf.ragged.constant([[0], [1, 2]])
2345 >>> tf.type_spec_from_value(rt).shape
2346 TensorShape([2, None])
2348 >>> rt = tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1)
2349 >>> tf.type_spec_from_value(rt).shape
2350 TensorShape([2, None, 2])
2352 Returns:
2353 A `tf.TensorShape` containing the statically known shape of the
2354 RaggedTensor. Ragged dimensions have a size of `None`.
2355 """
2356 return self._shape
2358 @property
2359 def ragged_rank(self):
2360 """The number of times the RaggedTensor's flat_values is partitioned.
2362 Defaults to `shape.ndims - 1`.
2364 Examples:
2366 >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
2367 >>> tf.type_spec_from_value(values).ragged_rank
2368 1
2370 >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2)
2371 >>> tf.type_spec_from_value(rt1).ragged_rank
2372 2
2374 Returns:
2375 A Python `int` indicating the number of times the underlying `flat_values`
2376 Tensor has been partitioned to add a new dimension.
2377 I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`.
2378 """
2379 return self._ragged_rank
2381 @property
2382 def row_splits_dtype(self):
2383 """The `tf.dtypes.DType` of the RaggedTensor's `row_splits`.
2385 Examples:
2387 >>> rt = tf.ragged.constant([[1, 2, 3], [4]], row_splits_dtype=tf.int64)
2388 >>> tf.type_spec_from_value(rt).row_splits_dtype
2389 tf.int64
2391 Returns:
2392 A `tf.dtypes.DType` for the RaggedTensor's `row_splits` tensor. One
2393 of `tf.int32` or `tf.int64`.
2394 """
2395 return self._row_splits_dtype
2397 @property
2398 def flat_values_spec(self):
2399 """The `TypeSpec` of the flat_values of RaggedTensor.
2401 Returns:
2402 - The TypeSpec of flat_values.
2403 - None when the flat_values is a Tensor.
2404 """
2405 return self._flat_values_spec
2407 @property
2408 def value_type(self):
2409 return RaggedTensor if self._ragged_rank > 0 else ops.Tensor
2411 def __init__(self,
2412 shape=None,
2413 dtype=dtypes.float32,
2414 ragged_rank=None,
2415 row_splits_dtype=dtypes.int64,
2416 flat_values_spec=None):
2417 """Constructs a type specification for a `tf.RaggedTensor`.
2419 Args:
2420 shape: The shape of the RaggedTensor, or `None` to allow any shape. If a
2421 shape is specified, then all ragged dimensions must have size `None`.
2422 dtype: `tf.DType` of values in the RaggedTensor.
2423 ragged_rank: Python integer, the number of times the RaggedTensor's
2424 flat_values is partitioned. Defaults to `shape.ndims - 1`.
2425 row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One
2426 of `tf.int32` or `tf.int64`.
2427 flat_values_spec: TypeSpec for flat_value of the RaggedTensor. It shall be
2428 provided when the flat_values is a CompositeTensor rather then Tensor.
2429 If both `dtype` and `flat_values_spec` and are provided, `dtype` must
2430 be the same as `flat_values_spec.dtype`. (experimental)
2431 """
2432 self._shape = tensor_shape.as_shape(shape)
2433 self._row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
2434 if flat_values_spec is not None:
2435 if dtype is None:
2436 dtype = flat_values_spec.dtype
2437 elif dtype != flat_values_spec.dtype:
2438 raise ValueError("dtype must be the same as flat_values_spec.dtype")
2439 elif dtype is None:
2440 raise ValueError(
2441 "At least one of dtype or flat_values_spec must be provided")
2442 self._dtype = dtypes.as_dtype(dtype)
2443 self._flat_values_spec = flat_values_spec
2445 rank = self._shape.ndims
2446 if ragged_rank is None:
2447 if rank is None:
2448 raise ValueError("Must specify ragged_rank or "
2449 "a shape with a known rank.")
2450 ragged_rank = rank - 1
2451 self._ragged_rank = ragged_rank
2452 if not isinstance(self._ragged_rank, int):
2453 raise TypeError(f"Argument `ragged_rank` must be an int. "
2454 f"Received {ragged_rank}.")
2456 if rank is not None:
2457 if ragged_rank >= rank:
2458 raise ValueError(f"Argument `ragged_rank` ({ragged_rank}) must be less "
2459 f"than rank ({rank}).")
2461 def is_compatible_with(self, spec_or_value):
2462 # RaggedTensor with ragged_rank 0 can be compatible with raw flat_values.
2463 if self._ragged_rank == 0:
2464 if self._flat_values_spec is None:
2465 if isinstance(spec_or_value, (ops.Tensor, tensor_spec.TensorSpec)):
2466 return tensor_spec.TensorSpec(
2467 self._shape, self._dtype).is_compatible_with(spec_or_value)
2468 elif not isinstance(spec_or_value, (RaggedTensor, RaggedTensorSpec)):
2469 return self._flat_values_spec.is_compatible_with(spec_or_value)
2470 return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value)
2472 def _serialize(self):
2473 if self._flat_values_spec is None:
2474 return (self._shape, self._dtype, self._ragged_rank,
2475 self._row_splits_dtype)
2476 else:
2477 return (self._shape, self._dtype, self._ragged_rank,
2478 self._row_splits_dtype, self._flat_values_spec)
2480 @property
2481 def _component_specs(self):
2482 if self._ragged_rank <= 0:
2483 if self._flat_values_spec is not None:
2484 return [self._flat_values_spec]
2485 else:
2486 return [tensor_spec.TensorSpec(self._shape, self._dtype)]
2488 flat_values_spec = self._flat_values_spec
2489 if flat_values_spec is None:
2490 flat_values_shape = tensor_shape.TensorShape([None]).concatenate(
2491 self._shape[self._ragged_rank + 1:])
2492 flat_values_spec = tensor_spec.TensorSpec(flat_values_shape, self._dtype)
2493 outer_dim = tensor_shape.dimension_at_index(self._shape, 0)
2494 outer_splits_shape = [None if outer_dim is None else outer_dim + 1]
2495 inner_splits_spec = tensor_spec.TensorSpec([None], self._row_splits_dtype)
2497 specs = ([
2498 flat_values_spec,
2499 tensor_spec.TensorSpec(outer_splits_shape, self._row_splits_dtype)
2500 ] + [inner_splits_spec for _ in range(self._ragged_rank - 1)])
2501 return specs
2503 def _to_components(self, value):
2504 if is_ragged(value):
2505 return [value.flat_values] + list(value.nested_row_splits)
2506 else:
2507 return [value]
2509 def _from_components(self, tensor_list):
2510 result = tensor_list[0]
2511 if (all(isinstance(t, np.ndarray) for t in tensor_list) and
2512 not tf2.enabled()):
2513 for row_splits in reversed(tensor_list[1:]):
2514 result = ragged_tensor_value.RaggedTensorValue(result, row_splits)
2515 else:
2516 if isinstance(tensor_list[0], np.ndarray):
2517 tensor_list = [ops.convert_to_tensor(t) for t in tensor_list]
2518 result = tensor_list[0]
2519 for row_splits in reversed(tensor_list[1:]):
2520 result = RaggedTensor(
2521 result,
2522 RowPartition.from_row_splits(row_splits, validate=False),
2523 internal=True)
2524 if self._shape.ndims is not None:
2525 if isinstance(result, RaggedTensor):
2526 result._set_shape(self._shape) # pylint: disable=protected-access
2527 # TODO(xjun): MaskedTensor doesn't implement set_shape.
2528 if self.flat_values_spec is not None and hasattr(result.flat_values,
2529 "set_shape"):
2530 result.flat_values.set_shape(self.flat_values_spec.shape)
2531 elif isinstance(result, ops.Tensor):
2532 result.set_shape(self._shape)
2533 return result
2535 # The RaggedTensorSpec tensor_list encoding uses to/from_variant ops
2536 # to (un)box the component tensors in a way that allows for batching &
2537 # unbatching.
2538 @property
2539 def _flat_tensor_specs(self):
2540 # NOTE(mishragaurav): The default flat shape of a boxed `RaggedTensor` is
2541 # `[]` (scalar), but a `RaggedTensorSpec` can also represent a batch of
2542 # boxed `RaggedTensor` objects with shape `(...)` (and batches of batches,
2543 # etc.), so the flat shape must be unknown.
2544 return [tensor_spec.TensorSpec(None, dtypes.variant)]
2546 def _to_tensor_list(self, value):
2547 # TODO(edloper): Update gen_ragged_conversion_ops that convert to and
2548 # from variant to include all of the row-partitioning tensors.
2549 if self._flat_values_spec is not None:
2550 raise ValueError("Customized value_type is not supported.")
2551 if isinstance(value, RaggedTensor):
2552 if value.ragged_rank != self._ragged_rank:
2553 raise ValueError(
2554 f"Ragged rank of value {value.ragged_rank} does not match "
2555 f"ragged rank of type {self._ragged_rank}.")
2556 # pylint: disable=protected-access
2557 return [value._to_variant(batched_input=False)]
2558 else:
2559 if self._ragged_rank > 0:
2560 raise ValueError(
2561 f"Expected a RaggedTensor if ragged rank={self._ragged_rank}"
2562 f" but got {type(value).__name__}."
2563 )
2564 return [
2565 gen_ragged_conversion_ops.ragged_tensor_to_variant(
2566 (), value, batched_input=False)
2567 ]
2569 def _to_batched_tensor_list(self, value):
2570 if self._flat_values_spec is not None:
2571 raise ValueError("Customized value_type is not supported.")
2572 if isinstance(value, RaggedTensor):
2573 if value.ragged_rank != self._ragged_rank:
2574 raise ValueError(
2575 f"Ragged rank of value {value.ragged_rank} does not match "
2576 f"ragged rank of type {self._ragged_rank}.")
2577 # pylint: disable=protected-access
2578 return [value._to_variant(batched_input=True)]
2579 else:
2580 if self._ragged_rank > 0:
2581 raise ValueError(
2582 f"Expected a RaggedTensor if ragged rank={self._ragged_rank}"
2583 f" but got {type(value).__name__}."
2584 )
2585 return [
2586 gen_ragged_conversion_ops.ragged_tensor_to_variant(
2587 rt_nested_splits=(), rt_dense_values=value, batched_input=True)
2588 ]
2590 def _from_compatible_tensor_list(self, tensor_list):
2591 if self._flat_values_spec is not None:
2592 raise ValueError("Customized value_type is not supported.")
2593 result = RaggedTensor._from_variant( # pylint: disable=protected-access
2594 tensor_list[0],
2595 dtype=self._dtype,
2596 row_splits_dtype=self._row_splits_dtype,
2597 output_ragged_rank=self._ragged_rank)
2598 if self._shape.ndims is not None:
2599 if isinstance(result, RaggedTensor):
2600 result._set_shape(self._shape) # pylint: disable=protected-access
2601 # TODO(xjun): MaskedTensor doesn't implement set_shape.
2602 if self.flat_values_spec is not None and hasattr(self.flat_values,
2603 "set_shape"):
2604 result.flat_values.set_shape(self.flat_values_spec.shape)
2605 else:
2606 result.set_shape(self._shape)
2607 return result
2609 def _batch(self, batch_size):
2610 if self._flat_values_spec is not None:
2611 raise ValueError("Customized value_type is not supported.")
2612 return RaggedTensorSpec(
2613 tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
2614 self._dtype, self._ragged_rank + 1, self._row_splits_dtype)
2616 def _unbatch(self):
2617 if self._flat_values_spec is not None:
2618 raise ValueError("Customized value_type is not supported.")
2619 # Note: Negative ragged_rank is allowed here because the dataset could be
2620 # subsequently batched again. If ragged_rank > 1, assume row_splits_dtype is
2621 # consistent. Errors are handled in
2622 # RaggedTensorSpec._from_compatible_tensor_list()
2623 return RaggedTensorSpec(self._shape[1:], self._dtype, self._ragged_rank - 1,
2624 self._row_splits_dtype)
2626 def _to_legacy_output_types(self):
2627 return self._dtype
2629 def _to_legacy_output_shapes(self):
2630 return self._shape
2632 def _to_legacy_output_classes(self):
2633 return self
2635 @classmethod
2636 def from_value(cls, value):
2637 if (isinstance(value, ragged_tensor_value.RaggedTensorValue) or
2638 isinstance(value.flat_values, ops.Tensor)):
2639 return cls(
2640 shape=value.shape,
2641 dtype=value.values.dtype,
2642 ragged_rank=value.ragged_rank,
2643 row_splits_dtype=value.row_splits.dtype)
2644 else:
2645 flat_values_spec = type_spec.type_spec_from_value(value.flat_values)
2646 # Relax shape[0] to None, as it is connected to dynamic ragged shapes.
2647 flat_values_spec = flat_values_spec._unbatch()._batch(None) # pylint: disable=protected-access
2648 return cls(
2649 shape=value.shape,
2650 dtype=value.values.dtype,
2651 ragged_rank=value.ragged_rank,
2652 row_splits_dtype=value.row_splits.dtype,
2653 flat_values_spec=flat_values_spec)
2656nested_structure_coder.register_codec(
2657 nested_structure_coder.BuiltInTypeSpecCodec(
2658 RaggedTensorSpec, struct_pb2.TypeSpecProto.RAGGED_TENSOR_SPEC
2659 )
2660)
2663type_spec.register_type_spec_from_value_converter(
2664 ragged_tensor_value.RaggedTensorValue, RaggedTensorSpec.from_value)
2667# ===============================================================================
2668# Convert value -> tensor
2669# ===============================================================================
2670def convert_to_tensor_or_ragged_tensor(value,
2671 dtype=None,
2672 preferred_dtype=None,
2673 name=None):
2674 """Converts value to a `RaggedTensor` or `Tensor`.
2676 * If `value` is a `RaggedTensor`, then return it as-is.
2677 * If `value` is a `RaggedTensorValue`, return a corresponding constant
2678 `RaggedTensor`.
2679 * Otherwise, use `convert_to_tensor` to convert `value` to a `Tensor`.
2681 Args:
2682 value: A `RaggedTensor`, a `RaggedTensorValue`, or an object whose type has
2683 a registered `Tensor` conversion function.
2684 dtype: Optional element type for the returned tensor. If missing the type
2685 is inferred from the type of `value`.
2686 preferred_dtype: Optional element type for the returned tensor, used when
2687 dtype is None. This argument has no effect if `value` is already a
2688 tensor, or when conversion is not possible.
2689 name: Optional name to use if a new `Tensor` is created.
2691 Returns:
2692 A `Tensor` or `RaggedTensor`.
2693 """
2694 if isinstance(value, RaggedTensor):
2695 if dtype and not dtype.is_compatible_with(value.dtype):
2696 raise ValueError(f"Tensor conversion requested dtype {dtype.name} for "
2697 f"RaggedTensor with dtype {value.dtype.name}: {value}.")
2698 return value
2699 elif isinstance(value, ragged_tensor_value.RaggedTensorValue):
2700 with ops.name_scope(name, "ConvertToTensorOrRaggedTensor", []):
2701 flat_values = ops.convert_to_tensor(
2702 value=value.flat_values,
2703 dtype=dtype,
2704 dtype_hint=preferred_dtype,
2705 name="flat_values")
2706 return RaggedTensor.from_nested_row_splits(
2707 flat_values, value.nested_row_splits, validate=False)
2708 else:
2709 return tensor_conversion.convert_to_tensor_v2_with_dispatch(
2710 value=value, dtype=dtype, dtype_hint=preferred_dtype, name=name
2711 )
2714def _convert_to_ragged_tensor_values(value):
2715 """Converts value to supported RaggedTensor value.
2717 * If `value` is an object of supported value type, then return it as-is.
2718 * Otherwise convert it to Tensor or RaggedTensor.
2720 Args:
2721 value: An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor
2722 value types, or an object whose type has a registered `Tensor` conversion
2723 function.
2725 Returns:
2726 An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor
2727 value types
2728 """
2729 if _is_supported_ragged_values_type(value):
2730 return value
2731 else:
2732 return convert_to_tensor_or_ragged_tensor(value, name="values")
2735# ===============================================================================
2736# Register RaggedTensor for use with session.run.
2737# ===============================================================================
2738def _ragged_tensor_value_from_components(components):
2739 components = list(components)
2740 value = components.pop()
2741 while components:
2742 value = ragged_tensor_value.RaggedTensorValue(value, components.pop())
2743 return value
2746def _ragged_tensor_session_fetch(rt):
2747 components = rt.nested_row_splits + (rt.flat_values,)
2748 return (components, _ragged_tensor_value_from_components)
2751def _ragged_tensor_session_feed(feed_key, feed_val):
2752 key_components = feed_key.nested_row_splits + (feed_key.flat_values,)
2753 val_components = feed_val.nested_row_splits + (feed_val.flat_values,)
2754 return zip(key_components, val_components)
2757def _ragged_tensor_session_feed_for_partial_run(feed_key):
2758 return feed_key.nested_row_splits + (feed_key.flat_values,)
2761session.register_session_run_conversion_functions(
2762 RaggedTensor, _ragged_tensor_session_fetch, _ragged_tensor_session_feed,
2763 _ragged_tensor_session_feed_for_partial_run)
2766# ===============================================================================
2767# RaggedTensorType
2768# ===============================================================================
2769class RaggedTensorType:
2770 """Encoding of a static type for a `RaggedTensor`.
2772 Use this type to express/declare that an output must have the type of
2773 `RaggedTensor`.
2774 """
2776 def __init__(self, dtype, ragged_rank, row_splits_dtype=dtypes.int64):
2777 """Initializes a RaggedTensorType object.
2779 Args:
2780 dtype: data type of the `RaggedTensor`'s inner values.
2781 ragged_rank: ragged_rank of the declared `RaggedTensor`.
2782 row_splits_dtype: data type for the `RaggedTensor`'s row splits.
2783 One of: `tf.int32` or `tf.int64`.
2784 """
2785 row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
2786 self._dtype = dtype
2787 self._ragged_rank = ragged_rank
2788 self._row_splits_dtype = row_splits_dtype
2790 dtype = property(lambda self: self._dtype)
2791 ragged_rank = property(lambda self: self._ragged_rank)
2792 row_splits_dtype = property(lambda self: self._row_splits_dtype)
2794 def __repr__(self):
2795 return "RaggedTensorType(%r, %r, %r)" % (self.dtype, self.ragged_rank,
2796 self.row_splits_dtype)
2799# ===============================================================================
2800# Helper Functions
2801# ===============================================================================
2802def _assert_sparse_indices_are_ragged_right(indices):
2803 """Checks that the given SparseTensor.indices tensor is ragged-right.
2805 Example: `indices = [[0, 0], [0, 1], [2, 0], [3, 1]]` is not ragged right
2806 because the entry `[3, 1]` skips a cell.
2808 Args:
2809 indices: The SparseTensor indices to check.
2811 Returns:
2812 A list of control dependency op tensors.
2813 """
2814 index_prefix = indices[:, :-1]
2815 index_suffix = indices[:, -1]
2817 # Check whether each index is starting a new row in the innermost dimension
2818 # (prefix[i] != prefix[i-1]) or continuing a row (prefix[i] == prefix[i-1]).
2819 # (Note: this skips the first index; we will check that separately below.)
2820 index_prefix_changed = math_ops.reduce_any(
2821 math_ops.not_equal(index_prefix[1:], index_prefix[:-1]), axis=1)
2823 # Check two cases:
2824 # * For indices that start a new row: index_suffix[i] must be zero.
2825 # * For indices that continue a row: index_suffix[i] must be equal to
2826 # index_suffix[i-1]+1.
2827 index_ok = array_ops.where(
2828 index_prefix_changed, math_ops.equal(index_suffix[1:], 0),
2829 math_ops.equal(index_suffix[1:], index_suffix[:-1] + 1))
2831 # Also check that the very first index didn't skip any cells. The first
2832 # index starts a new row (by definition), so its suffix should be zero.
2833 sparse_indices_are_ragged_right = math_ops.logical_and(
2834 math_ops.reduce_all(math_ops.equal(index_suffix[:1], 0)),
2835 math_ops.reduce_all(index_ok))
2837 message = [
2838 "SparseTensor is not right-ragged", "SparseTensor.indices =", indices
2839 ]
2840 return [control_flow_assert.Assert(sparse_indices_are_ragged_right, message)]
2843@ops.RegisterGradient("RaggedTensorToSparse")
2844def _ragged_tensor_to_sparse_gradient(op, unused_sparse_indices_grad,
2845 sparse_values_grad,
2846 unused_sparse_shape_grad):
2847 """Gradient for RaggedTensorToSparse."""
2848 op_inputs_nested_row_splits = op.inputs[:-1]
2849 op_inputs_flat_values = op.inputs[-1]
2851 # No gradient for the RaggedTensor's nested_row_splits.
2852 nested_row_splits_gradient = [None] * len(op_inputs_nested_row_splits)
2854 # Gradient for the RaggedTensor's flat_values is formed by reshaping
2855 # the gradient for the SparseTensor's values.
2856 flat_values_shape = array_ops.shape(op_inputs_flat_values)
2857 flat_values_gradient = array_ops.reshape(sparse_values_grad,
2858 flat_values_shape)
2860 return nested_row_splits_gradient + [flat_values_gradient]
2863def _assert_monotonic_increasing(tensor, message=None):
2864 return check_ops.assert_non_negative(
2865 tensor[1:] - tensor[:-1], message=message)
2868def _assert_zero(tensor, message=None):
2869 return check_ops.assert_equal(
2870 tensor, constant_op.constant(0, dtype=tensor.dtype), message=message)
2873def _nrows(tensor, out_type=dtypes.int32):
2874 if isinstance(tensor, RaggedTensor):
2875 return tensor.nrows(out_type=out_type)
2876 else:
2877 return array_ops.shape(tensor, out_type=out_type)[0]
2880def merge_dims(value, outer_axis, inner_axis):
2881 """Merges value[outer_axis...inner_axis] into a single dimension.
2883 See `RaggedTensor.merge_dims()` for more details. This helper differs from
2884 `RaggedTensor.merge_dims()` in that `value` may be a dense or ragged tensor.
2886 Args:
2887 value: A `RaggedTensor` or `Tensor`
2888 outer_axis: `int`
2889 inner_axis: `int`
2891 Returns:
2892 A flattened `RaggedTensor` or `Tensor`.
2893 """
2894 if outer_axis == inner_axis:
2895 return value
2897 # Flatten outer dimensions of a RaggedTensor by just taking its values.
2898 while outer_axis == 0 and isinstance(value, RaggedTensor):
2899 value = value.values
2900 inner_axis -= 1
2901 if inner_axis == 0:
2902 return value
2904 # Flatten non-Ragged tensors using tf.reshape().
2905 if not isinstance(value, RaggedTensor):
2906 if value.shape.is_fully_defined():
2907 old_shape = value.shape.as_list()
2908 new_shape = old_shape[:outer_axis] + [-1] + old_shape[inner_axis + 1:]
2909 else:
2910 old_shape = array_ops.shape(value)
2911 new_shape = array_ops.concat(
2912 [old_shape[:outer_axis], [-1], old_shape[inner_axis + 1:]], axis=0)
2913 return array_ops.reshape(value, new_shape)
2915 # Handle outer_axis>1 via recursion.
2916 if outer_axis > 1:
2917 return value.with_values(
2918 merge_dims(value.values, outer_axis - 1, inner_axis - 1))
2920 # At this point, we know outer_axis == 1, and value is a RaggedTensor.
2921 # So we need to flatten the values and build a corresponding splits tensor.
2922 new_values = value.values
2923 new_splits = value.row_splits
2924 for axis in range(outer_axis, inner_axis):
2925 if isinstance(new_values, RaggedTensor):
2926 # Flatten a single ragged dimension.
2927 new_splits = array_ops.gather(new_values.row_splits, new_splits)
2928 new_values = new_values.values
2929 else:
2930 # Flatten all remaining dense dimensions.
2931 shape_split = inner_axis - axis + 1
2932 if new_values.shape.is_fully_defined():
2933 old_shape = new_values.shape.as_list()
2934 new_shape = [-1] + old_shape[shape_split:]
2935 flat_size = _prod(old_shape[1:shape_split])
2936 else:
2937 old_shape = array_ops.shape(new_values)
2938 new_shape = array_ops.concat([[-1], old_shape[shape_split:]], axis=0)
2939 flat_size = math_ops.cast(
2940 math_ops.reduce_prod(old_shape[1:shape_split]), new_splits.dtype)
2941 new_values = array_ops.reshape(new_values, new_shape)
2942 new_splits = new_splits * flat_size
2943 break
2944 return RaggedTensor.from_row_splits(new_values, new_splits)
2947def _prod(lst):
2948 """Returns the product of the numbers in a list."""
2949 return functools.reduce(operator.mul, lst, 1)
2952def _get_row_partition_type_tensor_pairs_tail(partition):
2953 """Gets a row partition type tensor pair for the tail.
2955 If value_rowid is defined, then it is used. Otherwise, row_splits
2956 are used.
2958 Args:
2959 partition: a RowPartition.
2961 Returns:
2962 A list of (row_partition_type, row_partition_tensor) pairs.
2963 """
2964 if partition._has_precomputed_value_rowids(): # pylint: disable=protected-access
2965 return ("VALUE_ROWIDS", partition.value_rowids())
2966 else:
2967 return ("ROW_SPLITS", partition.row_splits())
2970def _get_row_partition_type_tensor_pairs(rt_input):
2971 """Gets a list of the row partitions for rt_input.
2973 If value_rowids are defined, then they are used. Otherwise, row_splits
2974 are used. If the outermost level has value_rowids defind, then nrows is
2975 also added.
2977 Args:
2978 rt_input: a ragged tensor.
2980 Returns:
2981 A list of (row_partition_type, row_partition_tensor) pairs.
2982 """
2983 partitions = rt_input._nested_row_partitions # pylint: disable=protected-access
2984 tail = [_get_row_partition_type_tensor_pairs_tail(x) for x in partitions[1:]]
2986 if partitions[0]._value_rowids is not None: # pylint: disable=protected-access
2987 return [("FIRST_DIM_SIZE", partitions[0].nrows()),
2988 ("VALUE_ROWIDS", partitions[0].value_rowids())] + tail
2989 else:
2990 return [("ROW_SPLITS", partitions[0].row_splits())] + tail
2993def _shape_as_tensor(shape, dtype):
2994 """Takes shape and coerces it to a shape as a tensor.
2996 If the object is already a tensor, simply passes it on (result is guaranteed
2997 to be int64 or int32, but not necessarily dtype).
2998 If not, creates a tensor of type dtype.
3000 Result is either a scalar equal to -1 if the shape is unknown_rank.
3001 Otherwise, it is a vector, where unknown dimensions are represented with a
3002 value of -1.
3004 In C++, see TensorShapeFromTensor for parsing shapes in kernels, and
3005 InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape, for
3006 use in the shape inference function.
3008 Args:
3009 shape: input to coerce from TensorShape, Tensor, None, List[Optional[Int]],
3010 Tuple[Optional[Int]].
3011 dtype: tf.int64 or tf.int32
3013 Returns:
3014 a scalar or vector tensor of dtype tf.int32 or tf.int64.
3015 """
3016 if dtype != dtypes.int64 and dtype != dtypes.int32:
3017 raise ValueError(f"Expected int64 or int32 for dtype: got {dtype}.")
3019 if isinstance(shape, ops.Tensor):
3020 if shape.dtype != dtypes.int64 and shape.dtype != dtypes.int32:
3021 return math_ops.cast(shape, dtype)
3022 return shape
3023 shape = tensor_shape.as_shape(shape)
3024 if not shape:
3025 # Imply rank is unknown using a -1 scalar.
3026 return constant_op.constant(-1, dtype=dtype)
3027 shape = [(-1 if x is None else x) for x in shape.as_list()]
3028 # At this point, shape is List[Int].
3029 return constant_op.constant(shape, dtype=dtype)
3032def _nvals_uniform_row_length(values, uniform_row_length):
3033 """Get the number of values for uniform row length constructor."""
3034 const_nvals = tensor_shape.dimension_at_index(values.shape, 0).value
3035 if const_nvals is not None:
3036 nvals = constant_op.constant(const_nvals, uniform_row_length.dtype)
3037 elif isinstance(values, RaggedTensor):
3038 nvals = values.nrows(out_type=uniform_row_length.dtype)
3039 else:
3040 nvals = array_ops.shape(values, out_type=uniform_row_length.dtype)[0]
3041 return nvals
3044def _get_optional_partition_dtype(values):
3045 """Returns the partition dtype, or None if None exists."""
3046 if isinstance(values, RaggedTensor):
3047 # pylint: disable=protected-access
3048 return values._row_partition.dtype
3049 return None
3052_SUPPORTED_RAGGED_VALUE_TYPES = (ops.Tensor, RaggedTensor)
3055# TODO(edloper): Consider whether we should change the registry to be on
3056# TypeSpecs rather than ValueTypes.
3057def _add_supported_value_type(cls):
3058 """Register the `cls` as supported value type of RaggedTenosr.
3060 The cls must be a subclass of CompositeTensor, and must support:
3061 - Spec:
3062 The Spec must be a `BatchableTypeSpec`
3063 - Properties:
3064 - x.shape
3065 - x.dtype
3066 - Methods:
3067 - x.__getitem__(idx) (method: returns a supported value type)
3068 - x.set_shape(shape)
3069 - Ops:
3070 - tf.shape(x) -- tf.shape(x)[0] must be a tf.Tensor.
3071 - tf.tile(x)
3072 - assert_rank_at_least(x)
3073 - tf.ones_like(x)
3074 - tf.gather(params=x, indices=Tensor)
3075 - tf.add(x, y)
3076 - tf.boolean_mask(x, ...)
3077 - @TODO(edloper): Complete this list
3079 Note: the following RaggedTensor, RaggedTensorSpec methods & ops are not
3080 currently supported unless `rt.values` is a RaggedTensor or a tf.Tensor:
3081 - rt.to_tensor()
3082 - rt.to_sparse_tensor()
3083 - rt._to_variant()
3084 - rt._from_variant()
3085 - tf.ragged.cross([rt])
3086 - tf.gather(params=x, indices=rt) # rt used for indices
3087 - RaggedTensorSpec methods:
3088 - _batch
3089 - _unbatch
3090 - _to_tensor_list
3091 - _to_batched_tensor_list
3092 - _from_compatible_tensor_list
3094 Args:
3095 cls: The type to be added to supported value types.
3096 """
3097 if not issubclass(cls, composite_tensor.CompositeTensor):
3098 raise ValueError(f"cls ({cls}) must be a subclass of CompositeTensor.")
3099 if not hasattr(cls, "shape"):
3100 raise ValueError("cls must support the `shape` property.")
3101 if not hasattr(cls, "dtype"):
3102 raise ValueError("cls must support the `dtype` property.")
3103 global _SUPPORTED_RAGGED_VALUE_TYPES
3104 _SUPPORTED_RAGGED_VALUE_TYPES += (cls,)
3107def _is_supported_ragged_values_type(value):
3108 return isinstance(value, _SUPPORTED_RAGGED_VALUE_TYPES)
3111def _assert_is_supported_ragged_values_type(value):
3112 if not _is_supported_ragged_values_type(value):
3113 ok_types = ", ".join(cls.__name__ for cls in _SUPPORTED_RAGGED_VALUE_TYPES)
3114 raise TypeError(f"type(values) must be one of: {ok_types}, got {value}.")
3117def _formatter(x):
3118 """Separate Numpy array elements with comma."""
3119 if isinstance(x, np.ndarray):
3120 if x.size != 0:
3121 return np.array2string(x, separator=", ")
3122 else:
3123 # When x.size==0, np.array2string always returns `[]`. This isn't always
3124 # what we want. E.g., if `x.shape=[0, 3]`, then we want `[[], [], []]`.
3125 return repr(x.tolist())
3126 else:
3127 return str(x)
3129# Type annotation indicating that a value is ragged. Includes RaggedTensor
3130# as well as the (deprecated) RaggedTensorValue class from TF 1.x.
3131Ragged = typing.Union[RaggedTensor, ragged_tensor_value.RaggedTensorValue]
3133# Type annotation indicating that a value is a ragged tensor, a dense tensor,
3134# or a value that can be converted to a tensor (e.g. np.array).
3135# TODO(edloper): Add Variable to TensorLike, and remove it from here.
3136RaggedOrDense = typing.Union[Ragged, core_types.TensorLike]