Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/nadam.py: 23%

94 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Nadam optimizer implementation.""" 

16# pylint: disable=g-classes-have-attributes 

17 

18from tensorflow.python.framework import ops 

19from tensorflow.python.framework import tensor_conversion 

20from tensorflow.python.keras import backend_config 

21from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 

22from tensorflow.python.keras.optimizer_v2 import optimizer_v2 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import control_flow_ops 

25from tensorflow.python.ops import math_ops 

26from tensorflow.python.ops import state_ops 

27from tensorflow.python.ops import variables as tf_variables 

28from tensorflow.python.util.tf_export import keras_export 

29 

30 

31@keras_export('keras.optimizers.Nadam') 

32class Nadam(optimizer_v2.OptimizerV2): 

33 r"""Optimizer that implements the NAdam algorithm. 

34 Much like Adam is essentially RMSprop with momentum, Nadam is Adam with 

35 Nesterov momentum. 

36 

37 Args: 

38 learning_rate: A Tensor or a floating point value. The learning rate. 

39 beta_1: A float value or a constant float tensor. The exponential decay 

40 rate for the 1st moment estimates. 

41 beta_2: A float value or a constant float tensor. The exponential decay 

42 rate for the exponentially weighted infinity norm. 

43 epsilon: A small constant for numerical stability. 

44 name: Optional name for the operations created when applying gradients. 

45 Defaults to `"Nadam"`. 

46 **kwargs: Keyword arguments. Allowed to be one of 

47 `"clipnorm"` or `"clipvalue"`. 

48 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 

49 gradients by value. 

50 

51 Usage Example: 

52 >>> opt = tf.keras.optimizers.Nadam(learning_rate=0.2) 

53 >>> var1 = tf.Variable(10.0) 

54 >>> loss = lambda: (var1 ** 2) / 2.0 

55 >>> step_count = opt.minimize(loss, [var1]).numpy() 

56 >>> "{:.1f}".format(var1.numpy()) 

57 9.8 

58 

59 Reference: 

60 - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf). 

61 """ 

62 

63 _HAS_AGGREGATE_GRAD = True 

64 

65 def __init__(self, 

66 learning_rate=0.001, 

67 beta_1=0.9, 

68 beta_2=0.999, 

69 epsilon=1e-7, 

70 name='Nadam', 

71 **kwargs): 

72 # Backwards compatibility with keras NAdam optimizer. 

73 kwargs['decay'] = kwargs.pop('schedule_decay', 0.004) 

74 learning_rate = kwargs.get('lr', learning_rate) 

75 if isinstance(learning_rate, learning_rate_schedule.LearningRateSchedule): 

76 raise ValueError('The Nadam optimizer does not support ' 

77 'tf.keras.optimizers.LearningRateSchedules as the ' 

78 'learning rate.') 

79 

80 super(Nadam, self).__init__(name, **kwargs) 

81 self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) 

82 self._set_hyper('decay', self._initial_decay) 

83 self._set_hyper('beta_1', beta_1) 

84 self._set_hyper('beta_2', beta_2) 

85 self.epsilon = epsilon or backend_config.epsilon() 

86 self._m_cache = None 

87 

88 def _create_slots(self, var_list): 

89 var_dtype = var_list[0].dtype.base_dtype 

90 if self._m_cache is None: 

91 self._m_cache = self.add_weight( 

92 'momentum_cache', 

93 shape=[], 

94 dtype=var_dtype, 

95 initializer='ones', 

96 trainable=False, 

97 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 

98 self._weights.append(self._m_cache) 

99 # Separate for-loops to respect the ordering of slot variables from v1. 

100 for var in var_list: 

101 # Create slots for the first moments. 

102 self.add_slot(var, 'm') 

103 for var in var_list: 

104 # Create slots for the second moments. 

105 self.add_slot(var, 'v') 

106 

107 def _prepare_local(self, var_device, var_dtype, apply_state): 

108 lr_t = array_ops.identity(self._get_hyper('learning_rate', var_dtype)) 

109 beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) 

110 beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) 

111 local_step = math_ops.cast(self.iterations + 1, var_dtype) 

112 next_step = math_ops.cast(self.iterations + 2, var_dtype) 

113 

114 decay_base = math_ops.cast(0.96, var_dtype) 

115 

116 m_t = beta_1_t * (1. - 0.5 * ( 

117 math_ops.pow(decay_base, self._initial_decay * local_step))) 

118 m_t_1 = beta_1_t * (1. - 0.5 * ( 

119 math_ops.pow(decay_base, self._initial_decay * next_step))) 

120 

121 m_schedule_new = math_ops.cast(self._m_cache_read, var_dtype) * m_t 

122 if var_dtype is self._m_cache.dtype: 

123 m_schedule_new = array_ops.identity(state_ops.assign( 

124 self._m_cache, m_schedule_new, use_locking=self._use_locking)) 

125 m_schedule_next = m_schedule_new * m_t_1 

126 

127 apply_state[(var_device, var_dtype)] = dict( 

128 lr_t=lr_t, 

129 neg_lr_t=-lr_t, # pylint: disable=invalid-unary-operand-type 

130 epsilon=tensor_conversion.convert_to_tensor_v2_with_dispatch( 

131 self.epsilon, var_dtype 

132 ), 

133 beta_1_t=beta_1_t, 

134 beta_2_t=beta_2_t, 

135 m_t=m_t, 

136 m_t_1=m_t_1, 

137 one_minus_beta_1_t=1 - beta_1_t, 

138 one_minus_beta_2_t=1 - beta_2_t, 

139 one_minus_m_t=1.0 - m_t, 

140 one_minus_m_schedule_new=1.0 - m_schedule_new, 

141 one_minus_m_schedule_next=1.0 - m_schedule_next, 

142 v_t_prime_denominator=1.0 - math_ops.pow(beta_2_t, local_step), 

143 ) 

144 

145 def _prepare(self, var_list): 

146 # Get the value of the momentum cache before starting to apply gradients. 

147 self._m_cache_read = array_ops.identity(self._m_cache) 

148 return super(Nadam, self)._prepare(var_list) 

149 

150 def _resource_apply_dense(self, grad, var, apply_state=None): 

151 var_device, var_dtype = var.device, var.dtype.base_dtype 

152 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 

153 or self._fallback_apply_state(var_device, var_dtype)) 

154 

155 m = self.get_slot(var, 'm') 

156 v = self.get_slot(var, 'v') 

157 

158 g_prime = grad / coefficients['one_minus_m_schedule_new'] 

159 m_t = (coefficients['beta_1_t'] * m + 

160 coefficients['one_minus_beta_1_t'] * grad) 

161 m_t = state_ops.assign(m, m_t, use_locking=self._use_locking) 

162 m_t_prime = m_t / coefficients['one_minus_m_schedule_next'] 

163 v_t = (coefficients['beta_2_t'] * v + 

164 coefficients['one_minus_beta_2_t'] * math_ops.square(grad)) 

165 v_t = state_ops.assign(v, v_t, use_locking=self._use_locking) 

166 v_t_prime = v_t / coefficients['v_t_prime_denominator'] 

167 m_t_bar = (coefficients['one_minus_m_t'] * g_prime + 

168 coefficients['m_t_1'] * m_t_prime) 

169 var_t = var - coefficients['lr_t'] * m_t_bar / ( 

170 math_ops.sqrt(v_t_prime) + coefficients['epsilon']) 

171 return state_ops.assign(var, var_t, use_locking=self._use_locking).op 

172 

173 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 

174 var_device, var_dtype = var.device, var.dtype.base_dtype 

175 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 

176 or self._fallback_apply_state(var_device, var_dtype)) 

177 

178 m = self.get_slot(var, 'm') 

179 v = self.get_slot(var, 'v') 

180 

181 g_prime = grad / coefficients['one_minus_m_schedule_new'] 

182 

183 # m_t = beta1 * m + (1 - beta1) * g_t 

184 m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] 

185 m_t = state_ops.assign(m, m * coefficients['beta_1_t'], 

186 use_locking=self._use_locking) 

187 

188 with ops.control_dependencies([m_t]): 

189 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) 

190 m_t_slice = array_ops.gather(m_t, indices) 

191 

192 m_t_prime = m_t_slice / coefficients['one_minus_m_schedule_next'] 

193 m_t_bar = (coefficients['one_minus_m_t'] * g_prime + 

194 coefficients['m_t_1'] * m_t_prime) 

195 

196 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 

197 v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t'] 

198 v_t = state_ops.assign(v, v * coefficients['beta_2_t'], 

199 use_locking=self._use_locking) 

200 

201 with ops.control_dependencies([v_t]): 

202 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) 

203 v_t_slice = array_ops.gather(v_t, indices) 

204 

205 v_t_prime = v_t_slice / coefficients['v_t_prime_denominator'] 

206 v_prime_sqrt_plus_eps = math_ops.sqrt(v_t_prime) + coefficients['epsilon'] 

207 

208 var_update = self._resource_scatter_add( 

209 var, indices, 

210 coefficients['neg_lr_t'] * m_t_bar / v_prime_sqrt_plus_eps) 

211 return control_flow_ops.group(*[var_update, m_t_bar, v_t]) 

212 

213 def get_config(self): 

214 config = super(Nadam, self).get_config() 

215 config.update({ 

216 'learning_rate': self._serialize_hyperparameter('learning_rate'), 

217 'decay': self._initial_decay, 

218 'beta_1': self._serialize_hyperparameter('beta_1'), 

219 'beta_2': self._serialize_hyperparameter('beta_2'), 

220 'epsilon': self.epsilon, 

221 }) 

222 return config