Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/io.py: 87%

15 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python API for save and loading a dataset.""" 

16 

17from tensorflow.python.data.ops import dataset_ops 

18from tensorflow.python.util import deprecation 

19from tensorflow.python.util.tf_export import tf_export 

20 

21COMPRESSION_GZIP = "GZIP" 

22COMPRESSION_SNAPPY = "NONE" 

23DATASET_SPEC_FILENAME = "dataset_spec.pb" 

24 

25 

26@tf_export("data.experimental.save", v1=[]) 

27@deprecation.deprecated(None, "Use `tf.data.Dataset.save(...)` instead.") 

28def save(dataset, 

29 path, 

30 compression=None, 

31 shard_func=None, 

32 checkpoint_args=None): 

33 """Saves the content of the given dataset. 

34 

35 Example usage: 

36 

37 >>> import tempfile 

38 >>> path = os.path.join(tempfile.gettempdir(), "saved_data") 

39 >>> # Save a dataset 

40 >>> dataset = tf.data.Dataset.range(2) 

41 >>> tf.data.experimental.save(dataset, path) 

42 >>> new_dataset = tf.data.experimental.load(path) 

43 >>> for elem in new_dataset: 

44 ... print(elem) 

45 tf.Tensor(0, shape=(), dtype=int64) 

46 tf.Tensor(1, shape=(), dtype=int64) 

47 

48 The saved dataset is saved in multiple file "shards". By default, the dataset 

49 output is divided to shards in a round-robin fashion but custom sharding can 

50 be specified via the `shard_func` function. For example, you can save the 

51 dataset to using a single shard as follows: 

52 

53 ```python 

54 dataset = make_dataset() 

55 def custom_shard_func(element): 

56 return np.int64(0) 

57 dataset = tf.data.experimental.save( 

58 path="/path/to/data", ..., shard_func=custom_shard_func) 

59 ``` 

60 

61 To enable checkpointing, pass in `checkpoint_args` to the `save` method 

62 as follows: 

63 

64 ```python 

65 dataset = tf.data.Dataset.range(100) 

66 save_dir = "..." 

67 checkpoint_prefix = "..." 

68 step_counter = tf.Variable(0, trainable=False) 

69 checkpoint_args = { 

70 "checkpoint_interval": 50, 

71 "step_counter": step_counter, 

72 "directory": checkpoint_prefix, 

73 "max_to_keep": 20, 

74 } 

75 dataset.save(dataset, save_dir, checkpoint_args=checkpoint_args) 

76 ``` 

77 

78 NOTE: The directory layout and file format used for saving the dataset is 

79 considered an implementation detail and may change. For this reason, datasets 

80 saved through `tf.data.experimental.save` should only be consumed through 

81 `tf.data.experimental.load`, which is guaranteed to be backwards compatible. 

82 

83 Args: 

84 dataset: The dataset to save. 

85 path: Required. A directory to use for saving the dataset. 

86 compression: Optional. The algorithm to use to compress data when writing 

87 it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`. 

88 shard_func: Optional. A function to control the mapping of dataset elements 

89 to file shards. The function is expected to map elements of the input 

90 dataset to int64 shard IDs. If present, the function will be traced and 

91 executed as graph computation. 

92 checkpoint_args: Optional args for checkpointing which will be passed into 

93 the `tf.train.CheckpointManager`. If `checkpoint_args` are not specified, 

94 then checkpointing will not be performed. The `save()` implementation 

95 creates a `tf.train.Checkpoint` object internally, so users should not 

96 set the `checkpoint` argument in `checkpoint_args`. 

97 

98 Returns: 

99 An operation which when executed performs the save. When writing 

100 checkpoints, returns None. The return value is useful in unit tests. 

101 

102 Raises: 

103 ValueError if `checkpoint` is passed into `checkpoint_args`. 

104 """ 

105 return dataset.save(path, compression, shard_func, checkpoint_args) 

106 

107 

108@tf_export("data.experimental.load", v1=[]) 

109@deprecation.deprecated(None, "Use `tf.data.Dataset.load(...)` instead.") 

110def load(path, element_spec=None, compression=None, reader_func=None): 

111 """Loads a previously saved dataset. 

112 

113 Example usage: 

114 

115 >>> import tempfile 

116 >>> path = os.path.join(tempfile.gettempdir(), "saved_data") 

117 >>> # Save a dataset 

118 >>> dataset = tf.data.Dataset.range(2) 

119 >>> tf.data.experimental.save(dataset, path) 

120 >>> new_dataset = tf.data.experimental.load(path) 

121 >>> for elem in new_dataset: 

122 ... print(elem) 

123 tf.Tensor(0, shape=(), dtype=int64) 

124 tf.Tensor(1, shape=(), dtype=int64) 

125 

126 

127 If the default option of sharding the saved dataset was used, the element 

128 order of the saved dataset will be preserved when loading it. 

129 

130 The `reader_func` argument can be used to specify a custom order in which 

131 elements should be loaded from the individual shards. The `reader_func` is 

132 expected to take a single argument -- a dataset of datasets, each containing 

133 elements of one of the shards -- and return a dataset of elements. For 

134 example, the order of shards can be shuffled when loading them as follows: 

135 

136 ```python 

137 def custom_reader_func(datasets): 

138 datasets = datasets.shuffle(NUM_SHARDS) 

139 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE) 

140 

141 dataset = tf.data.experimental.load( 

142 path="/path/to/data", ..., reader_func=custom_reader_func) 

143 ``` 

144 

145 Args: 

146 path: Required. A path pointing to a previously saved dataset. 

147 element_spec: Optional. A nested structure of `tf.TypeSpec` objects matching 

148 the structure of an element of the saved dataset and specifying the type 

149 of individual element components. If not provided, the nested structure of 

150 `tf.TypeSpec` saved with the saved dataset is used. Note that this 

151 argument is required in graph mode. 

152 compression: Optional. The algorithm to use to decompress the data when 

153 reading it. Supported options are `GZIP` and `NONE`. Defaults to `NONE`. 

154 reader_func: Optional. A function to control how to read data from shards. 

155 If present, the function will be traced and executed as graph computation. 

156 

157 Returns: 

158 A `tf.data.Dataset` instance. 

159 

160 Raises: 

161 FileNotFoundError: If `element_spec` is not specified and the saved nested 

162 structure of `tf.TypeSpec` can not be located with the saved dataset. 

163 ValueError: If `element_spec` is not specified and the method is executed 

164 in graph mode. 

165 """ 

166 return dataset_ops.Dataset.load(path, element_spec, compression, reader_func)