Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/weight_decay_optimizers.py: 38%

100 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base class to make optimizers weight decay ready.""" 

16 

17import importlib 

18import tensorflow as tf 

19from tensorflow_addons.utils.types import FloatTensorLike 

20from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes 

21 

22from typeguard import typechecked 

23from typing import Union, Callable, Type, Optional, List 

24 

25 

26class DecoupledWeightDecayExtension: 

27 """This class allows to extend optimizers with decoupled weight decay. 

28 

29 It implements the decoupled weight decay described by [Loshchilov & Hutter] 

30 (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is 

31 decoupled from the optimization steps w.r.t. to the loss function. 

32 For SGD variants, this simplifies hyperparameter search since it decouples 

33 the settings of weight decay and learning rate. 

34 For adaptive gradient algorithms, it regularizes variables with large 

35 gradients more than L2 regularization would, which was shown to yield 

36 better training loss and generalization error in the paper above. 

37 

38 This class alone is not an optimizer but rather extends existing 

39 optimizers with decoupled weight decay. We explicitly define the two 

40 examples used in the above paper (SGDW and AdamW), but in general this can 

41 extend any OptimizerX class by using 

42 `ExtendedCls = extend_with_decoupled_weight_decay(OptimizerX)`. 

43 Weight decay can then be set when instantiating the optimizer: 

44 `optimizerX = ExtendedCls(weight_decay=0.001, learning_rate=0.001)`. 

45 In order for it to work, it must be the first class the Optimizer with 

46 weight decay inherits from, e.g. 

47 

48 ```python 

49 class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): 

50 def __init__(self, weight_decay, *args, **kwargs): 

51 super(AdamW, self).__init__(weight_decay, *args, **kwargs). 

52 ``` 

53 

54 Note: this extension decays weights BEFORE applying the update based 

55 on the gradient, i.e. this extension only has the desired behaviour for 

56 optimizers which do not depend on the value of'var' in the update step! 

57 

58 Note: when applying a decay to the learning rate, be sure to manually apply 

59 the decay to the `weight_decay` as well. For example: 

60 

61 ```python 

62 step = tf.Variable(0, trainable=False) 

63 schedule = tf.optimizers.schedules.PiecewiseConstantDecay( 

64 [10000, 15000], [1e-0, 1e-1, 1e-2]) 

65 # lr and wd can be a function or a tensor 

66 lr = 1e-1 * schedule(step) 

67 wd = lambda: 1e-4 * schedule(step) 

68 

69 # ... 

70 

71 optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) 

72 ``` 

73 """ 

74 

75 @typechecked 

76 def __init__( 

77 self, 

78 weight_decay: Union[FloatTensorLike, Callable], 

79 exclude_from_weight_decay: Optional[List[str]] = None, 

80 **kwargs, 

81 ): 

82 """Extension class that adds weight decay to an optimizer. 

83 

84 Args: 

85 weight_decay: A `Tensor`, a floating point value, or a schedule 

86 that is a `tf.keras.optimizers.schedules.LearningRateSchedule` 

87 to decay the variable by, in the update step. 

88 exclude_from_weight_decay: List of regex patterns of 

89 variables excluded from weight decay. Variables whose name 

90 contain a substring matching the pattern will be excluded. 

91 Note `decay_var_list` in `minimize` or `apply_gradients` takes 

92 priority over `exclude_from_weight_decay` if specified. 

93 **kwargs: Optional list or tuple or set of `Variable` objects to 

94 decay. 

95 """ 

96 wd = kwargs.pop("weight_decay", weight_decay) 

97 super().__init__(**kwargs) 

98 self._decay_var_list = None # is set in minimize or apply_gradients 

99 self._set_hyper("weight_decay", wd) 

100 self.exclude_from_weight_decay = exclude_from_weight_decay 

101 

102 def get_config(self): 

103 config = super().get_config() 

104 config.update( 

105 { 

106 "weight_decay": self._serialize_hyperparameter("weight_decay"), 

107 "exclude_from_weight_decay": self.exclude_from_weight_decay, 

108 } 

109 ) 

110 return config 

111 

112 @classmethod 

113 def from_config(cls, config, custom_objects=None): 

114 # LR handling copied from optimizer_v2.OptimizerV2 

115 if "learning_rate" in config: 

116 if isinstance(config["learning_rate"], dict): 

117 config["learning_rate"] = tf.keras.optimizers.schedules.deserialize( 

118 config["learning_rate"], custom_objects=custom_objects 

119 ) 

120 

121 if "weight_decay" in config: 

122 if isinstance(config["weight_decay"], dict): 

123 config["weight_decay"] = tf.keras.optimizers.schedules.deserialize( 

124 config["weight_decay"], custom_objects=custom_objects 

125 ) 

126 

127 return cls(**config) 

128 

129 def minimize( 

130 self, 

131 loss, 

132 var_list, 

133 grad_loss=None, 

134 name=None, 

135 decay_var_list=None, 

136 tape=None, 

137 ): 

138 """Minimize `loss` by updating `var_list`. 

139 

140 This method simply computes gradient using `tf.GradientTape` and calls 

141 `apply_gradients()`. If you want to process the gradient before 

142 applying then call `tf.GradientTape` and `apply_gradients()` explicitly 

143 instead of using this function. 

144 

145 Args: 

146 loss: `Tensor` or callable. If a callable, `loss` should take no 

147 arguments and return the value to minimize. If a `Tensor`, the 

148 `tape` argument must be passed. 

149 var_list: list or tuple of `Variable` objects to update to 

150 minimize `loss`, or a callable returning the list or tuple of 

151 `Variable` objects. Use callable when the variable list would 

152 otherwise be incomplete before `minimize` since the variables 

153 are created at the first time `loss` is called. 

154 grad_loss: Optional. A `Tensor` holding the gradient computed for 

155 `loss`. 

156 decay_var_list: Optional list of variables to be decayed. Defaults 

157 to all variables in var_list. Note `decay_var_list` takes 

158 priority over `exclude_from_weight_decay` if specified. 

159 name: Optional name for the returned operation. 

160 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a 

161 `Tensor`, the tape that computed the `loss` must be provided. 

162 Returns: 

163 An Operation that updates the variables in `var_list`. 

164 Raises: 

165 ValueError: If some of the variables are not `Variable` objects. 

166 """ 

167 self._set_decay_var_list(var_list, decay_var_list) 

168 return super().minimize( 

169 loss, var_list=var_list, grad_loss=grad_loss, name=name, tape=tape 

170 ) 

171 

172 def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwargs): 

173 """Apply gradients to variables. 

174 

175 This is the second part of `minimize()`. It returns an `Operation` that 

176 applies gradients. 

177 

178 Args: 

179 grads_and_vars: List of (gradient, variable) pairs. 

180 name: Optional name for the returned operation. Default to the 

181 name passed to the `Optimizer` constructor. 

182 decay_var_list: Optional list of variables to be decayed. Defaults 

183 to all variables in var_list. Note `decay_var_list` takes 

184 priority over `exclude_from_weight_decay` if specified. 

185 **kwargs: Additional arguments to pass to the base optimizer's 

186 apply_gradient method, e.g., TF2.2 added an argument 

187 `experimental_aggregate_gradients`. 

188 Returns: 

189 An `Operation` that applies the specified gradients. 

190 Raises: 

191 TypeError: If `grads_and_vars` is malformed. 

192 ValueError: If none of the variables have gradients. 

193 """ 

194 grads_and_vars = list(grads_and_vars) 

195 self._set_decay_var_list((v for _, v in grads_and_vars), decay_var_list) 

196 return super().apply_gradients(grads_and_vars, name=name, **kwargs) 

197 

198 def _decay_weights_op(self, var, apply_state=None): 

199 if self._do_use_weight_decay(var): 

200 var_device, var_dtype = var.device, var.dtype.base_dtype 

201 coefficients = (apply_state or {}).get( 

202 (var_device, var_dtype) 

203 ) or self._fallback_apply_state(var_device, var_dtype) 

204 

205 return var.assign_sub(coefficients["wd_t"] * var, self._use_locking) 

206 return tf.no_op() 

207 

208 def _decay_weights_sparse_op(self, var, indices, apply_state=None): 

209 if self._do_use_weight_decay(var): 

210 var_device, var_dtype = var.device, var.dtype.base_dtype 

211 coefficients = (apply_state or {}).get( 

212 (var_device, var_dtype) 

213 ) or self._fallback_apply_state(var_device, var_dtype) 

214 

215 update = -coefficients["wd_t"] * tf.gather(var, indices) 

216 return self._resource_scatter_add(var, indices, update) 

217 return tf.no_op() 

218 

219 def _prepare_local(self, var_device, var_dtype, apply_state): 

220 super(DecoupledWeightDecayExtension, self)._prepare_local( 

221 var_device, var_dtype, apply_state 

222 ) 

223 

224 if "weight_decay" in self._hyper: 

225 wd_t = tf.identity(self._decayed_wd(var_dtype)) 

226 apply_state[(var_device, var_dtype)]["wd_t"] = wd_t 

227 

228 def _decayed_wd(self, var_dtype): 

229 wd_t = self._get_hyper("weight_decay", var_dtype) 

230 

231 if isinstance(wd_t, tf.keras.optimizers.schedules.LearningRateSchedule): 

232 wd_t = tf.cast(wd_t(self.iterations), var_dtype) 

233 

234 return wd_t 

235 

236 # Here, we overwrite the apply functions that the base optimizer calls. 

237 # super().apply_x resolves to the apply_x function of the BaseOptimizer. 

238 

239 def _resource_apply_dense(self, grad, var, apply_state=None): 

240 with tf.control_dependencies( 

241 [self._decay_weights_op(var, apply_state=apply_state)] 

242 ): 

243 return super()._resource_apply_dense(grad, var, apply_state=apply_state) 

244 

245 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 

246 decay_op = self._decay_weights_sparse_op(var, indices, apply_state=apply_state) 

247 with tf.control_dependencies([decay_op]): 

248 return super()._resource_apply_sparse( 

249 grad, var, indices, apply_state=apply_state 

250 ) 

251 

252 def _set_decay_var_list(self, var_list, decay_var_list=None): 

253 if decay_var_list: 

254 self._decay_var_list = set(v.ref() for v in decay_var_list) 

255 elif self.exclude_from_weight_decay: 

256 self._decay_var_list = set( 

257 v.ref() 

258 for v in var_list 

259 if not is_variable_matched_by_regexes(v, self.exclude_from_weight_decay) 

260 ) 

261 else: 

262 self._decay_var_list = None 

263 

264 def _do_use_weight_decay(self, var): 

265 """Whether to use L2 weight decay for `var`.""" 

266 if self._decay_var_list is None: 

267 return True 

268 return var.ref() in self._decay_var_list 

269 

270 

271if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None: 

272 keras_legacy_optimizer = Union[ 

273 tf.keras.optimizers.legacy.Optimizer, tf.keras.optimizers.Optimizer 

274 ] 

275else: 

276 keras_legacy_optimizer = tf.keras.optimizers.Optimizer 

277 

278 

279@typechecked 

280def extend_with_decoupled_weight_decay( 

281 base_optimizer: Type[keras_legacy_optimizer], 

282) -> Type[keras_legacy_optimizer]: 

283 """Factory function returning an optimizer class with decoupled weight 

284 decay. 

285 

286 Returns an optimizer class. An instance of the returned class computes the 

287 update step of `base_optimizer` and additionally decays the weights. 

288 E.g., the class returned by 

289 `extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is 

290 equivalent to `tfa.optimizers.AdamW`. 

291 

292 The API of the new optimizer class slightly differs from the API of the 

293 base optimizer: 

294 - The first argument to the constructor is the weight decay rate. 

295 - Optional keyword argument `exclude_from_weight_decay` accepts list of 

296 regex patterns of variables excluded from weight decay. Variables whose 

297 name contain a substring matching the pattern will be excluded. 

298 - `minimize` and `apply_gradients` accept the optional keyword argument 

299 `decay_var_list`, which specifies the variables that should be decayed. 

300 Note this takes priority over `exclude_from_weight_decay` if specified. 

301 If both `None`, all variables that are optimized are decayed. 

302 

303 Usage example: 

304 ```python 

305 # MyAdamW is a new class 

306 MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam) 

307 # Create a MyAdamW object 

308 optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) 

309 # update var1, var2 but only decay var1 

310 optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1]) 

311 

312 Note: this extension decays weights BEFORE applying the update based 

313 on the gradient, i.e. this extension only has the desired behaviour for 

314 optimizers which do not depend on the value of 'var' in the update step! 

315 

316 Note: when applying a decay to the learning rate, be sure to manually apply 

317 the decay to the `weight_decay` as well. For example: 

318 

319 ```python 

320 step = tf.Variable(0, trainable=False) 

321 schedule = tf.optimizers.schedules.PiecewiseConstantDecay( 

322 [10000, 15000], [1e-0, 1e-1, 1e-2]) 

323 # lr and wd can be a function or a tensor 

324 lr = 1e-1 * schedule(step) 

325 wd = lambda: 1e-4 * schedule(step) 

326 

327 # ... 

328 

329 optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) 

330 ``` 

331 

332 Note: you might want to register your own custom optimizer using 

333 `tf.keras.utils.get_custom_objects()`. 

334 

335 Args: 

336 base_optimizer: An optimizer class that inherits from 

337 tf.optimizers.Optimizer. 

338 

339 Returns: 

340 A new optimizer class that inherits from DecoupledWeightDecayExtension 

341 and base_optimizer. 

342 """ 

343 

344 class OptimizerWithDecoupledWeightDecay( 

345 DecoupledWeightDecayExtension, base_optimizer 

346 ): 

347 """Base_optimizer with decoupled weight decay. 

348 

349 This class computes the update step of `base_optimizer` and 

350 additionally decays the variable with the weight decay being 

351 decoupled from the optimization steps w.r.t. to the loss 

352 function, as described by [Loshchilov & Hutter] 

353 (https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this 

354 simplifies hyperparameter search since it decouples the settings 

355 of weight decay and learning rate. For adaptive gradient 

356 algorithms, it regularizes variables with large gradients more 

357 than L2 regularization would, which was shown to yield better 

358 training loss and generalization error in the paper above. 

359 """ 

360 

361 @typechecked 

362 def __init__( 

363 self, 

364 weight_decay: Union[FloatTensorLike, Callable], 

365 *args, 

366 **kwargs, 

367 ): 

368 # super delegation is necessary here 

369 super().__init__(weight_decay, *args, **kwargs) 

370 

371 return OptimizerWithDecoupledWeightDecay 

372 

373 

374if hasattr(tf.keras.optimizers, "legacy"): 

375 ADAM_CLASS = tf.keras.optimizers.legacy.Adam 

376 SGD_CLASS = tf.keras.optimizers.legacy.SGD 

377else: 

378 ADAM_CLASS = tf.keras.optimizers.Adam 

379 SGD_CLASS = tf.keras.optimizers.SGD 

380 

381 

382@tf.keras.utils.register_keras_serializable(package="Addons") 

383class SGDW(DecoupledWeightDecayExtension, SGD_CLASS): 

384 """Optimizer that implements the Momentum algorithm with weight_decay. 

385 

386 This is an implementation of the SGDW optimizer described in "Decoupled 

387 Weight Decay Regularization" by [Loshchilov & Hutter] 

388 (https://arxiv.org/pdf/1711.05101.pdf). 

389 It computes the update step of `tf.keras.optimizers.SGD` and additionally 

390 decays the variable. Note that this is different from adding 

391 L2 regularization on the variables to the loss. Decoupling the weight decay 

392 from other hyperparameters (in particular the learning rate) simplifies 

393 hyperparameter search. 

394 

395 For further information see the documentation of the SGD Optimizer. 

396 

397 This optimizer can also be instantiated as 

398 ```python 

399 extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD, 

400 weight_decay=weight_decay) 

401 ``` 

402 

403 Note: when applying a decay to the learning rate, be sure to manually apply 

404 the decay to the `weight_decay` as well. For example: 

405 

406 ```python 

407 step = tf.Variable(0, trainable=False) 

408 schedule = tf.optimizers.schedules.PiecewiseConstantDecay( 

409 [10000, 15000], [1e-0, 1e-1, 1e-2]) 

410 # lr and wd can be a function or a tensor 

411 lr = 1e-1 * schedule(step) 

412 wd = lambda: 1e-4 * schedule(step) 

413 

414 # ... 

415 

416 optimizer = tfa.optimizers.SGDW( 

417 learning_rate=lr, weight_decay=wd, momentum=0.9) 

418 ``` 

419 """ 

420 

421 @typechecked 

422 def __init__( 

423 self, 

424 weight_decay: Union[FloatTensorLike, Callable], 

425 learning_rate: Union[FloatTensorLike, Callable] = 0.001, 

426 momentum: Union[FloatTensorLike, Callable] = 0.0, 

427 nesterov: bool = False, 

428 name: str = "SGDW", 

429 **kwargs, 

430 ): 

431 """Construct a new SGDW optimizer. 

432 

433 For further information see the documentation of the SGD Optimizer. 

434 

435 Args: 

436 learning_rate: float hyperparameter >= 0. Learning rate. 

437 momentum: float hyperparameter >= 0 that accelerates SGD in the 

438 relevant direction and dampens oscillations. 

439 nesterov: boolean. Whether to apply Nesterov momentum. 

440 name: Optional name prefix for the operations created when applying 

441 gradients. Defaults to 'SGD'. 

442 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, 

443 `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip 

444 gradients by norm; `clipvalue` is clip gradients by value. 

445 `decay` is included for backward compatibility to allow time 

446 inverse decay of learning rate. `lr` is included for backward 

447 compatibility, recommended to use `learning_rate` instead. 

448 `exclude_from_weight_decay` accepts list of regex patterns of 

449 variables excluded from weight decay. 

450 """ 

451 super().__init__( 

452 weight_decay, 

453 learning_rate=learning_rate, 

454 momentum=momentum, 

455 nesterov=nesterov, 

456 name=name, 

457 **kwargs, 

458 ) 

459 

460 

461@tf.keras.utils.register_keras_serializable(package="Addons") 

462class AdamW(DecoupledWeightDecayExtension, ADAM_CLASS): 

463 """Optimizer that implements the Adam algorithm with weight decay. 

464 

465 This is an implementation of the AdamW optimizer described in "Decoupled 

466 Weight Decay Regularization" by [Loshchilov & Hutter] 

467 (https://arxiv.org/pdf/1711.05101.pdf). 

468 

469 It computes the update step of `tf.keras.optimizers.Adam` and additionally 

470 decays the variable. Note that this is different from adding L2 

471 regularization on the variables to the loss: it regularizes variables with 

472 large gradients more than L2 regularization would, which was shown to yield 

473 better training loss and generalization error in the paper above. 

474 

475 For further information see the documentation of the Adam Optimizer. 

476 

477 This optimizer can also be instantiated as 

478 ```python 

479 extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam, 

480 weight_decay=weight_decay) 

481 ``` 

482 

483 Note: when applying a decay to the learning rate, be sure to manually apply 

484 the decay to the `weight_decay` as well. For example: 

485 

486 ```python 

487 step = tf.Variable(0, trainable=False) 

488 schedule = tf.optimizers.schedules.PiecewiseConstantDecay( 

489 [10000, 15000], [1e-0, 1e-1, 1e-2]) 

490 # lr and wd can be a function or a tensor 

491 lr = 1e-1 * schedule(step) 

492 wd = lambda: 1e-4 * schedule(step) 

493 

494 # ... 

495 

496 optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) 

497 ``` 

498 """ 

499 

500 @typechecked 

501 def __init__( 

502 self, 

503 weight_decay: Union[FloatTensorLike, Callable], 

504 learning_rate: Union[FloatTensorLike, Callable] = 0.001, 

505 beta_1: Union[FloatTensorLike, Callable] = 0.9, 

506 beta_2: Union[FloatTensorLike, Callable] = 0.999, 

507 epsilon: FloatTensorLike = 1e-07, 

508 amsgrad: bool = False, 

509 name: str = "AdamW", 

510 **kwargs, 

511 ): 

512 """Construct a new AdamW optimizer. 

513 

514 For further information see the documentation of the Adam Optimizer. 

515 

516 Args: 

517 weight_decay: A Tensor or a floating point value. The weight decay. 

518 learning_rate: A Tensor or a floating point value. The learning 

519 rate. 

520 beta_1: A float value or a constant float tensor. The exponential 

521 decay rate for the 1st moment estimates. 

522 beta_2: A float value or a constant float tensor. The exponential 

523 decay rate for the 2nd moment estimates. 

524 epsilon: A small constant for numerical stability. This epsilon is 

525 "epsilon hat" in the Kingma and Ba paper (in the formula just 

526 before Section 2.1), not the epsilon in Algorithm 1 of the 

527 paper. 

528 amsgrad: boolean. Whether to apply AMSGrad variant of this 

529 algorithm from the paper "On the Convergence of Adam and 

530 beyond". 

531 name: Optional name for the operations created when applying 

532 gradients. Defaults to "AdamW". 

533 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, 

534 `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip 

535 gradients by norm; `clipvalue` is clip gradients by value. 

536 `decay` is included for backward compatibility to allow time 

537 inverse decay of learning rate. `lr` is included for backward 

538 compatibility, recommended to use `learning_rate` instead. 

539 `exclude_from_weight_decay` accepts list of regex patterns of 

540 variables excluded from weight decay. 

541 """ 

542 super().__init__( 

543 weight_decay, 

544 learning_rate=learning_rate, 

545 beta_1=beta_1, 

546 beta_2=beta_2, 

547 epsilon=epsilon, 

548 amsgrad=amsgrad, 

549 name=name, 

550 **kwargs, 

551 )