Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/activation/thresholded_relu.py: 50%

26 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Thresholded Rectified Linear Unit activation layer.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src import backend 

21from keras.src.engine.base_layer import Layer 

22from keras.src.utils import tf_utils 

23 

24# isort: off 

25from tensorflow.python.util.tf_export import keras_export 

26 

27 

28@keras_export("keras.layers.ThresholdedReLU") 

29class ThresholdedReLU(Layer): 

30 """Thresholded Rectified Linear Unit. 

31 

32 It follows: 

33 

34 ``` 

35 f(x) = x for x > theta 

36 f(x) = 0 otherwise` 

37 ``` 

38 

39 Input shape: 

40 Arbitrary. Use the keyword argument `input_shape` 

41 (tuple of integers, does not include the samples axis) 

42 when using this layer as the first layer in a model. 

43 

44 Output shape: 

45 Same shape as the input. 

46 

47 Args: 

48 theta: Float >= 0. Threshold location of activation. 

49 """ 

50 

51 def __init__(self, theta=1.0, **kwargs): 

52 super().__init__(**kwargs) 

53 if theta is None: 

54 raise ValueError( 

55 "Theta of a Thresholded ReLU layer cannot be None, expecting a " 

56 f"float. Received: {theta}" 

57 ) 

58 if theta < 0: 

59 raise ValueError( 

60 "The theta value of a Thresholded ReLU layer " 

61 f"should be >=0. Received: {theta}" 

62 ) 

63 self.supports_masking = True 

64 self.theta = backend.cast_to_floatx(theta) 

65 

66 def call(self, inputs): 

67 dtype = self.compute_dtype 

68 return inputs * tf.cast(tf.greater(inputs, self.theta), dtype) 

69 

70 def get_config(self): 

71 config = {"theta": float(self.theta)} 

72 base_config = super().get_config() 

73 return dict(list(base_config.items()) + list(config.items())) 

74 

75 @tf_utils.shape_type_conversion 

76 def compute_output_shape(self, input_shape): 

77 return input_shape 

78