Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training_distributed_v1.py: 10%
310 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to distributed training."""
17import numpy as np
18import tensorflow.compat.v2 as tf
20from keras.src import backend
21from keras.src import callbacks as cbks
22from keras.src.distribute import distribute_coordinator_utils as dc
23from keras.src.distribute import distributed_training_utils_v1 as dist_utils
24from keras.src.engine import partial_batch_padding_handler as padding_util
25from keras.src.engine import training_arrays_v1
26from keras.src.engine import training_utils_v1
27from keras.src.utils.generic_utils import Progbar
28from keras.src.utils.mode_keys import ModeKeys
30# isort: off
31from tensorflow.python.distribute import input_lib
32from tensorflow.python.platform import tf_logging as logging
35def _per_replica_execution_function(model, mode):
36 exec_func = model._make_execution_function(mode)
37 return (
38 exec_func.inputs,
39 exec_func.outputs,
40 exec_func.updates_op,
41 exec_func.session_kwargs,
42 )
45def _build_model(strategy, model, mode, inputs, targets=None):
46 if model._compile_distribution:
47 dist_utils.clone_model_on_replicas(
48 model, strategy, mode, inputs=inputs, targets=targets
49 )
50 else:
51 dist_utils._build_distributed_network(
52 model, strategy, mode, inputs, targets
53 )
56def _make_train_step_fn(model, mode, strategy, output_labels):
57 """Create step fn.
59 Args:
60 model: a Keras Model instance.
61 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
62 strategy: a `tf.distribute.Strategy` instance.
63 output_labels: the output labels for the step function.
65 Returns:
66 A step function to run by `tf.distribute.Strategy`.
67 """
69 def _step_fn(ctx, inputs):
70 """A step fn that returns update ops."""
71 if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
72 inputs, targets = inputs
73 else:
74 targets = None
76 # When input feature is a dictionary of tensors, dictionary is
77 # flattended to an array and passed as a model input. This results in
78 # input mismatch when model input layer names are not sorted in
79 # alphabetical order as `nest.flatten()`sorts dictionary elements by
80 # keys. As so, transform input tensors into an array and order it along
81 # `model._feed_input_names`.
82 if isinstance(inputs, dict):
83 inputs = [
84 inputs[input_name] for input_name in model._feed_input_names
85 ]
87 _build_model(strategy, model, mode, inputs, targets)
89 (
90 grouped_inputs,
91 grouped_outputs,
92 grouped_updates,
93 grouped_session_args,
94 ) = strategy.extended.call_for_each_replica(
95 _per_replica_execution_function,
96 args=(dist_utils.get_distributed_model(model, mode), mode),
97 )
98 (
99 all_inputs,
100 all_outputs,
101 all_updates,
102 all_session_args,
103 ) = dist_utils.unwrap_values(
104 strategy,
105 grouped_inputs,
106 grouped_outputs,
107 grouped_updates,
108 grouped_session_args,
109 )
110 combined_fn = backend.function(
111 all_inputs,
112 all_outputs,
113 updates=all_updates,
114 name="distributed_" + str(mode) + "_function",
115 **all_session_args
116 )
118 for label, output in zip(output_labels, combined_fn.outputs):
119 if label == "loss":
120 reduce_op = tf.distribute.ReduceOp.SUM
121 else:
122 # We reduce all other metrics using mean for now. This is
123 # temporary workaround until new metrics are in place.
124 reduce_op = tf.distribute.ReduceOp.MEAN
125 ctx.set_last_step_output(label, output, reduce_op)
127 # TODO(priyag, sourabhbajaj): Ignoring these things from the
128 # combined_fn: feed_dict, session kwargs, run options, run_metadata for
129 # now. These should be handled appropriately
130 return combined_fn.updates_op
132 return _step_fn
135def experimental_tpu_fit_loop(
136 model,
137 dataset,
138 epochs=100,
139 verbose=1,
140 callbacks=None,
141 initial_epoch=0,
142 steps_per_epoch=None,
143 val_dataset=None,
144 validation_steps=None,
145 validation_freq=1,
146):
147 """Fit loop for training with TPU tf.distribute.Strategy.
149 Args:
150 model: Keras Model instance.
151 dataset: Dataset that returns inputs and targets
152 epochs: Number of times to iterate over the data
153 verbose: Integer, Verbosity mode, 0, 1 or 2
154 callbacks: List of callbacks to be called during training
155 initial_epoch: Epoch at which to start training
156 (useful for resuming a previous training run)
157 steps_per_epoch: Total number of steps (batches of samples)
158 before declaring one epoch finished and starting the
159 next epoch. Ignored with the default value of `None`.
160 val_dataset: Dataset for validation data.
161 validation_steps: Number of steps to run validation for
162 (only if doing validation from data tensors).
163 Ignored with the default value of `None`.
164 validation_freq: Only relevant if validation data is provided. Integer
165 or `collections.abc.Container` instance (e.g. list, tuple, etc.). If
166 an integer, specifies how many training epochs to run before a new
167 validation run is performed, e.g. `validation_freq=2` runs
168 validation every 2 epochs. If a Container, specifies the epochs on
169 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
170 validation at the end of the 1st, 2nd, and 10th epochs.
172 Returns:
173 Returns `None`.
175 Raises:
176 ValueError: in case of invalid arguments.
177 """
178 mode = ModeKeys.TRAIN
180 current_strategy = model._distribution_strategy
181 iteration_value = min(
182 steps_per_epoch, current_strategy.extended.steps_per_run
183 )
184 steps_per_run = backend.variable(
185 value=iteration_value, dtype="int32", name="steps_per_run"
186 )
188 # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
189 iterator = dist_utils.get_iterator(dataset, current_strategy)
191 scope = dist_utils.distributed_scope(
192 strategy=current_strategy, learning_phase=1
193 )
194 scope.__enter__()
196 out_labels = model.metrics_names or []
198 step_fn = _make_train_step_fn(
199 model, ModeKeys.TRAIN, current_strategy, out_labels
200 )
202 # Add initial dummy values for loss and other metric tensors.
203 initial_loop_values = {}
204 initial_loop_values["loss"] = tf.constant(1e7)
205 for m in model._get_training_eval_metrics():
206 tensor = m.result()
207 initial_loop_values[m.name] = tf.zeros(tensor.shape, tensor.dtype)
209 ctx = current_strategy.extended.experimental_run_steps_on_iterator(
210 step_fn,
211 iterator,
212 iterations=steps_per_run,
213 initial_loop_values=initial_loop_values,
214 )
215 train_op = ctx.run_op
216 output_tensors = ctx.last_step_outputs
218 do_validation = bool(validation_steps)
220 if model._compile_distribution:
221 dist_utils._copy_weights_to_distributed_model(model, mode)
223 callbacks = cbks.configure_callbacks(
224 callbacks,
225 model,
226 do_validation=do_validation,
227 epochs=epochs,
228 steps_per_epoch=steps_per_epoch,
229 verbose=verbose,
230 count_mode="steps",
231 mode=mode,
232 )
234 # Calculate the steps each time on the device.
235 steps_to_run = [current_strategy.extended.steps_per_run] * (
236 steps_per_epoch // current_strategy.extended.steps_per_run
237 )
238 if steps_per_epoch % current_strategy.extended.steps_per_run:
239 steps_to_run.append(
240 steps_per_epoch % current_strategy.extended.steps_per_run
241 )
242 target_steps = len(steps_to_run)
244 callbacks._call_begin_hook(mode)
246 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
247 initial_epoch, mode
248 )
250 for epoch in range(initial_epoch, epochs):
251 dist_utils._reset_metrics(model)
252 callbacks.on_epoch_begin(epoch)
253 epoch_logs = {}
254 step_index = 0
255 prev_step_count = None
256 current_step = 0
257 while current_step < target_steps:
258 step_count = steps_to_run[current_step]
259 batch_logs = {
260 "batch": step_index,
261 "size": 1,
262 "num_steps": step_count,
263 }
264 callbacks._call_batch_hook(mode, "begin", step_index, batch_logs)
265 if prev_step_count is None or step_count != prev_step_count:
266 backend.get_session().run(steps_per_run.assign(step_count))
267 prev_step_count = step_count
268 try:
269 _, outputs = backend.batch_get_value([train_op, output_tensors])
270 except tf.errors.OutOfRangeError:
271 logging.warning(
272 "Your dataset iterator ran out of data; "
273 "interrupting training. Make sure that your dataset "
274 "can generate at least `steps_per_epoch * epochs` "
275 "batches (in this case, %d batches)."
276 % steps_per_epoch
277 * epochs
278 )
279 break
281 batch_logs.update(outputs)
282 callbacks._call_batch_hook(mode, "end", step_index, batch_logs)
283 step_index = step_index + step_count
284 current_step += 1
286 if callbacks.model.stop_training:
287 break
289 if do_validation and training_utils_v1.should_run_validation(
290 validation_freq, epoch
291 ):
292 logging.info("Running validation at fit epoch: %s", epoch)
294 if model._compile_distribution:
295 # Since we create a new clone from the original model we need to
296 # copy the weights back to the original model before we can run
297 # validation.
298 dist_utils._copy_weights_to_original_model(
299 model, ModeKeys.TRAIN
300 )
302 val_outs = experimental_tpu_test_loop(
303 model,
304 val_dataset,
305 steps=validation_steps,
306 verbose=verbose,
307 callbacks=callbacks,
308 )
309 if not isinstance(val_outs, list):
310 val_outs = [val_outs]
311 # Same labels assumed.
312 for label, val_out in zip(out_labels, val_outs):
313 epoch_logs["val_" + label] = val_out
315 callbacks.on_epoch_end(epoch, epoch_logs)
316 if callbacks.model.stop_training:
317 break
318 model._successful_loop_finish = True
319 callbacks._call_end_hook(mode)
321 if model._compile_distribution:
322 # Copy the weights back from the replicated model to the original model.
323 dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN)
324 scope.__exit__(None, None, None)
325 return model.history
328def experimental_tpu_test_loop(
329 model, dataset, verbose=0, steps=None, callbacks=None
330):
331 """Test loop for evaluating with TPU tf.distribute.Strategy.
333 Args:
334 model: Keras Model instance.
335 dataset: Dataset for input data.
336 verbose: Integer, Verbosity mode 0 or 1.
337 steps: Total number of steps (batches of samples)
338 before declaring predictions finished.
339 Ignored with the default value of `None`.
340 callbacks: List of callbacks to be called during training
342 Returns:
343 Scalar loss (if the model has a single output and no metrics)
344 or list of scalars (if the model has multiple outputs
345 and/or metrics). The attribute `model.metrics_names` will give you
346 the display labels for the outputs.
347 """
348 mode = ModeKeys.TEST
349 current_strategy = model._distribution_strategy
350 iterator = dist_utils.get_iterator(dataset, current_strategy)
352 scope = dist_utils.distributed_scope(
353 strategy=current_strategy, learning_phase=0
354 )
355 scope.__enter__()
357 out_labels = model.metrics_names
359 def _test_step_fn(inputs):
360 """A fn that returns output of single test step."""
361 if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
362 inputs, targets = inputs
363 else:
364 targets = None
366 (
367 tf.distribute.get_replica_context().merge_call(
368 _build_model, args=(model, mode, inputs, targets)
369 )
370 )
372 (_, outputs, updates, _) = _per_replica_execution_function(
373 dist_utils.get_distributed_model(model, mode), mode
374 )
375 with tf.control_dependencies([updates]):
376 return [tf.identity(out) for out in outputs]
378 test_input_data = iterator.get_next()
379 per_replica_outputs = current_strategy.run(
380 _test_step_fn, args=(test_input_data,)
381 )
382 output_tensors = {}
383 for label, output in zip(out_labels, per_replica_outputs):
384 if label == "loss":
385 reduce_op = tf.distribute.ReduceOp.SUM
386 else:
387 # We reduce all other metrics using mean for now. This is temporary
388 # workaround until new metrics are in place.
389 reduce_op = tf.distribute.ReduceOp.MEAN
390 output_tensors[label] = current_strategy.reduce(
391 reduce_op, output, axis=None
392 )
393 test_op = tf.group(list(output_tensors.values()))
395 if verbose >= 1:
396 progbar = Progbar(target=steps)
398 if model._compile_distribution:
399 dist_utils._copy_weights_to_distributed_model(model, mode)
401 dist_utils._reset_metrics(model)
403 callbacks = cbks.configure_callbacks(
404 callbacks,
405 model,
406 do_validation=False,
407 epochs=1,
408 steps_per_epoch=steps,
409 verbose=verbose,
410 count_mode="steps",
411 mode=ModeKeys.TEST,
412 )
413 callbacks._call_begin_hook(mode)
415 outs = [0.0] * len(model.metrics_names)
416 if steps is not None:
417 target_steps = steps
418 else:
419 raise ValueError(
420 "Number of steps could not be inferred from the data, "
421 "please pass the steps argument."
422 )
424 current_step = 0
425 while current_step < target_steps:
426 batch_logs = {"batch": current_step, "size": 1}
427 callbacks._call_batch_hook(mode, "begin", current_step, batch_logs)
428 try:
429 _, batch_outs = backend.batch_get_value([test_op, output_tensors])
430 except tf.errors.OutOfRangeError:
431 warning_msg = (
432 "Make sure that your dataset can generate at least "
433 "`steps` batches (in this case, {} batches).".format(steps)
434 )
436 logging.warning(
437 "Your dataset iterator ran out of data; "
438 "interrupting evaluation. " + warning_msg
439 )
440 target_steps = current_step
441 break
442 for i, label in enumerate(model.metrics_names):
443 if i == 0:
444 # Loss is stateless metrics.
445 outs[i] += batch_outs[label]
446 else:
447 # For all stateful metrics, the aggregation is handled by
448 # mirrored vars.
449 outs[i] = batch_outs[label]
451 batch_logs = callbacks.make_logs(model, batch_logs, outs, mode)
452 callbacks._call_batch_hook(mode, "end", current_step, batch_logs)
453 if verbose == 1:
454 progbar.update(current_step + 1)
455 current_step += 1
457 if verbose >= 1:
458 # Progress bar finishes at the end.
459 progbar.update(target_steps)
460 callbacks._call_end_hook(mode)
462 scope.__exit__(None, None, None)
463 if len(outs) > 0:
464 outs[0] /= target_steps
466 if len(outs) == 1:
467 return outs[0]
468 return outs
471def experimental_tpu_predict_loop(
472 model, dataset, verbose=0, steps=None, callbacks=None
473):
474 """Predict loop for predicting with TPU tf.distribute.Strategy.
476 Args:
477 model: Keras Model instance.
478 dataset: Dataset for input data.
479 verbose: Integer, Verbosity mode 0 or 1.
480 steps: Total number of steps (batches of samples)
481 before declaring `_predict_loop` finished.
482 Ignored with the default value of `None`.
483 callbacks: List of callbacks to be called during training
485 Returns:
486 Array of predictions (if the model has a single output)
487 or list of arrays of predictions
488 (if the model has multiple outputs).
489 """
490 mode = ModeKeys.PREDICT
491 dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset)
492 padding_handler = None
493 if not dataset_fully_shaped:
494 # TODO(hongjunchoi): Investigate whether operations from
495 # PartialBatchPaddingHandler are unnecessarily pruned out
496 # during graph optimization.
497 padding_handler = padding_util.PartialBatchPaddingHandler(
498 model._feed_output_shapes
499 )
500 batch_size, _, prefetch_buffer = input_lib._get_dataset_attributes(
501 dataset
502 )
503 padding_handler.padded_batch_size = batch_size
504 padding_handler.padding_mask = dataset.reduce(
505 padding_handler.padding_mask, padding_handler.update_mask
506 )
508 dataset = dataset.map(padding_handler.pad_batch)
509 dataset = dataset.unbatch()
510 # Upon this point, it is guaranteed that the dataset does not
511 # have partial batches. Thus, we set `drop_remainder=True` to
512 # get static shape information about the elements in the dataset.
513 dataset = dataset.batch(batch_size, drop_remainder=True)
515 if prefetch_buffer is not None:
516 dataset = dataset.prefetch(prefetch_buffer)
518 current_strategy = model._distribution_strategy
519 iterator = dist_utils.get_iterator(dataset, current_strategy)
521 scope = dist_utils.distributed_scope(
522 strategy=current_strategy, learning_phase=0
523 )
524 scope.__enter__()
526 def _predict_step_fn(inputs):
527 """A fn that returns output of single prediction step."""
529 (
530 tf.distribute.get_replica_context().merge_call(
531 _build_model, args=(model, mode, inputs)
532 )
533 )
535 (_, outputs, updates, _) = _per_replica_execution_function(
536 dist_utils.get_distributed_model(model, mode), mode
537 )
539 with tf.control_dependencies([updates]):
540 return [tf.identity(out) for out in outputs]
542 # TODO(hongjunchoi): When numpy array is passed as an input to `predict()`
543 # use numpy arrays directly to avoid cumulating unnecessary input pipeline
544 # ops.
545 predict_input_data = iterator.get_next()
546 per_replica_outputs = current_strategy.run(
547 _predict_step_fn, args=(predict_input_data,)
548 )
549 output_tensors = dist_utils.flatten_per_replica_values(
550 current_strategy, per_replica_outputs
551 )
553 if verbose >= 1:
554 progbar = Progbar(target=steps)
556 if model._compile_distribution:
557 dist_utils._copy_weights_to_distributed_model(model, mode)
559 dist_utils._reset_metrics(model)
561 callbacks = cbks.configure_callbacks(
562 callbacks,
563 model,
564 do_validation=False,
565 epochs=1,
566 steps_per_epoch=steps,
567 verbose=verbose,
568 count_mode="steps",
569 mode=mode,
570 )
571 callbacks._call_begin_hook(mode)
573 # Since we do not know how many samples we will see, we cannot pre-allocate
574 # the returned Numpy arrays. Instead, we store one array per batch seen
575 # and concatenate them upon returning.
576 num_model_outputs = len(model.output_names)
577 unconcatenated_outs = [[] for _ in range(num_model_outputs)]
578 if steps is not None:
579 target_steps = steps
580 else:
581 raise ValueError(
582 "Number of steps could not be inferred from the data, "
583 "please pass the steps argument."
584 )
586 current_step = 0
587 while current_step < target_steps:
588 batch_logs = {"batch": current_step, "size": 1}
589 callbacks._call_batch_hook(mode, "begin", current_step, batch_logs)
590 try:
591 predict_ops = tf.group(output_tensors)
592 _, batch_outs = backend.batch_get_value(
593 [predict_ops, output_tensors]
594 )
596 except tf.errors.OutOfRangeError:
597 warning_msg = (
598 "Make sure that your dataset can generate at least "
599 "`steps` batches (in this case, {} batches).".format(steps)
600 )
602 logging.warning(
603 "Your dataset iterator ran out of data; "
604 "interrupting evaluation. " + warning_msg
605 )
606 break
608 # TODO(priyag): maybe need to unwrap the outputs first for
609 # MirroredStrategy.
610 for i in range(num_model_outputs):
611 output_start_index = i * current_strategy.num_replicas_in_sync
612 output_end_index = (
613 output_start_index + current_strategy.num_replicas_in_sync
614 )
615 single_model_output = batch_outs[
616 output_start_index:output_end_index
617 ]
618 unconcatenated_outs[i].extend(single_model_output)
620 batch_logs = callbacks.make_logs(model, batch_logs, batch_outs, mode)
621 callbacks._call_batch_hook(mode, "end", current_step, batch_logs)
622 if verbose == 1:
623 progbar.update(current_step + 1)
624 current_step += 1
626 if verbose >= 1:
627 # Progress bar finishes at the end.
628 progbar.update(current_step)
630 callbacks._call_end_hook(mode)
632 scope.__exit__(None, None, None)
634 if len(unconcatenated_outs) == 1:
635 prediction_result = np.concatenate(unconcatenated_outs[0], axis=0)
636 else:
637 prediction_result = [
638 np.concatenate(out, axis=0) for out in unconcatenated_outs
639 ]
641 if padding_handler:
642 prediction_result = padding_handler.apply_mask(prediction_result)
644 return prediction_result
647class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
648 """Training loop for distribution strategy with single worker."""
650 def fit(
651 self,
652 model,
653 x=None,
654 y=None,
655 batch_size=None,
656 epochs=1,
657 verbose=1,
658 callbacks=None,
659 validation_split=0.0,
660 validation_data=None,
661 shuffle=True,
662 class_weight=None,
663 sample_weight=None,
664 initial_epoch=0,
665 steps_per_epoch=None,
666 validation_steps=None,
667 validation_freq=1,
668 **kwargs
669 ):
670 """Fit loop for Distribution Strategies."""
671 dist_utils.validate_callbacks(
672 input_callbacks=callbacks, optimizer=model.optimizer
673 )
674 dist_utils.validate_inputs(x, y)
676 batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
677 model._distribution_strategy,
678 x,
679 batch_size,
680 steps_per_epoch,
681 ModeKeys.TRAIN,
682 validation_split=validation_split,
683 )
684 batch_size = model._validate_or_infer_batch_size(
685 batch_size, steps_per_epoch, x
686 )
687 dataset = model._distribution_standardize_user_data(
688 x,
689 y,
690 sample_weight=sample_weight,
691 class_weight=class_weight,
692 batch_size=batch_size,
693 validation_split=validation_split,
694 shuffle=shuffle,
695 epochs=epochs,
696 )
697 if not dist_utils.is_distributing_by_cloning(model):
698 with model._distribution_strategy.scope():
699 (dataset, _, _) = model._standardize_user_data(
700 dataset,
701 sample_weight=sample_weight,
702 class_weight=class_weight,
703 batch_size=batch_size,
704 validation_split=validation_split,
705 shuffle=shuffle,
706 )
708 val_dataset = None
709 if validation_data:
710 (
711 val_x,
712 val_y,
713 val_sample_weights,
714 ) = training_utils_v1.unpack_validation_data(validation_data)
715 dist_utils.validate_inputs(val_x, val_y)
716 _, validation_steps = dist_utils.process_batch_and_step_size(
717 model._distribution_strategy,
718 val_x,
719 batch_size,
720 validation_steps,
721 ModeKeys.TEST,
722 )
724 val_dataset = model._distribution_standardize_user_data(
725 val_x,
726 val_y,
727 sample_weight=val_sample_weights,
728 class_weight=None,
729 batch_size=batch_size,
730 validation_split=validation_split,
731 shuffle=shuffle,
732 allow_partial_batch=True,
733 )
734 elif validation_split:
735 raise ValueError(
736 "validation_split argument is not supported with "
737 "distribution strategies."
738 )
740 if backend.is_tpu_strategy(model._distribution_strategy):
741 steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
742 model,
743 dataset,
744 steps_per_epoch,
745 epochs,
746 steps_name="steps_per_epoch",
747 )
748 if steps_per_epoch is None:
749 raise ValueError(
750 "Number of steps could not be inferred from the data, "
751 "please pass the steps_per_epoch argument."
752 )
754 if not tf.executing_eagerly():
755 # Run TPU training in a custom loop in graph mode.
756 return experimental_tpu_fit_loop(
757 model,
758 dataset,
759 epochs=epochs,
760 verbose=verbose,
761 callbacks=callbacks,
762 val_dataset=val_dataset,
763 initial_epoch=initial_epoch,
764 steps_per_epoch=steps_per_epoch,
765 validation_steps=validation_steps,
766 validation_freq=validation_freq,
767 )
769 return training_arrays_v1.fit_loop(
770 model,
771 dataset,
772 batch_size=batch_size,
773 epochs=epochs,
774 verbose=verbose,
775 callbacks=callbacks,
776 val_inputs=val_dataset,
777 shuffle=shuffle,
778 initial_epoch=initial_epoch,
779 steps_per_epoch=steps_per_epoch,
780 validation_steps=validation_steps,
781 validation_freq=validation_freq,
782 steps_name="steps_per_epoch",
783 )
785 def evaluate(
786 self,
787 model,
788 x=None,
789 y=None,
790 batch_size=None,
791 verbose=1,
792 sample_weight=None,
793 steps=None,
794 callbacks=None,
795 **kwargs
796 ):
797 """Evaluate loop for Distribution Strategies."""
798 dist_utils.validate_inputs(x, y)
799 batch_size, steps = dist_utils.process_batch_and_step_size(
800 model._distribution_strategy, x, batch_size, steps, ModeKeys.TEST
801 )
802 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
803 dataset = model._distribution_standardize_user_data(
804 x,
805 y,
806 sample_weight=sample_weight,
807 batch_size=batch_size,
808 allow_partial_batch=True,
809 )
811 if backend.is_tpu_strategy(model._distribution_strategy):
812 steps = training_utils_v1.infer_steps_for_dataset(
813 model, dataset, steps, steps_name="steps"
814 )
815 if steps is None:
816 raise ValueError(
817 "Number of steps could not be inferred from the data, "
818 "please pass the steps argument."
819 )
821 if not tf.executing_eagerly():
822 # Run TPU evaluation in a custom loop in graph mode.
823 return experimental_tpu_test_loop(
824 model,
825 dataset,
826 verbose=verbose,
827 steps=steps,
828 callbacks=callbacks,
829 )
831 return training_arrays_v1.test_loop(
832 model,
833 inputs=dataset,
834 batch_size=batch_size,
835 verbose=verbose,
836 steps=steps,
837 callbacks=callbacks,
838 )
840 def predict(
841 self,
842 model,
843 x,
844 batch_size=None,
845 verbose=0,
846 steps=None,
847 callbacks=None,
848 **kwargs
849 ):
850 """Predict loop for Distribution Strategies."""
851 dist_utils.validate_inputs(x=x, y=None)
852 batch_size, steps = dist_utils.process_batch_and_step_size(
853 model._distribution_strategy, x, batch_size, steps, ModeKeys.PREDICT
854 )
855 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
856 dataset = model._distribution_standardize_user_data(
857 x, batch_size=batch_size, allow_partial_batch=True
858 )
859 if backend.is_tpu_strategy(model._distribution_strategy):
860 steps = training_utils_v1.infer_steps_for_dataset(
861 model, dataset, steps, steps_name="steps"
862 )
863 if steps is None:
864 raise ValueError(
865 "Number of steps could not be inferred from the data, "
866 "please pass the steps argument."
867 )
868 if not tf.executing_eagerly():
869 return experimental_tpu_predict_loop(
870 model,
871 dataset,
872 verbose=verbose,
873 steps=steps,
874 callbacks=callbacks,
875 )
876 return training_arrays_v1.predict_loop(
877 model,
878 dataset,
879 batch_size=batch_size,
880 verbose=verbose,
881 steps=steps,
882 callbacks=callbacks,
883 )
886def _train_with_multi_worker(method):
887 """Decorator handles multi worker training with distribution strategy."""
889 def wrapper(model, **kwargs):
890 def _worker_fn(_):
891 callbacks = kwargs.pop("callbacks", None)
892 filtered_callbacks = dist_utils.filter_distributed_callbacks(
893 callbacks, model
894 )
895 kwargs["callbacks"] = filtered_callbacks
896 return method(model, **kwargs)
898 return dc.run_distribute_coordinator(
899 _worker_fn, model._distribution_strategy
900 )
902 return wrapper
905class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop):
906 """Training loop for distribution strategy with multiple worker."""
908 def __init__(self, single_worker_loop):
909 self._single_worker_loop = single_worker_loop
911 def fit(self, *args, **kwargs):
912 return _train_with_multi_worker(self._single_worker_loop.fit)(
913 *args, **kwargs
914 )
916 def evaluate(self, *args, **kwargs):
917 return _train_with_multi_worker(self._single_worker_loop.evaluate)(
918 *args, **kwargs
919 )
921 def predict(self, *args, **kwargs):
922 # Currently predict is still using the single worker implementation.
923 return self._single_worker_loop.predict(*args, **kwargs)