Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/metrics/harmonic_mean.py: 71%

14 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implements HarmonicMean.""" 

16 

17import tensorflow as tf 

18 

19from typeguard import typechecked 

20from tensorflow_addons.utils.types import AcceptableDTypes 

21 

22 

23@tf.keras.utils.register_keras_serializable(package="Addons") 

24class HarmonicMean(tf.keras.metrics.Mean): 

25 """Compute Harmonic Mean 

26 The harmonic mean is a kind of mean. It can be expressed as the reciprocal of 

27 the arithmetic mean of the reciprocals of the given set of numbers. 

28 Note: `tfa.metrics.HarmonicMean` can be used the same as `tf.keras.metrics.Mean`. 

29 Args: 

30 name: (Optional) String name of the metric instance. 

31 dtype: (Optional) Data type of the metric result. 

32 Usage: 

33 >>> metric = tfa.metrics.HarmonicMean() 

34 >>> metric.update_state([1, 4, 4]) 

35 >>> metric.result().numpy() 

36 2.0 

37 """ 

38 

39 @typechecked 

40 def __init__( 

41 self, name: str = "harmonic_mean", dtype: AcceptableDTypes = None, **kwargs 

42 ): 

43 super().__init__(name=name, dtype=dtype, **kwargs) 

44 

45 def update_state(self, values, sample_weight=None) -> None: 

46 values = tf.cast(values, dtype=self.dtype) 

47 super().update_state(tf.math.reciprocal(values), sample_weight) 

48 

49 def result(self) -> tf.Tensor: 

50 return tf.math.reciprocal_no_nan(super().result())