Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/data_adapter.py: 25%

746 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Adapter module that convert different input data objects into tf.dataset.""" 

16 

17import abc 

18import contextlib 

19import functools 

20import itertools 

21import math 

22import random 

23 

24import numpy as np 

25 

26from tensorflow.python.data.experimental.ops import cardinality 

27from tensorflow.python.data.ops import dataset_ops 

28from tensorflow.python.data.ops import iterator_ops 

29from tensorflow.python.data.ops import options as options_lib 

30from tensorflow.python.distribute import distribute_lib 

31from tensorflow.python.distribute import input_lib 

32from tensorflow.python.eager import context 

33from tensorflow.python.framework import dtypes 

34from tensorflow.python.framework import errors 

35from tensorflow.python.framework import ops 

36from tensorflow.python.framework import smart_cond 

37from tensorflow.python.framework import sparse_tensor 

38from tensorflow.python.framework import tensor_conversion 

39from tensorflow.python.framework import tensor_shape 

40from tensorflow.python.keras import backend 

41from tensorflow.python.keras.engine import training_utils 

42from tensorflow.python.keras.utils import data_utils 

43from tensorflow.python.keras.utils import dataset_creator 

44from tensorflow.python.keras.utils import tf_utils 

45from tensorflow.python.ops import array_ops 

46from tensorflow.python.ops import math_ops 

47from tensorflow.python.ops import random_ops 

48from tensorflow.python.ops import script_ops 

49from tensorflow.python.platform import tf_logging as logging 

50from tensorflow.python.types import data as data_types 

51from tensorflow.python.util import nest 

52from tensorflow.python.util.tf_export import keras_export 

53 

54 

55class DataAdapter(object, metaclass=abc.ABCMeta): 

56 """Base class for input data adapter. 

57 

58 In TF 2.0, tf.data is the preferred API for user to feed in data. In order 

59 to simplify the training code path, all the input data object will be 

60 converted to `tf.data.Dataset` if possible. 

61 

62 Note that since this class is mainly targeted for TF 2.0, it might have a lot 

63 of assumptions under the hood, eg eager context by default, distribution 

64 strategy, etc. In the meantime, some legacy feature support might be dropped, 

65 eg, Iterator from dataset API in v1, etc. 

66 

67 The sample usage of this class is like: 

68 

69 ``` 

70 x = tf.data.Dataset.range(100) 

71 adapter_cls = [NumpyArrayDataAdapter, ..., DatasetAdapter] 

72 applicable_adapters = [cls for cls in adapter_cls if cls.can_handle(x)] 

73 if len(applicable_adapters) != 1: 

74 raise ValueError("Expect only one adapter class to handle the input") 

75 

76 dataset = applicable_adapters[0](x).get_dataset() 

77 for data in dataset: 

78 # training 

79 ``` 

80 """ 

81 

82 @staticmethod 

83 def can_handle(x, y=None): 

84 """Whether the current DataAdapter could handle the input x and y. 

85 

86 Structure wise, x and y can be single object, or list of objects if there 

87 multiple input/output, or dictionary of objects when the intput/output are 

88 named. 

89 

90 Args: 

91 x: input features. 

92 y: target labels. Note that y could be None in the case of prediction. 

93 

94 Returns: 

95 boolean 

96 """ 

97 raise NotImplementedError 

98 

99 @abc.abstractmethod 

100 def __init__(self, x, y=None, **kwargs): 

101 """Create a DataAdapter based on data inputs. 

102 

103 The caller must make sure to call `can_handle()` first before invoking this 

104 method. Provide unsupported data type will result into unexpected behavior. 

105 

106 Args: 

107 x: input features. 

108 y: target labels. Note that y could be None in the case of prediction. 

109 **kwargs: Other keyword arguments for DataAdapter during the construction 

110 of the tf.dataset.Dataset. For example: 

111 - Numpy data might have `sample_weights` which will be used for 

112 weighting the loss function during training. 

113 - Numpy data might need to have `batch_size` parameter when constructing 

114 the dataset and iterator. 

115 - Certain input might need to be distribution strategy aware. When 

116 `distribution_strategy` is passed, the created dataset need to respect 

117 the strategy. 

118 DataAdapter might choose to ignore any keyword argument if it doesn't 

119 use it, or raise exception if any required argument is not provide. 

120 """ 

121 if not self.can_handle(x, y): 

122 raise ValueError("{} Cannot handle input {}, {}".format( 

123 self.__class__, x, y)) 

124 

125 @abc.abstractmethod 

126 def get_dataset(self): 

127 """Get a dataset instance for the current DataAdapter. 

128 

129 Note that the dataset returned does not repeat for epoch, so caller might 

130 need to create new iterator for the same dataset at the beginning of the 

131 epoch. This behavior might change in future. 

132 

133 Returns: 

134 An tf.dataset.Dataset. Caller might use the dataset in different 

135 context, eg iter(dataset) in eager to get the value directly, or in graph 

136 mode, provide the iterator tensor to Keras model function. 

137 """ 

138 raise NotImplementedError 

139 

140 @abc.abstractmethod 

141 def get_size(self): 

142 """Return the size (number of batches) for the dataset created. 

143 

144 For certain type of the data input, the number of batches is known, eg for 

145 Numpy data, the size is same as (number_of_element / batch_size). Whereas 

146 for dataset or python generator, the size is unknown since it may or may not 

147 have a end state. 

148 

149 Returns: 

150 int, the number of batches for the dataset, or None if it is unknown. The 

151 caller could use this to control the loop of training, show progress bar, 

152 or handle unexpected StopIteration error. 

153 """ 

154 raise NotImplementedError 

155 

156 @abc.abstractmethod 

157 def batch_size(self): 

158 """Return the batch size of the dataset created. 

159 

160 For certain type of the data input, the batch size is known, and even 

161 required, like numpy array. Where as for dataset, the batch is unknown 

162 unless we take a peek. 

163 

164 Returns: 

165 int, the batch size of the dataset, or None if it is unknown. 

166 """ 

167 raise NotImplementedError 

168 

169 def representative_batch_size(self): 

170 """Return a representative size for batches in the dataset. 

171 

172 This is not guaranteed to be the batch size for all batches in the 

173 dataset. It just needs to be a rough approximation for batch sizes in 

174 the dataset. 

175 

176 Returns: 

177 int, a representative size for batches found in the dataset, 

178 or None if it is unknown. 

179 """ 

180 return self.batch_size() 

181 

182 @abc.abstractmethod 

183 def has_partial_batch(self): 

184 """Whether the dataset has partial batch at the end.""" 

185 raise NotImplementedError 

186 

187 @abc.abstractmethod 

188 def partial_batch_size(self): 

189 """The size of the final partial batch for dataset. 

190 

191 Will return None if has_partial_batch is False or batch_size is None. 

192 """ 

193 raise NotImplementedError 

194 

195 @abc.abstractmethod 

196 def should_recreate_iterator(self): 

197 """Returns whether a new iterator should be created every epoch.""" 

198 raise NotImplementedError 

199 

200 def get_samples(self): 

201 """Returns number of samples in the data, or `None`.""" 

202 if not self.get_size() or not self.batch_size(): 

203 return None 

204 total_sample = self.get_size() * self.batch_size() 

205 if self.has_partial_batch(): 

206 total_sample -= (self.batch_size() - self.partial_batch_size()) 

207 return total_sample 

208 

209 def on_epoch_end(self): 

210 """A hook called after each epoch.""" 

211 pass 

212 

213 

214class TensorLikeDataAdapter(DataAdapter): 

215 """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy.""" 

216 

217 @staticmethod 

218 def can_handle(x, y=None): 

219 # TODO(kaftan): Check performance implications of using a flatten 

220 # here for other types of inputs. 

221 flat_inputs = nest.flatten(x) 

222 if y is not None: 

223 flat_inputs += nest.flatten(y) 

224 

225 tensor_types = _get_tensor_types() 

226 

227 def _is_tensor(v): 

228 if isinstance(v, tensor_types): 

229 return True 

230 return False 

231 

232 return all(_is_tensor(v) for v in flat_inputs) 

233 

234 def __init__(self, 

235 x, 

236 y=None, 

237 sample_weights=None, 

238 sample_weight_modes=None, 

239 batch_size=None, 

240 epochs=1, 

241 steps=None, 

242 shuffle=False, 

243 **kwargs): 

244 super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs) 

245 x, y, sample_weights = _process_tensorlike((x, y, sample_weights)) 

246 sample_weight_modes = broadcast_sample_weight_modes( 

247 sample_weights, sample_weight_modes) 

248 

249 # If sample_weights are not specified for an output use 1.0 as weights. 

250 (sample_weights, _, _) = training_utils.handle_partial_sample_weights( 

251 y, sample_weights, sample_weight_modes, check_all_flat=True) 

252 

253 inputs = pack_x_y_sample_weight(x, y, sample_weights) 

254 

255 num_samples = set(int(i.shape[0]) for i in nest.flatten(inputs)).pop() 

256 _check_data_cardinality(inputs) 

257 

258 # If batch_size is not passed but steps is, calculate from the input data. 

259 # Default to 32 for backwards compat. 

260 if not batch_size: 

261 batch_size = int(math.ceil(num_samples / steps)) if steps else 32 

262 

263 self._size = int(math.ceil(num_samples / batch_size)) 

264 self._batch_size = batch_size 

265 

266 num_full_batches = int(num_samples // batch_size) 

267 self._partial_batch_size = num_samples % batch_size 

268 

269 if isinstance(shuffle, str): 

270 shuffle = shuffle.lower() 

271 

272 self._shuffle = shuffle 

273 # Vectorized version of shuffle. 

274 # This is a performance improvement over using `from_tensor_slices`. 

275 # The indices of the data are shuffled and batched, and these indices 

276 # are then zipped with the data and used to extract a batch of the data 

277 # at each step. The performance improvements here come from: 

278 # 1. vectorized batch using gather 

279 # 2. parallelized map 

280 # 3. pipelined permutation generation 

281 # 4. optimized permutation batching 

282 # 5. disabled static optimizations 

283 

284 indices_dataset = dataset_ops.DatasetV2.range(1) 

285 if shuffle != "batch": 

286 indices_dataset = indices_dataset.repeat(epochs) 

287 

288 def permutation(_): 

289 # It turns out to be more performant to make a new set of indices rather 

290 # than reusing the same range Tensor. (presumably because of buffer 

291 # forwarding.) 

292 indices = math_ops.range(num_samples, dtype=dtypes.int64) 

293 if shuffle and shuffle != "batch": 

294 indices = random_ops.random_shuffle(indices) 

295 return indices 

296 

297 # We prefetch a single element. Computing large permutations can take quite 

298 # a while so we don't want to wait for prefetching over an epoch boundary to 

299 # trigger the next permutation. On the other hand, too many simultaneous 

300 # shuffles can contend on a hardware level and degrade all performance. 

301 indices_dataset = indices_dataset.map(permutation).prefetch(1) 

302 

303 def slice_batch_indices(indices): 

304 """Convert a Tensor of indices into a dataset of batched indices. 

305 

306 This step can be accomplished in several ways. The most natural is to 

307 slice the Tensor in a Dataset map. (With a condition on the upper index to 

308 handle the partial batch.) However it turns out that coercing the Tensor 

309 into a shape which is divisible by the batch size (and handling the last 

310 partial batch separately) allows for a much more favorable memory access 

311 pattern and improved performance. 

312 

313 Args: 

314 indices: Tensor which determines the data order for an entire epoch. 

315 

316 Returns: 

317 A Dataset of batched indices. 

318 """ 

319 num_in_full_batch = num_full_batches * batch_size 

320 first_k_indices = array_ops.slice(indices, [0], [num_in_full_batch]) 

321 first_k_indices = array_ops.reshape( 

322 first_k_indices, [num_full_batches, batch_size]) 

323 

324 flat_dataset = dataset_ops.DatasetV2.from_tensor_slices(first_k_indices) 

325 if self._partial_batch_size: 

326 index_remainder = dataset_ops.DatasetV2.from_tensors(array_ops.slice( 

327 indices, [num_in_full_batch], [self._partial_batch_size])) 

328 flat_dataset = flat_dataset.concatenate(index_remainder) 

329 

330 if shuffle == "batch": 

331 # 1024 is a magic constant that has not been properly evaluated 

332 flat_dataset = flat_dataset.shuffle(1024).repeat(epochs) 

333 return flat_dataset 

334 

335 indices_dataset = indices_dataset.flat_map(slice_batch_indices) 

336 

337 dataset = self.slice_inputs(indices_dataset, inputs) 

338 

339 if shuffle == "batch": 

340 def shuffle_batch(*batch): 

341 return nest.map_structure(random_ops.random_shuffle, batch) 

342 dataset = dataset.map(shuffle_batch) 

343 

344 self._dataset = dataset 

345 

346 def slice_inputs(self, indices_dataset, inputs): 

347 """Slice inputs into a Dataset of batches. 

348 

349 Given a Dataset of batch indices and the unsliced inputs, 

350 this step slices the inputs in a parallelized fashion 

351 and produces a dataset of input batches. 

352 

353 Args: 

354 indices_dataset: A Dataset of batched indices 

355 inputs: A python data structure that contains the inputs, targets, 

356 and possibly sample weights. 

357 

358 Returns: 

359 A Dataset of input batches matching the batch indices. 

360 """ 

361 dataset = dataset_ops.DatasetV2.zip(( 

362 indices_dataset, 

363 dataset_ops.DatasetV2.from_tensors(inputs).repeat() 

364 )) 

365 

366 def grab_batch(i, data): 

367 return nest.map_structure(lambda d: array_ops.gather(d, i, axis=0), data) 

368 

369 dataset = dataset.map( 

370 grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE) 

371 

372 # Default optimizations are disabled to avoid the overhead of (unnecessary) 

373 # input pipeline graph serialization and deserialization 

374 options = options_lib.Options() 

375 options.experimental_optimization.apply_default_optimizations = False 

376 if self._shuffle: 

377 # See b/141490660 for more details. 

378 options.experimental_external_state_policy = ( 

379 options_lib.ExternalStatePolicy.IGNORE) 

380 dataset = dataset.with_options(options) 

381 return dataset 

382 

383 def get_dataset(self): 

384 return self._dataset 

385 

386 def get_size(self): 

387 return self._size 

388 

389 def batch_size(self): 

390 return self._batch_size 

391 

392 def has_partial_batch(self): 

393 return self._partial_batch_size > 0 

394 

395 def partial_batch_size(self): 

396 return self._partial_batch_size or None 

397 

398 def should_recreate_iterator(self): 

399 # An infinite dataset is always created here. 

400 return False 

401 

402 

403class GenericArrayLikeDataAdapter(TensorLikeDataAdapter): 

404 """Adapter that handles array-like data without forcing it into memory. 

405 

406 This adapter handles array-like datasets that may be too big to fully 

407 fit into memory. 

408 

409 Specifically, this adapter handles any Python class which implements: 

410 `__get_item__`, `__len__`, `shape`, and `dtype` with the same meanings 

411 as Numpy, but it ignores any case where all the inputs are Tensors or Numpy 

412 arrays (because that case is handled by the base TensorLikeDataAdapter). 

413 

414 It ignores scipy sparse matrices and Composite Tensors because those are 

415 handled by the CompositeTensorDataAdapter. 

416 

417 It also does not handle lists/tuples of scalars, because those are handled 

418 by the ListsOfScalarsDataAdapter. 

419 """ 

420 

421 @staticmethod 

422 def can_handle(x, y=None): 

423 flat_inputs = nest.flatten(x) 

424 if y is not None: 

425 flat_inputs += nest.flatten(y) 

426 

427 def _is_array_like(v): 

428 """Return True if v is a Tensor, array, or is array-like.""" 

429 return ( 

430 hasattr(v, "__getitem__") and 

431 hasattr(v, "shape") and 

432 hasattr(v, "dtype") and 

433 hasattr(v, "__len__") 

434 ) 

435 

436 if (not TensorLikeDataAdapter.can_handle(x, y) and 

437 not CompositeTensorDataAdapter.can_handle(x, y)): 

438 return all(_is_array_like(v) for v in flat_inputs) 

439 else: 

440 return False 

441 

442 def __init__(self, *args, **kwargs): 

443 logging.warning( 

444 "Keras is training/fitting/evaluating on array-like data. Keras may " 

445 "not be optimized for this format, so if your input data format is " 

446 "supported by TensorFlow I/O (https://github.com/tensorflow/io) we " 

447 "recommend using that to load a Dataset instead.") 

448 

449 super(GenericArrayLikeDataAdapter, self).__init__(*args, **kwargs) 

450 

451 def slice_inputs(self, indices_dataset, inputs): 

452 """Slice inputs into a Dataset of batches. 

453 

454 Given a Dataset of batch indices and the unsliced inputs, 

455 this step slices the inputs in a parallelized fashion 

456 and produces a dataset of input batches. 

457 

458 Args: 

459 indices_dataset: A Dataset of batched indices 

460 inputs: A python data structure that contains the inputs, targets, 

461 and possibly sample weights. 

462 

463 Returns: 

464 A Dataset of input batches matching the batch indices. 

465 """ 

466 flat_inputs = nest.flatten(inputs) 

467 def dynamic_shape_like(t): 

468 shape = list(t.shape) 

469 shape[0] = None 

470 return tuple(shape) 

471 

472 flat_dtypes = [inp.dtype for inp in flat_inputs] 

473 contiguous = True 

474 if self._shuffle and self._shuffle != "batch": 

475 contiguous = False 

476 

477 def grab_batch(indices): 

478 """Grab a batch of data from the inputs.""" 

479 # This uses a py_function to avoid converting the array-like 

480 # into a Tensor before slicing it, because converting the array-like 

481 # to a Tensor may force it into memory.. 

482 def py_method(ind): 

483 def slice_array(data): 

484 return training_utils.slice_arrays(data, ind.numpy(), 

485 contiguous=contiguous) 

486 return [slice_array(inp) for inp in flat_inputs] 

487 

488 flat_out = script_ops.eager_py_func(py_method, [indices], flat_dtypes) 

489 for v, original_inp in zip(flat_out, flat_inputs): 

490 v.set_shape(dynamic_shape_like(original_inp)) 

491 return nest.pack_sequence_as(inputs, flat_out) 

492 

493 dataset = indices_dataset.map( 

494 grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE) 

495 

496 return dataset 

497 

498 

499class DatasetCreatorAdapter(DataAdapter): 

500 """Adapter that handles dataset functions.""" 

501 

502 def __init__(self, x, y, steps=None, distribution_strategy=None, **kwargs): 

503 super(DatasetCreatorAdapter, self).__init__(x, **kwargs) 

504 

505 if not isinstance(x, dataset_creator.DatasetCreator): 

506 raise TypeError("The input of a `DatasetCreatorAdapter` should be a " 

507 "`DatasetCreator` but it received type {}.".format( 

508 type(x))) 

509 if steps is None: 

510 raise ValueError("When using a " 

511 "`tf.keras.utils.experimental.DatasetCreator`, " 

512 "`steps_per_epoch`, `validation_steps` or `steps` " 

513 "argument must be provided in `Model.fit`, " 

514 "`Model.evaluate`, or `Model.predict`.") 

515 self.dataset_creator = x 

516 self.steps = steps 

517 self.strategy = distribution_strategy 

518 

519 @staticmethod 

520 def can_handle(x, y=None): 

521 if isinstance(x, dataset_creator.DatasetCreator): 

522 assert y is None 

523 return True 

524 

525 def should_recreate_iterator(self): 

526 # We expect users to shuffle the dataset in their `dataset_fn` supplied to 

527 # `DatasetCreator`. Since that is a buffered shuffle, we intend to not reset 

528 # the dataset so the batches that are not shuffled can still be pulled. 

529 return False 

530 

531 def get_size(self): 

532 return None # To be inferred by `DataHandler`. 

533 

534 def get_dataset(self): 

535 return self.strategy.distribute_datasets_from_function( 

536 self.dataset_creator, options=self.dataset_creator.input_options) 

537 

538 def batch_size(self): 

539 raise NotImplementedError() 

540 

541 def has_partial_batch(self): 

542 raise NotImplementedError() 

543 

544 def partial_batch_size(self): 

545 raise NotImplementedError() 

546 

547 

548class CompositeTensorDataAdapter(DataAdapter): 

549 """Adapter that handles composite tensor.""" 

550 

551 @staticmethod 

552 def can_handle(x, y=None): 

553 flat_inputs = nest.flatten(x) 

554 if y is not None: 

555 flat_inputs += nest.flatten(y) 

556 

557 def _is_composite(v): 

558 # Dataset/iterator/DistributedDataset inherits from CompositeTensor but 

559 # should be handled by DatasetAdapter and GeneratorAdapter. 

560 if (tf_utils.is_extension_type(v) and 

561 not isinstance(v, 

562 (dataset_ops.DatasetV2, iterator_ops.IteratorBase)) and 

563 not _is_distributed_dataset(v)): 

564 return True 

565 # Support Scipy sparse tensors if scipy is installed 

566 return _is_scipy_sparse(v) 

567 

568 def _is_tensor_or_composite(v): 

569 if isinstance(v, (ops.Tensor, np.ndarray)): 

570 return True 

571 return _is_composite(v) 

572 

573 return (any(_is_composite(v) for v in flat_inputs) and 

574 all(_is_tensor_or_composite(v) for v in flat_inputs)) 

575 

576 def __init__(self, 

577 x, 

578 y=None, 

579 sample_weights=None, 

580 sample_weight_modes=None, 

581 batch_size=None, 

582 steps=None, 

583 shuffle=False, 

584 **kwargs): 

585 super(CompositeTensorDataAdapter, self).__init__(x, y, **kwargs) 

586 x, y, sample_weights = _process_tensorlike((x, y, sample_weights)) 

587 sample_weight_modes = broadcast_sample_weight_modes( 

588 sample_weights, sample_weight_modes) 

589 

590 # If sample_weights are not specified for an output use 1.0 as weights. 

591 (sample_weights, _, _) = training_utils.handle_partial_sample_weights( 

592 y, sample_weights, sample_weight_modes, check_all_flat=True) 

593 

594 inputs = pack_x_y_sample_weight(x, y, sample_weights) 

595 

596 dataset = dataset_ops.DatasetV2.from_tensor_slices(inputs) 

597 num_samples = int(nest.flatten(x)[0].shape[0]) 

598 if shuffle: 

599 dataset = dataset.shuffle(num_samples) 

600 

601 # If batch_size is not passed but steps is, calculate from the input data. 

602 # Default to 32 for backwards compat. 

603 if not batch_size: 

604 batch_size = int(math.ceil(num_samples / steps)) if steps else 32 

605 

606 dataset = dataset.batch(batch_size) 

607 self._size = int(math.ceil(num_samples / batch_size)) 

608 self._batch_size = batch_size 

609 self._has_partial_batch = (self._size != (num_samples // batch_size)) 

610 

611 self._partial_batch_size = None 

612 if self._has_partial_batch: 

613 self._partial_batch_size = ( 

614 num_samples - (self._size - 1) * self._batch_size) 

615 

616 self._dataset = dataset 

617 

618 def get_dataset(self): 

619 return self._dataset 

620 

621 def get_size(self): 

622 return self._size 

623 

624 def batch_size(self): 

625 return self._batch_size 

626 

627 def has_partial_batch(self): 

628 return self._has_partial_batch 

629 

630 def partial_batch_size(self): 

631 return self._partial_batch_size 

632 

633 def should_recreate_iterator(self): 

634 return True 

635 

636 

637class ListsOfScalarsDataAdapter(DataAdapter): 

638 """Adapter that handles lists of scalars and lists of lists of scalars.""" 

639 

640 @staticmethod 

641 def can_handle(x, y=None): 

642 handles_x = ListsOfScalarsDataAdapter._is_list_of_scalars(x) 

643 handles_y = True 

644 if y is not None: 

645 handles_y = ListsOfScalarsDataAdapter._is_list_of_scalars(y) 

646 return handles_x and handles_y 

647 

648 @staticmethod 

649 def _is_list_of_scalars(inp): 

650 if isinstance(inp, (float, int, str, bytes, bytearray)): 

651 return True 

652 if isinstance(inp, (list, tuple)) and inp: 

653 return ListsOfScalarsDataAdapter._is_list_of_scalars(inp[0]) 

654 return False 

655 

656 def __init__(self, 

657 x, 

658 y=None, 

659 sample_weights=None, 

660 sample_weight_modes=None, 

661 batch_size=None, 

662 shuffle=False, 

663 **kwargs): 

664 super(ListsOfScalarsDataAdapter, self).__init__(x, y, **kwargs) 

665 x = np.asarray(x) 

666 if y is not None: 

667 y = np.asarray(y) 

668 if sample_weights is not None: 

669 sample_weights = np.asarray(sample_weights) 

670 sample_weight_modes = broadcast_sample_weight_modes( 

671 sample_weights, sample_weight_modes) 

672 

673 self._internal_adapter = TensorLikeDataAdapter( 

674 x, 

675 y=y, 

676 sample_weights=sample_weights, 

677 sample_weight_modes=sample_weight_modes, 

678 batch_size=batch_size, 

679 shuffle=shuffle, 

680 **kwargs) 

681 

682 def get_dataset(self): 

683 return self._internal_adapter.get_dataset() 

684 

685 def get_size(self): 

686 return self._internal_adapter.get_size() 

687 

688 def batch_size(self): 

689 return self._internal_adapter.batch_size() 

690 

691 def has_partial_batch(self): 

692 return self._internal_adapter.has_partial_batch() 

693 

694 def partial_batch_size(self): 

695 return self._internal_adapter.partial_batch_size() 

696 

697 def should_recreate_iterator(self): 

698 return True 

699 

700 

701class DatasetAdapter(DataAdapter): 

702 """Adapter that handles `tf.data.Dataset`.""" 

703 

704 @staticmethod 

705 def can_handle(x, y=None): 

706 return (isinstance(x, (data_types.DatasetV1, data_types.DatasetV2)) or 

707 _is_distributed_dataset(x)) 

708 

709 def __init__(self, 

710 x, 

711 y=None, 

712 sample_weights=None, 

713 steps=None, 

714 **kwargs): 

715 super(DatasetAdapter, self).__init__(x, y, **kwargs) 

716 # Note that the dataset instance is immutable, its fine to reuse the user 

717 # provided dataset. 

718 self._dataset = x 

719 

720 # The user-provided steps. 

721 self._user_steps = steps 

722 

723 self._validate_args(y, sample_weights, steps) 

724 

725 def get_dataset(self): 

726 return self._dataset 

727 

728 def get_size(self): 

729 return # Inferred in `DataHandler`. 

730 

731 def batch_size(self): 

732 return None 

733 

734 def has_partial_batch(self): 

735 return False 

736 

737 def partial_batch_size(self): 

738 return None 

739 

740 def should_recreate_iterator(self): 

741 # Since DistributedDatasets have no cardinality, the user must provide 

742 # all steps that need to be run, calling `.repeat()` as needed. 

743 if _is_distributed_dataset(self._dataset): 

744 return False 

745 

746 # If user doesn't supply `steps`, or if they supply `steps` that 

747 # exactly equals the size of the `Dataset`, create a new iterator 

748 # each epoch. 

749 return (self._user_steps is None or 

750 cardinality.cardinality(self._dataset).numpy() == self._user_steps) 

751 

752 def _validate_args(self, y, sample_weights, steps): 

753 """Validates `__init__` arguments.""" 

754 # Arguments that shouldn't be passed. 

755 if not is_none_or_empty(y): 

756 raise ValueError("`y` argument is not supported when using " 

757 "dataset as input.") 

758 if not is_none_or_empty(sample_weights): 

759 raise ValueError("`sample_weight` argument is not supported when using " 

760 "dataset as input.") 

761 

762 if steps is None: 

763 if _is_distributed_dataset(self._dataset): 

764 raise ValueError("When providing a distributed dataset, you must " 

765 "specify the number of steps to run.") 

766 

767 size = cardinality.cardinality(self._dataset).numpy() 

768 if size == cardinality.INFINITE and steps is None: 

769 raise ValueError( 

770 "When providing an infinite dataset, you must specify " 

771 "the number of steps to run (if you did not intend to " 

772 "create an infinite dataset, make sure to not call " 

773 "`repeat()` on the dataset).") 

774 

775 

776class GeneratorDataAdapter(DataAdapter): 

777 """Adapter that handles python generators and iterators.""" 

778 

779 @staticmethod 

780 def can_handle(x, y=None): 

781 return ((hasattr(x, "__next__") or hasattr(x, "next")) 

782 and hasattr(x, "__iter__") 

783 and not isinstance(x, data_utils.Sequence)) 

784 

785 def __init__(self, 

786 x, 

787 y=None, 

788 sample_weights=None, 

789 workers=1, 

790 use_multiprocessing=False, 

791 max_queue_size=10, 

792 model=None, 

793 **kwargs): 

794 # Generators should never shuffle as exhausting the generator in order to 

795 # shuffle the batches is inefficient. 

796 kwargs.pop("shuffle", None) 

797 

798 if not is_none_or_empty(y): 

799 raise ValueError("`y` argument is not supported when using " 

800 "python generator as input.") 

801 if not is_none_or_empty(sample_weights): 

802 raise ValueError("`sample_weight` argument is not supported when using " 

803 "python generator as input.") 

804 

805 super(GeneratorDataAdapter, self).__init__(x, y, **kwargs) 

806 

807 # Since we have to know the dtype of the python generator when we build the 

808 # dataset, we have to look at a batch to infer the structure. 

809 peek, x = self._peek_and_restore(x) 

810 peek = self._standardize_batch(peek) 

811 peek = _process_tensorlike(peek) 

812 

813 # Need to build the Model on concrete input shapes. 

814 if model is not None and not model.built: 

815 concrete_x, _, _ = unpack_x_y_sample_weight(peek) 

816 model.distribute_strategy.run( 

817 lambda x: model(x, training=False), args=(concrete_x,)) 

818 

819 self._first_batch_size = int(nest.flatten(peek)[0].shape[0]) 

820 

821 def _get_dynamic_shape(t): 

822 shape = t.shape 

823 # Unknown number of dimensions, `as_list` cannot be called. 

824 if shape.rank is None: 

825 return shape 

826 return tensor_shape.TensorShape([None for _ in shape.as_list()]) 

827 

828 output_shapes = nest.map_structure(_get_dynamic_shape, peek) 

829 output_types = nest.map_structure(lambda t: t.dtype, peek) 

830 

831 # Note that dataset API takes a callable that creates a generator object, 

832 # rather than generator itself, which is why we define a function here. 

833 generator_fn = self._handle_multiprocessing(x, workers, use_multiprocessing, 

834 max_queue_size) 

835 

836 def wrapped_generator(): 

837 for data in generator_fn(): 

838 yield self._standardize_batch(data) 

839 

840 dataset = dataset_ops.DatasetV2.from_generator( 

841 wrapped_generator, output_types, output_shapes=output_shapes) 

842 

843 if workers == 1 and not use_multiprocessing: 

844 dataset = dataset.prefetch(1) 

845 

846 self._dataset = dataset 

847 

848 def _standardize_batch(self, data): 

849 """Standardizes a batch output by a generator.""" 

850 # Removes `None`s. 

851 x, y, sample_weight = unpack_x_y_sample_weight(data) 

852 data = pack_x_y_sample_weight(x, y, sample_weight) 

853 

854 data = nest.list_to_tuple(data) 

855 

856 def _convert_dtype(t): 

857 if (isinstance(t, np.ndarray) and issubclass(t.dtype.type, np.floating)): 

858 return np.array(t, dtype=backend.floatx()) 

859 return t 

860 

861 data = nest.map_structure(_convert_dtype, data) 

862 return data 

863 

864 @staticmethod 

865 def _peek_and_restore(x): 

866 peek = next(x) 

867 return peek, itertools.chain([peek], x) 

868 

869 def _handle_multiprocessing(self, x, workers, use_multiprocessing, 

870 max_queue_size): 

871 """Create a callable, possibly including an Enqueuer.""" 

872 if workers > 1 or (workers > 0 and use_multiprocessing): 

873 def generator_fn(): 

874 enqueuer = data_utils.GeneratorEnqueuer( 

875 x, use_multiprocessing=use_multiprocessing) 

876 enqueuer.start(workers=workers, max_queue_size=max_queue_size) 

877 return enqueuer.get() 

878 else: 

879 generator_fn = lambda: x 

880 return generator_fn 

881 

882 def get_dataset(self): 

883 return self._dataset 

884 

885 def get_size(self): 

886 return None 

887 

888 def batch_size(self): 

889 return None 

890 

891 def representative_batch_size(self): 

892 return self._first_batch_size 

893 

894 def has_partial_batch(self): 

895 return False 

896 

897 def partial_batch_size(self): 

898 return 

899 

900 def should_recreate_iterator(self): 

901 return False 

902 

903 

904class KerasSequenceAdapter(GeneratorDataAdapter): 

905 """Adapter that handles `keras.utils.Sequence`.""" 

906 

907 @staticmethod 

908 def can_handle(x, y=None): 

909 return isinstance(x, data_utils.Sequence) 

910 

911 def __init__(self, 

912 x, 

913 y=None, 

914 sample_weights=None, 

915 shuffle=False, 

916 workers=1, 

917 use_multiprocessing=False, 

918 max_queue_size=10, 

919 model=None, 

920 **kwargs): 

921 if not is_none_or_empty(y): 

922 raise ValueError("`y` argument is not supported when using " 

923 "`keras.utils.Sequence` as input.") 

924 if not is_none_or_empty(sample_weights): 

925 raise ValueError("`sample_weight` argument is not supported when using " 

926 "`keras.utils.Sequence` as input.") 

927 

928 self._size = len(x) 

929 self._shuffle_sequence = shuffle 

930 self._keras_sequence = x 

931 self._enqueuer = None 

932 super(KerasSequenceAdapter, self).__init__( 

933 x, 

934 shuffle=False, # Shuffle is handed in the _make_callable override. 

935 workers=workers, 

936 use_multiprocessing=use_multiprocessing, 

937 max_queue_size=max_queue_size, 

938 model=model, 

939 **kwargs) 

940 

941 @staticmethod 

942 def _peek_and_restore(x): 

943 return x[0], x 

944 

945 def _handle_multiprocessing(self, x, workers, use_multiprocessing, 

946 max_queue_size): 

947 if workers > 1 or (workers > 0 and use_multiprocessing): 

948 def generator_fn(): 

949 self._enqueuer = data_utils.OrderedEnqueuer( 

950 x, use_multiprocessing=use_multiprocessing, 

951 shuffle=self._shuffle_sequence) 

952 self._enqueuer.start(workers=workers, max_queue_size=max_queue_size) 

953 return self._enqueuer.get() 

954 else: 

955 def generator_fn(): 

956 order = range(len(x)) 

957 if self._shuffle_sequence: 

958 # Match the shuffle convention in OrderedEnqueuer. 

959 order = list(order) 

960 random.shuffle(order) 

961 

962 for i in order: 

963 yield x[i] 

964 

965 return generator_fn 

966 

967 def get_size(self): 

968 return self._size 

969 

970 def should_recreate_iterator(self): 

971 return True 

972 

973 def on_epoch_end(self): 

974 if self._enqueuer: 

975 self._enqueuer.stop() 

976 self._keras_sequence.on_epoch_end() 

977 

978 

979ALL_ADAPTER_CLS = [ 

980 ListsOfScalarsDataAdapter, TensorLikeDataAdapter, 

981 GenericArrayLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter, 

982 KerasSequenceAdapter, CompositeTensorDataAdapter, DatasetCreatorAdapter 

983] 

984 

985 

986def select_data_adapter(x, y): 

987 """Selects a data adapter than can handle a given x and y.""" 

988 adapter_cls = [cls for cls in ALL_ADAPTER_CLS if cls.can_handle(x, y)] 

989 if not adapter_cls: 

990 # TODO(scottzhu): This should be a less implementation-specific error. 

991 raise ValueError( 

992 "Failed to find data adapter that can handle " 

993 "input: {}, {}".format( 

994 _type_name(x), _type_name(y))) 

995 elif len(adapter_cls) > 1: 

996 raise RuntimeError( 

997 "Data adapters should be mutually exclusive for " 

998 "handling inputs. Found multiple adapters {} to handle " 

999 "input: {}, {}".format( 

1000 adapter_cls, _type_name(x), _type_name(y))) 

1001 return adapter_cls[0] 

1002 

1003 

1004def _type_name(x): 

1005 """Generates a description of the type of an object.""" 

1006 if isinstance(x, dict): 

1007 key_types = set(_type_name(key) for key in x.keys()) 

1008 val_types = set(_type_name(key) for key in x.values()) 

1009 return "({} containing {} keys and {} values)".format( 

1010 type(x), key_types, val_types) 

1011 if isinstance(x, (list, tuple)): 

1012 types = set(_type_name(val) for val in x) 

1013 return "({} containing values of types {})".format( 

1014 type(x), types) 

1015 return str(type(x)) 

1016 

1017 

1018def _process_tensorlike(inputs): 

1019 """Process tensor-like inputs. 

1020 

1021 This function: 

1022 

1023 (1) Converts `Numpy` arrays to `Tensor`s. 

1024 (2) Converts `Scipy` sparse matrices to `SparseTensor`s. 

1025 (2) Converts `list`s to `tuple`s (for `tf.data` support). 

1026 

1027 Args: 

1028 inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like. 

1029 

1030 Returns: 

1031 Structure of `Tensor`s or tensor-like. 

1032 """ 

1033 

1034 def _convert_numpy_and_scipy(x): 

1035 if isinstance(x, np.ndarray): 

1036 dtype = None 

1037 if issubclass(x.dtype.type, np.floating): 

1038 dtype = backend.floatx() 

1039 return tensor_conversion.convert_to_tensor_v2_with_dispatch( 

1040 x, dtype=dtype 

1041 ) 

1042 elif _is_scipy_sparse(x): 

1043 return _scipy_sparse_to_sparse_tensor(x) 

1044 return x 

1045 

1046 inputs = nest.map_structure(_convert_numpy_and_scipy, inputs) 

1047 return nest.list_to_tuple(inputs) 

1048 

1049 

1050def is_none_or_empty(inputs): 

1051 # util method to check if the input is a None or a empty list. 

1052 # the python "not" check will raise an error like below if the input is a 

1053 # numpy array 

1054 # "The truth value of an array with more than one element is ambiguous. 

1055 # Use a.any() or a.all()" 

1056 return inputs is None or not nest.flatten(inputs) 

1057 

1058 

1059def broadcast_sample_weight_modes(target_structure, sample_weight_modes): 

1060 """Match sample_weight_modes structure with output structure.""" 

1061 if target_structure is None or not nest.flatten(target_structure): 

1062 return sample_weight_modes 

1063 

1064 if isinstance(sample_weight_modes, str): 

1065 if isinstance(target_structure, dict): 

1066 return {key: sample_weight_modes for key in target_structure.keys()} 

1067 return [sample_weight_modes for _ in target_structure] 

1068 

1069 if sample_weight_modes: 

1070 try: 

1071 nest.assert_same_structure( 

1072 training_utils.list_to_tuple(target_structure), 

1073 training_utils.list_to_tuple(sample_weight_modes)) 

1074 except (ValueError, TypeError): 

1075 target_str = str(nest.map_structure(lambda _: "...", target_structure)) 

1076 mode_str = str(nest.map_structure(lambda _: "...", sample_weight_modes)) 

1077 

1078 # Attempt to coerce sample_weight_modes to the target structure. This 

1079 # implicitly depends on the fact that Model flattens outputs for its 

1080 # internal representation. 

1081 try: 

1082 sample_weight_modes = nest.pack_sequence_as( 

1083 target_structure, nest.flatten(sample_weight_modes)) 

1084 logging.warning( 

1085 "sample_weight modes were coerced from\n {}\n to \n {}" 

1086 .format(target_str, mode_str)) 

1087 except (ValueError, TypeError): 

1088 raise ValueError( 

1089 "Unable to match target structure and sample_weight_modes " 

1090 "structure:\n {}\n to \n {}".format(target_str, mode_str)) 

1091 

1092 return sample_weight_modes 

1093 

1094 

1095class DataHandler(object): 

1096 """Handles iterating over epoch-level `tf.data.Iterator` objects.""" 

1097 

1098 def __init__(self, 

1099 x, 

1100 y=None, 

1101 sample_weight=None, 

1102 batch_size=None, 

1103 steps_per_epoch=None, 

1104 initial_epoch=0, 

1105 epochs=1, 

1106 shuffle=False, 

1107 class_weight=None, 

1108 max_queue_size=10, 

1109 workers=1, 

1110 use_multiprocessing=False, 

1111 model=None, 

1112 steps_per_execution=None, 

1113 distribute=True): 

1114 """Initializes a `DataHandler`. 

1115 

1116 Arguments: 

1117 x: See `Model.fit`. 

1118 y: See `Model.fit`. 

1119 sample_weight: See `Model.fit`. 

1120 batch_size: See `Model.fit`. 

1121 steps_per_epoch: See `Model.fit`. 

1122 initial_epoch: See `Model.fit`. 

1123 epochs: See `Model.fit`. 

1124 shuffle: See `Model.fit`. 

1125 class_weight: See `Model.fit`. 

1126 max_queue_size: See `Model.fit`. 

1127 workers: See `Model.fit`. 

1128 use_multiprocessing: See `Model.fit`. 

1129 model: The `Model` instance. Needed in order to correctly `build` the 

1130 `Model` using generator-like inputs (see `GeneratorDataAdapter`). 

1131 steps_per_execution: See `Model.compile`. 

1132 distribute: Whether to distribute the `tf.dataset`. 

1133 `PreprocessingLayer.adapt` does not support distributed datasets, 

1134 `Model` should always set this to `True`. 

1135 """ 

1136 

1137 self._initial_epoch = initial_epoch 

1138 self._epochs = epochs 

1139 self._insufficient_data = False 

1140 self._model = model 

1141 

1142 # `steps_per_execution_value` is the cached initial value. 

1143 # `steps_per_execution` is mutable and may be changed by the DataAdapter 

1144 # to handle partial executions. 

1145 if steps_per_execution is None: 

1146 self._steps_per_execution = 1 

1147 self._steps_per_execution_value = 1 

1148 else: 

1149 self._steps_per_execution = steps_per_execution 

1150 self._steps_per_execution_value = steps_per_execution.numpy().item() 

1151 

1152 adapter_cls = select_data_adapter(x, y) 

1153 self._adapter = adapter_cls( 

1154 x, 

1155 y, 

1156 batch_size=batch_size, 

1157 steps=steps_per_epoch, 

1158 epochs=epochs - initial_epoch, 

1159 sample_weights=sample_weight, 

1160 shuffle=shuffle, 

1161 max_queue_size=max_queue_size, 

1162 workers=workers, 

1163 use_multiprocessing=use_multiprocessing, 

1164 distribution_strategy=distribute_lib.get_strategy(), 

1165 model=model) 

1166 

1167 strategy = distribute_lib.get_strategy() 

1168 

1169 self._current_step = 0 

1170 self._step_increment = self._steps_per_execution_value - 1 

1171 self._insufficient_data = False 

1172 

1173 self._configure_dataset_and_inferred_steps(strategy, x, steps_per_epoch, 

1174 class_weight, distribute) 

1175 

1176 def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch, 

1177 class_weight, distribute): 

1178 """Configure the `_dataset` and `_inferred_steps` attributes.""" 

1179 del x 

1180 dataset = self._adapter.get_dataset() 

1181 if class_weight: 

1182 dataset = dataset.map(_make_class_weight_map_fn(class_weight)) 

1183 self._inferred_steps = self._infer_steps(steps_per_epoch, dataset) 

1184 

1185 # `PreprocessingLayer.adapt` does not currently support distributed 

1186 # datasets, so we pass `distribute=False` there. 

1187 if distribute and not _is_distributed_dataset(dataset): 

1188 dataset = strategy.experimental_distribute_dataset(dataset) 

1189 self._dataset = dataset 

1190 self._validate_data_handler() 

1191 

1192 def enumerate_epochs(self): 

1193 """Yields `(epoch, tf.data.Iterator)`.""" 

1194 with self._truncate_execution_to_epoch(): 

1195 data_iterator = iter(self._dataset) 

1196 for epoch in range(self._initial_epoch, self._epochs): 

1197 if self._insufficient_data: # Set by `catch_stop_iteration`. 

1198 break 

1199 if self._adapter.should_recreate_iterator(): 

1200 data_iterator = iter(self._dataset) 

1201 yield epoch, data_iterator 

1202 self._adapter.on_epoch_end() 

1203 

1204 @contextlib.contextmanager 

1205 def _truncate_execution_to_epoch(self): 

1206 """Truncates steps per execution to at most one epoch.""" 

1207 should_truncate = ( 

1208 self._inferred_steps is not None and 

1209 self._steps_per_execution_value > self._inferred_steps) 

1210 original_value = self._steps_per_execution_value 

1211 try: 

1212 if should_truncate: 

1213 self._steps_per_execution.assign(self._inferred_steps) 

1214 self._steps_per_execution_value = self._inferred_steps 

1215 yield 

1216 finally: 

1217 if should_truncate: 

1218 self._steps_per_execution.assign(original_value) 

1219 self._steps_per_execution_value = original_value 

1220 

1221 def sync(self): 

1222 context.async_wait() 

1223 

1224 @contextlib.contextmanager 

1225 def catch_stop_iteration(self): 

1226 """Catches errors when an iterator runs out of data.""" 

1227 try: 

1228 yield 

1229 self.sync() 

1230 except (StopIteration, errors.OutOfRangeError): 

1231 if self._inferred_steps is None: 

1232 self._inferred_steps = self._current_step 

1233 else: 

1234 self._insufficient_data = True 

1235 total_epochs = self._epochs - self._initial_epoch 

1236 logging.warning( 

1237 "Your input ran out of data; interrupting training. " 

1238 "Make sure that your dataset or generator can generate at " 

1239 "least `steps_per_epoch * epochs` batches (in this case, " 

1240 "{} batches). You may need to use the repeat() function " 

1241 "when building your dataset.".format(total_epochs * 

1242 self._inferred_steps)) 

1243 

1244 def steps(self): 

1245 """Yields steps for the current epoch.""" 

1246 self._current_step = 0 

1247 # `self._inferred_steps` can be changed by `catch_stop_iteration`. 

1248 while (self._inferred_steps is None or 

1249 self._current_step < self._inferred_steps): 

1250 if self._insufficient_data: # Set by `catch_stop_iteration`. 

1251 break 

1252 

1253 can_run_full_execution = ( 

1254 self._steps_per_execution_value == 1 or 

1255 self._inferred_steps is None or 

1256 self._inferred_steps - self._current_step >= 

1257 self._steps_per_execution_value) 

1258 

1259 if can_run_full_execution: 

1260 self._step_increment = self._steps_per_execution_value - 1 

1261 yield self._current_step 

1262 self._current_step += self._steps_per_execution_value 

1263 else: 

1264 # Last partial execution. 

1265 steps_remaining = self._inferred_steps - self._current_step 

1266 self._steps_per_execution.assign(steps_remaining) 

1267 self._step_increment = steps_remaining - 1 

1268 yield self._current_step 

1269 self._current_step += steps_remaining 

1270 self._steps_per_execution.assign(self._steps_per_execution_value) 

1271 

1272 @property 

1273 def step_increment(self): 

1274 """The number to increment the step for `on_batch_end` methods.""" 

1275 return self._step_increment 

1276 

1277 @property 

1278 def inferred_steps(self): 

1279 """The inferred steps per epoch of the created `Dataset`. 

1280 

1281 This will be `None` in the case where: 

1282 

1283 (1) A `Dataset` of unknown cardinality was passed to the `DataHandler`, and 

1284 (2) `steps_per_epoch` was not provided, and 

1285 (3) The first epoch of iteration has not yet completed. 

1286 

1287 Returns: 

1288 The inferred steps per epoch of the created `Dataset`. 

1289 """ 

1290 return self._inferred_steps 

1291 

1292 @property 

1293 def should_sync(self): 

1294 # Catch OutOfRangeError for Datasets of unknown size. 

1295 # This blocks until the batch has finished executing. 

1296 # TODO(b/150292341): Allow multiple async steps here. 

1297 return self._inferred_steps is None 

1298 

1299 def _log_indefinite_training_warning(self): 

1300 logging.warning("The training loop will run indefinitely since you have " 

1301 "set `steps_per_epoch=-1`. Please use batch-level " 

1302 "callbacks to save checkpoints or log training progress, " 

1303 "etc") 

1304 

1305 def _infer_steps(self, steps, dataset): 

1306 """Infers steps_per_epoch needed to loop through a dataset.""" 

1307 if steps == -1: 

1308 self._log_indefinite_training_warning() 

1309 return None 

1310 

1311 if steps is not None: 

1312 return steps 

1313 

1314 adapter_steps = self._adapter.get_size() 

1315 if adapter_steps is not None: 

1316 return adapter_steps 

1317 

1318 size = cardinality.cardinality(dataset) 

1319 if size == cardinality.INFINITE and steps is None: 

1320 raise ValueError( 

1321 "When passing an infinitely repeating dataset, please specify a " 

1322 "`steps_per_epoch` value so that epoch level " 

1323 "callbacks continue to work. The value can be arbitrary, or a number " 

1324 "that you think correctly defines the size of an epoch. " 

1325 "Epoch-level callbacks will then be called at this interval.") 

1326 if size >= 0: 

1327 return size.numpy().item() 

1328 return None 

1329 

1330 @property 

1331 def _samples(self): 

1332 return self._adapter.get_samples() 

1333 

1334 def _validate_data_handler(self): 

1335 # TODO(b/152094471): Support this with DistIter.get_next_as_optional. 

1336 if self._steps_per_execution_value > 1 and self._inferred_steps is None: 

1337 raise ValueError( 

1338 "Could not infer the size of the data. With " 

1339 "`steps_per_execution > 1`, you must specify the number of steps " 

1340 "to run.") 

1341 

1342 

1343class _ClusterCoordinatorDataHandler(DataHandler): 

1344 """A `DataHandler` that is compatible with `ClusterCoordinator`.""" 

1345 

1346 def __init__(self, x, y=None, **kwargs): 

1347 if not isinstance(x, dataset_creator.DatasetCreator): 

1348 x = self._convert_to_dataset_creator(x, y, **kwargs) 

1349 

1350 super().__init__(x=x, **kwargs) 

1351 

1352 def _convert_to_dataset_creator(self, x, y, **kwargs): 

1353 """Converts non-tf.data.Dataset to `DatasetCreator` instances.""" 

1354 

1355 def _dataset_fn(input_context): 

1356 del input_context 

1357 data_adapter_cls = select_data_adapter(x, y) 

1358 return data_adapter_cls(x=x, y=y, **kwargs).get_dataset() 

1359 

1360 # This check is needed because types like `tf.data.Dataset` don't work with 

1361 # PSS yet. So only apply this logic to the types we can support. 

1362 if (isinstance(x, _get_tensor_types()) and 

1363 isinstance(y, _get_tensor_types())): 

1364 return dataset_creator.DatasetCreator(_dataset_fn) 

1365 else: 

1366 raise NotImplementedError( 

1367 "Only `tf.keras.utils.experimental.DatasetCreator`, `tf.Tensor`, " 

1368 "numpy arrays and pandas dataframes are supported types at this " 

1369 "time.") 

1370 

1371 def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch, 

1372 class_weight, distribute): 

1373 if not isinstance(x, dataset_creator.DatasetCreator): 

1374 raise TypeError("When using `ParameterServerStrategy`, `x` must be a " 

1375 "`DatasetCreator`.") 

1376 

1377 def per_worker_dataset_fn(): 

1378 

1379 return strategy.distribute_datasets_from_function( 

1380 x, options=x.input_options) 

1381 

1382 self._dataset = self._model._cluster_coordinator.create_per_worker_dataset( # pylint: disable=protected-access 

1383 per_worker_dataset_fn) 

1384 

1385 if steps_per_epoch == -1: 

1386 self._inferred_steps = None 

1387 self._log_indefinite_training_warning() 

1388 else: 

1389 self._inferred_steps = steps_per_epoch 

1390 

1391 def sync(self): 

1392 self._model._cluster_coordinator.join() # pylint: disable=protected-access 

1393 

1394 

1395def get_data_handler(*args, **kwargs): 

1396 if getattr(kwargs["model"], "_cluster_coordinator", None): 

1397 return _ClusterCoordinatorDataHandler(*args, **kwargs) 

1398 return DataHandler(*args, **kwargs) 

1399 

1400 

1401def _make_class_weight_map_fn(class_weight): 

1402 """Applies class weighting to a `Dataset`. 

1403 

1404 The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where 

1405 `y` must be a single `Tensor`. 

1406 

1407 Args: 

1408 class_weight: A map where the keys are integer class ids and values are 

1409 the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}` 

1410 

1411 Returns: 

1412 A function that can be used with `tf.data.Dataset.map` to apply class 

1413 weighting. 

1414 """ 

1415 class_ids = list(sorted(class_weight.keys())) 

1416 expected_class_ids = list(range(len(class_ids))) 

1417 if class_ids != expected_class_ids: 

1418 error_msg = ( 

1419 "Expected `class_weight` to be a dict with keys from 0 to one less " 

1420 "than the number of classes, found {}").format(class_weight) 

1421 raise ValueError(error_msg) 

1422 

1423 class_weight_tensor = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

1424 [class_weight[int(c)] for c in class_ids] 

1425 ) 

1426 

1427 def _class_weights_map_fn(*data): 

1428 """Convert `class_weight` to `sample_weight`.""" 

1429 x, y, sw = unpack_x_y_sample_weight(data) 

1430 

1431 if nest.is_nested(y): 

1432 raise ValueError( 

1433 "`class_weight` is only supported for Models with a single output.") 

1434 

1435 if y.shape.rank > 2: 

1436 raise ValueError("`class_weight` not supported for " 

1437 "3+ dimensional targets.") 

1438 

1439 y_classes = smart_cond.smart_cond( 

1440 y.shape.rank == 2 and backend.shape(y)[1] > 1, 

1441 lambda: backend.argmax(y, axis=1), 

1442 lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64)) 

1443 

1444 cw = array_ops.gather_v2(class_weight_tensor, y_classes) 

1445 if sw is not None: 

1446 cw = math_ops.cast(cw, sw.dtype) 

1447 sw, cw = expand_1d((sw, cw)) 

1448 # `class_weight` and `sample_weight` are multiplicative. 

1449 sw = sw * cw 

1450 else: 

1451 sw = cw 

1452 

1453 return x, y, sw 

1454 

1455 return _class_weights_map_fn 

1456 

1457 

1458def expand_1d(data): 

1459 """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s.""" 

1460 

1461 def _expand_single_1d_tensor(t): 

1462 # Leaves `CompositeTensor`s as-is. 

1463 if (isinstance(t, ops.Tensor) and 

1464 isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1): 

1465 return array_ops.expand_dims_v2(t, axis=-1) 

1466 return t 

1467 

1468 return nest.map_structure(_expand_single_1d_tensor, data) 

1469 

1470 

1471def train_validation_split(arrays, validation_split): 

1472 """Split arrays into train and validation subsets in deterministic order. 

1473 

1474 The last part of data will become validation data. 

1475 

1476 Args: 

1477 arrays: Tensors to split. Allowed inputs are arbitrarily nested structures 

1478 of Tensors and NumPy arrays. 

1479 validation_split: Float between 0 and 1. The proportion of the dataset to 

1480 include in the validation split. The rest of the dataset will be included 

1481 in the training split. 

1482 Returns: 

1483 `(train_arrays, validation_arrays)` 

1484 """ 

1485 

1486 def _can_split(t): 

1487 tensor_types = _get_tensor_types() 

1488 return isinstance(t, tensor_types) or t is None 

1489 

1490 flat_arrays = nest.flatten(arrays) 

1491 unsplitable = [type(t) for t in flat_arrays if not _can_split(t)] 

1492 if unsplitable: 

1493 raise ValueError( 

1494 "`validation_split` is only supported for Tensors or NumPy " 

1495 "arrays, found following types in the input: {}".format(unsplitable)) 

1496 

1497 if all(t is None for t in flat_arrays): 

1498 return arrays, arrays 

1499 

1500 first_non_none = None 

1501 for t in flat_arrays: 

1502 if t is not None: 

1503 first_non_none = t 

1504 break 

1505 

1506 # Assumes all arrays have the same batch shape or are `None`. 

1507 batch_dim = int(first_non_none.shape[0]) 

1508 split_at = int(math.floor(batch_dim * (1. - validation_split))) 

1509 

1510 if split_at == 0 or split_at == batch_dim: 

1511 raise ValueError( 

1512 "Training data contains {batch_dim} samples, which is not sufficient " 

1513 "to split it into a validation and training set as specified by " 

1514 "`validation_split={validation_split}`. Either provide more data, or a " 

1515 "different value for the `validation_split` argument." .format( 

1516 batch_dim=batch_dim, validation_split=validation_split)) 

1517 

1518 def _split(t, start, end): 

1519 if t is None: 

1520 return t 

1521 return t[start:end] 

1522 

1523 train_arrays = nest.map_structure( 

1524 functools.partial(_split, start=0, end=split_at), arrays) 

1525 val_arrays = nest.map_structure( 

1526 functools.partial(_split, start=split_at, end=batch_dim), arrays) 

1527 

1528 return train_arrays, val_arrays 

1529 

1530 

1531@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[]) 

1532def unpack_x_y_sample_weight(data): 

1533 """Unpacks user-provided data tuple. 

1534 

1535 This is a convenience utility to be used when overriding 

1536 `Model.train_step`, `Model.test_step`, or `Model.predict_step`. 

1537 This utility makes it easy to support data of the form `(x,)`, 

1538 `(x, y)`, or `(x, y, sample_weight)`. 

1539 

1540 Standalone usage: 

1541 

1542 >>> features_batch = tf.ones((10, 5)) 

1543 >>> labels_batch = tf.zeros((10, 5)) 

1544 >>> data = (features_batch, labels_batch) 

1545 >>> # `y` and `sample_weight` will default to `None` if not provided. 

1546 >>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) 

1547 >>> sample_weight is None 

1548 True 

1549 

1550 Example in overridden `Model.train_step`: 

1551 

1552 ```python 

1553 class MyModel(tf.keras.Model): 

1554 

1555 def train_step(self, data): 

1556 # If `sample_weight` is not provided, all samples will be weighted 

1557 # equally. 

1558 x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) 

1559 

1560 with tf.GradientTape() as tape: 

1561 y_pred = self(x, training=True) 

1562 loss = self.compiled_loss( 

1563 y, y_pred, sample_weight, regularization_losses=self.losses) 

1564 trainable_variables = self.trainable_variables 

1565 gradients = tape.gradient(loss, trainable_variables) 

1566 self.optimizer.apply_gradients(zip(gradients, trainable_variables)) 

1567 

1568 self.compiled_metrics.update_state(y, y_pred, sample_weight) 

1569 return {m.name: m.result() for m in self.metrics} 

1570 ``` 

1571 

1572 Args: 

1573 data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`. 

1574 

1575 Returns: 

1576 The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not 

1577 provided. 

1578 """ 

1579 if not isinstance(data, tuple): 

1580 return (data, None, None) 

1581 elif len(data) == 1: 

1582 return (data[0], None, None) 

1583 elif len(data) == 2: 

1584 return (data[0], data[1], None) 

1585 elif len(data) == 3: 

1586 return (data[0], data[1], data[2]) 

1587 else: 

1588 error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, " 

1589 "or `(x, y, sample_weight)`, found: {}").format(data) 

1590 raise ValueError(error_msg) 

1591 

1592 

1593@keras_export("keras.utils.pack_x_y_sample_weight", v1=[]) 

1594def pack_x_y_sample_weight(x, y=None, sample_weight=None): 

1595 """Packs user-provided data into a tuple. 

1596 

1597 This is a convenience utility for packing data into the tuple formats 

1598 that `Model.fit` uses. 

1599 

1600 Standalone usage: 

1601 

1602 >>> x = tf.ones((10, 1)) 

1603 >>> data = tf.keras.utils.pack_x_y_sample_weight(x) 

1604 >>> isinstance(data, tf.Tensor) 

1605 True 

1606 >>> y = tf.ones((10, 1)) 

1607 >>> data = tf.keras.utils.pack_x_y_sample_weight(x, y) 

1608 >>> isinstance(data, tuple) 

1609 True 

1610 >>> x, y = data 

1611 

1612 Args: 

1613 x: Features to pass to `Model`. 

1614 y: Ground-truth targets to pass to `Model`. 

1615 sample_weight: Sample weight for each element. 

1616 

1617 Returns: 

1618 Tuple in the format used in `Model.fit`. 

1619 """ 

1620 if y is None: 

1621 # For single x-input, we do no tuple wrapping since in this case 

1622 # there is no ambiguity. This also makes NumPy and Dataset 

1623 # consistent in that the user does not have to wrap their Dataset 

1624 # data in an unecessary tuple 

1625 if not nest.is_nested(x): 

1626 return x 

1627 else: 

1628 return (x,) 

1629 elif sample_weight is None: 

1630 return (x, y) 

1631 else: 

1632 return (x, y, sample_weight) 

1633 

1634 

1635def single_batch_iterator(strategy, 

1636 x, 

1637 y=None, 

1638 sample_weight=None, 

1639 class_weight=None): 

1640 """Creates a single-batch dataset.""" 

1641 x, y, sample_weight = _process_tensorlike((x, y, sample_weight)) 

1642 if y is None: 

1643 data = (x,) 

1644 elif sample_weight is None: 

1645 data = (x, y) 

1646 else: 

1647 data = (x, y, sample_weight) 

1648 

1649 _check_data_cardinality(data) 

1650 dataset = dataset_ops.DatasetV2.from_tensors(data) 

1651 if class_weight: 

1652 dataset = dataset.map(_make_class_weight_map_fn(class_weight)) 

1653 dataset = strategy.experimental_distribute_dataset(dataset) 

1654 return iter(dataset) 

1655 

1656 

1657def _check_data_cardinality(data): 

1658 num_samples = set(int(i.shape[0]) for i in nest.flatten(data)) 

1659 if len(num_samples) > 1: 

1660 msg = "Data cardinality is ambiguous:\n" 

1661 for label, single_data in zip(["x", "y", "sample_weight"], data): 

1662 msg += " {} sizes: {}\n".format( 

1663 label, ", ".join(str(i.shape[0]) for i in nest.flatten(single_data))) 

1664 msg += "Make sure all arrays contain the same number of samples." 

1665 raise ValueError(msg) 

1666 

1667 

1668def _get_tensor_types(): 

1669 try: 

1670 import pandas as pd # pylint: disable=g-import-not-at-top 

1671 

1672 return (ops.Tensor, np.ndarray, pd.Series, pd.DataFrame) 

1673 except ImportError: 

1674 return (ops.Tensor, np.ndarray) 

1675 

1676 

1677def _is_scipy_sparse(x): 

1678 try: 

1679 from scipy.sparse import issparse # pylint: disable=g-import-not-at-top 

1680 

1681 return issparse(x) 

1682 except ImportError: 

1683 return False 

1684 

1685 

1686def _scipy_sparse_to_sparse_tensor(t): 

1687 """Converts a SciPy sparse matrix to a SparseTensor.""" 

1688 sparse_coo = t.tocoo() 

1689 row, col = sparse_coo.row, sparse_coo.col 

1690 data, shape = sparse_coo.data, sparse_coo.shape 

1691 if issubclass(data.dtype.type, np.floating): 

1692 data = data.astype(backend.floatx()) 

1693 indices = np.concatenate( 

1694 (np.expand_dims(row, axis=1), np.expand_dims(col, axis=1)), axis=1) 

1695 return sparse_tensor.SparseTensor(indices, data, shape) 

1696 

1697 

1698def _is_distributed_dataset(ds): 

1699 return isinstance(ds, input_lib.DistributedDatasetInterface)