Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/legacy/ftrl.py: 21%
57 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ftrl-proximal optimizer implementation."""
18import tensorflow.compat.v2 as tf
20from keras.src.optimizers.legacy import optimizer_v2
22# isort: off
23from tensorflow.python.util.tf_export import keras_export
26@keras_export(
27 "keras.optimizers.legacy.Ftrl",
28 v1=["keras.optimizers.Ftrl", "keras.optimizers.legacy.Ftrl"],
29)
30class Ftrl(optimizer_v2.OptimizerV2):
31 r"""Optimizer that implements the FTRL algorithm.
33 "Follow The Regularized Leader" (FTRL) is an optimization algorithm
34 developed at Google for click-through rate prediction in the early 2010s. It
35 is most suitable for shallow models with large and sparse feature spaces.
36 The algorithm is described by
37 [McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf).
38 The Keras version has support for both online L2 regularization
39 (the L2 regularization described in the paper
40 above) and shrinkage-type L2 regularization
41 (which is the addition of an L2 penalty to the loss function).
43 Initialization:
45 ```python
46 n = 0
47 sigma = 0
48 z = 0
49 ```
51 Update rule for one variable `w`:
53 ```python
54 prev_n = n
55 n = n + g ** 2
56 sigma = (sqrt(n) - sqrt(prev_n)) / lr
57 z = z + g - sigma * w
58 if abs(z) < lambda_1:
59 w = 0
60 else:
61 w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2)
62 ```
64 Notation:
66 - `lr` is the learning rate
67 - `g` is the gradient for the variable
68 - `lambda_1` is the L1 regularization strength
69 - `lambda_2` is the L2 regularization strength
71 Check the documentation for the `l2_shrinkage_regularization_strength`
72 parameter for more details when shrinkage is enabled, in which case gradient
73 is replaced with a gradient with shrinkage.
75 Args:
76 learning_rate: A `Tensor`, floating point value, or a schedule that is a
77 `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate.
78 learning_rate_power: A float value, must be less or equal to zero.
79 Controls how the learning rate decreases during training. Use zero for
80 a fixed learning rate.
81 initial_accumulator_value: The starting value for accumulators.
82 Only zero or positive values are allowed.
83 l1_regularization_strength: A float value, must be greater than or
84 equal to zero. Defaults to `0.0`.
85 l2_regularization_strength: A float value, must be greater than or
86 equal to zero. Defaults to `0.0`.
87 name: Optional name prefix for the operations created when applying
88 gradients. Defaults to `"Ftrl"`.
89 l2_shrinkage_regularization_strength: A float value, must be greater than
90 or equal to zero. This differs from L2 above in that the L2 above is a
91 stabilization penalty, whereas this L2 shrinkage is a magnitude penalty.
92 When input is sparse shrinkage will only happen on the active weights.
93 beta: A float value, representing the beta value from the paper.
94 Defaults to `0.0`.
95 **kwargs: keyword arguments. Allowed arguments are `clipvalue`,
96 `clipnorm`, `global_clipnorm`.
97 If `clipvalue` (float) is set, the gradient of each weight
98 is clipped to be no higher than this value.
99 If `clipnorm` (float) is set, the gradient of each weight
100 is individually clipped so that its norm is no higher than this value.
101 If `global_clipnorm` (float) is set the gradient of all weights is
102 clipped so that their global norm is no higher than this value.
104 Reference:
105 - [McMahan et al., 2013](
106 https://research.google.com/pubs/archive/41159.pdf)
107 """
109 def __init__(
110 self,
111 learning_rate=0.001,
112 learning_rate_power=-0.5,
113 initial_accumulator_value=0.1,
114 l1_regularization_strength=0.0,
115 l2_regularization_strength=0.0,
116 name="Ftrl",
117 l2_shrinkage_regularization_strength=0.0,
118 beta=0.0,
119 **kwargs,
120 ):
121 super().__init__(name, **kwargs)
123 if initial_accumulator_value < 0.0:
124 raise ValueError(
125 "`initial_accumulator_value` needs to be "
126 "positive or zero. Received: "
127 f"initial_accumulator_value={initial_accumulator_value}."
128 )
129 if learning_rate_power > 0.0:
130 raise ValueError(
131 "`learning_rate_power` needs to be "
132 "negative or zero. Received: "
133 f"learning_rate_power={learning_rate_power}."
134 )
135 if l1_regularization_strength < 0.0:
136 raise ValueError(
137 "`l1_regularization_strength` needs to be positive or zero. "
138 "Received: l1_regularization_strength="
139 f"{l1_regularization_strength}."
140 )
141 if l2_regularization_strength < 0.0:
142 raise ValueError(
143 "`l2_regularization_strength` needs to be positive or zero. "
144 "Received: l2_regularization_strength="
145 f"{l2_regularization_strength}."
146 )
147 if l2_shrinkage_regularization_strength < 0.0:
148 raise ValueError(
149 "`l2_shrinkage_regularization_strength` needs to be positive "
150 "or zero. Received: l2_shrinkage_regularization_strength"
151 f"={l2_shrinkage_regularization_strength}."
152 )
154 self._set_hyper("learning_rate", learning_rate)
155 self._set_hyper("decay", self._initial_decay)
156 self._set_hyper("learning_rate_power", learning_rate_power)
157 self._set_hyper(
158 "l1_regularization_strength", l1_regularization_strength
159 )
160 self._set_hyper(
161 "l2_regularization_strength", l2_regularization_strength
162 )
163 self._set_hyper("beta", beta)
164 self._initial_accumulator_value = initial_accumulator_value
165 self._l2_shrinkage_regularization_strength = (
166 l2_shrinkage_regularization_strength
167 )
169 def _create_slots(self, var_list):
170 # Create the "accum" and "linear" slots.
171 for var in var_list:
172 dtype = var.dtype.base_dtype
173 init = tf.compat.v1.constant_initializer(
174 self._initial_accumulator_value, dtype=dtype
175 )
176 self.add_slot(var, "accumulator", init)
177 self.add_slot(var, "linear")
179 def _prepare_local(self, var_device, var_dtype, apply_state):
180 super()._prepare_local(var_device, var_dtype, apply_state)
181 apply_state[(var_device, var_dtype)].update(
182 dict(
183 learning_rate_power=tf.identity(
184 self._get_hyper("learning_rate_power", var_dtype)
185 ),
186 l1_regularization_strength=tf.identity(
187 self._get_hyper("l1_regularization_strength", var_dtype)
188 ),
189 l2_regularization_strength=tf.identity(
190 self._get_hyper("l2_regularization_strength", var_dtype)
191 ),
192 beta=tf.identity(self._get_hyper("beta", var_dtype)),
193 l2_shrinkage_regularization_strength=tf.cast(
194 self._l2_shrinkage_regularization_strength, var_dtype
195 ),
196 )
197 )
199 def _resource_apply_dense(self, grad, var, apply_state=None):
200 var_device, var_dtype = var.device, var.dtype.base_dtype
201 coefficients = (apply_state or {}).get(
202 (var_device, var_dtype)
203 ) or self._fallback_apply_state(var_device, var_dtype)
205 # Adjust L2 regularization strength to include beta to avoid the
206 # underlying TensorFlow ops needing to include it.
207 adjusted_l2_regularization_strength = coefficients[
208 "l2_regularization_strength"
209 ] + coefficients["beta"] / (2.0 * coefficients["lr_t"])
211 accum = self.get_slot(var, "accumulator")
212 linear = self.get_slot(var, "linear")
214 if self._l2_shrinkage_regularization_strength <= 0.0:
215 return tf.raw_ops.ResourceApplyFtrl(
216 var=var.handle,
217 accum=accum.handle,
218 linear=linear.handle,
219 grad=grad,
220 lr=coefficients["lr_t"],
221 l1=coefficients["l1_regularization_strength"],
222 l2=adjusted_l2_regularization_strength,
223 lr_power=coefficients["learning_rate_power"],
224 use_locking=self._use_locking,
225 )
226 else:
227 return tf.raw_ops.ResourceApplyFtrlV2(
228 var=var.handle,
229 accum=accum.handle,
230 linear=linear.handle,
231 grad=grad,
232 lr=coefficients["lr_t"],
233 l1=coefficients["l1_regularization_strength"],
234 l2=adjusted_l2_regularization_strength,
235 l2_shrinkage=coefficients[
236 "l2_shrinkage_regularization_strength"
237 ],
238 lr_power=coefficients["learning_rate_power"],
239 use_locking=self._use_locking,
240 )
242 def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
243 var_device, var_dtype = var.device, var.dtype.base_dtype
244 coefficients = (apply_state or {}).get(
245 (var_device, var_dtype)
246 ) or self._fallback_apply_state(var_device, var_dtype)
248 # Adjust L2 regularization strength to include beta to avoid the
249 # underlying TensorFlow ops needing to include it.
250 adjusted_l2_regularization_strength = coefficients[
251 "l2_regularization_strength"
252 ] + coefficients["beta"] / (2.0 * coefficients["lr_t"])
254 accum = self.get_slot(var, "accumulator")
255 linear = self.get_slot(var, "linear")
257 if self._l2_shrinkage_regularization_strength <= 0.0:
258 return tf.raw_ops.ResourceSparseApplyFtrl(
259 var=var.handle,
260 accum=accum.handle,
261 linear=linear.handle,
262 grad=grad,
263 indices=indices,
264 lr=coefficients["lr_t"],
265 l1=coefficients["l1_regularization_strength"],
266 l2=adjusted_l2_regularization_strength,
267 lr_power=coefficients["learning_rate_power"],
268 use_locking=self._use_locking,
269 )
270 else:
271 return tf.raw_ops.ResourceSparseApplyFtrlV2(
272 var=var.handle,
273 accum=accum.handle,
274 linear=linear.handle,
275 grad=grad,
276 indices=indices,
277 lr=coefficients["lr_t"],
278 l1=coefficients["l1_regularization_strength"],
279 l2=adjusted_l2_regularization_strength,
280 l2_shrinkage=coefficients[
281 "l2_shrinkage_regularization_strength"
282 ],
283 lr_power=coefficients["learning_rate_power"],
284 use_locking=self._use_locking,
285 )
287 def get_config(self):
288 config = super().get_config()
289 config.update(
290 {
291 "learning_rate": self._serialize_hyperparameter(
292 "learning_rate"
293 ),
294 "decay": self._initial_decay,
295 "initial_accumulator_value": self._initial_accumulator_value,
296 "learning_rate_power": self._serialize_hyperparameter(
297 "learning_rate_power"
298 ),
299 "l1_regularization_strength": self._serialize_hyperparameter(
300 "l1_regularization_strength"
301 ),
302 "l2_regularization_strength": self._serialize_hyperparameter(
303 "l2_regularization_strength"
304 ),
305 "beta": self._serialize_hyperparameter("beta"),
306 "l2_shrinkage_regularization_strength": self._l2_shrinkage_regularization_strength, # noqa: E501
307 }
308 )
309 return config