Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/tpu_util.py: 15%
322 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU-specific utilities for DTensor."""
17import functools
18import time
19from typing import List, Optional, Dict
21import numpy as np
23from tensorflow.dtensor.python import config
24from tensorflow.dtensor.python import dtensor_device
25from tensorflow.dtensor.python import gen_dtensor_ops
26from tensorflow.dtensor.python import layout as layout_lib
27from tensorflow.python.eager import context
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.tpu import topology
36from tensorflow.python.util.tf_export import tf_export
39_MESH_DIM_X = "x"
40_TPU_DEVICE_TYPE = "TPU"
42# A dedicated, hidden device used to make C++ API calls.
43_dtensor_device = None
45# `_topology._mesh_shape` contains the TPU hardware slice size.
46# `_topology.device_coordinates` maps TF task-device ordinals to TPU core IDs.
47_tpu_topology = None
49# Cache core ID <-> location mappings so we need not make repeated C++ calls.
50# Both are indexed by TF task-device ordinals.
51_all_core_ids = None
52_all_core_locations = None
55class _CoreLocation:
56 """Represents a TPU core's location in the mesh."""
58 def __init__(self, x: int = 0, y: int = 0, z: int = 0, core: int = 0):
59 self.x = x
60 self.y = y
61 self.z = z
62 self.core = core
64 def __eq__(self, other):
65 if not isinstance(other, _CoreLocation):
66 return False
67 return self.x == other.x and self.y == other.y and self.z == other.z and self.core == other.core
69 def __ne__(self, other):
70 if not isinstance(other, _CoreLocation):
71 return True
72 return not self == other
74 def __hash__(self):
75 return hash((self.x, self.y, self.z, self.core))
77 def __repr__(self):
78 return f"{type(self).__name__}(x={self.x}, y={self.y}, z={self.z}, core={self.core})"
80 def to_list(self):
81 return [self.x, self.y, self.z, self.core]
84def _create_device_array(shape, device_type, host_id, local_device_ids=None):
85 """Returns ID and device lists that can be used to create a mesh."""
86 num_global_devices = config.num_global_devices(device_type)
87 global_device_ids = np.arange(num_global_devices).reshape(shape)
88 local_device_list = config.local_devices(device_type)
90 # User can specify local_device_ids or use default list for multi host.
91 num_local_devices = len(local_device_list)
92 local_device_ids = [
93 x + host_id * num_local_devices for x in range(num_local_devices)
94 ] if not local_device_ids else local_device_ids
96 return global_device_ids, local_device_ids, local_device_list
99def _create_tpu_topology(core_locations: List[_CoreLocation], num_tasks: int,
100 num_devices_per_task: int) -> topology.Topology:
101 """Returns a Topology object build from a _CoreLocation list.
103 Args:
104 core_locations: A list of _CoreLocation objects sorted first by TF task ID
105 and then by per-task device ordinals.
106 num_tasks: The number of TF tasks in the cluster.
107 num_devices_per_task: The number of TPU devices local to each task.
108 """
110 assert min([l.x for l in core_locations]) == 0
111 assert min([l.y for l in core_locations]) == 0
112 assert min([l.z for l in core_locations]) == 0
113 assert min([l.core for l in core_locations]) == 0
114 x_max = max([l.x for l in core_locations])
115 y_max = max([l.y for l in core_locations])
116 z_max = max([l.z for l in core_locations])
117 core_max = max([l.core for l in core_locations])
118 mesh_shape = [x_max + 1, y_max + 1, z_max + 1, core_max + 1]
120 device_coordinates = [[l.x, l.y, l.z, l.core] for l in core_locations]
121 device_coordinates = np.asarray(device_coordinates).reshape(
122 num_tasks, num_devices_per_task, 4)
124 return topology.Topology(
125 mesh_shape=mesh_shape, device_coordinates=device_coordinates)
128def shutdown_tpu_system():
129 """Shuts down the TPU system."""
131 @def_function.function
132 def _shutdown_tpu_system():
133 return gen_dtensor_ops.shutdown_tpu_system()
135 success = _shutdown_tpu_system() if context.is_tfrt_enabled() else True
136 if success:
137 logging.info("TPU system shut down.")
138 else:
139 logging.warning("TPU system fails to shut down.")
142def tpu_system_init_helper(task_id,
143 num_tasks,
144 num_devices,
145 use_tfrt_host_runtime=True):
146 """A helper function to initialize multi-client tpu system."""
148 @def_function.function
149 def _tpu_init_fn():
150 return gen_dtensor_ops.configure_and_initialize_global_tpu(
151 use_tfrt_host_runtime=use_tfrt_host_runtime)
153 @def_function.function
154 def _set_global_tpu_array_fn(topology_proto):
155 gen_dtensor_ops.d_tensor_set_global_tpu_array(topology_proto)
157 with ops.device("/job:" + config.full_job_name() + "/device:TPU_SYSTEM:0"): # pylint: disable=protected-access
158 my_core_ids = _tpu_init_fn()
159 logging.info("TPU core IDs: %s", my_core_ids)
161 # `my_core_ids` contains the IDs of TPU cores attached to this host.
162 #
163 # To generate correct and efficient XLA AllReduce group assignment, we must
164 # merge these arrays from all hosts and broadcast the result back to all
165 # hosts, so all hosts can use these mappings in their MLIR passes.
166 #
167 # This is essentially doing what WaitForDistributedTpuOp and
168 # SetGlobalTPUArrayOp do, in our multi-client environment.
169 num_devices_per_task = int(num_devices / num_tasks)
171 # Create a one-time use mesh and layout just for merging core IDs.
172 mesh = layout_lib.Mesh([_MESH_DIM_X],
173 *_create_device_array((num_devices,), _TPU_DEVICE_TYPE,
174 config.client_id()))
175 layout = layout_lib.Layout([_MESH_DIM_X, layout_lib.UNSHARDED], mesh)
176 device = dtensor_device.DTensorDevice(meshes=[mesh])
177 logging.info("TPU core locations: %s",
178 device.tpu_core_ids_to_locations(my_core_ids))
180 # At this point, we don't know which cores are attached to other hosts.
181 # The core ID mappings in the runtime haven't been set yet.
182 #
183 # The core ID merging AllReduce below is carefully written so it works
184 # without needing correct core mappings to be set in the runtime. We will
185 # use this AllReduce's result to set the core ID mappings, and all future
186 # user-initiated AllReduces will use the mappings.
187 #
188 # The runtime is hard-coded to ignore core ID mappings on this AllReduce.
189 all_core_ids = np.zeros([num_devices], dtype=np.int32)
190 for i in range(len(my_core_ids)):
191 all_core_ids[task_id * num_devices_per_task + i] = my_core_ids[i]
193 # Only one local device gets valid input: 8 local core IDs among
194 # (num_tasks - 1) * 8 zeros. The 8 core IDs are set using task ID as offset.
195 # The other 7 local devices get zero inputs. All devices on all host
196 # participate in one AllReduce, whose result will be core IDs arranged by
197 # task-device ordinals.
198 all_core_ids = constant_op.constant([all_core_ids])
199 zeros = array_ops.zeros_like(all_core_ids)
200 all_core_ids = [all_core_ids] + [zeros] * (num_devices_per_task - 1)
202 with ops.device(device.name):
203 all_core_ids = device.pack(all_core_ids, layout)
204 all_core_ids = math_ops.reduce_sum(all_core_ids, axis=[0])
205 unpacked_all_tpu_ids = device.unpack(all_core_ids)
207 all_core_ids = list(unpacked_all_tpu_ids[0].numpy())
208 logging.info("All TPU core IDs: %s", all_core_ids)
210 # Set the default core ID mappings in the runtime for legacy code and tests.
211 #
212 # Legacy code and tests create TPU meshes directly without using the
213 # `create_tpu_mesh` function below. Those meshes have global device IDs
214 # equal to TF task-device ordinals. The `all_core_ids` array happens to
215 # arrange core IDs by TF task-device ordinals. Using this array on those
216 # meshes guarantee correct although inefficient results.
217 device.set_tpu_core_ids("", all_core_ids)
219 # Remember enough global, immutable information to be able to build any ring
220 # we want prescribed by `create_tpu_mesh` in the future.
221 global _all_core_ids
222 _all_core_ids = all_core_ids
224 all_core_locations = device.tpu_core_ids_to_locations(all_core_ids)
225 all_core_locations = [
226 _CoreLocation(l[0], l[1], l[2], l[3]) for l in all_core_locations
227 ]
228 global _all_core_locations
229 _all_core_locations = all_core_locations
230 logging.info("All TPU core locations: %s", all_core_locations)
232 tpu_topology = _create_tpu_topology(all_core_locations, num_tasks,
233 num_devices_per_task)
235 _set_global_tpu_array_fn(tpu_topology.serialized())
236 return tpu_topology, device
239def initialize_tpu_system():
240 """Initializes the TPU system."""
242 # Make sure the server change is fully propagated before attempting to run
243 # the core ID merging logic below.
244 context.ensure_initialized()
245 context.async_wait()
246 context.context()._clear_caches() # pylint: disable=protected-access
248 use_tfrt_host_runtime = context.context().use_tfrt
249 logging.info("Using TFRT host runtime is set to %s", use_tfrt_host_runtime)
250 try:
251 task_id = config.client_id()
252 num_tasks = config.num_clients()
253 num_devices = config.num_global_devices(_TPU_DEVICE_TYPE)
255 tpu_topology, device = tpu_system_init_helper(
256 task_id,
257 num_tasks,
258 num_devices,
259 use_tfrt_host_runtime=use_tfrt_host_runtime)
260 global _tpu_topology
261 _tpu_topology = tpu_topology
262 logging.vlog(1, "TPU Topology: %s, %s", tpu_topology.mesh_shape,
263 tpu_topology.device_coordinates)
265 global _dtensor_device
266 _dtensor_device = device
268 context.async_wait()
270 except errors.InvalidArgumentError as e:
271 raise errors.NotFoundError(
272 None, None,
273 "Initialization failed, no valid TPUs found. " + str(e)) from e
275 except errors.InternalError as e:
276 logging.error("Hit internal error during TPU system initialization. "
277 + "It is likely hareware failure. \nPlease check the error "
278 + "messages above to see whether that's the case. \nIf so, "
279 + "consider to restart the job or try another machine.")
280 raise e
282 # Clear out the eager context caches since the memory is invalid now.
283 logging.info("Clearing out eager caches")
284 context.context()._clear_caches() # pylint: disable=protected-access
287def _enumerate_cores(bounds: List[int], ring_bounds: List[int],
288 ring_sizes: List[int], host_bounds: List[int],
289 host_sizes: List[int]) -> List[List[int]]:
290 """Enumerates cores within `bounds` from fatest to slowest varying axes.
292 Args:
293 bounds: Upper bounds of axes, from fastest to slowest varying.
294 ring_bounds: Upper bounds of ring size per axis in the same axis order.
295 ring_sizes: Number consecutive cores in the ring built so far, cumulatively.
296 host_bounds: Number of axis values per host in the same axis order.
297 host_sizes: Number consecutive cores on one host, cumulatively.
299 Returns:
300 Cores represented as a list of 4 integers in the same axis order.
301 """
302 if not bounds:
303 return [[]]
305 # Recursively enumerate cores under all but the slowest varying axis.
306 partials = _enumerate_cores(bounds[:-1], ring_bounds[:-1], ring_sizes[:-1],
307 host_bounds[:-1], host_sizes[:-1])
309 # Append the slowest varying axis to the end of all partial results.
310 # From ring_i|j to host_i|j to core_i|j, use progressively smaller or equal
311 # iteration groupings until every one of the bounds[-1] * len(partials)
312 # combinations is iterated on.
313 # Despite the six levels of nested loops below, the total time complexity for
314 # this invocation is O(N), where N is the number of cores in the topology.
315 results = []
316 for ring_i in range(0, bounds[-1], ring_bounds[-1]):
317 for ring_j in range(0, len(partials), ring_sizes[-1]):
318 for host_i in range(ring_i, ring_i + ring_bounds[-1], host_bounds[-1]):
319 for host_j in range(ring_j, ring_j + ring_sizes[-1], host_sizes[-1]):
320 for i in range(host_i, host_i + host_bounds[-1]):
321 for j in range(host_j, host_j + host_sizes[-1]):
322 results.append(partials[j] + [i])
323 return results
326def _enumerate_core_locations(bounds: List[int], ring_bounds: List[int],
327 axes: List[str],
328 can_split_host_across_rings: bool,
329 ring_size: int) -> List[_CoreLocation]:
330 """Enumerates all possible core locations under the axis iteration order.
332 Args:
333 bounds: A list of 4 positive integers, upper bound values for x, y, z, core.
334 ring_bounds: A list of 4 positive integers, upper bound values for ring size
335 in x, y, z, core axes.
336 axes: A permutation of ["x", "y", "z", "core"], the axis iteration order.
337 can_split_host_across_rings: If true, devices attached to the same host may
338 get assigned to different rings.
339 ring_size: Number of devices in a ring, only for argument validation.
341 Returns:
342 A list of all CoreLocation objects defined in a TPU slice of shape `bounds`,
343 sorted by axis iteration order specified by `axes`.
345 For example, given bounds=[2, 2, 1, 2] and axes=["core", "z", "y", "x"],
346 return 8 core locations expressed in (x, y, z, core) format but iterated in
347 core -> z -> y -> x order (fatest to slowest varying):
349 [_CoreLocation(0, 0, 0, 0),
350 _CoreLocation(0, 0, 0, 1),
351 _CoreLocation(0, 1, 0, 0),
352 _CoreLocation(0, 1, 0, 1),
353 _CoreLocation(1, 0, 0, 0),
354 _CoreLocation(1, 0, 0, 1),
355 _CoreLocation(1, 1, 0, 0),
356 _CoreLocation(1, 1, 0, 1)]
358 Raises:
359 ValueError: If ring_size cannot be fulfilled without splitting hosts.
360 """
362 num_cores_per_chip = bounds[3]
363 if num_cores_per_chip != 1 and num_cores_per_chip != 2:
364 raise ValueError("Unsupported TPU slice size: %s" % bounds)
366 # Translate `axes` from string to integer format.
367 axes = [{"x": 0, "y": 1, "z": 2, "core": 3}[axis] for axis in axes]
368 # Reorder bounds from fastest to slowest varying axes.
369 bounds = [bounds[i] for i in axes]
371 # Set and validate host_bounds.
372 if can_split_host_across_rings:
373 # If we can split hosts, shrink every host to effectively contain 1 device.
374 host_bounds = [1, 1, 1, 1]
375 elif np.prod(bounds) <= 2:
376 # We must be running on 1x1 or 1x1x1 Forge.
377 host_bounds = [[1, 1, 1, num_cores_per_chip][i] for i in axes]
378 else:
379 # Other cases including 2x2 Forge and Borg must use a full donut.
380 host_bounds = [[2, 2, 1, num_cores_per_chip][i] for i in axes]
381 # host_sizes is the cumulative products of host_bounts.
382 host_sizes = [1]
383 for host_bound in host_bounds:
384 host_sizes.append(host_sizes[-1] * host_bound)
385 host_size = host_sizes.pop()
386 # When can_split_host_across_rings is false, a ring must contain at least as
387 # many devices as a host has.
388 if ring_size < host_size:
389 assert not can_split_host_across_rings
390 raise ValueError(
391 "Rings too small for can_split_host_across_rings = False: %d" %
392 ring_size)
394 # Reorder ring_bounds and validate it's element-wise >= host_bounds.
395 ring_bounds = [ring_bounds[i] for i in axes]
396 if ring_bounds < host_bounds:
397 raise ValueError("ring_bounds %s should be >= host_bounds %s" %
398 (ring_bounds, host_bounds))
399 ring_sizes = [1]
400 # ring_sizes is the cumulative products of ring_bounds.
401 for ring_bound in ring_bounds:
402 ring_sizes.append(ring_sizes[-1] * ring_bound)
403 ring_sizes.pop()
405 # Enumerate cores in the given iteration order. Each core is represented as a
406 # list of int, which are offsets from fatest to slowest varying axes.
407 cores = _enumerate_cores(bounds, ring_bounds, ring_sizes, host_bounds,
408 host_sizes)
409 # Reorder offsets of each core back to the x, y, z, core order.
410 core_locations = []
411 for core in cores:
412 core = [core[axes.index(i)] for i in range(4)]
413 core_locations.append(_CoreLocation(core[0], core[1], core[2], core[3]))
414 return core_locations
417def _build_all_reduce_ring(core_locations: List[_CoreLocation],
418 rotate: bool = False) -> List[int]:
419 """Reorders a list of TPU cores to optimize for AllReduce performance.
421 This is ported from the C++ tensorflow::BuildAllReduceRing function,
422 mixed with some logic from TF TPU's device_assignment._ring_3d.
424 Args:
425 core_locations: A list of core locations expressed as [x, y, z, core].
426 rotate: If true, scan the cores in a column-major order. False by default.
428 Returns:
429 A permutation of the input list such that neighbors in the sequence are
430 nearby in the TPU topology.
431 """
433 permutation = list(range(len(core_locations)))
434 if not permutation:
435 return permutation
436 logging.vlog(2, "Core locations in: %s", core_locations)
438 first_column = min([l.x for l in core_locations])
439 first_row = min([l.y for l in core_locations])
440 same_z = (len(set([l.z for l in core_locations])) == 1)
441 logging.vlog(2, "first_column: %d", first_column)
442 logging.vlog(2, "first_row: %d", first_row)
443 logging.vlog(2, "same_z: %s", same_z)
445 def _cmp_2d(ia: int, ib: int) -> int:
446 if not rotate:
447 a = core_locations[ia]
448 b = core_locations[ib]
450 # Order the first column last in the sequence, except for the first row.
451 a_first = (a.x == first_column and a.y != first_row)
452 b_first = (b.x == first_column and b.y != first_row)
453 if a_first != b_first:
454 return -1 if b_first else 1
456 # Order rows in increasing order, unless in the first column.
457 if a.y != b.y:
458 return b.y - a.y if a_first else a.y - b.y
460 # Order even rows left to right, odd rows right to left.
461 if a.x != b.x:
462 return a.x - b.x if a.y % 2 == 0 else b.x - a.x
464 # Order cores in increasing order.
465 return a.core - b.core
466 else:
467 a = core_locations[ia]
468 b = core_locations[ib]
470 # Order the first row last in the sequence, except for the first column.
471 a_first = (a.y == first_row and a.x != first_column)
472 b_first = (b.y == first_row and b.x != first_column)
473 if a_first != b_first:
474 return -1 if b_first else 1
476 # Order columns in increasing order, unless in the first row.
477 if a.x != b.x:
478 return b.x - a.x if a_first else a.x - b.x
480 # Order even columns top down, odd columns bottom up.
481 if a.y != b.y:
482 return a.y - b.y if a.x % 2 == 0 else b.y - a.y
484 # Order cores in increasing order.
485 return a.core - b.core
487 def _cmp_3d(ia: int, ib: int) -> int:
488 a = core_locations[ia]
489 b = core_locations[ib]
491 a_corner = (a.x == first_column and a.y == first_row)
492 b_corner = (b.x == first_column and b.y == first_row)
494 # If both are in the corner, order in reverse z then core order.
495 if a_corner and b_corner:
496 return b.z - a.z if a.z != b.z else a.core - b.core
498 # Corner cores always go after non-corner cores.
499 if a_corner != b_corner:
500 return -1 if b_corner else 1
502 # Both non-corner cores are on the same z-plane. Reverse odd z-planes.
503 if a.z == b.z:
504 return _cmp_2d(ia, ib) if a.z % 2 == 0 else -_cmp_2d(ia, ib)
506 # Both non-corner cores are on different z-planes. Smaller z goes first.
507 return a.z - b.z
509 # If all cores are on the same z-plane, order as usual. Otherwise, order
510 # neighbor z-planes in opposite orders. Stack all z-planes along the z axis
511 # and connect them in one corner.
512 if same_z:
513 permutation.sort(key=functools.cmp_to_key(_cmp_2d))
514 else:
515 permutation.sort(key=functools.cmp_to_key(_cmp_3d))
516 logging.vlog(2, "Permutation out: %s", permutation)
517 return permutation
520def _build_orthogonal_rings(
521 core_locations: List[_CoreLocation], ring_size: int,
522 rotate_ring_across_rings: bool) -> List[_CoreLocation]:
523 """Build two all-reduce rings orthogonal to each other.
525 One ring includes every `ring_size` consecutive core locations. It is usually
526 applied to the model-parallel dimension of a mesh to achieve best 1D
527 all-reduce performance. The other ring includes core locations separated by
528 a stride of `ring_size`. It is usually applied to the data-parallel dimension
529 of a mesh to get predictable strided all-reduce performance.
531 Args:
532 core_locations: A list of core locations expressed as [x, y, z, core].
533 ring_size: The number of core locations in the consecutive ring.
534 rotate_ring_across_rings: Build column-major secondary rings.
536 Returns:
537 A permutation of the input list forming the described rings.
538 """
539 # Build a ring for the first `ring_size` cores, and apply that permutation to
540 # every group of `ring_size` cores.
541 num_cores = len(core_locations)
542 permutation = _build_all_reduce_ring(core_locations[:ring_size])
543 for r in range(0, num_cores, ring_size):
544 core_locations[r:r + ring_size] = [
545 core_locations[r + permutation[i]] for i in range(ring_size)
546 ]
547 logging.vlog(1, "Permutated core locations: %s", core_locations)
549 # Build a "ring" for the collection of devices consisting of the 0th device
550 # from every group, and apply that permutation to every i-th device group.
551 # This is achieved by transposing the list and back.
552 transposed = []
553 for i in range(ring_size):
554 transposed += [
555 core_locations[g + i] for g in range(0, num_cores, ring_size)
556 ]
558 num_rings = int(num_cores / ring_size)
559 permutation = _build_all_reduce_ring(
560 transposed[:num_rings], rotate=rotate_ring_across_rings)
561 for r in range(0, num_cores, num_rings):
562 transposed[r:r + num_rings] = [
563 transposed[r + permutation[i]] for i in range(num_rings)
564 ]
566 untransposed = []
567 for i in range(num_rings):
568 untransposed += [transposed[g + i] for g in range(0, num_cores, num_rings)]
569 logging.vlog(1, "Stride-permutated core locations: %s", untransposed)
571 return untransposed
574@tf_export("experimental.dtensor.create_tpu_mesh", v1=[])
575def create_tpu_mesh(
576 mesh_dim_names: List[str],
577 mesh_shape: List[int],
578 mesh_name: str,
579 ring_dims: Optional[int] = None,
580 ring_axes: Optional[List[str]] = None,
581 ring_bounds: Optional[List[int]] = None,
582 can_split_host_across_rings: bool = True,
583 build_ring_across_rings: bool = False,
584 rotate_ring_across_rings: bool = False,
585 use_xla_spmd: bool = layout_lib.USE_XLA_SPMD) -> layout_lib.Mesh:
586 """Returns a distributed TPU mesh optimized for AllReduce ring reductions.
588 Only as many as leading axes specified by `ring_axes` as necessary will be
589 used to build rings, as long as the subslice formed by these axes have enough
590 cores to contain a ring of the required size. The leftover axes in `ring_axes`
591 won't affect results.
593 This function always uses all TPU devices, and offers more customization than
594 `tf.experimental.dtensor.create_distributed_mesh`.
596 Args:
597 mesh_dim_names: List of mesh dimension names.
598 mesh_shape: Shape of the mesh.
599 mesh_name: A unique name for the mesh. If empty, internally generate one.
600 ring_dims: Optional; The number of leading (ring_dims > 0) or trailing
601 (ring_dims < 0) mesh dimensions to build rings for. If unspecified, build
602 rings for all but the first dimension.
603 ring_axes: Optional; A permutation of ["x", "y", "z", "core"], specifying
604 the order of TPU topology axes to build rings in. If unspecified, default
605 to ["core", "x", "y", "z"].
606 ring_bounds: Optional; The maximum number of devices on each axis, in the x,
607 y, z, core order. If unspecified, default to physical topology limits.
608 can_split_host_across_rings: Optional; If true, devices attached to the same
609 host (i.e., DTensor client) may get assigned to different rings. Setting
610 it to false may cause some combinations of arguments to be infeasible; see
611 DeviceAssignmentTest.testCreateMesh[No]SplittingHosts* for examples.
612 build_ring_across_rings: Optional; If true, also build a data-parallel ring
613 across model-parallel rings. This ring could be strided.
614 rotate_ring_across_rings: Optional; If true, build the data-parallel ring in
615 column-major instead of row-major order.
616 use_xla_spmd: Boolean when True, will use XLA SPMD instead of
617 DTensor SPMD.
618 """
620 logging.info("Building a TPU mesh %s of shape %s", mesh_name, mesh_shape)
621 logging.info("Requested ring_dims: %s", ring_dims)
622 logging.info("Requested ring_axes: %s", ring_axes)
623 logging.info("Requested ring_bounds: %s", ring_bounds)
624 logging.info("Requested can_split_host_across_rings: %s",
625 can_split_host_across_rings)
626 if not mesh_name:
627 mesh_name = "mesh_%f" % time.time()
628 logging.info("Requested mesh_name: %s", mesh_name)
630 # By default, build rings for all but the first (usually batch) dimension.
631 if ring_dims is None:
632 ring_dims = 1 - len(mesh_shape)
633 elif ring_dims < -len(mesh_shape) or ring_dims > len(mesh_shape):
634 raise ValueError("Invalid ring_dims value: %d" % ring_dims)
635 logging.info("Actual ring_dims: %s", ring_dims)
637 # By default, vary axes in the core -> x -> y -> z order.
638 if ring_axes is None:
639 ring_axes = ["core", "x", "y", "z"]
640 elif len(ring_axes) != 4:
641 raise ValueError("Expected 4 elements in ring_axes, got %s" % ring_axes)
642 elif sorted(ring_axes) != ["core", "x", "y", "z"]:
643 raise ValueError("Invalid ring_axes value: %s" % ring_axes)
644 logging.info("Actual ring_axes: %s", ring_axes)
646 # Validate ring_bounds values.
647 if _tpu_topology is None:
648 raise ValueError(
649 "Invalid TPU topology, run dtensor.initialize_tpu_system() first")
650 topology_shape = list(_tpu_topology.mesh_shape)
651 if ring_bounds is None:
652 ring_bounds = topology_shape
653 elif len(ring_bounds) != 4:
654 raise ValueError("Expected 4 elements in ring_bounds, got %s" % ring_bounds)
655 elif ring_bounds > topology_shape:
656 raise ValueError("ring_bounds %s should be <= topology sizes %s" %
657 (ring_bounds, topology_shape))
658 logging.info("Actual ring_bounds: %s", ring_bounds)
660 # Compute ring_size, the number of cores in a ring.
661 if ring_dims > 0:
662 ring_size = np.prod(mesh_shape[:ring_dims])
663 elif ring_dims < 0:
664 ring_size = np.prod(mesh_shape[ring_dims:])
665 else:
666 ring_size = 1 # single-core rings
667 logging.info("Actual ring_size: %d", ring_size)
669 # Rearrange all cores according to the axis iteration order.
670 global_core_locations = _enumerate_core_locations(
671 topology_shape, ring_bounds, ring_axes, can_split_host_across_rings,
672 ring_size)
673 logging.vlog(1, "Enumerated core locations: %s", global_core_locations)
674 num_cores = len(global_core_locations)
676 # The mesh to be created must use all TPU cores in the system.
677 mesh_size = np.prod(mesh_shape)
678 if mesh_size != num_cores:
679 raise ValueError(
680 "Invalid mesh size: mesh shape %s cannot 1:1 map to %d TPU cores" %
681 (mesh_shape, num_cores))
683 # Build a ring for the `ring_size` dimension and, if required, a strided ring
684 # for the orthogonal dimension.
685 if build_ring_across_rings:
686 global_core_locations = _build_orthogonal_rings(global_core_locations,
687 ring_size,
688 rotate_ring_across_rings)
689 else:
690 permutation = _build_all_reduce_ring(global_core_locations[:ring_size])
691 for r in range(0, num_cores, ring_size):
692 global_core_locations[r:r + ring_size] = [
693 global_core_locations[r + permutation[i]] for i in range(ring_size)
694 ]
695 logging.vlog(1, "Permutated core locations: %s", global_core_locations)
697 # For this point on, change from List[CoreLocation] to List[List[int]] for
698 # easier interaction with the C++ API.
699 global_core_locations = [l.to_list() for l in global_core_locations]
700 if _dtensor_device is None:
701 raise ValueError("Invalid system device, "
702 "run dtensor.initialize_accelerator_system() first")
703 global_core_ids = _dtensor_device.tpu_core_locations_to_ids(
704 global_core_locations)
706 # Store a per-mesh mapping in the runtime.
707 _dtensor_device.set_tpu_core_ids(mesh_name, global_core_ids)
709 # Create the mesh by manually specifying local_device_ids.
710 local_core_locations = _tpu_topology.device_coordinates[config.client_id()]
711 indexes = [
712 global_core_locations.index(list(local_core_location))
713 for local_core_location in local_core_locations
714 ]
715 global_device_ids, local_device_ids, local_device_list = _create_device_array(
716 mesh_shape, _TPU_DEVICE_TYPE, None, local_device_ids=indexes)
717 return layout_lib.Mesh(mesh_dim_names, global_device_ids, local_device_ids,
718 local_device_list, mesh_name, use_xla_spmd)
721def get_device_ids(mesh: layout_lib.Mesh,
722 client_id: Optional[int] = None) -> List[int]:
723 """Returns the device IDs of all TPU cores local to the given client.
725 A device ID is a non-negative integer that uniquely identifies a device in the
726 mesh. For example, for a 2x2 mesh ('x', 'y'), this function returns a
727 permutation of [0, 1, 2, 3].
729 Note that device IDs and device locations are equivalent. The former is a
730 linearization of the latter along mesh dimensions.
732 Args:
733 mesh: A TPU mesh.
734 client_id: Optional; A DTensor client ID. If empty, query this client.
735 """
737 if mesh.device_type() != _TPU_DEVICE_TYPE:
738 raise ValueError("The mesh must be a TPU mesh")
740 if client_id is None or client_id == config.client_id():
741 return mesh.local_device_ids()
743 # It's not clear we should ever allow a client to query other clients for
744 # their device IDs.
745 raise NotImplementedError(
746 "Looking up other clients' device IDs is not supported")
749def get_device_locations(
750 mesh: layout_lib.Mesh,
751 client_id: Optional[int] = None) -> List[Dict[str, int]]:
752 """Returns the device locations of all TPU cores local to the given client.
754 A device location is a dictionary from dimension names to indices on those
755 dimensions. For example, for a 2x2 mesh ('x', 'y'), this function returns a
756 permutation of this list:
758 [{'x': 0, 'y': 0},
759 {'x': 0, 'y': 1},
760 {'x': 1, 'y': 0},
761 {'x': 1, 'y': 1}].
763 Note that device IDs and device locations are equivalent. The former is a
764 linearization of the latter along mesh dimensions.
766 Args:
767 mesh: A TPU mesh.
768 client_id: Optional; A DTensor client ID. If empty, query this client.
769 """
771 if mesh.device_type() != _TPU_DEVICE_TYPE:
772 raise ValueError("The mesh must be a TPU mesh")
774 if client_id is None or client_id == config.client_id():
775 return mesh.local_device_locations()
777 # It's not clear we should ever allow a client to query other clients for
778 # their device locations.
779 raise NotImplementedError(
780 "Looking up other clients' device locations is not supported")
783# TODO(b/245589661): Remove dtensor_initialize_tpu_system() and
784# dtensor_shutdown_tpu_system() after users stopped using them.
785def dtensor_initialize_tpu_system(enable_coordination_service=False):
786 """Deprecated way to initialize the TPU system."""
787 from . import accelerator_util # pylint: disable=g-import-not-at-top
788 accelerator_util.initialize_accelerator_system(
789 "TPU", enable_coordination_service=enable_coordination_service)
792def dtensor_shutdown_tpu_system():
793 """Deprecated way to shutodwn the TPU system."""
794 from . import accelerator_util # pylint: disable=g-import-not-at-top
795 accelerator_util.shutdown_accelerator_system()