Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/normalization/batch_normalization.py: 11%

494 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The V2 implementation of Normalization layers.""" 

16 

17import warnings 

18 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src import backend 

22from keras.src import constraints 

23from keras.src import initializers 

24from keras.src import regularizers 

25from keras.src.dtensor import utils 

26from keras.src.engine.base_layer import Layer 

27from keras.src.engine.input_spec import InputSpec 

28from keras.src.utils import control_flow_util 

29from keras.src.utils import tf_utils 

30 

31# isort: off 

32from tensorflow.python.ops.control_flow_ops import ( 

33 get_enclosing_xla_context, 

34) 

35from tensorflow.python.platform import tf_logging as logging 

36from tensorflow.python.util import deprecation 

37from tensorflow.python.util.tf_export import keras_export 

38 

39 

40class BatchNormalizationBase(Layer): 

41 r"""Layer that normalizes its inputs. 

42 

43 Batch normalization applies a transformation that maintains the mean output 

44 close to 0 and the output standard deviation close to 1. 

45 

46 Importantly, batch normalization works differently during training and 

47 during inference. 

48 

49 **During training** (i.e. when using `fit()` or when calling the layer/model 

50 with the argument `training=True`), the layer normalizes its output using 

51 the mean and standard deviation of the current batch of inputs. That is to 

52 say, for each channel being normalized, the layer returns 

53 `gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where: 

54 

55 - `epsilon` is small constant (configurable as part of the constructor 

56 arguments) 

57 - `gamma` is a learned scaling factor (initialized as 1), which 

58 can be disabled by passing `scale=False` to the constructor. 

59 - `beta` is a learned offset factor (initialized as 0), which 

60 can be disabled by passing `center=False` to the constructor. 

61 

62 **During inference** (i.e. when using `evaluate()` or `predict()`) or when 

63 calling the layer/model with the argument `training=False` (which is the 

64 default), the layer normalizes its output using a moving average of the 

65 mean and standard deviation of the batches it has seen during training. That 

66 is to say, it returns 

67 `gamma * (batch - self.moving_mean) / sqrt(self.moving_var+epsilon) + beta`. 

68 

69 `self.moving_mean` and `self.moving_var` are non-trainable variables that 

70 are updated each time the layer in called in training mode, as such: 

71 

72 - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)` 

73 - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)` 

74 

75 As such, the layer will only normalize its inputs during inference 

76 *after having been trained on data that has similar statistics as the 

77 inference data*. 

78 

79 Args: 

80 axis: Integer or a list of integers, the axis that should be normalized 

81 (typically the features axis). For instance, after a `Conv2D` layer with 

82 `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. 

83 momentum: Momentum for the moving average. 

84 epsilon: Small float added to variance to avoid dividing by zero. 

85 center: If True, add offset of `beta` to normalized tensor. If False, 

86 `beta` is ignored. 

87 scale: If True, multiply by `gamma`. If False, `gamma` is not used. When 

88 the next layer is linear (also e.g. `nn.relu`), this can be disabled 

89 since the scaling will be done by the next layer. 

90 beta_initializer: Initializer for the beta weight. 

91 gamma_initializer: Initializer for the gamma weight. 

92 moving_mean_initializer: Initializer for the moving mean. 

93 moving_variance_initializer: Initializer for the moving variance. 

94 beta_regularizer: Optional regularizer for the beta weight. 

95 gamma_regularizer: Optional regularizer for the gamma weight. 

96 beta_constraint: Optional constraint for the beta weight. 

97 gamma_constraint: Optional constraint for the gamma weight. 

98 renorm: Whether to use [Batch Renormalization]( 

99 https://arxiv.org/abs/1702.03275). This adds extra variables during 

100 training. The inference is the same for either value of this 

101 parameter. 

102 renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to 

103 scalar `Tensors` used to clip the renorm correction. The correction `(r, 

104 d)` is used as `corrected_value = normalized_value * r + d`, with `r` 

105 clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, 

106 dmax are set to inf, 0, inf, respectively. 

107 renorm_momentum: Momentum used to update the moving means and standard 

108 deviations with renorm. Unlike `momentum`, this affects training and 

109 should be neither too small (which would add noise) nor too large (which 

110 would give stale estimates). Note that `momentum` is still applied to 

111 get the means and variances for inference. 

112 fused: if `True`, use a faster, fused implementation, or raise a 

113 ValueError if the fused implementation cannot be used. If `None`, use 

114 the faster implementation if possible. If False, do not used the fused 

115 implementation. Note that in TensorFlow 1.x, the meaning of 

116 `fused=True` is different: if `False`, the layer uses the 

117 system-recommended implementation. You cannot use `fused=True` if a 

118 mask is passed in the `call()` method. 

119 trainable: Boolean, if `True` the variables will be marked as trainable. 

120 virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`, 

121 which means batch normalization is performed across the whole batch. 

122 When `virtual_batch_size` is not `None`, instead perform "Ghost Batch 

123 Normalization", which creates virtual sub-batches which are each 

124 normalized separately (with shared gamma, beta, and moving statistics). 

125 Must divide the actual batch size during execution. 

126 adjustment: A function taking the `Tensor` containing the (dynamic) shape 

127 of the input tensor and returning a pair (scale, bias) to apply to the 

128 normalized values (before gamma and beta), only during training. For 

129 example, if `axis=-1`, 

130 `adjustment = lambda shape: ( 

131 tf.random.uniform(shape[-1:], 0.93, 1.07), 

132 tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized 

133 value by up to 7% up or down, then shift the result by up to 0.1 

134 (with independent scaling and bias for each feature but shared 

135 across all examples), and finally apply gamma and/or beta. If 

136 `None`, no adjustment is applied. Cannot be specified if 

137 virtual_batch_size is specified. 

138 synchronized: If True, synchronizes the global batch statistics (mean and 

139 variance) for the layer across all devices at each training step in a 

140 distributed training strategy. If False, each replica uses its own 

141 local batch statistics. Only relevant when used inside a 

142 `tf.distribute` strategy. 

143 

144 Call arguments: 

145 inputs: Input tensor (of any rank). 

146 training: Python boolean indicating whether the layer should behave in 

147 training mode or in inference mode. 

148 - `training=True`: The layer will normalize its inputs using the mean 

149 and variance of the current batch of inputs. 

150 - `training=False`: The layer will normalize its inputs using the mean 

151 and variance of its moving statistics, learned during training. 

152 mask: Binary tensor of shape broadcastable to `inputs` tensor, indicating 

153 the positions for which the mean and variance should be computed. 

154 

155 Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of 

156 integers, does not include the samples axis) when using this layer as the 

157 first layer in a model. 

158 

159 Output shape: Same shape as input. 

160 

161 Reference: 

162 - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167). 

163 """ 

164 

165 # By default, the base class uses V2 behavior. The BatchNormalization V1 

166 # subclass sets this to False to use the V1 behavior. 

167 _USE_V2_BEHAVIOR = True 

168 

169 def __init__( 

170 self, 

171 axis=-1, 

172 momentum=0.99, 

173 epsilon=1e-3, 

174 center=True, 

175 scale=True, 

176 beta_initializer="zeros", 

177 gamma_initializer="ones", 

178 moving_mean_initializer="zeros", 

179 moving_variance_initializer="ones", 

180 beta_regularizer=None, 

181 gamma_regularizer=None, 

182 beta_constraint=None, 

183 gamma_constraint=None, 

184 renorm=False, 

185 renorm_clipping=None, 

186 renorm_momentum=0.99, 

187 fused=None, 

188 trainable=True, 

189 virtual_batch_size=None, 

190 adjustment=None, 

191 name=None, 

192 synchronized=False, 

193 **kwargs, 

194 ): 

195 super().__init__(name=name, **kwargs) 

196 if isinstance(axis, (list, tuple)): 

197 self.axis = axis[:] 

198 elif isinstance(axis, int): 

199 self.axis = axis 

200 else: 

201 raise TypeError( 

202 "Expected an int or a list/tuple of ints for the " 

203 "argument 'axis', but received: %r" % axis 

204 ) 

205 if synchronized and fused: 

206 raise ValueError( 

207 "`fused=True` is not supported when `synchronized=True`." 

208 ) 

209 self.synchronized = synchronized 

210 if self.synchronized: 

211 fused = False 

212 

213 self.momentum = momentum 

214 self.epsilon = epsilon 

215 self.center = center 

216 self.scale = scale 

217 self.beta_initializer = initializers.get(beta_initializer) 

218 self.gamma_initializer = initializers.get(gamma_initializer) 

219 self.moving_mean_initializer = initializers.get(moving_mean_initializer) 

220 self.moving_variance_initializer = initializers.get( 

221 moving_variance_initializer 

222 ) 

223 self.beta_regularizer = regularizers.get(beta_regularizer) 

224 self.gamma_regularizer = regularizers.get(gamma_regularizer) 

225 self.beta_constraint = constraints.get(beta_constraint) 

226 self.gamma_constraint = constraints.get(gamma_constraint) 

227 self.renorm = renorm 

228 self.virtual_batch_size = virtual_batch_size 

229 self.adjustment = adjustment 

230 if self._USE_V2_BEHAVIOR: 

231 if fused: 

232 self._raise_if_fused_cannot_be_used() 

233 # We leave fused as None if self._fused_can_be_used()==True, since 

234 # we still may set it to False in self.build() if the input rank is 

235 # not 4. 

236 elif fused is None and not self._fused_can_be_used(): 

237 fused = False 

238 elif fused is None: 

239 fused = True 

240 self.supports_masking = True 

241 

242 self.fused = fused 

243 self._bessels_correction_test_only = True 

244 self.trainable = trainable 

245 

246 if renorm: 

247 renorm_clipping = renorm_clipping or {} 

248 keys = ["rmax", "rmin", "dmax"] 

249 if set(renorm_clipping) - set(keys): 

250 raise ValueError( 

251 "Received invalid keys for `renorm_clipping` argument: " 

252 f"{renorm_clipping}. Supported values: {keys}." 

253 ) 

254 self.renorm_clipping = renorm_clipping 

255 self.renorm_momentum = renorm_momentum 

256 

257 def _raise_if_fused_cannot_be_used(self): 

258 """Raises a ValueError if fused implementation cannot be used. 

259 

260 In addition to the checks done in this function, the input tensors rank 

261 must be 4 or 5. The input rank check can only be done once the input 

262 shape is known. 

263 """ 

264 # Note the ValueErrors in this function are caught and not reraised in 

265 # _fused_can_be_used(). No other exception besides ValueError should be 

266 # raised here. 

267 

268 # Currently fused batch norm doesn't support renorm. It also only 

269 # supports a channel dimension on axis 1 or 3 (rank=4) / 1 or 4 (rank5), 

270 # when no virtual batch size or adjustment is used. 

271 if self.renorm: 

272 raise ValueError( 

273 "Passing both `fused=True` and `renorm=True` is not supported" 

274 ) 

275 axis = [self.axis] if isinstance(self.axis, int) else self.axis 

276 # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, when the 

277 # input rank is 4. Similarly, the valid axis is -4, -1, 1, 4 when the 

278 # rank is 5. The combination of ranks and axes will be checked later. 

279 if len(axis) > 1 or axis[0] not in (-4, -3, -1, 1, 3, 4): 

280 raise ValueError( 

281 "Passing `fused=True` is only supported when axis is 1 " 

282 "or 3 for input rank = 4 or 1 or 4 for input rank = 5. " 

283 "Got axis %s" % (axis,) 

284 ) 

285 if self.virtual_batch_size is not None: 

286 raise ValueError( 

287 "Passing `fused=True` is not supported when " 

288 "`virtual_batch_size` is specified." 

289 ) 

290 if self.adjustment is not None: 

291 raise ValueError( 

292 "Passing `fused=True` is not supported when " 

293 "`adjustment` is specified." 

294 ) 

295 # TODO(reedwm): Support fp64 in FusedBatchNorm then remove this check. 

296 if self._compute_dtype not in ("float16", "bfloat16", "float32", None): 

297 raise ValueError( 

298 "Passing `fused=True` is only supported when the compute " 

299 "dtype is float16, bfloat16, or float32. Got dtype: %s" 

300 % (self._compute_dtype,) 

301 ) 

302 

303 def _fused_can_be_used(self): 

304 try: 

305 self._raise_if_fused_cannot_be_used() 

306 return True 

307 except ValueError: 

308 return False 

309 

310 @property 

311 def trainable(self): 

312 return self._trainable 

313 

314 @trainable.setter 

315 def trainable(self, value): 

316 self._trainable = value 

317 

318 @property 

319 def _param_dtype(self): 

320 # Raise parameters of fp16 batch norm to fp32 

321 if self.dtype == tf.float16 or self.dtype == tf.bfloat16: 

322 return tf.float32 

323 else: 

324 return self.dtype or tf.float32 

325 

326 def build(self, input_shape): 

327 self.axis = tf_utils.validate_axis(self.axis, input_shape) 

328 input_shape = tf.TensorShape(input_shape) 

329 rank = input_shape.rank 

330 

331 if self.virtual_batch_size is not None: 

332 if self.virtual_batch_size <= 0: 

333 raise ValueError( 

334 "`virtual_batch_size` must be a positive integer that " 

335 "divides the true batch size of the input tensor. " 

336 f"Received: virtual_batch_size={self.virtual_batch_size}" 

337 ) 

338 # If using virtual batches, the first dimension must be the batch 

339 # dimension and cannot be the batch norm axis 

340 if 0 in self.axis: 

341 raise ValueError( 

342 "When using `virtual_batch_size`, the batch dimension " 

343 "must be 0 and thus axis cannot include 0. " 

344 f"Received axis={self.axis}" 

345 ) 

346 if self.adjustment is not None: 

347 raise ValueError( 

348 "When using `virtual_batch_size`, adjustment cannot " 

349 "be specified" 

350 ) 

351 

352 if self.fused in (None, True): 

353 # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape 

354 # the output back to its original shape accordingly. 

355 if self._USE_V2_BEHAVIOR: 

356 if self.fused is None: 

357 self.fused = rank in (4, 5) 

358 elif self.fused and rank not in (4, 5): 

359 raise ValueError( 

360 "Batch normalization layers with `fused=True` only " 

361 "support 4D or 5D input tensors. " 

362 f"Received tensor with shape: {tuple(input_shape)}" 

363 ) 

364 else: 

365 assert self.fused is not None 

366 self.fused = rank in (4, 5) and self._fused_can_be_used() 

367 # TODO(chrisying): fused batch norm is currently not supported for 

368 # multi-axis batch norm and by extension virtual batches. In some 

369 # cases, it might be possible to use fused batch norm but would 

370 # require reshaping the Tensor to 4D with the axis in 1 or 3 

371 # (preferred 1) which is particularly tricky. A compromise might be 

372 # to just support the most common use case (turning 5D w/ virtual 

373 # batch to NCHW) 

374 

375 if self.fused: 

376 if self.axis == [1] and rank == 4: 

377 self._data_format = "NCHW" 

378 elif self.axis == [1] and rank == 5: 

379 self._data_format = "NCDHW" 

380 elif self.axis == [3] and rank == 4: 

381 self._data_format = "NHWC" 

382 elif self.axis == [4] and rank == 5: 

383 self._data_format = "NDHWC" 

384 elif rank == 5: 

385 # 5D tensors that can be passed in but should not use fused 

386 # batch norm due to unsupported axis. 

387 self.fused = False 

388 else: 

389 if rank == 4: 

390 raise ValueError( 

391 "Unsupported axis. The use of `fused=True` is only " 

392 "possible with `axis=1` or `axis=3` for 4D input " 

393 f"tensors. Received: axis={tuple(self.axis)}" 

394 ) 

395 else: 

396 raise ValueError( 

397 "Unsupported axis. The use of `fused=True` is only " 

398 "possible with `axis=1` or `axis=4` for 5D input " 

399 f"tensors. Received: axis={tuple(self.axis)}" 

400 ) 

401 

402 axis_to_dim = {x: input_shape.dims[x].value for x in self.axis} 

403 for x in axis_to_dim: 

404 if axis_to_dim[x] is None: 

405 raise ValueError( 

406 "Input has undefined `axis` dimension. Received input " 

407 f"with shape {tuple(input_shape)} " 

408 f"and axis={tuple(self.axis)}" 

409 ) 

410 self.input_spec = InputSpec(ndim=rank, axes=axis_to_dim) 

411 

412 if len(axis_to_dim) == 1 and self.virtual_batch_size is None: 

413 # Single axis batch norm (most common/default use-case) 

414 param_shape = (list(axis_to_dim.values())[0],) 

415 else: 

416 # Parameter shape is the original shape but with 1 in all non-axis 

417 # dims 

418 param_shape = [ 

419 axis_to_dim[i] if i in axis_to_dim else 1 for i in range(rank) 

420 ] 

421 if self.virtual_batch_size is not None: 

422 # When using virtual batches, add an extra dim at index 1 

423 param_shape.insert(1, 1) 

424 for idx, x in enumerate(self.axis): 

425 self.axis[idx] = x + 1 # Account for added dimension 

426 self._param_shape = param_shape 

427 if self.scale: 

428 self.gamma = self.add_weight( 

429 name="gamma", 

430 shape=param_shape, 

431 dtype=self._param_dtype, 

432 initializer=self.gamma_initializer, 

433 regularizer=self.gamma_regularizer, 

434 constraint=self.gamma_constraint, 

435 trainable=True, 

436 experimental_autocast=False, 

437 ) 

438 else: 

439 self.gamma = None 

440 

441 if self.center: 

442 self.beta = self.add_weight( 

443 name="beta", 

444 shape=param_shape, 

445 dtype=self._param_dtype, 

446 initializer=self.beta_initializer, 

447 regularizer=self.beta_regularizer, 

448 constraint=self.beta_constraint, 

449 trainable=True, 

450 experimental_autocast=False, 

451 ) 

452 else: 

453 self.beta = None 

454 

455 try: 

456 # Disable variable partitioning when creating the moving mean and 

457 # variance 

458 if hasattr(self, "_scope") and self._scope: 

459 partitioner = self._scope.partitioner 

460 self._scope.set_partitioner(None) 

461 else: 

462 partitioner = None 

463 self.moving_mean = self.add_weight( 

464 name="moving_mean", 

465 shape=param_shape, 

466 dtype=self._param_dtype, 

467 initializer=self.moving_mean_initializer, 

468 synchronization=tf.VariableSynchronization.ON_READ, 

469 trainable=False, 

470 aggregation=tf.VariableAggregation.MEAN, 

471 experimental_autocast=False, 

472 ) 

473 

474 self.moving_variance = self.add_weight( 

475 name="moving_variance", 

476 shape=param_shape, 

477 dtype=self._param_dtype, 

478 initializer=self.moving_variance_initializer, 

479 synchronization=tf.VariableSynchronization.ON_READ, 

480 trainable=False, 

481 aggregation=tf.VariableAggregation.MEAN, 

482 experimental_autocast=False, 

483 ) 

484 

485 if self.renorm: 

486 # In batch renormalization we track the inference moving stddev 

487 # instead of the moving variance to more closely align with the 

488 # paper. 

489 def moving_stddev_initializer(*args, **kwargs): 

490 return tf.sqrt( 

491 self.moving_variance_initializer(*args, **kwargs) 

492 ) 

493 

494 with tf.distribute.get_strategy().extended.colocate_vars_with( 

495 self.moving_variance 

496 ): 

497 self.moving_stddev = self.add_weight( 

498 name="moving_stddev", 

499 shape=param_shape, 

500 dtype=self._param_dtype, 

501 initializer=moving_stddev_initializer, 

502 synchronization=tf.VariableSynchronization.ON_READ, 

503 trainable=False, 

504 aggregation=tf.VariableAggregation.MEAN, 

505 experimental_autocast=False, 

506 ) 

507 

508 # Create variables to maintain the moving mean and standard 

509 # deviation. These are used in training and thus are different 

510 # from the moving averages above. The renorm variables are 

511 # colocated with moving_mean and moving_stddev. 

512 # NOTE: below, the outer `with device` block causes the current 

513 # device stack to be cleared. The nested ones use a `lambda` to 

514 # set the desired device and ignore any devices that may be set 

515 # by the custom getter. 

516 def _renorm_variable(name, shape, initializer="zeros"): 

517 """Create a renorm variable.""" 

518 var = self.add_weight( 

519 name=name, 

520 shape=shape, 

521 dtype=self._param_dtype, 

522 initializer=initializer, 

523 synchronization=tf.VariableSynchronization.ON_READ, 

524 trainable=False, 

525 aggregation=tf.VariableAggregation.MEAN, 

526 experimental_autocast=False, 

527 ) 

528 return var 

529 

530 with tf.distribute.get_strategy().extended.colocate_vars_with( 

531 self.moving_mean 

532 ): 

533 self.renorm_mean = _renorm_variable( 

534 "renorm_mean", param_shape, self.moving_mean_initializer 

535 ) 

536 with tf.distribute.get_strategy().extended.colocate_vars_with( 

537 self.moving_stddev 

538 ): 

539 self.renorm_stddev = _renorm_variable( 

540 "renorm_stddev", param_shape, moving_stddev_initializer 

541 ) 

542 finally: 

543 if partitioner: 

544 self._scope.set_partitioner(partitioner) 

545 self.built = True 

546 

547 def call(self, inputs, training=None, mask=None): 

548 inputs = tf.cast(inputs, self.compute_dtype) 

549 training = self._get_training_value(training) 

550 # Determine a boolean value for `training`: could be True, False, or 

551 # None. 

552 training_value = control_flow_util.constant_value(training) 

553 _raise_for_non_sync_bn_with_renorm_and_dtensor_strategy( 

554 synchronized=self.synchronized, 

555 training=training, 

556 renorm=self.renorm, 

557 ) 

558 

559 if self.virtual_batch_size is not None: 

560 # Virtual batches (aka ghost batches) can be simulated by reshaping 

561 # the Tensor and reusing the existing batch norm implementation 

562 original_shape = tf.shape(inputs) 

563 original_shape = tf.concat( 

564 [tf.constant([-1]), original_shape[1:]], axis=0 

565 ) 

566 

567 if tf.__internal__.tf2.enabled(): 

568 expanded_shape = ( 

569 [self.virtual_batch_size, -1] if training_value else [-1, 1] 

570 ) 

571 expanded_shape = tf.concat( 

572 [ 

573 tf.constant(expanded_shape), 

574 original_shape[1:], 

575 ], 

576 axis=0, 

577 ) 

578 else: 

579 # Preserve incorrect legacy behavior for backwards compatibility 

580 expanded_shape = tf.concat( 

581 [ 

582 tf.constant([self.virtual_batch_size, -1]), 

583 original_shape[1:], 

584 ], 

585 axis=0, 

586 ) 

587 

588 # Will cause errors if virtual_batch_size does not divide the batch 

589 # size 

590 inputs = tf.reshape(inputs, expanded_shape) 

591 

592 def undo_virtual_batching(outputs): 

593 outputs = tf.reshape(outputs, original_shape) 

594 return outputs 

595 

596 if self.fused: 

597 outputs = self._fused_batch_norm( 

598 inputs, mask=mask, training=training 

599 ) 

600 if self.virtual_batch_size is not None: 

601 # Currently never reaches here since fused_batch_norm does not 

602 # support virtual batching 

603 outputs = undo_virtual_batching(outputs) 

604 return outputs 

605 

606 inputs_dtype = inputs.dtype.base_dtype 

607 if inputs_dtype in (tf.float16, tf.bfloat16): 

608 # Do all math in float32 if given 16-bit inputs for numeric 

609 # stability. In particular, it's very easy for variance to overflow 

610 # in float16 and for safety we also choose to cast bfloat16 to 

611 # float32. 

612 inputs = tf.cast(inputs, tf.float32) 

613 

614 # Compute the axes along which to reduce the mean / variance 

615 input_shape = inputs.shape 

616 ndims = len(input_shape) 

617 reduction_axes = [i for i in range(ndims) if i not in self.axis] 

618 if self.virtual_batch_size is not None: 

619 del reduction_axes[1] # Do not reduce along virtual batch dim 

620 

621 # Broadcasting only necessary for single-axis batch norm where the axis 

622 # is not the last dimension 

623 broadcast_shape = [1] * ndims 

624 broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value 

625 

626 def _broadcast(v): 

627 if ( 

628 v is not None 

629 and len(v.shape) != ndims 

630 and reduction_axes != list(range(ndims - 1)) 

631 ): 

632 return tf.reshape(v, broadcast_shape) 

633 return v 

634 

635 scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 

636 

637 def _compose_transforms(scale, offset, then_scale, then_offset): 

638 if then_scale is not None: 

639 scale *= then_scale 

640 offset *= then_scale 

641 if then_offset is not None: 

642 offset += then_offset 

643 return (scale, offset) 

644 

645 if training_value == False: # noqa: E712 

646 mean, variance = self.moving_mean, self.moving_variance 

647 else: 

648 # The following long block are handling mean/variance update during 

649 # the training stage in various of different settings. 

650 if self.adjustment: 

651 adj_scale, adj_bias = self.adjustment(tf.shape(inputs)) 

652 # Adjust only during training. 

653 adj_scale = control_flow_util.smart_cond( 

654 training, lambda: adj_scale, lambda: tf.ones_like(adj_scale) 

655 ) 

656 adj_bias = control_flow_util.smart_cond( 

657 training, lambda: adj_bias, lambda: tf.zeros_like(adj_bias) 

658 ) 

659 scale, offset = _compose_transforms( 

660 adj_scale, adj_bias, scale, offset 

661 ) 

662 

663 # Some of the computations here are not necessary when 

664 # training==False but not a constant. However, this makes the code 

665 # simpler. 

666 keep_dims = ( 

667 self.virtual_batch_size is not None or len(self.axis) > 1 

668 ) 

669 mean, variance = self._moments( 

670 tf.cast(inputs, self._param_dtype), 

671 reduction_axes, 

672 keep_dims=keep_dims, 

673 mask=mask, 

674 ) 

675 

676 moving_mean = self.moving_mean 

677 moving_variance = self.moving_variance 

678 

679 mean = control_flow_util.smart_cond( 

680 training, 

681 lambda: mean, 

682 lambda: tf.convert_to_tensor(moving_mean), 

683 ) 

684 variance = control_flow_util.smart_cond( 

685 training, 

686 lambda: variance, 

687 lambda: tf.convert_to_tensor(moving_variance), 

688 ) 

689 

690 if self.virtual_batch_size is not None: 

691 # This isn't strictly correct since in ghost batch norm, you are 

692 # supposed to sequentially update the moving_mean and 

693 # moving_variance with each sub-batch. However, since the moving 

694 # statistics are only used during evaluation, it is more 

695 # efficient to just update in one step and should not make a 

696 # significant difference in the result. 

697 new_mean = tf.reduce_mean(mean, axis=1, keepdims=True) 

698 new_variance = tf.reduce_mean(variance, axis=1, keepdims=True) 

699 else: 

700 if ( 

701 utils.running_with_dtensor_strategy() 

702 and not self.synchronized 

703 ): 

704 new_mean = tf.math.reduce_mean(mean, axis=reduction_axes) 

705 new_variance = tf.math.reduce_mean( 

706 variance, axis=reduction_axes 

707 ) 

708 else: 

709 new_mean, new_variance = mean, variance 

710 

711 if self._support_zero_size_input(): 

712 # Keras assumes that batch dimension is the first dimension for 

713 # Batch Normalization. 

714 input_batch_size = tf.shape(inputs)[0] 

715 else: 

716 input_batch_size = None 

717 

718 if self.renorm: 

719 ( 

720 r, 

721 d, 

722 new_mean, 

723 new_variance, 

724 ) = self._renorm_correction_and_moments( 

725 new_mean, new_variance, training, input_batch_size 

726 ) 

727 # When training, the normalized values (say, x) will be 

728 # transformed as x * gamma + beta without renorm, and (x * r + 

729 # d) * gamma + beta = x * (r * gamma) + (d * gamma + beta) with 

730 # renorm. 

731 r = _broadcast(tf.stop_gradient(r, name="renorm_r")) 

732 d = _broadcast(tf.stop_gradient(d, name="renorm_d")) 

733 scale, offset = _compose_transforms(r, d, scale, offset) 

734 

735 def _do_update(var, value): 

736 """Compute the updates for mean and variance.""" 

737 return self._assign_moving_average( 

738 var, value, self.momentum, input_batch_size 

739 ) 

740 

741 def mean_update(): 

742 true_branch = lambda: _do_update(self.moving_mean, new_mean) 

743 false_branch = lambda: self.moving_mean 

744 return control_flow_util.smart_cond( 

745 training, true_branch, false_branch 

746 ) 

747 

748 def variance_update(): 

749 """Update the moving variance.""" 

750 

751 def true_branch_renorm(): 

752 # We apply epsilon as part of the moving_stddev to mirror 

753 # the training code path. 

754 moving_stddev = _do_update( 

755 self.moving_stddev, tf.sqrt(new_variance + self.epsilon) 

756 ) 

757 return self._assign_new_value( 

758 self.moving_variance, 

759 # Apply relu in case floating point rounding causes it 

760 # to go negative. 

761 backend.relu( 

762 moving_stddev * moving_stddev - self.epsilon 

763 ), 

764 ) 

765 

766 if self.renorm: 

767 true_branch = true_branch_renorm 

768 else: 

769 true_branch = lambda: _do_update( 

770 self.moving_variance, new_variance 

771 ) 

772 

773 false_branch = lambda: self.moving_variance 

774 return control_flow_util.smart_cond( 

775 training, true_branch, false_branch 

776 ) 

777 

778 self.add_update(mean_update) 

779 self.add_update(variance_update) 

780 # End of handling mean/variance calculation and update. 

781 

782 mean = tf.cast(mean, inputs.dtype) 

783 variance = tf.cast(variance, inputs.dtype) 

784 if offset is not None: 

785 offset = tf.cast(offset, inputs.dtype) 

786 if scale is not None: 

787 scale = tf.cast(scale, inputs.dtype) 

788 outputs = tf.nn.batch_normalization( 

789 inputs, 

790 _broadcast(mean), 

791 _broadcast(variance), 

792 offset, 

793 scale, 

794 self.epsilon, 

795 ) 

796 if inputs_dtype in (tf.float16, tf.bfloat16): 

797 outputs = tf.cast(outputs, inputs_dtype) 

798 

799 # If some components of the shape got lost due to adjustments, fix that. 

800 outputs.set_shape(input_shape) 

801 

802 if self.virtual_batch_size is not None: 

803 outputs = undo_virtual_batching(outputs) 

804 return outputs 

805 

806 def compute_output_shape(self, input_shape): 

807 return input_shape 

808 

809 def get_config(self): 

810 config = { 

811 "axis": self.axis, 

812 "momentum": self.momentum, 

813 "epsilon": self.epsilon, 

814 "center": self.center, 

815 "scale": self.scale, 

816 "beta_initializer": initializers.serialize(self.beta_initializer), 

817 "gamma_initializer": initializers.serialize(self.gamma_initializer), 

818 "moving_mean_initializer": initializers.serialize( 

819 self.moving_mean_initializer 

820 ), 

821 "moving_variance_initializer": initializers.serialize( 

822 self.moving_variance_initializer 

823 ), 

824 "beta_regularizer": regularizers.serialize(self.beta_regularizer), 

825 "gamma_regularizer": regularizers.serialize(self.gamma_regularizer), 

826 "beta_constraint": constraints.serialize(self.beta_constraint), 

827 "gamma_constraint": constraints.serialize(self.gamma_constraint), 

828 } 

829 # Only add TensorFlow-specific parameters if they are set, so as to 

830 # preserve model compatibility with external Keras. 

831 if self.renorm: 

832 config["renorm"] = True 

833 config["renorm_clipping"] = self.renorm_clipping 

834 config["renorm_momentum"] = self.renorm_momentum 

835 if self.virtual_batch_size is not None: 

836 config["virtual_batch_size"] = self.virtual_batch_size 

837 # Note: adjustment is not serializable. 

838 if self.adjustment is not None: 

839 logging.warning( 

840 "The `adjustment` function of this `BatchNormalization` " 

841 "layer cannot be serialized and has been omitted from " 

842 "the layer config. It will not be included when " 

843 "re-creating the layer from the saved config." 

844 ) 

845 base_config = super().get_config() 

846 return dict(list(base_config.items()) + list(config.items())) 

847 

848 ######################## Start of private methods ########################## 

849 def _support_zero_size_input(self): 

850 if not tf.distribute.has_strategy(): 

851 return False 

852 strategy = tf.distribute.get_strategy() 

853 # TODO(b/195085185): remove experimental_enable_get_next_as_optional 

854 # after migrating all users. 

855 return getattr( 

856 strategy.extended, 

857 "enable_partial_batch_handling", 

858 getattr( 

859 strategy.extended, 

860 "experimental_enable_get_next_as_optional", 

861 False, 

862 ), 

863 ) 

864 

865 def _assign_moving_average(self, variable, value, momentum, inputs_size): 

866 def calculate_update_delta(): 

867 decay = tf.convert_to_tensor(1.0 - momentum, name="decay") 

868 if decay.dtype != variable.dtype.base_dtype: 

869 decay = tf.cast(decay, variable.dtype.base_dtype) 

870 update_delta = (variable - tf.cast(value, variable.dtype)) * decay 

871 if inputs_size is not None: 

872 update_delta = tf.where( 

873 inputs_size > 0, 

874 update_delta, 

875 backend.zeros_like(update_delta), 

876 ) 

877 return update_delta 

878 

879 with backend.name_scope("AssignMovingAvg") as scope: 

880 if tf.compat.v1.executing_eagerly_outside_functions(): 

881 return variable.assign_sub(calculate_update_delta(), name=scope) 

882 else: 

883 with tf.compat.v1.colocate_with(variable): 

884 return tf.compat.v1.assign_sub( 

885 variable, calculate_update_delta(), name=scope 

886 ) 

887 

888 def _assign_new_value(self, variable, value): 

889 with backend.name_scope("AssignNewValue") as scope: 

890 if tf.compat.v1.executing_eagerly_outside_functions(): 

891 return variable.assign(value, name=scope) 

892 else: 

893 with tf.compat.v1.colocate_with(variable): 

894 return tf.compat.v1.assign(variable, value, name=scope) 

895 

896 def _fused_batch_norm(self, inputs, mask, training): 

897 """Returns the output of fused batch norm.""" 

898 if mask is not None: 

899 warnings.warn( 

900 "Masking is not supported with `fused=True`. " 

901 "You should either turn off fusing " 

902 "(`fused=False`) or you should not pass a `mask` " 

903 "argument when calling the layer. " 

904 "For the moment `mask` will be ignored for the " 

905 "normalization." 

906 ) 

907 if self.center: 

908 beta = self.beta 

909 else: 

910 beta = backend.constant( 

911 0.0, dtype=self._param_dtype, shape=self._param_shape 

912 ) 

913 if self.scale: 

914 gamma = self.gamma 

915 else: 

916 gamma = backend.constant( 

917 1.0, dtype=self._param_dtype, shape=self._param_shape 

918 ) 

919 

920 # TODO(b/129279393): Support zero batch input in non 

921 # DistributionStrategy code as well. 

922 if self._support_zero_size_input(): 

923 # Keras assumes that batch dimension is the first dimension for 

924 # Batch Normalization. 

925 input_batch_size = tf.shape(inputs)[0] 

926 else: 

927 input_batch_size = None 

928 

929 # TODO(rmlarsen): Support using fused avg updates for non-eager 

930 # execution after fixing graph pattern matching and enabling 

931 # fused_batch_norm to take exponential_avg_factor as a tensor input. 

932 use_fused_avg_updates = ( 

933 tf.compat.v1.executing_eagerly_outside_functions() 

934 and isinstance(self.momentum, (float, int)) 

935 and get_enclosing_xla_context() is None 

936 ) 

937 if use_fused_avg_updates: 

938 exponential_avg_factor = 1.0 - self.momentum 

939 else: 

940 exponential_avg_factor = None 

941 

942 def _maybe_add_or_remove_bessels_correction(variance, remove=True): 

943 r"""Add or remove Bessel's correction.""" 

944 # Removes Bessel's correction if remove == True, adds it otherwise. 

945 # This is to be consistent with non-fused batch norm. Note that the 

946 # variance computed by fused batch norm is with Bessel's correction. 

947 # This is only used in legacy V1 batch norm tests. 

948 if self._bessels_correction_test_only: 

949 return variance 

950 sample_size = tf.cast( 

951 tf.size(inputs) / tf.size(variance), variance.dtype 

952 ) 

953 if remove: 

954 factor = ( 

955 sample_size - tf.cast(1.0, variance.dtype) 

956 ) / sample_size 

957 else: 

958 factor = sample_size / ( 

959 sample_size - tf.cast(1.0, variance.dtype) 

960 ) 

961 return variance * factor 

962 

963 def _fused_batch_norm_training(): 

964 return tf.compat.v1.nn.fused_batch_norm( 

965 inputs, 

966 gamma, 

967 beta, 

968 mean=self.moving_mean, 

969 variance=_maybe_add_or_remove_bessels_correction( 

970 self.moving_variance, remove=False 

971 ), 

972 epsilon=self.epsilon, 

973 is_training=True, 

974 data_format=self._data_format, 

975 exponential_avg_factor=exponential_avg_factor, 

976 ) 

977 

978 def _fused_batch_norm_inference(): 

979 return tf.compat.v1.nn.fused_batch_norm( 

980 inputs, 

981 gamma, 

982 beta, 

983 mean=self.moving_mean, 

984 variance=self.moving_variance, 

985 epsilon=self.epsilon, 

986 is_training=False, 

987 data_format=self._data_format, 

988 ) 

989 

990 output, mean, variance = control_flow_util.smart_cond( 

991 training, _fused_batch_norm_training, _fused_batch_norm_inference 

992 ) 

993 variance = _maybe_add_or_remove_bessels_correction( 

994 variance, remove=True 

995 ) 

996 

997 training_value = control_flow_util.constant_value(training) 

998 if training_value or training_value is None: 

999 if not use_fused_avg_updates: 

1000 if training_value is None: 

1001 momentum = control_flow_util.smart_cond( 

1002 training, lambda: self.momentum, lambda: 1.0 

1003 ) 

1004 else: 

1005 momentum = tf.convert_to_tensor(self.momentum) 

1006 

1007 def mean_update(): 

1008 """Update self.moving_mean with the most recent data point.""" 

1009 if use_fused_avg_updates: 

1010 if input_batch_size is not None: 

1011 new_mean = control_flow_util.smart_cond( 

1012 input_batch_size > 0, 

1013 lambda: mean, 

1014 lambda: self.moving_mean, 

1015 ) 

1016 else: 

1017 new_mean = mean 

1018 return self._assign_new_value(self.moving_mean, new_mean) 

1019 else: 

1020 return self._assign_moving_average( 

1021 self.moving_mean, mean, momentum, input_batch_size 

1022 ) 

1023 

1024 def variance_update(): 

1025 """Update self.moving_variance with the most recent data 

1026 point.""" 

1027 if use_fused_avg_updates: 

1028 if input_batch_size is not None: 

1029 new_variance = control_flow_util.smart_cond( 

1030 input_batch_size > 0, 

1031 lambda: variance, 

1032 lambda: self.moving_variance, 

1033 ) 

1034 else: 

1035 new_variance = variance 

1036 return self._assign_new_value( 

1037 self.moving_variance, new_variance 

1038 ) 

1039 else: 

1040 return self._assign_moving_average( 

1041 self.moving_variance, 

1042 variance, 

1043 momentum, 

1044 input_batch_size, 

1045 ) 

1046 

1047 self.add_update(mean_update) 

1048 self.add_update(variance_update) 

1049 

1050 return output 

1051 

1052 def _renorm_correction_and_moments( 

1053 self, mean, variance, training, inputs_size 

1054 ): 

1055 """Returns the correction and update values for renorm.""" 

1056 stddev = tf.sqrt(variance + self.epsilon) 

1057 # Compute the average mean and standard deviation, as if they were 

1058 # initialized with this batch's moments. 

1059 renorm_mean = self.renorm_mean 

1060 # Avoid divide by zero early on in training. 

1061 renorm_stddev = tf.maximum(self.renorm_stddev, tf.sqrt(self.epsilon)) 

1062 # Compute the corrections for batch renorm. 

1063 r = stddev / renorm_stddev 

1064 d = (mean - renorm_mean) / renorm_stddev 

1065 # Ensure the corrections use pre-update moving averages. 

1066 with tf.control_dependencies([r, d]): 

1067 mean = tf.identity(mean) 

1068 stddev = tf.identity(stddev) 

1069 rmin, rmax, dmax = [ 

1070 self.renorm_clipping.get(key) for key in ["rmin", "rmax", "dmax"] 

1071 ] 

1072 if rmin is not None: 

1073 r = tf.maximum(r, rmin) 

1074 if rmax is not None: 

1075 r = tf.minimum(r, rmax) 

1076 if dmax is not None: 

1077 d = tf.maximum(d, -dmax) 

1078 d = tf.minimum(d, dmax) 

1079 # When not training, use r=1, d=0. 

1080 r = control_flow_util.smart_cond( 

1081 training, lambda: r, lambda: tf.ones_like(r) 

1082 ) 

1083 d = control_flow_util.smart_cond( 

1084 training, lambda: d, lambda: tf.zeros_like(d) 

1085 ) 

1086 

1087 def _update_renorm_variable(var, value, inputs_size): 

1088 """Updates a moving average and weight, returns the unbiased 

1089 value.""" 

1090 value = tf.identity(value) 

1091 

1092 def _do_update(): 

1093 """Updates the var, returns the updated value.""" 

1094 new_var = self._assign_moving_average( 

1095 var, value, self.renorm_momentum, inputs_size 

1096 ) 

1097 return new_var 

1098 

1099 def _fake_update(): 

1100 return tf.identity(var) 

1101 

1102 return control_flow_util.smart_cond( 

1103 training, _do_update, _fake_update 

1104 ) 

1105 

1106 # TODO(yuefengz): colocate the operations 

1107 update_new_mean = _update_renorm_variable( 

1108 self.renorm_mean, mean, inputs_size 

1109 ) 

1110 update_new_stddev = _update_renorm_variable( 

1111 self.renorm_stddev, stddev, inputs_size 

1112 ) 

1113 

1114 # Update the inference mode moving averages with the batch value. 

1115 with tf.control_dependencies([update_new_mean, update_new_stddev]): 

1116 out_mean = tf.identity(mean) 

1117 out_variance = tf.identity(variance) 

1118 

1119 return (r, d, out_mean, out_variance) 

1120 

1121 def _calculate_mean_and_var( 

1122 self, inputs, reduction_axes, keep_dims, mask=None 

1123 ): 

1124 if self.synchronized: 

1125 return self._sync_calculate_mean_and_var( 

1126 inputs, reduction_axes, keep_dims, mask=mask 

1127 ) 

1128 return self._no_sync_calculate_mean_and_var( 

1129 inputs, reduction_axes, keep_dims, mask=mask 

1130 ) 

1131 

1132 def _no_sync_calculate_mean_and_var( 

1133 self, inputs, reduction_axes, keep_dims, mask=None 

1134 ): 

1135 if mask is None: 

1136 return tf.nn.moments(inputs, reduction_axes, keepdims=keep_dims) 

1137 else: 

1138 mask_weights = tf.cast( 

1139 mask, self.compute_dtype, name="mask_weights" 

1140 ) 

1141 mask_weights = tf.expand_dims( 

1142 mask_weights, axis=-1, name="mask_weights_broadcasted" 

1143 ) 

1144 return tf.nn.weighted_moments( 

1145 inputs, 

1146 axes=reduction_axes, 

1147 frequency_weights=mask_weights, 

1148 keepdims=keep_dims, 

1149 ) 

1150 

1151 def _sync_calculate_mean_and_var( 

1152 self, x, reduction_axes, keep_dims, mask=None 

1153 ): 

1154 with backend.name_scope("moments"): 

1155 # The dynamic range of fp16 is too limited to support the collection 

1156 # of sufficient statistics. As a workaround we simply perform the 

1157 # operations on 32-bit floats before converting the mean and 

1158 # variance back to fp16 

1159 y = tf.cast(x, tf.float32) if x.dtype == tf.float16 else x 

1160 replica_ctx = tf.distribute.get_replica_context() 

1161 

1162 if not replica_ctx: 

1163 return self._no_sync_calculate_mean_and_var( 

1164 x, reduction_axes, keep_dims, mask=mask 

1165 ) 

1166 

1167 if mask is not None: 

1168 mask_weights = tf.cast(mask, y.dtype, name="mask_weights") 

1169 mask_weights = tf.expand_dims( 

1170 mask_weights, axis=-1, name="mask_weights_broadcasted" 

1171 ) 

1172 y *= mask_weights 

1173 local_count = tf.broadcast_to( 

1174 mask_weights, tf.shape(y), name="count" 

1175 ) 

1176 else: 

1177 local_count = tf.ones_like(y, name="count") 

1178 

1179 local_sum = tf.reduce_sum(y, axis=reduction_axes, keepdims=True) 

1180 local_squared_sum = tf.reduce_sum( 

1181 tf.square(y), axis=reduction_axes, keepdims=True 

1182 ) 

1183 local_count = tf.reduce_sum( 

1184 local_count, axis=reduction_axes, keepdims=True 

1185 ) 

1186 

1187 # TODO(b/163099951): batch the all-reduces once we sort out the 

1188 # ordering issue for NCCL. We don't have a mechanism to launch 

1189 # NCCL in the same order in each replica nowadays, so we limit 

1190 # NCCL to batch all-reduces. 

1191 y_sum = replica_ctx.all_reduce( 

1192 tf.distribute.ReduceOp.SUM, local_sum 

1193 ) 

1194 y_squared_sum = replica_ctx.all_reduce( 

1195 tf.distribute.ReduceOp.SUM, local_squared_sum 

1196 ) 

1197 count_sum = replica_ctx.all_reduce( 

1198 tf.distribute.ReduceOp.SUM, local_count 

1199 ) 

1200 

1201 mean = y_sum / count_sum 

1202 y_squared_mean = y_squared_sum / count_sum 

1203 # var = E(x^2) - E(x)^2 

1204 variance = y_squared_mean - tf.square(mean) 

1205 if not keep_dims: 

1206 mean = tf.squeeze(mean, reduction_axes) 

1207 variance = tf.squeeze(variance, reduction_axes) 

1208 if x.dtype == tf.float16: 

1209 return ( 

1210 tf.cast(mean, tf.float16), 

1211 tf.cast(variance, tf.float16), 

1212 ) 

1213 else: 

1214 return (mean, variance) 

1215 

1216 def _dtensor_calculate_mean_and_var( 

1217 self, inputs, reduction_axes, keep_dims, mask=None 

1218 ): 

1219 if self.synchronized: 

1220 return self._dtensor_sync_calculate_mean_and_var( 

1221 inputs, reduction_axes, keep_dims, mask=mask 

1222 ) 

1223 return self._dtensor_no_sync_calculate_mean_and_var( 

1224 inputs, reduction_axes, keep_dims, mask=mask 

1225 ) 

1226 

1227 def _dtensor_no_sync_calculate_mean_and_var( 

1228 self, inputs, reduction_axes, keep_dims, mask=None 

1229 ): 

1230 replica_tensor = _expand_tensor_with_local_replica_group(inputs) 

1231 local_batch_size = tf.shape(replica_tensor)[1] 

1232 

1233 # Since we added a new axis in the beginning, all the value in 

1234 # reduction_axes need to be incremented by 1. 

1235 updated_reduction_axes = [n + 1 for n in reduction_axes] 

1236 

1237 if mask is None: 

1238 mean, var = tf.nn.moments( 

1239 replica_tensor, updated_reduction_axes, keepdims=keep_dims 

1240 ) 

1241 else: 

1242 mask_weights = tf.cast( 

1243 mask, self.compute_dtype, name="mask_weights" 

1244 ) 

1245 mask_weights = tf.expand_dims( 

1246 mask_weights, axis=-1, name="mask_weights_broadcasted" 

1247 ) 

1248 mask_weights = _expand_tensor_with_local_replica_group(mask_weights) 

1249 mean, var = tf.nn.weighted_moments( 

1250 replica_tensor, 

1251 axes=updated_reduction_axes, 

1252 frequency_weights=mask_weights, 

1253 keepdims=keep_dims, 

1254 ) 

1255 # Also note that the mean/var we have here will have an extra dim in 

1256 # axis 0, which is represented for num local replica. Down the 

1257 # stream, the mean/var will be used to update the moving_mean/var 

1258 # and also normalize the inputs. To make the shape match, we will 

1259 # expand the tensor shape from [num_replica, x, y] to 

1260 # [batch_size, x, y] so that it can be properly used for 

1261 # normalization. When it reaches the mean/var update, a separate 

1262 # logic will be there to reduce_mean the value based on the batch 

1263 # dim. 

1264 mean = tf.repeat(mean, local_batch_size, axis=0) 

1265 var = tf.repeat(var, local_batch_size, axis=0) 

1266 if not keep_dims: 

1267 # We need to fill the reduced dims so that the mean/var can be 

1268 # properly broadcast to the input shapes. In the example above, 

1269 # the original reduction_axes is [0, 1]. We ignore the first 0 

1270 # (batch dim) here since we already expand and use it as num_replica 

1271 for dim in reduction_axes[1:]: 

1272 mean = tf.expand_dims(mean, axis=dim) 

1273 var = tf.expand_dims(var, axis=dim) 

1274 return mean, var 

1275 

1276 def _dtensor_sync_calculate_mean_and_var( 

1277 self, inputs, reduction_axes, keep_dims, mask=None 

1278 ): 

1279 # In the DTensor sync BN, since the input tensor is already in global 

1280 # context, we just need to use the normal moments/weighted_moments 

1281 # to calculate mean/var, which is same as the non-sync BN in the normal 

1282 # mode. 

1283 return self._no_sync_calculate_mean_and_var( 

1284 inputs, reduction_axes, keep_dims, mask 

1285 ) 

1286 

1287 def _moments(self, inputs, reduction_axes, keep_dims, mask=None): 

1288 if utils.running_with_dtensor_strategy(): 

1289 mean, variance = self._dtensor_calculate_mean_and_var( 

1290 inputs, reduction_axes, keep_dims, mask=mask 

1291 ) 

1292 else: 

1293 mean, variance = self._calculate_mean_and_var( 

1294 inputs, reduction_axes, keep_dims, mask=mask 

1295 ) 

1296 # TODO(b/129279393): Support zero batch input in non 

1297 # DistributionStrategy code as well. 

1298 if self._support_zero_size_input(): 

1299 input_batch_size = tf.shape(inputs)[0] 

1300 mean = tf.where( 

1301 input_batch_size > 0, mean, backend.zeros_like(mean) 

1302 ) 

1303 variance = tf.where( 

1304 input_batch_size > 0, variance, backend.zeros_like(variance) 

1305 ) 

1306 return mean, variance 

1307 

1308 def _get_training_value(self, training=None): 

1309 if training is None: 

1310 training = backend.learning_phase() 

1311 if self._USE_V2_BEHAVIOR: 

1312 if isinstance(training, int): 

1313 training = bool(training) 

1314 if not self.trainable: 

1315 # When the layer is not trainable, it overrides the value passed 

1316 # from model. 

1317 training = False 

1318 return training 

1319 

1320 

1321@keras_export("keras.layers.BatchNormalization", v1=[]) 

1322class BatchNormalization(BatchNormalizationBase): 

1323 """Layer that normalizes its inputs. 

1324 

1325 Batch normalization applies a transformation that maintains the mean output 

1326 close to 0 and the output standard deviation close to 1. 

1327 

1328 Importantly, batch normalization works differently during training and 

1329 during inference. 

1330 

1331 **During training** (i.e. when using `fit()` or when calling the layer/model 

1332 with the argument `training=True`), the layer normalizes its output using 

1333 the mean and standard deviation of the current batch of inputs. That is to 

1334 say, for each channel being normalized, the layer returns 

1335 `gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where: 

1336 

1337 - `epsilon` is small constant (configurable as part of the constructor 

1338 arguments) 

1339 - `gamma` is a learned scaling factor (initialized as 1), which 

1340 can be disabled by passing `scale=False` to the constructor. 

1341 - `beta` is a learned offset factor (initialized as 0), which 

1342 can be disabled by passing `center=False` to the constructor. 

1343 

1344 **During inference** (i.e. when using `evaluate()` or `predict()` or when 

1345 calling the layer/model with the argument `training=False` (which is the 

1346 default), the layer normalizes its output using a moving average of the 

1347 mean and standard deviation of the batches it has seen during training. That 

1348 is to say, it returns 

1349 `gamma * (batch - self.moving_mean) / sqrt(self.moving_var+epsilon) + beta`. 

1350 

1351 `self.moving_mean` and `self.moving_var` are non-trainable variables that 

1352 are updated each time the layer in called in training mode, as such: 

1353 

1354 - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)` 

1355 - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)` 

1356 

1357 As such, the layer will only normalize its inputs during inference 

1358 *after having been trained on data that has similar statistics as the 

1359 inference data*. 

1360 

1361 When `synchronized=True` is set and if this layer is used within a 

1362 `tf.distribute` strategy, there will be an `allreduce` call 

1363 to aggregate batch statistics across all replicas at every 

1364 training step. Setting `synchronized` has no impact when the model is 

1365 trained without specifying any distribution strategy. 

1366 

1367 Example usage: 

1368 

1369 ```python 

1370 strategy = tf.distribute.MirroredStrategy() 

1371 

1372 with strategy.scope(): 

1373 model = tf.keras.Sequential() 

1374 model.add(tf.keras.layers.Dense(16)) 

1375 model.add(tf.keras.layers.BatchNormalization(synchronized=True)) 

1376 ``` 

1377 

1378 Args: 

1379 axis: Integer, the axis that should be normalized (typically the features 

1380 axis). For instance, after a `Conv2D` layer with 

1381 `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. 

1382 momentum: Momentum for the moving average. 

1383 epsilon: Small float added to variance to avoid dividing by zero. 

1384 center: If True, add offset of `beta` to normalized tensor. If False, 

1385 `beta` is ignored. 

1386 scale: If True, multiply by `gamma`. If False, `gamma` is not used. When 

1387 the next layer is linear (also e.g. `nn.relu`), this can be disabled 

1388 since the scaling will be done by the next layer. 

1389 beta_initializer: Initializer for the beta weight. 

1390 gamma_initializer: Initializer for the gamma weight. 

1391 moving_mean_initializer: Initializer for the moving mean. 

1392 moving_variance_initializer: Initializer for the moving variance. 

1393 beta_regularizer: Optional regularizer for the beta weight. 

1394 gamma_regularizer: Optional regularizer for the gamma weight. 

1395 beta_constraint: Optional constraint for the beta weight. 

1396 gamma_constraint: Optional constraint for the gamma weight. 

1397 synchronized: If True, synchronizes the global batch statistics (mean and 

1398 variance) for the layer across all devices at each training step in a 

1399 distributed training strategy. If False, each replica uses its own 

1400 local batch statistics. Only relevant when used inside a 

1401 `tf.distribute` strategy. 

1402 

1403 Call arguments: 

1404 inputs: Input tensor (of any rank). 

1405 training: Python boolean indicating whether the layer should behave in 

1406 training mode or in inference mode. 

1407 - `training=True`: The layer will normalize its inputs using the mean 

1408 and variance of the current batch of inputs. 

1409 - `training=False`: The layer will normalize its inputs using the mean 

1410 and variance of its moving statistics, learned during training. 

1411 

1412 Input shape: 

1413 Arbitrary. Use the keyword argument `input_shape` (tuple of 

1414 integers, does not include the samples axis) when using this layer as the 

1415 first layer in a model. 

1416 

1417 Output shape: 

1418 Same shape as input. 

1419 

1420 Reference: 

1421 - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167). 

1422 

1423 **About setting `layer.trainable = False` on a `BatchNormalization` layer:** 

1424 

1425 The meaning of setting `layer.trainable = False` is to freeze the layer, 

1426 i.e. its internal state will not change during training: 

1427 its trainable weights will not be updated 

1428 during `fit()` or `train_on_batch()`, and its state updates will not be run. 

1429 

1430 Usually, this does not necessarily mean that the layer is run in inference 

1431 mode (which is normally controlled by the `training` argument that can 

1432 be passed when calling a layer). "Frozen state" and "inference mode" 

1433 are two separate concepts. 

1434 

1435 However, in the case of the `BatchNormalization` layer, **setting 

1436 `trainable = False` on the layer means that the layer will be 

1437 subsequently run in inference mode** (meaning that it will use 

1438 the moving mean and the moving variance to normalize the current batch, 

1439 rather than using the mean and variance of the current batch). 

1440 

1441 This behavior has been introduced in TensorFlow 2.0, in order 

1442 to enable `layer.trainable = False` to produce the most commonly 

1443 expected behavior in the convnet fine-tuning use case. 

1444 

1445 Note that: 

1446 - Setting `trainable` on an model containing other layers will 

1447 recursively set the `trainable` value of all inner layers. 

1448 - If the value of the `trainable` 

1449 attribute is changed after calling `compile()` on a model, 

1450 the new value doesn't take effect for this model 

1451 until `compile()` is called again. 

1452 """ 

1453 

1454 _USE_V2_BEHAVIOR = True 

1455 

1456 @utils.allow_initializer_layout 

1457 def __init__( 

1458 self, 

1459 axis=-1, 

1460 momentum=0.99, 

1461 epsilon=1e-3, 

1462 center=True, 

1463 scale=True, 

1464 beta_initializer="zeros", 

1465 gamma_initializer="ones", 

1466 moving_mean_initializer="zeros", 

1467 moving_variance_initializer="ones", 

1468 beta_regularizer=None, 

1469 gamma_regularizer=None, 

1470 beta_constraint=None, 

1471 gamma_constraint=None, 

1472 synchronized=False, 

1473 **kwargs, 

1474 ): 

1475 # Currently we only support aggregating over the global batch size. 

1476 super().__init__( 

1477 axis=axis, 

1478 momentum=momentum, 

1479 epsilon=epsilon, 

1480 center=center, 

1481 scale=scale, 

1482 beta_initializer=beta_initializer, 

1483 gamma_initializer=gamma_initializer, 

1484 moving_mean_initializer=moving_mean_initializer, 

1485 moving_variance_initializer=moving_variance_initializer, 

1486 beta_regularizer=beta_regularizer, 

1487 gamma_regularizer=gamma_regularizer, 

1488 beta_constraint=beta_constraint, 

1489 gamma_constraint=gamma_constraint, 

1490 synchronized=synchronized, 

1491 **kwargs, 

1492 ) 

1493 

1494 

1495@keras_export("keras.layers.experimental.SyncBatchNormalization", v1=[]) 

1496@deprecation.deprecated_endpoints( 

1497 "keras.layers.experimental.SyncBatchNormalization" 

1498) 

1499class SyncBatchNormalization(BatchNormalizationBase): 

1500 """Deprecated. Please use `tf.keras.layers.BatchNormalization` instead. 

1501 

1502 Caution: `tf.keras.layers.experimental.SyncBatchNormalization` endpoint is 

1503 deprecated and will be removed in a future release. Please use 

1504 `tf.keras.layers.BatchNormalization` with parameter `synchronized` 

1505 set to True 

1506 """ 

1507 

1508 def __init__( 

1509 self, 

1510 axis=-1, 

1511 momentum=0.99, 

1512 epsilon=1e-3, 

1513 center=True, 

1514 scale=True, 

1515 beta_initializer="zeros", 

1516 gamma_initializer="ones", 

1517 moving_mean_initializer="zeros", 

1518 moving_variance_initializer="ones", 

1519 beta_regularizer=None, 

1520 gamma_regularizer=None, 

1521 beta_constraint=None, 

1522 gamma_constraint=None, 

1523 **kwargs, 

1524 ): 

1525 warning = ( 

1526 "`tf.keras.layers.experimental.SyncBatchNormalization` endpoint is " 

1527 "deprecated and will be removed in a future release. Please use " 

1528 "`tf.keras.layers.BatchNormalization` with parameter " 

1529 "`synchronized` set to True." 

1530 ) 

1531 logging.log_first_n(logging.WARN, warning, 1) 

1532 super().__init__( 

1533 axis=axis, 

1534 momentum=momentum, 

1535 epsilon=epsilon, 

1536 center=center, 

1537 scale=scale, 

1538 beta_initializer=beta_initializer, 

1539 gamma_initializer=gamma_initializer, 

1540 moving_mean_initializer=moving_mean_initializer, 

1541 moving_variance_initializer=moving_variance_initializer, 

1542 beta_regularizer=beta_regularizer, 

1543 gamma_regularizer=gamma_regularizer, 

1544 beta_constraint=beta_constraint, 

1545 gamma_constraint=gamma_constraint, 

1546 synchronized=True, 

1547 **kwargs, 

1548 ) 

1549 

1550 

1551def _expand_tensor_with_local_replica_group(inputs): 

1552 """Reshape the input tensor to have an extra dimension of replica group. 

1553 

1554 Under the DTensor usage, the normal batch norm still need to perform on 

1555 a local batch size, which mean we can't directly do mean/var on a global 

1556 tensor. In order to do a local mean/var, we have to add a new dimention to 

1557 the tensor, so that the ops will not cross the replica boundary. E.g, 

1558 a global tensor with shape [8, x, y] and has 2 local replica, the output of 

1559 this will be [2, 4, x, y], where the first dim is for num of replica, and 

1560 the second dim is for the local batch size. The follow ops can do reduces 

1561 among the local batch dimension. 

1562 

1563 Note that this function should only be used under DTensor based strategy, 

1564 and it will use the current strategy in the context to get the number of 

1565 replica. 

1566 

1567 Args: 

1568 inputs: Tensor with shape [global_batch_size, ...] 

1569 

1570 Returns: 

1571 Tensor with shape [num_replica, local_batch_size, ...] 

1572 """ 

1573 # TODO(b/272382109): Implement this an an Op. 

1574 input_shape = tf.shape(inputs) 

1575 global_batch_size = input_shape[0] 

1576 num_replica = tf.distribute.get_strategy().num_replicas_in_sync 

1577 local_batch_size = global_batch_size // num_replica 

1578 replica_shape = tf.stack([num_replica, local_batch_size]) 

1579 replica_shape = tf.concat([replica_shape, input_shape[1:]], axis=0) 

1580 return tf.reshape(inputs, replica_shape) 

1581 

1582 

1583def _raise_for_non_sync_bn_with_renorm_and_dtensor_strategy( 

1584 synchronized, training, renorm 

1585): 

1586 if ( 

1587 utils.running_with_dtensor_strategy() 

1588 and not synchronized 

1589 and training == True 

1590 and renorm 

1591 ): 

1592 raise NotImplementedError( 

1593 "Renorm for BatchNormalization under DTensor based distribution " 

1594 "strategy is not supported at the moment. Please file a feature " 

1595 "request if this is blocking your adoption." 

1596 ) 

1597