Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/tlu.py: 31%

39 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implements Thresholded Linear Unit.""" 

16 

17import tensorflow as tf 

18from typeguard import typechecked 

19 

20from tensorflow_addons.utils import types 

21 

22 

23@tf.keras.utils.register_keras_serializable(package="Addons") 

24class TLU(tf.keras.layers.Layer): 

25 r"""Thresholded Linear Unit. 

26 

27 An activation function which is similar to ReLU 

28 but with a learned threshold that benefits models using FRN(Filter Response 

29 Normalization). Original paper: https://arxiv.org/pdf/1911.09737. 

30 

31 Input shape: 

32 Arbitrary. Use the keyword argument `input_shape` 

33 (tuple of integers, does not include the samples axis) 

34 when using this layer as the first layer in a model. 

35 

36 Output shape: 

37 Same shape as the input. 

38 

39 Args: 

40 affine: `bool`. Whether to make it TLU-Affine or not 

41 which has the form $\max(x, \alpha*x + \tau)$` 

42 """ 

43 

44 @typechecked 

45 def __init__( 

46 self, 

47 affine: bool = False, 

48 tau_initializer: types.Initializer = "zeros", 

49 tau_regularizer: types.Regularizer = None, 

50 tau_constraint: types.Constraint = None, 

51 alpha_initializer: types.Initializer = "zeros", 

52 alpha_regularizer: types.Regularizer = None, 

53 alpha_constraint: types.Constraint = None, 

54 **kwargs, 

55 ): 

56 super().__init__(**kwargs) 

57 self.supports_masking = True 

58 self.affine = affine 

59 self.tau_initializer = tf.keras.initializers.get(tau_initializer) 

60 self.tau_regularizer = tf.keras.regularizers.get(tau_regularizer) 

61 self.tau_constraint = tf.keras.constraints.get(tau_constraint) 

62 if self.affine: 

63 self.alpha_initializer = tf.keras.initializers.get(alpha_initializer) 

64 self.alpha_regularizer = tf.keras.regularizers.get(alpha_regularizer) 

65 self.alpha_constraint = tf.keras.constraints.get(alpha_constraint) 

66 

67 def build(self, input_shape): 

68 param_shape = list(input_shape[1:]) 

69 self.tau = self.add_weight( 

70 shape=param_shape, 

71 name="tau", 

72 initializer=self.tau_initializer, 

73 regularizer=self.tau_regularizer, 

74 constraint=self.tau_constraint, 

75 synchronization=tf.VariableSynchronization.AUTO, 

76 aggregation=tf.VariableAggregation.MEAN, 

77 ) 

78 if self.affine: 

79 self.alpha = self.add_weight( 

80 shape=param_shape, 

81 name="alpha", 

82 initializer=self.alpha_initializer, 

83 regularizer=self.alpha_regularizer, 

84 constraint=self.alpha_constraint, 

85 synchronization=tf.VariableSynchronization.AUTO, 

86 aggregation=tf.VariableAggregation.MEAN, 

87 ) 

88 

89 axes = {i: input_shape[i] for i in range(1, len(input_shape))} 

90 self.input_spec = tf.keras.layers.InputSpec(ndim=len(input_shape), axes=axes) 

91 self.built = True 

92 

93 def call(self, inputs): 

94 v = self.alpha * inputs if self.affine else 0 

95 return tf.maximum(inputs, self.tau + v) 

96 

97 def get_config(self): 

98 config = { 

99 "tau_initializer": tf.keras.initializers.serialize(self.tau_initializer), 

100 "tau_regularizer": tf.keras.regularizers.serialize(self.tau_regularizer), 

101 "tau_constraint": tf.keras.constraints.serialize(self.tau_constraint), 

102 "affine": self.affine, 

103 } 

104 

105 if self.affine: 

106 config["alpha_initializer"] = tf.keras.initializers.serialize( 

107 self.alpha_initializer 

108 ) 

109 config["alpha_regularizer"] = tf.keras.regularizers.serialize( 

110 self.alpha_regularizer 

111 ) 

112 config["alpha_constraint"] = tf.keras.constraints.serialize( 

113 self.alpha_constraint 

114 ) 

115 

116 base_config = super().get_config() 

117 return {**base_config, **config} 

118 

119 def compute_output_shape(self, input_shape): 

120 return input_shape