Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/iterator_ops.py: 43%

291 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python wrappers for Iterators.""" 

16import abc 

17import threading 

18import warnings 

19 

20from tensorflow.core.protobuf import struct_pb2 

21from tensorflow.python.checkpoint import saveable_compat 

22from tensorflow.python.data.ops import iterator_autograph 

23from tensorflow.python.data.ops import optional_ops 

24from tensorflow.python.data.ops import options as options_lib 

25from tensorflow.python.data.util import nest 

26from tensorflow.python.data.util import structure 

27from tensorflow.python.eager import context 

28from tensorflow.python.framework import composite_tensor 

29from tensorflow.python.framework import dtypes 

30from tensorflow.python.framework import errors 

31from tensorflow.python.framework import ops 

32from tensorflow.python.framework import tensor_shape 

33from tensorflow.python.framework import tensor_spec 

34from tensorflow.python.framework import type_spec 

35from tensorflow.python.framework import type_utils 

36from tensorflow.python.ops import gen_dataset_ops 

37from tensorflow.python.ops import parsing_ops 

38from tensorflow.python.saved_model import nested_structure_coder 

39from tensorflow.python.trackable import base as trackable 

40from tensorflow.python.training.saver import BaseSaverBuilder 

41from tensorflow.python.util import _pywrap_utils 

42from tensorflow.python.util import deprecation 

43from tensorflow.python.util import lazy_loader 

44from tensorflow.python.util.compat import collections_abc 

45from tensorflow.python.util.tf_export import tf_export 

46 

47 

48# NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple 

49# times, e.g. when you are distributing different elements to multiple 

50# devices in a single step. However, a common pitfall arises when 

51# users call `Iterator.get_next()` in each iteration of their training 

52# loop. `Iterator.get_next()` adds ops to the graph, and executing 

53# each op allocates resources (including threads); as a consequence, 

54# invoking it in every iteration of a training loop causes slowdown 

55# and eventual resource exhaustion. To guard against this outcome, we 

56# log a warning when the number of uses crosses a threshold of suspicion. 

57GET_NEXT_CALL_WARNING_THRESHOLD = 32 

58 

59GET_NEXT_CALL_WARNING_MESSAGE = ( 

60 "An unusually high number of `Iterator.get_next()` calls was detected. " 

61 "This often indicates that `Iterator.get_next()` is being called inside " 

62 "a training loop, which will cause gradual slowdown and eventual resource " 

63 "exhaustion. If this is the case, restructure your code to call " 

64 "`next_element = iterator.get_next()` once outside the loop, and use " 

65 "`next_element` as the input to some computation that is invoked inside " 

66 "the loop.") 

67 

68# NOTE(jsimsa): Threshold used as a heuristic to check for infinite loop during 

69# tf.function tracing. 

70GET_NEXT_CALL_ERROR_THRESHOLD = 32 

71 

72GET_NEXT_CALL_ERROR_MESSAGE = ( 

73 "An unusually high number of `tf.data.Iterator.get_next()` calls was " 

74 "detected. This suggests that the `for elem in dataset: ...` idiom is used " 

75 "within tf.function with AutoGraph disabled. This idiom is only supported " 

76 "when AutoGraph is enabled.") 

77 

78# Collection of all IteratorResources in the `Graph`. 

79GLOBAL_ITERATORS = "iterators" 

80 

81 

82autograph_ctx = lazy_loader.LazyLoader( 

83 "autograph_ctx", globals(), 

84 "tensorflow.python.autograph.core.ag_ctx") 

85 

86 

87def _device_stack_is_empty(): 

88 if context.executing_eagerly(): 

89 return context.context().device_name is None 

90 # pylint: disable=protected-access 

91 device_stack = ops.get_default_graph()._device_functions_outer_to_inner 

92 # pylint: enable=protected-access 

93 return not bool(device_stack) 

94 

95 

96@saveable_compat.legacy_saveable_name("ITERATOR") 

97@tf_export(v1=["data.Iterator"]) 

98class Iterator(trackable.Trackable): 

99 """Represents the state of iterating through a `Dataset`.""" 

100 

101 def __init__(self, iterator_resource, initializer, output_types, 

102 output_shapes, output_classes): 

103 """Creates a new iterator from the given iterator resource. 

104 

105 Note: Most users will not call this initializer directly, and will 

106 instead use `Dataset.make_initializable_iterator()` or 

107 `Dataset.make_one_shot_iterator()`. 

108 

109 Args: 

110 iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the 

111 iterator. 

112 initializer: A `tf.Operation` that should be run to initialize this 

113 iterator. 

114 output_types: A (nested) structure of `tf.DType` objects corresponding to 

115 each component of an element of this iterator. 

116 output_shapes: A (nested) structure of `tf.TensorShape` objects 

117 corresponding to each component of an element of this iterator. 

118 output_classes: A (nested) structure of Python `type` objects 

119 corresponding to each component of an element of this iterator. 

120 

121 Raises: 

122 TypeError: If `output_types`, `output_shapes`, or `output_classes` is not 

123 specified. 

124 """ 

125 self._iterator_resource = iterator_resource 

126 self._initializer = initializer 

127 

128 if (output_types is None or output_shapes is None 

129 or output_classes is None): 

130 raise ValueError( 

131 "All of `output_types`, `output_shapes`, and `output_classes` " 

132 "must be specified to create an iterator. Got " 

133 f"`output_types` = {output_types!r}, " 

134 f"`output_shapes` = {output_shapes!r}, " 

135 f"`output_classes` = {output_classes!r}.") 

136 self._element_spec = structure.convert_legacy_structure( 

137 output_types, output_shapes, output_classes) 

138 self._flat_tensor_shapes = structure.get_flat_tensor_shapes( 

139 self._element_spec) 

140 self._flat_tensor_types = structure.get_flat_tensor_types( 

141 self._element_spec) 

142 

143 self._string_handle = gen_dataset_ops.iterator_to_string_handle( 

144 self._iterator_resource) 

145 self._get_next_call_count = 0 

146 ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource) 

147 

148 @staticmethod 

149 def from_structure(output_types, 

150 output_shapes=None, 

151 shared_name=None, 

152 output_classes=None): 

153 """Creates a new, uninitialized `Iterator` with the given structure. 

154 

155 This iterator-constructing method can be used to create an iterator that 

156 is reusable with many different datasets. 

157 

158 The returned iterator is not bound to a particular dataset, and it has 

159 no `initializer`. To initialize the iterator, run the operation returned by 

160 `Iterator.make_initializer(dataset)`. 

161 

162 The following is an example 

163 

164 ```python 

165 iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) 

166 

167 dataset_range = Dataset.range(10) 

168 range_initializer = iterator.make_initializer(dataset_range) 

169 

170 dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) 

171 evens_initializer = iterator.make_initializer(dataset_evens) 

172 

173 # Define a model based on the iterator; in this example, the model_fn 

174 # is expected to take scalar tf.int64 Tensors as input (see 

175 # the definition of 'iterator' above). 

176 prediction, loss = model_fn(iterator.get_next()) 

177 

178 # Train for `num_epochs`, where for each epoch, we first iterate over 

179 # dataset_range, and then iterate over dataset_evens. 

180 for _ in range(num_epochs): 

181 # Initialize the iterator to `dataset_range` 

182 sess.run(range_initializer) 

183 while True: 

184 try: 

185 pred, loss_val = sess.run([prediction, loss]) 

186 except tf.errors.OutOfRangeError: 

187 break 

188 

189 # Initialize the iterator to `dataset_evens` 

190 sess.run(evens_initializer) 

191 while True: 

192 try: 

193 pred, loss_val = sess.run([prediction, loss]) 

194 except tf.errors.OutOfRangeError: 

195 break 

196 ``` 

197 

198 Args: 

199 output_types: A (nested) structure of `tf.DType` objects corresponding to 

200 each component of an element of this dataset. 

201 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` 

202 objects corresponding to each component of an element of this dataset. 

203 If omitted, each component will have an unconstrainted shape. 

204 shared_name: (Optional.) If non-empty, this iterator will be shared under 

205 the given name across multiple sessions that share the same devices 

206 (e.g. when using a remote server). 

207 output_classes: (Optional.) A (nested) structure of Python `type` objects 

208 corresponding to each component of an element of this iterator. If 

209 omitted, each component is assumed to be of type `tf.Tensor`. 

210 

211 Returns: 

212 An `Iterator`. 

213 

214 Raises: 

215 TypeError: If the structures of `output_shapes` and `output_types` are 

216 not the same. 

217 """ 

218 output_types = nest.map_structure(dtypes.as_dtype, output_types) 

219 if output_shapes is None: 

220 output_shapes = nest.map_structure( 

221 lambda _: tensor_shape.TensorShape(None), output_types) 

222 else: 

223 output_shapes = nest.map_structure_up_to(output_types, 

224 tensor_shape.as_shape, 

225 output_shapes) 

226 if output_classes is None: 

227 output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) 

228 nest.assert_same_structure(output_types, output_shapes) 

229 output_structure = structure.convert_legacy_structure( 

230 output_types, output_shapes, output_classes) 

231 if shared_name is None: 

232 shared_name = "" 

233 iterator_resource = gen_dataset_ops.iterator_v2( 

234 container="", 

235 shared_name=shared_name, 

236 output_types=structure.get_flat_tensor_types(output_structure), 

237 output_shapes=structure.get_flat_tensor_shapes( 

238 output_structure)) 

239 return Iterator(iterator_resource, None, output_types, output_shapes, 

240 output_classes) 

241 

242 @staticmethod 

243 def from_string_handle(string_handle, 

244 output_types, 

245 output_shapes=None, 

246 output_classes=None): 

247 """Creates a new, uninitialized `Iterator` based on the given handle. 

248 

249 This method allows you to define a "feedable" iterator where you can choose 

250 between concrete iterators by feeding a value in a `tf.Session.run` call. 

251 In that case, `string_handle` would be a `tf.compat.v1.placeholder`, and you 

252 would 

253 feed it with the value of `tf.data.Iterator.string_handle` in each step. 

254 

255 For example, if you had two iterators that marked the current position in 

256 a training dataset and a test dataset, you could choose which to use in 

257 each step as follows: 

258 

259 ```python 

260 train_iterator = tf.data.Dataset(...).make_one_shot_iterator() 

261 train_iterator_handle = sess.run(train_iterator.string_handle()) 

262 

263 test_iterator = tf.data.Dataset(...).make_one_shot_iterator() 

264 test_iterator_handle = sess.run(test_iterator.string_handle()) 

265 

266 handle = tf.compat.v1.placeholder(tf.string, shape=[]) 

267 iterator = tf.data.Iterator.from_string_handle( 

268 handle, train_iterator.output_types) 

269 

270 next_element = iterator.get_next() 

271 loss = f(next_element) 

272 

273 train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) 

274 test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) 

275 ``` 

276 

277 Args: 

278 string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to 

279 a handle produced by the `Iterator.string_handle()` method. 

280 output_types: A (nested) structure of `tf.DType` objects corresponding to 

281 each component of an element of this dataset. 

282 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` 

283 objects corresponding to each component of an element of this dataset. 

284 If omitted, each component will have an unconstrainted shape. 

285 output_classes: (Optional.) A (nested) structure of Python `type` objects 

286 corresponding to each component of an element of this iterator. If 

287 omitted, each component is assumed to be of type `tf.Tensor`. 

288 

289 Returns: 

290 An `Iterator`. 

291 """ 

292 output_types = nest.map_structure(dtypes.as_dtype, output_types) 

293 if output_shapes is None: 

294 output_shapes = nest.map_structure( 

295 lambda _: tensor_shape.TensorShape(None), output_types) 

296 else: 

297 output_shapes = nest.map_structure_up_to(output_types, 

298 tensor_shape.as_shape, 

299 output_shapes) 

300 if output_classes is None: 

301 output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) 

302 nest.assert_same_structure(output_types, output_shapes) 

303 output_structure = structure.convert_legacy_structure( 

304 output_types, output_shapes, output_classes) 

305 string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) 

306 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( 

307 string_handle, 

308 output_types=structure.get_flat_tensor_types(output_structure), 

309 output_shapes=structure.get_flat_tensor_shapes(output_structure)) 

310 return Iterator(iterator_resource, None, output_types, output_shapes, 

311 output_classes) 

312 

313 @property 

314 def initializer(self): 

315 """A `tf.Operation` that should be run to initialize this iterator. 

316 

317 Returns: 

318 A `tf.Operation` that should be run to initialize this iterator 

319 

320 Raises: 

321 ValueError: If this iterator initializes itself automatically. 

322 """ 

323 if self._initializer is not None: 

324 return self._initializer 

325 else: 

326 # TODO(mrry): Consider whether one-shot iterators should have 

327 # initializers that simply reset their state to the beginning. 

328 raise ValueError( 

329 "The iterator does not have an initializer. This means it was likely " 

330 "created using `tf.data.Dataset.make_one_shot_iterator()`. For an " 

331 "initializable iterator, use " 

332 "`tf.data.Dataset.make_initializable_iterator()` instead.") 

333 

334 def make_initializer(self, dataset, name=None): 

335 """Returns a `tf.Operation` that initializes this iterator on `dataset`. 

336 

337 Args: 

338 dataset: A `Dataset` whose `element_spec` if compatible with this 

339 iterator. 

340 name: (Optional.) A name for the created operation. 

341 

342 Returns: 

343 A `tf.Operation` that can be run to initialize this iterator on the given 

344 `dataset`. 

345 

346 Raises: 

347 TypeError: If `dataset` and this iterator do not have a compatible 

348 `element_spec`. 

349 """ 

350 with ops.name_scope(name, "make_initializer") as name: 

351 # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due 

352 # to that creating a circular dependency. 

353 # pylint: disable=protected-access 

354 dataset_output_types = nest.map_structure( 

355 lambda component_spec: component_spec._to_legacy_output_types(), 

356 dataset.element_spec) 

357 dataset_output_shapes = nest.map_structure( 

358 lambda component_spec: component_spec._to_legacy_output_shapes(), 

359 dataset.element_spec) 

360 dataset_output_classes = nest.map_structure( 

361 lambda component_spec: component_spec._to_legacy_output_classes(), 

362 dataset.element_spec) 

363 # pylint: enable=protected-access 

364 

365 nest.assert_same_structure(self.output_types, dataset_output_types) 

366 nest.assert_same_structure(self.output_shapes, dataset_output_shapes) 

367 for iterator_class, dataset_class in zip( 

368 nest.flatten(self.output_classes), 

369 nest.flatten(dataset_output_classes)): 

370 if iterator_class is not dataset_class: 

371 raise TypeError( 

372 f"Expected output classes {self.output_classes!r} but got " 

373 f"dataset with output classes {dataset_output_classes!r}.") 

374 for iterator_dtype, dataset_dtype in zip( 

375 nest.flatten(self.output_types), nest.flatten(dataset_output_types)): 

376 if iterator_dtype != dataset_dtype: 

377 raise TypeError( 

378 f"Expected output types {self.output_types!r} but got dataset " 

379 f"with output types {dataset_output_types!r}.") 

380 for iterator_shape, dataset_shape in zip( 

381 nest.flatten(self.output_shapes), nest.flatten( 

382 dataset_output_shapes)): 

383 if not iterator_shape.is_compatible_with(dataset_shape): 

384 raise TypeError( 

385 f"Expected output shapes compatible with {self.output_shapes!r} " 

386 f"but got dataset with output shapes {dataset_output_shapes!r}.") 

387 

388 # TODO(b/169442955): Investigate the need for this colocation constraint. 

389 with ops.colocate_with(self._iterator_resource): 

390 # pylint: disable=protected-access 

391 return gen_dataset_ops.make_iterator( 

392 dataset._variant_tensor, self._iterator_resource, name=name) 

393 

394 def get_next(self, name=None): 

395 """Returns the next element. 

396 

397 In graph mode, you should typically call this method *once* and use its 

398 result as the input to another computation. A typical loop will then call 

399 `tf.Session.run` on the result of that computation. The loop will terminate 

400 when the `Iterator.get_next()` operation raises 

401 `tf.errors.OutOfRangeError`. The following skeleton shows how to use 

402 this method when building a training loop: 

403 

404 ```python 

405 dataset = ... # A `tf.data.Dataset` object. 

406 iterator = dataset.make_initializable_iterator() 

407 next_element = iterator.get_next() 

408 

409 # Build a TensorFlow graph that does something with each element. 

410 loss = model_function(next_element) 

411 optimizer = ... # A `tf.compat.v1.train.Optimizer` object. 

412 train_op = optimizer.minimize(loss) 

413 

414 with tf.compat.v1.Session() as sess: 

415 try: 

416 while True: 

417 sess.run(train_op) 

418 except tf.errors.OutOfRangeError: 

419 pass 

420 ``` 

421 

422 NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g. 

423 when you are distributing different elements to multiple devices in a single 

424 step. However, a common pitfall arises when users call `Iterator.get_next()` 

425 in each iteration of their training loop. `Iterator.get_next()` adds ops to 

426 the graph, and executing each op allocates resources (including threads); as 

427 a consequence, invoking it in every iteration of a training loop causes 

428 slowdown and eventual resource exhaustion. To guard against this outcome, we 

429 log a warning when the number of uses crosses a fixed threshold of 

430 suspiciousness. 

431 

432 Args: 

433 name: (Optional.) A name for the created operation. 

434 

435 Returns: 

436 A (nested) structure of values matching `tf.data.Iterator.element_spec`. 

437 """ 

438 self._get_next_call_count += 1 

439 if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: 

440 warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) 

441 

442 # TODO(b/169442955): Investigate the need for this colocation constraint. 

443 with ops.colocate_with(self._iterator_resource): 

444 # pylint: disable=protected-access 

445 flat_ret = gen_dataset_ops.iterator_get_next( 

446 self._iterator_resource, 

447 output_types=self._flat_tensor_types, 

448 output_shapes=self._flat_tensor_shapes, 

449 name=name) 

450 return structure.from_tensor_list(self._element_spec, flat_ret) 

451 

452 def get_next_as_optional(self): 

453 # TODO(b/169442955): Investigate the need for this colocation constraint. 

454 with ops.colocate_with(self._iterator_resource): 

455 # pylint: disable=protected-access 

456 return optional_ops._OptionalImpl( 

457 gen_dataset_ops.iterator_get_next_as_optional( 

458 self._iterator_resource, 

459 output_types=structure.get_flat_tensor_types(self.element_spec), 

460 output_shapes=structure.get_flat_tensor_shapes( 

461 self.element_spec)), self.element_spec) 

462 

463 def string_handle(self, name=None): 

464 """Returns a string-valued `tf.Tensor` that represents this iterator. 

465 

466 Args: 

467 name: (Optional.) A name for the created operation. 

468 

469 Returns: 

470 A scalar `tf.Tensor` of type `tf.string`. 

471 """ 

472 if name is None: 

473 return self._string_handle 

474 else: 

475 return gen_dataset_ops.iterator_to_string_handle( 

476 self._iterator_resource, name=name) 

477 

478 @property 

479 @deprecation.deprecated( 

480 None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.") 

481 def output_classes(self): 

482 """Returns the class of each component of an element of this iterator. 

483 

484 The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`. 

485 

486 Returns: 

487 A (nested) structure of Python `type` objects corresponding to each 

488 component of an element of this dataset. 

489 """ 

490 return nest.map_structure( 

491 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 

492 self._element_spec) 

493 

494 @property 

495 @deprecation.deprecated( 

496 None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.") 

497 def output_shapes(self): 

498 """Returns the shape of each component of an element of this iterator. 

499 

500 Returns: 

501 A (nested) structure of `tf.TensorShape` objects corresponding to each 

502 component of an element of this dataset. 

503 """ 

504 return nest.map_structure( 

505 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 

506 self._element_spec) 

507 

508 @property 

509 @deprecation.deprecated( 

510 None, "Use `tf.compat.v1.data.get_output_types(iterator)`.") 

511 def output_types(self): 

512 """Returns the type of each component of an element of this iterator. 

513 

514 Returns: 

515 A (nested) structure of `tf.DType` objects corresponding to each component 

516 of an element of this dataset. 

517 """ 

518 return nest.map_structure( 

519 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 

520 self._element_spec) 

521 

522 @property 

523 def element_spec(self): 

524 """The type specification of an element of this iterator. 

525 

526 For more information, 

527 read [this guide](https://www.tensorflow.org/guide/data#dataset_structure). 

528 

529 Returns: 

530 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 

531 element of this iterator and specifying the type of individual components. 

532 """ 

533 

534 return self._element_spec 

535 

536 # override 

537 def _serialize_to_tensors(self): 

538 serialized_iterator = gen_dataset_ops.serialize_iterator( 

539 self._iterator_resource, 

540 options_lib.ExternalStatePolicy.FAIL.value) 

541 return {"_STATE": serialized_iterator} 

542 

543 # override 

544 def _restore_from_tensors(self, restored_tensors): 

545 with ops.colocate_with(self._iterator_resource): 

546 return [gen_dataset_ops.deserialize_iterator( 

547 self._iterator_resource, restored_tensors["_STATE"])] 

548 

549 

550_uid_counter = 0 

551_uid_lock = threading.Lock() 

552 

553 

554def _generate_shared_name(prefix): 

555 with _uid_lock: 

556 global _uid_counter 

557 uid = _uid_counter 

558 _uid_counter += 1 

559 return "{}{}".format(prefix, uid) 

560 

561 

562@tf_export("data.Iterator", v1=[]) 

563class IteratorBase( 

564 collections_abc.Iterator, 

565 trackable.Trackable, 

566 composite_tensor.CompositeTensor, 

567 metaclass=abc.ABCMeta): 

568 """Represents an iterator of a `tf.data.Dataset`. 

569 

570 `tf.data.Iterator` is the primary mechanism for enumerating elements of a 

571 `tf.data.Dataset`. It supports the Python Iterator protocol, which means 

572 it can be iterated over using a for-loop: 

573 

574 >>> dataset = tf.data.Dataset.range(2) 

575 >>> for element in dataset: 

576 ... print(element) 

577 tf.Tensor(0, shape=(), dtype=int64) 

578 tf.Tensor(1, shape=(), dtype=int64) 

579 

580 or by fetching individual elements explicitly via `get_next()`: 

581 

582 >>> dataset = tf.data.Dataset.range(2) 

583 >>> iterator = iter(dataset) 

584 >>> print(iterator.get_next()) 

585 tf.Tensor(0, shape=(), dtype=int64) 

586 >>> print(iterator.get_next()) 

587 tf.Tensor(1, shape=(), dtype=int64) 

588 

589 In addition, non-raising iteration is supported via `get_next_as_optional()`, 

590 which returns the next element (if available) wrapped in a 

591 `tf.experimental.Optional`. 

592 

593 >>> dataset = tf.data.Dataset.from_tensors(42) 

594 >>> iterator = iter(dataset) 

595 >>> optional = iterator.get_next_as_optional() 

596 >>> print(optional.has_value()) 

597 tf.Tensor(True, shape=(), dtype=bool) 

598 >>> optional = iterator.get_next_as_optional() 

599 >>> print(optional.has_value()) 

600 tf.Tensor(False, shape=(), dtype=bool) 

601 """ 

602 

603 @abc.abstractproperty 

604 def element_spec(self): 

605 """The type specification of an element of this iterator. 

606 

607 >>> dataset = tf.data.Dataset.from_tensors(42) 

608 >>> iterator = iter(dataset) 

609 >>> iterator.element_spec 

610 tf.TensorSpec(shape=(), dtype=tf.int32, name=None) 

611 

612 For more information, 

613 read [this guide](https://www.tensorflow.org/guide/data#dataset_structure). 

614 

615 Returns: 

616 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 

617 element of this iterator, specifying the type of individual components. 

618 """ 

619 raise NotImplementedError("Iterator.element_spec") 

620 

621 @abc.abstractmethod 

622 def get_next(self): 

623 """Returns the next element. 

624 

625 >>> dataset = tf.data.Dataset.from_tensors(42) 

626 >>> iterator = iter(dataset) 

627 >>> print(iterator.get_next()) 

628 tf.Tensor(42, shape=(), dtype=int32) 

629 

630 Returns: 

631 A (nested) structure of values matching `tf.data.Iterator.element_spec`. 

632 

633 Raises: 

634 `tf.errors.OutOfRangeError`: If the end of the iterator has been reached. 

635 """ 

636 raise NotImplementedError("Iterator.get_next()") 

637 

638 @abc.abstractmethod 

639 def get_next_as_optional(self): 

640 """Returns the next element wrapped in `tf.experimental.Optional`. 

641 

642 If the iterator has reached the end of the sequence, the returned 

643 `tf.experimental.Optional` will have no value. 

644 

645 >>> dataset = tf.data.Dataset.from_tensors(42) 

646 >>> iterator = iter(dataset) 

647 >>> optional = iterator.get_next_as_optional() 

648 >>> print(optional.has_value()) 

649 tf.Tensor(True, shape=(), dtype=bool) 

650 >>> print(optional.get_value()) 

651 tf.Tensor(42, shape=(), dtype=int32) 

652 >>> optional = iterator.get_next_as_optional() 

653 >>> print(optional.has_value()) 

654 tf.Tensor(False, shape=(), dtype=bool) 

655 

656 Returns: 

657 A `tf.experimental.Optional` object representing the next element. 

658 """ 

659 raise NotImplementedError("Iterator.get_next_as_optional()") 

660 

661 

662@saveable_compat.legacy_saveable_name("ITERATOR") 

663class OwnedIterator(IteratorBase): 

664 """An iterator producing tf.Tensor objects from a tf.data.Dataset. 

665 

666 The iterator resource created through `OwnedIterator` is owned by the Python 

667 object and the life time of the underlying resource is tied to the life time 

668 of the `OwnedIterator` object. This makes `OwnedIterator` appropriate for use 

669 in eager mode and inside of tf.functions. 

670 """ 

671 

672 def __init__(self, dataset=None, components=None, element_spec=None): 

673 """Creates a new iterator from the given dataset. 

674 

675 If `dataset` is not specified, the iterator will be created from the given 

676 tensor components and element structure. In particular, the alternative for 

677 constructing the iterator is used when the iterator is reconstructed from 

678 it `CompositeTensor` representation. 

679 

680 Args: 

681 dataset: A `tf.data.Dataset` object. 

682 components: Tensor components to construct the iterator from. 

683 element_spec: A (nested) structure of `TypeSpec` objects that 

684 represents the type specification of elements of the iterator. 

685 

686 Raises: 

687 ValueError: If `dataset` is not provided and either `components` or 

688 `element_spec` is not provided. Or `dataset` is provided and either 

689 `components` and `element_spec` is provided. 

690 """ 

691 super(OwnedIterator, self).__init__() 

692 

693 if dataset is None: 

694 if (components is None or element_spec is None): 

695 raise ValueError( 

696 "When `dataset` is not provided, both `components` and " 

697 "`element_spec` must be specified.") 

698 # pylint: disable=protected-access 

699 self._element_spec = element_spec 

700 self._flat_output_types = structure.get_flat_tensor_types( 

701 self._element_spec) 

702 self._flat_output_shapes = structure.get_flat_tensor_shapes( 

703 self._element_spec) 

704 self._iterator_resource, = components 

705 else: 

706 if (components is not None or element_spec is not None): 

707 raise ValueError( 

708 "When `dataset` is provided, `element_spec` and `components` must " 

709 "not be specified.") 

710 self._create_iterator(dataset) 

711 

712 self._get_next_call_count = 0 

713 

714 def _create_iterator(self, dataset): 

715 # pylint: disable=protected-access 

716 dataset = dataset._apply_debug_options() 

717 

718 # Store dataset reference to ensure that dataset is alive when this iterator 

719 # is being used. For example, `tf.data.Dataset.from_generator` registers 

720 # a few py_funcs that are needed in `self._next_internal`. If the dataset 

721 # is deleted, this iterator crashes on `self.__next__(...)` call. 

722 self._dataset = dataset 

723 

724 ds_variant = dataset._variant_tensor 

725 self._element_spec = dataset.element_spec 

726 self._flat_output_types = structure.get_flat_tensor_types( 

727 self._element_spec) 

728 self._flat_output_shapes = structure.get_flat_tensor_shapes( 

729 self._element_spec) 

730 with ops.colocate_with(ds_variant): 

731 self._iterator_resource = ( 

732 gen_dataset_ops.anonymous_iterator_v3( 

733 output_types=self._flat_output_types, 

734 output_shapes=self._flat_output_shapes)) 

735 if not context.executing_eagerly(): 

736 # Add full type information to the graph so host memory types inside 

737 # variants stay on CPU, e.g, ragged string tensors. 

738 # TODO(b/224776031) Remove this when AnonymousIterateV3 can use 

739 # (reverse) type inference and all other ops that are needed to 

740 # provide type information to the AnonymousIterateV3 also support 

741 # type inference (esp. cross-function type inference) instead of 

742 # setting the full type information manually. 

743 fulltype = type_utils.iterator_full_type_from_spec( 

744 self._element_spec) 

745 # fulltype is PRODUCT[ITERATOR[PRODUCT[...]]] 

746 assert len(fulltype.args[0].args[0].args) == len( 

747 self._flat_output_types) 

748 self._iterator_resource.op.experimental_set_type(fulltype) 

749 gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource) 

750 

751 def __iter__(self): 

752 return self 

753 

754 def next(self): # For Python 2 compatibility 

755 return self.__next__() 

756 

757 def _next_internal(self): 

758 autograph_status = autograph_ctx.control_status_ctx().status 

759 autograph_disabled = autograph_status == autograph_ctx.Status.DISABLED 

760 if not context.executing_eagerly() and autograph_disabled: 

761 self._get_next_call_count += 1 

762 if self._get_next_call_count > GET_NEXT_CALL_ERROR_THRESHOLD: 

763 raise ValueError(GET_NEXT_CALL_ERROR_MESSAGE) 

764 

765 if not context.executing_eagerly(): 

766 # TODO(b/169442955): Investigate the need for this colocation constraint. 

767 with ops.colocate_with(self._iterator_resource): 

768 ret = gen_dataset_ops.iterator_get_next( 

769 self._iterator_resource, 

770 output_types=self._flat_output_types, 

771 output_shapes=self._flat_output_shapes) 

772 return structure.from_compatible_tensor_list(self._element_spec, ret) 

773 

774 # TODO(b/77291417): This runs in sync mode as iterators use an error status 

775 # to communicate that there is no more data to iterate over. 

776 with context.execution_mode(context.SYNC): 

777 ret = gen_dataset_ops.iterator_get_next( 

778 self._iterator_resource, 

779 output_types=self._flat_output_types, 

780 output_shapes=self._flat_output_shapes) 

781 

782 try: 

783 # Fast path for the case `self._structure` is not a nested structure. 

784 return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access 

785 except AttributeError: 

786 return structure.from_compatible_tensor_list(self._element_spec, ret) 

787 

788 def _save(self): 

789 external_state_policy = None 

790 if ( 

791 self._dataset 

792 and self._dataset.options().experimental_external_state_policy 

793 ): 

794 external_state_policy = ( 

795 self._dataset.options().experimental_external_state_policy.value 

796 ) 

797 state_variant = gen_dataset_ops.serialize_iterator( 

798 self._iterator_resource, external_state_policy 

799 ) 

800 return parsing_ops.serialize_tensor(state_variant) 

801 

802 def _restore(self, state): 

803 state_variant = parsing_ops.parse_tensor(state, dtypes.variant) 

804 return gen_dataset_ops.deserialize_iterator( 

805 self._iterator_resource, state_variant 

806 ) 

807 

808 @property 

809 def _type_spec(self): 

810 return IteratorSpec(self.element_spec) 

811 

812 def __next__(self): 

813 try: 

814 return self._next_internal() 

815 except errors.OutOfRangeError: 

816 raise StopIteration 

817 

818 @property 

819 @deprecation.deprecated( 

820 None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.") 

821 def output_classes(self): 

822 """Returns the class of each component of an element of this iterator. 

823 

824 The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`. 

825 

826 Returns: 

827 A (nested) structure of Python `type` objects corresponding to each 

828 component of an element of this dataset. 

829 """ 

830 return nest.map_structure( 

831 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 

832 self._element_spec) 

833 

834 @property 

835 @deprecation.deprecated( 

836 None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.") 

837 def output_shapes(self): 

838 """Returns the shape of each component of an element of this iterator. 

839 

840 Returns: 

841 A (nested) structure of `tf.TensorShape` objects corresponding to each 

842 component of an element of this dataset. 

843 """ 

844 return nest.map_structure( 

845 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 

846 self._element_spec) 

847 

848 @property 

849 @deprecation.deprecated( 

850 None, "Use `tf.compat.v1.data.get_output_types(iterator)`.") 

851 def output_types(self): 

852 """Returns the type of each component of an element of this iterator. 

853 

854 Returns: 

855 A (nested) structure of `tf.DType` objects corresponding to each component 

856 of an element of this dataset. 

857 """ 

858 return nest.map_structure( 

859 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 

860 self._element_spec) 

861 

862 @property 

863 def element_spec(self): 

864 return self._element_spec 

865 

866 def get_next(self): 

867 return self._next_internal() 

868 

869 def get_next_as_optional(self): 

870 # TODO(b/169442955): Investigate the need for this colocation constraint. 

871 with ops.colocate_with(self._iterator_resource): 

872 # pylint: disable=protected-access 

873 return optional_ops._OptionalImpl( 

874 gen_dataset_ops.iterator_get_next_as_optional( 

875 self._iterator_resource, 

876 output_types=structure.get_flat_tensor_types(self.element_spec), 

877 output_shapes=structure.get_flat_tensor_shapes( 

878 self.element_spec)), self.element_spec) 

879 

880 def _serialize_to_tensors(self): 

881 serialized_iterator = None 

882 if (self._dataset and 

883 self._dataset.options().experimental_external_state_policy): 

884 serialized_iterator = gen_dataset_ops.serialize_iterator( 

885 self._iterator_resource, 

886 self._dataset.options().experimental_external_state_policy.value) 

887 else: 

888 serialized_iterator = gen_dataset_ops.serialize_iterator( 

889 self._iterator_resource, 

890 options_lib.ExternalStatePolicy.FAIL.value) 

891 return {"_STATE": serialized_iterator} 

892 

893 def _restore_from_tensors(self, restored_tensors): 

894 with ops.colocate_with(self._iterator_resource): 

895 return [gen_dataset_ops.deserialize_iterator( 

896 self._iterator_resource, restored_tensors["_STATE"])] 

897 

898 def __tf_tracing_type__(self, _): 

899 return self._type_spec 

900 

901 

902@tf_export("data.IteratorSpec", v1=[]) 

903class IteratorSpec(type_spec.TypeSpec): 

904 """Type specification for `tf.data.Iterator`. 

905 

906 For instance, `tf.data.IteratorSpec` can be used to define a tf.function that 

907 takes `tf.data.Iterator` as an input argument: 

908 

909 >>> @tf.function(input_signature=[tf.data.IteratorSpec( 

910 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))]) 

911 ... def square(iterator): 

912 ... x = iterator.get_next() 

913 ... return x * x 

914 >>> dataset = tf.data.Dataset.from_tensors(5) 

915 >>> iterator = iter(dataset) 

916 >>> print(square(iterator)) 

917 tf.Tensor(25, shape=(), dtype=int32) 

918 

919 Attributes: 

920 element_spec: A (nested) structure of `tf.TypeSpec` objects that represents 

921 the type specification of the iterator elements. 

922 """ 

923 

924 __slots__ = ["_element_spec"] 

925 

926 def __init__(self, element_spec): 

927 self._element_spec = element_spec 

928 

929 @property 

930 def value_type(self): 

931 return OwnedIterator 

932 

933 def _serialize(self): 

934 return (self._element_spec,) 

935 

936 @property 

937 def _component_specs(self): 

938 return (tensor_spec.TensorSpec([], dtypes.resource),) 

939 

940 def _to_components(self, value): 

941 return (value._iterator_resource,) # pylint: disable=protected-access 

942 

943 def _from_components(self, components): 

944 return OwnedIterator( 

945 dataset=None, 

946 components=components, 

947 element_spec=self._element_spec) 

948 

949 @staticmethod 

950 def from_value(value): 

951 return IteratorSpec(value.element_spec) # pylint: disable=protected-access 

952 

953 

954# TODO(b/71645805): Expose trackable stateful objects from dataset. 

955class _IteratorSaveable(BaseSaverBuilder.SaveableObject): 

956 """SaveableObject for saving/restoring iterator state.""" 

957 

958 def __init__( 

959 self, 

960 iterator_resource, 

961 name, 

962 external_state_policy=options_lib.ExternalStatePolicy.FAIL): 

963 serialized_iterator = gen_dataset_ops.serialize_iterator( 

964 iterator_resource, external_state_policy=external_state_policy.value) 

965 specs = [ 

966 BaseSaverBuilder.SaveSpec( 

967 serialized_iterator, 

968 "", 

969 name + "_STATE", 

970 device=iterator_resource.device) 

971 ] 

972 super(_IteratorSaveable, self).__init__(iterator_resource, specs, name) 

973 

974 def restore(self, restored_tensors, restored_shapes): 

975 with ops.colocate_with(self.op): 

976 return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) 

977 

978 

979nested_structure_coder.register_codec( 

980 nested_structure_coder.BuiltInTypeSpecCodec( 

981 IteratorSpec, struct_pb2.TypeSpecProto.DATA_ITERATOR_SPEC 

982 ) 

983) 

984 

985 

986@deprecation.deprecated( 

987 None, "Use `tf.data.Iterator.get_next_as_optional()` instead.") 

988@tf_export("data.experimental.get_next_as_optional") 

989def get_next_as_optional(iterator): 

990 """Returns a `tf.experimental.Optional` with the next element of the iterator. 

991 

992 If the iterator has reached the end of the sequence, the returned 

993 `tf.experimental.Optional` will have no value. 

994 

995 Args: 

996 iterator: A `tf.data.Iterator`. 

997 

998 Returns: 

999 A `tf.experimental.Optional` object which either contains the next element 

1000 of the iterator (if it exists) or no value. 

1001 """ 

1002 return iterator.get_next_as_optional() 

1003 

1004 

1005_pywrap_utils.RegisterType("OwnedIterator", OwnedIterator) 

1006iterator_autograph.register_overrides()