Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/adamw.py: 20%
65 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AdamW optimizer implementation."""
18import tensorflow.compat.v2 as tf
20from keras.src.optimizers import optimizer
21from keras.src.saving.object_registration import register_keras_serializable
23# isort: off
24from tensorflow.python.util.tf_export import keras_export
27@register_keras_serializable()
28@keras_export(
29 "keras.optimizers.AdamW",
30 "keras.optimizers.experimental.AdamW",
31 "keras.dtensor.experimental.optimizers.AdamW",
32 v1=[],
33)
34class AdamW(optimizer.Optimizer):
35 r"""Optimizer that implements the AdamW algorithm.
37 AdamW optimization is a stochastic gradient descent method that is based on
38 adaptive estimation of first-order and second-order moments with an added
39 method to decay weights per the techniques discussed in the paper,
40 'Decoupled Weight Decay Regularization' by
41 [Loshchilov, Hutter et al., 2019](https://arxiv.org/abs/1711.05101).
43 According to
44 [Kingma et al., 2014](http://arxiv.org/abs/1412.6980),
45 the underying Adam method is "*computationally
46 efficient, has little memory requirement, invariant to diagonal rescaling of
47 gradients, and is well suited for problems that are large in terms of
48 data/parameters*".
50 Args:
51 learning_rate: A `tf.Tensor`, floating point value, a schedule that is a
52 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
53 that takes no arguments and returns the actual value to use. The
54 learning rate. Defaults to 0.001.
55 beta_1: A float value or a constant float tensor, or a callable
56 that takes no arguments and returns the actual value to use. The
57 exponential decay rate for the 1st moment estimates. Defaults to 0.9.
58 beta_2: A float value or a constant float tensor, or a callable
59 that takes no arguments and returns the actual value to use. The
60 exponential decay rate for the 2nd moment estimates. Defaults to 0.999.
61 epsilon: A small constant for numerical stability. This epsilon is
62 "epsilon hat" in the Kingma and Ba paper (in the formula just before
63 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to
64 1e-7.
65 amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm from
66 the paper "On the Convergence of Adam and beyond". Defaults to `False`.
67 {{base_optimizer_keyword_args}}
69 Reference:
70 - [Loshchilov et al., 2019](https://arxiv.org/abs/1711.05101)
71 - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) for `adam`
72 - [Reddi et al., 2018](
73 https://openreview.net/pdf?id=ryQu7f-RZ) for `amsgrad`.
75 Notes:
77 The sparse implementation of this algorithm (used when the gradient is an
78 IndexedSlices object, typically because of `tf.gather` or an embedding
79 lookup in the forward pass) does apply momentum to variable slices even if
80 they were not used in the forward pass (meaning they have a gradient equal
81 to zero). Momentum decay (beta1) is also applied to the entire momentum
82 accumulator. This means that the sparse behavior is equivalent to the dense
83 behavior (in contrast to some momentum implementations which ignore momentum
84 unless a variable slice was actually used).
85 """
87 def __init__(
88 self,
89 learning_rate=0.001,
90 weight_decay=0.004,
91 beta_1=0.9,
92 beta_2=0.999,
93 epsilon=1e-7,
94 amsgrad=False,
95 clipnorm=None,
96 clipvalue=None,
97 global_clipnorm=None,
98 use_ema=False,
99 ema_momentum=0.99,
100 ema_overwrite_frequency=None,
101 jit_compile=True,
102 name="AdamW",
103 **kwargs
104 ):
105 super().__init__(
106 name=name,
107 clipnorm=clipnorm,
108 clipvalue=clipvalue,
109 global_clipnorm=global_clipnorm,
110 use_ema=use_ema,
111 ema_momentum=ema_momentum,
112 ema_overwrite_frequency=ema_overwrite_frequency,
113 jit_compile=jit_compile,
114 **kwargs
115 )
116 self._learning_rate = self._build_learning_rate(learning_rate)
117 self.weight_decay = weight_decay
118 self.beta_1 = beta_1
119 self.beta_2 = beta_2
120 self.epsilon = epsilon
121 self.amsgrad = amsgrad
123 if self.weight_decay is None:
124 raise ValueError(
125 "Missing value of `weight_decay` which is required and"
126 " must be a float value."
127 )
129 def build(self, var_list):
130 """Initialize optimizer variables.
132 AdamW optimizer has 3 types of variables: momentums, velocities and
133 velocity_hat (only set when amsgrad is applied),
135 Args:
136 var_list: list of model variables to build AdamW variables on.
137 """
138 super().build(var_list)
139 if hasattr(self, "_built") and self._built:
140 return
141 self._built = True
142 self._momentums = []
143 self._velocities = []
144 for var in var_list:
145 self._momentums.append(
146 self.add_variable_from_reference(
147 model_variable=var, variable_name="m"
148 )
149 )
150 self._velocities.append(
151 self.add_variable_from_reference(
152 model_variable=var, variable_name="v"
153 )
154 )
155 if self.amsgrad:
156 self._velocity_hats = []
157 for var in var_list:
158 self._velocity_hats.append(
159 self.add_variable_from_reference(
160 model_variable=var, variable_name="vhat"
161 )
162 )
164 def update_step(self, gradient, variable):
165 """Update step given gradient and the associated model variable."""
166 beta_1_power = None
167 beta_2_power = None
168 lr = tf.cast(self.learning_rate, variable.dtype)
169 local_step = tf.cast(self.iterations + 1, variable.dtype)
170 beta_1_power = tf.pow(tf.cast(self.beta_1, variable.dtype), local_step)
171 beta_2_power = tf.pow(tf.cast(self.beta_2, variable.dtype), local_step)
173 var_key = self._var_key(variable)
174 m = self._momentums[self._index_dict[var_key]]
175 v = self._velocities[self._index_dict[var_key]]
177 alpha = lr * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)
179 if isinstance(gradient, tf.IndexedSlices):
180 # Sparse gradients.
181 m.assign_add(-m * (1 - self.beta_1))
182 m.scatter_add(
183 tf.IndexedSlices(
184 gradient.values * (1 - self.beta_1), gradient.indices
185 )
186 )
187 v.assign_add(-v * (1 - self.beta_2))
188 v.scatter_add(
189 tf.IndexedSlices(
190 tf.square(gradient.values) * (1 - self.beta_2),
191 gradient.indices,
192 )
193 )
194 if self.amsgrad:
195 v_hat = self._velocity_hats[self._index_dict[var_key]]
196 v_hat.assign(tf.maximum(v_hat, v))
197 v = v_hat
198 variable.assign_sub((m * alpha) / (tf.sqrt(v) + self.epsilon))
199 else:
200 # Dense gradients.
201 m.assign_add((gradient - m) * (1 - self.beta_1))
202 v.assign_add((tf.square(gradient) - v) * (1 - self.beta_2))
203 if self.amsgrad:
204 v_hat = self._velocity_hats[self._index_dict[var_key]]
205 v_hat.assign(tf.maximum(v_hat, v))
206 v = v_hat
207 variable.assign_sub((m * alpha) / (tf.sqrt(v) + self.epsilon))
209 def get_config(self):
210 config = super().get_config()
212 config.update(
213 {
214 "learning_rate": self._serialize_hyperparameter(
215 self._learning_rate
216 ),
217 "weight_decay": self.weight_decay,
218 "beta_1": self.beta_1,
219 "beta_2": self.beta_2,
220 "epsilon": self.epsilon,
221 "amsgrad": self.amsgrad,
222 }
223 )
224 return config
227AdamW.__doc__ = AdamW.__doc__.replace(
228 "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
229)