Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/metrics/confusion_metrics.py: 28%
331 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Confusion metrics, i.e. metrics based on True/False positives/negatives."""
17import abc
19import numpy as np
20import tensorflow.compat.v2 as tf
22from keras.src import activations
23from keras.src import backend
24from keras.src.dtensor import utils as dtensor_utils
25from keras.src.metrics import base_metric
26from keras.src.utils import metrics_utils
27from keras.src.utils.generic_utils import to_list
28from keras.src.utils.tf_utils import is_tensor_or_variable
30# isort: off
31from tensorflow.python.util.tf_export import keras_export
34class _ConfusionMatrixConditionCount(base_metric.Metric):
35 """Calculates the number of the given confusion matrix condition.
37 Args:
38 confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions.
39 thresholds: (Optional) Defaults to 0.5. A float value or a python
40 list/tuple of float threshold values in [0, 1]. A threshold is compared
41 with prediction values to determine the truth value of predictions
42 (i.e., above the threshold is `true`, below is `false`). One metric
43 value is generated for each threshold value.
44 name: (Optional) string name of the metric instance.
45 dtype: (Optional) data type of the metric result.
46 """
48 def __init__(
49 self, confusion_matrix_cond, thresholds=None, name=None, dtype=None
50 ):
51 super().__init__(name=name, dtype=dtype)
52 self._confusion_matrix_cond = confusion_matrix_cond
53 self.init_thresholds = thresholds
54 self.thresholds = metrics_utils.parse_init_thresholds(
55 thresholds, default_threshold=0.5
56 )
57 self._thresholds_distributed_evenly = (
58 metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
59 )
60 self.accumulator = self.add_weight(
61 "accumulator", shape=(len(self.thresholds),), initializer="zeros"
62 )
64 def update_state(self, y_true, y_pred, sample_weight=None):
65 """Accumulates the metric statistics.
67 Args:
68 y_true: The ground truth values.
69 y_pred: The predicted values.
70 sample_weight: Optional weighting of each example. Defaults to 1. Can
71 be a `Tensor` whose rank is either 0, or the same rank as `y_true`,
72 and must be broadcastable to `y_true`.
74 Returns:
75 Update op.
76 """
77 return metrics_utils.update_confusion_matrix_variables(
78 {self._confusion_matrix_cond: self.accumulator},
79 y_true,
80 y_pred,
81 thresholds=self.thresholds,
82 thresholds_distributed_evenly=self._thresholds_distributed_evenly,
83 sample_weight=sample_weight,
84 )
86 def result(self):
87 if len(self.thresholds) == 1:
88 result = self.accumulator[0]
89 else:
90 result = self.accumulator
91 return tf.convert_to_tensor(result)
93 def reset_state(self):
94 backend.batch_set_value(
95 [(v, np.zeros(v.shape.as_list())) for v in self.variables]
96 )
98 def get_config(self):
99 config = {"thresholds": self.init_thresholds}
100 base_config = super().get_config()
101 return dict(list(base_config.items()) + list(config.items()))
104@keras_export("keras.metrics.FalsePositives")
105class FalsePositives(_ConfusionMatrixConditionCount):
106 """Calculates the number of false positives.
108 If `sample_weight` is given, calculates the sum of the weights of
109 false positives. This metric creates one local variable, `accumulator`
110 that is used to keep track of the number of false positives.
112 If `sample_weight` is `None`, weights default to 1.
113 Use `sample_weight` of 0 to mask values.
115 Args:
116 thresholds: (Optional) Defaults to 0.5. A float value, or a Python
117 list/tuple of float threshold values in [0, 1]. A threshold is compared
118 with prediction values to determine the truth value of predictions
119 (i.e., above the threshold is `true`, below is `false`). If used with a
120 loss function that sets `from_logits=True` (i.e. no sigmoid applied to
121 predictions), `thresholds` should be set to 0. One metric value is
122 generated for each threshold value.
123 name: (Optional) string name of the metric instance.
124 dtype: (Optional) data type of the metric result.
126 Standalone usage:
128 >>> m = tf.keras.metrics.FalsePositives()
129 >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1])
130 >>> m.result().numpy()
131 2.0
133 >>> m.reset_state()
134 >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0])
135 >>> m.result().numpy()
136 1.0
138 Usage with `compile()` API:
140 ```python
141 model.compile(optimizer='sgd',
142 loss='mse',
143 metrics=[tf.keras.metrics.FalsePositives()])
144 ```
146 Usage with a loss with `from_logits=True`:
148 ```python
149 model.compile(optimizer='adam',
150 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
151 metrics=[tf.keras.metrics.FalsePositives(thresholds=0)])
152 ```
153 """
155 @dtensor_utils.inject_mesh
156 def __init__(self, thresholds=None, name=None, dtype=None):
157 super().__init__(
158 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES,
159 thresholds=thresholds,
160 name=name,
161 dtype=dtype,
162 )
165@keras_export("keras.metrics.FalseNegatives")
166class FalseNegatives(_ConfusionMatrixConditionCount):
167 """Calculates the number of false negatives.
169 If `sample_weight` is given, calculates the sum of the weights of
170 false negatives. This metric creates one local variable, `accumulator`
171 that is used to keep track of the number of false negatives.
173 If `sample_weight` is `None`, weights default to 1.
174 Use `sample_weight` of 0 to mask values.
176 Args:
177 thresholds: (Optional) Defaults to 0.5. A float value, or a Python
178 list/tuple of float threshold values in [0, 1]. A threshold is compared
179 with prediction values to determine the truth value of predictions
180 (i.e., above the threshold is `true`, below is `false`). If used with a
181 loss function that sets `from_logits=True` (i.e. no sigmoid applied to
182 predictions), `thresholds` should be set to 0. One metric value is
183 generated for each threshold value.
184 name: (Optional) string name of the metric instance.
185 dtype: (Optional) data type of the metric result.
187 Standalone usage:
189 >>> m = tf.keras.metrics.FalseNegatives()
190 >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
191 >>> m.result().numpy()
192 2.0
194 >>> m.reset_state()
195 >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0])
196 >>> m.result().numpy()
197 1.0
199 Usage with `compile()` API:
201 ```python
202 model.compile(optimizer='sgd',
203 loss='mse',
204 metrics=[tf.keras.metrics.FalseNegatives()])
205 ```
207 Usage with a loss with `from_logits=True`:
209 ```python
210 model.compile(optimizer='adam',
211 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
212 metrics=[tf.keras.metrics.FalseNegatives(thresholds=0)])
213 ```
214 """
216 @dtensor_utils.inject_mesh
217 def __init__(self, thresholds=None, name=None, dtype=None):
218 super().__init__(
219 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES,
220 thresholds=thresholds,
221 name=name,
222 dtype=dtype,
223 )
226@keras_export("keras.metrics.TrueNegatives")
227class TrueNegatives(_ConfusionMatrixConditionCount):
228 """Calculates the number of true negatives.
230 If `sample_weight` is given, calculates the sum of the weights of
231 true negatives. This metric creates one local variable, `accumulator`
232 that is used to keep track of the number of true negatives.
234 If `sample_weight` is `None`, weights default to 1.
235 Use `sample_weight` of 0 to mask values.
237 Args:
238 thresholds: (Optional) Defaults to 0.5. A float value, or a Python
239 list/tuple of float threshold values in [0, 1]. A threshold is compared
240 with prediction values to determine the truth value of predictions
241 (i.e., above the threshold is `true`, below is `false`). If used with a
242 loss function that sets `from_logits=True` (i.e. no sigmoid applied to
243 predictions), `thresholds` should be set to 0. One metric value is
244 generated for each threshold value.
245 name: (Optional) string name of the metric instance.
246 dtype: (Optional) data type of the metric result.
248 Standalone usage:
250 >>> m = tf.keras.metrics.TrueNegatives()
251 >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0])
252 >>> m.result().numpy()
253 2.0
255 >>> m.reset_state()
256 >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0])
257 >>> m.result().numpy()
258 1.0
260 Usage with `compile()` API:
262 ```python
263 model.compile(optimizer='sgd',
264 loss='mse',
265 metrics=[tf.keras.metrics.TrueNegatives()])
266 ```
268 Usage with a loss with `from_logits=True`:
270 ```python
271 model.compile(optimizer='adam',
272 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
273 metrics=[tf.keras.metrics.TrueNegatives(thresholds=0)])
274 ```
275 """
277 @dtensor_utils.inject_mesh
278 def __init__(self, thresholds=None, name=None, dtype=None):
279 super().__init__(
280 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES,
281 thresholds=thresholds,
282 name=name,
283 dtype=dtype,
284 )
287@keras_export("keras.metrics.TruePositives")
288class TruePositives(_ConfusionMatrixConditionCount):
289 """Calculates the number of true positives.
291 If `sample_weight` is given, calculates the sum of the weights of
292 true positives. This metric creates one local variable, `true_positives`
293 that is used to keep track of the number of true positives.
295 If `sample_weight` is `None`, weights default to 1.
296 Use `sample_weight` of 0 to mask values.
298 Args:
299 thresholds: (Optional) Defaults to 0.5. A float value, or a Python
300 list/tuple of float threshold values in [0, 1]. A threshold is compared
301 with prediction values to determine the truth value of predictions
302 (i.e., above the threshold is `true`, below is `false`). If used with a
303 loss function that sets `from_logits=True` (i.e. no sigmoid applied to
304 predictions), `thresholds` should be set to 0. One metric value is
305 generated for each threshold value.
306 name: (Optional) string name of the metric instance.
307 dtype: (Optional) data type of the metric result.
309 Standalone usage:
311 >>> m = tf.keras.metrics.TruePositives()
312 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
313 >>> m.result().numpy()
314 2.0
316 >>> m.reset_state()
317 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
318 >>> m.result().numpy()
319 1.0
321 Usage with `compile()` API:
323 ```python
324 model.compile(optimizer='sgd',
325 loss='mse',
326 metrics=[tf.keras.metrics.TruePositives()])
327 ```
329 Usage with a loss with `from_logits=True`:
331 ```python
332 model.compile(optimizer='adam',
333 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
334 metrics=[tf.keras.metrics.TruePositives(thresholds=0)])
335 ```
336 """
338 @dtensor_utils.inject_mesh
339 def __init__(self, thresholds=None, name=None, dtype=None):
340 super().__init__(
341 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES,
342 thresholds=thresholds,
343 name=name,
344 dtype=dtype,
345 )
348@keras_export("keras.metrics.Precision")
349class Precision(base_metric.Metric):
350 """Computes the precision of the predictions with respect to the labels.
352 The metric creates two local variables, `true_positives` and
353 `false_positives` that are used to compute the precision. This value is
354 ultimately returned as `precision`, an idempotent operation that simply
355 divides `true_positives` by the sum of `true_positives` and
356 `false_positives`.
358 If `sample_weight` is `None`, weights default to 1.
359 Use `sample_weight` of 0 to mask values.
361 If `top_k` is set, we'll calculate precision as how often on average a class
362 among the top-k classes with the highest predicted values of a batch entry
363 is correct and can be found in the label for that entry.
365 If `class_id` is specified, we calculate precision by considering only the
366 entries in the batch for which `class_id` is above the threshold and/or in
367 the top-k highest predictions, and computing the fraction of them for which
368 `class_id` is indeed a correct label.
370 Args:
371 thresholds: (Optional) A float value, or a Python list/tuple of float
372 threshold values in [0, 1]. A threshold is compared with prediction
373 values to determine the truth value of predictions (i.e., above the
374 threshold is `true`, below is `false`). If used with a loss function
375 that sets `from_logits=True` (i.e. no sigmoid applied to predictions),
376 `thresholds` should be set to 0. One metric value is generated for each
377 threshold value. If neither thresholds nor top_k are set, the default is
378 to calculate precision with `thresholds=0.5`.
379 top_k: (Optional) Unset by default. An int value specifying the top-k
380 predictions to consider when calculating precision.
381 class_id: (Optional) Integer class ID for which we want binary metrics.
382 This must be in the half-open interval `[0, num_classes)`, where
383 `num_classes` is the last dimension of predictions.
384 name: (Optional) string name of the metric instance.
385 dtype: (Optional) data type of the metric result.
387 Standalone usage:
389 >>> m = tf.keras.metrics.Precision()
390 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
391 >>> m.result().numpy()
392 0.6666667
394 >>> m.reset_state()
395 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
396 >>> m.result().numpy()
397 1.0
399 >>> # With top_k=2, it will calculate precision over y_true[:2]
400 >>> # and y_pred[:2]
401 >>> m = tf.keras.metrics.Precision(top_k=2)
402 >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
403 >>> m.result().numpy()
404 0.0
406 >>> # With top_k=4, it will calculate precision over y_true[:4]
407 >>> # and y_pred[:4]
408 >>> m = tf.keras.metrics.Precision(top_k=4)
409 >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
410 >>> m.result().numpy()
411 0.5
413 Usage with `compile()` API:
415 ```python
416 model.compile(optimizer='sgd',
417 loss='mse',
418 metrics=[tf.keras.metrics.Precision()])
419 ```
421 Usage with a loss with `from_logits=True`:
423 ```python
424 model.compile(optimizer='adam',
425 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
426 metrics=[tf.keras.metrics.Precision(thresholds=0)])
427 ```
428 """
430 @dtensor_utils.inject_mesh
431 def __init__(
432 self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None
433 ):
434 super().__init__(name=name, dtype=dtype)
435 self.init_thresholds = thresholds
436 self.top_k = top_k
437 self.class_id = class_id
439 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
440 self.thresholds = metrics_utils.parse_init_thresholds(
441 thresholds, default_threshold=default_threshold
442 )
443 self._thresholds_distributed_evenly = (
444 metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
445 )
446 self.true_positives = self.add_weight(
447 "true_positives", shape=(len(self.thresholds),), initializer="zeros"
448 )
449 self.false_positives = self.add_weight(
450 "false_positives",
451 shape=(len(self.thresholds),),
452 initializer="zeros",
453 )
455 def update_state(self, y_true, y_pred, sample_weight=None):
456 """Accumulates true positive and false positive statistics.
458 Args:
459 y_true: The ground truth values, with the same dimensions as `y_pred`.
460 Will be cast to `bool`.
461 y_pred: The predicted values. Each element must be in the range
462 `[0, 1]`.
463 sample_weight: Optional weighting of each example. Defaults to 1. Can
464 be a `Tensor` whose rank is either 0, or the same rank as `y_true`,
465 and must be broadcastable to `y_true`.
467 Returns:
468 Update op.
469 """
470 return metrics_utils.update_confusion_matrix_variables(
471 {
472 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
473 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
474 },
475 y_true,
476 y_pred,
477 thresholds=self.thresholds,
478 thresholds_distributed_evenly=self._thresholds_distributed_evenly,
479 top_k=self.top_k,
480 class_id=self.class_id,
481 sample_weight=sample_weight,
482 )
484 def result(self):
485 result = tf.math.divide_no_nan(
486 self.true_positives,
487 tf.math.add(self.true_positives, self.false_positives),
488 )
489 return result[0] if len(self.thresholds) == 1 else result
491 def reset_state(self):
492 num_thresholds = len(to_list(self.thresholds))
493 backend.batch_set_value(
494 [
495 (v, np.zeros((num_thresholds,)))
496 for v in (self.true_positives, self.false_positives)
497 ]
498 )
500 def get_config(self):
501 config = {
502 "thresholds": self.init_thresholds,
503 "top_k": self.top_k,
504 "class_id": self.class_id,
505 }
506 base_config = super().get_config()
507 return dict(list(base_config.items()) + list(config.items()))
510@keras_export("keras.metrics.Recall")
511class Recall(base_metric.Metric):
512 """Computes the recall of the predictions with respect to the labels.
514 This metric creates two local variables, `true_positives` and
515 `false_negatives`, that are used to compute the recall. This value is
516 ultimately returned as `recall`, an idempotent operation that simply divides
517 `true_positives` by the sum of `true_positives` and `false_negatives`.
519 If `sample_weight` is `None`, weights default to 1.
520 Use `sample_weight` of 0 to mask values.
522 If `top_k` is set, recall will be computed as how often on average a class
523 among the labels of a batch entry is in the top-k predictions.
525 If `class_id` is specified, we calculate recall by considering only the
526 entries in the batch for which `class_id` is in the label, and computing the
527 fraction of them for which `class_id` is above the threshold and/or in the
528 top-k predictions.
530 Args:
531 thresholds: (Optional) A float value, or a Python list/tuple of float
532 threshold values in [0, 1]. A threshold is compared with prediction
533 values to determine the truth value of predictions (i.e., above the
534 threshold is `true`, below is `false`). If used with a loss function
535 that sets `from_logits=True` (i.e. no sigmoid applied to predictions),
536 `thresholds` should be set to 0. One metric value is generated for each
537 threshold value. If neither thresholds nor top_k are set, the default is
538 to calculate recall with `thresholds=0.5`.
539 top_k: (Optional) Unset by default. An int value specifying the top-k
540 predictions to consider when calculating recall.
541 class_id: (Optional) Integer class ID for which we want binary metrics.
542 This must be in the half-open interval `[0, num_classes)`, where
543 `num_classes` is the last dimension of predictions.
544 name: (Optional) string name of the metric instance.
545 dtype: (Optional) data type of the metric result.
547 Standalone usage:
549 >>> m = tf.keras.metrics.Recall()
550 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
551 >>> m.result().numpy()
552 0.6666667
554 >>> m.reset_state()
555 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
556 >>> m.result().numpy()
557 1.0
559 Usage with `compile()` API:
561 ```python
562 model.compile(optimizer='sgd',
563 loss='mse',
564 metrics=[tf.keras.metrics.Recall()])
565 ```
567 Usage with a loss with `from_logits=True`:
569 ```python
570 model.compile(optimizer='adam',
571 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
572 metrics=[tf.keras.metrics.Recall(thresholds=0)])
573 ```
574 """
576 @dtensor_utils.inject_mesh
577 def __init__(
578 self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None
579 ):
580 super().__init__(name=name, dtype=dtype)
581 self.init_thresholds = thresholds
582 self.top_k = top_k
583 self.class_id = class_id
585 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
586 self.thresholds = metrics_utils.parse_init_thresholds(
587 thresholds, default_threshold=default_threshold
588 )
589 self._thresholds_distributed_evenly = (
590 metrics_utils.is_evenly_distributed_thresholds(self.thresholds)
591 )
592 self.true_positives = self.add_weight(
593 "true_positives", shape=(len(self.thresholds),), initializer="zeros"
594 )
595 self.false_negatives = self.add_weight(
596 "false_negatives",
597 shape=(len(self.thresholds),),
598 initializer="zeros",
599 )
601 def update_state(self, y_true, y_pred, sample_weight=None):
602 """Accumulates true positive and false negative statistics.
604 Args:
605 y_true: The ground truth values, with the same dimensions as `y_pred`.
606 Will be cast to `bool`.
607 y_pred: The predicted values. Each element must be in the range
608 `[0, 1]`.
609 sample_weight: Optional weighting of each example. Defaults to 1. Can
610 be a `Tensor` whose rank is either 0, or the same rank as `y_true`,
611 and must be broadcastable to `y_true`.
613 Returns:
614 Update op.
615 """
616 return metrics_utils.update_confusion_matrix_variables(
617 {
618 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
619 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
620 },
621 y_true,
622 y_pred,
623 thresholds=self.thresholds,
624 thresholds_distributed_evenly=self._thresholds_distributed_evenly,
625 top_k=self.top_k,
626 class_id=self.class_id,
627 sample_weight=sample_weight,
628 )
630 def result(self):
631 result = tf.math.divide_no_nan(
632 self.true_positives,
633 tf.math.add(self.true_positives, self.false_negatives),
634 )
635 return result[0] if len(self.thresholds) == 1 else result
637 def reset_state(self):
638 num_thresholds = len(to_list(self.thresholds))
639 backend.batch_set_value(
640 [
641 (v, np.zeros((num_thresholds,)))
642 for v in (self.true_positives, self.false_negatives)
643 ]
644 )
646 def get_config(self):
647 config = {
648 "thresholds": self.init_thresholds,
649 "top_k": self.top_k,
650 "class_id": self.class_id,
651 }
652 base_config = super().get_config()
653 return dict(list(base_config.items()) + list(config.items()))
656class SensitivitySpecificityBase(base_metric.Metric, metaclass=abc.ABCMeta):
657 """Abstract base class for computing sensitivity and specificity.
659 For additional information about specificity and sensitivity, see
660 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
661 """
663 def __init__(
664 self, value, num_thresholds=200, class_id=None, name=None, dtype=None
665 ):
666 super().__init__(name=name, dtype=dtype)
667 if num_thresholds <= 0:
668 raise ValueError(
669 "Argument `num_thresholds` must be an integer > 0. "
670 f"Received: num_thresholds={num_thresholds}"
671 )
672 self.value = value
673 self.class_id = class_id
674 self.true_positives = self.add_weight(
675 "true_positives", shape=(num_thresholds,), initializer="zeros"
676 )
677 self.true_negatives = self.add_weight(
678 "true_negatives", shape=(num_thresholds,), initializer="zeros"
679 )
680 self.false_positives = self.add_weight(
681 "false_positives", shape=(num_thresholds,), initializer="zeros"
682 )
683 self.false_negatives = self.add_weight(
684 "false_negatives", shape=(num_thresholds,), initializer="zeros"
685 )
687 # Compute `num_thresholds` thresholds in [0, 1]
688 if num_thresholds == 1:
689 self.thresholds = [0.5]
690 self._thresholds_distributed_evenly = False
691 else:
692 thresholds = [
693 (i + 1) * 1.0 / (num_thresholds - 1)
694 for i in range(num_thresholds - 2)
695 ]
696 self.thresholds = [0.0] + thresholds + [1.0]
697 self._thresholds_distributed_evenly = True
699 def update_state(self, y_true, y_pred, sample_weight=None):
700 """Accumulates confusion matrix statistics.
702 Args:
703 y_true: The ground truth values.
704 y_pred: The predicted values.
705 sample_weight: Optional weighting of each example. Defaults to 1. Can
706 be a `Tensor` whose rank is either 0, or the same rank as `y_true`,
707 and must be broadcastable to `y_true`.
709 Returns:
710 Update op.
711 """
712 return metrics_utils.update_confusion_matrix_variables(
713 {
714 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
715 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501
716 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
717 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
718 },
719 y_true,
720 y_pred,
721 thresholds=self.thresholds,
722 thresholds_distributed_evenly=self._thresholds_distributed_evenly,
723 class_id=self.class_id,
724 sample_weight=sample_weight,
725 )
727 def reset_state(self):
728 num_thresholds = len(self.thresholds)
729 confusion_matrix_variables = (
730 self.true_positives,
731 self.true_negatives,
732 self.false_positives,
733 self.false_negatives,
734 )
735 backend.batch_set_value(
736 [
737 (v, np.zeros((num_thresholds,)))
738 for v in confusion_matrix_variables
739 ]
740 )
742 def get_config(self):
743 config = {"class_id": self.class_id}
744 base_config = super().get_config()
745 return dict(list(base_config.items()) + list(config.items()))
747 def _find_max_under_constraint(self, constrained, dependent, predicate):
748 """Returns the maximum of dependent_statistic that satisfies the
749 constraint.
751 Args:
752 constrained: Over these values the constraint
753 is specified. A rank-1 tensor.
754 dependent: From these values the maximum that satiesfies the
755 constraint is selected. Values in this tensor and in
756 `constrained` are linked by having the same threshold at each
757 position, hence this tensor must have the same shape.
758 predicate: A binary boolean functor to be applied to arguments
759 `constrained` and `self.value`, e.g. `tf.greater`.
761 Returns:
762 maximal dependent value, if no value satiesfies the constraint 0.0.
763 """
764 feasible = tf.where(predicate(constrained, self.value))
765 feasible_exists = tf.greater(tf.size(feasible), 0)
766 max_dependent = tf.reduce_max(tf.gather(dependent, feasible))
768 return tf.where(feasible_exists, max_dependent, 0.0)
771@keras_export("keras.metrics.SensitivityAtSpecificity")
772class SensitivityAtSpecificity(SensitivitySpecificityBase):
773 """Computes best sensitivity where specificity is >= specified value.
775 the sensitivity at a given specificity.
777 `Sensitivity` measures the proportion of actual positives that are correctly
778 identified as such (tp / (tp + fn)).
779 `Specificity` measures the proportion of actual negatives that are correctly
780 identified as such (tn / (tn + fp)).
782 This metric creates four local variables, `true_positives`,
783 `true_negatives`, `false_positives` and `false_negatives` that are used to
784 compute the sensitivity at the given specificity. The threshold for the
785 given specificity value is computed and used to evaluate the corresponding
786 sensitivity.
788 If `sample_weight` is `None`, weights default to 1.
789 Use `sample_weight` of 0 to mask values.
791 If `class_id` is specified, we calculate precision by considering only the
792 entries in the batch for which `class_id` is above the threshold
793 predictions, and computing the fraction of them for which `class_id` is
794 indeed a correct label.
796 For additional information about specificity and sensitivity, see
797 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
799 Args:
800 specificity: A scalar value in range `[0, 1]`.
801 num_thresholds: (Optional) Defaults to 200. The number of thresholds to
802 use for matching the given specificity.
803 class_id: (Optional) Integer class ID for which we want binary metrics.
804 This must be in the half-open interval `[0, num_classes)`, where
805 `num_classes` is the last dimension of predictions.
806 name: (Optional) string name of the metric instance.
807 dtype: (Optional) data type of the metric result.
809 Standalone usage:
811 >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5)
812 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
813 >>> m.result().numpy()
814 0.5
816 >>> m.reset_state()
817 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
818 ... sample_weight=[1, 1, 2, 2, 1])
819 >>> m.result().numpy()
820 0.333333
822 Usage with `compile()` API:
824 ```python
825 model.compile(
826 optimizer='sgd',
827 loss='mse',
828 metrics=[tf.keras.metrics.SensitivityAtSpecificity()])
829 ```
830 """
832 @dtensor_utils.inject_mesh
833 def __init__(
834 self,
835 specificity,
836 num_thresholds=200,
837 class_id=None,
838 name=None,
839 dtype=None,
840 ):
841 if specificity < 0 or specificity > 1:
842 raise ValueError(
843 "Argument `specificity` must be in the range [0, 1]. "
844 f"Received: specificity={specificity}"
845 )
846 self.specificity = specificity
847 self.num_thresholds = num_thresholds
848 super().__init__(
849 specificity,
850 num_thresholds=num_thresholds,
851 class_id=class_id,
852 name=name,
853 dtype=dtype,
854 )
856 def result(self):
857 specificities = tf.math.divide_no_nan(
858 self.true_negatives,
859 tf.math.add(self.true_negatives, self.false_positives),
860 )
861 sensitivities = tf.math.divide_no_nan(
862 self.true_positives,
863 tf.math.add(self.true_positives, self.false_negatives),
864 )
865 return self._find_max_under_constraint(
866 specificities, sensitivities, tf.greater_equal
867 )
869 def get_config(self):
870 config = {
871 "num_thresholds": self.num_thresholds,
872 "specificity": self.specificity,
873 }
874 base_config = super().get_config()
875 return dict(list(base_config.items()) + list(config.items()))
878@keras_export("keras.metrics.SpecificityAtSensitivity")
879class SpecificityAtSensitivity(SensitivitySpecificityBase):
880 """Computes best specificity where sensitivity is >= specified value.
882 `Sensitivity` measures the proportion of actual positives that are correctly
883 identified as such (tp / (tp + fn)).
884 `Specificity` measures the proportion of actual negatives that are correctly
885 identified as such (tn / (tn + fp)).
887 This metric creates four local variables, `true_positives`,
888 `true_negatives`, `false_positives` and `false_negatives` that are used to
889 compute the specificity at the given sensitivity. The threshold for the
890 given sensitivity value is computed and used to evaluate the corresponding
891 specificity.
893 If `sample_weight` is `None`, weights default to 1.
894 Use `sample_weight` of 0 to mask values.
896 If `class_id` is specified, we calculate precision by considering only the
897 entries in the batch for which `class_id` is above the threshold
898 predictions, and computing the fraction of them for which `class_id` is
899 indeed a correct label.
901 For additional information about specificity and sensitivity, see
902 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
904 Args:
905 sensitivity: A scalar value in range `[0, 1]`.
906 num_thresholds: (Optional) Defaults to 200. The number of thresholds to
907 use for matching the given sensitivity.
908 class_id: (Optional) Integer class ID for which we want binary metrics.
909 This must be in the half-open interval `[0, num_classes)`, where
910 `num_classes` is the last dimension of predictions.
911 name: (Optional) string name of the metric instance.
912 dtype: (Optional) data type of the metric result.
914 Standalone usage:
916 >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5)
917 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
918 >>> m.result().numpy()
919 0.66666667
921 >>> m.reset_state()
922 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
923 ... sample_weight=[1, 1, 2, 2, 2])
924 >>> m.result().numpy()
925 0.5
927 Usage with `compile()` API:
929 ```python
930 model.compile(
931 optimizer='sgd',
932 loss='mse',
933 metrics=[tf.keras.metrics.SpecificityAtSensitivity()])
934 ```
935 """
937 @dtensor_utils.inject_mesh
938 def __init__(
939 self,
940 sensitivity,
941 num_thresholds=200,
942 class_id=None,
943 name=None,
944 dtype=None,
945 ):
946 if sensitivity < 0 or sensitivity > 1:
947 raise ValueError(
948 "Argument `sensitivity` must be in the range [0, 1]. "
949 f"Received: sensitivity={sensitivity}"
950 )
951 self.sensitivity = sensitivity
952 self.num_thresholds = num_thresholds
953 super().__init__(
954 sensitivity,
955 num_thresholds=num_thresholds,
956 class_id=class_id,
957 name=name,
958 dtype=dtype,
959 )
961 def result(self):
962 sensitivities = tf.math.divide_no_nan(
963 self.true_positives,
964 tf.math.add(self.true_positives, self.false_negatives),
965 )
966 specificities = tf.math.divide_no_nan(
967 self.true_negatives,
968 tf.math.add(self.true_negatives, self.false_positives),
969 )
970 return self._find_max_under_constraint(
971 sensitivities, specificities, tf.greater_equal
972 )
974 def get_config(self):
975 config = {
976 "num_thresholds": self.num_thresholds,
977 "sensitivity": self.sensitivity,
978 }
979 base_config = super().get_config()
980 return dict(list(base_config.items()) + list(config.items()))
983@keras_export("keras.metrics.PrecisionAtRecall")
984class PrecisionAtRecall(SensitivitySpecificityBase):
985 """Computes best precision where recall is >= specified value.
987 This metric creates four local variables, `true_positives`,
988 `true_negatives`, `false_positives` and `false_negatives` that are used to
989 compute the precision at the given recall. The threshold for the given
990 recall value is computed and used to evaluate the corresponding precision.
992 If `sample_weight` is `None`, weights default to 1.
993 Use `sample_weight` of 0 to mask values.
995 If `class_id` is specified, we calculate precision by considering only the
996 entries in the batch for which `class_id` is above the threshold
997 predictions, and computing the fraction of them for which `class_id` is
998 indeed a correct label.
1000 Args:
1001 recall: A scalar value in range `[0, 1]`.
1002 num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1003 use for matching the given recall.
1004 class_id: (Optional) Integer class ID for which we want binary metrics.
1005 This must be in the half-open interval `[0, num_classes)`, where
1006 `num_classes` is the last dimension of predictions.
1007 name: (Optional) string name of the metric instance.
1008 dtype: (Optional) data type of the metric result.
1010 Standalone usage:
1012 >>> m = tf.keras.metrics.PrecisionAtRecall(0.5)
1013 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
1014 >>> m.result().numpy()
1015 0.5
1017 >>> m.reset_state()
1018 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
1019 ... sample_weight=[2, 2, 2, 1, 1])
1020 >>> m.result().numpy()
1021 0.33333333
1023 Usage with `compile()` API:
1025 ```python
1026 model.compile(
1027 optimizer='sgd',
1028 loss='mse',
1029 metrics=[tf.keras.metrics.PrecisionAtRecall(recall=0.8)])
1030 ```
1031 """
1033 @dtensor_utils.inject_mesh
1034 def __init__(
1035 self, recall, num_thresholds=200, class_id=None, name=None, dtype=None
1036 ):
1037 if recall < 0 or recall > 1:
1038 raise ValueError(
1039 "Argument `recall` must be in the range [0, 1]. "
1040 f"Received: recall={recall}"
1041 )
1042 self.recall = recall
1043 self.num_thresholds = num_thresholds
1044 super().__init__(
1045 value=recall,
1046 num_thresholds=num_thresholds,
1047 class_id=class_id,
1048 name=name,
1049 dtype=dtype,
1050 )
1052 def result(self):
1053 recalls = tf.math.divide_no_nan(
1054 self.true_positives,
1055 tf.math.add(self.true_positives, self.false_negatives),
1056 )
1057 precisions = tf.math.divide_no_nan(
1058 self.true_positives,
1059 tf.math.add(self.true_positives, self.false_positives),
1060 )
1061 return self._find_max_under_constraint(
1062 recalls, precisions, tf.greater_equal
1063 )
1065 def get_config(self):
1066 config = {"num_thresholds": self.num_thresholds, "recall": self.recall}
1067 base_config = super().get_config()
1068 return dict(list(base_config.items()) + list(config.items()))
1071@keras_export("keras.metrics.RecallAtPrecision")
1072class RecallAtPrecision(SensitivitySpecificityBase):
1073 """Computes best recall where precision is >= specified value.
1075 For a given score-label-distribution the required precision might not
1076 be achievable, in this case 0.0 is returned as recall.
1078 This metric creates four local variables, `true_positives`,
1079 `true_negatives`, `false_positives` and `false_negatives` that are used to
1080 compute the recall at the given precision. The threshold for the given
1081 precision value is computed and used to evaluate the corresponding recall.
1083 If `sample_weight` is `None`, weights default to 1.
1084 Use `sample_weight` of 0 to mask values.
1086 If `class_id` is specified, we calculate precision by considering only the
1087 entries in the batch for which `class_id` is above the threshold
1088 predictions, and computing the fraction of them for which `class_id` is
1089 indeed a correct label.
1091 Args:
1092 precision: A scalar value in range `[0, 1]`.
1093 num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1094 use for matching the given precision.
1095 class_id: (Optional) Integer class ID for which we want binary metrics.
1096 This must be in the half-open interval `[0, num_classes)`, where
1097 `num_classes` is the last dimension of predictions.
1098 name: (Optional) string name of the metric instance.
1099 dtype: (Optional) data type of the metric result.
1101 Standalone usage:
1103 >>> m = tf.keras.metrics.RecallAtPrecision(0.8)
1104 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
1105 >>> m.result().numpy()
1106 0.5
1108 >>> m.reset_state()
1109 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
1110 ... sample_weight=[1, 0, 0, 1])
1111 >>> m.result().numpy()
1112 1.0
1114 Usage with `compile()` API:
1116 ```python
1117 model.compile(
1118 optimizer='sgd',
1119 loss='mse',
1120 metrics=[tf.keras.metrics.RecallAtPrecision(precision=0.8)])
1121 ```
1122 """
1124 @dtensor_utils.inject_mesh
1125 def __init__(
1126 self,
1127 precision,
1128 num_thresholds=200,
1129 class_id=None,
1130 name=None,
1131 dtype=None,
1132 ):
1133 if precision < 0 or precision > 1:
1134 raise ValueError(
1135 "Argument `precision` must be in the range [0, 1]. "
1136 f"Received: precision={precision}"
1137 )
1138 self.precision = precision
1139 self.num_thresholds = num_thresholds
1140 super().__init__(
1141 value=precision,
1142 num_thresholds=num_thresholds,
1143 class_id=class_id,
1144 name=name,
1145 dtype=dtype,
1146 )
1148 def result(self):
1149 precisions = tf.math.divide_no_nan(
1150 self.true_positives,
1151 tf.math.add(self.true_positives, self.false_positives),
1152 )
1153 recalls = tf.math.divide_no_nan(
1154 self.true_positives,
1155 tf.math.add(self.true_positives, self.false_negatives),
1156 )
1157 return self._find_max_under_constraint(
1158 precisions, recalls, tf.greater_equal
1159 )
1161 def get_config(self):
1162 config = {
1163 "num_thresholds": self.num_thresholds,
1164 "precision": self.precision,
1165 }
1166 base_config = super().get_config()
1167 return dict(list(base_config.items()) + list(config.items()))
1170@keras_export("keras.metrics.AUC")
1171class AUC(base_metric.Metric):
1172 """Approximates the AUC (Area under the curve) of the ROC or PR curves.
1174 The AUC (Area under the curve) of the ROC (Receiver operating
1175 characteristic; default) or PR (Precision Recall) curves are quality
1176 measures of binary classifiers. Unlike the accuracy, and like cross-entropy
1177 losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
1179 This class approximates AUCs using a Riemann sum. During the metric
1180 accumulation phrase, predictions are accumulated within predefined buckets
1181 by value. The AUC is then computed by interpolating per-bucket averages.
1182 These buckets define the evaluated operational points.
1184 This metric creates four local variables, `true_positives`,
1185 `true_negatives`, `false_positives` and `false_negatives` that are used to
1186 compute the AUC. To discretize the AUC curve, a linearly spaced set of
1187 thresholds is used to compute pairs of recall and precision values. The area
1188 under the ROC-curve is therefore computed using the height of the recall
1189 values by the false positive rate, while the area under the PR-curve is the
1190 computed using the height of the precision values by the recall.
1192 This value is ultimately returned as `auc`, an idempotent operation that
1193 computes the area under a discretized curve of precision versus recall
1194 values (computed using the aforementioned variables). The `num_thresholds`
1195 variable controls the degree of discretization with larger numbers of
1196 thresholds more closely approximating the true AUC. The quality of the
1197 approximation may vary dramatically depending on `num_thresholds`. The
1198 `thresholds` parameter can be used to manually specify thresholds which
1199 split the predictions more evenly.
1201 For a best approximation of the real AUC, `predictions` should be
1202 distributed approximately uniformly in the range [0, 1] (if
1203 `from_logits=False`). The quality of the AUC approximation may be poor if
1204 this is not the case. Setting `summation_method` to 'minoring' or 'majoring'
1205 can help quantify the error in the approximation by providing lower or upper
1206 bound estimate of the AUC.
1208 If `sample_weight` is `None`, weights default to 1.
1209 Use `sample_weight` of 0 to mask values.
1211 Args:
1212 num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1213 use when discretizing the roc curve. Values must be > 1.
1214 curve: (Optional) Specifies the name of the curve to be computed, 'ROC'
1215 [default] or 'PR' for the Precision-Recall-curve.
1216 summation_method: (Optional) Specifies the [Riemann summation method](
1217 https://en.wikipedia.org/wiki/Riemann_sum) used.
1218 'interpolation' (default) applies mid-point summation scheme for
1219 `ROC`. For PR-AUC, interpolates (true/false) positives but not the
1220 ratio that is precision (see Davis & Goadrich 2006 for details);
1221 'minoring' applies left summation for increasing intervals and right
1222 summation for decreasing intervals; 'majoring' does the opposite.
1223 name: (Optional) string name of the metric instance.
1224 dtype: (Optional) data type of the metric result.
1225 thresholds: (Optional) A list of floating point values to use as the
1226 thresholds for discretizing the curve. If set, the `num_thresholds`
1227 parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
1228 equal to {-epsilon, 1+epsilon} for a small positive epsilon value will
1229 be automatically included with these to correctly handle predictions
1230 equal to exactly 0 or 1.
1231 multi_label: boolean indicating whether multilabel data should be
1232 treated as such, wherein AUC is computed separately for each label and
1233 then averaged across labels, or (when False) if the data should be
1234 flattened into a single label before AUC computation. In the latter
1235 case, when multilabel data is passed to AUC, each label-prediction pair
1236 is treated as an individual data point. Should be set to False for
1237 multi-class data.
1238 num_labels: (Optional) The number of labels, used when `multi_label` is
1239 True. If `num_labels` is not specified, then state variables get created
1240 on the first call to `update_state`.
1241 label_weights: (Optional) list, array, or tensor of non-negative weights
1242 used to compute AUCs for multilabel data. When `multi_label` is True,
1243 the weights are applied to the individual label AUCs when they are
1244 averaged to produce the multi-label AUC. When it's False, they are used
1245 to weight the individual label predictions in computing the confusion
1246 matrix on the flattened data. Note that this is unlike class_weights in
1247 that class_weights weights the example depending on the value of its
1248 label, whereas label_weights depends only on the index of that label
1249 before flattening; therefore `label_weights` should not be used for
1250 multi-class data.
1251 from_logits: boolean indicating whether the predictions (`y_pred` in
1252 `update_state`) are probabilities or sigmoid logits. As a rule of thumb,
1253 when using a keras loss, the `from_logits` constructor argument of the
1254 loss should match the AUC `from_logits` constructor argument.
1256 Standalone usage:
1258 >>> m = tf.keras.metrics.AUC(num_thresholds=3)
1259 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
1260 >>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
1261 >>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
1262 >>> # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0]
1263 >>> # auc = ((((1+0.5)/2)*(1-0)) + (((0.5+0)/2)*(0-0))) = 0.75
1264 >>> m.result().numpy()
1265 0.75
1267 >>> m.reset_state()
1268 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
1269 ... sample_weight=[1, 0, 0, 1])
1270 >>> m.result().numpy()
1271 1.0
1273 Usage with `compile()` API:
1275 ```python
1276 # Reports the AUC of a model outputting a probability.
1277 model.compile(optimizer='sgd',
1278 loss=tf.keras.losses.BinaryCrossentropy(),
1279 metrics=[tf.keras.metrics.AUC()])
1281 # Reports the AUC of a model outputting a logit.
1282 model.compile(optimizer='sgd',
1283 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
1284 metrics=[tf.keras.metrics.AUC(from_logits=True)])
1285 ```
1286 """
1288 @dtensor_utils.inject_mesh
1289 def __init__(
1290 self,
1291 num_thresholds=200,
1292 curve="ROC",
1293 summation_method="interpolation",
1294 name=None,
1295 dtype=None,
1296 thresholds=None,
1297 multi_label=False,
1298 num_labels=None,
1299 label_weights=None,
1300 from_logits=False,
1301 ):
1302 # Validate configurations.
1303 if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
1304 metrics_utils.AUCCurve
1305 ):
1306 raise ValueError(
1307 f'Invalid `curve` argument value "{curve}". '
1308 f"Expected one of: {list(metrics_utils.AUCCurve)}"
1309 )
1310 if isinstance(
1311 summation_method, metrics_utils.AUCSummationMethod
1312 ) and summation_method not in list(metrics_utils.AUCSummationMethod):
1313 raise ValueError(
1314 "Invalid `summation_method` "
1315 f'argument value "{summation_method}". '
1316 f"Expected one of: {list(metrics_utils.AUCSummationMethod)}"
1317 )
1319 # Update properties.
1320 self._init_from_thresholds = thresholds is not None
1321 if thresholds is not None:
1322 # If specified, use the supplied thresholds.
1323 self.num_thresholds = len(thresholds) + 2
1324 thresholds = sorted(thresholds)
1325 self._thresholds_distributed_evenly = (
1326 metrics_utils.is_evenly_distributed_thresholds(
1327 np.array([0.0] + thresholds + [1.0])
1328 )
1329 )
1330 else:
1331 if num_thresholds <= 1:
1332 raise ValueError(
1333 "Argument `num_thresholds` must be an integer > 1. "
1334 f"Received: num_thresholds={num_thresholds}"
1335 )
1337 # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
1338 # (0, 1).
1339 self.num_thresholds = num_thresholds
1340 thresholds = [
1341 (i + 1) * 1.0 / (num_thresholds - 1)
1342 for i in range(num_thresholds - 2)
1343 ]
1344 self._thresholds_distributed_evenly = True
1346 # Add an endpoint "threshold" below zero and above one for either
1347 # threshold method to account for floating point imprecisions.
1348 self._thresholds = np.array(
1349 [0.0 - backend.epsilon()] + thresholds + [1.0 + backend.epsilon()]
1350 )
1352 if isinstance(curve, metrics_utils.AUCCurve):
1353 self.curve = curve
1354 else:
1355 self.curve = metrics_utils.AUCCurve.from_str(curve)
1356 if isinstance(summation_method, metrics_utils.AUCSummationMethod):
1357 self.summation_method = summation_method
1358 else:
1359 self.summation_method = metrics_utils.AUCSummationMethod.from_str(
1360 summation_method
1361 )
1362 super().__init__(name=name, dtype=dtype)
1364 # Handle multilabel arguments.
1365 self.multi_label = multi_label
1366 self.num_labels = num_labels
1367 if label_weights is not None:
1368 label_weights = tf.constant(label_weights, dtype=self.dtype)
1369 tf.debugging.assert_non_negative(
1370 label_weights,
1371 message="All values of `label_weights` must be non-negative.",
1372 )
1373 self.label_weights = label_weights
1375 else:
1376 self.label_weights = None
1378 self._from_logits = from_logits
1380 self._built = False
1381 if self.multi_label:
1382 if num_labels:
1383 shape = tf.TensorShape([None, num_labels])
1384 self._build(shape)
1385 else:
1386 if num_labels:
1387 raise ValueError(
1388 "`num_labels` is needed only when `multi_label` is True."
1389 )
1390 self._build(None)
1392 @property
1393 def thresholds(self):
1394 """The thresholds used for evaluating AUC."""
1395 return list(self._thresholds)
1397 def _build(self, shape):
1398 """Initialize TP, FP, TN, and FN tensors, given the shape of the
1399 data."""
1400 if self.multi_label:
1401 if shape.ndims != 2:
1402 raise ValueError(
1403 "`y_true` must have rank 2 when `multi_label=True`. "
1404 f"Found rank {shape.ndims}. "
1405 f"Full shape received for `y_true`: {shape}"
1406 )
1407 self._num_labels = shape[1]
1408 variable_shape = tf.TensorShape(
1409 [self.num_thresholds, self._num_labels]
1410 )
1411 else:
1412 variable_shape = tf.TensorShape([self.num_thresholds])
1414 self._build_input_shape = shape
1415 # Create metric variables
1416 self.true_positives = self.add_weight(
1417 "true_positives", shape=variable_shape, initializer="zeros"
1418 )
1419 self.true_negatives = self.add_weight(
1420 "true_negatives", shape=variable_shape, initializer="zeros"
1421 )
1422 self.false_positives = self.add_weight(
1423 "false_positives", shape=variable_shape, initializer="zeros"
1424 )
1425 self.false_negatives = self.add_weight(
1426 "false_negatives", shape=variable_shape, initializer="zeros"
1427 )
1429 if self.multi_label:
1430 with tf.init_scope():
1431 # This should only be necessary for handling v1 behavior. In v2,
1432 # AUC should be initialized outside of any tf.functions, and
1433 # therefore in eager mode.
1434 if not tf.executing_eagerly():
1435 backend._initialize_variables(backend._get_session())
1437 self._built = True
1439 def update_state(self, y_true, y_pred, sample_weight=None):
1440 """Accumulates confusion matrix statistics.
1442 Args:
1443 y_true: The ground truth values.
1444 y_pred: The predicted values.
1445 sample_weight: Optional weighting of each example. Defaults to 1. Can
1446 be a `Tensor` whose rank is either 0, or the same rank as `y_true`,
1447 and must be broadcastable to `y_true`.
1449 Returns:
1450 Update op.
1451 """
1452 if not self._built:
1453 self._build(tf.TensorShape(y_pred.shape))
1455 if self.multi_label or (self.label_weights is not None):
1456 # y_true should have shape (number of examples, number of labels).
1457 shapes = [(y_true, ("N", "L"))]
1458 if self.multi_label:
1459 # TP, TN, FP, and FN should all have shape
1460 # (number of thresholds, number of labels).
1461 shapes.extend(
1462 [
1463 (self.true_positives, ("T", "L")),
1464 (self.true_negatives, ("T", "L")),
1465 (self.false_positives, ("T", "L")),
1466 (self.false_negatives, ("T", "L")),
1467 ]
1468 )
1469 if self.label_weights is not None:
1470 # label_weights should be of length equal to the number of
1471 # labels.
1472 shapes.append((self.label_weights, ("L",)))
1473 tf.debugging.assert_shapes(
1474 shapes, message="Number of labels is not consistent."
1475 )
1477 # Only forward label_weights to update_confusion_matrix_variables when
1478 # multi_label is False. Otherwise the averaging of individual label AUCs
1479 # is handled in AUC.result
1480 label_weights = None if self.multi_label else self.label_weights
1482 if self._from_logits:
1483 y_pred = activations.sigmoid(y_pred)
1485 return metrics_utils.update_confusion_matrix_variables(
1486 {
1487 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501
1488 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501
1489 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501
1490 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501
1491 },
1492 y_true,
1493 y_pred,
1494 self._thresholds,
1495 thresholds_distributed_evenly=self._thresholds_distributed_evenly,
1496 sample_weight=sample_weight,
1497 multi_label=self.multi_label,
1498 label_weights=label_weights,
1499 )
1501 def interpolate_pr_auc(self):
1502 """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
1504 https://www.biostat.wisc.edu/~page/rocpr.pdf
1506 Note here we derive & use a closed formula not present in the paper
1507 as follows:
1509 Precision = TP / (TP + FP) = TP / P
1511 Modeling all of TP (true positive), FP (false positive) and their sum
1512 P = TP + FP (predicted positive) as varying linearly within each
1513 interval [A, B] between successive thresholds, we get
1515 Precision slope = dTP / dP
1516 = (TP_B - TP_A) / (P_B - P_A)
1517 = (TP - TP_A) / (P - P_A)
1518 Precision = (TP_A + slope * (P - P_A)) / P
1520 The area within the interval is (slope / total_pos_weight) times
1522 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
1523 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
1525 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
1527 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
1529 Bringing back the factor (slope / total_pos_weight) we'd put aside, we
1530 get
1532 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
1534 where dTP == TP_B - TP_A.
1536 Note that when P_A == 0 the above calculation simplifies into
1538 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
1540 which is really equivalent to imputing constant precision throughout the
1541 first bucket having >0 true positives.
1543 Returns:
1544 pr_auc: an approximation of the area under the P-R curve.
1545 """
1546 dtp = (
1547 self.true_positives[: self.num_thresholds - 1]
1548 - self.true_positives[1:]
1549 )
1550 p = tf.math.add(self.true_positives, self.false_positives)
1551 dp = p[: self.num_thresholds - 1] - p[1:]
1552 prec_slope = tf.math.divide_no_nan(
1553 dtp, tf.maximum(dp, 0), name="prec_slope"
1554 )
1555 intercept = self.true_positives[1:] - tf.multiply(prec_slope, p[1:])
1557 safe_p_ratio = tf.where(
1558 tf.logical_and(p[: self.num_thresholds - 1] > 0, p[1:] > 0),
1559 tf.math.divide_no_nan(
1560 p[: self.num_thresholds - 1],
1561 tf.maximum(p[1:], 0),
1562 name="recall_relative_ratio",
1563 ),
1564 tf.ones_like(p[1:]),
1565 )
1567 pr_auc_increment = tf.math.divide_no_nan(
1568 prec_slope * (dtp + intercept * tf.math.log(safe_p_ratio)),
1569 tf.maximum(self.true_positives[1:] + self.false_negatives[1:], 0),
1570 name="pr_auc_increment",
1571 )
1573 if self.multi_label:
1574 by_label_auc = tf.reduce_sum(
1575 pr_auc_increment, name=self.name + "_by_label", axis=0
1576 )
1577 if self.label_weights is None:
1578 # Evenly weighted average of the label AUCs.
1579 return tf.reduce_mean(by_label_auc, name=self.name)
1580 else:
1581 # Weighted average of the label AUCs.
1582 return tf.math.divide_no_nan(
1583 tf.reduce_sum(
1584 tf.multiply(by_label_auc, self.label_weights)
1585 ),
1586 tf.reduce_sum(self.label_weights),
1587 name=self.name,
1588 )
1589 else:
1590 return tf.reduce_sum(pr_auc_increment, name="interpolate_pr_auc")
1592 def result(self):
1593 if (
1594 self.curve == metrics_utils.AUCCurve.PR
1595 and self.summation_method
1596 == metrics_utils.AUCSummationMethod.INTERPOLATION
1597 ):
1598 # This use case is different and is handled separately.
1599 return self.interpolate_pr_auc()
1601 # Set `x` and `y` values for the curves based on `curve` config.
1602 recall = tf.math.divide_no_nan(
1603 self.true_positives,
1604 tf.math.add(self.true_positives, self.false_negatives),
1605 )
1606 if self.curve == metrics_utils.AUCCurve.ROC:
1607 fp_rate = tf.math.divide_no_nan(
1608 self.false_positives,
1609 tf.math.add(self.false_positives, self.true_negatives),
1610 )
1611 x = fp_rate
1612 y = recall
1613 else: # curve == 'PR'.
1614 precision = tf.math.divide_no_nan(
1615 self.true_positives,
1616 tf.math.add(self.true_positives, self.false_positives),
1617 )
1618 x = recall
1619 y = precision
1621 # Find the rectangle heights based on `summation_method`.
1622 if (
1623 self.summation_method
1624 == metrics_utils.AUCSummationMethod.INTERPOLATION
1625 ):
1626 # Note: the case ('PR', 'interpolation') has been handled above.
1627 heights = (y[: self.num_thresholds - 1] + y[1:]) / 2.0
1628 elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:
1629 heights = tf.minimum(y[: self.num_thresholds - 1], y[1:])
1630 # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:
1631 else:
1632 heights = tf.maximum(y[: self.num_thresholds - 1], y[1:])
1634 # Sum up the areas of all the rectangles.
1635 if self.multi_label:
1636 riemann_terms = tf.multiply(
1637 x[: self.num_thresholds - 1] - x[1:], heights
1638 )
1639 by_label_auc = tf.reduce_sum(
1640 riemann_terms, name=self.name + "_by_label", axis=0
1641 )
1643 if self.label_weights is None:
1644 # Unweighted average of the label AUCs.
1645 return tf.reduce_mean(by_label_auc, name=self.name)
1646 else:
1647 # Weighted average of the label AUCs.
1648 return tf.math.divide_no_nan(
1649 tf.reduce_sum(
1650 tf.multiply(by_label_auc, self.label_weights)
1651 ),
1652 tf.reduce_sum(self.label_weights),
1653 name=self.name,
1654 )
1655 else:
1656 return tf.reduce_sum(
1657 tf.multiply(x[: self.num_thresholds - 1] - x[1:], heights),
1658 name=self.name,
1659 )
1661 def reset_state(self):
1662 if self._built:
1663 confusion_matrix_variables = (
1664 self.true_positives,
1665 self.true_negatives,
1666 self.false_positives,
1667 self.false_negatives,
1668 )
1669 if self.multi_label:
1670 backend.batch_set_value(
1671 [
1672 (v, np.zeros((self.num_thresholds, self._num_labels)))
1673 for v in confusion_matrix_variables
1674 ]
1675 )
1676 else:
1677 backend.batch_set_value(
1678 [
1679 (v, np.zeros((self.num_thresholds,)))
1680 for v in confusion_matrix_variables
1681 ]
1682 )
1684 def get_config(self):
1685 if is_tensor_or_variable(self.label_weights):
1686 label_weights = backend.eval(self.label_weights)
1687 else:
1688 label_weights = self.label_weights
1689 config = {
1690 "num_thresholds": self.num_thresholds,
1691 "curve": self.curve.value,
1692 "summation_method": self.summation_method.value,
1693 "multi_label": self.multi_label,
1694 "num_labels": self.num_labels,
1695 "label_weights": label_weights,
1696 "from_logits": self._from_logits,
1697 }
1698 # optimization to avoid serializing a large number of generated
1699 # thresholds
1700 if self._init_from_thresholds:
1701 # We remove the endpoint thresholds as an inverse of how the
1702 # thresholds were initialized. This ensures that a metric
1703 # initialized from this config has the same thresholds.
1704 config["thresholds"] = self.thresholds[1:-1]
1705 base_config = super().get_config()
1706 return dict(list(base_config.items()) + list(config.items()))