Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/optimizers/__init__.py: 47%
106 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Built-in optimizer classes.
17For more examples see the base class `tf.keras.optimizers.Optimizer`.
18"""
20# Imports needed for deserialization.
22import platform
24import tensorflow.compat.v2 as tf
25from absl import logging
27from keras.src import backend
28from keras.src.optimizers import adadelta
29from keras.src.optimizers import adafactor
30from keras.src.optimizers import adagrad
31from keras.src.optimizers import adam
32from keras.src.optimizers import adamax
33from keras.src.optimizers import adamw
34from keras.src.optimizers import ftrl
35from keras.src.optimizers import lion
36from keras.src.optimizers import nadam
37from keras.src.optimizers import optimizer as base_optimizer
38from keras.src.optimizers import rmsprop
39from keras.src.optimizers import sgd
40from keras.src.optimizers.legacy import adadelta as adadelta_legacy
41from keras.src.optimizers.legacy import adagrad as adagrad_legacy
42from keras.src.optimizers.legacy import adam as adam_legacy
43from keras.src.optimizers.legacy import adamax as adamax_legacy
44from keras.src.optimizers.legacy import ftrl as ftrl_legacy
45from keras.src.optimizers.legacy import gradient_descent as gradient_descent_legacy
46from keras.src.optimizers.legacy import nadam as nadam_legacy
47from keras.src.optimizers.legacy import optimizer_v2 as base_optimizer_legacy
48from keras.src.optimizers.legacy import rmsprop as rmsprop_legacy
49from keras.src.optimizers.legacy.adadelta import Adadelta
50from keras.src.optimizers.legacy.adagrad import Adagrad
51from keras.src.optimizers.legacy.adam import Adam
52from keras.src.optimizers.legacy.adamax import Adamax
53from keras.src.optimizers.legacy.ftrl import Ftrl
55# Symbols to be accessed under keras.optimizers. To be replaced with
56# optimizers v2022 when they graduate out of experimental.
57from keras.src.optimizers.legacy.gradient_descent import SGD
58from keras.src.optimizers.legacy.nadam import Nadam
59from keras.src.optimizers.legacy.rmsprop import RMSprop
60from keras.src.optimizers.optimizer_v1 import Optimizer
61from keras.src.optimizers.optimizer_v1 import TFOptimizer
62from keras.src.optimizers.schedules import learning_rate_schedule
63from keras.src.saving.legacy import serialization as legacy_serialization
64from keras.src.saving.serialization_lib import deserialize_keras_object
65from keras.src.saving.serialization_lib import serialize_keras_object
67# isort: off
68from tensorflow.python.util.tf_export import keras_export
70# pylint: disable=line-too-long
73@keras_export("keras.optimizers.serialize")
74def serialize(optimizer, use_legacy_format=False):
75 """Serialize the optimizer configuration to JSON compatible python dict.
77 The configuration can be used for persistence and reconstruct the
78 `Optimizer` instance again.
80 >>> tf.keras.optimizers.serialize(tf.keras.optimizers.legacy.SGD())
81 {'module': 'keras.optimizers.legacy', 'class_name': 'SGD', 'config': {'name': 'SGD', 'learning_rate': 0.01, 'decay': 0.0, 'momentum': 0.0, 'nesterov': False}, 'registered_name': None}""" # noqa: E501
82 """
83 Args:
84 optimizer: An `Optimizer` instance to serialize.
86 Returns:
87 Python dict which contains the configuration of the input optimizer.
88 """
89 if use_legacy_format:
90 return legacy_serialization.serialize_keras_object(optimizer)
91 return serialize_keras_object(optimizer)
94def is_arm_mac():
95 return platform.system() == "Darwin" and platform.processor() == "arm"
98@keras_export("keras.optimizers.deserialize")
99def deserialize(config, custom_objects=None, use_legacy_format=False, **kwargs):
100 """Inverse of the `serialize` function.
102 Args:
103 config: Optimizer configuration dictionary.
104 custom_objects: Optional dictionary mapping names (strings) to custom
105 objects (classes and functions) to be considered during
106 deserialization.
108 Returns:
109 A Keras Optimizer instance.
110 """
111 # loss_scale_optimizer has a direct dependency of optimizer, import here
112 # rather than top to avoid the cyclic dependency.
113 from keras.src.mixed_precision import (
114 loss_scale_optimizer,
115 )
117 use_legacy_optimizer = kwargs.pop("use_legacy_optimizer", False)
118 if kwargs:
119 raise TypeError(f"Invalid keyword arguments: {kwargs}")
120 if len(config["config"]) > 0:
121 # If the optimizer config is not empty, then we use the value of
122 # `is_legacy_optimizer` to override `use_legacy_optimizer`. If
123 # `is_legacy_optimizer` does not exist in config, it means we are
124 # using the legacy optimzier.
125 use_legacy_optimizer = config["config"].get("is_legacy_optimizer", True)
126 if (
127 tf.__internal__.tf2.enabled()
128 and tf.executing_eagerly()
129 and not is_arm_mac()
130 and not use_legacy_optimizer
131 ):
132 # We observed a slowdown of optimizer on M1 Mac, so we fall back to the
133 # legacy optimizer for M1 users now, see b/263339144 for more context.
134 all_classes = {
135 "adadelta": adadelta.Adadelta,
136 "adagrad": adagrad.Adagrad,
137 "adam": adam.Adam,
138 "adamax": adamax.Adamax,
139 "experimentaladadelta": adadelta.Adadelta,
140 "experimentaladagrad": adagrad.Adagrad,
141 "experimentaladam": adam.Adam,
142 "experimentalsgd": sgd.SGD,
143 "nadam": nadam.Nadam,
144 "rmsprop": rmsprop.RMSprop,
145 "sgd": sgd.SGD,
146 "ftrl": ftrl.Ftrl,
147 "lossscaleoptimizer": loss_scale_optimizer.LossScaleOptimizerV3,
148 "lossscaleoptimizerv3": loss_scale_optimizer.LossScaleOptimizerV3,
149 # LossScaleOptimizerV1 was an old version of LSO that was removed.
150 # Deserializing it turns it into a LossScaleOptimizer
151 "lossscaleoptimizerv1": loss_scale_optimizer.LossScaleOptimizer,
152 }
153 else:
154 all_classes = {
155 "adadelta": adadelta_legacy.Adadelta,
156 "adagrad": adagrad_legacy.Adagrad,
157 "adam": adam_legacy.Adam,
158 "adamax": adamax_legacy.Adamax,
159 "experimentaladadelta": adadelta.Adadelta,
160 "experimentaladagrad": adagrad.Adagrad,
161 "experimentaladam": adam.Adam,
162 "experimentalsgd": sgd.SGD,
163 "nadam": nadam_legacy.Nadam,
164 "rmsprop": rmsprop_legacy.RMSprop,
165 "sgd": gradient_descent_legacy.SGD,
166 "ftrl": ftrl_legacy.Ftrl,
167 "lossscaleoptimizer": loss_scale_optimizer.LossScaleOptimizer,
168 "lossscaleoptimizerv3": loss_scale_optimizer.LossScaleOptimizerV3,
169 # LossScaleOptimizerV1 was an old version of LSO that was removed.
170 # Deserializing it turns it into a LossScaleOptimizer
171 "lossscaleoptimizerv1": loss_scale_optimizer.LossScaleOptimizer,
172 }
174 # Make deserialization case-insensitive for built-in optimizers.
175 if config["class_name"].lower() in all_classes:
176 config["class_name"] = config["class_name"].lower()
178 if use_legacy_format:
179 return legacy_serialization.deserialize_keras_object(
180 config,
181 module_objects=all_classes,
182 custom_objects=custom_objects,
183 printable_module_name="optimizer",
184 )
186 return deserialize_keras_object(
187 config,
188 module_objects=all_classes,
189 custom_objects=custom_objects,
190 printable_module_name="optimizer",
191 )
194@keras_export(
195 "keras.__internal__.optimizers.convert_to_legacy_optimizer", v1=[]
196)
197def convert_to_legacy_optimizer(optimizer):
198 """Convert experimental optimizer to legacy optimizer.
200 This function takes in a `keras.optimizers.Optimizer`
201 instance and converts it to the corresponding
202 `keras.optimizers.legacy.Optimizer` instance.
203 For example, `keras.optimizers.Adam(...)` to
204 `keras.optimizers.legacy.Adam(...)`.
206 Args:
207 optimizer: An instance of `keras.optimizers.Optimizer`.
208 """
209 # loss_scale_optimizer has a direct dependency of optimizer, import here
210 # rather than top to avoid the cyclic dependency.
211 from keras.src.mixed_precision import (
212 loss_scale_optimizer,
213 )
215 if not isinstance(optimizer, base_optimizer.Optimizer):
216 raise ValueError(
217 "`convert_to_legacy_optimizer` should only be called "
218 "on instances of `tf.keras.optimizers.Optimizer`, but "
219 f"received {optimizer} of type {type(optimizer)}."
220 )
221 optimizer_name = optimizer.__class__.__name__.lower()
222 config = optimizer.get_config()
223 # Remove fields that only exist in experimental optimizer.
224 keys_to_remove = [
225 "weight_decay",
226 "use_ema",
227 "ema_momentum",
228 "ema_overwrite_frequency",
229 "jit_compile",
230 "is_legacy_optimizer",
231 ]
232 for key in keys_to_remove:
233 config.pop(key, None)
235 if isinstance(optimizer, loss_scale_optimizer.LossScaleOptimizerV3):
236 # For LossScaleOptimizers, recursively convert the inner optimizer
237 config["inner_optimizer"] = convert_to_legacy_optimizer(
238 optimizer.inner_optimizer
239 )
240 if optimizer_name == "lossscaleoptimizerv3":
241 optimizer_name = "lossscaleoptimizer"
243 # Learning rate can be a custom LearningRateSchedule, which is stored as
244 # a dict in config, and cannot be deserialized.
245 if hasattr(optimizer, "_learning_rate") and isinstance(
246 optimizer._learning_rate, learning_rate_schedule.LearningRateSchedule
247 ):
248 config["learning_rate"] = optimizer._learning_rate
249 legacy_optimizer_config = {
250 "class_name": optimizer_name,
251 "config": config,
252 }
253 return deserialize(legacy_optimizer_config, use_legacy_optimizer=True)
256@keras_export("keras.optimizers.get")
257def get(identifier, **kwargs):
258 """Retrieves a Keras Optimizer instance.
260 Args:
261 identifier: Optimizer identifier, one of - String: name of an optimizer
262 - Dictionary: configuration dictionary. - Keras Optimizer instance (it
263 will be returned unchanged). - TensorFlow Optimizer instance (it will
264 be wrapped as a Keras Optimizer).
266 Returns:
267 A Keras Optimizer instance.
269 Raises:
270 ValueError: If `identifier` cannot be interpreted.
271 """
272 use_legacy_optimizer = kwargs.pop("use_legacy_optimizer", False)
273 if kwargs:
274 raise TypeError(f"Invalid keyword arguments: {kwargs}")
275 if isinstance(
276 identifier,
277 (
278 Optimizer,
279 base_optimizer_legacy.OptimizerV2,
280 ),
281 ):
282 return identifier
283 elif isinstance(identifier, base_optimizer.Optimizer):
284 if tf.__internal__.tf2.enabled() and not is_arm_mac():
285 return identifier
286 else:
287 # If TF2 is disabled or on a M1 mac, we convert to the legacy
288 # optimizer. We observed a slowdown of optimizer on M1 Mac, so we
289 # fall back to the legacy optimizer for now, see b/263339144
290 # for more context.
291 optimizer_name = identifier.__class__.__name__
292 logging.warning(
293 "There is a known slowdown when using v2.11+ Keras optimizers "
294 "on M1/M2 Macs. Falling back to the "
295 "legacy Keras optimizer, i.e., "
296 f"`tf.keras.optimizers.legacy.{optimizer_name}`."
297 )
298 return convert_to_legacy_optimizer(identifier)
300 # Wrap legacy TF optimizer instances
301 elif isinstance(identifier, tf.compat.v1.train.Optimizer):
302 opt = TFOptimizer(identifier)
303 backend.track_tf_optimizer(opt)
304 return opt
305 elif isinstance(identifier, dict):
306 use_legacy_format = "module" not in identifier
307 return deserialize(
308 identifier,
309 use_legacy_optimizer=use_legacy_optimizer,
310 use_legacy_format=use_legacy_format,
311 )
312 elif isinstance(identifier, str):
313 config = {"class_name": str(identifier), "config": {}}
314 return get(
315 config,
316 use_legacy_optimizer=use_legacy_optimizer,
317 )
318 else:
319 raise ValueError(
320 f"Could not interpret optimizer identifier: {identifier}"
321 )