Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/normalization/group_normalization.py: 24%
86 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The Keras Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Group normalization layer"""
17import tensorflow.compat.v2 as tf
19from keras.src import backend
20from keras.src import constraints
21from keras.src import initializers
22from keras.src import regularizers
23from keras.src.layers import InputSpec
24from keras.src.layers import Layer
25from keras.src.utils import tf_utils
27# isort: off
28from tensorflow.python.util.tf_export import keras_export
31@keras_export("keras.layers.GroupNormalization", v1=[])
32class GroupNormalization(Layer):
33 """Group normalization layer.
35 Group Normalization divides the channels into groups and computes
36 within each group the mean and variance for normalization.
37 Empirically, its accuracy is more stable than batch norm in a wide
38 range of small batch sizes, if learning rate is adjusted linearly
39 with batch sizes.
41 Relation to Layer Normalization:
42 If the number of groups is set to 1, then this operation becomes nearly
43 identical to Layer Normalization (see Layer Normalization docs for details).
45 Relation to Instance Normalization:
46 If the number of groups is set to the input dimension (number of groups is
47 equal to number of channels), then this operation becomes identical to
48 Instance Normalization.
50 Args:
51 groups: Integer, the number of groups for Group Normalization. Can be in
52 the range [1, N] where N is the input dimension. The input dimension
53 must be divisible by the number of groups. Defaults to 32.
54 axis: Integer or List/Tuple. The axis or axes to normalize across.
55 Typically this is the features axis/axes. The left-out axes are
56 typically the batch axis/axes. This argument defaults to `-1`, the last
57 dimension in the input.
58 epsilon: Small float added to variance to avoid dividing by zero. Defaults
59 to 1e-3
60 center: If True, add offset of `beta` to normalized tensor. If False,
61 `beta` is ignored. Defaults to True.
62 scale: If True, multiply by `gamma`. If False, `gamma` is not used.
63 Defaults to True. When the next layer is linear (also e.g. `nn.relu`),
64 this can be disabled since the scaling will be done by the next layer.
65 beta_initializer: Initializer for the beta weight. Defaults to zeros.
66 gamma_initializer: Initializer for the gamma weight. Defaults to ones.
67 beta_regularizer: Optional regularizer for the beta weight. None by
68 default.
69 gamma_regularizer: Optional regularizer for the gamma weight. None by
70 default.
71 beta_constraint: Optional constraint for the beta weight. None by default.
72 gamma_constraint: Optional constraint for the gamma weight. None by
73 default. Input shape: Arbitrary. Use the keyword argument `input_shape`
74 (tuple of integers, does not include the samples axis) when using this
75 layer as the first layer in a model. Output shape: Same shape as input.
76 Reference: - [Yuxin Wu & Kaiming He, 2018](https://arxiv.org/abs/1803.08494)
77 """
79 def __init__(
80 self,
81 groups=32,
82 axis=-1,
83 epsilon=1e-3,
84 center=True,
85 scale=True,
86 beta_initializer="zeros",
87 gamma_initializer="ones",
88 beta_regularizer=None,
89 gamma_regularizer=None,
90 beta_constraint=None,
91 gamma_constraint=None,
92 **kwargs,
93 ):
94 super().__init__(**kwargs)
95 self.supports_masking = True
96 self.groups = groups
97 self.axis = axis
98 self.epsilon = epsilon
99 self.center = center
100 self.scale = scale
101 self.beta_initializer = initializers.get(beta_initializer)
102 self.gamma_initializer = initializers.get(gamma_initializer)
103 self.beta_regularizer = regularizers.get(beta_regularizer)
104 self.gamma_regularizer = regularizers.get(gamma_regularizer)
105 self.beta_constraint = constraints.get(beta_constraint)
106 self.gamma_constraint = constraints.get(gamma_constraint)
108 def build(self, input_shape):
109 tf_utils.validate_axis(self.axis, input_shape)
111 dim = input_shape[self.axis]
112 if dim is None:
113 raise ValueError(
114 f"Axis {self.axis} of input tensor should have a defined "
115 "dimension but the layer received an input with shape "
116 f"{input_shape}."
117 )
119 if self.groups == -1:
120 self.groups = dim
122 if dim < self.groups:
123 raise ValueError(
124 f"Number of groups ({self.groups}) cannot be more than the "
125 f"number of channels ({dim})."
126 )
128 if dim % self.groups != 0:
129 raise ValueError(
130 f"Number of groups ({self.groups}) must be a multiple "
131 f"of the number of channels ({dim})."
132 )
134 self.input_spec = InputSpec(
135 ndim=len(input_shape), axes={self.axis: dim}
136 )
138 if self.scale:
139 self.gamma = self.add_weight(
140 shape=(dim,),
141 name="gamma",
142 initializer=self.gamma_initializer,
143 regularizer=self.gamma_regularizer,
144 constraint=self.gamma_constraint,
145 )
146 else:
147 self.gamma = None
149 if self.center:
150 self.beta = self.add_weight(
151 shape=(dim,),
152 name="beta",
153 initializer=self.beta_initializer,
154 regularizer=self.beta_regularizer,
155 constraint=self.beta_constraint,
156 )
157 else:
158 self.beta = None
160 super().build(input_shape)
162 def call(self, inputs):
163 input_shape = tf.shape(inputs)
165 reshaped_inputs = self._reshape_into_groups(inputs)
167 normalized_inputs = self._apply_normalization(
168 reshaped_inputs, input_shape
169 )
171 return tf.reshape(normalized_inputs, input_shape)
173 def _reshape_into_groups(self, inputs):
174 input_shape = tf.shape(inputs)
175 group_shape = [input_shape[i] for i in range(inputs.shape.rank)]
177 group_shape[self.axis] = input_shape[self.axis] // self.groups
178 group_shape.insert(self.axis, self.groups)
179 group_shape = tf.stack(group_shape)
180 reshaped_inputs = tf.reshape(inputs, group_shape)
181 return reshaped_inputs
183 def _apply_normalization(self, reshaped_inputs, input_shape):
184 group_reduction_axes = list(range(1, reshaped_inputs.shape.rank))
186 axis = -2 if self.axis == -1 else self.axis - 1
187 group_reduction_axes.pop(axis)
189 mean, variance = tf.nn.moments(
190 reshaped_inputs, group_reduction_axes, keepdims=True
191 )
193 gamma, beta = self._get_reshaped_weights(input_shape)
194 normalized_inputs = tf.nn.batch_normalization(
195 reshaped_inputs,
196 mean=mean,
197 variance=variance,
198 scale=gamma,
199 offset=beta,
200 variance_epsilon=self.epsilon,
201 )
202 return normalized_inputs
204 def _get_reshaped_weights(self, input_shape):
205 broadcast_shape = self._create_broadcast_shape(input_shape)
206 gamma = None
207 beta = None
208 if self.scale:
209 gamma = tf.reshape(self.gamma, broadcast_shape)
211 if self.center:
212 beta = tf.reshape(self.beta, broadcast_shape)
213 return gamma, beta
215 def _create_broadcast_shape(self, input_shape):
216 broadcast_shape = [1] * backend.int_shape(input_shape)[0]
218 broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
219 broadcast_shape.insert(self.axis, self.groups)
221 return broadcast_shape
223 def compute_output_shape(self, input_shape):
224 return input_shape
226 def get_config(self):
227 config = {
228 "groups": self.groups,
229 "axis": self.axis,
230 "epsilon": self.epsilon,
231 "center": self.center,
232 "scale": self.scale,
233 "beta_initializer": initializers.serialize(self.beta_initializer),
234 "gamma_initializer": initializers.serialize(self.gamma_initializer),
235 "beta_regularizer": regularizers.serialize(self.beta_regularizer),
236 "gamma_regularizer": regularizers.serialize(self.gamma_regularizer),
237 "beta_constraint": constraints.serialize(self.beta_constraint),
238 "gamma_constraint": constraints.serialize(self.gamma_constraint),
239 }
240 base_config = super().get_config()
241 return {**base_config, **config}