Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/lamb.py: 19%
99 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Layer-wise Adaptive Moments (LAMB) optimizer.
17See paper [Large Batch Optimization for Deep Learning: Training BERT in
1876 minutes](https://arxiv.org/abs/1904.00962).
19"""
21import warnings
23from typing import Optional, Union, Callable, List
24from typeguard import typechecked
26import tensorflow as tf
27from tensorflow_addons.optimizers import KerasLegacyOptimizer
28from tensorflow_addons.utils.types import FloatTensorLike
29from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes
32@tf.keras.utils.register_keras_serializable(package="Addons")
33class LAMB(KerasLegacyOptimizer):
34 """Optimizer that implements the Layer-wise Adaptive Moments (LAMB).
36 See paper [Large Batch Optimization for Deep Learning: Training BERT
37 in 76 minutes](https://arxiv.org/abs/1904.00962).
38 """
40 @typechecked
41 def __init__(
42 self,
43 learning_rate: Union[FloatTensorLike, Callable] = 0.001,
44 beta_1: FloatTensorLike = 0.9,
45 beta_2: FloatTensorLike = 0.999,
46 epsilon: FloatTensorLike = 1e-6,
47 weight_decay: FloatTensorLike = 0.0,
48 exclude_from_weight_decay: Optional[List[str]] = None,
49 exclude_from_layer_adaptation: Optional[List[str]] = None,
50 name: str = "LAMB",
51 **kwargs,
52 ):
53 """Construct a new LAMB optimizer.
55 Args:
56 learning_rate: A `Tensor` or a floating point value. or a schedule
57 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`
58 The learning rate.
59 beta_1: A `float` value or a constant `float` tensor.
60 The exponential decay rate for the 1st moment estimates.
61 beta_2: A `float` value or a constant `float` tensor.
62 The exponential decay rate for the 2nd moment estimates.
63 epsilon: A small constant for numerical stability.
64 weight_decay: weight decay.
65 exclude_from_weight_decay: List of regex patterns of
66 variables excluded from weight decay. Variables whose name
67 contain a substring matching the pattern will be excluded.
68 exclude_from_layer_adaptation: List of regex patterns of
69 variables excluded from layer adaptation. Variables whose name
70 contain a substring matching the pattern will be excluded.
71 name: Optional name for the operations created when applying
72 gradients. Defaults to "LAMB".
73 **kwargs: keyword arguments. Allowed to be {`clipnorm`,
74 `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
75 norm; `clipvalue` is clip gradients by value, `decay` is
76 included for backward compatibility to allow time inverse
77 decay of learning rate. `lr` is included for backward
78 compatibility, recommended to use `learning_rate` instead.
79 """
81 if "weight_decay_rate" in kwargs:
82 warnings.warn(
83 "weight_decay_rate has been renamed to weight_decay,"
84 "and will be deprecated in Addons 0.18.",
85 DeprecationWarning,
86 )
87 weight_decay = kwargs["weight_decay_rate"]
88 del kwargs["weight_decay_rate"]
90 super().__init__(name, **kwargs)
92 # Just adding the square of the weights to the loss function is *not*
93 # the correct way of using L2 regularization/weight decay with Adam,
94 # since that will interact with the m and v parameters in strange ways.
95 #
96 # Instead we want to decay the weights in a manner that doesn't interact
97 # with the m/v parameters.
98 self._set_hyper("weight_decay", weight_decay)
99 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
101 # This is learning rate decay for using keras learning rate schedule.
102 self._set_hyper("decay", self._initial_decay)
103 self._set_hyper("beta_1", beta_1)
104 self._set_hyper("beta_2", beta_2)
105 self.epsilon = epsilon or tf.backend_config.epsilon()
106 self.exclude_from_weight_decay = exclude_from_weight_decay
107 # exclude_from_layer_adaptation is set to exclude_from_weight_decay if
108 # the arg is None.
109 if exclude_from_layer_adaptation:
110 self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
111 else:
112 self.exclude_from_layer_adaptation = exclude_from_weight_decay
114 def _create_slots(self, var_list):
115 # Create slots for the first and second moments.
116 # Separate for-loops to respect the ordering of slot variables from v1.
117 for var in var_list:
118 self.add_slot(var, "m")
119 for var in var_list:
120 self.add_slot(var, "v")
122 def _prepare_local(self, var_device, var_dtype, apply_state):
123 super()._prepare_local(var_device, var_dtype, apply_state)
125 local_step = tf.cast(self.iterations + 1, var_dtype)
126 beta_1_t = tf.identity(self._get_hyper("beta_1", var_dtype))
127 beta_2_t = tf.identity(self._get_hyper("beta_2", var_dtype))
128 weight_decay = tf.identity(self._get_hyper("weight_decay", var_dtype))
129 beta_1_power = tf.pow(beta_1_t, local_step)
130 beta_2_power = tf.pow(beta_2_t, local_step)
131 apply_state[(var_device, var_dtype)].update(
132 dict(
133 weight_decay=weight_decay,
134 epsilon=tf.convert_to_tensor(self.epsilon, var_dtype),
135 beta_1_t=beta_1_t,
136 beta_1_power=beta_1_power,
137 one_minus_beta_1_t=1 - beta_1_t,
138 beta_2_t=beta_2_t,
139 beta_2_power=beta_2_power,
140 one_minus_beta_2_t=1 - beta_2_t,
141 )
142 )
144 def _resource_apply_dense(self, grad, var, apply_state=None):
145 var_device, var_dtype = var.device, var.dtype.base_dtype
146 coefficients = (apply_state or {}).get(
147 (var_device, var_dtype)
148 ) or self._fallback_apply_state(var_device, var_dtype)
150 # m_t = beta1 * m + (1 - beta1) * g_t
151 m = self.get_slot(var, "m")
152 m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"]
153 m_t = m * coefficients["beta_1_t"] + m_scaled_g_values
154 m_t = m.assign(m_t, use_locking=self._use_locking)
155 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
156 v = self.get_slot(var, "v")
157 v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"]
158 v_t = v * coefficients["beta_2_t"] + v_scaled_g_values
159 v_t = v.assign(v_t, use_locking=self._use_locking)
161 m_t_hat = m_t / (1.0 - coefficients["beta_1_power"])
162 v_t_hat = v_t / (1.0 - coefficients["beta_2_power"])
164 v_sqrt = tf.sqrt(v_t_hat)
165 update = m_t_hat / (v_sqrt + coefficients["epsilon"])
167 if self._do_use_weight_decay(var):
168 update += coefficients["weight_decay"] * var
170 ratio = 1.0
171 if self._do_layer_adaptation(var):
172 w_norm = tf.norm(var, ord=2)
173 g_norm = tf.norm(update, ord=2)
174 ratio = tf.where(
175 tf.greater(w_norm, 0),
176 tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0),
177 1.0,
178 )
180 var_update = var - ratio * coefficients["lr_t"] * update
181 return var.assign(var_update, use_locking=self._use_locking)
183 def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
184 var_device, var_dtype = var.device, var.dtype.base_dtype
185 coefficients = (apply_state or {}).get(
186 (var_device, var_dtype)
187 ) or self._fallback_apply_state(var_device, var_dtype)
189 # m_t = beta1 * m + (1 - beta1) * g_t
190 m = self.get_slot(var, "m")
191 m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"]
192 m_t = m.assign(m * coefficients["beta_1_t"], use_locking=self._use_locking)
193 with tf.control_dependencies([m_t]):
194 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
196 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
197 v = self.get_slot(var, "v")
198 v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"]
199 v_t = v.assign(v * coefficients["beta_2_t"], use_locking=self._use_locking)
200 with tf.control_dependencies([v_t]):
201 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
203 m_t_hat = m_t / (1.0 - coefficients["beta_1_power"])
204 v_t_hat = v_t / (1.0 - coefficients["beta_2_power"])
206 v_sqrt = tf.sqrt(v_t_hat)
207 update = m_t_hat / (v_sqrt + coefficients["epsilon"])
209 if self._do_use_weight_decay(var):
210 update += coefficients["weight_decay"] * var
212 ratio = 1.0
213 if self._do_layer_adaptation(var):
214 w_norm = tf.norm(var, ord=2)
215 g_norm = tf.norm(update, ord=2)
216 ratio = tf.where(
217 tf.greater(w_norm, 0),
218 tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0),
219 1.0,
220 )
222 var_update = var.assign_sub(
223 ratio * coefficients["lr_t"] * update, use_locking=self._use_locking
224 )
225 return tf.group(*[var_update, m_t, v_t])
227 def get_config(self):
228 config = super().get_config()
229 config.update(
230 {
231 "learning_rate": self._serialize_hyperparameter("learning_rate"),
232 "weight_decay": self._serialize_hyperparameter("weight_decay"),
233 "decay": self._serialize_hyperparameter("decay"),
234 "beta_1": self._serialize_hyperparameter("beta_1"),
235 "beta_2": self._serialize_hyperparameter("beta_2"),
236 "epsilon": self.epsilon,
237 "exclude_from_weight_decay": self.exclude_from_weight_decay,
238 "exclude_from_layer_adaptation": self.exclude_from_layer_adaptation,
239 }
240 )
241 return config
243 def _do_use_weight_decay(self, variable):
244 """Whether to use L2 weight decay for `param_name`."""
245 return not is_variable_matched_by_regexes(
246 variable, self.exclude_from_weight_decay
247 )
249 def _do_layer_adaptation(self, variable):
250 """Whether to do layer-wise learning rate adaptation for
251 `param_name`."""
252 return not is_variable_matched_by_regexes(
253 variable, self.exclude_from_layer_adaptation
254 )