Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py: 35%

1070 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python wrappers for Datasets.""" 

16import abc 

17import functools 

18import queue 

19import threading 

20import warnings 

21 

22import numpy as np 

23 

24from tensorflow.core.framework import dataset_metadata_pb2 

25from tensorflow.core.framework import dataset_options_pb2 

26from tensorflow.core.framework import graph_pb2 

27from tensorflow.core.protobuf import struct_pb2 

28from tensorflow.python import tf2 

29from tensorflow.python.data.ops import dataset_autograph 

30from tensorflow.python.data.ops import debug_mode 

31from tensorflow.python.data.ops import iterator_ops 

32from tensorflow.python.data.ops import options as options_lib 

33from tensorflow.python.data.ops import structured_function 

34from tensorflow.python.data.util import nest 

35from tensorflow.python.data.util import structure 

36from tensorflow.python.data.util import traverse 

37from tensorflow.python.eager import context 

38from tensorflow.python.framework import auto_control_deps 

39from tensorflow.python.framework import auto_control_deps_utils as acd_utils 

40from tensorflow.python.framework import composite_tensor 

41from tensorflow.python.framework import constant_op 

42from tensorflow.python.framework import dtypes 

43from tensorflow.python.framework import function 

44from tensorflow.python.framework import ops 

45from tensorflow.python.framework import random_seed as core_random_seed 

46from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 

47from tensorflow.python.framework import tensor_shape 

48from tensorflow.python.framework import tensor_spec 

49from tensorflow.python.framework import tensor_util 

50from tensorflow.python.framework import type_spec 

51from tensorflow.python.ops import array_ops 

52from tensorflow.python.ops import check_ops 

53from tensorflow.python.ops import cond 

54from tensorflow.python.ops import control_flow_assert 

55from tensorflow.python.ops import gen_dataset_ops 

56from tensorflow.python.ops import gen_io_ops 

57from tensorflow.python.ops import gen_parsing_ops 

58from tensorflow.python.ops import logging_ops 

59from tensorflow.python.ops import math_ops 

60from tensorflow.python.ops import random_ops 

61from tensorflow.python.ops import string_ops 

62from tensorflow.python.ops.ragged import ragged_tensor 

63from tensorflow.python.saved_model import nested_structure_coder 

64from tensorflow.python.trackable import asset 

65from tensorflow.python.trackable import base as tracking_base 

66from tensorflow.python.trackable import resource as resource_lib 

67from tensorflow.python.types import data as data_types 

68from tensorflow.python.types import trace 

69from tensorflow.python.util import deprecation 

70from tensorflow.python.util import lazy_loader 

71from tensorflow.python.util import nest as tf_nest 

72from tensorflow.python.util.compat import collections_abc 

73from tensorflow.python.util.tf_export import tf_export 

74 

75# Symbols forwarded for legacy access through dataset_ops.py. These forwarded 

76# symbols can be removed once all internal uses are updated. 

77StructuredFunctionWrapper = structured_function.StructuredFunctionWrapper 

78 

79# Loaded lazily due to a circular dependency (roughly 

80# tf.function->wrap_function->dataset->autograph->tf.function). 

81# TODO(b/133251390): Use a regular import. 

82wrap_function = lazy_loader.LazyLoader( 

83 "wrap_function", globals(), 

84 "tensorflow.python.eager.wrap_function") 

85# Loaded lazily due to a circular dependency 

86# dataset_ops->def_function->func_graph->autograph->dataset_ops 

87# TODO(kathywu): Use a regular import. 

88def_function = lazy_loader.LazyLoader( 

89 "def_function", globals(), 

90 "tensorflow.python.eager.def_function") 

91# TODO(b/240947712): Clean up the circular dependencies. 

92# Loaded lazily due to a circular dependency (dataset_ops -> 

93# prefetch_op -> dataset_ops). 

94prefetch_op = lazy_loader.LazyLoader( 

95 "prefetch_op", globals(), 

96 "tensorflow.python.data.ops.prefetch_op") 

97# Loaded lazily due to a circular dependency (dataset_ops -> 

98# shuffle_op -> dataset_ops). 

99shuffle_op = lazy_loader.LazyLoader( 

100 "shuffle_op", globals(), 

101 "tensorflow.python.data.ops.shuffle_op") 

102 

103 

104ops.NotDifferentiable("ReduceDataset") 

105 

106# A constant that can be used to enable auto-tuning. 

107AUTOTUNE = -1 

108tf_export("data.AUTOTUNE").export_constant(__name__, "AUTOTUNE") 

109# TODO(b/168128531): Deprecate and remove this symbol. 

110tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE") 

111 

112# Constants representing infinite and unknown cardinalities. 

113INFINITE = -1 

114UNKNOWN = -2 

115COMPRESSION_GZIP = "GZIP" 

116COMPRESSION_SNAPPY = "NONE" 

117DATASET_SPEC_FILENAME = "dataset_spec.pb" 

118tf_export("data.INFINITE_CARDINALITY").export_constant(__name__, "INFINITE") 

119tf_export("data.UNKNOWN_CARDINALITY").export_constant(__name__, "UNKNOWN") 

120 

121 

122def _validate_and_encode(name): 

123 if not name.isidentifier(): 

124 raise ValueError("Invalid `name`. The argument `name` needs to be a valid " 

125 "identifier. Value is considered a valid identifier if it " 

126 "only contains alphanumeric characters (a-z), (A-Z), and " 

127 "(0-9), or underscores (_). A valid identifier cannot " 

128 "start with a number, or contain any spaces.") 

129 return name.encode("utf-8") 

130 

131 

132def get_type(value): 

133 """Returns the type of `value` if it is a TypeSpec.""" 

134 

135 if isinstance(value, type_spec.TypeSpec): 

136 return value.value_type() 

137 else: 

138 return type(value) 

139 

140 

141@tf_export("data.Dataset", v1=[]) 

142class DatasetV2( 

143 collections_abc.Iterable, 

144 tracking_base.Trackable, 

145 composite_tensor.CompositeTensor, 

146 data_types.DatasetV2, 

147 metaclass=abc.ABCMeta): 

148 """Represents a potentially large set of elements. 

149 

150 The `tf.data.Dataset` API supports writing descriptive and efficient input 

151 pipelines. `Dataset` usage follows a common pattern: 

152 

153 1. Create a source dataset from your input data. 

154 2. Apply dataset transformations to preprocess the data. 

155 3. Iterate over the dataset and process the elements. 

156 

157 Iteration happens in a streaming fashion, so the full dataset does not need to 

158 fit into memory. 

159 

160 Source Datasets: 

161 

162 The simplest way to create a dataset is to create it from a python `list`: 

163 

164 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

165 >>> for element in dataset: 

166 ... print(element) 

167 tf.Tensor(1, shape=(), dtype=int32) 

168 tf.Tensor(2, shape=(), dtype=int32) 

169 tf.Tensor(3, shape=(), dtype=int32) 

170 

171 To process lines from files, use `tf.data.TextLineDataset`: 

172 

173 >>> dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"]) 

174 

175 To process records written in the `TFRecord` format, use `TFRecordDataset`: 

176 

177 >>> dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"]) 

178 

179 To create a dataset of all files matching a pattern, use 

180 `tf.data.Dataset.list_files`: 

181 

182 ```python 

183 dataset = tf.data.Dataset.list_files("/path/*.txt") 

184 ``` 

185 

186 See `tf.data.FixedLengthRecordDataset` and `tf.data.Dataset.from_generator` 

187 for more ways to create datasets. 

188 

189 Transformations: 

190 

191 Once you have a dataset, you can apply transformations to prepare the data for 

192 your model: 

193 

194 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

195 >>> dataset = dataset.map(lambda x: x*2) 

196 >>> list(dataset.as_numpy_iterator()) 

197 [2, 4, 6] 

198 

199 Common Terms: 

200 

201 **Element**: A single output from calling `next()` on a dataset iterator. 

202 Elements may be nested structures containing multiple components. For 

203 example, the element `(1, (3, "apple"))` has one tuple nested in another 

204 tuple. The components are `1`, `3`, and `"apple"`. 

205 

206 **Component**: The leaf in the nested structure of an element. 

207 

208 Supported types: 

209 

210 Elements can be nested structures of tuples, named tuples, and dictionaries. 

211 Note that Python lists are *not* treated as nested structures of components. 

212 Instead, lists are converted to tensors and treated as components. For 

213 example, the element `(1, [1, 2, 3])` has only two components; the tensor `1` 

214 and the tensor `[1, 2, 3]`. Element components can be of any type 

215 representable by `tf.TypeSpec`, including `tf.Tensor`, `tf.data.Dataset`, 

216 `tf.sparse.SparseTensor`, `tf.RaggedTensor`, and `tf.TensorArray`. 

217 

218 ```python 

219 a = 1 # Integer element 

220 b = 2.0 # Float element 

221 c = (1, 2) # Tuple element with 2 components 

222 d = {"a": (2, 2), "b": 3} # Dict element with 3 components 

223 Point = collections.namedtuple("Point", ["x", "y"]) 

224 e = Point(1, 2) # Named tuple 

225 f = tf.data.Dataset.range(10) # Dataset element 

226 ``` 

227 

228 For more information, 

229 read [this guide](https://www.tensorflow.org/guide/data). 

230 """ 

231 

232 def __init__(self, variant_tensor): 

233 """Creates a DatasetV2 object. 

234 

235 This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not 

236 take anything in its constructor whereas in the DatasetV2, we expect 

237 subclasses to create a variant_tensor and pass it in to the super() call. 

238 

239 Args: 

240 variant_tensor: A DT_VARIANT tensor that represents the dataset. 

241 """ 

242 self._variant_tensor_attr = variant_tensor 

243 self._graph_attr = ops.get_default_graph() 

244 

245 # Initialize the options for this dataset and its inputs. 

246 self._options_attr = options_lib.Options() 

247 for input_dataset in self._inputs(): 

248 input_options = None 

249 if isinstance(input_dataset, data_types.DatasetV1): 

250 # If the V1 dataset does not have the `_dataset` attribute, we assume it 

251 # is a dataset source and hence does not have options. Otherwise, we 

252 # grab the options of `_dataset` object 

253 if hasattr(input_dataset, "_dataset"): 

254 if not isinstance(input_dataset._dataset, data_types.DatasetV2): 

255 raise TypeError( 

256 f"Each input of dataset {type(self)} should be a subclass of " 

257 f"`tf.data.Dataset` but encountered " 

258 f"{type(input_dataset._dataset)}.") 

259 input_options = input_dataset._dataset._options_attr 

260 elif isinstance(input_dataset, data_types.DatasetV2): 

261 input_options = input_dataset._options_attr 

262 else: 

263 raise TypeError( 

264 f"Each input of dataset {type(self)} should be a subclass of " 

265 f"`tf.data.Dataset` but encountered {type(input_dataset)}.") 

266 if input_options is not None: 

267 self._options_attr = self._options_attr.merge(input_options) 

268 self._options_attr._set_mutable(False) # pylint: disable=protected-access 

269 

270 @property 

271 def _variant_tensor(self): 

272 return self._variant_tensor_attr 

273 

274 @_variant_tensor.setter 

275 def _variant_tensor(self, _): 

276 raise ValueError("The `_variant_tensor` property cannot be modified.") 

277 

278 @deprecation.deprecated_args(None, "Use external_state_policy instead", 

279 "allow_stateful") 

280 def _as_serialized_graph( 

281 self, 

282 allow_stateful=None, 

283 strip_device_assignment=None, 

284 external_state_policy=options_lib.ExternalStatePolicy.WARN): 

285 """Produces serialized graph representation of the dataset. 

286 

287 Args: 

288 allow_stateful: If true, we allow stateful ops to be present in the graph 

289 def. In that case, the state in these ops would be thrown away. 

290 strip_device_assignment: If true, non-local (i.e. job and task) device 

291 assignment is stripped from ops in the serialized graph. 

292 external_state_policy: The ExternalStatePolicy enum that determines how we 

293 handle input pipelines that depend on external state. By default, its 

294 set to WARN. 

295 

296 Returns: 

297 A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a 

298 serialized graph. 

299 """ 

300 if external_state_policy: 

301 policy = external_state_policy.value 

302 return gen_dataset_ops.dataset_to_graph_v2( 

303 self._variant_tensor, 

304 external_state_policy=policy, 

305 strip_device_assignment=strip_device_assignment) 

306 if strip_device_assignment: 

307 return gen_dataset_ops.dataset_to_graph( 

308 self._variant_tensor, 

309 allow_stateful=allow_stateful, 

310 strip_device_assignment=strip_device_assignment) 

311 return gen_dataset_ops.dataset_to_graph( 

312 self._variant_tensor, allow_stateful=allow_stateful) 

313 

314 def _maybe_track_assets(self, graph_def): 

315 """Finds and tracks nodes in `graph_def` that refer to asset files. 

316 

317 Args: 

318 graph_def: Serialized graph representation of this dataset. 

319 

320 Returns: 

321 A dictionary mapping the node name of an asset constant to a tracked 

322 `asset.Asset` object. 

323 """ 

324 asset_tracker = {} 

325 for node in graph_def.node: 

326 if node.name.startswith("FileIdentity"): 

327 asset_tracker[node.input[0]] = None 

328 

329 if not asset_tracker: 

330 return {} 

331 

332 for node in graph_def.node: 

333 if node.name in asset_tracker: 

334 tensor_proto = node.attr["value"].tensor 

335 with context.eager_mode(), ops.device("CPU"): 

336 node_value = gen_parsing_ops.parse_tensor( 

337 tensor_proto.SerializeToString(), dtypes.string).numpy() 

338 asset_tracker[node.name] = ([ 

339 self._track_trackable(asset.Asset(n), 

340 name=node.name + "_" + str(i), overwrite=True) 

341 for i, n in enumerate(node_value) 

342 ]) 

343 return asset_tracker 

344 

345 def _trackable_children(self, 

346 save_type=tracking_base.SaveType.CHECKPOINT, 

347 **kwargs): 

348 if save_type != tracking_base.SaveType.SAVEDMODEL: 

349 return {} 

350 

351 # _trace_variant_creation only works when executing eagerly, so we don't 

352 # want to run it in the object initialization. 

353 @def_function.function(input_signature=[], autograph=False) 

354 def _creator(): 

355 resource = self._trace_variant_creation()() # pylint: disable=protected-access 

356 return resource 

357 _creator.get_concrete_function() # Trigger asset tracking 

358 

359 children = super(DatasetV2, self)._trackable_children(save_type, **kwargs) 

360 children["_variant_tracker"] = _VariantTracker(self._variant_tensor, 

361 _creator) 

362 return children 

363 

364 def _trace_variant_creation(self): 

365 """Traces a function which outputs a variant `tf.Tensor` for this dataset. 

366 

367 Note that creating this function involves evaluating an op, and is currently 

368 only supported when executing eagerly. 

369 

370 Returns: 

371 A zero-argument `ConcreteFunction` which outputs a variant `tf.Tensor`. 

372 """ 

373 variant = self._variant_tensor 

374 if not isinstance(variant, ops.EagerTensor): 

375 raise NotImplementedError( 

376 "Constructing a tf.function that reproduces a given dataset is only " 

377 "supported for datasets created eagerly. Please file a feature " 

378 "request if this is important to you.") 

379 with context.eager_mode(), ops.device("CPU"): 

380 # pylint: disable=protected-access 

381 graph_def = graph_pb2.GraphDef().FromString( 

382 self._as_serialized_graph(external_state_policy=options_lib 

383 .ExternalStatePolicy.FAIL).numpy()) 

384 output_node_names = [] 

385 for node in graph_def.node: 

386 if node.op == "_Retval": 

387 output_node_names = node.input 

388 

389 if len(output_node_names) != 1: 

390 raise AssertionError( 

391 f"Dataset graph is expected to only have one return value but found " 

392 f"{len(output_node_names)} return values: {output_node_names}.") 

393 

394 output_node_name = output_node_names[0] 

395 

396 file_path_nodes = {} 

397 # When building a tf.function, track files as `saved_model.Asset`s. 

398 if ops.get_default_graph().building_function: 

399 asset_tracker = self._maybe_track_assets(graph_def) 

400 for key in asset_tracker: 

401 assets_list = [ 

402 array_ops.expand_dims(asset.asset_path, axis=0) 

403 for asset in asset_tracker[key] 

404 ] 

405 file_path_nodes[key] = array_ops.concat(assets_list, axis=0) 

406 

407 # Add functions used in this Dataset to the function's graph, since they 

408 # need to follow it around (and for example be added to a SavedModel which 

409 # references the dataset). 

410 variant_function = wrap_function.function_from_graph_def( 

411 graph_def, 

412 inputs=[], 

413 outputs=output_node_name + ":0", 

414 captures=file_path_nodes) 

415 for used_function in self._functions(): 

416 used_function.function.add_to_graph(variant_function.graph) 

417 return variant_function 

418 

419 @abc.abstractmethod 

420 def _inputs(self): 

421 """Returns a list of the input datasets of the dataset.""" 

422 

423 raise NotImplementedError(f"{type(self)}._inputs()") 

424 

425 @property 

426 def _graph(self): 

427 return self._graph_attr 

428 

429 @_graph.setter 

430 def _graph(self, _): 

431 raise ValueError("The `_graph` property cannot be modified.") 

432 

433 # TODO(jsimsa): Change this to be the transitive closure of functions used 

434 # by this dataset and its inputs. 

435 def _functions(self): 

436 """Returns a list of functions associated with this dataset. 

437 

438 Returns: 

439 A list of `StructuredFunctionWrapper` objects. 

440 """ 

441 return [] 

442 

443 def _options(self): 

444 """Returns the options tensor for this dataset.""" 

445 # pylint: disable=protected-access 

446 return gen_dataset_ops.get_options(self._variant_tensor) 

447 

448 @classmethod 

449 def _options_tensor_to_options(cls, serialized_options): 

450 """Converts options tensor to tf.data.Options object.""" 

451 options = options_lib.Options() 

452 if tensor_util.constant_value(serialized_options) is not None: 

453 pb = dataset_options_pb2.Options.FromString(tensor_util.constant_value( 

454 serialized_options)) 

455 options._from_proto(pb) # pylint: disable=protected-access 

456 return options 

457 

458 def options(self): 

459 """Returns the options for this dataset and its inputs. 

460 

461 Returns: 

462 A `tf.data.Options` object representing the dataset options. 

463 """ 

464 if context.executing_eagerly(): 

465 options = self._options_tensor_to_options(self._options()) 

466 options._set_mutable(False) # pylint: disable=protected-access 

467 return options 

468 warnings.warn("To make it possible to preserve tf.data options across " 

469 "serialization boundaries, their implementation has moved to " 

470 "be part of the TensorFlow graph. As a consequence, the " 

471 "options value is in general no longer known at graph " 

472 "construction time. Invoking this method in graph mode " 

473 "retains the legacy behavior of the original implementation, " 

474 "but note that the returned value might not reflect the " 

475 "actual value of the options.") 

476 return self._options_attr 

477 

478 def _apply_debug_options(self): 

479 if debug_mode.DEBUG_MODE: 

480 # Disable autotuning and static optimizations that could introduce 

481 # parallelism or asynchrony. 

482 options = options_lib.Options() 

483 options.autotune.enabled = False 

484 options.experimental_optimization.filter_parallelization = False 

485 options.experimental_optimization.map_and_batch_fusion = False 

486 options.experimental_optimization.map_parallelization = False 

487 dataset = _OptionsDataset(self, options) 

488 else: 

489 dataset = self 

490 

491 return dataset 

492 

493 def __iter__(self): 

494 """Creates an iterator for elements of this dataset. 

495 

496 The returned iterator implements the Python Iterator protocol. 

497 

498 Returns: 

499 An `tf.data.Iterator` for the elements of this dataset. 

500 

501 Raises: 

502 RuntimeError: If not inside of tf.function and not executing eagerly. 

503 """ 

504 if context.executing_eagerly() or ops.inside_function(): 

505 with ops.colocate_with(self._variant_tensor): 

506 return iterator_ops.OwnedIterator(self) 

507 else: 

508 raise RuntimeError("`tf.data.Dataset` only supports Python-style " 

509 "iteration in eager mode or within tf.function.") 

510 

511 def __bool__(self): 

512 return True # Required as __len__ is defined 

513 

514 __nonzero__ = __bool__ # Python 2 backward compatibility 

515 

516 def __len__(self): 

517 """Returns the length of the dataset if it is known and finite. 

518 

519 This method requires that you are running in eager mode, and that the 

520 length of the dataset is known and non-infinite. When the length may be 

521 unknown or infinite, or if you are running in graph mode, use 

522 `tf.data.Dataset.cardinality` instead. 

523 

524 Returns: 

525 An integer representing the length of the dataset. 

526 

527 Raises: 

528 RuntimeError: If the dataset length is unknown or infinite, or if eager 

529 execution is not enabled. 

530 """ 

531 if not context.executing_eagerly(): 

532 raise TypeError("`tf.data.Dataset` only supports `len` in eager mode. " 

533 "Use `tf.data.Dataset.cardinality()` instead.") 

534 length = self.cardinality() 

535 if length.numpy() == INFINITE: 

536 raise TypeError("The dataset is infinite.") 

537 if length.numpy() == UNKNOWN: 

538 raise TypeError("The dataset length is unknown.") 

539 return length 

540 

541 @abc.abstractproperty 

542 def element_spec(self): 

543 """The type specification of an element of this dataset. 

544 

545 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

546 >>> dataset.element_spec 

547 TensorSpec(shape=(), dtype=tf.int32, name=None) 

548 

549 For more information, 

550 read [this guide](https://www.tensorflow.org/guide/data#dataset_structure). 

551 

552 Returns: 

553 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 

554 element of this dataset and specifying the type of individual components. 

555 """ 

556 raise NotImplementedError(f"{type(self)}.element_spec()") 

557 

558 def __repr__(self): 

559 type_ = type(self._dataset if isinstance(self, DatasetV1Adapter) else self) 

560 return f"<{type_.__name__} element_spec={self.element_spec}>" 

561 

562 def __debug_string__(self): 

563 """Returns a string showing the type of the dataset and its inputs. 

564 

565 This string is intended only for debugging purposes, and may change without 

566 warning. 

567 """ 

568 lines = [] 

569 to_process = [(self, 0)] # Stack of (dataset, depth) pairs. 

570 while to_process: 

571 dataset, depth = to_process.pop() 

572 lines.append("-"*2*depth + repr(dataset)) 

573 to_process.extend([(ds, depth+1) for ds in dataset._inputs()]) # pylint: disable=protected-access 

574 return "\n".join(lines) 

575 

576 def as_numpy_iterator(self): 

577 """Returns an iterator which converts all elements of the dataset to numpy. 

578 

579 Use `as_numpy_iterator` to inspect the content of your dataset. To see 

580 element shapes and types, print dataset elements directly instead of using 

581 `as_numpy_iterator`. 

582 

583 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

584 >>> for element in dataset: 

585 ... print(element) 

586 tf.Tensor(1, shape=(), dtype=int32) 

587 tf.Tensor(2, shape=(), dtype=int32) 

588 tf.Tensor(3, shape=(), dtype=int32) 

589 

590 This method requires that you are running in eager mode and the dataset's 

591 element_spec contains only `TensorSpec` components. 

592 

593 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

594 >>> for element in dataset.as_numpy_iterator(): 

595 ... print(element) 

596 1 

597 2 

598 3 

599 

600 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

601 >>> print(list(dataset.as_numpy_iterator())) 

602 [1, 2, 3] 

603 

604 `as_numpy_iterator()` will preserve the nested structure of dataset 

605 elements. 

606 

607 >>> dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]), 

608 ... 'b': [5, 6]}) 

609 >>> list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5}, 

610 ... {'a': (2, 4), 'b': 6}] 

611 True 

612 

613 Returns: 

614 An iterable over the elements of the dataset, with their tensors converted 

615 to numpy arrays. 

616 

617 Raises: 

618 TypeError: if an element contains a non-`Tensor` value. 

619 RuntimeError: if eager execution is not enabled. 

620 """ 

621 if not context.executing_eagerly(): 

622 raise RuntimeError("`tf.data.Dataset.as_numpy_iterator()` is only " 

623 "supported in eager mode.") 

624 for component_spec in nest.flatten(self.element_spec): 

625 if not isinstance( 

626 component_spec, 

627 (tensor_spec.TensorSpec, ragged_tensor.RaggedTensorSpec, 

628 sparse_tensor_lib.SparseTensorSpec, structure.NoneTensorSpec)): 

629 raise TypeError( 

630 f"`tf.data.Dataset.as_numpy_iterator()` is not supported for " 

631 f"datasets that produce values of type {component_spec.value_type}") 

632 

633 return _NumpyIterator(self) 

634 

635 @property 

636 def _flat_shapes(self): 

637 """Returns a list `tf.TensorShapes`s for the element tensor representation. 

638 

639 Returns: 

640 A list `tf.TensorShapes`s for the element tensor representation. 

641 """ 

642 return structure.get_flat_tensor_shapes(self.element_spec) 

643 

644 @property 

645 def _flat_types(self): 

646 """Returns a list `tf.DType`s for the element tensor representation. 

647 

648 Returns: 

649 A list `tf.DType`s for the element tensor representation. 

650 """ 

651 return structure.get_flat_tensor_types(self.element_spec) 

652 

653 @property 

654 def _flat_structure(self): 

655 """Helper for setting `output_shapes` and `output_types` attrs of an op. 

656 

657 Most dataset op constructors expect `output_shapes` and `output_types` 

658 arguments that represent the flattened structure of an element. This helper 

659 function generates these attrs as a keyword argument dictionary, allowing 

660 `Dataset._variant_tensor` implementations to pass `**self._flat_structure` 

661 to the op constructor. 

662 

663 Returns: 

664 A dictionary of keyword arguments that can be passed to a dataset op 

665 constructor. 

666 """ 

667 return { 

668 "output_shapes": self._flat_shapes, 

669 "output_types": self._flat_types, 

670 } 

671 

672 @property 

673 def _metadata(self): 

674 """Helper for generating dataset metadata.""" 

675 metadata = dataset_metadata_pb2.Metadata() 

676 if self._name: 

677 metadata.name = _validate_and_encode(self._name) 

678 return metadata 

679 

680 @property 

681 def _common_args(self): 

682 """Helper for generating arguments that are common across most dataset ops. 

683 

684 Most dataset op constructors expect `output_shapes` and `output_types` 

685 arguments that represent the flattened structure of an element, as well as a 

686 `metadata` argument for additional metadata such as user-defined dataset 

687 name. This helper function generates common attributes as a keyword argument 

688 dictionary, allowing `Dataset._variant_tensor` implementations to pass 

689 `**self._common_args` to the op constructor. 

690 

691 Returns: 

692 A dictionary of keyword arguments that can be passed to a dataset op 

693 constructor. 

694 """ 

695 return { 

696 "metadata": self._metadata.SerializeToString(), 

697 "output_shapes": self._flat_shapes, 

698 "output_types": self._flat_types, 

699 } 

700 

701 @property 

702 def _type_spec(self): 

703 return DatasetSpec(self.element_spec) 

704 

705 @staticmethod 

706 def from_tensors(tensors, name=None): 

707 """Creates a `Dataset` with a single element, comprising the given tensors. 

708 

709 `from_tensors` produces a dataset containing only a single element. To slice 

710 the input tensor into multiple elements, use `from_tensor_slices` instead. 

711 

712 >>> dataset = tf.data.Dataset.from_tensors([1, 2, 3]) 

713 >>> list(dataset.as_numpy_iterator()) 

714 [array([1, 2, 3], dtype=int32)] 

715 >>> dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A')) 

716 >>> list(dataset.as_numpy_iterator()) 

717 [(array([1, 2, 3], dtype=int32), b'A')] 

718 

719 >>> # You can use `from_tensors` to produce a dataset which repeats 

720 >>> # the same example many times. 

721 >>> example = tf.constant([1,2,3]) 

722 >>> dataset = tf.data.Dataset.from_tensors(example).repeat(2) 

723 >>> list(dataset.as_numpy_iterator()) 

724 [array([1, 2, 3], dtype=int32), array([1, 2, 3], dtype=int32)] 

725 

726 Note that if `tensors` contains a NumPy array, and eager execution is not 

727 enabled, the values will be embedded in the graph as one or more 

728 `tf.constant` operations. For large datasets (> 1 GB), this can waste 

729 memory and run into byte limits of graph serialization. If `tensors` 

730 contains one or more large NumPy arrays, consider the alternative described 

731 in [this 

732 guide](https://tensorflow.org/guide/data#consuming_numpy_arrays). 

733 

734 Args: 

735 tensors: A dataset "element". Supported values are documented 

736 [here](https://www.tensorflow.org/guide/data#dataset_structure). 

737 name: (Optional.) A name for the tf.data operation. 

738 

739 Returns: 

740 Dataset: A `Dataset`. 

741 """ 

742 # Loaded lazily due to a circular dependency (dataset_ops -> 

743 # from_tensors_op -> dataset_ops). 

744 # pylint: disable=g-import-not-at-top,protected-access 

745 from tensorflow.python.data.ops import from_tensors_op 

746 return from_tensors_op._from_tensors(tensors, name) 

747 # pylint: enable=g-import-not-at-top,protected-access 

748 

749 @staticmethod 

750 def from_tensor_slices(tensors, name=None): 

751 """Creates a `Dataset` whose elements are slices of the given tensors. 

752 

753 The given tensors are sliced along their first dimension. This operation 

754 preserves the structure of the input tensors, removing the first dimension 

755 of each tensor and using it as the dataset dimension. All input tensors 

756 must have the same size in their first dimensions. 

757 

758 >>> # Slicing a 1D tensor produces scalar tensor elements. 

759 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

760 >>> list(dataset.as_numpy_iterator()) 

761 [1, 2, 3] 

762 

763 >>> # Slicing a 2D tensor produces 1D tensor elements. 

764 >>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]]) 

765 >>> list(dataset.as_numpy_iterator()) 

766 [array([1, 2], dtype=int32), array([3, 4], dtype=int32)] 

767 

768 >>> # Slicing a tuple of 1D tensors produces tuple elements containing 

769 >>> # scalar tensors. 

770 >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6])) 

771 >>> list(dataset.as_numpy_iterator()) 

772 [(1, 3, 5), (2, 4, 6)] 

773 

774 >>> # Dictionary structure is also preserved. 

775 >>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]}) 

776 >>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3}, 

777 ... {'a': 2, 'b': 4}] 

778 True 

779 

780 >>> # Two tensors can be combined into one Dataset object. 

781 >>> features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor 

782 >>> labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor 

783 >>> dataset = Dataset.from_tensor_slices((features, labels)) 

784 >>> # Both the features and the labels tensors can be converted 

785 >>> # to a Dataset object separately and combined after. 

786 >>> features_dataset = Dataset.from_tensor_slices(features) 

787 >>> labels_dataset = Dataset.from_tensor_slices(labels) 

788 >>> dataset = Dataset.zip((features_dataset, labels_dataset)) 

789 >>> # A batched feature and label set can be converted to a Dataset 

790 >>> # in similar fashion. 

791 >>> batched_features = tf.constant([[[1, 3], [2, 3]], 

792 ... [[2, 1], [1, 2]], 

793 ... [[3, 3], [3, 2]]], shape=(3, 2, 2)) 

794 >>> batched_labels = tf.constant([['A', 'A'], 

795 ... ['B', 'B'], 

796 ... ['A', 'B']], shape=(3, 2, 1)) 

797 >>> dataset = Dataset.from_tensor_slices((batched_features, batched_labels)) 

798 >>> for element in dataset.as_numpy_iterator(): 

799 ... print(element) 

800 (array([[1, 3], 

801 [2, 3]], dtype=int32), array([[b'A'], 

802 [b'A']], dtype=object)) 

803 (array([[2, 1], 

804 [1, 2]], dtype=int32), array([[b'B'], 

805 [b'B']], dtype=object)) 

806 (array([[3, 3], 

807 [3, 2]], dtype=int32), array([[b'A'], 

808 [b'B']], dtype=object)) 

809 

810 Note that if `tensors` contains a NumPy array, and eager execution is not 

811 enabled, the values will be embedded in the graph as one or more 

812 `tf.constant` operations. For large datasets (> 1 GB), this can waste 

813 memory and run into byte limits of graph serialization. If `tensors` 

814 contains one or more large NumPy arrays, consider the alternative described 

815 in [this guide]( 

816 https://tensorflow.org/guide/data#consuming_numpy_arrays). 

817 

818 Args: 

819 tensors: A dataset element, whose components have the same first 

820 dimension. Supported values are documented 

821 [here](https://www.tensorflow.org/guide/data#dataset_structure). 

822 name: (Optional.) A name for the tf.data operation. 

823 

824 Returns: 

825 Dataset: A `Dataset`. 

826 """ 

827 # Loaded lazily due to a circular dependency (dataset_ops -> 

828 # from_tensor_slices_op -> dataset_ops). 

829 # pylint: disable=g-import-not-at-top,protected-access 

830 from tensorflow.python.data.ops import from_tensor_slices_op 

831 return from_tensor_slices_op._from_tensor_slices(tensors, name) 

832 # pylint: enable=g-import-not-at-top,protected-access 

833 

834 class _GeneratorState: 

835 """Stores outstanding iterators created from a Python generator. 

836 

837 This class keeps track of potentially multiple iterators that may have 

838 been created from a generator, e.g. in the case that the dataset is 

839 repeated, or nested within a parallel computation. 

840 """ 

841 

842 def __init__(self, generator): 

843 self._generator = generator 

844 self._lock = threading.Lock() 

845 self._next_id = 0 # GUARDED_BY(self._lock) 

846 self._args = {} 

847 self._iterators = {} 

848 

849 def _normalize_id(self, iterator_id): 

850 # In debug mode, iterator ids may be eagerly-generated np.arrays instead 

851 # of Tensors. We convert them to scalars to make them hashable. 

852 if isinstance(iterator_id, np.ndarray): 

853 return iterator_id.item() 

854 return iterator_id 

855 

856 def get_next_id(self, *args): 

857 with self._lock: 

858 ret = self._next_id 

859 self._next_id += 1 

860 self._args[ret] = args 

861 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 

862 # casting in `py_func()` will create an array of `np.int32` on Windows, 

863 # leading to a runtime error. 

864 return np.array(ret, dtype=np.int64) 

865 

866 def get_iterator(self, iterator_id): 

867 iterator_id = self._normalize_id(iterator_id) 

868 try: 

869 return self._iterators[iterator_id] 

870 except KeyError: 

871 iterator = iter(self._generator(*self._args.pop(iterator_id))) 

872 self._iterators[iterator_id] = iterator 

873 return iterator 

874 

875 def iterator_completed(self, iterator_id): 

876 del self._iterators[self._normalize_id(iterator_id)] 

877 

878 @staticmethod 

879 @deprecation.deprecated_args(None, "Use output_signature instead", 

880 "output_types", "output_shapes") 

881 def from_generator(generator, 

882 output_types=None, 

883 output_shapes=None, 

884 args=None, 

885 output_signature=None, 

886 name=None): 

887 """Creates a `Dataset` whose elements are generated by `generator`. 

888 

889 Note: The current implementation of `Dataset.from_generator()` uses 

890 `tf.numpy_function` and inherits the same constraints. In particular, it 

891 requires the dataset and iterator related operations to be placed 

892 on a device in the same process as the Python program that called 

893 `Dataset.from_generator()`. In particular, using `from_generator` will 

894 preclude the use of tf.data service for scaling out dataset processing. 

895 The body of `generator` will not be serialized in a `GraphDef`, and you 

896 should not use this method if you need to serialize your model and restore 

897 it in a different environment. 

898 

899 The `generator` argument must be a callable object that returns 

900 an object that supports the `iter()` protocol (e.g. a generator function). 

901 

902 The elements generated by `generator` must be compatible with either the 

903 given `output_signature` argument or with the given `output_types` and 

904 (optionally) `output_shapes` arguments, whichever was specified. 

905 

906 The recommended way to call `from_generator` is to use the 

907 `output_signature` argument. In this case the output will be assumed to 

908 consist of objects with the classes, shapes and types defined by 

909 `tf.TypeSpec` objects from `output_signature` argument: 

910 

911 >>> def gen(): 

912 ... ragged_tensor = tf.ragged.constant([[1, 2], [3]]) 

913 ... yield 42, ragged_tensor 

914 >>> 

915 >>> dataset = tf.data.Dataset.from_generator( 

916 ... gen, 

917 ... output_signature=( 

918 ... tf.TensorSpec(shape=(), dtype=tf.int32), 

919 ... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32))) 

920 >>> 

921 >>> list(dataset.take(1)) 

922 [(<tf.Tensor: shape=(), dtype=int32, numpy=42>, 

923 <tf.RaggedTensor [[1, 2], [3]]>)] 

924 

925 There is also a deprecated way to call `from_generator` by either with 

926 `output_types` argument alone or together with `output_shapes` argument. 

927 In this case the output of the function will be assumed to consist of 

928 `tf.Tensor` objects with the types defined by `output_types` and with the 

929 shapes which are either unknown or defined by `output_shapes`. 

930 

931 Note: If `generator` depends on mutable global variables or other external 

932 state, be aware that the runtime may invoke `generator` multiple times 

933 (in order to support repeating the `Dataset`) and at any time 

934 between the call to `Dataset.from_generator()` and the production of the 

935 first element from the generator. Mutating global variables or external 

936 state can cause undefined behavior, and we recommend that you explicitly 

937 cache any external state in `generator` before calling 

938 `Dataset.from_generator()`. 

939 

940 Note: While the `output_signature` parameter makes it possible to yield 

941 `Dataset` elements, the scope of `Dataset.from_generator()` should be 

942 limited to logic that cannot be expressed through tf.data operations. Using 

943 tf.data operations within the generator function is an anti-pattern and may 

944 result in incremental memory growth. 

945 

946 Args: 

947 generator: A callable object that returns an object that supports the 

948 `iter()` protocol. If `args` is not specified, `generator` must take no 

949 arguments; otherwise it must take as many arguments as there are values 

950 in `args`. 

951 output_types: (Optional.) A (nested) structure of `tf.DType` objects 

952 corresponding to each component of an element yielded by `generator`. 

953 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` 

954 objects corresponding to each component of an element yielded by 

955 `generator`. 

956 args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated 

957 and passed to `generator` as NumPy-array arguments. 

958 output_signature: (Optional.) A (nested) structure of `tf.TypeSpec` 

959 objects corresponding to each component of an element yielded by 

960 `generator`. 

961 name: (Optional.) A name for the tf.data operations used by 

962 `from_generator`. 

963 

964 Returns: 

965 Dataset: A `Dataset`. 

966 """ 

967 # Loaded lazily due to a circular dependency (dataset_ops -> 

968 # from_generator_op -> dataset_ops). 

969 # pylint: disable=g-import-not-at-top,protected-access 

970 from tensorflow.python.data.ops import from_generator_op 

971 return from_generator_op._from_generator(generator, output_types, 

972 output_shapes, args, 

973 output_signature, name) 

974 # pylint: enable=g-import-not-at-top,protected-access 

975 

976 @staticmethod 

977 def range(*args, **kwargs): 

978 """Creates a `Dataset` of a step-separated range of values. 

979 

980 >>> list(Dataset.range(5).as_numpy_iterator()) 

981 [0, 1, 2, 3, 4] 

982 >>> list(Dataset.range(2, 5).as_numpy_iterator()) 

983 [2, 3, 4] 

984 >>> list(Dataset.range(1, 5, 2).as_numpy_iterator()) 

985 [1, 3] 

986 >>> list(Dataset.range(1, 5, -2).as_numpy_iterator()) 

987 [] 

988 >>> list(Dataset.range(5, 1).as_numpy_iterator()) 

989 [] 

990 >>> list(Dataset.range(5, 1, -2).as_numpy_iterator()) 

991 [5, 3] 

992 >>> list(Dataset.range(2, 5, output_type=tf.int32).as_numpy_iterator()) 

993 [2, 3, 4] 

994 >>> list(Dataset.range(1, 5, 2, output_type=tf.float32).as_numpy_iterator()) 

995 [1.0, 3.0] 

996 

997 Args: 

998 *args: follows the same semantics as python's range. 

999 len(args) == 1 -> start = 0, stop = args[0], step = 1. 

1000 len(args) == 2 -> start = args[0], stop = args[1], step = 1. 

1001 len(args) == 3 -> start = args[0], stop = args[1], step = args[2]. 

1002 **kwargs: 

1003 - output_type: Its expected dtype. (Optional, default: `tf.int64`). 

1004 - name: (Optional.) A name for the tf.data operation. 

1005 

1006 Returns: 

1007 Dataset: A `RangeDataset`. 

1008 

1009 Raises: 

1010 ValueError: if len(args) == 0. 

1011 """ 

1012 # Loaded lazily due to a circular dependency (dataset_ops -> range_op -> 

1013 # -> dataset_ops). 

1014 # pylint: disable=g-import-not-at-top,protected-access 

1015 from tensorflow.python.data.ops import range_op 

1016 return range_op._range(*args, **kwargs) 

1017 # pylint: enable=g-import-not-at-top,protected-access 

1018 

1019 @staticmethod 

1020 def zip(*args, datasets=None, name=None): 

1021 """Creates a `Dataset` by zipping together the given datasets. 

1022 

1023 This method has similar semantics to the built-in `zip()` function 

1024 in Python, with the main difference being that the `datasets` 

1025 argument can be a (nested) structure of `Dataset` objects. The supported 

1026 nesting mechanisms are documented 

1027 [here] (https://www.tensorflow.org/guide/data#dataset_structure). 

1028 

1029 >>> # The datasets or nested structure of datasets `*args` argument 

1030 >>> # determines the structure of elements in the resulting dataset. 

1031 >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ] 

1032 >>> b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ] 

1033 >>> ds = tf.data.Dataset.zip(a, b) 

1034 >>> list(ds.as_numpy_iterator()) 

1035 [(1, 4), (2, 5), (3, 6)] 

1036 >>> ds = tf.data.Dataset.zip(b, a) 

1037 >>> list(ds.as_numpy_iterator()) 

1038 [(4, 1), (5, 2), (6, 3)] 

1039 >>> 

1040 >>> # The `datasets` argument may contain an arbitrary number of datasets. 

1041 >>> c = tf.data.Dataset.range(7, 13).batch(2) # ==> [ [7, 8], 

1042 ... # [9, 10], 

1043 ... # [11, 12] ] 

1044 >>> ds = tf.data.Dataset.zip(a, b, c) 

1045 >>> for element in ds.as_numpy_iterator(): 

1046 ... print(element) 

1047 (1, 4, array([7, 8])) 

1048 (2, 5, array([ 9, 10])) 

1049 (3, 6, array([11, 12])) 

1050 >>> 

1051 >>> # The number of elements in the resulting dataset is the same as 

1052 >>> # the size of the smallest dataset in `datasets`. 

1053 >>> d = tf.data.Dataset.range(13, 15) # ==> [ 13, 14 ] 

1054 >>> ds = tf.data.Dataset.zip(a, d) 

1055 >>> list(ds.as_numpy_iterator()) 

1056 [(1, 13), (2, 14)] 

1057 

1058 Args: 

1059 *args: Datasets or nested structures of datasets to zip together. This 

1060 can't be set if `datasets` is set. 

1061 datasets: A (nested) structure of datasets. This can't be set if `*args` 

1062 is set. Note that this exists only for backwards compatibility and it is 

1063 preferred to use *args. 

1064 name: (Optional.) A name for the tf.data operation. 

1065 

1066 Returns: 

1067 A new `Dataset` with the transformation applied as described above. 

1068 """ 

1069 # Loaded lazily due to a circular dependency (dataset_ops -> zip_op -> 

1070 # dataset_ops). 

1071 # pylint: disable=g-import-not-at-top,protected-access 

1072 from tensorflow.python.data.ops import zip_op 

1073 

1074 if not args and datasets is None: 

1075 raise TypeError("Must pass at least one dataset to `zip`.") 

1076 if args and datasets is not None: 

1077 raise TypeError("Both `*args` and `datasets` cannot be set.") 

1078 if len(args) == 1: 

1079 datasets = args[0] 

1080 elif len(args) > 1: 

1081 datasets = args 

1082 return zip_op._zip(datasets, name) 

1083 # pylint: enable=g-import-not-at-top,protected-access 

1084 

1085 def concatenate(self, dataset, name=None): 

1086 """Creates a `Dataset` by concatenating the given dataset with this dataset. 

1087 

1088 >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ] 

1089 >>> b = tf.data.Dataset.range(4, 8) # ==> [ 4, 5, 6, 7 ] 

1090 >>> ds = a.concatenate(b) 

1091 >>> list(ds.as_numpy_iterator()) 

1092 [1, 2, 3, 4, 5, 6, 7] 

1093 >>> # The input dataset and dataset to be concatenated should have 

1094 >>> # compatible element specs. 

1095 >>> c = tf.data.Dataset.zip((a, b)) 

1096 >>> a.concatenate(c) 

1097 Traceback (most recent call last): 

1098 TypeError: Two datasets to concatenate have different types 

1099 <dtype: 'int64'> and (tf.int64, tf.int64) 

1100 >>> d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"]) 

1101 >>> a.concatenate(d) 

1102 Traceback (most recent call last): 

1103 TypeError: Two datasets to concatenate have different types 

1104 <dtype: 'int64'> and <dtype: 'string'> 

1105 

1106 Args: 

1107 dataset: `Dataset` to be concatenated. 

1108 name: (Optional.) A name for the tf.data operation. 

1109 

1110 Returns: 

1111 A new `Dataset` with the transformation applied as described above. 

1112 """ 

1113 # Loaded lazily due to a circular dependency (dataset_ops -> 

1114 # concatenate_op -> dataset_ops). 

1115 # pylint: disable=g-import-not-at-top,protected-access 

1116 from tensorflow.python.data.ops import concatenate_op 

1117 return concatenate_op._concatenate(self, dataset, name) 

1118 # pylint: enable=g-import-not-at-top,protected-access 

1119 

1120 @staticmethod 

1121 def counter(start=0, step=1, dtype=dtypes.int64, name=None): 

1122 """Creates a `Dataset` that counts from `start` in steps of size `step`. 

1123 

1124 Unlike `tf.data.Dataset.range`, which stops at some ending number, 

1125 `tf.data.Dataset.counter` produces elements indefinitely. 

1126 

1127 >>> dataset = tf.data.experimental.Counter().take(5) 

1128 >>> list(dataset.as_numpy_iterator()) 

1129 [0, 1, 2, 3, 4] 

1130 >>> dataset.element_spec 

1131 TensorSpec(shape=(), dtype=tf.int64, name=None) 

1132 >>> dataset = tf.data.experimental.Counter(dtype=tf.int32) 

1133 >>> dataset.element_spec 

1134 TensorSpec(shape=(), dtype=tf.int32, name=None) 

1135 >>> dataset = tf.data.experimental.Counter(start=2).take(5) 

1136 >>> list(dataset.as_numpy_iterator()) 

1137 [2, 3, 4, 5, 6] 

1138 >>> dataset = tf.data.experimental.Counter(start=2, step=5).take(5) 

1139 >>> list(dataset.as_numpy_iterator()) 

1140 [2, 7, 12, 17, 22] 

1141 >>> dataset = tf.data.experimental.Counter(start=10, step=-1).take(5) 

1142 >>> list(dataset.as_numpy_iterator()) 

1143 [10, 9, 8, 7, 6] 

1144 

1145 Args: 

1146 start: (Optional.) The starting value for the counter. Defaults to 0. 

1147 step: (Optional.) The step size for the counter. Defaults to 1. 

1148 dtype: (Optional.) The data type for counter elements. Defaults to 

1149 `tf.int64`. 

1150 name: (Optional.) A name for the tf.data operation. 

1151 

1152 Returns: 

1153 A `Dataset` of scalar `dtype` elements. 

1154 """ 

1155 # Loaded lazily due to a circular dependency (dataset_ops -> counter_op 

1156 # -> dataset_ops). 

1157 # pylint: disable=g-import-not-at-top,protected-access 

1158 from tensorflow.python.data.ops import counter_op 

1159 return counter_op._counter(start, step, dtype, name=name) 

1160 # pylint: enable=g-import-not-at-top,protected-access 

1161 

1162 def rebatch(self, batch_size, drop_remainder=False, name=None): 

1163 """Creates a `Dataset` that rebatches the elements from this dataset. 

1164 

1165 `rebatch(N)` is functionally equivalent to `unbatch().batch(N)`, but is 

1166 more efficient, performing one copy instead of two. 

1167 

1168 >>> ds = tf.data.Dataset.range(6) 

1169 >>> ds = ds.batch(2) 

1170 >>> ds = ds.rebatch(3) 

1171 >>> list(ds.as_numpy_iterator()) 

1172 [array([0, 1, 2]), array([3, 4, 5])] 

1173 

1174 >>> ds = tf.data.Dataset.range(7) 

1175 >>> ds = ds.batch(4) 

1176 >>> ds = ds.rebatch(3) 

1177 >>> list(ds.as_numpy_iterator()) 

1178 [array([0, 1, 2]), array([3, 4, 5]), array([6])] 

1179 

1180 >>> ds = tf.data.Dataset.range(7) 

1181 >>> ds = ds.batch(2) 

1182 >>> ds = ds.rebatch(3, drop_remainder=True) 

1183 >>> list(ds.as_numpy_iterator()) 

1184 [array([0, 1, 2]), array([3, 4, 5])] 

1185 

1186 If the `batch_size` argument is a list, `rebatch` cycles through the list 

1187 to determine the size of each batch. 

1188 

1189 >>> ds = tf.data.Dataset.range(8) 

1190 >>> ds = ds.batch(4) 

1191 >>> ds = ds.rebatch([2, 1, 1]) 

1192 >>> list(ds.as_numpy_iterator()) 

1193 [array([0, 1]), array([2]), array([3]), array([4, 5]), array([6]), 

1194 array([7])] 

1195 

1196 Args: 

1197 batch_size: A `tf.int64` scalar or vector, representing the size of 

1198 batches to produce. If this argument is a vector, these values are 

1199 cycled through in round robin fashion. 

1200 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 

1201 whether the last batch should be dropped in the case it has fewer than 

1202 `batch_size[cycle_index]` elements; the default behavior is not to drop 

1203 the smaller batch. 

1204 name: (Optional.) A name for the tf.data operation. 

1205 

1206 Returns: 

1207 A `Dataset` of scalar `dtype` elements. 

1208 """ 

1209 # Loaded lazily due to a circular dependency (dataset_ops -> rebatch_op -> 

1210 # rebatch_op -> dataset_ops). 

1211 # pylint: disable=g-import-not-at-top,protected-access 

1212 from tensorflow.python.data.ops import rebatch_op 

1213 return rebatch_op._rebatch(self, batch_size, drop_remainder, name=name) 

1214 # pylint: enable=g-import-not-at-top,protected-access 

1215 

1216 def prefetch(self, buffer_size, name=None): 

1217 """Creates a `Dataset` that prefetches elements from this dataset. 

1218 

1219 Most dataset input pipelines should end with a call to `prefetch`. This 

1220 allows later elements to be prepared while the current element is being 

1221 processed. This often improves latency and throughput, at the cost of 

1222 using additional memory to store prefetched elements. 

1223 

1224 Note: Like other `Dataset` methods, prefetch operates on the 

1225 elements of the input dataset. It has no concept of examples vs. batches. 

1226 `examples.prefetch(2)` will prefetch two elements (2 examples), 

1227 while `examples.batch(20).prefetch(2)` will prefetch 2 elements 

1228 (2 batches, of 20 examples each). 

1229 

1230 >>> dataset = tf.data.Dataset.range(3) 

1231 >>> dataset = dataset.prefetch(2) 

1232 >>> list(dataset.as_numpy_iterator()) 

1233 [0, 1, 2] 

1234 

1235 Args: 

1236 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum 

1237 number of elements that will be buffered when prefetching. If the value 

1238 `tf.data.AUTOTUNE` is used, then the buffer size is dynamically tuned. 

1239 name: Optional. A name for the tf.data transformation. 

1240 

1241 Returns: 

1242 A new `Dataset` with the transformation applied as described above. 

1243 """ 

1244 return prefetch_op._prefetch( # pylint: disable=protected-access 

1245 self, buffer_size, name=name) 

1246 

1247 @staticmethod 

1248 def list_files(file_pattern, shuffle=None, seed=None, name=None): 

1249 """A dataset of all files matching one or more glob patterns. 

1250 

1251 The `file_pattern` argument should be a small number of glob patterns. 

1252 If your filenames have already been globbed, use 

1253 `Dataset.from_tensor_slices(filenames)` instead, as re-globbing every 

1254 filename with `list_files` may result in poor performance with remote 

1255 storage systems. 

1256 

1257 Note: The default behavior of this method is to return filenames in 

1258 a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False` 

1259 to get results in a deterministic order. 

1260 

1261 Example: 

1262 If we had the following files on our filesystem: 

1263 

1264 - /path/to/dir/a.txt 

1265 - /path/to/dir/b.py 

1266 - /path/to/dir/c.py 

1267 

1268 If we pass "/path/to/dir/*.py" as the directory, the dataset 

1269 would produce: 

1270 

1271 - /path/to/dir/b.py 

1272 - /path/to/dir/c.py 

1273 

1274 Args: 

1275 file_pattern: A string, a list of strings, or a `tf.Tensor` of string type 

1276 (scalar or vector), representing the filename glob (i.e. shell wildcard) 

1277 pattern(s) that will be matched. 

1278 shuffle: (Optional.) If `True`, the file names will be shuffled randomly. 

1279 Defaults to `True`. 

1280 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 

1281 seed that will be used to create the distribution. See 

1282 `tf.random.set_seed` for behavior. 

1283 name: Optional. A name for the tf.data operations used by `list_files`. 

1284 

1285 Returns: 

1286 Dataset: A `Dataset` of strings corresponding to file names. 

1287 """ 

1288 with ops.name_scope("list_files"): 

1289 if shuffle is None: 

1290 shuffle = True 

1291 file_pattern = ops.convert_to_tensor( 

1292 file_pattern, dtype=dtypes.string, name="file_pattern") 

1293 matching_files = gen_io_ops.matching_files(file_pattern) 

1294 

1295 # Raise an exception if `file_pattern` does not match any files. 

1296 condition = math_ops.greater(array_ops.shape(matching_files)[0], 0, 

1297 name="match_not_empty") 

1298 

1299 message = math_ops.add( 

1300 "No files matched pattern: ", 

1301 string_ops.reduce_join(file_pattern, separator=", "), name="message") 

1302 

1303 assert_not_empty = control_flow_assert.Assert( 

1304 condition, [message], summarize=1, name="assert_not_empty") 

1305 with ops.control_dependencies([assert_not_empty]): 

1306 matching_files = array_ops.identity(matching_files) 

1307 

1308 # TODO(b/240947712): Remove lazy import after this method is factored out. 

1309 # Loaded lazily due to a circular dependency (dataset_ops -> 

1310 # from_tensor_slices_op -> dataset_ops). 

1311 # pylint: disable=g-import-not-at-top,protected-access 

1312 from tensorflow.python.data.ops import from_tensor_slices_op 

1313 dataset = from_tensor_slices_op._TensorSliceDataset( 

1314 matching_files, is_files=True, name=name) 

1315 # pylint: enable=g-import-not-at-top,protected-access 

1316 if issubclass(Dataset, DatasetV1): 

1317 dataset = DatasetV1Adapter(dataset) 

1318 if shuffle: 

1319 # NOTE(mrry): The shuffle buffer size must be greater than zero, but the 

1320 # list of files might be empty. 

1321 buffer_size = math_ops.maximum( 

1322 array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1) 

1323 dataset = dataset.shuffle(buffer_size, seed=seed, name=name) 

1324 return dataset 

1325 

1326 def repeat(self, count=None, name=None): 

1327 """Repeats this dataset so each original value is seen `count` times. 

1328 

1329 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

1330 >>> dataset = dataset.repeat(3) 

1331 >>> list(dataset.as_numpy_iterator()) 

1332 [1, 2, 3, 1, 2, 3, 1, 2, 3] 

1333 

1334 Note: If the input dataset depends on global state (e.g. a random number 

1335 generator) or its output is non-deterministic (e.g. because of upstream 

1336 `shuffle`), then different repetitions may produce different elements. 

1337 

1338 Args: 

1339 count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 

1340 number of times the dataset should be repeated. The default behavior (if 

1341 `count` is `None` or `-1`) is for the dataset be repeated indefinitely. 

1342 name: (Optional.) A name for the tf.data operation. 

1343 

1344 Returns: 

1345 A new `Dataset` with the transformation applied as described above. 

1346 """ 

1347 # Loaded lazily due to a circular dependency (dataset_ops -> repeat_op -> 

1348 # dataset_ops). 

1349 # pylint: disable=g-import-not-at-top,protected-access,redefined-outer-name 

1350 from tensorflow.python.data.ops import repeat_op 

1351 return repeat_op._repeat(self, count, name) 

1352 # pylint: enable=g-import-not-at-top,protected-access,redefined-outer-name 

1353 

1354 def enumerate(self, start=0, name=None): 

1355 """Enumerates the elements of this dataset. 

1356 

1357 It is similar to python's `enumerate`. 

1358 

1359 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

1360 >>> dataset = dataset.enumerate(start=5) 

1361 >>> for element in dataset.as_numpy_iterator(): 

1362 ... print(element) 

1363 (5, 1) 

1364 (6, 2) 

1365 (7, 3) 

1366 

1367 >>> # The (nested) structure of the input dataset determines the 

1368 >>> # structure of elements in the resulting dataset. 

1369 >>> dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)]) 

1370 >>> dataset = dataset.enumerate() 

1371 >>> for element in dataset.as_numpy_iterator(): 

1372 ... print(element) 

1373 (0, array([7, 8], dtype=int32)) 

1374 (1, array([ 9, 10], dtype=int32)) 

1375 

1376 Args: 

1377 start: A `tf.int64` scalar `tf.Tensor`, representing the start value for 

1378 enumeration. 

1379 name: Optional. A name for the tf.data operations used by `enumerate`. 

1380 

1381 Returns: 

1382 A new `Dataset` with the transformation applied as described above. 

1383 """ 

1384 

1385 max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max 

1386 range_dataset = Dataset.range(start, max_value, name=name) 

1387 # Replicate the range component so that each split is enumerated 

1388 # independently. This avoids the need for prohibitively expensive 

1389 # cross-split coordination. 

1390 range_dataset = _apply_rewrite(range_dataset, "replicate_on_split") 

1391 return Dataset.zip((range_dataset, self), name=name) 

1392 

1393 def shuffle(self, 

1394 buffer_size, 

1395 seed=None, 

1396 reshuffle_each_iteration=None, 

1397 name=None): 

1398 """Randomly shuffles the elements of this dataset. 

1399 

1400 This dataset fills a buffer with `buffer_size` elements, then randomly 

1401 samples elements from this buffer, replacing the selected elements with new 

1402 elements. For perfect shuffling, a buffer size greater than or equal to the 

1403 full size of the dataset is required. 

1404 

1405 For instance, if your dataset contains 10,000 elements but `buffer_size` is 

1406 set to 1,000, then `shuffle` will initially select a random element from 

1407 only the first 1,000 elements in the buffer. Once an element is selected, 

1408 its space in the buffer is replaced by the next (i.e. 1,001-st) element, 

1409 maintaining the 1,000 element buffer. 

1410 

1411 `reshuffle_each_iteration` controls whether the shuffle order should be 

1412 different for each epoch. In TF 1.X, the idiomatic way to create epochs 

1413 was through the `repeat` transformation: 

1414 

1415 ```python 

1416 dataset = tf.data.Dataset.range(3) 

1417 dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 

1418 dataset = dataset.repeat(2) 

1419 # [1, 0, 2, 1, 2, 0] 

1420 

1421 dataset = tf.data.Dataset.range(3) 

1422 dataset = dataset.shuffle(3, reshuffle_each_iteration=False) 

1423 dataset = dataset.repeat(2) 

1424 # [1, 0, 2, 1, 0, 2] 

1425 ``` 

1426 

1427 In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it 

1428 possible to also create epochs through Python iteration: 

1429 

1430 ```python 

1431 dataset = tf.data.Dataset.range(3) 

1432 dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 

1433 list(dataset.as_numpy_iterator()) 

1434 # [1, 0, 2] 

1435 list(dataset.as_numpy_iterator()) 

1436 # [1, 2, 0] 

1437 ``` 

1438 

1439 ```python 

1440 dataset = tf.data.Dataset.range(3) 

1441 dataset = dataset.shuffle(3, reshuffle_each_iteration=False) 

1442 list(dataset.as_numpy_iterator()) 

1443 # [1, 0, 2] 

1444 list(dataset.as_numpy_iterator()) 

1445 # [1, 0, 2] 

1446 ``` 

1447 

1448 ### Fully shuffling all the data 

1449 

1450 To shuffle an entire dataset, set `buffer_size=dataset.cardinality(). This 

1451 is equivalent to setting the `buffer_size` equal to the number of elements 

1452 in the dataset, resulting in uniform shuffle. 

1453 

1454 Note: `shuffle(dataset.cardinality())` loads the full dataset into memory so 

1455 that it can be shuffled. This will cause a memory overflow (OOM) error if 

1456 the dataset is too large, so full-shuffle should only be used for datasets 

1457 that are known to fit in the memory, such as datasets of filenames or other 

1458 small datasets. 

1459 

1460 ```python 

1461 dataset = tf.data.Dataset.range(20) 

1462 dataset = dataset.shuffle(dataset.cardinality()) 

1463 # [18, 4, 9, 2, 17, 8, 5, 10, 0, 6, 16, 3, 19, 7, 14, 11, 15, 13, 12, 1] 

1464 ``` 

1465 

1466 Args: 

1467 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 

1468 elements from this dataset from which the new dataset will sample. To 

1469 uniformly shuffle the entire dataset, use 

1470 `buffer_size=dataset.cardinality()`. 

1471 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 

1472 seed that will be used to create the distribution. See 

1473 `tf.random.set_seed` for behavior. 

1474 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates 

1475 that the dataset should be pseudorandomly reshuffled each time it is 

1476 iterated over. (Defaults to `True`.) 

1477 name: (Optional.) A name for the tf.data operation. 

1478 

1479 Returns: 

1480 A new `Dataset` with the transformation applied as described above. 

1481 """ 

1482 return shuffle_op._shuffle( # pylint: disable=protected-access 

1483 self, buffer_size, seed, reshuffle_each_iteration, name=name) 

1484 

1485 def cache(self, filename="", name=None): 

1486 """Caches the elements in this dataset. 

1487 

1488 The first time the dataset is iterated over, its elements will be cached 

1489 either in the specified file or in memory. Subsequent iterations will 

1490 use the cached data. 

1491 

1492 Note: To guarantee that the cache gets finalized, the input dataset must be 

1493 iterated through in its entirety, until it raises StopIteration. Otherwise, 

1494 subsequent iterations may not use cached data. 

1495 

1496 >>> dataset = tf.data.Dataset.range(5) 

1497 >>> dataset = dataset.map(lambda x: x**2) 

1498 >>> dataset = dataset.cache() 

1499 >>> # The first time reading through the data will generate the data using 

1500 >>> # `range` and `map`. 

1501 >>> list(dataset.as_numpy_iterator()) 

1502 [0, 1, 4, 9, 16] 

1503 >>> # Subsequent iterations read from the cache. 

1504 >>> list(dataset.as_numpy_iterator()) 

1505 [0, 1, 4, 9, 16] 

1506 

1507 When caching to a file, the cached data will persist across runs. Even the 

1508 first iteration through the data will read from the cache file. Changing 

1509 the input pipeline before the call to `.cache()` will have no effect until 

1510 the cache file is removed or the filename is changed. 

1511 

1512 ```python 

1513 dataset = tf.data.Dataset.range(5) 

1514 dataset = dataset.cache("/path/to/file") 

1515 list(dataset.as_numpy_iterator()) 

1516 # [0, 1, 2, 3, 4] 

1517 dataset = tf.data.Dataset.range(10) 

1518 dataset = dataset.cache("/path/to/file") # Same file! 

1519 list(dataset.as_numpy_iterator()) 

1520 # [0, 1, 2, 3, 4] 

1521 ``` 

1522 

1523 Note: `cache` will produce exactly the same elements during each iteration 

1524 through the dataset. If you wish to randomize the iteration order, make sure 

1525 to call `shuffle` *after* calling `cache`. 

1526 

1527 Args: 

1528 filename: A `tf.string` scalar `tf.Tensor`, representing the name of a 

1529 directory on the filesystem to use for caching elements in this Dataset. 

1530 If a filename is not provided, the dataset will be cached in memory. 

1531 name: (Optional.) A name for the tf.data operation. 

1532 

1533 Returns: 

1534 A new `Dataset` with the transformation applied as described above. 

1535 """ 

1536 # Loaded lazily due to a circular dependency (dataset_ops -> cache_op -> 

1537 # -> dataset_ops). 

1538 # pylint: disable=g-import-not-at-top,protected-access 

1539 from tensorflow.python.data.ops import cache_op 

1540 return cache_op._cache(self, filename, name) 

1541 # pylint: enable=g-import-not-at-top,protected-access 

1542 

1543 def take(self, count, name=None): 

1544 """Creates a `Dataset` with at most `count` elements from this dataset. 

1545 

1546 >>> dataset = tf.data.Dataset.range(10) 

1547 >>> dataset = dataset.take(3) 

1548 >>> list(dataset.as_numpy_iterator()) 

1549 [0, 1, 2] 

1550 

1551 Args: 

1552 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 

1553 elements of this dataset that should be taken to form the new dataset. 

1554 If `count` is -1, or if `count` is greater than the size of this 

1555 dataset, the new dataset will contain all elements of this dataset. 

1556 name: (Optional.) A name for the tf.data operation. 

1557 

1558 Returns: 

1559 A new `Dataset` with the transformation applied as described above. 

1560 """ 

1561 # Loaded lazily due to a circular dependency (dataset_ops -> 

1562 # take_op -> dataset_ops). 

1563 # pylint: disable=g-import-not-at-top,protected-access 

1564 from tensorflow.python.data.ops import take_op 

1565 return take_op._take(self, count, name=name) 

1566 # pylint: enable=g-import-not-at-top,protected-access 

1567 

1568 def skip(self, count, name=None): 

1569 """Creates a `Dataset` that skips `count` elements from this dataset. 

1570 

1571 >>> dataset = tf.data.Dataset.range(10) 

1572 >>> dataset = dataset.skip(7) 

1573 >>> list(dataset.as_numpy_iterator()) 

1574 [7, 8, 9] 

1575 

1576 Args: 

1577 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 

1578 elements of this dataset that should be skipped to form the new dataset. 

1579 If `count` is greater than the size of this dataset, the new dataset 

1580 will contain no elements. If `count` is -1, skips the entire dataset. 

1581 name: (Optional.) A name for the tf.data operation. 

1582 

1583 Returns: 

1584 A new `Dataset` with the transformation applied as described above. 

1585 """ 

1586 # Loaded lazily due to a circular dependency (dataset_ops -> 

1587 # skip_op -> dataset_ops). 

1588 # pylint: disable=g-import-not-at-top,protected-access 

1589 from tensorflow.python.data.ops import skip_op 

1590 return skip_op._skip(self, count, name) 

1591 # pylint: enable=g-import-not-at-top,protected-access 

1592 

1593 def shard(self, num_shards, index, name=None): 

1594 """Creates a `Dataset` that includes only 1/`num_shards` of this dataset. 

1595 

1596 `shard` is deterministic. The Dataset produced by `A.shard(n, i)` will 

1597 contain all elements of A whose index mod n = i. 

1598 

1599 >>> A = tf.data.Dataset.range(10) 

1600 >>> B = A.shard(num_shards=3, index=0) 

1601 >>> list(B.as_numpy_iterator()) 

1602 [0, 3, 6, 9] 

1603 >>> C = A.shard(num_shards=3, index=1) 

1604 >>> list(C.as_numpy_iterator()) 

1605 [1, 4, 7] 

1606 >>> D = A.shard(num_shards=3, index=2) 

1607 >>> list(D.as_numpy_iterator()) 

1608 [2, 5, 8] 

1609 

1610 This dataset operator is very useful when running distributed training, as 

1611 it allows each worker to read a unique subset. 

1612 

1613 When reading a single input file, you can shard elements as follows: 

1614 

1615 ```python 

1616 d = tf.data.TFRecordDataset(input_file) 

1617 d = d.shard(num_workers, worker_index) 

1618 d = d.repeat(num_epochs) 

1619 d = d.shuffle(shuffle_buffer_size) 

1620 d = d.map(parser_fn, num_parallel_calls=num_map_threads) 

1621 ``` 

1622 

1623 Important caveats: 

1624 

1625 - Be sure to shard before you use any randomizing operator (such as 

1626 shuffle). 

1627 - Generally it is best if the shard operator is used early in the dataset 

1628 pipeline. For example, when reading from a set of TFRecord files, shard 

1629 before converting the dataset to input samples. This avoids reading every 

1630 file on every worker. The following is an example of an efficient 

1631 sharding strategy within a complete pipeline: 

1632 

1633 ```python 

1634 d = Dataset.list_files(pattern, shuffle=False) 

1635 d = d.shard(num_workers, worker_index) 

1636 d = d.repeat(num_epochs) 

1637 d = d.shuffle(shuffle_buffer_size) 

1638 d = d.interleave(tf.data.TFRecordDataset, 

1639 cycle_length=num_readers, block_length=1) 

1640 d = d.map(parser_fn, num_parallel_calls=num_map_threads) 

1641 ``` 

1642 

1643 Args: 

1644 num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of 

1645 shards operating in parallel. 

1646 index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. 

1647 name: (Optional.) A name for the tf.data operation. 

1648 

1649 Returns: 

1650 A new `Dataset` with the transformation applied as described above. 

1651 

1652 Raises: 

1653 InvalidArgumentError: if `num_shards` or `index` are illegal values. 

1654 

1655 Note: error checking is done on a best-effort basis, and errors aren't 

1656 guaranteed to be caught upon dataset creation. (e.g. providing in a 

1657 placeholder tensor bypasses the early checking, and will instead result 

1658 in an error during a session.run call.) 

1659 """ 

1660 # pylint: disable=g-import-not-at-top,protected-access 

1661 from tensorflow.python.data.ops import shard_op 

1662 return shard_op._shard(self, num_shards, index, name=name) 

1663 # pylint: enable=g-import-not-at-top,protected-access 

1664 

1665 def save(self, 

1666 path, 

1667 compression=None, 

1668 shard_func=None, 

1669 checkpoint_args=None): 

1670 """Saves the content of the given dataset. 

1671 

1672 Example usage: 

1673 

1674 >>> import tempfile 

1675 >>> path = os.path.join(tempfile.gettempdir(), "saved_data") 

1676 >>> # Save a dataset 

1677 >>> dataset = tf.data.Dataset.range(2) 

1678 >>> dataset.save(path) 

1679 >>> new_dataset = tf.data.Dataset.load(path) 

1680 >>> for elem in new_dataset: 

1681 ... print(elem) 

1682 tf.Tensor(0, shape=(), dtype=int64) 

1683 tf.Tensor(1, shape=(), dtype=int64) 

1684 

1685 The saved dataset is saved in multiple file "shards". By default, the 

1686 dataset output is divided to shards in a round-robin fashion but custom 

1687 sharding can be specified via the `shard_func` function. For example, you 

1688 can save the dataset to using a single shard as follows: 

1689 

1690 ```python 

1691 dataset = make_dataset() 

1692 def custom_shard_func(element): 

1693 return np.int64(0) 

1694 dataset.save( 

1695 path="/path/to/data", ..., shard_func=custom_shard_func) 

1696 ``` 

1697 

1698 To enable checkpointing, pass in `checkpoint_args` to the `save` method 

1699 as follows: 

1700 

1701 ```python 

1702 dataset = tf.data.Dataset.range(100) 

1703 save_dir = "..." 

1704 checkpoint_prefix = "..." 

1705 step_counter = tf.Variable(0, trainable=False) 

1706 checkpoint_args = { 

1707 "checkpoint_interval": 50, 

1708 "step_counter": step_counter, 

1709 "directory": checkpoint_prefix, 

1710 "max_to_keep": 20, 

1711 } 

1712 dataset.save(dataset, save_dir, checkpoint_args=checkpoint_args) 

1713 ``` 

1714 

1715 NOTE: The directory layout and file format used for saving the dataset is 

1716 considered an implementation detail and may change. For this reason, 

1717 datasets saved through `tf.data.Dataset.save` should only be consumed 

1718 through `tf.data.Dataset.load`, which is guaranteed to be 

1719 backwards compatible. 

1720 

1721 Args: 

1722 path: Required. A directory to use for saving the dataset. 

1723 compression: Optional. The algorithm to use to compress data when writing 

1724 it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`. 

1725 shard_func: Optional. A function to control the mapping of dataset 

1726 elements to file shards. The function is expected to map elements of 

1727 the input dataset to int64 shard IDs. If present, the function will be 

1728 traced and executed as graph computation. 

1729 checkpoint_args: Optional args for checkpointing which will be passed into 

1730 the `tf.train.CheckpointManager`. If `checkpoint_args` are not 

1731 specified, then checkpointing will not be performed. The `save()` 

1732 implementation creates a `tf.train.Checkpoint` object internally, so 

1733 users should not set the `checkpoint` argument in `checkpoint_args`. 

1734 

1735 Returns: 

1736 An operation which when executed performs the save. When writing 

1737 checkpoints, returns None. The return value is useful in unit tests. 

1738 

1739 Raises: 

1740 ValueError if `checkpoint` is passed into `checkpoint_args`. 

1741 """ 

1742 # Loaded lazily due to a circular dependency (dataset_ops -> save_op -> 

1743 # dataset_ops). 

1744 # pylint: disable=g-import-not-at-top,protected-access 

1745 from tensorflow.python.data.ops import save_op 

1746 return save_op._save(self, path, compression, shard_func, checkpoint_args) 

1747 # pylint: enable=g-import-not-at-top,protected-access 

1748 

1749 @staticmethod 

1750 def load(path, element_spec=None, compression=None, reader_func=None): 

1751 """Loads a previously saved dataset. 

1752 

1753 Example usage: 

1754 

1755 >>> import tempfile 

1756 >>> path = os.path.join(tempfile.gettempdir(), "saved_data") 

1757 >>> # Save a dataset 

1758 >>> dataset = tf.data.Dataset.range(2) 

1759 >>> tf.data.Dataset.save(dataset, path) 

1760 >>> new_dataset = tf.data.Dataset.load(path) 

1761 >>> for elem in new_dataset: 

1762 ... print(elem) 

1763 tf.Tensor(0, shape=(), dtype=int64) 

1764 tf.Tensor(1, shape=(), dtype=int64) 

1765 

1766 

1767 If the default option of sharding the saved dataset was used, the element 

1768 order of the saved dataset will be preserved when loading it. 

1769 

1770 The `reader_func` argument can be used to specify a custom order in which 

1771 elements should be loaded from the individual shards. The `reader_func` is 

1772 expected to take a single argument -- a dataset of datasets, each containing 

1773 elements of one of the shards -- and return a dataset of elements. For 

1774 example, the order of shards can be shuffled when loading them as follows: 

1775 

1776 ```python 

1777 def custom_reader_func(datasets): 

1778 datasets = datasets.shuffle(NUM_SHARDS) 

1779 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE) 

1780 

1781 dataset = tf.data.Dataset.load( 

1782 path="/path/to/data", ..., reader_func=custom_reader_func) 

1783 ``` 

1784 

1785 Args: 

1786 path: Required. A path pointing to a previously saved dataset. 

1787 element_spec: Optional. A nested structure of `tf.TypeSpec` objects 

1788 matching the structure of an element of the saved dataset and specifying 

1789 the type of individual element components. If not provided, the nested 

1790 structure of `tf.TypeSpec` saved with the saved dataset is used. Note 

1791 that this argument is required in graph mode. 

1792 compression: Optional. The algorithm to use to decompress the data when 

1793 reading it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`. 

1794 reader_func: Optional. A function to control how to read data from shards. 

1795 If present, the function will be traced and executed as graph 

1796 computation. 

1797 

1798 Returns: 

1799 A `tf.data.Dataset` instance. 

1800 

1801 Raises: 

1802 FileNotFoundError: If `element_spec` is not specified and the saved nested 

1803 structure of `tf.TypeSpec` can not be located with the saved dataset. 

1804 ValueError: If `element_spec` is not specified and the method is executed 

1805 in graph mode. 

1806 """ 

1807 # Loaded lazily due to a circular dependency (dataset_ops -> load_op -> 

1808 # dataset_ops). 

1809 # pylint: disable=g-import-not-at-top,protected-access 

1810 from tensorflow.python.data.ops import load_op 

1811 return load_op._load( 

1812 path=path, 

1813 element_spec=element_spec, 

1814 compression=compression, 

1815 reader_func=reader_func) 

1816 # pylint: enable=g-import-not-at-top,protected-access 

1817 

1818 def batch(self, 

1819 batch_size, 

1820 drop_remainder=False, 

1821 num_parallel_calls=None, 

1822 deterministic=None, 

1823 name=None): 

1824 """Combines consecutive elements of this dataset into batches. 

1825 

1826 >>> dataset = tf.data.Dataset.range(8) 

1827 >>> dataset = dataset.batch(3) 

1828 >>> list(dataset.as_numpy_iterator()) 

1829 [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])] 

1830 

1831 >>> dataset = tf.data.Dataset.range(8) 

1832 >>> dataset = dataset.batch(3, drop_remainder=True) 

1833 >>> list(dataset.as_numpy_iterator()) 

1834 [array([0, 1, 2]), array([3, 4, 5])] 

1835 

1836 The components of the resulting element will have an additional outer 

1837 dimension, which will be `batch_size` (or `N % batch_size` for the last 

1838 element if `batch_size` does not divide the number of input elements `N` 

1839 evenly and `drop_remainder` is `False`). If your program depends on the 

1840 batches having the same outer dimension, you should set the `drop_remainder` 

1841 argument to `True` to prevent the smaller batch from being produced. 

1842 

1843 Note: If your program requires data to have a statically known shape (e.g., 

1844 when using XLA), you should use `drop_remainder=True`. Without 

1845 `drop_remainder=True` the shape of the output dataset will have an unknown 

1846 leading dimension due to the possibility of a smaller final batch. 

1847 

1848 Args: 

1849 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 

1850 consecutive elements of this dataset to combine in a single batch. 

1851 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 

1852 whether the last batch should be dropped in the case it has fewer than 

1853 `batch_size` elements; the default behavior is not to drop the smaller 

1854 batch. 

1855 num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`, 

1856 representing the number of batches to compute asynchronously in 

1857 parallel. 

1858 If not specified, batches will be computed sequentially. If the value 

1859 `tf.data.AUTOTUNE` is used, then the number of parallel 

1860 calls is set dynamically based on available resources. 

1861 deterministic: (Optional.) When `num_parallel_calls` is specified, if this 

1862 boolean is specified (`True` or `False`), it controls the order in which 

1863 the transformation produces elements. If set to `False`, the 

1864 transformation is allowed to yield elements out of order to trade 

1865 determinism for performance. If not specified, the 

1866 `tf.data.Options.deterministic` option (`True` by default) controls the 

1867 behavior. 

1868 name: (Optional.) A name for the tf.data operation. 

1869 

1870 Returns: 

1871 A new `Dataset` with the transformation applied as described above. 

1872 """ 

1873 # Loaded lazily due to a circular dependency (dataset_ops -> batch_op -> 

1874 # dataset_ops). 

1875 # pylint: disable=g-import-not-at-top,protected-access,redefined-outer-name 

1876 from tensorflow.python.data.ops import batch_op 

1877 return batch_op._batch(self, batch_size, drop_remainder, num_parallel_calls, 

1878 deterministic, name) 

1879 # pylint: enable=g-import-not-at-top,protected-access,redefined-outer-name 

1880 

1881 def padded_batch(self, 

1882 batch_size, 

1883 padded_shapes=None, 

1884 padding_values=None, 

1885 drop_remainder=False, 

1886 name=None): 

1887 """Combines consecutive elements of this dataset into padded batches. 

1888 

1889 This transformation combines multiple consecutive elements of the input 

1890 dataset into a single element. 

1891 

1892 Like `tf.data.Dataset.batch`, the components of the resulting element will 

1893 have an additional outer dimension, which will be `batch_size` (or 

1894 `N % batch_size` for the last element if `batch_size` does not divide the 

1895 number of input elements `N` evenly and `drop_remainder` is `False`). If 

1896 your program depends on the batches having the same outer dimension, you 

1897 should set the `drop_remainder` argument to `True` to prevent the smaller 

1898 batch from being produced. 

1899 

1900 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have 

1901 different shapes, and this transformation will pad each component to the 

1902 respective shape in `padded_shapes`. The `padded_shapes` argument 

1903 determines the resulting shape for each dimension of each component in an 

1904 output element: 

1905 

1906 * If the dimension is a constant, the component will be padded out to that 

1907 length in that dimension. 

1908 * If the dimension is unknown, the component will be padded out to the 

1909 maximum length of all elements in that dimension. 

1910 

1911 >>> A = (tf.data.Dataset 

1912 ... .range(1, 5, output_type=tf.int32) 

1913 ... .map(lambda x: tf.fill([x], x))) 

1914 >>> # Pad to the smallest per-batch size that fits all elements. 

1915 >>> B = A.padded_batch(2) 

1916 >>> for element in B.as_numpy_iterator(): 

1917 ... print(element) 

1918 [[1 0] 

1919 [2 2]] 

1920 [[3 3 3 0] 

1921 [4 4 4 4]] 

1922 >>> # Pad to a fixed size. 

1923 >>> C = A.padded_batch(2, padded_shapes=5) 

1924 >>> for element in C.as_numpy_iterator(): 

1925 ... print(element) 

1926 [[1 0 0 0 0] 

1927 [2 2 0 0 0]] 

1928 [[3 3 3 0 0] 

1929 [4 4 4 4 0]] 

1930 >>> # Pad with a custom value. 

1931 >>> D = A.padded_batch(2, padded_shapes=5, padding_values=-1) 

1932 >>> for element in D.as_numpy_iterator(): 

1933 ... print(element) 

1934 [[ 1 -1 -1 -1 -1] 

1935 [ 2 2 -1 -1 -1]] 

1936 [[ 3 3 3 -1 -1] 

1937 [ 4 4 4 4 -1]] 

1938 >>> # Components of nested elements can be padded independently. 

1939 >>> elements = [([1, 2, 3], [10]), 

1940 ... ([4, 5], [11, 12])] 

1941 >>> dataset = tf.data.Dataset.from_generator( 

1942 ... lambda: iter(elements), (tf.int32, tf.int32)) 

1943 >>> # Pad the first component of the tuple to length 4, and the second 

1944 >>> # component to the smallest size that fits. 

1945 >>> dataset = dataset.padded_batch(2, 

1946 ... padded_shapes=([4], [None]), 

1947 ... padding_values=(-1, 100)) 

1948 >>> list(dataset.as_numpy_iterator()) 

1949 [(array([[ 1, 2, 3, -1], [ 4, 5, -1, -1]], dtype=int32), 

1950 array([[ 10, 100], [ 11, 12]], dtype=int32))] 

1951 >>> # Pad with a single value and multiple components. 

1952 >>> E = tf.data.Dataset.zip((A, A)).padded_batch(2, padding_values=-1) 

1953 >>> for element in E.as_numpy_iterator(): 

1954 ... print(element) 

1955 (array([[ 1, -1], 

1956 [ 2, 2]], dtype=int32), array([[ 1, -1], 

1957 [ 2, 2]], dtype=int32)) 

1958 (array([[ 3, 3, 3, -1], 

1959 [ 4, 4, 4, 4]], dtype=int32), array([[ 3, 3, 3, -1], 

1960 [ 4, 4, 4, 4]], dtype=int32)) 

1961 

1962 See also `tf.data.experimental.dense_to_sparse_batch`, which combines 

1963 elements that may have different shapes into a `tf.sparse.SparseTensor`. 

1964 

1965 Args: 

1966 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 

1967 consecutive elements of this dataset to combine in a single batch. 

1968 padded_shapes: (Optional.) A (nested) structure of `tf.TensorShape` or 

1969 `tf.int64` vector tensor-like objects representing the shape to which 

1970 the respective component of each input element should be padded prior 

1971 to batching. Any unknown dimensions will be padded to the maximum size 

1972 of that dimension in each batch. If unset, all dimensions of all 

1973 components are padded to the maximum size in the batch. `padded_shapes` 

1974 must be set if any component has an unknown rank. 

1975 padding_values: (Optional.) A (nested) structure of scalar-shaped 

1976 `tf.Tensor`, representing the padding values to use for the respective 

1977 components. None represents that the (nested) structure should be padded 

1978 with default values. Defaults are `0` for numeric types and the empty 

1979 string for string types. The `padding_values` should have the same 

1980 (nested) structure as the input dataset. If `padding_values` is a single 

1981 element and the input dataset has multiple components, then the same 

1982 `padding_values` will be used to pad every component of the dataset. 

1983 If `padding_values` is a scalar, then its value will be broadcasted 

1984 to match the shape of each component. 

1985 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 

1986 whether the last batch should be dropped in the case it has fewer than 

1987 `batch_size` elements; the default behavior is not to drop the smaller 

1988 batch. 

1989 name: (Optional.) A name for the tf.data operation. 

1990 

1991 Returns: 

1992 A new `Dataset` with the transformation applied as described above. 

1993 

1994 Raises: 

1995 ValueError: If a component has an unknown rank, and the `padded_shapes` 

1996 argument is not set. 

1997 TypeError: If a component is of an unsupported type. The list of supported 

1998 types is documented in 

1999 https://www.tensorflow.org/guide/data#dataset_structure. 

2000 """ 

2001 # Loaded lazily due to a circular dependency (dataset_ops -> 

2002 # padded_batch_op -> dataset_ops). 

2003 # pylint: disable=g-import-not-at-top,protected-access 

2004 from tensorflow.python.data.ops import padded_batch_op 

2005 return padded_batch_op._padded_batch(self, batch_size, padded_shapes, 

2006 padding_values, drop_remainder, name) 

2007 # pylint: enable=g-import-not-at-top,protected-access 

2008 

2009 def ragged_batch(self, 

2010 batch_size, 

2011 drop_remainder=False, 

2012 row_splits_dtype=dtypes.int64, 

2013 name=None): 

2014 """Combines consecutive elements of this dataset into `tf.RaggedTensor`s. 

2015 

2016 Like `tf.data.Dataset.batch`, the components of the resulting element will 

2017 have an additional outer dimension, which will be `batch_size` (or 

2018 `N % batch_size` for the last element if `batch_size` does not divide the 

2019 number of input elements `N` evenly and `drop_remainder` is `False`). If 

2020 your program depends on the batches having the same outer dimension, you 

2021 should set the `drop_remainder` argument to `True` to prevent the smaller 

2022 batch from being produced. 

2023 

2024 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have 

2025 different shapes: 

2026 

2027 * If an input element is a `tf.Tensor` whose static `tf.TensorShape` is 

2028 fully defined, then it is batched as normal. 

2029 * If an input element is a `tf.Tensor` whose static `tf.TensorShape` 

2030 contains one or more axes with unknown size (i.e., `shape[i]=None`), then 

2031 the output will contain a `tf.RaggedTensor` that is ragged up to any of such 

2032 dimensions. 

2033 * If an input element is a `tf.RaggedTensor` or any other type, then it is 

2034 batched as normal. 

2035 

2036 Example: 

2037 

2038 >>> dataset = tf.data.Dataset.range(6) 

2039 >>> dataset = dataset.map(lambda x: tf.range(x)) 

2040 >>> dataset.element_spec.shape 

2041 TensorShape([None]) 

2042 >>> dataset = dataset.ragged_batch(2) 

2043 >>> for batch in dataset: 

2044 ... print(batch) 

2045 <tf.RaggedTensor [[], [0]]> 

2046 <tf.RaggedTensor [[0, 1], [0, 1, 2]]> 

2047 <tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]> 

2048 

2049 Args: 

2050 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 

2051 consecutive elements of this dataset to combine in a single batch. 

2052 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 

2053 whether the last batch should be dropped in the case it has fewer than 

2054 `batch_size` elements; the default behavior is not to drop the smaller 

2055 batch. 

2056 row_splits_dtype: The dtype that should be used for the `row_splits` of 

2057 any new ragged tensors. Existing `tf.RaggedTensor` elements do not have 

2058 their row_splits dtype changed. 

2059 name: (Optional.) A string indicating a name for the `tf.data` operation. 

2060 

2061 Returns: 

2062 A new `Dataset` with the transformation applied as described above. 

2063 """ 

2064 # Loaded lazily due to a circular dependency (dataset_ops -> 

2065 # ragged_batch_op -> dataset_ops). 

2066 # pylint: disable=g-import-not-at-top,protected-access 

2067 from tensorflow.python.data.ops import ragged_batch_op 

2068 return ragged_batch_op._ragged_batch(self, batch_size, drop_remainder, 

2069 row_splits_dtype, name) 

2070 # pylint: enable=g-import-not-at-top,protected-access 

2071 

2072 def sparse_batch(self, batch_size, row_shape, name=None): 

2073 """Combines consecutive elements into `tf.sparse.SparseTensor`s. 

2074 

2075 Like `Dataset.padded_batch()`, this transformation combines multiple 

2076 consecutive elements of the dataset, which might have different 

2077 shapes, into a single element. The resulting element has three 

2078 components (`indices`, `values`, and `dense_shape`), which 

2079 comprise a `tf.sparse.SparseTensor` that represents the same data. The 

2080 `row_shape` represents the dense shape of each row in the 

2081 resulting `tf.sparse.SparseTensor`, to which the effective batch size is 

2082 prepended. For example: 

2083 

2084 ```python 

2085 # NOTE: The following examples use `{ ... }` to represent the 

2086 # contents of a dataset. 

2087 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } 

2088 

2089 a.apply(tf.data.experimental.dense_to_sparse_batch( 

2090 batch_size=2, row_shape=[6])) == 

2091 { 

2092 ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices 

2093 ['a', 'b', 'c', 'a', 'b'], # values 

2094 [2, 6]), # dense_shape 

2095 ([[0, 0], [0, 1], [0, 2], [0, 3]], 

2096 ['a', 'b', 'c', 'd'], 

2097 [1, 6]) 

2098 } 

2099 ``` 

2100 

2101 Args: 

2102 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 

2103 consecutive elements of this dataset to combine in a single batch. 

2104 row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like object 

2105 representing the equivalent dense shape of a row in the resulting 

2106 `tf.sparse.SparseTensor`. Each element of this dataset must have the 

2107 same rank as `row_shape`, and must have size less than or equal to 

2108 `row_shape` in each dimension. 

2109 name: (Optional.) A string indicating a name for the `tf.data` operation. 

2110 

2111 Returns: 

2112 A new `Dataset` with the transformation applied as described above. 

2113 """ 

2114 # Loaded lazily due to a circular dependency (dataset_ops -> 

2115 # sparse_batch_op -> dataset_ops). 

2116 # pylint: disable=g-import-not-at-top,protected-access 

2117 from tensorflow.python.data.ops import sparse_batch_op 

2118 return sparse_batch_op._sparse_batch(self, batch_size, row_shape, name) 

2119 # pylint: disable=g-import-not-at-top,protected-access 

2120 

2121 def map(self, 

2122 map_func, 

2123 num_parallel_calls=None, 

2124 deterministic=None, 

2125 name=None): 

2126 """Maps `map_func` across the elements of this dataset. 

2127 

2128 This transformation applies `map_func` to each element of this dataset, and 

2129 returns a new dataset containing the transformed elements, in the same 

2130 order as they appeared in the input. `map_func` can be used to change both 

2131 the values and the structure of a dataset's elements. Supported structure 

2132 constructs are documented 

2133 [here](https://www.tensorflow.org/guide/data#dataset_structure). 

2134 

2135 For example, `map` can be used for adding 1 to each element, or projecting a 

2136 subset of element components. 

2137 

2138 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 

2139 >>> dataset = dataset.map(lambda x: x + 1) 

2140 >>> list(dataset.as_numpy_iterator()) 

2141 [2, 3, 4, 5, 6] 

2142 

2143 The input signature of `map_func` is determined by the structure of each 

2144 element in this dataset. 

2145 

2146 >>> dataset = Dataset.range(5) 

2147 >>> # `map_func` takes a single argument of type `tf.Tensor` with the same 

2148 >>> # shape and dtype. 

2149 >>> result = dataset.map(lambda x: x + 1) 

2150 

2151 >>> # Each element is a tuple containing two `tf.Tensor` objects. 

2152 >>> elements = [(1, "foo"), (2, "bar"), (3, "baz")] 

2153 >>> dataset = tf.data.Dataset.from_generator( 

2154 ... lambda: elements, (tf.int32, tf.string)) 

2155 >>> # `map_func` takes two arguments of type `tf.Tensor`. This function 

2156 >>> # projects out just the first component. 

2157 >>> result = dataset.map(lambda x_int, y_str: x_int) 

2158 >>> list(result.as_numpy_iterator()) 

2159 [1, 2, 3] 

2160 

2161 >>> # Each element is a dictionary mapping strings to `tf.Tensor` objects. 

2162 >>> elements = ([{"a": 1, "b": "foo"}, 

2163 ... {"a": 2, "b": "bar"}, 

2164 ... {"a": 3, "b": "baz"}]) 

2165 >>> dataset = tf.data.Dataset.from_generator( 

2166 ... lambda: elements, {"a": tf.int32, "b": tf.string}) 

2167 >>> # `map_func` takes a single argument of type `dict` with the same keys 

2168 >>> # as the elements. 

2169 >>> result = dataset.map(lambda d: str(d["a"]) + d["b"]) 

2170 

2171 The value or values returned by `map_func` determine the structure of each 

2172 element in the returned dataset. 

2173 

2174 >>> dataset = tf.data.Dataset.range(3) 

2175 >>> # `map_func` returns two `tf.Tensor` objects. 

2176 >>> def g(x): 

2177 ... return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"]) 

2178 >>> result = dataset.map(g) 

2179 >>> result.element_spec 

2180 (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), \ 

2181dtype=tf.string, name=None)) 

2182 >>> # Python primitives, lists, and NumPy arrays are implicitly converted to 

2183 >>> # `tf.Tensor`. 

2184 >>> def h(x): 

2185 ... return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64) 

2186 >>> result = dataset.map(h) 

2187 >>> result.element_spec 

2188 (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), \ 

2189dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, \ 

2190name=None)) 

2191 >>> # `map_func` can return nested structures. 

2192 >>> def i(x): 

2193 ... return (37.0, [42, 16]), "foo" 

2194 >>> result = dataset.map(i) 

2195 >>> result.element_spec 

2196 ((TensorSpec(shape=(), dtype=tf.float32, name=None), 

2197 TensorSpec(shape=(2,), dtype=tf.int32, name=None)), 

2198 TensorSpec(shape=(), dtype=tf.string, name=None)) 

2199 

2200 `map_func` can accept as arguments and return any type of dataset element. 

2201 

2202 Note that irrespective of the context in which `map_func` is defined (eager 

2203 vs. graph), tf.data traces the function and executes it as a graph. To use 

2204 Python code inside of the function you have a few options: 

2205 

2206 1) Rely on AutoGraph to convert Python code into an equivalent graph 

2207 computation. The downside of this approach is that AutoGraph can convert 

2208 some but not all Python code. 

2209 

2210 2) Use `tf.py_function`, which allows you to write arbitrary Python code but 

2211 will generally result in worse performance than 1). For example: 

2212 

2213 >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world']) 

2214 >>> # transform a string tensor to upper case string using a Python function 

2215 >>> def upper_case_fn(t: tf.Tensor): 

2216 ... return t.numpy().decode('utf-8').upper() 

2217 >>> d = d.map(lambda x: tf.py_function(func=upper_case_fn, 

2218 ... inp=[x], Tout=tf.string)) 

2219 >>> list(d.as_numpy_iterator()) 

2220 [b'HELLO', b'WORLD'] 

2221 

2222 3) Use `tf.numpy_function`, which also allows you to write arbitrary 

2223 Python code. Note that `tf.py_function` accepts `tf.Tensor` whereas 

2224 `tf.numpy_function` accepts numpy arrays and returns only numpy arrays. 

2225 For example: 

2226 

2227 >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world']) 

2228 >>> def upper_case_fn(t: np.ndarray): 

2229 ... return t.decode('utf-8').upper() 

2230 >>> d = d.map(lambda x: tf.numpy_function(func=upper_case_fn, 

2231 ... inp=[x], Tout=tf.string)) 

2232 >>> list(d.as_numpy_iterator()) 

2233 [b'HELLO', b'WORLD'] 

2234 

2235 Note that the use of `tf.numpy_function` and `tf.py_function` 

2236 in general precludes the possibility of executing user-defined 

2237 transformations in parallel (because of Python GIL). 

2238 

2239 Performance can often be improved by setting `num_parallel_calls` so that 

2240 `map` will use multiple threads to process elements. If deterministic order 

2241 isn't required, it can also improve performance to set 

2242 `deterministic=False`. 

2243 

2244 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 

2245 >>> dataset = dataset.map(lambda x: x + 1, 

2246 ... num_parallel_calls=tf.data.AUTOTUNE, 

2247 ... deterministic=False) 

2248 

2249 The order of elements yielded by this transformation is deterministic if 

2250 `deterministic=True`. If `map_func` contains stateful operations and 

2251 `num_parallel_calls > 1`, the order in which that state is accessed is 

2252 undefined, so the values of output elements may not be deterministic 

2253 regardless of the `deterministic` flag value. 

2254 

2255 Args: 

2256 map_func: A function mapping a dataset element to another dataset element. 

2257 num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`, 

2258 representing the number elements to process asynchronously in parallel. 

2259 If not specified, elements will be processed sequentially. If the value 

2260 `tf.data.AUTOTUNE` is used, then the number of parallel 

2261 calls is set dynamically based on available CPU. 

2262 deterministic: (Optional.) When `num_parallel_calls` is specified, if this 

2263 boolean is specified (`True` or `False`), it controls the order in which 

2264 the transformation produces elements. If set to `False`, the 

2265 transformation is allowed to yield elements out of order to trade 

2266 determinism for performance. If not specified, the 

2267 `tf.data.Options.deterministic` option (`True` by default) controls the 

2268 behavior. 

2269 name: (Optional.) A name for the tf.data operation. 

2270 

2271 Returns: 

2272 A new `Dataset` with the transformation applied as described above. 

2273 """ 

2274 # Loaded lazily due to a circular dependency (dataset_ops -> map_op -> 

2275 # dataset_ops). 

2276 # pylint: disable=g-import-not-at-top,protected-access 

2277 from tensorflow.python.data.ops import map_op 

2278 return map_op._map_v2( 

2279 self, 

2280 map_func, 

2281 num_parallel_calls=num_parallel_calls, 

2282 deterministic=deterministic, 

2283 name=name) 

2284 # pylint: enable=g-import-not-at-top,protected-access 

2285 

2286 def flat_map(self, map_func, name=None): 

2287 """Maps `map_func` across this dataset and flattens the result. 

2288 

2289 The type signature is: 

2290 

2291 ``` 

2292 def flat_map( 

2293 self: Dataset[T], 

2294 map_func: Callable[[T], Dataset[S]] 

2295 ) -> Dataset[S] 

2296 ``` 

2297 

2298 Use `flat_map` if you want to make sure that the order of your dataset 

2299 stays the same. For example, to flatten a dataset of batches into a 

2300 dataset of their elements: 

2301 

2302 >>> dataset = tf.data.Dataset.from_tensor_slices( 

2303 ... [[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 

2304 >>> dataset = dataset.flat_map(tf.data.Dataset.from_tensor_slices) 

2305 >>> list(dataset.as_numpy_iterator()) 

2306 [1, 2, 3, 4, 5, 6, 7, 8, 9] 

2307 

2308 `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since 

2309 `flat_map` produces the same output as 

2310 `tf.data.Dataset.interleave(cycle_length=1)` 

2311 

2312 Args: 

2313 map_func: A function mapping a dataset element to a dataset. 

2314 name: (Optional.) A name for the tf.data operation. 

2315 

2316 Returns: 

2317 A new `Dataset` with the transformation applied as described above. 

2318 """ 

2319 # Loaded lazily due to a circular dependency (dataset_ops -> flat_map_op -> 

2320 # dataset_ops). 

2321 # pylint: disable=g-import-not-at-top,protected-access 

2322 from tensorflow.python.data.ops import flat_map_op 

2323 return flat_map_op._flat_map(self, map_func, name=name) 

2324 # pylint: enable=g-import-not-at-top,protected-access 

2325 

2326 def ignore_errors(self, log_warning=False, name=None): 

2327 """Drops elements that cause errors. 

2328 

2329 >>> dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.]) 

2330 >>> dataset = dataset.map(lambda x: tf.debugging.check_numerics(1. / x, "")) 

2331 >>> list(dataset.as_numpy_iterator()) 

2332 Traceback (most recent call last): 

2333 ... 

2334 InvalidArgumentError: ... Tensor had Inf values 

2335 >>> dataset = dataset.ignore_errors() 

2336 >>> list(dataset.as_numpy_iterator()) 

2337 [1.0, 0.5, 0.25] 

2338 

2339 Args: 

2340 log_warning: (Optional.) A bool indicating whether or not ignored errors 

2341 should be logged to stderr. Defaults to `False`. 

2342 name: (Optional.) A string indicating a name for the `tf.data` operation. 

2343 

2344 Returns: 

2345 A new `Dataset` with the transformation applied as described above. 

2346 """ 

2347 # Loaded lazily due to a circular dependency (dataset_ops -> 

2348 # ignore_errors_op -> dataset_ops). 

2349 # pylint: disable=g-import-not-at-top,protected-access 

2350 from tensorflow.python.data.ops import ignore_errors_op 

2351 return ignore_errors_op._ignore_errors(self, log_warning, name) 

2352 # pylint: enable=g-import-not-at-top,protected-access 

2353 

2354 def interleave(self, 

2355 map_func, 

2356 cycle_length=None, 

2357 block_length=None, 

2358 num_parallel_calls=None, 

2359 deterministic=None, 

2360 name=None): 

2361 """Maps `map_func` across this dataset, and interleaves the results. 

2362 

2363 The type signature is: 

2364 

2365 ``` 

2366 def interleave( 

2367 self: Dataset[T], 

2368 map_func: Callable[[T], Dataset[S]] 

2369 ) -> Dataset[S] 

2370 ``` 

2371 

2372 For example, you can use `Dataset.interleave()` to process many input files 

2373 concurrently: 

2374 

2375 >>> # Preprocess 4 files concurrently, and interleave blocks of 16 records 

2376 >>> # from each file. 

2377 >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt", 

2378 ... "/var/data/file3.txt", "/var/data/file4.txt"] 

2379 >>> dataset = tf.data.Dataset.from_tensor_slices(filenames) 

2380 >>> def parse_fn(filename): 

2381 ... return tf.data.Dataset.range(10) 

2382 >>> dataset = dataset.interleave(lambda x: 

2383 ... tf.data.TextLineDataset(x).map(parse_fn, num_parallel_calls=1), 

2384 ... cycle_length=4, block_length=16) 

2385 

2386 The `cycle_length` and `block_length` arguments control the order in which 

2387 elements are produced. `cycle_length` controls the number of input elements 

2388 that are processed concurrently. If you set `cycle_length` to 1, this 

2389 transformation will handle one input element at a time, and will produce 

2390 identical results to `tf.data.Dataset.flat_map`. In general, 

2391 this transformation will apply `map_func` to `cycle_length` input elements, 

2392 open iterators on the returned `Dataset` objects, and cycle through them 

2393 producing `block_length` consecutive elements from each iterator, and 

2394 consuming the next input element each time it reaches the end of an 

2395 iterator. 

2396 

2397 For example: 

2398 

2399 >>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ] 

2400 >>> # NOTE: New lines indicate "block" boundaries. 

2401 >>> dataset = dataset.interleave( 

2402 ... lambda x: Dataset.from_tensors(x).repeat(6), 

2403 ... cycle_length=2, block_length=4) 

2404 >>> list(dataset.as_numpy_iterator()) 

2405 [1, 1, 1, 1, 

2406 2, 2, 2, 2, 

2407 1, 1, 

2408 2, 2, 

2409 3, 3, 3, 3, 

2410 4, 4, 4, 4, 

2411 3, 3, 

2412 4, 4, 

2413 5, 5, 5, 5, 

2414 5, 5] 

2415 

2416 Note: The order of elements yielded by this transformation is 

2417 deterministic, as long as `map_func` is a pure function and 

2418 `deterministic=True`. If `map_func` contains any stateful operations, the 

2419 order in which that state is accessed is undefined. 

2420 

2421 Performance can often be improved by setting `num_parallel_calls` so that 

2422 `interleave` will use multiple threads to fetch elements. If determinism 

2423 isn't required, it can also improve performance to set 

2424 `deterministic=False`. 

2425 

2426 >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt", 

2427 ... "/var/data/file3.txt", "/var/data/file4.txt"] 

2428 >>> dataset = tf.data.Dataset.from_tensor_slices(filenames) 

2429 >>> dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x), 

2430 ... cycle_length=4, num_parallel_calls=tf.data.AUTOTUNE, 

2431 ... deterministic=False) 

2432 

2433 Args: 

2434 map_func: A function that takes a dataset element and returns a 

2435 `tf.data.Dataset`. 

2436 cycle_length: (Optional.) The number of input elements that will be 

2437 processed concurrently. If not set, the tf.data runtime decides what it 

2438 should be based on available CPU. If `num_parallel_calls` is set to 

2439 `tf.data.AUTOTUNE`, the `cycle_length` argument identifies 

2440 the maximum degree of parallelism. 

2441 block_length: (Optional.) The number of consecutive elements to produce 

2442 from each input element before cycling to another input element. If not 

2443 set, defaults to 1. 

2444 num_parallel_calls: (Optional.) If specified, the implementation creates a 

2445 threadpool, which is used to fetch inputs from cycle elements 

2446 asynchronously and in parallel. The default behavior is to fetch inputs 

2447 from cycle elements synchronously with no parallelism. If the value 

2448 `tf.data.AUTOTUNE` is used, then the number of parallel 

2449 calls is set dynamically based on available CPU. 

2450 deterministic: (Optional.) When `num_parallel_calls` is specified, if this 

2451 boolean is specified (`True` or `False`), it controls the order in which 

2452 the transformation produces elements. If set to `False`, the 

2453 transformation is allowed to yield elements out of order to trade 

2454 determinism for performance. If not specified, the 

2455 `tf.data.Options.deterministic` option (`True` by default) controls the 

2456 behavior. 

2457 name: (Optional.) A name for the tf.data operation. 

2458 

2459 Returns: 

2460 A new `Dataset` with the transformation applied as described above. 

2461 """ 

2462 # Loaded lazily due to a circular dependency ( 

2463 # dataset_ops -> interleave_op -> dataset_ops). 

2464 # pylint: disable=g-import-not-at-top,protected-access 

2465 from tensorflow.python.data.ops import interleave_op 

2466 return interleave_op._interleave(self, map_func, cycle_length, block_length, 

2467 num_parallel_calls, deterministic, name) 

2468 # pylint: enable=g-import-not-at-top,protected-access 

2469 

2470 def filter(self, predicate, name=None): 

2471 """Filters this dataset according to `predicate`. 

2472 

2473 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

2474 >>> dataset = dataset.filter(lambda x: x < 3) 

2475 >>> list(dataset.as_numpy_iterator()) 

2476 [1, 2] 

2477 >>> # `tf.math.equal(x, y)` is required for equality comparison 

2478 >>> def filter_fn(x): 

2479 ... return tf.math.equal(x, 1) 

2480 >>> dataset = dataset.filter(filter_fn) 

2481 >>> list(dataset.as_numpy_iterator()) 

2482 [1] 

2483 

2484 Args: 

2485 predicate: A function mapping a dataset element to a boolean. 

2486 name: (Optional.) A name for the tf.data operation. 

2487 

2488 Returns: 

2489 A new `Dataset` with the transformation applied as described above. 

2490 """ 

2491 # Loaded lazily due to a circular dependency (dataset_ops -> filter_op -> 

2492 # dataset_ops). 

2493 # pylint: disable=g-import-not-at-top,protected-access 

2494 from tensorflow.python.data.ops import filter_op 

2495 return filter_op._filter(self, predicate, name) 

2496 # pylint: enable=g-import-not-at-top,protected-access 

2497 

2498 def apply(self, transformation_func): 

2499 """Applies a transformation function to this dataset. 

2500 

2501 `apply` enables chaining of custom `Dataset` transformations, which are 

2502 represented as functions that take one `Dataset` argument and return a 

2503 transformed `Dataset`. 

2504 

2505 >>> dataset = tf.data.Dataset.range(100) 

2506 >>> def dataset_fn(ds): 

2507 ... return ds.filter(lambda x: x < 5) 

2508 >>> dataset = dataset.apply(dataset_fn) 

2509 >>> list(dataset.as_numpy_iterator()) 

2510 [0, 1, 2, 3, 4] 

2511 

2512 Args: 

2513 transformation_func: A function that takes one `Dataset` argument and 

2514 returns a `Dataset`. 

2515 

2516 Returns: 

2517 A new `Dataset` with the transformation applied as described above. 

2518 """ 

2519 dataset = transformation_func(self) 

2520 if not isinstance(dataset, data_types.DatasetV2): 

2521 raise TypeError( 

2522 f"`transformation_func` must return a `tf.data.Dataset` object. " 

2523 f"Got {type(dataset)}.") 

2524 dataset._input_datasets = [self] # pylint: disable=protected-access 

2525 return dataset 

2526 

2527 def window(self, size, shift=None, stride=1, drop_remainder=False, name=None): 

2528 """Returns a dataset of "windows". 

2529 

2530 Each "window" is a dataset that contains a subset of elements of the 

2531 input dataset. These are finite datasets of size `size` (or possibly fewer 

2532 if there are not enough input elements to fill the window and 

2533 `drop_remainder` evaluates to `False`). 

2534 

2535 For example: 

2536 

2537 >>> dataset = tf.data.Dataset.range(7).window(3) 

2538 >>> for window in dataset: 

2539 ... print(window) 

2540 <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)> 

2541 <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)> 

2542 <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)> 

2543 

2544 Since windows are datasets, they can be iterated over: 

2545 

2546 >>> for window in dataset: 

2547 ... print(list(window.as_numpy_iterator())) 

2548 [0, 1, 2] 

2549 [3, 4, 5] 

2550 [6] 

2551 

2552 #### Shift 

2553 

2554 The `shift` argument determines the number of input elements to shift 

2555 between the start of each window. If windows and elements are both numbered 

2556 starting at 0, the first element in window `k` will be element `k * shift` 

2557 of the input dataset. In particular, the first element of the first window 

2558 will always be the first element of the input dataset. 

2559 

2560 >>> dataset = tf.data.Dataset.range(7).window(3, shift=1, 

2561 ... drop_remainder=True) 

2562 >>> for window in dataset: 

2563 ... print(list(window.as_numpy_iterator())) 

2564 [0, 1, 2] 

2565 [1, 2, 3] 

2566 [2, 3, 4] 

2567 [3, 4, 5] 

2568 [4, 5, 6] 

2569 

2570 #### Stride 

2571 

2572 The `stride` argument determines the stride between input elements within a 

2573 window. 

2574 

2575 >>> dataset = tf.data.Dataset.range(7).window(3, shift=1, stride=2, 

2576 ... drop_remainder=True) 

2577 >>> for window in dataset: 

2578 ... print(list(window.as_numpy_iterator())) 

2579 [0, 2, 4] 

2580 [1, 3, 5] 

2581 [2, 4, 6] 

2582 

2583 #### Nested elements 

2584 

2585 When the `window` transformation is applied to a dataset whos elements are 

2586 nested structures, it produces a dataset where the elements have the same 

2587 nested structure but each leaf is replaced by a window. In other words, 

2588 the nesting is applied outside of the windows as opposed inside of them. 

2589 

2590 The type signature is: 

2591 

2592 ``` 

2593 def window( 

2594 self: Dataset[Nest[T]], ... 

2595 ) -> Dataset[Nest[Dataset[T]]] 

2596 ``` 

2597 

2598 Applying `window` to a `Dataset` of tuples gives a tuple of windows: 

2599 

2600 >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5], 

2601 ... [6, 7, 8, 9, 10])) 

2602 >>> dataset = dataset.window(2) 

2603 >>> windows = next(iter(dataset)) 

2604 >>> windows 

2605 (<...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int32, name=None)>, 

2606 <...Dataset element_spec=TensorSpec(shape=(), dtype=tf.int32, name=None)>) 

2607 

2608 >>> def to_numpy(ds): 

2609 ... return list(ds.as_numpy_iterator()) 

2610 >>> 

2611 >>> for windows in dataset: 

2612 ... print(to_numpy(windows[0]), to_numpy(windows[1])) 

2613 [1, 2] [6, 7] 

2614 [3, 4] [8, 9] 

2615 [5] [10] 

2616 

2617 Applying `window` to a `Dataset` of dictionaries gives a dictionary of 

2618 `Datasets`: 

2619 

2620 >>> dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3], 

2621 ... 'b': [4, 5, 6], 

2622 ... 'c': [7, 8, 9]}) 

2623 >>> dataset = dataset.window(2) 

2624 >>> def to_numpy(ds): 

2625 ... return list(ds.as_numpy_iterator()) 

2626 >>> 

2627 >>> for windows in dataset: 

2628 ... print(tf.nest.map_structure(to_numpy, windows)) 

2629 {'a': [1, 2], 'b': [4, 5], 'c': [7, 8]} 

2630 {'a': [3], 'b': [6], 'c': [9]} 

2631 

2632 #### Flatten a dataset of windows 

2633 

2634 The `Dataset.flat_map` and `Dataset.interleave` methods can be used to 

2635 flatten a dataset of windows into a single dataset. 

2636 

2637 The argument to `flat_map` is a function that takes an element from the 

2638 dataset and returns a `Dataset`. `flat_map` chains together the resulting 

2639 datasets sequentially. 

2640 

2641 For example, to turn each window into a dense tensor: 

2642 

2643 >>> dataset = tf.data.Dataset.range(7).window(3, shift=1, 

2644 ... drop_remainder=True) 

2645 >>> batched = dataset.flat_map(lambda x:x.batch(3)) 

2646 >>> for batch in batched: 

2647 ... print(batch.numpy()) 

2648 [0 1 2] 

2649 [1 2 3] 

2650 [2 3 4] 

2651 [3 4 5] 

2652 [4 5 6] 

2653 

2654 Args: 

2655 size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements 

2656 of the input dataset to combine into a window. Must be positive. 

2657 shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 

2658 number of input elements by which the window moves in each iteration. 

2659 Defaults to `size`. Must be positive. 

2660 stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 

2661 stride of the input elements in the sliding window. Must be positive. 

2662 The default value of 1 means "retain every input element". 

2663 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 

2664 whether the last windows should be dropped if their size is smaller than 

2665 `size`. 

2666 name: (Optional.) A name for the tf.data operation. 

2667 

2668 Returns: 

2669 A new `Dataset` with the transformation applied as described above. 

2670 """ 

2671 # Loaded lazily due to a circular dependency (dataset_ops -> window_op -> 

2672 # dataset_ops). 

2673 # pylint: disable=g-import-not-at-top,protected-access 

2674 from tensorflow.python.data.ops import window_op 

2675 return window_op._window(self, size, shift, stride, drop_remainder, name) 

2676 # pylint: enable=g-import-not-at-top,protected-access 

2677 

2678 def reduce(self, initial_state, reduce_func, name=None): 

2679 """Reduces the input dataset to a single element. 

2680 

2681 The transformation calls `reduce_func` successively on every element of 

2682 the input dataset until the dataset is exhausted, aggregating information in 

2683 its internal state. The `initial_state` argument is used for the initial 

2684 state and the final state is returned as the result. 

2685 

2686 >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy() 

2687 5 

2688 >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy() 

2689 10 

2690 

2691 Args: 

2692 initial_state: An element representing the initial state of the 

2693 transformation. 

2694 reduce_func: A function that maps `(old_state, input_element)` to 

2695 `new_state`. It must take two arguments and return a new element 

2696 The structure of `new_state` must match the structure of 

2697 `initial_state`. 

2698 name: (Optional.) A name for the tf.data operation. 

2699 

2700 Returns: 

2701 A dataset element corresponding to the final state of the transformation. 

2702 

2703 """ 

2704 

2705 with ops.name_scope("initial_state"): 

2706 initial_state = structure.normalize_element(initial_state) 

2707 state_structure = structure.type_spec_from_value(initial_state) 

2708 

2709 # Iteratively rerun the reduce function until reaching a fixed point on 

2710 # `state_structure`. 

2711 need_to_rerun = True 

2712 while need_to_rerun: 

2713 

2714 wrapped_func = structured_function.StructuredFunctionWrapper( 

2715 reduce_func, 

2716 "reduce()", 

2717 input_structure=(state_structure, self.element_spec), 

2718 add_to_graph=False) 

2719 

2720 # Extract and validate class information from the returned values. 

2721 output_classes = wrapped_func.output_classes 

2722 state_classes = nest.map_structure( 

2723 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 

2724 state_structure) 

2725 for new_state_class, state_class in zip( 

2726 nest.flatten(output_classes), nest.flatten(state_classes)): 

2727 if not issubclass(new_state_class, state_class): 

2728 raise TypeError( 

2729 f"The element classes for the new state must match the initial " 

2730 f"state. Expected {state_classes} but got " 

2731 f"{wrapped_func.output_classes}.") 

2732 

2733 # Extract and validate type information from the returned values. 

2734 output_types = wrapped_func.output_types 

2735 state_types = nest.map_structure( 

2736 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 

2737 state_structure) 

2738 for new_state_type, state_type in zip( 

2739 nest.flatten(output_types), nest.flatten(state_types)): 

2740 if new_state_type != state_type: 

2741 raise TypeError( 

2742 f"The element types for the new state must match the initial " 

2743 f"state. Expected {state_types} but got " 

2744 f"{wrapped_func.output_types}.") 

2745 

2746 # Extract shape information from the returned values. 

2747 output_shapes = wrapped_func.output_shapes 

2748 state_shapes = nest.map_structure( 

2749 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 

2750 state_structure) 

2751 flat_state_shapes = nest.flatten(state_shapes) 

2752 flat_new_state_shapes = nest.flatten(output_shapes) 

2753 weakened_state_shapes = [ 

2754 original.most_specific_compatible_shape(new) 

2755 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 

2756 ] 

2757 

2758 need_to_rerun = False 

2759 for original_shape, weakened_shape in zip(flat_state_shapes, 

2760 weakened_state_shapes): 

2761 if original_shape.ndims is not None and ( 

2762 weakened_shape.ndims is None or 

2763 original_shape.as_list() != weakened_shape.as_list()): 

2764 need_to_rerun = True 

2765 break 

2766 

2767 if need_to_rerun: 

2768 # TODO(b/110122868): Support a "most specific compatible structure" 

2769 # method for combining structures, to avoid using legacy structures 

2770 # here. 

2771 state_structure = structure.convert_legacy_structure( 

2772 state_types, 

2773 nest.pack_sequence_as(state_shapes, weakened_state_shapes), 

2774 state_classes) 

2775 

2776 reduce_func = wrapped_func.function 

2777 reduce_func.add_to_graph(ops.get_default_graph()) 

2778 

2779 dataset = self._apply_debug_options() 

2780 

2781 # pylint: disable=protected-access 

2782 metadata = dataset_metadata_pb2.Metadata() 

2783 if name: 

2784 metadata.name = _validate_and_encode(name) 

2785 return structure.from_compatible_tensor_list( 

2786 state_structure, 

2787 gen_dataset_ops.reduce_dataset( 

2788 dataset._variant_tensor, 

2789 structure.to_tensor_list(state_structure, initial_state), 

2790 reduce_func.captured_inputs, 

2791 f=reduce_func, 

2792 output_shapes=structure.get_flat_tensor_shapes(state_structure), 

2793 output_types=structure.get_flat_tensor_types(state_structure), 

2794 metadata=metadata.SerializeToString())) 

2795 

2796 def get_single_element(self, name=None): 

2797 """Returns the single element of the `dataset`. 

2798 

2799 The function enables you to use a `tf.data.Dataset` in a stateless 

2800 "tensor-in tensor-out" expression, without creating an iterator. 

2801 This facilitates the ease of data transformation on tensors using the 

2802 optimized `tf.data.Dataset` abstraction on top of them. 

2803 

2804 For example, lets consider a `preprocessing_fn` which would take as an 

2805 input the raw features and returns the processed feature along with 

2806 it's label. 

2807 

2808 ```python 

2809 def preprocessing_fn(raw_feature): 

2810 # ... the raw_feature is preprocessed as per the use-case 

2811 return feature 

2812 

2813 raw_features = ... # input batch of BATCH_SIZE elements. 

2814 dataset = (tf.data.Dataset.from_tensor_slices(raw_features) 

2815 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 

2816 .batch(BATCH_SIZE)) 

2817 

2818 processed_features = dataset.get_single_element() 

2819 ``` 

2820 

2821 In the above example, the `raw_features` tensor of length=BATCH_SIZE 

2822 was converted to a `tf.data.Dataset`. Next, each of the `raw_feature` was 

2823 mapped using the `preprocessing_fn` and the processed features were 

2824 grouped into a single batch. The final `dataset` contains only one element 

2825 which is a batch of all the processed features. 

2826 

2827 NOTE: The `dataset` should contain only one element. 

2828 

2829 Now, instead of creating an iterator for the `dataset` and retrieving the 

2830 batch of features, the `tf.data.get_single_element()` function is used 

2831 to skip the iterator creation process and directly output the batch of 

2832 features. 

2833 

2834 This can be particularly useful when your tensor transformations are 

2835 expressed as `tf.data.Dataset` operations, and you want to use those 

2836 transformations while serving your model. 

2837 

2838 #### Keras 

2839 

2840 ```python 

2841 

2842 model = ... # A pre-built or custom model 

2843 

2844 class PreprocessingModel(tf.keras.Model): 

2845 def __init__(self, model): 

2846 super().__init__(self) 

2847 self.model = model 

2848 

2849 @tf.function(input_signature=[...]) 

2850 def serving_fn(self, data): 

2851 ds = tf.data.Dataset.from_tensor_slices(data) 

2852 ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 

2853 ds = ds.batch(batch_size=BATCH_SIZE) 

2854 return tf.argmax(self.model(ds.get_single_element()), axis=-1) 

2855 

2856 preprocessing_model = PreprocessingModel(model) 

2857 your_exported_model_dir = ... # save the model to this path. 

2858 tf.saved_model.save(preprocessing_model, your_exported_model_dir, 

2859 signatures={'serving_default': preprocessing_model.serving_fn} 

2860 ) 

2861 ``` 

2862 

2863 #### Estimator 

2864 

2865 In the case of estimators, you need to generally define a `serving_input_fn` 

2866 which would require the features to be processed by the model while 

2867 inferencing. 

2868 

2869 ```python 

2870 def serving_input_fn(): 

2871 

2872 raw_feature_spec = ... # Spec for the raw_features 

2873 input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( 

2874 raw_feature_spec, default_batch_size=None) 

2875 ) 

2876 serving_input_receiver = input_fn() 

2877 raw_features = serving_input_receiver.features 

2878 

2879 def preprocessing_fn(raw_feature): 

2880 # ... the raw_feature is preprocessed as per the use-case 

2881 return feature 

2882 

2883 dataset = (tf.data.Dataset.from_tensor_slices(raw_features) 

2884 .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) 

2885 .batch(BATCH_SIZE)) 

2886 

2887 processed_features = dataset.get_single_element() 

2888 

2889 # Please note that the value of `BATCH_SIZE` should be equal to 

2890 # the size of the leading dimension of `raw_features`. This ensures 

2891 # that `dataset` has only element, which is a pre-requisite for 

2892 # using `dataset.get_single_element()`. 

2893 

2894 return tf.estimator.export.ServingInputReceiver( 

2895 processed_features, serving_input_receiver.receiver_tensors) 

2896 

2897 estimator = ... # A pre-built or custom estimator 

2898 estimator.export_saved_model(your_exported_model_dir, serving_input_fn) 

2899 ``` 

2900 

2901 Args: 

2902 name: (Optional.) A name for the tf.data operation. 

2903 

2904 Returns: 

2905 A nested structure of `tf.Tensor` objects, corresponding to the single 

2906 element of `dataset`. 

2907 

2908 Raises: 

2909 InvalidArgumentError: (at runtime) if `dataset` does not contain exactly 

2910 one element. 

2911 """ 

2912 

2913 metadata = dataset_metadata_pb2.Metadata() 

2914 if name: 

2915 metadata.name = _validate_and_encode(name) 

2916 return structure.from_compatible_tensor_list( 

2917 self.element_spec, 

2918 gen_dataset_ops.dataset_to_single_element( 

2919 self._variant_tensor, 

2920 metadata=metadata.SerializeToString(), 

2921 **self._flat_structure)) # pylint: disable=protected-access 

2922 

2923 def unbatch(self, name=None): 

2924 """Splits elements of a dataset into multiple elements. 

2925 

2926 For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, 

2927 where `B` may vary for each input element, then for each element in the 

2928 dataset, the unbatched dataset will contain `B` consecutive elements 

2929 of shape `[a0, a1, ...]`. 

2930 

2931 >>> elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ] 

2932 >>> dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64) 

2933 >>> dataset = dataset.unbatch() 

2934 >>> list(dataset.as_numpy_iterator()) 

2935 [1, 2, 3, 1, 2, 1, 2, 3, 4] 

2936 

2937 Note: `unbatch` requires a data copy to slice up the batched tensor into 

2938 smaller, unbatched tensors. When optimizing performance, try to avoid 

2939 unnecessary usage of `unbatch`. 

2940 

2941 Args: 

2942 name: (Optional.) A name for the tf.data operation. 

2943 

2944 Returns: 

2945 A new `Dataset` with the transformation applied as described above. 

2946 """ 

2947 # Loaded lazily due to a circular dependency ( 

2948 # dataset_ops -> unbatch_op -> dataset_ops). 

2949 # pylint: disable=g-import-not-at-top,protected-access 

2950 from tensorflow.python.data.ops import unbatch_op 

2951 return unbatch_op._unbatch(self, name=name) 

2952 # pylint: enable=g-import-not-at-top,protected-access 

2953 

2954 def with_options(self, options, name=None): 

2955 """Returns a new `tf.data.Dataset` with the given options set. 

2956 

2957 The options are "global" in the sense they apply to the entire dataset. 

2958 If options are set multiple times, they are merged as long as different 

2959 options do not use different non-default values. 

2960 

2961 >>> ds = tf.data.Dataset.range(5) 

2962 >>> ds = ds.interleave(lambda x: tf.data.Dataset.range(5), 

2963 ... cycle_length=3, 

2964 ... num_parallel_calls=3) 

2965 >>> options = tf.data.Options() 

2966 >>> # This will make the interleave order non-deterministic. 

2967 >>> options.deterministic = False 

2968 >>> ds = ds.with_options(options) 

2969 

2970 Args: 

2971 options: A `tf.data.Options` that identifies the options the use. 

2972 name: (Optional.) A name for the tf.data operation. 

2973 

2974 Returns: 

2975 A new `Dataset` with the transformation applied as described above. 

2976 

2977 Raises: 

2978 ValueError: when an option is set more than once to a non-default value 

2979 """ 

2980 return _OptionsDataset(self, options, name=name) 

2981 

2982 def cardinality(self): 

2983 """Returns the cardinality of the dataset, if known. 

2984 

2985 `cardinality` may return `tf.data.INFINITE_CARDINALITY` if the dataset 

2986 contains an infinite number of elements or `tf.data.UNKNOWN_CARDINALITY` if 

2987 the analysis fails to determine the number of elements in the dataset 

2988 (e.g. when the dataset source is a file). 

2989 

2990 >>> dataset = tf.data.Dataset.range(42) 

2991 >>> print(dataset.cardinality().numpy()) 

2992 42 

2993 >>> dataset = dataset.repeat() 

2994 >>> cardinality = dataset.cardinality() 

2995 >>> print((cardinality == tf.data.INFINITE_CARDINALITY).numpy()) 

2996 True 

2997 >>> dataset = dataset.filter(lambda x: True) 

2998 >>> cardinality = dataset.cardinality() 

2999 >>> print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy()) 

3000 True 

3001 

3002 Returns: 

3003 A scalar `tf.int64` `Tensor` representing the cardinality of the dataset. 

3004 If the cardinality is infinite or unknown, `cardinality` returns the 

3005 named constants `tf.data.INFINITE_CARDINALITY` and 

3006 `tf.data.UNKNOWN_CARDINALITY` respectively. 

3007 """ 

3008 return gen_dataset_ops.dataset_cardinality(self._variant_tensor) 

3009 

3010 def group_by_window(self, 

3011 key_func, 

3012 reduce_func, 

3013 window_size=None, 

3014 window_size_func=None, 

3015 name=None): 

3016 """Groups windows of elements by key and reduces them. 

3017 

3018 This transformation maps each consecutive element in a dataset to a key 

3019 using `key_func` and groups the elements by key. It then applies 

3020 `reduce_func` to at most `window_size_func(key)` elements matching the same 

3021 key. All except the final window for each key will contain 

3022 `window_size_func(key)` elements; the final window may be smaller. 

3023 

3024 You may provide either a constant `window_size` or a window size determined 

3025 by the key through `window_size_func`. 

3026 

3027 >>> dataset = tf.data.Dataset.range(10) 

3028 >>> window_size = 5 

3029 >>> key_func = lambda x: x%2 

3030 >>> reduce_func = lambda key, dataset: dataset.batch(window_size) 

3031 >>> dataset = dataset.group_by_window( 

3032 ... key_func=key_func, 

3033 ... reduce_func=reduce_func, 

3034 ... window_size=window_size) 

3035 >>> for elem in dataset.as_numpy_iterator(): 

3036 ... print(elem) 

3037 [0 2 4 6 8] 

3038 [1 3 5 7 9] 

3039 

3040 Args: 

3041 key_func: A function mapping a nested structure of tensors (having shapes 

3042 and types defined by `self.output_shapes` and `self.output_types`) to a 

3043 scalar `tf.int64` tensor. 

3044 reduce_func: A function mapping a key and a dataset of up to `window_size` 

3045 consecutive elements matching that key to another dataset. 

3046 window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 

3047 consecutive elements matching the same key to combine in a single batch, 

3048 which will be passed to `reduce_func`. Mutually exclusive with 

3049 `window_size_func`. 

3050 window_size_func: A function mapping a key to a `tf.int64` scalar 

3051 `tf.Tensor`, representing the number of consecutive elements matching 

3052 the same key to combine in a single batch, which will be passed to 

3053 `reduce_func`. Mutually exclusive with `window_size`. 

3054 name: (Optional.) A name for the tf.data operation. 

3055 

3056 Returns: 

3057 A new `Dataset` with the transformation applied as described above. 

3058 

3059 Raises: 

3060 ValueError: if neither or both of {`window_size`, `window_size_func`} are 

3061 passed. 

3062 """ 

3063 # Loaded lazily due to a circular dependency ( 

3064 # dataset_ops -> group_by_window_op -> dataset_ops). 

3065 # pylint: disable=g-import-not-at-top,protected-access 

3066 from tensorflow.python.data.ops import group_by_window_op 

3067 return group_by_window_op._group_by_window( 

3068 self, key_func, reduce_func, window_size, window_size_func, name=name) 

3069 # pylint: enable=g-import-not-at-top,protected-access 

3070 

3071 def bucket_by_sequence_length(self, 

3072 element_length_func, 

3073 bucket_boundaries, 

3074 bucket_batch_sizes, 

3075 padded_shapes=None, 

3076 padding_values=None, 

3077 pad_to_bucket_boundary=False, 

3078 no_padding=False, 

3079 drop_remainder=False, 

3080 name=None): 

3081 """A transformation that buckets elements in a `Dataset` by length. 

3082 

3083 Elements of the `Dataset` are grouped together by length and then are padded 

3084 and batched. 

3085 

3086 This is useful for sequence tasks in which the elements have variable 

3087 length. Grouping together elements that have similar lengths reduces the 

3088 total fraction of padding in a batch which increases training step 

3089 efficiency. 

3090 

3091 Below is an example to bucketize the input data to the 3 buckets 

3092 "[0, 3), [3, 5), [5, inf)" based on sequence length, with batch size 2. 

3093 

3094 >>> elements = [ 

3095 ... [0], [1, 2, 3, 4], [5, 6, 7], 

3096 ... [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]] 

3097 >>> dataset = tf.data.Dataset.from_generator( 

3098 ... lambda: elements, tf.int64, output_shapes=[None]) 

3099 >>> dataset = dataset.bucket_by_sequence_length( 

3100 ... element_length_func=lambda elem: tf.shape(elem)[0], 

3101 ... bucket_boundaries=[3, 5], 

3102 ... bucket_batch_sizes=[2, 2, 2]) 

3103 >>> for elem in dataset.as_numpy_iterator(): 

3104 ... print(elem) 

3105 [[1 2 3 4] 

3106 [5 6 7 0]] 

3107 [[ 7 8 9 10 11 0] 

3108 [13 14 15 16 19 20]] 

3109 [[ 0 0] 

3110 [21 22]] 

3111 

3112 Args: 

3113 element_length_func: function from element in `Dataset` to `tf.int32`, 

3114 determines the length of the element, which will determine the bucket it 

3115 goes into. 

3116 bucket_boundaries: `list<int>`, upper length boundaries of the buckets. 

3117 bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be 

3118 `len(bucket_boundaries) + 1`. 

3119 padded_shapes: Nested structure of `tf.TensorShape` to pass to 

3120 `tf.data.Dataset.padded_batch`. If not provided, will use 

3121 `dataset.output_shapes`, which will result in variable length dimensions 

3122 being padded out to the maximum length in each batch. 

3123 padding_values: Values to pad with, passed to 

3124 `tf.data.Dataset.padded_batch`. Defaults to padding with 0. 

3125 pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown 

3126 size to maximum length in batch. If `True`, will pad dimensions with 

3127 unknown size to bucket boundary minus 1 (i.e., the maximum length in 

3128 each bucket), and caller must ensure that the source `Dataset` does not 

3129 contain any elements with length longer than `max(bucket_boundaries)`. 

3130 no_padding: `bool`, indicates whether to pad the batch features (features 

3131 need to be either of type `tf.sparse.SparseTensor` or of same shape). 

3132 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 

3133 whether the last batch should be dropped in the case it has fewer than 

3134 `batch_size` elements; the default behavior is not to drop the smaller 

3135 batch. 

3136 name: (Optional.) A name for the tf.data operation. 

3137 

3138 Returns: 

3139 A new `Dataset` with the transformation applied as described above. 

3140 

3141 Raises: 

3142 ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. 

3143 """ 

3144 if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1): 

3145 raise ValueError( 

3146 f"`len(bucket_batch_sizes)` must equal `len(bucket_boundaries) + 1` " 

3147 f"but `len(bucket_batch_sizes)={len(bucket_batch_sizes)}` and " 

3148 f"`len(bucket_boundaries)={len(bucket_boundaries)}`.") 

3149 

3150 batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) 

3151 

3152 def element_to_bucket_id(*args): 

3153 """Return int64 id of the length bucket for this element.""" 

3154 seq_length = element_length_func(*args) 

3155 

3156 boundaries = list(bucket_boundaries) 

3157 buckets_min = [np.iinfo(np.int32).min] + boundaries 

3158 buckets_max = boundaries + [np.iinfo(np.int32).max] 

3159 conditions_c = math_ops.logical_and( 

3160 math_ops.less_equal(buckets_min, seq_length), 

3161 math_ops.less(seq_length, buckets_max)) 

3162 bucket_id = math_ops.reduce_min(array_ops.where(conditions_c)) 

3163 

3164 return bucket_id 

3165 

3166 def window_size_fn(bucket_id): 

3167 # The window size is set to the batch size for this bucket 

3168 window_size = batch_sizes[bucket_id] 

3169 return window_size 

3170 

3171 def make_padded_shapes(shapes, none_filler=None): 

3172 padded = [] 

3173 for shape in nest.flatten(shapes): 

3174 shape = tensor_shape.TensorShape(shape) 

3175 shape = [ 

3176 none_filler if tensor_shape.dimension_value(d) is None else d 

3177 for d in shape 

3178 ] 

3179 padded.append(shape) 

3180 return nest.pack_sequence_as(shapes, padded) 

3181 

3182 def batching_fn(bucket_id, grouped_dataset): 

3183 """Batch elements in dataset.""" 

3184 batch_size = window_size_fn(bucket_id) 

3185 if no_padding: 

3186 return grouped_dataset.batch( 

3187 batch_size, drop_remainder=drop_remainder, name=name) 

3188 none_filler = None 

3189 if pad_to_bucket_boundary: 

3190 err_msg = ("When pad_to_bucket_boundary=True, elements must have " 

3191 "length < max(bucket_boundaries).") 

3192 check = check_ops.assert_less( 

3193 bucket_id, 

3194 constant_op.constant( 

3195 len(bucket_batch_sizes) - 1, dtype=dtypes.int64), 

3196 message=err_msg) 

3197 with ops.control_dependencies([check]): 

3198 boundaries = constant_op.constant( 

3199 bucket_boundaries, dtype=dtypes.int64) 

3200 bucket_boundary = boundaries[bucket_id] 

3201 none_filler = bucket_boundary - 1 

3202 input_shapes = get_legacy_output_shapes(grouped_dataset) 

3203 shapes = make_padded_shapes( 

3204 padded_shapes or input_shapes, none_filler=none_filler) 

3205 return grouped_dataset.padded_batch( 

3206 batch_size, 

3207 shapes, 

3208 padding_values, 

3209 drop_remainder=drop_remainder, 

3210 name=name) 

3211 

3212 return self.group_by_window( 

3213 key_func=element_to_bucket_id, 

3214 reduce_func=batching_fn, 

3215 window_size_func=window_size_fn, 

3216 name=name) 

3217 

3218 @staticmethod 

3219 def random(seed=None, rerandomize_each_iteration=None, name=None): 

3220 """Creates a `Dataset` of pseudorandom values. 

3221 

3222 The dataset generates a sequence of uniformly distributed integer values. 

3223 

3224 `rerandomize_each_iteration` controls whether the sequence of random number 

3225 generated should be re-randomized for each epoch. The default value is False 

3226 where the dataset generates the same sequence of random numbers for each 

3227 epoch. 

3228 

3229 >>> ds1 = tf.data.Dataset.random(seed=4).take(10) 

3230 >>> ds2 = tf.data.Dataset.random(seed=4).take(10) 

3231 >>> print(list(ds1.as_numpy_iterator())==list(ds2.as_numpy_iterator())) 

3232 True 

3233 

3234 >>> ds3 = tf.data.Dataset.random(seed=4).take(10) 

3235 >>> ds3_first_epoch = list(ds3.as_numpy_iterator()) 

3236 >>> ds3_second_epoch = list(ds3.as_numpy_iterator()) 

3237 >>> print(ds3_first_epoch == ds3_second_epoch) 

3238 True 

3239 

3240 >>> ds4 = tf.data.Dataset.random( 

3241 ... seed=4, rerandomize_each_iteration=True).take(10) 

3242 >>> ds4_first_epoch = list(ds4.as_numpy_iterator()) 

3243 >>> ds4_second_epoch = list(ds4.as_numpy_iterator()) 

3244 >>> print(ds4_first_epoch == ds4_second_epoch) 

3245 False 

3246 

3247 Args: 

3248 seed: (Optional) If specified, the dataset produces a deterministic 

3249 sequence of values. 

3250 rerandomize_each_iteration: (Optional) If set to False, the dataset 

3251 generates the same sequence of random numbers for each epoch. If set to 

3252 True, it generates a different deterministic sequence of random numbers 

3253 for each epoch. It is defaulted to False if left unspecified. 

3254 name: (Optional.) A name for the tf.data operation. 

3255 

3256 Returns: 

3257 Dataset: A `Dataset`. 

3258 """ 

3259 # Loaded lazily due to a circular dependency ( 

3260 # dataset_ops -> random_op -> dataset_ops). 

3261 # pylint: disable=g-import-not-at-top,protected-access 

3262 from tensorflow.python.data.ops import random_op 

3263 return random_op._random( 

3264 seed=seed, 

3265 rerandomize_each_iteration=rerandomize_each_iteration, 

3266 name=name) 

3267 # pylint: enable=g-import-not-at-top,protected-access 

3268 

3269 def snapshot(self, 

3270 path, 

3271 compression="AUTO", 

3272 reader_func=None, 

3273 shard_func=None, 

3274 name=None): 

3275 """API to persist the output of the input dataset. 

3276 

3277 The snapshot API allows users to transparently persist the output of their 

3278 preprocessing pipeline to disk, and materialize the pre-processed data on a 

3279 different training run. 

3280 

3281 This API enables repeated preprocessing steps to be consolidated, and allows 

3282 re-use of already processed data, trading off disk storage and network 

3283 bandwidth for freeing up more valuable CPU resources and accelerator compute 

3284 time. 

3285 

3286 https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md 

3287 has detailed design documentation of this feature. 

3288 

3289 Users can specify various options to control the behavior of snapshot, 

3290 including how snapshots are read from and written to by passing in 

3291 user-defined functions to the `reader_func` and `shard_func` parameters. 

3292 

3293 `shard_func` is a user specified function that maps input elements to 

3294 snapshot shards. 

3295 

3296 Users may want to specify this function to control how snapshot files should 

3297 be written to disk. Below is an example of how a potential `shard_func` 

3298 could be written. 

3299 

3300 ```python 

3301 dataset = ... 

3302 dataset = dataset.enumerate() 

3303 dataset = dataset.snapshot("/path/to/snapshot/dir", 

3304 shard_func=lambda x, y: x % NUM_SHARDS, ...) 

3305 dataset = dataset.map(lambda x, y: y) 

3306 ``` 

3307 

3308 `reader_func` is a user specified function that accepts a single argument: 

3309 (1) a Dataset of Datasets, each representing a "split" of elements of the 

3310 original dataset. The cardinality of the input dataset matches the 

3311 number of the shards specified in the `shard_func` (see above). The function 

3312 should return a Dataset of elements of the original dataset. 

3313 

3314 Users may want specify this function to control how snapshot files should be 

3315 read from disk, including the amount of shuffling and parallelism. 

3316 

3317 Here is an example of a standard reader function a user can define. This 

3318 function enables both dataset shuffling and parallel reading of datasets: 

3319 

3320 ```python 

3321 def user_reader_func(datasets): 

3322 # shuffle the datasets splits 

3323 datasets = datasets.shuffle(NUM_CORES) 

3324 # read datasets in parallel and interleave their elements 

3325 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE) 

3326 

3327 dataset = dataset.snapshot("/path/to/snapshot/dir", 

3328 reader_func=user_reader_func) 

3329 ``` 

3330 

3331 By default, snapshot parallelizes reads by the number of cores available on 

3332 the system, but will not attempt to shuffle the data. 

3333 

3334 Args: 

3335 path: Required. A directory to use for storing / loading the snapshot to / 

3336 from. 

3337 compression: Optional. The type of compression to apply to the snapshot 

3338 written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None. 

3339 Defaults to `AUTO`, which attempts to pick an appropriate compression 

3340 algorithm for the dataset. 

3341 reader_func: Optional. A function to control how to read data from 

3342 snapshot shards. 

3343 shard_func: Optional. A function to control how to shard data when writing 

3344 a snapshot. 

3345 name: (Optional.) A name for the tf.data operation. 

3346 

3347 Returns: 

3348 A new `Dataset` with the transformation applied as described above. 

3349 """ 

3350 # Loaded lazily due to a circular dependency ( 

3351 # dataset_ops -> snapshot_op -> dataset_ops). 

3352 # pylint: disable=g-import-not-at-top,protected-access 

3353 from tensorflow.python.data.ops import snapshot_op 

3354 return snapshot_op._snapshot( 

3355 self, path, compression, reader_func, shard_func, name=name) 

3356 # pylint: enable=g-import-not-at-top,protected-access 

3357 

3358 def scan(self, initial_state, scan_func, name=None): 

3359 """A transformation that scans a function across an input dataset. 

3360 

3361 This transformation is a stateful relative of `tf.data.Dataset.map`. 

3362 In addition to mapping `scan_func` across the elements of the input dataset, 

3363 `scan()` accumulates one or more state tensors, whose initial values are 

3364 `initial_state`. 

3365 

3366 >>> dataset = tf.data.Dataset.range(10) 

3367 >>> initial_state = tf.constant(0, dtype=tf.int64) 

3368 >>> scan_func = lambda state, i: (state + i, state + i) 

3369 >>> dataset = dataset.scan(initial_state=initial_state, scan_func=scan_func) 

3370 >>> list(dataset.as_numpy_iterator()) 

3371 [0, 1, 3, 6, 10, 15, 21, 28, 36, 45] 

3372 

3373 Args: 

3374 initial_state: A nested structure of tensors, representing the initial 

3375 state of the accumulator. 

3376 scan_func: A function that maps `(old_state, input_element)` to 

3377 `(new_state, output_element)`. It must take two arguments and return a 

3378 pair of nested structures of tensors. The `new_state` must match the 

3379 structure of `initial_state`. 

3380 name: (Optional.) A name for the tf.data operation. 

3381 

3382 Returns: 

3383 A new `Dataset` with the transformation applied as described above. 

3384 """ 

3385 

3386 # Loaded lazily due to a circular dependency (dataset_ops -> 

3387 # scan_op -> dataset_ops). 

3388 # pylint: disable=g-import-not-at-top,protected-access 

3389 from tensorflow.python.data.ops import scan_op 

3390 return scan_op._scan(self, initial_state, scan_func, name=name) 

3391 # pylint: enable=g-import-not-at-top,protected-access 

3392 

3393 def take_while(self, predicate, name=None): 

3394 """A transformation that stops dataset iteration based on a `predicate`. 

3395 

3396 >>> dataset = tf.data.Dataset.range(10) 

3397 >>> dataset = dataset.take_while(lambda x: x < 5) 

3398 >>> list(dataset.as_numpy_iterator()) 

3399 [0, 1, 2, 3, 4] 

3400 

3401 Args: 

3402 predicate: A function that maps a nested structure of tensors (having 

3403 shapes and types defined by `self.output_shapes` and 

3404 `self.output_types`) to a scalar `tf.bool` tensor. 

3405 name: (Optional.) A name for the tf.data operation. 

3406 

3407 Returns: 

3408 A new `Dataset` with the transformation applied as described above. 

3409 """ 

3410 # Loaded lazily due to a circular dependency ( 

3411 # dataset_ops -> take_while_op -> dataset_ops). 

3412 # pylint: disable=g-import-not-at-top,protected-access 

3413 from tensorflow.python.data.ops import take_while_op 

3414 return take_while_op._take_while(self, predicate, name=name) 

3415 # pylint: enable=g-import-not-at-top,protected-access 

3416 

3417 def unique(self, name=None): 

3418 """A transformation that discards duplicate elements of a `Dataset`. 

3419 

3420 Use this transformation to produce a dataset that contains one instance of 

3421 each unique element in the input. For example: 

3422 

3423 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1]) 

3424 >>> dataset = dataset.unique() 

3425 >>> sorted(list(dataset.as_numpy_iterator())) 

3426 [1, 2, 37] 

3427 

3428 Note: This transformation only supports datasets which fit into memory 

3429 and have elements of either `tf.int32`, `tf.int64` or `tf.string` type. 

3430 

3431 Args: 

3432 name: (Optional.) A name for the tf.data operation. 

3433 

3434 Returns: 

3435 A new `Dataset` with the transformation applied as described above. 

3436 """ 

3437 # Loaded lazily due to a circular dependency (dataset_ops -> unique_op -> 

3438 # dataset_ops). 

3439 # pylint: disable=g-import-not-at-top,protected-access 

3440 from tensorflow.python.data.ops import unique_op 

3441 return unique_op._unique(self, name) 

3442 # pylint: enable=g-import-not-at-top,protected-access 

3443 

3444 def rejection_resample(self, 

3445 class_func, 

3446 target_dist, 

3447 initial_dist=None, 

3448 seed=None, 

3449 name=None): 

3450 """Resamples elements to reach a target distribution. 

3451 

3452 Note: This implementation can reject **or repeat** elements in order to 

3453 reach the `target_dist`. So, in some cases, the output `Dataset` may be 

3454 larger than the input `Dataset`. 

3455 

3456 >>> initial_dist = [0.6, 0.4] 

3457 >>> n = 1000 

3458 >>> elems = np.random.choice(len(initial_dist), size=n, p=initial_dist) 

3459 >>> dataset = tf.data.Dataset.from_tensor_slices(elems) 

3460 >>> zero, one = np.bincount(list(dataset.as_numpy_iterator())) / n 

3461 

3462 Following from `initial_dist`, `zero` is ~0.6 and `one` is ~0.4. 

3463 

3464 >>> target_dist = [0.5, 0.5] 

3465 >>> dataset = dataset.rejection_resample( 

3466 ... class_func=lambda x: x, 

3467 ... target_dist=target_dist, 

3468 ... initial_dist=initial_dist) 

3469 >>> dataset = dataset.map(lambda class_func_result, data: data) 

3470 >>> zero, one = np.bincount(list(dataset.as_numpy_iterator())) / n 

3471 

3472 Following from `target_dist`, `zero` is ~0.5 and `one` is ~0.5. 

3473 

3474 Args: 

3475 class_func: A function mapping an element of the input dataset to a scalar 

3476 `tf.int32` tensor. Values should be in `[0, num_classes)`. 

3477 target_dist: A floating point type tensor, shaped `[num_classes]`. 

3478 initial_dist: (Optional.) A floating point type tensor, shaped 

3479 `[num_classes]`. If not provided, the true class distribution is 

3480 estimated live in a streaming fashion. 

3481 seed: (Optional.) Python integer seed for the resampler. 

3482 name: (Optional.) A name for the tf.data operation. 

3483 

3484 Returns: 

3485 A new `Dataset` with the transformation applied as described above. 

3486 """ 

3487 

3488 # TODO(b/245793127): Consider switching back to the 'v1' implementation. 

3489 

3490 target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") 

3491 target_dist_t = math_ops.cast(target_dist_t, dtypes.float32) 

3492 

3493 # Get initial distribution. 

3494 if initial_dist is not None: 

3495 initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") 

3496 initial_dist_t = math_ops.cast(initial_dist_t, dtypes.float32) 

3497 acceptance_dist, prob_of_original = ( 

3498 _calculate_acceptance_probs_with_mixing(initial_dist_t, 

3499 target_dist_t)) 

3500 initial_dist_ds = DatasetV2.from_tensors( 

3501 initial_dist_t, name=name).repeat(name=name) 

3502 acceptance_dist_ds = DatasetV2.from_tensors( 

3503 acceptance_dist, name=name).repeat(name=name) 

3504 prob_of_original_ds = DatasetV2.from_tensors( 

3505 prob_of_original, name=name).repeat(name=name) 

3506 else: 

3507 initial_dist_ds = _estimate_initial_dist_ds( 

3508 target_dist_t, self.map(class_func, name=name), name=name) 

3509 acceptance_and_original_prob_ds = initial_dist_ds.map( 

3510 lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda 

3511 initial, target_dist_t), 

3512 name=name) 

3513 acceptance_dist_ds = acceptance_and_original_prob_ds.map( 

3514 lambda accept_prob, _: accept_prob, name=name) 

3515 prob_of_original_ds = acceptance_and_original_prob_ds.map( 

3516 lambda _, prob_original: prob_original, name=name) 

3517 filtered_ds = _filter_ds(self, acceptance_dist_ds, initial_dist_ds, 

3518 class_func, seed) 

3519 # Prefetch filtered dataset for speed. 

3520 filtered_ds = filtered_ds.prefetch(3, name=name) 

3521 

3522 prob_original_static = _get_prob_original_static( 

3523 initial_dist_t, target_dist_t) if initial_dist is not None else None 

3524 

3525 def add_class_value(*x): 

3526 if len(x) == 1: 

3527 return class_func(*x), x[0] 

3528 else: 

3529 return class_func(*x), x 

3530 

3531 if prob_original_static == 1: 

3532 return self.map(add_class_value, name=name) 

3533 elif prob_original_static == 0: 

3534 return filtered_ds 

3535 else: 

3536 return Dataset.sample_from_datasets( 

3537 [self.map(add_class_value), filtered_ds], 

3538 weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]), 

3539 seed=seed, 

3540 stop_on_empty_dataset=True) 

3541 

3542 @staticmethod 

3543 def sample_from_datasets(datasets, 

3544 weights=None, 

3545 seed=None, 

3546 stop_on_empty_dataset=False, 

3547 rerandomize_each_iteration=None): 

3548 """Samples elements at random from the datasets in `datasets`. 

3549 

3550 Creates a dataset by interleaving elements of `datasets` with `weight[i]` 

3551 probability of picking an element from dataset `i`. Sampling is done without 

3552 replacement. For example, suppose we have 2 datasets: 

3553 

3554 ```python 

3555 dataset1 = tf.data.Dataset.range(0, 3) 

3556 dataset2 = tf.data.Dataset.range(100, 103) 

3557 ``` 

3558 

3559 Suppose that we sample from these 2 datasets with the following weights: 

3560 

3561 ```python 

3562 sample_dataset = tf.data.Dataset.sample_from_datasets( 

3563 [dataset1, dataset2], weights=[0.5, 0.5]) 

3564 ``` 

3565 

3566 One possible outcome of elements in sample_dataset is: 

3567 

3568 ``` 

3569 print(list(sample_dataset.as_numpy_iterator())) 

3570 # [100, 0, 1, 101, 2, 102] 

3571 ``` 

3572 

3573 Args: 

3574 datasets: A non-empty list of `tf.data.Dataset` objects with compatible 

3575 structure. 

3576 weights: (Optional.) A list or Tensor of `len(datasets)` floating-point 

3577 values where `weights[i]` represents the probability to sample from 

3578 `datasets[i]`, or a `tf.data.Dataset` object where each element is such 

3579 a list. Defaults to a uniform distribution across `datasets`. 

3580 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 

3581 seed that will be used to create the distribution. See 

3582 `tf.random.set_seed` for behavior. 

3583 stop_on_empty_dataset: If `True`, sampling stops if it encounters an empty 

3584 dataset. If `False`, it continues sampling and skips any empty datasets. 

3585 It is recommended to set it to `True`. Otherwise, the distribution of 

3586 samples starts off as the user intends, but may change as input datasets 

3587 become empty. This can be difficult to detect since the dataset starts 

3588 off looking correct. Default to `False` for backward compatibility. 

3589 rerandomize_each_iteration: An optional `bool`. The boolean argument 

3590 controls whether the sequence of random numbers used to determine which 

3591 dataset to sample from will be rerandomized each epoch. That is, it 

3592 determinies whether datasets will be sampled in the same order across 

3593 different epochs (the default behavior) or not. 

3594 

3595 Returns: 

3596 A dataset that interleaves elements from `datasets` at random, according 

3597 to `weights` if provided, otherwise with uniform probability. 

3598 

3599 Raises: 

3600 TypeError: If the `datasets` or `weights` arguments have the wrong type. 

3601 ValueError: 

3602 - If `datasets` is empty, or 

3603 - If `weights` is specified and does not match the length of `datasets`. 

3604 """ 

3605 # Loaded lazily due to a circular dependency 

3606 # (dataset_ops -> sample_from_datasets_op -> dataset_ops). 

3607 # pylint: disable=g-import-not-at-top,protected-access 

3608 from tensorflow.python.data.ops import sample_from_datasets_op 

3609 return sample_from_datasets_op._sample_from_datasets( # pylint: disable=protected-access 

3610 datasets, 

3611 weights, 

3612 seed, 

3613 stop_on_empty_dataset, 

3614 rerandomize_each_iteration, 

3615 ) 

3616 # pylint: enable=g-import-not-at-top,protected-access 

3617 

3618 @staticmethod 

3619 def choose_from_datasets(datasets, 

3620 choice_dataset, 

3621 stop_on_empty_dataset=True): 

3622 """Creates a dataset that deterministically chooses elements from `datasets`. 

3623 

3624 For example, given the following datasets: 

3625 

3626 ```python 

3627 datasets = [tf.data.Dataset.from_tensors("foo").repeat(), 

3628 tf.data.Dataset.from_tensors("bar").repeat(), 

3629 tf.data.Dataset.from_tensors("baz").repeat()] 

3630 

3631 # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. 

3632 choice_dataset = tf.data.Dataset.range(3).repeat(3) 

3633 

3634 result = tf.data.Dataset.choose_from_datasets(datasets, choice_dataset) 

3635 ``` 

3636 

3637 The elements of `result` will be: 

3638 

3639 ``` 

3640 "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" 

3641 ``` 

3642 

3643 Args: 

3644 datasets: A non-empty list of `tf.data.Dataset` objects with compatible 

3645 structure. 

3646 choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between 

3647 `0` and `len(datasets) - 1`. 

3648 stop_on_empty_dataset: If `True`, selection stops if it encounters an 

3649 empty dataset. If `False`, it skips empty datasets. It is recommended to 

3650 set it to `True`. Otherwise, the selected elements start off as the user 

3651 intends, but may change as input datasets become empty. This can be 

3652 difficult to detect since the dataset starts off looking correct. 

3653 Defaults to `True`. 

3654 

3655 Returns: 

3656 A new `Dataset` with the transformation applied as described above. 

3657 

3658 Raises: 

3659 TypeError: If `datasets` or `choice_dataset` has the wrong type. 

3660 ValueError: If `datasets` is empty. 

3661 """ 

3662 # Loaded lazily due to a circular dependency 

3663 # (dataset_ops -> choose_from_datasets_op -> dataset_ops). 

3664 # pylint: disable=g-import-not-at-top,protected-access 

3665 from tensorflow.python.data.ops import choose_from_datasets_op 

3666 return choose_from_datasets_op._choose_from_datasets( 

3667 datasets, choice_dataset, stop_on_empty_dataset) 

3668 # pylint: enable=g-import-not-at-top,protected-access 

3669 

3670 

3671@tf_export(v1=["data.Dataset"]) 

3672class DatasetV1(DatasetV2, data_types.DatasetV1): 

3673 """Represents a potentially large set of elements. 

3674 

3675 A `Dataset` can be used to represent an input pipeline as a 

3676 collection of elements and a "logical plan" of transformations that act on 

3677 those elements. 

3678 """ 

3679 

3680 def __init__(self): 

3681 try: 

3682 variant_tensor = self._as_variant_tensor() 

3683 except AttributeError as e: 

3684 if "_as_variant_tensor" in str(e): 

3685 raise AttributeError("Please use `_variant_tensor` instead of " 

3686 "`_as_variant_tensor()` to obtain the variant " 

3687 "associated with a dataset.") 

3688 raise AttributeError("{}: A likely cause of this error is that the super " 

3689 "call for this dataset is not the last line of the " 

3690 "`__init__` method. The base class invokes the " 

3691 "`_as_variant_tensor()` method in its constructor " 

3692 "and if that method uses attributes defined in the " 

3693 "`__init__` method, those attributes need to be " 

3694 "defined before the super call.".format(e)) 

3695 super(DatasetV1, self).__init__(variant_tensor) 

3696 

3697 @abc.abstractmethod 

3698 def _as_variant_tensor(self): 

3699 """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset. 

3700 

3701 Returns: 

3702 A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset. 

3703 """ 

3704 raise NotImplementedError(f"{type(self)}.as_variant_tensor()") 

3705 

3706 @deprecation.deprecated( 

3707 None, "This is a deprecated API that should only be used in TF 1 graph " 

3708 "mode and legacy TF 2 graph mode available through `tf.compat.v1`. In " 

3709 "all other situations -- namely, eager mode and inside `tf.function` -- " 

3710 "you can consume dataset elements using `for elem in dataset: ...` or " 

3711 "by explicitly creating iterator via `iterator = iter(dataset)` and " 

3712 "fetching its elements via `values = next(iterator)`. Furthermore, " 

3713 "this API is not available in TF 2. During the transition from TF 1 " 

3714 "to TF 2 you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)` " 

3715 "to create a TF 1 graph mode style iterator for a dataset created " 

3716 "through TF 2 APIs. Note that this should be a transient state of your " 

3717 "code base as there are in general no guarantees about the " 

3718 "interoperability of TF 1 and TF 2 code.") 

3719 def make_one_shot_iterator(self): 

3720 """Creates an iterator for elements of this dataset. 

3721 

3722 Note: The returned iterator will be initialized automatically. 

3723 A "one-shot" iterator does not currently support re-initialization. For 

3724 that see `make_initializable_iterator`. 

3725 

3726 Example: 

3727 

3728 ```python 

3729 # Building graph ... 

3730 dataset = ... 

3731 next_value = dataset.make_one_shot_iterator().get_next() 

3732 

3733 # ... from within a session ... 

3734 try: 

3735 while True: 

3736 value = sess.run(next_value) 

3737 ... 

3738 except tf.errors.OutOfRangeError: 

3739 pass 

3740 ``` 

3741 

3742 Returns: 

3743 An `tf.data.Iterator` for elements of this dataset. 

3744 """ 

3745 return self._make_one_shot_iterator() 

3746 

3747 def _make_one_shot_iterator(self): # pylint: disable=missing-docstring 

3748 if context.executing_eagerly(): 

3749 with ops.colocate_with(self._variant_tensor): 

3750 return iterator_ops.OwnedIterator(self) 

3751 

3752 _ensure_same_dataset_graph(self) 

3753 # Some ops (e.g. dataset ops) are marked as stateful but are stil safe to 

3754 # to capture by value. We must allowlist these ops so that the capturing 

3755 # logic captures the ops instead of raising an exception. 

3756 allowlisted_stateful_ops = traverse.obtain_capture_by_value_ops(self) 

3757 graph_level_seed, op_level_seed = core_random_seed.get_seed(None) 

3758 

3759 # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is 

3760 # a 0-argument function. 

3761 @function.Defun( 

3762 capture_by_value=True, 

3763 allowlisted_stateful_ops=allowlisted_stateful_ops) 

3764 def _make_dataset(): 

3765 """Factory function for a dataset.""" 

3766 # NOTE(mrry): `Defun` does not capture the graph-level seed from the 

3767 # enclosing graph, so if a graph-level seed is present we set the local 

3768 # graph seed based on a combination of the graph- and op-level seeds. 

3769 if graph_level_seed is not None: 

3770 assert op_level_seed is not None 

3771 core_random_seed.set_random_seed( 

3772 (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1)) 

3773 

3774 dataset = self._apply_debug_options() 

3775 return dataset._variant_tensor # pylint: disable=protected-access 

3776 

3777 try: 

3778 _make_dataset.add_to_graph(ops.get_default_graph()) 

3779 except ValueError as err: 

3780 if "Cannot capture a stateful node" in str(err): 

3781 raise ValueError( 

3782 "{}: A likely cause of this error is that the dataset for which " 

3783 "you are calling `make_one_shot_iterator()` captures a stateful " 

3784 "object, such as a `tf.Variable` or `tf.lookup.StaticHashTable`, " 

3785 "which is not supported. Use `make_initializable_iterator()` " 

3786 "instead.".format(err)) from None 

3787 else: 

3788 raise 

3789 

3790 with ops.colocate_with(self._variant_tensor): 

3791 # pylint: disable=protected-access 

3792 return iterator_ops.Iterator( 

3793 gen_dataset_ops.one_shot_iterator( 

3794 dataset_factory=_make_dataset, **self._flat_structure), None, 

3795 get_legacy_output_types(self), get_legacy_output_shapes(self), 

3796 get_legacy_output_classes(self)) 

3797 

3798 @deprecation.deprecated( 

3799 None, "This is a deprecated API that should only be used in TF 1 graph " 

3800 "mode and legacy TF 2 graph mode available through `tf.compat.v1`. " 

3801 "In all other situations -- namely, eager mode and inside `tf.function` " 

3802 "-- you can consume dataset elements using `for elem in dataset: ...` " 

3803 "or by explicitly creating iterator via `iterator = iter(dataset)` " 

3804 "and fetching its elements via `values = next(iterator)`. " 

3805 "Furthermore, this API is not available in TF 2. During the transition " 

3806 "from TF 1 to TF 2 you can use " 

3807 "`tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF " 

3808 "1 graph mode style iterator for a dataset created through TF 2 APIs. " 

3809 "Note that this should be a transient state of your code base as there " 

3810 "are in general no guarantees about the interoperability of TF 1 and TF " 

3811 "2 code.") 

3812 def make_initializable_iterator(self, shared_name=None): 

3813 """Creates an iterator for elements of this dataset. 

3814 

3815 Note: The returned iterator will be in an uninitialized state, 

3816 and you must run the `iterator.initializer` operation before using it: 

3817 

3818 ```python 

3819 # Building graph ... 

3820 dataset = ... 

3821 iterator = dataset.make_initializable_iterator() 

3822 next_value = iterator.get_next() # This is a Tensor. 

3823 

3824 # ... from within a session ... 

3825 sess.run(iterator.initializer) 

3826 try: 

3827 while True: 

3828 value = sess.run(next_value) 

3829 ... 

3830 except tf.errors.OutOfRangeError: 

3831 pass 

3832 ``` 

3833 

3834 Args: 

3835 shared_name: (Optional.) If non-empty, the returned iterator will be 

3836 shared under the given name across multiple sessions that share the same 

3837 devices (e.g. when using a remote server). 

3838 

3839 Returns: 

3840 A `tf.data.Iterator` for elements of this dataset. 

3841 

3842 Raises: 

3843 RuntimeError: If eager execution is enabled. 

3844 """ 

3845 return self._make_initializable_iterator(shared_name) 

3846 

3847 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=missing-docstring 

3848 if context.executing_eagerly(): 

3849 raise RuntimeError("`make_initializable_iterator()` is not supported in " 

3850 "eager mode. Use Python-style iteration instead.") 

3851 _ensure_same_dataset_graph(self) 

3852 dataset = self._apply_debug_options() 

3853 if shared_name is None: 

3854 shared_name = "" 

3855 

3856 with ops.colocate_with(self._variant_tensor): 

3857 iterator_resource = gen_dataset_ops.iterator_v2( 

3858 container="", shared_name=shared_name, **self._flat_structure) 

3859 

3860 initializer = gen_dataset_ops.make_iterator( 

3861 dataset._variant_tensor, # pylint: disable=protected-access 

3862 iterator_resource) 

3863 

3864 # pylint: disable=protected-access 

3865 return iterator_ops.Iterator(iterator_resource, initializer, 

3866 get_legacy_output_types(dataset), 

3867 get_legacy_output_shapes(dataset), 

3868 get_legacy_output_classes(dataset)) 

3869 

3870 @property 

3871 @deprecation.deprecated( 

3872 None, "Use `tf.compat.v1.data.get_output_classes(dataset)`.") 

3873 def output_classes(self): 

3874 """Returns the class of each component of an element of this dataset. 

3875 

3876 Returns: 

3877 A (nested) structure of Python `type` objects corresponding to each 

3878 component of an element of this dataset. 

3879 """ 

3880 return nest.map_structure( 

3881 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 

3882 self.element_spec) 

3883 

3884 @property 

3885 @deprecation.deprecated( 

3886 None, "Use `tf.compat.v1.data.get_output_shapes(dataset)`.") 

3887 def output_shapes(self): 

3888 """Returns the shape of each component of an element of this dataset. 

3889 

3890 Returns: 

3891 A (nested) structure of `tf.TensorShape` objects corresponding to each 

3892 component of an element of this dataset. 

3893 """ 

3894 return nest.map_structure( 

3895 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 

3896 self.element_spec) 

3897 

3898 @property 

3899 @deprecation.deprecated( 

3900 None, "Use `tf.compat.v1.data.get_output_types(dataset)`.") 

3901 def output_types(self): 

3902 """Returns the type of each component of an element of this dataset. 

3903 

3904 Returns: 

3905 A (nested) structure of `tf.DType` objects corresponding to each component 

3906 of an element of this dataset. 

3907 """ 

3908 return nest.map_structure( 

3909 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 

3910 self.element_spec) 

3911 

3912 @property 

3913 def element_spec(self): 

3914 # TODO(b/110122868): Remove this override once all `Dataset` instances 

3915 # implement `element_structure`. 

3916 return structure.convert_legacy_structure( 

3917 self.output_types, self.output_shapes, self.output_classes) 

3918 

3919 @staticmethod 

3920 @functools.wraps(DatasetV2.from_tensors) 

3921 def from_tensors(tensors, name=None): 

3922 return DatasetV1Adapter(DatasetV2.from_tensors(tensors, name=name)) 

3923 

3924 @staticmethod 

3925 @functools.wraps(DatasetV2.from_tensor_slices) 

3926 def from_tensor_slices(tensors, name=None): 

3927 return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors, name=name)) 

3928 

3929 @staticmethod 

3930 @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") 

3931 def from_sparse_tensor_slices(sparse_tensor): 

3932 """Splits each rank-N `tf.sparse.SparseTensor` in this dataset row-wise. 

3933 

3934 Args: 

3935 sparse_tensor: A `tf.sparse.SparseTensor`. 

3936 

3937 Returns: 

3938 Dataset: A `Dataset` of rank-(N-1) sparse tensors. 

3939 """ 

3940 # Loaded lazily due to a circular dependency (dataset_ops -> 

3941 # from_sparse_tensor_slices_op -> dataset_ops). 

3942 # pylint: disable=g-import-not-at-top,protected-access 

3943 from tensorflow.python.data.ops import from_sparse_tensor_slices_op 

3944 return from_sparse_tensor_slices_op._from_sparse_tensor_slices( 

3945 sparse_tensor) 

3946 # pylint: enable=g-import-not-at-top,protected-access 

3947 

3948 @staticmethod 

3949 @functools.wraps(DatasetV2.from_generator) 

3950 @deprecation.deprecated_args(None, "Use output_signature instead", 

3951 "output_types", "output_shapes") 

3952 def from_generator(generator, 

3953 output_types=None, 

3954 output_shapes=None, 

3955 args=None, 

3956 output_signature=None, 

3957 name=None): 

3958 # Calling DatasetV2.from_generator with output_shapes or output_types is 

3959 # deprecated, but this is already checked by the decorator on this function. 

3960 with deprecation.silence(): 

3961 return DatasetV1Adapter( 

3962 DatasetV2.from_generator( 

3963 generator, 

3964 output_types, 

3965 output_shapes, 

3966 args, 

3967 output_signature, 

3968 name=name)) 

3969 

3970 @staticmethod 

3971 @functools.wraps(DatasetV2.range) 

3972 def range(*args, **kwargs): 

3973 return DatasetV1Adapter(DatasetV2.range(*args, **kwargs)) 

3974 

3975 @staticmethod 

3976 @functools.wraps(DatasetV2.zip) 

3977 def zip(*args, datasets=None, name=None): 

3978 return DatasetV1Adapter(DatasetV2.zip(*args, datasets=datasets, name=name)) 

3979 

3980 @functools.wraps(DatasetV2.concatenate) 

3981 def concatenate(self, dataset, name=None): 

3982 return DatasetV1Adapter( 

3983 super(DatasetV1, self).concatenate(dataset, name=name)) 

3984 

3985 @functools.wraps(DatasetV2.prefetch) 

3986 def prefetch(self, buffer_size, name=None): 

3987 return DatasetV1Adapter( 

3988 super(DatasetV1, self).prefetch(buffer_size, name=name)) 

3989 

3990 @staticmethod 

3991 @functools.wraps(DatasetV2.list_files) 

3992 def list_files(file_pattern, shuffle=None, seed=None, name=None): 

3993 return DatasetV1Adapter( 

3994 DatasetV2.list_files(file_pattern, shuffle, seed, name=name)) 

3995 

3996 @functools.wraps(DatasetV2.repeat) 

3997 def repeat(self, count=None, name=None): 

3998 return DatasetV1Adapter(super(DatasetV1, self).repeat(count, name=name)) 

3999 

4000 @functools.wraps(DatasetV2.shuffle) 

4001 def shuffle(self, 

4002 buffer_size, 

4003 seed=None, 

4004 reshuffle_each_iteration=None, 

4005 name=None): 

4006 return DatasetV1Adapter( 

4007 super(DatasetV1, self).shuffle( 

4008 buffer_size, seed, reshuffle_each_iteration, name=name)) 

4009 

4010 @functools.wraps(DatasetV2.cache) 

4011 def cache(self, filename="", name=None): 

4012 return DatasetV1Adapter(super(DatasetV1, self).cache(filename, name=name)) 

4013 

4014 @functools.wraps(DatasetV2.take) 

4015 def take(self, count, name=None): 

4016 return DatasetV1Adapter(super(DatasetV1, self).take(count, name=name)) 

4017 

4018 @functools.wraps(DatasetV2.skip) 

4019 def skip(self, count, name=None): 

4020 return DatasetV1Adapter(super(DatasetV1, self).skip(count, name=name)) 

4021 

4022 @functools.wraps(DatasetV2.shard) 

4023 def shard(self, num_shards, index, name=None): 

4024 return DatasetV1Adapter( 

4025 super(DatasetV1, self).shard(num_shards, index, name=name)) 

4026 

4027 @functools.wraps(DatasetV2.batch) 

4028 def batch(self, 

4029 batch_size, 

4030 drop_remainder=False, 

4031 num_parallel_calls=None, 

4032 deterministic=None, 

4033 name=None): 

4034 return DatasetV1Adapter( 

4035 super(DatasetV1, self).batch( 

4036 batch_size, 

4037 drop_remainder, 

4038 num_parallel_calls, 

4039 deterministic, 

4040 name=name)) 

4041 

4042 @functools.wraps(DatasetV2.padded_batch) 

4043 def padded_batch(self, 

4044 batch_size, 

4045 padded_shapes=None, 

4046 padding_values=None, 

4047 drop_remainder=False, 

4048 name=None): 

4049 return DatasetV1Adapter( 

4050 super(DatasetV1, self).padded_batch( 

4051 batch_size, 

4052 padded_shapes, 

4053 padding_values, 

4054 drop_remainder, 

4055 name=name)) 

4056 

4057 @functools.wraps(DatasetV2.map) 

4058 def map(self, 

4059 map_func, 

4060 num_parallel_calls=None, 

4061 deterministic=None, 

4062 name=None): 

4063 # Loaded lazily due to a circular dependency (dataset_ops -> map_op -> 

4064 # dataset_ops). 

4065 # pylint: disable=g-import-not-at-top,protected-access 

4066 from tensorflow.python.data.ops import map_op 

4067 return map_op._map_v1( 

4068 self, 

4069 map_func, 

4070 num_parallel_calls=num_parallel_calls, 

4071 deterministic=deterministic) 

4072 # pylint: enable=g-import-not-at-top,protected-access 

4073 

4074 @deprecation.deprecated(None, "Use `tf.data.Dataset.map()") 

4075 def map_with_legacy_function(self, 

4076 map_func, 

4077 num_parallel_calls=None, 

4078 deterministic=None): 

4079 """Maps `map_func` across the elements of this dataset. 

4080 

4081 Note: This is an escape hatch for existing uses of `map` that do not work 

4082 with V2 functions. New uses are strongly discouraged and existing uses 

4083 should migrate to `map` as this method will be removed in V2. 

4084 

4085 Args: 

4086 map_func: A function mapping a (nested) structure of tensors (having 

4087 shapes and types defined by `self.output_shapes` and 

4088 `self.output_types`) to another (nested) structure of tensors. 

4089 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 

4090 representing the number elements to process asynchronously in parallel. 

4091 If not specified, elements will be processed sequentially. If the value 

4092 `tf.data.AUTOTUNE` is used, then the number of parallel calls is set 

4093 dynamically based on available CPU. 

4094 deterministic: (Optional.) When `num_parallel_calls` is specified, this 

4095 boolean controls the order in which the transformation produces 

4096 elements. If set to `False`, the transformation is allowed to yield 

4097 elements out of order to trade determinism for performance. If not 

4098 specified, the `tf.data.Options.deterministic` option (`True` by 

4099 default) controls the behavior. 

4100 

4101 Returns: 

4102 Dataset: A `Dataset`. 

4103 """ 

4104 # Loaded lazily due to a circular dependency (dataset_ops -> map_op -> 

4105 # dataset_ops). 

4106 # pylint: disable=g-import-not-at-top,protected-access 

4107 from tensorflow.python.data.ops import map_op 

4108 return map_op._map_v1_with_legacy_function( 

4109 self, 

4110 map_func, 

4111 num_parallel_calls=num_parallel_calls, 

4112 deterministic=deterministic) 

4113 # pylint: enable=g-import-not-at-top,protected-access 

4114 

4115 @functools.wraps(DatasetV2.flat_map) 

4116 def flat_map(self, map_func, name=None): 

4117 return DatasetV1Adapter( 

4118 super(DatasetV1, self).flat_map(map_func, name=name)) 

4119 

4120 @functools.wraps(DatasetV2.interleave) 

4121 def interleave(self, 

4122 map_func, 

4123 cycle_length=None, 

4124 block_length=None, 

4125 num_parallel_calls=None, 

4126 deterministic=None, 

4127 name=None): 

4128 return DatasetV1Adapter( 

4129 super(DatasetV1, self).interleave( 

4130 map_func, 

4131 cycle_length, 

4132 block_length, 

4133 num_parallel_calls, 

4134 deterministic, 

4135 name=name)) 

4136 

4137 @functools.wraps(DatasetV2.filter) 

4138 def filter(self, predicate, name=None): 

4139 return DatasetV1Adapter(super(DatasetV1, self).filter(predicate, name=name)) 

4140 

4141 @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()") 

4142 def filter_with_legacy_function(self, predicate): 

4143 """Filters this dataset according to `predicate`. 

4144 

4145 Note: This is an escape hatch for existing uses of `filter` that do not work 

4146 with V2 functions. New uses are strongly discouraged and existing uses 

4147 should migrate to `filter` as this method will be removed in V2. 

4148 

4149 Args: 

4150 predicate: A function mapping a (nested) structure of tensors (having 

4151 shapes and types defined by `self.output_shapes` and 

4152 `self.output_types`) to a scalar `tf.bool` tensor. 

4153 

4154 Returns: 

4155 Dataset: The `Dataset` containing the elements of this dataset for which 

4156 `predicate` is `True`. 

4157 """ 

4158 # Loaded lazily due to a circular dependency (dataset_ops -> filter_op -> 

4159 # dataset_ops). 

4160 # pylint: disable=g-import-not-at-top,protected-access 

4161 from tensorflow.python.data.ops import filter_op 

4162 return filter_op._FilterDataset(self, predicate, use_legacy_function=True) 

4163 # pylint: enable=g-import-not-at-top,protected-access 

4164 

4165 @functools.wraps(DatasetV2.apply) 

4166 def apply(self, transformation_func): 

4167 return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func)) 

4168 

4169 @functools.wraps(DatasetV2.window) 

4170 def window(self, size, shift=None, stride=1, drop_remainder=False, name=None): 

4171 return DatasetV1Adapter( 

4172 super(DatasetV1, 

4173 self).window(size, shift, stride, drop_remainder, name=name)) 

4174 

4175 @functools.wraps(DatasetV2.unbatch) 

4176 def unbatch(self, name=None): 

4177 return DatasetV1Adapter(super(DatasetV1, self).unbatch(name=name)) 

4178 

4179 @functools.wraps(DatasetV2.with_options) 

4180 def with_options(self, options, name=None): 

4181 return DatasetV1Adapter( 

4182 super(DatasetV1, self).with_options(options, name=name)) 

4183 

4184 

4185if tf2.enabled(): 

4186 Dataset = DatasetV2 

4187else: 

4188 Dataset = DatasetV1 

4189 

4190 

4191class DatasetV1Adapter(DatasetV1): 

4192 """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API.""" 

4193 

4194 def __init__(self, dataset): 

4195 self._dataset = dataset 

4196 super(DatasetV1Adapter, self).__init__() 

4197 

4198 def _as_variant_tensor(self): 

4199 return self._dataset._variant_tensor # pylint: disable=protected-access 

4200 

4201 def _inputs(self): 

4202 return self._dataset._inputs() # pylint: disable=protected-access 

4203 

4204 def _functions(self): 

4205 return self._dataset._functions() # pylint: disable=protected-access 

4206 

4207 def options(self): 

4208 return self._dataset.options() 

4209 

4210 @property 

4211 def element_spec(self): 

4212 return self._dataset.element_spec # pylint: disable=protected-access 

4213 

4214 def __iter__(self): 

4215 return iter(self._dataset) 

4216 

4217 

4218def _ensure_same_dataset_graph(dataset): 

4219 """Walks the dataset graph to ensure all datasets come from the same graph.""" 

4220 # pylint: disable=protected-access 

4221 current_graph = ops.get_default_graph() 

4222 bfs_q = queue.Queue() 

4223 bfs_q.put(dataset) 

4224 visited = [] 

4225 while not bfs_q.empty(): 

4226 ds = bfs_q.get() 

4227 visited.append(ds) 

4228 ds_graph = ds._graph 

4229 if current_graph != ds_graph: 

4230 raise ValueError( 

4231 f"The graph {current_graph} of the iterator is different from the " 

4232 f"graph {ds_graph} the dataset: {ds._variant_tensor} was created in. " 

4233 f"If you are using the Estimator API, make sure that no part of the " 

4234 f"dataset returned by the `input_fn` function is defined outside the " 

4235 f"`input_fn` function. Otherwise, make sure that the dataset is " 

4236 f"created in the same graph as the iterator.") 

4237 for input_ds in ds._inputs(): 

4238 if input_ds not in visited: 

4239 bfs_q.put(input_ds) 

4240 

4241 

4242@tf_export(v1=["data.make_one_shot_iterator"]) 

4243def make_one_shot_iterator(dataset): 

4244 """Creates an iterator for elements of `dataset`. 

4245 

4246 Note: The returned iterator will be initialized automatically. 

4247 A "one-shot" iterator does not support re-initialization. 

4248 

4249 Args: 

4250 dataset: A `tf.data.Dataset`. 

4251 

4252 Returns: 

4253 A `tf.data.Iterator` for elements of `dataset`. 

4254 

4255 @compatibility(TF2) 

4256 This is a legacy API for consuming dataset elements and should only be used 

4257 during transition from TF 1 to TF 2. Note that using this API should be 

4258 a transient state of your code base as there are in general no guarantees 

4259 about the interoperability of TF 1 and TF 2 code. 

4260 

4261 In TF 2 datasets are Python iterables which means you can consume their 

4262 elements using `for elem in dataset: ...` or by explicitly creating iterator 

4263 via `iterator = iter(dataset)` and fetching its elements via 

4264 `values = next(iterator)`. 

4265 @end_compatibility 

4266 """ 

4267 try: 

4268 # Call the defined `_make_one_shot_iterator()` if there is one, because some 

4269 # datasets (e.g. for prefetching) override its behavior. 

4270 return dataset._make_one_shot_iterator() # pylint: disable=protected-access 

4271 except AttributeError: 

4272 return DatasetV1Adapter(dataset)._make_one_shot_iterator() # pylint: disable=protected-access 

4273 

4274 

4275@tf_export(v1=["data.make_initializable_iterator"]) 

4276def make_initializable_iterator(dataset, shared_name=None): 

4277 """Creates an iterator for elements of `dataset`. 

4278 

4279 Note: The returned iterator will be in an uninitialized state, 

4280 and you must run the `iterator.initializer` operation before using it: 

4281 

4282 ```python 

4283 dataset = ... 

4284 iterator = tf.compat.v1.data.make_initializable_iterator(dataset) 

4285 # ... 

4286 sess.run(iterator.initializer) 

4287 ``` 

4288 

4289 Args: 

4290 dataset: A `tf.data.Dataset`. 

4291 shared_name: (Optional.) If non-empty, the returned iterator will be shared 

4292 under the given name across multiple sessions that share the same devices 

4293 (e.g. when using a remote server). 

4294 

4295 Returns: 

4296 A `tf.data.Iterator` for elements of `dataset`. 

4297 

4298 Raises: 

4299 RuntimeError: If eager execution is enabled. 

4300 

4301 @compatibility(TF2) 

4302 This is a legacy API for consuming dataset elements and should only be used 

4303 during transition from TF 1 to TF 2. Note that using this API should be 

4304 a transient state of your code base as there are in general no guarantees 

4305 about the interoperability of TF 1 and TF 2 code. 

4306 

4307 In TF 2 datasets are Python iterables which means you can consume their 

4308 elements using `for elem in dataset: ...` or by explicitly creating iterator 

4309 via `iterator = iter(dataset)` and fetching its elements via 

4310 `values = next(iterator)`. 

4311 @end_compatibility 

4312 """ 

4313 try: 

4314 # Call the defined `_make_initializable_iterator()` if there is one, because 

4315 # some datasets (e.g. for prefetching) override its behavior. 

4316 return dataset._make_initializable_iterator(shared_name) # pylint: disable=protected-access 

4317 except AttributeError: 

4318 return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name) # pylint: disable=protected-access 

4319 

4320 

4321@tf_export("data.experimental.get_structure") 

4322def get_structure(dataset_or_iterator): 

4323 """Returns the type signature for elements of the input dataset / iterator. 

4324 

4325 For example, to get the structure of a `tf.data.Dataset`: 

4326 

4327 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

4328 >>> tf.data.experimental.get_structure(dataset) 

4329 TensorSpec(shape=(), dtype=tf.int32, name=None) 

4330 

4331 >>> dataset = tf.data.experimental.from_list([(1, 'a'), (2, 'b'), (3, 'c')]) 

4332 >>> tf.data.experimental.get_structure(dataset) 

4333 (TensorSpec(shape=(), dtype=tf.int32, name=None), 

4334 TensorSpec(shape=(), dtype=tf.string, name=None)) 

4335 

4336 To get the structure of an `tf.data.Iterator`: 

4337 

4338 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

4339 >>> tf.data.experimental.get_structure(iter(dataset)) 

4340 TensorSpec(shape=(), dtype=tf.int32, name=None) 

4341 

4342 Args: 

4343 dataset_or_iterator: A `tf.data.Dataset` or an `tf.data.Iterator`. 

4344 

4345 Returns: 

4346 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 

4347 element of `dataset_or_iterator` and specifying the type of individual 

4348 components. 

4349 

4350 Raises: 

4351 TypeError: If input is not a `tf.data.Dataset` or an `tf.data.Iterator` 

4352 object. 

4353 """ 

4354 try: 

4355 return dataset_or_iterator.element_spec # pylint: disable=protected-access 

4356 except AttributeError: 

4357 raise TypeError(f"Invalid `dataset_or_iterator`. `dataset_or_iterator` " 

4358 f"must be a `tf.data.Dataset` or tf.data.Iterator object, " 

4359 f"but got {type(dataset_or_iterator)}.") 

4360 

4361 

4362@tf_export(v1=["data.get_output_classes"]) 

4363def get_legacy_output_classes(dataset_or_iterator): 

4364 """Returns the output classes for elements of the input dataset / iterator. 

4365 

4366 Args: 

4367 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 

4368 

4369 Returns: 

4370 A (nested) structure of Python `type` objects matching the structure of the 

4371 dataset / iterator elements and specifying the class of the individual 

4372 components. 

4373 

4374 @compatibility(TF2) 

4375 This is a legacy API for inspecting the type signature of dataset elements. In 

4376 TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead. 

4377 @end_compatibility 

4378 """ 

4379 return nest.map_structure( 

4380 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 

4381 get_structure(dataset_or_iterator)) 

4382 

4383 

4384@tf_export(v1=["data.get_output_shapes"]) 

4385def get_legacy_output_shapes(dataset_or_iterator): 

4386 """Returns the output shapes for elements of the input dataset / iterator. 

4387 

4388 Args: 

4389 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 

4390 

4391 Returns: 

4392 A (nested) structure of `tf.TensorShape` objects matching the structure of 

4393 the dataset / iterator elements and specifying the shape of the individual 

4394 components. 

4395 

4396 @compatibility(TF2) 

4397 This is a legacy API for inspecting the type signature of dataset elements. In 

4398 TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead. 

4399 @end_compatibility 

4400 """ 

4401 return nest.map_structure( 

4402 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 

4403 get_structure(dataset_or_iterator)) 

4404 

4405 

4406@tf_export(v1=["data.get_output_types"]) 

4407def get_legacy_output_types(dataset_or_iterator): 

4408 """Returns the output shapes for elements of the input dataset / iterator. 

4409 

4410 Args: 

4411 dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`. 

4412 

4413 Returns: 

4414 A (nested) structure of `tf.DType` objects matching the structure of 

4415 dataset / iterator elements and specifying the shape of the individual 

4416 components. 

4417 

4418 @compatibility(TF2) 

4419 This is a legacy API for inspecting the type signature of dataset elements. In 

4420 TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead. 

4421 @end_compatibility 

4422 """ 

4423 return nest.map_structure( 

4424 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 

4425 get_structure(dataset_or_iterator)) 

4426 

4427 

4428class DatasetSource(DatasetV2): 

4429 """Abstract class representing a dataset with no inputs.""" 

4430 

4431 def _inputs(self): 

4432 return [] 

4433 

4434 

4435class UnaryDataset(DatasetV2): 

4436 """Abstract class representing a dataset with one input.""" 

4437 

4438 def __init__(self, input_dataset, variant_tensor): 

4439 self._input_dataset = input_dataset 

4440 super(UnaryDataset, self).__init__(variant_tensor) 

4441 

4442 def _inputs(self): 

4443 return [self._input_dataset] 

4444 

4445 

4446class UnaryUnchangedStructureDataset(UnaryDataset): 

4447 """Represents a unary dataset with the same input and output structure.""" 

4448 

4449 def __init__(self, input_dataset, variant_tensor): 

4450 self._input_dataset = input_dataset 

4451 super(UnaryUnchangedStructureDataset, self).__init__( 

4452 input_dataset, variant_tensor) 

4453 

4454 @property 

4455 def element_spec(self): 

4456 return self._input_dataset.element_spec 

4457 

4458 

4459class _VariantDataset(DatasetV2): 

4460 """A Dataset wrapper around a `tf.variant`-typed function argument.""" 

4461 

4462 def __init__(self, dataset_variant, element_spec): 

4463 self._element_spec = element_spec 

4464 super(_VariantDataset, self).__init__(dataset_variant) 

4465 

4466 def _inputs(self): 

4467 return [] 

4468 

4469 @property 

4470 def element_spec(self): 

4471 return self._element_spec 

4472 

4473 

4474class _NestedVariant(composite_tensor.CompositeTensor): 

4475 

4476 def __init__(self, variant_tensor, element_spec, dataset_shape): 

4477 self._variant_tensor = variant_tensor 

4478 self._element_spec = element_spec 

4479 self._dataset_shape = dataset_shape 

4480 

4481 @property 

4482 def _type_spec(self): 

4483 return DatasetSpec(self._element_spec, self._dataset_shape) 

4484 

4485 

4486@tf_export("data.experimental.from_variant") 

4487def from_variant(variant, structure): 

4488 """Constructs a dataset from the given variant and (nested) structure. 

4489 

4490 Args: 

4491 variant: A scalar `tf.variant` tensor representing a dataset. 

4492 structure: A (nested) structure of `tf.TypeSpec` objects representing the 

4493 structure of each element in the dataset. 

4494 

4495 Returns: 

4496 A `tf.data.Dataset` instance. 

4497 """ 

4498 return _VariantDataset(variant, structure) # pylint: disable=protected-access 

4499 

4500 

4501@tf_export("data.experimental.to_variant") 

4502def to_variant(dataset): 

4503 """Returns a variant representing the given dataset. 

4504 

4505 Args: 

4506 dataset: A `tf.data.Dataset`. 

4507 

4508 Returns: 

4509 A scalar `tf.variant` tensor representing the given dataset. 

4510 """ 

4511 return dataset._variant_tensor # pylint: disable=protected-access 

4512 

4513 

4514@tf_export( 

4515 "data.DatasetSpec", 

4516 v1=["data.DatasetSpec", "data.experimental.DatasetStructure"]) 

4517class DatasetSpec(type_spec.BatchableTypeSpec): 

4518 """Type specification for `tf.data.Dataset`. 

4519 

4520 See `tf.TypeSpec` for more information about TensorFlow type specifications. 

4521 

4522 >>> dataset = tf.data.Dataset.range(3) 

4523 >>> tf.data.DatasetSpec.from_value(dataset) 

4524 DatasetSpec(TensorSpec(shape=(), dtype=tf.int64, name=None), TensorShape([])) 

4525 """ 

4526 

4527 __slots__ = ["_element_spec", "_dataset_shape"] 

4528 

4529 def __init__(self, element_spec, dataset_shape=()): 

4530 self._element_spec = element_spec 

4531 self._dataset_shape = tensor_shape.as_shape(dataset_shape) 

4532 

4533 @property 

4534 def value_type(self): 

4535 return Dataset 

4536 

4537 @property 

4538 def element_spec(self): 

4539 """The inner element spec.""" 

4540 return self._element_spec 

4541 

4542 def is_subtype_of(self, other): 

4543 """See base class.""" 

4544 if type(self) is not type(other): 

4545 return False 

4546 

4547 # TODO(b/220385675): _element_spec should always be a TypeSpec. 

4548 try: 

4549 tf_nest.assert_same_structure(self.element_spec, other.element_spec) 

4550 except (TypeError, ValueError): 

4551 return False 

4552 

4553 self_elements = tf_nest.flatten(self.element_spec) 

4554 other_elements = tf_nest.flatten(other.element_spec) 

4555 

4556 def is_subtype_or_equal(a, b): 

4557 if isinstance(a, trace.TraceType): 

4558 return a.is_subtype_of(b) 

4559 else: 

4560 return a == b 

4561 

4562 for self_element, other_element in zip(self_elements, other_elements): 

4563 if not is_subtype_or_equal(self_element, other_element): 

4564 return False 

4565 

4566 return self._dataset_shape.is_subtype_of(other._dataset_shape) # pylint: disable=protected-access 

4567 

4568 def most_specific_common_supertype(self, others): 

4569 """See base class.""" 

4570 if not all(type(self) is type(other) for other in others): 

4571 return None 

4572 

4573 try: 

4574 for other in others: 

4575 tf_nest.assert_same_structure(self.element_spec, other.element_spec) 

4576 except (TypeError, ValueError): 

4577 return None 

4578 

4579 self_components = tf_nest.flatten(self.element_spec) 

4580 others_components = [ 

4581 tf_nest.flatten(other.element_spec) for other in others 

4582 ] 

4583 common_components = [None] * len(self_components) 

4584 

4585 def common_supertype_or_equal(a, bs): 

4586 if isinstance(a, trace.TraceType): 

4587 return a.most_specific_common_supertype(bs) 

4588 else: 

4589 return a if all(a == b for b in bs) else None 

4590 

4591 for i, self_component in enumerate(self_components): 

4592 common_components[i] = common_supertype_or_equal( 

4593 self_component, 

4594 [other_components[i] for other_components in others_components]) 

4595 if self_component is not None and common_components[i] is None: 

4596 return None 

4597 common_element_spec = tf_nest.pack_sequence_as(self._element_spec, 

4598 common_components) 

4599 

4600 common_dataset_shape = self._dataset_shape.most_specific_common_supertype( 

4601 [other._dataset_shape for other in others]) # pylint: disable=protected-access 

4602 if common_dataset_shape is None: 

4603 return None 

4604 

4605 return DatasetSpec(common_element_spec, common_dataset_shape) 

4606 

4607 # TODO(b/220385675): Once _element_spec is guaranteed to be TypeSpec, the 

4608 # following functions do not need to be overloaded: is_subtype_of, 

4609 # most_specific_common_supertype, __hash__ and __eq__ 

4610 def _serialize(self): 

4611 return (self._element_spec, self._dataset_shape) 

4612 

4613 @property 

4614 def _component_specs(self): 

4615 return tensor_spec.TensorSpec(self._dataset_shape, dtypes.variant) 

4616 

4617 def _to_components(self, value): 

4618 return value._variant_tensor # pylint: disable=protected-access 

4619 

4620 def _from_components(self, components): 

4621 # pylint: disable=protected-access 

4622 if self._dataset_shape.ndims == 0: 

4623 return _VariantDataset(components, self._element_spec) 

4624 else: 

4625 return _NestedVariant(components, self._element_spec, self._dataset_shape) 

4626 

4627 def _to_tensor_list(self, value): 

4628 return [ 

4629 ops.convert_to_tensor( 

4630 tf_nest.map_structure(lambda x: x._variant_tensor, value)) # pylint: disable=protected-access 

4631 ] 

4632 

4633 @staticmethod 

4634 def from_value(value): 

4635 """Creates a `DatasetSpec` for the given `tf.data.Dataset` value.""" 

4636 return DatasetSpec(value.element_spec) # pylint: disable=protected-access 

4637 

4638 def _batch(self, batch_size): 

4639 return DatasetSpec( 

4640 self._element_spec, 

4641 tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape)) 

4642 

4643 def _unbatch(self): 

4644 if self._dataset_shape.ndims == 0: 

4645 raise ValueError("Slicing dataset elements is not supported for rank 0.") 

4646 return DatasetSpec(self._element_spec, self._dataset_shape[1:]) 

4647 

4648 def _to_batched_tensor_list(self, value): 

4649 if self._dataset_shape.ndims == 0: 

4650 raise ValueError("Slicing dataset elements is not supported for rank 0.") 

4651 return self._to_tensor_list(value) 

4652 

4653 def _to_legacy_output_types(self): 

4654 return self 

4655 

4656 def _to_legacy_output_shapes(self): 

4657 return self 

4658 

4659 def _to_legacy_output_classes(self): 

4660 return self 

4661 

4662 def __hash__(self): 

4663 # TODO(b/220385675): attributes can be dicts and hence unhashable. 

4664 return hash(DatasetSpec) 

4665 

4666 def __eq__(self, other): 

4667 return (isinstance(other, DatasetSpec) and 

4668 self._element_spec == other._element_spec and 

4669 self._dataset_shape == other._dataset_shape) 

4670 

4671 

4672nested_structure_coder.register_codec( 

4673 nested_structure_coder.BuiltInTypeSpecCodec( 

4674 DatasetSpec, struct_pb2.TypeSpecProto.DATA_DATASET_SPEC 

4675 ) 

4676) 

4677 

4678 

4679class _NumpyIterator(tracking_base.Trackable): 

4680 """Iterator over a dataset with elements converted to numpy.""" 

4681 

4682 __slots__ = ["_iterator"] 

4683 

4684 def __init__(self, dataset): 

4685 self._iterator = iter(dataset) 

4686 

4687 def __iter__(self): 

4688 return self 

4689 

4690 def __next__(self): 

4691 

4692 def to_numpy(x): 

4693 numpy = x._numpy() # pylint: disable=protected-access 

4694 if isinstance(numpy, np.ndarray): 

4695 # `numpy` shares the same underlying buffer as the `x` Tensor. 

4696 # Tensors are expected to be immutable, so we disable writes. 

4697 numpy.setflags(write=False) 

4698 return numpy 

4699 

4700 return nest.map_structure(to_numpy, next(self._iterator)) 

4701 

4702 def next(self): 

4703 return self.__next__() 

4704 

4705 # override 

4706 def _serialize_to_tensors(self): 

4707 # pylint: disable=protected-access 

4708 return self._iterator._serialize_to_tensors() 

4709 

4710 # override 

4711 def _restore_from_tensors(self, restored_tensors): 

4712 # pylint: disable=protected-access 

4713 return self._iterator._restore_from_tensors(restored_tensors) 

4714 

4715 def _save(self): 

4716 # pylint: disable=protected-access 

4717 return self._iterator._save() 

4718 

4719 def _restore(self, state): 

4720 # pylint: disable=protected-access 

4721 return self._iterator._restore(state) 

4722 

4723 

4724class _VariantTracker(resource_lib.CapturableResource): 

4725 """Allows export of functions capturing a Dataset in SavedModels. 

4726 

4727 When saving a SavedModel, `tf.saved_model.save` traverses the object 

4728 graph. Since Datasets reference _VariantTracker objects, that traversal will 

4729 find a _VariantTracker for each Dataset and so know how to save and restore 

4730 functions which reference the Dataset's variant Tensor. 

4731 """ 

4732 

4733 def __init__(self, variant_tensor, resource_creator): 

4734 """Record that `variant_tensor` is associated with `resource_creator`. 

4735 

4736 Args: 

4737 variant_tensor: The variant-dtype Tensor associated with the Dataset. This 

4738 Tensor will be a captured input to functions which use the Dataset, and 

4739 is used by saving code to identify the corresponding _VariantTracker. 

4740 resource_creator: A zero-argument function which creates a new 

4741 variant-dtype Tensor. This function will be included in SavedModels and 

4742 run to re-create the Dataset's variant Tensor on restore. 

4743 """ 

4744 super(_VariantTracker, self).__init__(device="CPU") 

4745 self._resource_handle = variant_tensor 

4746 if not isinstance(resource_creator, def_function.Function): 

4747 # Internal validation -- _VariantTracker assumes that resource creator is 

4748 # already a tf.function. 

4749 raise TypeError("Resource creator should already be a tf.function.") 

4750 self._create_resource = resource_creator 

4751 

4752 def _trackable_children(self, 

4753 save_type=tracking_base.SaveType.CHECKPOINT, 

4754 **kwargs): 

4755 if save_type != tracking_base.SaveType.SAVEDMODEL: 

4756 return {} 

4757 

4758 children = super(_VariantTracker, 

4759 self)._trackable_children(save_type, **kwargs) 

4760 # Overwrite the _create_resource function, since `self._create_resource` 

4761 # is already a tf.function. 

4762 children["_create_resource"] = self._create_resource 

4763 return children 

4764 

4765 

4766# TODO(b/254291122): Remove. 

4767# Loaded lazily due to a circular dependency (dataset_ops -> 

4768# batch_op -> dataset_ops). 

4769batch_op = lazy_loader.LazyLoader( 

4770 "batch_op", globals(), 

4771 "tensorflow.python.data.ops.batch_op") 

4772BatchDataset = batch_op._BatchDataset # pylint: disable=protected-access 

4773PrefetchDataset = prefetch_op._PrefetchDataset # pylint: disable=protected-access 

4774ShuffleDataset = shuffle_op._ShuffleDataset # pylint: disable=protected-access 

4775 

4776 

4777# TODO(b/254291122): Remove. 

4778# Loaded lazily due to a circular dependency (dataset_ops -> 

4779# repeat_op -> dataset_ops). 

4780repeat_op = lazy_loader.LazyLoader( 

4781 "repeat_op", globals(), 

4782 "tensorflow.python.data.ops.repeat_op") 

4783RepeatDataset = repeat_op._RepeatDataset # pylint: disable=protected-access 

4784 

4785 

4786class _OptionsDataset(UnaryUnchangedStructureDataset): 

4787 """An identity `Dataset` that stores options.""" 

4788 

4789 def __init__(self, input_dataset, options, name=None): 

4790 # pylint: disable=protected-access 

4791 self._input_dataset = input_dataset 

4792 options_pb = dataset_options_pb2.Options() 

4793 options_pb.CopyFrom(options._to_proto()) 

4794 self._name = name 

4795 with ops.colocate_with(input_dataset._variant_tensor): 

4796 variant_tensor = gen_dataset_ops.options_dataset( 

4797 input_dataset._variant_tensor, options_pb.SerializeToString(), 

4798 **self._common_args) 

4799 super(_OptionsDataset, self).__init__(input_dataset, variant_tensor) 

4800 

4801 if self._options_attr: 

4802 self._options_attr._set_mutable(True) 

4803 self._options_attr = self._options_attr.merge(options) 

4804 else: 

4805 self._options_attr = options 

4806 self._options_attr._set_mutable(False) 

4807 

4808 

4809def normalize_to_dense(dataset): 

4810 """Normalizes non-tensor components in a dataset to dense representations. 

4811 

4812 This is necessary for dataset transformations that slice along the batch 

4813 dimension and are oblivious to non-tensors, e.g. `unbatch`, `rebatch`. 

4814 

4815 Args: 

4816 dataset: Dataset to normalize. 

4817 

4818 Returns: 

4819 A dataset whose sparse and ragged tensors have been normalized to their 

4820 dense representations. 

4821 """ 

4822 

4823 # NOTE(mrry): This leads to a somewhat inefficient re-encoding step for all 

4824 # non-tensor components. 

4825 # 

4826 # TODO(mrry): Consider optimizing this if it turns out to be a bottleneck. 

4827 if structured_function._should_unpack(dataset.element_spec): # pylint: disable=protected-access 

4828 

4829 def normalize(*args): 

4830 return structure.to_batched_tensor_list(dataset.element_spec, tuple(args)) 

4831 else: 

4832 def normalize(arg): 

4833 return structure.to_batched_tensor_list(dataset.element_spec, arg) 

4834 

4835 normalized_dataset = dataset.map(normalize) 

4836 

4837 # NOTE(mrry): Our `map()` has lost information about the structure of 

4838 # non-tensor components, so re-apply the structure of the original dataset. 

4839 return _RestructuredDataset(normalized_dataset, dataset.element_spec) 

4840 

4841 

4842class _RestructuredDataset(UnaryDataset): 

4843 """An internal helper for changing the element spec of a dataset.""" 

4844 

4845 def __init__(self, dataset, element_spec): 

4846 self._input_dataset = dataset 

4847 self._element_spec = element_spec 

4848 

4849 variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access 

4850 super(_RestructuredDataset, self).__init__(dataset, variant_tensor) 

4851 

4852 @property 

4853 def element_spec(self): 

4854 return self._element_spec 

4855 

4856 

4857def _get_prob_original_static(initial_dist_t, target_dist_t): 

4858 """Returns the static probability of sampling from the original. 

4859 

4860 `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters 

4861 an Op that it isn't defined for. We have some custom logic to avoid this. 

4862 

4863 Args: 

4864 initial_dist_t: A tensor of the initial distribution. 

4865 target_dist_t: A tensor of the target distribution. 

4866 

4867 Returns: 

4868 The probability of sampling from the original distribution as a constant, 

4869 if it is a constant, or `None`. 

4870 """ 

4871 init_static = tensor_util.constant_value(initial_dist_t) 

4872 target_static = tensor_util.constant_value(target_dist_t) 

4873 

4874 if init_static is None or target_static is None: 

4875 return None 

4876 else: 

4877 return np.min(target_static / init_static) 

4878 

4879 

4880def _filter_ds(dataset, 

4881 acceptance_dist_ds, 

4882 initial_dist_ds, 

4883 class_func, 

4884 seed, 

4885 name=None): 

4886 """Filters a dataset based on per-class acceptance probabilities. 

4887 

4888 Args: 

4889 dataset: The dataset to be filtered. 

4890 acceptance_dist_ds: A dataset of acceptance probabilities. 

4891 initial_dist_ds: A dataset of the initial probability distribution, given or 

4892 estimated. 

4893 class_func: A function mapping an element of the input dataset to a scalar 

4894 `tf.int32` tensor. Values should be in `[0, num_classes)`. 

4895 seed: (Optional.) Python integer seed for the resampler. 

4896 name: (Optional.) A name for the tf.data operation. 

4897 

4898 Returns: 

4899 A dataset of (class value, data) after filtering. 

4900 """ 

4901 

4902 def maybe_warn_on_large_rejection(accept_dist, initial_dist): 

4903 proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist) 

4904 return cond.cond( 

4905 math_ops.less(proportion_rejected, .5), 

4906 lambda: accept_dist, 

4907 lambda: logging_ops.Print( # pylint: disable=g-long-lambda 

4908 accept_dist, [proportion_rejected, initial_dist, accept_dist], 

4909 message="Proportion of examples rejected by sampler is high: ", 

4910 summarize=100, 

4911 first_n=10)) 

4912 

4913 acceptance_dist_ds = ( 

4914 DatasetV2.zip((acceptance_dist_ds, initial_dist_ds), 

4915 name=name).map(maybe_warn_on_large_rejection, name=name)) 

4916 

4917 def _gather_and_copy(acceptance_prob, data): 

4918 if isinstance(data, tuple): 

4919 class_val = class_func(*data) 

4920 else: 

4921 class_val = class_func(data) 

4922 return class_val, array_ops.gather(acceptance_prob, class_val), data 

4923 

4924 current_probabilities_and_class_and_data_ds = DatasetV2.zip( 

4925 (acceptance_dist_ds, dataset), name=name).map( 

4926 _gather_and_copy, name=name) 

4927 

4928 def _reject(unused_class_val, p, unused_data): 

4929 return random_ops.random_uniform([], seed=seed, dtype=p.dtype) < p 

4930 

4931 filtered_ds = current_probabilities_and_class_and_data_ds.filter( 

4932 _reject, name=name) 

4933 return filtered_ds.map( 

4934 lambda class_value, _, data: (class_value, data), name=name) 

4935 

4936 

4937# pylint: disable=missing-function-docstring 

4938def _estimate_initial_dist_ds(target_dist_t, 

4939 class_values_ds, 

4940 dist_estimation_batch_size=32, 

4941 smoothing_constant=10, 

4942 name=None): 

4943 num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0]) 

4944 initial_examples_per_class_seen = array_ops.fill([num_classes], 

4945 np.int64(smoothing_constant)) 

4946 

4947 def update_estimate_and_tile(num_examples_per_class_seen, c): 

4948 updated_examples_per_class_seen, dist = _estimate_data_distribution( 

4949 c, num_examples_per_class_seen) 

4950 tiled_dist = array_ops.tile( 

4951 array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) 

4952 return updated_examples_per_class_seen, tiled_dist 

4953 

4954 initial_dist_ds = ( 

4955 class_values_ds.batch(dist_estimation_batch_size, name=name).scan( 

4956 initial_examples_per_class_seen, update_estimate_and_tile, 

4957 name=name).unbatch(name=name)) 

4958 

4959 return initial_dist_ds 

4960 

4961 

4962def _get_target_to_initial_ratio(initial_probs, target_probs): 

4963 # Add tiny to initial_probs to avoid divide by zero. 

4964 denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) 

4965 return target_probs / denom 

4966 

4967 

4968def _estimate_data_distribution(c, num_examples_per_class_seen): 

4969 """Estimate data distribution as labels are seen. 

4970 

4971 Args: 

4972 c: The class labels. Type `int32`, shape `[batch_size]`. 

4973 num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, containing 

4974 counts. 

4975 

4976 Returns: 

4977 num_examples_per_lass_seen: Updated counts. Type `int64`, shape 

4978 `[num_classes]`. 

4979 dist: The updated distribution. Type `float32`, shape `[num_classes]`. 

4980 """ 

4981 num_classes = num_examples_per_class_seen.get_shape()[0] 

4982 # Update the class-count based on what labels are seen in batch. 

4983 num_examples_per_class_seen = math_ops.add( 

4984 num_examples_per_class_seen, 

4985 math_ops.reduce_sum( 

4986 array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) 

4987 init_prob_estimate = math_ops.truediv( 

4988 num_examples_per_class_seen, 

4989 math_ops.reduce_sum(num_examples_per_class_seen)) 

4990 dist = math_ops.cast(init_prob_estimate, dtypes.float32) 

4991 return num_examples_per_class_seen, dist 

4992 

4993 

4994def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): 

4995 """Calculates the acceptance probabilities and mixing ratio. 

4996 

4997 In this case, we assume that we can *either* sample from the original data 

4998 distribution with probability `m`, or sample from a reshaped distribution 

4999 that comes from rejection sampling on the original distribution. This 

5000 rejection sampling is done on a per-class basis, with `a_i` representing the 

5001 probability of accepting data from class `i`. 

5002 

5003 This method is based on solving the following analysis for the reshaped 

5004 distribution: 

5005 

5006 Let F be the probability of a rejection (on any example). 

5007 Let p_i be the proportion of examples in the data in class i (init_probs) 

5008 Let a_i is the rate the rejection sampler should *accept* class i 

5009 Let t_i is the target proportion in the minibatches for class i (target_probs) 

5010 

5011 ``` 

5012 F = sum_i(p_i * (1-a_i)) 

5013 = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1 

5014 ``` 

5015 

5016 An example with class `i` will be accepted if `k` rejections occur, then an 

5017 example with class `i` is seen by the rejector, and it is accepted. This can 

5018 be written as follows: 

5019 

5020 ``` 

5021 t_i = sum_k=0^inf(F^k * p_i * a_i) 

5022 = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1 

5023 = p_i * a_i / sum_j(p_j * a_j) using F from above 

5024 ``` 

5025 

5026 Note that the following constraints hold: 

5027 ``` 

5028 0 <= p_i <= 1, sum_i(p_i) = 1 

5029 0 <= a_i <= 1 

5030 0 <= t_i <= 1, sum_i(t_i) = 1 

5031 ``` 

5032 

5033 A solution for a_i in terms of the other variables is the following: 

5034 ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` 

5035 

5036 If we try to minimize the amount of data rejected, we get the following: 

5037 

5038 M_max = max_i [ t_i / p_i ] 

5039 M_min = min_i [ t_i / p_i ] 

5040 

5041 The desired probability of accepting data if it comes from class `i`: 

5042 

5043 a_i = (t_i/p_i - m) / (M_max - m) 

5044 

5045 The desired probability of pulling a data element from the original dataset, 

5046 rather than the filtered one: 

5047 

5048 m = M_min 

5049 

5050 Args: 

5051 initial_probs: A Tensor of the initial probability distribution, given or 

5052 estimated. 

5053 target_probs: A Tensor of the corresponding classes. 

5054 

5055 Returns: 

5056 (A 1D Tensor with the per-class acceptance probabilities, the desired 

5057 probability of pull from the original distribution.) 

5058 """ 

5059 ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs) 

5060 max_ratio = math_ops.reduce_max(ratio_l) 

5061 min_ratio = math_ops.reduce_min(ratio_l) 

5062 

5063 # Target prob to sample from original distribution. 

5064 m = min_ratio 

5065 

5066 # TODO(joelshor): Simplify fraction, if possible. 

5067 a_i = (ratio_l - m) / (max_ratio - m) 

5068 return a_i, m 

5069 

5070 

5071def _apply_rewrite(dataset, rewrite): 

5072 # pylint: disable=protected-access 

5073 return _VariantDataset( 

5074 gen_dataset_ops.rewrite_dataset(dataset._variant_tensor, rewrite, 

5075 **dataset._flat_structure), 

5076 dataset.element_spec) 

5077 

5078 

5079def _collect_resource_inputs(op): 

5080 """Collects resource inputs for the given ops (and its variant inputs).""" 

5081 

5082 def _process(op_queue, seen_ops): 

5083 """Processes the next element of the op queue. 

5084 

5085 Args: 

5086 op_queue: Queue of Dataset operations to process. 

5087 seen_ops: Already processed set of Operations. 

5088 

5089 Returns: 

5090 A 2-tuple containing sets of resource handles. The first tuple entry 

5091 contains read-only handles and the second entry contains read-write 

5092 handles. 

5093 """ 

5094 

5095 reads = [] 

5096 writes = [] 

5097 op = op_queue.pop() 

5098 if op in seen_ops: 

5099 return reads, writes 

5100 seen_ops.add(op) 

5101 # TODO(b/150139257): All resource inputs are in writes right now since we 

5102 # have not updated the functional ops to set the special attribute that ACD 

5103 # uses to figure out which of the op's inputs are read-only. 

5104 reads, writes = acd_utils.get_read_write_resource_inputs(op) 

5105 # Conservatively assume that any variant inputs are datasets. 

5106 op_queue.extend(t.op for t in op.inputs if t.dtype == dtypes.variant) 

5107 return reads, writes 

5108 

5109 op_queue = [op] 

5110 seen_ops = set() 

5111 all_reads = [] 

5112 all_writes = [] 

5113 while op_queue: 

5114 reads, writes = _process(op_queue, seen_ops) 

5115 all_reads.extend(reads) 

5116 all_writes.extend(writes) 

5117 

5118 return all_reads, all_writes 

5119 

5120 

5121@auto_control_deps.register_acd_resource_resolver 

5122def _resource_resolver(op, resource_reads, resource_writes): 

5123 """Updates resource inputs for tf.data ops with indirect dependencies.""" 

5124 

5125 updated = False 

5126 if op.type in [ 

5127 "DatasetToSingleElement", "DatasetToTFRecord", "ReduceDataset" 

5128 ]: 

5129 reads, writes = _collect_resource_inputs(op) 

5130 for inp in reads: 

5131 if inp not in resource_reads: 

5132 updated = True 

5133 resource_reads.add(inp) 

5134 for inp in writes: 

5135 if inp not in resource_writes: 

5136 updated = True 

5137 resource_writes.add(inp) 

5138 

5139 if op.type in [ 

5140 "IteratorGetNext", "IteratorGetNextSync", "IteratorGetNextAsOptional" 

5141 ]: 

5142 iterator_resource = op.inputs[0] 

5143 make_iterator_ops = [ 

5144 op for op in iterator_resource.consumers() if op.type == "MakeIterator" 

5145 ] 

5146 

5147 if len(make_iterator_ops) == 1: 

5148 reads, writes = _collect_resource_inputs(make_iterator_ops[0]) 

5149 for inp in reads: 

5150 if inp not in resource_reads: 

5151 updated = True 

5152 resource_reads.add(inp) 

5153 for inp in writes: 

5154 if inp not in resource_writes: 

5155 updated = True 

5156 resource_writes.add(inp) 

5157 

5158 return updated 

5159 

5160 

5161dataset_autograph.register_overrides()