Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_strategy.py: 25%
421 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class MirroredStrategy implementing tf.distribute.Strategy."""
17import copy
19from tensorflow.python import tf2
20from tensorflow.python.distribute import collective_util
21from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
22from tensorflow.python.distribute import cross_device_utils
23from tensorflow.python.distribute import device_util
24from tensorflow.python.distribute import distribute_lib
25from tensorflow.python.distribute import distribute_utils
26from tensorflow.python.distribute import input_lib
27from tensorflow.python.distribute import input_util
28from tensorflow.python.distribute import mirrored_run
29from tensorflow.python.distribute import multi_worker_util
30from tensorflow.python.distribute import numpy_dataset
31from tensorflow.python.distribute import reduce_util
32from tensorflow.python.distribute import values
33from tensorflow.python.distribute import values_util
34from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
35from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
36from tensorflow.python.eager import context
37from tensorflow.python.eager import record
38from tensorflow.python.framework import config
39from tensorflow.python.framework import constant_op
40from tensorflow.python.framework import device as tf_device
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import ops
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import control_flow_util
46from tensorflow.python.ops import while_loop
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.util import nest
49from tensorflow.python.util.tf_export import tf_export
51# TODO(josh11b): Replace asserts in this file with if ...: raise ...
54def _is_device_list_single_worker(devices):
55 """Checks whether the devices list is for single or multi-worker.
57 Args:
58 devices: a list of device strings or tf.config.LogicalDevice objects, for
59 either local or for remote devices.
61 Returns:
62 a boolean indicating whether these device strings are for local or for
63 remote.
65 Raises:
66 ValueError: if device strings are not consistent.
67 """
68 specs = []
69 for d in devices:
70 name = d.name if isinstance(d, context.LogicalDevice) else d
71 specs.append(tf_device.DeviceSpec.from_string(name))
72 num_workers = len({(d.job, d.task, d.replica) for d in specs})
73 all_local = all(d.job in (None, "localhost") for d in specs)
74 any_local = any(d.job in (None, "localhost") for d in specs)
76 if any_local and not all_local:
77 raise ValueError("Local device should have only 'localhost' in the job "
78 "field in device string. "
79 "E.g. 'job:localhost' in "
80 "/job:localhost/replica:0/task:0/device:CPU:0"
81 "Devices cannot have mixed list of device strings "
82 "containing both localhost and other job types such as "
83 "worker, ps etc. ")
85 if num_workers == 1 and not all_local:
86 if any(d.task is None for d in specs):
87 raise ValueError("Remote device string must have task specified."
88 "E.g. 'task:0' in "
89 "/job:worker/replica:0/task:0/device:CPU:0")
91 return num_workers == 1
94def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker):
95 """Returns a device list given a cluster spec."""
96 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
97 devices = []
98 for task_type in ("chief", "worker"):
99 for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
100 if num_gpus_per_worker == 0:
101 devices.append("/job:%s/task:%d/device:CPU:0" % (task_type, task_id))
102 else:
103 devices.extend([
104 "/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id)
105 for gpu_id in range(num_gpus_per_worker)
106 ])
107 return devices
110def _group_device_list(devices):
111 """Groups the devices list by task_type and task_id.
113 Args:
114 devices: a list of device strings for remote devices.
116 Returns:
117 a dict of list of device strings mapping from task_type to a list of devices
118 for the task_type in the ascending order of task_id.
119 """
120 assert not _is_device_list_single_worker(devices)
121 device_dict = {}
123 for d in devices:
124 d_spec = tf_device.DeviceSpec.from_string(d)
126 # Create an entry for the task_type.
127 if d_spec.job not in device_dict:
128 device_dict[d_spec.job] = []
130 # Fill the device list for task_type until it covers the task_id.
131 while len(device_dict[d_spec.job]) <= d_spec.task:
132 device_dict[d_spec.job].append([])
134 device_dict[d_spec.job][d_spec.task].append(d)
136 return device_dict
139def _is_gpu_device(device):
140 return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
143def _infer_num_gpus_per_worker(devices):
144 """Infers the number of GPUs on each worker.
146 Currently to make multi-worker cross device ops work, we need all workers to
147 have the same number of GPUs.
149 Args:
150 devices: a list of device strings, can be either local devices or remote
151 devices.
153 Returns:
154 number of GPUs per worker.
156 Raises:
157 ValueError if workers have different number of GPUs or GPU indices are not
158 consecutive and starting from 0.
159 """
160 if _is_device_list_single_worker(devices):
161 return sum(1 for d in devices if _is_gpu_device(d))
162 else:
163 device_dict = _group_device_list(devices)
164 num_gpus = None
165 for _, devices_in_task in device_dict.items():
166 for device_in_task in devices_in_task:
167 if num_gpus is None:
168 num_gpus = sum(1 for d in device_in_task if _is_gpu_device(d))
170 # Verify other workers have the same number of GPUs.
171 elif num_gpus != sum(1 for d in device_in_task if _is_gpu_device(d)):
172 raise ValueError("All workers should have the same number of GPUs.")
174 for d in device_in_task:
175 d_spec = tf_device.DeviceSpec.from_string(d)
176 if (d_spec.device_type == "GPU" and
177 d_spec.device_index >= num_gpus):
178 raise ValueError("GPU `device_index` on a worker should be "
179 "consecutive and start from 0.")
180 return num_gpus
183def all_local_devices(num_gpus=None):
184 devices = config.list_logical_devices("GPU")
185 if num_gpus is not None:
186 devices = devices[:num_gpus]
187 return devices or config.list_logical_devices("CPU")
190def all_devices():
191 devices = []
192 tfconfig = TFConfigClusterResolver()
193 if tfconfig.cluster_spec().as_dict():
194 devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(),
195 context.num_gpus())
196 return devices if devices else all_local_devices()
199@tf_export("distribute.MirroredStrategy", v1=[]) # pylint: disable=g-classes-have-attributes
200class MirroredStrategy(distribute_lib.Strategy):
201 """Synchronous training across multiple replicas on one machine.
203 This strategy is typically used for training on one
204 machine with multiple GPUs. For TPUs, use
205 `tf.distribute.TPUStrategy`. To use `MirroredStrategy` with multiple workers,
206 please refer to `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
208 For example, a variable created under a `MirroredStrategy` is a
209 `MirroredVariable`. If no devices are specified in the constructor argument of
210 the strategy then it will use all the available GPUs. If no GPUs are found, it
211 will use the available CPUs. Note that TensorFlow treats all CPUs on a
212 machine as a single device, and uses threads internally for parallelism.
214 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
215 >>> with strategy.scope():
216 ... x = tf.Variable(1.)
217 >>> x
218 MirroredVariable:{
219 0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>,
220 1: <tf.Variable ... shape=() dtype=float32, numpy=1.0>
221 }
223 While using distribution strategies, all the variable creation should be done
224 within the strategy's scope. This will replicate the variables across all the
225 replicas and keep them in sync using an all-reduce algorithm.
227 Variables created inside a `MirroredStrategy` which is wrapped with a
228 `tf.function` are still `MirroredVariables`.
230 >>> x = []
231 >>> @tf.function # Wrap the function with tf.function.
232 ... def create_variable():
233 ... if not x:
234 ... x.append(tf.Variable(1.))
235 ... return x[0]
236 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
237 >>> with strategy.scope():
238 ... _ = create_variable()
239 ... print(x[0])
240 MirroredVariable:{
241 0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>,
242 1: <tf.Variable ... shape=() dtype=float32, numpy=1.0>
243 }
245 `experimental_distribute_dataset` can be used to distribute the dataset across
246 the replicas when writing your own training loop. If you are using `.fit` and
247 `.compile` methods available in `tf.keras`, then `tf.keras` will handle the
248 distribution for you.
250 For example:
252 ```python
253 my_strategy = tf.distribute.MirroredStrategy()
254 with my_strategy.scope():
255 @tf.function
256 def distribute_train_epoch(dataset):
257 def replica_fn(input):
258 # process input and return result
259 return result
261 total_result = 0
262 for x in dataset:
263 per_replica_result = my_strategy.run(replica_fn, args=(x,))
264 total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
265 per_replica_result, axis=None)
266 return total_result
268 dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
269 for _ in range(EPOCHS):
270 train_result = distribute_train_epoch(dist_dataset)
271 ```
273 Args:
274 devices: a list of device strings such as `['/gpu:0', '/gpu:1']`. If
275 `None`, all available GPUs are used. If no GPUs are found, CPU is used.
276 cross_device_ops: optional, a descendant of `CrossDeviceOps`. If this is not
277 set, `NcclAllReduce()` will be used by default. One would customize this
278 if NCCL isn't available or if a special implementation that exploits
279 the particular hardware is available.
280 """
282 # Only set this in tests.
283 _collective_key_base = 0
285 def __init__(self, devices=None, cross_device_ops=None):
286 extended = MirroredExtended(
287 self, devices=devices, cross_device_ops=cross_device_ops)
288 super(MirroredStrategy, self).__init__(extended)
289 distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
290 "MirroredStrategy")
293@tf_export(v1=["distribute.MirroredStrategy"])
294class MirroredStrategyV1(distribute_lib.StrategyV1): # pylint: disable=g-missing-docstring
296 __doc__ = MirroredStrategy.__doc__
298 # Only set this in tests.
299 _collective_key_base = 0
301 def __init__(self, devices=None, cross_device_ops=None):
302 extended = MirroredExtended(
303 self, devices=devices, cross_device_ops=cross_device_ops)
304 super(MirroredStrategyV1, self).__init__(extended)
305 distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
306 "MirroredStrategy")
309# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
310class MirroredExtended(distribute_lib.StrategyExtendedV1):
311 """Implementation of MirroredStrategy."""
313 def __init__(self, container_strategy, devices=None, cross_device_ops=None):
314 super(MirroredExtended, self).__init__(container_strategy)
315 if context.executing_eagerly():
316 if devices and not _is_device_list_single_worker(devices):
317 raise RuntimeError("In-graph multi-worker training with "
318 "`MirroredStrategy` is not supported in eager mode.")
319 else:
320 if TFConfigClusterResolver().cluster_spec().as_dict():
321 # if you are executing in eager mode, only the single machine code
322 # path is supported.
323 logging.info("Initializing local devices since in-graph multi-worker "
324 "training with `MirroredStrategy` is not supported in "
325 "eager mode. TF_CONFIG will be ignored when "
326 "when initializing `MirroredStrategy`.")
327 devices = devices or all_local_devices()
328 else:
329 devices = devices or all_devices()
331 assert devices, ("Got an empty `devices` list and unable to recognize "
332 "any local devices.")
334 self._collective_key_base = container_strategy._collective_key_base
335 self._communication_options = collective_util.Options(
336 implementation=collective_util.CommunicationImplementation.NCCL)
337 self._cross_device_ops = cross_device_ops
338 self._initialize_strategy(devices)
340 # TODO(b/128995245): Enable last partial batch support in graph mode.
341 if ops.executing_eagerly_outside_functions():
342 self.experimental_enable_get_next_as_optional = True
344 # Flag to turn on VariablePolicy.
345 self._use_var_policy = False
347 def _use_merge_call(self):
348 # We currently only disable merge_call when XLA is used to compile the `fn`
349 # passed to `strategy.run` and all devices are GPU.
350 return not control_flow_util.GraphOrParentsInXlaContext(
351 ops.get_default_graph()) or not all(
352 [_is_gpu_device(d) for d in self._devices])
354 def _initialize_strategy(self, devices):
355 # The _initialize_strategy method is intended to be used by distribute
356 # coordinator as well.
357 assert devices, "Must specify at least one device."
358 devices = tuple(device_util.resolve(d) for d in devices)
359 assert len(set(devices)) == len(devices), (
360 "No duplicates allowed in `devices` argument: %s" % (devices,))
362 self._initialize_single_worker(devices)
364 self._collective_ops = self._make_collective_ops_with_fallbacks()
365 # If cross_device_ops is not provided, set it to collective op by default.
366 if not self._cross_device_ops:
367 self._cross_device_ops = self._collective_ops
369 def _make_collective_ops_with_fallbacks(self):
370 self._collective_keys = cross_device_utils.CollectiveKeys(
371 group_key_start=1 + self._collective_key_base)
373 if not ops.executing_eagerly_outside_functions() and any(
374 "gpu" not in d.lower() for d in self._devices):
375 # In TF1/Session, fall back to ReductionToOneDevice() if there are
376 # non-GPU devices or virtual GPUs are used.
377 return cross_device_ops_lib.ReductionToOneDevice()
379 # Use ReductionToOneDevice() if mixed devices are used.
380 if any("cpu" in d.lower() for d in self._devices) and any(
381 "gpu" in d.lower() for d in self._devices):
382 return cross_device_ops_lib.ReductionToOneDevice()
384 if all("cpu" in d.lower() for d in self._devices):
385 # Use RING collective ops if all devices are CPU.
386 self._communication_options = collective_util.Options(
387 implementation=collective_util.CommunicationImplementation.RING)
389 else:
390 physical_gpus = context.context().list_physical_devices(device_type="GPU")
391 logical_gpus = context.context().list_logical_devices(device_type="GPU")
392 # Use RING collective ops if virtual devices are used.
393 if len(physical_gpus) < len(logical_gpus):
394 self._communication_options = collective_util.Options(
395 implementation=collective_util.CommunicationImplementation.RING)
397 # If all devices are physical GPU, use NCCL implementation.
398 return cross_device_ops_lib.CollectiveAllReduce(
399 devices=self._devices,
400 group_size=len(self._devices),
401 options=self._communication_options,
402 collective_keys=self._collective_keys)
404 def _initialize_single_worker(self, devices):
405 """Initializes the object for single-worker training."""
406 self._devices = tuple(device_util.canonicalize(d) for d in devices)
407 self._input_workers_devices = (
408 (device_util.canonicalize("/device:CPU:0", devices[0]), devices),)
410 self._host_input_device = numpy_dataset.SingleDevice(
411 self._input_workers_devices[0][0])
412 device_spec = tf_device.DeviceSpec.from_string(
413 self._input_workers_devices[0][0])
414 # Ensures when we enter strategy.scope() we use the correct default device
415 if device_spec.job is not None and device_spec.job != "localhost":
416 self._default_device = "/job:%s/replica:%d/task:%d" % (
417 device_spec.job, device_spec.replica, device_spec.task)
419 logging.info("Using MirroredStrategy with devices %r", devices)
421 def _initialize_multi_worker(self, devices):
422 """Initializes the object for multi-worker training."""
423 device_dict = _group_device_list(devices)
424 workers = []
425 worker_devices = []
426 for job in ("chief", "worker"):
427 for task in range(len(device_dict.get(job, []))):
428 worker = "/job:%s/task:%d" % (job, task)
429 workers.append(worker)
430 worker_devices.append((worker, device_dict[job][task]))
432 # Setting `_default_device` will add a device scope in the
433 # distribution.scope. We set the default device to the first worker. When
434 # users specify device under distribution.scope by
435 # with tf.device("/cpu:0"):
436 # ...
437 # their ops will end up on the cpu device of its first worker, e.g.
438 # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode.
439 self._default_device = workers[0]
440 self._host_input_device = numpy_dataset.SingleDevice(workers[0])
442 self._devices = tuple(devices)
443 self._input_workers_devices = worker_devices
444 self._is_multi_worker_training = True
446 if len(workers) > 1:
447 # Grandfather usage in the legacy tests if they're configured properly.
448 if (not isinstance(self._cross_device_ops,
449 cross_device_ops_lib.ReductionToOneDevice) or
450 self._cross_device_ops._num_between_graph_workers > 1): # pylint: disable=protected-access
451 raise ValueError(
452 "In-graph multi-worker training with `MirroredStrategy` is not "
453 "supported.")
454 self._inferred_cross_device_ops = self._cross_device_ops
455 else:
456 # TODO(yuefengz): make `select_cross_device_ops` work with device strings
457 # containing job names.
458 self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
460 logging.info("Using MirroredStrategy with remote devices %r", devices)
462 def _input_workers_with_options(self, options=None):
463 if not options:
464 return input_lib.InputWorkers(self._input_workers_devices)
465 if (options.experimental_replication_mode ==
466 distribute_lib.InputReplicationMode.PER_REPLICA):
467 if options.experimental_place_dataset_on_device:
468 self._input_workers_devices = (
469 tuple(
470 (device_util.canonicalize(d, d), (d,)) for d in self._devices))
471 else:
472 self._input_workers_devices = (
473 tuple((device_util.canonicalize("/device:CPU:0", d), (d,))
474 for d in self._devices))
475 return input_lib.InputWorkers(self._input_workers_devices)
476 else:
477 if not options.experimental_fetch_to_device:
478 return input_lib.InputWorkers([
479 (host_device, (host_device,) * len(compute_devices))
480 for host_device, compute_devices in self._input_workers_devices
481 ])
482 else:
483 return input_lib.InputWorkers(self._input_workers_devices)
485 @property
486 def _input_workers(self):
487 return self._input_workers_with_options()
489 def _get_variable_creator_initial_value(self,
490 replica_id,
491 device,
492 primary_var,
493 **kwargs):
494 """Return the initial value for variables on a replica."""
495 if replica_id == 0:
496 return kwargs["initial_value"]
497 else:
498 assert primary_var is not None
499 assert device is not None
500 assert kwargs is not None
502 def initial_value_fn():
503 if context.executing_eagerly() or ops.inside_function():
504 init_value = primary_var.value()
505 return array_ops.identity(init_value)
506 else:
507 with ops.device(device):
508 init_value = primary_var.initial_value
509 return array_ops.identity(init_value)
511 return initial_value_fn
513 def _create_variable(self, next_creator, **kwargs):
514 """Create a mirrored variable. See `DistributionStrategy.scope`."""
515 colocate_with = kwargs.pop("colocate_with", None)
516 if colocate_with is None:
517 devices = self._devices
518 elif isinstance(colocate_with, numpy_dataset.SingleDevice):
519 with ops.device(colocate_with.device):
520 return next_creator(**kwargs)
521 else:
522 devices = colocate_with._devices # pylint: disable=protected-access
524 def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring
525 value_list = []
526 for i, d in enumerate(devices):
527 with ops.device(d):
528 kwargs["initial_value"] = self._get_variable_creator_initial_value(
529 replica_id=i,
530 device=d,
531 primary_var=value_list[0] if value_list else None,
532 **kwargs)
533 if i > 0:
534 # Give replicas meaningful distinct names:
535 var0name = value_list[0].name.split(":")[0]
536 # We append a / to variable names created on replicas with id > 0 to
537 # ensure that we ignore the name scope and instead use the given
538 # name as the absolute name of the variable.
539 kwargs["name"] = "%s/replica_%d/" % (var0name, i)
540 with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
541 # Don't record operations (e.g. other variable reads) during
542 # variable creation.
543 with record.stop_recording():
544 v = next_creator(**kwargs)
545 assert not isinstance(v, values.DistributedVariable)
546 value_list.append(v)
547 return value_list
549 return distribute_utils.create_mirrored_variable(
550 self._container_strategy(), _real_mirrored_creator,
551 distribute_utils.VARIABLE_CLASS_MAPPING,
552 distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs)
554 def _validate_colocate_with_variable(self, colocate_with_variable):
555 distribute_utils.validate_colocate_distributed_variable(
556 colocate_with_variable, self)
558 def _make_dataset_iterator(self, dataset):
559 return input_lib_v1.DatasetIterator(
560 dataset,
561 self._input_workers,
562 self._container_strategy(),
563 num_replicas_in_sync=self._num_replicas_in_sync)
565 def _make_input_fn_iterator(
566 self,
567 input_fn,
568 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
569 input_contexts = []
570 num_workers = self._input_workers.num_workers
571 for i in range(num_workers):
572 input_contexts.append(distribute_lib.InputContext(
573 num_input_pipelines=num_workers,
574 input_pipeline_id=i,
575 num_replicas_in_sync=self._num_replicas_in_sync))
576 return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers,
577 input_contexts,
578 self._container_strategy())
580 def _experimental_distribute_dataset(self, dataset, options):
581 if (options and options.experimental_replication_mode ==
582 distribute_lib.InputReplicationMode.PER_REPLICA):
583 raise NotImplementedError(
584 "InputReplicationMode.PER_REPLICA "
585 "is only supported in "
586 "`distribute_datasets_from_function`."
587 )
588 return input_util.get_distributed_dataset(
589 dataset,
590 self._input_workers_with_options(options),
591 self._container_strategy(),
592 num_replicas_in_sync=self._num_replicas_in_sync,
593 options=options)
595 def _experimental_make_numpy_dataset(self, numpy_input, session):
596 return numpy_dataset.one_host_numpy_dataset(
597 numpy_input, self._host_input_device, session)
599 def _distribute_datasets_from_function(self, dataset_fn, options):
600 input_workers = self._input_workers_with_options(options)
601 input_contexts = []
602 num_workers = input_workers.num_workers
603 for i in range(num_workers):
604 input_contexts.append(distribute_lib.InputContext(
605 num_input_pipelines=num_workers,
606 input_pipeline_id=i,
607 num_replicas_in_sync=self._num_replicas_in_sync))
609 return input_util.get_distributed_datasets_from_function(
610 dataset_fn, input_workers, input_contexts, self._container_strategy(),
611 options)
613 def _experimental_distribute_values_from_function(self, value_fn):
614 per_replica_values = []
615 for replica_id in range(self._num_replicas_in_sync):
616 per_replica_values.append(value_fn(
617 distribute_lib.ValueContext(replica_id,
618 self._num_replicas_in_sync)))
619 return distribute_utils.regroup(per_replica_values, always_wrap=True)
621 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
622 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
623 initial_loop_values=None):
624 if initial_loop_values is None:
625 initial_loop_values = {}
626 initial_loop_values = nest.flatten(initial_loop_values)
628 ctx = input_lib.MultiStepContext()
629 def body(i, *args):
630 """A wrapper around `fn` to create the while loop body."""
631 del args
632 fn_result = fn(ctx, iterator.get_next())
633 for (name, output) in ctx.last_step_outputs.items():
634 # Convert all outputs to tensors, potentially from `DistributedValues`.
635 ctx.last_step_outputs[name] = self._local_results(output)
636 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
637 with ops.control_dependencies([fn_result]):
638 return [i + 1] + flat_last_step_outputs
640 # We capture the control_flow_context at this point, before we run `fn`
641 # inside a while_loop. This is useful in cases where we might need to exit
642 # these contexts and get back to the outer context to do some things, for
643 # e.g. create an op which should be evaluated only once at the end of the
644 # loop on the host. One such usage is in creating metrics' value op.
645 self._outer_control_flow_context = (
646 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
648 cond = lambda i, *args: i < iterations
649 i = constant_op.constant(0)
650 loop_result = while_loop.while_loop(
651 cond,
652 body, [i] + initial_loop_values,
653 name="",
654 parallel_iterations=1,
655 back_prop=False,
656 swap_memory=False,
657 return_same_structure=True)
658 del self._outer_control_flow_context
660 ctx.run_op = control_flow_ops.group(loop_result)
662 # Convert the last_step_outputs from a list to the original dict structure
663 # of last_step_outputs.
664 last_step_tensor_outputs = loop_result[1:]
665 last_step_tensor_outputs_dict = nest.pack_sequence_as(
666 ctx.last_step_outputs, last_step_tensor_outputs)
668 for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
669 output = last_step_tensor_outputs_dict[name]
670 # For outputs that have already been reduced, wrap them in a Mirrored
671 # container, else in a PerReplica container.
672 if reduce_op is None:
673 last_step_tensor_outputs_dict[name] = distribute_utils.regroup(output)
674 else:
675 assert len(output) == 1
676 last_step_tensor_outputs_dict[name] = output[0]
678 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
679 return ctx
681 def _broadcast_to(self, tensor, destinations):
682 # This is both a fast path for Python constants, and a way to delay
683 # converting Python values to a tensor until we know what type it
684 # should be converted to. Otherwise we have trouble with:
685 # global_step.assign_add(1)
686 # since the `1` gets broadcast as an int32 but global_step is int64.
687 if isinstance(tensor, (float, int)):
688 return tensor
689 # TODO(josh11b): In eager mode, use one thread per device, or async mode.
690 if not destinations:
691 # TODO(josh11b): Use current logical device instead of 0 here.
692 destinations = self._devices
693 return self._get_cross_device_ops(tensor).broadcast(tensor, destinations)
695 def _call_for_each_replica(self, fn, args, kwargs):
696 return mirrored_run.call_for_each_replica(
697 self._container_strategy(), fn, args, kwargs)
699 def _configure(self,
700 session_config=None,
701 cluster_spec=None,
702 task_type=None,
703 task_id=None):
704 del task_type, task_id
706 if session_config:
707 session_config.CopyFrom(self._update_config_proto(session_config))
709 if cluster_spec:
710 # TODO(yuefengz): remove the following code once cluster_resolver is
711 # added.
712 num_gpus_per_worker = _infer_num_gpus_per_worker(self._devices)
713 multi_worker_devices = _cluster_spec_to_device_list(
714 cluster_spec, num_gpus_per_worker)
715 self._initialize_multi_worker(multi_worker_devices)
717 def _update_config_proto(self, config_proto):
718 updated_config = copy.deepcopy(config_proto)
719 updated_config.isolate_session_state = True
720 return updated_config
722 def _get_cross_device_ops(self, value):
723 # Always use CollectiveAllReduce when XLA is enabled, since other cross
724 # device ops don't have as good support on XLA.
725 if not self._use_merge_call():
726 if not isinstance(self._cross_device_ops,
727 cross_device_ops_lib.CollectiveAllReduce):
728 logging.warning(
729 "Under XLA context, MirroredStrategy uses CollectiveAllReduce op. "
730 "Although %r is provided to initialize MirroredStrategy, it is "
731 "ignored in XLA. Please use CollectiveAllReduce(or default option) "
732 "in the future, since other cross device ops are not well "
733 "supported on XLA.", self._cross_device_ops
734 )
735 return self._collective_ops
737 if isinstance(value, values.DistributedValues):
738 value_int32 = True in {
739 dtypes.as_dtype(v.dtype) == dtypes.int32 for v in value.values
740 }
741 else:
742 value_int32 = dtypes.as_dtype(value.dtype) == dtypes.int32
744 if value_int32:
745 return cross_device_ops_lib.ReductionToOneDevice()
746 else:
747 return self._cross_device_ops
749 def _gather_to_implementation(self, value, destinations, axis, options):
750 if not isinstance(value, values.DistributedValues):
751 # ReductionToOneDevice._gather accepts DistributedValues only.
752 return value
753 return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access
754 value,
755 destinations=destinations,
756 axis=axis,
757 options=self._communication_options.merge(options))
759 def _reduce_to(self, reduce_op, value, destinations, options):
760 if (distribute_utils.is_mirrored(value) and
761 reduce_op == reduce_util.ReduceOp.MEAN):
762 return value
763 assert not distribute_utils.is_mirrored(value)
764 def get_values(value):
765 if not isinstance(value, values.DistributedValues):
766 # This function handles reducing values that are not PerReplica or
767 # Mirrored values. For example, the same value could be present on all
768 # replicas in which case `value` would be a single value or value could
769 # be 0.
770 return cross_device_ops_lib.reduce_non_distributed_value(
771 reduce_op, value, destinations, self._num_replicas_in_sync)
773 if self._use_merge_call() and (
774 not cross_device_ops_lib._devices_match(value, destinations) or # pylint: disable=protected-access
775 any("cpu" in d.lower()
776 for d in cross_device_ops_lib.get_devices_from(destinations))):
777 return cross_device_ops_lib.ReductionToOneDevice().reduce(
778 reduce_op, value, destinations)
779 return self._get_cross_device_ops(value).reduce(
780 reduce_op,
781 value,
782 destinations=destinations,
783 options=self._communication_options.merge(options))
785 return nest.map_structure(get_values, value)
787 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
788 cross_device_ops = None
789 for value, _ in value_destination_pairs:
790 if cross_device_ops is None:
791 cross_device_ops = self._get_cross_device_ops(value)
792 elif cross_device_ops is not self._get_cross_device_ops(value):
793 raise ValueError("Inputs to batch_reduce_to must be either all on "
794 "the host or all on the compute devices.")
795 return cross_device_ops.batch_reduce(
796 reduce_op,
797 value_destination_pairs,
798 options=self._communication_options.merge(options))
800 def _update(self, var, fn, args, kwargs, group):
801 # TODO(josh11b): In eager mode, use one thread per device.
802 assert isinstance(var, values.DistributedVariable)
803 updates = []
804 for i, v in enumerate(var.values):
805 name = "update_%d" % i
806 with ops.device(v.device), \
807 distribute_lib.UpdateContext(i), \
808 ops.name_scope(name):
809 # If args and kwargs are not mirrored, the value is returned as is.
810 updates.append(
811 fn(v, *distribute_utils.select_replica(i, args),
812 **distribute_utils.select_replica(i, kwargs)))
813 return distribute_utils.update_regroup(self, updates, group)
815 def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
816 """Implements `StrategyExtendedV2._replica_ctx_all_reduce`."""
817 # This implementation avoids using `merge_call` and just launches collective
818 # ops in one replica.
819 if options is None:
820 options = collective_util.Options()
822 if context.executing_eagerly() or (
823 not tf2.enabled()) or self._use_merge_call():
824 # In eager mode, falls back to the default implementation that uses
825 # `merge_call`. Replica functions are running sequentially in eager mode,
826 # and due to the blocking nature of collective ops, execution will hang if
827 # collective ops are to be launched sequentially.
828 return super()._replica_ctx_all_reduce(reduce_op, value, options)
830 replica_context = distribute_lib.get_replica_context()
831 assert replica_context, (
832 "`StrategyExtended._replica_ctx_all_reduce` must be called in a "
833 "replica context")
834 return self._get_cross_device_ops(value)._all_reduce( # pylint: disable=protected-access
835 reduce_op,
836 value,
837 replica_context._replica_id, # pylint: disable=protected-access
838 options)
840 def _replica_ctx_update(self, var, fn, args, kwargs, group):
841 if self._use_merge_call():
842 return super()._replica_ctx_update(var, fn, args, kwargs, group)
844 replica_context = distribute_lib.get_replica_context()
845 assert replica_context
846 replica_id = values_util.get_current_replica_id_as_int()
847 name = "update_%d" % replica_id
849 if isinstance(var, values.DistributedVariable):
850 var = var._get_replica(replica_id) # pylint: disable=protected-access
852 with ops.device(var.device), ops.name_scope(name):
853 result = fn(var, *args, **kwargs)
854 return result
856 def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
857 assert isinstance(colocate_with, tuple)
858 # TODO(josh11b): In eager mode, use one thread per device.
859 updates = []
860 for i, d in enumerate(colocate_with):
861 name = "update_%d" % i
862 with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name):
863 updates.append(
864 fn(*distribute_utils.select_replica(i, args),
865 **distribute_utils.select_replica(i, kwargs)))
866 return distribute_utils.update_regroup(self, updates, group)
868 def read_var(self, replica_local_var):
869 """Read the aggregate value of a replica-local variable."""
870 # pylint: disable=protected-access
871 if distribute_utils.is_sync_on_read(replica_local_var):
872 return replica_local_var._get_cross_replica()
873 assert distribute_utils.is_mirrored(replica_local_var)
874 return array_ops.identity(replica_local_var._get())
875 # pylint: enable=protected-access
877 def value_container(self, val):
878 return distribute_utils.value_container(val)
880 @property
881 def _num_replicas_in_sync(self):
882 return len(self._devices)
884 @property
885 def worker_devices(self):
886 return self._devices
888 @property
889 def worker_devices_by_replica(self):
890 return [[d] for d in self._devices]
892 @property
893 def parameter_devices(self):
894 return self.worker_devices
896 @property
897 def experimental_between_graph(self):
898 return False
900 @property
901 def experimental_should_init(self):
902 return True
904 @property
905 def should_checkpoint(self):
906 return True
908 @property
909 def should_save_summary(self):
910 return True
912 def non_slot_devices(self, var_list):
913 del var_list
914 # TODO(josh11b): Should this be the last logical device instead?
915 return self._devices
917 # TODO(priyag): Delete this once all strategies use global batch size.
918 @property
919 def _global_batch_size(self):
920 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
922 `make_input_fn_iterator` assumes per-replica batching.
924 Returns:
925 Boolean.
926 """
927 return True
929 def _in_multi_worker_mode(self):
930 """Whether this strategy indicates working in multi-worker settings."""
931 return False
933 def _get_local_replica_id(self, replica_id_in_sync_group):
934 return replica_id_in_sync_group
936 def _get_replica_id_in_sync_group(self, replica_id):
937 return replica_id