Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/data_adapter.py: 25%
746 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Adapter module that convert different input data objects into tf.dataset."""
17import abc
18import contextlib
19import functools
20import itertools
21import math
22import random
24import numpy as np
26from tensorflow.python.data.experimental.ops import cardinality
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.data.ops import iterator_ops
29from tensorflow.python.data.ops import options as options_lib
30from tensorflow.python.distribute import distribute_lib
31from tensorflow.python.distribute import input_lib
32from tensorflow.python.eager import context
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import smart_cond
37from tensorflow.python.framework import sparse_tensor
38from tensorflow.python.framework import tensor_conversion
39from tensorflow.python.framework import tensor_shape
40from tensorflow.python.keras import backend
41from tensorflow.python.keras.engine import training_utils
42from tensorflow.python.keras.utils import data_utils
43from tensorflow.python.keras.utils import dataset_creator
44from tensorflow.python.keras.utils import tf_utils
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import random_ops
48from tensorflow.python.ops import script_ops
49from tensorflow.python.platform import tf_logging as logging
50from tensorflow.python.types import data as data_types
51from tensorflow.python.util import nest
52from tensorflow.python.util.tf_export import keras_export
55class DataAdapter(object, metaclass=abc.ABCMeta):
56 """Base class for input data adapter.
58 In TF 2.0, tf.data is the preferred API for user to feed in data. In order
59 to simplify the training code path, all the input data object will be
60 converted to `tf.data.Dataset` if possible.
62 Note that since this class is mainly targeted for TF 2.0, it might have a lot
63 of assumptions under the hood, eg eager context by default, distribution
64 strategy, etc. In the meantime, some legacy feature support might be dropped,
65 eg, Iterator from dataset API in v1, etc.
67 The sample usage of this class is like:
69 ```
70 x = tf.data.Dataset.range(100)
71 adapter_cls = [NumpyArrayDataAdapter, ..., DatasetAdapter]
72 applicable_adapters = [cls for cls in adapter_cls if cls.can_handle(x)]
73 if len(applicable_adapters) != 1:
74 raise ValueError("Expect only one adapter class to handle the input")
76 dataset = applicable_adapters[0](x).get_dataset()
77 for data in dataset:
78 # training
79 ```
80 """
82 @staticmethod
83 def can_handle(x, y=None):
84 """Whether the current DataAdapter could handle the input x and y.
86 Structure wise, x and y can be single object, or list of objects if there
87 multiple input/output, or dictionary of objects when the intput/output are
88 named.
90 Args:
91 x: input features.
92 y: target labels. Note that y could be None in the case of prediction.
94 Returns:
95 boolean
96 """
97 raise NotImplementedError
99 @abc.abstractmethod
100 def __init__(self, x, y=None, **kwargs):
101 """Create a DataAdapter based on data inputs.
103 The caller must make sure to call `can_handle()` first before invoking this
104 method. Provide unsupported data type will result into unexpected behavior.
106 Args:
107 x: input features.
108 y: target labels. Note that y could be None in the case of prediction.
109 **kwargs: Other keyword arguments for DataAdapter during the construction
110 of the tf.dataset.Dataset. For example:
111 - Numpy data might have `sample_weights` which will be used for
112 weighting the loss function during training.
113 - Numpy data might need to have `batch_size` parameter when constructing
114 the dataset and iterator.
115 - Certain input might need to be distribution strategy aware. When
116 `distribution_strategy` is passed, the created dataset need to respect
117 the strategy.
118 DataAdapter might choose to ignore any keyword argument if it doesn't
119 use it, or raise exception if any required argument is not provide.
120 """
121 if not self.can_handle(x, y):
122 raise ValueError("{} Cannot handle input {}, {}".format(
123 self.__class__, x, y))
125 @abc.abstractmethod
126 def get_dataset(self):
127 """Get a dataset instance for the current DataAdapter.
129 Note that the dataset returned does not repeat for epoch, so caller might
130 need to create new iterator for the same dataset at the beginning of the
131 epoch. This behavior might change in future.
133 Returns:
134 An tf.dataset.Dataset. Caller might use the dataset in different
135 context, eg iter(dataset) in eager to get the value directly, or in graph
136 mode, provide the iterator tensor to Keras model function.
137 """
138 raise NotImplementedError
140 @abc.abstractmethod
141 def get_size(self):
142 """Return the size (number of batches) for the dataset created.
144 For certain type of the data input, the number of batches is known, eg for
145 Numpy data, the size is same as (number_of_element / batch_size). Whereas
146 for dataset or python generator, the size is unknown since it may or may not
147 have a end state.
149 Returns:
150 int, the number of batches for the dataset, or None if it is unknown. The
151 caller could use this to control the loop of training, show progress bar,
152 or handle unexpected StopIteration error.
153 """
154 raise NotImplementedError
156 @abc.abstractmethod
157 def batch_size(self):
158 """Return the batch size of the dataset created.
160 For certain type of the data input, the batch size is known, and even
161 required, like numpy array. Where as for dataset, the batch is unknown
162 unless we take a peek.
164 Returns:
165 int, the batch size of the dataset, or None if it is unknown.
166 """
167 raise NotImplementedError
169 def representative_batch_size(self):
170 """Return a representative size for batches in the dataset.
172 This is not guaranteed to be the batch size for all batches in the
173 dataset. It just needs to be a rough approximation for batch sizes in
174 the dataset.
176 Returns:
177 int, a representative size for batches found in the dataset,
178 or None if it is unknown.
179 """
180 return self.batch_size()
182 @abc.abstractmethod
183 def has_partial_batch(self):
184 """Whether the dataset has partial batch at the end."""
185 raise NotImplementedError
187 @abc.abstractmethod
188 def partial_batch_size(self):
189 """The size of the final partial batch for dataset.
191 Will return None if has_partial_batch is False or batch_size is None.
192 """
193 raise NotImplementedError
195 @abc.abstractmethod
196 def should_recreate_iterator(self):
197 """Returns whether a new iterator should be created every epoch."""
198 raise NotImplementedError
200 def get_samples(self):
201 """Returns number of samples in the data, or `None`."""
202 if not self.get_size() or not self.batch_size():
203 return None
204 total_sample = self.get_size() * self.batch_size()
205 if self.has_partial_batch():
206 total_sample -= (self.batch_size() - self.partial_batch_size())
207 return total_sample
209 def on_epoch_end(self):
210 """A hook called after each epoch."""
211 pass
214class TensorLikeDataAdapter(DataAdapter):
215 """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
217 @staticmethod
218 def can_handle(x, y=None):
219 # TODO(kaftan): Check performance implications of using a flatten
220 # here for other types of inputs.
221 flat_inputs = nest.flatten(x)
222 if y is not None:
223 flat_inputs += nest.flatten(y)
225 tensor_types = _get_tensor_types()
227 def _is_tensor(v):
228 if isinstance(v, tensor_types):
229 return True
230 return False
232 return all(_is_tensor(v) for v in flat_inputs)
234 def __init__(self,
235 x,
236 y=None,
237 sample_weights=None,
238 sample_weight_modes=None,
239 batch_size=None,
240 epochs=1,
241 steps=None,
242 shuffle=False,
243 **kwargs):
244 super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs)
245 x, y, sample_weights = _process_tensorlike((x, y, sample_weights))
246 sample_weight_modes = broadcast_sample_weight_modes(
247 sample_weights, sample_weight_modes)
249 # If sample_weights are not specified for an output use 1.0 as weights.
250 (sample_weights, _, _) = training_utils.handle_partial_sample_weights(
251 y, sample_weights, sample_weight_modes, check_all_flat=True)
253 inputs = pack_x_y_sample_weight(x, y, sample_weights)
255 num_samples = set(int(i.shape[0]) for i in nest.flatten(inputs)).pop()
256 _check_data_cardinality(inputs)
258 # If batch_size is not passed but steps is, calculate from the input data.
259 # Default to 32 for backwards compat.
260 if not batch_size:
261 batch_size = int(math.ceil(num_samples / steps)) if steps else 32
263 self._size = int(math.ceil(num_samples / batch_size))
264 self._batch_size = batch_size
266 num_full_batches = int(num_samples // batch_size)
267 self._partial_batch_size = num_samples % batch_size
269 if isinstance(shuffle, str):
270 shuffle = shuffle.lower()
272 self._shuffle = shuffle
273 # Vectorized version of shuffle.
274 # This is a performance improvement over using `from_tensor_slices`.
275 # The indices of the data are shuffled and batched, and these indices
276 # are then zipped with the data and used to extract a batch of the data
277 # at each step. The performance improvements here come from:
278 # 1. vectorized batch using gather
279 # 2. parallelized map
280 # 3. pipelined permutation generation
281 # 4. optimized permutation batching
282 # 5. disabled static optimizations
284 indices_dataset = dataset_ops.DatasetV2.range(1)
285 if shuffle != "batch":
286 indices_dataset = indices_dataset.repeat(epochs)
288 def permutation(_):
289 # It turns out to be more performant to make a new set of indices rather
290 # than reusing the same range Tensor. (presumably because of buffer
291 # forwarding.)
292 indices = math_ops.range(num_samples, dtype=dtypes.int64)
293 if shuffle and shuffle != "batch":
294 indices = random_ops.random_shuffle(indices)
295 return indices
297 # We prefetch a single element. Computing large permutations can take quite
298 # a while so we don't want to wait for prefetching over an epoch boundary to
299 # trigger the next permutation. On the other hand, too many simultaneous
300 # shuffles can contend on a hardware level and degrade all performance.
301 indices_dataset = indices_dataset.map(permutation).prefetch(1)
303 def slice_batch_indices(indices):
304 """Convert a Tensor of indices into a dataset of batched indices.
306 This step can be accomplished in several ways. The most natural is to
307 slice the Tensor in a Dataset map. (With a condition on the upper index to
308 handle the partial batch.) However it turns out that coercing the Tensor
309 into a shape which is divisible by the batch size (and handling the last
310 partial batch separately) allows for a much more favorable memory access
311 pattern and improved performance.
313 Args:
314 indices: Tensor which determines the data order for an entire epoch.
316 Returns:
317 A Dataset of batched indices.
318 """
319 num_in_full_batch = num_full_batches * batch_size
320 first_k_indices = array_ops.slice(indices, [0], [num_in_full_batch])
321 first_k_indices = array_ops.reshape(
322 first_k_indices, [num_full_batches, batch_size])
324 flat_dataset = dataset_ops.DatasetV2.from_tensor_slices(first_k_indices)
325 if self._partial_batch_size:
326 index_remainder = dataset_ops.DatasetV2.from_tensors(array_ops.slice(
327 indices, [num_in_full_batch], [self._partial_batch_size]))
328 flat_dataset = flat_dataset.concatenate(index_remainder)
330 if shuffle == "batch":
331 # 1024 is a magic constant that has not been properly evaluated
332 flat_dataset = flat_dataset.shuffle(1024).repeat(epochs)
333 return flat_dataset
335 indices_dataset = indices_dataset.flat_map(slice_batch_indices)
337 dataset = self.slice_inputs(indices_dataset, inputs)
339 if shuffle == "batch":
340 def shuffle_batch(*batch):
341 return nest.map_structure(random_ops.random_shuffle, batch)
342 dataset = dataset.map(shuffle_batch)
344 self._dataset = dataset
346 def slice_inputs(self, indices_dataset, inputs):
347 """Slice inputs into a Dataset of batches.
349 Given a Dataset of batch indices and the unsliced inputs,
350 this step slices the inputs in a parallelized fashion
351 and produces a dataset of input batches.
353 Args:
354 indices_dataset: A Dataset of batched indices
355 inputs: A python data structure that contains the inputs, targets,
356 and possibly sample weights.
358 Returns:
359 A Dataset of input batches matching the batch indices.
360 """
361 dataset = dataset_ops.DatasetV2.zip((
362 indices_dataset,
363 dataset_ops.DatasetV2.from_tensors(inputs).repeat()
364 ))
366 def grab_batch(i, data):
367 return nest.map_structure(lambda d: array_ops.gather(d, i, axis=0), data)
369 dataset = dataset.map(
370 grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE)
372 # Default optimizations are disabled to avoid the overhead of (unnecessary)
373 # input pipeline graph serialization and deserialization
374 options = options_lib.Options()
375 options.experimental_optimization.apply_default_optimizations = False
376 if self._shuffle:
377 # See b/141490660 for more details.
378 options.experimental_external_state_policy = (
379 options_lib.ExternalStatePolicy.IGNORE)
380 dataset = dataset.with_options(options)
381 return dataset
383 def get_dataset(self):
384 return self._dataset
386 def get_size(self):
387 return self._size
389 def batch_size(self):
390 return self._batch_size
392 def has_partial_batch(self):
393 return self._partial_batch_size > 0
395 def partial_batch_size(self):
396 return self._partial_batch_size or None
398 def should_recreate_iterator(self):
399 # An infinite dataset is always created here.
400 return False
403class GenericArrayLikeDataAdapter(TensorLikeDataAdapter):
404 """Adapter that handles array-like data without forcing it into memory.
406 This adapter handles array-like datasets that may be too big to fully
407 fit into memory.
409 Specifically, this adapter handles any Python class which implements:
410 `__get_item__`, `__len__`, `shape`, and `dtype` with the same meanings
411 as Numpy, but it ignores any case where all the inputs are Tensors or Numpy
412 arrays (because that case is handled by the base TensorLikeDataAdapter).
414 It ignores scipy sparse matrices and Composite Tensors because those are
415 handled by the CompositeTensorDataAdapter.
417 It also does not handle lists/tuples of scalars, because those are handled
418 by the ListsOfScalarsDataAdapter.
419 """
421 @staticmethod
422 def can_handle(x, y=None):
423 flat_inputs = nest.flatten(x)
424 if y is not None:
425 flat_inputs += nest.flatten(y)
427 def _is_array_like(v):
428 """Return True if v is a Tensor, array, or is array-like."""
429 return (
430 hasattr(v, "__getitem__") and
431 hasattr(v, "shape") and
432 hasattr(v, "dtype") and
433 hasattr(v, "__len__")
434 )
436 if (not TensorLikeDataAdapter.can_handle(x, y) and
437 not CompositeTensorDataAdapter.can_handle(x, y)):
438 return all(_is_array_like(v) for v in flat_inputs)
439 else:
440 return False
442 def __init__(self, *args, **kwargs):
443 logging.warning(
444 "Keras is training/fitting/evaluating on array-like data. Keras may "
445 "not be optimized for this format, so if your input data format is "
446 "supported by TensorFlow I/O (https://github.com/tensorflow/io) we "
447 "recommend using that to load a Dataset instead.")
449 super(GenericArrayLikeDataAdapter, self).__init__(*args, **kwargs)
451 def slice_inputs(self, indices_dataset, inputs):
452 """Slice inputs into a Dataset of batches.
454 Given a Dataset of batch indices and the unsliced inputs,
455 this step slices the inputs in a parallelized fashion
456 and produces a dataset of input batches.
458 Args:
459 indices_dataset: A Dataset of batched indices
460 inputs: A python data structure that contains the inputs, targets,
461 and possibly sample weights.
463 Returns:
464 A Dataset of input batches matching the batch indices.
465 """
466 flat_inputs = nest.flatten(inputs)
467 def dynamic_shape_like(t):
468 shape = list(t.shape)
469 shape[0] = None
470 return tuple(shape)
472 flat_dtypes = [inp.dtype for inp in flat_inputs]
473 contiguous = True
474 if self._shuffle and self._shuffle != "batch":
475 contiguous = False
477 def grab_batch(indices):
478 """Grab a batch of data from the inputs."""
479 # This uses a py_function to avoid converting the array-like
480 # into a Tensor before slicing it, because converting the array-like
481 # to a Tensor may force it into memory..
482 def py_method(ind):
483 def slice_array(data):
484 return training_utils.slice_arrays(data, ind.numpy(),
485 contiguous=contiguous)
486 return [slice_array(inp) for inp in flat_inputs]
488 flat_out = script_ops.eager_py_func(py_method, [indices], flat_dtypes)
489 for v, original_inp in zip(flat_out, flat_inputs):
490 v.set_shape(dynamic_shape_like(original_inp))
491 return nest.pack_sequence_as(inputs, flat_out)
493 dataset = indices_dataset.map(
494 grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE)
496 return dataset
499class DatasetCreatorAdapter(DataAdapter):
500 """Adapter that handles dataset functions."""
502 def __init__(self, x, y, steps=None, distribution_strategy=None, **kwargs):
503 super(DatasetCreatorAdapter, self).__init__(x, **kwargs)
505 if not isinstance(x, dataset_creator.DatasetCreator):
506 raise TypeError("The input of a `DatasetCreatorAdapter` should be a "
507 "`DatasetCreator` but it received type {}.".format(
508 type(x)))
509 if steps is None:
510 raise ValueError("When using a "
511 "`tf.keras.utils.experimental.DatasetCreator`, "
512 "`steps_per_epoch`, `validation_steps` or `steps` "
513 "argument must be provided in `Model.fit`, "
514 "`Model.evaluate`, or `Model.predict`.")
515 self.dataset_creator = x
516 self.steps = steps
517 self.strategy = distribution_strategy
519 @staticmethod
520 def can_handle(x, y=None):
521 if isinstance(x, dataset_creator.DatasetCreator):
522 assert y is None
523 return True
525 def should_recreate_iterator(self):
526 # We expect users to shuffle the dataset in their `dataset_fn` supplied to
527 # `DatasetCreator`. Since that is a buffered shuffle, we intend to not reset
528 # the dataset so the batches that are not shuffled can still be pulled.
529 return False
531 def get_size(self):
532 return None # To be inferred by `DataHandler`.
534 def get_dataset(self):
535 return self.strategy.distribute_datasets_from_function(
536 self.dataset_creator, options=self.dataset_creator.input_options)
538 def batch_size(self):
539 raise NotImplementedError()
541 def has_partial_batch(self):
542 raise NotImplementedError()
544 def partial_batch_size(self):
545 raise NotImplementedError()
548class CompositeTensorDataAdapter(DataAdapter):
549 """Adapter that handles composite tensor."""
551 @staticmethod
552 def can_handle(x, y=None):
553 flat_inputs = nest.flatten(x)
554 if y is not None:
555 flat_inputs += nest.flatten(y)
557 def _is_composite(v):
558 # Dataset/iterator/DistributedDataset inherits from CompositeTensor but
559 # should be handled by DatasetAdapter and GeneratorAdapter.
560 if (tf_utils.is_extension_type(v) and
561 not isinstance(v,
562 (dataset_ops.DatasetV2, iterator_ops.IteratorBase)) and
563 not _is_distributed_dataset(v)):
564 return True
565 # Support Scipy sparse tensors if scipy is installed
566 return _is_scipy_sparse(v)
568 def _is_tensor_or_composite(v):
569 if isinstance(v, (ops.Tensor, np.ndarray)):
570 return True
571 return _is_composite(v)
573 return (any(_is_composite(v) for v in flat_inputs) and
574 all(_is_tensor_or_composite(v) for v in flat_inputs))
576 def __init__(self,
577 x,
578 y=None,
579 sample_weights=None,
580 sample_weight_modes=None,
581 batch_size=None,
582 steps=None,
583 shuffle=False,
584 **kwargs):
585 super(CompositeTensorDataAdapter, self).__init__(x, y, **kwargs)
586 x, y, sample_weights = _process_tensorlike((x, y, sample_weights))
587 sample_weight_modes = broadcast_sample_weight_modes(
588 sample_weights, sample_weight_modes)
590 # If sample_weights are not specified for an output use 1.0 as weights.
591 (sample_weights, _, _) = training_utils.handle_partial_sample_weights(
592 y, sample_weights, sample_weight_modes, check_all_flat=True)
594 inputs = pack_x_y_sample_weight(x, y, sample_weights)
596 dataset = dataset_ops.DatasetV2.from_tensor_slices(inputs)
597 num_samples = int(nest.flatten(x)[0].shape[0])
598 if shuffle:
599 dataset = dataset.shuffle(num_samples)
601 # If batch_size is not passed but steps is, calculate from the input data.
602 # Default to 32 for backwards compat.
603 if not batch_size:
604 batch_size = int(math.ceil(num_samples / steps)) if steps else 32
606 dataset = dataset.batch(batch_size)
607 self._size = int(math.ceil(num_samples / batch_size))
608 self._batch_size = batch_size
609 self._has_partial_batch = (self._size != (num_samples // batch_size))
611 self._partial_batch_size = None
612 if self._has_partial_batch:
613 self._partial_batch_size = (
614 num_samples - (self._size - 1) * self._batch_size)
616 self._dataset = dataset
618 def get_dataset(self):
619 return self._dataset
621 def get_size(self):
622 return self._size
624 def batch_size(self):
625 return self._batch_size
627 def has_partial_batch(self):
628 return self._has_partial_batch
630 def partial_batch_size(self):
631 return self._partial_batch_size
633 def should_recreate_iterator(self):
634 return True
637class ListsOfScalarsDataAdapter(DataAdapter):
638 """Adapter that handles lists of scalars and lists of lists of scalars."""
640 @staticmethod
641 def can_handle(x, y=None):
642 handles_x = ListsOfScalarsDataAdapter._is_list_of_scalars(x)
643 handles_y = True
644 if y is not None:
645 handles_y = ListsOfScalarsDataAdapter._is_list_of_scalars(y)
646 return handles_x and handles_y
648 @staticmethod
649 def _is_list_of_scalars(inp):
650 if isinstance(inp, (float, int, str, bytes, bytearray)):
651 return True
652 if isinstance(inp, (list, tuple)) and inp:
653 return ListsOfScalarsDataAdapter._is_list_of_scalars(inp[0])
654 return False
656 def __init__(self,
657 x,
658 y=None,
659 sample_weights=None,
660 sample_weight_modes=None,
661 batch_size=None,
662 shuffle=False,
663 **kwargs):
664 super(ListsOfScalarsDataAdapter, self).__init__(x, y, **kwargs)
665 x = np.asarray(x)
666 if y is not None:
667 y = np.asarray(y)
668 if sample_weights is not None:
669 sample_weights = np.asarray(sample_weights)
670 sample_weight_modes = broadcast_sample_weight_modes(
671 sample_weights, sample_weight_modes)
673 self._internal_adapter = TensorLikeDataAdapter(
674 x,
675 y=y,
676 sample_weights=sample_weights,
677 sample_weight_modes=sample_weight_modes,
678 batch_size=batch_size,
679 shuffle=shuffle,
680 **kwargs)
682 def get_dataset(self):
683 return self._internal_adapter.get_dataset()
685 def get_size(self):
686 return self._internal_adapter.get_size()
688 def batch_size(self):
689 return self._internal_adapter.batch_size()
691 def has_partial_batch(self):
692 return self._internal_adapter.has_partial_batch()
694 def partial_batch_size(self):
695 return self._internal_adapter.partial_batch_size()
697 def should_recreate_iterator(self):
698 return True
701class DatasetAdapter(DataAdapter):
702 """Adapter that handles `tf.data.Dataset`."""
704 @staticmethod
705 def can_handle(x, y=None):
706 return (isinstance(x, (data_types.DatasetV1, data_types.DatasetV2)) or
707 _is_distributed_dataset(x))
709 def __init__(self,
710 x,
711 y=None,
712 sample_weights=None,
713 steps=None,
714 **kwargs):
715 super(DatasetAdapter, self).__init__(x, y, **kwargs)
716 # Note that the dataset instance is immutable, its fine to reuse the user
717 # provided dataset.
718 self._dataset = x
720 # The user-provided steps.
721 self._user_steps = steps
723 self._validate_args(y, sample_weights, steps)
725 def get_dataset(self):
726 return self._dataset
728 def get_size(self):
729 return # Inferred in `DataHandler`.
731 def batch_size(self):
732 return None
734 def has_partial_batch(self):
735 return False
737 def partial_batch_size(self):
738 return None
740 def should_recreate_iterator(self):
741 # Since DistributedDatasets have no cardinality, the user must provide
742 # all steps that need to be run, calling `.repeat()` as needed.
743 if _is_distributed_dataset(self._dataset):
744 return False
746 # If user doesn't supply `steps`, or if they supply `steps` that
747 # exactly equals the size of the `Dataset`, create a new iterator
748 # each epoch.
749 return (self._user_steps is None or
750 cardinality.cardinality(self._dataset).numpy() == self._user_steps)
752 def _validate_args(self, y, sample_weights, steps):
753 """Validates `__init__` arguments."""
754 # Arguments that shouldn't be passed.
755 if not is_none_or_empty(y):
756 raise ValueError("`y` argument is not supported when using "
757 "dataset as input.")
758 if not is_none_or_empty(sample_weights):
759 raise ValueError("`sample_weight` argument is not supported when using "
760 "dataset as input.")
762 if steps is None:
763 if _is_distributed_dataset(self._dataset):
764 raise ValueError("When providing a distributed dataset, you must "
765 "specify the number of steps to run.")
767 size = cardinality.cardinality(self._dataset).numpy()
768 if size == cardinality.INFINITE and steps is None:
769 raise ValueError(
770 "When providing an infinite dataset, you must specify "
771 "the number of steps to run (if you did not intend to "
772 "create an infinite dataset, make sure to not call "
773 "`repeat()` on the dataset).")
776class GeneratorDataAdapter(DataAdapter):
777 """Adapter that handles python generators and iterators."""
779 @staticmethod
780 def can_handle(x, y=None):
781 return ((hasattr(x, "__next__") or hasattr(x, "next"))
782 and hasattr(x, "__iter__")
783 and not isinstance(x, data_utils.Sequence))
785 def __init__(self,
786 x,
787 y=None,
788 sample_weights=None,
789 workers=1,
790 use_multiprocessing=False,
791 max_queue_size=10,
792 model=None,
793 **kwargs):
794 # Generators should never shuffle as exhausting the generator in order to
795 # shuffle the batches is inefficient.
796 kwargs.pop("shuffle", None)
798 if not is_none_or_empty(y):
799 raise ValueError("`y` argument is not supported when using "
800 "python generator as input.")
801 if not is_none_or_empty(sample_weights):
802 raise ValueError("`sample_weight` argument is not supported when using "
803 "python generator as input.")
805 super(GeneratorDataAdapter, self).__init__(x, y, **kwargs)
807 # Since we have to know the dtype of the python generator when we build the
808 # dataset, we have to look at a batch to infer the structure.
809 peek, x = self._peek_and_restore(x)
810 peek = self._standardize_batch(peek)
811 peek = _process_tensorlike(peek)
813 # Need to build the Model on concrete input shapes.
814 if model is not None and not model.built:
815 concrete_x, _, _ = unpack_x_y_sample_weight(peek)
816 model.distribute_strategy.run(
817 lambda x: model(x, training=False), args=(concrete_x,))
819 self._first_batch_size = int(nest.flatten(peek)[0].shape[0])
821 def _get_dynamic_shape(t):
822 shape = t.shape
823 # Unknown number of dimensions, `as_list` cannot be called.
824 if shape.rank is None:
825 return shape
826 return tensor_shape.TensorShape([None for _ in shape.as_list()])
828 output_shapes = nest.map_structure(_get_dynamic_shape, peek)
829 output_types = nest.map_structure(lambda t: t.dtype, peek)
831 # Note that dataset API takes a callable that creates a generator object,
832 # rather than generator itself, which is why we define a function here.
833 generator_fn = self._handle_multiprocessing(x, workers, use_multiprocessing,
834 max_queue_size)
836 def wrapped_generator():
837 for data in generator_fn():
838 yield self._standardize_batch(data)
840 dataset = dataset_ops.DatasetV2.from_generator(
841 wrapped_generator, output_types, output_shapes=output_shapes)
843 if workers == 1 and not use_multiprocessing:
844 dataset = dataset.prefetch(1)
846 self._dataset = dataset
848 def _standardize_batch(self, data):
849 """Standardizes a batch output by a generator."""
850 # Removes `None`s.
851 x, y, sample_weight = unpack_x_y_sample_weight(data)
852 data = pack_x_y_sample_weight(x, y, sample_weight)
854 data = nest.list_to_tuple(data)
856 def _convert_dtype(t):
857 if (isinstance(t, np.ndarray) and issubclass(t.dtype.type, np.floating)):
858 return np.array(t, dtype=backend.floatx())
859 return t
861 data = nest.map_structure(_convert_dtype, data)
862 return data
864 @staticmethod
865 def _peek_and_restore(x):
866 peek = next(x)
867 return peek, itertools.chain([peek], x)
869 def _handle_multiprocessing(self, x, workers, use_multiprocessing,
870 max_queue_size):
871 """Create a callable, possibly including an Enqueuer."""
872 if workers > 1 or (workers > 0 and use_multiprocessing):
873 def generator_fn():
874 enqueuer = data_utils.GeneratorEnqueuer(
875 x, use_multiprocessing=use_multiprocessing)
876 enqueuer.start(workers=workers, max_queue_size=max_queue_size)
877 return enqueuer.get()
878 else:
879 generator_fn = lambda: x
880 return generator_fn
882 def get_dataset(self):
883 return self._dataset
885 def get_size(self):
886 return None
888 def batch_size(self):
889 return None
891 def representative_batch_size(self):
892 return self._first_batch_size
894 def has_partial_batch(self):
895 return False
897 def partial_batch_size(self):
898 return
900 def should_recreate_iterator(self):
901 return False
904class KerasSequenceAdapter(GeneratorDataAdapter):
905 """Adapter that handles `keras.utils.Sequence`."""
907 @staticmethod
908 def can_handle(x, y=None):
909 return isinstance(x, data_utils.Sequence)
911 def __init__(self,
912 x,
913 y=None,
914 sample_weights=None,
915 shuffle=False,
916 workers=1,
917 use_multiprocessing=False,
918 max_queue_size=10,
919 model=None,
920 **kwargs):
921 if not is_none_or_empty(y):
922 raise ValueError("`y` argument is not supported when using "
923 "`keras.utils.Sequence` as input.")
924 if not is_none_or_empty(sample_weights):
925 raise ValueError("`sample_weight` argument is not supported when using "
926 "`keras.utils.Sequence` as input.")
928 self._size = len(x)
929 self._shuffle_sequence = shuffle
930 self._keras_sequence = x
931 self._enqueuer = None
932 super(KerasSequenceAdapter, self).__init__(
933 x,
934 shuffle=False, # Shuffle is handed in the _make_callable override.
935 workers=workers,
936 use_multiprocessing=use_multiprocessing,
937 max_queue_size=max_queue_size,
938 model=model,
939 **kwargs)
941 @staticmethod
942 def _peek_and_restore(x):
943 return x[0], x
945 def _handle_multiprocessing(self, x, workers, use_multiprocessing,
946 max_queue_size):
947 if workers > 1 or (workers > 0 and use_multiprocessing):
948 def generator_fn():
949 self._enqueuer = data_utils.OrderedEnqueuer(
950 x, use_multiprocessing=use_multiprocessing,
951 shuffle=self._shuffle_sequence)
952 self._enqueuer.start(workers=workers, max_queue_size=max_queue_size)
953 return self._enqueuer.get()
954 else:
955 def generator_fn():
956 order = range(len(x))
957 if self._shuffle_sequence:
958 # Match the shuffle convention in OrderedEnqueuer.
959 order = list(order)
960 random.shuffle(order)
962 for i in order:
963 yield x[i]
965 return generator_fn
967 def get_size(self):
968 return self._size
970 def should_recreate_iterator(self):
971 return True
973 def on_epoch_end(self):
974 if self._enqueuer:
975 self._enqueuer.stop()
976 self._keras_sequence.on_epoch_end()
979ALL_ADAPTER_CLS = [
980 ListsOfScalarsDataAdapter, TensorLikeDataAdapter,
981 GenericArrayLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter,
982 KerasSequenceAdapter, CompositeTensorDataAdapter, DatasetCreatorAdapter
983]
986def select_data_adapter(x, y):
987 """Selects a data adapter than can handle a given x and y."""
988 adapter_cls = [cls for cls in ALL_ADAPTER_CLS if cls.can_handle(x, y)]
989 if not adapter_cls:
990 # TODO(scottzhu): This should be a less implementation-specific error.
991 raise ValueError(
992 "Failed to find data adapter that can handle "
993 "input: {}, {}".format(
994 _type_name(x), _type_name(y)))
995 elif len(adapter_cls) > 1:
996 raise RuntimeError(
997 "Data adapters should be mutually exclusive for "
998 "handling inputs. Found multiple adapters {} to handle "
999 "input: {}, {}".format(
1000 adapter_cls, _type_name(x), _type_name(y)))
1001 return adapter_cls[0]
1004def _type_name(x):
1005 """Generates a description of the type of an object."""
1006 if isinstance(x, dict):
1007 key_types = set(_type_name(key) for key in x.keys())
1008 val_types = set(_type_name(key) for key in x.values())
1009 return "({} containing {} keys and {} values)".format(
1010 type(x), key_types, val_types)
1011 if isinstance(x, (list, tuple)):
1012 types = set(_type_name(val) for val in x)
1013 return "({} containing values of types {})".format(
1014 type(x), types)
1015 return str(type(x))
1018def _process_tensorlike(inputs):
1019 """Process tensor-like inputs.
1021 This function:
1023 (1) Converts `Numpy` arrays to `Tensor`s.
1024 (2) Converts `Scipy` sparse matrices to `SparseTensor`s.
1025 (2) Converts `list`s to `tuple`s (for `tf.data` support).
1027 Args:
1028 inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like.
1030 Returns:
1031 Structure of `Tensor`s or tensor-like.
1032 """
1034 def _convert_numpy_and_scipy(x):
1035 if isinstance(x, np.ndarray):
1036 dtype = None
1037 if issubclass(x.dtype.type, np.floating):
1038 dtype = backend.floatx()
1039 return tensor_conversion.convert_to_tensor_v2_with_dispatch(
1040 x, dtype=dtype
1041 )
1042 elif _is_scipy_sparse(x):
1043 return _scipy_sparse_to_sparse_tensor(x)
1044 return x
1046 inputs = nest.map_structure(_convert_numpy_and_scipy, inputs)
1047 return nest.list_to_tuple(inputs)
1050def is_none_or_empty(inputs):
1051 # util method to check if the input is a None or a empty list.
1052 # the python "not" check will raise an error like below if the input is a
1053 # numpy array
1054 # "The truth value of an array with more than one element is ambiguous.
1055 # Use a.any() or a.all()"
1056 return inputs is None or not nest.flatten(inputs)
1059def broadcast_sample_weight_modes(target_structure, sample_weight_modes):
1060 """Match sample_weight_modes structure with output structure."""
1061 if target_structure is None or not nest.flatten(target_structure):
1062 return sample_weight_modes
1064 if isinstance(sample_weight_modes, str):
1065 if isinstance(target_structure, dict):
1066 return {key: sample_weight_modes for key in target_structure.keys()}
1067 return [sample_weight_modes for _ in target_structure]
1069 if sample_weight_modes:
1070 try:
1071 nest.assert_same_structure(
1072 training_utils.list_to_tuple(target_structure),
1073 training_utils.list_to_tuple(sample_weight_modes))
1074 except (ValueError, TypeError):
1075 target_str = str(nest.map_structure(lambda _: "...", target_structure))
1076 mode_str = str(nest.map_structure(lambda _: "...", sample_weight_modes))
1078 # Attempt to coerce sample_weight_modes to the target structure. This
1079 # implicitly depends on the fact that Model flattens outputs for its
1080 # internal representation.
1081 try:
1082 sample_weight_modes = nest.pack_sequence_as(
1083 target_structure, nest.flatten(sample_weight_modes))
1084 logging.warning(
1085 "sample_weight modes were coerced from\n {}\n to \n {}"
1086 .format(target_str, mode_str))
1087 except (ValueError, TypeError):
1088 raise ValueError(
1089 "Unable to match target structure and sample_weight_modes "
1090 "structure:\n {}\n to \n {}".format(target_str, mode_str))
1092 return sample_weight_modes
1095class DataHandler(object):
1096 """Handles iterating over epoch-level `tf.data.Iterator` objects."""
1098 def __init__(self,
1099 x,
1100 y=None,
1101 sample_weight=None,
1102 batch_size=None,
1103 steps_per_epoch=None,
1104 initial_epoch=0,
1105 epochs=1,
1106 shuffle=False,
1107 class_weight=None,
1108 max_queue_size=10,
1109 workers=1,
1110 use_multiprocessing=False,
1111 model=None,
1112 steps_per_execution=None,
1113 distribute=True):
1114 """Initializes a `DataHandler`.
1116 Arguments:
1117 x: See `Model.fit`.
1118 y: See `Model.fit`.
1119 sample_weight: See `Model.fit`.
1120 batch_size: See `Model.fit`.
1121 steps_per_epoch: See `Model.fit`.
1122 initial_epoch: See `Model.fit`.
1123 epochs: See `Model.fit`.
1124 shuffle: See `Model.fit`.
1125 class_weight: See `Model.fit`.
1126 max_queue_size: See `Model.fit`.
1127 workers: See `Model.fit`.
1128 use_multiprocessing: See `Model.fit`.
1129 model: The `Model` instance. Needed in order to correctly `build` the
1130 `Model` using generator-like inputs (see `GeneratorDataAdapter`).
1131 steps_per_execution: See `Model.compile`.
1132 distribute: Whether to distribute the `tf.dataset`.
1133 `PreprocessingLayer.adapt` does not support distributed datasets,
1134 `Model` should always set this to `True`.
1135 """
1137 self._initial_epoch = initial_epoch
1138 self._epochs = epochs
1139 self._insufficient_data = False
1140 self._model = model
1142 # `steps_per_execution_value` is the cached initial value.
1143 # `steps_per_execution` is mutable and may be changed by the DataAdapter
1144 # to handle partial executions.
1145 if steps_per_execution is None:
1146 self._steps_per_execution = 1
1147 self._steps_per_execution_value = 1
1148 else:
1149 self._steps_per_execution = steps_per_execution
1150 self._steps_per_execution_value = steps_per_execution.numpy().item()
1152 adapter_cls = select_data_adapter(x, y)
1153 self._adapter = adapter_cls(
1154 x,
1155 y,
1156 batch_size=batch_size,
1157 steps=steps_per_epoch,
1158 epochs=epochs - initial_epoch,
1159 sample_weights=sample_weight,
1160 shuffle=shuffle,
1161 max_queue_size=max_queue_size,
1162 workers=workers,
1163 use_multiprocessing=use_multiprocessing,
1164 distribution_strategy=distribute_lib.get_strategy(),
1165 model=model)
1167 strategy = distribute_lib.get_strategy()
1169 self._current_step = 0
1170 self._step_increment = self._steps_per_execution_value - 1
1171 self._insufficient_data = False
1173 self._configure_dataset_and_inferred_steps(strategy, x, steps_per_epoch,
1174 class_weight, distribute)
1176 def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch,
1177 class_weight, distribute):
1178 """Configure the `_dataset` and `_inferred_steps` attributes."""
1179 del x
1180 dataset = self._adapter.get_dataset()
1181 if class_weight:
1182 dataset = dataset.map(_make_class_weight_map_fn(class_weight))
1183 self._inferred_steps = self._infer_steps(steps_per_epoch, dataset)
1185 # `PreprocessingLayer.adapt` does not currently support distributed
1186 # datasets, so we pass `distribute=False` there.
1187 if distribute and not _is_distributed_dataset(dataset):
1188 dataset = strategy.experimental_distribute_dataset(dataset)
1189 self._dataset = dataset
1190 self._validate_data_handler()
1192 def enumerate_epochs(self):
1193 """Yields `(epoch, tf.data.Iterator)`."""
1194 with self._truncate_execution_to_epoch():
1195 data_iterator = iter(self._dataset)
1196 for epoch in range(self._initial_epoch, self._epochs):
1197 if self._insufficient_data: # Set by `catch_stop_iteration`.
1198 break
1199 if self._adapter.should_recreate_iterator():
1200 data_iterator = iter(self._dataset)
1201 yield epoch, data_iterator
1202 self._adapter.on_epoch_end()
1204 @contextlib.contextmanager
1205 def _truncate_execution_to_epoch(self):
1206 """Truncates steps per execution to at most one epoch."""
1207 should_truncate = (
1208 self._inferred_steps is not None and
1209 self._steps_per_execution_value > self._inferred_steps)
1210 original_value = self._steps_per_execution_value
1211 try:
1212 if should_truncate:
1213 self._steps_per_execution.assign(self._inferred_steps)
1214 self._steps_per_execution_value = self._inferred_steps
1215 yield
1216 finally:
1217 if should_truncate:
1218 self._steps_per_execution.assign(original_value)
1219 self._steps_per_execution_value = original_value
1221 def sync(self):
1222 context.async_wait()
1224 @contextlib.contextmanager
1225 def catch_stop_iteration(self):
1226 """Catches errors when an iterator runs out of data."""
1227 try:
1228 yield
1229 self.sync()
1230 except (StopIteration, errors.OutOfRangeError):
1231 if self._inferred_steps is None:
1232 self._inferred_steps = self._current_step
1233 else:
1234 self._insufficient_data = True
1235 total_epochs = self._epochs - self._initial_epoch
1236 logging.warning(
1237 "Your input ran out of data; interrupting training. "
1238 "Make sure that your dataset or generator can generate at "
1239 "least `steps_per_epoch * epochs` batches (in this case, "
1240 "{} batches). You may need to use the repeat() function "
1241 "when building your dataset.".format(total_epochs *
1242 self._inferred_steps))
1244 def steps(self):
1245 """Yields steps for the current epoch."""
1246 self._current_step = 0
1247 # `self._inferred_steps` can be changed by `catch_stop_iteration`.
1248 while (self._inferred_steps is None or
1249 self._current_step < self._inferred_steps):
1250 if self._insufficient_data: # Set by `catch_stop_iteration`.
1251 break
1253 can_run_full_execution = (
1254 self._steps_per_execution_value == 1 or
1255 self._inferred_steps is None or
1256 self._inferred_steps - self._current_step >=
1257 self._steps_per_execution_value)
1259 if can_run_full_execution:
1260 self._step_increment = self._steps_per_execution_value - 1
1261 yield self._current_step
1262 self._current_step += self._steps_per_execution_value
1263 else:
1264 # Last partial execution.
1265 steps_remaining = self._inferred_steps - self._current_step
1266 self._steps_per_execution.assign(steps_remaining)
1267 self._step_increment = steps_remaining - 1
1268 yield self._current_step
1269 self._current_step += steps_remaining
1270 self._steps_per_execution.assign(self._steps_per_execution_value)
1272 @property
1273 def step_increment(self):
1274 """The number to increment the step for `on_batch_end` methods."""
1275 return self._step_increment
1277 @property
1278 def inferred_steps(self):
1279 """The inferred steps per epoch of the created `Dataset`.
1281 This will be `None` in the case where:
1283 (1) A `Dataset` of unknown cardinality was passed to the `DataHandler`, and
1284 (2) `steps_per_epoch` was not provided, and
1285 (3) The first epoch of iteration has not yet completed.
1287 Returns:
1288 The inferred steps per epoch of the created `Dataset`.
1289 """
1290 return self._inferred_steps
1292 @property
1293 def should_sync(self):
1294 # Catch OutOfRangeError for Datasets of unknown size.
1295 # This blocks until the batch has finished executing.
1296 # TODO(b/150292341): Allow multiple async steps here.
1297 return self._inferred_steps is None
1299 def _log_indefinite_training_warning(self):
1300 logging.warning("The training loop will run indefinitely since you have "
1301 "set `steps_per_epoch=-1`. Please use batch-level "
1302 "callbacks to save checkpoints or log training progress, "
1303 "etc")
1305 def _infer_steps(self, steps, dataset):
1306 """Infers steps_per_epoch needed to loop through a dataset."""
1307 if steps == -1:
1308 self._log_indefinite_training_warning()
1309 return None
1311 if steps is not None:
1312 return steps
1314 adapter_steps = self._adapter.get_size()
1315 if adapter_steps is not None:
1316 return adapter_steps
1318 size = cardinality.cardinality(dataset)
1319 if size == cardinality.INFINITE and steps is None:
1320 raise ValueError(
1321 "When passing an infinitely repeating dataset, please specify a "
1322 "`steps_per_epoch` value so that epoch level "
1323 "callbacks continue to work. The value can be arbitrary, or a number "
1324 "that you think correctly defines the size of an epoch. "
1325 "Epoch-level callbacks will then be called at this interval.")
1326 if size >= 0:
1327 return size.numpy().item()
1328 return None
1330 @property
1331 def _samples(self):
1332 return self._adapter.get_samples()
1334 def _validate_data_handler(self):
1335 # TODO(b/152094471): Support this with DistIter.get_next_as_optional.
1336 if self._steps_per_execution_value > 1 and self._inferred_steps is None:
1337 raise ValueError(
1338 "Could not infer the size of the data. With "
1339 "`steps_per_execution > 1`, you must specify the number of steps "
1340 "to run.")
1343class _ClusterCoordinatorDataHandler(DataHandler):
1344 """A `DataHandler` that is compatible with `ClusterCoordinator`."""
1346 def __init__(self, x, y=None, **kwargs):
1347 if not isinstance(x, dataset_creator.DatasetCreator):
1348 x = self._convert_to_dataset_creator(x, y, **kwargs)
1350 super().__init__(x=x, **kwargs)
1352 def _convert_to_dataset_creator(self, x, y, **kwargs):
1353 """Converts non-tf.data.Dataset to `DatasetCreator` instances."""
1355 def _dataset_fn(input_context):
1356 del input_context
1357 data_adapter_cls = select_data_adapter(x, y)
1358 return data_adapter_cls(x=x, y=y, **kwargs).get_dataset()
1360 # This check is needed because types like `tf.data.Dataset` don't work with
1361 # PSS yet. So only apply this logic to the types we can support.
1362 if (isinstance(x, _get_tensor_types()) and
1363 isinstance(y, _get_tensor_types())):
1364 return dataset_creator.DatasetCreator(_dataset_fn)
1365 else:
1366 raise NotImplementedError(
1367 "Only `tf.keras.utils.experimental.DatasetCreator`, `tf.Tensor`, "
1368 "numpy arrays and pandas dataframes are supported types at this "
1369 "time.")
1371 def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch,
1372 class_weight, distribute):
1373 if not isinstance(x, dataset_creator.DatasetCreator):
1374 raise TypeError("When using `ParameterServerStrategy`, `x` must be a "
1375 "`DatasetCreator`.")
1377 def per_worker_dataset_fn():
1379 return strategy.distribute_datasets_from_function(
1380 x, options=x.input_options)
1382 self._dataset = self._model._cluster_coordinator.create_per_worker_dataset( # pylint: disable=protected-access
1383 per_worker_dataset_fn)
1385 if steps_per_epoch == -1:
1386 self._inferred_steps = None
1387 self._log_indefinite_training_warning()
1388 else:
1389 self._inferred_steps = steps_per_epoch
1391 def sync(self):
1392 self._model._cluster_coordinator.join() # pylint: disable=protected-access
1395def get_data_handler(*args, **kwargs):
1396 if getattr(kwargs["model"], "_cluster_coordinator", None):
1397 return _ClusterCoordinatorDataHandler(*args, **kwargs)
1398 return DataHandler(*args, **kwargs)
1401def _make_class_weight_map_fn(class_weight):
1402 """Applies class weighting to a `Dataset`.
1404 The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where
1405 `y` must be a single `Tensor`.
1407 Args:
1408 class_weight: A map where the keys are integer class ids and values are
1409 the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`
1411 Returns:
1412 A function that can be used with `tf.data.Dataset.map` to apply class
1413 weighting.
1414 """
1415 class_ids = list(sorted(class_weight.keys()))
1416 expected_class_ids = list(range(len(class_ids)))
1417 if class_ids != expected_class_ids:
1418 error_msg = (
1419 "Expected `class_weight` to be a dict with keys from 0 to one less "
1420 "than the number of classes, found {}").format(class_weight)
1421 raise ValueError(error_msg)
1423 class_weight_tensor = tensor_conversion.convert_to_tensor_v2_with_dispatch(
1424 [class_weight[int(c)] for c in class_ids]
1425 )
1427 def _class_weights_map_fn(*data):
1428 """Convert `class_weight` to `sample_weight`."""
1429 x, y, sw = unpack_x_y_sample_weight(data)
1431 if nest.is_nested(y):
1432 raise ValueError(
1433 "`class_weight` is only supported for Models with a single output.")
1435 if y.shape.rank > 2:
1436 raise ValueError("`class_weight` not supported for "
1437 "3+ dimensional targets.")
1439 y_classes = smart_cond.smart_cond(
1440 y.shape.rank == 2 and backend.shape(y)[1] > 1,
1441 lambda: backend.argmax(y, axis=1),
1442 lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64))
1444 cw = array_ops.gather_v2(class_weight_tensor, y_classes)
1445 if sw is not None:
1446 cw = math_ops.cast(cw, sw.dtype)
1447 sw, cw = expand_1d((sw, cw))
1448 # `class_weight` and `sample_weight` are multiplicative.
1449 sw = sw * cw
1450 else:
1451 sw = cw
1453 return x, y, sw
1455 return _class_weights_map_fn
1458def expand_1d(data):
1459 """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s."""
1461 def _expand_single_1d_tensor(t):
1462 # Leaves `CompositeTensor`s as-is.
1463 if (isinstance(t, ops.Tensor) and
1464 isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1):
1465 return array_ops.expand_dims_v2(t, axis=-1)
1466 return t
1468 return nest.map_structure(_expand_single_1d_tensor, data)
1471def train_validation_split(arrays, validation_split):
1472 """Split arrays into train and validation subsets in deterministic order.
1474 The last part of data will become validation data.
1476 Args:
1477 arrays: Tensors to split. Allowed inputs are arbitrarily nested structures
1478 of Tensors and NumPy arrays.
1479 validation_split: Float between 0 and 1. The proportion of the dataset to
1480 include in the validation split. The rest of the dataset will be included
1481 in the training split.
1482 Returns:
1483 `(train_arrays, validation_arrays)`
1484 """
1486 def _can_split(t):
1487 tensor_types = _get_tensor_types()
1488 return isinstance(t, tensor_types) or t is None
1490 flat_arrays = nest.flatten(arrays)
1491 unsplitable = [type(t) for t in flat_arrays if not _can_split(t)]
1492 if unsplitable:
1493 raise ValueError(
1494 "`validation_split` is only supported for Tensors or NumPy "
1495 "arrays, found following types in the input: {}".format(unsplitable))
1497 if all(t is None for t in flat_arrays):
1498 return arrays, arrays
1500 first_non_none = None
1501 for t in flat_arrays:
1502 if t is not None:
1503 first_non_none = t
1504 break
1506 # Assumes all arrays have the same batch shape or are `None`.
1507 batch_dim = int(first_non_none.shape[0])
1508 split_at = int(math.floor(batch_dim * (1. - validation_split)))
1510 if split_at == 0 or split_at == batch_dim:
1511 raise ValueError(
1512 "Training data contains {batch_dim} samples, which is not sufficient "
1513 "to split it into a validation and training set as specified by "
1514 "`validation_split={validation_split}`. Either provide more data, or a "
1515 "different value for the `validation_split` argument." .format(
1516 batch_dim=batch_dim, validation_split=validation_split))
1518 def _split(t, start, end):
1519 if t is None:
1520 return t
1521 return t[start:end]
1523 train_arrays = nest.map_structure(
1524 functools.partial(_split, start=0, end=split_at), arrays)
1525 val_arrays = nest.map_structure(
1526 functools.partial(_split, start=split_at, end=batch_dim), arrays)
1528 return train_arrays, val_arrays
1531@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[])
1532def unpack_x_y_sample_weight(data):
1533 """Unpacks user-provided data tuple.
1535 This is a convenience utility to be used when overriding
1536 `Model.train_step`, `Model.test_step`, or `Model.predict_step`.
1537 This utility makes it easy to support data of the form `(x,)`,
1538 `(x, y)`, or `(x, y, sample_weight)`.
1540 Standalone usage:
1542 >>> features_batch = tf.ones((10, 5))
1543 >>> labels_batch = tf.zeros((10, 5))
1544 >>> data = (features_batch, labels_batch)
1545 >>> # `y` and `sample_weight` will default to `None` if not provided.
1546 >>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
1547 >>> sample_weight is None
1548 True
1550 Example in overridden `Model.train_step`:
1552 ```python
1553 class MyModel(tf.keras.Model):
1555 def train_step(self, data):
1556 # If `sample_weight` is not provided, all samples will be weighted
1557 # equally.
1558 x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
1560 with tf.GradientTape() as tape:
1561 y_pred = self(x, training=True)
1562 loss = self.compiled_loss(
1563 y, y_pred, sample_weight, regularization_losses=self.losses)
1564 trainable_variables = self.trainable_variables
1565 gradients = tape.gradient(loss, trainable_variables)
1566 self.optimizer.apply_gradients(zip(gradients, trainable_variables))
1568 self.compiled_metrics.update_state(y, y_pred, sample_weight)
1569 return {m.name: m.result() for m in self.metrics}
1570 ```
1572 Args:
1573 data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
1575 Returns:
1576 The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not
1577 provided.
1578 """
1579 if not isinstance(data, tuple):
1580 return (data, None, None)
1581 elif len(data) == 1:
1582 return (data[0], None, None)
1583 elif len(data) == 2:
1584 return (data[0], data[1], None)
1585 elif len(data) == 3:
1586 return (data[0], data[1], data[2])
1587 else:
1588 error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
1589 "or `(x, y, sample_weight)`, found: {}").format(data)
1590 raise ValueError(error_msg)
1593@keras_export("keras.utils.pack_x_y_sample_weight", v1=[])
1594def pack_x_y_sample_weight(x, y=None, sample_weight=None):
1595 """Packs user-provided data into a tuple.
1597 This is a convenience utility for packing data into the tuple formats
1598 that `Model.fit` uses.
1600 Standalone usage:
1602 >>> x = tf.ones((10, 1))
1603 >>> data = tf.keras.utils.pack_x_y_sample_weight(x)
1604 >>> isinstance(data, tf.Tensor)
1605 True
1606 >>> y = tf.ones((10, 1))
1607 >>> data = tf.keras.utils.pack_x_y_sample_weight(x, y)
1608 >>> isinstance(data, tuple)
1609 True
1610 >>> x, y = data
1612 Args:
1613 x: Features to pass to `Model`.
1614 y: Ground-truth targets to pass to `Model`.
1615 sample_weight: Sample weight for each element.
1617 Returns:
1618 Tuple in the format used in `Model.fit`.
1619 """
1620 if y is None:
1621 # For single x-input, we do no tuple wrapping since in this case
1622 # there is no ambiguity. This also makes NumPy and Dataset
1623 # consistent in that the user does not have to wrap their Dataset
1624 # data in an unecessary tuple
1625 if not nest.is_nested(x):
1626 return x
1627 else:
1628 return (x,)
1629 elif sample_weight is None:
1630 return (x, y)
1631 else:
1632 return (x, y, sample_weight)
1635def single_batch_iterator(strategy,
1636 x,
1637 y=None,
1638 sample_weight=None,
1639 class_weight=None):
1640 """Creates a single-batch dataset."""
1641 x, y, sample_weight = _process_tensorlike((x, y, sample_weight))
1642 if y is None:
1643 data = (x,)
1644 elif sample_weight is None:
1645 data = (x, y)
1646 else:
1647 data = (x, y, sample_weight)
1649 _check_data_cardinality(data)
1650 dataset = dataset_ops.DatasetV2.from_tensors(data)
1651 if class_weight:
1652 dataset = dataset.map(_make_class_weight_map_fn(class_weight))
1653 dataset = strategy.experimental_distribute_dataset(dataset)
1654 return iter(dataset)
1657def _check_data_cardinality(data):
1658 num_samples = set(int(i.shape[0]) for i in nest.flatten(data))
1659 if len(num_samples) > 1:
1660 msg = "Data cardinality is ambiguous:\n"
1661 for label, single_data in zip(["x", "y", "sample_weight"], data):
1662 msg += " {} sizes: {}\n".format(
1663 label, ", ".join(str(i.shape[0]) for i in nest.flatten(single_data)))
1664 msg += "Make sure all arrays contain the same number of samples."
1665 raise ValueError(msg)
1668def _get_tensor_types():
1669 try:
1670 import pandas as pd # pylint: disable=g-import-not-at-top
1672 return (ops.Tensor, np.ndarray, pd.Series, pd.DataFrame)
1673 except ImportError:
1674 return (ops.Tensor, np.ndarray)
1677def _is_scipy_sparse(x):
1678 try:
1679 from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
1681 return issparse(x)
1682 except ImportError:
1683 return False
1686def _scipy_sparse_to_sparse_tensor(t):
1687 """Converts a SciPy sparse matrix to a SparseTensor."""
1688 sparse_coo = t.tocoo()
1689 row, col = sparse_coo.row, sparse_coo.col
1690 data, shape = sparse_coo.data, sparse_coo.shape
1691 if issubclass(data.dtype.type, np.floating):
1692 data = data.astype(backend.floatx())
1693 indices = np.concatenate(
1694 (np.expand_dims(row, axis=1), np.expand_dims(col, axis=1)), axis=1)
1695 return sparse_tensor.SparseTensor(indices, data, shape)
1698def _is_distributed_dataset(ds):
1699 return isinstance(ds, input_lib.DistributedDatasetInterface)