Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/discriminative_layer_training.py: 33%

63 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Discriminative Layer Training Optimizer for TensorFlow.""" 

16 

17from typing import List, Union 

18 

19import tensorflow as tf 

20 

21from packaging.version import Version 

22from tensorflow_addons.optimizers import KerasLegacyOptimizer 

23from typeguard import typechecked 

24 

25if Version(tf.__version__).release >= Version("2.16").release: 

26 # Determine if loading keras 2 or 3. 

27 if ( 

28 hasattr(tf.keras, "version") 

29 and Version(tf.keras.version()).release >= Version("3.0").release 

30 ): 

31 # New versions of Keras require importing from `keras.src` when 

32 # importing internal symbols. 

33 from keras.src import backend 

34 from keras.src.utils import tf_utils 

35 else: 

36 from tf_keras.src import backend 

37 from tf_keras.src.utils import tf_utils 

38elif Version(tf.__version__).release >= Version("2.13").release: 

39 from keras.src import backend 

40 from keras.src.utils import tf_utils 

41else: 

42 from keras import backend 

43 from keras.utils import tf_utils 

44 

45 

46@tf.keras.utils.register_keras_serializable(package="Addons") 

47class MultiOptimizer(KerasLegacyOptimizer): 

48 """Multi Optimizer Wrapper for Discriminative Layer Training. 

49 

50 Creates a wrapper around a set of instantiated optimizer layer pairs. 

51 Generally useful for transfer learning of deep networks. 

52 

53 Each optimizer will optimize only the weights associated with its paired layer. 

54 This can be used to implement discriminative layer training by assigning 

55 different learning rates to each optimizer layer pair. 

56 `(tf.keras.optimizers.legacy.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported. 

57 Please note that the layers must be instantiated before instantiating the optimizer. 

58 

59 Args: 

60 optimizers_and_layers: a list of tuples of an optimizer and a layer or model. 

61 Each tuple should contain exactly 1 instantiated optimizer and 1 object that 

62 subclasses `tf.keras.Model`, `tf.keras.Sequential` or `tf.keras.layers.Layer`. 

63 Nested layers and models will be automatically discovered. 

64 Alternatively, in place of a single layer, you can pass a list of layers. 

65 optimizer_specs: specialized list for serialization. 

66 Should be left as None for almost all cases. 

67 If you are loading a serialized version of this optimizer, 

68 please use `tf.keras.models.load_model` after saving a model compiled with this optimizer. 

69 

70 Usage: 

71 

72 >>> model = tf.keras.Sequential([ 

73 ... tf.keras.Input(shape=(4,)), 

74 ... tf.keras.layers.Dense(8), 

75 ... tf.keras.layers.Dense(16), 

76 ... tf.keras.layers.Dense(32), 

77 ... ]) 

78 >>> optimizers = [ 

79 ... tf.keras.optimizers.Adam(learning_rate=1e-4), 

80 ... tf.keras.optimizers.Adam(learning_rate=1e-2) 

81 ... ] 

82 >>> optimizers_and_layers = [(optimizers[0], model.layers[0]), (optimizers[1], model.layers[1:])] 

83 >>> optimizer = tfa.optimizers.MultiOptimizer(optimizers_and_layers) 

84 >>> model.compile(optimizer=optimizer, loss="mse") 

85 

86 Reference: 

87 - [Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146) 

88 - [Collaborative Layer-wise Discriminative Learning in Deep Neural Networks](https://arxiv.org/abs/1607.05440) 

89 

90 Note: Currently, `tfa.optimizers.MultiOptimizer` does not support callbacks that modify optimizers. 

91 However, you can instantiate optimizer layer pairs with 

92 `tf.keras.optimizers.schedules.LearningRateSchedule` 

93 instead of a static learning rate. 

94 

95 This code should function on CPU, GPU, and TPU. Apply with `tf.distribute.Strategy().scope()` context as you 

96 would with any other optimizer. 

97 """ 

98 

99 @typechecked 

100 def __init__( 

101 self, 

102 optimizers_and_layers: Union[list, None] = None, 

103 optimizer_specs: Union[list, None] = None, 

104 name: str = "MultiOptimizer", 

105 **kwargs, 

106 ): 

107 

108 super(MultiOptimizer, self).__init__(name, **kwargs) 

109 

110 if optimizer_specs is None and optimizers_and_layers is not None: 

111 self.optimizer_specs = [ 

112 self.create_optimizer_spec(optimizer, layers_or_model) 

113 for optimizer, layers_or_model in optimizers_and_layers 

114 ] 

115 

116 elif optimizer_specs is not None and optimizers_and_layers is None: 

117 self.optimizer_specs = [ 

118 self.maybe_initialize_optimizer_spec(spec) for spec in optimizer_specs 

119 ] 

120 

121 else: 

122 raise RuntimeError( 

123 "Must specify one of `optimizers_and_layers` or `optimizer_specs`." 

124 ) 

125 

126 def apply_gradients(self, grads_and_vars, **kwargs): 

127 """Wrapped apply_gradient method. 

128 

129 Returns an operation to be executed. 

130 """ 

131 

132 for spec in self.optimizer_specs: 

133 spec["gv"] = [] 

134 

135 for grad, var in tuple(grads_and_vars): 

136 for spec in self.optimizer_specs: 

137 for name in spec["weights"]: 

138 if var.name == name: 

139 spec["gv"].append((grad, var)) 

140 

141 update_ops = [ 

142 spec["optimizer"].apply_gradients(spec["gv"], **kwargs) 

143 for spec in self.optimizer_specs 

144 ] 

145 update_group = tf.group(update_ops) 

146 

147 any_symbolic = any( 

148 isinstance(i, tf.Operation) or tf_utils.is_symbolic_tensor(i) 

149 for i in update_ops 

150 ) 

151 

152 if not tf.executing_eagerly() or any_symbolic: 

153 # If the current context is graph mode or any of the update ops are 

154 # symbolic then the step update should be carried out under a graph 

155 # context. (eager updates execute immediately) 

156 with backend._current_graph( # pylint: disable=protected-access 

157 update_ops 

158 ).as_default(): 

159 with tf.control_dependencies([update_group]): 

160 return self.iterations.assign_add(1, read_value=False) 

161 

162 return self.iterations.assign_add(1) 

163 

164 def get_config(self): 

165 config = super(MultiOptimizer, self).get_config() 

166 optimizer_specs_without_gv = [] 

167 for optimizer_spec in self.optimizer_specs: 

168 optimizer_specs_without_gv.append( 

169 { 

170 "optimizer": optimizer_spec["optimizer"], 

171 "weights": optimizer_spec["weights"], 

172 } 

173 ) 

174 config.update({"optimizer_specs": optimizer_specs_without_gv}) 

175 return config 

176 

177 @classmethod 

178 def create_optimizer_spec( 

179 cls, 

180 optimizer: KerasLegacyOptimizer, 

181 layers_or_model: Union[ 

182 tf.keras.Model, 

183 tf.keras.Sequential, 

184 tf.keras.layers.Layer, 

185 List[tf.keras.layers.Layer], 

186 ], 

187 ): 

188 """Creates a serializable optimizer spec. 

189 

190 The name of each variable is used rather than `var.ref()` to enable serialization and deserialization. 

191 """ 

192 if isinstance(layers_or_model, list): 

193 weights = [ 

194 var.name for sublayer in layers_or_model for var in sublayer.weights 

195 ] 

196 else: 

197 weights = [var.name for var in layers_or_model.weights] 

198 

199 return { 

200 "optimizer": optimizer, 

201 "weights": weights, 

202 } 

203 

204 @classmethod 

205 def maybe_initialize_optimizer_spec(cls, optimizer_spec): 

206 if isinstance(optimizer_spec["optimizer"], dict): 

207 optimizer_spec["optimizer"] = tf.keras.optimizers.deserialize( 

208 optimizer_spec["optimizer"] 

209 ) 

210 

211 return optimizer_spec 

212 

213 def __repr__(self): 

214 return "Multi Optimizer with %i optimizer layer pairs" % len( 

215 self.optimizer_specs 

216 )