Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/distribute.py: 30%
107 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Distribution Strategy-related dataset transformations."""
17from tensorflow.python.data.ops import dataset_ops
18from tensorflow.python.data.ops.options import ExternalStatePolicy
19from tensorflow.python.data.util import nest
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
27from tensorflow.python.types import data as data_types
28from tensorflow.python.util.tf_export import tf_export
30SHARD_HINT = -1
31tf_export("data.experimental.SHARD_HINT").export_constant(
32 __name__, "SHARD_HINT")
35class _AutoShardDataset(dataset_ops.UnaryDataset):
36 """A `Dataset` that shards the `Dataset` automatically.
38 This dataset takes in an existing dataset and tries to automatically figure
39 out how to shard the dataset in a multi-worker scenario using graph rewrites.
41 If the AutoShardPolicy is set to FILE, it walks up the dataset graph until
42 it finds a reader dataset, then inserts a ShardDataset op before that node
43 so that each worker only sees some files.
45 If the AutoShardPolicy is set to DATA, it inserts a ShardDataset op at the
46 end of the input pipeline, before any terminal PrefetchDataset if there is
47 one. Additionally, if there is a RebatchDatasetV2 in the input pipeline, it
48 is written to legacy RebatchDataset for correctness reasons, since
49 RebatchDatasetV2 is incompatible with data sharding.
51 If the AutoShardPolicy is set to AUTO, it tries to do file-based sharding.
52 If it cannot find a reader dataset, it falls back to doing data-based
53 sharding.
55 If the AutoShardPolicy is set to OFF, it does nothing.
57 Attributes:
58 num_workers: Total number of workers to shard this dataset across.
59 index: The current worker index (out of the total number of workers) this
60 dataset is for.
61 num_replicas: The total number of replicas across all workers. This is used
62 only when sharding by data (either DATA or AUTO) in order to rewrite
63 RebatchDatasetV2 to RebatchDataset.
65 Raises:
66 NotFoundError: If we cannot find a suitable reader dataset to begin
67 automatically sharding the dataset.
68 """
70 def __init__(self, input_dataset, num_workers, index, num_replicas=None):
71 self._input_dataset = input_dataset
73 self._element_spec = input_dataset.element_spec
74 variant_tensor = ged_ops.auto_shard_dataset(
75 self._input_dataset._variant_tensor, # pylint: disable=protected-access
76 num_workers=num_workers,
77 index=index,
78 auto_shard_policy=int(
79 input_dataset.options().experimental_distribute.auto_shard_policy),
80 num_replicas=num_replicas,
81 **self._flat_structure)
82 super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
84 @property
85 def element_spec(self):
86 return self._element_spec
89def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None): # pylint: disable=invalid-name
90 return dataset_ops.DatasetV1Adapter(
91 _AutoShardDataset(input_dataset, num_workers, index, num_replicas))
94class _LegacyRebatchDataset(dataset_ops.UnaryDataset):
95 """A `Dataset` that divides its input batches into `num_replicas` sub-batches.
97 For each batch in the input dataset, _LegacyRebatchDataset will produce
98 `num_replicas` smaller batches whose sizes add up to the original batch size.
100 For example:
102 ```python
103 ds = tf.data.Dataset.range(8)
104 ds = ds.batch(4)
105 ds = _LegacyRebatchDataset(ds, num_replicas=3)
106 for elem in ds:
107 print(elem)
108 >> [0, 1], [2, 3], [], [4, 5], [6, 7], []
109 ```
110 """
112 def __init__(self, input_dataset, num_replicas):
113 """Creates a _LegacyRebatchDataset.
115 Args:
116 input_dataset: `Dataset` to rebatch.
117 num_replicas: A `tf.int64` scalar, representing the number of sub-batches
118 to split each batch from `input_dataset` into.
119 """
121 def recalculate_batch_size(type_spec):
122 """Recalculates the output_shape after dividing it by num_replicas."""
123 output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access
124 if not isinstance(output_shape, tensor_shape.TensorShape):
125 return None
127 # If the output shape is unknown, we set the batch dimension to unknown.
128 if output_shape.rank is None:
129 return None
131 if len(output_shape) < 1:
132 raise ValueError(
133 "Invalid `input_dataset`. Expected a dataset whose elements "
134 "have rank >= 1 but found a dataset whose elements are scalars. "
135 "Fix the issue by adding the `batch` transformation to the "
136 "dataset.")
137 output_dims = [d.value for d in output_shape.dims]
139 if output_dims[0] is not None and output_dims[0] % num_replicas == 0:
140 return output_dims[0] // num_replicas
142 # Set the batch dimension to unknown. If the global batch size does not
143 # divide num_replicas evenly, the minibatches may have different sizes.
144 return None
146 def rebatch(type_spec):
147 # pylint: disable=protected-access
148 batch_size = recalculate_batch_size(type_spec)
149 return type_spec._unbatch()._batch(batch_size)
150 # pylint: enable=protected-access
152 self._element_spec = nest.map_structure(
153 rebatch, dataset_ops.get_structure(input_dataset))
155 # auto_shard rewrite assumes that there's normalize_to_dense before
156 # rebatch_dataset.
157 # LINT.IfChange
158 input_dataset = dataset_ops.normalize_to_dense(input_dataset)
159 variant_tensor = ged_ops.rebatch_dataset(
160 input_dataset._variant_tensor, # pylint: disable=protected-access
161 num_replicas=num_replicas,
162 **self._flat_structure)
163 # LINT.ThenChange(//tensorflow/core/grappler/optimizers/data/auto_shard.cc)
164 super(_LegacyRebatchDataset, self).__init__(input_dataset, variant_tensor)
166 @property
167 def element_spec(self):
168 return self._element_spec
171class _RemoteDataset(dataset_ops.DatasetSource):
172 """Creates a dataset on a given `device` given a graph def."""
174 def __init__(self, graph_def, device, element_spec):
175 self._elem_spec = element_spec
176 with ops.device(device):
177 variant_tensor = ged_ops.dataset_from_graph(graph_def)
178 super(_RemoteDataset, self).__init__(variant_tensor)
180 @property
181 def element_spec(self):
182 return self._elem_spec
185def replicate(dataset, devices):
186 """A transformation that replicates `dataset` onto a list of devices.
188 Args:
189 dataset: A `tf.data.Dataset` object.
190 devices: A list of devices to replicate the dataset on.
192 Returns:
193 A dictionary mapping device name to a dataset on that device.
194 """
195 if not isinstance(dataset, data_types.DatasetV2):
196 raise TypeError(
197 f"Invalid `dataset`. Expected a `tf.data.Dataset` object but "
198 f"got {type(dataset)}.")
200 # pylint: disable=protected-access
201 dataset_device = dataset._variant_tensor.device
203 datasets = {}
204 if len(devices) == 1 and devices[0] == dataset_device:
205 datasets[devices[0]] = dataset
206 return datasets
208 with ops.colocate_with(dataset._variant_tensor):
209 dataset = dataset._apply_debug_options()
210 graph_def = dataset._as_serialized_graph(
211 strip_device_assignment=True,
212 external_state_policy=ExternalStatePolicy.WARN)
213 for device in devices:
214 ds = _RemoteDataset(graph_def, device, dataset.element_spec)
215 datasets[device] = ds
216 return datasets
219def batch_sizes_for_worker(global_batch_size, num_workers,
220 num_replicas_per_worker, worker_index):
221 """Determines how to rebatch a dataset for the given worker.
223 Given the global batch size, number of workers, number of replicas per worker,
224 and worker index, returns the correct batch sizes for rebatching a dataset
225 on worker `worker_index` of `num_workers`, such that each global step (across
226 all workers and replicas) will consume global_batch_size elements. The
227 returned value should be passed as the `batch_sizes` input parameter to
228 `tf.data.experimental.rebatch()`. The returned batch sizes meet the following
229 constraints:
231 Let G = global_batch_size, W = num_workers, R = num_replicas_per_worker
232 (A) for any worker, len(batch_sizes) = W * R
233 (B) for any worker, sum(batch_sizes) == G
234 (C) for any global step (i.e. R iterations on each worker), the sum of batches
235 consumed by replicas across all workers is G.
236 (D) any two batch sizes of any two replicas differs by at most one.
238 For example, suppose we have G = 7, W = 2, R = 2, and suppose we have two
239 files which each contain 7 elements:
241 ```python
242 # WORKER 0
243 batch_sizes_0 = batch_sizes_for_worker(global_batch_size=global_batch_size,
244 num_workers=2,
245 num_replicas_per_worker=2,
246 worker_index=0)
247 print(batch_sizes_0)
248 >> [2, 2, 2, 1]
250 dataset_0 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"])
251 dataset_0 = dataset_0.shard(num_shards, index=0)
252 dataset_0 = dataset_0.batch(7)
253 dataset_0 = dataset_0.apply(tf.data.experimental.rebatch(batch_sizes_0))
254 for elem in dataset_0:
255 print(elem)
256 >> [[A0, A1], [A2, A3], [A4, A5], [A6]]
258 # WORKER 1
259 batch_sizes_1 = batch_sizes_for_worker(global_batch_size=global_batch_size,
260 num_workers=2,
261 num_replicas_per_worker=2,
262 worker_index=1)
263 print(batch_sizes_1)
264 >> [2, 1, 2, 2]
266 dataset_1 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"])
267 dataset_1 = dataset_1.shard(num_shards, index=1)
268 dataset_1 = dataset_1.batch(7)
269 dataset_1 = dataset_1.apply(tf.data.experimental.rebatch(batch_sizes_1))
270 for elem in dataset_1:
271 print(elem)
272 >> [[B0, B1], [B2], [B3, B4], [B5, B6]]
273 ```
275 The above example will produce the following elements:
277 Step 1:
278 Worker 0 Replica 0: [A0, A1]
279 Worker 0 Replica 1: [A2, A3]
280 Worker 1 Replica 0: [B0, B1]
281 Worker 1 Replica 1: [B2]
282 Total batch size = 7
284 Step 2:
285 Worker 0 Replica 0: [A4, A5]
286 Worker 0 Replica 1: [A6]
287 Worker 1 Replica 0: [B3, B4]
288 Worker 1 Replica 1: [B5, B6]
289 Total batch size = 7
291 Args:
292 global_batch_size: A `tf.int64` scalar, representing the global batch size.
293 num_workers: An integer representing the number of workers the dataset will
294 be distributed across.
295 num_replicas_per_worker: An integer representing the number of replicas per
296 worker. All workers are assumed to have the same number of replicas.
297 worker_index: An integer index of the worker to be rebatched.
299 Returns:
300 A `tf.int64` vector, representing the batch sizes to rebatch the dataset
301 into.
302 """
303 # Constraint (A)
304 num_subbatches = num_workers * num_replicas_per_worker
306 offset = worker_index * num_replicas_per_worker
308 const_value = tensor_util.constant_value(global_batch_size)
309 if const_value is not None:
310 # Use the constant global batch size for further calculations
311 global_batch_size = const_value
313 # Let N = W * R. Constraint (B) and (D) jointly mean that the iterations
314 # should have batch size either floor(B/N) or ceil(B/N). Namely, of the N
315 # subbatches a batch is split into, B - N * floor(B/N) of them will have size
316 # ceil(B/N), and the rest will have size floor(B/N).
317 floor = global_batch_size // num_subbatches
318 num_ceil = global_batch_size - (num_subbatches * floor)
320 # For worker 0, we assign the first num_ceil subbatches to have size
321 # ceil(B/N), and the remainder to have size floor(B/N). The other workers will
322 # each be offset by R * worker_index in order to meet constraint (C).
323 if const_value is not None:
324 # If the global batch size is a known constant value, we return a constant
325 # tensor directly instead of manipulating it with TF ops. This allows for
326 # better downstream shape inference.
327 worker_0 = [floor + 1] * num_ceil + [floor] * (num_subbatches - num_ceil)
328 return ops.convert_to_tensor(
329 worker_0[offset:] + worker_0[:offset],
330 dtype=dtypes.int64,
331 name="batch_sizes")
333 worker_0 = array_ops.ones(num_subbatches, dtype=dtypes.int64)
334 worker_0 = floor * worker_0 + array_ops.concat([
335 array_ops.ones(num_ceil, dtype=dtypes.int64),
336 array_ops.zeros(num_subbatches - num_ceil, dtype=dtypes.int64)
337 ],
338 axis=0)
340 return array_ops.concat([worker_0[offset:], worker_0[:offset]], axis=0)
343def compute_batch_size(dataset):
344 """An operation that returns the batch size of the dataset.
346 This op tries to infer the batch size statically by walking up the dataset
347 tree from the final dataset node and returning the batch size of the first
348 batching dataset (such as from .batch() and .padded_batch()) that it
349 encounters. This differs from using the `element_spec` of a dataset in that it
350 does not account for partial batches.
352 This operation may fail if it encounters contradictory batch sizes (for
353 example, if the dataset is created by zipping together two datasets with
354 different batch sizes), if there are no explicit batching transformations, or
355 if there are operations downstream from the batching transformation that may
356 modify its batch size. In these cases, it returns a -1.
358 Args:
359 dataset: A `tf.data.Dataset` object.
361 Returns:
362 A `tf.int64` Tensor representing the batch size of the dataset sans partial
363 batches. If this cannot be inferred statically, the value of this tensor
364 will be -1.
365 """
367 def get_static_batch_dim(type_spec):
368 try:
369 output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access
370 except NotImplementedError:
371 return None
372 if not isinstance(output_shape, tensor_shape.TensorShape):
373 return None
374 if output_shape.rank is None:
375 return None
376 return output_shape.dims[0].value
378 batch_dims = [
379 get_static_batch_dim(type_spec)
380 for type_spec in nest.flatten(dataset_ops.get_structure(dataset))
381 ]
383 if all(d is not None for d in batch_dims):
385 if all(d == batch_dims[0] for d in batch_dims):
386 # If all batch dimensions are known and equal, return that directly.
387 batch_dim = batch_dims[0]
388 else:
389 # If all batch dimensions are known but not all equal, return -1.
390 batch_dim = -1
392 return constant_op.constant(
393 batch_dim, dtype=dtypes.int64, name="static_batch_size")
395 # If any batch dimensions are unknown, use compute_batch_size op.
396 return ged_ops.compute_batch_size(dataset._variant_tensor) # pylint: disable=protected-access
399_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__