Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py: 25%

60 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Ftrl-proximal optimizer implementation.""" 

16# pylint: disable=g-classes-have-attributes 

17 

18from tensorflow.python.keras.optimizer_v2 import optimizer_v2 

19from tensorflow.python.ops import array_ops 

20from tensorflow.python.ops import init_ops 

21from tensorflow.python.ops import math_ops 

22from tensorflow.python.training import gen_training_ops 

23from tensorflow.python.util.tf_export import keras_export 

24 

25 

26@keras_export('keras.optimizers.Ftrl') 

27class Ftrl(optimizer_v2.OptimizerV2): 

28 r"""Optimizer that implements the FTRL algorithm. 

29 

30 "Follow The Regularized Leader" (FTRL) is an optimization algorithm developed 

31 at Google for click-through rate prediction in the early 2010s. It is most 

32 suitable for shallow models with large and sparse feature spaces. 

33 The algorithm is described by 

34 [McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf). 

35 The Keras version has support for both online L2 regularization 

36 (the L2 regularization described in the paper 

37 above) and shrinkage-type L2 regularization 

38 (which is the addition of an L2 penalty to the loss function). 

39 

40 Initialization: 

41 

42 ```python 

43 n = 0 

44 sigma = 0 

45 z = 0 

46 ``` 

47 

48 Update rule for one variable `w`: 

49 

50 ```python 

51 prev_n = n 

52 n = n + g ** 2 

53 sigma = (sqrt(n) - sqrt(prev_n)) / lr 

54 z = z + g - sigma * w 

55 if abs(z) < lambda_1: 

56 w = 0 

57 else: 

58 w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2) 

59 ``` 

60 

61 Notation: 

62 

63 - `lr` is the learning rate 

64 - `g` is the gradient for the variable 

65 - `lambda_1` is the L1 regularization strength 

66 - `lambda_2` is the L2 regularization strength 

67 

68 Check the documentation for the `l2_shrinkage_regularization_strength` 

69 parameter for more details when shrinkage is enabled, in which case gradient 

70 is replaced with a gradient with shrinkage. 

71 

72 Args: 

73 learning_rate: A `Tensor`, floating point value, or a schedule that is a 

74 `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. 

75 learning_rate_power: A float value, must be less or equal to zero. 

76 Controls how the learning rate decreases during training. Use zero for 

77 a fixed learning rate. 

78 initial_accumulator_value: The starting value for accumulators. 

79 Only zero or positive values are allowed. 

80 l1_regularization_strength: A float value, must be greater than or 

81 equal to zero. Defaults to 0.0. 

82 l2_regularization_strength: A float value, must be greater than or 

83 equal to zero. Defaults to 0.0. 

84 name: Optional name prefix for the operations created when applying 

85 gradients. Defaults to `"Ftrl"`. 

86 l2_shrinkage_regularization_strength: A float value, must be greater than 

87 or equal to zero. This differs from L2 above in that the L2 above is a 

88 stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. 

89 When input is sparse shrinkage will only happen on the active weights. 

90 beta: A float value, representing the beta value from the paper. 

91 Defaults to 0.0. 

92 **kwargs: Keyword arguments. Allowed to be one of 

93 `"clipnorm"` or `"clipvalue"`. 

94 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 

95 gradients by value. 

96 

97 Reference: 

98 - [McMahan et al., 2013]( 

99 https://research.google.com/pubs/archive/41159.pdf) 

100 """ 

101 

102 def __init__(self, 

103 learning_rate=0.001, 

104 learning_rate_power=-0.5, 

105 initial_accumulator_value=0.1, 

106 l1_regularization_strength=0.0, 

107 l2_regularization_strength=0.0, 

108 name='Ftrl', 

109 l2_shrinkage_regularization_strength=0.0, 

110 beta=0.0, 

111 **kwargs): 

112 super(Ftrl, self).__init__(name, **kwargs) 

113 

114 if initial_accumulator_value < 0.0: 

115 raise ValueError( 

116 'initial_accumulator_value %f needs to be positive or zero' % 

117 initial_accumulator_value) 

118 if learning_rate_power > 0.0: 

119 raise ValueError('learning_rate_power %f needs to be negative or zero' % 

120 learning_rate_power) 

121 if l1_regularization_strength < 0.0: 

122 raise ValueError( 

123 'l1_regularization_strength %f needs to be positive or zero' % 

124 l1_regularization_strength) 

125 if l2_regularization_strength < 0.0: 

126 raise ValueError( 

127 'l2_regularization_strength %f needs to be positive or zero' % 

128 l2_regularization_strength) 

129 if l2_shrinkage_regularization_strength < 0.0: 

130 raise ValueError( 

131 'l2_shrinkage_regularization_strength %f needs to be positive' 

132 ' or zero' % l2_shrinkage_regularization_strength) 

133 

134 self._set_hyper('learning_rate', learning_rate) 

135 self._set_hyper('decay', self._initial_decay) 

136 self._set_hyper('learning_rate_power', learning_rate_power) 

137 self._set_hyper('l1_regularization_strength', l1_regularization_strength) 

138 self._set_hyper('l2_regularization_strength', l2_regularization_strength) 

139 self._set_hyper('beta', beta) 

140 self._initial_accumulator_value = initial_accumulator_value 

141 self._l2_shrinkage_regularization_strength = ( 

142 l2_shrinkage_regularization_strength) 

143 

144 def _create_slots(self, var_list): 

145 # Create the "accum" and "linear" slots. 

146 for var in var_list: 

147 dtype = var.dtype.base_dtype 

148 init = init_ops.constant_initializer( 

149 self._initial_accumulator_value, dtype=dtype) 

150 self.add_slot(var, 'accumulator', init) 

151 self.add_slot(var, 'linear') 

152 

153 def _prepare_local(self, var_device, var_dtype, apply_state): 

154 super(Ftrl, self)._prepare_local(var_device, var_dtype, apply_state) 

155 apply_state[(var_device, var_dtype)].update( 

156 dict( 

157 learning_rate_power=array_ops.identity( 

158 self._get_hyper('learning_rate_power', var_dtype)), 

159 l1_regularization_strength=array_ops.identity( 

160 self._get_hyper('l1_regularization_strength', var_dtype)), 

161 l2_regularization_strength=array_ops.identity( 

162 self._get_hyper('l2_regularization_strength', var_dtype)), 

163 beta=array_ops.identity(self._get_hyper('beta', var_dtype)), 

164 l2_shrinkage_regularization_strength=math_ops.cast( 

165 self._l2_shrinkage_regularization_strength, var_dtype))) 

166 

167 def _resource_apply_dense(self, grad, var, apply_state=None): 

168 var_device, var_dtype = var.device, var.dtype.base_dtype 

169 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 

170 or self._fallback_apply_state(var_device, var_dtype)) 

171 

172 # Adjust L2 regularization strength to include beta to avoid the underlying 

173 # TensorFlow ops needing to include it. 

174 adjusted_l2_regularization_strength = ( 

175 coefficients['l2_regularization_strength'] + coefficients['beta'] / 

176 (2. * coefficients['lr_t'])) 

177 

178 accum = self.get_slot(var, 'accumulator') 

179 linear = self.get_slot(var, 'linear') 

180 

181 if self._l2_shrinkage_regularization_strength <= 0.0: 

182 return gen_training_ops.ResourceApplyFtrl( 

183 var=var.handle, 

184 accum=accum.handle, 

185 linear=linear.handle, 

186 grad=grad, 

187 lr=coefficients['lr_t'], 

188 l1=coefficients['l1_regularization_strength'], 

189 l2=adjusted_l2_regularization_strength, 

190 lr_power=coefficients['learning_rate_power'], 

191 use_locking=self._use_locking) 

192 else: 

193 return gen_training_ops.ResourceApplyFtrlV2( 

194 var=var.handle, 

195 accum=accum.handle, 

196 linear=linear.handle, 

197 grad=grad, 

198 lr=coefficients['lr_t'], 

199 l1=coefficients['l1_regularization_strength'], 

200 l2=adjusted_l2_regularization_strength, 

201 l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'], 

202 lr_power=coefficients['learning_rate_power'], 

203 use_locking=self._use_locking) 

204 

205 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 

206 var_device, var_dtype = var.device, var.dtype.base_dtype 

207 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 

208 or self._fallback_apply_state(var_device, var_dtype)) 

209 

210 # Adjust L2 regularization strength to include beta to avoid the underlying 

211 # TensorFlow ops needing to include it. 

212 adjusted_l2_regularization_strength = ( 

213 coefficients['l2_regularization_strength'] + coefficients['beta'] / 

214 (2. * coefficients['lr_t'])) 

215 

216 accum = self.get_slot(var, 'accumulator') 

217 linear = self.get_slot(var, 'linear') 

218 

219 if self._l2_shrinkage_regularization_strength <= 0.0: 

220 return gen_training_ops.ResourceSparseApplyFtrl( 

221 var=var.handle, 

222 accum=accum.handle, 

223 linear=linear.handle, 

224 grad=grad, 

225 indices=indices, 

226 lr=coefficients['lr_t'], 

227 l1=coefficients['l1_regularization_strength'], 

228 l2=adjusted_l2_regularization_strength, 

229 lr_power=coefficients['learning_rate_power'], 

230 use_locking=self._use_locking) 

231 else: 

232 return gen_training_ops.ResourceSparseApplyFtrlV2( 

233 var=var.handle, 

234 accum=accum.handle, 

235 linear=linear.handle, 

236 grad=grad, 

237 indices=indices, 

238 lr=coefficients['lr_t'], 

239 l1=coefficients['l1_regularization_strength'], 

240 l2=adjusted_l2_regularization_strength, 

241 l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'], 

242 lr_power=coefficients['learning_rate_power'], 

243 use_locking=self._use_locking) 

244 

245 def get_config(self): 

246 config = super(Ftrl, self).get_config() 

247 config.update({ 

248 'learning_rate': 

249 self._serialize_hyperparameter('learning_rate'), 

250 'decay': 

251 self._initial_decay, 

252 'initial_accumulator_value': 

253 self._initial_accumulator_value, 

254 'learning_rate_power': 

255 self._serialize_hyperparameter('learning_rate_power'), 

256 'l1_regularization_strength': 

257 self._serialize_hyperparameter('l1_regularization_strength'), 

258 'l2_regularization_strength': 

259 self._serialize_hyperparameter('l2_regularization_strength'), 

260 'beta': 

261 self._serialize_hyperparameter('beta'), 

262 'l2_shrinkage_regularization_strength': 

263 self._l2_shrinkage_regularization_strength, 

264 }) 

265 return config