Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/activation/prelu.py: 38%
47 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Parametric Rectified Linear Unit activation layer."""
18from keras.src import backend
19from keras.src import constraints
20from keras.src import initializers
21from keras.src import regularizers
22from keras.src.engine.base_layer import Layer
23from keras.src.engine.input_spec import InputSpec
24from keras.src.utils import tf_utils
26# isort: off
27from tensorflow.python.util.tf_export import keras_export
30@keras_export("keras.layers.PReLU")
31class PReLU(Layer):
32 """Parametric Rectified Linear Unit.
34 It follows:
36 ```
37 f(x) = alpha * x for x < 0
38 f(x) = x for x >= 0
39 ```
41 where `alpha` is a learned array with the same shape as x.
43 Input shape:
44 Arbitrary. Use the keyword argument `input_shape`
45 (tuple of integers, does not include the samples axis)
46 when using this layer as the first layer in a model.
48 Output shape:
49 Same shape as the input.
51 Args:
52 alpha_initializer: Initializer function for the weights.
53 alpha_regularizer: Regularizer for the weights.
54 alpha_constraint: Constraint for the weights.
55 shared_axes: The axes along which to share learnable
56 parameters for the activation function.
57 For example, if the incoming feature maps
58 are from a 2D convolution
59 with output shape `(batch, height, width, channels)`,
60 and you wish to share parameters across space
61 so that each filter only has one set of parameters,
62 set `shared_axes=[1, 2]`.
63 """
65 def __init__(
66 self,
67 alpha_initializer="zeros",
68 alpha_regularizer=None,
69 alpha_constraint=None,
70 shared_axes=None,
71 **kwargs
72 ):
73 super().__init__(**kwargs)
74 self.supports_masking = True
75 self.alpha_initializer = initializers.get(alpha_initializer)
76 self.alpha_regularizer = regularizers.get(alpha_regularizer)
77 self.alpha_constraint = constraints.get(alpha_constraint)
78 if shared_axes is None:
79 self.shared_axes = None
80 elif not isinstance(shared_axes, (list, tuple)):
81 self.shared_axes = [shared_axes]
82 else:
83 self.shared_axes = list(shared_axes)
85 @tf_utils.shape_type_conversion
86 def build(self, input_shape):
87 param_shape = list(input_shape[1:])
88 if self.shared_axes is not None:
89 for i in self.shared_axes:
90 param_shape[i - 1] = 1
91 self.alpha = self.add_weight(
92 shape=param_shape,
93 name="alpha",
94 initializer=self.alpha_initializer,
95 regularizer=self.alpha_regularizer,
96 constraint=self.alpha_constraint,
97 )
98 # Set input spec
99 axes = {}
100 if self.shared_axes:
101 for i in range(1, len(input_shape)):
102 if i not in self.shared_axes:
103 axes[i] = input_shape[i]
104 self.input_spec = InputSpec(ndim=len(input_shape), axes=axes)
105 self.built = True
107 def call(self, inputs):
108 pos = backend.relu(inputs)
109 neg = -self.alpha * backend.relu(-inputs)
110 return pos + neg
112 def get_config(self):
113 config = {
114 "alpha_initializer": initializers.serialize(self.alpha_initializer),
115 "alpha_regularizer": regularizers.serialize(self.alpha_regularizer),
116 "alpha_constraint": constraints.serialize(self.alpha_constraint),
117 "shared_axes": self.shared_axes,
118 }
119 base_config = super().get_config()
120 return dict(list(base_config.items()) + list(config.items()))
122 @tf_utils.shape_type_conversion
123 def compute_output_shape(self, input_shape):
124 return input_shape