Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/metrics/multilabel_confusion_matrix.py: 34%

47 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implements Multi-label confusion matrix scores.""" 

16 

17import warnings 

18 

19import tensorflow as tf 

20from tensorflow.keras import backend as K 

21from tensorflow.keras.metrics import Metric 

22import numpy as np 

23 

24from typeguard import typechecked 

25from tensorflow_addons.utils.types import AcceptableDTypes, FloatTensorLike 

26 

27 

28class MultiLabelConfusionMatrix(Metric): 

29 """Computes Multi-label confusion matrix. 

30 

31 Class-wise confusion matrix is computed for the 

32 evaluation of classification. 

33 

34 If multi-class input is provided, it will be treated 

35 as multilabel data. 

36 

37 Consider classification problem with two classes 

38 (i.e num_classes=2). 

39 

40 Resultant matrix `M` will be in the shape of `(num_classes, 2, 2)`. 

41 

42 Every class `i` has a dedicated matrix of shape `(2, 2)` that contains: 

43 

44 - true negatives for class `i` in `M(0,0)` 

45 - false positives for class `i` in `M(0,1)` 

46 - false negatives for class `i` in `M(1,0)` 

47 - true positives for class `i` in `M(1,1)` 

48 

49 Args: 

50 num_classes: `int`, the number of labels the prediction task can have. 

51 name: (Optional) string name of the metric instance. 

52 dtype: (Optional) data type of the metric result. 

53 

54 Usage: 

55 

56 >>> # multilabel confusion matrix 

57 >>> y_true = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.int32) 

58 >>> y_pred = np.array([[1, 0, 0], [0, 1, 1]], dtype=np.int32) 

59 >>> metric = tfa.metrics.MultiLabelConfusionMatrix(num_classes=3) 

60 >>> metric.update_state(y_true, y_pred) 

61 >>> result = metric.result() 

62 >>> result.numpy() #doctest: -DONT_ACCEPT_BLANKLINE 

63 array([[[1., 0.], 

64 [0., 1.]], 

65 <BLANKLINE> 

66 [[1., 0.], 

67 [0., 1.]], 

68 <BLANKLINE> 

69 [[0., 1.], 

70 [1., 0.]]], dtype=float32) 

71 >>> # if multiclass input is provided 

72 >>> y_true = np.array([[1, 0, 0], [0, 1, 0]], dtype=np.int32) 

73 >>> y_pred = np.array([[1, 0, 0], [0, 0, 1]], dtype=np.int32) 

74 >>> metric = tfa.metrics.MultiLabelConfusionMatrix(num_classes=3) 

75 >>> metric.update_state(y_true, y_pred) 

76 >>> result = metric.result() 

77 >>> result.numpy() #doctest: -DONT_ACCEPT_BLANKLINE 

78 array([[[1., 0.], 

79 [0., 1.]], 

80 <BLANKLINE> 

81 [[1., 0.], 

82 [1., 0.]], 

83 <BLANKLINE> 

84 [[1., 1.], 

85 [0., 0.]]], dtype=float32) 

86 

87 """ 

88 

89 @typechecked 

90 def __init__( 

91 self, 

92 num_classes: FloatTensorLike, 

93 name: str = "Multilabel_confusion_matrix", 

94 dtype: AcceptableDTypes = None, 

95 **kwargs, 

96 ): 

97 super().__init__(name=name, dtype=dtype) 

98 self.num_classes = num_classes 

99 self.true_positives = self.add_weight( 

100 "true_positives", 

101 shape=[self.num_classes], 

102 initializer="zeros", 

103 dtype=self.dtype, 

104 ) 

105 self.false_positives = self.add_weight( 

106 "false_positives", 

107 shape=[self.num_classes], 

108 initializer="zeros", 

109 dtype=self.dtype, 

110 ) 

111 self.false_negatives = self.add_weight( 

112 "false_negatives", 

113 shape=[self.num_classes], 

114 initializer="zeros", 

115 dtype=self.dtype, 

116 ) 

117 self.true_negatives = self.add_weight( 

118 "true_negatives", 

119 shape=[self.num_classes], 

120 initializer="zeros", 

121 dtype=self.dtype, 

122 ) 

123 

124 def update_state(self, y_true, y_pred, sample_weight=None): 

125 if sample_weight is not None: 

126 warnings.warn( 

127 "`sample_weight` is not None. Be aware that MultiLabelConfusionMatrix " 

128 "does not take `sample_weight` into account when computing the metric " 

129 "value." 

130 ) 

131 

132 y_true = tf.cast(y_true, tf.int32) 

133 y_pred = tf.cast(y_pred, tf.int32) 

134 # true positive 

135 true_positive = tf.math.count_nonzero(y_true * y_pred, 0) 

136 # predictions sum 

137 pred_sum = tf.math.count_nonzero(y_pred, 0) 

138 # true labels sum 

139 true_sum = tf.math.count_nonzero(y_true, 0) 

140 false_positive = pred_sum - true_positive 

141 false_negative = true_sum - true_positive 

142 y_true_negative = tf.math.not_equal(y_true, 1) 

143 y_pred_negative = tf.math.not_equal(y_pred, 1) 

144 true_negative = tf.math.count_nonzero( 

145 tf.math.logical_and(y_true_negative, y_pred_negative), axis=0 

146 ) 

147 

148 # true positive state update 

149 self.true_positives.assign_add(tf.cast(true_positive, self.dtype)) 

150 # false positive state update 

151 self.false_positives.assign_add(tf.cast(false_positive, self.dtype)) 

152 # false negative state update 

153 self.false_negatives.assign_add(tf.cast(false_negative, self.dtype)) 

154 # true negative state update 

155 self.true_negatives.assign_add(tf.cast(true_negative, self.dtype)) 

156 

157 def result(self): 

158 flat_confusion_matrix = tf.convert_to_tensor( 

159 [ 

160 self.true_negatives, 

161 self.false_positives, 

162 self.false_negatives, 

163 self.true_positives, 

164 ] 

165 ) 

166 # reshape into 2*2 matrix 

167 confusion_matrix = tf.reshape(tf.transpose(flat_confusion_matrix), [-1, 2, 2]) 

168 

169 return confusion_matrix 

170 

171 def get_config(self): 

172 """Returns the serializable config of the metric.""" 

173 

174 config = { 

175 "num_classes": self.num_classes, 

176 } 

177 base_config = super().get_config() 

178 return {**base_config, **config} 

179 

180 def reset_state(self): 

181 reset_value = np.zeros(self.num_classes, dtype=np.int32) 

182 K.batch_set_value([(v, reset_value) for v in self.variables]) 

183 

184 def reset_states(self): 

185 # Backwards compatibility alias of `reset_state`. New classes should 

186 # only implement `reset_state`. 

187 # Required in Tensorflow < 2.5.0 

188 return self.reset_state()