Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/tpu_util.py: 15%

322 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""TPU-specific utilities for DTensor.""" 

16 

17import functools 

18import time 

19from typing import List, Optional, Dict 

20 

21import numpy as np 

22 

23from tensorflow.dtensor.python import config 

24from tensorflow.dtensor.python import dtensor_device 

25from tensorflow.dtensor.python import gen_dtensor_ops 

26from tensorflow.dtensor.python import layout as layout_lib 

27from tensorflow.python.eager import context 

28from tensorflow.python.eager import def_function 

29from tensorflow.python.framework import constant_op 

30from tensorflow.python.framework import errors 

31from tensorflow.python.framework import ops 

32from tensorflow.python.ops import array_ops 

33from tensorflow.python.ops import math_ops 

34from tensorflow.python.platform import tf_logging as logging 

35from tensorflow.python.tpu import topology 

36from tensorflow.python.util.tf_export import tf_export 

37 

38 

39_MESH_DIM_X = "x" 

40_TPU_DEVICE_TYPE = "TPU" 

41 

42# A dedicated, hidden device used to make C++ API calls. 

43_dtensor_device = None 

44 

45# `_topology._mesh_shape` contains the TPU hardware slice size. 

46# `_topology.device_coordinates` maps TF task-device ordinals to TPU core IDs. 

47_tpu_topology = None 

48 

49# Cache core ID <-> location mappings so we need not make repeated C++ calls. 

50# Both are indexed by TF task-device ordinals. 

51_all_core_ids = None 

52_all_core_locations = None 

53 

54 

55class _CoreLocation: 

56 """Represents a TPU core's location in the mesh.""" 

57 

58 def __init__(self, x: int = 0, y: int = 0, z: int = 0, core: int = 0): 

59 self.x = x 

60 self.y = y 

61 self.z = z 

62 self.core = core 

63 

64 def __eq__(self, other): 

65 if not isinstance(other, _CoreLocation): 

66 return False 

67 return self.x == other.x and self.y == other.y and self.z == other.z and self.core == other.core 

68 

69 def __ne__(self, other): 

70 if not isinstance(other, _CoreLocation): 

71 return True 

72 return not self == other 

73 

74 def __hash__(self): 

75 return hash((self.x, self.y, self.z, self.core)) 

76 

77 def __repr__(self): 

78 return f"{type(self).__name__}(x={self.x}, y={self.y}, z={self.z}, core={self.core})" 

79 

80 def to_list(self): 

81 return [self.x, self.y, self.z, self.core] 

82 

83 

84def _create_device_array(shape, device_type, host_id, local_device_ids=None): 

85 """Returns ID and device lists that can be used to create a mesh.""" 

86 num_global_devices = config.num_global_devices(device_type) 

87 global_device_ids = np.arange(num_global_devices).reshape(shape) 

88 local_device_list = config.local_devices(device_type) 

89 

90 # User can specify local_device_ids or use default list for multi host. 

91 num_local_devices = len(local_device_list) 

92 local_device_ids = [ 

93 x + host_id * num_local_devices for x in range(num_local_devices) 

94 ] if not local_device_ids else local_device_ids 

95 

96 return global_device_ids, local_device_ids, local_device_list 

97 

98 

99def _create_tpu_topology(core_locations: List[_CoreLocation], num_tasks: int, 

100 num_devices_per_task: int) -> topology.Topology: 

101 """Returns a Topology object build from a _CoreLocation list. 

102 

103 Args: 

104 core_locations: A list of _CoreLocation objects sorted first by TF task ID 

105 and then by per-task device ordinals. 

106 num_tasks: The number of TF tasks in the cluster. 

107 num_devices_per_task: The number of TPU devices local to each task. 

108 """ 

109 

110 assert min([l.x for l in core_locations]) == 0 

111 assert min([l.y for l in core_locations]) == 0 

112 assert min([l.z for l in core_locations]) == 0 

113 assert min([l.core for l in core_locations]) == 0 

114 x_max = max([l.x for l in core_locations]) 

115 y_max = max([l.y for l in core_locations]) 

116 z_max = max([l.z for l in core_locations]) 

117 core_max = max([l.core for l in core_locations]) 

118 mesh_shape = [x_max + 1, y_max + 1, z_max + 1, core_max + 1] 

119 

120 device_coordinates = [[l.x, l.y, l.z, l.core] for l in core_locations] 

121 device_coordinates = np.asarray(device_coordinates).reshape( 

122 num_tasks, num_devices_per_task, 4) 

123 

124 return topology.Topology( 

125 mesh_shape=mesh_shape, device_coordinates=device_coordinates) 

126 

127 

128def shutdown_tpu_system(): 

129 """Shuts down the TPU system.""" 

130 

131 @def_function.function 

132 def _shutdown_tpu_system(): 

133 return gen_dtensor_ops.shutdown_tpu_system() 

134 

135 success = _shutdown_tpu_system() if context.is_tfrt_enabled() else True 

136 if success: 

137 logging.info("TPU system shut down.") 

138 else: 

139 logging.warning("TPU system fails to shut down.") 

140 

141 

142def tpu_system_init_helper(task_id, 

143 num_tasks, 

144 num_devices, 

145 use_tfrt_host_runtime=True): 

146 """A helper function to initialize multi-client tpu system.""" 

147 

148 @def_function.function 

149 def _tpu_init_fn(): 

150 return gen_dtensor_ops.configure_and_initialize_global_tpu( 

151 use_tfrt_host_runtime=use_tfrt_host_runtime) 

152 

153 @def_function.function 

154 def _set_global_tpu_array_fn(topology_proto): 

155 gen_dtensor_ops.d_tensor_set_global_tpu_array(topology_proto) 

156 

157 with ops.device("/job:" + config.full_job_name() + "/device:TPU_SYSTEM:0"): # pylint: disable=protected-access 

158 my_core_ids = _tpu_init_fn() 

159 logging.info("TPU core IDs: %s", my_core_ids) 

160 

161 # `my_core_ids` contains the IDs of TPU cores attached to this host. 

162 # 

163 # To generate correct and efficient XLA AllReduce group assignment, we must 

164 # merge these arrays from all hosts and broadcast the result back to all 

165 # hosts, so all hosts can use these mappings in their MLIR passes. 

166 # 

167 # This is essentially doing what WaitForDistributedTpuOp and 

168 # SetGlobalTPUArrayOp do, in our multi-client environment. 

169 num_devices_per_task = int(num_devices / num_tasks) 

170 

171 # Create a one-time use mesh and layout just for merging core IDs. 

172 mesh = layout_lib.Mesh([_MESH_DIM_X], 

173 *_create_device_array((num_devices,), _TPU_DEVICE_TYPE, 

174 config.client_id())) 

175 layout = layout_lib.Layout([_MESH_DIM_X, layout_lib.UNSHARDED], mesh) 

176 device = dtensor_device.DTensorDevice(meshes=[mesh]) 

177 logging.info("TPU core locations: %s", 

178 device.tpu_core_ids_to_locations(my_core_ids)) 

179 

180 # At this point, we don't know which cores are attached to other hosts. 

181 # The core ID mappings in the runtime haven't been set yet. 

182 # 

183 # The core ID merging AllReduce below is carefully written so it works 

184 # without needing correct core mappings to be set in the runtime. We will 

185 # use this AllReduce's result to set the core ID mappings, and all future 

186 # user-initiated AllReduces will use the mappings. 

187 # 

188 # The runtime is hard-coded to ignore core ID mappings on this AllReduce. 

189 all_core_ids = np.zeros([num_devices], dtype=np.int32) 

190 for i in range(len(my_core_ids)): 

191 all_core_ids[task_id * num_devices_per_task + i] = my_core_ids[i] 

192 

193 # Only one local device gets valid input: 8 local core IDs among 

194 # (num_tasks - 1) * 8 zeros. The 8 core IDs are set using task ID as offset. 

195 # The other 7 local devices get zero inputs. All devices on all host 

196 # participate in one AllReduce, whose result will be core IDs arranged by 

197 # task-device ordinals. 

198 all_core_ids = constant_op.constant([all_core_ids]) 

199 zeros = array_ops.zeros_like(all_core_ids) 

200 all_core_ids = [all_core_ids] + [zeros] * (num_devices_per_task - 1) 

201 

202 with ops.device(device.name): 

203 all_core_ids = device.pack(all_core_ids, layout) 

204 all_core_ids = math_ops.reduce_sum(all_core_ids, axis=[0]) 

205 unpacked_all_tpu_ids = device.unpack(all_core_ids) 

206 

207 all_core_ids = list(unpacked_all_tpu_ids[0].numpy()) 

208 logging.info("All TPU core IDs: %s", all_core_ids) 

209 

210 # Set the default core ID mappings in the runtime for legacy code and tests. 

211 # 

212 # Legacy code and tests create TPU meshes directly without using the 

213 # `create_tpu_mesh` function below. Those meshes have global device IDs 

214 # equal to TF task-device ordinals. The `all_core_ids` array happens to 

215 # arrange core IDs by TF task-device ordinals. Using this array on those 

216 # meshes guarantee correct although inefficient results. 

217 device.set_tpu_core_ids("", all_core_ids) 

218 

219 # Remember enough global, immutable information to be able to build any ring 

220 # we want prescribed by `create_tpu_mesh` in the future. 

221 global _all_core_ids 

222 _all_core_ids = all_core_ids 

223 

224 all_core_locations = device.tpu_core_ids_to_locations(all_core_ids) 

225 all_core_locations = [ 

226 _CoreLocation(l[0], l[1], l[2], l[3]) for l in all_core_locations 

227 ] 

228 global _all_core_locations 

229 _all_core_locations = all_core_locations 

230 logging.info("All TPU core locations: %s", all_core_locations) 

231 

232 tpu_topology = _create_tpu_topology(all_core_locations, num_tasks, 

233 num_devices_per_task) 

234 

235 _set_global_tpu_array_fn(tpu_topology.serialized()) 

236 return tpu_topology, device 

237 

238 

239def initialize_tpu_system(): 

240 """Initializes the TPU system.""" 

241 

242 # Make sure the server change is fully propagated before attempting to run 

243 # the core ID merging logic below. 

244 context.ensure_initialized() 

245 context.async_wait() 

246 context.context()._clear_caches() # pylint: disable=protected-access 

247 

248 use_tfrt_host_runtime = context.context().use_tfrt 

249 logging.info("Using TFRT host runtime is set to %s", use_tfrt_host_runtime) 

250 try: 

251 task_id = config.client_id() 

252 num_tasks = config.num_clients() 

253 num_devices = config.num_global_devices(_TPU_DEVICE_TYPE) 

254 

255 tpu_topology, device = tpu_system_init_helper( 

256 task_id, 

257 num_tasks, 

258 num_devices, 

259 use_tfrt_host_runtime=use_tfrt_host_runtime) 

260 global _tpu_topology 

261 _tpu_topology = tpu_topology 

262 logging.vlog(1, "TPU Topology: %s, %s", tpu_topology.mesh_shape, 

263 tpu_topology.device_coordinates) 

264 

265 global _dtensor_device 

266 _dtensor_device = device 

267 

268 context.async_wait() 

269 

270 except errors.InvalidArgumentError as e: 

271 raise errors.NotFoundError( 

272 None, None, 

273 "Initialization failed, no valid TPUs found. " + str(e)) from e 

274 

275 except errors.InternalError as e: 

276 logging.error("Hit internal error during TPU system initialization. " 

277 + "It is likely hareware failure. \nPlease check the error " 

278 + "messages above to see whether that's the case. \nIf so, " 

279 + "consider to restart the job or try another machine.") 

280 raise e 

281 

282 # Clear out the eager context caches since the memory is invalid now. 

283 logging.info("Clearing out eager caches") 

284 context.context()._clear_caches() # pylint: disable=protected-access 

285 

286 

287def _enumerate_cores(bounds: List[int], ring_bounds: List[int], 

288 ring_sizes: List[int], host_bounds: List[int], 

289 host_sizes: List[int]) -> List[List[int]]: 

290 """Enumerates cores within `bounds` from fatest to slowest varying axes. 

291 

292 Args: 

293 bounds: Upper bounds of axes, from fastest to slowest varying. 

294 ring_bounds: Upper bounds of ring size per axis in the same axis order. 

295 ring_sizes: Number consecutive cores in the ring built so far, cumulatively. 

296 host_bounds: Number of axis values per host in the same axis order. 

297 host_sizes: Number consecutive cores on one host, cumulatively. 

298 

299 Returns: 

300 Cores represented as a list of 4 integers in the same axis order. 

301 """ 

302 if not bounds: 

303 return [[]] 

304 

305 # Recursively enumerate cores under all but the slowest varying axis. 

306 partials = _enumerate_cores(bounds[:-1], ring_bounds[:-1], ring_sizes[:-1], 

307 host_bounds[:-1], host_sizes[:-1]) 

308 

309 # Append the slowest varying axis to the end of all partial results. 

310 # From ring_i|j to host_i|j to core_i|j, use progressively smaller or equal 

311 # iteration groupings until every one of the bounds[-1] * len(partials) 

312 # combinations is iterated on. 

313 # Despite the six levels of nested loops below, the total time complexity for 

314 # this invocation is O(N), where N is the number of cores in the topology. 

315 results = [] 

316 for ring_i in range(0, bounds[-1], ring_bounds[-1]): 

317 for ring_j in range(0, len(partials), ring_sizes[-1]): 

318 for host_i in range(ring_i, ring_i + ring_bounds[-1], host_bounds[-1]): 

319 for host_j in range(ring_j, ring_j + ring_sizes[-1], host_sizes[-1]): 

320 for i in range(host_i, host_i + host_bounds[-1]): 

321 for j in range(host_j, host_j + host_sizes[-1]): 

322 results.append(partials[j] + [i]) 

323 return results 

324 

325 

326def _enumerate_core_locations(bounds: List[int], ring_bounds: List[int], 

327 axes: List[str], 

328 can_split_host_across_rings: bool, 

329 ring_size: int) -> List[_CoreLocation]: 

330 """Enumerates all possible core locations under the axis iteration order. 

331 

332 Args: 

333 bounds: A list of 4 positive integers, upper bound values for x, y, z, core. 

334 ring_bounds: A list of 4 positive integers, upper bound values for ring size 

335 in x, y, z, core axes. 

336 axes: A permutation of ["x", "y", "z", "core"], the axis iteration order. 

337 can_split_host_across_rings: If true, devices attached to the same host may 

338 get assigned to different rings. 

339 ring_size: Number of devices in a ring, only for argument validation. 

340 

341 Returns: 

342 A list of all CoreLocation objects defined in a TPU slice of shape `bounds`, 

343 sorted by axis iteration order specified by `axes`. 

344 

345 For example, given bounds=[2, 2, 1, 2] and axes=["core", "z", "y", "x"], 

346 return 8 core locations expressed in (x, y, z, core) format but iterated in 

347 core -> z -> y -> x order (fatest to slowest varying): 

348 

349 [_CoreLocation(0, 0, 0, 0), 

350 _CoreLocation(0, 0, 0, 1), 

351 _CoreLocation(0, 1, 0, 0), 

352 _CoreLocation(0, 1, 0, 1), 

353 _CoreLocation(1, 0, 0, 0), 

354 _CoreLocation(1, 0, 0, 1), 

355 _CoreLocation(1, 1, 0, 0), 

356 _CoreLocation(1, 1, 0, 1)] 

357 

358 Raises: 

359 ValueError: If ring_size cannot be fulfilled without splitting hosts. 

360 """ 

361 

362 num_cores_per_chip = bounds[3] 

363 if num_cores_per_chip != 1 and num_cores_per_chip != 2: 

364 raise ValueError("Unsupported TPU slice size: %s" % bounds) 

365 

366 # Translate `axes` from string to integer format. 

367 axes = [{"x": 0, "y": 1, "z": 2, "core": 3}[axis] for axis in axes] 

368 # Reorder bounds from fastest to slowest varying axes. 

369 bounds = [bounds[i] for i in axes] 

370 

371 # Set and validate host_bounds. 

372 if can_split_host_across_rings: 

373 # If we can split hosts, shrink every host to effectively contain 1 device. 

374 host_bounds = [1, 1, 1, 1] 

375 elif np.prod(bounds) <= 2: 

376 # We must be running on 1x1 or 1x1x1 Forge. 

377 host_bounds = [[1, 1, 1, num_cores_per_chip][i] for i in axes] 

378 else: 

379 # Other cases including 2x2 Forge and Borg must use a full donut. 

380 host_bounds = [[2, 2, 1, num_cores_per_chip][i] for i in axes] 

381 # host_sizes is the cumulative products of host_bounts. 

382 host_sizes = [1] 

383 for host_bound in host_bounds: 

384 host_sizes.append(host_sizes[-1] * host_bound) 

385 host_size = host_sizes.pop() 

386 # When can_split_host_across_rings is false, a ring must contain at least as 

387 # many devices as a host has. 

388 if ring_size < host_size: 

389 assert not can_split_host_across_rings 

390 raise ValueError( 

391 "Rings too small for can_split_host_across_rings = False: %d" % 

392 ring_size) 

393 

394 # Reorder ring_bounds and validate it's element-wise >= host_bounds. 

395 ring_bounds = [ring_bounds[i] for i in axes] 

396 if ring_bounds < host_bounds: 

397 raise ValueError("ring_bounds %s should be >= host_bounds %s" % 

398 (ring_bounds, host_bounds)) 

399 ring_sizes = [1] 

400 # ring_sizes is the cumulative products of ring_bounds. 

401 for ring_bound in ring_bounds: 

402 ring_sizes.append(ring_sizes[-1] * ring_bound) 

403 ring_sizes.pop() 

404 

405 # Enumerate cores in the given iteration order. Each core is represented as a 

406 # list of int, which are offsets from fatest to slowest varying axes. 

407 cores = _enumerate_cores(bounds, ring_bounds, ring_sizes, host_bounds, 

408 host_sizes) 

409 # Reorder offsets of each core back to the x, y, z, core order. 

410 core_locations = [] 

411 for core in cores: 

412 core = [core[axes.index(i)] for i in range(4)] 

413 core_locations.append(_CoreLocation(core[0], core[1], core[2], core[3])) 

414 return core_locations 

415 

416 

417def _build_all_reduce_ring(core_locations: List[_CoreLocation], 

418 rotate: bool = False) -> List[int]: 

419 """Reorders a list of TPU cores to optimize for AllReduce performance. 

420 

421 This is ported from the C++ tensorflow::BuildAllReduceRing function, 

422 mixed with some logic from TF TPU's device_assignment._ring_3d. 

423 

424 Args: 

425 core_locations: A list of core locations expressed as [x, y, z, core]. 

426 rotate: If true, scan the cores in a column-major order. False by default. 

427 

428 Returns: 

429 A permutation of the input list such that neighbors in the sequence are 

430 nearby in the TPU topology. 

431 """ 

432 

433 permutation = list(range(len(core_locations))) 

434 if not permutation: 

435 return permutation 

436 logging.vlog(2, "Core locations in: %s", core_locations) 

437 

438 first_column = min([l.x for l in core_locations]) 

439 first_row = min([l.y for l in core_locations]) 

440 same_z = (len(set([l.z for l in core_locations])) == 1) 

441 logging.vlog(2, "first_column: %d", first_column) 

442 logging.vlog(2, "first_row: %d", first_row) 

443 logging.vlog(2, "same_z: %s", same_z) 

444 

445 def _cmp_2d(ia: int, ib: int) -> int: 

446 if not rotate: 

447 a = core_locations[ia] 

448 b = core_locations[ib] 

449 

450 # Order the first column last in the sequence, except for the first row. 

451 a_first = (a.x == first_column and a.y != first_row) 

452 b_first = (b.x == first_column and b.y != first_row) 

453 if a_first != b_first: 

454 return -1 if b_first else 1 

455 

456 # Order rows in increasing order, unless in the first column. 

457 if a.y != b.y: 

458 return b.y - a.y if a_first else a.y - b.y 

459 

460 # Order even rows left to right, odd rows right to left. 

461 if a.x != b.x: 

462 return a.x - b.x if a.y % 2 == 0 else b.x - a.x 

463 

464 # Order cores in increasing order. 

465 return a.core - b.core 

466 else: 

467 a = core_locations[ia] 

468 b = core_locations[ib] 

469 

470 # Order the first row last in the sequence, except for the first column. 

471 a_first = (a.y == first_row and a.x != first_column) 

472 b_first = (b.y == first_row and b.x != first_column) 

473 if a_first != b_first: 

474 return -1 if b_first else 1 

475 

476 # Order columns in increasing order, unless in the first row. 

477 if a.x != b.x: 

478 return b.x - a.x if a_first else a.x - b.x 

479 

480 # Order even columns top down, odd columns bottom up. 

481 if a.y != b.y: 

482 return a.y - b.y if a.x % 2 == 0 else b.y - a.y 

483 

484 # Order cores in increasing order. 

485 return a.core - b.core 

486 

487 def _cmp_3d(ia: int, ib: int) -> int: 

488 a = core_locations[ia] 

489 b = core_locations[ib] 

490 

491 a_corner = (a.x == first_column and a.y == first_row) 

492 b_corner = (b.x == first_column and b.y == first_row) 

493 

494 # If both are in the corner, order in reverse z then core order. 

495 if a_corner and b_corner: 

496 return b.z - a.z if a.z != b.z else a.core - b.core 

497 

498 # Corner cores always go after non-corner cores. 

499 if a_corner != b_corner: 

500 return -1 if b_corner else 1 

501 

502 # Both non-corner cores are on the same z-plane. Reverse odd z-planes. 

503 if a.z == b.z: 

504 return _cmp_2d(ia, ib) if a.z % 2 == 0 else -_cmp_2d(ia, ib) 

505 

506 # Both non-corner cores are on different z-planes. Smaller z goes first. 

507 return a.z - b.z 

508 

509 # If all cores are on the same z-plane, order as usual. Otherwise, order 

510 # neighbor z-planes in opposite orders. Stack all z-planes along the z axis 

511 # and connect them in one corner. 

512 if same_z: 

513 permutation.sort(key=functools.cmp_to_key(_cmp_2d)) 

514 else: 

515 permutation.sort(key=functools.cmp_to_key(_cmp_3d)) 

516 logging.vlog(2, "Permutation out: %s", permutation) 

517 return permutation 

518 

519 

520def _build_orthogonal_rings( 

521 core_locations: List[_CoreLocation], ring_size: int, 

522 rotate_ring_across_rings: bool) -> List[_CoreLocation]: 

523 """Build two all-reduce rings orthogonal to each other. 

524 

525 One ring includes every `ring_size` consecutive core locations. It is usually 

526 applied to the model-parallel dimension of a mesh to achieve best 1D 

527 all-reduce performance. The other ring includes core locations separated by 

528 a stride of `ring_size`. It is usually applied to the data-parallel dimension 

529 of a mesh to get predictable strided all-reduce performance. 

530 

531 Args: 

532 core_locations: A list of core locations expressed as [x, y, z, core]. 

533 ring_size: The number of core locations in the consecutive ring. 

534 rotate_ring_across_rings: Build column-major secondary rings. 

535 

536 Returns: 

537 A permutation of the input list forming the described rings. 

538 """ 

539 # Build a ring for the first `ring_size` cores, and apply that permutation to 

540 # every group of `ring_size` cores. 

541 num_cores = len(core_locations) 

542 permutation = _build_all_reduce_ring(core_locations[:ring_size]) 

543 for r in range(0, num_cores, ring_size): 

544 core_locations[r:r + ring_size] = [ 

545 core_locations[r + permutation[i]] for i in range(ring_size) 

546 ] 

547 logging.vlog(1, "Permutated core locations: %s", core_locations) 

548 

549 # Build a "ring" for the collection of devices consisting of the 0th device 

550 # from every group, and apply that permutation to every i-th device group. 

551 # This is achieved by transposing the list and back. 

552 transposed = [] 

553 for i in range(ring_size): 

554 transposed += [ 

555 core_locations[g + i] for g in range(0, num_cores, ring_size) 

556 ] 

557 

558 num_rings = int(num_cores / ring_size) 

559 permutation = _build_all_reduce_ring( 

560 transposed[:num_rings], rotate=rotate_ring_across_rings) 

561 for r in range(0, num_cores, num_rings): 

562 transposed[r:r + num_rings] = [ 

563 transposed[r + permutation[i]] for i in range(num_rings) 

564 ] 

565 

566 untransposed = [] 

567 for i in range(num_rings): 

568 untransposed += [transposed[g + i] for g in range(0, num_cores, num_rings)] 

569 logging.vlog(1, "Stride-permutated core locations: %s", untransposed) 

570 

571 return untransposed 

572 

573 

574@tf_export("experimental.dtensor.create_tpu_mesh", v1=[]) 

575def create_tpu_mesh( 

576 mesh_dim_names: List[str], 

577 mesh_shape: List[int], 

578 mesh_name: str, 

579 ring_dims: Optional[int] = None, 

580 ring_axes: Optional[List[str]] = None, 

581 ring_bounds: Optional[List[int]] = None, 

582 can_split_host_across_rings: bool = True, 

583 build_ring_across_rings: bool = False, 

584 rotate_ring_across_rings: bool = False, 

585 use_xla_spmd: bool = layout_lib.USE_XLA_SPMD) -> layout_lib.Mesh: 

586 """Returns a distributed TPU mesh optimized for AllReduce ring reductions. 

587 

588 Only as many as leading axes specified by `ring_axes` as necessary will be 

589 used to build rings, as long as the subslice formed by these axes have enough 

590 cores to contain a ring of the required size. The leftover axes in `ring_axes` 

591 won't affect results. 

592 

593 This function always uses all TPU devices, and offers more customization than 

594 `tf.experimental.dtensor.create_distributed_mesh`. 

595 

596 Args: 

597 mesh_dim_names: List of mesh dimension names. 

598 mesh_shape: Shape of the mesh. 

599 mesh_name: A unique name for the mesh. If empty, internally generate one. 

600 ring_dims: Optional; The number of leading (ring_dims > 0) or trailing 

601 (ring_dims < 0) mesh dimensions to build rings for. If unspecified, build 

602 rings for all but the first dimension. 

603 ring_axes: Optional; A permutation of ["x", "y", "z", "core"], specifying 

604 the order of TPU topology axes to build rings in. If unspecified, default 

605 to ["core", "x", "y", "z"]. 

606 ring_bounds: Optional; The maximum number of devices on each axis, in the x, 

607 y, z, core order. If unspecified, default to physical topology limits. 

608 can_split_host_across_rings: Optional; If true, devices attached to the same 

609 host (i.e., DTensor client) may get assigned to different rings. Setting 

610 it to false may cause some combinations of arguments to be infeasible; see 

611 DeviceAssignmentTest.testCreateMesh[No]SplittingHosts* for examples. 

612 build_ring_across_rings: Optional; If true, also build a data-parallel ring 

613 across model-parallel rings. This ring could be strided. 

614 rotate_ring_across_rings: Optional; If true, build the data-parallel ring in 

615 column-major instead of row-major order. 

616 use_xla_spmd: Boolean when True, will use XLA SPMD instead of 

617 DTensor SPMD. 

618 """ 

619 

620 logging.info("Building a TPU mesh %s of shape %s", mesh_name, mesh_shape) 

621 logging.info("Requested ring_dims: %s", ring_dims) 

622 logging.info("Requested ring_axes: %s", ring_axes) 

623 logging.info("Requested ring_bounds: %s", ring_bounds) 

624 logging.info("Requested can_split_host_across_rings: %s", 

625 can_split_host_across_rings) 

626 if not mesh_name: 

627 mesh_name = "mesh_%f" % time.time() 

628 logging.info("Requested mesh_name: %s", mesh_name) 

629 

630 # By default, build rings for all but the first (usually batch) dimension. 

631 if ring_dims is None: 

632 ring_dims = 1 - len(mesh_shape) 

633 elif ring_dims < -len(mesh_shape) or ring_dims > len(mesh_shape): 

634 raise ValueError("Invalid ring_dims value: %d" % ring_dims) 

635 logging.info("Actual ring_dims: %s", ring_dims) 

636 

637 # By default, vary axes in the core -> x -> y -> z order. 

638 if ring_axes is None: 

639 ring_axes = ["core", "x", "y", "z"] 

640 elif len(ring_axes) != 4: 

641 raise ValueError("Expected 4 elements in ring_axes, got %s" % ring_axes) 

642 elif sorted(ring_axes) != ["core", "x", "y", "z"]: 

643 raise ValueError("Invalid ring_axes value: %s" % ring_axes) 

644 logging.info("Actual ring_axes: %s", ring_axes) 

645 

646 # Validate ring_bounds values. 

647 if _tpu_topology is None: 

648 raise ValueError( 

649 "Invalid TPU topology, run dtensor.initialize_tpu_system() first") 

650 topology_shape = list(_tpu_topology.mesh_shape) 

651 if ring_bounds is None: 

652 ring_bounds = topology_shape 

653 elif len(ring_bounds) != 4: 

654 raise ValueError("Expected 4 elements in ring_bounds, got %s" % ring_bounds) 

655 elif ring_bounds > topology_shape: 

656 raise ValueError("ring_bounds %s should be <= topology sizes %s" % 

657 (ring_bounds, topology_shape)) 

658 logging.info("Actual ring_bounds: %s", ring_bounds) 

659 

660 # Compute ring_size, the number of cores in a ring. 

661 if ring_dims > 0: 

662 ring_size = np.prod(mesh_shape[:ring_dims]) 

663 elif ring_dims < 0: 

664 ring_size = np.prod(mesh_shape[ring_dims:]) 

665 else: 

666 ring_size = 1 # single-core rings 

667 logging.info("Actual ring_size: %d", ring_size) 

668 

669 # Rearrange all cores according to the axis iteration order. 

670 global_core_locations = _enumerate_core_locations( 

671 topology_shape, ring_bounds, ring_axes, can_split_host_across_rings, 

672 ring_size) 

673 logging.vlog(1, "Enumerated core locations: %s", global_core_locations) 

674 num_cores = len(global_core_locations) 

675 

676 # The mesh to be created must use all TPU cores in the system. 

677 mesh_size = np.prod(mesh_shape) 

678 if mesh_size != num_cores: 

679 raise ValueError( 

680 "Invalid mesh size: mesh shape %s cannot 1:1 map to %d TPU cores" % 

681 (mesh_shape, num_cores)) 

682 

683 # Build a ring for the `ring_size` dimension and, if required, a strided ring 

684 # for the orthogonal dimension. 

685 if build_ring_across_rings: 

686 global_core_locations = _build_orthogonal_rings(global_core_locations, 

687 ring_size, 

688 rotate_ring_across_rings) 

689 else: 

690 permutation = _build_all_reduce_ring(global_core_locations[:ring_size]) 

691 for r in range(0, num_cores, ring_size): 

692 global_core_locations[r:r + ring_size] = [ 

693 global_core_locations[r + permutation[i]] for i in range(ring_size) 

694 ] 

695 logging.vlog(1, "Permutated core locations: %s", global_core_locations) 

696 

697 # For this point on, change from List[CoreLocation] to List[List[int]] for 

698 # easier interaction with the C++ API. 

699 global_core_locations = [l.to_list() for l in global_core_locations] 

700 if _dtensor_device is None: 

701 raise ValueError("Invalid system device, " 

702 "run dtensor.initialize_accelerator_system() first") 

703 global_core_ids = _dtensor_device.tpu_core_locations_to_ids( 

704 global_core_locations) 

705 

706 # Store a per-mesh mapping in the runtime. 

707 _dtensor_device.set_tpu_core_ids(mesh_name, global_core_ids) 

708 

709 # Create the mesh by manually specifying local_device_ids. 

710 local_core_locations = _tpu_topology.device_coordinates[config.client_id()] 

711 indexes = [ 

712 global_core_locations.index(list(local_core_location)) 

713 for local_core_location in local_core_locations 

714 ] 

715 global_device_ids, local_device_ids, local_device_list = _create_device_array( 

716 mesh_shape, _TPU_DEVICE_TYPE, None, local_device_ids=indexes) 

717 return layout_lib.Mesh(mesh_dim_names, global_device_ids, local_device_ids, 

718 local_device_list, mesh_name, use_xla_spmd) 

719 

720 

721def get_device_ids(mesh: layout_lib.Mesh, 

722 client_id: Optional[int] = None) -> List[int]: 

723 """Returns the device IDs of all TPU cores local to the given client. 

724 

725 A device ID is a non-negative integer that uniquely identifies a device in the 

726 mesh. For example, for a 2x2 mesh ('x', 'y'), this function returns a 

727 permutation of [0, 1, 2, 3]. 

728 

729 Note that device IDs and device locations are equivalent. The former is a 

730 linearization of the latter along mesh dimensions. 

731 

732 Args: 

733 mesh: A TPU mesh. 

734 client_id: Optional; A DTensor client ID. If empty, query this client. 

735 """ 

736 

737 if mesh.device_type() != _TPU_DEVICE_TYPE: 

738 raise ValueError("The mesh must be a TPU mesh") 

739 

740 if client_id is None or client_id == config.client_id(): 

741 return mesh.local_device_ids() 

742 

743 # It's not clear we should ever allow a client to query other clients for 

744 # their device IDs. 

745 raise NotImplementedError( 

746 "Looking up other clients' device IDs is not supported") 

747 

748 

749def get_device_locations( 

750 mesh: layout_lib.Mesh, 

751 client_id: Optional[int] = None) -> List[Dict[str, int]]: 

752 """Returns the device locations of all TPU cores local to the given client. 

753 

754 A device location is a dictionary from dimension names to indices on those 

755 dimensions. For example, for a 2x2 mesh ('x', 'y'), this function returns a 

756 permutation of this list: 

757 

758 [{'x': 0, 'y': 0}, 

759 {'x': 0, 'y': 1}, 

760 {'x': 1, 'y': 0}, 

761 {'x': 1, 'y': 1}]. 

762 

763 Note that device IDs and device locations are equivalent. The former is a 

764 linearization of the latter along mesh dimensions. 

765 

766 Args: 

767 mesh: A TPU mesh. 

768 client_id: Optional; A DTensor client ID. If empty, query this client. 

769 """ 

770 

771 if mesh.device_type() != _TPU_DEVICE_TYPE: 

772 raise ValueError("The mesh must be a TPU mesh") 

773 

774 if client_id is None or client_id == config.client_id(): 

775 return mesh.local_device_locations() 

776 

777 # It's not clear we should ever allow a client to query other clients for 

778 # their device locations. 

779 raise NotImplementedError( 

780 "Looking up other clients' device locations is not supported") 

781 

782 

783# TODO(b/245589661): Remove dtensor_initialize_tpu_system() and 

784# dtensor_shutdown_tpu_system() after users stopped using them. 

785def dtensor_initialize_tpu_system(enable_coordination_service=False): 

786 """Deprecated way to initialize the TPU system.""" 

787 from . import accelerator_util # pylint: disable=g-import-not-at-top 

788 accelerator_util.initialize_accelerator_system( 

789 "TPU", enable_coordination_service=enable_coordination_service) 

790 

791 

792def dtensor_shutdown_tpu_system(): 

793 """Deprecated way to shutodwn the TPU system.""" 

794 from . import accelerator_util # pylint: disable=g-import-not-at-top 

795 accelerator_util.shutdown_accelerator_system()