Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/normalizations.py: 19%
192 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15# Orginal implementation from keras_contrib/layer/normalization
16# =============================================================================
18import logging
19import tensorflow as tf
20from typeguard import typechecked
22from tensorflow_addons.utils import types
25@tf.keras.utils.register_keras_serializable(package="Addons")
26class GroupNormalization(tf.keras.layers.Layer):
27 """Group normalization layer.
29 Source: "Group Normalization" (Yuxin Wu & Kaiming He, 2018)
30 https://arxiv.org/abs/1803.08494
32 Group Normalization divides the channels into groups and computes
33 within each group the mean and variance for normalization.
34 Empirically, its accuracy is more stable than batch norm in a wide
35 range of small batch sizes, if learning rate is adjusted linearly
36 with batch sizes.
38 Relation to Layer Normalization:
39 If the number of groups is set to 1, then this operation becomes identical
40 to Layer Normalization.
42 Relation to Instance Normalization:
43 If the number of groups is set to the
44 input dimension (number of groups is equal
45 to number of channels), then this operation becomes
46 identical to Instance Normalization.
48 Args:
49 groups: Integer, the number of groups for Group Normalization.
50 Can be in the range [1, N] where N is the input dimension.
51 The input dimension must be divisible by the number of groups.
52 Defaults to 32.
53 axis: Integer, the axis that should be normalized.
54 epsilon: Small float added to variance to avoid dividing by zero.
55 center: If True, add offset of `beta` to normalized tensor.
56 If False, `beta` is ignored.
57 scale: If True, multiply by `gamma`.
58 If False, `gamma` is not used.
59 beta_initializer: Initializer for the beta weight.
60 gamma_initializer: Initializer for the gamma weight.
61 beta_regularizer: Optional regularizer for the beta weight.
62 gamma_regularizer: Optional regularizer for the gamma weight.
63 beta_constraint: Optional constraint for the beta weight.
64 gamma_constraint: Optional constraint for the gamma weight.
66 Input shape:
67 Arbitrary. Use the keyword argument `input_shape`
68 (tuple of integers, does not include the samples axis)
69 when using this layer as the first layer in a model.
71 Output shape:
72 Same shape as input.
73 """
75 @typechecked
76 def __init__(
77 self,
78 groups: int = 32,
79 axis: int = -1,
80 epsilon: float = 1e-3,
81 center: bool = True,
82 scale: bool = True,
83 beta_initializer: types.Initializer = "zeros",
84 gamma_initializer: types.Initializer = "ones",
85 beta_regularizer: types.Regularizer = None,
86 gamma_regularizer: types.Regularizer = None,
87 beta_constraint: types.Constraint = None,
88 gamma_constraint: types.Constraint = None,
89 **kwargs,
90 ):
91 super().__init__(**kwargs)
92 self.supports_masking = True
93 self.groups = groups
94 self.axis = axis
95 self.epsilon = epsilon
96 self.center = center
97 self.scale = scale
98 self.beta_initializer = tf.keras.initializers.get(beta_initializer)
99 self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
100 self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer)
101 self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer)
102 self.beta_constraint = tf.keras.constraints.get(beta_constraint)
103 self.gamma_constraint = tf.keras.constraints.get(gamma_constraint)
104 self._check_axis()
106 def build(self, input_shape):
108 self._check_if_input_shape_is_none(input_shape)
109 self._set_number_of_groups_for_instance_norm(input_shape)
110 self._check_size_of_dimensions(input_shape)
111 self._create_input_spec(input_shape)
113 self._add_gamma_weight(input_shape)
114 self._add_beta_weight(input_shape)
115 self.built = True
116 super().build(input_shape)
118 def call(self, inputs):
120 input_shape = tf.keras.backend.int_shape(inputs)
121 tensor_input_shape = tf.shape(inputs)
123 reshaped_inputs, group_shape = self._reshape_into_groups(
124 inputs, input_shape, tensor_input_shape
125 )
127 normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape)
129 is_instance_norm = (input_shape[self.axis] // self.groups) == 1
130 if not is_instance_norm:
131 outputs = tf.reshape(normalized_inputs, tensor_input_shape)
132 else:
133 outputs = normalized_inputs
135 return outputs
137 def get_config(self):
138 config = {
139 "groups": self.groups,
140 "axis": self.axis,
141 "epsilon": self.epsilon,
142 "center": self.center,
143 "scale": self.scale,
144 "beta_initializer": tf.keras.initializers.serialize(self.beta_initializer),
145 "gamma_initializer": tf.keras.initializers.serialize(
146 self.gamma_initializer
147 ),
148 "beta_regularizer": tf.keras.regularizers.serialize(self.beta_regularizer),
149 "gamma_regularizer": tf.keras.regularizers.serialize(
150 self.gamma_regularizer
151 ),
152 "beta_constraint": tf.keras.constraints.serialize(self.beta_constraint),
153 "gamma_constraint": tf.keras.constraints.serialize(self.gamma_constraint),
154 }
155 base_config = super().get_config()
156 return {**base_config, **config}
158 def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape):
160 group_shape = [tensor_input_shape[i] for i in range(len(input_shape))]
161 is_instance_norm = (input_shape[self.axis] // self.groups) == 1
162 if not is_instance_norm:
163 group_shape[self.axis] = input_shape[self.axis] // self.groups
164 group_shape.insert(self.axis, self.groups)
165 group_shape = tf.stack(group_shape)
166 reshaped_inputs = tf.reshape(inputs, group_shape)
167 return reshaped_inputs, group_shape
168 else:
169 return inputs, group_shape
171 def _apply_normalization(self, reshaped_inputs, input_shape):
173 group_shape = tf.keras.backend.int_shape(reshaped_inputs)
174 group_reduction_axes = list(range(1, len(group_shape)))
175 is_instance_norm = (input_shape[self.axis] // self.groups) == 1
176 if not is_instance_norm:
177 axis = -2 if self.axis == -1 else self.axis - 1
178 else:
179 axis = -1 if self.axis == -1 else self.axis - 1
180 group_reduction_axes.pop(axis)
182 mean, variance = tf.nn.moments(
183 reshaped_inputs, group_reduction_axes, keepdims=True
184 )
186 gamma, beta = self._get_reshaped_weights(input_shape)
187 normalized_inputs = tf.nn.batch_normalization(
188 reshaped_inputs,
189 mean=mean,
190 variance=variance,
191 scale=gamma,
192 offset=beta,
193 variance_epsilon=self.epsilon,
194 )
195 return normalized_inputs
197 def _get_reshaped_weights(self, input_shape):
198 broadcast_shape = self._create_broadcast_shape(input_shape)
199 gamma = None
200 beta = None
201 if self.scale:
202 gamma = tf.reshape(self.gamma, broadcast_shape)
204 if self.center:
205 beta = tf.reshape(self.beta, broadcast_shape)
206 return gamma, beta
208 def _check_if_input_shape_is_none(self, input_shape):
209 dim = input_shape[self.axis]
210 if dim is None:
211 raise ValueError(
212 "Axis " + str(self.axis) + " of "
213 "input tensor should have a defined dimension "
214 "but the layer received an input with shape " + str(input_shape) + "."
215 )
217 def _set_number_of_groups_for_instance_norm(self, input_shape):
218 dim = input_shape[self.axis]
220 if self.groups == -1:
221 self.groups = dim
223 def _check_size_of_dimensions(self, input_shape):
225 dim = input_shape[self.axis]
226 if dim < self.groups:
227 raise ValueError(
228 "Number of groups (" + str(self.groups) + ") cannot be "
229 "more than the number of channels (" + str(dim) + ")."
230 )
232 if dim % self.groups != 0:
233 raise ValueError(
234 "Number of groups (" + str(self.groups) + ") must be a "
235 "multiple of the number of channels (" + str(dim) + ")."
236 )
238 def _check_axis(self):
240 if self.axis == 0:
241 raise ValueError(
242 "You are trying to normalize your batch axis. Do you want to "
243 "use tf.layer.batch_normalization instead"
244 )
246 def _create_input_spec(self, input_shape):
248 dim = input_shape[self.axis]
249 self.input_spec = tf.keras.layers.InputSpec(
250 ndim=len(input_shape), axes={self.axis: dim}
251 )
253 def _add_gamma_weight(self, input_shape):
255 dim = input_shape[self.axis]
256 shape = (dim,)
258 if self.scale:
259 self.gamma = self.add_weight(
260 shape=shape,
261 name="gamma",
262 initializer=self.gamma_initializer,
263 regularizer=self.gamma_regularizer,
264 constraint=self.gamma_constraint,
265 )
266 else:
267 self.gamma = None
269 def _add_beta_weight(self, input_shape):
271 dim = input_shape[self.axis]
272 shape = (dim,)
274 if self.center:
275 self.beta = self.add_weight(
276 shape=shape,
277 name="beta",
278 initializer=self.beta_initializer,
279 regularizer=self.beta_regularizer,
280 constraint=self.beta_constraint,
281 )
282 else:
283 self.beta = None
285 def _create_broadcast_shape(self, input_shape):
286 broadcast_shape = [1] * len(input_shape)
287 is_instance_norm = (input_shape[self.axis] // self.groups) == 1
288 if not is_instance_norm:
289 broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
290 broadcast_shape.insert(self.axis, self.groups)
291 else:
292 broadcast_shape[self.axis] = self.groups
293 return broadcast_shape
296@tf.keras.utils.register_keras_serializable(package="Addons")
297class InstanceNormalization(GroupNormalization):
298 """Instance normalization layer.
300 Instance Normalization is an specific case of ```GroupNormalization```since
301 it normalizes all features of one channel. The Groupsize is equal to the
302 channel size. Empirically, its accuracy is more stable than batch norm in a
303 wide range of small batch sizes, if learning rate is adjusted linearly
304 with batch sizes.
306 Arguments
307 axis: Integer, the axis that should be normalized.
308 epsilon: Small float added to variance to avoid dividing by zero.
309 center: If True, add offset of `beta` to normalized tensor.
310 If False, `beta` is ignored.
311 scale: If True, multiply by `gamma`.
312 If False, `gamma` is not used.
313 beta_initializer: Initializer for the beta weight.
314 gamma_initializer: Initializer for the gamma weight.
315 beta_regularizer: Optional regularizer for the beta weight.
316 gamma_regularizer: Optional regularizer for the gamma weight.
317 beta_constraint: Optional constraint for the beta weight.
318 gamma_constraint: Optional constraint for the gamma weight.
320 Input shape
321 Arbitrary. Use the keyword argument `input_shape`
322 (tuple of integers, does not include the samples axis)
323 when using this layer as the first layer in a model.
325 Output shape
326 Same shape as input.
328 References
329 - [Instance Normalization: The Missing Ingredient for Fast Stylization]
330 (https://arxiv.org/abs/1607.08022)
331 """
333 def __init__(self, **kwargs):
334 if "groups" in kwargs:
335 logging.warning("The given value for groups will be overwritten.")
337 kwargs["groups"] = -1
338 super().__init__(**kwargs)
341@tf.keras.utils.register_keras_serializable(package="Addons")
342class FilterResponseNormalization(tf.keras.layers.Layer):
343 """Filter response normalization layer.
345 Filter Response Normalization (FRN), a normalization
346 method that enables models trained with per-channel
347 normalization to achieve high accuracy. It performs better than
348 all other normalization techniques for small batches and is par
349 with Batch Normalization for bigger batch sizes.
351 Arguments
352 axis: List of axes that should be normalized. This should represent the
353 spatial dimensions.
354 epsilon: Small positive float value added to variance to avoid dividing by zero.
355 beta_initializer: Initializer for the beta weight.
356 gamma_initializer: Initializer for the gamma weight.
357 beta_regularizer: Optional regularizer for the beta weight.
358 gamma_regularizer: Optional regularizer for the gamma weight.
359 beta_constraint: Optional constraint for the beta weight.
360 gamma_constraint: Optional constraint for the gamma weight.
361 learned_epsilon: (bool) Whether to add another learnable
362 epsilon parameter or not.
363 name: Optional name for the layer
365 Input shape
366 Arbitrary. Use the keyword argument `input_shape`
367 (tuple of integers, does not include the samples axis)
368 when using this layer as the first layer in a model. This layer, as of now,
369 works on a 4-D tensor where the tensor should have the shape [N X H X W X C]
371 TODO: Add support for NCHW data format and FC layers.
373 Output shape
374 Same shape as input.
376 References
377 - [Filter Response Normalization Layer: Eliminating Batch Dependence
378 in the training of Deep Neural Networks]
379 (https://arxiv.org/abs/1911.09737)
380 """
382 def __init__(
383 self,
384 epsilon: float = 1e-6,
385 axis: list = [1, 2],
386 beta_initializer: types.Initializer = "zeros",
387 gamma_initializer: types.Initializer = "ones",
388 beta_regularizer: types.Regularizer = None,
389 gamma_regularizer: types.Regularizer = None,
390 beta_constraint: types.Constraint = None,
391 gamma_constraint: types.Constraint = None,
392 learned_epsilon: bool = False,
393 learned_epsilon_constraint: types.Constraint = None,
394 name: str = None,
395 **kwargs,
396 ):
397 super().__init__(name=name, **kwargs)
398 self.epsilon = epsilon
399 self.beta_initializer = tf.keras.initializers.get(beta_initializer)
400 self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
401 self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer)
402 self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer)
403 self.beta_constraint = tf.keras.constraints.get(beta_constraint)
404 self.gamma_constraint = tf.keras.constraints.get(gamma_constraint)
405 self.use_eps_learned = learned_epsilon
406 self.supports_masking = True
408 if self.use_eps_learned:
409 self.eps_learned_initializer = tf.keras.initializers.Constant(1e-4)
410 self.eps_learned_constraint = tf.keras.constraints.get(
411 learned_epsilon_constraint
412 )
413 self.eps_learned = self.add_weight(
414 shape=(1,),
415 name="learned_epsilon",
416 dtype=self.dtype,
417 initializer=tf.keras.initializers.get(self.eps_learned_initializer),
418 regularizer=None,
419 constraint=self.eps_learned_constraint,
420 )
421 else:
422 self.eps_learned_initializer = None
423 self.eps_learned_constraint = None
425 self._check_axis(axis)
427 def build(self, input_shape):
428 if len(tf.TensorShape(input_shape)) != 4:
429 raise ValueError(
430 """Only 4-D tensors (CNNs) are supported
431 as of now."""
432 )
433 self._check_if_input_shape_is_none(input_shape)
434 self._create_input_spec(input_shape)
435 self._add_gamma_weight(input_shape)
436 self._add_beta_weight(input_shape)
437 super().build(input_shape)
439 def call(self, inputs):
440 epsilon = tf.math.abs(tf.cast(self.epsilon, dtype=self.dtype))
441 if self.use_eps_learned:
442 epsilon += tf.math.abs(self.eps_learned)
443 nu2 = tf.reduce_mean(tf.square(inputs), axis=self.axis, keepdims=True)
444 normalized_inputs = inputs * tf.math.rsqrt(nu2 + epsilon)
445 return self.gamma * normalized_inputs + self.beta
447 def get_config(self):
448 config = {
449 "axis": self.axis,
450 "epsilon": self.epsilon,
451 "learned_epsilon": self.use_eps_learned,
452 "beta_initializer": tf.keras.initializers.serialize(self.beta_initializer),
453 "gamma_initializer": tf.keras.initializers.serialize(
454 self.gamma_initializer
455 ),
456 "beta_regularizer": tf.keras.regularizers.serialize(self.beta_regularizer),
457 "gamma_regularizer": tf.keras.regularizers.serialize(
458 self.gamma_regularizer
459 ),
460 "beta_constraint": tf.keras.constraints.serialize(self.beta_constraint),
461 "gamma_constraint": tf.keras.constraints.serialize(self.gamma_constraint),
462 "learned_epsilon_constraint": tf.keras.constraints.serialize(
463 self.eps_learned_constraint
464 ),
465 }
466 base_config = super().get_config()
467 return dict(**base_config, **config)
469 def _create_input_spec(self, input_shape):
470 ndims = len(tf.TensorShape(input_shape))
471 for idx, x in enumerate(self.axis):
472 if x < 0:
473 self.axis[idx] = ndims + x
475 # Validate axes
476 for x in self.axis:
477 if x < 0 or x >= ndims:
478 raise ValueError("Invalid axis: %d" % x)
480 if len(self.axis) != len(set(self.axis)):
481 raise ValueError("Duplicate axis: %s" % self.axis)
483 axis_to_dim = {x: input_shape[x] for x in self.axis}
484 self.input_spec = tf.keras.layers.InputSpec(ndim=ndims, axes=axis_to_dim)
486 def _check_axis(self, axis):
487 if not isinstance(axis, list):
488 raise TypeError(
489 """Expected a list of values but got {}.""".format(type(axis))
490 )
491 else:
492 self.axis = axis
494 if self.axis != [1, 2]:
495 raise ValueError(
496 """FilterResponseNormalization operates on per-channel basis.
497 Axis values should be a list of spatial dimensions."""
498 )
500 def _check_if_input_shape_is_none(self, input_shape):
501 dim1, dim2 = input_shape[self.axis[0]], input_shape[self.axis[1]]
502 if dim1 is None or dim2 is None:
503 raise ValueError(
504 """Axis {} of input tensor should have a defined dimension but
505 the layer received an input with shape {}.""".format(
506 self.axis, input_shape
507 )
508 )
510 def _add_gamma_weight(self, input_shape):
511 # Get the channel dimension
512 dim = input_shape[-1]
513 shape = [1, 1, 1, dim]
514 # Initialize gamma with shape (1, 1, 1, C)
515 self.gamma = self.add_weight(
516 shape=shape,
517 name="gamma",
518 dtype=self.dtype,
519 initializer=self.gamma_initializer,
520 regularizer=self.gamma_regularizer,
521 constraint=self.gamma_constraint,
522 )
524 def _add_beta_weight(self, input_shape):
525 # Get the channel dimension
526 dim = input_shape[-1]
527 shape = [1, 1, 1, dim]
528 # Initialize beta with shape (1, 1, 1, C)
529 self.beta = self.add_weight(
530 shape=shape,
531 name="beta",
532 dtype=self.dtype,
533 initializer=self.beta_initializer,
534 regularizer=self.beta_regularizer,
535 constraint=self.beta_constraint,
536 )