Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/activations/hardshrink.py: 33%

12 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16import tensorflow as tf 

17from tensorflow_addons.utils.types import Number, TensorLike 

18 

19 

20@tf.keras.utils.register_keras_serializable(package="Addons") 

21def hardshrink(x: TensorLike, lower: Number = -0.5, upper: Number = 0.5) -> tf.Tensor: 

22 r"""Hard shrink function. 

23 

24 Computes hard shrink function: 

25 

26 $$ 

27 \mathrm{hardshrink}(x) = 

28 \begin{cases} 

29 x & \text{if } x < \text{lower} \\ 

30 x & \text{if } x > \text{upper} \\ 

31 0 & \text{otherwise} 

32 \end{cases}. 

33 $$ 

34 

35 Usage: 

36 

37 >>> x = tf.constant([1.0, 0.0, 1.0]) 

38 >>> tfa.activations.hardshrink(x) 

39 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 0., 1.], dtype=float32)> 

40 

41 Args: 

42 x: A `Tensor`. Must be one of the following types: 

43 `bfloat16`, `float16`, `float32`, `float64`. 

44 lower: `float`, lower bound for setting values to zeros. 

45 upper: `float`, upper bound for setting values to zeros. 

46 Returns: 

47 A `Tensor`. Has the same type as `x`. 

48 """ 

49 if lower > upper: 

50 raise ValueError( 

51 "The value of lower is {} and should" 

52 " not be higher than the value " 

53 "variable upper, which is {} .".format(lower, upper) 

54 ) 

55 x = tf.convert_to_tensor(x) 

56 mask_lower = x < lower 

57 mask_upper = upper < x 

58 mask = tf.logical_or(mask_lower, mask_upper) 

59 mask = tf.cast(mask, x.dtype) 

60 return x * mask