Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_tensor_shape.py: 17%
269 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Shapes & broadcasting for RaggedTensors."""
17from tensorflow.python.framework import constant_op
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.framework import tensor_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import array_ops_stack
24from tensorflow.python.ops import control_flow_assert
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops.ragged import ragged_array_ops
27from tensorflow.python.ops.ragged import ragged_config
28from tensorflow.python.ops.ragged import ragged_tensor
29from tensorflow.python.ops.ragged import ragged_util
32class RaggedTensorDynamicShape:
33 """A collection of tensors encoding the shape of a potentially ragged tensor.
35 Each `RaggedTensorDynamicShape` consists of an ordered list of dimension
36 sizes. There are two dimension types:
38 * "Uniform dimensions" are dimensions where all slices have the same
39 length. `RaggedTensorDynamicShape` records the size of each uniform
40 dimension using a single scalar integer.
42 * "Ragged dimensions" are dimensions whose slices may have different
43 lengths. `RaggedTensorDynamicShape` records the size of each ragged
44 dimension using an integer vector containing the slice lengths for all
45 the slices across that dimension.
47 Furthermore, there are two ways a dimension might be encoded:
49 * "Partitioned dimensions" are dimensions that are encoded using a
50 `RaggedTensor`'s `nested_row_splits`. The outermostmost partitioned
51 dimension must be uniform, and the innermost partitioned dimension must
52 be ragged.
54 * "Inner dimensions" are dimensions that are encoded using a
55 `RaggedTensor`'s `flat_values`. Inner dimensions are always uniform.
57 The sizes of partitioned dimensions are recorded using `partitioned_dim_sizes`
58 and `inner_dim_sizes`:
60 * `partitioned_dim_sizes` is a list of tensors (one for each partitioned
61 dimension).
63 * For uniform dimensions, the tensor is an integer scalar specifying the
64 size of all slices across that dimension.
65 * For ragged dimensions, the tensor is an integer vector specifying the
66 size of each slice across that dimension.
68 * `inner_dim_sizes` is a single integer vector, where each element
69 specifies the size of a single inner dimension.
71 Examples:
73 Tensor | Ragged | Partitioned Dim Sizes | Inner Dim
74 : Rank : : Sizes
75 ------------------------------ | ------ | ---------------------- | ----------
76 `[[1, 2, 3], [4, 5, 6]]` | 0 | | `2, 3`
77 `[[1, 2], [], [3, 4, 5]]` | 1 | `3, (2, 0, 3)` |
78 `[[[1, 2], [3, 4]], [[5, 6]]]` | 1 | `2, (2, 1)` | 2
79 `[[[1, 2], [3]], [[4, 5]]]` | 2 | `2, (2, 1), (2, 1, 2)` |
80 """
82 def __init__(self, partitioned_dim_sizes, inner_dim_sizes,
83 dim_size_dtype=None):
84 """Creates a RaggedTensorDynamicShape.
86 Args:
87 partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for
88 each partitioned dimension. If dimension `d` is uniform, then
89 `partitioned_dim_sizes[d]` must be an integer scalar, specifying the
90 size of all slices across dimension `d`. If dimension `d` is ragged,
91 then `partitioned_dim_sizes[d]` must be an integer vector, specifying
92 the size of each slice across dimension `d`.
93 inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the
94 number of inner dimensions. `inner_dim_sizes[n]` is the size of all
95 slices across the `n`th inner dimension (which is the
96 `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor.
97 dim_size_dtype: dtype for dimension sizes. If not specified, then it
98 is chosen based on the dtypes of `partitioned_dim_sizes` and
99 `inner_dim_sizes`.
100 """
101 assert isinstance(partitioned_dim_sizes, (list, tuple))
103 with ops.name_scope(None, 'RaggedTensorDynamicShape',
104 (partitioned_dim_sizes, inner_dim_sizes)):
105 partitioned_dim_sizes = tuple(
106 ops.convert_to_tensor(size, name='partitioned_dimension_size_%d' % i)
107 for (i, size) in enumerate(partitioned_dim_sizes))
108 inner_dim_sizes = ops.convert_to_tensor(
109 inner_dim_sizes, name='inner_dim_sizes')
111 # Validate shapes.
112 if partitioned_dim_sizes:
113 for axis, dimension_size in enumerate(partitioned_dim_sizes):
114 if dimension_size.shape.ndims is None:
115 raise ValueError(
116 'rank of partitioned_dim_sizes[%d] is unknown' % axis)
117 dimension_size.shape.with_rank_at_most(1)
118 if partitioned_dim_sizes[0].shape.ndims == 1:
119 raise ValueError('outermost partitioned dimension must be uniform')
120 if partitioned_dim_sizes[-1].shape.ndims == 0:
121 raise ValueError('innermost partitioned dimension must be ragged')
122 inner_dim_sizes.shape.assert_has_rank(1)
124 # Convert dimension size tensors to a single dtype.
125 if dim_size_dtype is None:
126 dim_size_dtypes = set(
127 p.dtype for p in partitioned_dim_sizes if p.shape.ndims == 1)
128 if not dim_size_dtypes:
129 dim_size_dtype = dtypes.int64
130 elif len(dim_size_dtypes) == 1:
131 dim_size_dtype = dim_size_dtypes.pop()
132 else:
133 if not ragged_config.auto_cast_partition_dtype():
134 raise ValueError('partitioned_dim_sizes must have matching dtypes')
135 dim_size_dtype = dtypes.int64
136 partitioned_dim_sizes = tuple(math_ops.cast(p, dim_size_dtype)
137 for p in partitioned_dim_sizes)
138 inner_dim_sizes = math_ops.cast(inner_dim_sizes, dim_size_dtype)
140 self._partitioned_dim_sizes = partitioned_dim_sizes
141 self._inner_dim_sizes = inner_dim_sizes
143 def __repr__(self):
144 return ('RaggedTensorDynamicShape'
145 '(partitioned_dim_sizes=%r, inner_dim_sizes=%r)' %
146 (self._partitioned_dim_sizes, self._inner_dim_sizes))
148 @staticmethod
149 def from_dim_sizes(dim_sizes):
150 """Constructs a ragged shape from a list of dimension sizes.
152 This list contains a single tensor for each dimension, where the tensor
153 is a scalar if the dimension is uniform, or a vector if the dimension is
154 ragged.
156 Args:
157 dim_sizes: List of int32 or int64 scalars or vectors.
159 Returns:
160 A RaggedTensorDynamicShape.
161 """
162 with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes',
163 [dim_sizes]):
164 dim_sizes = tuple(
165 ops.convert_to_tensor(size, preferred_dtype=dtypes.int64,
166 name='dim_sizes') for size in dim_sizes)
167 # Split the dimensions into partitioned & inner dimensions.
168 inner_split = 0
169 for dim, dim_size in enumerate(dim_sizes):
170 if dim_size.shape.ndims == 1:
171 inner_split = dim + 1
172 elif dim_size.shape.ndims != 0:
173 raise ValueError('Each dim_size must be a scalar or a vector')
174 return RaggedTensorDynamicShape(dim_sizes[:inner_split],
175 dim_sizes[inner_split:])
177 @classmethod
178 def from_tensor(cls, rt_input, dim_size_dtype=None):
179 """Constructs a ragged shape for a potentially ragged tensor."""
180 with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]):
181 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
182 if not ragged_tensor.is_ragged(rt_input):
183 return cls([], array_ops.shape(rt_input), dim_size_dtype=dim_size_dtype)
184 else:
185 partitioned_dim_sizes = (
186 (rt_input.nrows(),) + rt_input.nested_row_lengths())
187 return RaggedTensorDynamicShape(
188 partitioned_dim_sizes,
189 array_ops.shape(rt_input.flat_values)[1:],
190 dim_size_dtype=dim_size_dtype)
192 def dimension_size(self, axis):
193 """Returns the size of slices across the specified dimension."""
194 if not isinstance(axis, int):
195 raise TypeError('axis must be an integer')
196 partitioned_ndims = len(self._partitioned_dim_sizes)
197 if axis < partitioned_ndims:
198 return self._partitioned_dim_sizes[axis]
199 else:
200 return self._inner_dim_sizes[axis - partitioned_ndims]
202 def is_ragged(self, axis):
203 """Returns true if the indicated dimension is ragged."""
204 if not isinstance(axis, int):
205 raise TypeError('axis must be an integer')
206 rank = self.rank
207 if axis < 0:
208 raise ValueError('Negative axis values are not supported')
209 elif rank is not None and axis >= rank:
210 raise ValueError('Expected axis=%s < rank=%s' % (axis, rank))
211 else:
212 return (axis > 0 and axis < len(self._partitioned_dim_sizes) and
213 self._partitioned_dim_sizes[axis].shape.ndims == 1)
215 @property
216 def rank(self):
217 """The number of dimensions in this shape, or None if unknown."""
218 inner_ndims = tensor_shape.dimension_value(self._inner_dim_sizes.shape[0])
219 if inner_ndims is None:
220 return None
221 else:
222 return len(self._partitioned_dim_sizes) + inner_ndims
224 @property
225 def partitioned_dim_sizes(self):
226 """The partitioned dimension sizes for this shape.
228 Returns:
229 A `list` of 0-D or 1-D integer `Tensor`.
230 """
231 return self._partitioned_dim_sizes
233 @property
234 def inner_dim_sizes(self):
235 """The inner dimension sizes for this shape.
237 Returns:
238 A 1-D integer `Tensor`.
239 """
240 return self._inner_dim_sizes
242 @property
243 def num_partitioned_dimensions(self):
244 """The number of partitioned dimensions in this shape."""
245 return len(self._partitioned_dim_sizes)
247 @property
248 def num_inner_dimensions(self):
249 """The number of inner dimensions, or `None` if not statically known."""
250 return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0])
252 @property
253 def dim_size_dtype(self):
254 """DType used by this shape for dimension sizes."""
255 return self._inner_dim_sizes.dtype
257 def broadcast_to_rank(self, rank):
258 """Adds leading size-1 dimensions to broadcast `self` to the given rank.
260 E.g., if `shape1` is `[3, (D2), 4]`, then `shape1.broadcast_to_rank(5)`
261 is `[1, 1, 3, (D2), 4]`.
263 Args:
264 rank: The rank for the returned shape.
266 Returns:
267 A RaggedTensorDynamicShape with `rank` dimensions, whose inner dimensions
268 have the same size as `self` and whose outer dimensions have size `1`.
270 Raises:
271 ValueError: If `self.rank` is unknown or greater than `rank`.
272 """
273 if self.rank is None:
274 raise ValueError('Unable to broadcast: self.rank is unknown')
275 dims_to_add = rank - self.rank
276 if dims_to_add < 0:
277 raise ValueError('Unable to broadcast: rank=%d must be greater than '
278 'self.rank=%d.' % (rank, self.rank))
279 elif dims_to_add == 0:
280 return self
281 elif self._partitioned_dim_sizes:
282 partitioned_dims = (1,) * dims_to_add + self._partitioned_dim_sizes
283 return RaggedTensorDynamicShape(partitioned_dims, self.inner_dim_sizes,
284 self.dim_size_dtype)
285 else:
286 inner_dims = array_ops.concat(
287 [array_ops.ones([dims_to_add], self.dim_size_dtype),
288 self.inner_dim_sizes],
289 axis=0)
290 return RaggedTensorDynamicShape([], inner_dims, self.dim_size_dtype)
292 def broadcast_dimension(self, axis, lengths):
293 """Returns a shape that is broadcast-compatible with self & lengths.
295 * If dimension[axis] is uniform and lengths is a scalar, the check
296 that either lengths==1 or axis==1 or lengths==axis, and tile
297 dimension[axis] with tf.where(lengths==axis, 1, axis) repeats.
299 * If dimension[axis] is uniform and lengths is a vector, then check
300 that dimension[axis]==1, and raggedly tile dimension[axis] with
301 lengths repeats. (we can skip tiling if we statically know that
302 slice_lengths == 1??)
304 * If dimension[axis] is ragged and lengths is a scalar, then check
305 that lengths==1.
307 * If dimension[axis] is ragged and lengths is a vector, then check
308 that self.dimension_size(axis) == lengths.
310 Args:
311 axis: `int`. The dimension to broadcast.
312 lengths: 0-D or 1-D integer `Tensor`.
314 Returns:
315 A `RaggedTensorDynamicShape`.
316 """
317 lengths = ragged_util.convert_to_int_tensor(
318 lengths, name='lengths', dtype=self.dim_size_dtype)
319 # Check whether lengths is a scalar (for uniform dimensions) or
320 # vector (for ragged dimensions).
321 if lengths.shape.ndims is None:
322 raise ValueError('lengths must have a known rank.')
323 elif lengths.shape.ndims > 1:
324 raise ValueError('lengths must be a scalar or vector')
325 else:
326 lengths_is_scalar = (lengths.shape.ndims == 0)
328 # Verify that the shapes are compatible.
329 if self.is_ragged(axis):
330 if lengths_is_scalar:
331 condition = math_ops.equal(lengths, 1)
332 else:
333 condition = math_ops.reduce_all(
334 math_ops.equal(lengths, self.dimension_size(axis)))
335 else:
336 axis_dim_size = self.dimension_size(axis)
337 if lengths_is_scalar:
338 condition = (
339 math_ops.equal(lengths, 1) | math_ops.equal(axis_dim_size, 1)
340 | math_ops.equal(axis_dim_size, lengths))
341 else:
342 condition = math_ops.equal(axis_dim_size, 1)
343 broadcast_err = [
344 'Unable to broadcast: dimension size mismatch in dimension', axis,
345 'lengths=', lengths, 'dim_size=',
346 self.dimension_size(axis)
347 ]
348 broadcast_check = control_flow_assert.Assert(
349 condition, data=broadcast_err, summarize=10)
351 with ops.control_dependencies([broadcast_check]):
352 # Partitioned dimensions:
353 if axis < self.num_partitioned_dimensions:
354 if self.is_ragged(axis):
355 # Use an identity op to make sure the check actually gets run.
356 return RaggedTensorDynamicShape(
357 self._partitioned_dim_sizes,
358 array_ops.identity(self.inner_dim_sizes), self.dim_size_dtype)
359 else:
360 return self._broadcast_uniform_partitioned_dimension(axis, lengths)
362 # Inner dimensions:
363 else:
364 if lengths_is_scalar:
365 return self._broadcast_inner_dimension_to_uniform(axis, lengths)
366 else:
367 if axis == 0:
368 raise ValueError('Unable to broadcast: '
369 'outermost dimension must be uniform.')
370 return self._broadcast_inner_dimension_to_ragged(axis, lengths)
372 def num_slices_in_dimension(self, axis):
373 """Returns the total number of slices across the indicated dimension."""
374 if axis < 0:
375 return constant_op.constant(1, dtype=self.dim_size_dtype)
376 elif self.is_ragged(axis):
377 return math_ops.reduce_sum(self._partitioned_dim_sizes[axis])
378 else:
379 return self.dimension_size(axis) * self.num_slices_in_dimension(axis - 1)
381 def _broadcast_uniform_partitioned_dimension(self, axis, lengths):
382 """Broadcasts the partitioned dimension `axis` to match `lengths`."""
383 axis_dim_size = self.dimension_size(axis)
384 partitioned_sizes = list(self._partitioned_dim_sizes[:axis])
386 if lengths.shape.ndims == 0:
387 lengths = array_ops.where(
388 math_ops.equal(axis_dim_size, 1), lengths, axis_dim_size)
389 repeats = array_ops.where(math_ops.equal(axis_dim_size, 1), lengths, 1)
390 splits = array_ops_stack.stack([0, self.num_slices_in_dimension(axis)])
391 else:
392 splits = math_ops.range(
393 array_ops.size(lengths, out_type=self.dim_size_dtype) + 1)
394 repeats = lengths
396 partitioned_sizes.append(lengths)
398 for dim_size in self._partitioned_dim_sizes[axis + 1:]:
399 if dim_size.shape.ndims == 0:
400 partitioned_sizes.append(dim_size)
401 splits *= dim_size
402 else:
403 partitioned_sizes.append(
404 ragged_util.repeat_ranges(dim_size, splits, repeats))
405 splits = array_ops.gather(
406 ragged_util.lengths_to_splits(dim_size), splits)
407 inner_sizes = self._inner_dim_sizes
408 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes,
409 self.dim_size_dtype)
411 def _broadcast_inner_dimension_to_uniform(self, axis, length):
412 """Broadcasts the inner dimension `axis` to match `lengths`."""
413 dim_size = self.dimension_size(axis)
414 axis_in_inner_dims = axis - self.num_partitioned_dimensions
415 partitioned_sizes = self._partitioned_dim_sizes
416 inner_sizes = array_ops.concat([
417 self._inner_dim_sizes[:axis_in_inner_dims],
418 [array_ops.where(math_ops.equal(dim_size, 1), length, dim_size)],
419 self._inner_dim_sizes[axis_in_inner_dims + 1:]
420 ],
421 axis=0)
422 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes,
423 self.dim_size_dtype)
425 def _broadcast_inner_dimension_to_ragged(self, axis, lengths):
426 axis_in_inner_dims = axis - self.num_partitioned_dimensions
427 partitioned_sizes = (
428 self._partitioned_dim_sizes + tuple([
429 self._inner_dim_sizes[i] for i in range(axis_in_inner_dims)
430 ]) + (lengths,))
431 inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:]
432 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
434 def with_dim_size_dtype(self, dtype):
435 if dtype not in (dtypes.int32, dtypes.int64):
436 raise ValueError('dtype must be int32 or int64')
437 if self.dim_size_dtype == dtype:
438 return self
439 return RaggedTensorDynamicShape(
440 [math_ops.cast(p, dtype) for p in self._partitioned_dim_sizes],
441 math_ops.cast(self._inner_dim_sizes, dtype))
444def broadcast_dynamic_shape(shape_x, shape_y):
445 """Returns the shape formed by broadcasting two shapes to be compatible.
447 Args:
448 shape_x: A `RaggedTensorDynamicShape`
449 shape_y: A `RaggedTensorDynamicShape`
451 Returns:
452 A `RaggedTensorDynamicShape`.
453 Raises:
454 ValueError: If `shape_x` and `shape_y` are not broadcast-compatible.
455 """
456 if not isinstance(shape_x, RaggedTensorDynamicShape):
457 raise TypeError('shape_x must be a RaggedTensorDynamicShape')
458 if not isinstance(shape_y, RaggedTensorDynamicShape):
459 raise TypeError('shape_y must be a RaggedTensorDynamicShape')
461 # Broadcast both shapes to have the same rank.
462 if shape_x.rank is None or shape_y.rank is None:
463 raise ValueError('Unable to broadcast: unknown rank')
464 broadcast_rank = max(shape_x.rank, shape_y.rank)
465 shape_x = shape_x.broadcast_to_rank(broadcast_rank)
466 shape_y = shape_y.broadcast_to_rank(broadcast_rank)
468 # Broadcast dimensions one at a time, starting from the outermost dimension.
469 for axis in range(broadcast_rank):
470 shape_x = shape_x.broadcast_dimension(axis, shape_y.dimension_size(axis))
471 shape_y = shape_y.broadcast_dimension(axis, shape_x.dimension_size(axis))
473 return shape_x
476def broadcast_to(rt_input, shape, broadcast_inner_dimensions=True):
477 """Broadcasts a potentially ragged tensor to a ragged shape.
479 Tiles `rt_input` as necessary to match the given shape.
481 Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`.
483 Args:
484 rt_input: The potentially ragged tensor to broadcast.
485 shape: A `RaggedTensorDynamicShape`
486 broadcast_inner_dimensions: If false, then inner dimensions will not be
487 tiled.
489 Returns:
490 A potentially ragged tensor whose values are taken from
491 `rt_input`, and whose shape matches `shape`.
492 """
493 if not isinstance(shape, RaggedTensorDynamicShape):
494 raise TypeError('shape must be a RaggedTensorDynamicShape')
495 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
497 # Broadcasting to a uniform shape.
498 if shape.num_partitioned_dimensions == 0:
499 return _broadcast_to_uniform_shape(rt_input, shape,
500 broadcast_inner_dimensions)
501 else:
502 return _broadcast_to_ragged_shape(rt_input, shape,
503 broadcast_inner_dimensions)
506def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions):
507 """Broadcasts rt_input to the uniform shape `shape`."""
508 if isinstance(rt_input, ragged_tensor.RaggedTensor):
509 raise ValueError('Incompatible with shape: ragged rank mismatch')
510 if broadcast_inner_dimensions:
511 return array_ops.broadcast_to(rt_input, shape.inner_dim_sizes)
512 else:
513 return rt_input
516def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
517 """Broadcasts rt_input to the ragged shape `dst_shape`."""
518 # Check that rt_input and dst_shape have the same row_splits dtype.
519 if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
520 rt_input.row_splits.dtype != dst_shape.dim_size_dtype):
521 if not ragged_config.auto_cast_partition_dtype():
522 raise ValueError('rt_input and dst_shape have different row_split '
523 'dtypes; use RaggedTensor.with_row_splits_dtype() or '
524 'RaggedTensorDynamicShape.with_dim_size_dtype() to '
525 'convert to a compatible dtype.')
526 rt_input = rt_input.with_row_splits_dtype(dtypes.int64)
527 dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64)
529 # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
530 if rt_input.shape.ndims is None or dst_shape.rank is None:
531 raise ValueError('Unable to broadcast: unknown rank')
532 if rt_input.shape.ndims > dst_shape.rank:
533 raise ValueError('Incompatible with shape: rank mismatch')
534 if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
535 rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions):
536 raise ValueError('Incompatible with shape: ragged rank mismatch')
538 src_shape = RaggedTensorDynamicShape.from_tensor(rt_input)
539 src_shape = src_shape.broadcast_to_rank(dst_shape.rank)
541 # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape.
542 if dst_shape.rank > rt_input.shape.ndims:
543 if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1:
544 rt_input = array_ops.reshape(
545 rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))
546 for _ in range(dst_shape.rank - rt_input.shape.ndims):
547 if ragged_tensor.is_ragged(rt_input):
548 nrows = rt_input.nrows()
549 else:
550 nrows = array_ops.shape(rt_input,
551 out_type=dst_shape.dim_size_dtype)[0]
552 rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows],
553 validate=False)
555 # Add ragged dimensions to match dst_shape.
556 if ragged_tensor.is_ragged(rt_input):
557 inner_rank_diff = (
558 rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions)
559 if inner_rank_diff > 0:
560 rt_input = rt_input.with_flat_values(
561 ragged_tensor.RaggedTensor.from_tensor(
562 rt_input.flat_values, ragged_rank=inner_rank_diff,
563 row_splits_dtype=dst_shape.dim_size_dtype))
564 else:
565 rt_input = ragged_tensor.RaggedTensor.from_tensor(
566 rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1,
567 row_splits_dtype=dst_shape.dim_size_dtype)
569 # Do broadcasting for any dimensions that will remain uniform. We can do
570 # these all at once, since they're independent of one another.
571 multiples = [1] * dst_shape.rank
572 for axis in range(dst_shape.num_partitioned_dimensions):
573 if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis):
574 src_size = src_shape.dimension_size(axis)
575 dst_size = dst_shape.dimension_size(axis)
576 if ((tensor_util.constant_value(src_size) in (1, None)) and
577 (tensor_util.constant_value(dst_size) != 1)):
578 multiples[axis] = array_ops.where(
579 math_ops.equal(src_size, 1), dst_size, 1)
580 if not all(isinstance(v, int) and v == 1 for v in multiples):
581 multiples = array_ops_stack.stack(multiples, axis=0)
582 rt_input = ragged_array_ops.tile(rt_input, multiples)
584 if broadcast_inner_dimensions:
585 new_shape = array_ops.broadcast_dynamic_shape(
586 array_ops.shape(
587 rt_input.flat_values, out_type=dst_shape.dim_size_dtype),
588 array_ops.concat([[1], dst_shape.inner_dim_sizes], axis=0))
589 rt_input = rt_input.with_flat_values(
590 array_ops.broadcast_to(rt_input.flat_values, new_shape))
592 # Do broadcasting for dimensions that become ragged. We must do these from
593 # outermost to innermost.
594 for axis in range(dst_shape.num_partitioned_dimensions):
595 if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
596 dst_size = dst_shape.dimension_size(axis)
597 rt_input = _ragged_tile_axis(rt_input, axis, dst_size,
598 dst_shape.dim_size_dtype)
600 return rt_input
603def _ragged_tile_axis(rt_input, axis, repeats, row_splits_dtype):
604 """Tile a dimension of a RaggedTensor to match a ragged shape."""
605 assert axis > 0 # Outermost dimension may not be ragged.
607 if not ragged_tensor.is_ragged(rt_input):
608 rt_input = ragged_tensor.RaggedTensor.from_tensor(
609 rt_input, ragged_rank=1, row_splits_dtype=row_splits_dtype)
611 if axis > 1:
612 return rt_input.with_values(
613 _ragged_tile_axis(rt_input.values, axis - 1, repeats,
614 row_splits_dtype))
615 else:
616 src_row_splits = rt_input.nested_row_splits
617 src_row_lengths = rt_input.nested_row_lengths()
618 splits = src_row_splits[0]
620 dst_row_lengths = [repeats]
621 for i in range(1, len(src_row_lengths)):
622 dst_row_lengths.append(
623 ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats))
624 splits = array_ops.gather(src_row_splits[i], splits)
625 dst_values = ragged_util.repeat_ranges(rt_input.flat_values, splits,
626 repeats)
627 return ragged_tensor.RaggedTensor.from_nested_row_lengths(
628 dst_values, dst_row_lengths, validate=False)