Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training_arrays_v1.py: 13%
252 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to plain array data."""
17import functools
19import numpy as np
20import tensorflow.compat.v2 as tf
22from keras.src import backend
23from keras.src import callbacks as cbks
24from keras.src.distribute import distributed_training_utils_v1
25from keras.src.engine import training_utils_v1
26from keras.src.utils import io_utils
27from keras.src.utils.generic_utils import make_batches
28from keras.src.utils.generic_utils import slice_arrays
29from keras.src.utils.mode_keys import ModeKeys
31# isort: off
32from tensorflow.python.platform import tf_logging as logging
35try:
36 from scipy.sparse import issparse
37except ImportError:
38 issparse = None
41def model_iteration(
42 model,
43 inputs,
44 targets=None,
45 sample_weights=None,
46 batch_size=None,
47 epochs=1,
48 verbose=1,
49 callbacks=None,
50 val_inputs=None,
51 val_targets=None,
52 val_sample_weights=None,
53 shuffle=True,
54 initial_epoch=0,
55 steps_per_epoch=None,
56 validation_steps=None,
57 validation_freq=1,
58 mode=ModeKeys.TRAIN,
59 validation_in_fit=False,
60 prepared_feed_values_from_dataset=False,
61 steps_name="steps",
62 **kwargs,
63):
64 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT.
66 Args:
67 model: Keras Model instance.
68 inputs: Either a list or dictionary of arrays, or a dataset instance.
69 targets: List/dictionary of input arrays.
70 sample_weights: Optional list of sample weight arrays.
71 batch_size: Integer batch size or None if unknown.
72 epochs: Number of times to iterate over the data
73 verbose: 0, 1, or 2. Verbosity mode.
74 0 = silent, 1 = progress bar, 2 = one line per epoch.
75 Note that the progress bar is not particularly useful when
76 logged to a file, so verbose=2 is recommended when not running
77 interactively (eg, in a production environment).
78 callbacks: List of callbacks to be called during training
79 val_inputs: Either a list or dictionary of arrays, or a dataset
80 instance.
81 val_targets: List/dictionary of target arrays.
82 val_sample_weights: Optional list of sample weight arrays.
83 shuffle: Whether to shuffle the data at the beginning of each epoch
84 concatenation of list the display names of the outputs of `f` and the
85 list of display names of the outputs of `f_val`.
86 initial_epoch: Epoch at which to start training (useful for resuming a
87 previous training run)
88 steps_per_epoch: Total number of steps (batches of samples) before
89 declaring one epoch finished and starting the next epoch. Ignored with
90 the default value of `None`.
91 validation_steps: Number of steps to run validation for (only if doing
92 validation from data tensors). Ignored with the default value of
93 `None`.
94 validation_freq: Only relevant if validation data is provided. Integer
95 or `collections.abc.Container` instance (e.g. list, tuple, etc.). If
96 an integer, specifies how many training epochs to run before a new
97 validation run is performed, e.g. `validation_freq=2` runs validation
98 every 2 epochs. If a Container, specifies the epochs on which to run
99 validation, e.g. `validation_freq=[1, 2, 10]` runs validation at the
100 end of the 1st, 2nd, and 10th epochs.
101 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
102 validation_in_fit: if true, then this method is invoked from within
103 training iteration (for validation). In the case where `val_inputs` is
104 a dataset, this flag indicates that its iterator and feed values are
105 already created so should properly reuse resources.
106 prepared_feed_values_from_dataset: if True, `inputs` is a list of feed
107 tensors returned from `_prepare_feed_values` call on the validation
108 dataset, so do not call it again on `inputs`. Should only be used for
109 inline validation (i.e., only if `validation_in_fit` is also True).
110 steps_name: The string name of the steps argument, either `steps`,
111 `validation_steps`, or `steps_per_epoch`. Only used for error message
112 formatting.
113 **kwargs: Additional arguments for backwards compatibility.
115 Returns:
116 - In TRAIN mode: `History` object.
117 - In TEST mode: Evaluation metrics.
118 - In PREDICT mode: Outputs of the Model called on inputs.
120 Raises:
121 ValueError: in case of invalid arguments.
122 """
123 # Backwards compatibility.
124 if "steps" in kwargs:
125 steps_per_epoch = kwargs.pop("steps")
126 if kwargs:
127 raise TypeError(f"Unknown arguments: {kwargs}")
129 # In case we were passed a dataset, we extract symbolic tensors from it.
130 reset_dataset_after_each_epoch = False
131 input_iterator = None
132 is_dataset = isinstance(
133 inputs, (tf.compat.v1.data.Dataset, tf.data.Dataset)
134 )
135 # TODO(fchollet): consider moving `steps_per_epoch` inference to
136 # _standardize_user_data and set reset_dataset_after_each_epoch as an
137 # attribute on the dataset instance.
138 if is_dataset:
139 if steps_per_epoch is None:
140 reset_dataset_after_each_epoch = True
141 steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
142 model,
143 inputs,
144 steps_per_epoch,
145 epochs=epochs,
146 steps_name=steps_name,
147 )
148 input_iterator = _get_iterator(inputs, model._distribution_strategy)
150 # Enter tf.distribute.Strategy scope.
151 if model._distribution_strategy:
152 scope = distributed_training_utils_v1.distributed_scope(
153 strategy=model._distribution_strategy,
154 learning_phase=(1 if mode == ModeKeys.TRAIN else 0),
155 )
156 scope.__enter__()
158 use_steps = is_dataset or steps_per_epoch is not None
159 do_validation = val_inputs is not None
161 # Prepare input data.
162 inputs = input_iterator or inputs
163 if validation_in_fit and prepared_feed_values_from_dataset:
164 # When invoking validation in training loop, avoid creating iterator and
165 # list of feed values for the same validation dataset multiple times
166 # (which essentially would call `iterator.get_next()` that slows down
167 # execution and leads to OOM errors eventually.
168 ins = inputs
169 else:
170 ins = _prepare_feed_values(model, inputs, targets, sample_weights, mode)
171 # `ins` is a function when a distribute strategy is used in Eager mode.
172 # In that case `is_dataset` is True. The code branches that have
173 # requirements about the type of `ins` do not trigger in the distributed
174 # case.
176 if not is_dataset:
177 num_samples_or_steps = _get_num_samples_or_steps(
178 ins, batch_size, steps_per_epoch
179 )
180 else:
181 num_samples_or_steps = steps_per_epoch
183 # Update sample_weight_mode of the model if sample_weights is specified by
184 # the user. We need to call this function after we have a handle on the
185 # inputs (both numpy arrays and datasets) in order to determine if the user
186 # has specified sample_weights.
187 _update_sample_weight_mode(model, mode, ins)
189 # Get step function and loop type. As part of building the execution
190 # function we recompile the metrics based on the updated
191 # sample_weight_mode value.
192 f = _make_execution_function(model, mode)
194 # Prepare validation data. Hold references to the iterator and the input
195 # list to properly reinitialize and reuse in multiple validation passes.
196 val_iterator = None
197 if isinstance(val_inputs, (tf.compat.v1.data.Dataset, tf.data.Dataset)):
198 if validation_steps is None:
199 # Because we pass an iterator feed instead of a Dataset to the eval
200 # model_iteration() call, it will not trigger the dataset-input path
201 # that determines the number of steps required. To avoid this issue,
202 # set validation_steps here if validation_steps is None.
203 validation_steps = training_utils_v1.infer_steps_for_dataset(
204 model,
205 val_inputs,
206 validation_steps,
207 epochs=epochs,
208 steps_name="validation_steps",
209 )
210 val_iterator = _get_iterator(val_inputs, model._distribution_strategy)
211 val_inputs = _prepare_feed_values(
212 model, val_iterator, val_targets, val_sample_weights, ModeKeys.TEST
213 )
214 # Get num steps for printing.
215 val_samples_or_steps = validation_steps
216 else:
217 # Get num samples for printing.
218 val_samples_or_steps = (
219 val_inputs and tf.nest.flatten(val_inputs)[0].shape[0] or None
220 )
222 if mode == ModeKeys.TRAIN and verbose:
223 _print_train_info(
224 num_samples_or_steps, val_samples_or_steps, is_dataset
225 )
227 # Configure callbacks.
228 count_mode = "steps" if use_steps else "samples"
229 callbacks = cbks.configure_callbacks(
230 callbacks,
231 model,
232 do_validation=do_validation,
233 batch_size=batch_size,
234 epochs=epochs,
235 steps_per_epoch=steps_per_epoch,
236 samples=num_samples_or_steps,
237 count_mode=count_mode,
238 verbose=verbose,
239 mode=mode,
240 )
242 # Find beforehand arrays that need sparse-to-dense conversion.
243 if issparse is not None and not use_steps:
244 indices_for_conversion_to_dense = []
245 feed = _get_model_feed(model, mode)
246 for i, (input_data, feed_tensor) in enumerate(zip(ins, feed)):
247 if issparse(input_data) and not backend.is_sparse(feed_tensor):
248 indices_for_conversion_to_dense.append(i)
250 # Select aggregation method.
251 if mode == ModeKeys.PREDICT:
252 aggregator = training_utils_v1.OutputsAggregator(
253 use_steps,
254 num_samples=None if steps_per_epoch else num_samples_or_steps,
255 steps=steps_per_epoch,
256 )
257 else:
258 aggregator = training_utils_v1.MetricsAggregator(
259 use_steps,
260 num_samples=None if steps_per_epoch else num_samples_or_steps,
261 steps=steps_per_epoch,
262 )
264 if model._compile_distribution:
265 distributed_training_utils_v1._copy_weights_to_distributed_model(
266 model, mode
267 )
269 callbacks.model.stop_training = False
270 callbacks._call_begin_hook(mode)
272 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
273 initial_epoch, mode
274 )
276 for epoch in range(initial_epoch, epochs):
277 if callbacks.model.stop_training:
278 break
280 # Setup work for each epoch
281 epoch_logs = {}
282 if mode != ModeKeys.PREDICT:
283 # Collecting and resetting metrics has non-zero cost and will
284 # needlessly slow down model.predict.
285 model.reset_metrics()
286 if mode == ModeKeys.TRAIN:
287 callbacks.on_epoch_begin(epoch, epoch_logs)
289 if use_steps:
290 # Step-wise loop.
291 if steps_per_epoch is None:
292 # Loop over dataset until `OutOfRangeError` is raised.
293 target_steps = np.inf
294 else:
295 # Loop over dataset for the specified number of steps.
296 target_steps = steps_per_epoch
298 step = 0
299 while step < target_steps:
300 batch_logs = {"batch": step, "size": 1}
301 callbacks._call_batch_hook(mode, "begin", step, batch_logs)
303 # Get outputs.
304 try:
305 # `ins` can be callable in tf.distribute.Strategy + eager
306 # case.
307 if not callable(ins) or (
308 model._distribution_strategy
309 and not distributed_training_utils_v1.is_distributing_by_cloning( # noqa: E501
310 model
311 )
312 ):
313 actual_inputs = ins
314 else:
315 actual_inputs = ins()
316 batch_outs = f(actual_inputs)
317 except tf.errors.OutOfRangeError:
318 if is_dataset:
319 # The dataset passed by the user ran out of batches.
320 # Now we know the cardinality of the dataset. If
321 # steps_per_epoch was specified, then running out of
322 # data is unexpected, so we stop training and inform the
323 # user.
324 if steps_per_epoch:
325 callbacks.model.stop_training = True
326 logging.warning(
327 "Your dataset ran out of data; interrupting "
328 "training. Make sure that your dataset can "
329 "generate at least `%s * epochs` batches (in "
330 "this case, %d batches). You may need to use "
331 "the repeat() function when building your "
332 "dataset."
333 % (steps_name, steps_per_epoch * epochs)
334 )
335 elif step > 0:
336 steps_per_epoch = step
337 aggregator.steps = steps_per_epoch
338 else:
339 # We ran out of batches while the user passed an
340 # iterator (legacy).
341 callbacks.model.stop_training = True
342 logging.warning(
343 "Your dataset iterator ran out of data; "
344 "interrupting training. Make sure that your "
345 "iterator can generate at least `%s * epochs` "
346 "batches (in this case, %d batches). You may need "
347 "to use the repeat() function when building your "
348 "dataset." % (steps_name, steps_per_epoch * epochs)
349 )
350 break
352 if not isinstance(batch_outs, list):
353 batch_outs = [batch_outs]
355 if model._distribution_strategy:
356 batch_outs = distributed_training_utils_v1._per_replica_aggregate_batch( # noqa: E501
357 model._distribution_strategy, batch_outs, model, mode
358 )
360 # Aggregate results.
361 if step == 0:
362 aggregator.create(batch_outs)
363 aggregator.aggregate(batch_outs)
365 # Callbacks batch end.
366 batch_logs = callbacks.make_logs(
367 model, batch_logs, batch_outs, mode
368 )
369 callbacks._call_batch_hook(mode, "end", step, batch_logs)
370 step += 1
372 if callbacks.model.stop_training:
373 break
374 else:
375 # Sample-wise loop.
376 index_array = np.arange(num_samples_or_steps)
377 if shuffle == "batch":
378 index_array = training_utils_v1.batch_shuffle(
379 index_array, batch_size
380 )
381 elif shuffle:
382 np.random.shuffle(index_array)
383 batches = make_batches(num_samples_or_steps, batch_size)
384 for batch_index, (batch_start, batch_end) in enumerate(batches):
385 batch_ids = index_array[batch_start:batch_end]
386 # Slice into a batch.
387 if len(batches) == 1:
388 # If we only have one batch, do not slice. This takes care
389 # of composite tensors in non-Dataset modes; we currently
390 # don't support slicing them.
391 # TODO(b/133517906): Add slicing support.
392 ins_batch = ins
393 else:
394 try:
395 if ins and isinstance(ins[-1], int):
396 # Do not slice the training phase flag.
397 ins_batch = slice_arrays(ins[:-1], batch_ids) + [
398 ins[-1]
399 ]
400 else:
401 ins_batch = slice_arrays(ins, batch_ids)
402 except TypeError:
403 raise TypeError(
404 "TypeError while preparing batch. "
405 "If using HDF5 input data, "
406 'pass shuffle="batch".'
407 )
409 # Sparse to dense conversion.
410 if issparse is not None:
411 for i in indices_for_conversion_to_dense:
412 ins_batch[i] = ins_batch[i].toarray()
414 # Callbacks batch_begin.
415 batch_logs = {"batch": batch_index, "size": len(batch_ids)}
416 callbacks._call_batch_hook(
417 mode, "begin", batch_index, batch_logs
418 )
420 # Get outputs.
421 batch_outs = f(ins_batch)
422 if not isinstance(batch_outs, list):
423 batch_outs = [batch_outs]
425 # Aggregate results.
426 if batch_index == 0:
427 aggregator.create(batch_outs)
428 aggregator.aggregate(batch_outs, batch_start, batch_end)
430 # Callbacks batch end.
431 batch_logs = callbacks.make_logs(
432 model, batch_logs, batch_outs, mode
433 )
434 callbacks._call_batch_hook(mode, "end", batch_index, batch_logs)
436 if callbacks.model.stop_training:
437 break
439 aggregator.finalize()
440 results = aggregator.results
441 epoch_logs = callbacks.make_logs(model, epoch_logs, results, mode)
442 if len(results) == 1:
443 results = results[0]
445 # Run the test loop every `validation_freq` epochs during training.
446 if (
447 do_validation
448 and training_utils_v1.should_run_validation(validation_freq, epoch)
449 and not callbacks.model.stop_training
450 ):
452 if model._compile_distribution:
453 # Since we create a new clone from the original model we need to
454 # copy the weights back to the original model before we can run
455 # validation.
456 distributed_training_utils_v1._copy_weights_to_original_model(
457 model, ModeKeys.TRAIN
458 )
460 val_results = model_iteration(
461 model,
462 val_inputs,
463 targets=val_targets,
464 sample_weights=val_sample_weights,
465 batch_size=batch_size,
466 steps_per_epoch=validation_steps,
467 callbacks=callbacks,
468 verbose=0,
469 mode=ModeKeys.TEST,
470 validation_in_fit=True,
471 prepared_feed_values_from_dataset=(val_iterator is not None),
472 steps_name="validation_steps",
473 )
474 if not isinstance(val_results, list):
475 val_results = [val_results]
476 epoch_logs = callbacks.make_logs(
477 model, epoch_logs, val_results, mode, prefix="val_"
478 )
479 if val_iterator and epoch < epochs - 1:
480 _reinitialize_iterator(
481 val_iterator, model._distribution_strategy
482 )
484 if mode == ModeKeys.TRAIN:
485 # Epochs only apply to `fit`.
486 callbacks.on_epoch_end(epoch, epoch_logs)
488 # Reinitialize dataset iterator for the next epoch.
489 if reset_dataset_after_each_epoch and epoch < epochs - 1:
490 _reinitialize_iterator(input_iterator, model._distribution_strategy)
492 model._successful_loop_finish = True
493 callbacks._call_end_hook(mode)
495 if model._distribution_strategy:
496 if model._compile_distribution:
497 # TODO(priyag, psv): Copy back metrics to the original model as
498 # well?
499 distributed_training_utils_v1._copy_weights_to_original_model(
500 model, mode
501 )
502 scope.__exit__(None, None, None)
504 if mode == ModeKeys.TRAIN:
505 return model.history
506 return results
509def _get_model_feed(model, mode):
510 if mode == ModeKeys.PREDICT:
511 feed = model._feed_inputs
512 else:
513 feed = (
514 model._feed_inputs
515 + model._feed_targets
516 + model._feed_sample_weights
517 )
518 return feed
521def _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset):
522 increment = "steps" if is_dataset else "samples"
523 msg = f"Train on {num_samples_or_steps} {increment}"
524 if val_samples_or_steps:
525 msg += f", validate on {val_samples_or_steps} {increment}"
526 io_utils.print_msg(msg)
529def _get_num_samples_or_steps(ins, batch_size, steps_per_epoch):
530 """Returns total number of samples when training in batch mode or steps."""
531 if steps_per_epoch:
532 return steps_per_epoch
533 return training_utils_v1.check_num_samples(
534 ins, batch_size, steps_per_epoch, "steps_per_epoch"
535 )
538def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
539 """Prepare feed values to the model execution function.
541 Args:
542 model: Model to prepare feed values for.
543 inputs: List or dict of model inputs.
544 targets: Optional list of model targets.
545 sample_weights: Optional list of sample weight arrays.
546 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
548 Returns:
549 Feed values for the model in the given mode.
550 """
551 if model._distribution_strategy:
552 if isinstance(inputs, (tf.compat.v1.data.Dataset, tf.data.Dataset)):
553 inputs = distributed_training_utils_v1.get_iterator(
554 inputs, model._distribution_strategy
555 )
557 def get_distributed_inputs():
558 return distributed_training_utils_v1._prepare_feed_values(
559 model, inputs, targets, sample_weights, mode
560 )
562 # In the eager case, we want to call the input method per step, so
563 # return a lambda from here that can be called. Note that this is
564 # applicable only in Distribution Strategy case as it follows the same
565 # code path for both eager and graph modes.
566 # TODO(priyag,omalleyt): Either we should move the training DS with
567 # IteratorBase to use training_generator code path, or figure out how to
568 # set a symbolic Iterator out of a Dataset when in eager mode.
569 if tf.executing_eagerly():
570 return get_distributed_inputs
571 else:
572 return get_distributed_inputs()
574 if isinstance(
575 inputs,
576 (
577 tf.compat.v1.data.Dataset,
578 tf.data.Dataset,
579 tf.compat.v1.data.Iterator,
580 ),
581 ):
582 inputs, targets, sample_weights = model._standardize_user_data(
583 inputs, extract_tensors_from_dataset=True
584 )
586 inputs = training_utils_v1.ModelInputs(inputs).as_list()
587 targets = list(targets or [])
588 sample_weights = list(sample_weights or [])
589 ins = inputs + targets + sample_weights
590 if mode == ModeKeys.TRAIN and not isinstance(
591 backend.symbolic_learning_phase(), int
592 ):
593 ins += [True] # Add learning phase value.
594 return ins
597def _get_iterator(inputs, distribution_strategy=None):
598 if distribution_strategy:
599 return distributed_training_utils_v1.get_iterator(
600 inputs, distribution_strategy
601 )
602 return training_utils_v1.get_iterator(inputs)
605def _reinitialize_iterator(iterator, distribution_strategy=None):
606 if distribution_strategy:
607 distributed_training_utils_v1.initialize_iterator(
608 iterator, distribution_strategy
609 )
610 else:
611 training_utils_v1.initialize_iterator(iterator)
614def _make_execution_function(model, mode):
615 """Makes function to run one step of model execution."""
616 if model._distribution_strategy:
617 return distributed_training_utils_v1._make_execution_function(
618 model, mode
619 )
620 return model._make_execution_function(mode)
623def _update_sample_weight_mode(model, mode, inputs):
624 """Updates the sample_weight_mode of a given model."""
625 # Add a quick return to prevent us from calling model._feed_targets that
626 # accesses certain model properties that may not be set in the `PREDICT`
627 # mode.
628 if mode == ModeKeys.PREDICT:
629 return
631 sample_weights = None
632 # `inputs` is the model's inputs + targets + sample_weights +
633 # learning phase placeholder if specified. To update the sample_weight_mode
634 # we need to determine if the user has passed sample weights as part of the
635 # input.
636 if not callable(inputs):
637 sample_weights = inputs[
638 len(model._feed_inputs) + len(model._feed_targets) :
639 ]
640 has_learning_phase_pl = mode == ModeKeys.TRAIN and not isinstance(
641 backend.symbolic_learning_phase(), int
642 )
643 if has_learning_phase_pl:
644 sample_weights = sample_weights[:-1]
645 model._update_sample_weight_modes(sample_weights=sample_weights)
647 # Call the DistributionStrategy specific function to update the
648 # sample_weight_mode on the model.
649 if model._distribution_strategy:
650 distributed_training_utils_v1._update_sample_weight_modes(
651 model, mode, sample_weights
652 )
655# For backwards compatibility for internal users of these loops.
656fit_loop = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
657test_loop = functools.partial(
658 model_iteration, mode=ModeKeys.TEST, shuffle=False
659)
660predict_loop = functools.partial(
661 model_iteration, mode=ModeKeys.PREDICT, shuffle=False
662)
665class ArrayLikeTrainingLoop(training_utils_v1.TrainingLoop):
666 """TrainingLoop that handle inputs like array.
668 This is the default handler for most of the input data types, includes
669 symbolic tensors or Numpy array-like, Datasets and iterators in graph mode
670 (since they generate symbolic tensors). This Function is used to handle
671 model with `run_eagerly` = False.
672 """
674 def fit(
675 self,
676 model,
677 x=None,
678 y=None,
679 batch_size=None,
680 epochs=1,
681 verbose=1,
682 callbacks=None,
683 validation_split=0.0,
684 validation_data=None,
685 shuffle=True,
686 class_weight=None,
687 sample_weight=None,
688 initial_epoch=0,
689 steps_per_epoch=None,
690 validation_steps=None,
691 validation_freq=1,
692 **kwargs,
693 ):
694 batch_size = model._validate_or_infer_batch_size(
695 batch_size, steps_per_epoch, x
696 )
698 x, y, sample_weights = model._standardize_user_data(
699 x,
700 y,
701 sample_weight=sample_weight,
702 class_weight=class_weight,
703 batch_size=batch_size,
704 check_steps=True,
705 steps_name="steps_per_epoch",
706 steps=steps_per_epoch,
707 validation_split=validation_split,
708 shuffle=shuffle,
709 )
711 if validation_data:
712 val_x, val_y, val_sample_weights = model._prepare_validation_data(
713 validation_data, batch_size, validation_steps
714 )
715 elif validation_split and 0.0 < validation_split < 1.0:
716 (
717 x,
718 y,
719 sample_weights,
720 val_x,
721 val_y,
722 val_sample_weights,
723 ) = training_utils_v1.split_training_and_validation_data(
724 x, y, sample_weights, validation_split
725 )
726 else:
727 if validation_steps:
728 raise ValueError(
729 "`validation_steps` should not be specified if "
730 "`validation_data` is None."
731 )
732 val_x, val_y, val_sample_weights = None, None, None
734 return fit_loop(
735 model,
736 inputs=x,
737 targets=y,
738 sample_weights=sample_weights,
739 batch_size=batch_size,
740 epochs=epochs,
741 verbose=verbose,
742 callbacks=callbacks,
743 val_inputs=val_x,
744 val_targets=val_y,
745 val_sample_weights=val_sample_weights,
746 shuffle=shuffle,
747 initial_epoch=initial_epoch,
748 steps_per_epoch=steps_per_epoch,
749 validation_steps=validation_steps,
750 validation_freq=validation_freq,
751 steps_name="steps_per_epoch",
752 )
754 def evaluate(
755 self,
756 model,
757 x=None,
758 y=None,
759 batch_size=None,
760 verbose=1,
761 sample_weight=None,
762 steps=None,
763 callbacks=None,
764 **kwargs,
765 ):
766 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
767 x, y, sample_weights = model._standardize_user_data(
768 x,
769 y,
770 sample_weight=sample_weight,
771 batch_size=batch_size,
772 check_steps=True,
773 steps_name="steps",
774 steps=steps,
775 )
776 return test_loop(
777 model,
778 inputs=x,
779 targets=y,
780 sample_weights=sample_weights,
781 batch_size=batch_size,
782 verbose=verbose,
783 steps=steps,
784 callbacks=callbacks,
785 )
787 def predict(
788 self,
789 model,
790 x,
791 batch_size=None,
792 verbose=0,
793 steps=None,
794 callbacks=None,
795 **kwargs,
796 ):
797 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
798 x, _, _ = model._standardize_user_data(
799 x, check_steps=True, steps_name="steps", steps=steps
800 )
801 return predict_loop(
802 model,
803 x,
804 batch_size=batch_size,
805 verbose=verbose,
806 steps=steps,
807 callbacks=callbacks,
808 )