Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training_generator_v1.py: 15%
233 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to Python generators of array data.
16"""
18import functools
19import math
21import numpy as np
22import tensorflow.compat.v2 as tf
24from keras.src import backend
25from keras.src import callbacks as cbks
26from keras.src.engine import training_utils
27from keras.src.engine import training_utils_v1
28from keras.src.utils import data_utils
29from keras.src.utils import generic_utils
30from keras.src.utils.mode_keys import ModeKeys
32# isort: off
33from tensorflow.python.platform import tf_logging as logging
36def model_iteration(
37 model,
38 data,
39 steps_per_epoch=None,
40 epochs=1,
41 verbose=1,
42 callbacks=None,
43 validation_data=None,
44 validation_steps=None,
45 validation_freq=1,
46 class_weight=None,
47 max_queue_size=10,
48 workers=1,
49 use_multiprocessing=False,
50 shuffle=False,
51 initial_epoch=0,
52 mode=ModeKeys.TRAIN,
53 batch_size=None,
54 steps_name="steps",
55 **kwargs,
56):
57 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT.
59 Args:
60 model: Keras Model instance.
61 data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or
62 `(x, y, sample_weights)`) or a generator or
63 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
64 steps_per_epoch: Total number of steps (batches of samples) before
65 declaring one epoch finished and starting the next epoch. Ignored with
66 the default value of `None`.
67 epochs: Number of times to iterate over the data.
68 verbose: 0, 1, or 2. Verbosity mode.
69 0 = silent, 1 = progress bar, 2 = one line per epoch.
70 Note that the progress bar is not particularly useful when
71 logged to a file, so verbose=2 is recommended when not running
72 interactively (eg, in a production environment).
73 callbacks: List of callbacks to be called during training.
74 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or
75 `(x, y)` or `(x, y, sample_weights)`) or a generator or
76 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
77 validation_steps: Total number of steps (batches of samples) before
78 declaring validation finished.
79 validation_freq: Only relevant if validation data is provided. Integer
80 or `collections.abc.Container` instance (e.g. list, tuple, etc.). If
81 an integer, specifies how many training epochs to run before a new
82 validation run is performed, e.g. `validation_freq=2` runs validation
83 every 2 epochs. If a Container, specifies the epochs on which to run
84 validation, e.g. `validation_freq=[1, 2, 10]` runs validation at the
85 end of the 1st, 2nd, and 10th epochs.
86 class_weight: Dictionary mapping class indices to a weight for the
87 class.
88 max_queue_size: Integer. Maximum size for the generator queue. If
89 unspecified, `max_queue_size` will default to 10.
90 workers: Integer. Maximum number of processes to spin up when using
91 process-based threading. If unspecified, `workers` will default to 1.
92 If 0, will execute the generator on the main thread.
93 use_multiprocessing: Boolean. If `True`, use process-based threading. If
94 unspecified, `use_multiprocessing` will default to `False`. Note that
95 because this implementation relies on multiprocessing, you should not
96 pass non-picklable arguments to the generator as they can't be passed
97 easily to children processes.
98 shuffle: Boolean. Whether to shuffle the order of the batches at the
99 beginning of each epoch. Only used with instances of `Sequence`
100 (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not
101 `None`.
102 initial_epoch: Epoch at which to start training (useful for resuming a
103 previous training run).
104 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
105 batch_size: Integer batch size or None if unknown. Will only be used if
106 `data` is in NumPy/Tensor format.
107 steps_name: The string name of the steps argument, either `steps`,
108 `validation_steps`, or `steps_per_epoch`. Only used for error message
109 formatting.
110 **kwargs: Additional arguments for backwards compatibility. `steps` is
111 accepted as an alias for `steps_per_epoch`.
113 Returns:
114 - In TRAIN mode: `History` object.
115 - In TEST mode: Evaluation metrics.
116 - In PREDICT mode: Outputs of the Model called on inputs.
118 Raises:
119 ValueError: in case of invalid arguments.
120 """
121 if "steps" in kwargs:
122 steps_per_epoch = kwargs["steps"]
124 # Determine the number of steps per epoch and whether we should reset the
125 # dataset at the end of each epoch.
126 reset_dataset_after_each_epoch = False
127 original_dataset = None
128 is_dataset = isinstance(data, (tf.data.Dataset, tf.compat.v1.data.Dataset))
129 if is_dataset:
130 original_dataset = data
131 if steps_per_epoch is None:
132 reset_dataset_after_each_epoch = True
133 steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
134 model,
135 data,
136 steps_per_epoch,
137 epochs=epochs,
138 steps_name=steps_name,
139 )
141 # Convert to a format that supports `next(generator)`.
142 generator, steps_per_epoch = convert_to_generator_like(
143 data,
144 steps_per_epoch=steps_per_epoch,
145 batch_size=batch_size,
146 epochs=epochs - initial_epoch,
147 shuffle=shuffle,
148 )
150 do_validation = validation_data is not None
151 is_sequence = isinstance(generator, data_utils.Sequence)
152 _validate_arguments(
153 is_sequence,
154 is_dataset,
155 use_multiprocessing,
156 workers,
157 steps_per_epoch,
158 validation_data,
159 validation_steps,
160 mode,
161 kwargs,
162 )
164 batch_function = _make_execution_function(
165 model, mode, class_weight=class_weight
166 )
168 # Create the queue for the generator.
169 enqueuer = None
170 if not is_dataset:
171 generator, enqueuer = _make_enqueued_generator(
172 generator,
173 workers=workers,
174 use_multiprocessing=use_multiprocessing,
175 max_queue_size=max_queue_size,
176 shuffle=shuffle,
177 )
179 num_samples_or_steps, use_steps = _get_num_samples_or_steps(
180 data, steps_per_epoch
181 )
183 count_mode = "steps" if use_steps else "samples"
184 callbacks = cbks.configure_callbacks(
185 callbacks,
186 model,
187 do_validation=do_validation,
188 epochs=epochs,
189 steps_per_epoch=steps_per_epoch,
190 batch_size=batch_size,
191 samples=num_samples_or_steps,
192 count_mode=count_mode,
193 verbose=verbose,
194 mode=mode,
195 )
197 if mode == ModeKeys.PREDICT:
198 aggregator = training_utils_v1.OutputsAggregator(
199 True, steps=steps_per_epoch
200 )
201 else:
202 aggregator = training_utils_v1.MetricsAggregator(
203 True, steps=steps_per_epoch
204 )
206 should_set_learning_phase = tf.executing_eagerly() and model.run_eagerly
207 if should_set_learning_phase:
208 learning_phase_scope = backend.eager_learning_phase_scope(
209 1 if mode == ModeKeys.TRAIN else 0
210 )
211 learning_phase_scope.__enter__()
213 callbacks.model.stop_training = False
214 callbacks._call_begin_hook(mode)
216 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
217 initial_epoch, mode
218 )
220 for epoch in range(initial_epoch, epochs):
221 if callbacks.model.stop_training:
222 break
224 # Setup work for each epoch.
225 model.reset_metrics()
226 epoch_logs = {}
227 if mode == ModeKeys.TRAIN:
228 callbacks.on_epoch_begin(epoch, epoch_logs)
230 if steps_per_epoch is None:
231 # Loop over dataset until `OutOfRangeError` is raised.
232 target_steps = np.inf
233 else:
234 # Loop over dataset for the specified number of steps.
235 target_steps = steps_per_epoch
237 step = 0
238 while step < target_steps:
239 batch_data = _get_next_batch(generator)
240 if batch_data is None:
241 if is_dataset:
242 # The dataset passed by the user ran out of batches. Now we
243 # know the cardinality of the dataset. If steps_per_epoch
244 # was specified, then running out of data is unexpected, so
245 # we stop training and inform the user.
246 if steps_per_epoch:
247 callbacks.model.stop_training = True
248 logging.warning(
249 "Your dataset ran out of data; interrupting "
250 "training. Make sure that your dataset can "
251 "generate at least `%s * epochs` batches (in "
252 "this case, %d batches). You may need to use "
253 "the repeat() function when building your dataset."
254 % (steps_name, steps_per_epoch * epochs)
255 )
256 elif step > 0:
257 steps_per_epoch = step
258 aggregator.steps = steps_per_epoch
259 else:
260 # We ran out of batches while the user passed an iterator
261 # (legacy).
262 callbacks.model.stop_training = True
263 logging.warning(
264 "Your dataset iterator ran out of data; "
265 "interrupting training. Make sure that your iterator "
266 "can generate at least `%s * epochs` "
267 "batches (in this case, %d batches). You may need to"
268 "use the repeat() function when building your "
269 "dataset." % (steps_name, steps_per_epoch * epochs)
270 )
271 break
273 # `batch_size` used for validation data if validation
274 # data is NumPy/EagerTensors.
275 batch_size = int(tf.nest.flatten(batch_data)[0].shape[0])
277 # Callbacks batch begin.
278 batch_logs = {"batch": step, "size": batch_size}
279 callbacks._call_batch_hook(mode, "begin", step, batch_logs)
281 is_deferred = not model._is_compiled
282 batch_outs = batch_function(*batch_data)
283 if not isinstance(batch_outs, list):
284 batch_outs = [batch_outs]
286 if step == 0:
287 aggregator.create(batch_outs)
289 if is_deferred:
290 # Set callbacks params. We do this here when model is
291 # compiled only in the first iteration of this loop
292 # (deferred build scenario).
293 cbks.set_callback_parameters(
294 callbacks,
295 model,
296 do_validation=do_validation,
297 batch_size=batch_size,
298 epochs=epochs,
299 steps_per_epoch=steps_per_epoch,
300 samples=num_samples_or_steps,
301 verbose=verbose,
302 mode=mode,
303 )
305 # Aggregate results.
306 aggregator.aggregate(batch_outs)
308 # Callbacks batch end.
309 batch_logs = callbacks.make_logs(
310 model, batch_logs, batch_outs, mode
311 )
312 callbacks._call_batch_hook(mode, "end", step, batch_logs)
313 step += 1
315 if callbacks.model.stop_training:
316 break
318 aggregator.finalize()
319 results = aggregator.results
320 epoch_logs = callbacks.make_logs(model, epoch_logs, results, mode)
321 if len(results) == 1:
322 results = results[0]
324 # Run the test loop every epoch during training.
325 if (
326 do_validation
327 and training_utils_v1.should_run_validation(validation_freq, epoch)
328 and not callbacks.model.stop_training
329 ):
330 val_results = model_iteration(
331 model,
332 validation_data,
333 steps_per_epoch=validation_steps,
334 batch_size=batch_size,
335 class_weight=class_weight,
336 workers=workers,
337 use_multiprocessing=use_multiprocessing,
338 max_queue_size=max_queue_size,
339 callbacks=callbacks,
340 verbose=verbose,
341 mode=ModeKeys.TEST,
342 steps_name="validation_steps",
343 )
345 if not isinstance(val_results, list):
346 val_results = [val_results]
347 epoch_logs = callbacks.make_logs(
348 model, epoch_logs, val_results, mode, prefix="val_"
349 )
351 if mode == ModeKeys.TRAIN:
352 # Epochs only apply to `fit`.
353 callbacks.on_epoch_end(epoch, epoch_logs)
355 # Recreate dataset iterator for the next epoch.
356 if reset_dataset_after_each_epoch and epoch < epochs - 1:
357 generator = tf.compat.v1.data.make_one_shot_iterator(
358 original_dataset
359 )
361 model._successful_loop_finish = True
362 callbacks._call_end_hook(mode)
364 if enqueuer is not None:
365 enqueuer.stop()
367 if should_set_learning_phase:
368 learning_phase_scope.__exit__(None, None, None)
370 if mode == ModeKeys.TRAIN:
371 return model.history
372 return results
375# Maintain compatibility with the existing names.
376fit_generator = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
377evaluate_generator = functools.partial(
378 model_iteration, mode=ModeKeys.TEST, shuffle=False
379)
380predict_generator = functools.partial(
381 model_iteration, mode=ModeKeys.PREDICT, shuffle=False
382)
385def _get_next_batch(generator):
386 """Retrieves the next batch of input data."""
387 try:
388 generator_output = next(generator)
389 except (StopIteration, tf.errors.OutOfRangeError):
390 return None
392 if not isinstance(generator_output, tuple):
393 # Always wrap in a tuple.
394 generator_output = (generator_output,)
395 if len(generator_output) not in [1, 2, 3]:
396 raise ValueError(
397 "Output of generator should be a tuple of 1 or 2 or 3 "
398 "elements: (input,) or (input, target) or "
399 "(input, target, sample_weights). Received {}".format(
400 generator_output
401 )
402 )
403 return generator_output
406def _validate_arguments(
407 is_sequence,
408 is_dataset,
409 use_multiprocessing,
410 workers,
411 steps_per_epoch,
412 validation_data,
413 validation_steps,
414 mode,
415 kwargs,
416):
417 """Raises errors if arguments are invalid.
419 Args:
420 is_sequence: Boolean, whether data is a `keras.utils.data_utils.Sequence`
421 instance.
422 is_dataset: Boolean, whether data is a dataset instance.
423 use_multiprocessing: Boolean. If `True`, use process-based threading. If
424 unspecified, `use_multiprocessing` will default to `False`. Note that
425 because this implementation relies on multiprocessing, you should not
426 pass non-picklable arguments to the generator as they can't be passed
427 easily to children processes.
428 workers: Integer. Maximum number of processes to spin up when using
429 process-based threading. If unspecified, `workers` will default to 1. If
430 0, will execute the generator on the main thread.
431 steps_per_epoch: Total number of steps (batches of samples) before
432 declaring one epoch finished and starting the next epoch. Ignored with
433 the default value of `None`.
434 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or
435 `(x, y)` or `(x, y, sample_weights)`) or a generator or
436 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
437 validation_steps: Total number of steps (batches of samples) before
438 declaring validation finished.
439 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
440 kwargs: Additional arguments for backwards compatibility.
442 Raises:
443 ValueError: If `steps_per_epoch` or `validation_steps` are not passed
444 for data types that require them, or if unrecognized keyword
445 arguments are passed.
446 """
447 if not is_sequence and use_multiprocessing and workers > 1:
448 logging.warning(
449 UserWarning(
450 "Using a generator with `use_multiprocessing=True`"
451 " and multiple workers may duplicate your data."
452 " Please consider using the `keras.utils.Sequence`"
453 " class."
454 )
455 )
457 if steps_per_epoch is None and not is_dataset:
458 arg_name = "steps_per_epoch" if mode == ModeKeys.TRAIN else "steps"
459 raise ValueError(
460 f"Please specify the number of steps via the `{arg_name}` argument."
461 )
463 val_gen = data_utils.is_generator_or_sequence(
464 validation_data
465 ) or isinstance(validation_data, tf.data.Iterator)
466 if (
467 val_gen
468 and not isinstance(validation_data, data_utils.Sequence)
469 and not validation_steps
470 ):
471 raise ValueError("Please specify the `validation_steps` argument.")
473 if any(k != "steps" for k in kwargs):
474 raise ValueError(
475 f"Invalid arguments passed: {[k for k in kwargs if k != 'steps']}"
476 )
479def convert_to_generator_like(
480 data, batch_size=None, steps_per_epoch=None, epochs=1, shuffle=False
481):
482 """Make a generator out of NumPy or EagerTensor inputs.
484 Args:
485 data: Either a generator or `keras.utils.data_utils.Sequence` object or
486 `Dataset`, `Iterator`, or a {1,2,3}-tuple of NumPy arrays or
487 EagerTensors. If a tuple, the elements represent `(x, y,
488 sample_weights)` and may be `None` or `[None]`.
489 batch_size: Used when creating a generator out of tuples of NumPy arrays
490 or EagerTensors.
491 steps_per_epoch: Steps of the generator to run each epoch. If `None` the
492 number of steps will be read from the data (for
493 `keras.utils.data_utils.Sequence` types).
494 epochs: Total number of epochs to run.
495 shuffle: Whether the data should be shuffled.
497 Returns:
498 - Generator, `keras.utils.data_utils.Sequence`, or `Iterator`.
500 Raises:
501 - ValueError: If `batch_size` is not provided for NumPy or EagerTensor
502 inputs.
503 """
504 if isinstance(data, tuple):
505 # Scrub `Nones` that might have been passed for `targets`,
506 # `sample_weights`.
507 data = tuple(
508 ele
509 for ele in data
510 if not all(e is None for e in tf.nest.flatten(ele))
511 )
513 if data_utils.is_generator_or_sequence(data) or isinstance(
514 data, tf.data.Iterator
515 ):
516 if isinstance(data, data_utils.Sequence):
517 if steps_per_epoch is None:
518 steps_per_epoch = len(data)
519 return data, steps_per_epoch
520 if isinstance(data, tf.data.Dataset):
521 return tf.compat.v1.data.make_one_shot_iterator(data), steps_per_epoch
523 # Create generator from NumPy or EagerTensor Input.
524 num_samples = int(tf.nest.flatten(data)[0].shape[0])
525 if batch_size is None:
526 raise ValueError(
527 "When passing input data as arrays, do not specify "
528 "`steps_per_epoch`/`steps` argument. "
529 "Please use `batch_size` instead."
530 )
531 steps_per_epoch = int(math.ceil(num_samples / batch_size))
533 def _gen(data):
534 """Makes a generator out of a structure of NumPy/EagerTensors."""
535 index_array = np.arange(num_samples)
536 for _ in range(epochs):
537 if shuffle:
538 np.random.shuffle(index_array)
539 batches = generic_utils.make_batches(num_samples, batch_size)
540 for batch_start, batch_end in batches:
541 batch_ids = index_array[batch_start:batch_end]
542 flat_batch_data = training_utils.slice_arrays(
543 tf.nest.flatten(data), batch_ids, contiguous=(not shuffle)
544 )
545 yield tf.nest.pack_sequence_as(data, flat_batch_data)
547 return _gen(data), steps_per_epoch
550def _make_enqueued_generator(
551 generator,
552 workers=1,
553 use_multiprocessing=False,
554 max_queue_size=10,
555 shuffle=False,
556):
557 """Create a buffered queue of next elements of the generator."""
558 is_sequence = isinstance(generator, data_utils.Sequence)
559 enqueuer = None
560 if workers > 0:
561 if is_sequence:
562 enqueuer = data_utils.OrderedEnqueuer(
563 generator,
564 use_multiprocessing=use_multiprocessing,
565 shuffle=shuffle,
566 )
567 else:
568 enqueuer = data_utils.GeneratorEnqueuer(
569 generator, use_multiprocessing=use_multiprocessing
570 )
571 enqueuer.start(workers=workers, max_queue_size=max_queue_size)
572 output_generator = enqueuer.get()
573 else:
574 if is_sequence:
575 output_generator = data_utils.iter_sequence_infinite(generator)
576 else:
577 output_generator = generator
578 return output_generator, enqueuer
581def _make_execution_function(model, mode, class_weight=None):
582 """Makes function to run one step of model execution."""
583 if mode == ModeKeys.TRAIN:
584 f = functools.partial(model.train_on_batch, class_weight=class_weight)
585 elif mode == ModeKeys.TEST:
586 f = model.test_on_batch
587 else:
588 # Match signature of other modes to allow
589 # 1, 2, or 3-tuples from generator
590 def predict_on_batch(x, y=None, sample_weights=None):
591 return model.predict_on_batch(x)
593 f = predict_on_batch
595 # Maintain stateful metrics across batch-level calls.
596 if mode != ModeKeys.PREDICT:
597 f = functools.partial(f, reset_metrics=False)
599 return f
602def _get_num_samples_or_steps(data, steps_per_epoch):
603 """Returns number of samples or steps, and whether to use steps count
604 mode."""
605 flat_inputs = tf.nest.flatten(data)
606 if hasattr(flat_inputs[0], "shape"):
607 return int(flat_inputs[0].shape[0]), False
608 return steps_per_epoch, True
611class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
612 """Generator-like.
614 Input is Python generator, or Sequence object.
616 The difference between this class and `GeneratorLikeTrainingFunction` is
617 that this class only handles inputs that with x, y and sample_weight fused
618 into one param.
619 """
621 def fit(
622 self,
623 model,
624 x=None,
625 y=None,
626 batch_size=None,
627 epochs=1,
628 verbose=1,
629 callbacks=None,
630 validation_split=0.0,
631 validation_data=None,
632 shuffle=True,
633 class_weight=None,
634 sample_weight=None,
635 initial_epoch=0,
636 steps_per_epoch=None,
637 validation_steps=None,
638 validation_freq=1,
639 max_queue_size=10,
640 workers=1,
641 use_multiprocessing=False,
642 ):
643 model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
644 training_utils_v1.check_generator_arguments(
645 y, sample_weight, validation_split=validation_split
646 )
647 return fit_generator(
648 model,
649 x,
650 steps_per_epoch=steps_per_epoch,
651 epochs=epochs,
652 verbose=verbose,
653 callbacks=callbacks,
654 validation_data=validation_data,
655 validation_steps=validation_steps,
656 validation_freq=validation_freq,
657 class_weight=class_weight,
658 max_queue_size=max_queue_size,
659 workers=workers,
660 use_multiprocessing=use_multiprocessing,
661 shuffle=shuffle,
662 initial_epoch=initial_epoch,
663 steps_name="steps_per_epoch",
664 )
666 def evaluate(
667 self,
668 model,
669 x=None,
670 y=None,
671 batch_size=None,
672 verbose=1,
673 sample_weight=None,
674 steps=None,
675 callbacks=None,
676 max_queue_size=10,
677 workers=1,
678 use_multiprocessing=False,
679 ):
680 model._validate_or_infer_batch_size(batch_size, steps, x)
681 training_utils_v1.check_generator_arguments(y, sample_weight)
682 return evaluate_generator(
683 model,
684 x,
685 steps=steps,
686 verbose=verbose,
687 callbacks=callbacks,
688 max_queue_size=max_queue_size,
689 workers=workers,
690 use_multiprocessing=use_multiprocessing,
691 )
693 def predict(
694 self,
695 model,
696 x,
697 batch_size=None,
698 verbose=0,
699 steps=None,
700 callbacks=None,
701 max_queue_size=10,
702 workers=1,
703 use_multiprocessing=False,
704 ):
705 model._validate_or_infer_batch_size(batch_size, steps, x)
706 return predict_generator(
707 model,
708 x,
709 steps=steps,
710 verbose=verbose,
711 callbacks=callbacks,
712 max_queue_size=max_queue_size,
713 workers=workers,
714 use_multiprocessing=use_multiprocessing,
715 )
718class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
719 """A non-distributed Dataset or iterator in eager execution."""
721 def fit(
722 self,
723 model,
724 x=None,
725 y=None,
726 batch_size=None,
727 epochs=1,
728 verbose=1,
729 callbacks=None,
730 validation_split=0.0,
731 validation_data=None,
732 shuffle=True,
733 class_weight=None,
734 sample_weight=None,
735 initial_epoch=0,
736 steps_per_epoch=None,
737 validation_steps=None,
738 validation_freq=1,
739 **kwargs,
740 ):
741 model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
742 # Make sure that y, sample_weights, validation_split are not passed.
743 training_utils_v1.validate_dataset_input(
744 x, y, sample_weight, validation_split
745 )
746 if (
747 isinstance(x, (tf.compat.v1.data.Dataset, tf.data.Dataset))
748 and shuffle
749 ):
750 training_utils_v1.verify_dataset_shuffled(x)
752 return fit_generator(
753 model,
754 x,
755 steps_per_epoch=steps_per_epoch,
756 epochs=epochs,
757 verbose=verbose,
758 callbacks=callbacks,
759 validation_data=validation_data,
760 validation_steps=validation_steps,
761 validation_freq=validation_freq,
762 class_weight=class_weight,
763 workers=0,
764 shuffle=shuffle,
765 initial_epoch=initial_epoch,
766 steps_name="steps_per_epoch",
767 )
769 def evaluate(
770 self,
771 model,
772 x=None,
773 y=None,
774 batch_size=None,
775 verbose=1,
776 sample_weight=None,
777 steps=None,
778 callbacks=None,
779 **kwargs,
780 ):
781 model._validate_or_infer_batch_size(batch_size, steps, x)
782 # Make sure that y, sample_weights, validation_split are not passed.
783 training_utils_v1.validate_dataset_input(x, y, sample_weight)
784 return evaluate_generator(
785 model,
786 x,
787 steps=steps,
788 verbose=verbose,
789 workers=0,
790 callbacks=callbacks,
791 )
793 def predict(
794 self,
795 model,
796 x,
797 batch_size=None,
798 verbose=0,
799 steps=None,
800 callbacks=None,
801 **kwargs,
802 ):
803 model._validate_or_infer_batch_size(batch_size, steps, x)
804 return predict_generator(
805 model,
806 x,
807 steps=steps,
808 verbose=verbose,
809 workers=0,
810 callbacks=callbacks,
811 )
814class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop):
815 """TrainingLoop that handle inputs like python generator.
817 This is the default handler for most of the input data types, includes
818 symbolic tensors or Numpy array-like, Datasets and iterators in graph mode
819 (since they generate symbolic tensors). This Function is used to handle
820 model with `run_eagerly` = True.
821 """
823 def fit(
824 self,
825 model,
826 x=None,
827 y=None,
828 batch_size=None,
829 epochs=1,
830 verbose=1,
831 callbacks=None,
832 validation_split=0.0,
833 validation_data=None,
834 shuffle=True,
835 class_weight=None,
836 sample_weight=None,
837 initial_epoch=0,
838 steps_per_epoch=None,
839 validation_steps=None,
840 validation_freq=1,
841 **kwargs,
842 ):
843 batch_size = model._validate_or_infer_batch_size(
844 batch_size, steps_per_epoch, x
845 )
846 x, y, sample_weights = model._standardize_user_data(
847 x,
848 y,
849 sample_weight=sample_weight,
850 class_weight=class_weight,
851 batch_size=batch_size,
852 check_steps=True,
853 steps_name="steps_per_epoch",
854 steps=steps_per_epoch,
855 validation_split=validation_split,
856 shuffle=shuffle,
857 )
859 if validation_data:
860 validation_data = model._prepare_validation_data(
861 validation_data, batch_size, validation_steps
862 )
863 elif validation_split and 0.0 < validation_split < 1.0:
864 (
865 x,
866 y,
867 sample_weights,
868 val_x,
869 val_y,
870 val_sample_weights,
871 ) = training_utils_v1.split_training_and_validation_data(
872 x, y, sample_weights, validation_split
873 )
874 validation_data = (val_x, val_y, val_sample_weights)
875 else:
876 if validation_steps:
877 raise ValueError(
878 "`validation_steps` should not be specified if "
879 "`validation_data` is None."
880 )
882 return fit_generator(
883 model,
884 (x, y, sample_weights),
885 steps_per_epoch=steps_per_epoch,
886 batch_size=batch_size,
887 epochs=epochs,
888 verbose=verbose,
889 callbacks=callbacks,
890 validation_data=validation_data,
891 validation_steps=validation_steps,
892 validation_freq=validation_freq,
893 workers=0,
894 shuffle=shuffle,
895 initial_epoch=initial_epoch,
896 steps_name="steps_per_epoch",
897 )
899 def evaluate(
900 self,
901 model,
902 x=None,
903 y=None,
904 batch_size=None,
905 verbose=1,
906 sample_weight=None,
907 steps=None,
908 callbacks=None,
909 **kwargs,
910 ):
911 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
912 x, y, sample_weights = model._standardize_user_data(
913 x,
914 y,
915 sample_weight=sample_weight,
916 batch_size=batch_size,
917 check_steps=True,
918 steps_name="steps",
919 steps=steps,
920 )
921 return evaluate_generator(
922 model,
923 (x, y, sample_weights),
924 steps=steps,
925 batch_size=batch_size,
926 verbose=verbose,
927 workers=0,
928 callbacks=callbacks,
929 )
931 def predict(
932 self,
933 model,
934 x,
935 batch_size=None,
936 verbose=0,
937 steps=None,
938 callbacks=None,
939 **kwargs,
940 ):
941 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
942 x, _, _ = model._standardize_user_data(
943 x, check_steps=True, steps_name="steps", steps=steps
944 )
945 return predict_generator(
946 model,
947 x,
948 steps=steps,
949 batch_size=batch_size,
950 verbose=verbose,
951 workers=0,
952 callbacks=callbacks,
953 )