Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/normalization/layer_normalization.py: 17%
105 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Layer Normalization layer."""
17import tensorflow.compat.v2 as tf
19from keras.src import constraints
20from keras.src import initializers
21from keras.src import regularizers
22from keras.src.dtensor import utils
23from keras.src.engine.base_layer import Layer
24from keras.src.utils import tf_utils
26# isort: off
27from tensorflow.python.util.tf_export import keras_export
30@keras_export("keras.layers.LayerNormalization")
31class LayerNormalization(Layer):
32 """Layer normalization layer (Ba et al., 2016).
34 Normalize the activations of the previous layer for each given example in a
35 batch independently, rather than across a batch like Batch Normalization.
36 i.e. applies a transformation that maintains the mean activation within each
37 example close to 0 and the activation standard deviation close to 1.
39 Given a tensor `inputs`, moments are calculated and normalization
40 is performed across the axes specified in `axis`.
42 Example:
44 >>> data = tf.constant(np.arange(10).reshape(5, 2) * 10, dtype=tf.float32)
45 >>> print(data)
46 tf.Tensor(
47 [[ 0. 10.]
48 [20. 30.]
49 [40. 50.]
50 [60. 70.]
51 [80. 90.]], shape=(5, 2), dtype=float32)
53 >>> layer = tf.keras.layers.LayerNormalization(axis=1)
54 >>> output = layer(data)
55 >>> print(output)
56 tf.Tensor(
57 [[-1. 1.]
58 [-1. 1.]
59 [-1. 1.]
60 [-1. 1.]
61 [-1. 1.]], shape=(5, 2), dtype=float32)
63 Notice that with Layer Normalization the normalization happens across the
64 axes *within* each example, rather than across different examples in the
65 batch.
67 If `scale` or `center` are enabled, the layer will scale the normalized
68 outputs by broadcasting them with a trainable variable `gamma`, and center
69 the outputs by broadcasting with a trainable variable `beta`. `gamma` will
70 default to a ones tensor and `beta` will default to a zeros tensor, so that
71 centering and scaling are no-ops before training has begun.
73 So, with scaling and centering enabled the normalization equations
74 are as follows:
76 Let the intermediate activations for a mini-batch to be the `inputs`.
78 For each sample `x_i` in `inputs` with `k` features, we compute the mean and
79 variance of the sample:
81 ```python
82 mean_i = sum(x_i[j] for j in range(k)) / k
83 var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k
84 ```
86 and then compute a normalized `x_i_normalized`, including a small factor
87 `epsilon` for numerical stability.
89 ```python
90 x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon)
91 ```
93 And finally `x_i_normalized ` is linearly transformed by `gamma` and `beta`,
94 which are learned parameters:
96 ```python
97 output_i = x_i_normalized * gamma + beta
98 ```
100 `gamma` and `beta` will span the axes of `inputs` specified in `axis`, and
101 this part of the inputs' shape must be fully defined.
103 For example:
105 >>> layer = tf.keras.layers.LayerNormalization(axis=[1, 2, 3])
106 >>> layer.build([5, 20, 30, 40])
107 >>> print(layer.beta.shape)
108 (20, 30, 40)
109 >>> print(layer.gamma.shape)
110 (20, 30, 40)
112 Note that other implementations of layer normalization may choose to define
113 `gamma` and `beta` over a separate set of axes from the axes being
114 normalized across. For example, Group Normalization
115 ([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1
116 corresponds to a Layer Normalization that normalizes across height, width,
117 and channel and has `gamma` and `beta` span only the channel dimension.
118 So, this Layer Normalization implementation will not match a Group
119 Normalization layer with group size set to 1.
121 Args:
122 axis: Integer or List/Tuple. The axis or axes to normalize across.
123 Typically this is the features axis/axes. The left-out axes are
124 typically the batch axis/axes. This argument defaults to `-1`, the last
125 dimension in the input.
126 epsilon: Small float added to variance to avoid dividing by zero. Defaults
127 to 1e-3
128 center: If True, add offset of `beta` to normalized tensor. If False,
129 `beta` is ignored. Defaults to True.
130 scale: If True, multiply by `gamma`. If False, `gamma` is not used.
131 Defaults to True. When the next layer is linear (also e.g. `nn.relu`),
132 this can be disabled since the scaling will be done by the next layer.
133 beta_initializer: Initializer for the beta weight. Defaults to zeros.
134 gamma_initializer: Initializer for the gamma weight. Defaults to ones.
135 beta_regularizer: Optional regularizer for the beta weight. None by
136 default.
137 gamma_regularizer: Optional regularizer for the gamma weight. None by
138 default.
139 beta_constraint: Optional constraint for the beta weight. None by default.
140 gamma_constraint: Optional constraint for the gamma weight. None by
141 default.
143 Input shape:
144 Arbitrary. Use the keyword argument `input_shape` (tuple of
145 integers, does not include the samples axis) when using this layer as the
146 first layer in a model.
148 Output shape:
149 Same shape as input.
151 Reference:
152 - [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450).
153 """
155 @utils.allow_initializer_layout
156 def __init__(
157 self,
158 axis=-1,
159 epsilon=1e-3,
160 center=True,
161 scale=True,
162 beta_initializer="zeros",
163 gamma_initializer="ones",
164 beta_regularizer=None,
165 gamma_regularizer=None,
166 beta_constraint=None,
167 gamma_constraint=None,
168 **kwargs
169 ):
170 super().__init__(**kwargs)
171 if isinstance(axis, (list, tuple)):
172 self.axis = list(axis)
173 elif isinstance(axis, int):
174 self.axis = axis
175 else:
176 raise TypeError(
177 "Expected an int or a list/tuple of ints for the "
178 "argument 'axis', but received: %r" % axis
179 )
181 self.epsilon = epsilon
182 self.center = center
183 self.scale = scale
184 self.beta_initializer = initializers.get(beta_initializer)
185 self.gamma_initializer = initializers.get(gamma_initializer)
186 self.beta_regularizer = regularizers.get(beta_regularizer)
187 self.gamma_regularizer = regularizers.get(gamma_regularizer)
188 self.beta_constraint = constraints.get(beta_constraint)
189 self.gamma_constraint = constraints.get(gamma_constraint)
191 self.supports_masking = True
193 # Indicates whether a faster fused implementation can be used. This will
194 # be set to True or False in build()"
195 self._fused = None
197 def _fused_can_be_used(self, ndims):
198 """Returns false if fused implementation cannot be used.
200 Check if the axis is contiguous and can be collapsed into the last axis.
201 The self.axis is assumed to have no duplicates.
202 """
203 axis = sorted(self.axis)
204 can_use_fused = False
206 if axis[-1] == ndims - 1 and axis[-1] - axis[0] == len(axis) - 1:
207 can_use_fused = True
209 # fused_batch_norm will silently raise epsilon to be at least 1.001e-5,
210 # so we cannot used the fused version if epsilon is below that value.
211 # Also, the variable dtype must be float32, as fused_batch_norm only
212 # supports float32 variables.
213 if self.epsilon < 1.001e-5 or self.dtype != "float32":
214 can_use_fused = False
216 return can_use_fused
218 def build(self, input_shape):
219 self.axis = tf_utils.validate_axis(self.axis, input_shape)
220 input_shape = tf.TensorShape(input_shape)
221 rank = input_shape.rank
223 param_shape = [input_shape[dim] for dim in self.axis]
224 if self.scale:
225 self.gamma = self.add_weight(
226 name="gamma",
227 shape=param_shape,
228 initializer=self.gamma_initializer,
229 regularizer=self.gamma_regularizer,
230 constraint=self.gamma_constraint,
231 trainable=True,
232 experimental_autocast=False,
233 )
234 else:
235 self.gamma = None
237 if self.center:
238 self.beta = self.add_weight(
239 name="beta",
240 shape=param_shape,
241 initializer=self.beta_initializer,
242 regularizer=self.beta_regularizer,
243 constraint=self.beta_constraint,
244 trainable=True,
245 experimental_autocast=False,
246 )
247 else:
248 self.beta = None
250 self._fused = self._fused_can_be_used(rank)
251 self.built = True
253 def call(self, inputs):
254 # TODO(b/229545225): Remove the RaggedTensor check.
255 is_ragged = isinstance(inputs, tf.RaggedTensor)
256 if is_ragged:
257 inputs_lengths = inputs.nested_row_lengths()
258 inputs = inputs.to_tensor()
259 inputs = tf.cast(inputs, self.compute_dtype)
260 # Compute the axes along which to reduce the mean / variance
261 input_shape = inputs.shape
262 ndims = len(input_shape)
264 # Broadcasting only necessary for norm when the axis is not just
265 # the last dimension
266 broadcast_shape = [1] * ndims
267 for dim in self.axis:
268 broadcast_shape[dim] = input_shape.dims[dim].value
270 def _broadcast(v):
271 if (
272 v is not None
273 and len(v.shape) != ndims
274 and self.axis != [ndims - 1]
275 ):
276 return tf.reshape(v, broadcast_shape)
277 return v
279 if not self._fused:
280 input_dtype = inputs.dtype
281 if (
282 input_dtype in ("float16", "bfloat16")
283 and self.dtype == "float32"
284 ):
285 # If mixed precision is used, cast inputs to float32 so that
286 # this is at least as numerically stable as the fused version.
287 inputs = tf.cast(inputs, "float32")
289 # Calculate the moments on the last axis (layer activations).
290 mean, variance = tf.nn.moments(inputs, self.axis, keepdims=True)
292 scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
294 # Compute layer normalization using the batch_normalization
295 # function.
296 outputs = tf.nn.batch_normalization(
297 inputs,
298 mean,
299 variance,
300 offset=offset,
301 scale=scale,
302 variance_epsilon=self.epsilon,
303 )
304 outputs = tf.cast(outputs, input_dtype)
305 else:
306 # Collapse dims before self.axis, and dims in self.axis
307 pre_dim, in_dim = (1, 1)
308 axis = sorted(self.axis)
309 tensor_shape = tf.shape(inputs)
310 for dim in range(0, ndims):
311 dim_tensor = tensor_shape[dim]
312 if dim < axis[0]:
313 pre_dim = pre_dim * dim_tensor
314 else:
315 assert dim in axis
316 in_dim = in_dim * dim_tensor
318 squeezed_shape = [1, pre_dim, in_dim, 1]
319 # This fused operation requires reshaped inputs to be NCHW.
320 data_format = "NCHW"
322 inputs = tf.reshape(inputs, squeezed_shape)
324 # self.gamma and self.beta have the wrong shape for
325 # fused_batch_norm, so we cannot pass them as the scale and offset
326 # parameters. Therefore, we create two constant tensors in correct
327 # shapes for fused_batch_norm and later construct a separate
328 # calculation on the scale and offset.
329 scale = tf.ones([pre_dim], dtype=self.dtype)
330 offset = tf.zeros([pre_dim], dtype=self.dtype)
332 # Compute layer normalization using the fused_batch_norm function.
333 outputs, _, _ = tf.compat.v1.nn.fused_batch_norm(
334 inputs,
335 scale=scale,
336 offset=offset,
337 epsilon=self.epsilon,
338 data_format=data_format,
339 )
341 outputs = tf.reshape(outputs, tensor_shape)
343 scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
345 if scale is not None:
346 outputs = outputs * tf.cast(scale, outputs.dtype)
347 if offset is not None:
348 outputs = outputs + tf.cast(offset, outputs.dtype)
350 # If some components of the shape got lost due to adjustments, fix that.
351 outputs.set_shape(input_shape)
353 if is_ragged:
354 outputs = tf.RaggedTensor.from_tensor(outputs, inputs_lengths)
355 return outputs
357 def compute_output_shape(self, input_shape):
358 return input_shape
360 def get_config(self):
361 config = {
362 "axis": self.axis,
363 "epsilon": self.epsilon,
364 "center": self.center,
365 "scale": self.scale,
366 "beta_initializer": initializers.serialize(self.beta_initializer),
367 "gamma_initializer": initializers.serialize(self.gamma_initializer),
368 "beta_regularizer": regularizers.serialize(self.beta_regularizer),
369 "gamma_regularizer": regularizers.serialize(self.gamma_regularizer),
370 "beta_constraint": constraints.serialize(self.beta_constraint),
371 "gamma_constraint": constraints.serialize(self.gamma_constraint),
372 }
373 base_config = super().get_config()
374 return dict(list(base_config.items()) + list(config.items()))