Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/io.py: 87%
15 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python API for save and loading a dataset."""
17from tensorflow.python.data.ops import dataset_ops
18from tensorflow.python.util import deprecation
19from tensorflow.python.util.tf_export import tf_export
21COMPRESSION_GZIP = "GZIP"
22COMPRESSION_SNAPPY = "NONE"
23DATASET_SPEC_FILENAME = "dataset_spec.pb"
26@tf_export("data.experimental.save", v1=[])
27@deprecation.deprecated(None, "Use `tf.data.Dataset.save(...)` instead.")
28def save(dataset,
29 path,
30 compression=None,
31 shard_func=None,
32 checkpoint_args=None):
33 """Saves the content of the given dataset.
35 Example usage:
37 >>> import tempfile
38 >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
39 >>> # Save a dataset
40 >>> dataset = tf.data.Dataset.range(2)
41 >>> tf.data.experimental.save(dataset, path)
42 >>> new_dataset = tf.data.experimental.load(path)
43 >>> for elem in new_dataset:
44 ... print(elem)
45 tf.Tensor(0, shape=(), dtype=int64)
46 tf.Tensor(1, shape=(), dtype=int64)
48 The saved dataset is saved in multiple file "shards". By default, the dataset
49 output is divided to shards in a round-robin fashion but custom sharding can
50 be specified via the `shard_func` function. For example, you can save the
51 dataset to using a single shard as follows:
53 ```python
54 dataset = make_dataset()
55 def custom_shard_func(element):
56 return np.int64(0)
57 dataset = tf.data.experimental.save(
58 path="/path/to/data", ..., shard_func=custom_shard_func)
59 ```
61 To enable checkpointing, pass in `checkpoint_args` to the `save` method
62 as follows:
64 ```python
65 dataset = tf.data.Dataset.range(100)
66 save_dir = "..."
67 checkpoint_prefix = "..."
68 step_counter = tf.Variable(0, trainable=False)
69 checkpoint_args = {
70 "checkpoint_interval": 50,
71 "step_counter": step_counter,
72 "directory": checkpoint_prefix,
73 "max_to_keep": 20,
74 }
75 dataset.save(dataset, save_dir, checkpoint_args=checkpoint_args)
76 ```
78 NOTE: The directory layout and file format used for saving the dataset is
79 considered an implementation detail and may change. For this reason, datasets
80 saved through `tf.data.experimental.save` should only be consumed through
81 `tf.data.experimental.load`, which is guaranteed to be backwards compatible.
83 Args:
84 dataset: The dataset to save.
85 path: Required. A directory to use for saving the dataset.
86 compression: Optional. The algorithm to use to compress data when writing
87 it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
88 shard_func: Optional. A function to control the mapping of dataset elements
89 to file shards. The function is expected to map elements of the input
90 dataset to int64 shard IDs. If present, the function will be traced and
91 executed as graph computation.
92 checkpoint_args: Optional args for checkpointing which will be passed into
93 the `tf.train.CheckpointManager`. If `checkpoint_args` are not specified,
94 then checkpointing will not be performed. The `save()` implementation
95 creates a `tf.train.Checkpoint` object internally, so users should not
96 set the `checkpoint` argument in `checkpoint_args`.
98 Returns:
99 An operation which when executed performs the save. When writing
100 checkpoints, returns None. The return value is useful in unit tests.
102 Raises:
103 ValueError if `checkpoint` is passed into `checkpoint_args`.
104 """
105 return dataset.save(path, compression, shard_func, checkpoint_args)
108@tf_export("data.experimental.load", v1=[])
109@deprecation.deprecated(None, "Use `tf.data.Dataset.load(...)` instead.")
110def load(path, element_spec=None, compression=None, reader_func=None):
111 """Loads a previously saved dataset.
113 Example usage:
115 >>> import tempfile
116 >>> path = os.path.join(tempfile.gettempdir(), "saved_data")
117 >>> # Save a dataset
118 >>> dataset = tf.data.Dataset.range(2)
119 >>> tf.data.experimental.save(dataset, path)
120 >>> new_dataset = tf.data.experimental.load(path)
121 >>> for elem in new_dataset:
122 ... print(elem)
123 tf.Tensor(0, shape=(), dtype=int64)
124 tf.Tensor(1, shape=(), dtype=int64)
127 If the default option of sharding the saved dataset was used, the element
128 order of the saved dataset will be preserved when loading it.
130 The `reader_func` argument can be used to specify a custom order in which
131 elements should be loaded from the individual shards. The `reader_func` is
132 expected to take a single argument -- a dataset of datasets, each containing
133 elements of one of the shards -- and return a dataset of elements. For
134 example, the order of shards can be shuffled when loading them as follows:
136 ```python
137 def custom_reader_func(datasets):
138 datasets = datasets.shuffle(NUM_SHARDS)
139 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
141 dataset = tf.data.experimental.load(
142 path="/path/to/data", ..., reader_func=custom_reader_func)
143 ```
145 Args:
146 path: Required. A path pointing to a previously saved dataset.
147 element_spec: Optional. A nested structure of `tf.TypeSpec` objects matching
148 the structure of an element of the saved dataset and specifying the type
149 of individual element components. If not provided, the nested structure of
150 `tf.TypeSpec` saved with the saved dataset is used. Note that this
151 argument is required in graph mode.
152 compression: Optional. The algorithm to use to decompress the data when
153 reading it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`.
154 reader_func: Optional. A function to control how to read data from shards.
155 If present, the function will be traced and executed as graph computation.
157 Returns:
158 A `tf.data.Dataset` instance.
160 Raises:
161 FileNotFoundError: If `element_spec` is not specified and the saved nested
162 structure of `tf.TypeSpec` can not be located with the saved dataset.
163 ValueError: If `element_spec` is not specified and the method is executed
164 in graph mode.
165 """
166 return dataset_ops.Dataset.load(path, element_spec, compression, reader_func)