Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/distribute/distributed_training_utils.py: 40%

20 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities related to distributed training.""" 

16# pylint:disable=protected-access 

17 

18from tensorflow.python.distribute import distribute_lib 

19from tensorflow.python.distribute import values as values_lib 

20from tensorflow.python.keras import backend 

21from tensorflow.python.ops import variables 

22 

23 

24# TODO(b/118776054): Currently we support global batch size for TPUStrategy and 

25# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is 

26# no longer needed. 

27def global_batch_size_supported(distribution_strategy): 

28 return distribution_strategy.extended._global_batch_size # pylint: disable=protected-access 

29 

30 

31def call_replica_local_fn(fn, *args, **kwargs): 

32 """Call a function that uses replica-local variables. 

33 

34 This function correctly handles calling `fn` in a cross-replica 

35 context. 

36 

37 Args: 

38 fn: The function to call. 

39 *args: Positional arguments to the `fn`. 

40 **kwargs: Keyword argument to `fn`. 

41 

42 Returns: 

43 The result of calling `fn`. 

44 """ 

45 # TODO(b/132666209): Remove this function when we support assign_* 

46 # for replica-local variables. 

47 strategy = None 

48 if 'strategy' in kwargs: 

49 strategy = kwargs.pop('strategy') 

50 else: 

51 if distribute_lib.has_strategy(): 

52 strategy = distribute_lib.get_strategy() 

53 

54 # TODO(b/120571621): TPUStrategy does not implement replica-local variables. 

55 is_tpu = backend.is_tpu_strategy(strategy) 

56 if ((not is_tpu) and strategy and distribute_lib.in_cross_replica_context()): 

57 with strategy.scope(): 

58 return strategy.extended.call_for_each_replica(fn, args, kwargs) 

59 return fn(*args, **kwargs) 

60 

61 

62def is_distributed_variable(v): 

63 """Returns whether `v` is a distributed variable.""" 

64 return (isinstance(v, values_lib.DistributedValues) and 

65 isinstance(v, variables.Variable))