Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/legacy/rmsprop.py: 18%
93 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""RMSprop optimizer implementation."""
17import numpy as np
18import tensorflow.compat.v2 as tf
20from keras.src import backend_config
21from keras.src.optimizers.legacy import optimizer_v2
23# isort: off
24from tensorflow.python.util.tf_export import keras_export
27@keras_export(
28 "keras.optimizers.legacy.RMSprop",
29 v1=["keras.optimizers.RMSprop", "keras.optimizers.legacy.RMSprop"],
30)
31class RMSprop(optimizer_v2.OptimizerV2):
32 r"""Optimizer that implements the RMSprop algorithm.
34 The gist of RMSprop is to:
36 - Maintain a moving (discounted) average of the square of gradients
37 - Divide the gradient by the root of this average
39 This implementation of RMSprop uses plain momentum, not Nesterov momentum.
41 The centered version additionally maintains a moving average of the
42 gradients, and uses that average to estimate the variance.
44 Args:
45 learning_rate: A `Tensor`, floating point value, or a schedule that is a
46 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
47 that takes no arguments and returns the actual value to use. The
48 learning rate. Defaults to `0.001`.
49 rho: Discounting factor for the history/coming gradient. Defaults to
50 `0.9`.
51 momentum: A scalar or a scalar `Tensor`. Defaults to `0.0`.
52 epsilon: A small constant for numerical stability. This epsilon is
53 "epsilon hat" in the Kingma and Ba paper (in the formula just before
54 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to
55 `1e-7`.
56 centered: Boolean. If `True`, gradients are normalized by the estimated
57 variance of the gradient; if False, by the uncentered second moment.
58 Setting this to `True` may help with training, but is slightly more
59 expensive in terms of computation and memory. Defaults to `False`.
60 name: Optional name prefix for the operations created when applying
61 gradients. Defaults to `"RMSprop"`.
62 **kwargs: keyword arguments. Allowed arguments are `clipvalue`,
63 `clipnorm`, `global_clipnorm`.
64 If `clipvalue` (float) is set, the gradient of each weight
65 is clipped to be no higher than this value.
66 If `clipnorm` (float) is set, the gradient of each weight
67 is individually clipped so that its norm is no higher than this value.
68 If `global_clipnorm` (float) is set the gradient of all weights is
69 clipped so that their global norm is no higher than this value.
71 Note that in the dense implementation of this algorithm, variables and their
72 corresponding accumulators (momentum, gradient moving average, square
73 gradient moving average) will be updated even if the gradient is zero
74 (i.e. accumulators will decay, momentum will be applied). The sparse
75 implementation (used when the gradient is an `IndexedSlices` object,
76 typically because of `tf.gather` or an embedding lookup in the forward pass)
77 will not update variable slices or their accumulators unless those slices
78 were used in the forward pass (nor is there an "eventual" correction to
79 account for these omitted updates). This leads to more efficient updates for
80 large embedding lookup tables (where most of the slices are not accessed in
81 a particular graph execution), but differs from the published algorithm.
83 Usage:
85 >>> opt = tf.keras.optimizers.legacy.RMSprop(learning_rate=0.1)
86 >>> var1 = tf.Variable(10.0)
87 >>> loss = lambda: (var1 ** 2) / 2.0 # d(loss) / d(var1) = var1
88 >>> step_count = opt.minimize(loss, [var1]).numpy()
89 >>> var1.numpy()
90 9.683772
92 Reference:
93 - [Hinton, 2012](
94 http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
95 """
97 _HAS_AGGREGATE_GRAD = True
99 def __init__(
100 self,
101 learning_rate=0.001,
102 rho=0.9,
103 momentum=0.0,
104 epsilon=1e-7,
105 centered=False,
106 name="RMSprop",
107 **kwargs,
108 ):
109 """Construct a new RMSprop optimizer.
111 Args:
112 learning_rate: A `Tensor`, floating point value, or a schedule that is
113 a `tf.keras.optimizers.schedules.LearningRateSchedule`, or a
114 callable that takes no arguments and returns the actual value to
115 use. The learning rate. Defaults to `0.001`.
116 rho: Discounting factor for the history/coming gradient. Defaults to
117 `0.9`.
118 momentum: A scalar or a scalar `Tensor`. Defaults to `0.0`.
119 epsilon: A small constant for numerical stability. This epsilon is
120 "epsilon hat" in the Kingma and Ba paper (in the formula just before
121 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults
122 to 1e-7.
123 centered: Boolean. If `True`, gradients are normalized by the
124 estimated variance of the gradient; if False, by the uncentered
125 second moment. Setting this to `True` may help with training, but
126 is slightly more expensive in terms of computation and memory.
127 Defaults to `False`.
128 name: Optional name prefix for the operations created when applying
129 gradients. Defaults to "RMSprop".
130 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
131 `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is
132 clip gradients by value, `decay` is included for backward
133 compatibility to allow time inverse decay of learning rate. `lr` is
134 included for backward compatibility, recommended to use
135 `learning_rate` instead.
137 @compatibility(eager)
138 When eager execution is enabled, `learning_rate`, `decay`, `momentum`,
139 and `epsilon` can each be a callable that takes no arguments and returns
140 the actual value to use. This can be useful for changing these values
141 across different invocations of optimizer functions.
142 @end_compatibility
143 """
144 super().__init__(name, **kwargs)
145 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
146 self._set_hyper("decay", self._initial_decay)
147 self._set_hyper("rho", rho)
149 self._momentum = False
150 if (
151 isinstance(momentum, tf.Tensor)
152 or callable(momentum)
153 or momentum > 0
154 ):
155 self._momentum = True
156 if isinstance(momentum, (int, float)) and (
157 momentum < 0 or momentum > 1
158 ):
159 raise ValueError(
160 "`momentum` must be between [0, 1]. Received: "
161 f"momentum={momentum} (of type {type(momentum)})."
162 )
163 self._set_hyper("momentum", momentum)
165 self.epsilon = epsilon or backend_config.epsilon()
166 self.centered = centered
168 def _create_slots(self, var_list):
169 for var in var_list:
170 self.add_slot(var, "rms")
171 if self._momentum:
172 for var in var_list:
173 self.add_slot(var, "momentum")
174 if self.centered:
175 for var in var_list:
176 self.add_slot(var, "mg")
178 def _prepare_local(self, var_device, var_dtype, apply_state):
179 super()._prepare_local(var_device, var_dtype, apply_state)
181 rho = tf.identity(self._get_hyper("rho", var_dtype))
182 apply_state[(var_device, var_dtype)].update(
183 dict(
184 neg_lr_t=-apply_state[(var_device, var_dtype)]["lr_t"],
185 epsilon=tf.convert_to_tensor(self.epsilon, var_dtype),
186 rho=rho,
187 momentum=tf.identity(self._get_hyper("momentum", var_dtype)),
188 one_minus_rho=1.0 - rho,
189 )
190 )
192 def _resource_apply_dense(self, grad, var, apply_state=None):
193 var_device, var_dtype = var.device, var.dtype.base_dtype
194 coefficients = (apply_state or {}).get(
195 (var_device, var_dtype)
196 ) or self._fallback_apply_state(var_device, var_dtype)
198 rms = self.get_slot(var, "rms")
199 if self._momentum:
200 mom = self.get_slot(var, "momentum")
201 if self.centered:
202 mg = self.get_slot(var, "mg")
203 return tf.raw_ops.ResourceApplyCenteredRMSProp(
204 var=var.handle,
205 mg=mg.handle,
206 ms=rms.handle,
207 mom=mom.handle,
208 lr=coefficients["lr_t"],
209 rho=coefficients["rho"],
210 momentum=coefficients["momentum"],
211 epsilon=coefficients["epsilon"],
212 grad=grad,
213 use_locking=self._use_locking,
214 )
215 else:
216 return tf.raw_ops.ResourceApplyRMSProp(
217 var=var.handle,
218 ms=rms.handle,
219 mom=mom.handle,
220 lr=coefficients["lr_t"],
221 rho=coefficients["rho"],
222 momentum=coefficients["momentum"],
223 epsilon=coefficients["epsilon"],
224 grad=grad,
225 use_locking=self._use_locking,
226 )
227 else:
228 rms_t = coefficients["rho"] * rms + coefficients[
229 "one_minus_rho"
230 ] * tf.square(grad)
231 rms_t = tf.compat.v1.assign(
232 rms, rms_t, use_locking=self._use_locking
233 )
234 denom_t = rms_t
235 if self.centered:
236 mg = self.get_slot(var, "mg")
237 mg_t = (
238 coefficients["rho"] * mg
239 + coefficients["one_minus_rho"] * grad
240 )
241 mg_t = tf.compat.v1.assign(
242 mg, mg_t, use_locking=self._use_locking
243 )
244 denom_t = rms_t - tf.square(mg_t)
245 var_t = var - coefficients["lr_t"] * grad / (
246 tf.sqrt(denom_t) + coefficients["epsilon"]
247 )
248 return tf.compat.v1.assign(
249 var, var_t, use_locking=self._use_locking
250 ).op
252 def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
253 var_device, var_dtype = var.device, var.dtype.base_dtype
254 coefficients = (apply_state or {}).get(
255 (var_device, var_dtype)
256 ) or self._fallback_apply_state(var_device, var_dtype)
258 rms = self.get_slot(var, "rms")
259 if self._momentum:
260 mom = self.get_slot(var, "momentum")
261 if self.centered:
262 mg = self.get_slot(var, "mg")
263 return tf.raw_ops.ResourceSparseApplyCenteredRMSProp(
264 var=var.handle,
265 mg=mg.handle,
266 ms=rms.handle,
267 mom=mom.handle,
268 lr=coefficients["lr_t"],
269 rho=coefficients["rho"],
270 momentum=coefficients["momentum"],
271 epsilon=coefficients["epsilon"],
272 grad=grad,
273 indices=indices,
274 use_locking=self._use_locking,
275 )
276 else:
277 return tf.raw_ops.ResourceSparseApplyRMSProp(
278 var=var.handle,
279 ms=rms.handle,
280 mom=mom.handle,
281 lr=coefficients["lr_t"],
282 rho=coefficients["rho"],
283 momentum=coefficients["momentum"],
284 epsilon=coefficients["epsilon"],
285 grad=grad,
286 indices=indices,
287 use_locking=self._use_locking,
288 )
289 else:
290 rms_scaled_g_values = (grad * grad) * coefficients["one_minus_rho"]
291 rms_t = tf.compat.v1.assign(
292 rms, rms * coefficients["rho"], use_locking=self._use_locking
293 )
294 with tf.control_dependencies([rms_t]):
295 rms_t = self._resource_scatter_add(
296 rms, indices, rms_scaled_g_values
297 )
298 rms_slice = tf.gather(rms_t, indices)
299 denom_slice = rms_slice
300 if self.centered:
301 mg = self.get_slot(var, "mg")
302 mg_scaled_g_values = grad * coefficients["one_minus_rho"]
303 mg_t = tf.compat.v1.assign(
304 mg, mg * coefficients["rho"], use_locking=self._use_locking
305 )
306 with tf.control_dependencies([mg_t]):
307 mg_t = self._resource_scatter_add(
308 mg, indices, mg_scaled_g_values
309 )
310 mg_slice = tf.gather(mg_t, indices)
311 denom_slice = rms_slice - tf.square(mg_slice)
312 var_update = self._resource_scatter_add(
313 var,
314 indices,
315 coefficients["neg_lr_t"]
316 * grad
317 / (tf.sqrt(denom_slice) + coefficients["epsilon"]),
318 )
319 if self.centered:
320 return tf.group(*[var_update, rms_t, mg_t])
321 return tf.group(*[var_update, rms_t])
323 def set_weights(self, weights):
324 params = self.weights
325 # Override set_weights for backward compatibility of Keras V1 optimizer
326 # since it does not include iteration at head of the weight list. Set
327 # iteration to 0.
328 if len(params) == len(weights) + 1:
329 weights = [np.array(0)] + weights
330 super().set_weights(weights)
332 def get_config(self):
333 config = super().get_config()
334 config.update(
335 {
336 "learning_rate": self._serialize_hyperparameter(
337 "learning_rate"
338 ),
339 "decay": self._initial_decay,
340 "rho": self._serialize_hyperparameter("rho"),
341 "momentum": self._serialize_hyperparameter("momentum"),
342 "epsilon": self.epsilon,
343 "centered": self.centered,
344 }
345 )
346 return config
349RMSProp = RMSprop