Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/legacy/nadam.py: 18%
88 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Nadam optimizer implementation."""
17import tensorflow.compat.v2 as tf
19from keras.src import backend_config
20from keras.src.optimizers.legacy import optimizer_v2
21from keras.src.optimizers.schedules import learning_rate_schedule
23# isort: off
24from tensorflow.python.util.tf_export import keras_export
27@keras_export(
28 "keras.optimizers.legacy.Nadam",
29 v1=["keras.optimizers.Nadam", "keras.optimizers.legacy.Nadam"],
30)
31class Nadam(optimizer_v2.OptimizerV2):
32 r"""Optimizer that implements the NAdam algorithm.
33 Much like Adam is essentially RMSprop with momentum, Nadam is Adam with
34 Nesterov momentum.
36 Args:
37 learning_rate: A Tensor or a floating point value. The learning rate.
38 beta_1: A float value or a constant float tensor. The exponential decay
39 rate for the 1st moment estimates.
40 beta_2: A float value or a constant float tensor. The exponential decay
41 rate for the exponentially weighted infinity norm.
42 epsilon: A small constant for numerical stability.
43 name: Optional name for the operations created when applying gradients.
44 Defaults to `"Nadam"`.
45 **kwargs: keyword arguments. Allowed arguments are `clipvalue`,
46 `clipnorm`, `global_clipnorm`.
47 If `clipvalue` (float) is set, the gradient of each weight
48 is clipped to be no higher than this value.
49 If `clipnorm` (float) is set, the gradient of each weight
50 is individually clipped so that its norm is no higher than this value.
51 If `global_clipnorm` (float) is set the gradient of all weights is
52 clipped so that their global norm is no higher than this value.
54 Usage Example:
55 >>> opt = tf.keras.optimizers.legacy.Nadam(learning_rate=0.2)
56 >>> var1 = tf.Variable(10.0)
57 >>> loss = lambda: (var1 ** 2) / 2.0
58 >>> step_count = opt.minimize(loss, [var1]).numpy()
59 >>> "{:.1f}".format(var1.numpy())
60 9.8
62 Reference:
63 - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf).
64 """
66 _HAS_AGGREGATE_GRAD = True
68 def __init__(
69 self,
70 learning_rate=0.001,
71 beta_1=0.9,
72 beta_2=0.999,
73 epsilon=1e-7,
74 name="Nadam",
75 **kwargs
76 ):
77 # Backwards compatibility with keras NAdam optimizer.
78 kwargs["decay"] = kwargs.pop("schedule_decay", 0.004)
79 learning_rate = kwargs.get("lr", learning_rate)
80 if isinstance(
81 learning_rate, learning_rate_schedule.LearningRateSchedule
82 ):
83 raise ValueError(
84 "The Nadam optimizer does not support "
85 "tf.keras.optimizers.LearningRateSchedules as the "
86 "learning rate."
87 )
89 super().__init__(name, **kwargs)
90 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
91 self._set_hyper("decay", self._initial_decay)
92 self._set_hyper("beta_1", beta_1)
93 self._set_hyper("beta_2", beta_2)
94 self.epsilon = epsilon or backend_config.epsilon()
95 self._m_cache = None
97 def _create_slots(self, var_list):
98 var_dtype = var_list[0].dtype.base_dtype
99 if self._m_cache is None:
100 self._m_cache = self.add_weight(
101 "momentum_cache",
102 shape=[],
103 dtype=var_dtype,
104 initializer="ones",
105 trainable=False,
106 aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
107 )
108 self._weights.append(self._m_cache)
109 # Separate for-loops to respect the ordering of slot variables from v1.
110 for var in var_list:
111 # Create slots for the first moments.
112 self.add_slot(var, "m")
113 for var in var_list:
114 # Create slots for the second moments.
115 self.add_slot(var, "v")
117 def _prepare_local(self, var_device, var_dtype, apply_state):
118 lr_t = tf.identity(self._get_hyper("learning_rate", var_dtype))
119 beta_1_t = tf.identity(self._get_hyper("beta_1", var_dtype))
120 beta_2_t = tf.identity(self._get_hyper("beta_2", var_dtype))
121 local_step = tf.cast(self.iterations + 1, var_dtype)
122 next_step = tf.cast(self.iterations + 2, var_dtype)
124 decay_base = tf.cast(0.96, var_dtype)
126 m_t = beta_1_t * (
127 1.0 - 0.5 * (tf.pow(decay_base, self._initial_decay * local_step))
128 )
129 m_t_1 = beta_1_t * (
130 1.0 - 0.5 * (tf.pow(decay_base, self._initial_decay * next_step))
131 )
133 m_schedule_new = tf.cast(self._m_cache_read, var_dtype) * m_t
134 if var_dtype is self._m_cache.dtype:
135 m_schedule_new = tf.identity(
136 tf.compat.v1.assign(
137 self._m_cache, m_schedule_new, use_locking=self._use_locking
138 )
139 )
140 m_schedule_next = m_schedule_new * m_t_1
142 apply_state[(var_device, var_dtype)] = dict(
143 lr_t=lr_t,
144 neg_lr_t=-lr_t,
145 epsilon=tf.convert_to_tensor(self.epsilon, var_dtype),
146 beta_1_t=beta_1_t,
147 beta_2_t=beta_2_t,
148 m_t=m_t,
149 m_t_1=m_t_1,
150 one_minus_beta_1_t=1 - beta_1_t,
151 one_minus_beta_2_t=1 - beta_2_t,
152 one_minus_m_t=1.0 - m_t,
153 one_minus_m_schedule_new=1.0 - m_schedule_new,
154 one_minus_m_schedule_next=1.0 - m_schedule_next,
155 v_t_prime_denominator=1.0 - tf.pow(beta_2_t, local_step),
156 )
158 def _prepare(self, var_list):
159 # Get the value of the momentum cache before starting to apply
160 # gradients.
161 self._m_cache_read = tf.identity(self._m_cache)
162 return super()._prepare(var_list)
164 def _resource_apply_dense(self, grad, var, apply_state=None):
165 var_device, var_dtype = var.device, var.dtype.base_dtype
166 coefficients = (apply_state or {}).get(
167 (var_device, var_dtype)
168 ) or self._fallback_apply_state(var_device, var_dtype)
170 m = self.get_slot(var, "m")
171 v = self.get_slot(var, "v")
173 g_prime = grad / coefficients["one_minus_m_schedule_new"]
174 m_t = (
175 coefficients["beta_1_t"] * m
176 + coefficients["one_minus_beta_1_t"] * grad
177 )
178 m_t = tf.compat.v1.assign(m, m_t, use_locking=self._use_locking)
179 m_t_prime = m_t / coefficients["one_minus_m_schedule_next"]
180 v_t = coefficients["beta_2_t"] * v + coefficients[
181 "one_minus_beta_2_t"
182 ] * tf.square(grad)
183 v_t = tf.compat.v1.assign(v, v_t, use_locking=self._use_locking)
184 v_t_prime = v_t / coefficients["v_t_prime_denominator"]
185 m_t_bar = (
186 coefficients["one_minus_m_t"] * g_prime
187 + coefficients["m_t_1"] * m_t_prime
188 )
189 var_t = var - coefficients["lr_t"] * m_t_bar / (
190 tf.sqrt(v_t_prime) + coefficients["epsilon"]
191 )
192 return tf.compat.v1.assign(var, var_t, use_locking=self._use_locking).op
194 def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
195 var_device, var_dtype = var.device, var.dtype.base_dtype
196 coefficients = (apply_state or {}).get(
197 (var_device, var_dtype)
198 ) or self._fallback_apply_state(var_device, var_dtype)
200 m = self.get_slot(var, "m")
201 v = self.get_slot(var, "v")
203 g_prime = grad / coefficients["one_minus_m_schedule_new"]
205 # m_t = beta1 * m + (1 - beta1) * g_t
206 m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"]
207 m_t = tf.compat.v1.assign(
208 m, m * coefficients["beta_1_t"], use_locking=self._use_locking
209 )
211 with tf.control_dependencies([m_t]):
212 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
213 m_t_slice = tf.gather(m_t, indices)
215 m_t_prime = m_t_slice / coefficients["one_minus_m_schedule_next"]
216 m_t_bar = (
217 coefficients["one_minus_m_t"] * g_prime
218 + coefficients["m_t_1"] * m_t_prime
219 )
221 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
222 v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"]
223 v_t = tf.compat.v1.assign(
224 v, v * coefficients["beta_2_t"], use_locking=self._use_locking
225 )
227 with tf.control_dependencies([v_t]):
228 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
229 v_t_slice = tf.gather(v_t, indices)
231 v_t_prime = v_t_slice / coefficients["v_t_prime_denominator"]
232 v_prime_sqrt_plus_eps = tf.sqrt(v_t_prime) + coefficients["epsilon"]
234 var_update = self._resource_scatter_add(
235 var,
236 indices,
237 coefficients["neg_lr_t"] * m_t_bar / v_prime_sqrt_plus_eps,
238 )
239 return tf.group(*[var_update, m_t_bar, v_t])
241 def get_config(self):
242 config = super().get_config()
243 config.update(
244 {
245 "learning_rate": self._serialize_hyperparameter(
246 "learning_rate"
247 ),
248 "decay": self._initial_decay,
249 "beta_1": self._serialize_hyperparameter("beta_1"),
250 "beta_2": self._serialize_hyperparameter("beta_2"),
251 "epsilon": self.epsilon,
252 }
253 )
254 return config