Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizers.py: 56%
41 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""Built-in optimizer classes.
18For more examples see the base class `tf.keras.optimizers.Optimizer`.
19"""
21from tensorflow.python.keras import backend
22from tensorflow.python.keras.optimizer_v1 import Optimizer
23from tensorflow.python.keras.optimizer_v1 import TFOptimizer
24from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
25from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
26from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
27from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2
28from tensorflow.python.keras.optimizer_v2 import ftrl
29from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
30from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2
31from tensorflow.python.keras.optimizer_v2 import optimizer_v2
32from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2
33from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
34from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
35from tensorflow.python.training import optimizer as tf_optimizer_module
36from tensorflow.python.util.tf_export import keras_export
39@keras_export('keras.optimizers.serialize')
40def serialize(optimizer):
41 """Serialize the optimizer configuration to JSON compatible python dict.
43 The configuration can be used for persistence and reconstruct the `Optimizer`
44 instance again.
46 >>> tf.keras.optimizers.serialize(tf.keras.optimizers.SGD())
47 {'class_name': 'SGD', 'config': {'name': 'SGD', 'learning_rate': 0.01,
48 'decay': 0.0, 'momentum': 0.0,
49 'nesterov': False}}
51 Args:
52 optimizer: An `Optimizer` instance to serialize.
54 Returns:
55 Python dict which contains the configuration of the input optimizer.
56 """
57 return serialize_keras_object(optimizer)
60@keras_export('keras.optimizers.deserialize')
61def deserialize(config, custom_objects=None):
62 """Inverse of the `serialize` function.
64 Args:
65 config: Optimizer configuration dictionary.
66 custom_objects: Optional dictionary mapping names (strings) to custom
67 objects (classes and functions) to be considered during deserialization.
69 Returns:
70 A Keras Optimizer instance.
71 """
72 # loss_scale_optimizer has a direct dependency of optimizer, import here
73 # rather than top to avoid the cyclic dependency.
74 from tensorflow.python.keras.mixed_precision import loss_scale_optimizer # pylint: disable=g-import-not-at-top
75 all_classes = {
76 'adadelta': adadelta_v2.Adadelta,
77 'adagrad': adagrad_v2.Adagrad,
78 'adam': adam_v2.Adam,
79 'adamax': adamax_v2.Adamax,
80 'nadam': nadam_v2.Nadam,
81 'rmsprop': rmsprop_v2.RMSprop,
82 'sgd': gradient_descent_v2.SGD,
83 'ftrl': ftrl.Ftrl,
84 'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer,
85 # LossScaleOptimizerV1 deserializes into LossScaleOptimizer, as
86 # LossScaleOptimizerV1 will be removed soon but deserializing it will
87 # still be supported.
88 'lossscaleoptimizerv1': loss_scale_optimizer.LossScaleOptimizer,
89 }
91 # Make deserialization case-insensitive for built-in optimizers.
92 if config['class_name'].lower() in all_classes:
93 config['class_name'] = config['class_name'].lower()
94 return deserialize_keras_object(
95 config,
96 module_objects=all_classes,
97 custom_objects=custom_objects,
98 printable_module_name='optimizer')
101@keras_export('keras.optimizers.get')
102def get(identifier):
103 """Retrieves a Keras Optimizer instance.
105 Args:
106 identifier: Optimizer identifier, one of
107 - String: name of an optimizer
108 - Dictionary: configuration dictionary. - Keras Optimizer instance (it
109 will be returned unchanged). - TensorFlow Optimizer instance (it
110 will be wrapped as a Keras Optimizer).
112 Returns:
113 A Keras Optimizer instance.
115 Raises:
116 ValueError: If `identifier` cannot be interpreted.
117 """
118 if isinstance(identifier, (Optimizer, optimizer_v2.OptimizerV2)):
119 return identifier
120 # Wrap legacy TF optimizer instances
121 elif isinstance(identifier, tf_optimizer_module.Optimizer):
122 opt = TFOptimizer(identifier)
123 backend.track_tf_optimizer(opt)
124 return opt
125 elif isinstance(identifier, dict):
126 return deserialize(identifier)
127 elif isinstance(identifier, str):
128 config = {'class_name': str(identifier), 'config': {}}
129 return deserialize(config)
130 else:
131 raise ValueError(
132 'Could not interpret optimizer identifier: {}'.format(identifier))