Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/cross_device_ops.py: 21%
504 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes for different algorithms of reduction and broadcasting."""
17import collections
18import copy
19import multiprocessing.dummy
20import multiprocessing.pool
21import threading
23import numpy as np
24import six
26from tensorflow.python.client import device_lib
27from tensorflow.python.distribute import collective_util
28from tensorflow.python.distribute import cross_device_utils
29from tensorflow.python.distribute import device_util
30from tensorflow.python.distribute import distribute_utils
31from tensorflow.python.distribute import ps_values
32from tensorflow.python.distribute import reduce_util
33from tensorflow.python.distribute import tpu_values
34from tensorflow.python.distribute import values as value_lib
35from tensorflow.python.distribute import values_util
36from tensorflow.python.eager import context
37from tensorflow.python.eager import def_function
38from tensorflow.python.framework import indexed_slices
39from tensorflow.python.framework import kernels
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import tensor_util
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import resource_variable_ops
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.python.util import nest
47from tensorflow.python.util.tf_export import tf_export
48from tensorflow.tools.docs import doc_controls
51def check_destinations(destinations):
52 """Checks whether `destinations` is not empty.
54 Args:
55 destinations: a `DistributedValues`, variable, or string object.
57 Returns:
58 Boolean which is True if `destinations` is not empty.
59 """
60 # Calling bool() on a ResourceVariable is not allowed.
61 if isinstance(destinations,
62 (resource_variable_ops.BaseResourceVariable, ops.Tensor)):
63 return bool(destinations.device)
64 return bool(destinations)
67def validate_destinations(destinations):
68 """Validates the `destination` is one of expected types."""
69 if not isinstance(
70 destinations,
71 (value_lib.DistributedValues, ops.Tensor, indexed_slices.IndexedSlices,
72 ps_values.AggregatingVariable, six.string_types,
73 tpu_values.TPUMirroredVariable
74 )) and not resource_variable_ops.is_resource_variable(destinations):
75 raise ValueError("destinations must be one of a `DistributedValues` object,"
76 " a tf.Variable object, or a device string.")
78 if not check_destinations(destinations):
79 raise ValueError("destinations can not be empty")
82def reduce_non_distributed_value(reduce_op,
83 value,
84 destinations,
85 num_replicas_in_graph,
86 canonicalize_devices=True):
87 """Reduce a non-DistributedValue `value` to `destinations`."""
88 if isinstance(value, value_lib.DistributedValues):
89 raise ValueError("You are passing a `DistributedValues` to "
90 "`reduce_non_distributed_value`, which is not allowed.")
92 # If the same value is present on all replicas then the PerReplica value will
93 # be a single value. We also handle the case when `value` is a single value
94 # and equal to 0.
95 # TODO(b/138823479): handle the tensor value properly.
96 if not tensor_util.is_tf_type(value) and np.all(value == 0):
97 return np.zeros(value.shape, dtype=value.dtype)
98 # If there is only a single value and the reduce op is MEAN,
99 # that value should be on all destinations.
100 if reduce_op == reduce_util.ReduceOp.MEAN:
101 return value
102 elif num_replicas_in_graph != 1:
103 # We do not support a reduce op of SUM if the value is the same across
104 # all replicas. We call this as part of assign functions for
105 # MirroredVariables and summing up identical values across replicas is not
106 # clearly defined.
107 raise ValueError("A non-DistributedValues value %s cannot be reduced with "
108 "the given reduce op %s." % (value, reduce_op))
109 else:
110 validate_destinations(destinations)
111 return simple_broadcast(
112 value, destinations, canonicalize_devices=canonicalize_devices)
115def _make_tensor_into_per_replica(input_tensor):
116 """Converts a single tensor into a PerReplica object."""
117 if isinstance(input_tensor, value_lib.DistributedValues):
118 return input_tensor
120 # If input is not a Tensor, convert it to a Tensor first.
121 if not tensor_util.is_tensor(input_tensor):
122 input_tensor = ops.convert_to_tensor(input_tensor)
124 if hasattr(input_tensor, "device"):
125 return value_lib.PerReplica((input_tensor,))
127 raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object "
128 "because it doesn't have device set.")
131def _normalize_value_destination_pairs(value_destination_pairs):
132 """Converts each tensor into a PerReplica object in the input list."""
133 result = []
135 value_destination_pairs = list(value_destination_pairs)
137 if not isinstance(value_destination_pairs, (list, tuple)):
138 raise ValueError("`value_destination_pairs` should be a list or tuple")
139 for pair in value_destination_pairs:
140 if not isinstance(pair, tuple):
141 raise ValueError(
142 "Each element of `value_destination_pairs` should be a tuple.")
143 if len(pair) != 2:
144 raise ValueError("Each element of `value_destination_pairs` should be a "
145 "tuple of size 2.")
147 per_replica = _make_tensor_into_per_replica(pair[0])
148 result.append((per_replica, pair[1]))
149 return result
152def _validate_value_destination_pairs(value_destination_pairs):
153 """Validates value_destination_pairs are valid."""
154 # TODO(yuefengz): raise exceptions instead of returning False.
155 if not value_destination_pairs: return False
156 if not isinstance(value_destination_pairs, (list, tuple)): return False
157 if not all(isinstance(pair, tuple) for pair in value_destination_pairs):
158 return False
159 if not all(isinstance(v[0], value_lib.PerReplica)
160 for v in value_destination_pairs):
161 return False
162 return True
165# TODO(yuefengz): consider calling this function in the caller of
166# CrossDeviceOps.
167def get_devices_from(destinations, canonicalize_devices=True):
168 if isinstance(destinations, value_lib.DistributedValues):
169 return destinations._devices # pylint: disable=protected-access
170 if canonicalize_devices:
171 if isinstance(destinations, six.string_types):
172 return (device_util.resolve(destinations),)
173 return (device_util.resolve(destinations.device),)
175 # Let placer canonicalize and resolve destination devices.
176 if isinstance(destinations, six.string_types):
177 return (device_util.canonicalize_without_job_and_task(destinations),)
178 return (device_util.canonicalize_without_job_and_task(destinations.device),)
181def _devices_match(left, right, canonicalize_devices=True):
182 return left is right or set(get_devices_from(
183 left, canonicalize_devices)) == set(
184 get_devices_from(right, canonicalize_devices))
187def _all_devices_match(value_destination_pairs, canonicalize_devices=True):
188 if not all(
189 _devices_match(v, d, canonicalize_devices)
190 for v, d in value_destination_pairs):
191 return False
192 if not all(
193 _devices_match(v, value_destination_pairs[0][0], canonicalize_devices)
194 for v, _ in value_destination_pairs[1:]):
195 return False
196 return True
199def simple_broadcast(value,
200 destinations,
201 always_mirrored=False,
202 canonicalize_devices=True):
203 """Broadcast `value` to `destinations` using simple copies."""
204 devices = get_devices_from(destinations, canonicalize_devices)
205 if len(devices) == 1 and not always_mirrored:
206 return cross_device_utils.copy_tensor_or_indexed_slices_to_device(
207 value, devices[0])
208 else:
209 value_updates = []
210 for d in devices:
211 value_updates.append(
212 cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d))
213 return distribute_utils.regroup(value_updates,
214 wrap_class=value_lib.Mirrored)
217def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
218 reduce_op):
219 """Reduces the value by accumulation_fn and reduce_op."""
220 all_values = per_replica_value.values
221 if not all_values:
222 raise ValueError("`per_replica_value` must be non-empty")
223 count = len(all_values)
225 with ops.device(reduce_to_device):
226 with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
227 reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
228 all_values, accumulation_fn)
229 if reduce_op == reduce_util.ReduceOp.MEAN:
230 reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices(
231 reduced, count)
232 elif reduce_op != reduce_util.ReduceOp.SUM:
233 raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
234 return reduced
237def _simple_gather(per_replica_value, reduce_to_device, axis):
238 """Concatenate all values in the DistributedValues input and return."""
239 all_values = per_replica_value.values
240 if not all_values:
241 raise ValueError("`per_replica_value` must be non-empty")
243 with ops.device(reduce_to_device):
244 with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
245 gathered = array_ops.concat(all_values, axis)
246 return gathered
249@tf_export("distribute.CrossDeviceOps")
250class CrossDeviceOps(object):
251 """Base class for cross-device reduction and broadcasting algorithms.
253 The main purpose of this class is to be passed to
254 `tf.distribute.MirroredStrategy` in order to choose among different cross
255 device communication implementations. Prefer using the methods of
256 `tf.distribute.Strategy` instead of the ones of this class.
258 Implementations:
259 * `tf.distribute.ReductionToOneDevice`
260 * `tf.distribute.NcclAllReduce`
261 * `tf.distribute.HierarchicalCopyAllReduce`
262 """
264 def __init__(self):
265 self._canonicalize_devices = True
266 pass
268 @property
269 def _num_between_graph_workers(self):
270 # Returns 1 by default, the value may be overridden by sub classes.
271 return 1
273 def reduce(self, reduce_op, per_replica_value, destinations, options=None):
274 """Reduce `per_replica_value` to `destinations`.
276 See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in
277 the cross-replica context.
279 Args:
280 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
281 combined.
282 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
283 like object.
284 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
285 `tf.Tensor` alike object, or a device string. It specifies the devices
286 to reduce to. To perform an all-reduce, pass the same to `value` and
287 `destinations`. Note that if it's a `tf.Variable`, the value is reduced
288 to the devices of that variable, and this method doesn't update the
289 variable.
290 options: a `tf.distribute.experimental.CommunicationOptions`. See
291 `tf.distribute.experimental.CommunicationOptions` for details.
293 Returns:
294 A `tf.Tensor` or `tf.distribute.DistributedValues`.
296 Raises:
297 ValueError: if per_replica_value can't be converted to a
298 `tf.distribute.DistributedValues` or if destinations is not a string,
299 `tf.Variable` or `tf.distribute.DistributedValues`.
300 """
301 if options is None:
302 options = collective_util.Options()
304 per_replica_value = _make_tensor_into_per_replica(per_replica_value)
306 validate_destinations(destinations)
308 # Shortcut if `per_replica_value` only contains one value.
309 if self._num_between_graph_workers == 1 and len(
310 per_replica_value.values) == 1 and _devices_match(
311 per_replica_value, destinations, self._canonicalize_devices):
312 with ops.device(per_replica_value.values[0].device):
313 v = array_ops.identity(per_replica_value.values[0])
314 return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
316 if options is None:
317 options = collective_util.Options()
318 return self.reduce_implementation(reduce_op, per_replica_value,
319 destinations, options)
321 def _gather(self, per_replica_value, destinations, axis, options=None):
322 """Gather `per_replica_value` to `destinations`.
324 Args:
325 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
326 like object.
327 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
328 `tf.Tensor` alike object, or a device string. It specifies the devices
329 to gather to. To perform an all-gather, pass the same to `value` and
330 `destinations`. Note that if it's a `tf.Variable`, the value is gathered
331 to the devices of that variable, and this method doesn't update the
332 variable.
333 axis: specifies the dimension to gather along within each replica's
334 tensor.
335 options: a `tf.distribute.experimental.CommunicationOptions`. See
336 `tf.distribute.experimental.CommunicationOptions` for details.
338 Returns:
339 A `tf.Tensor` or `tf.distribute.DistributedValues`
341 Raises:
342 ValueError: if per_replica_value can't be converted to a
343 `tf.distribute.DistributedValues` or if destinations is not a string,
344 `tf.Variable` or `tf.distribute.DistributedValues`.
345 """
346 if isinstance(per_replica_value, indexed_slices.IndexedSlices):
347 raise NotImplementedError("gather/all_gather does not support "
348 "IndexedSlices")
349 if options is None:
350 options = collective_util.Options()
352 per_replica_value = _make_tensor_into_per_replica(per_replica_value)
354 validate_destinations(destinations)
356 # Shortcut if `per_replica_value` only contains one value.
357 if self._num_between_graph_workers == 1 and len(
358 per_replica_value.values) == 1 and _devices_match(
359 per_replica_value, destinations, self._canonicalize_devices):
360 with ops.device(per_replica_value.values[0].device):
361 v = array_ops.identity(per_replica_value.values[0])
362 return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
364 return self._gather_implementation(per_replica_value, destinations, axis,
365 options)
367 def _gather_implementation(self, per_replica_value, destinations, axis,
368 options):
369 """Implementation of `gather` method of `tf.distribute.CrossDeviceOps`.
371 Overriding this method is useful for subclass implementers.
373 Args:
374 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
375 like object.
376 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
377 `tf.Tensor` alike object, or a device string. It specifies the devices
378 to gather to. To perform an all-gather, pass the same to `value` and
379 `destinations`. Note that if it's a `tf.Variable`, the value is gathered
380 to the devices of that variable, this method doesn't update the
381 variable.
382 axis: specifies the dimension to gather along within each replica's
383 tensor.
384 options: a `tf.distribute.experimental.CommunicationOptions`. See
385 `tf.distribute.experimental.CommunicationOptions` for details.
387 Returns:
388 A `tf.Tensor` or `tf.distribute.DistributedValues`.
390 Raises:
391 ValueError: if per_replica_value can't be converted to a
392 `tf.distribute.DistributedValues` or if destinations is not a string,
393 `tf.Variable` or `tf.distribute.DistributedValues`.
394 """
395 raise NotImplementedError(
396 "_gather method must be implemented in descendants.")
398 def batch_reduce(self, reduce_op, value_destination_pairs, options=None):
399 """Reduce values to destinations in batches.
401 See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be
402 called in the cross-replica context.
404 Args:
405 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
406 combined.
407 value_destination_pairs: a sequence of (value, destinations) pairs. See
408 `tf.distribute.CrossDeviceOps.reduce` for descriptions.
409 options: a `tf.distribute.experimental.CommunicationOptions`. See
410 `tf.distribute.experimental.CommunicationOptions` for details.
412 Returns:
413 A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
414 in `value_destination_pairs`.
416 Raises:
417 ValueError: if `value_destination_pairs` is not an iterable of
418 tuples of `tf.distribute.DistributedValues` and destinations.
419 """
420 if options is None:
421 options = collective_util.Options()
422 # TODO(yuefengz): if destinations are different, split into several
423 # `_batch_reduce` invocations.
424 if not _validate_value_destination_pairs(value_destination_pairs):
425 # If the first element of each pair is a tensor, we try to turn it into a
426 # PerReplica object.
427 value_destination_pairs = _normalize_value_destination_pairs(
428 value_destination_pairs)
430 for _, d in value_destination_pairs:
431 validate_destinations(d)
433 # Shortcut all PerReplica objects only contain one value.
434 if self._num_between_graph_workers == 1 and _all_devices_match(
435 value_destination_pairs, self._canonicalize_devices) and len(
436 value_destination_pairs[0][0].values) == 1:
437 return [
438 distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored)
439 for v, _ in value_destination_pairs
440 ]
442 if options is None:
443 options = collective_util.Options()
444 return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
445 options)
447 def broadcast(self, tensor, destinations):
448 """Broadcast `tensor` to `destinations`.
450 This can only be called in the cross-replica context.
452 Args:
453 tensor: a `tf.Tensor` like object. The value to broadcast.
454 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
455 `tf.Tensor` alike object, or a device string. It specifies the devices
456 to broadcast to. Note that if it's a `tf.Variable`, the value is
457 broadcasted to the devices of that variable, this method doesn't update
458 the variable.
460 Returns:
461 A `tf.Tensor` or `tf.distribute.DistributedValues`.
462 """
463 validate_destinations(destinations)
464 return self.broadcast_implementation(tensor, destinations)
466 @doc_controls.for_subclass_implementers
467 def reduce_implementation(self, reduce_op, per_replica_value, destinations,
468 options):
469 """Implementation of `reduce`.
471 Overriding this method is useful for subclass implementers.
473 Args:
474 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
475 combined.
476 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
477 like object.
478 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
479 `tf.Tensor` alike object, or a device string. It specifies the devices
480 to reduce to. To perform an all-reduce, pass the same to `value` and
481 `destinations`. Note that if it's a `tf.Variable`, the value is reduced
482 to the devices of that variable, this method doesn't update the
483 variable.
484 options: a `tf.distribute.experimental.CommunicationOptions`. See
485 `tf.distribute.experimental.CommunicationOptions` for details.
487 Returns:
488 A `tf.Tensor` or `tf.distribute.DistributedValues`.
490 Raises:
491 ValueError: if per_replica_value can't be converted to a
492 `tf.distribute.DistributedValues` or if destinations is not a string,
493 `tf.Variable` or `tf.distribute.DistributedValues`.
494 """
495 raise NotImplementedError(
496 "_reduce method must be implemented in descendants.")
498 @doc_controls.for_subclass_implementers
499 def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
500 options):
501 """Implementation of `batch_reduce`.
503 Overriding this method is useful for subclass implementers.
505 Args:
506 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
507 combined.
508 value_destination_pairs: a sequence of (value, destinations) pairs. See
509 `reduce` for descriptions.
510 options: a `tf.distribute.experimental.CommunicationOptions`. See
511 `tf.distribute.experimental.CommunicationOptions` for details.
513 Returns:
514 A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
515 in `value_destination_pairs`.
517 Raises:
518 ValueError: if `value_destination_pairs` is not an iterable of
519 tuples of `tf.distribute.DistributedValues` and destinations.
520 """
521 raise NotImplementedError(
522 "batch_reduce_implementation method must be implemented in descendants."
523 )
525 @doc_controls.for_subclass_implementers
526 def broadcast_implementation(self, tensor, destinations):
527 """Implementation of `broadcast`.
529 Args:
530 tensor: a `tf.Tensor` like object. The value to broadcast.
531 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
532 `tf.Tensor` alike object, or a device string. It specifies the devices
533 to broadcast to.
534 `destinations`. Note that if it's a `tf.Variable`, the value is
535 broadcasted to the devices of that variable, this method doesn't update
536 the variable.
538 Returns:
539 A `tf.Tensor` or `tf.distribute.DistributedValues`.
540 """
541 return simple_broadcast(
542 tensor,
543 destinations,
544 always_mirrored=True,
545 canonicalize_devices=self._canonicalize_devices)
547 # ========================== Collective APIs ================================
548 #
549 # Different than `reduce`, `batch_reduce` and `broadcast` which must be called
550 # in cross-replcia context, collective APIs are to be called in replica
551 # context.
553 def _all_reduce(self, reduce_op, value, replica_id, options):
554 """All-reduce the `value` across all replicas so that all get the result.
556 `value` can be a nested structure of tensors or `IndexedSlices`. The
557 implementation should generally batch the all-reduces when possible.
558 `options` can be set to hint the batching behavior.
560 This API must be called in a replica context.
562 Args:
563 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
564 be combined.
565 value: Value to be reduced. A tensor or a nested structure of tensors or
566 `IndexedSlices`.
567 replica_id: An interger indicating the id of the replica where this
568 all_reduce is called under. This is the local replica id that ranges
569 from 0 to len(local_devices) - 1.
570 options: A `tf.distribute.experimental.CommunicationOptions`.
572 Returns:
573 A tensor/IndexedSlices or a nested strucutre of tensors/IndexedSlices with
574 the reduced values. The structure is the same as `value`.
575 """
576 raise NotImplementedError("_all_reduce must be implemented in descendants.")
579@tf_export("distribute.ReductionToOneDevice")
580class ReductionToOneDevice(CrossDeviceOps):
581 """A CrossDeviceOps implementation that copies values to one device to reduce.
583 This implementation always copies values to one device to reduce them, then
584 broadcast reduced values to the destinations. It doesn't support efficient
585 batching.
587 Here is how you can use `ReductionToOneDevice` in
588 `tf.distribute.MirroredStrategy`:
590 ```
591 strategy = tf.distribute.MirroredStrategy(
592 cross_device_ops=tf.distribute.ReductionToOneDevice())
593 ```
594 """
596 def __init__(self, reduce_to_device=None, accumulation_fn=None):
597 """Initializes with a device to reduce to and a way to accumulate.
599 Args:
600 reduce_to_device: the intermediate device to reduce to. If None, reduce
601 to the first device in `destinations` of the `reduce` method.
602 accumulation_fn: a function that does accumulation. If None,
603 `tf.math.add_n` is used.
604 """
605 self.reduce_to_device = reduce_to_device
606 self.accumulation_fn = accumulation_fn or math_ops.add_n
607 super(ReductionToOneDevice, self).__init__()
609 def reduce_implementation(self, reduce_op, per_replica_value, destinations,
610 options):
611 del options # Unused.
612 if check_destinations(destinations):
613 devices = get_devices_from(destinations, self._canonicalize_devices)
614 else:
615 devices = get_devices_from(per_replica_value, self._canonicalize_devices)
616 reduce_to_device = self.reduce_to_device or devices[0]
617 logging.log_first_n(
618 logging.INFO,
619 "Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10)
620 reduced = _simple_reduce(per_replica_value, reduce_to_device,
621 self.accumulation_fn, reduce_op)
622 return self.broadcast(reduced, destinations)
624 def _gather_implementation(self, per_replica_value, destinations, axis,
625 options):
626 del options # Unused.
627 if check_destinations(destinations):
628 devices = get_devices_from(destinations, self._canonicalize_devices)
629 else:
630 devices = get_devices_from(per_replica_value, self._canonicalize_devices)
631 reduce_to_device = self.reduce_to_device or devices[0]
632 logging.log_first_n(
633 logging.INFO,
634 "Gather to %s then broadcast to %r." % (reduce_to_device, devices), 10)
635 gathered = _simple_gather(per_replica_value, reduce_to_device, axis)
636 return self.broadcast(gathered, destinations)
638 def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
639 options):
640 return [
641 self.reduce_implementation(
642 reduce_op, t, destinations=v, options=options)
643 for t, v in value_destination_pairs
644 ]
647def _group_value_by_device(per_replica_values):
648 """Group values into sublists by their devices.
650 This grouping is needed to call the all-reduce library because it expects a
651 list of the following form:
652 [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...],
653 [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...],
654 [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...],
655 ...
656 ]
658 Args:
659 per_replica_values: a list of PerReplica objects.
661 Returns:
662 a list of lists, each sublist has components for its corresponding device of
663 PerReplica objects, paired with a None.
664 """
665 destinations = per_replica_values[0]._devices # pylint: disable=protected-access
666 grouped = [[] for _ in range(len(destinations))]
667 for per_replica_value in per_replica_values:
668 # pylint: disable=protected-access
669 for i, v in enumerate(per_replica_value.values):
670 assert per_replica_value._devices == destinations
671 grouped[i].append((v, None))
672 return grouped
675def _ungroup_and_make_mirrored(grouped_reduced,
676 destinations,
677 reduce_op,
678 num_between_graph_workers=1):
679 """Ungroup results from all-reduce and make Mirrored objects.
681 Each all-reduce result will be divided by the number of destinations before
682 Mirrored objects are created if reduce_op is "mean".
684 Args:
685 grouped_reduced: a list of lists, each sublist has components for each
686 device, paired with a None. It is the result from
687 cross_device_utils.aggregate_gradients_using*.
688 destinations: a value to colocate the result with.
689 reduce_op: Indicates how values will be aggregated. Accepted values
690 are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
691 num_between_graph_workers: number of workers in the between-graph
692 replication.
694 Returns:
695 a list of Mirrored objects.
696 """
697 num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers
698 index = [[] for _ in range(len(grouped_reduced[0]))]
699 for per_replica_reduced in grouped_reduced:
700 for i, (v, _) in enumerate(per_replica_reduced):
701 if reduce_op == reduce_util.ReduceOp.MEAN:
702 with ops.device(v.device):
703 index[i].append(v / num_replicas)
704 else:
705 index[i].append(v)
706 return [distribute_utils.regroup(
707 v, wrap_class=value_lib.Mirrored) for v in index]
710class _ConcatAndSplitPacker(object):
711 """Concatenate and split tensors for reduction."""
713 def __init__(self, num_packs=1):
714 """Initialize the _ConcatAndSplitPacker object.
716 Args:
717 num_packs: specifies the number of split packs that will be
718 formed.
720 Raises:
721 ValueError: if num_packs is not greater than 0.
722 """
723 if num_packs <= 0:
724 raise ValueError("num_packs must be greater than zero.")
725 self.num_packs = num_packs
727 def pack(self, grouped_grads_and_vars):
728 """Pack tensors."""
729 self.grouped_grads_and_vars = grouped_grads_and_vars
730 self.all_device_shapes = []
731 self.all_device_sizes = []
733 device_grad_packs = []
734 for device_grads_and_vars in grouped_grads_and_vars:
735 with ops.colocate_with(device_grads_and_vars[0][0]):
736 # Flatten all the grads.
737 flat_grads = [
738 array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars
739 ]
740 # Remember the original shape of all the grads.
741 device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars]
742 # Remember the original sizes of all the grads.
743 device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars]
744 # Concat all the flat grads into a big flat tensor.
745 concat_grads = array_ops.concat(flat_grads, 0)
747 # Split the big tensor into num_splits packs. In cases where the
748 # total size is not divisible num_splits, the last pack gets
749 # more elements.
750 # TODO(zhengxq): it is also possible to optimize away all the concat
751 # as well.
752 num_splits = self.num_packs
754 # The array_ops.size function will sometimes remove static shapes. So if
755 # all gradient shapes are defined, we use another method to get the
756 # total size.
757 # TODO(yuefengz): move this logic to array_ops.size.
758 if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars):
759 total_grad_size = sum(
760 [g.shape.num_elements() for g, _ in device_grads_and_vars])
761 else:
762 total_grad_size = array_ops.size(concat_grads)
764 split_size = total_grad_size // num_splits
765 split_size_last = total_grad_size - split_size * (num_splits - 1)
766 split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
767 grad_packs = array_ops.split(concat_grads, split_sizes)
769 # Ready to aggregate the repacked gradients, with fake variables.
770 # TODO(zhengxq): It is hacky to have to use fake variables.
771 # We should remove the need for variables in
772 # aggregate_gradients_using*.
773 device_grad_packs.append(zip(grad_packs, [None] * num_splits))
774 self.all_device_shapes.append(device_shapes)
775 self.all_device_sizes.append(device_sizes)
777 return device_grad_packs
779 def unpack(self, summed_device_grad_packs):
780 """Reverse the pack."""
781 aggregated_device_grads = []
782 for (summed_device_grad_packs,
783 device_grads_and_vars, device_shapes, device_sizes) in zip(
784 summed_device_grad_packs, self.grouped_grads_and_vars,
785 self.all_device_shapes, self.all_device_sizes):
786 # pylint: enable=line-too-long
787 # Reverse the packing operations in the previous steps. Form the
788 # summed gradients back into their original shapes.
789 with ops.colocate_with(summed_device_grad_packs[0][0]):
790 # Form a list of the summed grad packs.
791 device_grad_packs = [g for g, _ in summed_device_grad_packs]
793 # Concat them back into a big flat tensor.
794 device_grads_concat = array_ops.concat(device_grad_packs, 0)
796 # Split the tensors back into their original sizes.
797 grads_with_sizes = array_ops.split(device_grads_concat, device_sizes)
799 # Reshape the tensors back into their original shapes.
800 grads_with_shapes = [
801 array_ops.reshape(grad, shape)
802 for shape, grad in zip(device_shapes, grads_with_sizes)
803 ]
805 # Form the list with the original list of variables.
806 summed_device_grads = [
807 (g, v) for g, (_, v) in zip(grads_with_shapes,
808 device_grads_and_vars)
809 ]
810 aggregated_device_grads.append(summed_device_grads)
811 return aggregated_device_grads
814def _pack_tensors(device_grads, num_packs=0):
815 """Pack tensors if specified."""
816 if num_packs > 0:
817 tensor_packer = _ConcatAndSplitPacker(num_packs)
818 device_grad_packs = tensor_packer.pack(device_grads)
819 else:
820 tensor_packer = None
821 device_grad_packs = device_grads
822 return device_grad_packs, tensor_packer
825def _unpack_tensors(reduced, tensor_packer=None):
826 """Unpack tensors if they are packed before all-reduce."""
827 if tensor_packer:
828 return tensor_packer.unpack(reduced)
829 return reduced
832class AllReduceCrossDeviceOps(CrossDeviceOps):
833 """All-reduce implementation of CrossDeviceOps.
835 It performs all-reduce when applicable using NCCL or hierarchical copy. For
836 the batch API, tensors will be repacked or aggregated for more efficient
837 cross-device transportation.
839 For reduces that are not all-reduce, it falls back to
840 `tf.distribute.ReductionToOneDevice`.
841 """
843 def __init__(self, all_reduce_alg="nccl", num_packs=1):
844 """Initializes the object.
846 Args:
847 all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or
848 "hierarchical_copy" are supported.
849 num_packs: a non-negative integer. The number of packs to split values
850 into. If zero, no packing will be done.
851 """
852 self._all_reduce_alg = all_reduce_alg
853 self._num_packs = num_packs
854 self._simple_cross_replica_ops = ReductionToOneDevice()
855 super(AllReduceCrossDeviceOps, self).__init__()
857 def reduce_implementation(self, reduce_op, per_replica_value, destinations,
858 options):
859 del options # Unused.
860 # To use NCCL or all-reduce, source and destination devices should match,
861 # and none of the devices should be CPU.
862 if (_devices_match(per_replica_value, destinations) and
863 not any("cpu" in d.lower() for d in get_devices_from(destinations))):
864 return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
865 else:
866 return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
867 destinations)
869 def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
870 options):
871 if _all_devices_match(value_destination_pairs):
872 return self._batch_all_reduce(reduce_op,
873 [v[0] for v in value_destination_pairs])
874 else:
875 return [
876 self.reduce_implementation(reduce_op, value, dest, options)
877 for value, dest in value_destination_pairs
878 ]
880 def _batch_all_reduce(self, reduce_op, per_replica_values):
881 """All-reduce algorithm in a batch."""
882 dense_values, dense_indices, sparse_values, sparse_indices = (
883 cross_device_utils.split_by_sparsity(per_replica_values))
884 if dense_values:
885 dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
886 else:
887 dense_results = []
888 if sparse_values:
889 sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
890 sparse_values)
891 else:
892 sparse_results = []
893 return cross_device_utils.stitch_values(((dense_results, dense_indices),
894 (sparse_results, sparse_indices)))
896 def _do_batch_all_reduce(self, reduce_op, dense_values):
897 """Run batch all-reduces."""
898 logging.log_first_n(
899 logging.INFO,
900 "batch_all_reduce: %d all-reduces with algorithm = %s, num_packs = %d" %
901 (len(dense_values), self._all_reduce_alg, self._num_packs), 10)
903 destinations = dense_values[0]._devices # pylint: disable=protected-access
904 grouped = _group_value_by_device(dense_values)
906 # device_grad_packs:
907 # [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]]
908 device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs)
910 # The actual aggregation of the repacked gradients. Note that they are
911 # sharded among different aggregation trees. So it is important to strike
912 # the balance on num_splits.
913 if self._all_reduce_alg == "nccl":
914 # TODO(yuefengz): merge this into the all-reduce library.
915 reduced = cross_device_utils.aggregate_gradients_using_nccl(
916 device_grad_packs)
917 else:
918 # TODO(yuefengz): check that gpu ids in `destinations` are in ascending
919 # order.
920 reduced = (
921 cross_device_utils.aggregate_gradients_using_hierarchical_copy(
922 destinations, device_grad_packs))
924 reduced = _unpack_tensors(reduced, tensor_packer)
925 return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op)
927 def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values):
928 """Run batch all-reduce for sparse values."""
929 logging.log_first_n(
930 logging.WARN,
931 "Efficient allreduce is not supported for %d IndexedSlices" %
932 len(sparse_values), 10)
933 # Use `sparse_values` as destinations to do all-reduces. It is effectively
934 # an allgather under the hood but not an efficient one.
935 return self._simple_cross_replica_ops.batch_reduce(
936 reduce_op, zip(sparse_values, sparse_values))
938 def _gather_implementation(self, per_replica_value, destinations, axis,
939 options):
940 logging.log_first_n(
941 logging.WARN,
942 "gather/all_gather with NCCL or HierarchicalCopy is not supported. "
943 "Falling back to gather on one device and then broadcast. We're working"
944 " on a more efficient implementation.", 3)
945 return ReductionToOneDevice()._gather(per_replica_value, destinations, axis, # pylint: disable=protected-access
946 options)
949# For compatibility with code using the old name of `AllReduceCrossDeviceOps`.
950AllReduceCrossTowerOps = AllReduceCrossDeviceOps
953AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
954 "alg shards limit")
957@tf_export("distribute.NcclAllReduce")
958class NcclAllReduce(AllReduceCrossDeviceOps):
959 """NCCL all-reduce implementation of CrossDeviceOps.
961 It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be
962 repacked or aggregated for more efficient cross-device transportation.
964 For reduces that are not all-reduce, it falls back to
965 `tf.distribute.ReductionToOneDevice`.
967 Here is how you can use `NcclAllReduce` in `tf.distribute.MirroredStrategy`:
970 ```
971 strategy = tf.distribute.MirroredStrategy(
972 cross_device_ops=tf.distribute.NcclAllReduce())
973 ```
974 """
976 def __init__(self, num_packs=1):
977 """Initializes the object.
979 Args:
980 num_packs: a non-negative integer. The number of packs to split values
981 into. If zero, no packing will be done.
983 Raises:
984 ValueError: if `num_packs` is negative.
985 """
986 if num_packs < 0:
987 raise ValueError(
988 "NCCL all-reduce requires num_packs >= 0, but {} is specified".format(
989 num_packs))
990 super(NcclAllReduce, self).__init__(
991 all_reduce_alg="nccl", num_packs=num_packs)
994@tf_export("distribute.HierarchicalCopyAllReduce")
995class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
996 """Hierarchical copy all-reduce implementation of CrossDeviceOps.
998 It reduces to one GPU along edges in some hierarchy and broadcasts back to
999 each GPU along the same path. For the batch API, tensors will be repacked or
1000 aggregated for more efficient cross-device transportation.
1002 This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like
1003 that on DGX-1 machine. If you have different GPU inter-connections, it is
1004 likely that it would be slower than `tf.distribute.ReductionToOneDevice`.
1006 For reduces that are not all-reduce, it falls back to
1007 `tf.distribute.ReductionToOneDevice`.
1009 Here is how you can use `HierarchicalCopyAllReduce` in
1010 `tf.distribute.MirroredStrategy`:
1012 ```
1013 strategy = tf.distribute.MirroredStrategy(
1014 cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
1015 ```
1016 """
1018 def __init__(self, num_packs=1):
1019 """Initializes the object.
1021 Args:
1022 num_packs: a non-negative integer. The number of packs to split values
1023 into. If zero, no packing will be done.
1025 Raises:
1026 ValueError if `num_packs` is negative.
1027 """
1028 if num_packs < 0:
1029 raise ValueError(
1030 "HierarchicalCopy requires num_packs >= 0, but {} is specified"
1031 .format(num_packs))
1032 super(HierarchicalCopyAllReduce, self).__init__(
1033 all_reduce_alg="hierarchical_copy",
1034 num_packs=num_packs)
1037# TODO(crccw): remove after migrating all callers.
1038CollectiveCommunication = collective_util.CommunicationImplementation
1039CommunicationImplementation = collective_util.CommunicationImplementation
1042# TODO(yuefengz): support in-graph collective all-reduce.
1043class CollectiveAllReduce(CrossDeviceOps):
1044 """All-reduce cross device ops using collective ops.
1046 In the between-graph replicated training, it will still do all-reduces across
1047 all workers and then put results on the right destinations.
1048 """
1050 def __init__(self,
1051 devices,
1052 group_size,
1053 options,
1054 collective_keys=None,
1055 canonicalize_devices=True):
1056 """Initializes the object.
1058 Args:
1059 devices: a list of device strings to run collectives on.
1060 group_size: the global group size. For between-graph replicated training
1061 it's the total number of devices across all workers.
1062 options: a `tf.distribute.experimental.CommunicationOptions`.
1063 collective_keys: an optional CollectiveKey object.
1064 canonicalize_devices: Whether to canonicalize devices for workers or not.
1065 """
1066 if group_size % len(devices) > 0:
1067 raise ValueError("group_size must be divisible by the number of devices.")
1069 self._group_size = group_size
1070 self._options = options
1071 self._collective_keys = (collective_keys or
1072 cross_device_utils.CollectiveKeys())
1073 # This lock guards all collective launches, i.e. calls to
1074 # cross_device_utils.build_collectve_*.
1075 #
1076 # In a multi threaded eager program we need to ensure different groups of
1077 # collectives don't interleave each other, otherwise there could be
1078 # deadlocks. E.g. if two user threads both are launching collectives:
1079 # user-thread-0 device0 device1
1080 # user-thread-1 device0 device1
1081 # In eager mode, we use one thread per device to launch collective ops, so
1082 # the above launch sequences end up with the following queues:
1083 # device-0 collective-0 collective-1
1084 # device-1 collective-1 collective-0
1085 # This deadlocks since neither collective is able to finish.
1086 self._lock = threading.Lock()
1088 if canonicalize_devices:
1089 self._devices = tuple(device_util.canonicalize(d) for d in devices)
1090 else:
1091 self._devices = tuple(
1092 device_util.canonicalize_without_job_and_task(d) for d in devices)
1093 group_key = self._collective_keys.get_group_key(self._devices)
1094 self._launchers = []
1095 # Whether to only use NCCL for batched all-reduce when NCCL is requested.
1096 # This is because of the lack of mechanism to order NCCL operations
1097 # deterministically.
1098 self._limited_nccl = False
1099 for device in self._devices:
1100 launcher = cross_device_utils.CollectiveReplicaLauncher(
1101 group_key, group_size, self._collective_keys, device, options)
1102 self._launchers.append(launcher)
1103 if not launcher.can_order_nccl():
1104 self._limited_nccl = True
1106 super(CollectiveAllReduce, self).__init__()
1107 self._canonicalize_devices = canonicalize_devices
1109 @property
1110 def _num_between_graph_workers(self):
1111 # Currently we only support equal number of devices on each worker.
1112 return self._group_size / len(self._devices)
1114 def _all_reduce(self, reduce_op, value, replica_id, options):
1115 """Implements CrossDeviceOps.all_reduce."""
1116 # TODO(b/122840926): reuse this method in _batch_all_reduce.
1117 flat_values = nest.flatten(value)
1119 # If NCCL launches can't be ordered (self._limited_nccl == True), we only
1120 # use NCCL when batch_size > 1, hoping that there's only one batched
1121 # all-reduce, which is the gradient aggregation in optimizer. For TF 2.x,
1122 # NCCL launches are always ordered.
1123 if (self._limited_nccl and options.implementation
1124 == collective_util.CommunicationImplementation.NCCL and
1125 len(flat_values) == 1):
1126 options = options.merge(
1127 collective_util.Options(
1128 implementation=collective_util.CommunicationImplementation.RING))
1130 launcher = self._launchers[replica_id]
1131 dense_values, dense_indices, sparse_values, sparse_indices = (
1132 cross_device_utils.split_by_sparsity(flat_values))
1133 dense_results = []
1134 sparse_results = []
1136 if dense_values:
1137 # Reverse the lists so that there's better chance that values follows
1138 # the order in which they are calculated (e.g. when they're gradients), so
1139 # as to overlap calculation with communication. However, this may not be
1140 # optimal for cases like gradients of complicated non-sequential models.
1141 #
1142 # Note that we reverse the list before packing so that the first pack
1143 # won't be too small, since it's more likely for first few packs to have
1144 # long queuing time due to concurrent intense computation.
1145 #
1146 # TODO(b/147393503): explore solutions for optimal ordering.
1147 dense_values.reverse()
1148 packs = cross_device_utils.group_by_size(dense_values,
1149 options.bytes_per_pack)
1151 if not context.executing_eagerly() and replica_id == 0:
1152 logging.info(
1153 "Collective all_reduce tensors: %d all_reduces, num_devices = %d, "
1154 "group_size = %d, implementation = %s, num_packs = %d",
1155 len(dense_values), len(self._launchers), self._group_size,
1156 options.implementation, len(packs))
1158 dense_results = launcher.batch_all_reduce(packs, options)
1159 if reduce_op == reduce_util.ReduceOp.MEAN:
1160 for i, v in enumerate(dense_results):
1161 with ops.device(self._devices[replica_id]):
1162 dense_results[i] = v / self._group_size
1163 dense_results.reverse()
1165 if sparse_values:
1166 if not context.executing_eagerly() and replica_id == 0:
1167 logging.info(
1168 "Collective all_reduce IndexedSlices: %d all_reduces, num_devices ="
1169 "%d, group_size = %d, implementation = %s", len(sparse_values),
1170 len(self._launchers), self._group_size, options.implementation)
1172 for indexed_slice in sparse_values:
1173 sparse_results.append(
1174 launcher.all_reduce_indexed_slices(indexed_slice, options))
1176 if reduce_op == reduce_util.ReduceOp.MEAN:
1177 for i, v in enumerate(sparse_results):
1178 with ops.device(self._devices[replica_id]):
1179 sparse_results[i] = indexed_slices.IndexedSlices(
1180 values=sparse_results[i].values / self._group_size,
1181 indices=sparse_results[i].indices,
1182 dense_shape=sparse_results[i].dense_shape)
1184 flat_results = cross_device_utils.stitch_values(
1185 ((dense_results, dense_indices), (sparse_results, sparse_indices)))
1186 return nest.pack_sequence_as(value, flat_results)
1188 def _all_reduce_per_replica_values(self, reduce_op, per_replica_values,
1189 options):
1190 """All reduce a list of per_replica_value."""
1191 values_by_device = [[] for _ in self._devices]
1192 num_devices = len(self._devices)
1193 for per_replica in per_replica_values:
1194 for i in range(num_devices):
1195 values_by_device[i].append(per_replica.values[i])
1197 if context.executing_eagerly():
1199 def thread_fn(device_id):
1200 with context.eager_mode():
1201 return self._all_reduce(reduce_op, values_by_device[device_id],
1202 device_id, options)
1204 with self._lock:
1205 pool = multiprocessing.pool.ThreadPool(len(self._devices))
1206 outputs_by_device = pool.map(thread_fn, list(range(num_devices)))
1207 pool.close()
1208 else:
1209 outputs_by_device = []
1210 with self._lock:
1211 for i in range(num_devices):
1212 outputs_by_device.append(
1213 self._all_reduce(reduce_op, values_by_device[i], i, options))
1215 result = []
1216 for values in zip(*outputs_by_device):
1217 result.append(
1218 distribute_utils.regroup(values, wrap_class=value_lib.Mirrored))
1219 return result
1221 def reduce_implementation(self, reduce_op, per_replica_value, destinations,
1222 options):
1223 values_util.mark_as_unsaveable()
1224 all_reduced = self._all_reduce_per_replica_values(reduce_op,
1225 [per_replica_value],
1226 options)[0]
1227 devices = get_devices_from(destinations, self._canonicalize_devices)
1229 if _devices_match(per_replica_value, destinations,
1230 self._canonicalize_devices):
1231 return all_reduced
1233 # Convert `all_reduced` to a `Mirrored` object, as a simple and uniform
1234 # utility to access component for a particular device.
1235 if not isinstance(all_reduced, value_lib.Mirrored):
1236 all_reduced = value_lib.Mirrored([all_reduced])
1238 # If we got this far, the destination devices do not match the all-reduce
1239 # devices, so we must map from one to the other.
1240 index = []
1241 # We must add these control dependencies, otherwise we can get deadlock.
1242 with ops.control_dependencies(all_reduced.values):
1243 for d in devices:
1244 with ops.device(d):
1245 for v in all_reduced.values:
1246 if v.device == d:
1247 index.append(array_ops.identity(v))
1248 break
1249 else:
1250 # TODO(josh11b): Once we add support for model parallelism, get the
1251 # copy from the corresponding replica instead of the primary.
1252 index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access
1253 return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
1255 def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
1256 options):
1257 values_util.mark_as_unsaveable()
1258 all_devices_match = _all_devices_match(value_destination_pairs,
1259 self._canonicalize_devices)
1260 if all_devices_match:
1261 return self._all_reduce_per_replica_values(
1262 reduce_op, [v[0] for v in value_destination_pairs], options)
1263 else:
1264 if not all_devices_match:
1265 logging.log_first_n(
1266 logging.WARN, "Efficient batch_reduce is not supported if "
1267 "destinations are different.", 10)
1269 return [
1270 self.reduce_implementation(reduce_op, value, dest, options)
1271 for value, dest in value_destination_pairs
1272 ]
1274 def _gather_implementation(self, per_replica_value, destinations, axis,
1275 options):
1276 all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0]
1277 values_util.mark_as_unsaveable()
1278 devices = get_devices_from(destinations, self._canonicalize_devices)
1280 if _devices_match(per_replica_value, destinations,
1281 self._canonicalize_devices):
1282 return all_gathered
1284 # Convert `all_gathered` to a `Mirrored` object, as a simple and uniform
1285 # utility to access component for a particular device.
1286 if not isinstance(all_gathered, value_lib.Mirrored):
1287 all_gathered = value_lib.Mirrored([all_gathered])
1289 # If we got this far, the destination devices do not match the all-gather
1290 # devices, so we must map from one to the other.
1291 index = []
1292 # We must add these control dependencies, otherwise we can get deadlock.
1293 with ops.control_dependencies(all_gathered.values):
1294 for d in devices:
1295 with ops.device(d):
1296 for v in all_gathered.values:
1297 if v.device == d:
1298 index.append(array_ops.identity(v))
1299 break
1300 else:
1301 index.append(array_ops.identity(all_gathered._primary)) # pylint: disable=protected-access
1302 return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
1304 def _batch_all_gather(self, per_replica_values, axis, options):
1305 """all gather multiple per-replica-values."""
1306 batch_size = len(per_replica_values)
1307 # For now, we use NCCL only when batch_size > 1.
1308 # TODO(b/132575814): switch to NCCL for all collectives when implementation
1309 # is NCCL.
1310 if (self._limited_nccl and options.implementation
1311 == collective_util.CommunicationImplementation.NCCL and
1312 batch_size == 1):
1313 options = options.merge(
1314 collective_util.Options(
1315 implementation=collective_util.CommunicationImplementation.RING))
1317 logging.log_first_n(
1318 logging.INFO, "Collective batch_all_gather: %d all-gathers, "
1319 "num_devices = %d, group_size = %d, implementation = %s, " %
1320 (batch_size, len(
1321 self._devices), self._group_size, options.implementation), 10)
1323 def compute_gathered_values():
1324 gathered_values = []
1325 with self._lock, ops.name_scope("allgather"):
1326 for per_replica in per_replica_values:
1327 outputs = []
1328 for i in range(len(self._devices)):
1329 outputs.append(self._launchers[i].all_gather(
1330 per_replica.values[i], axis, options))
1331 gathered_values.append(outputs)
1332 return gathered_values
1334 if context.executing_eagerly():
1335 gathered_values = def_function.function(compute_gathered_values)()
1336 else:
1337 gathered_values = compute_gathered_values()
1339 mirrored = []
1340 for value in gathered_values:
1341 mirrored.append(
1342 distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
1343 return mirrored
1345 def __deepcopy__(self, memo):
1346 # distribute_coordinator deep-copies the strategy object, so
1347 # CollectiveAllReduce needs to support deep copy as well.
1348 collective_keys = copy.deepcopy(self._collective_keys, memo)
1349 return CollectiveAllReduce(self._devices, self._group_size, self._options,
1350 collective_keys, self._canonicalize_devices)
1353def select_cross_device_ops(devices, session_config=None):
1354 """Find the best `CrossDeviceOps` locally given a `tf.compat.v1.ConfigProto`.
1356 Args:
1357 devices: a list of devices passed to `tf.distribute.Strategy`.
1358 session_config: a `tf.compat.v1.ConfigProto` or `None`. If `None`, it will
1359 make decision based on all logical devices.
1361 Returns:
1362 A subclass of `CrossDeviceOps`.
1363 """
1364 requested_devices = set(device_util.canonicalize(d) for d in devices)
1365 if ops.executing_eagerly_outside_functions():
1366 logical_gpus = context.context().list_logical_devices(device_type="GPU")
1367 physical_gpus = context.context().list_physical_devices(device_type="GPU")
1368 if len(logical_gpus) != len(physical_gpus):
1369 logging.warning("NCCL is not supported when using virtual GPUs, falling"
1370 "back to reduction to one device")
1371 return ReductionToOneDevice()
1373 machine_devices = context.context().list_logical_devices()
1374 else:
1375 machine_devices = device_lib.list_local_devices(
1376 session_config=session_config)
1377 using_devices = set()
1378 for d in machine_devices:
1379 if device_util.canonicalize(d.name) in requested_devices:
1380 using_devices.add(d.name)
1382 if len(using_devices) != len(requested_devices):
1383 logging.warning(
1384 "Some requested devices in `tf.distribute.Strategy` are not visible "
1385 "to TensorFlow: %s", ",".join(list(requested_devices - using_devices)))
1387 if any("gpu" not in d.lower() for d in requested_devices):
1388 logging.warning("There are non-GPU devices in `tf.distribute.Strategy`, "
1389 "not using nccl allreduce.")
1390 return ReductionToOneDevice()
1392 if kernels.get_registered_kernels_for_op("NcclAllReduce"):
1393 return NcclAllReduce(num_packs=1)
1394 else:
1395 logging.warning("Nccl kernel is not found, not using nccl allreduce.")
1396 return ReductionToOneDevice()