Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py: 33%

140 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implementation of Cluster Resolvers for Cloud TPUs.""" 

16 

17import collections 

18import re 

19 

20from tensorflow.core.protobuf.tpu import topology_pb2 

21from tensorflow.python.distribute.cluster_resolver import cluster_resolver 

22from tensorflow.python.framework import config as framework_config 

23from tensorflow.python.framework import errors 

24from tensorflow.python.platform import tf_logging as logging 

25from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 

26from tensorflow.python.training import server_lib 

27from tensorflow.python.util import compat 

28 

29try: 

30 from cloud_tpu_client import client # pylint: disable=g-import-not-at-top 

31except ImportError: 

32 logging.debug( 

33 'Falling back to TensorFlow client; we recommended you install the Cloud ' 

34 'TPU client directly with pip install cloud-tpu-client.') 

35 from tensorflow.python.tpu.client import client # pylint: disable=g-import-not-at-top 

36 

37 

38def is_running_in_gce(): 

39 return True 

40 

41 

42class _LocalCloudTpuClient(object): 

43 """Dummy local Cloud TPU client.""" 

44 

45 def api_available(self): 

46 return False 

47 

48 

49_TPU_DEVICE_REGEX = re.compile( 

50 r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$') 

51_TPU_CONN_RETRIES = 120 

52DeviceDetails = collections.namedtuple( 

53 'DeviceDetails', ['device_map', 'total_cores']) 

54 

55 

56class TPUClusterResolver(cluster_resolver.ClusterResolver): 

57 """Cluster Resolver for Google Cloud TPUs. 

58 

59 This is an implementation of cluster resolvers for the Google Cloud TPU 

60 service. 

61 

62 TPUClusterResolver supports the following distinct environments: 

63 Google Compute Engine 

64 Google Kubernetes Engine 

65 Google internal 

66 

67 It can be passed into `tf.distribute.TPUStrategy` to support TF2 training on 

68 Cloud TPUs. 

69 """ 

70 

71 @staticmethod 

72 def connect(tpu=None, 

73 zone=None, 

74 project=None): 

75 """Initializes TPU and returns a TPUClusterResolver. 

76 

77 This API will connect to remote TPU cluster and initialize the TPU 

78 hardwares. Example usage: 

79 

80 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect( 

81 ... tpu='') 

82 

83 It can be viewed as convenient wrapper of the following code: 

84 

85 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 

86 >>> tf.config.experimental_connect_to_cluster(resolver) 

87 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 

88 

89 Args: 

90 tpu: A string corresponding to the TPU to use. It can be the TPU name or 

91 TPU worker gRPC address. If not set, it will try automatically resolve 

92 the TPU address on Cloud TPUs. 

93 zone: Zone where the TPUs are located. If omitted or empty, we will assume 

94 that the zone of the TPU is the same as the zone of the GCE VM, which we 

95 will try to discover from the GCE metadata service. 

96 project: Name of the GCP project containing Cloud TPUs. If omitted or 

97 empty, we will try to discover the project name of the GCE VM from the 

98 GCE metadata service. 

99 

100 Returns: 

101 An instance of TPUClusterResolver object. 

102 

103 Raises: 

104 NotFoundError: If no TPU devices found in eager mode. 

105 """ 

106 resolver = TPUClusterResolver(tpu, zone, project) 

107 from tensorflow.python.eager import remote # pylint: disable=g-import-not-at-top 

108 remote.connect_to_cluster(resolver) 

109 from tensorflow.python.tpu import tpu_strategy_util # pylint: disable=g-import-not-at-top 

110 tpu_strategy_util.initialize_tpu_system(resolver) 

111 return resolver 

112 

113 @staticmethod 

114 def _get_device_dict_and_cores(devices): 

115 """Returns a dict of hosts to cores and total cores given devices names. 

116 

117 Returns a namedtuple with two attributes: 

118 device_map: A map of host_ids to a list of core_ids. 

119 total_cores: The total number of cores within the TPU system. 

120 

121 Args: 

122 devices: A list of devices returned by session.list_devices() 

123 """ 

124 device_map = collections.defaultdict(list) 

125 num_cores = 0 

126 for device in devices: 

127 match = _TPU_DEVICE_REGEX.match(device.name) 

128 if match: 

129 host_id = match.group('host_id') 

130 core_id = match.group('core_id') 

131 device_map[host_id].append(core_id) 

132 num_cores += 1 

133 return DeviceDetails(device_map, num_cores) 

134 

135 @staticmethod 

136 def _verify_and_return_same_core_count(device_dict): 

137 """Verifies that every device in device_dict has the same # of cores.""" 

138 num_cores_per_host_set = ( 

139 {len(core_ids) for core_ids in device_dict.values()}) 

140 if len(num_cores_per_host_set) != 1: 

141 raise RuntimeError('TPU cores on each device is not the same. This ' 

142 'should never happen. Devices: {}'.format(device_dict)) 

143 return num_cores_per_host_set.pop() 

144 

145 def __init__(self, 

146 tpu=None, 

147 zone=None, 

148 project=None, 

149 job_name='worker', 

150 coordinator_name=None, 

151 coordinator_address=None, 

152 credentials='default', 

153 service=None, 

154 discovery_url=None): 

155 """Creates a new TPUClusterResolver object. 

156 

157 The ClusterResolver will then use the parameters to query the Cloud TPU APIs 

158 for the IP addresses and ports of each Cloud TPU listed. 

159 

160 Args: 

161 tpu: A string corresponding to the TPU to use. It can be the TPU name or 

162 TPU worker gRPC address. If not set, it will try automatically resolve 

163 the TPU address on Cloud TPUs. If set to "local", it will assume that 

164 the TPU is directly connected to the VM instead of over the network. 

165 zone: Zone where the TPUs are located. If omitted or empty, we will assume 

166 that the zone of the TPU is the same as the zone of the GCE VM, which we 

167 will try to discover from the GCE metadata service. 

168 project: Name of the GCP project containing Cloud TPUs. If omitted or 

169 empty, we will try to discover the project name of the GCE VM from the 

170 GCE metadata service. 

171 job_name: Name of the TensorFlow job the TPUs belong to. 

172 coordinator_name: The name to use for the coordinator. Set to None if the 

173 coordinator should not be included in the computed ClusterSpec. 

174 coordinator_address: The address of the coordinator (typically an ip:port 

175 pair). If set to None, a TF server will be started. If coordinator_name 

176 is None, a TF server will not be started even if coordinator_address is 

177 None. 

178 credentials: GCE Credentials. If None, then we use default credentials 

179 from the oauth2client 

180 service: The GCE API object returned by the googleapiclient.discovery 

181 function. If you specify a custom service object, then the credentials 

182 parameter will be ignored. 

183 discovery_url: A URL template that points to the location of the discovery 

184 service. It should have two parameters {api} and {apiVersion} that when 

185 filled in produce an absolute URL to the discovery document for that 

186 service. The environment variable 'TPU_API_DISCOVERY_URL' will override 

187 this. 

188 

189 Raises: 

190 ImportError: If the googleapiclient is not installed. 

191 ValueError: If no TPUs are specified. 

192 RuntimeError: If an empty TPU name is specified and this is running in a 

193 Google Cloud environment. 

194 """ 

195 

196 if tpu != 'local': 

197 # Default Cloud environment 

198 self._cloud_tpu_client = client.Client( 

199 tpu=tpu, 

200 zone=zone, 

201 project=project, 

202 credentials=credentials, 

203 service=service, 

204 discovery_url=discovery_url) 

205 self._tpu = self._cloud_tpu_client.name() 

206 else: 

207 # Directly connected TPU environment 

208 self._cloud_tpu_client = _LocalCloudTpuClient() 

209 self._tpu = 'local' 

210 

211 # By default the task_type is 'worker` and the task_id is 0 (which is the 

212 # first worker in the task). 

213 self.task_type = job_name 

214 self.task_id = 0 

215 self._coordinator_name = coordinator_name 

216 if (coordinator_name and not coordinator_address): 

217 self._start_local_server() 

218 else: 

219 self._coordinator_address = coordinator_address 

220 

221 self._tpu_topology = None 

222 

223 def __enter__(self): 

224 self._cloud_tpu_client.enter() 

225 

226 def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin 

227 self._cloud_tpu_client.exit(type, value, traceback) 

228 

229 def master(self, task_type=None, task_id=None, rpc_layer=None): 

230 """Get the Master string to be used for the session. 

231 

232 In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of 

233 first instance in the ClusterSpec returned by the cluster_spec function. 

234 

235 If a non-TPU name is used when constructing a TPUClusterResolver, that will 

236 be returned instead (e.g. If the tpus argument's value when constructing 

237 this TPUClusterResolver was 'grpc://10.240.1.2:8470', 

238 'grpc://10.240.1.2:8470' will be returned). 

239 

240 Args: 

241 task_type: (Optional, string) The type of the TensorFlow task of the 

242 master. 

243 task_id: (Optional, integer) The index of the TensorFlow task of the 

244 master. 

245 rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to 

246 communicate with TPUs. 

247 

248 Returns: 

249 string, the connection string to use when creating a session. 

250 

251 Raises: 

252 ValueError: If none of the TPUs specified exists. 

253 """ 

254 

255 if self._tpu != 'local': 

256 cluster_spec = self.cluster_spec() 

257 if task_type is not None and task_id is not None: 

258 # task_type and task_id is from the function parameter 

259 master = cluster_spec.task_address(task_type, task_id) 

260 elif self.task_type is not None and self.task_id is not None: 

261 # task_type and task_id is from the object 

262 master = cluster_spec.task_address(self.task_type, self.task_id) 

263 else: 

264 # by default we take the first item in the cluster with the right name 

265 job_tasks = cluster_spec.job_tasks(self.task_type) 

266 if not job_tasks: 

267 raise ValueError('No TPUs with the specified names exist.') 

268 master = job_tasks[0] 

269 return cluster_resolver.format_master_url(master, 'grpc') 

270 else: 

271 return '' 

272 

273 def get_master(self): 

274 return self.master() 

275 

276 def get_job_name(self): 

277 return self.task_type 

278 

279 def get_coordination_service_leader(self): 

280 """Returns the location for coordination service. 

281 

282 The coordination service should be located on TPU worker0. 

283 

284 Returns: 

285 A string indicate the location path. 

286 """ 

287 return '/job:' + self.get_job_name() + '/task:0' 

288 

289 def get_tpu_system_metadata(self): 

290 """Returns the metadata of the TPU system. 

291 

292 Users can call this method to get some facts of the TPU system, like 

293 total number of cores, number of TPU workers and the devices. E.g. 

294 ```python 

295 

296 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 

297 tpu_system_metadata = resolver.get_tpu_system_metadata() 

298 num_hosts = tpu_system_metadata.num_hosts 

299 ``` 

300 

301 Returns: 

302 A `tf.tpu.experimental.TPUSystemMetadata` object. 

303 """ 

304 cluster_spec = self.cluster_spec() 

305 cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None 

306 tpu_system_metadata = ( 

307 tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access 

308 self.master(), 

309 cluster_def=cluster_def, 

310 query_topology=False)) 

311 

312 return tpu_system_metadata 

313 

314 def cluster_spec(self): 

315 """Returns a ClusterSpec object based on the latest TPU information. 

316 

317 We retrieve the information from the GCE APIs every time this method is 

318 called. 

319 

320 Returns: 

321 A ClusterSpec containing host information returned from Cloud TPUs, 

322 or None. 

323 

324 Raises: 

325 RuntimeError: If the provided TPU is not healthy. 

326 """ 

327 ############################################################################ 

328 # There are 6 potential cases this code must handle: 

329 # 0. [Local case.] When a TPU is connected directly to the VM. 

330 # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and 

331 # a. Create a ClusterSpec that includes the coordinator job 

332 # b. Create a ClusterSpec without the coordinator job. 

333 # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of 

334 # tasks and 

335 # a. Create a ClusterSpec with the coordinator 

336 # b. Create a ClusterSpec without the coordinator 

337 ############################################################################ 

338 

339 if self._tpu != 'local': 

340 network_endpoints = self._cloud_tpu_client.network_endpoints() 

341 worker_list = [ 

342 '%s:%s' % (endpoint['ipAddress'], endpoint['port']) 

343 for endpoint in network_endpoints 

344 ] 

345 cluster_spec = {self.task_type: worker_list} 

346 if self._coordinator_address: 

347 # {1, 2}.a 

348 cluster_spec[self._coordinator_name] = [self._coordinator_address] 

349 return server_lib.ClusterSpec(cluster_spec) 

350 else: 

351 return server_lib.ClusterSpec({}) 

352 

353 def num_accelerators(self, 

354 task_type=None, 

355 task_id=None, 

356 config_proto=None): 

357 """Returns the number of TPU cores per worker. 

358 

359 Connects to the master and list all the devices present in the master, 

360 and counts them up. Also verifies that the device counts per host in the 

361 cluster is the same before returning the number of TPU cores per host. 

362 

363 Args: 

364 task_type: Unused. 

365 task_id: Unused. 

366 config_proto: Used to create a connection to a TPU master in order to 

367 retrieve the system metadata. 

368 

369 Raises: 

370 RuntimeError: If we cannot talk to a TPU worker after retrying or if the 

371 number of TPU devices per host is different. 

372 """ 

373 if self._tpu == 'local': 

374 return { 

375 'TPU': 

376 len([ 

377 d for d in framework_config.list_logical_devices() 

378 if d.device_type == 'TPU' 

379 ]) 

380 } 

381 

382 retry_count = 1 

383 # TODO(b/120564445): Replace with standard library for retries. 

384 while True: 

385 try: 

386 device_details = TPUClusterResolver._get_device_dict_and_cores( 

387 cluster_resolver.get_accelerator_devices( 

388 self.master(), config_proto=config_proto)) 

389 break 

390 except errors.DeadlineExceededError: 

391 error_message = ('Failed to connect to master. The TPU might not be ' 

392 'ready (e.g. still scheduling) or the master ' 

393 'address is incorrect: got (%s)' % self.master()) 

394 if retry_count <= _TPU_CONN_RETRIES: 

395 logging.warning(error_message) 

396 logging.warning('Retrying (%d/%d)...', retry_count, _TPU_CONN_RETRIES) 

397 retry_count += 1 

398 else: 

399 raise RuntimeError(error_message) 

400 

401 if device_details.total_cores: 

402 return { 

403 'TPU': 

404 TPUClusterResolver._verify_and_return_same_core_count( 

405 device_details.device_map) 

406 } 

407 return {'TPU': 0} 

408 

409 def set_tpu_topology(self, serialized_tpu_topology): 

410 """Sets the tpu topology info stored in this resolver.""" 

411 self._tpu_topology = topology_pb2.TopologyProto() 

412 self._tpu_topology.ParseFromString(serialized_tpu_topology) 

413 

414 @property 

415 def tpu_hardware_feature(self): 

416 """Returns the tpu topology info stored.""" 

417 if self._tpu_topology is None: 

418 return self._tpu_topology 

419 return self._tpu_topology.tpu_hardware_feature 

420 

421 @property 

422 def environment(self): 

423 """Returns the current environment which TensorFlow is running in.""" 

424 return self._environment 

425 

426 def _start_local_server(self): 

427 address = compat.as_text(self._cloud_tpu_client.get_local_ip()) 

428 self._server = server_lib.Server({'local': ['0.0.0.0:0']}, 

429 protocol='grpc', 

430 config=None, 

431 start=True) 

432 # self._server.target is of the form: grpc://ipaddress:port 

433 target = compat.as_bytes(self._server.target) 

434 splits = target.split(compat.as_bytes(':')) 

435 assert len(splits) == 3, self._server.target 

436 assert splits[0] == compat.as_bytes('grpc'), self._server.target 

437 self._coordinator_port = compat.as_text(splits[2]) 

438 self._coordinator_address = '%s:%s' % ( 

439 address, compat.as_text(self._coordinator_port)) 

440 

441 def __deepcopy__(self, memo): 

442 # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy. 

443 return self