Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/models/sharpness_aware_minimization.py: 25%
79 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Sharpness Aware Minimization implementation."""
17import copy
19import tensorflow.compat.v2 as tf
21from keras.src.engine import data_adapter
22from keras.src.layers import deserialize as deserialize_layer
23from keras.src.models import Model
24from keras.src.saving.object_registration import register_keras_serializable
25from keras.src.saving.serialization_lib import serialize_keras_object
27# isort: off
28from tensorflow.python.util.tf_export import keras_export
31@register_keras_serializable()
32@keras_export("keras.models.experimental.SharpnessAwareMinimization", v1=[])
33class SharpnessAwareMinimization(Model):
34 """Sharpness aware minimization (SAM) training flow.
36 Sharpness-aware minimization (SAM) is a technique that improves the model
37 generalization and provides robustness to label noise. Mini-batch splitting
38 is proven to improve the SAM's performance, so users can control how mini
39 batches are split via setting the `num_batch_splits` argument.
41 Args:
42 model: `tf.keras.Model` instance. The inner model that does the
43 forward-backward pass.
44 rho: float, defaults to 0.05. The gradients scaling factor.
45 num_batch_splits: int, defaults to None. The number of mini batches to
46 split into from each data batch. If None, batches are not split into
47 sub-batches.
48 name: string, defaults to None. The name of the SAM model.
50 Reference:
51 [Pierre Foret et al., 2020](https://arxiv.org/abs/2010.01412)
52 """
54 def __init__(self, model, rho=0.05, num_batch_splits=None, name=None):
55 super().__init__(name=name)
56 self.model = model
57 self.rho = rho
58 self.num_batch_splits = num_batch_splits
60 def train_step(self, data):
61 """The logic of one SAM training step.
63 Args:
64 data: A nested structure of `Tensor`s. It should be of structure
65 (x, y, sample_weight) or (x, y).
67 Returns:
68 A dict mapping metric names to running average values.
69 """
70 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
72 if self.num_batch_splits is not None:
73 x_split = tf.split(x, self.num_batch_splits)
74 y_split = tf.split(y, self.num_batch_splits)
75 else:
76 x_split = [x]
77 y_split = [y]
79 gradients_all_batches = []
80 pred_all_batches = []
81 for x_batch, y_batch in zip(x_split, y_split):
82 epsilon_w_cache = []
83 with tf.GradientTape() as tape:
84 pred = self.model(x_batch)
85 loss = self.compiled_loss(y_batch, pred)
86 pred_all_batches.append(pred)
87 trainable_variables = self.model.trainable_variables
88 gradients = tape.gradient(loss, trainable_variables)
90 gradients_order2_norm = self._gradients_order2_norm(gradients)
91 scale = self.rho / (gradients_order2_norm + 1e-12)
93 for gradient, variable in zip(gradients, trainable_variables):
94 epsilon_w = gradient * scale
95 self._distributed_apply_epsilon_w(
96 variable, epsilon_w, tf.distribute.get_strategy()
97 )
98 epsilon_w_cache.append(epsilon_w)
100 with tf.GradientTape() as tape:
101 pred = self(x_batch)
102 loss = self.compiled_loss(y_batch, pred)
103 gradients = tape.gradient(loss, trainable_variables)
104 if len(gradients_all_batches) == 0:
105 for gradient in gradients:
106 gradients_all_batches.append([gradient])
107 else:
108 for gradient, gradient_all_batches in zip(
109 gradients, gradients_all_batches
110 ):
111 gradient_all_batches.append(gradient)
112 for variable, epsilon_w in zip(
113 trainable_variables, epsilon_w_cache
114 ):
115 # Restore the variable to its original value before
116 # `apply_gradients()`.
117 self._distributed_apply_epsilon_w(
118 variable, -epsilon_w, tf.distribute.get_strategy()
119 )
121 gradients = []
122 for gradient_all_batches in gradients_all_batches:
123 gradients.append(tf.reduce_sum(gradient_all_batches, axis=0))
124 self.optimizer.apply_gradients(zip(gradients, trainable_variables))
126 pred = tf.concat(pred_all_batches, axis=0)
127 self.compiled_metrics.update_state(y, pred, sample_weight)
128 return {m.name: m.result() for m in self.metrics}
130 def call(self, inputs):
131 """Forward pass of SAM.
133 SAM delegates the forward pass call to the wrapped model.
135 Args:
136 inputs: Tensor. The model inputs.
138 Returns:
139 A Tensor, the outputs of the wrapped model for given `inputs`.
140 """
141 return self.model(inputs)
143 def get_config(self):
144 config = super().get_config()
145 config.update(
146 {
147 "model": serialize_keras_object(self.model),
148 "rho": self.rho,
149 }
150 )
151 return config
153 @classmethod
154 def from_config(cls, config, custom_objects=None):
155 # Avoid mutating the input dict.
156 config = copy.deepcopy(config)
157 model = deserialize_layer(
158 config.pop("model"), custom_objects=custom_objects
159 )
160 config["model"] = model
161 return super().from_config(config, custom_objects)
163 def _distributed_apply_epsilon_w(self, var, epsilon_w, strategy):
164 # Helper function to apply epsilon_w on model variables.
165 if isinstance(
166 tf.distribute.get_strategy(),
167 (
168 tf.distribute.experimental.ParameterServerStrategy,
169 tf.distribute.experimental.CentralStorageStrategy,
170 ),
171 ):
172 # Under PSS and CSS, the AggregatingVariable has to be kept in sync.
173 def distribute_apply(strategy, var, epsilon_w):
174 strategy.extended.update(
175 var,
176 lambda x, y: x.assign_add(y),
177 args=(epsilon_w,),
178 group=False,
179 )
181 tf.__internal__.distribute.interim.maybe_merge_call(
182 distribute_apply, tf.distribute.get_strategy(), var, epsilon_w
183 )
184 else:
185 var.assign_add(epsilon_w)
187 def _gradients_order2_norm(self, gradients):
188 norm = tf.norm(
189 tf.stack([tf.norm(grad) for grad in gradients if grad is not None])
190 )
191 return norm