Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/v1/input_lib.py: 36%
166 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed inputs."""
17from tensorflow.python.data.experimental.ops import cardinality as cardinality_lib
18from tensorflow.python.data.ops import dataset_ops
19from tensorflow.python.data.ops import multi_device_iterator_ops
20from tensorflow.python.data.ops import optional_ops
21from tensorflow.python.distribute import input_lib
22from tensorflow.python.eager import context
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.types import data as data_types
26from tensorflow.python.util.deprecation import deprecated
29class DistributedDatasetV1(input_lib.DistributedDataset):
30 """Distributed dataset that supports prefetching to multiple devices."""
32 def __init__(self,
33 dataset,
34 input_workers,
35 strategy,
36 num_replicas_in_sync=None,
37 input_context=None,
38 options=None):
39 self._input_workers = input_workers
40 super(DistributedDatasetV1, self).__init__(
41 input_workers,
42 strategy,
43 dataset,
44 num_replicas_in_sync=num_replicas_in_sync,
45 input_context=input_context,
46 options=options)
48 def make_one_shot_iterator(self):
49 """Get a one time use iterator for DistributedDatasetV1.
51 Note: This API is deprecated. Please use `for ... in dataset:` to iterate
52 over the dataset or `iter` to create an iterator.
54 Returns:
55 A DistributedIteratorV1 instance.
56 """
57 return self._make_one_shot_iterator()
59 def _make_one_shot_iterator(self):
60 """Get an iterator for DistributedDatasetV1."""
61 # Graph mode with one shot iterator is disabled because we have to call
62 # `initialize` on the iterator which is only required if we are using a
63 # tf.distribute strategy.
64 if not context.executing_eagerly():
65 raise ValueError("Cannot create a one shot iterator. Please use "
66 "`make_initializable_iterator()` instead.")
67 return self._get_iterator()
69 def make_initializable_iterator(self):
70 """Get an initializable iterator for DistributedDatasetV1.
72 Note: This API is deprecated. Please use
73 `tf.compat.v1.data.make_initializable_iterator(dataset)` to create an
74 initializable iterator.
76 Returns:
77 A DistributedIteratorV1 instance.
78 """
79 return self._make_initializable_iterator()
81 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=unused-argument
82 """Get an initializable iterator for DistributedDatasetV1."""
83 # Eager mode generates already initialized iterators. Hence we cannot create
84 # an initializable iterator.
85 if context.executing_eagerly():
86 raise ValueError("Cannot create initializable iterator in Eager mode. "
87 "Please use `iter()` instead.")
88 return self._get_iterator()
90 def _get_iterator(self):
91 worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
92 self._input_workers,
93 self._options)
94 cardinality = input_lib._cardinality(self._cloned_datasets[0]) # pylint: disable=protected-access
95 iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
96 self._strategy, cardinality,
97 self._enable_get_next_as_optional)
98 iterator._element_spec = self.element_spec # pylint: disable=protected-access
100 # When async eager is enabled, sometimes the iterator may not finish
101 # initialization before passing to a multi device function, add a sync point
102 # here to make sure all underlying iterators are initialized.
103 if context.executing_eagerly():
104 context.async_wait()
106 return iterator
108 # pylint: disable=non-iterator-returned
109 def __iter__(self):
110 if (ops.executing_eagerly_outside_functions() or
111 ops.get_default_graph().building_function):
112 return self._get_iterator()
114 raise RuntimeError("__iter__() is only supported inside of tf.function "
115 "or when eager execution is enabled.")
117 # pylint: enable=non-iterator-returned
120class DistributedDatasetsFromFunctionV1(
121 input_lib.DistributedDatasetsFromFunction):
122 """Inputs created from dataset function."""
124 def _make_initializable_iterator(self, shared_name=None):
125 """Get an initializable iterator for DistributedDatasetsFromFunctionV1."""
126 del shared_name # Unused
127 # Eager mode generates already initialized iterators. Hence we cannot create
128 # an initializable iterator.
129 if context.executing_eagerly():
130 raise ValueError("Cannot create initializable iterator in Eager mode. "
131 "Please use `iter()` instead.")
132 return self._get_iterator()
134 def _make_one_shot_iterator(self):
135 """Get an iterator for iterating over DistributedDatasetsFromFunctionV1."""
136 # Graph mode with one shot iterator is disabled because we have to call
137 # `initialize` on the iterator which is only required if we are using a
138 # tf.distribute strategy.
139 if not context.executing_eagerly():
140 raise ValueError("Cannot create a one shot iterator. Please use "
141 "`make_initializable_iterator()` instead.")
142 return self._get_iterator()
144 def _get_iterator(self):
145 iterators = _create_iterators_per_worker(self._datasets,
146 self._input_workers, self._options)
147 cardinality = input_lib._cardinality(self._datasets[0]) # pylint: disable=protected-access
148 iterator = DistributedIteratorV1(self._input_workers, iterators,
149 self._strategy, cardinality,
150 self._enable_get_next_as_optional)
151 iterator._element_spec = self._element_spec # pylint: disable=protected-access
153 # When async eager is enabled, sometimes the iterator may not finish
154 # initialization before passing to a multi device function, add a sync point
155 # here to make sure all underlying iterators are initialized.
156 if context.executing_eagerly():
157 context.async_wait()
159 return iterator
161 # pylint: disable=non-iterator-returned
162 def __iter__(self):
163 if (ops.executing_eagerly_outside_functions() or
164 ops.get_default_graph().building_function):
165 return self._get_iterator()
167 raise RuntimeError("__iter__() is only supported inside of tf.function "
168 "or when eager execution is enabled.")
170 # pylint: enable=non-iterator-returned
173class DistributedIteratorV1(input_lib.DistributedIteratorBase):
174 """Input Iterator for a distributed dataset."""
176 # We need a private initializer method for re-initializing multidevice
177 # iterators when used with Keras training loops. If we don't reinitialize the
178 # iterator we run into memory leak issues (b/123315763).
179 @property
180 def _initializer(self):
181 init_ops = []
182 for it in self._iterators:
183 init_ops.extend(it.initialize())
184 return control_flow_ops.group(init_ops)
186 @deprecated(None, "Use the iterator's `initializer` property instead.")
187 def initialize(self):
188 """Initialize underlying iterators.
190 Returns:
191 A list of any initializer ops that should be run.
192 """
193 return self._initializer
195 @property
196 def initializer(self):
197 """Returns a list of ops that initialize the iterator."""
198 return self.initialize()
200 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
201 @property
202 def output_classes(self):
203 return self._iterators[0].output_classes
205 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
206 @property
207 def output_shapes(self):
208 return self._iterators[0].output_shapes
210 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
211 @property
212 def output_types(self):
213 return self._iterators[0].output_types
215 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
216 def get_iterator(self, worker):
217 for i, w in enumerate(self._input_workers.worker_devices):
218 if worker == w:
219 return self._iterators[i]
220 return None
222 @property
223 def element_spec(self):
224 """The type specification of an element of this iterator."""
225 return self._element_spec
228class DatasetIterator(DistributedIteratorV1):
229 """Iterator created from input dataset."""
231 def __init__(self,
232 dataset,
233 input_workers,
234 strategy,
235 num_replicas_in_sync=None,
236 input_context=None):
237 """Make an iterator for the dataset on given devices.
239 If `num_replicas_in_sync` is not None, we split each batch of the dataset
240 into `num_replicas_in_sync` smaller batches, to be distributed among that
241 worker's replicas, so that the batch size for a global step (across all
242 workers and replicas) is as expected.
244 Args:
245 dataset: `tf.data.Dataset` that will be used as the input source.
246 input_workers: an `InputWorkers` object.
247 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
248 handle last partial batch.
249 num_replicas_in_sync: Optional integer. If this is not None, the value is
250 used to decide how to rebatch datasets into smaller batches so that the
251 total batch size for each step (across all workers and replicas) adds up
252 to `dataset`'s batch size.
253 input_context: `InputContext` for sharding. Only pass this in for between
254 graph multi-worker cases where there is only one `input_worker`. In
255 these cases, we will shard based on the `input_pipeline_id` and
256 `num_input_pipelines` in the `InputContext`.
257 """
258 dist_dataset = DistributedDatasetV1(
259 dataset,
260 input_workers,
261 strategy,
262 num_replicas_in_sync=num_replicas_in_sync,
263 input_context=input_context)
264 # pylint: disable=protected-access
265 worker_iterators = _create_iterators_per_worker(
266 dist_dataset._cloned_datasets, input_workers)
267 super(DatasetIterator,
268 self).__init__(input_workers, worker_iterators, strategy,
269 dist_dataset.cardinality,
270 dist_dataset._enable_get_next_as_optional)
271 self._element_spec = dist_dataset.element_spec
272 # pylint: enable=protected-access
275class InputFunctionIterator(DistributedIteratorV1):
276 """Iterator created from input function."""
278 def __init__(self, input_fn, input_workers, input_contexts, strategy):
279 """Make an iterator for input provided via an input function.
281 Currently implements PER_WORKER mode, in which the `input_fn` is called
282 once on each worker.
284 TODO(priyag): Add other replication modes.
286 Args:
287 input_fn: Input function that returns a `tf.data.Dataset` object.
288 input_workers: an `InputWorkers` object.
289 input_contexts: A list of `InputContext` instances to be passed to call(s)
290 to `input_fn`. Length and order should match worker order in
291 `worker_device_pairs`.
292 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
293 handle last partial batch.
294 """
295 assert isinstance(input_workers, input_lib.InputWorkers)
296 if input_workers.num_workers != len(input_contexts):
297 raise ValueError("Number of input workers (%d) is not same as number of "
298 "input_contexts (%d)" %
299 (input_workers.num_workers, len(input_contexts)))
301 iterators = []
302 for i, ctx in enumerate(input_contexts):
303 worker = input_workers.worker_devices[i]
304 with ops.device(worker):
305 result = input_fn(ctx)
306 devices = input_workers.compute_devices_for_worker(i)
307 if isinstance(result, data_types.DatasetV2):
308 iterator = _SingleWorkerDatasetIterator(result, worker, devices)
309 elif callable(result):
310 iterator = _SingleWorkerCallableIterator(result, worker, devices)
311 else:
312 raise ValueError(
313 "input_fn must return a tf.data.Dataset or a callable.")
314 iterators.append(iterator)
316 super(InputFunctionIterator, self).__init__(
317 input_workers,
318 iterators,
319 strategy,
320 cardinality=cardinality_lib.UNKNOWN,
321 enable_get_next_as_optional=False)
322 self._enable_get_next_as_optional = False
325class _SingleWorkerDatasetIterator(input_lib._SingleWorkerDatasetIteratorBase): # pylint: disable=protected-access
326 """Iterator for a single DistributedDatasetV1 instance."""
328 def _make_iterator(self):
329 """Make appropriate iterator on the dataset."""
330 with ops.device(self._worker):
331 if self._options is not None:
332 self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
333 self._dataset,
334 self._devices,
335 max_buffer_size=self._options.experimental_per_replica_buffer_size,
336 prefetch_buffer_size=self._options
337 .experimental_per_replica_buffer_size)
338 else:
339 self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
340 self._dataset,
341 self._devices,
342 )
344 def initialize(self):
345 """Initialize underlying iterator.
347 In eager execution, this simply recreates the underlying iterator.
348 In graph execution, it returns the initializer ops for the underlying
349 iterator.
351 Returns:
352 A list of any initializer ops that should be run.
353 """
354 if ops.executing_eagerly_outside_functions():
355 self._iterator._eager_reset() # pylint: disable=protected-access
356 return []
357 else:
358 return [self._iterator.initializer]
360 @property
361 def output_classes(self):
362 return dataset_ops.get_legacy_output_classes(self._iterator)
364 @property
365 def output_shapes(self):
366 return dataset_ops.get_legacy_output_shapes(self._iterator)
368 @property
369 def output_types(self):
370 return dataset_ops.get_legacy_output_types(self._iterator)
373class _SingleWorkerCallableIterator(object):
374 """Iterator for a single tensor-returning callable."""
376 def __init__(self, fn, worker, devices):
377 self._fn = fn
378 self._worker = worker
379 self._devices = devices
381 def get_next(self, device, name=None):
382 """Get next element for the given device from the callable."""
383 del device, name
384 with ops.device(self._worker):
385 return self._fn()
387 def get_next_as_list(self, name=None):
388 """Get next element from the callable."""
389 del name
390 with ops.device(self._worker):
391 data_list = [self._fn() for _ in self._devices]
392 return data_list
394 def get_next_as_optional_list(self):
395 with ops.device(self._worker):
396 data_list = [
397 optional_ops.Optional.from_value(self._fn()) for _ in self._devices
398 ]
399 return data_list
401 def initialize(self):
402 # TODO(petebu) Should this throw an exception instead?
403 return []
406def _create_iterators_per_worker(worker_datasets, input_workers, options=None):
407 """Create a multidevice iterator on each of the workers."""
408 assert isinstance(input_workers, input_lib.InputWorkers)
409 assert len(worker_datasets) == len(input_workers.worker_devices)
410 iterators = []
411 for i, worker in enumerate(input_workers.worker_devices):
412 with ops.device(worker):
413 worker_devices = input_workers.compute_devices_for_worker(i)
414 iterator = _SingleWorkerDatasetIterator(
415 worker_datasets[i], # pylint: disable=protected-access
416 worker,
417 worker_devices,
418 options)
419 iterators.append(iterator)
420 return iterators