Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/lamb.py: 19%

99 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Layer-wise Adaptive Moments (LAMB) optimizer. 

16 

17See paper [Large Batch Optimization for Deep Learning: Training BERT in 

1876 minutes](https://arxiv.org/abs/1904.00962). 

19""" 

20 

21import warnings 

22 

23from typing import Optional, Union, Callable, List 

24from typeguard import typechecked 

25 

26import tensorflow as tf 

27from tensorflow_addons.optimizers import KerasLegacyOptimizer 

28from tensorflow_addons.utils.types import FloatTensorLike 

29from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes 

30 

31 

32@tf.keras.utils.register_keras_serializable(package="Addons") 

33class LAMB(KerasLegacyOptimizer): 

34 """Optimizer that implements the Layer-wise Adaptive Moments (LAMB). 

35 

36 See paper [Large Batch Optimization for Deep Learning: Training BERT 

37 in 76 minutes](https://arxiv.org/abs/1904.00962). 

38 """ 

39 

40 @typechecked 

41 def __init__( 

42 self, 

43 learning_rate: Union[FloatTensorLike, Callable] = 0.001, 

44 beta_1: FloatTensorLike = 0.9, 

45 beta_2: FloatTensorLike = 0.999, 

46 epsilon: FloatTensorLike = 1e-6, 

47 weight_decay: FloatTensorLike = 0.0, 

48 exclude_from_weight_decay: Optional[List[str]] = None, 

49 exclude_from_layer_adaptation: Optional[List[str]] = None, 

50 name: str = "LAMB", 

51 **kwargs, 

52 ): 

53 """Construct a new LAMB optimizer. 

54 

55 Args: 

56 learning_rate: A `Tensor` or a floating point value. or a schedule 

57 that is a `tf.keras.optimizers.schedules.LearningRateSchedule` 

58 The learning rate. 

59 beta_1: A `float` value or a constant `float` tensor. 

60 The exponential decay rate for the 1st moment estimates. 

61 beta_2: A `float` value or a constant `float` tensor. 

62 The exponential decay rate for the 2nd moment estimates. 

63 epsilon: A small constant for numerical stability. 

64 weight_decay: weight decay. 

65 exclude_from_weight_decay: List of regex patterns of 

66 variables excluded from weight decay. Variables whose name 

67 contain a substring matching the pattern will be excluded. 

68 exclude_from_layer_adaptation: List of regex patterns of 

69 variables excluded from layer adaptation. Variables whose name 

70 contain a substring matching the pattern will be excluded. 

71 name: Optional name for the operations created when applying 

72 gradients. Defaults to "LAMB". 

73 **kwargs: keyword arguments. Allowed to be {`clipnorm`, 

74 `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by 

75 norm; `clipvalue` is clip gradients by value, `decay` is 

76 included for backward compatibility to allow time inverse 

77 decay of learning rate. `lr` is included for backward 

78 compatibility, recommended to use `learning_rate` instead. 

79 """ 

80 

81 if "weight_decay_rate" in kwargs: 

82 warnings.warn( 

83 "weight_decay_rate has been renamed to weight_decay," 

84 "and will be deprecated in Addons 0.18.", 

85 DeprecationWarning, 

86 ) 

87 weight_decay = kwargs["weight_decay_rate"] 

88 del kwargs["weight_decay_rate"] 

89 

90 super().__init__(name, **kwargs) 

91 

92 # Just adding the square of the weights to the loss function is *not* 

93 # the correct way of using L2 regularization/weight decay with Adam, 

94 # since that will interact with the m and v parameters in strange ways. 

95 # 

96 # Instead we want to decay the weights in a manner that doesn't interact 

97 # with the m/v parameters. 

98 self._set_hyper("weight_decay", weight_decay) 

99 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 

100 

101 # This is learning rate decay for using keras learning rate schedule. 

102 self._set_hyper("decay", self._initial_decay) 

103 self._set_hyper("beta_1", beta_1) 

104 self._set_hyper("beta_2", beta_2) 

105 self.epsilon = epsilon or tf.backend_config.epsilon() 

106 self.exclude_from_weight_decay = exclude_from_weight_decay 

107 # exclude_from_layer_adaptation is set to exclude_from_weight_decay if 

108 # the arg is None. 

109 if exclude_from_layer_adaptation: 

110 self.exclude_from_layer_adaptation = exclude_from_layer_adaptation 

111 else: 

112 self.exclude_from_layer_adaptation = exclude_from_weight_decay 

113 

114 def _create_slots(self, var_list): 

115 # Create slots for the first and second moments. 

116 # Separate for-loops to respect the ordering of slot variables from v1. 

117 for var in var_list: 

118 self.add_slot(var, "m") 

119 for var in var_list: 

120 self.add_slot(var, "v") 

121 

122 def _prepare_local(self, var_device, var_dtype, apply_state): 

123 super()._prepare_local(var_device, var_dtype, apply_state) 

124 

125 local_step = tf.cast(self.iterations + 1, var_dtype) 

126 beta_1_t = tf.identity(self._get_hyper("beta_1", var_dtype)) 

127 beta_2_t = tf.identity(self._get_hyper("beta_2", var_dtype)) 

128 weight_decay = tf.identity(self._get_hyper("weight_decay", var_dtype)) 

129 beta_1_power = tf.pow(beta_1_t, local_step) 

130 beta_2_power = tf.pow(beta_2_t, local_step) 

131 apply_state[(var_device, var_dtype)].update( 

132 dict( 

133 weight_decay=weight_decay, 

134 epsilon=tf.convert_to_tensor(self.epsilon, var_dtype), 

135 beta_1_t=beta_1_t, 

136 beta_1_power=beta_1_power, 

137 one_minus_beta_1_t=1 - beta_1_t, 

138 beta_2_t=beta_2_t, 

139 beta_2_power=beta_2_power, 

140 one_minus_beta_2_t=1 - beta_2_t, 

141 ) 

142 ) 

143 

144 def _resource_apply_dense(self, grad, var, apply_state=None): 

145 var_device, var_dtype = var.device, var.dtype.base_dtype 

146 coefficients = (apply_state or {}).get( 

147 (var_device, var_dtype) 

148 ) or self._fallback_apply_state(var_device, var_dtype) 

149 

150 # m_t = beta1 * m + (1 - beta1) * g_t 

151 m = self.get_slot(var, "m") 

152 m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"] 

153 m_t = m * coefficients["beta_1_t"] + m_scaled_g_values 

154 m_t = m.assign(m_t, use_locking=self._use_locking) 

155 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 

156 v = self.get_slot(var, "v") 

157 v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"] 

158 v_t = v * coefficients["beta_2_t"] + v_scaled_g_values 

159 v_t = v.assign(v_t, use_locking=self._use_locking) 

160 

161 m_t_hat = m_t / (1.0 - coefficients["beta_1_power"]) 

162 v_t_hat = v_t / (1.0 - coefficients["beta_2_power"]) 

163 

164 v_sqrt = tf.sqrt(v_t_hat) 

165 update = m_t_hat / (v_sqrt + coefficients["epsilon"]) 

166 

167 if self._do_use_weight_decay(var): 

168 update += coefficients["weight_decay"] * var 

169 

170 ratio = 1.0 

171 if self._do_layer_adaptation(var): 

172 w_norm = tf.norm(var, ord=2) 

173 g_norm = tf.norm(update, ord=2) 

174 ratio = tf.where( 

175 tf.greater(w_norm, 0), 

176 tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), 

177 1.0, 

178 ) 

179 

180 var_update = var - ratio * coefficients["lr_t"] * update 

181 return var.assign(var_update, use_locking=self._use_locking) 

182 

183 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 

184 var_device, var_dtype = var.device, var.dtype.base_dtype 

185 coefficients = (apply_state or {}).get( 

186 (var_device, var_dtype) 

187 ) or self._fallback_apply_state(var_device, var_dtype) 

188 

189 # m_t = beta1 * m + (1 - beta1) * g_t 

190 m = self.get_slot(var, "m") 

191 m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"] 

192 m_t = m.assign(m * coefficients["beta_1_t"], use_locking=self._use_locking) 

193 with tf.control_dependencies([m_t]): 

194 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) 

195 

196 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 

197 v = self.get_slot(var, "v") 

198 v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"] 

199 v_t = v.assign(v * coefficients["beta_2_t"], use_locking=self._use_locking) 

200 with tf.control_dependencies([v_t]): 

201 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) 

202 

203 m_t_hat = m_t / (1.0 - coefficients["beta_1_power"]) 

204 v_t_hat = v_t / (1.0 - coefficients["beta_2_power"]) 

205 

206 v_sqrt = tf.sqrt(v_t_hat) 

207 update = m_t_hat / (v_sqrt + coefficients["epsilon"]) 

208 

209 if self._do_use_weight_decay(var): 

210 update += coefficients["weight_decay"] * var 

211 

212 ratio = 1.0 

213 if self._do_layer_adaptation(var): 

214 w_norm = tf.norm(var, ord=2) 

215 g_norm = tf.norm(update, ord=2) 

216 ratio = tf.where( 

217 tf.greater(w_norm, 0), 

218 tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), 

219 1.0, 

220 ) 

221 

222 var_update = var.assign_sub( 

223 ratio * coefficients["lr_t"] * update, use_locking=self._use_locking 

224 ) 

225 return tf.group(*[var_update, m_t, v_t]) 

226 

227 def get_config(self): 

228 config = super().get_config() 

229 config.update( 

230 { 

231 "learning_rate": self._serialize_hyperparameter("learning_rate"), 

232 "weight_decay": self._serialize_hyperparameter("weight_decay"), 

233 "decay": self._serialize_hyperparameter("decay"), 

234 "beta_1": self._serialize_hyperparameter("beta_1"), 

235 "beta_2": self._serialize_hyperparameter("beta_2"), 

236 "epsilon": self.epsilon, 

237 "exclude_from_weight_decay": self.exclude_from_weight_decay, 

238 "exclude_from_layer_adaptation": self.exclude_from_layer_adaptation, 

239 } 

240 ) 

241 return config 

242 

243 def _do_use_weight_decay(self, variable): 

244 """Whether to use L2 weight decay for `param_name`.""" 

245 return not is_variable_matched_by_regexes( 

246 variable, self.exclude_from_weight_decay 

247 ) 

248 

249 def _do_layer_adaptation(self, variable): 

250 """Whether to do layer-wise learning rate adaptation for 

251 `param_name`.""" 

252 return not is_variable_matched_by_regexes( 

253 variable, self.exclude_from_layer_adaptation 

254 )