Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/stochastic_weight_averaging.py: 39%
31 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""An implementation of the Stochastic Weight Averaging optimizer.
17The Stochastic Weight Averaging mechanism was proposed by Pavel Izmailov
18et. al in the paper [Averaging Weights Leads to Wider Optima and Better
19Generalization](https://arxiv.org/abs/1803.05407). The optimizer
20implements averaging of multiple points along the trajectory of SGD.
21This averaging has shown to improve model performance on validation/test
22sets whilst possibly causing a small increase in loss on the training
23set.
24"""
26import tensorflow as tf
27from tensorflow_addons.optimizers.average_wrapper import AveragedOptimizerWrapper
28from tensorflow_addons.utils import types
30from typeguard import typechecked
33@tf.keras.utils.register_keras_serializable(package="Addons")
34class SWA(AveragedOptimizerWrapper):
35 """This class extends optimizers with Stochastic Weight Averaging (SWA).
37 The Stochastic Weight Averaging mechanism was proposed by Pavel Izmailov
38 et. al in the paper [Averaging Weights Leads to Wider Optima and
39 Better Generalization](https://arxiv.org/abs/1803.05407). The optimizer
40 implements averaging of multiple points along the trajectory of SGD. The
41 optimizer expects an inner optimizer which will be used to apply the
42 gradients to the variables and itself computes a running average of the
43 variables every `k` steps (which generally corresponds to the end
44 of a cycle when a cyclic learning rate is employed).
46 We also allow the specification of the number of steps averaging
47 should first happen after. Let's say, we want averaging to happen every `k`
48 steps after the first `m` steps. After step `m` we'd take a snapshot of the
49 variables and then average the weights appropriately at step `m + k`,
50 `m + 2k` and so on. The assign_average_vars function can be called at the
51 end of training to obtain the averaged_weights from the optimizer.
53 Note: If your model has batch-normalization layers you would need to run
54 the final weights through the data to compute the running mean and
55 variance corresponding to the activations for each layer of the network.
56 From the paper: If the DNN uses batch normalization we run one
57 additional pass over the data, to compute the running mean and standard
58 deviation of the activations for each layer of the network with SWA
59 weights after the training is finished, since these statistics are not
60 collected during training. For most deep learning libraries, such as
61 PyTorch or Tensorflow, one can typically collect these statistics by
62 making a forward pass over the data in training mode
63 ([Averaging Weights Leads to Wider Optima and Better
64 Generalization](https://arxiv.org/abs/1803.05407))
66 Example of usage:
68 ```python
69 opt = tf.keras.optimizers.SGD(learning_rate)
70 opt = tfa.optimizers.SWA(opt, start_averaging=m, average_period=k)
71 ```
72 """
74 @typechecked
75 def __init__(
76 self,
77 optimizer: types.Optimizer,
78 start_averaging: int = 0,
79 average_period: int = 10,
80 name: str = "SWA",
81 **kwargs,
82 ):
83 r"""Wrap optimizer with the Stochastic Weight Averaging mechanism.
85 Args:
86 optimizer: The original optimizer that will be used to compute and
87 apply the gradients.
88 start_averaging: An integer. Threshold to start averaging using
89 SWA. Averaging only occurs at `start_averaging` iters, must
90 be >= 0. If start_averaging = m, the first snapshot will be
91 taken after the mth application of gradients (where the first
92 iteration is iteration 0).
93 average_period: An integer. The synchronization period of SWA. The
94 averaging occurs every average_period steps. Averaging period
95 needs to be >= 1.
96 name: Optional name for the operations created when applying
97 gradients. Defaults to 'SWA'.
98 **kwargs: keyword arguments. Allowed to be {`clipnorm`,
99 `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
100 norm; `clipvalue` is clip gradients by value, `decay` is
101 included for backward compatibility to allow time inverse
102 decay of learning rate. `lr` is included for backward
103 compatibility, recommended to use `learning_rate` instead.
104 """
105 super().__init__(optimizer, name, **kwargs)
107 if average_period < 1:
108 raise ValueError("average_period must be >= 1")
109 if start_averaging < 0:
110 raise ValueError("start_averaging must be >= 0")
112 self._set_hyper("average_period", average_period)
113 self._set_hyper("start_averaging", start_averaging)
115 @tf.function
116 def average_op(self, var, average_var, local_apply_state):
117 average_period = self._get_hyper("average_period", tf.dtypes.int64)
118 start_averaging = self._get_hyper("start_averaging", tf.dtypes.int64)
119 # number of times snapshots of weights have been taken (using max to
120 # avoid negative values of num_snapshots).
121 num_snapshots = tf.math.maximum(
122 tf.cast(0, tf.int64),
123 tf.math.floordiv(self.iterations - start_averaging, average_period),
124 )
126 # The average update should happen iff two conditions are met:
127 # 1. A min number of iterations (start_averaging) have taken place.
128 # 2. Iteration is one in which snapshot should be taken.
129 checkpoint = start_averaging + num_snapshots * average_period
130 if self.iterations >= start_averaging and self.iterations == checkpoint:
131 num_snapshots = tf.cast(num_snapshots, tf.float32)
132 average_value = (average_var * num_snapshots + var) / (num_snapshots + 1.0)
133 return average_var.assign(average_value, use_locking=self._use_locking)
135 return average_var
137 def get_config(self):
138 config = {
139 "average_period": self._serialize_hyperparameter("average_period"),
140 "start_averaging": self._serialize_hyperparameter("start_averaging"),
141 }
142 base_config = super().get_config()
143 return {**base_config, **config}