Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/normalization/batch_normalization.py: 11%
494 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The V2 implementation of Normalization layers."""
17import warnings
19import tensorflow.compat.v2 as tf
21from keras.src import backend
22from keras.src import constraints
23from keras.src import initializers
24from keras.src import regularizers
25from keras.src.dtensor import utils
26from keras.src.engine.base_layer import Layer
27from keras.src.engine.input_spec import InputSpec
28from keras.src.utils import control_flow_util
29from keras.src.utils import tf_utils
31# isort: off
32from tensorflow.python.ops.control_flow_ops import (
33 get_enclosing_xla_context,
34)
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.util import deprecation
37from tensorflow.python.util.tf_export import keras_export
40class BatchNormalizationBase(Layer):
41 r"""Layer that normalizes its inputs.
43 Batch normalization applies a transformation that maintains the mean output
44 close to 0 and the output standard deviation close to 1.
46 Importantly, batch normalization works differently during training and
47 during inference.
49 **During training** (i.e. when using `fit()` or when calling the layer/model
50 with the argument `training=True`), the layer normalizes its output using
51 the mean and standard deviation of the current batch of inputs. That is to
52 say, for each channel being normalized, the layer returns
53 `gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where:
55 - `epsilon` is small constant (configurable as part of the constructor
56 arguments)
57 - `gamma` is a learned scaling factor (initialized as 1), which
58 can be disabled by passing `scale=False` to the constructor.
59 - `beta` is a learned offset factor (initialized as 0), which
60 can be disabled by passing `center=False` to the constructor.
62 **During inference** (i.e. when using `evaluate()` or `predict()`) or when
63 calling the layer/model with the argument `training=False` (which is the
64 default), the layer normalizes its output using a moving average of the
65 mean and standard deviation of the batches it has seen during training. That
66 is to say, it returns
67 `gamma * (batch - self.moving_mean) / sqrt(self.moving_var+epsilon) + beta`.
69 `self.moving_mean` and `self.moving_var` are non-trainable variables that
70 are updated each time the layer in called in training mode, as such:
72 - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)`
73 - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)`
75 As such, the layer will only normalize its inputs during inference
76 *after having been trained on data that has similar statistics as the
77 inference data*.
79 Args:
80 axis: Integer or a list of integers, the axis that should be normalized
81 (typically the features axis). For instance, after a `Conv2D` layer with
82 `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
83 momentum: Momentum for the moving average.
84 epsilon: Small float added to variance to avoid dividing by zero.
85 center: If True, add offset of `beta` to normalized tensor. If False,
86 `beta` is ignored.
87 scale: If True, multiply by `gamma`. If False, `gamma` is not used. When
88 the next layer is linear (also e.g. `nn.relu`), this can be disabled
89 since the scaling will be done by the next layer.
90 beta_initializer: Initializer for the beta weight.
91 gamma_initializer: Initializer for the gamma weight.
92 moving_mean_initializer: Initializer for the moving mean.
93 moving_variance_initializer: Initializer for the moving variance.
94 beta_regularizer: Optional regularizer for the beta weight.
95 gamma_regularizer: Optional regularizer for the gamma weight.
96 beta_constraint: Optional constraint for the beta weight.
97 gamma_constraint: Optional constraint for the gamma weight.
98 renorm: Whether to use [Batch Renormalization](
99 https://arxiv.org/abs/1702.03275). This adds extra variables during
100 training. The inference is the same for either value of this
101 parameter.
102 renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
103 scalar `Tensors` used to clip the renorm correction. The correction `(r,
104 d)` is used as `corrected_value = normalized_value * r + d`, with `r`
105 clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
106 dmax are set to inf, 0, inf, respectively.
107 renorm_momentum: Momentum used to update the moving means and standard
108 deviations with renorm. Unlike `momentum`, this affects training and
109 should be neither too small (which would add noise) nor too large (which
110 would give stale estimates). Note that `momentum` is still applied to
111 get the means and variances for inference.
112 fused: if `True`, use a faster, fused implementation, or raise a
113 ValueError if the fused implementation cannot be used. If `None`, use
114 the faster implementation if possible. If False, do not used the fused
115 implementation. Note that in TensorFlow 1.x, the meaning of
116 `fused=True` is different: if `False`, the layer uses the
117 system-recommended implementation. You cannot use `fused=True` if a
118 mask is passed in the `call()` method.
119 trainable: Boolean, if `True` the variables will be marked as trainable.
120 virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`,
121 which means batch normalization is performed across the whole batch.
122 When `virtual_batch_size` is not `None`, instead perform "Ghost Batch
123 Normalization", which creates virtual sub-batches which are each
124 normalized separately (with shared gamma, beta, and moving statistics).
125 Must divide the actual batch size during execution.
126 adjustment: A function taking the `Tensor` containing the (dynamic) shape
127 of the input tensor and returning a pair (scale, bias) to apply to the
128 normalized values (before gamma and beta), only during training. For
129 example, if `axis=-1`,
130 `adjustment = lambda shape: (
131 tf.random.uniform(shape[-1:], 0.93, 1.07),
132 tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized
133 value by up to 7% up or down, then shift the result by up to 0.1
134 (with independent scaling and bias for each feature but shared
135 across all examples), and finally apply gamma and/or beta. If
136 `None`, no adjustment is applied. Cannot be specified if
137 virtual_batch_size is specified.
138 synchronized: If True, synchronizes the global batch statistics (mean and
139 variance) for the layer across all devices at each training step in a
140 distributed training strategy. If False, each replica uses its own
141 local batch statistics. Only relevant when used inside a
142 `tf.distribute` strategy.
144 Call arguments:
145 inputs: Input tensor (of any rank).
146 training: Python boolean indicating whether the layer should behave in
147 training mode or in inference mode.
148 - `training=True`: The layer will normalize its inputs using the mean
149 and variance of the current batch of inputs.
150 - `training=False`: The layer will normalize its inputs using the mean
151 and variance of its moving statistics, learned during training.
152 mask: Binary tensor of shape broadcastable to `inputs` tensor, indicating
153 the positions for which the mean and variance should be computed.
155 Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of
156 integers, does not include the samples axis) when using this layer as the
157 first layer in a model.
159 Output shape: Same shape as input.
161 Reference:
162 - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167).
163 """
165 # By default, the base class uses V2 behavior. The BatchNormalization V1
166 # subclass sets this to False to use the V1 behavior.
167 _USE_V2_BEHAVIOR = True
169 def __init__(
170 self,
171 axis=-1,
172 momentum=0.99,
173 epsilon=1e-3,
174 center=True,
175 scale=True,
176 beta_initializer="zeros",
177 gamma_initializer="ones",
178 moving_mean_initializer="zeros",
179 moving_variance_initializer="ones",
180 beta_regularizer=None,
181 gamma_regularizer=None,
182 beta_constraint=None,
183 gamma_constraint=None,
184 renorm=False,
185 renorm_clipping=None,
186 renorm_momentum=0.99,
187 fused=None,
188 trainable=True,
189 virtual_batch_size=None,
190 adjustment=None,
191 name=None,
192 synchronized=False,
193 **kwargs,
194 ):
195 super().__init__(name=name, **kwargs)
196 if isinstance(axis, (list, tuple)):
197 self.axis = axis[:]
198 elif isinstance(axis, int):
199 self.axis = axis
200 else:
201 raise TypeError(
202 "Expected an int or a list/tuple of ints for the "
203 "argument 'axis', but received: %r" % axis
204 )
205 if synchronized and fused:
206 raise ValueError(
207 "`fused=True` is not supported when `synchronized=True`."
208 )
209 self.synchronized = synchronized
210 if self.synchronized:
211 fused = False
213 self.momentum = momentum
214 self.epsilon = epsilon
215 self.center = center
216 self.scale = scale
217 self.beta_initializer = initializers.get(beta_initializer)
218 self.gamma_initializer = initializers.get(gamma_initializer)
219 self.moving_mean_initializer = initializers.get(moving_mean_initializer)
220 self.moving_variance_initializer = initializers.get(
221 moving_variance_initializer
222 )
223 self.beta_regularizer = regularizers.get(beta_regularizer)
224 self.gamma_regularizer = regularizers.get(gamma_regularizer)
225 self.beta_constraint = constraints.get(beta_constraint)
226 self.gamma_constraint = constraints.get(gamma_constraint)
227 self.renorm = renorm
228 self.virtual_batch_size = virtual_batch_size
229 self.adjustment = adjustment
230 if self._USE_V2_BEHAVIOR:
231 if fused:
232 self._raise_if_fused_cannot_be_used()
233 # We leave fused as None if self._fused_can_be_used()==True, since
234 # we still may set it to False in self.build() if the input rank is
235 # not 4.
236 elif fused is None and not self._fused_can_be_used():
237 fused = False
238 elif fused is None:
239 fused = True
240 self.supports_masking = True
242 self.fused = fused
243 self._bessels_correction_test_only = True
244 self.trainable = trainable
246 if renorm:
247 renorm_clipping = renorm_clipping or {}
248 keys = ["rmax", "rmin", "dmax"]
249 if set(renorm_clipping) - set(keys):
250 raise ValueError(
251 "Received invalid keys for `renorm_clipping` argument: "
252 f"{renorm_clipping}. Supported values: {keys}."
253 )
254 self.renorm_clipping = renorm_clipping
255 self.renorm_momentum = renorm_momentum
257 def _raise_if_fused_cannot_be_used(self):
258 """Raises a ValueError if fused implementation cannot be used.
260 In addition to the checks done in this function, the input tensors rank
261 must be 4 or 5. The input rank check can only be done once the input
262 shape is known.
263 """
264 # Note the ValueErrors in this function are caught and not reraised in
265 # _fused_can_be_used(). No other exception besides ValueError should be
266 # raised here.
268 # Currently fused batch norm doesn't support renorm. It also only
269 # supports a channel dimension on axis 1 or 3 (rank=4) / 1 or 4 (rank5),
270 # when no virtual batch size or adjustment is used.
271 if self.renorm:
272 raise ValueError(
273 "Passing both `fused=True` and `renorm=True` is not supported"
274 )
275 axis = [self.axis] if isinstance(self.axis, int) else self.axis
276 # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, when the
277 # input rank is 4. Similarly, the valid axis is -4, -1, 1, 4 when the
278 # rank is 5. The combination of ranks and axes will be checked later.
279 if len(axis) > 1 or axis[0] not in (-4, -3, -1, 1, 3, 4):
280 raise ValueError(
281 "Passing `fused=True` is only supported when axis is 1 "
282 "or 3 for input rank = 4 or 1 or 4 for input rank = 5. "
283 "Got axis %s" % (axis,)
284 )
285 if self.virtual_batch_size is not None:
286 raise ValueError(
287 "Passing `fused=True` is not supported when "
288 "`virtual_batch_size` is specified."
289 )
290 if self.adjustment is not None:
291 raise ValueError(
292 "Passing `fused=True` is not supported when "
293 "`adjustment` is specified."
294 )
295 # TODO(reedwm): Support fp64 in FusedBatchNorm then remove this check.
296 if self._compute_dtype not in ("float16", "bfloat16", "float32", None):
297 raise ValueError(
298 "Passing `fused=True` is only supported when the compute "
299 "dtype is float16, bfloat16, or float32. Got dtype: %s"
300 % (self._compute_dtype,)
301 )
303 def _fused_can_be_used(self):
304 try:
305 self._raise_if_fused_cannot_be_used()
306 return True
307 except ValueError:
308 return False
310 @property
311 def trainable(self):
312 return self._trainable
314 @trainable.setter
315 def trainable(self, value):
316 self._trainable = value
318 @property
319 def _param_dtype(self):
320 # Raise parameters of fp16 batch norm to fp32
321 if self.dtype == tf.float16 or self.dtype == tf.bfloat16:
322 return tf.float32
323 else:
324 return self.dtype or tf.float32
326 def build(self, input_shape):
327 self.axis = tf_utils.validate_axis(self.axis, input_shape)
328 input_shape = tf.TensorShape(input_shape)
329 rank = input_shape.rank
331 if self.virtual_batch_size is not None:
332 if self.virtual_batch_size <= 0:
333 raise ValueError(
334 "`virtual_batch_size` must be a positive integer that "
335 "divides the true batch size of the input tensor. "
336 f"Received: virtual_batch_size={self.virtual_batch_size}"
337 )
338 # If using virtual batches, the first dimension must be the batch
339 # dimension and cannot be the batch norm axis
340 if 0 in self.axis:
341 raise ValueError(
342 "When using `virtual_batch_size`, the batch dimension "
343 "must be 0 and thus axis cannot include 0. "
344 f"Received axis={self.axis}"
345 )
346 if self.adjustment is not None:
347 raise ValueError(
348 "When using `virtual_batch_size`, adjustment cannot "
349 "be specified"
350 )
352 if self.fused in (None, True):
353 # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape
354 # the output back to its original shape accordingly.
355 if self._USE_V2_BEHAVIOR:
356 if self.fused is None:
357 self.fused = rank in (4, 5)
358 elif self.fused and rank not in (4, 5):
359 raise ValueError(
360 "Batch normalization layers with `fused=True` only "
361 "support 4D or 5D input tensors. "
362 f"Received tensor with shape: {tuple(input_shape)}"
363 )
364 else:
365 assert self.fused is not None
366 self.fused = rank in (4, 5) and self._fused_can_be_used()
367 # TODO(chrisying): fused batch norm is currently not supported for
368 # multi-axis batch norm and by extension virtual batches. In some
369 # cases, it might be possible to use fused batch norm but would
370 # require reshaping the Tensor to 4D with the axis in 1 or 3
371 # (preferred 1) which is particularly tricky. A compromise might be
372 # to just support the most common use case (turning 5D w/ virtual
373 # batch to NCHW)
375 if self.fused:
376 if self.axis == [1] and rank == 4:
377 self._data_format = "NCHW"
378 elif self.axis == [1] and rank == 5:
379 self._data_format = "NCDHW"
380 elif self.axis == [3] and rank == 4:
381 self._data_format = "NHWC"
382 elif self.axis == [4] and rank == 5:
383 self._data_format = "NDHWC"
384 elif rank == 5:
385 # 5D tensors that can be passed in but should not use fused
386 # batch norm due to unsupported axis.
387 self.fused = False
388 else:
389 if rank == 4:
390 raise ValueError(
391 "Unsupported axis. The use of `fused=True` is only "
392 "possible with `axis=1` or `axis=3` for 4D input "
393 f"tensors. Received: axis={tuple(self.axis)}"
394 )
395 else:
396 raise ValueError(
397 "Unsupported axis. The use of `fused=True` is only "
398 "possible with `axis=1` or `axis=4` for 5D input "
399 f"tensors. Received: axis={tuple(self.axis)}"
400 )
402 axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
403 for x in axis_to_dim:
404 if axis_to_dim[x] is None:
405 raise ValueError(
406 "Input has undefined `axis` dimension. Received input "
407 f"with shape {tuple(input_shape)} "
408 f"and axis={tuple(self.axis)}"
409 )
410 self.input_spec = InputSpec(ndim=rank, axes=axis_to_dim)
412 if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
413 # Single axis batch norm (most common/default use-case)
414 param_shape = (list(axis_to_dim.values())[0],)
415 else:
416 # Parameter shape is the original shape but with 1 in all non-axis
417 # dims
418 param_shape = [
419 axis_to_dim[i] if i in axis_to_dim else 1 for i in range(rank)
420 ]
421 if self.virtual_batch_size is not None:
422 # When using virtual batches, add an extra dim at index 1
423 param_shape.insert(1, 1)
424 for idx, x in enumerate(self.axis):
425 self.axis[idx] = x + 1 # Account for added dimension
426 self._param_shape = param_shape
427 if self.scale:
428 self.gamma = self.add_weight(
429 name="gamma",
430 shape=param_shape,
431 dtype=self._param_dtype,
432 initializer=self.gamma_initializer,
433 regularizer=self.gamma_regularizer,
434 constraint=self.gamma_constraint,
435 trainable=True,
436 experimental_autocast=False,
437 )
438 else:
439 self.gamma = None
441 if self.center:
442 self.beta = self.add_weight(
443 name="beta",
444 shape=param_shape,
445 dtype=self._param_dtype,
446 initializer=self.beta_initializer,
447 regularizer=self.beta_regularizer,
448 constraint=self.beta_constraint,
449 trainable=True,
450 experimental_autocast=False,
451 )
452 else:
453 self.beta = None
455 try:
456 # Disable variable partitioning when creating the moving mean and
457 # variance
458 if hasattr(self, "_scope") and self._scope:
459 partitioner = self._scope.partitioner
460 self._scope.set_partitioner(None)
461 else:
462 partitioner = None
463 self.moving_mean = self.add_weight(
464 name="moving_mean",
465 shape=param_shape,
466 dtype=self._param_dtype,
467 initializer=self.moving_mean_initializer,
468 synchronization=tf.VariableSynchronization.ON_READ,
469 trainable=False,
470 aggregation=tf.VariableAggregation.MEAN,
471 experimental_autocast=False,
472 )
474 self.moving_variance = self.add_weight(
475 name="moving_variance",
476 shape=param_shape,
477 dtype=self._param_dtype,
478 initializer=self.moving_variance_initializer,
479 synchronization=tf.VariableSynchronization.ON_READ,
480 trainable=False,
481 aggregation=tf.VariableAggregation.MEAN,
482 experimental_autocast=False,
483 )
485 if self.renorm:
486 # In batch renormalization we track the inference moving stddev
487 # instead of the moving variance to more closely align with the
488 # paper.
489 def moving_stddev_initializer(*args, **kwargs):
490 return tf.sqrt(
491 self.moving_variance_initializer(*args, **kwargs)
492 )
494 with tf.distribute.get_strategy().extended.colocate_vars_with(
495 self.moving_variance
496 ):
497 self.moving_stddev = self.add_weight(
498 name="moving_stddev",
499 shape=param_shape,
500 dtype=self._param_dtype,
501 initializer=moving_stddev_initializer,
502 synchronization=tf.VariableSynchronization.ON_READ,
503 trainable=False,
504 aggregation=tf.VariableAggregation.MEAN,
505 experimental_autocast=False,
506 )
508 # Create variables to maintain the moving mean and standard
509 # deviation. These are used in training and thus are different
510 # from the moving averages above. The renorm variables are
511 # colocated with moving_mean and moving_stddev.
512 # NOTE: below, the outer `with device` block causes the current
513 # device stack to be cleared. The nested ones use a `lambda` to
514 # set the desired device and ignore any devices that may be set
515 # by the custom getter.
516 def _renorm_variable(name, shape, initializer="zeros"):
517 """Create a renorm variable."""
518 var = self.add_weight(
519 name=name,
520 shape=shape,
521 dtype=self._param_dtype,
522 initializer=initializer,
523 synchronization=tf.VariableSynchronization.ON_READ,
524 trainable=False,
525 aggregation=tf.VariableAggregation.MEAN,
526 experimental_autocast=False,
527 )
528 return var
530 with tf.distribute.get_strategy().extended.colocate_vars_with(
531 self.moving_mean
532 ):
533 self.renorm_mean = _renorm_variable(
534 "renorm_mean", param_shape, self.moving_mean_initializer
535 )
536 with tf.distribute.get_strategy().extended.colocate_vars_with(
537 self.moving_stddev
538 ):
539 self.renorm_stddev = _renorm_variable(
540 "renorm_stddev", param_shape, moving_stddev_initializer
541 )
542 finally:
543 if partitioner:
544 self._scope.set_partitioner(partitioner)
545 self.built = True
547 def call(self, inputs, training=None, mask=None):
548 inputs = tf.cast(inputs, self.compute_dtype)
549 training = self._get_training_value(training)
550 # Determine a boolean value for `training`: could be True, False, or
551 # None.
552 training_value = control_flow_util.constant_value(training)
553 _raise_for_non_sync_bn_with_renorm_and_dtensor_strategy(
554 synchronized=self.synchronized,
555 training=training,
556 renorm=self.renorm,
557 )
559 if self.virtual_batch_size is not None:
560 # Virtual batches (aka ghost batches) can be simulated by reshaping
561 # the Tensor and reusing the existing batch norm implementation
562 original_shape = tf.shape(inputs)
563 original_shape = tf.concat(
564 [tf.constant([-1]), original_shape[1:]], axis=0
565 )
567 if tf.__internal__.tf2.enabled():
568 expanded_shape = (
569 [self.virtual_batch_size, -1] if training_value else [-1, 1]
570 )
571 expanded_shape = tf.concat(
572 [
573 tf.constant(expanded_shape),
574 original_shape[1:],
575 ],
576 axis=0,
577 )
578 else:
579 # Preserve incorrect legacy behavior for backwards compatibility
580 expanded_shape = tf.concat(
581 [
582 tf.constant([self.virtual_batch_size, -1]),
583 original_shape[1:],
584 ],
585 axis=0,
586 )
588 # Will cause errors if virtual_batch_size does not divide the batch
589 # size
590 inputs = tf.reshape(inputs, expanded_shape)
592 def undo_virtual_batching(outputs):
593 outputs = tf.reshape(outputs, original_shape)
594 return outputs
596 if self.fused:
597 outputs = self._fused_batch_norm(
598 inputs, mask=mask, training=training
599 )
600 if self.virtual_batch_size is not None:
601 # Currently never reaches here since fused_batch_norm does not
602 # support virtual batching
603 outputs = undo_virtual_batching(outputs)
604 return outputs
606 inputs_dtype = inputs.dtype.base_dtype
607 if inputs_dtype in (tf.float16, tf.bfloat16):
608 # Do all math in float32 if given 16-bit inputs for numeric
609 # stability. In particular, it's very easy for variance to overflow
610 # in float16 and for safety we also choose to cast bfloat16 to
611 # float32.
612 inputs = tf.cast(inputs, tf.float32)
614 # Compute the axes along which to reduce the mean / variance
615 input_shape = inputs.shape
616 ndims = len(input_shape)
617 reduction_axes = [i for i in range(ndims) if i not in self.axis]
618 if self.virtual_batch_size is not None:
619 del reduction_axes[1] # Do not reduce along virtual batch dim
621 # Broadcasting only necessary for single-axis batch norm where the axis
622 # is not the last dimension
623 broadcast_shape = [1] * ndims
624 broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
626 def _broadcast(v):
627 if (
628 v is not None
629 and len(v.shape) != ndims
630 and reduction_axes != list(range(ndims - 1))
631 ):
632 return tf.reshape(v, broadcast_shape)
633 return v
635 scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
637 def _compose_transforms(scale, offset, then_scale, then_offset):
638 if then_scale is not None:
639 scale *= then_scale
640 offset *= then_scale
641 if then_offset is not None:
642 offset += then_offset
643 return (scale, offset)
645 if training_value == False: # noqa: E712
646 mean, variance = self.moving_mean, self.moving_variance
647 else:
648 # The following long block are handling mean/variance update during
649 # the training stage in various of different settings.
650 if self.adjustment:
651 adj_scale, adj_bias = self.adjustment(tf.shape(inputs))
652 # Adjust only during training.
653 adj_scale = control_flow_util.smart_cond(
654 training, lambda: adj_scale, lambda: tf.ones_like(adj_scale)
655 )
656 adj_bias = control_flow_util.smart_cond(
657 training, lambda: adj_bias, lambda: tf.zeros_like(adj_bias)
658 )
659 scale, offset = _compose_transforms(
660 adj_scale, adj_bias, scale, offset
661 )
663 # Some of the computations here are not necessary when
664 # training==False but not a constant. However, this makes the code
665 # simpler.
666 keep_dims = (
667 self.virtual_batch_size is not None or len(self.axis) > 1
668 )
669 mean, variance = self._moments(
670 tf.cast(inputs, self._param_dtype),
671 reduction_axes,
672 keep_dims=keep_dims,
673 mask=mask,
674 )
676 moving_mean = self.moving_mean
677 moving_variance = self.moving_variance
679 mean = control_flow_util.smart_cond(
680 training,
681 lambda: mean,
682 lambda: tf.convert_to_tensor(moving_mean),
683 )
684 variance = control_flow_util.smart_cond(
685 training,
686 lambda: variance,
687 lambda: tf.convert_to_tensor(moving_variance),
688 )
690 if self.virtual_batch_size is not None:
691 # This isn't strictly correct since in ghost batch norm, you are
692 # supposed to sequentially update the moving_mean and
693 # moving_variance with each sub-batch. However, since the moving
694 # statistics are only used during evaluation, it is more
695 # efficient to just update in one step and should not make a
696 # significant difference in the result.
697 new_mean = tf.reduce_mean(mean, axis=1, keepdims=True)
698 new_variance = tf.reduce_mean(variance, axis=1, keepdims=True)
699 else:
700 if (
701 utils.running_with_dtensor_strategy()
702 and not self.synchronized
703 ):
704 new_mean = tf.math.reduce_mean(mean, axis=reduction_axes)
705 new_variance = tf.math.reduce_mean(
706 variance, axis=reduction_axes
707 )
708 else:
709 new_mean, new_variance = mean, variance
711 if self._support_zero_size_input():
712 # Keras assumes that batch dimension is the first dimension for
713 # Batch Normalization.
714 input_batch_size = tf.shape(inputs)[0]
715 else:
716 input_batch_size = None
718 if self.renorm:
719 (
720 r,
721 d,
722 new_mean,
723 new_variance,
724 ) = self._renorm_correction_and_moments(
725 new_mean, new_variance, training, input_batch_size
726 )
727 # When training, the normalized values (say, x) will be
728 # transformed as x * gamma + beta without renorm, and (x * r +
729 # d) * gamma + beta = x * (r * gamma) + (d * gamma + beta) with
730 # renorm.
731 r = _broadcast(tf.stop_gradient(r, name="renorm_r"))
732 d = _broadcast(tf.stop_gradient(d, name="renorm_d"))
733 scale, offset = _compose_transforms(r, d, scale, offset)
735 def _do_update(var, value):
736 """Compute the updates for mean and variance."""
737 return self._assign_moving_average(
738 var, value, self.momentum, input_batch_size
739 )
741 def mean_update():
742 true_branch = lambda: _do_update(self.moving_mean, new_mean)
743 false_branch = lambda: self.moving_mean
744 return control_flow_util.smart_cond(
745 training, true_branch, false_branch
746 )
748 def variance_update():
749 """Update the moving variance."""
751 def true_branch_renorm():
752 # We apply epsilon as part of the moving_stddev to mirror
753 # the training code path.
754 moving_stddev = _do_update(
755 self.moving_stddev, tf.sqrt(new_variance + self.epsilon)
756 )
757 return self._assign_new_value(
758 self.moving_variance,
759 # Apply relu in case floating point rounding causes it
760 # to go negative.
761 backend.relu(
762 moving_stddev * moving_stddev - self.epsilon
763 ),
764 )
766 if self.renorm:
767 true_branch = true_branch_renorm
768 else:
769 true_branch = lambda: _do_update(
770 self.moving_variance, new_variance
771 )
773 false_branch = lambda: self.moving_variance
774 return control_flow_util.smart_cond(
775 training, true_branch, false_branch
776 )
778 self.add_update(mean_update)
779 self.add_update(variance_update)
780 # End of handling mean/variance calculation and update.
782 mean = tf.cast(mean, inputs.dtype)
783 variance = tf.cast(variance, inputs.dtype)
784 if offset is not None:
785 offset = tf.cast(offset, inputs.dtype)
786 if scale is not None:
787 scale = tf.cast(scale, inputs.dtype)
788 outputs = tf.nn.batch_normalization(
789 inputs,
790 _broadcast(mean),
791 _broadcast(variance),
792 offset,
793 scale,
794 self.epsilon,
795 )
796 if inputs_dtype in (tf.float16, tf.bfloat16):
797 outputs = tf.cast(outputs, inputs_dtype)
799 # If some components of the shape got lost due to adjustments, fix that.
800 outputs.set_shape(input_shape)
802 if self.virtual_batch_size is not None:
803 outputs = undo_virtual_batching(outputs)
804 return outputs
806 def compute_output_shape(self, input_shape):
807 return input_shape
809 def get_config(self):
810 config = {
811 "axis": self.axis,
812 "momentum": self.momentum,
813 "epsilon": self.epsilon,
814 "center": self.center,
815 "scale": self.scale,
816 "beta_initializer": initializers.serialize(self.beta_initializer),
817 "gamma_initializer": initializers.serialize(self.gamma_initializer),
818 "moving_mean_initializer": initializers.serialize(
819 self.moving_mean_initializer
820 ),
821 "moving_variance_initializer": initializers.serialize(
822 self.moving_variance_initializer
823 ),
824 "beta_regularizer": regularizers.serialize(self.beta_regularizer),
825 "gamma_regularizer": regularizers.serialize(self.gamma_regularizer),
826 "beta_constraint": constraints.serialize(self.beta_constraint),
827 "gamma_constraint": constraints.serialize(self.gamma_constraint),
828 }
829 # Only add TensorFlow-specific parameters if they are set, so as to
830 # preserve model compatibility with external Keras.
831 if self.renorm:
832 config["renorm"] = True
833 config["renorm_clipping"] = self.renorm_clipping
834 config["renorm_momentum"] = self.renorm_momentum
835 if self.virtual_batch_size is not None:
836 config["virtual_batch_size"] = self.virtual_batch_size
837 # Note: adjustment is not serializable.
838 if self.adjustment is not None:
839 logging.warning(
840 "The `adjustment` function of this `BatchNormalization` "
841 "layer cannot be serialized and has been omitted from "
842 "the layer config. It will not be included when "
843 "re-creating the layer from the saved config."
844 )
845 base_config = super().get_config()
846 return dict(list(base_config.items()) + list(config.items()))
848 ######################## Start of private methods ##########################
849 def _support_zero_size_input(self):
850 if not tf.distribute.has_strategy():
851 return False
852 strategy = tf.distribute.get_strategy()
853 # TODO(b/195085185): remove experimental_enable_get_next_as_optional
854 # after migrating all users.
855 return getattr(
856 strategy.extended,
857 "enable_partial_batch_handling",
858 getattr(
859 strategy.extended,
860 "experimental_enable_get_next_as_optional",
861 False,
862 ),
863 )
865 def _assign_moving_average(self, variable, value, momentum, inputs_size):
866 def calculate_update_delta():
867 decay = tf.convert_to_tensor(1.0 - momentum, name="decay")
868 if decay.dtype != variable.dtype.base_dtype:
869 decay = tf.cast(decay, variable.dtype.base_dtype)
870 update_delta = (variable - tf.cast(value, variable.dtype)) * decay
871 if inputs_size is not None:
872 update_delta = tf.where(
873 inputs_size > 0,
874 update_delta,
875 backend.zeros_like(update_delta),
876 )
877 return update_delta
879 with backend.name_scope("AssignMovingAvg") as scope:
880 if tf.compat.v1.executing_eagerly_outside_functions():
881 return variable.assign_sub(calculate_update_delta(), name=scope)
882 else:
883 with tf.compat.v1.colocate_with(variable):
884 return tf.compat.v1.assign_sub(
885 variable, calculate_update_delta(), name=scope
886 )
888 def _assign_new_value(self, variable, value):
889 with backend.name_scope("AssignNewValue") as scope:
890 if tf.compat.v1.executing_eagerly_outside_functions():
891 return variable.assign(value, name=scope)
892 else:
893 with tf.compat.v1.colocate_with(variable):
894 return tf.compat.v1.assign(variable, value, name=scope)
896 def _fused_batch_norm(self, inputs, mask, training):
897 """Returns the output of fused batch norm."""
898 if mask is not None:
899 warnings.warn(
900 "Masking is not supported with `fused=True`. "
901 "You should either turn off fusing "
902 "(`fused=False`) or you should not pass a `mask` "
903 "argument when calling the layer. "
904 "For the moment `mask` will be ignored for the "
905 "normalization."
906 )
907 if self.center:
908 beta = self.beta
909 else:
910 beta = backend.constant(
911 0.0, dtype=self._param_dtype, shape=self._param_shape
912 )
913 if self.scale:
914 gamma = self.gamma
915 else:
916 gamma = backend.constant(
917 1.0, dtype=self._param_dtype, shape=self._param_shape
918 )
920 # TODO(b/129279393): Support zero batch input in non
921 # DistributionStrategy code as well.
922 if self._support_zero_size_input():
923 # Keras assumes that batch dimension is the first dimension for
924 # Batch Normalization.
925 input_batch_size = tf.shape(inputs)[0]
926 else:
927 input_batch_size = None
929 # TODO(rmlarsen): Support using fused avg updates for non-eager
930 # execution after fixing graph pattern matching and enabling
931 # fused_batch_norm to take exponential_avg_factor as a tensor input.
932 use_fused_avg_updates = (
933 tf.compat.v1.executing_eagerly_outside_functions()
934 and isinstance(self.momentum, (float, int))
935 and get_enclosing_xla_context() is None
936 )
937 if use_fused_avg_updates:
938 exponential_avg_factor = 1.0 - self.momentum
939 else:
940 exponential_avg_factor = None
942 def _maybe_add_or_remove_bessels_correction(variance, remove=True):
943 r"""Add or remove Bessel's correction."""
944 # Removes Bessel's correction if remove == True, adds it otherwise.
945 # This is to be consistent with non-fused batch norm. Note that the
946 # variance computed by fused batch norm is with Bessel's correction.
947 # This is only used in legacy V1 batch norm tests.
948 if self._bessels_correction_test_only:
949 return variance
950 sample_size = tf.cast(
951 tf.size(inputs) / tf.size(variance), variance.dtype
952 )
953 if remove:
954 factor = (
955 sample_size - tf.cast(1.0, variance.dtype)
956 ) / sample_size
957 else:
958 factor = sample_size / (
959 sample_size - tf.cast(1.0, variance.dtype)
960 )
961 return variance * factor
963 def _fused_batch_norm_training():
964 return tf.compat.v1.nn.fused_batch_norm(
965 inputs,
966 gamma,
967 beta,
968 mean=self.moving_mean,
969 variance=_maybe_add_or_remove_bessels_correction(
970 self.moving_variance, remove=False
971 ),
972 epsilon=self.epsilon,
973 is_training=True,
974 data_format=self._data_format,
975 exponential_avg_factor=exponential_avg_factor,
976 )
978 def _fused_batch_norm_inference():
979 return tf.compat.v1.nn.fused_batch_norm(
980 inputs,
981 gamma,
982 beta,
983 mean=self.moving_mean,
984 variance=self.moving_variance,
985 epsilon=self.epsilon,
986 is_training=False,
987 data_format=self._data_format,
988 )
990 output, mean, variance = control_flow_util.smart_cond(
991 training, _fused_batch_norm_training, _fused_batch_norm_inference
992 )
993 variance = _maybe_add_or_remove_bessels_correction(
994 variance, remove=True
995 )
997 training_value = control_flow_util.constant_value(training)
998 if training_value or training_value is None:
999 if not use_fused_avg_updates:
1000 if training_value is None:
1001 momentum = control_flow_util.smart_cond(
1002 training, lambda: self.momentum, lambda: 1.0
1003 )
1004 else:
1005 momentum = tf.convert_to_tensor(self.momentum)
1007 def mean_update():
1008 """Update self.moving_mean with the most recent data point."""
1009 if use_fused_avg_updates:
1010 if input_batch_size is not None:
1011 new_mean = control_flow_util.smart_cond(
1012 input_batch_size > 0,
1013 lambda: mean,
1014 lambda: self.moving_mean,
1015 )
1016 else:
1017 new_mean = mean
1018 return self._assign_new_value(self.moving_mean, new_mean)
1019 else:
1020 return self._assign_moving_average(
1021 self.moving_mean, mean, momentum, input_batch_size
1022 )
1024 def variance_update():
1025 """Update self.moving_variance with the most recent data
1026 point."""
1027 if use_fused_avg_updates:
1028 if input_batch_size is not None:
1029 new_variance = control_flow_util.smart_cond(
1030 input_batch_size > 0,
1031 lambda: variance,
1032 lambda: self.moving_variance,
1033 )
1034 else:
1035 new_variance = variance
1036 return self._assign_new_value(
1037 self.moving_variance, new_variance
1038 )
1039 else:
1040 return self._assign_moving_average(
1041 self.moving_variance,
1042 variance,
1043 momentum,
1044 input_batch_size,
1045 )
1047 self.add_update(mean_update)
1048 self.add_update(variance_update)
1050 return output
1052 def _renorm_correction_and_moments(
1053 self, mean, variance, training, inputs_size
1054 ):
1055 """Returns the correction and update values for renorm."""
1056 stddev = tf.sqrt(variance + self.epsilon)
1057 # Compute the average mean and standard deviation, as if they were
1058 # initialized with this batch's moments.
1059 renorm_mean = self.renorm_mean
1060 # Avoid divide by zero early on in training.
1061 renorm_stddev = tf.maximum(self.renorm_stddev, tf.sqrt(self.epsilon))
1062 # Compute the corrections for batch renorm.
1063 r = stddev / renorm_stddev
1064 d = (mean - renorm_mean) / renorm_stddev
1065 # Ensure the corrections use pre-update moving averages.
1066 with tf.control_dependencies([r, d]):
1067 mean = tf.identity(mean)
1068 stddev = tf.identity(stddev)
1069 rmin, rmax, dmax = [
1070 self.renorm_clipping.get(key) for key in ["rmin", "rmax", "dmax"]
1071 ]
1072 if rmin is not None:
1073 r = tf.maximum(r, rmin)
1074 if rmax is not None:
1075 r = tf.minimum(r, rmax)
1076 if dmax is not None:
1077 d = tf.maximum(d, -dmax)
1078 d = tf.minimum(d, dmax)
1079 # When not training, use r=1, d=0.
1080 r = control_flow_util.smart_cond(
1081 training, lambda: r, lambda: tf.ones_like(r)
1082 )
1083 d = control_flow_util.smart_cond(
1084 training, lambda: d, lambda: tf.zeros_like(d)
1085 )
1087 def _update_renorm_variable(var, value, inputs_size):
1088 """Updates a moving average and weight, returns the unbiased
1089 value."""
1090 value = tf.identity(value)
1092 def _do_update():
1093 """Updates the var, returns the updated value."""
1094 new_var = self._assign_moving_average(
1095 var, value, self.renorm_momentum, inputs_size
1096 )
1097 return new_var
1099 def _fake_update():
1100 return tf.identity(var)
1102 return control_flow_util.smart_cond(
1103 training, _do_update, _fake_update
1104 )
1106 # TODO(yuefengz): colocate the operations
1107 update_new_mean = _update_renorm_variable(
1108 self.renorm_mean, mean, inputs_size
1109 )
1110 update_new_stddev = _update_renorm_variable(
1111 self.renorm_stddev, stddev, inputs_size
1112 )
1114 # Update the inference mode moving averages with the batch value.
1115 with tf.control_dependencies([update_new_mean, update_new_stddev]):
1116 out_mean = tf.identity(mean)
1117 out_variance = tf.identity(variance)
1119 return (r, d, out_mean, out_variance)
1121 def _calculate_mean_and_var(
1122 self, inputs, reduction_axes, keep_dims, mask=None
1123 ):
1124 if self.synchronized:
1125 return self._sync_calculate_mean_and_var(
1126 inputs, reduction_axes, keep_dims, mask=mask
1127 )
1128 return self._no_sync_calculate_mean_and_var(
1129 inputs, reduction_axes, keep_dims, mask=mask
1130 )
1132 def _no_sync_calculate_mean_and_var(
1133 self, inputs, reduction_axes, keep_dims, mask=None
1134 ):
1135 if mask is None:
1136 return tf.nn.moments(inputs, reduction_axes, keepdims=keep_dims)
1137 else:
1138 mask_weights = tf.cast(
1139 mask, self.compute_dtype, name="mask_weights"
1140 )
1141 mask_weights = tf.expand_dims(
1142 mask_weights, axis=-1, name="mask_weights_broadcasted"
1143 )
1144 return tf.nn.weighted_moments(
1145 inputs,
1146 axes=reduction_axes,
1147 frequency_weights=mask_weights,
1148 keepdims=keep_dims,
1149 )
1151 def _sync_calculate_mean_and_var(
1152 self, x, reduction_axes, keep_dims, mask=None
1153 ):
1154 with backend.name_scope("moments"):
1155 # The dynamic range of fp16 is too limited to support the collection
1156 # of sufficient statistics. As a workaround we simply perform the
1157 # operations on 32-bit floats before converting the mean and
1158 # variance back to fp16
1159 y = tf.cast(x, tf.float32) if x.dtype == tf.float16 else x
1160 replica_ctx = tf.distribute.get_replica_context()
1162 if not replica_ctx:
1163 return self._no_sync_calculate_mean_and_var(
1164 x, reduction_axes, keep_dims, mask=mask
1165 )
1167 if mask is not None:
1168 mask_weights = tf.cast(mask, y.dtype, name="mask_weights")
1169 mask_weights = tf.expand_dims(
1170 mask_weights, axis=-1, name="mask_weights_broadcasted"
1171 )
1172 y *= mask_weights
1173 local_count = tf.broadcast_to(
1174 mask_weights, tf.shape(y), name="count"
1175 )
1176 else:
1177 local_count = tf.ones_like(y, name="count")
1179 local_sum = tf.reduce_sum(y, axis=reduction_axes, keepdims=True)
1180 local_squared_sum = tf.reduce_sum(
1181 tf.square(y), axis=reduction_axes, keepdims=True
1182 )
1183 local_count = tf.reduce_sum(
1184 local_count, axis=reduction_axes, keepdims=True
1185 )
1187 # TODO(b/163099951): batch the all-reduces once we sort out the
1188 # ordering issue for NCCL. We don't have a mechanism to launch
1189 # NCCL in the same order in each replica nowadays, so we limit
1190 # NCCL to batch all-reduces.
1191 y_sum = replica_ctx.all_reduce(
1192 tf.distribute.ReduceOp.SUM, local_sum
1193 )
1194 y_squared_sum = replica_ctx.all_reduce(
1195 tf.distribute.ReduceOp.SUM, local_squared_sum
1196 )
1197 count_sum = replica_ctx.all_reduce(
1198 tf.distribute.ReduceOp.SUM, local_count
1199 )
1201 mean = y_sum / count_sum
1202 y_squared_mean = y_squared_sum / count_sum
1203 # var = E(x^2) - E(x)^2
1204 variance = y_squared_mean - tf.square(mean)
1205 if not keep_dims:
1206 mean = tf.squeeze(mean, reduction_axes)
1207 variance = tf.squeeze(variance, reduction_axes)
1208 if x.dtype == tf.float16:
1209 return (
1210 tf.cast(mean, tf.float16),
1211 tf.cast(variance, tf.float16),
1212 )
1213 else:
1214 return (mean, variance)
1216 def _dtensor_calculate_mean_and_var(
1217 self, inputs, reduction_axes, keep_dims, mask=None
1218 ):
1219 if self.synchronized:
1220 return self._dtensor_sync_calculate_mean_and_var(
1221 inputs, reduction_axes, keep_dims, mask=mask
1222 )
1223 return self._dtensor_no_sync_calculate_mean_and_var(
1224 inputs, reduction_axes, keep_dims, mask=mask
1225 )
1227 def _dtensor_no_sync_calculate_mean_and_var(
1228 self, inputs, reduction_axes, keep_dims, mask=None
1229 ):
1230 replica_tensor = _expand_tensor_with_local_replica_group(inputs)
1231 local_batch_size = tf.shape(replica_tensor)[1]
1233 # Since we added a new axis in the beginning, all the value in
1234 # reduction_axes need to be incremented by 1.
1235 updated_reduction_axes = [n + 1 for n in reduction_axes]
1237 if mask is None:
1238 mean, var = tf.nn.moments(
1239 replica_tensor, updated_reduction_axes, keepdims=keep_dims
1240 )
1241 else:
1242 mask_weights = tf.cast(
1243 mask, self.compute_dtype, name="mask_weights"
1244 )
1245 mask_weights = tf.expand_dims(
1246 mask_weights, axis=-1, name="mask_weights_broadcasted"
1247 )
1248 mask_weights = _expand_tensor_with_local_replica_group(mask_weights)
1249 mean, var = tf.nn.weighted_moments(
1250 replica_tensor,
1251 axes=updated_reduction_axes,
1252 frequency_weights=mask_weights,
1253 keepdims=keep_dims,
1254 )
1255 # Also note that the mean/var we have here will have an extra dim in
1256 # axis 0, which is represented for num local replica. Down the
1257 # stream, the mean/var will be used to update the moving_mean/var
1258 # and also normalize the inputs. To make the shape match, we will
1259 # expand the tensor shape from [num_replica, x, y] to
1260 # [batch_size, x, y] so that it can be properly used for
1261 # normalization. When it reaches the mean/var update, a separate
1262 # logic will be there to reduce_mean the value based on the batch
1263 # dim.
1264 mean = tf.repeat(mean, local_batch_size, axis=0)
1265 var = tf.repeat(var, local_batch_size, axis=0)
1266 if not keep_dims:
1267 # We need to fill the reduced dims so that the mean/var can be
1268 # properly broadcast to the input shapes. In the example above,
1269 # the original reduction_axes is [0, 1]. We ignore the first 0
1270 # (batch dim) here since we already expand and use it as num_replica
1271 for dim in reduction_axes[1:]:
1272 mean = tf.expand_dims(mean, axis=dim)
1273 var = tf.expand_dims(var, axis=dim)
1274 return mean, var
1276 def _dtensor_sync_calculate_mean_and_var(
1277 self, inputs, reduction_axes, keep_dims, mask=None
1278 ):
1279 # In the DTensor sync BN, since the input tensor is already in global
1280 # context, we just need to use the normal moments/weighted_moments
1281 # to calculate mean/var, which is same as the non-sync BN in the normal
1282 # mode.
1283 return self._no_sync_calculate_mean_and_var(
1284 inputs, reduction_axes, keep_dims, mask
1285 )
1287 def _moments(self, inputs, reduction_axes, keep_dims, mask=None):
1288 if utils.running_with_dtensor_strategy():
1289 mean, variance = self._dtensor_calculate_mean_and_var(
1290 inputs, reduction_axes, keep_dims, mask=mask
1291 )
1292 else:
1293 mean, variance = self._calculate_mean_and_var(
1294 inputs, reduction_axes, keep_dims, mask=mask
1295 )
1296 # TODO(b/129279393): Support zero batch input in non
1297 # DistributionStrategy code as well.
1298 if self._support_zero_size_input():
1299 input_batch_size = tf.shape(inputs)[0]
1300 mean = tf.where(
1301 input_batch_size > 0, mean, backend.zeros_like(mean)
1302 )
1303 variance = tf.where(
1304 input_batch_size > 0, variance, backend.zeros_like(variance)
1305 )
1306 return mean, variance
1308 def _get_training_value(self, training=None):
1309 if training is None:
1310 training = backend.learning_phase()
1311 if self._USE_V2_BEHAVIOR:
1312 if isinstance(training, int):
1313 training = bool(training)
1314 if not self.trainable:
1315 # When the layer is not trainable, it overrides the value passed
1316 # from model.
1317 training = False
1318 return training
1321@keras_export("keras.layers.BatchNormalization", v1=[])
1322class BatchNormalization(BatchNormalizationBase):
1323 """Layer that normalizes its inputs.
1325 Batch normalization applies a transformation that maintains the mean output
1326 close to 0 and the output standard deviation close to 1.
1328 Importantly, batch normalization works differently during training and
1329 during inference.
1331 **During training** (i.e. when using `fit()` or when calling the layer/model
1332 with the argument `training=True`), the layer normalizes its output using
1333 the mean and standard deviation of the current batch of inputs. That is to
1334 say, for each channel being normalized, the layer returns
1335 `gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where:
1337 - `epsilon` is small constant (configurable as part of the constructor
1338 arguments)
1339 - `gamma` is a learned scaling factor (initialized as 1), which
1340 can be disabled by passing `scale=False` to the constructor.
1341 - `beta` is a learned offset factor (initialized as 0), which
1342 can be disabled by passing `center=False` to the constructor.
1344 **During inference** (i.e. when using `evaluate()` or `predict()` or when
1345 calling the layer/model with the argument `training=False` (which is the
1346 default), the layer normalizes its output using a moving average of the
1347 mean and standard deviation of the batches it has seen during training. That
1348 is to say, it returns
1349 `gamma * (batch - self.moving_mean) / sqrt(self.moving_var+epsilon) + beta`.
1351 `self.moving_mean` and `self.moving_var` are non-trainable variables that
1352 are updated each time the layer in called in training mode, as such:
1354 - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)`
1355 - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)`
1357 As such, the layer will only normalize its inputs during inference
1358 *after having been trained on data that has similar statistics as the
1359 inference data*.
1361 When `synchronized=True` is set and if this layer is used within a
1362 `tf.distribute` strategy, there will be an `allreduce` call
1363 to aggregate batch statistics across all replicas at every
1364 training step. Setting `synchronized` has no impact when the model is
1365 trained without specifying any distribution strategy.
1367 Example usage:
1369 ```python
1370 strategy = tf.distribute.MirroredStrategy()
1372 with strategy.scope():
1373 model = tf.keras.Sequential()
1374 model.add(tf.keras.layers.Dense(16))
1375 model.add(tf.keras.layers.BatchNormalization(synchronized=True))
1376 ```
1378 Args:
1379 axis: Integer, the axis that should be normalized (typically the features
1380 axis). For instance, after a `Conv2D` layer with
1381 `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
1382 momentum: Momentum for the moving average.
1383 epsilon: Small float added to variance to avoid dividing by zero.
1384 center: If True, add offset of `beta` to normalized tensor. If False,
1385 `beta` is ignored.
1386 scale: If True, multiply by `gamma`. If False, `gamma` is not used. When
1387 the next layer is linear (also e.g. `nn.relu`), this can be disabled
1388 since the scaling will be done by the next layer.
1389 beta_initializer: Initializer for the beta weight.
1390 gamma_initializer: Initializer for the gamma weight.
1391 moving_mean_initializer: Initializer for the moving mean.
1392 moving_variance_initializer: Initializer for the moving variance.
1393 beta_regularizer: Optional regularizer for the beta weight.
1394 gamma_regularizer: Optional regularizer for the gamma weight.
1395 beta_constraint: Optional constraint for the beta weight.
1396 gamma_constraint: Optional constraint for the gamma weight.
1397 synchronized: If True, synchronizes the global batch statistics (mean and
1398 variance) for the layer across all devices at each training step in a
1399 distributed training strategy. If False, each replica uses its own
1400 local batch statistics. Only relevant when used inside a
1401 `tf.distribute` strategy.
1403 Call arguments:
1404 inputs: Input tensor (of any rank).
1405 training: Python boolean indicating whether the layer should behave in
1406 training mode or in inference mode.
1407 - `training=True`: The layer will normalize its inputs using the mean
1408 and variance of the current batch of inputs.
1409 - `training=False`: The layer will normalize its inputs using the mean
1410 and variance of its moving statistics, learned during training.
1412 Input shape:
1413 Arbitrary. Use the keyword argument `input_shape` (tuple of
1414 integers, does not include the samples axis) when using this layer as the
1415 first layer in a model.
1417 Output shape:
1418 Same shape as input.
1420 Reference:
1421 - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167).
1423 **About setting `layer.trainable = False` on a `BatchNormalization` layer:**
1425 The meaning of setting `layer.trainable = False` is to freeze the layer,
1426 i.e. its internal state will not change during training:
1427 its trainable weights will not be updated
1428 during `fit()` or `train_on_batch()`, and its state updates will not be run.
1430 Usually, this does not necessarily mean that the layer is run in inference
1431 mode (which is normally controlled by the `training` argument that can
1432 be passed when calling a layer). "Frozen state" and "inference mode"
1433 are two separate concepts.
1435 However, in the case of the `BatchNormalization` layer, **setting
1436 `trainable = False` on the layer means that the layer will be
1437 subsequently run in inference mode** (meaning that it will use
1438 the moving mean and the moving variance to normalize the current batch,
1439 rather than using the mean and variance of the current batch).
1441 This behavior has been introduced in TensorFlow 2.0, in order
1442 to enable `layer.trainable = False` to produce the most commonly
1443 expected behavior in the convnet fine-tuning use case.
1445 Note that:
1446 - Setting `trainable` on an model containing other layers will
1447 recursively set the `trainable` value of all inner layers.
1448 - If the value of the `trainable`
1449 attribute is changed after calling `compile()` on a model,
1450 the new value doesn't take effect for this model
1451 until `compile()` is called again.
1452 """
1454 _USE_V2_BEHAVIOR = True
1456 @utils.allow_initializer_layout
1457 def __init__(
1458 self,
1459 axis=-1,
1460 momentum=0.99,
1461 epsilon=1e-3,
1462 center=True,
1463 scale=True,
1464 beta_initializer="zeros",
1465 gamma_initializer="ones",
1466 moving_mean_initializer="zeros",
1467 moving_variance_initializer="ones",
1468 beta_regularizer=None,
1469 gamma_regularizer=None,
1470 beta_constraint=None,
1471 gamma_constraint=None,
1472 synchronized=False,
1473 **kwargs,
1474 ):
1475 # Currently we only support aggregating over the global batch size.
1476 super().__init__(
1477 axis=axis,
1478 momentum=momentum,
1479 epsilon=epsilon,
1480 center=center,
1481 scale=scale,
1482 beta_initializer=beta_initializer,
1483 gamma_initializer=gamma_initializer,
1484 moving_mean_initializer=moving_mean_initializer,
1485 moving_variance_initializer=moving_variance_initializer,
1486 beta_regularizer=beta_regularizer,
1487 gamma_regularizer=gamma_regularizer,
1488 beta_constraint=beta_constraint,
1489 gamma_constraint=gamma_constraint,
1490 synchronized=synchronized,
1491 **kwargs,
1492 )
1495@keras_export("keras.layers.experimental.SyncBatchNormalization", v1=[])
1496@deprecation.deprecated_endpoints(
1497 "keras.layers.experimental.SyncBatchNormalization"
1498)
1499class SyncBatchNormalization(BatchNormalizationBase):
1500 """Deprecated. Please use `tf.keras.layers.BatchNormalization` instead.
1502 Caution: `tf.keras.layers.experimental.SyncBatchNormalization` endpoint is
1503 deprecated and will be removed in a future release. Please use
1504 `tf.keras.layers.BatchNormalization` with parameter `synchronized`
1505 set to True
1506 """
1508 def __init__(
1509 self,
1510 axis=-1,
1511 momentum=0.99,
1512 epsilon=1e-3,
1513 center=True,
1514 scale=True,
1515 beta_initializer="zeros",
1516 gamma_initializer="ones",
1517 moving_mean_initializer="zeros",
1518 moving_variance_initializer="ones",
1519 beta_regularizer=None,
1520 gamma_regularizer=None,
1521 beta_constraint=None,
1522 gamma_constraint=None,
1523 **kwargs,
1524 ):
1525 warning = (
1526 "`tf.keras.layers.experimental.SyncBatchNormalization` endpoint is "
1527 "deprecated and will be removed in a future release. Please use "
1528 "`tf.keras.layers.BatchNormalization` with parameter "
1529 "`synchronized` set to True."
1530 )
1531 logging.log_first_n(logging.WARN, warning, 1)
1532 super().__init__(
1533 axis=axis,
1534 momentum=momentum,
1535 epsilon=epsilon,
1536 center=center,
1537 scale=scale,
1538 beta_initializer=beta_initializer,
1539 gamma_initializer=gamma_initializer,
1540 moving_mean_initializer=moving_mean_initializer,
1541 moving_variance_initializer=moving_variance_initializer,
1542 beta_regularizer=beta_regularizer,
1543 gamma_regularizer=gamma_regularizer,
1544 beta_constraint=beta_constraint,
1545 gamma_constraint=gamma_constraint,
1546 synchronized=True,
1547 **kwargs,
1548 )
1551def _expand_tensor_with_local_replica_group(inputs):
1552 """Reshape the input tensor to have an extra dimension of replica group.
1554 Under the DTensor usage, the normal batch norm still need to perform on
1555 a local batch size, which mean we can't directly do mean/var on a global
1556 tensor. In order to do a local mean/var, we have to add a new dimention to
1557 the tensor, so that the ops will not cross the replica boundary. E.g,
1558 a global tensor with shape [8, x, y] and has 2 local replica, the output of
1559 this will be [2, 4, x, y], where the first dim is for num of replica, and
1560 the second dim is for the local batch size. The follow ops can do reduces
1561 among the local batch dimension.
1563 Note that this function should only be used under DTensor based strategy,
1564 and it will use the current strategy in the context to get the number of
1565 replica.
1567 Args:
1568 inputs: Tensor with shape [global_batch_size, ...]
1570 Returns:
1571 Tensor with shape [num_replica, local_batch_size, ...]
1572 """
1573 # TODO(b/272382109): Implement this an an Op.
1574 input_shape = tf.shape(inputs)
1575 global_batch_size = input_shape[0]
1576 num_replica = tf.distribute.get_strategy().num_replicas_in_sync
1577 local_batch_size = global_batch_size // num_replica
1578 replica_shape = tf.stack([num_replica, local_batch_size])
1579 replica_shape = tf.concat([replica_shape, input_shape[1:]], axis=0)
1580 return tf.reshape(inputs, replica_shape)
1583def _raise_for_non_sync_bn_with_renorm_and_dtensor_strategy(
1584 synchronized, training, renorm
1585):
1586 if (
1587 utils.running_with_dtensor_strategy()
1588 and not synchronized
1589 and training == True
1590 and renorm
1591 ):
1592 raise NotImplementedError(
1593 "Renorm for BatchNormalization under DTensor based distribution "
1594 "strategy is not supported at the moment. Please file a feature "
1595 "request if this is blocking your adoption."
1596 )