Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/client/client.py: 25%

232 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Cloud TPU Client.""" 

16 

17from concurrent import futures 

18import datetime 

19import json 

20import logging 

21import os 

22import time 

23import urllib 

24 

25from absl import flags 

26 

27_GOOGLE_API_CLIENT_INSTALLED = True 

28try: 

29 from googleapiclient import discovery # pylint: disable=g-import-not-at-top 

30 from oauth2client import client # pylint: disable=g-import-not-at-top 

31except ImportError: 

32 _GOOGLE_API_CLIENT_INSTALLED = False 

33 

34FLAGS = flags.FLAGS 

35 

36flags.DEFINE_bool('runtime_oom_exit', True, 

37 'Exit the script when the TPU runtime is OOM.') 

38flags.DEFINE_bool('hbm_oom_exit', True, 

39 'Exit the script when the TPU HBM is OOM.') 

40 

41_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' 

42_DEFAULT_TPUCONFIG_VARIABLE = 'TPU_CONFIG' 

43_ENDPOINTS_SEPARATOR = ',' 

44_DEFAULT_ENV_VARIABLE = 'TPU_NAME' 

45_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' 

46_GCE_METADATA_URL_ENV_VARIABLE = 'GCE_METADATA_IP' 

47_GCE_METADATA_ENDPOINT_ENV_VARIABLE = 'GCE_METADATA_HOST' 

48_DEFAULT_ENDPOINT_PORT = '8470' 

49_OOM_EVENT_COOL_TIME_SEC = 90 

50_VERSION_SWITCHER_ENDPOINT = 'http://{}:8475/requestversion' 

51 

52 

53def _utcnow(): 

54 """A wrapper function around datetime.datetime.utcnow. 

55 

56 This function is created for unit testing purpose. It's not easy to do 

57 StubOutWithMock with datetime.datetime package. 

58 

59 Returns: 

60 datetime.datetime 

61 """ 

62 return datetime.datetime.utcnow() 

63 

64 

65def _environment_discovery_url(): 

66 return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) 

67 

68 

69def _gce_metadata_endpoint(): 

70 endpoint = os.environ.get(_GCE_METADATA_ENDPOINT_ENV_VARIABLE) 

71 if not endpoint: 

72 endpoint = os.environ.get( 

73 _GCE_METADATA_URL_ENV_VARIABLE, 'metadata.google.internal' 

74 ) 

75 return 'http://' + endpoint 

76 

77 

78def _request_compute_metadata(path): 

79 req = urllib.request.Request( 

80 '%s/computeMetadata/v1/%s' % (_gce_metadata_endpoint(), path), 

81 headers={'Metadata-Flavor': 'Google'}) 

82 resp = urllib.request.urlopen(req) 

83 return _as_text(resp.read()) 

84 

85 

86def _environment_var_to_network_endpoints(endpoints): 

87 """Yields a dict with ip address and port.""" 

88 for endpoint in endpoints.split(','): 

89 grpc_prefix = 'grpc://' 

90 if endpoint.startswith(grpc_prefix): 

91 endpoint = endpoint.split(grpc_prefix)[1] 

92 parts = endpoint.split(':') 

93 ip_address = parts[0] 

94 port = _DEFAULT_ENDPOINT_PORT 

95 if len(parts) > 1: 

96 port = parts[1] 

97 yield { 

98 'ipAddress': ip_address, 

99 'port': port 

100 } 

101 

102 

103def _get_tpu_node_config(): 

104 tpu_config_env = os.environ.get(_DEFAULT_TPUCONFIG_VARIABLE) 

105 if tpu_config_env: 

106 return json.loads(tpu_config_env) 

107 return None 

108 

109 

110def _get_tpu_name(tpu): 

111 if tpu: 

112 return tpu 

113 

114 for e in [_GKE_ENV_VARIABLE, _DEFAULT_ENV_VARIABLE]: 

115 if e in os.environ: 

116 return os.environ[e] 

117 return None 

118 

119 

120def _as_text(s): 

121 if isinstance(s, bytes): 

122 return s.decode('utf-8') 

123 return s 

124 

125 

126class Client: 

127 """Client for working with the Cloud TPU API. 

128 

129 This client is intended to be used for resolving tpu name to ip addresses. 

130 

131 It's recommended to use this library as a contextlib to utilize all 

132 functionality. 

133 """ 

134 

135 def __init__(self, 

136 tpu=None, 

137 zone=None, 

138 project=None, 

139 credentials='default', 

140 service=None, 

141 discovery_url=None): 

142 if isinstance(tpu, list): 

143 if not tpu: 

144 raise ValueError('At least one TPU must be specified.') 

145 if len(tpu) != 1: 

146 raise NotImplementedError( 

147 'Using multiple TPUs in a single session is not yet implemented') 

148 tpu = tpu[0] 

149 

150 tpu = _get_tpu_name(tpu) 

151 

152 if tpu is None: 

153 tpu_node_config = _get_tpu_node_config() 

154 if tpu_node_config: 

155 tpu = tpu_node_config.get('tpu_node_name') 

156 project = project or tpu_node_config.get('project') 

157 zone = zone or tpu_node_config.get('zone') 

158 else: 

159 raise ValueError('Please provide a TPU Name to connect to.') 

160 

161 self._tpu = _as_text(tpu) 

162 

163 self._use_api = not self._tpu.startswith('grpc://') 

164 self._service = service 

165 

166 self._credentials = None 

167 self._project = None 

168 self._zone = None 

169 self._discovery_url = None 

170 if self._use_api: 

171 if credentials != 'default': 

172 self._credentials = credentials 

173 # Automatically detect project and zone if unspecified. 

174 if project: 

175 self._project = _as_text(project) 

176 else: 

177 self._project = _request_compute_metadata('project/project-id') 

178 if zone: 

179 self._zone = _as_text(zone) 

180 else: 

181 zone_path = _request_compute_metadata('instance/zone') 

182 self._zone = zone_path.split('/')[-1] 

183 self._discovery_url = _environment_discovery_url() or discovery_url 

184 

185 def _symptom_msg(self, msg): 

186 """Return the structured Symptom message.""" 

187 return 'Symptom: ' + msg 

188 

189 def _oom_event(self, symptoms): 

190 """Check if a runtime OOM event is reported.""" 

191 if not symptoms: 

192 return False 

193 for symptom in reversed(symptoms): 

194 if symptom['symptomType'] != 'OUT_OF_MEMORY': 

195 continue 

196 oom_datetime_str = symptom['createTime'].split('.')[0] 

197 oom_datetime = datetime.datetime.strptime(oom_datetime_str, 

198 '%Y-%m-%dT%H:%M:%S') 

199 time_diff = _utcnow() - oom_datetime 

200 if time_diff < datetime.timedelta(seconds=_OOM_EVENT_COOL_TIME_SEC): 

201 logging.warning( 

202 self._symptom_msg( 

203 'a recent runtime OOM has occurred ~{} seconds ago. The model ' 

204 'script will terminate automatically. To prevent future OOM ' 

205 'events, please consider reducing the model size. To disable this ' 

206 'behavior, set flag --runtime_oom_exit=false when starting the ' 

207 'script.'.format(time_diff.seconds))) 

208 return True 

209 return False 

210 

211 def _hbm_oom_event(self, symptoms): 

212 """Check if a HBM OOM event is reported.""" 

213 if not symptoms: 

214 return False 

215 for symptom in reversed(symptoms): 

216 if symptom['symptomType'] != 'HBM_OUT_OF_MEMORY': 

217 continue 

218 oom_datetime_str = symptom['createTime'].split('.')[0] 

219 oom_datetime = datetime.datetime.strptime(oom_datetime_str, 

220 '%Y-%m-%dT%H:%M:%S') 

221 time_diff = _utcnow() - oom_datetime 

222 if time_diff < datetime.timedelta(seconds=_OOM_EVENT_COOL_TIME_SEC): 

223 logging.warning( 

224 self._symptom_msg( 

225 'a recent HBM OOM has occurred ~{} seconds ago. The model ' 

226 'script will terminate automatically. To prevent future HBM OOM ' 

227 'events, please consider reducing the model size. To disable this ' 

228 'behavior, set flag --hbm_oom_exit=false when starting the ' 

229 'script.'.format(time_diff.seconds))) 

230 return True 

231 return False 

232 

233 def _tpu_service(self): 

234 """Creates a new Cloud TPU API object. 

235 

236 This works around an issue where the underlying HTTP connection sometimes 

237 times out when the script has been running for too long. Other methods in 

238 this object call this method to get a new API object whenever they need 

239 to communicate with the Cloud API. 

240 

241 Raises: 

242 RuntimeError: If the dependent Python packages are missing. 

243 

244 Returns: 

245 A Google Cloud TPU API object. 

246 """ 

247 if self._service: 

248 return self._service 

249 

250 if not _GOOGLE_API_CLIENT_INSTALLED: 

251 raise RuntimeError('Missing runtime dependency on the Google API client. ' 

252 'Run `pip install cloud-tpu-client` to fix.') 

253 

254 credentials = self._credentials 

255 if credentials is None or credentials == 'default': 

256 credentials = client.GoogleCredentials.get_application_default() 

257 

258 if self._discovery_url: 

259 return discovery.build( 

260 'tpu', 

261 'v1', 

262 credentials=credentials, 

263 discoveryServiceUrl=self._discovery_url, 

264 cache_discovery=False) 

265 else: 

266 return discovery.build( 

267 'tpu', 'v1', credentials=credentials, cache_discovery=False) 

268 

269 def _full_name(self): 

270 """Returns the full Cloud name for this TPU.""" 

271 return 'projects/%s/locations/%s/nodes/%s' % ( 

272 self._project, self._zone, self._tpu) 

273 

274 def _fetch_cloud_tpu_metadata(self): 

275 """Returns the TPU metadata object from the TPU Get API call.""" 

276 service = self._tpu_service() 

277 try: 

278 r = service.projects().locations().nodes().get(name=self._full_name()) 

279 return r.execute() 

280 except Exception as e: 

281 raise ValueError("Could not lookup TPU metadata from name '%s'. Please " 

282 'doublecheck the tpu argument in the TPUClusterResolver ' 

283 'constructor. Exception: %s' % (self._tpu, e)) 

284 

285 def _get_tpu_property(self, key): 

286 if self._use_api: 

287 metadata = self._fetch_cloud_tpu_metadata() 

288 return metadata.get(key) 

289 

290 return None 

291 

292 def __enter__(self): 

293 self._open = True 

294 

295 def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin 

296 del type, value, traceback 

297 

298 def recoverable(self): 

299 """Returns true if the TPU is in a state where training should eventually resume. 

300 

301 If false the TPU is in a unrecoverable state and should be recreated. 

302 """ 

303 state = self.state() 

304 symptoms = self.symptoms() 

305 if state and state in ['TERMINATED', 'PREEMPTED']: 

306 return False 

307 elif FLAGS.runtime_oom_exit and self._oom_event(symptoms): 

308 return False 

309 elif FLAGS.hbm_oom_exit and self._hbm_oom_event(symptoms): 

310 return False 

311 return True 

312 

313 def symptoms(self): 

314 """Return Cloud TPU Symptoms of the TPU.""" 

315 return self._get_tpu_property('symptoms') 

316 

317 def state(self): 

318 """Return state of the TPU.""" 

319 return self._get_tpu_property('state') 

320 

321 def health(self): 

322 """Return health of the TPU.""" 

323 return self._get_tpu_property('health') 

324 

325 def runtime_version(self): 

326 """Return runtime version of the TPU.""" 

327 

328 if not self._use_api: 

329 # Fallback on getting version directly from TPU. 

330 url = _VERSION_SWITCHER_ENDPOINT.format( 

331 self.network_endpoints()[0]['ipAddress']) 

332 try: 

333 req = urllib.request.Request(url) 

334 resp = urllib.request.urlopen(req) 

335 version_details = json.loads(resp.read()) 

336 return version_details.get('currentVersion') 

337 except urllib.error.HTTPError as e: 

338 status_code = e.code 

339 if status_code == 404: 

340 return None 

341 else: 

342 raise e 

343 return self._get_tpu_property('tensorflowVersion') 

344 

345 def accelerator_type(self): 

346 """Return accelerator type of the TPU.""" 

347 return self._get_tpu_property('acceleratorType') 

348 

349 def api_available(self): 

350 """Return if the Cloud TPU API is available, if not certain features will not work.""" 

351 return self._use_api 

352 

353 def name(self): 

354 """Return the name of the tpu, or the ip address if name is not provided.""" 

355 return self._tpu 

356 

357 def get_local_ip(self): 

358 """Return the local ip address of the Google Cloud VM the workload is running on.""" 

359 return _request_compute_metadata('instance/network-interfaces/0/ip') 

360 

361 def network_endpoints(self): 

362 """Return a list of tpu endpoints.""" 

363 if not self._use_api: 

364 return list(_environment_var_to_network_endpoints(self._tpu)) 

365 response = self._fetch_cloud_tpu_metadata() 

366 

367 if response.get('state') != 'READY': 

368 raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % 

369 (self._tpu, response.get('state'))) 

370 if 'networkEndpoints' in response: 

371 return response['networkEndpoints'] 

372 else: 

373 return [{'ipAddress': response['ipAddress'], 'port': response['port']}] 

374 

375 def wait_for_healthy(self, timeout_s=1200, interval=30): 

376 """Wait for TPU to become healthy or raise error if timeout reached. 

377 

378 Args: 

379 timeout_s (int): The timeout in seconds for waiting TPU to become healthy. 

380 interval (int): The interval in seconds to poll the TPU for health. 

381 

382 Raises: 

383 RuntimeError: If the TPU doesn't become healthy by the timeout. 

384 """ 

385 timeout = time.time() + timeout_s 

386 while self.health() != 'HEALTHY': 

387 logging.warning( 

388 ('Waiting for TPU "%s" with state "%s" ' 

389 'and health "%s" to become healthy'), 

390 self.name(), self.state(), self.health()) 

391 if time.time() + interval > timeout: 

392 raise RuntimeError( 

393 'Timed out waiting for TPU "%s" to become healthy' % self.name()) 

394 time.sleep(interval) 

395 

396 logging.warning('TPU "%s" is healthy.', self.name()) 

397 

398 def configure_tpu_version(self, version, restart_type='always'): 

399 """Configure TPU software version. 

400 

401 Args: 

402 version (string): Version of software to configure the TPU with. 

403 restart_type (string): Restart behaviour when switching versions, 

404 defaults to always restart. Options are 'always', 'ifNeeded'. 

405 

406 """ 

407 

408 def configure_worker(worker): 

409 """Configure individual TPU worker. 

410 

411 Args: 

412 worker: A dict with the field ipAddress where the configure request will 

413 be sent. 

414 """ 

415 ip_address = worker['ipAddress'] 

416 url = (_VERSION_SWITCHER_ENDPOINT + '/{}?restartType={}').format( 

417 ip_address, version, restart_type) 

418 req = urllib.request.Request(url, data=b'') 

419 try: 

420 urllib.request.urlopen(req) 

421 except urllib.error.HTTPError as e: 

422 status_code = e.code 

423 if status_code == 404: 

424 raise Exception( 

425 'Tensorflow version {} is not available on Cloud TPU, ' 

426 'try a previous nightly version or refer to ' 

427 'https://cloud.google.com/tpu/docs/release-notes for ' 

428 'the latest official version.'.format(version)) 

429 else: 

430 raise Exception('Failed to configure worker {}'.format(ip_address)) 

431 

432 workers = self.network_endpoints() 

433 

434 with futures.ThreadPoolExecutor(max_workers=len(workers)) as executor: 

435 results = executor.map(configure_worker, workers) 

436 for result in results: 

437 if result: 

438 result.result()