Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_system_metadata.py: 25%

102 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# =================================================================== 

15"""TPU system metadata and associated tooling.""" 

16 

17import collections 

18 

19from tensorflow.core.protobuf import config_pb2 

20from tensorflow.python.client import session as session_lib 

21from tensorflow.python.distribute import device_util 

22from tensorflow.python.eager import context 

23from tensorflow.python.framework import config 

24from tensorflow.python.framework import device as tf_device 

25from tensorflow.python.framework import errors 

26from tensorflow.python.framework import ops 

27from tensorflow.python.platform import tf_logging as logging 

28from tensorflow.python.tpu import tpu 

29from tensorflow.python.util.tf_export import tf_export 

30 

31_PINGING_MASTER_TIMEOUT_IN_MS = 5 * 60 * 1000 # 10 min 

32_RETRY_TIMES = 12 * 24 # 1 day 

33_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins 

34 

35_DEFAULT_JOB_NAME = 'tpu_worker' 

36_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' 

37_LOCAL_MASTERS = ('', 'local') 

38 

39 

40@tf_export('tpu.experimental.TPUSystemMetadata') 

41class TPUSystemMetadata( 

42 collections.namedtuple('TPUSystemMetadata', [ 

43 'num_cores', 

44 'num_hosts', 

45 'num_of_cores_per_host', 

46 'topology', 

47 'devices', 

48 ])): 

49 """Describes some metadata about the TPU system. 

50 

51 Attributes: 

52 num_cores: interger. Total number of TPU cores in the TPU system. 

53 num_hosts: interger. Total number of hosts (TPU workers) in the TPU system. 

54 num_of_cores_per_host: interger. Number of TPU cores per host (TPU worker). 

55 topology: an instance of `tf.tpu.experimental.Topology`, which describes the 

56 physical topology of TPU system. 

57 devices: a tuple of strings, which describes all the TPU devices in the 

58 system. 

59 """ 

60 

61 def __new__(cls, num_cores, num_hosts, num_of_cores_per_host, topology, 

62 devices): 

63 return super(TPUSystemMetadata, 

64 cls).__new__(cls, num_cores, num_hosts, num_of_cores_per_host, 

65 topology, devices) 

66 

67 

68def _query_tpu_system_metadata(master_address, cluster_def=None, 

69 query_topology=False): 

70 """Automatically detects the TPU system metadata in the system.""" 

71 tpu_core_count = 0 

72 devices = [] 

73 device_dict = collections.defaultdict(list) 

74 

75 if context.executing_eagerly(): 

76 logical_devices = config.list_logical_devices() 

77 

78 # We want the output type to match in both eager and session mode 

79 devices = [session_lib._DeviceAttributes(device_util.canonicalize(d.name), # pylint: disable=protected-access 

80 d.device_type, 0, 0) 

81 for d in logical_devices] 

82 else: 

83 # TODO(b/120564445): Replace with standard library for retries. 

84 retry_count = 1 

85 while True: 

86 logging.info('Querying Tensorflow master (%s) for TPU system metadata.', 

87 master_address) 

88 try: 

89 with ops.Graph().as_default(): 

90 with session_lib.Session( 

91 master_address, 

92 config=get_session_config_with_timeout( 

93 _PINGING_MASTER_TIMEOUT_IN_MS, 

94 cluster_def)) as sess: 

95 devices = sess.list_devices() 

96 break 

97 except errors.DeadlineExceededError: 

98 msg = ('Failed to connect to the Tensorflow master. The TPU worker may ' 

99 'not be ready (still scheduling) or the Tensorflow master ' 

100 'address is incorrect: got (%s).' % 

101 (master_address)) 

102 

103 # TODO(xiejw): For local or grpc master we might not need retry logic 

104 # here. 

105 if retry_count <= _RETRY_TIMES: 

106 logging.warning('%s', msg) 

107 logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES) 

108 retry_count += 1 

109 else: 

110 raise ValueError(msg) 

111 

112 for device in devices: 

113 spec = tf_device.DeviceSpec.from_string(device.name) 

114 if spec.device_type == 'TPU': 

115 device_dict[spec.task].append(spec.device_index) 

116 tpu_core_count += 1 

117 

118 num_of_cores_per_host = 0 

119 if tpu_core_count: 

120 num_cores_per_host_set = set( 

121 [len(core_ids) for core_ids in device_dict.values()]) 

122 if len(num_cores_per_host_set) != 1: 

123 raise RuntimeError( 

124 'TPU cores on each host is not same. This should not happen!. ' 

125 'devices: {}'.format(devices)) 

126 num_of_cores_per_host = num_cores_per_host_set.pop() 

127 

128 topology = None 

129 if query_topology: 

130 if not tpu_core_count: 

131 raise RuntimeError( 

132 'Cannot find any TPU cores in the system (master address {}). ' 

133 'This usually means the master address is incorrect or the ' 

134 'TPU worker has some problems. Available devices: {}'.format( 

135 master_address, devices)) 

136 

137 topology = _obtain_topology(master_address, cluster_def) 

138 

139 # We sort the metadata devices so that downstream users get a sorted list 

140 # for creating mirrored variables correctly. 

141 def _sort_key(device): 

142 spec = tf_device.DeviceSpec.from_string(device.name) 

143 return (spec.job, spec.replica, spec.task, spec.device_type, 

144 spec.device_index) 

145 devices = tuple(sorted(devices, key=_sort_key)) 

146 

147 metadata = TPUSystemMetadata( 

148 num_cores=tpu_core_count, 

149 num_hosts=len(device_dict), 

150 num_of_cores_per_host=num_of_cores_per_host, 

151 topology=topology, 

152 devices=devices) 

153 

154 if tpu_core_count: 

155 logging.info('Found TPU system:') 

156 logging.info('*** Num TPU Cores: %d', metadata.num_cores) 

157 logging.info('*** Num TPU Workers: %d', metadata.num_hosts) 

158 logging.info('*** Num TPU Cores Per Worker: %d', 

159 metadata.num_of_cores_per_host) 

160 for device in metadata.devices: 

161 logging.info('*** Available Device: %s', device) 

162 else: 

163 logging.info('Failed to find TPU: %s', metadata) 

164 return metadata 

165 

166 

167def _obtain_topology(master_address, cluster_def): 

168 """Obtains TPU fabric topology.""" 

169 try: 

170 logging.info('Initializing TPU system (master: %s) to fetch topology ' 

171 'for model parallelism. This might take a while.', 

172 master_address) 

173 with ops.Graph().as_default(): 

174 session_config = get_session_config_with_timeout( 

175 _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def) 

176 with session_lib.Session( 

177 master_address, config=session_config) as sess: 

178 topology = sess.run(tpu.initialize_system()) 

179 return topology 

180 except errors.DeadlineExceededError: 

181 raise ValueError( 

182 'Fail to initialize TPU system with master (%s). ' 

183 'Please double check the TPU system is functional.' % ( 

184 master_address)) 

185 

186 

187def get_session_config_with_timeout(timeout_in_secs, cluster_def): 

188 """Returns a session given a timeout and a cluster configuration.""" 

189 config_proto = config_pb2.ConfigProto( 

190 operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def) 

191 return config_proto 

192 

193 

194def master_job(master, cluster_def): 

195 """Returns the canonical job name to use to place TPU computations on. 

196 

197 Args: 

198 master: A `string` representing the TensorFlow master to use. 

199 cluster_def: A ClusterDef object describing the TPU cluster. 

200 

201 Returns: 

202 A string containing the job name, or None if no job should be specified. 

203 

204 Raises: 

205 ValueError: If the user needs to specify a tpu_job_name, because we are 

206 unable to infer the job name automatically, or if the user-specified job 

207 names are inappropriate. 

208 """ 

209 # If the user specifies the tpu_job_name, use that. 

210 

211 if master in _LOCAL_MASTERS: 

212 return None 

213 

214 if (not cluster_def or not cluster_def.job): 

215 return _DEFAULT_JOB_NAME 

216 job_names = set(job.name for job in cluster_def.job) 

217 if _DEFAULT_JOB_NAME in job_names: 

218 # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. 

219 raise ValueError('Currently, tpu_worker is not an allowed job name.') 

220 if len(job_names) == 1: 

221 return cluster_def.job[0].name 

222 if len(job_names) == 2: 

223 if _DEFAULT_COORDINATOR_JOB_NAME in job_names: 

224 job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) 

225 return job_names.pop() 

226 # TODO(b/67716447): Include more sophisticated heuristics. 

227 raise ValueError('Could not infer TPU job name.')