Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/stochastic_weight_averaging.py: 39%

31 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""An implementation of the Stochastic Weight Averaging optimizer. 

16 

17The Stochastic Weight Averaging mechanism was proposed by Pavel Izmailov 

18et. al in the paper [Averaging Weights Leads to Wider Optima and Better 

19Generalization](https://arxiv.org/abs/1803.05407). The optimizer 

20implements averaging of multiple points along the trajectory of SGD. 

21This averaging has shown to improve model performance on validation/test 

22sets whilst possibly causing a small increase in loss on the training 

23set. 

24""" 

25 

26import tensorflow as tf 

27from tensorflow_addons.optimizers.average_wrapper import AveragedOptimizerWrapper 

28from tensorflow_addons.utils import types 

29 

30from typeguard import typechecked 

31 

32 

33@tf.keras.utils.register_keras_serializable(package="Addons") 

34class SWA(AveragedOptimizerWrapper): 

35 """This class extends optimizers with Stochastic Weight Averaging (SWA). 

36 

37 The Stochastic Weight Averaging mechanism was proposed by Pavel Izmailov 

38 et. al in the paper [Averaging Weights Leads to Wider Optima and 

39 Better Generalization](https://arxiv.org/abs/1803.05407). The optimizer 

40 implements averaging of multiple points along the trajectory of SGD. The 

41 optimizer expects an inner optimizer which will be used to apply the 

42 gradients to the variables and itself computes a running average of the 

43 variables every `k` steps (which generally corresponds to the end 

44 of a cycle when a cyclic learning rate is employed). 

45 

46 We also allow the specification of the number of steps averaging 

47 should first happen after. Let's say, we want averaging to happen every `k` 

48 steps after the first `m` steps. After step `m` we'd take a snapshot of the 

49 variables and then average the weights appropriately at step `m + k`, 

50 `m + 2k` and so on. The assign_average_vars function can be called at the 

51 end of training to obtain the averaged_weights from the optimizer. 

52 

53 Note: If your model has batch-normalization layers you would need to run 

54 the final weights through the data to compute the running mean and 

55 variance corresponding to the activations for each layer of the network. 

56 From the paper: If the DNN uses batch normalization we run one 

57 additional pass over the data, to compute the running mean and standard 

58 deviation of the activations for each layer of the network with SWA 

59 weights after the training is finished, since these statistics are not 

60 collected during training. For most deep learning libraries, such as 

61 PyTorch or Tensorflow, one can typically collect these statistics by 

62 making a forward pass over the data in training mode 

63 ([Averaging Weights Leads to Wider Optima and Better 

64 Generalization](https://arxiv.org/abs/1803.05407)) 

65 

66 Example of usage: 

67 

68 ```python 

69 opt = tf.keras.optimizers.SGD(learning_rate) 

70 opt = tfa.optimizers.SWA(opt, start_averaging=m, average_period=k) 

71 ``` 

72 """ 

73 

74 @typechecked 

75 def __init__( 

76 self, 

77 optimizer: types.Optimizer, 

78 start_averaging: int = 0, 

79 average_period: int = 10, 

80 name: str = "SWA", 

81 **kwargs, 

82 ): 

83 r"""Wrap optimizer with the Stochastic Weight Averaging mechanism. 

84 

85 Args: 

86 optimizer: The original optimizer that will be used to compute and 

87 apply the gradients. 

88 start_averaging: An integer. Threshold to start averaging using 

89 SWA. Averaging only occurs at `start_averaging` iters, must 

90 be >= 0. If start_averaging = m, the first snapshot will be 

91 taken after the mth application of gradients (where the first 

92 iteration is iteration 0). 

93 average_period: An integer. The synchronization period of SWA. The 

94 averaging occurs every average_period steps. Averaging period 

95 needs to be >= 1. 

96 name: Optional name for the operations created when applying 

97 gradients. Defaults to 'SWA'. 

98 **kwargs: keyword arguments. Allowed to be {`clipnorm`, 

99 `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by 

100 norm; `clipvalue` is clip gradients by value, `decay` is 

101 included for backward compatibility to allow time inverse 

102 decay of learning rate. `lr` is included for backward 

103 compatibility, recommended to use `learning_rate` instead. 

104 """ 

105 super().__init__(optimizer, name, **kwargs) 

106 

107 if average_period < 1: 

108 raise ValueError("average_period must be >= 1") 

109 if start_averaging < 0: 

110 raise ValueError("start_averaging must be >= 0") 

111 

112 self._set_hyper("average_period", average_period) 

113 self._set_hyper("start_averaging", start_averaging) 

114 

115 @tf.function 

116 def average_op(self, var, average_var, local_apply_state): 

117 average_period = self._get_hyper("average_period", tf.dtypes.int64) 

118 start_averaging = self._get_hyper("start_averaging", tf.dtypes.int64) 

119 # number of times snapshots of weights have been taken (using max to 

120 # avoid negative values of num_snapshots). 

121 num_snapshots = tf.math.maximum( 

122 tf.cast(0, tf.int64), 

123 tf.math.floordiv(self.iterations - start_averaging, average_period), 

124 ) 

125 

126 # The average update should happen iff two conditions are met: 

127 # 1. A min number of iterations (start_averaging) have taken place. 

128 # 2. Iteration is one in which snapshot should be taken. 

129 checkpoint = start_averaging + num_snapshots * average_period 

130 if self.iterations >= start_averaging and self.iterations == checkpoint: 

131 num_snapshots = tf.cast(num_snapshots, tf.float32) 

132 average_value = (average_var * num_snapshots + var) / (num_snapshots + 1.0) 

133 return average_var.assign(average_value, use_locking=self._use_locking) 

134 

135 return average_var 

136 

137 def get_config(self): 

138 config = { 

139 "average_period": self._serialize_hyperparameter("average_period"), 

140 "start_averaging": self._serialize_hyperparameter("start_averaging"), 

141 } 

142 base_config = super().get_config() 

143 return {**base_config, **config}