Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/proximal_adagrad.py: 36%

42 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Proximal Adagrad optimizer.""" 

16 

17from typing import Callable, Union 

18 

19import tensorflow as tf 

20from typeguard import typechecked 

21 

22from tensorflow_addons.optimizers import KerasLegacyOptimizer 

23from tensorflow_addons.utils.types import FloatTensorLike 

24 

25 

26@tf.keras.utils.register_keras_serializable(package="Addons") 

27class ProximalAdagrad(KerasLegacyOptimizer): 

28 """Optimizer that implements the Proximal Adagrad algorithm. 

29 

30 References: 

31 - [Efficient Learning using Forward-Backward Splitting]( 

32 http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf). 

33 """ 

34 

35 @typechecked 

36 def __init__( 

37 self, 

38 learning_rate: Union[FloatTensorLike, Callable] = 0.001, 

39 initial_accumulator_value: float = 0.1, 

40 l1_regularization_strength: float = 0.0, 

41 l2_regularization_strength: float = 0.0, 

42 name: str = "ProximalAdagrad", 

43 **kwargs, 

44 ): 

45 """Construct a new Proximal Adagrad optimizer. 

46 

47 Args: 

48 learning_rate: A Tensor or a floating point value, or a schedule 

49 that is a `tf.keras.optimizers.schedules.LearningRateSchedule`. 

50 The learning rate. 

51 initial_accumulator_value: A floating point value. 

52 Starting value for the accumulators, must be positive. 

53 l1_regularization_strength: A floating point value. 

54 The l1 regularization term, must be greater than or 

55 equal to zero. 

56 l2_regularization_strength: A floating point value. 

57 The l2 regularization term, must be greater than or 

58 equal to zero. 

59 name: Optional name for the operations created when applying 

60 gradients. Defaults to "ProximalAdagrad". 

61 **kwargs: keyword arguments. Allowed to be {`clipnorm`, 

62 `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients 

63 by norm; `clipvalue` is clip gradients by value, `decay` is 

64 included for backward compatibility to allow time inverse 

65 decay of learning rate. `lr` is included for backward 

66 compatibility, recommended to use `learning_rate` instead. 

67 Raises: 

68 ValueError: If the `initial_accumulator_value`, `l1` or `l2` 

69 is invalid. 

70 """ 

71 if initial_accumulator_value < 0.0: 

72 raise ValueError("`initial_accumulator_value` must be non-negative.") 

73 if l1_regularization_strength < 0.0: 

74 raise ValueError("`l1_regularization_strength` must be non-negative.") 

75 if l2_regularization_strength < 0.0: 

76 raise ValueError("`l2_regularization_strength` must be non-negative.") 

77 super().__init__(name, **kwargs) 

78 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 

79 self._set_hyper("l1_regularization_strength", l1_regularization_strength) 

80 self._set_hyper("l2_regularization_strength", l2_regularization_strength) 

81 self._initial_accumulator_value = initial_accumulator_value 

82 

83 def _create_slots(self, var_list): 

84 for var in var_list: 

85 init = tf.keras.initializers.constant(self._initial_accumulator_value) 

86 self.add_slot(var, "accumulator", init) 

87 

88 def _resource_apply_dense(self, grad, var, apply_state=None): 

89 var_device, var_dtype = var.device, var.dtype.base_dtype 

90 coefficients = (apply_state or {}).get( 

91 (var_device, var_dtype) 

92 ) or self._fallback_apply_state(var_device, var_dtype) 

93 

94 acc = self.get_slot(var, "accumulator") 

95 return tf.raw_ops.ResourceApplyProximalAdagrad( 

96 var=var.handle, 

97 accum=acc.handle, 

98 lr=coefficients["lr_t"], 

99 l1=coefficients["l1_regularization_strength"], 

100 l2=coefficients["l2_regularization_strength"], 

101 grad=grad, 

102 use_locking=self._use_locking, 

103 ) 

104 

105 def _prepare_local(self, var_device, var_dtype, apply_state): 

106 super()._prepare_local(var_device, var_dtype, apply_state) 

107 apply_state[(var_device, var_dtype)].update( 

108 { 

109 "l1_regularization_strength": tf.identity( 

110 self._get_hyper("l1_regularization_strength", var_dtype) 

111 ), 

112 "l2_regularization_strength": tf.identity( 

113 self._get_hyper("l2_regularization_strength", var_dtype) 

114 ), 

115 } 

116 ) 

117 

118 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 

119 var_device, var_dtype = var.device, var.dtype.base_dtype 

120 coefficients = (apply_state or {}).get( 

121 (var_device, var_dtype) 

122 ) or self._fallback_apply_state(var_device, var_dtype) 

123 

124 acc = self.get_slot(var, "accumulator") 

125 return tf.raw_ops.ResourceSparseApplyProximalAdagrad( 

126 var=var.handle, 

127 accum=acc.handle, 

128 lr=coefficients["lr_t"], 

129 l1=coefficients["l1_regularization_strength"], 

130 l2=coefficients["l2_regularization_strength"], 

131 grad=grad, 

132 indices=indices, 

133 use_locking=self._use_locking, 

134 ) 

135 

136 def get_config(self): 

137 config = super().get_config() 

138 config.update( 

139 { 

140 "learning_rate": self._serialize_hyperparameter("learning_rate"), 

141 "initial_accumulator_value": self._initial_accumulator_value, 

142 "l1_regularization_strength": self._serialize_hyperparameter( 

143 "l1_regularization_strength" 

144 ), 

145 "l2_regularization_strength": self._serialize_hyperparameter( 

146 "l2_regularization_strength" 

147 ), 

148 } 

149 ) 

150 return config