Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/distribute/distributed_training_utils_v1.py: 15%
436 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to distributed training."""
17import functools
19import numpy as np
20import tensorflow.compat.v2 as tf
22from keras.src import backend
23from keras.src import callbacks
24from keras.src import metrics as metrics_module
25from keras.src import optimizers
26from keras.src.distribute import distribute_coordinator_utils as dc
27from keras.src.distribute import distributed_training_utils as dist_utils
28from keras.src.engine import training_utils_v1
29from keras.src.optimizers.legacy import optimizer_v2
30from keras.src.utils import tf_contextlib
31from keras.src.utils.mode_keys import ModeKeys
33# isort: off
34from tensorflow.python.platform import tf_logging as logging
37def set_weights(distribution_strategy, dist_model, weights):
38 """Sets the weights of the replicated models.
40 The weights of the replicated models are set to the weights of the original
41 model. The weights of the replicated model are Mirrored variables and hence
42 we need to use the `update` call within a DistributionStrategy scope.
44 Args:
45 distribution_strategy: DistributionStrategy used to distribute training
46 and validation.
47 dist_model: The replicated models on the different devices.
48 weights: The weights of the original model.
49 """
50 assign_ops = []
51 for layer in dist_model.layers:
52 num_param = len(layer.weights)
53 layer_weights = weights[:num_param]
54 for sw, w in zip(layer.weights, layer_weights):
55 if tf.compat.v1.executing_eagerly_outside_functions():
56 sw.assign(w)
57 else:
58 assign_ops.append(distribution_strategy.unwrap(sw.assign(w)))
59 weights = weights[num_param:]
61 if not tf.compat.v1.executing_eagerly_outside_functions():
62 backend.get_session(assign_ops).run(assign_ops)
65def unwrap_values(
66 distribution_strategy,
67 grouped_inputs,
68 grouped_outputs,
69 grouped_updates=None,
70 grouped_session_args=None,
71 with_loss_tensor=False,
72):
73 """Unwrap the list of values contained in the PerReplica parameters.
75 This function calls `flatten_per_replica_values` to parse each of the input
76 parameters into a list of values on the different devices. If we set
77 `with_loss_tensor` to be True, we also call `reduce` on the list of losses
78 on the different devices to give us one loss tensor.
80 Args:
81 distribution_strategy: DistributionStrategy used to distribute training
82 and validation.
83 grouped_inputs: PerReplica inputs returned from the train or test function
84 that we ran on each device.
85 grouped_outputs: PerReplica outputs returned from the train or test
86 function that we ran on each device.
87 grouped_updates: PerReplica updates returned from the train or test
88 function that we ran on each device.
89 grouped_session_args: PerReplica session args returned from the train or
90 test function that we ran on each device.
91 with_loss_tensor: Boolean that indicates if we need to add the reduced
92 loss tensor as one of the outputs.
94 Returns:
95 Values of each of the PerReplica parameters.
97 """
98 # Unwrap per device values returned from each model's train function.
99 # This will be used to construct the main train function.
100 all_inputs = flatten_per_replica_values(
101 distribution_strategy, grouped_inputs
102 )
103 all_outputs = unwrap_outputs(
104 distribution_strategy, grouped_outputs, with_loss_tensor
105 )
107 if grouped_updates:
108 all_updates = flatten_per_replica_values(
109 distribution_strategy, grouped_updates
110 )
111 else:
112 all_updates = None
114 all_session_args = {}
115 if grouped_session_args:
116 grouped_feed_dict = grouped_session_args.get("feed_dict")
117 if grouped_feed_dict:
118 all_session_args["feed_dict"] = flatten_per_replica_values(
119 distribution_strategy, grouped_feed_dict
120 )
122 grouped_fetches = grouped_session_args.get("fetches")
123 if grouped_fetches:
124 all_session_args["fetches"] = flatten_per_replica_values(
125 distribution_strategy, grouped_fetches
126 )
128 # TODO(priyag): Return only non empty/None values
129 return all_inputs, all_outputs, all_updates, all_session_args
132def unwrap_output_dict(strategy, grouped_outputs, mode):
133 """Unwrap the list of outputs contained in the PerReplica parameters."""
134 if mode == ModeKeys.PREDICT:
135 return flatten_per_replica_values(strategy, grouped_outputs)
137 # In the case of fit/eval, the grouped_outputs is a dict, whereas in
138 # predict, the output is as same structure as model output. They need to be
139 # treated differently
140 total_loss = strategy.reduce(
141 tf.distribute.ReduceOp.SUM, grouped_outputs["total_loss"][0], axis=None
142 )
143 output_losses = flatten_per_replica_values(
144 strategy, grouped_outputs["output_losses"]
145 )
146 metrics = flatten_per_replica_values(strategy, grouped_outputs["metrics"])
147 batch_size = strategy.reduce(
148 tf.distribute.ReduceOp.SUM, grouped_outputs["batch_size"], axis=None
149 )
150 if (
151 backend.is_tpu_strategy(strategy)
152 and tf.compat.v1.executing_eagerly_outside_functions()
153 ):
154 # Choose 1 value per replica in the TPU case since all replicas produce
155 # the same output.
156 # We only do this in eager mode for now since this function is used in
157 # both graph and eager mode and in the graph case we currently don't use
158 # experimental_run so would need to be removed when we converge the
159 # graph code path as well.
160 output_losses = output_losses[:: strategy.num_replicas_in_sync]
161 metrics = metrics[:: strategy.num_replicas_in_sync]
162 return {
163 "total_loss": [total_loss],
164 "output_losses": output_losses,
165 "metrics": metrics,
166 "batch_size": batch_size,
167 }
170def unwrap_outputs(
171 distribution_strategy, grouped_outputs, with_loss_tensor=False
172):
173 """Unwrap the list of outputs contained in the PerReplica parameters.
175 This function calls `flatten_per_replica_values` to parse each of the input
176 parameters into a list of outputs on the different devices. If we set
177 `with_loss_tensor` to be True, we also call `reduce` on the list of losses
178 on the different devices to give us one loss tensor.
180 Args:
181 distribution_strategy: DistributionStrategy used to distribute training
182 and validation.
183 grouped_outputs: PerReplica outputs returned from the train or test
184 function that we ran on each device.
185 with_loss_tensor: Boolean that indicates if we need to add the reduced
186 loss tensor as one of the outputs.
188 Returns:
189 Values of each of the PerReplica outputs.
191 """
192 if not with_loss_tensor:
193 return flatten_per_replica_values(
194 distribution_strategy, grouped_outputs
195 )
197 if not isinstance(grouped_outputs, list):
198 grouped_outputs = [grouped_outputs]
199 # reduce loss tensor before adding it to the list of fetches
200 loss = distribution_strategy.reduce(
201 tf.distribute.ReduceOp.SUM, grouped_outputs[0], axis=None
202 )
203 all_outputs = flatten_per_replica_values(
204 distribution_strategy, grouped_outputs[1:]
205 )
206 if (
207 backend.is_tpu_strategy(distribution_strategy)
208 and tf.compat.v1.executing_eagerly_outside_functions()
209 ):
210 # Choose 1 value per replica in the TPU case since all replicas produce
211 # the same output.
212 # We only do this in eager mode for now since this function is used in
213 # both graph and eager mode and in the graph case we currently don't use
214 # experimental_run so would need to be removed when we converge the
215 # graph code path as well.
216 all_outputs = all_outputs[:: distribution_strategy.num_replicas_in_sync]
217 return [loss] + all_outputs
220def flatten_per_replica_values(distribution_strategy, per_replica_values):
221 """Unwraps and flattens a nest of PerReplica parameters.
223 PerReplica values have one value associated with each device. Each entry in
224 the PerReplica dict has a device `key` and the corresponding value on the
225 device as the `value`. In this function we take a PerReplica value or a list
226 of PerReplica values and return all the values in the PerReplica dict.
228 Args:
229 distribution_strategy: DistributionStrategy used to distribute training
230 and validation.
231 per_replica_values: List of PerReplica object or a single PerReplica
232 object.
234 Returns:
235 List of values of all the PerReplica objects.
237 """
239 # This function takes a PerReplica object or a list of PerReplica objects
240 # and returns all the values associated with it.
241 return [
242 e
243 for flattened in tf.nest.flatten(per_replica_values)
244 for e in distribution_strategy.unwrap(flattened)
245 ]
248def validate_callbacks(input_callbacks, optimizer):
249 """Validate whether given callbacks are supported by DistributionStrategy.
251 Args:
252 input_callbacks: List of callbacks passed by the user to fit.
253 optimizer: Optimizer instance used to train the model.
255 Raises:
256 ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of
257 the callbacks passed.
258 ValueError: If `write_grads` is one of the parameters passed as part of
259 the TensorBoard callback.
260 """
261 if input_callbacks:
262 for callback in input_callbacks:
263 if isinstance(
264 callback,
265 (callbacks.LearningRateScheduler, callbacks.ReduceLROnPlateau),
266 ):
268 if not isinstance(optimizer, optimizer_v2.OptimizerV2):
269 raise ValueError(
270 "You must specify a Keras Optimizer V2 when using "
271 "%s callback with DistributionStrategy." % callback
272 )
274 # If users want to use the TensorBoard callback they cannot use
275 # certain features of the callback that involve accessing model
276 # attributes and running ops.
277 if isinstance(callback, callbacks.TensorBoard):
278 if getattr(callback, "write_grads", False):
279 logging.warning(
280 UserWarning(
281 "`write_grads` in the TensorBoard callback is not "
282 "supported when using DistributionStrategy. "
283 "Setting `write_grads` to `False`."
284 )
285 )
286 callback.write_grads = False
289def validate_distributed_dataset_inputs(
290 distribution_strategy, x, y, sample_weights=None
291):
292 """Validate all the components of a DistributedValue Dataset input.
294 Args:
295 distribution_strategy: The current DistributionStrategy used to call
296 `fit`/`evaluate`.
297 x: Input Dataset DistributedValue object. For example, when we use
298 `MirroredStrategy` this is a PerReplica object with a tensor for each
299 device set in the dict. x can also be a tuple or dict. The keys of the
300 dict should match the names of the input layers of the model.
301 y: Target Dataset DistributedValue object. For example, when we use
302 `MirroredStrategy` this is a PerReplica object with a tensor for each
303 device set in the dict. y can also be a tuple or dict. The keys of the
304 dict should match the names of the output layers of the model.
305 sample_weights: Sample weights Dataset DistributedValue object. For
306 example, when we use `MirroredStrategy` this is a PerReplica object
307 with a tensor for each device set in the dict.
309 Returns:
310 The unwrapped values list of the x and y DistributedValues inputs.
312 Raises:
313 ValueError: If x and y do not have support for being evaluated as tensors.
314 or if x and y contain elements that are not tensors or if x and y
315 contain elements that have a shape or dtype mismatch.
316 """
317 # If the input and target used to call the model are not dataset tensors,
318 # we need to raise an error. When using a DistributionStrategy, the input
319 # and targets to a model should be from a `tf.data.Dataset`.
321 # If each element of x and y are not tensors, we cannot standardize and
322 # validate the input and targets.
323 x_values_list = validate_per_replica_inputs(distribution_strategy, x)
325 if y is not None:
326 y_values_list = validate_per_replica_inputs(distribution_strategy, y)
327 else:
328 y_values_list = None
330 if sample_weights is not None:
331 sample_weights_list = validate_per_replica_inputs(
332 distribution_strategy, sample_weights
333 )
334 else:
335 sample_weights_list = None
337 # Return the unwrapped values to avoid calling `unwrap` a second time.
338 return x_values_list, y_values_list, sample_weights_list
341def validate_per_replica_inputs(distribution_strategy, x):
342 """Validates PerReplica dataset input list.
344 Args:
345 distribution_strategy: The current DistributionStrategy used to call
346 `fit`, `evaluate` and `predict`.
347 x: A list of PerReplica objects that represent the input or
348 target values.
350 Returns:
351 List containing the first element of each of the PerReplica objects in
352 the input list.
354 Raises:
355 ValueError: If any of the objects in the `per_replica_list` is not a
356 tensor.
358 """
359 # Convert the inputs and targets into a list of PerReplica objects.
360 per_replica_list = tf.nest.flatten(x)
361 x_values_list = []
362 for x in per_replica_list:
363 # At this point x should contain only tensors.
364 x_values = distribution_strategy.unwrap(x)
365 for value in x_values:
366 if not tf.is_tensor(value):
367 raise ValueError(
368 "Dataset input to the model should be tensors instead "
369 "they are of type {}".format(type(value))
370 )
372 if not tf.executing_eagerly():
373 # Validate that the shape and dtype of all the elements in x are the
374 # same.
375 validate_all_tensor_shapes(x, x_values)
376 validate_all_tensor_types(x, x_values)
378 x_values_list.append(x_values[0])
379 return x_values_list
382def validate_all_tensor_types(x, x_values):
383 x_dtype = x_values[0].dtype
384 for i in range(1, len(x_values)):
385 if x_dtype != x_values[i].dtype:
386 raise ValueError(
387 "Input tensor dtypes do not match for distributed tensor"
388 " inputs {}".format(x)
389 )
392def validate_all_tensor_shapes(x, x_values):
393 # Validate that the shape of all the elements in x have the same shape
394 x_shape = x_values[0].shape.as_list()
395 for i in range(1, len(x_values)):
396 if x_shape != x_values[i].shape.as_list():
397 raise ValueError(
398 "Input tensor shapes do not match for distributed tensor"
399 " inputs {}".format(x)
400 )
403def _wait_for_variable_initialization(session):
404 """Utility to wait for variables to be initialized."""
405 all_variables = backend._get_variables(backend.get_graph())
406 candidate_vars = []
407 for v in all_variables:
408 if not getattr(v, "_keras_initialized", False):
409 candidate_vars.append(v)
411 if not candidate_vars:
412 return
414 while True:
415 is_initialized = session.run(
416 [tf.compat.v1.is_variable_initialized(v) for v in candidate_vars]
417 )
418 uninitialized_vars = []
419 for flag, v in zip(is_initialized, candidate_vars):
420 if not flag:
421 uninitialized_vars.append(v)
422 v._keras_initialized = True
423 if not uninitialized_vars:
424 break
427def init_restore_or_wait_for_variables():
428 """Initialize or restore variables or wait for variables to be
429 initialized."""
430 backend._initialize_variables(backend._get_session())
433def validate_inputs(x, y):
434 """Validate inputs when using DistributionStrategy.
436 Args:
437 x: Model Inputs.
438 y: Model Targets.
440 Raises:
441 ValueError: if input is not a Dataset or a numpy array(when we use
442 MirroredStrategy).
443 """
444 if isinstance(x, tf.compat.v1.data.Iterator) or isinstance(
445 y, tf.compat.v1.data.Iterator
446 ):
447 raise ValueError(
448 "`DistributionStrategy` does not support inputs of type "
449 "Iterator. You must pass a `tf.data.Dataset` object or a "
450 "numpy array as input."
451 )
454def is_dataset_shape_fully_defined(dataset):
455 """Returns whether a dataset contains a final partial batch."""
456 shapes = tf.nest.flatten(tf.compat.v1.data.get_output_shapes(dataset))
457 unknown_shapes = [s for s in shapes if not s.is_fully_defined()]
458 return not unknown_shapes
461def process_batch_and_step_size(
462 strategy, inputs, batch_size, steps_per_epoch, mode, validation_split=0.0
463):
464 """Process the batch size and step size based on input and dist strategy."""
465 first_x_value = tf.nest.flatten(inputs)[0]
466 if isinstance(first_x_value, np.ndarray):
467 num_samples = first_x_value.shape[0]
468 if validation_split and 0.0 < validation_split < 1.0:
469 num_samples = int(num_samples * (1 - validation_split))
470 # Until support for partial batch is implemented across all
471 # functions and distribution strategy, we pass `mode` to selectively
472 # relax the constraint to consume all the training samples.
473 steps_per_epoch, batch_size = get_input_params(
474 strategy, num_samples, steps_per_epoch, batch_size, mode=mode
475 )
476 return batch_size, steps_per_epoch
479def get_input_params(
480 distribution_strategy, num_samples, steps, batch_size, mode=None
481):
482 """Calculate the number of batches and steps/steps_per_epoch.
484 Args:
485 distribution_strategy: The DistributionStrategy used to compile the model.
486 num_samples: The number of samples from which we determine the batch size
487 and steps.
488 steps: The specified number of steps.
489 batch_size: The specified batch_size.
490 mode: ModeKey representing whether input will be used for training,
491 evaluation, or prediction. This is used to relax the constraints on
492 consuming all the training samples to keep compatibility till we support
493 partial batches. If none, then partial batches are not allowed.
495 Returns:
496 steps: The steps or steps_per_epoch argument depending on if a user is
497 calling `fit`, `evaluate` or `predict`. If the is_training flag is set
498 we don't require the number of samples to be used completely.
499 batch_size: The batch size to be used in model iterations.
501 Raises:
502 ValueError: If the number of batches or steps evaluates to 0.
504 """
505 # TODO(b/118776054): Use global batch size for Keras/DS support.
506 # Currently this is only supported in TPUStrategy and CoreMirroredStrategy.
507 use_per_replica_batch = not dist_utils.global_batch_size_supported(
508 distribution_strategy
509 )
511 # TODO(b/128995245): In eager mode, uneven batch sizes are allowed except
512 # for `fit()` on TPUStrategy.
513 # In graph mode, the zero batch case in batch norm is not handled due to
514 # XLA-GPU regression. Uneven batch sizes are not allowed except
515 # for `test()` and `predict()` on TPUStrategy.
516 if tf.executing_eagerly():
517 allow_partial_batch = (
518 mode != ModeKeys.TRAIN
519 or not backend.is_tpu_strategy(distribution_strategy)
520 )
521 else:
522 allow_partial_batch = mode == ModeKeys.TRAIN or (
523 (mode == ModeKeys.PREDICT or mode == ModeKeys.TEST)
524 and backend.is_tpu_strategy(distribution_strategy)
525 )
527 if steps is None:
528 if batch_size is None:
529 # If neither the batch size or number of steps are set. We choose
530 # the global batch size as the minimum of number of samples and 32.
531 # 32 is chosen to provide backward compatibility.
532 global_batch_size = min(num_samples, 32)
533 else:
534 # If the user provided the batch size we need to handle the case
535 # between different strategies that use the global/per-replica batch
536 # size
537 global_batch_size = batch_size
538 if use_per_replica_batch:
539 global_batch_size *= distribution_strategy.num_replicas_in_sync
540 if allow_partial_batch:
541 steps = np.ceil(num_samples / global_batch_size).astype(int)
542 else:
543 if num_samples % global_batch_size:
544 raise ValueError(
545 "The number of samples %s is not divisible by "
546 "batch size %s." % (num_samples, global_batch_size)
547 )
548 steps = num_samples // global_batch_size
549 else:
550 if batch_size is None:
551 # We calculate the batch size based on the number of steps specified
552 if num_samples % steps:
553 raise ValueError(
554 "The number of samples %s is not divisible by "
555 "steps %s. Please change the number of steps to a "
556 "value that can consume all the samples"
557 % (num_samples, steps)
558 )
559 global_batch_size = num_samples // steps
560 else:
561 # If the user provided the batch size we need to handle the case
562 # between different strategies that use the global/per-replica batch
563 # size
564 global_batch_size = batch_size
565 if use_per_replica_batch:
566 global_batch_size *= distribution_strategy.num_replicas_in_sync
568 min_num_samples = global_batch_size * steps
569 if allow_partial_batch:
570 min_num_samples = (
571 global_batch_size * (steps - 1) + 1 if steps > 1 else 0
572 )
574 if num_samples < min_num_samples:
575 raise ValueError(
576 "Number of samples %s is less than samples required "
577 "for specified batch_size %s and steps %s"
578 % (num_samples, global_batch_size, steps)
579 )
581 # We need to return the per replica or global batch size based on the
582 # strategy
583 if use_per_replica_batch:
584 if global_batch_size % distribution_strategy.num_replicas_in_sync:
585 raise ValueError(
586 "The batch size (%s) could not be sharded evenly across the "
587 "sync replicas (%s) in the distribution strategy."
588 % (
589 global_batch_size,
590 distribution_strategy.num_replicas_in_sync,
591 )
592 )
593 batch_size = (
594 global_batch_size // distribution_strategy.num_replicas_in_sync
595 )
596 else:
597 batch_size = global_batch_size
599 return steps, batch_size
602def get_batch_dimension(iterator):
603 shapes = tf.nest.flatten(tf.compat.v1.data.get_output_shapes(iterator))
604 # Take the batch size from the first element, as it should be the same for
605 # all.
606 dims = shapes[0].dims
607 return dims[0] if dims else None
610def get_iterator(dataset, distribution_strategy):
611 with distribution_strategy.scope():
612 iterator = distribution_strategy.make_dataset_iterator(dataset)
613 initialize_iterator(iterator, distribution_strategy)
614 return iterator
617def initialize_iterator(iterator, distribution_strategy):
618 with distribution_strategy.scope():
619 init_op = tf.group(iterator.initializer)
620 if not tf.executing_eagerly():
621 backend.get_session((init_op,)).run(init_op)
624def _get_input_from_iterator(iterator, model):
625 """Get elements from the iterator and verify the input shape and type."""
626 next_element = iterator.get_next()
628 # `len(nest.flatten(x))` is going to not count empty elements such as {}.
629 # len(nest.flatten([[0,1,2], {}])) is 3 and not 4. The `next_element` is
630 # going to get flattened in `_prepare_feed_values` to work around that.
631 # Empty elements are going to get filtered out as part of the flattening.
632 if len(tf.nest.flatten(next_element)) == len(model.inputs):
633 x = next_element
634 y = None
635 sample_weights = None
636 elif len(tf.nest.flatten(next_element)) == (
637 len(model.inputs) + len(model.outputs)
638 ):
639 x, y = next_element
640 sample_weights = None
641 else:
642 x, y, sample_weights = next_element
644 # Validate that all the elements in x and y are of the same type and shape.
645 validate_distributed_dataset_inputs(
646 model._distribution_strategy, x, y, sample_weights
647 )
648 return x, y, sample_weights
651def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
652 """Prepare feed values to the model execution function.
654 Args:
655 model: Model to prepare feed values for.
656 inputs: List or dict of model inputs.
657 targets: Optional list of model targets.
658 sample_weights: Optional list of sample weight arrays.
659 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
661 Returns:
662 Feed values for the model in the given mode.
663 """
664 strategy = model._distribution_strategy
665 inputs, targets, sample_weights = _get_input_from_iterator(inputs, model)
666 if backend.is_tpu_strategy(strategy):
667 if sample_weights is not None:
668 raise ValueError("TPUStrategy does not support sample weights.")
670 # When the inputs are dict, then we want to flatten it in the same order as
671 # the input layers, such that the data are fed into the input layers in the
672 # correct order.
673 if isinstance(inputs, dict):
674 inputs = [inputs[key] for key in model._feed_input_names]
675 if is_distributing_by_cloning(model):
676 inputs = flatten_per_replica_values(strategy, inputs)
677 targets = flatten_per_replica_values(strategy, targets)
678 # Expand 1-dimensional inputs.
679 # TODO(b/124535720): Remove once this standarize data logic is shared
680 # with main flow.
681 inputs, targets = tf.nest.map_structure(
682 training_utils_v1.standardize_single_array, (inputs, targets)
683 )
684 else:
685 inputs = training_utils_v1.ModelInputs(inputs).as_list()
687 if mode == ModeKeys.PREDICT:
688 sample_weights = []
689 targets = []
690 elif sample_weights is not None and is_distributing_by_cloning(model):
691 if tf.executing_eagerly() and not model._compile_distribution:
692 raise NotImplementedError(
693 "`sample_weight` is not supported when using "
694 "tf.distribute.Strategy in eager mode and "
695 "cloning=True."
696 )
697 sample_weights = flatten_per_replica_values(strategy, sample_weights)
699 ins = [inputs, targets, sample_weights]
700 return tuple(ins)
703def is_distributing_by_cloning(model):
704 """Decide whether this model is going to be distributed via cloning.
706 We are going to distribute the model by cloning in graph mode.
708 Args:
709 model: Keras model to distribute.
711 Returns:
712 True if the `model` is going to be distributed using cloning and False
713 otherwise.
714 """
715 if (
716 backend.is_tpu_strategy(model._distribution_strategy)
717 and tf.executing_eagerly
718 ): # b/137580852
719 return False
720 elif tf.compat.v1.executing_eagerly_outside_functions():
721 return bool(model._compile_distribution)
722 return True
725def _custom_compile_for_predict(model):
726 """Custom compile for TPU predict mode."""
727 if not model.built:
728 # Model is not compilable because it does not know its number of inputs
729 # and outputs, nor their shapes and names. We will compile after the
730 # first time the model gets called on training data.
731 return
732 model._is_compiled = True
733 model.total_loss = None
734 model.train_function = None
735 model.test_function = None
736 model.predict_function = None
739def _build_network_on_replica(model, mode, inputs=None, targets=None):
740 """Build an updated model on replicas.
742 We create a new Keras model while sharing the variables from the old graph.
743 Building a new sub-graph is required since the original keras model creates
744 placeholders for the input and the output that are not accessible till we
745 call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`.
747 The sharing of weights and layers between the old and the new model
748 guarantee that we're using Strategy variables and any updates on either
749 model are reflected correctly in callbacks and loop iterations.
751 We need to make sure we share the optimizers between the old and the new
752 model as well so that optimizer state is not lost if the user is running fit
753 multiple times.
755 Args:
756 model: Model to be replicated across Replicas
757 mode: Which of fit/eval/predict is building the distributed network
758 inputs: Input variables to be passed to the model
759 targets: Target tensor to be passed to model.compile
761 Returns:
762 A new model with shared layers with the old model.
763 """
764 # Need to do imports here since we run into a circular dependency error.
765 from keras.src import models
766 from keras.src.engine import sequential
768 # We rely on the internal methods to avoid having share_weights weights in
769 # the public API.
770 if isinstance(model, sequential.Sequential):
771 updated_model = models._clone_sequential_model(
772 model, input_tensors=inputs, layer_fn=models.share_weights
773 )
774 else:
775 updated_model = models._clone_functional_model(
776 model, input_tensors=inputs, layer_fn=models.share_weights
777 )
778 # Callable losses added directly to a functional Model need to be added
779 # here.
780 updated_model._callable_losses = model._callable_losses
782 # Recast all low precision outputs back to float32 since we only casted the
783 # inputs to bfloat16 and not targets. This is done so that we can preserve
784 # precision when calculating the loss value.
785 def _upcast_low_precision_outputs(output):
786 if output.dtype == tf.bfloat16:
787 return tf.cast(output, tf.float32)
788 else:
789 return output
791 updated_model.outputs = [
792 _upcast_low_precision_outputs(o) for o in updated_model.outputs
793 ]
795 if isinstance(targets, tuple):
796 targets = tf.nest.flatten(targets)
798 if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case
799 _custom_compile_for_predict(updated_model)
800 else:
801 updated_model.compile(
802 model.optimizer,
803 model.loss,
804 metrics=metrics_module.clone_metrics(model._compile_metrics),
805 loss_weights=model.loss_weights,
806 sample_weight_mode=model.sample_weight_mode,
807 weighted_metrics=metrics_module.clone_metrics(
808 model._compile_weighted_metrics
809 ),
810 target_tensors=targets,
811 )
812 return updated_model
815def _build_distributed_network(
816 model, strategy, mode, inputs=None, targets=None
817):
818 """Create a cloned model on each replica."""
819 with backend.get_graph().as_default(), strategy.scope():
820 distributed_model = strategy.extended.call_for_each_replica(
821 _build_network_on_replica, args=(model, mode, inputs, targets)
822 )
823 set_distributed_model(model, mode, distributed_model)
826def _clone_and_build_model(model, mode, inputs=None, targets=None):
827 """Clone and build the given keras_model."""
828 # We need to set the import here since we run into a circular dependency
829 # error.
830 from keras.src import models
832 cloned_model = models.clone_model(model, input_tensors=inputs)
834 # Compile and build model.
835 if isinstance(model.optimizer, optimizers.TFOptimizer):
836 optimizer = model.optimizer
837 else:
838 optimizer_config = model.optimizer.get_config()
839 optimizer = model.optimizer.__class__.from_config(optimizer_config)
841 # Recast all low precision outputs back to float32 since we only casted
842 # the inputs to bfloat16 and not targets. This is done so that we can
843 # preserve precision when calculating the loss value.
844 def _upcast_low_precision_outputs(output):
845 if output.dtype == tf.bfloat16:
846 return tf.cast(output, tf.float32)
847 else:
848 return output
850 cloned_model.outputs = [
851 _upcast_low_precision_outputs(o) for o in cloned_model.outputs
852 ]
854 if isinstance(targets, tuple):
855 targets = tf.nest.flatten(targets)
856 if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case
857 _custom_compile_for_predict(cloned_model)
858 else:
859 cloned_model.compile(
860 optimizer,
861 model.loss,
862 metrics=metrics_module.clone_metrics(model._compile_metrics),
863 loss_weights=model.loss_weights,
864 sample_weight_mode=model.sample_weight_mode,
865 weighted_metrics=metrics_module.clone_metrics(
866 model._compile_weighted_metrics
867 ),
868 target_tensors=targets,
869 )
870 return cloned_model
873def clone_model_on_replicas(model, strategy, mode, inputs=None, targets=None):
874 """Create a cloned model on each replica."""
875 with backend.get_graph().as_default(), strategy.scope():
876 distributed_model = strategy.extended.call_for_each_replica(
877 _clone_and_build_model, args=(model, mode, inputs, targets)
878 )
879 set_distributed_model(model, mode, distributed_model)
880 if mode == ModeKeys.TRAIN:
881 model._make_callback_model(distributed_model)
884def _make_execution_function(model, mode):
885 """Makes or reuses function to run one step of distributed model
886 execution."""
887 if is_distributing_by_cloning(model):
888 return _make_execution_function_with_cloning(model, mode)
890 distributed_function = get_distributed_function(model, mode)
891 if distributed_function:
892 return distributed_function
894 distribution_function = _make_execution_function_without_cloning(
895 model, mode
896 )
897 set_distributed_function(model, mode, distribution_function)
898 return distribution_function
901def _make_execution_function_without_cloning(model, mode):
902 """Creates a function to run one step of distributed model execution."""
903 strategy = model._distribution_strategy
905 with strategy.scope():
906 per_replica_function = _make_replica_execution_function(model, mode)
908 def distributed_function(input_fn):
909 """A single step of the distributed execution across replicas."""
910 x, y, sample_weights = input_fn()
911 # Call `Model.{train,test,predict}_on_batch` on every replica
912 # passing PerReplicas as arguments. On every replica inside this
913 # call, each PerReplica object will return the value for that
914 # replica. The outputs are PerReplicas too.
915 outputs = strategy.run(
916 per_replica_function, args=(x, y, sample_weights)
917 )
918 # Out of PerReplica outputs reduce or pick values to return.
919 all_outputs = unwrap_outputs(
920 strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT)
921 )
922 return all_outputs
924 if not model.run_eagerly:
925 distributed_function = tf.function(distributed_function)
927 def execution_function(input_fn):
928 # `numpy` translates Tensors to values in Eager mode.
929 return [out.numpy() for out in distributed_function(input_fn)]
931 else:
932 execution_function = distributed_function
934 return execution_function
937def _make_replica_execution_function(model, mode):
938 """A single step of the distributed execution on a replica."""
939 if mode == ModeKeys.TRAIN:
940 func = model.train_on_batch
941 elif mode == ModeKeys.TEST:
942 func = model.test_on_batch
943 else:
945 def predict_on_batch(x, y=None, sample_weights=None):
946 del y, sample_weights
947 return model.predict_on_batch(x)
949 func = predict_on_batch
951 if mode != ModeKeys.PREDICT:
952 # `reset_metrics` is set to False to maintain stateful metrics across
953 # batch-level calls.
954 func = functools.partial(func, reset_metrics=False)
956 return func
959def _make_replicated_models_with_cloning(model, mode):
960 """Build models on each replica."""
961 strategy = model._distribution_strategy
963 # If distributed_model is not built, create one for `mode`.
964 if model._compile_distribution:
965 clone_model_on_replicas(model, strategy, mode)
966 else:
967 _build_distributed_network(model, strategy, mode)
970def _make_execution_function_with_cloning(model, mode):
971 """Clones or re-uses models to run one step of distributed model
972 execution."""
973 distributed_model = get_distributed_model(model, mode)
974 # TODO(b/134069401): Create a cache for the distributed model and exec
975 # function that incorporates additional attributes to be part of the cache
976 # key than just the mode.
977 # If distributed model for a particular `mode` is already built, use the
978 # `_distribution_function` on that distributed model.
979 # If you have updated the sample_weight_mode on the model, then you will
980 # need to recompile metrics and recreate the execution function. This is
981 # indicated by the `_recompile_exec_function` property.
982 if (
983 distributed_model
984 and hasattr(distributed_model, "_distribution_function")
985 and not (
986 hasattr(distributed_model, "_recompile_exec_function")
987 and distributed_model._recompile_exec_function
988 )
989 ):
990 return distributed_model._distributed_function
992 if not distributed_model:
993 _make_replicated_models_with_cloning(model, mode)
994 distributed_model = get_distributed_model(model, mode)
995 assert distributed_model
997 # Also create an execution function on that distributed model.
998 if tf.executing_eagerly():
999 distributed_function = _make_eager_execution_function(model, mode)
1000 else:
1001 distributed_function = _make_graph_execution_function(model, mode)
1003 # We cache the distributed execution function on the model since creating
1004 # distributed models and execution functions are expensive.
1005 distributed_model._distributed_function = distributed_function
1006 distributed_model._recompile_exec_function = False
1007 return distributed_function
1010def _make_graph_execution_function(model, mode):
1011 """Makes function to run one step of distributed model in graph mode."""
1013 def _per_replica_function(model):
1014 f = model._make_execution_function(mode)
1015 return (f.inputs, f.outputs, f.updates_op, f.session_kwargs)
1017 strategy = model._distribution_strategy
1018 with strategy.scope():
1019 # Create train ops on each of the devices when we call
1020 # `_per_replica_fit_function`.
1021 (
1022 grouped_inputs,
1023 grouped_outputs,
1024 grouped_updates,
1025 grouped_session_args,
1026 ) = strategy.extended.call_for_each_replica(
1027 _per_replica_function, args=(get_distributed_model(model, mode),)
1028 )
1030 # Initialize the variables in the replicated model. This is necessary
1031 # for multi-worker training because on some workers, initialization is
1032 # not needed. This method does initialization or waiting for
1033 # initialization according to the context object of distribute
1034 # coordinator.
1035 init_restore_or_wait_for_variables()
1037 # Unwrap all the per device values returned from
1038 # `call_for_each_replica`. Unwrapping per device values gives you a
1039 # list of values that can be used to construct a new train function that
1040 # is composed of update ops on all the devices over which the model is
1041 # distributed.
1042 (
1043 all_inputs,
1044 all_outputs,
1045 all_updates,
1046 all_session_args,
1047 ) = unwrap_values(
1048 strategy,
1049 grouped_inputs,
1050 grouped_outputs,
1051 grouped_updates,
1052 grouped_session_args,
1053 with_loss_tensor=(mode != ModeKeys.PREDICT),
1054 )
1056 return backend.function(
1057 all_inputs,
1058 all_outputs,
1059 updates=all_updates,
1060 name=f"distributed_{mode}_function",
1061 **all_session_args,
1062 )
1065def _make_eager_execution_function(model, mode):
1066 """Makes function to run one step of distributed model eager execution."""
1068 def _per_replica_function(model):
1069 f = model._make_execution_function(mode)
1070 return (f.inputs, f.outputs)
1072 # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of
1073 # using the global one.
1074 strategy = model._distribution_strategy
1075 global_graph = backend.get_graph()
1077 with global_graph.as_default(), strategy.scope():
1078 # First we gather the relevant portions of the model across all
1079 # replicas. `backend._scratch_graph(global_graph)` signals to Keras
1080 # that it should not lift to a separate graph when creating the
1081 # per-replica functions.
1082 with backend._scratch_graph(global_graph):
1083 # Create train ops on each of the devices when we call
1084 # `_per_replica_fit_function`.
1085 grouped = strategy.extended.call_for_each_replica(
1086 _per_replica_function,
1087 args=(get_distributed_model(model, mode),),
1088 )
1089 grouped_inputs, grouped_outputs = grouped
1091 # Unwrap all the per device values returned from
1092 # `call_for_each_replica`. Unwrapping per device values gives you a
1093 # list of values that can be used to construct a new train function
1094 # that is composed of inputs/outputs on all the devices over which
1095 # the model is distributed.
1096 (all_inputs, all_outputs, _, _) = unwrap_values(
1097 strategy,
1098 grouped_inputs,
1099 grouped_outputs,
1100 with_loss_tensor=(mode != ModeKeys.PREDICT),
1101 )
1103 # Finally, a joint Keras function is created; this one will be created
1104 # in a separate FuncGraph.
1105 return backend.function(
1106 all_inputs,
1107 all_outputs,
1108 name=f"eager_distributed_{mode}_function",
1109 )
1112def _copy_weights_to_distributed_model(original_model, mode):
1113 """Copies weights from original model to distributed models."""
1114 strategy = original_model._distribution_strategy
1115 distributed_model = get_distributed_model(original_model, mode)
1116 if strategy:
1117 # Copy the weights from the original model to each of the replicated
1118 # models.
1119 orig_model_weights = original_model.get_weights()
1120 first_model = strategy.unwrap(distributed_model)[0]
1121 set_weights(strategy, first_model, orig_model_weights)
1124def _copy_weights_to_original_model(model, mode):
1125 """Copies weights from first distributed model back to original model."""
1126 if model._distribution_strategy and mode == ModeKeys.TRAIN:
1127 distributed_model = get_distributed_model(model, mode)
1128 updated_weights = model._distribution_strategy.unwrap(
1129 distributed_model
1130 )[0].get_weights()
1131 model.set_weights(updated_weights)
1134def _per_replica_aggregate_batch(strategy, batch_outs, model, mode):
1135 """Aggregates the per-replica batch-level outputs from a distributed
1136 step."""
1137 if strategy is not None and mode == ModeKeys.PREDICT:
1138 total_batch_outs = []
1139 for i in range(len(model.outputs)):
1140 num_replicas = strategy.num_replicas_in_sync
1141 nested_outs = batch_outs[
1142 i * num_replicas : i * num_replicas + num_replicas
1143 ]
1144 total_batch_outs.append(
1145 concat_along_batch_dimension(tf.nest.flatten(nested_outs))
1146 )
1147 return total_batch_outs
1148 return batch_outs
1151def _reset_metrics(model):
1152 if model._distribution_strategy:
1153 for mode in [ModeKeys.TRAIN, ModeKeys.TEST, ModeKeys.PREDICT]:
1154 distributed_model = get_distributed_model(model, mode)
1155 if distributed_model:
1156 first_model = model._distribution_strategy.unwrap(
1157 distributed_model
1158 )[0]
1159 first_model.reset_metrics()
1162def get_distributed_model(model, mode):
1163 key = _generate_cache_key(mode)
1164 return model._distributed_model_cache.get(key, None)
1167def set_distributed_model(model, mode, distributed_model):
1168 key = _generate_cache_key(mode)
1169 model._distributed_model_cache[key] = distributed_model
1172def get_distributed_function(model, mode):
1173 key = _generate_cache_key(mode)
1174 return model._distributed_function_cache.get(key, None)
1177def set_distributed_function(model, mode, distributed_function):
1178 key = _generate_cache_key(mode)
1179 model._distributed_function_cache[key] = distributed_function
1182def _generate_cache_key(mode):
1183 key = hash(mode)
1184 return key
1187@tf_contextlib.contextmanager
1188def distributed_scope(strategy, learning_phase):
1189 with strategy.scope(), backend.learning_phase_scope(learning_phase):
1190 yield
1193def is_current_worker_chief():
1194 return dc.get_current_worker_context().is_chief
1197def filter_distributed_callbacks(callbacks_list, model):
1198 """Filter Callbacks based on the worker context when running multi-worker.
1200 Args:
1201 callbacks_list: A list of `Callback` instances.
1202 model: Keras model instance.
1204 Returns:
1205 The list of `Callback` instances that should be run on this worker.
1206 """
1208 if not model._in_multi_worker_mode():
1209 raise ValueError(
1210 "filter_distributed_callbacks() should only be called when Keras "
1211 "is in multi worker mode."
1212 )
1214 callbacks_list = callbacks_list or []
1215 if not [
1216 c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint)
1217 ]:
1218 # TODO(rchao): Consider providing a ModelCheckpoint here if the user
1219 # fails to (possibly with tempfile directory).
1220 logging.warning(
1221 "ModelCheckpoint callback is not provided. "
1222 "Workers will need to restart training if any fails."
1223 )
1225 if callbacks_list is None or is_current_worker_chief():
1226 return callbacks_list
1228 # Some Callbacks should only run on the chief worker.
1229 return [
1230 callback
1231 for callback in callbacks_list
1232 if not callback._chief_worker_only
1233 ]
1236def _update_sample_weight_modes(model, mode, sample_weights):
1237 """Update sample_weight_mode of the distributed model."""
1238 if is_distributing_by_cloning(model):
1239 distributed_model = get_distributed_model(model, mode)
1240 if not distributed_model:
1241 _make_replicated_models_with_cloning(model, mode)
1242 distributed_model = get_distributed_model(model, mode)
1243 distributed_model._recompile_exec_function = any(
1244 [e.sample_weights_mismatch() for e in model._training_endpoints]
1245 )
1247 if sample_weights:
1248 distributed_models = flatten_per_replica_values(
1249 model._distribution_strategy, distributed_model
1250 )
1251 # sample_weights is a tuple of 1 list where the number of elements
1252 # in the list is equal to the number of replicas in sync.
1253 sample_weights = sample_weights[0]
1254 if sample_weights and None not in sample_weights:
1255 for m, sw in zip(distributed_models, sample_weights):
1256 m._update_sample_weight_modes(sample_weights=[sw])
1259def concat_along_batch_dimension(outputs):
1260 """Concats prediction outputs along the batch dimension."""
1261 if isinstance(outputs[0], tf.SparseTensor):
1262 return tf.sparse.concat(axis=0, sp_inputs=outputs)
1263 if isinstance(outputs[0], tf.RaggedTensor):
1264 return tf.concat(outputs, axis=0)
1265 return np.concatenate(outputs)