Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/confusion_matrix.py: 33%
60 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Confusion matrix related utilities."""
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import array_ops_stack
21from tensorflow.python.ops import check_ops
22from tensorflow.python.ops import cond
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.util import deprecation
26from tensorflow.python.util import dispatch
27from tensorflow.python.util.tf_export import tf_export
30def remove_squeezable_dimensions(
31 labels, predictions, expected_rank_diff=0, name=None):
32 """Squeeze last dim if ranks differ from expected by exactly 1.
34 In the common case where we expect shapes to match, `expected_rank_diff`
35 defaults to 0, and we squeeze the last dimension of the larger rank if they
36 differ by 1.
38 But, for example, if `labels` contains class IDs and `predictions` contains 1
39 probability per class, we expect `predictions` to have 1 more dimension than
40 `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze
41 `labels` if `rank(predictions) - rank(labels) == 0`, and
42 `predictions` if `rank(predictions) - rank(labels) == 2`.
44 This will use static shape if available. Otherwise, it will add graph
45 operations, which could result in a performance hit.
47 Args:
48 labels: Label values, a `Tensor` whose dimensions match `predictions`.
49 predictions: Predicted values, a `Tensor` of arbitrary dimensions.
50 expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
51 name: Name of the op.
53 Returns:
54 Tuple of `labels` and `predictions`, possibly with last dim squeezed.
55 """
56 with ops.name_scope(name, 'remove_squeezable_dimensions',
57 [labels, predictions]):
58 predictions = ops.convert_to_tensor(predictions)
59 labels = ops.convert_to_tensor(labels)
60 predictions_shape = predictions.get_shape()
61 predictions_rank = predictions_shape.ndims
62 labels_shape = labels.get_shape()
63 labels_rank = labels_shape.ndims
64 if (labels_rank is not None) and (predictions_rank is not None):
65 # Use static rank.
66 rank_diff = predictions_rank - labels_rank
67 if (rank_diff == expected_rank_diff + 1 and
68 predictions_shape.dims[-1].is_compatible_with(1)):
69 predictions = array_ops.squeeze(predictions, [-1])
70 elif (rank_diff == expected_rank_diff - 1 and
71 labels_shape.dims[-1].is_compatible_with(1)):
72 labels = array_ops.squeeze(labels, [-1])
73 return labels, predictions
75 # Use dynamic rank.
76 rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
77 if (predictions_rank is None) or (
78 predictions_shape.dims[-1].is_compatible_with(1)):
79 predictions = cond.cond(
80 math_ops.equal(expected_rank_diff + 1, rank_diff),
81 lambda: array_ops.squeeze(predictions, [-1]),
82 lambda: predictions)
83 if (labels_rank is None) or (
84 labels_shape.dims[-1].is_compatible_with(1)):
85 labels = cond.cond(
86 math_ops.equal(expected_rank_diff - 1, rank_diff),
87 lambda: array_ops.squeeze(labels, [-1]),
88 lambda: labels)
89 return labels, predictions
92@tf_export('math.confusion_matrix', v1=[])
93@dispatch.add_dispatch_support
94def confusion_matrix(labels,
95 predictions,
96 num_classes=None,
97 weights=None,
98 dtype=dtypes.int32,
99 name=None):
100 """Computes the confusion matrix from predictions and labels.
102 The matrix columns represent the prediction labels and the rows represent the
103 real labels. The confusion matrix is always a 2-D array of shape `[n, n]`,
104 where `n` is the number of valid labels for a given classification task. Both
105 prediction and labels must be 1-D arrays of the same shape in order for this
106 function to work.
108 If `num_classes` is `None`, then `num_classes` will be set to one plus the
109 maximum value in either predictions or labels. Class labels are expected to
110 start at 0. For example, if `num_classes` is 3, then the possible labels
111 would be `[0, 1, 2]`.
113 If `weights` is not `None`, then each prediction contributes its
114 corresponding weight to the total value of the confusion matrix cell.
116 For example:
118 ```python
119 tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
120 [[0 0 0 0 0]
121 [0 0 1 0 0]
122 [0 0 1 0 0]
123 [0 0 0 0 0]
124 [0 0 0 0 1]]
125 ```
127 Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
128 resulting in a 5x5 confusion matrix.
130 Args:
131 labels: 1-D `Tensor` of real labels for the classification task.
132 predictions: 1-D `Tensor` of predictions for a given classification.
133 num_classes: The possible number of labels the classification task can
134 have. If this value is not provided, it will be calculated
135 using both predictions and labels array.
136 weights: An optional `Tensor` whose shape matches `predictions`.
137 dtype: Data type of the confusion matrix.
138 name: Scope name.
140 Returns:
141 A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion
142 matrix, where `n` is the number of possible labels in the classification
143 task.
145 Raises:
146 ValueError: If both predictions and labels are not 1-D vectors and have
147 mismatched shapes, or if `weights` is not `None` and its shape doesn't
148 match `predictions`.
149 """
150 with ops.name_scope(name, 'confusion_matrix',
151 (predictions, labels, num_classes, weights)) as name:
152 labels, predictions = remove_squeezable_dimensions(
153 ops.convert_to_tensor(labels, name='labels'),
154 ops.convert_to_tensor(
155 predictions, name='predictions'))
156 predictions = math_ops.cast(predictions, dtypes.int64)
157 labels = math_ops.cast(labels, dtypes.int64)
159 # Sanity checks - underflow or overflow can cause memory corruption.
160 labels = control_flow_ops.with_dependencies(
161 [check_ops.assert_non_negative(
162 labels, message='`labels` contains negative values')],
163 labels)
164 predictions = control_flow_ops.with_dependencies(
165 [check_ops.assert_non_negative(
166 predictions, message='`predictions` contains negative values')],
167 predictions)
169 if num_classes is None:
170 num_classes = math_ops.maximum(math_ops.reduce_max(predictions),
171 math_ops.reduce_max(labels)) + 1
172 else:
173 num_classes_int64 = math_ops.cast(num_classes, dtypes.int64)
174 labels = control_flow_ops.with_dependencies(
175 [check_ops.assert_less(
176 labels, num_classes_int64, message='`labels` out of bound')],
177 labels)
178 predictions = control_flow_ops.with_dependencies(
179 [check_ops.assert_less(
180 predictions, num_classes_int64,
181 message='`predictions` out of bound')],
182 predictions)
184 if weights is not None:
185 weights = ops.convert_to_tensor(weights, name='weights')
186 predictions.get_shape().assert_is_compatible_with(weights.get_shape())
187 weights = math_ops.cast(weights, dtype)
189 shape = array_ops_stack.stack([num_classes, num_classes])
190 indices = array_ops_stack.stack([labels, predictions], axis=1)
191 values = (array_ops.ones_like(predictions, dtype)
192 if weights is None else weights)
193 return array_ops.scatter_nd(
194 indices=indices,
195 updates=values,
196 shape=math_ops.cast(shape, dtypes.int64))
199@tf_export(v1=['math.confusion_matrix', 'confusion_matrix'])
200@dispatch.add_dispatch_support
201@deprecation.deprecated_endpoints('confusion_matrix', 'train.confusion_matrix')
202def confusion_matrix_v1(labels,
203 predictions,
204 num_classes=None,
205 dtype=dtypes.int32,
206 name=None,
207 weights=None):
208 """Computes the confusion matrix from predictions and labels.
210 The matrix columns represent the prediction labels and the rows represent the
211 real labels. The confusion matrix is always a 2-D array of shape `[n, n]`,
212 where `n` is the number of valid labels for a given classification task. Both
213 prediction and labels must be 1-D arrays of the same shape in order for this
214 function to work.
216 If `num_classes` is `None`, then `num_classes` will be set to one plus the
217 maximum value in either predictions or labels. Class labels are expected to
218 start at 0. For example, if `num_classes` is 3, then the possible labels
219 would be `[0, 1, 2]`.
221 If `weights` is not `None`, then each prediction contributes its
222 corresponding weight to the total value of the confusion matrix cell.
224 For example:
226 ```python
227 tf.math.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
228 [[0 0 0 0 0]
229 [0 0 1 0 0]
230 [0 0 1 0 0]
231 [0 0 0 0 0]
232 [0 0 0 0 1]]
233 ```
235 Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
236 resulting in a 5x5 confusion matrix.
238 Args:
239 labels: 1-D `Tensor` of real labels for the classification task.
240 predictions: 1-D `Tensor` of predictions for a given classification.
241 num_classes: The possible number of labels the classification task can have.
242 If this value is not provided, it will be calculated using both
243 predictions and labels array.
244 dtype: Data type of the confusion matrix.
245 name: Scope name.
246 weights: An optional `Tensor` whose shape matches `predictions`.
248 Returns:
249 A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion
250 matrix, where `n` is the number of possible labels in the classification
251 task.
253 Raises:
254 ValueError: If both predictions and labels are not 1-D vectors and have
255 mismatched shapes, or if `weights` is not `None` and its shape doesn't
256 match `predictions`.
257 """
258 return confusion_matrix(labels, predictions, num_classes, weights, dtype,
259 name)