Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/distribute/distributed_training_utils.py: 23%
48 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to distributed training."""
17import contextlib
19import tensorflow.compat.v2 as tf
20from absl import flags
22from keras.src import backend
24FLAGS = flags.FLAGS
27# TODO(b/118776054): Currently we support global batch size for TPUStrategy and
28# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is
29# no longer needed.
30def global_batch_size_supported(distribution_strategy):
31 return distribution_strategy.extended._global_batch_size
34def call_replica_local_fn(fn, *args, **kwargs):
35 """Call a function that uses replica-local variables.
37 This function correctly handles calling `fn` in a cross-replica
38 context.
40 Args:
41 fn: The function to call.
42 *args: Positional arguments to the `fn`.
43 **kwargs: Keyword argument to `fn`.
45 Returns:
46 The result of calling `fn`.
47 """
48 # TODO(b/132666209): Remove this function when we support assign_*
49 # for replica-local variables.
50 strategy = None
51 if "strategy" in kwargs:
52 strategy = kwargs.pop("strategy")
53 else:
54 if tf.distribute.has_strategy():
55 strategy = tf.distribute.get_strategy()
57 # TODO(b/120571621): TPUStrategy does not implement replica-local variables.
58 is_tpu = backend.is_tpu_strategy(strategy)
59 if (not is_tpu) and strategy and tf.distribute.in_cross_replica_context():
60 with strategy.scope():
61 return strategy.extended.call_for_each_replica(fn, args, kwargs)
62 return fn(*args, **kwargs)
65def is_distributed_variable(v):
66 """Returns whether `v` is a distributed variable."""
67 return isinstance(v, tf.distribute.DistributedValues) and isinstance(
68 v, tf.Variable
69 )
72def get_strategy():
73 """Creates a `tf.distribute.Strategy` object from flags.
75 Example usage:
77 ```python
78 strategy = utils.get_strategy()
79 with strategy.scope():
80 model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
82 model.compile(...)
83 train_ds, test_ds = ...
84 model.fit(train_ds, validation_data=test_ds, epochs=10)
85 ```
87 Returns:
88 `tf.distribute.Strategy` instance.
89 """
90 cls = FLAGS.keras_distribute_strategy_class
91 accepted_strats = {
92 "tpu",
93 "multi_worker_mirrored",
94 "mirrored",
95 "parameter_server",
96 "one_device",
97 }
98 if cls == "tpu":
99 tpu_addr = FLAGS.keras_distribute_strategy_tpu_addr
100 if not tpu_addr:
101 raise ValueError(
102 "When using a TPU strategy, you must set the flag "
103 "`keras_distribute_strategy_tpu_addr` (TPU address)."
104 )
105 cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
106 tpu=tpu_addr
107 )
108 tf.config.experimental_connect_to_cluster(cluster_resolver)
109 tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
110 strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
111 elif cls == "multi_worker_mirrored":
112 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
113 elif cls == "mirrored":
114 strategy = tf.distribute.MirroredStrategy()
115 elif cls == "parameter_server":
116 cluster_resolver = (
117 tf.distribute.cluster_resolver.TFConfigClusterResolver()
118 )
119 strategy = tf.distribute.experimental.ParameterServerStrategy(
120 cluster_resolver
121 )
122 elif cls == "one_device":
123 strategy = tf.distribute.OneDeviceStrategy("/gpu:0")
124 else:
125 raise ValueError(
126 "Unknown distribution strategy flag. Received: "
127 f"keras_distribute_strategy_class={cls}. "
128 f"It should be one of {accepted_strats}"
129 )
130 return strategy
133def maybe_preemption_handler_scope(model):
135 if getattr(model, "_preemption_handler", None):
136 preemption_checkpoint_scope = (
137 model._preemption_handler.watch_preemption_scope()
138 )
139 else:
140 preemption_checkpoint_scope = contextlib.nullcontext()
142 return preemption_checkpoint_scope