Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/__init__.py: 47%

106 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Built-in optimizer classes. 

16 

17For more examples see the base class `tf.keras.optimizers.Optimizer`. 

18""" 

19 

20# Imports needed for deserialization. 

21 

22import platform 

23 

24import tensorflow.compat.v2 as tf 

25from absl import logging 

26 

27from keras.src import backend 

28from keras.src.optimizers import adadelta 

29from keras.src.optimizers import adafactor 

30from keras.src.optimizers import adagrad 

31from keras.src.optimizers import adam 

32from keras.src.optimizers import adamax 

33from keras.src.optimizers import adamw 

34from keras.src.optimizers import ftrl 

35from keras.src.optimizers import lion 

36from keras.src.optimizers import nadam 

37from keras.src.optimizers import optimizer as base_optimizer 

38from keras.src.optimizers import rmsprop 

39from keras.src.optimizers import sgd 

40from keras.src.optimizers.legacy import adadelta as adadelta_legacy 

41from keras.src.optimizers.legacy import adagrad as adagrad_legacy 

42from keras.src.optimizers.legacy import adam as adam_legacy 

43from keras.src.optimizers.legacy import adamax as adamax_legacy 

44from keras.src.optimizers.legacy import ftrl as ftrl_legacy 

45from keras.src.optimizers.legacy import gradient_descent as gradient_descent_legacy 

46from keras.src.optimizers.legacy import nadam as nadam_legacy 

47from keras.src.optimizers.legacy import optimizer_v2 as base_optimizer_legacy 

48from keras.src.optimizers.legacy import rmsprop as rmsprop_legacy 

49from keras.src.optimizers.legacy.adadelta import Adadelta 

50from keras.src.optimizers.legacy.adagrad import Adagrad 

51from keras.src.optimizers.legacy.adam import Adam 

52from keras.src.optimizers.legacy.adamax import Adamax 

53from keras.src.optimizers.legacy.ftrl import Ftrl 

54 

55# Symbols to be accessed under keras.optimizers. To be replaced with 

56# optimizers v2022 when they graduate out of experimental. 

57from keras.src.optimizers.legacy.gradient_descent import SGD 

58from keras.src.optimizers.legacy.nadam import Nadam 

59from keras.src.optimizers.legacy.rmsprop import RMSprop 

60from keras.src.optimizers.optimizer_v1 import Optimizer 

61from keras.src.optimizers.optimizer_v1 import TFOptimizer 

62from keras.src.optimizers.schedules import learning_rate_schedule 

63from keras.src.saving.legacy import serialization as legacy_serialization 

64from keras.src.saving.serialization_lib import deserialize_keras_object 

65from keras.src.saving.serialization_lib import serialize_keras_object 

66 

67# isort: off 

68from tensorflow.python.util.tf_export import keras_export 

69 

70# pylint: disable=line-too-long 

71 

72 

73@keras_export("keras.optimizers.serialize") 

74def serialize(optimizer, use_legacy_format=False): 

75 """Serialize the optimizer configuration to JSON compatible python dict. 

76 

77 The configuration can be used for persistence and reconstruct the 

78 `Optimizer` instance again. 

79 

80 >>> tf.keras.optimizers.serialize(tf.keras.optimizers.legacy.SGD()) 

81 {'module': 'keras.optimizers.legacy', 'class_name': 'SGD', 'config': {'name': 'SGD', 'learning_rate': 0.01, 'decay': 0.0, 'momentum': 0.0, 'nesterov': False}, 'registered_name': None}""" # noqa: E501 

82 """ 

83 Args: 

84 optimizer: An `Optimizer` instance to serialize. 

85 

86 Returns: 

87 Python dict which contains the configuration of the input optimizer. 

88 """ 

89 if use_legacy_format: 

90 return legacy_serialization.serialize_keras_object(optimizer) 

91 return serialize_keras_object(optimizer) 

92 

93 

94def is_arm_mac(): 

95 return platform.system() == "Darwin" and platform.processor() == "arm" 

96 

97 

98@keras_export("keras.optimizers.deserialize") 

99def deserialize(config, custom_objects=None, use_legacy_format=False, **kwargs): 

100 """Inverse of the `serialize` function. 

101 

102 Args: 

103 config: Optimizer configuration dictionary. 

104 custom_objects: Optional dictionary mapping names (strings) to custom 

105 objects (classes and functions) to be considered during 

106 deserialization. 

107 

108 Returns: 

109 A Keras Optimizer instance. 

110 """ 

111 # loss_scale_optimizer has a direct dependency of optimizer, import here 

112 # rather than top to avoid the cyclic dependency. 

113 from keras.src.mixed_precision import ( 

114 loss_scale_optimizer, 

115 ) 

116 

117 use_legacy_optimizer = kwargs.pop("use_legacy_optimizer", False) 

118 if kwargs: 

119 raise TypeError(f"Invalid keyword arguments: {kwargs}") 

120 if len(config["config"]) > 0: 

121 # If the optimizer config is not empty, then we use the value of 

122 # `is_legacy_optimizer` to override `use_legacy_optimizer`. If 

123 # `is_legacy_optimizer` does not exist in config, it means we are 

124 # using the legacy optimzier. 

125 use_legacy_optimizer = config["config"].get("is_legacy_optimizer", True) 

126 if ( 

127 tf.__internal__.tf2.enabled() 

128 and tf.executing_eagerly() 

129 and not is_arm_mac() 

130 and not use_legacy_optimizer 

131 ): 

132 # We observed a slowdown of optimizer on M1 Mac, so we fall back to the 

133 # legacy optimizer for M1 users now, see b/263339144 for more context. 

134 all_classes = { 

135 "adadelta": adadelta.Adadelta, 

136 "adagrad": adagrad.Adagrad, 

137 "adam": adam.Adam, 

138 "adamax": adamax.Adamax, 

139 "experimentaladadelta": adadelta.Adadelta, 

140 "experimentaladagrad": adagrad.Adagrad, 

141 "experimentaladam": adam.Adam, 

142 "experimentalsgd": sgd.SGD, 

143 "nadam": nadam.Nadam, 

144 "rmsprop": rmsprop.RMSprop, 

145 "sgd": sgd.SGD, 

146 "ftrl": ftrl.Ftrl, 

147 "lossscaleoptimizer": loss_scale_optimizer.LossScaleOptimizerV3, 

148 "lossscaleoptimizerv3": loss_scale_optimizer.LossScaleOptimizerV3, 

149 # LossScaleOptimizerV1 was an old version of LSO that was removed. 

150 # Deserializing it turns it into a LossScaleOptimizer 

151 "lossscaleoptimizerv1": loss_scale_optimizer.LossScaleOptimizer, 

152 } 

153 else: 

154 all_classes = { 

155 "adadelta": adadelta_legacy.Adadelta, 

156 "adagrad": adagrad_legacy.Adagrad, 

157 "adam": adam_legacy.Adam, 

158 "adamax": adamax_legacy.Adamax, 

159 "experimentaladadelta": adadelta.Adadelta, 

160 "experimentaladagrad": adagrad.Adagrad, 

161 "experimentaladam": adam.Adam, 

162 "experimentalsgd": sgd.SGD, 

163 "nadam": nadam_legacy.Nadam, 

164 "rmsprop": rmsprop_legacy.RMSprop, 

165 "sgd": gradient_descent_legacy.SGD, 

166 "ftrl": ftrl_legacy.Ftrl, 

167 "lossscaleoptimizer": loss_scale_optimizer.LossScaleOptimizer, 

168 "lossscaleoptimizerv3": loss_scale_optimizer.LossScaleOptimizerV3, 

169 # LossScaleOptimizerV1 was an old version of LSO that was removed. 

170 # Deserializing it turns it into a LossScaleOptimizer 

171 "lossscaleoptimizerv1": loss_scale_optimizer.LossScaleOptimizer, 

172 } 

173 

174 # Make deserialization case-insensitive for built-in optimizers. 

175 if config["class_name"].lower() in all_classes: 

176 config["class_name"] = config["class_name"].lower() 

177 

178 if use_legacy_format: 

179 return legacy_serialization.deserialize_keras_object( 

180 config, 

181 module_objects=all_classes, 

182 custom_objects=custom_objects, 

183 printable_module_name="optimizer", 

184 ) 

185 

186 return deserialize_keras_object( 

187 config, 

188 module_objects=all_classes, 

189 custom_objects=custom_objects, 

190 printable_module_name="optimizer", 

191 ) 

192 

193 

194@keras_export( 

195 "keras.__internal__.optimizers.convert_to_legacy_optimizer", v1=[] 

196) 

197def convert_to_legacy_optimizer(optimizer): 

198 """Convert experimental optimizer to legacy optimizer. 

199 

200 This function takes in a `keras.optimizers.Optimizer` 

201 instance and converts it to the corresponding 

202 `keras.optimizers.legacy.Optimizer` instance. 

203 For example, `keras.optimizers.Adam(...)` to 

204 `keras.optimizers.legacy.Adam(...)`. 

205 

206 Args: 

207 optimizer: An instance of `keras.optimizers.Optimizer`. 

208 """ 

209 # loss_scale_optimizer has a direct dependency of optimizer, import here 

210 # rather than top to avoid the cyclic dependency. 

211 from keras.src.mixed_precision import ( 

212 loss_scale_optimizer, 

213 ) 

214 

215 if not isinstance(optimizer, base_optimizer.Optimizer): 

216 raise ValueError( 

217 "`convert_to_legacy_optimizer` should only be called " 

218 "on instances of `tf.keras.optimizers.Optimizer`, but " 

219 f"received {optimizer} of type {type(optimizer)}." 

220 ) 

221 optimizer_name = optimizer.__class__.__name__.lower() 

222 config = optimizer.get_config() 

223 # Remove fields that only exist in experimental optimizer. 

224 keys_to_remove = [ 

225 "weight_decay", 

226 "use_ema", 

227 "ema_momentum", 

228 "ema_overwrite_frequency", 

229 "jit_compile", 

230 "is_legacy_optimizer", 

231 ] 

232 for key in keys_to_remove: 

233 config.pop(key, None) 

234 

235 if isinstance(optimizer, loss_scale_optimizer.LossScaleOptimizerV3): 

236 # For LossScaleOptimizers, recursively convert the inner optimizer 

237 config["inner_optimizer"] = convert_to_legacy_optimizer( 

238 optimizer.inner_optimizer 

239 ) 

240 if optimizer_name == "lossscaleoptimizerv3": 

241 optimizer_name = "lossscaleoptimizer" 

242 

243 # Learning rate can be a custom LearningRateSchedule, which is stored as 

244 # a dict in config, and cannot be deserialized. 

245 if hasattr(optimizer, "_learning_rate") and isinstance( 

246 optimizer._learning_rate, learning_rate_schedule.LearningRateSchedule 

247 ): 

248 config["learning_rate"] = optimizer._learning_rate 

249 legacy_optimizer_config = { 

250 "class_name": optimizer_name, 

251 "config": config, 

252 } 

253 return deserialize(legacy_optimizer_config, use_legacy_optimizer=True) 

254 

255 

256@keras_export("keras.optimizers.get") 

257def get(identifier, **kwargs): 

258 """Retrieves a Keras Optimizer instance. 

259 

260 Args: 

261 identifier: Optimizer identifier, one of - String: name of an optimizer 

262 - Dictionary: configuration dictionary. - Keras Optimizer instance (it 

263 will be returned unchanged). - TensorFlow Optimizer instance (it will 

264 be wrapped as a Keras Optimizer). 

265 

266 Returns: 

267 A Keras Optimizer instance. 

268 

269 Raises: 

270 ValueError: If `identifier` cannot be interpreted. 

271 """ 

272 use_legacy_optimizer = kwargs.pop("use_legacy_optimizer", False) 

273 if kwargs: 

274 raise TypeError(f"Invalid keyword arguments: {kwargs}") 

275 if isinstance( 

276 identifier, 

277 ( 

278 Optimizer, 

279 base_optimizer_legacy.OptimizerV2, 

280 ), 

281 ): 

282 return identifier 

283 elif isinstance(identifier, base_optimizer.Optimizer): 

284 if tf.__internal__.tf2.enabled() and not is_arm_mac(): 

285 return identifier 

286 else: 

287 # If TF2 is disabled or on a M1 mac, we convert to the legacy 

288 # optimizer. We observed a slowdown of optimizer on M1 Mac, so we 

289 # fall back to the legacy optimizer for now, see b/263339144 

290 # for more context. 

291 optimizer_name = identifier.__class__.__name__ 

292 logging.warning( 

293 "There is a known slowdown when using v2.11+ Keras optimizers " 

294 "on M1/M2 Macs. Falling back to the " 

295 "legacy Keras optimizer, i.e., " 

296 f"`tf.keras.optimizers.legacy.{optimizer_name}`." 

297 ) 

298 return convert_to_legacy_optimizer(identifier) 

299 

300 # Wrap legacy TF optimizer instances 

301 elif isinstance(identifier, tf.compat.v1.train.Optimizer): 

302 opt = TFOptimizer(identifier) 

303 backend.track_tf_optimizer(opt) 

304 return opt 

305 elif isinstance(identifier, dict): 

306 use_legacy_format = "module" not in identifier 

307 return deserialize( 

308 identifier, 

309 use_legacy_optimizer=use_legacy_optimizer, 

310 use_legacy_format=use_legacy_format, 

311 ) 

312 elif isinstance(identifier, str): 

313 config = {"class_name": str(identifier), "config": {}} 

314 return get( 

315 config, 

316 use_legacy_optimizer=use_legacy_optimizer, 

317 ) 

318 else: 

319 raise ValueError( 

320 f"Could not interpret optimizer identifier: {identifier}" 

321 ) 

322