Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/normalization/layer_normalization.py: 17%

105 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Layer Normalization layer.""" 

16 

17import tensorflow.compat.v2 as tf 

18 

19from keras.src import constraints 

20from keras.src import initializers 

21from keras.src import regularizers 

22from keras.src.dtensor import utils 

23from keras.src.engine.base_layer import Layer 

24from keras.src.utils import tf_utils 

25 

26# isort: off 

27from tensorflow.python.util.tf_export import keras_export 

28 

29 

30@keras_export("keras.layers.LayerNormalization") 

31class LayerNormalization(Layer): 

32 """Layer normalization layer (Ba et al., 2016). 

33 

34 Normalize the activations of the previous layer for each given example in a 

35 batch independently, rather than across a batch like Batch Normalization. 

36 i.e. applies a transformation that maintains the mean activation within each 

37 example close to 0 and the activation standard deviation close to 1. 

38 

39 Given a tensor `inputs`, moments are calculated and normalization 

40 is performed across the axes specified in `axis`. 

41 

42 Example: 

43 

44 >>> data = tf.constant(np.arange(10).reshape(5, 2) * 10, dtype=tf.float32) 

45 >>> print(data) 

46 tf.Tensor( 

47 [[ 0. 10.] 

48 [20. 30.] 

49 [40. 50.] 

50 [60. 70.] 

51 [80. 90.]], shape=(5, 2), dtype=float32) 

52 

53 >>> layer = tf.keras.layers.LayerNormalization(axis=1) 

54 >>> output = layer(data) 

55 >>> print(output) 

56 tf.Tensor( 

57 [[-1. 1.] 

58 [-1. 1.] 

59 [-1. 1.] 

60 [-1. 1.] 

61 [-1. 1.]], shape=(5, 2), dtype=float32) 

62 

63 Notice that with Layer Normalization the normalization happens across the 

64 axes *within* each example, rather than across different examples in the 

65 batch. 

66 

67 If `scale` or `center` are enabled, the layer will scale the normalized 

68 outputs by broadcasting them with a trainable variable `gamma`, and center 

69 the outputs by broadcasting with a trainable variable `beta`. `gamma` will 

70 default to a ones tensor and `beta` will default to a zeros tensor, so that 

71 centering and scaling are no-ops before training has begun. 

72 

73 So, with scaling and centering enabled the normalization equations 

74 are as follows: 

75 

76 Let the intermediate activations for a mini-batch to be the `inputs`. 

77 

78 For each sample `x_i` in `inputs` with `k` features, we compute the mean and 

79 variance of the sample: 

80 

81 ```python 

82 mean_i = sum(x_i[j] for j in range(k)) / k 

83 var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k 

84 ``` 

85 

86 and then compute a normalized `x_i_normalized`, including a small factor 

87 `epsilon` for numerical stability. 

88 

89 ```python 

90 x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon) 

91 ``` 

92 

93 And finally `x_i_normalized ` is linearly transformed by `gamma` and `beta`, 

94 which are learned parameters: 

95 

96 ```python 

97 output_i = x_i_normalized * gamma + beta 

98 ``` 

99 

100 `gamma` and `beta` will span the axes of `inputs` specified in `axis`, and 

101 this part of the inputs' shape must be fully defined. 

102 

103 For example: 

104 

105 >>> layer = tf.keras.layers.LayerNormalization(axis=[1, 2, 3]) 

106 >>> layer.build([5, 20, 30, 40]) 

107 >>> print(layer.beta.shape) 

108 (20, 30, 40) 

109 >>> print(layer.gamma.shape) 

110 (20, 30, 40) 

111 

112 Note that other implementations of layer normalization may choose to define 

113 `gamma` and `beta` over a separate set of axes from the axes being 

114 normalized across. For example, Group Normalization 

115 ([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1 

116 corresponds to a Layer Normalization that normalizes across height, width, 

117 and channel and has `gamma` and `beta` span only the channel dimension. 

118 So, this Layer Normalization implementation will not match a Group 

119 Normalization layer with group size set to 1. 

120 

121 Args: 

122 axis: Integer or List/Tuple. The axis or axes to normalize across. 

123 Typically this is the features axis/axes. The left-out axes are 

124 typically the batch axis/axes. This argument defaults to `-1`, the last 

125 dimension in the input. 

126 epsilon: Small float added to variance to avoid dividing by zero. Defaults 

127 to 1e-3 

128 center: If True, add offset of `beta` to normalized tensor. If False, 

129 `beta` is ignored. Defaults to True. 

130 scale: If True, multiply by `gamma`. If False, `gamma` is not used. 

131 Defaults to True. When the next layer is linear (also e.g. `nn.relu`), 

132 this can be disabled since the scaling will be done by the next layer. 

133 beta_initializer: Initializer for the beta weight. Defaults to zeros. 

134 gamma_initializer: Initializer for the gamma weight. Defaults to ones. 

135 beta_regularizer: Optional regularizer for the beta weight. None by 

136 default. 

137 gamma_regularizer: Optional regularizer for the gamma weight. None by 

138 default. 

139 beta_constraint: Optional constraint for the beta weight. None by default. 

140 gamma_constraint: Optional constraint for the gamma weight. None by 

141 default. 

142 

143 Input shape: 

144 Arbitrary. Use the keyword argument `input_shape` (tuple of 

145 integers, does not include the samples axis) when using this layer as the 

146 first layer in a model. 

147 

148 Output shape: 

149 Same shape as input. 

150 

151 Reference: 

152 - [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450). 

153 """ 

154 

155 @utils.allow_initializer_layout 

156 def __init__( 

157 self, 

158 axis=-1, 

159 epsilon=1e-3, 

160 center=True, 

161 scale=True, 

162 beta_initializer="zeros", 

163 gamma_initializer="ones", 

164 beta_regularizer=None, 

165 gamma_regularizer=None, 

166 beta_constraint=None, 

167 gamma_constraint=None, 

168 **kwargs 

169 ): 

170 super().__init__(**kwargs) 

171 if isinstance(axis, (list, tuple)): 

172 self.axis = list(axis) 

173 elif isinstance(axis, int): 

174 self.axis = axis 

175 else: 

176 raise TypeError( 

177 "Expected an int or a list/tuple of ints for the " 

178 "argument 'axis', but received: %r" % axis 

179 ) 

180 

181 self.epsilon = epsilon 

182 self.center = center 

183 self.scale = scale 

184 self.beta_initializer = initializers.get(beta_initializer) 

185 self.gamma_initializer = initializers.get(gamma_initializer) 

186 self.beta_regularizer = regularizers.get(beta_regularizer) 

187 self.gamma_regularizer = regularizers.get(gamma_regularizer) 

188 self.beta_constraint = constraints.get(beta_constraint) 

189 self.gamma_constraint = constraints.get(gamma_constraint) 

190 

191 self.supports_masking = True 

192 

193 # Indicates whether a faster fused implementation can be used. This will 

194 # be set to True or False in build()" 

195 self._fused = None 

196 

197 def _fused_can_be_used(self, ndims): 

198 """Returns false if fused implementation cannot be used. 

199 

200 Check if the axis is contiguous and can be collapsed into the last axis. 

201 The self.axis is assumed to have no duplicates. 

202 """ 

203 axis = sorted(self.axis) 

204 can_use_fused = False 

205 

206 if axis[-1] == ndims - 1 and axis[-1] - axis[0] == len(axis) - 1: 

207 can_use_fused = True 

208 

209 # fused_batch_norm will silently raise epsilon to be at least 1.001e-5, 

210 # so we cannot used the fused version if epsilon is below that value. 

211 # Also, the variable dtype must be float32, as fused_batch_norm only 

212 # supports float32 variables. 

213 if self.epsilon < 1.001e-5 or self.dtype != "float32": 

214 can_use_fused = False 

215 

216 return can_use_fused 

217 

218 def build(self, input_shape): 

219 self.axis = tf_utils.validate_axis(self.axis, input_shape) 

220 input_shape = tf.TensorShape(input_shape) 

221 rank = input_shape.rank 

222 

223 param_shape = [input_shape[dim] for dim in self.axis] 

224 if self.scale: 

225 self.gamma = self.add_weight( 

226 name="gamma", 

227 shape=param_shape, 

228 initializer=self.gamma_initializer, 

229 regularizer=self.gamma_regularizer, 

230 constraint=self.gamma_constraint, 

231 trainable=True, 

232 experimental_autocast=False, 

233 ) 

234 else: 

235 self.gamma = None 

236 

237 if self.center: 

238 self.beta = self.add_weight( 

239 name="beta", 

240 shape=param_shape, 

241 initializer=self.beta_initializer, 

242 regularizer=self.beta_regularizer, 

243 constraint=self.beta_constraint, 

244 trainable=True, 

245 experimental_autocast=False, 

246 ) 

247 else: 

248 self.beta = None 

249 

250 self._fused = self._fused_can_be_used(rank) 

251 self.built = True 

252 

253 def call(self, inputs): 

254 # TODO(b/229545225): Remove the RaggedTensor check. 

255 is_ragged = isinstance(inputs, tf.RaggedTensor) 

256 if is_ragged: 

257 inputs_lengths = inputs.nested_row_lengths() 

258 inputs = inputs.to_tensor() 

259 inputs = tf.cast(inputs, self.compute_dtype) 

260 # Compute the axes along which to reduce the mean / variance 

261 input_shape = inputs.shape 

262 ndims = len(input_shape) 

263 

264 # Broadcasting only necessary for norm when the axis is not just 

265 # the last dimension 

266 broadcast_shape = [1] * ndims 

267 for dim in self.axis: 

268 broadcast_shape[dim] = input_shape.dims[dim].value 

269 

270 def _broadcast(v): 

271 if ( 

272 v is not None 

273 and len(v.shape) != ndims 

274 and self.axis != [ndims - 1] 

275 ): 

276 return tf.reshape(v, broadcast_shape) 

277 return v 

278 

279 if not self._fused: 

280 input_dtype = inputs.dtype 

281 if ( 

282 input_dtype in ("float16", "bfloat16") 

283 and self.dtype == "float32" 

284 ): 

285 # If mixed precision is used, cast inputs to float32 so that 

286 # this is at least as numerically stable as the fused version. 

287 inputs = tf.cast(inputs, "float32") 

288 

289 # Calculate the moments on the last axis (layer activations). 

290 mean, variance = tf.nn.moments(inputs, self.axis, keepdims=True) 

291 

292 scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 

293 

294 # Compute layer normalization using the batch_normalization 

295 # function. 

296 outputs = tf.nn.batch_normalization( 

297 inputs, 

298 mean, 

299 variance, 

300 offset=offset, 

301 scale=scale, 

302 variance_epsilon=self.epsilon, 

303 ) 

304 outputs = tf.cast(outputs, input_dtype) 

305 else: 

306 # Collapse dims before self.axis, and dims in self.axis 

307 pre_dim, in_dim = (1, 1) 

308 axis = sorted(self.axis) 

309 tensor_shape = tf.shape(inputs) 

310 for dim in range(0, ndims): 

311 dim_tensor = tensor_shape[dim] 

312 if dim < axis[0]: 

313 pre_dim = pre_dim * dim_tensor 

314 else: 

315 assert dim in axis 

316 in_dim = in_dim * dim_tensor 

317 

318 squeezed_shape = [1, pre_dim, in_dim, 1] 

319 # This fused operation requires reshaped inputs to be NCHW. 

320 data_format = "NCHW" 

321 

322 inputs = tf.reshape(inputs, squeezed_shape) 

323 

324 # self.gamma and self.beta have the wrong shape for 

325 # fused_batch_norm, so we cannot pass them as the scale and offset 

326 # parameters. Therefore, we create two constant tensors in correct 

327 # shapes for fused_batch_norm and later construct a separate 

328 # calculation on the scale and offset. 

329 scale = tf.ones([pre_dim], dtype=self.dtype) 

330 offset = tf.zeros([pre_dim], dtype=self.dtype) 

331 

332 # Compute layer normalization using the fused_batch_norm function. 

333 outputs, _, _ = tf.compat.v1.nn.fused_batch_norm( 

334 inputs, 

335 scale=scale, 

336 offset=offset, 

337 epsilon=self.epsilon, 

338 data_format=data_format, 

339 ) 

340 

341 outputs = tf.reshape(outputs, tensor_shape) 

342 

343 scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 

344 

345 if scale is not None: 

346 outputs = outputs * tf.cast(scale, outputs.dtype) 

347 if offset is not None: 

348 outputs = outputs + tf.cast(offset, outputs.dtype) 

349 

350 # If some components of the shape got lost due to adjustments, fix that. 

351 outputs.set_shape(input_shape) 

352 

353 if is_ragged: 

354 outputs = tf.RaggedTensor.from_tensor(outputs, inputs_lengths) 

355 return outputs 

356 

357 def compute_output_shape(self, input_shape): 

358 return input_shape 

359 

360 def get_config(self): 

361 config = { 

362 "axis": self.axis, 

363 "epsilon": self.epsilon, 

364 "center": self.center, 

365 "scale": self.scale, 

366 "beta_initializer": initializers.serialize(self.beta_initializer), 

367 "gamma_initializer": initializers.serialize(self.gamma_initializer), 

368 "beta_regularizer": regularizers.serialize(self.beta_regularizer), 

369 "gamma_regularizer": regularizers.serialize(self.gamma_regularizer), 

370 "beta_constraint": constraints.serialize(self.beta_constraint), 

371 "gamma_constraint": constraints.serialize(self.gamma_constraint), 

372 } 

373 base_config = super().get_config() 

374 return dict(list(base_config.items()) + list(config.items())) 

375