Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/collective_util.py: 51%

49 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# coding=utf-8 

2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15# ============================================================================== 

16"""Utilities for collectives.""" 

17 

18import copy 

19import enum 

20 

21from tensorflow.python.util import deprecation 

22from tensorflow.python.util.tf_export import tf_export 

23 

24 

25# TODO(b/170340570): print deprecation warning for CollectiveCommunication. 

26@tf_export("distribute.experimental.CommunicationImplementation", 

27 "distribute.experimental.CollectiveCommunication") 

28class CommunicationImplementation(enum.Enum): 

29 """Cross device communication implementation. 

30 

31 Warning: The alias `tf.distribute.experimental.CollectiveCommunication` is 

32 deprecated and will be removed in a future version. Use 

33 `tf.distribute.experimental.CommunicationImplementation` instead. 

34 

35 * `AUTO`: Automatically chosen by Tensorflow. 

36 * `RING`: TensorFlow's ring algorithms for all-reduce and 

37 all-gather. 

38 * `NCCL`: NVIDIA®'s NCCL library. This is now only used for all-reduce on 

39 GPUs; all-reduce on CPU, all-gather and broadcast fallbacks to RING. 

40 """ 

41 AUTO = "AUTO" 

42 RING = "RING" 

43 NCCL = "NCCL" 

44 # TODO(ayushd): add ncclAllGather implementation. 

45 

46 

47CollectiveCommunication = CommunicationImplementation 

48 

49 

50@tf_export("distribute.experimental.CommunicationOptions") 

51class _OptionsExported(object): 

52 """Options for cross device communications like All-reduce. 

53 

54 This can be passed to methods like 

55 `tf.distribute.get_replica_context().all_reduce()` to optimize collective 

56 operation performance. Note that these are only hints, which may or may not 

57 change the actual behavior. Some options only apply to certain strategy and 

58 are ignored by others. 

59 

60 One common optimization is to break gradients all-reduce into multiple packs 

61 so that weight updates can overlap with gradient all-reduce. 

62 

63 Examples: 

64 

65 ```python 

66 options = tf.distribute.experimental.CommunicationOptions( 

67 bytes_per_pack=50 * 1024 * 1024, 

68 timeout_seconds=120.0, 

69 implementation=tf.distribute.experimental.CommunicationImplementation.NCCL 

70 ) 

71 grads = tf.distribute.get_replica_context().all_reduce( 

72 'sum', grads, options=options) 

73 optimizer.apply_gradients(zip(grads, vars), 

74 experimental_aggregate_gradients=False) 

75 ``` 

76 

77 """ 

78 

79 def __new__(cls, *args, **kwargs): 

80 # We expose a dummy class so that we can separate internal and public APIs. 

81 # Note that __init__ won't be called on the returned object if it's a 

82 # different class [1]. 

83 # [1] https://docs.python.org/3/reference/datamodel.html#object.__new__ 

84 return Options(*args, **kwargs) 

85 

86 def __init__(self, 

87 bytes_per_pack=0, 

88 timeout_seconds=None, 

89 implementation=CommunicationImplementation.AUTO): 

90 """Creates a CollectiveHints. 

91 

92 Args: 

93 bytes_per_pack: a non-negative integer. Breaks collective operations into 

94 packs of certain size. If it's zero, the value is determined 

95 automatically. This hint is respected by all multi-replica strategies 

96 except `TPUStrategy`. 

97 timeout_seconds: a float or None, timeout in seconds. If not None, the 

98 collective raises `tf.errors.DeadlineExceededError` if it takes longer 

99 than this timeout. Zero disables timeout. This can be useful when 

100 debugging hanging issues. This should only be used for debugging since 

101 it creates a new thread for each collective, i.e. an overhead of 

102 `timeout_seconds * num_collectives_per_second` more threads. This only 

103 works for `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 

104 implementation: a 

105 `tf.distribute.experimental.CommunicationImplementation`. This is a hint 

106 on the preferred communication implementation. Possible values include 

107 `AUTO`, `RING`, and `NCCL`. NCCL is generally more performant for GPU, 

108 but doesn't work for CPU. This only works for 

109 `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 

110 

111 Raises: 

112 ValueError: When arguments have invalid value. 

113 """ 

114 pass 

115 

116 

117class Options(object): 

118 """Implementation of OptionsInterface.""" 

119 

120 def __init__(self, 

121 bytes_per_pack=0, 

122 timeout_seconds=None, 

123 implementation=CommunicationImplementation.AUTO): 

124 if bytes_per_pack < 0: 

125 raise ValueError( 

126 f"Argument `bytes_per_pack` must be >=0, Received {bytes_per_pack}.") 

127 if isinstance(implementation, str): 

128 implementation = CommunicationImplementation(implementation.upper()) 

129 if not isinstance(implementation, CommunicationImplementation): 

130 raise ValueError( 

131 "Argument `implementation` must be instance of " 

132 "`tf.distribute.experimental.CommunicationImplementation`.") 

133 self.bytes_per_pack = bytes_per_pack 

134 self.timeout_seconds = timeout_seconds 

135 self.implementation = implementation 

136 

137 __init__.__doc__ = _OptionsExported.__init__.__doc__ 

138 

139 def merge(self, options): 

140 """Merges with another options and returns a new one. 

141 

142 Values specified in the `options` takes precedence if they're not the 

143 default. 

144 

145 Args: 

146 options: a `tf.distribute.experimental.CollectiveCommunication`. 

147 

148 Returns: 

149 A new `tf.distribute.experimental.CollectiveCommunication`. 

150 """ 

151 merged = copy.deepcopy(self) 

152 if options is None: 

153 return merged 

154 if options.bytes_per_pack != 0: 

155 merged.bytes_per_pack = options.bytes_per_pack 

156 if options.timeout_seconds is not None: 

157 merged.timeout_seconds = options.timeout_seconds 

158 if options.implementation != CommunicationImplementation.AUTO: 

159 merged.implementation = options.implementation 

160 return merged 

161 

162 def __str__(self): 

163 return (f"Options(bytes_per_pack={self.bytes_per_pack}," 

164 f"timeout_seconds={self.timeout_seconds}, " 

165 f"implementation={self.implementation})") 

166 

167 

168@tf_export("distribute.experimental.CollectiveHints") 

169class Hints(object): 

170 """Hints for collective operations like AllReduce. 

171 

172 This can be passed to methods like 

173 `tf.distribute.get_replica_context().all_reduce()` to optimize collective 

174 operation performance. Note that these are only hints, which may or may not 

175 change the actual behavior. Some options only apply to certain strategy and 

176 are ignored by others. 

177 

178 One common optimization is to break gradients all-reduce into multiple packs 

179 so that weight updates can overlap with gradient all-reduce. 

180 

181 Examples: 

182 

183 - bytes_per_pack 

184 

185 ```python 

186 hints = tf.distribute.experimental.CollectiveHints( 

187 bytes_per_pack=50 * 1024 * 1024) 

188 grads = tf.distribute.get_replica_context().all_reduce( 

189 'sum', grads, experimental_hints=hints) 

190 optimizer.apply_gradients(zip(grads, vars), 

191 experimental_aggregate_gradients=False) 

192 ``` 

193 

194 - timeout_seconds 

195 

196 ```python 

197 strategy = tf.distribute.MirroredStrategy() 

198 hints = tf.distribute.experimental.CollectiveHints( 

199 timeout_seconds=120.0) 

200 try: 

201 strategy.reduce("sum", v, axis=None, experimental_hints=hints) 

202 except tf.errors.DeadlineExceededError: 

203 do_something() 

204 ``` 

205 

206 """ 

207 

208 @deprecation.deprecated( 

209 None, "use distribute.experimental.CommunicationOptions instead") 

210 def __new__(cls, bytes_per_pack=0, timeout_seconds=None): 

211 return Options( 

212 bytes_per_pack=bytes_per_pack, timeout_seconds=timeout_seconds) 

213 

214 def __init__(self, bytes_per_pack=0, timeout_seconds=None): 

215 """Creates a CollectiveHints. 

216 

217 Args: 

218 bytes_per_pack: a non-negative integer. Breaks collective operations into 

219 packs of certain size. If it's zero, the value is determined 

220 automatically. This only applies to all-reduce with 

221 `MultiWorkerMirroredStrategy` currently. 

222 timeout_seconds: a float or None, timeout in seconds. If not None, the 

223 collective raises `tf.errors.DeadlineExceededError` if it takes longer 

224 than this timeout. This can be useful when debugging hanging issues. 

225 This should only be used for debugging since it creates a new thread for 

226 each collective, i.e. an overhead of `timeout_seconds * 

227 num_collectives_per_second` more threads. This only works for 

228 `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 

229 

230 Raises: 

231 ValueError: When arguments have invalid value. 

232 """ 

233 pass