Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/gradient_descent.py: 34%
50 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SGD optimizer implementation."""
16# pylint: disable=g-classes-have-attributes
18from tensorflow.python.framework import ops
19from tensorflow.python.keras.optimizer_v2 import optimizer_v2
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import gen_resource_variable_ops
22from tensorflow.python.training import gen_training_ops
23from tensorflow.python.util.tf_export import keras_export
26@keras_export("keras.optimizers.SGD")
27class SGD(optimizer_v2.OptimizerV2):
28 r"""Gradient descent (with momentum) optimizer.
30 Update rule for parameter `w` with gradient `g` when `momentum` is 0:
32 ```python
33 w = w - learning_rate * g
34 ```
36 Update rule when `momentum` is larger than 0:
38 ```python
39 velocity = momentum * velocity - learning_rate * g
40 w = w + velocity
41 ```
43 When `nesterov=True`, this rule becomes:
45 ```python
46 velocity = momentum * velocity - learning_rate * g
47 w = w + momentum * velocity - learning_rate * g
48 ```
50 Args:
51 learning_rate: A `Tensor`, floating point value, or a schedule that is a
52 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
53 that takes no arguments and returns the actual value to use. The
54 learning rate. Defaults to 0.01.
55 momentum: float hyperparameter >= 0 that accelerates gradient descent
56 in the relevant
57 direction and dampens oscillations. Defaults to 0, i.e., vanilla gradient
58 descent.
59 nesterov: boolean. Whether to apply Nesterov momentum.
60 Defaults to `False`.
61 name: Optional name prefix for the operations created when applying
62 gradients. Defaults to `"SGD"`.
63 **kwargs: Keyword arguments. Allowed to be one of
64 `"clipnorm"` or `"clipvalue"`.
65 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
66 gradients by value.
68 Usage:
70 >>> opt = tf.keras.optimizers.SGD(learning_rate=0.1)
71 >>> var = tf.Variable(1.0)
72 >>> loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1
73 >>> step_count = opt.minimize(loss, [var]).numpy()
74 >>> # Step is `- learning_rate * grad`
75 >>> var.numpy()
76 0.9
78 >>> opt = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
79 >>> var = tf.Variable(1.0)
80 >>> val0 = var.value()
81 >>> loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1
82 >>> # First step is `- learning_rate * grad`
83 >>> step_count = opt.minimize(loss, [var]).numpy()
84 >>> val1 = var.value()
85 >>> (val0 - val1).numpy()
86 0.1
87 >>> # On later steps, step-size increases because of momentum
88 >>> step_count = opt.minimize(loss, [var]).numpy()
89 >>> val2 = var.value()
90 >>> (val1 - val2).numpy()
91 0.18
93 Reference:
94 - For `nesterov=True`, See [Sutskever et al., 2013](
95 http://jmlr.org/proceedings/papers/v28/sutskever13.pdf).
96 """
98 _HAS_AGGREGATE_GRAD = True
100 def __init__(self,
101 learning_rate=0.01,
102 momentum=0.0,
103 nesterov=False,
104 name="SGD",
105 **kwargs):
106 super(SGD, self).__init__(name, **kwargs)
107 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
108 self._set_hyper("decay", self._initial_decay)
110 self._momentum = False
111 if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0:
112 self._momentum = True
113 if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1):
114 raise ValueError("`momentum` must be between [0, 1].")
115 self._set_hyper("momentum", momentum)
117 self.nesterov = nesterov
119 def _create_slots(self, var_list):
120 if self._momentum:
121 for var in var_list:
122 self.add_slot(var, "momentum")
124 def _prepare_local(self, var_device, var_dtype, apply_state):
125 super(SGD, self)._prepare_local(var_device, var_dtype, apply_state)
126 apply_state[(var_device, var_dtype)]["momentum"] = array_ops.identity(
127 self._get_hyper("momentum", var_dtype))
129 def _resource_apply_dense(self, grad, var, apply_state=None):
130 var_device, var_dtype = var.device, var.dtype.base_dtype
131 coefficients = ((apply_state or {}).get((var_device, var_dtype))
132 or self._fallback_apply_state(var_device, var_dtype))
134 if self._momentum:
135 momentum_var = self.get_slot(var, "momentum")
136 return gen_training_ops.ResourceApplyKerasMomentum(
137 var=var.handle,
138 accum=momentum_var.handle,
139 lr=coefficients["lr_t"],
140 grad=grad,
141 momentum=coefficients["momentum"],
142 use_locking=self._use_locking,
143 use_nesterov=self.nesterov)
144 else:
145 return gen_training_ops.ResourceApplyGradientDescent(
146 var=var.handle,
147 alpha=coefficients["lr_t"],
148 delta=grad,
149 use_locking=self._use_locking)
151 def _resource_apply_sparse_duplicate_indices(self, grad, var, indices,
152 **kwargs):
153 if self._momentum:
154 return super(SGD, self)._resource_apply_sparse_duplicate_indices(
155 grad, var, indices, **kwargs)
156 else:
157 var_device, var_dtype = var.device, var.dtype.base_dtype
158 coefficients = (kwargs.get("apply_state", {}).get((var_device, var_dtype))
159 or self._fallback_apply_state(var_device, var_dtype))
161 return gen_resource_variable_ops.ResourceScatterAdd(
162 resource=var.handle,
163 indices=indices,
164 updates=-grad * coefficients["lr_t"])
166 def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
167 # This method is only needed for momentum optimization.
168 var_device, var_dtype = var.device, var.dtype.base_dtype
169 coefficients = ((apply_state or {}).get((var_device, var_dtype))
170 or self._fallback_apply_state(var_device, var_dtype))
172 momentum_var = self.get_slot(var, "momentum")
173 return gen_training_ops.ResourceSparseApplyKerasMomentum(
174 var=var.handle,
175 accum=momentum_var.handle,
176 lr=coefficients["lr_t"],
177 grad=grad,
178 indices=indices,
179 momentum=coefficients["momentum"],
180 use_locking=self._use_locking,
181 use_nesterov=self.nesterov)
183 def get_config(self):
184 config = super(SGD, self).get_config()
185 config.update({
186 "learning_rate": self._serialize_hyperparameter("learning_rate"),
187 "decay": self._initial_decay,
188 "momentum": self._serialize_hyperparameter("momentum"),
189 "nesterov": self.nesterov,
190 })
191 return config