Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/config.py: 43%

88 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""DTensor Configuration API.""" 

16 

17import os 

18from typing import List, Optional, Union 

19 

20from tensorflow.python.eager import context 

21from tensorflow.python.framework import config as tf_config 

22from tensorflow.python.framework import device as tf_device 

23from tensorflow.python.util.tf_export import tf_export 

24 

25_DT_CLIENT_ID = "DTENSOR_CLIENT_ID" 

26# DTENSOR_NUM_CLIENTS is removed, but some DTensor users still use this symbol. 

27_DT_NUM_CLIENTS = "DTENSOR_NUM_CLIENTS" 

28_DT_JOB_NAME = "DTENSOR_JOB_NAME" 

29_DT_JOBS = "DTENSOR_JOBS" 

30_DT_HEARTBEAT_ENABLED = "DTENSOR_ENABLE_HEARTBEAT" 

31 

32 

33# All functions in this file can be used before calling 

34# `tf.experimental.dtensor.initialize_accelerator_system`. 

35 

36 

37# ----------------------------------------------------------------------------- 

38# Distributed training-related methods. 

39# 

40# Most users should use DTensor utility methods to create a mesh. The methods 

41# here are only for advanced users who want to fully customize their meshes. 

42# Note that local_devices and num_local_devices return the actual number of 

43# locally attached devices. The others are set through environment variables. 

44 

45 

46@tf_export("experimental.dtensor.local_devices", v1=[]) 

47def local_devices( 

48 device_type: str, 

49 for_client_id: Optional[int] = None) -> List[tf_device.DeviceSpec]: 

50 """Returns a list of device specs configured on this client.""" 

51 if device_type.upper() not in ["CPU", "GPU", "TPU"]: 

52 raise ValueError(f"Device type {device_type} is not CPU, GPU, or TPU.") 

53 

54 if for_client_id is None: 

55 for_client_id = client_id() 

56 

57 # Return fully qualified device specs, sorted by increasing device index. 

58 return [ 

59 tf_device.DeviceSpec( # pylint: disable=g-complex-comprehension 

60 job=job_name(), 

61 replica=0, # replica is deprecated and mostly hard-coded now. 

62 task=for_client_id, 

63 device_type=device_type, 

64 device_index=i) for i in range(num_local_devices(device_type)) 

65 ] 

66 

67 

68@tf_export("experimental.dtensor.num_local_devices", v1=[]) 

69def num_local_devices(device_type: str) -> int: 

70 """Returns the number of devices of device_type configured on this client.""" 

71 

72 # Reads from config because CPU and GPU can use logical devices. 

73 if device_type.upper() in ["CPU", "GPU"]: 

74 context_config = context.get_config() 

75 return context_config.device_count[device_type.upper()] 

76 

77 return len(tf_config.list_physical_devices(device_type)) 

78 

79 

80@tf_export("experimental.dtensor.num_global_devices", v1=[]) 

81def num_global_devices(device_type: str) -> int: 

82 """Returns the number of devices of device_type in this DTensor cluster.""" 

83 return num_local_devices(device_type) * num_clients() 

84 

85 

86@tf_export("experimental.dtensor.client_id", v1=[]) 

87def client_id() -> int: 

88 """Returns this client's ID.""" 

89 # If missing, assume running with a single client with client_id of 0. 

90 client_id_value = int(os.environ.get(_DT_CLIENT_ID, "0")) 

91 if client_id_value < 0: 

92 raise ValueError(f"Environment variable {_DT_CLIENT_ID} " 

93 f"must be >= 0, got {client_id_value}. ") 

94 if client_id_value >= num_clients(): 

95 raise ValueError(f"Environment variable {_DT_CLIENT_ID} " 

96 f"must be < {num_clients()}, got {client_id_value}") 

97 return client_id_value 

98 

99 

100@tf_export("experimental.dtensor.num_clients", v1=[]) 

101def num_clients() -> int: 

102 """Returns the number of clients in this DTensor cluster.""" 

103 if is_local_mode(): 

104 return 1 

105 return len(jobs()) 

106 

107 

108@tf_export("experimental.dtensor.job_name", v1=[]) 

109def job_name() -> str: 

110 """Returns the job name used by all clients in this DTensor cluster.""" 

111 # If missing, assumes the program runs locally and use localhost as job name 

112 # per TensorFlow convention. 

113 return os.environ.get(_DT_JOB_NAME, 

114 "localhost" if num_clients() == 1 else "worker") 

115 

116 

117@tf_export("experimental.dtensor.full_job_name", v1=[]) 

118def full_job_name(task_id: Optional[int] = None) -> str: 

119 """Returns the fully qualified TF job name for this or another task.""" 

120 # If task_id is None, use this client's ID, which is equal to its task ID. 

121 if task_id is None: 

122 task_id = client_id() 

123 # In local runs and unit tests, there should be exactly one client running 

124 # on one TF task. 

125 if num_clients() == 1 and task_id != 0: 

126 raise ValueError(f"Unexpected task ID {task_id} in local runs") 

127 return f"{job_name()}/replica:0/task:{task_id}" 

128 

129 

130def _bns_task_id(job: str) -> Union[int, str]: 

131 """Tries to extract an integer task ID from a job name. 

132 

133 For example, for `job` = '/.../tpu_worker/0:port_name', return 0. 

134 

135 Args: 

136 job: A job name to extract task ID from. 

137 

138 Returns: 

139 The task ID on success, or the original job name on failure. 

140 """ 

141 maybe_task_id = job.rsplit("/")[-1].rsplit(":")[0] 

142 try: 

143 return int(maybe_task_id) 

144 except ValueError: 

145 return job 

146 

147 

148@tf_export("experimental.dtensor.jobs", v1=[]) 

149def jobs() -> List[str]: 

150 """Returns a list of job names of all clients in this DTensor cluster.""" 

151 d_jobs = os.environ.get(_DT_JOBS) 

152 if d_jobs is None: 

153 return [] 

154 d_jobs_list = d_jobs.split(",") 

155 

156 # Validate ordering for BNS style job names. 

157 # For definition of BNS, refer to https://research.google/pubs/pub43438/. 

158 if any([name.startswith("/bns/") for name in d_jobs_list]): 

159 if d_jobs_list != sorted(d_jobs_list, key=_bns_task_id): 

160 raise ValueError( 

161 f"Unexpected DTENSOR_JOBS content {d_jobs}. Sort entries " 

162 "in DTENSOR_JOBS because cluster construction relies on " 

163 "the order.") 

164 

165 return d_jobs_list 

166 

167 

168@tf_export("experimental.dtensor.heartbeat_enabled", v1=[]) 

169def heartbeat_enabled() -> bool: 

170 """Returns true if DTensor heartbeat service is enabled.""" 

171 return os.environ.get(_DT_HEARTBEAT_ENABLED, "true").lower() in ("true", "1") 

172 

173 

174def is_local_mode() -> bool: 

175 """Returns true if DTensor shall run in local mode.""" 

176 return not jobs() 

177 

178 

179def is_tpu_present() -> bool: 

180 """Returns true if TPU devices are present.""" 

181 # Check if TPU is present from initialized context. 

182 # TPU_SYSTEM is a device that indicates TPUs are present. 

183 tpu_system_devices = tf_config.list_physical_devices("TPU_SYSTEM") 

184 return bool(tpu_system_devices) 

185 

186 

187def is_gpu_present() -> bool: 

188 """Returns true if TPU devices are present.""" 

189 return bool(tf_config.list_physical_devices("GPU")) 

190 

191 

192@tf_export("experimental.dtensor.preferred_device_type", v1=[]) 

193def preferred_device_type() -> str: 

194 """Returns the preferred device type for the accelerators. 

195 

196 The returned device type is determined by checking the first present device 

197 type from all supported device types in the order of 'TPU', 'GPU', 'CPU'. 

198 """ 

199 if is_tpu_present(): 

200 return "TPU" 

201 elif is_gpu_present(): 

202 return "GPU" 

203 

204 return "CPU" 

205 

206 

207def gpu_use_nccl_communication() -> bool: 

208 """Return True if environment indicates NCCL shall be used for GPU.""" 

209 return os.environ.get("DTENSOR_GPU_USE_NCCL_COMMUNICATION", "0") != "0" 

210 

211 

212def backend_is_pw() -> bool: 

213 """Return True if environment indicates the backend is Pathways.""" 

214 return os.environ.get("DTENSOR_USE_PARALLEL_EXECUTOR") == "pw"