Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/collective_util.py: 51%
49 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# coding=utf-8
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Utilities for collectives."""
18import copy
19import enum
21from tensorflow.python.util import deprecation
22from tensorflow.python.util.tf_export import tf_export
25# TODO(b/170340570): print deprecation warning for CollectiveCommunication.
26@tf_export("distribute.experimental.CommunicationImplementation",
27 "distribute.experimental.CollectiveCommunication")
28class CommunicationImplementation(enum.Enum):
29 """Cross device communication implementation.
31 Warning: The alias `tf.distribute.experimental.CollectiveCommunication` is
32 deprecated and will be removed in a future version. Use
33 `tf.distribute.experimental.CommunicationImplementation` instead.
35 * `AUTO`: Automatically chosen by Tensorflow.
36 * `RING`: TensorFlow's ring algorithms for all-reduce and
37 all-gather.
38 * `NCCL`: NVIDIA®'s NCCL library. This is now only used for all-reduce on
39 GPUs; all-reduce on CPU, all-gather and broadcast fallbacks to RING.
40 """
41 AUTO = "AUTO"
42 RING = "RING"
43 NCCL = "NCCL"
44 # TODO(ayushd): add ncclAllGather implementation.
47CollectiveCommunication = CommunicationImplementation
50@tf_export("distribute.experimental.CommunicationOptions")
51class _OptionsExported(object):
52 """Options for cross device communications like All-reduce.
54 This can be passed to methods like
55 `tf.distribute.get_replica_context().all_reduce()` to optimize collective
56 operation performance. Note that these are only hints, which may or may not
57 change the actual behavior. Some options only apply to certain strategy and
58 are ignored by others.
60 One common optimization is to break gradients all-reduce into multiple packs
61 so that weight updates can overlap with gradient all-reduce.
63 Examples:
65 ```python
66 options = tf.distribute.experimental.CommunicationOptions(
67 bytes_per_pack=50 * 1024 * 1024,
68 timeout_seconds=120.0,
69 implementation=tf.distribute.experimental.CommunicationImplementation.NCCL
70 )
71 grads = tf.distribute.get_replica_context().all_reduce(
72 'sum', grads, options=options)
73 optimizer.apply_gradients(zip(grads, vars),
74 experimental_aggregate_gradients=False)
75 ```
77 """
79 def __new__(cls, *args, **kwargs):
80 # We expose a dummy class so that we can separate internal and public APIs.
81 # Note that __init__ won't be called on the returned object if it's a
82 # different class [1].
83 # [1] https://docs.python.org/3/reference/datamodel.html#object.__new__
84 return Options(*args, **kwargs)
86 def __init__(self,
87 bytes_per_pack=0,
88 timeout_seconds=None,
89 implementation=CommunicationImplementation.AUTO):
90 """Creates a CollectiveHints.
92 Args:
93 bytes_per_pack: a non-negative integer. Breaks collective operations into
94 packs of certain size. If it's zero, the value is determined
95 automatically. This hint is respected by all multi-replica strategies
96 except `TPUStrategy`.
97 timeout_seconds: a float or None, timeout in seconds. If not None, the
98 collective raises `tf.errors.DeadlineExceededError` if it takes longer
99 than this timeout. Zero disables timeout. This can be useful when
100 debugging hanging issues. This should only be used for debugging since
101 it creates a new thread for each collective, i.e. an overhead of
102 `timeout_seconds * num_collectives_per_second` more threads. This only
103 works for `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
104 implementation: a
105 `tf.distribute.experimental.CommunicationImplementation`. This is a hint
106 on the preferred communication implementation. Possible values include
107 `AUTO`, `RING`, and `NCCL`. NCCL is generally more performant for GPU,
108 but doesn't work for CPU. This only works for
109 `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
111 Raises:
112 ValueError: When arguments have invalid value.
113 """
114 pass
117class Options(object):
118 """Implementation of OptionsInterface."""
120 def __init__(self,
121 bytes_per_pack=0,
122 timeout_seconds=None,
123 implementation=CommunicationImplementation.AUTO):
124 if bytes_per_pack < 0:
125 raise ValueError(
126 f"Argument `bytes_per_pack` must be >=0, Received {bytes_per_pack}.")
127 if isinstance(implementation, str):
128 implementation = CommunicationImplementation(implementation.upper())
129 if not isinstance(implementation, CommunicationImplementation):
130 raise ValueError(
131 "Argument `implementation` must be instance of "
132 "`tf.distribute.experimental.CommunicationImplementation`.")
133 self.bytes_per_pack = bytes_per_pack
134 self.timeout_seconds = timeout_seconds
135 self.implementation = implementation
137 __init__.__doc__ = _OptionsExported.__init__.__doc__
139 def merge(self, options):
140 """Merges with another options and returns a new one.
142 Values specified in the `options` takes precedence if they're not the
143 default.
145 Args:
146 options: a `tf.distribute.experimental.CollectiveCommunication`.
148 Returns:
149 A new `tf.distribute.experimental.CollectiveCommunication`.
150 """
151 merged = copy.deepcopy(self)
152 if options is None:
153 return merged
154 if options.bytes_per_pack != 0:
155 merged.bytes_per_pack = options.bytes_per_pack
156 if options.timeout_seconds is not None:
157 merged.timeout_seconds = options.timeout_seconds
158 if options.implementation != CommunicationImplementation.AUTO:
159 merged.implementation = options.implementation
160 return merged
162 def __str__(self):
163 return (f"Options(bytes_per_pack={self.bytes_per_pack},"
164 f"timeout_seconds={self.timeout_seconds}, "
165 f"implementation={self.implementation})")
168@tf_export("distribute.experimental.CollectiveHints")
169class Hints(object):
170 """Hints for collective operations like AllReduce.
172 This can be passed to methods like
173 `tf.distribute.get_replica_context().all_reduce()` to optimize collective
174 operation performance. Note that these are only hints, which may or may not
175 change the actual behavior. Some options only apply to certain strategy and
176 are ignored by others.
178 One common optimization is to break gradients all-reduce into multiple packs
179 so that weight updates can overlap with gradient all-reduce.
181 Examples:
183 - bytes_per_pack
185 ```python
186 hints = tf.distribute.experimental.CollectiveHints(
187 bytes_per_pack=50 * 1024 * 1024)
188 grads = tf.distribute.get_replica_context().all_reduce(
189 'sum', grads, experimental_hints=hints)
190 optimizer.apply_gradients(zip(grads, vars),
191 experimental_aggregate_gradients=False)
192 ```
194 - timeout_seconds
196 ```python
197 strategy = tf.distribute.MirroredStrategy()
198 hints = tf.distribute.experimental.CollectiveHints(
199 timeout_seconds=120.0)
200 try:
201 strategy.reduce("sum", v, axis=None, experimental_hints=hints)
202 except tf.errors.DeadlineExceededError:
203 do_something()
204 ```
206 """
208 @deprecation.deprecated(
209 None, "use distribute.experimental.CommunicationOptions instead")
210 def __new__(cls, bytes_per_pack=0, timeout_seconds=None):
211 return Options(
212 bytes_per_pack=bytes_per_pack, timeout_seconds=timeout_seconds)
214 def __init__(self, bytes_per_pack=0, timeout_seconds=None):
215 """Creates a CollectiveHints.
217 Args:
218 bytes_per_pack: a non-negative integer. Breaks collective operations into
219 packs of certain size. If it's zero, the value is determined
220 automatically. This only applies to all-reduce with
221 `MultiWorkerMirroredStrategy` currently.
222 timeout_seconds: a float or None, timeout in seconds. If not None, the
223 collective raises `tf.errors.DeadlineExceededError` if it takes longer
224 than this timeout. This can be useful when debugging hanging issues.
225 This should only be used for debugging since it creates a new thread for
226 each collective, i.e. an overhead of `timeout_seconds *
227 num_collectives_per_second` more threads. This only works for
228 `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
230 Raises:
231 ValueError: When arguments have invalid value.
232 """
233 pass