Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/utils.py: 21%
34 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Additional Utilities used for tfa.optimizers."""
17import re
18import tensorflow as tf
19from typing import List
22def fit_bn(model, *args, **kwargs):
23 """Resets batch normalization layers of model, and recalculates the
24 statistics for each batchnorm layer by running a pass on the data.
26 Args:
27 model: An instance of tf.keras.Model
28 *args, **kwargs: Params that'll be passed to `.fit` method of model
29 """
30 kwargs["epochs"] = 1
31 if not isinstance(model, tf.keras.Model):
32 raise TypeError("model must be an instance of tf.keras.Model")
34 if not model.built:
35 raise ValueError("Call `fit_bn` after the model is built and trained")
37 assign_ops = []
38 for layer in model.layers:
39 if isinstance(layer, tf.keras.layers.BatchNormalization):
40 assign_ops.extend(
41 [
42 layer.moving_mean.assign(tf.zeros_like(layer.moving_mean)),
43 layer.moving_variance.assign(tf.ones_like(layer.moving_variance)),
44 ]
45 )
47 _trainable = model.trainable
48 _metrics = model._metrics
49 model.trainable = False
50 model._metrics = []
52 model.fit(*args, **kwargs)
54 model.trainable = _trainable
55 model._metrics = _metrics
58def get_variable_name(variable) -> str:
59 """Get the variable name from the variable tensor."""
60 param_name = variable.name
61 m = re.match("^(.*):\\d+$", param_name)
62 if m is not None:
63 param_name = m.group(1)
64 return param_name
67def is_variable_matched_by_regexes(variable, regexes: List[str]) -> bool:
68 """Whether variable is matched in regexes list by its name."""
69 if regexes:
70 # var_name = get_variable_name(variable)
71 var_name = variable.name
72 for r in regexes:
73 if re.search(r, var_name):
74 return True
75 return False