Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/metrics/multilabel_confusion_matrix.py: 34%
47 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements Multi-label confusion matrix scores."""
17import warnings
19import tensorflow as tf
20from tensorflow.keras import backend as K
21from tensorflow.keras.metrics import Metric
22import numpy as np
24from typeguard import typechecked
25from tensorflow_addons.utils.types import AcceptableDTypes, FloatTensorLike
28class MultiLabelConfusionMatrix(Metric):
29 """Computes Multi-label confusion matrix.
31 Class-wise confusion matrix is computed for the
32 evaluation of classification.
34 If multi-class input is provided, it will be treated
35 as multilabel data.
37 Consider classification problem with two classes
38 (i.e num_classes=2).
40 Resultant matrix `M` will be in the shape of `(num_classes, 2, 2)`.
42 Every class `i` has a dedicated matrix of shape `(2, 2)` that contains:
44 - true negatives for class `i` in `M(0,0)`
45 - false positives for class `i` in `M(0,1)`
46 - false negatives for class `i` in `M(1,0)`
47 - true positives for class `i` in `M(1,1)`
49 Args:
50 num_classes: `int`, the number of labels the prediction task can have.
51 name: (Optional) string name of the metric instance.
52 dtype: (Optional) data type of the metric result.
54 Usage:
56 >>> # multilabel confusion matrix
57 >>> y_true = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.int32)
58 >>> y_pred = np.array([[1, 0, 0], [0, 1, 1]], dtype=np.int32)
59 >>> metric = tfa.metrics.MultiLabelConfusionMatrix(num_classes=3)
60 >>> metric.update_state(y_true, y_pred)
61 >>> result = metric.result()
62 >>> result.numpy() #doctest: -DONT_ACCEPT_BLANKLINE
63 array([[[1., 0.],
64 [0., 1.]],
65 <BLANKLINE>
66 [[1., 0.],
67 [0., 1.]],
68 <BLANKLINE>
69 [[0., 1.],
70 [1., 0.]]], dtype=float32)
71 >>> # if multiclass input is provided
72 >>> y_true = np.array([[1, 0, 0], [0, 1, 0]], dtype=np.int32)
73 >>> y_pred = np.array([[1, 0, 0], [0, 0, 1]], dtype=np.int32)
74 >>> metric = tfa.metrics.MultiLabelConfusionMatrix(num_classes=3)
75 >>> metric.update_state(y_true, y_pred)
76 >>> result = metric.result()
77 >>> result.numpy() #doctest: -DONT_ACCEPT_BLANKLINE
78 array([[[1., 0.],
79 [0., 1.]],
80 <BLANKLINE>
81 [[1., 0.],
82 [1., 0.]],
83 <BLANKLINE>
84 [[1., 1.],
85 [0., 0.]]], dtype=float32)
87 """
89 @typechecked
90 def __init__(
91 self,
92 num_classes: FloatTensorLike,
93 name: str = "Multilabel_confusion_matrix",
94 dtype: AcceptableDTypes = None,
95 **kwargs,
96 ):
97 super().__init__(name=name, dtype=dtype)
98 self.num_classes = num_classes
99 self.true_positives = self.add_weight(
100 "true_positives",
101 shape=[self.num_classes],
102 initializer="zeros",
103 dtype=self.dtype,
104 )
105 self.false_positives = self.add_weight(
106 "false_positives",
107 shape=[self.num_classes],
108 initializer="zeros",
109 dtype=self.dtype,
110 )
111 self.false_negatives = self.add_weight(
112 "false_negatives",
113 shape=[self.num_classes],
114 initializer="zeros",
115 dtype=self.dtype,
116 )
117 self.true_negatives = self.add_weight(
118 "true_negatives",
119 shape=[self.num_classes],
120 initializer="zeros",
121 dtype=self.dtype,
122 )
124 def update_state(self, y_true, y_pred, sample_weight=None):
125 if sample_weight is not None:
126 warnings.warn(
127 "`sample_weight` is not None. Be aware that MultiLabelConfusionMatrix "
128 "does not take `sample_weight` into account when computing the metric "
129 "value."
130 )
132 y_true = tf.cast(y_true, tf.int32)
133 y_pred = tf.cast(y_pred, tf.int32)
134 # true positive
135 true_positive = tf.math.count_nonzero(y_true * y_pred, 0)
136 # predictions sum
137 pred_sum = tf.math.count_nonzero(y_pred, 0)
138 # true labels sum
139 true_sum = tf.math.count_nonzero(y_true, 0)
140 false_positive = pred_sum - true_positive
141 false_negative = true_sum - true_positive
142 y_true_negative = tf.math.not_equal(y_true, 1)
143 y_pred_negative = tf.math.not_equal(y_pred, 1)
144 true_negative = tf.math.count_nonzero(
145 tf.math.logical_and(y_true_negative, y_pred_negative), axis=0
146 )
148 # true positive state update
149 self.true_positives.assign_add(tf.cast(true_positive, self.dtype))
150 # false positive state update
151 self.false_positives.assign_add(tf.cast(false_positive, self.dtype))
152 # false negative state update
153 self.false_negatives.assign_add(tf.cast(false_negative, self.dtype))
154 # true negative state update
155 self.true_negatives.assign_add(tf.cast(true_negative, self.dtype))
157 def result(self):
158 flat_confusion_matrix = tf.convert_to_tensor(
159 [
160 self.true_negatives,
161 self.false_positives,
162 self.false_negatives,
163 self.true_positives,
164 ]
165 )
166 # reshape into 2*2 matrix
167 confusion_matrix = tf.reshape(tf.transpose(flat_confusion_matrix), [-1, 2, 2])
169 return confusion_matrix
171 def get_config(self):
172 """Returns the serializable config of the metric."""
174 config = {
175 "num_classes": self.num_classes,
176 }
177 base_config = super().get_config()
178 return {**base_config, **config}
180 def reset_state(self):
181 reset_value = np.zeros(self.num_classes, dtype=np.int32)
182 K.batch_set_value([(v, reset_value) for v in self.variables])
184 def reset_states(self):
185 # Backwards compatibility alias of `reset_state`. New classes should
186 # only implement `reset_state`.
187 # Required in Tensorflow < 2.5.0
188 return self.reset_state()