Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/multi_worker_util.py: 18%

87 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for multi-worker distribution strategies.""" 

16 

17from tensorflow.core.protobuf import cluster_pb2 

18from tensorflow.python.distribute import distribute_coordinator_context as dc_context 

19from tensorflow.python.training import server_lib 

20 

21 

22def normalize_cluster_spec(cluster_spec): 

23 """Makes `cluster_spec` into a `ClusterSpec` object. 

24 

25 Args: 

26 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 

27 cluster configurations. 

28 

29 Returns: 

30 a `ClusterSpec` object. 

31 

32 Raises: 

33 ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a 

34 `ClusterDef`. 

35 """ 

36 if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): 

37 return server_lib.ClusterSpec(cluster_spec) 

38 elif not isinstance(cluster_spec, server_lib.ClusterSpec): 

39 raise ValueError( 

40 "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " 

41 "`tf.train.ClusterDef` object") 

42 return cluster_spec 

43 

44 

45def task_count(cluster_spec, task_type): 

46 try: 

47 return cluster_spec.num_tasks(task_type) 

48 except ValueError: 

49 return 0 

50 

51 

52def _validate_cluster_spec(cluster_spec, 

53 task_type, 

54 task_id): 

55 """Validates `cluster_spec`. 

56 

57 It checks: 

58 1) task type is one of "chief", "worker", "ps", "evaluator", or not provided 

59 (None). 

60 2) whether there is such a task type as `task_type` in the `cluster_spec`. The 

61 only exception is `evaluator`. In other words, it is still a valid 

62 configuration when `task_type` is `evaluator` but it doesn't appear in 

63 `cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator. 

64 3) whether there is at most one "chief" job. 

65 4) whether there is at most one "evaluator" job. 

66 5) whether the `task_id` is smaller than the number of tasks for that 

67 particular `task_type`. 

68 

69 Args: 

70 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated. 

71 task_type: string indicating the type of the task. 

72 task_id: the id of the `task_type` in this cluster. 

73 

74 Raises: 

75 ValueError: if `cluster_spec` fails any check. 

76 """ 

77 allowed_task_types = ("chief", "worker", "evaluator", "ps", None) 

78 

79 cluster_spec = normalize_cluster_spec(cluster_spec) 

80 

81 if any(job not in allowed_task_types for job in cluster_spec.jobs): 

82 raise ValueError("Disallowed task type found in cluster spec. Allowed " 

83 "types are {} and the cluster spec is {}.".format( 

84 allowed_task_types, cluster_spec)) 

85 

86 if task_type not in allowed_task_types: 

87 raise ValueError( 

88 "Unrecognized task_type: {}, valid task types are: {}".format( 

89 task_type, allowed_task_types)) 

90 

91 if (task_type and task_type not in cluster_spec.jobs and 

92 task_type != "evaluator"): 

93 raise ValueError("`task_type` %r not found in cluster_spec." % task_type) 

94 

95 if task_count(cluster_spec, "chief") > 1: 

96 raise ValueError("There must be at most one 'chief' job.") 

97 

98 if task_count(cluster_spec, "evaluator") > 1: 

99 raise ValueError("There must be at most one 'evaluator' job.") 

100 

101 # The `evaluator` job is allowed to be missing in `cluster_spec`. 

102 if task_type in cluster_spec.jobs and task_id >= task_count( 

103 cluster_spec, task_type): 

104 raise ValueError( 

105 "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type)) 

106 

107 

108def is_chief(cluster_spec=None, task_type=None, task_id=None): 

109 """Returns whether the given task is chief in the cluster. 

110 

111 Since there is at most one evaluator and the evaluator itself should be 

112 independent of the training cluster, the evaluator job is also a chief job on 

113 its own. 

114 

115 If this is currently running under a `_WorkerContext` of distribute 

116 coordinator, the arguments can be omitted as the result is already available. 

117 

118 Args: 

119 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the 

120 cluster configurations. 

121 task_type: the task type in the cluster. 

122 task_id: the task id in the cluster. 

123 

124 Returns: 

125 a boolean indicating whether the given task is chief. 

126 

127 Raises: 

128 ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds 

129 the maximum id of the `task_type`. 

130 """ 

131 if has_worker_context(): 

132 # If a worker context exists, use the value provided by it. 

133 return dc_context.get_current_worker_context().is_chief 

134 

135 _validate_cluster_spec(cluster_spec, task_type, task_id) 

136 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 

137 

138 if task_type == "chief" or task_type == "evaluator": 

139 return True 

140 

141 # If chief not in the cluster_spec, use the first worker as chief. This is 

142 # common in CollectiveAllReduceStrategy. 

143 if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0): 

144 return True 

145 return False 

146 

147 

148def collective_leader(cluster_spec, task_type, task_id): 

149 """Return the job name for the leader of for collective ops. 

150 

151 Args: 

152 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the 

153 cluster configurations. 

154 task_type: the task type in the cluster. 

155 task_id: the task id in the cluster. 

156 

157 Returns: 

158 a string indicating the leader job name or empty string if no need to set 

159 leader job. 

160 """ 

161 cluster_spec = normalize_cluster_spec(cluster_spec) 

162 

163 # No need to set collective leader for local. 

164 if not cluster_spec.as_dict(): 

165 return "" 

166 

167 _validate_cluster_spec(cluster_spec, task_type, task_id) 

168 

169 # Only one evaluator, so no need to set collective leader. 

170 if task_type == "evaluator": 

171 return "" 

172 

173 # Use chief if chief is in the cluster. 

174 if "chief" in cluster_spec.jobs: 

175 return "/job:chief/replica:0/task:0" 

176 

177 # Use worker 0 if no chief job. 

178 assert "worker" in cluster_spec.jobs 

179 return "/job:worker/replica:0/task:0" 

180 

181 

182def coordination_leader(cluster_spec): 

183 """Return the task name of the coordination service leader. 

184 

185 Args: 

186 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object sxpecifying the 

187 cluster configurations. 

188 

189 Returns: 

190 a string indicating the task name of the coordination service leader. 

191 """ 

192 cluster_spec = normalize_cluster_spec(cluster_spec) 

193 

194 # No need to set coordination service leader for local. 

195 if not cluster_spec.as_dict(): 

196 return "" 

197 

198 # Use PS 0 if parameter servers are in the cluster 

199 if "ps" in cluster_spec.jobs: 

200 return "/job:ps/replica:0/task:0" 

201 

202 # Use chief if chief is in the cluster. 

203 if "chief" in cluster_spec.jobs: 

204 return "/job:chief/replica:0/task:0" 

205 

206 # Use worker 0 if no chief job. 

207 assert "worker" in cluster_spec.jobs 

208 return "/job:worker/replica:0/task:0" 

209 

210 

211def worker_count(cluster_spec, task_type): 

212 """Returns the number of workers in the cluster.""" 

213 _validate_cluster_spec(cluster_spec, task_type, task_id=0) 

214 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 

215 

216 # Other jobs such as "ps" shouldn't call this function. 

217 if task_type not in ["chief", "worker", "evaluator"]: 

218 raise ValueError("Unexpected `task_type` %r" % task_type) 

219 

220 if task_type == "evaluator": 

221 # The "evaluator" is in its own cluster or its own partition of a cluster. 

222 # So we don't have to count "chief" or "worker" if the current task is an 

223 # "evaluator". 

224 return len(cluster_spec["evaluator"]) 

225 else: 

226 # In the non-evaluator case, we return the total number of "chief" and 

227 # "worker" tasks as the "chief" is also a worker. 

228 return (len(cluster_spec.get("chief", [])) + len( 

229 cluster_spec.get("worker", []))) 

230 

231 

232def id_in_cluster(cluster_spec, task_type, task_id): 

233 """Returns a unique id for the task in the `task_type`'s cluster. 

234 

235 It returns an id ranging from [0, `worker_count(task_type, task_id)`). 

236 

237 Note: this function assumes that "evaluate" job is in its own cluster or its 

238 own partition of a cluster. 

239 

240 Args: 

241 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated. 

242 task_type: string indicating the type of the task. 

243 task_id: the id of the `task_type` in this cluster. 

244 

245 Returns: 

246 an int indicating the unique id. 

247 

248 Throws: 

249 ValueError: if `task_type` is not "chief", "worker" or "evaluator". 

250 """ 

251 _validate_cluster_spec(cluster_spec, task_type, task_id) 

252 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 

253 

254 # The "chief" job has always id 0 and there is at most one and "worker" jobs 

255 # come after it. 

256 if task_type == "chief": 

257 return 0 

258 

259 if task_type == "worker": 

260 return task_id + len(cluster_spec.get("chief", [])) 

261 

262 # The "evaluator" is in its own cluster or its own partition of a cluster. 

263 if task_type == "evaluator": 

264 return task_id 

265 

266 # We currently don't assign ids to other tasks. 

267 raise ValueError("There is no id for task_type %r" % task_type) 

268 

269 

270def should_save_checkpoint(): 

271 """Returns whether the current worker should save checkpoints. 

272 

273 In multi-worker training, if saving checkpoint is requested by user, or needed 

274 for fault-tolerance, the cluster should save checkpoint but not necessarily 

275 every worker in the cluster should. 

276 

277 TODO(rchao): Consider generalizing this util to be `should_save_file` as there 

278 can be other files to save such as summary. 

279 

280 Returns: 

281 Whether this particular worker in the cluster should save checkpoints. 

282 """ 

283 return dc_context.get_current_worker_context().should_checkpoint 

284 

285 

286def should_load_checkpoint(): 

287 """Returns whether the current worker should load checkpoints. 

288 

289 In multi-worker training, if loading checkpoint is requested by user, or 

290 needed for fault-tolerance, the cluster should load checkpoint but not 

291 necessarily every worker in the cluster should. 

292 

293 Returns: 

294 Whether this particular worker in the cluster should load checkpoints. 

295 """ 

296 return dc_context.get_current_worker_context().experimental_should_init 

297 

298 

299def wait_for_other_workers(): 

300 """Waits for other workers to reach the same call to this method.""" 

301 return dc_context.get_current_worker_context().wait_for_other_workers() 

302 

303 

304def has_worker_context(): 

305 """Returns whether a worker context has been entered.""" 

306 return dc_context.get_current_worker_context() is not None