Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/mesh_util.py: 24%
97 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to help with mesh creation."""
17from typing import List, Optional, Tuple
18from absl import logging
19import numpy as np
21from tensorflow.dtensor.python import accelerator_util
22from tensorflow.dtensor.python import api
23from tensorflow.dtensor.python import config
24from tensorflow.dtensor.python import layout
25from tensorflow.dtensor.python import tpu_util
26from tensorflow.python.eager import context
27from tensorflow.python.framework import device as tf_device
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.util.tf_export import tf_export
33def _print_context(num_global_devices: int, num_clients: int, client_id: int,
34 device_type: str, mesh: layout.Mesh) -> None:
35 logging.info('This is client %d of %d clients', client_id, num_clients)
36 logging.info('Number of global %s devices: %d', device_type.upper(),
37 num_global_devices)
38 # pylint: disable=protected-access
39 logging.info('Global device IDs: %s', mesh.global_device_ids())
40 logging.info('Local device IDs: %s', mesh.local_device_ids())
41 logging.info('Local devices: %s', mesh.local_devices())
42 # pylint: enable=protected-access
45def _make_device_specs(
46 devices: Optional[List[str]] = None,
47 device_type: Optional[str] = None
48) -> Tuple[List[tf_device.DeviceSpec], str]:
49 """Makes device specs from local devices names or number of global devices."""
51 if devices is None:
52 if device_type is None:
53 device_type = 'CPU'
54 devices = config.local_devices(device_type)
55 else:
56 devices = [tf_device.DeviceSpec.from_string(d) for d in devices]
57 if device_type is None:
58 device_type = devices[0].device_type
60 if device_type.upper() != devices[0].device_type.upper():
61 raise ValueError(
62 f'Conflicting devices {str(devices)} and device_type {device_type}')
64 return devices, device_type
67@tf_export('experimental.dtensor.create_mesh', v1=[])
68def create_mesh(mesh_dims: Optional[List[Tuple[str, int]]] = None,
69 mesh_name: str = '',
70 devices: Optional[List[str]] = None,
71 device_type: Optional[str] = None,
72 use_xla_spmd: bool = layout.USE_XLA_SPMD) -> layout.Mesh:
73 """Creates a single-client mesh.
75 If both `mesh_dims` and `devices` are specified, they must match each otehr.
76 As a special case, when all arguments are missing, this creates a 1D CPU mesh
77 with an empty name, assigning all available devices to that dimension.
79 Args:
80 mesh_dims: A list of (dim_name, dim_size) tuples. Defaults to a single
81 batch-parallel dimension called 'x' using all devices. As a special case,
82 a single-element mesh_dims whose dim_size is -1 also uses all devices.
83 mesh_name: Name of the created mesh. Defaults to ''.
84 devices: String representations of devices to use. This is the device part
85 of tf.DeviceSpec, e.g. 'CPU:0'. Defaults to all available logical devices.
86 device_type: If `devices` is missing, the type of devices to use. Defaults
87 to 'CPU'.
88 use_xla_spmd: Boolean when True, will use XLA SPMD instead of
89 DTensor SPMD.
91 Returns:
92 A single-client mesh created from specified or default arguments.
93 """
94 device_specs, device_type = _make_device_specs(devices, device_type)
96 local_spec = tf_device.DeviceSpec(job=config.job_name(), replica=0, task=0)
97 device_specs = [local_spec.make_merged_spec(d) for d in device_specs]
99 if mesh_dims is None:
100 mesh_dims = [('x', len(device_specs))]
101 elif len(mesh_dims) == 1 and mesh_dims[0][1] == -1:
102 # Replace -1 dim_size in a 1D mesh will the number of all devices.
103 mesh_dims[0] = (mesh_dims[0][0], len(device_specs))
105 dim_names = [d[0] for d in mesh_dims]
106 shape = [d[1] for d in mesh_dims]
108 if np.prod(shape) != len(device_specs):
109 raise ValueError(f'length of devices ({len(device_specs)}) must be '
110 f'equal to total size of the mesh of shape {shape}')
112 global_device_ids = np.arange(len(device_specs)).reshape(shape)
113 local_device_ids = np.ravel(global_device_ids).tolist()
114 mesh = layout.Mesh(
115 dim_names=dim_names,
116 global_device_ids=global_device_ids,
117 local_device_ids=local_device_ids,
118 local_devices=device_specs,
119 mesh_name=mesh_name,
120 use_xla_spmd=use_xla_spmd)
121 _print_context(
122 num_global_devices=len(device_specs),
123 num_clients=1,
124 client_id=0,
125 device_type=device_type,
126 mesh=mesh)
127 return mesh
130@tf_export('experimental.dtensor.create_distributed_mesh', v1=[])
131def create_distributed_mesh(
132 mesh_dims: List[Tuple[str, int]],
133 mesh_name: str = '',
134 local_devices: Optional[List[str]] = None,
135 device_type: Optional[str] = None,
136 use_xla_spmd: bool = layout.USE_XLA_SPMD) -> layout.Mesh:
137 """Creates a distributed mesh.
139 This is similar to `create_mesh`, but with a different set of arguments to
140 create a mesh that spans evenly across a multi-client DTensor cluster.
142 For CPU and GPU meshes, users can choose to use fewer local devices than what
143 is available `local_devices`.
145 For TPU, only meshes that uses all TPU cores is supported by the DTensor
146 runtime.
148 Args:
149 mesh_dims: A list of (dim_name, dim_size) tuples.
150 mesh_name: Name of the created mesh. Defaults to ''.
151 local_devices: String representations of devices to use. This is the device
152 part of tf.DeviceSpec, e.g. 'CPU:0'. Defaults to all available local
153 logical devices.
154 device_type: Type of device to build the mesh for. Defaults to 'CPU'.
155 Supported values are 'CPU', 'GPU', 'TPU'.6
156 use_xla_spmd: Boolean when True, will use XLA SPMD instead of
157 DTensor SPMD.
159 Returns:
160 A mesh that spans evenly across all DTensor clients in the cluster.
161 """
162 dim_names, shape = zip(*mesh_dims)
164 if not accelerator_util.is_initialized():
165 raise ValueError('Accelerators are uninitialized, please run '
166 'dtensor.initialize_accelerator_system() first.')
168 if device_type and device_type.upper() == 'TPU':
169 # TODO(b/185940495): Allow multi-mesh and partial on TPU.
170 # TPU meshes can only be configured through environment variables that
171 # reflect the actual TPU topology. Do not let users specify custom args.
172 if local_devices is not None:
173 raise ValueError(
174 f'Do not specify devices for {device_type.upper()} meshes. '
175 f'Using a partial list of devices for {device_type.upper()} '
176 f'is not supported.')
178 device_specs, device_type = _make_device_specs(local_devices, device_type)
180 if device_type.upper() in ['CPU', 'GPU']:
181 # For CPU and GPU meshes, user-specified args take precedence over env vars.
182 # This is particularly useful on single clients when users want to create
183 # meshes that use fewer logical devices than what's available.
185 local_spec = tf_device.DeviceSpec(
186 job=config.job_name(), replica=0, task=config.client_id())
187 device_specs = [local_spec.make_merged_spec(d) for d in device_specs]
189 # Assumes identical number of local devices per client.
190 num_global_devices = len(device_specs) * config.num_clients()
192 if np.prod(shape) != num_global_devices:
193 raise ValueError(
194 f'Global number of devices '
195 f'({len(device_specs)} per client * {config.num_clients()} clients '
196 f'= {num_global_devices}) must be '
197 f'equal to total size of the mesh of shape {shape}')
199 global_device_ids = np.arange(num_global_devices).reshape(shape)
200 flattened = np.ravel(global_device_ids).tolist()
201 start_idx = len(device_specs) * config.client_id()
202 local_device_ids = flattened[start_idx:start_idx + len(device_specs)]
204 mesh = layout.Mesh(
205 dim_names=dim_names,
206 global_device_ids=global_device_ids,
207 local_device_ids=local_device_ids,
208 local_devices=device_specs,
209 mesh_name=mesh_name,
210 use_xla_spmd=use_xla_spmd)
211 _print_context(num_global_devices, config.num_clients(), config.client_id(),
212 device_type, mesh)
213 return mesh
215 if device_type.upper() == 'TPU':
216 mesh = tpu_util.create_tpu_mesh(
217 mesh_dim_names=dim_names,
218 mesh_shape=shape,
219 mesh_name=mesh_name,
220 use_xla_spmd=use_xla_spmd)
221 _print_context(
222 config.num_global_devices(device_type), config.num_clients(),
223 config.client_id(), device_type, mesh)
224 return mesh
226 raise ValueError(f'Device type {device_type} is not CPU, GPU or TPU')
229_BARRIER_DICT = {}
232@tf_export('experimental.dtensor.barrier', v1=[])
233def barrier(mesh: layout.Mesh,
234 barrier_name: Optional[str] = None,
235 timeout_in_ms: Optional[int] = None):
236 """Runs a barrier on the mesh.
238 Upon returning from the barrier, all operations run before the barrier
239 would have completed across all clients. Currently we allocate a fully
240 sharded tensor with mesh shape and run an all_reduce on it.
242 Example:
244 A barrier can be used before application exit to ensure completion of pending
245 ops.
247 ```python
249 x = [1, 2, 3]
250 x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1))
251 dtensor.barrier(mesh)
253 # At this point all devices on all clients in the mesh have completed
254 # operations before the barrier. Therefore it is OK to tear down the clients.
255 sys.exit()
256 ```
258 Args:
259 mesh: The mesh to run the barrier on.
260 barrier_name: The name of the barrier. Mainly used for logging purpose.
261 timeout_in_ms: The timeout of the barrier in ms. If omitted, blocks
262 indefinitely till the barrier is reached from all clients.
263 """
264 if barrier_name is None:
265 barrier_name = '(barrier)'
267 logging.info('entering barrier before op: %s', barrier_name)
269 # Make sure all ops are consumed before running the sync.
270 context.async_wait()
272 # Reduction on a fully sharded tensor requires all devices to participate
273 # and serves as a barrier on the mesh.
274 component = array_ops.reshape(1.0, [1] * len(mesh.shape()))
275 ones = api.pack([component] * mesh.num_local_devices(),
276 layout.Layout(mesh.dim_names, mesh))
278 mesh_size = math_ops.reduce_sum(ones)
279 if mesh_size != mesh.size:
280 raise ValueError(
281 'Global barrier produced wrong mesh size : {0} while mesh has actual'
282 'size : {1}'.format(mesh_size, mesh.size))
284 # TODO(hthu): This isn't strictly needed but might cause confusing behaviors
285 # from users. Consider dropping this if there is a `big` performance hit.
286 context.async_wait()
288 if context.context().coordination_service:
289 if timeout_in_ms is None:
290 timeout_in_ms = 24 * 60 * 60 * 1000 # 24 hours to stand in for infinite.
292 num_calls = _BARRIER_DICT.setdefault(barrier_name, 0)
293 _BARRIER_DICT[barrier_name] = num_calls + 1
295 barrier_id = f'{barrier_name}:{num_calls}'
296 context.context().wait_at_barrier(barrier_id, timeout_in_ms)
298 logging.info('finished running barrier across all clients after '
299 'op: %s', barrier_name)