Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/callbacks/average_model_checkpoint.py: 31%

29 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15 

16import tensorflow as tf 

17from typeguard import typechecked 

18from tensorflow_addons.optimizers.average_wrapper import AveragedOptimizerWrapper 

19 

20 

21class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint): 

22 r"""The callback that saves average model weights. 

23 

24 The callback that should be used with optimizers that extend 

25 `tfa.optimizers.AveragedOptimizerWrapper`, i.e., 

26 `tfa.optimizers.MovingAverage` and 

27 `tfa.optimizers.StochasticAverage` optimizers. 

28 It saves and, optionally, assigns the averaged weights. 

29 

30 Args: 

31 update_weights: If `True`, assign the moving average weights 

32 to the model, and save them. If False, keep the old 

33 non-averaged weights, but the saved model uses the 

34 average weights. 

35 

36 See `tf.keras.callbacks.ModelCheckpoint` for the other args. 

37 """ 

38 

39 @typechecked 

40 def __init__( 

41 self, 

42 update_weights: bool, 

43 filepath: str, 

44 monitor: str = "val_loss", 

45 verbose: int = 0, 

46 save_best_only: bool = False, 

47 save_weights_only: bool = False, 

48 mode: str = "auto", 

49 save_freq: str = "epoch", 

50 **kwargs, 

51 ): 

52 self.update_weights = update_weights 

53 super().__init__( 

54 filepath, 

55 monitor, 

56 verbose, 

57 save_best_only, 

58 save_weights_only, 

59 mode, 

60 save_freq, 

61 **kwargs, 

62 ) 

63 

64 def _get_optimizer(self): 

65 optimizer = self.model.optimizer 

66 if type(optimizer).__name__ in ["LossScaleOptimizer", "LossScaleOptimizerV1"]: 

67 optimizer = optimizer.inner_optimizer 

68 

69 return optimizer 

70 

71 def set_model(self, model): 

72 super().set_model(model) 

73 optimizer = self._get_optimizer() 

74 if not isinstance(optimizer, AveragedOptimizerWrapper): 

75 raise TypeError( 

76 "AverageModelCheckpoint is only used when training" 

77 "with MovingAverage or StochasticAverage" 

78 ) 

79 

80 def _save_model(self, *args, **kwargs): 

81 optimizer = self._get_optimizer() 

82 assert isinstance(optimizer, AveragedOptimizerWrapper) 

83 

84 if self.update_weights: 

85 optimizer.assign_average_vars(self.model.trainable_weights) 

86 return super()._save_model(*args, **kwargs) 

87 else: 

88 # Note: `model.get_weights()` gives us the weights (non-ref) 

89 # whereas `model.variables` returns references to the variables. 

90 non_avg_weights = self.model.get_weights() 

91 optimizer.assign_average_vars(self.model.trainable_weights) 

92 # result is currently None, since `super._save_model` doesn't 

93 # return anything, but this may change in the future. 

94 result = super()._save_model(*args, **kwargs) 

95 self.model.set_weights(non_avg_weights) 

96 return result