Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/average_wrapper.py: 38%

88 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16import abc 

17import warnings 

18 

19import tensorflow as tf 

20from tensorflow_addons.optimizers import KerasLegacyOptimizer 

21from tensorflow_addons.utils import types 

22from typeguard import typechecked 

23 

24 

25class AveragedOptimizerWrapper(KerasLegacyOptimizer, metaclass=abc.ABCMeta): 

26 @typechecked 

27 def __init__( 

28 self, 

29 optimizer: types.Optimizer, 

30 name: str = "AverageOptimizer", 

31 **kwargs, 

32 ): 

33 super().__init__(name, **kwargs) 

34 

35 if isinstance(optimizer, str): 

36 if ( 

37 hasattr(tf.keras.optimizers, "legacy") 

38 and KerasLegacyOptimizer == tf.keras.optimizers.legacy.Optimizer 

39 ): 

40 optimizer = tf.keras.optimizers.get( 

41 optimizer, use_legacy_optimizer=True 

42 ) 

43 else: 

44 optimizer = tf.keras.optimizers.get(optimizer) 

45 

46 if not isinstance(optimizer, KerasLegacyOptimizer): 

47 raise TypeError( 

48 "optimizer is not an object of tf.keras.optimizers.legacy.Optimizer " 

49 ) 

50 

51 self._optimizer = optimizer 

52 self._track_trackable(self._optimizer, "awg_optimizer") 

53 

54 def _create_slots(self, var_list): 

55 self._optimizer._create_slots(var_list=var_list) 

56 for var in var_list: 

57 self.add_slot(var, "average") 

58 

59 def _create_hypers(self): 

60 self._optimizer._create_hypers() 

61 

62 def _prepare_local(self, var_device, var_dtype, apply_state): 

63 return self._optimizer._prepare_local(var_device, var_dtype, apply_state) 

64 

65 def apply_gradients(self, grads_and_vars, name=None, **kwargs): 

66 self._optimizer._iterations = self.iterations 

67 return super().apply_gradients(grads_and_vars, name, **kwargs) 

68 

69 @abc.abstractmethod 

70 def average_op(self, var, average_var, local_apply_state): 

71 raise NotImplementedError 

72 

73 def _apply_average_op(self, train_op, var, apply_state): 

74 apply_state = apply_state or {} 

75 local_apply_state = apply_state.get((var.device, var.dtype.base_dtype)) 

76 if local_apply_state is None: 

77 local_apply_state = self._fallback_apply_state( 

78 var.device, var.dtype.base_dtype 

79 ) 

80 average_var = self.get_slot(var, "average") 

81 return self.average_op(var, average_var, local_apply_state) 

82 

83 def _resource_apply_dense(self, grad, var, apply_state=None): 

84 if "apply_state" in self._optimizer._dense_apply_args: 

85 train_op = self._optimizer._resource_apply_dense( 

86 grad, var, apply_state=apply_state 

87 ) 

88 else: 

89 train_op = self._optimizer._resource_apply_dense(grad, var) 

90 average_op = self._apply_average_op(train_op, var, apply_state) 

91 return tf.group(train_op, average_op) 

92 

93 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 

94 if "apply_state" in self._optimizer._sparse_apply_args: 

95 train_op = self._optimizer._resource_apply_sparse( 

96 grad, var, indices, apply_state=apply_state 

97 ) 

98 else: 

99 train_op = self._optimizer._resource_apply_sparse(grad, var, indices) 

100 average_op = self._apply_average_op(train_op, var, apply_state) 

101 return tf.group(train_op, average_op) 

102 

103 def _resource_apply_sparse_duplicate_indices( 

104 self, grad, var, indices, apply_state=None 

105 ): 

106 if "apply_state" in self._optimizer._sparse_apply_args: 

107 train_op = self._optimizer._resource_apply_sparse_duplicate_indices( 

108 grad, var, indices, apply_state=apply_state 

109 ) 

110 else: 

111 train_op = self._optimizer._resource_apply_sparse_duplicate_indices( 

112 grad, var, indices 

113 ) 

114 average_op = self._apply_average_op(train_op, var, apply_state) 

115 return tf.group(train_op, average_op) 

116 

117 def assign_average_vars(self, var_list): 

118 """Assign variables in var_list with their respective averages. 

119 

120 Args: 

121 var_list: List of model variables to be assigned to their average. 

122 

123 Returns: 

124 assign_op: The op corresponding to the assignment operation of 

125 variables to their average. 

126 

127 Example: 

128 ```python 

129 model = tf.Sequential([...]) 

130 opt = tfa.optimizers.SWA( 

131 tf.keras.optimizers.SGD(lr=2.0), 100, 10) 

132 model.compile(opt, ...) 

133 model.fit(x, y, ...) 

134 

135 # Update the weights to their mean before saving 

136 opt.assign_average_vars(model.variables) 

137 

138 model.save('model.h5') 

139 ``` 

140 """ 

141 assign_ops = [] 

142 for var in var_list: 

143 try: 

144 assign_ops.append( 

145 var.assign( 

146 self.get_slot(var, "average"), 

147 use_locking=self._use_locking, 

148 ) 

149 ) 

150 except Exception as e: 

151 warnings.warn("Unable to assign average slot to {} : {}".format(var, e)) 

152 return tf.group(assign_ops) 

153 

154 def get_config(self): 

155 config = { 

156 "optimizer": tf.keras.optimizers.serialize(self._optimizer), 

157 } 

158 base_config = super().get_config() 

159 return {**base_config, **config} 

160 

161 @classmethod 

162 def from_config(cls, config, custom_objects=None): 

163 optimizer = tf.keras.optimizers.deserialize( 

164 config.pop("optimizer"), custom_objects=custom_objects 

165 ) 

166 return cls(optimizer, **config) 

167 

168 @property 

169 def weights(self): 

170 return self._weights + self._optimizer.weights 

171 

172 @property 

173 def lr(self): 

174 return self._optimizer._get_hyper("learning_rate") 

175 

176 @lr.setter 

177 def lr(self, lr): 

178 self._optimizer._set_hyper("learning_rate", lr) # 

179 

180 @property 

181 def learning_rate(self): 

182 return self._optimizer._get_hyper("learning_rate") 

183 

184 @learning_rate.setter 

185 def learning_rate(self, learning_rate): 

186 self._optimizer._set_hyper("learning_rate", learning_rate)