Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/wrappers.py: 16%
95 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
16import logging
18import tensorflow as tf
19from typeguard import typechecked
22@tf.keras.utils.register_keras_serializable(package="Addons")
23class WeightNormalization(tf.keras.layers.Wrapper):
24 """Performs weight normalization.
26 This wrapper reparameterizes a layer by decoupling the weight's
27 magnitude and direction.
28 This speeds up convergence by improving the
29 conditioning of the optimization problem.
31 See [Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks](https://arxiv.org/abs/1602.07868).
33 Wrap `tf.keras.layers.Conv2D`:
35 >>> x = np.random.rand(1, 10, 10, 1)
36 >>> conv2d = WeightNormalization(tf.keras.layers.Conv2D(2, 2), data_init=False)
37 >>> y = conv2d(x)
38 >>> y.shape
39 TensorShape([1, 9, 9, 2])
41 Wrap `tf.keras.layers.Dense`:
43 >>> x = np.random.rand(1, 10, 10, 1)
44 >>> dense = WeightNormalization(tf.keras.layers.Dense(10), data_init=False)
45 >>> y = dense(x)
46 >>> y.shape
47 TensorShape([1, 10, 10, 10])
49 Args:
50 layer: A `tf.keras.layers.Layer` instance.
51 data_init: If `True` use data dependent variable initialization.
52 Raises:
53 ValueError: If not initialized with a `Layer` instance.
54 ValueError: If `Layer` does not contain a `kernel` of weights.
55 NotImplementedError: If `data_init` is True and running graph execution.
56 """
58 @typechecked
59 def __init__(self, layer: tf.keras.layers, data_init: bool = True, **kwargs):
60 super().__init__(layer, **kwargs)
61 self.data_init = data_init
62 self._track_trackable(layer, name="layer")
63 self.is_rnn = isinstance(self.layer, tf.keras.layers.RNN)
65 if self.data_init and self.is_rnn:
66 logging.warning(
67 "WeightNormalization: Using `data_init=True` with RNNs "
68 "is advised against by the paper. Use `data_init=False`."
69 )
71 def build(self, input_shape):
72 """Build `Layer`"""
73 input_shape = tf.TensorShape(input_shape)
74 self.input_spec = tf.keras.layers.InputSpec(shape=[None] + input_shape[1:])
76 if not self.layer.built:
77 self.layer.build(input_shape)
79 kernel_layer = self.layer.cell if self.is_rnn else self.layer
81 if not hasattr(kernel_layer, "kernel"):
82 raise ValueError(
83 "`WeightNormalization` must wrap a layer that"
84 " contains a `kernel` for weights"
85 )
87 if self.is_rnn:
88 kernel = kernel_layer.recurrent_kernel
89 else:
90 kernel = kernel_layer.kernel
92 # The kernel's filter or unit dimension is -1
93 self.layer_depth = int(kernel.shape[-1])
94 self.kernel_norm_axes = list(range(kernel.shape.rank - 1))
96 self.g = self.add_weight(
97 name="g",
98 shape=(self.layer_depth,),
99 initializer="ones",
100 dtype=kernel.dtype,
101 trainable=True,
102 )
103 self.v = kernel
105 self._initialized = self.add_weight(
106 name="initialized",
107 shape=None,
108 initializer="zeros",
109 dtype=tf.dtypes.bool,
110 trainable=False,
111 )
113 if self.data_init:
114 # Used for data initialization in self._data_dep_init.
115 with tf.name_scope("data_dep_init"):
116 layer_config = tf.keras.layers.serialize(self.layer)
117 layer_config["config"]["trainable"] = False
118 self._naked_clone_layer = tf.keras.layers.deserialize(layer_config)
119 self._naked_clone_layer.build(input_shape)
120 self._naked_clone_layer.set_weights(self.layer.get_weights())
121 if not self.is_rnn:
122 self._naked_clone_layer.activation = None
124 self.built = True
126 def call(self, inputs):
127 """Call `Layer`"""
129 def _do_nothing():
130 return tf.identity(self.g)
132 def _update_weights():
133 # Ensure we read `self.g` after _update_weights.
134 with tf.control_dependencies(self._initialize_weights(inputs)):
135 return tf.identity(self.g)
137 g = tf.cond(self._initialized, _do_nothing, _update_weights)
139 with tf.name_scope("compute_weights"):
140 # Replace kernel by normalized weight variable.
141 kernel = tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * g
143 if self.is_rnn:
144 self.layer.cell.recurrent_kernel = kernel
145 update_kernel = tf.identity(self.layer.cell.recurrent_kernel)
146 else:
147 self.layer.kernel = kernel
148 update_kernel = tf.identity(self.layer.kernel)
150 # Ensure we calculate result after updating kernel.
151 with tf.control_dependencies([update_kernel]):
152 outputs = self.layer(inputs)
153 return outputs
155 def compute_output_shape(self, input_shape):
156 return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())
158 def _initialize_weights(self, inputs):
159 """Initialize weight g.
161 The initial value of g could either from the initial value in v,
162 or by the input value if self.data_init is True.
163 """
164 with tf.control_dependencies(
165 [
166 tf.debugging.assert_equal( # pylint: disable=bad-continuation
167 self._initialized, False, message="The layer has been initialized."
168 )
169 ]
170 ):
171 if self.data_init:
172 assign_tensors = self._data_dep_init(inputs)
173 else:
174 assign_tensors = self._init_norm()
175 assign_tensors.append(self._initialized.assign(True))
176 return assign_tensors
178 def _init_norm(self):
179 """Set the weight g with the norm of the weight vector."""
180 with tf.name_scope("init_norm"):
181 v_flat = tf.reshape(self.v, [-1, self.layer_depth])
182 v_norm = tf.linalg.norm(v_flat, axis=0)
183 g_tensor = self.g.assign(tf.reshape(v_norm, (self.layer_depth,)))
184 return [g_tensor]
186 def _data_dep_init(self, inputs):
187 """Data dependent initialization."""
188 with tf.name_scope("data_dep_init"):
189 # Generate data dependent init values
190 x_init = self._naked_clone_layer(inputs)
191 data_norm_axes = list(range(x_init.shape.rank - 1))
192 m_init, v_init = tf.nn.moments(x_init, data_norm_axes)
193 scale_init = 1.0 / tf.math.sqrt(v_init + 1e-10)
195 # RNNs have fused kernels that are tiled
196 # Repeat scale_init to match the shape of fused kernel
197 # Note: This is only to support the operation,
198 # the paper advises against RNN+data_dep_init
199 if scale_init.shape[0] != self.g.shape[0]:
200 rep = int(self.g.shape[0] / scale_init.shape[0])
201 scale_init = tf.tile(scale_init, [rep])
203 # Assign data dependent init values
204 g_tensor = self.g.assign(self.g * scale_init)
205 if hasattr(self.layer, "bias") and self.layer.bias is not None:
206 bias_tensor = self.layer.bias.assign(-m_init * scale_init)
207 return [g_tensor, bias_tensor]
208 else:
209 return [g_tensor]
211 def get_config(self):
212 config = {"data_init": self.data_init}
213 base_config = super().get_config()
214 return {**base_config, **config}
216 def remove(self):
217 kernel = tf.Variable(
218 tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * self.g,
219 name="recurrent_kernel" if self.is_rnn else "kernel",
220 )
222 if self.is_rnn:
223 self.layer.cell.recurrent_kernel = kernel
224 else:
225 self.layer.kernel = kernel
227 return self.layer