Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/data_adapter.py: 24%

800 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Adapter module that convert different input data objects into tf.dataset.""" 

16 

17import abc 

18import contextlib 

19import functools 

20import itertools 

21import math 

22import random 

23 

24import numpy as np 

25import tensorflow.compat.v2 as tf 

26 

27from keras.src import backend 

28from keras.src.distribute import distributed_training_utils 

29from keras.src.engine import training_utils 

30from keras.src.utils import data_utils 

31from keras.src.utils import dataset_creator 

32from keras.src.utils import tf_utils 

33 

34# isort: off 

35from tensorflow.python.distribute.input_lib import ( 

36 DistributedDataset, 

37) 

38from tensorflow.python.eager import context 

39from tensorflow.python.framework import type_spec 

40from tensorflow.python.platform import tf_logging as logging 

41from tensorflow.python.util.tf_export import keras_export 

42from tensorflow.python.data.ops import ( 

43 from_sparse_tensor_slices_op, 

44) 

45from tensorflow.python.data.ops import from_generator_op 

46from tensorflow.python.data.ops import range_op 

47from tensorflow.python.data.ops import from_tensors_op 

48from tensorflow.python.data.ops import from_tensor_slices_op 

49 

50try: 

51 import pandas as pd 

52except ImportError: 

53 pd = None 

54 

55keras_data_adapter_gauge = tf.__internal__.monitoring.BoolGauge( 

56 "/tensorflow/api/keras/data_adapters", "keras data adapter usage", "method" 

57) 

58 

59 

60class DataAdapter(object, metaclass=abc.ABCMeta): 

61 """Base class for input data adapter. 

62 

63 In TF 2.0, tf.data is the preferred API for user to feed in data. In order 

64 to simplify the training code path, all the input data object will be 

65 converted to `tf.data.Dataset` if possible. 

66 

67 Note that since this class is mainly targeted for TF 2.0, it might have a 

68 lot of assumptions under the hood, e.g. eager context by default, 

69 distribution strategy, etc. In the meantime, some legacy feature support 

70 might be dropped, eg, Iterator from dataset API in v1, etc. 

71 

72 The sample usage of this class is like: 

73 

74 ``` 

75 x = tf.data.Dataset.range(100) 

76 adapter_cls = [NumpyArrayDataAdapter, ..., DatasetAdapter] 

77 applicable_adapters = [cls for cls in adapter_cls if cls.can_handle(x)] 

78 if len(applicable_adapters) != 1: 

79 raise ValueError("Expect only one adapter class to handle the input") 

80 

81 dataset = applicable_adapters[0](x).get_dataset() 

82 for data in dataset: 

83 # training 

84 ``` 

85 """ 

86 

87 @staticmethod 

88 def can_handle(x, y=None): 

89 """Whether the current DataAdapter could handle the input x and y. 

90 

91 Structure wise, x and y can be single object, or list of objects if 

92 there multiple input/output, or dictionary of objects when the 

93 input/output are named. 

94 

95 Args: 

96 x: input features. 

97 y: target labels. Note that y could be None in the case of prediction. 

98 

99 Returns: 

100 boolean 

101 """ 

102 raise NotImplementedError 

103 

104 @abc.abstractmethod 

105 def __init__(self, x, y=None, **kwargs): 

106 """Create a DataAdapter based on data inputs. 

107 

108 The caller must make sure to call `can_handle()` first before invoking 

109 this method. Provide unsupported data type will result into unexpected 

110 behavior. 

111 

112 Args: 

113 x: input features. 

114 y: target labels. Note that y could be None in the case of prediction. 

115 **kwargs: Other keyword arguments for DataAdapter during the 

116 construction of the tf.dataset.Dataset. For example: 

117 - Numpy data might have `sample_weights` which will be used for 

118 weighting the loss function during training. 

119 - Numpy data might need to have `batch_size` parameter when 

120 constructing the dataset and iterator. 

121 - Certain input might need to be distribution strategy aware. When 

122 `distribution_strategy` is passed, the created dataset need to 

123 respect the strategy. 

124 DataAdapter might choose to ignore any keyword argument if it 

125 doesn't use it, or raise exception if any required argument is not 

126 provided. 

127 """ 

128 if not self.can_handle(x, y): 

129 raise ValueError(f"{self.__class__} Cannot handle input {x}, {y}") 

130 

131 @abc.abstractmethod 

132 def get_dataset(self): 

133 """Get a dataset instance for the current DataAdapter. 

134 

135 Note that the dataset returned does not repeat for epoch, so caller 

136 might need to create new iterator for the same dataset at the beginning 

137 of the epoch. This behavior might change in the future. 

138 

139 Returns: 

140 A `tf.data.Dataset`. Caller might use the dataset in different 

141 context, e.g. iter(dataset) in eager to get the value directly, or in 

142 graph mode, provide the iterator tensor to Keras model function. 

143 """ 

144 raise NotImplementedError 

145 

146 @abc.abstractmethod 

147 def get_size(self): 

148 """Return the size (number of batches) for the dataset created. 

149 

150 For certain type of the data input, the number of batches is known, eg 

151 for Numpy data, the size is same as (number_of_element / batch_size). 

152 Whereas for dataset or python generator, the size is unknown since it 

153 may or may not have an end state. 

154 

155 Returns: 

156 int, the number of batches for the dataset, or None if it is unknown. 

157 The caller could use this to control the loop of training, show 

158 progress bar, or handle unexpected StopIteration error. 

159 """ 

160 raise NotImplementedError 

161 

162 @abc.abstractmethod 

163 def batch_size(self): 

164 """Return the batch size of the dataset created. 

165 

166 For certain type of the data input, the batch size is known, and even 

167 required, like numpy array. Whereas for dataset, the batch is unknown 

168 unless we take a peek. 

169 

170 Returns: 

171 int, the batch size of the dataset, or None if it is unknown. 

172 """ 

173 raise NotImplementedError 

174 

175 def representative_batch_size(self): 

176 """Return a representative size for batches in the dataset. 

177 

178 This is not guaranteed to be the batch size for all batches in the 

179 dataset. It just needs to be a rough approximation for batch sizes in 

180 the dataset. 

181 

182 Returns: 

183 int, a representative size for batches found in the dataset, 

184 or None if it is unknown. 

185 """ 

186 return self.batch_size() 

187 

188 @abc.abstractmethod 

189 def has_partial_batch(self): 

190 """Whether the dataset has partial batch at the end.""" 

191 raise NotImplementedError 

192 

193 @abc.abstractmethod 

194 def partial_batch_size(self): 

195 """The size of the final partial batch for dataset. 

196 

197 Will return None if has_partial_batch is False or batch_size is None. 

198 """ 

199 raise NotImplementedError 

200 

201 @abc.abstractmethod 

202 def should_recreate_iterator(self): 

203 """Returns whether a new iterator should be created every epoch.""" 

204 raise NotImplementedError 

205 

206 def get_samples(self): 

207 """Returns number of samples in the data, or `None`.""" 

208 if not self.get_size() or not self.batch_size(): 

209 return None 

210 total_sample = self.get_size() * self.batch_size() 

211 if self.has_partial_batch(): 

212 total_sample -= self.batch_size() - self.partial_batch_size() 

213 return total_sample 

214 

215 def on_epoch_end(self): 

216 """A hook called after each epoch.""" 

217 pass 

218 

219 

220class TensorLikeDataAdapter(DataAdapter): 

221 """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy.""" 

222 

223 @staticmethod 

224 def can_handle(x, y=None): 

225 # TODO(kaftan): Check performance implications of using a flatten 

226 # here for other types of inputs. 

227 flat_inputs = tf.nest.flatten(x) 

228 if y is not None: 

229 flat_inputs += tf.nest.flatten(y) 

230 

231 tensor_types = _get_tensor_types() 

232 

233 def _is_tensor(v): 

234 if isinstance(v, tensor_types): 

235 return True 

236 return False 

237 

238 return all(_is_tensor(v) for v in flat_inputs) 

239 

240 def __init__( 

241 self, 

242 x, 

243 y=None, 

244 sample_weights=None, 

245 sample_weight_modes=None, 

246 batch_size=None, 

247 epochs=1, 

248 steps=None, 

249 shuffle=False, 

250 **kwargs, 

251 ): 

252 super().__init__(x, y, **kwargs) 

253 x, y, sample_weights = _process_tensorlike((x, y, sample_weights)) 

254 sample_weight_modes = broadcast_sample_weight_modes( 

255 sample_weights, sample_weight_modes 

256 ) 

257 

258 # If sample_weights are not specified for an output use 1.0 as weights. 

259 (sample_weights, _, _) = training_utils.handle_partial_sample_weights( 

260 y, sample_weights, sample_weight_modes, check_all_flat=True 

261 ) 

262 

263 inputs = pack_x_y_sample_weight(x, y, sample_weights) 

264 

265 num_samples = set( 

266 int(i.shape[0]) for i in tf.nest.flatten(inputs) 

267 ).pop() 

268 _check_data_cardinality(inputs) 

269 

270 # If batch_size is not passed but steps is, calculate from the input 

271 # data. Defaults to `32` for backwards compatibility. 

272 if not batch_size: 

273 batch_size = int(math.ceil(num_samples / steps)) if steps else 32 

274 

275 self._size = int(math.ceil(num_samples / batch_size)) 

276 self._batch_size = batch_size 

277 

278 num_full_batches = int(num_samples // batch_size) 

279 self._partial_batch_size = num_samples % batch_size 

280 

281 if isinstance(shuffle, str): 

282 shuffle = shuffle.lower() 

283 

284 self._shuffle = shuffle 

285 # Vectorized version of shuffle. 

286 # This is a performance improvement over using `from_tensor_slices`. 

287 # The indices of the data are shuffled and batched, and these indices 

288 # are then zipped with the data and used to extract a batch of the data 

289 # at each step. The performance improvements here come from: 

290 # 1. vectorized batch using gather 

291 # 2. parallelized map 

292 # 3. pipelined permutation generation 

293 # 4. optimized permutation batching 

294 # 5. disabled static optimizations 

295 

296 indices_dataset = tf.data.Dataset.range(1) 

297 if shuffle != "batch": 

298 indices_dataset = indices_dataset.repeat(epochs) 

299 

300 def permutation(_): 

301 # It turns out to be more performant to make a new set of indices 

302 # rather than reusing the same range Tensor. (presumably because of 

303 # buffer forwarding.) 

304 indices = tf.range(num_samples, dtype=tf.int64) 

305 if shuffle and shuffle != "batch": 

306 indices = tf.random.shuffle(indices) 

307 return indices 

308 

309 # We prefetch a single element. Computing large permutations can take 

310 # quite a while so we don't want to wait for prefetching over an epoch 

311 # boundary to trigger the next permutation. On the other hand, too many 

312 # simultaneous shuffles can contend on a hardware level and degrade all 

313 # performance. 

314 indices_dataset = indices_dataset.map(permutation).prefetch(1) 

315 

316 def slice_batch_indices(indices): 

317 """Convert a Tensor of indices into a dataset of batched indices. 

318 

319 This step can be accomplished in several ways. The most natural is 

320 to slice the Tensor in a Dataset map. (With a condition on the upper 

321 index to handle the partial batch.) However it turns out that 

322 coercing the Tensor into a shape which is divisible by the batch 

323 size (and handling the last partial batch separately) allows for a 

324 much more favorable memory access pattern and improved performance. 

325 

326 Args: 

327 indices: Tensor which determines the data order for an entire 

328 epoch. 

329 

330 Returns: 

331 A Dataset of batched indices. 

332 """ 

333 num_in_full_batch = num_full_batches * batch_size 

334 first_k_indices = tf.slice(indices, [0], [num_in_full_batch]) 

335 first_k_indices = tf.reshape( 

336 first_k_indices, [num_full_batches, batch_size] 

337 ) 

338 

339 flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices) 

340 if self._partial_batch_size: 

341 index_remainder = tf.data.Dataset.from_tensors( 

342 tf.slice( 

343 indices, [num_in_full_batch], [self._partial_batch_size] 

344 ) 

345 ) 

346 flat_dataset = flat_dataset.concatenate(index_remainder) 

347 

348 if shuffle == "batch": 

349 # 1024 is a magic constant that has not been properly evaluated 

350 flat_dataset = flat_dataset.shuffle(1024).repeat(epochs) 

351 return flat_dataset 

352 

353 indices_dataset = indices_dataset.flat_map(slice_batch_indices) 

354 

355 dataset = self.slice_inputs(indices_dataset, inputs) 

356 

357 if shuffle == "batch": 

358 

359 def shuffle_batch(*batch): 

360 return tf.nest.map_structure(tf.random.shuffle, batch) 

361 

362 dataset = dataset.map(shuffle_batch) 

363 

364 options = tf.data.Options() 

365 options.experimental_distribute.auto_shard_policy = ( 

366 tf.data.experimental.AutoShardPolicy.DATA 

367 ) 

368 dataset = dataset.with_options(options) 

369 

370 self._dataset = dataset.prefetch(tf.data.AUTOTUNE) 

371 

372 def slice_inputs(self, indices_dataset, inputs): 

373 """Slice inputs into a Dataset of batches. 

374 

375 Given a Dataset of batch indices and the unsliced inputs, 

376 this step slices the inputs in a parallelized fashion 

377 and produces a dataset of input batches. 

378 

379 Args: 

380 indices_dataset: A Dataset of batched indices 

381 inputs: A python data structure that contains the inputs, targets, 

382 and possibly sample weights. 

383 

384 Returns: 

385 A Dataset of input batches matching the batch indices. 

386 """ 

387 dataset = tf.data.Dataset.zip( 

388 (indices_dataset, tf.data.Dataset.from_tensors(inputs).repeat()) 

389 ) 

390 

391 def grab_batch(i, data): 

392 return tf.nest.map_structure( 

393 lambda d: tf.gather(d, i, axis=0), data 

394 ) 

395 

396 dataset = dataset.map(grab_batch, num_parallel_calls=tf.data.AUTOTUNE) 

397 

398 # Default optimizations are disabled to avoid the overhead of 

399 # (unnecessary) input pipeline graph serialization and deserialization 

400 options = tf.data.Options() 

401 options.experimental_optimization.apply_default_optimizations = False 

402 if self._shuffle: 

403 # See b/141490660 for more details. 

404 options.experimental_external_state_policy = ( 

405 tf.data.experimental.ExternalStatePolicy.IGNORE 

406 ) 

407 dataset = dataset.with_options(options) 

408 return dataset 

409 

410 def get_dataset(self): 

411 return self._dataset 

412 

413 def get_size(self): 

414 return self._size 

415 

416 def batch_size(self): 

417 return self._batch_size 

418 

419 def has_partial_batch(self): 

420 return self._partial_batch_size > 0 

421 

422 def partial_batch_size(self): 

423 return self._partial_batch_size or None 

424 

425 def should_recreate_iterator(self): 

426 # An infinite dataset is always created here. 

427 return False 

428 

429 

430class GenericArrayLikeDataAdapter(TensorLikeDataAdapter): 

431 """Adapter that handles array-like data without forcing it into memory. 

432 

433 This adapter handles array-like datasets that may be too big to fully 

434 fit into memory. 

435 

436 Specifically, this adapter handles any Python class which implements: 

437 `__get_item__`, `__len__`, `shape`, and `dtype` with the same meanings 

438 as Numpy, but it ignores any case where all the inputs are Tensors or Numpy 

439 arrays (because that case is handled by the base TensorLikeDataAdapter). 

440 

441 It ignores scipy sparse matrices and Composite Tensors because those are 

442 handled by the CompositeTensorDataAdapter. 

443 

444 It also does not handle lists/tuples of scalars, because those are handled 

445 by the ListsOfScalarsDataAdapter. 

446 """ 

447 

448 @staticmethod 

449 def can_handle(x, y=None): 

450 flat_inputs = tf.nest.flatten(x) 

451 if y is not None: 

452 flat_inputs += tf.nest.flatten(y) 

453 

454 def _is_array_like(v): 

455 """Return True if v is a Tensor, array, or is array-like.""" 

456 return ( 

457 hasattr(v, "__getitem__") 

458 and hasattr(v, "shape") 

459 and hasattr(v, "dtype") 

460 and hasattr(v, "__len__") 

461 ) 

462 

463 if not TensorLikeDataAdapter.can_handle( 

464 x, y 

465 ) and not CompositeTensorDataAdapter.can_handle(x, y): 

466 return all(_is_array_like(v) for v in flat_inputs) 

467 else: 

468 return False 

469 

470 def __init__(self, *args, **kwargs): 

471 logging.warning( 

472 "Keras is training/fitting/evaluating on array-like data. Keras " 

473 "may not be optimized for this format, so if your input data " 

474 "format is supported by TensorFlow I/O " 

475 "(https://github.com/tensorflow/io) we recommend using that to " 

476 "load a Dataset instead." 

477 ) 

478 

479 super().__init__(*args, **kwargs) 

480 

481 def slice_inputs(self, indices_dataset, inputs): 

482 """Slice inputs into a Dataset of batches. 

483 

484 Given a Dataset of batch indices and the unsliced inputs, 

485 this step slices the inputs in a parallelized fashion 

486 and produces a dataset of input batches. 

487 

488 Args: 

489 indices_dataset: A Dataset of batched indices 

490 inputs: A python data structure that contains the inputs, targets, 

491 and possibly sample weights. 

492 

493 Returns: 

494 A Dataset of input batches matching the batch indices. 

495 """ 

496 flat_inputs = tf.nest.flatten(inputs) 

497 

498 def dynamic_shape_like(t): 

499 shape = list(t.shape) 

500 shape[0] = None 

501 return tuple(shape) 

502 

503 flat_dtypes = [inp.dtype for inp in flat_inputs] 

504 contiguous = True 

505 if self._shuffle and self._shuffle != "batch": 

506 contiguous = False 

507 

508 def grab_batch(indices): 

509 """Grab a batch of data from the inputs.""" 

510 # This uses a py_function to avoid converting the array-like 

511 # into a Tensor before slicing it, because converting the array-like 

512 # to a Tensor may force it into memory.. 

513 def py_method(ind): 

514 def slice_array(data): 

515 return training_utils.slice_arrays( 

516 data, ind.numpy(), contiguous=contiguous 

517 ) 

518 

519 return [slice_array(inp) for inp in flat_inputs] 

520 

521 flat_out = tf.py_function(py_method, [indices], flat_dtypes) 

522 for v, original_inp in zip(flat_out, flat_inputs): 

523 v.set_shape(dynamic_shape_like(original_inp)) 

524 return tf.nest.pack_sequence_as(inputs, flat_out) 

525 

526 dataset = indices_dataset.map( 

527 grab_batch, num_parallel_calls=tf.data.AUTOTUNE 

528 ) 

529 

530 return dataset 

531 

532 

533class DatasetCreatorAdapter(DataAdapter): 

534 """Adapter that handles dataset functions.""" 

535 

536 def __init__(self, x, y, steps=None, distribution_strategy=None, **kwargs): 

537 super().__init__(x, **kwargs) 

538 

539 if not isinstance(x, dataset_creator.DatasetCreator): 

540 raise TypeError( 

541 "The input of a `DatasetCreatorAdapter` should be a " 

542 "`DatasetCreator` but it received type {}.".format(type(x)) 

543 ) 

544 if steps is None: 

545 if not kwargs.get("pss_evaluation_shards"): 

546 raise ValueError( 

547 "When using a " 

548 "`tf.keras.utils.experimental.DatasetCreator`, " 

549 "`steps_per_epoch`, `validation_steps`, `steps`, or " 

550 "`pss_evaluation_shards` argument must be provided in " 

551 "`Model.fit`, `Model.evaluate`, or `Model.predict`." 

552 ) 

553 self.dataset_creator = x 

554 self.steps = steps 

555 self.strategy = distribution_strategy 

556 

557 @staticmethod 

558 def can_handle(x, y=None): 

559 if isinstance(x, dataset_creator.DatasetCreator): 

560 assert y is None 

561 return True 

562 

563 def should_recreate_iterator(self): 

564 # We expect users to shuffle the dataset in their `dataset_fn` supplied 

565 # to `DatasetCreator`. Since that is a buffered shuffle, we intend to 

566 # not reset the dataset so the batches that are not shuffled can still 

567 # be pulled. 

568 return False 

569 

570 def get_size(self): 

571 return None # To be inferred by `DataHandler`. 

572 

573 def get_dataset(self): 

574 return self.strategy.distribute_datasets_from_function( 

575 self.dataset_creator, options=self.dataset_creator.input_options 

576 ) 

577 

578 def batch_size(self): 

579 raise NotImplementedError() 

580 

581 def has_partial_batch(self): 

582 raise NotImplementedError() 

583 

584 def partial_batch_size(self): 

585 raise NotImplementedError() 

586 

587 

588class CompositeTensorDataAdapter(DataAdapter): 

589 """Adapter that handles composite tensor.""" 

590 

591 @staticmethod 

592 def can_handle(x, y=None): 

593 flat_inputs = tf.nest.flatten(x) 

594 if y is not None: 

595 flat_inputs += tf.nest.flatten(y) 

596 

597 def _is_composite(v): 

598 # Dataset/iterator/DistributedDataset inherits from CompositeTensor 

599 # but should be handled by DatasetAdapter and GeneratorAdapter. 

600 if ( 

601 tf_utils.is_extension_type(v) 

602 and not isinstance(v, (tf.data.Dataset, tf.data.Iterator)) 

603 and not _is_distributed_dataset(v) 

604 ): 

605 return True 

606 # Support Scipy sparse tensors if scipy is installed 

607 return _is_scipy_sparse(v) 

608 

609 def _is_tensor_or_composite(v): 

610 if isinstance(v, (tf.Tensor, np.ndarray)): 

611 return True 

612 return _is_composite(v) 

613 

614 return any(_is_composite(v) for v in flat_inputs) and all( 

615 _is_tensor_or_composite(v) for v in flat_inputs 

616 ) 

617 

618 def __init__( 

619 self, 

620 x, 

621 y=None, 

622 sample_weights=None, 

623 sample_weight_modes=None, 

624 batch_size=None, 

625 steps=None, 

626 shuffle=False, 

627 **kwargs, 

628 ): 

629 super().__init__(x, y, **kwargs) 

630 x, y, sample_weights = _process_tensorlike((x, y, sample_weights)) 

631 sample_weight_modes = broadcast_sample_weight_modes( 

632 sample_weights, sample_weight_modes 

633 ) 

634 

635 # If sample_weights are not specified for an output use 1.0 as weights. 

636 (sample_weights, _, _) = training_utils.handle_partial_sample_weights( 

637 y, sample_weights, sample_weight_modes, check_all_flat=True 

638 ) 

639 

640 inputs = pack_x_y_sample_weight(x, y, sample_weights) 

641 

642 dataset = tf.data.Dataset.from_tensor_slices(inputs) 

643 num_samples = int(tf.nest.flatten(x)[0].shape[0]) 

644 if shuffle: 

645 dataset = dataset.shuffle(num_samples) 

646 

647 # If batch_size is not passed but steps is, calculate from the input 

648 # data. Defaults to `32` for backwards compatibility. 

649 if not batch_size: 

650 batch_size = int(math.ceil(num_samples / steps)) if steps else 32 

651 

652 dataset = dataset.batch(batch_size) 

653 self._size = int(math.ceil(num_samples / batch_size)) 

654 self._batch_size = batch_size 

655 self._has_partial_batch = self._size != (num_samples // batch_size) 

656 

657 self._partial_batch_size = None 

658 if self._has_partial_batch: 

659 self._partial_batch_size = ( 

660 num_samples - (self._size - 1) * self._batch_size 

661 ) 

662 

663 self._dataset = dataset.prefetch(tf.data.AUTOTUNE) 

664 

665 def get_dataset(self): 

666 return self._dataset 

667 

668 def get_size(self): 

669 return self._size 

670 

671 def batch_size(self): 

672 return self._batch_size 

673 

674 def has_partial_batch(self): 

675 return self._has_partial_batch 

676 

677 def partial_batch_size(self): 

678 return self._partial_batch_size 

679 

680 def should_recreate_iterator(self): 

681 return True 

682 

683 

684class ListsOfScalarsDataAdapter(DataAdapter): 

685 """Adapter that handles lists of scalars and lists of lists of scalars.""" 

686 

687 @staticmethod 

688 def can_handle(x, y=None): 

689 handles_x = ListsOfScalarsDataAdapter._is_list_of_scalars(x) 

690 handles_y = True 

691 if y is not None: 

692 handles_y = ListsOfScalarsDataAdapter._is_list_of_scalars(y) 

693 return handles_x and handles_y 

694 

695 @staticmethod 

696 def _is_list_of_scalars(inp): 

697 if isinstance(inp, (float, int, str, bytes, bytearray)): 

698 return True 

699 if isinstance(inp, (list, tuple)) and inp: 

700 return ListsOfScalarsDataAdapter._is_list_of_scalars(inp[0]) 

701 return False 

702 

703 def __init__( 

704 self, 

705 x, 

706 y=None, 

707 sample_weights=None, 

708 sample_weight_modes=None, 

709 batch_size=None, 

710 shuffle=False, 

711 **kwargs, 

712 ): 

713 super().__init__(x, y, **kwargs) 

714 x = np.asarray(x) 

715 if y is not None: 

716 y = np.asarray(y) 

717 if sample_weights is not None: 

718 sample_weights = np.asarray(sample_weights) 

719 sample_weight_modes = broadcast_sample_weight_modes( 

720 sample_weights, sample_weight_modes 

721 ) 

722 

723 self._internal_adapter = TensorLikeDataAdapter( 

724 x, 

725 y=y, 

726 sample_weights=sample_weights, 

727 sample_weight_modes=sample_weight_modes, 

728 batch_size=batch_size, 

729 shuffle=shuffle, 

730 **kwargs, 

731 ) 

732 

733 def get_dataset(self): 

734 return self._internal_adapter.get_dataset() 

735 

736 def get_size(self): 

737 return self._internal_adapter.get_size() 

738 

739 def batch_size(self): 

740 return self._internal_adapter.batch_size() 

741 

742 def has_partial_batch(self): 

743 return self._internal_adapter.has_partial_batch() 

744 

745 def partial_batch_size(self): 

746 return self._internal_adapter.partial_batch_size() 

747 

748 def should_recreate_iterator(self): 

749 return True 

750 

751 

752class DatasetAdapter(DataAdapter): 

753 """Adapter that handles `tf.data.Dataset`.""" 

754 

755 @staticmethod 

756 def can_handle(x, y=None): 

757 return isinstance( 

758 x, (tf.compat.v1.data.Dataset, tf.data.Dataset) 

759 ) or _is_distributed_dataset(x) 

760 

761 def __init__(self, x, y=None, sample_weights=None, steps=None, **kwargs): 

762 super().__init__(x, y, **kwargs) 

763 # Note that the dataset instance is immutable, its fine to reuse the 

764 # user provided dataset. 

765 self._dataset = x 

766 

767 # The user-provided steps. 

768 self._user_steps = steps 

769 

770 self._validate_args( 

771 y, sample_weights, steps, kwargs.get("pss_evaluation_shards") 

772 ) 

773 

774 def get_dataset(self): 

775 return self._dataset 

776 

777 def get_size(self): 

778 return # Inferred in `DataHandler`. 

779 

780 def batch_size(self): 

781 return None 

782 

783 def has_partial_batch(self): 

784 return False 

785 

786 def partial_batch_size(self): 

787 return None 

788 

789 def should_recreate_iterator(self): 

790 # Since DistributedDatasets have no cardinality, the user must provide 

791 # all steps that need to be run, calling `.repeat()` as needed. 

792 if _is_distributed_dataset(self._dataset): 

793 return False 

794 

795 # If user doesn't supply `steps`, or if they supply `steps` that 

796 # exactly equals the size of the `Dataset`, create a new iterator 

797 # each epoch. 

798 return ( 

799 self._user_steps is None 

800 or tf.data.experimental.cardinality(self._dataset).numpy() 

801 == self._user_steps 

802 ) 

803 

804 def _validate_args(self, y, sample_weights, steps, pss_evaluation_shards): 

805 """Validates `__init__` arguments.""" 

806 # Arguments that shouldn't be passed. 

807 if not is_none_or_empty(y): 

808 raise ValueError( 

809 "`y` argument is not supported when using dataset as input." 

810 ) 

811 if not is_none_or_empty(sample_weights): 

812 raise ValueError( 

813 "`sample_weight` argument is not supported when using " 

814 "dataset as input." 

815 ) 

816 

817 if steps is None: 

818 if _is_distributed_dataset(self._dataset): 

819 if not pss_evaluation_shards: 

820 raise ValueError( 

821 "When providing a distributed dataset, you must " 

822 "specify the number of steps to run." 

823 ) 

824 else: 

825 size = tf.data.experimental.cardinality(self._dataset).numpy() 

826 if size == tf.data.experimental.INFINITE_CARDINALITY: 

827 if pss_evaluation_shards: 

828 raise ValueError( 

829 "When performing exact evaluation, the dataset " 

830 "must be finite. Make sure not to call `repeat()` " 

831 "on your dataset." 

832 ) 

833 else: 

834 raise ValueError( 

835 "When providing an infinite dataset, you must " 

836 "specify the number of steps to run (if you did " 

837 "not intend to create an infinite dataset, make " 

838 "sure to not call `repeat()` on the dataset)." 

839 ) 

840 

841 

842class GeneratorDataAdapter(DataAdapter): 

843 """Adapter that handles python generators and iterators.""" 

844 

845 @staticmethod 

846 def can_handle(x, y=None): 

847 return ( 

848 (hasattr(x, "__next__") or hasattr(x, "next")) 

849 and hasattr(x, "__iter__") 

850 and not isinstance(x, data_utils.Sequence) 

851 ) 

852 

853 def __init__( 

854 self, 

855 x, 

856 y=None, 

857 sample_weights=None, 

858 workers=1, 

859 use_multiprocessing=False, 

860 max_queue_size=10, 

861 model=None, 

862 **kwargs, 

863 ): 

864 # Generators should never shuffle as exhausting the generator in order 

865 # to shuffle the batches is inefficient. 

866 kwargs.pop("shuffle", None) 

867 

868 if not is_none_or_empty(y): 

869 raise ValueError( 

870 "`y` argument is not supported when using " 

871 "python generator as input." 

872 ) 

873 if not is_none_or_empty(sample_weights): 

874 raise ValueError( 

875 "`sample_weight` argument is not supported when using " 

876 "python generator as input." 

877 ) 

878 

879 super().__init__(x, y, **kwargs) 

880 

881 # Since we have to know the dtype of the python generator when we build 

882 # the dataset, we have to look at a batch to infer the structure. 

883 peek, x = self._peek_and_restore(x) 

884 peek = self._standardize_batch(peek) 

885 peek = _process_tensorlike(peek) 

886 

887 # Need to build the Model on concrete input shapes. 

888 if model is not None and not model.built: 

889 concrete_x, _, _ = unpack_x_y_sample_weight(peek) 

890 try: 

891 model.distribute_strategy.run( 

892 lambda x: model(x, training=False), args=(concrete_x,) 

893 ) 

894 except NotImplementedError: 

895 # The above call may fail if the model is a container-like class 

896 # that does not implement its own forward pass (e.g. a GAN or 

897 # VAE where the forward pass is handled by subcomponents). Such 

898 # a model does not need to be built. 

899 pass 

900 

901 self._first_batch_size = int(tf.nest.flatten(peek)[0].shape[0]) 

902 

903 def _get_tensor_spec(t): 

904 # TODO(b/226395276): Remove _with_tensor_ranks_only usage. 

905 return type_spec.type_spec_from_value(t)._with_tensor_ranks_only() 

906 

907 output_signature = tf.nest.map_structure(_get_tensor_spec, peek) 

908 

909 # Note that dataset API takes a callable that creates a generator 

910 # object, rather than generator itself, which is why we define a 

911 # function here. 

912 generator_fn = self._handle_multiprocessing( 

913 x, workers, use_multiprocessing, max_queue_size 

914 ) 

915 

916 def wrapped_generator(): 

917 for data in generator_fn(): 

918 yield self._standardize_batch(data) 

919 

920 dataset = tf.data.Dataset.from_generator( 

921 wrapped_generator, output_signature=output_signature 

922 ) 

923 

924 if workers == 1 and not use_multiprocessing: 

925 dataset = dataset.prefetch(1) 

926 

927 self._dataset = dataset.prefetch(tf.data.AUTOTUNE) 

928 

929 def _standardize_batch(self, data): 

930 """Standardizes a batch output by a generator.""" 

931 # Removes `None`s. 

932 x, y, sample_weight = unpack_x_y_sample_weight(data) 

933 data = pack_x_y_sample_weight(x, y, sample_weight) 

934 

935 data = tf.__internal__.nest.list_to_tuple(data) 

936 

937 def _convert_dtype(t): 

938 if isinstance(t, np.ndarray) and issubclass( 

939 t.dtype.type, np.floating 

940 ): 

941 return np.array(t, dtype=backend.floatx()) 

942 return t 

943 

944 data = tf.nest.map_structure(_convert_dtype, data) 

945 return data 

946 

947 @staticmethod 

948 def _peek_and_restore(x): 

949 peek = next(x) 

950 return peek, itertools.chain([peek], x) 

951 

952 def _handle_multiprocessing( 

953 self, x, workers, use_multiprocessing, max_queue_size 

954 ): 

955 """Create a callable, possibly including an Enqueuer.""" 

956 if workers > 1 or (workers > 0 and use_multiprocessing): 

957 

958 def generator_fn(): 

959 enqueuer = data_utils.GeneratorEnqueuer( 

960 x, use_multiprocessing=use_multiprocessing 

961 ) 

962 enqueuer.start(workers=workers, max_queue_size=max_queue_size) 

963 return enqueuer.get() 

964 

965 else: 

966 generator_fn = lambda: x 

967 return generator_fn 

968 

969 def get_dataset(self): 

970 return self._dataset 

971 

972 def get_size(self): 

973 return None 

974 

975 def batch_size(self): 

976 return None 

977 

978 def representative_batch_size(self): 

979 return self._first_batch_size 

980 

981 def has_partial_batch(self): 

982 return False 

983 

984 def partial_batch_size(self): 

985 return 

986 

987 def should_recreate_iterator(self): 

988 return False 

989 

990 

991class KerasSequenceAdapter(GeneratorDataAdapter): 

992 """Adapter that handles `keras.utils.Sequence`.""" 

993 

994 @staticmethod 

995 def can_handle(x, y=None): 

996 return isinstance(x, data_utils.Sequence) 

997 

998 def __init__( 

999 self, 

1000 x, 

1001 y=None, 

1002 sample_weights=None, 

1003 shuffle=False, 

1004 workers=1, 

1005 use_multiprocessing=False, 

1006 max_queue_size=10, 

1007 model=None, 

1008 **kwargs, 

1009 ): 

1010 if not is_none_or_empty(y): 

1011 raise ValueError( 

1012 "`y` argument is not supported when using " 

1013 "`keras.utils.Sequence` as input." 

1014 ) 

1015 if not is_none_or_empty(sample_weights): 

1016 raise ValueError( 

1017 "`sample_weight` argument is not supported when using " 

1018 "`keras.utils.Sequence` as input." 

1019 ) 

1020 

1021 self._shuffle_sequence = shuffle 

1022 self._keras_sequence = x 

1023 self._enqueuer = None 

1024 super().__init__( 

1025 x, 

1026 shuffle=False, # Shuffle is handed in the _make_callable override. 

1027 workers=workers, 

1028 use_multiprocessing=use_multiprocessing, 

1029 max_queue_size=max_queue_size, 

1030 model=model, 

1031 **kwargs, 

1032 ) 

1033 

1034 @staticmethod 

1035 def _peek_and_restore(x): 

1036 return x[0], x 

1037 

1038 def _handle_multiprocessing( 

1039 self, x, workers, use_multiprocessing, max_queue_size 

1040 ): 

1041 if workers > 1 or (workers > 0 and use_multiprocessing): 

1042 

1043 def generator_fn(): 

1044 self._enqueuer = data_utils.OrderedEnqueuer( 

1045 x, 

1046 use_multiprocessing=use_multiprocessing, 

1047 shuffle=self._shuffle_sequence, 

1048 ) 

1049 self._enqueuer.start( 

1050 workers=workers, max_queue_size=max_queue_size 

1051 ) 

1052 return self._enqueuer.get() 

1053 

1054 else: 

1055 

1056 def generator_fn(): 

1057 order = range(len(x)) 

1058 if self._shuffle_sequence: 

1059 # Match the shuffle convention in OrderedEnqueuer. 

1060 order = list(order) 

1061 random.shuffle(order) 

1062 

1063 for i in order: 

1064 yield x[i] 

1065 

1066 return generator_fn 

1067 

1068 def get_size(self): 

1069 return len(self._keras_sequence) 

1070 

1071 def should_recreate_iterator(self): 

1072 return True 

1073 

1074 def on_epoch_end(self): 

1075 if self._enqueuer: 

1076 self._enqueuer.stop() 

1077 self._keras_sequence.on_epoch_end() 

1078 

1079 

1080ALL_ADAPTER_CLS = [ 

1081 ListsOfScalarsDataAdapter, 

1082 TensorLikeDataAdapter, 

1083 GenericArrayLikeDataAdapter, 

1084 DatasetAdapter, 

1085 GeneratorDataAdapter, 

1086 KerasSequenceAdapter, 

1087 CompositeTensorDataAdapter, 

1088 DatasetCreatorAdapter, 

1089] 

1090 

1091UNSHARDABLE_DATASET_TYPES = [ 

1092 from_generator_op._GeneratorDataset, 

1093 range_op._RangeDataset, 

1094 from_sparse_tensor_slices_op._SparseTensorSliceDataset, 

1095 from_tensors_op._TensorDataset, 

1096 from_tensor_slices_op._TensorSliceDataset, 

1097] 

1098 

1099 

1100def select_data_adapter(x, y): 

1101 """Selects a data adapter that can handle a given x and y.""" 

1102 adapter_cls = [cls for cls in ALL_ADAPTER_CLS if cls.can_handle(x, y)] 

1103 if not adapter_cls: 

1104 # TODO(scottzhu): This should be a less implementation-specific error. 

1105 raise ValueError( 

1106 "Failed to find data adapter that can handle input: {}, {}".format( 

1107 _type_name(x), _type_name(y) 

1108 ) 

1109 ) 

1110 elif len(adapter_cls) > 1: 

1111 raise RuntimeError( 

1112 "Data adapters should be mutually exclusive for " 

1113 "handling inputs. Found multiple adapters {} to handle " 

1114 "input: {}, {}".format(adapter_cls, _type_name(x), _type_name(y)) 

1115 ) 

1116 # Instrument the data adapter usage before returning it 

1117 keras_data_adapter_gauge.get_cell(adapter_cls[0].__name__).set(True) 

1118 return adapter_cls[0] 

1119 

1120 

1121def _type_name(x): 

1122 """Generates a description of the type of an object.""" 

1123 if isinstance(x, dict): 

1124 key_types = set(_type_name(key) for key in x.keys()) 

1125 val_types = set(_type_name(key) for key in x.values()) 

1126 return f"({type(x)} containing {key_types} keys and {val_types} values)" 

1127 if isinstance(x, (list, tuple)): 

1128 types = set(_type_name(val) for val in x) 

1129 return f"({type(x)} containing values of types {types})" 

1130 return str(type(x)) 

1131 

1132 

1133def _process_tensorlike(inputs): 

1134 """Process tensor-like inputs. 

1135 

1136 This function: 

1137 

1138 (1) Converts `Numpy` arrays to `Tensor`s. 

1139 (2) Converts `Scipy` sparse matrices to `SparseTensor`s. 

1140 (3) Converts `pandas.Series` to `Tensor`s 

1141 (4) Converts `list`s to `tuple`s (for `tf.data` support). 

1142 

1143 Args: 

1144 inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like. 

1145 

1146 Returns: 

1147 Structure of `Tensor`s or tensor-like. 

1148 """ 

1149 

1150 def _convert_single_tensor(x): 

1151 if _is_pandas_series(x): 

1152 x = np.expand_dims(x.to_numpy(), axis=-1) 

1153 

1154 if isinstance(x, np.ndarray): 

1155 dtype = None 

1156 if issubclass(x.dtype.type, np.floating): 

1157 dtype = backend.floatx() 

1158 return tf.convert_to_tensor(x, dtype=dtype) 

1159 elif _is_scipy_sparse(x): 

1160 return _scipy_sparse_to_sparse_tensor(x) 

1161 return x 

1162 

1163 inputs = tf.nest.map_structure(_convert_single_tensor, inputs) 

1164 return tf.__internal__.nest.list_to_tuple(inputs) 

1165 

1166 

1167def is_none_or_empty(inputs): 

1168 # util method to check if the input is a None or a empty list. 

1169 # the python "not" check will raise an error like below if the input is a 

1170 # numpy array 

1171 # "The truth value of an array with more than one element is ambiguous. 

1172 # Use a.any() or a.all()" 

1173 return inputs is None or not tf.nest.flatten(inputs) 

1174 

1175 

1176def broadcast_sample_weight_modes(target_structure, sample_weight_modes): 

1177 """Match sample_weight_modes structure with output structure.""" 

1178 if target_structure is None or not tf.nest.flatten(target_structure): 

1179 return sample_weight_modes 

1180 

1181 if isinstance(sample_weight_modes, str): 

1182 if isinstance(target_structure, dict): 

1183 return {key: sample_weight_modes for key in target_structure.keys()} 

1184 return [sample_weight_modes for _ in target_structure] 

1185 

1186 if sample_weight_modes: 

1187 try: 

1188 tf.nest.assert_same_structure( 

1189 training_utils.list_to_tuple(target_structure), 

1190 training_utils.list_to_tuple(sample_weight_modes), 

1191 ) 

1192 except (ValueError, TypeError): 

1193 target_str = str( 

1194 tf.nest.map_structure(lambda _: "...", target_structure) 

1195 ) 

1196 mode_str = str( 

1197 tf.nest.map_structure(lambda _: "...", sample_weight_modes) 

1198 ) 

1199 

1200 # Attempt to coerce sample_weight_modes to the target structure. 

1201 # This implicitly depends on the fact that Model flattens outputs 

1202 # for its internal representation. 

1203 try: 

1204 sample_weight_modes = tf.nest.pack_sequence_as( 

1205 target_structure, tf.nest.flatten(sample_weight_modes) 

1206 ) 

1207 logging.warning( 

1208 "sample_weight modes were coerced from\n " 

1209 "{}\n to \n {}".format(target_str, mode_str) 

1210 ) 

1211 except (ValueError, TypeError): 

1212 raise ValueError( 

1213 "Unable to match target structure and sample_weight_modes " 

1214 "structure:\n {}\n to \n {}".format( 

1215 target_str, mode_str 

1216 ) 

1217 ) 

1218 

1219 return sample_weight_modes 

1220 

1221 

1222class DataHandler: 

1223 """Handles iterating over epoch-level `tf.data.Iterator` objects.""" 

1224 

1225 def __init__( 

1226 self, 

1227 x, 

1228 y=None, 

1229 sample_weight=None, 

1230 batch_size=None, 

1231 steps_per_epoch=None, 

1232 initial_epoch=0, 

1233 epochs=1, 

1234 shuffle=False, 

1235 class_weight=None, 

1236 max_queue_size=10, 

1237 workers=1, 

1238 use_multiprocessing=False, 

1239 model=None, 

1240 steps_per_execution=None, 

1241 distribute=True, 

1242 pss_evaluation_shards=0, 

1243 ): 

1244 """Initializes a `DataHandler`. 

1245 

1246 Arguments: 

1247 x: See `Model.fit`. 

1248 y: See `Model.fit`. 

1249 sample_weight: See `Model.fit`. 

1250 batch_size: See `Model.fit`. 

1251 steps_per_epoch: See `Model.fit`. 

1252 initial_epoch: See `Model.fit`. 

1253 epochs: See `Model.fit`. 

1254 shuffle: See `Model.fit`. 

1255 class_weight: See `Model.fit`. 

1256 max_queue_size: See `Model.fit`. 

1257 workers: See `Model.fit`. 

1258 use_multiprocessing: See `Model.fit`. 

1259 model: The `Model` instance. Needed in order to correctly `build` the 

1260 `Model` using generator-like inputs (see `GeneratorDataAdapter`). 

1261 steps_per_execution: See `Model.compile`. 

1262 distribute: Whether to distribute the `tf.dataset`. 

1263 `PreprocessingLayer.adapt` does not support distributed datasets, 

1264 `Model` should always set this to `True`. 

1265 pss_evaluation_shards: See `Model.fit`. 

1266 """ 

1267 

1268 self._initial_epoch = initial_epoch 

1269 self._initial_step = 0 

1270 self._epochs = epochs 

1271 self._insufficient_data = False 

1272 self._model = model 

1273 

1274 self._steps_per_epoch = steps_per_epoch 

1275 

1276 # `steps_per_execution_value` is the cached initial value. 

1277 # `steps_per_execution` is mutable and may be changed by the DataAdapter 

1278 # to handle partial executions. 

1279 if steps_per_execution is None: 

1280 self._steps_per_execution = tf.Variable(1) 

1281 else: 

1282 self._steps_per_execution = steps_per_execution 

1283 

1284 adapter_cls = select_data_adapter(x, y) 

1285 self._adapter = adapter_cls( 

1286 x, 

1287 y, 

1288 batch_size=batch_size, 

1289 steps=steps_per_epoch, 

1290 epochs=epochs - initial_epoch, 

1291 sample_weights=sample_weight, 

1292 shuffle=shuffle, 

1293 max_queue_size=max_queue_size, 

1294 workers=workers, 

1295 use_multiprocessing=use_multiprocessing, 

1296 distribution_strategy=tf.distribute.get_strategy(), 

1297 model=model, 

1298 pss_evaluation_shards=pss_evaluation_shards, 

1299 ) 

1300 

1301 strategy = tf.distribute.get_strategy() 

1302 

1303 self._current_step = 0 

1304 self._step_increment = self._steps_per_execution.numpy().item() - 1 

1305 self._insufficient_data = False 

1306 

1307 self._configure_dataset_and_inferred_steps( 

1308 strategy, x, steps_per_epoch, class_weight, distribute 

1309 ) 

1310 

1311 def _configure_dataset_and_inferred_steps( 

1312 self, strategy, x, steps_per_epoch, class_weight, distribute 

1313 ): 

1314 """Configure the `_dataset` and `_inferred_steps` attributes.""" 

1315 del x 

1316 dataset = self._adapter.get_dataset() 

1317 if class_weight: 

1318 dataset = dataset.map(_make_class_weight_map_fn(class_weight)) 

1319 self._inferred_steps = self._infer_steps(steps_per_epoch, dataset) 

1320 

1321 # `PreprocessingLayer.adapt` does not currently support distributed 

1322 # datasets, so we pass `distribute=False` there. 

1323 if distribute and not _is_distributed_dataset(dataset): 

1324 dataset = strategy.experimental_distribute_dataset(dataset) 

1325 self._dataset = dataset 

1326 self._validate_data_handler() 

1327 

1328 def enumerate_epochs(self): 

1329 """Yields `(epoch, tf.data.Iterator)`.""" 

1330 with self._truncate_execution_to_epoch(): 

1331 data_iterator = iter(self._dataset) 

1332 for epoch in range(self._initial_epoch, self._epochs): 

1333 if self._insufficient_data: # Set by `catch_stop_iteration`. 

1334 break 

1335 if self._adapter.should_recreate_iterator(): 

1336 data_iterator = iter(self._dataset) 

1337 if not isinstance(self._dataset, DistributedDataset): 

1338 steps = self._infer_steps( 

1339 self._steps_per_epoch, self._dataset 

1340 ) 

1341 if steps is not None: 

1342 self._inferred_steps = steps 

1343 yield epoch, data_iterator 

1344 self._adapter.on_epoch_end() 

1345 

1346 @contextlib.contextmanager 

1347 def _truncate_execution_to_epoch(self): 

1348 """Truncates steps per execution to at most one epoch.""" 

1349 should_truncate = ( 

1350 self._inferred_steps is not None 

1351 and self._steps_per_execution.numpy().item() > self._inferred_steps 

1352 ) 

1353 original_value = self._steps_per_execution.numpy().item() 

1354 try: 

1355 if should_truncate: 

1356 self._steps_per_execution.assign(self._inferred_steps) 

1357 yield 

1358 finally: 

1359 if should_truncate: 

1360 self._steps_per_execution.assign(original_value) 

1361 

1362 def sync(self): 

1363 context.async_wait() 

1364 

1365 @contextlib.contextmanager 

1366 def catch_stop_iteration(self): 

1367 """Catches errors when an iterator runs out of data.""" 

1368 with distributed_training_utils.maybe_preemption_handler_scope( 

1369 self._model 

1370 ): 

1371 try: 

1372 yield 

1373 self.sync() 

1374 except (StopIteration, tf.errors.OutOfRangeError): 

1375 if self._inferred_steps is None: 

1376 self._inferred_steps = self._current_step 

1377 else: 

1378 self._insufficient_data = True 

1379 total_epochs = self._epochs - self._initial_epoch 

1380 logging.warning( 

1381 "Your input ran out of data; interrupting training. " 

1382 "Make sure that your dataset or generator can generate " 

1383 "at least `steps_per_epoch * epochs` batches (in this " 

1384 "case, {} batches). You may need to use the repeat() " 

1385 "function when building your dataset.".format( 

1386 total_epochs * self._inferred_steps 

1387 ) 

1388 ) 

1389 

1390 def steps(self): 

1391 """Yields steps for the current epoch.""" 

1392 self._current_step = self._initial_step 

1393 self._initial_step = 0 

1394 # `self._inferred_steps` can be changed by `catch_stop_iteration`. 

1395 while ( 

1396 self._inferred_steps is None 

1397 or self._current_step < self._inferred_steps 

1398 ): 

1399 if self._insufficient_data: # Set by `catch_stop_iteration`. 

1400 break 

1401 original_spe = self._steps_per_execution.numpy().item() 

1402 can_run_full_execution = ( 

1403 original_spe == 1 

1404 or self._inferred_steps is None 

1405 or self._inferred_steps - self._current_step >= original_spe 

1406 ) 

1407 

1408 if can_run_full_execution: 

1409 self._step_increment = original_spe - 1 

1410 yield self._current_step 

1411 self._current_step += original_spe 

1412 else: 

1413 # Last partial execution. 

1414 steps_remaining = self._inferred_steps - self._current_step 

1415 self._steps_per_execution.assign(steps_remaining) 

1416 self._step_increment = steps_remaining - 1 

1417 yield self._current_step 

1418 self._current_step += steps_remaining 

1419 self._steps_per_execution.assign(original_spe) 

1420 

1421 @property 

1422 def step_increment(self): 

1423 """The number to increment the step for `on_batch_end` methods.""" 

1424 return self._step_increment 

1425 

1426 @property 

1427 def inferred_steps(self): 

1428 """The inferred steps per epoch of the created `Dataset`. 

1429 

1430 This will be `None` in the case where: 

1431 

1432 (1) A `Dataset` of unknown cardinality was passed to the `DataHandler`, 

1433 (2) `steps_per_epoch` was not provided, and 

1434 (3) The first epoch of iteration has not yet completed. 

1435 

1436 Returns: 

1437 The inferred steps per epoch of the created `Dataset`. 

1438 """ 

1439 return self._inferred_steps 

1440 

1441 @property 

1442 def should_sync(self): 

1443 # Catch OutOfRangeError for Datasets of unknown size. 

1444 # This blocks until the batch has finished executing. 

1445 # TODO(b/150292341): Allow multiple async steps here. 

1446 return self._inferred_steps is None 

1447 

1448 def _log_indefinite_training_warning(self): 

1449 logging.warning( 

1450 "The training loop will run indefinitely since you have " 

1451 "set `steps_per_epoch=-1`. Please use batch-level " 

1452 "callbacks to save checkpoints or log training progress, " 

1453 "etc" 

1454 ) 

1455 

1456 def _infer_steps(self, steps, dataset): 

1457 """Infers steps_per_epoch needed to loop through a dataset.""" 

1458 if steps == -1: 

1459 self._log_indefinite_training_warning() 

1460 return None 

1461 

1462 if steps is not None: 

1463 return steps 

1464 

1465 adapter_steps = self._adapter.get_size() 

1466 if adapter_steps is not None: 

1467 return adapter_steps 

1468 

1469 # tf.distribute's `PerWorkerDataset` does not inherit from 

1470 # `tf.data.Dataset` and in those cases we give up on inferring steps. 

1471 if not isinstance(dataset, tf.data.Dataset): 

1472 return None 

1473 

1474 size = tf.data.experimental.cardinality(dataset) 

1475 if size == tf.data.experimental.INFINITE_CARDINALITY and steps is None: 

1476 raise ValueError( 

1477 "When passing an infinitely repeating dataset, please specify " 

1478 "a `steps_per_epoch` value so that epoch level " 

1479 "callbacks continue to work. The value can be arbitrary, or a " 

1480 "number that you think correctly defines the size of an epoch. " 

1481 "Epoch-level callbacks will then be called at this interval." 

1482 ) 

1483 if size >= 0: 

1484 return size.numpy().item() 

1485 return None 

1486 

1487 @property 

1488 def _samples(self): 

1489 return self._adapter.get_samples() 

1490 

1491 def _validate_data_handler(self): 

1492 # TODO(b/152094471): Support this with DistIter.get_next_as_optional. 

1493 if ( 

1494 self._steps_per_execution.numpy().item() > 1 

1495 and self._inferred_steps is None 

1496 ): 

1497 raise ValueError( 

1498 "Could not infer the size of the data. With " 

1499 "`steps_per_execution > 1`, you must specify the number of " 

1500 "steps to run." 

1501 ) 

1502 

1503 

1504class _ClusterCoordinatorDataHandler(DataHandler): 

1505 """A `DataHandler` that is compatible with `ClusterCoordinator`.""" 

1506 

1507 def __init__(self, x, y=None, **kwargs): 

1508 if not _is_distributed_dataset(x) and not isinstance( 

1509 x, (dataset_creator.DatasetCreator, tf.data.Dataset) 

1510 ): 

1511 x = self._convert_to_dataset_creator(x, y, **kwargs) 

1512 

1513 super().__init__(x=x, **kwargs) 

1514 

1515 def _convert_to_dataset_creator(self, x, y, **kwargs): 

1516 """Converts non-tf.data.Dataset to `DatasetCreator` instances.""" 

1517 

1518 def _dataset_fn(input_context): 

1519 del input_context 

1520 data_adapter_cls = select_data_adapter(x, y) 

1521 return data_adapter_cls(x=x, y=y, **kwargs).get_dataset() 

1522 

1523 # This check is needed because types like `tf.data.Dataset` don't work 

1524 # with PSS yet. So only apply this logic to the types we can support. 

1525 if isinstance(x, _get_tensor_types()) and isinstance( 

1526 y, _get_tensor_types() 

1527 ): 

1528 return dataset_creator.DatasetCreator(_dataset_fn) 

1529 else: 

1530 raise NotImplementedError( 

1531 "Only `tf.keras.utils.experimental.DatasetCreator`, " 

1532 "`tf.Tensor`, numpy arrays and pandas dataframes are " 

1533 "supported types at this time." 

1534 ) 

1535 

1536 def _configure_dataset_and_inferred_steps( 

1537 self, strategy, x, steps_per_epoch, class_weight, distribute 

1538 ): 

1539 if isinstance(x, dataset_creator.DatasetCreator): 

1540 

1541 def per_worker_dataset_fn(): 

1542 

1543 return strategy.distribute_datasets_from_function( 

1544 x, options=x.input_options 

1545 ) 

1546 

1547 coordinator = self._model._cluster_coordinator 

1548 self._dataset = coordinator.create_per_worker_dataset( 

1549 per_worker_dataset_fn 

1550 ) 

1551 else: 

1552 assert distribute 

1553 if not _is_distributed_dataset(x): 

1554 x = strategy.experimental_distribute_dataset(x) 

1555 

1556 coordinator = self._model._cluster_coordinator 

1557 self._dataset = coordinator.create_per_worker_dataset(x) 

1558 

1559 if steps_per_epoch == -1: 

1560 self._inferred_steps = None 

1561 self._log_indefinite_training_warning() 

1562 else: 

1563 self._inferred_steps = steps_per_epoch 

1564 

1565 def sync(self): 

1566 self._model._cluster_coordinator.join() 

1567 

1568 

1569class _ClusterCoordinatorExactEvalDataHandler(_ClusterCoordinatorDataHandler): 

1570 def __init__(self, x, y=None, **kwargs): 

1571 super().__init__(x=x, **kwargs) 

1572 self._total_shards = kwargs.get("pss_evaluation_shards") 

1573 

1574 def _warn_if_not_file_shardable(self, dataset): 

1575 # Traverse backwards to find source dataset and check if that is one of 

1576 # the unshardable types 

1577 # TODO(b/268521864): expand this to inspect dataset function graphs and 

1578 # use the auto-sharding logic rather than re-creating it here. 

1579 cur_dataset = dataset 

1580 while hasattr(cur_dataset, "_input_dataset"): 

1581 cur_dataset = cur_dataset._input_dataset 

1582 if type(cur_dataset) in UNSHARDABLE_DATASET_TYPES: 

1583 logging.warning( 

1584 "Found source dataset of type {}. This type is not " 

1585 "efficiently shardable, so exact evaluation may be " 

1586 "slower than inexact evaluation. Try converting to " 

1587 "a TFRecord or other file-based dataset if " 

1588 "performance is a concern.".format(type(cur_dataset)) 

1589 ) 

1590 

1591 def _configure_dataset_and_inferred_steps( 

1592 self, strategy, x, steps_per_epoch, class_weight, distribute 

1593 ): 

1594 if isinstance(x, dataset_creator.DatasetCreator): 

1595 

1596 def per_worker_dataset_fn(): 

1597 ddf = strategy.distribute_datasets_from_function( 

1598 x, options=x.input_options 

1599 ) 

1600 return ddf 

1601 

1602 coordinator = self._model._cluster_coordinator 

1603 self._dataset = coordinator.create_per_worker_dataset( 

1604 per_worker_dataset_fn 

1605 ) 

1606 logging.info("dataset element spec: %r", self._dataset.element_spec) 

1607 self._dataset = self._dataset.build() 

1608 else: 

1609 # TODO(b/268226218): Support DistributedDataset input 

1610 if not _is_distributed_dataset(x): 

1611 self._warn_if_not_file_shardable(x) 

1612 x = strategy.experimental_distribute_dataset(x) 

1613 

1614 coordinator = self._model._cluster_coordinator 

1615 self._dataset = coordinator.create_per_worker_dataset(x) 

1616 self._dataset = self._dataset.build() 

1617 

1618 if steps_per_epoch == -1: 

1619 self._inferred_steps = None 

1620 self._log_indefinite_training_warning() 

1621 else: 

1622 self._inferred_steps = steps_per_epoch 

1623 

1624 def enumerate_epochs(self): 

1625 """Yields `(epoch, dataset)`.""" 

1626 for epoch in range(self._initial_epoch, self._epochs): 

1627 yield epoch, self._dataset 

1628 self._adapter.on_epoch_end() 

1629 

1630 def steps(self): 

1631 """Yields steps for the current epoch.""" 

1632 for step in range(self._total_shards): 

1633 yield step 

1634 

1635 

1636@keras_export("keras.__internal__.utils.get_data_handler", v1=[]) 

1637def get_data_handler(*args, **kwargs): 

1638 """Creates a `DataHandler`, providing standardized access to a `Dataset`. 

1639 

1640 See `DataHandler` for the list and definition of the arguments. See the 

1641 implementation of `Model.fit()`, `evaluate()`, or `predict()` methods 

1642 for complete usage examples. As a rule of tumb, `get_data_handler()` accepts 

1643 the same inputs as the `x` argument of `Model.fit()`. 

1644 

1645 Example: 

1646 

1647 ```python 

1648 def step(iterator): 

1649 data = next(iterator) 

1650 # result <= Do something with data 

1651 return result 

1652 tf_step = tf.function(step, reduce_retracing=True) 

1653 

1654 # Assume x is a tf.data Dataset. 

1655 data_handler = data_adapter.get_data_handler(x=x) 

1656 # Epoch iteration 

1657 for epo_idx, iterator in data_handler.enumerate_epochs(): 

1658 # Stop on dataset exhaustion. 

1659 with data_handler.catch_stop_iteration(): 

1660 for step in data_handler.steps(): # Step iteration 

1661 step_result = step(iterator) 

1662 ``` 

1663 

1664 Args: 

1665 *args: Arguments passed to the `DataHandler` constructor. 

1666 **kwargs: Arguments passed to the `DataHandler` constructor. 

1667 

1668 Returns: 

1669 A `DataHandler` object. If the model's cluster coordinate is set (e.g. the 

1670 model was defined under a parameter-server strategy), returns a 

1671 `_ClusterCoordinatorDataHandler`. 

1672 

1673 """ 

1674 if getattr(kwargs["model"], "_cluster_coordinator", None): 

1675 if kwargs.get("pss_evaluation_shards"): 

1676 return _ClusterCoordinatorExactEvalDataHandler(*args, **kwargs) 

1677 return _ClusterCoordinatorDataHandler(*args, **kwargs) 

1678 return DataHandler(*args, **kwargs) 

1679 

1680 

1681def _make_class_weight_map_fn(class_weight): 

1682 """Applies class weighting to a `Dataset`. 

1683 

1684 The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where 

1685 `y` must be a single `Tensor`. 

1686 

1687 Args: 

1688 class_weight: A map where the keys are integer class ids and values are 

1689 the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}` 

1690 

1691 Returns: 

1692 A function that can be used with `tf.data.Dataset.map` to apply class 

1693 weighting. 

1694 """ 

1695 class_ids = list(sorted(class_weight.keys())) 

1696 expected_class_ids = list(range(len(class_ids))) 

1697 if class_ids != expected_class_ids: 

1698 error_msg = ( 

1699 "Expected `class_weight` to be a dict with keys from 0 to one less " 

1700 "than the number of classes, found {}" 

1701 ).format(class_weight) 

1702 raise ValueError(error_msg) 

1703 

1704 class_weight_tensor = tf.convert_to_tensor( 

1705 [class_weight[int(c)] for c in class_ids] 

1706 ) 

1707 

1708 def _class_weights_map_fn(*data): 

1709 """Convert `class_weight` to `sample_weight`.""" 

1710 x, y, sw = unpack_x_y_sample_weight(data) 

1711 

1712 if tf.nest.is_nested(y): 

1713 raise ValueError( 

1714 "`class_weight` is only supported for Models with a single " 

1715 "output." 

1716 ) 

1717 

1718 if y.shape.rank >= 2: 

1719 y_classes = tf.__internal__.smart_cond.smart_cond( 

1720 backend.shape(y)[-1] > 1, 

1721 lambda: backend.argmax(y, axis=-1), 

1722 lambda: tf.cast(tf.round(tf.squeeze(y, axis=-1)), tf.int64), 

1723 ) 

1724 else: 

1725 # Special casing for rank 1, where we can guarantee sparse encoding. 

1726 y_classes = tf.cast(tf.round(y), tf.int64) 

1727 

1728 cw = tf.gather(class_weight_tensor, y_classes) 

1729 if sw is not None: 

1730 cw = tf.cast(cw, sw.dtype) 

1731 # `class_weight` and `sample_weight` are multiplicative. 

1732 # If class_weight has more than 2 dimensions, we need to reshape 

1733 # sample_weight to make broadcasting possible for multiplication. 

1734 rank_delta = cw.shape.rank - sw.shape.rank 

1735 sw = tf.reshape(sw, sw.shape + [1] * rank_delta) 

1736 sw = sw * cw 

1737 else: 

1738 sw = cw 

1739 return x, y, sw 

1740 

1741 return _class_weights_map_fn 

1742 

1743 

1744def train_validation_split(arrays, validation_split): 

1745 """Split arrays into train and validation subsets in deterministic order. 

1746 

1747 The last part of data will become validation data. 

1748 

1749 Args: 

1750 arrays: Tensors to split. Allowed inputs are arbitrarily nested structures 

1751 of Tensors and NumPy arrays. 

1752 validation_split: Float between 0 and 1. The proportion of the dataset to 

1753 include in the validation split. The rest of the dataset will be 

1754 included in the training split. 

1755 Returns: 

1756 `(train_arrays, validation_arrays)` 

1757 """ 

1758 

1759 def _can_split(t): 

1760 tensor_types = _get_tensor_types() 

1761 return isinstance(t, tensor_types) or t is None 

1762 

1763 flat_arrays = tf.nest.flatten(arrays) 

1764 unsplitable = [type(t) for t in flat_arrays if not _can_split(t)] 

1765 if unsplitable: 

1766 raise ValueError( 

1767 "`validation_split` is only supported for Tensors or NumPy " 

1768 "arrays, found following types in the input: {}".format(unsplitable) 

1769 ) 

1770 

1771 if all(t is None for t in flat_arrays): 

1772 return arrays, arrays 

1773 

1774 first_non_none = None 

1775 for t in flat_arrays: 

1776 if t is not None: 

1777 first_non_none = t 

1778 break 

1779 

1780 # Assumes all arrays have the same batch shape or are `None`. 

1781 batch_dim = int(first_non_none.shape[0]) 

1782 split_at = int(math.floor(batch_dim * (1.0 - validation_split))) 

1783 

1784 if split_at == 0 or split_at == batch_dim: 

1785 raise ValueError( 

1786 "Training data contains {batch_dim} samples, which is not " 

1787 "sufficient to split it into a validation and training set as " 

1788 "specified by `validation_split={validation_split}`. Either " 

1789 "provide more data, or a different value for the " 

1790 "`validation_split` argument.".format( 

1791 batch_dim=batch_dim, validation_split=validation_split 

1792 ) 

1793 ) 

1794 

1795 def _split(t, start, end): 

1796 if t is None: 

1797 return t 

1798 return t[start:end] 

1799 

1800 train_arrays = tf.nest.map_structure( 

1801 functools.partial(_split, start=0, end=split_at), arrays 

1802 ) 

1803 val_arrays = tf.nest.map_structure( 

1804 functools.partial(_split, start=split_at, end=batch_dim), arrays 

1805 ) 

1806 

1807 return train_arrays, val_arrays 

1808 

1809 

1810@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[]) 

1811def unpack_x_y_sample_weight(data): 

1812 """Unpacks user-provided data tuple. 

1813 

1814 This is a convenience utility to be used when overriding 

1815 `Model.train_step`, `Model.test_step`, or `Model.predict_step`. 

1816 This utility makes it easy to support data of the form `(x,)`, 

1817 `(x, y)`, or `(x, y, sample_weight)`. 

1818 

1819 Standalone usage: 

1820 

1821 >>> features_batch = tf.ones((10, 5)) 

1822 >>> labels_batch = tf.zeros((10, 5)) 

1823 >>> data = (features_batch, labels_batch) 

1824 >>> # `y` and `sample_weight` will default to `None` if not provided. 

1825 >>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) 

1826 >>> sample_weight is None 

1827 True 

1828 

1829 Example in overridden `Model.train_step`: 

1830 

1831 ```python 

1832 class MyModel(tf.keras.Model): 

1833 

1834 def train_step(self, data): 

1835 # If `sample_weight` is not provided, all samples will be weighted 

1836 # equally. 

1837 x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) 

1838 

1839 with tf.GradientTape() as tape: 

1840 y_pred = self(x, training=True) 

1841 loss = self.compiled_loss( 

1842 y, y_pred, sample_weight, regularization_losses=self.losses) 

1843 trainable_variables = self.trainable_variables 

1844 gradients = tape.gradient(loss, trainable_variables) 

1845 self.optimizer.apply_gradients(zip(gradients, trainable_variables)) 

1846 

1847 self.compiled_metrics.update_state(y, y_pred, sample_weight) 

1848 return {m.name: m.result() for m in self.metrics} 

1849 ``` 

1850 

1851 Args: 

1852 data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`. 

1853 

1854 Returns: 

1855 The unpacked tuple, with `None`s for `y` and `sample_weight` if they are 

1856 not provided. 

1857 """ 

1858 if isinstance(data, list): 

1859 data = tuple(data) 

1860 if not isinstance(data, tuple): 

1861 return (data, None, None) 

1862 elif len(data) == 1: 

1863 return (data[0], None, None) 

1864 elif len(data) == 2: 

1865 return (data[0], data[1], None) 

1866 elif len(data) == 3: 

1867 return (data[0], data[1], data[2]) 

1868 else: 

1869 error_msg = ( 

1870 "Data is expected to be in format `x`, `(x,)`, `(x, y)`, " 

1871 "or `(x, y, sample_weight)`, found: {}" 

1872 ).format(data) 

1873 raise ValueError(error_msg) 

1874 

1875 

1876@keras_export("keras.utils.pack_x_y_sample_weight", v1=[]) 

1877def pack_x_y_sample_weight(x, y=None, sample_weight=None): 

1878 """Packs user-provided data into a tuple. 

1879 

1880 This is a convenience utility for packing data into the tuple formats 

1881 that `Model.fit` uses. 

1882 

1883 Standalone usage: 

1884 

1885 >>> x = tf.ones((10, 1)) 

1886 >>> data = tf.keras.utils.pack_x_y_sample_weight(x) 

1887 >>> isinstance(data, tf.Tensor) 

1888 True 

1889 >>> y = tf.ones((10, 1)) 

1890 >>> data = tf.keras.utils.pack_x_y_sample_weight(x, y) 

1891 >>> isinstance(data, tuple) 

1892 True 

1893 >>> x, y = data 

1894 

1895 Args: 

1896 x: Features to pass to `Model`. 

1897 y: Ground-truth targets to pass to `Model`. 

1898 sample_weight: Sample weight for each element. 

1899 

1900 Returns: 

1901 Tuple in the format used in `Model.fit`. 

1902 """ 

1903 if y is None: 

1904 # For single x-input, we do no tuple wrapping since in this case 

1905 # there is no ambiguity. This also makes NumPy and Dataset 

1906 # consistent in that the user does not have to wrap their Dataset 

1907 # data in an unnecessary tuple. 

1908 if not isinstance(x, tuple or list): 

1909 return x 

1910 else: 

1911 return (x,) 

1912 elif sample_weight is None: 

1913 return (x, y) 

1914 else: 

1915 return (x, y, sample_weight) 

1916 

1917 

1918def single_batch_iterator( 

1919 strategy, x, y=None, sample_weight=None, class_weight=None 

1920): 

1921 """Creates a single-batch dataset.""" 

1922 x, y, sample_weight = _process_tensorlike((x, y, sample_weight)) 

1923 if y is None: 

1924 data = (x,) 

1925 elif sample_weight is None: 

1926 data = (x, y) 

1927 else: 

1928 data = (x, y, sample_weight) 

1929 

1930 _check_data_cardinality(data) 

1931 dataset = tf.data.Dataset.from_tensors(data) 

1932 if class_weight: 

1933 dataset = dataset.map(_make_class_weight_map_fn(class_weight)) 

1934 dataset = strategy.experimental_distribute_dataset(dataset) 

1935 return iter(dataset) 

1936 

1937 

1938def _check_data_cardinality(data): 

1939 num_samples = set(int(i.shape[0]) for i in tf.nest.flatten(data)) 

1940 if len(num_samples) > 1: 

1941 msg = "Data cardinality is ambiguous:\n" 

1942 for label, single_data in zip(["x", "y", "sample_weight"], data): 

1943 msg += " {} sizes: {}\n".format( 

1944 label, 

1945 ", ".join( 

1946 str(i.shape[0]) for i in tf.nest.flatten(single_data) 

1947 ), 

1948 ) 

1949 msg += "Make sure all arrays contain the same number of samples." 

1950 raise ValueError(msg) 

1951 

1952 

1953def _get_tensor_types(): 

1954 if pd is None: 

1955 return (tf.Tensor, np.ndarray) 

1956 else: 

1957 return (tf.Tensor, np.ndarray, pd.Series, pd.DataFrame) 

1958 

1959 

1960def _is_scipy_sparse(x): 

1961 try: 

1962 from scipy.sparse import issparse 

1963 

1964 return issparse(x) 

1965 except ImportError: 

1966 return False 

1967 

1968 

1969def _is_pandas_series(x): 

1970 if pd is None: 

1971 return False 

1972 else: 

1973 return isinstance(x, pd.Series) 

1974 

1975 

1976def _scipy_sparse_to_sparse_tensor(t): 

1977 """Converts a SciPy sparse matrix to a SparseTensor.""" 

1978 sparse_coo = t.tocoo() 

1979 row, col = sparse_coo.row, sparse_coo.col 

1980 data, shape = sparse_coo.data, sparse_coo.shape 

1981 if issubclass(data.dtype.type, np.floating): 

1982 data = data.astype(backend.floatx()) 

1983 indices = np.concatenate( 

1984 (np.expand_dims(row, axis=1), np.expand_dims(col, axis=1)), axis=1 

1985 ) 

1986 return tf.SparseTensor(indices, data, shape) 

1987 

1988 

1989def _is_distributed_dataset(ds): 

1990 return isinstance( 

1991 ds, 

1992 ( 

1993 tf.distribute.DistributedDataset, 

1994 tf.experimental.dtensor.DTensorDataset, 

1995 ), 

1996 ) 

1997