Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_system_metadata.py: 25%
102 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""TPU system metadata and associated tooling."""
17import collections
19from tensorflow.core.protobuf import config_pb2
20from tensorflow.python.client import session as session_lib
21from tensorflow.python.distribute import device_util
22from tensorflow.python.eager import context
23from tensorflow.python.framework import config
24from tensorflow.python.framework import device as tf_device
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import ops
27from tensorflow.python.platform import tf_logging as logging
28from tensorflow.python.tpu import tpu
29from tensorflow.python.util.tf_export import tf_export
31_PINGING_MASTER_TIMEOUT_IN_MS = 5 * 60 * 1000 # 10 min
32_RETRY_TIMES = 12 * 24 # 1 day
33_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
35_DEFAULT_JOB_NAME = 'tpu_worker'
36_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
37_LOCAL_MASTERS = ('', 'local')
40@tf_export('tpu.experimental.TPUSystemMetadata')
41class TPUSystemMetadata(
42 collections.namedtuple('TPUSystemMetadata', [
43 'num_cores',
44 'num_hosts',
45 'num_of_cores_per_host',
46 'topology',
47 'devices',
48 ])):
49 """Describes some metadata about the TPU system.
51 Attributes:
52 num_cores: interger. Total number of TPU cores in the TPU system.
53 num_hosts: interger. Total number of hosts (TPU workers) in the TPU system.
54 num_of_cores_per_host: interger. Number of TPU cores per host (TPU worker).
55 topology: an instance of `tf.tpu.experimental.Topology`, which describes the
56 physical topology of TPU system.
57 devices: a tuple of strings, which describes all the TPU devices in the
58 system.
59 """
61 def __new__(cls, num_cores, num_hosts, num_of_cores_per_host, topology,
62 devices):
63 return super(TPUSystemMetadata,
64 cls).__new__(cls, num_cores, num_hosts, num_of_cores_per_host,
65 topology, devices)
68def _query_tpu_system_metadata(master_address, cluster_def=None,
69 query_topology=False):
70 """Automatically detects the TPU system metadata in the system."""
71 tpu_core_count = 0
72 devices = []
73 device_dict = collections.defaultdict(list)
75 if context.executing_eagerly():
76 logical_devices = config.list_logical_devices()
78 # We want the output type to match in both eager and session mode
79 devices = [session_lib._DeviceAttributes(device_util.canonicalize(d.name), # pylint: disable=protected-access
80 d.device_type, 0, 0)
81 for d in logical_devices]
82 else:
83 # TODO(b/120564445): Replace with standard library for retries.
84 retry_count = 1
85 while True:
86 logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
87 master_address)
88 try:
89 with ops.Graph().as_default():
90 with session_lib.Session(
91 master_address,
92 config=get_session_config_with_timeout(
93 _PINGING_MASTER_TIMEOUT_IN_MS,
94 cluster_def)) as sess:
95 devices = sess.list_devices()
96 break
97 except errors.DeadlineExceededError:
98 msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
99 'not be ready (still scheduling) or the Tensorflow master '
100 'address is incorrect: got (%s).' %
101 (master_address))
103 # TODO(xiejw): For local or grpc master we might not need retry logic
104 # here.
105 if retry_count <= _RETRY_TIMES:
106 logging.warning('%s', msg)
107 logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
108 retry_count += 1
109 else:
110 raise ValueError(msg)
112 for device in devices:
113 spec = tf_device.DeviceSpec.from_string(device.name)
114 if spec.device_type == 'TPU':
115 device_dict[spec.task].append(spec.device_index)
116 tpu_core_count += 1
118 num_of_cores_per_host = 0
119 if tpu_core_count:
120 num_cores_per_host_set = set(
121 [len(core_ids) for core_ids in device_dict.values()])
122 if len(num_cores_per_host_set) != 1:
123 raise RuntimeError(
124 'TPU cores on each host is not same. This should not happen!. '
125 'devices: {}'.format(devices))
126 num_of_cores_per_host = num_cores_per_host_set.pop()
128 topology = None
129 if query_topology:
130 if not tpu_core_count:
131 raise RuntimeError(
132 'Cannot find any TPU cores in the system (master address {}). '
133 'This usually means the master address is incorrect or the '
134 'TPU worker has some problems. Available devices: {}'.format(
135 master_address, devices))
137 topology = _obtain_topology(master_address, cluster_def)
139 # We sort the metadata devices so that downstream users get a sorted list
140 # for creating mirrored variables correctly.
141 def _sort_key(device):
142 spec = tf_device.DeviceSpec.from_string(device.name)
143 return (spec.job, spec.replica, spec.task, spec.device_type,
144 spec.device_index)
145 devices = tuple(sorted(devices, key=_sort_key))
147 metadata = TPUSystemMetadata(
148 num_cores=tpu_core_count,
149 num_hosts=len(device_dict),
150 num_of_cores_per_host=num_of_cores_per_host,
151 topology=topology,
152 devices=devices)
154 if tpu_core_count:
155 logging.info('Found TPU system:')
156 logging.info('*** Num TPU Cores: %d', metadata.num_cores)
157 logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
158 logging.info('*** Num TPU Cores Per Worker: %d',
159 metadata.num_of_cores_per_host)
160 for device in metadata.devices:
161 logging.info('*** Available Device: %s', device)
162 else:
163 logging.info('Failed to find TPU: %s', metadata)
164 return metadata
167def _obtain_topology(master_address, cluster_def):
168 """Obtains TPU fabric topology."""
169 try:
170 logging.info('Initializing TPU system (master: %s) to fetch topology '
171 'for model parallelism. This might take a while.',
172 master_address)
173 with ops.Graph().as_default():
174 session_config = get_session_config_with_timeout(
175 _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def)
176 with session_lib.Session(
177 master_address, config=session_config) as sess:
178 topology = sess.run(tpu.initialize_system())
179 return topology
180 except errors.DeadlineExceededError:
181 raise ValueError(
182 'Fail to initialize TPU system with master (%s). '
183 'Please double check the TPU system is functional.' % (
184 master_address))
187def get_session_config_with_timeout(timeout_in_secs, cluster_def):
188 """Returns a session given a timeout and a cluster configuration."""
189 config_proto = config_pb2.ConfigProto(
190 operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
191 return config_proto
194def master_job(master, cluster_def):
195 """Returns the canonical job name to use to place TPU computations on.
197 Args:
198 master: A `string` representing the TensorFlow master to use.
199 cluster_def: A ClusterDef object describing the TPU cluster.
201 Returns:
202 A string containing the job name, or None if no job should be specified.
204 Raises:
205 ValueError: If the user needs to specify a tpu_job_name, because we are
206 unable to infer the job name automatically, or if the user-specified job
207 names are inappropriate.
208 """
209 # If the user specifies the tpu_job_name, use that.
211 if master in _LOCAL_MASTERS:
212 return None
214 if (not cluster_def or not cluster_def.job):
215 return _DEFAULT_JOB_NAME
216 job_names = set(job.name for job in cluster_def.job)
217 if _DEFAULT_JOB_NAME in job_names:
218 # b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
219 raise ValueError('Currently, tpu_worker is not an allowed job name.')
220 if len(job_names) == 1:
221 return cluster_def.job[0].name
222 if len(job_names) == 2:
223 if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
224 job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
225 return job_names.pop()
226 # TODO(b/67716447): Include more sophisticated heuristics.
227 raise ValueError('Could not infer TPU job name.')