Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/metrics/iou_metrics.py: 41%
102 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""IoU metrics."""
17from typing import List
18from typing import Optional
19from typing import Tuple
20from typing import Union
22import numpy as np
23import tensorflow.compat.v2 as tf
25from keras.src import backend
26from keras.src.dtensor import utils as dtensor_utils
27from keras.src.metrics import base_metric
29# isort: off
30from tensorflow.python.util.tf_export import keras_export
33class _IoUBase(base_metric.Metric):
34 """Computes the confusion matrix for Intersection-Over-Union metrics.
36 Intersection-Over-Union is a common evaluation metric for semantic image
37 segmentation.
39 For an individual class, the IoU metric is defined as follows:
41 ```
42 iou = true_positives / (true_positives + false_positives + false_negatives)
43 ```
45 From IoUs of individual classes, the MeanIoU can be computed as the mean of
46 the individual IoUs.
48 To compute IoUs, the predictions are accumulated in a confusion matrix,
49 weighted by `sample_weight` and the metric is then calculated from it.
51 If `sample_weight` is `None`, weights default to 1.
52 Use `sample_weight` of 0 to mask values.
54 Args:
55 num_classes: The possible number of labels the prediction task can have.
56 This value must be provided, since a confusion matrix of size
57 `(num_classes, num_classes)` will be allocated.
58 name: (Optional) string name of the metric instance.
59 dtype: (Optional) data type of the metric result.
60 ignore_class: Optional integer. The ID of a class to be ignored during
61 metric computation. This is useful, for example, in segmentation
62 problems featuring a "void" class (commonly -1 or 255) in segmentation
63 maps. By default (`ignore_class=None`), all classes are considered.
64 sparse_y_true: Whether labels are encoded using integers or
65 dense floating point vectors. If `False`, the `tf.argmax` function
66 will be used to determine each sample's most likely associated label.
67 sparse_y_pred: Whether predictions are encoded using integers or
68 dense floating point vectors. If `False`, the `tf.argmax` function
69 will be used to determine each sample's most likely associated label.
70 axis: (Optional) Defaults to -1. The dimension containing the logits.
71 """
73 def __init__(
74 self,
75 num_classes: int,
76 name: Optional[str] = None,
77 dtype: Optional[Union[str, tf.dtypes.DType]] = None,
78 ignore_class: Optional[int] = None,
79 sparse_y_true: bool = True,
80 sparse_y_pred: bool = True,
81 axis: int = -1,
82 ):
83 super().__init__(name=name, dtype=dtype)
84 self.num_classes = num_classes
85 self.ignore_class = ignore_class
86 self.sparse_y_true = sparse_y_true
87 self.sparse_y_pred = sparse_y_pred
88 self.axis = axis
90 # Variable to accumulate the predictions in the confusion matrix.
91 self.total_cm = self.add_weight(
92 "total_confusion_matrix",
93 shape=(num_classes, num_classes),
94 initializer="zeros",
95 )
97 def update_state(self, y_true, y_pred, sample_weight=None):
98 """Accumulates the confusion matrix statistics.
100 Args:
101 y_true: The ground truth values.
102 y_pred: The predicted values.
103 sample_weight: Optional weighting of each example. Defaults to 1. Can
104 be a `Tensor` whose rank is either 0, or the same rank as `y_true`,
105 and must be broadcastable to `y_true`.
107 Returns:
108 Update op.
109 """
111 if not self.sparse_y_true:
112 y_true = tf.argmax(y_true, axis=self.axis)
113 if not self.sparse_y_pred:
114 y_pred = tf.argmax(y_pred, axis=self.axis)
116 y_true = tf.cast(y_true, self._dtype)
117 y_pred = tf.cast(y_pred, self._dtype)
119 # Flatten the input if its rank > 1.
120 if y_pred.shape.ndims > 1:
121 y_pred = tf.reshape(y_pred, [-1])
123 if y_true.shape.ndims > 1:
124 y_true = tf.reshape(y_true, [-1])
126 if sample_weight is not None:
127 sample_weight = tf.cast(sample_weight, self._dtype)
128 if sample_weight.shape.ndims > 1:
129 sample_weight = tf.reshape(sample_weight, [-1])
131 if self.ignore_class is not None:
132 ignore_class = tf.cast(self.ignore_class, y_true.dtype)
133 valid_mask = tf.not_equal(y_true, ignore_class)
134 y_true = y_true[valid_mask]
135 y_pred = y_pred[valid_mask]
136 if sample_weight is not None:
137 sample_weight = sample_weight[valid_mask]
139 # Accumulate the prediction to current confusion matrix.
140 current_cm = tf.math.confusion_matrix(
141 y_true,
142 y_pred,
143 self.num_classes,
144 weights=sample_weight,
145 dtype=self._dtype,
146 )
147 return self.total_cm.assign_add(current_cm)
149 def reset_state(self):
150 backend.set_value(
151 self.total_cm, np.zeros((self.num_classes, self.num_classes))
152 )
155@keras_export("keras.metrics.IoU")
156class IoU(_IoUBase):
157 """Computes the Intersection-Over-Union metric for specific target classes.
159 General definition and computation:
161 Intersection-Over-Union is a common evaluation metric for semantic image
162 segmentation.
164 For an individual class, the IoU metric is defined as follows:
166 ```
167 iou = true_positives / (true_positives + false_positives + false_negatives)
168 ```
170 To compute IoUs, the predictions are accumulated in a confusion matrix,
171 weighted by `sample_weight` and the metric is then calculated from it.
173 If `sample_weight` is `None`, weights default to 1.
174 Use `sample_weight` of 0 to mask values.
176 Note, this class first computes IoUs for all individual classes, then
177 returns the mean of IoUs for the classes that are specified by
178 `target_class_ids`. If `target_class_ids` has only one id value, the IoU of
179 that specific class is returned.
181 Args:
182 num_classes: The possible number of labels the prediction task can have.
183 A confusion matrix of dimension = [num_classes, num_classes] will be
184 allocated to accumulate predictions from which the metric is calculated.
185 target_class_ids: A tuple or list of target class ids for which the metric
186 is returned. To compute IoU for a specific class, a list (or tuple) of a
187 single id value should be provided.
188 name: (Optional) string name of the metric instance.
189 dtype: (Optional) data type of the metric result.
190 ignore_class: Optional integer. The ID of a class to be ignored during
191 metric computation. This is useful, for example, in segmentation
192 problems featuring a "void" class (commonly -1 or 255) in segmentation
193 maps. By default (`ignore_class=None`), all classes are considered.
194 sparse_y_true: Whether labels are encoded using integers or
195 dense floating point vectors. If `False`, the `tf.argmax` function
196 will be used to determine each sample's most likely associated label.
197 sparse_y_pred: Whether predictions are encoded using integers or
198 dense floating point vectors. If `False`, the `tf.argmax` function
199 will be used to determine each sample's most likely associated label.
200 axis: (Optional) Defaults to -1. The dimension containing the logits.
202 Standalone usage:
204 >>> # cm = [[1, 1],
205 >>> # [1, 1]]
206 >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
207 >>> # iou = true_positives / (sum_row + sum_col - true_positives))
208 >>> # iou = [0.33, 0.33]
209 >>> m = tf.keras.metrics.IoU(num_classes=2, target_class_ids=[0])
210 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
211 >>> m.result().numpy()
212 0.33333334
214 >>> m.reset_state()
215 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],
216 ... sample_weight=[0.3, 0.3, 0.3, 0.1])
217 >>> # cm = [[0.3, 0.3],
218 >>> # [0.3, 0.1]]
219 >>> # sum_row = [0.6, 0.4], sum_col = [0.6, 0.4],
220 >>> # true_positives = [0.3, 0.1]
221 >>> # iou = [0.33, 0.14]
222 >>> m.result().numpy()
223 0.33333334
225 Usage with `compile()` API:
227 ```python
228 model.compile(
229 optimizer='sgd',
230 loss='mse',
231 metrics=[tf.keras.metrics.IoU(num_classes=2, target_class_ids=[0])])
232 ```
233 """
235 @dtensor_utils.inject_mesh
236 def __init__(
237 self,
238 num_classes: int,
239 target_class_ids: Union[List[int], Tuple[int, ...]],
240 name: Optional[str] = None,
241 dtype: Optional[Union[str, tf.dtypes.DType]] = None,
242 ignore_class: Optional[int] = None,
243 sparse_y_true: bool = True,
244 sparse_y_pred: bool = True,
245 axis: int = -1,
246 ):
247 super().__init__(
248 name=name,
249 num_classes=num_classes,
250 ignore_class=ignore_class,
251 sparse_y_true=sparse_y_true,
252 sparse_y_pred=sparse_y_pred,
253 axis=axis,
254 dtype=dtype,
255 )
256 if max(target_class_ids) >= num_classes:
257 raise ValueError(
258 f"Target class id {max(target_class_ids)} "
259 "is out of range, which is "
260 f"[{0}, {num_classes})."
261 )
262 self.target_class_ids = list(target_class_ids)
264 def result(self):
265 """Compute the intersection-over-union via the confusion matrix."""
266 sum_over_row = tf.cast(
267 tf.reduce_sum(self.total_cm, axis=0), dtype=self._dtype
268 )
269 sum_over_col = tf.cast(
270 tf.reduce_sum(self.total_cm, axis=1), dtype=self._dtype
271 )
272 true_positives = tf.cast(
273 tf.linalg.tensor_diag_part(self.total_cm), dtype=self._dtype
274 )
276 # sum_over_row + sum_over_col =
277 # 2 * true_positives + false_positives + false_negatives.
278 denominator = sum_over_row + sum_over_col - true_positives
280 # Only keep the target classes
281 true_positives = tf.gather(true_positives, self.target_class_ids)
282 denominator = tf.gather(denominator, self.target_class_ids)
284 # If the denominator is 0, we need to ignore the class.
285 num_valid_entries = tf.reduce_sum(
286 tf.cast(tf.not_equal(denominator, 0), dtype=self._dtype)
287 )
289 iou = tf.math.divide_no_nan(true_positives, denominator)
291 return tf.math.divide_no_nan(
292 tf.reduce_sum(iou, name="mean_iou"), num_valid_entries
293 )
295 def get_config(self):
296 config = {
297 "num_classes": self.num_classes,
298 "target_class_ids": self.target_class_ids,
299 "ignore_class": self.ignore_class,
300 "sparse_y_true": self.sparse_y_true,
301 "sparse_y_pred": self.sparse_y_pred,
302 "axis": self.axis,
303 }
304 base_config = super().get_config()
305 return dict(list(base_config.items()) + list(config.items()))
308@keras_export("keras.metrics.BinaryIoU")
309class BinaryIoU(IoU):
310 """Computes the Intersection-Over-Union metric for class 0 and/or 1.
312 General definition and computation:
314 Intersection-Over-Union is a common evaluation metric for semantic image
315 segmentation.
317 For an individual class, the IoU metric is defined as follows:
319 ```
320 iou = true_positives / (true_positives + false_positives + false_negatives)
321 ```
323 To compute IoUs, the predictions are accumulated in a confusion matrix,
324 weighted by `sample_weight` and the metric is then calculated from it.
326 If `sample_weight` is `None`, weights default to 1.
327 Use `sample_weight` of 0 to mask values.
329 This class can be used to compute IoUs for a binary classification task
330 where the predictions are provided as logits. First a `threshold` is applied
331 to the predicted values such that those that are below the `threshold` are
332 converted to class 0 and those that are above the `threshold` are converted
333 to class 1.
335 IoUs for classes 0 and 1 are then computed, the mean of IoUs for the classes
336 that are specified by `target_class_ids` is returned.
338 Note: with `threshold=0`, this metric has the same behavior as `IoU`.
340 Args:
341 target_class_ids: A tuple or list of target class ids for which the metric
342 is returned. Options are `[0]`, `[1]`, or `[0, 1]`. With `[0]` (or
343 `[1]`), the IoU metric for class 0 (or class 1, respectively) is
344 returned. With `[0, 1]`, the mean of IoUs for the two classes is
345 returned.
346 threshold: A threshold that applies to the prediction logits to convert
347 them to either predicted class 0 if the logit is below `threshold` or
348 predicted class 1 if the logit is above `threshold`.
349 name: (Optional) string name of the metric instance.
350 dtype: (Optional) data type of the metric result.
352 Standalone usage:
354 >>> m = tf.keras.metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3)
355 >>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7])
356 >>> m.result().numpy()
357 0.33333334
359 >>> m.reset_state()
360 >>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7],
361 ... sample_weight=[0.2, 0.3, 0.4, 0.1])
362 >>> # cm = [[0.2, 0.4],
363 >>> # [0.3, 0.1]]
364 >>> # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5],
365 >>> # true_positives = [0.2, 0.1]
366 >>> # iou = [0.222, 0.125]
367 >>> m.result().numpy()
368 0.17361112
370 Usage with `compile()` API:
372 ```python
373 model.compile(
374 optimizer='sgd',
375 loss='mse',
376 metrics=[tf.keras.metrics.BinaryIoU(target_class_ids=[0], threshold=0.5)])
377 ```
378 """
380 @dtensor_utils.inject_mesh
381 def __init__(
382 self,
383 target_class_ids: Union[List[int], Tuple[int, ...]] = (0, 1),
384 threshold=0.5,
385 name=None,
386 dtype=None,
387 ):
389 super().__init__(
390 num_classes=2,
391 target_class_ids=target_class_ids,
392 name=name,
393 dtype=dtype,
394 )
395 self.threshold = threshold
397 def update_state(self, y_true, y_pred, sample_weight=None):
398 """Accumulates the confusion matrix statistics.
400 Before the confusion matrix is updated, the predicted values are
401 thresholded to be:
402 0 for values that are smaller than the `threshold`
403 1 for values that are larger or equal to the `threshold`
405 Args:
406 y_true: The ground truth values.
407 y_pred: The predicted values.
408 sample_weight: Optional weighting of each example. Defaults to 1. Can
409 be a `Tensor` whose rank is either 0, or the same rank as `y_true`,
410 and must be broadcastable to `y_true`.
412 Returns:
413 Update op.
414 """
415 y_pred = tf.cast(y_pred, self._dtype)
416 y_pred = tf.cast(y_pred >= self.threshold, self._dtype)
417 return super().update_state(y_true, y_pred, sample_weight)
419 def get_config(self):
420 return {
421 "target_class_ids": self.target_class_ids,
422 "threshold": self.threshold,
423 "name": self.name,
424 "dtype": self._dtype,
425 }
428@keras_export("keras.metrics.MeanIoU")
429class MeanIoU(IoU):
430 """Computes the mean Intersection-Over-Union metric.
432 General definition and computation:
434 Intersection-Over-Union is a common evaluation metric for semantic image
435 segmentation.
437 For an individual class, the IoU metric is defined as follows:
439 ```
440 iou = true_positives / (true_positives + false_positives + false_negatives)
441 ```
443 To compute IoUs, the predictions are accumulated in a confusion matrix,
444 weighted by `sample_weight` and the metric is then calculated from it.
446 If `sample_weight` is `None`, weights default to 1.
447 Use `sample_weight` of 0 to mask values.
449 Note that this class first computes IoUs for all individual classes, then
450 returns the mean of these values.
452 Args:
453 num_classes: The possible number of labels the prediction task can have.
454 This value must be provided, since a confusion matrix of dimension =
455 [num_classes, num_classes] will be allocated.
456 name: (Optional) string name of the metric instance.
457 dtype: (Optional) data type of the metric result.
458 ignore_class: Optional integer. The ID of a class to be ignored during
459 metric computation. This is useful, for example, in segmentation
460 problems featuring a "void" class (commonly -1 or 255) in segmentation
461 maps. By default (`ignore_class=None`), all classes are considered.
462 sparse_y_true: Whether labels are encoded using integers or
463 dense floating point vectors. If `False`, the `tf.argmax` function
464 will be used to determine each sample's most likely associated label.
465 sparse_y_pred: Whether predictions are encoded using integers or
466 dense floating point vectors. If `False`, the `tf.argmax` function
467 will be used to determine each sample's most likely associated label.
468 axis: (Optional) Defaults to -1. The dimension containing the logits.
470 Standalone usage:
472 >>> # cm = [[1, 1],
473 >>> # [1, 1]]
474 >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
475 >>> # iou = true_positives / (sum_row + sum_col - true_positives))
476 >>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33
477 >>> m = tf.keras.metrics.MeanIoU(num_classes=2)
478 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
479 >>> m.result().numpy()
480 0.33333334
482 >>> m.reset_state()
483 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],
484 ... sample_weight=[0.3, 0.3, 0.3, 0.1])
485 >>> m.result().numpy()
486 0.23809525
488 Usage with `compile()` API:
490 ```python
491 model.compile(
492 optimizer='sgd',
493 loss='mse',
494 metrics=[tf.keras.metrics.MeanIoU(num_classes=2)])
495 ```
496 """
498 @dtensor_utils.inject_mesh
499 def __init__(
500 self,
501 num_classes: int,
502 name: Optional[str] = None,
503 dtype: Optional[Union[str, tf.dtypes.DType]] = None,
504 ignore_class: Optional[int] = None,
505 sparse_y_true: bool = True,
506 sparse_y_pred: bool = True,
507 axis: int = -1,
508 ):
509 target_class_ids = list(range(num_classes))
510 super().__init__(
511 name=name,
512 num_classes=num_classes,
513 target_class_ids=target_class_ids,
514 axis=axis,
515 dtype=dtype,
516 ignore_class=ignore_class,
517 sparse_y_true=sparse_y_true,
518 sparse_y_pred=sparse_y_pred,
519 )
521 def get_config(self):
522 return {
523 "num_classes": self.num_classes,
524 "name": self.name,
525 "dtype": self._dtype,
526 "ignore_class": self.ignore_class,
527 "sparse_y_true": self.sparse_y_true,
528 "sparse_y_pred": self.sparse_y_pred,
529 "axis": self.axis,
530 }
533@keras_export("keras.metrics.OneHotIoU")
534class OneHotIoU(IoU):
535 """Computes the Intersection-Over-Union metric for one-hot encoded labels.
537 General definition and computation:
539 Intersection-Over-Union is a common evaluation metric for semantic image
540 segmentation.
542 For an individual class, the IoU metric is defined as follows:
544 ```
545 iou = true_positives / (true_positives + false_positives + false_negatives)
546 ```
548 To compute IoUs, the predictions are accumulated in a confusion matrix,
549 weighted by `sample_weight` and the metric is then calculated from it.
551 If `sample_weight` is `None`, weights default to 1.
552 Use `sample_weight` of 0 to mask values.
554 This class can be used to compute IoU for multi-class classification tasks
555 where the labels are one-hot encoded (the last axis should have one
556 dimension per class). Note that the predictions should also have the same
557 shape. To compute the IoU, first the labels and predictions are converted
558 back into integer format by taking the argmax over the class axis. Then the
559 same computation steps as for the base `IoU` class apply.
561 Note, if there is only one channel in the labels and predictions, this class
562 is the same as class `IoU`. In this case, use `IoU` instead.
564 Also, make sure that `num_classes` is equal to the number of classes in the
565 data, to avoid a "labels out of bound" error when the confusion matrix is
566 computed.
568 Args:
569 num_classes: The possible number of labels the prediction task can have.
570 A confusion matrix of shape `(num_classes, num_classes)` will be
571 allocated to accumulate predictions from which the metric is calculated.
572 target_class_ids: A tuple or list of target class ids for which the metric
573 is returned. To compute IoU for a specific class, a list (or tuple) of a
574 single id value should be provided.
575 name: (Optional) string name of the metric instance.
576 dtype: (Optional) data type of the metric result.
577 ignore_class: Optional integer. The ID of a class to be ignored during
578 metric computation. This is useful, for example, in segmentation
579 problems featuring a "void" class (commonly -1 or 255) in segmentation
580 maps. By default (`ignore_class=None`), all classes are considered.
581 sparse_y_pred: Whether predictions are encoded using natural numbers or
582 probability distribution vectors. If `False`, the `tf.argmax` function
583 will be used to determine each sample's most likely associated label.
584 axis: (Optional) Defaults to -1. The dimension containing the logits.
586 Standalone usage:
588 >>> y_true = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])
589 >>> y_pred = tf.constant([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1],
590 ... [0.1, 0.4, 0.5]])
591 >>> sample_weight = [0.1, 0.2, 0.3, 0.4]
592 >>> m = tf.keras.metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2])
593 >>> m.update_state(
594 ... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)
595 >>> # cm = [[0, 0, 0.2+0.4],
596 >>> # [0.3, 0, 0],
597 >>> # [0, 0, 0.1]]
598 >>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1]
599 >>> # true_positives = [0, 0, 0.1]
600 >>> # single_iou = true_positives / (sum_row + sum_col - true_positives))
601 >>> # mean_iou = (0 / (0.3 + 0.6 - 0) + 0.1 / (0.7 + 0.1 - 0.1)) / 2
602 >>> m.result().numpy()
603 0.071
605 Usage with `compile()` API:
607 ```python
608 model.compile(
609 optimizer='sgd',
610 loss='mse',
611 metrics=[tf.keras.metrics.OneHotIoU(num_classes=3, target_class_id=[1])])
612 ```
613 """
615 @dtensor_utils.inject_mesh
616 def __init__(
617 self,
618 num_classes: int,
619 target_class_ids: Union[List[int], Tuple[int, ...]],
620 name=None,
621 dtype=None,
622 ignore_class: Optional[int] = None,
623 sparse_y_pred: bool = False,
624 axis: int = -1,
625 ):
626 super().__init__(
627 num_classes=num_classes,
628 target_class_ids=target_class_ids,
629 name=name,
630 dtype=dtype,
631 ignore_class=ignore_class,
632 sparse_y_true=False,
633 sparse_y_pred=sparse_y_pred,
634 axis=axis,
635 )
637 def get_config(self):
638 return {
639 "num_classes": self.num_classes,
640 "target_class_ids": self.target_class_ids,
641 "name": self.name,
642 "dtype": self._dtype,
643 "ignore_class": self.ignore_class,
644 "sparse_y_pred": self.sparse_y_pred,
645 "axis": self.axis,
646 }
649@keras_export("keras.metrics.OneHotMeanIoU")
650class OneHotMeanIoU(MeanIoU):
651 """Computes mean Intersection-Over-Union metric for one-hot encoded labels.
653 General definition and computation:
655 Intersection-Over-Union is a common evaluation metric for semantic image
656 segmentation.
658 For an individual class, the IoU metric is defined as follows:
660 ```
661 iou = true_positives / (true_positives + false_positives + false_negatives)
662 ```
664 To compute IoUs, the predictions are accumulated in a confusion matrix,
665 weighted by `sample_weight` and the metric is then calculated from it.
667 If `sample_weight` is `None`, weights default to 1.
668 Use `sample_weight` of 0 to mask values.
670 This class can be used to compute the mean IoU for multi-class
671 classification tasks where the labels are one-hot encoded (the last axis
672 should have one dimension per class). Note that the predictions should also
673 have the same shape. To compute the mean IoU, first the labels and
674 predictions are converted back into integer format by taking the argmax over
675 the class axis. Then the same computation steps as for the base `MeanIoU`
676 class apply.
678 Note, if there is only one channel in the labels and predictions, this class
679 is the same as class `MeanIoU`. In this case, use `MeanIoU` instead.
681 Also, make sure that `num_classes` is equal to the number of classes in the
682 data, to avoid a "labels out of bound" error when the confusion matrix is
683 computed.
685 Args:
686 num_classes: The possible number of labels the prediction task can have.
687 A confusion matrix of shape `(num_classes, num_classes)` will be
688 allocated to accumulate predictions from which the metric is calculated.
689 name: (Optional) string name of the metric instance.
690 dtype: (Optional) data type of the metric result.
691 ignore_class: Optional integer. The ID of a class to be ignored during
692 metric computation. This is useful, for example, in segmentation
693 problems featuring a "void" class (commonly -1 or 255) in segmentation
694 maps. By default (`ignore_class=None`), all classes are considered.
695 sparse_y_pred: Whether predictions are encoded using natural numbers or
696 probability distribution vectors. If `False`, the `tf.argmax` function
697 will be used to determine each sample's most likely associated label.
698 axis: (Optional) Defaults to -1. The dimension containing the logits.
700 Standalone usage:
702 >>> y_true = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])
703 >>> y_pred = tf.constant([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1],
704 ... [0.1, 0.4, 0.5]])
705 >>> sample_weight = [0.1, 0.2, 0.3, 0.4]
706 >>> m = tf.keras.metrics.OneHotMeanIoU(num_classes=3)
707 >>> m.update_state(
708 ... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)
709 >>> # cm = [[0, 0, 0.2+0.4],
710 >>> # [0.3, 0, 0],
711 >>> # [0, 0, 0.1]]
712 >>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1]
713 >>> # true_positives = [0, 0, 0.1]
714 >>> # single_iou = true_positives / (sum_row + sum_col - true_positives))
715 >>> # mean_iou = (0 + 0 + 0.1 / (0.7 + 0.1 - 0.1)) / 3
716 >>> m.result().numpy()
717 0.048
719 Usage with `compile()` API:
721 ```python
722 model.compile(
723 optimizer='sgd',
724 loss='mse',
725 metrics=[tf.keras.metrics.OneHotMeanIoU(num_classes=3)])
726 ```
727 """
729 @dtensor_utils.inject_mesh
730 def __init__(
731 self,
732 num_classes: int,
733 name: str = None,
734 dtype: Optional[Union[str, tf.dtypes.DType]] = None,
735 ignore_class: Optional[int] = None,
736 sparse_y_pred: bool = False,
737 axis: int = -1,
738 ):
739 super().__init__(
740 num_classes=num_classes,
741 axis=axis,
742 name=name,
743 dtype=dtype,
744 ignore_class=ignore_class,
745 sparse_y_true=False,
746 sparse_y_pred=sparse_y_pred,
747 )
749 def get_config(self):
750 return {
751 "num_classes": self.num_classes,
752 "name": self.name,
753 "dtype": self._dtype,
754 "ignore_class": self.ignore_class,
755 "sparse_y_pred": self.sparse_y_pred,
756 "axis": self.axis,
757 }