Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/parsing_config.py: 20%
302 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Feature configuration for tf.io.parse_example."""
17import collections
18import re
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import check_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import sparse_ops
28from tensorflow.python.ops.ragged import ragged_math_ops
29from tensorflow.python.ops.ragged import ragged_tensor
30from tensorflow.python.platform import tf_logging
31from tensorflow.python.util.tf_export import tf_export
34# TODO(b/122887740) Refactor code:
35# * Move input verification to feature configuration objects (e.g.,
36# VarLenFeature should check that dtype is a valid dtype).
37# * Add an _add_feature() method to each feature configuration object
38# (rather than using a dispatch table in _ParseOpParams._add_feature).
39# * Update _construct_tensors_for_composite_features() to call a method
40# on the feature object (rather than using dispatch).
43@tf_export("io.VarLenFeature", v1=["VarLenFeature", "io.VarLenFeature"])
44class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
45 """Configuration for parsing a variable-length input feature.
47 Fields:
48 dtype: Data type of input.
49 """
50 pass
53@tf_export("io.RaggedFeature")
54class RaggedFeature(
55 collections.namedtuple(
56 "RaggedFeature",
57 ["dtype", "value_key", "partitions", "row_splits_dtype", "validate"])):
58 """Configuration for passing a RaggedTensor input feature.
60 `value_key` specifies the feature key for a variable-length list of values;
61 and `partitions` specifies zero or more feature keys for partitioning those
62 values into higher dimensions. Each element of `partitions` must be one of
63 the following:
65 * `tf.io.RaggedFeature.RowSplits(key: string)`
66 * `tf.io.RaggedFeature.RowLengths(key: string)`
67 * `tf.io.RaggedFeature.RowStarts(key: string)`
68 * `tf.io.RaggedFeature.RowLimits(key: string)`
69 * `tf.io.RaggedFeature.ValueRowIds(key: string)`
70 * `tf.io.RaggedFeature.UniformRowLength(length: int)`.
72 Where `key` is a feature key whose values are used to partition the values.
73 Partitions are listed from outermost to innermost.
75 * If `len(partitions) == 0` (the default), then:
77 * A feature from a single `tf.Example` is parsed into a 1D `tf.Tensor`.
78 * A feature from a batch of `tf.Example`s is parsed into a 2D
79 `tf.RaggedTensor`, where the outer dimension is the batch dimension, and
80 the inner (ragged) dimension is the feature length in each example.
82 * If `len(partitions) == 1`, then:
84 * A feature from a single `tf.Example` is parsed into a 2D
85 `tf.RaggedTensor`, where the values taken from the `value_key` are
86 separated into rows using the partition key.
87 * A feature from a batch of `tf.Example`s is parsed into a 3D
88 `tf.RaggedTensor`, where the outer dimension is the batch dimension,
89 the two inner dimensions are formed by separating the `value_key` values
90 from each example into rows using that example's partition key.
92 * If `len(partitions) > 1`, then:
94 * A feature from a single `tf.Example` is parsed into a `tf.RaggedTensor`
95 whose rank is `len(partitions)+1`, and whose ragged_rank is
96 `len(partitions)`.
98 * A feature from a batch of `tf.Example`s is parsed into a `tf.RaggedTensor`
99 whose rank is `len(partitions)+2` and whose ragged_rank is
100 `len(partitions)+1`, where the outer dimension is the batch dimension.
102 There is one exception: if the final (i.e., innermost) element(s) of
103 `partitions` are `UniformRowLength`s, then the values are simply reshaped (as
104 a higher-dimensional `tf.Tensor`), rather than being wrapped in a
105 `tf.RaggedTensor`.
107 #### Examples
109 >>> import google.protobuf.text_format as pbtext
110 >>> example_batch = [
111 ... pbtext.Merge(r'''
112 ... features {
113 ... feature {key: "v" value {int64_list {value: [3, 1, 4, 1, 5, 9]}}}
114 ... feature {key: "s1" value {int64_list {value: [0, 2, 3, 3, 6]}}}
115 ... feature {key: "s2" value {int64_list {value: [0, 2, 3, 4]}}}
116 ... }''', tf.train.Example()).SerializeToString(),
117 ... pbtext.Merge(r'''
118 ... features {
119 ... feature {key: "v" value {int64_list {value: [2, 7, 1, 8, 2, 8, 1]}}}
120 ... feature {key: "s1" value {int64_list {value: [0, 3, 4, 5, 7]}}}
121 ... feature {key: "s2" value {int64_list {value: [0, 1, 1, 4]}}}
122 ... }''', tf.train.Example()).SerializeToString()]
124 >>> features = {
125 ... # Zero partitions: returns 1D tf.Tensor for each Example.
126 ... 'f1': tf.io.RaggedFeature(value_key="v", dtype=tf.int64),
127 ... # One partition: returns 2D tf.RaggedTensor for each Example.
128 ... 'f2': tf.io.RaggedFeature(value_key="v", dtype=tf.int64, partitions=[
129 ... tf.io.RaggedFeature.RowSplits("s1")]),
130 ... # Two partitions: returns 3D tf.RaggedTensor for each Example.
131 ... 'f3': tf.io.RaggedFeature(value_key="v", dtype=tf.int64, partitions=[
132 ... tf.io.RaggedFeature.RowSplits("s2"),
133 ... tf.io.RaggedFeature.RowSplits("s1")])
134 ... }
136 >>> feature_dict = tf.io.parse_single_example(example_batch[0], features)
137 >>> for (name, val) in sorted(feature_dict.items()):
138 ... print('%s: %s' % (name, val))
139 f1: tf.Tensor([3 1 4 1 5 9], shape=(6,), dtype=int64)
140 f2: <tf.RaggedTensor [[3, 1], [4], [], [1, 5, 9]]>
141 f3: <tf.RaggedTensor [[[3, 1], [4]], [[]], [[1, 5, 9]]]>
143 >>> feature_dict = tf.io.parse_example(example_batch, features)
144 >>> for (name, val) in sorted(feature_dict.items()):
145 ... print('%s: %s' % (name, val))
146 f1: <tf.RaggedTensor [[3, 1, 4, 1, 5, 9],
147 [2, 7, 1, 8, 2, 8, 1]]>
148 f2: <tf.RaggedTensor [[[3, 1], [4], [], [1, 5, 9]],
149 [[2, 7, 1], [8], [2], [8, 1]]]>
150 f3: <tf.RaggedTensor [[[[3, 1], [4]], [[]], [[1, 5, 9]]],
151 [[[2, 7, 1]], [], [[8], [2], [8, 1]]]]>
153 Fields:
154 dtype: Data type of the `RaggedTensor`. Must be one of:
155 `tf.dtypes.int64`, `tf.dtypes.float32`, `tf.dtypes.string`.
156 value_key: (Optional.) Key for a `Feature` in the input `Example`, whose
157 parsed `Tensor` will be the resulting `RaggedTensor.flat_values`. If
158 not specified, then it defaults to the key for this `RaggedFeature`.
159 partitions: (Optional.) A list of objects specifying the row-partitioning
160 tensors (from outermost to innermost). Each entry in this list must be
161 one of:
162 * `tf.io.RaggedFeature.RowSplits(key: string)`
163 * `tf.io.RaggedFeature.RowLengths(key: string)`
164 * `tf.io.RaggedFeature.RowStarts(key: string)`
165 * `tf.io.RaggedFeature.RowLimits(key: string)`
166 * `tf.io.RaggedFeature.ValueRowIds(key: string)`
167 * `tf.io.RaggedFeature.UniformRowLength(length: int)`.
168 Where `key` is a key for a `Feature` in the input `Example`, whose parsed
169 `Tensor` will be the resulting row-partitioning tensor.
170 row_splits_dtype: (Optional.) Data type for the row-partitioning tensor(s).
171 One of `int32` or `int64`. Defaults to `int32`.
172 validate: (Optional.) Boolean indicating whether or not to validate that
173 the input values form a valid RaggedTensor. Defaults to `False`.
174 """
176 # pylint: disable=invalid-name
177 RowSplits = collections.namedtuple("RowSplits", ["key"])
178 RowLengths = collections.namedtuple("RowLengths", ["key"])
179 RowStarts = collections.namedtuple("RowStarts", ["key"])
180 RowLimits = collections.namedtuple("RowLimits", ["key"])
181 ValueRowIds = collections.namedtuple("ValueRowIds", ["key"])
182 UniformRowLength = collections.namedtuple("UniformRowLength", ["length"])
183 # pylint: enable=invalid-name
185 _PARTITION_TYPES = (RowSplits, RowLengths, RowStarts, RowLimits, ValueRowIds,
186 UniformRowLength)
188 def __new__(cls,
189 dtype,
190 value_key=None,
191 partitions=(),
192 row_splits_dtype=dtypes.int32,
193 validate=False):
194 if value_key is not None:
195 if not isinstance(value_key, str):
196 raise ValueError(
197 f"Argument `value_key` must be a string; got {value_key}")
198 if not value_key:
199 raise ValueError("Argument `value_key` must not be empty")
200 dtype = dtypes.as_dtype(dtype)
201 if dtype not in (dtypes.int64, dtypes.float32, dtypes.string):
202 raise ValueError("Argument `dtype` must be int64, float32, or bytes; got "
203 f"{dtype!r}")
204 row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
205 if row_splits_dtype not in (dtypes.int32, dtypes.int64):
206 raise ValueError("Argument `row_splits_dtype` must be int32 or int64; got"
207 f"{row_splits_dtype!r}")
208 if not isinstance(partitions, (list, tuple)):
209 raise TypeError("Argument `partitions` must be a list or tuple. Received"
210 f"partitions={partitions} of type "
211 f"{type(partitions).__name__}.")
212 for partition in partitions:
213 if not isinstance(partition, cls._PARTITION_TYPES):
214 raise TypeError("Argument `partitions` must be a list of partition "
215 f"objects {cls._PARTITION_TYPES}; got: {partition!r}")
216 if not isinstance(validate, bool):
217 raise TypeError(f"Argument `validate` must be a bool; got {validate!r}")
218 return super(RaggedFeature, cls).__new__(cls, dtype, value_key, partitions,
219 row_splits_dtype, validate)
222@tf_export("io.SparseFeature", v1=["io.SparseFeature", "SparseFeature"])
223class SparseFeature(
224 collections.namedtuple(
225 "SparseFeature",
226 ["index_key", "value_key", "dtype", "size", "already_sorted"])):
227 """Configuration for parsing a sparse input feature from an `Example`.
229 Note, preferably use `VarLenFeature` (possibly in combination with a
230 `SequenceExample`) in order to parse out `SparseTensor`s instead of
231 `SparseFeature` due to its simplicity.
233 Closely mimicking the `SparseTensor` that will be obtained by parsing an
234 `Example` with a `SparseFeature` config, a `SparseFeature` contains a
236 * `value_key`: The name of key for a `Feature` in the `Example` whose parsed
237 `Tensor` will be the resulting `SparseTensor.values`.
239 * `index_key`: A list of names - one for each dimension in the resulting
240 `SparseTensor` whose `indices[i][dim]` indicating the position of
241 the `i`-th value in the `dim` dimension will be equal to the `i`-th value in
242 the Feature with key named `index_key[dim]` in the `Example`.
244 * `size`: A list of ints for the resulting `SparseTensor.dense_shape`.
246 For example, we can represent the following 2D `SparseTensor`
248 ```python
249 SparseTensor(indices=[[3, 1], [20, 0]],
250 values=[0.5, -1.0]
251 dense_shape=[100, 3])
252 ```
254 with an `Example` input proto
256 ```python
257 features {
258 feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } }
259 feature { key: "ix0" value { int64_list { value: [ 3, 20 ] } } }
260 feature { key: "ix1" value { int64_list { value: [ 1, 0 ] } } }
261 }
262 ```
264 and `SparseFeature` config with 2 `index_key`s
266 ```python
267 SparseFeature(index_key=["ix0", "ix1"],
268 value_key="val",
269 dtype=tf.float32,
270 size=[100, 3])
271 ```
273 Fields:
274 index_key: A single string name or a list of string names of index features.
275 For each key the underlying feature's type must be `int64` and its length
276 must always match that of the `value_key` feature.
277 To represent `SparseTensor`s with a `dense_shape` of `rank` higher than 1
278 a list of length `rank` should be used.
279 value_key: Name of value feature. The underlying feature's type must
280 be `dtype` and its length must always match that of all the `index_key`s'
281 features.
282 dtype: Data type of the `value_key` feature.
283 size: A Python int or list thereof specifying the dense shape. Should be a
284 list if and only if `index_key` is a list. In that case the list must be
285 equal to the length of `index_key`. Each for each entry `i` all values in
286 the `index_key`[i] feature must be in `[0, size[i])`.
287 already_sorted: A Python boolean to specify whether the values in
288 `value_key` are already sorted by their index position. If so skip
289 sorting. False by default (optional).
290 """
292 def __new__(cls, index_key, value_key, dtype, size, already_sorted=False):
293 return super(SparseFeature, cls).__new__(
294 cls, index_key, value_key, dtype, size, already_sorted)
297@tf_export("io.FixedLenFeature", v1=["io.FixedLenFeature", "FixedLenFeature"])
298class FixedLenFeature(collections.namedtuple(
299 "FixedLenFeature", ["shape", "dtype", "default_value"])):
300 """Configuration for parsing a fixed-length input feature.
302 To treat sparse input as dense, provide a `default_value`; otherwise,
303 the parse functions will fail on any examples missing this feature.
305 Fields:
306 shape: Shape of input data.
307 dtype: Data type of input.
308 default_value: Value to be used if an example is missing this feature. It
309 must be compatible with `dtype` and of the specified `shape`.
310 """
312 def __new__(cls, shape, dtype, default_value=None):
313 return super(FixedLenFeature, cls).__new__(
314 cls, shape, dtype, default_value)
317@tf_export("io.FixedLenSequenceFeature",
318 v1=["io.FixedLenSequenceFeature", "FixedLenSequenceFeature"])
319class FixedLenSequenceFeature(collections.namedtuple(
320 "FixedLenSequenceFeature",
321 ["shape", "dtype", "allow_missing", "default_value"])):
322 """Configuration for parsing a variable-length input feature into a `Tensor`.
324 The resulting `Tensor` of parsing a single `SequenceExample` or `Example` has
325 a static `shape` of `[None] + shape` and the specified `dtype`.
326 The resulting `Tensor` of parsing a `batch_size` many `Example`s has
327 a static `shape` of `[batch_size, None] + shape` and the specified `dtype`.
328 The entries in the `batch` from different `Examples` will be padded with
329 `default_value` to the maximum length present in the `batch`.
331 To treat a sparse input as dense, provide `allow_missing=True`; otherwise,
332 the parse functions will fail on any examples missing this feature.
334 Fields:
335 shape: Shape of input data for dimension 2 and higher. First dimension is
336 of variable length `None`.
337 dtype: Data type of input.
338 allow_missing: Whether to allow this feature to be missing from a feature
339 list item. Is available only for parsing `SequenceExample` not for
340 parsing `Examples`.
341 default_value: Scalar value to be used to pad multiple `Example`s to their
342 maximum length. Irrelevant for parsing a single `Example` or
343 `SequenceExample`. Defaults to "" for dtype string and 0 otherwise
344 (optional).
345 """
347 def __new__(cls, shape, dtype, allow_missing=False, default_value=None):
348 return super(FixedLenSequenceFeature, cls).__new__(
349 cls, shape, dtype, allow_missing, default_value)
352class _ParseOpParams:
353 """Raw parameters used by `gen_parsing_ops`.
355 Attributes:
356 sparse_keys: A list of string keys in the examples' features. The results
357 for these keys will be returned as `SparseTensor` objects.
358 sparse_types: A list of `DTypes` of the same length as `sparse_keys`. Only
359 `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
360 (`BytesList`) are supported.
361 dense_keys: A list of string keys in the examples' features. The results for
362 these keys will be returned as `Tensor`s
363 dense_types: A list of DTypes of the same length as `dense_keys`. Only
364 `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
365 (`BytesList`) are supported.
366 dense_defaults: A dict mapping string keys to `Tensor`s. The keys of the
367 dict must match the dense_keys of the feature.
368 dense_shapes: A list of tuples with the same length as `dense_keys`. The
369 shape of the data for each dense feature referenced by `dense_keys`.
370 Required for any input tensors identified by `dense_keys`. Must be either
371 fully defined, or may contain an unknown first dimension. An unknown first
372 dimension means the feature is treated as having a variable number of
373 blocks, and the output shape along this dimension is considered unknown at
374 graph build time. Padding is applied for minibatch elements smaller than
375 the maximum number of blocks for the given feature along this dimension.
376 ragged_keys: A list of string keys in the examples' features. The
377 results for these keys will be returned as `RaggedTensor` objects.
378 ragged_value_types: A list of `DTypes` of the same length as `ragged_keys`,
379 specifying the value type for each ragged feature. Must be one of:
380 `tf.float32`, `tf.int64`, `tf.string`.
381 ragged_split_types: A list of `DTypes` of the same length as `ragged_keys`,
382 specifying the row_splits type for each ragged feature. Must be one of:
383 `tf.int32`, `tf.int64`.
384 dense_shapes_as_proto: dense_shapes converted to TensorShapeProto.
385 dense_defaults_vec: A vector of `Tensor`s containing the default values,
386 corresponding 1:1 with `dense_keys`.
387 num_features: The total number of feature keys.
388 """
390 def __init__(self,
391 sparse_keys=None,
392 sparse_types=None,
393 dense_keys=None,
394 dense_types=None,
395 dense_defaults=None,
396 dense_shapes=None,
397 ragged_keys=None,
398 ragged_value_types=None,
399 ragged_split_types=None):
400 # Note: we use an OrderedDict for dense_defaults, to ensure consistent
401 # graph construction order for _e2e_test.
402 dense_defaults = (
403 collections.OrderedDict() if dense_defaults is None else dense_defaults)
404 sparse_keys = [] if sparse_keys is None else sparse_keys
405 sparse_types = [] if sparse_types is None else sparse_types
406 dense_keys = [] if dense_keys is None else dense_keys
407 dense_types = [] if dense_types is None else dense_types
408 dense_shapes = ([[]] *
409 len(dense_keys) if dense_shapes is None else dense_shapes)
410 ragged_keys = [] if ragged_keys is None else ragged_keys
411 ragged_value_types = ([]
412 if ragged_value_types is None else ragged_value_types)
413 ragged_split_types = ([]
414 if ragged_split_types is None else ragged_split_types)
415 self.sparse_keys = sparse_keys
416 self.sparse_types = [dtypes.as_dtype(t) for t in sparse_types]
417 self.dense_keys = dense_keys
418 self.dense_types = [dtypes.as_dtype(t) for t in dense_types]
419 self.dense_shapes = [tensor_shape.as_shape(s) for s in dense_shapes]
420 self.dense_defaults = dense_defaults
421 self.ragged_keys = ragged_keys
422 self.ragged_value_types = [dtypes.as_dtype(t) for t in ragged_value_types]
423 self.ragged_split_types = [dtypes.as_dtype(t) for t in ragged_split_types]
424 self._validate()
426 @classmethod
427 def from_features(cls, features, types):
428 """Builds _ParseOpParams for a given set of features and allowed types.
430 Args:
431 features: A `dict` mapping feature keys to objects of a type in `types`.
432 types: Type of features to allow, among `FixedLenFeature`,
433 `VarLenFeature`, `SparseFeature`, and `FixedLenSequenceFeature`.
435 Returns:
436 A `_ParseOpParams` containing the raw parameters for `gen_parsing_ops`.
438 Raises:
439 ValueError: if `features` contains an item not in `types`, or an invalid
440 feature.
441 ValueError: if sparse and dense key sets intersect.
442 ValueError: if input lengths do not match up.
443 """
444 params = cls()
445 if features:
446 # NOTE: We iterate over sorted keys to keep things deterministic.
447 for key in sorted(features.keys()):
448 feature = features[key]
449 if not isinstance(feature, tuple(types)):
450 raise ValueError(
451 f"Unsupported {type(feature).__name__} {feature} for key '{key}'")
452 params._add_feature(key, feature) # pylint: disable=protected-access
453 params._validate() # pylint: disable=protected-access
454 return params
456 @property
457 def dense_shapes_as_proto(self):
458 return [shape.as_proto() for shape in self.dense_shapes]
460 @property
461 def num_features(self):
462 return len(self.dense_keys) + len(self.sparse_keys) + len(self.ragged_keys)
464 @property
465 def dense_defaults_vec(self):
466 return [
467 self._make_dense_default(k, s, t)
468 for k, s, t in zip(self.dense_keys, self.dense_shapes, self.dense_types)
469 ]
471 def _make_dense_default(self, key, shape, dtype):
472 """Construct the default value tensor for a specified dense feature.
474 Args:
475 key: The key string identifying the dense feature.
476 shape: The dense feature's shape.
477 dtype: The dense feature's dtype.
479 Returns:
480 A Tensor.
481 """
482 default_value = self.dense_defaults.get(key)
483 if (shape.ndims is not None and shape.ndims > 0 and
484 shape.dims[0].value is None):
485 # Variable stride dense shape, the default value should be a
486 # scalar padding value.
487 if default_value is None:
488 default_value = ops.convert_to_tensor(
489 "" if dtype == dtypes.string else 0, dtype=dtype)
490 else:
491 # Reshape to a scalar to ensure user gets an error if they
492 # provide a tensor that's not intended to be a padding value
493 # (0 or 2+ elements).
494 key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
495 default_value = ops.convert_to_tensor(
496 default_value, dtype=dtype, name=key_name)
497 default_value = array_ops.reshape(default_value, [])
498 else:
499 if default_value is None:
500 default_value = constant_op.constant([], dtype=dtype)
501 elif not isinstance(default_value, ops.Tensor):
502 key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
503 default_value = ops.convert_to_tensor(
504 default_value, dtype=dtype, name=key_name)
505 default_value = array_ops.reshape(default_value, shape)
507 return default_value
509 def _add_feature(self, key, feature):
510 """Adds the specified feature to this ParseOpParams."""
511 if isinstance(feature, VarLenFeature):
512 self._add_varlen_feature(key, feature)
513 elif isinstance(feature, SparseFeature):
514 self._add_sparse_feature(key, feature)
515 elif isinstance(feature, FixedLenFeature):
516 self._add_fixed_len_feature(key, feature)
517 elif isinstance(feature, FixedLenSequenceFeature):
518 self._add_fixed_len_sequence_feature(key, feature)
519 elif isinstance(feature, RaggedFeature):
520 self._add_ragged_feature(key, feature)
521 else:
522 raise ValueError(f"Invalid feature {key}:{feature}.")
524 def _add_varlen_feature(self, key, feature):
525 """Adds a VarLenFeature."""
526 if not feature.dtype:
527 raise ValueError(
528 f"Missing type for feature {key}. Received feature={feature}")
529 self._add_sparse_key(key, feature.dtype)
531 def _add_sparse_key(self, key, dtype):
532 """Adds a sparse key & dtype, checking for duplicates."""
533 if key in self.sparse_keys:
534 original_dtype = self.sparse_types[self.sparse_keys.index(key)]
535 if original_dtype != dtype:
536 raise ValueError(
537 f"Conflicting type {original_dtype} vs {dtype} for feature {key}.")
538 else:
539 self.sparse_keys.append(key)
540 self.sparse_types.append(dtype)
542 def _add_sparse_feature(self, key, feature):
543 """Adds a SparseFeature."""
545 if not feature.index_key:
546 raise ValueError(f"Missing index_key for SparseFeature {feature}.")
547 if not feature.value_key:
548 raise ValueError(f"Missing value_key for SparseFeature {feature}.")
549 if not feature.dtype:
550 raise ValueError(f"Missing type for feature {key}. Received feature="
551 f"{feature}.")
552 index_keys = feature.index_key
553 if isinstance(index_keys, str):
554 index_keys = [index_keys]
555 elif len(index_keys) > 1:
556 tf_logging.warning("SparseFeature is a complicated feature config "
557 "and should only be used after careful "
558 "consideration of VarLenFeature.")
559 for index_key in sorted(index_keys):
560 self._add_sparse_key(index_key, dtypes.int64)
561 self._add_sparse_key(feature.value_key, feature.dtype)
563 def _add_fixed_len_feature(self, key, feature):
564 """Adds a FixedLenFeature."""
565 if not feature.dtype:
566 raise ValueError(f"Missing type for feature {key}. Received feature="
567 f"{feature}.")
568 if feature.shape is None:
569 raise ValueError(f"Missing shape for feature {key}. Received feature="
570 f"{feature}.")
571 feature_tensor_shape = tensor_shape.as_shape(feature.shape)
572 if (feature.shape and feature_tensor_shape.ndims and
573 feature_tensor_shape.dims[0].value is None):
574 raise ValueError(f"First dimension of shape for feature {key} unknown. "
575 "Consider using FixedLenSequenceFeature. Received "
576 f"feature={feature}.")
577 if (feature.shape is not None and
578 not feature_tensor_shape.is_fully_defined()):
579 raise ValueError(f"All dimensions of shape for feature {key} need to be "
580 f"known but received {feature.shape!s}.")
581 self.dense_keys.append(key)
582 self.dense_shapes.append(tensor_shape.as_shape(feature.shape))
583 self.dense_types.append(feature.dtype)
584 if feature.default_value is not None:
585 self.dense_defaults[key] = feature.default_value
587 def _add_fixed_len_sequence_feature(self, key, feature):
588 """Adds a FixedLenSequenceFeature."""
589 if not feature.dtype:
590 raise ValueError(f"Missing type for feature {key}. Received feature="
591 f"{feature}.")
592 if feature.shape is None:
593 raise ValueError(f"Missing shape for feature {key}. Received feature="
594 f"{feature}.")
595 self.dense_keys.append(key)
596 self.dense_shapes.append(tensor_shape.as_shape(feature.shape))
597 self.dense_types.append(feature.dtype)
598 if feature.allow_missing:
599 self.dense_defaults[key] = None
600 if feature.default_value is not None:
601 self.dense_defaults[key] = feature.default_value
603 def _add_ragged_key(self, key, value_type, split_type):
604 """Adds a ragged key & dtype, checking for duplicates."""
605 if key in self.ragged_keys:
606 original_value_type = self.ragged_value_types[self.ragged_keys.index(key)]
607 original_split_type = self.ragged_split_types[self.ragged_keys.index(key)]
608 if original_value_type != value_type:
609 raise ValueError(f"Conflicting type {original_value_type} vs "
610 f"{value_type} for feature {key}.")
611 if original_split_type != split_type:
612 raise ValueError(f"Conflicting partition type {original_split_type} vs "
613 f"{split_type} for feature {key}.")
614 else:
615 self.ragged_keys.append(key)
616 self.ragged_value_types.append(value_type)
617 self.ragged_split_types.append(split_type)
619 def _add_ragged_feature(self, key, feature):
620 """Adds a RaggedFeature."""
621 value_key = key if feature.value_key is None else feature.value_key
622 self._add_ragged_key(value_key, feature.dtype, feature.row_splits_dtype)
623 for partition in feature.partitions:
624 if not isinstance(partition, RaggedFeature.UniformRowLength):
625 self._add_ragged_key(partition.key, dtypes.int64,
626 feature.row_splits_dtype)
628 def _validate(self):
629 """Validates the features in this ParseOpParams."""
630 if len(self.dense_shapes) != len(self.dense_keys):
631 raise ValueError("len(self.dense_shapes) != len(self.dense_keys): "
632 f"{len(self.dense_shapes)} vs {len(self.dense_keys)}.")
633 if len(self.dense_types) != len(self.dense_keys):
634 raise ValueError("len(self.dense_types) != len(self.dense_keys): "
635 f"{len(self.dense_types)} vs {len(self.dense_keys)}.")
636 if len(self.sparse_types) != len(self.sparse_keys):
637 raise ValueError("len(self.sparse_types) != len(self.sparse_keys): "
638 f"{len(self.sparse_types)} vs {len(self.sparse_keys)}.")
639 if len(self.ragged_value_types) != len(self.ragged_keys):
640 raise ValueError(
641 "len(self.ragged_value_types) != len(self.ragged_keys): "
642 f"{len(self.ragged_value_types)} vs {len(self.ragged_keys)}.")
643 if len(self.ragged_split_types) != len(self.ragged_keys):
644 raise ValueError(
645 "len(self.ragged_split_types) != len(self.ragged_keys): "
646 f"{len(self.ragged_split_types)} vs {len(self.ragged_keys)}.")
648 dense_key_set = set(self.dense_keys)
649 sparse_key_set = set(self.sparse_keys)
650 ragged_key_set = set(self.ragged_keys)
651 if not dense_key_set.isdisjoint(sparse_key_set):
652 raise ValueError(
653 "Dense and sparse keys must not intersect; dense_keys: "
654 f"{self.dense_keys}, sparse_keys: {self.sparse_keys}, intersection: "
655 f"{dense_key_set.intersection(sparse_key_set)}")
656 if not dense_key_set.isdisjoint(ragged_key_set):
657 raise ValueError(
658 "Dense and ragged keys must not intersect; dense_keys: ",
659 f"{self.dense_keys}, ragged_keys: {self.ragged_keys}, intersection: "
660 f"{dense_key_set.intersection(ragged_key_set)}")
661 if not ragged_key_set.isdisjoint(sparse_key_set):
662 raise ValueError(
663 "Ragged and sparse keys must not intersect; ragged_keys: "
664 f"{self.ragged_keys}, sparse_keys: {self.sparse_keys}, intersection: "
665 f"{ragged_key_set.intersection(sparse_key_set)}")
668def _construct_tensors_for_composite_features(features, tensor_dict):
669 """Creates tensors for SparseFeatures and RaggedFeatures.
671 Constructs new dict based on `tensor_dict`.
673 For each key in `features` whose value is a `SparseFeature`:
675 * Looks up that SparseFeature's value_key and index_keys in tensor_dict.
676 * Uses those tensors to construct a single SparseTensor.
677 * Stores that SparseTensor in the output dict under the same key.
679 For each key in `features` whose value is a `RaggedFeature`:
681 * Looks up that RaggedFeature's value_key and partition keys in tensor_dict.
682 * Uses those tensors to construct a single RaggedTensor.
683 * Stores that RaggedTensor in the output dict under the same key.
685 For any other key in `features`:
687 * Copies that key and its value from tensor_dict to the output dictionary.
689 Args:
690 features: A `dict` mapping feature keys to `SparseFeature` or
691 `RaggedFeature` values. Values of other types will be ignored.
692 tensor_dict: A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and
693 `RaggedTensor` values. Expected to contain keys of the `SparseFeature`s'
694 `index_key`s and `value_key`s and mapping them to `SparseTensor`s.
696 Returns:
697 A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and
698 `RaggedTensor` values. Similar to `tensor_dict` except each `SparseFeature`
699 in `features` results in a single `SparseTensor`; and each `RaggedFeature`
700 in `features` results in a single `RaggedTensor`.
701 """
702 tensor_dict = dict(tensor_dict) # Do not modify argument passed in.
703 updates = {}
704 for key in sorted(features.keys()):
705 feature = features[key]
706 if isinstance(feature, SparseFeature):
707 # Construct SparseTensors for SparseFeatures
708 if isinstance(feature.index_key, str):
709 sp_ids = tensor_dict[feature.index_key]
710 else:
711 sp_ids = [tensor_dict[index_key] for index_key in feature.index_key]
712 sp_values = tensor_dict[feature.value_key]
713 updates[key] = sparse_ops.sparse_merge(
714 sp_ids,
715 sp_values,
716 vocab_size=feature.size,
717 already_sorted=feature.already_sorted)
718 elif isinstance(feature, RaggedFeature):
719 # Construct RaggedTensors for RaggedFeatures.
720 value_key = key if feature.value_key is None else feature.value_key
721 rt = tensor_dict[value_key]
722 if isinstance(rt, ragged_tensor.RaggedTensor):
723 # We processed a batch of tf.Example or tf.SequenceExample, or single
724 # tf.SequenceExample.
725 if rt.ragged_rank > 1:
726 # We're processing a batch of SequenceExample, and we effectively have
727 # two batch dimensions. Cllapse those batch dimensions here, and
728 # restore them below (using outer_splits).
729 outer_splits = rt.row_splits
730 rt = rt.values
731 else:
732 outer_splits = None
733 for partition in reversed(feature.partitions):
734 rt = _add_batched_ragged_partition(rt, partition, tensor_dict,
735 key, feature.validate,
736 outer_splits)
737 if outer_splits is not None:
738 rt = ragged_tensor.RaggedTensor.from_row_splits(
739 rt, outer_splits, validate=feature.validate)
740 else:
741 # We processed a single tf.Example.
742 for partition in reversed(feature.partitions):
743 rt = _add_ragged_partition(rt, partition, tensor_dict,
744 feature.row_splits_dtype, feature.validate)
745 updates[key] = rt
747 # Process updates after all composite tensors have been constructed (in case
748 # multiple features use the same value_key, and one uses that key as its
749 # feature key).
750 tensor_dict.update(updates)
752 # Remove tensors from dictionary that were only used to construct
753 # tensors for SparseFeature or RaggedTensor.
754 for key in set(tensor_dict) - set(features):
755 del tensor_dict[key]
756 return tensor_dict
759def _add_ragged_partition(values, partition, tensor_dict, row_splits_dtype,
760 validate):
761 """Creates a RaggedTensor from a values tensor and a partition tensor.
763 Args:
764 values: The values tensor for the new RaggedTensor.
765 partition: The partition configuration object. Specifies the key that
766 should be used to look up the partition tensor (unless partition is a
767 RaggedFeature.UniformRowLength, in which case there is no partition
768 tensor).
769 tensor_dict: The dictionary mapping keys to tensors.
770 row_splits_dtype: The dtype for the partition tensor.
771 validate: Whether to validate that the values form a valid RaggedTensor.
773 Returns:
774 A new RaggedTensor formed from the values and partition tensors.
775 """
776 if isinstance(partition, RaggedFeature.UniformRowLength):
777 if isinstance(values, ragged_tensor.RaggedTensor):
778 length = ops.convert_to_tensor(partition.length, dtype=row_splits_dtype)
779 return ragged_tensor.RaggedTensor.from_uniform_row_length(
780 values, length, validate=validate)
781 else:
782 return array_ops.reshape(values, array_ops.concat(
783 [[-1, partition.length], array_ops.shape(values)[1:]], axis=0))
784 else:
785 partition_t = math_ops.cast(tensor_dict[partition.key], row_splits_dtype)
786 if isinstance(partition, RaggedFeature.RowSplits):
787 return ragged_tensor.RaggedTensor.from_row_splits(
788 values, partition_t, validate=validate)
789 elif isinstance(partition, RaggedFeature.RowLengths):
790 return ragged_tensor.RaggedTensor.from_row_lengths(
791 values, partition_t, validate=validate)
792 elif isinstance(partition, RaggedFeature.RowStarts):
793 return ragged_tensor.RaggedTensor.from_row_starts(
794 values, partition_t, validate=validate)
795 elif isinstance(partition, RaggedFeature.RowLimits):
796 return ragged_tensor.RaggedTensor.from_row_limits(
797 values, partition_t, validate=validate)
798 elif isinstance(partition, RaggedFeature.ValueRowIds):
799 return ragged_tensor.RaggedTensor.from_value_rowids(
800 values, partition_t, validate=validate)
801 raise ValueError(f"Unhandled partition type {partition!r}")
804def _add_batched_ragged_partition(rt, partition, tensor_dict, feature_key,
805 validate, outer_splits=None):
806 """Adds a batched ragged partition tensor to a batched ragged tensor.
808 Args:
809 rt: A RaggedTensor with shape [batch_size, ...].
810 partition: The partition configuration object. Specifies the key that
811 should be used to look up the partition tensor (unless partition is a
812 RaggedFeature.UniformRowLength, in which case there is no partition
813 tensor). The specified tensor must have shape [batch_size, ...].
814 tensor_dict: The dictionary mapping keys to tensors.
815 feature_key: The name of the feature being parsed (for error messages).
816 validate: Whether to validate that the values form a valid RaggedTensor.
817 outer_splits: If not None, then we have two batch dimensions, and this
818 is the row-splits for the collapsed batch dimension. Every partition
819 tensor must have an outer row_splits that matches this value.
821 Returns:
822 A new RaggedTensor where each batch item `rt[i]` has been partitioned
823 using the `partition_t[i]`.
824 """
825 if isinstance(partition, RaggedFeature.UniformRowLength):
826 if rt.ragged_rank > 1:
827 length = ops.convert_to_tensor(partition.length, rt.row_splits.dtype)
828 return ragged_tensor.RaggedTensor.from_row_splits(
829 ragged_tensor.RaggedTensor.from_uniform_row_length(
830 rt.values, length, validate=validate),
831 rt.row_splits // length,
832 validate=validate)
833 else:
834 reshaped_vals = array_ops.reshape(rt.values, array_ops.concat(
835 [[-1, partition.length], array_ops.shape(rt.values)[1:]], axis=0))
836 return ragged_tensor.RaggedTensor.from_row_splits(
837 reshaped_vals, rt.row_splits // partition.length, validate=validate)
839 partition_t = tensor_dict[partition.key]
840 if partition_t.values.dtype != rt.row_splits.dtype:
841 partition_t = math_ops.cast(partition_t, rt.row_splits.dtype)
843 checks = []
844 if outer_splits is not None:
845 if validate:
846 checks.append(check_ops.assert_equal(
847 outer_splits, partition_t.row_splits,
848 message="Feature %s: values and partitions are not aligned"
849 % feature_key))
850 partition_t = partition_t.values
852 with ops.control_dependencies(checks):
853 if isinstance(partition, (RaggedFeature.RowSplits,
854 RaggedFeature.RowLimits)):
855 if isinstance(partition, RaggedFeature.RowSplits):
856 partition_t = partition_t[:, 1:]
857 adjusted_limits = partition_t.values + array_ops.repeat(
858 rt.row_starts(), partition_t.row_lengths())
859 return partition_t.with_values(
860 ragged_tensor.RaggedTensor.from_row_limits(
861 rt.values, adjusted_limits, validate=validate))
862 elif isinstance(partition, RaggedFeature.RowStarts):
863 adjusted_starts = partition_t.values + array_ops.repeat(
864 rt.row_starts(), partition_t.row_lengths())
865 return partition_t.with_values(
866 ragged_tensor.RaggedTensor.from_row_starts(
867 rt.values, adjusted_starts, validate=validate))
868 elif isinstance(partition, RaggedFeature.RowLengths):
869 return partition_t.with_values(
870 ragged_tensor.RaggedTensor.from_row_lengths(
871 rt.values, partition_t.values, validate=validate))
872 elif isinstance(partition, RaggedFeature.ValueRowIds):
873 nrows = math_ops.maximum( # number of rows in each batch item
874 ragged_math_ops.reduce_max(partition_t + 1, axis=1), 0)
875 adjusted_rowids = partition_t.values + array_ops.repeat(
876 math_ops.cumsum(nrows, exclusive=True), partition_t.row_lengths())
877 return ragged_tensor.RaggedTensor.from_row_lengths(
878 ragged_tensor.RaggedTensor.from_value_rowids(
879 rt.values, adjusted_rowids, validate=validate),
880 nrows,
881 validate=validate)
883 raise ValueError(f"Unhandled partition type {partition!r}")
886def _build_ragged_tensors(serialized_shape,
887 ragged_values,
888 ragged_row_splits,
889 ragged_inner_splits=None):
890 """Builds RaggedTensors from the outputs of a parse op."""
891 if ragged_inner_splits is not None:
892 ragged_values = [
893 ragged_tensor.RaggedTensor.from_row_splits(val, split, validate=False)
894 for (val, split) in zip(ragged_values, ragged_inner_splits)
895 ]
896 if serialized_shape.ndims == 0:
897 return ragged_values
898 else:
899 return [
900 ragged_tensor.RaggedTensor.from_row_splits(val, split, validate=False)
901 for (val, split) in zip(ragged_values, ragged_row_splits)
902 ]