Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/distribute/distributed_training_utils.py: 23%

48 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities related to distributed training.""" 

16 

17import contextlib 

18 

19import tensorflow.compat.v2 as tf 

20from absl import flags 

21 

22from keras.src import backend 

23 

24FLAGS = flags.FLAGS 

25 

26 

27# TODO(b/118776054): Currently we support global batch size for TPUStrategy and 

28# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is 

29# no longer needed. 

30def global_batch_size_supported(distribution_strategy): 

31 return distribution_strategy.extended._global_batch_size 

32 

33 

34def call_replica_local_fn(fn, *args, **kwargs): 

35 """Call a function that uses replica-local variables. 

36 

37 This function correctly handles calling `fn` in a cross-replica 

38 context. 

39 

40 Args: 

41 fn: The function to call. 

42 *args: Positional arguments to the `fn`. 

43 **kwargs: Keyword argument to `fn`. 

44 

45 Returns: 

46 The result of calling `fn`. 

47 """ 

48 # TODO(b/132666209): Remove this function when we support assign_* 

49 # for replica-local variables. 

50 strategy = None 

51 if "strategy" in kwargs: 

52 strategy = kwargs.pop("strategy") 

53 else: 

54 if tf.distribute.has_strategy(): 

55 strategy = tf.distribute.get_strategy() 

56 

57 # TODO(b/120571621): TPUStrategy does not implement replica-local variables. 

58 is_tpu = backend.is_tpu_strategy(strategy) 

59 if (not is_tpu) and strategy and tf.distribute.in_cross_replica_context(): 

60 with strategy.scope(): 

61 return strategy.extended.call_for_each_replica(fn, args, kwargs) 

62 return fn(*args, **kwargs) 

63 

64 

65def is_distributed_variable(v): 

66 """Returns whether `v` is a distributed variable.""" 

67 return isinstance(v, tf.distribute.DistributedValues) and isinstance( 

68 v, tf.Variable 

69 ) 

70 

71 

72def get_strategy(): 

73 """Creates a `tf.distribute.Strategy` object from flags. 

74 

75 Example usage: 

76 

77 ```python 

78 strategy = utils.get_strategy() 

79 with strategy.scope(): 

80 model = tf.keras.Sequential([tf.keras.layers.Dense(10)]) 

81 

82 model.compile(...) 

83 train_ds, test_ds = ... 

84 model.fit(train_ds, validation_data=test_ds, epochs=10) 

85 ``` 

86 

87 Returns: 

88 `tf.distribute.Strategy` instance. 

89 """ 

90 cls = FLAGS.keras_distribute_strategy_class 

91 accepted_strats = { 

92 "tpu", 

93 "multi_worker_mirrored", 

94 "mirrored", 

95 "parameter_server", 

96 "one_device", 

97 } 

98 if cls == "tpu": 

99 tpu_addr = FLAGS.keras_distribute_strategy_tpu_addr 

100 if not tpu_addr: 

101 raise ValueError( 

102 "When using a TPU strategy, you must set the flag " 

103 "`keras_distribute_strategy_tpu_addr` (TPU address)." 

104 ) 

105 cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( 

106 tpu=tpu_addr 

107 ) 

108 tf.config.experimental_connect_to_cluster(cluster_resolver) 

109 tf.tpu.experimental.initialize_tpu_system(cluster_resolver) 

110 strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) 

111 elif cls == "multi_worker_mirrored": 

112 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 

113 elif cls == "mirrored": 

114 strategy = tf.distribute.MirroredStrategy() 

115 elif cls == "parameter_server": 

116 cluster_resolver = ( 

117 tf.distribute.cluster_resolver.TFConfigClusterResolver() 

118 ) 

119 strategy = tf.distribute.experimental.ParameterServerStrategy( 

120 cluster_resolver 

121 ) 

122 elif cls == "one_device": 

123 strategy = tf.distribute.OneDeviceStrategy("/gpu:0") 

124 else: 

125 raise ValueError( 

126 "Unknown distribution strategy flag. Received: " 

127 f"keras_distribute_strategy_class={cls}. " 

128 f"It should be one of {accepted_strats}" 

129 ) 

130 return strategy 

131 

132 

133def maybe_preemption_handler_scope(model): 

134 

135 if getattr(model, "_preemption_handler", None): 

136 preemption_checkpoint_scope = ( 

137 model._preemption_handler.watch_preemption_scope() 

138 ) 

139 else: 

140 preemption_checkpoint_scope = contextlib.nullcontext() 

141 

142 return preemption_checkpoint_scope 

143