Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/utils.py: 17%

76 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Optimizer utilities.""" 

16 

17from tensorflow.python.distribute import central_storage_strategy 

18from tensorflow.python.distribute import distribute_lib 

19from tensorflow.python.distribute import reduce_util as ds_reduce_util 

20from tensorflow.python.ops import clip_ops 

21from tensorflow.python.platform import tf_logging as logging 

22 

23 

24def all_reduce_sum_gradients(grads_and_vars): 

25 """Returns all-reduced gradients aggregated via summation. 

26 

27 Args: 

28 grads_and_vars: List of (gradient, variable) pairs. 

29 

30 Returns: 

31 List of (gradient, variable) pairs where gradients have been all-reduced. 

32 """ 

33 grads_and_vars = list(grads_and_vars) 

34 filtered_grads_and_vars = filter_empty_gradients(grads_and_vars) 

35 if filtered_grads_and_vars: 

36 if strategy_supports_no_merge_call(): 

37 grads = [pair[0] for pair in filtered_grads_and_vars] 

38 reduced = distribute_lib.get_strategy().extended._replica_ctx_all_reduce( # pylint: disable=protected-access 

39 ds_reduce_util.ReduceOp.SUM, grads) 

40 else: 

41 # TODO(b/183257003): Remove this branch 

42 reduced = distribute_lib.get_replica_context().merge_call( 

43 _all_reduce_sum_fn, args=(filtered_grads_and_vars,)) 

44 else: 

45 reduced = [] 

46 # Copy 'reduced' but add None gradients back in 

47 reduced_with_nones = [] 

48 reduced_pos = 0 

49 for g, v in grads_and_vars: 

50 if g is None: 

51 reduced_with_nones.append((None, v)) 

52 else: 

53 reduced_with_nones.append((reduced[reduced_pos], v)) 

54 reduced_pos += 1 

55 assert reduced_pos == len(reduced), "Failed to add all gradients" 

56 return reduced_with_nones 

57 

58 

59def filter_empty_gradients(grads_and_vars): 

60 """Filter out `(grad, var)` pairs that have a gradient equal to `None`.""" 

61 grads_and_vars = tuple(grads_and_vars) 

62 if not grads_and_vars: 

63 return grads_and_vars 

64 

65 filtered = [] 

66 vars_with_empty_grads = [] 

67 for grad, var in grads_and_vars: 

68 if grad is None: 

69 vars_with_empty_grads.append(var) 

70 else: 

71 filtered.append((grad, var)) 

72 filtered = tuple(filtered) 

73 

74 if not filtered: 

75 raise ValueError("No gradients provided for any variable: %s." % 

76 ([v.name for _, v in grads_and_vars],)) 

77 if vars_with_empty_grads: 

78 logging.warning( 

79 ("Gradients do not exist for variables %s when minimizing the loss."), 

80 ([v.name for v in vars_with_empty_grads])) 

81 return filtered 

82 

83 

84def make_gradient_clipnorm_fn(clipnorm): 

85 """Creates a gradient transformation function for clipping by norm.""" 

86 if clipnorm is None: 

87 return lambda grads_and_vars: grads_and_vars 

88 

89 def gradient_clipnorm_fn(grads_and_vars): 

90 

91 if isinstance(distribute_lib.get_strategy(), 

92 (central_storage_strategy.CentralStorageStrategy, 

93 central_storage_strategy.CentralStorageStrategyV1)): 

94 raise ValueError( 

95 "`clipnorm` is not supported with `CenteralStorageStrategy`") 

96 

97 clipped_grads_and_vars = [ 

98 (clip_ops.clip_by_norm(g, clipnorm), v) for g, v in grads_and_vars 

99 ] 

100 return clipped_grads_and_vars 

101 

102 return gradient_clipnorm_fn 

103 

104 

105def make_global_gradient_clipnorm_fn(clipnorm): 

106 """Creates a gradient transformation function for clipping by norm.""" 

107 if clipnorm is None: 

108 return lambda grads_and_vars: grads_and_vars 

109 

110 def gradient_clipnorm_fn(grads_and_vars): 

111 

112 if isinstance(distribute_lib.get_strategy(), 

113 (central_storage_strategy.CentralStorageStrategy, 

114 central_storage_strategy.CentralStorageStrategyV1)): 

115 raise ValueError( 

116 "`global_clipnorm` is not supported with `CenteralStorageStrategy`") 

117 

118 grads, variables = zip(*grads_and_vars) 

119 clipped_grads, _ = clip_ops.clip_by_global_norm(grads, clipnorm) 

120 clipped_grads_and_vars = list(zip(clipped_grads, variables)) 

121 return clipped_grads_and_vars 

122 

123 return gradient_clipnorm_fn 

124 

125 

126def make_gradient_clipvalue_fn(clipvalue): 

127 """Creates a gradient transformation function for clipping by value.""" 

128 if clipvalue is None: 

129 return lambda grads_and_vars: grads_and_vars 

130 

131 def gradient_clipvalue_fn(grads_and_vars): 

132 

133 if isinstance(distribute_lib.get_strategy(), 

134 (central_storage_strategy.CentralStorageStrategy, 

135 central_storage_strategy.CentralStorageStrategyV1)): 

136 raise ValueError( 

137 "`clipvalue` is not supported with `CenteralStorageStrategy`") 

138 

139 clipped_grads_and_vars = [(clip_ops.clip_by_value(g, -clipvalue, 

140 clipvalue), v) 

141 for g, v in grads_and_vars] 

142 return clipped_grads_and_vars 

143 

144 return gradient_clipvalue_fn 

145 

146 

147def _all_reduce_sum_fn(distribution, grads_and_vars): 

148 return distribution.extended.batch_reduce_to(ds_reduce_util.ReduceOp.SUM, 

149 grads_and_vars) 

150 

151 

152def strategy_supports_no_merge_call(): 

153 """Returns if the current Strategy can operate in pure replica context.""" 

154 if not distribute_lib.has_strategy(): 

155 return True 

156 strategy = distribute_lib.get_strategy() 

157 return not strategy.extended._use_merge_call() # pylint: disable=protected-access