Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/adabelief.py: 10%

143 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""AdaBelief optimizer.""" 

16 

17import tensorflow as tf 

18from tensorflow_addons.utils.types import FloatTensorLike 

19 

20from tensorflow_addons.optimizers import KerasLegacyOptimizer 

21from typing import Union, Callable, Dict 

22 

23 

24@tf.keras.utils.register_keras_serializable(package="Addons") 

25class AdaBelief(KerasLegacyOptimizer): 

26 """Variant of the Adam optimizer. 

27 

28 It achieves fast convergence as Adam and generalization comparable to SGD. 

29 It adapts the step size depending on its "belief" in the gradient direction 

30 — the optimizer adaptively scales step size by the difference between the 

31 predicted and observed gradients. 

32 

33 It implements the AdaBelief proposed by 

34 Juntang Zhuang et al. in [AdaBelief Optimizer: Adapting stepsizes by the 

35 belief in observed gradients](https://arxiv.org/abs/2010.07468). 

36 

37 Example of usage: 

38 

39 ```python 

40 opt = tfa.optimizers.AdaBelief(lr=1e-3) 

41 ``` 

42 

43 Note: `amsgrad` is not described in the original paper. Use it with 

44 caution. 

45 

46 You can enable enable warmup by setting `total_steps` and 

47 `warmup_proportion`, 

48 and enable recitifcation as in RAdam by setting 'rectify': 

49 ```python 

50 opt = tfa.optimizers.AdaBelief( 

51 lr=1e-3, 

52 total_steps=10000, 

53 warmup_proportion=0.1, 

54 min_lr=1e-5, 

55 rectify=True, 

56 ) 

57 ``` 

58 

59 In the above example, the learning rate will increase linearly 

60 from 0 to `lr` in 1000 steps, then decrease linearly from `lr` to `min_lr` 

61 in 9000 steps. 

62 

63 Note 'rectify' is independent of 'warmup', you can choose any combinations. 

64 

65 Lookahead, proposed by Michael R. Zhang et.al in the paper 

66 [Lookahead Optimizer: k steps forward, 1 step back] 

67 (https://arxiv.org/abs/1907.08610v1), can be integrated with AdaBelief, 

68 which is called 'ranger_adabelief' in the author's implementation 

69 https://github.com/juntang-zhuang/Adabelief-Optimizer. 

70 The mechanism can be enabled by using the lookahead wrapper. For example: 

71 

72 ```python 

73 adabelief = tfa.optimizers.AdaBelief() 

74 ranger = tfa.optimizers.Lookahead(adabelief, sync_period=6, slow_step_size=0.5) 

75 ``` 

76 """ 

77 

78 def __init__( 

79 self, 

80 learning_rate: Union[FloatTensorLike, Callable, Dict] = 0.001, 

81 beta_1: FloatTensorLike = 0.9, 

82 beta_2: FloatTensorLike = 0.999, 

83 epsilon: FloatTensorLike = 1e-14, 

84 weight_decay: Union[FloatTensorLike, Callable, Dict] = 0.0, 

85 amsgrad: bool = False, 

86 rectify: bool = True, 

87 sma_threshold: FloatTensorLike = 5.0, 

88 total_steps: int = 0, 

89 warmup_proportion: FloatTensorLike = 0.1, 

90 min_lr: FloatTensorLike = 0.0, 

91 name: str = "AdaBelief", 

92 **kwargs, 

93 ): 

94 r"""Construct a new RAdam optimizer. 

95 

96 Args: 

97 learning_rate: A `Tensor` or a floating point value, or a schedule 

98 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`. 

99 The learning rate. 

100 beta_1: A float value or a constant float tensor. The exponential 

101 decay rate for the 1st moment estimates. 

102 beta_2: A float value or a constant float tensor. The exponential 

103 decay rate for the 2nd moment estimates. 

104 epsilon: A small constant for numerical stability. Default=1e-14. 

105 Note that AdaBelief uses epsilon within sqrt (default=1e-14), 

106 while Adam uses epsilon outside sqrt (default=1e-7). 

107 weight_decay: A `Tensor` or a floating point value, or a schedule 

108 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`. 

109 Weight decay for each parameter. 

110 amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm 

111 from the paper "On the Convergence of Adam and beyond". 

112 sma_threshold. A float value. The threshold for simple mean 

113 average. 

114 rectify: boolean. Whether to apply learning rate rectification as 

115 from RAdam. 

116 total_steps: An integer. Total number of training steps. Enable 

117 warmup by setting a value greater than zero. 

118 warmup_proportion: A floating point value. The proportion of 

119 increasing steps. 

120 min_lr: A floating point value. Minimum learning rate after warmup. 

121 name: Optional name for the operations created when applying 

122 gradients. Defaults to "RectifiedAdam". 

123 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, 

124 `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` 

125 is clip gradients by value, `decay` is included for backward 

126 compatibility to allow time inverse decay of learning rate. `lr` 

127 is included for backward compatibility, recommended to use 

128 `learning_rate` instead. 

129 """ 

130 super().__init__(name, **kwargs) 

131 

132 if isinstance(learning_rate, Dict): 

133 learning_rate = tf.keras.optimizers.schedules.deserialize(learning_rate) 

134 

135 if isinstance(weight_decay, Dict): 

136 weight_decay = tf.keras.optimizers.schedules.deserialize(weight_decay) 

137 

138 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 

139 self._set_hyper("beta_1", beta_1) 

140 self._set_hyper("beta_2", beta_2) 

141 self._set_hyper("decay", self._initial_decay) 

142 self._set_hyper("weight_decay", weight_decay) 

143 self._set_hyper("sma_threshold", sma_threshold) 

144 self._set_hyper("total_steps", float(total_steps)) 

145 self._set_hyper("warmup_proportion", warmup_proportion) 

146 self._set_hyper("min_lr", min_lr) 

147 self.epsilon = epsilon or tf.keras.backend.epsilon() 

148 self.amsgrad = amsgrad 

149 self.rectify = rectify 

150 self._has_weight_decay = weight_decay != 0.0 

151 self._initial_total_steps = total_steps 

152 

153 def _create_slots(self, var_list): 

154 for var in var_list: 

155 self.add_slot(var, "m") 

156 for var in var_list: 

157 self.add_slot(var, "v") 

158 if self.amsgrad: 

159 for var in var_list: 

160 self.add_slot(var, "vhat") 

161 

162 def set_weights(self, weights): 

163 params = self.weights 

164 num_vars = int((len(params) - 1) / 2) 

165 if len(weights) == 3 * num_vars + 1: 

166 weights = weights[: len(params)] 

167 super().set_weights(weights) 

168 

169 def _decayed_wd(self, var_dtype): 

170 wd_t = self._get_hyper("weight_decay", var_dtype) 

171 if isinstance(wd_t, tf.keras.optimizers.schedules.LearningRateSchedule): 

172 wd_t = tf.cast(wd_t(self.iterations), var_dtype) 

173 return wd_t 

174 

175 def _resource_apply_dense(self, grad, var): 

176 var_dtype = var.dtype.base_dtype 

177 lr_t = self._decayed_lr(var_dtype) 

178 wd_t = self._decayed_wd(var_dtype) 

179 m = self.get_slot(var, "m") 

180 v = self.get_slot(var, "v") 

181 beta_1_t = self._get_hyper("beta_1", var_dtype) 

182 beta_2_t = self._get_hyper("beta_2", var_dtype) 

183 epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) 

184 local_step = tf.cast(self.iterations + 1, var_dtype) 

185 beta_1_power = tf.pow(beta_1_t, local_step) 

186 beta_2_power = tf.pow(beta_2_t, local_step) 

187 

188 if self._initial_total_steps > 0: 

189 total_steps = self._get_hyper("total_steps", var_dtype) 

190 warmup_steps = total_steps * self._get_hyper("warmup_proportion", var_dtype) 

191 min_lr = self._get_hyper("min_lr", var_dtype) 

192 decay_steps = tf.maximum(total_steps - warmup_steps, 1) 

193 decay_rate = (min_lr - lr_t) / decay_steps 

194 lr_t = tf.where( 

195 local_step <= warmup_steps, 

196 lr_t * (local_step / warmup_steps), 

197 lr_t + decay_rate * tf.minimum(local_step - warmup_steps, decay_steps), 

198 ) 

199 

200 sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 

201 sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) 

202 

203 m_t = m.assign( 

204 beta_1_t * m + (1.0 - beta_1_t) * grad, 

205 use_locking=self._use_locking, 

206 ) 

207 m_corr_t = m_t / (1.0 - beta_1_power) 

208 

209 v_t = v.assign( 

210 beta_2_t * v + (1.0 - beta_2_t) * tf.math.square(grad - m_t) + epsilon_t, 

211 use_locking=self._use_locking, 

212 ) 

213 if self.amsgrad: 

214 vhat = self.get_slot(var, "vhat") 

215 vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking) 

216 v_corr_t = tf.sqrt(vhat_t / (1.0 - beta_2_power)) 

217 else: 

218 vhat_t = None 

219 v_corr_t = tf.sqrt(v_t / (1.0 - beta_2_power)) 

220 

221 if self.rectify: 

222 r_t_numerator = (sma_t - 4.0) * (sma_t - 2.0) * sma_inf 

223 r_t_denominator = (sma_inf - 4.0) * (sma_inf - 2.0) * sma_t 

224 r_t = tf.sqrt(r_t_numerator / r_t_denominator) 

225 sma_threshold = self._get_hyper("sma_threshold", var_dtype) 

226 var_t = tf.where( 

227 sma_t >= sma_threshold, 

228 r_t * m_corr_t / (v_corr_t + epsilon_t), 

229 m_corr_t, 

230 ) 

231 else: 

232 var_t = m_corr_t / (v_corr_t + epsilon_t) 

233 

234 if self._has_weight_decay: 

235 var_t += wd_t * var 

236 

237 var_update = var.assign_sub(lr_t * var_t, use_locking=self._use_locking) 

238 

239 updates = [var_update, m_t, v_t] 

240 if self.amsgrad: 

241 updates.append(vhat_t) 

242 return tf.group(*updates) 

243 

244 def _resource_apply_sparse(self, grad, var, indices): 

245 var_dtype = var.dtype.base_dtype 

246 lr_t = self._decayed_lr(var_dtype) 

247 wd_t = self._decayed_wd(var_dtype) 

248 beta_1_t = self._get_hyper("beta_1", var_dtype) 

249 beta_2_t = self._get_hyper("beta_2", var_dtype) 

250 epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) 

251 local_step = tf.cast(self.iterations + 1, var_dtype) 

252 beta_1_power = tf.pow(beta_1_t, local_step) 

253 beta_2_power = tf.pow(beta_2_t, local_step) 

254 

255 if self._initial_total_steps > 0: 

256 total_steps = self._get_hyper("total_steps", var_dtype) 

257 warmup_steps = total_steps * self._get_hyper("warmup_proportion", var_dtype) 

258 min_lr = self._get_hyper("min_lr", var_dtype) 

259 decay_steps = tf.maximum(total_steps - warmup_steps, 1) 

260 decay_rate = (min_lr - lr_t) / decay_steps 

261 lr_t = tf.where( 

262 local_step <= warmup_steps, 

263 lr_t * (local_step / warmup_steps), 

264 lr_t + decay_rate * tf.minimum(local_step - warmup_steps, decay_steps), 

265 ) 

266 

267 sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 

268 sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) 

269 

270 m = self.get_slot(var, "m") 

271 m_scaled_g_values = grad * (1 - beta_1_t) 

272 m_t = m.assign(m * beta_1_t, use_locking=self._use_locking) 

273 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) 

274 m_corr_t = m_t / (1.0 - beta_1_power) 

275 

276 v = self.get_slot(var, "v") 

277 m_t_indices = tf.gather(m_t, indices) 

278 v_scaled_g_values = ( 

279 tf.math.square(grad - m_t_indices) * (1 - beta_2_t) + epsilon_t 

280 ) 

281 v_t = v.assign(v * beta_2_t, use_locking=self._use_locking) 

282 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) 

283 

284 if self.amsgrad: 

285 vhat = self.get_slot(var, "vhat") 

286 vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking) 

287 v_corr_t = tf.sqrt(vhat_t / (1.0 - beta_2_power)) 

288 else: 

289 vhat_t = None 

290 v_corr_t = tf.sqrt(v_t / (1.0 - beta_2_power)) 

291 

292 if self.rectify: 

293 r_t_numerator = (sma_t - 4.0) * (sma_t - 2.0) * sma_inf 

294 r_t_denominator = (sma_inf - 4.0) * (sma_inf - 2.0) * sma_t 

295 r_t = tf.sqrt(r_t_numerator / r_t_denominator) 

296 sma_threshold = self._get_hyper("sma_threshold", var_dtype) 

297 var_t = tf.where( 

298 sma_t >= sma_threshold, 

299 r_t * m_corr_t / (v_corr_t + epsilon_t), 

300 m_corr_t, 

301 ) 

302 else: 

303 var_t = m_corr_t / (v_corr_t + epsilon_t) 

304 

305 if self._has_weight_decay: 

306 var_t += wd_t * var 

307 

308 var_update = self._resource_scatter_add( 

309 var, indices, tf.gather(-lr_t * var_t, indices) 

310 ) 

311 

312 updates = [var_update, m_t, v_t] 

313 if self.amsgrad: 

314 updates.append(vhat_t) 

315 return tf.group(*updates) 

316 

317 def get_config(self): 

318 config = super().get_config() 

319 config.update( 

320 { 

321 "learning_rate": self._serialize_hyperparameter("learning_rate"), 

322 "beta_1": self._serialize_hyperparameter("beta_1"), 

323 "beta_2": self._serialize_hyperparameter("beta_2"), 

324 "decay": self._serialize_hyperparameter("decay"), 

325 "weight_decay": self._serialize_hyperparameter("weight_decay"), 

326 "sma_threshold": self._serialize_hyperparameter("sma_threshold"), 

327 "epsilon": self.epsilon, 

328 "amsgrad": self.amsgrad, 

329 "rectify": self.rectify, 

330 "total_steps": int(self._serialize_hyperparameter("total_steps")), 

331 "warmup_proportion": self._serialize_hyperparameter( 

332 "warmup_proportion" 

333 ), 

334 "min_lr": self._serialize_hyperparameter("min_lr"), 

335 } 

336 ) 

337 return config