Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/adabelief.py: 10%
143 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AdaBelief optimizer."""
17import tensorflow as tf
18from tensorflow_addons.utils.types import FloatTensorLike
20from tensorflow_addons.optimizers import KerasLegacyOptimizer
21from typing import Union, Callable, Dict
24@tf.keras.utils.register_keras_serializable(package="Addons")
25class AdaBelief(KerasLegacyOptimizer):
26 """Variant of the Adam optimizer.
28 It achieves fast convergence as Adam and generalization comparable to SGD.
29 It adapts the step size depending on its "belief" in the gradient direction
30 — the optimizer adaptively scales step size by the difference between the
31 predicted and observed gradients.
33 It implements the AdaBelief proposed by
34 Juntang Zhuang et al. in [AdaBelief Optimizer: Adapting stepsizes by the
35 belief in observed gradients](https://arxiv.org/abs/2010.07468).
37 Example of usage:
39 ```python
40 opt = tfa.optimizers.AdaBelief(lr=1e-3)
41 ```
43 Note: `amsgrad` is not described in the original paper. Use it with
44 caution.
46 You can enable enable warmup by setting `total_steps` and
47 `warmup_proportion`,
48 and enable recitifcation as in RAdam by setting 'rectify':
49 ```python
50 opt = tfa.optimizers.AdaBelief(
51 lr=1e-3,
52 total_steps=10000,
53 warmup_proportion=0.1,
54 min_lr=1e-5,
55 rectify=True,
56 )
57 ```
59 In the above example, the learning rate will increase linearly
60 from 0 to `lr` in 1000 steps, then decrease linearly from `lr` to `min_lr`
61 in 9000 steps.
63 Note 'rectify' is independent of 'warmup', you can choose any combinations.
65 Lookahead, proposed by Michael R. Zhang et.al in the paper
66 [Lookahead Optimizer: k steps forward, 1 step back]
67 (https://arxiv.org/abs/1907.08610v1), can be integrated with AdaBelief,
68 which is called 'ranger_adabelief' in the author's implementation
69 https://github.com/juntang-zhuang/Adabelief-Optimizer.
70 The mechanism can be enabled by using the lookahead wrapper. For example:
72 ```python
73 adabelief = tfa.optimizers.AdaBelief()
74 ranger = tfa.optimizers.Lookahead(adabelief, sync_period=6, slow_step_size=0.5)
75 ```
76 """
78 def __init__(
79 self,
80 learning_rate: Union[FloatTensorLike, Callable, Dict] = 0.001,
81 beta_1: FloatTensorLike = 0.9,
82 beta_2: FloatTensorLike = 0.999,
83 epsilon: FloatTensorLike = 1e-14,
84 weight_decay: Union[FloatTensorLike, Callable, Dict] = 0.0,
85 amsgrad: bool = False,
86 rectify: bool = True,
87 sma_threshold: FloatTensorLike = 5.0,
88 total_steps: int = 0,
89 warmup_proportion: FloatTensorLike = 0.1,
90 min_lr: FloatTensorLike = 0.0,
91 name: str = "AdaBelief",
92 **kwargs,
93 ):
94 r"""Construct a new RAdam optimizer.
96 Args:
97 learning_rate: A `Tensor` or a floating point value, or a schedule
98 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`.
99 The learning rate.
100 beta_1: A float value or a constant float tensor. The exponential
101 decay rate for the 1st moment estimates.
102 beta_2: A float value or a constant float tensor. The exponential
103 decay rate for the 2nd moment estimates.
104 epsilon: A small constant for numerical stability. Default=1e-14.
105 Note that AdaBelief uses epsilon within sqrt (default=1e-14),
106 while Adam uses epsilon outside sqrt (default=1e-7).
107 weight_decay: A `Tensor` or a floating point value, or a schedule
108 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`.
109 Weight decay for each parameter.
110 amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm
111 from the paper "On the Convergence of Adam and beyond".
112 sma_threshold. A float value. The threshold for simple mean
113 average.
114 rectify: boolean. Whether to apply learning rate rectification as
115 from RAdam.
116 total_steps: An integer. Total number of training steps. Enable
117 warmup by setting a value greater than zero.
118 warmup_proportion: A floating point value. The proportion of
119 increasing steps.
120 min_lr: A floating point value. Minimum learning rate after warmup.
121 name: Optional name for the operations created when applying
122 gradients. Defaults to "RectifiedAdam".
123 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
124 `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue`
125 is clip gradients by value, `decay` is included for backward
126 compatibility to allow time inverse decay of learning rate. `lr`
127 is included for backward compatibility, recommended to use
128 `learning_rate` instead.
129 """
130 super().__init__(name, **kwargs)
132 if isinstance(learning_rate, Dict):
133 learning_rate = tf.keras.optimizers.schedules.deserialize(learning_rate)
135 if isinstance(weight_decay, Dict):
136 weight_decay = tf.keras.optimizers.schedules.deserialize(weight_decay)
138 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
139 self._set_hyper("beta_1", beta_1)
140 self._set_hyper("beta_2", beta_2)
141 self._set_hyper("decay", self._initial_decay)
142 self._set_hyper("weight_decay", weight_decay)
143 self._set_hyper("sma_threshold", sma_threshold)
144 self._set_hyper("total_steps", float(total_steps))
145 self._set_hyper("warmup_proportion", warmup_proportion)
146 self._set_hyper("min_lr", min_lr)
147 self.epsilon = epsilon or tf.keras.backend.epsilon()
148 self.amsgrad = amsgrad
149 self.rectify = rectify
150 self._has_weight_decay = weight_decay != 0.0
151 self._initial_total_steps = total_steps
153 def _create_slots(self, var_list):
154 for var in var_list:
155 self.add_slot(var, "m")
156 for var in var_list:
157 self.add_slot(var, "v")
158 if self.amsgrad:
159 for var in var_list:
160 self.add_slot(var, "vhat")
162 def set_weights(self, weights):
163 params = self.weights
164 num_vars = int((len(params) - 1) / 2)
165 if len(weights) == 3 * num_vars + 1:
166 weights = weights[: len(params)]
167 super().set_weights(weights)
169 def _decayed_wd(self, var_dtype):
170 wd_t = self._get_hyper("weight_decay", var_dtype)
171 if isinstance(wd_t, tf.keras.optimizers.schedules.LearningRateSchedule):
172 wd_t = tf.cast(wd_t(self.iterations), var_dtype)
173 return wd_t
175 def _resource_apply_dense(self, grad, var):
176 var_dtype = var.dtype.base_dtype
177 lr_t = self._decayed_lr(var_dtype)
178 wd_t = self._decayed_wd(var_dtype)
179 m = self.get_slot(var, "m")
180 v = self.get_slot(var, "v")
181 beta_1_t = self._get_hyper("beta_1", var_dtype)
182 beta_2_t = self._get_hyper("beta_2", var_dtype)
183 epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
184 local_step = tf.cast(self.iterations + 1, var_dtype)
185 beta_1_power = tf.pow(beta_1_t, local_step)
186 beta_2_power = tf.pow(beta_2_t, local_step)
188 if self._initial_total_steps > 0:
189 total_steps = self._get_hyper("total_steps", var_dtype)
190 warmup_steps = total_steps * self._get_hyper("warmup_proportion", var_dtype)
191 min_lr = self._get_hyper("min_lr", var_dtype)
192 decay_steps = tf.maximum(total_steps - warmup_steps, 1)
193 decay_rate = (min_lr - lr_t) / decay_steps
194 lr_t = tf.where(
195 local_step <= warmup_steps,
196 lr_t * (local_step / warmup_steps),
197 lr_t + decay_rate * tf.minimum(local_step - warmup_steps, decay_steps),
198 )
200 sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0
201 sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power)
203 m_t = m.assign(
204 beta_1_t * m + (1.0 - beta_1_t) * grad,
205 use_locking=self._use_locking,
206 )
207 m_corr_t = m_t / (1.0 - beta_1_power)
209 v_t = v.assign(
210 beta_2_t * v + (1.0 - beta_2_t) * tf.math.square(grad - m_t) + epsilon_t,
211 use_locking=self._use_locking,
212 )
213 if self.amsgrad:
214 vhat = self.get_slot(var, "vhat")
215 vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking)
216 v_corr_t = tf.sqrt(vhat_t / (1.0 - beta_2_power))
217 else:
218 vhat_t = None
219 v_corr_t = tf.sqrt(v_t / (1.0 - beta_2_power))
221 if self.rectify:
222 r_t_numerator = (sma_t - 4.0) * (sma_t - 2.0) * sma_inf
223 r_t_denominator = (sma_inf - 4.0) * (sma_inf - 2.0) * sma_t
224 r_t = tf.sqrt(r_t_numerator / r_t_denominator)
225 sma_threshold = self._get_hyper("sma_threshold", var_dtype)
226 var_t = tf.where(
227 sma_t >= sma_threshold,
228 r_t * m_corr_t / (v_corr_t + epsilon_t),
229 m_corr_t,
230 )
231 else:
232 var_t = m_corr_t / (v_corr_t + epsilon_t)
234 if self._has_weight_decay:
235 var_t += wd_t * var
237 var_update = var.assign_sub(lr_t * var_t, use_locking=self._use_locking)
239 updates = [var_update, m_t, v_t]
240 if self.amsgrad:
241 updates.append(vhat_t)
242 return tf.group(*updates)
244 def _resource_apply_sparse(self, grad, var, indices):
245 var_dtype = var.dtype.base_dtype
246 lr_t = self._decayed_lr(var_dtype)
247 wd_t = self._decayed_wd(var_dtype)
248 beta_1_t = self._get_hyper("beta_1", var_dtype)
249 beta_2_t = self._get_hyper("beta_2", var_dtype)
250 epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
251 local_step = tf.cast(self.iterations + 1, var_dtype)
252 beta_1_power = tf.pow(beta_1_t, local_step)
253 beta_2_power = tf.pow(beta_2_t, local_step)
255 if self._initial_total_steps > 0:
256 total_steps = self._get_hyper("total_steps", var_dtype)
257 warmup_steps = total_steps * self._get_hyper("warmup_proportion", var_dtype)
258 min_lr = self._get_hyper("min_lr", var_dtype)
259 decay_steps = tf.maximum(total_steps - warmup_steps, 1)
260 decay_rate = (min_lr - lr_t) / decay_steps
261 lr_t = tf.where(
262 local_step <= warmup_steps,
263 lr_t * (local_step / warmup_steps),
264 lr_t + decay_rate * tf.minimum(local_step - warmup_steps, decay_steps),
265 )
267 sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0
268 sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power)
270 m = self.get_slot(var, "m")
271 m_scaled_g_values = grad * (1 - beta_1_t)
272 m_t = m.assign(m * beta_1_t, use_locking=self._use_locking)
273 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
274 m_corr_t = m_t / (1.0 - beta_1_power)
276 v = self.get_slot(var, "v")
277 m_t_indices = tf.gather(m_t, indices)
278 v_scaled_g_values = (
279 tf.math.square(grad - m_t_indices) * (1 - beta_2_t) + epsilon_t
280 )
281 v_t = v.assign(v * beta_2_t, use_locking=self._use_locking)
282 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
284 if self.amsgrad:
285 vhat = self.get_slot(var, "vhat")
286 vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking)
287 v_corr_t = tf.sqrt(vhat_t / (1.0 - beta_2_power))
288 else:
289 vhat_t = None
290 v_corr_t = tf.sqrt(v_t / (1.0 - beta_2_power))
292 if self.rectify:
293 r_t_numerator = (sma_t - 4.0) * (sma_t - 2.0) * sma_inf
294 r_t_denominator = (sma_inf - 4.0) * (sma_inf - 2.0) * sma_t
295 r_t = tf.sqrt(r_t_numerator / r_t_denominator)
296 sma_threshold = self._get_hyper("sma_threshold", var_dtype)
297 var_t = tf.where(
298 sma_t >= sma_threshold,
299 r_t * m_corr_t / (v_corr_t + epsilon_t),
300 m_corr_t,
301 )
302 else:
303 var_t = m_corr_t / (v_corr_t + epsilon_t)
305 if self._has_weight_decay:
306 var_t += wd_t * var
308 var_update = self._resource_scatter_add(
309 var, indices, tf.gather(-lr_t * var_t, indices)
310 )
312 updates = [var_update, m_t, v_t]
313 if self.amsgrad:
314 updates.append(vhat_t)
315 return tf.group(*updates)
317 def get_config(self):
318 config = super().get_config()
319 config.update(
320 {
321 "learning_rate": self._serialize_hyperparameter("learning_rate"),
322 "beta_1": self._serialize_hyperparameter("beta_1"),
323 "beta_2": self._serialize_hyperparameter("beta_2"),
324 "decay": self._serialize_hyperparameter("decay"),
325 "weight_decay": self._serialize_hyperparameter("weight_decay"),
326 "sma_threshold": self._serialize_hyperparameter("sma_threshold"),
327 "epsilon": self.epsilon,
328 "amsgrad": self.amsgrad,
329 "rectify": self.rectify,
330 "total_steps": int(self._serialize_hyperparameter("total_steps")),
331 "warmup_proportion": self._serialize_hyperparameter(
332 "warmup_proportion"
333 ),
334 "min_lr": self._serialize_hyperparameter("min_lr"),
335 }
336 )
337 return config