Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/layout.py: 30%

217 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python definitions for `Mesh` and `Layout`.""" 

16 

17import collections 

18import itertools 

19from typing import List, Dict, Optional, Union 

20 

21import numpy as np 

22 

23from tensorflow.dtensor.proto import layout_pb2 

24from tensorflow.dtensor.python import config 

25from tensorflow.python import _pywrap_dtensor_device 

26from tensorflow.python.framework import device as tf_device 

27from tensorflow.python.framework import ops 

28from tensorflow.python.util.tf_export import tf_export 

29 

30# UNSHARDED indicates a tensor dimension is not sharded over any mesh dimension. 

31UNSHARDED = 'unsharded' 

32MATCH = 'match' 

33USE_XLA_SPMD = False 

34 

35tf_export( 

36 'experimental.dtensor.UNSHARDED', 

37 v1=[]).export_constant(__name__, 'UNSHARDED') 

38tf_export( 

39 'experimental.dtensor.MATCH', v1=[]).export_constant(__name__, 'MATCH') 

40 

41MeshDimension = collections.namedtuple('MeshDimension', ['name', 'size']) 

42 

43 

44def _compute_mesh_strides(shape: List[int]) -> List[int]: 

45 strides = [1] 

46 for idx, dim_size in enumerate(reversed(shape[1:])): 

47 strides.append(strides[idx] * dim_size) 

48 strides.reverse() 

49 return strides 

50 

51 

52@tf_export('experimental.dtensor.Mesh', v1=[]) 

53class Mesh(_pywrap_dtensor_device.Mesh): 

54 """Represents a Mesh configuration over a certain list of Mesh Dimensions. 

55 

56 A mesh consists of named dimensions with sizes, which describe how a set of 

57 devices are arranged. Defining tensor layouts in terms of mesh dimensions 

58 allows us to efficiently determine the communication required when computing 

59 an operation with tensors of different layouts. 

60 

61 A mesh provides information not only about the placement of the tensors but 

62 also the topology of the underlying devices. For example, we can group 8 TPUs 

63 as a 1-D array for data parallelism or a `2x4` grid for (2-way) data 

64 parallelism and (4-way) model parallelism. 

65 

66 Note: the utilities `dtensor.create_mesh` and 

67 `dtensor.create_distributed_mesh` provide a simpler API to create meshes for 

68 single- or multi-client use cases. 

69 """ 

70 

71 def __init__( 

72 self, 

73 dim_names: List[str], 

74 global_device_ids: np.ndarray, 

75 local_device_ids: List[int], 

76 local_devices: List[Union[tf_device.DeviceSpec, str]], 

77 mesh_name: str = '', 

78 global_devices: Optional[List[Union[tf_device.DeviceSpec, str]]] = None, 

79 use_xla_spmd: bool = USE_XLA_SPMD, 

80 ): 

81 """Builds a Mesh. 

82 

83 The `dim_names` and `global_device_ids` arguments describe the dimension 

84 names and shape for the mesh. 

85 

86 For example, 

87 

88 ```python 

89 dim_names = ('x', 'y'), 

90 global_device_ids = [[0, 1], 

91 [2, 3], 

92 [4, 5]] 

93 ``` 

94 

95 defines a 2D mesh of shape 3x2. A reduction over the 'x' dimension will 

96 reduce across columns (0, 2, 4) and (1, 3, 5), and a reduction over the 'y' 

97 dimension reduces across rows. 

98 

99 Note: the utilities `dtensor.create_mesh` and 

100 `dtensor.create_distributed_mesh` provide a simpler API to create meshes for 

101 single- or multi-client use cases. 

102 

103 Args: 

104 dim_names: A list of strings indicating dimension names. 

105 global_device_ids: An ndarray of global device IDs is used to compose 

106 DeviceSpecs describing the mesh. The shape of this array determines the 

107 size of each mesh dimension. Values in this array should increment 

108 sequentially from 0. This argument is the same for every DTensor client. 

109 local_device_ids: A list of local device IDs equal to a subset of values 

110 in global_device_ids. They indicate the position of local devices in the 

111 global mesh. Different DTensor clients must contain distinct 

112 local_device_ids contents. All local_device_ids from all DTensor clients 

113 must cover every element in global_device_ids. 

114 local_devices: The list of devices hosted locally. The elements correspond 

115 1:1 to those of local_device_ids. 

116 mesh_name: The name of the mesh. Currently, this is rarely used, and is 

117 mostly used to indicate whether it is a CPU, GPU, or TPU-based mesh. 

118 global_devices (optional): The list of global devices. Set when multiple 

119 device meshes are in use. 

120 use_xla_spmd (optional): Boolean when True, will use XLA SPMD instead of 

121 DTensor SPMD. 

122 """ 

123 # Check if input args are valid. 

124 if not isinstance(global_device_ids, np.ndarray): 

125 raise ValueError('Variable global_device_ids must be an ndarray.') 

126 if global_device_ids.size == 0: 

127 raise ValueError('Variable global_device_ids must be non-empty.') 

128 flat_global_device_ids = global_device_ids.flatten() 

129 # global_device_ids are expected to be consecutive numbers. 

130 # LINT.IfChange 

131 distance = flat_global_device_ids[0] 

132 if any( 

133 (gid - i != distance) for i, gid in enumerate(flat_global_device_ids)): 

134 raise ValueError('global_device_ids must sequentially increase: %s' % 

135 global_device_ids) 

136 # LINT.ThenChange(//tensorflow/dtensor/cc/dtensor_device.cc) 

137 

138 # TODO(b/242201545): This class is only for args type transformation for 

139 # exported C++ Mesh class after the unification is complete. Any other 

140 # logics should reside in the C++ layer, including validation checks, shall 

141 # go to C++. 

142 

143 if len(dim_names) != global_device_ids.ndim: 

144 raise ValueError( 

145 'Number of mesh dimensions does not match number of dimension names.') 

146 

147 if not isinstance(local_device_ids, list): 

148 raise ValueError('Variable local_device_ids must be a list of integers.') 

149 

150 if not isinstance(local_devices, list): 

151 raise ValueError('Variable local_devices must be a list of DeviceSpecs.') 

152 

153 if global_devices and not isinstance(global_devices, list): 

154 raise ValueError('Variable global_devices must be a list of DeviceSpecs.') 

155 

156 if not local_devices and not global_devices: 

157 raise ValueError('Empty list of devices not allowed.') 

158 

159 # Transform args format for C++ Mesh constructor 

160 global_device_ids_flatten = global_device_ids.flatten() 

161 global_device_ids_shape = global_device_ids.shape 

162 

163 def to_str(d) -> str: 

164 if isinstance(d, tf_device.DeviceSpec): 

165 return d.to_string() 

166 return d 

167 

168 def to_spec(d) -> tf_device.DeviceSpec: 

169 if not isinstance(d, tf_device.DeviceSpec): 

170 return tf_device.DeviceSpec.from_string(d) 

171 return d 

172 

173 local_devices_str = [to_str(d) for d in local_devices] 

174 local_devices_spec = [to_spec(d) for d in local_devices] 

175 if not global_devices: 

176 global_devices = [] 

177 global_devices_str = [to_str(d) for d in global_devices] 

178 global_devices_spec = [to_spec(d) for d in global_devices] 

179 

180 local_devices_set = set(local_devices_spec) 

181 local_device_only_contains_host_cpu = ( 

182 len(local_devices_set) == 1 and 

183 list(local_devices_set)[0].device_type == 'CPU') 

184 if not local_device_only_contains_host_cpu and len(local_devices) != len( 

185 local_devices_set): 

186 raise ValueError('Duplicate devices found in mesh specification %s.' % 

187 [d for d in local_devices if local_devices.count(d) > 1]) 

188 

189 if len(local_device_ids) != len(local_devices): 

190 raise ValueError( 

191 'Variable local_device_ids does not have same size as local_devices.') 

192 

193 if len(local_device_ids) > len(np.ravel(global_device_ids)): 

194 raise ValueError('Cannot have more local than gobal device IDs.') 

195 

196 device_types = set([device.device_type for device in local_devices_spec]) 

197 if not device_types: 

198 device_types = set([device.device_type for device in global_devices_spec]) 

199 if None in device_types: 

200 raise ValueError('device_type is required') 

201 if len(device_types) > 1: 

202 raise ValueError('Devices containing multiple device_types : %s' % 

203 device_types) 

204 device_type = device_types.pop() 

205 if use_xla_spmd and device_type != 'TPU': 

206 raise ValueError('XLA SPMD is not currently not supported for %s mesh.' % 

207 device_type) 

208 

209 super().__init__( 

210 mesh_name, 

211 dim_names, 

212 global_device_ids_shape, 

213 global_device_ids_flatten, 

214 global_devices_str, 

215 local_device_ids, 

216 local_devices_str, 

217 use_xla_spmd, 

218 ) 

219 

220 def global_device_ids(self) -> np.ndarray: 

221 """Returns a global device list as an array.""" 

222 return np.array(super().global_device_ids(), dtype=np.int64).reshape( 

223 self.shape() 

224 ) 

225 

226 def __getitem__(self, dim_name: str) -> MeshDimension: 

227 return MeshDimension(name=dim_name, size=self.dim_size(dim_name)) 

228 

229 def __hash__(self): 

230 return hash(self.as_proto().SerializeToString(deterministic=True)) 

231 

232 def __repr__(self) -> str: 

233 dims = [tuple(self[dim_name]) for dim_name in self.dim_names] 

234 return ( 

235 f'<Mesh object with dims={dims}, device_type="{self.device_type()}", ' 

236 f'num_local_devices={self.num_local_devices()}), ' 

237 f'size={self.size}>' 

238 ) 

239 

240 # TODO(panzf): change to pybind11 pickle implementation in the last step 

241 def __reduce__(self): 

242 return Mesh.from_string, (self.to_string(),) 

243 

244 # TODO(b/242201545): implement this in Mesh C++ class 

245 def coords(self, device_idx: int) -> ops.Tensor: 

246 """Converts the device index into a tensor of mesh coordinates.""" 

247 strides = ops.convert_to_tensor(self.strides) 

248 shape = ops.convert_to_tensor(self.shape()) 

249 return (device_idx // strides) % shape 

250 

251 @classmethod 

252 def from_proto(cls, proto: layout_pb2.MeshProto) -> 'Mesh': 

253 """Construct a mesh instance from input `proto`.""" 

254 mesh = _pywrap_dtensor_device.Mesh.__new__(cls) 

255 _pywrap_dtensor_device.Mesh.__init__(mesh, mesh_proto=proto) 

256 return mesh 

257 

258 @classmethod 

259 def from_string(cls, mesh_str: str) -> 'Mesh': 

260 mesh = _pywrap_dtensor_device.Mesh.__new__(cls) 

261 _pywrap_dtensor_device.Mesh.__init__(mesh, mesh_str=mesh_str) 

262 return mesh 

263 

264 @classmethod 

265 def from_device(cls, device: str) -> 'Mesh': 

266 """Constructs a single device mesh from a device string.""" 

267 mesh = _pywrap_dtensor_device.Mesh.__new__(cls) 

268 _pywrap_dtensor_device.Mesh.__init__(mesh, single_device=device) 

269 return mesh 

270 

271 # TODO(b/242201545): implement this in Mesh C++ class 

272 def host_mesh(self): 

273 """Returns the 1-1 mapped host mesh.""" 

274 if self.device_type().upper() == 'CPU': 

275 return self 

276 

277 v_cpus_counts = config.num_local_devices('CPU') 

278 if v_cpus_counts < len(self.local_devices()): 

279 raise ValueError( 

280 'Must have at least {0} virtual CPUs for mesh : {1}, ' 

281 'but got : {2} virtual CPUs. ' 

282 'Call tf.experimental.dtensor.initialize_accelerator_system() ' 

283 'to initialize the host CPU devices with the accelerators.'.format( 

284 len(self.local_devices()), self.to_string(), v_cpus_counts 

285 ) 

286 ) 

287 local_device_specs = [ 

288 tf_device.DeviceSpec.from_string(d) for d in self.local_devices() 

289 ] 

290 global_device_specs = [ 

291 tf_device.DeviceSpec.from_string(d) for d in self.global_devices() 

292 ] 

293 

294 device_array = np.asarray( 

295 [spec.replace(device_type='CPU') for spec in local_device_specs] 

296 ).reshape((len(self.local_devices()), 1)) 

297 global_devices = [ 

298 spec.replace(device_type='CPU') for spec in global_device_specs 

299 ] 

300 h_mesh = Mesh( 

301 self.dim_names, 

302 self.global_device_ids(), 

303 self.local_device_ids(), 

304 np.ravel(device_array).tolist(), 

305 global_devices=global_devices, 

306 ) 

307 return h_mesh 

308 

309 # TODO(b/242201545): implement this in Mesh C++ class 

310 def local_device_locations(self) -> List[Dict[str, int]]: 

311 """Returns a list of local device locations. 

312 

313 A device location is a dictionary from dimension names to indices on those 

314 dimensions. 

315 """ 

316 mapping = self.unravel_index() 

317 return [mapping[device_id] for device_id in self.local_device_ids()] 

318 

319 # TODO(b/242201545): implement this in Mesh C++ class 

320 @property 

321 def strides(self) -> List[int]: 

322 """Returns the strides tensor array for this mesh. 

323 

324 If the mesh shape is `[a, b, c, d]`, then the strides array can be computed 

325 as `[b*c*d, c*d, d, 1]`. This array can be useful in computing local device 

326 offsets given a device ID. Using the same example, the device coordinates of 

327 the mesh can be computed as: 

328 

329 ``` 

330 [(device_id / (b*c*d)) % a, 

331 (device_id / (c*d)) % b, 

332 (device_id / (d)) % c, 

333 (device_id) % d] 

334 ``` 

335 

336 This is the same as `(device_id // mesh.strides) % mesh.shape`. 

337 

338 Returns: 

339 The mesh strides as an integer tensor. 

340 """ 

341 return _compute_mesh_strides(self.shape()) 

342 

343 # TODO(b/242201545): implement this in Mesh C++ class 

344 def unravel_index(self): 

345 """Returns a dictionary from device ID to {dim_name: dim_index}. 

346 

347 For example, for a 3x2 mesh, return this: 

348 

349 ``` 

350 { 0: {'x': 0, 'y', 0}, 

351 1: {'x': 0, 'y', 1}, 

352 2: {'x': 1, 'y', 0}, 

353 3: {'x': 1, 'y', 1}, 

354 4: {'x': 2, 'y', 0}, 

355 5: {'x': 2, 'y', 1} } 

356 ``` 

357 """ 

358 idx_ranges = [range(self.dim_size(dim_name)) for dim_name in self.dim_names] 

359 mesh_pos = itertools.product(*idx_ranges) 

360 mapping = {} 

361 for device_id, device_pos in enumerate(mesh_pos): 

362 device_loc = {} 

363 for dim_name, dim_index in zip(self.dim_names, device_pos): 

364 device_loc[dim_name] = dim_index 

365 mapping[device_id] = device_loc 

366 return mapping 

367 

368 

369# TODO(hthu): Consider making this class immutable. 

370@tf_export('experimental.dtensor.Layout', v1=[]) 

371class Layout(_pywrap_dtensor_device.Layout): 

372 """Represents the layout information of a DTensor. 

373 

374 A layout describes how a distributed tensor is partitioned across a mesh (and 

375 thus across devices). For each axis of the tensor, the corresponding 

376 sharding spec indicates which dimension of the mesh it is sharded over. A 

377 special sharding spec `UNSHARDED` indicates that axis is replicated on 

378 all the devices of that mesh. 

379 

380 For example, let's consider a 1-D mesh: 

381 

382 ``` 

383 Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], [("x", 6)]) 

384 ``` 

385 

386 This mesh arranges 6 TPU devices into a 1-D array. `Layout([UNSHARDED], mesh)` 

387 is a layout for rank-1 tensor which is replicated on the 6 devices. 

388 

389 For another example, let's consider a 2-D mesh: 

390 

391 ``` 

392 Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], 

393 [("x", 3), ("y", 2)]) 

394 ``` 

395 

396 This mesh arranges 6 TPU devices into a `3x2` 2-D array. 

397 `Layout(["x", UNSHARDED], mesh)` is a layout for rank-2 tensor whose first 

398 axis is sharded on mesh dimension "x" and the second axis is replicated. If we 

399 place `np.arange(6).reshape((3, 2))` using this layout, the individual 

400 components tensors would look like: 

401 

402 ``` 

403 Device | Component 

404 TPU:0 [[0, 1]] 

405 TPU:1 [[0, 1]] 

406 TPU:2 [[2, 3]] 

407 TPU:3 [[2, 3]] 

408 TPU:4 [[4, 5]] 

409 TPU:5 [[4, 5]] 

410 ``` 

411 """ 

412 

413 def __init__(self, sharding_specs: List[str], mesh: Mesh): 

414 """Builds a Layout from a list of dimension names and a Mesh. 

415 

416 Args: 

417 sharding_specs: List of sharding specifications, each corresponding to a 

418 tensor axis. Each specification (dim_sharding) can either be a mesh 

419 dimension or the special value UNSHARDED. 

420 mesh: A mesh configuration for the Tensor. 

421 

422 Returns: 

423 A valid Layout built with given layout & mesh. 

424 """ 

425 # Validate mesh 

426 if not isinstance(mesh, Mesh): 

427 raise ValueError('mesh is not a valid Mesh object.') 

428 

429 # Validate sharding spec 

430 for _, dim_sharding in enumerate(sharding_specs): 

431 # If special value no need to check for uniqueness, just skip. 

432 if dim_sharding == UNSHARDED or dim_sharding == MATCH: 

433 continue 

434 # Check dim_sharding is unique. 

435 if sharding_specs.count(dim_sharding) > 1: 

436 raise ValueError( 

437 ('Mesh dimension {mesh_dim} was repeated in sharding ' + 

438 'specification {sharding_specs}. Mesh dimensions must be unique ' + 

439 'in a layout.').format( 

440 mesh_dim=dim_sharding, sharding_specs=sharding_specs)) 

441 # Check dim_sharding is mesh dimension. 

442 if dim_sharding not in mesh: 

443 raise ValueError( 

444 ('{dim_sharding}: A dimension sharding must either be a ' + 

445 'valid mesh dimension or UNSHARDED.').format( 

446 dim_sharding=dim_sharding)) 

447 

448 super().__init__(sharding_specs=sharding_specs, mesh=mesh) 

449 

450 def __repr__(self) -> str: 

451 return f'Layout(sharding_specs={self.sharding_specs}, mesh={self.mesh})' 

452 

453 def __hash__(self): 

454 return hash(self.as_proto().SerializeToString(deterministic=True)) 

455 

456 # TODO(panzf): change to pybind11 pickle implementation in the last step 

457 def __reduce__(self): 

458 return Layout.from_string, (self.to_string(),) 

459 

460 # TODO(b/242201545): Find a way to return Mesh object from the pywrap module. 

461 @property 

462 def mesh(self): 

463 return Mesh.from_proto(super().mesh.as_proto()) 

464 

465 @property 

466 def shape(self): 

467 return self.mesh.shape() 

468 

469 @classmethod 

470 def batch_sharded( 

471 cls, mesh: Mesh, batch_dim: str, rank: int, axis: int = 0 

472 ) -> 'Layout': 

473 """Returns a layout sharded on batch dimension.""" 

474 layout_obj = _pywrap_dtensor_device.Layout.__new__(cls) 

475 _pywrap_dtensor_device.Layout.__init__( 

476 # Watchout for the different ordering. 

477 layout_obj, 

478 mesh=mesh, 

479 rank=rank, 

480 batch_dim=batch_dim, 

481 axis=axis, 

482 ) 

483 return layout_obj 

484 

485 # TODO(b/242201545): Move this to C++ / find the corresponding function there. 

486 def delete(self, dims: List[int]) -> 'Layout': 

487 """Returns the layout with the give dimensions deleted.""" 

488 if not isinstance(dims, list): 

489 dims = [dims] 

490 new_specs = [ 

491 spec for i, spec in enumerate(self.sharding_specs) if i not in dims 

492 ] 

493 return Layout(new_specs, self.mesh) 

494 

495 @classmethod 

496 def from_proto(cls, layout_proto: layout_pb2.LayoutProto) -> 'Layout': 

497 """Creates an instance from a LayoutProto.""" 

498 layout_obj = _pywrap_dtensor_device.Layout.__new__(cls) 

499 _pywrap_dtensor_device.Layout.__init__( 

500 layout_obj, layout_proto=layout_proto 

501 ) 

502 return layout_obj 

503 

504 @classmethod 

505 def from_string(cls, layout_str: str) -> 'Layout': 

506 """Creates an instance from a human-readable string.""" 

507 layout_obj = _pywrap_dtensor_device.Layout.__new__(cls) 

508 _pywrap_dtensor_device.Layout.__init__(layout_obj, layout_str=layout_str) 

509 return layout_obj 

510 

511 @classmethod 

512 def inner_sharded(cls, mesh: Mesh, inner_dim: str, rank: int) -> 'Layout': 

513 """Returns a layout sharded on inner dimension.""" 

514 return cls.batch_sharded(mesh, inner_dim, rank, axis=rank - 1) 

515 

516 @classmethod 

517 def from_single_device_mesh(cls, mesh: Mesh) -> 'Layout': 

518 """Constructs a single device layout from a single device mesh.""" 

519 layout = _pywrap_dtensor_device.Layout.__new__(cls) 

520 _pywrap_dtensor_device.Layout.__init__(layout, mesh=mesh) 

521 return layout 

522 

523 @classmethod 

524 def from_device(cls, device: str) -> 'Layout': 

525 """Constructs a single device layout from a single device mesh.""" 

526 return cls.from_single_device_mesh(Mesh.from_device(device)) 

527 

528 # TODO(b/242201545): Move this to C++ / find the corresponding function there. 

529 def offset_to_shard(self): 

530 """Mapping from offset in a flattened list to shard index.""" 

531 unravel_index = self.mesh.unravel_index() 

532 locations = [None] * self.mesh.size 

533 for offset, mesh_loc in unravel_index.items(): 

534 loc = [] 

535 for dim_sharding in self.sharding_specs: 

536 if dim_sharding == UNSHARDED: 

537 loc.append(0) 

538 else: 

539 loc.append(mesh_loc[dim_sharding]) 

540 locations[offset] = tuple(loc) 

541 

542 return locations 

543 

544 # TODO(b/242201545): Move this to C++ / find the corresponding function there. 

545 def offset_tuple_to_global_index(self, offset_tuple): 

546 """Mapping from offset to index in global tensor.""" 

547 index = 0 

548 for i, o in enumerate(offset_tuple): 

549 m = 1 

550 for x in range(i + 1, self.rank): 

551 m = m * self.num_shards(x) 

552 index = index + m * o 

553 return index 

554 

555 @classmethod 

556 def replicated(cls, mesh: Mesh, rank: int) -> 'Layout': 

557 """Returns a replicated layout of rank `rank`.""" 

558 layout_obj = _pywrap_dtensor_device.Layout.__new__(cls) 

559 _pywrap_dtensor_device.Layout.__init__(layout_obj, mesh=mesh, rank=rank) 

560 return layout_obj