Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py: 25%
60 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ftrl-proximal optimizer implementation."""
16# pylint: disable=g-classes-have-attributes
18from tensorflow.python.keras.optimizer_v2 import optimizer_v2
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import init_ops
21from tensorflow.python.ops import math_ops
22from tensorflow.python.training import gen_training_ops
23from tensorflow.python.util.tf_export import keras_export
26@keras_export('keras.optimizers.Ftrl')
27class Ftrl(optimizer_v2.OptimizerV2):
28 r"""Optimizer that implements the FTRL algorithm.
30 "Follow The Regularized Leader" (FTRL) is an optimization algorithm developed
31 at Google for click-through rate prediction in the early 2010s. It is most
32 suitable for shallow models with large and sparse feature spaces.
33 The algorithm is described by
34 [McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf).
35 The Keras version has support for both online L2 regularization
36 (the L2 regularization described in the paper
37 above) and shrinkage-type L2 regularization
38 (which is the addition of an L2 penalty to the loss function).
40 Initialization:
42 ```python
43 n = 0
44 sigma = 0
45 z = 0
46 ```
48 Update rule for one variable `w`:
50 ```python
51 prev_n = n
52 n = n + g ** 2
53 sigma = (sqrt(n) - sqrt(prev_n)) / lr
54 z = z + g - sigma * w
55 if abs(z) < lambda_1:
56 w = 0
57 else:
58 w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2)
59 ```
61 Notation:
63 - `lr` is the learning rate
64 - `g` is the gradient for the variable
65 - `lambda_1` is the L1 regularization strength
66 - `lambda_2` is the L2 regularization strength
68 Check the documentation for the `l2_shrinkage_regularization_strength`
69 parameter for more details when shrinkage is enabled, in which case gradient
70 is replaced with a gradient with shrinkage.
72 Args:
73 learning_rate: A `Tensor`, floating point value, or a schedule that is a
74 `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate.
75 learning_rate_power: A float value, must be less or equal to zero.
76 Controls how the learning rate decreases during training. Use zero for
77 a fixed learning rate.
78 initial_accumulator_value: The starting value for accumulators.
79 Only zero or positive values are allowed.
80 l1_regularization_strength: A float value, must be greater than or
81 equal to zero. Defaults to 0.0.
82 l2_regularization_strength: A float value, must be greater than or
83 equal to zero. Defaults to 0.0.
84 name: Optional name prefix for the operations created when applying
85 gradients. Defaults to `"Ftrl"`.
86 l2_shrinkage_regularization_strength: A float value, must be greater than
87 or equal to zero. This differs from L2 above in that the L2 above is a
88 stabilization penalty, whereas this L2 shrinkage is a magnitude penalty.
89 When input is sparse shrinkage will only happen on the active weights.
90 beta: A float value, representing the beta value from the paper.
91 Defaults to 0.0.
92 **kwargs: Keyword arguments. Allowed to be one of
93 `"clipnorm"` or `"clipvalue"`.
94 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
95 gradients by value.
97 Reference:
98 - [McMahan et al., 2013](
99 https://research.google.com/pubs/archive/41159.pdf)
100 """
102 def __init__(self,
103 learning_rate=0.001,
104 learning_rate_power=-0.5,
105 initial_accumulator_value=0.1,
106 l1_regularization_strength=0.0,
107 l2_regularization_strength=0.0,
108 name='Ftrl',
109 l2_shrinkage_regularization_strength=0.0,
110 beta=0.0,
111 **kwargs):
112 super(Ftrl, self).__init__(name, **kwargs)
114 if initial_accumulator_value < 0.0:
115 raise ValueError(
116 'initial_accumulator_value %f needs to be positive or zero' %
117 initial_accumulator_value)
118 if learning_rate_power > 0.0:
119 raise ValueError('learning_rate_power %f needs to be negative or zero' %
120 learning_rate_power)
121 if l1_regularization_strength < 0.0:
122 raise ValueError(
123 'l1_regularization_strength %f needs to be positive or zero' %
124 l1_regularization_strength)
125 if l2_regularization_strength < 0.0:
126 raise ValueError(
127 'l2_regularization_strength %f needs to be positive or zero' %
128 l2_regularization_strength)
129 if l2_shrinkage_regularization_strength < 0.0:
130 raise ValueError(
131 'l2_shrinkage_regularization_strength %f needs to be positive'
132 ' or zero' % l2_shrinkage_regularization_strength)
134 self._set_hyper('learning_rate', learning_rate)
135 self._set_hyper('decay', self._initial_decay)
136 self._set_hyper('learning_rate_power', learning_rate_power)
137 self._set_hyper('l1_regularization_strength', l1_regularization_strength)
138 self._set_hyper('l2_regularization_strength', l2_regularization_strength)
139 self._set_hyper('beta', beta)
140 self._initial_accumulator_value = initial_accumulator_value
141 self._l2_shrinkage_regularization_strength = (
142 l2_shrinkage_regularization_strength)
144 def _create_slots(self, var_list):
145 # Create the "accum" and "linear" slots.
146 for var in var_list:
147 dtype = var.dtype.base_dtype
148 init = init_ops.constant_initializer(
149 self._initial_accumulator_value, dtype=dtype)
150 self.add_slot(var, 'accumulator', init)
151 self.add_slot(var, 'linear')
153 def _prepare_local(self, var_device, var_dtype, apply_state):
154 super(Ftrl, self)._prepare_local(var_device, var_dtype, apply_state)
155 apply_state[(var_device, var_dtype)].update(
156 dict(
157 learning_rate_power=array_ops.identity(
158 self._get_hyper('learning_rate_power', var_dtype)),
159 l1_regularization_strength=array_ops.identity(
160 self._get_hyper('l1_regularization_strength', var_dtype)),
161 l2_regularization_strength=array_ops.identity(
162 self._get_hyper('l2_regularization_strength', var_dtype)),
163 beta=array_ops.identity(self._get_hyper('beta', var_dtype)),
164 l2_shrinkage_regularization_strength=math_ops.cast(
165 self._l2_shrinkage_regularization_strength, var_dtype)))
167 def _resource_apply_dense(self, grad, var, apply_state=None):
168 var_device, var_dtype = var.device, var.dtype.base_dtype
169 coefficients = ((apply_state or {}).get((var_device, var_dtype))
170 or self._fallback_apply_state(var_device, var_dtype))
172 # Adjust L2 regularization strength to include beta to avoid the underlying
173 # TensorFlow ops needing to include it.
174 adjusted_l2_regularization_strength = (
175 coefficients['l2_regularization_strength'] + coefficients['beta'] /
176 (2. * coefficients['lr_t']))
178 accum = self.get_slot(var, 'accumulator')
179 linear = self.get_slot(var, 'linear')
181 if self._l2_shrinkage_regularization_strength <= 0.0:
182 return gen_training_ops.ResourceApplyFtrl(
183 var=var.handle,
184 accum=accum.handle,
185 linear=linear.handle,
186 grad=grad,
187 lr=coefficients['lr_t'],
188 l1=coefficients['l1_regularization_strength'],
189 l2=adjusted_l2_regularization_strength,
190 lr_power=coefficients['learning_rate_power'],
191 use_locking=self._use_locking)
192 else:
193 return gen_training_ops.ResourceApplyFtrlV2(
194 var=var.handle,
195 accum=accum.handle,
196 linear=linear.handle,
197 grad=grad,
198 lr=coefficients['lr_t'],
199 l1=coefficients['l1_regularization_strength'],
200 l2=adjusted_l2_regularization_strength,
201 l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'],
202 lr_power=coefficients['learning_rate_power'],
203 use_locking=self._use_locking)
205 def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
206 var_device, var_dtype = var.device, var.dtype.base_dtype
207 coefficients = ((apply_state or {}).get((var_device, var_dtype))
208 or self._fallback_apply_state(var_device, var_dtype))
210 # Adjust L2 regularization strength to include beta to avoid the underlying
211 # TensorFlow ops needing to include it.
212 adjusted_l2_regularization_strength = (
213 coefficients['l2_regularization_strength'] + coefficients['beta'] /
214 (2. * coefficients['lr_t']))
216 accum = self.get_slot(var, 'accumulator')
217 linear = self.get_slot(var, 'linear')
219 if self._l2_shrinkage_regularization_strength <= 0.0:
220 return gen_training_ops.ResourceSparseApplyFtrl(
221 var=var.handle,
222 accum=accum.handle,
223 linear=linear.handle,
224 grad=grad,
225 indices=indices,
226 lr=coefficients['lr_t'],
227 l1=coefficients['l1_regularization_strength'],
228 l2=adjusted_l2_regularization_strength,
229 lr_power=coefficients['learning_rate_power'],
230 use_locking=self._use_locking)
231 else:
232 return gen_training_ops.ResourceSparseApplyFtrlV2(
233 var=var.handle,
234 accum=accum.handle,
235 linear=linear.handle,
236 grad=grad,
237 indices=indices,
238 lr=coefficients['lr_t'],
239 l1=coefficients['l1_regularization_strength'],
240 l2=adjusted_l2_regularization_strength,
241 l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'],
242 lr_power=coefficients['learning_rate_power'],
243 use_locking=self._use_locking)
245 def get_config(self):
246 config = super(Ftrl, self).get_config()
247 config.update({
248 'learning_rate':
249 self._serialize_hyperparameter('learning_rate'),
250 'decay':
251 self._initial_decay,
252 'initial_accumulator_value':
253 self._initial_accumulator_value,
254 'learning_rate_power':
255 self._serialize_hyperparameter('learning_rate_power'),
256 'l1_regularization_strength':
257 self._serialize_hyperparameter('l1_regularization_strength'),
258 'l2_regularization_strength':
259 self._serialize_hyperparameter('l2_regularization_strength'),
260 'beta':
261 self._serialize_hyperparameter('beta'),
262 'l2_shrinkage_regularization_strength':
263 self._l2_shrinkage_regularization_strength,
264 })
265 return config