Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_arrays_v1.py: 14%
255 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to plain array data."""
16# pylint: disable=protected-access
18import functools
20import numpy as np
22from tensorflow.python.data.ops import iterator_ops
23from tensorflow.python.eager import context
24from tensorflow.python.framework import errors
25from tensorflow.python.keras import backend
26from tensorflow.python.keras import callbacks as cbks
27from tensorflow.python.keras.distribute import distributed_training_utils_v1
28from tensorflow.python.keras.engine import training_utils_v1
29from tensorflow.python.keras.utils.generic_utils import make_batches
30from tensorflow.python.keras.utils.generic_utils import slice_arrays
31from tensorflow.python.keras.utils.mode_keys import ModeKeys
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.types import data as data_types
34from tensorflow.python.util import nest
36try:
37 from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
38except ImportError:
39 issparse = None
42def model_iteration(model,
43 inputs,
44 targets=None,
45 sample_weights=None,
46 batch_size=None,
47 epochs=1,
48 verbose=1,
49 callbacks=None,
50 val_inputs=None,
51 val_targets=None,
52 val_sample_weights=None,
53 shuffle=True,
54 initial_epoch=0,
55 steps_per_epoch=None,
56 validation_steps=None,
57 validation_freq=1,
58 mode=ModeKeys.TRAIN,
59 validation_in_fit=False,
60 prepared_feed_values_from_dataset=False,
61 steps_name='steps',
62 **kwargs):
63 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT.
65 Args:
66 model: Keras Model instance.
67 inputs: Either a list or dictionary of arrays, or a dataset instance.
68 targets: List/dictionary of input arrays.
69 sample_weights: Optional list of sample weight arrays.
70 batch_size: Integer batch size or None if unknown.
71 epochs: Number of times to iterate over the data
72 verbose: 0, 1, or 2. Verbosity mode.
73 0 = silent, 1 = progress bar, 2 = one line per epoch.
74 Note that the progress bar is not particularly useful when
75 logged to a file, so verbose=2 is recommended when not running
76 interactively (eg, in a production environment).
77 callbacks: List of callbacks to be called during training
78 val_inputs: Either a list or dictionary of arrays, or a dataset instance.
79 val_targets: List/dictionary of target arrays.
80 val_sample_weights: Optional list of sample weight arrays.
81 shuffle: Whether to shuffle the data at the beginning of each epoch
82 concatenation of list the display names of the outputs of `f` and the
83 list of display names of the outputs of `f_val`.
84 initial_epoch: Epoch at which to start training (useful for resuming a
85 previous training run)
86 steps_per_epoch: Total number of steps (batches of samples) before
87 declaring one epoch finished and starting the next epoch. Ignored with
88 the default value of `None`.
89 validation_steps: Number of steps to run validation for (only if doing
90 validation from data tensors). Ignored with the default value of
91 `None`.
92 validation_freq: Only relevant if validation data is provided. Integer or
93 `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
94 integer, specifies how many training epochs to run before a new
95 validation run is performed, e.g. `validation_freq=2` runs
96 validation every 2 epochs. If a Container, specifies the epochs on
97 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
98 validation at the end of the 1st, 2nd, and 10th epochs.
99 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
100 validation_in_fit: if true, then this method is invoked from within
101 training iteration (for validation). In the case where `val_inputs` is
102 a dataset, this flag indicates that its iterator and feed values are
103 already created so should properly reuse resources.
104 prepared_feed_values_from_dataset: if True, `inputs` is a list of feed
105 tensors returned from `_prepare_feed_values` call on the validation
106 dataset, so do not call it again on `inputs`. Should only be used for
107 inline validation (i.e., only if `validation_in_fit` is also True).
108 steps_name: The string name of the steps argument, either `steps`,
109 `validation_steps`, or `steps_per_epoch`. Only used for error message
110 formatting.
111 **kwargs: Additional arguments for backwards compatibility.
113 Returns:
114 - In TRAIN mode: `History` object.
115 - In TEST mode: Evaluation metrics.
116 - In PREDICT mode: Outputs of the Model called on inputs.
118 Raises:
119 ValueError: in case of invalid arguments.
120 """
121 # Backwards compatibility.
122 if 'steps' in kwargs:
123 steps_per_epoch = kwargs.pop('steps')
124 if kwargs:
125 raise TypeError('Unknown arguments: %s' % (kwargs,))
127 # In case we were passed a dataset, we extract symbolic tensors from it.
128 reset_dataset_after_each_epoch = False
129 input_iterator = None
130 is_dataset = isinstance(inputs,
131 (data_types.DatasetV1, data_types.DatasetV2))
132 # TODO(fchollet): consider moving `steps_per_epoch` inference to
133 # _standardize_user_data and set reset_dataset_after_each_epoch as an
134 # attribute on the dataset instance.
135 if is_dataset:
136 if steps_per_epoch is None:
137 reset_dataset_after_each_epoch = True
138 steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
139 model, inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name)
140 input_iterator = _get_iterator(inputs, model._distribution_strategy)
142 # Enter tf.distribute.Strategy scope.
143 if model._distribution_strategy:
144 scope = distributed_training_utils_v1.distributed_scope(
145 strategy=model._distribution_strategy,
146 learning_phase=(1 if mode == ModeKeys.TRAIN else 0))
147 scope.__enter__()
149 use_steps = is_dataset or steps_per_epoch is not None
150 do_validation = val_inputs is not None
152 # Prepare input data.
153 inputs = input_iterator or inputs
154 if validation_in_fit and prepared_feed_values_from_dataset:
155 # When invoking validation in training loop, avoid creating iterator and
156 # list of feed values for the same validation dataset multiple times (which
157 # essentially would call `iterator.get_next()` that slows down execution and
158 # leads to OOM errors eventually.
159 ins = inputs
160 else:
161 ins = _prepare_feed_values(model, inputs, targets, sample_weights, mode)
162 # `ins` is a function when a distribute strategy is used in Eager mode. In
163 # that case `is_dataset` is True. The code branches that have requirements
164 # about the type of `ins` do not trigger in the distributed case.
166 if not is_dataset:
167 num_samples_or_steps = _get_num_samples_or_steps(ins, batch_size,
168 steps_per_epoch)
169 else:
170 num_samples_or_steps = steps_per_epoch
172 # Update sample_weight_mode of the model if sample_weights is specified by the
173 # user. We need to call this function after we have a handle on the inputs
174 # (both numpy arrays and datasets) in order to determine if the user has
175 # specified sample_weights.
176 _update_sample_weight_mode(model, mode, ins)
178 # Get step function and loop type. As part of building the execution
179 # function we recompile the metrics based on the updated
180 # sample_weight_mode value.
181 f = _make_execution_function(model, mode)
183 # Prepare validation data. Hold references to the iterator and the input list
184 # to properly reinitialize and reuse in multiple validation passes.
185 val_iterator = None
186 if isinstance(val_inputs, (data_types.DatasetV1, data_types.DatasetV2)):
187 if validation_steps is None:
188 # Because we pass an iterator feed instead of a Dataset to the eval
189 # model_iteration() call, it will not trigger the dataset-input path
190 # that determines the number of steps required. To avoid this issue,
191 # set validation_steps here if validation_steps is None.
192 validation_steps = training_utils_v1.infer_steps_for_dataset(
193 model,
194 val_inputs,
195 validation_steps,
196 epochs=epochs,
197 steps_name='validation_steps')
198 val_iterator = _get_iterator(val_inputs, model._distribution_strategy)
199 val_inputs = _prepare_feed_values(
200 model, val_iterator, val_targets, val_sample_weights, ModeKeys.TEST)
201 # Get num steps for printing.
202 val_samples_or_steps = validation_steps
203 else:
204 # Get num samples for printing.
205 val_samples_or_steps = val_inputs and nest.flatten(
206 val_inputs)[0].shape[0] or None
208 if mode == ModeKeys.TRAIN and verbose:
209 _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset)
211 # Configure callbacks.
212 count_mode = 'steps' if use_steps else 'samples'
213 callbacks = cbks.configure_callbacks(
214 callbacks,
215 model,
216 do_validation=do_validation,
217 batch_size=batch_size,
218 epochs=epochs,
219 steps_per_epoch=steps_per_epoch,
220 samples=num_samples_or_steps,
221 count_mode=count_mode,
222 verbose=verbose,
223 mode=mode)
225 # Find beforehand arrays that need sparse-to-dense conversion.
226 if issparse is not None and not use_steps:
227 indices_for_conversion_to_dense = []
228 feed = _get_model_feed(model, mode)
229 for i, (input_data, feed_tensor) in enumerate(zip(ins, feed)):
230 if issparse(input_data) and not backend.is_sparse(feed_tensor):
231 indices_for_conversion_to_dense.append(i)
233 # Select aggregation method.
234 if mode == ModeKeys.PREDICT:
235 aggregator = training_utils_v1.OutputsAggregator(
236 use_steps,
237 num_samples=None if steps_per_epoch else num_samples_or_steps,
238 steps=steps_per_epoch)
239 else:
240 aggregator = training_utils_v1.MetricsAggregator(
241 use_steps,
242 num_samples=None if steps_per_epoch else num_samples_or_steps,
243 steps=steps_per_epoch)
245 if model._compile_distribution:
246 distributed_training_utils_v1._copy_weights_to_distributed_model(
247 model, mode)
249 callbacks.model.stop_training = False
250 callbacks._call_begin_hook(mode)
252 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode)
254 for epoch in range(initial_epoch, epochs):
255 if callbacks.model.stop_training:
256 break
258 # Setup work for each epoch
259 epoch_logs = {}
260 if mode != ModeKeys.PREDICT:
261 # Collecting and resetting metrics has non-zero cost and will needlessly
262 # slow down model.predict.
263 model.reset_metrics()
264 if mode == ModeKeys.TRAIN:
265 callbacks.on_epoch_begin(epoch, epoch_logs)
267 if use_steps:
268 # Step-wise loop.
269 if steps_per_epoch is None:
270 # Loop over dataset until `OutOfRangeError` is raised.
271 target_steps = np.inf
272 else:
273 # Loop over dataset for the specified number of steps.
274 target_steps = steps_per_epoch
276 step = 0
277 while step < target_steps:
278 batch_logs = {'batch': step, 'size': 1}
279 callbacks._call_batch_hook(mode, 'begin', step, batch_logs)
281 # Get outputs.
282 try:
283 # `ins` can be callable in tf.distribute.Strategy + eager case.
284 if not callable(ins) or (model._distribution_strategy and
285 not distributed_training_utils_v1
286 .is_distributing_by_cloning(model)):
287 actual_inputs = ins
288 else:
289 actual_inputs = ins()
290 batch_outs = f(actual_inputs)
291 except errors.OutOfRangeError:
292 if is_dataset:
293 # The dataset passed by the user ran out of batches.
294 # Now we know the cardinality of the dataset.
295 # If steps_per_epoch was specified, then running out of data is
296 # unexpected, so we stop training and inform the user.
297 if steps_per_epoch:
298 callbacks.model.stop_training = True
299 logging.warning(
300 'Your dataset ran out of data; interrupting training. '
301 'Make sure that your dataset can generate at least '
302 '`%s * epochs` batches (in this case, %d batches). '
303 'You may need to use the repeat() function when '
304 'building your dataset.'
305 % (steps_name, steps_per_epoch * epochs))
306 elif step > 0:
307 steps_per_epoch = step
308 aggregator.steps = steps_per_epoch
309 else:
310 # We ran out of batches while the user passed an iterator (legacy).
311 callbacks.model.stop_training = True
312 logging.warning(
313 'Your dataset iterator ran out of data; '
314 'interrupting training. Make sure that your iterator '
315 'can generate at least `%s * epochs` '
316 'batches (in this case, %d batches). You may need to'
317 'use the repeat() function when building your '
318 'dataset.' % (steps_name, steps_per_epoch * epochs))
319 break
321 if not isinstance(batch_outs, list):
322 batch_outs = [batch_outs]
324 if model._distribution_strategy:
325 batch_outs = (
326 distributed_training_utils_v1._per_replica_aggregate_batch(
327 model._distribution_strategy, batch_outs, model, mode))
329 # Aggregate results.
330 if step == 0:
331 aggregator.create(batch_outs)
332 aggregator.aggregate(batch_outs)
334 # Callbacks batch end.
335 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
336 callbacks._call_batch_hook(mode, 'end', step, batch_logs)
337 step += 1
339 if callbacks.model.stop_training:
340 break
341 else:
342 # Sample-wise loop.
343 index_array = np.arange(num_samples_or_steps)
344 if shuffle == 'batch':
345 index_array = training_utils_v1.batch_shuffle(index_array, batch_size)
346 elif shuffle:
347 np.random.shuffle(index_array)
348 batches = make_batches(num_samples_or_steps, batch_size)
349 for batch_index, (batch_start, batch_end) in enumerate(batches):
350 batch_ids = index_array[batch_start:batch_end]
351 # Slice into a batch.
352 if len(batches) == 1:
353 # If we only have one batch, do not slice. This takes care of
354 # composite tensors in non-Dataset modes; we currently don't support
355 # slicing them.
356 # TODO(b/133517906): Add slicing support.
357 ins_batch = ins
358 else:
359 try:
360 if ins and isinstance(ins[-1], int):
361 # Do not slice the training phase flag.
362 ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
363 else:
364 ins_batch = slice_arrays(ins, batch_ids)
365 except TypeError:
366 raise TypeError('TypeError while preparing batch. '
367 'If using HDF5 input data, '
368 'pass shuffle="batch".')
370 # Sparse to dense conversion.
371 if issparse is not None:
372 for i in indices_for_conversion_to_dense:
373 ins_batch[i] = ins_batch[i].toarray()
375 # Callbacks batch_begin.
376 batch_logs = {'batch': batch_index, 'size': len(batch_ids)}
377 callbacks._call_batch_hook(mode, 'begin', batch_index, batch_logs)
379 # Get outputs.
380 batch_outs = f(ins_batch)
381 if not isinstance(batch_outs, list):
382 batch_outs = [batch_outs]
384 # Aggregate results.
385 if batch_index == 0:
386 aggregator.create(batch_outs)
387 aggregator.aggregate(batch_outs, batch_start, batch_end)
389 # Callbacks batch end.
390 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
391 callbacks._call_batch_hook(mode, 'end', batch_index, batch_logs)
393 if callbacks.model.stop_training:
394 break
396 aggregator.finalize()
397 results = aggregator.results
398 epoch_logs = cbks.make_logs(model, epoch_logs, results, mode)
399 if len(results) == 1:
400 results = results[0]
402 # Run the test loop every `validation_freq` epochs during training.
403 if (do_validation and
404 training_utils_v1.should_run_validation(validation_freq, epoch) and
405 not callbacks.model.stop_training):
407 if model._compile_distribution:
408 # Since we create a new clone from the original model we need to copy
409 # the weights back to the original model before we can run validation.
410 distributed_training_utils_v1._copy_weights_to_original_model(
411 model, ModeKeys.TRAIN)
413 val_results = model_iteration(
414 model,
415 val_inputs,
416 targets=val_targets,
417 sample_weights=val_sample_weights,
418 batch_size=batch_size,
419 steps_per_epoch=validation_steps,
420 callbacks=callbacks,
421 verbose=0,
422 mode=ModeKeys.TEST,
423 validation_in_fit=True,
424 prepared_feed_values_from_dataset=(val_iterator is not None),
425 steps_name='validation_steps')
426 if not isinstance(val_results, list):
427 val_results = [val_results]
428 epoch_logs = cbks.make_logs(
429 model, epoch_logs, val_results, mode, prefix='val_')
430 if val_iterator and epoch < epochs - 1:
431 _reinitialize_iterator(val_iterator, model._distribution_strategy)
433 if mode == ModeKeys.TRAIN:
434 # Epochs only apply to `fit`.
435 callbacks.on_epoch_end(epoch, epoch_logs)
437 # Reinitialize dataset iterator for the next epoch.
438 if reset_dataset_after_each_epoch and epoch < epochs - 1:
439 _reinitialize_iterator(input_iterator, model._distribution_strategy)
441 model._successful_loop_finish = True
442 callbacks._call_end_hook(mode)
444 if model._distribution_strategy:
445 if model._compile_distribution:
446 # TODO(priyag, psv): Copy back metrics to the original model as well?
447 distributed_training_utils_v1._copy_weights_to_original_model(model, mode)
448 scope.__exit__(None, None, None)
450 if mode == ModeKeys.TRAIN:
451 return model.history
452 return results
455def _get_model_feed(model, mode):
456 if mode == ModeKeys.PREDICT:
457 feed = model._feed_inputs
458 else:
459 feed = (
460 model._feed_inputs + model._feed_targets + model._feed_sample_weights)
461 return feed
464def _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset):
465 increment = 'steps' if is_dataset else 'samples'
466 msg = 'Train on {0} {increment}'.format(
467 num_samples_or_steps, increment=increment)
468 if val_samples_or_steps:
469 msg += ', validate on {0} {increment}'.format(
470 val_samples_or_steps, increment=increment)
471 print(msg)
474def _get_num_samples_or_steps(ins, batch_size, steps_per_epoch):
475 """Returns total number of samples (when training in batch mode) or steps."""
476 if steps_per_epoch:
477 return steps_per_epoch
478 return training_utils_v1.check_num_samples(ins, batch_size, steps_per_epoch,
479 'steps_per_epoch')
482def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
483 """Prepare feed values to the model execution function.
485 Args:
486 model: Model to prepare feed values for.
487 inputs: List or dict of model inputs.
488 targets: Optional list of model targets.
489 sample_weights: Optional list of sample weight arrays.
490 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
492 Returns:
493 Feed values for the model in the given mode.
494 """
495 if model._distribution_strategy:
496 if isinstance(inputs, (data_types.DatasetV1, data_types.DatasetV2)):
497 inputs = distributed_training_utils_v1.get_iterator(
498 inputs, model._distribution_strategy)
500 def get_distributed_inputs():
501 return distributed_training_utils_v1._prepare_feed_values(
502 model, inputs, targets, sample_weights, mode)
504 # In the eager case, we want to call the input method per step, so return
505 # a lambda from here that can be called. Note that this is applicable only
506 # in Distribution Strategy case as it follows the same code path for both
507 # eager and graph modes.
508 # TODO(priyag,omalleyt): Either we should move the training DS with
509 # IteratorBase to use training_generator code path, or figure out how to
510 # set a symbolic Iterator out of a Dataset when in eager mode.
511 if context.executing_eagerly():
512 return get_distributed_inputs
513 else:
514 return get_distributed_inputs()
516 if isinstance(inputs, (data_types.DatasetV1, data_types.DatasetV2,
517 iterator_ops.Iterator)):
518 inputs, targets, sample_weights = model._standardize_user_data(
519 inputs,
520 extract_tensors_from_dataset=True)
522 inputs = training_utils_v1.ModelInputs(inputs).as_list()
523 targets = list(targets or [])
524 sample_weights = list(sample_weights or [])
525 ins = inputs + targets + sample_weights
526 if mode == ModeKeys.TRAIN and not isinstance(
527 backend.symbolic_learning_phase(), int):
528 ins += [True] # Add learning phase value.
529 return ins
532def _get_iterator(inputs, distribution_strategy=None):
533 if distribution_strategy:
534 return distributed_training_utils_v1.get_iterator(
535 inputs, distribution_strategy)
536 return training_utils_v1.get_iterator(inputs)
539def _reinitialize_iterator(iterator, distribution_strategy=None):
540 if distribution_strategy:
541 distributed_training_utils_v1.initialize_iterator(
542 iterator, distribution_strategy)
543 else:
544 training_utils_v1.initialize_iterator(iterator)
547def _make_execution_function(model, mode):
548 """Makes function to run one step of model execution."""
549 if model._distribution_strategy:
550 return distributed_training_utils_v1._make_execution_function(model, mode)
551 return model._make_execution_function(mode)
554def _update_sample_weight_mode(model, mode, inputs):
555 """Updates the sample_weight_mode of a given model."""
556 # Add a quick return to prevent us from calling model._feed_targets that
557 # accesses certain model properties that may not be set in the `PREDICT` mode.
558 if mode == ModeKeys.PREDICT:
559 return
561 sample_weights = None
562 # `inputs` is the model's inputs + targets + sample_weights +
563 # learning phase placeholder if specified. To update the sample_weight_mode
564 # we need to determine if the user has passed sample weights as part of the
565 # input.
566 if not callable(inputs):
567 sample_weights = inputs[len(model._feed_inputs) + len(model._feed_targets):]
568 has_learning_phase_pl = (mode == ModeKeys.TRAIN and
569 not isinstance(backend.symbolic_learning_phase(),
570 int))
571 if has_learning_phase_pl:
572 sample_weights = sample_weights[:-1]
573 model._update_sample_weight_modes(sample_weights=sample_weights)
575 # Call the DistributionStrategy specific function to update the
576 # sample_weight_mode on the model.
577 if model._distribution_strategy:
578 distributed_training_utils_v1._update_sample_weight_modes(model, mode,
579 sample_weights)
581# For backwards compatibility for internal users of these loops.
582fit_loop = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
583test_loop = functools.partial(
584 model_iteration, mode=ModeKeys.TEST, shuffle=False)
585predict_loop = functools.partial(
586 model_iteration, mode=ModeKeys.PREDICT, shuffle=False)
589class ArrayLikeTrainingLoop(training_utils_v1.TrainingLoop):
590 """TrainingLoop that handle inputs like array.
592 This is the default handler for most of the input data types, includes
593 symbolic tensors or Numpy array-like, Datasets and iterators in graph mode
594 (since they generate symbolic tensors). This Function is used to handle model
595 with `run_eagerly` = False.
596 """
598 def fit(self,
599 model,
600 x=None,
601 y=None,
602 batch_size=None,
603 epochs=1,
604 verbose=1,
605 callbacks=None,
606 validation_split=0.,
607 validation_data=None,
608 shuffle=True,
609 class_weight=None,
610 sample_weight=None,
611 initial_epoch=0,
612 steps_per_epoch=None,
613 validation_steps=None,
614 validation_freq=1,
615 **kwargs):
616 batch_size = model._validate_or_infer_batch_size(batch_size,
617 steps_per_epoch, x)
619 x, y, sample_weights = model._standardize_user_data(
620 x,
621 y,
622 sample_weight=sample_weight,
623 class_weight=class_weight,
624 batch_size=batch_size,
625 check_steps=True,
626 steps_name='steps_per_epoch',
627 steps=steps_per_epoch,
628 validation_split=validation_split,
629 shuffle=shuffle)
631 if validation_data:
632 val_x, val_y, val_sample_weights = model._prepare_validation_data(
633 validation_data, batch_size, validation_steps)
634 elif validation_split and 0. < validation_split < 1.:
635 (x, y, sample_weights, val_x, val_y, val_sample_weights
636 ) = training_utils_v1.split_training_and_validation_data(
637 x, y, sample_weights, validation_split)
638 else:
639 if validation_steps:
640 raise ValueError('`validation_steps` should not be specified if '
641 '`validation_data` is None.')
642 val_x, val_y, val_sample_weights = None, None, None
644 return fit_loop(
645 model,
646 inputs=x,
647 targets=y,
648 sample_weights=sample_weights,
649 batch_size=batch_size,
650 epochs=epochs,
651 verbose=verbose,
652 callbacks=callbacks,
653 val_inputs=val_x,
654 val_targets=val_y,
655 val_sample_weights=val_sample_weights,
656 shuffle=shuffle,
657 initial_epoch=initial_epoch,
658 steps_per_epoch=steps_per_epoch,
659 validation_steps=validation_steps,
660 validation_freq=validation_freq,
661 steps_name='steps_per_epoch')
663 def evaluate(self,
664 model,
665 x=None,
666 y=None,
667 batch_size=None,
668 verbose=1,
669 sample_weight=None,
670 steps=None,
671 callbacks=None,
672 **kwargs):
673 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
674 x, y, sample_weights = model._standardize_user_data(
675 x,
676 y,
677 sample_weight=sample_weight,
678 batch_size=batch_size,
679 check_steps=True,
680 steps_name='steps',
681 steps=steps)
682 return test_loop(
683 model,
684 inputs=x,
685 targets=y,
686 sample_weights=sample_weights,
687 batch_size=batch_size,
688 verbose=verbose,
689 steps=steps,
690 callbacks=callbacks)
692 def predict(self,
693 model,
694 x,
695 batch_size=None,
696 verbose=0,
697 steps=None,
698 callbacks=None,
699 **kwargs):
700 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
701 x, _, _ = model._standardize_user_data(
702 x, check_steps=True, steps_name='steps', steps=steps)
703 return predict_loop(
704 model,
705 x,
706 batch_size=batch_size,
707 verbose=verbose,
708 steps=steps,
709 callbacks=callbacks)