Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/models/sharpness_aware_minimization.py: 25%

79 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Sharpness Aware Minimization implementation.""" 

16 

17import copy 

18 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src.engine import data_adapter 

22from keras.src.layers import deserialize as deserialize_layer 

23from keras.src.models import Model 

24from keras.src.saving.object_registration import register_keras_serializable 

25from keras.src.saving.serialization_lib import serialize_keras_object 

26 

27# isort: off 

28from tensorflow.python.util.tf_export import keras_export 

29 

30 

31@register_keras_serializable() 

32@keras_export("keras.models.experimental.SharpnessAwareMinimization", v1=[]) 

33class SharpnessAwareMinimization(Model): 

34 """Sharpness aware minimization (SAM) training flow. 

35 

36 Sharpness-aware minimization (SAM) is a technique that improves the model 

37 generalization and provides robustness to label noise. Mini-batch splitting 

38 is proven to improve the SAM's performance, so users can control how mini 

39 batches are split via setting the `num_batch_splits` argument. 

40 

41 Args: 

42 model: `tf.keras.Model` instance. The inner model that does the 

43 forward-backward pass. 

44 rho: float, defaults to 0.05. The gradients scaling factor. 

45 num_batch_splits: int, defaults to None. The number of mini batches to 

46 split into from each data batch. If None, batches are not split into 

47 sub-batches. 

48 name: string, defaults to None. The name of the SAM model. 

49 

50 Reference: 

51 [Pierre Foret et al., 2020](https://arxiv.org/abs/2010.01412) 

52 """ 

53 

54 def __init__(self, model, rho=0.05, num_batch_splits=None, name=None): 

55 super().__init__(name=name) 

56 self.model = model 

57 self.rho = rho 

58 self.num_batch_splits = num_batch_splits 

59 

60 def train_step(self, data): 

61 """The logic of one SAM training step. 

62 

63 Args: 

64 data: A nested structure of `Tensor`s. It should be of structure 

65 (x, y, sample_weight) or (x, y). 

66 

67 Returns: 

68 A dict mapping metric names to running average values. 

69 """ 

70 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) 

71 

72 if self.num_batch_splits is not None: 

73 x_split = tf.split(x, self.num_batch_splits) 

74 y_split = tf.split(y, self.num_batch_splits) 

75 else: 

76 x_split = [x] 

77 y_split = [y] 

78 

79 gradients_all_batches = [] 

80 pred_all_batches = [] 

81 for x_batch, y_batch in zip(x_split, y_split): 

82 epsilon_w_cache = [] 

83 with tf.GradientTape() as tape: 

84 pred = self.model(x_batch) 

85 loss = self.compiled_loss(y_batch, pred) 

86 pred_all_batches.append(pred) 

87 trainable_variables = self.model.trainable_variables 

88 gradients = tape.gradient(loss, trainable_variables) 

89 

90 gradients_order2_norm = self._gradients_order2_norm(gradients) 

91 scale = self.rho / (gradients_order2_norm + 1e-12) 

92 

93 for gradient, variable in zip(gradients, trainable_variables): 

94 epsilon_w = gradient * scale 

95 self._distributed_apply_epsilon_w( 

96 variable, epsilon_w, tf.distribute.get_strategy() 

97 ) 

98 epsilon_w_cache.append(epsilon_w) 

99 

100 with tf.GradientTape() as tape: 

101 pred = self(x_batch) 

102 loss = self.compiled_loss(y_batch, pred) 

103 gradients = tape.gradient(loss, trainable_variables) 

104 if len(gradients_all_batches) == 0: 

105 for gradient in gradients: 

106 gradients_all_batches.append([gradient]) 

107 else: 

108 for gradient, gradient_all_batches in zip( 

109 gradients, gradients_all_batches 

110 ): 

111 gradient_all_batches.append(gradient) 

112 for variable, epsilon_w in zip( 

113 trainable_variables, epsilon_w_cache 

114 ): 

115 # Restore the variable to its original value before 

116 # `apply_gradients()`. 

117 self._distributed_apply_epsilon_w( 

118 variable, -epsilon_w, tf.distribute.get_strategy() 

119 ) 

120 

121 gradients = [] 

122 for gradient_all_batches in gradients_all_batches: 

123 gradients.append(tf.reduce_sum(gradient_all_batches, axis=0)) 

124 self.optimizer.apply_gradients(zip(gradients, trainable_variables)) 

125 

126 pred = tf.concat(pred_all_batches, axis=0) 

127 self.compiled_metrics.update_state(y, pred, sample_weight) 

128 return {m.name: m.result() for m in self.metrics} 

129 

130 def call(self, inputs): 

131 """Forward pass of SAM. 

132 

133 SAM delegates the forward pass call to the wrapped model. 

134 

135 Args: 

136 inputs: Tensor. The model inputs. 

137 

138 Returns: 

139 A Tensor, the outputs of the wrapped model for given `inputs`. 

140 """ 

141 return self.model(inputs) 

142 

143 def get_config(self): 

144 config = super().get_config() 

145 config.update( 

146 { 

147 "model": serialize_keras_object(self.model), 

148 "rho": self.rho, 

149 } 

150 ) 

151 return config 

152 

153 @classmethod 

154 def from_config(cls, config, custom_objects=None): 

155 # Avoid mutating the input dict. 

156 config = copy.deepcopy(config) 

157 model = deserialize_layer( 

158 config.pop("model"), custom_objects=custom_objects 

159 ) 

160 config["model"] = model 

161 return super().from_config(config, custom_objects) 

162 

163 def _distributed_apply_epsilon_w(self, var, epsilon_w, strategy): 

164 # Helper function to apply epsilon_w on model variables. 

165 if isinstance( 

166 tf.distribute.get_strategy(), 

167 ( 

168 tf.distribute.experimental.ParameterServerStrategy, 

169 tf.distribute.experimental.CentralStorageStrategy, 

170 ), 

171 ): 

172 # Under PSS and CSS, the AggregatingVariable has to be kept in sync. 

173 def distribute_apply(strategy, var, epsilon_w): 

174 strategy.extended.update( 

175 var, 

176 lambda x, y: x.assign_add(y), 

177 args=(epsilon_w,), 

178 group=False, 

179 ) 

180 

181 tf.__internal__.distribute.interim.maybe_merge_call( 

182 distribute_apply, tf.distribute.get_strategy(), var, epsilon_w 

183 ) 

184 else: 

185 var.assign_add(epsilon_w) 

186 

187 def _gradients_order2_norm(self, gradients): 

188 norm = tf.norm( 

189 tf.stack([tf.norm(grad) for grad in gradients if grad is not None]) 

190 ) 

191 return norm 

192