Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/nadam.py: 23%
94 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Nadam optimizer implementation."""
16# pylint: disable=g-classes-have-attributes
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import tensor_conversion
20from tensorflow.python.keras import backend_config
21from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
22from tensorflow.python.keras.optimizer_v2 import optimizer_v2
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import state_ops
27from tensorflow.python.ops import variables as tf_variables
28from tensorflow.python.util.tf_export import keras_export
31@keras_export('keras.optimizers.Nadam')
32class Nadam(optimizer_v2.OptimizerV2):
33 r"""Optimizer that implements the NAdam algorithm.
34 Much like Adam is essentially RMSprop with momentum, Nadam is Adam with
35 Nesterov momentum.
37 Args:
38 learning_rate: A Tensor or a floating point value. The learning rate.
39 beta_1: A float value or a constant float tensor. The exponential decay
40 rate for the 1st moment estimates.
41 beta_2: A float value or a constant float tensor. The exponential decay
42 rate for the exponentially weighted infinity norm.
43 epsilon: A small constant for numerical stability.
44 name: Optional name for the operations created when applying gradients.
45 Defaults to `"Nadam"`.
46 **kwargs: Keyword arguments. Allowed to be one of
47 `"clipnorm"` or `"clipvalue"`.
48 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
49 gradients by value.
51 Usage Example:
52 >>> opt = tf.keras.optimizers.Nadam(learning_rate=0.2)
53 >>> var1 = tf.Variable(10.0)
54 >>> loss = lambda: (var1 ** 2) / 2.0
55 >>> step_count = opt.minimize(loss, [var1]).numpy()
56 >>> "{:.1f}".format(var1.numpy())
57 9.8
59 Reference:
60 - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf).
61 """
63 _HAS_AGGREGATE_GRAD = True
65 def __init__(self,
66 learning_rate=0.001,
67 beta_1=0.9,
68 beta_2=0.999,
69 epsilon=1e-7,
70 name='Nadam',
71 **kwargs):
72 # Backwards compatibility with keras NAdam optimizer.
73 kwargs['decay'] = kwargs.pop('schedule_decay', 0.004)
74 learning_rate = kwargs.get('lr', learning_rate)
75 if isinstance(learning_rate, learning_rate_schedule.LearningRateSchedule):
76 raise ValueError('The Nadam optimizer does not support '
77 'tf.keras.optimizers.LearningRateSchedules as the '
78 'learning rate.')
80 super(Nadam, self).__init__(name, **kwargs)
81 self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
82 self._set_hyper('decay', self._initial_decay)
83 self._set_hyper('beta_1', beta_1)
84 self._set_hyper('beta_2', beta_2)
85 self.epsilon = epsilon or backend_config.epsilon()
86 self._m_cache = None
88 def _create_slots(self, var_list):
89 var_dtype = var_list[0].dtype.base_dtype
90 if self._m_cache is None:
91 self._m_cache = self.add_weight(
92 'momentum_cache',
93 shape=[],
94 dtype=var_dtype,
95 initializer='ones',
96 trainable=False,
97 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
98 self._weights.append(self._m_cache)
99 # Separate for-loops to respect the ordering of slot variables from v1.
100 for var in var_list:
101 # Create slots for the first moments.
102 self.add_slot(var, 'm')
103 for var in var_list:
104 # Create slots for the second moments.
105 self.add_slot(var, 'v')
107 def _prepare_local(self, var_device, var_dtype, apply_state):
108 lr_t = array_ops.identity(self._get_hyper('learning_rate', var_dtype))
109 beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
110 beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
111 local_step = math_ops.cast(self.iterations + 1, var_dtype)
112 next_step = math_ops.cast(self.iterations + 2, var_dtype)
114 decay_base = math_ops.cast(0.96, var_dtype)
116 m_t = beta_1_t * (1. - 0.5 * (
117 math_ops.pow(decay_base, self._initial_decay * local_step)))
118 m_t_1 = beta_1_t * (1. - 0.5 * (
119 math_ops.pow(decay_base, self._initial_decay * next_step)))
121 m_schedule_new = math_ops.cast(self._m_cache_read, var_dtype) * m_t
122 if var_dtype is self._m_cache.dtype:
123 m_schedule_new = array_ops.identity(state_ops.assign(
124 self._m_cache, m_schedule_new, use_locking=self._use_locking))
125 m_schedule_next = m_schedule_new * m_t_1
127 apply_state[(var_device, var_dtype)] = dict(
128 lr_t=lr_t,
129 neg_lr_t=-lr_t, # pylint: disable=invalid-unary-operand-type
130 epsilon=tensor_conversion.convert_to_tensor_v2_with_dispatch(
131 self.epsilon, var_dtype
132 ),
133 beta_1_t=beta_1_t,
134 beta_2_t=beta_2_t,
135 m_t=m_t,
136 m_t_1=m_t_1,
137 one_minus_beta_1_t=1 - beta_1_t,
138 one_minus_beta_2_t=1 - beta_2_t,
139 one_minus_m_t=1.0 - m_t,
140 one_minus_m_schedule_new=1.0 - m_schedule_new,
141 one_minus_m_schedule_next=1.0 - m_schedule_next,
142 v_t_prime_denominator=1.0 - math_ops.pow(beta_2_t, local_step),
143 )
145 def _prepare(self, var_list):
146 # Get the value of the momentum cache before starting to apply gradients.
147 self._m_cache_read = array_ops.identity(self._m_cache)
148 return super(Nadam, self)._prepare(var_list)
150 def _resource_apply_dense(self, grad, var, apply_state=None):
151 var_device, var_dtype = var.device, var.dtype.base_dtype
152 coefficients = ((apply_state or {}).get((var_device, var_dtype))
153 or self._fallback_apply_state(var_device, var_dtype))
155 m = self.get_slot(var, 'm')
156 v = self.get_slot(var, 'v')
158 g_prime = grad / coefficients['one_minus_m_schedule_new']
159 m_t = (coefficients['beta_1_t'] * m +
160 coefficients['one_minus_beta_1_t'] * grad)
161 m_t = state_ops.assign(m, m_t, use_locking=self._use_locking)
162 m_t_prime = m_t / coefficients['one_minus_m_schedule_next']
163 v_t = (coefficients['beta_2_t'] * v +
164 coefficients['one_minus_beta_2_t'] * math_ops.square(grad))
165 v_t = state_ops.assign(v, v_t, use_locking=self._use_locking)
166 v_t_prime = v_t / coefficients['v_t_prime_denominator']
167 m_t_bar = (coefficients['one_minus_m_t'] * g_prime +
168 coefficients['m_t_1'] * m_t_prime)
169 var_t = var - coefficients['lr_t'] * m_t_bar / (
170 math_ops.sqrt(v_t_prime) + coefficients['epsilon'])
171 return state_ops.assign(var, var_t, use_locking=self._use_locking).op
173 def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
174 var_device, var_dtype = var.device, var.dtype.base_dtype
175 coefficients = ((apply_state or {}).get((var_device, var_dtype))
176 or self._fallback_apply_state(var_device, var_dtype))
178 m = self.get_slot(var, 'm')
179 v = self.get_slot(var, 'v')
181 g_prime = grad / coefficients['one_minus_m_schedule_new']
183 # m_t = beta1 * m + (1 - beta1) * g_t
184 m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
185 m_t = state_ops.assign(m, m * coefficients['beta_1_t'],
186 use_locking=self._use_locking)
188 with ops.control_dependencies([m_t]):
189 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
190 m_t_slice = array_ops.gather(m_t, indices)
192 m_t_prime = m_t_slice / coefficients['one_minus_m_schedule_next']
193 m_t_bar = (coefficients['one_minus_m_t'] * g_prime +
194 coefficients['m_t_1'] * m_t_prime)
196 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
197 v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
198 v_t = state_ops.assign(v, v * coefficients['beta_2_t'],
199 use_locking=self._use_locking)
201 with ops.control_dependencies([v_t]):
202 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
203 v_t_slice = array_ops.gather(v_t, indices)
205 v_t_prime = v_t_slice / coefficients['v_t_prime_denominator']
206 v_prime_sqrt_plus_eps = math_ops.sqrt(v_t_prime) + coefficients['epsilon']
208 var_update = self._resource_scatter_add(
209 var, indices,
210 coefficients['neg_lr_t'] * m_t_bar / v_prime_sqrt_plus_eps)
211 return control_flow_ops.group(*[var_update, m_t_bar, v_t])
213 def get_config(self):
214 config = super(Nadam, self).get_config()
215 config.update({
216 'learning_rate': self._serialize_hyperparameter('learning_rate'),
217 'decay': self._initial_decay,
218 'beta_1': self._serialize_hyperparameter('beta_1'),
219 'beta_2': self._serialize_hyperparameter('beta_2'),
220 'epsilon': self.epsilon,
221 })
222 return config