Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_generator_v1.py: 17%
238 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to Python generators of array data.
16"""
17# pylint: disable=protected-access
19import functools
20import math
22import numpy as np
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.data.ops import iterator_ops
26from tensorflow.python.eager import context
27from tensorflow.python.framework import errors
28from tensorflow.python.keras import backend
29from tensorflow.python.keras import callbacks as cbks
30from tensorflow.python.keras.engine import training_utils
31from tensorflow.python.keras.engine import training_utils_v1
32from tensorflow.python.keras.utils import data_utils
33from tensorflow.python.keras.utils import generic_utils
34from tensorflow.python.keras.utils.mode_keys import ModeKeys
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.types import data as data_types
37from tensorflow.python.util import nest
40def model_iteration(model,
41 data,
42 steps_per_epoch=None,
43 epochs=1,
44 verbose=1,
45 callbacks=None,
46 validation_data=None,
47 validation_steps=None,
48 validation_freq=1,
49 class_weight=None,
50 max_queue_size=10,
51 workers=1,
52 use_multiprocessing=False,
53 shuffle=False,
54 initial_epoch=0,
55 mode=ModeKeys.TRAIN,
56 batch_size=None,
57 steps_name='steps',
58 **kwargs):
59 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT.
61 Args:
62 model: Keras Model instance.
63 data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or
64 `(x, y, sample_weights)`) or a generator or
65 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
66 steps_per_epoch: Total number of steps (batches of samples) before
67 declaring one epoch finished and starting the next epoch. Ignored with
68 the default value of `None`.
69 epochs: Number of times to iterate over the data.
70 verbose: 0, 1, or 2. Verbosity mode.
71 0 = silent, 1 = progress bar, 2 = one line per epoch.
72 Note that the progress bar is not particularly useful when
73 logged to a file, so verbose=2 is recommended when not running
74 interactively (eg, in a production environment).
75 callbacks: List of callbacks to be called during training.
76 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or
77 `(x, y)` or `(x, y, sample_weights)`) or a generator or
78 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
79 validation_steps: Total number of steps (batches of samples) before
80 declaring validation finished.
81 validation_freq: Only relevant if validation data is provided. Integer or
82 `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
83 integer, specifies how many training epochs to run before a new
84 validation run is performed, e.g. `validation_freq=2` runs
85 validation every 2 epochs. If a Container, specifies the epochs on
86 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
87 validation at the end of the 1st, 2nd, and 10th epochs.
88 class_weight: Dictionary mapping class indices to a weight for the class.
89 max_queue_size: Integer. Maximum size for the generator queue. If
90 unspecified, `max_queue_size` will default to 10.
91 workers: Integer. Maximum number of processes to spin up when using
92 process-based threading. If unspecified, `workers` will default to 1. If
93 0, will execute the generator on the main thread.
94 use_multiprocessing: Boolean. If `True`, use process-based threading. If
95 unspecified, `use_multiprocessing` will default to `False`. Note that
96 because this implementation relies on multiprocessing, you should not
97 pass non-picklable arguments to the generator as they can't be passed
98 easily to children processes.
99 shuffle: Boolean. Whether to shuffle the order of the batches at the
100 beginning of each epoch. Only used with instances of `Sequence`
101 (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not
102 `None`.
103 initial_epoch: Epoch at which to start training (useful for resuming a
104 previous training run).
105 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
106 batch_size: Integer batch size or None if unknown. Will only be used if
107 `data` is in NumPy/Tensor format.
108 steps_name: The string name of the steps argument, either `steps`,
109 `validation_steps`, or `steps_per_epoch`. Only used for error message
110 formatting.
111 **kwargs: Additional arguments for backwards compatibility. `steps` is
112 accepted as an alias for `steps_per_epoch`.
114 Returns:
115 - In TRAIN mode: `History` object.
116 - In TEST mode: Evaluation metrics.
117 - In PREDICT mode: Outputs of the Model called on inputs.
119 Raises:
120 ValueError: in case of invalid arguments.
121 """
122 if 'steps' in kwargs:
123 steps_per_epoch = kwargs['steps']
125 # Determine the number of steps per epoch and whether we should reset the
126 # dataset at the end of each epoch.
127 reset_dataset_after_each_epoch = False
128 original_dataset = None
129 is_dataset = isinstance(data, (data_types.DatasetV2, data_types.DatasetV1))
130 if is_dataset:
131 original_dataset = data
132 if steps_per_epoch is None:
133 reset_dataset_after_each_epoch = True
134 steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
135 model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name)
137 # Convert to a format that supports `next(generator)`.
138 generator, steps_per_epoch = convert_to_generator_like(
139 data,
140 steps_per_epoch=steps_per_epoch,
141 batch_size=batch_size,
142 epochs=epochs - initial_epoch,
143 shuffle=shuffle)
145 do_validation = validation_data is not None
146 is_sequence = isinstance(generator, data_utils.Sequence)
147 _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers,
148 steps_per_epoch, validation_data, validation_steps, mode,
149 kwargs)
151 batch_function = _make_execution_function(
152 model, mode, class_weight=class_weight)
154 # Create the queue for the generator.
155 enqueuer = None
156 if not is_dataset:
157 generator, enqueuer = _make_enqueued_generator(
158 generator,
159 workers=workers,
160 use_multiprocessing=use_multiprocessing,
161 max_queue_size=max_queue_size,
162 shuffle=shuffle)
164 num_samples_or_steps, use_steps = _get_num_samples_or_steps(
165 data, steps_per_epoch)
167 count_mode = 'steps' if use_steps else 'samples'
168 callbacks = cbks.configure_callbacks(
169 callbacks,
170 model,
171 do_validation=do_validation,
172 epochs=epochs,
173 steps_per_epoch=steps_per_epoch,
174 batch_size=batch_size,
175 samples=num_samples_or_steps,
176 count_mode=count_mode,
177 verbose=verbose,
178 mode=mode)
180 if mode == ModeKeys.PREDICT:
181 aggregator = training_utils_v1.OutputsAggregator(
182 True, steps=steps_per_epoch)
183 else:
184 aggregator = training_utils_v1.MetricsAggregator(
185 True, steps=steps_per_epoch)
187 should_set_learning_phase = context.executing_eagerly() and model.run_eagerly
188 if should_set_learning_phase:
189 learning_phase_scope = backend.eager_learning_phase_scope(
190 1 if mode == ModeKeys.TRAIN else 0)
191 learning_phase_scope.__enter__()
193 callbacks.model.stop_training = False
194 callbacks._call_begin_hook(mode)
196 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode)
198 for epoch in range(initial_epoch, epochs):
199 if callbacks.model.stop_training:
200 break
202 # Setup work for each epoch.
203 model.reset_metrics()
204 epoch_logs = {}
205 if mode == ModeKeys.TRAIN:
206 callbacks.on_epoch_begin(epoch, epoch_logs)
208 if steps_per_epoch is None:
209 # Loop over dataset until `OutOfRangeError` is raised.
210 target_steps = np.inf
211 else:
212 # Loop over dataset for the specified number of steps.
213 target_steps = steps_per_epoch
215 step = 0
216 while step < target_steps:
217 batch_data = _get_next_batch(generator)
218 if batch_data is None:
219 if is_dataset:
220 # The dataset passed by the user ran out of batches.
221 # Now we know the cardinality of the dataset.
222 # If steps_per_epoch was specified, then running out of data is
223 # unexpected, so we stop training and inform the user.
224 if steps_per_epoch:
225 callbacks.model.stop_training = True
226 logging.warning(
227 'Your dataset ran out of data; interrupting training. '
228 'Make sure that your dataset can generate at least '
229 '`%s * epochs` batches (in this case, %d batches). '
230 'You may need to use the repeat() function when '
231 'building your dataset.'
232 % (steps_name, steps_per_epoch * epochs))
233 elif step > 0:
234 steps_per_epoch = step
235 aggregator.steps = steps_per_epoch
236 else:
237 # We ran out of batches while the user passed an iterator (legacy).
238 callbacks.model.stop_training = True
239 logging.warning(
240 'Your dataset iterator ran out of data; '
241 'interrupting training. Make sure that your iterator '
242 'can generate at least `%s * epochs` '
243 'batches (in this case, %d batches). You may need to'
244 'use the repeat() function when building your '
245 'dataset.' % (steps_name, steps_per_epoch * epochs))
246 break
248 # `batch_size` used for validation data if validation
249 # data is NumPy/EagerTensors.
250 batch_size = int(nest.flatten(batch_data)[0].shape[0])
252 # Callbacks batch begin.
253 batch_logs = {'batch': step, 'size': batch_size}
254 callbacks._call_batch_hook(mode, 'begin', step, batch_logs)
256 is_deferred = not model._is_compiled
257 batch_outs = batch_function(*batch_data)
258 if not isinstance(batch_outs, list):
259 batch_outs = [batch_outs]
261 if step == 0:
262 aggregator.create(batch_outs)
264 if is_deferred:
265 # Set callbacks params. We do this here when model is compiled only
266 # in the first iteration of this loop (deferred build scenario).
267 cbks.set_callback_parameters(
268 callbacks,
269 model,
270 do_validation=do_validation,
271 batch_size=batch_size,
272 epochs=epochs,
273 steps_per_epoch=steps_per_epoch,
274 samples=num_samples_or_steps,
275 verbose=verbose,
276 mode=mode)
278 # Aggregate results.
279 aggregator.aggregate(batch_outs)
281 # Callbacks batch end.
282 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
283 callbacks._call_batch_hook(mode, 'end', step, batch_logs)
284 step += 1
286 if callbacks.model.stop_training:
287 break
289 aggregator.finalize()
290 results = aggregator.results
291 epoch_logs = cbks.make_logs(model, epoch_logs, results, mode)
292 if len(results) == 1:
293 results = results[0]
295 # Run the test loop every epoch during training.
296 if (do_validation and
297 training_utils_v1.should_run_validation(validation_freq, epoch) and
298 not callbacks.model.stop_training):
299 val_results = model_iteration(
300 model,
301 validation_data,
302 steps_per_epoch=validation_steps,
303 batch_size=batch_size,
304 class_weight=class_weight,
305 workers=workers,
306 use_multiprocessing=use_multiprocessing,
307 max_queue_size=max_queue_size,
308 callbacks=callbacks,
309 verbose=verbose,
310 mode=ModeKeys.TEST,
311 steps_name='validation_steps')
313 if not isinstance(val_results, list):
314 val_results = [val_results]
315 epoch_logs = cbks.make_logs(
316 model, epoch_logs, val_results, mode, prefix='val_')
318 if mode == ModeKeys.TRAIN:
319 # Epochs only apply to `fit`.
320 callbacks.on_epoch_end(epoch, epoch_logs)
322 # Recreate dataset iterator for the next epoch.
323 if reset_dataset_after_each_epoch and epoch < epochs - 1:
324 generator = dataset_ops.make_one_shot_iterator(original_dataset)
326 model._successful_loop_finish = True
327 callbacks._call_end_hook(mode)
329 if enqueuer is not None:
330 enqueuer.stop()
332 if should_set_learning_phase:
333 learning_phase_scope.__exit__(None, None, None)
335 if mode == ModeKeys.TRAIN:
336 return model.history
337 return results
340# Maintain compatibility with the existing names.
341fit_generator = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
342evaluate_generator = functools.partial(
343 model_iteration, mode=ModeKeys.TEST, shuffle=False)
344predict_generator = functools.partial(
345 model_iteration, mode=ModeKeys.PREDICT, shuffle=False)
348def _get_next_batch(generator):
349 """Retrieves the next batch of input data."""
350 try:
351 generator_output = next(generator)
352 except (StopIteration, errors.OutOfRangeError):
353 return None
355 if not isinstance(generator_output, tuple):
356 # Always wrap in a tuple.
357 generator_output = (generator_output,)
358 if len(generator_output) not in [1, 2, 3]:
359 raise ValueError(
360 'Output of generator should be a tuple of 1 or 2 or 3 '
361 'elements: (input,) or (input, target) or '
362 '(input, target, sample_weights). Received {}'.format(generator_output))
363 return generator_output
366def _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers,
367 steps_per_epoch, validation_data, validation_steps,
368 mode, kwargs):
369 """Raises errors if arguments are invalid.
371 Args:
372 is_sequence: Boolean, whether data is a `keras.utils.data_utils.Sequence`
373 instance.
374 is_dataset: Boolean, whether data is a dataset instance.
375 use_multiprocessing: Boolean. If `True`, use process-based threading. If
376 unspecified, `use_multiprocessing` will default to `False`. Note that
377 because this implementation relies on multiprocessing, you should not pass
378 non-picklable arguments to the generator as they can't be passed easily to
379 children processes.
380 workers: Integer. Maximum number of processes to spin up when using
381 process-based threading. If unspecified, `workers` will default to 1. If
382 0, will execute the generator on the main thread.
383 steps_per_epoch: Total number of steps (batches of samples) before declaring
384 one epoch finished and starting the next epoch. Ignored with the default
385 value of `None`.
386 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x,
387 y)` or `(x, y, sample_weights)`) or a generator or
388 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
389 validation_steps: Total number of steps (batches of samples) before
390 declaring validation finished.
391 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
392 kwargs: Additional arguments for backwards compatibility.
394 Raises:
395 ValueError: If `steps_per_epoch` or `validation_steps` are not passed
396 for data types that require them, or if unrecognized keyword
397 arguments are passed.
398 """
399 if not is_sequence and use_multiprocessing and workers > 1:
400 logging.warning(
401 UserWarning('Using a generator with `use_multiprocessing=True`'
402 ' and multiple workers may duplicate your data.'
403 ' Please consider using the `keras.utils.Sequence`'
404 ' class.'))
406 if steps_per_epoch is None and not is_dataset:
407 arg_name = 'steps_per_epoch' if mode == ModeKeys.TRAIN else 'steps'
408 raise ValueError('Please specify the number of steps via the '
409 '`{}` argument.'.format(arg_name))
411 val_gen = (
412 data_utils.is_generator_or_sequence(validation_data) or
413 isinstance(validation_data, iterator_ops.IteratorBase))
414 if (val_gen and not isinstance(validation_data, data_utils.Sequence) and
415 not validation_steps):
416 raise ValueError('Please specify the `validation_steps` argument.')
418 if any(k != 'steps' for k in kwargs):
419 raise ValueError('Invalid arguments passed: {}'.format(
420 [k for k in kwargs if k != 'steps']))
423def convert_to_generator_like(data,
424 batch_size=None,
425 steps_per_epoch=None,
426 epochs=1,
427 shuffle=False):
428 """Make a generator out of NumPy or EagerTensor inputs.
430 Args:
431 data: Either a generator or `keras.utils.data_utils.Sequence` object or
432 `Dataset`, `Iterator`, or a {1,2,3}-tuple of NumPy arrays or EagerTensors.
433 If a tuple, the elements represent `(x, y, sample_weights)` and may be
434 `None` or `[None]`.
435 batch_size: Used when creating a generator out of tuples of NumPy arrays or
436 EagerTensors.
437 steps_per_epoch: Steps of the generator to run each epoch. If `None` the
438 number of steps will be read from the data (for
439 `keras.utils.data_utils.Sequence` types).
440 epochs: Total number of epochs to run.
441 shuffle: Whether the data should be shuffled.
443 Returns:
444 - Generator, `keras.utils.data_utils.Sequence`, or `Iterator`.
446 Raises:
447 - ValueError: If `batch_size` is not provided for NumPy or EagerTensor
448 inputs.
449 """
450 if isinstance(data, tuple):
451 # Scrub `Nones` that might have been passed for `targets`, `sample_weights`.
452 data = tuple(
453 ele for ele in data if not all(e is None for e in nest.flatten(ele)))
455 if data_utils.is_generator_or_sequence(data) or isinstance(
456 data, iterator_ops.IteratorBase):
457 if isinstance(data, data_utils.Sequence):
458 if steps_per_epoch is None:
459 steps_per_epoch = len(data)
460 return data, steps_per_epoch
461 if isinstance(data, data_types.DatasetV2):
462 return dataset_ops.make_one_shot_iterator(data), steps_per_epoch
464 # Create generator from NumPy or EagerTensor Input.
465 num_samples = int(nest.flatten(data)[0].shape[0])
466 if batch_size is None:
467 raise ValueError(
468 'When passing input data as arrays, do not specify '
469 '`steps_per_epoch`/`steps` argument. Please use `batch_size` instead.')
470 steps_per_epoch = int(math.ceil(num_samples / batch_size))
472 def _gen(data):
473 """Makes a generator out of a structure of NumPy/EagerTensors."""
474 index_array = np.arange(num_samples)
475 for _ in range(epochs):
476 if shuffle:
477 np.random.shuffle(index_array)
478 batches = generic_utils.make_batches(num_samples, batch_size)
479 for (batch_start, batch_end) in batches:
480 batch_ids = index_array[batch_start:batch_end]
481 flat_batch_data = training_utils.slice_arrays(
482 nest.flatten(data), batch_ids, contiguous=(not shuffle))
483 yield nest.pack_sequence_as(data, flat_batch_data)
485 return _gen(data), steps_per_epoch
488def _make_enqueued_generator(generator,
489 workers=1,
490 use_multiprocessing=False,
491 max_queue_size=10,
492 shuffle=False):
493 """Create a buffered queue of next elements of the generator."""
494 is_sequence = isinstance(generator, data_utils.Sequence)
495 enqueuer = None
496 if workers > 0:
497 if is_sequence:
498 enqueuer = data_utils.OrderedEnqueuer(
499 generator, use_multiprocessing=use_multiprocessing, shuffle=shuffle)
500 else:
501 enqueuer = data_utils.GeneratorEnqueuer(
502 generator, use_multiprocessing=use_multiprocessing)
503 enqueuer.start(workers=workers, max_queue_size=max_queue_size)
504 output_generator = enqueuer.get()
505 else:
506 if is_sequence:
507 output_generator = data_utils.iter_sequence_infinite(generator)
508 else:
509 output_generator = generator
510 return output_generator, enqueuer
513def _make_execution_function(model, mode, class_weight=None):
514 """Makes function to run one step of model execution."""
515 if mode == ModeKeys.TRAIN:
516 f = functools.partial(model.train_on_batch, class_weight=class_weight)
517 elif mode == ModeKeys.TEST:
518 f = model.test_on_batch
519 else:
520 # Match signature of other modes to allow
521 # 1, 2, or 3-tuples from generator
522 def predict_on_batch(x, y=None, sample_weights=None): # pylint: disable=unused-argument
523 return model.predict_on_batch(x)
525 f = predict_on_batch
527 # Maintain stateful metrics across batch-level calls.
528 if mode != ModeKeys.PREDICT:
529 f = functools.partial(f, reset_metrics=False)
531 return f
534def _get_num_samples_or_steps(data, steps_per_epoch):
535 """Returns number of samples or steps, and whether to use steps count mode."""
536 flat_inputs = nest.flatten(data)
537 if hasattr(flat_inputs[0], 'shape'):
538 return int(flat_inputs[0].shape[0]), False
539 return steps_per_epoch, True
542class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
543 """Generator-like.
545 Input is Python generator, or Sequence object.
547 The difference between this class and `GeneratorLikeTrainingFunction` is that
548 this class only handles inputs that with x, y and sample_weight fused into one
549 param.
550 """
552 def fit(self,
553 model,
554 x=None,
555 y=None,
556 batch_size=None,
557 epochs=1,
558 verbose=1,
559 callbacks=None,
560 validation_split=0.,
561 validation_data=None,
562 shuffle=True,
563 class_weight=None,
564 sample_weight=None,
565 initial_epoch=0,
566 steps_per_epoch=None,
567 validation_steps=None,
568 validation_freq=1,
569 max_queue_size=10,
570 workers=1,
571 use_multiprocessing=False):
572 model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
573 training_utils_v1.check_generator_arguments(
574 y, sample_weight, validation_split=validation_split)
575 return fit_generator(
576 model,
577 x,
578 steps_per_epoch=steps_per_epoch,
579 epochs=epochs,
580 verbose=verbose,
581 callbacks=callbacks,
582 validation_data=validation_data,
583 validation_steps=validation_steps,
584 validation_freq=validation_freq,
585 class_weight=class_weight,
586 max_queue_size=max_queue_size,
587 workers=workers,
588 use_multiprocessing=use_multiprocessing,
589 shuffle=shuffle,
590 initial_epoch=initial_epoch,
591 steps_name='steps_per_epoch')
593 def evaluate(self,
594 model,
595 x=None,
596 y=None,
597 batch_size=None,
598 verbose=1,
599 sample_weight=None,
600 steps=None,
601 callbacks=None,
602 max_queue_size=10,
603 workers=1,
604 use_multiprocessing=False):
605 model._validate_or_infer_batch_size(batch_size, steps, x)
606 training_utils_v1.check_generator_arguments(y, sample_weight)
607 return evaluate_generator(
608 model,
609 x,
610 steps=steps,
611 verbose=verbose,
612 callbacks=callbacks,
613 max_queue_size=max_queue_size,
614 workers=workers,
615 use_multiprocessing=use_multiprocessing)
617 def predict(self,
618 model,
619 x,
620 batch_size=None,
621 verbose=0,
622 steps=None,
623 callbacks=None,
624 max_queue_size=10,
625 workers=1,
626 use_multiprocessing=False):
627 model._validate_or_infer_batch_size(batch_size, steps, x)
628 return predict_generator(
629 model,
630 x,
631 steps=steps,
632 verbose=verbose,
633 callbacks=callbacks,
634 max_queue_size=max_queue_size,
635 workers=workers,
636 use_multiprocessing=use_multiprocessing)
639class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
640 """A non-distributed Dataset or iterator in eager execution."""
642 def fit(self,
643 model,
644 x=None,
645 y=None,
646 batch_size=None,
647 epochs=1,
648 verbose=1,
649 callbacks=None,
650 validation_split=0.,
651 validation_data=None,
652 shuffle=True,
653 class_weight=None,
654 sample_weight=None,
655 initial_epoch=0,
656 steps_per_epoch=None,
657 validation_steps=None,
658 validation_freq=1,
659 **kwargs):
660 model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
661 # Make sure that y, sample_weights, validation_split are not passed.
662 training_utils_v1.validate_dataset_input(x, y, sample_weight,
663 validation_split)
664 if (isinstance(x, (data_types.DatasetV1, data_types.DatasetV2)) and
665 shuffle):
666 training_utils_v1.verify_dataset_shuffled(x)
668 return fit_generator(
669 model,
670 x,
671 steps_per_epoch=steps_per_epoch,
672 epochs=epochs,
673 verbose=verbose,
674 callbacks=callbacks,
675 validation_data=validation_data,
676 validation_steps=validation_steps,
677 validation_freq=validation_freq,
678 class_weight=class_weight,
679 workers=0,
680 shuffle=shuffle,
681 initial_epoch=initial_epoch,
682 steps_name='steps_per_epoch')
684 def evaluate(self,
685 model,
686 x=None,
687 y=None,
688 batch_size=None,
689 verbose=1,
690 sample_weight=None,
691 steps=None,
692 callbacks=None,
693 **kwargs):
694 model._validate_or_infer_batch_size(batch_size, steps, x)
695 # Make sure that y, sample_weights, validation_split are not passed.
696 training_utils_v1.validate_dataset_input(x, y, sample_weight)
697 return evaluate_generator(
698 model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks)
700 def predict(self,
701 model,
702 x,
703 batch_size=None,
704 verbose=0,
705 steps=None,
706 callbacks=None,
707 **kwargs):
708 model._validate_or_infer_batch_size(batch_size, steps, x)
709 return predict_generator(
710 model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks)
713class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop):
714 """TrainingLoop that handle inputs like python generator.
716 This is the default handler for most of the input data types, includes
717 symbolic tensors or Numpy array-like, Datasets and iterators in graph mode
718 (since they generate symbolic tensors). This Function is used to handle model
719 with `run_eagerly` = True.
720 """
722 def fit(self,
723 model,
724 x=None,
725 y=None,
726 batch_size=None,
727 epochs=1,
728 verbose=1,
729 callbacks=None,
730 validation_split=0.,
731 validation_data=None,
732 shuffle=True,
733 class_weight=None,
734 sample_weight=None,
735 initial_epoch=0,
736 steps_per_epoch=None,
737 validation_steps=None,
738 validation_freq=1,
739 **kwargs):
740 batch_size = model._validate_or_infer_batch_size(batch_size,
741 steps_per_epoch, x)
742 x, y, sample_weights = model._standardize_user_data(
743 x,
744 y,
745 sample_weight=sample_weight,
746 class_weight=class_weight,
747 batch_size=batch_size,
748 check_steps=True,
749 steps_name='steps_per_epoch',
750 steps=steps_per_epoch,
751 validation_split=validation_split,
752 shuffle=shuffle)
754 if validation_data:
755 validation_data = model._prepare_validation_data(validation_data,
756 batch_size,
757 validation_steps)
758 elif validation_split and 0. < validation_split < 1.:
759 (x, y, sample_weights, val_x, val_y,
760 val_sample_weights) = (
761 training_utils_v1.split_training_and_validation_data(
762 x, y, sample_weights, validation_split))
763 validation_data = (val_x, val_y, val_sample_weights)
764 else:
765 if validation_steps:
766 raise ValueError('`validation_steps` should not be specified if '
767 '`validation_data` is None.')
769 return fit_generator(
770 model, (x, y, sample_weights),
771 steps_per_epoch=steps_per_epoch,
772 batch_size=batch_size,
773 epochs=epochs,
774 verbose=verbose,
775 callbacks=callbacks,
776 validation_data=validation_data,
777 validation_steps=validation_steps,
778 validation_freq=validation_freq,
779 workers=0,
780 shuffle=shuffle,
781 initial_epoch=initial_epoch,
782 steps_name='steps_per_epoch')
784 def evaluate(self,
785 model,
786 x=None,
787 y=None,
788 batch_size=None,
789 verbose=1,
790 sample_weight=None,
791 steps=None,
792 callbacks=None,
793 **kwargs):
794 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
795 x, y, sample_weights = model._standardize_user_data(
796 x,
797 y,
798 sample_weight=sample_weight,
799 batch_size=batch_size,
800 check_steps=True,
801 steps_name='steps',
802 steps=steps)
803 return evaluate_generator(
804 model, (x, y, sample_weights),
805 steps=steps,
806 batch_size=batch_size,
807 verbose=verbose,
808 workers=0,
809 callbacks=callbacks)
811 def predict(self,
812 model,
813 x,
814 batch_size=None,
815 verbose=0,
816 steps=None,
817 callbacks=None,
818 **kwargs):
819 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
820 x, _, _ = model._standardize_user_data(
821 x, check_steps=True, steps_name='steps', steps=steps)
822 return predict_generator(
823 model,
824 x,
825 steps=steps,
826 batch_size=batch_size,
827 verbose=verbose,
828 workers=0,
829 callbacks=callbacks)