Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/client/client.py: 25%
232 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Cloud TPU Client."""
17from concurrent import futures
18import datetime
19import json
20import logging
21import os
22import time
23import urllib
25from absl import flags
27_GOOGLE_API_CLIENT_INSTALLED = True
28try:
29 from googleapiclient import discovery # pylint: disable=g-import-not-at-top
30 from oauth2client import client # pylint: disable=g-import-not-at-top
31except ImportError:
32 _GOOGLE_API_CLIENT_INSTALLED = False
34FLAGS = flags.FLAGS
36flags.DEFINE_bool('runtime_oom_exit', True,
37 'Exit the script when the TPU runtime is OOM.')
38flags.DEFINE_bool('hbm_oom_exit', True,
39 'Exit the script when the TPU HBM is OOM.')
41_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
42_DEFAULT_TPUCONFIG_VARIABLE = 'TPU_CONFIG'
43_ENDPOINTS_SEPARATOR = ','
44_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
45_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
46_GCE_METADATA_URL_ENV_VARIABLE = 'GCE_METADATA_IP'
47_GCE_METADATA_ENDPOINT_ENV_VARIABLE = 'GCE_METADATA_HOST'
48_DEFAULT_ENDPOINT_PORT = '8470'
49_OOM_EVENT_COOL_TIME_SEC = 90
50_VERSION_SWITCHER_ENDPOINT = 'http://{}:8475/requestversion'
53def _utcnow():
54 """A wrapper function around datetime.datetime.utcnow.
56 This function is created for unit testing purpose. It's not easy to do
57 StubOutWithMock with datetime.datetime package.
59 Returns:
60 datetime.datetime
61 """
62 return datetime.datetime.utcnow()
65def _environment_discovery_url():
66 return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
69def _gce_metadata_endpoint():
70 endpoint = os.environ.get(_GCE_METADATA_ENDPOINT_ENV_VARIABLE)
71 if not endpoint:
72 endpoint = os.environ.get(
73 _GCE_METADATA_URL_ENV_VARIABLE, 'metadata.google.internal'
74 )
75 return 'http://' + endpoint
78def _request_compute_metadata(path):
79 req = urllib.request.Request(
80 '%s/computeMetadata/v1/%s' % (_gce_metadata_endpoint(), path),
81 headers={'Metadata-Flavor': 'Google'})
82 resp = urllib.request.urlopen(req)
83 return _as_text(resp.read())
86def _environment_var_to_network_endpoints(endpoints):
87 """Yields a dict with ip address and port."""
88 for endpoint in endpoints.split(','):
89 grpc_prefix = 'grpc://'
90 if endpoint.startswith(grpc_prefix):
91 endpoint = endpoint.split(grpc_prefix)[1]
92 parts = endpoint.split(':')
93 ip_address = parts[0]
94 port = _DEFAULT_ENDPOINT_PORT
95 if len(parts) > 1:
96 port = parts[1]
97 yield {
98 'ipAddress': ip_address,
99 'port': port
100 }
103def _get_tpu_node_config():
104 tpu_config_env = os.environ.get(_DEFAULT_TPUCONFIG_VARIABLE)
105 if tpu_config_env:
106 return json.loads(tpu_config_env)
107 return None
110def _get_tpu_name(tpu):
111 if tpu:
112 return tpu
114 for e in [_GKE_ENV_VARIABLE, _DEFAULT_ENV_VARIABLE]:
115 if e in os.environ:
116 return os.environ[e]
117 return None
120def _as_text(s):
121 if isinstance(s, bytes):
122 return s.decode('utf-8')
123 return s
126class Client:
127 """Client for working with the Cloud TPU API.
129 This client is intended to be used for resolving tpu name to ip addresses.
131 It's recommended to use this library as a contextlib to utilize all
132 functionality.
133 """
135 def __init__(self,
136 tpu=None,
137 zone=None,
138 project=None,
139 credentials='default',
140 service=None,
141 discovery_url=None):
142 if isinstance(tpu, list):
143 if not tpu:
144 raise ValueError('At least one TPU must be specified.')
145 if len(tpu) != 1:
146 raise NotImplementedError(
147 'Using multiple TPUs in a single session is not yet implemented')
148 tpu = tpu[0]
150 tpu = _get_tpu_name(tpu)
152 if tpu is None:
153 tpu_node_config = _get_tpu_node_config()
154 if tpu_node_config:
155 tpu = tpu_node_config.get('tpu_node_name')
156 project = project or tpu_node_config.get('project')
157 zone = zone or tpu_node_config.get('zone')
158 else:
159 raise ValueError('Please provide a TPU Name to connect to.')
161 self._tpu = _as_text(tpu)
163 self._use_api = not self._tpu.startswith('grpc://')
164 self._service = service
166 self._credentials = None
167 self._project = None
168 self._zone = None
169 self._discovery_url = None
170 if self._use_api:
171 if credentials != 'default':
172 self._credentials = credentials
173 # Automatically detect project and zone if unspecified.
174 if project:
175 self._project = _as_text(project)
176 else:
177 self._project = _request_compute_metadata('project/project-id')
178 if zone:
179 self._zone = _as_text(zone)
180 else:
181 zone_path = _request_compute_metadata('instance/zone')
182 self._zone = zone_path.split('/')[-1]
183 self._discovery_url = _environment_discovery_url() or discovery_url
185 def _symptom_msg(self, msg):
186 """Return the structured Symptom message."""
187 return 'Symptom: ' + msg
189 def _oom_event(self, symptoms):
190 """Check if a runtime OOM event is reported."""
191 if not symptoms:
192 return False
193 for symptom in reversed(symptoms):
194 if symptom['symptomType'] != 'OUT_OF_MEMORY':
195 continue
196 oom_datetime_str = symptom['createTime'].split('.')[0]
197 oom_datetime = datetime.datetime.strptime(oom_datetime_str,
198 '%Y-%m-%dT%H:%M:%S')
199 time_diff = _utcnow() - oom_datetime
200 if time_diff < datetime.timedelta(seconds=_OOM_EVENT_COOL_TIME_SEC):
201 logging.warning(
202 self._symptom_msg(
203 'a recent runtime OOM has occurred ~{} seconds ago. The model '
204 'script will terminate automatically. To prevent future OOM '
205 'events, please consider reducing the model size. To disable this '
206 'behavior, set flag --runtime_oom_exit=false when starting the '
207 'script.'.format(time_diff.seconds)))
208 return True
209 return False
211 def _hbm_oom_event(self, symptoms):
212 """Check if a HBM OOM event is reported."""
213 if not symptoms:
214 return False
215 for symptom in reversed(symptoms):
216 if symptom['symptomType'] != 'HBM_OUT_OF_MEMORY':
217 continue
218 oom_datetime_str = symptom['createTime'].split('.')[0]
219 oom_datetime = datetime.datetime.strptime(oom_datetime_str,
220 '%Y-%m-%dT%H:%M:%S')
221 time_diff = _utcnow() - oom_datetime
222 if time_diff < datetime.timedelta(seconds=_OOM_EVENT_COOL_TIME_SEC):
223 logging.warning(
224 self._symptom_msg(
225 'a recent HBM OOM has occurred ~{} seconds ago. The model '
226 'script will terminate automatically. To prevent future HBM OOM '
227 'events, please consider reducing the model size. To disable this '
228 'behavior, set flag --hbm_oom_exit=false when starting the '
229 'script.'.format(time_diff.seconds)))
230 return True
231 return False
233 def _tpu_service(self):
234 """Creates a new Cloud TPU API object.
236 This works around an issue where the underlying HTTP connection sometimes
237 times out when the script has been running for too long. Other methods in
238 this object call this method to get a new API object whenever they need
239 to communicate with the Cloud API.
241 Raises:
242 RuntimeError: If the dependent Python packages are missing.
244 Returns:
245 A Google Cloud TPU API object.
246 """
247 if self._service:
248 return self._service
250 if not _GOOGLE_API_CLIENT_INSTALLED:
251 raise RuntimeError('Missing runtime dependency on the Google API client. '
252 'Run `pip install cloud-tpu-client` to fix.')
254 credentials = self._credentials
255 if credentials is None or credentials == 'default':
256 credentials = client.GoogleCredentials.get_application_default()
258 if self._discovery_url:
259 return discovery.build(
260 'tpu',
261 'v1',
262 credentials=credentials,
263 discoveryServiceUrl=self._discovery_url,
264 cache_discovery=False)
265 else:
266 return discovery.build(
267 'tpu', 'v1', credentials=credentials, cache_discovery=False)
269 def _full_name(self):
270 """Returns the full Cloud name for this TPU."""
271 return 'projects/%s/locations/%s/nodes/%s' % (
272 self._project, self._zone, self._tpu)
274 def _fetch_cloud_tpu_metadata(self):
275 """Returns the TPU metadata object from the TPU Get API call."""
276 service = self._tpu_service()
277 try:
278 r = service.projects().locations().nodes().get(name=self._full_name())
279 return r.execute()
280 except Exception as e:
281 raise ValueError("Could not lookup TPU metadata from name '%s'. Please "
282 'doublecheck the tpu argument in the TPUClusterResolver '
283 'constructor. Exception: %s' % (self._tpu, e))
285 def _get_tpu_property(self, key):
286 if self._use_api:
287 metadata = self._fetch_cloud_tpu_metadata()
288 return metadata.get(key)
290 return None
292 def __enter__(self):
293 self._open = True
295 def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin
296 del type, value, traceback
298 def recoverable(self):
299 """Returns true if the TPU is in a state where training should eventually resume.
301 If false the TPU is in a unrecoverable state and should be recreated.
302 """
303 state = self.state()
304 symptoms = self.symptoms()
305 if state and state in ['TERMINATED', 'PREEMPTED']:
306 return False
307 elif FLAGS.runtime_oom_exit and self._oom_event(symptoms):
308 return False
309 elif FLAGS.hbm_oom_exit and self._hbm_oom_event(symptoms):
310 return False
311 return True
313 def symptoms(self):
314 """Return Cloud TPU Symptoms of the TPU."""
315 return self._get_tpu_property('symptoms')
317 def state(self):
318 """Return state of the TPU."""
319 return self._get_tpu_property('state')
321 def health(self):
322 """Return health of the TPU."""
323 return self._get_tpu_property('health')
325 def runtime_version(self):
326 """Return runtime version of the TPU."""
328 if not self._use_api:
329 # Fallback on getting version directly from TPU.
330 url = _VERSION_SWITCHER_ENDPOINT.format(
331 self.network_endpoints()[0]['ipAddress'])
332 try:
333 req = urllib.request.Request(url)
334 resp = urllib.request.urlopen(req)
335 version_details = json.loads(resp.read())
336 return version_details.get('currentVersion')
337 except urllib.error.HTTPError as e:
338 status_code = e.code
339 if status_code == 404:
340 return None
341 else:
342 raise e
343 return self._get_tpu_property('tensorflowVersion')
345 def accelerator_type(self):
346 """Return accelerator type of the TPU."""
347 return self._get_tpu_property('acceleratorType')
349 def api_available(self):
350 """Return if the Cloud TPU API is available, if not certain features will not work."""
351 return self._use_api
353 def name(self):
354 """Return the name of the tpu, or the ip address if name is not provided."""
355 return self._tpu
357 def get_local_ip(self):
358 """Return the local ip address of the Google Cloud VM the workload is running on."""
359 return _request_compute_metadata('instance/network-interfaces/0/ip')
361 def network_endpoints(self):
362 """Return a list of tpu endpoints."""
363 if not self._use_api:
364 return list(_environment_var_to_network_endpoints(self._tpu))
365 response = self._fetch_cloud_tpu_metadata()
367 if response.get('state') != 'READY':
368 raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
369 (self._tpu, response.get('state')))
370 if 'networkEndpoints' in response:
371 return response['networkEndpoints']
372 else:
373 return [{'ipAddress': response['ipAddress'], 'port': response['port']}]
375 def wait_for_healthy(self, timeout_s=1200, interval=30):
376 """Wait for TPU to become healthy or raise error if timeout reached.
378 Args:
379 timeout_s (int): The timeout in seconds for waiting TPU to become healthy.
380 interval (int): The interval in seconds to poll the TPU for health.
382 Raises:
383 RuntimeError: If the TPU doesn't become healthy by the timeout.
384 """
385 timeout = time.time() + timeout_s
386 while self.health() != 'HEALTHY':
387 logging.warning(
388 ('Waiting for TPU "%s" with state "%s" '
389 'and health "%s" to become healthy'),
390 self.name(), self.state(), self.health())
391 if time.time() + interval > timeout:
392 raise RuntimeError(
393 'Timed out waiting for TPU "%s" to become healthy' % self.name())
394 time.sleep(interval)
396 logging.warning('TPU "%s" is healthy.', self.name())
398 def configure_tpu_version(self, version, restart_type='always'):
399 """Configure TPU software version.
401 Args:
402 version (string): Version of software to configure the TPU with.
403 restart_type (string): Restart behaviour when switching versions,
404 defaults to always restart. Options are 'always', 'ifNeeded'.
406 """
408 def configure_worker(worker):
409 """Configure individual TPU worker.
411 Args:
412 worker: A dict with the field ipAddress where the configure request will
413 be sent.
414 """
415 ip_address = worker['ipAddress']
416 url = (_VERSION_SWITCHER_ENDPOINT + '/{}?restartType={}').format(
417 ip_address, version, restart_type)
418 req = urllib.request.Request(url, data=b'')
419 try:
420 urllib.request.urlopen(req)
421 except urllib.error.HTTPError as e:
422 status_code = e.code
423 if status_code == 404:
424 raise Exception(
425 'Tensorflow version {} is not available on Cloud TPU, '
426 'try a previous nightly version or refer to '
427 'https://cloud.google.com/tpu/docs/release-notes for '
428 'the latest official version.'.format(version))
429 else:
430 raise Exception('Failed to configure worker {}'.format(ip_address))
432 workers = self.network_endpoints()
434 with futures.ThreadPoolExecutor(max_workers=len(workers)) as executor:
435 results = executor.map(configure_worker, workers)
436 for result in results:
437 if result:
438 result.result()