Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/rectified_adam.py: 13%

128 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Rectified Adam (RAdam) optimizer.""" 

16import tensorflow as tf 

17from tensorflow_addons.utils.types import FloatTensorLike 

18 

19from tensorflow_addons.optimizers import KerasLegacyOptimizer 

20from typing import Union, Callable, Dict 

21from typeguard import typechecked 

22 

23 

24@tf.keras.utils.register_keras_serializable(package="Addons") 

25class RectifiedAdam(KerasLegacyOptimizer): 

26 """Variant of the Adam optimizer whose adaptive learning rate is rectified 

27 so as to have a consistent variance. 

28 

29 It implements the Rectified Adam (a.k.a. RAdam) proposed by 

30 Liyuan Liu et al. in [On The Variance Of The Adaptive Learning Rate 

31 And Beyond](https://arxiv.org/pdf/1908.03265v1.pdf). 

32 

33 Example of usage: 

34 

35 ```python 

36 opt = tfa.optimizers.RectifiedAdam(lr=1e-3) 

37 ``` 

38 

39 Note: `amsgrad` is not described in the original paper. Use it with 

40 caution. 

41 

42 RAdam is not a placement of the heuristic warmup, the settings should be 

43 kept if warmup has already been employed and tuned in the baseline method. 

44 You can enable warmup by setting `total_steps` and `warmup_proportion`: 

45 

46 ```python 

47 opt = tfa.optimizers.RectifiedAdam( 

48 lr=1e-3, 

49 total_steps=10000, 

50 warmup_proportion=0.1, 

51 min_lr=1e-5, 

52 ) 

53 ``` 

54 

55 In the above example, the learning rate will increase linearly 

56 from 0 to `lr` in 1000 steps, then decrease linearly from `lr` to `min_lr` 

57 in 9000 steps. 

58 

59 Lookahead, proposed by Michael R. Zhang et.al in the paper 

60 [Lookahead Optimizer: k steps forward, 1 step back] 

61 (https://arxiv.org/abs/1907.08610v1), can be integrated with RAdam, 

62 which is announced by Less Wright and the new combined optimizer can also 

63 be called "Ranger". The mechanism can be enabled by using the lookahead 

64 wrapper. For example: 

65 

66 ```python 

67 radam = tfa.optimizers.RectifiedAdam() 

68 ranger = tfa.optimizers.Lookahead(radam, sync_period=6, slow_step_size=0.5) 

69 ``` 

70 """ 

71 

72 @typechecked 

73 def __init__( 

74 self, 

75 learning_rate: Union[FloatTensorLike, Callable, Dict] = 0.001, 

76 beta_1: FloatTensorLike = 0.9, 

77 beta_2: FloatTensorLike = 0.999, 

78 epsilon: FloatTensorLike = 1e-7, 

79 weight_decay: Union[FloatTensorLike, Callable, Dict] = 0.0, 

80 amsgrad: bool = False, 

81 sma_threshold: FloatTensorLike = 5.0, 

82 total_steps: int = 0, 

83 warmup_proportion: FloatTensorLike = 0.1, 

84 min_lr: FloatTensorLike = 0.0, 

85 name: str = "RectifiedAdam", 

86 **kwargs, 

87 ): 

88 r"""Construct a new RAdam optimizer. 

89 

90 Args: 

91 learning_rate: A `Tensor` or a floating point value, or a schedule 

92 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`. 

93 The learning rate. 

94 beta_1: A float value or a constant float tensor. 

95 The exponential decay rate for the 1st moment estimates. 

96 beta_2: A float value or a constant float tensor. 

97 The exponential decay rate for the 2nd moment estimates. 

98 epsilon: A small constant for numerical stability. 

99 weight_decay: A `Tensor` or a floating point value, or a schedule 

100 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`. 

101 Weight decay for each parameter. 

102 amsgrad: boolean. Whether to apply AMSGrad variant of this 

103 algorithm from the paper "On the Convergence of Adam and 

104 beyond". 

105 sma_threshold. A float value. 

106 The threshold for simple mean average. 

107 total_steps: An integer value. Total number of training steps. 

108 Enable warmup by setting a positive value. 

109 warmup_proportion: A floating point value. 

110 The proportion of increasing steps. 

111 min_lr: A floating point value. Minimum learning rate after warmup. 

112 name: Optional name for the operations created when applying 

113 gradients. Defaults to "RectifiedAdam". 

114 **kwargs: keyword arguments. Allowed to be {`clipnorm`, 

115 `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients 

116 by norm; `clipvalue` is clip gradients by value, `decay` is 

117 included for backward compatibility to allow time inverse 

118 decay of learning rate. `lr` is included for backward 

119 compatibility, recommended to use `learning_rate` instead. 

120 """ 

121 super().__init__(name, **kwargs) 

122 

123 if isinstance(learning_rate, Dict): 

124 learning_rate = tf.keras.optimizers.schedules.deserialize(learning_rate) 

125 

126 if isinstance(weight_decay, Dict): 

127 weight_decay = tf.keras.optimizers.schedules.deserialize(weight_decay) 

128 

129 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 

130 self._set_hyper("beta_1", beta_1) 

131 self._set_hyper("beta_2", beta_2) 

132 self._set_hyper("decay", self._initial_decay) 

133 self._set_hyper("weight_decay", weight_decay) 

134 self._set_hyper("sma_threshold", sma_threshold) 

135 self._set_hyper("total_steps", float(total_steps)) 

136 self._set_hyper("warmup_proportion", warmup_proportion) 

137 self._set_hyper("min_lr", min_lr) 

138 self.epsilon = epsilon or tf.keras.backend.epsilon() 

139 self.amsgrad = amsgrad 

140 self._has_weight_decay = weight_decay != 0.0 

141 self._initial_total_steps = total_steps 

142 

143 def _create_slots(self, var_list): 

144 for var in var_list: 

145 self.add_slot(var, "m") 

146 for var in var_list: 

147 self.add_slot(var, "v") 

148 if self.amsgrad: 

149 for var in var_list: 

150 self.add_slot(var, "vhat") 

151 

152 def set_weights(self, weights): 

153 params = self.weights 

154 num_vars = int((len(params) - 1) / 2) 

155 if len(weights) == 3 * num_vars + 1: 

156 weights = weights[: len(params)] 

157 super().set_weights(weights) 

158 

159 def _decayed_wd(self, var_dtype): 

160 wd_t = self._get_hyper("weight_decay", var_dtype) 

161 if isinstance(wd_t, tf.keras.optimizers.schedules.LearningRateSchedule): 

162 wd_t = tf.cast(wd_t(self.iterations), var_dtype) 

163 return wd_t 

164 

165 def _prepare_local(self, var_device, var_dtype, apply_state): 

166 super()._prepare_local(var_device, var_dtype, apply_state) 

167 lr_t = self._decayed_lr(var_dtype) 

168 wd_t = self._decayed_wd(var_dtype) 

169 beta_1_t = self._get_hyper("beta_1", var_dtype) 

170 beta_2_t = self._get_hyper("beta_2", var_dtype) 

171 local_step = tf.cast(self.iterations + 1, var_dtype) 

172 beta_1_power = tf.pow(beta_1_t, local_step) 

173 beta_2_power = tf.pow(beta_2_t, local_step) 

174 one_minus_beta_1_t = 1.0 - beta_1_t 

175 recip_one_minus_beta_1_power = 1.0 / (1.0 - beta_1_power) 

176 one_minus_beta_2_t = 1.0 - beta_2_t 

177 recip_one_minus_beta_2_power = 1.0 / (1.0 - beta_2_power) 

178 sma_inf = 2.0 / one_minus_beta_2_t - 1.0 

179 sma_t = sma_inf - 2.0 * local_step * beta_2_power * recip_one_minus_beta_2_power 

180 r_t = tf.sqrt( 

181 (sma_t - 4.0) 

182 / (sma_inf - 4.0) 

183 * (sma_t - 2.0) 

184 / (sma_inf - 2.0) 

185 * sma_inf 

186 / sma_t 

187 ) 

188 sma_threshold = self._get_hyper("sma_threshold", var_dtype) 

189 sma_t_ge_sma_threshold = sma_t >= sma_threshold 

190 if self._initial_total_steps > 0: 

191 total_steps = self._get_hyper("total_steps", var_dtype) 

192 warmup_steps = total_steps * self._get_hyper("warmup_proportion", var_dtype) 

193 min_lr = self._get_hyper("min_lr", var_dtype) 

194 decay_steps = tf.maximum(total_steps - warmup_steps, 1) 

195 decay_rate = (min_lr - lr_t) / decay_steps 

196 lr_t = tf.where( 

197 local_step <= warmup_steps, 

198 lr_t * (local_step / warmup_steps), 

199 lr_t + decay_rate * tf.minimum(local_step - warmup_steps, decay_steps), 

200 ) 

201 apply_state[(var_device, var_dtype)].update( 

202 dict( 

203 lr_t=lr_t, 

204 wd_t=wd_t, 

205 beta_1_t=beta_1_t, 

206 beta_2_t=beta_2_t, 

207 epsilon_t=tf.convert_to_tensor(self.epsilon, var_dtype), 

208 local_step=local_step, 

209 beta_1_power=beta_1_power, 

210 beta_2_power=beta_2_power, 

211 sma_inf=sma_inf, 

212 sma_t=sma_t, 

213 one_minus_beta_1_t=one_minus_beta_1_t, 

214 recip_one_minus_beta_1_power=recip_one_minus_beta_1_power, 

215 one_minus_beta_2_t=one_minus_beta_2_t, 

216 recip_one_minus_beta_2_power=recip_one_minus_beta_2_power, 

217 r_t=r_t, 

218 sma_t_ge_sma_threshold=sma_t_ge_sma_threshold, 

219 ) 

220 ) 

221 

222 def _resource_apply_dense(self, grad, var, apply_state=None): 

223 var_device, var_dtype = var.device, var.dtype.base_dtype 

224 coef = (apply_state or {}).get( 

225 (var_device, var_dtype) 

226 ) or self._fallback_apply_state(var_device, var_dtype) 

227 m = self.get_slot(var, "m") 

228 v = self.get_slot(var, "v") 

229 

230 m_t = m.assign( 

231 coef["beta_1_t"] * m + coef["one_minus_beta_1_t"] * grad, 

232 use_locking=self._use_locking, 

233 ) 

234 m_corr_t = m_t * coef["recip_one_minus_beta_1_power"] 

235 

236 v_t = v.assign( 

237 coef["beta_2_t"] * v + coef["one_minus_beta_2_t"] * tf.square(grad), 

238 use_locking=self._use_locking, 

239 ) 

240 if self.amsgrad: 

241 vhat = self.get_slot(var, "vhat") 

242 vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking) 

243 v_corr_t = tf.sqrt(vhat_t * coef["recip_one_minus_beta_2_power"]) 

244 else: 

245 vhat_t = None 

246 v_corr_t = tf.sqrt(v_t * coef["recip_one_minus_beta_2_power"]) 

247 

248 var_t = tf.where( 

249 coef["sma_t_ge_sma_threshold"], 

250 coef["r_t"] * m_corr_t / (v_corr_t + coef["epsilon_t"]), 

251 m_corr_t, 

252 ) 

253 

254 if self._has_weight_decay: 

255 var_t += coef["wd_t"] * var 

256 

257 var_update = var.assign_sub(coef["lr_t"] * var_t, use_locking=self._use_locking) 

258 

259 updates = [var_update, m_t, v_t] 

260 if self.amsgrad: 

261 updates.append(vhat_t) 

262 return tf.group(*updates) 

263 

264 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 

265 var_device, var_dtype = var.device, var.dtype.base_dtype 

266 coef = (apply_state or {}).get( 

267 (var_device, var_dtype) 

268 ) or self._fallback_apply_state(var_device, var_dtype) 

269 

270 m = self.get_slot(var, "m") 

271 m_scaled_g_values = grad * coef["one_minus_beta_1_t"] 

272 m_t = m.assign(m * coef["beta_1_t"], use_locking=self._use_locking) 

273 with tf.control_dependencies([m_t]): 

274 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) 

275 m_corr_t = m_t * coef["recip_one_minus_beta_1_power"] 

276 

277 v = self.get_slot(var, "v") 

278 v_scaled_g_values = (grad * grad) * coef["one_minus_beta_2_t"] 

279 v_t = v.assign(v * coef["beta_2_t"], use_locking=self._use_locking) 

280 with tf.control_dependencies([v_t]): 

281 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) 

282 

283 if self.amsgrad: 

284 vhat = self.get_slot(var, "vhat") 

285 vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking) 

286 v_corr_t = tf.sqrt(vhat_t * coef["recip_one_minus_beta_2_power"]) 

287 else: 

288 vhat_t = None 

289 v_corr_t = tf.sqrt(v_t * coef["recip_one_minus_beta_2_power"]) 

290 

291 var_t = tf.where( 

292 coef["sma_t_ge_sma_threshold"], 

293 coef["r_t"] * m_corr_t / (v_corr_t + coef["epsilon_t"]), 

294 m_corr_t, 

295 ) 

296 

297 if self._has_weight_decay: 

298 var_t += coef["wd_t"] * var 

299 

300 with tf.control_dependencies([var_t]): 

301 var_update = self._resource_scatter_add( 

302 var, indices, tf.gather(-coef["lr_t"] * var_t, indices) 

303 ) 

304 

305 updates = [var_update, m_t, v_t] 

306 if self.amsgrad: 

307 updates.append(vhat_t) 

308 return tf.group(*updates) 

309 

310 def get_config(self): 

311 config = super().get_config() 

312 config.update( 

313 { 

314 "learning_rate": self._serialize_hyperparameter("learning_rate"), 

315 "beta_1": self._serialize_hyperparameter("beta_1"), 

316 "beta_2": self._serialize_hyperparameter("beta_2"), 

317 "decay": self._serialize_hyperparameter("decay"), 

318 "weight_decay": self._serialize_hyperparameter("weight_decay"), 

319 "sma_threshold": self._serialize_hyperparameter("sma_threshold"), 

320 "epsilon": self.epsilon, 

321 "amsgrad": self.amsgrad, 

322 "total_steps": int(self._serialize_hyperparameter("total_steps")), 

323 "warmup_proportion": self._serialize_hyperparameter( 

324 "warmup_proportion" 

325 ), 

326 "min_lr": self._serialize_hyperparameter("min_lr"), 

327 } 

328 ) 

329 return config