Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/rmsprop.py: 18%

72 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""RMSprop optimizer implementation.""" 

16 

17import tensorflow.compat.v2 as tf 

18 

19from keras.src.optimizers import optimizer 

20from keras.src.saving.object_registration import register_keras_serializable 

21 

22# isort: off 

23from tensorflow.python.util.tf_export import keras_export 

24 

25 

26@register_keras_serializable() 

27@keras_export( 

28 "keras.optimizers.experimental.RMSprop", 

29 "keras.optimizers.RMSprop", 

30 "keras.dtensor.experimental.optimizers.RMSprop", 

31 v1=[], 

32) 

33class RMSprop(optimizer.Optimizer): 

34 r"""Optimizer that implements the RMSprop algorithm. 

35 

36 The gist of RMSprop is to: 

37 

38 - Maintain a moving (discounted) average of the square of gradients 

39 - Divide the gradient by the root of this average 

40 

41 This implementation of RMSprop uses plain momentum, not Nesterov momentum. 

42 

43 The centered version additionally maintains a moving average of the 

44 gradients, and uses that average to estimate the variance. 

45 

46 Args: 

47 learning_rate: Initial value for the learning rate: 

48 either a floating point value, 

49 or a `tf.keras.optimizers.schedules.LearningRateSchedule` instance. 

50 Defaults to 0.001. 

51 rho: float, defaults to 0.9. Discounting factor for the old gradients. 

52 momentum: float, defaults to 0.0. If not 0.0., the optimizer tracks the 

53 momentum value, with a decay rate equals to `1 - momentum`. 

54 epsilon: A small constant for numerical stability. This epsilon is 

55 "epsilon hat" in the Kingma and Ba paper (in the formula just before 

56 Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to 

57 1e-7. 

58 centered: Boolean. If `True`, gradients are normalized by the estimated 

59 variance of the gradient; if False, by the uncentered second moment. 

60 Setting this to `True` may help with training, but is slightly more 

61 expensive in terms of computation and memory. Defaults to `False`. 

62 {{base_optimizer_keyword_args}} 

63 

64 Usage: 

65 

66 >>> opt = tf.keras.optimizers.RMSprop(learning_rate=0.1) 

67 >>> var1 = tf.Variable(10.0) 

68 >>> loss = lambda: (var1 ** 2) / 2.0 # d(loss) / d(var1) = var1 

69 >>> opt.minimize(loss, [var1]) 

70 >>> var1.numpy() 

71 9.683772 

72 

73 Reference: 

74 - [Hinton, 2012]( 

75 http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) 

76 """ 

77 

78 def __init__( 

79 self, 

80 learning_rate=0.001, 

81 rho=0.9, 

82 momentum=0.0, 

83 epsilon=1e-7, 

84 centered=False, 

85 weight_decay=None, 

86 clipnorm=None, 

87 clipvalue=None, 

88 global_clipnorm=None, 

89 use_ema=False, 

90 ema_momentum=0.99, 

91 ema_overwrite_frequency=100, 

92 jit_compile=True, 

93 name="RMSprop", 

94 **kwargs 

95 ): 

96 super().__init__( 

97 weight_decay=weight_decay, 

98 clipnorm=clipnorm, 

99 clipvalue=clipvalue, 

100 global_clipnorm=global_clipnorm, 

101 use_ema=use_ema, 

102 ema_momentum=ema_momentum, 

103 ema_overwrite_frequency=ema_overwrite_frequency, 

104 jit_compile=jit_compile, 

105 name=name, 

106 **kwargs 

107 ) 

108 self._learning_rate = self._build_learning_rate(learning_rate) 

109 self.rho = rho 

110 self.momentum = momentum 

111 self.epsilon = epsilon 

112 self.centered = centered 

113 

114 def build(self, var_list): 

115 super().build(var_list) 

116 if hasattr(self, "_built") and self._built: 

117 return 

118 self._built = True 

119 

120 self._velocities = [] 

121 for var in var_list: 

122 self._velocities.append( 

123 self.add_variable_from_reference(var, "velocity") 

124 ) 

125 

126 self._momentums = [] 

127 if self.momentum > 0: 

128 for var in var_list: 

129 self._momentums.append( 

130 self.add_variable_from_reference(var, "momentum") 

131 ) 

132 

133 self._average_gradients = [] 

134 if self.centered: 

135 for var in var_list: 

136 self._average_gradients.append( 

137 self.add_variable_from_reference(var, "average_gradient") 

138 ) 

139 

140 def update_step(self, gradient, variable): 

141 """Update step given gradient and the associated model variable.""" 

142 lr = tf.cast(self.learning_rate, variable.dtype) 

143 

144 var_key = self._var_key(variable) 

145 velocity = self._velocities[self._index_dict[var_key]] 

146 momentum = None 

147 if self.momentum > 0: 

148 momentum = self._momentums[self._index_dict[var_key]] 

149 average_grad = None 

150 if self.centered: 

151 average_grad = self._average_gradients[self._index_dict[var_key]] 

152 

153 rho = self.rho 

154 

155 if isinstance(gradient, tf.IndexedSlices): 

156 # Sparse gradients. 

157 velocity.assign(rho * velocity) 

158 velocity.scatter_add( 

159 tf.IndexedSlices( 

160 tf.square(gradient.values) * (1 - rho), gradient.indices 

161 ) 

162 ) 

163 if self.centered: 

164 average_grad.assign(rho * average_grad) 

165 average_grad.scatter_add( 

166 tf.IndexedSlices( 

167 gradient.values * (1 - rho), gradient.indices 

168 ) 

169 ) 

170 denominator = velocity - tf.square(average_grad) + self.epsilon 

171 else: 

172 denominator = velocity + self.epsilon 

173 denominator_slices = tf.gather(denominator, gradient.indices) 

174 increment = tf.IndexedSlices( 

175 lr * gradient.values * tf.math.rsqrt(denominator_slices), 

176 gradient.indices, 

177 ) 

178 

179 if self.momentum > 0: 

180 momentum.assign(self.momentum * momentum) 

181 momentum.scatter_add(increment) 

182 variable.assign_add(-momentum) 

183 else: 

184 variable.scatter_add(-increment) 

185 else: 

186 # Dense gradients. 

187 velocity.assign(rho * velocity + (1 - rho) * tf.square(gradient)) 

188 if self.centered: 

189 average_grad.assign(rho * average_grad + (1 - rho) * gradient) 

190 denominator = velocity - tf.square(average_grad) + self.epsilon 

191 else: 

192 denominator = velocity + self.epsilon 

193 increment = lr * gradient * tf.math.rsqrt(denominator) 

194 if self.momentum > 0: 

195 momentum.assign(self.momentum * momentum + increment) 

196 variable.assign_add(-momentum) 

197 else: 

198 variable.assign_add(-increment) 

199 

200 def get_config(self): 

201 config = super().get_config() 

202 

203 config.update( 

204 { 

205 "learning_rate": self._serialize_hyperparameter( 

206 self._learning_rate 

207 ), 

208 "rho": self.rho, 

209 "momentum": self.momentum, 

210 "epsilon": self.epsilon, 

211 "centered": self.centered, 

212 } 

213 ) 

214 return config 

215 

216 

217RMSprop.__doc__ = RMSprop.__doc__.replace( 

218 "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args 

219) 

220