Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/lion.py: 31%
42 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Lion optimizer implementation."""
17import tensorflow.compat.v2 as tf
19from keras.src.optimizers import optimizer
20from keras.src.saving.object_registration import register_keras_serializable
22# isort: off
23from tensorflow.python.util.tf_export import keras_export
26@register_keras_serializable()
27@keras_export("keras.optimizers.Lion", v1=[])
28class Lion(optimizer.Optimizer):
29 """Optimizer that implements the Lion algorithm.
31 The Lion optimizer is a stochastic-gradient-descent method that uses the
32 sign operator to control the magnitude of the update, unlike other adaptive
33 optimizers such as Adam that rely on second-order moments. This make
34 Lion more memory-efficient as it only keeps track of the momentum. According
35 to the authors (see reference), its performance gain over Adam grows with
36 the batch size. Because the update of Lion is produced through the sign
37 operation, resulting in a larger norm, a suitable learning rate for Lion is
38 typically 3-10x smaller than that for AdamW. The weight decay for Lion
39 should be in turn 3-10x larger than that for AdamW to maintain a
40 similar strength (lr * wd).
42 Args:
43 learning_rate: A `tf.Tensor`, floating point value, a schedule that is a
44 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
45 that takes no arguments and returns the actual value to use. The
46 learning rate. Defaults to 0.0001.
47 beta_1: A float value or a constant float tensor, or a callable
48 that takes no arguments and returns the actual value to use. The rate
49 to combine the current gradient and the 1st moment estimate.
50 beta_2: A float value or a constant float tensor, or a callable
51 that takes no arguments and returns the actual value to use. The
52 exponential decay rate for the 1st moment estimate.
53 {{base_optimizer_keyword_args}}
55 References:
56 - [Chen et al., 2023](http://arxiv.org/abs/2302.06675)
57 - [Authors' implementation](
58 http://github.com/google/automl/tree/master/lion)
60 """
62 def __init__(
63 self,
64 learning_rate=0.0001,
65 beta_1=0.9,
66 beta_2=0.99,
67 weight_decay=None,
68 clipnorm=None,
69 clipvalue=None,
70 global_clipnorm=None,
71 use_ema=False,
72 ema_momentum=0.99,
73 ema_overwrite_frequency=None,
74 jit_compile=True,
75 name="Lion",
76 **kwargs,
77 ):
78 super().__init__(
79 name=name,
80 weight_decay=weight_decay,
81 clipnorm=clipnorm,
82 clipvalue=clipvalue,
83 global_clipnorm=global_clipnorm,
84 use_ema=use_ema,
85 ema_momentum=ema_momentum,
86 ema_overwrite_frequency=ema_overwrite_frequency,
87 jit_compile=jit_compile,
88 **kwargs,
89 )
90 self._learning_rate = self._build_learning_rate(learning_rate)
91 self.beta_1 = beta_1
92 self.beta_2 = beta_2
93 if beta_1 <= 0 or beta_1 > 1:
94 raise ValueError(
95 f"`beta_1`={beta_1} must be between ]0, 1]. Otherwise, "
96 "the optimizer degenerates to SignSGD."
97 )
99 def build(self, var_list):
100 """Initialize optimizer variables.
102 Lion optimizer has one variable `momentums`.
104 Args:
105 var_list: list of model variables to build Lion variables on.
106 """
107 super().build(var_list)
108 if hasattr(self, "_built") and self._built:
109 return
110 self.momentums = []
111 for var in var_list:
112 self.momentums.append(
113 self.add_variable_from_reference(
114 model_variable=var, variable_name="m"
115 )
116 )
117 self._built = True
119 def update_step(self, gradient, variable):
120 """Update step given gradient and the associated model variable."""
121 lr = tf.cast(self.learning_rate, variable.dtype)
122 beta_1 = tf.cast(self.beta_1, variable.dtype)
123 beta_2 = tf.cast(self.beta_2, variable.dtype)
124 var_key = self._var_key(variable)
125 m = self.momentums[self._index_dict[var_key]]
127 if isinstance(gradient, tf.IndexedSlices):
128 # Sparse gradients (use m as a buffer)
129 m.assign(m * beta_1)
130 m.scatter_add(
131 tf.IndexedSlices(
132 gradient.values * (1.0 - beta_1), gradient.indices
133 )
134 )
135 variable.assign_sub(lr * tf.math.sign(m))
137 m.assign(m * beta_2 / beta_1)
138 m.scatter_add(
139 tf.IndexedSlices(
140 gradient.values * (1.0 - beta_2 / beta_1), gradient.indices
141 )
142 )
143 else:
144 # Dense gradients
145 variable.assign_sub(
146 lr * tf.math.sign(m * beta_1 + gradient * (1.0 - beta_1))
147 )
148 m.assign(m * beta_2 + gradient * (1.0 - beta_2))
150 def get_config(self):
151 config = super().get_config()
153 config.update(
154 {
155 "learning_rate": self._serialize_hyperparameter(
156 self._learning_rate
157 ),
158 "beta_1": self.beta_1,
159 "beta_2": self.beta_2,
160 }
161 )
162 return config
165Lion.__doc__ = Lion.__doc__.replace(
166 "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
167)