Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/weight_decay_optimizers.py: 38%
100 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class to make optimizers weight decay ready."""
17import importlib
18import tensorflow as tf
19from tensorflow_addons.utils.types import FloatTensorLike
20from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes
22from typeguard import typechecked
23from typing import Union, Callable, Type, Optional, List
26class DecoupledWeightDecayExtension:
27 """This class allows to extend optimizers with decoupled weight decay.
29 It implements the decoupled weight decay described by [Loshchilov & Hutter]
30 (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
31 decoupled from the optimization steps w.r.t. to the loss function.
32 For SGD variants, this simplifies hyperparameter search since it decouples
33 the settings of weight decay and learning rate.
34 For adaptive gradient algorithms, it regularizes variables with large
35 gradients more than L2 regularization would, which was shown to yield
36 better training loss and generalization error in the paper above.
38 This class alone is not an optimizer but rather extends existing
39 optimizers with decoupled weight decay. We explicitly define the two
40 examples used in the above paper (SGDW and AdamW), but in general this can
41 extend any OptimizerX class by using
42 `ExtendedCls = extend_with_decoupled_weight_decay(OptimizerX)`.
43 Weight decay can then be set when instantiating the optimizer:
44 `optimizerX = ExtendedCls(weight_decay=0.001, learning_rate=0.001)`.
45 In order for it to work, it must be the first class the Optimizer with
46 weight decay inherits from, e.g.
48 ```python
49 class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
50 def __init__(self, weight_decay, *args, **kwargs):
51 super(AdamW, self).__init__(weight_decay, *args, **kwargs).
52 ```
54 Note: this extension decays weights BEFORE applying the update based
55 on the gradient, i.e. this extension only has the desired behaviour for
56 optimizers which do not depend on the value of'var' in the update step!
58 Note: when applying a decay to the learning rate, be sure to manually apply
59 the decay to the `weight_decay` as well. For example:
61 ```python
62 step = tf.Variable(0, trainable=False)
63 schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
64 [10000, 15000], [1e-0, 1e-1, 1e-2])
65 # lr and wd can be a function or a tensor
66 lr = 1e-1 * schedule(step)
67 wd = lambda: 1e-4 * schedule(step)
69 # ...
71 optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
72 ```
73 """
75 @typechecked
76 def __init__(
77 self,
78 weight_decay: Union[FloatTensorLike, Callable],
79 exclude_from_weight_decay: Optional[List[str]] = None,
80 **kwargs,
81 ):
82 """Extension class that adds weight decay to an optimizer.
84 Args:
85 weight_decay: A `Tensor`, a floating point value, or a schedule
86 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`
87 to decay the variable by, in the update step.
88 exclude_from_weight_decay: List of regex patterns of
89 variables excluded from weight decay. Variables whose name
90 contain a substring matching the pattern will be excluded.
91 Note `decay_var_list` in `minimize` or `apply_gradients` takes
92 priority over `exclude_from_weight_decay` if specified.
93 **kwargs: Optional list or tuple or set of `Variable` objects to
94 decay.
95 """
96 wd = kwargs.pop("weight_decay", weight_decay)
97 super().__init__(**kwargs)
98 self._decay_var_list = None # is set in minimize or apply_gradients
99 self._set_hyper("weight_decay", wd)
100 self.exclude_from_weight_decay = exclude_from_weight_decay
102 def get_config(self):
103 config = super().get_config()
104 config.update(
105 {
106 "weight_decay": self._serialize_hyperparameter("weight_decay"),
107 "exclude_from_weight_decay": self.exclude_from_weight_decay,
108 }
109 )
110 return config
112 @classmethod
113 def from_config(cls, config, custom_objects=None):
114 # LR handling copied from optimizer_v2.OptimizerV2
115 if "learning_rate" in config:
116 if isinstance(config["learning_rate"], dict):
117 config["learning_rate"] = tf.keras.optimizers.schedules.deserialize(
118 config["learning_rate"], custom_objects=custom_objects
119 )
121 if "weight_decay" in config:
122 if isinstance(config["weight_decay"], dict):
123 config["weight_decay"] = tf.keras.optimizers.schedules.deserialize(
124 config["weight_decay"], custom_objects=custom_objects
125 )
127 return cls(**config)
129 def minimize(
130 self,
131 loss,
132 var_list,
133 grad_loss=None,
134 name=None,
135 decay_var_list=None,
136 tape=None,
137 ):
138 """Minimize `loss` by updating `var_list`.
140 This method simply computes gradient using `tf.GradientTape` and calls
141 `apply_gradients()`. If you want to process the gradient before
142 applying then call `tf.GradientTape` and `apply_gradients()` explicitly
143 instead of using this function.
145 Args:
146 loss: `Tensor` or callable. If a callable, `loss` should take no
147 arguments and return the value to minimize. If a `Tensor`, the
148 `tape` argument must be passed.
149 var_list: list or tuple of `Variable` objects to update to
150 minimize `loss`, or a callable returning the list or tuple of
151 `Variable` objects. Use callable when the variable list would
152 otherwise be incomplete before `minimize` since the variables
153 are created at the first time `loss` is called.
154 grad_loss: Optional. A `Tensor` holding the gradient computed for
155 `loss`.
156 decay_var_list: Optional list of variables to be decayed. Defaults
157 to all variables in var_list. Note `decay_var_list` takes
158 priority over `exclude_from_weight_decay` if specified.
159 name: Optional name for the returned operation.
160 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a
161 `Tensor`, the tape that computed the `loss` must be provided.
162 Returns:
163 An Operation that updates the variables in `var_list`.
164 Raises:
165 ValueError: If some of the variables are not `Variable` objects.
166 """
167 self._set_decay_var_list(var_list, decay_var_list)
168 return super().minimize(
169 loss, var_list=var_list, grad_loss=grad_loss, name=name, tape=tape
170 )
172 def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwargs):
173 """Apply gradients to variables.
175 This is the second part of `minimize()`. It returns an `Operation` that
176 applies gradients.
178 Args:
179 grads_and_vars: List of (gradient, variable) pairs.
180 name: Optional name for the returned operation. Default to the
181 name passed to the `Optimizer` constructor.
182 decay_var_list: Optional list of variables to be decayed. Defaults
183 to all variables in var_list. Note `decay_var_list` takes
184 priority over `exclude_from_weight_decay` if specified.
185 **kwargs: Additional arguments to pass to the base optimizer's
186 apply_gradient method, e.g., TF2.2 added an argument
187 `experimental_aggregate_gradients`.
188 Returns:
189 An `Operation` that applies the specified gradients.
190 Raises:
191 TypeError: If `grads_and_vars` is malformed.
192 ValueError: If none of the variables have gradients.
193 """
194 grads_and_vars = list(grads_and_vars)
195 self._set_decay_var_list((v for _, v in grads_and_vars), decay_var_list)
196 return super().apply_gradients(grads_and_vars, name=name, **kwargs)
198 def _decay_weights_op(self, var, apply_state=None):
199 if self._do_use_weight_decay(var):
200 var_device, var_dtype = var.device, var.dtype.base_dtype
201 coefficients = (apply_state or {}).get(
202 (var_device, var_dtype)
203 ) or self._fallback_apply_state(var_device, var_dtype)
205 return var.assign_sub(coefficients["wd_t"] * var, self._use_locking)
206 return tf.no_op()
208 def _decay_weights_sparse_op(self, var, indices, apply_state=None):
209 if self._do_use_weight_decay(var):
210 var_device, var_dtype = var.device, var.dtype.base_dtype
211 coefficients = (apply_state or {}).get(
212 (var_device, var_dtype)
213 ) or self._fallback_apply_state(var_device, var_dtype)
215 update = -coefficients["wd_t"] * tf.gather(var, indices)
216 return self._resource_scatter_add(var, indices, update)
217 return tf.no_op()
219 def _prepare_local(self, var_device, var_dtype, apply_state):
220 super(DecoupledWeightDecayExtension, self)._prepare_local(
221 var_device, var_dtype, apply_state
222 )
224 if "weight_decay" in self._hyper:
225 wd_t = tf.identity(self._decayed_wd(var_dtype))
226 apply_state[(var_device, var_dtype)]["wd_t"] = wd_t
228 def _decayed_wd(self, var_dtype):
229 wd_t = self._get_hyper("weight_decay", var_dtype)
231 if isinstance(wd_t, tf.keras.optimizers.schedules.LearningRateSchedule):
232 wd_t = tf.cast(wd_t(self.iterations), var_dtype)
234 return wd_t
236 # Here, we overwrite the apply functions that the base optimizer calls.
237 # super().apply_x resolves to the apply_x function of the BaseOptimizer.
239 def _resource_apply_dense(self, grad, var, apply_state=None):
240 with tf.control_dependencies(
241 [self._decay_weights_op(var, apply_state=apply_state)]
242 ):
243 return super()._resource_apply_dense(grad, var, apply_state=apply_state)
245 def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
246 decay_op = self._decay_weights_sparse_op(var, indices, apply_state=apply_state)
247 with tf.control_dependencies([decay_op]):
248 return super()._resource_apply_sparse(
249 grad, var, indices, apply_state=apply_state
250 )
252 def _set_decay_var_list(self, var_list, decay_var_list=None):
253 if decay_var_list:
254 self._decay_var_list = set(v.ref() for v in decay_var_list)
255 elif self.exclude_from_weight_decay:
256 self._decay_var_list = set(
257 v.ref()
258 for v in var_list
259 if not is_variable_matched_by_regexes(v, self.exclude_from_weight_decay)
260 )
261 else:
262 self._decay_var_list = None
264 def _do_use_weight_decay(self, var):
265 """Whether to use L2 weight decay for `var`."""
266 if self._decay_var_list is None:
267 return True
268 return var.ref() in self._decay_var_list
271if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
272 keras_legacy_optimizer = Union[
273 tf.keras.optimizers.legacy.Optimizer, tf.keras.optimizers.Optimizer
274 ]
275else:
276 keras_legacy_optimizer = tf.keras.optimizers.Optimizer
279@typechecked
280def extend_with_decoupled_weight_decay(
281 base_optimizer: Type[keras_legacy_optimizer],
282) -> Type[keras_legacy_optimizer]:
283 """Factory function returning an optimizer class with decoupled weight
284 decay.
286 Returns an optimizer class. An instance of the returned class computes the
287 update step of `base_optimizer` and additionally decays the weights.
288 E.g., the class returned by
289 `extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is
290 equivalent to `tfa.optimizers.AdamW`.
292 The API of the new optimizer class slightly differs from the API of the
293 base optimizer:
294 - The first argument to the constructor is the weight decay rate.
295 - Optional keyword argument `exclude_from_weight_decay` accepts list of
296 regex patterns of variables excluded from weight decay. Variables whose
297 name contain a substring matching the pattern will be excluded.
298 - `minimize` and `apply_gradients` accept the optional keyword argument
299 `decay_var_list`, which specifies the variables that should be decayed.
300 Note this takes priority over `exclude_from_weight_decay` if specified.
301 If both `None`, all variables that are optimized are decayed.
303 Usage example:
304 ```python
305 # MyAdamW is a new class
306 MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)
307 # Create a MyAdamW object
308 optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
309 # update var1, var2 but only decay var1
310 optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1])
312 Note: this extension decays weights BEFORE applying the update based
313 on the gradient, i.e. this extension only has the desired behaviour for
314 optimizers which do not depend on the value of 'var' in the update step!
316 Note: when applying a decay to the learning rate, be sure to manually apply
317 the decay to the `weight_decay` as well. For example:
319 ```python
320 step = tf.Variable(0, trainable=False)
321 schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
322 [10000, 15000], [1e-0, 1e-1, 1e-2])
323 # lr and wd can be a function or a tensor
324 lr = 1e-1 * schedule(step)
325 wd = lambda: 1e-4 * schedule(step)
327 # ...
329 optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
330 ```
332 Note: you might want to register your own custom optimizer using
333 `tf.keras.utils.get_custom_objects()`.
335 Args:
336 base_optimizer: An optimizer class that inherits from
337 tf.optimizers.Optimizer.
339 Returns:
340 A new optimizer class that inherits from DecoupledWeightDecayExtension
341 and base_optimizer.
342 """
344 class OptimizerWithDecoupledWeightDecay(
345 DecoupledWeightDecayExtension, base_optimizer
346 ):
347 """Base_optimizer with decoupled weight decay.
349 This class computes the update step of `base_optimizer` and
350 additionally decays the variable with the weight decay being
351 decoupled from the optimization steps w.r.t. to the loss
352 function, as described by [Loshchilov & Hutter]
353 (https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this
354 simplifies hyperparameter search since it decouples the settings
355 of weight decay and learning rate. For adaptive gradient
356 algorithms, it regularizes variables with large gradients more
357 than L2 regularization would, which was shown to yield better
358 training loss and generalization error in the paper above.
359 """
361 @typechecked
362 def __init__(
363 self,
364 weight_decay: Union[FloatTensorLike, Callable],
365 *args,
366 **kwargs,
367 ):
368 # super delegation is necessary here
369 super().__init__(weight_decay, *args, **kwargs)
371 return OptimizerWithDecoupledWeightDecay
374if hasattr(tf.keras.optimizers, "legacy"):
375 ADAM_CLASS = tf.keras.optimizers.legacy.Adam
376 SGD_CLASS = tf.keras.optimizers.legacy.SGD
377else:
378 ADAM_CLASS = tf.keras.optimizers.Adam
379 SGD_CLASS = tf.keras.optimizers.SGD
382@tf.keras.utils.register_keras_serializable(package="Addons")
383class SGDW(DecoupledWeightDecayExtension, SGD_CLASS):
384 """Optimizer that implements the Momentum algorithm with weight_decay.
386 This is an implementation of the SGDW optimizer described in "Decoupled
387 Weight Decay Regularization" by [Loshchilov & Hutter]
388 (https://arxiv.org/pdf/1711.05101.pdf).
389 It computes the update step of `tf.keras.optimizers.SGD` and additionally
390 decays the variable. Note that this is different from adding
391 L2 regularization on the variables to the loss. Decoupling the weight decay
392 from other hyperparameters (in particular the learning rate) simplifies
393 hyperparameter search.
395 For further information see the documentation of the SGD Optimizer.
397 This optimizer can also be instantiated as
398 ```python
399 extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD,
400 weight_decay=weight_decay)
401 ```
403 Note: when applying a decay to the learning rate, be sure to manually apply
404 the decay to the `weight_decay` as well. For example:
406 ```python
407 step = tf.Variable(0, trainable=False)
408 schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
409 [10000, 15000], [1e-0, 1e-1, 1e-2])
410 # lr and wd can be a function or a tensor
411 lr = 1e-1 * schedule(step)
412 wd = lambda: 1e-4 * schedule(step)
414 # ...
416 optimizer = tfa.optimizers.SGDW(
417 learning_rate=lr, weight_decay=wd, momentum=0.9)
418 ```
419 """
421 @typechecked
422 def __init__(
423 self,
424 weight_decay: Union[FloatTensorLike, Callable],
425 learning_rate: Union[FloatTensorLike, Callable] = 0.001,
426 momentum: Union[FloatTensorLike, Callable] = 0.0,
427 nesterov: bool = False,
428 name: str = "SGDW",
429 **kwargs,
430 ):
431 """Construct a new SGDW optimizer.
433 For further information see the documentation of the SGD Optimizer.
435 Args:
436 learning_rate: float hyperparameter >= 0. Learning rate.
437 momentum: float hyperparameter >= 0 that accelerates SGD in the
438 relevant direction and dampens oscillations.
439 nesterov: boolean. Whether to apply Nesterov momentum.
440 name: Optional name prefix for the operations created when applying
441 gradients. Defaults to 'SGD'.
442 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
443 `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip
444 gradients by norm; `clipvalue` is clip gradients by value.
445 `decay` is included for backward compatibility to allow time
446 inverse decay of learning rate. `lr` is included for backward
447 compatibility, recommended to use `learning_rate` instead.
448 `exclude_from_weight_decay` accepts list of regex patterns of
449 variables excluded from weight decay.
450 """
451 super().__init__(
452 weight_decay,
453 learning_rate=learning_rate,
454 momentum=momentum,
455 nesterov=nesterov,
456 name=name,
457 **kwargs,
458 )
461@tf.keras.utils.register_keras_serializable(package="Addons")
462class AdamW(DecoupledWeightDecayExtension, ADAM_CLASS):
463 """Optimizer that implements the Adam algorithm with weight decay.
465 This is an implementation of the AdamW optimizer described in "Decoupled
466 Weight Decay Regularization" by [Loshchilov & Hutter]
467 (https://arxiv.org/pdf/1711.05101.pdf).
469 It computes the update step of `tf.keras.optimizers.Adam` and additionally
470 decays the variable. Note that this is different from adding L2
471 regularization on the variables to the loss: it regularizes variables with
472 large gradients more than L2 regularization would, which was shown to yield
473 better training loss and generalization error in the paper above.
475 For further information see the documentation of the Adam Optimizer.
477 This optimizer can also be instantiated as
478 ```python
479 extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam,
480 weight_decay=weight_decay)
481 ```
483 Note: when applying a decay to the learning rate, be sure to manually apply
484 the decay to the `weight_decay` as well. For example:
486 ```python
487 step = tf.Variable(0, trainable=False)
488 schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
489 [10000, 15000], [1e-0, 1e-1, 1e-2])
490 # lr and wd can be a function or a tensor
491 lr = 1e-1 * schedule(step)
492 wd = lambda: 1e-4 * schedule(step)
494 # ...
496 optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
497 ```
498 """
500 @typechecked
501 def __init__(
502 self,
503 weight_decay: Union[FloatTensorLike, Callable],
504 learning_rate: Union[FloatTensorLike, Callable] = 0.001,
505 beta_1: Union[FloatTensorLike, Callable] = 0.9,
506 beta_2: Union[FloatTensorLike, Callable] = 0.999,
507 epsilon: FloatTensorLike = 1e-07,
508 amsgrad: bool = False,
509 name: str = "AdamW",
510 **kwargs,
511 ):
512 """Construct a new AdamW optimizer.
514 For further information see the documentation of the Adam Optimizer.
516 Args:
517 weight_decay: A Tensor or a floating point value. The weight decay.
518 learning_rate: A Tensor or a floating point value. The learning
519 rate.
520 beta_1: A float value or a constant float tensor. The exponential
521 decay rate for the 1st moment estimates.
522 beta_2: A float value or a constant float tensor. The exponential
523 decay rate for the 2nd moment estimates.
524 epsilon: A small constant for numerical stability. This epsilon is
525 "epsilon hat" in the Kingma and Ba paper (in the formula just
526 before Section 2.1), not the epsilon in Algorithm 1 of the
527 paper.
528 amsgrad: boolean. Whether to apply AMSGrad variant of this
529 algorithm from the paper "On the Convergence of Adam and
530 beyond".
531 name: Optional name for the operations created when applying
532 gradients. Defaults to "AdamW".
533 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
534 `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip
535 gradients by norm; `clipvalue` is clip gradients by value.
536 `decay` is included for backward compatibility to allow time
537 inverse decay of learning rate. `lr` is included for backward
538 compatibility, recommended to use `learning_rate` instead.
539 `exclude_from_weight_decay` accepts list of regex patterns of
540 variables excluded from weight decay.
541 """
542 super().__init__(
543 weight_decay,
544 learning_rate=learning_rate,
545 beta_1=beta_1,
546 beta_2=beta_2,
547 epsilon=epsilon,
548 amsgrad=amsgrad,
549 name=name,
550 **kwargs,
551 )