Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/mesh_util.py: 24%

97 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities to help with mesh creation.""" 

16 

17from typing import List, Optional, Tuple 

18from absl import logging 

19import numpy as np 

20 

21from tensorflow.dtensor.python import accelerator_util 

22from tensorflow.dtensor.python import api 

23from tensorflow.dtensor.python import config 

24from tensorflow.dtensor.python import layout 

25from tensorflow.dtensor.python import tpu_util 

26from tensorflow.python.eager import context 

27from tensorflow.python.framework import device as tf_device 

28from tensorflow.python.ops import array_ops 

29from tensorflow.python.ops import math_ops 

30from tensorflow.python.util.tf_export import tf_export 

31 

32 

33def _print_context(num_global_devices: int, num_clients: int, client_id: int, 

34 device_type: str, mesh: layout.Mesh) -> None: 

35 logging.info('This is client %d of %d clients', client_id, num_clients) 

36 logging.info('Number of global %s devices: %d', device_type.upper(), 

37 num_global_devices) 

38 # pylint: disable=protected-access 

39 logging.info('Global device IDs: %s', mesh.global_device_ids()) 

40 logging.info('Local device IDs: %s', mesh.local_device_ids()) 

41 logging.info('Local devices: %s', mesh.local_devices()) 

42 # pylint: enable=protected-access 

43 

44 

45def _make_device_specs( 

46 devices: Optional[List[str]] = None, 

47 device_type: Optional[str] = None 

48) -> Tuple[List[tf_device.DeviceSpec], str]: 

49 """Makes device specs from local devices names or number of global devices.""" 

50 

51 if devices is None: 

52 if device_type is None: 

53 device_type = 'CPU' 

54 devices = config.local_devices(device_type) 

55 else: 

56 devices = [tf_device.DeviceSpec.from_string(d) for d in devices] 

57 if device_type is None: 

58 device_type = devices[0].device_type 

59 

60 if device_type.upper() != devices[0].device_type.upper(): 

61 raise ValueError( 

62 f'Conflicting devices {str(devices)} and device_type {device_type}') 

63 

64 return devices, device_type 

65 

66 

67@tf_export('experimental.dtensor.create_mesh', v1=[]) 

68def create_mesh(mesh_dims: Optional[List[Tuple[str, int]]] = None, 

69 mesh_name: str = '', 

70 devices: Optional[List[str]] = None, 

71 device_type: Optional[str] = None, 

72 use_xla_spmd: bool = layout.USE_XLA_SPMD) -> layout.Mesh: 

73 """Creates a single-client mesh. 

74 

75 If both `mesh_dims` and `devices` are specified, they must match each otehr. 

76 As a special case, when all arguments are missing, this creates a 1D CPU mesh 

77 with an empty name, assigning all available devices to that dimension. 

78 

79 Args: 

80 mesh_dims: A list of (dim_name, dim_size) tuples. Defaults to a single 

81 batch-parallel dimension called 'x' using all devices. As a special case, 

82 a single-element mesh_dims whose dim_size is -1 also uses all devices. 

83 mesh_name: Name of the created mesh. Defaults to ''. 

84 devices: String representations of devices to use. This is the device part 

85 of tf.DeviceSpec, e.g. 'CPU:0'. Defaults to all available logical devices. 

86 device_type: If `devices` is missing, the type of devices to use. Defaults 

87 to 'CPU'. 

88 use_xla_spmd: Boolean when True, will use XLA SPMD instead of 

89 DTensor SPMD. 

90 

91 Returns: 

92 A single-client mesh created from specified or default arguments. 

93 """ 

94 device_specs, device_type = _make_device_specs(devices, device_type) 

95 

96 local_spec = tf_device.DeviceSpec(job=config.job_name(), replica=0, task=0) 

97 device_specs = [local_spec.make_merged_spec(d) for d in device_specs] 

98 

99 if mesh_dims is None: 

100 mesh_dims = [('x', len(device_specs))] 

101 elif len(mesh_dims) == 1 and mesh_dims[0][1] == -1: 

102 # Replace -1 dim_size in a 1D mesh will the number of all devices. 

103 mesh_dims[0] = (mesh_dims[0][0], len(device_specs)) 

104 

105 dim_names = [d[0] for d in mesh_dims] 

106 shape = [d[1] for d in mesh_dims] 

107 

108 if np.prod(shape) != len(device_specs): 

109 raise ValueError(f'length of devices ({len(device_specs)}) must be ' 

110 f'equal to total size of the mesh of shape {shape}') 

111 

112 global_device_ids = np.arange(len(device_specs)).reshape(shape) 

113 local_device_ids = np.ravel(global_device_ids).tolist() 

114 mesh = layout.Mesh( 

115 dim_names=dim_names, 

116 global_device_ids=global_device_ids, 

117 local_device_ids=local_device_ids, 

118 local_devices=device_specs, 

119 mesh_name=mesh_name, 

120 use_xla_spmd=use_xla_spmd) 

121 _print_context( 

122 num_global_devices=len(device_specs), 

123 num_clients=1, 

124 client_id=0, 

125 device_type=device_type, 

126 mesh=mesh) 

127 return mesh 

128 

129 

130@tf_export('experimental.dtensor.create_distributed_mesh', v1=[]) 

131def create_distributed_mesh( 

132 mesh_dims: List[Tuple[str, int]], 

133 mesh_name: str = '', 

134 local_devices: Optional[List[str]] = None, 

135 device_type: Optional[str] = None, 

136 use_xla_spmd: bool = layout.USE_XLA_SPMD) -> layout.Mesh: 

137 """Creates a distributed mesh. 

138 

139 This is similar to `create_mesh`, but with a different set of arguments to 

140 create a mesh that spans evenly across a multi-client DTensor cluster. 

141 

142 For CPU and GPU meshes, users can choose to use fewer local devices than what 

143 is available `local_devices`. 

144 

145 For TPU, only meshes that uses all TPU cores is supported by the DTensor 

146 runtime. 

147 

148 Args: 

149 mesh_dims: A list of (dim_name, dim_size) tuples. 

150 mesh_name: Name of the created mesh. Defaults to ''. 

151 local_devices: String representations of devices to use. This is the device 

152 part of tf.DeviceSpec, e.g. 'CPU:0'. Defaults to all available local 

153 logical devices. 

154 device_type: Type of device to build the mesh for. Defaults to 'CPU'. 

155 Supported values are 'CPU', 'GPU', 'TPU'.6 

156 use_xla_spmd: Boolean when True, will use XLA SPMD instead of 

157 DTensor SPMD. 

158 

159 Returns: 

160 A mesh that spans evenly across all DTensor clients in the cluster. 

161 """ 

162 dim_names, shape = zip(*mesh_dims) 

163 

164 if not accelerator_util.is_initialized(): 

165 raise ValueError('Accelerators are uninitialized, please run ' 

166 'dtensor.initialize_accelerator_system() first.') 

167 

168 if device_type and device_type.upper() == 'TPU': 

169 # TODO(b/185940495): Allow multi-mesh and partial on TPU. 

170 # TPU meshes can only be configured through environment variables that 

171 # reflect the actual TPU topology. Do not let users specify custom args. 

172 if local_devices is not None: 

173 raise ValueError( 

174 f'Do not specify devices for {device_type.upper()} meshes. ' 

175 f'Using a partial list of devices for {device_type.upper()} ' 

176 f'is not supported.') 

177 

178 device_specs, device_type = _make_device_specs(local_devices, device_type) 

179 

180 if device_type.upper() in ['CPU', 'GPU']: 

181 # For CPU and GPU meshes, user-specified args take precedence over env vars. 

182 # This is particularly useful on single clients when users want to create 

183 # meshes that use fewer logical devices than what's available. 

184 

185 local_spec = tf_device.DeviceSpec( 

186 job=config.job_name(), replica=0, task=config.client_id()) 

187 device_specs = [local_spec.make_merged_spec(d) for d in device_specs] 

188 

189 # Assumes identical number of local devices per client. 

190 num_global_devices = len(device_specs) * config.num_clients() 

191 

192 if np.prod(shape) != num_global_devices: 

193 raise ValueError( 

194 f'Global number of devices ' 

195 f'({len(device_specs)} per client * {config.num_clients()} clients ' 

196 f'= {num_global_devices}) must be ' 

197 f'equal to total size of the mesh of shape {shape}') 

198 

199 global_device_ids = np.arange(num_global_devices).reshape(shape) 

200 flattened = np.ravel(global_device_ids).tolist() 

201 start_idx = len(device_specs) * config.client_id() 

202 local_device_ids = flattened[start_idx:start_idx + len(device_specs)] 

203 

204 mesh = layout.Mesh( 

205 dim_names=dim_names, 

206 global_device_ids=global_device_ids, 

207 local_device_ids=local_device_ids, 

208 local_devices=device_specs, 

209 mesh_name=mesh_name, 

210 use_xla_spmd=use_xla_spmd) 

211 _print_context(num_global_devices, config.num_clients(), config.client_id(), 

212 device_type, mesh) 

213 return mesh 

214 

215 if device_type.upper() == 'TPU': 

216 mesh = tpu_util.create_tpu_mesh( 

217 mesh_dim_names=dim_names, 

218 mesh_shape=shape, 

219 mesh_name=mesh_name, 

220 use_xla_spmd=use_xla_spmd) 

221 _print_context( 

222 config.num_global_devices(device_type), config.num_clients(), 

223 config.client_id(), device_type, mesh) 

224 return mesh 

225 

226 raise ValueError(f'Device type {device_type} is not CPU, GPU or TPU') 

227 

228 

229_BARRIER_DICT = {} 

230 

231 

232@tf_export('experimental.dtensor.barrier', v1=[]) 

233def barrier(mesh: layout.Mesh, 

234 barrier_name: Optional[str] = None, 

235 timeout_in_ms: Optional[int] = None): 

236 """Runs a barrier on the mesh. 

237 

238 Upon returning from the barrier, all operations run before the barrier 

239 would have completed across all clients. Currently we allocate a fully 

240 sharded tensor with mesh shape and run an all_reduce on it. 

241 

242 Example: 

243 

244 A barrier can be used before application exit to ensure completion of pending 

245 ops. 

246 

247 ```python 

248 

249 x = [1, 2, 3] 

250 x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1)) 

251 dtensor.barrier(mesh) 

252 

253 # At this point all devices on all clients in the mesh have completed 

254 # operations before the barrier. Therefore it is OK to tear down the clients. 

255 sys.exit() 

256 ``` 

257 

258 Args: 

259 mesh: The mesh to run the barrier on. 

260 barrier_name: The name of the barrier. Mainly used for logging purpose. 

261 timeout_in_ms: The timeout of the barrier in ms. If omitted, blocks 

262 indefinitely till the barrier is reached from all clients. 

263 """ 

264 if barrier_name is None: 

265 barrier_name = '(barrier)' 

266 

267 logging.info('entering barrier before op: %s', barrier_name) 

268 

269 # Make sure all ops are consumed before running the sync. 

270 context.async_wait() 

271 

272 # Reduction on a fully sharded tensor requires all devices to participate 

273 # and serves as a barrier on the mesh. 

274 component = array_ops.reshape(1.0, [1] * len(mesh.shape())) 

275 ones = api.pack([component] * mesh.num_local_devices(), 

276 layout.Layout(mesh.dim_names, mesh)) 

277 

278 mesh_size = math_ops.reduce_sum(ones) 

279 if mesh_size != mesh.size: 

280 raise ValueError( 

281 'Global barrier produced wrong mesh size : {0} while mesh has actual' 

282 'size : {1}'.format(mesh_size, mesh.size)) 

283 

284 # TODO(hthu): This isn't strictly needed but might cause confusing behaviors 

285 # from users. Consider dropping this if there is a `big` performance hit. 

286 context.async_wait() 

287 

288 if context.context().coordination_service: 

289 if timeout_in_ms is None: 

290 timeout_in_ms = 24 * 60 * 60 * 1000 # 24 hours to stand in for infinite. 

291 

292 num_calls = _BARRIER_DICT.setdefault(barrier_name, 0) 

293 _BARRIER_DICT[barrier_name] = num_calls + 1 

294 

295 barrier_id = f'{barrier_name}:{num_calls}' 

296 context.context().wait_at_barrier(barrier_id, timeout_in_ms) 

297 

298 logging.info('finished running barrier across all clients after ' 

299 'op: %s', barrier_name)