Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_array_ops.py: 22%
363 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Array operations for RaggedTensors."""
17from typing import Optional
18from typing import Union
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import sparse_tensor
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import array_ops_stack
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import data_flow_ops
30from tensorflow.python.ops import gen_ragged_array_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import sort_ops
33from tensorflow.python.ops.ragged import dynamic_ragged_shape
34from tensorflow.python.ops.ragged import ragged_functional_ops
35from tensorflow.python.ops.ragged import ragged_math_ops
36from tensorflow.python.ops.ragged import ragged_tensor
37from tensorflow.python.ops.ragged import ragged_util
38from tensorflow.python.ops.ragged import segment_id_ops
39from tensorflow.python.types import core as core_types
40from tensorflow.python.util import dispatch
41from tensorflow.python.util.tf_export import tf_export
43# ===============================================================================
44# Masking
45# ===============================================================================
48@tf_export('ragged.boolean_mask')
49@dispatch.add_dispatch_support
50def boolean_mask(data, mask, name=None):
51 """Applies a boolean mask to `data` without flattening the mask dimensions.
53 Returns a potentially ragged tensor that is formed by retaining the elements
54 in `data` where the corresponding value in `mask` is `True`.
56 * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]`
58 Where `j` is the `i`th `True` entry of `mask[a1...aA]`.
60 Note that `output` preserves the mask dimensions `a1...aA`; this differs
61 from `tf.boolean_mask`, which flattens those dimensions.
63 Args:
64 data: A potentially ragged tensor.
65 mask: A potentially ragged boolean tensor. `mask`'s shape must be a prefix
66 of `data`'s shape. `rank(mask)` must be known statically.
67 name: A name prefix for the returned tensor (optional).
69 Returns:
70 A potentially ragged tensor that is formed by retaining the elements in
71 `data` where the corresponding value in `mask` is `True`.
73 * `rank(output) = rank(data)`.
74 * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`.
76 Raises:
77 ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is
78 not a prefix of `data.shape`.
80 #### Examples:
82 >>> # Aliases for True & False so data and mask line up.
83 >>> T, F = (True, False)
85 >>> tf.ragged.boolean_mask( # Mask a 2D Tensor.
86 ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
87 ... mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list()
88 [[1, 3], [], [7]]
90 >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor.
91 ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
92 ... tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list()
93 [[3], [], [5, 6]]
95 >>> tf.ragged.boolean_mask( # Mask rows of a 2D RaggedTensor.
96 ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
97 ... tf.ragged.constant([True, False, True])).to_list()
98 [[1, 2, 3], [5, 6]]
99 """
100 with ops.name_scope(name, 'RaggedMask', [data, mask]):
101 # Convert inputs to tensors.
102 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
103 mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(
104 mask, dtypes.bool, name='mask')
105 row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
106 data, mask, return_dtype=True)
108 # Get static rank of mask.
109 if mask.shape.ndims is None:
110 raise ValueError('mask.shape.ndims must be known statically.')
111 elif mask.shape.ndims == 0:
112 raise ValueError('mask cannot be scalar.')
114 # If mask is ragged, then recurse with a non-ragged mask.
115 if ragged_tensor.is_ragged(mask):
116 if not ragged_tensor.is_ragged(data):
117 data = ragged_tensor.RaggedTensor.from_tensor(
118 data,
119 ragged_rank=mask.ragged_rank,
120 row_splits_dtype=mask.row_splits.dtype)
121 # Check that mask.nested_row_splits is a prefix of
122 # data.nested_row_splits.
123 splits_list = [
124 mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank]
125 ]
126 with ops.control_dependencies(
127 ragged_util.assert_splits_match(splits_list)):
128 # Strip off ragged `splits` until `mask` is non-ragged. Keep the splits
129 # that we strip off in `splits`, so we can add them back on after
130 # we recursively mask the non-ragged data.
131 splits = []
132 while ragged_tensor.is_ragged(mask):
133 if mask.shape.ndims > 2:
134 splits.append(mask.row_splits)
135 else:
136 # Count the number of True mask values in each row to find the
137 # lengths of the filtered rows; then convert to splits.
138 int_mask = ragged_functional_ops.map_flat_values(
139 math_ops.cast, mask, dtype=row_splits_dtype)
140 masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
141 splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
142 mask = mask.values
143 data = data.values
145 # Recursively apply the nested non-ragged mask to the nested data.
146 masked_values = boolean_mask(data, mask)
148 # Add the ragged `splits` back to the result.
149 masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits(
150 masked_values, splits, validate=False)
152 return masked_values
154 # If mask is non-ragged and has rank 1, and data is ragged, then build a
155 # ragged tensor with the indicated rows.
156 elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1:
157 # Get the masked splits: first get the length of each row, then filter
158 # out the rows that we are deleting, and convert that filtered set of
159 # masks back to a splits tensor.
160 lengths = data.row_lengths()
161 masked_lengths = array_ops.boolean_mask(lengths, mask)
162 masked_splits = ragged_util.lengths_to_splits(masked_lengths)
164 # Get the masked values: first get row ids corresponding to each
165 # value, then use tf.gather to build a boolean mask that's false for
166 # values that come from rows that we are deleting, and use that mask to
167 # construct the masked values tensor.
168 segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits)
169 segment_mask = array_ops.gather(mask, segment_ids)
170 masked_values = boolean_mask(data.values, segment_mask)
172 return ragged_tensor.RaggedTensor.from_row_splits(
173 masked_values, masked_splits, validate=False)
175 # If mask is non-ragged and has rank>1, then convert it to be ragged,
176 # with a ragged rank matching data.
177 if ragged_tensor.is_ragged(data):
178 mask = ragged_tensor.RaggedTensor.from_tensor(
179 mask,
180 ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
181 row_splits_dtype=data.row_splits.dtype)
182 return boolean_mask(data, mask)
184 # Otherwise, data and mask are both `Tensor`s.
185 else:
186 # Apply `boolean_mask` to get the masked values.
187 masked_values = array_ops.boolean_mask(data, mask)
189 if mask.shape.ndims >= 2:
190 # Add the innermost ragged dimension. For each innermost cell, get the
191 # number of values it contains. Then flatten that to get a list of
192 # cell lengths, and convert it to splits. Finally, combine the splits
193 # and values to get the innermost ragged tensor.
194 masked_lengths = math_ops.count_nonzero(
195 mask, axis=-1, dtype=row_splits_dtype)
196 flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
197 masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
198 masked_values, flattened_masked_lengths, validate=False)
200 # Wrap remaining ragged dimensions.
201 if mask.shape.ndims > 2:
202 mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
203 split_size = math_ops.cumprod(mask_shape) + 1
204 for dim in range(mask.shape.ndims - 3, -1, -1):
205 elt_size = mask_shape[dim + 1]
206 masked_splits = math_ops.range(split_size[dim]) * elt_size
207 masked_values = ragged_tensor.RaggedTensor.from_row_splits(
208 masked_values, masked_splits, validate=False)
210 return masked_values
213# ===============================================================================
214# Tiling
215# ===============================================================================
216@dispatch.dispatch_for_api(array_ops.tile)
217def tile(input: ragged_tensor.Ragged, multiples, name=None): # pylint: disable=redefined-builtin
218 """Constructs a `RaggedTensor` by tiling a given `RaggedTensor`.
220 The values of `input` are replicated `multiples[i]` times along the
221 `i`th dimension (for each dimension `i`). For every dimension `axis` in
222 `input`, the length of each output element in that dimension is the
223 length of corresponding input element multiplied by `multiples[axis]`.
225 Args:
226 input: A `RaggedTensor`.
227 multiples: A 1-D integer `Tensor`. Length must be the same as the number of
228 dimensions in `input`.
229 name: A name for the operation (optional).
231 Returns:
232 A `RaggedTensor` with the same type, rank, and ragged_rank as `input`.
234 #### Example:
236 >>> rt = tf.ragged.constant([[1, 2], [3]])
237 >>> tf.tile(rt, [3, 2]).to_list()
238 [[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]]
239 """
240 with ops.name_scope(name, 'RaggedTile', [input, multiples]):
241 input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
242 input, name='input')
243 if not ragged_tensor.is_ragged(input):
244 return array_ops.tile(input, multiples, name)
245 multiples = ragged_util.convert_to_int_tensor(
246 multiples, name='multiples', dtype=input.row_splits.dtype)
247 multiples.shape.assert_has_rank(1)
249 # If the constant value of `multiples` is available, then we can use it
250 # to skip tiling dimensions where `multiples=1`.
251 const_multiples = tensor_util.constant_value(multiples)
253 return ragged_tensor.RaggedTensor.from_nested_row_splits(
254 _tile_ragged_values(input, multiples, const_multiples),
255 _tile_ragged_splits(input, multiples, const_multiples),
256 validate=False)
259def _tile_ragged_values(rt_input, multiples, const_multiples=None):
260 """Builds flat_values tensor for a tiled `RaggedTensor`.
262 Returns a tensor that repeats the values in
263 `rt_input.flat_values` in the
264 appropriate pattern to construct a `RaggedTensor` that tiles `rt_input` as
265 specified by `multiples`.
267 Args:
268 rt_input: The `RaggedTensor` whose values should be repeated.
269 multiples: A 1-D integer `tensor`, indicating how many times each dimension
270 should be repeated.
271 const_multiples: Optional constant value for multiples. Used to skip tiling
272 dimensions where `multiples=1`.
274 Returns:
275 A `Tensor` with the same type and rank as `rt_input.flat_values`.
277 #### Example:
279 >>> rt = tf.ragged.constant([[1, 2], [3]])
280 >>> _tile_ragged_values(rt, tf.constant([3, 2])).numpy()
281 array([1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3], dtype=int32)
282 """
283 ragged_rank = rt_input.ragged_rank
284 nested_splits = rt_input.nested_row_splits
286 # Pointers to the values in `rt_input.flat_values`.
287 inner_value_ids = math_ops.range(nested_splits[-1][-1])
289 # For each ragged dimension (working from the innermost to outermost),
290 # expand `inner_value_ids` as necessary to tile that dimension.
291 prev_splits = None
292 for axis in range(ragged_rank, 0, -1):
293 # Ragged splits for this dimension.
294 splits = nested_splits[axis - 1]
296 # Adjust splits so they point into `inner_value_ids` (instead of just
297 # pointing into the next dimension's values).
298 if prev_splits is not None: # Not the first pass through the loop.
299 splits = array_ops.gather(prev_splits * multiples[axis + 1], splits)
301 # Repeat each element in this ragged dimension `multiples[axis]` times.
302 if const_multiples is None or const_multiples[axis] != 1:
303 inner_value_ids = ragged_util.repeat_ranges(inner_value_ids, splits,
304 multiples[axis])
306 prev_splits = splits
308 # Gather the tiled inner values.
309 ragged_tiled_values = array_ops.gather(rt_input.flat_values, inner_value_ids)
311 # Tile the flat_values for the uniform dimensions (i.e., for `axis=0` plus
312 # `axis=range(ragged_rank, rank)`).
313 inner_repeats = array_ops.concat([multiples[:1], multiples[ragged_rank + 1:]],
314 axis=0)
315 return array_ops.tile(ragged_tiled_values, inner_repeats)
318def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
319 """Builds nested_split tensors for a tiled `RaggedTensor`.
321 Returns a list of split tensors that can be used to construct the
322 `RaggedTensor` that tiles `rt_input` as specified by `multiples`.
324 Args:
325 rt_input: The `RaggedTensor` that is being tiled.
326 multiples: A 1-D integer `tensor`, indicating how many times each dimension
327 should be repeated.
328 const_multiples: Optional constant value for multiples. Used to skip tiling
329 dimensions where `multiples=1`.
331 Returns:
332 A list of 1-D integer `Tensor`s (one for each ragged dimension in
333 `rt_input`).
335 #### Example:
337 >>> rt = tf.ragged.constant([[1, 2], [3]])
338 >>> _tile_ragged_splits(rt, [3, 2])
339 [<tf.Tensor: shape=(7,), dtype=int64,
340 numpy=array([ 0, 4, 6, 10, 12, 16, 18])>]
341 """
342 ragged_rank = rt_input.ragged_rank
343 nested_splits = rt_input.nested_row_splits
345 # projected_splits[src_axis, dst_axis] contains the split points that divide
346 # the rows from src_axis in the list of dst_axis values. E.g.,
347 # projected_splits[i, i] = nested_splits[i], and
348 # projected_splits[i, i+1] = gather(nested_splits[i+1], nested_splits[i]).
349 projected_splits = [{i: nested_splits[i]} for i in range(ragged_rank)]
350 for src_axis in range(ragged_rank):
351 for dst_axis in range(src_axis + 1, ragged_rank - 1):
352 projected_splits[src_axis][dst_axis] = array_ops.gather(
353 nested_splits[dst_axis], projected_splits[src_axis][dst_axis - 1])
355 # For each ragged dimension: nested_splits[axis] -> result_splits[axis].
356 result_splits = []
357 for axis in range(ragged_rank):
358 # Get the length of each row for the input tensor for this dimension.
359 input_lengths = nested_splits[axis][1:] - nested_splits[axis][:-1]
361 # Multiply those lengths by the `multiples` of dimension axis+1, since
362 # each value will be repeated that number of times.
363 output_lengths = input_lengths * multiples[axis + 1]
365 # Repeat ranges of the row lengths as necessary for them to be tiled in
366 # each ragged dimension `d < axis`. (Start with dimension d=axis-1, and
367 # work our way up to dimension d=0.)
368 repeats = 1
369 for d in range(axis - 1, -1, -1):
370 if const_multiples is None or const_multiples[d + 1] != 1:
371 splits = projected_splits[d][axis - 1] * repeats
372 output_lengths = ragged_util.repeat_ranges(output_lengths, splits,
373 multiples[d + 1])
374 repeats *= multiples[d + 1]
376 # Tile splits for the outermost (uniform) dimension.
377 output_lengths = array_ops.tile(output_lengths, multiples[:1])
379 # Convert to splits.
380 result_splits.append(ragged_util.lengths_to_splits(output_lengths))
382 return result_splits
385# ===============================================================================
386# Reshaping
387# ===============================================================================
390@dispatch.dispatch_for_api(array_ops.expand_dims_v2)
391def expand_dims(input: ragged_tensor.Ragged, axis, name=None): # pylint: disable=redefined-builtin
392 """Inserts a dimension with shape 1 into a potentially ragged tensor's shape.
394 Given a potentially ragged tenor `input`, this operation inserts a
395 dimension with size 1 at the dimension `axis` of `input`'s shape.
397 The following table gives some examples showing how `ragged.expand_dims`
398 impacts the shapes of different input tensors. Ragged dimensions are
399 indicated by enclosing them in parentheses.
401 input.shape | axis | result.shape
402 ----------------------- | ---- | -----------------------------
403 `[D1, D2]` | `0` | `[1, D1, D2]`
404 `[D1, D2]` | `1` | `[D1, 1, D2]`
405 `[D1, D2]` | `2` | `[D1, D2, 1]`
406 `[D1, (D2), (D3), D4]` | `0` | `[1, D1, (D2), (D3), D4]`
407 `[D1, (D2), (D3), D4]` | `1` | `[D1, 1, (D2), (D3), D4]`
408 `[D1, (D2), (D3), D4]` | `2` | `[D1, (D2), 1, (D3), D4]`
409 `[D1, (D2), (D3), D4]` | `3` | `[D1, (D2), (D3), 1, D4]`
410 `[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]`
412 Args:
413 input: The potentially tensor that should be expanded with a new dimension.
414 axis: An integer constant indicating where the new dimension should be
415 inserted.
416 name: A name for the operation (optional).
418 Returns:
419 A tensor with the same values as `input`, with an added dimension of
420 size 1 at `axis`.
422 #### Examples:
424 >>> rt = tf.ragged.constant([[1, 2], [3]])
425 >>> print(rt.shape)
426 (2, None)
428 >>> expanded = tf.expand_dims(rt, axis=0)
429 >>> print(expanded.shape, expanded)
430 (1, 2, None) <tf.RaggedTensor [[[1, 2], [3]]]>
432 >>> expanded = tf.expand_dims(rt, axis=1)
433 >>> print(expanded.shape, expanded)
434 (2, 1, None) <tf.RaggedTensor [[[1, 2]], [[3]]]>
436 >>> expanded = tf.expand_dims(rt, axis=2)
437 >>> print(expanded.shape, expanded)
438 (2, None, 1) <tf.RaggedTensor [[[1], [2]], [[3]]]>
439 """
440 with ops.name_scope(name, 'RaggedExpandDims', [input]):
441 input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
442 input, name='input')
444 if not ragged_tensor.is_ragged(input):
445 return array_ops.expand_dims(input, axis)
447 ndims = None if input.shape.ndims is None else input.shape.ndims + 1
448 axis = array_ops.get_positive_axis(axis, ndims, ndims_name='rank(input)')
450 if axis == 0:
451 return ragged_tensor.RaggedTensor.from_uniform_row_length(
452 input, uniform_row_length=input.nrows(), nrows=1, validate=False)
453 elif axis == 1:
454 return ragged_tensor.RaggedTensor.from_uniform_row_length(
455 input, uniform_row_length=1, nrows=input.nrows(), validate=False)
456 else:
457 if ragged_tensor.is_ragged(input.values):
458 return input.with_values(expand_dims(input.values, axis - 1))
459 else:
460 return input.with_values(array_ops.expand_dims(input.values, axis - 1))
463@dispatch.dispatch_for_api(array_ops.expand_dims)
464def _ragged_expand_dims_v1(
465 input: ragged_tensor.Ragged, # pylint: disable=redefined-builtin
466 axis=None,
467 name=None,
468 dim=None):
469 if dim is not None:
470 axis = dim
471 return expand_dims(input=input, axis=axis, name=name)
474# ===============================================================================
475# RaggedTensor Size
476# ===============================================================================
479@dispatch.dispatch_for_api(array_ops.size_v2)
480def size(input: ragged_tensor.Ragged, out_type=dtypes.int32, name=None): # pylint: disable=redefined-builtin
481 """Returns the size of a potentially ragged tensor.
483 The size of a ragged tensor is the size of its inner values.
485 #### Example:
487 >>> tf.size(tf.ragged.constant([[1, 2], [3]])).numpy()
488 3
490 Args:
491 input: A potentially ragged `Tensor`.
492 out_type: The numeric output type for the operation.
493 name: A name for the operation (optional).
495 Returns:
496 A Tensor of type `out_type`.
497 """
498 if ragged_tensor.is_ragged(input):
499 return array_ops.size(input.flat_values, out_type=out_type, name=name)
500 else:
501 return array_ops.size(input, out_type=out_type, name=name)
504@dispatch.dispatch_for_api(array_ops.size)
505def _ragged_size_v1(
506 input: ragged_tensor.Ragged, # pylint: disable=redefined-builtin
507 name=None,
508 out_type=dtypes.int32):
509 return size(input=input, out_type=out_type, name=name)
512# ===============================================================================
513# ragged.rank
514# ===============================================================================
515@dispatch.dispatch_for_api(array_ops.rank)
516def rank(input: ragged_tensor.Ragged, name=None): # pylint: disable=redefined-builtin
517 """Returns the rank of a RaggedTensor.
519 Returns a 0-D `int32` `Tensor` representing the rank of `input`.
521 #### Example:
523 >>> # shape of tensor 't' is [2, None, None]
524 >>> t = tf.ragged.constant([[[1], [2, 2]], [[3, 3, 3], [4, 4, 4, 4]]])
525 >>> tf.rank(t).numpy()
526 3
528 Args:
529 input: A `RaggedTensor`
530 name: A name for the operation (optional).
532 Returns:
533 A `Tensor` of type `int32`.
534 """
535 with ops.name_scope(name, 'RaggedRank', [input]) as name:
536 if not ragged_tensor.is_ragged(input):
537 return array_ops.rank(input, name)
539 return input.ragged_rank + array_ops.rank(input.flat_values)
542# ===============================================================================
543# ragged.one_hot
544# ===============================================================================
545@dispatch.dispatch_for_api(array_ops.one_hot)
546def ragged_one_hot(indices: ragged_tensor.Ragged,
547 depth,
548 on_value=None,
549 off_value=None,
550 axis=None,
551 dtype=None,
552 name=None):
553 """Applies tf.one_hot along the values of a RaggedTensor."""
554 # Get the adjusted axis value for the call to array_ops.one_hot.
555 # Note: the only negative `axis` value supported by array_ops.one_hot is -1.
556 if isinstance(axis, int) and axis >= 0:
557 if axis <= indices.ragged_rank:
558 raise ValueError('axis (%d) must be greater than indices.ragged_rank '
559 '(%d).' % (axis, indices.ragged_rank))
560 axis -= indices.ragged_rank
562 with ops.name_scope(name, 'RaggedOneHot',
563 [indices, depth, on_value, off_value, axis]):
564 indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
565 indices, name='indices')
566 return indices.with_flat_values(
567 array_ops.one_hot(indices.flat_values, depth, on_value, off_value, axis,
568 dtype, name))
571# ===============================================================================
572# ragged.stack_dynamic_partitions
573# ===============================================================================
574@tf_export('ragged.stack_dynamic_partitions')
575@dispatch.add_dispatch_support
576def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
577 """Stacks dynamic partitions of a Tensor or RaggedTensor.
579 Returns a RaggedTensor `output` with `num_partitions` rows, where the row
580 `output[i]` is formed by stacking all slices `data[j1...jN]` such that
581 `partitions[j1...jN] = i`. Slices of `data` are stacked in row-major
582 order.
584 If `num_partitions` is an `int` (not a `Tensor`), then this is equivalent to
585 `tf.ragged.stack(tf.dynamic_partition(data, partitions, num_partitions))`.
587 #### Example:
589 >>> data = ['a', 'b', 'c', 'd', 'e']
590 >>> partitions = [ 3, 0, 2, 2, 3]
591 >>> num_partitions = 5
592 >>> tf.ragged.stack_dynamic_partitions(data, partitions, num_partitions)
593 <tf.RaggedTensor [[b'b'], [], [b'c', b'd'], [b'a', b'e'], []]>
595 Args:
596 data: A `Tensor` or `RaggedTensor` containing the values to stack.
597 partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the
598 partition that each slice of `data` should be added to. `partitions.shape`
599 must be a prefix of `data.shape`. Values must be greater than or equal to
600 zero, and less than `num_partitions`. `partitions` is not required to be
601 sorted.
602 num_partitions: An `int32` or `int64` scalar specifying the number of
603 partitions to output. This determines the number of rows in `output`.
604 name: A name prefix for the returned tensor (optional).
606 Returns:
607 A `RaggedTensor` containing the stacked partitions. The returned tensor
608 has the same dtype as `data`, and its shape is
609 `[num_partitions, (D)] + data.shape[partitions.rank:]`, where `(D)` is a
610 ragged dimension whose length is the number of data slices stacked for
611 each `partition`.
612 """
613 with ops.name_scope(name, 'SegmentStack', [data, partitions, num_partitions]):
614 # Convert inputs to tensors.
615 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
616 row_splits_dtype = (
617 data.row_splits.dtype
618 if isinstance(data, ragged_tensor.RaggedTensor) else None)
619 partitions = ragged_tensor.convert_to_tensor_or_ragged_tensor(
620 partitions, name='partitions', preferred_dtype=row_splits_dtype)
621 num_partitions = ops.convert_to_tensor(
622 num_partitions, name='num_partitions', preferred_dtype=partitions.dtype)
623 if row_splits_dtype is not None:
624 partitions = math_ops.cast(partitions, row_splits_dtype)
625 num_partitions = math_ops.cast(num_partitions, partitions.dtype)
627 # Sanity-checks for shapes.
628 partitions_rank = partitions.shape.ndims
629 if partitions_rank is None:
630 raise ValueError('partitions must have known rank.')
631 num_partitions.shape.assert_has_rank(0)
632 partitions.shape.assert_is_compatible_with(data.shape[:partitions_rank])
634 if partitions_rank == 0:
635 # If partitions is a scalar, then just create a RaggedTensor containing
636 # that single the complete `data` value in the specified row.
637 return ragged_tensor.RaggedTensor.from_value_rowids(
638 values=array_ops_stack.stack([data]),
639 value_rowids=array_ops_stack.stack([partitions]),
640 nrows=num_partitions,
641 validate=False)
643 elif partitions_rank == 1:
644 # If partitions is a vector (the typical case): we can just use data and
645 # partitions as the `values` and `value_rowids` for `from_value_rowids`,
646 # as long as we sort them first.
647 permutation = sort_ops.argsort(partitions, stable=True)
648 value_rowids = array_ops.gather(partitions, permutation)
649 values = array_ops.gather(data, permutation)
650 checks = [
651 check_ops.assert_less(
652 value_rowids[-1:], num_partitions,
653 message='partitions must be less than num_partitions'),
654 check_ops.assert_non_negative(
655 partitions, message='partitions must be non-negative.')
656 ]
657 with ops.control_dependencies(checks):
658 return ragged_tensor.RaggedTensor.from_value_rowids(
659 values, value_rowids, nrows=num_partitions, validate=False)
661 else:
662 # Handle higher-dimensional partitions via recursion.
663 if not isinstance(data, ragged_tensor.RaggedTensor):
664 data = ragged_tensor.RaggedTensor.from_tensor(
665 data, row_splits_dtype=partitions.dtype, ragged_rank=1)
666 if not isinstance(partitions, ragged_tensor.RaggedTensor):
667 partitions = ragged_tensor.RaggedTensor.from_tensor(
668 partitions,
669 row_splits_dtype=partitions.dtype,
670 ragged_rank=max(data.ragged_rank, partitions_rank - 1))
671 check = check_ops.assert_equal(
672 data.row_splits,
673 partitions.row_splits,
674 message='data and partitions have incompatible ragged shapes')
675 with ops.control_dependencies([check]):
676 return stack_dynamic_partitions(data.values, partitions.values,
677 num_partitions)
680# ===============================================================================
681# Reverse
682# ===============================================================================
683@dispatch.dispatch_for_api(array_ops.reverse)
684def reverse(tensor: ragged_tensor.Ragged, axis, name=None):
685 """Reverses a RaggedTensor along the specified axes.
687 #### Example:
689 >>> data = tf.ragged.constant([
690 ... [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10], [11, 12]]])
691 >>> tf.reverse(data, axis=[0, 2])
692 <tf.RaggedTensor [[[8, 7], [10, 9], [12, 11]], [[6, 5]], [[2, 1], [4, 3]]]>
694 Args:
695 tensor: A 'RaggedTensor' to reverse.
696 axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices of
697 the axes to reverse.
698 name: A name prefix for the returned tensor (optional).
700 Returns:
701 A 'RaggedTensor'.
702 """
703 type_error_msg = ('`axis` must be a list of int or a constant tensor'
704 'when reversing axes in a ragged tensor')
706 with ops.name_scope(name, 'Reverse', [tensor, axis]):
707 if isinstance(axis, ops.Tensor):
708 axis = tensor_util.constant_value(axis)
709 if axis is None:
710 raise TypeError(type_error_msg)
711 elif not (isinstance(axis, (list, tuple)) and
712 all(isinstance(dim, int) for dim in axis)):
713 raise TypeError(type_error_msg)
715 tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
716 tensor, name='tensor')
718 # Allow usage of negative values to specify innermost axes.
719 axis = [
720 array_ops.get_positive_axis(dim, tensor.shape.rank, 'axis[%d]' % i,
721 'rank(tensor)')
722 for i, dim in enumerate(axis)
723 ]
725 # We only need to slice up to the max axis. If the axis list
726 # is empty, it should be 0.
727 slices = [slice(None)] * (max(axis) + 1 if axis else 0)
729 for dim in axis:
730 slices[dim] = slice(None, None, -1)
732 return tensor[tuple(slices)]
735# ===============================================================================
736# Cross
737# ===============================================================================
740@tf_export('ragged.cross')
741@dispatch.add_dispatch_support
742def cross(inputs, name=None):
743 """Generates feature cross from a list of tensors.
745 The input tensors must have `rank=2`, and must all have the same number of
746 rows. The result is a `RaggedTensor` with the same number of rows as the
747 inputs, where `result[row]` contains a list of all combinations of values
748 formed by taking a single value from each input's corresponding row
749 (`inputs[i][row]`). Values are combined by joining their strings with '_X_'.
750 E.g.:
752 >>> tf.ragged.cross([tf.ragged.constant([['a'], ['b', 'c']]),
753 ... tf.ragged.constant([['d'], ['e']]),
754 ... tf.ragged.constant([['f'], ['g']])])
755 <tf.RaggedTensor [[b'a_X_d_X_f'], [b'b_X_e_X_g', b'c_X_e_X_g']]>
757 Args:
758 inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`.
759 name: Optional name for the op.
761 Returns:
762 A 2D `RaggedTensor` of type `string`.
763 """
764 return _cross_internal(inputs=inputs, hashed_output=False, name=name)
767@tf_export('ragged.cross_hashed')
768@dispatch.add_dispatch_support
769def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None):
770 """Generates hashed feature cross from a list of tensors.
772 The input tensors must have `rank=2`, and must all have the same number of
773 rows. The result is a `RaggedTensor` with the same number of rows as the
774 inputs, where `result[row]` contains a list of all combinations of values
775 formed by taking a single value from each input's corresponding row
776 (`inputs[i][row]`). Values are combined by hashing together their
777 fingerprints. E.g.:
779 >>> tf.ragged.cross_hashed([tf.ragged.constant([['a'], ['b', 'c']]),
780 ... tf.ragged.constant([['d'], ['e']]),
781 ... tf.ragged.constant([['f'], ['g']])],
782 ... num_buckets=100)
783 <tf.RaggedTensor [[78], [66, 74]]>
785 Args:
786 inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`.
787 num_buckets: A non-negative `int` that used to bucket the hashed values. If
788 `num_buckets != 0`, then `output = hashed_value % num_buckets`.
789 hash_key: Integer hash_key that will be used by the `FingerprintCat64`
790 function. If not given, a default key is used.
791 name: Optional name for the op.
793 Returns:
794 A 2D `RaggedTensor` of type `int64`.
795 """
796 return _cross_internal(
797 inputs=inputs,
798 hashed_output=True,
799 num_buckets=num_buckets,
800 hash_key=hash_key,
801 name=name)
804_DEFAULT_CROSS_HASH_KEY = 0xDECAFCAFFE
807def _cross_internal(inputs,
808 hashed_output=False,
809 num_buckets=0,
810 hash_key=None,
811 name=None):
812 """Generates feature cross from a list of ragged and dense tensors."""
813 if not isinstance(inputs, (tuple, list)):
814 raise TypeError('Inputs must be a list')
816 if hash_key is None:
817 hash_key = _DEFAULT_CROSS_HASH_KEY
819 ragged_inputs = []
820 sparse_inputs = []
821 dense_inputs = []
822 input_order = []
823 with ops.name_scope(name, 'RaggedCross', inputs):
824 for i, t in enumerate(inputs):
825 if sparse_tensor.is_sparse(t):
826 t = sparse_tensor.SparseTensor.from_value(t)
827 else:
828 t = ragged_tensor.convert_to_tensor_or_ragged_tensor(t)
829 if t.dtype.is_integer:
830 t = math_ops.cast(t, dtypes.int64)
831 elif t.dtype != dtypes.string:
832 raise ValueError('Unexpected dtype for inputs[%d]: %s' % (i, t.dtype))
833 if isinstance(t, ragged_tensor.RaggedTensor):
834 if t.ragged_rank != 1:
835 raise ValueError('tf.ragged.cross only supports inputs with rank=2')
836 ragged_inputs.append(t)
837 input_order.append('R')
838 elif isinstance(t, sparse_tensor.SparseTensor):
839 sparse_inputs.append(t)
840 input_order.append('S')
841 else:
842 dense_inputs.append(t)
843 input_order.append('D')
845 out_values_type = dtypes.int64 if hashed_output else dtypes.string
846 if ragged_inputs and all(
847 t.row_splits.dtype == dtypes.int32 for t in ragged_inputs):
848 out_row_splits_type = dtypes.int32
849 else:
850 out_row_splits_type = dtypes.int64
852 # Convert hash_key from uint64 -> int64, since we need to pass it via
853 # an int64 attr.
854 if hash_key > 2**63:
855 hash_key -= 2**64
857 values_out, splits_out = gen_ragged_array_ops.ragged_cross(
858 ragged_values=[rt.values for rt in ragged_inputs],
859 ragged_row_splits=[rt.row_splits for rt in ragged_inputs],
860 sparse_indices=[st.indices for st in sparse_inputs],
861 sparse_values=[st.values for st in sparse_inputs],
862 sparse_shape=[st.dense_shape for st in sparse_inputs],
863 dense_inputs=dense_inputs,
864 input_order=''.join(input_order),
865 hashed_output=hashed_output,
866 num_buckets=num_buckets,
867 hash_key=hash_key,
868 out_values_type=out_values_type.as_datatype_enum,
869 out_row_splits_type=out_row_splits_type.as_datatype_enum,
870 name=name)
872 return ragged_tensor.RaggedTensor.from_row_splits(
873 values_out, splits_out, validate=False)
876def fill_empty_rows(ragged_input, default_value, name=None):
877 """Fills empty rows in the input `RaggedTensor` with rank 2 with a default
879 value.
881 This op adds entries with the specified `default_value` for any row in the
882 input that does not already have a value.
884 The op also returns an indicator vector such that
886 empty_row_indicator[i] = True iff row i was an empty row.
888 Args:
889 ragged_input: A `RaggedTensor` with rank 2.
890 default_value: The value to fill for empty rows, with the same type as
891 `ragged_input.`
892 name: A name prefix for the returned tensors (optional)
894 Returns:
895 ragged_ordered_output: A `RaggedTensor`with all empty rows filled in with
896 `default_value`.
897 empty_row_indicator: A bool vector indicating whether each input row was
898 empty.
900 Raises:
901 TypeError: If `ragged_input` is not a `RaggedTensor`.
902 """
903 with ops.name_scope(name, 'RaggedFillEmptyRows', [ragged_input]):
904 if not isinstance(ragged_input, ragged_tensor.RaggedTensor):
905 raise TypeError(
906 'ragged_input must be RaggedTensor, got'
907 f' {type(ragged_input)}'
908 )
909 default_value = ops.convert_to_tensor(
910 default_value, dtype=ragged_input.dtype
911 )
912 (
913 output_value_rowids,
914 output_values,
915 empty_row_indicator,
916 unused_reverse_index_map,
917 ) = gen_ragged_array_ops.ragged_fill_empty_rows(
918 value_rowids=ragged_input.value_rowids(),
919 values=ragged_input.values,
920 nrows=ragged_input.nrows(),
921 default_value=default_value,
922 )
923 return (
924 ragged_tensor.RaggedTensor.from_value_rowids(
925 values=output_values,
926 value_rowids=output_value_rowids,
927 validate=False,
928 ),
929 empty_row_indicator,
930 )
933@ops.RegisterGradient('RaggedFillEmptyRows')
934def _ragged_fill_empty_rows_grad(
935 op,
936 unused_grad_output_indices,
937 output_grad_values,
938 unused_grad_empty_row_indicator,
939 unused_grad_reverse_index_map,
940):
941 """Gradients for RaggedFillEmptyRows."""
942 reverse_index_map = op.outputs[3]
944 d_values, d_default_value = gen_ragged_array_ops.ragged_fill_empty_rows_grad(
945 reverse_index_map=reverse_index_map, grad_values=output_grad_values
946 )
948 # d_value_rowids, d_values, d_nrows, d_default_value.
949 return [None, d_values, None, d_default_value]
952# ===============================================================================
953# dynamic_partition
954# ===============================================================================
955@dispatch.dispatch_for_api(data_flow_ops.dynamic_partition)
956def dynamic_partition(data: ragged_tensor.RaggedOrDense,
957 partitions: ragged_tensor.RaggedOrDense,
958 num_partitions,
959 name=None):
960 """RaggedTensor dispatch override for tf.dynamic_partition."""
961 if not isinstance(num_partitions, int) or num_partitions < 0:
962 raise TypeError('num_partitions must be a non-negative integer')
963 result = stack_dynamic_partitions(data, partitions, num_partitions, name)
964 return [result[i] for i in range(num_partitions)]
967# ===============================================================================
968# split
969# ===============================================================================
970@dispatch.dispatch_for_api(array_ops.split)
971def split(value: ragged_tensor.Ragged,
972 num_or_size_splits,
973 axis=0,
974 num=None,
975 name=None):
976 """Splits a RaggedTensor `value` into a list of sub RaggedTensors.
978 If `num_or_size_splits` is an `int`, then it splits `value` along the
979 dimension `axis` into `num_or_size_splits` smaller RaggedTensors. This
980 requires that `value.shape[axis]` is divisible by `num_or_size_splits`.
982 If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into
983 `len(num_or_size_splits)` elements. The shape of the `i`-th element has the
984 same size as the `value` except along dimension `axis` where the size is
985 `num_or_size_splits[i]`.
987 Splits along a ragged dimension is not allowed.
989 For example:
991 >>> rt = tf.RaggedTensor.from_row_lengths(
992 ... np.arange(6 * 3).reshape(6, 3), row_lengths=[1, 2, 2, 1])
993 >>> rt.shape
994 TensorShape([4, None, 3])
995 >>>
996 >>> rt1, rt2 = tf.split(rt, 2) # uniform splits
997 >>> rt1.shape
998 TensorShape([2, None, 3])
999 >>> rt2.shape
1000 TensorShape([2, None, 3])
1001 >>>
1002 >>> rt3, rt4, rt5 = tf.split(rt, [1, 2, 1]) # ragged splits
1003 >>> rt3.shape
1004 TensorShape([1, None, 3])
1005 >>> rt4.shape
1006 TensorShape([2, None, 3])
1007 >>> rt5.shape
1008 TensorShape([1, None, 3])
1009 >>>
1010 >>> rt6, rt7 = tf.split(rt, [1, 2], axis=2) # splits along axis 2
1011 >>> rt6.shape
1012 TensorShape([4, None, 1])
1013 >>> rt7.shape
1014 TensorShape([4, None, 2])
1016 Args:
1017 value: The `RaggedTensor` to split.
1018 num_or_size_splits: Either an `int` indicating the number of splits
1019 along `axis` or a 1-D integer `Tensor` or Python list containing the sizes
1020 of each output tensor along `axis`. If a Python int, then it must evenly
1021 divide `value.shape[axis]`; otherwise the sum of sizes along the split
1022 axis must match that of the `value`.
1023 axis: An `int` or scalar `int32` `Tensor`. The dimension along which
1024 to split. Must be in the range `[-rank(value), rank(value))`. Defaults to
1025 0.
1026 num: An `int` used to specify the number of outputs when
1027 `num_or_size_splits` is a 1-D list or `Tensor` and its length is
1028 statically unknown, e.g., specifying `tf.TensorSepc(None)` with
1029 the `input_signature` argument of `tf.function` (optional).
1030 name: A name for the operation (optional).
1032 Returns:
1033 if `num_or_size_splits` is an `int` returns a list of `num_or_size_splits`
1034 `RaggedTensor` objects; if `num_or_size_splits` is a 1-D Tensor returns
1035 `num_or_size_splits.get_shape[0]` `RaggedTensor` objects resulting from
1036 splitting `value`.
1038 Raises:
1039 ValueError: If the dimension `axis` of `value` is a ragged dimension.
1040 ValueError: If `num` is unspecified and cannot be inferred.
1041 ValueError: If `num` is specified but doesn't match the length of
1042 `num_or_size_splits`.
1043 ValueError: If `num_or_size_splits` is an `int` and less than 1.
1044 TypeError: If `num_or_size_splits` is not an `int` or 1-D
1045 list or 1-D `Tensor`.
1046 InvalidArgumentError: If the `axis` of `value` cannot be exactly splitted
1047 by `num_or_size_splits`.
1048 InvalidArgumentError: If `num_or_size_splits` is contains negative integers.
1049 InvalidArgumentError: If `num_or_size_splits`'s static shape is unknown and
1050 its dynamic shape is inconsistent `num`.
1051 InvalidArgumentError: If `num_or_size_splits`'s static rank is unknown and
1052 `axis` is a negative integer.
1053 """
1054 with ops.name_scope(name, 'RaggedSplit'):
1055 value = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1056 value, name='value')
1057 if isinstance(num_or_size_splits, int) and num_or_size_splits == 1:
1058 return [value]
1060 # static assert
1061 check_ops.assert_integer_v2(
1062 num_or_size_splits,
1063 message=('`num_or_size_splits` must be an `int` or 1-D list or '
1064 '`Tensor` of integers.'))
1065 value_shape = dynamic_ragged_shape.DynamicRaggedShape.from_tensor(value)
1066 axis = array_ops.get_positive_axis(axis, value_shape.rank)
1067 try:
1068 dim_size = value_shape[axis]
1069 except ValueError:
1070 raise ValueError('Cannot split a ragged dimension. Got `value` with '
1071 f'shape {value_shape} and `axis` {axis}.')
1072 if isinstance(num_or_size_splits, int):
1073 # Uniform split
1074 num_splits = num_or_size_splits
1075 if num_splits < 1:
1076 raise ValueError('`num_or_size_splits` must be >=1 if it is an `int`.'
1077 f'Received {num_or_size_splits}.')
1078 split_length = math_ops.floordiv(dim_size, num_splits)
1079 split_lengths = array_ops.repeat(split_length, num_splits)
1080 else:
1081 # Ragged split
1082 num_splits = None
1083 split_lengths = ops.convert_to_tensor(num_or_size_splits)
1084 if split_lengths.shape.ndims is not None:
1085 if split_lengths.shape.ndims != 1:
1086 raise TypeError('`num_or_size_splits` must be an `int` or 1-D list '
1087 f'or `Tensor`. Received {num_or_size_splits}.')
1088 num_splits = tensor_shape.dimension_value(split_lengths.shape[0])
1090 if num_splits is None:
1091 if num is None:
1092 raise ValueError('`num` must be specified as an `int` when the '
1093 'size of `num_or_size_split` is statically '
1094 f'unknown. Received `num`: {num} and '
1095 f'`num_or_size_split`: {num_or_size_splits}.')
1096 num_splits = num
1097 else:
1098 if num is not None and num != num_splits:
1099 raise ValueError('`num` does not match the size of '
1100 f'`num_or_size_split`. Received `num`: {num} and '
1101 f'size of `num_or_size_split`: {num_splits}.')
1103 splits = array_ops.concat([[0], math_ops.cumsum(split_lengths)], axis=0)
1104 checks = []
1105 checks.append(
1106 check_ops.assert_non_negative_v2(
1107 num_or_size_splits,
1108 message='`num_or_size_splits` must be non-negative.'))
1109 checks.append(
1110 check_ops.assert_equal_v2(
1111 num_splits,
1112 array_ops.shape(split_lengths)[0],
1113 message='`num` is inconsistent with `num_or_size_split.shape[0]`.'))
1114 checks.append(
1115 check_ops.assert_equal_v2(
1116 math_ops.cast(dim_size, splits.dtype),
1117 splits[-1],
1118 message=('Cannot exactly split the `axis` dimension of `value` '
1119 'with the given `num_or_size_split`.')))
1120 splits = control_flow_ops.with_dependencies(checks, splits)
1121 splited_rts = []
1122 slices = [slice(None)] * (axis + 1)
1123 for i in range(num_splits):
1124 slices[-1] = slice(splits[i], splits[i + 1])
1125 splited_rts.append(value[tuple(slices)])
1126 return splited_rts
1129# ===============================================================================
1130# RaggedTensor shape operations
1131# ===============================================================================
1134@dispatch.dispatch_for_api(array_ops.reshape)
1135def ragged_reshape(
1136 tensor: ragged_tensor.RaggedOrDense,
1137 shape: dynamic_ragged_shape.DenseOrRaggedShape
1138) -> Union[ragged_tensor.RaggedTensor, ops.Tensor]:
1139 """Reshapes a tensor or ragged tensor."""
1140 tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1141 tensor, name='tensor')
1142 if isinstance(tensor, ragged_tensor.RaggedTensor):
1143 tensor = tensor.values
1145 if isinstance(shape, dynamic_ragged_shape.DynamicRaggedShape):
1146 flat_values = array_ops.reshape(tensor, shape.inner_shape)
1147 return ragged_tensor.RaggedTensor._from_nested_row_partitions( # pylint: disable=protected-access
1148 flat_values,
1149 shape.row_partitions,
1150 validate=False)
1151 else:
1152 shape = ops.convert_to_tensor(shape, name='shape')
1153 return array_ops.reshape(tensor, shape)
1156@dispatch.dispatch_for_api(array_ops.broadcast_to)
1157def broadcast_to(
1158 input: ragged_tensor.RaggedOrDense, # pylint: disable=redefined-builtin
1159 shape: dynamic_ragged_shape.DynamicRaggedShape
1160) -> Union[ragged_tensor.RaggedTensor, ops.Tensor]:
1161 """Broadcasts a potentially ragged tensor to a ragged shape.
1163 Tiles `input` as necessary to match the given shape.
1165 Behavior is undefined if `input` is not broadcast-compatible with `shape`.
1167 Args:
1168 input: The potentially ragged tensor to broadcast.
1169 shape: A `DynamicRaggedShape`
1171 Returns:
1172 A potentially ragged tensor whose values are taken from
1173 `input`, and whose shape matches `shape`.
1174 """
1175 return dynamic_ragged_shape.broadcast_to(input, shape)
1178# Note: default value for out_type needs to be int32, to match the
1179# default for tf.shape's out_type parameter.
1180@dispatch.dispatch_for_api(array_ops.shape)
1181def ragged_shape(
1182 input: ragged_tensor.Ragged, # pylint: disable=redefined-builtin
1183 name: Optional[str] = None,
1184 out_type=dtypes.int32) -> dynamic_ragged_shape.DynamicRaggedShape:
1185 """Returns the shape of a RaggedTensor.
1187 Args:
1188 input: A `RaggedTensor`
1189 name: A name for the operation (optional).
1190 out_type: dtype used to encode the shape.
1192 Returns:
1193 A `tf.experimental.DynamicRaggedShape`
1194 """
1195 with ops.name_scope(name, 'RaggedShape', [input]):
1196 return dynamic_ragged_shape.DynamicRaggedShape.from_tensor(input, out_type)
1199@dispatch.dispatch_for_api(array_ops.broadcast_dynamic_shape)
1200def broadcast_dynamic_shape(
1201 shape_x: dynamic_ragged_shape.DenseOrRaggedShape,
1202 shape_y: dynamic_ragged_shape.DenseOrRaggedShape
1203) -> dynamic_ragged_shape.DynamicRaggedShape:
1204 """Returns the shape formed by broadcasting two shapes to be compatible.
1206 1. If shape_x and shape_y both have row_partitions, then fail if their dtypes
1207 don't match.
1208 2. If neither has row_partitions and they have different dtypes,
1209 go with int64.
1210 3. If one has row_partitions, go with that dtype.
1212 Args:
1213 shape_x: A `DynamicRaggedShape`
1214 shape_y: A `DynamicRaggedShape`
1216 Returns:
1217 A `DynamicRaggedShape`.
1218 Raises:
1219 ValueError: If `shape_x` and `shape_y` are not broadcast-compatible.
1220 """
1221 if not isinstance(shape_x, dynamic_ragged_shape.DynamicRaggedShape):
1222 shape_x = dynamic_ragged_shape.DynamicRaggedShape([], shape_x)
1223 if not isinstance(shape_y, dynamic_ragged_shape.DynamicRaggedShape):
1224 shape_y = dynamic_ragged_shape.DynamicRaggedShape([], shape_y)
1225 return dynamic_ragged_shape.broadcast_dynamic_shape(shape_x, shape_y)
1228@dispatch.dispatch_for_api(array_ops.ones)
1229def ones(shape: dynamic_ragged_shape.DynamicRaggedShape,
1230 dtype=dtypes.float32,
1231 name=None) -> ragged_tensor.RaggedOrDense:
1232 """Returns ones shaped like x."""
1233 flat_values = array_ops.ones(shape.inner_shape, dtype=dtype, name=name)
1234 return shape._add_row_partitions(flat_values) # pylint: disable=protected-access
1237@dispatch.dispatch_for_api(array_ops.zeros)
1238def zeros(shape: dynamic_ragged_shape.DynamicRaggedShape,
1239 dtype=dtypes.float32,
1240 name=None) -> ragged_tensor.RaggedOrDense:
1241 """Returns ones shaped like x."""
1242 flat_values = array_ops.zeros(shape.inner_shape, dtype=dtype, name=name)
1243 return shape._add_row_partitions(flat_values) # pylint: disable=protected-access
1246@dispatch.dispatch_for_api(array_ops.fill)
1247def fill(dims: dynamic_ragged_shape.DynamicRaggedShape,
1248 value: core_types.TensorLike,
1249 name: Optional[str] = None) -> ragged_tensor.RaggedOrDense:
1250 """Creates a tensor with shape `dims` and fills it with `value`."""
1251 flat_values = array_ops.fill(dims.inner_shape, value, name=name)
1252 return dims._add_row_partitions(flat_values) # pylint: disable=protected-access
1255# ===============================================================================
1256# bitcast
1257# ===============================================================================
1258@dispatch.dispatch_for_api(array_ops.bitcast)
1259def bitcast(
1260 input: ragged_tensor.RaggedOrDense, # pylint: disable=redefined-builtin
1261 type, # pylint: disable=redefined-builtin
1262 name=None) -> ragged_tensor.RaggedOrDense:
1263 """RaggedTensor dispatch override for tf.bitcast."""
1264 type = dtypes.as_dtype(type)
1265 with ops.name_scope(name, 'Bitcast', [input]):
1266 input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1267 input, name='input')
1268 if (input.dtype.size < type.size and input.flat_values.shape.rank < 2):
1269 raise ValueError('`input.flat_values` is required to have rank >= 2 when '
1270 'input.dtype.size < type.size. Actual rank: '
1271 f'{input.flat_values.shape.rank}')
1272 return input.with_flat_values(array_ops.bitcast(input.flat_values, type))