Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/shuffle_ops.py: 27%

84 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Experimental shuffle ops.""" 

16 

17import functools 

18import numpy as np 

19 

20from tensorflow.python.data.experimental.ops import random_access 

21from tensorflow.python.data.ops import dataset_ops 

22from tensorflow.python.data.util import random_seed 

23from tensorflow.python.framework import constant_op 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import ops 

26from tensorflow.python.ops import array_ops 

27from tensorflow.python.ops import gen_dataset_ops 

28from tensorflow.python.ops import math_ops 

29from tensorflow.python.ops import stateless_random_ops 

30from tensorflow.python.util import deprecation 

31from tensorflow.python.util.tf_export import tf_export 

32 

33 

34class _ShuffleAndRepeatDataset(dataset_ops.UnaryUnchangedStructureDataset): 

35 """A `Dataset` that fuses `shuffle` and `repeat`.""" 

36 

37 def __init__(self, input_dataset, buffer_size, count=None, seed=None): 

38 self._input_dataset = input_dataset 

39 self._buffer_size = ops.convert_to_tensor( 

40 buffer_size, dtype=dtypes.int64, name="buffer_size") 

41 if count is None: 

42 self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") 

43 else: 

44 self._count = ops.convert_to_tensor( 

45 count, dtype=dtypes.int64, name="count") 

46 self._seed, self._seed2 = random_seed.get_seed(seed) 

47 variant_tensor = gen_dataset_ops.shuffle_and_repeat_dataset( 

48 self._input_dataset._variant_tensor, # pylint: disable=protected-access 

49 buffer_size=self._buffer_size, 

50 count=self._count, 

51 seed=self._seed, 

52 seed2=self._seed2, 

53 **self._flat_structure) 

54 super(_ShuffleAndRepeatDataset, self).__init__(input_dataset, 

55 variant_tensor) 

56 

57 

58@deprecation.deprecated( 

59 None, "Use `tf.data.Dataset.shuffle(buffer_size, seed)` followed by " 

60 "`tf.data.Dataset.repeat(count)`. Static tf.data optimizations will take " 

61 "care of using the fused implementation.") 

62@tf_export("data.experimental.shuffle_and_repeat") 

63def shuffle_and_repeat(buffer_size, count=None, seed=None): 

64 """Shuffles and repeats a Dataset, reshuffling with each repetition. 

65 

66 >>> d = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 

67 >>> d = d.apply(tf.data.experimental.shuffle_and_repeat(2, count=2)) 

68 >>> [elem.numpy() for elem in d] # doctest: +SKIP 

69 [2, 3, 1, 1, 3, 2] 

70 

71 ```python 

72 dataset.apply( 

73 tf.data.experimental.shuffle_and_repeat(buffer_size, count, seed)) 

74 ``` 

75 

76 produces the same output as 

77 

78 ```python 

79 dataset.shuffle( 

80 buffer_size, seed=seed, reshuffle_each_iteration=True).repeat(count) 

81 ``` 

82 

83 In each repetition, this dataset fills a buffer with `buffer_size` elements, 

84 then randomly samples elements from this buffer, replacing the selected 

85 elements with new elements. For perfect shuffling, set the buffer size equal 

86 to the full size of the dataset. 

87 

88 For instance, if your dataset contains 10,000 elements but `buffer_size` is 

89 set to 1,000, then `shuffle` will initially select a random element from 

90 only the first 1,000 elements in the buffer. Once an element is selected, 

91 its space in the buffer is replaced by the next (i.e. 1,001-st) element, 

92 maintaining the 1,000 element buffer. 

93 

94 Args: 

95 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum 

96 number elements that will be buffered when prefetching. 

97 count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the number 

98 of times the dataset should be repeated. The default behavior (if `count` 

99 is `None` or `-1`) is for the dataset be repeated indefinitely. 

100 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 

101 seed that will be used to create the distribution. See 

102 `tf.random.set_seed` for behavior. 

103 

104 Returns: 

105 A `Dataset` transformation function, which can be passed to 

106 `tf.data.Dataset.apply`. 

107 """ 

108 

109 def _apply_fn(dataset): # pylint: disable=missing-docstring 

110 return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed) 

111 

112 return _apply_fn 

113 

114 

115def _process_file_infos(file_infos): 

116 """Computes aggregate information about files to read. 

117 

118 The method collects information about the files to read, the total number of 

119 elements, and arrays that can be used to account for elements to be skipped, 

120 which can be specified via the "skip" and "take" keys. 

121 

122 To account for elements to skip, the range of each file can be divided into 

123 three regions: 

124 - S (elements to skip) 

125 - T (elements to read) 

126 - R (remainder of elements that will also be skipped) 

127 

128 The `thresholds` and `offsets` arrays are initialized as follows: 

129 `thresholds = [0, T_1, T_1 + T_2, ...]` and 

130 `offsets = [S_1, S_1 + R_1 + S_2, S_1 + R_1 + S_2 + R_2 + S_3, ...]` 

131 

132 This makes it possible to map an index from a contiguous range 

133 `(0...num_elements_to_read)` to an index in the range of all elements, 

134 skipping over elements as per the "skip" and "take" keys values. In 

135 particular, for a given input index `X`, we find the greatest `thresholds` 

136 value that is smaller or equal to `X`. Let `t(X)` denotes such index in the 

137 `thresholds` array. The output index is computed as `X + offsets[t(X)]`. 

138 

139 Args: 

140 file_infos: See `file_infos` argument of `index_shuffle` for details. 

141 

142 Returns: 

143 A dictionary containing the following keys: 

144 - `files`, the vector of pathnames of files to read 

145 - `num_elements`, an integer identifying the total number of elements 

146 - `offsets`, the vector of offsets to use for index adjustment (in case 

147 any elements should be skipped) 

148 - `thresholds`, the vector of thresholds to use for index adjustment (in 

149 case any elements should be skipped) 

150 """ 

151 files = [] 

152 num_elements = 0 

153 offsets = np.int64([]) 

154 offset_sum = 0 

155 thresholds = np.int64([]) 

156 threshold_sum = 0 

157 adjustment_needed = False 

158 for file_info in file_infos: 

159 files.append(file_info["path"]) 

160 skip = 0 

161 if "skip" in file_info: 

162 if file_info["skip"] < -1: 

163 raise ValueError("`skip` should be greater than `-1` but got {}".format( 

164 file_info["skip"])) 

165 if file_info["skip"] == -1: 

166 skip = file_info["num_elements"] 

167 else: 

168 skip = min(file_info["skip"], file_info["num_elements"]) 

169 take = file_info["num_elements"] - skip 

170 if "take" in file_info: 

171 if file_info["take"] < -1: 

172 raise ValueError("`take` should be greater than `-1` but got {}".format( 

173 file_info["take"])) 

174 # `file_info["take"] == -1` is a no-op 

175 if file_info["take"] != -1: 

176 take = min(file_info["take"], take) 

177 remainder = file_info["num_elements"] - skip - take 

178 if take != file_info["num_elements"]: 

179 adjustment_needed = True 

180 num_elements += take 

181 offsets = np.append(offsets, offset_sum + skip) 

182 offset_sum += skip + remainder 

183 thresholds = np.append(thresholds, threshold_sum) 

184 threshold_sum += take 

185 result = {"files": files, "num_elements": num_elements} 

186 if adjustment_needed: 

187 result["offsets"] = offsets 

188 result["thresholds"] = thresholds 

189 return result 

190 

191 

192def _adjust_index(index, thresholds, offsets): 

193 """Adjusts index to account for elements to be skipped.""" 

194 t_index = array_ops.shape( 

195 array_ops.boolean_mask( 

196 thresholds, 

197 math_ops.less_equal(thresholds, index)))[0] - 1 

198 return index + array_ops.gather(offsets, t_index) 

199 

200 

201# TODO(jsimsa): Expose this method in the public API. When we do, consider 

202# defining `FileInfo` as a public API to encapsulate the information provided 

203# through the `file_infos` argument. 

204def index_shuffle(file_infos, 

205 reader_factory, 

206 seed=None, 

207 reshuffle_each_iteration=False, 

208 num_parallel_calls=dataset_ops.AUTOTUNE): 

209 """Creates a (globally) shuffled dataset from the given set of files. 

210 

211 Unlike `tf.data.Dataset.shuffle()`, which uses an in-memory buffer to shuffle 

212 elements of input dataset in a streaming fashion, 

213 `tf.data.experimental.index_shuffle()` performs a global shuffle of element 

214 indices and then reads the data in a shuffled order. The advantage of 

215 `index_shuffle()` is that it can perform global shuffle of datasets that do 

216 not fit into memory (as long as the array of their indices does) and that the 

217 shuffling logic it provides is compatible with symbolic checkpointing. The 

218 disadvantage of `index_shuffle()` is that reading data in a shuffled random 

219 order will in general not be as efficient as reading data sequentially. 

220 

221 Args: 

222 file_infos: A list of dictionaries that describe each file of the input 

223 dataset. Each dictionary is expected to contain the "path" key, which 

224 identifies the path of the file and the "num_elements" key, which 

225 identifies the number of elements in the file. In addition, the "skip" 

226 and "take" keys can be used to identify the number of elements to skip 

227 and take respectively. By default, no elements are skipped and all 

228 elements are taken. 

229 reader_factory: A function that maps a sequence of filenames to an instance 

230 of `tf.data.Dataset` that reads data from the files. 

231 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 

232 seed that will be used to shuffle the order of elements. Default to 

233 non-deterministic seed. 

234 reshuffle_each_iteration: (Optional.) A `tf.bool` scalar `tf.Tensor`, that 

235 determines whether to change the shuffle order each iteration. Defaults to 

236 `False`. 

237 num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`, that 

238 determines the maximum number of random access operations to perform 

239 in parallel. By default, the tf.data runtime uses autotuning to determine 

240 the value dynamically. 

241 

242 Returns: 

243 A `tf.data.Dataset` object, representing a globally shuffled dataset of 

244 the input data. 

245 """ 

246 

247 result = _process_file_infos(file_infos) 

248 

249 def sequential_index_shuffle(seeds): 

250 dataset = dataset_ops.Dataset.range(result["num_elements"]) 

251 

252 def read_element(dataset, index): 

253 # 1) Shuffle the index. 

254 shuffled_index = stateless_random_ops.index_shuffle( 

255 index, seeds, result["num_elements"] - 1) 

256 # 2) If needed, adjust the index to the non-contiguous range. 

257 if "thresholds" in result and "offsets" in result: 

258 shuffled_index = _adjust_index(shuffled_index, result["thresholds"], 

259 result["offsets"]) 

260 # 3) Perform the read. 

261 return random_access.at(dataset, shuffled_index) 

262 

263 # We evaluate `reader_factory()` eagerly to prevent the dataset from being 

264 # created on every lookup. 

265 map_func = functools.partial(read_element, reader_factory(result["files"])) 

266 return dataset.map(map_func, num_parallel_calls=num_parallel_calls) 

267 

268 rng_ds = dataset_ops.Dataset.random( 

269 seed=seed, 

270 rerandomize_each_iteration=reshuffle_each_iteration) 

271 rng_ds = rng_ds.take(2).batch(2, drop_remainder=True) 

272 return rng_ds.flat_map(sequential_index_shuffle)