Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/iterator_ops.py: 43%
291 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Iterators."""
16import abc
17import threading
18import warnings
20from tensorflow.core.protobuf import struct_pb2
21from tensorflow.python.checkpoint import saveable_compat
22from tensorflow.python.data.ops import iterator_autograph
23from tensorflow.python.data.ops import optional_ops
24from tensorflow.python.data.ops import options as options_lib
25from tensorflow.python.data.util import nest
26from tensorflow.python.data.util import structure
27from tensorflow.python.eager import context
28from tensorflow.python.framework import composite_tensor
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.framework import tensor_spec
34from tensorflow.python.framework import type_spec
35from tensorflow.python.framework import type_utils
36from tensorflow.python.ops import gen_dataset_ops
37from tensorflow.python.ops import parsing_ops
38from tensorflow.python.saved_model import nested_structure_coder
39from tensorflow.python.trackable import base as trackable
40from tensorflow.python.training.saver import BaseSaverBuilder
41from tensorflow.python.util import _pywrap_utils
42from tensorflow.python.util import deprecation
43from tensorflow.python.util import lazy_loader
44from tensorflow.python.util.compat import collections_abc
45from tensorflow.python.util.tf_export import tf_export
48# NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple
49# times, e.g. when you are distributing different elements to multiple
50# devices in a single step. However, a common pitfall arises when
51# users call `Iterator.get_next()` in each iteration of their training
52# loop. `Iterator.get_next()` adds ops to the graph, and executing
53# each op allocates resources (including threads); as a consequence,
54# invoking it in every iteration of a training loop causes slowdown
55# and eventual resource exhaustion. To guard against this outcome, we
56# log a warning when the number of uses crosses a threshold of suspicion.
57GET_NEXT_CALL_WARNING_THRESHOLD = 32
59GET_NEXT_CALL_WARNING_MESSAGE = (
60 "An unusually high number of `Iterator.get_next()` calls was detected. "
61 "This often indicates that `Iterator.get_next()` is being called inside "
62 "a training loop, which will cause gradual slowdown and eventual resource "
63 "exhaustion. If this is the case, restructure your code to call "
64 "`next_element = iterator.get_next()` once outside the loop, and use "
65 "`next_element` as the input to some computation that is invoked inside "
66 "the loop.")
68# NOTE(jsimsa): Threshold used as a heuristic to check for infinite loop during
69# tf.function tracing.
70GET_NEXT_CALL_ERROR_THRESHOLD = 32
72GET_NEXT_CALL_ERROR_MESSAGE = (
73 "An unusually high number of `tf.data.Iterator.get_next()` calls was "
74 "detected. This suggests that the `for elem in dataset: ...` idiom is used "
75 "within tf.function with AutoGraph disabled. This idiom is only supported "
76 "when AutoGraph is enabled.")
78# Collection of all IteratorResources in the `Graph`.
79GLOBAL_ITERATORS = "iterators"
82autograph_ctx = lazy_loader.LazyLoader(
83 "autograph_ctx", globals(),
84 "tensorflow.python.autograph.core.ag_ctx")
87def _device_stack_is_empty():
88 if context.executing_eagerly():
89 return context.context().device_name is None
90 # pylint: disable=protected-access
91 device_stack = ops.get_default_graph()._device_functions_outer_to_inner
92 # pylint: enable=protected-access
93 return not bool(device_stack)
96@saveable_compat.legacy_saveable_name("ITERATOR")
97@tf_export(v1=["data.Iterator"])
98class Iterator(trackable.Trackable):
99 """Represents the state of iterating through a `Dataset`."""
101 def __init__(self, iterator_resource, initializer, output_types,
102 output_shapes, output_classes):
103 """Creates a new iterator from the given iterator resource.
105 Note: Most users will not call this initializer directly, and will
106 instead use `Dataset.make_initializable_iterator()` or
107 `Dataset.make_one_shot_iterator()`.
109 Args:
110 iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the
111 iterator.
112 initializer: A `tf.Operation` that should be run to initialize this
113 iterator.
114 output_types: A (nested) structure of `tf.DType` objects corresponding to
115 each component of an element of this iterator.
116 output_shapes: A (nested) structure of `tf.TensorShape` objects
117 corresponding to each component of an element of this iterator.
118 output_classes: A (nested) structure of Python `type` objects
119 corresponding to each component of an element of this iterator.
121 Raises:
122 TypeError: If `output_types`, `output_shapes`, or `output_classes` is not
123 specified.
124 """
125 self._iterator_resource = iterator_resource
126 self._initializer = initializer
128 if (output_types is None or output_shapes is None
129 or output_classes is None):
130 raise ValueError(
131 "All of `output_types`, `output_shapes`, and `output_classes` "
132 "must be specified to create an iterator. Got "
133 f"`output_types` = {output_types!r}, "
134 f"`output_shapes` = {output_shapes!r}, "
135 f"`output_classes` = {output_classes!r}.")
136 self._element_spec = structure.convert_legacy_structure(
137 output_types, output_shapes, output_classes)
138 self._flat_tensor_shapes = structure.get_flat_tensor_shapes(
139 self._element_spec)
140 self._flat_tensor_types = structure.get_flat_tensor_types(
141 self._element_spec)
143 self._string_handle = gen_dataset_ops.iterator_to_string_handle(
144 self._iterator_resource)
145 self._get_next_call_count = 0
146 ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
148 @staticmethod
149 def from_structure(output_types,
150 output_shapes=None,
151 shared_name=None,
152 output_classes=None):
153 """Creates a new, uninitialized `Iterator` with the given structure.
155 This iterator-constructing method can be used to create an iterator that
156 is reusable with many different datasets.
158 The returned iterator is not bound to a particular dataset, and it has
159 no `initializer`. To initialize the iterator, run the operation returned by
160 `Iterator.make_initializer(dataset)`.
162 The following is an example
164 ```python
165 iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))
167 dataset_range = Dataset.range(10)
168 range_initializer = iterator.make_initializer(dataset_range)
170 dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
171 evens_initializer = iterator.make_initializer(dataset_evens)
173 # Define a model based on the iterator; in this example, the model_fn
174 # is expected to take scalar tf.int64 Tensors as input (see
175 # the definition of 'iterator' above).
176 prediction, loss = model_fn(iterator.get_next())
178 # Train for `num_epochs`, where for each epoch, we first iterate over
179 # dataset_range, and then iterate over dataset_evens.
180 for _ in range(num_epochs):
181 # Initialize the iterator to `dataset_range`
182 sess.run(range_initializer)
183 while True:
184 try:
185 pred, loss_val = sess.run([prediction, loss])
186 except tf.errors.OutOfRangeError:
187 break
189 # Initialize the iterator to `dataset_evens`
190 sess.run(evens_initializer)
191 while True:
192 try:
193 pred, loss_val = sess.run([prediction, loss])
194 except tf.errors.OutOfRangeError:
195 break
196 ```
198 Args:
199 output_types: A (nested) structure of `tf.DType` objects corresponding to
200 each component of an element of this dataset.
201 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape`
202 objects corresponding to each component of an element of this dataset.
203 If omitted, each component will have an unconstrainted shape.
204 shared_name: (Optional.) If non-empty, this iterator will be shared under
205 the given name across multiple sessions that share the same devices
206 (e.g. when using a remote server).
207 output_classes: (Optional.) A (nested) structure of Python `type` objects
208 corresponding to each component of an element of this iterator. If
209 omitted, each component is assumed to be of type `tf.Tensor`.
211 Returns:
212 An `Iterator`.
214 Raises:
215 TypeError: If the structures of `output_shapes` and `output_types` are
216 not the same.
217 """
218 output_types = nest.map_structure(dtypes.as_dtype, output_types)
219 if output_shapes is None:
220 output_shapes = nest.map_structure(
221 lambda _: tensor_shape.TensorShape(None), output_types)
222 else:
223 output_shapes = nest.map_structure_up_to(output_types,
224 tensor_shape.as_shape,
225 output_shapes)
226 if output_classes is None:
227 output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
228 nest.assert_same_structure(output_types, output_shapes)
229 output_structure = structure.convert_legacy_structure(
230 output_types, output_shapes, output_classes)
231 if shared_name is None:
232 shared_name = ""
233 iterator_resource = gen_dataset_ops.iterator_v2(
234 container="",
235 shared_name=shared_name,
236 output_types=structure.get_flat_tensor_types(output_structure),
237 output_shapes=structure.get_flat_tensor_shapes(
238 output_structure))
239 return Iterator(iterator_resource, None, output_types, output_shapes,
240 output_classes)
242 @staticmethod
243 def from_string_handle(string_handle,
244 output_types,
245 output_shapes=None,
246 output_classes=None):
247 """Creates a new, uninitialized `Iterator` based on the given handle.
249 This method allows you to define a "feedable" iterator where you can choose
250 between concrete iterators by feeding a value in a `tf.Session.run` call.
251 In that case, `string_handle` would be a `tf.compat.v1.placeholder`, and you
252 would
253 feed it with the value of `tf.data.Iterator.string_handle` in each step.
255 For example, if you had two iterators that marked the current position in
256 a training dataset and a test dataset, you could choose which to use in
257 each step as follows:
259 ```python
260 train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
261 train_iterator_handle = sess.run(train_iterator.string_handle())
263 test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
264 test_iterator_handle = sess.run(test_iterator.string_handle())
266 handle = tf.compat.v1.placeholder(tf.string, shape=[])
267 iterator = tf.data.Iterator.from_string_handle(
268 handle, train_iterator.output_types)
270 next_element = iterator.get_next()
271 loss = f(next_element)
273 train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
274 test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
275 ```
277 Args:
278 string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to
279 a handle produced by the `Iterator.string_handle()` method.
280 output_types: A (nested) structure of `tf.DType` objects corresponding to
281 each component of an element of this dataset.
282 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape`
283 objects corresponding to each component of an element of this dataset.
284 If omitted, each component will have an unconstrainted shape.
285 output_classes: (Optional.) A (nested) structure of Python `type` objects
286 corresponding to each component of an element of this iterator. If
287 omitted, each component is assumed to be of type `tf.Tensor`.
289 Returns:
290 An `Iterator`.
291 """
292 output_types = nest.map_structure(dtypes.as_dtype, output_types)
293 if output_shapes is None:
294 output_shapes = nest.map_structure(
295 lambda _: tensor_shape.TensorShape(None), output_types)
296 else:
297 output_shapes = nest.map_structure_up_to(output_types,
298 tensor_shape.as_shape,
299 output_shapes)
300 if output_classes is None:
301 output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
302 nest.assert_same_structure(output_types, output_shapes)
303 output_structure = structure.convert_legacy_structure(
304 output_types, output_shapes, output_classes)
305 string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
306 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
307 string_handle,
308 output_types=structure.get_flat_tensor_types(output_structure),
309 output_shapes=structure.get_flat_tensor_shapes(output_structure))
310 return Iterator(iterator_resource, None, output_types, output_shapes,
311 output_classes)
313 @property
314 def initializer(self):
315 """A `tf.Operation` that should be run to initialize this iterator.
317 Returns:
318 A `tf.Operation` that should be run to initialize this iterator
320 Raises:
321 ValueError: If this iterator initializes itself automatically.
322 """
323 if self._initializer is not None:
324 return self._initializer
325 else:
326 # TODO(mrry): Consider whether one-shot iterators should have
327 # initializers that simply reset their state to the beginning.
328 raise ValueError(
329 "The iterator does not have an initializer. This means it was likely "
330 "created using `tf.data.Dataset.make_one_shot_iterator()`. For an "
331 "initializable iterator, use "
332 "`tf.data.Dataset.make_initializable_iterator()` instead.")
334 def make_initializer(self, dataset, name=None):
335 """Returns a `tf.Operation` that initializes this iterator on `dataset`.
337 Args:
338 dataset: A `Dataset` whose `element_spec` if compatible with this
339 iterator.
340 name: (Optional.) A name for the created operation.
342 Returns:
343 A `tf.Operation` that can be run to initialize this iterator on the given
344 `dataset`.
346 Raises:
347 TypeError: If `dataset` and this iterator do not have a compatible
348 `element_spec`.
349 """
350 with ops.name_scope(name, "make_initializer") as name:
351 # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due
352 # to that creating a circular dependency.
353 # pylint: disable=protected-access
354 dataset_output_types = nest.map_structure(
355 lambda component_spec: component_spec._to_legacy_output_types(),
356 dataset.element_spec)
357 dataset_output_shapes = nest.map_structure(
358 lambda component_spec: component_spec._to_legacy_output_shapes(),
359 dataset.element_spec)
360 dataset_output_classes = nest.map_structure(
361 lambda component_spec: component_spec._to_legacy_output_classes(),
362 dataset.element_spec)
363 # pylint: enable=protected-access
365 nest.assert_same_structure(self.output_types, dataset_output_types)
366 nest.assert_same_structure(self.output_shapes, dataset_output_shapes)
367 for iterator_class, dataset_class in zip(
368 nest.flatten(self.output_classes),
369 nest.flatten(dataset_output_classes)):
370 if iterator_class is not dataset_class:
371 raise TypeError(
372 f"Expected output classes {self.output_classes!r} but got "
373 f"dataset with output classes {dataset_output_classes!r}.")
374 for iterator_dtype, dataset_dtype in zip(
375 nest.flatten(self.output_types), nest.flatten(dataset_output_types)):
376 if iterator_dtype != dataset_dtype:
377 raise TypeError(
378 f"Expected output types {self.output_types!r} but got dataset "
379 f"with output types {dataset_output_types!r}.")
380 for iterator_shape, dataset_shape in zip(
381 nest.flatten(self.output_shapes), nest.flatten(
382 dataset_output_shapes)):
383 if not iterator_shape.is_compatible_with(dataset_shape):
384 raise TypeError(
385 f"Expected output shapes compatible with {self.output_shapes!r} "
386 f"but got dataset with output shapes {dataset_output_shapes!r}.")
388 # TODO(b/169442955): Investigate the need for this colocation constraint.
389 with ops.colocate_with(self._iterator_resource):
390 # pylint: disable=protected-access
391 return gen_dataset_ops.make_iterator(
392 dataset._variant_tensor, self._iterator_resource, name=name)
394 def get_next(self, name=None):
395 """Returns the next element.
397 In graph mode, you should typically call this method *once* and use its
398 result as the input to another computation. A typical loop will then call
399 `tf.Session.run` on the result of that computation. The loop will terminate
400 when the `Iterator.get_next()` operation raises
401 `tf.errors.OutOfRangeError`. The following skeleton shows how to use
402 this method when building a training loop:
404 ```python
405 dataset = ... # A `tf.data.Dataset` object.
406 iterator = dataset.make_initializable_iterator()
407 next_element = iterator.get_next()
409 # Build a TensorFlow graph that does something with each element.
410 loss = model_function(next_element)
411 optimizer = ... # A `tf.compat.v1.train.Optimizer` object.
412 train_op = optimizer.minimize(loss)
414 with tf.compat.v1.Session() as sess:
415 try:
416 while True:
417 sess.run(train_op)
418 except tf.errors.OutOfRangeError:
419 pass
420 ```
422 NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g.
423 when you are distributing different elements to multiple devices in a single
424 step. However, a common pitfall arises when users call `Iterator.get_next()`
425 in each iteration of their training loop. `Iterator.get_next()` adds ops to
426 the graph, and executing each op allocates resources (including threads); as
427 a consequence, invoking it in every iteration of a training loop causes
428 slowdown and eventual resource exhaustion. To guard against this outcome, we
429 log a warning when the number of uses crosses a fixed threshold of
430 suspiciousness.
432 Args:
433 name: (Optional.) A name for the created operation.
435 Returns:
436 A (nested) structure of values matching `tf.data.Iterator.element_spec`.
437 """
438 self._get_next_call_count += 1
439 if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
440 warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
442 # TODO(b/169442955): Investigate the need for this colocation constraint.
443 with ops.colocate_with(self._iterator_resource):
444 # pylint: disable=protected-access
445 flat_ret = gen_dataset_ops.iterator_get_next(
446 self._iterator_resource,
447 output_types=self._flat_tensor_types,
448 output_shapes=self._flat_tensor_shapes,
449 name=name)
450 return structure.from_tensor_list(self._element_spec, flat_ret)
452 def get_next_as_optional(self):
453 # TODO(b/169442955): Investigate the need for this colocation constraint.
454 with ops.colocate_with(self._iterator_resource):
455 # pylint: disable=protected-access
456 return optional_ops._OptionalImpl(
457 gen_dataset_ops.iterator_get_next_as_optional(
458 self._iterator_resource,
459 output_types=structure.get_flat_tensor_types(self.element_spec),
460 output_shapes=structure.get_flat_tensor_shapes(
461 self.element_spec)), self.element_spec)
463 def string_handle(self, name=None):
464 """Returns a string-valued `tf.Tensor` that represents this iterator.
466 Args:
467 name: (Optional.) A name for the created operation.
469 Returns:
470 A scalar `tf.Tensor` of type `tf.string`.
471 """
472 if name is None:
473 return self._string_handle
474 else:
475 return gen_dataset_ops.iterator_to_string_handle(
476 self._iterator_resource, name=name)
478 @property
479 @deprecation.deprecated(
480 None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.")
481 def output_classes(self):
482 """Returns the class of each component of an element of this iterator.
484 The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`.
486 Returns:
487 A (nested) structure of Python `type` objects corresponding to each
488 component of an element of this dataset.
489 """
490 return nest.map_structure(
491 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
492 self._element_spec)
494 @property
495 @deprecation.deprecated(
496 None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.")
497 def output_shapes(self):
498 """Returns the shape of each component of an element of this iterator.
500 Returns:
501 A (nested) structure of `tf.TensorShape` objects corresponding to each
502 component of an element of this dataset.
503 """
504 return nest.map_structure(
505 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
506 self._element_spec)
508 @property
509 @deprecation.deprecated(
510 None, "Use `tf.compat.v1.data.get_output_types(iterator)`.")
511 def output_types(self):
512 """Returns the type of each component of an element of this iterator.
514 Returns:
515 A (nested) structure of `tf.DType` objects corresponding to each component
516 of an element of this dataset.
517 """
518 return nest.map_structure(
519 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
520 self._element_spec)
522 @property
523 def element_spec(self):
524 """The type specification of an element of this iterator.
526 For more information,
527 read [this guide](https://www.tensorflow.org/guide/data#dataset_structure).
529 Returns:
530 A (nested) structure of `tf.TypeSpec` objects matching the structure of an
531 element of this iterator and specifying the type of individual components.
532 """
534 return self._element_spec
536 # override
537 def _serialize_to_tensors(self):
538 serialized_iterator = gen_dataset_ops.serialize_iterator(
539 self._iterator_resource,
540 options_lib.ExternalStatePolicy.FAIL.value)
541 return {"_STATE": serialized_iterator}
543 # override
544 def _restore_from_tensors(self, restored_tensors):
545 with ops.colocate_with(self._iterator_resource):
546 return [gen_dataset_ops.deserialize_iterator(
547 self._iterator_resource, restored_tensors["_STATE"])]
550_uid_counter = 0
551_uid_lock = threading.Lock()
554def _generate_shared_name(prefix):
555 with _uid_lock:
556 global _uid_counter
557 uid = _uid_counter
558 _uid_counter += 1
559 return "{}{}".format(prefix, uid)
562@tf_export("data.Iterator", v1=[])
563class IteratorBase(
564 collections_abc.Iterator,
565 trackable.Trackable,
566 composite_tensor.CompositeTensor,
567 metaclass=abc.ABCMeta):
568 """Represents an iterator of a `tf.data.Dataset`.
570 `tf.data.Iterator` is the primary mechanism for enumerating elements of a
571 `tf.data.Dataset`. It supports the Python Iterator protocol, which means
572 it can be iterated over using a for-loop:
574 >>> dataset = tf.data.Dataset.range(2)
575 >>> for element in dataset:
576 ... print(element)
577 tf.Tensor(0, shape=(), dtype=int64)
578 tf.Tensor(1, shape=(), dtype=int64)
580 or by fetching individual elements explicitly via `get_next()`:
582 >>> dataset = tf.data.Dataset.range(2)
583 >>> iterator = iter(dataset)
584 >>> print(iterator.get_next())
585 tf.Tensor(0, shape=(), dtype=int64)
586 >>> print(iterator.get_next())
587 tf.Tensor(1, shape=(), dtype=int64)
589 In addition, non-raising iteration is supported via `get_next_as_optional()`,
590 which returns the next element (if available) wrapped in a
591 `tf.experimental.Optional`.
593 >>> dataset = tf.data.Dataset.from_tensors(42)
594 >>> iterator = iter(dataset)
595 >>> optional = iterator.get_next_as_optional()
596 >>> print(optional.has_value())
597 tf.Tensor(True, shape=(), dtype=bool)
598 >>> optional = iterator.get_next_as_optional()
599 >>> print(optional.has_value())
600 tf.Tensor(False, shape=(), dtype=bool)
601 """
603 @abc.abstractproperty
604 def element_spec(self):
605 """The type specification of an element of this iterator.
607 >>> dataset = tf.data.Dataset.from_tensors(42)
608 >>> iterator = iter(dataset)
609 >>> iterator.element_spec
610 tf.TensorSpec(shape=(), dtype=tf.int32, name=None)
612 For more information,
613 read [this guide](https://www.tensorflow.org/guide/data#dataset_structure).
615 Returns:
616 A (nested) structure of `tf.TypeSpec` objects matching the structure of an
617 element of this iterator, specifying the type of individual components.
618 """
619 raise NotImplementedError("Iterator.element_spec")
621 @abc.abstractmethod
622 def get_next(self):
623 """Returns the next element.
625 >>> dataset = tf.data.Dataset.from_tensors(42)
626 >>> iterator = iter(dataset)
627 >>> print(iterator.get_next())
628 tf.Tensor(42, shape=(), dtype=int32)
630 Returns:
631 A (nested) structure of values matching `tf.data.Iterator.element_spec`.
633 Raises:
634 `tf.errors.OutOfRangeError`: If the end of the iterator has been reached.
635 """
636 raise NotImplementedError("Iterator.get_next()")
638 @abc.abstractmethod
639 def get_next_as_optional(self):
640 """Returns the next element wrapped in `tf.experimental.Optional`.
642 If the iterator has reached the end of the sequence, the returned
643 `tf.experimental.Optional` will have no value.
645 >>> dataset = tf.data.Dataset.from_tensors(42)
646 >>> iterator = iter(dataset)
647 >>> optional = iterator.get_next_as_optional()
648 >>> print(optional.has_value())
649 tf.Tensor(True, shape=(), dtype=bool)
650 >>> print(optional.get_value())
651 tf.Tensor(42, shape=(), dtype=int32)
652 >>> optional = iterator.get_next_as_optional()
653 >>> print(optional.has_value())
654 tf.Tensor(False, shape=(), dtype=bool)
656 Returns:
657 A `tf.experimental.Optional` object representing the next element.
658 """
659 raise NotImplementedError("Iterator.get_next_as_optional()")
662@saveable_compat.legacy_saveable_name("ITERATOR")
663class OwnedIterator(IteratorBase):
664 """An iterator producing tf.Tensor objects from a tf.data.Dataset.
666 The iterator resource created through `OwnedIterator` is owned by the Python
667 object and the life time of the underlying resource is tied to the life time
668 of the `OwnedIterator` object. This makes `OwnedIterator` appropriate for use
669 in eager mode and inside of tf.functions.
670 """
672 def __init__(self, dataset=None, components=None, element_spec=None):
673 """Creates a new iterator from the given dataset.
675 If `dataset` is not specified, the iterator will be created from the given
676 tensor components and element structure. In particular, the alternative for
677 constructing the iterator is used when the iterator is reconstructed from
678 it `CompositeTensor` representation.
680 Args:
681 dataset: A `tf.data.Dataset` object.
682 components: Tensor components to construct the iterator from.
683 element_spec: A (nested) structure of `TypeSpec` objects that
684 represents the type specification of elements of the iterator.
686 Raises:
687 ValueError: If `dataset` is not provided and either `components` or
688 `element_spec` is not provided. Or `dataset` is provided and either
689 `components` and `element_spec` is provided.
690 """
691 super(OwnedIterator, self).__init__()
693 if dataset is None:
694 if (components is None or element_spec is None):
695 raise ValueError(
696 "When `dataset` is not provided, both `components` and "
697 "`element_spec` must be specified.")
698 # pylint: disable=protected-access
699 self._element_spec = element_spec
700 self._flat_output_types = structure.get_flat_tensor_types(
701 self._element_spec)
702 self._flat_output_shapes = structure.get_flat_tensor_shapes(
703 self._element_spec)
704 self._iterator_resource, = components
705 else:
706 if (components is not None or element_spec is not None):
707 raise ValueError(
708 "When `dataset` is provided, `element_spec` and `components` must "
709 "not be specified.")
710 self._create_iterator(dataset)
712 self._get_next_call_count = 0
714 def _create_iterator(self, dataset):
715 # pylint: disable=protected-access
716 dataset = dataset._apply_debug_options()
718 # Store dataset reference to ensure that dataset is alive when this iterator
719 # is being used. For example, `tf.data.Dataset.from_generator` registers
720 # a few py_funcs that are needed in `self._next_internal`. If the dataset
721 # is deleted, this iterator crashes on `self.__next__(...)` call.
722 self._dataset = dataset
724 ds_variant = dataset._variant_tensor
725 self._element_spec = dataset.element_spec
726 self._flat_output_types = structure.get_flat_tensor_types(
727 self._element_spec)
728 self._flat_output_shapes = structure.get_flat_tensor_shapes(
729 self._element_spec)
730 with ops.colocate_with(ds_variant):
731 self._iterator_resource = (
732 gen_dataset_ops.anonymous_iterator_v3(
733 output_types=self._flat_output_types,
734 output_shapes=self._flat_output_shapes))
735 if not context.executing_eagerly():
736 # Add full type information to the graph so host memory types inside
737 # variants stay on CPU, e.g, ragged string tensors.
738 # TODO(b/224776031) Remove this when AnonymousIterateV3 can use
739 # (reverse) type inference and all other ops that are needed to
740 # provide type information to the AnonymousIterateV3 also support
741 # type inference (esp. cross-function type inference) instead of
742 # setting the full type information manually.
743 fulltype = type_utils.iterator_full_type_from_spec(
744 self._element_spec)
745 # fulltype is PRODUCT[ITERATOR[PRODUCT[...]]]
746 assert len(fulltype.args[0].args[0].args) == len(
747 self._flat_output_types)
748 self._iterator_resource.op.experimental_set_type(fulltype)
749 gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
751 def __iter__(self):
752 return self
754 def next(self): # For Python 2 compatibility
755 return self.__next__()
757 def _next_internal(self):
758 autograph_status = autograph_ctx.control_status_ctx().status
759 autograph_disabled = autograph_status == autograph_ctx.Status.DISABLED
760 if not context.executing_eagerly() and autograph_disabled:
761 self._get_next_call_count += 1
762 if self._get_next_call_count > GET_NEXT_CALL_ERROR_THRESHOLD:
763 raise ValueError(GET_NEXT_CALL_ERROR_MESSAGE)
765 if not context.executing_eagerly():
766 # TODO(b/169442955): Investigate the need for this colocation constraint.
767 with ops.colocate_with(self._iterator_resource):
768 ret = gen_dataset_ops.iterator_get_next(
769 self._iterator_resource,
770 output_types=self._flat_output_types,
771 output_shapes=self._flat_output_shapes)
772 return structure.from_compatible_tensor_list(self._element_spec, ret)
774 # TODO(b/77291417): This runs in sync mode as iterators use an error status
775 # to communicate that there is no more data to iterate over.
776 with context.execution_mode(context.SYNC):
777 ret = gen_dataset_ops.iterator_get_next(
778 self._iterator_resource,
779 output_types=self._flat_output_types,
780 output_shapes=self._flat_output_shapes)
782 try:
783 # Fast path for the case `self._structure` is not a nested structure.
784 return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access
785 except AttributeError:
786 return structure.from_compatible_tensor_list(self._element_spec, ret)
788 def _save(self):
789 external_state_policy = None
790 if (
791 self._dataset
792 and self._dataset.options().experimental_external_state_policy
793 ):
794 external_state_policy = (
795 self._dataset.options().experimental_external_state_policy.value
796 )
797 state_variant = gen_dataset_ops.serialize_iterator(
798 self._iterator_resource, external_state_policy
799 )
800 return parsing_ops.serialize_tensor(state_variant)
802 def _restore(self, state):
803 state_variant = parsing_ops.parse_tensor(state, dtypes.variant)
804 return gen_dataset_ops.deserialize_iterator(
805 self._iterator_resource, state_variant
806 )
808 @property
809 def _type_spec(self):
810 return IteratorSpec(self.element_spec)
812 def __next__(self):
813 try:
814 return self._next_internal()
815 except errors.OutOfRangeError:
816 raise StopIteration
818 @property
819 @deprecation.deprecated(
820 None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.")
821 def output_classes(self):
822 """Returns the class of each component of an element of this iterator.
824 The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`.
826 Returns:
827 A (nested) structure of Python `type` objects corresponding to each
828 component of an element of this dataset.
829 """
830 return nest.map_structure(
831 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
832 self._element_spec)
834 @property
835 @deprecation.deprecated(
836 None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.")
837 def output_shapes(self):
838 """Returns the shape of each component of an element of this iterator.
840 Returns:
841 A (nested) structure of `tf.TensorShape` objects corresponding to each
842 component of an element of this dataset.
843 """
844 return nest.map_structure(
845 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
846 self._element_spec)
848 @property
849 @deprecation.deprecated(
850 None, "Use `tf.compat.v1.data.get_output_types(iterator)`.")
851 def output_types(self):
852 """Returns the type of each component of an element of this iterator.
854 Returns:
855 A (nested) structure of `tf.DType` objects corresponding to each component
856 of an element of this dataset.
857 """
858 return nest.map_structure(
859 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
860 self._element_spec)
862 @property
863 def element_spec(self):
864 return self._element_spec
866 def get_next(self):
867 return self._next_internal()
869 def get_next_as_optional(self):
870 # TODO(b/169442955): Investigate the need for this colocation constraint.
871 with ops.colocate_with(self._iterator_resource):
872 # pylint: disable=protected-access
873 return optional_ops._OptionalImpl(
874 gen_dataset_ops.iterator_get_next_as_optional(
875 self._iterator_resource,
876 output_types=structure.get_flat_tensor_types(self.element_spec),
877 output_shapes=structure.get_flat_tensor_shapes(
878 self.element_spec)), self.element_spec)
880 def _serialize_to_tensors(self):
881 serialized_iterator = None
882 if (self._dataset and
883 self._dataset.options().experimental_external_state_policy):
884 serialized_iterator = gen_dataset_ops.serialize_iterator(
885 self._iterator_resource,
886 self._dataset.options().experimental_external_state_policy.value)
887 else:
888 serialized_iterator = gen_dataset_ops.serialize_iterator(
889 self._iterator_resource,
890 options_lib.ExternalStatePolicy.FAIL.value)
891 return {"_STATE": serialized_iterator}
893 def _restore_from_tensors(self, restored_tensors):
894 with ops.colocate_with(self._iterator_resource):
895 return [gen_dataset_ops.deserialize_iterator(
896 self._iterator_resource, restored_tensors["_STATE"])]
898 def __tf_tracing_type__(self, _):
899 return self._type_spec
902@tf_export("data.IteratorSpec", v1=[])
903class IteratorSpec(type_spec.TypeSpec):
904 """Type specification for `tf.data.Iterator`.
906 For instance, `tf.data.IteratorSpec` can be used to define a tf.function that
907 takes `tf.data.Iterator` as an input argument:
909 >>> @tf.function(input_signature=[tf.data.IteratorSpec(
910 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))])
911 ... def square(iterator):
912 ... x = iterator.get_next()
913 ... return x * x
914 >>> dataset = tf.data.Dataset.from_tensors(5)
915 >>> iterator = iter(dataset)
916 >>> print(square(iterator))
917 tf.Tensor(25, shape=(), dtype=int32)
919 Attributes:
920 element_spec: A (nested) structure of `tf.TypeSpec` objects that represents
921 the type specification of the iterator elements.
922 """
924 __slots__ = ["_element_spec"]
926 def __init__(self, element_spec):
927 self._element_spec = element_spec
929 @property
930 def value_type(self):
931 return OwnedIterator
933 def _serialize(self):
934 return (self._element_spec,)
936 @property
937 def _component_specs(self):
938 return (tensor_spec.TensorSpec([], dtypes.resource),)
940 def _to_components(self, value):
941 return (value._iterator_resource,) # pylint: disable=protected-access
943 def _from_components(self, components):
944 return OwnedIterator(
945 dataset=None,
946 components=components,
947 element_spec=self._element_spec)
949 @staticmethod
950 def from_value(value):
951 return IteratorSpec(value.element_spec) # pylint: disable=protected-access
954# TODO(b/71645805): Expose trackable stateful objects from dataset.
955class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
956 """SaveableObject for saving/restoring iterator state."""
958 def __init__(
959 self,
960 iterator_resource,
961 name,
962 external_state_policy=options_lib.ExternalStatePolicy.FAIL):
963 serialized_iterator = gen_dataset_ops.serialize_iterator(
964 iterator_resource, external_state_policy=external_state_policy.value)
965 specs = [
966 BaseSaverBuilder.SaveSpec(
967 serialized_iterator,
968 "",
969 name + "_STATE",
970 device=iterator_resource.device)
971 ]
972 super(_IteratorSaveable, self).__init__(iterator_resource, specs, name)
974 def restore(self, restored_tensors, restored_shapes):
975 with ops.colocate_with(self.op):
976 return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
979nested_structure_coder.register_codec(
980 nested_structure_coder.BuiltInTypeSpecCodec(
981 IteratorSpec, struct_pb2.TypeSpecProto.DATA_ITERATOR_SPEC
982 )
983)
986@deprecation.deprecated(
987 None, "Use `tf.data.Iterator.get_next_as_optional()` instead.")
988@tf_export("data.experimental.get_next_as_optional")
989def get_next_as_optional(iterator):
990 """Returns a `tf.experimental.Optional` with the next element of the iterator.
992 If the iterator has reached the end of the sequence, the returned
993 `tf.experimental.Optional` will have no value.
995 Args:
996 iterator: A `tf.data.Iterator`.
998 Returns:
999 A `tf.experimental.Optional` object which either contains the next element
1000 of the iterator (if it exists) or no value.
1001 """
1002 return iterator.get_next_as_optional()
1005_pywrap_utils.RegisterType("OwnedIterator", OwnedIterator)
1006iterator_autograph.register_overrides()