1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to distributed training."""
16# pylint:disable=protected-access
17
18from tensorflow.python.distribute import distribute_lib
19from tensorflow.python.distribute import values as values_lib
20from tensorflow.python.keras import backend
21from tensorflow.python.ops import variables
22
23
24# TODO(b/118776054): Currently we support global batch size for TPUStrategy and
25# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is
26# no longer needed.
27def global_batch_size_supported(distribution_strategy):
28 return distribution_strategy.extended._global_batch_size # pylint: disable=protected-access
29
30
31def call_replica_local_fn(fn, *args, **kwargs):
32 """Call a function that uses replica-local variables.
33
34 This function correctly handles calling `fn` in a cross-replica
35 context.
36
37 Args:
38 fn: The function to call.
39 *args: Positional arguments to the `fn`.
40 **kwargs: Keyword argument to `fn`.
41
42 Returns:
43 The result of calling `fn`.
44 """
45 # TODO(b/132666209): Remove this function when we support assign_*
46 # for replica-local variables.
47 strategy = None
48 if 'strategy' in kwargs:
49 strategy = kwargs.pop('strategy')
50 else:
51 if distribute_lib.has_strategy():
52 strategy = distribute_lib.get_strategy()
53
54 # TODO(b/120571621): TPUStrategy does not implement replica-local variables.
55 is_tpu = backend.is_tpu_strategy(strategy)
56 if ((not is_tpu) and strategy and distribute_lib.in_cross_replica_context()):
57 with strategy.scope():
58 return strategy.extended.call_for_each_replica(fn, args, kwargs)
59 return fn(*args, **kwargs)
60
61
62def is_distributed_variable(v):
63 """Returns whether `v` is a distributed variable."""
64 return (isinstance(v, values_lib.DistributedValues) and
65 isinstance(v, variables.Variable))