Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/yogi.py: 10%
155 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Yogi: Extension of yogi adaptive nonconvex optimizer in Keras.
17Implementation of Additive Averaging.
18m_t+1 = beta1*m_t + (1-beta1)*g_t
19v_t+1 = v_t + sign(g_t-v_t)(g_t^2)
20Experiments show better performance across NLP and Vision tasks.
21Paper:
22https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf
23"""
25import tensorflow as tf
26from tensorflow_addons.utils.types import FloatTensorLike
28from tensorflow_addons.optimizers import KerasLegacyOptimizer
29from typeguard import typechecked
30from typing import Union, Callable
33def _solve(a, b, c):
34 """Return solution of a quadratic minimization.
36 The optimization equation is:
37 f(a, b, c) = argmin_w{1/2 * a * w^2 + b * w + c * |w|}
38 we get optimal solution w*:
39 w* = -(b - sign(b)*c)/a if |b| > c else w* = 0
40 REQUIRES: Dimensionality of a and b must be same
41 Args:
42 a: A Tensor
43 b: A Tensor
44 c: A Tensor with one element.
45 Returns:
46 A Tensor w, which is solution for the equation
47 """
48 w = (c * tf.sign(b) - b) / a
49 w = tf.cast(tf.abs(b) > c, dtype=b.dtype) * w
50 return w
53@tf.keras.utils.register_keras_serializable(package="Addons")
54class Yogi(KerasLegacyOptimizer):
55 """Optimizer that implements the Yogi algorithm in Keras.
57 See Algorithm 2 of
58 https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf.
59 """
61 @typechecked
62 def __init__(
63 self,
64 learning_rate: Union[FloatTensorLike, Callable] = 0.01,
65 beta1: FloatTensorLike = 0.9,
66 beta2: FloatTensorLike = 0.999,
67 epsilon: FloatTensorLike = 1e-3,
68 l1_regularization_strength: FloatTensorLike = 0.0,
69 l2_regularization_strength: FloatTensorLike = 0.0,
70 initial_accumulator_value: FloatTensorLike = 1e-6,
71 activation: str = "sign",
72 name: str = "Yogi",
73 **kwargs,
74 ):
75 """Construct a new Yogi optimizer.
77 Args:
78 learning_rate: A Tensor or a floating point value.
79 The learning rate.
80 beta1: A float value or a constant float tensor.
81 The exponential decay rate for the 1st moment estimates.
82 beta2: A float value or a constant float tensor.
83 The exponential decay rate for the 2nd moment estimates.
84 epsilon: A constant trading off adaptivity and noise.
85 l1_regularization_strength: A float value, must be greater than or
86 equal to zero.
87 l2_regularization_strength: A float value, must be greater than or
88 equal to zero.
89 initial_accumulator_value: The starting value for accumulators.
90 Only positive values are allowed.
91 activation: Use hard sign or soft tanh to determin sign.
92 name: Optional name for the operations created when applying
93 gradients. Defaults to "Yogi".
94 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
95 `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue`
96 is clip gradients by value, `decay` is included for backward
97 compatibility to allow time inverse decay of learning rate. `lr`
98 is included for backward compatibility, recommended to use
99 `learning_rate` instead.
100 """
101 super().__init__(name, **kwargs)
102 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
103 self._set_hyper("decay", self._initial_decay)
104 self._set_hyper("beta_1", beta1)
105 self._set_hyper("beta_2", beta2)
106 self._set_hyper("epsilon", epsilon)
107 self._set_hyper("l1_regularization_strength", l1_regularization_strength)
108 self._set_hyper("l2_regularization_strength", l2_regularization_strength)
110 self._beta1 = beta1
111 self._activation = activation
112 self._initial_accumulator_value = initial_accumulator_value
113 self._l1_regularization_strength = l1_regularization_strength
114 self._l2_regularization_strength = l2_regularization_strength
116 def _create_slots(self, var_list):
117 """See `tf.train.Optimizer._create_slots()`."""
118 # Create slots for the first and second moments, and maximum second moments.
119 for var in var_list:
120 init = tf.constant_initializer(self._initial_accumulator_value)
121 self.add_slot(var, "v", init)
122 if self._beta1 > 0.0:
123 self.add_slot(var, "m")
125 def _resource_apply_dense(self, grad, var):
126 """See `tf.train.Optimizer._apply_dense()`."""
127 var_dtype = var.dtype.base_dtype
128 lr_t = self._decayed_lr(var_dtype)
129 beta1_t = self._get_hyper("beta_1", var_dtype)
130 beta2_t = self._get_hyper("beta_2", var_dtype)
131 epsilon_t = self._get_hyper("epsilon", var_dtype)
132 l1_t = self._get_hyper("l1_regularization_strength", var_dtype)
133 l2_t = self._get_hyper("l2_regularization_strength", var_dtype)
134 local_step = tf.cast(self.iterations + 1, var_dtype)
135 beta1_power = tf.pow(beta1_t, local_step)
136 beta2_power = tf.pow(beta2_t, local_step)
138 lr = lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power)
140 update_vs = []
141 if self._beta1 == 0.0:
142 # v_t = v + sign(g_t^2-v)(g_t^2)
143 v = self.get_slot(var, "v")
144 grad2 = grad * grad
145 if self._activation == "sign":
146 sign = tf.sign(grad2 - v)
147 elif self._activation == "tanh":
148 sign = tf.tanh(10 * (grad2 - v))
149 else:
150 raise NotImplementedError("Activation function can be sign or tanh")
151 v_t = v.assign_add(
152 (1 - beta2_t) * sign * grad2, use_locking=self._use_locking
153 )
154 v_sqrt = tf.sqrt(v_t)
156 # Yogi effective LR
157 per_coord_lr = lr / (v_sqrt + epsilon_t)
159 # Variable update
160 # Step 1: Gradient descent
161 new_var = var - per_coord_lr * grad
162 # Step 2: Prox operator
163 if self._l1_regularization_strength > 0:
164 new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr)
165 elif self._l2_regularization_strength > 0:
166 new_var = new_var / (1 + l2_t * per_coord_lr)
167 # Step 3: Update
168 var_update = var.assign(new_var, use_locking=self._use_locking)
170 update_vs.append(var_update)
171 update_vs.append(v_t)
173 else:
174 # m_t = beta1 * m + (1 - beta1) * g_t
175 m = self.get_slot(var, "m")
176 m_t = m.assign(
177 m * beta1_t + grad * (1 - beta1_t), use_locking=self._use_locking
178 )
180 # v_t = v + sign(g_t^2-v)(g_t^2)
181 v = self.get_slot(var, "v")
182 grad2 = grad * grad
183 if self._activation == "sign":
184 sign = tf.sign(grad2 - v)
185 elif self._activation == "tanh":
186 sign = tf.tanh(10 * (grad2 - v))
187 else:
188 raise NotImplementedError("Activation function can be sign or tanh")
189 v_t = v.assign_add(
190 (1 - beta2_t) * sign * grad2, use_locking=self._use_locking
191 )
192 v_sqrt = tf.sqrt(v_t)
194 # Yogi effective LR
195 per_coord_lr = lr / (v_sqrt + epsilon_t)
197 # Variable update
198 # Step 1: Gradient descent
199 new_var = var - per_coord_lr * m_t
200 # Step 2: Prox operator
201 if self._l1_regularization_strength > 0:
202 new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr)
203 elif self._l2_regularization_strength > 0:
204 new_var = new_var / (1 + l2_t * per_coord_lr)
205 # Step 3: Update
206 var_update = var.assign(new_var, use_locking=self._use_locking)
207 update_vs.append(var_update)
208 update_vs.append(m_t)
209 update_vs.append(v_t)
211 # Create an op that groups all the above operations
212 return tf.group(*update_vs)
214 def _resource_apply_sparse(self, grad, var, indices):
215 """Applies sparse gradients to a variable.
217 Args:
218 grad: A tensor for the `values` of `tf.IndexedSlices`.
219 var: A `tf.Variable` object.
220 indices: A tensor for the `indices` of `tf.IndexedSlices`.
221 Returns:
222 An op which updates `var` with `grad` and `indices`.
223 """
225 var_dtype = var.dtype.base_dtype
226 lr_t = self._decayed_lr(var_dtype)
227 beta1_t = self._get_hyper("beta_1", var_dtype)
228 beta2_t = self._get_hyper("beta_2", var_dtype)
229 epsilon_t = self._get_hyper("epsilon", var_dtype)
230 l1_t = self._get_hyper("l1_regularization_strength", var_dtype)
231 l2_t = self._get_hyper("l2_regularization_strength", var_dtype)
232 local_step = tf.cast(self.iterations + 1, var_dtype)
233 beta1_power = tf.pow(beta1_t, local_step)
234 beta2_power = tf.pow(beta2_t, local_step)
236 lr = lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power)
238 update_vs = []
239 if self._beta1 == 0.0:
240 # v_t = v + sign(g_t^2-v)(g_t^2)
241 v = self.get_slot(var, "v")
242 grad2 = grad * grad
243 v_slice = tf.gather(v, indices)
244 if self._activation == "sign":
245 sign = tf.sign(grad2 - v_slice)
246 elif self._activation == "tanh":
247 sign = tf.tanh(10 * (grad2 - v_slice))
248 else:
249 raise NotImplementedError("Activation function can be sign or tanh")
250 v_scaled_g_values = v_slice + (1 - beta2_t) * sign * grad2
251 v_t = self._resource_scatter_update(v, indices, v_scaled_g_values)
252 v_sqrt = tf.sqrt(v_scaled_g_values)
254 # Yogi effective LR
255 per_coord_lr = lr / (v_sqrt + epsilon_t)
257 # Variable update
258 # Step 1: Gradient descent
259 var_slice = tf.gather(var, indices)
260 new_var = var_slice - per_coord_lr * grad
261 # Step 2: Prox operator
262 if self._l1_regularization_strength > 0:
263 new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr)
264 elif self._l2_regularization_strength > 0:
265 new_var = new_var / (1 + l2_t * per_coord_lr)
266 # Step 3: Update
267 var_update = self._resource_scatter_update(var, indices, new_var)
268 update_vs.append(var_update)
269 update_vs.append(v_t)
271 else:
272 # m_t = beta1 * m + (1 - beta1) * g_t
273 m = self.get_slot(var, "m")
274 m_scaled_g_values = grad * (1 - beta1_t)
275 m_t = m.assign(m * beta1_t, use_locking=self._use_locking)
276 with tf.control_dependencies([m_t]):
277 m_slice = tf.gather(m, indices) + m_scaled_g_values
278 m_t = self._resource_scatter_update(m, indices, m_slice)
280 # v_t = v + sign(g_t^2-v)(g_t^2)
281 v = self.get_slot(var, "v")
282 grad2 = grad * grad
283 v_slice = tf.gather(v, indices)
284 if self._activation == "sign":
285 sign = tf.sign(grad2 - tf.gather(v, indices))
286 elif self._activation == "tanh":
287 sign = tf.tanh(10 * (grad2 - tf.gather(v, indices)))
288 else:
289 raise NotImplementedError("Activation function can be sign or tanh")
290 v_scaled_g_values = v_slice + (1 - beta2_t) * sign * grad2
291 v_t = self._resource_scatter_update(v, indices, v_scaled_g_values)
292 v_sqrt = tf.sqrt(v_scaled_g_values)
294 # Yogi effective LR
295 per_coord_lr = lr / (v_sqrt + epsilon_t)
297 # Variable update
298 # Step 1: Gradient descent
299 var_slice = tf.gather(var, indices)
300 new_var = var_slice - per_coord_lr * m_slice
301 # Step 2: Prox operator
302 if self._l1_regularization_strength > 0:
303 new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr)
304 elif self._l2_regularization_strength > 0:
305 new_var = new_var / (1 + l2_t * per_coord_lr)
306 # Step 3: Update
307 var_update = self._resource_scatter_update(var, indices, new_var)
308 update_vs.append(var_update)
309 update_vs.append(m_t)
310 update_vs.append(v_t)
312 # Create an op that groups all the above operations
313 return tf.group(*update_vs)
315 def get_config(self):
316 config = super().get_config()
317 config.update(
318 {
319 "learning_rate": self._serialize_hyperparameter("learning_rate"),
320 "decay": self._serialize_hyperparameter("decay"),
321 "beta1": self._serialize_hyperparameter("beta_1"),
322 "beta2": self._serialize_hyperparameter("beta_2"),
323 "epsilon": self._serialize_hyperparameter("epsilon"),
324 "l1_regularization_strength": self._serialize_hyperparameter(
325 "l1_regularization_strength"
326 ),
327 "l2_regularization_strength": self._serialize_hyperparameter(
328 "l2_regularization_strength"
329 ),
330 "activation": self._activation,
331 "initial_accumulator_value": self._initial_accumulator_value,
332 }
333 )
334 return config