Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/wrappers.py: 16%

95 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15 

16import logging 

17 

18import tensorflow as tf 

19from typeguard import typechecked 

20 

21 

22@tf.keras.utils.register_keras_serializable(package="Addons") 

23class WeightNormalization(tf.keras.layers.Wrapper): 

24 """Performs weight normalization. 

25 

26 This wrapper reparameterizes a layer by decoupling the weight's 

27 magnitude and direction. 

28 This speeds up convergence by improving the 

29 conditioning of the optimization problem. 

30 

31 See [Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks](https://arxiv.org/abs/1602.07868). 

32 

33 Wrap `tf.keras.layers.Conv2D`: 

34 

35 >>> x = np.random.rand(1, 10, 10, 1) 

36 >>> conv2d = WeightNormalization(tf.keras.layers.Conv2D(2, 2), data_init=False) 

37 >>> y = conv2d(x) 

38 >>> y.shape 

39 TensorShape([1, 9, 9, 2]) 

40 

41 Wrap `tf.keras.layers.Dense`: 

42 

43 >>> x = np.random.rand(1, 10, 10, 1) 

44 >>> dense = WeightNormalization(tf.keras.layers.Dense(10), data_init=False) 

45 >>> y = dense(x) 

46 >>> y.shape 

47 TensorShape([1, 10, 10, 10]) 

48 

49 Args: 

50 layer: A `tf.keras.layers.Layer` instance. 

51 data_init: If `True` use data dependent variable initialization. 

52 Raises: 

53 ValueError: If not initialized with a `Layer` instance. 

54 ValueError: If `Layer` does not contain a `kernel` of weights. 

55 NotImplementedError: If `data_init` is True and running graph execution. 

56 """ 

57 

58 @typechecked 

59 def __init__(self, layer: tf.keras.layers, data_init: bool = True, **kwargs): 

60 super().__init__(layer, **kwargs) 

61 self.data_init = data_init 

62 self._track_trackable(layer, name="layer") 

63 self.is_rnn = isinstance(self.layer, tf.keras.layers.RNN) 

64 

65 if self.data_init and self.is_rnn: 

66 logging.warning( 

67 "WeightNormalization: Using `data_init=True` with RNNs " 

68 "is advised against by the paper. Use `data_init=False`." 

69 ) 

70 

71 def build(self, input_shape): 

72 """Build `Layer`""" 

73 input_shape = tf.TensorShape(input_shape) 

74 self.input_spec = tf.keras.layers.InputSpec(shape=[None] + input_shape[1:]) 

75 

76 if not self.layer.built: 

77 self.layer.build(input_shape) 

78 

79 kernel_layer = self.layer.cell if self.is_rnn else self.layer 

80 

81 if not hasattr(kernel_layer, "kernel"): 

82 raise ValueError( 

83 "`WeightNormalization` must wrap a layer that" 

84 " contains a `kernel` for weights" 

85 ) 

86 

87 if self.is_rnn: 

88 kernel = kernel_layer.recurrent_kernel 

89 else: 

90 kernel = kernel_layer.kernel 

91 

92 # The kernel's filter or unit dimension is -1 

93 self.layer_depth = int(kernel.shape[-1]) 

94 self.kernel_norm_axes = list(range(kernel.shape.rank - 1)) 

95 

96 self.g = self.add_weight( 

97 name="g", 

98 shape=(self.layer_depth,), 

99 initializer="ones", 

100 dtype=kernel.dtype, 

101 trainable=True, 

102 ) 

103 self.v = kernel 

104 

105 self._initialized = self.add_weight( 

106 name="initialized", 

107 shape=None, 

108 initializer="zeros", 

109 dtype=tf.dtypes.bool, 

110 trainable=False, 

111 ) 

112 

113 if self.data_init: 

114 # Used for data initialization in self._data_dep_init. 

115 with tf.name_scope("data_dep_init"): 

116 layer_config = tf.keras.layers.serialize(self.layer) 

117 layer_config["config"]["trainable"] = False 

118 self._naked_clone_layer = tf.keras.layers.deserialize(layer_config) 

119 self._naked_clone_layer.build(input_shape) 

120 self._naked_clone_layer.set_weights(self.layer.get_weights()) 

121 if not self.is_rnn: 

122 self._naked_clone_layer.activation = None 

123 

124 self.built = True 

125 

126 def call(self, inputs): 

127 """Call `Layer`""" 

128 

129 def _do_nothing(): 

130 return tf.identity(self.g) 

131 

132 def _update_weights(): 

133 # Ensure we read `self.g` after _update_weights. 

134 with tf.control_dependencies(self._initialize_weights(inputs)): 

135 return tf.identity(self.g) 

136 

137 g = tf.cond(self._initialized, _do_nothing, _update_weights) 

138 

139 with tf.name_scope("compute_weights"): 

140 # Replace kernel by normalized weight variable. 

141 kernel = tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * g 

142 

143 if self.is_rnn: 

144 self.layer.cell.recurrent_kernel = kernel 

145 update_kernel = tf.identity(self.layer.cell.recurrent_kernel) 

146 else: 

147 self.layer.kernel = kernel 

148 update_kernel = tf.identity(self.layer.kernel) 

149 

150 # Ensure we calculate result after updating kernel. 

151 with tf.control_dependencies([update_kernel]): 

152 outputs = self.layer(inputs) 

153 return outputs 

154 

155 def compute_output_shape(self, input_shape): 

156 return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list()) 

157 

158 def _initialize_weights(self, inputs): 

159 """Initialize weight g. 

160 

161 The initial value of g could either from the initial value in v, 

162 or by the input value if self.data_init is True. 

163 """ 

164 with tf.control_dependencies( 

165 [ 

166 tf.debugging.assert_equal( # pylint: disable=bad-continuation 

167 self._initialized, False, message="The layer has been initialized." 

168 ) 

169 ] 

170 ): 

171 if self.data_init: 

172 assign_tensors = self._data_dep_init(inputs) 

173 else: 

174 assign_tensors = self._init_norm() 

175 assign_tensors.append(self._initialized.assign(True)) 

176 return assign_tensors 

177 

178 def _init_norm(self): 

179 """Set the weight g with the norm of the weight vector.""" 

180 with tf.name_scope("init_norm"): 

181 v_flat = tf.reshape(self.v, [-1, self.layer_depth]) 

182 v_norm = tf.linalg.norm(v_flat, axis=0) 

183 g_tensor = self.g.assign(tf.reshape(v_norm, (self.layer_depth,))) 

184 return [g_tensor] 

185 

186 def _data_dep_init(self, inputs): 

187 """Data dependent initialization.""" 

188 with tf.name_scope("data_dep_init"): 

189 # Generate data dependent init values 

190 x_init = self._naked_clone_layer(inputs) 

191 data_norm_axes = list(range(x_init.shape.rank - 1)) 

192 m_init, v_init = tf.nn.moments(x_init, data_norm_axes) 

193 scale_init = 1.0 / tf.math.sqrt(v_init + 1e-10) 

194 

195 # RNNs have fused kernels that are tiled 

196 # Repeat scale_init to match the shape of fused kernel 

197 # Note: This is only to support the operation, 

198 # the paper advises against RNN+data_dep_init 

199 if scale_init.shape[0] != self.g.shape[0]: 

200 rep = int(self.g.shape[0] / scale_init.shape[0]) 

201 scale_init = tf.tile(scale_init, [rep]) 

202 

203 # Assign data dependent init values 

204 g_tensor = self.g.assign(self.g * scale_init) 

205 if hasattr(self.layer, "bias") and self.layer.bias is not None: 

206 bias_tensor = self.layer.bias.assign(-m_init * scale_init) 

207 return [g_tensor, bias_tensor] 

208 else: 

209 return [g_tensor] 

210 

211 def get_config(self): 

212 config = {"data_init": self.data_init} 

213 base_config = super().get_config() 

214 return {**base_config, **config} 

215 

216 def remove(self): 

217 kernel = tf.Variable( 

218 tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * self.g, 

219 name="recurrent_kernel" if self.is_rnn else "kernel", 

220 ) 

221 

222 if self.is_rnn: 

223 self.layer.cell.recurrent_kernel = kernel 

224 else: 

225 self.layer.kernel = kernel 

226 

227 return self.layer