Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/nadam.py: 19%

67 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Nadam optimizer implementation.""" 

16 

17import tensorflow.compat.v2 as tf 

18 

19from keras.src.optimizers import optimizer 

20from keras.src.saving.object_registration import register_keras_serializable 

21 

22# isort: off 

23from tensorflow.python.util.tf_export import keras_export 

24 

25 

26@register_keras_serializable() 

27@keras_export( 

28 "keras.optimizers.experimental.Nadam", "keras.optimizers.Nadam", v1=[] 

29) 

30class Nadam(optimizer.Optimizer): 

31 r"""Optimizer that implements the Nadam algorithm. 

32 

33 Much like Adam is essentially RMSprop with momentum, Nadam is Adam with 

34 Nesterov momentum. 

35 

36 Args: 

37 learning_rate: A `tf.Tensor`, floating point value, a schedule that is a 

38 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable 

39 that takes no arguments and returns the actual value to use. The 

40 learning rate. Defaults to `0.001`. 

41 beta_1: A float value or a constant float tensor, or a callable 

42 that takes no arguments and returns the actual value to use. The 

43 exponential decay rate for the 1st moment estimates. Defaults to `0.9`. 

44 beta_2: A float value or a constant float tensor, or a callable 

45 that takes no arguments and returns the actual value to use. The 

46 exponential decay rate for the 2nd moment estimates. Defaults to 

47 `0.999`. 

48 epsilon: A small constant for numerical stability. This epsilon is 

49 "epsilon hat" in the Kingma and Ba paper (in the formula just before 

50 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to 

51 `1e-7`. 

52 {{base_optimizer_keyword_args}} 

53 

54 Reference: 

55 - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf). 

56 

57 """ 

58 

59 def __init__( 

60 self, 

61 learning_rate=0.001, 

62 beta_1=0.9, 

63 beta_2=0.999, 

64 epsilon=1e-7, 

65 weight_decay=None, 

66 clipnorm=None, 

67 clipvalue=None, 

68 global_clipnorm=None, 

69 use_ema=False, 

70 ema_momentum=0.99, 

71 ema_overwrite_frequency=None, 

72 jit_compile=True, 

73 name="Nadam", 

74 **kwargs 

75 ): 

76 super().__init__( 

77 name=name, 

78 weight_decay=weight_decay, 

79 clipnorm=clipnorm, 

80 clipvalue=clipvalue, 

81 global_clipnorm=global_clipnorm, 

82 use_ema=use_ema, 

83 ema_momentum=ema_momentum, 

84 ema_overwrite_frequency=ema_overwrite_frequency, 

85 jit_compile=jit_compile, 

86 **kwargs 

87 ) 

88 self._learning_rate = self._build_learning_rate(learning_rate) 

89 self.beta_1 = beta_1 

90 self.beta_2 = beta_2 

91 self.epsilon = epsilon 

92 

93 def build(self, var_list): 

94 """Initialize optimizer variables. 

95 

96 Nadam optimizer has 2 types of variables: momentums and velocities. 

97 

98 Args: 

99 var_list: list of model variables to build Nadam variables on. 

100 """ 

101 super().build(var_list) 

102 if getattr(self, "_built", False): 

103 return 

104 self._built = True 

105 self._momentums = [] 

106 self._velocities = [] 

107 self._u_product = tf.Variable(1.0, dtype=var_list[0].dtype) 

108 # Keep a counter on how many times of _u_product has been computed to 

109 # avoid duplicated computations. 

110 self._u_product_counter = 1 

111 

112 for var in var_list: 

113 self._momentums.append( 

114 self.add_variable_from_reference( 

115 model_variable=var, variable_name="m" 

116 ) 

117 ) 

118 self._velocities.append( 

119 self.add_variable_from_reference( 

120 model_variable=var, variable_name="v" 

121 ) 

122 ) 

123 

124 def update_step(self, gradient, variable): 

125 """Update step given gradient and the associated model variable.""" 

126 var_dtype = variable.dtype 

127 lr = tf.cast(self.learning_rate, var_dtype) 

128 local_step = tf.cast(self.iterations + 1, var_dtype) 

129 next_step = tf.cast(self.iterations + 2, var_dtype) 

130 decay = tf.cast(0.96, var_dtype) 

131 beta_1 = tf.cast(self.beta_1, var_dtype) 

132 beta_2 = tf.cast(self.beta_2, var_dtype) 

133 u_t = beta_1 * (1.0 - 0.5 * (tf.pow(decay, local_step))) 

134 u_t_1 = beta_1 * (1.0 - 0.5 * (tf.pow(decay, next_step))) 

135 

136 def get_cached_u_product(): 

137 return self._u_product 

138 

139 def compute_new_u_product(): 

140 u_product_t = self._u_product * u_t 

141 self._u_product.assign(u_product_t) 

142 self._u_product_counter += 1 

143 return u_product_t 

144 

145 u_product_t = tf.cond( 

146 self._u_product_counter == (self.iterations + 2), 

147 true_fn=get_cached_u_product, 

148 false_fn=compute_new_u_product, 

149 ) 

150 u_product_t_1 = u_product_t * u_t_1 

151 beta_2_power = tf.pow(beta_2, local_step) 

152 

153 var_key = self._var_key(variable) 

154 m = self._momentums[self._index_dict[var_key]] 

155 v = self._velocities[self._index_dict[var_key]] 

156 

157 if isinstance(gradient, tf.IndexedSlices): 

158 # Sparse gradients. 

159 m.assign_add(-m * (1 - beta_1)) 

160 m.scatter_add( 

161 tf.IndexedSlices( 

162 gradient.values * (1 - beta_1), gradient.indices 

163 ) 

164 ) 

165 v.assign_add(-v * (1 - beta_2)) 

166 v.scatter_add( 

167 tf.IndexedSlices( 

168 tf.square(gradient.values) * (1 - beta_2), gradient.indices 

169 ) 

170 ) 

171 m_hat = u_t_1 * m / (1 - u_product_t_1) + (1 - u_t) * gradient / ( 

172 1 - u_product_t 

173 ) 

174 v_hat = v / (1 - beta_2_power) 

175 

176 variable.assign_sub((m_hat * lr) / (tf.sqrt(v_hat) + self.epsilon)) 

177 else: 

178 # Dense gradients. 

179 m.assign_add((gradient - m) * (1 - beta_1)) 

180 v.assign_add((tf.square(gradient) - v) * (1 - beta_2)) 

181 m_hat = u_t_1 * m / (1 - u_product_t_1) + (1 - u_t) * gradient / ( 

182 1 - u_product_t 

183 ) 

184 v_hat = v / (1 - beta_2_power) 

185 

186 variable.assign_sub((m_hat * lr) / (tf.sqrt(v_hat) + self.epsilon)) 

187 

188 def get_config(self): 

189 config = super().get_config() 

190 

191 config.update( 

192 { 

193 "learning_rate": self._serialize_hyperparameter( 

194 self._learning_rate 

195 ), 

196 "beta_1": self.beta_1, 

197 "beta_2": self.beta_2, 

198 "epsilon": self.epsilon, 

199 } 

200 ) 

201 return config 

202 

203 

204Nadam.__doc__ = Nadam.__doc__.replace( 

205 "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args 

206) 

207