Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/ftrl.py: 23%

57 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""FTRL optimizer implementation.""" 

16 

17import tensorflow.compat.v2 as tf 

18 

19from keras.src.optimizers import optimizer 

20from keras.src.saving.object_registration import register_keras_serializable 

21 

22# isort: off 

23from tensorflow.python.util.tf_export import keras_export 

24 

25 

26@register_keras_serializable() 

27@keras_export( 

28 "keras.optimizers.experimental.Ftrl", "keras.optimizers.Ftrl", v1=[] 

29) 

30class Ftrl(optimizer.Optimizer): 

31 r"""Optimizer that implements the FTRL algorithm. 

32 

33 "Follow The Regularized Leader" (FTRL) is an optimization algorithm 

34 developed at Google for click-through rate prediction in the early 2010s. It 

35 is most suitable for shallow models with large and sparse feature spaces. 

36 The algorithm is described by 

37 [McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf). 

38 The Keras version has support for both online L2 regularization 

39 (the L2 regularization described in the paper 

40 above) and shrinkage-type L2 regularization 

41 (which is the addition of an L2 penalty to the loss function). 

42 

43 Initialization: 

44 

45 ```python 

46 n = 0 

47 sigma = 0 

48 z = 0 

49 ``` 

50 

51 Update rule for one variable `w`: 

52 

53 ```python 

54 prev_n = n 

55 n = n + g ** 2 

56 sigma = (n ** -lr_power - prev_n ** -lr_power) / lr 

57 z = z + g - sigma * w 

58 if abs(z) < lambda_1: 

59 w = 0 

60 else: 

61 w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2) 

62 ``` 

63 

64 Notation: 

65 

66 - `lr` is the learning rate 

67 - `g` is the gradient for the variable 

68 - `lambda_1` is the L1 regularization strength 

69 - `lambda_2` is the L2 regularization strength 

70 - `lr_power` is the power to scale n. 

71 

72 Check the documentation for the `l2_shrinkage_regularization_strength` 

73 parameter for more details when shrinkage is enabled, in which case gradient 

74 is replaced with a gradient with shrinkage. 

75 

76 Args: 

77 learning_rate: A `Tensor`, floating point value, a schedule that is a 

78 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable that 

79 takes no arguments and returns the actual value to use. The learning 

80 rate. Defaults to `0.001`. 

81 learning_rate_power: A float value, must be less or equal to zero. 

82 Controls how the learning rate decreases during training. Use zero for a 

83 fixed learning rate. 

84 initial_accumulator_value: The starting value for accumulators. Only zero 

85 or positive values are allowed. 

86 l1_regularization_strength: A float value, must be greater than or equal 

87 to zero. Defaults to `0.0`. 

88 l2_regularization_strength: A float value, must be greater than or equal 

89 to zero. Defaults to `0.0`. 

90 l2_shrinkage_regularization_strength: A float value, must be greater than 

91 or equal to zero. This differs from L2 above in that the L2 above is a 

92 stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. 

93 When input is sparse shrinkage will only happen on the active weights. 

94 beta: A float value, representing the beta value from the paper. Defaults 

95 to 0.0. 

96 {{base_optimizer_keyword_args}} 

97 """ 

98 

99 def __init__( 

100 self, 

101 learning_rate=0.001, 

102 learning_rate_power=-0.5, 

103 initial_accumulator_value=0.1, 

104 l1_regularization_strength=0.0, 

105 l2_regularization_strength=0.0, 

106 l2_shrinkage_regularization_strength=0.0, 

107 beta=0.0, 

108 weight_decay=None, 

109 clipnorm=None, 

110 clipvalue=None, 

111 global_clipnorm=None, 

112 use_ema=False, 

113 ema_momentum=0.99, 

114 ema_overwrite_frequency=None, 

115 jit_compile=True, 

116 name="Ftrl", 

117 **kwargs, 

118 ): 

119 super().__init__( 

120 name=name, 

121 weight_decay=weight_decay, 

122 clipnorm=clipnorm, 

123 clipvalue=clipvalue, 

124 global_clipnorm=global_clipnorm, 

125 use_ema=use_ema, 

126 ema_momentum=ema_momentum, 

127 ema_overwrite_frequency=ema_overwrite_frequency, 

128 jit_compile=jit_compile, 

129 **kwargs, 

130 ) 

131 

132 if initial_accumulator_value < 0.0: 

133 raise ValueError( 

134 "`initial_accumulator_value` needs to be positive or zero. " 

135 "Received: initial_accumulator_value=" 

136 f"{initial_accumulator_value}." 

137 ) 

138 if learning_rate_power > 0.0: 

139 raise ValueError( 

140 "`learning_rate_power` needs to be negative or zero. Received: " 

141 f"learning_rate_power={learning_rate_power}." 

142 ) 

143 if l1_regularization_strength < 0.0: 

144 raise ValueError( 

145 "`l1_regularization_strength` needs to be positive or zero. " 

146 "Received: l1_regularization_strength=" 

147 f"{l1_regularization_strength}." 

148 ) 

149 if l2_regularization_strength < 0.0: 

150 raise ValueError( 

151 "`l2_regularization_strength` needs to be positive or zero. " 

152 "Received: l2_regularization_strength=" 

153 f"{l2_regularization_strength}." 

154 ) 

155 if l2_shrinkage_regularization_strength < 0.0: 

156 raise ValueError( 

157 "`l2_shrinkage_regularization_strength` needs to be positive " 

158 "or zero. Received: l2_shrinkage_regularization_strength" 

159 f"={l2_shrinkage_regularization_strength}." 

160 ) 

161 

162 self._learning_rate = self._build_learning_rate(learning_rate) 

163 self.learning_rate_power = learning_rate_power 

164 self.initial_accumulator_value = initial_accumulator_value 

165 self.l1_regularization_strength = l1_regularization_strength 

166 self.l2_regularization_strength = l2_regularization_strength 

167 self.l2_shrinkage_regularization_strength = ( 

168 l2_shrinkage_regularization_strength 

169 ) 

170 self.beta = beta 

171 

172 def build(self, var_list): 

173 """Initialize optimizer variables. 

174 

175 Args: 

176 var_list: list of model variables to build Ftrl variables on. 

177 """ 

178 super().build(var_list) 

179 if hasattr(self, "_built") and self._built: 

180 return 

181 self._accumulators = [] 

182 self._linears = [] 

183 for var in var_list: 

184 self._accumulators.append( 

185 self.add_variable_from_reference( 

186 model_variable=var, 

187 variable_name="accumulator", 

188 initial_value=tf.cast( 

189 tf.fill( 

190 dims=var.shape, value=self.initial_accumulator_value 

191 ), 

192 dtype=var.dtype, 

193 ), 

194 ) 

195 ) 

196 self._linears.append( 

197 self.add_variable_from_reference( 

198 model_variable=var, variable_name="linear" 

199 ) 

200 ) 

201 self._built = True 

202 

203 def update_step(self, gradient, variable): 

204 """Update step given gradient and the associated model variable.""" 

205 

206 lr = tf.cast(self.learning_rate, variable.dtype) 

207 var_key = self._var_key(variable) 

208 accum = self._accumulators[self._index_dict[var_key]] 

209 linear = self._linears[self._index_dict[var_key]] 

210 

211 lr_power = self.learning_rate_power 

212 l2_reg = self.l2_regularization_strength 

213 l2_reg = l2_reg + self.beta / (2.0 * lr) 

214 

215 # Ftrl optimizer has the same implementation for sparse and dense 

216 # gradients update. 

217 grad_to_use = ( 

218 gradient + 2 * self.l2_shrinkage_regularization_strength * variable 

219 ) 

220 new_accum = accum + tf.pow(gradient, 2) 

221 linear.assign_add( 

222 grad_to_use 

223 - (tf.pow(new_accum, -lr_power) - tf.pow(accum, -lr_power)) 

224 / lr 

225 * variable 

226 ) 

227 quadratic = tf.pow(new_accum, (-lr_power)) / lr + 2 * l2_reg 

228 linear_clipped = tf.clip_by_value( 

229 linear, 

230 -self.l1_regularization_strength, 

231 self.l1_regularization_strength, 

232 ) 

233 variable.assign((linear_clipped - linear) / quadratic) 

234 accum.assign(new_accum) 

235 

236 def get_config(self): 

237 config = super().get_config() 

238 

239 config.update( 

240 { 

241 "learning_rate": self._serialize_hyperparameter( 

242 self._learning_rate 

243 ), 

244 "learning_rate_power": self.learning_rate_power, 

245 "initial_accumulator_value": self.initial_accumulator_value, 

246 "l1_regularization_strength": self.l1_regularization_strength, 

247 "l2_regularization_strength": self.l2_regularization_strength, 

248 "l2_shrinkage_regularization_strength": self.l2_shrinkage_regularization_strength, # noqa: E501 

249 "beta": self.beta, 

250 } 

251 ) 

252 return config 

253 

254 

255Ftrl.__doc__ = Ftrl.__doc__.replace( 

256 "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args 

257) 

258