Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training_utils_v1.py: 15%
760 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related utilities."""
17import abc
18import atexit
19import collections
20import functools
21import multiprocessing.pool
22import threading
23import time
25import numpy as np
26import tensorflow.compat.v2 as tf
28from keras.src import backend
29from keras.src import callbacks as cbks
30from keras.src import losses
31from keras.src import metrics as metrics_module
32from keras.src.utils import data_utils
33from keras.src.utils import generic_utils
34from keras.src.utils import losses_utils
35from keras.src.utils import tf_inspect
37# isort: off
38from tensorflow.python.platform import tf_logging as logging
41def is_composite_or_composite_value(tensor):
42 """Returns true if 'tensor' is a CompositeTensor or a CT Value object."""
43 # TODO(b/125094323): This should be isinstance(CompositeTensor) or
44 # isinstance(CompositeTensorValue) once we support that.
45 return isinstance(
46 tensor,
47 (
48 tf.__internal__.CompositeTensor,
49 tf.compat.v1.SparseTensorValue,
50 tf.compat.v1.ragged.RaggedTensorValue,
51 ),
52 )
55class Aggregator(object, metaclass=abc.ABCMeta):
56 """Abstract base class used to aggregate batch-level outputs of a loop.
58 Attributes:
59 use_steps: Whether the loop is using `step` or `batch_size`.
60 num_samples: Total number of samples: `batch_size * num_batches`.
61 steps: Total number of steps.
62 batch_size: Batch size. It is used for validation checks between inputs
63 and outputs.
64 results: What to return at the end of the aggregation loop.
65 """
67 def __init__(
68 self, use_steps, num_samples=None, steps=None, batch_size=None
69 ):
70 self.use_steps = use_steps
71 self.num_samples = num_samples
72 self.steps = steps
73 self.batch_size = batch_size
74 self.results = []
76 @abc.abstractmethod
77 def create(self, batch_outs):
78 """Creates the initial results from the first batch outputs.
80 Args:
81 batch_outs: A list of batch-level outputs.
82 """
83 raise NotImplementedError("Must be implemented in subclasses.")
85 @abc.abstractmethod
86 def aggregate(self, batch_outs, batch_start=None, batch_end=None):
87 """Aggregates batch-level results into total results.
89 Args:
90 batch_outs: A list of batch-level outputs.
91 batch_start: The start index of this batch. Always `None` if
92 `use_steps` is `True`.
93 batch_end: The end index of this batch. Always `None` if `use_steps`
94 is `True`.
95 """
96 raise NotImplementedError("Must be implemented in subclasses.")
98 @abc.abstractmethod
99 def finalize(self):
100 """Prepares the total results to be returned."""
101 raise NotImplementedError("Must be implemented in subclasses.")
104class MetricsAggregator(Aggregator):
105 """Aggregator that calculates loss and metrics info.
107 Attributes:
108 use_steps: Whether the loop is using `step` or `batch_size`.
109 num_samples: Total number of samples: `batch_size*num_batches`.
110 steps: Total number of steps, ie number of times to iterate over a dataset
111 to cover all samples.
112 """
114 def __init__(self, use_steps, num_samples=None, steps=None):
115 super().__init__(
116 use_steps=use_steps,
117 num_samples=num_samples,
118 steps=steps,
119 batch_size=None,
120 )
122 def create(self, batch_outs):
123 self.results = [0.0] * len(batch_outs)
125 def aggregate(self, batch_outs, batch_start=None, batch_end=None):
126 # Loss.
127 if self.use_steps:
128 self.results[0] += batch_outs[0]
129 else:
130 self.results[0] += batch_outs[0] * (batch_end - batch_start)
131 # Metrics (always stateful, just grab current values.)
132 self.results[1:] = batch_outs[1:]
134 def finalize(self):
135 if not self.results:
136 raise ValueError("Empty training data.")
137 self.results[0] /= self.num_samples or self.steps
140def _append_sparse_tensor_value(target, to_append):
141 """Append sparse tensor value objects."""
142 # Make sure the sparse tensors are of the same size (except for the 0th
143 # dim).
144 if len(target.dense_shape) != len(to_append.dense_shape):
145 raise RuntimeError(
146 "Unable to concatenate %s and %s. The inner dense shapes do not "
147 "have the same number of dimensions (%s vs %s)"
148 % (target, to_append, target.dense_shape, to_append.dense_shape)
149 )
151 if target.dense_shape[1:] != to_append.dense_shape[1:]:
152 raise RuntimeError(
153 "Unable to concatenate %s and %s. The inner dense shapes do not "
154 "match inner dimensions (%s vs %s)"
155 % (
156 target,
157 to_append,
158 target.dense_shape[1:],
159 to_append.dense_shape[1:],
160 )
161 )
163 # Add the to_append indices to target, updating the 0th value, and keeping
164 # track of the maximum so we know the final dense_shape of this tensor.
165 base_dim0_value = target.dense_shape[0]
166 max_dim0_value = target.dense_shape[0]
167 new_indices = target.indices
168 for index in to_append.indices:
169 # Here, we iterate through the sparse indices of the tensor to append.
170 # For each index, we update its zeroth value (the batch index) by adding
171 # the number of batch items in the tensor we are appending to (so an
172 # index of [0, 0, 1] for a value that is being appended to a tensor with
173 # 0th dim size 3 would become [3, 0, 1].)
174 index[0] += base_dim0_value
175 max_dim0_value = max(max_dim0_value, index[0])
176 new_indices = np.append(new_indices, [index], axis=0)
178 # Extend the values array to contain all of the appended values. These will
179 # be in the same order as the indices added above.
180 new_values = np.concatenate((target.values, to_append.values), axis=0)
182 # Create a new dense shape by replacing the value for the 0th dimension
183 # with the new max dim0 value.
184 new_dense_shape = list(target.dense_shape)
185 new_dense_shape[0] = max_dim0_value + 1
186 new_dense_shape = tuple(new_dense_shape)
188 return tf.compat.v1.SparseTensorValue(
189 indices=new_indices, values=new_values, dense_shape=new_dense_shape
190 )
193def _append_ragged_tensor_value(target, to_append):
194 """Append ragged tensor value objects."""
195 # Make sure the ragged tensors are of the same size (save for the 0th dim).
196 if len(target.shape) != len(to_append.shape):
197 raise RuntimeError(f"Unable to concatenate {target} and {to_append}")
199 if target.shape[1:] != to_append.shape[1:]:
200 raise RuntimeError(f"Unable to concatenate {target} and {to_append}")
202 adjusted_row_splits = to_append.row_splits[1:] + target.row_splits[-1]
203 new_row_splits = np.append(target.row_splits, adjusted_row_splits)
204 if isinstance(target.values, tf.compat.v1.ragged.RaggedTensorValue):
205 new_values = _append_ragged_tensor_value(
206 target.values, to_append.values
207 )
208 else:
209 new_values = np.concatenate((target.values, to_append.values), axis=0)
211 return tf.compat.v1.ragged.RaggedTensorValue(new_values, new_row_splits)
214def _append_composite_tensor(target, to_append):
215 """Helper function to append composite tensors to each other in the 0 axis.
217 In order to support batching within a fit/evaluate/predict call, we need
218 to be able to aggregate within a CompositeTensor. Unfortunately, the CT
219 API currently does not make this easy - especially in V1 mode, where we're
220 working with CompositeTensor Value objects that have no connection with the
221 CompositeTensors that created them.
223 Args:
224 target: CompositeTensor or CompositeTensor value object that will be
225 appended to.
226 to_append: CompositeTensor or CompositeTensor value object to append to.
227 'target'.
229 Returns:
230 A CompositeTensor or CompositeTensor value object.
232 Raises:
233 RuntimeError: if concatenation is not possible.
234 """
235 if type(target) is not type(to_append):
236 raise RuntimeError(
237 f"Unable to concatenate {type(target)} and {type(to_append)}"
238 )
240 # Perform type-specific concatenation.
241 # TODO(b/125094323): This should be replaced by a simple call to
242 # target.append() that should work on all of the below classes.
244 # If we're seeing a CompositeTensor here, we know it's because we're in
245 # Eager mode (or else we'd have evaluated the CT to a CT Value object
246 # already). Therefore, it's safe to call concat() on it without evaluating
247 # the result any further. If not - that is, if we're seeing a
248 # SparseTensorValue or a RaggedTensorValue - we need to hand-update it
249 # since we're outside of the graph anyways.
250 if isinstance(target, tf.SparseTensor):
251 # We need to invoke the sparse version of concatenate here - tf.concat
252 # won't work.
253 return tf.compat.v1.sparse_concat(sp_inputs=[target, to_append], axis=0)
254 elif isinstance(target, tf.RaggedTensor):
255 return tf.concat([target, to_append], axis=0)
256 elif isinstance(target, tf.compat.v1.SparseTensorValue):
257 return _append_sparse_tensor_value(target, to_append)
258 elif isinstance(target, tf.compat.v1.ragged.RaggedTensorValue):
259 return _append_ragged_tensor_value(target, to_append)
260 else:
261 raise RuntimeError(
262 f"Attempted to concatenate unsupported object {type(target)}."
263 )
266class ConcatAggregator(Aggregator):
267 """Combine tensor-likes which cannot be merged on the fly.
269 This class expects to aggregate a single tensor-like rather than a nested
270 structure of tensor-likes.
271 """
273 def __init__(self, batch_size):
274 self.composite = None
275 super().__init__(
276 use_steps=True, num_samples=None, steps=None, batch_size=batch_size
277 )
279 def create(self, batch_element):
280 self.composite = is_composite_or_composite_value(batch_element)
282 def aggregate(self, batch_element, batch_start=None, batch_end=None):
284 # TODO(psv): Add num_samples check here to detect when output batch
285 # #samples is < batch size and != input batch #samples.
286 if self.batch_size and self.batch_size < batch_element.shape[0]:
287 raise ValueError(
288 "Mismatch between expected batch size and model output batch "
289 "size. Output shape = {}, "
290 "expected output shape = shape {}".format(
291 batch_element.shape,
292 (self.batch_size,) + batch_element.shape[1:],
293 )
294 )
295 self.results.append(batch_element)
297 def finalize(self):
298 # Special case of single batch inference which skips a copy.
299 if len(self.results) == 1:
300 self.results = self.results[0]
302 elif self.composite:
303 # TODO(taylorrobie): efficiently concatenate.
304 results = self.results[0]
305 for r in self.results[1:]:
306 results = _append_composite_tensor(results, r)
307 self.results = results
309 else:
310 self.results = np.concatenate(self.results, axis=0)
313_COPY_THREADS = 4
314_COPY_POOL = None
317def get_copy_pool():
318 """Shared threadpool for copying arrays.
320 Pool instantiation takes ~ 2ms, so a singleton pool is used rather than
321 creating a pool per SliceAggregator.
323 Returns:
324 The global copy threadpool.
325 """
326 global _COPY_POOL
327 if _COPY_POOL is None:
328 _COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS)
329 atexit.register(_COPY_POOL.close)
330 return _COPY_POOL
333class SliceAggregator(Aggregator):
334 """Combine arrays where the final size is known.
336 This class expects to aggregate a single tensor-like rather than a nested
337 structure of tensor-likes.
339 NumPy copies are an operation that threads handle quite well because all of
340 the heavy lifting is in c and does not need the GIL. Moreover, we can
341 perform lock-free writes to the same buffer in multiple threads because the
342 nature of result aggregation guarantees that either the indices are disjoint
343 or the aggregator will throw an exception in finalize. Moreover, because
344 aggregation is performed on the slowest varying dimension, assignments for a
345 given batch will write to contiguous blocks of memory, further minimizing
346 contention.
348 There is, however, some scheduling and context switching overhead which will
349 offset the gains from pipelining the slice assignment. Below a given
350 threshold it is faster to simply assign in the main thread rather than
351 enqueue the assignment in a side thread. The exact threshold will vary from
352 system to system, but the time is not very sensitive to the exact transition
353 so a value of 2 ** 14 was chosen which should be reasonable on most systems.
354 """
356 _BINARY_SIZE_THRESHOLD = 2**14
357 _MAX_COPY_SECONDS = 300
359 def __init__(self, num_samples, batch_size):
360 self._async_copies = []
361 self._pool = get_copy_pool()
362 self._errors = []
363 super().__init__(
364 use_steps=False,
365 num_samples=num_samples,
366 steps=None,
367 batch_size=batch_size,
368 )
370 def create(self, batch_element):
371 # This step does not need to be pipelined because NumPy empty array
372 # initialization is effectively instantaneous.
373 shape = (self.num_samples,) + batch_element.shape[1:]
374 dtype = batch_element.dtype
376 self.results = np.empty(shape=shape, dtype=dtype)
378 def aggregate(self, batch_element, batch_start, batch_end):
379 # Fail early.
380 if self._errors:
381 raise self._errors[0]
383 # In the special case of single batch inference, no copy is needed.
384 if batch_end - batch_start == self.num_samples:
385 if self.num_samples != batch_element.shape[0]:
386 raise ValueError(
387 "Mismatch between expected batch size and model "
388 "output batch size. Output shape = {}, "
389 "expected output shape = shape {}".format(
390 batch_element.shape, self.results.shape
391 )
392 )
394 self.results = batch_element
395 return
397 # This is an approximate threshold, so we don't need to consider the
398 # number of bytes per element.
399 num_elements = np.prod(batch_element.shape)
400 if num_elements < self._BINARY_SIZE_THRESHOLD:
401 self.results[batch_start:batch_end] = batch_element
402 else:
403 is_finished = threading.Event()
404 self._pool.apply_async(
405 self._slice_assign,
406 args=(batch_element, batch_start, batch_end, is_finished),
407 )
408 self._async_copies.append(is_finished)
410 def _slice_assign(self, batch_element, batch_start, batch_end, is_finished):
411 """Legacy utility method to slice input arrays."""
412 try:
413 self.results[batch_start:batch_end] = batch_element
415 except Exception as e:
416 # `_slice_assign` should only be called in threads and exceptions
417 # raised in threads do not carry over to the main thread. So instead
418 # we perform a a broad catch in the thread and then store the
419 # exception to be re-raised in the main thread.
420 self._errors.append(e)
422 finally:
423 is_finished.set()
425 def finalize(self):
426 start_time = time.time()
427 for is_finished in self._async_copies:
428 timeout = max(
429 [0.0, self._MAX_COPY_SECONDS - (time.time() - start_time)]
430 )
431 if not is_finished.wait(timeout):
432 raise ValueError("Timed out waiting for copy to complete.")
434 if self._errors:
435 raise self._errors[0]
438class OutputsAggregator(Aggregator):
439 """Aggregator that concatenates outputs."""
441 _structure = None
443 def create(self, batch_outs):
444 # SparseTensorValue is a named tuple which nest will flatten, so we need
445 # to guard it to properly handle the structure.
446 self._structure = tf.__internal__.nest.get_traverse_shallow_structure(
447 lambda x: not is_composite_or_composite_value(x), batch_outs
448 )
449 batch_outs = tf.__internal__.nest.flatten_up_to(
450 self._structure, batch_outs
451 )
453 for batch_element in batch_outs:
454 if is_composite_or_composite_value(batch_element):
455 # If the output is not a ndarray, it will be either a composite
456 # tensor or a composite tensor's Value object. In either case,
457 # we can't allocate an array to hold the object - we'll handle
458 # it later.
459 self.results.append(ConcatAggregator(self.batch_size))
460 elif isinstance(batch_element, np.ndarray):
461 self.results.append(
462 (
463 ConcatAggregator(self.batch_size)
464 if self.use_steps
465 else SliceAggregator(self.num_samples, self.batch_size)
466 )
467 )
468 else:
469 # This is not a ndarray, a CompositeTensor, or a
470 # CompositeTensorValue. Fail fast rather than trying to
471 # concatenate it.
472 raise RuntimeError(
473 "Attempted to aggregate unsupported object {}.".format(
474 batch_element
475 )
476 )
478 self.results[-1].create(batch_element)
480 def aggregate(self, batch_outs, batch_start=None, batch_end=None):
481 batch_outs = tf.__internal__.nest.flatten_up_to(
482 self._structure, batch_outs
483 )
484 for batch_element, result in zip(batch_outs, self.results):
485 result.aggregate(batch_element, batch_start, batch_end)
487 def finalize(self):
488 for result in self.results:
489 result.finalize()
490 self.results = [i.results for i in self.results]
491 self.results = tf.nest.pack_sequence_as(self._structure, self.results)
494def get_progbar(model, count_mode, include_metrics=True):
495 """Get Progbar."""
496 if include_metrics:
497 stateful_metric_names = getattr(model, "metrics_names", None)
498 if stateful_metric_names:
499 stateful_metric_names = stateful_metric_names[1:] # Exclude `loss`
500 else:
501 stateful_metric_names = None
502 return cbks.ProgbarLogger(
503 count_mode, stateful_metrics=stateful_metric_names
504 )
507def check_num_samples(ins, batch_size=None, steps=None, steps_name="steps"):
508 """Determine the number of samples provided for training and evaluation.
510 The number of samples is not defined when running with `steps`,
511 in which case the number of samples is set to `None`.
513 Args:
514 ins: List of tensors to be fed to the Keras function.
515 batch_size: Integer batch size or `None` if not defined.
516 steps: Total number of steps (batches of samples) before declaring
517 `_predict_loop` finished. Ignored with the default value of `None`.
518 steps_name: The public API's parameter name for `steps`.
520 Raises:
521 ValueError: when `steps` is `None` and the attribute `ins.shape`
522 does not exist. Also raises ValueError when `steps` is not `None`
523 and `batch_size` is not `None` because they are mutually
524 exclusive.
526 Returns:
527 When steps is `None`, returns the number of samples to be
528 processed based on the size of the first dimension of the
529 first input numpy array. When steps is not `None` and
530 `batch_size` is `None`, returns `None`.
531 """
532 if steps is not None and batch_size is not None:
533 raise ValueError(
534 "If " + steps_name + " is set, the `batch_size` must be None."
535 )
536 if check_steps_argument(ins, steps, steps_name):
537 return None
539 if hasattr(ins[0], "shape"):
540 return int(ins[0].shape[0])
541 return None # Edge case where ins == [static_learning_phase]
544def standardize_single_array(x, expected_shape=None):
545 """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1."""
546 if x is None:
547 return None
549 if is_composite_or_composite_value(x):
550 return x
552 if isinstance(x, int):
553 raise ValueError(
554 f"Expected an array data type but received an integer: {x}"
555 )
557 if (
558 x.shape is not None
559 and len(x.shape) == 1
560 and (expected_shape is None or len(expected_shape) != 1)
561 ):
562 if tf.is_tensor(x):
563 x = tf.compat.v1.expand_dims(x, axis=1)
564 else:
565 x = np.expand_dims(x, 1)
566 return x
569def get_composite_shape(tensor):
570 """Returns the shape of the passed composite tensor."""
571 if isinstance(tensor, tf.compat.v1.SparseTensorValue):
572 # SparseTensorValues use a 'dense_shape' attribute
573 return tensor.dense_shape
574 else:
575 return tensor.shape
578def standardize_input_data(
579 data, names, shapes=None, check_batch_axis=True, exception_prefix=""
580):
581 """Normalizes inputs and targets provided by users.
583 Users may pass data as a list of arrays, dictionary of arrays,
584 or as a single array. We normalize this to an ordered list of
585 arrays (same order as `names`), while checking that the provided
586 arrays have shapes that match the network's expectations.
588 Args:
589 data: User-provided input data (polymorphic).
590 names: List of expected array names.
591 shapes: Optional list of expected array shapes.
592 check_batch_axis: Boolean; whether to check that the batch axis of the
593 arrays matches the expected value found in `shapes`.
594 exception_prefix: String prefix used for exception formatting.
596 Returns:
597 List of standardized input arrays (one array per model input).
599 Raises:
600 ValueError: in case of improperly formatted user-provided data.
601 """
602 try:
603 data_len = len(data)
604 except TypeError:
605 # For instance if data is `None` or a symbolic Tensor.
606 data_len = None
608 if not names:
609 if data_len and not isinstance(data, dict):
610 raise ValueError(
611 "Error when checking model "
612 + exception_prefix
613 + ": expected no data, but got:",
614 data,
615 )
616 return []
617 if data is None:
618 return [None for _ in range(len(names))]
620 if isinstance(data, dict):
621 try:
622 data = [
623 data[x].values
624 if data[x].__class__.__name__ == "DataFrame"
625 else data[x]
626 for x in names
627 ]
628 except KeyError as e:
629 raise ValueError(
630 'No data provided for "'
631 + e.args[0]
632 + '". Need data for each key in: '
633 + str(names)
634 )
635 elif isinstance(data, (list, tuple)):
636 if isinstance(data[0], (list, tuple)):
637 data = [np.asarray(d) for d in data]
638 elif len(names) == 1 and isinstance(data[0], (float, int)):
639 data = [np.asarray(data)]
640 else:
641 data = [
642 x.values if x.__class__.__name__ == "DataFrame" else x
643 for x in data
644 ]
645 else:
646 data = data.values if data.__class__.__name__ == "DataFrame" else data
647 data = [data]
649 if shapes is not None:
650 data = [
651 standardize_single_array(x, shape)
652 for (x, shape) in zip(data, shapes)
653 ]
654 else:
655 data = [standardize_single_array(x) for x in data]
657 if len(data) != len(names):
658 if data and hasattr(data[0], "shape"):
659 raise ValueError(
660 "Error when checking model "
661 + exception_prefix
662 + ": the list of Numpy arrays that you are passing to "
663 "your model is not the size the model expected. "
664 "Expected to see "
665 + str(len(names))
666 + " array(s), "
667 + "for inputs "
668 + str(names)
669 + " but instead got the following list of "
670 + str(len(data))
671 + " arrays: "
672 + str(data)[:200]
673 + "..."
674 )
675 elif len(names) > 1:
676 raise ValueError(
677 "Error when checking model "
678 + exception_prefix
679 + ": you are passing a list as input to your model, "
680 "but the model expects a list of "
681 + str(len(names))
682 + " Numpy arrays instead. The list you passed was: "
683 + str(data)[:200]
684 )
685 elif len(data) == 1 and not hasattr(data[0], "shape"):
686 raise TypeError(
687 "Error when checking model "
688 + exception_prefix
689 + ": data should be a Numpy array, or list/dict of "
690 "Numpy arrays. Found: " + str(data)[:200] + "..."
691 )
692 elif len(names) == 1:
693 data = [np.asarray(data)]
695 # Check shapes compatibility.
696 if shapes:
697 for i in range(len(names)):
698 if shapes[i] is not None:
699 if tf.is_tensor(data[i]):
700 tensorshape = data[i].shape
701 if not tensorshape:
702 continue
703 data_shape = tuple(tensorshape.as_list())
704 elif is_composite_or_composite_value(data[i]):
705 tensorshape = get_composite_shape(data[i])
706 data_shape = tuple(tensorshape.as_list())
707 else:
708 data_shape = data[i].shape
710 shape = shapes[i]
711 if len(data_shape) != len(shape):
712 raise ValueError(
713 "Error when checking "
714 + exception_prefix
715 + ": expected "
716 + names[i]
717 + " to have "
718 + str(len(shape))
719 + " dimensions, but got array with shape "
720 + str(data_shape)
721 )
722 if not check_batch_axis:
723 data_shape = data_shape[1:]
724 shape = shape[1:]
725 for dim, ref_dim in zip(data_shape, shape):
726 if (
727 ref_dim != dim
728 and ref_dim is not None
729 and dim is not None
730 ):
731 raise ValueError(
732 "Error when checking "
733 + exception_prefix
734 + ": expected "
735 + names[i]
736 + " to have shape "
737 + str(shape)
738 + " but got array with shape "
739 + str(data_shape)
740 )
741 return data
744def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
745 """Maps `sample_weight` or `class_weight` to model outputs.
747 Args:
748 x_weight: User-provided `sample_weight` or `class_weight` argument.
749 output_names: List of output names (strings) in the model.
750 weight_type: A string used purely for exception printing.
752 Returns:
753 A list of `sample_weight` or `class_weight` where there are exactly
754 one element per model output.
756 Raises:
757 ValueError: In case of invalid user-provided argument.
758 """
759 if x_weight is None or (
760 isinstance(x_weight, (list, tuple)) and len(x_weight) == 0
761 ):
762 return [None for _ in output_names]
763 if len(output_names) == 1:
764 if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1:
765 return x_weight
766 if isinstance(x_weight, dict) and output_names[0] in x_weight:
767 return [x_weight[output_names[0]]]
768 else:
769 return [x_weight]
770 if isinstance(x_weight, (list, tuple)):
771 if len(x_weight) != len(output_names):
772 raise ValueError(
773 "Provided `"
774 + weight_type
775 + "` was a list of "
776 + str(len(x_weight))
777 + " elements, but the model has "
778 + str(len(output_names))
779 + " outputs. You should provide one `"
780 + weight_type
781 + "`array per model output."
782 )
783 return x_weight
784 if isinstance(x_weight, collections.abc.Mapping):
785 generic_utils.check_for_unexpected_keys(
786 weight_type, x_weight, output_names
787 )
788 x_weights = []
789 for name in output_names:
790 x_weights.append(x_weight.get(name))
791 return x_weights
792 else:
793 raise TypeError(
794 "The model has multiple outputs, so `"
795 + weight_type
796 + "` should be either a list or a dict. Provided `"
797 + weight_type
798 + "` type not understood: "
799 + str(x_weight)
800 )
803def standardize_class_weights(class_weight, output_names):
804 return standardize_sample_or_class_weights(
805 class_weight, output_names, "class_weight"
806 )
809def standardize_sample_weights(sample_weight, output_names):
810 return standardize_sample_or_class_weights(
811 sample_weight, output_names, "sample_weight"
812 )
815def check_array_lengths(inputs, targets, weights=None):
816 """Does user input validation for numpy arrays.
818 Args:
819 inputs: list of Numpy arrays of inputs.
820 targets: list of Numpy arrays of targets.
821 weights: list of Numpy arrays of sample weights.
823 Raises:
824 ValueError: in case of incorrectly formatted data.
825 """
827 def is_tensor_or_composite_tensor(x):
828 return tf.is_tensor(x) or is_composite_or_composite_value(x)
830 def set_of_lengths(x):
831 # Returns a set with the variation between
832 # different shapes, with None => 0
833 if x is None:
834 return {}
835 else:
836 return set(
837 [
838 y.shape[0]
839 for y in x
840 if y is not None and not is_tensor_or_composite_tensor(y)
841 ]
842 )
844 set_x = set_of_lengths(inputs)
845 set_y = set_of_lengths(targets)
846 set_w = set_of_lengths(weights)
847 if len(set_x) > 1:
848 raise ValueError(
849 "All input arrays (x) should have "
850 "the same number of samples. Got array shapes: "
851 + str([x.shape for x in inputs])
852 )
853 if len(set_y) > 1:
854 raise ValueError(
855 "All target arrays (y) should have "
856 "the same number of samples. Got array shapes: "
857 + str([y.shape for y in targets])
858 )
859 if set_x and set_y and list(set_x)[0] != list(set_y)[0]:
860 raise ValueError(
861 "Input arrays should have "
862 "the same number of samples as target arrays. "
863 "Found "
864 + str(list(set_x)[0])
865 + " input samples and "
866 + str(list(set_y)[0])
867 + " target samples."
868 )
869 if len(set_w) > 1:
870 raise ValueError(
871 "All sample_weight arrays should have "
872 "the same number of samples. Got array shapes: "
873 + str([w.shape for w in weights])
874 )
875 if set_y and set_w and list(set_y)[0] != list(set_w)[0]:
876 raise ValueError(
877 "Sample_weight arrays should have "
878 "the same number of samples as target arrays. Got "
879 + str(list(set_y)[0])
880 + " input samples and "
881 + str(list(set_w)[0])
882 + " target samples."
883 )
886def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
887 """Does validation on the compatibility of targets and loss functions.
889 This helps prevent users from using loss functions incorrectly. This check
890 is purely for UX purposes.
892 Args:
893 targets: list of Numpy arrays of targets.
894 loss_fns: list of loss functions.
895 output_shapes: list of shapes of model outputs.
897 Raises:
898 ValueError: if a loss function or target array
899 is incompatible with an output.
900 """
901 key_loss_fns = {
902 losses.mean_squared_error,
903 losses.binary_crossentropy,
904 losses.categorical_crossentropy,
905 }
906 key_loss_classes = (
907 losses.MeanSquaredError,
908 losses.BinaryCrossentropy,
909 losses.CategoricalCrossentropy,
910 )
911 for y, loss, shape in zip(targets, loss_fns, output_shapes):
912 if y is None or loss is None or tf.is_tensor(y):
913 continue
914 if losses.is_categorical_crossentropy(loss):
915 if y.shape[-1] == 1:
916 raise ValueError(
917 "You are passing a target array of shape "
918 + str(y.shape)
919 + " while using as loss `categorical_crossentropy`. "
920 "`categorical_crossentropy` expects "
921 "targets to be binary matrices (1s and 0s) "
922 "of shape (samples, classes). "
923 "If your targets are integer classes, "
924 "you can convert them to the expected format via:\n"
925 "```\n"
926 "from keras.src.utils import to_categorical\n"
927 "y_binary = to_categorical(y_int)\n"
928 "```\n"
929 "\n"
930 "Alternatively, you can use the loss function "
931 "`sparse_categorical_crossentropy` instead, "
932 "which does expect integer targets."
933 )
935 is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper)
936 if isinstance(loss, key_loss_classes) or (
937 is_loss_wrapper and (loss.fn in key_loss_fns)
938 ):
939 for target_dim, out_dim in zip(y.shape[1:], shape[1:]):
940 if out_dim is not None and target_dim != out_dim:
941 loss_name = loss.name
942 if loss_name is None:
943 loss_type = loss.fn if is_loss_wrapper else type(loss)
944 loss_name = loss_type.__name__
945 raise ValueError(
946 "A target array with shape "
947 + str(y.shape)
948 + " was passed for an output of shape "
949 + str(shape)
950 + " while using as loss `"
951 + loss_name
952 + "`. "
953 "This loss expects targets to have the same shape "
954 "as the output."
955 )
958def collect_per_output_metric_info(
959 metrics,
960 output_names,
961 output_shapes,
962 loss_fns,
963 from_serialized=False,
964 is_weighted=False,
965):
966 """Maps metric names and functions to model outputs.
968 Args:
969 metrics: a list or a list of lists or a dict of metric functions.
970 output_names: a list of the names (strings) of model outputs.
971 output_shapes: a list of the shapes (strings) of model outputs.
972 loss_fns: a list of the loss functions corresponding to the model
973 outputs.
974 from_serialized: whether the model the metrics are being sourced from is
975 being initialized from a serialized format.
976 is_weighted: Boolean indicating whether the given metrics are weighted.
978 Returns:
979 A list (one entry per model output) of dicts.
980 For instance, if the model has 2 outputs, and for the first output
981 we want to compute "binary_accuracy" and "binary_crossentropy",
982 and just "binary_accuracy" for the second output,
983 the list would look like: `[{
984 'acc': binary_accuracy(),
985 'ce': binary_crossentropy(),
986 }, {
987 'acc': binary_accuracy(),
988 }]`
990 Raises:
991 TypeError: if an incorrect type is passed for the `metrics` argument.
992 """
993 if not metrics:
994 return [{} for _ in output_names]
996 if isinstance(metrics, list):
997 any_sub_list = any(isinstance(m, list) for m in metrics)
998 if any_sub_list:
999 if len(metrics) != len(output_names):
1000 raise ValueError(
1001 "When passing a list of lists as `metrics`, "
1002 "it should have one entry per model output. "
1003 "The model has "
1004 + str(len(output_names))
1005 + " outputs, but you passed metrics="
1006 + str(metrics)
1007 )
1008 # User has provided a list of len = len(outputs).
1009 nested_metrics = [generic_utils.to_list(m) for m in metrics]
1010 else:
1011 # If it is a single list we then apply all metrics to all outputs.
1012 if len(output_names) > 1:
1013 nested_metrics = []
1014 for _ in output_names:
1015 nested_metrics.append(
1016 [metrics_module.clone_metric(m) for m in metrics]
1017 )
1018 else:
1019 nested_metrics = [metrics]
1020 elif isinstance(metrics, collections.abc.Mapping):
1021 generic_utils.check_for_unexpected_keys(
1022 "metrics", metrics, output_names
1023 )
1024 nested_metrics = []
1025 for name in output_names:
1026 output_metrics = generic_utils.to_list(metrics.get(name, []))
1027 nested_metrics.append(output_metrics)
1028 else:
1029 raise TypeError(
1030 "Type of `metrics` argument not understood. "
1031 "Expected a list or dictionary, found: " + str(metrics)
1032 )
1034 per_output_metrics = []
1035 for i, metrics in enumerate(nested_metrics):
1036 metrics_dict = collections.OrderedDict()
1037 for metric in metrics:
1038 metric_name = get_metric_name(metric, is_weighted)
1039 metric_fn = get_metric_function(
1040 metric, output_shape=output_shapes[i], loss_fn=loss_fns[i]
1041 )
1042 metric_fn._from_serialized = from_serialized
1044 # If the metric function is not stateful, we create a stateful
1045 # version.
1046 if not isinstance(metric_fn, metrics_module.Metric):
1047 metric_fn = metrics_module.MeanMetricWrapper(
1048 metric_fn, name=metric_name
1049 )
1050 # If the metric is being revived from something stateless, such
1051 # as a string (e.g. "accuracy"), we may need to later reapply
1052 # transformations such as renaming.
1053 metric_fn._from_serialized = False
1054 metrics_dict[metric_name] = metric_fn
1055 per_output_metrics.append(metrics_dict)
1057 return per_output_metrics
1060def batch_shuffle(index_array, batch_size):
1061 """Shuffles an array in a batch-wise fashion.
1063 Useful for shuffling HDF5 arrays
1064 (where one cannot access arbitrary indices).
1066 Args:
1067 index_array: array of indices to be shuffled.
1068 batch_size: integer.
1070 Returns:
1071 The `index_array` array, shuffled in a batch-wise fashion.
1072 """
1073 batch_count = int(len(index_array) / batch_size)
1074 # to reshape we need to be cleanly divisible by batch size
1075 # we stash extra items and reappend them after shuffling
1076 last_batch = index_array[batch_count * batch_size :]
1077 index_array = index_array[: batch_count * batch_size]
1078 index_array = index_array.reshape((batch_count, batch_size))
1079 np.random.shuffle(index_array)
1080 index_array = index_array.flatten()
1081 return np.append(index_array, last_batch)
1084def standardize_weights(
1085 y, sample_weight=None, class_weight=None, sample_weight_mode=None
1086):
1087 """Performs sample weight validation and standardization.
1089 Everything gets normalized to a single sample-wise (or timestep-wise)
1090 weight array. If both `sample_weight` and `class_weight` are provided,
1091 the weights are multiplied.
1093 Args:
1094 y: Numpy array or Tensor of model targets to be weighted.
1095 sample_weight: User-provided `sample_weight` argument.
1096 class_weight: User-provided `class_weight` argument.
1097 sample_weight_mode: One of `None` or `"temporal"`. `"temporal"`
1098 indicated that we expect 2D weight data that will be applied to the
1099 last 2 dimensions of the targets (i.e. we are weighting timesteps, not
1100 samples).
1102 Returns:
1103 A numpy array of target weights, one entry per sample to weight.
1105 Raises:
1106 ValueError: In case of invalid user-provided arguments.
1107 """
1108 # Iterator may return sample_weight as 1-tuple
1109 if isinstance(sample_weight, tuple):
1110 sample_weight = sample_weight[0]
1111 if sample_weight_mode is not None and sample_weight_mode != "samplewise":
1112 if sample_weight_mode != "temporal":
1113 raise ValueError(
1114 '"sample_weight_mode should be None or "temporal". Found: '
1115 + str(sample_weight_mode)
1116 )
1117 if len(y.shape) < 3:
1118 raise ValueError(
1119 "Found a sample_weight array for an input with shape "
1120 + str(y.shape)
1121 + ". "
1122 "Timestep-wise sample weighting (use of "
1123 'sample_weight_mode="temporal") is restricted to '
1124 "outputs that are at least 3D, i.e. that have "
1125 "a time dimension."
1126 )
1127 if sample_weight is not None and len(sample_weight.shape) != 2:
1128 raise ValueError(
1129 "Found a sample_weight array with shape "
1130 + str(sample_weight.shape)
1131 + ". "
1132 "In order to use timestep-wise sample weighting, "
1133 "you should pass a 2D sample_weight array."
1134 )
1135 else:
1136 if sample_weight is not None and len(sample_weight.shape) != 1:
1137 raise ValueError(
1138 "Found a sample_weight array with shape {}. In order to "
1139 "use timestep-wise sample weights, you should specify "
1140 'sample_weight_mode="temporal" in compile(); founssd "{}" '
1141 "instead. If you just mean to use sample-wise weights, "
1142 "make sure your sample_weight array is 1D.".format(
1143 sample_weight.shape, sample_weight_mode
1144 )
1145 )
1147 if sample_weight is not None:
1148 if len(sample_weight.shape) > len(y.shape):
1149 raise ValueError(
1150 "Found a sample_weight with shape"
1151 + str(sample_weight.shape)
1152 + ".Expected sample_weight with rank less than or equal to "
1153 + str(len(y.shape))
1154 )
1156 if (
1157 not tf.is_tensor(sample_weight)
1158 and y.shape[: sample_weight.ndim] != sample_weight.shape
1159 ):
1160 raise ValueError(
1161 "Found a sample_weight array with shape "
1162 + str(sample_weight.shape)
1163 + " for an input with shape "
1164 + str(y.shape)
1165 + ". sample_weight cannot be broadcast."
1166 )
1168 # Class weights applied per-sample.
1169 class_sample_weight = None
1170 if isinstance(class_weight, dict):
1171 if len(y.shape) > 2:
1172 raise ValueError(
1173 "`class_weight` not supported for 3+ dimensional targets."
1174 )
1176 if tf.is_tensor(y):
1177 # Few classes are expected, so densifying is reasonable.
1178 keys = np.array(sorted(class_weight.keys()))
1179 values = np.array([class_weight[i] for i in keys])
1180 weight_vector = np.zeros(np.max(keys) + 1)
1181 weight_vector[:] = np.nan
1182 weight_vector[keys] = values
1184 y_classes = tf.__internal__.smart_cond.smart_cond(
1185 len(y.shape.as_list()) == 2 and backend.shape(y)[1] > 1,
1186 lambda: backend.argmax(y, axis=1),
1187 lambda: tf.cast(backend.reshape(y, (-1,)), tf.int64),
1188 )
1189 class_sample_weight = tf.compat.v1.gather(weight_vector, y_classes)
1190 tf.debugging.check_numerics(
1191 class_sample_weight,
1192 "Invalid classes or class weights detected. NaN values "
1193 "indicate that an appropriate class weight could not be "
1194 "determined.",
1195 )
1196 class_sample_weight = tf.cast(class_sample_weight, backend.floatx())
1197 if sample_weight is not None:
1198 sample_weight = tf.cast(
1199 tf.convert_to_tensor(sample_weight), backend.floatx()
1200 )
1201 else:
1202 y_classes = y
1203 if len(y.shape) == 2:
1204 if y.shape[1] > 1:
1205 y_classes = np.argmax(y, axis=1)
1206 elif y.shape[1] == 1:
1207 y_classes = np.reshape(y, y.shape[0])
1209 class_sample_weight = np.asarray(
1210 [class_weight[cls] for cls in y_classes if cls in class_weight]
1211 )
1213 if len(class_sample_weight) != len(y_classes):
1214 # subtract the sets to pick all missing classes
1215 existing_classes = set(y_classes)
1216 existing_class_weight = set(class_weight.keys())
1217 raise ValueError(
1218 "`class_weight` must contain all classes in the data."
1219 " The classes %s exist in the data but not in "
1220 "`class_weight`."
1221 % (existing_classes - existing_class_weight)
1222 )
1224 if class_sample_weight is not None and sample_weight is not None:
1225 # Multiply weights if both are provided.
1226 return class_sample_weight * sample_weight
1227 if sample_weight is not None:
1228 return sample_weight
1229 if class_sample_weight is not None:
1230 return class_sample_weight
1231 return None
1234def has_symbolic_tensors(ls):
1235 if tf.executing_eagerly():
1236 return False
1237 return has_tensors(ls)
1240def has_tensors(ls):
1241 """Returns true if `ls` contains tensors."""
1242 # Note: at some point in time ragged tensors didn't count as tensors, so
1243 # this returned false for ragged tensors. Making this return true fails some
1244 # tests which would then require a steps_per_epoch argument.
1245 if isinstance(ls, (list, tuple)):
1246 return any(
1247 tf.is_tensor(v) and not isinstance(v, tf.RaggedTensor) for v in ls
1248 )
1249 if isinstance(ls, dict):
1250 return any(
1251 tf.is_tensor(v) and not isinstance(v, tf.RaggedTensor)
1252 for _, v in ls.items()
1253 )
1254 return tf.is_tensor(ls) and not isinstance(ls, tf.RaggedTensor)
1257def get_metric_name(metric, weighted=False):
1258 """Returns the name corresponding to the given metric input.
1260 Args:
1261 metric: Metric function name or reference.
1262 weighted: Boolean indicating if the given metric is weighted.
1264 Returns:
1265 The metric name.
1266 """
1267 if tf.__internal__.tf2.enabled():
1268 # We keep the string that the user has set in compile as the metric
1269 # name.
1270 if isinstance(metric, str):
1271 return metric
1273 metric = metrics_module.get(metric)
1274 return metric.name if hasattr(metric, "name") else metric.__name__
1275 else:
1276 metric_name_prefix = "weighted_" if weighted else ""
1277 if metric in ("accuracy", "acc", "crossentropy", "ce"):
1278 if metric in ("accuracy", "acc"):
1279 suffix = "acc"
1280 elif metric in ("crossentropy", "ce"):
1281 suffix = "ce"
1282 else:
1283 metric_fn = metrics_module.get(metric)
1284 # Get metric name as string
1285 if hasattr(metric_fn, "name"):
1286 suffix = metric_fn.name
1287 else:
1288 suffix = metric_fn.__name__
1289 metric_name = metric_name_prefix + suffix
1290 return metric_name
1293def get_metric_function(metric, output_shape=None, loss_fn=None):
1294 """Returns the metric function corresponding to the given metric input.
1296 Args:
1297 metric: Metric function name or reference.
1298 output_shape: The shape of the output that this metric will be
1299 calculated for.
1300 loss_fn: The loss function used.
1302 Returns:
1303 The metric function.
1304 """
1305 if metric not in ["accuracy", "acc", "crossentropy", "ce"]:
1306 return metrics_module.get(metric)
1308 is_sparse_categorical_crossentropy = isinstance(
1309 loss_fn, losses.SparseCategoricalCrossentropy
1310 ) or (
1311 isinstance(loss_fn, losses.LossFunctionWrapper)
1312 and loss_fn.fn == losses.sparse_categorical_crossentropy
1313 )
1315 is_binary_crossentropy = isinstance(loss_fn, losses.BinaryCrossentropy) or (
1316 isinstance(loss_fn, losses.LossFunctionWrapper)
1317 and loss_fn.fn == losses.binary_crossentropy
1318 )
1320 if metric in ["accuracy", "acc"]:
1321 if output_shape[-1] == 1 or is_binary_crossentropy:
1322 return metrics_module.binary_accuracy
1323 elif is_sparse_categorical_crossentropy:
1324 return metrics_module.sparse_categorical_accuracy
1325 # If the output_shape[-1] is not 1, then we know output is
1326 # `categorical`. We assume it is sparse categorical only if loss is
1327 # explicitly given as sparse categorical crossentropy loss.
1328 return metrics_module.categorical_accuracy
1329 else:
1330 if output_shape[-1] == 1 or is_binary_crossentropy:
1331 return metrics_module.binary_crossentropy
1332 elif is_sparse_categorical_crossentropy:
1333 return metrics_module.sparse_categorical_crossentropy
1334 return metrics_module.categorical_crossentropy
1337def call_metric_function(
1338 metric_fn, y_true, y_pred=None, weights=None, mask=None
1339):
1340 """Invokes metric function and returns the metric result tensor."""
1341 if mask is not None:
1342 mask = tf.cast(mask, y_pred.dtype)
1343 if weights is None:
1344 # Use mask as sample weight.
1345 weights = mask
1346 else:
1347 # Update dimensions of weights to match with mask.
1348 weights = tf.cast(weights, dtype=y_pred.dtype)
1349 mask, _, weights = losses_utils.squeeze_or_expand_dimensions(
1350 mask, sample_weight=weights
1351 )
1352 weights *= mask
1354 if y_pred is not None:
1355 return metric_fn(y_true, y_pred, sample_weight=weights)
1356 # `Mean` metric only takes a single value.
1357 return metric_fn(y_true, sample_weight=weights)
1360def get_loss_function(loss):
1361 """Returns the loss corresponding to the loss input in `compile` API."""
1362 if loss is None or isinstance(loss, losses.Loss):
1363 return loss
1365 if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss):
1366 # It is not safe to assume that the loss takes no constructor arguments.
1367 raise ValueError(
1368 "Received uninstantiated Loss class: {}\n"
1369 "Please call loss classes "
1370 "before passing them to Model.compile.".format(loss)
1371 )
1373 # Deserialize loss configuration, if needed.
1374 if isinstance(loss, collections.abc.Mapping):
1375 loss = losses.get(loss)
1377 # Custom callable class.
1378 if callable(loss) and not hasattr(loss, "__name__"):
1379 return loss
1381 # Wrap loss function with signature `(y_true, y_pred, **kwargs)`
1382 # in `LossFunctionWrapper` class.
1383 loss_fn = losses.get(loss)
1385 # For losses which are given as strings/functions in the compile API,
1386 # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE`
1387 # (both in distribution strategy context and otherwise).
1388 return losses.LossFunctionWrapper(
1389 loss_fn,
1390 name=loss_fn.__name__,
1391 reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
1392 )
1395def validate_dataset_input(x, y, sample_weight, validation_split=None):
1396 """Validates user input arguments when a dataset iterator is passed.
1398 Args:
1399 x: Input data. A `tf.data` dataset or iterator.
1400 y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s).
1401 Expected to be `None` when `x` is a dataset iterator.
1402 sample_weight: An optional sample-weight array passed by the user to
1403 weight the importance of each sample in `x`. Expected to be `None` when
1404 `x` is a dataset iterator
1405 validation_split: Float between 0 and 1. Fraction of the training data to
1406 be used as validation data. Expected to be `None` when `x` is a dataset
1407 iterator.
1409 Raises:
1410 ValueError: if argument `y` or `sample_weight` or `validation_split` are
1411 provided by user.
1412 """
1413 if y is not None:
1414 raise ValueError(
1415 "You passed a dataset or dataset iterator (%s) as "
1416 "input `x` to your model. In that case, you should "
1417 "not specify a target (`y`) argument, since the dataset "
1418 "or dataset iterator generates both input data and "
1419 "target data. "
1420 "Received: %s" % (x, y)
1421 )
1422 if sample_weight is not None:
1423 raise ValueError(
1424 "`sample_weight` argument is not supported when input "
1425 "`x` is a dataset or a dataset iterator. Instead, you"
1426 "can provide sample_weight as the third element of your"
1427 "dataset, i.e. (inputs, targets, sample_weight). "
1428 "Received: x=%s, sample_weight=%s" % (x, sample_weight)
1429 )
1430 if validation_split is not None and validation_split != 0.0:
1431 raise ValueError(
1432 "`validation_split` argument is not supported when "
1433 "input `x` is a dataset or a dataset iterator. "
1434 "Received: x=%s, validation_split=%f" % (x, validation_split)
1435 )
1438def validate_input_types(inp, orig_inp, allow_dict=True, field_name="inputs"):
1439 """Helper function to validate either inputs or targets."""
1440 if isinstance(inp, (list, tuple)):
1441 if not all(isinstance(v, np.ndarray) or tf.is_tensor(v) for v in inp):
1442 raise ValueError(
1443 "Please provide as model inputs either a single array or a "
1444 f"list of arrays. You passed: {field_name}={str(orig_inp)}"
1445 )
1446 elif isinstance(inp, dict):
1447 if not allow_dict:
1448 raise ValueError(
1449 f"You cannot pass a dictionary as model {field_name}."
1450 )
1451 elif not isinstance(inp, np.ndarray) and not tf.is_tensor(inp):
1452 raise ValueError(
1453 "Please provide as model inputs either a single array or a list of "
1454 "arrays. You passed: {}={}".format(field_name, orig_inp)
1455 )
1458def check_generator_arguments(
1459 y=None, sample_weight=None, validation_split=None
1460):
1461 """Validates arguments passed when using a generator."""
1462 if y is not None:
1463 raise ValueError(
1464 "`y` argument is not supported when data is"
1465 "a generator or Sequence instance. Instead pass targets"
1466 " as the second element of the generator."
1467 )
1468 if sample_weight is not None:
1469 raise ValueError(
1470 "`sample_weight` argument is not supported when data is"
1471 "a generator or Sequence instance. Instead pass sample"
1472 " weights as the third element of the generator."
1473 )
1474 if validation_split:
1475 raise ValueError(
1476 "If your data is in the form of a Python generator, "
1477 "you cannot use `validation_split`."
1478 )
1481def check_steps_argument(input_data, steps, steps_name):
1482 """Validates `steps` argument based on input data's type.
1484 The cases when `steps` value must be provided are when
1485 1. input data passed is an iterator.
1486 2. model was built on top of symbolic tensors, input data is not
1487 required and is `None`.
1488 3. input data passed is a symbolic tensor.
1490 Args:
1491 input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or
1492 tf.data.Dataset iterator or `None`.
1493 steps: Integer or `None`. Total number of steps (batches of samples) to
1494 execute.
1495 steps_name: The public API's parameter name for `steps`.
1497 Returns:
1498 boolean, True if `steps` argument is required, else False.
1500 Raises:
1501 ValueError: if `steps` argument is required for given input data type
1502 but not provided.
1503 """
1504 is_x_iterator = isinstance(
1505 input_data, (tf.compat.v1.data.Iterator, tf.data.Iterator)
1506 )
1507 if (
1508 input_data is None
1509 or is_x_iterator
1510 or has_symbolic_tensors(input_data)
1511 or (isinstance(input_data, list) and not input_data)
1512 ):
1513 if steps is None:
1514 input_type_str = (
1515 "a Dataset iterator" if is_x_iterator else "data tensors"
1516 )
1517 raise ValueError(
1518 "When using {input_type} as input to a model, you should"
1519 " specify the `{steps_name}` argument.".format(
1520 input_type=input_type_str, steps_name=steps_name
1521 )
1522 )
1523 return True
1525 if isinstance(input_data, (tf.compat.v1.data.Dataset, tf.data.Dataset)):
1526 return True
1528 if steps is not None:
1529 list_types = (np.ndarray, list, tuple)
1530 if isinstance(input_data, list_types) or (
1531 isinstance(input_data, dict)
1532 and any(isinstance(v, list_types) for v in input_data.values())
1533 ):
1534 logging.warning(
1535 "When passing input data as arrays, do not specify "
1536 "`steps_per_epoch`/`steps` argument. "
1537 "Please use `batch_size` instead."
1538 )
1539 return False
1542def cast_single_tensor(x, dtype=None):
1543 if isinstance(x, np.ndarray):
1544 x = tf.convert_to_tensor(x)
1545 dtype = dtype or backend.floatx()
1546 if x.dtype.is_floating:
1547 return tf.cast(x, dtype=dtype)
1548 return x
1551def cast_if_floating_dtype_and_mismatch(targets, outputs):
1552 """Returns target data tensors using correct datatype.
1554 Checks that each target and output pair are the same datatype. If not, casts
1555 the target to the output's datatype.
1557 Args:
1558 targets: tensor or list of targets.
1559 outputs: tensor or list of outputs.
1561 Returns:
1562 Targets in appropriate datatype.
1563 """
1564 if tf.is_tensor(targets):
1565 # There is one target, so output[0] should be the only output.
1566 return cast_single_tensor(targets, dtype=outputs[0].dtype)
1567 new_targets = []
1568 for target, out in zip(targets, outputs):
1569 if isinstance(target, np.ndarray):
1570 target = tf.convert_to_tensor(target)
1571 if target.dtype != out.dtype:
1572 new_targets.append(cast_single_tensor(target, dtype=out.dtype))
1573 else:
1574 new_targets.append(target)
1575 return new_targets
1578def cast_if_floating_dtype(x, dtype=None):
1579 """Casts the given data tensors to the default floating point type.
1581 Casts only if the input is already a floating point type.
1582 Args:
1583 x: tensor or list/tuple of tensors.
1584 dtype: The dtype to which Tensors should be cast.
1586 Returns:
1587 Converted input.
1588 """
1589 return tf.nest.map_structure(
1590 functools.partial(cast_single_tensor, dtype=dtype), x
1591 )
1594def cast_to_model_input_dtypes(x, model):
1595 """Casts the given data tensors to the dtypes of the model inputs.
1597 Args:
1598 x: tensor or list/tuple of tensors.
1599 model: The model.
1601 Returns:
1602 Converted input. Each tensor is casted to the corresponding input in
1603 `model.inputs`.
1604 """
1605 input_dtypes = tf.nest.map_structure(lambda t: t.dtype, model.inputs)
1606 return tf.nest.map_structure(tf.cast, x, input_dtypes)
1609def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
1610 """Prepares sample weight modes for the model.
1612 Args:
1613 training_endpoints: List of model _TrainingEndpoints.
1614 sample_weight_mode: sample weight mode user input passed from compile API.
1616 Raises:
1617 ValueError: In case of invalid `sample_weight_mode` input.
1618 """
1620 if isinstance(sample_weight_mode, collections.abc.Mapping):
1621 generic_utils.check_for_unexpected_keys(
1622 "sample_weight_mode",
1623 sample_weight_mode,
1624 [e.output_name for e in training_endpoints],
1625 )
1627 for end_point in training_endpoints:
1628 if not end_point.should_skip_target_weights():
1629 if end_point.output_name not in sample_weight_mode:
1630 raise ValueError(
1631 "Output "
1632 + end_point.output_name
1633 + "missing from `_sample_weight_modes` dictionary"
1634 )
1635 else:
1636 end_point.sample_weight_mode = sample_weight_mode.get(
1637 end_point.output_name
1638 )
1639 elif isinstance(sample_weight_mode, (list, tuple)):
1640 if len(sample_weight_mode) != len(training_endpoints):
1641 raise ValueError(
1642 "When passing a list as sample_weight_mode, "
1643 "it should have one entry per model output. "
1644 "The model has "
1645 + str(len(training_endpoints))
1646 + " outputs, but you passed "
1647 + str(len(sample_weight_mode))
1648 + "_sample_weight_modes."
1649 )
1650 for mode, endpoint in zip(sample_weight_mode, training_endpoints):
1651 if not endpoint.should_skip_target_weights():
1652 endpoint.sample_weight_mode = mode
1653 else:
1654 for endpoint in training_endpoints:
1655 if not endpoint.should_skip_target_weights():
1656 endpoint.sample_weight_mode = sample_weight_mode
1659def prepare_loss_functions(loss, output_names):
1660 """Converts loss to a list of loss functions.
1662 Args:
1663 loss: String (name of objective function), objective function or
1664 `tf.keras.losses.Loss` instance. See `tf.keras.losses`.
1665 If the model has multiple
1666 outputs, you can use a different loss on each output by passing a
1667 dictionary or a list of losses. The loss value that will be minimized
1668 by the model will then be the sum of all individual losses.
1669 output_names: List of model output names.
1671 Returns:
1672 A list of loss objective functions.
1674 Raises:
1675 ValueError: If loss is a dict with keys not in model output names,
1676 or if loss is a list with len not equal to model outputs.
1677 """
1678 if isinstance(loss, collections.abc.Mapping):
1679 generic_utils.check_for_unexpected_keys("loss", loss, output_names)
1680 loss_functions = []
1681 for name in output_names:
1682 if name not in loss:
1683 logging.warning(
1684 "Output {0} missing from loss dictionary. We assume "
1685 "this was done on purpose. The fit and evaluate APIs will "
1686 f"not be expecting any data to be passed to {name}."
1687 )
1688 loss_functions.append(get_loss_function(loss.get(name, None)))
1689 elif isinstance(loss, str):
1690 loss_functions = [get_loss_function(loss) for _ in output_names]
1691 elif isinstance(loss, collections.abc.Sequence):
1692 if len(loss) != len(output_names):
1693 raise ValueError(
1694 "When passing a list as loss, it should have one entry "
1695 "per model outputs. The model has {} outputs, but you "
1696 "passed loss={}".format(len(output_names), loss)
1697 )
1698 loss_functions = tf.nest.map_structure(get_loss_function, loss)
1699 else:
1700 loss_functions = [
1701 get_loss_function(loss) for _ in range(len(output_names))
1702 ]
1704 return loss_functions
1707def prepare_loss_weights(training_endpoints, loss_weights=None):
1708 """Converts loss weights to a list of loss weights.
1710 The result loss weights will be populated on the training endpoint.
1712 Args:
1713 training_endpoints: List of model training endpoints.
1714 loss_weights: Optional list or dictionary specifying scalar coefficients
1715 (Python floats) to weight the loss contributions of different model
1716 outputs. The loss value that will be minimized by the model will then
1717 be the *weighted sum* of all individual losses, weighted by the
1718 `loss_weights` coefficients. If a list, it is expected to have a 1:1
1719 mapping to the model's outputs. If a dict, it is expected to map
1720 output names (strings) to scalar coefficients.
1722 Raises:
1723 ValueError: If loss weight is a dict with key not in model output names,
1724 or if loss is a list with len not equal to model outputs.
1725 """
1726 if loss_weights is None:
1727 for e in training_endpoints:
1728 e.loss_weight = 1.0
1729 elif isinstance(loss_weights, collections.abc.Mapping):
1730 generic_utils.check_for_unexpected_keys(
1731 "loss_weights",
1732 loss_weights,
1733 [e.output_name for e in training_endpoints],
1734 )
1735 for e in training_endpoints:
1736 e.loss_weight = loss_weights.get(e.output_name, 1.0)
1737 elif isinstance(loss_weights, list):
1738 if len(loss_weights) != len(training_endpoints):
1739 raise ValueError(
1740 "When passing a list as loss_weights, "
1741 "it should have one entry per model output. "
1742 "The model has "
1743 + str(len(training_endpoints))
1744 + " outputs, but you passed loss_weights="
1745 + str(loss_weights)
1746 )
1747 for w, e in zip(loss_weights, training_endpoints):
1748 e.loss_weight = w
1749 else:
1750 raise TypeError(
1751 "Could not interpret loss_weights argument: "
1752 + str(loss_weights)
1753 + " - expected a list of dicts."
1754 )
1757# TODO(rohanj): This is a hack to get around not depending on feature_column and
1758# create a cyclical dependency. Figure out a cleaner solution
1759def is_feature_layer(layer):
1760 """Returns whether `layer` is a FeatureLayer or not."""
1761 return getattr(layer, "_is_feature_layer", False)
1764def is_eager_dataset_or_iterator(data):
1765 return tf.executing_eagerly() and isinstance(
1766 data, (tf.compat.v1.data.Dataset, tf.data.Dataset, tf.data.Iterator)
1767 )
1770def get_dataset_graph_def(dataset):
1771 if tf.executing_eagerly():
1772 graph_def_str = dataset._as_serialized_graph().numpy()
1773 else:
1774 graph_def_str = backend.get_value(dataset._as_serialized_graph())
1775 return tf.compat.v1.GraphDef().FromString(graph_def_str)
1778def verify_dataset_shuffled(x):
1779 """Verifies that the dataset is shuffled.
1781 Args:
1782 x: Dataset passed as an input to the model.
1784 Returns:
1785 boolean, whether the input dataset is shuffled or not.
1786 """
1787 assert isinstance(x, tf.data.Dataset)
1788 graph_def = get_dataset_graph_def(x)
1789 for node in graph_def.node:
1790 if node.op.startswith("ShuffleDataset"):
1791 return True
1792 # Also check graph_def.library.function for ds.interleave or ds.flat_map
1793 for function in graph_def.library.function:
1794 for node in function.node_def:
1795 if node.op.startswith("ShuffleDataset"):
1796 return True
1797 logging.warning(
1798 "Expected a shuffled dataset but input dataset `x` is "
1799 "not shuffled. Please invoke `shuffle()` on input dataset."
1800 )
1801 return False
1804def is_dataset_or_iterator(data):
1805 return isinstance(
1806 data,
1807 (
1808 tf.compat.v1.data.Dataset,
1809 tf.data.Dataset,
1810 tf.compat.v1.data.Iterator,
1811 tf.data.Iterator,
1812 ),
1813 )
1816def get_iterator(dataset):
1817 """Create and initialize an iterator from a dataset."""
1818 if tf.executing_eagerly():
1819 iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
1820 else:
1821 iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
1822 initialize_iterator(iterator)
1823 return iterator
1826def initialize_iterator(iterator):
1827 if not tf.executing_eagerly():
1828 init_op = iterator.initializer
1829 backend.get_session((init_op,)).run(init_op)
1832def extract_tensors_from_dataset(dataset):
1833 """Extract tuple of tensors `inputs, targets, sample_weight` from a dataset.
1835 Args:
1836 dataset: Dataset instance.
1838 Returns:
1839 Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
1840 """
1841 iterator = get_iterator(dataset)
1842 inputs, targets, sample_weight = unpack_iterator_input(iterator)
1843 return inputs, targets, sample_weight
1846def unpack_iterator_input(iterator):
1847 """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`.
1849 Args:
1850 iterator: Instance of a dataset iterator.
1852 Returns:
1853 Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
1854 """
1855 try:
1856 next_element = iterator.get_next()
1857 except tf.errors.OutOfRangeError:
1858 raise RuntimeError(
1859 "Your dataset iterator ran out of data; "
1860 "Make sure that your dataset can generate "
1861 "required number of samples."
1862 )
1864 if isinstance(next_element, (list, tuple)):
1865 if len(next_element) not in [2, 3]:
1866 raise ValueError(
1867 "Please provide model inputs as a list or tuple of 2 or 3 "
1868 "elements: (input, target) or (input, target, sample_weights) "
1869 "Received %s" % next_element
1870 )
1871 if len(next_element) == 2:
1872 x, y = next_element
1873 weights = None
1874 else:
1875 x, y, weights = next_element
1876 else:
1877 x = next_element
1878 y = None
1879 weights = None
1880 return x, y, weights
1883def infer_steps_for_dataset(
1884 model, dataset, steps, epochs=1, steps_name="steps"
1885):
1886 """Infers steps_per_epoch needed to loop through a dataset.
1888 Args:
1889 model: Keras model instance.
1890 dataset: Input data of type tf.data.Dataset.
1891 steps: Number of steps to draw from the dataset (may be None if
1892 unknown).
1893 epochs: Number of times to iterate over the dataset.
1894 steps_name: The string name of the steps argument, either `steps`,
1895 `validation_steps`, or `steps_per_epoch`. Only used for error message
1896 formatting.
1898 Returns:
1899 Integer or `None`. Inferred number of steps to loop through the dataset.
1900 `None` is returned if 1) the size of the dataset is unknown and `steps`
1901 was not specified, or 2) this is multi-worker training and auto sharding
1902 is enabled.
1904 Raises:
1905 ValueError: In case of invalid argument values.
1906 """
1907 assert isinstance(dataset, tf.data.Dataset)
1908 if model._in_multi_worker_mode() and (
1909 dataset.options().experimental_distribute.auto_shard_policy
1910 != tf.data.experimental.AutoShardPolicy.OFF
1911 ):
1912 # If the dataset would be auto-sharded, we should not infer a local
1913 # steps_per_epoch due to the possible imbalanced sharding between
1914 # workers.
1915 return None
1917 size = backend.get_value(tf.data.experimental.cardinality(dataset))
1918 if size == tf.data.experimental.INFINITE_CARDINALITY and steps is None:
1919 raise ValueError(
1920 "When passing an infinitely repeating dataset, you "
1921 "must specify the `%s` argument." % (steps_name,)
1922 )
1923 if size >= 0:
1924 if steps is not None and steps * epochs > size:
1925 if epochs > 1:
1926 raise ValueError(
1927 "The dataset you passed contains %s batches, but you "
1928 "passed `epochs=%s` and `%s=%s`, which is a total of "
1929 "%s steps. We cannot draw that many steps from this "
1930 "dataset. We suggest to set `%s=%s`."
1931 % (
1932 size,
1933 epochs,
1934 steps_name,
1935 steps,
1936 steps * epochs,
1937 steps_name,
1938 size // epochs,
1939 )
1940 )
1941 else:
1942 raise ValueError(
1943 "The dataset you passed contains %s batches, but you "
1944 "passed `%s=%s`. We cannot draw that many steps from "
1945 "this dataset. We suggest to set `%s=%s`."
1946 % (size, steps_name, steps, steps_name, size)
1947 )
1948 if steps is None:
1949 if size >= 0:
1950 return size
1951 return None
1952 return steps
1955class ModelInputs:
1956 """Encapsulates model inputs.
1958 Allows for transforming model inputs while keeping the same structure.
1959 """
1961 def __init__(self, inputs):
1962 self._inputs = inputs
1963 self._is_dict = isinstance(self._inputs, dict)
1964 self._is_single_input = not isinstance(
1965 self._inputs, (list, tuple, dict)
1966 )
1968 self._flattened_inputs = []
1969 self._input_names = []
1971 if self._is_dict:
1972 for k in sorted(self._inputs.keys()):
1973 self._flattened_inputs.append(self._inputs[k])
1974 self._input_names.append(k)
1975 else:
1976 self._flattened_inputs = tf.nest.flatten(self._inputs)
1977 self._input_names = [
1978 "input_%d" % (i + 1) for i in range(len(self._flattened_inputs))
1979 ]
1981 def get_input_names(self):
1982 """Returns keys to name inputs by.
1984 In case inputs provided were a list, tuple or single entry, we make up a
1985 key 'input_%d'. For dictionary case, we return a sorted list of keys.
1986 """
1987 return self._input_names
1989 def get_symbolic_inputs(self, return_single_as_list=False):
1990 """Returns inputs to be set as self.inputs for a model."""
1991 # TODO(karmel): There is a side-effect here where what you get
1992 # with as_list and as_dict depends on whether you have called this
1993 # method first, since it modifies in place.
1994 for i, (k, v) in enumerate(
1995 zip(self._input_names, self._flattened_inputs)
1996 ):
1997 if isinstance(v, (list, float, int)):
1998 v = np.asarray(v)
1999 if v.ndim == 1:
2000 v = np.expand_dims(v, 1)
2002 if isinstance(v, np.ndarray):
2003 # We fix the placeholder shape except the batch size.
2004 # This is suboptimal, but it is the best we can do with the info
2005 # we have. The user should call
2006 # `model._set_inputs(placeholders)` to specify custom
2007 # placeholders if the need arises.
2008 shape = (None,) + tuple(v.shape[1:])
2009 if shape == (None,):
2010 shape = (None, 1)
2011 dtype = tf.as_dtype(v.dtype)
2012 if dtype.is_floating:
2013 dtype = backend.floatx()
2014 v = backend.placeholder(shape=shape, name=k, dtype=dtype)
2015 elif isinstance(v, tf.TensorSpec):
2016 shape = (None,) + tuple(v.shape.as_list()[1:])
2017 if shape == (None,):
2018 shape = (None, 1)
2019 v = backend.placeholder(shape=shape, name=k, dtype=v.dtype)
2021 self._flattened_inputs[i] = v
2023 if self._is_dict:
2024 return dict(zip(self._input_names, self._flattened_inputs))
2025 if self._is_single_input and not return_single_as_list:
2026 return self._flattened_inputs[0]
2027 return self._flattened_inputs
2029 def as_dict(self):
2030 """An iterable over a dictionary version of inputs."""
2031 for k, v in zip(self._input_names, self._flattened_inputs):
2032 yield k, v
2034 def as_list(self):
2035 """Returning the inputs as a list."""
2036 return self._flattened_inputs
2039# Allow use of methods not exposed to the user.
2042def generic_output_names(outputs_list):
2043 return ["output_%d" % (i + 1) for i in range(len(outputs_list))]
2046def should_run_validation(validation_freq, epoch):
2047 """Checks if validation should be run this epoch.
2049 Args:
2050 validation_freq: Integer or list. If an integer, specifies how many
2051 training epochs to run before a new validation run is performed. If a
2052 list, specifies the epochs on which to run validation.
2053 epoch: Integer, the number of the training epoch just completed.
2055 Returns:
2056 Bool, True if validation should be run.
2058 Raises:
2059 ValueError: if `validation_freq` is an Integer and less than 1, or if
2060 it is neither an Integer nor a Sequence.
2061 """
2062 # `epoch` is 0-indexed internally but 1-indexed in the public API.
2063 one_indexed_epoch = epoch + 1
2065 if isinstance(validation_freq, int):
2066 if validation_freq < 1:
2067 raise ValueError("`validation_freq` can not be less than 1.")
2068 return one_indexed_epoch % validation_freq == 0
2070 if not isinstance(validation_freq, collections.abc.Container):
2071 raise ValueError(
2072 "`validation_freq` must be an Integer or "
2073 "`collections.abc.Container` (e.g. list, tuple, etc.)"
2074 )
2075 return one_indexed_epoch in validation_freq
2078def split_training_and_validation_data(x, y, sample_weights, validation_split):
2079 """Split input data into train/eval section based on validation_split."""
2080 if has_symbolic_tensors(x):
2081 raise ValueError(
2082 "If your data is in the form of symbolic tensors, "
2083 "you cannot use `validation_split`."
2084 )
2085 if hasattr(x[0], "shape"):
2086 split_at = int(x[0].shape[0] * (1.0 - validation_split))
2087 else:
2088 split_at = int(len(x[0]) * (1.0 - validation_split))
2089 x, val_x = (
2090 generic_utils.slice_arrays(x, 0, split_at),
2091 generic_utils.slice_arrays(x, split_at),
2092 )
2093 y, val_y = (
2094 generic_utils.slice_arrays(y, 0, split_at),
2095 generic_utils.slice_arrays(y, split_at),
2096 )
2097 if sample_weights:
2098 sample_weights, val_sample_weights = (
2099 generic_utils.slice_arrays(sample_weights, 0, split_at),
2100 generic_utils.slice_arrays(sample_weights, split_at),
2101 )
2102 else:
2103 val_sample_weights = None
2104 return x, y, sample_weights, val_x, val_y, val_sample_weights
2107def unpack_validation_data(validation_data, raise_if_ambiguous=True):
2108 """Unpack validation data based input type.
2110 The validation data is not touched if its dataset or dataset iterator.
2111 For other type of input (Numpy or tensor), it will be unpacked into tuple of
2112 3 which is x, y and sample weights.
2114 Args:
2115 validation_data: dataset, dataset iterator, or numpy, tensor tuple.
2116 raise_if_ambiguous: boolean on whether to fail if validation_data cannot
2117 be parsed. Otherwise simply return validation_data, None, None and defer
2118 the decision to the caller.
2120 Returns:
2121 tuple of 3, (x, y, sample_weights) for numpy and tensor input.
2122 """
2123 if isinstance(
2124 validation_data,
2125 (
2126 tf.compat.v1.data.Iterator,
2127 tf.data.Iterator,
2128 tf.data.Dataset,
2129 data_utils.Sequence,
2130 ),
2131 ) or not hasattr(validation_data, "__len__"):
2132 val_x = validation_data
2133 val_y = None
2134 val_sample_weight = None
2135 elif len(validation_data) == 2:
2136 try:
2137 (
2138 val_x,
2139 val_y,
2140 ) = validation_data
2141 val_sample_weight = None
2142 except ValueError:
2143 val_x, val_y, val_sample_weight = validation_data, None, None
2144 elif len(validation_data) == 3:
2145 try:
2146 (
2147 val_x,
2148 val_y,
2149 val_sample_weight,
2150 ) = validation_data
2151 except ValueError:
2152 val_x, val_y, val_sample_weight = validation_data, None, None
2153 else:
2154 if raise_if_ambiguous:
2155 raise ValueError(
2156 "When passing a `validation_data` argument, "
2157 "it must contain either 2 items (x_val, y_val), "
2158 "or 3 items (x_val, y_val, val_sample_weights), "
2159 "or alternatively it could be a dataset or a "
2160 "dataset or a dataset iterator. "
2161 "However we received `validation_data=%s`" % validation_data
2162 )
2163 val_x, val_y, val_sample_weight = validation_data, None, None
2164 return val_x, val_y, val_sample_weight
2167class TrainingLoop:
2168 """TrainingLoop is a wrapper class around the training logic.
2170 This class is trying to encapsulate the different logic of fit/eval/predict
2171 with regard to different data input and model condition.
2173 Note that TrainingLoop is stateless, which means it doesn't contain any
2174 internal field and can be reused with different model and inputs.
2175 """
2177 def fit(
2178 self,
2179 model,
2180 x=None,
2181 y=None,
2182 batch_size=None,
2183 epochs=1,
2184 verbose=1,
2185 callbacks=None,
2186 validation_split=0.0,
2187 validation_data=None,
2188 shuffle=True,
2189 class_weight=None,
2190 sample_weight=None,
2191 initial_epoch=0,
2192 steps_per_epoch=None,
2193 validation_steps=None,
2194 validation_freq=1,
2195 **kwargs,
2196 ):
2197 """Train the model with the inputs and targets."""
2198 raise NotImplementedError()
2200 def evaluate(
2201 self,
2202 model,
2203 x=None,
2204 y=None,
2205 batch_size=None,
2206 verbose=1,
2207 sample_weight=None,
2208 steps=None,
2209 callbacks=None,
2210 **kwargs,
2211 ):
2212 """Returns the loss value & metrics values for the model in test
2213 mode."""
2214 raise NotImplementedError()
2216 def predict(
2217 self,
2218 model,
2219 x,
2220 batch_size=None,
2221 verbose=0,
2222 steps=None,
2223 callbacks=None,
2224 **kwargs,
2225 ):
2226 raise NotImplementedError()