Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/compiler/xla/experimental/xla_sharding.py: 26%
153 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Experimental support for defining XLA shardings."""
17import numpy as _np # Avoids becoming a part of public Tensorflow API.
19from tensorflow.compiler.tf2xla.python import xla as tf2xla
20from tensorflow.compiler.xla import xla_data_pb2
21from tensorflow.core.framework import attr_value_pb2
24class Sharding(object):
25 """A class to support adding sharding attributes to Ops.
27 Use the factory constructors and then call apply_to_tensor:
28 Sharding.replicate().apply_to_tensor(tensor)
29 """
31 def __init__(self, proto=None):
32 """Do not use this constructor; use the factory functions below."""
33 self._proto = proto
35 @classmethod
36 def replicate(cls):
37 """Returns a replicated sharding attribute.
39 This causes an op to be computed in its entirety independently on all
40 cores in the XLA device.
41 """
42 return Sharding(
43 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED))
45 @classmethod
46 def manual(cls):
47 """Returns a manuall sharding attribute.
49 This means the op is manually partitioned by the user and XLA will not
50 change the shapes.
51 """
52 return Sharding(
53 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MANUAL))
55 @classmethod
56 def assign_device(cls, core):
57 """Returns an AssignDevice sharding attribute.
59 This causes an op to be computed in its entirety only on one core in
60 the XLA device.
61 Args:
62 core: The core to assign this Op to.
63 """
64 return Sharding(
65 proto=xla_data_pb2.OpSharding(
66 type=xla_data_pb2.OpSharding.MAXIMAL,
67 tile_assignment_dimensions=[1],
68 tile_assignment_devices=[core]))
70 @classmethod
71 def tile(cls, tile_assignment):
72 """Returns a Tiled sharding attribute.
74 This causes an op to be partially computed on multiple cores in the
75 XLA device.
77 Args:
78 tile_assignment: An np.ndarray describing the topology of the tiling and
79 which device will compute which part of the topology.
81 Raises:
82 TypeError: tile_assignment was not of np.array type.
84 TODO(jmolloy): This concept is nefarious and is not
85 something we really want to expose to users (especially as the
86 contract for tile_assignment is very strict).
87 """
88 if not isinstance(tile_assignment, _np.ndarray):
89 raise TypeError('Tile assignment must be of type np.ndarray')
90 dims = list(tile_assignment.shape)
91 flattened_devices = tile_assignment.reshape(-1, order='C')
92 return Sharding(
93 proto=xla_data_pb2.OpSharding(
94 type=xla_data_pb2.OpSharding.OTHER,
95 tile_assignment_dimensions=dims,
96 tile_assignment_devices=list(flattened_devices)))
98 @classmethod
99 def subgroup_tile(cls, tile_assignment, subgroup_modes):
100 """Returns a subgroup manual sharding attribute.
102 This is similar to tile(), but tile_assignment has one or more dimension
103 than the tensor, and subgroup_modes define the sharding types in the last
104 dimensions of tile_assignment.
106 Args:
107 tile_assignment: An np.ndarray describing the topology of the tiling and
108 which device will compute which part of the topology.
109 subgroup_modes: sharding types for the dimension more than the tensor
110 shape rank.
112 Raises:
113 TypeError: tile_assignment was not of np.array type or subgroup_modes
114 has unsupported sharding type.
115 """
116 if not isinstance(tile_assignment, _np.ndarray):
117 raise TypeError('SubgroupTile assignment must be of type np.ndarray')
119 if not isinstance(subgroup_modes, list):
120 raise TypeError('subgroup_modes in subgroup manual must be of type list')
122 if len(tile_assignment.shape) < len(subgroup_modes):
123 raise TypeError('SubgroupTile assignment must have rank larger than'
124 ' length of subgroup_modes')
126 for sharding_type in subgroup_modes:
127 if sharding_type not in [
128 xla_data_pb2.OpSharding.REPLICATED, xla_data_pb2.OpSharding.MANUAL
129 ]:
130 raise TypeError(
131 'Each sharding_type in subgroup_modes in subgroup manual must '
132 'be of type xla_data_pb2.OpSharding.REPLICATED'
133 ' or xla_data_pb2.OpSharding.MANUAL')
134 dims = list(tile_assignment.shape)
135 flattened_devices = tile_assignment.reshape(-1, order='C')
136 return Sharding(
137 proto=xla_data_pb2.OpSharding(
138 type=xla_data_pb2.OpSharding.OTHER,
139 tile_assignment_dimensions=dims,
140 tile_assignment_devices=list(flattened_devices),
141 last_tile_dims=list(subgroup_modes)))
143 @classmethod
144 def partial_tile(cls, tile_assignment):
145 """Returns a partially tiled sharding attribute.
147 This is similar to tile(), but tile_assignment has one more dimension than
148 the tensor, and tiles in the last dimension of tile_assignment are
149 replicated.
151 Args:
152 tile_assignment: An np.ndarray describing the topology of the tiling and
153 which device will compute which part of the topology.
155 Raises:
156 TypeError: tile_assignment was not of np.array type.
157 """
158 if not isinstance(tile_assignment, _np.ndarray):
159 raise TypeError('PartialTile assignment must be of type np.ndarray')
160 dims = list(tile_assignment.shape)
161 flattened_devices = tile_assignment.reshape(-1, order='C')
162 return Sharding(
163 proto=xla_data_pb2.OpSharding(
164 type=xla_data_pb2.OpSharding.OTHER,
165 tile_assignment_dimensions=dims,
166 tile_assignment_devices=list(flattened_devices),
167 replicate_on_last_tile_dim=True))
169 @classmethod
170 def split(cls, tensor, split_dimension, num_devices, input_shape=None):
171 """Returns a Sharding that splits a tensor across a dimension.
173 This creates a Tiled attribute, similar to tile(), but easier to use for the
174 common case of tiling a tensor N ways in one dimension.
176 Args:
177 tensor: A tf.Tensor to split.
178 split_dimension: The dimension number to split.
179 num_devices: The number of cores to split `tensor` over.
180 input_shape: The shape of the original tensor.
182 Raises:
183 ValueError: The tensor to split was smaller in the split dimension than
184 the number of devices to split over.
185 """
186 if input_shape:
187 shape = input_shape
188 else:
189 shape = tensor.shape.as_list()
190 if (shape[split_dimension] is not None and
191 shape[split_dimension] < num_devices):
192 raise ValueError('Split dimension was smaller than the required number '
193 'of splits: shape=%r, dimension=%r, num_devices=%r' %
194 (shape, split_dimension, num_devices))
196 tile_assignment_dims = [1] * len(shape)
197 tile_assignment_dims[split_dimension] = num_devices
199 return Sharding(
200 proto=xla_data_pb2.OpSharding(
201 type=xla_data_pb2.OpSharding.OTHER,
202 tile_assignment_dimensions=tile_assignment_dims,
203 tile_assignment_devices=range(num_devices)))
205 def apply_to_tensor(self,
206 tensor,
207 assign_tuple_sharding=False,
208 use_sharding_op=False,
209 unspecified_dims=None):
210 """Applies this Sharding attribute to `tensor`.
212 Args:
213 tensor: A tf.Tensor to split.
214 assign_tuple_sharding: If the sharding type should be a tuple.
215 use_sharding_op: Whether to create a sharding op on `tensor`.
216 unspecified_dims: An optional list of dimensions unspecified.
218 Returns:
219 The tensor with Sharding attribute.
220 """
221 if unspecified_dims:
222 assert use_sharding_op and not assign_tuple_sharding
223 proto = self._proto
224 if use_sharding_op:
225 if assign_tuple_sharding:
226 proto = self._create_tuple_proto(num_outputs=1)
227 tensor = tf2xla.sharding(tensor, sharding=proto.SerializeToString())
228 else:
229 tensor = tf2xla.sharding(
230 tensor,
231 sharding=proto.SerializeToString(),
232 unspecified_dims=unspecified_dims or [])
233 elif assign_tuple_sharding or len(tensor.op.outputs) > 1:
234 proto = self._get_or_create_tuple_proto(tensor.op)
235 # We can't mutate an element of old_proto.tuple_shardings, so create
236 # a new proto.
237 tuple_shardings = list(proto.tuple_shardings)
238 tuple_shardings[tensor.value_index] = self._proto
239 proto = xla_data_pb2.OpSharding(
240 type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings)
242 # TODO(jmolloy): This need to be seriously revisited before declaring this
243 # API available for public use.
244 # pylint: disable=protected-access
245 tensor.op._set_attr('_XlaSharding',
246 attr_value_pb2.AttrValue(s=proto.SerializeToString()))
247 return tensor
249 def apply_to_operation(self, operation):
250 """Applies this Sharding attribute to `operation`.
252 Args:
253 operation: A tf.Operation to add sharding annotation.
254 """
255 attr_value = attr_value_pb2.AttrValue(s=self._proto.SerializeToString())
256 # pylint: disable=protected-access
257 operation._set_attr('_XlaSharding', attr_value)
259 @property
260 def proto(self):
261 """Return the sharding protobuf of type xla_data_pb2.OpSharding."""
262 return self._proto
264 def _get_or_create_tuple_proto(self, op):
265 try:
266 attr = op.get_attr('_XlaSharding')
267 proto = xla_data_pb2.OpSharding()
268 proto.ParseFromString(attr)
269 return proto
270 except ValueError:
271 return self._create_tuple_proto(len(op.outputs))
273 def _create_tuple_proto(self, num_outputs):
274 shardings = [
275 xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)
276 ] * num_outputs
277 return xla_data_pb2.OpSharding(
278 type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings)
281def copy_sharding(from_tensor, to_tensor, use_sharding_op=False):
282 """Copies the a tensor's sharding to another.
284 Args:
285 from_tensor: Source tensor. Must be the sole output of an op.
286 to_tensor: the tensor the annotate with the copy.
287 use_sharding_op: whether to create a sharding op on `to_tensor`.
289 Returns:
290 A tensor with sharding annotation copied from `from_tensor`.
291 """
292 sharding = get_tensor_sharding(from_tensor)
293 if sharding is None:
294 return to_tensor
296 if use_sharding_op:
297 to_tensor = tf2xla.sharding(to_tensor, sharding=sharding)
298 attr_value = attr_value_pb2.AttrValue(s=sharding)
299 # pylint: disable=protected-access
300 to_tensor.op._set_attr('_XlaSharding', attr_value)
301 return to_tensor
303# Helpers for the above factory functions that allow easy application of
304# shardings, for example:
305# tensor = xla_sharding.replicate(tensor)
308def replicate(tensor, assign_tuple_sharding=False, use_sharding_op=False):
309 return Sharding.replicate().apply_to_tensor(
310 tensor,
311 assign_tuple_sharding=assign_tuple_sharding,
312 use_sharding_op=use_sharding_op)
315def assign_device(tensor,
316 device,
317 assign_tuple_sharding=False,
318 use_sharding_op=False):
319 """Returns a tensor that has AssignDevice sharding attribute."""
320 return Sharding.assign_device(device).apply_to_tensor(
321 tensor,
322 assign_tuple_sharding=assign_tuple_sharding,
323 use_sharding_op=use_sharding_op)
326def tile(tensor,
327 tile_assignment,
328 assign_tuple_sharding=False,
329 use_sharding_op=False,
330 unspecified_dims=None):
331 """Returns a tensor that has tiled sharding.
333 Args:
334 tensor: A tf.Tensor to shard.
335 tile_assignment: An np.ndarray describing the topology of the tiling and
336 which device will compute which part of the topology.
337 assign_tuple_sharding: If the sharding type should be a tuple.
338 use_sharding_op: If true, adds a sharding op to set the sharding.
339 unspecified_dims: An optional list of dimensions unspecified.
340 """
341 return Sharding.tile(tile_assignment).apply_to_tensor(
342 tensor,
343 assign_tuple_sharding=assign_tuple_sharding,
344 use_sharding_op=use_sharding_op,
345 unspecified_dims=unspecified_dims or [])
348def split(tensor,
349 split_dimension,
350 num_devices,
351 assign_tuple_sharding=False,
352 use_sharding_op=False,
353 input_shape=None):
354 """Returns a tensor that is split along the given dimension.
356 Args:
357 tensor: A tf.Tensor to split.
358 split_dimension: The dimension to split.
359 num_devices: The number of devices to partition the dimension.
360 assign_tuple_sharding: If the sharding type should be a tuple.
361 use_sharding_op: If true, adds a sharding op to set the sharding.
362 input_shape: The full shape of the input tensor.
363 """
364 return Sharding.split(tensor, split_dimension, num_devices,
365 input_shape).apply_to_tensor(
366 tensor,
367 assign_tuple_sharding=assign_tuple_sharding,
368 use_sharding_op=use_sharding_op)
371def partial_tile(tensor,
372 tile_assignment,
373 use_sharding_op=False,
374 unspecified_dims=None):
375 """Returns a tensor that has tiled sharding.
377 Args:
378 tensor: A tf.Tensor to shard.
379 tile_assignment: An np.ndarray describing the topology of the tiling and
380 which device will compute which part of the topology. It must have one
381 more dimension than tensor, and the last dimension represents partially
382 replicated tiles.
383 use_sharding_op: If true, adds a sharding op to set the sharding.
384 unspecified_dims: An optional list of dimensions unspecified.
385 """
386 return Sharding.partial_tile(tile_assignment).apply_to_tensor(
387 tensor,
388 use_sharding_op=use_sharding_op,
389 unspecified_dims=unspecified_dims or [])
392def get_op_sharding(op):
393 """Returns sharding attribute of an op.
395 Args:
396 op: a TensorFlow op.
398 Returns:
399 The attribute representing XLA sharding on this op.
400 """
401 try:
402 return op.get_attr('_XlaSharding')
403 except ValueError:
404 return None
405 except AttributeError:
406 # AttributeError: 'DistributedVarOp' object has no attribute 'get_attr'.
407 return None
410def get_tensor_sharding(tensor):
411 """Returns sharding attribute of a Tensor.
413 Args:
414 tensor: a Tensor.
416 Returns:
417 The attribute representing XLA sharding on tensor's op.
418 """
419 try:
420 return get_op_sharding(tensor.op)
421 except AttributeError:
422 # AttributeError: Tensor.op is meaningless when eager execution is enabled.
423 return None
426def get_sharding_tile_shape(sharding):
427 """Returns the tile assignment shape for a sharded Tensor.
429 Args:
430 sharding: a serialized OpSharding message describing the layout of a
431 sharded Tensor.
433 Returns:
434 A list, for each dimension of the sharded Tensor, of the number of shards
435 into which it has been split. Returns None if the input indicates no tile
436 assignments.
437 """
438 if sharding is None:
439 return None
440 sharding_message = xla_data_pb2.OpSharding()
441 sharding_message.ParseFromString(sharding)
442 if sharding_message.tile_assignment_dimensions:
443 return sharding_message.tile_assignment_dimensions
444 else:
445 return None
448def auto_to_manual_spmd_partition(tensor,
449 manual_sharding,
450 single_dim=-1,
451 unspecified_dims=None):
452 """Switches from automatic SPMD partitioning to manual partitioning.
454 Converts a full-shaped tensor (to be automatically partitioned by SPMD
455 partitioner) to a shard-shaped tensor to be consumed by manually partitioned
456 ops.
458 Args:
459 tensor: A tf.Tensor in full shape.
460 manual_sharding: A serialized string of OpSharding to be used in manual
461 partitioning.
462 single_dim: If >= 0, the conversion will happen only on this dim in
463 subgroups.
464 unspecified_dims: An optional list of dimensions unspecified.
466 Returns:
467 A shard-shaped tensor to be consumed by manually partitioned ops.
468 """
469 return tf2xla.spmd_full_to_shard_shape(
470 tensor,
471 manual_sharding=manual_sharding,
472 dim=single_dim,
473 unspecified_dims=unspecified_dims or [])
476def manual_to_auto_spmd_partition(tensor,
477 manual_sharding,
478 full_shape,
479 single_dim=-1,
480 unspecified_dims=None):
481 """Switches from manual partitioning to automatic SPMD partitioning.
483 Converts a shard-shaped tensor (manually partitioned in SPMD-style) to a
484 full-shaped tensor to be partitioned automatically by the SPMD partitioner.
486 Args:
487 tensor: A tf.Tensor in shard shape.
488 manual_sharding: a serialized string of OpSharding to be used in manual
489 partitioning.
490 full_shape: the shape of tensor before partitioning.
491 single_dim: If >= 0, the conversion will happen only on this dim in
492 subgroups.
493 unspecified_dims: An optional list of dimensions unspecified.
495 Returns:
496 A full-shaped tensor to be partitioned automatically by the SPMD
497 partitioner.
498 """
499 return tf2xla.spmd_shard_to_full_shape(
500 tensor,
501 manual_sharding=manual_sharding,
502 full_shape=full_shape,
503 dim=single_dim,
504 unspecified_dims=unspecified_dims or [])
507def mesh_split_sharding(device_mesh,
508 tensor_split_dims_mapping,
509 manual_mesh_dims=None):
510 """Returns a Sharding object representing sharding along multiple dimensions.
512 Args:
513 device_mesh: An np.ndarray describing the topology of the device mesh and
514 each element is the ID of the device in the topology.
515 tensor_split_dims_mapping: A list of integers that map each tensor axis to
516 the device mesh axis along which it is sharded. Its length is the tensor
517 rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor
518 dimension i. Use -1 for tensor dimensions that are not sharded.
519 manual_mesh_dims: An optional list of mesh dims for manual subgroups.
521 Raises:
522 ValueError: The number of tensor split dimensions is larger than device mesh
523 rank.
524 """
525 manual_mesh_dims = manual_mesh_dims or []
526 permutation = [d for d in tensor_split_dims_mapping if d >= 0
527 ] + manual_mesh_dims
528 if len(permutation) > len(device_mesh.shape):
529 raise ValueError(
530 'Number of tensor split dimensions (%r) is larger than device mesh '
531 'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' %
532 (len(permutation), len(
533 device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape))
534 # Append replicated dimensions to the end.
535 transpose_permutation = permutation + [
536 d for d in range(len(device_mesh.shape)) if d not in permutation
537 ]
538 tile_assignment = _np.transpose(device_mesh, transpose_permutation)
539 tile_shape = [
540 1 if d < 0 else device_mesh.shape[d]
541 for d in (tensor_split_dims_mapping + manual_mesh_dims)
542 ]
543 subgroup_modes = [xla_data_pb2.OpSharding.MANUAL] * len(manual_mesh_dims)
544 partial = len(permutation) < len(device_mesh.shape)
545 if partial:
546 tile_shape.append(_np.prod(device_mesh.shape) // _np.prod(tile_shape))
547 subgroup_modes.append(xla_data_pb2.OpSharding.REPLICATED)
548 tile_assignment = _np.reshape(tile_assignment, tile_shape)
550 if manual_mesh_dims:
551 return Sharding.subgroup_tile(tile_assignment, subgroup_modes)
553 if partial:
554 return Sharding.partial_tile(tile_assignment)
555 return Sharding.tile(tile_assignment)
558def mesh_split(tensor,
559 device_mesh,
560 tensor_split_dims_mapping,
561 use_sharding_op=False,
562 manual_mesh_dims=None,
563 unspecified_dims=None):
564 """Returns a tensor that is split along multiple dimensions in a device mesh.
566 Args:
567 tensor: A tf.Tensor to split.
568 device_mesh: An np.ndarray describing the topology of the device mesh and
569 each element is the ID of the device in the topology.
570 tensor_split_dims_mapping: A list of integers that map each tensor axis to
571 the device mesh axis along which it is sharded. Its length is the tensor
572 rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor
573 dimension i. Use -1 for tensor dimensions that are not sharded.
574 use_sharding_op: If true, adds a sharding op to set the sharding.
575 manual_mesh_dims: An optional list of mesh dims for manual subgroups.
576 unspecified_dims: An optional list of dimensions unspecified.
578 Raises:
579 ValueError: The number of tensor split dimensions is larger than device mesh
580 rank.
581 """
582 sharding = mesh_split_sharding(device_mesh, tensor_split_dims_mapping,
583 manual_mesh_dims)
584 return sharding.apply_to_tensor(
585 tensor,
586 use_sharding_op=use_sharding_op,
587 unspecified_dims=unspecified_dims or [])