Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/cocob.py: 25%

48 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""COntinuos COin Betting (COCOB) Backprop optimizer""" 

16 

17from typeguard import typechecked 

18import tensorflow as tf 

19 

20from tensorflow_addons.optimizers import KerasLegacyOptimizer 

21 

22 

23@tf.keras.utils.register_keras_serializable(package="Addons") 

24class COCOB(KerasLegacyOptimizer): 

25 """Optimizer that implements COCOB Backprop Algorithm 

26 

27 Reference: 

28 - [COntinuos COin Betting (COCOB) Backprop optimizer 

29 ](https://arxiv.org/abs/1705.07795) 

30 """ 

31 

32 @typechecked 

33 def __init__( 

34 self, 

35 alpha: float = 100, 

36 use_locking: bool = False, 

37 name: str = "COCOB", 

38 **kwargs, 

39 ): 

40 """Constructs a new COCOB-Backprop optimizer 

41 

42 Arguments: 

43 `aplha`: Default value is set to 100 as per paper. 

44 This has the effect of restricting the value of the 

45 parameters in the first iterations of the algorithm. 

46 (Refer to Paper for indepth understanding) 

47 

48 Rasies: 

49 `ValueError`: If the value of `alpha` is less than 1. 

50 `NotImplementedError`: If the data is in sparse format. 

51 """ 

52 

53 if alpha < 1: 

54 raise ValueError("`alpha` must be greater than Zero") 

55 

56 super().__init__(name, **kwargs) 

57 self._set_hyper("alpha", alpha) 

58 self._alpha = alpha 

59 

60 def _create_slots(self, var_list): 

61 for v in var_list: 

62 self.add_slot(v, "lr", initializer=tf.keras.initializers.Constant(1e-8)) 

63 self.add_slot(v, "grad_norm_sum") 

64 self.add_slot(v, "gradients_sum") 

65 self.add_slot(v, "tilde_w") 

66 self.add_slot(v, "reward") 

67 

68 def _resource_apply_dense(self, grad, handle, apply_state=None): 

69 gradients_sum = self.get_slot(handle, "gradients_sum") 

70 grad_norm_sum = self.get_slot(handle, "grad_norm_sum") 

71 tilde_w = self.get_slot(handle, "tilde_w") 

72 lr = self.get_slot(handle, "lr") 

73 reward = self.get_slot(handle, "reward") 

74 

75 lr_update = tf.maximum(lr, tf.abs(grad)) 

76 gradients_sum_update = gradients_sum + grad 

77 grad_norm_sum_update = grad_norm_sum + tf.abs(grad) 

78 reward_update = tf.maximum(reward - grad * tilde_w, 0) 

79 

80 grad_max = tf.maximum(grad_norm_sum_update + lr_update, self._alpha * lr_update) 

81 rewards_lr_sum = reward_update + lr_update 

82 new_w = -gradients_sum_update / (lr_update * (grad_max)) * rewards_lr_sum 

83 

84 var_update = handle - tilde_w + new_w 

85 tilde_w_update = new_w 

86 

87 gradients_sum_update_op = gradients_sum.assign(gradients_sum_update) 

88 grad_norm_sum_update_op = grad_norm_sum.assign(grad_norm_sum_update) 

89 var_update_op = handle.assign(var_update) 

90 tilde_w_update_op = tilde_w.assign(tilde_w_update) 

91 lr_update_op = lr.assign(lr_update) 

92 reward_update_op = reward.assign(reward_update) 

93 

94 return tf.group( 

95 *[ 

96 gradients_sum_update_op, 

97 var_update_op, 

98 grad_norm_sum_update_op, 

99 tilde_w_update_op, 

100 reward_update_op, 

101 lr_update_op, 

102 ] 

103 ) 

104 

105 def _resource_apply_sparse(self, grad, handle, indices, apply_state=None): 

106 raise NotImplementedError() 

107 

108 def get_config(self): 

109 

110 config = { 

111 "alpha": self._serialize_hyperparameter("alpha"), 

112 } 

113 base_config = super().get_config() 

114 return {**base_config, **config}