Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/average_wrapper.py: 38%
88 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16import abc
17import warnings
19import tensorflow as tf
20from tensorflow_addons.optimizers import KerasLegacyOptimizer
21from tensorflow_addons.utils import types
22from typeguard import typechecked
25class AveragedOptimizerWrapper(KerasLegacyOptimizer, metaclass=abc.ABCMeta):
26 @typechecked
27 def __init__(
28 self,
29 optimizer: types.Optimizer,
30 name: str = "AverageOptimizer",
31 **kwargs,
32 ):
33 super().__init__(name, **kwargs)
35 if isinstance(optimizer, str):
36 if (
37 hasattr(tf.keras.optimizers, "legacy")
38 and KerasLegacyOptimizer == tf.keras.optimizers.legacy.Optimizer
39 ):
40 optimizer = tf.keras.optimizers.get(
41 optimizer, use_legacy_optimizer=True
42 )
43 else:
44 optimizer = tf.keras.optimizers.get(optimizer)
46 if not isinstance(optimizer, KerasLegacyOptimizer):
47 raise TypeError(
48 "optimizer is not an object of tf.keras.optimizers.legacy.Optimizer "
49 )
51 self._optimizer = optimizer
52 self._track_trackable(self._optimizer, "awg_optimizer")
54 def _create_slots(self, var_list):
55 self._optimizer._create_slots(var_list=var_list)
56 for var in var_list:
57 self.add_slot(var, "average")
59 def _create_hypers(self):
60 self._optimizer._create_hypers()
62 def _prepare_local(self, var_device, var_dtype, apply_state):
63 return self._optimizer._prepare_local(var_device, var_dtype, apply_state)
65 def apply_gradients(self, grads_and_vars, name=None, **kwargs):
66 self._optimizer._iterations = self.iterations
67 return super().apply_gradients(grads_and_vars, name, **kwargs)
69 @abc.abstractmethod
70 def average_op(self, var, average_var, local_apply_state):
71 raise NotImplementedError
73 def _apply_average_op(self, train_op, var, apply_state):
74 apply_state = apply_state or {}
75 local_apply_state = apply_state.get((var.device, var.dtype.base_dtype))
76 if local_apply_state is None:
77 local_apply_state = self._fallback_apply_state(
78 var.device, var.dtype.base_dtype
79 )
80 average_var = self.get_slot(var, "average")
81 return self.average_op(var, average_var, local_apply_state)
83 def _resource_apply_dense(self, grad, var, apply_state=None):
84 if "apply_state" in self._optimizer._dense_apply_args:
85 train_op = self._optimizer._resource_apply_dense(
86 grad, var, apply_state=apply_state
87 )
88 else:
89 train_op = self._optimizer._resource_apply_dense(grad, var)
90 average_op = self._apply_average_op(train_op, var, apply_state)
91 return tf.group(train_op, average_op)
93 def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
94 if "apply_state" in self._optimizer._sparse_apply_args:
95 train_op = self._optimizer._resource_apply_sparse(
96 grad, var, indices, apply_state=apply_state
97 )
98 else:
99 train_op = self._optimizer._resource_apply_sparse(grad, var, indices)
100 average_op = self._apply_average_op(train_op, var, apply_state)
101 return tf.group(train_op, average_op)
103 def _resource_apply_sparse_duplicate_indices(
104 self, grad, var, indices, apply_state=None
105 ):
106 if "apply_state" in self._optimizer._sparse_apply_args:
107 train_op = self._optimizer._resource_apply_sparse_duplicate_indices(
108 grad, var, indices, apply_state=apply_state
109 )
110 else:
111 train_op = self._optimizer._resource_apply_sparse_duplicate_indices(
112 grad, var, indices
113 )
114 average_op = self._apply_average_op(train_op, var, apply_state)
115 return tf.group(train_op, average_op)
117 def assign_average_vars(self, var_list):
118 """Assign variables in var_list with their respective averages.
120 Args:
121 var_list: List of model variables to be assigned to their average.
123 Returns:
124 assign_op: The op corresponding to the assignment operation of
125 variables to their average.
127 Example:
128 ```python
129 model = tf.Sequential([...])
130 opt = tfa.optimizers.SWA(
131 tf.keras.optimizers.SGD(lr=2.0), 100, 10)
132 model.compile(opt, ...)
133 model.fit(x, y, ...)
135 # Update the weights to their mean before saving
136 opt.assign_average_vars(model.variables)
138 model.save('model.h5')
139 ```
140 """
141 assign_ops = []
142 for var in var_list:
143 try:
144 assign_ops.append(
145 var.assign(
146 self.get_slot(var, "average"),
147 use_locking=self._use_locking,
148 )
149 )
150 except Exception as e:
151 warnings.warn("Unable to assign average slot to {} : {}".format(var, e))
152 return tf.group(assign_ops)
154 def get_config(self):
155 config = {
156 "optimizer": tf.keras.optimizers.serialize(self._optimizer),
157 }
158 base_config = super().get_config()
159 return {**base_config, **config}
161 @classmethod
162 def from_config(cls, config, custom_objects=None):
163 optimizer = tf.keras.optimizers.deserialize(
164 config.pop("optimizer"), custom_objects=custom_objects
165 )
166 return cls(optimizer, **config)
168 @property
169 def weights(self):
170 return self._weights + self._optimizer.weights
172 @property
173 def lr(self):
174 return self._optimizer._get_hyper("learning_rate")
176 @lr.setter
177 def lr(self, lr):
178 self._optimizer._set_hyper("learning_rate", lr) #
180 @property
181 def learning_rate(self):
182 return self._optimizer._get_hyper("learning_rate")
184 @learning_rate.setter
185 def learning_rate(self, learning_rate):
186 self._optimizer._set_hyper("learning_rate", learning_rate)