Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/collective_all_reduce_strategy.py: 27%
406 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class CollectiveAllReduceStrategy implementing DistributionStrategy."""
17import copy
18import threading
19import time
20import weakref
22from tensorflow.core.protobuf import rewriter_config_pb2
23from tensorflow.core.protobuf import tensorflow_server_pb2
24from tensorflow.python.distribute import collective_util
25from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
26from tensorflow.python.distribute import cross_device_utils
27from tensorflow.python.distribute import device_util
28from tensorflow.python.distribute import distribute_lib
29from tensorflow.python.distribute import distribute_utils
30from tensorflow.python.distribute import input_lib
31from tensorflow.python.distribute import input_util
32from tensorflow.python.distribute import mirrored_strategy
33from tensorflow.python.distribute import multi_worker_util
34from tensorflow.python.distribute import numpy_dataset
35from tensorflow.python.distribute import reduce_util
36from tensorflow.python.distribute import values
37from tensorflow.python.distribute.cluster_resolver import ClusterResolver
38from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
39from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
40from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
41from tensorflow.python.eager import context
42from tensorflow.python.framework import device as tf_device
43from tensorflow.python.framework import errors
44from tensorflow.python.framework import ops
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import collective_ops
47from tensorflow.python.ops import control_flow_util
48from tensorflow.python.platform import tf_logging as logging
49from tensorflow.python.tpu import tpu_strategy_util
50from tensorflow.python.trackable import base
51from tensorflow.python.util import deprecation
52from tensorflow.python.util.tf_export import tf_export
53from tensorflow.tsl.protobuf import coordination_config_pb2
56# pylint: disable=line-too-long
57@tf_export("distribute.MultiWorkerMirroredStrategy", v1=[])
58class CollectiveAllReduceStrategy(distribute_lib.Strategy):
59 """A distribution strategy for synchronous training on multiple workers.
61 This strategy implements synchronous distributed training across multiple
62 workers, each with potentially multiple GPUs. Similar to
63 `tf.distribute.MirroredStrategy`, it replicates all variables and computations
64 to each local device. The difference is that it uses a distributed collective
65 implementation (e.g. all-reduce), so that multiple workers can work together.
67 You need to launch your program on each worker and configure
68 `cluster_resolver` correctly. For example, if you are using
69 `tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to
70 have its corresponding `task_type` and `task_id` set in the `TF_CONFIG`
71 environment variable. An example TF_CONFIG on worker-0 of a two worker cluster
72 is:
74 ```
75 TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'
76 ```
78 Your program runs on each worker as-is. Note that collectives require each
79 worker to participate. All `tf.distribute` and non `tf.distribute` API may use
80 collectives internally, e.g. checkpointing and saving since reading a
81 `tf.Variable` with `tf.VariableSynchronization.ON_READ` all-reduces the value.
82 Therefore it's recommended to run exactly the same program on each worker.
83 Dispatching based on `task_type` or `task_id` of the worker is error-prone.
85 `cluster_resolver.num_accelerators()` determines the number of GPUs the
86 strategy uses. If it's zero, the strategy uses the CPU. All workers need to
87 use the same number of devices, otherwise the behavior is undefined.
89 This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy`
90 instead.
92 After setting up TF_CONFIG, using this strategy is similar to using
93 `tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`.
95 ```
96 strategy = tf.distribute.MultiWorkerMirroredStrategy()
98 with strategy.scope():
99 model = tf.keras.Sequential([
100 tf.keras.layers.Dense(2, input_shape=(5,)),
101 ])
102 optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
104 def dataset_fn(ctx):
105 x = np.random.random((2, 5)).astype(np.float32)
106 y = np.random.randint(2, size=(2, 1))
107 dataset = tf.data.Dataset.from_tensor_slices((x, y))
108 return dataset.repeat().batch(1, drop_remainder=True)
109 dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
111 model.compile()
112 model.fit(dist_dataset)
113 ```
115 You can also write your own training loop:
117 ```
118 @tf.function
119 def train_step(iterator):
121 def step_fn(inputs):
122 features, labels = inputs
123 with tf.GradientTape() as tape:
124 logits = model(features, training=True)
125 loss = tf.keras.losses.sparse_categorical_crossentropy(
126 labels, logits)
128 grads = tape.gradient(loss, model.trainable_variables)
129 optimizer.apply_gradients(zip(grads, model.trainable_variables))
131 strategy.run(step_fn, args=(next(iterator),))
133 for _ in range(NUM_STEP):
134 train_step(iterator)
135 ```
137 See
138 [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
139 for a detailed tutorial.
141 __Saving__
143 You need to save and checkpoint on all workers instead of just one. This is
144 because variables whose synchronization=ON_READ triggers aggregation during
145 saving. It's recommended to save to a different path on each worker to avoid
146 race conditions. Each worker saves the same thing. See
147 [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading)
148 tutorial for examples.
150 __Known Issues__
152 * `tf.distribute.cluster_resolver.TFConfigClusterResolver` does not return the
153 correct number of accelerators. The strategy uses all available GPUs if
154 `cluster_resolver` is `tf.distribute.cluster_resolver.TFConfigClusterResolver`
155 or `None`.
156 * In eager mode, the strategy needs to be created before calling any other
157 Tensorflow API.
159 """
160 # pylint: enable=line-too-long
162 # TODO(anjalisridhar): Update our guides with examples showing how we can use
163 # the cluster_resolver argument.
165 # The starting number for collective keys. This should only be set in tests.
166 _collective_key_base = 0
168 def __init__(self,
169 cluster_resolver=None,
170 communication_options=None):
171 """Creates the strategy.
173 Args:
174 cluster_resolver: optional
175 `tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
176 `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
177 communication_options: optional
178 `tf.distribute.experimental.CommunicationOptions`. This configures the
179 default options for cross device communications. It can be overridden by
180 options provided to the communication APIs like
181 `tf.distribute.ReplicaContext.all_reduce`. See
182 `tf.distribute.experimental.CommunicationOptions` for details.
183 """
184 if communication_options is None:
185 communication_options = collective_util.Options()
186 super(CollectiveAllReduceStrategy, self).__init__(
187 CollectiveAllReduceExtended(
188 self,
189 cluster_resolver=cluster_resolver,
190 communication_options=communication_options))
192 distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
193 "MultiWorkerMirroredStrategy")
194 # pylint: disable=protected-access
195 distribute_lib.distribution_strategy_replica_gauge.get_cell(
196 "num_workers").set(self.extended._num_workers)
197 distribute_lib.distribution_strategy_replica_gauge.get_cell(
198 "num_replicas_per_worker").set(self.extended._num_devices_per_worker)
200 @classmethod
201 def _from_local_devices(cls, devices, communication_options=None):
202 """A convenience method to create an object with a list of devices."""
203 obj = cls(communication_options=communication_options)
204 obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
205 return obj
207 @property
208 def cluster_resolver(self):
209 """Returns the cluster resolver associated with this strategy.
211 As a multi-worker strategy, `tf.distribute.MultiWorkerMirroredStrategy`
212 provides the associated `tf.distribute.cluster_resolver.ClusterResolver`. If
213 the user provides one in `__init__`, that instance is returned; if the user
214 does not, a default `TFConfigClusterResolver` is provided.
215 """
216 return self.extended._cluster_resolver # pylint: disable=protected-access
219class _CollectiveAllReduceStrategyExperimentalMeta(type):
221 @classmethod
222 def __instancecheck__(cls, instance):
223 # This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(),
224 # tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is
225 # performing such check.
226 return isinstance(instance, CollectiveAllReduceStrategy)
229@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[])
230class _CollectiveAllReduceStrategyExperimental(
231 CollectiveAllReduceStrategy,
232 metaclass=_CollectiveAllReduceStrategyExperimentalMeta):
234 __doc__ = CollectiveAllReduceStrategy.__doc__
236 @deprecation.deprecated(
237 None, "use distribute.MultiWorkerMirroredStrategy instead")
238 def __init__(self,
239 communication=collective_util.CommunicationImplementation.AUTO,
240 cluster_resolver=None):
241 """Creates the strategy.
243 Args:
244 communication: optional
245 `tf.distribute.experimental.CommunicationImplementation`. This is a hint
246 on the preferred collective communication implementation. Possible
247 values include `AUTO`, `RING`, and `NCCL`.
248 cluster_resolver: optional
249 `tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
250 `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
251 """
252 communication_options = collective_util.Options(
253 implementation=communication)
254 super(_CollectiveAllReduceStrategyExperimental,
255 self).__init__(cluster_resolver, communication_options)
257 @classmethod
258 def _from_local_devices(
259 cls,
260 devices,
261 communication=collective_util.CommunicationImplementation.AUTO):
262 """A convenience method to create an object with a list of devices."""
263 obj = cls(communication)
264 obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
265 return obj
268_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__
271@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring
272class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
274 __doc__ = CollectiveAllReduceStrategy.__doc__
276 # The starting number for collective keys. This should only be set in tests.
277 _collective_key_base = 0
279 def __init__(self,
280 communication=collective_util.CommunicationImplementation.AUTO,
281 cluster_resolver=None):
282 """Initializes the object."""
283 communication_options = collective_util.Options(
284 implementation=communication)
285 super(CollectiveAllReduceStrategyV1, self).__init__(
286 CollectiveAllReduceExtended(
287 self,
288 cluster_resolver=cluster_resolver,
289 communication_options=communication_options))
290 distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
291 "MultiWorkerMirroredStrategy")
292 # pylint: disable=protected-access
293 distribute_lib.distribution_strategy_replica_gauge.get_cell(
294 "num_workers").set(self.extended._num_workers)
295 distribute_lib.distribution_strategy_replica_gauge.get_cell(
296 "num_gpu_per_worker").set(
297 self.extended._num_devices_per_worker
298 if self.extended._local_device_type == "GPU"
299 else 0)
302def _is_gpu_device(device):
303 return tf_device.DeviceSpec.from_string(device).device_type == "GPU"
306class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
307 """Implementation of CollectiveAllReduceStrategy."""
309 # Whether to perdically check the health of the cluster. If any worker is not
310 # reachable, collectives are aborted and the user program should get a
311 # tf.errors.UnavailableError. It's required to restart in order to recover.
312 _enable_check_health = True
313 # Check health interval in seconds.
314 _check_health_interval = 30
315 # Timeout in seconds for the first check health. The first check health needs
316 # to wait for cluster, which may make a longer time.
317 _check_health_initial_timeout = 0
318 # Times to retry before considering the peer is down.
319 _check_health_retry_limit = 3
320 # Timeout in seconds the each check health.
321 _check_health_timeout = 10
323 def __init__(self, container_strategy, cluster_resolver,
324 communication_options, devices=None):
325 if not isinstance(communication_options, collective_util.Options):
326 raise ValueError("communication_options must be an instance of "
327 "tf.distribute.experimental.CommunicationOptions")
328 if cluster_resolver and devices:
329 raise ValueError(
330 "cluster_resolver and devices cannot be set at the same time")
332 self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
333 if not isinstance(self._cluster_resolver, ClusterResolver):
334 raise ValueError("cluster_resolver must be an instance of "
335 "tf.distribute.cluster_resolver.ClusterResolver")
336 distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
337 self._communication_options = communication_options
338 self._collective_key_base = container_strategy._collective_key_base # pylint: disable=protected-access
339 self._initialize_strategy(self._cluster_resolver, devices=devices)
340 self._cfer_fn_cache = weakref.WeakKeyDictionary()
341 self.experimental_enable_get_next_as_optional = True
342 assert isinstance(self._cross_device_ops,
343 cross_device_ops_lib.CollectiveAllReduce)
345 def _use_merge_call(self):
346 # We currently only disable merge_call when XLA is used to compile the `fn`
347 # passed to `strategy.run` and all devices are GPU.
348 return not control_flow_util.GraphOrParentsInXlaContext(
349 ops.get_default_graph()) or not all(
350 [_is_gpu_device(d) for d in self._devices])
352 def _initialize_strategy(self, cluster_resolver, devices):
353 # If devices are provided or cluster_spec is not specified, initialize
354 # single worker. Otherwise initialize multi workers.
355 if devices or not cluster_resolver.cluster_spec().as_dict():
356 self._initialize_local(cluster_resolver, devices=devices)
357 else:
358 self._initialize_multi_worker(cluster_resolver)
360 def _initialize_local_devices(self, cluster_resolver, worker_device):
361 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
362 # some cases.
363 if isinstance(cluster_resolver, TFConfigClusterResolver):
364 num_gpus = context.num_gpus()
365 num_tpus = 0
366 else:
367 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
368 num_tpus = cluster_resolver.num_accelerators().get("TPU", 0)
370 if num_gpus:
371 local_device_type = "GPU"
372 num_local_devices = num_gpus
373 elif num_tpus:
374 local_device_type = "TPU"
375 num_local_devices = num_tpus
376 else:
377 local_device_type = "CPU"
378 num_local_devices = 1
379 local_devices = tuple(
380 f"{worker_device}/device:{local_device_type}:{i}"
381 for i in range(num_local_devices))
382 return local_devices, local_device_type
384 def _initialize_local(self, cluster_resolver, devices=None):
385 """Initializes the object for local training."""
386 self._is_chief = True
387 self._num_workers = 1
389 if ops.executing_eagerly_outside_functions():
390 try:
391 context.context().configure_collective_ops(
392 scoped_allocator_enabled_ops=("CollectiveReduce",))
393 except RuntimeError:
394 logging.warning("Collective ops is not configured at program startup. "
395 "Some performance features may not be enabled.")
396 self._collective_ops_configured = True
398 if devices:
399 local_devices = devices
400 if "GPU" in devices[0]:
401 local_device_type = "GPU"
402 elif "TPU" in devices[0]:
403 local_device_type = "TPU"
404 else:
405 local_device_type = "CPU"
406 else:
407 local_devices, local_device_type = self._initialize_local_devices(
408 cluster_resolver, worker_device="")
410 self._worker_device = device_util.canonicalize("/device:CPU:0")
411 self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
413 self._collective_keys = cross_device_utils.CollectiveKeys(
414 group_key_start=1 + self._collective_key_base)
415 self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
416 devices=local_devices,
417 group_size=len(local_devices),
418 options=self._communication_options,
419 collective_keys=self._collective_keys)
420 # CrossDeviceOps for per host tensors.
421 self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
422 devices=[self._worker_device],
423 group_size=self._num_workers,
424 options=self._communication_options,
425 collective_keys=self._collective_keys)
426 super(CollectiveAllReduceExtended, self)._initialize_single_worker(
427 local_devices)
429 self._cluster_spec = None
430 self._task_type = None
431 self._task_id = None
432 self._id_in_cluster = 0
434 # This is a mark to tell whether we are running with standalone client or
435 # independent worker. Right now with standalone client, strategy object is
436 # created as local strategy and then turn into multi-worker strategy via
437 # configure call.
438 self._local_or_standalone_client_mode = True
440 # Save the num_devices_per_worker and rpc_layer for configure method.
441 self._num_devices_per_worker = len(local_devices)
442 self._local_device_type = local_device_type
443 self._rpc_layer = cluster_resolver.rpc_layer
444 self._warn_nccl_no_gpu()
446 logging.info(
447 "Single-worker MultiWorkerMirroredStrategy with local_devices "
448 "= %r, communication = %s", local_devices,
449 self._communication_options.implementation)
451 def _initialize_multi_worker(self, cluster_resolver):
452 """Initializes the object for multi-worker training."""
453 cluster_spec = multi_worker_util.normalize_cluster_spec(
454 cluster_resolver.cluster_spec())
455 task_type = cluster_resolver.task_type
456 task_id = cluster_resolver.task_id
457 if task_type is None or task_id is None:
458 raise ValueError("When `cluster_spec` is given, you must also specify "
459 "`task_type` and `task_id`.")
460 self._cluster_spec = cluster_spec
461 self._task_type = task_type
462 self._task_id = task_id
463 self._id_in_cluster = multi_worker_util.id_in_cluster(
464 self._cluster_spec, self._task_type, self._task_id)
466 self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
467 if not self._num_workers:
468 raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found "
469 "in `cluster_spec`.")
471 self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
472 task_id)
474 self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
475 self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
477 if (ops.executing_eagerly_outside_functions() and
478 not getattr(self, "_local_or_standalone_client_mode", False)):
479 context.context().configure_collective_ops(
480 collective_leader=multi_worker_util.collective_leader(
481 cluster_spec, task_type, task_id),
482 scoped_allocator_enabled_ops=("CollectiveReduce",),
483 device_filters=("/job:%s/task:%d" % (task_type, task_id),))
484 self._collective_ops_configured = True
485 if context.context().coordination_service is None:
486 coordinated_jobs = ["chief", "worker"]
487 if task_type in coordinated_jobs:
488 coordinated_job_config = []
489 for job in coordinated_jobs:
490 if job in cluster_spec.jobs:
491 coordinated_job_config.append(
492 coordination_config_pb2.CoordinatedJob(
493 name=job,
494 num_tasks=cluster_spec.num_tasks(job)))
495 context.context().configure_coordination_service(
496 service_type="standalone",
497 service_leader=multi_worker_util.coordination_leader(
498 cluster_spec),
499 coordinated_jobs=coordinated_job_config)
501 # Starting a std server in eager mode and in independent worker mode.
502 if (context.executing_eagerly() and
503 not getattr(self, "_std_server_started", False) and
504 not getattr(self, "_local_or_standalone_client_mode", False)):
505 # Checking _local_or_standalone_client_mode as well because we should not
506 # create the std server in standalone client mode.
507 config_proto = copy.deepcopy(context.context().config)
508 config_proto = self._update_config_proto(config_proto)
510 # If coordination service is enabled, use its internal heartbeat to detect
511 # peer failures instead of the Python-level health check.
512 if config_proto.experimental.coordination_config.service_type:
513 self._enable_check_health = False
515 if hasattr(cluster_resolver, "port"):
516 port = cluster_resolver.port
517 else:
518 port = 0
519 server_def = tensorflow_server_pb2.ServerDef(
520 cluster=cluster_spec.as_cluster_def(),
521 default_session_config=config_proto,
522 job_name=task_type,
523 task_index=task_id,
524 protocol=cluster_resolver.rpc_layer or "grpc",
525 port=port)
526 context.context().enable_collective_ops(server_def)
527 self._std_server_started = True
528 # The `ensure_initialized` is needed before calling
529 # `context.context().devices()`.
530 context.context().ensure_initialized()
531 logging.info(
532 "Enabled multi-worker collective ops with available devices: %r",
533 context.context().devices())
535 # TODO(yuefengz): The `num_gpus` is only for this particular task. It
536 # assumes all workers have the same number of GPUs. We should remove this
537 # assumption by querying all tasks for their numbers of GPUs.
538 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
539 # some cases.
540 local_devices, local_device_type = self._initialize_local_devices(
541 cluster_resolver, self._worker_device)
542 if local_device_type == "TPU":
543 tpu_strategy_util.initialize_tpu_system()
545 self._collective_keys = cross_device_utils.CollectiveKeys(
546 group_key_start=1 + self._collective_key_base)
547 self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
548 devices=local_devices,
549 group_size=len(local_devices) * self._num_workers,
550 options=self._communication_options,
551 collective_keys=self._collective_keys)
552 # CrossDeviceOps for per host tensors.
553 self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
554 devices=[self._worker_device],
555 group_size=self._num_workers,
556 options=self._communication_options,
557 collective_keys=self._collective_keys)
558 super(CollectiveAllReduceExtended, self)._initialize_single_worker(
559 local_devices)
561 # Add a default device so that ops without specified devices will not end up
562 # on other workers.
563 self._default_device = "/job:%s/task:%d" % (task_type, task_id)
565 # Save the num_devices_per_worker and rpc_layer for configure method.
566 self._num_devices_per_worker = len(local_devices)
567 self._local_device_type = local_device_type
568 self._rpc_layer = cluster_resolver.rpc_layer
569 self._warn_nccl_no_gpu()
571 if self._enable_check_health and context.executing_eagerly():
572 self._start_check_health_thread()
573 else:
574 logging.info("Check health not enabled.")
576 logging.info(
577 "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, "
578 "task_id = %r, num_workers = %r, local_devices = %r, "
579 "communication = %s", cluster_spec.as_dict(), task_type, task_id,
580 self._num_workers, local_devices,
581 self._communication_options.implementation)
583 def __del__(self):
584 self._stop_check_health_thread()
586 def _input_workers_with_options(self, options=None):
587 host_device = device_util.get_host_for_device(self._worker_device)
588 if not options or options.experimental_fetch_to_device:
589 return input_lib.InputWorkers([(host_device, self.worker_devices)])
590 else:
591 return input_lib.InputWorkers([(
592 host_device,
593 [device_util.get_host_for_device(worker) for worker in
594 self.worker_devices])])
596 @property
597 def _input_workers(self):
598 return self._input_workers_with_options()
600 def _get_variable_creator_initial_value(self,
601 replica_id,
602 device,
603 primary_var,
604 **kwargs):
605 if replica_id == 0: # First replica on each worker.
606 assert device is not None
607 assert primary_var is None
609 def initial_value_fn(): # pylint: disable=g-missing-docstring
610 # Only the first device participates in the broadcast of initial values.
611 group_key = self._collective_keys.get_group_key([device])
612 group_size = self._num_workers
613 collective_instance_key = (
614 self._collective_keys.get_instance_key(group_key, device))
616 with ops.device(device):
617 initial_value = kwargs["initial_value"]
618 if callable(initial_value):
619 initial_value = initial_value()
620 if isinstance(initial_value, base.CheckpointInitialValue):
621 initial_value = initial_value.wrapped_value
622 assert not callable(initial_value)
623 initial_value = ops.convert_to_tensor(
624 initial_value, dtype=kwargs.get("dtype", None))
626 if self._num_workers > 1:
627 if self._is_chief:
628 bcast_send = collective_ops.broadcast_send(
629 initial_value, initial_value.shape, initial_value.dtype,
630 group_size, group_key, collective_instance_key)
631 with ops.control_dependencies([bcast_send]):
632 return array_ops.identity(initial_value)
633 else:
634 return collective_ops.broadcast_recv(initial_value.shape,
635 initial_value.dtype,
636 group_size, group_key,
637 collective_instance_key)
638 return initial_value
640 return initial_value_fn
641 else:
642 return super(CollectiveAllReduceExtended,
643 self)._get_variable_creator_initial_value(
644 replica_id=replica_id,
645 device=device,
646 primary_var=primary_var,
647 **kwargs)
649 def _make_input_context(self):
650 input_context = distribute_lib.InputContext(
651 num_input_pipelines=self._num_workers,
652 input_pipeline_id=self._id_in_cluster,
653 num_replicas_in_sync=self._num_replicas_in_sync)
654 return input_context
656 def _experimental_distribute_dataset(self, dataset, options):
657 if (options and options.experimental_replication_mode ==
658 distribute_lib.InputReplicationMode.PER_REPLICA):
659 raise NotImplementedError(
660 "InputReplicationMode.PER_REPLICA "
661 "is only supported in "
662 "`distribute_datasets_from_function` "
663 "of tf.distribute.MirroredStrategy"
664 )
665 input_context = self._make_input_context()
666 return input_util.get_distributed_dataset(
667 dataset,
668 self._input_workers_with_options(options),
669 self._container_strategy(),
670 num_replicas_in_sync=self._num_replicas_in_sync,
671 input_context=input_context,
672 options=options)
674 def _distribute_datasets_from_function(self, dataset_fn, options):
675 if (options and options.experimental_replication_mode ==
676 distribute_lib.InputReplicationMode.PER_REPLICA):
677 raise NotImplementedError(
678 "InputReplicationMode.PER_REPLICA "
679 "is only supported in "
680 "`distribute_datasets_from_function` "
681 "of tf.distribute.MirroredStrategy")
682 input_context = self._make_input_context()
683 return input_util.get_distributed_datasets_from_function(
684 dataset_fn=dataset_fn,
685 input_workers=self._input_workers_with_options(options),
686 input_contexts=[input_context],
687 strategy=self._container_strategy(),
688 options=options)
690 def _experimental_distribute_values_from_function(self, value_fn):
691 per_replica_values = []
692 num_local_replicas = len(self.worker_devices)
693 for local_replica_id in range(num_local_replicas):
694 replica_id = (self._id_in_cluster * num_local_replicas +
695 local_replica_id)
696 value_context = distribute_lib.ValueContext(
697 replica_id, self._num_replicas_in_sync)
698 per_replica_values.append(value_fn(value_context))
699 return distribute_utils.regroup(per_replica_values, always_wrap=True)
701 def _make_dataset_iterator(self, dataset):
702 """Distributes the dataset to each local GPU."""
703 input_context = self._make_input_context()
704 return input_lib_v1.DatasetIterator(
705 dataset,
706 self._input_workers,
707 self._container_strategy(),
708 num_replicas_in_sync=self._num_replicas_in_sync,
709 input_context=input_context)
711 def _make_input_fn_iterator(
712 self,
713 input_fn,
714 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
715 """Distributes the input function to each local GPU."""
716 input_context = self._make_input_context()
717 return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers,
718 [input_context],
719 self._container_strategy())
721 def _configure(self,
722 session_config=None,
723 cluster_spec=None,
724 task_type=None,
725 task_id=None):
726 """Configures the object.
728 Args:
729 session_config: a `tf.compat.v1.ConfigProto`
730 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
731 cluster configurations.
732 task_type: the current task type, such as "worker".
733 task_id: the current task id.
735 Raises:
736 ValueError: if `task_type` is not in the `cluster_spec`.
737 """
738 if cluster_spec:
739 cluster_resolver = SimpleClusterResolver(
740 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
741 task_type=task_type,
742 task_id=task_id,
743 num_accelerators={
744 self._local_device_type: self._num_devices_per_worker},
745 rpc_layer=self._rpc_layer)
746 self._initialize_multi_worker(cluster_resolver)
747 assert isinstance(self._cross_device_ops,
748 cross_device_ops_lib.CollectiveAllReduce)
750 if session_config:
751 session_config.CopyFrom(self._update_config_proto(session_config))
753 def _update_config_proto(self, config_proto):
754 updated_config = copy.deepcopy(config_proto)
755 # Enable the scoped allocator optimization for CollectiveOps. This
756 # optimization converts many small all-reduces into fewer larger
757 # all-reduces.
758 rewrite_options = updated_config.graph_options.rewrite_options
759 rewrite_options.scoped_allocator_optimization = (
760 rewriter_config_pb2.RewriterConfig.ON)
761 # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op =
762 # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we
763 # clear and then append.
764 del rewrite_options.scoped_allocator_opts.enable_op[:]
765 rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
767 if (not ops.executing_eagerly_outside_functions() and
768 self._communication_options.implementation ==
769 collective_util.CommunicationImplementation.NCCL):
770 updated_config.experimental.collective_nccl = True
772 if not self._cluster_spec:
773 return updated_config
775 assert self._task_type
776 assert self._task_id is not None
778 # Collective group leader is needed for collective ops to coordinate
779 # workers.
780 updated_config.experimental.collective_group_leader = (
781 multi_worker_util.collective_leader(self._cluster_spec, self._task_type,
782 self._task_id))
784 # The device filters prevent communication between workers.
785 del updated_config.device_filters[:]
786 updated_config.device_filters.append(
787 "/job:%s/task:%d" % (self._task_type, self._task_id))
789 return updated_config
791 def _get_cross_device_ops(self, value):
792 # CollectiveAllReduce works on a predefined set of devices. In most cases
793 # they should be the compute devices, but certain use cases may reduce host
794 # tensors as well (e.g. early stopping). We infer the cross_device_ops to
795 # use based on the number of devices, since inputs don't always have device
796 # annotations. The compute devices one is preferred since we can potentially
797 # leverage NCCL.
798 if isinstance(value, values.DistributedValues):
799 num_devices = len(value._values) # pylint: disable=protected-access
800 else:
801 num_devices = 1
802 if num_devices == len(self.worker_devices):
803 return self._cross_device_ops
804 else:
805 return self._host_cross_device_ops
807 def _gather_to_implementation(self, value, destinations, axis, options):
808 return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access
809 value,
810 destinations=destinations,
811 axis=axis,
812 options=options)
814 def _reduce_to(self, reduce_op, value, destinations, options):
815 if (isinstance(value, values.Mirrored) and
816 reduce_op == reduce_util.ReduceOp.MEAN):
817 return value
818 assert not isinstance(value, values.Mirrored)
820 if (isinstance(value, values.DistributedValues) and
821 len(self.worker_devices) == 1):
822 value = value.values[0]
824 # When there are multiple workers, we need to reduce across workers using
825 # collective ops.
826 if (not isinstance(value, values.DistributedValues) and
827 self._num_workers == 1):
828 # This function handles reducing values that are not PerReplica or
829 # Mirrored values. For example, the same value could be present on all
830 # replicas in which case `value` would be a single value or value could
831 # be 0.
832 return cross_device_ops_lib.reduce_non_distributed_value(
833 reduce_op, value, destinations, len(self.worker_devices))
834 return self._get_cross_device_ops(value).reduce(
835 reduce_op,
836 value,
837 destinations=destinations,
838 options=self._communication_options.merge(options))
840 def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
841 """Implements `StrategyExtendedV2._replica_ctx_all_reduce`."""
842 # This implementation avoids using `merge_call` and just launches collective
843 # ops in one replica.
844 if options is None:
845 options = collective_util.Options()
847 if context.executing_eagerly():
848 # In eager mode, falls back to the default implemenation that uses
849 # `merge_call`. Replica functions are running sequentially in eager mode,
850 # and due to the blocking nature of collective ops, execution will hang if
851 # collective ops are to be launched sequentially.
852 return super()._replica_ctx_all_reduce(reduce_op, value, options)
854 replica_context = distribute_lib.get_replica_context()
855 assert replica_context, (
856 "`StrategyExtended._replica_ctx_all_reduce` must be called in a "
857 "replica context")
858 return self._cross_device_ops._all_reduce( # pylint: disable=protected-access
859 reduce_op,
860 value,
861 replica_context._replica_id, # pylint: disable=protected-access
862 options)
864 def _check_health(self):
865 while True:
866 if self._check_health_thread_should_stop.is_set():
867 return
868 for job in self._cluster_spec.jobs:
869 for task_id in range(self._cluster_spec.num_tasks(job)):
870 peer = "/job:{}/replica:0/task:{}".format(job, task_id)
871 attempts = 0
872 while True:
873 attempts += 1
874 try:
875 context.context().check_collective_ops_peer_health(
876 peer, timeout_in_ms=self._check_health_timeout * 1000)
877 # If check_collective_ops_peer_health doesn't raise an Exception,
878 # the peer is healthy.
879 break
880 except (errors.UnavailableError, errors.FailedPreconditionError,
881 errors.DeadlineExceededError) as e:
882 # TODO(b/151232436): Always raise UnavailableError when a peer
883 # fails. Now there could be many kinds of errors:
884 # - Unavailable: when the peer is not reachable, e.g. it's down.
885 # - FailedPrecondition: when the peer has restarted.
886 if attempts < self._check_health_retry_limit:
887 logging.warning("%s seems down, retrying %d/%d", peer, attempts,
888 self._check_health_retry_limit)
889 continue
890 logging.error(
891 "Cluster check alive failed, %s is down, "
892 "aborting collectives: %s", peer, e)
893 context.context().abort_collective_ops(
894 errors.UNAVAILABLE,
895 "cluster check alive failed, {} is down".format(peer))
896 return
897 except Exception as e: # pylint: disable=broad-except
898 logging.error("Unexpected exception in check alive: %s", e)
899 context.context().abort_collective_ops(
900 errors.INTERNAL,
901 "unexecpted exception in check alive: %s" % e)
902 return
903 time.sleep(self._check_health_interval)
905 def _start_check_health_thread(self):
906 # Use a dummy all-reduce as a barrier to wait for all workers to be up,
907 # otherwise the check health may fail immediately.
909 # Use array_ops.identity to create the dummy tensor so that we have a new
910 # Tensor. If we use constant it may be a cached from on a /job:localhost
911 # device, which will cause some code that relies on tensor.device to error.
912 #
913 # TODO(b/151232436): change to an explicit barrier if we have it.
914 dummy_value = array_ops.identity([])
915 logging.info("Waiting for the cluster, timeout = %s",
916 self._check_health_initial_timeout or "inf")
917 try:
918 self._host_cross_device_ops.reduce(
919 reduce_util.ReduceOp.SUM,
920 dummy_value,
921 dummy_value,
922 options=collective_util.Options(
923 timeout_seconds=self._check_health_initial_timeout,
924 implementation=collective_util.CommunicationImplementation.RING))
925 if context.is_async():
926 context.async_wait()
927 except errors.DeadlineExceededError:
928 raise RuntimeError(
929 "Timeout waiting for the cluster, timeout is %d seconds" %
930 self._check_health_initial_timeout)
931 logging.info("Cluster is ready.")
932 self._check_health_thread_should_stop = threading.Event()
933 # Start the thread as daemon to avoid it blocking the program from exiting.
934 # We try best to shutdown the thread but __del__ is not guaranteed to be
935 # called when program exists.
936 self._check_health_thread = threading.Thread(
937 target=self._check_health,
938 daemon=True)
939 self._check_health_thread.start()
941 def _stop_check_health_thread(self):
942 if getattr(self, "_check_health_thread", None):
943 logging.info("stopping check health thread")
944 self._check_health_thread_should_stop.set()
945 self._check_health_thread.join()
946 self._check_health_thread = None
947 logging.info("check health thread stopped")
949 def _warn_nccl_no_gpu(self):
950 if ((self._communication_options.implementation ==
951 collective_util.CommunicationImplementation.NCCL) and
952 self._local_device_type != "GPU"):
953 logging.warning("Enabled NCCL communication but no GPUs detected/"
954 "specified.")
956 def _in_multi_worker_mode(self):
957 """Whether this strategy indicates working in multi-worker settings."""
958 return self._num_workers > 1
960 @property
961 def experimental_between_graph(self):
962 return True
964 @property
965 def experimental_should_init(self):
966 return True
968 @property
969 def should_checkpoint(self):
970 return self._is_chief
972 @property
973 def should_save_summary(self):
974 return self._is_chief
976 @property
977 def _num_replicas_in_sync(self):
978 return len(self.worker_devices) * self._num_workers
980 # TODO(priyag): Delete this once all strategies use global batch size.
981 @property
982 def _global_batch_size(self):
983 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
985 `make_input_fn_iterator` assumes per-replica batching.
987 Returns:
988 Boolean.
989 """
990 return True
992 def _get_replica_id_in_sync_group(self, replica_id):
993 return self._id_in_cluster * len(self.worker_devices) + replica_id
995 def _get_local_replica_id(self, replica_id_in_sync_group):
996 return (replica_id_in_sync_group -
997 self._id_in_cluster * len(self.worker_devices))
999 def __deepcopy__(self, memo):
1000 # We check the check health thread instead of whether we are in eager mode
1001 # to limit the backward incompatibility.
1002 if hasattr(self, "_check_health_thread"):
1003 raise ValueError(
1004 "MultiWorkerMirroredStrategy cannot be deep copied in eager mode. "
1005 "If you're using Estimator and see this error message, call "
1006 "tf.compat.v1.disable_eager_execution() at the beginning of your "
1007 "program")
1008 # Otherwise, do a regular deepcopy.
1009 cls = self.__class__
1010 result = cls.__new__(cls)
1011 memo[id(self)] = result
1012 for k, v in self.__dict__.items():
1013 setattr(result, k, copy.deepcopy(v, memo))
1014 return result