Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/input_lib.py: 25%

785 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Various classes representing distributed inputs.""" 

16 

17import functools 

18import sys 

19import time 

20 

21import six 

22 

23from tensorflow.python.autograph.operators import py_builtins 

24from tensorflow.python.data.experimental.ops import batching 

25from tensorflow.python.data.experimental.ops import cardinality as cardinality_lib 

26from tensorflow.python.data.experimental.ops import distribute 

27from tensorflow.python.data.ops import dataset_ops 

28from tensorflow.python.data.ops import iterator_ops 

29from tensorflow.python.data.ops import multi_device_iterator_ops 

30from tensorflow.python.data.ops import optional_ops 

31from tensorflow.python.distribute import device_util 

32from tensorflow.python.distribute import distribute_lib 

33from tensorflow.python.distribute import distribute_utils 

34from tensorflow.python.distribute import input_ops 

35from tensorflow.python.distribute import reduce_util 

36from tensorflow.python.distribute import values 

37from tensorflow.python.distribute.distribute_lib import InputReplicationMode 

38from tensorflow.python.eager import context 

39from tensorflow.python.eager import monitoring 

40from tensorflow.python.framework import composite_tensor 

41from tensorflow.python.framework import device as tf_device 

42from tensorflow.python.framework import dtypes 

43from tensorflow.python.framework import errors 

44from tensorflow.python.framework import ops 

45from tensorflow.python.framework import sparse_tensor 

46from tensorflow.python.framework import tensor_shape 

47from tensorflow.python.framework import tensor_util 

48from tensorflow.python.framework import type_spec 

49from tensorflow.python.ops import array_ops 

50from tensorflow.python.ops import cond as tf_cond 

51from tensorflow.python.ops import math_ops 

52from tensorflow.python.ops import while_loop 

53from tensorflow.python.ops.ragged import ragged_tensor 

54from tensorflow.python.platform import tf_logging as logging 

55from tensorflow.python.types import distribute as distribute_types 

56from tensorflow.python.util import nest 

57from tensorflow.python.util.compat import collections_abc 

58 

59 

60_distributed_dataset_initialization_time_milliseconds = monitoring.Sampler( 

61 "/tensorflow/api/distribution_strategy/" 

62 "distributed_dataset_initialization_time_milliseconds", 

63 monitoring.ExponentialBuckets(scale=1, growth_factor=2, bucket_count=26), 

64 "Track the time (in milliseconds) to initialize distributed datasets.", 

65 "strategy", "workers") 

66 

67_distributed_dataset_from_function_initialization_time_milliseconds = ( 

68 monitoring.Sampler( 

69 "/tensorflow/api/distribution_strategy/" 

70 "distributed_dataset_from_function_initialization_time_milliseconds", 

71 monitoring.ExponentialBuckets( 

72 scale=1, growth_factor=2, bucket_count=26), 

73 "Track the time (in milliseconds) to initialize distributed datasets " 

74 "from function.", 

75 "strategy", "workers")) 

76 

77 

78def get_iterator_spec_from_dataset(strategy, dataset): 

79 """Returns an iterator spec from dataset function. 

80 

81 This function constructs type spec for iterator obtained from 

82 iter(dataset). 

83 

84 Args: 

85 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 

86 handle last partial batch. 

87 dataset: A tf.data.Dataset instance. If using a function that returns a 

88 tf.data.Dataset instance, pass dataset_fn.structured_outputs. 

89 

90 Returns: 

91 A type_spec for iterator for dataset instance. 

92 

93 """ 

94 # pylint: disable=protected-access 

95 output_element_spec = dataset.element_spec 

96 if isinstance(dataset._type_spec, 

97 (DistributedDatasetSpec, 

98 DistributedDatasetsFromFunctionSpec)): 

99 iterator_type_spec = DistributedIteratorSpec( 

100 strategy.extended._input_workers_with_options(), 

101 output_element_spec, 

102 strategy.extended._container_strategy(), 

103 options=None, 

104 cardinality=dataset.cardinality, 

105 enable_get_next_as_optional=True) 

106 else: 

107 if strategy.extended._num_gpus_per_worker: 

108 logging.warning( 

109 f"{strategy.extended._num_gpus_per_worker} GPUs " 

110 "are allocated per worker. Please use DistributedDataset by " 

111 "calling strategy.experimental_distribute_dataset or strategy." 

112 "distribute_datasets_from_function to make best use of GPU " 

113 "resources" 

114 ) 

115 iterator_type_spec = iterator_ops.IteratorSpec(output_element_spec) 

116 return iterator_type_spec 

117 # pylint: enable=protected-access 

118 

119 

120class InputWorkers(object): 

121 """A 1-to-many mapping from input worker devices to compute devices.""" 

122 

123 # TODO(ishark): Remove option canonicalize_devices and make all the callers 

124 # pass canonicalized or raw device strings as relevant from strategy. 

125 def __init__(self, 

126 worker_device_pairs, 

127 canonicalize_devices=True): 

128 """Initialize an `InputWorkers` object. 

129 

130 Args: 

131 worker_device_pairs: A sequence of pairs: `(input device, a tuple of 

132 compute devices fed by that input device)`. 

133 canonicalize_devices: Whether to canonicalize devices for workers fully or 

134 partially. If False, it will partially canonicalize devices by removing 

135 job and task. 

136 """ 

137 self._worker_device_pairs = worker_device_pairs 

138 self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs) 

139 self._canonicalize_devices = canonicalize_devices 

140 if canonicalize_devices: 

141 self._fed_devices = tuple( 

142 tuple(device_util.canonicalize(d) 

143 for d in f) 

144 for _, f in self._worker_device_pairs) 

145 else: 

146 self._fed_devices = tuple( 

147 tuple(device_util.canonicalize_without_job_and_task(d) 

148 for d in f) 

149 for _, f in self._worker_device_pairs) 

150 

151 @property 

152 def num_workers(self): 

153 return len(self._input_worker_devices) 

154 

155 @property 

156 def worker_devices(self): 

157 return self._input_worker_devices 

158 

159 def compute_devices_for_worker(self, worker_index): 

160 return self._fed_devices[worker_index] 

161 

162 def __repr__(self): 

163 devices = self.worker_devices 

164 debug_repr = ",\n".join(" %d %s: %s" % 

165 (i, devices[i], self._fed_devices[i]) 

166 for i in range(len(devices))) 

167 return "%s:{\n%s}" % (self.__class__.__name__, debug_repr) 

168 

169 def serialize(self): 

170 return (self._worker_device_pairs, self._canonicalize_devices) 

171 

172 def deserialize(self, serialized): 

173 return InputWorkers(serialized) 

174 

175 

176def _calculate_replicas_with_values(strategy, input_workers, optional_list): 

177 """Calcualates the number of replicas that have values. 

178 

179 Args: 

180 strategy: the `tf.distribute.Strategy`. 

181 input_workers: the `InputWorkers`. 

182 optional_list: a list of lists `tf.experimental.Optional`. The values from 

183 each compute device grouped by the input device. 

184 

185 Returns: 

186 A scalar Tensor. 

187 """ 

188 worker_has_values = [] 

189 for worker, optionals in zip(input_workers.worker_devices, optional_list): 

190 with ops.device(worker): 

191 device_has_values = [ 

192 math_ops.cast(v.has_value(), dtypes.int64) for v in optionals 

193 ] 

194 worker_has_values.append( 

195 math_ops.reduce_sum(device_has_values, keepdims=True)) 

196 client_has_values = math_ops.reduce_sum(worker_has_values, keepdims=True) 

197 if strategy.extended._in_multi_worker_mode(): # pylint: disable=protected-access 

198 global_has_values = strategy.reduce( 

199 reduce_util.ReduceOp.SUM, client_has_values, axis=None) 

200 return array_ops.reshape(global_has_values, []) 

201 else: 

202 return array_ops.reshape(client_has_values, []) 

203 

204 

205def _is_statically_shaped(element_spec): 

206 """Test if an iterator output is statically shaped. 

207 

208 For sparse and ragged tensors this only tests the batch dimension. 

209 

210 Args: 

211 element_spec: a nest structure of `tf.TypeSpec`. The element spec of the 

212 dataset of the iterator. 

213 

214 Returns: 

215 True if the shape is static, false otherwise. 

216 """ 

217 

218 for spec in nest.flatten(element_spec): 

219 if isinstance( 

220 spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)): 

221 # For sparse or ragged tensor, we should only check the first 

222 # dimension in order to get_next_as_optional. This is because 

223 # when these tensors get batched by dataset only the batch dimension 

224 # is set. 

225 if spec.shape.rank > 0 and spec.shape.as_list()[0] is None: 

226 return False 

227 else: 

228 for component in spec._flat_tensor_specs: # pylint: disable=protected-access 

229 if not component.shape.is_fully_defined(): 

230 return False 

231 return True 

232 

233 

234class DistributedIteratorBase(collections_abc.Iterator, 

235 distribute_types.DistributedIteratorInterface): 

236 """Common implementation for all input iterators.""" 

237 

238 # pylint: disable=super-init-not-called 

239 def __init__( 

240 self, 

241 input_workers, 

242 iterators, 

243 strategy, 

244 cardinality, 

245 enable_get_next_as_optional, 

246 replica_order=None, 

247 ): 

248 assert isinstance(input_workers, InputWorkers) 

249 if not input_workers.worker_devices: 

250 raise ValueError("Should have at least one worker for input iterator.") 

251 

252 self._iterators = iterators 

253 self._input_workers = input_workers 

254 self._strategy = strategy 

255 self._cardinality = cardinality 

256 self._enable_get_next_as_optional = enable_get_next_as_optional 

257 self._replica_order = replica_order 

258 

259 def next(self): 

260 return self.__next__() 

261 

262 def __next__(self): 

263 try: 

264 return self.get_next() 

265 except errors.OutOfRangeError: 

266 raise StopIteration 

267 

268 def __iter__(self): 

269 return self 

270 

271 def get_next_as_optional(self): 

272 # Ideally get_next_as_optional() should be consistent with get_next(), but 

273 # we used to always do partial batch handling in get_next_as_optional(). We 

274 # are keeping this behavior for now until we understantd the impact. 

275 

276 # Skip partial batch handling when the dataset is infinite or empty, as 

277 # there won't be any partial batches in those cases. This gives the user 

278 # more static shapes as it avoids the tf.cond. Note that for empty datasets, 

279 # we can only skip in single client mode, as the dataset can be non-empty on 

280 # other workers. 

281 if self._cardinality == cardinality_lib.INFINITE: 

282 return optional_ops.Optional.from_value( 

283 self._get_next_no_partial_batch_handling()) 

284 if (self._cardinality == 0 and 

285 not self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 

286 return optional_ops.Optional.empty(self._element_spec) 

287 

288 optional_list = [] 

289 for i, worker in enumerate(self._input_workers.worker_devices): 

290 with ops.device(worker): 

291 optional_list.append(self._iterators[i].get_next_as_optional_list()) 

292 

293 def _create_optional_with_dummy(): 

294 value_list = _get_value_or_dummy( 

295 self._input_workers, optional_list, produce_dummy=True) 

296 

297 if self._replica_order is not None: 

298 value_list = self._reorder_replicas(value_list) 

299 

300 per_replica = _create_per_replica(value_list, self._strategy) 

301 return optional_ops.Optional.from_value(per_replica) 

302 

303 def _create_empty_optional(): 

304 return optional_ops.Optional.empty(self._element_spec) 

305 

306 num_replicas_with_values = _calculate_replicas_with_values( 

307 self._strategy, self._input_workers, optional_list) 

308 

309 return tf_cond.cond( 

310 num_replicas_with_values > 0, 

311 _create_optional_with_dummy, 

312 _create_empty_optional, 

313 strict=True) 

314 

315 def get_next(self, name=None): 

316 """Returns the next input from the iterator for all replicas.""" 

317 with distribute_lib.enter_or_assert_strategy( 

318 self._strategy): 

319 if distribute_lib.get_replica_context() is not None: 

320 raise ValueError("next(iterator) should be called from outside of " 

321 "replica_fn. e.g. strategy.run(replica_fn, " 

322 "args=(next(iterator),))") 

323 

324 if not self._enable_get_next_as_optional: 

325 return self._get_next_no_partial_batch_handling(name) 

326 

327 optional_list = [] 

328 for i, worker in enumerate(self._input_workers.worker_devices): 

329 with ops.device(worker): 

330 optional_list.append(self._iterators[i].get_next_as_optional_list()) 

331 num_replicas_with_values = _calculate_replicas_with_values( 

332 self._strategy, self._input_workers, optional_list) 

333 

334 def _value_or_dummy(): 

335 value_list = _get_value_or_dummy( 

336 self._input_workers, optional_list, produce_dummy=True) 

337 

338 if self._replica_order is not None: 

339 value_list = self._reorder_replicas(value_list) 

340 

341 return _create_per_replica(value_list, self._strategy) 

342 

343 def _eof(): 

344 # Optional.get_value raises InvalidArgumentError when there's no value, 

345 # so we need to call GetNext to raise EOFError. 

346 return self._get_next_no_partial_batch_handling() 

347 

348 return tf_cond.cond( 

349 num_replicas_with_values > 0, _value_or_dummy, _eof, strict=True) 

350 

351 def _get_next_no_partial_batch_handling(self, name=None): 

352 replicas = [] 

353 for i, worker in enumerate(self._input_workers.worker_devices): 

354 if name is not None: 

355 d = tf_device.DeviceSpec.from_string(worker) 

356 new_name = "%s_%s_%d" % (name, d.job, d.task) 

357 else: 

358 new_name = None 

359 with ops.device(worker): 

360 # Make `replicas` a flat list of values across all replicas. 

361 replicas.extend(self._iterators[i].get_next_as_list(new_name)) 

362 

363 if self._replica_order is not None: 

364 replicas = self._reorder_replicas(replicas) 

365 

366 return _create_per_replica(replicas, self._strategy) 

367 

368 def _reorder_replicas(self, replicas): 

369 assert len(self._replica_order) == len( 

370 replicas 

371 ), "replica order size ({}) != replicas size ({})!".format( 

372 len(self._replica_order), len(replicas) 

373 ) 

374 return [replicas[i] for i in self._replica_order] 

375 

376 

377class DistributedDatasetAndIteratorSpec(type_spec.TypeSpec): 

378 """Common Type specification for `DistributedDataset and DistributedDatasetsFromFunction.""" 

379 

380 __slots__ = [ 

381 "_input_workers", "_element_spec", "_strategy", "_cardinality", 

382 "_enable_get_next_as_optional", "_options", "_canonicalize_devices" 

383 ] 

384 

385 def __init__( 

386 self, 

387 input_workers, 

388 element_spec, 

389 strategy, 

390 options, 

391 cardinality=cardinality_lib.UNKNOWN, 

392 enable_get_next_as_optional=None, 

393 replica_order=None, 

394 ): 

395 # We don't want to allow deserialization of this class because we don't 

396 # serialize the strategy object. Currently the only places where 

397 # _deserialize is called is when we save/restore using SavedModels. 

398 if isinstance(input_workers, tuple): 

399 raise NotImplementedError("DistributedIteratorSpec does not have support " 

400 "for deserialization.") 

401 else: 

402 self._input_workers = input_workers 

403 self._element_spec = element_spec 

404 self._strategy = strategy 

405 self._cardinality = cardinality 

406 self._enable_get_next_as_optional = enable_get_next_as_optional 

407 self._options = options 

408 if self._strategy: 

409 self._canonicalize_devices = getattr(self._strategy, 

410 "_canonicalize_devices", True) 

411 else: 

412 self._canonicalize_devices = True 

413 self._replica_order = replica_order 

414 

415 def _serialize(self): 

416 # We cannot serialize the strategy object so we convert it to an id that we 

417 # can use for comparison. 

418 return (self._input_workers.serialize(), self._element_spec, 

419 id(self._strategy), id(self._options)) 

420 

421 def _deserialize(self): 

422 raise ValueError( 

423 f"Deserialization is currently unsupported for {type(self)}.") 

424 

425 def sanity_check_type(self, other): 

426 """Returns the most specific TypeSpec compatible with `self` and `other`. 

427 

428 Args: 

429 other: A `TypeSpec`. 

430 

431 Raises: 

432 ValueError: If there is no TypeSpec that is compatible with both `self` 

433 and `other`. 

434 """ 

435 # pylint: disable=protected-access 

436 if type(self) is not type(other): 

437 raise ValueError("No TypeSpec is compatible with both %s and %s" % 

438 (self, other)) 

439 if self._input_workers.serialize() != other._input_workers.serialize(): 

440 raise ValueError("_input_workers is not compatible with both %s " 

441 "and %s" % (self, other)) 

442 if self._strategy is not other._strategy: 

443 raise ValueError("tf.distribute strategy is not compatible with both %s " 

444 "and %s" % (self, other)) 

445 

446 def is_subtype_of(self, other): 

447 """Returns True if `self` is subtype of `other`. 

448 

449 Args: 

450 other: A `TypeSpec`. 

451 """ 

452 try: 

453 self.sanity_check_type(other) 

454 nest.assert_same_structure(self._element_spec, other._element_spec) # pylint: disable=protected-access 

455 except (TypeError, ValueError): 

456 return False 

457 

458 self_elements = nest.flatten(self._element_spec) 

459 other_elements = nest.flatten(other._element_spec) # pylint: disable=protected-access 

460 

461 return all( 

462 self_element.is_subtype_of(other_element) 

463 for (self_element, other_element) in zip(self_elements, other_elements)) 

464 

465 def most_specific_common_supertype(self, others): 

466 """Returns the most specific supertype of `self` and `others`. 

467 

468 Args: 

469 others: A Sequence of `TypeSpec`. 

470 

471 Returns `None` if a supertype does not exist. 

472 """ 

473 try: 

474 for other in others: 

475 self.sanity_check_type(other) 

476 nest.assert_same_structure(self._element_spec, other._element_spec) # pylint: disable=protected-access 

477 except (TypeError, ValueError): 

478 return None 

479 

480 self_elements = nest.flatten(self._element_spec) 

481 others_elements = [nest.flatten(other._element_spec) for other in others] # pylint: disable=protected-access 

482 common_elements = [None] * len(self_elements) 

483 

484 for i, self_element in enumerate(self_elements): 

485 common_elements[i] = self_element.most_specific_common_supertype( 

486 [other_elements[i] for other_elements in others_elements]) 

487 if common_elements[i] is None: 

488 return None 

489 common_element_spec = nest.pack_sequence_as(self._element_spec, 

490 common_elements) 

491 return type(self)( 

492 self._input_workers, 

493 common_element_spec, 

494 self._strategy, 

495 self._options, 

496 cardinality=self._cardinality, 

497 enable_get_next_as_optional=self._enable_get_next_as_optional) 

498 

499 def _with_tensor_ranks_only(self): 

500 element_spec = nest.map_structure( 

501 lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access 

502 self._element_spec) 

503 return type(self)( 

504 self._input_workers, 

505 element_spec, 

506 self._strategy, 

507 self._options, 

508 cardinality=self._cardinality, 

509 enable_get_next_as_optional=self._enable_get_next_as_optional) 

510 

511 # TODO(b/206014848): Remove once names are not used. 

512 def _without_tensor_names(self): 

513 element_spec = nest.map_structure( 

514 lambda s: s._without_tensor_names(), # pylint: disable=protected-access 

515 self._element_spec) 

516 return type(self)( 

517 self._input_workers, 

518 element_spec, 

519 self._strategy, 

520 self._options, 

521 cardinality=self._cardinality, 

522 enable_get_next_as_optional=self._enable_get_next_as_optional) 

523 

524 

525class DistributedIteratorSpec(DistributedDatasetAndIteratorSpec): 

526 """Type specification for `DistributedIterator`.""" 

527 

528 @property 

529 def value_type(self): 

530 return DistributedIterator 

531 

532 @property 

533 def _component_specs(self): 

534 specs = [] 

535 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access 

536 

537 for i, (input_device, compute_devices) in enumerate(worker_device_pairs): 

538 element_spec = nest.map_structure( 

539 functools.partial(_replace_per_replica_spec, i=i), self._element_spec) 

540 specs.append( 

541 _SingleWorkerDatasetIteratorSpec(input_device, compute_devices, 

542 element_spec, self._options, 

543 self._canonicalize_devices)) 

544 return specs 

545 

546 def _to_components(self, value): 

547 return value._iterators # pylint: disable=protected-access 

548 

549 def _from_components(self, components): 

550 return DistributedIterator( 

551 input_workers=self._input_workers, 

552 iterators=None, 

553 components=components, 

554 element_spec=self._element_spec, 

555 strategy=self._strategy, 

556 cardinality=self._cardinality, 

557 enable_get_next_as_optional=self._enable_get_next_as_optional, 

558 options=self._options, 

559 replica_order=self._replica_order, 

560 ) 

561 

562 @staticmethod 

563 def from_value(value): 

564 # pylint: disable=protected-access 

565 return DistributedIteratorSpec( 

566 value._input_workers, 

567 value._element_spec, 

568 value._strategy, 

569 value._options, 

570 cardinality=value._cardinality, 

571 enable_get_next_as_optional=value._enable_get_next_as_optional) 

572 

573 

574class DistributedIterator(DistributedIteratorBase, 

575 composite_tensor.CompositeTensor): 

576 """Input Iterator for a distributed dataset.""" 

577 

578 def __init__( 

579 self, 

580 input_workers=None, 

581 iterators=None, 

582 strategy=None, 

583 components=None, 

584 element_spec=None, 

585 cardinality=cardinality_lib.UNKNOWN, 

586 enable_get_next_as_optional=False, 

587 options=None, 

588 replica_order=None, 

589 ): 

590 if input_workers is None: 

591 raise ValueError("`input_workers` should be " 

592 "provided.") 

593 

594 error_message = ("Either `input_workers` or " 

595 "both `components` and `element_spec` need to be " 

596 "provided.") 

597 self._options = options 

598 

599 if iterators is None: 

600 if (components is None or element_spec is None): 

601 raise ValueError(error_message) 

602 self._element_spec = element_spec 

603 self._input_workers = input_workers 

604 self._iterators = components 

605 self._strategy = strategy 

606 self._cardinality = cardinality 

607 self._enable_get_next_as_optional = enable_get_next_as_optional 

608 self._replica_order = replica_order 

609 else: 

610 if (components is not None and element_spec is not None): 

611 raise ValueError(error_message) 

612 

613 super(DistributedIterator, self).__init__( 

614 input_workers, 

615 iterators, 

616 strategy, 

617 cardinality, 

618 enable_get_next_as_optional, 

619 replica_order, 

620 ) 

621 

622 @property 

623 def element_spec(self): 

624 # When partial batch handling is enabled, always set the batch dimension to 

625 # None, otherwise we just follow element_spec of the underlying dataset 

626 # (whose batch dimension may also be None). This is because with partial 

627 # batching handling we could always produce empty batches. 

628 if (self._enable_get_next_as_optional and 

629 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 

630 return nest.map_structure( 

631 _rebatch_as_dynamic, self._element_spec, expand_composites=False) 

632 return self._element_spec 

633 

634 @property 

635 def _type_spec(self): 

636 # Note that we use actual element_spec instead of the rebatched-as-dynamic 

637 # one to create DistributedIteratorSpec, to be consistent with the 

638 # underlying iterators' specs. 

639 return DistributedIteratorSpec( 

640 self._input_workers, 

641 self._element_spec, 

642 self._strategy, 

643 self._options, 

644 self._cardinality, 

645 self._enable_get_next_as_optional, 

646 self._replica_order, 

647 ) 

648 

649 

650class _IterableInput(collections_abc.Iterable, 

651 distribute_types.DistributedDatasetInterface): 

652 """Base class for iterable inputs for distribution strategies.""" 

653 

654 # pylint: disable=super-init-not-called 

655 def __init__(self, input_workers): 

656 assert isinstance(input_workers, InputWorkers) 

657 self._input_workers = input_workers 

658 

659 def __iter__(self): 

660 raise NotImplementedError("must be implemented in descendants") 

661 

662 def reduce(self, initial_state, reduce_fn): 

663 """Execute a `reduce_fn` over all the elements of the input.""" 

664 iterator = iter(self) 

665 optional_data = iterator.get_next_as_optional() 

666 

667 def cond(optional_data, state): 

668 del state # Unused. 

669 return optional_data.has_value() 

670 

671 def loop_body(optional_data, state): 

672 """Executes `reduce_fn` in a loop till the dataset is empty.""" 

673 state = reduce_fn(state, optional_data.get_value()) 

674 optional_data = iterator.get_next_as_optional() 

675 return optional_data, state 

676 

677 optional_data, final_state = while_loop.while_loop( 

678 cond, 

679 loop_body, [optional_data, initial_state], 

680 parallel_iterations=1, 

681 return_same_structure=True) 

682 return final_state 

683 

684 

685class DistributedDatasetSpec(DistributedDatasetAndIteratorSpec): 

686 """Type specification for `DistributedDataset.""" 

687 

688 @property 

689 def value_type(self): 

690 return DistributedDataset 

691 

692 @property 

693 def _component_specs(self): 

694 specs = [] 

695 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access 

696 

697 for i, _ in enumerate(worker_device_pairs): 

698 element_spec = nest.map_structure( 

699 functools.partial(_replace_per_replica_spec, i=i), self._element_spec) 

700 specs.append(dataset_ops.DatasetSpec(element_spec)) 

701 return specs 

702 

703 def _to_components(self, value): 

704 return value._cloned_datasets # pylint: disable=protected-access 

705 

706 def _from_components(self, components): 

707 return DistributedDataset( 

708 input_workers=self._input_workers, 

709 strategy=self._strategy, 

710 components=components, 

711 element_spec=self._element_spec, 

712 enable_get_next_as_optional=self._enable_get_next_as_optional, 

713 options=self._options, 

714 replica_order=self._replica_order, 

715 ) 

716 

717 @staticmethod 

718 def from_value(value): 

719 # pylint: disable=protected-access 

720 return DistributedDatasetSpec( 

721 value._input_workers, 

722 value._element_spec, 

723 value._strategy, 

724 value._options, 

725 enable_get_next_as_optional=value._enable_get_next_as_optional) 

726 # pylint: enable=protected-access 

727 

728 

729class DistributedDataset(_IterableInput, composite_tensor.CompositeTensor): 

730 """Distributed dataset that supports prefetching to multiple devices.""" 

731 

732 def __init__( 

733 self, 

734 input_workers, 

735 strategy, 

736 dataset=None, 

737 num_replicas_in_sync=None, 

738 input_context=None, 

739 components=None, 

740 element_spec=None, 

741 enable_get_next_as_optional=None, 

742 build=True, 

743 options=None, 

744 replica_order=None, 

745 ): 

746 """Distribute the dataset on all workers. 

747 

748 If `num_replicas_in_sync` is not None, we split each batch of the dataset 

749 into `num_replicas_in_sync` smaller batches, to be distributed among that 

750 worker's replicas, so that the batch size for a global step (across all 

751 workers and replicas) is as expected. 

752 

753 Args: 

754 input_workers: an `InputWorkers` object. 

755 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 

756 handle last partial batch. 

757 dataset: `tf.data.Dataset` that will be used as the input source. Either 

758 dataset or components field should be passed when constructing 

759 DistributedDataset. Use this when contructing DistributedDataset from a 

760 new `tf.data.Dataset`. Use components when constructing using 

761 DistributedDatasetSpec. 

762 num_replicas_in_sync: Optional integer. If this is not None, the value is 

763 used to decide how to rebatch datasets into smaller batches so that the 

764 total batch size for each step (across all workers and replicas) adds up 

765 to `dataset`'s batch size. 

766 input_context: `InputContext` for sharding. Only pass this in for between 

767 graph multi-worker cases where there is only one `input_worker`. In 

768 these cases, we will shard based on the `input_pipeline_id` and 

769 `num_input_pipelines` in the `InputContext`. 

770 components: datasets when DistributedDataset is constructed from 

771 DistributedDatasetSpec. Either field dataset or components should be 

772 passed. 

773 element_spec: element spec for DistributedDataset when constructing from 

774 DistributedDatasetSpec. This will be used to set the element_spec for 

775 DistributedDataset and verified against element_spec from components. 

776 enable_get_next_as_optional: this is required when components is passed 

777 instead of dataset. 

778 build: whether to build underlying datasets when this object is created. 

779 This is only useful for `ParameterServerStrategy` now. 

780 options: `tf.distribute.InputOptions` used to control options on how this 

781 dataset is distributed. 

782 replica_order: the order of the replicas, which will be used to reorder 

783 the iterators to match the device order. 

784 """ 

785 super(DistributedDataset, self).__init__(input_workers=input_workers) 

786 if input_workers is None or strategy is None: 

787 raise ValueError("input_workers and strategy are required arguments") 

788 if dataset is not None and components is not None: 

789 raise ValueError("Only one of dataset or components should be present") 

790 if dataset is None and components is None: 

791 raise ValueError("At least one of dataset or components should be passed") 

792 

793 self._input_workers = input_workers 

794 self._strategy = strategy 

795 self._options = options 

796 self._input_context = input_context 

797 self._num_replicas_in_sync = num_replicas_in_sync 

798 self._replica_order = replica_order 

799 

800 if dataset is not None: 

801 self._original_dataset = dataset 

802 self._built = False 

803 if build: 

804 self.build() 

805 else: 

806 if not build: 

807 raise ValueError( 

808 "When constructing DistributedDataset with components, build " 

809 "should not be False. This is an internal error. Please file a " 

810 "bug.") 

811 if enable_get_next_as_optional is None: 

812 raise ValueError( 

813 "When constructing DistributedDataset with components, " + 

814 "enable_get_next_as_optional should also be passed") 

815 self._cloned_datasets = components 

816 self._cardinality = _cardinality(self._cloned_datasets[0]) 

817 self._enable_get_next_as_optional = enable_get_next_as_optional 

818 

819 assert element_spec is not None 

820 if element_spec != _create_distributed_tensor_spec( 

821 self._strategy, self._cloned_datasets[0].element_spec): 

822 raise ValueError("Mismatched element_spec from the passed components") 

823 self._element_spec = element_spec 

824 

825 self._built = True 

826 

827 def build(self, dataset_to_replace=None): 

828 assert not self._built 

829 dataset = dataset_to_replace or self._original_dataset 

830 self._cardinality = _cardinality(dataset) 

831 self._enable_get_next_as_optional = _enable_get_next_as_optional( 

832 self._strategy, dataset, self._cardinality) 

833 distribute_start_time_ns = time.time_ns() 

834 self._create_cloned_datasets_from_dataset(dataset, self._input_context, 

835 self._input_workers, 

836 self._strategy, 

837 self._num_replicas_in_sync) 

838 if context.executing_eagerly(): 

839 # Records the time to initialize the distributed dataset. 

840 context.async_wait() 

841 distribute_duration_ms = (time.time_ns() - 

842 distribute_start_time_ns) // 1_000_000 

843 _distributed_dataset_initialization_time_milliseconds.get_cell( 

844 self._strategy.__class__.__name__, 

845 str(self._input_workers.num_workers)).add(distribute_duration_ms) 

846 self._element_spec = _create_distributed_tensor_spec( 

847 self._strategy, self._cloned_datasets[0].element_spec) 

848 self._built = True 

849 

850 def auto_shard(self, num_shards, shard_ix): 

851 assert ( 

852 len(self._cloned_datasets) == len(self._input_workers.worker_devices) 

853 ), ( 

854 f"datasets: {len(self._cloned_datasets)}, " 

855 f"input workers: {len(self._input_workers.worker_devices)}" 

856 ) 

857 sharded_datasets = [] 

858 for i in range(len(self._input_workers.worker_devices)): 

859 with ops.colocate_with(self._cloned_datasets[i]._variant_tensor): # pylint:disable=protected-access 

860 sharded_datasets.append( 

861 input_ops.auto_shard_dataset( 

862 self._cloned_datasets[i], num_shards, shard_ix, 

863 self._num_replicas_in_sync 

864 )) 

865 return DistributedDataset( 

866 self._input_workers, 

867 self._strategy, 

868 components=sharded_datasets, 

869 element_spec=self._element_spec, 

870 options=self._options, 

871 enable_get_next_as_optional=self._enable_get_next_as_optional) 

872 

873 @property 

874 def cardinality(self): 

875 if not self._built: 

876 raise ValueError( 

877 "Cannot get the cardinality of a dataset that is not built") 

878 return self._cardinality 

879 

880 def _create_cloned_datasets_from_dataset(self, dataset, input_context, 

881 input_workers, strategy, 

882 num_replicas_in_sync): 

883 # We clone and shard the dataset on each worker. The current setup tries to 

884 # shard the dataset by files if possible so that each worker sees a 

885 # different subset of files. If that is not possible, will attempt to shard 

886 # the final input such that each worker will run the entire preprocessing 

887 # pipeline and only receive its own shard of the dataset. 

888 

889 # Additionally, we rebatch the dataset on each worker into 

890 # `num_replicas_in_sync` smaller batches to be distributed among that 

891 # worker's replicas, so that the batch size for a global step (across all 

892 # workers and replicas) adds up to the original dataset's batch size. 

893 if num_replicas_in_sync is not None and num_replicas_in_sync > 1: 

894 num_workers = input_context.num_input_pipelines if input_context else len( 

895 input_workers.worker_devices) 

896 rebatch_fn = self._make_rebatch_fn(dataset, num_workers, 

897 num_replicas_in_sync) 

898 else: 

899 rebatch_fn = None 

900 self._cloned_datasets = [] 

901 if input_context: 

902 # Between-graph where we rely on the input_context for sharding 

903 assert input_workers.num_workers == 1 

904 if rebatch_fn is not None: 

905 dataset = rebatch_fn(dataset, input_context.input_pipeline_id) 

906 dataset = input_ops.auto_shard_dataset(dataset, 

907 input_context.num_input_pipelines, 

908 input_context.input_pipeline_id, 

909 num_replicas_in_sync) 

910 self._cloned_datasets.append(dataset) 

911 else: 

912 replicated_ds = distribute.replicate(dataset, 

913 input_workers.worker_devices) 

914 for i, worker in enumerate(input_workers.worker_devices): 

915 with ops.device(worker): 

916 cloned_dataset = replicated_ds[worker] 

917 if rebatch_fn is not None: 

918 cloned_dataset = rebatch_fn(cloned_dataset, i) 

919 cloned_dataset = input_ops.auto_shard_dataset( 

920 cloned_dataset, len(input_workers.worker_devices), i, 

921 num_replicas_in_sync) 

922 self._cloned_datasets.append(cloned_dataset) 

923 

924 def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync): 

925 """Returns a callable that rebatches the input dataset. 

926 

927 Args: 

928 dataset: A `tf.data.Dataset` representing the dataset to be distributed. 

929 num_workers: An integer representing the number of workers to distribute 

930 `dataset` among. 

931 num_replicas_in_sync: An integer representing the number of replicas in 

932 sync across all workers. 

933 """ 

934 if num_replicas_in_sync % num_workers: 

935 raise ValueError( 

936 "tf.distribute expects every worker to have the same number of " 

937 "replicas. However, encountered `num_replicas_in_sync` ({}) that " 

938 "cannot be divided by `num_workers` ({})".format( 

939 num_replicas_in_sync, num_workers)) 

940 

941 num_replicas_per_worker = num_replicas_in_sync // num_workers 

942 with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access 

943 batch_size = distribute.compute_batch_size(dataset) 

944 

945 def rebatch_fn(dataset, worker_index): 

946 try: 

947 

948 def apply_rebatch(): 

949 batch_sizes = distribute.batch_sizes_for_worker( 

950 batch_size, num_workers, num_replicas_per_worker, worker_index) 

951 return dataset.rebatch(batch_sizes).prefetch(num_replicas_per_worker) 

952 

953 # pylint: disable=protected-access 

954 def apply_legacy_rebatch(): 

955 return distribute._LegacyRebatchDataset( 

956 dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker) 

957 

958 with ops.colocate_with(dataset._variant_tensor): 

959 return tf_cond.cond( 

960 math_ops.not_equal(batch_size, -1), 

961 true_fn=apply_rebatch, 

962 false_fn=apply_legacy_rebatch) 

963 except errors.InvalidArgumentError as e: 

964 if "without encountering a batch" in str(e): 

965 six.reraise( 

966 ValueError, 

967 ValueError( 

968 "Call the `batch` method on the input Dataset in order to be " 

969 "able to split your input across {} replicas.\n Please see " 

970 "the tf.distribute.Strategy guide. {}".format( 

971 num_replicas_in_sync, e)), 

972 sys.exc_info()[2]) 

973 else: 

974 raise 

975 

976 return rebatch_fn 

977 

978 def __iter__(self): 

979 if not (context.executing_eagerly() or 

980 ops.get_default_graph().building_function): 

981 raise RuntimeError("__iter__() is only supported inside of tf.function " 

982 "or when eager execution is enabled.") 

983 if not self._built: 

984 raise ValueError("To use this dataset, you need to pass this dataset to " 

985 "ClusterCoordinator.create_per_worker_dataset.") 

986 

987 canonicalize_devices = getattr(self._strategy, "_canonicalize_devices", 

988 True) 

989 

990 worker_iterators = _create_iterators_per_worker( 

991 self._cloned_datasets, 

992 self._input_workers, 

993 options=self._options, 

994 canonicalize_devices=canonicalize_devices) 

995 iterator = DistributedIterator( 

996 self._input_workers, 

997 worker_iterators, 

998 self._strategy, 

999 cardinality=self._cardinality, 

1000 enable_get_next_as_optional=self._enable_get_next_as_optional, 

1001 options=self._options, 

1002 replica_order=self._replica_order, 

1003 ) 

1004 iterator._element_spec = self._element_spec # pylint: disable=protected-access 

1005 

1006 # When async eager is enabled, sometimes the iterator may not finish 

1007 # initialization before passing to a multi device function, add a sync point 

1008 # here to make sure all underlying iterators are initialized. 

1009 if context.executing_eagerly(): 

1010 context.async_wait() 

1011 

1012 return iterator 

1013 

1014 @property 

1015 def element_spec(self): 

1016 """The type specification of an element of this dataset.""" 

1017 # When partial batch handling is enabled, always set the batch dimension to 

1018 # None, otherwise we just follow element_spec of the underlying dataset 

1019 # (whose batch dimension may also be None). This is because with partial 

1020 # batching handling we could always produce empty batches. 

1021 if (self._enable_get_next_as_optional and 

1022 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 

1023 return nest.map_structure( 

1024 _rebatch_as_dynamic, self._element_spec, expand_composites=False) 

1025 return self._element_spec 

1026 

1027 @property 

1028 def _type_spec(self): 

1029 return DistributedDatasetSpec( 

1030 self._input_workers, 

1031 self._element_spec, 

1032 self._strategy, 

1033 self._options, 

1034 enable_get_next_as_optional=self._enable_get_next_as_optional) 

1035 

1036 

1037class DistributedDatasetsFromFunctionSpec(DistributedDatasetAndIteratorSpec): 

1038 """Type specification for `DistributedDatasetsFromFunction.""" 

1039 

1040 @property 

1041 def value_type(self): 

1042 return DistributedDatasetsFromFunction 

1043 

1044 @property 

1045 def _component_specs(self): 

1046 specs = [] 

1047 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access 

1048 

1049 for i, _ in enumerate(worker_device_pairs): 

1050 element_spec = nest.map_structure( 

1051 functools.partial(_replace_per_replica_spec, i=i), self._element_spec) 

1052 specs.append(dataset_ops.DatasetSpec(element_spec)) 

1053 return specs 

1054 

1055 def _to_components(self, value): 

1056 return value._datasets # pylint: disable=protected-access 

1057 

1058 def _from_components(self, components): 

1059 return DistributedDatasetsFromFunction( 

1060 input_workers=self._input_workers, 

1061 strategy=self._strategy, 

1062 components=components, 

1063 element_spec=self._element_spec, 

1064 options=self._options) 

1065 

1066 @staticmethod 

1067 def from_value(value): 

1068 # pylint: disable=protected-access 

1069 return DistributedDatasetsFromFunctionSpec( 

1070 input_workers=value._input_workers, 

1071 element_spec=value._element_spec, 

1072 strategy=value._strategy, 

1073 options=value._options) 

1074 

1075 

1076# TODO(priyag): Add other replication modes. 

1077class DistributedDatasetsFromFunction(_IterableInput, 

1078 composite_tensor.CompositeTensor): 

1079 """Inputs created from dataset function.""" 

1080 

1081 def __init__( 

1082 self, 

1083 input_workers, 

1084 strategy, 

1085 input_contexts=None, 

1086 dataset_fn=None, 

1087 options=None, 

1088 components=None, 

1089 element_spec=None, 

1090 build=True, 

1091 replica_order=None, 

1092 ): 

1093 """Makes an iterable from datasets created by the given function. 

1094 

1095 Args: 

1096 input_workers: an `InputWorkers` object. 

1097 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 

1098 handle last partial batch. 

1099 input_contexts: A list of `InputContext` instances to be passed to call(s) 

1100 to `dataset_fn`. Length and order should match worker order in 

1101 `worker_device_pairs`. 

1102 dataset_fn: A function that returns a `Dataset` given an `InputContext`. 

1103 Either dataset_fn or components should be passed to construct 

1104 DistributedDatasetsFromFunction. Use this when constructing 

1105 DistributedDataset using a function. Use components when constructing 

1106 using DistributedDatasetsFromFunctionSpec. 

1107 options: `tf.distribute.InputOptions` used to control options on how this 

1108 dataset is distributed. 

1109 components: datasets when DistributedDatasetsFromFunction is constructed 

1110 from DistributedDatasetsFromFunctionSpec. Only one of dataset or 

1111 components should be passed. 

1112 element_spec: element spec for DistributedDataset when constructing from 

1113 DistributedDatasetSpec. This will be used to set the element_spec for 

1114 DistributedDatasetsFromFunctionSpec and verified against element_spec 

1115 from components. 

1116 build: whether to build underlying datasets when this object is created. 

1117 This is only useful for `ParameterServerStrategy` now. 

1118 replica_order: the order of the replicas, which will be used to reorder 

1119 the iterators to match the device order. 

1120 """ 

1121 super(DistributedDatasetsFromFunction, self).__init__( 

1122 input_workers=input_workers) 

1123 self._input_workers = input_workers 

1124 self._strategy = strategy 

1125 self._options = options 

1126 self._replica_order = replica_order 

1127 if dataset_fn is not None and components is not None: 

1128 raise ValueError("Only one of dataset_fn or components should be set") 

1129 if dataset_fn is None and components is None: 

1130 raise ValueError("At least one of dataset_fn or components should be set") 

1131 

1132 if dataset_fn is not None: 

1133 if input_workers.num_workers != len(input_contexts): 

1134 raise ValueError( 

1135 "Number of input workers (%d) is not same as number of " 

1136 "input_contexts (%d)" % 

1137 (input_workers.num_workers, len(input_contexts))) 

1138 self._input_contexts = input_contexts 

1139 self._num_replicas_in_sync = self._input_contexts[0].num_replicas_in_sync 

1140 self._dataset_fn = dataset_fn 

1141 self._built = False 

1142 if build: 

1143 self.build() 

1144 else: 

1145 if element_spec is None: 

1146 raise ValueError( 

1147 "element_spec should also be passed when passing components") 

1148 if not build: 

1149 raise ValueError( 

1150 "When constructing DistributedDatasetFromFunction with components, " 

1151 "build should not be False. This is an internal error. Please file " 

1152 "a bug.") 

1153 self._element_spec = element_spec 

1154 self._datasets = components 

1155 self._num_replicas_in_sync = None 

1156 self._built = True 

1157 self._cardinality = _cardinality(self._datasets[0]) 

1158 self._enable_get_next_as_optional = _enable_get_next_as_optional( 

1159 self._strategy, self._datasets[0], self._cardinality) 

1160 

1161 def build(self): 

1162 assert not self._built 

1163 distribute_start_time_ns = time.time_ns() 

1164 self._datasets, element_spec = ( 

1165 _create_datasets_from_function_with_input_context( 

1166 self._input_contexts, self._input_workers, self._dataset_fn)) 

1167 if context.executing_eagerly(): 

1168 # Records the time to initialize the distributed dataset. 

1169 context.async_wait() 

1170 distribute_duration_ms = (time.time_ns() - 

1171 distribute_start_time_ns) // 1_000_000 

1172 _distributed_dataset_from_function_initialization_time_milliseconds.get_cell( 

1173 self._strategy.__class__.__name__, 

1174 str(self._input_workers.num_workers)).add(distribute_duration_ms) 

1175 

1176 self._element_spec = _create_distributed_tensor_spec( 

1177 self._strategy, element_spec) 

1178 self._cardinality = _cardinality(self._datasets[0]) 

1179 self._enable_get_next_as_optional = _enable_get_next_as_optional( 

1180 self._strategy, self._datasets[0], self._cardinality) 

1181 self._built = True 

1182 

1183 def auto_shard(self, num_shards, shard_ix): 

1184 assert ( 

1185 len(self._datasets) == len(self._input_workers.worker_devices) 

1186 ), ( 

1187 f"datasets: {len(self._datasets)}, " 

1188 f"input workers: {len(self._input_workers.worker_devices)}" 

1189 ) 

1190 sharded_datasets = [] 

1191 for i in range(len(self._input_workers.worker_devices)): 

1192 with ops.colocate_with(self._datasets[i]._variant_tensor): # pylint: disable=protected-access 

1193 sharded_datasets.append( 

1194 input_ops.auto_shard_dataset( 

1195 self._datasets[i], num_shards, shard_ix, 

1196 self._num_replicas_in_sync 

1197 ) 

1198 ) 

1199 return DistributedDatasetsFromFunction(self._input_workers, self._strategy, 

1200 components=sharded_datasets, 

1201 element_spec=self._element_spec, 

1202 options=self._options) 

1203 

1204 @property 

1205 def cardinality(self): 

1206 if not self._built: 

1207 raise ValueError( 

1208 "Cannot get the cardinality of a dataset that is not built") 

1209 return self._cardinality 

1210 

1211 def __iter__(self): 

1212 if not (ops.executing_eagerly_outside_functions() or 

1213 ops.get_default_graph().building_function): 

1214 raise RuntimeError("__iter__() is only supported inside of tf.function " 

1215 "or when eager execution is enabled.") 

1216 

1217 if not self._built: 

1218 raise ValueError("You need to use this dataset in " 

1219 "ClusterCoordinator.create_per_worker_dataset.") 

1220 

1221 canonicalize_devices = getattr(self._strategy, "_canonicalize_devices", 

1222 True) 

1223 

1224 iterators = _create_iterators_per_worker( 

1225 self._datasets, 

1226 self._input_workers, 

1227 options=self._options, 

1228 canonicalize_devices=canonicalize_devices) 

1229 iterator = DistributedIterator( 

1230 input_workers=self._input_workers, 

1231 iterators=iterators, 

1232 strategy=self._strategy, 

1233 cardinality=self._cardinality, 

1234 enable_get_next_as_optional=self._enable_get_next_as_optional, 

1235 options=self._options, 

1236 replica_order=self._replica_order, 

1237 ) 

1238 iterator._element_spec = self._element_spec # pylint: disable=protected-access 

1239 

1240 # When async eager is enabled, sometimes the iterator may not finish 

1241 # initialization before passing to a multi device function, add a sync 

1242 # point here to make sure all underlying iterators are initialized. 

1243 if context.executing_eagerly(): 

1244 context.async_wait() 

1245 

1246 return iterator 

1247 

1248 @property 

1249 def element_spec(self): 

1250 """The type specification of an element of this dataset.""" 

1251 # When partial batch handling is enabled, always set the batch dimension to 

1252 # None, otherwise we just follow element_spec of the underlying dataset 

1253 # (whose batch dimension may also be None). This is because with partial 

1254 # batching handling we could always produce empty batches. 

1255 if (self._enable_get_next_as_optional and 

1256 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 

1257 return nest.map_structure( 

1258 _rebatch_as_dynamic, self._element_spec, expand_composites=False) 

1259 return self._element_spec 

1260 

1261 @property 

1262 def _type_spec(self): 

1263 return DistributedDatasetsFromFunctionSpec(self._input_workers, 

1264 self._element_spec, 

1265 self._strategy, self._options) 

1266 

1267 

1268def _dummy_tensor_fn(value_structure): 

1269 """A function to create dummy tensors from `value_structure`.""" 

1270 

1271 def create_dummy_tensor(spec): 

1272 """Create a dummy tensor with possible batch dimensions set to 0.""" 

1273 if hasattr(spec, "_create_empty_value"): 

1274 # Type spec may overwrite default dummy values behavior by declaring the 

1275 # `_create_empty_value(self)` method. This method must return a value 

1276 # compatible with the type spec with batch dimensions set to 0 or fail if 

1277 # such a value does not exist. This allows a composite tensor to customize 

1278 # dummy values creation as, in general, its dummy value is not composed 

1279 # from dummy components (e.g. `row_splits` tensor of a RaggedTensor is 

1280 # never allowed to be empty). See b/183969859 for more discussions. 

1281 # TODO(b/186079336): reconsider CompositeTensor support. 

1282 return spec._create_empty_value() # pylint: disable=protected-access 

1283 

1284 if isinstance(spec, ragged_tensor.RaggedTensorSpec): 

1285 # Splice out the ragged dimensions. 

1286 # pylint: disable=protected-access 

1287 feature_shape = spec._shape[:1].concatenate( 

1288 spec._shape[(1 + spec._ragged_rank):]) 

1289 feature_type = spec._dtype 

1290 # pylint: enable=protected-access 

1291 else: 

1292 feature_shape = spec.shape 

1293 feature_type = spec.dtype 

1294 # Ideally we should set the batch dimension to 0, however as in 

1295 # DistributionStrategy we don't know the batch dimension, we try to 

1296 # guess it as much as possible. If the feature has unknown dimensions, we 

1297 # will set them to 0. If the feature shape is already static, we guess the 

1298 # first dimension as batch dimension and set it to 0. 

1299 dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()] 

1300 if feature_shape else []) 

1301 if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or 

1302 feature_shape.is_fully_defined()): 

1303 dims[0] = tensor_shape.Dimension(0) 

1304 

1305 if isinstance(spec, sparse_tensor.SparseTensorSpec): 

1306 return sparse_tensor.SparseTensor( 

1307 values=array_ops.zeros(0, feature_type), 

1308 indices=array_ops.zeros((0, len(dims)), dtypes.int64), 

1309 dense_shape=dims) 

1310 

1311 # Create the dummy tensor. 

1312 dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type) 

1313 if isinstance(spec, ragged_tensor.RaggedTensorSpec): 

1314 # Reinsert the ragged dimensions with size 0. 

1315 # pylint: disable=protected-access 

1316 row_splits = array_ops.zeros(1, spec._row_splits_dtype) 

1317 dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits( 

1318 dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False) 

1319 # pylint: enable=protected-access 

1320 return dummy_tensor 

1321 

1322 return nest.map_structure(create_dummy_tensor, value_structure) 

1323 

1324 

1325def _get_value_or_dummy(input_workers, optional_list, produce_dummy): 

1326 """Returns the value of the optionals or dummy values. 

1327 

1328 Args: 

1329 input_workers: the `InputWorkers`. 

1330 optional_list: a list of lists `tf.experimental.Optional`. The values from 

1331 each compute device grouped by the input device. 

1332 produce_dummy: a bool. Whether to produce dummy tensors when the optional 

1333 doesn't have a value. 

1334 

1335 Returns: 

1336 A flatten list of Tensors. 

1337 

1338 """ 

1339 value_list = [] 

1340 for i, worker in enumerate(input_workers.worker_devices): 

1341 with ops.device(worker): 

1342 devices = input_workers.compute_devices_for_worker(i) 

1343 for j, device in enumerate(devices): 

1344 with ops.device(device): 

1345 if produce_dummy: 

1346 # pylint: disable=cell-var-from-loop 

1347 value_list.append( 

1348 tf_cond.cond( 

1349 optional_list[i][j].has_value(), 

1350 lambda: optional_list[i][j].get_value(), # pylint: disable=unnecessary-lambda 

1351 lambda: _dummy_tensor_fn(optional_list[i][j].element_spec), 

1352 strict=True, 

1353 )) 

1354 # pylint: enable=cell-var-from-loop 

1355 else: 

1356 value_list.append(optional_list[i][j].get_value()) 

1357 return value_list 

1358 

1359 

1360class _SingleWorkerDatasetIteratorBase(object): 

1361 """Iterator for a single `tf.data.Dataset`.""" 

1362 

1363 def __init__(self, dataset, worker, devices, options=None): 

1364 """Create iterator for the `dataset` to fetch data to worker's `devices` . 

1365 

1366 A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch 

1367 input to the devices on the given worker. 

1368 

1369 Args: 

1370 dataset: A `tf.data.Dataset` instance. 

1371 worker: Worker on which ops should be created. 

1372 devices: Distribute data from `dataset` to these devices. 

1373 options: options. 

1374 """ 

1375 self._dataset = dataset 

1376 self._worker = worker 

1377 self._devices = devices 

1378 self._element_spec = dataset.element_spec 

1379 self._options = options 

1380 self._make_iterator() 

1381 

1382 def _make_iterator(self): 

1383 raise NotImplementedError("must be implemented in descendants") 

1384 

1385 def _format_data_list_with_options(self, data_list): 

1386 """Change the data in to a list type if required. 

1387 

1388 The OwnedMultiDeviceIterator returns the list data type, 

1389 while the PER_REPLICA iterator (when used with prefetch disabled) 

1390 returns without the enclosed list. This is to fix the inconsistency. 

1391 Args: 

1392 data_list: data_list 

1393 Returns: 

1394 list 

1395 """ 

1396 if (self._options and self._options.experimental_replication_mode == 

1397 InputReplicationMode.PER_REPLICA and 

1398 not self._options.experimental_fetch_to_device): 

1399 return [data_list] 

1400 else: 

1401 return data_list 

1402 

1403 def get_next(self, device, name=None): 

1404 """Get next element for the given device.""" 

1405 del name 

1406 with ops.device(self._worker): 

1407 if _should_use_multi_device_iterator(self._options): 

1408 return self._iterator.get_next(device) 

1409 else: 

1410 return self._iterator.get_next() 

1411 

1412 def get_next_as_list(self, name=None): 

1413 """Get next element from the underlying iterator. 

1414 

1415 Runs the iterator get_next() within a device scope. Since this doesn't use 

1416 get_next_as_optional(), it is considerably faster than get_next_as_list(), 

1417 but it raises EOFError if any of the device doesn't get any data. 

1418 

1419 Args: 

1420 name: not used. 

1421 

1422 Returns: 

1423 A list consisting of the next data from each device. 

1424 """ 

1425 del name 

1426 with ops.device(self._worker): 

1427 return self._format_data_list_with_options(self._iterator.get_next()) 

1428 

1429 def get_next_as_optional_list(self): 

1430 with ops.device(self._worker): 

1431 return self._format_data_list_with_options( 

1432 self._iterator.get_next_as_optional()) 

1433 

1434 

1435class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec): 

1436 """Type specification for `_SingleWorkerOwnedDatasetIterator`.""" 

1437 

1438 __slots__ = [ 

1439 "_worker", "_devices", "_element_spec", "_options", 

1440 "_canonicalize_devices" 

1441 ] 

1442 

1443 def __init__(self, worker, devices, element_spec, options, 

1444 canonicalize_devices=True): 

1445 self._worker = worker 

1446 if canonicalize_devices: 

1447 self._devices = tuple(device_util.canonicalize(d) for d in devices) 

1448 else: 

1449 self._devices = tuple( 

1450 device_util.canonicalize_without_job_and_task(d) for d in devices) 

1451 self._element_spec = element_spec 

1452 # `self._options` intentionally made not `None` for proper serialization. 

1453 self._options = (options if options is not None else 

1454 distribute_lib.InputOptions()) 

1455 self._canonicalize_devices = canonicalize_devices 

1456 

1457 @property 

1458 def value_type(self): 

1459 return _SingleWorkerOwnedDatasetIterator 

1460 

1461 def _serialize(self): 

1462 return (self._worker, self._devices, self._element_spec, self._options, 

1463 self._canonicalize_devices) 

1464 

1465 def _get_multi_device_iterator_spec(self, specs): 

1466 device_scope = device_util.canonicalize(self._worker, device_util.current()) 

1467 host_device = device_util.get_host_for_device(device_scope) 

1468 # source_device while creating iterator governs the worker device in 

1469 # iterator spec. 

1470 worker = host_device 

1471 specs.append( 

1472 multi_device_iterator_ops.MultiDeviceIteratorSpec( 

1473 self._devices, worker, element_spec=self._element_spec)) 

1474 

1475 @property 

1476 def _component_specs(self): 

1477 specs = [] 

1478 if _should_use_multi_device_iterator(self._options): 

1479 self._get_multi_device_iterator_spec(specs) 

1480 else: 

1481 specs.append(iterator_ops.IteratorSpec(element_spec=self._element_spec)) 

1482 return specs 

1483 

1484 def _to_components(self, value): 

1485 return [value._iterator] # pylint: disable=protected-access 

1486 

1487 def _from_components(self, components): 

1488 return _SingleWorkerOwnedDatasetIterator( 

1489 dataset=None, 

1490 worker=self._worker, 

1491 devices=self._devices, 

1492 components=components, 

1493 element_spec=self._element_spec, 

1494 options=self._options, 

1495 canonicalize_devices=self._canonicalize_devices) 

1496 

1497 @staticmethod 

1498 def from_value(value): 

1499 # pylint: disable=protected-access 

1500 return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices, 

1501 value._element_spec, value._options, 

1502 value._canonicalize_devices) 

1503 

1504 

1505class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, 

1506 composite_tensor.CompositeTensor): 

1507 """Iterator for a DistributedDataset instance.""" 

1508 

1509 def __init__(self, 

1510 dataset=None, 

1511 worker=None, 

1512 devices=None, 

1513 components=None, 

1514 element_spec=None, 

1515 options=None, 

1516 canonicalize_devices=None): 

1517 """Create iterator for the `dataset` to fetch data to worker's `devices` . 

1518 

1519 `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the 

1520 given worker. The lifetime of this iterator is tied to the encompassing 

1521 python object. Once we go out of scope of the python object or return from 

1522 a tf.function the underlying iterator resource is deleted. 

1523 

1524 Args: 

1525 dataset: A `tf.data.Dataset` instance. 

1526 worker: Worker on which ops should be created. 

1527 devices: Distribute data from `dataset` to these devices. 

1528 components: Tensor components to construct the 

1529 _SingleWorkerOwnedDatasetIterator from. 

1530 element_spec: A nested structure of `TypeSpec` objects that represents the 

1531 type specification of elements of the iterator. 

1532 options: `tf.distribute.InputOptions` used to control options on how this 

1533 dataset is distributed. 

1534 canonicalize_devices: Whether to canonicalize devices for workers fully or 

1535 partially. If False, it will partially canonicalize devices by removing 

1536 job and task. 

1537 """ 

1538 if worker is None or devices is None: 

1539 raise ValueError("Both `worker` and `devices` should be provided") 

1540 

1541 error_message = ("Either `dataset` or both `components` and `element_spec` " 

1542 "need to be provided.") 

1543 

1544 self._options = options 

1545 self._canonicalize_devices = canonicalize_devices 

1546 if dataset is None: 

1547 if (components is None or element_spec is None): 

1548 raise ValueError(error_message) 

1549 self._element_spec = element_spec 

1550 self._worker = worker 

1551 self._devices = devices 

1552 self._iterator = components[0] 

1553 else: 

1554 if (components is not None or element_spec is not None): 

1555 raise ValueError(error_message) 

1556 super(_SingleWorkerOwnedDatasetIterator, 

1557 self).__init__(dataset, worker, devices, self._options) 

1558 

1559 def _create_owned_multi_device_iterator(self): 

1560 # If the worker devices are already canonicalized, canonicalizing again 

1561 # would have no impact. 

1562 # For strategies running on remote workers such as PS Strategy, the device 

1563 # scope will be derived from current worker, if used under init_scope(). 

1564 if not ops.inside_function(): 

1565 device_scope = device_util.canonicalize(self._worker, 

1566 device_util.current()) 

1567 host_device = device_util.get_host_for_device(device_scope) 

1568 else: 

1569 # In general, iterators should not be created within tf.functions. For 

1570 # exact visitation guarantee solutions for parameter server training, 

1571 # however, we do create iterators within the tf.functions that are 

1572 # dispatched to workers. In these cases, the traced device must match the 

1573 # runtime device. Since tracing occurs on the chief, we do not want to use 

1574 # the current device scope, which would be the chief, but rather use the 

1575 # relative worker device scope explicitly. 

1576 device_scope, host_device = self._worker, self._worker 

1577 with ops.device(device_scope): 

1578 if self._options is not None: 

1579 self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( 

1580 self._dataset, 

1581 self._devices, 

1582 source_device=host_device, 

1583 max_buffer_size=self._options 

1584 .experimental_per_replica_buffer_size, 

1585 prefetch_buffer_size=self._options 

1586 .experimental_per_replica_buffer_size) 

1587 else: 

1588 self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( 

1589 self._dataset, self._devices, source_device=host_device) 

1590 

1591 def _make_iterator(self): 

1592 """Make appropriate iterator on the dataset.""" 

1593 if not self._worker: 

1594 raise ValueError("Worker device must be specified when creating an " 

1595 "owned iterator.") 

1596 if _should_use_multi_device_iterator(self._options): 

1597 self._create_owned_multi_device_iterator() 

1598 else: 

1599 with ops.device(self._worker): 

1600 self._iterator = iter(self._dataset) 

1601 

1602 @property 

1603 def element_spec(self): 

1604 return self._element_spec 

1605 

1606 @property 

1607 def _type_spec(self): 

1608 return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices, 

1609 self._element_spec, self._options, 

1610 self._canonicalize_devices) 

1611 

1612 @property 

1613 def output_classes(self): 

1614 """Returns the class of each component of an element of this iterator. 

1615 

1616 The expected values are `tf.Tensor` and `tf.SparseTensor`. 

1617 

1618 Returns: 

1619 A nested structure of Python `type` objects corresponding to each 

1620 component of an element of this dataset. 

1621 """ 

1622 return nest.map_structure( 

1623 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 

1624 self._element_spec) 

1625 

1626 @property 

1627 def output_shapes(self): 

1628 """Returns the shape of each component of an element of this iterator. 

1629 

1630 Returns: 

1631 A nested structure of `tf.TensorShape` objects corresponding to each 

1632 component of an element of this dataset. 

1633 """ 

1634 return nest.map_structure( 

1635 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 

1636 self._element_spec) 

1637 

1638 @property 

1639 def output_types(self): 

1640 """Returns the type of each component of an element of this iterator. 

1641 

1642 Returns: 

1643 A nested structure of `tf.DType` objects corresponding to each component 

1644 of an element of this dataset. 

1645 """ 

1646 return nest.map_structure( 

1647 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 

1648 self._element_spec) 

1649 

1650 

1651def _create_iterators_per_worker(worker_datasets, 

1652 input_workers, 

1653 options=None, 

1654 canonicalize_devices=False): 

1655 """Create a multidevice iterator on each of the workers.""" 

1656 assert isinstance(input_workers, InputWorkers) 

1657 assert len(worker_datasets) == len(input_workers.worker_devices) 

1658 iterators = [] 

1659 for i, worker in enumerate(input_workers.worker_devices): 

1660 with ops.device(worker): 

1661 worker_devices = input_workers.compute_devices_for_worker(i) 

1662 iterator = _SingleWorkerOwnedDatasetIterator( 

1663 dataset=worker_datasets[i], 

1664 worker=worker, 

1665 devices=worker_devices, 

1666 options=options, 

1667 canonicalize_devices=canonicalize_devices) 

1668 iterators.append(iterator) 

1669 return iterators 

1670 

1671 

1672def _create_datasets_from_function_with_input_context(input_contexts, 

1673 input_workers, 

1674 dataset_fn): 

1675 """Create device datasets per worker given a dataset function.""" 

1676 datasets = [] 

1677 for i, ctx in enumerate(input_contexts): 

1678 worker = input_workers.worker_devices[i] 

1679 with ops.device(worker): 

1680 dataset = dataset_fn(ctx) 

1681 datasets.append(dataset) 

1682 return datasets, dataset.element_spec 

1683 

1684 

1685# TODO(sourabhbajaj): Remove this in lieu of distributed datasets 

1686def _get_batched_dataset(d): 

1687 """Get the batched dataset from `d`.""" 

1688 # pylint: disable=protected-access 

1689 if isinstance(d, dataset_ops.DatasetV1Adapter): 

1690 d = d._dataset 

1691 

1692 if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)): 

1693 return d 

1694 elif isinstance(d, (dataset_ops.PrefetchDataset, 

1695 dataset_ops._OptionsDataset)): 

1696 return _get_batched_dataset(d._input_dataset) 

1697 

1698 raise ValueError( 

1699 "Unable to get batched dataset from the input dataset. `batch` " 

1700 "`map_and_batch` need to be the last operations on the dataset. " 

1701 "The batch operations can be followed by a prefetch.") 

1702 

1703 

1704def _get_batched_dataset_attributes(d): 

1705 """Get `batch_size`, `drop_remainder` of dataset.""" 

1706 # pylint: disable=protected-access 

1707 assert isinstance(d, 

1708 (dataset_ops.BatchDataset, batching._MapAndBatchDataset)) 

1709 if isinstance(d, dataset_ops.BatchDataset): 

1710 batch_size = d._batch_size 

1711 drop_remainder = d._drop_remainder 

1712 elif isinstance(d, batching._MapAndBatchDataset): 

1713 batch_size = d._batch_size_t 

1714 drop_remainder = d._drop_remainder_t 

1715 # pylint: enable=protected-access 

1716 

1717 if tensor_util.is_tf_type(batch_size): 

1718 batch_size = tensor_util.constant_value(batch_size) 

1719 

1720 if tensor_util.is_tf_type(drop_remainder): 

1721 drop_remainder = tensor_util.constant_value(drop_remainder) 

1722 

1723 return batch_size, drop_remainder 

1724 

1725 

1726# TODO(sourabhbajaj): Remove this in lieu of distributed datasets 

1727def _get_dataset_attributes(dataset): 

1728 """Get the underlying attributes from the dataset object.""" 

1729 # pylint: disable=protected-access 

1730 

1731 # First, get batch_size and drop_remainder from the dataset. We need 

1732 # to walk back the dataset creation process and find the batched version in 

1733 # order to get the attributes. 

1734 batched_dataset = _get_batched_dataset(dataset) 

1735 batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset) 

1736 

1737 # Second, prefetch buffer should be get from the original dataset. 

1738 prefetch_buffer = None 

1739 if isinstance(dataset, dataset_ops.PrefetchDataset): 

1740 prefetch_buffer = dataset._buffer_size 

1741 elif (isinstance(dataset, dataset_ops.DatasetV1Adapter) 

1742 and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)): 

1743 prefetch_buffer = dataset._dataset._buffer_size 

1744 

1745 return batch_size, drop_remainder, prefetch_buffer 

1746 

1747 

1748def _should_use_multi_device_iterator(options): 

1749 """Determine whether to use multi_device_iterator_ops.""" 

1750 if (options is None or 

1751 options.experimental_replication_mode == InputReplicationMode.PER_WORKER 

1752 or 

1753 (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA 

1754 and options.experimental_fetch_to_device)): 

1755 return True 

1756 return False 

1757 

1758 

1759class MultiStepContext(object): 

1760 """A context object that can be used to capture things when running steps. 

1761 

1762 This context object is useful when running multiple steps at a time using the 

1763 `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step 

1764 function to specify which outputs to emit at what frequency. Currently it 

1765 supports capturing output from the last step, as well as capturing non tensor 

1766 outputs. In the future it will be augmented to support other use cases such 

1767 as output each N steps. 

1768 """ 

1769 

1770 def __init__(self): 

1771 """Initialize an output context. 

1772 

1773 Returns: 

1774 A context object. 

1775 """ 

1776 self._last_step_outputs = {} 

1777 self._last_step_outputs_reduce_ops = {} 

1778 self._non_tensor_outputs = {} 

1779 

1780 @property 

1781 def last_step_outputs(self): 

1782 """A dictionary consisting of outputs to be captured on last step. 

1783 

1784 Keys in the dictionary are names of tensors to be captured, as specified 

1785 when `set_last_step_output` is called. 

1786 Values in the dictionary are the tensors themselves. If 

1787 `set_last_step_output` was called with a `reduce_op` for this output, 

1788 then the value is the reduced value. 

1789 

1790 Returns: 

1791 A dictionary with last step outputs. 

1792 """ 

1793 return self._last_step_outputs 

1794 

1795 def _set_last_step_outputs(self, outputs): 

1796 """Replace the entire dictionary of last step outputs.""" 

1797 if not isinstance(outputs, dict): 

1798 raise ValueError("Need a dictionary to set last_step_outputs.") 

1799 self._last_step_outputs = outputs 

1800 

1801 def set_last_step_output(self, name, output, reduce_op=None): 

1802 """Set `output` with `name` to be outputted from the last step. 

1803 

1804 Args: 

1805 name: String, name to identify the output. Doesn't need to match tensor 

1806 name. 

1807 output: The tensors that should be outputted with `name`. See below for 

1808 actual types supported. 

1809 reduce_op: Reduction method to use to reduce outputs from multiple 

1810 replicas. Required if `set_last_step_output` is called in a replica 

1811 context. Optional in cross_replica_context. 

1812 When present, the outputs from all the replicas are reduced using the 

1813 current distribution strategy's `reduce` method. Hence, the type of 

1814 `output` must be what's supported by the corresponding `reduce` method. 

1815 For e.g. if using MirroredStrategy and reduction is set, output 

1816 must be a `PerReplica` value. 

1817 The reduce method is also recorded in a dictionary 

1818 `_last_step_outputs_reduce_ops` for later interpreting of the 

1819 outputs as already reduced or not. 

1820 """ 

1821 if distribute_lib.in_cross_replica_context(): 

1822 self._last_step_outputs_reduce_ops[name] = reduce_op 

1823 if reduce_op is None: 

1824 self._last_step_outputs[name] = output 

1825 else: 

1826 distribution = distribute_lib.get_strategy() 

1827 self._last_step_outputs[name] = distribution.reduce(reduce_op, output, 

1828 axis=None) 

1829 else: 

1830 assert reduce_op is not None 

1831 def merge_fn(distribution, value): 

1832 self._last_step_outputs[name] = distribution.reduce(reduce_op, value, 

1833 axis=None) 

1834 # Setting this inside the `merge_fn` because all replicas share the same 

1835 # context object, so it's more robust to set it only once (even if all 

1836 # the replicas are trying to set the same value). 

1837 self._last_step_outputs_reduce_ops[name] = reduce_op 

1838 

1839 distribute_lib.get_replica_context().merge_call( 

1840 merge_fn, args=(output,)) 

1841 

1842 @property 

1843 def non_tensor_outputs(self): 

1844 """A dictionary consisting of any non tensor outputs to be captured.""" 

1845 return self._non_tensor_outputs 

1846 

1847 def set_non_tensor_output(self, name, output): 

1848 """Set `output` with `name` to be captured as a non tensor output.""" 

1849 if distribute_lib.in_cross_replica_context(): 

1850 self._non_tensor_outputs[name] = output 

1851 else: 

1852 def merge_fn(distribution, value): 

1853 # NOTE(priyag): For non tensor outputs, we simply return all the values 

1854 # in a list as reduction doesn't make sense on non tensors. 

1855 self._non_tensor_outputs[name] = ( 

1856 distribution.experimental_local_results(value)) 

1857 distribute_lib.get_replica_context().merge_call( 

1858 merge_fn, args=(output,)) 

1859 

1860 

1861def _create_distributed_tensor_spec(strategy, tensor_spec): 

1862 """Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`. 

1863 

1864 Args: 

1865 strategy: The given `tf.distribute` strategy. 

1866 tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the 

1867 shape should be None if you have partial batches. 

1868 

1869 Returns: 

1870 A `tf.TypeSpec` that matches the values produced by a given strategy. This 

1871 can be a `tf.TensorSpec` or a `PerRelicaSpec`. 

1872 """ 

1873 num_replicas = len(strategy.extended.worker_devices) 

1874 

1875 # For one device strategy that is not MultiWorkerMirroredStrategy, return the 

1876 # tensor_spec as is, since we don't wrap the output with PerReplica in this 

1877 # case. 

1878 # TODO(b/166464552): remove after we always wrap for all strategies. 

1879 if not _always_wrap(strategy): 

1880 return tensor_spec 

1881 

1882 # For other cases we assume the input to tf.function is a per replica type. 

1883 def _get_value_per_replica(tensor_spec_per_input): 

1884 value_specs = [tensor_spec_per_input for _ in range(num_replicas)] 

1885 return values.PerReplicaSpec(*value_specs) 

1886 

1887 return nest.map_structure(_get_value_per_replica, tensor_spec) 

1888 

1889 

1890def _replace_per_replica_spec(spec, i): 

1891 """If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec.""" 

1892 if isinstance(spec, values.PerReplicaSpec): 

1893 return spec._value_specs[i] # pylint: disable=protected-access 

1894 else: 

1895 return spec 

1896 

1897 

1898def _cardinality(dataset): 

1899 """Returns the cardinality of the dataset.""" 

1900 if context.executing_eagerly(): 

1901 with ops.device(dataset._variant_tensor.device): # pylint: disable=protected-access 

1902 return dataset.cardinality().numpy() 

1903 return cardinality_lib.UNKNOWN 

1904 

1905 

1906def _enable_get_next_as_optional(strategy, dataset, cardinality): 

1907 """Returns whether to enable using partial batch handling.""" 

1908 # TODO(b/133073708): we currently need a flag to control the usage because 

1909 # there is a performance difference between get_next() and 

1910 # get_next_as_optional(). And we only enable get_next_as_optional when the 

1911 # output shapes are not static. 

1912 # 

1913 # TODO(rxsang): We want to always enable the get_next_as_optional behavior 

1914 # when user passed input_fn instead of dataset. 

1915 if not getattr( 

1916 strategy.extended, "enable_partial_batch_handling", 

1917 getattr(strategy.extended, "experimental_enable_get_next_as_optional", 

1918 False)): 

1919 return False 

1920 

1921 # If the dataset is infinite, we don't need to enable last partial batch 

1922 # support. Note that we can only evaluate the cardinality of the dataset in 

1923 # eager. 

1924 if cardinality == cardinality_lib.INFINITE: 

1925 return False 

1926 

1927 return not _is_statically_shaped( 

1928 dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 

1929 

1930 

1931def _create_per_replica(value_list, strategy): 

1932 """Creates a PerReplica. 

1933 

1934 For strategies other than OneDeviceStrategy, it creates a PerReplica whose 

1935 type spec is set to the element spec of the dataset. This helps avoid 

1936 retracing for partial batches. Retracing is problematic for multi client when 

1937 different client retraces different time, since retracing changes the 

1938 collective keys in the tf.function, and causes mismatches among clients. 

1939 

1940 For single client strategies, this simply calls distribute_utils.regroup(). 

1941 

1942 Args: 

1943 value_list: a list of values, one for each replica. 

1944 strategy: the `tf.distribute.Strategy`. 

1945 

1946 Returns: 

1947 a structure of PerReplica. 

1948 

1949 """ 

1950 # TODO(b/166464552): always wrap for all one device strategies as well. 

1951 always_wrap = _always_wrap(strategy) 

1952 per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap) 

1953 return per_replicas 

1954 

1955 

1956def _always_wrap(strategy): 

1957 """Returns whether to always wrap the values in a DistributedValues.""" 

1958 return strategy.extended._in_multi_worker_mode() or len( # pylint: disable=protected-access 

1959 strategy.extended.worker_devices) > 1 

1960 

1961 

1962def _rebatch_as_dynamic(per_replica_spec): 

1963 """Rebatch the spec to have a dynamic batch dimension.""" 

1964 assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec 

1965 

1966 # pylint: disable=protected-access 

1967 def _rebatch(spec): 

1968 # Rebatch if possible. 

1969 try: 

1970 return spec._unbatch()._batch(None) 

1971 except ValueError: 

1972 pass 

1973 return spec 

1974 

1975 return values.PerReplicaSpec( 

1976 *nest.map_structure(_rebatch, per_replica_spec._value_specs)) 

1977 # pylint: enable=protected-access 

1978 

1979 

1980def _ag_enumerate_not_implemented(s, unused_start): 

1981 msg = ( 

1982 f"enumerate not supported with {s.__class__.__name__} types within " 

1983 "tf.functions. Use a for loop over the dataset and keep a separate " 

1984 "counter instead." 

1985 ) 

1986 raise NotImplementedError(msg) 

1987 

1988 

1989py_builtins.enumerate_registry.register( 

1990 DistributedIterator, _ag_enumerate_not_implemented 

1991) 

1992py_builtins.enumerate_registry.register( 

1993 DistributedDataset, _ag_enumerate_not_implemented 

1994)