Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/moving_average.py: 29%

72 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16import tensorflow as tf 

17 

18from tensorflow_addons.optimizers import AveragedOptimizerWrapper 

19from tensorflow_addons.utils import types 

20 

21from typing import Union 

22from typeguard import typechecked 

23 

24 

25@tf.keras.utils.register_keras_serializable(package="Addons") 

26class MovingAverage(AveragedOptimizerWrapper): 

27 """Optimizer that computes a moving average of the variables. 

28 

29 Empirically it has been found that using the moving average of the trained 

30 parameters of a deep network is better than using its trained parameters 

31 directly. This optimizer allows you to compute this moving average and swap 

32 the variables at save time so that any code outside of the training loop 

33 will use by default the average values instead of the original ones. 

34 

35 Example of usage: 

36 

37 ```python 

38 opt = tf.keras.optimizers.SGD(learning_rate) 

39 opt = tfa.optimizers.MovingAverage(opt) 

40 

41 ``` 

42 """ 

43 

44 @typechecked 

45 def __init__( 

46 self, 

47 optimizer: types.Optimizer, 

48 average_decay: types.FloatTensorLike = 0.99, 

49 num_updates: Union[None, int, tf.Variable] = None, 

50 start_step: int = 0, 

51 dynamic_decay: bool = False, 

52 name: str = "MovingAverage", 

53 **kwargs, 

54 ): 

55 r"""Construct a new MovingAverage optimizer. 

56 

57 Args: 

58 optimizer: str or `tf.keras.optimizers.legacy.Optimizer` that will be 

59 used to compute and apply gradients. 

60 average_decay: float. Decay to use to maintain the moving averages 

61 of trained variables. 

62 num_updates: Optional count of the number of updates applied to 

63 variables. 

64 start_step: int. What step to start the moving average. 

65 dynamic_decay: bool. Whether to change the decay based on the number 

66 of optimizer updates. Decay will start at 0.1 and gradually 

67 increase up to `average_decay` after each optimizer update. 

68 name: Optional name for the operations created when applying 

69 gradients. Defaults to "MovingAverage". 

70 **kwargs: keyword arguments. Allowed to be {`clipnorm`, 

71 `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by 

72 norm; `clipvalue` is clip gradients by value, `decay` is 

73 included for backward compatibility to allow time inverse 

74 decay of learning rate. `lr` is included for backward 

75 compatibility, recommended to use `learning_rate` instead. 

76 """ 

77 super().__init__(optimizer, name, **kwargs) 

78 self._num_updates = num_updates 

79 if self._num_updates is not None: 

80 if isinstance(self._num_updates, tf.Variable): 

81 tf.debugging.assert_integer( 

82 self._num_updates, 

83 ( 

84 'type of argument "num_updates" must be ' 

85 "int; got {} instead".format(self._num_updates.dtype) 

86 ), 

87 ) 

88 num_updates = tf.cast(self._num_updates, tf.float32, name="num_updates") 

89 average_decay = tf.minimum( 

90 average_decay, (1.0 + num_updates) / (10.0 + num_updates) 

91 ) 

92 

93 self._set_hyper("average_decay", average_decay) 

94 self._start_step = start_step 

95 self._dynamic_decay = dynamic_decay 

96 

97 @tf.function 

98 def _get_decay(self, step: tf.Tensor): 

99 average_decay = self._get_hyper("average_decay", tf.dtypes.float32) 

100 

101 step = tf.cast(step, tf.float32) 

102 if step < self._start_step: 

103 return tf.constant(0.0, tf.float32) 

104 elif self._dynamic_decay: 

105 step_count = step - self._start_step 

106 return tf.minimum(average_decay, (1.0 + step_count) / (10.0 + step_count)) 

107 else: 

108 return average_decay 

109 

110 def _prepare_local(self, var_device, var_dtype, apply_state): 

111 super()._prepare_local(var_device, var_dtype, apply_state) 

112 apply_state[(var_device, var_dtype)]["tfa_ma_decay"] = self._get_decay( 

113 self._optimizer.iterations 

114 ) 

115 

116 def average_op(self, var, average_var, local_apply_state): 

117 return tf.keras.backend.moving_average_update( 

118 average_var, var, local_apply_state["tfa_ma_decay"] 

119 ) 

120 

121 def get_config(self): 

122 config = { 

123 "average_decay": self._serialize_hyperparameter("average_decay"), 

124 "num_updates": self._num_updates, 

125 "start_step": self._start_step, 

126 "dynamic_decay": self._dynamic_decay, 

127 } 

128 base_config = super().get_config() 

129 return {**base_config, **config} 

130 

131 def _create_slots(self, var_list): 

132 self._optimizer._create_slots(var_list=var_list) 

133 for var in var_list: 

134 self.add_slot(var, "average", var.read_value()) 

135 

136 self._average_weights = [self.get_slot(var, "average") for var in var_list] 

137 self._model_weights = var_list 

138 

139 def shadow_copy(self, model_weights): 

140 """Creates shadow variables for the given model weights.""" 

141 for var in model_weights: 

142 self.add_slot(var, "average", initializer="zeros") 

143 self._average_weights = [self.get_slot(var, "average") for var in model_weights] 

144 self._model_weights = model_weights 

145 

146 @property 

147 def has_shadow_copy(self): 

148 """Whether this optimizer has created shadow variables.""" 

149 return self._model_weights is not None 

150 

151 def swap_weights(self): 

152 """Swap the average and moving weights. 

153 

154 This is a convenience method to allow one to evaluate the averaged weights 

155 at test time. Loads the weights stored in `self._average_weights` into the model, 

156 keeping a copy of the original model weights. Swapping twice will return 

157 the original weights. 

158 """ 

159 if tf.distribute.in_cross_replica_context(): 

160 strategy = tf.distribute.get_strategy() 

161 return strategy.run(self._swap_weights, args=()) 

162 else: 

163 raise ValueError( 

164 "Swapping weights must occur under a " "tf.distribute.Strategy" 

165 ) 

166 

167 @tf.function 

168 def _swap_weights(self): 

169 def fn_0(a, b): 

170 return a.assign_add(b, use_locking=self._use_locking) 

171 

172 def fn_1(b, a): 

173 return b.assign(a - b, use_locking=self._use_locking) 

174 

175 def fn_2(a, b): 

176 return a.assign_sub(b, use_locking=self._use_locking) 

177 

178 def swap(strategy, a, b): 

179 """Swap `a` and `b` and mirror to all devices.""" 

180 for a_element, b_element in zip(a, b): 

181 strategy.extended.update( 

182 a_element, fn_0, args=(b_element,) 

183 ) # a = a + b 

184 strategy.extended.update( 

185 b_element, fn_1, args=(a_element,) 

186 ) # b = a - b 

187 strategy.extended.update( 

188 a_element, fn_2, args=(b_element,) 

189 ) # a = a - b 

190 

191 ctx = tf.distribute.get_replica_context() 

192 return ctx.merge_call(swap, args=(self._average_weights, self._model_weights))