Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/input_util.py: 27%

207 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""APIs to deal with input datasets efficiently in DTensor. 

16 

17When using tf.data with DTensor, the `DTensorDataset` API can be used to 

18efficiently handle loading the input data and correctly packing it to the 

19corresponding devices. This API is intended to work with unbatched data and can 

20be used for both data and model parallel setups. 

21 

22Example usage: 

23 

24>>> # 1-D mesh with 4 devices 

25>>> mesh = dtensor.Mesh(dim_names=['batch'], ...) 

26>>> layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1) 

27>>> dataset = tf.data.Dataset.range(256) 

28>>> d_dataset = dtensor.DTensorDataset( 

29... dataset=dataset, 

30... global_batch_size=16, 

31... mesh=mesh, 

32... layouts=layout, 

33... batch_dim='batch') 

34>>> d_iter = iter(d_dataset) 

35>>> # Each batch is a length 16 tensor sharded across 4 devices 

36>>> batch_0_dtensor = next(d_iter) 

37>>> batch_0_dtensor 

38<tf.Tensor: shape=(16,), 

39 dtype=int64, 

40 value={"CPU:0": [ 0 1 2 4], 

41 "CPU:1": [ 5 6 7 8], 

42 "CPU:2": [ 9 10 11 12], 

43 "CPU:3": [13 14 15 16]}> 

44>>> batch_1_dtensor = next(d_iter) 

45>>> batch_1_dtensor 

46<tf.Tensor: shape=(16,), 

47 dtype=int64, 

48 value={"CPU:0": [17 18 19 20], 

49 "CPU:1": [21 22 23 24], 

50 "CPU:2": [25 26 27 28], 

51 "CPU:3": [29 30 31 32]}> 

52 

53For multi-client setups, `DTensorDataset` interacts with tf.data service to 

54correctly distribute the dataset among the participating clients. DTensor works 

55with tf.data service in co-located mode where each worker is running alongside 

56the DTensor client (the Tensorflow Python process). The `TFDataServiceConfig` 

57dataclass can be filled with information about the tf.data service cluster, and 

58passed to `DTensorDataset` to enable distribution. 

59""" 

60 

61import dataclasses 

62import operator 

63 

64from typing import Any, List, Optional, Sequence, Tuple 

65 

66from tensorflow.dtensor.python import api 

67from tensorflow.dtensor.python import config 

68from tensorflow.dtensor.python import layout as layout_lib 

69from tensorflow.python.data.experimental.ops import data_service_ops 

70from tensorflow.python.data.ops import dataset_ops 

71from tensorflow.python.data.ops import iterator_ops 

72from tensorflow.python.eager import context 

73from tensorflow.python.framework import constant_op 

74from tensorflow.python.framework import dtypes 

75from tensorflow.python.framework import errors 

76from tensorflow.python.framework import ops 

77from tensorflow.python.framework import tensor_shape 

78from tensorflow.python.framework import tensor_spec 

79from tensorflow.python.ops import array_ops 

80from tensorflow.python.ops import math_ops 

81from tensorflow.python.types import data as data_types 

82from tensorflow.python.util import nest 

83from tensorflow.python.util.tf_export import tf_export 

84 

85 

86@dataclasses.dataclass 

87class TFDataServiceConfig: 

88 """Specifies the tf.data service configuration to use. 

89 

90 Attributes: 

91 dispatcher_address: a string specifying the address of the tf.data service 

92 dispatcher server. 

93 job_name: a non-empty string identifying the shared job that will be created 

94 on tf.data service to process this dataset. 

95 """ 

96 dispatcher_address: str 

97 job_name: str 

98 

99 

100# TODO(b/223275517): Add support for get_next_as_optional(). 

101class _DTensorIterator(iterator_ops.OwnedIterator): 

102 """An iterator for a tf.data.Dataset distributed using DTensor. 

103 

104 DTensorIterator encapsulates multiple underlying dataset iterators. It handles 

105 retrieving the tensors to be placed on each underlying device and then uses 

106 the 'pack' operation to create and return a DTensor. Thus users need only 

107 interact with a single DTensorIterator to automatically distribute dataset 

108 tensors onto devices. 

109 """ 

110 

111 def __init__( 

112 self, 

113 dtensor_components: Tuple[ops.Tensor], 

114 global_element_spec: tensor_spec.TensorSpec, 

115 layouts: Any): 

116 """Initializes a distributed iterator for DTensor datasets. 

117 

118 This iterator encapsulates tf.data iterators for the underlying devices, and 

119 treats it as a packed DTensor of iterator resource tensors. 

120 

121 Args: 

122 dtensor_components: a tuple containing the underlying iterator resources 

123 packed into a DTensor. This is expected to be a tuple with a single 

124 element. 

125 global_element_spec: the underlying dataset's element spec from a global 

126 view. 

127 layouts: a structure of DTensor layouts to be applied to the elements 

128 returned by the underlying iterators. This can be a single layout or 

129 (possibly nested) tuples or dictionaries of layouts, and the structure 

130 must match the structure of the iterator elements. 

131 """ 

132 # dtensor_components is expected to be a single-element tuple. 

133 [self._iterator_resource_dtensor] = dtensor_components 

134 self._global_element_spec = global_element_spec 

135 self._layouts = layouts 

136 self._layouts_str = nest.map_structure( 

137 lambda layout: layout.to_string(), layouts) 

138 

139 super().__init__( 

140 components=dtensor_components, element_spec=global_element_spec) 

141 

142 def __next__(self): 

143 try: 

144 # IteratorGetNext will return a DTensor on the host, so move it to the 

145 # device mesh. If the dataset layouts are on the host mesh itself, this 

146 # is handled by DTensor as a no-op. 

147 host_elem = self._next_internal() 

148 device_elem = nest.map_structure( 

149 api.copy_to_mesh, host_elem, self._layouts) 

150 context.async_wait() 

151 return device_elem 

152 except errors.OutOfRangeError as e: 

153 # Match TF2 eager executor behavior by raising StopIteration when iterator 

154 # is out of range. 

155 if context.executing_eagerly(): 

156 raise StopIteration from e 

157 else: 

158 raise e 

159 

160 @property 

161 def _type_spec(self): 

162 return _DTensorIteratorSpec(self._global_element_spec, self._layouts_str) 

163 

164 

165class _DTensorIteratorSpec(iterator_ops.IteratorSpec): 

166 """Type specification for `_DTensorIterator`.""" 

167 

168 __slots__ = ['_global_element_spec', '_layouts_str'] 

169 

170 def __init__( 

171 self, global_element_spec: tensor_spec.TensorSpec, layouts_str: Any): 

172 super().__init__(global_element_spec) 

173 self._global_element_spec = global_element_spec 

174 self._layouts_str = layouts_str 

175 

176 @property 

177 def value_type(self): 

178 return _DTensorIterator 

179 

180 def _serialize(self): 

181 return (self._global_element_spec, self._layouts_str) 

182 

183 @property 

184 def _component_specs(self): 

185 return (tensor_spec.TensorSpec([], dtypes.resource),) 

186 

187 def _to_components(self, value): 

188 return (value._iterator_resource_dtensor,) # pylint: disable=protected-access 

189 

190 def _from_components(self, components): 

191 layouts = nest.map_structure( 

192 layout_lib.Layout.from_string, self._layouts_str) 

193 return _DTensorIterator( 

194 dtensor_components=components, 

195 global_element_spec=self._global_element_spec, 

196 layouts=layouts) 

197 

198 @classmethod 

199 def from_value(cls, value): 

200 return cls(value._global_element_spec, value._layouts_str) # pylint: disable=protected-access 

201 

202 

203def _validate_input(flattened_layouts: Sequence[layout_lib.Layout], 

204 flattened_elem_spec: Sequence[tensor_spec.TensorSpec], 

205 dataset_already_batched: bool): 

206 """Checks that the dataset's layouts and element specs are compatible. 

207 

208 Args: 

209 flattened_layouts: the flattened list of layouts used to distribute the 

210 dataset. 

211 flattened_elem_spec: the flattened list of element specs used in the 

212 dataset's components. 

213 dataset_already_batched: whether the dataset to be validated is already 

214 batched. 

215 

216 Raises: 

217 ValueError: if the dataset's inputs are incompatible. 

218 """ 

219 if not flattened_elem_spec: 

220 raise ValueError( 

221 'Expected input element spec of at least one element, was empty.') 

222 

223 first_elem_shape = flattened_elem_spec[0].shape 

224 

225 for layout, elem_spec in zip(flattened_layouts, flattened_elem_spec): 

226 if elem_spec.shape.rank is None: 

227 raise ValueError( 

228 'Dataset element shape must have a valid rank, got spec %s.' % 

229 elem_spec) 

230 

231 # Check that layout's rank matches the element's rank. If dataset is not yet 

232 # batched, then the layout's rank must be one greater than the element's 

233 # rank. 

234 expected_rank = elem_spec.shape.rank 

235 if not dataset_already_batched: 

236 expected_rank += 1 

237 if layout.rank != expected_rank: 

238 raise ValueError( 

239 ('Expected layout with rank %d for element spec %s, got layout %s. ' 

240 'Check that the dataset is not batched before passing to ' 

241 'DTensorDataset.') % 

242 (expected_rank, elem_spec, layout.sharding_specs)) 

243 

244 if dataset_already_batched: 

245 # Check that the batch dimension size of all dataset elements match. 

246 batch_dim_size = first_elem_shape.as_list()[0] 

247 if batch_dim_size is None: 

248 raise ValueError( 

249 ('Size of batch dimension of element spec %s is None. Ensure ' 

250 'drop_remainder=True when batching the dataset.') % elem_spec) 

251 

252 if elem_spec.shape.as_list()[0] != batch_dim_size: 

253 raise ValueError( 

254 ('Size of batch dimension of element spec %s does not match ' 

255 'expected size %d.') % (elem_spec, batch_dim_size)) 

256 

257 

258def _shard_counts(layout: layout_lib.Layout, 

259 batch_dim: Optional[str] = None) -> List[int]: 

260 """Computes a list of the number of shards in each dimension of the layout. 

261 

262 The shard counts are used to slice each dataset element. The batch dimension's 

263 count is overridden to 1 since we only consider how many shards to make 

264 locally (within each local replica). Sharding across clients is handled by 

265 either tf.data.Dataset's shard transformation (in the single-client case) or 

266 tf.data service's distribute function (in the multi-client case). 

267 

268 Args: 

269 layout: the layout to compute the shard counts for. 

270 batch_dim: the name of the batch dimension of the layout, if present. 

271 

272 Returns: 

273 A list of shard counts, one element per dimension of the layout. 

274 """ 

275 shard_counts = [] 

276 for spec in layout.sharding_specs: 

277 if spec in (batch_dim, layout_lib.UNSHARDED): 

278 shard_counts.append(1) 

279 else: 

280 shard_counts.append(layout.mesh.dim_size(spec)) 

281 return shard_counts 

282 

283 

284def _index_matrix(layout: layout_lib.Layout, 

285 elem_spec: tensor_spec.TensorSpec) -> ops.Tensor: 

286 """Computes a utility matrix to derive device-based slice offsets. 

287 

288 This function builds a matrix of shape `[mesh.rank, layout.rank]` for each 

289 dataset element. This matrix can be used to slice the DTensor components 

290 returned by the iterator according to the local device that component is to be 

291 placed on. This can be done by multiplying the device offsets of shape 

292 `[1, mesh.rank]` with this index matrix to get a `[1, layout.rank]` shape 

293 tensor containing the slice offsets. 

294 

295 Note: the index on the batch dim is always 0 since sharding on the batch 

296 dimension is handled by either tf.data.Dataset's shard transformation (in the 

297 single-client case) or tf.data service's distribute function (in the 

298 multi-client case). If there is no sharding on the batch dimension (or any 

299 other dimension), the slice index remains 0. 

300 

301 Args: 

302 layout: the layout of the dataset element. 

303 elem_spec: the spec of the dataset element. 

304 

305 Returns: 

306 The index matrix as a tensor. 

307 """ 

308 matrix = [] 

309 for dim in layout.mesh.dim_names: 

310 row = [0] 

311 for layout_idx, spec in enumerate(layout.sharding_specs[1:]): 

312 if spec == layout_lib.UNSHARDED or spec != dim: 

313 row.append(0) 

314 else: 

315 row.append(elem_spec.shape[layout_idx] // layout.mesh.dim_size(dim)) 

316 matrix.append(row) 

317 

318 return constant_op.constant(matrix, dtype=dtypes.int32) 

319 

320 

321def _pack_iterator_resource_dtensor( 

322 datasets: List[Tuple[int, data_types.DatasetV2]], 

323 layouts: Any, 

324 mesh: layout_lib.Mesh, 

325 num_local_devices_per_replica: int): 

326 """Creates a DTensor iterator resource for the per-replica datasets. 

327 

328 Given a list of replica ID to tf.data.Dataset mappings, this function creates 

329 iterators for each device and then packs the underlying iterator resource 

330 tensors into a single DTensor. This resource tensor is used by the 

331 IteratorGetNext op to retrieve the next element in the dataset. 

332 

333 Args: 

334 datasets: a list of tuples of each unique local replica ID to the dataset 

335 object whose elements will be placed on the devices corresponding to that 

336 replica. 

337 layouts: a structure of DTensor layouts to be applied to the elements 

338 returned by the underlying iterators. This can be a single layout or 

339 (possibly nested) tuples or dictionaries of layouts, and the structure 

340 must match the structure of the iterator elements. 

341 mesh: the DTensor mesh to place the iterator batches on. 

342 num_local_devices_per_replica: the number of devices in each data-parallel 

343 replica. 

344 

345 Returns: 

346 A DTensor of the underlying iterator resource tensors. 

347 """ 

348 host_mesh_devices = mesh.host_mesh().local_devices() 

349 device_idx = 0 

350 

351 iterators = [] 

352 for _, dataset in datasets: 

353 for idx in range(num_local_devices_per_replica): 

354 with ops.device_v2(host_mesh_devices[device_idx]): 

355 device_dataset = dataset.shard( 

356 num_shards=num_local_devices_per_replica, index=idx) 

357 iterators.append(iter(device_dataset)) 

358 device_idx += 1 

359 

360 if device_idx != len(host_mesh_devices): 

361 raise ValueError( 

362 'The `datasets` argument does not have the correct number of' 

363 f' underlying datasets, found {device_idx} but expected' 

364 f' {len(host_mesh_devices)}.') 

365 

366 host_layouts = nest.map_structure( 

367 lambda l: layout_lib.Layout(l.sharding_specs, mesh.host_mesh()), layouts) 

368 

369 # Pack the iterator resource tensors into a replicated 0-dimensional DTensor 

370 # and set the element layouts. 

371 iterator_resources = [it._iterator_resource for it in iterators] # pylint: disable=protected-access 

372 d_iterator_resource = api.pack( 

373 iterator_resources, 

374 layout_lib.Layout.replicated(mesh=mesh.host_mesh(), rank=0)) 

375 api._dtensor_device().set_iterator_element_layouts( # pylint: disable=protected-access 

376 d_iterator_resource, nest.flatten(host_layouts)) 

377 

378 return d_iterator_resource 

379 

380 

381@tf_export('experimental.dtensor.DTensorDataset', v1=[]) 

382class DTensorDataset(dataset_ops.UnaryUnchangedStructureDataset): 

383 """A dataset of DTensors. 

384 

385 DTensorDataset encapsulates a `tf.data.Dataset` whose elements are 

386 automatically packed and returned as DTensors based on a given mesh and 

387 layouts. 

388 """ 

389 

390 def __init__(self, 

391 dataset: data_types.DatasetV2, 

392 *, 

393 mesh: layout_lib.Mesh, 

394 layouts: Any, 

395 global_batch_size: int, 

396 dataset_already_batched: bool = False, 

397 batch_dim: Optional[str] = None, 

398 prefetch: Optional[int] = None, 

399 tf_data_service_config: Optional[TFDataServiceConfig] = None): 

400 """Creates a DTensorDataset. 

401 

402 DTensorDataset automatically handles distribution of the dataset elements to 

403 each client's devices. It can be used to create an iterator that returns 

404 DTensors of the input data on each iteration. 

405 

406 DTensorDataset works best with unbatched datasets. It takes the mesh and the 

407 provided layouts to automatically calculate how to batch the input locally 

408 for each replica. 

409 

410 If the provided dataset is already batched according to the per-replica 

411 batch size, then `dataset_already_batched` must be set and DTensorDataset 

412 will check that the batch size is consistent with the intended 

413 `global_batch_size` using the layout information. Each replica receives a 

414 separate slice of the global batch, thus the per-replica batch size can be 

415 computed as the global batch size divided by the number of model replicas. 

416 For a DTensor mesh, the number of replicas is equal to the size of the 

417 mesh's batch dimension. 

418 

419 TODO(b/223275517): add support for input datasets that are already batched 

420 to the global batch size. 

421 

422 Args: 

423 dataset: a `tf.data.Dataset` object. 

424 mesh: the DTensor mesh to place the dataset batches on. 

425 layouts: a structure of DTensor layouts to be applied to the input dataset 

426 values. This can be a single layout or (possibly nested) tuples or 

427 dictionaries of layouts, and the structure must match the structure of 

428 the dataset. Either all or none of the layouts should be sharded on the 

429 batch dimension; having only a subset of layouts batch sharded will not 

430 work and raises a ValueError. 

431 global_batch_size: the desired global batch size. 

432 dataset_already_batched: must be set only if the dataset is already 

433 batched to the per-replica batch size. The batched dataset must have 

434 `drop_remainder=True` set since DTensor requires static shapes for 

435 slicing the input tensors. 

436 batch_dim: the mesh dimension on which the input's batch dimension is 

437 sharded. Set to None if the input layouts do not shard on the batch 

438 dimension. 

439 prefetch: number of batches to prefetch using Dataset.prefetch. 

440 tf_data_service_config: if operating in multi-client mode, this config 

441 specifies the tf.data service configuration to use. 

442 

443 Raises: 

444 ValueError: on any of the following situations, 

445 1. if the structures and ranks of layouts and the dataset do not match. 

446 2. if the shapes in the dataset's spec are not fully defined. 

447 3. if batch_dim is specified and all layouts are not batch-sharded. 

448 4. if per_replica_batch_size is specified for an already batched Dataset 

449 but it does not match the expected per-replica size based on the 

450 provided mesh. 

451 TypeError: if type of structures of layouts and the dataset do not match. 

452 """ 

453 super().__init__(dataset, dataset_ops.to_variant(dataset)) 

454 

455 # TODO(b/271162918): fix multi-client use case. 

456 if tf_data_service_config is not None: 

457 raise NotImplementedError( 

458 'Multi-client DTensorDataset is currently not supported.' 

459 ' Check b/271162918.') 

460 

461 self._mesh = mesh 

462 self._layouts = layouts 

463 self._batch_dim = batch_dim 

464 self._prefetch = prefetch 

465 self._tf_data_service_config = tf_data_service_config 

466 

467 nest.assert_same_structure(dataset.element_spec, layouts) 

468 flattened_layouts = nest.flatten(layouts) 

469 flattened_elem_spec = nest.flatten(dataset.element_spec) 

470 

471 if batch_dim: 

472 num_global_replicas = mesh.dim_size(batch_dim) 

473 self._local_replica_ids = list( 

474 dict.fromkeys( 

475 [loc[batch_dim] for loc in mesh.local_device_locations()])) 

476 

477 for layout in flattened_layouts: 

478 if batch_dim != layout.sharding_specs[0]: 

479 raise ValueError( 

480 ('batch_dim %s was specified but at least one layout did not ' 

481 'contain it: %s') % (batch_dim, layout)) 

482 else: 

483 # Only one replica since there is no sharding on the batch dimension. 

484 num_global_replicas = 1 

485 self._local_replica_ids = [0] 

486 

487 # Validate layout and element spec compatibility, and raise ValueError if 

488 # invalid. 

489 _validate_input( 

490 flattened_layouts, 

491 flattened_elem_spec, 

492 dataset_already_batched=dataset_already_batched) 

493 

494 expected_batch_size = global_batch_size // num_global_replicas 

495 if not dataset_already_batched: 

496 self._batched_dataset = dataset.batch( 

497 expected_batch_size, drop_remainder=True) 

498 else: 

499 per_replica_batch_size = flattened_elem_spec[0].shape.as_list()[0] 

500 if per_replica_batch_size != expected_batch_size: 

501 raise ValueError( 

502 ('per_replica_batch_size does not matched expected size based on ' 

503 'the mesh, got %d but expected %d.') % 

504 (per_replica_batch_size, expected_batch_size)) 

505 self._batched_dataset = dataset 

506 

507 # Construct a global element spec of the dataset. 

508 flattened_global_elem_spec = [] 

509 batch_tensor_shape = tensor_shape.as_shape([global_batch_size]) 

510 for elem_spec in nest.flatten(self._batched_dataset.element_spec): 

511 new_elem_spec = tensor_spec.TensorSpec( 

512 shape=operator.concat(batch_tensor_shape, elem_spec.shape[1:]), 

513 dtype=elem_spec.dtype, 

514 name=elem_spec.name) 

515 flattened_global_elem_spec.append(new_elem_spec) 

516 self._global_element_spec = nest.pack_sequence_as( 

517 dataset.element_spec, flattened_global_elem_spec) 

518 

519 num_global_devices_per_replica = config.num_global_devices( 

520 mesh.device_type()) // num_global_replicas 

521 self._num_local_replicas = len(self._local_replica_ids) 

522 self._num_local_devices_per_replica = mesh.num_local_devices( 

523 ) // self._num_local_replicas 

524 # The number of clients each replica is split over. 

525 self._num_clients_per_replica = ( 

526 num_global_devices_per_replica // self._num_local_devices_per_replica) 

527 # In the case where a replica is split across multiple clients, an offset 

528 # needs to be added to the index used by the partitioning logic such that 

529 # the local devices on that client can be correctly matched to slices of the 

530 # input tensor(s). If replicas are wholly contained within a client, then 

531 # this offset is always 0. 

532 self._partition_offset = (config.client_id() % self._num_clients_per_replica 

533 ) * self._num_local_devices_per_replica 

534 

535 # Helper data structures used in partitioning the dataset tensors. 

536 self._all_shard_counts = [ 

537 _shard_counts(layout, batch_dim) for layout in flattened_layouts 

538 ] 

539 self._index_matrices = [ 

540 _index_matrix(layout, elem_spec) 

541 for layout, elem_spec in zip(flattened_layouts, flattened_elem_spec) 

542 ] 

543 

544 def __iter__(self): 

545 datasets: List[Tuple[int, data_types.DatasetV2]] = [] 

546 

547 # Start with the batched the dataset. 

548 local_dataset = self._batched_dataset 

549 

550 if self._batch_dim is not None: 

551 if self._num_clients_per_replica > 1: 

552 # If a replica is split over multiple clients then each batch needs to 

553 # be repeated before distribution as many times as there are clients 

554 # corresponding to that replica. 

555 local_dataset = self._repeat_batch(local_dataset, 

556 self._num_clients_per_replica) 

557 sharding_policy = data_service_ops.ShardingPolicy.DATA 

558 else: 

559 # Replicas are unique to each client, so FILE based sharding can be used 

560 # which is more performant since each worker does not need to read the 

561 # entire dataset. 

562 sharding_policy = data_service_ops.ShardingPolicy.FILE 

563 else: 

564 # No batch dimension sharding specified so disable dataset sharding during 

565 # the distribute step. 

566 sharding_policy = data_service_ops.ShardingPolicy.OFF 

567 

568 # Apply distribution here (if specified) so all remaining transformations 

569 # are executed locally. 

570 if self._tf_data_service_config is not None: 

571 local_dataset = local_dataset.apply( 

572 data_service_ops.distribute( 

573 processing_mode=sharding_policy, 

574 service=self._tf_data_service_config.dispatcher_address, 

575 job_name=f'{self._tf_data_service_config.job_name}_{config.client_id()}', 

576 target_workers='LOCAL')) 

577 

578 for local_replica_idx, replica_id in enumerate(self._local_replica_ids): 

579 # Select the shard for the corresponding replica. 

580 dataset = local_dataset.shard(self._num_local_replicas, local_replica_idx) 

581 

582 # Repeat each batch for each local device in the replica. 

583 dataset = self._repeat_batch(dataset, self._num_local_devices_per_replica) 

584 

585 # Slice each shard further for all non-batch dim shards. If there is no 

586 # non-batch dim sharding, this slice is essentially a no-op. 

587 dataset = self._partition(dataset) 

588 

589 # Apply prefetch as the last step. Since each batch is repeated, the 

590 # number of elements to prefetch has to be scaled by the same size. 

591 if self._prefetch is not None: 

592 dataset = dataset.prefetch( 

593 self._prefetch * self._num_local_devices_per_replica) 

594 

595 datasets.append((replica_id, dataset)) 

596 

597 # Convert the datasets into iterators placed on the host. 

598 d_iterator_resource = _pack_iterator_resource_dtensor( 

599 datasets=datasets, 

600 layouts=self._layouts, 

601 mesh=self._mesh, 

602 num_local_devices_per_replica=self._num_local_devices_per_replica) 

603 

604 return _DTensorIterator( 

605 dtensor_components=(d_iterator_resource,), 

606 global_element_spec=self._global_element_spec, 

607 layouts=self._layouts) 

608 

609 def _repeat_batch(self, dataset, repeats): 

610 def repeat(*x): 

611 return dataset_ops.DatasetV2.from_tensors(x).repeat(repeats) 

612 

613 return dataset.flat_map(repeat) 

614 

615 def _partition(self, dataset): 

616 """Slices each dataset element on any sharded non-batch dimension.""" 

617 

618 # TODO(b/223275517): decouple from self and make testable. 

619 def slice_batch(index, batch): 

620 flattened_batch = nest.flatten(batch) 

621 flattened_output = [] 

622 

623 norm_index = math_ops.cast( 

624 index % self._num_local_devices_per_replica, dtype=dtypes.int32) 

625 norm_index += self._partition_offset 

626 coords = self._mesh.coords(norm_index) 

627 coords = array_ops.reshape(coords, (1, -1)) 

628 

629 for element, shard_counts, idx_matrix in zip(flattened_batch, 

630 self._all_shard_counts, 

631 self._index_matrices): 

632 indexes = math_ops.matmul(coords, idx_matrix) 

633 start = array_ops.reshape(indexes, (-1,)) 

634 size = array_ops.shape_v2( 

635 element, out_type=dtypes.int32) // shard_counts 

636 flattened_output.append( 

637 array_ops.slice(element, begin=start, size=size)) 

638 

639 return nest.pack_sequence_as(batch, flattened_output) 

640 

641 enumerated_dataset = dataset.enumerate() 

642 partitioned_dataset = enumerated_dataset.map(slice_batch) 

643 return partitioned_dataset 

644 

645 @property 

646 def element_spec(self): 

647 return self._global_element_spec