Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/legacy/ftrl.py: 21%

57 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Ftrl-proximal optimizer implementation.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src.optimizers.legacy import optimizer_v2 

21 

22# isort: off 

23from tensorflow.python.util.tf_export import keras_export 

24 

25 

26@keras_export( 

27 "keras.optimizers.legacy.Ftrl", 

28 v1=["keras.optimizers.Ftrl", "keras.optimizers.legacy.Ftrl"], 

29) 

30class Ftrl(optimizer_v2.OptimizerV2): 

31 r"""Optimizer that implements the FTRL algorithm. 

32 

33 "Follow The Regularized Leader" (FTRL) is an optimization algorithm 

34 developed at Google for click-through rate prediction in the early 2010s. It 

35 is most suitable for shallow models with large and sparse feature spaces. 

36 The algorithm is described by 

37 [McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf). 

38 The Keras version has support for both online L2 regularization 

39 (the L2 regularization described in the paper 

40 above) and shrinkage-type L2 regularization 

41 (which is the addition of an L2 penalty to the loss function). 

42 

43 Initialization: 

44 

45 ```python 

46 n = 0 

47 sigma = 0 

48 z = 0 

49 ``` 

50 

51 Update rule for one variable `w`: 

52 

53 ```python 

54 prev_n = n 

55 n = n + g ** 2 

56 sigma = (sqrt(n) - sqrt(prev_n)) / lr 

57 z = z + g - sigma * w 

58 if abs(z) < lambda_1: 

59 w = 0 

60 else: 

61 w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2) 

62 ``` 

63 

64 Notation: 

65 

66 - `lr` is the learning rate 

67 - `g` is the gradient for the variable 

68 - `lambda_1` is the L1 regularization strength 

69 - `lambda_2` is the L2 regularization strength 

70 

71 Check the documentation for the `l2_shrinkage_regularization_strength` 

72 parameter for more details when shrinkage is enabled, in which case gradient 

73 is replaced with a gradient with shrinkage. 

74 

75 Args: 

76 learning_rate: A `Tensor`, floating point value, or a schedule that is a 

77 `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. 

78 learning_rate_power: A float value, must be less or equal to zero. 

79 Controls how the learning rate decreases during training. Use zero for 

80 a fixed learning rate. 

81 initial_accumulator_value: The starting value for accumulators. 

82 Only zero or positive values are allowed. 

83 l1_regularization_strength: A float value, must be greater than or 

84 equal to zero. Defaults to `0.0`. 

85 l2_regularization_strength: A float value, must be greater than or 

86 equal to zero. Defaults to `0.0`. 

87 name: Optional name prefix for the operations created when applying 

88 gradients. Defaults to `"Ftrl"`. 

89 l2_shrinkage_regularization_strength: A float value, must be greater than 

90 or equal to zero. This differs from L2 above in that the L2 above is a 

91 stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. 

92 When input is sparse shrinkage will only happen on the active weights. 

93 beta: A float value, representing the beta value from the paper. 

94 Defaults to `0.0`. 

95 **kwargs: keyword arguments. Allowed arguments are `clipvalue`, 

96 `clipnorm`, `global_clipnorm`. 

97 If `clipvalue` (float) is set, the gradient of each weight 

98 is clipped to be no higher than this value. 

99 If `clipnorm` (float) is set, the gradient of each weight 

100 is individually clipped so that its norm is no higher than this value. 

101 If `global_clipnorm` (float) is set the gradient of all weights is 

102 clipped so that their global norm is no higher than this value. 

103 

104 Reference: 

105 - [McMahan et al., 2013]( 

106 https://research.google.com/pubs/archive/41159.pdf) 

107 """ 

108 

109 def __init__( 

110 self, 

111 learning_rate=0.001, 

112 learning_rate_power=-0.5, 

113 initial_accumulator_value=0.1, 

114 l1_regularization_strength=0.0, 

115 l2_regularization_strength=0.0, 

116 name="Ftrl", 

117 l2_shrinkage_regularization_strength=0.0, 

118 beta=0.0, 

119 **kwargs, 

120 ): 

121 super().__init__(name, **kwargs) 

122 

123 if initial_accumulator_value < 0.0: 

124 raise ValueError( 

125 "`initial_accumulator_value` needs to be " 

126 "positive or zero. Received: " 

127 f"initial_accumulator_value={initial_accumulator_value}." 

128 ) 

129 if learning_rate_power > 0.0: 

130 raise ValueError( 

131 "`learning_rate_power` needs to be " 

132 "negative or zero. Received: " 

133 f"learning_rate_power={learning_rate_power}." 

134 ) 

135 if l1_regularization_strength < 0.0: 

136 raise ValueError( 

137 "`l1_regularization_strength` needs to be positive or zero. " 

138 "Received: l1_regularization_strength=" 

139 f"{l1_regularization_strength}." 

140 ) 

141 if l2_regularization_strength < 0.0: 

142 raise ValueError( 

143 "`l2_regularization_strength` needs to be positive or zero. " 

144 "Received: l2_regularization_strength=" 

145 f"{l2_regularization_strength}." 

146 ) 

147 if l2_shrinkage_regularization_strength < 0.0: 

148 raise ValueError( 

149 "`l2_shrinkage_regularization_strength` needs to be positive " 

150 "or zero. Received: l2_shrinkage_regularization_strength" 

151 f"={l2_shrinkage_regularization_strength}." 

152 ) 

153 

154 self._set_hyper("learning_rate", learning_rate) 

155 self._set_hyper("decay", self._initial_decay) 

156 self._set_hyper("learning_rate_power", learning_rate_power) 

157 self._set_hyper( 

158 "l1_regularization_strength", l1_regularization_strength 

159 ) 

160 self._set_hyper( 

161 "l2_regularization_strength", l2_regularization_strength 

162 ) 

163 self._set_hyper("beta", beta) 

164 self._initial_accumulator_value = initial_accumulator_value 

165 self._l2_shrinkage_regularization_strength = ( 

166 l2_shrinkage_regularization_strength 

167 ) 

168 

169 def _create_slots(self, var_list): 

170 # Create the "accum" and "linear" slots. 

171 for var in var_list: 

172 dtype = var.dtype.base_dtype 

173 init = tf.compat.v1.constant_initializer( 

174 self._initial_accumulator_value, dtype=dtype 

175 ) 

176 self.add_slot(var, "accumulator", init) 

177 self.add_slot(var, "linear") 

178 

179 def _prepare_local(self, var_device, var_dtype, apply_state): 

180 super()._prepare_local(var_device, var_dtype, apply_state) 

181 apply_state[(var_device, var_dtype)].update( 

182 dict( 

183 learning_rate_power=tf.identity( 

184 self._get_hyper("learning_rate_power", var_dtype) 

185 ), 

186 l1_regularization_strength=tf.identity( 

187 self._get_hyper("l1_regularization_strength", var_dtype) 

188 ), 

189 l2_regularization_strength=tf.identity( 

190 self._get_hyper("l2_regularization_strength", var_dtype) 

191 ), 

192 beta=tf.identity(self._get_hyper("beta", var_dtype)), 

193 l2_shrinkage_regularization_strength=tf.cast( 

194 self._l2_shrinkage_regularization_strength, var_dtype 

195 ), 

196 ) 

197 ) 

198 

199 def _resource_apply_dense(self, grad, var, apply_state=None): 

200 var_device, var_dtype = var.device, var.dtype.base_dtype 

201 coefficients = (apply_state or {}).get( 

202 (var_device, var_dtype) 

203 ) or self._fallback_apply_state(var_device, var_dtype) 

204 

205 # Adjust L2 regularization strength to include beta to avoid the 

206 # underlying TensorFlow ops needing to include it. 

207 adjusted_l2_regularization_strength = coefficients[ 

208 "l2_regularization_strength" 

209 ] + coefficients["beta"] / (2.0 * coefficients["lr_t"]) 

210 

211 accum = self.get_slot(var, "accumulator") 

212 linear = self.get_slot(var, "linear") 

213 

214 if self._l2_shrinkage_regularization_strength <= 0.0: 

215 return tf.raw_ops.ResourceApplyFtrl( 

216 var=var.handle, 

217 accum=accum.handle, 

218 linear=linear.handle, 

219 grad=grad, 

220 lr=coefficients["lr_t"], 

221 l1=coefficients["l1_regularization_strength"], 

222 l2=adjusted_l2_regularization_strength, 

223 lr_power=coefficients["learning_rate_power"], 

224 use_locking=self._use_locking, 

225 ) 

226 else: 

227 return tf.raw_ops.ResourceApplyFtrlV2( 

228 var=var.handle, 

229 accum=accum.handle, 

230 linear=linear.handle, 

231 grad=grad, 

232 lr=coefficients["lr_t"], 

233 l1=coefficients["l1_regularization_strength"], 

234 l2=adjusted_l2_regularization_strength, 

235 l2_shrinkage=coefficients[ 

236 "l2_shrinkage_regularization_strength" 

237 ], 

238 lr_power=coefficients["learning_rate_power"], 

239 use_locking=self._use_locking, 

240 ) 

241 

242 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 

243 var_device, var_dtype = var.device, var.dtype.base_dtype 

244 coefficients = (apply_state or {}).get( 

245 (var_device, var_dtype) 

246 ) or self._fallback_apply_state(var_device, var_dtype) 

247 

248 # Adjust L2 regularization strength to include beta to avoid the 

249 # underlying TensorFlow ops needing to include it. 

250 adjusted_l2_regularization_strength = coefficients[ 

251 "l2_regularization_strength" 

252 ] + coefficients["beta"] / (2.0 * coefficients["lr_t"]) 

253 

254 accum = self.get_slot(var, "accumulator") 

255 linear = self.get_slot(var, "linear") 

256 

257 if self._l2_shrinkage_regularization_strength <= 0.0: 

258 return tf.raw_ops.ResourceSparseApplyFtrl( 

259 var=var.handle, 

260 accum=accum.handle, 

261 linear=linear.handle, 

262 grad=grad, 

263 indices=indices, 

264 lr=coefficients["lr_t"], 

265 l1=coefficients["l1_regularization_strength"], 

266 l2=adjusted_l2_regularization_strength, 

267 lr_power=coefficients["learning_rate_power"], 

268 use_locking=self._use_locking, 

269 ) 

270 else: 

271 return tf.raw_ops.ResourceSparseApplyFtrlV2( 

272 var=var.handle, 

273 accum=accum.handle, 

274 linear=linear.handle, 

275 grad=grad, 

276 indices=indices, 

277 lr=coefficients["lr_t"], 

278 l1=coefficients["l1_regularization_strength"], 

279 l2=adjusted_l2_regularization_strength, 

280 l2_shrinkage=coefficients[ 

281 "l2_shrinkage_regularization_strength" 

282 ], 

283 lr_power=coefficients["learning_rate_power"], 

284 use_locking=self._use_locking, 

285 ) 

286 

287 def get_config(self): 

288 config = super().get_config() 

289 config.update( 

290 { 

291 "learning_rate": self._serialize_hyperparameter( 

292 "learning_rate" 

293 ), 

294 "decay": self._initial_decay, 

295 "initial_accumulator_value": self._initial_accumulator_value, 

296 "learning_rate_power": self._serialize_hyperparameter( 

297 "learning_rate_power" 

298 ), 

299 "l1_regularization_strength": self._serialize_hyperparameter( 

300 "l1_regularization_strength" 

301 ), 

302 "l2_regularization_strength": self._serialize_hyperparameter( 

303 "l2_regularization_strength" 

304 ), 

305 "beta": self._serialize_hyperparameter("beta"), 

306 "l2_shrinkage_regularization_strength": self._l2_shrinkage_regularization_strength, # noqa: E501 

307 } 

308 ) 

309 return config 

310