Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/layout.py: 30%
217 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python definitions for `Mesh` and `Layout`."""
17import collections
18import itertools
19from typing import List, Dict, Optional, Union
21import numpy as np
23from tensorflow.dtensor.proto import layout_pb2
24from tensorflow.dtensor.python import config
25from tensorflow.python import _pywrap_dtensor_device
26from tensorflow.python.framework import device as tf_device
27from tensorflow.python.framework import ops
28from tensorflow.python.util.tf_export import tf_export
30# UNSHARDED indicates a tensor dimension is not sharded over any mesh dimension.
31UNSHARDED = 'unsharded'
32MATCH = 'match'
33USE_XLA_SPMD = False
35tf_export(
36 'experimental.dtensor.UNSHARDED',
37 v1=[]).export_constant(__name__, 'UNSHARDED')
38tf_export(
39 'experimental.dtensor.MATCH', v1=[]).export_constant(__name__, 'MATCH')
41MeshDimension = collections.namedtuple('MeshDimension', ['name', 'size'])
44def _compute_mesh_strides(shape: List[int]) -> List[int]:
45 strides = [1]
46 for idx, dim_size in enumerate(reversed(shape[1:])):
47 strides.append(strides[idx] * dim_size)
48 strides.reverse()
49 return strides
52@tf_export('experimental.dtensor.Mesh', v1=[])
53class Mesh(_pywrap_dtensor_device.Mesh):
54 """Represents a Mesh configuration over a certain list of Mesh Dimensions.
56 A mesh consists of named dimensions with sizes, which describe how a set of
57 devices are arranged. Defining tensor layouts in terms of mesh dimensions
58 allows us to efficiently determine the communication required when computing
59 an operation with tensors of different layouts.
61 A mesh provides information not only about the placement of the tensors but
62 also the topology of the underlying devices. For example, we can group 8 TPUs
63 as a 1-D array for data parallelism or a `2x4` grid for (2-way) data
64 parallelism and (4-way) model parallelism.
66 Note: the utilities `dtensor.create_mesh` and
67 `dtensor.create_distributed_mesh` provide a simpler API to create meshes for
68 single- or multi-client use cases.
69 """
71 def __init__(
72 self,
73 dim_names: List[str],
74 global_device_ids: np.ndarray,
75 local_device_ids: List[int],
76 local_devices: List[Union[tf_device.DeviceSpec, str]],
77 mesh_name: str = '',
78 global_devices: Optional[List[Union[tf_device.DeviceSpec, str]]] = None,
79 use_xla_spmd: bool = USE_XLA_SPMD,
80 ):
81 """Builds a Mesh.
83 The `dim_names` and `global_device_ids` arguments describe the dimension
84 names and shape for the mesh.
86 For example,
88 ```python
89 dim_names = ('x', 'y'),
90 global_device_ids = [[0, 1],
91 [2, 3],
92 [4, 5]]
93 ```
95 defines a 2D mesh of shape 3x2. A reduction over the 'x' dimension will
96 reduce across columns (0, 2, 4) and (1, 3, 5), and a reduction over the 'y'
97 dimension reduces across rows.
99 Note: the utilities `dtensor.create_mesh` and
100 `dtensor.create_distributed_mesh` provide a simpler API to create meshes for
101 single- or multi-client use cases.
103 Args:
104 dim_names: A list of strings indicating dimension names.
105 global_device_ids: An ndarray of global device IDs is used to compose
106 DeviceSpecs describing the mesh. The shape of this array determines the
107 size of each mesh dimension. Values in this array should increment
108 sequentially from 0. This argument is the same for every DTensor client.
109 local_device_ids: A list of local device IDs equal to a subset of values
110 in global_device_ids. They indicate the position of local devices in the
111 global mesh. Different DTensor clients must contain distinct
112 local_device_ids contents. All local_device_ids from all DTensor clients
113 must cover every element in global_device_ids.
114 local_devices: The list of devices hosted locally. The elements correspond
115 1:1 to those of local_device_ids.
116 mesh_name: The name of the mesh. Currently, this is rarely used, and is
117 mostly used to indicate whether it is a CPU, GPU, or TPU-based mesh.
118 global_devices (optional): The list of global devices. Set when multiple
119 device meshes are in use.
120 use_xla_spmd (optional): Boolean when True, will use XLA SPMD instead of
121 DTensor SPMD.
122 """
123 # Check if input args are valid.
124 if not isinstance(global_device_ids, np.ndarray):
125 raise ValueError('Variable global_device_ids must be an ndarray.')
126 if global_device_ids.size == 0:
127 raise ValueError('Variable global_device_ids must be non-empty.')
128 flat_global_device_ids = global_device_ids.flatten()
129 # global_device_ids are expected to be consecutive numbers.
130 # LINT.IfChange
131 distance = flat_global_device_ids[0]
132 if any(
133 (gid - i != distance) for i, gid in enumerate(flat_global_device_ids)):
134 raise ValueError('global_device_ids must sequentially increase: %s' %
135 global_device_ids)
136 # LINT.ThenChange(//tensorflow/dtensor/cc/dtensor_device.cc)
138 # TODO(b/242201545): This class is only for args type transformation for
139 # exported C++ Mesh class after the unification is complete. Any other
140 # logics should reside in the C++ layer, including validation checks, shall
141 # go to C++.
143 if len(dim_names) != global_device_ids.ndim:
144 raise ValueError(
145 'Number of mesh dimensions does not match number of dimension names.')
147 if not isinstance(local_device_ids, list):
148 raise ValueError('Variable local_device_ids must be a list of integers.')
150 if not isinstance(local_devices, list):
151 raise ValueError('Variable local_devices must be a list of DeviceSpecs.')
153 if global_devices and not isinstance(global_devices, list):
154 raise ValueError('Variable global_devices must be a list of DeviceSpecs.')
156 if not local_devices and not global_devices:
157 raise ValueError('Empty list of devices not allowed.')
159 # Transform args format for C++ Mesh constructor
160 global_device_ids_flatten = global_device_ids.flatten()
161 global_device_ids_shape = global_device_ids.shape
163 def to_str(d) -> str:
164 if isinstance(d, tf_device.DeviceSpec):
165 return d.to_string()
166 return d
168 def to_spec(d) -> tf_device.DeviceSpec:
169 if not isinstance(d, tf_device.DeviceSpec):
170 return tf_device.DeviceSpec.from_string(d)
171 return d
173 local_devices_str = [to_str(d) for d in local_devices]
174 local_devices_spec = [to_spec(d) for d in local_devices]
175 if not global_devices:
176 global_devices = []
177 global_devices_str = [to_str(d) for d in global_devices]
178 global_devices_spec = [to_spec(d) for d in global_devices]
180 local_devices_set = set(local_devices_spec)
181 local_device_only_contains_host_cpu = (
182 len(local_devices_set) == 1 and
183 list(local_devices_set)[0].device_type == 'CPU')
184 if not local_device_only_contains_host_cpu and len(local_devices) != len(
185 local_devices_set):
186 raise ValueError('Duplicate devices found in mesh specification %s.' %
187 [d for d in local_devices if local_devices.count(d) > 1])
189 if len(local_device_ids) != len(local_devices):
190 raise ValueError(
191 'Variable local_device_ids does not have same size as local_devices.')
193 if len(local_device_ids) > len(np.ravel(global_device_ids)):
194 raise ValueError('Cannot have more local than gobal device IDs.')
196 device_types = set([device.device_type for device in local_devices_spec])
197 if not device_types:
198 device_types = set([device.device_type for device in global_devices_spec])
199 if None in device_types:
200 raise ValueError('device_type is required')
201 if len(device_types) > 1:
202 raise ValueError('Devices containing multiple device_types : %s' %
203 device_types)
204 device_type = device_types.pop()
205 if use_xla_spmd and device_type != 'TPU':
206 raise ValueError('XLA SPMD is not currently not supported for %s mesh.' %
207 device_type)
209 super().__init__(
210 mesh_name,
211 dim_names,
212 global_device_ids_shape,
213 global_device_ids_flatten,
214 global_devices_str,
215 local_device_ids,
216 local_devices_str,
217 use_xla_spmd,
218 )
220 def global_device_ids(self) -> np.ndarray:
221 """Returns a global device list as an array."""
222 return np.array(super().global_device_ids(), dtype=np.int64).reshape(
223 self.shape()
224 )
226 def __getitem__(self, dim_name: str) -> MeshDimension:
227 return MeshDimension(name=dim_name, size=self.dim_size(dim_name))
229 def __hash__(self):
230 return hash(self.as_proto().SerializeToString(deterministic=True))
232 def __repr__(self) -> str:
233 dims = [tuple(self[dim_name]) for dim_name in self.dim_names]
234 return (
235 f'<Mesh object with dims={dims}, device_type="{self.device_type()}", '
236 f'num_local_devices={self.num_local_devices()}), '
237 f'size={self.size}>'
238 )
240 # TODO(panzf): change to pybind11 pickle implementation in the last step
241 def __reduce__(self):
242 return Mesh.from_string, (self.to_string(),)
244 # TODO(b/242201545): implement this in Mesh C++ class
245 def coords(self, device_idx: int) -> ops.Tensor:
246 """Converts the device index into a tensor of mesh coordinates."""
247 strides = ops.convert_to_tensor(self.strides)
248 shape = ops.convert_to_tensor(self.shape())
249 return (device_idx // strides) % shape
251 @classmethod
252 def from_proto(cls, proto: layout_pb2.MeshProto) -> 'Mesh':
253 """Construct a mesh instance from input `proto`."""
254 mesh = _pywrap_dtensor_device.Mesh.__new__(cls)
255 _pywrap_dtensor_device.Mesh.__init__(mesh, mesh_proto=proto)
256 return mesh
258 @classmethod
259 def from_string(cls, mesh_str: str) -> 'Mesh':
260 mesh = _pywrap_dtensor_device.Mesh.__new__(cls)
261 _pywrap_dtensor_device.Mesh.__init__(mesh, mesh_str=mesh_str)
262 return mesh
264 @classmethod
265 def from_device(cls, device: str) -> 'Mesh':
266 """Constructs a single device mesh from a device string."""
267 mesh = _pywrap_dtensor_device.Mesh.__new__(cls)
268 _pywrap_dtensor_device.Mesh.__init__(mesh, single_device=device)
269 return mesh
271 # TODO(b/242201545): implement this in Mesh C++ class
272 def host_mesh(self):
273 """Returns the 1-1 mapped host mesh."""
274 if self.device_type().upper() == 'CPU':
275 return self
277 v_cpus_counts = config.num_local_devices('CPU')
278 if v_cpus_counts < len(self.local_devices()):
279 raise ValueError(
280 'Must have at least {0} virtual CPUs for mesh : {1}, '
281 'but got : {2} virtual CPUs. '
282 'Call tf.experimental.dtensor.initialize_accelerator_system() '
283 'to initialize the host CPU devices with the accelerators.'.format(
284 len(self.local_devices()), self.to_string(), v_cpus_counts
285 )
286 )
287 local_device_specs = [
288 tf_device.DeviceSpec.from_string(d) for d in self.local_devices()
289 ]
290 global_device_specs = [
291 tf_device.DeviceSpec.from_string(d) for d in self.global_devices()
292 ]
294 device_array = np.asarray(
295 [spec.replace(device_type='CPU') for spec in local_device_specs]
296 ).reshape((len(self.local_devices()), 1))
297 global_devices = [
298 spec.replace(device_type='CPU') for spec in global_device_specs
299 ]
300 h_mesh = Mesh(
301 self.dim_names,
302 self.global_device_ids(),
303 self.local_device_ids(),
304 np.ravel(device_array).tolist(),
305 global_devices=global_devices,
306 )
307 return h_mesh
309 # TODO(b/242201545): implement this in Mesh C++ class
310 def local_device_locations(self) -> List[Dict[str, int]]:
311 """Returns a list of local device locations.
313 A device location is a dictionary from dimension names to indices on those
314 dimensions.
315 """
316 mapping = self.unravel_index()
317 return [mapping[device_id] for device_id in self.local_device_ids()]
319 # TODO(b/242201545): implement this in Mesh C++ class
320 @property
321 def strides(self) -> List[int]:
322 """Returns the strides tensor array for this mesh.
324 If the mesh shape is `[a, b, c, d]`, then the strides array can be computed
325 as `[b*c*d, c*d, d, 1]`. This array can be useful in computing local device
326 offsets given a device ID. Using the same example, the device coordinates of
327 the mesh can be computed as:
329 ```
330 [(device_id / (b*c*d)) % a,
331 (device_id / (c*d)) % b,
332 (device_id / (d)) % c,
333 (device_id) % d]
334 ```
336 This is the same as `(device_id // mesh.strides) % mesh.shape`.
338 Returns:
339 The mesh strides as an integer tensor.
340 """
341 return _compute_mesh_strides(self.shape())
343 # TODO(b/242201545): implement this in Mesh C++ class
344 def unravel_index(self):
345 """Returns a dictionary from device ID to {dim_name: dim_index}.
347 For example, for a 3x2 mesh, return this:
349 ```
350 { 0: {'x': 0, 'y', 0},
351 1: {'x': 0, 'y', 1},
352 2: {'x': 1, 'y', 0},
353 3: {'x': 1, 'y', 1},
354 4: {'x': 2, 'y', 0},
355 5: {'x': 2, 'y', 1} }
356 ```
357 """
358 idx_ranges = [range(self.dim_size(dim_name)) for dim_name in self.dim_names]
359 mesh_pos = itertools.product(*idx_ranges)
360 mapping = {}
361 for device_id, device_pos in enumerate(mesh_pos):
362 device_loc = {}
363 for dim_name, dim_index in zip(self.dim_names, device_pos):
364 device_loc[dim_name] = dim_index
365 mapping[device_id] = device_loc
366 return mapping
369# TODO(hthu): Consider making this class immutable.
370@tf_export('experimental.dtensor.Layout', v1=[])
371class Layout(_pywrap_dtensor_device.Layout):
372 """Represents the layout information of a DTensor.
374 A layout describes how a distributed tensor is partitioned across a mesh (and
375 thus across devices). For each axis of the tensor, the corresponding
376 sharding spec indicates which dimension of the mesh it is sharded over. A
377 special sharding spec `UNSHARDED` indicates that axis is replicated on
378 all the devices of that mesh.
380 For example, let's consider a 1-D mesh:
382 ```
383 Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], [("x", 6)])
384 ```
386 This mesh arranges 6 TPU devices into a 1-D array. `Layout([UNSHARDED], mesh)`
387 is a layout for rank-1 tensor which is replicated on the 6 devices.
389 For another example, let's consider a 2-D mesh:
391 ```
392 Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"],
393 [("x", 3), ("y", 2)])
394 ```
396 This mesh arranges 6 TPU devices into a `3x2` 2-D array.
397 `Layout(["x", UNSHARDED], mesh)` is a layout for rank-2 tensor whose first
398 axis is sharded on mesh dimension "x" and the second axis is replicated. If we
399 place `np.arange(6).reshape((3, 2))` using this layout, the individual
400 components tensors would look like:
402 ```
403 Device | Component
404 TPU:0 [[0, 1]]
405 TPU:1 [[0, 1]]
406 TPU:2 [[2, 3]]
407 TPU:3 [[2, 3]]
408 TPU:4 [[4, 5]]
409 TPU:5 [[4, 5]]
410 ```
411 """
413 def __init__(self, sharding_specs: List[str], mesh: Mesh):
414 """Builds a Layout from a list of dimension names and a Mesh.
416 Args:
417 sharding_specs: List of sharding specifications, each corresponding to a
418 tensor axis. Each specification (dim_sharding) can either be a mesh
419 dimension or the special value UNSHARDED.
420 mesh: A mesh configuration for the Tensor.
422 Returns:
423 A valid Layout built with given layout & mesh.
424 """
425 # Validate mesh
426 if not isinstance(mesh, Mesh):
427 raise ValueError('mesh is not a valid Mesh object.')
429 # Validate sharding spec
430 for _, dim_sharding in enumerate(sharding_specs):
431 # If special value no need to check for uniqueness, just skip.
432 if dim_sharding == UNSHARDED or dim_sharding == MATCH:
433 continue
434 # Check dim_sharding is unique.
435 if sharding_specs.count(dim_sharding) > 1:
436 raise ValueError(
437 ('Mesh dimension {mesh_dim} was repeated in sharding ' +
438 'specification {sharding_specs}. Mesh dimensions must be unique ' +
439 'in a layout.').format(
440 mesh_dim=dim_sharding, sharding_specs=sharding_specs))
441 # Check dim_sharding is mesh dimension.
442 if dim_sharding not in mesh:
443 raise ValueError(
444 ('{dim_sharding}: A dimension sharding must either be a ' +
445 'valid mesh dimension or UNSHARDED.').format(
446 dim_sharding=dim_sharding))
448 super().__init__(sharding_specs=sharding_specs, mesh=mesh)
450 def __repr__(self) -> str:
451 return f'Layout(sharding_specs={self.sharding_specs}, mesh={self.mesh})'
453 def __hash__(self):
454 return hash(self.as_proto().SerializeToString(deterministic=True))
456 # TODO(panzf): change to pybind11 pickle implementation in the last step
457 def __reduce__(self):
458 return Layout.from_string, (self.to_string(),)
460 # TODO(b/242201545): Find a way to return Mesh object from the pywrap module.
461 @property
462 def mesh(self):
463 return Mesh.from_proto(super().mesh.as_proto())
465 @property
466 def shape(self):
467 return self.mesh.shape()
469 @classmethod
470 def batch_sharded(
471 cls, mesh: Mesh, batch_dim: str, rank: int, axis: int = 0
472 ) -> 'Layout':
473 """Returns a layout sharded on batch dimension."""
474 layout_obj = _pywrap_dtensor_device.Layout.__new__(cls)
475 _pywrap_dtensor_device.Layout.__init__(
476 # Watchout for the different ordering.
477 layout_obj,
478 mesh=mesh,
479 rank=rank,
480 batch_dim=batch_dim,
481 axis=axis,
482 )
483 return layout_obj
485 # TODO(b/242201545): Move this to C++ / find the corresponding function there.
486 def delete(self, dims: List[int]) -> 'Layout':
487 """Returns the layout with the give dimensions deleted."""
488 if not isinstance(dims, list):
489 dims = [dims]
490 new_specs = [
491 spec for i, spec in enumerate(self.sharding_specs) if i not in dims
492 ]
493 return Layout(new_specs, self.mesh)
495 @classmethod
496 def from_proto(cls, layout_proto: layout_pb2.LayoutProto) -> 'Layout':
497 """Creates an instance from a LayoutProto."""
498 layout_obj = _pywrap_dtensor_device.Layout.__new__(cls)
499 _pywrap_dtensor_device.Layout.__init__(
500 layout_obj, layout_proto=layout_proto
501 )
502 return layout_obj
504 @classmethod
505 def from_string(cls, layout_str: str) -> 'Layout':
506 """Creates an instance from a human-readable string."""
507 layout_obj = _pywrap_dtensor_device.Layout.__new__(cls)
508 _pywrap_dtensor_device.Layout.__init__(layout_obj, layout_str=layout_str)
509 return layout_obj
511 @classmethod
512 def inner_sharded(cls, mesh: Mesh, inner_dim: str, rank: int) -> 'Layout':
513 """Returns a layout sharded on inner dimension."""
514 return cls.batch_sharded(mesh, inner_dim, rank, axis=rank - 1)
516 @classmethod
517 def from_single_device_mesh(cls, mesh: Mesh) -> 'Layout':
518 """Constructs a single device layout from a single device mesh."""
519 layout = _pywrap_dtensor_device.Layout.__new__(cls)
520 _pywrap_dtensor_device.Layout.__init__(layout, mesh=mesh)
521 return layout
523 @classmethod
524 def from_device(cls, device: str) -> 'Layout':
525 """Constructs a single device layout from a single device mesh."""
526 return cls.from_single_device_mesh(Mesh.from_device(device))
528 # TODO(b/242201545): Move this to C++ / find the corresponding function there.
529 def offset_to_shard(self):
530 """Mapping from offset in a flattened list to shard index."""
531 unravel_index = self.mesh.unravel_index()
532 locations = [None] * self.mesh.size
533 for offset, mesh_loc in unravel_index.items():
534 loc = []
535 for dim_sharding in self.sharding_specs:
536 if dim_sharding == UNSHARDED:
537 loc.append(0)
538 else:
539 loc.append(mesh_loc[dim_sharding])
540 locations[offset] = tuple(loc)
542 return locations
544 # TODO(b/242201545): Move this to C++ / find the corresponding function there.
545 def offset_tuple_to_global_index(self, offset_tuple):
546 """Mapping from offset to index in global tensor."""
547 index = 0
548 for i, o in enumerate(offset_tuple):
549 m = 1
550 for x in range(i + 1, self.rank):
551 m = m * self.num_shards(x)
552 index = index + m * o
553 return index
555 @classmethod
556 def replicated(cls, mesh: Mesh, rank: int) -> 'Layout':
557 """Returns a replicated layout of rank `rank`."""
558 layout_obj = _pywrap_dtensor_device.Layout.__new__(cls)
559 _pywrap_dtensor_device.Layout.__init__(layout_obj, mesh=mesh, rank=rank)
560 return layout_obj