Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py: 35%
1070 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Datasets."""
16import abc
17import functools
18import queue
19import threading
20import warnings
22import numpy as np
24from tensorflow.core.framework import dataset_metadata_pb2
25from tensorflow.core.framework import dataset_options_pb2
26from tensorflow.core.framework import graph_pb2
27from tensorflow.core.protobuf import struct_pb2
28from tensorflow.python import tf2
29from tensorflow.python.data.ops import dataset_autograph
30from tensorflow.python.data.ops import debug_mode
31from tensorflow.python.data.ops import iterator_ops
32from tensorflow.python.data.ops import options as options_lib
33from tensorflow.python.data.ops import structured_function
34from tensorflow.python.data.util import nest
35from tensorflow.python.data.util import structure
36from tensorflow.python.data.util import traverse
37from tensorflow.python.eager import context
38from tensorflow.python.framework import auto_control_deps
39from tensorflow.python.framework import auto_control_deps_utils as acd_utils
40from tensorflow.python.framework import composite_tensor
41from tensorflow.python.framework import constant_op
42from tensorflow.python.framework import dtypes
43from tensorflow.python.framework import function
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import random_seed as core_random_seed
46from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
47from tensorflow.python.framework import tensor_shape
48from tensorflow.python.framework import tensor_spec
49from tensorflow.python.framework import tensor_util
50from tensorflow.python.framework import type_spec
51from tensorflow.python.ops import array_ops
52from tensorflow.python.ops import check_ops
53from tensorflow.python.ops import cond
54from tensorflow.python.ops import control_flow_assert
55from tensorflow.python.ops import gen_dataset_ops
56from tensorflow.python.ops import gen_io_ops
57from tensorflow.python.ops import gen_parsing_ops
58from tensorflow.python.ops import logging_ops
59from tensorflow.python.ops import math_ops
60from tensorflow.python.ops import random_ops
61from tensorflow.python.ops import string_ops
62from tensorflow.python.ops.ragged import ragged_tensor
63from tensorflow.python.saved_model import nested_structure_coder
64from tensorflow.python.trackable import asset
65from tensorflow.python.trackable import base as tracking_base
66from tensorflow.python.trackable import resource as resource_lib
67from tensorflow.python.types import data as data_types
68from tensorflow.python.types import trace
69from tensorflow.python.util import deprecation
70from tensorflow.python.util import lazy_loader
71from tensorflow.python.util import nest as tf_nest
72from tensorflow.python.util.compat import collections_abc
73from tensorflow.python.util.tf_export import tf_export
75# Symbols forwarded for legacy access through dataset_ops.py. These forwarded
76# symbols can be removed once all internal uses are updated.
77StructuredFunctionWrapper = structured_function.StructuredFunctionWrapper
79# Loaded lazily due to a circular dependency (roughly
80# tf.function->wrap_function->dataset->autograph->tf.function).
81# TODO(b/133251390): Use a regular import.
82wrap_function = lazy_loader.LazyLoader(
83 "wrap_function", globals(),
84 "tensorflow.python.eager.wrap_function")
85# Loaded lazily due to a circular dependency
86# dataset_ops->def_function->func_graph->autograph->dataset_ops
87# TODO(kathywu): Use a regular import.
88def_function = lazy_loader.LazyLoader(
89 "def_function", globals(),
90 "tensorflow.python.eager.def_function")
91# TODO(b/240947712): Clean up the circular dependencies.
92# Loaded lazily due to a circular dependency (dataset_ops ->
93# prefetch_op -> dataset_ops).
94prefetch_op = lazy_loader.LazyLoader(
95 "prefetch_op", globals(),
96 "tensorflow.python.data.ops.prefetch_op")
97# Loaded lazily due to a circular dependency (dataset_ops ->
98# shuffle_op -> dataset_ops).
99shuffle_op = lazy_loader.LazyLoader(
100 "shuffle_op", globals(),
101 "tensorflow.python.data.ops.shuffle_op")
104ops.NotDifferentiable("ReduceDataset")
106# A constant that can be used to enable auto-tuning.
107AUTOTUNE = -1
108tf_export("data.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
109# TODO(b/168128531): Deprecate and remove this symbol.
110tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
112# Constants representing infinite and unknown cardinalities.
113INFINITE = -1
114UNKNOWN = -2
115COMPRESSION_GZIP = "GZIP"
116COMPRESSION_SNAPPY = "NONE"
117DATASET_SPEC_FILENAME = "dataset_spec.pb"
118tf_export("data.INFINITE_CARDINALITY").export_constant(__name__, "INFINITE")
119tf_export("data.UNKNOWN_CARDINALITY").export_constant(__name__, "UNKNOWN")
122def _validate_and_encode(name):
123 if not name.isidentifier():
124 raise ValueError("Invalid `name`. The argument `name` needs to be a valid "
125 "identifier. Value is considered a valid identifier if it "
126 "only contains alphanumeric characters (a-z), (A-Z), and "
127 "(0-9), or underscores (_). A valid identifier cannot "
128 "start with a number, or contain any spaces.")
129 return name.encode("utf-8")
132def get_type(value):
133 """Returns the type of `value` if it is a TypeSpec."""
135 if isinstance(value, type_spec.TypeSpec):
136 return value.value_type()
137 else:
138 return type(value)
141@tf_export("data.Dataset", v1=[])
142class DatasetV2(
143 collections_abc.Iterable,
144 tracking_base.Trackable,
145 composite_tensor.CompositeTensor,
146 data_types.DatasetV2,
147 metaclass=abc.ABCMeta):
148 """Represents a potentially large set of elements.
150 The `tf.data.Dataset` API supports writing descriptive and efficient input
151 pipelines. `Dataset` usage follows a common pattern:
153 1. Create a source dataset from your input data.
154 2. Apply dataset transformations to preprocess the data.
155 3. Iterate over the dataset and process the elements.
157 Iteration happens in a streaming fashion, so the full dataset does not need to
158 fit into memory.
160 Source Datasets:
162 The simplest way to create a dataset is to create it from a python `list`:
164 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
165 >>> for element in dataset:
166 ... print(element)
167 tf.Tensor(1, shape=(), dtype=int32)
168 tf.Tensor(2, shape=(), dtype=int32)
169 tf.Tensor(3, shape=(), dtype=int32)
171 To process lines from files, use `tf.data.TextLineDataset`:
173 >>> dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])
175 To process records written in the `TFRecord` format, use `TFRecordDataset`:
177 >>> dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])
179 To create a dataset of all files matching a pattern, use
180 `tf.data.Dataset.list_files`:
182 ```python
183 dataset = tf.data.Dataset.list_files("/path/*.txt")
184 ```
186 See `tf.data.FixedLengthRecordDataset` and `tf.data.Dataset.from_generator`
187 for more ways to create datasets.
189 Transformations:
191 Once you have a dataset, you can apply transformations to prepare the data for
192 your model:
194 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
195 >>> dataset = dataset.map(lambda x: x*2)
196 >>> list(dataset.as_numpy_iterator())
197 [2, 4, 6]
199 Common Terms:
201 **Element**: A single output from calling `next()` on a dataset iterator.
202 Elements may be nested structures containing multiple components. For
203 example, the element `(1, (3, "apple"))` has one tuple nested in another
204 tuple. The components are `1`, `3`, and `"apple"`.
206 **Component**: The leaf in the nested structure of an element.
208 Supported types:
210 Elements can be nested structures of tuples, named tuples, and dictionaries.
211 Note that Python lists are *not* treated as nested structures of components.
212 Instead, lists are converted to tensors and treated as components. For
213 example, the element `(1, [1, 2, 3])` has only two components; the tensor `1`
214 and the tensor `[1, 2, 3]`. Element components can be of any type
215 representable by `tf.TypeSpec`, including `tf.Tensor`, `tf.data.Dataset`,
216 `tf.sparse.SparseTensor`, `tf.RaggedTensor`, and `tf.TensorArray`.
218 ```python
219 a = 1 # Integer element
220 b = 2.0 # Float element
221 c = (1, 2) # Tuple element with 2 components
222 d = {"a": (2, 2), "b": 3} # Dict element with 3 components
223 Point = collections.namedtuple("Point", ["x", "y"])
224 e = Point(1, 2) # Named tuple
225 f = tf.data.Dataset.range(10) # Dataset element
226 ```
228 For more information,
229 read [this guide](https://www.tensorflow.org/guide/data).
230 """
232 def __init__(self, variant_tensor):
233 """Creates a DatasetV2 object.
235 This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not
236 take anything in its constructor whereas in the DatasetV2, we expect
237 subclasses to create a variant_tensor and pass it in to the super() call.
239 Args:
240 variant_tensor: A DT_VARIANT tensor that represents the dataset.
241 """
242 self._variant_tensor_attr = variant_tensor
243 self._graph_attr = ops.get_default_graph()
245 # Initialize the options for this dataset and its inputs.
246 self._options_attr = options_lib.Options()
247 for input_dataset in self._inputs():
248 input_options = None
249 if isinstance(input_dataset, data_types.DatasetV1):
250 # If the V1 dataset does not have the `_dataset` attribute, we assume it
251 # is a dataset source and hence does not have options. Otherwise, we
252 # grab the options of `_dataset` object
253 if hasattr(input_dataset, "_dataset"):
254 if not isinstance(input_dataset._dataset, data_types.DatasetV2):
255 raise TypeError(
256 f"Each input of dataset {type(self)} should be a subclass of "
257 f"`tf.data.Dataset` but encountered "
258 f"{type(input_dataset._dataset)}.")
259 input_options = input_dataset._dataset._options_attr
260 elif isinstance(input_dataset, data_types.DatasetV2):
261 input_options = input_dataset._options_attr
262 else:
263 raise TypeError(
264 f"Each input of dataset {type(self)} should be a subclass of "
265 f"`tf.data.Dataset` but encountered {type(input_dataset)}.")
266 if input_options is not None:
267 self._options_attr = self._options_attr.merge(input_options)
268 self._options_attr._set_mutable(False) # pylint: disable=protected-access
270 @property
271 def _variant_tensor(self):
272 return self._variant_tensor_attr
274 @_variant_tensor.setter
275 def _variant_tensor(self, _):
276 raise ValueError("The `_variant_tensor` property cannot be modified.")
278 @deprecation.deprecated_args(None, "Use external_state_policy instead",
279 "allow_stateful")
280 def _as_serialized_graph(
281 self,
282 allow_stateful=None,
283 strip_device_assignment=None,
284 external_state_policy=options_lib.ExternalStatePolicy.WARN):
285 """Produces serialized graph representation of the dataset.
287 Args:
288 allow_stateful: If true, we allow stateful ops to be present in the graph
289 def. In that case, the state in these ops would be thrown away.
290 strip_device_assignment: If true, non-local (i.e. job and task) device
291 assignment is stripped from ops in the serialized graph.
292 external_state_policy: The ExternalStatePolicy enum that determines how we
293 handle input pipelines that depend on external state. By default, its
294 set to WARN.
296 Returns:
297 A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
298 serialized graph.
299 """
300 if external_state_policy:
301 policy = external_state_policy.value
302 return gen_dataset_ops.dataset_to_graph_v2(
303 self._variant_tensor,
304 external_state_policy=policy,
305 strip_device_assignment=strip_device_assignment)
306 if strip_device_assignment:
307 return gen_dataset_ops.dataset_to_graph(
308 self._variant_tensor,
309 allow_stateful=allow_stateful,
310 strip_device_assignment=strip_device_assignment)
311 return gen_dataset_ops.dataset_to_graph(
312 self._variant_tensor, allow_stateful=allow_stateful)
314 def _maybe_track_assets(self, graph_def):
315 """Finds and tracks nodes in `graph_def` that refer to asset files.
317 Args:
318 graph_def: Serialized graph representation of this dataset.
320 Returns:
321 A dictionary mapping the node name of an asset constant to a tracked
322 `asset.Asset` object.
323 """
324 asset_tracker = {}
325 for node in graph_def.node:
326 if node.name.startswith("FileIdentity"):
327 asset_tracker[node.input[0]] = None
329 if not asset_tracker:
330 return {}
332 for node in graph_def.node:
333 if node.name in asset_tracker:
334 tensor_proto = node.attr["value"].tensor
335 with context.eager_mode(), ops.device("CPU"):
336 node_value = gen_parsing_ops.parse_tensor(
337 tensor_proto.SerializeToString(), dtypes.string).numpy()
338 asset_tracker[node.name] = ([
339 self._track_trackable(asset.Asset(n),
340 name=node.name + "_" + str(i), overwrite=True)
341 for i, n in enumerate(node_value)
342 ])
343 return asset_tracker
345 def _trackable_children(self,
346 save_type=tracking_base.SaveType.CHECKPOINT,
347 **kwargs):
348 if save_type != tracking_base.SaveType.SAVEDMODEL:
349 return {}
351 # _trace_variant_creation only works when executing eagerly, so we don't
352 # want to run it in the object initialization.
353 @def_function.function(input_signature=[], autograph=False)
354 def _creator():
355 resource = self._trace_variant_creation()() # pylint: disable=protected-access
356 return resource
357 _creator.get_concrete_function() # Trigger asset tracking
359 children = super(DatasetV2, self)._trackable_children(save_type, **kwargs)
360 children["_variant_tracker"] = _VariantTracker(self._variant_tensor,
361 _creator)
362 return children
364 def _trace_variant_creation(self):
365 """Traces a function which outputs a variant `tf.Tensor` for this dataset.
367 Note that creating this function involves evaluating an op, and is currently
368 only supported when executing eagerly.
370 Returns:
371 A zero-argument `ConcreteFunction` which outputs a variant `tf.Tensor`.
372 """
373 variant = self._variant_tensor
374 if not isinstance(variant, ops.EagerTensor):
375 raise NotImplementedError(
376 "Constructing a tf.function that reproduces a given dataset is only "
377 "supported for datasets created eagerly. Please file a feature "
378 "request if this is important to you.")
379 with context.eager_mode(), ops.device("CPU"):
380 # pylint: disable=protected-access
381 graph_def = graph_pb2.GraphDef().FromString(
382 self._as_serialized_graph(external_state_policy=options_lib
383 .ExternalStatePolicy.FAIL).numpy())
384 output_node_names = []
385 for node in graph_def.node:
386 if node.op == "_Retval":
387 output_node_names = node.input
389 if len(output_node_names) != 1:
390 raise AssertionError(
391 f"Dataset graph is expected to only have one return value but found "
392 f"{len(output_node_names)} return values: {output_node_names}.")
394 output_node_name = output_node_names[0]
396 file_path_nodes = {}
397 # When building a tf.function, track files as `saved_model.Asset`s.
398 if ops.get_default_graph().building_function:
399 asset_tracker = self._maybe_track_assets(graph_def)
400 for key in asset_tracker:
401 assets_list = [
402 array_ops.expand_dims(asset.asset_path, axis=0)
403 for asset in asset_tracker[key]
404 ]
405 file_path_nodes[key] = array_ops.concat(assets_list, axis=0)
407 # Add functions used in this Dataset to the function's graph, since they
408 # need to follow it around (and for example be added to a SavedModel which
409 # references the dataset).
410 variant_function = wrap_function.function_from_graph_def(
411 graph_def,
412 inputs=[],
413 outputs=output_node_name + ":0",
414 captures=file_path_nodes)
415 for used_function in self._functions():
416 used_function.function.add_to_graph(variant_function.graph)
417 return variant_function
419 @abc.abstractmethod
420 def _inputs(self):
421 """Returns a list of the input datasets of the dataset."""
423 raise NotImplementedError(f"{type(self)}._inputs()")
425 @property
426 def _graph(self):
427 return self._graph_attr
429 @_graph.setter
430 def _graph(self, _):
431 raise ValueError("The `_graph` property cannot be modified.")
433 # TODO(jsimsa): Change this to be the transitive closure of functions used
434 # by this dataset and its inputs.
435 def _functions(self):
436 """Returns a list of functions associated with this dataset.
438 Returns:
439 A list of `StructuredFunctionWrapper` objects.
440 """
441 return []
443 def _options(self):
444 """Returns the options tensor for this dataset."""
445 # pylint: disable=protected-access
446 return gen_dataset_ops.get_options(self._variant_tensor)
448 @classmethod
449 def _options_tensor_to_options(cls, serialized_options):
450 """Converts options tensor to tf.data.Options object."""
451 options = options_lib.Options()
452 if tensor_util.constant_value(serialized_options) is not None:
453 pb = dataset_options_pb2.Options.FromString(tensor_util.constant_value(
454 serialized_options))
455 options._from_proto(pb) # pylint: disable=protected-access
456 return options
458 def options(self):
459 """Returns the options for this dataset and its inputs.
461 Returns:
462 A `tf.data.Options` object representing the dataset options.
463 """
464 if context.executing_eagerly():
465 options = self._options_tensor_to_options(self._options())
466 options._set_mutable(False) # pylint: disable=protected-access
467 return options
468 warnings.warn("To make it possible to preserve tf.data options across "
469 "serialization boundaries, their implementation has moved to "
470 "be part of the TensorFlow graph. As a consequence, the "
471 "options value is in general no longer known at graph "
472 "construction time. Invoking this method in graph mode "
473 "retains the legacy behavior of the original implementation, "
474 "but note that the returned value might not reflect the "
475 "actual value of the options.")
476 return self._options_attr
478 def _apply_debug_options(self):
479 if debug_mode.DEBUG_MODE:
480 # Disable autotuning and static optimizations that could introduce
481 # parallelism or asynchrony.
482 options = options_lib.Options()
483 options.autotune.enabled = False
484 options.experimental_optimization.filter_parallelization = False
485 options.experimental_optimization.map_and_batch_fusion = False
486 options.experimental_optimization.map_parallelization = False
487 dataset = _OptionsDataset(self, options)
488 else:
489 dataset = self
491 return dataset
493 def __iter__(self):
494 """Creates an iterator for elements of this dataset.
496 The returned iterator implements the Python Iterator protocol.
498 Returns:
499 An `tf.data.Iterator` for the elements of this dataset.
501 Raises:
502 RuntimeError: If not inside of tf.function and not executing eagerly.
503 """
504 if context.executing_eagerly() or ops.inside_function():
505 with ops.colocate_with(self._variant_tensor):
506 return iterator_ops.OwnedIterator(self)
507 else:
508 raise RuntimeError("`tf.data.Dataset` only supports Python-style "
509 "iteration in eager mode or within tf.function.")
511 def __bool__(self):
512 return True # Required as __len__ is defined
514 __nonzero__ = __bool__ # Python 2 backward compatibility
516 def __len__(self):
517 """Returns the length of the dataset if it is known and finite.
519 This method requires that you are running in eager mode, and that the
520 length of the dataset is known and non-infinite. When the length may be
521 unknown or infinite, or if you are running in graph mode, use
522 `tf.data.Dataset.cardinality` instead.
524 Returns:
525 An integer representing the length of the dataset.
527 Raises:
528 RuntimeError: If the dataset length is unknown or infinite, or if eager
529 execution is not enabled.
530 """
531 if not context.executing_eagerly():
532 raise TypeError("`tf.data.Dataset` only supports `len` in eager mode. "
533 "Use `tf.data.Dataset.cardinality()` instead.")
534 length = self.cardinality()
535 if length.numpy() == INFINITE:
536 raise TypeError("The dataset is infinite.")
537 if length.numpy() == UNKNOWN:
538 raise TypeError("The dataset length is unknown.")
539 return length
541 @abc.abstractproperty
542 def element_spec(self):
543 """The type specification of an element of this dataset.
545 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
546 >>> dataset.element_spec
547 TensorSpec(shape=(), dtype=tf.int32, name=None)
549 For more information,
550 read [this guide](https://www.tensorflow.org/guide/data#dataset_structure).
552 Returns:
553 A (nested) structure of `tf.TypeSpec` objects matching the structure of an
554 element of this dataset and specifying the type of individual components.
555 """
556 raise NotImplementedError(f"{type(self)}.element_spec()")
558 def __repr__(self):
559 type_ = type(self._dataset if isinstance(self, DatasetV1Adapter) else self)
560 return f"<{type_.__name__} element_spec={self.element_spec}>"
562 def __debug_string__(self):
563 """Returns a string showing the type of the dataset and its inputs.
565 This string is intended only for debugging purposes, and may change without
566 warning.
567 """
568 lines = []
569 to_process = [(self, 0)] # Stack of (dataset, depth) pairs.
570 while to_process:
571 dataset, depth = to_process.pop()
572 lines.append("-"*2*depth + repr(dataset))
573 to_process.extend([(ds, depth+1) for ds in dataset._inputs()]) # pylint: disable=protected-access
574 return "\n".join(lines)
576 def as_numpy_iterator(self):
577 """Returns an iterator which converts all elements of the dataset to numpy.
579 Use `as_numpy_iterator` to inspect the content of your dataset. To see
580 element shapes and types, print dataset elements directly instead of using
581 `as_numpy_iterator`.
583 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
584 >>> for element in dataset:
585 ... print(element)
586 tf.Tensor(1, shape=(), dtype=int32)
587 tf.Tensor(2, shape=(), dtype=int32)
588 tf.Tensor(3, shape=(), dtype=int32)
590 This method requires that you are running in eager mode and the dataset's
591 element_spec contains only `TensorSpec` components.
593 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
594 >>> for element in dataset.as_numpy_iterator():
595 ... print(element)
596 1
597 2
598 3
600 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
601 >>> print(list(dataset.as_numpy_iterator()))
602 [1, 2, 3]
604 `as_numpy_iterator()` will preserve the nested structure of dataset
605 elements.
607 >>> dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]),
608 ... 'b': [5, 6]})
609 >>> list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5},
610 ... {'a': (2, 4), 'b': 6}]
611 True
613 Returns:
614 An iterable over the elements of the dataset, with their tensors converted
615 to numpy arrays.
617 Raises:
618 TypeError: if an element contains a non-`Tensor` value.
619 RuntimeError: if eager execution is not enabled.
620 """
621 if not context.executing_eagerly():
622 raise RuntimeError("`tf.data.Dataset.as_numpy_iterator()` is only "
623 "supported in eager mode.")
624 for component_spec in nest.flatten(self.element_spec):
625 if not isinstance(
626 component_spec,
627 (tensor_spec.TensorSpec, ragged_tensor.RaggedTensorSpec,
628 sparse_tensor_lib.SparseTensorSpec, structure.NoneTensorSpec)):
629 raise TypeError(
630 f"`tf.data.Dataset.as_numpy_iterator()` is not supported for "
631 f"datasets that produce values of type {component_spec.value_type}")
633 return _NumpyIterator(self)
635 @property
636 def _flat_shapes(self):
637 """Returns a list `tf.TensorShapes`s for the element tensor representation.
639 Returns:
640 A list `tf.TensorShapes`s for the element tensor representation.
641 """
642 return structure.get_flat_tensor_shapes(self.element_spec)
644 @property
645 def _flat_types(self):
646 """Returns a list `tf.DType`s for the element tensor representation.
648 Returns:
649 A list `tf.DType`s for the element tensor representation.
650 """
651 return structure.get_flat_tensor_types(self.element_spec)
653 @property
654 def _flat_structure(self):
655 """Helper for setting `output_shapes` and `output_types` attrs of an op.
657 Most dataset op constructors expect `output_shapes` and `output_types`
658 arguments that represent the flattened structure of an element. This helper
659 function generates these attrs as a keyword argument dictionary, allowing
660 `Dataset._variant_tensor` implementations to pass `**self._flat_structure`
661 to the op constructor.
663 Returns:
664 A dictionary of keyword arguments that can be passed to a dataset op
665 constructor.
666 """
667 return {
668 "output_shapes": self._flat_shapes,
669 "output_types": self._flat_types,
670 }
672 @property
673 def _metadata(self):
674 """Helper for generating dataset metadata."""
675 metadata = dataset_metadata_pb2.Metadata()
676 if self._name:
677 metadata.name = _validate_and_encode(self._name)
678 return metadata
680 @property
681 def _common_args(self):
682 """Helper for generating arguments that are common across most dataset ops.
684 Most dataset op constructors expect `output_shapes` and `output_types`
685 arguments that represent the flattened structure of an element, as well as a
686 `metadata` argument for additional metadata such as user-defined dataset
687 name. This helper function generates common attributes as a keyword argument
688 dictionary, allowing `Dataset._variant_tensor` implementations to pass
689 `**self._common_args` to the op constructor.
691 Returns:
692 A dictionary of keyword arguments that can be passed to a dataset op
693 constructor.
694 """
695 return {
696 "metadata": self._metadata.SerializeToString(),
697 "output_shapes": self._flat_shapes,
698 "output_types": self._flat_types,
699 }
701 @property
702 def _type_spec(self):
703 return DatasetSpec(self.element_spec)
705 @staticmethod
706 def from_tensors(tensors, name=None):
707 """Creates a `Dataset` with a single element, comprising the given tensors.
709 `from_tensors` produces a dataset containing only a single element. To slice
710 the input tensor into multiple elements, use `from_tensor_slices` instead.
712 >>> dataset = tf.data.Dataset.from_tensors([1, 2, 3])
713 >>> list(dataset.as_numpy_iterator())
714 [array([1, 2, 3], dtype=int32)]
715 >>> dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A'))
716 >>> list(dataset.as_numpy_iterator())
717 [(array([1, 2, 3], dtype=int32), b'A')]
719 >>> # You can use `from_tensors` to produce a dataset which repeats
720 >>> # the same example many times.
721 >>> example = tf.constant([1,2,3])
722 >>> dataset = tf.data.Dataset.from_tensors(example).repeat(2)
723 >>> list(dataset.as_numpy_iterator())
724 [array([1, 2, 3], dtype=int32), array([1, 2, 3], dtype=int32)]
726 Note that if `tensors` contains a NumPy array, and eager execution is not
727 enabled, the values will be embedded in the graph as one or more
728 `tf.constant` operations. For large datasets (> 1 GB), this can waste
729 memory and run into byte limits of graph serialization. If `tensors`
730 contains one or more large NumPy arrays, consider the alternative described
731 in [this
732 guide](https://tensorflow.org/guide/data#consuming_numpy_arrays).
734 Args:
735 tensors: A dataset "element". Supported values are documented
736 [here](https://www.tensorflow.org/guide/data#dataset_structure).
737 name: (Optional.) A name for the tf.data operation.
739 Returns:
740 Dataset: A `Dataset`.
741 """
742 # Loaded lazily due to a circular dependency (dataset_ops ->
743 # from_tensors_op -> dataset_ops).
744 # pylint: disable=g-import-not-at-top,protected-access
745 from tensorflow.python.data.ops import from_tensors_op
746 return from_tensors_op._from_tensors(tensors, name)
747 # pylint: enable=g-import-not-at-top,protected-access
749 @staticmethod
750 def from_tensor_slices(tensors, name=None):
751 """Creates a `Dataset` whose elements are slices of the given tensors.
753 The given tensors are sliced along their first dimension. This operation
754 preserves the structure of the input tensors, removing the first dimension
755 of each tensor and using it as the dataset dimension. All input tensors
756 must have the same size in their first dimensions.
758 >>> # Slicing a 1D tensor produces scalar tensor elements.
759 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
760 >>> list(dataset.as_numpy_iterator())
761 [1, 2, 3]
763 >>> # Slicing a 2D tensor produces 1D tensor elements.
764 >>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
765 >>> list(dataset.as_numpy_iterator())
766 [array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
768 >>> # Slicing a tuple of 1D tensors produces tuple elements containing
769 >>> # scalar tensors.
770 >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
771 >>> list(dataset.as_numpy_iterator())
772 [(1, 3, 5), (2, 4, 6)]
774 >>> # Dictionary structure is also preserved.
775 >>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
776 >>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
777 ... {'a': 2, 'b': 4}]
778 True
780 >>> # Two tensors can be combined into one Dataset object.
781 >>> features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor
782 >>> labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor
783 >>> dataset = Dataset.from_tensor_slices((features, labels))
784 >>> # Both the features and the labels tensors can be converted
785 >>> # to a Dataset object separately and combined after.
786 >>> features_dataset = Dataset.from_tensor_slices(features)
787 >>> labels_dataset = Dataset.from_tensor_slices(labels)
788 >>> dataset = Dataset.zip((features_dataset, labels_dataset))
789 >>> # A batched feature and label set can be converted to a Dataset
790 >>> # in similar fashion.
791 >>> batched_features = tf.constant([[[1, 3], [2, 3]],
792 ... [[2, 1], [1, 2]],
793 ... [[3, 3], [3, 2]]], shape=(3, 2, 2))
794 >>> batched_labels = tf.constant([['A', 'A'],
795 ... ['B', 'B'],
796 ... ['A', 'B']], shape=(3, 2, 1))
797 >>> dataset = Dataset.from_tensor_slices((batched_features, batched_labels))
798 >>> for element in dataset.as_numpy_iterator():
799 ... print(element)
800 (array([[1, 3],
801 [2, 3]], dtype=int32), array([[b'A'],
802 [b'A']], dtype=object))
803 (array([[2, 1],
804 [1, 2]], dtype=int32), array([[b'B'],
805 [b'B']], dtype=object))
806 (array([[3, 3],
807 [3, 2]], dtype=int32), array([[b'A'],
808 [b'B']], dtype=object))
810 Note that if `tensors` contains a NumPy array, and eager execution is not
811 enabled, the values will be embedded in the graph as one or more
812 `tf.constant` operations. For large datasets (> 1 GB), this can waste
813 memory and run into byte limits of graph serialization. If `tensors`
814 contains one or more large NumPy arrays, consider the alternative described
815 in [this guide](
816 https://tensorflow.org/guide/data#consuming_numpy_arrays).
818 Args:
819 tensors: A dataset element, whose components have the same first
820 dimension. Supported values are documented
821 [here](https://www.tensorflow.org/guide/data#dataset_structure).
822 name: (Optional.) A name for the tf.data operation.
824 Returns:
825 Dataset: A `Dataset`.
826 """
827 # Loaded lazily due to a circular dependency (dataset_ops ->
828 # from_tensor_slices_op -> dataset_ops).
829 # pylint: disable=g-import-not-at-top,protected-access
830 from tensorflow.python.data.ops import from_tensor_slices_op
831 return from_tensor_slices_op._from_tensor_slices(tensors, name)
832 # pylint: enable=g-import-not-at-top,protected-access
834 class _GeneratorState:
835 """Stores outstanding iterators created from a Python generator.
837 This class keeps track of potentially multiple iterators that may have
838 been created from a generator, e.g. in the case that the dataset is
839 repeated, or nested within a parallel computation.
840 """
842 def __init__(self, generator):
843 self._generator = generator
844 self._lock = threading.Lock()
845 self._next_id = 0 # GUARDED_BY(self._lock)
846 self._args = {}
847 self._iterators = {}
849 def _normalize_id(self, iterator_id):
850 # In debug mode, iterator ids may be eagerly-generated np.arrays instead
851 # of Tensors. We convert them to scalars to make them hashable.
852 if isinstance(iterator_id, np.ndarray):
853 return iterator_id.item()
854 return iterator_id
856 def get_next_id(self, *args):
857 with self._lock:
858 ret = self._next_id
859 self._next_id += 1
860 self._args[ret] = args
861 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
862 # casting in `py_func()` will create an array of `np.int32` on Windows,
863 # leading to a runtime error.
864 return np.array(ret, dtype=np.int64)
866 def get_iterator(self, iterator_id):
867 iterator_id = self._normalize_id(iterator_id)
868 try:
869 return self._iterators[iterator_id]
870 except KeyError:
871 iterator = iter(self._generator(*self._args.pop(iterator_id)))
872 self._iterators[iterator_id] = iterator
873 return iterator
875 def iterator_completed(self, iterator_id):
876 del self._iterators[self._normalize_id(iterator_id)]
878 @staticmethod
879 @deprecation.deprecated_args(None, "Use output_signature instead",
880 "output_types", "output_shapes")
881 def from_generator(generator,
882 output_types=None,
883 output_shapes=None,
884 args=None,
885 output_signature=None,
886 name=None):
887 """Creates a `Dataset` whose elements are generated by `generator`.
889 Note: The current implementation of `Dataset.from_generator()` uses
890 `tf.numpy_function` and inherits the same constraints. In particular, it
891 requires the dataset and iterator related operations to be placed
892 on a device in the same process as the Python program that called
893 `Dataset.from_generator()`. In particular, using `from_generator` will
894 preclude the use of tf.data service for scaling out dataset processing.
895 The body of `generator` will not be serialized in a `GraphDef`, and you
896 should not use this method if you need to serialize your model and restore
897 it in a different environment.
899 The `generator` argument must be a callable object that returns
900 an object that supports the `iter()` protocol (e.g. a generator function).
902 The elements generated by `generator` must be compatible with either the
903 given `output_signature` argument or with the given `output_types` and
904 (optionally) `output_shapes` arguments, whichever was specified.
906 The recommended way to call `from_generator` is to use the
907 `output_signature` argument. In this case the output will be assumed to
908 consist of objects with the classes, shapes and types defined by
909 `tf.TypeSpec` objects from `output_signature` argument:
911 >>> def gen():
912 ... ragged_tensor = tf.ragged.constant([[1, 2], [3]])
913 ... yield 42, ragged_tensor
914 >>>
915 >>> dataset = tf.data.Dataset.from_generator(
916 ... gen,
917 ... output_signature=(
918 ... tf.TensorSpec(shape=(), dtype=tf.int32),
919 ... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))
920 >>>
921 >>> list(dataset.take(1))
922 [(<tf.Tensor: shape=(), dtype=int32, numpy=42>,
923 <tf.RaggedTensor [[1, 2], [3]]>)]
925 There is also a deprecated way to call `from_generator` by either with
926 `output_types` argument alone or together with `output_shapes` argument.
927 In this case the output of the function will be assumed to consist of
928 `tf.Tensor` objects with the types defined by `output_types` and with the
929 shapes which are either unknown or defined by `output_shapes`.
931 Note: If `generator` depends on mutable global variables or other external
932 state, be aware that the runtime may invoke `generator` multiple times
933 (in order to support repeating the `Dataset`) and at any time
934 between the call to `Dataset.from_generator()` and the production of the
935 first element from the generator. Mutating global variables or external
936 state can cause undefined behavior, and we recommend that you explicitly
937 cache any external state in `generator` before calling
938 `Dataset.from_generator()`.
940 Note: While the `output_signature` parameter makes it possible to yield
941 `Dataset` elements, the scope of `Dataset.from_generator()` should be
942 limited to logic that cannot be expressed through tf.data operations. Using
943 tf.data operations within the generator function is an anti-pattern and may
944 result in incremental memory growth.
946 Args:
947 generator: A callable object that returns an object that supports the
948 `iter()` protocol. If `args` is not specified, `generator` must take no
949 arguments; otherwise it must take as many arguments as there are values
950 in `args`.
951 output_types: (Optional.) A (nested) structure of `tf.DType` objects
952 corresponding to each component of an element yielded by `generator`.
953 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape`
954 objects corresponding to each component of an element yielded by
955 `generator`.
956 args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
957 and passed to `generator` as NumPy-array arguments.
958 output_signature: (Optional.) A (nested) structure of `tf.TypeSpec`
959 objects corresponding to each component of an element yielded by
960 `generator`.
961 name: (Optional.) A name for the tf.data operations used by
962 `from_generator`.
964 Returns:
965 Dataset: A `Dataset`.
966 """
967 # Loaded lazily due to a circular dependency (dataset_ops ->
968 # from_generator_op -> dataset_ops).
969 # pylint: disable=g-import-not-at-top,protected-access
970 from tensorflow.python.data.ops import from_generator_op
971 return from_generator_op._from_generator(generator, output_types,
972 output_shapes, args,
973 output_signature, name)
974 # pylint: enable=g-import-not-at-top,protected-access
976 @staticmethod
977 def range(*args, **kwargs):
978 """Creates a `Dataset` of a step-separated range of values.
980 >>> list(Dataset.range(5).as_numpy_iterator())
981 [0, 1, 2, 3, 4]
982 >>> list(Dataset.range(2, 5).as_numpy_iterator())
983 [2, 3, 4]
984 >>> list(Dataset.range(1, 5, 2).as_numpy_iterator())
985 [1, 3]
986 >>> list(Dataset.range(1, 5, -2).as_numpy_iterator())
987 []
988 >>> list(Dataset.range(5, 1).as_numpy_iterator())
989 []
990 >>> list(Dataset.range(5, 1, -2).as_numpy_iterator())
991 [5, 3]
992 >>> list(Dataset.range(2, 5, output_type=tf.int32).as_numpy_iterator())
993 [2, 3, 4]
994 >>> list(Dataset.range(1, 5, 2, output_type=tf.float32).as_numpy_iterator())
995 [1.0, 3.0]
997 Args:
998 *args: follows the same semantics as python's range.
999 len(args) == 1 -> start = 0, stop = args[0], step = 1.
1000 len(args) == 2 -> start = args[0], stop = args[1], step = 1.
1001 len(args) == 3 -> start = args[0], stop = args[1], step = args[2].
1002 **kwargs:
1003 - output_type: Its expected dtype. (Optional, default: `tf.int64`).
1004 - name: (Optional.) A name for the tf.data operation.
1006 Returns:
1007 Dataset: A `RangeDataset`.
1009 Raises:
1010 ValueError: if len(args) == 0.
1011 """
1012 # Loaded lazily due to a circular dependency (dataset_ops -> range_op ->
1013 # -> dataset_ops).
1014 # pylint: disable=g-import-not-at-top,protected-access
1015 from tensorflow.python.data.ops import range_op
1016 return range_op._range(*args, **kwargs)
1017 # pylint: enable=g-import-not-at-top,protected-access
1019 @staticmethod
1020 def zip(*args, datasets=None, name=None):
1021 """Creates a `Dataset` by zipping together the given datasets.
1023 This method has similar semantics to the built-in `zip()` function
1024 in Python, with the main difference being that the `datasets`
1025 argument can be a (nested) structure of `Dataset` objects. The supported
1026 nesting mechanisms are documented
1027 [here] (https://www.tensorflow.org/guide/data#dataset_structure).
1029 >>> # The datasets or nested structure of datasets `*args` argument
1030 >>> # determines the structure of elements in the resulting dataset.
1031 >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ]
1032 >>> b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ]
1033 >>> ds = tf.data.Dataset.zip(a, b)
1034 >>> list(ds.as_numpy_iterator())
1035 [(1, 4), (2, 5), (3, 6)]
1036 >>> ds = tf.data.Dataset.zip(b, a)
1037 >>> list(ds.as_numpy_iterator())
1038 [(4, 1), (5, 2), (6, 3)]
1039 >>>
1040 >>> # The `datasets` argument may contain an arbitrary number of datasets.
1041 >>> c = tf.data.Dataset.range(7, 13).batch(2) # ==> [ [7, 8],
1042 ... # [9, 10],
1043 ... # [11, 12] ]
1044 >>> ds = tf.data.Dataset.zip(a, b, c)
1045 >>> for element in ds.as_numpy_iterator():
1046 ... print(element)
1047 (1, 4, array([7, 8]))
1048 (2, 5, array([ 9, 10]))
1049 (3, 6, array([11, 12]))
1050 >>>
1051 >>> # The number of elements in the resulting dataset is the same as
1052 >>> # the size of the smallest dataset in `datasets`.
1053 >>> d = tf.data.Dataset.range(13, 15) # ==> [ 13, 14 ]
1054 >>> ds = tf.data.Dataset.zip(a, d)
1055 >>> list(ds.as_numpy_iterator())
1056 [(1, 13), (2, 14)]
1058 Args:
1059 *args: Datasets or nested structures of datasets to zip together. This
1060 can't be set if `datasets` is set.
1061 datasets: A (nested) structure of datasets. This can't be set if `*args`
1062 is set. Note that this exists only for backwards compatibility and it is
1063 preferred to use *args.
1064 name: (Optional.) A name for the tf.data operation.
1066 Returns:
1067 A new `Dataset` with the transformation applied as described above.
1068 """
1069 # Loaded lazily due to a circular dependency (dataset_ops -> zip_op ->
1070 # dataset_ops).
1071 # pylint: disable=g-import-not-at-top,protected-access
1072 from tensorflow.python.data.ops import zip_op
1074 if not args and datasets is None:
1075 raise TypeError("Must pass at least one dataset to `zip`.")
1076 if args and datasets is not None:
1077 raise TypeError("Both `*args` and `datasets` cannot be set.")
1078 if len(args) == 1:
1079 datasets = args[0]
1080 elif len(args) > 1:
1081 datasets = args
1082 return zip_op._zip(datasets, name)
1083 # pylint: enable=g-import-not-at-top,protected-access
1085 def concatenate(self, dataset, name=None):
1086 """Creates a `Dataset` by concatenating the given dataset with this dataset.
1088 >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ]
1089 >>> b = tf.data.Dataset.range(4, 8) # ==> [ 4, 5, 6, 7 ]
1090 >>> ds = a.concatenate(b)
1091 >>> list(ds.as_numpy_iterator())
1092 [1, 2, 3, 4, 5, 6, 7]
1093 >>> # The input dataset and dataset to be concatenated should have
1094 >>> # compatible element specs.
1095 >>> c = tf.data.Dataset.zip((a, b))
1096 >>> a.concatenate(c)
1097 Traceback (most recent call last):
1098 TypeError: Two datasets to concatenate have different types
1099 <dtype: 'int64'> and (tf.int64, tf.int64)
1100 >>> d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])
1101 >>> a.concatenate(d)
1102 Traceback (most recent call last):
1103 TypeError: Two datasets to concatenate have different types
1104 <dtype: 'int64'> and <dtype: 'string'>
1106 Args:
1107 dataset: `Dataset` to be concatenated.
1108 name: (Optional.) A name for the tf.data operation.
1110 Returns:
1111 A new `Dataset` with the transformation applied as described above.
1112 """
1113 # Loaded lazily due to a circular dependency (dataset_ops ->
1114 # concatenate_op -> dataset_ops).
1115 # pylint: disable=g-import-not-at-top,protected-access
1116 from tensorflow.python.data.ops import concatenate_op
1117 return concatenate_op._concatenate(self, dataset, name)
1118 # pylint: enable=g-import-not-at-top,protected-access
1120 @staticmethod
1121 def counter(start=0, step=1, dtype=dtypes.int64, name=None):
1122 """Creates a `Dataset` that counts from `start` in steps of size `step`.
1124 Unlike `tf.data.Dataset.range`, which stops at some ending number,
1125 `tf.data.Dataset.counter` produces elements indefinitely.
1127 >>> dataset = tf.data.experimental.Counter().take(5)
1128 >>> list(dataset.as_numpy_iterator())
1129 [0, 1, 2, 3, 4]
1130 >>> dataset.element_spec
1131 TensorSpec(shape=(), dtype=tf.int64, name=None)
1132 >>> dataset = tf.data.experimental.Counter(dtype=tf.int32)
1133 >>> dataset.element_spec
1134 TensorSpec(shape=(), dtype=tf.int32, name=None)
1135 >>> dataset = tf.data.experimental.Counter(start=2).take(5)
1136 >>> list(dataset.as_numpy_iterator())
1137 [2, 3, 4, 5, 6]
1138 >>> dataset = tf.data.experimental.Counter(start=2, step=5).take(5)
1139 >>> list(dataset.as_numpy_iterator())
1140 [2, 7, 12, 17, 22]
1141 >>> dataset = tf.data.experimental.Counter(start=10, step=-1).take(5)
1142 >>> list(dataset.as_numpy_iterator())
1143 [10, 9, 8, 7, 6]
1145 Args:
1146 start: (Optional.) The starting value for the counter. Defaults to 0.
1147 step: (Optional.) The step size for the counter. Defaults to 1.
1148 dtype: (Optional.) The data type for counter elements. Defaults to
1149 `tf.int64`.
1150 name: (Optional.) A name for the tf.data operation.
1152 Returns:
1153 A `Dataset` of scalar `dtype` elements.
1154 """
1155 # Loaded lazily due to a circular dependency (dataset_ops -> counter_op
1156 # -> dataset_ops).
1157 # pylint: disable=g-import-not-at-top,protected-access
1158 from tensorflow.python.data.ops import counter_op
1159 return counter_op._counter(start, step, dtype, name=name)
1160 # pylint: enable=g-import-not-at-top,protected-access
1162 def rebatch(self, batch_size, drop_remainder=False, name=None):
1163 """Creates a `Dataset` that rebatches the elements from this dataset.
1165 `rebatch(N)` is functionally equivalent to `unbatch().batch(N)`, but is
1166 more efficient, performing one copy instead of two.
1168 >>> ds = tf.data.Dataset.range(6)
1169 >>> ds = ds.batch(2)
1170 >>> ds = ds.rebatch(3)
1171 >>> list(ds.as_numpy_iterator())
1172 [array([0, 1, 2]), array([3, 4, 5])]
1174 >>> ds = tf.data.Dataset.range(7)
1175 >>> ds = ds.batch(4)
1176 >>> ds = ds.rebatch(3)
1177 >>> list(ds.as_numpy_iterator())
1178 [array([0, 1, 2]), array([3, 4, 5]), array([6])]
1180 >>> ds = tf.data.Dataset.range(7)
1181 >>> ds = ds.batch(2)
1182 >>> ds = ds.rebatch(3, drop_remainder=True)
1183 >>> list(ds.as_numpy_iterator())
1184 [array([0, 1, 2]), array([3, 4, 5])]
1186 If the `batch_size` argument is a list, `rebatch` cycles through the list
1187 to determine the size of each batch.
1189 >>> ds = tf.data.Dataset.range(8)
1190 >>> ds = ds.batch(4)
1191 >>> ds = ds.rebatch([2, 1, 1])
1192 >>> list(ds.as_numpy_iterator())
1193 [array([0, 1]), array([2]), array([3]), array([4, 5]), array([6]),
1194 array([7])]
1196 Args:
1197 batch_size: A `tf.int64` scalar or vector, representing the size of
1198 batches to produce. If this argument is a vector, these values are
1199 cycled through in round robin fashion.
1200 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
1201 whether the last batch should be dropped in the case it has fewer than
1202 `batch_size[cycle_index]` elements; the default behavior is not to drop
1203 the smaller batch.
1204 name: (Optional.) A name for the tf.data operation.
1206 Returns:
1207 A `Dataset` of scalar `dtype` elements.
1208 """
1209 # Loaded lazily due to a circular dependency (dataset_ops -> rebatch_op ->
1210 # rebatch_op -> dataset_ops).
1211 # pylint: disable=g-import-not-at-top,protected-access
1212 from tensorflow.python.data.ops import rebatch_op
1213 return rebatch_op._rebatch(self, batch_size, drop_remainder, name=name)
1214 # pylint: enable=g-import-not-at-top,protected-access
1216 def prefetch(self, buffer_size, name=None):
1217 """Creates a `Dataset` that prefetches elements from this dataset.
1219 Most dataset input pipelines should end with a call to `prefetch`. This
1220 allows later elements to be prepared while the current element is being
1221 processed. This often improves latency and throughput, at the cost of
1222 using additional memory to store prefetched elements.
1224 Note: Like other `Dataset` methods, prefetch operates on the
1225 elements of the input dataset. It has no concept of examples vs. batches.
1226 `examples.prefetch(2)` will prefetch two elements (2 examples),
1227 while `examples.batch(20).prefetch(2)` will prefetch 2 elements
1228 (2 batches, of 20 examples each).
1230 >>> dataset = tf.data.Dataset.range(3)
1231 >>> dataset = dataset.prefetch(2)
1232 >>> list(dataset.as_numpy_iterator())
1233 [0, 1, 2]
1235 Args:
1236 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum
1237 number of elements that will be buffered when prefetching. If the value
1238 `tf.data.AUTOTUNE` is used, then the buffer size is dynamically tuned.
1239 name: Optional. A name for the tf.data transformation.
1241 Returns:
1242 A new `Dataset` with the transformation applied as described above.
1243 """
1244 return prefetch_op._prefetch( # pylint: disable=protected-access
1245 self, buffer_size, name=name)
1247 @staticmethod
1248 def list_files(file_pattern, shuffle=None, seed=None, name=None):
1249 """A dataset of all files matching one or more glob patterns.
1251 The `file_pattern` argument should be a small number of glob patterns.
1252 If your filenames have already been globbed, use
1253 `Dataset.from_tensor_slices(filenames)` instead, as re-globbing every
1254 filename with `list_files` may result in poor performance with remote
1255 storage systems.
1257 Note: The default behavior of this method is to return filenames in
1258 a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False`
1259 to get results in a deterministic order.
1261 Example:
1262 If we had the following files on our filesystem:
1264 - /path/to/dir/a.txt
1265 - /path/to/dir/b.py
1266 - /path/to/dir/c.py
1268 If we pass "/path/to/dir/*.py" as the directory, the dataset
1269 would produce:
1271 - /path/to/dir/b.py
1272 - /path/to/dir/c.py
1274 Args:
1275 file_pattern: A string, a list of strings, or a `tf.Tensor` of string type
1276 (scalar or vector), representing the filename glob (i.e. shell wildcard)
1277 pattern(s) that will be matched.
1278 shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
1279 Defaults to `True`.
1280 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
1281 seed that will be used to create the distribution. See
1282 `tf.random.set_seed` for behavior.
1283 name: Optional. A name for the tf.data operations used by `list_files`.
1285 Returns:
1286 Dataset: A `Dataset` of strings corresponding to file names.
1287 """
1288 with ops.name_scope("list_files"):
1289 if shuffle is None:
1290 shuffle = True
1291 file_pattern = ops.convert_to_tensor(
1292 file_pattern, dtype=dtypes.string, name="file_pattern")
1293 matching_files = gen_io_ops.matching_files(file_pattern)
1295 # Raise an exception if `file_pattern` does not match any files.
1296 condition = math_ops.greater(array_ops.shape(matching_files)[0], 0,
1297 name="match_not_empty")
1299 message = math_ops.add(
1300 "No files matched pattern: ",
1301 string_ops.reduce_join(file_pattern, separator=", "), name="message")
1303 assert_not_empty = control_flow_assert.Assert(
1304 condition, [message], summarize=1, name="assert_not_empty")
1305 with ops.control_dependencies([assert_not_empty]):
1306 matching_files = array_ops.identity(matching_files)
1308 # TODO(b/240947712): Remove lazy import after this method is factored out.
1309 # Loaded lazily due to a circular dependency (dataset_ops ->
1310 # from_tensor_slices_op -> dataset_ops).
1311 # pylint: disable=g-import-not-at-top,protected-access
1312 from tensorflow.python.data.ops import from_tensor_slices_op
1313 dataset = from_tensor_slices_op._TensorSliceDataset(
1314 matching_files, is_files=True, name=name)
1315 # pylint: enable=g-import-not-at-top,protected-access
1316 if issubclass(Dataset, DatasetV1):
1317 dataset = DatasetV1Adapter(dataset)
1318 if shuffle:
1319 # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
1320 # list of files might be empty.
1321 buffer_size = math_ops.maximum(
1322 array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
1323 dataset = dataset.shuffle(buffer_size, seed=seed, name=name)
1324 return dataset
1326 def repeat(self, count=None, name=None):
1327 """Repeats this dataset so each original value is seen `count` times.
1329 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1330 >>> dataset = dataset.repeat(3)
1331 >>> list(dataset.as_numpy_iterator())
1332 [1, 2, 3, 1, 2, 3, 1, 2, 3]
1334 Note: If the input dataset depends on global state (e.g. a random number
1335 generator) or its output is non-deterministic (e.g. because of upstream
1336 `shuffle`), then different repetitions may produce different elements.
1338 Args:
1339 count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
1340 number of times the dataset should be repeated. The default behavior (if
1341 `count` is `None` or `-1`) is for the dataset be repeated indefinitely.
1342 name: (Optional.) A name for the tf.data operation.
1344 Returns:
1345 A new `Dataset` with the transformation applied as described above.
1346 """
1347 # Loaded lazily due to a circular dependency (dataset_ops -> repeat_op ->
1348 # dataset_ops).
1349 # pylint: disable=g-import-not-at-top,protected-access,redefined-outer-name
1350 from tensorflow.python.data.ops import repeat_op
1351 return repeat_op._repeat(self, count, name)
1352 # pylint: enable=g-import-not-at-top,protected-access,redefined-outer-name
1354 def enumerate(self, start=0, name=None):
1355 """Enumerates the elements of this dataset.
1357 It is similar to python's `enumerate`.
1359 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1360 >>> dataset = dataset.enumerate(start=5)
1361 >>> for element in dataset.as_numpy_iterator():
1362 ... print(element)
1363 (5, 1)
1364 (6, 2)
1365 (7, 3)
1367 >>> # The (nested) structure of the input dataset determines the
1368 >>> # structure of elements in the resulting dataset.
1369 >>> dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)])
1370 >>> dataset = dataset.enumerate()
1371 >>> for element in dataset.as_numpy_iterator():
1372 ... print(element)
1373 (0, array([7, 8], dtype=int32))
1374 (1, array([ 9, 10], dtype=int32))
1376 Args:
1377 start: A `tf.int64` scalar `tf.Tensor`, representing the start value for
1378 enumeration.
1379 name: Optional. A name for the tf.data operations used by `enumerate`.
1381 Returns:
1382 A new `Dataset` with the transformation applied as described above.
1383 """
1385 max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
1386 range_dataset = Dataset.range(start, max_value, name=name)
1387 # Replicate the range component so that each split is enumerated
1388 # independently. This avoids the need for prohibitively expensive
1389 # cross-split coordination.
1390 range_dataset = _apply_rewrite(range_dataset, "replicate_on_split")
1391 return Dataset.zip((range_dataset, self), name=name)
1393 def shuffle(self,
1394 buffer_size,
1395 seed=None,
1396 reshuffle_each_iteration=None,
1397 name=None):
1398 """Randomly shuffles the elements of this dataset.
1400 This dataset fills a buffer with `buffer_size` elements, then randomly
1401 samples elements from this buffer, replacing the selected elements with new
1402 elements. For perfect shuffling, a buffer size greater than or equal to the
1403 full size of the dataset is required.
1405 For instance, if your dataset contains 10,000 elements but `buffer_size` is
1406 set to 1,000, then `shuffle` will initially select a random element from
1407 only the first 1,000 elements in the buffer. Once an element is selected,
1408 its space in the buffer is replaced by the next (i.e. 1,001-st) element,
1409 maintaining the 1,000 element buffer.
1411 `reshuffle_each_iteration` controls whether the shuffle order should be
1412 different for each epoch. In TF 1.X, the idiomatic way to create epochs
1413 was through the `repeat` transformation:
1415 ```python
1416 dataset = tf.data.Dataset.range(3)
1417 dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
1418 dataset = dataset.repeat(2)
1419 # [1, 0, 2, 1, 2, 0]
1421 dataset = tf.data.Dataset.range(3)
1422 dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
1423 dataset = dataset.repeat(2)
1424 # [1, 0, 2, 1, 0, 2]
1425 ```
1427 In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it
1428 possible to also create epochs through Python iteration:
1430 ```python
1431 dataset = tf.data.Dataset.range(3)
1432 dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
1433 list(dataset.as_numpy_iterator())
1434 # [1, 0, 2]
1435 list(dataset.as_numpy_iterator())
1436 # [1, 2, 0]
1437 ```
1439 ```python
1440 dataset = tf.data.Dataset.range(3)
1441 dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
1442 list(dataset.as_numpy_iterator())
1443 # [1, 0, 2]
1444 list(dataset.as_numpy_iterator())
1445 # [1, 0, 2]
1446 ```
1448 ### Fully shuffling all the data
1450 To shuffle an entire dataset, set `buffer_size=dataset.cardinality(). This
1451 is equivalent to setting the `buffer_size` equal to the number of elements
1452 in the dataset, resulting in uniform shuffle.
1454 Note: `shuffle(dataset.cardinality())` loads the full dataset into memory so
1455 that it can be shuffled. This will cause a memory overflow (OOM) error if
1456 the dataset is too large, so full-shuffle should only be used for datasets
1457 that are known to fit in the memory, such as datasets of filenames or other
1458 small datasets.
1460 ```python
1461 dataset = tf.data.Dataset.range(20)
1462 dataset = dataset.shuffle(dataset.cardinality())
1463 # [18, 4, 9, 2, 17, 8, 5, 10, 0, 6, 16, 3, 19, 7, 14, 11, 15, 13, 12, 1]
1464 ```
1466 Args:
1467 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1468 elements from this dataset from which the new dataset will sample. To
1469 uniformly shuffle the entire dataset, use
1470 `buffer_size=dataset.cardinality()`.
1471 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
1472 seed that will be used to create the distribution. See
1473 `tf.random.set_seed` for behavior.
1474 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
1475 that the dataset should be pseudorandomly reshuffled each time it is
1476 iterated over. (Defaults to `True`.)
1477 name: (Optional.) A name for the tf.data operation.
1479 Returns:
1480 A new `Dataset` with the transformation applied as described above.
1481 """
1482 return shuffle_op._shuffle( # pylint: disable=protected-access
1483 self, buffer_size, seed, reshuffle_each_iteration, name=name)
1485 def cache(self, filename="", name=None):
1486 """Caches the elements in this dataset.
1488 The first time the dataset is iterated over, its elements will be cached
1489 either in the specified file or in memory. Subsequent iterations will
1490 use the cached data.
1492 Note: To guarantee that the cache gets finalized, the input dataset must be
1493 iterated through in its entirety, until it raises StopIteration. Otherwise,
1494 subsequent iterations may not use cached data.
1496 >>> dataset = tf.data.Dataset.range(5)
1497 >>> dataset = dataset.map(lambda x: x**2)
1498 >>> dataset = dataset.cache()
1499 >>> # The first time reading through the data will generate the data using
1500 >>> # `range` and `map`.
1501 >>> list(dataset.as_numpy_iterator())
1502 [0, 1, 4, 9, 16]
1503 >>> # Subsequent iterations read from the cache.
1504 >>> list(dataset.as_numpy_iterator())
1505 [0, 1, 4, 9, 16]
1507 When caching to a file, the cached data will persist across runs. Even the
1508 first iteration through the data will read from the cache file. Changing
1509 the input pipeline before the call to `.cache()` will have no effect until
1510 the cache file is removed or the filename is changed.
1512 ```python
1513 dataset = tf.data.Dataset.range(5)
1514 dataset = dataset.cache("/path/to/file")
1515 list(dataset.as_numpy_iterator())
1516 # [0, 1, 2, 3, 4]
1517 dataset = tf.data.Dataset.range(10)
1518 dataset = dataset.cache("/path/to/file") # Same file!
1519 list(dataset.as_numpy_iterator())
1520 # [0, 1, 2, 3, 4]
1521 ```
1523 Note: `cache` will produce exactly the same elements during each iteration
1524 through the dataset. If you wish to randomize the iteration order, make sure
1525 to call `shuffle` *after* calling `cache`.
1527 Args:
1528 filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
1529 directory on the filesystem to use for caching elements in this Dataset.
1530 If a filename is not provided, the dataset will be cached in memory.
1531 name: (Optional.) A name for the tf.data operation.
1533 Returns:
1534 A new `Dataset` with the transformation applied as described above.
1535 """
1536 # Loaded lazily due to a circular dependency (dataset_ops -> cache_op ->
1537 # -> dataset_ops).
1538 # pylint: disable=g-import-not-at-top,protected-access
1539 from tensorflow.python.data.ops import cache_op
1540 return cache_op._cache(self, filename, name)
1541 # pylint: enable=g-import-not-at-top,protected-access
1543 def take(self, count, name=None):
1544 """Creates a `Dataset` with at most `count` elements from this dataset.
1546 >>> dataset = tf.data.Dataset.range(10)
1547 >>> dataset = dataset.take(3)
1548 >>> list(dataset.as_numpy_iterator())
1549 [0, 1, 2]
1551 Args:
1552 count: A `tf.int64` scalar `tf.Tensor`, representing the number of
1553 elements of this dataset that should be taken to form the new dataset.
1554 If `count` is -1, or if `count` is greater than the size of this
1555 dataset, the new dataset will contain all elements of this dataset.
1556 name: (Optional.) A name for the tf.data operation.
1558 Returns:
1559 A new `Dataset` with the transformation applied as described above.
1560 """
1561 # Loaded lazily due to a circular dependency (dataset_ops ->
1562 # take_op -> dataset_ops).
1563 # pylint: disable=g-import-not-at-top,protected-access
1564 from tensorflow.python.data.ops import take_op
1565 return take_op._take(self, count, name=name)
1566 # pylint: enable=g-import-not-at-top,protected-access
1568 def skip(self, count, name=None):
1569 """Creates a `Dataset` that skips `count` elements from this dataset.
1571 >>> dataset = tf.data.Dataset.range(10)
1572 >>> dataset = dataset.skip(7)
1573 >>> list(dataset.as_numpy_iterator())
1574 [7, 8, 9]
1576 Args:
1577 count: A `tf.int64` scalar `tf.Tensor`, representing the number of
1578 elements of this dataset that should be skipped to form the new dataset.
1579 If `count` is greater than the size of this dataset, the new dataset
1580 will contain no elements. If `count` is -1, skips the entire dataset.
1581 name: (Optional.) A name for the tf.data operation.
1583 Returns:
1584 A new `Dataset` with the transformation applied as described above.
1585 """
1586 # Loaded lazily due to a circular dependency (dataset_ops ->
1587 # skip_op -> dataset_ops).
1588 # pylint: disable=g-import-not-at-top,protected-access
1589 from tensorflow.python.data.ops import skip_op
1590 return skip_op._skip(self, count, name)
1591 # pylint: enable=g-import-not-at-top,protected-access
1593 def shard(self, num_shards, index, name=None):
1594 """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
1596 `shard` is deterministic. The Dataset produced by `A.shard(n, i)` will
1597 contain all elements of A whose index mod n = i.
1599 >>> A = tf.data.Dataset.range(10)
1600 >>> B = A.shard(num_shards=3, index=0)
1601 >>> list(B.as_numpy_iterator())
1602 [0, 3, 6, 9]
1603 >>> C = A.shard(num_shards=3, index=1)
1604 >>> list(C.as_numpy_iterator())
1605 [1, 4, 7]
1606 >>> D = A.shard(num_shards=3, index=2)
1607 >>> list(D.as_numpy_iterator())
1608 [2, 5, 8]
1610 This dataset operator is very useful when running distributed training, as
1611 it allows each worker to read a unique subset.
1613 When reading a single input file, you can shard elements as follows:
1615 ```python
1616 d = tf.data.TFRecordDataset(input_file)
1617 d = d.shard(num_workers, worker_index)
1618 d = d.repeat(num_epochs)
1619 d = d.shuffle(shuffle_buffer_size)
1620 d = d.map(parser_fn, num_parallel_calls=num_map_threads)
1621 ```
1623 Important caveats:
1625 - Be sure to shard before you use any randomizing operator (such as
1626 shuffle).
1627 - Generally it is best if the shard operator is used early in the dataset
1628 pipeline. For example, when reading from a set of TFRecord files, shard
1629 before converting the dataset to input samples. This avoids reading every
1630 file on every worker. The following is an example of an efficient
1631 sharding strategy within a complete pipeline:
1633 ```python
1634 d = Dataset.list_files(pattern, shuffle=False)
1635 d = d.shard(num_workers, worker_index)
1636 d = d.repeat(num_epochs)
1637 d = d.shuffle(shuffle_buffer_size)
1638 d = d.interleave(tf.data.TFRecordDataset,
1639 cycle_length=num_readers, block_length=1)
1640 d = d.map(parser_fn, num_parallel_calls=num_map_threads)
1641 ```
1643 Args:
1644 num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
1645 shards operating in parallel.
1646 index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
1647 name: (Optional.) A name for the tf.data operation.
1649 Returns:
1650 A new `Dataset` with the transformation applied as described above.
1652 Raises:
1653 InvalidArgumentError: if `num_shards` or `index` are illegal values.
1655 Note: error checking is done on a best-effort basis, and errors aren't
1656 guaranteed to be caught upon dataset creation. (e.g. providing in a
1657 placeholder tensor bypasses the early checking, and will instead result
1658 in an error during a session.run call.)
1659 """
1660 # pylint: disable=g-import-not-at-top,protected-access
1661 from tensorflow.python.data.ops import shard_op
1662 return shard_op._shard(self, num_shards, index, name=name)
1663 # pylint: enable=g-import-not-at-top,protected-access
1665 def save(self,
1666 path,
1667 compression=None,
1668 shard_func=None,
1669 checkpoint_args=None):
1670 """Saves the content of the given dataset.
1672 Example usage:
1674 >>> import tempfile
1675 >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
1676 >>> # Save a dataset
1677 >>> dataset = tf.data.Dataset.range(2)
1678 >>> dataset.save(path)
1679 >>> new_dataset = tf.data.Dataset.load(path)
1680 >>> for elem in new_dataset:
1681 ... print(elem)
1682 tf.Tensor(0, shape=(), dtype=int64)
1683 tf.Tensor(1, shape=(), dtype=int64)
1685 The saved dataset is saved in multiple file "shards". By default, the
1686 dataset output is divided to shards in a round-robin fashion but custom
1687 sharding can be specified via the `shard_func` function. For example, you
1688 can save the dataset to using a single shard as follows:
1690 ```python
1691 dataset = make_dataset()
1692 def custom_shard_func(element):
1693 return np.int64(0)
1694 dataset.save(
1695 path="/path/to/data", ..., shard_func=custom_shard_func)
1696 ```
1698 To enable checkpointing, pass in `checkpoint_args` to the `save` method
1699 as follows:
1701 ```python
1702 dataset = tf.data.Dataset.range(100)
1703 save_dir = "..."
1704 checkpoint_prefix = "..."
1705 step_counter = tf.Variable(0, trainable=False)
1706 checkpoint_args = {
1707 "checkpoint_interval": 50,
1708 "step_counter": step_counter,
1709 "directory": checkpoint_prefix,
1710 "max_to_keep": 20,
1711 }
1712 dataset.save(dataset, save_dir, checkpoint_args=checkpoint_args)
1713 ```
1715 NOTE: The directory layout and file format used for saving the dataset is
1716 considered an implementation detail and may change. For this reason,
1717 datasets saved through `tf.data.Dataset.save` should only be consumed
1718 through `tf.data.Dataset.load`, which is guaranteed to be
1719 backwards compatible.
1721 Args:
1722 path: Required. A directory to use for saving the dataset.
1723 compression: Optional. The algorithm to use to compress data when writing
1724 it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
1725 shard_func: Optional. A function to control the mapping of dataset
1726 elements to file shards. The function is expected to map elements of
1727 the input dataset to int64 shard IDs. If present, the function will be
1728 traced and executed as graph computation.
1729 checkpoint_args: Optional args for checkpointing which will be passed into
1730 the `tf.train.CheckpointManager`. If `checkpoint_args` are not
1731 specified, then checkpointing will not be performed. The `save()`
1732 implementation creates a `tf.train.Checkpoint` object internally, so
1733 users should not set the `checkpoint` argument in `checkpoint_args`.
1735 Returns:
1736 An operation which when executed performs the save. When writing
1737 checkpoints, returns None. The return value is useful in unit tests.
1739 Raises:
1740 ValueError if `checkpoint` is passed into `checkpoint_args`.
1741 """
1742 # Loaded lazily due to a circular dependency (dataset_ops -> save_op ->
1743 # dataset_ops).
1744 # pylint: disable=g-import-not-at-top,protected-access
1745 from tensorflow.python.data.ops import save_op
1746 return save_op._save(self, path, compression, shard_func, checkpoint_args)
1747 # pylint: enable=g-import-not-at-top,protected-access
1749 @staticmethod
1750 def load(path, element_spec=None, compression=None, reader_func=None):
1751 """Loads a previously saved dataset.
1753 Example usage:
1755 >>> import tempfile
1756 >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
1757 >>> # Save a dataset
1758 >>> dataset = tf.data.Dataset.range(2)
1759 >>> tf.data.Dataset.save(dataset, path)
1760 >>> new_dataset = tf.data.Dataset.load(path)
1761 >>> for elem in new_dataset:
1762 ... print(elem)
1763 tf.Tensor(0, shape=(), dtype=int64)
1764 tf.Tensor(1, shape=(), dtype=int64)
1767 If the default option of sharding the saved dataset was used, the element
1768 order of the saved dataset will be preserved when loading it.
1770 The `reader_func` argument can be used to specify a custom order in which
1771 elements should be loaded from the individual shards. The `reader_func` is
1772 expected to take a single argument -- a dataset of datasets, each containing
1773 elements of one of the shards -- and return a dataset of elements. For
1774 example, the order of shards can be shuffled when loading them as follows:
1776 ```python
1777 def custom_reader_func(datasets):
1778 datasets = datasets.shuffle(NUM_SHARDS)
1779 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
1781 dataset = tf.data.Dataset.load(
1782 path="/path/to/data", ..., reader_func=custom_reader_func)
1783 ```
1785 Args:
1786 path: Required. A path pointing to a previously saved dataset.
1787 element_spec: Optional. A nested structure of `tf.TypeSpec` objects
1788 matching the structure of an element of the saved dataset and specifying
1789 the type of individual element components. If not provided, the nested
1790 structure of `tf.TypeSpec` saved with the saved dataset is used. Note
1791 that this argument is required in graph mode.
1792 compression: Optional. The algorithm to use to decompress the data when
1793 reading it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
1794 reader_func: Optional. A function to control how to read data from shards.
1795 If present, the function will be traced and executed as graph
1796 computation.
1798 Returns:
1799 A `tf.data.Dataset` instance.
1801 Raises:
1802 FileNotFoundError: If `element_spec` is not specified and the saved nested
1803 structure of `tf.TypeSpec` can not be located with the saved dataset.
1804 ValueError: If `element_spec` is not specified and the method is executed
1805 in graph mode.
1806 """
1807 # Loaded lazily due to a circular dependency (dataset_ops -> load_op ->
1808 # dataset_ops).
1809 # pylint: disable=g-import-not-at-top,protected-access
1810 from tensorflow.python.data.ops import load_op
1811 return load_op._load(
1812 path=path,
1813 element_spec=element_spec,
1814 compression=compression,
1815 reader_func=reader_func)
1816 # pylint: enable=g-import-not-at-top,protected-access
1818 def batch(self,
1819 batch_size,
1820 drop_remainder=False,
1821 num_parallel_calls=None,
1822 deterministic=None,
1823 name=None):
1824 """Combines consecutive elements of this dataset into batches.
1826 >>> dataset = tf.data.Dataset.range(8)
1827 >>> dataset = dataset.batch(3)
1828 >>> list(dataset.as_numpy_iterator())
1829 [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
1831 >>> dataset = tf.data.Dataset.range(8)
1832 >>> dataset = dataset.batch(3, drop_remainder=True)
1833 >>> list(dataset.as_numpy_iterator())
1834 [array([0, 1, 2]), array([3, 4, 5])]
1836 The components of the resulting element will have an additional outer
1837 dimension, which will be `batch_size` (or `N % batch_size` for the last
1838 element if `batch_size` does not divide the number of input elements `N`
1839 evenly and `drop_remainder` is `False`). If your program depends on the
1840 batches having the same outer dimension, you should set the `drop_remainder`
1841 argument to `True` to prevent the smaller batch from being produced.
1843 Note: If your program requires data to have a statically known shape (e.g.,
1844 when using XLA), you should use `drop_remainder=True`. Without
1845 `drop_remainder=True` the shape of the output dataset will have an unknown
1846 leading dimension due to the possibility of a smaller final batch.
1848 Args:
1849 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1850 consecutive elements of this dataset to combine in a single batch.
1851 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
1852 whether the last batch should be dropped in the case it has fewer than
1853 `batch_size` elements; the default behavior is not to drop the smaller
1854 batch.
1855 num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`,
1856 representing the number of batches to compute asynchronously in
1857 parallel.
1858 If not specified, batches will be computed sequentially. If the value
1859 `tf.data.AUTOTUNE` is used, then the number of parallel
1860 calls is set dynamically based on available resources.
1861 deterministic: (Optional.) When `num_parallel_calls` is specified, if this
1862 boolean is specified (`True` or `False`), it controls the order in which
1863 the transformation produces elements. If set to `False`, the
1864 transformation is allowed to yield elements out of order to trade
1865 determinism for performance. If not specified, the
1866 `tf.data.Options.deterministic` option (`True` by default) controls the
1867 behavior.
1868 name: (Optional.) A name for the tf.data operation.
1870 Returns:
1871 A new `Dataset` with the transformation applied as described above.
1872 """
1873 # Loaded lazily due to a circular dependency (dataset_ops -> batch_op ->
1874 # dataset_ops).
1875 # pylint: disable=g-import-not-at-top,protected-access,redefined-outer-name
1876 from tensorflow.python.data.ops import batch_op
1877 return batch_op._batch(self, batch_size, drop_remainder, num_parallel_calls,
1878 deterministic, name)
1879 # pylint: enable=g-import-not-at-top,protected-access,redefined-outer-name
1881 def padded_batch(self,
1882 batch_size,
1883 padded_shapes=None,
1884 padding_values=None,
1885 drop_remainder=False,
1886 name=None):
1887 """Combines consecutive elements of this dataset into padded batches.
1889 This transformation combines multiple consecutive elements of the input
1890 dataset into a single element.
1892 Like `tf.data.Dataset.batch`, the components of the resulting element will
1893 have an additional outer dimension, which will be `batch_size` (or
1894 `N % batch_size` for the last element if `batch_size` does not divide the
1895 number of input elements `N` evenly and `drop_remainder` is `False`). If
1896 your program depends on the batches having the same outer dimension, you
1897 should set the `drop_remainder` argument to `True` to prevent the smaller
1898 batch from being produced.
1900 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
1901 different shapes, and this transformation will pad each component to the
1902 respective shape in `padded_shapes`. The `padded_shapes` argument
1903 determines the resulting shape for each dimension of each component in an
1904 output element:
1906 * If the dimension is a constant, the component will be padded out to that
1907 length in that dimension.
1908 * If the dimension is unknown, the component will be padded out to the
1909 maximum length of all elements in that dimension.
1911 >>> A = (tf.data.Dataset
1912 ... .range(1, 5, output_type=tf.int32)
1913 ... .map(lambda x: tf.fill([x], x)))
1914 >>> # Pad to the smallest per-batch size that fits all elements.
1915 >>> B = A.padded_batch(2)
1916 >>> for element in B.as_numpy_iterator():
1917 ... print(element)
1918 [[1 0]
1919 [2 2]]
1920 [[3 3 3 0]
1921 [4 4 4 4]]
1922 >>> # Pad to a fixed size.
1923 >>> C = A.padded_batch(2, padded_shapes=5)
1924 >>> for element in C.as_numpy_iterator():
1925 ... print(element)
1926 [[1 0 0 0 0]
1927 [2 2 0 0 0]]
1928 [[3 3 3 0 0]
1929 [4 4 4 4 0]]
1930 >>> # Pad with a custom value.
1931 >>> D = A.padded_batch(2, padded_shapes=5, padding_values=-1)
1932 >>> for element in D.as_numpy_iterator():
1933 ... print(element)
1934 [[ 1 -1 -1 -1 -1]
1935 [ 2 2 -1 -1 -1]]
1936 [[ 3 3 3 -1 -1]
1937 [ 4 4 4 4 -1]]
1938 >>> # Components of nested elements can be padded independently.
1939 >>> elements = [([1, 2, 3], [10]),
1940 ... ([4, 5], [11, 12])]
1941 >>> dataset = tf.data.Dataset.from_generator(
1942 ... lambda: iter(elements), (tf.int32, tf.int32))
1943 >>> # Pad the first component of the tuple to length 4, and the second
1944 >>> # component to the smallest size that fits.
1945 >>> dataset = dataset.padded_batch(2,
1946 ... padded_shapes=([4], [None]),
1947 ... padding_values=(-1, 100))
1948 >>> list(dataset.as_numpy_iterator())
1949 [(array([[ 1, 2, 3, -1], [ 4, 5, -1, -1]], dtype=int32),
1950 array([[ 10, 100], [ 11, 12]], dtype=int32))]
1951 >>> # Pad with a single value and multiple components.
1952 >>> E = tf.data.Dataset.zip((A, A)).padded_batch(2, padding_values=-1)
1953 >>> for element in E.as_numpy_iterator():
1954 ... print(element)
1955 (array([[ 1, -1],
1956 [ 2, 2]], dtype=int32), array([[ 1, -1],
1957 [ 2, 2]], dtype=int32))
1958 (array([[ 3, 3, 3, -1],
1959 [ 4, 4, 4, 4]], dtype=int32), array([[ 3, 3, 3, -1],
1960 [ 4, 4, 4, 4]], dtype=int32))
1962 See also `tf.data.experimental.dense_to_sparse_batch`, which combines
1963 elements that may have different shapes into a `tf.sparse.SparseTensor`.
1965 Args:
1966 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1967 consecutive elements of this dataset to combine in a single batch.
1968 padded_shapes: (Optional.) A (nested) structure of `tf.TensorShape` or
1969 `tf.int64` vector tensor-like objects representing the shape to which
1970 the respective component of each input element should be padded prior
1971 to batching. Any unknown dimensions will be padded to the maximum size
1972 of that dimension in each batch. If unset, all dimensions of all
1973 components are padded to the maximum size in the batch. `padded_shapes`
1974 must be set if any component has an unknown rank.
1975 padding_values: (Optional.) A (nested) structure of scalar-shaped
1976 `tf.Tensor`, representing the padding values to use for the respective
1977 components. None represents that the (nested) structure should be padded
1978 with default values. Defaults are `0` for numeric types and the empty
1979 string for string types. The `padding_values` should have the same
1980 (nested) structure as the input dataset. If `padding_values` is a single
1981 element and the input dataset has multiple components, then the same
1982 `padding_values` will be used to pad every component of the dataset.
1983 If `padding_values` is a scalar, then its value will be broadcasted
1984 to match the shape of each component.
1985 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
1986 whether the last batch should be dropped in the case it has fewer than
1987 `batch_size` elements; the default behavior is not to drop the smaller
1988 batch.
1989 name: (Optional.) A name for the tf.data operation.
1991 Returns:
1992 A new `Dataset` with the transformation applied as described above.
1994 Raises:
1995 ValueError: If a component has an unknown rank, and the `padded_shapes`
1996 argument is not set.
1997 TypeError: If a component is of an unsupported type. The list of supported
1998 types is documented in
1999 https://www.tensorflow.org/guide/data#dataset_structure.
2000 """
2001 # Loaded lazily due to a circular dependency (dataset_ops ->
2002 # padded_batch_op -> dataset_ops).
2003 # pylint: disable=g-import-not-at-top,protected-access
2004 from tensorflow.python.data.ops import padded_batch_op
2005 return padded_batch_op._padded_batch(self, batch_size, padded_shapes,
2006 padding_values, drop_remainder, name)
2007 # pylint: enable=g-import-not-at-top,protected-access
2009 def ragged_batch(self,
2010 batch_size,
2011 drop_remainder=False,
2012 row_splits_dtype=dtypes.int64,
2013 name=None):
2014 """Combines consecutive elements of this dataset into `tf.RaggedTensor`s.
2016 Like `tf.data.Dataset.batch`, the components of the resulting element will
2017 have an additional outer dimension, which will be `batch_size` (or
2018 `N % batch_size` for the last element if `batch_size` does not divide the
2019 number of input elements `N` evenly and `drop_remainder` is `False`). If
2020 your program depends on the batches having the same outer dimension, you
2021 should set the `drop_remainder` argument to `True` to prevent the smaller
2022 batch from being produced.
2024 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
2025 different shapes:
2027 * If an input element is a `tf.Tensor` whose static `tf.TensorShape` is
2028 fully defined, then it is batched as normal.
2029 * If an input element is a `tf.Tensor` whose static `tf.TensorShape`
2030 contains one or more axes with unknown size (i.e., `shape[i]=None`), then
2031 the output will contain a `tf.RaggedTensor` that is ragged up to any of such
2032 dimensions.
2033 * If an input element is a `tf.RaggedTensor` or any other type, then it is
2034 batched as normal.
2036 Example:
2038 >>> dataset = tf.data.Dataset.range(6)
2039 >>> dataset = dataset.map(lambda x: tf.range(x))
2040 >>> dataset.element_spec.shape
2041 TensorShape([None])
2042 >>> dataset = dataset.ragged_batch(2)
2043 >>> for batch in dataset:
2044 ... print(batch)
2045 <tf.RaggedTensor [[], [0]]>
2046 <tf.RaggedTensor [[0, 1], [0, 1, 2]]>
2047 <tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]>
2049 Args:
2050 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
2051 consecutive elements of this dataset to combine in a single batch.
2052 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
2053 whether the last batch should be dropped in the case it has fewer than
2054 `batch_size` elements; the default behavior is not to drop the smaller
2055 batch.
2056 row_splits_dtype: The dtype that should be used for the `row_splits` of
2057 any new ragged tensors. Existing `tf.RaggedTensor` elements do not have
2058 their row_splits dtype changed.
2059 name: (Optional.) A string indicating a name for the `tf.data` operation.
2061 Returns:
2062 A new `Dataset` with the transformation applied as described above.
2063 """
2064 # Loaded lazily due to a circular dependency (dataset_ops ->
2065 # ragged_batch_op -> dataset_ops).
2066 # pylint: disable=g-import-not-at-top,protected-access
2067 from tensorflow.python.data.ops import ragged_batch_op
2068 return ragged_batch_op._ragged_batch(self, batch_size, drop_remainder,
2069 row_splits_dtype, name)
2070 # pylint: enable=g-import-not-at-top,protected-access
2072 def sparse_batch(self, batch_size, row_shape, name=None):
2073 """Combines consecutive elements into `tf.sparse.SparseTensor`s.
2075 Like `Dataset.padded_batch()`, this transformation combines multiple
2076 consecutive elements of the dataset, which might have different
2077 shapes, into a single element. The resulting element has three
2078 components (`indices`, `values`, and `dense_shape`), which
2079 comprise a `tf.sparse.SparseTensor` that represents the same data. The
2080 `row_shape` represents the dense shape of each row in the
2081 resulting `tf.sparse.SparseTensor`, to which the effective batch size is
2082 prepended. For example:
2084 ```python
2085 # NOTE: The following examples use `{ ... }` to represent the
2086 # contents of a dataset.
2087 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
2089 a.apply(tf.data.experimental.dense_to_sparse_batch(
2090 batch_size=2, row_shape=[6])) ==
2091 {
2092 ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices
2093 ['a', 'b', 'c', 'a', 'b'], # values
2094 [2, 6]), # dense_shape
2095 ([[0, 0], [0, 1], [0, 2], [0, 3]],
2096 ['a', 'b', 'c', 'd'],
2097 [1, 6])
2098 }
2099 ```
2101 Args:
2102 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
2103 consecutive elements of this dataset to combine in a single batch.
2104 row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like object
2105 representing the equivalent dense shape of a row in the resulting
2106 `tf.sparse.SparseTensor`. Each element of this dataset must have the
2107 same rank as `row_shape`, and must have size less than or equal to
2108 `row_shape` in each dimension.
2109 name: (Optional.) A string indicating a name for the `tf.data` operation.
2111 Returns:
2112 A new `Dataset` with the transformation applied as described above.
2113 """
2114 # Loaded lazily due to a circular dependency (dataset_ops ->
2115 # sparse_batch_op -> dataset_ops).
2116 # pylint: disable=g-import-not-at-top,protected-access
2117 from tensorflow.python.data.ops import sparse_batch_op
2118 return sparse_batch_op._sparse_batch(self, batch_size, row_shape, name)
2119 # pylint: disable=g-import-not-at-top,protected-access
2121 def map(self,
2122 map_func,
2123 num_parallel_calls=None,
2124 deterministic=None,
2125 name=None):
2126 """Maps `map_func` across the elements of this dataset.
2128 This transformation applies `map_func` to each element of this dataset, and
2129 returns a new dataset containing the transformed elements, in the same
2130 order as they appeared in the input. `map_func` can be used to change both
2131 the values and the structure of a dataset's elements. Supported structure
2132 constructs are documented
2133 [here](https://www.tensorflow.org/guide/data#dataset_structure).
2135 For example, `map` can be used for adding 1 to each element, or projecting a
2136 subset of element components.
2138 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
2139 >>> dataset = dataset.map(lambda x: x + 1)
2140 >>> list(dataset.as_numpy_iterator())
2141 [2, 3, 4, 5, 6]
2143 The input signature of `map_func` is determined by the structure of each
2144 element in this dataset.
2146 >>> dataset = Dataset.range(5)
2147 >>> # `map_func` takes a single argument of type `tf.Tensor` with the same
2148 >>> # shape and dtype.
2149 >>> result = dataset.map(lambda x: x + 1)
2151 >>> # Each element is a tuple containing two `tf.Tensor` objects.
2152 >>> elements = [(1, "foo"), (2, "bar"), (3, "baz")]
2153 >>> dataset = tf.data.Dataset.from_generator(
2154 ... lambda: elements, (tf.int32, tf.string))
2155 >>> # `map_func` takes two arguments of type `tf.Tensor`. This function
2156 >>> # projects out just the first component.
2157 >>> result = dataset.map(lambda x_int, y_str: x_int)
2158 >>> list(result.as_numpy_iterator())
2159 [1, 2, 3]
2161 >>> # Each element is a dictionary mapping strings to `tf.Tensor` objects.
2162 >>> elements = ([{"a": 1, "b": "foo"},
2163 ... {"a": 2, "b": "bar"},
2164 ... {"a": 3, "b": "baz"}])
2165 >>> dataset = tf.data.Dataset.from_generator(
2166 ... lambda: elements, {"a": tf.int32, "b": tf.string})
2167 >>> # `map_func` takes a single argument of type `dict` with the same keys
2168 >>> # as the elements.
2169 >>> result = dataset.map(lambda d: str(d["a"]) + d["b"])
2171 The value or values returned by `map_func` determine the structure of each
2172 element in the returned dataset.
2174 >>> dataset = tf.data.Dataset.range(3)
2175 >>> # `map_func` returns two `tf.Tensor` objects.
2176 >>> def g(x):
2177 ... return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
2178 >>> result = dataset.map(g)
2179 >>> result.element_spec
2180 (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), \
2181dtype=tf.string, name=None))
2182 >>> # Python primitives, lists, and NumPy arrays are implicitly converted to
2183 >>> # `tf.Tensor`.
2184 >>> def h(x):
2185 ... return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64)
2186 >>> result = dataset.map(h)
2187 >>> result.element_spec
2188 (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), \
2189dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, \
2190name=None))
2191 >>> # `map_func` can return nested structures.
2192 >>> def i(x):
2193 ... return (37.0, [42, 16]), "foo"
2194 >>> result = dataset.map(i)
2195 >>> result.element_spec
2196 ((TensorSpec(shape=(), dtype=tf.float32, name=None),
2197 TensorSpec(shape=(2,), dtype=tf.int32, name=None)),
2198 TensorSpec(shape=(), dtype=tf.string, name=None))
2200 `map_func` can accept as arguments and return any type of dataset element.
2202 Note that irrespective of the context in which `map_func` is defined (eager
2203 vs. graph), tf.data traces the function and executes it as a graph. To use
2204 Python code inside of the function you have a few options:
2206 1) Rely on AutoGraph to convert Python code into an equivalent graph
2207 computation. The downside of this approach is that AutoGraph can convert
2208 some but not all Python code.
2210 2) Use `tf.py_function`, which allows you to write arbitrary Python code but
2211 will generally result in worse performance than 1). For example:
2213 >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
2214 >>> # transform a string tensor to upper case string using a Python function
2215 >>> def upper_case_fn(t: tf.Tensor):
2216 ... return t.numpy().decode('utf-8').upper()
2217 >>> d = d.map(lambda x: tf.py_function(func=upper_case_fn,
2218 ... inp=[x], Tout=tf.string))
2219 >>> list(d.as_numpy_iterator())
2220 [b'HELLO', b'WORLD']
2222 3) Use `tf.numpy_function`, which also allows you to write arbitrary
2223 Python code. Note that `tf.py_function` accepts `tf.Tensor` whereas
2224 `tf.numpy_function` accepts numpy arrays and returns only numpy arrays.
2225 For example:
2227 >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
2228 >>> def upper_case_fn(t: np.ndarray):
2229 ... return t.decode('utf-8').upper()
2230 >>> d = d.map(lambda x: tf.numpy_function(func=upper_case_fn,
2231 ... inp=[x], Tout=tf.string))
2232 >>> list(d.as_numpy_iterator())
2233 [b'HELLO', b'WORLD']
2235 Note that the use of `tf.numpy_function` and `tf.py_function`
2236 in general precludes the possibility of executing user-defined
2237 transformations in parallel (because of Python GIL).
2239 Performance can often be improved by setting `num_parallel_calls` so that
2240 `map` will use multiple threads to process elements. If deterministic order
2241 isn't required, it can also improve performance to set
2242 `deterministic=False`.
2244 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
2245 >>> dataset = dataset.map(lambda x: x + 1,
2246 ... num_parallel_calls=tf.data.AUTOTUNE,
2247 ... deterministic=False)
2249 The order of elements yielded by this transformation is deterministic if
2250 `deterministic=True`. If `map_func` contains stateful operations and
2251 `num_parallel_calls > 1`, the order in which that state is accessed is
2252 undefined, so the values of output elements may not be deterministic
2253 regardless of the `deterministic` flag value.
2255 Args:
2256 map_func: A function mapping a dataset element to another dataset element.
2257 num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`,
2258 representing the number elements to process asynchronously in parallel.
2259 If not specified, elements will be processed sequentially. If the value
2260 `tf.data.AUTOTUNE` is used, then the number of parallel
2261 calls is set dynamically based on available CPU.
2262 deterministic: (Optional.) When `num_parallel_calls` is specified, if this
2263 boolean is specified (`True` or `False`), it controls the order in which
2264 the transformation produces elements. If set to `False`, the
2265 transformation is allowed to yield elements out of order to trade
2266 determinism for performance. If not specified, the
2267 `tf.data.Options.deterministic` option (`True` by default) controls the
2268 behavior.
2269 name: (Optional.) A name for the tf.data operation.
2271 Returns:
2272 A new `Dataset` with the transformation applied as described above.
2273 """
2274 # Loaded lazily due to a circular dependency (dataset_ops -> map_op ->
2275 # dataset_ops).
2276 # pylint: disable=g-import-not-at-top,protected-access
2277 from tensorflow.python.data.ops import map_op
2278 return map_op._map_v2(
2279 self,
2280 map_func,
2281 num_parallel_calls=num_parallel_calls,
2282 deterministic=deterministic,
2283 name=name)
2284 # pylint: enable=g-import-not-at-top,protected-access
2286 def flat_map(self, map_func, name=None):
2287 """Maps `map_func` across this dataset and flattens the result.
2289 The type signature is:
2291 ```
2292 def flat_map(
2293 self: Dataset[T],
2294 map_func: Callable[[T], Dataset[S]]
2295 ) -> Dataset[S]
2296 ```
2298 Use `flat_map` if you want to make sure that the order of your dataset
2299 stays the same. For example, to flatten a dataset of batches into a
2300 dataset of their elements:
2302 >>> dataset = tf.data.Dataset.from_tensor_slices(
2303 ... [[1, 2, 3], [4, 5, 6], [7, 8, 9]])
2304 >>> dataset = dataset.flat_map(tf.data.Dataset.from_tensor_slices)
2305 >>> list(dataset.as_numpy_iterator())
2306 [1, 2, 3, 4, 5, 6, 7, 8, 9]
2308 `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
2309 `flat_map` produces the same output as
2310 `tf.data.Dataset.interleave(cycle_length=1)`
2312 Args:
2313 map_func: A function mapping a dataset element to a dataset.
2314 name: (Optional.) A name for the tf.data operation.
2316 Returns:
2317 A new `Dataset` with the transformation applied as described above.
2318 """
2319 # Loaded lazily due to a circular dependency (dataset_ops -> flat_map_op ->
2320 # dataset_ops).
2321 # pylint: disable=g-import-not-at-top,protected-access
2322 from tensorflow.python.data.ops import flat_map_op
2323 return flat_map_op._flat_map(self, map_func, name=name)
2324 # pylint: enable=g-import-not-at-top,protected-access
2326 def ignore_errors(self, log_warning=False, name=None):
2327 """Drops elements that cause errors.
2329 >>> dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.])
2330 >>> dataset = dataset.map(lambda x: tf.debugging.check_numerics(1. / x, ""))
2331 >>> list(dataset.as_numpy_iterator())
2332 Traceback (most recent call last):
2333 ...
2334 InvalidArgumentError: ... Tensor had Inf values
2335 >>> dataset = dataset.ignore_errors()
2336 >>> list(dataset.as_numpy_iterator())
2337 [1.0, 0.5, 0.25]
2339 Args:
2340 log_warning: (Optional.) A bool indicating whether or not ignored errors
2341 should be logged to stderr. Defaults to `False`.
2342 name: (Optional.) A string indicating a name for the `tf.data` operation.
2344 Returns:
2345 A new `Dataset` with the transformation applied as described above.
2346 """
2347 # Loaded lazily due to a circular dependency (dataset_ops ->
2348 # ignore_errors_op -> dataset_ops).
2349 # pylint: disable=g-import-not-at-top,protected-access
2350 from tensorflow.python.data.ops import ignore_errors_op
2351 return ignore_errors_op._ignore_errors(self, log_warning, name)
2352 # pylint: enable=g-import-not-at-top,protected-access
2354 def interleave(self,
2355 map_func,
2356 cycle_length=None,
2357 block_length=None,
2358 num_parallel_calls=None,
2359 deterministic=None,
2360 name=None):
2361 """Maps `map_func` across this dataset, and interleaves the results.
2363 The type signature is:
2365 ```
2366 def interleave(
2367 self: Dataset[T],
2368 map_func: Callable[[T], Dataset[S]]
2369 ) -> Dataset[S]
2370 ```
2372 For example, you can use `Dataset.interleave()` to process many input files
2373 concurrently:
2375 >>> # Preprocess 4 files concurrently, and interleave blocks of 16 records
2376 >>> # from each file.
2377 >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
2378 ... "/var/data/file3.txt", "/var/data/file4.txt"]
2379 >>> dataset = tf.data.Dataset.from_tensor_slices(filenames)
2380 >>> def parse_fn(filename):
2381 ... return tf.data.Dataset.range(10)
2382 >>> dataset = dataset.interleave(lambda x:
2383 ... tf.data.TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
2384 ... cycle_length=4, block_length=16)
2386 The `cycle_length` and `block_length` arguments control the order in which
2387 elements are produced. `cycle_length` controls the number of input elements
2388 that are processed concurrently. If you set `cycle_length` to 1, this
2389 transformation will handle one input element at a time, and will produce
2390 identical results to `tf.data.Dataset.flat_map`. In general,
2391 this transformation will apply `map_func` to `cycle_length` input elements,
2392 open iterators on the returned `Dataset` objects, and cycle through them
2393 producing `block_length` consecutive elements from each iterator, and
2394 consuming the next input element each time it reaches the end of an
2395 iterator.
2397 For example:
2399 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
2400 >>> # NOTE: New lines indicate "block" boundaries.
2401 >>> dataset = dataset.interleave(
2402 ... lambda x: Dataset.from_tensors(x).repeat(6),
2403 ... cycle_length=2, block_length=4)
2404 >>> list(dataset.as_numpy_iterator())
2405 [1, 1, 1, 1,
2406 2, 2, 2, 2,
2407 1, 1,
2408 2, 2,
2409 3, 3, 3, 3,
2410 4, 4, 4, 4,
2411 3, 3,
2412 4, 4,
2413 5, 5, 5, 5,
2414 5, 5]
2416 Note: The order of elements yielded by this transformation is
2417 deterministic, as long as `map_func` is a pure function and
2418 `deterministic=True`. If `map_func` contains any stateful operations, the
2419 order in which that state is accessed is undefined.
2421 Performance can often be improved by setting `num_parallel_calls` so that
2422 `interleave` will use multiple threads to fetch elements. If determinism
2423 isn't required, it can also improve performance to set
2424 `deterministic=False`.
2426 >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
2427 ... "/var/data/file3.txt", "/var/data/file4.txt"]
2428 >>> dataset = tf.data.Dataset.from_tensor_slices(filenames)
2429 >>> dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x),
2430 ... cycle_length=4, num_parallel_calls=tf.data.AUTOTUNE,
2431 ... deterministic=False)
2433 Args:
2434 map_func: A function that takes a dataset element and returns a
2435 `tf.data.Dataset`.
2436 cycle_length: (Optional.) The number of input elements that will be
2437 processed concurrently. If not set, the tf.data runtime decides what it
2438 should be based on available CPU. If `num_parallel_calls` is set to
2439 `tf.data.AUTOTUNE`, the `cycle_length` argument identifies
2440 the maximum degree of parallelism.
2441 block_length: (Optional.) The number of consecutive elements to produce
2442 from each input element before cycling to another input element. If not
2443 set, defaults to 1.
2444 num_parallel_calls: (Optional.) If specified, the implementation creates a
2445 threadpool, which is used to fetch inputs from cycle elements
2446 asynchronously and in parallel. The default behavior is to fetch inputs
2447 from cycle elements synchronously with no parallelism. If the value
2448 `tf.data.AUTOTUNE` is used, then the number of parallel
2449 calls is set dynamically based on available CPU.
2450 deterministic: (Optional.) When `num_parallel_calls` is specified, if this
2451 boolean is specified (`True` or `False`), it controls the order in which
2452 the transformation produces elements. If set to `False`, the
2453 transformation is allowed to yield elements out of order to trade
2454 determinism for performance. If not specified, the
2455 `tf.data.Options.deterministic` option (`True` by default) controls the
2456 behavior.
2457 name: (Optional.) A name for the tf.data operation.
2459 Returns:
2460 A new `Dataset` with the transformation applied as described above.
2461 """
2462 # Loaded lazily due to a circular dependency (
2463 # dataset_ops -> interleave_op -> dataset_ops).
2464 # pylint: disable=g-import-not-at-top,protected-access
2465 from tensorflow.python.data.ops import interleave_op
2466 return interleave_op._interleave(self, map_func, cycle_length, block_length,
2467 num_parallel_calls, deterministic, name)
2468 # pylint: enable=g-import-not-at-top,protected-access
2470 def filter(self, predicate, name=None):
2471 """Filters this dataset according to `predicate`.
2473 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
2474 >>> dataset = dataset.filter(lambda x: x < 3)
2475 >>> list(dataset.as_numpy_iterator())
2476 [1, 2]
2477 >>> # `tf.math.equal(x, y)` is required for equality comparison
2478 >>> def filter_fn(x):
2479 ... return tf.math.equal(x, 1)
2480 >>> dataset = dataset.filter(filter_fn)
2481 >>> list(dataset.as_numpy_iterator())
2482 [1]
2484 Args:
2485 predicate: A function mapping a dataset element to a boolean.
2486 name: (Optional.) A name for the tf.data operation.
2488 Returns:
2489 A new `Dataset` with the transformation applied as described above.
2490 """
2491 # Loaded lazily due to a circular dependency (dataset_ops -> filter_op ->
2492 # dataset_ops).
2493 # pylint: disable=g-import-not-at-top,protected-access
2494 from tensorflow.python.data.ops import filter_op
2495 return filter_op._filter(self, predicate, name)
2496 # pylint: enable=g-import-not-at-top,protected-access
2498 def apply(self, transformation_func):
2499 """Applies a transformation function to this dataset.
2501 `apply` enables chaining of custom `Dataset` transformations, which are
2502 represented as functions that take one `Dataset` argument and return a
2503 transformed `Dataset`.
2505 >>> dataset = tf.data.Dataset.range(100)
2506 >>> def dataset_fn(ds):
2507 ... return ds.filter(lambda x: x < 5)
2508 >>> dataset = dataset.apply(dataset_fn)
2509 >>> list(dataset.as_numpy_iterator())
2510 [0, 1, 2, 3, 4]
2512 Args:
2513 transformation_func: A function that takes one `Dataset` argument and
2514 returns a `Dataset`.
2516 Returns:
2517 A new `Dataset` with the transformation applied as described above.
2518 """
2519 dataset = transformation_func(self)
2520 if not isinstance(dataset, data_types.DatasetV2):
2521 raise TypeError(
2522 f"`transformation_func` must return a `tf.data.Dataset` object. "
2523 f"Got {type(dataset)}.")
2524 dataset._input_datasets = [self] # pylint: disable=protected-access
2525 return dataset
2527 def window(self, size, shift=None, stride=1, drop_remainder=False, name=None):
2528 """Returns a dataset of "windows".
2530 Each "window" is a dataset that contains a subset of elements of the
2531 input dataset. These are finite datasets of size `size` (or possibly fewer
2532 if there are not enough input elements to fill the window and
2533 `drop_remainder` evaluates to `False`).
2535 For example:
2537 >>> dataset = tf.data.Dataset.range(7).window(3)
2538 >>> for window in dataset:
2539 ... print(window)
2540 <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
2541 <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
2542 <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
2544 Since windows are datasets, they can be iterated over:
2546 >>> for window in dataset:
2547 ... print(list(window.as_numpy_iterator()))
2548 [0, 1, 2]
2549 [3, 4, 5]
2550 [6]
2552 #### Shift
2554 The `shift` argument determines the number of input elements to shift
2555 between the start of each window. If windows and elements are both numbered
2556 starting at 0, the first element in window `k` will be element `k * shift`
2557 of the input dataset. In particular, the first element of the first window
2558 will always be the first element of the input dataset.
2560 >>> dataset = tf.data.Dataset.range(7).window(3, shift=1,
2561 ... drop_remainder=True)
2562 >>> for window in dataset:
2563 ... print(list(window.as_numpy_iterator()))
2564 [0, 1, 2]
2565 [1, 2, 3]
2566 [2, 3, 4]
2567 [3, 4, 5]
2568 [4, 5, 6]
2570 #### Stride
2572 The `stride` argument determines the stride between input elements within a
2573 window.
2575 >>> dataset = tf.data.Dataset.range(7).window(3, shift=1, stride=2,
2576 ... drop_remainder=True)
2577 >>> for window in dataset:
2578 ... print(list(window.as_numpy_iterator()))
2579 [0, 2, 4]
2580 [1, 3, 5]
2581 [2, 4, 6]
2583 #### Nested elements
2585 When the `window` transformation is applied to a dataset whos elements are
2586 nested structures, it produces a dataset where the elements have the same
2587 nested structure but each leaf is replaced by a window. In other words,
2588 the nesting is applied outside of the windows as opposed inside of them.
2590 The type signature is:
2592 ```
2593 def window(
2594 self: Dataset[Nest[T]], ...
2595 ) -> Dataset[Nest[Dataset[T]]]
2596 ```
2598 Applying `window` to a `Dataset` of tuples gives a tuple of windows:
2600 >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5],
2601 ... [6, 7, 8, 9, 10]))
2602 >>> dataset = dataset.window(2)
2603 >>> windows = next(iter(dataset))
2604 >>> windows
2605 (<...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int32, name=None)>,
2606 <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int32, name=None)>)
2608 >>> def to_numpy(ds):
2609 ... return list(ds.as_numpy_iterator())
2610 >>>
2611 >>> for windows in dataset:
2612 ... print(to_numpy(windows[0]), to_numpy(windows[1]))
2613 [1, 2] [6, 7]
2614 [3, 4] [8, 9]
2615 [5] [10]
2617 Applying `window` to a `Dataset` of dictionaries gives a dictionary of
2618 `Datasets`:
2620 >>> dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3],
2621 ... 'b': [4, 5, 6],
2622 ... 'c': [7, 8, 9]})
2623 >>> dataset = dataset.window(2)
2624 >>> def to_numpy(ds):
2625 ... return list(ds.as_numpy_iterator())
2626 >>>
2627 >>> for windows in dataset:
2628 ... print(tf.nest.map_structure(to_numpy, windows))
2629 {'a': [1, 2], 'b': [4, 5], 'c': [7, 8]}
2630 {'a': [3], 'b': [6], 'c': [9]}
2632 #### Flatten a dataset of windows
2634 The `Dataset.flat_map` and `Dataset.interleave` methods can be used to
2635 flatten a dataset of windows into a single dataset.
2637 The argument to `flat_map` is a function that takes an element from the
2638 dataset and returns a `Dataset`. `flat_map` chains together the resulting
2639 datasets sequentially.
2641 For example, to turn each window into a dense tensor:
2643 >>> dataset = tf.data.Dataset.range(7).window(3, shift=1,
2644 ... drop_remainder=True)
2645 >>> batched = dataset.flat_map(lambda x:x.batch(3))
2646 >>> for batch in batched:
2647 ... print(batch.numpy())
2648 [0 1 2]
2649 [1 2 3]
2650 [2 3 4]
2651 [3 4 5]
2652 [4 5 6]
2654 Args:
2655 size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
2656 of the input dataset to combine into a window. Must be positive.
2657 shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
2658 number of input elements by which the window moves in each iteration.
2659 Defaults to `size`. Must be positive.
2660 stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
2661 stride of the input elements in the sliding window. Must be positive.
2662 The default value of 1 means "retain every input element".
2663 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
2664 whether the last windows should be dropped if their size is smaller than
2665 `size`.
2666 name: (Optional.) A name for the tf.data operation.
2668 Returns:
2669 A new `Dataset` with the transformation applied as described above.
2670 """
2671 # Loaded lazily due to a circular dependency (dataset_ops -> window_op ->
2672 # dataset_ops).
2673 # pylint: disable=g-import-not-at-top,protected-access
2674 from tensorflow.python.data.ops import window_op
2675 return window_op._window(self, size, shift, stride, drop_remainder, name)
2676 # pylint: enable=g-import-not-at-top,protected-access
2678 def reduce(self, initial_state, reduce_func, name=None):
2679 """Reduces the input dataset to a single element.
2681 The transformation calls `reduce_func` successively on every element of
2682 the input dataset until the dataset is exhausted, aggregating information in
2683 its internal state. The `initial_state` argument is used for the initial
2684 state and the final state is returned as the result.
2686 >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy()
2687 5
2688 >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy()
2689 10
2691 Args:
2692 initial_state: An element representing the initial state of the
2693 transformation.
2694 reduce_func: A function that maps `(old_state, input_element)` to
2695 `new_state`. It must take two arguments and return a new element
2696 The structure of `new_state` must match the structure of
2697 `initial_state`.
2698 name: (Optional.) A name for the tf.data operation.
2700 Returns:
2701 A dataset element corresponding to the final state of the transformation.
2703 """
2705 with ops.name_scope("initial_state"):
2706 initial_state = structure.normalize_element(initial_state)
2707 state_structure = structure.type_spec_from_value(initial_state)
2709 # Iteratively rerun the reduce function until reaching a fixed point on
2710 # `state_structure`.
2711 need_to_rerun = True
2712 while need_to_rerun:
2714 wrapped_func = structured_function.StructuredFunctionWrapper(
2715 reduce_func,
2716 "reduce()",
2717 input_structure=(state_structure, self.element_spec),
2718 add_to_graph=False)
2720 # Extract and validate class information from the returned values.
2721 output_classes = wrapped_func.output_classes
2722 state_classes = nest.map_structure(
2723 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
2724 state_structure)
2725 for new_state_class, state_class in zip(
2726 nest.flatten(output_classes), nest.flatten(state_classes)):
2727 if not issubclass(new_state_class, state_class):
2728 raise TypeError(
2729 f"The element classes for the new state must match the initial "
2730 f"state. Expected {state_classes} but got "
2731 f"{wrapped_func.output_classes}.")
2733 # Extract and validate type information from the returned values.
2734 output_types = wrapped_func.output_types
2735 state_types = nest.map_structure(
2736 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
2737 state_structure)
2738 for new_state_type, state_type in zip(
2739 nest.flatten(output_types), nest.flatten(state_types)):
2740 if new_state_type != state_type:
2741 raise TypeError(
2742 f"The element types for the new state must match the initial "
2743 f"state. Expected {state_types} but got "
2744 f"{wrapped_func.output_types}.")
2746 # Extract shape information from the returned values.
2747 output_shapes = wrapped_func.output_shapes
2748 state_shapes = nest.map_structure(
2749 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
2750 state_structure)
2751 flat_state_shapes = nest.flatten(state_shapes)
2752 flat_new_state_shapes = nest.flatten(output_shapes)
2753 weakened_state_shapes = [
2754 original.most_specific_compatible_shape(new)
2755 for original, new in zip(flat_state_shapes, flat_new_state_shapes)
2756 ]
2758 need_to_rerun = False
2759 for original_shape, weakened_shape in zip(flat_state_shapes,
2760 weakened_state_shapes):
2761 if original_shape.ndims is not None and (
2762 weakened_shape.ndims is None or
2763 original_shape.as_list() != weakened_shape.as_list()):
2764 need_to_rerun = True
2765 break
2767 if need_to_rerun:
2768 # TODO(b/110122868): Support a "most specific compatible structure"
2769 # method for combining structures, to avoid using legacy structures
2770 # here.
2771 state_structure = structure.convert_legacy_structure(
2772 state_types,
2773 nest.pack_sequence_as(state_shapes, weakened_state_shapes),
2774 state_classes)
2776 reduce_func = wrapped_func.function
2777 reduce_func.add_to_graph(ops.get_default_graph())
2779 dataset = self._apply_debug_options()
2781 # pylint: disable=protected-access
2782 metadata = dataset_metadata_pb2.Metadata()
2783 if name:
2784 metadata.name = _validate_and_encode(name)
2785 return structure.from_compatible_tensor_list(
2786 state_structure,
2787 gen_dataset_ops.reduce_dataset(
2788 dataset._variant_tensor,
2789 structure.to_tensor_list(state_structure, initial_state),
2790 reduce_func.captured_inputs,
2791 f=reduce_func,
2792 output_shapes=structure.get_flat_tensor_shapes(state_structure),
2793 output_types=structure.get_flat_tensor_types(state_structure),
2794 metadata=metadata.SerializeToString()))
2796 def get_single_element(self, name=None):
2797 """Returns the single element of the `dataset`.
2799 The function enables you to use a `tf.data.Dataset` in a stateless
2800 "tensor-in tensor-out" expression, without creating an iterator.
2801 This facilitates the ease of data transformation on tensors using the
2802 optimized `tf.data.Dataset` abstraction on top of them.
2804 For example, lets consider a `preprocessing_fn` which would take as an
2805 input the raw features and returns the processed feature along with
2806 it's label.
2808 ```python
2809 def preprocessing_fn(raw_feature):
2810 # ... the raw_feature is preprocessed as per the use-case
2811 return feature
2813 raw_features = ... # input batch of BATCH_SIZE elements.
2814 dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
2815 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
2816 .batch(BATCH_SIZE))
2818 processed_features = dataset.get_single_element()
2819 ```
2821 In the above example, the `raw_features` tensor of length=BATCH_SIZE
2822 was converted to a `tf.data.Dataset`. Next, each of the `raw_feature` was
2823 mapped using the `preprocessing_fn` and the processed features were
2824 grouped into a single batch. The final `dataset` contains only one element
2825 which is a batch of all the processed features.
2827 NOTE: The `dataset` should contain only one element.
2829 Now, instead of creating an iterator for the `dataset` and retrieving the
2830 batch of features, the `tf.data.get_single_element()` function is used
2831 to skip the iterator creation process and directly output the batch of
2832 features.
2834 This can be particularly useful when your tensor transformations are
2835 expressed as `tf.data.Dataset` operations, and you want to use those
2836 transformations while serving your model.
2838 #### Keras
2840 ```python
2842 model = ... # A pre-built or custom model
2844 class PreprocessingModel(tf.keras.Model):
2845 def __init__(self, model):
2846 super().__init__(self)
2847 self.model = model
2849 @tf.function(input_signature=[...])
2850 def serving_fn(self, data):
2851 ds = tf.data.Dataset.from_tensor_slices(data)
2852 ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
2853 ds = ds.batch(batch_size=BATCH_SIZE)
2854 return tf.argmax(self.model(ds.get_single_element()), axis=-1)
2856 preprocessing_model = PreprocessingModel(model)
2857 your_exported_model_dir = ... # save the model to this path.
2858 tf.saved_model.save(preprocessing_model, your_exported_model_dir,
2859 signatures={'serving_default': preprocessing_model.serving_fn}
2860 )
2861 ```
2863 #### Estimator
2865 In the case of estimators, you need to generally define a `serving_input_fn`
2866 which would require the features to be processed by the model while
2867 inferencing.
2869 ```python
2870 def serving_input_fn():
2872 raw_feature_spec = ... # Spec for the raw_features
2873 input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
2874 raw_feature_spec, default_batch_size=None)
2875 )
2876 serving_input_receiver = input_fn()
2877 raw_features = serving_input_receiver.features
2879 def preprocessing_fn(raw_feature):
2880 # ... the raw_feature is preprocessed as per the use-case
2881 return feature
2883 dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
2884 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
2885 .batch(BATCH_SIZE))
2887 processed_features = dataset.get_single_element()
2889 # Please note that the value of `BATCH_SIZE` should be equal to
2890 # the size of the leading dimension of `raw_features`. This ensures
2891 # that `dataset` has only element, which is a pre-requisite for
2892 # using `dataset.get_single_element()`.
2894 return tf.estimator.export.ServingInputReceiver(
2895 processed_features, serving_input_receiver.receiver_tensors)
2897 estimator = ... # A pre-built or custom estimator
2898 estimator.export_saved_model(your_exported_model_dir, serving_input_fn)
2899 ```
2901 Args:
2902 name: (Optional.) A name for the tf.data operation.
2904 Returns:
2905 A nested structure of `tf.Tensor` objects, corresponding to the single
2906 element of `dataset`.
2908 Raises:
2909 InvalidArgumentError: (at runtime) if `dataset` does not contain exactly
2910 one element.
2911 """
2913 metadata = dataset_metadata_pb2.Metadata()
2914 if name:
2915 metadata.name = _validate_and_encode(name)
2916 return structure.from_compatible_tensor_list(
2917 self.element_spec,
2918 gen_dataset_ops.dataset_to_single_element(
2919 self._variant_tensor,
2920 metadata=metadata.SerializeToString(),
2921 **self._flat_structure)) # pylint: disable=protected-access
2923 def unbatch(self, name=None):
2924 """Splits elements of a dataset into multiple elements.
2926 For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
2927 where `B` may vary for each input element, then for each element in the
2928 dataset, the unbatched dataset will contain `B` consecutive elements
2929 of shape `[a0, a1, ...]`.
2931 >>> elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ]
2932 >>> dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64)
2933 >>> dataset = dataset.unbatch()
2934 >>> list(dataset.as_numpy_iterator())
2935 [1, 2, 3, 1, 2, 1, 2, 3, 4]
2937 Note: `unbatch` requires a data copy to slice up the batched tensor into
2938 smaller, unbatched tensors. When optimizing performance, try to avoid
2939 unnecessary usage of `unbatch`.
2941 Args:
2942 name: (Optional.) A name for the tf.data operation.
2944 Returns:
2945 A new `Dataset` with the transformation applied as described above.
2946 """
2947 # Loaded lazily due to a circular dependency (
2948 # dataset_ops -> unbatch_op -> dataset_ops).
2949 # pylint: disable=g-import-not-at-top,protected-access
2950 from tensorflow.python.data.ops import unbatch_op
2951 return unbatch_op._unbatch(self, name=name)
2952 # pylint: enable=g-import-not-at-top,protected-access
2954 def with_options(self, options, name=None):
2955 """Returns a new `tf.data.Dataset` with the given options set.
2957 The options are "global" in the sense they apply to the entire dataset.
2958 If options are set multiple times, they are merged as long as different
2959 options do not use different non-default values.
2961 >>> ds = tf.data.Dataset.range(5)
2962 >>> ds = ds.interleave(lambda x: tf.data.Dataset.range(5),
2963 ... cycle_length=3,
2964 ... num_parallel_calls=3)
2965 >>> options = tf.data.Options()
2966 >>> # This will make the interleave order non-deterministic.
2967 >>> options.deterministic = False
2968 >>> ds = ds.with_options(options)
2970 Args:
2971 options: A `tf.data.Options` that identifies the options the use.
2972 name: (Optional.) A name for the tf.data operation.
2974 Returns:
2975 A new `Dataset` with the transformation applied as described above.
2977 Raises:
2978 ValueError: when an option is set more than once to a non-default value
2979 """
2980 return _OptionsDataset(self, options, name=name)
2982 def cardinality(self):
2983 """Returns the cardinality of the dataset, if known.
2985 `cardinality` may return `tf.data.INFINITE_CARDINALITY` if the dataset
2986 contains an infinite number of elements or `tf.data.UNKNOWN_CARDINALITY` if
2987 the analysis fails to determine the number of elements in the dataset
2988 (e.g. when the dataset source is a file).
2990 >>> dataset = tf.data.Dataset.range(42)
2991 >>> print(dataset.cardinality().numpy())
2992 42
2993 >>> dataset = dataset.repeat()
2994 >>> cardinality = dataset.cardinality()
2995 >>> print((cardinality == tf.data.INFINITE_CARDINALITY).numpy())
2996 True
2997 >>> dataset = dataset.filter(lambda x: True)
2998 >>> cardinality = dataset.cardinality()
2999 >>> print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy())
3000 True
3002 Returns:
3003 A scalar `tf.int64` `Tensor` representing the cardinality of the dataset.
3004 If the cardinality is infinite or unknown, `cardinality` returns the
3005 named constants `tf.data.INFINITE_CARDINALITY` and
3006 `tf.data.UNKNOWN_CARDINALITY` respectively.
3007 """
3008 return gen_dataset_ops.dataset_cardinality(self._variant_tensor)
3010 def group_by_window(self,
3011 key_func,
3012 reduce_func,
3013 window_size=None,
3014 window_size_func=None,
3015 name=None):
3016 """Groups windows of elements by key and reduces them.
3018 This transformation maps each consecutive element in a dataset to a key
3019 using `key_func` and groups the elements by key. It then applies
3020 `reduce_func` to at most `window_size_func(key)` elements matching the same
3021 key. All except the final window for each key will contain
3022 `window_size_func(key)` elements; the final window may be smaller.
3024 You may provide either a constant `window_size` or a window size determined
3025 by the key through `window_size_func`.
3027 >>> dataset = tf.data.Dataset.range(10)
3028 >>> window_size = 5
3029 >>> key_func = lambda x: x%2
3030 >>> reduce_func = lambda key, dataset: dataset.batch(window_size)
3031 >>> dataset = dataset.group_by_window(
3032 ... key_func=key_func,
3033 ... reduce_func=reduce_func,
3034 ... window_size=window_size)
3035 >>> for elem in dataset.as_numpy_iterator():
3036 ... print(elem)
3037 [0 2 4 6 8]
3038 [1 3 5 7 9]
3040 Args:
3041 key_func: A function mapping a nested structure of tensors (having shapes
3042 and types defined by `self.output_shapes` and `self.output_types`) to a
3043 scalar `tf.int64` tensor.
3044 reduce_func: A function mapping a key and a dataset of up to `window_size`
3045 consecutive elements matching that key to another dataset.
3046 window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
3047 consecutive elements matching the same key to combine in a single batch,
3048 which will be passed to `reduce_func`. Mutually exclusive with
3049 `window_size_func`.
3050 window_size_func: A function mapping a key to a `tf.int64` scalar
3051 `tf.Tensor`, representing the number of consecutive elements matching
3052 the same key to combine in a single batch, which will be passed to
3053 `reduce_func`. Mutually exclusive with `window_size`.
3054 name: (Optional.) A name for the tf.data operation.
3056 Returns:
3057 A new `Dataset` with the transformation applied as described above.
3059 Raises:
3060 ValueError: if neither or both of {`window_size`, `window_size_func`} are
3061 passed.
3062 """
3063 # Loaded lazily due to a circular dependency (
3064 # dataset_ops -> group_by_window_op -> dataset_ops).
3065 # pylint: disable=g-import-not-at-top,protected-access
3066 from tensorflow.python.data.ops import group_by_window_op
3067 return group_by_window_op._group_by_window(
3068 self, key_func, reduce_func, window_size, window_size_func, name=name)
3069 # pylint: enable=g-import-not-at-top,protected-access
3071 def bucket_by_sequence_length(self,
3072 element_length_func,
3073 bucket_boundaries,
3074 bucket_batch_sizes,
3075 padded_shapes=None,
3076 padding_values=None,
3077 pad_to_bucket_boundary=False,
3078 no_padding=False,
3079 drop_remainder=False,
3080 name=None):
3081 """A transformation that buckets elements in a `Dataset` by length.
3083 Elements of the `Dataset` are grouped together by length and then are padded
3084 and batched.
3086 This is useful for sequence tasks in which the elements have variable
3087 length. Grouping together elements that have similar lengths reduces the
3088 total fraction of padding in a batch which increases training step
3089 efficiency.
3091 Below is an example to bucketize the input data to the 3 buckets
3092 "[0, 3), [3, 5), [5, inf)" based on sequence length, with batch size 2.
3094 >>> elements = [
3095 ... [0], [1, 2, 3, 4], [5, 6, 7],
3096 ... [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]]
3097 >>> dataset = tf.data.Dataset.from_generator(
3098 ... lambda: elements, tf.int64, output_shapes=[None])
3099 >>> dataset = dataset.bucket_by_sequence_length(
3100 ... element_length_func=lambda elem: tf.shape(elem)[0],
3101 ... bucket_boundaries=[3, 5],
3102 ... bucket_batch_sizes=[2, 2, 2])
3103 >>> for elem in dataset.as_numpy_iterator():
3104 ... print(elem)
3105 [[1 2 3 4]
3106 [5 6 7 0]]
3107 [[ 7 8 9 10 11 0]
3108 [13 14 15 16 19 20]]
3109 [[ 0 0]
3110 [21 22]]
3112 Args:
3113 element_length_func: function from element in `Dataset` to `tf.int32`,
3114 determines the length of the element, which will determine the bucket it
3115 goes into.
3116 bucket_boundaries: `list<int>`, upper length boundaries of the buckets.
3117 bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be
3118 `len(bucket_boundaries) + 1`.
3119 padded_shapes: Nested structure of `tf.TensorShape` to pass to
3120 `tf.data.Dataset.padded_batch`. If not provided, will use
3121 `dataset.output_shapes`, which will result in variable length dimensions
3122 being padded out to the maximum length in each batch.
3123 padding_values: Values to pad with, passed to
3124 `tf.data.Dataset.padded_batch`. Defaults to padding with 0.
3125 pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
3126 size to maximum length in batch. If `True`, will pad dimensions with
3127 unknown size to bucket boundary minus 1 (i.e., the maximum length in
3128 each bucket), and caller must ensure that the source `Dataset` does not
3129 contain any elements with length longer than `max(bucket_boundaries)`.
3130 no_padding: `bool`, indicates whether to pad the batch features (features
3131 need to be either of type `tf.sparse.SparseTensor` or of same shape).
3132 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
3133 whether the last batch should be dropped in the case it has fewer than
3134 `batch_size` elements; the default behavior is not to drop the smaller
3135 batch.
3136 name: (Optional.) A name for the tf.data operation.
3138 Returns:
3139 A new `Dataset` with the transformation applied as described above.
3141 Raises:
3142 ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
3143 """
3144 if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
3145 raise ValueError(
3146 f"`len(bucket_batch_sizes)` must equal `len(bucket_boundaries) + 1` "
3147 f"but `len(bucket_batch_sizes)={len(bucket_batch_sizes)}` and "
3148 f"`len(bucket_boundaries)={len(bucket_boundaries)}`.")
3150 batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
3152 def element_to_bucket_id(*args):
3153 """Return int64 id of the length bucket for this element."""
3154 seq_length = element_length_func(*args)
3156 boundaries = list(bucket_boundaries)
3157 buckets_min = [np.iinfo(np.int32).min] + boundaries
3158 buckets_max = boundaries + [np.iinfo(np.int32).max]
3159 conditions_c = math_ops.logical_and(
3160 math_ops.less_equal(buckets_min, seq_length),
3161 math_ops.less(seq_length, buckets_max))
3162 bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
3164 return bucket_id
3166 def window_size_fn(bucket_id):
3167 # The window size is set to the batch size for this bucket
3168 window_size = batch_sizes[bucket_id]
3169 return window_size
3171 def make_padded_shapes(shapes, none_filler=None):
3172 padded = []
3173 for shape in nest.flatten(shapes):
3174 shape = tensor_shape.TensorShape(shape)
3175 shape = [
3176 none_filler if tensor_shape.dimension_value(d) is None else d
3177 for d in shape
3178 ]
3179 padded.append(shape)
3180 return nest.pack_sequence_as(shapes, padded)
3182 def batching_fn(bucket_id, grouped_dataset):
3183 """Batch elements in dataset."""
3184 batch_size = window_size_fn(bucket_id)
3185 if no_padding:
3186 return grouped_dataset.batch(
3187 batch_size, drop_remainder=drop_remainder, name=name)
3188 none_filler = None
3189 if pad_to_bucket_boundary:
3190 err_msg = ("When pad_to_bucket_boundary=True, elements must have "
3191 "length < max(bucket_boundaries).")
3192 check = check_ops.assert_less(
3193 bucket_id,
3194 constant_op.constant(
3195 len(bucket_batch_sizes) - 1, dtype=dtypes.int64),
3196 message=err_msg)
3197 with ops.control_dependencies([check]):
3198 boundaries = constant_op.constant(
3199 bucket_boundaries, dtype=dtypes.int64)
3200 bucket_boundary = boundaries[bucket_id]
3201 none_filler = bucket_boundary - 1
3202 input_shapes = get_legacy_output_shapes(grouped_dataset)
3203 shapes = make_padded_shapes(
3204 padded_shapes or input_shapes, none_filler=none_filler)
3205 return grouped_dataset.padded_batch(
3206 batch_size,
3207 shapes,
3208 padding_values,
3209 drop_remainder=drop_remainder,
3210 name=name)
3212 return self.group_by_window(
3213 key_func=element_to_bucket_id,
3214 reduce_func=batching_fn,
3215 window_size_func=window_size_fn,
3216 name=name)
3218 @staticmethod
3219 def random(seed=None, rerandomize_each_iteration=None, name=None):
3220 """Creates a `Dataset` of pseudorandom values.
3222 The dataset generates a sequence of uniformly distributed integer values.
3224 `rerandomize_each_iteration` controls whether the sequence of random number
3225 generated should be re-randomized for each epoch. The default value is False
3226 where the dataset generates the same sequence of random numbers for each
3227 epoch.
3229 >>> ds1 = tf.data.Dataset.random(seed=4).take(10)
3230 >>> ds2 = tf.data.Dataset.random(seed=4).take(10)
3231 >>> print(list(ds1.as_numpy_iterator())==list(ds2.as_numpy_iterator()))
3232 True
3234 >>> ds3 = tf.data.Dataset.random(seed=4).take(10)
3235 >>> ds3_first_epoch = list(ds3.as_numpy_iterator())
3236 >>> ds3_second_epoch = list(ds3.as_numpy_iterator())
3237 >>> print(ds3_first_epoch == ds3_second_epoch)
3238 True
3240 >>> ds4 = tf.data.Dataset.random(
3241 ... seed=4, rerandomize_each_iteration=True).take(10)
3242 >>> ds4_first_epoch = list(ds4.as_numpy_iterator())
3243 >>> ds4_second_epoch = list(ds4.as_numpy_iterator())
3244 >>> print(ds4_first_epoch == ds4_second_epoch)
3245 False
3247 Args:
3248 seed: (Optional) If specified, the dataset produces a deterministic
3249 sequence of values.
3250 rerandomize_each_iteration: (Optional) If set to False, the dataset
3251 generates the same sequence of random numbers for each epoch. If set to
3252 True, it generates a different deterministic sequence of random numbers
3253 for each epoch. It is defaulted to False if left unspecified.
3254 name: (Optional.) A name for the tf.data operation.
3256 Returns:
3257 Dataset: A `Dataset`.
3258 """
3259 # Loaded lazily due to a circular dependency (
3260 # dataset_ops -> random_op -> dataset_ops).
3261 # pylint: disable=g-import-not-at-top,protected-access
3262 from tensorflow.python.data.ops import random_op
3263 return random_op._random(
3264 seed=seed,
3265 rerandomize_each_iteration=rerandomize_each_iteration,
3266 name=name)
3267 # pylint: enable=g-import-not-at-top,protected-access
3269 def snapshot(self,
3270 path,
3271 compression="AUTO",
3272 reader_func=None,
3273 shard_func=None,
3274 name=None):
3275 """API to persist the output of the input dataset.
3277 The snapshot API allows users to transparently persist the output of their
3278 preprocessing pipeline to disk, and materialize the pre-processed data on a
3279 different training run.
3281 This API enables repeated preprocessing steps to be consolidated, and allows
3282 re-use of already processed data, trading off disk storage and network
3283 bandwidth for freeing up more valuable CPU resources and accelerator compute
3284 time.
3286 https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md
3287 has detailed design documentation of this feature.
3289 Users can specify various options to control the behavior of snapshot,
3290 including how snapshots are read from and written to by passing in
3291 user-defined functions to the `reader_func` and `shard_func` parameters.
3293 `shard_func` is a user specified function that maps input elements to
3294 snapshot shards.
3296 Users may want to specify this function to control how snapshot files should
3297 be written to disk. Below is an example of how a potential `shard_func`
3298 could be written.
3300 ```python
3301 dataset = ...
3302 dataset = dataset.enumerate()
3303 dataset = dataset.snapshot("/path/to/snapshot/dir",
3304 shard_func=lambda x, y: x % NUM_SHARDS, ...)
3305 dataset = dataset.map(lambda x, y: y)
3306 ```
3308 `reader_func` is a user specified function that accepts a single argument:
3309 (1) a Dataset of Datasets, each representing a "split" of elements of the
3310 original dataset. The cardinality of the input dataset matches the
3311 number of the shards specified in the `shard_func` (see above). The function
3312 should return a Dataset of elements of the original dataset.
3314 Users may want specify this function to control how snapshot files should be
3315 read from disk, including the amount of shuffling and parallelism.
3317 Here is an example of a standard reader function a user can define. This
3318 function enables both dataset shuffling and parallel reading of datasets:
3320 ```python
3321 def user_reader_func(datasets):
3322 # shuffle the datasets splits
3323 datasets = datasets.shuffle(NUM_CORES)
3324 # read datasets in parallel and interleave their elements
3325 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
3327 dataset = dataset.snapshot("/path/to/snapshot/dir",
3328 reader_func=user_reader_func)
3329 ```
3331 By default, snapshot parallelizes reads by the number of cores available on
3332 the system, but will not attempt to shuffle the data.
3334 Args:
3335 path: Required. A directory to use for storing / loading the snapshot to /
3336 from.
3337 compression: Optional. The type of compression to apply to the snapshot
3338 written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None.
3339 Defaults to `AUTO`, which attempts to pick an appropriate compression
3340 algorithm for the dataset.
3341 reader_func: Optional. A function to control how to read data from
3342 snapshot shards.
3343 shard_func: Optional. A function to control how to shard data when writing
3344 a snapshot.
3345 name: (Optional.) A name for the tf.data operation.
3347 Returns:
3348 A new `Dataset` with the transformation applied as described above.
3349 """
3350 # Loaded lazily due to a circular dependency (
3351 # dataset_ops -> snapshot_op -> dataset_ops).
3352 # pylint: disable=g-import-not-at-top,protected-access
3353 from tensorflow.python.data.ops import snapshot_op
3354 return snapshot_op._snapshot(
3355 self, path, compression, reader_func, shard_func, name=name)
3356 # pylint: enable=g-import-not-at-top,protected-access
3358 def scan(self, initial_state, scan_func, name=None):
3359 """A transformation that scans a function across an input dataset.
3361 This transformation is a stateful relative of `tf.data.Dataset.map`.
3362 In addition to mapping `scan_func` across the elements of the input dataset,
3363 `scan()` accumulates one or more state tensors, whose initial values are
3364 `initial_state`.
3366 >>> dataset = tf.data.Dataset.range(10)
3367 >>> initial_state = tf.constant(0, dtype=tf.int64)
3368 >>> scan_func = lambda state, i: (state + i, state + i)
3369 >>> dataset = dataset.scan(initial_state=initial_state, scan_func=scan_func)
3370 >>> list(dataset.as_numpy_iterator())
3371 [0, 1, 3, 6, 10, 15, 21, 28, 36, 45]
3373 Args:
3374 initial_state: A nested structure of tensors, representing the initial
3375 state of the accumulator.
3376 scan_func: A function that maps `(old_state, input_element)` to
3377 `(new_state, output_element)`. It must take two arguments and return a
3378 pair of nested structures of tensors. The `new_state` must match the
3379 structure of `initial_state`.
3380 name: (Optional.) A name for the tf.data operation.
3382 Returns:
3383 A new `Dataset` with the transformation applied as described above.
3384 """
3386 # Loaded lazily due to a circular dependency (dataset_ops ->
3387 # scan_op -> dataset_ops).
3388 # pylint: disable=g-import-not-at-top,protected-access
3389 from tensorflow.python.data.ops import scan_op
3390 return scan_op._scan(self, initial_state, scan_func, name=name)
3391 # pylint: enable=g-import-not-at-top,protected-access
3393 def take_while(self, predicate, name=None):
3394 """A transformation that stops dataset iteration based on a `predicate`.
3396 >>> dataset = tf.data.Dataset.range(10)
3397 >>> dataset = dataset.take_while(lambda x: x < 5)
3398 >>> list(dataset.as_numpy_iterator())
3399 [0, 1, 2, 3, 4]
3401 Args:
3402 predicate: A function that maps a nested structure of tensors (having
3403 shapes and types defined by `self.output_shapes` and
3404 `self.output_types`) to a scalar `tf.bool` tensor.
3405 name: (Optional.) A name for the tf.data operation.
3407 Returns:
3408 A new `Dataset` with the transformation applied as described above.
3409 """
3410 # Loaded lazily due to a circular dependency (
3411 # dataset_ops -> take_while_op -> dataset_ops).
3412 # pylint: disable=g-import-not-at-top,protected-access
3413 from tensorflow.python.data.ops import take_while_op
3414 return take_while_op._take_while(self, predicate, name=name)
3415 # pylint: enable=g-import-not-at-top,protected-access
3417 def unique(self, name=None):
3418 """A transformation that discards duplicate elements of a `Dataset`.
3420 Use this transformation to produce a dataset that contains one instance of
3421 each unique element in the input. For example:
3423 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
3424 >>> dataset = dataset.unique()
3425 >>> sorted(list(dataset.as_numpy_iterator()))
3426 [1, 2, 37]
3428 Note: This transformation only supports datasets which fit into memory
3429 and have elements of either `tf.int32`, `tf.int64` or `tf.string` type.
3431 Args:
3432 name: (Optional.) A name for the tf.data operation.
3434 Returns:
3435 A new `Dataset` with the transformation applied as described above.
3436 """
3437 # Loaded lazily due to a circular dependency (dataset_ops -> unique_op ->
3438 # dataset_ops).
3439 # pylint: disable=g-import-not-at-top,protected-access
3440 from tensorflow.python.data.ops import unique_op
3441 return unique_op._unique(self, name)
3442 # pylint: enable=g-import-not-at-top,protected-access
3444 def rejection_resample(self,
3445 class_func,
3446 target_dist,
3447 initial_dist=None,
3448 seed=None,
3449 name=None):
3450 """Resamples elements to reach a target distribution.
3452 Note: This implementation can reject **or repeat** elements in order to
3453 reach the `target_dist`. So, in some cases, the output `Dataset` may be
3454 larger than the input `Dataset`.
3456 >>> initial_dist = [0.6, 0.4]
3457 >>> n = 1000
3458 >>> elems = np.random.choice(len(initial_dist), size=n, p=initial_dist)
3459 >>> dataset = tf.data.Dataset.from_tensor_slices(elems)
3460 >>> zero, one = np.bincount(list(dataset.as_numpy_iterator())) / n
3462 Following from `initial_dist`, `zero` is ~0.6 and `one` is ~0.4.
3464 >>> target_dist = [0.5, 0.5]
3465 >>> dataset = dataset.rejection_resample(
3466 ... class_func=lambda x: x,
3467 ... target_dist=target_dist,
3468 ... initial_dist=initial_dist)
3469 >>> dataset = dataset.map(lambda class_func_result, data: data)
3470 >>> zero, one = np.bincount(list(dataset.as_numpy_iterator())) / n
3472 Following from `target_dist`, `zero` is ~0.5 and `one` is ~0.5.
3474 Args:
3475 class_func: A function mapping an element of the input dataset to a scalar
3476 `tf.int32` tensor. Values should be in `[0, num_classes)`.
3477 target_dist: A floating point type tensor, shaped `[num_classes]`.
3478 initial_dist: (Optional.) A floating point type tensor, shaped
3479 `[num_classes]`. If not provided, the true class distribution is
3480 estimated live in a streaming fashion.
3481 seed: (Optional.) Python integer seed for the resampler.
3482 name: (Optional.) A name for the tf.data operation.
3484 Returns:
3485 A new `Dataset` with the transformation applied as described above.
3486 """
3488 # TODO(b/245793127): Consider switching back to the 'v1' implementation.
3490 target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
3491 target_dist_t = math_ops.cast(target_dist_t, dtypes.float32)
3493 # Get initial distribution.
3494 if initial_dist is not None:
3495 initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
3496 initial_dist_t = math_ops.cast(initial_dist_t, dtypes.float32)
3497 acceptance_dist, prob_of_original = (
3498 _calculate_acceptance_probs_with_mixing(initial_dist_t,
3499 target_dist_t))
3500 initial_dist_ds = DatasetV2.from_tensors(
3501 initial_dist_t, name=name).repeat(name=name)
3502 acceptance_dist_ds = DatasetV2.from_tensors(
3503 acceptance_dist, name=name).repeat(name=name)
3504 prob_of_original_ds = DatasetV2.from_tensors(
3505 prob_of_original, name=name).repeat(name=name)
3506 else:
3507 initial_dist_ds = _estimate_initial_dist_ds(
3508 target_dist_t, self.map(class_func, name=name), name=name)
3509 acceptance_and_original_prob_ds = initial_dist_ds.map(
3510 lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda
3511 initial, target_dist_t),
3512 name=name)
3513 acceptance_dist_ds = acceptance_and_original_prob_ds.map(
3514 lambda accept_prob, _: accept_prob, name=name)
3515 prob_of_original_ds = acceptance_and_original_prob_ds.map(
3516 lambda _, prob_original: prob_original, name=name)
3517 filtered_ds = _filter_ds(self, acceptance_dist_ds, initial_dist_ds,
3518 class_func, seed)
3519 # Prefetch filtered dataset for speed.
3520 filtered_ds = filtered_ds.prefetch(3, name=name)
3522 prob_original_static = _get_prob_original_static(
3523 initial_dist_t, target_dist_t) if initial_dist is not None else None
3525 def add_class_value(*x):
3526 if len(x) == 1:
3527 return class_func(*x), x[0]
3528 else:
3529 return class_func(*x), x
3531 if prob_original_static == 1:
3532 return self.map(add_class_value, name=name)
3533 elif prob_original_static == 0:
3534 return filtered_ds
3535 else:
3536 return Dataset.sample_from_datasets(
3537 [self.map(add_class_value), filtered_ds],
3538 weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
3539 seed=seed,
3540 stop_on_empty_dataset=True)
3542 @staticmethod
3543 def sample_from_datasets(datasets,
3544 weights=None,
3545 seed=None,
3546 stop_on_empty_dataset=False,
3547 rerandomize_each_iteration=None):
3548 """Samples elements at random from the datasets in `datasets`.
3550 Creates a dataset by interleaving elements of `datasets` with `weight[i]`
3551 probability of picking an element from dataset `i`. Sampling is done without
3552 replacement. For example, suppose we have 2 datasets:
3554 ```python
3555 dataset1 = tf.data.Dataset.range(0, 3)
3556 dataset2 = tf.data.Dataset.range(100, 103)
3557 ```
3559 Suppose that we sample from these 2 datasets with the following weights:
3561 ```python
3562 sample_dataset = tf.data.Dataset.sample_from_datasets(
3563 [dataset1, dataset2], weights=[0.5, 0.5])
3564 ```
3566 One possible outcome of elements in sample_dataset is:
3568 ```
3569 print(list(sample_dataset.as_numpy_iterator()))
3570 # [100, 0, 1, 101, 2, 102]
3571 ```
3573 Args:
3574 datasets: A non-empty list of `tf.data.Dataset` objects with compatible
3575 structure.
3576 weights: (Optional.) A list or Tensor of `len(datasets)` floating-point
3577 values where `weights[i]` represents the probability to sample from
3578 `datasets[i]`, or a `tf.data.Dataset` object where each element is such
3579 a list. Defaults to a uniform distribution across `datasets`.
3580 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
3581 seed that will be used to create the distribution. See
3582 `tf.random.set_seed` for behavior.
3583 stop_on_empty_dataset: If `True`, sampling stops if it encounters an empty
3584 dataset. If `False`, it continues sampling and skips any empty datasets.
3585 It is recommended to set it to `True`. Otherwise, the distribution of
3586 samples starts off as the user intends, but may change as input datasets
3587 become empty. This can be difficult to detect since the dataset starts
3588 off looking correct. Default to `False` for backward compatibility.
3589 rerandomize_each_iteration: An optional `bool`. The boolean argument
3590 controls whether the sequence of random numbers used to determine which
3591 dataset to sample from will be rerandomized each epoch. That is, it
3592 determinies whether datasets will be sampled in the same order across
3593 different epochs (the default behavior) or not.
3595 Returns:
3596 A dataset that interleaves elements from `datasets` at random, according
3597 to `weights` if provided, otherwise with uniform probability.
3599 Raises:
3600 TypeError: If the `datasets` or `weights` arguments have the wrong type.
3601 ValueError:
3602 - If `datasets` is empty, or
3603 - If `weights` is specified and does not match the length of `datasets`.
3604 """
3605 # Loaded lazily due to a circular dependency
3606 # (dataset_ops -> sample_from_datasets_op -> dataset_ops).
3607 # pylint: disable=g-import-not-at-top,protected-access
3608 from tensorflow.python.data.ops import sample_from_datasets_op
3609 return sample_from_datasets_op._sample_from_datasets( # pylint: disable=protected-access
3610 datasets,
3611 weights,
3612 seed,
3613 stop_on_empty_dataset,
3614 rerandomize_each_iteration,
3615 )
3616 # pylint: enable=g-import-not-at-top,protected-access
3618 @staticmethod
3619 def choose_from_datasets(datasets,
3620 choice_dataset,
3621 stop_on_empty_dataset=True):
3622 """Creates a dataset that deterministically chooses elements from `datasets`.
3624 For example, given the following datasets:
3626 ```python
3627 datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
3628 tf.data.Dataset.from_tensors("bar").repeat(),
3629 tf.data.Dataset.from_tensors("baz").repeat()]
3631 # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
3632 choice_dataset = tf.data.Dataset.range(3).repeat(3)
3634 result = tf.data.Dataset.choose_from_datasets(datasets, choice_dataset)
3635 ```
3637 The elements of `result` will be:
3639 ```
3640 "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
3641 ```
3643 Args:
3644 datasets: A non-empty list of `tf.data.Dataset` objects with compatible
3645 structure.
3646 choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between
3647 `0` and `len(datasets) - 1`.
3648 stop_on_empty_dataset: If `True`, selection stops if it encounters an
3649 empty dataset. If `False`, it skips empty datasets. It is recommended to
3650 set it to `True`. Otherwise, the selected elements start off as the user
3651 intends, but may change as input datasets become empty. This can be
3652 difficult to detect since the dataset starts off looking correct.
3653 Defaults to `True`.
3655 Returns:
3656 A new `Dataset` with the transformation applied as described above.
3658 Raises:
3659 TypeError: If `datasets` or `choice_dataset` has the wrong type.
3660 ValueError: If `datasets` is empty.
3661 """
3662 # Loaded lazily due to a circular dependency
3663 # (dataset_ops -> choose_from_datasets_op -> dataset_ops).
3664 # pylint: disable=g-import-not-at-top,protected-access
3665 from tensorflow.python.data.ops import choose_from_datasets_op
3666 return choose_from_datasets_op._choose_from_datasets(
3667 datasets, choice_dataset, stop_on_empty_dataset)
3668 # pylint: enable=g-import-not-at-top,protected-access
3671@tf_export(v1=["data.Dataset"])
3672class DatasetV1(DatasetV2, data_types.DatasetV1):
3673 """Represents a potentially large set of elements.
3675 A `Dataset` can be used to represent an input pipeline as a
3676 collection of elements and a "logical plan" of transformations that act on
3677 those elements.
3678 """
3680 def __init__(self):
3681 try:
3682 variant_tensor = self._as_variant_tensor()
3683 except AttributeError as e:
3684 if "_as_variant_tensor" in str(e):
3685 raise AttributeError("Please use `_variant_tensor` instead of "
3686 "`_as_variant_tensor()` to obtain the variant "
3687 "associated with a dataset.")
3688 raise AttributeError("{}: A likely cause of this error is that the super "
3689 "call for this dataset is not the last line of the "
3690 "`__init__` method. The base class invokes the "
3691 "`_as_variant_tensor()` method in its constructor "
3692 "and if that method uses attributes defined in the "
3693 "`__init__` method, those attributes need to be "
3694 "defined before the super call.".format(e))
3695 super(DatasetV1, self).__init__(variant_tensor)
3697 @abc.abstractmethod
3698 def _as_variant_tensor(self):
3699 """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
3701 Returns:
3702 A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
3703 """
3704 raise NotImplementedError(f"{type(self)}.as_variant_tensor()")
3706 @deprecation.deprecated(
3707 None, "This is a deprecated API that should only be used in TF 1 graph "
3708 "mode and legacy TF 2 graph mode available through `tf.compat.v1`. In "
3709 "all other situations -- namely, eager mode and inside `tf.function` -- "
3710 "you can consume dataset elements using `for elem in dataset: ...` or "
3711 "by explicitly creating iterator via `iterator = iter(dataset)` and "
3712 "fetching its elements via `values = next(iterator)`. Furthermore, "
3713 "this API is not available in TF 2. During the transition from TF 1 "
3714 "to TF 2 you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)` "
3715 "to create a TF 1 graph mode style iterator for a dataset created "
3716 "through TF 2 APIs. Note that this should be a transient state of your "
3717 "code base as there are in general no guarantees about the "
3718 "interoperability of TF 1 and TF 2 code.")
3719 def make_one_shot_iterator(self):
3720 """Creates an iterator for elements of this dataset.
3722 Note: The returned iterator will be initialized automatically.
3723 A "one-shot" iterator does not currently support re-initialization. For
3724 that see `make_initializable_iterator`.
3726 Example:
3728 ```python
3729 # Building graph ...
3730 dataset = ...
3731 next_value = dataset.make_one_shot_iterator().get_next()
3733 # ... from within a session ...
3734 try:
3735 while True:
3736 value = sess.run(next_value)
3737 ...
3738 except tf.errors.OutOfRangeError:
3739 pass
3740 ```
3742 Returns:
3743 An `tf.data.Iterator` for elements of this dataset.
3744 """
3745 return self._make_one_shot_iterator()
3747 def _make_one_shot_iterator(self): # pylint: disable=missing-docstring
3748 if context.executing_eagerly():
3749 with ops.colocate_with(self._variant_tensor):
3750 return iterator_ops.OwnedIterator(self)
3752 _ensure_same_dataset_graph(self)
3753 # Some ops (e.g. dataset ops) are marked as stateful but are stil safe to
3754 # to capture by value. We must allowlist these ops so that the capturing
3755 # logic captures the ops instead of raising an exception.
3756 allowlisted_stateful_ops = traverse.obtain_capture_by_value_ops(self)
3757 graph_level_seed, op_level_seed = core_random_seed.get_seed(None)
3759 # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
3760 # a 0-argument function.
3761 @function.Defun(
3762 capture_by_value=True,
3763 allowlisted_stateful_ops=allowlisted_stateful_ops)
3764 def _make_dataset():
3765 """Factory function for a dataset."""
3766 # NOTE(mrry): `Defun` does not capture the graph-level seed from the
3767 # enclosing graph, so if a graph-level seed is present we set the local
3768 # graph seed based on a combination of the graph- and op-level seeds.
3769 if graph_level_seed is not None:
3770 assert op_level_seed is not None
3771 core_random_seed.set_random_seed(
3772 (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1))
3774 dataset = self._apply_debug_options()
3775 return dataset._variant_tensor # pylint: disable=protected-access
3777 try:
3778 _make_dataset.add_to_graph(ops.get_default_graph())
3779 except ValueError as err:
3780 if "Cannot capture a stateful node" in str(err):
3781 raise ValueError(
3782 "{}: A likely cause of this error is that the dataset for which "
3783 "you are calling `make_one_shot_iterator()` captures a stateful "
3784 "object, such as a `tf.Variable` or `tf.lookup.StaticHashTable`, "
3785 "which is not supported. Use `make_initializable_iterator()` "
3786 "instead.".format(err)) from None
3787 else:
3788 raise
3790 with ops.colocate_with(self._variant_tensor):
3791 # pylint: disable=protected-access
3792 return iterator_ops.Iterator(
3793 gen_dataset_ops.one_shot_iterator(
3794 dataset_factory=_make_dataset, **self._flat_structure), None,
3795 get_legacy_output_types(self), get_legacy_output_shapes(self),
3796 get_legacy_output_classes(self))
3798 @deprecation.deprecated(
3799 None, "This is a deprecated API that should only be used in TF 1 graph "
3800 "mode and legacy TF 2 graph mode available through `tf.compat.v1`. "
3801 "In all other situations -- namely, eager mode and inside `tf.function` "
3802 "-- you can consume dataset elements using `for elem in dataset: ...` "
3803 "or by explicitly creating iterator via `iterator = iter(dataset)` "
3804 "and fetching its elements via `values = next(iterator)`. "
3805 "Furthermore, this API is not available in TF 2. During the transition "
3806 "from TF 1 to TF 2 you can use "
3807 "`tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF "
3808 "1 graph mode style iterator for a dataset created through TF 2 APIs. "
3809 "Note that this should be a transient state of your code base as there "
3810 "are in general no guarantees about the interoperability of TF 1 and TF "
3811 "2 code.")
3812 def make_initializable_iterator(self, shared_name=None):
3813 """Creates an iterator for elements of this dataset.
3815 Note: The returned iterator will be in an uninitialized state,
3816 and you must run the `iterator.initializer` operation before using it:
3818 ```python
3819 # Building graph ...
3820 dataset = ...
3821 iterator = dataset.make_initializable_iterator()
3822 next_value = iterator.get_next() # This is a Tensor.
3824 # ... from within a session ...
3825 sess.run(iterator.initializer)
3826 try:
3827 while True:
3828 value = sess.run(next_value)
3829 ...
3830 except tf.errors.OutOfRangeError:
3831 pass
3832 ```
3834 Args:
3835 shared_name: (Optional.) If non-empty, the returned iterator will be
3836 shared under the given name across multiple sessions that share the same
3837 devices (e.g. when using a remote server).
3839 Returns:
3840 A `tf.data.Iterator` for elements of this dataset.
3842 Raises:
3843 RuntimeError: If eager execution is enabled.
3844 """
3845 return self._make_initializable_iterator(shared_name)
3847 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=missing-docstring
3848 if context.executing_eagerly():
3849 raise RuntimeError("`make_initializable_iterator()` is not supported in "
3850 "eager mode. Use Python-style iteration instead.")
3851 _ensure_same_dataset_graph(self)
3852 dataset = self._apply_debug_options()
3853 if shared_name is None:
3854 shared_name = ""
3856 with ops.colocate_with(self._variant_tensor):
3857 iterator_resource = gen_dataset_ops.iterator_v2(
3858 container="", shared_name=shared_name, **self._flat_structure)
3860 initializer = gen_dataset_ops.make_iterator(
3861 dataset._variant_tensor, # pylint: disable=protected-access
3862 iterator_resource)
3864 # pylint: disable=protected-access
3865 return iterator_ops.Iterator(iterator_resource, initializer,
3866 get_legacy_output_types(dataset),
3867 get_legacy_output_shapes(dataset),
3868 get_legacy_output_classes(dataset))
3870 @property
3871 @deprecation.deprecated(
3872 None, "Use `tf.compat.v1.data.get_output_classes(dataset)`.")
3873 def output_classes(self):
3874 """Returns the class of each component of an element of this dataset.
3876 Returns:
3877 A (nested) structure of Python `type` objects corresponding to each
3878 component of an element of this dataset.
3879 """
3880 return nest.map_structure(
3881 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
3882 self.element_spec)
3884 @property
3885 @deprecation.deprecated(
3886 None, "Use `tf.compat.v1.data.get_output_shapes(dataset)`.")
3887 def output_shapes(self):
3888 """Returns the shape of each component of an element of this dataset.
3890 Returns:
3891 A (nested) structure of `tf.TensorShape` objects corresponding to each
3892 component of an element of this dataset.
3893 """
3894 return nest.map_structure(
3895 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
3896 self.element_spec)
3898 @property
3899 @deprecation.deprecated(
3900 None, "Use `tf.compat.v1.data.get_output_types(dataset)`.")
3901 def output_types(self):
3902 """Returns the type of each component of an element of this dataset.
3904 Returns:
3905 A (nested) structure of `tf.DType` objects corresponding to each component
3906 of an element of this dataset.
3907 """
3908 return nest.map_structure(
3909 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
3910 self.element_spec)
3912 @property
3913 def element_spec(self):
3914 # TODO(b/110122868): Remove this override once all `Dataset` instances
3915 # implement `element_structure`.
3916 return structure.convert_legacy_structure(
3917 self.output_types, self.output_shapes, self.output_classes)
3919 @staticmethod
3920 @functools.wraps(DatasetV2.from_tensors)
3921 def from_tensors(tensors, name=None):
3922 return DatasetV1Adapter(DatasetV2.from_tensors(tensors, name=name))
3924 @staticmethod
3925 @functools.wraps(DatasetV2.from_tensor_slices)
3926 def from_tensor_slices(tensors, name=None):
3927 return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors, name=name))
3929 @staticmethod
3930 @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
3931 def from_sparse_tensor_slices(sparse_tensor):
3932 """Splits each rank-N `tf.sparse.SparseTensor` in this dataset row-wise.
3934 Args:
3935 sparse_tensor: A `tf.sparse.SparseTensor`.
3937 Returns:
3938 Dataset: A `Dataset` of rank-(N-1) sparse tensors.
3939 """
3940 # Loaded lazily due to a circular dependency (dataset_ops ->
3941 # from_sparse_tensor_slices_op -> dataset_ops).
3942 # pylint: disable=g-import-not-at-top,protected-access
3943 from tensorflow.python.data.ops import from_sparse_tensor_slices_op
3944 return from_sparse_tensor_slices_op._from_sparse_tensor_slices(
3945 sparse_tensor)
3946 # pylint: enable=g-import-not-at-top,protected-access
3948 @staticmethod
3949 @functools.wraps(DatasetV2.from_generator)
3950 @deprecation.deprecated_args(None, "Use output_signature instead",
3951 "output_types", "output_shapes")
3952 def from_generator(generator,
3953 output_types=None,
3954 output_shapes=None,
3955 args=None,
3956 output_signature=None,
3957 name=None):
3958 # Calling DatasetV2.from_generator with output_shapes or output_types is
3959 # deprecated, but this is already checked by the decorator on this function.
3960 with deprecation.silence():
3961 return DatasetV1Adapter(
3962 DatasetV2.from_generator(
3963 generator,
3964 output_types,
3965 output_shapes,
3966 args,
3967 output_signature,
3968 name=name))
3970 @staticmethod
3971 @functools.wraps(DatasetV2.range)
3972 def range(*args, **kwargs):
3973 return DatasetV1Adapter(DatasetV2.range(*args, **kwargs))
3975 @staticmethod
3976 @functools.wraps(DatasetV2.zip)
3977 def zip(*args, datasets=None, name=None):
3978 return DatasetV1Adapter(DatasetV2.zip(*args, datasets=datasets, name=name))
3980 @functools.wraps(DatasetV2.concatenate)
3981 def concatenate(self, dataset, name=None):
3982 return DatasetV1Adapter(
3983 super(DatasetV1, self).concatenate(dataset, name=name))
3985 @functools.wraps(DatasetV2.prefetch)
3986 def prefetch(self, buffer_size, name=None):
3987 return DatasetV1Adapter(
3988 super(DatasetV1, self).prefetch(buffer_size, name=name))
3990 @staticmethod
3991 @functools.wraps(DatasetV2.list_files)
3992 def list_files(file_pattern, shuffle=None, seed=None, name=None):
3993 return DatasetV1Adapter(
3994 DatasetV2.list_files(file_pattern, shuffle, seed, name=name))
3996 @functools.wraps(DatasetV2.repeat)
3997 def repeat(self, count=None, name=None):
3998 return DatasetV1Adapter(super(DatasetV1, self).repeat(count, name=name))
4000 @functools.wraps(DatasetV2.shuffle)
4001 def shuffle(self,
4002 buffer_size,
4003 seed=None,
4004 reshuffle_each_iteration=None,
4005 name=None):
4006 return DatasetV1Adapter(
4007 super(DatasetV1, self).shuffle(
4008 buffer_size, seed, reshuffle_each_iteration, name=name))
4010 @functools.wraps(DatasetV2.cache)
4011 def cache(self, filename="", name=None):
4012 return DatasetV1Adapter(super(DatasetV1, self).cache(filename, name=name))
4014 @functools.wraps(DatasetV2.take)
4015 def take(self, count, name=None):
4016 return DatasetV1Adapter(super(DatasetV1, self).take(count, name=name))
4018 @functools.wraps(DatasetV2.skip)
4019 def skip(self, count, name=None):
4020 return DatasetV1Adapter(super(DatasetV1, self).skip(count, name=name))
4022 @functools.wraps(DatasetV2.shard)
4023 def shard(self, num_shards, index, name=None):
4024 return DatasetV1Adapter(
4025 super(DatasetV1, self).shard(num_shards, index, name=name))
4027 @functools.wraps(DatasetV2.batch)
4028 def batch(self,
4029 batch_size,
4030 drop_remainder=False,
4031 num_parallel_calls=None,
4032 deterministic=None,
4033 name=None):
4034 return DatasetV1Adapter(
4035 super(DatasetV1, self).batch(
4036 batch_size,
4037 drop_remainder,
4038 num_parallel_calls,
4039 deterministic,
4040 name=name))
4042 @functools.wraps(DatasetV2.padded_batch)
4043 def padded_batch(self,
4044 batch_size,
4045 padded_shapes=None,
4046 padding_values=None,
4047 drop_remainder=False,
4048 name=None):
4049 return DatasetV1Adapter(
4050 super(DatasetV1, self).padded_batch(
4051 batch_size,
4052 padded_shapes,
4053 padding_values,
4054 drop_remainder,
4055 name=name))
4057 @functools.wraps(DatasetV2.map)
4058 def map(self,
4059 map_func,
4060 num_parallel_calls=None,
4061 deterministic=None,
4062 name=None):
4063 # Loaded lazily due to a circular dependency (dataset_ops -> map_op ->
4064 # dataset_ops).
4065 # pylint: disable=g-import-not-at-top,protected-access
4066 from tensorflow.python.data.ops import map_op
4067 return map_op._map_v1(
4068 self,
4069 map_func,
4070 num_parallel_calls=num_parallel_calls,
4071 deterministic=deterministic)
4072 # pylint: enable=g-import-not-at-top,protected-access
4074 @deprecation.deprecated(None, "Use `tf.data.Dataset.map()")
4075 def map_with_legacy_function(self,
4076 map_func,
4077 num_parallel_calls=None,
4078 deterministic=None):
4079 """Maps `map_func` across the elements of this dataset.
4081 Note: This is an escape hatch for existing uses of `map` that do not work
4082 with V2 functions. New uses are strongly discouraged and existing uses
4083 should migrate to `map` as this method will be removed in V2.
4085 Args:
4086 map_func: A function mapping a (nested) structure of tensors (having
4087 shapes and types defined by `self.output_shapes` and
4088 `self.output_types`) to another (nested) structure of tensors.
4089 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
4090 representing the number elements to process asynchronously in parallel.
4091 If not specified, elements will be processed sequentially. If the value
4092 `tf.data.AUTOTUNE` is used, then the number of parallel calls is set
4093 dynamically based on available CPU.
4094 deterministic: (Optional.) When `num_parallel_calls` is specified, this
4095 boolean controls the order in which the transformation produces
4096 elements. If set to `False`, the transformation is allowed to yield
4097 elements out of order to trade determinism for performance. If not
4098 specified, the `tf.data.Options.deterministic` option (`True` by
4099 default) controls the behavior.
4101 Returns:
4102 Dataset: A `Dataset`.
4103 """
4104 # Loaded lazily due to a circular dependency (dataset_ops -> map_op ->
4105 # dataset_ops).
4106 # pylint: disable=g-import-not-at-top,protected-access
4107 from tensorflow.python.data.ops import map_op
4108 return map_op._map_v1_with_legacy_function(
4109 self,
4110 map_func,
4111 num_parallel_calls=num_parallel_calls,
4112 deterministic=deterministic)
4113 # pylint: enable=g-import-not-at-top,protected-access
4115 @functools.wraps(DatasetV2.flat_map)
4116 def flat_map(self, map_func, name=None):
4117 return DatasetV1Adapter(
4118 super(DatasetV1, self).flat_map(map_func, name=name))
4120 @functools.wraps(DatasetV2.interleave)
4121 def interleave(self,
4122 map_func,
4123 cycle_length=None,
4124 block_length=None,
4125 num_parallel_calls=None,
4126 deterministic=None,
4127 name=None):
4128 return DatasetV1Adapter(
4129 super(DatasetV1, self).interleave(
4130 map_func,
4131 cycle_length,
4132 block_length,
4133 num_parallel_calls,
4134 deterministic,
4135 name=name))
4137 @functools.wraps(DatasetV2.filter)
4138 def filter(self, predicate, name=None):
4139 return DatasetV1Adapter(super(DatasetV1, self).filter(predicate, name=name))
4141 @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()")
4142 def filter_with_legacy_function(self, predicate):
4143 """Filters this dataset according to `predicate`.
4145 Note: This is an escape hatch for existing uses of `filter` that do not work
4146 with V2 functions. New uses are strongly discouraged and existing uses
4147 should migrate to `filter` as this method will be removed in V2.
4149 Args:
4150 predicate: A function mapping a (nested) structure of tensors (having
4151 shapes and types defined by `self.output_shapes` and
4152 `self.output_types`) to a scalar `tf.bool` tensor.
4154 Returns:
4155 Dataset: The `Dataset` containing the elements of this dataset for which
4156 `predicate` is `True`.
4157 """
4158 # Loaded lazily due to a circular dependency (dataset_ops -> filter_op ->
4159 # dataset_ops).
4160 # pylint: disable=g-import-not-at-top,protected-access
4161 from tensorflow.python.data.ops import filter_op
4162 return filter_op._FilterDataset(self, predicate, use_legacy_function=True)
4163 # pylint: enable=g-import-not-at-top,protected-access
4165 @functools.wraps(DatasetV2.apply)
4166 def apply(self, transformation_func):
4167 return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func))
4169 @functools.wraps(DatasetV2.window)
4170 def window(self, size, shift=None, stride=1, drop_remainder=False, name=None):
4171 return DatasetV1Adapter(
4172 super(DatasetV1,
4173 self).window(size, shift, stride, drop_remainder, name=name))
4175 @functools.wraps(DatasetV2.unbatch)
4176 def unbatch(self, name=None):
4177 return DatasetV1Adapter(super(DatasetV1, self).unbatch(name=name))
4179 @functools.wraps(DatasetV2.with_options)
4180 def with_options(self, options, name=None):
4181 return DatasetV1Adapter(
4182 super(DatasetV1, self).with_options(options, name=name))
4185if tf2.enabled():
4186 Dataset = DatasetV2
4187else:
4188 Dataset = DatasetV1
4191class DatasetV1Adapter(DatasetV1):
4192 """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API."""
4194 def __init__(self, dataset):
4195 self._dataset = dataset
4196 super(DatasetV1Adapter, self).__init__()
4198 def _as_variant_tensor(self):
4199 return self._dataset._variant_tensor # pylint: disable=protected-access
4201 def _inputs(self):
4202 return self._dataset._inputs() # pylint: disable=protected-access
4204 def _functions(self):
4205 return self._dataset._functions() # pylint: disable=protected-access
4207 def options(self):
4208 return self._dataset.options()
4210 @property
4211 def element_spec(self):
4212 return self._dataset.element_spec # pylint: disable=protected-access
4214 def __iter__(self):
4215 return iter(self._dataset)
4218def _ensure_same_dataset_graph(dataset):
4219 """Walks the dataset graph to ensure all datasets come from the same graph."""
4220 # pylint: disable=protected-access
4221 current_graph = ops.get_default_graph()
4222 bfs_q = queue.Queue()
4223 bfs_q.put(dataset)
4224 visited = []
4225 while not bfs_q.empty():
4226 ds = bfs_q.get()
4227 visited.append(ds)
4228 ds_graph = ds._graph
4229 if current_graph != ds_graph:
4230 raise ValueError(
4231 f"The graph {current_graph} of the iterator is different from the "
4232 f"graph {ds_graph} the dataset: {ds._variant_tensor} was created in. "
4233 f"If you are using the Estimator API, make sure that no part of the "
4234 f"dataset returned by the `input_fn` function is defined outside the "
4235 f"`input_fn` function. Otherwise, make sure that the dataset is "
4236 f"created in the same graph as the iterator.")
4237 for input_ds in ds._inputs():
4238 if input_ds not in visited:
4239 bfs_q.put(input_ds)
4242@tf_export(v1=["data.make_one_shot_iterator"])
4243def make_one_shot_iterator(dataset):
4244 """Creates an iterator for elements of `dataset`.
4246 Note: The returned iterator will be initialized automatically.
4247 A "one-shot" iterator does not support re-initialization.
4249 Args:
4250 dataset: A `tf.data.Dataset`.
4252 Returns:
4253 A `tf.data.Iterator` for elements of `dataset`.
4255 @compatibility(TF2)
4256 This is a legacy API for consuming dataset elements and should only be used
4257 during transition from TF 1 to TF 2. Note that using this API should be
4258 a transient state of your code base as there are in general no guarantees
4259 about the interoperability of TF 1 and TF 2 code.
4261 In TF 2 datasets are Python iterables which means you can consume their
4262 elements using `for elem in dataset: ...` or by explicitly creating iterator
4263 via `iterator = iter(dataset)` and fetching its elements via
4264 `values = next(iterator)`.
4265 @end_compatibility
4266 """
4267 try:
4268 # Call the defined `_make_one_shot_iterator()` if there is one, because some
4269 # datasets (e.g. for prefetching) override its behavior.
4270 return dataset._make_one_shot_iterator() # pylint: disable=protected-access
4271 except AttributeError:
4272 return DatasetV1Adapter(dataset)._make_one_shot_iterator() # pylint: disable=protected-access
4275@tf_export(v1=["data.make_initializable_iterator"])
4276def make_initializable_iterator(dataset, shared_name=None):
4277 """Creates an iterator for elements of `dataset`.
4279 Note: The returned iterator will be in an uninitialized state,
4280 and you must run the `iterator.initializer` operation before using it:
4282 ```python
4283 dataset = ...
4284 iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
4285 # ...
4286 sess.run(iterator.initializer)
4287 ```
4289 Args:
4290 dataset: A `tf.data.Dataset`.
4291 shared_name: (Optional.) If non-empty, the returned iterator will be shared
4292 under the given name across multiple sessions that share the same devices
4293 (e.g. when using a remote server).
4295 Returns:
4296 A `tf.data.Iterator` for elements of `dataset`.
4298 Raises:
4299 RuntimeError: If eager execution is enabled.
4301 @compatibility(TF2)
4302 This is a legacy API for consuming dataset elements and should only be used
4303 during transition from TF 1 to TF 2. Note that using this API should be
4304 a transient state of your code base as there are in general no guarantees
4305 about the interoperability of TF 1 and TF 2 code.
4307 In TF 2 datasets are Python iterables which means you can consume their
4308 elements using `for elem in dataset: ...` or by explicitly creating iterator
4309 via `iterator = iter(dataset)` and fetching its elements via
4310 `values = next(iterator)`.
4311 @end_compatibility
4312 """
4313 try:
4314 # Call the defined `_make_initializable_iterator()` if there is one, because
4315 # some datasets (e.g. for prefetching) override its behavior.
4316 return dataset._make_initializable_iterator(shared_name) # pylint: disable=protected-access
4317 except AttributeError:
4318 return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name) # pylint: disable=protected-access
4321@tf_export("data.experimental.get_structure")
4322def get_structure(dataset_or_iterator):
4323 """Returns the type signature for elements of the input dataset / iterator.
4325 For example, to get the structure of a `tf.data.Dataset`:
4327 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
4328 >>> tf.data.experimental.get_structure(dataset)
4329 TensorSpec(shape=(), dtype=tf.int32, name=None)
4331 >>> dataset = tf.data.experimental.from_list([(1, 'a'), (2, 'b'), (3, 'c')])
4332 >>> tf.data.experimental.get_structure(dataset)
4333 (TensorSpec(shape=(), dtype=tf.int32, name=None),
4334 TensorSpec(shape=(), dtype=tf.string, name=None))
4336 To get the structure of an `tf.data.Iterator`:
4338 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
4339 >>> tf.data.experimental.get_structure(iter(dataset))
4340 TensorSpec(shape=(), dtype=tf.int32, name=None)
4342 Args:
4343 dataset_or_iterator: A `tf.data.Dataset` or an `tf.data.Iterator`.
4345 Returns:
4346 A (nested) structure of `tf.TypeSpec` objects matching the structure of an
4347 element of `dataset_or_iterator` and specifying the type of individual
4348 components.
4350 Raises:
4351 TypeError: If input is not a `tf.data.Dataset` or an `tf.data.Iterator`
4352 object.
4353 """
4354 try:
4355 return dataset_or_iterator.element_spec # pylint: disable=protected-access
4356 except AttributeError:
4357 raise TypeError(f"Invalid `dataset_or_iterator`. `dataset_or_iterator` "
4358 f"must be a `tf.data.Dataset` or tf.data.Iterator object, "
4359 f"but got {type(dataset_or_iterator)}.")
4362@tf_export(v1=["data.get_output_classes"])
4363def get_legacy_output_classes(dataset_or_iterator):
4364 """Returns the output classes for elements of the input dataset / iterator.
4366 Args:
4367 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
4369 Returns:
4370 A (nested) structure of Python `type` objects matching the structure of the
4371 dataset / iterator elements and specifying the class of the individual
4372 components.
4374 @compatibility(TF2)
4375 This is a legacy API for inspecting the type signature of dataset elements. In
4376 TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead.
4377 @end_compatibility
4378 """
4379 return nest.map_structure(
4380 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
4381 get_structure(dataset_or_iterator))
4384@tf_export(v1=["data.get_output_shapes"])
4385def get_legacy_output_shapes(dataset_or_iterator):
4386 """Returns the output shapes for elements of the input dataset / iterator.
4388 Args:
4389 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
4391 Returns:
4392 A (nested) structure of `tf.TensorShape` objects matching the structure of
4393 the dataset / iterator elements and specifying the shape of the individual
4394 components.
4396 @compatibility(TF2)
4397 This is a legacy API for inspecting the type signature of dataset elements. In
4398 TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead.
4399 @end_compatibility
4400 """
4401 return nest.map_structure(
4402 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
4403 get_structure(dataset_or_iterator))
4406@tf_export(v1=["data.get_output_types"])
4407def get_legacy_output_types(dataset_or_iterator):
4408 """Returns the output shapes for elements of the input dataset / iterator.
4410 Args:
4411 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
4413 Returns:
4414 A (nested) structure of `tf.DType` objects matching the structure of
4415 dataset / iterator elements and specifying the shape of the individual
4416 components.
4418 @compatibility(TF2)
4419 This is a legacy API for inspecting the type signature of dataset elements. In
4420 TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead.
4421 @end_compatibility
4422 """
4423 return nest.map_structure(
4424 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
4425 get_structure(dataset_or_iterator))
4428class DatasetSource(DatasetV2):
4429 """Abstract class representing a dataset with no inputs."""
4431 def _inputs(self):
4432 return []
4435class UnaryDataset(DatasetV2):
4436 """Abstract class representing a dataset with one input."""
4438 def __init__(self, input_dataset, variant_tensor):
4439 self._input_dataset = input_dataset
4440 super(UnaryDataset, self).__init__(variant_tensor)
4442 def _inputs(self):
4443 return [self._input_dataset]
4446class UnaryUnchangedStructureDataset(UnaryDataset):
4447 """Represents a unary dataset with the same input and output structure."""
4449 def __init__(self, input_dataset, variant_tensor):
4450 self._input_dataset = input_dataset
4451 super(UnaryUnchangedStructureDataset, self).__init__(
4452 input_dataset, variant_tensor)
4454 @property
4455 def element_spec(self):
4456 return self._input_dataset.element_spec
4459class _VariantDataset(DatasetV2):
4460 """A Dataset wrapper around a `tf.variant`-typed function argument."""
4462 def __init__(self, dataset_variant, element_spec):
4463 self._element_spec = element_spec
4464 super(_VariantDataset, self).__init__(dataset_variant)
4466 def _inputs(self):
4467 return []
4469 @property
4470 def element_spec(self):
4471 return self._element_spec
4474class _NestedVariant(composite_tensor.CompositeTensor):
4476 def __init__(self, variant_tensor, element_spec, dataset_shape):
4477 self._variant_tensor = variant_tensor
4478 self._element_spec = element_spec
4479 self._dataset_shape = dataset_shape
4481 @property
4482 def _type_spec(self):
4483 return DatasetSpec(self._element_spec, self._dataset_shape)
4486@tf_export("data.experimental.from_variant")
4487def from_variant(variant, structure):
4488 """Constructs a dataset from the given variant and (nested) structure.
4490 Args:
4491 variant: A scalar `tf.variant` tensor representing a dataset.
4492 structure: A (nested) structure of `tf.TypeSpec` objects representing the
4493 structure of each element in the dataset.
4495 Returns:
4496 A `tf.data.Dataset` instance.
4497 """
4498 return _VariantDataset(variant, structure) # pylint: disable=protected-access
4501@tf_export("data.experimental.to_variant")
4502def to_variant(dataset):
4503 """Returns a variant representing the given dataset.
4505 Args:
4506 dataset: A `tf.data.Dataset`.
4508 Returns:
4509 A scalar `tf.variant` tensor representing the given dataset.
4510 """
4511 return dataset._variant_tensor # pylint: disable=protected-access
4514@tf_export(
4515 "data.DatasetSpec",
4516 v1=["data.DatasetSpec", "data.experimental.DatasetStructure"])
4517class DatasetSpec(type_spec.BatchableTypeSpec):
4518 """Type specification for `tf.data.Dataset`.
4520 See `tf.TypeSpec` for more information about TensorFlow type specifications.
4522 >>> dataset = tf.data.Dataset.range(3)
4523 >>> tf.data.DatasetSpec.from_value(dataset)
4524 DatasetSpec(TensorSpec(shape=(), dtype=tf.int64, name=None), TensorShape([]))
4525 """
4527 __slots__ = ["_element_spec", "_dataset_shape"]
4529 def __init__(self, element_spec, dataset_shape=()):
4530 self._element_spec = element_spec
4531 self._dataset_shape = tensor_shape.as_shape(dataset_shape)
4533 @property
4534 def value_type(self):
4535 return Dataset
4537 @property
4538 def element_spec(self):
4539 """The inner element spec."""
4540 return self._element_spec
4542 def is_subtype_of(self, other):
4543 """See base class."""
4544 if type(self) is not type(other):
4545 return False
4547 # TODO(b/220385675): _element_spec should always be a TypeSpec.
4548 try:
4549 tf_nest.assert_same_structure(self.element_spec, other.element_spec)
4550 except (TypeError, ValueError):
4551 return False
4553 self_elements = tf_nest.flatten(self.element_spec)
4554 other_elements = tf_nest.flatten(other.element_spec)
4556 def is_subtype_or_equal(a, b):
4557 if isinstance(a, trace.TraceType):
4558 return a.is_subtype_of(b)
4559 else:
4560 return a == b
4562 for self_element, other_element in zip(self_elements, other_elements):
4563 if not is_subtype_or_equal(self_element, other_element):
4564 return False
4566 return self._dataset_shape.is_subtype_of(other._dataset_shape) # pylint: disable=protected-access
4568 def most_specific_common_supertype(self, others):
4569 """See base class."""
4570 if not all(type(self) is type(other) for other in others):
4571 return None
4573 try:
4574 for other in others:
4575 tf_nest.assert_same_structure(self.element_spec, other.element_spec)
4576 except (TypeError, ValueError):
4577 return None
4579 self_components = tf_nest.flatten(self.element_spec)
4580 others_components = [
4581 tf_nest.flatten(other.element_spec) for other in others
4582 ]
4583 common_components = [None] * len(self_components)
4585 def common_supertype_or_equal(a, bs):
4586 if isinstance(a, trace.TraceType):
4587 return a.most_specific_common_supertype(bs)
4588 else:
4589 return a if all(a == b for b in bs) else None
4591 for i, self_component in enumerate(self_components):
4592 common_components[i] = common_supertype_or_equal(
4593 self_component,
4594 [other_components[i] for other_components in others_components])
4595 if self_component is not None and common_components[i] is None:
4596 return None
4597 common_element_spec = tf_nest.pack_sequence_as(self._element_spec,
4598 common_components)
4600 common_dataset_shape = self._dataset_shape.most_specific_common_supertype(
4601 [other._dataset_shape for other in others]) # pylint: disable=protected-access
4602 if common_dataset_shape is None:
4603 return None
4605 return DatasetSpec(common_element_spec, common_dataset_shape)
4607 # TODO(b/220385675): Once _element_spec is guaranteed to be TypeSpec, the
4608 # following functions do not need to be overloaded: is_subtype_of,
4609 # most_specific_common_supertype, __hash__ and __eq__
4610 def _serialize(self):
4611 return (self._element_spec, self._dataset_shape)
4613 @property
4614 def _component_specs(self):
4615 return tensor_spec.TensorSpec(self._dataset_shape, dtypes.variant)
4617 def _to_components(self, value):
4618 return value._variant_tensor # pylint: disable=protected-access
4620 def _from_components(self, components):
4621 # pylint: disable=protected-access
4622 if self._dataset_shape.ndims == 0:
4623 return _VariantDataset(components, self._element_spec)
4624 else:
4625 return _NestedVariant(components, self._element_spec, self._dataset_shape)
4627 def _to_tensor_list(self, value):
4628 return [
4629 ops.convert_to_tensor(
4630 tf_nest.map_structure(lambda x: x._variant_tensor, value)) # pylint: disable=protected-access
4631 ]
4633 @staticmethod
4634 def from_value(value):
4635 """Creates a `DatasetSpec` for the given `tf.data.Dataset` value."""
4636 return DatasetSpec(value.element_spec) # pylint: disable=protected-access
4638 def _batch(self, batch_size):
4639 return DatasetSpec(
4640 self._element_spec,
4641 tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape))
4643 def _unbatch(self):
4644 if self._dataset_shape.ndims == 0:
4645 raise ValueError("Slicing dataset elements is not supported for rank 0.")
4646 return DatasetSpec(self._element_spec, self._dataset_shape[1:])
4648 def _to_batched_tensor_list(self, value):
4649 if self._dataset_shape.ndims == 0:
4650 raise ValueError("Slicing dataset elements is not supported for rank 0.")
4651 return self._to_tensor_list(value)
4653 def _to_legacy_output_types(self):
4654 return self
4656 def _to_legacy_output_shapes(self):
4657 return self
4659 def _to_legacy_output_classes(self):
4660 return self
4662 def __hash__(self):
4663 # TODO(b/220385675): attributes can be dicts and hence unhashable.
4664 return hash(DatasetSpec)
4666 def __eq__(self, other):
4667 return (isinstance(other, DatasetSpec) and
4668 self._element_spec == other._element_spec and
4669 self._dataset_shape == other._dataset_shape)
4672nested_structure_coder.register_codec(
4673 nested_structure_coder.BuiltInTypeSpecCodec(
4674 DatasetSpec, struct_pb2.TypeSpecProto.DATA_DATASET_SPEC
4675 )
4676)
4679class _NumpyIterator(tracking_base.Trackable):
4680 """Iterator over a dataset with elements converted to numpy."""
4682 __slots__ = ["_iterator"]
4684 def __init__(self, dataset):
4685 self._iterator = iter(dataset)
4687 def __iter__(self):
4688 return self
4690 def __next__(self):
4692 def to_numpy(x):
4693 numpy = x._numpy() # pylint: disable=protected-access
4694 if isinstance(numpy, np.ndarray):
4695 # `numpy` shares the same underlying buffer as the `x` Tensor.
4696 # Tensors are expected to be immutable, so we disable writes.
4697 numpy.setflags(write=False)
4698 return numpy
4700 return nest.map_structure(to_numpy, next(self._iterator))
4702 def next(self):
4703 return self.__next__()
4705 # override
4706 def _serialize_to_tensors(self):
4707 # pylint: disable=protected-access
4708 return self._iterator._serialize_to_tensors()
4710 # override
4711 def _restore_from_tensors(self, restored_tensors):
4712 # pylint: disable=protected-access
4713 return self._iterator._restore_from_tensors(restored_tensors)
4715 def _save(self):
4716 # pylint: disable=protected-access
4717 return self._iterator._save()
4719 def _restore(self, state):
4720 # pylint: disable=protected-access
4721 return self._iterator._restore(state)
4724class _VariantTracker(resource_lib.CapturableResource):
4725 """Allows export of functions capturing a Dataset in SavedModels.
4727 When saving a SavedModel, `tf.saved_model.save` traverses the object
4728 graph. Since Datasets reference _VariantTracker objects, that traversal will
4729 find a _VariantTracker for each Dataset and so know how to save and restore
4730 functions which reference the Dataset's variant Tensor.
4731 """
4733 def __init__(self, variant_tensor, resource_creator):
4734 """Record that `variant_tensor` is associated with `resource_creator`.
4736 Args:
4737 variant_tensor: The variant-dtype Tensor associated with the Dataset. This
4738 Tensor will be a captured input to functions which use the Dataset, and
4739 is used by saving code to identify the corresponding _VariantTracker.
4740 resource_creator: A zero-argument function which creates a new
4741 variant-dtype Tensor. This function will be included in SavedModels and
4742 run to re-create the Dataset's variant Tensor on restore.
4743 """
4744 super(_VariantTracker, self).__init__(device="CPU")
4745 self._resource_handle = variant_tensor
4746 if not isinstance(resource_creator, def_function.Function):
4747 # Internal validation -- _VariantTracker assumes that resource creator is
4748 # already a tf.function.
4749 raise TypeError("Resource creator should already be a tf.function.")
4750 self._create_resource = resource_creator
4752 def _trackable_children(self,
4753 save_type=tracking_base.SaveType.CHECKPOINT,
4754 **kwargs):
4755 if save_type != tracking_base.SaveType.SAVEDMODEL:
4756 return {}
4758 children = super(_VariantTracker,
4759 self)._trackable_children(save_type, **kwargs)
4760 # Overwrite the _create_resource function, since `self._create_resource`
4761 # is already a tf.function.
4762 children["_create_resource"] = self._create_resource
4763 return children
4766# TODO(b/254291122): Remove.
4767# Loaded lazily due to a circular dependency (dataset_ops ->
4768# batch_op -> dataset_ops).
4769batch_op = lazy_loader.LazyLoader(
4770 "batch_op", globals(),
4771 "tensorflow.python.data.ops.batch_op")
4772BatchDataset = batch_op._BatchDataset # pylint: disable=protected-access
4773PrefetchDataset = prefetch_op._PrefetchDataset # pylint: disable=protected-access
4774ShuffleDataset = shuffle_op._ShuffleDataset # pylint: disable=protected-access
4777# TODO(b/254291122): Remove.
4778# Loaded lazily due to a circular dependency (dataset_ops ->
4779# repeat_op -> dataset_ops).
4780repeat_op = lazy_loader.LazyLoader(
4781 "repeat_op", globals(),
4782 "tensorflow.python.data.ops.repeat_op")
4783RepeatDataset = repeat_op._RepeatDataset # pylint: disable=protected-access
4786class _OptionsDataset(UnaryUnchangedStructureDataset):
4787 """An identity `Dataset` that stores options."""
4789 def __init__(self, input_dataset, options, name=None):
4790 # pylint: disable=protected-access
4791 self._input_dataset = input_dataset
4792 options_pb = dataset_options_pb2.Options()
4793 options_pb.CopyFrom(options._to_proto())
4794 self._name = name
4795 with ops.colocate_with(input_dataset._variant_tensor):
4796 variant_tensor = gen_dataset_ops.options_dataset(
4797 input_dataset._variant_tensor, options_pb.SerializeToString(),
4798 **self._common_args)
4799 super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
4801 if self._options_attr:
4802 self._options_attr._set_mutable(True)
4803 self._options_attr = self._options_attr.merge(options)
4804 else:
4805 self._options_attr = options
4806 self._options_attr._set_mutable(False)
4809def normalize_to_dense(dataset):
4810 """Normalizes non-tensor components in a dataset to dense representations.
4812 This is necessary for dataset transformations that slice along the batch
4813 dimension and are oblivious to non-tensors, e.g. `unbatch`, `rebatch`.
4815 Args:
4816 dataset: Dataset to normalize.
4818 Returns:
4819 A dataset whose sparse and ragged tensors have been normalized to their
4820 dense representations.
4821 """
4823 # NOTE(mrry): This leads to a somewhat inefficient re-encoding step for all
4824 # non-tensor components.
4825 #
4826 # TODO(mrry): Consider optimizing this if it turns out to be a bottleneck.
4827 if structured_function._should_unpack(dataset.element_spec): # pylint: disable=protected-access
4829 def normalize(*args):
4830 return structure.to_batched_tensor_list(dataset.element_spec, tuple(args))
4831 else:
4832 def normalize(arg):
4833 return structure.to_batched_tensor_list(dataset.element_spec, arg)
4835 normalized_dataset = dataset.map(normalize)
4837 # NOTE(mrry): Our `map()` has lost information about the structure of
4838 # non-tensor components, so re-apply the structure of the original dataset.
4839 return _RestructuredDataset(normalized_dataset, dataset.element_spec)
4842class _RestructuredDataset(UnaryDataset):
4843 """An internal helper for changing the element spec of a dataset."""
4845 def __init__(self, dataset, element_spec):
4846 self._input_dataset = dataset
4847 self._element_spec = element_spec
4849 variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
4850 super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
4852 @property
4853 def element_spec(self):
4854 return self._element_spec
4857def _get_prob_original_static(initial_dist_t, target_dist_t):
4858 """Returns the static probability of sampling from the original.
4860 `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
4861 an Op that it isn't defined for. We have some custom logic to avoid this.
4863 Args:
4864 initial_dist_t: A tensor of the initial distribution.
4865 target_dist_t: A tensor of the target distribution.
4867 Returns:
4868 The probability of sampling from the original distribution as a constant,
4869 if it is a constant, or `None`.
4870 """
4871 init_static = tensor_util.constant_value(initial_dist_t)
4872 target_static = tensor_util.constant_value(target_dist_t)
4874 if init_static is None or target_static is None:
4875 return None
4876 else:
4877 return np.min(target_static / init_static)
4880def _filter_ds(dataset,
4881 acceptance_dist_ds,
4882 initial_dist_ds,
4883 class_func,
4884 seed,
4885 name=None):
4886 """Filters a dataset based on per-class acceptance probabilities.
4888 Args:
4889 dataset: The dataset to be filtered.
4890 acceptance_dist_ds: A dataset of acceptance probabilities.
4891 initial_dist_ds: A dataset of the initial probability distribution, given or
4892 estimated.
4893 class_func: A function mapping an element of the input dataset to a scalar
4894 `tf.int32` tensor. Values should be in `[0, num_classes)`.
4895 seed: (Optional.) Python integer seed for the resampler.
4896 name: (Optional.) A name for the tf.data operation.
4898 Returns:
4899 A dataset of (class value, data) after filtering.
4900 """
4902 def maybe_warn_on_large_rejection(accept_dist, initial_dist):
4903 proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
4904 return cond.cond(
4905 math_ops.less(proportion_rejected, .5),
4906 lambda: accept_dist,
4907 lambda: logging_ops.Print( # pylint: disable=g-long-lambda
4908 accept_dist, [proportion_rejected, initial_dist, accept_dist],
4909 message="Proportion of examples rejected by sampler is high: ",
4910 summarize=100,
4911 first_n=10))
4913 acceptance_dist_ds = (
4914 DatasetV2.zip((acceptance_dist_ds, initial_dist_ds),
4915 name=name).map(maybe_warn_on_large_rejection, name=name))
4917 def _gather_and_copy(acceptance_prob, data):
4918 if isinstance(data, tuple):
4919 class_val = class_func(*data)
4920 else:
4921 class_val = class_func(data)
4922 return class_val, array_ops.gather(acceptance_prob, class_val), data
4924 current_probabilities_and_class_and_data_ds = DatasetV2.zip(
4925 (acceptance_dist_ds, dataset), name=name).map(
4926 _gather_and_copy, name=name)
4928 def _reject(unused_class_val, p, unused_data):
4929 return random_ops.random_uniform([], seed=seed, dtype=p.dtype) < p
4931 filtered_ds = current_probabilities_and_class_and_data_ds.filter(
4932 _reject, name=name)
4933 return filtered_ds.map(
4934 lambda class_value, _, data: (class_value, data), name=name)
4937# pylint: disable=missing-function-docstring
4938def _estimate_initial_dist_ds(target_dist_t,
4939 class_values_ds,
4940 dist_estimation_batch_size=32,
4941 smoothing_constant=10,
4942 name=None):
4943 num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0])
4944 initial_examples_per_class_seen = array_ops.fill([num_classes],
4945 np.int64(smoothing_constant))
4947 def update_estimate_and_tile(num_examples_per_class_seen, c):
4948 updated_examples_per_class_seen, dist = _estimate_data_distribution(
4949 c, num_examples_per_class_seen)
4950 tiled_dist = array_ops.tile(
4951 array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
4952 return updated_examples_per_class_seen, tiled_dist
4954 initial_dist_ds = (
4955 class_values_ds.batch(dist_estimation_batch_size, name=name).scan(
4956 initial_examples_per_class_seen, update_estimate_and_tile,
4957 name=name).unbatch(name=name))
4959 return initial_dist_ds
4962def _get_target_to_initial_ratio(initial_probs, target_probs):
4963 # Add tiny to initial_probs to avoid divide by zero.
4964 denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
4965 return target_probs / denom
4968def _estimate_data_distribution(c, num_examples_per_class_seen):
4969 """Estimate data distribution as labels are seen.
4971 Args:
4972 c: The class labels. Type `int32`, shape `[batch_size]`.
4973 num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, containing
4974 counts.
4976 Returns:
4977 num_examples_per_lass_seen: Updated counts. Type `int64`, shape
4978 `[num_classes]`.
4979 dist: The updated distribution. Type `float32`, shape `[num_classes]`.
4980 """
4981 num_classes = num_examples_per_class_seen.get_shape()[0]
4982 # Update the class-count based on what labels are seen in batch.
4983 num_examples_per_class_seen = math_ops.add(
4984 num_examples_per_class_seen,
4985 math_ops.reduce_sum(
4986 array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
4987 init_prob_estimate = math_ops.truediv(
4988 num_examples_per_class_seen,
4989 math_ops.reduce_sum(num_examples_per_class_seen))
4990 dist = math_ops.cast(init_prob_estimate, dtypes.float32)
4991 return num_examples_per_class_seen, dist
4994def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
4995 """Calculates the acceptance probabilities and mixing ratio.
4997 In this case, we assume that we can *either* sample from the original data
4998 distribution with probability `m`, or sample from a reshaped distribution
4999 that comes from rejection sampling on the original distribution. This
5000 rejection sampling is done on a per-class basis, with `a_i` representing the
5001 probability of accepting data from class `i`.
5003 This method is based on solving the following analysis for the reshaped
5004 distribution:
5006 Let F be the probability of a rejection (on any example).
5007 Let p_i be the proportion of examples in the data in class i (init_probs)
5008 Let a_i is the rate the rejection sampler should *accept* class i
5009 Let t_i is the target proportion in the minibatches for class i (target_probs)
5011 ```
5012 F = sum_i(p_i * (1-a_i))
5013 = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1
5014 ```
5016 An example with class `i` will be accepted if `k` rejections occur, then an
5017 example with class `i` is seen by the rejector, and it is accepted. This can
5018 be written as follows:
5020 ```
5021 t_i = sum_k=0^inf(F^k * p_i * a_i)
5022 = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1
5023 = p_i * a_i / sum_j(p_j * a_j) using F from above
5024 ```
5026 Note that the following constraints hold:
5027 ```
5028 0 <= p_i <= 1, sum_i(p_i) = 1
5029 0 <= a_i <= 1
5030 0 <= t_i <= 1, sum_i(t_i) = 1
5031 ```
5033 A solution for a_i in terms of the other variables is the following:
5034 ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
5036 If we try to minimize the amount of data rejected, we get the following:
5038 M_max = max_i [ t_i / p_i ]
5039 M_min = min_i [ t_i / p_i ]
5041 The desired probability of accepting data if it comes from class `i`:
5043 a_i = (t_i/p_i - m) / (M_max - m)
5045 The desired probability of pulling a data element from the original dataset,
5046 rather than the filtered one:
5048 m = M_min
5050 Args:
5051 initial_probs: A Tensor of the initial probability distribution, given or
5052 estimated.
5053 target_probs: A Tensor of the corresponding classes.
5055 Returns:
5056 (A 1D Tensor with the per-class acceptance probabilities, the desired
5057 probability of pull from the original distribution.)
5058 """
5059 ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
5060 max_ratio = math_ops.reduce_max(ratio_l)
5061 min_ratio = math_ops.reduce_min(ratio_l)
5063 # Target prob to sample from original distribution.
5064 m = min_ratio
5066 # TODO(joelshor): Simplify fraction, if possible.
5067 a_i = (ratio_l - m) / (max_ratio - m)
5068 return a_i, m
5071def _apply_rewrite(dataset, rewrite):
5072 # pylint: disable=protected-access
5073 return _VariantDataset(
5074 gen_dataset_ops.rewrite_dataset(dataset._variant_tensor, rewrite,
5075 **dataset._flat_structure),
5076 dataset.element_spec)
5079def _collect_resource_inputs(op):
5080 """Collects resource inputs for the given ops (and its variant inputs)."""
5082 def _process(op_queue, seen_ops):
5083 """Processes the next element of the op queue.
5085 Args:
5086 op_queue: Queue of Dataset operations to process.
5087 seen_ops: Already processed set of Operations.
5089 Returns:
5090 A 2-tuple containing sets of resource handles. The first tuple entry
5091 contains read-only handles and the second entry contains read-write
5092 handles.
5093 """
5095 reads = []
5096 writes = []
5097 op = op_queue.pop()
5098 if op in seen_ops:
5099 return reads, writes
5100 seen_ops.add(op)
5101 # TODO(b/150139257): All resource inputs are in writes right now since we
5102 # have not updated the functional ops to set the special attribute that ACD
5103 # uses to figure out which of the op's inputs are read-only.
5104 reads, writes = acd_utils.get_read_write_resource_inputs(op)
5105 # Conservatively assume that any variant inputs are datasets.
5106 op_queue.extend(t.op for t in op.inputs if t.dtype == dtypes.variant)
5107 return reads, writes
5109 op_queue = [op]
5110 seen_ops = set()
5111 all_reads = []
5112 all_writes = []
5113 while op_queue:
5114 reads, writes = _process(op_queue, seen_ops)
5115 all_reads.extend(reads)
5116 all_writes.extend(writes)
5118 return all_reads, all_writes
5121@auto_control_deps.register_acd_resource_resolver
5122def _resource_resolver(op, resource_reads, resource_writes):
5123 """Updates resource inputs for tf.data ops with indirect dependencies."""
5125 updated = False
5126 if op.type in [
5127 "DatasetToSingleElement", "DatasetToTFRecord", "ReduceDataset"
5128 ]:
5129 reads, writes = _collect_resource_inputs(op)
5130 for inp in reads:
5131 if inp not in resource_reads:
5132 updated = True
5133 resource_reads.add(inp)
5134 for inp in writes:
5135 if inp not in resource_writes:
5136 updated = True
5137 resource_writes.add(inp)
5139 if op.type in [
5140 "IteratorGetNext", "IteratorGetNextSync", "IteratorGetNextAsOptional"
5141 ]:
5142 iterator_resource = op.inputs[0]
5143 make_iterator_ops = [
5144 op for op in iterator_resource.consumers() if op.type == "MakeIterator"
5145 ]
5147 if len(make_iterator_ops) == 1:
5148 reads, writes = _collect_resource_inputs(make_iterator_ops[0])
5149 for inp in reads:
5150 if inp not in resource_reads:
5151 updated = True
5152 resource_reads.add(inp)
5153 for inp in writes:
5154 if inp not in resource_writes:
5155 updated = True
5156 resource_writes.add(inp)
5158 return updated
5161dataset_autograph.register_overrides()