Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/metrics/utils.py: 37%

35 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for metrics.""" 

16 

17import numpy as np 

18import tensorflow as tf 

19from tensorflow_addons.utils.types import AcceptableDTypes 

20 

21from typeguard import typechecked 

22from typing import Optional, Callable 

23 

24 

25class MeanMetricWrapper(tf.keras.metrics.Mean): 

26 """Wraps a stateless metric function with the Mean metric.""" 

27 

28 @typechecked 

29 def __init__( 

30 self, 

31 fn: Callable, 

32 name: Optional[str] = None, 

33 dtype: AcceptableDTypes = None, 

34 **kwargs, 

35 ): 

36 """Creates a `MeanMetricWrapper` instance. 

37 Args: 

38 fn: The metric function to wrap, with signature 

39 `fn(y_true, y_pred, **kwargs)`. 

40 name: (Optional) string name of the metric instance. 

41 dtype: (Optional) data type of the metric result. 

42 **kwargs: The keyword arguments that are passed on to `fn`. 

43 """ 

44 super().__init__(name=name, dtype=dtype) 

45 self._fn = fn 

46 self._fn_kwargs = kwargs 

47 

48 def update_state(self, y_true, y_pred, sample_weight=None): 

49 """Accumulates metric statistics. 

50 

51 `y_true` and `y_pred` should have the same shape. 

52 Args: 

53 y_true: The ground truth values. 

54 y_pred: The predicted values. 

55 sample_weight: Optional weighting of each example. Defaults to 1. 

56 Can be a `Tensor` whose rank is either 0, or the same rank as 

57 `y_true`, and must be broadcastable to `y_true`. 

58 Returns: 

59 Update op. 

60 """ 

61 y_true = tf.cast(y_true, self._dtype) 

62 y_pred = tf.cast(y_pred, self._dtype) 

63 # TODO: Add checks for ragged tensors and dimensions: 

64 # `ragged_assert_compatible_and_get_flat_values` 

65 # and `squeeze_or_expand_dimensions` 

66 matches = self._fn(y_true, y_pred, **self._fn_kwargs) 

67 return super().update_state(matches, sample_weight=sample_weight) 

68 

69 def get_config(self): 

70 config = {k: v for k, v in self._fn_kwargs.items()} 

71 base_config = super().get_config() 

72 return {**base_config, **config} 

73 

74 

75def _get_model(metric, num_output): 

76 # Test API comptibility with tf.keras Model 

77 model = tf.keras.Sequential() 

78 model.add(tf.keras.layers.Dense(64, activation="relu")) 

79 model.add(tf.keras.layers.Dense(num_output, activation="softmax")) 

80 model.compile( 

81 optimizer="adam", loss="categorical_crossentropy", metrics=["acc", metric] 

82 ) 

83 

84 data = np.random.random((10, 3)) 

85 labels = np.random.random((10, num_output)) 

86 model.fit(data, labels, epochs=1, batch_size=5, verbose=0) 

87 

88 

89def sample_weight_shape_match(v, sample_weight): 

90 if sample_weight is None: 

91 return tf.ones_like(v) 

92 if np.size(sample_weight) == 1: 

93 return tf.fill(v.shape, sample_weight) 

94 return tf.convert_to_tensor(sample_weight)