Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/activation/prelu.py: 38%

47 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Parametric Rectified Linear Unit activation layer.""" 

16 

17 

18from keras.src import backend 

19from keras.src import constraints 

20from keras.src import initializers 

21from keras.src import regularizers 

22from keras.src.engine.base_layer import Layer 

23from keras.src.engine.input_spec import InputSpec 

24from keras.src.utils import tf_utils 

25 

26# isort: off 

27from tensorflow.python.util.tf_export import keras_export 

28 

29 

30@keras_export("keras.layers.PReLU") 

31class PReLU(Layer): 

32 """Parametric Rectified Linear Unit. 

33 

34 It follows: 

35 

36 ``` 

37 f(x) = alpha * x for x < 0 

38 f(x) = x for x >= 0 

39 ``` 

40 

41 where `alpha` is a learned array with the same shape as x. 

42 

43 Input shape: 

44 Arbitrary. Use the keyword argument `input_shape` 

45 (tuple of integers, does not include the samples axis) 

46 when using this layer as the first layer in a model. 

47 

48 Output shape: 

49 Same shape as the input. 

50 

51 Args: 

52 alpha_initializer: Initializer function for the weights. 

53 alpha_regularizer: Regularizer for the weights. 

54 alpha_constraint: Constraint for the weights. 

55 shared_axes: The axes along which to share learnable 

56 parameters for the activation function. 

57 For example, if the incoming feature maps 

58 are from a 2D convolution 

59 with output shape `(batch, height, width, channels)`, 

60 and you wish to share parameters across space 

61 so that each filter only has one set of parameters, 

62 set `shared_axes=[1, 2]`. 

63 """ 

64 

65 def __init__( 

66 self, 

67 alpha_initializer="zeros", 

68 alpha_regularizer=None, 

69 alpha_constraint=None, 

70 shared_axes=None, 

71 **kwargs 

72 ): 

73 super().__init__(**kwargs) 

74 self.supports_masking = True 

75 self.alpha_initializer = initializers.get(alpha_initializer) 

76 self.alpha_regularizer = regularizers.get(alpha_regularizer) 

77 self.alpha_constraint = constraints.get(alpha_constraint) 

78 if shared_axes is None: 

79 self.shared_axes = None 

80 elif not isinstance(shared_axes, (list, tuple)): 

81 self.shared_axes = [shared_axes] 

82 else: 

83 self.shared_axes = list(shared_axes) 

84 

85 @tf_utils.shape_type_conversion 

86 def build(self, input_shape): 

87 param_shape = list(input_shape[1:]) 

88 if self.shared_axes is not None: 

89 for i in self.shared_axes: 

90 param_shape[i - 1] = 1 

91 self.alpha = self.add_weight( 

92 shape=param_shape, 

93 name="alpha", 

94 initializer=self.alpha_initializer, 

95 regularizer=self.alpha_regularizer, 

96 constraint=self.alpha_constraint, 

97 ) 

98 # Set input spec 

99 axes = {} 

100 if self.shared_axes: 

101 for i in range(1, len(input_shape)): 

102 if i not in self.shared_axes: 

103 axes[i] = input_shape[i] 

104 self.input_spec = InputSpec(ndim=len(input_shape), axes=axes) 

105 self.built = True 

106 

107 def call(self, inputs): 

108 pos = backend.relu(inputs) 

109 neg = -self.alpha * backend.relu(-inputs) 

110 return pos + neg 

111 

112 def get_config(self): 

113 config = { 

114 "alpha_initializer": initializers.serialize(self.alpha_initializer), 

115 "alpha_regularizer": regularizers.serialize(self.alpha_regularizer), 

116 "alpha_constraint": constraints.serialize(self.alpha_constraint), 

117 "shared_axes": self.shared_axes, 

118 } 

119 base_config = super().get_config() 

120 return dict(list(base_config.items()) + list(config.items())) 

121 

122 @tf_utils.shape_type_conversion 

123 def compute_output_shape(self, input_shape): 

124 return input_shape 

125