Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/lookahead.py: 37%
76 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16import tensorflow as tf
17from tensorflow_addons.utils import types
19from tensorflow_addons.optimizers import KerasLegacyOptimizer
20from typeguard import typechecked
23@tf.keras.utils.register_keras_serializable(package="Addons")
24class Lookahead(KerasLegacyOptimizer):
25 """This class allows to extend optimizers with the lookahead mechanism.
27 The mechanism is proposed by Michael R. Zhang et.al in the paper
28 [Lookahead Optimizer: k steps forward, 1 step back]
29 (https://arxiv.org/abs/1907.08610v1). The optimizer iteratively updates two
30 sets of weights: the search directions for weights are chosen by the inner
31 optimizer, while the "slow weights" are updated each `k` steps based on the
32 directions of the "fast weights" and the two sets of weights are
33 synchronized. This method improves the learning stability and lowers the
34 variance of its inner optimizer.
36 Example of usage:
38 ```python
39 opt = tf.keras.optimizers.SGD(learning_rate)
40 opt = tfa.optimizers.Lookahead(opt)
41 ```
42 """
44 @typechecked
45 def __init__(
46 self,
47 optimizer: types.Optimizer,
48 sync_period: int = 6,
49 slow_step_size: types.FloatTensorLike = 0.5,
50 name: str = "Lookahead",
51 **kwargs,
52 ):
53 r"""Wrap optimizer with the lookahead mechanism.
55 Args:
56 optimizer: The original optimizer that will be used to compute
57 and apply the gradients.
58 sync_period: An integer. The synchronization period of lookahead.
59 Enable lookahead mechanism by setting it with a positive value.
60 slow_step_size: A floating point value.
61 The ratio for updating the slow weights.
62 name: Optional name for the operations created when applying
63 gradients. Defaults to "Lookahead".
64 **kwargs: keyword arguments. Allowed to be {`clipnorm`,
65 `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients
66 by norm; `clipvalue` is clip gradients by value, `decay` is
67 included for backward compatibility to allow time inverse
68 decay of learning rate. `lr` is included for backward
69 compatibility, recommended to use `learning_rate` instead.
70 """
71 super().__init__(name, **kwargs)
73 if isinstance(optimizer, str):
74 if (
75 hasattr(tf.keras.optimizers, "legacy")
76 and KerasLegacyOptimizer == tf.keras.optimizers.legacy.Optimizer
77 ):
78 optimizer = tf.keras.optimizers.get(
79 optimizer, use_legacy_optimizer=True
80 )
81 else:
82 optimizer = tf.keras.optimizers.get(optimizer)
83 if not isinstance(optimizer, KerasLegacyOptimizer):
84 raise TypeError(
85 "optimizer is not an object of tf.keras.optimizers.legacy.Optimizer "
86 )
88 self._optimizer = optimizer
89 self._set_hyper("sync_period", sync_period)
90 self._set_hyper("slow_step_size", slow_step_size)
91 self._initialized = False
92 self._track_trackable(self._optimizer, "lh_base_optimizer")
94 def _create_slots(self, var_list):
95 self._optimizer._create_slots(
96 var_list=var_list
97 ) # pylint: disable=protected-access
98 for var in var_list:
99 self.add_slot(var, "slow", initializer=var)
101 def _create_hypers(self):
102 self._optimizer._create_hypers() # pylint: disable=protected-access
104 def _prepare(self, var_list):
105 return self._optimizer._prepare(
106 var_list=var_list
107 ) # pylint: disable=protected-access
109 def apply_gradients(self, grads_and_vars, name=None, **kwargs):
110 self._optimizer._iterations = (
111 self.iterations
112 ) # pylint: disable=protected-access
113 return super().apply_gradients(grads_and_vars, name, **kwargs)
115 def _look_ahead_op(self, var):
116 var_dtype = var.dtype.base_dtype
117 slow_var = self.get_slot(var, "slow")
118 local_step = tf.cast(self.iterations + 1, tf.dtypes.int64)
119 sync_period = self._get_hyper("sync_period", tf.dtypes.int64)
120 slow_step_size = self._get_hyper("slow_step_size", var_dtype)
121 step_back = slow_var + slow_step_size * (var - slow_var)
122 sync_cond = tf.equal(
123 tf.math.floordiv(local_step, sync_period) * sync_period, local_step
124 )
125 with tf.control_dependencies([step_back]):
126 slow_update = slow_var.assign(
127 tf.where(sync_cond, step_back, slow_var),
128 use_locking=self._use_locking,
129 )
130 var_update = var.assign(
131 tf.where(sync_cond, step_back, var),
132 use_locking=self._use_locking,
133 )
134 return tf.group(slow_update, var_update)
136 @property
137 def weights(self):
138 return self._weights + self._optimizer.weights
140 def _resource_apply_dense(self, grad, var):
141 train_op = self._optimizer._resource_apply_dense(
142 grad, var
143 ) # pylint: disable=protected-access
144 with tf.control_dependencies([train_op]):
145 look_ahead_op = self._look_ahead_op(var)
146 return tf.group(train_op, look_ahead_op)
148 def _resource_apply_sparse(self, grad, var, indices):
149 train_op = (
150 self._optimizer._resource_apply_sparse( # pylint: disable=protected-access
151 grad, var, indices
152 )
153 )
154 with tf.control_dependencies([train_op]):
155 look_ahead_op = self._look_ahead_op(var)
156 return tf.group(train_op, look_ahead_op)
158 def get_config(self):
159 config = {
160 "optimizer": tf.keras.optimizers.serialize(self._optimizer),
161 "sync_period": self._serialize_hyperparameter("sync_period"),
162 "slow_step_size": self._serialize_hyperparameter("slow_step_size"),
163 }
164 base_config = super().get_config()
165 return {**base_config, **config}
167 @property
168 def learning_rate(self):
169 return self._optimizer._get_hyper("learning_rate")
171 @learning_rate.setter
172 def learning_rate(self, learning_rate):
173 self._optimizer._set_hyper("learning_rate", learning_rate)
175 @property
176 def lr(self):
177 return self.learning_rate
179 @lr.setter
180 def lr(self, lr):
181 self.learning_rate = lr
183 @classmethod
184 def from_config(cls, config, custom_objects=None):
185 optimizer = tf.keras.optimizers.deserialize(
186 config.pop("optimizer"), custom_objects=custom_objects
187 )
188 return cls(optimizer, **config)