Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/normalization/group_normalization.py: 24%

86 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The Keras Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Group normalization layer""" 

16 

17import tensorflow.compat.v2 as tf 

18 

19from keras.src import backend 

20from keras.src import constraints 

21from keras.src import initializers 

22from keras.src import regularizers 

23from keras.src.layers import InputSpec 

24from keras.src.layers import Layer 

25from keras.src.utils import tf_utils 

26 

27# isort: off 

28from tensorflow.python.util.tf_export import keras_export 

29 

30 

31@keras_export("keras.layers.GroupNormalization", v1=[]) 

32class GroupNormalization(Layer): 

33 """Group normalization layer. 

34 

35 Group Normalization divides the channels into groups and computes 

36 within each group the mean and variance for normalization. 

37 Empirically, its accuracy is more stable than batch norm in a wide 

38 range of small batch sizes, if learning rate is adjusted linearly 

39 with batch sizes. 

40 

41 Relation to Layer Normalization: 

42 If the number of groups is set to 1, then this operation becomes nearly 

43 identical to Layer Normalization (see Layer Normalization docs for details). 

44 

45 Relation to Instance Normalization: 

46 If the number of groups is set to the input dimension (number of groups is 

47 equal to number of channels), then this operation becomes identical to 

48 Instance Normalization. 

49 

50 Args: 

51 groups: Integer, the number of groups for Group Normalization. Can be in 

52 the range [1, N] where N is the input dimension. The input dimension 

53 must be divisible by the number of groups. Defaults to 32. 

54 axis: Integer or List/Tuple. The axis or axes to normalize across. 

55 Typically this is the features axis/axes. The left-out axes are 

56 typically the batch axis/axes. This argument defaults to `-1`, the last 

57 dimension in the input. 

58 epsilon: Small float added to variance to avoid dividing by zero. Defaults 

59 to 1e-3 

60 center: If True, add offset of `beta` to normalized tensor. If False, 

61 `beta` is ignored. Defaults to True. 

62 scale: If True, multiply by `gamma`. If False, `gamma` is not used. 

63 Defaults to True. When the next layer is linear (also e.g. `nn.relu`), 

64 this can be disabled since the scaling will be done by the next layer. 

65 beta_initializer: Initializer for the beta weight. Defaults to zeros. 

66 gamma_initializer: Initializer for the gamma weight. Defaults to ones. 

67 beta_regularizer: Optional regularizer for the beta weight. None by 

68 default. 

69 gamma_regularizer: Optional regularizer for the gamma weight. None by 

70 default. 

71 beta_constraint: Optional constraint for the beta weight. None by default. 

72 gamma_constraint: Optional constraint for the gamma weight. None by 

73 default. Input shape: Arbitrary. Use the keyword argument `input_shape` 

74 (tuple of integers, does not include the samples axis) when using this 

75 layer as the first layer in a model. Output shape: Same shape as input. 

76 Reference: - [Yuxin Wu & Kaiming He, 2018](https://arxiv.org/abs/1803.08494) 

77 """ 

78 

79 def __init__( 

80 self, 

81 groups=32, 

82 axis=-1, 

83 epsilon=1e-3, 

84 center=True, 

85 scale=True, 

86 beta_initializer="zeros", 

87 gamma_initializer="ones", 

88 beta_regularizer=None, 

89 gamma_regularizer=None, 

90 beta_constraint=None, 

91 gamma_constraint=None, 

92 **kwargs, 

93 ): 

94 super().__init__(**kwargs) 

95 self.supports_masking = True 

96 self.groups = groups 

97 self.axis = axis 

98 self.epsilon = epsilon 

99 self.center = center 

100 self.scale = scale 

101 self.beta_initializer = initializers.get(beta_initializer) 

102 self.gamma_initializer = initializers.get(gamma_initializer) 

103 self.beta_regularizer = regularizers.get(beta_regularizer) 

104 self.gamma_regularizer = regularizers.get(gamma_regularizer) 

105 self.beta_constraint = constraints.get(beta_constraint) 

106 self.gamma_constraint = constraints.get(gamma_constraint) 

107 

108 def build(self, input_shape): 

109 tf_utils.validate_axis(self.axis, input_shape) 

110 

111 dim = input_shape[self.axis] 

112 if dim is None: 

113 raise ValueError( 

114 f"Axis {self.axis} of input tensor should have a defined " 

115 "dimension but the layer received an input with shape " 

116 f"{input_shape}." 

117 ) 

118 

119 if self.groups == -1: 

120 self.groups = dim 

121 

122 if dim < self.groups: 

123 raise ValueError( 

124 f"Number of groups ({self.groups}) cannot be more than the " 

125 f"number of channels ({dim})." 

126 ) 

127 

128 if dim % self.groups != 0: 

129 raise ValueError( 

130 f"Number of groups ({self.groups}) must be a multiple " 

131 f"of the number of channels ({dim})." 

132 ) 

133 

134 self.input_spec = InputSpec( 

135 ndim=len(input_shape), axes={self.axis: dim} 

136 ) 

137 

138 if self.scale: 

139 self.gamma = self.add_weight( 

140 shape=(dim,), 

141 name="gamma", 

142 initializer=self.gamma_initializer, 

143 regularizer=self.gamma_regularizer, 

144 constraint=self.gamma_constraint, 

145 ) 

146 else: 

147 self.gamma = None 

148 

149 if self.center: 

150 self.beta = self.add_weight( 

151 shape=(dim,), 

152 name="beta", 

153 initializer=self.beta_initializer, 

154 regularizer=self.beta_regularizer, 

155 constraint=self.beta_constraint, 

156 ) 

157 else: 

158 self.beta = None 

159 

160 super().build(input_shape) 

161 

162 def call(self, inputs): 

163 input_shape = tf.shape(inputs) 

164 

165 reshaped_inputs = self._reshape_into_groups(inputs) 

166 

167 normalized_inputs = self._apply_normalization( 

168 reshaped_inputs, input_shape 

169 ) 

170 

171 return tf.reshape(normalized_inputs, input_shape) 

172 

173 def _reshape_into_groups(self, inputs): 

174 input_shape = tf.shape(inputs) 

175 group_shape = [input_shape[i] for i in range(inputs.shape.rank)] 

176 

177 group_shape[self.axis] = input_shape[self.axis] // self.groups 

178 group_shape.insert(self.axis, self.groups) 

179 group_shape = tf.stack(group_shape) 

180 reshaped_inputs = tf.reshape(inputs, group_shape) 

181 return reshaped_inputs 

182 

183 def _apply_normalization(self, reshaped_inputs, input_shape): 

184 group_reduction_axes = list(range(1, reshaped_inputs.shape.rank)) 

185 

186 axis = -2 if self.axis == -1 else self.axis - 1 

187 group_reduction_axes.pop(axis) 

188 

189 mean, variance = tf.nn.moments( 

190 reshaped_inputs, group_reduction_axes, keepdims=True 

191 ) 

192 

193 gamma, beta = self._get_reshaped_weights(input_shape) 

194 normalized_inputs = tf.nn.batch_normalization( 

195 reshaped_inputs, 

196 mean=mean, 

197 variance=variance, 

198 scale=gamma, 

199 offset=beta, 

200 variance_epsilon=self.epsilon, 

201 ) 

202 return normalized_inputs 

203 

204 def _get_reshaped_weights(self, input_shape): 

205 broadcast_shape = self._create_broadcast_shape(input_shape) 

206 gamma = None 

207 beta = None 

208 if self.scale: 

209 gamma = tf.reshape(self.gamma, broadcast_shape) 

210 

211 if self.center: 

212 beta = tf.reshape(self.beta, broadcast_shape) 

213 return gamma, beta 

214 

215 def _create_broadcast_shape(self, input_shape): 

216 broadcast_shape = [1] * backend.int_shape(input_shape)[0] 

217 

218 broadcast_shape[self.axis] = input_shape[self.axis] // self.groups 

219 broadcast_shape.insert(self.axis, self.groups) 

220 

221 return broadcast_shape 

222 

223 def compute_output_shape(self, input_shape): 

224 return input_shape 

225 

226 def get_config(self): 

227 config = { 

228 "groups": self.groups, 

229 "axis": self.axis, 

230 "epsilon": self.epsilon, 

231 "center": self.center, 

232 "scale": self.scale, 

233 "beta_initializer": initializers.serialize(self.beta_initializer), 

234 "gamma_initializer": initializers.serialize(self.gamma_initializer), 

235 "beta_regularizer": regularizers.serialize(self.beta_regularizer), 

236 "gamma_regularizer": regularizers.serialize(self.gamma_regularizer), 

237 "beta_constraint": constraints.serialize(self.beta_constraint), 

238 "gamma_constraint": constraints.serialize(self.gamma_constraint), 

239 } 

240 base_config = super().get_config() 

241 return {**base_config, **config} 

242