Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/nadam.py: 19%
67 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Nadam optimizer implementation."""
17import tensorflow.compat.v2 as tf
19from keras.src.optimizers import optimizer
20from keras.src.saving.object_registration import register_keras_serializable
22# isort: off
23from tensorflow.python.util.tf_export import keras_export
26@register_keras_serializable()
27@keras_export(
28 "keras.optimizers.experimental.Nadam", "keras.optimizers.Nadam", v1=[]
29)
30class Nadam(optimizer.Optimizer):
31 r"""Optimizer that implements the Nadam algorithm.
33 Much like Adam is essentially RMSprop with momentum, Nadam is Adam with
34 Nesterov momentum.
36 Args:
37 learning_rate: A `tf.Tensor`, floating point value, a schedule that is a
38 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
39 that takes no arguments and returns the actual value to use. The
40 learning rate. Defaults to `0.001`.
41 beta_1: A float value or a constant float tensor, or a callable
42 that takes no arguments and returns the actual value to use. The
43 exponential decay rate for the 1st moment estimates. Defaults to `0.9`.
44 beta_2: A float value or a constant float tensor, or a callable
45 that takes no arguments and returns the actual value to use. The
46 exponential decay rate for the 2nd moment estimates. Defaults to
47 `0.999`.
48 epsilon: A small constant for numerical stability. This epsilon is
49 "epsilon hat" in the Kingma and Ba paper (in the formula just before
50 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to
51 `1e-7`.
52 {{base_optimizer_keyword_args}}
54 Reference:
55 - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf).
57 """
59 def __init__(
60 self,
61 learning_rate=0.001,
62 beta_1=0.9,
63 beta_2=0.999,
64 epsilon=1e-7,
65 weight_decay=None,
66 clipnorm=None,
67 clipvalue=None,
68 global_clipnorm=None,
69 use_ema=False,
70 ema_momentum=0.99,
71 ema_overwrite_frequency=None,
72 jit_compile=True,
73 name="Nadam",
74 **kwargs
75 ):
76 super().__init__(
77 name=name,
78 weight_decay=weight_decay,
79 clipnorm=clipnorm,
80 clipvalue=clipvalue,
81 global_clipnorm=global_clipnorm,
82 use_ema=use_ema,
83 ema_momentum=ema_momentum,
84 ema_overwrite_frequency=ema_overwrite_frequency,
85 jit_compile=jit_compile,
86 **kwargs
87 )
88 self._learning_rate = self._build_learning_rate(learning_rate)
89 self.beta_1 = beta_1
90 self.beta_2 = beta_2
91 self.epsilon = epsilon
93 def build(self, var_list):
94 """Initialize optimizer variables.
96 Nadam optimizer has 2 types of variables: momentums and velocities.
98 Args:
99 var_list: list of model variables to build Nadam variables on.
100 """
101 super().build(var_list)
102 if getattr(self, "_built", False):
103 return
104 self._built = True
105 self._momentums = []
106 self._velocities = []
107 self._u_product = tf.Variable(1.0, dtype=var_list[0].dtype)
108 # Keep a counter on how many times of _u_product has been computed to
109 # avoid duplicated computations.
110 self._u_product_counter = 1
112 for var in var_list:
113 self._momentums.append(
114 self.add_variable_from_reference(
115 model_variable=var, variable_name="m"
116 )
117 )
118 self._velocities.append(
119 self.add_variable_from_reference(
120 model_variable=var, variable_name="v"
121 )
122 )
124 def update_step(self, gradient, variable):
125 """Update step given gradient and the associated model variable."""
126 var_dtype = variable.dtype
127 lr = tf.cast(self.learning_rate, var_dtype)
128 local_step = tf.cast(self.iterations + 1, var_dtype)
129 next_step = tf.cast(self.iterations + 2, var_dtype)
130 decay = tf.cast(0.96, var_dtype)
131 beta_1 = tf.cast(self.beta_1, var_dtype)
132 beta_2 = tf.cast(self.beta_2, var_dtype)
133 u_t = beta_1 * (1.0 - 0.5 * (tf.pow(decay, local_step)))
134 u_t_1 = beta_1 * (1.0 - 0.5 * (tf.pow(decay, next_step)))
136 def get_cached_u_product():
137 return self._u_product
139 def compute_new_u_product():
140 u_product_t = self._u_product * u_t
141 self._u_product.assign(u_product_t)
142 self._u_product_counter += 1
143 return u_product_t
145 u_product_t = tf.cond(
146 self._u_product_counter == (self.iterations + 2),
147 true_fn=get_cached_u_product,
148 false_fn=compute_new_u_product,
149 )
150 u_product_t_1 = u_product_t * u_t_1
151 beta_2_power = tf.pow(beta_2, local_step)
153 var_key = self._var_key(variable)
154 m = self._momentums[self._index_dict[var_key]]
155 v = self._velocities[self._index_dict[var_key]]
157 if isinstance(gradient, tf.IndexedSlices):
158 # Sparse gradients.
159 m.assign_add(-m * (1 - beta_1))
160 m.scatter_add(
161 tf.IndexedSlices(
162 gradient.values * (1 - beta_1), gradient.indices
163 )
164 )
165 v.assign_add(-v * (1 - beta_2))
166 v.scatter_add(
167 tf.IndexedSlices(
168 tf.square(gradient.values) * (1 - beta_2), gradient.indices
169 )
170 )
171 m_hat = u_t_1 * m / (1 - u_product_t_1) + (1 - u_t) * gradient / (
172 1 - u_product_t
173 )
174 v_hat = v / (1 - beta_2_power)
176 variable.assign_sub((m_hat * lr) / (tf.sqrt(v_hat) + self.epsilon))
177 else:
178 # Dense gradients.
179 m.assign_add((gradient - m) * (1 - beta_1))
180 v.assign_add((tf.square(gradient) - v) * (1 - beta_2))
181 m_hat = u_t_1 * m / (1 - u_product_t_1) + (1 - u_t) * gradient / (
182 1 - u_product_t
183 )
184 v_hat = v / (1 - beta_2_power)
186 variable.assign_sub((m_hat * lr) / (tf.sqrt(v_hat) + self.epsilon))
188 def get_config(self):
189 config = super().get_config()
191 config.update(
192 {
193 "learning_rate": self._serialize_hyperparameter(
194 self._learning_rate
195 ),
196 "beta_1": self.beta_1,
197 "beta_2": self.beta_2,
198 "epsilon": self.epsilon,
199 }
200 )
201 return config
204Nadam.__doc__ = Nadam.__doc__.replace(
205 "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
206)