Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/multi_worker_util.py: 18%
87 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for multi-worker distribution strategies."""
17from tensorflow.core.protobuf import cluster_pb2
18from tensorflow.python.distribute import distribute_coordinator_context as dc_context
19from tensorflow.python.training import server_lib
22def normalize_cluster_spec(cluster_spec):
23 """Makes `cluster_spec` into a `ClusterSpec` object.
25 Args:
26 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
27 cluster configurations.
29 Returns:
30 a `ClusterSpec` object.
32 Raises:
33 ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a
34 `ClusterDef`.
35 """
36 if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
37 return server_lib.ClusterSpec(cluster_spec)
38 elif not isinstance(cluster_spec, server_lib.ClusterSpec):
39 raise ValueError(
40 "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
41 "`tf.train.ClusterDef` object")
42 return cluster_spec
45def task_count(cluster_spec, task_type):
46 try:
47 return cluster_spec.num_tasks(task_type)
48 except ValueError:
49 return 0
52def _validate_cluster_spec(cluster_spec,
53 task_type,
54 task_id):
55 """Validates `cluster_spec`.
57 It checks:
58 1) task type is one of "chief", "worker", "ps", "evaluator", or not provided
59 (None).
60 2) whether there is such a task type as `task_type` in the `cluster_spec`. The
61 only exception is `evaluator`. In other words, it is still a valid
62 configuration when `task_type` is `evaluator` but it doesn't appear in
63 `cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator.
64 3) whether there is at most one "chief" job.
65 4) whether there is at most one "evaluator" job.
66 5) whether the `task_id` is smaller than the number of tasks for that
67 particular `task_type`.
69 Args:
70 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
71 task_type: string indicating the type of the task.
72 task_id: the id of the `task_type` in this cluster.
74 Raises:
75 ValueError: if `cluster_spec` fails any check.
76 """
77 allowed_task_types = ("chief", "worker", "evaluator", "ps", None)
79 cluster_spec = normalize_cluster_spec(cluster_spec)
81 if any(job not in allowed_task_types for job in cluster_spec.jobs):
82 raise ValueError("Disallowed task type found in cluster spec. Allowed "
83 "types are {} and the cluster spec is {}.".format(
84 allowed_task_types, cluster_spec))
86 if task_type not in allowed_task_types:
87 raise ValueError(
88 "Unrecognized task_type: {}, valid task types are: {}".format(
89 task_type, allowed_task_types))
91 if (task_type and task_type not in cluster_spec.jobs and
92 task_type != "evaluator"):
93 raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
95 if task_count(cluster_spec, "chief") > 1:
96 raise ValueError("There must be at most one 'chief' job.")
98 if task_count(cluster_spec, "evaluator") > 1:
99 raise ValueError("There must be at most one 'evaluator' job.")
101 # The `evaluator` job is allowed to be missing in `cluster_spec`.
102 if task_type in cluster_spec.jobs and task_id >= task_count(
103 cluster_spec, task_type):
104 raise ValueError(
105 "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
108def is_chief(cluster_spec=None, task_type=None, task_id=None):
109 """Returns whether the given task is chief in the cluster.
111 Since there is at most one evaluator and the evaluator itself should be
112 independent of the training cluster, the evaluator job is also a chief job on
113 its own.
115 If this is currently running under a `_WorkerContext` of distribute
116 coordinator, the arguments can be omitted as the result is already available.
118 Args:
119 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
120 cluster configurations.
121 task_type: the task type in the cluster.
122 task_id: the task id in the cluster.
124 Returns:
125 a boolean indicating whether the given task is chief.
127 Raises:
128 ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
129 the maximum id of the `task_type`.
130 """
131 if has_worker_context():
132 # If a worker context exists, use the value provided by it.
133 return dc_context.get_current_worker_context().is_chief
135 _validate_cluster_spec(cluster_spec, task_type, task_id)
136 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
138 if task_type == "chief" or task_type == "evaluator":
139 return True
141 # If chief not in the cluster_spec, use the first worker as chief. This is
142 # common in CollectiveAllReduceStrategy.
143 if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0):
144 return True
145 return False
148def collective_leader(cluster_spec, task_type, task_id):
149 """Return the job name for the leader of for collective ops.
151 Args:
152 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
153 cluster configurations.
154 task_type: the task type in the cluster.
155 task_id: the task id in the cluster.
157 Returns:
158 a string indicating the leader job name or empty string if no need to set
159 leader job.
160 """
161 cluster_spec = normalize_cluster_spec(cluster_spec)
163 # No need to set collective leader for local.
164 if not cluster_spec.as_dict():
165 return ""
167 _validate_cluster_spec(cluster_spec, task_type, task_id)
169 # Only one evaluator, so no need to set collective leader.
170 if task_type == "evaluator":
171 return ""
173 # Use chief if chief is in the cluster.
174 if "chief" in cluster_spec.jobs:
175 return "/job:chief/replica:0/task:0"
177 # Use worker 0 if no chief job.
178 assert "worker" in cluster_spec.jobs
179 return "/job:worker/replica:0/task:0"
182def coordination_leader(cluster_spec):
183 """Return the task name of the coordination service leader.
185 Args:
186 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object sxpecifying the
187 cluster configurations.
189 Returns:
190 a string indicating the task name of the coordination service leader.
191 """
192 cluster_spec = normalize_cluster_spec(cluster_spec)
194 # No need to set coordination service leader for local.
195 if not cluster_spec.as_dict():
196 return ""
198 # Use PS 0 if parameter servers are in the cluster
199 if "ps" in cluster_spec.jobs:
200 return "/job:ps/replica:0/task:0"
202 # Use chief if chief is in the cluster.
203 if "chief" in cluster_spec.jobs:
204 return "/job:chief/replica:0/task:0"
206 # Use worker 0 if no chief job.
207 assert "worker" in cluster_spec.jobs
208 return "/job:worker/replica:0/task:0"
211def worker_count(cluster_spec, task_type):
212 """Returns the number of workers in the cluster."""
213 _validate_cluster_spec(cluster_spec, task_type, task_id=0)
214 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
216 # Other jobs such as "ps" shouldn't call this function.
217 if task_type not in ["chief", "worker", "evaluator"]:
218 raise ValueError("Unexpected `task_type` %r" % task_type)
220 if task_type == "evaluator":
221 # The "evaluator" is in its own cluster or its own partition of a cluster.
222 # So we don't have to count "chief" or "worker" if the current task is an
223 # "evaluator".
224 return len(cluster_spec["evaluator"])
225 else:
226 # In the non-evaluator case, we return the total number of "chief" and
227 # "worker" tasks as the "chief" is also a worker.
228 return (len(cluster_spec.get("chief", [])) + len(
229 cluster_spec.get("worker", [])))
232def id_in_cluster(cluster_spec, task_type, task_id):
233 """Returns a unique id for the task in the `task_type`'s cluster.
235 It returns an id ranging from [0, `worker_count(task_type, task_id)`).
237 Note: this function assumes that "evaluate" job is in its own cluster or its
238 own partition of a cluster.
240 Args:
241 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
242 task_type: string indicating the type of the task.
243 task_id: the id of the `task_type` in this cluster.
245 Returns:
246 an int indicating the unique id.
248 Throws:
249 ValueError: if `task_type` is not "chief", "worker" or "evaluator".
250 """
251 _validate_cluster_spec(cluster_spec, task_type, task_id)
252 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
254 # The "chief" job has always id 0 and there is at most one and "worker" jobs
255 # come after it.
256 if task_type == "chief":
257 return 0
259 if task_type == "worker":
260 return task_id + len(cluster_spec.get("chief", []))
262 # The "evaluator" is in its own cluster or its own partition of a cluster.
263 if task_type == "evaluator":
264 return task_id
266 # We currently don't assign ids to other tasks.
267 raise ValueError("There is no id for task_type %r" % task_type)
270def should_save_checkpoint():
271 """Returns whether the current worker should save checkpoints.
273 In multi-worker training, if saving checkpoint is requested by user, or needed
274 for fault-tolerance, the cluster should save checkpoint but not necessarily
275 every worker in the cluster should.
277 TODO(rchao): Consider generalizing this util to be `should_save_file` as there
278 can be other files to save such as summary.
280 Returns:
281 Whether this particular worker in the cluster should save checkpoints.
282 """
283 return dc_context.get_current_worker_context().should_checkpoint
286def should_load_checkpoint():
287 """Returns whether the current worker should load checkpoints.
289 In multi-worker training, if loading checkpoint is requested by user, or
290 needed for fault-tolerance, the cluster should load checkpoint but not
291 necessarily every worker in the cluster should.
293 Returns:
294 Whether this particular worker in the cluster should load checkpoints.
295 """
296 return dc_context.get_current_worker_context().experimental_should_init
299def wait_for_other_workers():
300 """Waits for other workers to reach the same call to this method."""
301 return dc_context.get_current_worker_context().wait_for_other_workers()
304def has_worker_context():
305 """Returns whether a worker context has been entered."""
306 return dc_context.get_current_worker_context() is not None