Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/snapshot.py: 44%

41 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Dataset snapshot and related functionality.""" 

16from tensorflow.python.data.ops import dataset_ops 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import ops 

19from tensorflow.python.framework import random_seed 

20from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 

21from tensorflow.python.util import deprecation 

22from tensorflow.python.util.tf_export import tf_export 

23 

24COMPRESSION_GZIP = "GZIP" 

25COMPRESSION_SNAPPY = "SNAPPY" 

26COMPRESSION_NONE = None 

27 

28 

29class _LegacySnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset): 

30 """A Dataset that captures a snapshot or reads from a snapshot.""" 

31 

32 def __init__(self, 

33 input_dataset, 

34 path, 

35 compression=None, 

36 reader_path_prefix=None, 

37 writer_path_prefix=None, 

38 shard_size_bytes=None, 

39 pending_snapshot_expiry_seconds=None, 

40 num_reader_threads=None, 

41 reader_buffer_size=None, 

42 num_writer_threads=None, 

43 writer_buffer_size=None, 

44 shuffle_on_read=None, 

45 shuffle_seed=None, 

46 mode=None, 

47 snapshot_name=None): 

48 

49 self._compression = compression if compression is not None else "" 

50 self._reader_path_prefix = ( 

51 reader_path_prefix if reader_path_prefix is not None else "") 

52 self._writer_path_prefix = ( 

53 writer_path_prefix if writer_path_prefix is not None else "") 

54 self._shard_size_bytes = ( 

55 shard_size_bytes if shard_size_bytes is not None else -1) 

56 self._pending_snapshot_expiry_seconds = ( 

57 pending_snapshot_expiry_seconds 

58 if pending_snapshot_expiry_seconds is not None else -1) 

59 self._num_reader_threads = ( 

60 num_reader_threads if num_reader_threads is not None else -1) 

61 self._reader_buffer_size = ( 

62 reader_buffer_size if reader_buffer_size is not None else -1) 

63 self._num_writer_threads = ( 

64 num_writer_threads if num_writer_threads is not None else -1) 

65 self._writer_buffer_size = ( 

66 writer_buffer_size if writer_buffer_size is not None else -1) 

67 self._shuffle_on_read = ( 

68 shuffle_on_read if shuffle_on_read is not None else False) 

69 self._mode = (mode if mode is not None else "auto") 

70 self._snapshot_name = (snapshot_name if snapshot_name is not None else "") 

71 

72 self._seed, self._seed2 = random_seed.get_seed(shuffle_seed) 

73 

74 self._input_dataset = input_dataset 

75 self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path") 

76 

77 variant_tensor = ged_ops.snapshot_dataset( 

78 self._input_dataset._variant_tensor, # pylint: disable=protected-access 

79 path=self._path, 

80 compression=self._compression, 

81 reader_path_prefix=self._reader_path_prefix, 

82 writer_path_prefix=self._writer_path_prefix, 

83 shard_size_bytes=self._shard_size_bytes, 

84 pending_snapshot_expiry_seconds=self._pending_snapshot_expiry_seconds, 

85 num_reader_threads=self._num_reader_threads, 

86 reader_buffer_size=self._reader_buffer_size, 

87 num_writer_threads=self._num_writer_threads, 

88 writer_buffer_size=self._writer_buffer_size, 

89 shuffle_on_read=self._shuffle_on_read, 

90 seed=self._seed, 

91 seed2=self._seed2, 

92 mode=self._mode, 

93 snapshot_name=self._snapshot_name, 

94 **self._flat_structure) 

95 

96 super(_LegacySnapshotDataset, self).__init__(input_dataset, variant_tensor) 

97 

98 

99@deprecation.deprecated(None, "Use `tf.data.Dataset.shapshot(...)` instead.") 

100def legacy_snapshot(path, 

101 compression=None, 

102 reader_path_prefix=None, 

103 writer_path_prefix=None, 

104 shard_size_bytes=None, 

105 pending_snapshot_expiry_seconds=None, 

106 num_reader_threads=None, 

107 reader_buffer_size=None, 

108 num_writer_threads=None, 

109 writer_buffer_size=None, 

110 shuffle_on_read=None, 

111 shuffle_seed=None, 

112 mode=None, 

113 snapshot_name=None): 

114 """Writes to/reads from a snapshot of a dataset. 

115 

116 This function attempts to determine whether a valid snapshot exists at the 

117 `path`, and reads from the snapshot if so. If not, it will run the 

118 preprocessing pipeline as usual, and write out a snapshot of the data 

119 processed for future use. 

120 

121 Args: 

122 path: A directory where we want to save our snapshots and/or read from a 

123 previously saved snapshot. 

124 compression: The type of compression to apply to the Dataset. Currently 

125 supports "GZIP" or None. Defaults to None (no compression). 

126 reader_path_prefix: A prefix to add to the path when reading from snapshots. 

127 Defaults to None. 

128 writer_path_prefix: A prefix to add to the path when writing to snapshots. 

129 Defaults to None. 

130 shard_size_bytes: The size of each shard to be written by the snapshot 

131 dataset op. Defaults to 10 GiB. 

132 pending_snapshot_expiry_seconds: How long to wait (in seconds) before the 

133 snapshot op considers a previously unfinished snapshot to be stale. 

134 num_reader_threads: Number of threads to parallelize reading from snapshot. 

135 Especially useful if compression is turned on since the decompression 

136 operation tends to be intensive. Defaults to 1. If > 1, then this might 

137 introduce non-determinism i.e. the order in which the elements are read 

138 from the snapshot are different from the order they're written. 

139 reader_buffer_size: Maximum number of elements we can prefetch reading from 

140 the snapshot. Defaults to 1. Increasing this might improve performance but 

141 will increase memory consumption. 

142 num_writer_threads: Number of threads to parallelize writing from snapshot. 

143 We'll open up `num_writer_threads` files and write to them in parallel. 

144 Especially useful if compression is turned on since the compression 

145 operation tends to be intensive. Defaults to 1. If > 1, then this might 

146 introduce non-determinism i.e. the order in which the elements are read 

147 from the upstream iterator are different from the order they're written. 

148 writer_buffer_size: Maximum number of pipeline elements to fill up the 

149 buffer before writing them out using `num_writer_threads`. 

150 shuffle_on_read: If this is True, then the order in which examples are 

151 produced when reading from a snapshot will be random. Defaults to False. 

152 shuffle_seed: Optional. If shuffle_seed is set, the random number generator 

153 used for shuffling (when shuffle_on_read is turned on) is seeded by the 

154 given seed. Otherwise, it is seeded by a random seed that differs for 

155 every run. 

156 mode: The mode at which snapshot should operate. Valid options are "auto", 

157 "read", "write", and "passthrough". The default mode is "auto", where the 

158 snapshot op will automatically determine what mode to operate in. 

159 snapshot_name: If set, use the supplied string as a named snapshot name 

160 instead of introspecting the data pipeline and automatically generating a 

161 unique identifier for the snapshot. 

162 

163 Returns: 

164 A `Dataset` transformation function, which can be passed to 

165 `tf.data.Dataset.apply`. 

166 """ 

167 

168 def _apply_fn(dataset): 

169 return _LegacySnapshotDataset( 

170 input_dataset=dataset, 

171 path=path, 

172 compression=compression, 

173 reader_path_prefix=reader_path_prefix, 

174 writer_path_prefix=writer_path_prefix, 

175 shard_size_bytes=shard_size_bytes, 

176 pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds, 

177 num_reader_threads=num_reader_threads, 

178 reader_buffer_size=reader_buffer_size, 

179 num_writer_threads=num_writer_threads, 

180 writer_buffer_size=writer_buffer_size, 

181 shuffle_on_read=shuffle_on_read, 

182 shuffle_seed=shuffle_seed, 

183 mode=mode, 

184 snapshot_name=snapshot_name) 

185 

186 return _apply_fn 

187 

188 

189@deprecation.deprecated(None, "Use `tf.data.Dataset.snapshot(...)`.") 

190@tf_export("data.experimental.snapshot") 

191def snapshot(path, compression="AUTO", reader_func=None, shard_func=None): 

192 """API to persist the output of the input dataset. 

193 

194 The snapshot API allows users to transparently persist the output of their 

195 preprocessing pipeline to disk, and materialize the pre-processed data on a 

196 different training run. 

197 

198 This API enables repeated preprocessing steps to be consolidated, and allows 

199 re-use of already processed data, trading off disk storage and network 

200 bandwidth for freeing up more valuable CPU resources and accelerator compute 

201 time. 

202 

203 https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md 

204 has detailed design documentation of this feature. 

205 

206 Users can specify various options to control the behavior of snapshot, 

207 including how snapshots are read from and written to by passing in 

208 user-defined functions to the `reader_func` and `shard_func` parameters. 

209 

210 `shard_func` is a user specified function that maps input elements to snapshot 

211 shards. 

212 

213 Users may want to specify this function to control how snapshot files should 

214 be written to disk. Below is an example of how a potential shard_func could 

215 be written. 

216 

217 ```python 

218 dataset = ... 

219 dataset = dataset.enumerate() 

220 dataset = dataset.apply(tf.data.Dataset.shapshot("/path/to/snapshot/dir", 

221 shard_func=lambda x, y: x % NUM_SHARDS, ...)) 

222 dataset = dataset.map(lambda x, y: y) 

223 ``` 

224 

225 `reader_func` is a user specified function that accepts a single argument: 

226 (1) a Dataset of Datasets, each representing a "split" of elements of the 

227 original dataset. The cardinality of the input dataset matches the 

228 number of the shards specified in the `shard_func` (see above). The function 

229 should return a Dataset of elements of the original dataset. 

230 

231 Users may want specify this function to control how snapshot files should be 

232 read from disk, including the amount of shuffling and parallelism. 

233 

234 Here is an example of a standard reader function a user can define. This 

235 function enables both dataset shuffling and parallel reading of datasets: 

236 

237 ```python 

238 def user_reader_func(datasets): 

239 # shuffle the datasets splits 

240 datasets = datasets.shuffle(NUM_CORES) 

241 # read datasets in parallel and interleave their elements 

242 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE) 

243 

244 dataset = dataset.apply(tf.data.Dataset.shapshot("/path/to/snapshot/dir", 

245 reader_func=user_reader_func)) 

246 ``` 

247 

248 By default, snapshot parallelizes reads by the number of cores available on 

249 the system, but will not attempt to shuffle the data. 

250 

251 Args: 

252 path: Required. A directory to use for storing / loading the snapshot to / 

253 from. 

254 compression: Optional. The type of compression to apply to the snapshot 

255 written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None. 

256 Defaults to AUTO, which attempts to pick an appropriate compression 

257 algorithm for the dataset. 

258 reader_func: Optional. A function to control how to read data from snapshot 

259 shards. 

260 shard_func: Optional. A function to control how to shard data when writing a 

261 snapshot. 

262 

263 Returns: 

264 A `Dataset` transformation function, which can be passed to 

265 `tf.data.Dataset.apply`. 

266 """ 

267 

268 def _apply_fn(dataset): 

269 """Actual dataset transformation.""" 

270 return dataset.snapshot( 

271 path=path, 

272 compression=compression, 

273 reader_func=reader_func, 

274 shard_func=shard_func) 

275 

276 return _apply_fn