Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/discriminative_layer_training.py: 33%
63 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Discriminative Layer Training Optimizer for TensorFlow."""
17from typing import List, Union
19import tensorflow as tf
21from packaging.version import Version
22from tensorflow_addons.optimizers import KerasLegacyOptimizer
23from typeguard import typechecked
25if Version(tf.__version__).release >= Version("2.16").release:
26 # Determine if loading keras 2 or 3.
27 if (
28 hasattr(tf.keras, "version")
29 and Version(tf.keras.version()).release >= Version("3.0").release
30 ):
31 # New versions of Keras require importing from `keras.src` when
32 # importing internal symbols.
33 from keras.src import backend
34 from keras.src.utils import tf_utils
35 else:
36 from tf_keras.src import backend
37 from tf_keras.src.utils import tf_utils
38elif Version(tf.__version__).release >= Version("2.13").release:
39 from keras.src import backend
40 from keras.src.utils import tf_utils
41else:
42 from keras import backend
43 from keras.utils import tf_utils
46@tf.keras.utils.register_keras_serializable(package="Addons")
47class MultiOptimizer(KerasLegacyOptimizer):
48 """Multi Optimizer Wrapper for Discriminative Layer Training.
50 Creates a wrapper around a set of instantiated optimizer layer pairs.
51 Generally useful for transfer learning of deep networks.
53 Each optimizer will optimize only the weights associated with its paired layer.
54 This can be used to implement discriminative layer training by assigning
55 different learning rates to each optimizer layer pair.
56 `(tf.keras.optimizers.legacy.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported.
57 Please note that the layers must be instantiated before instantiating the optimizer.
59 Args:
60 optimizers_and_layers: a list of tuples of an optimizer and a layer or model.
61 Each tuple should contain exactly 1 instantiated optimizer and 1 object that
62 subclasses `tf.keras.Model`, `tf.keras.Sequential` or `tf.keras.layers.Layer`.
63 Nested layers and models will be automatically discovered.
64 Alternatively, in place of a single layer, you can pass a list of layers.
65 optimizer_specs: specialized list for serialization.
66 Should be left as None for almost all cases.
67 If you are loading a serialized version of this optimizer,
68 please use `tf.keras.models.load_model` after saving a model compiled with this optimizer.
70 Usage:
72 >>> model = tf.keras.Sequential([
73 ... tf.keras.Input(shape=(4,)),
74 ... tf.keras.layers.Dense(8),
75 ... tf.keras.layers.Dense(16),
76 ... tf.keras.layers.Dense(32),
77 ... ])
78 >>> optimizers = [
79 ... tf.keras.optimizers.Adam(learning_rate=1e-4),
80 ... tf.keras.optimizers.Adam(learning_rate=1e-2)
81 ... ]
82 >>> optimizers_and_layers = [(optimizers[0], model.layers[0]), (optimizers[1], model.layers[1:])]
83 >>> optimizer = tfa.optimizers.MultiOptimizer(optimizers_and_layers)
84 >>> model.compile(optimizer=optimizer, loss="mse")
86 Reference:
87 - [Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146)
88 - [Collaborative Layer-wise Discriminative Learning in Deep Neural Networks](https://arxiv.org/abs/1607.05440)
90 Note: Currently, `tfa.optimizers.MultiOptimizer` does not support callbacks that modify optimizers.
91 However, you can instantiate optimizer layer pairs with
92 `tf.keras.optimizers.schedules.LearningRateSchedule`
93 instead of a static learning rate.
95 This code should function on CPU, GPU, and TPU. Apply with `tf.distribute.Strategy().scope()` context as you
96 would with any other optimizer.
97 """
99 @typechecked
100 def __init__(
101 self,
102 optimizers_and_layers: Union[list, None] = None,
103 optimizer_specs: Union[list, None] = None,
104 name: str = "MultiOptimizer",
105 **kwargs,
106 ):
108 super(MultiOptimizer, self).__init__(name, **kwargs)
110 if optimizer_specs is None and optimizers_and_layers is not None:
111 self.optimizer_specs = [
112 self.create_optimizer_spec(optimizer, layers_or_model)
113 for optimizer, layers_or_model in optimizers_and_layers
114 ]
116 elif optimizer_specs is not None and optimizers_and_layers is None:
117 self.optimizer_specs = [
118 self.maybe_initialize_optimizer_spec(spec) for spec in optimizer_specs
119 ]
121 else:
122 raise RuntimeError(
123 "Must specify one of `optimizers_and_layers` or `optimizer_specs`."
124 )
126 def apply_gradients(self, grads_and_vars, **kwargs):
127 """Wrapped apply_gradient method.
129 Returns an operation to be executed.
130 """
132 for spec in self.optimizer_specs:
133 spec["gv"] = []
135 for grad, var in tuple(grads_and_vars):
136 for spec in self.optimizer_specs:
137 for name in spec["weights"]:
138 if var.name == name:
139 spec["gv"].append((grad, var))
141 update_ops = [
142 spec["optimizer"].apply_gradients(spec["gv"], **kwargs)
143 for spec in self.optimizer_specs
144 ]
145 update_group = tf.group(update_ops)
147 any_symbolic = any(
148 isinstance(i, tf.Operation) or tf_utils.is_symbolic_tensor(i)
149 for i in update_ops
150 )
152 if not tf.executing_eagerly() or any_symbolic:
153 # If the current context is graph mode or any of the update ops are
154 # symbolic then the step update should be carried out under a graph
155 # context. (eager updates execute immediately)
156 with backend._current_graph( # pylint: disable=protected-access
157 update_ops
158 ).as_default():
159 with tf.control_dependencies([update_group]):
160 return self.iterations.assign_add(1, read_value=False)
162 return self.iterations.assign_add(1)
164 def get_config(self):
165 config = super(MultiOptimizer, self).get_config()
166 optimizer_specs_without_gv = []
167 for optimizer_spec in self.optimizer_specs:
168 optimizer_specs_without_gv.append(
169 {
170 "optimizer": optimizer_spec["optimizer"],
171 "weights": optimizer_spec["weights"],
172 }
173 )
174 config.update({"optimizer_specs": optimizer_specs_without_gv})
175 return config
177 @classmethod
178 def create_optimizer_spec(
179 cls,
180 optimizer: KerasLegacyOptimizer,
181 layers_or_model: Union[
182 tf.keras.Model,
183 tf.keras.Sequential,
184 tf.keras.layers.Layer,
185 List[tf.keras.layers.Layer],
186 ],
187 ):
188 """Creates a serializable optimizer spec.
190 The name of each variable is used rather than `var.ref()` to enable serialization and deserialization.
191 """
192 if isinstance(layers_or_model, list):
193 weights = [
194 var.name for sublayer in layers_or_model for var in sublayer.weights
195 ]
196 else:
197 weights = [var.name for var in layers_or_model.weights]
199 return {
200 "optimizer": optimizer,
201 "weights": weights,
202 }
204 @classmethod
205 def maybe_initialize_optimizer_spec(cls, optimizer_spec):
206 if isinstance(optimizer_spec["optimizer"], dict):
207 optimizer_spec["optimizer"] = tf.keras.optimizers.deserialize(
208 optimizer_spec["optimizer"]
209 )
211 return optimizer_spec
213 def __repr__(self):
214 return "Multi Optimizer with %i optimizer layer pairs" % len(
215 self.optimizer_specs
216 )