Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/data_adapter.py: 24%
800 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Adapter module that convert different input data objects into tf.dataset."""
17import abc
18import contextlib
19import functools
20import itertools
21import math
22import random
24import numpy as np
25import tensorflow.compat.v2 as tf
27from keras.src import backend
28from keras.src.distribute import distributed_training_utils
29from keras.src.engine import training_utils
30from keras.src.utils import data_utils
31from keras.src.utils import dataset_creator
32from keras.src.utils import tf_utils
34# isort: off
35from tensorflow.python.distribute.input_lib import (
36 DistributedDataset,
37)
38from tensorflow.python.eager import context
39from tensorflow.python.framework import type_spec
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.util.tf_export import keras_export
42from tensorflow.python.data.ops import (
43 from_sparse_tensor_slices_op,
44)
45from tensorflow.python.data.ops import from_generator_op
46from tensorflow.python.data.ops import range_op
47from tensorflow.python.data.ops import from_tensors_op
48from tensorflow.python.data.ops import from_tensor_slices_op
50try:
51 import pandas as pd
52except ImportError:
53 pd = None
55keras_data_adapter_gauge = tf.__internal__.monitoring.BoolGauge(
56 "/tensorflow/api/keras/data_adapters", "keras data adapter usage", "method"
57)
60class DataAdapter(object, metaclass=abc.ABCMeta):
61 """Base class for input data adapter.
63 In TF 2.0, tf.data is the preferred API for user to feed in data. In order
64 to simplify the training code path, all the input data object will be
65 converted to `tf.data.Dataset` if possible.
67 Note that since this class is mainly targeted for TF 2.0, it might have a
68 lot of assumptions under the hood, e.g. eager context by default,
69 distribution strategy, etc. In the meantime, some legacy feature support
70 might be dropped, eg, Iterator from dataset API in v1, etc.
72 The sample usage of this class is like:
74 ```
75 x = tf.data.Dataset.range(100)
76 adapter_cls = [NumpyArrayDataAdapter, ..., DatasetAdapter]
77 applicable_adapters = [cls for cls in adapter_cls if cls.can_handle(x)]
78 if len(applicable_adapters) != 1:
79 raise ValueError("Expect only one adapter class to handle the input")
81 dataset = applicable_adapters[0](x).get_dataset()
82 for data in dataset:
83 # training
84 ```
85 """
87 @staticmethod
88 def can_handle(x, y=None):
89 """Whether the current DataAdapter could handle the input x and y.
91 Structure wise, x and y can be single object, or list of objects if
92 there multiple input/output, or dictionary of objects when the
93 input/output are named.
95 Args:
96 x: input features.
97 y: target labels. Note that y could be None in the case of prediction.
99 Returns:
100 boolean
101 """
102 raise NotImplementedError
104 @abc.abstractmethod
105 def __init__(self, x, y=None, **kwargs):
106 """Create a DataAdapter based on data inputs.
108 The caller must make sure to call `can_handle()` first before invoking
109 this method. Provide unsupported data type will result into unexpected
110 behavior.
112 Args:
113 x: input features.
114 y: target labels. Note that y could be None in the case of prediction.
115 **kwargs: Other keyword arguments for DataAdapter during the
116 construction of the tf.dataset.Dataset. For example:
117 - Numpy data might have `sample_weights` which will be used for
118 weighting the loss function during training.
119 - Numpy data might need to have `batch_size` parameter when
120 constructing the dataset and iterator.
121 - Certain input might need to be distribution strategy aware. When
122 `distribution_strategy` is passed, the created dataset need to
123 respect the strategy.
124 DataAdapter might choose to ignore any keyword argument if it
125 doesn't use it, or raise exception if any required argument is not
126 provided.
127 """
128 if not self.can_handle(x, y):
129 raise ValueError(f"{self.__class__} Cannot handle input {x}, {y}")
131 @abc.abstractmethod
132 def get_dataset(self):
133 """Get a dataset instance for the current DataAdapter.
135 Note that the dataset returned does not repeat for epoch, so caller
136 might need to create new iterator for the same dataset at the beginning
137 of the epoch. This behavior might change in the future.
139 Returns:
140 A `tf.data.Dataset`. Caller might use the dataset in different
141 context, e.g. iter(dataset) in eager to get the value directly, or in
142 graph mode, provide the iterator tensor to Keras model function.
143 """
144 raise NotImplementedError
146 @abc.abstractmethod
147 def get_size(self):
148 """Return the size (number of batches) for the dataset created.
150 For certain type of the data input, the number of batches is known, eg
151 for Numpy data, the size is same as (number_of_element / batch_size).
152 Whereas for dataset or python generator, the size is unknown since it
153 may or may not have an end state.
155 Returns:
156 int, the number of batches for the dataset, or None if it is unknown.
157 The caller could use this to control the loop of training, show
158 progress bar, or handle unexpected StopIteration error.
159 """
160 raise NotImplementedError
162 @abc.abstractmethod
163 def batch_size(self):
164 """Return the batch size of the dataset created.
166 For certain type of the data input, the batch size is known, and even
167 required, like numpy array. Whereas for dataset, the batch is unknown
168 unless we take a peek.
170 Returns:
171 int, the batch size of the dataset, or None if it is unknown.
172 """
173 raise NotImplementedError
175 def representative_batch_size(self):
176 """Return a representative size for batches in the dataset.
178 This is not guaranteed to be the batch size for all batches in the
179 dataset. It just needs to be a rough approximation for batch sizes in
180 the dataset.
182 Returns:
183 int, a representative size for batches found in the dataset,
184 or None if it is unknown.
185 """
186 return self.batch_size()
188 @abc.abstractmethod
189 def has_partial_batch(self):
190 """Whether the dataset has partial batch at the end."""
191 raise NotImplementedError
193 @abc.abstractmethod
194 def partial_batch_size(self):
195 """The size of the final partial batch for dataset.
197 Will return None if has_partial_batch is False or batch_size is None.
198 """
199 raise NotImplementedError
201 @abc.abstractmethod
202 def should_recreate_iterator(self):
203 """Returns whether a new iterator should be created every epoch."""
204 raise NotImplementedError
206 def get_samples(self):
207 """Returns number of samples in the data, or `None`."""
208 if not self.get_size() or not self.batch_size():
209 return None
210 total_sample = self.get_size() * self.batch_size()
211 if self.has_partial_batch():
212 total_sample -= self.batch_size() - self.partial_batch_size()
213 return total_sample
215 def on_epoch_end(self):
216 """A hook called after each epoch."""
217 pass
220class TensorLikeDataAdapter(DataAdapter):
221 """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
223 @staticmethod
224 def can_handle(x, y=None):
225 # TODO(kaftan): Check performance implications of using a flatten
226 # here for other types of inputs.
227 flat_inputs = tf.nest.flatten(x)
228 if y is not None:
229 flat_inputs += tf.nest.flatten(y)
231 tensor_types = _get_tensor_types()
233 def _is_tensor(v):
234 if isinstance(v, tensor_types):
235 return True
236 return False
238 return all(_is_tensor(v) for v in flat_inputs)
240 def __init__(
241 self,
242 x,
243 y=None,
244 sample_weights=None,
245 sample_weight_modes=None,
246 batch_size=None,
247 epochs=1,
248 steps=None,
249 shuffle=False,
250 **kwargs,
251 ):
252 super().__init__(x, y, **kwargs)
253 x, y, sample_weights = _process_tensorlike((x, y, sample_weights))
254 sample_weight_modes = broadcast_sample_weight_modes(
255 sample_weights, sample_weight_modes
256 )
258 # If sample_weights are not specified for an output use 1.0 as weights.
259 (sample_weights, _, _) = training_utils.handle_partial_sample_weights(
260 y, sample_weights, sample_weight_modes, check_all_flat=True
261 )
263 inputs = pack_x_y_sample_weight(x, y, sample_weights)
265 num_samples = set(
266 int(i.shape[0]) for i in tf.nest.flatten(inputs)
267 ).pop()
268 _check_data_cardinality(inputs)
270 # If batch_size is not passed but steps is, calculate from the input
271 # data. Defaults to `32` for backwards compatibility.
272 if not batch_size:
273 batch_size = int(math.ceil(num_samples / steps)) if steps else 32
275 self._size = int(math.ceil(num_samples / batch_size))
276 self._batch_size = batch_size
278 num_full_batches = int(num_samples // batch_size)
279 self._partial_batch_size = num_samples % batch_size
281 if isinstance(shuffle, str):
282 shuffle = shuffle.lower()
284 self._shuffle = shuffle
285 # Vectorized version of shuffle.
286 # This is a performance improvement over using `from_tensor_slices`.
287 # The indices of the data are shuffled and batched, and these indices
288 # are then zipped with the data and used to extract a batch of the data
289 # at each step. The performance improvements here come from:
290 # 1. vectorized batch using gather
291 # 2. parallelized map
292 # 3. pipelined permutation generation
293 # 4. optimized permutation batching
294 # 5. disabled static optimizations
296 indices_dataset = tf.data.Dataset.range(1)
297 if shuffle != "batch":
298 indices_dataset = indices_dataset.repeat(epochs)
300 def permutation(_):
301 # It turns out to be more performant to make a new set of indices
302 # rather than reusing the same range Tensor. (presumably because of
303 # buffer forwarding.)
304 indices = tf.range(num_samples, dtype=tf.int64)
305 if shuffle and shuffle != "batch":
306 indices = tf.random.shuffle(indices)
307 return indices
309 # We prefetch a single element. Computing large permutations can take
310 # quite a while so we don't want to wait for prefetching over an epoch
311 # boundary to trigger the next permutation. On the other hand, too many
312 # simultaneous shuffles can contend on a hardware level and degrade all
313 # performance.
314 indices_dataset = indices_dataset.map(permutation).prefetch(1)
316 def slice_batch_indices(indices):
317 """Convert a Tensor of indices into a dataset of batched indices.
319 This step can be accomplished in several ways. The most natural is
320 to slice the Tensor in a Dataset map. (With a condition on the upper
321 index to handle the partial batch.) However it turns out that
322 coercing the Tensor into a shape which is divisible by the batch
323 size (and handling the last partial batch separately) allows for a
324 much more favorable memory access pattern and improved performance.
326 Args:
327 indices: Tensor which determines the data order for an entire
328 epoch.
330 Returns:
331 A Dataset of batched indices.
332 """
333 num_in_full_batch = num_full_batches * batch_size
334 first_k_indices = tf.slice(indices, [0], [num_in_full_batch])
335 first_k_indices = tf.reshape(
336 first_k_indices, [num_full_batches, batch_size]
337 )
339 flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices)
340 if self._partial_batch_size:
341 index_remainder = tf.data.Dataset.from_tensors(
342 tf.slice(
343 indices, [num_in_full_batch], [self._partial_batch_size]
344 )
345 )
346 flat_dataset = flat_dataset.concatenate(index_remainder)
348 if shuffle == "batch":
349 # 1024 is a magic constant that has not been properly evaluated
350 flat_dataset = flat_dataset.shuffle(1024).repeat(epochs)
351 return flat_dataset
353 indices_dataset = indices_dataset.flat_map(slice_batch_indices)
355 dataset = self.slice_inputs(indices_dataset, inputs)
357 if shuffle == "batch":
359 def shuffle_batch(*batch):
360 return tf.nest.map_structure(tf.random.shuffle, batch)
362 dataset = dataset.map(shuffle_batch)
364 options = tf.data.Options()
365 options.experimental_distribute.auto_shard_policy = (
366 tf.data.experimental.AutoShardPolicy.DATA
367 )
368 dataset = dataset.with_options(options)
370 self._dataset = dataset.prefetch(tf.data.AUTOTUNE)
372 def slice_inputs(self, indices_dataset, inputs):
373 """Slice inputs into a Dataset of batches.
375 Given a Dataset of batch indices and the unsliced inputs,
376 this step slices the inputs in a parallelized fashion
377 and produces a dataset of input batches.
379 Args:
380 indices_dataset: A Dataset of batched indices
381 inputs: A python data structure that contains the inputs, targets,
382 and possibly sample weights.
384 Returns:
385 A Dataset of input batches matching the batch indices.
386 """
387 dataset = tf.data.Dataset.zip(
388 (indices_dataset, tf.data.Dataset.from_tensors(inputs).repeat())
389 )
391 def grab_batch(i, data):
392 return tf.nest.map_structure(
393 lambda d: tf.gather(d, i, axis=0), data
394 )
396 dataset = dataset.map(grab_batch, num_parallel_calls=tf.data.AUTOTUNE)
398 # Default optimizations are disabled to avoid the overhead of
399 # (unnecessary) input pipeline graph serialization and deserialization
400 options = tf.data.Options()
401 options.experimental_optimization.apply_default_optimizations = False
402 if self._shuffle:
403 # See b/141490660 for more details.
404 options.experimental_external_state_policy = (
405 tf.data.experimental.ExternalStatePolicy.IGNORE
406 )
407 dataset = dataset.with_options(options)
408 return dataset
410 def get_dataset(self):
411 return self._dataset
413 def get_size(self):
414 return self._size
416 def batch_size(self):
417 return self._batch_size
419 def has_partial_batch(self):
420 return self._partial_batch_size > 0
422 def partial_batch_size(self):
423 return self._partial_batch_size or None
425 def should_recreate_iterator(self):
426 # An infinite dataset is always created here.
427 return False
430class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
431 """Adapter that handles array-like data without forcing it into memory.
433 This adapter handles array-like datasets that may be too big to fully
434 fit into memory.
436 Specifically, this adapter handles any Python class which implements:
437 `__get_item__`, `__len__`, `shape`, and `dtype` with the same meanings
438 as Numpy, but it ignores any case where all the inputs are Tensors or Numpy
439 arrays (because that case is handled by the base TensorLikeDataAdapter).
441 It ignores scipy sparse matrices and Composite Tensors because those are
442 handled by the CompositeTensorDataAdapter.
444 It also does not handle lists/tuples of scalars, because those are handled
445 by the ListsOfScalarsDataAdapter.
446 """
448 @staticmethod
449 def can_handle(x, y=None):
450 flat_inputs = tf.nest.flatten(x)
451 if y is not None:
452 flat_inputs += tf.nest.flatten(y)
454 def _is_array_like(v):
455 """Return True if v is a Tensor, array, or is array-like."""
456 return (
457 hasattr(v, "__getitem__")
458 and hasattr(v, "shape")
459 and hasattr(v, "dtype")
460 and hasattr(v, "__len__")
461 )
463 if not TensorLikeDataAdapter.can_handle(
464 x, y
465 ) and not CompositeTensorDataAdapter.can_handle(x, y):
466 return all(_is_array_like(v) for v in flat_inputs)
467 else:
468 return False
470 def __init__(self, *args, **kwargs):
471 logging.warning(
472 "Keras is training/fitting/evaluating on array-like data. Keras "
473 "may not be optimized for this format, so if your input data "
474 "format is supported by TensorFlow I/O "
475 "(https://github.com/tensorflow/io) we recommend using that to "
476 "load a Dataset instead."
477 )
479 super().__init__(*args, **kwargs)
481 def slice_inputs(self, indices_dataset, inputs):
482 """Slice inputs into a Dataset of batches.
484 Given a Dataset of batch indices and the unsliced inputs,
485 this step slices the inputs in a parallelized fashion
486 and produces a dataset of input batches.
488 Args:
489 indices_dataset: A Dataset of batched indices
490 inputs: A python data structure that contains the inputs, targets,
491 and possibly sample weights.
493 Returns:
494 A Dataset of input batches matching the batch indices.
495 """
496 flat_inputs = tf.nest.flatten(inputs)
498 def dynamic_shape_like(t):
499 shape = list(t.shape)
500 shape[0] = None
501 return tuple(shape)
503 flat_dtypes = [inp.dtype for inp in flat_inputs]
504 contiguous = True
505 if self._shuffle and self._shuffle != "batch":
506 contiguous = False
508 def grab_batch(indices):
509 """Grab a batch of data from the inputs."""
510 # This uses a py_function to avoid converting the array-like
511 # into a Tensor before slicing it, because converting the array-like
512 # to a Tensor may force it into memory..
513 def py_method(ind):
514 def slice_array(data):
515 return training_utils.slice_arrays(
516 data, ind.numpy(), contiguous=contiguous
517 )
519 return [slice_array(inp) for inp in flat_inputs]
521 flat_out = tf.py_function(py_method, [indices], flat_dtypes)
522 for v, original_inp in zip(flat_out, flat_inputs):
523 v.set_shape(dynamic_shape_like(original_inp))
524 return tf.nest.pack_sequence_as(inputs, flat_out)
526 dataset = indices_dataset.map(
527 grab_batch, num_parallel_calls=tf.data.AUTOTUNE
528 )
530 return dataset
533class DatasetCreatorAdapter(DataAdapter):
534 """Adapter that handles dataset functions."""
536 def __init__(self, x, y, steps=None, distribution_strategy=None, **kwargs):
537 super().__init__(x, **kwargs)
539 if not isinstance(x, dataset_creator.DatasetCreator):
540 raise TypeError(
541 "The input of a `DatasetCreatorAdapter` should be a "
542 "`DatasetCreator` but it received type {}.".format(type(x))
543 )
544 if steps is None:
545 if not kwargs.get("pss_evaluation_shards"):
546 raise ValueError(
547 "When using a "
548 "`tf.keras.utils.experimental.DatasetCreator`, "
549 "`steps_per_epoch`, `validation_steps`, `steps`, or "
550 "`pss_evaluation_shards` argument must be provided in "
551 "`Model.fit`, `Model.evaluate`, or `Model.predict`."
552 )
553 self.dataset_creator = x
554 self.steps = steps
555 self.strategy = distribution_strategy
557 @staticmethod
558 def can_handle(x, y=None):
559 if isinstance(x, dataset_creator.DatasetCreator):
560 assert y is None
561 return True
563 def should_recreate_iterator(self):
564 # We expect users to shuffle the dataset in their `dataset_fn` supplied
565 # to `DatasetCreator`. Since that is a buffered shuffle, we intend to
566 # not reset the dataset so the batches that are not shuffled can still
567 # be pulled.
568 return False
570 def get_size(self):
571 return None # To be inferred by `DataHandler`.
573 def get_dataset(self):
574 return self.strategy.distribute_datasets_from_function(
575 self.dataset_creator, options=self.dataset_creator.input_options
576 )
578 def batch_size(self):
579 raise NotImplementedError()
581 def has_partial_batch(self):
582 raise NotImplementedError()
584 def partial_batch_size(self):
585 raise NotImplementedError()
588class CompositeTensorDataAdapter(DataAdapter):
589 """Adapter that handles composite tensor."""
591 @staticmethod
592 def can_handle(x, y=None):
593 flat_inputs = tf.nest.flatten(x)
594 if y is not None:
595 flat_inputs += tf.nest.flatten(y)
597 def _is_composite(v):
598 # Dataset/iterator/DistributedDataset inherits from CompositeTensor
599 # but should be handled by DatasetAdapter and GeneratorAdapter.
600 if (
601 tf_utils.is_extension_type(v)
602 and not isinstance(v, (tf.data.Dataset, tf.data.Iterator))
603 and not _is_distributed_dataset(v)
604 ):
605 return True
606 # Support Scipy sparse tensors if scipy is installed
607 return _is_scipy_sparse(v)
609 def _is_tensor_or_composite(v):
610 if isinstance(v, (tf.Tensor, np.ndarray)):
611 return True
612 return _is_composite(v)
614 return any(_is_composite(v) for v in flat_inputs) and all(
615 _is_tensor_or_composite(v) for v in flat_inputs
616 )
618 def __init__(
619 self,
620 x,
621 y=None,
622 sample_weights=None,
623 sample_weight_modes=None,
624 batch_size=None,
625 steps=None,
626 shuffle=False,
627 **kwargs,
628 ):
629 super().__init__(x, y, **kwargs)
630 x, y, sample_weights = _process_tensorlike((x, y, sample_weights))
631 sample_weight_modes = broadcast_sample_weight_modes(
632 sample_weights, sample_weight_modes
633 )
635 # If sample_weights are not specified for an output use 1.0 as weights.
636 (sample_weights, _, _) = training_utils.handle_partial_sample_weights(
637 y, sample_weights, sample_weight_modes, check_all_flat=True
638 )
640 inputs = pack_x_y_sample_weight(x, y, sample_weights)
642 dataset = tf.data.Dataset.from_tensor_slices(inputs)
643 num_samples = int(tf.nest.flatten(x)[0].shape[0])
644 if shuffle:
645 dataset = dataset.shuffle(num_samples)
647 # If batch_size is not passed but steps is, calculate from the input
648 # data. Defaults to `32` for backwards compatibility.
649 if not batch_size:
650 batch_size = int(math.ceil(num_samples / steps)) if steps else 32
652 dataset = dataset.batch(batch_size)
653 self._size = int(math.ceil(num_samples / batch_size))
654 self._batch_size = batch_size
655 self._has_partial_batch = self._size != (num_samples // batch_size)
657 self._partial_batch_size = None
658 if self._has_partial_batch:
659 self._partial_batch_size = (
660 num_samples - (self._size - 1) * self._batch_size
661 )
663 self._dataset = dataset.prefetch(tf.data.AUTOTUNE)
665 def get_dataset(self):
666 return self._dataset
668 def get_size(self):
669 return self._size
671 def batch_size(self):
672 return self._batch_size
674 def has_partial_batch(self):
675 return self._has_partial_batch
677 def partial_batch_size(self):
678 return self._partial_batch_size
680 def should_recreate_iterator(self):
681 return True
684class ListsOfScalarsDataAdapter(DataAdapter):
685 """Adapter that handles lists of scalars and lists of lists of scalars."""
687 @staticmethod
688 def can_handle(x, y=None):
689 handles_x = ListsOfScalarsDataAdapter._is_list_of_scalars(x)
690 handles_y = True
691 if y is not None:
692 handles_y = ListsOfScalarsDataAdapter._is_list_of_scalars(y)
693 return handles_x and handles_y
695 @staticmethod
696 def _is_list_of_scalars(inp):
697 if isinstance(inp, (float, int, str, bytes, bytearray)):
698 return True
699 if isinstance(inp, (list, tuple)) and inp:
700 return ListsOfScalarsDataAdapter._is_list_of_scalars(inp[0])
701 return False
703 def __init__(
704 self,
705 x,
706 y=None,
707 sample_weights=None,
708 sample_weight_modes=None,
709 batch_size=None,
710 shuffle=False,
711 **kwargs,
712 ):
713 super().__init__(x, y, **kwargs)
714 x = np.asarray(x)
715 if y is not None:
716 y = np.asarray(y)
717 if sample_weights is not None:
718 sample_weights = np.asarray(sample_weights)
719 sample_weight_modes = broadcast_sample_weight_modes(
720 sample_weights, sample_weight_modes
721 )
723 self._internal_adapter = TensorLikeDataAdapter(
724 x,
725 y=y,
726 sample_weights=sample_weights,
727 sample_weight_modes=sample_weight_modes,
728 batch_size=batch_size,
729 shuffle=shuffle,
730 **kwargs,
731 )
733 def get_dataset(self):
734 return self._internal_adapter.get_dataset()
736 def get_size(self):
737 return self._internal_adapter.get_size()
739 def batch_size(self):
740 return self._internal_adapter.batch_size()
742 def has_partial_batch(self):
743 return self._internal_adapter.has_partial_batch()
745 def partial_batch_size(self):
746 return self._internal_adapter.partial_batch_size()
748 def should_recreate_iterator(self):
749 return True
752class DatasetAdapter(DataAdapter):
753 """Adapter that handles `tf.data.Dataset`."""
755 @staticmethod
756 def can_handle(x, y=None):
757 return isinstance(
758 x, (tf.compat.v1.data.Dataset, tf.data.Dataset)
759 ) or _is_distributed_dataset(x)
761 def __init__(self, x, y=None, sample_weights=None, steps=None, **kwargs):
762 super().__init__(x, y, **kwargs)
763 # Note that the dataset instance is immutable, its fine to reuse the
764 # user provided dataset.
765 self._dataset = x
767 # The user-provided steps.
768 self._user_steps = steps
770 self._validate_args(
771 y, sample_weights, steps, kwargs.get("pss_evaluation_shards")
772 )
774 def get_dataset(self):
775 return self._dataset
777 def get_size(self):
778 return # Inferred in `DataHandler`.
780 def batch_size(self):
781 return None
783 def has_partial_batch(self):
784 return False
786 def partial_batch_size(self):
787 return None
789 def should_recreate_iterator(self):
790 # Since DistributedDatasets have no cardinality, the user must provide
791 # all steps that need to be run, calling `.repeat()` as needed.
792 if _is_distributed_dataset(self._dataset):
793 return False
795 # If user doesn't supply `steps`, or if they supply `steps` that
796 # exactly equals the size of the `Dataset`, create a new iterator
797 # each epoch.
798 return (
799 self._user_steps is None
800 or tf.data.experimental.cardinality(self._dataset).numpy()
801 == self._user_steps
802 )
804 def _validate_args(self, y, sample_weights, steps, pss_evaluation_shards):
805 """Validates `__init__` arguments."""
806 # Arguments that shouldn't be passed.
807 if not is_none_or_empty(y):
808 raise ValueError(
809 "`y` argument is not supported when using dataset as input."
810 )
811 if not is_none_or_empty(sample_weights):
812 raise ValueError(
813 "`sample_weight` argument is not supported when using "
814 "dataset as input."
815 )
817 if steps is None:
818 if _is_distributed_dataset(self._dataset):
819 if not pss_evaluation_shards:
820 raise ValueError(
821 "When providing a distributed dataset, you must "
822 "specify the number of steps to run."
823 )
824 else:
825 size = tf.data.experimental.cardinality(self._dataset).numpy()
826 if size == tf.data.experimental.INFINITE_CARDINALITY:
827 if pss_evaluation_shards:
828 raise ValueError(
829 "When performing exact evaluation, the dataset "
830 "must be finite. Make sure not to call `repeat()` "
831 "on your dataset."
832 )
833 else:
834 raise ValueError(
835 "When providing an infinite dataset, you must "
836 "specify the number of steps to run (if you did "
837 "not intend to create an infinite dataset, make "
838 "sure to not call `repeat()` on the dataset)."
839 )
842class GeneratorDataAdapter(DataAdapter):
843 """Adapter that handles python generators and iterators."""
845 @staticmethod
846 def can_handle(x, y=None):
847 return (
848 (hasattr(x, "__next__") or hasattr(x, "next"))
849 and hasattr(x, "__iter__")
850 and not isinstance(x, data_utils.Sequence)
851 )
853 def __init__(
854 self,
855 x,
856 y=None,
857 sample_weights=None,
858 workers=1,
859 use_multiprocessing=False,
860 max_queue_size=10,
861 model=None,
862 **kwargs,
863 ):
864 # Generators should never shuffle as exhausting the generator in order
865 # to shuffle the batches is inefficient.
866 kwargs.pop("shuffle", None)
868 if not is_none_or_empty(y):
869 raise ValueError(
870 "`y` argument is not supported when using "
871 "python generator as input."
872 )
873 if not is_none_or_empty(sample_weights):
874 raise ValueError(
875 "`sample_weight` argument is not supported when using "
876 "python generator as input."
877 )
879 super().__init__(x, y, **kwargs)
881 # Since we have to know the dtype of the python generator when we build
882 # the dataset, we have to look at a batch to infer the structure.
883 peek, x = self._peek_and_restore(x)
884 peek = self._standardize_batch(peek)
885 peek = _process_tensorlike(peek)
887 # Need to build the Model on concrete input shapes.
888 if model is not None and not model.built:
889 concrete_x, _, _ = unpack_x_y_sample_weight(peek)
890 try:
891 model.distribute_strategy.run(
892 lambda x: model(x, training=False), args=(concrete_x,)
893 )
894 except NotImplementedError:
895 # The above call may fail if the model is a container-like class
896 # that does not implement its own forward pass (e.g. a GAN or
897 # VAE where the forward pass is handled by subcomponents). Such
898 # a model does not need to be built.
899 pass
901 self._first_batch_size = int(tf.nest.flatten(peek)[0].shape[0])
903 def _get_tensor_spec(t):
904 # TODO(b/226395276): Remove _with_tensor_ranks_only usage.
905 return type_spec.type_spec_from_value(t)._with_tensor_ranks_only()
907 output_signature = tf.nest.map_structure(_get_tensor_spec, peek)
909 # Note that dataset API takes a callable that creates a generator
910 # object, rather than generator itself, which is why we define a
911 # function here.
912 generator_fn = self._handle_multiprocessing(
913 x, workers, use_multiprocessing, max_queue_size
914 )
916 def wrapped_generator():
917 for data in generator_fn():
918 yield self._standardize_batch(data)
920 dataset = tf.data.Dataset.from_generator(
921 wrapped_generator, output_signature=output_signature
922 )
924 if workers == 1 and not use_multiprocessing:
925 dataset = dataset.prefetch(1)
927 self._dataset = dataset.prefetch(tf.data.AUTOTUNE)
929 def _standardize_batch(self, data):
930 """Standardizes a batch output by a generator."""
931 # Removes `None`s.
932 x, y, sample_weight = unpack_x_y_sample_weight(data)
933 data = pack_x_y_sample_weight(x, y, sample_weight)
935 data = tf.__internal__.nest.list_to_tuple(data)
937 def _convert_dtype(t):
938 if isinstance(t, np.ndarray) and issubclass(
939 t.dtype.type, np.floating
940 ):
941 return np.array(t, dtype=backend.floatx())
942 return t
944 data = tf.nest.map_structure(_convert_dtype, data)
945 return data
947 @staticmethod
948 def _peek_and_restore(x):
949 peek = next(x)
950 return peek, itertools.chain([peek], x)
952 def _handle_multiprocessing(
953 self, x, workers, use_multiprocessing, max_queue_size
954 ):
955 """Create a callable, possibly including an Enqueuer."""
956 if workers > 1 or (workers > 0 and use_multiprocessing):
958 def generator_fn():
959 enqueuer = data_utils.GeneratorEnqueuer(
960 x, use_multiprocessing=use_multiprocessing
961 )
962 enqueuer.start(workers=workers, max_queue_size=max_queue_size)
963 return enqueuer.get()
965 else:
966 generator_fn = lambda: x
967 return generator_fn
969 def get_dataset(self):
970 return self._dataset
972 def get_size(self):
973 return None
975 def batch_size(self):
976 return None
978 def representative_batch_size(self):
979 return self._first_batch_size
981 def has_partial_batch(self):
982 return False
984 def partial_batch_size(self):
985 return
987 def should_recreate_iterator(self):
988 return False
991class KerasSequenceAdapter(GeneratorDataAdapter):
992 """Adapter that handles `keras.utils.Sequence`."""
994 @staticmethod
995 def can_handle(x, y=None):
996 return isinstance(x, data_utils.Sequence)
998 def __init__(
999 self,
1000 x,
1001 y=None,
1002 sample_weights=None,
1003 shuffle=False,
1004 workers=1,
1005 use_multiprocessing=False,
1006 max_queue_size=10,
1007 model=None,
1008 **kwargs,
1009 ):
1010 if not is_none_or_empty(y):
1011 raise ValueError(
1012 "`y` argument is not supported when using "
1013 "`keras.utils.Sequence` as input."
1014 )
1015 if not is_none_or_empty(sample_weights):
1016 raise ValueError(
1017 "`sample_weight` argument is not supported when using "
1018 "`keras.utils.Sequence` as input."
1019 )
1021 self._shuffle_sequence = shuffle
1022 self._keras_sequence = x
1023 self._enqueuer = None
1024 super().__init__(
1025 x,
1026 shuffle=False, # Shuffle is handed in the _make_callable override.
1027 workers=workers,
1028 use_multiprocessing=use_multiprocessing,
1029 max_queue_size=max_queue_size,
1030 model=model,
1031 **kwargs,
1032 )
1034 @staticmethod
1035 def _peek_and_restore(x):
1036 return x[0], x
1038 def _handle_multiprocessing(
1039 self, x, workers, use_multiprocessing, max_queue_size
1040 ):
1041 if workers > 1 or (workers > 0 and use_multiprocessing):
1043 def generator_fn():
1044 self._enqueuer = data_utils.OrderedEnqueuer(
1045 x,
1046 use_multiprocessing=use_multiprocessing,
1047 shuffle=self._shuffle_sequence,
1048 )
1049 self._enqueuer.start(
1050 workers=workers, max_queue_size=max_queue_size
1051 )
1052 return self._enqueuer.get()
1054 else:
1056 def generator_fn():
1057 order = range(len(x))
1058 if self._shuffle_sequence:
1059 # Match the shuffle convention in OrderedEnqueuer.
1060 order = list(order)
1061 random.shuffle(order)
1063 for i in order:
1064 yield x[i]
1066 return generator_fn
1068 def get_size(self):
1069 return len(self._keras_sequence)
1071 def should_recreate_iterator(self):
1072 return True
1074 def on_epoch_end(self):
1075 if self._enqueuer:
1076 self._enqueuer.stop()
1077 self._keras_sequence.on_epoch_end()
1080ALL_ADAPTER_CLS = [
1081 ListsOfScalarsDataAdapter,
1082 TensorLikeDataAdapter,
1083 GenericArrayLikeDataAdapter,
1084 DatasetAdapter,
1085 GeneratorDataAdapter,
1086 KerasSequenceAdapter,
1087 CompositeTensorDataAdapter,
1088 DatasetCreatorAdapter,
1089]
1091UNSHARDABLE_DATASET_TYPES = [
1092 from_generator_op._GeneratorDataset,
1093 range_op._RangeDataset,
1094 from_sparse_tensor_slices_op._SparseTensorSliceDataset,
1095 from_tensors_op._TensorDataset,
1096 from_tensor_slices_op._TensorSliceDataset,
1097]
1100def select_data_adapter(x, y):
1101 """Selects a data adapter that can handle a given x and y."""
1102 adapter_cls = [cls for cls in ALL_ADAPTER_CLS if cls.can_handle(x, y)]
1103 if not adapter_cls:
1104 # TODO(scottzhu): This should be a less implementation-specific error.
1105 raise ValueError(
1106 "Failed to find data adapter that can handle input: {}, {}".format(
1107 _type_name(x), _type_name(y)
1108 )
1109 )
1110 elif len(adapter_cls) > 1:
1111 raise RuntimeError(
1112 "Data adapters should be mutually exclusive for "
1113 "handling inputs. Found multiple adapters {} to handle "
1114 "input: {}, {}".format(adapter_cls, _type_name(x), _type_name(y))
1115 )
1116 # Instrument the data adapter usage before returning it
1117 keras_data_adapter_gauge.get_cell(adapter_cls[0].__name__).set(True)
1118 return adapter_cls[0]
1121def _type_name(x):
1122 """Generates a description of the type of an object."""
1123 if isinstance(x, dict):
1124 key_types = set(_type_name(key) for key in x.keys())
1125 val_types = set(_type_name(key) for key in x.values())
1126 return f"({type(x)} containing {key_types} keys and {val_types} values)"
1127 if isinstance(x, (list, tuple)):
1128 types = set(_type_name(val) for val in x)
1129 return f"({type(x)} containing values of types {types})"
1130 return str(type(x))
1133def _process_tensorlike(inputs):
1134 """Process tensor-like inputs.
1136 This function:
1138 (1) Converts `Numpy` arrays to `Tensor`s.
1139 (2) Converts `Scipy` sparse matrices to `SparseTensor`s.
1140 (3) Converts `pandas.Series` to `Tensor`s
1141 (4) Converts `list`s to `tuple`s (for `tf.data` support).
1143 Args:
1144 inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like.
1146 Returns:
1147 Structure of `Tensor`s or tensor-like.
1148 """
1150 def _convert_single_tensor(x):
1151 if _is_pandas_series(x):
1152 x = np.expand_dims(x.to_numpy(), axis=-1)
1154 if isinstance(x, np.ndarray):
1155 dtype = None
1156 if issubclass(x.dtype.type, np.floating):
1157 dtype = backend.floatx()
1158 return tf.convert_to_tensor(x, dtype=dtype)
1159 elif _is_scipy_sparse(x):
1160 return _scipy_sparse_to_sparse_tensor(x)
1161 return x
1163 inputs = tf.nest.map_structure(_convert_single_tensor, inputs)
1164 return tf.__internal__.nest.list_to_tuple(inputs)
1167def is_none_or_empty(inputs):
1168 # util method to check if the input is a None or a empty list.
1169 # the python "not" check will raise an error like below if the input is a
1170 # numpy array
1171 # "The truth value of an array with more than one element is ambiguous.
1172 # Use a.any() or a.all()"
1173 return inputs is None or not tf.nest.flatten(inputs)
1176def broadcast_sample_weight_modes(target_structure, sample_weight_modes):
1177 """Match sample_weight_modes structure with output structure."""
1178 if target_structure is None or not tf.nest.flatten(target_structure):
1179 return sample_weight_modes
1181 if isinstance(sample_weight_modes, str):
1182 if isinstance(target_structure, dict):
1183 return {key: sample_weight_modes for key in target_structure.keys()}
1184 return [sample_weight_modes for _ in target_structure]
1186 if sample_weight_modes:
1187 try:
1188 tf.nest.assert_same_structure(
1189 training_utils.list_to_tuple(target_structure),
1190 training_utils.list_to_tuple(sample_weight_modes),
1191 )
1192 except (ValueError, TypeError):
1193 target_str = str(
1194 tf.nest.map_structure(lambda _: "...", target_structure)
1195 )
1196 mode_str = str(
1197 tf.nest.map_structure(lambda _: "...", sample_weight_modes)
1198 )
1200 # Attempt to coerce sample_weight_modes to the target structure.
1201 # This implicitly depends on the fact that Model flattens outputs
1202 # for its internal representation.
1203 try:
1204 sample_weight_modes = tf.nest.pack_sequence_as(
1205 target_structure, tf.nest.flatten(sample_weight_modes)
1206 )
1207 logging.warning(
1208 "sample_weight modes were coerced from\n "
1209 "{}\n to \n {}".format(target_str, mode_str)
1210 )
1211 except (ValueError, TypeError):
1212 raise ValueError(
1213 "Unable to match target structure and sample_weight_modes "
1214 "structure:\n {}\n to \n {}".format(
1215 target_str, mode_str
1216 )
1217 )
1219 return sample_weight_modes
1222class DataHandler:
1223 """Handles iterating over epoch-level `tf.data.Iterator` objects."""
1225 def __init__(
1226 self,
1227 x,
1228 y=None,
1229 sample_weight=None,
1230 batch_size=None,
1231 steps_per_epoch=None,
1232 initial_epoch=0,
1233 epochs=1,
1234 shuffle=False,
1235 class_weight=None,
1236 max_queue_size=10,
1237 workers=1,
1238 use_multiprocessing=False,
1239 model=None,
1240 steps_per_execution=None,
1241 distribute=True,
1242 pss_evaluation_shards=0,
1243 ):
1244 """Initializes a `DataHandler`.
1246 Arguments:
1247 x: See `Model.fit`.
1248 y: See `Model.fit`.
1249 sample_weight: See `Model.fit`.
1250 batch_size: See `Model.fit`.
1251 steps_per_epoch: See `Model.fit`.
1252 initial_epoch: See `Model.fit`.
1253 epochs: See `Model.fit`.
1254 shuffle: See `Model.fit`.
1255 class_weight: See `Model.fit`.
1256 max_queue_size: See `Model.fit`.
1257 workers: See `Model.fit`.
1258 use_multiprocessing: See `Model.fit`.
1259 model: The `Model` instance. Needed in order to correctly `build` the
1260 `Model` using generator-like inputs (see `GeneratorDataAdapter`).
1261 steps_per_execution: See `Model.compile`.
1262 distribute: Whether to distribute the `tf.dataset`.
1263 `PreprocessingLayer.adapt` does not support distributed datasets,
1264 `Model` should always set this to `True`.
1265 pss_evaluation_shards: See `Model.fit`.
1266 """
1268 self._initial_epoch = initial_epoch
1269 self._initial_step = 0
1270 self._epochs = epochs
1271 self._insufficient_data = False
1272 self._model = model
1274 self._steps_per_epoch = steps_per_epoch
1276 # `steps_per_execution_value` is the cached initial value.
1277 # `steps_per_execution` is mutable and may be changed by the DataAdapter
1278 # to handle partial executions.
1279 if steps_per_execution is None:
1280 self._steps_per_execution = tf.Variable(1)
1281 else:
1282 self._steps_per_execution = steps_per_execution
1284 adapter_cls = select_data_adapter(x, y)
1285 self._adapter = adapter_cls(
1286 x,
1287 y,
1288 batch_size=batch_size,
1289 steps=steps_per_epoch,
1290 epochs=epochs - initial_epoch,
1291 sample_weights=sample_weight,
1292 shuffle=shuffle,
1293 max_queue_size=max_queue_size,
1294 workers=workers,
1295 use_multiprocessing=use_multiprocessing,
1296 distribution_strategy=tf.distribute.get_strategy(),
1297 model=model,
1298 pss_evaluation_shards=pss_evaluation_shards,
1299 )
1301 strategy = tf.distribute.get_strategy()
1303 self._current_step = 0
1304 self._step_increment = self._steps_per_execution.numpy().item() - 1
1305 self._insufficient_data = False
1307 self._configure_dataset_and_inferred_steps(
1308 strategy, x, steps_per_epoch, class_weight, distribute
1309 )
1311 def _configure_dataset_and_inferred_steps(
1312 self, strategy, x, steps_per_epoch, class_weight, distribute
1313 ):
1314 """Configure the `_dataset` and `_inferred_steps` attributes."""
1315 del x
1316 dataset = self._adapter.get_dataset()
1317 if class_weight:
1318 dataset = dataset.map(_make_class_weight_map_fn(class_weight))
1319 self._inferred_steps = self._infer_steps(steps_per_epoch, dataset)
1321 # `PreprocessingLayer.adapt` does not currently support distributed
1322 # datasets, so we pass `distribute=False` there.
1323 if distribute and not _is_distributed_dataset(dataset):
1324 dataset = strategy.experimental_distribute_dataset(dataset)
1325 self._dataset = dataset
1326 self._validate_data_handler()
1328 def enumerate_epochs(self):
1329 """Yields `(epoch, tf.data.Iterator)`."""
1330 with self._truncate_execution_to_epoch():
1331 data_iterator = iter(self._dataset)
1332 for epoch in range(self._initial_epoch, self._epochs):
1333 if self._insufficient_data: # Set by `catch_stop_iteration`.
1334 break
1335 if self._adapter.should_recreate_iterator():
1336 data_iterator = iter(self._dataset)
1337 if not isinstance(self._dataset, DistributedDataset):
1338 steps = self._infer_steps(
1339 self._steps_per_epoch, self._dataset
1340 )
1341 if steps is not None:
1342 self._inferred_steps = steps
1343 yield epoch, data_iterator
1344 self._adapter.on_epoch_end()
1346 @contextlib.contextmanager
1347 def _truncate_execution_to_epoch(self):
1348 """Truncates steps per execution to at most one epoch."""
1349 should_truncate = (
1350 self._inferred_steps is not None
1351 and self._steps_per_execution.numpy().item() > self._inferred_steps
1352 )
1353 original_value = self._steps_per_execution.numpy().item()
1354 try:
1355 if should_truncate:
1356 self._steps_per_execution.assign(self._inferred_steps)
1357 yield
1358 finally:
1359 if should_truncate:
1360 self._steps_per_execution.assign(original_value)
1362 def sync(self):
1363 context.async_wait()
1365 @contextlib.contextmanager
1366 def catch_stop_iteration(self):
1367 """Catches errors when an iterator runs out of data."""
1368 with distributed_training_utils.maybe_preemption_handler_scope(
1369 self._model
1370 ):
1371 try:
1372 yield
1373 self.sync()
1374 except (StopIteration, tf.errors.OutOfRangeError):
1375 if self._inferred_steps is None:
1376 self._inferred_steps = self._current_step
1377 else:
1378 self._insufficient_data = True
1379 total_epochs = self._epochs - self._initial_epoch
1380 logging.warning(
1381 "Your input ran out of data; interrupting training. "
1382 "Make sure that your dataset or generator can generate "
1383 "at least `steps_per_epoch * epochs` batches (in this "
1384 "case, {} batches). You may need to use the repeat() "
1385 "function when building your dataset.".format(
1386 total_epochs * self._inferred_steps
1387 )
1388 )
1390 def steps(self):
1391 """Yields steps for the current epoch."""
1392 self._current_step = self._initial_step
1393 self._initial_step = 0
1394 # `self._inferred_steps` can be changed by `catch_stop_iteration`.
1395 while (
1396 self._inferred_steps is None
1397 or self._current_step < self._inferred_steps
1398 ):
1399 if self._insufficient_data: # Set by `catch_stop_iteration`.
1400 break
1401 original_spe = self._steps_per_execution.numpy().item()
1402 can_run_full_execution = (
1403 original_spe == 1
1404 or self._inferred_steps is None
1405 or self._inferred_steps - self._current_step >= original_spe
1406 )
1408 if can_run_full_execution:
1409 self._step_increment = original_spe - 1
1410 yield self._current_step
1411 self._current_step += original_spe
1412 else:
1413 # Last partial execution.
1414 steps_remaining = self._inferred_steps - self._current_step
1415 self._steps_per_execution.assign(steps_remaining)
1416 self._step_increment = steps_remaining - 1
1417 yield self._current_step
1418 self._current_step += steps_remaining
1419 self._steps_per_execution.assign(original_spe)
1421 @property
1422 def step_increment(self):
1423 """The number to increment the step for `on_batch_end` methods."""
1424 return self._step_increment
1426 @property
1427 def inferred_steps(self):
1428 """The inferred steps per epoch of the created `Dataset`.
1430 This will be `None` in the case where:
1432 (1) A `Dataset` of unknown cardinality was passed to the `DataHandler`,
1433 (2) `steps_per_epoch` was not provided, and
1434 (3) The first epoch of iteration has not yet completed.
1436 Returns:
1437 The inferred steps per epoch of the created `Dataset`.
1438 """
1439 return self._inferred_steps
1441 @property
1442 def should_sync(self):
1443 # Catch OutOfRangeError for Datasets of unknown size.
1444 # This blocks until the batch has finished executing.
1445 # TODO(b/150292341): Allow multiple async steps here.
1446 return self._inferred_steps is None
1448 def _log_indefinite_training_warning(self):
1449 logging.warning(
1450 "The training loop will run indefinitely since you have "
1451 "set `steps_per_epoch=-1`. Please use batch-level "
1452 "callbacks to save checkpoints or log training progress, "
1453 "etc"
1454 )
1456 def _infer_steps(self, steps, dataset):
1457 """Infers steps_per_epoch needed to loop through a dataset."""
1458 if steps == -1:
1459 self._log_indefinite_training_warning()
1460 return None
1462 if steps is not None:
1463 return steps
1465 adapter_steps = self._adapter.get_size()
1466 if adapter_steps is not None:
1467 return adapter_steps
1469 # tf.distribute's `PerWorkerDataset` does not inherit from
1470 # `tf.data.Dataset` and in those cases we give up on inferring steps.
1471 if not isinstance(dataset, tf.data.Dataset):
1472 return None
1474 size = tf.data.experimental.cardinality(dataset)
1475 if size == tf.data.experimental.INFINITE_CARDINALITY and steps is None:
1476 raise ValueError(
1477 "When passing an infinitely repeating dataset, please specify "
1478 "a `steps_per_epoch` value so that epoch level "
1479 "callbacks continue to work. The value can be arbitrary, or a "
1480 "number that you think correctly defines the size of an epoch. "
1481 "Epoch-level callbacks will then be called at this interval."
1482 )
1483 if size >= 0:
1484 return size.numpy().item()
1485 return None
1487 @property
1488 def _samples(self):
1489 return self._adapter.get_samples()
1491 def _validate_data_handler(self):
1492 # TODO(b/152094471): Support this with DistIter.get_next_as_optional.
1493 if (
1494 self._steps_per_execution.numpy().item() > 1
1495 and self._inferred_steps is None
1496 ):
1497 raise ValueError(
1498 "Could not infer the size of the data. With "
1499 "`steps_per_execution > 1`, you must specify the number of "
1500 "steps to run."
1501 )
1504class _ClusterCoordinatorDataHandler(DataHandler):
1505 """A `DataHandler` that is compatible with `ClusterCoordinator`."""
1507 def __init__(self, x, y=None, **kwargs):
1508 if not _is_distributed_dataset(x) and not isinstance(
1509 x, (dataset_creator.DatasetCreator, tf.data.Dataset)
1510 ):
1511 x = self._convert_to_dataset_creator(x, y, **kwargs)
1513 super().__init__(x=x, **kwargs)
1515 def _convert_to_dataset_creator(self, x, y, **kwargs):
1516 """Converts non-tf.data.Dataset to `DatasetCreator` instances."""
1518 def _dataset_fn(input_context):
1519 del input_context
1520 data_adapter_cls = select_data_adapter(x, y)
1521 return data_adapter_cls(x=x, y=y, **kwargs).get_dataset()
1523 # This check is needed because types like `tf.data.Dataset` don't work
1524 # with PSS yet. So only apply this logic to the types we can support.
1525 if isinstance(x, _get_tensor_types()) and isinstance(
1526 y, _get_tensor_types()
1527 ):
1528 return dataset_creator.DatasetCreator(_dataset_fn)
1529 else:
1530 raise NotImplementedError(
1531 "Only `tf.keras.utils.experimental.DatasetCreator`, "
1532 "`tf.Tensor`, numpy arrays and pandas dataframes are "
1533 "supported types at this time."
1534 )
1536 def _configure_dataset_and_inferred_steps(
1537 self, strategy, x, steps_per_epoch, class_weight, distribute
1538 ):
1539 if isinstance(x, dataset_creator.DatasetCreator):
1541 def per_worker_dataset_fn():
1543 return strategy.distribute_datasets_from_function(
1544 x, options=x.input_options
1545 )
1547 coordinator = self._model._cluster_coordinator
1548 self._dataset = coordinator.create_per_worker_dataset(
1549 per_worker_dataset_fn
1550 )
1551 else:
1552 assert distribute
1553 if not _is_distributed_dataset(x):
1554 x = strategy.experimental_distribute_dataset(x)
1556 coordinator = self._model._cluster_coordinator
1557 self._dataset = coordinator.create_per_worker_dataset(x)
1559 if steps_per_epoch == -1:
1560 self._inferred_steps = None
1561 self._log_indefinite_training_warning()
1562 else:
1563 self._inferred_steps = steps_per_epoch
1565 def sync(self):
1566 self._model._cluster_coordinator.join()
1569class _ClusterCoordinatorExactEvalDataHandler(_ClusterCoordinatorDataHandler):
1570 def __init__(self, x, y=None, **kwargs):
1571 super().__init__(x=x, **kwargs)
1572 self._total_shards = kwargs.get("pss_evaluation_shards")
1574 def _warn_if_not_file_shardable(self, dataset):
1575 # Traverse backwards to find source dataset and check if that is one of
1576 # the unshardable types
1577 # TODO(b/268521864): expand this to inspect dataset function graphs and
1578 # use the auto-sharding logic rather than re-creating it here.
1579 cur_dataset = dataset
1580 while hasattr(cur_dataset, "_input_dataset"):
1581 cur_dataset = cur_dataset._input_dataset
1582 if type(cur_dataset) in UNSHARDABLE_DATASET_TYPES:
1583 logging.warning(
1584 "Found source dataset of type {}. This type is not "
1585 "efficiently shardable, so exact evaluation may be "
1586 "slower than inexact evaluation. Try converting to "
1587 "a TFRecord or other file-based dataset if "
1588 "performance is a concern.".format(type(cur_dataset))
1589 )
1591 def _configure_dataset_and_inferred_steps(
1592 self, strategy, x, steps_per_epoch, class_weight, distribute
1593 ):
1594 if isinstance(x, dataset_creator.DatasetCreator):
1596 def per_worker_dataset_fn():
1597 ddf = strategy.distribute_datasets_from_function(
1598 x, options=x.input_options
1599 )
1600 return ddf
1602 coordinator = self._model._cluster_coordinator
1603 self._dataset = coordinator.create_per_worker_dataset(
1604 per_worker_dataset_fn
1605 )
1606 logging.info("dataset element spec: %r", self._dataset.element_spec)
1607 self._dataset = self._dataset.build()
1608 else:
1609 # TODO(b/268226218): Support DistributedDataset input
1610 if not _is_distributed_dataset(x):
1611 self._warn_if_not_file_shardable(x)
1612 x = strategy.experimental_distribute_dataset(x)
1614 coordinator = self._model._cluster_coordinator
1615 self._dataset = coordinator.create_per_worker_dataset(x)
1616 self._dataset = self._dataset.build()
1618 if steps_per_epoch == -1:
1619 self._inferred_steps = None
1620 self._log_indefinite_training_warning()
1621 else:
1622 self._inferred_steps = steps_per_epoch
1624 def enumerate_epochs(self):
1625 """Yields `(epoch, dataset)`."""
1626 for epoch in range(self._initial_epoch, self._epochs):
1627 yield epoch, self._dataset
1628 self._adapter.on_epoch_end()
1630 def steps(self):
1631 """Yields steps for the current epoch."""
1632 for step in range(self._total_shards):
1633 yield step
1636@keras_export("keras.__internal__.utils.get_data_handler", v1=[])
1637def get_data_handler(*args, **kwargs):
1638 """Creates a `DataHandler`, providing standardized access to a `Dataset`.
1640 See `DataHandler` for the list and definition of the arguments. See the
1641 implementation of `Model.fit()`, `evaluate()`, or `predict()` methods
1642 for complete usage examples. As a rule of tumb, `get_data_handler()` accepts
1643 the same inputs as the `x` argument of `Model.fit()`.
1645 Example:
1647 ```python
1648 def step(iterator):
1649 data = next(iterator)
1650 # result <= Do something with data
1651 return result
1652 tf_step = tf.function(step, reduce_retracing=True)
1654 # Assume x is a tf.data Dataset.
1655 data_handler = data_adapter.get_data_handler(x=x)
1656 # Epoch iteration
1657 for epo_idx, iterator in data_handler.enumerate_epochs():
1658 # Stop on dataset exhaustion.
1659 with data_handler.catch_stop_iteration():
1660 for step in data_handler.steps(): # Step iteration
1661 step_result = step(iterator)
1662 ```
1664 Args:
1665 *args: Arguments passed to the `DataHandler` constructor.
1666 **kwargs: Arguments passed to the `DataHandler` constructor.
1668 Returns:
1669 A `DataHandler` object. If the model's cluster coordinate is set (e.g. the
1670 model was defined under a parameter-server strategy), returns a
1671 `_ClusterCoordinatorDataHandler`.
1673 """
1674 if getattr(kwargs["model"], "_cluster_coordinator", None):
1675 if kwargs.get("pss_evaluation_shards"):
1676 return _ClusterCoordinatorExactEvalDataHandler(*args, **kwargs)
1677 return _ClusterCoordinatorDataHandler(*args, **kwargs)
1678 return DataHandler(*args, **kwargs)
1681def _make_class_weight_map_fn(class_weight):
1682 """Applies class weighting to a `Dataset`.
1684 The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where
1685 `y` must be a single `Tensor`.
1687 Args:
1688 class_weight: A map where the keys are integer class ids and values are
1689 the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`
1691 Returns:
1692 A function that can be used with `tf.data.Dataset.map` to apply class
1693 weighting.
1694 """
1695 class_ids = list(sorted(class_weight.keys()))
1696 expected_class_ids = list(range(len(class_ids)))
1697 if class_ids != expected_class_ids:
1698 error_msg = (
1699 "Expected `class_weight` to be a dict with keys from 0 to one less "
1700 "than the number of classes, found {}"
1701 ).format(class_weight)
1702 raise ValueError(error_msg)
1704 class_weight_tensor = tf.convert_to_tensor(
1705 [class_weight[int(c)] for c in class_ids]
1706 )
1708 def _class_weights_map_fn(*data):
1709 """Convert `class_weight` to `sample_weight`."""
1710 x, y, sw = unpack_x_y_sample_weight(data)
1712 if tf.nest.is_nested(y):
1713 raise ValueError(
1714 "`class_weight` is only supported for Models with a single "
1715 "output."
1716 )
1718 if y.shape.rank >= 2:
1719 y_classes = tf.__internal__.smart_cond.smart_cond(
1720 backend.shape(y)[-1] > 1,
1721 lambda: backend.argmax(y, axis=-1),
1722 lambda: tf.cast(tf.round(tf.squeeze(y, axis=-1)), tf.int64),
1723 )
1724 else:
1725 # Special casing for rank 1, where we can guarantee sparse encoding.
1726 y_classes = tf.cast(tf.round(y), tf.int64)
1728 cw = tf.gather(class_weight_tensor, y_classes)
1729 if sw is not None:
1730 cw = tf.cast(cw, sw.dtype)
1731 # `class_weight` and `sample_weight` are multiplicative.
1732 # If class_weight has more than 2 dimensions, we need to reshape
1733 # sample_weight to make broadcasting possible for multiplication.
1734 rank_delta = cw.shape.rank - sw.shape.rank
1735 sw = tf.reshape(sw, sw.shape + [1] * rank_delta)
1736 sw = sw * cw
1737 else:
1738 sw = cw
1739 return x, y, sw
1741 return _class_weights_map_fn
1744def train_validation_split(arrays, validation_split):
1745 """Split arrays into train and validation subsets in deterministic order.
1747 The last part of data will become validation data.
1749 Args:
1750 arrays: Tensors to split. Allowed inputs are arbitrarily nested structures
1751 of Tensors and NumPy arrays.
1752 validation_split: Float between 0 and 1. The proportion of the dataset to
1753 include in the validation split. The rest of the dataset will be
1754 included in the training split.
1755 Returns:
1756 `(train_arrays, validation_arrays)`
1757 """
1759 def _can_split(t):
1760 tensor_types = _get_tensor_types()
1761 return isinstance(t, tensor_types) or t is None
1763 flat_arrays = tf.nest.flatten(arrays)
1764 unsplitable = [type(t) for t in flat_arrays if not _can_split(t)]
1765 if unsplitable:
1766 raise ValueError(
1767 "`validation_split` is only supported for Tensors or NumPy "
1768 "arrays, found following types in the input: {}".format(unsplitable)
1769 )
1771 if all(t is None for t in flat_arrays):
1772 return arrays, arrays
1774 first_non_none = None
1775 for t in flat_arrays:
1776 if t is not None:
1777 first_non_none = t
1778 break
1780 # Assumes all arrays have the same batch shape or are `None`.
1781 batch_dim = int(first_non_none.shape[0])
1782 split_at = int(math.floor(batch_dim * (1.0 - validation_split)))
1784 if split_at == 0 or split_at == batch_dim:
1785 raise ValueError(
1786 "Training data contains {batch_dim} samples, which is not "
1787 "sufficient to split it into a validation and training set as "
1788 "specified by `validation_split={validation_split}`. Either "
1789 "provide more data, or a different value for the "
1790 "`validation_split` argument.".format(
1791 batch_dim=batch_dim, validation_split=validation_split
1792 )
1793 )
1795 def _split(t, start, end):
1796 if t is None:
1797 return t
1798 return t[start:end]
1800 train_arrays = tf.nest.map_structure(
1801 functools.partial(_split, start=0, end=split_at), arrays
1802 )
1803 val_arrays = tf.nest.map_structure(
1804 functools.partial(_split, start=split_at, end=batch_dim), arrays
1805 )
1807 return train_arrays, val_arrays
1810@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[])
1811def unpack_x_y_sample_weight(data):
1812 """Unpacks user-provided data tuple.
1814 This is a convenience utility to be used when overriding
1815 `Model.train_step`, `Model.test_step`, or `Model.predict_step`.
1816 This utility makes it easy to support data of the form `(x,)`,
1817 `(x, y)`, or `(x, y, sample_weight)`.
1819 Standalone usage:
1821 >>> features_batch = tf.ones((10, 5))
1822 >>> labels_batch = tf.zeros((10, 5))
1823 >>> data = (features_batch, labels_batch)
1824 >>> # `y` and `sample_weight` will default to `None` if not provided.
1825 >>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
1826 >>> sample_weight is None
1827 True
1829 Example in overridden `Model.train_step`:
1831 ```python
1832 class MyModel(tf.keras.Model):
1834 def train_step(self, data):
1835 # If `sample_weight` is not provided, all samples will be weighted
1836 # equally.
1837 x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
1839 with tf.GradientTape() as tape:
1840 y_pred = self(x, training=True)
1841 loss = self.compiled_loss(
1842 y, y_pred, sample_weight, regularization_losses=self.losses)
1843 trainable_variables = self.trainable_variables
1844 gradients = tape.gradient(loss, trainable_variables)
1845 self.optimizer.apply_gradients(zip(gradients, trainable_variables))
1847 self.compiled_metrics.update_state(y, y_pred, sample_weight)
1848 return {m.name: m.result() for m in self.metrics}
1849 ```
1851 Args:
1852 data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
1854 Returns:
1855 The unpacked tuple, with `None`s for `y` and `sample_weight` if they are
1856 not provided.
1857 """
1858 if isinstance(data, list):
1859 data = tuple(data)
1860 if not isinstance(data, tuple):
1861 return (data, None, None)
1862 elif len(data) == 1:
1863 return (data[0], None, None)
1864 elif len(data) == 2:
1865 return (data[0], data[1], None)
1866 elif len(data) == 3:
1867 return (data[0], data[1], data[2])
1868 else:
1869 error_msg = (
1870 "Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
1871 "or `(x, y, sample_weight)`, found: {}"
1872 ).format(data)
1873 raise ValueError(error_msg)
1876@keras_export("keras.utils.pack_x_y_sample_weight", v1=[])
1877def pack_x_y_sample_weight(x, y=None, sample_weight=None):
1878 """Packs user-provided data into a tuple.
1880 This is a convenience utility for packing data into the tuple formats
1881 that `Model.fit` uses.
1883 Standalone usage:
1885 >>> x = tf.ones((10, 1))
1886 >>> data = tf.keras.utils.pack_x_y_sample_weight(x)
1887 >>> isinstance(data, tf.Tensor)
1888 True
1889 >>> y = tf.ones((10, 1))
1890 >>> data = tf.keras.utils.pack_x_y_sample_weight(x, y)
1891 >>> isinstance(data, tuple)
1892 True
1893 >>> x, y = data
1895 Args:
1896 x: Features to pass to `Model`.
1897 y: Ground-truth targets to pass to `Model`.
1898 sample_weight: Sample weight for each element.
1900 Returns:
1901 Tuple in the format used in `Model.fit`.
1902 """
1903 if y is None:
1904 # For single x-input, we do no tuple wrapping since in this case
1905 # there is no ambiguity. This also makes NumPy and Dataset
1906 # consistent in that the user does not have to wrap their Dataset
1907 # data in an unnecessary tuple.
1908 if not isinstance(x, tuple or list):
1909 return x
1910 else:
1911 return (x,)
1912 elif sample_weight is None:
1913 return (x, y)
1914 else:
1915 return (x, y, sample_weight)
1918def single_batch_iterator(
1919 strategy, x, y=None, sample_weight=None, class_weight=None
1920):
1921 """Creates a single-batch dataset."""
1922 x, y, sample_weight = _process_tensorlike((x, y, sample_weight))
1923 if y is None:
1924 data = (x,)
1925 elif sample_weight is None:
1926 data = (x, y)
1927 else:
1928 data = (x, y, sample_weight)
1930 _check_data_cardinality(data)
1931 dataset = tf.data.Dataset.from_tensors(data)
1932 if class_weight:
1933 dataset = dataset.map(_make_class_weight_map_fn(class_weight))
1934 dataset = strategy.experimental_distribute_dataset(dataset)
1935 return iter(dataset)
1938def _check_data_cardinality(data):
1939 num_samples = set(int(i.shape[0]) for i in tf.nest.flatten(data))
1940 if len(num_samples) > 1:
1941 msg = "Data cardinality is ambiguous:\n"
1942 for label, single_data in zip(["x", "y", "sample_weight"], data):
1943 msg += " {} sizes: {}\n".format(
1944 label,
1945 ", ".join(
1946 str(i.shape[0]) for i in tf.nest.flatten(single_data)
1947 ),
1948 )
1949 msg += "Make sure all arrays contain the same number of samples."
1950 raise ValueError(msg)
1953def _get_tensor_types():
1954 if pd is None:
1955 return (tf.Tensor, np.ndarray)
1956 else:
1957 return (tf.Tensor, np.ndarray, pd.Series, pd.DataFrame)
1960def _is_scipy_sparse(x):
1961 try:
1962 from scipy.sparse import issparse
1964 return issparse(x)
1965 except ImportError:
1966 return False
1969def _is_pandas_series(x):
1970 if pd is None:
1971 return False
1972 else:
1973 return isinstance(x, pd.Series)
1976def _scipy_sparse_to_sparse_tensor(t):
1977 """Converts a SciPy sparse matrix to a SparseTensor."""
1978 sparse_coo = t.tocoo()
1979 row, col = sparse_coo.row, sparse_coo.col
1980 data, shape = sparse_coo.data, sparse_coo.shape
1981 if issubclass(data.dtype.type, np.floating):
1982 data = data.astype(backend.floatx())
1983 indices = np.concatenate(
1984 (np.expand_dims(row, axis=1), np.expand_dims(col, axis=1)), axis=1
1985 )
1986 return tf.SparseTensor(indices, data, shape)
1989def _is_distributed_dataset(ds):
1990 return isinstance(
1991 ds,
1992 (
1993 tf.distribute.DistributedDataset,
1994 tf.experimental.dtensor.DTensorDataset,
1995 ),
1996 )