Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/metrics/harmonic_mean.py: 71%
14 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements HarmonicMean."""
17import tensorflow as tf
19from typeguard import typechecked
20from tensorflow_addons.utils.types import AcceptableDTypes
23@tf.keras.utils.register_keras_serializable(package="Addons")
24class HarmonicMean(tf.keras.metrics.Mean):
25 """Compute Harmonic Mean
26 The harmonic mean is a kind of mean. It can be expressed as the reciprocal of
27 the arithmetic mean of the reciprocals of the given set of numbers.
28 Note: `tfa.metrics.HarmonicMean` can be used the same as `tf.keras.metrics.Mean`.
29 Args:
30 name: (Optional) String name of the metric instance.
31 dtype: (Optional) Data type of the metric result.
32 Usage:
33 >>> metric = tfa.metrics.HarmonicMean()
34 >>> metric.update_state([1, 4, 4])
35 >>> metric.result().numpy()
36 2.0
37 """
39 @typechecked
40 def __init__(
41 self, name: str = "harmonic_mean", dtype: AcceptableDTypes = None, **kwargs
42 ):
43 super().__init__(name=name, dtype=dtype, **kwargs)
45 def update_state(self, values, sample_weight=None) -> None:
46 values = tf.cast(values, dtype=self.dtype)
47 super().update_state(tf.math.reciprocal(values), sample_weight)
49 def result(self) -> tf.Tensor:
50 return tf.math.reciprocal_no_nan(super().result())