Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/row_partition.py: 23%
494 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A class used to partition a sequence into contiguous subsequences ("rows").
16"""
19# TODO(edloper): Make into a ExtensionType (if possible)
22import numpy as np
24from tensorflow.core.protobuf import struct_pb2
25from tensorflow.python.framework import composite_tensor
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_conversion
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.framework import type_spec
34from tensorflow.python.framework import type_spec_registry
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import check_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import gen_ragged_math_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops.ragged import segment_id_ops
41from tensorflow.python.saved_model import nested_structure_coder
42from tensorflow.python.util.tf_export import tf_export
44# ===============================================================================
45# RowPartition
46# ===============================================================================
47# TODO(edloper): Consider removing row_starts and row_limits factory methods
48# and accessors from RowPartition. In particular, these two encodings are
49# "second-class citizens": we never cache them, and if you do construct a
50# RowPartition from them then it may be more expensive than you might expect
51# (because we append a value to the beginning/end to transform them into
52# splits). If we do remove them from RowPartition, then we would still keep
53# the from_row_starts and from_row_limits factory methods in RaggedTensor.
56@tf_export("experimental.RowPartition")
57class RowPartition(composite_tensor.CompositeTensor):
58 """Partitioning of a sequence of values into contiguous subsequences ("rows").
60 A `RowPartition` describes how a sequence with `nvals` items should be
61 divided into `nrows` contiguous subsequences ("rows"). For example, a
62 `RowPartition` could be used to partition the vector `[1, 2, 3, 4, 5]` into
63 subsequences `[[1, 2], [3], [], [4, 5]]`. Note that `RowPartition` stores
64 information about how values are partitioned, but does not include the
65 partitioned values themselves. `tf.RaggedTensor` is used to pair a `values`
66 tensor with one or more `RowPartition`s, providing a complete encoding for a
67 ragged tensor (i.e. a tensor with variable-length dimensions).
69 `RowPartition`s may be defined using several different schemes:
71 * `row_lengths`: an integer vector with shape `[nrows]`, which specifies
72 the length of each row.
74 * `row_splits`: an integer vector with shape `[nrows+1]`, specifying the
75 "split points" between each row.
77 * `row_starts`: an integer vector with shape `[nrows]`, which specifies
78 the start offset for each row. Equivalent to `row_splits[:-1]`.
80 * `row_limits`: an integer vector with shape `[nrows]`, which specifies
81 the stop offset for each row. Equivalent to `row_splits[1:]`.
83 * `value_rowids` is an integer vector with shape `[nvals]`, corresponding
84 one-to-one with sequence values, which specifies the row that each value
85 belongs to. If the partition has empty trailing rows, then `nrows`
86 must also be specified.
88 * `uniform_row_length` is an integer scalar, specifying the length of every
89 row. This scheme may only be used if all rows have the same length.
91 For example, the following `RowPartition`s all represent the partitioning of
92 8 values into 5 sublists as follows: `[[*, *, *, *], [], [*, *, *], [*], []]`.
94 >>> p1 = RowPartition.from_row_lengths([4, 0, 3, 1, 0])
95 >>> p2 = RowPartition.from_row_splits([0, 4, 4, 7, 8, 8])
96 >>> p3 = RowPartition.from_row_starts([0, 4, 4, 7, 8], nvals=8)
97 >>> p4 = RowPartition.from_row_limits([4, 4, 7, 8, 8])
98 >>> p5 = RowPartition.from_value_rowids([0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
100 For more information about each scheme, see the documentation for the
101 its factory method. For additional examples, see the documentation on
102 `tf.RaggedTensor`.
104 ### Precomputed Encodings
106 `RowPartition` always stores at least one encoding of the partitioning, but
107 it can be configured to cache additional encodings as well. This can
108 avoid unnecessary recomputation in eager mode. (In graph mode, optimizations
109 such as common subexpression elimination will typically prevent these
110 unnecessary recomputations.) To check which encodings are precomputed, use
111 `RowPartition.has_precomputed_<encoding>`. To cache an additional
112 encoding, use `RowPartition.with_precomputed_<encoding>`.
113 """
115 # =============================================================================
116 # Constructor (private)
117 # =============================================================================
118 def __init__(self,
119 row_splits,
120 row_lengths=None,
121 value_rowids=None,
122 nrows=None,
123 uniform_row_length=None,
124 nvals=None,
125 internal=False):
126 """Creates a `RowPartition` from the specified encoding tensor(s).
128 This constructor is private -- please use one of the following ops to
129 build `RowPartition`s:
131 * `RowPartition.from_row_lengths`
132 * `RowPartition.from_value_rowids`
133 * `RowPartition.from_row_splits`
134 * `RowPartition.from_row_starts`
135 * `RowPartition.from_row_limits`
136 * `RowPartition.from_uniform_row_length`
138 If row_splits is has a constant value, then all other arguments should
139 have a constant value.
141 Args:
142 row_splits: A 1-D integer tensor with shape `[nrows+1]`.
143 row_lengths: A 1-D integer tensor with shape `[nrows]`
144 value_rowids: A 1-D integer tensor with shape `[nvals]`.
145 nrows: A 1-D integer scalar tensor.
146 uniform_row_length: A scalar tensor.
147 nvals: A scalar tensor.
148 internal: Private key value, required to ensure that this private
149 constructor is *only* called from the factory methods.
151 Raises:
152 TypeError: If a row partitioning tensor has an inappropriate dtype.
153 TypeError: If exactly one row partitioning argument was not specified.
154 ValueError: If a row partitioning tensor has an inappropriate shape.
155 ValueError: If multiple partitioning arguments are specified.
156 ValueError: If nrows is specified but value_rowids is not None.
157 """
158 if internal is not _row_partition_factory_key:
159 raise ValueError("RowPartition constructor is private; please use one "
160 "of the factory methods instead (e.g., "
161 "RowPartition.from_row_lengths())")
163 # Validate the arguments.
164 if not isinstance(row_splits, ops.Tensor):
165 raise TypeError("Row-partitioning argument must be a Tensor, got %r" %
166 row_splits)
167 if row_splits.dtype not in (dtypes.int32, dtypes.int64):
168 raise ValueError("Row-partitioning argument must be int32 or int64")
170 # Validate shapes & dtypes.
171 row_splits.shape.assert_has_rank(1)
172 row_splits.set_shape([None])
173 self._row_splits = row_splits
175 # Store any cached tensors. These are used to avoid unnecessary
176 # round-trip conversions when a RowPartition is constructed from
177 # lengths or rowids, and we later want those lengths/rowids back.
178 for tensor in [row_lengths, value_rowids, nrows, uniform_row_length, nvals]:
179 if tensor is not None:
180 if not isinstance(tensor, ops.Tensor):
181 raise TypeError("Cached value must be a Tensor or None.")
182 elif tensor.dtype != row_splits.dtype:
183 raise ValueError(f"Inconsistent dtype for encoding tensors: "
184 f"{tensor} vs {row_splits}")
185 self._row_lengths = row_lengths
186 self._value_rowids = value_rowids
187 self._nrows = nrows
188 self._uniform_row_length = uniform_row_length
189 self._nvals = nvals
191 # =============================================================================
192 # Factory Methods
193 # =============================================================================
195 @classmethod
196 def from_value_rowids(cls,
197 value_rowids,
198 nrows=None,
199 validate=True,
200 dtype=None,
201 dtype_hint=None):
202 """Creates a `RowPartition` with rows partitioned by `value_rowids`.
204 This `RowPartition` divides a sequence `values` into rows by specifying
205 which row each value should be added to:
207 ```python
208 partitioned_rows = [[] for _ in nrows]
209 for (value, rowid) in zip(values, value_rowids):
210 partitioned_rows[rowid].append(value)
211 ```
213 Args:
214 value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds
215 one-to-one with `values`, and specifies each value's row index. Must be
216 nonnegative, and must be sorted in ascending order.
217 nrows: An integer scalar specifying the number of rows. This should be
218 specified if the `RowPartition` may containing empty training rows. Must
219 be greater than `value_rowids[-1]` (or greater than or equal to zero if
220 `value_rowids` is empty). Defaults to `value_rowids[-1] + 1` (or zero if
221 `value_rowids` is empty).
222 validate: If true, then use assertions to check that the arguments form a
223 valid `RowPartition`.
224 dtype: Optional dtype for the RowPartition. If missing, the type
225 is inferred from the type of `value_rowids`, dtype_hint, or tf.int64.
226 dtype_hint: Optional dtype for the RowPartition, used when dtype
227 is None. In some cases, a caller may not have a dtype in mind when
228 converting to a tensor, so dtype_hint can be used as a soft preference.
229 If the conversion to `dtype_hint` is not possible, this argument has no
230 effect.
232 Returns:
233 A `RowPartition`.
235 Raises:
236 ValueError: If `nrows` is incompatible with `value_rowids`.
238 #### Example:
240 >>> print(RowPartition.from_value_rowids(
241 ... value_rowids=[0, 0, 0, 0, 2, 2, 2, 3],
242 ... nrows=4))
243 tf.RowPartition(row_splits=[0 4 4 7 8])
244 """
245 # Local import bincount_ops to avoid import-cycle since bincount_ops
246 # imports ragged_tensor.
247 from tensorflow.python.ops import bincount_ops # pylint: disable=g-import-not-at-top
248 if not isinstance(validate, bool):
249 raise TypeError("validate must have type bool")
250 with ops.name_scope(None, "RowPartitionFromValueRowIds",
251 [value_rowids, nrows]):
252 value_rowids = cls._convert_row_partition(
253 value_rowids, "value_rowids", dtype_hint=dtype_hint, dtype=dtype)
254 if nrows is None:
255 const_rowids = tensor_util.constant_value(value_rowids)
256 if const_rowids is None:
257 nrows = array_ops.concat([value_rowids[-1:], [-1]], axis=0)[0] + 1
258 const_nrows = None
259 else:
260 const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0
261 nrows = ops.convert_to_tensor(
262 const_nrows, value_rowids.dtype, name="nrows")
263 else:
264 nrows = ops.convert_to_tensor(nrows, value_rowids.dtype, "nrows")
265 const_nrows = tensor_util.constant_value(nrows)
266 if const_nrows is not None:
267 if const_nrows < 0:
268 raise ValueError("Expected nrows >= 0; got %d" % const_nrows)
269 const_rowids = tensor_util.constant_value(value_rowids)
270 if const_rowids is not None and const_rowids.size > 0:
271 if not const_nrows >= const_rowids[-1] + 1:
272 raise ValueError(
273 "Expected nrows >= value_rowids[-1] + 1; got nrows=%d, "
274 "value_rowids[-1]=%d" % (const_nrows, const_rowids[-1]))
276 value_rowids.shape.assert_has_rank(1)
277 nrows.shape.assert_has_rank(0)
279 if validate:
280 msg = ("Arguments to from_value_rowids do not form a valid "
281 "RowPartition")
282 checks = [
283 check_ops.assert_rank(value_rowids, 1, message=msg),
284 check_ops.assert_rank(nrows, 0, message=msg),
285 check_ops.assert_non_negative(value_rowids[:1], message=msg),
286 _assert_monotonic_increasing(value_rowids, message=msg),
287 check_ops.assert_less(value_rowids[-1:], nrows, message=msg),
288 ]
289 value_rowids = control_flow_ops.with_dependencies(checks, value_rowids)
291 # Convert value_rowids & nrows to row_splits.
292 # Note: we don't use segment_ids_to_row_splits() here because we want
293 # to save the intermediate value `row_lengths`, so we can cache it.
294 # TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the
295 # cast.
296 value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32)
297 nrows_int32 = math_ops.cast(nrows, dtypes.int32)
298 row_lengths = bincount_ops.bincount(
299 value_rowids_int32,
300 minlength=nrows_int32,
301 maxlength=nrows_int32,
302 dtype=value_rowids.dtype)
303 row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0)
304 if const_nrows is not None:
305 row_lengths.set_shape([const_nrows])
306 row_splits.set_shape([const_nrows + 1])
308 return cls(
309 row_splits=row_splits,
310 row_lengths=row_lengths,
311 value_rowids=value_rowids,
312 nrows=nrows,
313 internal=_row_partition_factory_key)
315 @classmethod
316 def from_row_splits(cls,
317 row_splits,
318 validate=True,
319 dtype=None,
320 dtype_hint=None):
321 """Creates a `RowPartition` with rows partitioned by `row_splits`.
323 This `RowPartition` divides a sequence `values` into rows by indicating
324 where each row begins and ends:
326 ```python
327 partitioned_rows = []
328 for i in range(len(row_splits) - 1):
329 row_start = row_splits[i]
330 row_end = row_splits[i + 1]
331 partitioned_rows.append(values[row_start:row_end])
332 ```
334 Args:
335 row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be
336 empty, and must be sorted in ascending order. `row_splits[0]` must be
337 zero.
338 validate: If true, then use assertions to check that the arguments form a
339 valid `RowPartition`.
340 dtype: Optional dtype for the RowPartition. If missing, the type
341 is inferred from the type of `row_splits`, dtype_hint, or tf.int64.
342 dtype_hint: Optional dtype for the RowPartition, used when dtype
343 is None. In some cases, a caller may not have a dtype in mind when
344 converting to a tensor, so dtype_hint can be used as a soft preference.
345 If the conversion to `dtype_hint` is not possible, this argument has no
346 effect.
348 Returns:
349 A `RowPartition`.
351 Raises:
352 ValueError: If `row_splits` is an empty list.
353 """
354 if not isinstance(validate, bool):
355 raise TypeError("validate must have type bool")
356 if isinstance(row_splits, (list, tuple)) and not row_splits:
357 raise ValueError("row_splits tensor may not be empty.")
358 if isinstance(row_splits, tensor_spec.TensorSpec):
359 return cls(row_splits=row_splits, internal=_row_partition_factory_key)
361 with ops.name_scope(None, "RowPartitionFromRowSplits", [row_splits]):
362 row_splits = cls._convert_row_partition(
363 row_splits, "row_splits", dtype_hint=dtype_hint, dtype=dtype)
364 row_splits.shape.assert_has_rank(1)
366 if validate:
367 msg = "Arguments to from_row_splits do not form a valid RaggedTensor:"
368 checks = [
369 check_ops.assert_rank(row_splits, 1, message=(msg + "rank")),
370 _assert_zero(row_splits[0], message=(msg + "zero")),
371 _assert_monotonic_increasing(
372 row_splits, message=(msg + "monotonic")),
373 ]
374 row_splits = control_flow_ops.with_dependencies(checks, row_splits)
376 return cls(row_splits=row_splits, internal=_row_partition_factory_key)
378 @classmethod
379 def from_row_lengths(cls,
380 row_lengths,
381 validate=True,
382 dtype=None,
383 dtype_hint=None):
384 """Creates a `RowPartition` with rows partitioned by `row_lengths`.
386 This `RowPartition` divides a sequence `values` into rows by indicating
387 the length of each row:
389 ```python
390 partitioned_rows = [[values.pop(0) for _ in range(length)]
391 for length in row_lengths]
392 ```
394 Args:
395 row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be
396 nonnegative.
397 validate: If true, then use assertions to check that the arguments form a
398 valid `RowPartition`.
400 dtype: Optional dtype for the RowPartition. If missing, the type
401 is inferred from the type of `row_lengths`, dtype_hint, or tf.int64.
402 dtype_hint: Optional dtype for the RowPartition, used when dtype
403 is None. In some cases, a caller may not have a dtype in mind when
404 converting to a tensor, so dtype_hint can be used as a soft preference.
405 If the conversion to `dtype_hint` is not possible, this argument has no
406 effect.
408 Returns:
409 A `RowPartition`.
410 """
411 if not isinstance(validate, bool):
412 raise TypeError("validate must have type bool")
413 with ops.name_scope(None, "RowPartitionFromRowLengths", [row_lengths]):
414 row_lengths = cls._convert_row_partition(
415 row_lengths, "row_lengths", dtype_hint=dtype_hint, dtype=dtype)
416 row_lengths.shape.assert_has_rank(1)
418 if validate:
419 msg = "Arguments to from_row_lengths do not form a valid RowPartition"
420 checks = [
421 check_ops.assert_rank(row_lengths, 1, message=msg),
422 check_ops.assert_non_negative(row_lengths, message=msg),
423 ]
424 row_lengths = control_flow_ops.with_dependencies(checks, row_lengths)
426 row_limits = math_ops.cumsum(row_lengths)
427 row_splits = array_ops.concat([[0], row_limits], axis=0)
428 return cls(
429 row_splits=row_splits,
430 row_lengths=row_lengths,
431 internal=_row_partition_factory_key)
433 @classmethod
434 def from_row_starts(cls,
435 row_starts,
436 nvals,
437 validate=True,
438 dtype=None,
439 dtype_hint=None):
440 """Creates a `RowPartition` with rows partitioned by `row_starts`.
442 Equivalent to: `from_row_splits(concat([row_starts, nvals], axis=0))`.
444 Args:
445 row_starts: A 1-D integer tensor with shape `[nrows]`. Must be
446 nonnegative and sorted in ascending order. If `nrows>0`, then
447 `row_starts[0]` must be zero.
448 nvals: A scalar tensor indicating the number of values.
449 validate: If true, then use assertions to check that the arguments form a
450 valid `RowPartition`.
451 dtype: Optional dtype for the RowPartition. If missing, the type
452 is inferred from the type of `row_starts`, dtype_hint, or tf.int64.
453 dtype_hint: Optional dtype for the RowPartition, used when dtype
454 is None. In some cases, a caller may not have a dtype in mind when
455 converting to a tensor, so dtype_hint can be used as a soft preference.
456 If the conversion to `dtype_hint` is not possible, this argument has no
457 effect.
459 Returns:
460 A `RowPartition`.
461 """
462 if not isinstance(validate, bool):
463 raise TypeError("validate must have type bool")
464 with ops.name_scope(None, "RowPartitionFromRowStarts", [row_starts]):
465 row_starts = cls._convert_row_partition(
466 row_starts, "row_starts", dtype_hint=dtype_hint, dtype=dtype)
467 row_starts.shape.assert_has_rank(1)
468 # TODO(martinz): nvals and row_starts could be inconsistent at call time,
469 # even though they eventually end up the same type.
470 nvals = math_ops.cast(nvals, row_starts.dtype)
471 if validate:
472 msg = "Arguments to from_row_starts do not form a valid RaggedTensor"
473 checks = [
474 check_ops.assert_rank(row_starts, 1, message=msg),
475 _assert_zero(row_starts[:1], message=msg),
476 _assert_monotonic_increasing(row_starts, message=msg),
477 check_ops.assert_less_equal(row_starts[-1:], nvals, message=msg),
478 ]
479 row_starts = control_flow_ops.with_dependencies(checks, row_starts)
481 row_splits = array_ops.concat([row_starts, [nvals]], axis=0)
482 return cls(row_splits=row_splits, nvals=nvals,
483 internal=_row_partition_factory_key)
485 @classmethod
486 def from_row_limits(cls,
487 row_limits,
488 validate=True,
489 dtype=None,
490 dtype_hint=None):
491 """Creates a `RowPartition` with rows partitioned by `row_limits`.
493 Equivalent to: `from_row_splits(values, concat([0, row_limits], axis=0))`.
495 Args:
496 row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in
497 ascending order.
498 validate: If true, then use assertions to check that the arguments form a
499 valid `RowPartition`.
500 dtype: Optional dtype for the RowPartition. If missing, the type
501 is inferred from the type of `row_limits`, dtype_hint, or tf.int64.
502 dtype_hint: Optional dtype for the RowPartition, used when dtype
503 is None. In some cases, a caller may not have a dtype in mind when
504 converting to a tensor, so dtype_hint can be used as a soft preference.
505 If the conversion to `dtype_hint` is not possible, this argument has no
506 effect.
508 Returns:
509 A `RowPartition`.
510 """
511 if not isinstance(validate, bool):
512 raise TypeError("validate must have type bool")
513 with ops.name_scope(None, "RowPartitionFromRowLimits", [row_limits]):
514 row_limits = cls._convert_row_partition(
515 row_limits, "row_limits", dtype_hint=dtype_hint, dtype=dtype)
516 row_limits.shape.assert_has_rank(1)
518 if validate:
519 msg = "Arguments to from_row_limits do not form a valid RaggedTensor"
520 checks = [
521 check_ops.assert_rank(row_limits, 1, message=msg),
522 check_ops.assert_non_negative(row_limits[:1], message=msg),
523 _assert_monotonic_increasing(row_limits, message=msg),
524 ]
525 row_limits = control_flow_ops.with_dependencies(checks, row_limits)
527 zero = array_ops.zeros([1], row_limits.dtype)
528 row_splits = array_ops.concat([zero, row_limits], axis=0)
529 return cls(row_splits=row_splits, internal=_row_partition_factory_key)
531 @classmethod
532 def from_uniform_row_length(cls,
533 uniform_row_length,
534 nvals=None,
535 nrows=None,
536 validate=True,
537 dtype=None,
538 dtype_hint=None):
539 """Creates a `RowPartition` with rows partitioned by `uniform_row_length`.
541 This `RowPartition` divides a sequence `values` into rows that all have
542 the same length:
544 ```python
545 partitioned_rows = [[values.pop(0) for _ in range(uniform_row_length)]
546 for _ in range(nrows)]
547 ```
549 Note that either or both of nvals and nrows must be specified.
551 Args:
552 uniform_row_length: A scalar integer tensor. Must be nonnegative. The
553 size of the outer axis of `values` must be evenly divisible by
554 `uniform_row_length`.
555 nvals: a non-negative scalar integer tensor for the number of values.
556 Must be specified if nrows is not specified. If not specified,
557 defaults to uniform_row_length*nrows
558 nrows: The number of rows in the constructed RowPartition. If not
559 specified, then it defaults to `nvals/uniform_row_length` (or `0` if
560 `uniform_row_length==0`). `nrows` only needs to be specified if
561 `uniform_row_length` might be zero. `uniform_row_length*nrows` must be
562 `nvals`.
563 validate: If true, then use assertions to check that the arguments form a
564 valid `RowPartition`.
565 dtype: Optional dtype for the RowPartition. If missing, the type
566 is inferred from the type of `uniform_row_length`, dtype_hint,
567 or tf.int64.
568 dtype_hint: Optional dtype for the RowPartition, used when dtype
569 is None. In some cases, a caller may not have a dtype in mind when
570 converting to a tensor, so dtype_hint can be used as a soft preference.
571 If the conversion to `dtype_hint` is not possible, this argument has no
572 effect.
574 Returns:
575 A `RowPartition`.
576 """
577 if not isinstance(validate, bool):
578 raise TypeError("validate must have type bool")
579 if nrows is None and nvals is None:
580 raise ValueError("Either (or both) of nvals and nrows must be specified")
581 with ops.name_scope(None, "RowPartitionFromUniformRowLength",
582 [uniform_row_length, nrows]):
583 [uniform_row_length, nvals, nrows
584 ] = _convert_all_to_tensors([(uniform_row_length, "uniform_row_length"),
585 (nvals, "nvals"), (nrows, "nrows")],
586 dtype=dtype,
587 dtype_hint=dtype_hint)
589 uniform_row_length.shape.assert_has_rank(0)
591 # Find nrows.
592 const_row_length = tensor_util.constant_value(uniform_row_length)
593 if nrows is None:
594 if const_row_length is None:
595 # Avoid division by zero if uniform_row_length==0 (and nvals==0).
596 rowlen_or_1 = math_ops.maximum(
597 uniform_row_length,
598 constant_op.constant(1, uniform_row_length.dtype))
599 nrows = nvals // rowlen_or_1
600 elif const_row_length == 0:
601 nrows = constant_op.constant(0, dtype=uniform_row_length.dtype)
602 else:
603 nrows = nvals // const_row_length
604 const_nrows = None if nrows is None else tensor_util.constant_value(nrows)
605 const_nvals = None if nvals is None else tensor_util.constant_value(nvals)
606 const_uniform_row_length = tensor_util.constant_value(uniform_row_length)
608 checks = []
610 if const_nvals is None and const_nrows is not None and const_uniform_row_length is not None:
611 const_nvals = const_nrows * const_uniform_row_length
612 if nvals is not None and validate:
613 checks.append(check_ops.assert_equal(nvals, const_nvals))
614 nvals = constant_op.constant(const_nvals, uniform_row_length.dtype)
616 if nvals is None:
617 nvals = nrows * uniform_row_length
619 # Find row_splits.
620 if const_nrows is not None and const_row_length is not None:
621 row_splits = [v * const_row_length for v in range(const_nrows + 1)]
622 row_splits = constant_op.constant(row_splits, uniform_row_length.dtype)
623 else:
624 row_splits = math_ops.range(
625 nrows + 1, dtype=uniform_row_length.dtype) * uniform_row_length
627 if validate:
629 if (const_nrows is None or const_row_length is None or
630 const_nvals is None):
631 checks.append(
632 check_ops.assert_equal(
633 nrows * uniform_row_length, nvals,
634 ("uniform_row_length", uniform_row_length, "times nrows",
635 nrows, "must equal nvals", nvals)))
636 else:
637 if const_nrows * const_row_length != const_nvals:
638 raise ValueError(
639 "uniform_row_length=%d times nrows=%d must equal nvals=%d" %
640 (const_row_length, const_nrows, const_nvals))
642 if uniform_row_length.shape.rank is None:
643 checks.append(
644 check_ops.assert_rank(
645 uniform_row_length,
646 0,
647 message="uniform_row_length must be a scalar."))
649 const_row_length = tensor_util.constant_value(uniform_row_length)
650 if const_row_length is None:
651 checks.append(
652 check_ops.assert_greater_equal(
653 uniform_row_length,
654 constant_op.constant(0, uniform_row_length.dtype),
655 message="uniform_row_length must be >= 0."))
656 else:
657 if const_row_length < 0:
658 raise ValueError("uniform_row_length must be >= 0.")
660 row_splits = control_flow_ops.with_dependencies(checks, row_splits)
662 return cls(
663 row_splits=row_splits,
664 uniform_row_length=uniform_row_length,
665 nrows=nrows,
666 nvals=nvals,
667 internal=_row_partition_factory_key)
669 @classmethod
670 def _convert_row_partition(cls, partition, name, dtype=None, dtype_hint=None):
671 """Converts `partition` to Tensors.
673 Args:
674 partition: A row-partitioning tensor for the `RowPartition` being
675 constructed. I.e., one of: row_splits, row_lengths, row_starts,
676 row_limits, value_rowids, uniform_row_length.
677 name: The name of the row-partitioning tensor.
678 dtype: Optional dtype for the RowPartition. If missing, the type
679 is inferred from the type of `uniform_row_length`, dtype_hint,
680 or tf.int64.
681 dtype_hint: Optional dtype for the RowPartition, used when dtype
682 is None. In some cases, a caller may not have a dtype in mind when
683 converting to a tensor, so dtype_hint can be used as a soft preference.
684 If the conversion to `dtype_hint` is not possible, this argument has no
685 effect.
687 Returns:
688 A tensor equivalent to partition.
690 Raises:
691 ValueError: if dtype is not int32 or int64.
692 """
693 if dtype_hint is None:
694 dtype_hint = dtypes.int64
695 if (isinstance(partition, np.ndarray) and
696 partition.dtype == np.int32 and dtype is None):
697 partition = ops.convert_to_tensor(partition, name=name)
698 else:
699 partition = tensor_conversion.convert_to_tensor_v2(
700 partition, dtype_hint=dtype_hint, dtype=dtype, name=name
701 )
702 if partition.dtype not in (dtypes.int32, dtypes.int64):
703 raise ValueError("%s must have dtype int32 or int64" % name)
705 return partition
707 def _with_dependencies(self, dependencies):
708 """Returns a new RowPartition equal to self with control dependencies.
710 Specifically, self._row_splits is gated by the given control dependencies.
711 Used to add sanity checks to the constructors.
713 Args:
714 dependencies: a list of tensors to use as dependencies.
716 Returns:
717 A new RowPartition object.
718 """
719 new_row_splits = control_flow_ops.with_dependencies(dependencies,
720 self._row_splits)
721 return RowPartition(
722 row_splits=new_row_splits,
723 row_lengths=self._row_lengths,
724 value_rowids=self._value_rowids,
725 nrows=self._nrows,
726 uniform_row_length=self._uniform_row_length,
727 internal=_row_partition_factory_key)
729 # =============================================================================
730 # Accessors
731 # =============================================================================
733 @property
734 def dtype(self):
735 """The `DType` used to encode the row partition (either int32 or int64)."""
736 return self._row_splits.dtype
738 def row_splits(self):
739 """Returns the row-split indices for this row partition.
741 `row_splits` specifies where the values for each row begin and end.
742 In particular, the values for row `i` are stored in the slice
743 `values[row_splits[i]:row_splits[i+1]]`.
745 Returns:
746 A 1-D integer `Tensor` with shape `[self.nrows+1]`.
747 The returned tensor is non-empty, and is sorted in ascending order.
748 `self.row_splits()[0] == 0`.
749 `self.row_splits()[-1] == self.nvals()`.
750 """
751 return self._row_splits
753 def value_rowids(self):
754 """Returns the row indices for this row partition.
756 `value_rowids` specifies the row index fo reach value. In particular,
757 `value_rowids[i]` is the row index for `values[i]`.
759 Returns:
760 A 1-D integer `Tensor` with shape `[self.nvals()]`.
761 The returned tensor is nonnegative, and is sorted in ascending order.
762 """
763 if self._value_rowids is not None:
764 return self._value_rowids
765 return segment_id_ops.row_splits_to_segment_ids(self._row_splits)
767 def nvals(self):
768 """Returns the number of values partitioned by this `RowPartition`.
770 If the sequence partitioned by this `RowPartition` is a tensor, then
771 `nvals` is the size of that tensor's outermost dimension -- i.e.,
772 `nvals == values.shape[0]`.
774 Returns:
775 scalar integer Tensor
776 """
777 # TODO(martinz): Uncomment these lines.
778 # if self._nvals is not None:
779 # return self._nvals
780 return self._row_splits[-1]
782 def nrows(self):
783 """Returns the number of rows created by this `RowPartition`.
785 Returns:
786 scalar integer Tensor
787 """
788 if self._nrows is not None:
789 return self._nrows
790 nsplits = tensor_shape.dimension_at_index(self._row_splits.shape, 0)
791 if nsplits.value is None:
792 return array_ops.shape(self._row_splits, out_type=self.dtype)[0] - 1
793 else:
794 return constant_op.constant(nsplits.value - 1, dtype=self.dtype)
796 def uniform_row_length(self):
797 """Returns the length of each row in this partition, if rows are uniform.
799 If all rows in this `RowPartition` have the same length, then this returns
800 that length as a scalar integer `Tensor`. Otherwise, it returns `None`.
802 Returns:
803 scalar Tensor with `type=self.dtype`, or `None`.
804 """
805 return self._uniform_row_length
807 def row_starts(self):
808 """Returns the start indices for rows in this row partition.
810 These indices specify where the values for each row begin.
811 `partition.row_starts()` is equal to `partition.row_splits()[:-1]`.
813 Returns:
814 A 1-D integer Tensor with shape `[self.nrows()]`.
815 The returned tensor is nonnegative, and is sorted in ascending order.
816 `self.row_starts()[0] == 0`.
817 `self.row_starts()[-1] <= self.nvals()`.
818 """
819 return self._row_splits[:-1]
821 def row_limits(self):
822 """Returns the limit indices for rows in this row partition.
824 These indices specify where the values for each row end.
825 `partition.row_limits()` is equal to `partition.row_splits()[:-1]`.
827 Returns:
828 A 1-D integer Tensor with shape `[self.nrows]`.
829 The returned tensor is nonnegative, and is sorted in ascending order.
830 `self.row_limits()[-1] == self.nvals()`.
831 """
832 return self._row_splits[1:]
834 def row_lengths(self):
835 """Returns the lengths of rows in this `RowPartition`.
837 Returns:
838 A 1-D integer Tensor with shape `[self.nrows]`.
839 The returned tensor is nonnegative.
840 `tf.reduce_sum(self.row_lengths) == self.nvals()`.
841 """
842 if self._row_lengths is not None:
843 return self._row_lengths
844 splits = self._row_splits
845 return splits[1:] - splits[:-1]
847 @property
848 def static_nrows(self):
849 """The number of rows in this partition, if statically known.
851 ```python
852 self.row_lengths().shape == [self.static_nrows]
853 self.row_starts().shape == [self.static_nrows]
854 self.row_limits().shape == [self.static_nrows]
855 self.row_splits().shape == [self.static_nrows + 1]
856 ```
858 Returns:
859 The number of rows in this partition as an `int` (if statically known);
860 or `None` (otherwise).
861 """
862 if self._row_splits is not None:
863 nrows_plus_one = tensor_shape.dimension_value(self._row_splits.shape[0])
864 if nrows_plus_one is not None:
865 return nrows_plus_one - 1
866 if self._row_lengths is not None:
867 nrows = tensor_shape.dimension_value(self._row_lengths.shape[0])
868 if nrows is not None:
869 return nrows
870 if self._nrows is not None:
871 return tensor_util.constant_value(self._nrows)
872 return None
874 @property
875 def static_nvals(self):
876 """The number of values in this partition, if statically known.
878 ```python
879 self.value_rowids().shape == [self.static_vals]
880 ```
882 Returns:
883 The number of values in this partition as an `int` (if statically known);
884 or `None` (otherwise).
885 """
886 if self._nvals is not None:
887 nvals = tensor_util.constant_value(self._nvals)
888 if nvals is not None:
889 return nvals
890 if self._value_rowids is not None:
891 nvals = tensor_shape.dimension_at_index(self._value_rowids.shape, 0)
892 if nvals.value is not None:
893 return nvals.value
894 return None
896 @property
897 def static_uniform_row_length(self):
898 """The number of values in each row of this partition, if statically known.
900 Returns:
901 The number of values in each row of this partition as an `int` (if
902 statically known); or `None` (otherwise).
903 """
904 if self._uniform_row_length is not None:
905 return tensor_util.constant_value(self._uniform_row_length)
906 return None
908 def offsets_in_rows(self):
909 """Return the offset of each value.
911 RowPartition takes an array x and converts it into sublists.
912 offsets[i] is the index of x[i] in its sublist.
913 Given a shape, such as:
914 [*,*,*],[*,*],[],[*,*]
915 This returns:
916 0,1,2,0,1,0,1
918 Returns:
919 an offset for every value.
920 """
921 return gen_ragged_math_ops.ragged_range(
922 starts=constant_op.constant(0, self.dtype),
923 limits=self.row_lengths(),
924 deltas=constant_op.constant(1, self.dtype)).rt_dense_values
926 def is_uniform(self):
927 """Returns true if the partition is known to be uniform statically.
929 This is based upon the existence of self._uniform_row_length. For example:
930 RowPartition.from_row_lengths([3,3,3]).is_uniform()==false
931 RowPartition.from_uniform_row_length(5, nvals=20).is_uniform()==true
932 RowPartition.from_row_lengths([2,0,2]).is_uniform()==false
934 Returns:
935 Whether a RowPartition is known to be uniform statically.
936 """
937 return self._uniform_row_length is not None
939 def _static_check(self):
940 """Checks if the object is internally consistent.
942 Raises:
943 ValueError if inconsistent.
944 """
945 my_dtype = self.dtype
946 if self._uniform_row_length is not None:
947 if self._uniform_row_length.dtype != my_dtype:
948 raise ValueError("_uniform_row_length.dtype=" +
949 str(self._uniform_row_length.dtype) + ", not " +
950 str(my_dtype))
952 if self._row_lengths is not None and self._row_lengths.dtype != my_dtype:
953 raise ValueError("_row_lengths.dtype=" + str(self._row_lengths.dtype) +
954 ", not " + str(my_dtype))
956 if self._value_rowids is not None and self._value_rowids.dtype != my_dtype:
957 raise ValueError("_value_rowids.dtype=" + str(self._value_rowids.dtype) +
958 ", not " + str(my_dtype))
960 if self._nrows is not None and self._nrows.dtype != my_dtype:
961 raise ValueError("_nrows.dtype=" + str(self._nrows.dtype) + ", not " +
962 str(my_dtype))
964 # =============================================================================
965 # Transformation
966 # =============================================================================
968 def with_dtype(self, dtype):
969 """Returns a copy of this RowPartition with the given encoding dtype.
971 Args:
972 dtype: The dtype for encoding tensors, such as `row_splits` and `nrows`.
973 One of `tf.int32` or `tf.int64`.
975 Returns:
976 A copy of this RowPartition, with the encoding tensors cast to the given
977 type.
978 """
979 dtype = dtypes.as_dtype(dtype)
980 if dtype not in (dtypes.int32, dtypes.int64):
981 raise ValueError("dtype must be int32 or int64")
982 if self.dtype == dtype:
983 return self
985 return RowPartition(
986 row_splits=_cast_if_not_none(self._row_splits, dtype),
987 row_lengths=_cast_if_not_none(self._row_lengths, dtype),
988 value_rowids=_cast_if_not_none(self._value_rowids, dtype),
989 nrows=_cast_if_not_none(self._nrows, dtype),
990 uniform_row_length=_cast_if_not_none(self._uniform_row_length, dtype),
991 internal=_row_partition_factory_key)
993 # =============================================================================
994 # String Encoding
995 # =============================================================================
997 def __repr__(self):
998 if self._uniform_row_length is not None:
999 return (f"tf.RowPartition(nrows={self._nrows}, "
1000 f"uniform_row_length={self._uniform_row_length})")
1001 else:
1002 return f"tf.RowPartition(row_splits={self._row_splits})"
1004 # =============================================================================
1005 # Precomputed Encodings
1006 # =============================================================================
1008 def _has_precomputed_row_splits(self):
1009 """Returns true if `row_splits` has already been computed.
1011 If true, then `self.row_splits()` will return its value without calling
1012 any TensorFlow ops.
1013 """
1014 return self._row_splits is not None
1016 def _has_precomputed_row_lengths(self):
1017 """Returns true if `row_lengths` has already been computed.
1019 If true, then `self.row_lengths()` will return its value without calling
1020 any TensorFlow ops.
1021 """
1022 return self._row_lengths is not None
1024 def _has_precomputed_value_rowids(self):
1025 """Returns true if `value_rowids` has already been computed.
1027 If true, then `self.value_rowids()` will return its value without calling
1028 any TensorFlow ops.
1029 """
1030 return self._value_rowids is not None
1032 def _has_precomputed_nrows(self):
1033 """Returns true if `nrows` has already been computed.
1035 If true, then `self.nrows()` will return its value without calling
1036 any TensorFlow ops.
1037 """
1038 return self._nrows is not None
1040 def _has_precomputed_nvals(self):
1041 """Returns true if `nvals` has already been computed.
1043 If true, then `self.nvals()` will return its value without calling
1044 any TensorFlow ops.
1045 """
1046 return self._nvals is not None
1048 def _with_precomputed_row_splits(self):
1049 """Returns a copy of `self` with `row_splits` precomputed."""
1050 return RowPartition(
1051 row_splits=self.row_splits(),
1052 row_lengths=self._row_lengths,
1053 value_rowids=self._value_rowids,
1054 nrows=self._nrows,
1055 uniform_row_length=self._uniform_row_length,
1056 nvals=self._nvals,
1057 internal=_row_partition_factory_key)
1059 def _with_precomputed_row_lengths(self):
1060 """Returns a copy of `self` with `row_lengths` precomputed."""
1061 return RowPartition(
1062 row_splits=self._row_splits,
1063 row_lengths=self.row_lengths(),
1064 value_rowids=self._value_rowids,
1065 nrows=self._nrows,
1066 nvals=self._nvals,
1067 uniform_row_length=self._uniform_row_length,
1068 internal=_row_partition_factory_key)
1070 def _with_precomputed_value_rowids(self):
1071 """Returns a copy of `self` with `value_rowids` precomputed."""
1072 return RowPartition(
1073 row_splits=self._row_splits,
1074 row_lengths=self._row_lengths,
1075 value_rowids=self.value_rowids(),
1076 nrows=self._nrows,
1077 nvals=self._nvals,
1078 uniform_row_length=self._uniform_row_length,
1079 internal=_row_partition_factory_key)
1081 def _with_precomputed_nrows(self):
1082 """Returns a copy of `self` with `nrows` precomputed."""
1083 return RowPartition(
1084 row_splits=self._row_splits,
1085 row_lengths=self._row_lengths,
1086 value_rowids=self._value_rowids,
1087 nrows=self.nrows(),
1088 nvals=self._nvals,
1089 uniform_row_length=self._uniform_row_length,
1090 internal=_row_partition_factory_key)
1092 def _with_precomputed_nvals(self):
1093 """Returns a copy of `self` with `row_splits` precomputed."""
1094 return RowPartition(
1095 row_splits=self.row_splits(),
1096 row_lengths=self._row_lengths,
1097 value_rowids=self._value_rowids,
1098 nrows=self._nrows,
1099 nvals=self.nvals(),
1100 uniform_row_length=self._uniform_row_length,
1101 internal=_row_partition_factory_key)
1103 def _merge_with_spec(self, b):
1104 """Merge with a TypeSpec to create a new RowPartition."""
1105 a_spec = self._type_spec
1106 if not a_spec.is_compatible_with(b):
1107 # TODO(martinz): Should a dynamic check be used here?
1108 raise ValueError("RowPartition and RowPartitionSpec are not compatible")
1109 nrows = constant_op.constant(
1110 b.nrows, self.dtype) if b.nrows is not None else self._nrows
1111 nvals = constant_op.constant(
1112 b.nvals, self.dtype) if b.nvals is not None else self._nvals
1113 uniform_row_length = constant_op.constant(
1114 b.uniform_row_length, self.dtype
1115 ) if b.uniform_row_length is not None else self._uniform_row_length
1116 return RowPartition(
1117 row_splits=self._row_splits,
1118 row_lengths=self._row_lengths,
1119 value_rowids=self._value_rowids,
1120 nvals=nvals,
1121 uniform_row_length=uniform_row_length,
1122 nrows=nrows,
1123 internal=_row_partition_factory_key)
1125 def _merge_precomputed_encodings(self, other, validate=True):
1126 """Returns a RowPartition that merges encodings from `self` and `other`.
1128 Requires that `self` and `other` describe the same partition.
1130 Args:
1131 other: A `RowPartition` that encodes the same partition as `self`.
1132 validate: If true, then add runtime checks to verify that `self` and
1133 `other` encode the same row partition.
1135 Returns:
1136 A `RowPartition`.
1137 """
1138 # pylint: disable=protected-access
1139 if (self is other or # Fast path if row partitions are equal.
1140 (self._row_splits is other._row_splits and
1141 self._row_lengths is other._row_lengths and
1142 self._value_rowids is other._value_rowids and
1143 self._nrows is other._nrows and
1144 self._nvals is other._nvals and
1145 self._uniform_row_length is other._uniform_row_length)):
1146 return self
1148 # Merge the component tensors. We only need to validate one encoding.
1149 # We merge less-expensive encodings first (to avoid expensive validation).
1150 nrows, nrows_validated = _merge_tensors(self._nrows, other._nrows, "nrows",
1151 validate)
1152 nvals, _ = _merge_tensors(self._nvals, other._nvals, "nvals", validate)
1153 uniform_row_length, uniform_row_length_validated = _merge_tensors(
1154 self._uniform_row_length, other._uniform_row_length,
1155 "uniform_row_length", validate)
1156 if uniform_row_length_validated and nrows_validated:
1157 validate = False # Validation complete.
1158 row_splits, row_splits_validated = _merge_tensors(self._row_splits,
1159 other._row_splits,
1160 "row_splits", validate)
1161 if row_splits_validated:
1162 validate = False # Validation complete.
1163 row_lengths, row_lengths_validated = _merge_tensors(self._row_lengths,
1164 other._row_lengths,
1165 "row_lengths", validate)
1166 if row_lengths_validated:
1167 validate = False # Validation complete.
1168 value_rowids, value_rowids_validated = _merge_tensors(
1169 self._value_rowids, other._value_rowids, "value_rowids", validate)
1170 if value_rowids_validated and nrows_validated:
1171 validate = False # Validation complete.
1172 # TODO(edloper): If we make the row_splits encoding optional, then there
1173 # will be cases where we need to do validation at this point -- e.g. if
1174 # self has only row_splits and other has only value_rowids. But for
1175 # now, we are guaranteed to have done validation by this point.
1177 # Avoid creating new RowPartition objects if we don't need to.
1178 if (row_splits is self._row_splits and row_lengths is self._row_lengths and
1179 value_rowids is self._value_rowids and nrows is self._nrows and
1180 uniform_row_length is self._uniform_row_length):
1181 return self
1182 if (row_splits is other._row_splits and
1183 row_lengths is other._row_lengths and
1184 value_rowids is other._value_rowids and nrows is other._nrows and
1185 uniform_row_length is other._uniform_row_length):
1186 return other
1188 return RowPartition(
1189 row_splits=row_splits,
1190 row_lengths=row_lengths,
1191 value_rowids=value_rowids,
1192 nrows=nrows,
1193 uniform_row_length=uniform_row_length,
1194 nvals=nvals,
1195 internal=_row_partition_factory_key)
1197 # =============================================================================
1198 # Composite Tensor
1199 # =============================================================================
1201 @property
1202 def _type_spec(self):
1203 return RowPartitionSpec.from_value(self)
1206# ===============================================================================
1207# RowPartitionSpec
1208# ===============================================================================
1209# TODO(edloper): Consider refactoring RowPartitionSpec to allow any combination
1210# of precomputed row-partition encodings (rather than always using row_splits).
1213@type_spec_registry.register("tf.RowPartitionSpec")
1214class RowPartitionSpec(type_spec.TypeSpec):
1215 """Type specification for a `tf.RowPartition`."""
1217 __slots__ = ["_nrows", "_nvals", "_uniform_row_length", "_dtype"]
1219 value_type = property(lambda self: RowPartition)
1221 def __init__(self,
1222 nrows=None,
1223 nvals=None,
1224 uniform_row_length=None,
1225 dtype=dtypes.int64):
1226 """Constructs a new RowPartitionSpec.
1228 Args:
1229 nrows: The number of rows in the RowPartition, or `None` if unspecified.
1230 nvals: The number of values partitioned by the RowPartition, or `None` if
1231 unspecified.
1232 uniform_row_length: The number of values in each row for this
1233 RowPartition, or `None` if rows are ragged or row length is unspecified.
1234 dtype: The data type used to encode the partition. One of `tf.int64` or
1235 `tf.int32`.
1236 """
1237 # Wrap dimension sizes in 1D TensorShapes so the default implementations
1238 # of TypeSpec methods such as `is_compatile_with` will work.
1239 nrows = tensor_shape.TensorShape([nrows])
1240 nvals = tensor_shape.TensorShape([nvals])
1241 if not isinstance(uniform_row_length, tensor_shape.TensorShape):
1242 uniform_row_length = tensor_shape.TensorShape([uniform_row_length])
1243 else:
1244 uniform_row_length = uniform_row_length.with_rank(1)
1246 self._nrows = nrows
1247 self._nvals = nvals
1248 self._uniform_row_length = uniform_row_length
1249 self._dtype = dtypes.as_dtype(dtype)
1250 if self._dtype not in (dtypes.int32, dtypes.int64):
1251 raise ValueError("dtype must be tf.int32 or tf.int64")
1253 # Check dimension consistency, & infer dimensions when possible.
1254 nrows = tensor_shape.dimension_value(nrows[0])
1255 nvals = tensor_shape.dimension_value(nvals[0])
1256 ncols = tensor_shape.dimension_value(uniform_row_length[0])
1257 if nrows == 0: # no rows -> no values.
1258 if nvals is None:
1259 self._nvals = tensor_shape.TensorShape([0])
1260 elif nvals != 0:
1261 raise ValueError("nvals=%s is not compatible with nrows=%s" %
1262 (nvals, nrows))
1263 if ncols == 0: # there are no values in each row -> no values.
1264 if nvals is None:
1265 self._nvals = tensor_shape.TensorShape([0])
1266 elif nvals != 0:
1267 raise ValueError("nvals=%s is not compatible with uniform_row_length"
1268 "=%s" % (nvals, uniform_row_length))
1269 if ncols is not None and nvals is not None:
1270 if ncols != 0 and nvals % ncols != 0:
1271 raise ValueError("nvals=%s is not compatible with uniform_row_length"
1272 "=%s (doesn't divide evenly)" % (nvals, ncols))
1273 if nrows is not None and nvals != ncols * nrows:
1274 raise ValueError("nvals=%s is not compatible with nrows=%s and "
1275 "uniform_row_length=%s" % (nvals, nrows, ncols))
1276 if nrows is None and ncols != 0:
1277 self._nrows = tensor_shape.TensorShape([nvals // ncols])
1278 if ncols is not None and nrows is not None and nvals is None:
1279 self._nvals = tensor_shape.TensorShape([ncols * nrows])
1281 def is_compatible_with(self, other):
1282 if not super(RowPartitionSpec, self).is_compatible_with(other):
1283 return False
1284 nrows = self._nrows.merge_with(other.nrows)
1285 nvals = self._nvals.merge_with(other.nvals)
1286 ncols = self._uniform_row_length.merge_with(other.uniform_row_length)
1287 return self._dimensions_compatible(nrows, nvals, ncols)
1289 def _serialize(self):
1290 return (self._nrows, self._nvals, self._uniform_row_length, self._dtype)
1292 @classmethod
1293 def _deserialize(cls, serialization):
1294 # Remove TensorShape wrappers from serialization.
1295 (nrows, nvals, uniform_row_length, dtype) = serialization
1296 nrows = tensor_shape.dimension_value(nrows[0])
1297 nvals = tensor_shape.dimension_value(nvals[0])
1298 return cls(nrows, nvals, uniform_row_length, dtype)
1300 @property
1301 def nrows(self):
1302 return tensor_shape.dimension_value(self._nrows[0])
1304 @property
1305 def nvals(self):
1306 return tensor_shape.dimension_value(self._nvals[0])
1308 @property
1309 def uniform_row_length(self):
1310 return tensor_shape.dimension_value(self._uniform_row_length[0])
1312 @property
1313 def dtype(self):
1314 return self._dtype
1316 @property
1317 def _component_specs(self):
1318 row_splits_shape = tensor_shape.TensorShape(
1319 [tensor_shape.dimension_at_index(self._nrows, 0) + 1])
1320 return tensor_spec.TensorSpec(row_splits_shape, self._dtype)
1322 def _to_components(self, value):
1323 return value.row_splits()
1325 def _from_components(self, tensor):
1326 return RowPartition.from_row_splits(tensor, validate=False)
1328 @classmethod
1329 def from_value(cls, value):
1330 if not isinstance(value, RowPartition):
1331 raise TypeError("Expected `value` to be a `RowPartition`")
1332 return cls(value.static_nrows, value.static_nvals,
1333 value.static_uniform_row_length, value.dtype)
1335 def __repr__(self):
1336 return ("RowPartitionSpec(nrows=%s, nvals=%s, uniform_row_length=%s, "
1337 "dtype=%r)" % (self.nrows, self.nvals, self.uniform_row_length,
1338 self.dtype))
1340 @staticmethod
1341 def _dimensions_compatible(nrows, nvals, uniform_row_length):
1342 """Returns true if the given dimensions are compatible."""
1343 nrows = tensor_shape.dimension_value(nrows[0])
1344 nvals = tensor_shape.dimension_value(nvals[0])
1345 ncols = tensor_shape.dimension_value(uniform_row_length[0])
1346 if nrows == 0 and nvals not in (0, None):
1347 return False # can't have values if we have no rows.
1348 if ncols == 0 and nvals not in (0, None):
1349 return False # can't have values if we have no values in each row.
1350 if ncols is not None and nvals is not None:
1351 if ncols != 0 and nvals % ncols != 0:
1352 return False # rows aren't uniform.
1353 if nrows is not None and nvals != ncols * nrows:
1354 return False # inconsistent number of values.
1355 return True
1357 def _merge_with(self, other):
1358 """Merge two RowPartitionSpecs."""
1359 nrows = self._nrows.merge_with(other.nrows)
1360 nvals = self._nvals.merge_with(other.nvals)
1361 ncols = self._uniform_row_length.merge_with(other.uniform_row_length)
1363 if not RowPartitionSpec._dimensions_compatible(nrows, nvals, ncols):
1364 raise ValueError("Merging incompatible RowPartitionSpecs")
1366 # NOTE: if the dtypes are unequal, behavior is unspecified.
1367 if self.dtype != other.dtype:
1368 raise ValueError("Merging RowPartitionSpecs with incompatible dtypes")
1370 return RowPartitionSpec(nrows=nrows[0],
1371 nvals=nvals[0],
1372 uniform_row_length=ncols[0],
1373 dtype=self.dtype)
1375 def with_dtype(self, dtype):
1376 nrows = tensor_shape.dimension_value(self._nrows[0])
1377 nvals = tensor_shape.dimension_value(self._nvals[0])
1378 return RowPartitionSpec(nrows, nvals, self._uniform_row_length, dtype)
1380 def __deepcopy__(self, memo):
1381 del memo
1382 dtype = self.dtype
1383 nrows = tensor_shape.dimension_value(self._nrows[0])
1384 nvals = tensor_shape.dimension_value(self._nvals[0])
1385 uniform_row_length = (None if self._uniform_row_length is None else
1386 tensor_shape.dimension_value(
1387 self._uniform_row_length[0]))
1388 return RowPartitionSpec(nrows, nvals, uniform_row_length, dtype)
1391nested_structure_coder.register_codec(
1392 nested_structure_coder.BuiltInTypeSpecCodec(
1393 RowPartitionSpec, struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC
1394 )
1395)
1398# ===============================================================================
1399# Helper Functions
1400# ===============================================================================
1403def _assert_monotonic_increasing(tensor, message=None):
1404 return check_ops.assert_non_negative(
1405 tensor[1:] - tensor[:-1], message=message)
1408def _assert_zero(tensor, message=None):
1409 return check_ops.assert_equal(
1410 tensor, constant_op.constant(0, dtype=tensor.dtype), message=message)
1413def _cast_if_not_none(tensor, dtype):
1414 return None if tensor is None else math_ops.cast(tensor, dtype)
1417def _merge_tensors(t1, t2, name, validate):
1418 """Merge two optional Tensors with equal values into a single Tensor.
1420 Args:
1421 t1: tf.Tensor or None
1422 t2: tf.Tensor or None
1423 name: A name for the tensors (for error messages)
1424 validate: If true, then check that `t1` is compatible with `t2` (if both are
1425 non-None).
1427 Returns:
1428 A pair `(merged_value, validated)`:
1429 * `merged_value` is `t1` if it is not None; or `t2` otherwise.
1430 * `validated` is true if we validated that t1 and t2 are equal (either
1431 by adding a check, or because t1 is t2).
1432 """
1433 if t1 is None:
1434 return t2, False
1435 elif t2 is None:
1436 return t1, False
1437 elif t1 is t2:
1438 return t1, True
1439 else:
1440 err_msg = ("RowPartition._merge_precomputed_encodings: partitions "
1441 "have incompatible %s" % name)
1442 if not t1.shape.is_compatible_with(t2.shape):
1443 raise ValueError(err_msg)
1444 if validate:
1445 checks = [check_ops.assert_equal(t1, t2, message=err_msg)]
1446 return control_flow_ops.with_dependencies(checks, t1), True
1447 else:
1448 return t1, False
1450_row_partition_factory_key = object() # unique private object
1453def _get_dtype_or_none(value):
1454 if isinstance(value, ops.Tensor):
1455 return value.dtype
1456 return None
1459def _get_target_dtype(values, dtype=None, dtype_hint=None):
1460 """Gets the target dtype of a family of values."""
1461 if dtype is not None:
1462 return dtype
1464 for value in values:
1465 if isinstance(value, ops.Tensor):
1466 return value.dtype
1468 for value in values:
1469 if isinstance(value, np.ndarray):
1470 return dtypes.as_dtype(value.dtype)
1472 if dtype_hint is not None:
1473 return dtype_hint
1475 return dtypes.int64
1478def _convert_all_to_tensors(values, dtype=None, dtype_hint=None):
1479 """Convert a list of objects to tensors of the same dtype."""
1480 target_dtype = _get_target_dtype([x for (x, _) in values], dtype, dtype_hint)
1482 # If dtype is None, we use convert behavior.
1483 # If dtype is not None, we use cast behavior.
1484 convert_behavior = dtype is None
1486 if convert_behavior:
1487 return [
1488 None if x is None else ops.convert_to_tensor(
1489 x, dtype=target_dtype, name=name) for (x, name) in values
1490 ]
1491 else:
1492 return [
1493 None if x is None else math_ops.cast(x, dtype=target_dtype, name=name)
1494 for (x, name) in values
1495 ]