Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_distributed_v1.py: 12%
317 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to distributed training."""
16# pylint: disable=protected-access
18import numpy as np
19from tensorflow.python.distribute import distribute_lib
20from tensorflow.python.distribute import input_lib
21from tensorflow.python.distribute import reduce_util as ds_reduce_util
22from tensorflow.python.eager import context
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import ops
26from tensorflow.python.keras import backend
27from tensorflow.python.keras import callbacks as cbks
28from tensorflow.python.keras.distribute import distribute_coordinator_utils as dc
29from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils
30from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
31from tensorflow.python.keras.engine import training_arrays_v1
32from tensorflow.python.keras.engine import training_utils_v1
33from tensorflow.python.keras.utils.generic_utils import Progbar
34from tensorflow.python.keras.utils.mode_keys import ModeKeys
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.platform import tf_logging as logging
40def _per_replica_execution_function(model, mode):
41 exec_func = model._make_execution_function(mode)
42 return (exec_func.inputs, exec_func.outputs, exec_func.updates_op,
43 exec_func.session_kwargs)
46def _build_model(strategy, model, mode, inputs, targets=None):
47 if model._compile_distribution:
48 dist_utils.clone_model_on_replicas(
49 model, strategy, mode, inputs=inputs, targets=targets)
50 else:
51 dist_utils._build_distributed_network(model, strategy, mode, inputs,
52 targets)
55def _make_train_step_fn(model, mode, strategy, output_labels):
56 """Create step fn.
58 Args:
59 model: a Keras Model instance.
60 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
61 strategy: a `tf.distribute.Strategy` instance.
62 output_labels: the output labels for the step function.
64 Returns:
65 A step function to run by `tf.distribute.Strategy`.
66 """
68 def _step_fn(ctx, inputs):
69 """A step fn that returns update ops."""
70 if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
71 inputs, targets = inputs
72 else:
73 targets = None
75 # When input feature is a dictionary of tensors, dictionary is flattended
76 # to an array and passed as a model input. This results in input mismatch
77 # when model input layer names are not sorted in alphabetical order as
78 # `nest.flatten()`sorts dictionary elements by keys. As so, transform input
79 # tensors into an array and order it along `model._feed_input_names`.
80 if isinstance(inputs, dict):
81 inputs = [inputs[input_name] for input_name in model._feed_input_names]
83 _build_model(strategy, model, mode, inputs, targets)
85 (grouped_inputs, grouped_outputs, grouped_updates,
86 grouped_session_args) = strategy.extended.call_for_each_replica(
87 _per_replica_execution_function,
88 args=(dist_utils.get_distributed_model(model, mode), mode))
89 (all_inputs, all_outputs, all_updates,
90 all_session_args) = dist_utils.unwrap_values(strategy, grouped_inputs,
91 grouped_outputs,
92 grouped_updates,
93 grouped_session_args)
94 combined_fn = backend.function(
95 all_inputs,
96 all_outputs,
97 updates=all_updates,
98 name='distributed_' + str(mode) + '_function',
99 **all_session_args)
101 for label, output in zip(output_labels, combined_fn.outputs):
102 if label == 'loss':
103 reduce_op = ds_reduce_util.ReduceOp.SUM
104 else:
105 # We reduce all other metrics using mean for now. This is temporary
106 # workaround until new metrics are in place.
107 reduce_op = ds_reduce_util.ReduceOp.MEAN
108 ctx.set_last_step_output(label, output, reduce_op)
110 # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
111 # feed_dict, session kwargs, run options, run_metadata for now. These should
112 # be handled appropriately
113 return combined_fn.updates_op
115 return _step_fn
118def experimental_tpu_fit_loop(model,
119 dataset,
120 epochs=100,
121 verbose=1,
122 callbacks=None,
123 initial_epoch=0,
124 steps_per_epoch=None,
125 val_dataset=None,
126 validation_steps=None,
127 validation_freq=1):
128 """Fit loop for training with TPU tf.distribute.Strategy.
130 Args:
131 model: Keras Model instance.
132 dataset: Dataset that returns inputs and targets
133 epochs: Number of times to iterate over the data
134 verbose: Integer, Verbosity mode, 0, 1 or 2
135 callbacks: List of callbacks to be called during training
136 initial_epoch: Epoch at which to start training
137 (useful for resuming a previous training run)
138 steps_per_epoch: Total number of steps (batches of samples)
139 before declaring one epoch finished and starting the
140 next epoch. Ignored with the default value of `None`.
141 val_dataset: Dataset for validation data.
142 validation_steps: Number of steps to run validation for
143 (only if doing validation from data tensors).
144 Ignored with the default value of `None`.
145 validation_freq: Only relevant if validation data is provided. Integer or
146 `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
147 integer, specifies how many training epochs to run before a new
148 validation run is performed, e.g. `validation_freq=2` runs
149 validation every 2 epochs. If a Container, specifies the epochs on
150 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
151 validation at the end of the 1st, 2nd, and 10th epochs.
153 Returns:
154 Returns `None`.
156 Raises:
157 ValueError: in case of invalid arguments.
158 """
159 mode = ModeKeys.TRAIN
161 current_strategy = model._distribution_strategy
162 iteration_value = min(steps_per_epoch,
163 current_strategy.extended.steps_per_run)
164 steps_per_run = backend.variable(
165 value=iteration_value,
166 dtype='int32',
167 name='steps_per_run')
169 # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
170 iterator = dist_utils.get_iterator(dataset, current_strategy)
172 scope = dist_utils.distributed_scope(
173 strategy=current_strategy, learning_phase=1)
174 scope.__enter__()
176 out_labels = model.metrics_names or []
178 step_fn = _make_train_step_fn(model, ModeKeys.TRAIN, current_strategy,
179 out_labels)
181 # Add initial dummy values for loss and other metric tensors.
182 initial_loop_values = {}
183 initial_loop_values['loss'] = constant_op.constant(1e7)
184 for m in model._get_training_eval_metrics():
185 tensor = m.result()
186 initial_loop_values[m.name] = array_ops.zeros(tensor.shape, tensor.dtype)
188 ctx = current_strategy.extended.experimental_run_steps_on_iterator(
189 step_fn, iterator, iterations=steps_per_run,
190 initial_loop_values=initial_loop_values)
191 train_op = ctx.run_op
192 output_tensors = ctx.last_step_outputs
194 do_validation = bool(validation_steps)
196 if model._compile_distribution:
197 dist_utils._copy_weights_to_distributed_model(model, mode)
199 callbacks = cbks.configure_callbacks(
200 callbacks,
201 model,
202 do_validation=do_validation,
203 epochs=epochs,
204 steps_per_epoch=steps_per_epoch,
205 verbose=verbose,
206 count_mode='steps',
207 mode=mode)
209 # Calculate the steps each time on the device.
210 steps_to_run = ([current_strategy.extended.steps_per_run] *
211 (steps_per_epoch //
212 current_strategy.extended.steps_per_run))
213 if steps_per_epoch % current_strategy.extended.steps_per_run:
214 steps_to_run.append(
215 steps_per_epoch % current_strategy.extended.steps_per_run)
216 target_steps = len(steps_to_run)
218 callbacks._call_begin_hook(mode)
220 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode)
222 for epoch in range(initial_epoch, epochs):
223 dist_utils._reset_metrics(model)
224 callbacks.on_epoch_begin(epoch)
225 epoch_logs = {}
226 step_index = 0
227 prev_step_count = None
228 current_step = 0
229 while current_step < target_steps:
230 step_count = steps_to_run[current_step]
231 batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
232 callbacks._call_batch_hook(mode, 'begin', step_index, batch_logs)
233 if prev_step_count is None or step_count != prev_step_count:
234 backend.get_session().run(steps_per_run.assign(step_count))
235 prev_step_count = step_count
236 try:
237 _, outputs = backend.batch_get_value([train_op, output_tensors])
238 except errors.OutOfRangeError:
239 logging.warning('Your dataset iterator ran out of data; '
240 'interrupting training. Make sure that your dataset '
241 'can generate at least `steps_per_epoch * epochs` '
242 'batches (in this case, %d batches).' %
243 steps_per_epoch * epochs)
244 break
246 batch_logs.update(outputs)
247 callbacks._call_batch_hook(mode, 'end', step_index, batch_logs)
248 step_index = step_index + step_count
249 current_step += 1
251 if callbacks.model.stop_training:
252 break
254 if (do_validation and
255 training_utils_v1.should_run_validation(validation_freq, epoch)):
256 logging.info('Running validation at fit epoch: %s', epoch)
258 if model._compile_distribution:
259 # Since we create a new clone from the original model we need to copy
260 # the weights back to the original model before we can run validation.
261 dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN)
263 val_outs = experimental_tpu_test_loop( # pylint: disable=undefined-variable
264 model,
265 val_dataset,
266 steps=validation_steps,
267 verbose=verbose,
268 callbacks=callbacks)
269 if not isinstance(val_outs, list):
270 val_outs = [val_outs]
271 # Same labels assumed.
272 for label, val_out in zip(out_labels, val_outs):
273 epoch_logs['val_' + label] = val_out
275 callbacks.on_epoch_end(epoch, epoch_logs)
276 if callbacks.model.stop_training:
277 break
278 model._successful_loop_finish = True
279 callbacks._call_end_hook(mode)
281 if model._compile_distribution:
282 # Copy the weights back from the replicated model to the original model.
283 dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN)
284 scope.__exit__(None, None, None)
285 return model.history
288def experimental_tpu_test_loop(model,
289 dataset,
290 verbose=0,
291 steps=None,
292 callbacks=None):
293 """Test loop for evaluating with TPU tf.distribute.Strategy.
295 Args:
296 model: Keras Model instance.
297 dataset: Dataset for input data.
298 verbose: Integer, Verbosity mode 0 or 1.
299 steps: Total number of steps (batches of samples)
300 before declaring predictions finished.
301 Ignored with the default value of `None`.
302 callbacks: List of callbacks to be called during training
304 Returns:
305 Scalar loss (if the model has a single output and no metrics)
306 or list of scalars (if the model has multiple outputs
307 and/or metrics). The attribute `model.metrics_names` will give you
308 the display labels for the outputs.
309 """
310 mode = ModeKeys.TEST
311 current_strategy = model._distribution_strategy
312 iterator = dist_utils.get_iterator(dataset, current_strategy)
314 scope = dist_utils.distributed_scope(
315 strategy=current_strategy, learning_phase=0)
316 scope.__enter__()
318 out_labels = model.metrics_names
320 def _test_step_fn(inputs):
321 """A fn that returns output of single test step."""
322 if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
323 inputs, targets = inputs
324 else:
325 targets = None
327 (distribute_lib.get_replica_context().merge_call(
328 _build_model, args=(model, mode, inputs, targets)))
330 (_, outputs, updates, _) = _per_replica_execution_function(
331 dist_utils.get_distributed_model(model, mode), mode)
332 with ops.control_dependencies([updates]):
333 return [array_ops.identity(out) for out in outputs]
335 test_input_data = iterator.get_next()
336 per_replica_outputs = current_strategy.run(
337 _test_step_fn, args=(test_input_data,))
338 output_tensors = {}
339 for label, output in zip(out_labels, per_replica_outputs):
340 if label == 'loss':
341 reduce_op = ds_reduce_util.ReduceOp.SUM
342 else:
343 # We reduce all other metrics using mean for now. This is temporary
344 # workaround until new metrics are in place.
345 reduce_op = ds_reduce_util.ReduceOp.MEAN
346 output_tensors[label] = current_strategy.reduce(reduce_op, output,
347 axis=None)
348 test_op = control_flow_ops.group(list(output_tensors.values()))
350 if verbose >= 1:
351 progbar = Progbar(target=steps)
353 if model._compile_distribution:
354 dist_utils._copy_weights_to_distributed_model(model, mode)
356 dist_utils._reset_metrics(model)
358 callbacks = cbks.configure_callbacks(
359 callbacks,
360 model,
361 do_validation=False,
362 epochs=1,
363 steps_per_epoch=steps,
364 verbose=verbose,
365 count_mode='steps',
366 mode=ModeKeys.TEST)
367 callbacks._call_begin_hook(mode)
369 outs = [0.] * len(model.metrics_names)
370 if steps is not None:
371 target_steps = steps
372 else:
373 raise ValueError('Number of steps could not be inferred from the data, '
374 'please pass the steps argument.')
376 current_step = 0
377 while current_step < target_steps:
378 batch_logs = {'batch': current_step, 'size': 1}
379 callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs)
380 try:
381 _, batch_outs = backend.batch_get_value([test_op, output_tensors])
382 except errors.OutOfRangeError:
383 warning_msg = (
384 'Make sure that your dataset can generate at least '
385 '`steps` batches (in this case, {} batches).'.format(steps))
387 logging.warning('Your dataset iterator ran out of data; '
388 'interrupting evaluation. ' + warning_msg)
389 target_steps = current_step
390 break
391 for i, label in enumerate(model.metrics_names):
392 if i == 0:
393 # Loss is stateless metrics.
394 outs[i] += batch_outs[label]
395 else:
396 # For all stateful metrics, the aggregation is handled by mirrored vars.
397 outs[i] = batch_outs[label]
399 batch_logs = cbks.make_logs(model, batch_logs, outs, mode)
400 callbacks._call_batch_hook(mode, 'end', current_step, batch_logs)
401 if verbose == 1:
402 progbar.update(current_step + 1)
403 current_step += 1
405 if verbose >= 1:
406 # Progress bar finishes at the end.
407 progbar.update(target_steps)
408 callbacks._call_end_hook(mode)
410 scope.__exit__(None, None, None)
411 if len(outs) >= 0:
412 outs[0] /= (target_steps)
414 if len(outs) == 1:
415 return outs[0]
416 return outs
419def experimental_tpu_predict_loop(model,
420 dataset,
421 verbose=0,
422 steps=None,
423 callbacks=None):
424 """Predict loop for predicting with TPU tf.distribute.Strategy.
426 Args:
427 model: Keras Model instance.
428 dataset: Dataset for input data.
429 verbose: Integer, Verbosity mode 0 or 1.
430 steps: Total number of steps (batches of samples)
431 before declaring `_predict_loop` finished.
432 Ignored with the default value of `None`.
433 callbacks: List of callbacks to be called during training
435 Returns:
436 Array of predictions (if the model has a single output)
437 or list of arrays of predictions
438 (if the model has multiple outputs).
439 """
440 mode = ModeKeys.PREDICT
441 dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset)
442 padding_handler = None
443 if not dataset_fully_shaped:
444 # TODO(hongjunchoi): Investigate whether operations from
445 # PartialBatchPaddingHandler are unnecessarily pruned out
446 # during graph optimization.
447 padding_handler = padding_util.PartialBatchPaddingHandler(
448 model._feed_output_shapes)
449 batch_size, _, prefetch_buffer = input_lib._get_dataset_attributes(dataset)
450 padding_handler.padded_batch_size = batch_size
451 padding_handler.padding_mask = dataset.reduce(padding_handler.padding_mask,
452 padding_handler.update_mask)
454 dataset = dataset.map(padding_handler.pad_batch)
455 dataset = dataset.unbatch()
456 # Upon this point, it is guaranteed that the dataset does not
457 # have partial batches. Thus, we set `drop_remainder=True` to
458 # get static shape information about the elements in the dataset.
459 dataset = dataset.batch(batch_size, drop_remainder=True)
461 if prefetch_buffer is not None:
462 dataset = dataset.prefetch(prefetch_buffer)
464 current_strategy = model._distribution_strategy
465 iterator = dist_utils.get_iterator(dataset, current_strategy)
467 scope = dist_utils.distributed_scope(
468 strategy=current_strategy, learning_phase=0)
469 scope.__enter__()
471 def _predict_step_fn(inputs):
472 """A fn that returns output of single prediction step."""
474 (distribute_lib.get_replica_context().merge_call(
475 _build_model, args=(model, mode, inputs)))
477 (_, outputs, updates, _) = _per_replica_execution_function(
478 dist_utils.get_distributed_model(model, mode), mode)
480 with ops.control_dependencies([updates]):
481 return [array_ops.identity(out) for out in outputs]
483 # TODO(hongjunchoi): When numpy array is passed as an input to `predict()`
484 # use numpy arrays directly to avoid cumulating unnecessary input pipeline
485 # ops.
486 predict_input_data = iterator.get_next()
487 per_replica_outputs = current_strategy.run(
488 _predict_step_fn, args=(predict_input_data,))
489 output_tensors = dist_utils.flatten_per_replica_values(
490 current_strategy, per_replica_outputs)
492 if verbose >= 1:
493 progbar = Progbar(target=steps)
495 if model._compile_distribution:
496 dist_utils._copy_weights_to_distributed_model(model, mode)
498 dist_utils._reset_metrics(model)
500 callbacks = cbks.configure_callbacks(
501 callbacks,
502 model,
503 do_validation=False,
504 epochs=1,
505 steps_per_epoch=steps,
506 verbose=verbose,
507 count_mode='steps',
508 mode=mode)
509 callbacks._call_begin_hook(mode)
511 # Since we do not know how many samples we will see, we cannot pre-allocate
512 # the returned Numpy arrays. Instead, we store one array per batch seen
513 # and concatenate them upon returning.
514 num_model_outputs = len(model.output_names)
515 unconcatenated_outs = [[] for _ in range(num_model_outputs)]
516 if steps is not None:
517 target_steps = steps
518 else:
519 raise ValueError('Number of steps could not be inferred from the data, '
520 'please pass the steps argument.')
522 current_step = 0
523 while current_step < target_steps:
524 batch_logs = {'batch': current_step, 'size': 1}
525 callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs)
526 try:
527 predict_ops = control_flow_ops.group(output_tensors)
528 _, batch_outs = backend.batch_get_value([predict_ops, output_tensors])
530 except errors.OutOfRangeError:
531 warning_msg = (
532 'Make sure that your dataset can generate at least '
533 '`steps` batches (in this case, {} batches).'.format(steps))
535 logging.warning('Your dataset iterator ran out of data; '
536 'interrupting evaluation. ' + warning_msg)
537 break
539 # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy.
540 for i in range(num_model_outputs):
541 output_start_index = i * current_strategy.num_replicas_in_sync
542 output_end_index = (
543 output_start_index + current_strategy.num_replicas_in_sync)
544 single_model_output = batch_outs[output_start_index:output_end_index]
545 unconcatenated_outs[i].extend(single_model_output)
547 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
548 callbacks._call_batch_hook(mode, 'end', current_step, batch_logs)
549 if verbose == 1:
550 progbar.update(current_step + 1)
551 current_step += 1
553 if verbose >= 1:
554 # Progress bar finishes at the end.
555 progbar.update(current_step)
557 callbacks._call_end_hook(mode)
559 scope.__exit__(None, None, None)
561 if len(unconcatenated_outs) == 1:
562 prediction_result = np.concatenate(unconcatenated_outs[0], axis=0)
563 else:
564 prediction_result = [
565 np.concatenate(out, axis=0) for out in unconcatenated_outs
566 ]
568 if padding_handler:
569 prediction_result = padding_handler.apply_mask(prediction_result)
571 return prediction_result
574class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
575 """Training loop for distribution strategy with single worker."""
577 def fit(self,
578 model,
579 x=None,
580 y=None,
581 batch_size=None,
582 epochs=1,
583 verbose=1,
584 callbacks=None,
585 validation_split=0.,
586 validation_data=None,
587 shuffle=True,
588 class_weight=None,
589 sample_weight=None,
590 initial_epoch=0,
591 steps_per_epoch=None,
592 validation_steps=None,
593 validation_freq=1,
594 **kwargs):
595 """Fit loop for Distribution Strategies."""
596 dist_utils.validate_callbacks(input_callbacks=callbacks,
597 optimizer=model.optimizer)
598 dist_utils.validate_inputs(x, y)
600 batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
601 model._distribution_strategy,
602 x,
603 batch_size,
604 steps_per_epoch,
605 ModeKeys.TRAIN,
606 validation_split=validation_split)
607 batch_size = model._validate_or_infer_batch_size(
608 batch_size, steps_per_epoch, x)
609 dataset = model._distribution_standardize_user_data(
610 x, y,
611 sample_weight=sample_weight,
612 class_weight=class_weight,
613 batch_size=batch_size,
614 validation_split=validation_split,
615 shuffle=shuffle,
616 epochs=epochs)
617 if not dist_utils.is_distributing_by_cloning(model):
618 with model._distribution_strategy.scope():
619 (dataset, _, _) = model._standardize_user_data(
620 dataset,
621 sample_weight=sample_weight,
622 class_weight=class_weight,
623 batch_size=batch_size,
624 validation_split=validation_split,
625 shuffle=shuffle)
627 val_dataset = None
628 if validation_data:
629 val_x, val_y, val_sample_weights = (
630 training_utils_v1.unpack_validation_data(validation_data))
631 dist_utils.validate_inputs(val_x, val_y)
632 _, validation_steps = dist_utils.process_batch_and_step_size(
633 model._distribution_strategy, val_x, batch_size, validation_steps,
634 ModeKeys.TEST)
636 val_dataset = model._distribution_standardize_user_data(
637 val_x, val_y,
638 sample_weight=val_sample_weights,
639 class_weight=None,
640 batch_size=batch_size,
641 validation_split=validation_split,
642 shuffle=shuffle,
643 allow_partial_batch=True)
644 elif validation_split:
645 raise ValueError('validation_split argument is not supported with '
646 'distribution strategies.')
648 if backend.is_tpu_strategy(model._distribution_strategy):
649 steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
650 model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
651 if steps_per_epoch is None:
652 raise ValueError('Number of steps could not be inferred from the data, '
653 'please pass the steps_per_epoch argument.')
655 if not context.executing_eagerly():
656 # Run TPU training in a custom loop in graph mode.
657 return experimental_tpu_fit_loop(
658 model,
659 dataset,
660 epochs=epochs,
661 verbose=verbose,
662 callbacks=callbacks,
663 val_dataset=val_dataset,
664 initial_epoch=initial_epoch,
665 steps_per_epoch=steps_per_epoch,
666 validation_steps=validation_steps,
667 validation_freq=validation_freq)
669 return training_arrays_v1.fit_loop(
670 model,
671 dataset,
672 batch_size=batch_size,
673 epochs=epochs,
674 verbose=verbose,
675 callbacks=callbacks,
676 val_inputs=val_dataset,
677 shuffle=shuffle,
678 initial_epoch=initial_epoch,
679 steps_per_epoch=steps_per_epoch,
680 validation_steps=validation_steps,
681 validation_freq=validation_freq,
682 steps_name='steps_per_epoch')
684 def evaluate(self,
685 model,
686 x=None,
687 y=None,
688 batch_size=None,
689 verbose=1,
690 sample_weight=None,
691 steps=None,
692 callbacks=None,
693 **kwargs):
694 """Evaluate loop for Distribution Strategies."""
695 dist_utils.validate_inputs(x, y)
696 batch_size, steps = dist_utils.process_batch_and_step_size(
697 model._distribution_strategy, x, batch_size, steps, ModeKeys.TEST)
698 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
699 dataset = model._distribution_standardize_user_data(
700 x, y,
701 sample_weight=sample_weight,
702 batch_size=batch_size,
703 allow_partial_batch=True)
705 if backend.is_tpu_strategy(model._distribution_strategy):
706 steps = training_utils_v1.infer_steps_for_dataset(
707 model, dataset, steps, steps_name='steps')
708 if steps is None:
709 raise ValueError('Number of steps could not be inferred from the data, '
710 'please pass the steps argument.')
712 if not context.executing_eagerly():
713 # Run TPU evaluation in a custom loop in graph mode.
714 return experimental_tpu_test_loop(
715 model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
717 return training_arrays_v1.test_loop(
718 model,
719 inputs=dataset,
720 batch_size=batch_size,
721 verbose=verbose,
722 steps=steps,
723 callbacks=callbacks)
725 def predict(self,
726 model,
727 x,
728 batch_size=None,
729 verbose=0,
730 steps=None,
731 callbacks=None,
732 **kwargs):
733 """Predict loop for Distribution Strategies."""
734 dist_utils.validate_inputs(x=x, y=None)
735 batch_size, steps = dist_utils.process_batch_and_step_size(
736 model._distribution_strategy, x, batch_size, steps, ModeKeys.PREDICT)
737 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
738 dataset = model._distribution_standardize_user_data(
739 x,
740 batch_size=batch_size,
741 allow_partial_batch=True)
742 if backend.is_tpu_strategy(model._distribution_strategy):
743 steps = training_utils_v1.infer_steps_for_dataset(
744 model, dataset, steps, steps_name='steps')
745 if steps is None:
746 raise ValueError('Number of steps could not be inferred from the data, '
747 'please pass the steps argument.')
748 if not context.executing_eagerly():
749 return experimental_tpu_predict_loop(
750 model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
751 return training_arrays_v1.predict_loop(
752 model,
753 dataset,
754 batch_size=batch_size,
755 verbose=verbose,
756 steps=steps,
757 callbacks=callbacks)
760def _train_with_multi_worker(method):
761 """Decorator that handles multi worker training with distribution strategy."""
763 def wrapper(model, **kwargs):
764 def _worker_fn(_):
765 callbacks = kwargs.pop('callbacks', None)
766 filtered_callbacks = dist_utils.filter_distributed_callbacks(
767 callbacks, model)
768 kwargs['callbacks'] = filtered_callbacks
769 return method(model, **kwargs)
771 return dc.run_distribute_coordinator(
772 _worker_fn,
773 model._distribution_strategy)
775 return wrapper
778class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop):
779 """Training loop for distribution strategy with multiple worker."""
781 def __init__(self, single_worker_loop):
782 self._single_worker_loop = single_worker_loop
784 def fit(self, *args, **kwargs):
785 return _train_with_multi_worker(self._single_worker_loop.fit)(
786 *args, **kwargs)
788 def evaluate(self, *args, **kwargs):
789 return _train_with_multi_worker(self._single_worker_loop.evaluate)(
790 *args, **kwargs)
792 def predict(self, *args, **kwargs):
793 # Currently predict is still using the single worker implementation.
794 return self._single_worker_loop.predict(*args, **kwargs)