Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/rmsprop.py: 23%
99 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""RMSprop optimizer implementation."""
16# pylint: disable=g-classes-have-attributes
18import numpy as np
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import tensor_conversion
22from tensorflow.python.keras import backend_config
23from tensorflow.python.keras.optimizer_v2 import optimizer_v2
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import state_ops
28from tensorflow.python.training import gen_training_ops
29from tensorflow.python.util.tf_export import keras_export
32@keras_export("keras.optimizers.RMSprop")
33class RMSprop(optimizer_v2.OptimizerV2):
34 r"""Optimizer that implements the RMSprop algorithm.
36 The gist of RMSprop is to:
38 - Maintain a moving (discounted) average of the square of gradients
39 - Divide the gradient by the root of this average
41 This implementation of RMSprop uses plain momentum, not Nesterov momentum.
43 The centered version additionally maintains a moving average of the
44 gradients, and uses that average to estimate the variance.
46 Args:
47 learning_rate: A `Tensor`, floating point value, or a schedule that is a
48 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
49 that takes no arguments and returns the actual value to use. The
50 learning rate. Defaults to 0.001.
51 rho: Discounting factor for the history/coming gradient. Defaults to 0.9.
52 momentum: A scalar or a scalar `Tensor`. Defaults to 0.0.
53 epsilon: A small constant for numerical stability. This epsilon is
54 "epsilon hat" in the Kingma and Ba paper (in the formula just before
55 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to
56 1e-7.
57 centered: Boolean. If `True`, gradients are normalized by the estimated
58 variance of the gradient; if False, by the uncentered second moment.
59 Setting this to `True` may help with training, but is slightly more
60 expensive in terms of computation and memory. Defaults to `False`.
61 name: Optional name prefix for the operations created when applying
62 gradients. Defaults to `"RMSprop"`.
63 **kwargs: Keyword arguments. Allowed to be one of
64 `"clipnorm"` or `"clipvalue"`.
65 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
66 gradients by value.
68 Note that in the dense implementation of this algorithm, variables and their
69 corresponding accumulators (momentum, gradient moving average, square
70 gradient moving average) will be updated even if the gradient is zero
71 (i.e. accumulators will decay, momentum will be applied). The sparse
72 implementation (used when the gradient is an `IndexedSlices` object,
73 typically because of `tf.gather` or an embedding lookup in the forward pass)
74 will not update variable slices or their accumulators unless those slices
75 were used in the forward pass (nor is there an "eventual" correction to
76 account for these omitted updates). This leads to more efficient updates for
77 large embedding lookup tables (where most of the slices are not accessed in
78 a particular graph execution), but differs from the published algorithm.
80 Usage:
82 >>> opt = tf.keras.optimizers.RMSprop(learning_rate=0.1)
83 >>> var1 = tf.Variable(10.0)
84 >>> loss = lambda: (var1 ** 2) / 2.0 # d(loss) / d(var1) = var1
85 >>> step_count = opt.minimize(loss, [var1]).numpy()
86 >>> var1.numpy()
87 9.683772
89 Reference:
90 - [Hinton, 2012](
91 http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
92 """
94 _HAS_AGGREGATE_GRAD = True
96 def __init__(self,
97 learning_rate=0.001,
98 rho=0.9,
99 momentum=0.0,
100 epsilon=1e-7,
101 centered=False,
102 name="RMSprop",
103 **kwargs):
104 """Construct a new RMSprop optimizer.
106 Args:
107 learning_rate: A `Tensor`, floating point value, or a schedule that is a
108 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
109 that takes no arguments and returns the actual value to use. The
110 learning rate. Defaults to 0.001.
111 rho: Discounting factor for the history/coming gradient. Defaults to 0.9.
112 momentum: A scalar or a scalar `Tensor`. Defaults to 0.0.
113 epsilon: A small constant for numerical stability. This epsilon is
114 "epsilon hat" in the Kingma and Ba paper (in the formula just before
115 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to
116 1e-7.
117 centered: Boolean. If `True`, gradients are normalized by the estimated
118 variance of the gradient; if False, by the uncentered second moment.
119 Setting this to `True` may help with training, but is slightly more
120 expensive in terms of computation and memory. Defaults to `False`.
121 name: Optional name prefix for the operations created when applying
122 gradients. Defaults to "RMSprop".
123 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
124 `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
125 gradients by value, `decay` is included for backward compatibility to
126 allow time inverse decay of learning rate. `lr` is included for backward
127 compatibility, recommended to use `learning_rate` instead.
129 @compatibility(eager)
130 When eager execution is enabled, `learning_rate`, `decay`, `momentum`, and
131 `epsilon` can each be a callable that takes no arguments and returns the
132 actual value to use. This can be useful for changing these values across
133 different invocations of optimizer functions.
134 @end_compatibility
135 """
136 super(RMSprop, self).__init__(name, **kwargs)
137 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
138 self._set_hyper("decay", self._initial_decay)
139 self._set_hyper("rho", rho)
141 self._momentum = False
142 if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0:
143 self._momentum = True
144 if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1):
145 raise ValueError("`momentum` must be between [0, 1].")
146 self._set_hyper("momentum", momentum)
148 self.epsilon = epsilon or backend_config.epsilon()
149 self.centered = centered
151 def _create_slots(self, var_list):
152 for var in var_list:
153 self.add_slot(var, "rms")
154 if self._momentum:
155 for var in var_list:
156 self.add_slot(var, "momentum")
157 if self.centered:
158 for var in var_list:
159 self.add_slot(var, "mg")
161 def _prepare_local(self, var_device, var_dtype, apply_state):
162 super(RMSprop, self)._prepare_local(var_device, var_dtype, apply_state)
164 rho = array_ops.identity(self._get_hyper("rho", var_dtype))
165 apply_state[(var_device, var_dtype)].update(
166 dict(
167 neg_lr_t=-apply_state[(var_device, var_dtype)]["lr_t"],
168 epsilon=tensor_conversion.convert_to_tensor_v2_with_dispatch(
169 self.epsilon, var_dtype
170 ),
171 rho=rho,
172 momentum=array_ops.identity(self._get_hyper("momentum", var_dtype)),
173 one_minus_rho=1.0 - rho,
174 )
175 )
177 def _resource_apply_dense(self, grad, var, apply_state=None):
178 var_device, var_dtype = var.device, var.dtype.base_dtype
179 coefficients = ((apply_state or {}).get((var_device, var_dtype))
180 or self._fallback_apply_state(var_device, var_dtype))
182 rms = self.get_slot(var, "rms")
183 if self._momentum:
184 mom = self.get_slot(var, "momentum")
185 if self.centered:
186 mg = self.get_slot(var, "mg")
187 return gen_training_ops.ResourceApplyCenteredRMSProp(
188 var=var.handle,
189 mg=mg.handle,
190 ms=rms.handle,
191 mom=mom.handle,
192 lr=coefficients["lr_t"],
193 rho=coefficients["rho"],
194 momentum=coefficients["momentum"],
195 epsilon=coefficients["epsilon"],
196 grad=grad,
197 use_locking=self._use_locking)
198 else:
199 return gen_training_ops.ResourceApplyRMSProp(
200 var=var.handle,
201 ms=rms.handle,
202 mom=mom.handle,
203 lr=coefficients["lr_t"],
204 rho=coefficients["rho"],
205 momentum=coefficients["momentum"],
206 epsilon=coefficients["epsilon"],
207 grad=grad,
208 use_locking=self._use_locking)
209 else:
210 rms_t = (coefficients["rho"] * rms +
211 coefficients["one_minus_rho"] * math_ops.square(grad))
212 rms_t = state_ops.assign(rms, rms_t, use_locking=self._use_locking)
213 denom_t = rms_t
214 if self.centered:
215 mg = self.get_slot(var, "mg")
216 mg_t = coefficients["rho"] * mg + coefficients["one_minus_rho"] * grad
217 mg_t = state_ops.assign(mg, mg_t, use_locking=self._use_locking)
218 denom_t = rms_t - math_ops.square(mg_t)
219 var_t = var - coefficients["lr_t"] * grad / (
220 math_ops.sqrt(denom_t) + coefficients["epsilon"])
221 return state_ops.assign(var, var_t, use_locking=self._use_locking).op
223 def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
224 var_device, var_dtype = var.device, var.dtype.base_dtype
225 coefficients = ((apply_state or {}).get((var_device, var_dtype))
226 or self._fallback_apply_state(var_device, var_dtype))
228 rms = self.get_slot(var, "rms")
229 if self._momentum:
230 mom = self.get_slot(var, "momentum")
231 if self.centered:
232 mg = self.get_slot(var, "mg")
233 return gen_training_ops.ResourceSparseApplyCenteredRMSProp(
234 var=var.handle,
235 mg=mg.handle,
236 ms=rms.handle,
237 mom=mom.handle,
238 lr=coefficients["lr_t"],
239 rho=coefficients["rho"],
240 momentum=coefficients["momentum"],
241 epsilon=coefficients["epsilon"],
242 grad=grad,
243 indices=indices,
244 use_locking=self._use_locking)
245 else:
246 return gen_training_ops.ResourceSparseApplyRMSProp(
247 var=var.handle,
248 ms=rms.handle,
249 mom=mom.handle,
250 lr=coefficients["lr_t"],
251 rho=coefficients["rho"],
252 momentum=coefficients["momentum"],
253 epsilon=coefficients["epsilon"],
254 grad=grad,
255 indices=indices,
256 use_locking=self._use_locking)
257 else:
258 rms_scaled_g_values = (grad * grad) * coefficients["one_minus_rho"]
259 rms_t = state_ops.assign(rms, rms * coefficients["rho"],
260 use_locking=self._use_locking)
261 with ops.control_dependencies([rms_t]):
262 rms_t = self._resource_scatter_add(rms, indices, rms_scaled_g_values)
263 rms_slice = array_ops.gather(rms_t, indices)
264 denom_slice = rms_slice
265 if self.centered:
266 mg = self.get_slot(var, "mg")
267 mg_scaled_g_values = grad * coefficients["one_minus_rho"]
268 mg_t = state_ops.assign(mg, mg * coefficients["rho"],
269 use_locking=self._use_locking)
270 with ops.control_dependencies([mg_t]):
271 mg_t = self._resource_scatter_add(mg, indices, mg_scaled_g_values)
272 mg_slice = array_ops.gather(mg_t, indices)
273 denom_slice = rms_slice - math_ops.square(mg_slice)
274 var_update = self._resource_scatter_add(
275 var, indices, coefficients["neg_lr_t"] * grad / (
276 math_ops.sqrt(denom_slice) + coefficients["epsilon"]))
277 if self.centered:
278 return control_flow_ops.group(*[var_update, rms_t, mg_t])
279 return control_flow_ops.group(*[var_update, rms_t])
281 def set_weights(self, weights):
282 params = self.weights
283 # Override set_weights for backward compatibility of Keras V1 optimizer
284 # since it does not include iteration at head of the weight list. Set
285 # iteration to 0.
286 if len(params) == len(weights) + 1:
287 weights = [np.array(0)] + weights
288 super(RMSprop, self).set_weights(weights)
290 def get_config(self):
291 config = super(RMSprop, self).get_config()
292 config.update({
293 "learning_rate": self._serialize_hyperparameter("learning_rate"),
294 "decay": self._initial_decay,
295 "rho": self._serialize_hyperparameter("rho"),
296 "momentum": self._serialize_hyperparameter("momentum"),
297 "epsilon": self.epsilon,
298 "centered": self.centered,
299 })
300 return config
303RMSProp = RMSprop