Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/dynamic_ragged_shape.py: 14%
1340 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Shapes & broadcasting for RaggedTensors.
17TODO(martinz): make this suitable for output for tf.shape
18TODO(martinz): replace ragged_tensor_shape with this.
19"""
21import abc
22from typing import Any, Iterable, Optional, Sequence, Tuple, Union
24import numpy as np
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import extension_type
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.framework import tensor_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import array_ops_stack
34from tensorflow.python.ops import check_ops
35from tensorflow.python.ops import cond
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops.ragged import ragged_tensor
39from tensorflow.python.ops.ragged.row_partition import RowPartition
40from tensorflow.python.ops.ragged.row_partition import RowPartitionSpec
41from tensorflow.python.types import core
42from tensorflow.python.util import dispatch
43from tensorflow.python.util.tf_export import tf_export
46class _DynamicRaggedShapeBatchEncoder(extension_type.ExtensionTypeBatchEncoder):
47 """A batch encoder for DynamicRaggedShape below."""
49 def batch(self, spec: "DynamicRaggedShape.Spec",
50 batch_size) -> "DynamicRaggedShape.Spec":
51 if spec.num_row_partitions:
52 new_head = _batch_rp_spec_head(spec._row_partitions[0], batch_size) # pylint:disable=protected-access
53 new_tail = [_batch_rp_spec(rp, batch_size) for rp in spec._row_partitions] # pylint:disable=protected-access
54 new_rp = [new_head] + new_tail
55 new_static_inner_shape = _batch_static_inner_shape(
56 spec._static_inner_shape, batch_size) # pylint:disable=protected-access
58 return DynamicRaggedShape.Spec(
59 row_partitions=new_rp,
60 static_inner_shape=new_static_inner_shape,
61 dtype=spec.dtype)
62 elif batch_size is None:
63 if spec.inner_rank == 0:
64 return DynamicRaggedShape.Spec._from_tensor_shape( # pylint:disable=protected-access
65 [None],
66 0,
67 dtype=spec.dtype)
68 else:
69 # Might be None
70 new_head = RowPartitionSpec(
71 uniform_row_length=spec._dimension(0), # pylint:disable=protected-access
72 dtype=spec.dtype)
73 new_static_inner_shape = _batch_static_inner_shape(
74 spec._static_inner_shape, batch_size) # pylint:disable=protected-access
75 return DynamicRaggedShape.Spec(
76 row_partitions=[new_head],
77 static_inner_shape=new_static_inner_shape,
78 dtype=spec.dtype)
79 else:
81 return DynamicRaggedShape.Spec(
82 row_partitions=[],
83 static_inner_shape=_batch_tensor_shape(
84 spec._static_inner_shape, # pylint:disable=protected-access
85 batch_size),
86 dtype=spec.dtype)
88 def unbatch(self,
89 spec: "DynamicRaggedShape.Spec") -> "DynamicRaggedShape.Spec":
90 if spec.num_row_partitions:
91 result = []
92 head = spec._row_partitions[0] # pylint:disable=protected-access
93 scale = None if head.uniform_row_length is None else head.nrows
95 for rp in spec._row_partitions[1:]: # pylint:disable=protected-access
96 if scale is None:
97 result.append(
98 RowPartitionSpec(
99 nrows=None,
100 nvals=None,
101 uniform_row_length=rp.uniform_row_length,
102 dtype=spec.dtype))
103 else:
104 nrows = None if rp.nrows is None else rp.nrows // scale
105 if rp.uniform_row_length is None:
106 scale = None
107 result.append(
108 RowPartitionSpec(
109 nrows=nrows,
110 nvals=None,
111 uniform_row_length=None,
112 dtype=spec.dtype))
113 else:
114 result.append(
115 RowPartitionSpec(
116 nrows=nrows,
117 nvals=rp.nvals // scale,
118 uniform_row_length=rp.uniform_row_length,
119 dtype=spec.dtype))
120 return DynamicRaggedShape.Spec(
121 row_partitions=result,
122 static_inner_shape=_unbatch_static_inner_shape(
123 spec._static_inner_shape, scale), # pylint:disable=protected-access
124 dtype=spec.dtype)
125 else: # spec.num_row_partitions == 0
126 return DynamicRaggedShape.Spec(
127 row_partitions=[],
128 static_inner_shape=spec._static_inner_shape[1:], # pylint:disable=protected-access
129 dtype=spec.dtype)
131 def decode(self, spec: "DynamicRaggedShape.Spec",
132 encoding) -> "DynamicRaggedShape":
133 return DynamicRaggedShape.from_tensor(encoding, dtype=spec.dtype)
135 def encode(self,
136 spec: "DynamicRaggedShape.Spec",
137 value,
138 minimum_rank=0) -> Union[ragged_tensor.RaggedTensor, ops.Tensor]:
139 return ones(value, dtype=dtypes.bool)
141 def encoding_specs(
142 self, spec: "DynamicRaggedShape.Spec"
143 ) -> Union[ragged_tensor.RaggedTensorSpec, tensor_spec.TensorSpec]:
144 if spec.rank != 0:
145 ragged_rank = spec.num_row_partitions
146 else:
147 # special case: need to unbatch twice to get ragged tensor.
148 ragged_rank = -1
149 return ragged_tensor.RaggedTensorSpec(
150 shape=spec._to_tensor_shape(), # pylint:disable=protected-access
151 dtype=dtypes.bool,
152 ragged_rank=ragged_rank,
153 row_splits_dtype=spec.dtype)
156# TODO(martinz): allow inner_shape to be a fully defined TensorShape.
157# A "fully defined TensorShape" means one where the rank and all dimensions are
158# known.
159# Allowing inner_shape might mean allowing inner_shape to be initialized by
160# a fully defined TensorShape, or it might mean that you can actually store
161# TensorShape in the inner_shape field. This could conceivably construct
162# a DynamicRaggedShape that was dtype agnostic.
163#
164# TODO(martinz): unify the impl of the determination of index type across
165# RowPartition and DynamicRaggedShape.
166@tf_export("experimental.DynamicRaggedShape")
167class DynamicRaggedShape(extension_type.BatchableExtensionType):
168 """The shape of a ragged or dense tensor.
170 Ragged shapes are encoded using two fields:
172 * `inner_shape`: An integer vector giving the shape of a dense tensor.
173 * `row_partitions`: A list of `RowPartition` objects, describing how
174 that flat shape should be partitioned to add ragged axes.
176 If a DynamicRaggedShape is the shape of a RaggedTensor rt, then:
177 1. row_partitions = rt._nested_row_partitions
178 (and thus len(row_partitions) > 0)
179 2. inner_shape is the shape of rt.flat_values
181 If a DynamicRaggedShape is the shape of a dense tensor t, then:
182 1. row_partitions = []
183 2. inner_shape is the shape of t.
185 Examples:
187 The following table gives a few examples (where `RP(lengths)` is short
188 for `RowPartition.from_lengths(lengths)`):
190 Row Partitions | Inner Shape | Example Tensor
191 --------------------------- | ------------ | ----------------------------
192 [] | [2, 3] | `[[1, 2, 3], [4, 5, 6]]`
193 [RP([2, 0, 3])] | [5] | `[[1, 2], [], [3, 4, 5]]`
194 [RP([2, 1])] | [3, 2] | `[[[1, 2], [3, 4]], [[5, 6]]]`
195 [RP([2, 1]), RP([2, 1, 2])] | [5] | `[[[1, 2], [3]], [[4, 5]]]`
196 """
197 _row_partitions: Tuple[RowPartition, ...]
198 _inner_shape: ops.Tensor
199 _static_inner_shape: tensor_shape.TensorShape
200 __batch_encoder__ = _DynamicRaggedShapeBatchEncoder()
201 __name__ = "tf.DynamicRaggedShape"
203 def __init__(self,
204 row_partitions: Sequence[RowPartition],
205 inner_shape: core.TensorLike,
206 dtype: Optional[dtypes.DType] = None,
207 validate: bool = False,
208 static_inner_shape: ... = None):
209 """Core constructor for a DynamicRaggedShape.
211 Create a DynamicRaggedShape. This can be used to construct a
212 DynamicRaggedShape representing a ragged or dense shape. If row_partitions
213 is an empty list, then this is equivalent to a dense shape.
215 If row_partitions is specified, then the num_row_partitions will be equal
216 to len(row_partitions). There are several checks made.
217 Specifically:
218 1. Consecutive row_partitions must have consistent nvals and nrows.
219 2. The last row_partitions must have nvals equal to the first element of
220 inner_shape.
222 The inner_shape is converted to a tensor.
223 All row_partitions and the inner_shape are converted to the same dtype
224 (int64 or int32).
226 Args:
227 row_partitions: the row_partitions of the shape.
228 inner_shape: if len(row_partitions) > 0, the shape of the flat_values.
229 Otherwise, the shape of the tensor.
230 dtype: tf.int64, tf.int32, or None representing the preferred dtype.
231 validate: if true, dynamic validation is applied to the shape.
232 static_inner_shape: if len(row_partitions) > 0, the static shape of the
233 flat_values. Otherwise, the static shape of the tensor. Should be
234 convertible to a TensorShape.
235 """
236 if not isinstance(row_partitions, Iterable):
237 raise TypeError(
238 "row_partitions should be a list of row partitions. Instead, got " +
239 str(row_partitions))
240 for x in row_partitions:
241 if not isinstance(x, RowPartition):
242 raise TypeError("row_partitions contains " + str(x) +
243 " which is not a RowPartition")
244 dtype = _find_dtype_iterable(row_partitions, dtype)
245 dtype = _find_dtype(inner_shape, dtype)
246 if (isinstance(inner_shape, np.ndarray) and
247 inner_shape.dtype == np.int32 and dtype is None):
248 dtype = dtypes.int32
249 dtype = _find_dtype(dtypes.int64, dtype)
251 row_partitions = tuple([rp.with_dtype(dtype) for rp in row_partitions])
252 self._row_partitions = row_partitions
253 self._inner_shape = ops.convert_to_tensor(
254 inner_shape, dtype_hint=dtype, name="inner_dim_sizes")
255 if self._inner_shape.dtype != dtype:
256 self._inner_shape = math_ops.cast(self._inner_shape, dtype)
258 checks = []
259 # Validate shapes.
260 if self._row_partitions:
261 for axis, rp in enumerate(self._row_partitions):
262 if axis > 0:
263 previous_row_partition = self._row_partitions[axis - 1]
264 msg = ("RowPartitions in DynamicRaggedShape do not align "
265 f"between {axis - 1} and {axis}")
266 static_nrows = rp.static_nrows
267 static_nvals = previous_row_partition.static_nvals
268 if (static_nrows is not None) and (static_nvals is not None):
269 if static_nrows != static_nvals:
270 raise ValueError(msg)
271 else:
272 continue
273 if validate:
274 checks.append(
275 check_ops.assert_equal(
276 previous_row_partition.nvals(), rp.nrows(), message=msg))
278 self._inner_shape.shape.assert_has_rank(1)
280 self._static_inner_shape = tensor_util.constant_value_as_shape(
281 self._inner_shape)
282 if static_inner_shape is not None:
283 self._static_inner_shape = self._static_inner_shape.merge_with(
284 static_inner_shape)
286 if row_partitions:
287 last_row_partition = row_partitions[-1]
288 static_nvals = last_row_partition.static_nvals
289 static_inner_shape_nvals = tensor_shape.dimension_value(
290 self._static_inner_shape[0])
291 if static_nvals is not None and static_inner_shape_nvals is not None:
292 if static_nvals != static_inner_shape_nvals:
293 raise ValueError("Last row partition does not match inner_shape.")
294 elif validate:
295 checks.append(
296 check_ops.assert_equal(
297 last_row_partition.nvals(),
298 self._inner_shape[0],
299 message="Last row partition does not match inner_shape."))
300 if checks:
301 self._inner_shape = control_flow_ops.with_dependencies(
302 checks, self._inner_shape, name="inner_shape_validated")
303 self._row_partitions = [
304 rp._with_dependencies(checks) for rp in self._row_partitions # pylint: disable=protected-access
305 ]
307 @classmethod
308 def from_lengths(cls,
309 lengths: Sequence[Union[Sequence[int], int]],
310 num_row_partitions=None,
311 dtype=dtypes.int64):
312 """Creates a shape with the given lengths and num_row_partitions.
314 The lengths can either be a nonnegative int or a list of nonnegative ints.
316 If num_row_partitions is None, then the minimal num_row_partitions is used.
318 For example, [2, (3, 2)] is the shape of [[0, 0, 0], [0, 0]], and
319 [2, 2] is the shape of [[0, 0], [0, 0]]
321 This chooses the minimal num_row_partitions required (including zero).
323 The following table gives a few examples (where `RP(lengths)` is short
324 for `RowPartition.from_lengths(lengths)`):
326 For example:
327 from_lengths | row_partitions | inner_shape
328 ---------------------- | --------------------------| -------------
329 [] | [] | []
330 [2, (3, 2)] | [RP([3, 2])] | [5]
331 [2, 2] | [] | [2, 2]
332 [2, (3, 2), 7] | [RP([3, 2])] | [5, 7]
333 [2, (2, 2), 3] | [RP([2, 2])] | [4, 3]
334 [2, 2, 3] | [] | [2, 2, 3]
335 [2, (2, 1), (2, 0, 3)] | [RP(2, 1), RP([2, 0, 3])] | [5]
337 If we want the row partitions to end with uniform row partitions, then
338 we can set num_row_partitions.
340 For example,
341 below URP(3, 12) is RowPartition.from_uniform_row_length(3, 12)
343 from_lengths | num_row_partitions | row_partitions | inner_shape
344 ---------------| -------------------|--------------------------|------------
345 [2, (3, 2), 2] | 2 | [RP([3, 2]), URP(2, 10)] | [10]
346 [2, 2] | 1 | [URP(2, 4)] | [4]
347 [2, 2, 3] | 0 | [] | [2, 2, 3]
348 [2, 2, 3] | 1 | [URP(2, 4)] | [4, 3]
349 [2, 2, 3] | 2 | [URP(2, 4), URP(3, 12)] | [12]
353 Representing the shapes from init():
355 from_lengths | Tensor Example
356 ------------------------ | ------------------------------
357 `[2, 3]` | `[[1, 2, 3], [4, 5, 6]]`
358 `[3, (2, 0, 3)]` | `[[1, 2], [], [3, 4, 5]]`
359 `[2, (2, 1), 2]` | `[[[1, 2], [3, 4]], [[5, 6]]]`
360 `[2, (2, 1), (2, 1, 2)]` | `[[[1, 2], [3]], [[4, 5]]]`
362 Args:
363 lengths: the lengths of sublists along each axis.
364 num_row_partitions: the num_row_partitions of the result or None
365 indicating the minimum number of row_partitions.
366 dtype: the dtype of the shape (tf.int32 or tf.int64).
368 Returns:
369 a new DynamicRaggedShape
370 """
371 if not isinstance(lengths, list):
372 raise ValueError("lengths should be a list")
373 for x in lengths:
374 if not _is_int_or_tuple_of_ints(x):
375 raise ValueError(
376 "element of lengths should be int or tuple of ints: instead %r" %
377 (x,))
379 if num_row_partitions is None:
380 # Calculate the minimal num_row_partitions.
381 is_list = [not isinstance(x, int) for x in lengths]
382 if any(is_list):
383 # Last index when not a list.
384 num_row_partitions = len(is_list) - is_list[-1::-1].index(True) - 1
385 else:
386 num_row_partitions = 0
388 if not isinstance(num_row_partitions, int):
389 raise ValueError("num_row_partitions should be an int or None")
391 if not lengths:
392 if num_row_partitions > 0:
393 raise ValueError("num_row_partitions==0 for a scalar shape")
394 return DynamicRaggedShape([], [], dtype=dtype)
396 if not num_row_partitions < len(lengths):
397 raise ValueError("num_row_partitions should be less than `len(lengths)` "
398 "if shape is not scalar.")
400 if num_row_partitions > 0:
401 (row_partitions, nvals) = _to_row_partitions_and_nvals_from_lengths(
402 lengths[:num_row_partitions + 1])
403 inner_shape = [nvals] + lengths[num_row_partitions + 1:]
404 return DynamicRaggedShape(row_partitions, inner_shape, dtype=dtype)
405 else:
406 return DynamicRaggedShape([], lengths, dtype=dtype)
408 @classmethod
409 def from_row_partitions(cls, row_partitions, dtype=None):
410 """Create a shape from row_partitions.
412 Args:
413 row_partitions: a nonempty list of RowPartition objects.
414 dtype: the dtype to use, or None to use the row_partitions dtype.
416 Returns:
417 a DynamicRaggedShape with inner_rank==1.
418 """
419 if not row_partitions:
420 raise ValueError("row_partitions cannot be empty")
421 inner_shape = [row_partitions[-1].nvals()]
422 return DynamicRaggedShape(row_partitions, inner_shape, dtype=dtype)
424 @classmethod
425 def _from_inner_shape(cls, inner_shape, dtype=None):
426 """Create a shape from inner_shape, where num_row_partitions == 0."""
427 return DynamicRaggedShape([], inner_shape, dtype=dtype)
429 # pylint: disable=protected-access
430 @classmethod
431 def from_tensor(cls, t, dtype=None):
432 """Constructs a ragged shape for a potentially ragged tensor."""
433 if ragged_tensor.is_ragged(t):
434 return DynamicRaggedShape(
435 t._nested_row_partitions, _flat_values_shape(t), dtype=dtype)
436 else:
437 return DynamicRaggedShape._from_inner_shape(
438 array_ops.shape(t), dtype=dtype)
440 @property
441 def row_partitions(self):
442 """The row_partitions of the shape."""
443 return self._row_partitions
445 @property
446 def num_row_partitions(self):
447 """The number of row_partitions of the shape."""
448 return len(self._row_partitions)
450 @property
451 def dtype(self):
452 """The dtype of the shape -- one of tf.int32 or tf.int64."""
453 return self._inner_shape.dtype
455 def _static_inner_shape_as_list(self, truncate_first):
456 """Returns the lengths of the inner shape (if rank known), or [...]."""
457 if self._static_inner_shape.rank is None:
458 return [...]
459 result = self._static_inner_shape.as_list()
460 if truncate_first:
461 return result[1:]
462 return result
464 def static_lengths(self, ragged_lengths=True):
465 """Returns a list of statically known axis lengths.
467 This represents what values are known. For each row partition, it presents
468 either the uniform row length (if statically known),
469 the list of row lengths, or none if it is not statically known.
470 For the inner shape, if the rank is known, then each dimension is reported
471 if known, and None otherwise. If the rank of the inner shape is not known,
472 then the returned list ends with an ellipsis.
474 Args:
475 ragged_lengths: If false, returns None for all ragged dimensions.
477 Returns:
478 A Sequence[Union[Sequence[int],int, None]] of lengths, with a possible
479 Ellipsis at the end.
480 """
481 if self.num_row_partitions == 0:
482 return self._static_inner_shape_as_list(False)
483 first_dim = self.row_partitions[0].static_nrows
484 if isinstance(first_dim, tensor_shape.Dimension):
485 first_dim = first_dim.value
486 rp_dims = [first_dim]
487 for rp in self.row_partitions:
488 if rp.is_uniform():
489 rp_dims.append(rp.static_uniform_row_length)
490 elif ragged_lengths:
491 const_vals = tensor_util.constant_value(rp.row_lengths())
492 if const_vals is None:
493 rp_dims.append(None)
494 else:
495 rp_dims.append(tuple(const_vals.tolist()))
496 else:
497 rp_dims.append(None)
499 return rp_dims + self._static_inner_shape_as_list(True)
501 def __repr__(self):
502 lengths = _list_with_ellipsis_to_str(self.static_lengths())
503 return ("<DynamicRaggedShape "
504 "lengths=%s num_row_partitions=%r>" %
505 (lengths, self.num_row_partitions))
507 def _to_tensor_shape(self) -> tensor_shape.TensorShape:
508 """Returns a TensorShape representation of the shape."""
509 lengths = self.static_lengths(ragged_lengths=False)
510 if not lengths:
511 return tensor_shape.TensorShape(())
512 if lengths[-1] == Ellipsis:
513 return tensor_shape.TensorShape(None)
514 return tensor_shape.TensorShape(lengths)
516 def _slice_shape(self, start, stop):
517 """Returns a shape self[start:stop].
519 If start == 0, then this truncates dimensions after stop.
520 If start != 0, then this will return a shape with num_row_partitions == 0.
522 See __getitem__.
524 Args:
525 start: the first dimension. 0 <= start <= rank
526 stop: the last dimension (exclusive). 0 <= stop <= rank
527 """
528 if stop <= start:
529 return DynamicRaggedShape._from_inner_shape([])
530 elif start == 0:
531 if stop <= self.num_row_partitions:
532 if stop == 1:
533 return DynamicRaggedShape._from_inner_shape(
534 [self.row_partitions[0].nrows()])
535 new_row_partitions = self.row_partitions[:stop - 1]
536 new_inner_shape = [new_row_partitions[-1].nvals()]
537 return DynamicRaggedShape(new_row_partitions, new_inner_shape)
538 else:
539 if self.rank is None:
540 new_inner_rank = stop - self.num_row_partitions
541 new_inner_shape = self.inner_shape[:new_inner_rank]
542 return DynamicRaggedShape(
543 row_partitions=self.row_partitions,
544 inner_shape=new_inner_shape,
545 static_inner_shape=None,
546 validate=False)
548 elif self.rank <= stop:
549 return self
550 new_inner_rank = stop - self.num_row_partitions
551 new_inner_shape = self.inner_shape[:new_inner_rank]
552 return DynamicRaggedShape(
553 row_partitions=self.row_partitions,
554 inner_shape=new_inner_shape,
555 static_inner_shape=tensor_shape.TensorShape([None] *
556 new_inner_rank),
557 validate=False)
558 else:
559 if self.rank is None or stop < self.rank:
560 partial = self._slice_shape(0, stop)
561 else:
562 partial = self
564 for x in partial.row_partitions:
565 if not x.is_uniform():
566 raise ValueError("All relevant dimensions must be uniform")
567 if partial.rank is None:
568 # TODO(martinz): Implement _with_num_row_partitions(0) if rank is
569 # unknown, and remove.
570 raise NotImplementedError(
571 "__getitem__[start:stop] where start > 0 not implemented")
573 return DynamicRaggedShape._from_inner_shape(
574 partial._with_num_row_partitions(0).inner_shape[start:])
576 def _dimension(self, index):
577 """Return a dimension, if the dimension is not ragged (see __getitem__)."""
578 rank = self.rank
579 if not isinstance(index, int):
580 raise TypeError("index should be an int")
581 if (self.num_row_partitions == 0 or index > self.num_row_partitions + 1):
582 # If num_row_partitions > 0 and index <= num_row_partitions + 1, then
583 # we are safe.
584 if rank is None:
585 raise ValueError(
586 "Rank must be known to use __getitem__ on a large index.")
587 if index >= rank:
588 raise IndexError("Index is too big: " + str(index) + ">=" + str(rank))
589 if index < 0:
590 raise IndexError("Index must be non-negative: " + str(index))
591 elif not self.is_uniform(index):
592 raise ValueError("Index " + str(index) + " is not uniform")
593 elif index == 0 and self.num_row_partitions > 0:
594 static_nrows = self.row_partitions[0].static_nrows
595 if static_nrows is not None:
596 return constant_op.constant(static_nrows, dtype=self.dtype)
597 return self.row_partitions[0].nrows()
598 elif self.num_row_partitions == 0:
599 static_result = tensor_shape.dimension_value(
600 self._static_inner_shape[index])
601 if static_result is not None:
602 return constant_op.constant(static_result, dtype=self.dtype)
603 return self.inner_shape[index]
604 elif index > self.num_row_partitions:
605 static_result = tensor_shape.dimension_value(
606 self._static_inner_shape[index - self.num_row_partitions])
607 if static_result is not None:
608 return constant_op.constant(static_result, dtype=self.dtype)
610 return self.inner_shape[index - self.num_row_partitions]
611 else:
612 return self.row_partitions[index - 1].uniform_row_length()
614 def __getitem__(self, index):
615 """Returns a dimension or a slice of the shape.
617 Ragged shapes can have ragged dimensions that depend upon other dimensions.
618 Therefore, if you ask for a dimension that is ragged, this function returns
619 a ValueError. For similar reasons, if a slice is selected that includes
620 a ragged dimension without including the zero dimension, then this fails.
622 Any slice that does not start at zero will return a shape
623 with num_row_partitions == 0.
625 Args:
626 index: the index: can be an int or a slice.
628 Raises:
629 IndexError: if the index is not in range.
630 ValueError: if the rank is unknown, or a ragged rank is requested
631 incorrectly.
632 """
633 rank = self.rank
634 if isinstance(index, slice):
636 if (index.step is not None) and (index.step != 1):
637 raise IndexError("Cannot stride through a shape")
638 start = index.start
639 stop = index.stop
640 if start is None:
641 start = 0
642 start = _fix_start_index(start, rank, self.num_row_partitions)
643 stop = _fix_stop_index(stop, rank)
644 return self._slice_shape(start, stop)
645 elif isinstance(index, int):
646 if index < 0:
647 if rank is None:
648 raise ValueError(
649 "Rank must be known to use __getitem__ with a negative index.")
650 return self._dimension(rank + index)
651 return self._dimension(index)
652 else:
653 raise TypeError("Argument is not an int or a slice")
655 def _num_elements(self):
656 """Number of elements in a shape.
658 Returns:
659 The number of elements in the shape.
661 """
662 return math_ops.reduce_prod(self.inner_shape)
664 def _num_slices_in_dimension(self, axis):
665 """The total size of a dimension (like nvals).
667 Effectively, this is self[:axis+1]._num_elements()
669 Example:
670 shape = DynamicRaggedShape._from_inner_shape([2, 3, 4])
671 shape._num_slices_in_dimension(0) = 2
672 shape._num_slices_in_dimension(1) = 6
673 shape._num_slices_in_dimension(2) = 24
674 shape._num_slices_in_dimension(-1) = 24
675 shape._num_slices_in_dimension(-2) = 6
676 shape._num_slices_in_dimension(-2) = 2
678 Args:
679 axis: the last axis to include in the number of elements. If negative,
680 then axis = axis + rank.
682 Returns:
683 The number of elements in the shape.
684 """
685 if not isinstance(axis, int):
686 raise TypeError("axis must be an integer")
687 if axis < 0:
688 rank = self.rank
689 if rank is None:
690 raise ValueError(
691 "You can't use negative values if the rank is undefined")
692 axis = axis + rank
693 if axis == 0:
694 return self._dimension(0)
695 if axis <= self.num_row_partitions:
696 return self.row_partitions[axis - 1].nvals()
697 # If self.num_row_partitions = 1, and
698 # self.inner_shape=[3,5,6], and axis=2, then you want:
699 # 15 = 3 * 5 = math_ops.reduce_prod(self.inner_shape[:2])
700 # 2 = axis - (self.num_row_partitions - 1)
701 # If num_row_partitions=0, and
702 # self.inner_shape=[3,5,6] and axis=2, then you want:
703 # 90 = 3 * 5 * 6 = math_ops.reduce_prod(self.inner_shape[:3])
704 # 3 = axis - (self.num_row_partitions - 1)
705 remainder = axis - (self.num_row_partitions - 1)
706 return _reduce_prod_patch(self.inner_shape[:remainder])
708 def is_uniform(self, axis):
709 """Returns true if the indicated dimension is uniform."""
710 if not isinstance(axis, int):
711 raise TypeError("axis must be an integer")
712 rank = self.rank
713 if axis < 0:
714 raise IndexError("Negative axis values are not supported")
715 elif rank is not None and axis >= rank:
716 raise IndexError("Expected axis=%s < rank=%s" % (axis, rank))
717 else:
718 return ((axis == 0 or axis > len(self._row_partitions)) # pylint:disable=superfluous-parens
719 or self._row_partitions[axis - 1].is_uniform())
721 @property
722 def rank(self):
723 """The number of dimensions in this shape, or None if unknown."""
724 inner_rank = self.inner_rank
725 if inner_rank is None:
726 return None
727 else:
728 return self.num_row_partitions + inner_rank
730 @property
731 def inner_shape(self):
732 """The inner dimension sizes for this shape.
734 Returns:
735 A 1-D integer `Tensor`.
736 """
737 return self._inner_shape
739 @property
740 def inner_rank(self):
741 """The rank of inner_shape."""
742 return tensor_shape.dimension_value(self._static_inner_shape.rank)
744 def _alt_inner_shape(self, new_inner_rank):
745 """Get an alternative inner shape with higher or lower rank.
747 For the rank of the inner shape to be be higher, the last few ragged
748 dimensions must have uniform_row_length.
750 Args:
751 new_inner_rank: the new rank of the inner_shape
753 Returns:
754 A new inner_shape of rank new_inner_rank.
755 """
756 if new_inner_rank == 0:
757 raise ValueError("new_inner_rank cannot be zero")
758 elif self.inner_rank == 0:
759 raise ValueError("old inner_rank cannot be zero")
760 elif new_inner_rank == self.inner_rank:
761 return self.inner_shape
762 elif new_inner_rank < self.inner_rank:
763 if self._static_inner_shape.is_fully_defined():
764 return _alt_inner_shape_from_tensor_shape(self._static_inner_shape,
765 self.dtype, new_inner_rank)
766 first_dimension = self._num_slices_in_dimension(-new_inner_rank)
767 if new_inner_rank == 1:
768 return array_ops.expand_dims(first_dimension, 0)
769 remaining_dimensions = self.inner_shape[1 - new_inner_rank:]
770 return array_ops.concat(
771 [array_ops.expand_dims(first_dimension, 0), remaining_dimensions],
772 axis=0)
773 else:
774 assert new_inner_rank > self.inner_rank
775 new_dimensions = new_inner_rank - self.inner_rank
776 if any(
777 [not x.is_uniform() for x in self.row_partitions[-new_dimensions:]]):
778 raise ValueError("Cannot get an inner shape over a ragged dimension")
779 first_dimension = self._num_slices_in_dimension(-new_inner_rank)
780 new_dimensions = new_inner_rank - self.inner_rank
781 new_dims = [first_dimension] + [
782 x.uniform_row_length() for x in self.row_partitions[-new_dimensions:]
783 ]
784 return array_ops.concat(
785 [array_ops_stack.stack(new_dims), self.inner_shape[1:]], axis=0)
787 def _inner_shape_dim(self, dimension):
788 """Returns an int or a tensor representing _inner_shape[dimension]."""
789 result = tensor_shape.dimension_value(self._static_inner_shape[dimension])
790 return self._inner_shape[dimension] if result is None else result
792 def _with_inner_rank(self, inner_rank):
793 """Returns the same shape but a different inner_rank.
795 All dimensions that are to be represented in the inner_shape must be dense.
796 See inner_rank.
798 Args:
799 inner_rank: the new inner_rank of the shape.
801 Returns:
802 the same shape but a different inner_rank
804 Raises:
805 ValueError if the new dense rank is invalid, or the old rank is unknown.
806 """
807 rank = self.rank
808 if rank is None:
809 raise ValueError("Rank must be known to adjust inner_rank")
810 elif rank < 2:
811 if inner_rank == rank:
812 return self
813 raise ValueError("Cannot change inner_rank if rank < 2")
814 else:
815 # When self.rank is not None:
816 # self.rank = self.inner_rank + self.num_row_partitions
817 new_num_row_partitions = rank - inner_rank
818 return self._with_num_row_partitions(new_num_row_partitions)
820 def _with_num_row_partitions(self, num_row_partitions):
821 """Creates an identical shape with the given num_row_partitions.
823 Note that the shape must be statically refactorable to this rank.
824 In particular:
825 * rank must be known.
826 * num_row_partitions must be a nonnegative int.
827 * num_row_partitions must be less than the rank of the shape
828 * num_row_partitions must be greater or equal to the index of any ragged
829 dimension.
831 Note that if the num_row_partitions is the same, self is returned.
833 Args:
834 num_row_partitions: the target num_row_partitions (must be a nonnegative
835 int).
837 Returns:
838 a shape with a (possibly) different num_row_partitions.
840 Raises:
841 ValueError: if the rank is unknown, the argument is not a nonnegative int,
842 or there is a dimension that is nonuniform.
843 """
844 rank = self.rank
845 if rank is None:
846 raise ValueError("Rank must be known to adjust num_row_partitions")
847 if not isinstance(num_row_partitions, int):
848 raise ValueError("num_row_partitions must be an int")
849 if num_row_partitions < 0:
850 raise ValueError("num_row_partitions must be nonnegative")
851 if num_row_partitions == self.num_row_partitions:
852 return self
853 if num_row_partitions >= rank:
854 raise ValueError("num_row_partitions must be less than rank")
855 if num_row_partitions > self.num_row_partitions:
856 num_row_partitions_diff = num_row_partitions - self.num_row_partitions
857 new_inner_rank = self.rank - num_row_partitions
858 nvals = self._inner_shape_dim(0)
859 more_rp = []
860 for i in range(num_row_partitions_diff):
861 nrows = nvals
862 row_length = self._inner_shape_dim(i + 1)
863 nvals = nrows * row_length
864 rp = RowPartition.from_uniform_row_length(
865 row_length, nrows=nrows, dtype=self.dtype)
866 more_rp.append(rp)
867 alt_inner = self._alt_inner_shape(new_inner_rank)
868 return DynamicRaggedShape(list(self.row_partitions) + more_rp, alt_inner)
869 else:
870 assert num_row_partitions < self.num_row_partitions
871 return DynamicRaggedShape(
872 self.row_partitions[:num_row_partitions],
873 self._alt_inner_shape(self.rank - num_row_partitions))
875 def _merge_dims(self, outer_axis: int,
876 inner_axis: int) -> "DynamicRaggedShape":
877 """Merges outer_axis...inner_axis into a single dimension.
879 Returns a copy of this shape with the specified range of dimensions
880 flattened into a single dimension, with elements in row-major order.
882 #### Examples:
884 >>> tf.experimental.DynamicRaggedShape.from_lengths([2, (2,1),
885 ... (1,2,3)])._merge_dims(0, 1)
886 <DynamicRaggedShape lengths=[3, (1, 2, 3)] num_row_partitions=1>
887 >>> tf.experimental.DynamicRaggedShape.from_lengths([2, (2,1),
888 ... (1,2,3)])._merge_dims(1, 2)
889 <DynamicRaggedShape lengths=[2, (3, 3)] num_row_partitions=1>
890 >>> tf.experimental.DynamicRaggedShape.from_lengths([2, (2,1),
891 ... (1,2,3)])._merge_dims(0, 2)
892 <DynamicRaggedShape lengths=[6] num_row_partitions=0>
894 To mimic the behavior of `np.flatten` (which flattens all dimensions), use
895 `rt.merge_dims(0, -1). To mimic the behavior of `tf.layers.Flatten` (which
896 flattens all dimensions except the outermost batch dimension), use
897 `rt.merge_dims(1, -1)`.
899 Args:
900 outer_axis: `int`: The first dimension in the range of dimensions to
901 merge. May be negative if `self.shape.rank` is statically known.
902 inner_axis: `int`: The last dimension in the range of dimensions to merge.
903 May be negative if `self.shape.rank` is statically known.
905 Returns:
906 A copy of this shape, with the specified dimensions merged into a
907 single dimension. The returned shape will be
908 `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
909 is the total number of slices in the merged dimensions.
910 """
911 outer_axis = array_ops.get_positive_axis(
912 outer_axis, self.rank, axis_name="outer_axis", ndims_name="rank(self)")
913 inner_axis = array_ops.get_positive_axis(
914 inner_axis, self.rank, axis_name="inner_axis", ndims_name="rank(self)")
915 if not outer_axis <= inner_axis:
916 raise ValueError(f"Expected outer_axis ({outer_axis}) to be less than or "
917 f"equal to inner_axis ({inner_axis}).")
918 if outer_axis == inner_axis:
919 return self
920 if self.num_row_partitions == 0:
921 # A dense tensor.
922 (new_inner_shape,
923 new_static_inner_shape) = _merge_inner_shape(self._inner_shape,
924 self._static_inner_shape,
925 outer_axis, inner_axis)
926 return DynamicRaggedShape([],
927 new_inner_shape,
928 dtype=self.dtype,
929 static_inner_shape=new_static_inner_shape)
930 if inner_axis <= self.num_row_partitions:
931 # Here, we are merging the row_partitions,
932 # but the inner_shape is unchanged.
933 if outer_axis == 0:
934 # There is no need to merge axes before the first, just truncate them.
935 return DynamicRaggedShape(
936 self._row_partitions[inner_axis:],
937 self.inner_shape,
938 dtype=self.dtype,
939 static_inner_shape=self._static_inner_shape)
940 prefix_rp = self._row_partitions[:outer_axis - 1]
941 suffix_rp = self._row_partitions[inner_axis:]
942 internal_rp = self._row_partitions[outer_axis - 1:inner_axis]
943 new_rp = prefix_rp + (_merge_row_partitions(internal_rp),) + suffix_rp
945 return DynamicRaggedShape(
946 new_rp,
947 self.inner_shape,
948 dtype=self.dtype,
949 static_inner_shape=self._static_inner_shape)
950 elif outer_axis > self.num_row_partitions:
951 # In this scenario, only the inner_shape is changed.
952 # Example #1:
953 # if [2, (1, 2), 5, 3], num_row_partitions=1, outer_axis=2, inner_axis=3.
954 # Result: [2, (1, 2), 15], num_row_partitions=1, outer_axis=2,
955 # inner_axis=3.
956 (new_inner_shape, new_static_inner_shape) = _merge_inner_shape(
957 self._inner_shape, self._static_inner_shape,
958 outer_axis - self.num_row_partitions,
959 inner_axis - self.num_row_partitions)
960 return DynamicRaggedShape(
961 self._row_partitions,
962 new_inner_shape,
963 dtype=self.dtype,
964 static_inner_shape=new_static_inner_shape)
965 else:
966 # Here, both inner_shape and row_partitions are changed.
967 rank = self.rank
968 if rank is None:
969 raise ValueError("Cannot merge_dims of the inner shape if the " +
970 "dimension of inner_shape is unknown")
971 if outer_axis == 0:
972 new_inner_shape = self._alt_inner_shape(rank - inner_axis)
973 return DynamicRaggedShape._from_inner_shape(new_inner_shape)
974 else:
975 prefix = self._row_partitions[:outer_axis - 1]
976 suffix = _merge_row_partitions(self._row_partitions[outer_axis - 1:])
977 new_inner_shape = self._alt_inner_shape(rank - inner_axis)
978 num_merged_inner = inner_axis - self.num_row_partitions
979 prod = _reduce_prod_patch(self._inner_shape[1:num_merged_inner + 1])
980 tail_suffix = RowPartition.from_row_splits(suffix.row_splits() * prod)
981 return DynamicRaggedShape(prefix + (tail_suffix,), new_inner_shape)
983 def with_dtype(self, dtype):
984 """Change the dtype of the shape."""
985 if dtype == self.dtype:
986 return self
987 else:
988 return DynamicRaggedShape(
989 self.row_partitions, self.inner_shape, dtype=dtype)
991 def _merge_with(self, other: "DynamicRaggedShape") -> "DynamicRaggedShape":
992 """Merge two shapes that are equal modulo num_row_partitions.
994 The resulting num_row_partitions is the maximum of the two
995 num_row_partitions.
997 Args:
998 other: a DynamicRaggedShape representing the same shape with a possibly
999 different number of row partitions.
1001 Returns:
1002 A DynamicRaggedShape with the same shape and the maximum of the
1003 num_row_partitions of the two shapes.
1004 """
1005 max_num_row_partitions = max(self.num_row_partitions,
1006 other.num_row_partitions)
1007 a = self._with_num_row_partitions(max_num_row_partitions)
1008 b = other._with_num_row_partitions(max_num_row_partitions)
1009 new_row_partitions = [
1010 rp_a._merge_precomputed_encodings(rp_b)
1011 for (rp_a, rp_b) in zip(a._row_partitions, b._row_partitions)
1012 ]
1013 new_dtype = b.dtype if a.dtype == dtypes.int32 else dtypes.int64
1015 new_static_inner_shape = a._static_inner_shape.merge_with(
1016 b._static_inner_shape)
1017 new_inner_shape = a._inner_shape
1018 return DynamicRaggedShape(new_row_partitions, new_inner_shape, new_dtype,
1019 True, new_static_inner_shape)
1021 def _merge_with_spec(
1022 self, other: "DynamicRaggedShape.Spec") -> "DynamicRaggedShape":
1023 """Merge a spec with a DynamicRaggedShape."""
1024 # TODO(martinz): add tests for dynamic inconsistencies.
1025 max_num_row_partitions = max(self.num_row_partitions,
1026 other.num_row_partitions)
1027 a = self._with_num_row_partitions(max_num_row_partitions)
1028 b = other._with_num_row_partitions(max_num_row_partitions)
1029 new_row_partitions = [
1030 rp_a._merge_with_spec(rp_b)
1031 for (rp_a, rp_b) in zip(a._row_partitions, b._row_partitions)
1032 ]
1033 new_dtype = b.dtype if a.dtype == dtypes.int32 else dtypes.int64
1035 new_static_inner_shape = a._static_inner_shape.merge_with(
1036 b._static_inner_shape)
1037 new_inner_shape = a._inner_shape
1038 return DynamicRaggedShape(new_row_partitions, new_inner_shape, new_dtype,
1039 True, new_static_inner_shape)
1041 def _as_row_partitions(self):
1042 """Returns row partitions representing this shape.
1044 In order to represent a shape as row partitions, the rank of the shape
1045 must be known, and the shape must have rank at least one.
1047 Returns:
1048 A list of RowPartition objects.
1049 Raises:
1050 ValueError, if the shape cannot be represented by RowPartitions.
1051 """
1052 rank = self.rank
1053 if rank is None:
1054 raise ValueError("rank must be known for _as_row_partitions")
1055 elif rank < 1:
1056 raise ValueError("rank must be >= 1 for _as_row_partitions")
1057 fully_ragged = self._with_num_row_partitions(rank - 1)
1058 return fully_ragged.row_partitions
1060 def _validate_flat_values_dynamically(self, flat_values):
1061 """Test if flat_values have the right nvals dynamically."""
1062 if self.row_partitions:
1063 assert_op = check_ops.assert_equal(
1064 self.row_partitions[-1].nvals(),
1065 array_ops.shape(flat_values, out_type=self.dtype)[0],
1066 message="Last row partition does not match flat_values.")
1067 return control_flow_ops.with_dependencies([assert_op], flat_values)
1068 return flat_values
1070 def _validate_flat_values(self, flat_values):
1071 """Test if flat_values have the right nvals."""
1072 if not isinstance(flat_values, ops.Tensor):
1073 return flat_values
1074 if self.row_partitions:
1075 last_row_partition = self.row_partitions[-1]
1076 flat_values_shape = flat_values.shape
1077 if flat_values_shape is None:
1078 return self._validate_flat_values_dynamically(flat_values)
1079 first_dim_flat_values = flat_values_shape[0]
1080 if isinstance(first_dim_flat_values, tensor_shape.Dimension):
1081 first_dim_flat_values = first_dim_flat_values.value
1082 if first_dim_flat_values is None:
1083 return self._validate_flat_values_dynamically(flat_values)
1084 static_nvals = last_row_partition.static_nvals
1085 if static_nvals is None:
1086 return self._validate_flat_values_dynamically(flat_values)
1087 if first_dim_flat_values != static_nvals:
1088 raise ValueError("Last row partition does not match flat_values.")
1089 return flat_values
1091 def _add_row_partitions(self, flat_values, validate=False):
1092 """Add row partitions to flat_values, if necessary.
1094 If the shape is truly ragged, then this adds the row_partitions.
1096 The shape is dense, then this just returns flat_values.
1098 Args:
1099 flat_values: the flat_values of a ragged tensor with this shape, or a
1100 dense tensor with this shape.
1101 validate: validate the flat_values have the right first dimension.
1103 Returns:
1104 flat_values reshaped to have row_partitions.
1105 """
1106 if self.row_partitions:
1107 if validate:
1108 flat_values = self._validate_flat_values(flat_values)
1109 return ragged_tensor.RaggedTensor._from_nested_row_partitions(
1110 flat_values, self.row_partitions, validate=False)
1111 else:
1112 return flat_values
1114 class Spec:
1115 """A Spec for DynamicRaggedShape: similar to a static shape."""
1117 def __init__(self, row_partitions: Tuple[RowPartitionSpec, ...],
1118 static_inner_shape: tensor_shape.TensorShape,
1119 dtype: dtypes.DType):
1120 """Create a Spec given row partitions, a static inner shape, and a dtype.
1122 Args:
1123 row_partitions: A sequence of `RowPartitionSpec`s describing how the
1124 ragged shape is partitioned.
1125 static_inner_shape: The static shape of the flat_values.
1126 dtype: The DType used to encode the shape (tf.int64 or tf.int32).
1127 """
1128 # Independent validation and coercion of each argument.
1129 if not isinstance(row_partitions, Iterable):
1130 raise TypeError("row_partitions should be an Iterable")
1132 row_partitions = tuple(row_partitions)
1134 static_inner_shape = tensor_shape.as_shape(static_inner_shape)
1136 dtype = dtypes.as_dtype(dtype)
1138 if not all(isinstance(rp, RowPartitionSpec) for rp in row_partitions):
1139 raise TypeError(
1140 "row_partitions should be an Iterable of RowPartitionSpecs")
1142 if dtype != dtypes.int32 and dtype != dtypes.int64:
1143 raise ValueError("dtype must be tf.int32 or tf.int64")
1145 # All fields are now typechecked and internally consistent.
1146 for spec in row_partitions:
1147 if spec.dtype != dtype:
1148 raise ValueError(
1149 f"dtype of {spec!r} is {spec.dtype!r}: expected {dtype!r}")
1151 row_partitions = tuple(row_partitions)
1153 inner_rank = static_inner_shape.rank
1155 if inner_rank == 0:
1156 if row_partitions:
1157 raise ValueError(
1158 "If row_partitions are provided, must have inner_rank > 0")
1159 else:
1160 num_slices_in_dimension = [] # type: Sequence[tensor_shape.Dimension]
1162 # We first attempt to calculate num_slices_in_dimension through a
1163 # forward pass, using nrows[k] = nrows[k-1] * uniform_row_length
1164 # and other tricks.
1165 for i in range(len(row_partitions)):
1166 rp = row_partitions[i]
1167 result = tensor_shape.Dimension(rp.nrows)
1168 if i > 0:
1169 previous_rp = row_partitions[i - 1]
1170 result = result.merge_with(previous_rp.nvals)
1171 result = result.merge_with(num_slices_in_dimension[-1] *
1172 previous_rp.uniform_row_length)
1173 num_slices_in_dimension.append(result)
1174 # In the last step of the forward pass,
1175 # we combine nvals and the first dimension in static_inner_shape.
1176 if row_partitions:
1177 last_rp = row_partitions[-1]
1178 result = (num_slices_in_dimension[-1] *
1179 last_rp.uniform_row_length).merge_with(last_rp.nvals)
1180 if inner_rank is not None:
1181 result = result.merge_with(
1182 tensor_shape.dimension_at_index(static_inner_shape, 0))
1183 static_inner_shape = result + static_inner_shape[1:]
1184 num_slices_in_dimension.append(result)
1186 # Now, we start a backward pass.
1187 for i in range(len(num_slices_in_dimension) - 1, 0, -1):
1188 num_slices_in_dimension[i - 1] = num_slices_in_dimension[
1189 i - 1].merge_with(
1190 _safe_floor_div(num_slices_in_dimension[i],
1191 row_partitions[i - 1].uniform_row_length))
1193 # Finally, we construct the partitions.
1194 row_partitions = [
1195 RowPartitionSpec( # pylint: disable=g-complex-comprehension
1196 nrows=num_slices_in_dimension[i].value,
1197 uniform_row_length=rp.uniform_row_length,
1198 nvals=num_slices_in_dimension[i + 1].value,
1199 dtype=rp.dtype) for i, rp in enumerate(row_partitions)
1200 ]
1202 self._static_inner_shape = static_inner_shape
1203 self._inner_shape = tensor_spec.TensorSpec([inner_rank], dtype=dtype)
1204 self._row_partitions = row_partitions
1206 def __repr__(self):
1207 return (
1208 f"DynamicRaggedShape.Spec(row_partitions={self._row_partitions!r}, " +
1209 f"static_inner_shape={self._static_inner_shape!r}, " +
1210 f"dtype={self.dtype!r})")
1212 @classmethod
1213 def from_value(cls, value: Any) -> "DynamicRaggedShape.Spec":
1214 """Create a Spec from a DynamicRaggedShape."""
1215 # super().from_value(...) creates an object, but there is no validation.
1216 # No methods can be trusted on the object, just the properties.
1217 initial = super(DynamicRaggedShape.Spec, cls).from_value(value)
1219 # However, since value is a DynamicRaggedShape, we
1220 # can guarantee that initial._inner_shape.shape.rank == 1
1222 # Moreover, if inner_shape.shape[0] is not None, then
1223 # static_inner_shape.rank is not None.
1225 return DynamicRaggedShape.Spec(
1226 row_partitions=initial._row_partitions,
1227 static_inner_shape=initial._static_inner_shape,
1228 dtype=initial._inner_shape.dtype)
1230 # TODO(martinz): it is unclear what the default uniformity of RowPartitions
1231 # should be, so I am moving this to experimental until we figure it out.
1232 # Also, while I have specified this is meant to represent a shape of a
1233 # proper Tensor instead of a RaggedTensor, this is also subject to
1234 # interpretation.
1235 @classmethod
1236 def _from_tensor_shape(cls, shape: Any, num_row_partitions: int,
1237 dtype: dtypes.DType) -> "DynamicRaggedShape.Spec":
1238 """Creates a `DynamicRaggedShape.Spec` corresponding to a `tf.TensorShape`.
1240 It is assumed that this is a `tf.TensorShape` coming from a
1241 `tf.TensorSpec`, not from `RaggedTensor.shape`.
1243 In addition to the shape, we need to know the number of row partitions,
1244 and the dtype used in the shape (tf.int32 or tf.int64).
1246 Within the dimensions that are partitioned, all dimensions are assumed
1247 to be uniform.
1249 Args:
1250 shape: a TensorShape.
1251 num_row_partitions: the ragged rank of the RaggedShape.
1252 dtype: the dtype of the shape (not the tensor); tf.int64 or tf.int32.
1254 Returns:
1255 a DynamicRaggedShape.Spec representing a TensorShape.
1256 """
1257 if dtype != dtypes.int32 and dtype != dtypes.int64:
1258 raise ValueError("dtype must be tf.int32 or tf.int64")
1260 shape = tensor_shape.as_shape(shape)
1261 if shape.rank is None:
1262 row_partitions = [
1263 RowPartitionSpec(dtype=dtype) for _ in range(num_row_partitions)
1264 ]
1265 return DynamicRaggedShape.Spec(
1266 row_partitions=row_partitions,
1267 static_inner_shape=tensor_shape.TensorShape(None),
1268 dtype=dtype)
1270 if shape.rank <= 1:
1271 # Create a scalar or vector shape.
1272 if num_row_partitions:
1273 raise ValueError("num_row_partitions should be zero " +
1274 "if shape is a scalar or vector.")
1275 return DynamicRaggedShape.Spec(
1276 row_partitions=[], static_inner_shape=shape, dtype=dtype)
1278 if shape.rank <= num_row_partitions:
1279 raise ValueError("num_row_partitions must be less than rank")
1281 num_elements_so_far = tensor_shape.dimension_value(shape[0])
1282 rp_specs = []
1283 for i in range(num_row_partitions):
1284 current_dim = tensor_shape.dimension_value(shape[i + 1])
1285 if current_dim is None or num_elements_so_far is None:
1286 nvals = None
1287 else:
1288 nvals = num_elements_so_far * current_dim
1289 rp_specs.append(
1290 RowPartitionSpec(
1291 nrows=num_elements_so_far,
1292 nvals=nvals,
1293 uniform_row_length=current_dim,
1294 dtype=dtype))
1295 num_elements_so_far = nvals
1297 static_inner_shape = tensor_shape.TensorShape(
1298 [num_elements_so_far]) + shape[num_row_partitions + 1:]
1299 return DynamicRaggedShape.Spec(
1300 row_partitions=rp_specs,
1301 static_inner_shape=static_inner_shape,
1302 dtype=dtype)
1304 @classmethod
1305 def _from_spec(
1306 cls,
1307 spec: Union["DynamicRaggedShape.Spec", ragged_tensor.RaggedTensorSpec,
1308 tensor_spec.TensorSpec],
1309 dtype: dtypes.DType = dtypes.int64) -> "DynamicRaggedShape.Spec":
1310 """Create a TypeSpec for the shape of an object with a given TypeSpec.
1312 I.e., if `x_spec = tf.type_spec_from_value(x)`, then
1313 `DynamicRaggedShape.from_spec(x_spec)` returns a TypeSpec compatible with
1314 `tf.type_spec_from_value(tf.shape(x))`.
1316 >>> rt = tf.ragged.constant([[1, 2], [3], [4, 5, 6]])
1317 >>> rt_spec = tf.type_spec_from_value(rt)
1318 >>> rt_shape = DynamicRaggedShape.from_tensor(rt)
1320 >>> shape_spec_1 = tf.type_spec_from_value(rt_shape)
1321 >>> shape_spec_2 = DynamicRaggedShape.Spec._from_spec(rt_spec)
1322 >>> assert shape_spec_1.is_compatible_with(shape_spec_2)
1324 Args:
1325 spec: a Spec of a Tensor or RaggedTensor.
1326 dtype: the default dtype (if necessary).
1328 Returns:
1329 A Spec of the shape of a Tensor or RaggedTensor.
1331 """
1332 # TODO(martinz): Add StructuredTensor.Spec when its easy.
1333 if isinstance(spec, DynamicRaggedShape.Spec):
1334 return spec
1335 elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
1336 return cls._from_tensor_shape(spec.shape, spec.ragged_rank,
1337 spec.row_splits_dtype)
1338 elif isinstance(spec, tensor_spec.TensorSpec):
1339 return cls._from_tensor_shape(
1340 shape=spec.shape, num_row_partitions=0, dtype=dtype)
1342 @property
1343 def dtype(self) -> dtypes.DType:
1344 return self._inner_shape.dtype
1346 @property
1347 def inner_rank(self) -> Optional[int]:
1348 if self._static_inner_shape.rank is not None:
1349 return self._static_inner_shape.rank
1350 if self._inner_shape.shape.rank is None:
1351 return None
1352 return tensor_shape.dimension_value(self._inner_shape.shape[0])
1354 @property
1355 def num_row_partitions(self) -> int:
1356 return len(self._row_partitions)
1358 @property
1359 def rank(self) -> Optional[int]:
1360 inner_rank = self.inner_rank
1361 return None if inner_rank is None else inner_rank + self.num_row_partitions
1363 def _dimension(self, index: int) -> Optional[int]:
1364 """Get the size of dimension index, if known statically."""
1365 if index == 0:
1366 if self._row_partitions:
1367 return self._row_partitions[0].nrows
1368 elif self.inner_rank is None:
1369 return None
1370 elif self.inner_rank == 0:
1371 raise ValueError("Index out of range: 0.")
1372 else:
1373 return tensor_shape.dimension_value(self._static_inner_shape[0])
1374 if index <= len(self._row_partitions):
1375 return self._row_partitions[index - 1].uniform_row_length
1377 relative_index = index - self.num_row_partitions
1379 if self.inner_rank is None:
1380 return None
1381 elif self.inner_rank <= relative_index:
1382 raise ValueError(f"Index out of range: {index}.")
1383 else:
1384 return tensor_shape.dimension_value(
1385 self._static_inner_shape[relative_index])
1387 def _num_slices_in_dimension(self, axis: int) -> Optional[int]:
1388 """The total size of a dimension (like nvals).
1390 This is a static version of DynamicRaggedShape._num_slices_in_dimension()
1392 Example:
1394 ```
1395 shape = DynamicRaggedShape.Spec(
1396 _row_partitions=[
1397 RowPartitionSpec(nrows=3, nvals=14, dtype=tf.int32)
1398 RowPartitionSpec(nrows=14, nvals=25, dtype=tf.int32)
1400 ],
1401 _static_inner_shape=tf.TensorShape([25, 3, 4]),
1402 _inner_shape=tf.TensorSpec(tf.TensorShape([3]), dtype=tf.int32))
1403 shape._num_slices_in_dimension(0) = 3
1404 shape._num_slices_in_dimension(1) = 14
1405 shape._num_slices_in_dimension(2) = 25
1406 shape._num_slices_in_dimension(3) = 3
1407 shape._num_slices_in_dimension(4) = 4
1408 shape._num_slices_in_dimension(-2) = 3
1409 ```
1411 Args:
1412 axis: the last dimension to include.
1414 Returns:
1415 the number of values in a dimension.
1416 """
1417 if not isinstance(axis, int):
1418 raise TypeError("axis must be an integer")
1419 axis = array_ops.get_positive_axis(axis, self.rank, ndims_name="rank")
1421 if axis == 0:
1422 return self._dimension(0)
1423 if axis <= self.num_row_partitions:
1424 # TODO(martinz): use nvals OR nrows, whichever is defined.
1425 return self._row_partitions[axis - 1].nvals
1426 remainder = axis - (self.num_row_partitions - 1)
1427 head_inner_shape = self._static_inner_shape[:remainder]
1428 return head_inner_shape.num_elements()
1430 def with_dtype(self, dtype: dtypes.DType) -> "DynamicRaggedShape.Spec":
1431 """Return the same spec, but with a different DType."""
1432 new_rp_specs = [rp.with_dtype(dtype) for rp in self._row_partitions]
1433 return DynamicRaggedShape.Spec(
1434 row_partitions=new_rp_specs,
1435 static_inner_shape=self._static_inner_shape,
1436 dtype=dtype)
1438 def _merge_with(
1439 self, other: "DynamicRaggedShape.Spec") -> "DynamicRaggedShape.Spec":
1440 """Merges all information between two specs.
1442 Specs are expected to represent the same information modulo
1443 num_row_partitons.
1445 If the specs are of different ranks, then fail.
1447 Args:
1448 other: another Spec of the same rank.
1450 Returns:
1451 a Spec with the union of information.
1452 """
1453 max_num_row_partitions = max(self.num_row_partitions,
1454 other.num_row_partitions)
1455 a = self._with_num_row_partitions(max_num_row_partitions)
1456 b = other._with_num_row_partitions(max_num_row_partitions)
1458 new_rp = [
1459 a._merge_with(b)
1460 for (a, b) in zip(a._row_partitions, b._row_partitions)
1461 ]
1463 new_static_inner_shape = a._static_inner_shape.merge_with(
1464 b._static_inner_shape)
1466 dtype = b.dtype if (a.dtype == dtypes.int32) else dtypes.int64
1468 return DynamicRaggedShape.Spec(
1469 new_rp, new_static_inner_shape, dtype=dtype)
1471 def _with_num_row_partitions(
1472 self, new_num_row_partitions: int) -> "DynamicRaggedShape.Spec":
1473 """Change the number of row partitions in the spec."""
1474 rank = self.rank
1475 if rank is None:
1476 raise ValueError(
1477 "Changing num_row_partitions with unknown rank unsupported")
1478 if new_num_row_partitions > max(rank - 1, 0):
1479 raise ValueError("Number of row partitions too large")
1480 if new_num_row_partitions < 0:
1481 raise ValueError("Number of row partitions negative")
1482 if self.num_row_partitions == new_num_row_partitions:
1483 return self
1484 elif self.num_row_partitions < new_num_row_partitions:
1485 # TODO(martinz): Consider swapping.
1486 rp_delta = new_num_row_partitions - self.num_row_partitions
1487 tail_shape = DynamicRaggedShape.Spec._from_tensor_shape(
1488 self._static_inner_shape, rp_delta, self.dtype)
1489 return DynamicRaggedShape.Spec(
1490 row_partitions=self._row_partitions + tail_shape._row_partitions,
1491 static_inner_shape=tail_shape._static_inner_shape,
1492 dtype=self.dtype)
1493 else:
1494 assert self.num_row_partitions > new_num_row_partitions
1495 new_row_partitions = self._row_partitions[:new_num_row_partitions]
1496 last_row_partition = new_row_partitions[-1]
1497 old_row_partitions = self._row_partitions[new_num_row_partitions:]
1498 new_static_inner_shape = (
1499 tensor_shape.TensorShape(
1500 [last_row_partition.nvals] +
1501 [x.uniform_row_length for x in old_row_partitions]) +
1502 self._static_inner_shape[1:])
1503 return DynamicRaggedShape.Spec(new_row_partitions,
1504 new_static_inner_shape, self.dtype)
1506 def _set_rank_if_unknown(self, new_rank: int) -> "DynamicRaggedShape.Spec":
1507 """Ensures this has a known rank at least new_rank."""
1508 if new_rank is None:
1509 raise TypeError("new_rank is None, but expected int")
1510 if new_rank < 0:
1511 raise ValueError("Rank must be non-negative")
1512 current_rank = self.rank
1513 if current_rank is not None and current_rank < new_rank:
1514 raise ValueError(
1515 "Rank is {current_rank}, expected at least {new_rank}.".format(
1516 current_rank=current_rank, new_rank=new_rank))
1518 if current_rank is not None:
1519 return self
1521 if self._row_partitions:
1522 new_inner_rank = max(new_rank - self.num_row_partitions, 1)
1523 first_dim = self._row_partitions[-1].nvals
1524 static_inner_shape = tensor_shape.TensorShape([first_dim] + [None] *
1525 (new_inner_rank - 1))
1526 else:
1527 static_inner_shape = tensor_shape.TensorShape([None] * new_rank)
1529 return DynamicRaggedShape.Spec(
1530 row_partitions=self._row_partitions,
1531 static_inner_shape=static_inner_shape,
1532 dtype=self.dtype)
1534 def _truncate(self, new_rank: int) -> "DynamicRaggedShape.Spec":
1535 """Truncate a ragged shape spec.
1537 For example, if the original spec s was for a shape:
1538 [3, [4, 1], 2, 7]
1540 Then truncate_dynamic_ragged_shape_spec(s, 3) is a spec for:
1541 [3, [4, 1], 2]
1543 Args:
1544 new_rank: the new rank
1546 Returns:
1547 A truncated DynamicRaggedShape.Spec.
1548 """
1549 if self.rank is None:
1550 return self._set_rank_if_unknown(new_rank)._truncate(new_rank)
1552 if new_rank == 0:
1553 return DynamicRaggedShape.Spec._from_tensor_shape([], 0, self.dtype)
1555 if new_rank == 1:
1556 vector_size = self._dimension(0)
1557 return DynamicRaggedShape.Spec._from_tensor_shape([vector_size], 0,
1558 self.dtype)
1560 if new_rank < self.num_row_partitions + 1:
1561 new_row_partitions = self._row_partitions[:new_rank - 1]
1562 new_static_inner_shape = tensor_shape.TensorShape(
1563 [new_row_partitions[-1].nvals])
1564 return DynamicRaggedShape.Spec(
1565 row_partitions=new_row_partitions,
1566 static_inner_shape=new_static_inner_shape,
1567 dtype=self.dtype)
1568 else:
1569 remainder = new_rank - self.num_row_partitions
1570 new_static_inner_shape = self._static_inner_shape[:remainder]
1571 return DynamicRaggedShape.Spec(
1572 row_partitions=self._row_partitions,
1573 static_inner_shape=new_static_inner_shape,
1574 dtype=self.dtype)
1576 def _to_tensor_shape(self):
1577 """Get a tensor shape corresponding to this type."""
1578 alt = self
1579 if alt._static_inner_shape.rank is None:
1580 return tensor_shape.TensorShape(None)
1581 if alt._static_inner_shape.rank == 0:
1582 assert not alt._row_partitions
1583 return alt._static_inner_shape
1584 prefix = [alt._dimension(0)]
1585 prefix.extend([rp.uniform_row_length for rp in alt._row_partitions])
1586 suffix = alt._static_inner_shape[1:]
1587 return tensor_shape.TensorShape(prefix) + suffix
1590def broadcast_dynamic_shape(shape_x: DynamicRaggedShape,
1591 shape_y: DynamicRaggedShape) -> DynamicRaggedShape:
1592 """Returns the shape formed by broadcasting two shapes to be compatible.
1594 1. If shape_x and shape_y both have row_partitions, then fail if their dtypes
1595 don't match.
1596 2. If neither has row_partitions and they have different dtypes,
1597 go with int64.
1598 3. If one has row_partitions, go with that dtype.
1600 Args:
1601 shape_x: A `DynamicRaggedShape`
1602 shape_y: A `DynamicRaggedShape`
1604 Returns:
1605 A `DynamicRaggedShape`.
1606 Raises:
1607 ValueError: If `shape_x` and `shape_y` are not broadcast-compatible.
1608 """
1609 if not isinstance(shape_x, DynamicRaggedShape):
1610 raise TypeError("shape_x must be a DynamicRaggedShape")
1611 if not isinstance(shape_y, DynamicRaggedShape):
1612 raise TypeError("shape_y must be a DynamicRaggedShape")
1614 return broadcast_dynamic_shape_extended(shape_x, shape_y)[0]
1617def broadcast_to(rt_input, shape: DynamicRaggedShape):
1618 """Broadcasts a potentially ragged tensor to a ragged shape.
1620 Tiles `rt_input` as necessary to match the given shape.
1622 Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`.
1624 Args:
1625 rt_input: The potentially ragged tensor to broadcast.
1626 shape: A `DynamicRaggedShape`
1628 Returns:
1629 A potentially ragged tensor whose values are taken from
1630 `rt_input`, and whose shape matches `shape`.
1631 """
1632 if not isinstance(shape, DynamicRaggedShape):
1633 raise TypeError("shape must be a DynamicRaggedShape")
1634 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
1635 origin_shape = None
1636 if ragged_tensor.is_ragged(rt_input):
1637 if shape.num_row_partitions != 0:
1638 if rt_input.row_splits.dtype != shape.dtype:
1639 raise ValueError("Cannot coerce row_splits.dtype")
1640 else:
1641 shape = shape.with_dtype(rt_input.row_splits.dtype)
1642 origin_shape = DynamicRaggedShape.from_tensor(rt_input)
1643 else:
1644 if shape.num_row_partitions != 0:
1645 origin_shape = DynamicRaggedShape.from_tensor(rt_input, dtype=shape.dtype)
1646 else:
1647 origin_shape = DynamicRaggedShape.from_tensor(
1648 rt_input, dtype=dtypes.int64)
1649 shape = shape.with_dtype(dtype=dtypes.int64)
1651 broadcaster = _get_broadcaster(origin_shape, shape)
1652 return broadcaster.broadcast(rt_input)
1655def broadcast_dynamic_shape_extended(
1656 a: DynamicRaggedShape, b: DynamicRaggedShape
1657): # -> Tuple[DynamicRaggedShape, _Broadcaster, _Broadcaster]
1658 """Gets the smallest shape to which a and b can broadcast.
1660 In order to create the smallest shape, one must also do most of the
1661 work to figure out how to transform from the shapes given. Thus, in addition
1662 to returning the shape, it also creates transformations from the
1663 original shapes to the result.
1665 This is the equivalent of:
1667 c = broadcast_dynamic_shape(a, b)
1668 ac = get_broadcaster(a, c)
1669 bc = get_broadcaster(b, c)
1670 return (c, ac, bc)
1672 Args:
1673 a: a DynamicRaggedShape
1674 b: a DynamicRaggedShape
1676 Returns:
1677 A triple of a shape and two broadcasters.
1678 """
1679 if a.row_partitions and b.row_partitions:
1680 if a.dtype != b.dtype:
1681 raise ValueError("Dtypes don't match")
1682 elif a.dtype != b.dtype:
1683 if a.row_partitions:
1684 b = b.with_dtype(a.dtype)
1685 elif b.row_partitions:
1686 a = a.with_dtype(b.dtype)
1687 else:
1688 a = a.with_dtype(dtypes.int64)
1689 b = b.with_dtype(dtypes.int64)
1691 if (a.rank is None or b.rank is None):
1692 raise ValueError("Unable to broadcast: unknown rank")
1693 elif a.rank == 0:
1694 return (b, _Broadcaster(a, b, []), _get_identity_broadcaster(b))
1695 elif b.rank == 0:
1696 return (a, _get_identity_broadcaster(a), _Broadcaster(b, a, []))
1697 elif a.rank == 1 and b.rank == 1:
1698 [a_layer, b_layer,
1699 target] = _broadcast_dynamic_shape_one_layer(a.inner_shape, b.inner_shape)
1700 target_shape = DynamicRaggedShape._from_inner_shape(target) # pylint: disable=protected-access
1701 return (target_shape, _Broadcaster(a, target_shape, [a_layer]),
1702 _Broadcaster(b, target_shape, [b_layer]))
1704 if a.rank > b.rank:
1705 (c, bc, ac) = _broadcast_dynamic_shape_extended_helper(b, a) # pylint: disable=arguments-out-of-order
1707 return (c, ac, bc)
1709 return _broadcast_dynamic_shape_extended_helper(a, b)
1712def _row_partitions_identical(shape_a, shape_b):
1713 """Returns True iff all row_partitions in shapes are identical."""
1714 return ((shape_a.num_row_partitions == shape_b.num_row_partitions) and all(
1715 a is b for a, b in zip(shape_a.row_partitions, shape_b.row_partitions)))
1718# TODO(martinz): Preserve shapes better (see CL/414806185)
1719@dispatch.dispatch_for_binary_elementwise_apis(ragged_tensor.RaggedOrDense,
1720 ragged_tensor.RaggedOrDense)
1721def ragged_binary_elementwise_op_impl(op, x, y):
1722 """Binary elementwise api handler for RaggedTensors."""
1723 x_is_ragged = ragged_tensor.is_ragged(x)
1724 y_is_ragged = ragged_tensor.is_ragged(y)
1726 # Convert args to tensors.
1727 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1728 x, preferred_dtype=(y.dtype if y_is_ragged else None))
1729 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1730 y, preferred_dtype=x.dtype)
1732 if x_is_ragged and y_is_ragged:
1733 x, y = ragged_tensor.match_row_splits_dtypes(x, y)
1735 if ((x_is_ragged and y_is_ragged) or
1736 (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
1737 (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
1738 shape_x = DynamicRaggedShape.from_tensor(x)
1739 shape_y = DynamicRaggedShape.from_tensor(y)
1740 if shape_x.dtype != shape_y.dtype:
1741 if not x_is_ragged:
1742 shape_x = shape_x.with_dtype(shape_y.dtype)
1743 elif not y_is_ragged:
1744 shape_y = shape_y.with_dtype(shape_x.dtype)
1746 if _row_partitions_identical(shape_x, shape_y):
1747 # At this point, both x and y must be ragged.
1748 return shape_x._add_row_partitions( # pylint: disable=protected-access
1749 op(x.flat_values, y.flat_values),
1750 validate=False)
1752 (shape_z, bcast_xz,
1753 bcast_yz) = broadcast_dynamic_shape_extended(shape_x, shape_y)
1754 x_new_flat = bcast_xz.broadcast_flat_values(x, inner_dimensions=False)
1755 y_new_flat = bcast_yz.broadcast_flat_values(y, inner_dimensions=False)
1756 z_flat = op(x_new_flat, y_new_flat)
1757 return shape_z._add_row_partitions(z_flat, validate=True) # pylint: disable=protected-access
1759 x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
1760 y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
1761 mapped_values = op(x_values, y_values)
1762 if isinstance(mapped_values, bool):
1763 return mapped_values # Special case for tensor_equals.
1764 if ragged_tensor.is_ragged(x):
1765 return x.with_flat_values(mapped_values)
1766 else:
1767 return y.with_flat_values(mapped_values)
1770@dispatch.dispatch_for_binary_elementwise_assert_apis(
1771 ragged_tensor.RaggedOrDense, ragged_tensor.RaggedOrDense)
1772def ragged_binary_elementwise_assert_op_impl(op, x, y):
1773 """Binary elementwise assert api handler for RaggedTensors.
1775 This handles binary assert operations for ragged tensors. Compared with
1776 `ragged_binary_elementwise_op_impl`, this handler does not compute a ragged
1777 tensor as output. Instead, it applies the assert operation `op` to input
1778 tensors based on their ragged shapes and flat_values, and returns the result
1779 of the assertion operation.
1781 Args:
1782 op: a binary assert operation on Tensors.
1783 x: something that can be coerced to a Tensor or RaggedTensor.
1784 y: something that can be coerced to a Tensor or RaggedTensor.
1786 Returns:
1787 the result of the assertion operation.
1789 """
1790 x_is_ragged = ragged_tensor.is_ragged(x)
1791 y_is_ragged = ragged_tensor.is_ragged(y)
1793 # Convert args to tensors.
1794 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1795 x, preferred_dtype=(y.dtype if y_is_ragged else None))
1796 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1797 y, preferred_dtype=x.dtype)
1799 if x_is_ragged and y_is_ragged:
1800 x, y = ragged_tensor.match_row_splits_dtypes(x, y)
1802 if ((x_is_ragged and y_is_ragged) or
1803 (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
1804 (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
1805 shape_x = DynamicRaggedShape.from_tensor(x)
1806 shape_y = DynamicRaggedShape.from_tensor(y)
1807 if shape_x.dtype != shape_y.dtype:
1808 if not x_is_ragged:
1809 shape_x = shape_x.with_dtype(shape_y.dtype)
1810 elif not y_is_ragged:
1811 shape_y = shape_y.with_dtype(shape_x.dtype)
1813 if _row_partitions_identical(shape_x, shape_y):
1814 # At this point, both x and y must be ragged.
1815 return op(x.flat_values, y.flat_values)
1817 (_, bcast_xz, bcast_yz) = broadcast_dynamic_shape_extended(shape_x, shape_y)
1818 x_new_flat = bcast_xz.broadcast_flat_values(x, inner_dimensions=False)
1819 y_new_flat = bcast_yz.broadcast_flat_values(y, inner_dimensions=False)
1820 return op(x_new_flat, y_new_flat)
1822 x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
1823 y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
1824 return op(x_values, y_values)
1827def _find_dtype_helper(value, preferred):
1828 """Helper for _find_dtype."""
1829 if preferred is not None:
1830 return preferred
1831 elif isinstance(value, RowPartition):
1832 return value.dtype
1833 elif isinstance(value, dtypes.DType):
1834 return value
1835 elif isinstance(value, int):
1836 return None
1837 elif isinstance(value, list):
1838 return None
1839 elif isinstance(value, tuple):
1840 return None
1841 elif isinstance(value, core.Tensor):
1842 return value.dtype
1843 return value.dtype
1846def _find_dtype(value, preferred):
1847 """Returns the preferred dtype of value or preferred if preferred != None.
1849 This is used as an operator to pass over multiple objects in decreasing order
1850 of priority until there is a preferred dtype for one. For example, if you were
1851 adding three tensor-ish things (some tensors, some lists), and needed a
1852 preferred dtype, you could use this as:
1854 def adding(a, b, c, dtype = None):
1855 dtype = _find_dtype(a, dtype)
1856 dtype = _find_dtype(b, dtype)
1857 dtype = _find_dtype(c, dtype)
1858 if dtype is None:
1859 dtype = tf.float32
1860 ...Code continues here...
1862 Args:
1863 value: a list, value, RowPartition, or tensor.
1864 preferred: a given dtype. If not None, this will be returned.
1866 Returns:
1867 an optional dtype.
1868 """
1869 result = _find_dtype_helper(value, preferred)
1870 if (result == dtypes.int64 or result == dtypes.int32 or result is None):
1871 return result
1872 raise ValueError("Illegal dtype: " + str(result))
1875def _find_dtype_iterable(
1876 iterable: Iterable[Any],
1877 dtype: Optional[dtypes.DType]) -> Optional[dtypes.DType]:
1878 """Find the preferred dtype of a list of objects.
1880 This will go over the iterable, and use the first object with a preferred
1881 dtype. The dtype passed has highest priority if it is not None.
1883 Args:
1884 iterable: an iterable with things that might have a dtype.
1885 dtype: an overriding dtype, or None.
1887 Returns:
1888 an optional dtype.
1889 """
1890 if dtype is not None:
1891 return dtype
1892 for x in iterable:
1893 dtype = _find_dtype(x, dtype)
1894 return dtype
1897class _LayerBroadcaster(abc.ABC):
1898 """A broadcaster of a single layer.
1900 Although this class does not literally contain a gather_index, the reference
1901 implementation is defined through a gather_index. Thus, any subclasses should
1902 first define the gather_index property. Other functions can be overridden
1903 for optimization, but it should not change the behavior.
1904 """
1906 @property
1907 @abc.abstractmethod
1908 def gather_index(self):
1909 """Returns a 1D tensor.
1911 The size of the 1D tensor is equal to the destination size.
1913 The ith element of the result is the index of the source of the ith element.
1914 """
1915 pass
1917 @property
1918 def dtype(self):
1919 """Returns the dtype of the broadcast."""
1920 return self.gather_index.dtype
1922 @abc.abstractmethod
1923 def with_dtype(self, dtype):
1924 """Returns an identical _LayerBroadcaster with a different dtype."""
1925 pass
1927 def __repr__(self):
1928 return str(self.gather_index)
1930 @classmethod
1931 def from_gather_index(cls, gather_index):
1932 """Create a broadcaster from a gather_index."""
1933 return _GatherLayerBroadcaster(gather_index)
1935 @classmethod
1936 def first_layer(cls, nrows_source, nrows_target):
1937 """Create a broadcaster from a gather_index."""
1938 gather_index = _first_layer_gather_index(nrows_source, nrows_target)
1939 return _LayerBroadcaster.from_gather_index(gather_index)
1941 @classmethod
1942 def get_singleton_broadcaster(cls, target_size):
1943 """Broadcast from 1 element to target_size elements."""
1944 return _LayerBroadcaster.from_gather_index(
1945 array_ops.zeros(target_size, dtype=target_size.dtype))
1947 @abc.abstractmethod
1948 def with_dependencies(self, checks):
1949 """Add dependencies to a _LayerBroadcaster.
1951 Args:
1952 checks: a list of ops that need to be run before any tensors from the
1953 Broadcaster are used.
1955 Returns:
1956 a copy of this _LayerBroadcaster with dependencies added.
1957 """
1958 pass
1960 @classmethod
1961 def get_identity_broadcaster(cls, nvals, dtype=None):
1962 """Create an identity broadcaster.
1964 TODO(martinz): an identity broadcaster can be far more efficient than a
1965 generic broadcaster. Add an optimized implementation.
1966 Args:
1967 nvals: the number of values for the broadcaster.
1968 dtype: the dtype of the broadcaster, or None to use the dtype of nvals.
1970 Returns:
1971 an identity broadcaster from [0....nvals-1] to [0...nvals-1]
1972 """
1973 return _GatherLayerBroadcaster(math_ops.range(nvals, dtype=dtype))
1975 def broadcast_tensor(self, tensor):
1976 """Broadcast from a dense tensor.
1978 It is assumed that the first axis of the dense tensor is indexed by the
1979 source shape, and at the end, the first axis of the dense tensor is
1980 indexed by the destination shape.
1982 Args:
1983 tensor: a dense tensor.
1985 Returns:
1986 A dense tensor.
1987 """
1988 return array_ops.gather(tensor, self.gather_index)
1990 def dest_nrows(self):
1991 """Return the number of rows in the resulting gather, or None if tiling."""
1992 return math_ops.cast(
1993 array_ops.shape(self.gather_index)[0], dtype=self.dtype)
1995 def broadcast_row_partition(self, rp):
1996 """Return a new shape where the rows are broadcasted.
1998 *--self--->*
1999 | |
2000 rp result
2001 | |
2002 V V
2003 *--------->*
2005 This is equivalent to:
2006 return RowPartition.from_row_lengths(self.broadcast(rp.row_lengths()))
2008 However, if the shape has uniform row length, then that property is
2009 maintained.
2011 Args:
2012 rp: a row partition.
2014 Returns:
2015 a RowPartition representing a broadcast version of this row partition.
2016 """
2017 if not rp.is_uniform():
2018 return RowPartition.from_row_lengths(
2019 self.broadcast_tensor(rp.row_lengths()))
2020 else:
2021 return RowPartition.from_uniform_row_length(
2022 rp.uniform_row_length(),
2023 nvals=rp.uniform_row_length() * self.dest_nrows(),
2024 nrows=self.dest_nrows())
2026 def next_layer(self, original_rp, broadcast_rp):
2027 r"""Create the next layer gather_index whether or not a broadcast happens.
2029 *---------self------->*
2030 | |
2031 original_rp broadcast_rp
2032 | |
2033 \|/ \|/
2034 *--next_broadcaster-->*
2035 Args:
2036 original_rp: the original row partition.
2037 broadcast_rp: the target row partition.
2039 Returns:
2040 the gather_index for next_broadcaster.
2042 """
2043 gather_index = _next_layer_gather_index(self, original_rp, broadcast_rp)
2044 return _LayerBroadcaster.from_gather_index(gather_index)
2047class _GatherLayerBroadcaster(_LayerBroadcaster):
2048 """Implements _LayerBroadcaster with an explicit gather_index.
2050 For example, suppose that the source shape is:
2051 [*],[*,*]
2052 And the target shape is:
2053 [*],[*,*],[*],[*,*]
2054 Then, this can be represented with a map:
2055 [0,1,2,0,1,2]
2057 """
2059 def __init__(self, gather_index):
2060 gather_index = ops.convert_to_tensor(gather_index)
2061 if (gather_index.dtype != dtypes.int64 and
2062 gather_index.dtype != dtypes.int32):
2063 raise ValueError("gather_index must be int64 or int32")
2064 self._gather_index = gather_index
2066 @property
2067 def gather_index(self):
2068 return self._gather_index
2070 def with_dtype(self, dtype):
2071 return _GatherLayerBroadcaster(math_ops.cast(self._gather_index, dtype))
2073 def with_dependencies(self, checks):
2074 new_gather_index = control_flow_ops.with_dependencies(
2075 checks, self._gather_index)
2076 return _GatherLayerBroadcaster(new_gather_index)
2079class _Broadcaster:
2080 """A _Broadcaster represents a transformation from one shape to another.
2082 It provides a transform for each axis of the source shape to the
2083 corresponding axis of the destination shape.
2085 """
2087 def __init__(self,
2088 source_shape,
2089 target_shape,
2090 layer_broadcasters,
2091 dtype=None):
2092 """Create a broadcaster.
2094 Do not call directly.
2095 The source_shape, target_shape, and layer_broadcasters are converted
2096 to have the same dtype.
2098 Note: source_shape.rank and target_shape.rank must be known.
2099 Args:
2100 source_shape: the source DynamicRaggedShape
2101 target_shape: the target DynamicRaggedShape
2102 layer_broadcasters: List[_LayerBroadcaster] of length source_shape.rank.
2103 dtype: the preferred dtype of the broadcaster.
2105 Raises:
2106 TypeError: if the input types don't match.
2107 """
2108 if not isinstance(source_shape, DynamicRaggedShape):
2109 raise TypeError("source_shape is not a DynamicRaggedShape")
2110 if not isinstance(target_shape, DynamicRaggedShape):
2111 raise TypeError("target_shape is not a DynamicRaggedShape")
2112 if not isinstance(layer_broadcasters, list):
2113 raise TypeError("layer_broadcasters not a list: " +
2114 str(layer_broadcasters))
2115 for bc in layer_broadcasters:
2116 if not isinstance(bc, _LayerBroadcaster):
2117 raise TypeError("Not a LayerBroadcaster: " + str(bc))
2119 dtype = _find_dtype(source_shape, dtype)
2120 dtype = _find_dtype(target_shape, dtype)
2121 dtype = _find_dtype_iterable(layer_broadcasters, dtype)
2122 dtype = _find_dtype(dtypes.int64, dtype)
2123 self._source_shape = source_shape.with_dtype(dtype)
2124 self._target_shape = target_shape.with_dtype(dtype)
2125 self._layer_broadcasters = [x.with_dtype(dtype) for x in layer_broadcasters]
2127 def __repr__(self):
2128 return ("{src_shape:" + str(self._source_shape) + ", target_shape:" +
2129 str(self._target_shape) + " layer_broadcasters: " +
2130 str(self._layer_broadcasters) + "}")
2132 def with_dtype(self, dtype):
2133 """Return a copy of this Broadcaster with a different dtype."""
2134 return _Broadcaster(self._source_shape, self._target_shape,
2135 self._layer_broadcasters, dtype)
2137 @property
2138 def source_shape(self):
2139 return self._source_shape
2141 @property
2142 def target_shape(self):
2143 return self._target_shape
2145 @property
2146 def dtype(self):
2147 return self._source_shape.dtype
2149 def _target_inner_shape_int32(self):
2150 new_inner_shape = self.target_shape.inner_shape
2151 if new_inner_shape.dtype == dtypes.int64:
2152 new_inner_shape = math_ops.cast(new_inner_shape, dtype=dtypes.int32)
2153 return new_inner_shape
2155 # pylint:disable=protected-access
2156 def broadcast_flat_values(self, rt, inner_dimensions=True):
2157 """flat_values of a ragged tensor broadcast to target_shape.
2159 If inner_dimensions==True, then the result is a dense tensor with shape
2160 target_shape.inner_shape, the flat values of the broadcasted shape.
2162 If you add target_shape.row_partitions, you will get the full broadcasted
2163 shape.
2165 If inner_dimensions==False, the result is a dense tensor that satsifies
2166 certain properties:
2167 1. broadcast_to(result, target_shape.inner_shape) will give the result
2168 if inner_dimensions==True.
2169 2. Either (a) (result.rank < target_shape.inner_rank)
2170 or (b) (result.shape[0] == target_shape.inner_shape[0]).
2171 3. result.rank = min(target_shape.inner_rank, rt.rank)
2172 4. For i < target_shape.inner_rank - 1, and i < rt.rank,
2173 and if rt.shape[-i]!=1, then result.shape[-i]=target_shape[-i].
2174 Args:
2175 rt: a ragged or dense tensor.
2176 inner_dimensions: if true, broadcast the inner dimensions as well.
2178 Returns:
2179 a dense tensor
2180 """
2181 if ragged_tensor.is_ragged(rt):
2182 rt = rt.flat_values
2183 # If rt was a regular tensor, it is its own flat_values.
2184 if self.target_shape.rank == 0:
2185 return rt
2186 inner_rank = self.target_shape.inner_rank
2187 if inner_rank > self._source_shape.rank:
2188 # The dense rank is larger than the whole shape. So, we make the shape
2189 # dense.
2190 if self.source_shape.num_row_partitions > 0:
2191 rt = array_ops.reshape(
2192 rt, self.source_shape._alt_inner_shape(self.source_shape.rank))
2193 # rt.rank == self._source_shape.rank < inner_rank
2194 # Here, property 2a holds.
2195 if inner_dimensions:
2196 return array_ops.broadcast_to(rt, self._target_inner_shape_int32())
2197 return rt
2198 else:
2199 if self._source_shape.inner_rank != inner_rank:
2200 rt = array_ops.reshape(rt,
2201 self._source_shape._alt_inner_shape(inner_rank)) # pylint:disable=protected-access
2202 # After the reshape, rt is flat_values with inner_rank.
2203 flat_broadcaster = self._layer_broadcasters[-inner_rank]
2204 rt = flat_broadcaster.broadcast_tensor(rt)
2205 # Here, property 2b holds.
2206 if inner_dimensions:
2207 rt = array_ops.broadcast_to(rt, self._target_inner_shape_int32())
2208 return rt
2210 def broadcast(self, rt):
2211 """Broadcast a tensor of source_shape to target_shape."""
2212 flat_values = self.broadcast_flat_values(rt)
2213 return self.target_shape._add_row_partitions(flat_values) # pylint:disable=protected-access
2216def _get_layer_broadcasters_from_rps(zero_broadcaster, source_rps, target_rps):
2217 """Get LayerBroadcasters from RowPartitions.
2219 *--zero_broadcaster->*
2220 | |
2221 source_rps[0] target_rps[0]
2222 | |
2223 V V
2224 *---result[1]------->*
2225 | |
2226 source_rps[1] target_rps[1]
2227 | |
2228 V V
2229 *---result[2]------->*
2230 .
2231 .
2232 .
2233 *---result[k-1]----->*
2234 | |
2235 source_rps[k] target_rps[k]
2236 | |
2237 V V
2238 *---result[k]------->*
2240 Note: result[0] = zero_broadcaster
2242 Args:
2243 zero_broadcaster: a broadcaster between the source and target row
2244 partitions' rows, and equal to result[0].
2245 source_rps: source row partitions.
2246 target_rps: target row partitions (same length as source_rps).
2248 Returns:
2249 result: a list of LayerBroadcasters.
2250 """
2251 if not isinstance(zero_broadcaster, _LayerBroadcaster):
2252 raise TypeError("Not a _LayerBroadcaster: " + str(zero_broadcaster))
2253 assert len(source_rps) == len(target_rps)
2254 if not source_rps:
2255 return [zero_broadcaster]
2256 next_broadcaster = zero_broadcaster.next_layer(source_rps[0], target_rps[0])
2257 tail_broadcasters = _get_layer_broadcasters_from_rps(next_broadcaster,
2258 source_rps[1:],
2259 target_rps[1:])
2260 return [zero_broadcaster] + tail_broadcasters
2263def _get_broadcaster(source_shape, target_shape):
2264 """Get a _Broadcaster from source_shape to target_shape."""
2265 if source_shape.dtype != target_shape.dtype:
2266 raise ValueError("The source and target row_split dtypes should be equal")
2268 if (source_shape.rank is None or target_shape.rank is None):
2269 raise ValueError("Rank of source and target must be statically known")
2270 elif source_shape.rank > target_shape.rank:
2271 raise ValueError("Cannot broadcast to a shape with smaller rank")
2272 elif source_shape.rank == 0:
2273 return _Broadcaster(source_shape, target_shape, [])
2274 elif target_shape.rank == 1:
2275 assert source_shape.rank == 1
2276 layer = _LayerBroadcaster.first_layer(source_shape.inner_shape[0],
2277 target_shape.inner_shape[0])
2278 return _Broadcaster(source_shape, target_shape, [layer])
2280 assert source_shape.rank <= target_shape.rank
2281 assert target_shape.rank >= 2
2282 assert source_shape.rank >= 1
2284 source_rps = source_shape._as_row_partitions() # pylint: disable=protected-access
2286 target_rps = target_shape._as_row_partitions() # pylint: disable=protected-access
2288 assert len(target_rps) >= 1
2289 assert len(source_rps) <= len(target_rps)
2290 source_nrows = source_shape[0]
2291 if len(source_rps) < len(target_rps):
2292 # Note: this includes the case where len(source_rps)==0.
2293 # Here we begin at -1, one dimension before source_rps[0].
2294 # neg_one_source_rp | neg_one_target_rp=target_rps[-(len(source_rps)+1)]
2295 # source_rps[0] | target_rps[-len(source_rps)]
2296 # source_rps[1] | target_rps[1-len(source_rps)]
2297 # ... | ...
2298 # source_rps[-1] | target_rps[-1]
2299 neg_one_source_rp = RowPartition.from_uniform_row_length(
2300 uniform_row_length=source_nrows, nrows=1, nvals=source_nrows)
2301 neg_one_target_rp = target_rps[-(len(source_rps) + 1)]
2302 neg_one_broadcaster = _LayerBroadcaster.get_singleton_broadcaster(
2303 neg_one_target_rp.nrows())
2304 zeroth_broadcaster = neg_one_broadcaster.next_layer(neg_one_source_rp,
2305 neg_one_target_rp)
2306 target_rps_tail = target_rps[-len(source_rps):] if len(
2307 source_rps) >= 1 else []
2309 layers = _get_layer_broadcasters_from_rps(zeroth_broadcaster, source_rps,
2310 target_rps_tail)
2311 return _Broadcaster(source_shape, target_shape, layers)
2312 else:
2313 assert len(target_rps) == len(source_rps)
2314 zeroth_broadcaster = _LayerBroadcaster.first_layer(source_rps[0].nrows(),
2315 target_rps[0].nrows())
2316 layers = _get_layer_broadcasters_from_rps(zeroth_broadcaster, source_rps,
2317 target_rps)
2319 return _Broadcaster(source_shape, target_shape, layers)
2322def _get_identity_broadcaster(shape):
2323 """Gets a Broadcaster for two identical shapes."""
2324 if shape.rank is None:
2325 raise ValueError("Shape must have a defined rank")
2326 layers = [
2327 _LayerBroadcaster.get_identity_broadcaster(
2328 shape._num_slices_in_dimension(i)) for i in range(shape.rank) # pylint: disable=protected-access
2329 ]
2330 return _Broadcaster(shape, shape, layers)
2333def _broadcast_dynamic_shape_one_layer(a, b):
2334 """Broadcast two vectors, given their shapes.
2336 Args:
2337 a: the number of rows in a.
2338 b: the number of rows in b.
2340 Returns:
2341 (layer_a, layer_b, target_shape)
2342 layer_a is a _LayerBroadcaster from a to the target_shape.
2343 layer_b is a _LayerBroadcaster from b to the target_shape.
2344 target_shape is the target_shape
2346 Raises:
2347 InvalidArgumentError if the shapes are not consistent.
2348 """
2349 a_0 = a[0]
2350 b_0 = b[0]
2352 def broadcast_from_a():
2353 # Assumes a_0 == 1
2354 a_layer = array_ops.zeros(b_0, dtype=b_0.dtype)
2355 b_layer = math_ops.range(b_0)
2356 target = b
2357 return [a_layer, b_layer, target]
2359 a_static = tensor_util.constant_value(a)
2360 if a_static is not None and a_static[0] == 1:
2361 [a_gi, b_gi, target] = broadcast_from_a()
2362 a_layer = _LayerBroadcaster.from_gather_index(a_gi)
2363 b_layer = _LayerBroadcaster.from_gather_index(b_gi)
2364 return [a_layer, b_layer, target]
2366 def broadcast_from_b():
2367 # Assumes b_0 == 1
2368 a_layer = math_ops.range(a_0)
2369 b_layer = array_ops.zeros(a_0, dtype=a_0.dtype)
2370 target = a
2371 return [a_layer, b_layer, target]
2373 b_static = tensor_util.constant_value(b)
2374 if b_static is not None and b_static[0] == 1:
2375 [a_gi, b_gi, target] = broadcast_from_b()
2376 a_layer = _LayerBroadcaster.from_gather_index(a_gi)
2377 b_layer = _LayerBroadcaster.from_gather_index(b_gi)
2378 return [a_layer, b_layer, target]
2380 def broadcast_noop():
2381 # Assumes a_0 == 1
2382 a_layer = math_ops.range(a_0)
2383 b_layer = math_ops.range(b_0)
2384 target = b
2385 return [a_layer, b_layer, target]
2387 can_broadcast_from_a = math_ops.equal(a_0, 1)
2388 can_broadcast_from_b = math_ops.equal(b_0, 1)
2390 def broadcast_not_from_a():
2391 return cond.cond(
2392 can_broadcast_from_b, true_fn=broadcast_from_b, false_fn=broadcast_noop)
2394 nrows_equal = math_ops.equal(a_0, b_0)
2395 can_broadcast = math_ops.logical_or(
2396 can_broadcast_from_a,
2397 math_ops.logical_or(can_broadcast_from_b, nrows_equal))
2399 check_can_broadcast = check_ops.assert_equal(
2400 can_broadcast, True, message="Cannot broadcast")
2402 results = cond.cond(
2403 can_broadcast_from_a,
2404 true_fn=broadcast_from_a,
2405 false_fn=broadcast_not_from_a)
2407 results = [
2408 control_flow_ops.with_dependencies([check_can_broadcast], x)
2409 for x in results
2410 ]
2411 [a_gi, b_gi, target] = results
2412 a_layer = _LayerBroadcaster.from_gather_index(a_gi)
2413 b_layer = _LayerBroadcaster.from_gather_index(b_gi)
2414 return [a_layer, b_layer, target]
2417def _broadcast_dynamic_shape_first_layer(a_0, b_0):
2418 """Broadcast the first layer of two dynamic shapes given the dimensions.
2420 Args:
2421 a_0: the number of rows in a.
2422 b_0: the number of rows in b.
2424 Returns:
2425 (use_a, layer_a, layer_b)
2426 where use_a is true if the target provably equals a, false otherwise.
2427 layer_a is a _LayerBroadcaster from a to the target.
2428 layer_b is a _LayerBroadcaster from b to the target.
2429 """
2431 def broadcast_from_a():
2432 # Assumes a_0 == 1
2433 a_layer = array_ops.zeros(b_0, dtype=b_0.dtype)
2434 b_layer = math_ops.range(b_0)
2435 return [a_layer, b_layer]
2437 static_a_0 = tensor_util.constant_value(a_0)
2438 static_b_0 = tensor_util.constant_value(b_0)
2439 if static_a_0 is not None:
2440 if static_a_0 == static_b_0:
2441 id_broadcaster = _LayerBroadcaster.get_identity_broadcaster(
2442 static_a_0, dtype=a_0.dtype)
2443 return [id_broadcaster, id_broadcaster]
2444 elif static_a_0 == 1:
2445 return [
2446 _LayerBroadcaster.get_singleton_broadcaster(b_0),
2447 _LayerBroadcaster.get_identity_broadcaster(b_0)
2448 ]
2450 if static_b_0 == 1:
2451 return [
2452 _LayerBroadcaster.get_identity_broadcaster(a_0),
2453 _LayerBroadcaster.get_singleton_broadcaster(a_0)
2454 ]
2456 def broadcast_from_b():
2457 # Assumes b_0 == 1
2458 a_layer = math_ops.range(a_0)
2459 b_layer = array_ops.zeros(a_0, dtype=a_0.dtype)
2460 return [a_layer, b_layer]
2462 def broadcast_noop():
2463 # Assumes a_0 == b_0
2464 a_layer = math_ops.range(a_0)
2465 b_layer = math_ops.range(b_0)
2466 return [a_layer, b_layer]
2468 can_broadcast_from_a = math_ops.equal(a_0, constant_op.constant(1, a_0.dtype))
2469 can_broadcast_from_b = math_ops.equal(b_0, constant_op.constant(1, b_0.dtype))
2471 def broadcast_not_from_a():
2472 return cond.cond(
2473 can_broadcast_from_b, true_fn=broadcast_from_b, false_fn=broadcast_noop)
2475 # Ideally, this would only block control flow on broadcast_noop, but
2476 # the control flow doesn't seem to work.
2477 can_broadcast = math_ops.logical_or(
2478 math_ops.logical_or(can_broadcast_from_a, can_broadcast_from_b),
2479 math_ops.equal(a_0, b_0))
2481 result = cond.cond(
2482 can_broadcast_from_a,
2483 true_fn=broadcast_from_a,
2484 false_fn=broadcast_not_from_a)
2486 return [
2487 _LayerBroadcaster.from_gather_index(
2488 control_flow_ops.with_dependencies(
2489 [check_ops.assert_equal(can_broadcast, True)], x)) for x in result
2490 ]
2493def _broadcast_half(
2494 ac_0: _LayerBroadcaster,
2495 a_1: RowPartition) -> Tuple[_LayerBroadcaster, RowPartition]:
2496 """Does a NOOP broadcast of a_1.
2498 *-ac_0-->*
2499 | |
2500 a_1 c_1
2501 | |
2502 V V
2503 *-ac_1-->*
2505 Note that by definition this cannot fail: there is always a well-defined
2506 NOOP broadcast. This is usually intended as half of broadcasting two shapes
2507 together.
2508 Args:
2509 ac_0: previous LayerBroadcaster
2510 a_1: previous RowPartition
2512 Returns:
2513 [ac_1, c_1] where ac_1 is the next LayerBroadcaster, and c_1 is the
2514 broadcast RowPartition
2515 """
2516 c_1 = ac_0.broadcast_row_partition(a_1)
2517 old_value_rowids = array_ops.gather(ac_0.gather_index, c_1.value_rowids())
2518 old_row_starts = array_ops.gather(a_1.row_splits(), old_value_rowids)
2519 gather_index = old_row_starts + c_1.offsets_in_rows()
2520 return [_LayerBroadcaster.from_gather_index(gather_index), c_1]
2523def _broadcast_dynamic_shape_next_layer_half_ragged(
2524 ac_0: _LayerBroadcaster, bc_0: _LayerBroadcaster, a_1: RowPartition,
2525 b_1: RowPartition
2526) -> Tuple[RowPartition, _LayerBroadcaster, _LayerBroadcaster]:
2527 r"""Broadcast target and next layer broadcaster of two dynamic shapes.
2529 a_1 is uniform, and b_1 is ragged.
2530 *--ac_0-->*<--bc_0--*
2531 | | |
2532 a_1 c_1 b_1
2533 | | |
2534 V V V
2535 *--ac_1-->*<--bc_1--*
2537 Args:
2538 ac_0: _LayerBroadcaster from a to c in the previous layer.
2539 bc_0: _LayerBroadcaster from b to c in the previous layer.
2540 a_1: a uniform RowPartition for the next layer of a.
2541 b_1: a ragged RowPartition for the next layer of b.
2543 Returns:
2544 (c_1, ac_1, bc_1)
2545 c_1: a RowPartition for the next layer of the dynamic shape.
2546 ac_1: _LayerBroadcaster from a to c in the next layer.
2547 bc_1: _LayerBroadcaster from b to c in the next layer.
2548 """
2549 if not isinstance(ac_0, _LayerBroadcaster):
2550 raise TypeError("ac_0 should be a _LayerBroadcaster")
2551 if not isinstance(bc_0, _LayerBroadcaster):
2552 raise TypeError("bc_0 should be a _LayerBroadcaster")
2553 if not isinstance(a_1, RowPartition):
2554 raise TypeError("a_1 should be a RowPartition")
2555 if not isinstance(b_1, RowPartition):
2556 raise TypeError("b_1 should be a RowPartition")
2558 assert a_1.is_uniform()
2559 assert not b_1.is_uniform()
2561 static_a_1 = tensor_util.constant_value(a_1.uniform_row_length())
2562 if static_a_1 == 1:
2563 [bc_1, c_1b] = _broadcast_half(bc_0, b_1)
2564 ac_1_gather_index = array_ops.gather(ac_0.gather_index, c_1b.value_rowids())
2565 c_1 = RowPartition.from_row_splits(c_1b.row_splits())
2566 ac_1 = _LayerBroadcaster.from_gather_index(ac_1_gather_index)
2567 bc_1 = _LayerBroadcaster.from_gather_index(bc_1.gather_index)
2568 return [c_1, ac_1, bc_1]
2570 def broadcast_noop():
2571 # The sides must be "equal".
2572 [ac_1, c_1a] = _broadcast_half(ac_0, a_1)
2573 [bc_1, c_1b] = _broadcast_half(bc_0, b_1)
2574 checks = [check_ops.assert_equal(c_1a.row_splits(), c_1b.row_splits())]
2575 return [
2576 control_flow_ops.with_dependencies(checks, x)
2577 for x in [a_1.row_splits(), ac_1.gather_index, bc_1.gather_index]
2578 ]
2580 def broadcast_a():
2581 [bc_1, c_1b] = _broadcast_half(bc_0, b_1)
2582 ac_1_gather_index = array_ops.gather(ac_0.gather_index, c_1b.value_rowids())
2583 return [
2584 c_1b.row_splits(),
2585 ac_1_gather_index,
2586 bc_1.gather_index,
2587 ]
2589 can_broadcast_a = math_ops.equal(a_1.uniform_row_length(), 1)
2591 [c_1_row_splits, ac_1_gather_index,
2592 bc_1_gather_index] = cond.cond(
2593 can_broadcast_a, true_fn=broadcast_a, false_fn=broadcast_noop)
2595 c_1 = RowPartition.from_row_splits(c_1_row_splits)
2596 ac_1 = _LayerBroadcaster.from_gather_index(ac_1_gather_index)
2597 bc_1 = _LayerBroadcaster.from_gather_index(bc_1_gather_index)
2598 return [c_1, ac_1, bc_1]
2601def _broadcast_dynamic_shape_next_layer_both_uniform(
2602 ac_0: _LayerBroadcaster, bc_0: _LayerBroadcaster, a_1: RowPartition,
2603 b_1: RowPartition
2604) -> Tuple[RowPartition, _LayerBroadcaster, _LayerBroadcaster]:
2605 r"""Broadcast target and next layer broadcaster of two uniform dynamic shapes.
2607 *--ac_0-->*<--bc_0--*
2608 | | |
2609 a_1 c_1 b_1
2610 | | |
2611 V V V
2612 *--ac_1-->*<--bc_1--*
2614 Args:
2615 ac_0: _LayerBroadcaster from a to c in the previous layer.
2616 bc_0: _LayerBroadcaster from b to c in the previous layer.
2617 a_1: a RowPartition for the next layer of a.
2618 b_1: a RowPartition for the next layer of b.
2620 Returns:
2621 (c_1, ac_1, bc_1)
2622 c_1: a RowPartition for the next layer of the dynamic shape.
2623 ac_1: _LayerBroadcaster from a to c in the next layer.
2624 bc_1: _LayerBroadcaster from b to c in the next layer.
2625 """
2626 if not isinstance(ac_0, _LayerBroadcaster):
2627 raise TypeError("ac_0 should be a _LayerBroadcaster")
2628 if not isinstance(bc_0, _LayerBroadcaster):
2629 raise TypeError("bc_0 should be a _LayerBroadcaster")
2630 if not isinstance(a_1, RowPartition):
2631 raise TypeError("a_1 should be a RowPartition")
2632 if not isinstance(b_1, RowPartition):
2633 raise TypeError("b_1 should be a RowPartition")
2634 assert a_1.is_uniform()
2635 assert b_1.is_uniform()
2637 static_a_1 = tensor_util.constant_value(a_1.uniform_row_length())
2638 static_b_1 = tensor_util.constant_value(b_1.uniform_row_length())
2640 if static_a_1 is not None:
2641 if static_a_1 == static_b_1:
2642 # Here, this dimension is the same, but we may have to broadcast previous
2643 # dimensions.
2644 [ac_1, _] = _broadcast_half(ac_0, a_1)
2645 [bc_1, _] = _broadcast_half(bc_0, b_1)
2646 c_1 = RowPartition.from_uniform_row_length(
2647 static_a_1, nrows=ac_0.dest_nrows())
2648 return [c_1, ac_1, bc_1]
2649 elif static_a_1 == 1:
2650 [bc_1, c_1b] = _broadcast_half(bc_0, b_1)
2651 ac_1 = _LayerBroadcaster.from_gather_index(
2652 array_ops.gather(ac_0.gather_index, c_1b.value_rowids()))
2653 c_1 = RowPartition.from_uniform_row_length(
2654 b_1.uniform_row_length(), nrows=bc_0.dest_nrows())
2655 return [c_1, ac_1, bc_1]
2657 if static_b_1 == 1:
2658 [ac_1, c_1a] = _broadcast_half(ac_0, a_1)
2659 bc_1 = _LayerBroadcaster.from_gather_index(
2660 array_ops.gather(bc_0.gather_index, c_1a.value_rowids()))
2661 c_1 = RowPartition.from_uniform_row_length(
2662 a_1.uniform_row_length(), nrows=ac_0.dest_nrows())
2663 return [c_1, ac_1, bc_1]
2665 def broadcast_noop():
2666 # Assumes a_1.uniform_row_length() == b_1.uniform_row_length()
2667 # Both sides broadcast to a single shape.
2668 [ac_1, _] = _broadcast_half(ac_0, a_1)
2669 [bc_1, _] = _broadcast_half(bc_0, b_1)
2670 return [a_1.uniform_row_length(), ac_1.gather_index, bc_1.gather_index]
2672 def broadcast_a():
2673 [bc_1, c_1b] = _broadcast_half(bc_0, b_1)
2674 ac_1_gather_index = array_ops.gather(ac_0.gather_index, c_1b.value_rowids())
2675 return [
2676 b_1.uniform_row_length(),
2677 ac_1_gather_index,
2678 bc_1.gather_index,
2679 ]
2681 def broadcast_b():
2682 [ac_1, c_1a] = _broadcast_half(ac_0, a_1)
2683 bc_1_gather_index = array_ops.gather(bc_0.gather_index, c_1a.value_rowids())
2684 return [a_1.uniform_row_length(), ac_1.gather_index, bc_1_gather_index]
2686 can_broadcast_b = math_ops.equal(b_1.uniform_row_length(), 1)
2688 def no_broadcast_a():
2689 return cond.cond(
2690 can_broadcast_b, true_fn=broadcast_b, false_fn=broadcast_noop)
2692 can_broadcast_a = math_ops.equal(a_1.uniform_row_length(), 1)
2694 broadcast_asserts = [
2695 check_ops.assert_equal(
2696 math_ops.logical_or(
2697 math_ops.logical_or(can_broadcast_a, can_broadcast_b),
2698 math_ops.equal(a_1.uniform_row_length(),
2699 b_1.uniform_row_length())), True)
2700 ]
2702 result = cond.cond(
2703 can_broadcast_a, true_fn=broadcast_a, false_fn=no_broadcast_a)
2705 [c_1_uniform_row_length, ac_1_gather_index, bc_1_gather_index] = [
2706 control_flow_ops.with_dependencies(broadcast_asserts, x) for x in result
2707 ]
2709 c_1 = RowPartition.from_uniform_row_length(
2710 c_1_uniform_row_length,
2711 nvals=c_1_uniform_row_length * ac_0.dest_nrows(),
2712 nrows=ac_0.dest_nrows())
2713 ac_1 = _LayerBroadcaster.from_gather_index(ac_1_gather_index)
2714 bc_1 = _LayerBroadcaster.from_gather_index(bc_1_gather_index)
2715 return [c_1, ac_1, bc_1]
2718def _broadcast_dynamic_shape_next_layer(
2719 ac_0: _LayerBroadcaster, bc_0: _LayerBroadcaster, a_1: RowPartition,
2720 b_1: RowPartition
2721) -> Tuple[RowPartition, _LayerBroadcaster, _LayerBroadcaster]:
2722 r"""Broadcast target and next layer broadcaster of two dynamic shapes.
2724 *--ac_0-->*<--bc_0--*
2725 | | |
2726 a_1 c_1 b_1
2727 | | |
2728 V V V
2729 *--ac_1-->*<--bc_1--*
2731 Args:
2732 ac_0: _LayerBroadcaster from a to c in the previous layer.
2733 bc_0: _LayerBroadcaster from b to c in the previous layer.
2734 a_1: a RowPartition for the next layer of a.
2735 b_1: a RowPartition for the next layer of b.
2737 Returns:
2738 (c_1, ac_1, bc_1)
2739 c_1: a RowPartition for the next layer of the dynamic shape.
2740 ac_1: _LayerBroadcaster from a to c in the next layer.
2741 bc_1: _LayerBroadcaster from b to c in the next layer.
2742 """
2743 if not isinstance(ac_0, _LayerBroadcaster):
2744 raise TypeError("ac_0 should be a _LayerBroadcaster")
2745 if not isinstance(bc_0, _LayerBroadcaster):
2746 raise TypeError("bc_0 should be a _LayerBroadcaster")
2747 if not isinstance(a_1, RowPartition):
2748 raise TypeError("a_1 should be a RowPartition")
2749 if not isinstance(b_1, RowPartition):
2750 raise TypeError("b_1 should be a RowPartition")
2752 if a_1.is_uniform():
2753 if b_1.is_uniform():
2754 return _broadcast_dynamic_shape_next_layer_both_uniform(
2755 ac_0, bc_0, a_1, b_1)
2756 else:
2757 return _broadcast_dynamic_shape_next_layer_half_ragged(
2758 ac_0, bc_0, a_1, b_1)
2759 else:
2760 if b_1.is_uniform():
2761 [c_1, bc_1, ac_1] = _broadcast_dynamic_shape_next_layer_half_ragged( # pylint: disable=arguments-out-of-order
2762 bc_0, ac_0, b_1, a_1)
2763 return (c_1, ac_1, bc_1)
2764 else:
2765 # If neither shape is uniform, we cannot broadcast the dimension.
2766 [ac_1, c_1a] = _broadcast_half(ac_0, a_1)
2767 [bc_1, c_1b] = _broadcast_half(bc_0, b_1)
2768 check_valid = [
2769 check_ops.assert_equal(c_1a.row_splits(), c_1b.row_splits())
2770 ]
2771 return (
2772 c_1a._with_dependencies(check_valid), # pylint: disable=protected-access
2773 ac_1.with_dependencies(check_valid),
2774 bc_1.with_dependencies(check_valid))
2777def _broadcast_dynamic_shape_from_rps(
2778 a_zero: _LayerBroadcaster, b_zero: _LayerBroadcaster,
2779 a_rps: Sequence[RowPartition], b_rps: Sequence[RowPartition]
2780) -> Tuple[Sequence[RowPartition], Sequence[_LayerBroadcaster],
2781 Sequence[_LayerBroadcaster]]:
2782 """Create BroadcastLayers from two shapes to a target shape.
2785 *--a_zero->*<-b_zero-*
2786 | | |
2787 a_rps[0] c_rps[0] b_rps[0]
2788 | | |
2789 V V V
2790 *--ac[1]-->*<-bc[1]--*
2791 | | |
2792 a_rps[1] c_rps[0] b_rps[1]
2793 | | |
2794 V V V
2795 *--ac[2]-->*<-bc[2]--*
2797 Note: ac[0]=a_zero, and bc[0]=b_zero.
2798 Args:
2799 a_zero: broadcaster from rows of a_rps[0] to target shape.
2800 b_zero: broadcaster from rows of b_rps[0] to target shape.
2801 a_rps: RowPartitions of first shape.
2802 b_rps: RowPartitions of second shape, equal in length to a_rps.
2804 Returns:
2805 (c_rps, ac, bc) where:
2806 c_rps: RowPartitions of target shape.
2807 ac: layers broadcasting from the first shape.
2808 bc: layers broadcasting from the second shape.
2809 """
2810 assert len(a_rps) == len(b_rps)
2811 if a_rps:
2812 (c_1, ac_1,
2813 bc_1) = _broadcast_dynamic_shape_next_layer(a_zero, b_zero, a_rps[0],
2814 b_rps[0])
2815 (c_suffix, a_layers,
2816 b_layers) = _broadcast_dynamic_shape_from_rps(ac_1, bc_1, a_rps[1:],
2817 b_rps[1:])
2819 return ([c_1] + c_suffix, [ac_1] + a_layers, [bc_1] + b_layers)
2820 else:
2821 return ([], [], [])
2824def _get_broadcast_num_row_partitions(a: DynamicRaggedShape,
2825 b: DynamicRaggedShape):
2826 """Returns broadcast_dynamic_shape(a, b).num_row_partitions."""
2827 # Assumes rank and num_row_partitions are not None.
2828 if (a.num_row_partitions == 0 and b.num_row_partitions == 0):
2829 return 0
2830 expanded_num_row_partitions_a = a.num_row_partitions + max(0, b.rank - a.rank)
2831 expanded_num_row_partitions_b = b.num_row_partitions + max(0, a.rank - b.rank)
2833 if a.num_row_partitions == 0:
2834 return expanded_num_row_partitions_b
2836 if b.num_row_partitions == 0:
2837 return expanded_num_row_partitions_a
2839 return max(expanded_num_row_partitions_a, expanded_num_row_partitions_b)
2842# pylint: disable=protected-access
2843def _broadcast_dynamic_shape_extended_complete(
2844 a: DynamicRaggedShape, b: DynamicRaggedShape, b_rps: Sequence[RowPartition],
2845 c_suffix: Sequence[RowPartition], ac: Sequence[_LayerBroadcaster],
2846 bc_suffix: Sequence[_LayerBroadcaster]
2847) -> Tuple[DynamicRaggedShape, _Broadcaster, _Broadcaster]:
2848 """Helper for broadcast_dynamic_shape_extended."""
2849 c_prefix = b_rps[:-len(c_suffix)]
2850 bc_prefix_length = b.rank - len(bc_suffix)
2851 bc_prefix = [
2852 _LayerBroadcaster.get_identity_broadcaster(b._num_slices_in_dimension(i))
2853 for i in range(bc_prefix_length)
2854 ]
2855 c_num_row_partitions = _get_broadcast_num_row_partitions(a, b)
2857 c_raw = DynamicRaggedShape.from_row_partitions(c_prefix + tuple(c_suffix))
2858 c = c_raw._with_num_row_partitions(c_num_row_partitions)
2859 return (c, _Broadcaster(a, c, ac), _Broadcaster(b, c, bc_prefix + bc_suffix))
2862def _broadcast_dynamic_shape_extended_helper(
2863 a: DynamicRaggedShape, b: DynamicRaggedShape
2864) -> Tuple[DynamicRaggedShape, _Broadcaster, _Broadcaster]:
2865 """Helper for broadcast_dynamic_shape_extended.
2867 Here, we force:
2868 a.rank <= b.rank
2869 2 <= b.rank
2870 1 <= a.rank
2871 Args:
2872 a: a DynamicRaggedShape
2873 b: a DynamicRaggedShape
2875 Returns:
2876 A triple of a shape and two broadcasters.
2877 """
2878 assert a.rank <= b.rank
2879 assert 2 <= b.rank
2880 assert 1 <= a.rank
2881 a_rps = a._as_row_partitions() # pylint: disable=protected-access
2882 b_rps = b._as_row_partitions() # pylint: disable=protected-access
2884 if len(a_rps) < len(b_rps):
2885 # Note: this includes the case where len(a_rps)==0.
2886 # Here we begin at -1, one dimension before a_rps[0].
2887 # neg_one_a_rp | b_rps[-(len(a_rps)+1)]
2888 # a_rps[0] | b_rps[-len(a_rps)]
2889 # a_rps[1] | b_rps[1-len(a_rps)]
2890 # ... | ...
2891 # a_rps[-1] | b_rps[-1]
2893 a_nrows = a[0]
2894 a_nrows_static = tensor_util.constant_value(a_nrows)
2895 if a_nrows_static is not None:
2896 a_nrows = a_nrows_static
2898 neg_one_a_rp = RowPartition.from_uniform_row_length(
2899 uniform_row_length=a_nrows, nrows=1, nvals=a_nrows)
2900 neg_one_b_rp = b_rps[-(len(a_rps) + 1)]
2901 (neg_one_ac, neg_one_bc) = _broadcast_dynamic_shape_first_layer(
2902 constant_op.constant(1, dtype=b_rps[0].dtype), neg_one_b_rp.nrows())
2904 # The first part of the solution.
2905 (c_zero, ac_zero,
2906 bc_zero) = _broadcast_dynamic_shape_next_layer(neg_one_ac, neg_one_bc,
2907 neg_one_a_rp, neg_one_b_rp)
2908 b_rps_tail = b_rps[-len(a_rps):] if len(a_rps) >= 1 else []
2910 (c_suffix, ac_layers,
2911 bc_layers) = _broadcast_dynamic_shape_from_rps(ac_zero, bc_zero, a_rps,
2912 b_rps_tail)
2914 return _broadcast_dynamic_shape_extended_complete(
2915 a=a,
2916 b=b,
2917 b_rps=b_rps,
2918 c_suffix=[c_zero] + c_suffix,
2919 ac=[ac_zero] + ac_layers,
2920 bc_suffix=[neg_one_bc, bc_zero] + bc_layers)
2922 else:
2923 assert len(a_rps) == len(b_rps)
2924 (ac_zero,
2925 bc_zero) = _broadcast_dynamic_shape_first_layer(a_rps[0].nrows(),
2926 b_rps[0].nrows())
2928 (c_rps, a_layers,
2929 b_layers) = _broadcast_dynamic_shape_from_rps(ac_zero, bc_zero, a_rps,
2930 b_rps)
2931 return _broadcast_dynamic_shape_extended_complete(
2932 a=a,
2933 b=b,
2934 b_rps=b_rps,
2935 c_suffix=c_rps,
2936 ac=[ac_zero] + a_layers,
2937 bc_suffix=[bc_zero] + b_layers)
2940def _fix_start_index(index, rank, num_row_partitions):
2941 """Slice indexes are always silently truncated."""
2942 if index < 0:
2943 if rank is None:
2944 raise ValueError(
2945 "Rank must be known to use __getitem__ on a negative index.")
2946 index = rank + index
2947 if index < 0:
2948 index = 0
2949 if (num_row_partitions > 0 and index <= num_row_partitions + 1):
2950 # The rank is always >= num_row_partitions + 1 if num_row_partitions > 0.
2951 return index
2952 if index == 0:
2953 return index
2954 if rank is None:
2955 raise ValueError("Rank must be known to use __getitem__ on a large index.")
2956 if index >= rank:
2957 index = rank
2958 return index
2961def _fix_stop_index(index, rank):
2962 """Slice indexes are always silently truncated."""
2963 if index is None:
2964 if rank is None:
2965 raise ValueError("Rank must be known to use __getitem__ without a stop.")
2966 index = rank
2967 if index < 0:
2968 if rank is None:
2969 raise ValueError(
2970 "Rank must be known to use __getitem__ on a negative index.")
2971 index = rank + index
2972 if index < 0:
2973 index = 0
2974 if rank is not None:
2975 index = min(rank, index)
2976 return index
2979def _first_layer_gather_index(nrows_source, nrows_target):
2980 """Return the first layer gather_index.
2982 Args:
2983 nrows_source: the number of rows in the source.
2984 nrows_target: the number of rows in the target.
2986 Returns:
2987 A tensor, usable as a gather_index for a _LayerBroadcaster.
2988 """
2990 def gi_broadcast_first():
2991 return array_ops.zeros(nrows_target, dtype=nrows_target.dtype)
2993 def gi_no_broadcast_first():
2994 gather_index = math_ops.range(nrows_target, dtype=nrows_target.dtype)
2995 return gather_index
2997 do_broadcast = math_ops.equal(nrows_source,
2998 constant_op.constant(1, nrows_source.dtype))
2999 nrows_equal = math_ops.equal(nrows_source, nrows_target)
3000 can_broadcast = check_ops.assert_equal(
3001 math_ops.logical_or(do_broadcast, nrows_equal),
3002 True,
3003 message="Cannot broadcast")
3005 gather_index = cond.cond(
3006 do_broadcast, true_fn=gi_broadcast_first, false_fn=gi_no_broadcast_first)
3008 return control_flow_ops.with_dependencies([can_broadcast], gather_index)
3011def _next_layer_gather_index(bc, original_rp, broadcast_rp):
3012 r"""Create the next layer gather_index whether or not a broadcast happens.
3014 *----------bc-------->*
3015 | |
3016 original_rp broadcast_rp
3017 | |
3018 \|/ \|/
3019 *--next_broadcaster-->*
3021 Args:
3022 bc: the old broadcaster.
3023 original_rp: the original row partition.
3024 broadcast_rp: the target row partition.
3026 Returns:
3027 the gather_index for next_broadcaster.
3028 Raises:
3029 InvalidArgumentError if the shapes are incompatible.
3030 """
3031 old_value_rowids = array_ops.gather(bc.gather_index,
3032 broadcast_rp.value_rowids())
3034 def gi_no_broadcast():
3035 # TODO(martinz): decide if row_splits or row_starts should be used here.
3036 old_row_starts = array_ops.gather(original_rp.row_splits(),
3037 old_value_rowids)
3038 expected_row_lengths = array_ops.gather(
3039 params=original_rp.row_lengths(), indices=bc.gather_index)
3040 actual_row_lengths = broadcast_rp.row_lengths()
3041 check_valid = check_ops.assert_equal(
3042 expected_row_lengths, actual_row_lengths, message="Cannot broadcast")
3043 gather_index = old_row_starts + broadcast_rp.offsets_in_rows()
3044 return control_flow_ops.with_dependencies([check_valid], gather_index)
3046 def gi_broadcast():
3047 # Several optimizations can occur here.
3048 # old_row_starts == old_value_rowids, because:
3049 # if you are broadcasting, then the source has uniform row length of 1,
3050 # implying original_rp.row_splits == tf.range(orgininal_rp.nvals + 1)
3051 # When broadcasting, there is no need to add offsets to the
3052 # source, because the source has size 1.
3053 # Also, this is always valid, because we enforce source and destination
3054 # have uniform_row_length.
3055 return old_value_rowids
3057 if not original_rp.is_uniform():
3058 return gi_no_broadcast()
3060 do_broadcast = math_ops.equal(original_rp.uniform_row_length(),
3061 constant_op.constant(1, original_rp.dtype))
3062 gather_index = cond.cond(
3063 do_broadcast, true_fn=gi_broadcast, false_fn=gi_no_broadcast)
3065 return gather_index
3068def _flat_values_shape(rt):
3069 if isinstance(rt, ragged_tensor.RaggedTensor):
3070 return array_ops.shape(rt.flat_values)
3071 return rt.flat_values.shape
3074def _to_row_partitions_and_nvals_from_lengths(
3075 lengths: Sequence[Union[int, Sequence[int]]],
3076 dtype=None) -> Tuple[Sequence[RowPartition], int]:
3077 """Allow ragged and uniform shapes to be specified.
3079 For example, [2, [2,1], 2] represents a shape like:
3080 [[[0, 0], [0, 0]], [[0, 0]]]
3082 Args:
3083 lengths: a list of integers and lists of integers.
3084 dtype: dtype of the shape (tf.int32 or tf.int64)
3086 Returns:
3087 a sequence of RowPartitions, and the number of values of the last partition.
3088 """
3089 size_so_far = lengths[0]
3090 result = []
3091 for current_lengths in lengths[1:]:
3092 if isinstance(current_lengths, int):
3093 nrows = size_so_far
3094 nvals = current_lengths * nrows
3095 size_so_far = nvals
3096 result.append(
3097 RowPartition.from_uniform_row_length(
3098 current_lengths, nvals, nrows=nrows, dtype_hint=dtype))
3099 else:
3100 if size_so_far != len(current_lengths):
3101 raise ValueError("Shape not consistent.")
3102 result.append(
3103 RowPartition.from_row_lengths(current_lengths, dtype_hint=dtype))
3104 size_so_far = sum(current_lengths)
3105 return (result, size_so_far)
3108def _element_to_string(x):
3109 """element to a string within a list."""
3110 if x is Ellipsis:
3111 return "..."
3112 if isinstance(x, str):
3113 return "'" + x + "'"
3114 return str(x)
3117def _list_tail_with_ellipsis(arr):
3118 """Print the tail of a list where the list might have an ellipsis."""
3119 if not arr:
3120 return "]"
3121 else:
3122 return ", " + _element_to_string(arr[0]) + _list_tail_with_ellipsis(arr[1:])
3125def _list_with_ellipsis_to_str(arr):
3126 """Print a list that might have ellipsis."""
3127 if not arr:
3128 return "[]"
3129 return "[" + _element_to_string(arr[0]) + _list_tail_with_ellipsis(arr[1:])
3132def _is_int_or_tuple_of_ints(x):
3133 if isinstance(x, int):
3134 return True
3135 if not isinstance(x, tuple):
3136 return False
3137 for y in x:
3138 if not isinstance(y, int):
3139 return False
3140 return True
3143def _alt_inner_shape_from_tensor_shape(shape, dtype, new_inner_rank):
3144 """Helper for _alt_inner_shape, used directly in _with_num_row_partitions."""
3145 if new_inner_rank == 1:
3146 return constant_op.constant([shape.num_elements()], dtype=dtype)
3147 new_inner_rank_tail_length = new_inner_rank - 1
3148 inner_shape_tail = shape[-new_inner_rank_tail_length:].as_list()
3149 first_dim = shape[:-new_inner_rank_tail_length].num_elements()
3150 return constant_op.constant([first_dim] + inner_shape_tail, dtype=dtype)
3153def _safe_floor_div(dividend: tensor_shape.Dimension,
3154 divisor: tensor_shape.Dimension) -> tensor_shape.Dimension:
3155 if tensor_shape.dimension_value(divisor) == 0:
3156 return None
3157 return dividend // divisor
3160# TODO(b/218932570)
3161def _reduce_prod_patch(x):
3162 if x.dtype == dtypes.int64:
3163 return math_ops.cast(
3164 math_ops.reduce_prod(math_ops.cast(x, dtypes.int32)), dtypes.int64)
3165 return math_ops.reduce_prod(x)
3168# Type alias for shape encoded as a DynamicRaggedShape or a Tensor.
3169DenseOrRaggedShape = Union[DynamicRaggedShape, core.TensorLike]
3172def _merge_row_partitions(
3173 row_partitions: Sequence[RowPartition]) -> RowPartition:
3174 # TODO(martinz): handle uniform splits.
3175 # TODO(martinz): consider using value_row_ids if present.
3176 # Note: this probably won't be called with len(row_partitions)==1, so no
3177 # need to optimize.
3178 row_splits = row_partitions[0].row_splits()
3179 for rp in row_partitions[1:]:
3180 row_splits = array_ops.gather(rp.row_splits(), row_splits)
3181 return RowPartition.from_row_splits(row_splits)
3184def _merge_inner_shape(
3185 inner_shape: ops.Tensor, static_inner_shape: tensor_shape.TensorShape,
3186 outer_axis: int,
3187 inner_axis: int) -> Tuple[ops.Tensor, tensor_shape.TensorShape]:
3188 """Merge the inner shape of a DynamicRaggedShape."""
3189 prefix = inner_shape[:outer_axis]
3190 suffix = inner_shape[inner_axis + 1:]
3192 internal = inner_shape[outer_axis:inner_axis + 1]
3193 internal_value = [_reduce_prod_patch(internal)]
3194 new_internal = array_ops.concat([prefix, internal_value, suffix], axis=0)
3195 prefix_static = static_inner_shape[:outer_axis]
3196 suffix_static = static_inner_shape[inner_axis + 1:]
3197 internal_static = static_inner_shape[outer_axis:inner_axis + 1]
3198 internal_value_static = tensor_shape.TensorShape(
3199 [internal_static.num_elements()])
3200 new_internal_static = prefix_static + internal_value_static + suffix_static
3202 return (new_internal, new_internal_static)
3205def _batch_rp_spec(rp_spec: RowPartitionSpec,
3206 batch_size: Optional[int]) -> RowPartitionSpec:
3207 """Batches a RowPartitionSpec.
3209 Given a RowPartitionSpec and a batch_size, create a RowPartitionSpec that
3210 will be the spec for the concatenation of batch_size RowPartitions.
3212 A RowPartition can be considered a transformation from a list of a given
3213 length to a list of lists. Assume rp_a is a map from list_a to nlist_a,
3214 And rp_b is a map from list_b to nlist_b. concat(rp_a, rp_b) is a
3215 transform of concat(list_a, list_b) to concat(nlist_a, nlist_b).
3217 If batch_size is None, then have the spec be able to handle an arbitrary
3218 number of RowPartitions.
3220 Args:
3221 rp_spec: a RowPartitionSpec for all the RowPartitions to be concatenated.
3222 batch_size: the number of rp_specs to be concatenated.
3224 Returns:
3225 a batched RowPartitionSpec.
3226 """
3227 if batch_size is None:
3228 return RowPartitionSpec(
3229 uniform_row_length=rp_spec.uniform_row_length, dtype=rp_spec.dtype)
3230 nrows = None if rp_spec.nrows is None else rp_spec.nrows * batch_size
3231 nvals = None if rp_spec.nvals is None else rp_spec.nvals * batch_size
3232 return RowPartitionSpec(
3233 nrows=nrows,
3234 nvals=nvals,
3235 uniform_row_length=rp_spec.uniform_row_length,
3236 dtype=rp_spec.dtype)
3239def _batch_rp_spec_head(old_head: RowPartitionSpec,
3240 batch_size: Optional[int]) -> RowPartitionSpec:
3241 """Creates a RowPartitionSpec representing the new dimension created."""
3242 nvals = None if (old_head.nrows is None or
3243 batch_size is None) else batch_size * old_head.nrows
3244 return RowPartitionSpec(
3245 nrows=batch_size,
3246 nvals=nvals,
3247 uniform_row_length=old_head.nrows,
3248 dtype=old_head.dtype)
3251def _batch_static_inner_shape(
3252 old_shape: tensor_shape.TensorShape,
3253 batch_size: Optional[int]) -> tensor_shape.TensorShape:
3254 """Returns a copy of old_shape with axis=0 multiplied by batch_size.
3256 Only use if this is the inner_shape of a DynamicRaggedShape.Spec with one
3257 or more row partitions.
3259 Args:
3260 old_shape: the original inner_shape.
3261 batch_size: the batch size.
3263 Returns:
3264 a new shape.
3265 """
3266 head_dim = tensor_shape.dimension_at_index(old_shape, 0) * batch_size
3267 return head_dim + old_shape[1:]
3270def _batch_tensor_shape(old_shape: tensor_shape.TensorShape,
3271 batch_size: int) -> tensor_shape.TensorShape:
3272 return tensor_shape.TensorShape([batch_size]) + old_shape
3275def _unbatch_static_inner_shape(
3276 old_shape: tensor_shape.TensorShape,
3277 batch_size: Optional[int]) -> tensor_shape.TensorShape:
3278 """Unbatch a static_inner_shape when num_row_partitions > 0."""
3279 head_dim = tensor_shape.dimension_at_index(old_shape, 0) // batch_size
3280 return head_dim + old_shape[1:]
3283# Copied from ragged_array_ops.py
3284def ones(shape: DynamicRaggedShape,
3285 dtype=dtypes.float32,
3286 name: Optional[str] = None) -> ragged_tensor.RaggedOrDense:
3287 """Returns ones shaped like x."""
3288 flat_values = array_ops.ones(shape.inner_shape, dtype=dtype, name=name)
3289 return ragged_tensor.RaggedTensor._from_nested_row_partitions( # pylint: disable=protected-access
3290 flat_values, shape.row_partitions)