Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/lookahead.py: 37%

76 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16import tensorflow as tf 

17from tensorflow_addons.utils import types 

18 

19from tensorflow_addons.optimizers import KerasLegacyOptimizer 

20from typeguard import typechecked 

21 

22 

23@tf.keras.utils.register_keras_serializable(package="Addons") 

24class Lookahead(KerasLegacyOptimizer): 

25 """This class allows to extend optimizers with the lookahead mechanism. 

26 

27 The mechanism is proposed by Michael R. Zhang et.al in the paper 

28 [Lookahead Optimizer: k steps forward, 1 step back] 

29 (https://arxiv.org/abs/1907.08610v1). The optimizer iteratively updates two 

30 sets of weights: the search directions for weights are chosen by the inner 

31 optimizer, while the "slow weights" are updated each `k` steps based on the 

32 directions of the "fast weights" and the two sets of weights are 

33 synchronized. This method improves the learning stability and lowers the 

34 variance of its inner optimizer. 

35 

36 Example of usage: 

37 

38 ```python 

39 opt = tf.keras.optimizers.SGD(learning_rate) 

40 opt = tfa.optimizers.Lookahead(opt) 

41 ``` 

42 """ 

43 

44 @typechecked 

45 def __init__( 

46 self, 

47 optimizer: types.Optimizer, 

48 sync_period: int = 6, 

49 slow_step_size: types.FloatTensorLike = 0.5, 

50 name: str = "Lookahead", 

51 **kwargs, 

52 ): 

53 r"""Wrap optimizer with the lookahead mechanism. 

54 

55 Args: 

56 optimizer: The original optimizer that will be used to compute 

57 and apply the gradients. 

58 sync_period: An integer. The synchronization period of lookahead. 

59 Enable lookahead mechanism by setting it with a positive value. 

60 slow_step_size: A floating point value. 

61 The ratio for updating the slow weights. 

62 name: Optional name for the operations created when applying 

63 gradients. Defaults to "Lookahead". 

64 **kwargs: keyword arguments. Allowed to be {`clipnorm`, 

65 `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients 

66 by norm; `clipvalue` is clip gradients by value, `decay` is 

67 included for backward compatibility to allow time inverse 

68 decay of learning rate. `lr` is included for backward 

69 compatibility, recommended to use `learning_rate` instead. 

70 """ 

71 super().__init__(name, **kwargs) 

72 

73 if isinstance(optimizer, str): 

74 if ( 

75 hasattr(tf.keras.optimizers, "legacy") 

76 and KerasLegacyOptimizer == tf.keras.optimizers.legacy.Optimizer 

77 ): 

78 optimizer = tf.keras.optimizers.get( 

79 optimizer, use_legacy_optimizer=True 

80 ) 

81 else: 

82 optimizer = tf.keras.optimizers.get(optimizer) 

83 if not isinstance(optimizer, KerasLegacyOptimizer): 

84 raise TypeError( 

85 "optimizer is not an object of tf.keras.optimizers.legacy.Optimizer " 

86 ) 

87 

88 self._optimizer = optimizer 

89 self._set_hyper("sync_period", sync_period) 

90 self._set_hyper("slow_step_size", slow_step_size) 

91 self._initialized = False 

92 self._track_trackable(self._optimizer, "lh_base_optimizer") 

93 

94 def _create_slots(self, var_list): 

95 self._optimizer._create_slots( 

96 var_list=var_list 

97 ) # pylint: disable=protected-access 

98 for var in var_list: 

99 self.add_slot(var, "slow", initializer=var) 

100 

101 def _create_hypers(self): 

102 self._optimizer._create_hypers() # pylint: disable=protected-access 

103 

104 def _prepare(self, var_list): 

105 return self._optimizer._prepare( 

106 var_list=var_list 

107 ) # pylint: disable=protected-access 

108 

109 def apply_gradients(self, grads_and_vars, name=None, **kwargs): 

110 self._optimizer._iterations = ( 

111 self.iterations 

112 ) # pylint: disable=protected-access 

113 return super().apply_gradients(grads_and_vars, name, **kwargs) 

114 

115 def _look_ahead_op(self, var): 

116 var_dtype = var.dtype.base_dtype 

117 slow_var = self.get_slot(var, "slow") 

118 local_step = tf.cast(self.iterations + 1, tf.dtypes.int64) 

119 sync_period = self._get_hyper("sync_period", tf.dtypes.int64) 

120 slow_step_size = self._get_hyper("slow_step_size", var_dtype) 

121 step_back = slow_var + slow_step_size * (var - slow_var) 

122 sync_cond = tf.equal( 

123 tf.math.floordiv(local_step, sync_period) * sync_period, local_step 

124 ) 

125 with tf.control_dependencies([step_back]): 

126 slow_update = slow_var.assign( 

127 tf.where(sync_cond, step_back, slow_var), 

128 use_locking=self._use_locking, 

129 ) 

130 var_update = var.assign( 

131 tf.where(sync_cond, step_back, var), 

132 use_locking=self._use_locking, 

133 ) 

134 return tf.group(slow_update, var_update) 

135 

136 @property 

137 def weights(self): 

138 return self._weights + self._optimizer.weights 

139 

140 def _resource_apply_dense(self, grad, var): 

141 train_op = self._optimizer._resource_apply_dense( 

142 grad, var 

143 ) # pylint: disable=protected-access 

144 with tf.control_dependencies([train_op]): 

145 look_ahead_op = self._look_ahead_op(var) 

146 return tf.group(train_op, look_ahead_op) 

147 

148 def _resource_apply_sparse(self, grad, var, indices): 

149 train_op = ( 

150 self._optimizer._resource_apply_sparse( # pylint: disable=protected-access 

151 grad, var, indices 

152 ) 

153 ) 

154 with tf.control_dependencies([train_op]): 

155 look_ahead_op = self._look_ahead_op(var) 

156 return tf.group(train_op, look_ahead_op) 

157 

158 def get_config(self): 

159 config = { 

160 "optimizer": tf.keras.optimizers.serialize(self._optimizer), 

161 "sync_period": self._serialize_hyperparameter("sync_period"), 

162 "slow_step_size": self._serialize_hyperparameter("slow_step_size"), 

163 } 

164 base_config = super().get_config() 

165 return {**base_config, **config} 

166 

167 @property 

168 def learning_rate(self): 

169 return self._optimizer._get_hyper("learning_rate") 

170 

171 @learning_rate.setter 

172 def learning_rate(self, learning_rate): 

173 self._optimizer._set_hyper("learning_rate", learning_rate) 

174 

175 @property 

176 def lr(self): 

177 return self.learning_rate 

178 

179 @lr.setter 

180 def lr(self, lr): 

181 self.learning_rate = lr 

182 

183 @classmethod 

184 def from_config(cls, config, custom_objects=None): 

185 optimizer = tf.keras.optimizers.deserialize( 

186 config.pop("optimizer"), custom_objects=custom_objects 

187 ) 

188 return cls(optimizer, **config)