Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/input_util.py: 27%
207 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""APIs to deal with input datasets efficiently in DTensor.
17When using tf.data with DTensor, the `DTensorDataset` API can be used to
18efficiently handle loading the input data and correctly packing it to the
19corresponding devices. This API is intended to work with unbatched data and can
20be used for both data and model parallel setups.
22Example usage:
24>>> # 1-D mesh with 4 devices
25>>> mesh = dtensor.Mesh(dim_names=['batch'], ...)
26>>> layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)
27>>> dataset = tf.data.Dataset.range(256)
28>>> d_dataset = dtensor.DTensorDataset(
29... dataset=dataset,
30... global_batch_size=16,
31... mesh=mesh,
32... layouts=layout,
33... batch_dim='batch')
34>>> d_iter = iter(d_dataset)
35>>> # Each batch is a length 16 tensor sharded across 4 devices
36>>> batch_0_dtensor = next(d_iter)
37>>> batch_0_dtensor
38<tf.Tensor: shape=(16,),
39 dtype=int64,
40 value={"CPU:0": [ 0 1 2 4],
41 "CPU:1": [ 5 6 7 8],
42 "CPU:2": [ 9 10 11 12],
43 "CPU:3": [13 14 15 16]}>
44>>> batch_1_dtensor = next(d_iter)
45>>> batch_1_dtensor
46<tf.Tensor: shape=(16,),
47 dtype=int64,
48 value={"CPU:0": [17 18 19 20],
49 "CPU:1": [21 22 23 24],
50 "CPU:2": [25 26 27 28],
51 "CPU:3": [29 30 31 32]}>
53For multi-client setups, `DTensorDataset` interacts with tf.data service to
54correctly distribute the dataset among the participating clients. DTensor works
55with tf.data service in co-located mode where each worker is running alongside
56the DTensor client (the Tensorflow Python process). The `TFDataServiceConfig`
57dataclass can be filled with information about the tf.data service cluster, and
58passed to `DTensorDataset` to enable distribution.
59"""
61import dataclasses
62import operator
64from typing import Any, List, Optional, Sequence, Tuple
66from tensorflow.dtensor.python import api
67from tensorflow.dtensor.python import config
68from tensorflow.dtensor.python import layout as layout_lib
69from tensorflow.python.data.experimental.ops import data_service_ops
70from tensorflow.python.data.ops import dataset_ops
71from tensorflow.python.data.ops import iterator_ops
72from tensorflow.python.eager import context
73from tensorflow.python.framework import constant_op
74from tensorflow.python.framework import dtypes
75from tensorflow.python.framework import errors
76from tensorflow.python.framework import ops
77from tensorflow.python.framework import tensor_shape
78from tensorflow.python.framework import tensor_spec
79from tensorflow.python.ops import array_ops
80from tensorflow.python.ops import math_ops
81from tensorflow.python.types import data as data_types
82from tensorflow.python.util import nest
83from tensorflow.python.util.tf_export import tf_export
86@dataclasses.dataclass
87class TFDataServiceConfig:
88 """Specifies the tf.data service configuration to use.
90 Attributes:
91 dispatcher_address: a string specifying the address of the tf.data service
92 dispatcher server.
93 job_name: a non-empty string identifying the shared job that will be created
94 on tf.data service to process this dataset.
95 """
96 dispatcher_address: str
97 job_name: str
100# TODO(b/223275517): Add support for get_next_as_optional().
101class _DTensorIterator(iterator_ops.OwnedIterator):
102 """An iterator for a tf.data.Dataset distributed using DTensor.
104 DTensorIterator encapsulates multiple underlying dataset iterators. It handles
105 retrieving the tensors to be placed on each underlying device and then uses
106 the 'pack' operation to create and return a DTensor. Thus users need only
107 interact with a single DTensorIterator to automatically distribute dataset
108 tensors onto devices.
109 """
111 def __init__(
112 self,
113 dtensor_components: Tuple[ops.Tensor],
114 global_element_spec: tensor_spec.TensorSpec,
115 layouts: Any):
116 """Initializes a distributed iterator for DTensor datasets.
118 This iterator encapsulates tf.data iterators for the underlying devices, and
119 treats it as a packed DTensor of iterator resource tensors.
121 Args:
122 dtensor_components: a tuple containing the underlying iterator resources
123 packed into a DTensor. This is expected to be a tuple with a single
124 element.
125 global_element_spec: the underlying dataset's element spec from a global
126 view.
127 layouts: a structure of DTensor layouts to be applied to the elements
128 returned by the underlying iterators. This can be a single layout or
129 (possibly nested) tuples or dictionaries of layouts, and the structure
130 must match the structure of the iterator elements.
131 """
132 # dtensor_components is expected to be a single-element tuple.
133 [self._iterator_resource_dtensor] = dtensor_components
134 self._global_element_spec = global_element_spec
135 self._layouts = layouts
136 self._layouts_str = nest.map_structure(
137 lambda layout: layout.to_string(), layouts)
139 super().__init__(
140 components=dtensor_components, element_spec=global_element_spec)
142 def __next__(self):
143 try:
144 # IteratorGetNext will return a DTensor on the host, so move it to the
145 # device mesh. If the dataset layouts are on the host mesh itself, this
146 # is handled by DTensor as a no-op.
147 host_elem = self._next_internal()
148 device_elem = nest.map_structure(
149 api.copy_to_mesh, host_elem, self._layouts)
150 context.async_wait()
151 return device_elem
152 except errors.OutOfRangeError as e:
153 # Match TF2 eager executor behavior by raising StopIteration when iterator
154 # is out of range.
155 if context.executing_eagerly():
156 raise StopIteration from e
157 else:
158 raise e
160 @property
161 def _type_spec(self):
162 return _DTensorIteratorSpec(self._global_element_spec, self._layouts_str)
165class _DTensorIteratorSpec(iterator_ops.IteratorSpec):
166 """Type specification for `_DTensorIterator`."""
168 __slots__ = ['_global_element_spec', '_layouts_str']
170 def __init__(
171 self, global_element_spec: tensor_spec.TensorSpec, layouts_str: Any):
172 super().__init__(global_element_spec)
173 self._global_element_spec = global_element_spec
174 self._layouts_str = layouts_str
176 @property
177 def value_type(self):
178 return _DTensorIterator
180 def _serialize(self):
181 return (self._global_element_spec, self._layouts_str)
183 @property
184 def _component_specs(self):
185 return (tensor_spec.TensorSpec([], dtypes.resource),)
187 def _to_components(self, value):
188 return (value._iterator_resource_dtensor,) # pylint: disable=protected-access
190 def _from_components(self, components):
191 layouts = nest.map_structure(
192 layout_lib.Layout.from_string, self._layouts_str)
193 return _DTensorIterator(
194 dtensor_components=components,
195 global_element_spec=self._global_element_spec,
196 layouts=layouts)
198 @classmethod
199 def from_value(cls, value):
200 return cls(value._global_element_spec, value._layouts_str) # pylint: disable=protected-access
203def _validate_input(flattened_layouts: Sequence[layout_lib.Layout],
204 flattened_elem_spec: Sequence[tensor_spec.TensorSpec],
205 dataset_already_batched: bool):
206 """Checks that the dataset's layouts and element specs are compatible.
208 Args:
209 flattened_layouts: the flattened list of layouts used to distribute the
210 dataset.
211 flattened_elem_spec: the flattened list of element specs used in the
212 dataset's components.
213 dataset_already_batched: whether the dataset to be validated is already
214 batched.
216 Raises:
217 ValueError: if the dataset's inputs are incompatible.
218 """
219 if not flattened_elem_spec:
220 raise ValueError(
221 'Expected input element spec of at least one element, was empty.')
223 first_elem_shape = flattened_elem_spec[0].shape
225 for layout, elem_spec in zip(flattened_layouts, flattened_elem_spec):
226 if elem_spec.shape.rank is None:
227 raise ValueError(
228 'Dataset element shape must have a valid rank, got spec %s.' %
229 elem_spec)
231 # Check that layout's rank matches the element's rank. If dataset is not yet
232 # batched, then the layout's rank must be one greater than the element's
233 # rank.
234 expected_rank = elem_spec.shape.rank
235 if not dataset_already_batched:
236 expected_rank += 1
237 if layout.rank != expected_rank:
238 raise ValueError(
239 ('Expected layout with rank %d for element spec %s, got layout %s. '
240 'Check that the dataset is not batched before passing to '
241 'DTensorDataset.') %
242 (expected_rank, elem_spec, layout.sharding_specs))
244 if dataset_already_batched:
245 # Check that the batch dimension size of all dataset elements match.
246 batch_dim_size = first_elem_shape.as_list()[0]
247 if batch_dim_size is None:
248 raise ValueError(
249 ('Size of batch dimension of element spec %s is None. Ensure '
250 'drop_remainder=True when batching the dataset.') % elem_spec)
252 if elem_spec.shape.as_list()[0] != batch_dim_size:
253 raise ValueError(
254 ('Size of batch dimension of element spec %s does not match '
255 'expected size %d.') % (elem_spec, batch_dim_size))
258def _shard_counts(layout: layout_lib.Layout,
259 batch_dim: Optional[str] = None) -> List[int]:
260 """Computes a list of the number of shards in each dimension of the layout.
262 The shard counts are used to slice each dataset element. The batch dimension's
263 count is overridden to 1 since we only consider how many shards to make
264 locally (within each local replica). Sharding across clients is handled by
265 either tf.data.Dataset's shard transformation (in the single-client case) or
266 tf.data service's distribute function (in the multi-client case).
268 Args:
269 layout: the layout to compute the shard counts for.
270 batch_dim: the name of the batch dimension of the layout, if present.
272 Returns:
273 A list of shard counts, one element per dimension of the layout.
274 """
275 shard_counts = []
276 for spec in layout.sharding_specs:
277 if spec in (batch_dim, layout_lib.UNSHARDED):
278 shard_counts.append(1)
279 else:
280 shard_counts.append(layout.mesh.dim_size(spec))
281 return shard_counts
284def _index_matrix(layout: layout_lib.Layout,
285 elem_spec: tensor_spec.TensorSpec) -> ops.Tensor:
286 """Computes a utility matrix to derive device-based slice offsets.
288 This function builds a matrix of shape `[mesh.rank, layout.rank]` for each
289 dataset element. This matrix can be used to slice the DTensor components
290 returned by the iterator according to the local device that component is to be
291 placed on. This can be done by multiplying the device offsets of shape
292 `[1, mesh.rank]` with this index matrix to get a `[1, layout.rank]` shape
293 tensor containing the slice offsets.
295 Note: the index on the batch dim is always 0 since sharding on the batch
296 dimension is handled by either tf.data.Dataset's shard transformation (in the
297 single-client case) or tf.data service's distribute function (in the
298 multi-client case). If there is no sharding on the batch dimension (or any
299 other dimension), the slice index remains 0.
301 Args:
302 layout: the layout of the dataset element.
303 elem_spec: the spec of the dataset element.
305 Returns:
306 The index matrix as a tensor.
307 """
308 matrix = []
309 for dim in layout.mesh.dim_names:
310 row = [0]
311 for layout_idx, spec in enumerate(layout.sharding_specs[1:]):
312 if spec == layout_lib.UNSHARDED or spec != dim:
313 row.append(0)
314 else:
315 row.append(elem_spec.shape[layout_idx] // layout.mesh.dim_size(dim))
316 matrix.append(row)
318 return constant_op.constant(matrix, dtype=dtypes.int32)
321def _pack_iterator_resource_dtensor(
322 datasets: List[Tuple[int, data_types.DatasetV2]],
323 layouts: Any,
324 mesh: layout_lib.Mesh,
325 num_local_devices_per_replica: int):
326 """Creates a DTensor iterator resource for the per-replica datasets.
328 Given a list of replica ID to tf.data.Dataset mappings, this function creates
329 iterators for each device and then packs the underlying iterator resource
330 tensors into a single DTensor. This resource tensor is used by the
331 IteratorGetNext op to retrieve the next element in the dataset.
333 Args:
334 datasets: a list of tuples of each unique local replica ID to the dataset
335 object whose elements will be placed on the devices corresponding to that
336 replica.
337 layouts: a structure of DTensor layouts to be applied to the elements
338 returned by the underlying iterators. This can be a single layout or
339 (possibly nested) tuples or dictionaries of layouts, and the structure
340 must match the structure of the iterator elements.
341 mesh: the DTensor mesh to place the iterator batches on.
342 num_local_devices_per_replica: the number of devices in each data-parallel
343 replica.
345 Returns:
346 A DTensor of the underlying iterator resource tensors.
347 """
348 host_mesh_devices = mesh.host_mesh().local_devices()
349 device_idx = 0
351 iterators = []
352 for _, dataset in datasets:
353 for idx in range(num_local_devices_per_replica):
354 with ops.device_v2(host_mesh_devices[device_idx]):
355 device_dataset = dataset.shard(
356 num_shards=num_local_devices_per_replica, index=idx)
357 iterators.append(iter(device_dataset))
358 device_idx += 1
360 if device_idx != len(host_mesh_devices):
361 raise ValueError(
362 'The `datasets` argument does not have the correct number of'
363 f' underlying datasets, found {device_idx} but expected'
364 f' {len(host_mesh_devices)}.')
366 host_layouts = nest.map_structure(
367 lambda l: layout_lib.Layout(l.sharding_specs, mesh.host_mesh()), layouts)
369 # Pack the iterator resource tensors into a replicated 0-dimensional DTensor
370 # and set the element layouts.
371 iterator_resources = [it._iterator_resource for it in iterators] # pylint: disable=protected-access
372 d_iterator_resource = api.pack(
373 iterator_resources,
374 layout_lib.Layout.replicated(mesh=mesh.host_mesh(), rank=0))
375 api._dtensor_device().set_iterator_element_layouts( # pylint: disable=protected-access
376 d_iterator_resource, nest.flatten(host_layouts))
378 return d_iterator_resource
381@tf_export('experimental.dtensor.DTensorDataset', v1=[])
382class DTensorDataset(dataset_ops.UnaryUnchangedStructureDataset):
383 """A dataset of DTensors.
385 DTensorDataset encapsulates a `tf.data.Dataset` whose elements are
386 automatically packed and returned as DTensors based on a given mesh and
387 layouts.
388 """
390 def __init__(self,
391 dataset: data_types.DatasetV2,
392 *,
393 mesh: layout_lib.Mesh,
394 layouts: Any,
395 global_batch_size: int,
396 dataset_already_batched: bool = False,
397 batch_dim: Optional[str] = None,
398 prefetch: Optional[int] = None,
399 tf_data_service_config: Optional[TFDataServiceConfig] = None):
400 """Creates a DTensorDataset.
402 DTensorDataset automatically handles distribution of the dataset elements to
403 each client's devices. It can be used to create an iterator that returns
404 DTensors of the input data on each iteration.
406 DTensorDataset works best with unbatched datasets. It takes the mesh and the
407 provided layouts to automatically calculate how to batch the input locally
408 for each replica.
410 If the provided dataset is already batched according to the per-replica
411 batch size, then `dataset_already_batched` must be set and DTensorDataset
412 will check that the batch size is consistent with the intended
413 `global_batch_size` using the layout information. Each replica receives a
414 separate slice of the global batch, thus the per-replica batch size can be
415 computed as the global batch size divided by the number of model replicas.
416 For a DTensor mesh, the number of replicas is equal to the size of the
417 mesh's batch dimension.
419 TODO(b/223275517): add support for input datasets that are already batched
420 to the global batch size.
422 Args:
423 dataset: a `tf.data.Dataset` object.
424 mesh: the DTensor mesh to place the dataset batches on.
425 layouts: a structure of DTensor layouts to be applied to the input dataset
426 values. This can be a single layout or (possibly nested) tuples or
427 dictionaries of layouts, and the structure must match the structure of
428 the dataset. Either all or none of the layouts should be sharded on the
429 batch dimension; having only a subset of layouts batch sharded will not
430 work and raises a ValueError.
431 global_batch_size: the desired global batch size.
432 dataset_already_batched: must be set only if the dataset is already
433 batched to the per-replica batch size. The batched dataset must have
434 `drop_remainder=True` set since DTensor requires static shapes for
435 slicing the input tensors.
436 batch_dim: the mesh dimension on which the input's batch dimension is
437 sharded. Set to None if the input layouts do not shard on the batch
438 dimension.
439 prefetch: number of batches to prefetch using Dataset.prefetch.
440 tf_data_service_config: if operating in multi-client mode, this config
441 specifies the tf.data service configuration to use.
443 Raises:
444 ValueError: on any of the following situations,
445 1. if the structures and ranks of layouts and the dataset do not match.
446 2. if the shapes in the dataset's spec are not fully defined.
447 3. if batch_dim is specified and all layouts are not batch-sharded.
448 4. if per_replica_batch_size is specified for an already batched Dataset
449 but it does not match the expected per-replica size based on the
450 provided mesh.
451 TypeError: if type of structures of layouts and the dataset do not match.
452 """
453 super().__init__(dataset, dataset_ops.to_variant(dataset))
455 # TODO(b/271162918): fix multi-client use case.
456 if tf_data_service_config is not None:
457 raise NotImplementedError(
458 'Multi-client DTensorDataset is currently not supported.'
459 ' Check b/271162918.')
461 self._mesh = mesh
462 self._layouts = layouts
463 self._batch_dim = batch_dim
464 self._prefetch = prefetch
465 self._tf_data_service_config = tf_data_service_config
467 nest.assert_same_structure(dataset.element_spec, layouts)
468 flattened_layouts = nest.flatten(layouts)
469 flattened_elem_spec = nest.flatten(dataset.element_spec)
471 if batch_dim:
472 num_global_replicas = mesh.dim_size(batch_dim)
473 self._local_replica_ids = list(
474 dict.fromkeys(
475 [loc[batch_dim] for loc in mesh.local_device_locations()]))
477 for layout in flattened_layouts:
478 if batch_dim != layout.sharding_specs[0]:
479 raise ValueError(
480 ('batch_dim %s was specified but at least one layout did not '
481 'contain it: %s') % (batch_dim, layout))
482 else:
483 # Only one replica since there is no sharding on the batch dimension.
484 num_global_replicas = 1
485 self._local_replica_ids = [0]
487 # Validate layout and element spec compatibility, and raise ValueError if
488 # invalid.
489 _validate_input(
490 flattened_layouts,
491 flattened_elem_spec,
492 dataset_already_batched=dataset_already_batched)
494 expected_batch_size = global_batch_size // num_global_replicas
495 if not dataset_already_batched:
496 self._batched_dataset = dataset.batch(
497 expected_batch_size, drop_remainder=True)
498 else:
499 per_replica_batch_size = flattened_elem_spec[0].shape.as_list()[0]
500 if per_replica_batch_size != expected_batch_size:
501 raise ValueError(
502 ('per_replica_batch_size does not matched expected size based on '
503 'the mesh, got %d but expected %d.') %
504 (per_replica_batch_size, expected_batch_size))
505 self._batched_dataset = dataset
507 # Construct a global element spec of the dataset.
508 flattened_global_elem_spec = []
509 batch_tensor_shape = tensor_shape.as_shape([global_batch_size])
510 for elem_spec in nest.flatten(self._batched_dataset.element_spec):
511 new_elem_spec = tensor_spec.TensorSpec(
512 shape=operator.concat(batch_tensor_shape, elem_spec.shape[1:]),
513 dtype=elem_spec.dtype,
514 name=elem_spec.name)
515 flattened_global_elem_spec.append(new_elem_spec)
516 self._global_element_spec = nest.pack_sequence_as(
517 dataset.element_spec, flattened_global_elem_spec)
519 num_global_devices_per_replica = config.num_global_devices(
520 mesh.device_type()) // num_global_replicas
521 self._num_local_replicas = len(self._local_replica_ids)
522 self._num_local_devices_per_replica = mesh.num_local_devices(
523 ) // self._num_local_replicas
524 # The number of clients each replica is split over.
525 self._num_clients_per_replica = (
526 num_global_devices_per_replica // self._num_local_devices_per_replica)
527 # In the case where a replica is split across multiple clients, an offset
528 # needs to be added to the index used by the partitioning logic such that
529 # the local devices on that client can be correctly matched to slices of the
530 # input tensor(s). If replicas are wholly contained within a client, then
531 # this offset is always 0.
532 self._partition_offset = (config.client_id() % self._num_clients_per_replica
533 ) * self._num_local_devices_per_replica
535 # Helper data structures used in partitioning the dataset tensors.
536 self._all_shard_counts = [
537 _shard_counts(layout, batch_dim) for layout in flattened_layouts
538 ]
539 self._index_matrices = [
540 _index_matrix(layout, elem_spec)
541 for layout, elem_spec in zip(flattened_layouts, flattened_elem_spec)
542 ]
544 def __iter__(self):
545 datasets: List[Tuple[int, data_types.DatasetV2]] = []
547 # Start with the batched the dataset.
548 local_dataset = self._batched_dataset
550 if self._batch_dim is not None:
551 if self._num_clients_per_replica > 1:
552 # If a replica is split over multiple clients then each batch needs to
553 # be repeated before distribution as many times as there are clients
554 # corresponding to that replica.
555 local_dataset = self._repeat_batch(local_dataset,
556 self._num_clients_per_replica)
557 sharding_policy = data_service_ops.ShardingPolicy.DATA
558 else:
559 # Replicas are unique to each client, so FILE based sharding can be used
560 # which is more performant since each worker does not need to read the
561 # entire dataset.
562 sharding_policy = data_service_ops.ShardingPolicy.FILE
563 else:
564 # No batch dimension sharding specified so disable dataset sharding during
565 # the distribute step.
566 sharding_policy = data_service_ops.ShardingPolicy.OFF
568 # Apply distribution here (if specified) so all remaining transformations
569 # are executed locally.
570 if self._tf_data_service_config is not None:
571 local_dataset = local_dataset.apply(
572 data_service_ops.distribute(
573 processing_mode=sharding_policy,
574 service=self._tf_data_service_config.dispatcher_address,
575 job_name=f'{self._tf_data_service_config.job_name}_{config.client_id()}',
576 target_workers='LOCAL'))
578 for local_replica_idx, replica_id in enumerate(self._local_replica_ids):
579 # Select the shard for the corresponding replica.
580 dataset = local_dataset.shard(self._num_local_replicas, local_replica_idx)
582 # Repeat each batch for each local device in the replica.
583 dataset = self._repeat_batch(dataset, self._num_local_devices_per_replica)
585 # Slice each shard further for all non-batch dim shards. If there is no
586 # non-batch dim sharding, this slice is essentially a no-op.
587 dataset = self._partition(dataset)
589 # Apply prefetch as the last step. Since each batch is repeated, the
590 # number of elements to prefetch has to be scaled by the same size.
591 if self._prefetch is not None:
592 dataset = dataset.prefetch(
593 self._prefetch * self._num_local_devices_per_replica)
595 datasets.append((replica_id, dataset))
597 # Convert the datasets into iterators placed on the host.
598 d_iterator_resource = _pack_iterator_resource_dtensor(
599 datasets=datasets,
600 layouts=self._layouts,
601 mesh=self._mesh,
602 num_local_devices_per_replica=self._num_local_devices_per_replica)
604 return _DTensorIterator(
605 dtensor_components=(d_iterator_resource,),
606 global_element_spec=self._global_element_spec,
607 layouts=self._layouts)
609 def _repeat_batch(self, dataset, repeats):
610 def repeat(*x):
611 return dataset_ops.DatasetV2.from_tensors(x).repeat(repeats)
613 return dataset.flat_map(repeat)
615 def _partition(self, dataset):
616 """Slices each dataset element on any sharded non-batch dimension."""
618 # TODO(b/223275517): decouple from self and make testable.
619 def slice_batch(index, batch):
620 flattened_batch = nest.flatten(batch)
621 flattened_output = []
623 norm_index = math_ops.cast(
624 index % self._num_local_devices_per_replica, dtype=dtypes.int32)
625 norm_index += self._partition_offset
626 coords = self._mesh.coords(norm_index)
627 coords = array_ops.reshape(coords, (1, -1))
629 for element, shard_counts, idx_matrix in zip(flattened_batch,
630 self._all_shard_counts,
631 self._index_matrices):
632 indexes = math_ops.matmul(coords, idx_matrix)
633 start = array_ops.reshape(indexes, (-1,))
634 size = array_ops.shape_v2(
635 element, out_type=dtypes.int32) // shard_counts
636 flattened_output.append(
637 array_ops.slice(element, begin=start, size=size))
639 return nest.pack_sequence_as(batch, flattened_output)
641 enumerated_dataset = dataset.enumerate()
642 partitioned_dataset = enumerated_dataset.map(slice_batch)
643 return partitioned_dataset
645 @property
646 def element_spec(self):
647 return self._global_element_spec