Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/activations/hardshrink.py: 33%
12 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16import tensorflow as tf
17from tensorflow_addons.utils.types import Number, TensorLike
20@tf.keras.utils.register_keras_serializable(package="Addons")
21def hardshrink(x: TensorLike, lower: Number = -0.5, upper: Number = 0.5) -> tf.Tensor:
22 r"""Hard shrink function.
24 Computes hard shrink function:
26 $$
27 \mathrm{hardshrink}(x) =
28 \begin{cases}
29 x & \text{if } x < \text{lower} \\
30 x & \text{if } x > \text{upper} \\
31 0 & \text{otherwise}
32 \end{cases}.
33 $$
35 Usage:
37 >>> x = tf.constant([1.0, 0.0, 1.0])
38 >>> tfa.activations.hardshrink(x)
39 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 0., 1.], dtype=float32)>
41 Args:
42 x: A `Tensor`. Must be one of the following types:
43 `bfloat16`, `float16`, `float32`, `float64`.
44 lower: `float`, lower bound for setting values to zeros.
45 upper: `float`, upper bound for setting values to zeros.
46 Returns:
47 A `Tensor`. Has the same type as `x`.
48 """
49 if lower > upper:
50 raise ValueError(
51 "The value of lower is {} and should"
52 " not be higher than the value "
53 "variable upper, which is {} .".format(lower, upper)
54 )
55 x = tf.convert_to_tensor(x)
56 mask_lower = x < lower
57 mask_upper = upper < x
58 mask = tf.logical_or(mask_lower, mask_upper)
59 mask = tf.cast(mask, x.dtype)
60 return x * mask