Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/shuffle_ops.py: 27%
84 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experimental shuffle ops."""
17import functools
18import numpy as np
20from tensorflow.python.data.experimental.ops import random_access
21from tensorflow.python.data.ops import dataset_ops
22from tensorflow.python.data.util import random_seed
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_dataset_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import stateless_random_ops
30from tensorflow.python.util import deprecation
31from tensorflow.python.util.tf_export import tf_export
34class _ShuffleAndRepeatDataset(dataset_ops.UnaryUnchangedStructureDataset):
35 """A `Dataset` that fuses `shuffle` and `repeat`."""
37 def __init__(self, input_dataset, buffer_size, count=None, seed=None):
38 self._input_dataset = input_dataset
39 self._buffer_size = ops.convert_to_tensor(
40 buffer_size, dtype=dtypes.int64, name="buffer_size")
41 if count is None:
42 self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
43 else:
44 self._count = ops.convert_to_tensor(
45 count, dtype=dtypes.int64, name="count")
46 self._seed, self._seed2 = random_seed.get_seed(seed)
47 variant_tensor = gen_dataset_ops.shuffle_and_repeat_dataset(
48 self._input_dataset._variant_tensor, # pylint: disable=protected-access
49 buffer_size=self._buffer_size,
50 count=self._count,
51 seed=self._seed,
52 seed2=self._seed2,
53 **self._flat_structure)
54 super(_ShuffleAndRepeatDataset, self).__init__(input_dataset,
55 variant_tensor)
58@deprecation.deprecated(
59 None, "Use `tf.data.Dataset.shuffle(buffer_size, seed)` followed by "
60 "`tf.data.Dataset.repeat(count)`. Static tf.data optimizations will take "
61 "care of using the fused implementation.")
62@tf_export("data.experimental.shuffle_and_repeat")
63def shuffle_and_repeat(buffer_size, count=None, seed=None):
64 """Shuffles and repeats a Dataset, reshuffling with each repetition.
66 >>> d = tf.data.Dataset.from_tensor_slices([1, 2, 3])
67 >>> d = d.apply(tf.data.experimental.shuffle_and_repeat(2, count=2))
68 >>> [elem.numpy() for elem in d] # doctest: +SKIP
69 [2, 3, 1, 1, 3, 2]
71 ```python
72 dataset.apply(
73 tf.data.experimental.shuffle_and_repeat(buffer_size, count, seed))
74 ```
76 produces the same output as
78 ```python
79 dataset.shuffle(
80 buffer_size, seed=seed, reshuffle_each_iteration=True).repeat(count)
81 ```
83 In each repetition, this dataset fills a buffer with `buffer_size` elements,
84 then randomly samples elements from this buffer, replacing the selected
85 elements with new elements. For perfect shuffling, set the buffer size equal
86 to the full size of the dataset.
88 For instance, if your dataset contains 10,000 elements but `buffer_size` is
89 set to 1,000, then `shuffle` will initially select a random element from
90 only the first 1,000 elements in the buffer. Once an element is selected,
91 its space in the buffer is replaced by the next (i.e. 1,001-st) element,
92 maintaining the 1,000 element buffer.
94 Args:
95 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum
96 number elements that will be buffered when prefetching.
97 count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the number
98 of times the dataset should be repeated. The default behavior (if `count`
99 is `None` or `-1`) is for the dataset be repeated indefinitely.
100 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
101 seed that will be used to create the distribution. See
102 `tf.random.set_seed` for behavior.
104 Returns:
105 A `Dataset` transformation function, which can be passed to
106 `tf.data.Dataset.apply`.
107 """
109 def _apply_fn(dataset): # pylint: disable=missing-docstring
110 return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
112 return _apply_fn
115def _process_file_infos(file_infos):
116 """Computes aggregate information about files to read.
118 The method collects information about the files to read, the total number of
119 elements, and arrays that can be used to account for elements to be skipped,
120 which can be specified via the "skip" and "take" keys.
122 To account for elements to skip, the range of each file can be divided into
123 three regions:
124 - S (elements to skip)
125 - T (elements to read)
126 - R (remainder of elements that will also be skipped)
128 The `thresholds` and `offsets` arrays are initialized as follows:
129 `thresholds = [0, T_1, T_1 + T_2, ...]` and
130 `offsets = [S_1, S_1 + R_1 + S_2, S_1 + R_1 + S_2 + R_2 + S_3, ...]`
132 This makes it possible to map an index from a contiguous range
133 `(0...num_elements_to_read)` to an index in the range of all elements,
134 skipping over elements as per the "skip" and "take" keys values. In
135 particular, for a given input index `X`, we find the greatest `thresholds`
136 value that is smaller or equal to `X`. Let `t(X)` denotes such index in the
137 `thresholds` array. The output index is computed as `X + offsets[t(X)]`.
139 Args:
140 file_infos: See `file_infos` argument of `index_shuffle` for details.
142 Returns:
143 A dictionary containing the following keys:
144 - `files`, the vector of pathnames of files to read
145 - `num_elements`, an integer identifying the total number of elements
146 - `offsets`, the vector of offsets to use for index adjustment (in case
147 any elements should be skipped)
148 - `thresholds`, the vector of thresholds to use for index adjustment (in
149 case any elements should be skipped)
150 """
151 files = []
152 num_elements = 0
153 offsets = np.int64([])
154 offset_sum = 0
155 thresholds = np.int64([])
156 threshold_sum = 0
157 adjustment_needed = False
158 for file_info in file_infos:
159 files.append(file_info["path"])
160 skip = 0
161 if "skip" in file_info:
162 if file_info["skip"] < -1:
163 raise ValueError("`skip` should be greater than `-1` but got {}".format(
164 file_info["skip"]))
165 if file_info["skip"] == -1:
166 skip = file_info["num_elements"]
167 else:
168 skip = min(file_info["skip"], file_info["num_elements"])
169 take = file_info["num_elements"] - skip
170 if "take" in file_info:
171 if file_info["take"] < -1:
172 raise ValueError("`take` should be greater than `-1` but got {}".format(
173 file_info["take"]))
174 # `file_info["take"] == -1` is a no-op
175 if file_info["take"] != -1:
176 take = min(file_info["take"], take)
177 remainder = file_info["num_elements"] - skip - take
178 if take != file_info["num_elements"]:
179 adjustment_needed = True
180 num_elements += take
181 offsets = np.append(offsets, offset_sum + skip)
182 offset_sum += skip + remainder
183 thresholds = np.append(thresholds, threshold_sum)
184 threshold_sum += take
185 result = {"files": files, "num_elements": num_elements}
186 if adjustment_needed:
187 result["offsets"] = offsets
188 result["thresholds"] = thresholds
189 return result
192def _adjust_index(index, thresholds, offsets):
193 """Adjusts index to account for elements to be skipped."""
194 t_index = array_ops.shape(
195 array_ops.boolean_mask(
196 thresholds,
197 math_ops.less_equal(thresholds, index)))[0] - 1
198 return index + array_ops.gather(offsets, t_index)
201# TODO(jsimsa): Expose this method in the public API. When we do, consider
202# defining `FileInfo` as a public API to encapsulate the information provided
203# through the `file_infos` argument.
204def index_shuffle(file_infos,
205 reader_factory,
206 seed=None,
207 reshuffle_each_iteration=False,
208 num_parallel_calls=dataset_ops.AUTOTUNE):
209 """Creates a (globally) shuffled dataset from the given set of files.
211 Unlike `tf.data.Dataset.shuffle()`, which uses an in-memory buffer to shuffle
212 elements of input dataset in a streaming fashion,
213 `tf.data.experimental.index_shuffle()` performs a global shuffle of element
214 indices and then reads the data in a shuffled order. The advantage of
215 `index_shuffle()` is that it can perform global shuffle of datasets that do
216 not fit into memory (as long as the array of their indices does) and that the
217 shuffling logic it provides is compatible with symbolic checkpointing. The
218 disadvantage of `index_shuffle()` is that reading data in a shuffled random
219 order will in general not be as efficient as reading data sequentially.
221 Args:
222 file_infos: A list of dictionaries that describe each file of the input
223 dataset. Each dictionary is expected to contain the "path" key, which
224 identifies the path of the file and the "num_elements" key, which
225 identifies the number of elements in the file. In addition, the "skip"
226 and "take" keys can be used to identify the number of elements to skip
227 and take respectively. By default, no elements are skipped and all
228 elements are taken.
229 reader_factory: A function that maps a sequence of filenames to an instance
230 of `tf.data.Dataset` that reads data from the files.
231 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
232 seed that will be used to shuffle the order of elements. Default to
233 non-deterministic seed.
234 reshuffle_each_iteration: (Optional.) A `tf.bool` scalar `tf.Tensor`, that
235 determines whether to change the shuffle order each iteration. Defaults to
236 `False`.
237 num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`, that
238 determines the maximum number of random access operations to perform
239 in parallel. By default, the tf.data runtime uses autotuning to determine
240 the value dynamically.
242 Returns:
243 A `tf.data.Dataset` object, representing a globally shuffled dataset of
244 the input data.
245 """
247 result = _process_file_infos(file_infos)
249 def sequential_index_shuffle(seeds):
250 dataset = dataset_ops.Dataset.range(result["num_elements"])
252 def read_element(dataset, index):
253 # 1) Shuffle the index.
254 shuffled_index = stateless_random_ops.index_shuffle(
255 index, seeds, result["num_elements"] - 1)
256 # 2) If needed, adjust the index to the non-contiguous range.
257 if "thresholds" in result and "offsets" in result:
258 shuffled_index = _adjust_index(shuffled_index, result["thresholds"],
259 result["offsets"])
260 # 3) Perform the read.
261 return random_access.at(dataset, shuffled_index)
263 # We evaluate `reader_factory()` eagerly to prevent the dataset from being
264 # created on every lookup.
265 map_func = functools.partial(read_element, reader_factory(result["files"]))
266 return dataset.map(map_func, num_parallel_calls=num_parallel_calls)
268 rng_ds = dataset_ops.Dataset.random(
269 seed=seed,
270 rerandomize_each_iteration=reshuffle_each_iteration)
271 rng_ds = rng_ds.take(2).batch(2, drop_remainder=True)
272 return rng_ds.flat_map(sequential_index_shuffle)