Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/adamw.py: 20%

65 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""AdamW optimizer implementation.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src.optimizers import optimizer 

21from keras.src.saving.object_registration import register_keras_serializable 

22 

23# isort: off 

24from tensorflow.python.util.tf_export import keras_export 

25 

26 

27@register_keras_serializable() 

28@keras_export( 

29 "keras.optimizers.AdamW", 

30 "keras.optimizers.experimental.AdamW", 

31 "keras.dtensor.experimental.optimizers.AdamW", 

32 v1=[], 

33) 

34class AdamW(optimizer.Optimizer): 

35 r"""Optimizer that implements the AdamW algorithm. 

36 

37 AdamW optimization is a stochastic gradient descent method that is based on 

38 adaptive estimation of first-order and second-order moments with an added 

39 method to decay weights per the techniques discussed in the paper, 

40 'Decoupled Weight Decay Regularization' by 

41 [Loshchilov, Hutter et al., 2019](https://arxiv.org/abs/1711.05101). 

42 

43 According to 

44 [Kingma et al., 2014](http://arxiv.org/abs/1412.6980), 

45 the underying Adam method is "*computationally 

46 efficient, has little memory requirement, invariant to diagonal rescaling of 

47 gradients, and is well suited for problems that are large in terms of 

48 data/parameters*". 

49 

50 Args: 

51 learning_rate: A `tf.Tensor`, floating point value, a schedule that is a 

52 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable 

53 that takes no arguments and returns the actual value to use. The 

54 learning rate. Defaults to 0.001. 

55 beta_1: A float value or a constant float tensor, or a callable 

56 that takes no arguments and returns the actual value to use. The 

57 exponential decay rate for the 1st moment estimates. Defaults to 0.9. 

58 beta_2: A float value or a constant float tensor, or a callable 

59 that takes no arguments and returns the actual value to use. The 

60 exponential decay rate for the 2nd moment estimates. Defaults to 0.999. 

61 epsilon: A small constant for numerical stability. This epsilon is 

62 "epsilon hat" in the Kingma and Ba paper (in the formula just before 

63 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to 

64 1e-7. 

65 amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm from 

66 the paper "On the Convergence of Adam and beyond". Defaults to `False`. 

67 {{base_optimizer_keyword_args}} 

68 

69 Reference: 

70 - [Loshchilov et al., 2019](https://arxiv.org/abs/1711.05101) 

71 - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) for `adam` 

72 - [Reddi et al., 2018]( 

73 https://openreview.net/pdf?id=ryQu7f-RZ) for `amsgrad`. 

74 

75 Notes: 

76 

77 The sparse implementation of this algorithm (used when the gradient is an 

78 IndexedSlices object, typically because of `tf.gather` or an embedding 

79 lookup in the forward pass) does apply momentum to variable slices even if 

80 they were not used in the forward pass (meaning they have a gradient equal 

81 to zero). Momentum decay (beta1) is also applied to the entire momentum 

82 accumulator. This means that the sparse behavior is equivalent to the dense 

83 behavior (in contrast to some momentum implementations which ignore momentum 

84 unless a variable slice was actually used). 

85 """ 

86 

87 def __init__( 

88 self, 

89 learning_rate=0.001, 

90 weight_decay=0.004, 

91 beta_1=0.9, 

92 beta_2=0.999, 

93 epsilon=1e-7, 

94 amsgrad=False, 

95 clipnorm=None, 

96 clipvalue=None, 

97 global_clipnorm=None, 

98 use_ema=False, 

99 ema_momentum=0.99, 

100 ema_overwrite_frequency=None, 

101 jit_compile=True, 

102 name="AdamW", 

103 **kwargs 

104 ): 

105 super().__init__( 

106 name=name, 

107 clipnorm=clipnorm, 

108 clipvalue=clipvalue, 

109 global_clipnorm=global_clipnorm, 

110 use_ema=use_ema, 

111 ema_momentum=ema_momentum, 

112 ema_overwrite_frequency=ema_overwrite_frequency, 

113 jit_compile=jit_compile, 

114 **kwargs 

115 ) 

116 self._learning_rate = self._build_learning_rate(learning_rate) 

117 self.weight_decay = weight_decay 

118 self.beta_1 = beta_1 

119 self.beta_2 = beta_2 

120 self.epsilon = epsilon 

121 self.amsgrad = amsgrad 

122 

123 if self.weight_decay is None: 

124 raise ValueError( 

125 "Missing value of `weight_decay` which is required and" 

126 " must be a float value." 

127 ) 

128 

129 def build(self, var_list): 

130 """Initialize optimizer variables. 

131 

132 AdamW optimizer has 3 types of variables: momentums, velocities and 

133 velocity_hat (only set when amsgrad is applied), 

134 

135 Args: 

136 var_list: list of model variables to build AdamW variables on. 

137 """ 

138 super().build(var_list) 

139 if hasattr(self, "_built") and self._built: 

140 return 

141 self._built = True 

142 self._momentums = [] 

143 self._velocities = [] 

144 for var in var_list: 

145 self._momentums.append( 

146 self.add_variable_from_reference( 

147 model_variable=var, variable_name="m" 

148 ) 

149 ) 

150 self._velocities.append( 

151 self.add_variable_from_reference( 

152 model_variable=var, variable_name="v" 

153 ) 

154 ) 

155 if self.amsgrad: 

156 self._velocity_hats = [] 

157 for var in var_list: 

158 self._velocity_hats.append( 

159 self.add_variable_from_reference( 

160 model_variable=var, variable_name="vhat" 

161 ) 

162 ) 

163 

164 def update_step(self, gradient, variable): 

165 """Update step given gradient and the associated model variable.""" 

166 beta_1_power = None 

167 beta_2_power = None 

168 lr = tf.cast(self.learning_rate, variable.dtype) 

169 local_step = tf.cast(self.iterations + 1, variable.dtype) 

170 beta_1_power = tf.pow(tf.cast(self.beta_1, variable.dtype), local_step) 

171 beta_2_power = tf.pow(tf.cast(self.beta_2, variable.dtype), local_step) 

172 

173 var_key = self._var_key(variable) 

174 m = self._momentums[self._index_dict[var_key]] 

175 v = self._velocities[self._index_dict[var_key]] 

176 

177 alpha = lr * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power) 

178 

179 if isinstance(gradient, tf.IndexedSlices): 

180 # Sparse gradients. 

181 m.assign_add(-m * (1 - self.beta_1)) 

182 m.scatter_add( 

183 tf.IndexedSlices( 

184 gradient.values * (1 - self.beta_1), gradient.indices 

185 ) 

186 ) 

187 v.assign_add(-v * (1 - self.beta_2)) 

188 v.scatter_add( 

189 tf.IndexedSlices( 

190 tf.square(gradient.values) * (1 - self.beta_2), 

191 gradient.indices, 

192 ) 

193 ) 

194 if self.amsgrad: 

195 v_hat = self._velocity_hats[self._index_dict[var_key]] 

196 v_hat.assign(tf.maximum(v_hat, v)) 

197 v = v_hat 

198 variable.assign_sub((m * alpha) / (tf.sqrt(v) + self.epsilon)) 

199 else: 

200 # Dense gradients. 

201 m.assign_add((gradient - m) * (1 - self.beta_1)) 

202 v.assign_add((tf.square(gradient) - v) * (1 - self.beta_2)) 

203 if self.amsgrad: 

204 v_hat = self._velocity_hats[self._index_dict[var_key]] 

205 v_hat.assign(tf.maximum(v_hat, v)) 

206 v = v_hat 

207 variable.assign_sub((m * alpha) / (tf.sqrt(v) + self.epsilon)) 

208 

209 def get_config(self): 

210 config = super().get_config() 

211 

212 config.update( 

213 { 

214 "learning_rate": self._serialize_hyperparameter( 

215 self._learning_rate 

216 ), 

217 "weight_decay": self.weight_decay, 

218 "beta_1": self.beta_1, 

219 "beta_2": self.beta_2, 

220 "epsilon": self.epsilon, 

221 "amsgrad": self.amsgrad, 

222 } 

223 ) 

224 return config 

225 

226 

227AdamW.__doc__ = AdamW.__doc__.replace( 

228 "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args 

229) 

230