1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of Cluster Resolvers for Cloud TPUs."""
16
17import collections
18import re
19
20from tensorflow.core.protobuf.tpu import topology_pb2
21from tensorflow.python.distribute.cluster_resolver import cluster_resolver
22from tensorflow.python.framework import config as framework_config
23from tensorflow.python.framework import errors
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
26from tensorflow.python.training import server_lib
27from tensorflow.python.util import compat
28
29try:
30 from cloud_tpu_client import client # pylint: disable=g-import-not-at-top
31except ImportError:
32 logging.debug(
33 'Falling back to TensorFlow client; we recommended you install the Cloud '
34 'TPU client directly with pip install cloud-tpu-client.')
35 from tensorflow.python.tpu.client import client # pylint: disable=g-import-not-at-top
36
37
38def is_running_in_gce():
39 return True
40
41
42class _LocalCloudTpuClient(object):
43 """Dummy local Cloud TPU client."""
44
45 def api_available(self):
46 return False
47
48
49_TPU_DEVICE_REGEX = re.compile(
50 r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$')
51_TPU_CONN_RETRIES = 120
52DeviceDetails = collections.namedtuple(
53 'DeviceDetails', ['device_map', 'total_cores'])
54
55
56class TPUClusterResolver(cluster_resolver.ClusterResolver):
57 """Cluster Resolver for Google Cloud TPUs.
58
59 This is an implementation of cluster resolvers for the Google Cloud TPU
60 service.
61
62 TPUClusterResolver supports the following distinct environments:
63 Google Compute Engine
64 Google Kubernetes Engine
65 Google internal
66
67 It can be passed into `tf.distribute.TPUStrategy` to support TF2 training on
68 Cloud TPUs.
69 """
70
71 @staticmethod
72 def connect(tpu=None,
73 zone=None,
74 project=None):
75 """Initializes TPU and returns a TPUClusterResolver.
76
77 This API will connect to remote TPU cluster and initialize the TPU
78 hardwares. Example usage:
79
80 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect(
81 ... tpu='')
82
83 It can be viewed as convenient wrapper of the following code:
84
85 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
86 >>> tf.config.experimental_connect_to_cluster(resolver)
87 >>> tf.tpu.experimental.initialize_tpu_system(resolver)
88
89 Args:
90 tpu: A string corresponding to the TPU to use. It can be the TPU name or
91 TPU worker gRPC address. If not set, it will try automatically resolve
92 the TPU address on Cloud TPUs.
93 zone: Zone where the TPUs are located. If omitted or empty, we will assume
94 that the zone of the TPU is the same as the zone of the GCE VM, which we
95 will try to discover from the GCE metadata service.
96 project: Name of the GCP project containing Cloud TPUs. If omitted or
97 empty, we will try to discover the project name of the GCE VM from the
98 GCE metadata service.
99
100 Returns:
101 An instance of TPUClusterResolver object.
102
103 Raises:
104 NotFoundError: If no TPU devices found in eager mode.
105 """
106 resolver = TPUClusterResolver(tpu, zone, project)
107 from tensorflow.python.eager import remote # pylint: disable=g-import-not-at-top
108 remote.connect_to_cluster(resolver)
109 from tensorflow.python.tpu import tpu_strategy_util # pylint: disable=g-import-not-at-top
110 tpu_strategy_util.initialize_tpu_system(resolver)
111 return resolver
112
113 @staticmethod
114 def _get_device_dict_and_cores(devices):
115 """Returns a dict of hosts to cores and total cores given devices names.
116
117 Returns a namedtuple with two attributes:
118 device_map: A map of host_ids to a list of core_ids.
119 total_cores: The total number of cores within the TPU system.
120
121 Args:
122 devices: A list of devices returned by session.list_devices()
123 """
124 device_map = collections.defaultdict(list)
125 num_cores = 0
126 for device in devices:
127 match = _TPU_DEVICE_REGEX.match(device.name)
128 if match:
129 host_id = match.group('host_id')
130 core_id = match.group('core_id')
131 device_map[host_id].append(core_id)
132 num_cores += 1
133 return DeviceDetails(device_map, num_cores)
134
135 @staticmethod
136 def _verify_and_return_same_core_count(device_dict):
137 """Verifies that every device in device_dict has the same # of cores."""
138 num_cores_per_host_set = (
139 {len(core_ids) for core_ids in device_dict.values()})
140 if len(num_cores_per_host_set) != 1:
141 raise RuntimeError('TPU cores on each device is not the same. This '
142 'should never happen. Devices: {}'.format(device_dict))
143 return num_cores_per_host_set.pop()
144
145 def __init__(self,
146 tpu=None,
147 zone=None,
148 project=None,
149 job_name='worker',
150 coordinator_name=None,
151 coordinator_address=None,
152 credentials='default',
153 service=None,
154 discovery_url=None):
155 """Creates a new TPUClusterResolver object.
156
157 The ClusterResolver will then use the parameters to query the Cloud TPU APIs
158 for the IP addresses and ports of each Cloud TPU listed.
159
160 Args:
161 tpu: A string corresponding to the TPU to use. It can be the TPU name or
162 TPU worker gRPC address. If not set, it will try automatically resolve
163 the TPU address on Cloud TPUs. If set to "local", it will assume that
164 the TPU is directly connected to the VM instead of over the network.
165 zone: Zone where the TPUs are located. If omitted or empty, we will assume
166 that the zone of the TPU is the same as the zone of the GCE VM, which we
167 will try to discover from the GCE metadata service.
168 project: Name of the GCP project containing Cloud TPUs. If omitted or
169 empty, we will try to discover the project name of the GCE VM from the
170 GCE metadata service.
171 job_name: Name of the TensorFlow job the TPUs belong to.
172 coordinator_name: The name to use for the coordinator. Set to None if the
173 coordinator should not be included in the computed ClusterSpec.
174 coordinator_address: The address of the coordinator (typically an ip:port
175 pair). If set to None, a TF server will be started. If coordinator_name
176 is None, a TF server will not be started even if coordinator_address is
177 None.
178 credentials: GCE Credentials. If None, then we use default credentials
179 from the oauth2client
180 service: The GCE API object returned by the googleapiclient.discovery
181 function. If you specify a custom service object, then the credentials
182 parameter will be ignored.
183 discovery_url: A URL template that points to the location of the discovery
184 service. It should have two parameters {api} and {apiVersion} that when
185 filled in produce an absolute URL to the discovery document for that
186 service. The environment variable 'TPU_API_DISCOVERY_URL' will override
187 this.
188
189 Raises:
190 ImportError: If the googleapiclient is not installed.
191 ValueError: If no TPUs are specified.
192 RuntimeError: If an empty TPU name is specified and this is running in a
193 Google Cloud environment.
194 """
195
196 if tpu != 'local':
197 # Default Cloud environment
198 self._cloud_tpu_client = client.Client(
199 tpu=tpu,
200 zone=zone,
201 project=project,
202 credentials=credentials,
203 service=service,
204 discovery_url=discovery_url)
205 self._tpu = self._cloud_tpu_client.name()
206 else:
207 # Directly connected TPU environment
208 self._cloud_tpu_client = _LocalCloudTpuClient()
209 self._tpu = 'local'
210
211 # By default the task_type is 'worker` and the task_id is 0 (which is the
212 # first worker in the task).
213 self.task_type = job_name
214 self.task_id = 0
215 self._coordinator_name = coordinator_name
216 if (coordinator_name and not coordinator_address):
217 self._start_local_server()
218 else:
219 self._coordinator_address = coordinator_address
220
221 self._tpu_topology = None
222
223 def __enter__(self):
224 self._cloud_tpu_client.enter()
225
226 def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin
227 self._cloud_tpu_client.exit(type, value, traceback)
228
229 def master(self, task_type=None, task_id=None, rpc_layer=None):
230 """Get the Master string to be used for the session.
231
232 In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of
233 first instance in the ClusterSpec returned by the cluster_spec function.
234
235 If a non-TPU name is used when constructing a TPUClusterResolver, that will
236 be returned instead (e.g. If the tpus argument's value when constructing
237 this TPUClusterResolver was 'grpc://10.240.1.2:8470',
238 'grpc://10.240.1.2:8470' will be returned).
239
240 Args:
241 task_type: (Optional, string) The type of the TensorFlow task of the
242 master.
243 task_id: (Optional, integer) The index of the TensorFlow task of the
244 master.
245 rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to
246 communicate with TPUs.
247
248 Returns:
249 string, the connection string to use when creating a session.
250
251 Raises:
252 ValueError: If none of the TPUs specified exists.
253 """
254
255 if self._tpu != 'local':
256 cluster_spec = self.cluster_spec()
257 if task_type is not None and task_id is not None:
258 # task_type and task_id is from the function parameter
259 master = cluster_spec.task_address(task_type, task_id)
260 elif self.task_type is not None and self.task_id is not None:
261 # task_type and task_id is from the object
262 master = cluster_spec.task_address(self.task_type, self.task_id)
263 else:
264 # by default we take the first item in the cluster with the right name
265 job_tasks = cluster_spec.job_tasks(self.task_type)
266 if not job_tasks:
267 raise ValueError('No TPUs with the specified names exist.')
268 master = job_tasks[0]
269 return cluster_resolver.format_master_url(master, 'grpc')
270 else:
271 return ''
272
273 def get_master(self):
274 return self.master()
275
276 def get_job_name(self):
277 return self.task_type
278
279 def get_coordination_service_leader(self):
280 """Returns the location for coordination service.
281
282 The coordination service should be located on TPU worker0.
283
284 Returns:
285 A string indicate the location path.
286 """
287 return '/job:' + self.get_job_name() + '/task:0'
288
289 def get_tpu_system_metadata(self):
290 """Returns the metadata of the TPU system.
291
292 Users can call this method to get some facts of the TPU system, like
293 total number of cores, number of TPU workers and the devices. E.g.
294 ```python
295
296 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
297 tpu_system_metadata = resolver.get_tpu_system_metadata()
298 num_hosts = tpu_system_metadata.num_hosts
299 ```
300
301 Returns:
302 A `tf.tpu.experimental.TPUSystemMetadata` object.
303 """
304 cluster_spec = self.cluster_spec()
305 cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
306 tpu_system_metadata = (
307 tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access
308 self.master(),
309 cluster_def=cluster_def,
310 query_topology=False))
311
312 return tpu_system_metadata
313
314 def cluster_spec(self):
315 """Returns a ClusterSpec object based on the latest TPU information.
316
317 We retrieve the information from the GCE APIs every time this method is
318 called.
319
320 Returns:
321 A ClusterSpec containing host information returned from Cloud TPUs,
322 or None.
323
324 Raises:
325 RuntimeError: If the provided TPU is not healthy.
326 """
327 ############################################################################
328 # There are 6 potential cases this code must handle:
329 # 0. [Local case.] When a TPU is connected directly to the VM.
330 # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and
331 # a. Create a ClusterSpec that includes the coordinator job
332 # b. Create a ClusterSpec without the coordinator job.
333 # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of
334 # tasks and
335 # a. Create a ClusterSpec with the coordinator
336 # b. Create a ClusterSpec without the coordinator
337 ############################################################################
338
339 if self._tpu != 'local':
340 network_endpoints = self._cloud_tpu_client.network_endpoints()
341 worker_list = [
342 '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
343 for endpoint in network_endpoints
344 ]
345 cluster_spec = {self.task_type: worker_list}
346 if self._coordinator_address:
347 # {1, 2}.a
348 cluster_spec[self._coordinator_name] = [self._coordinator_address]
349 return server_lib.ClusterSpec(cluster_spec)
350 else:
351 return server_lib.ClusterSpec({})
352
353 def num_accelerators(self,
354 task_type=None,
355 task_id=None,
356 config_proto=None):
357 """Returns the number of TPU cores per worker.
358
359 Connects to the master and list all the devices present in the master,
360 and counts them up. Also verifies that the device counts per host in the
361 cluster is the same before returning the number of TPU cores per host.
362
363 Args:
364 task_type: Unused.
365 task_id: Unused.
366 config_proto: Used to create a connection to a TPU master in order to
367 retrieve the system metadata.
368
369 Raises:
370 RuntimeError: If we cannot talk to a TPU worker after retrying or if the
371 number of TPU devices per host is different.
372 """
373 if self._tpu == 'local':
374 return {
375 'TPU':
376 len([
377 d for d in framework_config.list_logical_devices()
378 if d.device_type == 'TPU'
379 ])
380 }
381
382 retry_count = 1
383 # TODO(b/120564445): Replace with standard library for retries.
384 while True:
385 try:
386 device_details = TPUClusterResolver._get_device_dict_and_cores(
387 cluster_resolver.get_accelerator_devices(
388 self.master(), config_proto=config_proto))
389 break
390 except errors.DeadlineExceededError:
391 error_message = ('Failed to connect to master. The TPU might not be '
392 'ready (e.g. still scheduling) or the master '
393 'address is incorrect: got (%s)' % self.master())
394 if retry_count <= _TPU_CONN_RETRIES:
395 logging.warning(error_message)
396 logging.warning('Retrying (%d/%d)...', retry_count, _TPU_CONN_RETRIES)
397 retry_count += 1
398 else:
399 raise RuntimeError(error_message)
400
401 if device_details.total_cores:
402 return {
403 'TPU':
404 TPUClusterResolver._verify_and_return_same_core_count(
405 device_details.device_map)
406 }
407 return {'TPU': 0}
408
409 def set_tpu_topology(self, serialized_tpu_topology):
410 """Sets the tpu topology info stored in this resolver."""
411 self._tpu_topology = topology_pb2.TopologyProto()
412 self._tpu_topology.ParseFromString(serialized_tpu_topology)
413
414 @property
415 def tpu_hardware_feature(self):
416 """Returns the tpu topology info stored."""
417 if self._tpu_topology is None:
418 return self._tpu_topology
419 return self._tpu_topology.tpu_hardware_feature
420
421 @property
422 def environment(self):
423 """Returns the current environment which TensorFlow is running in."""
424 return self._environment
425
426 def _start_local_server(self):
427 address = compat.as_text(self._cloud_tpu_client.get_local_ip())
428 self._server = server_lib.Server({'local': ['0.0.0.0:0']},
429 protocol='grpc',
430 config=None,
431 start=True)
432 # self._server.target is of the form: grpc://ipaddress:port
433 target = compat.as_bytes(self._server.target)
434 splits = target.split(compat.as_bytes(':'))
435 assert len(splits) == 3, self._server.target
436 assert splits[0] == compat.as_bytes('grpc'), self._server.target
437 self._coordinator_port = compat.as_text(splits[2])
438 self._coordinator_address = '%s:%s' % (
439 address, compat.as_text(self._coordinator_port))
440
441 def __deepcopy__(self, memo):
442 # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy.
443 return self