Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/input_lib.py: 25%
785 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed inputs."""
17import functools
18import sys
19import time
21import six
23from tensorflow.python.autograph.operators import py_builtins
24from tensorflow.python.data.experimental.ops import batching
25from tensorflow.python.data.experimental.ops import cardinality as cardinality_lib
26from tensorflow.python.data.experimental.ops import distribute
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.data.ops import iterator_ops
29from tensorflow.python.data.ops import multi_device_iterator_ops
30from tensorflow.python.data.ops import optional_ops
31from tensorflow.python.distribute import device_util
32from tensorflow.python.distribute import distribute_lib
33from tensorflow.python.distribute import distribute_utils
34from tensorflow.python.distribute import input_ops
35from tensorflow.python.distribute import reduce_util
36from tensorflow.python.distribute import values
37from tensorflow.python.distribute.distribute_lib import InputReplicationMode
38from tensorflow.python.eager import context
39from tensorflow.python.eager import monitoring
40from tensorflow.python.framework import composite_tensor
41from tensorflow.python.framework import device as tf_device
42from tensorflow.python.framework import dtypes
43from tensorflow.python.framework import errors
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import sparse_tensor
46from tensorflow.python.framework import tensor_shape
47from tensorflow.python.framework import tensor_util
48from tensorflow.python.framework import type_spec
49from tensorflow.python.ops import array_ops
50from tensorflow.python.ops import cond as tf_cond
51from tensorflow.python.ops import math_ops
52from tensorflow.python.ops import while_loop
53from tensorflow.python.ops.ragged import ragged_tensor
54from tensorflow.python.platform import tf_logging as logging
55from tensorflow.python.types import distribute as distribute_types
56from tensorflow.python.util import nest
57from tensorflow.python.util.compat import collections_abc
60_distributed_dataset_initialization_time_milliseconds = monitoring.Sampler(
61 "/tensorflow/api/distribution_strategy/"
62 "distributed_dataset_initialization_time_milliseconds",
63 monitoring.ExponentialBuckets(scale=1, growth_factor=2, bucket_count=26),
64 "Track the time (in milliseconds) to initialize distributed datasets.",
65 "strategy", "workers")
67_distributed_dataset_from_function_initialization_time_milliseconds = (
68 monitoring.Sampler(
69 "/tensorflow/api/distribution_strategy/"
70 "distributed_dataset_from_function_initialization_time_milliseconds",
71 monitoring.ExponentialBuckets(
72 scale=1, growth_factor=2, bucket_count=26),
73 "Track the time (in milliseconds) to initialize distributed datasets "
74 "from function.",
75 "strategy", "workers"))
78def get_iterator_spec_from_dataset(strategy, dataset):
79 """Returns an iterator spec from dataset function.
81 This function constructs type spec for iterator obtained from
82 iter(dataset).
84 Args:
85 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
86 handle last partial batch.
87 dataset: A tf.data.Dataset instance. If using a function that returns a
88 tf.data.Dataset instance, pass dataset_fn.structured_outputs.
90 Returns:
91 A type_spec for iterator for dataset instance.
93 """
94 # pylint: disable=protected-access
95 output_element_spec = dataset.element_spec
96 if isinstance(dataset._type_spec,
97 (DistributedDatasetSpec,
98 DistributedDatasetsFromFunctionSpec)):
99 iterator_type_spec = DistributedIteratorSpec(
100 strategy.extended._input_workers_with_options(),
101 output_element_spec,
102 strategy.extended._container_strategy(),
103 options=None,
104 cardinality=dataset.cardinality,
105 enable_get_next_as_optional=True)
106 else:
107 if strategy.extended._num_gpus_per_worker:
108 logging.warning(
109 f"{strategy.extended._num_gpus_per_worker} GPUs "
110 "are allocated per worker. Please use DistributedDataset by "
111 "calling strategy.experimental_distribute_dataset or strategy."
112 "distribute_datasets_from_function to make best use of GPU "
113 "resources"
114 )
115 iterator_type_spec = iterator_ops.IteratorSpec(output_element_spec)
116 return iterator_type_spec
117 # pylint: enable=protected-access
120class InputWorkers(object):
121 """A 1-to-many mapping from input worker devices to compute devices."""
123 # TODO(ishark): Remove option canonicalize_devices and make all the callers
124 # pass canonicalized or raw device strings as relevant from strategy.
125 def __init__(self,
126 worker_device_pairs,
127 canonicalize_devices=True):
128 """Initialize an `InputWorkers` object.
130 Args:
131 worker_device_pairs: A sequence of pairs: `(input device, a tuple of
132 compute devices fed by that input device)`.
133 canonicalize_devices: Whether to canonicalize devices for workers fully or
134 partially. If False, it will partially canonicalize devices by removing
135 job and task.
136 """
137 self._worker_device_pairs = worker_device_pairs
138 self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs)
139 self._canonicalize_devices = canonicalize_devices
140 if canonicalize_devices:
141 self._fed_devices = tuple(
142 tuple(device_util.canonicalize(d)
143 for d in f)
144 for _, f in self._worker_device_pairs)
145 else:
146 self._fed_devices = tuple(
147 tuple(device_util.canonicalize_without_job_and_task(d)
148 for d in f)
149 for _, f in self._worker_device_pairs)
151 @property
152 def num_workers(self):
153 return len(self._input_worker_devices)
155 @property
156 def worker_devices(self):
157 return self._input_worker_devices
159 def compute_devices_for_worker(self, worker_index):
160 return self._fed_devices[worker_index]
162 def __repr__(self):
163 devices = self.worker_devices
164 debug_repr = ",\n".join(" %d %s: %s" %
165 (i, devices[i], self._fed_devices[i])
166 for i in range(len(devices)))
167 return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)
169 def serialize(self):
170 return (self._worker_device_pairs, self._canonicalize_devices)
172 def deserialize(self, serialized):
173 return InputWorkers(serialized)
176def _calculate_replicas_with_values(strategy, input_workers, optional_list):
177 """Calcualates the number of replicas that have values.
179 Args:
180 strategy: the `tf.distribute.Strategy`.
181 input_workers: the `InputWorkers`.
182 optional_list: a list of lists `tf.experimental.Optional`. The values from
183 each compute device grouped by the input device.
185 Returns:
186 A scalar Tensor.
187 """
188 worker_has_values = []
189 for worker, optionals in zip(input_workers.worker_devices, optional_list):
190 with ops.device(worker):
191 device_has_values = [
192 math_ops.cast(v.has_value(), dtypes.int64) for v in optionals
193 ]
194 worker_has_values.append(
195 math_ops.reduce_sum(device_has_values, keepdims=True))
196 client_has_values = math_ops.reduce_sum(worker_has_values, keepdims=True)
197 if strategy.extended._in_multi_worker_mode(): # pylint: disable=protected-access
198 global_has_values = strategy.reduce(
199 reduce_util.ReduceOp.SUM, client_has_values, axis=None)
200 return array_ops.reshape(global_has_values, [])
201 else:
202 return array_ops.reshape(client_has_values, [])
205def _is_statically_shaped(element_spec):
206 """Test if an iterator output is statically shaped.
208 For sparse and ragged tensors this only tests the batch dimension.
210 Args:
211 element_spec: a nest structure of `tf.TypeSpec`. The element spec of the
212 dataset of the iterator.
214 Returns:
215 True if the shape is static, false otherwise.
216 """
218 for spec in nest.flatten(element_spec):
219 if isinstance(
220 spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)):
221 # For sparse or ragged tensor, we should only check the first
222 # dimension in order to get_next_as_optional. This is because
223 # when these tensors get batched by dataset only the batch dimension
224 # is set.
225 if spec.shape.rank > 0 and spec.shape.as_list()[0] is None:
226 return False
227 else:
228 for component in spec._flat_tensor_specs: # pylint: disable=protected-access
229 if not component.shape.is_fully_defined():
230 return False
231 return True
234class DistributedIteratorBase(collections_abc.Iterator,
235 distribute_types.DistributedIteratorInterface):
236 """Common implementation for all input iterators."""
238 # pylint: disable=super-init-not-called
239 def __init__(
240 self,
241 input_workers,
242 iterators,
243 strategy,
244 cardinality,
245 enable_get_next_as_optional,
246 replica_order=None,
247 ):
248 assert isinstance(input_workers, InputWorkers)
249 if not input_workers.worker_devices:
250 raise ValueError("Should have at least one worker for input iterator.")
252 self._iterators = iterators
253 self._input_workers = input_workers
254 self._strategy = strategy
255 self._cardinality = cardinality
256 self._enable_get_next_as_optional = enable_get_next_as_optional
257 self._replica_order = replica_order
259 def next(self):
260 return self.__next__()
262 def __next__(self):
263 try:
264 return self.get_next()
265 except errors.OutOfRangeError:
266 raise StopIteration
268 def __iter__(self):
269 return self
271 def get_next_as_optional(self):
272 # Ideally get_next_as_optional() should be consistent with get_next(), but
273 # we used to always do partial batch handling in get_next_as_optional(). We
274 # are keeping this behavior for now until we understantd the impact.
276 # Skip partial batch handling when the dataset is infinite or empty, as
277 # there won't be any partial batches in those cases. This gives the user
278 # more static shapes as it avoids the tf.cond. Note that for empty datasets,
279 # we can only skip in single client mode, as the dataset can be non-empty on
280 # other workers.
281 if self._cardinality == cardinality_lib.INFINITE:
282 return optional_ops.Optional.from_value(
283 self._get_next_no_partial_batch_handling())
284 if (self._cardinality == 0 and
285 not self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
286 return optional_ops.Optional.empty(self._element_spec)
288 optional_list = []
289 for i, worker in enumerate(self._input_workers.worker_devices):
290 with ops.device(worker):
291 optional_list.append(self._iterators[i].get_next_as_optional_list())
293 def _create_optional_with_dummy():
294 value_list = _get_value_or_dummy(
295 self._input_workers, optional_list, produce_dummy=True)
297 if self._replica_order is not None:
298 value_list = self._reorder_replicas(value_list)
300 per_replica = _create_per_replica(value_list, self._strategy)
301 return optional_ops.Optional.from_value(per_replica)
303 def _create_empty_optional():
304 return optional_ops.Optional.empty(self._element_spec)
306 num_replicas_with_values = _calculate_replicas_with_values(
307 self._strategy, self._input_workers, optional_list)
309 return tf_cond.cond(
310 num_replicas_with_values > 0,
311 _create_optional_with_dummy,
312 _create_empty_optional,
313 strict=True)
315 def get_next(self, name=None):
316 """Returns the next input from the iterator for all replicas."""
317 with distribute_lib.enter_or_assert_strategy(
318 self._strategy):
319 if distribute_lib.get_replica_context() is not None:
320 raise ValueError("next(iterator) should be called from outside of "
321 "replica_fn. e.g. strategy.run(replica_fn, "
322 "args=(next(iterator),))")
324 if not self._enable_get_next_as_optional:
325 return self._get_next_no_partial_batch_handling(name)
327 optional_list = []
328 for i, worker in enumerate(self._input_workers.worker_devices):
329 with ops.device(worker):
330 optional_list.append(self._iterators[i].get_next_as_optional_list())
331 num_replicas_with_values = _calculate_replicas_with_values(
332 self._strategy, self._input_workers, optional_list)
334 def _value_or_dummy():
335 value_list = _get_value_or_dummy(
336 self._input_workers, optional_list, produce_dummy=True)
338 if self._replica_order is not None:
339 value_list = self._reorder_replicas(value_list)
341 return _create_per_replica(value_list, self._strategy)
343 def _eof():
344 # Optional.get_value raises InvalidArgumentError when there's no value,
345 # so we need to call GetNext to raise EOFError.
346 return self._get_next_no_partial_batch_handling()
348 return tf_cond.cond(
349 num_replicas_with_values > 0, _value_or_dummy, _eof, strict=True)
351 def _get_next_no_partial_batch_handling(self, name=None):
352 replicas = []
353 for i, worker in enumerate(self._input_workers.worker_devices):
354 if name is not None:
355 d = tf_device.DeviceSpec.from_string(worker)
356 new_name = "%s_%s_%d" % (name, d.job, d.task)
357 else:
358 new_name = None
359 with ops.device(worker):
360 # Make `replicas` a flat list of values across all replicas.
361 replicas.extend(self._iterators[i].get_next_as_list(new_name))
363 if self._replica_order is not None:
364 replicas = self._reorder_replicas(replicas)
366 return _create_per_replica(replicas, self._strategy)
368 def _reorder_replicas(self, replicas):
369 assert len(self._replica_order) == len(
370 replicas
371 ), "replica order size ({}) != replicas size ({})!".format(
372 len(self._replica_order), len(replicas)
373 )
374 return [replicas[i] for i in self._replica_order]
377class DistributedDatasetAndIteratorSpec(type_spec.TypeSpec):
378 """Common Type specification for `DistributedDataset and DistributedDatasetsFromFunction."""
380 __slots__ = [
381 "_input_workers", "_element_spec", "_strategy", "_cardinality",
382 "_enable_get_next_as_optional", "_options", "_canonicalize_devices"
383 ]
385 def __init__(
386 self,
387 input_workers,
388 element_spec,
389 strategy,
390 options,
391 cardinality=cardinality_lib.UNKNOWN,
392 enable_get_next_as_optional=None,
393 replica_order=None,
394 ):
395 # We don't want to allow deserialization of this class because we don't
396 # serialize the strategy object. Currently the only places where
397 # _deserialize is called is when we save/restore using SavedModels.
398 if isinstance(input_workers, tuple):
399 raise NotImplementedError("DistributedIteratorSpec does not have support "
400 "for deserialization.")
401 else:
402 self._input_workers = input_workers
403 self._element_spec = element_spec
404 self._strategy = strategy
405 self._cardinality = cardinality
406 self._enable_get_next_as_optional = enable_get_next_as_optional
407 self._options = options
408 if self._strategy:
409 self._canonicalize_devices = getattr(self._strategy,
410 "_canonicalize_devices", True)
411 else:
412 self._canonicalize_devices = True
413 self._replica_order = replica_order
415 def _serialize(self):
416 # We cannot serialize the strategy object so we convert it to an id that we
417 # can use for comparison.
418 return (self._input_workers.serialize(), self._element_spec,
419 id(self._strategy), id(self._options))
421 def _deserialize(self):
422 raise ValueError(
423 f"Deserialization is currently unsupported for {type(self)}.")
425 def sanity_check_type(self, other):
426 """Returns the most specific TypeSpec compatible with `self` and `other`.
428 Args:
429 other: A `TypeSpec`.
431 Raises:
432 ValueError: If there is no TypeSpec that is compatible with both `self`
433 and `other`.
434 """
435 # pylint: disable=protected-access
436 if type(self) is not type(other):
437 raise ValueError("No TypeSpec is compatible with both %s and %s" %
438 (self, other))
439 if self._input_workers.serialize() != other._input_workers.serialize():
440 raise ValueError("_input_workers is not compatible with both %s "
441 "and %s" % (self, other))
442 if self._strategy is not other._strategy:
443 raise ValueError("tf.distribute strategy is not compatible with both %s "
444 "and %s" % (self, other))
446 def is_subtype_of(self, other):
447 """Returns True if `self` is subtype of `other`.
449 Args:
450 other: A `TypeSpec`.
451 """
452 try:
453 self.sanity_check_type(other)
454 nest.assert_same_structure(self._element_spec, other._element_spec) # pylint: disable=protected-access
455 except (TypeError, ValueError):
456 return False
458 self_elements = nest.flatten(self._element_spec)
459 other_elements = nest.flatten(other._element_spec) # pylint: disable=protected-access
461 return all(
462 self_element.is_subtype_of(other_element)
463 for (self_element, other_element) in zip(self_elements, other_elements))
465 def most_specific_common_supertype(self, others):
466 """Returns the most specific supertype of `self` and `others`.
468 Args:
469 others: A Sequence of `TypeSpec`.
471 Returns `None` if a supertype does not exist.
472 """
473 try:
474 for other in others:
475 self.sanity_check_type(other)
476 nest.assert_same_structure(self._element_spec, other._element_spec) # pylint: disable=protected-access
477 except (TypeError, ValueError):
478 return None
480 self_elements = nest.flatten(self._element_spec)
481 others_elements = [nest.flatten(other._element_spec) for other in others] # pylint: disable=protected-access
482 common_elements = [None] * len(self_elements)
484 for i, self_element in enumerate(self_elements):
485 common_elements[i] = self_element.most_specific_common_supertype(
486 [other_elements[i] for other_elements in others_elements])
487 if common_elements[i] is None:
488 return None
489 common_element_spec = nest.pack_sequence_as(self._element_spec,
490 common_elements)
491 return type(self)(
492 self._input_workers,
493 common_element_spec,
494 self._strategy,
495 self._options,
496 cardinality=self._cardinality,
497 enable_get_next_as_optional=self._enable_get_next_as_optional)
499 def _with_tensor_ranks_only(self):
500 element_spec = nest.map_structure(
501 lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access
502 self._element_spec)
503 return type(self)(
504 self._input_workers,
505 element_spec,
506 self._strategy,
507 self._options,
508 cardinality=self._cardinality,
509 enable_get_next_as_optional=self._enable_get_next_as_optional)
511 # TODO(b/206014848): Remove once names are not used.
512 def _without_tensor_names(self):
513 element_spec = nest.map_structure(
514 lambda s: s._without_tensor_names(), # pylint: disable=protected-access
515 self._element_spec)
516 return type(self)(
517 self._input_workers,
518 element_spec,
519 self._strategy,
520 self._options,
521 cardinality=self._cardinality,
522 enable_get_next_as_optional=self._enable_get_next_as_optional)
525class DistributedIteratorSpec(DistributedDatasetAndIteratorSpec):
526 """Type specification for `DistributedIterator`."""
528 @property
529 def value_type(self):
530 return DistributedIterator
532 @property
533 def _component_specs(self):
534 specs = []
535 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
537 for i, (input_device, compute_devices) in enumerate(worker_device_pairs):
538 element_spec = nest.map_structure(
539 functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
540 specs.append(
541 _SingleWorkerDatasetIteratorSpec(input_device, compute_devices,
542 element_spec, self._options,
543 self._canonicalize_devices))
544 return specs
546 def _to_components(self, value):
547 return value._iterators # pylint: disable=protected-access
549 def _from_components(self, components):
550 return DistributedIterator(
551 input_workers=self._input_workers,
552 iterators=None,
553 components=components,
554 element_spec=self._element_spec,
555 strategy=self._strategy,
556 cardinality=self._cardinality,
557 enable_get_next_as_optional=self._enable_get_next_as_optional,
558 options=self._options,
559 replica_order=self._replica_order,
560 )
562 @staticmethod
563 def from_value(value):
564 # pylint: disable=protected-access
565 return DistributedIteratorSpec(
566 value._input_workers,
567 value._element_spec,
568 value._strategy,
569 value._options,
570 cardinality=value._cardinality,
571 enable_get_next_as_optional=value._enable_get_next_as_optional)
574class DistributedIterator(DistributedIteratorBase,
575 composite_tensor.CompositeTensor):
576 """Input Iterator for a distributed dataset."""
578 def __init__(
579 self,
580 input_workers=None,
581 iterators=None,
582 strategy=None,
583 components=None,
584 element_spec=None,
585 cardinality=cardinality_lib.UNKNOWN,
586 enable_get_next_as_optional=False,
587 options=None,
588 replica_order=None,
589 ):
590 if input_workers is None:
591 raise ValueError("`input_workers` should be "
592 "provided.")
594 error_message = ("Either `input_workers` or "
595 "both `components` and `element_spec` need to be "
596 "provided.")
597 self._options = options
599 if iterators is None:
600 if (components is None or element_spec is None):
601 raise ValueError(error_message)
602 self._element_spec = element_spec
603 self._input_workers = input_workers
604 self._iterators = components
605 self._strategy = strategy
606 self._cardinality = cardinality
607 self._enable_get_next_as_optional = enable_get_next_as_optional
608 self._replica_order = replica_order
609 else:
610 if (components is not None and element_spec is not None):
611 raise ValueError(error_message)
613 super(DistributedIterator, self).__init__(
614 input_workers,
615 iterators,
616 strategy,
617 cardinality,
618 enable_get_next_as_optional,
619 replica_order,
620 )
622 @property
623 def element_spec(self):
624 # When partial batch handling is enabled, always set the batch dimension to
625 # None, otherwise we just follow element_spec of the underlying dataset
626 # (whose batch dimension may also be None). This is because with partial
627 # batching handling we could always produce empty batches.
628 if (self._enable_get_next_as_optional and
629 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
630 return nest.map_structure(
631 _rebatch_as_dynamic, self._element_spec, expand_composites=False)
632 return self._element_spec
634 @property
635 def _type_spec(self):
636 # Note that we use actual element_spec instead of the rebatched-as-dynamic
637 # one to create DistributedIteratorSpec, to be consistent with the
638 # underlying iterators' specs.
639 return DistributedIteratorSpec(
640 self._input_workers,
641 self._element_spec,
642 self._strategy,
643 self._options,
644 self._cardinality,
645 self._enable_get_next_as_optional,
646 self._replica_order,
647 )
650class _IterableInput(collections_abc.Iterable,
651 distribute_types.DistributedDatasetInterface):
652 """Base class for iterable inputs for distribution strategies."""
654 # pylint: disable=super-init-not-called
655 def __init__(self, input_workers):
656 assert isinstance(input_workers, InputWorkers)
657 self._input_workers = input_workers
659 def __iter__(self):
660 raise NotImplementedError("must be implemented in descendants")
662 def reduce(self, initial_state, reduce_fn):
663 """Execute a `reduce_fn` over all the elements of the input."""
664 iterator = iter(self)
665 optional_data = iterator.get_next_as_optional()
667 def cond(optional_data, state):
668 del state # Unused.
669 return optional_data.has_value()
671 def loop_body(optional_data, state):
672 """Executes `reduce_fn` in a loop till the dataset is empty."""
673 state = reduce_fn(state, optional_data.get_value())
674 optional_data = iterator.get_next_as_optional()
675 return optional_data, state
677 optional_data, final_state = while_loop.while_loop(
678 cond,
679 loop_body, [optional_data, initial_state],
680 parallel_iterations=1,
681 return_same_structure=True)
682 return final_state
685class DistributedDatasetSpec(DistributedDatasetAndIteratorSpec):
686 """Type specification for `DistributedDataset."""
688 @property
689 def value_type(self):
690 return DistributedDataset
692 @property
693 def _component_specs(self):
694 specs = []
695 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
697 for i, _ in enumerate(worker_device_pairs):
698 element_spec = nest.map_structure(
699 functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
700 specs.append(dataset_ops.DatasetSpec(element_spec))
701 return specs
703 def _to_components(self, value):
704 return value._cloned_datasets # pylint: disable=protected-access
706 def _from_components(self, components):
707 return DistributedDataset(
708 input_workers=self._input_workers,
709 strategy=self._strategy,
710 components=components,
711 element_spec=self._element_spec,
712 enable_get_next_as_optional=self._enable_get_next_as_optional,
713 options=self._options,
714 replica_order=self._replica_order,
715 )
717 @staticmethod
718 def from_value(value):
719 # pylint: disable=protected-access
720 return DistributedDatasetSpec(
721 value._input_workers,
722 value._element_spec,
723 value._strategy,
724 value._options,
725 enable_get_next_as_optional=value._enable_get_next_as_optional)
726 # pylint: enable=protected-access
729class DistributedDataset(_IterableInput, composite_tensor.CompositeTensor):
730 """Distributed dataset that supports prefetching to multiple devices."""
732 def __init__(
733 self,
734 input_workers,
735 strategy,
736 dataset=None,
737 num_replicas_in_sync=None,
738 input_context=None,
739 components=None,
740 element_spec=None,
741 enable_get_next_as_optional=None,
742 build=True,
743 options=None,
744 replica_order=None,
745 ):
746 """Distribute the dataset on all workers.
748 If `num_replicas_in_sync` is not None, we split each batch of the dataset
749 into `num_replicas_in_sync` smaller batches, to be distributed among that
750 worker's replicas, so that the batch size for a global step (across all
751 workers and replicas) is as expected.
753 Args:
754 input_workers: an `InputWorkers` object.
755 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
756 handle last partial batch.
757 dataset: `tf.data.Dataset` that will be used as the input source. Either
758 dataset or components field should be passed when constructing
759 DistributedDataset. Use this when contructing DistributedDataset from a
760 new `tf.data.Dataset`. Use components when constructing using
761 DistributedDatasetSpec.
762 num_replicas_in_sync: Optional integer. If this is not None, the value is
763 used to decide how to rebatch datasets into smaller batches so that the
764 total batch size for each step (across all workers and replicas) adds up
765 to `dataset`'s batch size.
766 input_context: `InputContext` for sharding. Only pass this in for between
767 graph multi-worker cases where there is only one `input_worker`. In
768 these cases, we will shard based on the `input_pipeline_id` and
769 `num_input_pipelines` in the `InputContext`.
770 components: datasets when DistributedDataset is constructed from
771 DistributedDatasetSpec. Either field dataset or components should be
772 passed.
773 element_spec: element spec for DistributedDataset when constructing from
774 DistributedDatasetSpec. This will be used to set the element_spec for
775 DistributedDataset and verified against element_spec from components.
776 enable_get_next_as_optional: this is required when components is passed
777 instead of dataset.
778 build: whether to build underlying datasets when this object is created.
779 This is only useful for `ParameterServerStrategy` now.
780 options: `tf.distribute.InputOptions` used to control options on how this
781 dataset is distributed.
782 replica_order: the order of the replicas, which will be used to reorder
783 the iterators to match the device order.
784 """
785 super(DistributedDataset, self).__init__(input_workers=input_workers)
786 if input_workers is None or strategy is None:
787 raise ValueError("input_workers and strategy are required arguments")
788 if dataset is not None and components is not None:
789 raise ValueError("Only one of dataset or components should be present")
790 if dataset is None and components is None:
791 raise ValueError("At least one of dataset or components should be passed")
793 self._input_workers = input_workers
794 self._strategy = strategy
795 self._options = options
796 self._input_context = input_context
797 self._num_replicas_in_sync = num_replicas_in_sync
798 self._replica_order = replica_order
800 if dataset is not None:
801 self._original_dataset = dataset
802 self._built = False
803 if build:
804 self.build()
805 else:
806 if not build:
807 raise ValueError(
808 "When constructing DistributedDataset with components, build "
809 "should not be False. This is an internal error. Please file a "
810 "bug.")
811 if enable_get_next_as_optional is None:
812 raise ValueError(
813 "When constructing DistributedDataset with components, " +
814 "enable_get_next_as_optional should also be passed")
815 self._cloned_datasets = components
816 self._cardinality = _cardinality(self._cloned_datasets[0])
817 self._enable_get_next_as_optional = enable_get_next_as_optional
819 assert element_spec is not None
820 if element_spec != _create_distributed_tensor_spec(
821 self._strategy, self._cloned_datasets[0].element_spec):
822 raise ValueError("Mismatched element_spec from the passed components")
823 self._element_spec = element_spec
825 self._built = True
827 def build(self, dataset_to_replace=None):
828 assert not self._built
829 dataset = dataset_to_replace or self._original_dataset
830 self._cardinality = _cardinality(dataset)
831 self._enable_get_next_as_optional = _enable_get_next_as_optional(
832 self._strategy, dataset, self._cardinality)
833 distribute_start_time_ns = time.time_ns()
834 self._create_cloned_datasets_from_dataset(dataset, self._input_context,
835 self._input_workers,
836 self._strategy,
837 self._num_replicas_in_sync)
838 if context.executing_eagerly():
839 # Records the time to initialize the distributed dataset.
840 context.async_wait()
841 distribute_duration_ms = (time.time_ns() -
842 distribute_start_time_ns) // 1_000_000
843 _distributed_dataset_initialization_time_milliseconds.get_cell(
844 self._strategy.__class__.__name__,
845 str(self._input_workers.num_workers)).add(distribute_duration_ms)
846 self._element_spec = _create_distributed_tensor_spec(
847 self._strategy, self._cloned_datasets[0].element_spec)
848 self._built = True
850 def auto_shard(self, num_shards, shard_ix):
851 assert (
852 len(self._cloned_datasets) == len(self._input_workers.worker_devices)
853 ), (
854 f"datasets: {len(self._cloned_datasets)}, "
855 f"input workers: {len(self._input_workers.worker_devices)}"
856 )
857 sharded_datasets = []
858 for i in range(len(self._input_workers.worker_devices)):
859 with ops.colocate_with(self._cloned_datasets[i]._variant_tensor): # pylint:disable=protected-access
860 sharded_datasets.append(
861 input_ops.auto_shard_dataset(
862 self._cloned_datasets[i], num_shards, shard_ix,
863 self._num_replicas_in_sync
864 ))
865 return DistributedDataset(
866 self._input_workers,
867 self._strategy,
868 components=sharded_datasets,
869 element_spec=self._element_spec,
870 options=self._options,
871 enable_get_next_as_optional=self._enable_get_next_as_optional)
873 @property
874 def cardinality(self):
875 if not self._built:
876 raise ValueError(
877 "Cannot get the cardinality of a dataset that is not built")
878 return self._cardinality
880 def _create_cloned_datasets_from_dataset(self, dataset, input_context,
881 input_workers, strategy,
882 num_replicas_in_sync):
883 # We clone and shard the dataset on each worker. The current setup tries to
884 # shard the dataset by files if possible so that each worker sees a
885 # different subset of files. If that is not possible, will attempt to shard
886 # the final input such that each worker will run the entire preprocessing
887 # pipeline and only receive its own shard of the dataset.
889 # Additionally, we rebatch the dataset on each worker into
890 # `num_replicas_in_sync` smaller batches to be distributed among that
891 # worker's replicas, so that the batch size for a global step (across all
892 # workers and replicas) adds up to the original dataset's batch size.
893 if num_replicas_in_sync is not None and num_replicas_in_sync > 1:
894 num_workers = input_context.num_input_pipelines if input_context else len(
895 input_workers.worker_devices)
896 rebatch_fn = self._make_rebatch_fn(dataset, num_workers,
897 num_replicas_in_sync)
898 else:
899 rebatch_fn = None
900 self._cloned_datasets = []
901 if input_context:
902 # Between-graph where we rely on the input_context for sharding
903 assert input_workers.num_workers == 1
904 if rebatch_fn is not None:
905 dataset = rebatch_fn(dataset, input_context.input_pipeline_id)
906 dataset = input_ops.auto_shard_dataset(dataset,
907 input_context.num_input_pipelines,
908 input_context.input_pipeline_id,
909 num_replicas_in_sync)
910 self._cloned_datasets.append(dataset)
911 else:
912 replicated_ds = distribute.replicate(dataset,
913 input_workers.worker_devices)
914 for i, worker in enumerate(input_workers.worker_devices):
915 with ops.device(worker):
916 cloned_dataset = replicated_ds[worker]
917 if rebatch_fn is not None:
918 cloned_dataset = rebatch_fn(cloned_dataset, i)
919 cloned_dataset = input_ops.auto_shard_dataset(
920 cloned_dataset, len(input_workers.worker_devices), i,
921 num_replicas_in_sync)
922 self._cloned_datasets.append(cloned_dataset)
924 def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync):
925 """Returns a callable that rebatches the input dataset.
927 Args:
928 dataset: A `tf.data.Dataset` representing the dataset to be distributed.
929 num_workers: An integer representing the number of workers to distribute
930 `dataset` among.
931 num_replicas_in_sync: An integer representing the number of replicas in
932 sync across all workers.
933 """
934 if num_replicas_in_sync % num_workers:
935 raise ValueError(
936 "tf.distribute expects every worker to have the same number of "
937 "replicas. However, encountered `num_replicas_in_sync` ({}) that "
938 "cannot be divided by `num_workers` ({})".format(
939 num_replicas_in_sync, num_workers))
941 num_replicas_per_worker = num_replicas_in_sync // num_workers
942 with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access
943 batch_size = distribute.compute_batch_size(dataset)
945 def rebatch_fn(dataset, worker_index):
946 try:
948 def apply_rebatch():
949 batch_sizes = distribute.batch_sizes_for_worker(
950 batch_size, num_workers, num_replicas_per_worker, worker_index)
951 return dataset.rebatch(batch_sizes).prefetch(num_replicas_per_worker)
953 # pylint: disable=protected-access
954 def apply_legacy_rebatch():
955 return distribute._LegacyRebatchDataset(
956 dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker)
958 with ops.colocate_with(dataset._variant_tensor):
959 return tf_cond.cond(
960 math_ops.not_equal(batch_size, -1),
961 true_fn=apply_rebatch,
962 false_fn=apply_legacy_rebatch)
963 except errors.InvalidArgumentError as e:
964 if "without encountering a batch" in str(e):
965 six.reraise(
966 ValueError,
967 ValueError(
968 "Call the `batch` method on the input Dataset in order to be "
969 "able to split your input across {} replicas.\n Please see "
970 "the tf.distribute.Strategy guide. {}".format(
971 num_replicas_in_sync, e)),
972 sys.exc_info()[2])
973 else:
974 raise
976 return rebatch_fn
978 def __iter__(self):
979 if not (context.executing_eagerly() or
980 ops.get_default_graph().building_function):
981 raise RuntimeError("__iter__() is only supported inside of tf.function "
982 "or when eager execution is enabled.")
983 if not self._built:
984 raise ValueError("To use this dataset, you need to pass this dataset to "
985 "ClusterCoordinator.create_per_worker_dataset.")
987 canonicalize_devices = getattr(self._strategy, "_canonicalize_devices",
988 True)
990 worker_iterators = _create_iterators_per_worker(
991 self._cloned_datasets,
992 self._input_workers,
993 options=self._options,
994 canonicalize_devices=canonicalize_devices)
995 iterator = DistributedIterator(
996 self._input_workers,
997 worker_iterators,
998 self._strategy,
999 cardinality=self._cardinality,
1000 enable_get_next_as_optional=self._enable_get_next_as_optional,
1001 options=self._options,
1002 replica_order=self._replica_order,
1003 )
1004 iterator._element_spec = self._element_spec # pylint: disable=protected-access
1006 # When async eager is enabled, sometimes the iterator may not finish
1007 # initialization before passing to a multi device function, add a sync point
1008 # here to make sure all underlying iterators are initialized.
1009 if context.executing_eagerly():
1010 context.async_wait()
1012 return iterator
1014 @property
1015 def element_spec(self):
1016 """The type specification of an element of this dataset."""
1017 # When partial batch handling is enabled, always set the batch dimension to
1018 # None, otherwise we just follow element_spec of the underlying dataset
1019 # (whose batch dimension may also be None). This is because with partial
1020 # batching handling we could always produce empty batches.
1021 if (self._enable_get_next_as_optional and
1022 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
1023 return nest.map_structure(
1024 _rebatch_as_dynamic, self._element_spec, expand_composites=False)
1025 return self._element_spec
1027 @property
1028 def _type_spec(self):
1029 return DistributedDatasetSpec(
1030 self._input_workers,
1031 self._element_spec,
1032 self._strategy,
1033 self._options,
1034 enable_get_next_as_optional=self._enable_get_next_as_optional)
1037class DistributedDatasetsFromFunctionSpec(DistributedDatasetAndIteratorSpec):
1038 """Type specification for `DistributedDatasetsFromFunction."""
1040 @property
1041 def value_type(self):
1042 return DistributedDatasetsFromFunction
1044 @property
1045 def _component_specs(self):
1046 specs = []
1047 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
1049 for i, _ in enumerate(worker_device_pairs):
1050 element_spec = nest.map_structure(
1051 functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
1052 specs.append(dataset_ops.DatasetSpec(element_spec))
1053 return specs
1055 def _to_components(self, value):
1056 return value._datasets # pylint: disable=protected-access
1058 def _from_components(self, components):
1059 return DistributedDatasetsFromFunction(
1060 input_workers=self._input_workers,
1061 strategy=self._strategy,
1062 components=components,
1063 element_spec=self._element_spec,
1064 options=self._options)
1066 @staticmethod
1067 def from_value(value):
1068 # pylint: disable=protected-access
1069 return DistributedDatasetsFromFunctionSpec(
1070 input_workers=value._input_workers,
1071 element_spec=value._element_spec,
1072 strategy=value._strategy,
1073 options=value._options)
1076# TODO(priyag): Add other replication modes.
1077class DistributedDatasetsFromFunction(_IterableInput,
1078 composite_tensor.CompositeTensor):
1079 """Inputs created from dataset function."""
1081 def __init__(
1082 self,
1083 input_workers,
1084 strategy,
1085 input_contexts=None,
1086 dataset_fn=None,
1087 options=None,
1088 components=None,
1089 element_spec=None,
1090 build=True,
1091 replica_order=None,
1092 ):
1093 """Makes an iterable from datasets created by the given function.
1095 Args:
1096 input_workers: an `InputWorkers` object.
1097 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
1098 handle last partial batch.
1099 input_contexts: A list of `InputContext` instances to be passed to call(s)
1100 to `dataset_fn`. Length and order should match worker order in
1101 `worker_device_pairs`.
1102 dataset_fn: A function that returns a `Dataset` given an `InputContext`.
1103 Either dataset_fn or components should be passed to construct
1104 DistributedDatasetsFromFunction. Use this when constructing
1105 DistributedDataset using a function. Use components when constructing
1106 using DistributedDatasetsFromFunctionSpec.
1107 options: `tf.distribute.InputOptions` used to control options on how this
1108 dataset is distributed.
1109 components: datasets when DistributedDatasetsFromFunction is constructed
1110 from DistributedDatasetsFromFunctionSpec. Only one of dataset or
1111 components should be passed.
1112 element_spec: element spec for DistributedDataset when constructing from
1113 DistributedDatasetSpec. This will be used to set the element_spec for
1114 DistributedDatasetsFromFunctionSpec and verified against element_spec
1115 from components.
1116 build: whether to build underlying datasets when this object is created.
1117 This is only useful for `ParameterServerStrategy` now.
1118 replica_order: the order of the replicas, which will be used to reorder
1119 the iterators to match the device order.
1120 """
1121 super(DistributedDatasetsFromFunction, self).__init__(
1122 input_workers=input_workers)
1123 self._input_workers = input_workers
1124 self._strategy = strategy
1125 self._options = options
1126 self._replica_order = replica_order
1127 if dataset_fn is not None and components is not None:
1128 raise ValueError("Only one of dataset_fn or components should be set")
1129 if dataset_fn is None and components is None:
1130 raise ValueError("At least one of dataset_fn or components should be set")
1132 if dataset_fn is not None:
1133 if input_workers.num_workers != len(input_contexts):
1134 raise ValueError(
1135 "Number of input workers (%d) is not same as number of "
1136 "input_contexts (%d)" %
1137 (input_workers.num_workers, len(input_contexts)))
1138 self._input_contexts = input_contexts
1139 self._num_replicas_in_sync = self._input_contexts[0].num_replicas_in_sync
1140 self._dataset_fn = dataset_fn
1141 self._built = False
1142 if build:
1143 self.build()
1144 else:
1145 if element_spec is None:
1146 raise ValueError(
1147 "element_spec should also be passed when passing components")
1148 if not build:
1149 raise ValueError(
1150 "When constructing DistributedDatasetFromFunction with components, "
1151 "build should not be False. This is an internal error. Please file "
1152 "a bug.")
1153 self._element_spec = element_spec
1154 self._datasets = components
1155 self._num_replicas_in_sync = None
1156 self._built = True
1157 self._cardinality = _cardinality(self._datasets[0])
1158 self._enable_get_next_as_optional = _enable_get_next_as_optional(
1159 self._strategy, self._datasets[0], self._cardinality)
1161 def build(self):
1162 assert not self._built
1163 distribute_start_time_ns = time.time_ns()
1164 self._datasets, element_spec = (
1165 _create_datasets_from_function_with_input_context(
1166 self._input_contexts, self._input_workers, self._dataset_fn))
1167 if context.executing_eagerly():
1168 # Records the time to initialize the distributed dataset.
1169 context.async_wait()
1170 distribute_duration_ms = (time.time_ns() -
1171 distribute_start_time_ns) // 1_000_000
1172 _distributed_dataset_from_function_initialization_time_milliseconds.get_cell(
1173 self._strategy.__class__.__name__,
1174 str(self._input_workers.num_workers)).add(distribute_duration_ms)
1176 self._element_spec = _create_distributed_tensor_spec(
1177 self._strategy, element_spec)
1178 self._cardinality = _cardinality(self._datasets[0])
1179 self._enable_get_next_as_optional = _enable_get_next_as_optional(
1180 self._strategy, self._datasets[0], self._cardinality)
1181 self._built = True
1183 def auto_shard(self, num_shards, shard_ix):
1184 assert (
1185 len(self._datasets) == len(self._input_workers.worker_devices)
1186 ), (
1187 f"datasets: {len(self._datasets)}, "
1188 f"input workers: {len(self._input_workers.worker_devices)}"
1189 )
1190 sharded_datasets = []
1191 for i in range(len(self._input_workers.worker_devices)):
1192 with ops.colocate_with(self._datasets[i]._variant_tensor): # pylint: disable=protected-access
1193 sharded_datasets.append(
1194 input_ops.auto_shard_dataset(
1195 self._datasets[i], num_shards, shard_ix,
1196 self._num_replicas_in_sync
1197 )
1198 )
1199 return DistributedDatasetsFromFunction(self._input_workers, self._strategy,
1200 components=sharded_datasets,
1201 element_spec=self._element_spec,
1202 options=self._options)
1204 @property
1205 def cardinality(self):
1206 if not self._built:
1207 raise ValueError(
1208 "Cannot get the cardinality of a dataset that is not built")
1209 return self._cardinality
1211 def __iter__(self):
1212 if not (ops.executing_eagerly_outside_functions() or
1213 ops.get_default_graph().building_function):
1214 raise RuntimeError("__iter__() is only supported inside of tf.function "
1215 "or when eager execution is enabled.")
1217 if not self._built:
1218 raise ValueError("You need to use this dataset in "
1219 "ClusterCoordinator.create_per_worker_dataset.")
1221 canonicalize_devices = getattr(self._strategy, "_canonicalize_devices",
1222 True)
1224 iterators = _create_iterators_per_worker(
1225 self._datasets,
1226 self._input_workers,
1227 options=self._options,
1228 canonicalize_devices=canonicalize_devices)
1229 iterator = DistributedIterator(
1230 input_workers=self._input_workers,
1231 iterators=iterators,
1232 strategy=self._strategy,
1233 cardinality=self._cardinality,
1234 enable_get_next_as_optional=self._enable_get_next_as_optional,
1235 options=self._options,
1236 replica_order=self._replica_order,
1237 )
1238 iterator._element_spec = self._element_spec # pylint: disable=protected-access
1240 # When async eager is enabled, sometimes the iterator may not finish
1241 # initialization before passing to a multi device function, add a sync
1242 # point here to make sure all underlying iterators are initialized.
1243 if context.executing_eagerly():
1244 context.async_wait()
1246 return iterator
1248 @property
1249 def element_spec(self):
1250 """The type specification of an element of this dataset."""
1251 # When partial batch handling is enabled, always set the batch dimension to
1252 # None, otherwise we just follow element_spec of the underlying dataset
1253 # (whose batch dimension may also be None). This is because with partial
1254 # batching handling we could always produce empty batches.
1255 if (self._enable_get_next_as_optional and
1256 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
1257 return nest.map_structure(
1258 _rebatch_as_dynamic, self._element_spec, expand_composites=False)
1259 return self._element_spec
1261 @property
1262 def _type_spec(self):
1263 return DistributedDatasetsFromFunctionSpec(self._input_workers,
1264 self._element_spec,
1265 self._strategy, self._options)
1268def _dummy_tensor_fn(value_structure):
1269 """A function to create dummy tensors from `value_structure`."""
1271 def create_dummy_tensor(spec):
1272 """Create a dummy tensor with possible batch dimensions set to 0."""
1273 if hasattr(spec, "_create_empty_value"):
1274 # Type spec may overwrite default dummy values behavior by declaring the
1275 # `_create_empty_value(self)` method. This method must return a value
1276 # compatible with the type spec with batch dimensions set to 0 or fail if
1277 # such a value does not exist. This allows a composite tensor to customize
1278 # dummy values creation as, in general, its dummy value is not composed
1279 # from dummy components (e.g. `row_splits` tensor of a RaggedTensor is
1280 # never allowed to be empty). See b/183969859 for more discussions.
1281 # TODO(b/186079336): reconsider CompositeTensor support.
1282 return spec._create_empty_value() # pylint: disable=protected-access
1284 if isinstance(spec, ragged_tensor.RaggedTensorSpec):
1285 # Splice out the ragged dimensions.
1286 # pylint: disable=protected-access
1287 feature_shape = spec._shape[:1].concatenate(
1288 spec._shape[(1 + spec._ragged_rank):])
1289 feature_type = spec._dtype
1290 # pylint: enable=protected-access
1291 else:
1292 feature_shape = spec.shape
1293 feature_type = spec.dtype
1294 # Ideally we should set the batch dimension to 0, however as in
1295 # DistributionStrategy we don't know the batch dimension, we try to
1296 # guess it as much as possible. If the feature has unknown dimensions, we
1297 # will set them to 0. If the feature shape is already static, we guess the
1298 # first dimension as batch dimension and set it to 0.
1299 dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()]
1300 if feature_shape else [])
1301 if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or
1302 feature_shape.is_fully_defined()):
1303 dims[0] = tensor_shape.Dimension(0)
1305 if isinstance(spec, sparse_tensor.SparseTensorSpec):
1306 return sparse_tensor.SparseTensor(
1307 values=array_ops.zeros(0, feature_type),
1308 indices=array_ops.zeros((0, len(dims)), dtypes.int64),
1309 dense_shape=dims)
1311 # Create the dummy tensor.
1312 dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
1313 if isinstance(spec, ragged_tensor.RaggedTensorSpec):
1314 # Reinsert the ragged dimensions with size 0.
1315 # pylint: disable=protected-access
1316 row_splits = array_ops.zeros(1, spec._row_splits_dtype)
1317 dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits(
1318 dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False)
1319 # pylint: enable=protected-access
1320 return dummy_tensor
1322 return nest.map_structure(create_dummy_tensor, value_structure)
1325def _get_value_or_dummy(input_workers, optional_list, produce_dummy):
1326 """Returns the value of the optionals or dummy values.
1328 Args:
1329 input_workers: the `InputWorkers`.
1330 optional_list: a list of lists `tf.experimental.Optional`. The values from
1331 each compute device grouped by the input device.
1332 produce_dummy: a bool. Whether to produce dummy tensors when the optional
1333 doesn't have a value.
1335 Returns:
1336 A flatten list of Tensors.
1338 """
1339 value_list = []
1340 for i, worker in enumerate(input_workers.worker_devices):
1341 with ops.device(worker):
1342 devices = input_workers.compute_devices_for_worker(i)
1343 for j, device in enumerate(devices):
1344 with ops.device(device):
1345 if produce_dummy:
1346 # pylint: disable=cell-var-from-loop
1347 value_list.append(
1348 tf_cond.cond(
1349 optional_list[i][j].has_value(),
1350 lambda: optional_list[i][j].get_value(), # pylint: disable=unnecessary-lambda
1351 lambda: _dummy_tensor_fn(optional_list[i][j].element_spec),
1352 strict=True,
1353 ))
1354 # pylint: enable=cell-var-from-loop
1355 else:
1356 value_list.append(optional_list[i][j].get_value())
1357 return value_list
1360class _SingleWorkerDatasetIteratorBase(object):
1361 """Iterator for a single `tf.data.Dataset`."""
1363 def __init__(self, dataset, worker, devices, options=None):
1364 """Create iterator for the `dataset` to fetch data to worker's `devices` .
1366 A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch
1367 input to the devices on the given worker.
1369 Args:
1370 dataset: A `tf.data.Dataset` instance.
1371 worker: Worker on which ops should be created.
1372 devices: Distribute data from `dataset` to these devices.
1373 options: options.
1374 """
1375 self._dataset = dataset
1376 self._worker = worker
1377 self._devices = devices
1378 self._element_spec = dataset.element_spec
1379 self._options = options
1380 self._make_iterator()
1382 def _make_iterator(self):
1383 raise NotImplementedError("must be implemented in descendants")
1385 def _format_data_list_with_options(self, data_list):
1386 """Change the data in to a list type if required.
1388 The OwnedMultiDeviceIterator returns the list data type,
1389 while the PER_REPLICA iterator (when used with prefetch disabled)
1390 returns without the enclosed list. This is to fix the inconsistency.
1391 Args:
1392 data_list: data_list
1393 Returns:
1394 list
1395 """
1396 if (self._options and self._options.experimental_replication_mode ==
1397 InputReplicationMode.PER_REPLICA and
1398 not self._options.experimental_fetch_to_device):
1399 return [data_list]
1400 else:
1401 return data_list
1403 def get_next(self, device, name=None):
1404 """Get next element for the given device."""
1405 del name
1406 with ops.device(self._worker):
1407 if _should_use_multi_device_iterator(self._options):
1408 return self._iterator.get_next(device)
1409 else:
1410 return self._iterator.get_next()
1412 def get_next_as_list(self, name=None):
1413 """Get next element from the underlying iterator.
1415 Runs the iterator get_next() within a device scope. Since this doesn't use
1416 get_next_as_optional(), it is considerably faster than get_next_as_list(),
1417 but it raises EOFError if any of the device doesn't get any data.
1419 Args:
1420 name: not used.
1422 Returns:
1423 A list consisting of the next data from each device.
1424 """
1425 del name
1426 with ops.device(self._worker):
1427 return self._format_data_list_with_options(self._iterator.get_next())
1429 def get_next_as_optional_list(self):
1430 with ops.device(self._worker):
1431 return self._format_data_list_with_options(
1432 self._iterator.get_next_as_optional())
1435class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
1436 """Type specification for `_SingleWorkerOwnedDatasetIterator`."""
1438 __slots__ = [
1439 "_worker", "_devices", "_element_spec", "_options",
1440 "_canonicalize_devices"
1441 ]
1443 def __init__(self, worker, devices, element_spec, options,
1444 canonicalize_devices=True):
1445 self._worker = worker
1446 if canonicalize_devices:
1447 self._devices = tuple(device_util.canonicalize(d) for d in devices)
1448 else:
1449 self._devices = tuple(
1450 device_util.canonicalize_without_job_and_task(d) for d in devices)
1451 self._element_spec = element_spec
1452 # `self._options` intentionally made not `None` for proper serialization.
1453 self._options = (options if options is not None else
1454 distribute_lib.InputOptions())
1455 self._canonicalize_devices = canonicalize_devices
1457 @property
1458 def value_type(self):
1459 return _SingleWorkerOwnedDatasetIterator
1461 def _serialize(self):
1462 return (self._worker, self._devices, self._element_spec, self._options,
1463 self._canonicalize_devices)
1465 def _get_multi_device_iterator_spec(self, specs):
1466 device_scope = device_util.canonicalize(self._worker, device_util.current())
1467 host_device = device_util.get_host_for_device(device_scope)
1468 # source_device while creating iterator governs the worker device in
1469 # iterator spec.
1470 worker = host_device
1471 specs.append(
1472 multi_device_iterator_ops.MultiDeviceIteratorSpec(
1473 self._devices, worker, element_spec=self._element_spec))
1475 @property
1476 def _component_specs(self):
1477 specs = []
1478 if _should_use_multi_device_iterator(self._options):
1479 self._get_multi_device_iterator_spec(specs)
1480 else:
1481 specs.append(iterator_ops.IteratorSpec(element_spec=self._element_spec))
1482 return specs
1484 def _to_components(self, value):
1485 return [value._iterator] # pylint: disable=protected-access
1487 def _from_components(self, components):
1488 return _SingleWorkerOwnedDatasetIterator(
1489 dataset=None,
1490 worker=self._worker,
1491 devices=self._devices,
1492 components=components,
1493 element_spec=self._element_spec,
1494 options=self._options,
1495 canonicalize_devices=self._canonicalize_devices)
1497 @staticmethod
1498 def from_value(value):
1499 # pylint: disable=protected-access
1500 return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
1501 value._element_spec, value._options,
1502 value._canonicalize_devices)
1505class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
1506 composite_tensor.CompositeTensor):
1507 """Iterator for a DistributedDataset instance."""
1509 def __init__(self,
1510 dataset=None,
1511 worker=None,
1512 devices=None,
1513 components=None,
1514 element_spec=None,
1515 options=None,
1516 canonicalize_devices=None):
1517 """Create iterator for the `dataset` to fetch data to worker's `devices` .
1519 `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
1520 given worker. The lifetime of this iterator is tied to the encompassing
1521 python object. Once we go out of scope of the python object or return from
1522 a tf.function the underlying iterator resource is deleted.
1524 Args:
1525 dataset: A `tf.data.Dataset` instance.
1526 worker: Worker on which ops should be created.
1527 devices: Distribute data from `dataset` to these devices.
1528 components: Tensor components to construct the
1529 _SingleWorkerOwnedDatasetIterator from.
1530 element_spec: A nested structure of `TypeSpec` objects that represents the
1531 type specification of elements of the iterator.
1532 options: `tf.distribute.InputOptions` used to control options on how this
1533 dataset is distributed.
1534 canonicalize_devices: Whether to canonicalize devices for workers fully or
1535 partially. If False, it will partially canonicalize devices by removing
1536 job and task.
1537 """
1538 if worker is None or devices is None:
1539 raise ValueError("Both `worker` and `devices` should be provided")
1541 error_message = ("Either `dataset` or both `components` and `element_spec` "
1542 "need to be provided.")
1544 self._options = options
1545 self._canonicalize_devices = canonicalize_devices
1546 if dataset is None:
1547 if (components is None or element_spec is None):
1548 raise ValueError(error_message)
1549 self._element_spec = element_spec
1550 self._worker = worker
1551 self._devices = devices
1552 self._iterator = components[0]
1553 else:
1554 if (components is not None or element_spec is not None):
1555 raise ValueError(error_message)
1556 super(_SingleWorkerOwnedDatasetIterator,
1557 self).__init__(dataset, worker, devices, self._options)
1559 def _create_owned_multi_device_iterator(self):
1560 # If the worker devices are already canonicalized, canonicalizing again
1561 # would have no impact.
1562 # For strategies running on remote workers such as PS Strategy, the device
1563 # scope will be derived from current worker, if used under init_scope().
1564 if not ops.inside_function():
1565 device_scope = device_util.canonicalize(self._worker,
1566 device_util.current())
1567 host_device = device_util.get_host_for_device(device_scope)
1568 else:
1569 # In general, iterators should not be created within tf.functions. For
1570 # exact visitation guarantee solutions for parameter server training,
1571 # however, we do create iterators within the tf.functions that are
1572 # dispatched to workers. In these cases, the traced device must match the
1573 # runtime device. Since tracing occurs on the chief, we do not want to use
1574 # the current device scope, which would be the chief, but rather use the
1575 # relative worker device scope explicitly.
1576 device_scope, host_device = self._worker, self._worker
1577 with ops.device(device_scope):
1578 if self._options is not None:
1579 self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
1580 self._dataset,
1581 self._devices,
1582 source_device=host_device,
1583 max_buffer_size=self._options
1584 .experimental_per_replica_buffer_size,
1585 prefetch_buffer_size=self._options
1586 .experimental_per_replica_buffer_size)
1587 else:
1588 self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
1589 self._dataset, self._devices, source_device=host_device)
1591 def _make_iterator(self):
1592 """Make appropriate iterator on the dataset."""
1593 if not self._worker:
1594 raise ValueError("Worker device must be specified when creating an "
1595 "owned iterator.")
1596 if _should_use_multi_device_iterator(self._options):
1597 self._create_owned_multi_device_iterator()
1598 else:
1599 with ops.device(self._worker):
1600 self._iterator = iter(self._dataset)
1602 @property
1603 def element_spec(self):
1604 return self._element_spec
1606 @property
1607 def _type_spec(self):
1608 return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
1609 self._element_spec, self._options,
1610 self._canonicalize_devices)
1612 @property
1613 def output_classes(self):
1614 """Returns the class of each component of an element of this iterator.
1616 The expected values are `tf.Tensor` and `tf.SparseTensor`.
1618 Returns:
1619 A nested structure of Python `type` objects corresponding to each
1620 component of an element of this dataset.
1621 """
1622 return nest.map_structure(
1623 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
1624 self._element_spec)
1626 @property
1627 def output_shapes(self):
1628 """Returns the shape of each component of an element of this iterator.
1630 Returns:
1631 A nested structure of `tf.TensorShape` objects corresponding to each
1632 component of an element of this dataset.
1633 """
1634 return nest.map_structure(
1635 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
1636 self._element_spec)
1638 @property
1639 def output_types(self):
1640 """Returns the type of each component of an element of this iterator.
1642 Returns:
1643 A nested structure of `tf.DType` objects corresponding to each component
1644 of an element of this dataset.
1645 """
1646 return nest.map_structure(
1647 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
1648 self._element_spec)
1651def _create_iterators_per_worker(worker_datasets,
1652 input_workers,
1653 options=None,
1654 canonicalize_devices=False):
1655 """Create a multidevice iterator on each of the workers."""
1656 assert isinstance(input_workers, InputWorkers)
1657 assert len(worker_datasets) == len(input_workers.worker_devices)
1658 iterators = []
1659 for i, worker in enumerate(input_workers.worker_devices):
1660 with ops.device(worker):
1661 worker_devices = input_workers.compute_devices_for_worker(i)
1662 iterator = _SingleWorkerOwnedDatasetIterator(
1663 dataset=worker_datasets[i],
1664 worker=worker,
1665 devices=worker_devices,
1666 options=options,
1667 canonicalize_devices=canonicalize_devices)
1668 iterators.append(iterator)
1669 return iterators
1672def _create_datasets_from_function_with_input_context(input_contexts,
1673 input_workers,
1674 dataset_fn):
1675 """Create device datasets per worker given a dataset function."""
1676 datasets = []
1677 for i, ctx in enumerate(input_contexts):
1678 worker = input_workers.worker_devices[i]
1679 with ops.device(worker):
1680 dataset = dataset_fn(ctx)
1681 datasets.append(dataset)
1682 return datasets, dataset.element_spec
1685# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
1686def _get_batched_dataset(d):
1687 """Get the batched dataset from `d`."""
1688 # pylint: disable=protected-access
1689 if isinstance(d, dataset_ops.DatasetV1Adapter):
1690 d = d._dataset
1692 if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
1693 return d
1694 elif isinstance(d, (dataset_ops.PrefetchDataset,
1695 dataset_ops._OptionsDataset)):
1696 return _get_batched_dataset(d._input_dataset)
1698 raise ValueError(
1699 "Unable to get batched dataset from the input dataset. `batch` "
1700 "`map_and_batch` need to be the last operations on the dataset. "
1701 "The batch operations can be followed by a prefetch.")
1704def _get_batched_dataset_attributes(d):
1705 """Get `batch_size`, `drop_remainder` of dataset."""
1706 # pylint: disable=protected-access
1707 assert isinstance(d,
1708 (dataset_ops.BatchDataset, batching._MapAndBatchDataset))
1709 if isinstance(d, dataset_ops.BatchDataset):
1710 batch_size = d._batch_size
1711 drop_remainder = d._drop_remainder
1712 elif isinstance(d, batching._MapAndBatchDataset):
1713 batch_size = d._batch_size_t
1714 drop_remainder = d._drop_remainder_t
1715 # pylint: enable=protected-access
1717 if tensor_util.is_tf_type(batch_size):
1718 batch_size = tensor_util.constant_value(batch_size)
1720 if tensor_util.is_tf_type(drop_remainder):
1721 drop_remainder = tensor_util.constant_value(drop_remainder)
1723 return batch_size, drop_remainder
1726# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
1727def _get_dataset_attributes(dataset):
1728 """Get the underlying attributes from the dataset object."""
1729 # pylint: disable=protected-access
1731 # First, get batch_size and drop_remainder from the dataset. We need
1732 # to walk back the dataset creation process and find the batched version in
1733 # order to get the attributes.
1734 batched_dataset = _get_batched_dataset(dataset)
1735 batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset)
1737 # Second, prefetch buffer should be get from the original dataset.
1738 prefetch_buffer = None
1739 if isinstance(dataset, dataset_ops.PrefetchDataset):
1740 prefetch_buffer = dataset._buffer_size
1741 elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
1742 and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
1743 prefetch_buffer = dataset._dataset._buffer_size
1745 return batch_size, drop_remainder, prefetch_buffer
1748def _should_use_multi_device_iterator(options):
1749 """Determine whether to use multi_device_iterator_ops."""
1750 if (options is None or
1751 options.experimental_replication_mode == InputReplicationMode.PER_WORKER
1752 or
1753 (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
1754 and options.experimental_fetch_to_device)):
1755 return True
1756 return False
1759class MultiStepContext(object):
1760 """A context object that can be used to capture things when running steps.
1762 This context object is useful when running multiple steps at a time using the
1763 `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
1764 function to specify which outputs to emit at what frequency. Currently it
1765 supports capturing output from the last step, as well as capturing non tensor
1766 outputs. In the future it will be augmented to support other use cases such
1767 as output each N steps.
1768 """
1770 def __init__(self):
1771 """Initialize an output context.
1773 Returns:
1774 A context object.
1775 """
1776 self._last_step_outputs = {}
1777 self._last_step_outputs_reduce_ops = {}
1778 self._non_tensor_outputs = {}
1780 @property
1781 def last_step_outputs(self):
1782 """A dictionary consisting of outputs to be captured on last step.
1784 Keys in the dictionary are names of tensors to be captured, as specified
1785 when `set_last_step_output` is called.
1786 Values in the dictionary are the tensors themselves. If
1787 `set_last_step_output` was called with a `reduce_op` for this output,
1788 then the value is the reduced value.
1790 Returns:
1791 A dictionary with last step outputs.
1792 """
1793 return self._last_step_outputs
1795 def _set_last_step_outputs(self, outputs):
1796 """Replace the entire dictionary of last step outputs."""
1797 if not isinstance(outputs, dict):
1798 raise ValueError("Need a dictionary to set last_step_outputs.")
1799 self._last_step_outputs = outputs
1801 def set_last_step_output(self, name, output, reduce_op=None):
1802 """Set `output` with `name` to be outputted from the last step.
1804 Args:
1805 name: String, name to identify the output. Doesn't need to match tensor
1806 name.
1807 output: The tensors that should be outputted with `name`. See below for
1808 actual types supported.
1809 reduce_op: Reduction method to use to reduce outputs from multiple
1810 replicas. Required if `set_last_step_output` is called in a replica
1811 context. Optional in cross_replica_context.
1812 When present, the outputs from all the replicas are reduced using the
1813 current distribution strategy's `reduce` method. Hence, the type of
1814 `output` must be what's supported by the corresponding `reduce` method.
1815 For e.g. if using MirroredStrategy and reduction is set, output
1816 must be a `PerReplica` value.
1817 The reduce method is also recorded in a dictionary
1818 `_last_step_outputs_reduce_ops` for later interpreting of the
1819 outputs as already reduced or not.
1820 """
1821 if distribute_lib.in_cross_replica_context():
1822 self._last_step_outputs_reduce_ops[name] = reduce_op
1823 if reduce_op is None:
1824 self._last_step_outputs[name] = output
1825 else:
1826 distribution = distribute_lib.get_strategy()
1827 self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
1828 axis=None)
1829 else:
1830 assert reduce_op is not None
1831 def merge_fn(distribution, value):
1832 self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
1833 axis=None)
1834 # Setting this inside the `merge_fn` because all replicas share the same
1835 # context object, so it's more robust to set it only once (even if all
1836 # the replicas are trying to set the same value).
1837 self._last_step_outputs_reduce_ops[name] = reduce_op
1839 distribute_lib.get_replica_context().merge_call(
1840 merge_fn, args=(output,))
1842 @property
1843 def non_tensor_outputs(self):
1844 """A dictionary consisting of any non tensor outputs to be captured."""
1845 return self._non_tensor_outputs
1847 def set_non_tensor_output(self, name, output):
1848 """Set `output` with `name` to be captured as a non tensor output."""
1849 if distribute_lib.in_cross_replica_context():
1850 self._non_tensor_outputs[name] = output
1851 else:
1852 def merge_fn(distribution, value):
1853 # NOTE(priyag): For non tensor outputs, we simply return all the values
1854 # in a list as reduction doesn't make sense on non tensors.
1855 self._non_tensor_outputs[name] = (
1856 distribution.experimental_local_results(value))
1857 distribute_lib.get_replica_context().merge_call(
1858 merge_fn, args=(output,))
1861def _create_distributed_tensor_spec(strategy, tensor_spec):
1862 """Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`.
1864 Args:
1865 strategy: The given `tf.distribute` strategy.
1866 tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the
1867 shape should be None if you have partial batches.
1869 Returns:
1870 A `tf.TypeSpec` that matches the values produced by a given strategy. This
1871 can be a `tf.TensorSpec` or a `PerRelicaSpec`.
1872 """
1873 num_replicas = len(strategy.extended.worker_devices)
1875 # For one device strategy that is not MultiWorkerMirroredStrategy, return the
1876 # tensor_spec as is, since we don't wrap the output with PerReplica in this
1877 # case.
1878 # TODO(b/166464552): remove after we always wrap for all strategies.
1879 if not _always_wrap(strategy):
1880 return tensor_spec
1882 # For other cases we assume the input to tf.function is a per replica type.
1883 def _get_value_per_replica(tensor_spec_per_input):
1884 value_specs = [tensor_spec_per_input for _ in range(num_replicas)]
1885 return values.PerReplicaSpec(*value_specs)
1887 return nest.map_structure(_get_value_per_replica, tensor_spec)
1890def _replace_per_replica_spec(spec, i):
1891 """If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec."""
1892 if isinstance(spec, values.PerReplicaSpec):
1893 return spec._value_specs[i] # pylint: disable=protected-access
1894 else:
1895 return spec
1898def _cardinality(dataset):
1899 """Returns the cardinality of the dataset."""
1900 if context.executing_eagerly():
1901 with ops.device(dataset._variant_tensor.device): # pylint: disable=protected-access
1902 return dataset.cardinality().numpy()
1903 return cardinality_lib.UNKNOWN
1906def _enable_get_next_as_optional(strategy, dataset, cardinality):
1907 """Returns whether to enable using partial batch handling."""
1908 # TODO(b/133073708): we currently need a flag to control the usage because
1909 # there is a performance difference between get_next() and
1910 # get_next_as_optional(). And we only enable get_next_as_optional when the
1911 # output shapes are not static.
1912 #
1913 # TODO(rxsang): We want to always enable the get_next_as_optional behavior
1914 # when user passed input_fn instead of dataset.
1915 if not getattr(
1916 strategy.extended, "enable_partial_batch_handling",
1917 getattr(strategy.extended, "experimental_enable_get_next_as_optional",
1918 False)):
1919 return False
1921 # If the dataset is infinite, we don't need to enable last partial batch
1922 # support. Note that we can only evaluate the cardinality of the dataset in
1923 # eager.
1924 if cardinality == cardinality_lib.INFINITE:
1925 return False
1927 return not _is_statically_shaped(
1928 dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
1931def _create_per_replica(value_list, strategy):
1932 """Creates a PerReplica.
1934 For strategies other than OneDeviceStrategy, it creates a PerReplica whose
1935 type spec is set to the element spec of the dataset. This helps avoid
1936 retracing for partial batches. Retracing is problematic for multi client when
1937 different client retraces different time, since retracing changes the
1938 collective keys in the tf.function, and causes mismatches among clients.
1940 For single client strategies, this simply calls distribute_utils.regroup().
1942 Args:
1943 value_list: a list of values, one for each replica.
1944 strategy: the `tf.distribute.Strategy`.
1946 Returns:
1947 a structure of PerReplica.
1949 """
1950 # TODO(b/166464552): always wrap for all one device strategies as well.
1951 always_wrap = _always_wrap(strategy)
1952 per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap)
1953 return per_replicas
1956def _always_wrap(strategy):
1957 """Returns whether to always wrap the values in a DistributedValues."""
1958 return strategy.extended._in_multi_worker_mode() or len( # pylint: disable=protected-access
1959 strategy.extended.worker_devices) > 1
1962def _rebatch_as_dynamic(per_replica_spec):
1963 """Rebatch the spec to have a dynamic batch dimension."""
1964 assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec
1966 # pylint: disable=protected-access
1967 def _rebatch(spec):
1968 # Rebatch if possible.
1969 try:
1970 return spec._unbatch()._batch(None)
1971 except ValueError:
1972 pass
1973 return spec
1975 return values.PerReplicaSpec(
1976 *nest.map_structure(_rebatch, per_replica_spec._value_specs))
1977 # pylint: enable=protected-access
1980def _ag_enumerate_not_implemented(s, unused_start):
1981 msg = (
1982 f"enumerate not supported with {s.__class__.__name__} types within "
1983 "tf.functions. Use a for loop over the dataset and keep a separate "
1984 "counter instead."
1985 )
1986 raise NotImplementedError(msg)
1989py_builtins.enumerate_registry.register(
1990 DistributedIterator, _ag_enumerate_not_implemented
1991)
1992py_builtins.enumerate_registry.register(
1993 DistributedDataset, _ag_enumerate_not_implemented
1994)