1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution."""
16
17import abc
18
19import collections
20
21import six
22
23from tensorflow.python.client import session
24from tensorflow.python.eager import context
25from tensorflow.python.framework import config
26from tensorflow.python.framework import ops
27from tensorflow.python.training.server_lib import ClusterSpec
28from tensorflow.python.util.tf_export import tf_export
29
30
31def format_master_url(master, rpc_layer=None):
32 if rpc_layer:
33 return '%s://%s' % (rpc_layer, master)
34 else:
35 return master
36
37
38def get_accelerator_devices(master, config_proto):
39 """Returns accelerator devices given a master and a configuration."""
40 if context.executing_eagerly():
41 logical_devices = config.list_logical_devices()
42 devices = []
43 for d in logical_devices:
44 if d.device_type == 'CPU' or d.device_type == 'XLA_CPU': # Filter CPUs
45 continue
46 devices.append(session._DeviceAttributes(d.name, d.device_type, 0, 0)) # pylint: disable=protected-access
47 return devices
48 else:
49 with ops.Graph().as_default():
50 with session.Session(master, config=config_proto) as s:
51 devices = s.list_devices()
52 return devices
53
54
55@tf_export('distribute.cluster_resolver.ClusterResolver')
56@six.add_metaclass(abc.ABCMeta)
57class ClusterResolver(object):
58 """Abstract class for all implementations of ClusterResolvers.
59
60 This defines the skeleton for all implementations of ClusterResolvers.
61 ClusterResolvers are a way for TensorFlow to communicate with various cluster
62 management systems (e.g. GCE, AWS, etc...) and gives TensorFlow necessary
63 information to set up distributed training.
64
65 By letting TensorFlow communicate with these systems, we will be able to
66 automatically discover and resolve IP addresses for various TensorFlow
67 workers. This will eventually allow us to automatically recover from
68 underlying machine failures and scale TensorFlow worker clusters up and down.
69
70 Note to Implementors of `tf.distribute.cluster_resolver.ClusterResolver`
71 subclass: In addition to these abstract methods, when task_type, task_id, and
72 rpc_layer attributes are applicable, you should also implement them either as
73 properties with getters or setters, or directly set the attributes
74 `self._task_type`, `self._task_id`, or `self._rpc_layer` so the base class'
75 getters and setters are used. See
76 `tf.distribute.cluster_resolver.SimpleClusterResolver.__init__` for an
77 example.
78
79 In general, multi-client tf.distribute strategies such as
80 `tf.distribute.experimental.MultiWorkerMirroredStrategy` require task_type and
81 task_id properties to be available in the `ClusterResolver` they are using. On
82 the other hand, these concepts are not applicable in single-client strategies,
83 such as `tf.distribute.experimental.TPUStrategy`, because the program is only
84 expected to be run on one task, so there should not be a need to have code
85 branches according to task type and task id.
86
87 - task_type is the name of the server's current named job (e.g. 'worker',
88 'ps' in a distributed parameterized training job).
89 - task_id is the ordinal index of the server within the task type.
90 - rpc_layer is the protocol used by TensorFlow to communicate with other
91 TensorFlow servers in a distributed environment.
92 """
93
94 @abc.abstractmethod
95 def cluster_spec(self):
96 """Retrieve the current state of the cluster and return a `tf.train.ClusterSpec`.
97
98 Returns:
99 A `tf.train.ClusterSpec` representing the state of the cluster at the
100 moment this function is called.
101
102 Implementors of this function must take care in ensuring that the
103 ClusterSpec returned is up-to-date at the time of calling this function.
104 This usually means retrieving the information from the underlying cluster
105 management system every time this function is invoked and reconstructing
106 a cluster_spec, rather than attempting to cache anything.
107 """
108 raise NotImplementedError()
109
110 @abc.abstractmethod
111 def master(self, task_type=None, task_id=None, rpc_layer=None):
112 """Retrieves the name or URL of the session master.
113
114 Note: this is only useful for TensorFlow 1.x.
115
116 Args:
117 task_type: (Optional) The type of the TensorFlow task of the master.
118 task_id: (Optional) The index of the TensorFlow task of the master.
119 rpc_layer: (Optional) The RPC protocol for the given cluster.
120
121 Returns:
122 The name or URL of the session master.
123
124 Implementors of this function must take care in ensuring that the master
125 returned is up-to-date at the time to calling this function. This usually
126 means retrieving the master every time this function is invoked.
127 """
128 raise NotImplementedError()
129
130 def num_accelerators(self,
131 task_type=None,
132 task_id=None,
133 config_proto=None):
134 """Returns the number of accelerator cores per worker.
135
136 This returns the number of accelerator cores (such as GPUs and TPUs)
137 available per worker.
138
139 Optionally, we allow callers to specify the task_type, and task_id, for
140 if they want to target a specific TensorFlow task to query
141 the number of accelerators. This is to support heterogenous environments,
142 where the number of accelerators cores per host is different.
143
144 Args:
145 task_type: (Optional) The type of the TensorFlow task of the machine we
146 want to query.
147 task_id: (Optional) The index of the TensorFlow task of the machine we
148 want to query.
149 config_proto: (Optional) Configuration for starting a new session to
150 query how many accelerator cores it has.
151
152 Returns:
153 A map of accelerator types to number of cores.
154 """
155 master = self.master(task_type, task_id)
156 # TODO(b/126786766): in eager mode, we should check whether
157 # `tf.config.experimental_connect_to_cluster` is called or not.
158 devices = get_accelerator_devices(master, config_proto)
159 mapping = collections.defaultdict(int)
160 for device in devices:
161 if task_type is not None and task_id is not None:
162 job_path = '/job:%s' % task_type
163 task_path = '/task:%s' % task_id
164 if job_path not in device.name or task_path not in device.name:
165 continue
166 mapping[device.device_type] += 1
167 return mapping
168
169 @property
170 def environment(self):
171 """Returns the current environment which TensorFlow is running in.
172
173 There are two possible return values, "google" (when TensorFlow is running
174 in a Google-internal environment) or an empty string (when TensorFlow is
175 running elsewhere).
176
177 If you are implementing a ClusterResolver that works in both the Google
178 environment and the open-source world (for instance, a TPU ClusterResolver
179 or similar), you will have to return the appropriate string depending on the
180 environment, which you will have to detect.
181
182 Otherwise, if you are implementing a ClusterResolver that will only work
183 in open-source TensorFlow, you do not need to implement this property.
184 """
185 return ''
186
187 @property
188 def task_type(self):
189 """Returns the task type this `ClusterResolver` indicates.
190
191 In TensorFlow distributed environment, each job may have an applicable
192 task type. Valid task types in TensorFlow include
193 'chief': a worker that is designated with more responsibility,
194 'worker': a regular worker for training/evaluation,
195 'ps': a parameter server, or
196 'evaluator': an evaluator that evaluates the checkpoints for metrics.
197
198 See [Multi-worker configuration](
199 https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#multi-worker_configuration)
200 for more information about 'chief' and 'worker' task type, which are most
201 commonly used.
202
203 Having access to such information is useful when user needs to run specific
204 code according to task types. For example,
205
206 ```python
207 cluster_spec = tf.train.ClusterSpec({
208 "ps": ["localhost:2222", "localhost:2223"],
209 "worker": ["localhost:2224", "localhost:2225", "localhost:2226"]
210 })
211
212 # SimpleClusterResolver is used here for illustration; other cluster
213 # resolvers may be used for other source of task type/id.
214 simple_resolver = SimpleClusterResolver(cluster_spec, task_type="worker",
215 task_id=1)
216
217 ...
218
219 if cluster_resolver.task_type == 'worker':
220 # Perform something that's only applicable on workers. This block
221 # will run on this particular instance since we've specified this task to
222 # be a worker in above cluster resolver.
223 elif cluster_resolver.task_type == 'ps':
224 # Perform something that's only applicable on parameter servers. This
225 # block will not run on this particular instance.
226 ```
227
228 Returns `None` if such information is not available or is not applicable
229 in the current distributed environment, such as training with
230 `tf.distribute.experimental.TPUStrategy`.
231
232 For more information, please see
233 `tf.distribute.cluster_resolver.ClusterResolver`'s class doc.
234 """
235 return getattr(self, '_task_type', None)
236
237 @property
238 def task_id(self):
239 """Returns the task id this `ClusterResolver` indicates.
240
241 In TensorFlow distributed environment, each job may have an applicable
242 task id, which is the index of the instance within its task type. This is
243 useful when user needs to run specific code according to task index. For
244 example,
245
246 ```python
247 cluster_spec = tf.train.ClusterSpec({
248 "ps": ["localhost:2222", "localhost:2223"],
249 "worker": ["localhost:2224", "localhost:2225", "localhost:2226"]
250 })
251
252 # SimpleClusterResolver is used here for illustration; other cluster
253 # resolvers may be used for other source of task type/id.
254 simple_resolver = SimpleClusterResolver(cluster_spec, task_type="worker",
255 task_id=0)
256
257 ...
258
259 if cluster_resolver.task_type == 'worker' and cluster_resolver.task_id == 0:
260 # Perform something that's only applicable on 'worker' type, id 0. This
261 # block will run on this particular instance since we've specified this
262 # task to be a 'worker', id 0 in above cluster resolver.
263 else:
264 # Perform something that's only applicable on other ids. This block will
265 # not run on this particular instance.
266 ```
267
268 Returns `None` if such information is not available or is not applicable
269 in the current distributed environment, such as training with
270 `tf.distribute.cluster_resolver.TPUClusterResolver`.
271
272 For more information, please see
273 `tf.distribute.cluster_resolver.ClusterResolver`'s class docstring.
274 """
275 return getattr(self, '_task_id', None)
276
277 @task_type.setter
278 def task_type(self, task_type):
279 """Setter of `task_type` property. See `task_type` property doc."""
280 self._task_type = task_type
281
282 @task_id.setter
283 def task_id(self, task_id):
284 """Setter of `task_id` property. See `task_type` property doc."""
285 self._task_id = task_id
286
287
288@tf_export('distribute.cluster_resolver.SimpleClusterResolver')
289class SimpleClusterResolver(ClusterResolver):
290 """Simple implementation of ClusterResolver that accepts all attributes.
291
292 Please see the base class for documentation of arguments of its constructor.
293
294 It is useful if you want to specify some or all attributes.
295
296 Usage example with `tf.distribute.Strategy`:
297
298 ```Python
299 cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222",
300 "worker1.example.com:2222"]})
301
302 # On worker 0
303 cluster_resolver = SimpleClusterResolver(cluster, task_type="worker",
304 task_id=0,
305 num_accelerators={"GPU": 8},
306 rpc_layer="grpc")
307 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
308 cluster_resolver=cluster_resolver)
309
310 # On worker 1
311 cluster_resolver = SimpleClusterResolver(cluster, task_type="worker",
312 task_id=1,
313 num_accelerators={"GPU": 8},
314 rpc_layer="grpc")
315 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
316 cluster_resolver=cluster_resolver)
317 ```
318 """
319
320 def __init__(self, cluster_spec, master='', task_type=None, task_id=None,
321 environment='', num_accelerators=None,
322 rpc_layer=None):
323 """Creates a SimpleClusterResolver from a ClusterSpec."""
324 super(SimpleClusterResolver, self).__init__()
325
326 self._task_type = task_type
327 self._task_id = task_id
328 self._environment = environment
329
330 self._num_accelerators = num_accelerators
331 self._rpc_layer = rpc_layer
332
333 if not isinstance(cluster_spec, ClusterSpec):
334 raise TypeError('cluster_spec must be a `tf.train.ClusterSpec`.')
335 self._cluster_spec = cluster_spec
336
337 if not isinstance(master, str):
338 raise TypeError('master must be a string.')
339 self._master = master
340
341 def cluster_spec(self):
342 """Returns the ClusterSpec passed into the constructor."""
343 return self._cluster_spec
344
345 def master(self, task_type=None, task_id=None, rpc_layer=None):
346 """Returns the master address to use when creating a session.
347
348 Note: this is only useful for TensorFlow 1.x.
349
350 Args:
351 task_type: (Optional) The type of the TensorFlow task of the master.
352 task_id: (Optional) The index of the TensorFlow task of the master.
353 rpc_layer: (Optional) The RPC used by distributed TensorFlow.
354
355 Returns:
356 The name or URL of the session master.
357
358 If a task_type and task_id is given, this will override the `master`
359 string passed into the initialization function.
360 """
361 if task_type is not None and task_id is not None:
362 master = self.cluster_spec().task_address(task_type, task_id)
363 else:
364 master = self._master
365
366 return format_master_url(master, rpc_layer=rpc_layer or self._rpc_layer)
367
368 @property
369 def task_type(self):
370 return self._task_type
371
372 @property
373 def task_id(self):
374 return self._task_id
375
376 @task_type.setter
377 def task_type(self, task_type):
378 self._task_type = task_type
379
380 @task_id.setter
381 def task_id(self, task_id):
382 self._task_id = task_id
383
384 @property
385 def environment(self):
386 return self._environment
387
388 def num_accelerators(self,
389 task_type=None,
390 task_id=None,
391 config_proto=None):
392 """Returns the number of accelerator cores per worker.
393
394 The SimpleClusterResolver does not do automatic detection of accelerators,
395 and thus all arguments are unused and we simply return the value provided
396 in the constructor.
397
398 Args:
399 task_type: Unused.
400 task_id: Unused.
401 config_proto: Unused.
402 """
403 # Unused
404 del task_type, task_id, config_proto
405 if self._num_accelerators is None:
406 return {}
407 return self._num_accelerators
408
409 @property
410 def rpc_layer(self):
411 return self._rpc_layer
412
413 @rpc_layer.setter
414 def rpc_layer(self, rpc_layer):
415 self._rpc_layer = rpc_layer
416
417
418@tf_export('distribute.cluster_resolver.UnionResolver')
419class UnionClusterResolver(ClusterResolver):
420 """Performs a union on underlying ClusterResolvers.
421
422 This class performs a union given two or more existing ClusterResolvers. It
423 merges the underlying ClusterResolvers, and returns one unified ClusterSpec
424 when cluster_spec is called. The details of the merge function is
425 documented in the cluster_spec function.
426
427 For additional ClusterResolver properties such as task type, task index,
428 rpc layer, environment, etc..., we will return the value from the first
429 ClusterResolver in the union.
430
431 An example to combine two cluster resolvers:
432
433 ```Python
434 cluster_0 = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222",
435 "worker1.example.com:2222"]})
436 cluster_resolver_0 = SimpleClusterResolver(cluster, task_type="worker",
437 task_id=0,
438 rpc_layer="grpc")
439
440 cluster_1 = tf.train.ClusterSpec({"ps": ["ps0.example.com:2222",
441 "ps1.example.com:2222"]})
442 cluster_resolver_1 = SimpleClusterResolver(cluster, task_type="ps",
443 task_id=0,
444 rpc_layer="grpc")
445
446 # Its task type would be "worker".
447 cluster_resolver = UnionClusterResolver(cluster_resolver_0,
448 cluster_resolver_1)
449 ```
450
451 An example to override the number of GPUs in a TFConfigClusterResolver
452 instance:
453
454 ```Python
455 tf_config = TFConfigClusterResolver()
456 gpu_override = SimpleClusterResolver(tf_config.cluster_spec(),
457 num_accelerators={"GPU": 1})
458 cluster_resolver = UnionResolver(gpu_override, tf_config)
459 ```
460 """
461
462 def __init__(self, *args, **kwargs):
463 """Initializes a UnionClusterResolver with other ClusterResolvers.
464
465 Args:
466 *args: `ClusterResolver` objects to be unionized.
467 **kwargs:
468 rpc_layer - (Optional) Override value for the RPC layer used by
469 TensorFlow.
470 task_type - (Optional) Override value for the current task type.
471 task_id - (Optional) Override value for the current task index.
472
473 Raises:
474 TypeError: If any argument is not a subclass of `ClusterResolvers`.
475 ValueError: If there are no arguments passed.
476 """
477 super(UnionClusterResolver, self).__init__()
478
479 self._rpc_layer = kwargs.pop('rpc_layer', None)
480 self._task_type = kwargs.pop('task_type', None)
481 self._task_id = kwargs.pop('task_id', None)
482
483 if kwargs:
484 raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs))
485
486 if not args:
487 raise ValueError('At least one ClusterResolver is required.')
488
489 for cluster_resolver in args:
490 if not isinstance(cluster_resolver, ClusterResolver):
491 raise TypeError('All arguments must be a sub-class of '
492 '`ClusterResolver.`')
493 self._cluster_resolvers = args
494
495 def cluster_spec(self):
496 """Returns a union of all the ClusterSpecs from the ClusterResolvers.
497
498 Returns:
499 A ClusterSpec containing host information merged from all the underlying
500 ClusterResolvers.
501
502 Raises:
503 KeyError: If there are conflicting keys detected when merging two or
504 more dictionaries, this exception is raised.
505
506 Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the
507 same job name, we will merge the list/dict of workers.
508
509 If *all* underlying ClusterSpecs expose the set of workers as lists, we will
510 concatenate the lists of workers, starting with the list of workers from
511 the first ClusterResolver passed into the constructor.
512
513 If *any* of the ClusterSpecs expose the set of workers as a dict, we will
514 treat all the sets of workers as dicts (even if they are returned as lists)
515 and will only merge them into a dict if there is no conflicting keys. If
516 there is a conflicting key, we will raise a `KeyError`.
517 """
518
519 merged_cluster = {}
520
521 # We figure out whether it is all lists for a particular job, or whether
522 # there are dicts inside.
523 for cluster_resolver in self._cluster_resolvers:
524 cluster_spec = cluster_resolver.cluster_spec()
525 cluster_dict = cluster_spec.as_dict()
526
527 for job_name, tasks in cluster_dict.items():
528 if job_name in merged_cluster:
529 # If we see a dict, then we write a dict out regardless.
530 if isinstance(tasks, dict):
531 merged_cluster[job_name] = {}
532 else:
533 # We take whichever type is present.
534 if isinstance(tasks, list):
535 merged_cluster[job_name] = []
536 else:
537 merged_cluster[job_name] = {}
538
539 # We then do the merge as appropriate in merged_cluster[job].
540 for cluster_resolver in self._cluster_resolvers:
541 cluster_spec = cluster_resolver.cluster_spec()
542 cluster_dict = cluster_spec.as_dict()
543
544 for job_name, tasks in cluster_dict.items():
545 if isinstance(merged_cluster[job_name], list):
546 # We all have lists, we can just concatenate and be done.
547 merged_cluster[job_name].extend(tasks)
548 else:
549 if isinstance(tasks, list):
550 # We convert to a dictionary if the type is a list.
551 task_dict = dict(zip(range(0, len(tasks)), tasks))
552 else:
553 # We can simply make a copy (for update) and be done.
554 task_dict = tasks.copy()
555
556 # We detect if there are duplicates, and raise an error if so.
557 task_keys = set(task_dict)
558 merged_keys = set(merged_cluster[job_name].keys())
559 intersected_keys = task_keys.intersection(merged_keys)
560 if intersected_keys:
561 raise KeyError('Duplicate keys detected when merging two '
562 'ClusterSpecs: %s' % repr(intersected_keys))
563
564 # We do the merge after all the processing.
565 merged_cluster[job_name].update(task_dict)
566
567 return ClusterSpec(merged_cluster)
568
569 def master(self, task_type=None, task_id=None, rpc_layer=None):
570 """Returns the master address to use when creating a session.
571
572 This usually returns the master from the first ClusterResolver passed in,
573 but you can override this by specifying the task_type and task_id.
574
575 Note: this is only useful for TensorFlow 1.x.
576
577 Args:
578 task_type: (Optional) The type of the TensorFlow task of the master.
579 task_id: (Optional) The index of the TensorFlow task of the master.
580 rpc_layer: (Optional) The RPC protocol for the given cluster.
581
582 Returns:
583 The name or URL of the session master.
584 """
585 if task_type is not None and task_id is not None:
586 master = self.cluster_spec().task_address(task_type, task_id)
587 return format_master_url(master, rpc_layer or self._rpc_layer)
588
589 return self._cluster_resolvers[0].master(rpc_layer=rpc_layer)
590
591 @property
592 def task_type(self):
593 return self._task_type or self._cluster_resolvers[0].task_type
594
595 @property
596 def task_id(self):
597 return self._task_id or self._cluster_resolvers[0].task_id
598
599 @task_type.setter
600 def task_type(self, task_type):
601 self._task_type = task_type
602
603 @task_id.setter
604 def task_id(self, task_id):
605 self._task_id = task_id
606
607 @property
608 def environment(self):
609 return self._cluster_resolvers[0].environment
610
611 def num_accelerators(self,
612 task_type=None,
613 task_id=None,
614 config_proto=None):
615 return self._cluster_resolvers[0].num_accelerators(
616 task_type, task_id, config_proto)
617
618 @property
619 def rpc_layer(self):
620 return self._rpc_layer or self._cluster_resolvers[0].rpc_layer
621
622 @rpc_layer.setter
623 def rpc_layer(self, rpc_layer):
624 self._rpc_layer = rpc_layer