Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/moving_average.py: 29%
72 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16import tensorflow as tf
18from tensorflow_addons.optimizers import AveragedOptimizerWrapper
19from tensorflow_addons.utils import types
21from typing import Union
22from typeguard import typechecked
25@tf.keras.utils.register_keras_serializable(package="Addons")
26class MovingAverage(AveragedOptimizerWrapper):
27 """Optimizer that computes a moving average of the variables.
29 Empirically it has been found that using the moving average of the trained
30 parameters of a deep network is better than using its trained parameters
31 directly. This optimizer allows you to compute this moving average and swap
32 the variables at save time so that any code outside of the training loop
33 will use by default the average values instead of the original ones.
35 Example of usage:
37 ```python
38 opt = tf.keras.optimizers.SGD(learning_rate)
39 opt = tfa.optimizers.MovingAverage(opt)
41 ```
42 """
44 @typechecked
45 def __init__(
46 self,
47 optimizer: types.Optimizer,
48 average_decay: types.FloatTensorLike = 0.99,
49 num_updates: Union[None, int, tf.Variable] = None,
50 start_step: int = 0,
51 dynamic_decay: bool = False,
52 name: str = "MovingAverage",
53 **kwargs,
54 ):
55 r"""Construct a new MovingAverage optimizer.
57 Args:
58 optimizer: str or `tf.keras.optimizers.legacy.Optimizer` that will be
59 used to compute and apply gradients.
60 average_decay: float. Decay to use to maintain the moving averages
61 of trained variables.
62 num_updates: Optional count of the number of updates applied to
63 variables.
64 start_step: int. What step to start the moving average.
65 dynamic_decay: bool. Whether to change the decay based on the number
66 of optimizer updates. Decay will start at 0.1 and gradually
67 increase up to `average_decay` after each optimizer update.
68 name: Optional name for the operations created when applying
69 gradients. Defaults to "MovingAverage".
70 **kwargs: keyword arguments. Allowed to be {`clipnorm`,
71 `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
72 norm; `clipvalue` is clip gradients by value, `decay` is
73 included for backward compatibility to allow time inverse
74 decay of learning rate. `lr` is included for backward
75 compatibility, recommended to use `learning_rate` instead.
76 """
77 super().__init__(optimizer, name, **kwargs)
78 self._num_updates = num_updates
79 if self._num_updates is not None:
80 if isinstance(self._num_updates, tf.Variable):
81 tf.debugging.assert_integer(
82 self._num_updates,
83 (
84 'type of argument "num_updates" must be '
85 "int; got {} instead".format(self._num_updates.dtype)
86 ),
87 )
88 num_updates = tf.cast(self._num_updates, tf.float32, name="num_updates")
89 average_decay = tf.minimum(
90 average_decay, (1.0 + num_updates) / (10.0 + num_updates)
91 )
93 self._set_hyper("average_decay", average_decay)
94 self._start_step = start_step
95 self._dynamic_decay = dynamic_decay
97 @tf.function
98 def _get_decay(self, step: tf.Tensor):
99 average_decay = self._get_hyper("average_decay", tf.dtypes.float32)
101 step = tf.cast(step, tf.float32)
102 if step < self._start_step:
103 return tf.constant(0.0, tf.float32)
104 elif self._dynamic_decay:
105 step_count = step - self._start_step
106 return tf.minimum(average_decay, (1.0 + step_count) / (10.0 + step_count))
107 else:
108 return average_decay
110 def _prepare_local(self, var_device, var_dtype, apply_state):
111 super()._prepare_local(var_device, var_dtype, apply_state)
112 apply_state[(var_device, var_dtype)]["tfa_ma_decay"] = self._get_decay(
113 self._optimizer.iterations
114 )
116 def average_op(self, var, average_var, local_apply_state):
117 return tf.keras.backend.moving_average_update(
118 average_var, var, local_apply_state["tfa_ma_decay"]
119 )
121 def get_config(self):
122 config = {
123 "average_decay": self._serialize_hyperparameter("average_decay"),
124 "num_updates": self._num_updates,
125 "start_step": self._start_step,
126 "dynamic_decay": self._dynamic_decay,
127 }
128 base_config = super().get_config()
129 return {**base_config, **config}
131 def _create_slots(self, var_list):
132 self._optimizer._create_slots(var_list=var_list)
133 for var in var_list:
134 self.add_slot(var, "average", var.read_value())
136 self._average_weights = [self.get_slot(var, "average") for var in var_list]
137 self._model_weights = var_list
139 def shadow_copy(self, model_weights):
140 """Creates shadow variables for the given model weights."""
141 for var in model_weights:
142 self.add_slot(var, "average", initializer="zeros")
143 self._average_weights = [self.get_slot(var, "average") for var in model_weights]
144 self._model_weights = model_weights
146 @property
147 def has_shadow_copy(self):
148 """Whether this optimizer has created shadow variables."""
149 return self._model_weights is not None
151 def swap_weights(self):
152 """Swap the average and moving weights.
154 This is a convenience method to allow one to evaluate the averaged weights
155 at test time. Loads the weights stored in `self._average_weights` into the model,
156 keeping a copy of the original model weights. Swapping twice will return
157 the original weights.
158 """
159 if tf.distribute.in_cross_replica_context():
160 strategy = tf.distribute.get_strategy()
161 return strategy.run(self._swap_weights, args=())
162 else:
163 raise ValueError(
164 "Swapping weights must occur under a " "tf.distribute.Strategy"
165 )
167 @tf.function
168 def _swap_weights(self):
169 def fn_0(a, b):
170 return a.assign_add(b, use_locking=self._use_locking)
172 def fn_1(b, a):
173 return b.assign(a - b, use_locking=self._use_locking)
175 def fn_2(a, b):
176 return a.assign_sub(b, use_locking=self._use_locking)
178 def swap(strategy, a, b):
179 """Swap `a` and `b` and mirror to all devices."""
180 for a_element, b_element in zip(a, b):
181 strategy.extended.update(
182 a_element, fn_0, args=(b_element,)
183 ) # a = a + b
184 strategy.extended.update(
185 b_element, fn_1, args=(a_element,)
186 ) # b = a - b
187 strategy.extended.update(
188 a_element, fn_2, args=(b_element,)
189 ) # a = a - b
191 ctx = tf.distribute.get_replica_context()
192 return ctx.merge_call(swap, args=(self._average_weights, self._model_weights))