Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/optimizers/utils.py: 21%

34 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Additional Utilities used for tfa.optimizers.""" 

16 

17import re 

18import tensorflow as tf 

19from typing import List 

20 

21 

22def fit_bn(model, *args, **kwargs): 

23 """Resets batch normalization layers of model, and recalculates the 

24 statistics for each batchnorm layer by running a pass on the data. 

25 

26 Args: 

27 model: An instance of tf.keras.Model 

28 *args, **kwargs: Params that'll be passed to `.fit` method of model 

29 """ 

30 kwargs["epochs"] = 1 

31 if not isinstance(model, tf.keras.Model): 

32 raise TypeError("model must be an instance of tf.keras.Model") 

33 

34 if not model.built: 

35 raise ValueError("Call `fit_bn` after the model is built and trained") 

36 

37 assign_ops = [] 

38 for layer in model.layers: 

39 if isinstance(layer, tf.keras.layers.BatchNormalization): 

40 assign_ops.extend( 

41 [ 

42 layer.moving_mean.assign(tf.zeros_like(layer.moving_mean)), 

43 layer.moving_variance.assign(tf.ones_like(layer.moving_variance)), 

44 ] 

45 ) 

46 

47 _trainable = model.trainable 

48 _metrics = model._metrics 

49 model.trainable = False 

50 model._metrics = [] 

51 

52 model.fit(*args, **kwargs) 

53 

54 model.trainable = _trainable 

55 model._metrics = _metrics 

56 

57 

58def get_variable_name(variable) -> str: 

59 """Get the variable name from the variable tensor.""" 

60 param_name = variable.name 

61 m = re.match("^(.*):\\d+$", param_name) 

62 if m is not None: 

63 param_name = m.group(1) 

64 return param_name 

65 

66 

67def is_variable_matched_by_regexes(variable, regexes: List[str]) -> bool: 

68 """Whether variable is matched in regexes list by its name.""" 

69 if regexes: 

70 # var_name = get_variable_name(variable) 

71 var_name = variable.name 

72 for r in regexes: 

73 if re.search(r, var_name): 

74 return True 

75 return False