Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizers.py: 56%

41 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=invalid-name 

16"""Built-in optimizer classes. 

17 

18For more examples see the base class `tf.keras.optimizers.Optimizer`. 

19""" 

20 

21from tensorflow.python.keras import backend 

22from tensorflow.python.keras.optimizer_v1 import Optimizer 

23from tensorflow.python.keras.optimizer_v1 import TFOptimizer 

24from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2 

25from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2 

26from tensorflow.python.keras.optimizer_v2 import adam as adam_v2 

27from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2 

28from tensorflow.python.keras.optimizer_v2 import ftrl 

29from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 

30from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2 

31from tensorflow.python.keras.optimizer_v2 import optimizer_v2 

32from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2 

33from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 

34from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 

35from tensorflow.python.training import optimizer as tf_optimizer_module 

36from tensorflow.python.util.tf_export import keras_export 

37 

38 

39@keras_export('keras.optimizers.serialize') 

40def serialize(optimizer): 

41 """Serialize the optimizer configuration to JSON compatible python dict. 

42 

43 The configuration can be used for persistence and reconstruct the `Optimizer` 

44 instance again. 

45 

46 >>> tf.keras.optimizers.serialize(tf.keras.optimizers.SGD()) 

47 {'class_name': 'SGD', 'config': {'name': 'SGD', 'learning_rate': 0.01, 

48 'decay': 0.0, 'momentum': 0.0, 

49 'nesterov': False}} 

50 

51 Args: 

52 optimizer: An `Optimizer` instance to serialize. 

53 

54 Returns: 

55 Python dict which contains the configuration of the input optimizer. 

56 """ 

57 return serialize_keras_object(optimizer) 

58 

59 

60@keras_export('keras.optimizers.deserialize') 

61def deserialize(config, custom_objects=None): 

62 """Inverse of the `serialize` function. 

63 

64 Args: 

65 config: Optimizer configuration dictionary. 

66 custom_objects: Optional dictionary mapping names (strings) to custom 

67 objects (classes and functions) to be considered during deserialization. 

68 

69 Returns: 

70 A Keras Optimizer instance. 

71 """ 

72 # loss_scale_optimizer has a direct dependency of optimizer, import here 

73 # rather than top to avoid the cyclic dependency. 

74 from tensorflow.python.keras.mixed_precision import loss_scale_optimizer # pylint: disable=g-import-not-at-top 

75 all_classes = { 

76 'adadelta': adadelta_v2.Adadelta, 

77 'adagrad': adagrad_v2.Adagrad, 

78 'adam': adam_v2.Adam, 

79 'adamax': adamax_v2.Adamax, 

80 'nadam': nadam_v2.Nadam, 

81 'rmsprop': rmsprop_v2.RMSprop, 

82 'sgd': gradient_descent_v2.SGD, 

83 'ftrl': ftrl.Ftrl, 

84 'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer, 

85 # LossScaleOptimizerV1 deserializes into LossScaleOptimizer, as 

86 # LossScaleOptimizerV1 will be removed soon but deserializing it will 

87 # still be supported. 

88 'lossscaleoptimizerv1': loss_scale_optimizer.LossScaleOptimizer, 

89 } 

90 

91 # Make deserialization case-insensitive for built-in optimizers. 

92 if config['class_name'].lower() in all_classes: 

93 config['class_name'] = config['class_name'].lower() 

94 return deserialize_keras_object( 

95 config, 

96 module_objects=all_classes, 

97 custom_objects=custom_objects, 

98 printable_module_name='optimizer') 

99 

100 

101@keras_export('keras.optimizers.get') 

102def get(identifier): 

103 """Retrieves a Keras Optimizer instance. 

104 

105 Args: 

106 identifier: Optimizer identifier, one of 

107 - String: name of an optimizer 

108 - Dictionary: configuration dictionary. - Keras Optimizer instance (it 

109 will be returned unchanged). - TensorFlow Optimizer instance (it 

110 will be wrapped as a Keras Optimizer). 

111 

112 Returns: 

113 A Keras Optimizer instance. 

114 

115 Raises: 

116 ValueError: If `identifier` cannot be interpreted. 

117 """ 

118 if isinstance(identifier, (Optimizer, optimizer_v2.OptimizerV2)): 

119 return identifier 

120 # Wrap legacy TF optimizer instances 

121 elif isinstance(identifier, tf_optimizer_module.Optimizer): 

122 opt = TFOptimizer(identifier) 

123 backend.track_tf_optimizer(opt) 

124 return opt 

125 elif isinstance(identifier, dict): 

126 return deserialize(identifier) 

127 elif isinstance(identifier, str): 

128 config = {'class_name': str(identifier), 'config': {}} 

129 return deserialize(config) 

130 else: 

131 raise ValueError( 

132 'Could not interpret optimizer identifier: {}'.format(identifier))