Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/lion.py: 31%

42 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Lion optimizer implementation.""" 

16 

17import tensorflow.compat.v2 as tf 

18 

19from keras.src.optimizers import optimizer 

20from keras.src.saving.object_registration import register_keras_serializable 

21 

22# isort: off 

23from tensorflow.python.util.tf_export import keras_export 

24 

25 

26@register_keras_serializable() 

27@keras_export("keras.optimizers.Lion", v1=[]) 

28class Lion(optimizer.Optimizer): 

29 """Optimizer that implements the Lion algorithm. 

30 

31 The Lion optimizer is a stochastic-gradient-descent method that uses the 

32 sign operator to control the magnitude of the update, unlike other adaptive 

33 optimizers such as Adam that rely on second-order moments. This make 

34 Lion more memory-efficient as it only keeps track of the momentum. According 

35 to the authors (see reference), its performance gain over Adam grows with 

36 the batch size. Because the update of Lion is produced through the sign 

37 operation, resulting in a larger norm, a suitable learning rate for Lion is 

38 typically 3-10x smaller than that for AdamW. The weight decay for Lion 

39 should be in turn 3-10x larger than that for AdamW to maintain a 

40 similar strength (lr * wd). 

41 

42 Args: 

43 learning_rate: A `tf.Tensor`, floating point value, a schedule that is a 

44 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable 

45 that takes no arguments and returns the actual value to use. The 

46 learning rate. Defaults to 0.0001. 

47 beta_1: A float value or a constant float tensor, or a callable 

48 that takes no arguments and returns the actual value to use. The rate 

49 to combine the current gradient and the 1st moment estimate. 

50 beta_2: A float value or a constant float tensor, or a callable 

51 that takes no arguments and returns the actual value to use. The 

52 exponential decay rate for the 1st moment estimate. 

53 {{base_optimizer_keyword_args}} 

54 

55 References: 

56 - [Chen et al., 2023](http://arxiv.org/abs/2302.06675) 

57 - [Authors' implementation]( 

58 http://github.com/google/automl/tree/master/lion) 

59 

60 """ 

61 

62 def __init__( 

63 self, 

64 learning_rate=0.0001, 

65 beta_1=0.9, 

66 beta_2=0.99, 

67 weight_decay=None, 

68 clipnorm=None, 

69 clipvalue=None, 

70 global_clipnorm=None, 

71 use_ema=False, 

72 ema_momentum=0.99, 

73 ema_overwrite_frequency=None, 

74 jit_compile=True, 

75 name="Lion", 

76 **kwargs, 

77 ): 

78 super().__init__( 

79 name=name, 

80 weight_decay=weight_decay, 

81 clipnorm=clipnorm, 

82 clipvalue=clipvalue, 

83 global_clipnorm=global_clipnorm, 

84 use_ema=use_ema, 

85 ema_momentum=ema_momentum, 

86 ema_overwrite_frequency=ema_overwrite_frequency, 

87 jit_compile=jit_compile, 

88 **kwargs, 

89 ) 

90 self._learning_rate = self._build_learning_rate(learning_rate) 

91 self.beta_1 = beta_1 

92 self.beta_2 = beta_2 

93 if beta_1 <= 0 or beta_1 > 1: 

94 raise ValueError( 

95 f"`beta_1`={beta_1} must be between ]0, 1]. Otherwise, " 

96 "the optimizer degenerates to SignSGD." 

97 ) 

98 

99 def build(self, var_list): 

100 """Initialize optimizer variables. 

101 

102 Lion optimizer has one variable `momentums`. 

103 

104 Args: 

105 var_list: list of model variables to build Lion variables on. 

106 """ 

107 super().build(var_list) 

108 if hasattr(self, "_built") and self._built: 

109 return 

110 self.momentums = [] 

111 for var in var_list: 

112 self.momentums.append( 

113 self.add_variable_from_reference( 

114 model_variable=var, variable_name="m" 

115 ) 

116 ) 

117 self._built = True 

118 

119 def update_step(self, gradient, variable): 

120 """Update step given gradient and the associated model variable.""" 

121 lr = tf.cast(self.learning_rate, variable.dtype) 

122 beta_1 = tf.cast(self.beta_1, variable.dtype) 

123 beta_2 = tf.cast(self.beta_2, variable.dtype) 

124 var_key = self._var_key(variable) 

125 m = self.momentums[self._index_dict[var_key]] 

126 

127 if isinstance(gradient, tf.IndexedSlices): 

128 # Sparse gradients (use m as a buffer) 

129 m.assign(m * beta_1) 

130 m.scatter_add( 

131 tf.IndexedSlices( 

132 gradient.values * (1.0 - beta_1), gradient.indices 

133 ) 

134 ) 

135 variable.assign_sub(lr * tf.math.sign(m)) 

136 

137 m.assign(m * beta_2 / beta_1) 

138 m.scatter_add( 

139 tf.IndexedSlices( 

140 gradient.values * (1.0 - beta_2 / beta_1), gradient.indices 

141 ) 

142 ) 

143 else: 

144 # Dense gradients 

145 variable.assign_sub( 

146 lr * tf.math.sign(m * beta_1 + gradient * (1.0 - beta_1)) 

147 ) 

148 m.assign(m * beta_2 + gradient * (1.0 - beta_2)) 

149 

150 def get_config(self): 

151 config = super().get_config() 

152 

153 config.update( 

154 { 

155 "learning_rate": self._serialize_hyperparameter( 

156 self._learning_rate 

157 ), 

158 "beta_1": self.beta_1, 

159 "beta_2": self.beta_2, 

160 } 

161 ) 

162 return config 

163 

164 

165Lion.__doc__ = Lion.__doc__.replace( 

166 "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args 

167) 

168