Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/yogi.py: 10%

155 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Yogi: Extension of yogi adaptive nonconvex optimizer in Keras. 

16 

17Implementation of Additive Averaging. 

18m_t+1 = beta1*m_t + (1-beta1)*g_t 

19v_t+1 = v_t + sign(g_t-v_t)(g_t^2) 

20Experiments show better performance across NLP and Vision tasks. 

21Paper: 

22https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf 

23""" 

24 

25import tensorflow as tf 

26from tensorflow_addons.utils.types import FloatTensorLike 

27 

28from tensorflow_addons.optimizers import KerasLegacyOptimizer 

29from typeguard import typechecked 

30from typing import Union, Callable 

31 

32 

33def _solve(a, b, c): 

34 """Return solution of a quadratic minimization. 

35 

36 The optimization equation is: 

37 f(a, b, c) = argmin_w{1/2 * a * w^2 + b * w + c * |w|} 

38 we get optimal solution w*: 

39 w* = -(b - sign(b)*c)/a if |b| > c else w* = 0 

40 REQUIRES: Dimensionality of a and b must be same 

41 Args: 

42 a: A Tensor 

43 b: A Tensor 

44 c: A Tensor with one element. 

45 Returns: 

46 A Tensor w, which is solution for the equation 

47 """ 

48 w = (c * tf.sign(b) - b) / a 

49 w = tf.cast(tf.abs(b) > c, dtype=b.dtype) * w 

50 return w 

51 

52 

53@tf.keras.utils.register_keras_serializable(package="Addons") 

54class Yogi(KerasLegacyOptimizer): 

55 """Optimizer that implements the Yogi algorithm in Keras. 

56 

57 See Algorithm 2 of 

58 https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf. 

59 """ 

60 

61 @typechecked 

62 def __init__( 

63 self, 

64 learning_rate: Union[FloatTensorLike, Callable] = 0.01, 

65 beta1: FloatTensorLike = 0.9, 

66 beta2: FloatTensorLike = 0.999, 

67 epsilon: FloatTensorLike = 1e-3, 

68 l1_regularization_strength: FloatTensorLike = 0.0, 

69 l2_regularization_strength: FloatTensorLike = 0.0, 

70 initial_accumulator_value: FloatTensorLike = 1e-6, 

71 activation: str = "sign", 

72 name: str = "Yogi", 

73 **kwargs, 

74 ): 

75 """Construct a new Yogi optimizer. 

76 

77 Args: 

78 learning_rate: A Tensor or a floating point value. 

79 The learning rate. 

80 beta1: A float value or a constant float tensor. 

81 The exponential decay rate for the 1st moment estimates. 

82 beta2: A float value or a constant float tensor. 

83 The exponential decay rate for the 2nd moment estimates. 

84 epsilon: A constant trading off adaptivity and noise. 

85 l1_regularization_strength: A float value, must be greater than or 

86 equal to zero. 

87 l2_regularization_strength: A float value, must be greater than or 

88 equal to zero. 

89 initial_accumulator_value: The starting value for accumulators. 

90 Only positive values are allowed. 

91 activation: Use hard sign or soft tanh to determin sign. 

92 name: Optional name for the operations created when applying 

93 gradients. Defaults to "Yogi". 

94 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, 

95 `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` 

96 is clip gradients by value, `decay` is included for backward 

97 compatibility to allow time inverse decay of learning rate. `lr` 

98 is included for backward compatibility, recommended to use 

99 `learning_rate` instead. 

100 """ 

101 super().__init__(name, **kwargs) 

102 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 

103 self._set_hyper("decay", self._initial_decay) 

104 self._set_hyper("beta_1", beta1) 

105 self._set_hyper("beta_2", beta2) 

106 self._set_hyper("epsilon", epsilon) 

107 self._set_hyper("l1_regularization_strength", l1_regularization_strength) 

108 self._set_hyper("l2_regularization_strength", l2_regularization_strength) 

109 

110 self._beta1 = beta1 

111 self._activation = activation 

112 self._initial_accumulator_value = initial_accumulator_value 

113 self._l1_regularization_strength = l1_regularization_strength 

114 self._l2_regularization_strength = l2_regularization_strength 

115 

116 def _create_slots(self, var_list): 

117 """See `tf.train.Optimizer._create_slots()`.""" 

118 # Create slots for the first and second moments, and maximum second moments. 

119 for var in var_list: 

120 init = tf.constant_initializer(self._initial_accumulator_value) 

121 self.add_slot(var, "v", init) 

122 if self._beta1 > 0.0: 

123 self.add_slot(var, "m") 

124 

125 def _resource_apply_dense(self, grad, var): 

126 """See `tf.train.Optimizer._apply_dense()`.""" 

127 var_dtype = var.dtype.base_dtype 

128 lr_t = self._decayed_lr(var_dtype) 

129 beta1_t = self._get_hyper("beta_1", var_dtype) 

130 beta2_t = self._get_hyper("beta_2", var_dtype) 

131 epsilon_t = self._get_hyper("epsilon", var_dtype) 

132 l1_t = self._get_hyper("l1_regularization_strength", var_dtype) 

133 l2_t = self._get_hyper("l2_regularization_strength", var_dtype) 

134 local_step = tf.cast(self.iterations + 1, var_dtype) 

135 beta1_power = tf.pow(beta1_t, local_step) 

136 beta2_power = tf.pow(beta2_t, local_step) 

137 

138 lr = lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power) 

139 

140 update_vs = [] 

141 if self._beta1 == 0.0: 

142 # v_t = v + sign(g_t^2-v)(g_t^2) 

143 v = self.get_slot(var, "v") 

144 grad2 = grad * grad 

145 if self._activation == "sign": 

146 sign = tf.sign(grad2 - v) 

147 elif self._activation == "tanh": 

148 sign = tf.tanh(10 * (grad2 - v)) 

149 else: 

150 raise NotImplementedError("Activation function can be sign or tanh") 

151 v_t = v.assign_add( 

152 (1 - beta2_t) * sign * grad2, use_locking=self._use_locking 

153 ) 

154 v_sqrt = tf.sqrt(v_t) 

155 

156 # Yogi effective LR 

157 per_coord_lr = lr / (v_sqrt + epsilon_t) 

158 

159 # Variable update 

160 # Step 1: Gradient descent 

161 new_var = var - per_coord_lr * grad 

162 # Step 2: Prox operator 

163 if self._l1_regularization_strength > 0: 

164 new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr) 

165 elif self._l2_regularization_strength > 0: 

166 new_var = new_var / (1 + l2_t * per_coord_lr) 

167 # Step 3: Update 

168 var_update = var.assign(new_var, use_locking=self._use_locking) 

169 

170 update_vs.append(var_update) 

171 update_vs.append(v_t) 

172 

173 else: 

174 # m_t = beta1 * m + (1 - beta1) * g_t 

175 m = self.get_slot(var, "m") 

176 m_t = m.assign( 

177 m * beta1_t + grad * (1 - beta1_t), use_locking=self._use_locking 

178 ) 

179 

180 # v_t = v + sign(g_t^2-v)(g_t^2) 

181 v = self.get_slot(var, "v") 

182 grad2 = grad * grad 

183 if self._activation == "sign": 

184 sign = tf.sign(grad2 - v) 

185 elif self._activation == "tanh": 

186 sign = tf.tanh(10 * (grad2 - v)) 

187 else: 

188 raise NotImplementedError("Activation function can be sign or tanh") 

189 v_t = v.assign_add( 

190 (1 - beta2_t) * sign * grad2, use_locking=self._use_locking 

191 ) 

192 v_sqrt = tf.sqrt(v_t) 

193 

194 # Yogi effective LR 

195 per_coord_lr = lr / (v_sqrt + epsilon_t) 

196 

197 # Variable update 

198 # Step 1: Gradient descent 

199 new_var = var - per_coord_lr * m_t 

200 # Step 2: Prox operator 

201 if self._l1_regularization_strength > 0: 

202 new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr) 

203 elif self._l2_regularization_strength > 0: 

204 new_var = new_var / (1 + l2_t * per_coord_lr) 

205 # Step 3: Update 

206 var_update = var.assign(new_var, use_locking=self._use_locking) 

207 update_vs.append(var_update) 

208 update_vs.append(m_t) 

209 update_vs.append(v_t) 

210 

211 # Create an op that groups all the above operations 

212 return tf.group(*update_vs) 

213 

214 def _resource_apply_sparse(self, grad, var, indices): 

215 """Applies sparse gradients to a variable. 

216 

217 Args: 

218 grad: A tensor for the `values` of `tf.IndexedSlices`. 

219 var: A `tf.Variable` object. 

220 indices: A tensor for the `indices` of `tf.IndexedSlices`. 

221 Returns: 

222 An op which updates `var` with `grad` and `indices`. 

223 """ 

224 

225 var_dtype = var.dtype.base_dtype 

226 lr_t = self._decayed_lr(var_dtype) 

227 beta1_t = self._get_hyper("beta_1", var_dtype) 

228 beta2_t = self._get_hyper("beta_2", var_dtype) 

229 epsilon_t = self._get_hyper("epsilon", var_dtype) 

230 l1_t = self._get_hyper("l1_regularization_strength", var_dtype) 

231 l2_t = self._get_hyper("l2_regularization_strength", var_dtype) 

232 local_step = tf.cast(self.iterations + 1, var_dtype) 

233 beta1_power = tf.pow(beta1_t, local_step) 

234 beta2_power = tf.pow(beta2_t, local_step) 

235 

236 lr = lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power) 

237 

238 update_vs = [] 

239 if self._beta1 == 0.0: 

240 # v_t = v + sign(g_t^2-v)(g_t^2) 

241 v = self.get_slot(var, "v") 

242 grad2 = grad * grad 

243 v_slice = tf.gather(v, indices) 

244 if self._activation == "sign": 

245 sign = tf.sign(grad2 - v_slice) 

246 elif self._activation == "tanh": 

247 sign = tf.tanh(10 * (grad2 - v_slice)) 

248 else: 

249 raise NotImplementedError("Activation function can be sign or tanh") 

250 v_scaled_g_values = v_slice + (1 - beta2_t) * sign * grad2 

251 v_t = self._resource_scatter_update(v, indices, v_scaled_g_values) 

252 v_sqrt = tf.sqrt(v_scaled_g_values) 

253 

254 # Yogi effective LR 

255 per_coord_lr = lr / (v_sqrt + epsilon_t) 

256 

257 # Variable update 

258 # Step 1: Gradient descent 

259 var_slice = tf.gather(var, indices) 

260 new_var = var_slice - per_coord_lr * grad 

261 # Step 2: Prox operator 

262 if self._l1_regularization_strength > 0: 

263 new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr) 

264 elif self._l2_regularization_strength > 0: 

265 new_var = new_var / (1 + l2_t * per_coord_lr) 

266 # Step 3: Update 

267 var_update = self._resource_scatter_update(var, indices, new_var) 

268 update_vs.append(var_update) 

269 update_vs.append(v_t) 

270 

271 else: 

272 # m_t = beta1 * m + (1 - beta1) * g_t 

273 m = self.get_slot(var, "m") 

274 m_scaled_g_values = grad * (1 - beta1_t) 

275 m_t = m.assign(m * beta1_t, use_locking=self._use_locking) 

276 with tf.control_dependencies([m_t]): 

277 m_slice = tf.gather(m, indices) + m_scaled_g_values 

278 m_t = self._resource_scatter_update(m, indices, m_slice) 

279 

280 # v_t = v + sign(g_t^2-v)(g_t^2) 

281 v = self.get_slot(var, "v") 

282 grad2 = grad * grad 

283 v_slice = tf.gather(v, indices) 

284 if self._activation == "sign": 

285 sign = tf.sign(grad2 - tf.gather(v, indices)) 

286 elif self._activation == "tanh": 

287 sign = tf.tanh(10 * (grad2 - tf.gather(v, indices))) 

288 else: 

289 raise NotImplementedError("Activation function can be sign or tanh") 

290 v_scaled_g_values = v_slice + (1 - beta2_t) * sign * grad2 

291 v_t = self._resource_scatter_update(v, indices, v_scaled_g_values) 

292 v_sqrt = tf.sqrt(v_scaled_g_values) 

293 

294 # Yogi effective LR 

295 per_coord_lr = lr / (v_sqrt + epsilon_t) 

296 

297 # Variable update 

298 # Step 1: Gradient descent 

299 var_slice = tf.gather(var, indices) 

300 new_var = var_slice - per_coord_lr * m_slice 

301 # Step 2: Prox operator 

302 if self._l1_regularization_strength > 0: 

303 new_var = _solve(1 + l2_t * per_coord_lr, -new_var, l1_t * per_coord_lr) 

304 elif self._l2_regularization_strength > 0: 

305 new_var = new_var / (1 + l2_t * per_coord_lr) 

306 # Step 3: Update 

307 var_update = self._resource_scatter_update(var, indices, new_var) 

308 update_vs.append(var_update) 

309 update_vs.append(m_t) 

310 update_vs.append(v_t) 

311 

312 # Create an op that groups all the above operations 

313 return tf.group(*update_vs) 

314 

315 def get_config(self): 

316 config = super().get_config() 

317 config.update( 

318 { 

319 "learning_rate": self._serialize_hyperparameter("learning_rate"), 

320 "decay": self._serialize_hyperparameter("decay"), 

321 "beta1": self._serialize_hyperparameter("beta_1"), 

322 "beta2": self._serialize_hyperparameter("beta_2"), 

323 "epsilon": self._serialize_hyperparameter("epsilon"), 

324 "l1_regularization_strength": self._serialize_hyperparameter( 

325 "l1_regularization_strength" 

326 ), 

327 "l2_regularization_strength": self._serialize_hyperparameter( 

328 "l2_regularization_strength" 

329 ), 

330 "activation": self._activation, 

331 "initial_accumulator_value": self._initial_accumulator_value, 

332 } 

333 ) 

334 return config