1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Util of GCE specifics to ingegrate with WorkerPreemptionHandler."""
16import enum
17import os
18import sys
19import requests
20
21from six.moves.urllib import request
22from tensorflow.python.eager import context
23from tensorflow.python.platform import tf_logging as logging
24
25
26GCP_METADATA_HEADER = {'Metadata-Flavor': 'Google'}
27_GCE_METADATA_URL_ENV_VARIABLE = 'GCE_METADATA_IP'
28_RESTARTABLE_EXIT_CODE = 143
29GRACE_PERIOD_GCE = 3600
30
31
32def gce_exit_fn():
33 sys.exit(_RESTARTABLE_EXIT_CODE)
34
35
36def default_tpu_exit_fn():
37 """Default exit function to run after saving checkpoint for TPUStrategy.
38
39 For TPUStrategy, we want the coordinator to exit after workers are down so
40 that restarted coordinator would not connect to workers scheduled to be
41 preempted. This function achieves so by attempting to get a key-value store
42 from coordination service, which will block until workers are done and then
43 returns with error. Then we have the coordinator sys.exit(42) to re-schedule
44 the job.
45 """
46 logging.info('Waiting for workers to exit...')
47 try:
48 context.context().get_config_key_value('BLOCK_TILL_EXIT')
49 except: # pylint: disable=bare-except
50 logging.info('Restarting cluster due to preemption.')
51 sys.exit(42)
52
53
54def request_compute_metadata(path):
55 """Returns GCE VM compute metadata."""
56 gce_metadata_endpoint = 'http://' + os.environ.get(
57 _GCE_METADATA_URL_ENV_VARIABLE, 'metadata.google.internal')
58 req = request.Request(
59 '%s/computeMetadata/v1/%s' % (gce_metadata_endpoint, path),
60 headers={'Metadata-Flavor': 'Google'})
61 info = request.urlopen(req).read()
62 if isinstance(info, bytes):
63 return info.decode('utf-8')
64 else:
65 return info
66
67
68def termination_watcher_function_gce():
69 result = request_compute_metadata(
70 'instance/maintenance-event') == 'TERMINATE_ON_HOST_MAINTENANCE'
71 return result
72
73
74def on_gcp():
75 """Detect whether the current running environment is on GCP."""
76 gce_metadata_endpoint = 'http://' + os.environ.get(
77 _GCE_METADATA_URL_ENV_VARIABLE, 'metadata.google.internal')
78
79 try:
80 # Timeout in 5 seconds, in case the test environment has connectivity issue.
81 # There is not default timeout, which means it might block forever.
82 response = requests.get(
83 '%s/computeMetadata/v1/%s' %
84 (gce_metadata_endpoint, 'instance/hostname'),
85 headers=GCP_METADATA_HEADER,
86 timeout=5)
87 return response.status_code == 200
88 except requests.exceptions.RequestException:
89 return False
90
91
92@enum.unique
93class PlatformDevice(enum.Enum):
94 INTERNAL_CPU = 'internal_CPU'
95 INTERNAL_GPU = 'internal_GPU'
96 INTERNAL_TPU = 'internal_TPU'
97 GCE_GPU = 'GCE_GPU'
98 GCE_TPU = 'GCE_TPU'
99 GCE_CPU = 'GCE_CPU'
100 UNSUPPORTED = 'unsupported'
101
102
103def detect_platform():
104 """Returns the platform and device information."""
105 if on_gcp():
106 if context.context().list_logical_devices('GPU'):
107 return PlatformDevice.GCE_GPU
108 elif context.context().list_logical_devices('TPU'):
109 return PlatformDevice.GCE_TPU
110 else:
111 return PlatformDevice.GCE_CPU
112
113 else:
114 if context.context().list_logical_devices('GPU'):
115 return PlatformDevice.INTERNAL_GPU
116 elif context.context().list_logical_devices('TPU'):
117 return PlatformDevice.INTERNAL_TPU
118 else:
119 return PlatformDevice.INTERNAL_CPU