Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/tlu.py: 31%
39 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements Thresholded Linear Unit."""
17import tensorflow as tf
18from typeguard import typechecked
20from tensorflow_addons.utils import types
23@tf.keras.utils.register_keras_serializable(package="Addons")
24class TLU(tf.keras.layers.Layer):
25 r"""Thresholded Linear Unit.
27 An activation function which is similar to ReLU
28 but with a learned threshold that benefits models using FRN(Filter Response
29 Normalization). Original paper: https://arxiv.org/pdf/1911.09737.
31 Input shape:
32 Arbitrary. Use the keyword argument `input_shape`
33 (tuple of integers, does not include the samples axis)
34 when using this layer as the first layer in a model.
36 Output shape:
37 Same shape as the input.
39 Args:
40 affine: `bool`. Whether to make it TLU-Affine or not
41 which has the form $\max(x, \alpha*x + \tau)$`
42 """
44 @typechecked
45 def __init__(
46 self,
47 affine: bool = False,
48 tau_initializer: types.Initializer = "zeros",
49 tau_regularizer: types.Regularizer = None,
50 tau_constraint: types.Constraint = None,
51 alpha_initializer: types.Initializer = "zeros",
52 alpha_regularizer: types.Regularizer = None,
53 alpha_constraint: types.Constraint = None,
54 **kwargs,
55 ):
56 super().__init__(**kwargs)
57 self.supports_masking = True
58 self.affine = affine
59 self.tau_initializer = tf.keras.initializers.get(tau_initializer)
60 self.tau_regularizer = tf.keras.regularizers.get(tau_regularizer)
61 self.tau_constraint = tf.keras.constraints.get(tau_constraint)
62 if self.affine:
63 self.alpha_initializer = tf.keras.initializers.get(alpha_initializer)
64 self.alpha_regularizer = tf.keras.regularizers.get(alpha_regularizer)
65 self.alpha_constraint = tf.keras.constraints.get(alpha_constraint)
67 def build(self, input_shape):
68 param_shape = list(input_shape[1:])
69 self.tau = self.add_weight(
70 shape=param_shape,
71 name="tau",
72 initializer=self.tau_initializer,
73 regularizer=self.tau_regularizer,
74 constraint=self.tau_constraint,
75 synchronization=tf.VariableSynchronization.AUTO,
76 aggregation=tf.VariableAggregation.MEAN,
77 )
78 if self.affine:
79 self.alpha = self.add_weight(
80 shape=param_shape,
81 name="alpha",
82 initializer=self.alpha_initializer,
83 regularizer=self.alpha_regularizer,
84 constraint=self.alpha_constraint,
85 synchronization=tf.VariableSynchronization.AUTO,
86 aggregation=tf.VariableAggregation.MEAN,
87 )
89 axes = {i: input_shape[i] for i in range(1, len(input_shape))}
90 self.input_spec = tf.keras.layers.InputSpec(ndim=len(input_shape), axes=axes)
91 self.built = True
93 def call(self, inputs):
94 v = self.alpha * inputs if self.affine else 0
95 return tf.maximum(inputs, self.tau + v)
97 def get_config(self):
98 config = {
99 "tau_initializer": tf.keras.initializers.serialize(self.tau_initializer),
100 "tau_regularizer": tf.keras.regularizers.serialize(self.tau_regularizer),
101 "tau_constraint": tf.keras.constraints.serialize(self.tau_constraint),
102 "affine": self.affine,
103 }
105 if self.affine:
106 config["alpha_initializer"] = tf.keras.initializers.serialize(
107 self.alpha_initializer
108 )
109 config["alpha_regularizer"] = tf.keras.regularizers.serialize(
110 self.alpha_regularizer
111 )
112 config["alpha_constraint"] = tf.keras.constraints.serialize(
113 self.alpha_constraint
114 )
116 base_config = super().get_config()
117 return {**base_config, **config}
119 def compute_output_shape(self, input_shape):
120 return input_shape