Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/rmsprop.py: 23%

99 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""RMSprop optimizer implementation.""" 

16# pylint: disable=g-classes-have-attributes 

17 

18import numpy as np 

19 

20from tensorflow.python.framework import ops 

21from tensorflow.python.framework import tensor_conversion 

22from tensorflow.python.keras import backend_config 

23from tensorflow.python.keras.optimizer_v2 import optimizer_v2 

24from tensorflow.python.ops import array_ops 

25from tensorflow.python.ops import control_flow_ops 

26from tensorflow.python.ops import math_ops 

27from tensorflow.python.ops import state_ops 

28from tensorflow.python.training import gen_training_ops 

29from tensorflow.python.util.tf_export import keras_export 

30 

31 

32@keras_export("keras.optimizers.RMSprop") 

33class RMSprop(optimizer_v2.OptimizerV2): 

34 r"""Optimizer that implements the RMSprop algorithm. 

35 

36 The gist of RMSprop is to: 

37 

38 - Maintain a moving (discounted) average of the square of gradients 

39 - Divide the gradient by the root of this average 

40 

41 This implementation of RMSprop uses plain momentum, not Nesterov momentum. 

42 

43 The centered version additionally maintains a moving average of the 

44 gradients, and uses that average to estimate the variance. 

45 

46 Args: 

47 learning_rate: A `Tensor`, floating point value, or a schedule that is a 

48 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable 

49 that takes no arguments and returns the actual value to use. The 

50 learning rate. Defaults to 0.001. 

51 rho: Discounting factor for the history/coming gradient. Defaults to 0.9. 

52 momentum: A scalar or a scalar `Tensor`. Defaults to 0.0. 

53 epsilon: A small constant for numerical stability. This epsilon is 

54 "epsilon hat" in the Kingma and Ba paper (in the formula just before 

55 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to 

56 1e-7. 

57 centered: Boolean. If `True`, gradients are normalized by the estimated 

58 variance of the gradient; if False, by the uncentered second moment. 

59 Setting this to `True` may help with training, but is slightly more 

60 expensive in terms of computation and memory. Defaults to `False`. 

61 name: Optional name prefix for the operations created when applying 

62 gradients. Defaults to `"RMSprop"`. 

63 **kwargs: Keyword arguments. Allowed to be one of 

64 `"clipnorm"` or `"clipvalue"`. 

65 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 

66 gradients by value. 

67 

68 Note that in the dense implementation of this algorithm, variables and their 

69 corresponding accumulators (momentum, gradient moving average, square 

70 gradient moving average) will be updated even if the gradient is zero 

71 (i.e. accumulators will decay, momentum will be applied). The sparse 

72 implementation (used when the gradient is an `IndexedSlices` object, 

73 typically because of `tf.gather` or an embedding lookup in the forward pass) 

74 will not update variable slices or their accumulators unless those slices 

75 were used in the forward pass (nor is there an "eventual" correction to 

76 account for these omitted updates). This leads to more efficient updates for 

77 large embedding lookup tables (where most of the slices are not accessed in 

78 a particular graph execution), but differs from the published algorithm. 

79 

80 Usage: 

81 

82 >>> opt = tf.keras.optimizers.RMSprop(learning_rate=0.1) 

83 >>> var1 = tf.Variable(10.0) 

84 >>> loss = lambda: (var1 ** 2) / 2.0 # d(loss) / d(var1) = var1 

85 >>> step_count = opt.minimize(loss, [var1]).numpy() 

86 >>> var1.numpy() 

87 9.683772 

88 

89 Reference: 

90 - [Hinton, 2012]( 

91 http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) 

92 """ 

93 

94 _HAS_AGGREGATE_GRAD = True 

95 

96 def __init__(self, 

97 learning_rate=0.001, 

98 rho=0.9, 

99 momentum=0.0, 

100 epsilon=1e-7, 

101 centered=False, 

102 name="RMSprop", 

103 **kwargs): 

104 """Construct a new RMSprop optimizer. 

105 

106 Args: 

107 learning_rate: A `Tensor`, floating point value, or a schedule that is a 

108 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable 

109 that takes no arguments and returns the actual value to use. The 

110 learning rate. Defaults to 0.001. 

111 rho: Discounting factor for the history/coming gradient. Defaults to 0.9. 

112 momentum: A scalar or a scalar `Tensor`. Defaults to 0.0. 

113 epsilon: A small constant for numerical stability. This epsilon is 

114 "epsilon hat" in the Kingma and Ba paper (in the formula just before 

115 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to 

116 1e-7. 

117 centered: Boolean. If `True`, gradients are normalized by the estimated 

118 variance of the gradient; if False, by the uncentered second moment. 

119 Setting this to `True` may help with training, but is slightly more 

120 expensive in terms of computation and memory. Defaults to `False`. 

121 name: Optional name prefix for the operations created when applying 

122 gradients. Defaults to "RMSprop". 

123 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, 

124 `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip 

125 gradients by value, `decay` is included for backward compatibility to 

126 allow time inverse decay of learning rate. `lr` is included for backward 

127 compatibility, recommended to use `learning_rate` instead. 

128 

129 @compatibility(eager) 

130 When eager execution is enabled, `learning_rate`, `decay`, `momentum`, and 

131 `epsilon` can each be a callable that takes no arguments and returns the 

132 actual value to use. This can be useful for changing these values across 

133 different invocations of optimizer functions. 

134 @end_compatibility 

135 """ 

136 super(RMSprop, self).__init__(name, **kwargs) 

137 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 

138 self._set_hyper("decay", self._initial_decay) 

139 self._set_hyper("rho", rho) 

140 

141 self._momentum = False 

142 if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0: 

143 self._momentum = True 

144 if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1): 

145 raise ValueError("`momentum` must be between [0, 1].") 

146 self._set_hyper("momentum", momentum) 

147 

148 self.epsilon = epsilon or backend_config.epsilon() 

149 self.centered = centered 

150 

151 def _create_slots(self, var_list): 

152 for var in var_list: 

153 self.add_slot(var, "rms") 

154 if self._momentum: 

155 for var in var_list: 

156 self.add_slot(var, "momentum") 

157 if self.centered: 

158 for var in var_list: 

159 self.add_slot(var, "mg") 

160 

161 def _prepare_local(self, var_device, var_dtype, apply_state): 

162 super(RMSprop, self)._prepare_local(var_device, var_dtype, apply_state) 

163 

164 rho = array_ops.identity(self._get_hyper("rho", var_dtype)) 

165 apply_state[(var_device, var_dtype)].update( 

166 dict( 

167 neg_lr_t=-apply_state[(var_device, var_dtype)]["lr_t"], 

168 epsilon=tensor_conversion.convert_to_tensor_v2_with_dispatch( 

169 self.epsilon, var_dtype 

170 ), 

171 rho=rho, 

172 momentum=array_ops.identity(self._get_hyper("momentum", var_dtype)), 

173 one_minus_rho=1.0 - rho, 

174 ) 

175 ) 

176 

177 def _resource_apply_dense(self, grad, var, apply_state=None): 

178 var_device, var_dtype = var.device, var.dtype.base_dtype 

179 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 

180 or self._fallback_apply_state(var_device, var_dtype)) 

181 

182 rms = self.get_slot(var, "rms") 

183 if self._momentum: 

184 mom = self.get_slot(var, "momentum") 

185 if self.centered: 

186 mg = self.get_slot(var, "mg") 

187 return gen_training_ops.ResourceApplyCenteredRMSProp( 

188 var=var.handle, 

189 mg=mg.handle, 

190 ms=rms.handle, 

191 mom=mom.handle, 

192 lr=coefficients["lr_t"], 

193 rho=coefficients["rho"], 

194 momentum=coefficients["momentum"], 

195 epsilon=coefficients["epsilon"], 

196 grad=grad, 

197 use_locking=self._use_locking) 

198 else: 

199 return gen_training_ops.ResourceApplyRMSProp( 

200 var=var.handle, 

201 ms=rms.handle, 

202 mom=mom.handle, 

203 lr=coefficients["lr_t"], 

204 rho=coefficients["rho"], 

205 momentum=coefficients["momentum"], 

206 epsilon=coefficients["epsilon"], 

207 grad=grad, 

208 use_locking=self._use_locking) 

209 else: 

210 rms_t = (coefficients["rho"] * rms + 

211 coefficients["one_minus_rho"] * math_ops.square(grad)) 

212 rms_t = state_ops.assign(rms, rms_t, use_locking=self._use_locking) 

213 denom_t = rms_t 

214 if self.centered: 

215 mg = self.get_slot(var, "mg") 

216 mg_t = coefficients["rho"] * mg + coefficients["one_minus_rho"] * grad 

217 mg_t = state_ops.assign(mg, mg_t, use_locking=self._use_locking) 

218 denom_t = rms_t - math_ops.square(mg_t) 

219 var_t = var - coefficients["lr_t"] * grad / ( 

220 math_ops.sqrt(denom_t) + coefficients["epsilon"]) 

221 return state_ops.assign(var, var_t, use_locking=self._use_locking).op 

222 

223 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 

224 var_device, var_dtype = var.device, var.dtype.base_dtype 

225 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 

226 or self._fallback_apply_state(var_device, var_dtype)) 

227 

228 rms = self.get_slot(var, "rms") 

229 if self._momentum: 

230 mom = self.get_slot(var, "momentum") 

231 if self.centered: 

232 mg = self.get_slot(var, "mg") 

233 return gen_training_ops.ResourceSparseApplyCenteredRMSProp( 

234 var=var.handle, 

235 mg=mg.handle, 

236 ms=rms.handle, 

237 mom=mom.handle, 

238 lr=coefficients["lr_t"], 

239 rho=coefficients["rho"], 

240 momentum=coefficients["momentum"], 

241 epsilon=coefficients["epsilon"], 

242 grad=grad, 

243 indices=indices, 

244 use_locking=self._use_locking) 

245 else: 

246 return gen_training_ops.ResourceSparseApplyRMSProp( 

247 var=var.handle, 

248 ms=rms.handle, 

249 mom=mom.handle, 

250 lr=coefficients["lr_t"], 

251 rho=coefficients["rho"], 

252 momentum=coefficients["momentum"], 

253 epsilon=coefficients["epsilon"], 

254 grad=grad, 

255 indices=indices, 

256 use_locking=self._use_locking) 

257 else: 

258 rms_scaled_g_values = (grad * grad) * coefficients["one_minus_rho"] 

259 rms_t = state_ops.assign(rms, rms * coefficients["rho"], 

260 use_locking=self._use_locking) 

261 with ops.control_dependencies([rms_t]): 

262 rms_t = self._resource_scatter_add(rms, indices, rms_scaled_g_values) 

263 rms_slice = array_ops.gather(rms_t, indices) 

264 denom_slice = rms_slice 

265 if self.centered: 

266 mg = self.get_slot(var, "mg") 

267 mg_scaled_g_values = grad * coefficients["one_minus_rho"] 

268 mg_t = state_ops.assign(mg, mg * coefficients["rho"], 

269 use_locking=self._use_locking) 

270 with ops.control_dependencies([mg_t]): 

271 mg_t = self._resource_scatter_add(mg, indices, mg_scaled_g_values) 

272 mg_slice = array_ops.gather(mg_t, indices) 

273 denom_slice = rms_slice - math_ops.square(mg_slice) 

274 var_update = self._resource_scatter_add( 

275 var, indices, coefficients["neg_lr_t"] * grad / ( 

276 math_ops.sqrt(denom_slice) + coefficients["epsilon"])) 

277 if self.centered: 

278 return control_flow_ops.group(*[var_update, rms_t, mg_t]) 

279 return control_flow_ops.group(*[var_update, rms_t]) 

280 

281 def set_weights(self, weights): 

282 params = self.weights 

283 # Override set_weights for backward compatibility of Keras V1 optimizer 

284 # since it does not include iteration at head of the weight list. Set 

285 # iteration to 0. 

286 if len(params) == len(weights) + 1: 

287 weights = [np.array(0)] + weights 

288 super(RMSprop, self).set_weights(weights) 

289 

290 def get_config(self): 

291 config = super(RMSprop, self).get_config() 

292 config.update({ 

293 "learning_rate": self._serialize_hyperparameter("learning_rate"), 

294 "decay": self._initial_decay, 

295 "rho": self._serialize_hyperparameter("rho"), 

296 "momentum": self._serialize_hyperparameter("momentum"), 

297 "epsilon": self.epsilon, 

298 "centered": self.centered, 

299 }) 

300 return config 

301 

302 

303RMSProp = RMSprop