Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/metrics/geometric_mean.py: 45%
33 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements GeometricMean."""
17import tensorflow as tf
18from tensorflow.keras import backend as K
19from tensorflow.keras.metrics import Metric
21from typeguard import typechecked
22from tensorflow_addons.utils.types import AcceptableDTypes
23from tensorflow_addons.metrics.utils import sample_weight_shape_match
26@tf.keras.utils.register_keras_serializable(package="Addons")
27class GeometricMean(Metric):
28 """Compute Geometric Mean
30 The geometric mean is a kind of mean. Unlike the arithmetic mean
31 that uses the sum of values, it uses the product of the values to
32 represent typical values in a set of numbers.
34 Note: `tfa.metrics.GeometricMean` can be used the same as `tf.keras.metrics.Mean`.
36 Args:
37 name: (Optional) String name of the metric instance.
38 dtype: (Optional) Data type of the metric result.
40 Usage:
42 >>> metric = tfa.metrics.GeometricMean()
43 >>> metric.update_state([1, 3, 5, 7, 9])
44 >>> metric.result().numpy()
45 3.9362833
46 """
48 @typechecked
49 def __init__(
50 self, name: str = "geometric_mean", dtype: AcceptableDTypes = None, **kwargs
51 ):
52 super().__init__(name=name, dtype=dtype, **kwargs)
53 self.total = self.add_weight(
54 "total", shape=None, initializer="zeros", dtype=dtype
55 )
56 self.count = self.add_weight(
57 "count", shape=None, initializer="zeros", dtype=dtype
58 )
60 def update_state(self, values, sample_weight=None) -> None:
61 values = tf.cast(values, dtype=self.dtype)
62 sample_weight = sample_weight_shape_match(values, sample_weight)
63 sample_weight = tf.cast(sample_weight, dtype=self.dtype)
65 self.count.assign_add(tf.reduce_sum(sample_weight))
66 if not tf.math.is_inf(self.total):
67 log_v = tf.math.log(values)
68 log_v = tf.math.multiply(sample_weight, log_v)
69 log_v = tf.reduce_sum(log_v)
70 self.total.assign_add(log_v)
72 def result(self) -> tf.Tensor:
73 if tf.math.is_inf(self.total):
74 return tf.constant(0, dtype=self.dtype)
75 ret = tf.math.exp(self.total / self.count)
76 return tf.cast(ret, dtype=self.dtype)
78 def reset_state(self) -> None:
79 K.batch_set_value([(v, 0) for v in self.variables])
81 def reset_states(self):
82 # Backwards compatibility alias of `reset_state`. New classes should
83 # only implement `reset_state`.
84 # Required in Tensorflow < 2.5.0
85 return self.reset_state()