Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/metrics/geometric_mean.py: 45%

33 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implements GeometricMean.""" 

16 

17import tensorflow as tf 

18from tensorflow.keras import backend as K 

19from tensorflow.keras.metrics import Metric 

20 

21from typeguard import typechecked 

22from tensorflow_addons.utils.types import AcceptableDTypes 

23from tensorflow_addons.metrics.utils import sample_weight_shape_match 

24 

25 

26@tf.keras.utils.register_keras_serializable(package="Addons") 

27class GeometricMean(Metric): 

28 """Compute Geometric Mean 

29 

30 The geometric mean is a kind of mean. Unlike the arithmetic mean 

31 that uses the sum of values, it uses the product of the values to 

32 represent typical values in a set of numbers. 

33 

34 Note: `tfa.metrics.GeometricMean` can be used the same as `tf.keras.metrics.Mean`. 

35 

36 Args: 

37 name: (Optional) String name of the metric instance. 

38 dtype: (Optional) Data type of the metric result. 

39 

40 Usage: 

41 

42 >>> metric = tfa.metrics.GeometricMean() 

43 >>> metric.update_state([1, 3, 5, 7, 9]) 

44 >>> metric.result().numpy() 

45 3.9362833 

46 """ 

47 

48 @typechecked 

49 def __init__( 

50 self, name: str = "geometric_mean", dtype: AcceptableDTypes = None, **kwargs 

51 ): 

52 super().__init__(name=name, dtype=dtype, **kwargs) 

53 self.total = self.add_weight( 

54 "total", shape=None, initializer="zeros", dtype=dtype 

55 ) 

56 self.count = self.add_weight( 

57 "count", shape=None, initializer="zeros", dtype=dtype 

58 ) 

59 

60 def update_state(self, values, sample_weight=None) -> None: 

61 values = tf.cast(values, dtype=self.dtype) 

62 sample_weight = sample_weight_shape_match(values, sample_weight) 

63 sample_weight = tf.cast(sample_weight, dtype=self.dtype) 

64 

65 self.count.assign_add(tf.reduce_sum(sample_weight)) 

66 if not tf.math.is_inf(self.total): 

67 log_v = tf.math.log(values) 

68 log_v = tf.math.multiply(sample_weight, log_v) 

69 log_v = tf.reduce_sum(log_v) 

70 self.total.assign_add(log_v) 

71 

72 def result(self) -> tf.Tensor: 

73 if tf.math.is_inf(self.total): 

74 return tf.constant(0, dtype=self.dtype) 

75 ret = tf.math.exp(self.total / self.count) 

76 return tf.cast(ret, dtype=self.dtype) 

77 

78 def reset_state(self) -> None: 

79 K.batch_set_value([(v, 0) for v in self.variables]) 

80 

81 def reset_states(self): 

82 # Backwards compatibility alias of `reset_state`. New classes should 

83 # only implement `reset_state`. 

84 # Required in Tensorflow < 2.5.0 

85 return self.reset_state()