Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/legacy/nadam.py: 18%

88 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Nadam optimizer implementation.""" 

16 

17import tensorflow.compat.v2 as tf 

18 

19from keras.src import backend_config 

20from keras.src.optimizers.legacy import optimizer_v2 

21from keras.src.optimizers.schedules import learning_rate_schedule 

22 

23# isort: off 

24from tensorflow.python.util.tf_export import keras_export 

25 

26 

27@keras_export( 

28 "keras.optimizers.legacy.Nadam", 

29 v1=["keras.optimizers.Nadam", "keras.optimizers.legacy.Nadam"], 

30) 

31class Nadam(optimizer_v2.OptimizerV2): 

32 r"""Optimizer that implements the NAdam algorithm. 

33 Much like Adam is essentially RMSprop with momentum, Nadam is Adam with 

34 Nesterov momentum. 

35 

36 Args: 

37 learning_rate: A Tensor or a floating point value. The learning rate. 

38 beta_1: A float value or a constant float tensor. The exponential decay 

39 rate for the 1st moment estimates. 

40 beta_2: A float value or a constant float tensor. The exponential decay 

41 rate for the exponentially weighted infinity norm. 

42 epsilon: A small constant for numerical stability. 

43 name: Optional name for the operations created when applying gradients. 

44 Defaults to `"Nadam"`. 

45 **kwargs: keyword arguments. Allowed arguments are `clipvalue`, 

46 `clipnorm`, `global_clipnorm`. 

47 If `clipvalue` (float) is set, the gradient of each weight 

48 is clipped to be no higher than this value. 

49 If `clipnorm` (float) is set, the gradient of each weight 

50 is individually clipped so that its norm is no higher than this value. 

51 If `global_clipnorm` (float) is set the gradient of all weights is 

52 clipped so that their global norm is no higher than this value. 

53 

54 Usage Example: 

55 >>> opt = tf.keras.optimizers.legacy.Nadam(learning_rate=0.2) 

56 >>> var1 = tf.Variable(10.0) 

57 >>> loss = lambda: (var1 ** 2) / 2.0 

58 >>> step_count = opt.minimize(loss, [var1]).numpy() 

59 >>> "{:.1f}".format(var1.numpy()) 

60 9.8 

61 

62 Reference: 

63 - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf). 

64 """ 

65 

66 _HAS_AGGREGATE_GRAD = True 

67 

68 def __init__( 

69 self, 

70 learning_rate=0.001, 

71 beta_1=0.9, 

72 beta_2=0.999, 

73 epsilon=1e-7, 

74 name="Nadam", 

75 **kwargs 

76 ): 

77 # Backwards compatibility with keras NAdam optimizer. 

78 kwargs["decay"] = kwargs.pop("schedule_decay", 0.004) 

79 learning_rate = kwargs.get("lr", learning_rate) 

80 if isinstance( 

81 learning_rate, learning_rate_schedule.LearningRateSchedule 

82 ): 

83 raise ValueError( 

84 "The Nadam optimizer does not support " 

85 "tf.keras.optimizers.LearningRateSchedules as the " 

86 "learning rate." 

87 ) 

88 

89 super().__init__(name, **kwargs) 

90 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 

91 self._set_hyper("decay", self._initial_decay) 

92 self._set_hyper("beta_1", beta_1) 

93 self._set_hyper("beta_2", beta_2) 

94 self.epsilon = epsilon or backend_config.epsilon() 

95 self._m_cache = None 

96 

97 def _create_slots(self, var_list): 

98 var_dtype = var_list[0].dtype.base_dtype 

99 if self._m_cache is None: 

100 self._m_cache = self.add_weight( 

101 "momentum_cache", 

102 shape=[], 

103 dtype=var_dtype, 

104 initializer="ones", 

105 trainable=False, 

106 aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, 

107 ) 

108 self._weights.append(self._m_cache) 

109 # Separate for-loops to respect the ordering of slot variables from v1. 

110 for var in var_list: 

111 # Create slots for the first moments. 

112 self.add_slot(var, "m") 

113 for var in var_list: 

114 # Create slots for the second moments. 

115 self.add_slot(var, "v") 

116 

117 def _prepare_local(self, var_device, var_dtype, apply_state): 

118 lr_t = tf.identity(self._get_hyper("learning_rate", var_dtype)) 

119 beta_1_t = tf.identity(self._get_hyper("beta_1", var_dtype)) 

120 beta_2_t = tf.identity(self._get_hyper("beta_2", var_dtype)) 

121 local_step = tf.cast(self.iterations + 1, var_dtype) 

122 next_step = tf.cast(self.iterations + 2, var_dtype) 

123 

124 decay_base = tf.cast(0.96, var_dtype) 

125 

126 m_t = beta_1_t * ( 

127 1.0 - 0.5 * (tf.pow(decay_base, self._initial_decay * local_step)) 

128 ) 

129 m_t_1 = beta_1_t * ( 

130 1.0 - 0.5 * (tf.pow(decay_base, self._initial_decay * next_step)) 

131 ) 

132 

133 m_schedule_new = tf.cast(self._m_cache_read, var_dtype) * m_t 

134 if var_dtype is self._m_cache.dtype: 

135 m_schedule_new = tf.identity( 

136 tf.compat.v1.assign( 

137 self._m_cache, m_schedule_new, use_locking=self._use_locking 

138 ) 

139 ) 

140 m_schedule_next = m_schedule_new * m_t_1 

141 

142 apply_state[(var_device, var_dtype)] = dict( 

143 lr_t=lr_t, 

144 neg_lr_t=-lr_t, 

145 epsilon=tf.convert_to_tensor(self.epsilon, var_dtype), 

146 beta_1_t=beta_1_t, 

147 beta_2_t=beta_2_t, 

148 m_t=m_t, 

149 m_t_1=m_t_1, 

150 one_minus_beta_1_t=1 - beta_1_t, 

151 one_minus_beta_2_t=1 - beta_2_t, 

152 one_minus_m_t=1.0 - m_t, 

153 one_minus_m_schedule_new=1.0 - m_schedule_new, 

154 one_minus_m_schedule_next=1.0 - m_schedule_next, 

155 v_t_prime_denominator=1.0 - tf.pow(beta_2_t, local_step), 

156 ) 

157 

158 def _prepare(self, var_list): 

159 # Get the value of the momentum cache before starting to apply 

160 # gradients. 

161 self._m_cache_read = tf.identity(self._m_cache) 

162 return super()._prepare(var_list) 

163 

164 def _resource_apply_dense(self, grad, var, apply_state=None): 

165 var_device, var_dtype = var.device, var.dtype.base_dtype 

166 coefficients = (apply_state or {}).get( 

167 (var_device, var_dtype) 

168 ) or self._fallback_apply_state(var_device, var_dtype) 

169 

170 m = self.get_slot(var, "m") 

171 v = self.get_slot(var, "v") 

172 

173 g_prime = grad / coefficients["one_minus_m_schedule_new"] 

174 m_t = ( 

175 coefficients["beta_1_t"] * m 

176 + coefficients["one_minus_beta_1_t"] * grad 

177 ) 

178 m_t = tf.compat.v1.assign(m, m_t, use_locking=self._use_locking) 

179 m_t_prime = m_t / coefficients["one_minus_m_schedule_next"] 

180 v_t = coefficients["beta_2_t"] * v + coefficients[ 

181 "one_minus_beta_2_t" 

182 ] * tf.square(grad) 

183 v_t = tf.compat.v1.assign(v, v_t, use_locking=self._use_locking) 

184 v_t_prime = v_t / coefficients["v_t_prime_denominator"] 

185 m_t_bar = ( 

186 coefficients["one_minus_m_t"] * g_prime 

187 + coefficients["m_t_1"] * m_t_prime 

188 ) 

189 var_t = var - coefficients["lr_t"] * m_t_bar / ( 

190 tf.sqrt(v_t_prime) + coefficients["epsilon"] 

191 ) 

192 return tf.compat.v1.assign(var, var_t, use_locking=self._use_locking).op 

193 

194 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 

195 var_device, var_dtype = var.device, var.dtype.base_dtype 

196 coefficients = (apply_state or {}).get( 

197 (var_device, var_dtype) 

198 ) or self._fallback_apply_state(var_device, var_dtype) 

199 

200 m = self.get_slot(var, "m") 

201 v = self.get_slot(var, "v") 

202 

203 g_prime = grad / coefficients["one_minus_m_schedule_new"] 

204 

205 # m_t = beta1 * m + (1 - beta1) * g_t 

206 m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"] 

207 m_t = tf.compat.v1.assign( 

208 m, m * coefficients["beta_1_t"], use_locking=self._use_locking 

209 ) 

210 

211 with tf.control_dependencies([m_t]): 

212 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) 

213 m_t_slice = tf.gather(m_t, indices) 

214 

215 m_t_prime = m_t_slice / coefficients["one_minus_m_schedule_next"] 

216 m_t_bar = ( 

217 coefficients["one_minus_m_t"] * g_prime 

218 + coefficients["m_t_1"] * m_t_prime 

219 ) 

220 

221 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 

222 v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"] 

223 v_t = tf.compat.v1.assign( 

224 v, v * coefficients["beta_2_t"], use_locking=self._use_locking 

225 ) 

226 

227 with tf.control_dependencies([v_t]): 

228 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) 

229 v_t_slice = tf.gather(v_t, indices) 

230 

231 v_t_prime = v_t_slice / coefficients["v_t_prime_denominator"] 

232 v_prime_sqrt_plus_eps = tf.sqrt(v_t_prime) + coefficients["epsilon"] 

233 

234 var_update = self._resource_scatter_add( 

235 var, 

236 indices, 

237 coefficients["neg_lr_t"] * m_t_bar / v_prime_sqrt_plus_eps, 

238 ) 

239 return tf.group(*[var_update, m_t_bar, v_t]) 

240 

241 def get_config(self): 

242 config = super().get_config() 

243 config.update( 

244 { 

245 "learning_rate": self._serialize_hyperparameter( 

246 "learning_rate" 

247 ), 

248 "decay": self._initial_decay, 

249 "beta_1": self._serialize_hyperparameter("beta_1"), 

250 "beta_2": self._serialize_hyperparameter("beta_2"), 

251 "epsilon": self.epsilon, 

252 } 

253 ) 

254 return config 

255