Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/sgd.py: 26%
50 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SGD optimizer implementation."""
17import tensorflow.compat.v2 as tf
19from keras.src.optimizers import optimizer
20from keras.src.saving.object_registration import register_keras_serializable
22# isort: off
23from tensorflow.python.util.tf_export import keras_export
26@register_keras_serializable()
27@keras_export(
28 "keras.optimizers.experimental.SGD",
29 "keras.optimizers.SGD",
30 "keras.dtensor.experimental.optimizers.SGD",
31 v1=[],
32)
33class SGD(optimizer.Optimizer):
34 r"""Gradient descent (with momentum) optimizer.
36 Update rule for parameter `w` with gradient `g` when `momentum` is 0:
38 ```python
39 w = w - learning_rate * g
40 ```
42 Update rule when `momentum` is larger than 0:
44 ```python
45 velocity = momentum * velocity - learning_rate * g
46 w = w + velocity
47 ```
49 When `nesterov=True`, this rule becomes:
51 ```python
52 velocity = momentum * velocity - learning_rate * g
53 w = w + momentum * velocity - learning_rate * g
54 ```
56 Args:
57 learning_rate: A `Tensor`, floating point value, or a schedule that is a
58 `keras.optimizers.schedules.LearningRateSchedule`, or a callable
59 that takes no arguments and returns the actual value to use. The
60 learning rate. Defaults to 0.001.
61 momentum: float hyperparameter >= 0 that accelerates gradient descent in
62 the relevant direction and dampens oscillations. Defaults to 0, i.e.,
63 vanilla gradient descent.
64 nesterov: boolean. Whether to apply Nesterov momentum.
65 Defaults to `False`.
66 {{base_optimizer_keyword_args}}
68 Usage:
70 >>> opt = tf.keras.optimizers.SGD(learning_rate=0.1)
71 >>> var = tf.Variable(1.0)
72 >>> loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1
73 >>> opt.minimize(loss, [var])
74 >>> # Step is `- learning_rate * grad`
75 >>> var.numpy()
76 0.9
78 >>> opt = tf.keras.optimizers.SGD(0.1, momentum=0.9)
79 >>> var = tf.Variable(1.0)
80 >>> val0 = var.value()
81 >>> loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1
82 >>> # First step is `- learning_rate * grad`
83 >>> opt.minimize(loss, [var])
84 >>> val1 = var.value()
85 >>> (val0 - val1).numpy()
86 0.1
87 >>> # On later steps, step-size increases because of momentum
88 >>> opt.minimize(loss, [var])
89 >>> val2 = var.value()
90 >>> (val1 - val2).numpy()
91 0.18
93 Reference:
94 - For `nesterov=True`, See [Sutskever et al., 2013](
95 http://proceedings.mlr.press/v28/sutskever13.pdf).
96 """
98 def __init__(
99 self,
100 learning_rate=0.01,
101 momentum=0.0,
102 nesterov=False,
103 weight_decay=None,
104 clipnorm=None,
105 clipvalue=None,
106 global_clipnorm=None,
107 use_ema=False,
108 ema_momentum=0.99,
109 ema_overwrite_frequency=None,
110 jit_compile=True,
111 name="SGD",
112 **kwargs
113 ):
114 super().__init__(
115 name=name,
116 weight_decay=weight_decay,
117 clipnorm=clipnorm,
118 clipvalue=clipvalue,
119 global_clipnorm=global_clipnorm,
120 use_ema=use_ema,
121 ema_momentum=ema_momentum,
122 ema_overwrite_frequency=ema_overwrite_frequency,
123 jit_compile=jit_compile,
124 **kwargs
125 )
126 self._learning_rate = self._build_learning_rate(learning_rate)
127 self.momentum = momentum
128 self.nesterov = nesterov
129 if isinstance(momentum, (int, float)) and (
130 momentum < 0 or momentum > 1
131 ):
132 raise ValueError("`momentum` must be between [0, 1].")
134 def build(self, var_list):
135 """Initialize optimizer variables.
137 SGD optimizer has one variable `momentums`, only set if `self.momentum`
138 is not 0.
140 Args:
141 var_list: list of model variables to build SGD variables on.
142 """
143 super().build(var_list)
144 if hasattr(self, "_built") and self._built:
145 return
146 self.momentums = []
147 for var in var_list:
148 self.momentums.append(
149 self.add_variable_from_reference(
150 model_variable=var, variable_name="m"
151 )
152 )
153 self._built = True
155 def update_step(self, gradient, variable):
156 """Update step given gradient and the associated model variable."""
157 lr = tf.cast(self.learning_rate, variable.dtype)
158 m = None
159 var_key = self._var_key(variable)
160 momentum = tf.cast(self.momentum, variable.dtype)
161 m = self.momentums[self._index_dict[var_key]]
163 # TODO(b/204321487): Add nesterov acceleration.
164 if isinstance(gradient, tf.IndexedSlices):
165 # Sparse gradients.
166 add_value = tf.IndexedSlices(
167 -gradient.values * lr, gradient.indices
168 )
169 if m is not None:
170 m.assign(m * momentum)
171 m.scatter_add(add_value)
172 if self.nesterov:
173 variable.scatter_add(add_value)
174 variable.assign_add(m * momentum)
175 else:
176 variable.assign_add(m)
177 else:
178 variable.scatter_add(add_value)
179 else:
180 # Dense gradients
181 if m is not None:
182 m.assign(-gradient * lr + m * momentum)
183 if self.nesterov:
184 variable.assign_add(-gradient * lr + m * momentum)
185 else:
186 variable.assign_add(m)
187 else:
188 variable.assign_add(-gradient * lr)
190 def get_config(self):
191 config = super().get_config()
193 config.update(
194 {
195 "learning_rate": self._serialize_hyperparameter(
196 self._learning_rate
197 ),
198 "momentum": self.momentum,
199 "nesterov": self.nesterov,
200 }
201 )
202 return config
205SGD.__doc__ = SGD.__doc__.replace(
206 "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
207)