Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/rectified_adam.py: 13%
128 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Rectified Adam (RAdam) optimizer."""
16import tensorflow as tf
17from tensorflow_addons.utils.types import FloatTensorLike
19from tensorflow_addons.optimizers import KerasLegacyOptimizer
20from typing import Union, Callable, Dict
21from typeguard import typechecked
24@tf.keras.utils.register_keras_serializable(package="Addons")
25class RectifiedAdam(KerasLegacyOptimizer):
26 """Variant of the Adam optimizer whose adaptive learning rate is rectified
27 so as to have a consistent variance.
29 It implements the Rectified Adam (a.k.a. RAdam) proposed by
30 Liyuan Liu et al. in [On The Variance Of The Adaptive Learning Rate
31 And Beyond](https://arxiv.org/pdf/1908.03265v1.pdf).
33 Example of usage:
35 ```python
36 opt = tfa.optimizers.RectifiedAdam(lr=1e-3)
37 ```
39 Note: `amsgrad` is not described in the original paper. Use it with
40 caution.
42 RAdam is not a placement of the heuristic warmup, the settings should be
43 kept if warmup has already been employed and tuned in the baseline method.
44 You can enable warmup by setting `total_steps` and `warmup_proportion`:
46 ```python
47 opt = tfa.optimizers.RectifiedAdam(
48 lr=1e-3,
49 total_steps=10000,
50 warmup_proportion=0.1,
51 min_lr=1e-5,
52 )
53 ```
55 In the above example, the learning rate will increase linearly
56 from 0 to `lr` in 1000 steps, then decrease linearly from `lr` to `min_lr`
57 in 9000 steps.
59 Lookahead, proposed by Michael R. Zhang et.al in the paper
60 [Lookahead Optimizer: k steps forward, 1 step back]
61 (https://arxiv.org/abs/1907.08610v1), can be integrated with RAdam,
62 which is announced by Less Wright and the new combined optimizer can also
63 be called "Ranger". The mechanism can be enabled by using the lookahead
64 wrapper. For example:
66 ```python
67 radam = tfa.optimizers.RectifiedAdam()
68 ranger = tfa.optimizers.Lookahead(radam, sync_period=6, slow_step_size=0.5)
69 ```
70 """
72 @typechecked
73 def __init__(
74 self,
75 learning_rate: Union[FloatTensorLike, Callable, Dict] = 0.001,
76 beta_1: FloatTensorLike = 0.9,
77 beta_2: FloatTensorLike = 0.999,
78 epsilon: FloatTensorLike = 1e-7,
79 weight_decay: Union[FloatTensorLike, Callable, Dict] = 0.0,
80 amsgrad: bool = False,
81 sma_threshold: FloatTensorLike = 5.0,
82 total_steps: int = 0,
83 warmup_proportion: FloatTensorLike = 0.1,
84 min_lr: FloatTensorLike = 0.0,
85 name: str = "RectifiedAdam",
86 **kwargs,
87 ):
88 r"""Construct a new RAdam optimizer.
90 Args:
91 learning_rate: A `Tensor` or a floating point value, or a schedule
92 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`.
93 The learning rate.
94 beta_1: A float value or a constant float tensor.
95 The exponential decay rate for the 1st moment estimates.
96 beta_2: A float value or a constant float tensor.
97 The exponential decay rate for the 2nd moment estimates.
98 epsilon: A small constant for numerical stability.
99 weight_decay: A `Tensor` or a floating point value, or a schedule
100 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`.
101 Weight decay for each parameter.
102 amsgrad: boolean. Whether to apply AMSGrad variant of this
103 algorithm from the paper "On the Convergence of Adam and
104 beyond".
105 sma_threshold. A float value.
106 The threshold for simple mean average.
107 total_steps: An integer value. Total number of training steps.
108 Enable warmup by setting a positive value.
109 warmup_proportion: A floating point value.
110 The proportion of increasing steps.
111 min_lr: A floating point value. Minimum learning rate after warmup.
112 name: Optional name for the operations created when applying
113 gradients. Defaults to "RectifiedAdam".
114 **kwargs: keyword arguments. Allowed to be {`clipnorm`,
115 `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients
116 by norm; `clipvalue` is clip gradients by value, `decay` is
117 included for backward compatibility to allow time inverse
118 decay of learning rate. `lr` is included for backward
119 compatibility, recommended to use `learning_rate` instead.
120 """
121 super().__init__(name, **kwargs)
123 if isinstance(learning_rate, Dict):
124 learning_rate = tf.keras.optimizers.schedules.deserialize(learning_rate)
126 if isinstance(weight_decay, Dict):
127 weight_decay = tf.keras.optimizers.schedules.deserialize(weight_decay)
129 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
130 self._set_hyper("beta_1", beta_1)
131 self._set_hyper("beta_2", beta_2)
132 self._set_hyper("decay", self._initial_decay)
133 self._set_hyper("weight_decay", weight_decay)
134 self._set_hyper("sma_threshold", sma_threshold)
135 self._set_hyper("total_steps", float(total_steps))
136 self._set_hyper("warmup_proportion", warmup_proportion)
137 self._set_hyper("min_lr", min_lr)
138 self.epsilon = epsilon or tf.keras.backend.epsilon()
139 self.amsgrad = amsgrad
140 self._has_weight_decay = weight_decay != 0.0
141 self._initial_total_steps = total_steps
143 def _create_slots(self, var_list):
144 for var in var_list:
145 self.add_slot(var, "m")
146 for var in var_list:
147 self.add_slot(var, "v")
148 if self.amsgrad:
149 for var in var_list:
150 self.add_slot(var, "vhat")
152 def set_weights(self, weights):
153 params = self.weights
154 num_vars = int((len(params) - 1) / 2)
155 if len(weights) == 3 * num_vars + 1:
156 weights = weights[: len(params)]
157 super().set_weights(weights)
159 def _decayed_wd(self, var_dtype):
160 wd_t = self._get_hyper("weight_decay", var_dtype)
161 if isinstance(wd_t, tf.keras.optimizers.schedules.LearningRateSchedule):
162 wd_t = tf.cast(wd_t(self.iterations), var_dtype)
163 return wd_t
165 def _prepare_local(self, var_device, var_dtype, apply_state):
166 super()._prepare_local(var_device, var_dtype, apply_state)
167 lr_t = self._decayed_lr(var_dtype)
168 wd_t = self._decayed_wd(var_dtype)
169 beta_1_t = self._get_hyper("beta_1", var_dtype)
170 beta_2_t = self._get_hyper("beta_2", var_dtype)
171 local_step = tf.cast(self.iterations + 1, var_dtype)
172 beta_1_power = tf.pow(beta_1_t, local_step)
173 beta_2_power = tf.pow(beta_2_t, local_step)
174 one_minus_beta_1_t = 1.0 - beta_1_t
175 recip_one_minus_beta_1_power = 1.0 / (1.0 - beta_1_power)
176 one_minus_beta_2_t = 1.0 - beta_2_t
177 recip_one_minus_beta_2_power = 1.0 / (1.0 - beta_2_power)
178 sma_inf = 2.0 / one_minus_beta_2_t - 1.0
179 sma_t = sma_inf - 2.0 * local_step * beta_2_power * recip_one_minus_beta_2_power
180 r_t = tf.sqrt(
181 (sma_t - 4.0)
182 / (sma_inf - 4.0)
183 * (sma_t - 2.0)
184 / (sma_inf - 2.0)
185 * sma_inf
186 / sma_t
187 )
188 sma_threshold = self._get_hyper("sma_threshold", var_dtype)
189 sma_t_ge_sma_threshold = sma_t >= sma_threshold
190 if self._initial_total_steps > 0:
191 total_steps = self._get_hyper("total_steps", var_dtype)
192 warmup_steps = total_steps * self._get_hyper("warmup_proportion", var_dtype)
193 min_lr = self._get_hyper("min_lr", var_dtype)
194 decay_steps = tf.maximum(total_steps - warmup_steps, 1)
195 decay_rate = (min_lr - lr_t) / decay_steps
196 lr_t = tf.where(
197 local_step <= warmup_steps,
198 lr_t * (local_step / warmup_steps),
199 lr_t + decay_rate * tf.minimum(local_step - warmup_steps, decay_steps),
200 )
201 apply_state[(var_device, var_dtype)].update(
202 dict(
203 lr_t=lr_t,
204 wd_t=wd_t,
205 beta_1_t=beta_1_t,
206 beta_2_t=beta_2_t,
207 epsilon_t=tf.convert_to_tensor(self.epsilon, var_dtype),
208 local_step=local_step,
209 beta_1_power=beta_1_power,
210 beta_2_power=beta_2_power,
211 sma_inf=sma_inf,
212 sma_t=sma_t,
213 one_minus_beta_1_t=one_minus_beta_1_t,
214 recip_one_minus_beta_1_power=recip_one_minus_beta_1_power,
215 one_minus_beta_2_t=one_minus_beta_2_t,
216 recip_one_minus_beta_2_power=recip_one_minus_beta_2_power,
217 r_t=r_t,
218 sma_t_ge_sma_threshold=sma_t_ge_sma_threshold,
219 )
220 )
222 def _resource_apply_dense(self, grad, var, apply_state=None):
223 var_device, var_dtype = var.device, var.dtype.base_dtype
224 coef = (apply_state or {}).get(
225 (var_device, var_dtype)
226 ) or self._fallback_apply_state(var_device, var_dtype)
227 m = self.get_slot(var, "m")
228 v = self.get_slot(var, "v")
230 m_t = m.assign(
231 coef["beta_1_t"] * m + coef["one_minus_beta_1_t"] * grad,
232 use_locking=self._use_locking,
233 )
234 m_corr_t = m_t * coef["recip_one_minus_beta_1_power"]
236 v_t = v.assign(
237 coef["beta_2_t"] * v + coef["one_minus_beta_2_t"] * tf.square(grad),
238 use_locking=self._use_locking,
239 )
240 if self.amsgrad:
241 vhat = self.get_slot(var, "vhat")
242 vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking)
243 v_corr_t = tf.sqrt(vhat_t * coef["recip_one_minus_beta_2_power"])
244 else:
245 vhat_t = None
246 v_corr_t = tf.sqrt(v_t * coef["recip_one_minus_beta_2_power"])
248 var_t = tf.where(
249 coef["sma_t_ge_sma_threshold"],
250 coef["r_t"] * m_corr_t / (v_corr_t + coef["epsilon_t"]),
251 m_corr_t,
252 )
254 if self._has_weight_decay:
255 var_t += coef["wd_t"] * var
257 var_update = var.assign_sub(coef["lr_t"] * var_t, use_locking=self._use_locking)
259 updates = [var_update, m_t, v_t]
260 if self.amsgrad:
261 updates.append(vhat_t)
262 return tf.group(*updates)
264 def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
265 var_device, var_dtype = var.device, var.dtype.base_dtype
266 coef = (apply_state or {}).get(
267 (var_device, var_dtype)
268 ) or self._fallback_apply_state(var_device, var_dtype)
270 m = self.get_slot(var, "m")
271 m_scaled_g_values = grad * coef["one_minus_beta_1_t"]
272 m_t = m.assign(m * coef["beta_1_t"], use_locking=self._use_locking)
273 with tf.control_dependencies([m_t]):
274 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
275 m_corr_t = m_t * coef["recip_one_minus_beta_1_power"]
277 v = self.get_slot(var, "v")
278 v_scaled_g_values = (grad * grad) * coef["one_minus_beta_2_t"]
279 v_t = v.assign(v * coef["beta_2_t"], use_locking=self._use_locking)
280 with tf.control_dependencies([v_t]):
281 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
283 if self.amsgrad:
284 vhat = self.get_slot(var, "vhat")
285 vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking)
286 v_corr_t = tf.sqrt(vhat_t * coef["recip_one_minus_beta_2_power"])
287 else:
288 vhat_t = None
289 v_corr_t = tf.sqrt(v_t * coef["recip_one_minus_beta_2_power"])
291 var_t = tf.where(
292 coef["sma_t_ge_sma_threshold"],
293 coef["r_t"] * m_corr_t / (v_corr_t + coef["epsilon_t"]),
294 m_corr_t,
295 )
297 if self._has_weight_decay:
298 var_t += coef["wd_t"] * var
300 with tf.control_dependencies([var_t]):
301 var_update = self._resource_scatter_add(
302 var, indices, tf.gather(-coef["lr_t"] * var_t, indices)
303 )
305 updates = [var_update, m_t, v_t]
306 if self.amsgrad:
307 updates.append(vhat_t)
308 return tf.group(*updates)
310 def get_config(self):
311 config = super().get_config()
312 config.update(
313 {
314 "learning_rate": self._serialize_hyperparameter("learning_rate"),
315 "beta_1": self._serialize_hyperparameter("beta_1"),
316 "beta_2": self._serialize_hyperparameter("beta_2"),
317 "decay": self._serialize_hyperparameter("decay"),
318 "weight_decay": self._serialize_hyperparameter("weight_decay"),
319 "sma_threshold": self._serialize_hyperparameter("sma_threshold"),
320 "epsilon": self.epsilon,
321 "amsgrad": self.amsgrad,
322 "total_steps": int(self._serialize_hyperparameter("total_steps")),
323 "warmup_proportion": self._serialize_hyperparameter(
324 "warmup_proportion"
325 ),
326 "min_lr": self._serialize_hyperparameter("min_lr"),
327 }
328 )
329 return config