Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/metrics/iou_metrics.py: 41%

102 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""IoU metrics.""" 

16 

17from typing import List 

18from typing import Optional 

19from typing import Tuple 

20from typing import Union 

21 

22import numpy as np 

23import tensorflow.compat.v2 as tf 

24 

25from keras.src import backend 

26from keras.src.dtensor import utils as dtensor_utils 

27from keras.src.metrics import base_metric 

28 

29# isort: off 

30from tensorflow.python.util.tf_export import keras_export 

31 

32 

33class _IoUBase(base_metric.Metric): 

34 """Computes the confusion matrix for Intersection-Over-Union metrics. 

35 

36 Intersection-Over-Union is a common evaluation metric for semantic image 

37 segmentation. 

38 

39 For an individual class, the IoU metric is defined as follows: 

40 

41 ``` 

42 iou = true_positives / (true_positives + false_positives + false_negatives) 

43 ``` 

44 

45 From IoUs of individual classes, the MeanIoU can be computed as the mean of 

46 the individual IoUs. 

47 

48 To compute IoUs, the predictions are accumulated in a confusion matrix, 

49 weighted by `sample_weight` and the metric is then calculated from it. 

50 

51 If `sample_weight` is `None`, weights default to 1. 

52 Use `sample_weight` of 0 to mask values. 

53 

54 Args: 

55 num_classes: The possible number of labels the prediction task can have. 

56 This value must be provided, since a confusion matrix of size 

57 `(num_classes, num_classes)` will be allocated. 

58 name: (Optional) string name of the metric instance. 

59 dtype: (Optional) data type of the metric result. 

60 ignore_class: Optional integer. The ID of a class to be ignored during 

61 metric computation. This is useful, for example, in segmentation 

62 problems featuring a "void" class (commonly -1 or 255) in segmentation 

63 maps. By default (`ignore_class=None`), all classes are considered. 

64 sparse_y_true: Whether labels are encoded using integers or 

65 dense floating point vectors. If `False`, the `tf.argmax` function 

66 will be used to determine each sample's most likely associated label. 

67 sparse_y_pred: Whether predictions are encoded using integers or 

68 dense floating point vectors. If `False`, the `tf.argmax` function 

69 will be used to determine each sample's most likely associated label. 

70 axis: (Optional) Defaults to -1. The dimension containing the logits. 

71 """ 

72 

73 def __init__( 

74 self, 

75 num_classes: int, 

76 name: Optional[str] = None, 

77 dtype: Optional[Union[str, tf.dtypes.DType]] = None, 

78 ignore_class: Optional[int] = None, 

79 sparse_y_true: bool = True, 

80 sparse_y_pred: bool = True, 

81 axis: int = -1, 

82 ): 

83 super().__init__(name=name, dtype=dtype) 

84 self.num_classes = num_classes 

85 self.ignore_class = ignore_class 

86 self.sparse_y_true = sparse_y_true 

87 self.sparse_y_pred = sparse_y_pred 

88 self.axis = axis 

89 

90 # Variable to accumulate the predictions in the confusion matrix. 

91 self.total_cm = self.add_weight( 

92 "total_confusion_matrix", 

93 shape=(num_classes, num_classes), 

94 initializer="zeros", 

95 ) 

96 

97 def update_state(self, y_true, y_pred, sample_weight=None): 

98 """Accumulates the confusion matrix statistics. 

99 

100 Args: 

101 y_true: The ground truth values. 

102 y_pred: The predicted values. 

103 sample_weight: Optional weighting of each example. Defaults to 1. Can 

104 be a `Tensor` whose rank is either 0, or the same rank as `y_true`, 

105 and must be broadcastable to `y_true`. 

106 

107 Returns: 

108 Update op. 

109 """ 

110 

111 if not self.sparse_y_true: 

112 y_true = tf.argmax(y_true, axis=self.axis) 

113 if not self.sparse_y_pred: 

114 y_pred = tf.argmax(y_pred, axis=self.axis) 

115 

116 y_true = tf.cast(y_true, self._dtype) 

117 y_pred = tf.cast(y_pred, self._dtype) 

118 

119 # Flatten the input if its rank > 1. 

120 if y_pred.shape.ndims > 1: 

121 y_pred = tf.reshape(y_pred, [-1]) 

122 

123 if y_true.shape.ndims > 1: 

124 y_true = tf.reshape(y_true, [-1]) 

125 

126 if sample_weight is not None: 

127 sample_weight = tf.cast(sample_weight, self._dtype) 

128 if sample_weight.shape.ndims > 1: 

129 sample_weight = tf.reshape(sample_weight, [-1]) 

130 

131 if self.ignore_class is not None: 

132 ignore_class = tf.cast(self.ignore_class, y_true.dtype) 

133 valid_mask = tf.not_equal(y_true, ignore_class) 

134 y_true = y_true[valid_mask] 

135 y_pred = y_pred[valid_mask] 

136 if sample_weight is not None: 

137 sample_weight = sample_weight[valid_mask] 

138 

139 # Accumulate the prediction to current confusion matrix. 

140 current_cm = tf.math.confusion_matrix( 

141 y_true, 

142 y_pred, 

143 self.num_classes, 

144 weights=sample_weight, 

145 dtype=self._dtype, 

146 ) 

147 return self.total_cm.assign_add(current_cm) 

148 

149 def reset_state(self): 

150 backend.set_value( 

151 self.total_cm, np.zeros((self.num_classes, self.num_classes)) 

152 ) 

153 

154 

155@keras_export("keras.metrics.IoU") 

156class IoU(_IoUBase): 

157 """Computes the Intersection-Over-Union metric for specific target classes. 

158 

159 General definition and computation: 

160 

161 Intersection-Over-Union is a common evaluation metric for semantic image 

162 segmentation. 

163 

164 For an individual class, the IoU metric is defined as follows: 

165 

166 ``` 

167 iou = true_positives / (true_positives + false_positives + false_negatives) 

168 ``` 

169 

170 To compute IoUs, the predictions are accumulated in a confusion matrix, 

171 weighted by `sample_weight` and the metric is then calculated from it. 

172 

173 If `sample_weight` is `None`, weights default to 1. 

174 Use `sample_weight` of 0 to mask values. 

175 

176 Note, this class first computes IoUs for all individual classes, then 

177 returns the mean of IoUs for the classes that are specified by 

178 `target_class_ids`. If `target_class_ids` has only one id value, the IoU of 

179 that specific class is returned. 

180 

181 Args: 

182 num_classes: The possible number of labels the prediction task can have. 

183 A confusion matrix of dimension = [num_classes, num_classes] will be 

184 allocated to accumulate predictions from which the metric is calculated. 

185 target_class_ids: A tuple or list of target class ids for which the metric 

186 is returned. To compute IoU for a specific class, a list (or tuple) of a 

187 single id value should be provided. 

188 name: (Optional) string name of the metric instance. 

189 dtype: (Optional) data type of the metric result. 

190 ignore_class: Optional integer. The ID of a class to be ignored during 

191 metric computation. This is useful, for example, in segmentation 

192 problems featuring a "void" class (commonly -1 or 255) in segmentation 

193 maps. By default (`ignore_class=None`), all classes are considered. 

194 sparse_y_true: Whether labels are encoded using integers or 

195 dense floating point vectors. If `False`, the `tf.argmax` function 

196 will be used to determine each sample's most likely associated label. 

197 sparse_y_pred: Whether predictions are encoded using integers or 

198 dense floating point vectors. If `False`, the `tf.argmax` function 

199 will be used to determine each sample's most likely associated label. 

200 axis: (Optional) Defaults to -1. The dimension containing the logits. 

201 

202 Standalone usage: 

203 

204 >>> # cm = [[1, 1], 

205 >>> # [1, 1]] 

206 >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1] 

207 >>> # iou = true_positives / (sum_row + sum_col - true_positives)) 

208 >>> # iou = [0.33, 0.33] 

209 >>> m = tf.keras.metrics.IoU(num_classes=2, target_class_ids=[0]) 

210 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1]) 

211 >>> m.result().numpy() 

212 0.33333334 

213 

214 >>> m.reset_state() 

215 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1], 

216 ... sample_weight=[0.3, 0.3, 0.3, 0.1]) 

217 >>> # cm = [[0.3, 0.3], 

218 >>> # [0.3, 0.1]] 

219 >>> # sum_row = [0.6, 0.4], sum_col = [0.6, 0.4], 

220 >>> # true_positives = [0.3, 0.1] 

221 >>> # iou = [0.33, 0.14] 

222 >>> m.result().numpy() 

223 0.33333334 

224 

225 Usage with `compile()` API: 

226 

227 ```python 

228 model.compile( 

229 optimizer='sgd', 

230 loss='mse', 

231 metrics=[tf.keras.metrics.IoU(num_classes=2, target_class_ids=[0])]) 

232 ``` 

233 """ 

234 

235 @dtensor_utils.inject_mesh 

236 def __init__( 

237 self, 

238 num_classes: int, 

239 target_class_ids: Union[List[int], Tuple[int, ...]], 

240 name: Optional[str] = None, 

241 dtype: Optional[Union[str, tf.dtypes.DType]] = None, 

242 ignore_class: Optional[int] = None, 

243 sparse_y_true: bool = True, 

244 sparse_y_pred: bool = True, 

245 axis: int = -1, 

246 ): 

247 super().__init__( 

248 name=name, 

249 num_classes=num_classes, 

250 ignore_class=ignore_class, 

251 sparse_y_true=sparse_y_true, 

252 sparse_y_pred=sparse_y_pred, 

253 axis=axis, 

254 dtype=dtype, 

255 ) 

256 if max(target_class_ids) >= num_classes: 

257 raise ValueError( 

258 f"Target class id {max(target_class_ids)} " 

259 "is out of range, which is " 

260 f"[{0}, {num_classes})." 

261 ) 

262 self.target_class_ids = list(target_class_ids) 

263 

264 def result(self): 

265 """Compute the intersection-over-union via the confusion matrix.""" 

266 sum_over_row = tf.cast( 

267 tf.reduce_sum(self.total_cm, axis=0), dtype=self._dtype 

268 ) 

269 sum_over_col = tf.cast( 

270 tf.reduce_sum(self.total_cm, axis=1), dtype=self._dtype 

271 ) 

272 true_positives = tf.cast( 

273 tf.linalg.tensor_diag_part(self.total_cm), dtype=self._dtype 

274 ) 

275 

276 # sum_over_row + sum_over_col = 

277 # 2 * true_positives + false_positives + false_negatives. 

278 denominator = sum_over_row + sum_over_col - true_positives 

279 

280 # Only keep the target classes 

281 true_positives = tf.gather(true_positives, self.target_class_ids) 

282 denominator = tf.gather(denominator, self.target_class_ids) 

283 

284 # If the denominator is 0, we need to ignore the class. 

285 num_valid_entries = tf.reduce_sum( 

286 tf.cast(tf.not_equal(denominator, 0), dtype=self._dtype) 

287 ) 

288 

289 iou = tf.math.divide_no_nan(true_positives, denominator) 

290 

291 return tf.math.divide_no_nan( 

292 tf.reduce_sum(iou, name="mean_iou"), num_valid_entries 

293 ) 

294 

295 def get_config(self): 

296 config = { 

297 "num_classes": self.num_classes, 

298 "target_class_ids": self.target_class_ids, 

299 "ignore_class": self.ignore_class, 

300 "sparse_y_true": self.sparse_y_true, 

301 "sparse_y_pred": self.sparse_y_pred, 

302 "axis": self.axis, 

303 } 

304 base_config = super().get_config() 

305 return dict(list(base_config.items()) + list(config.items())) 

306 

307 

308@keras_export("keras.metrics.BinaryIoU") 

309class BinaryIoU(IoU): 

310 """Computes the Intersection-Over-Union metric for class 0 and/or 1. 

311 

312 General definition and computation: 

313 

314 Intersection-Over-Union is a common evaluation metric for semantic image 

315 segmentation. 

316 

317 For an individual class, the IoU metric is defined as follows: 

318 

319 ``` 

320 iou = true_positives / (true_positives + false_positives + false_negatives) 

321 ``` 

322 

323 To compute IoUs, the predictions are accumulated in a confusion matrix, 

324 weighted by `sample_weight` and the metric is then calculated from it. 

325 

326 If `sample_weight` is `None`, weights default to 1. 

327 Use `sample_weight` of 0 to mask values. 

328 

329 This class can be used to compute IoUs for a binary classification task 

330 where the predictions are provided as logits. First a `threshold` is applied 

331 to the predicted values such that those that are below the `threshold` are 

332 converted to class 0 and those that are above the `threshold` are converted 

333 to class 1. 

334 

335 IoUs for classes 0 and 1 are then computed, the mean of IoUs for the classes 

336 that are specified by `target_class_ids` is returned. 

337 

338 Note: with `threshold=0`, this metric has the same behavior as `IoU`. 

339 

340 Args: 

341 target_class_ids: A tuple or list of target class ids for which the metric 

342 is returned. Options are `[0]`, `[1]`, or `[0, 1]`. With `[0]` (or 

343 `[1]`), the IoU metric for class 0 (or class 1, respectively) is 

344 returned. With `[0, 1]`, the mean of IoUs for the two classes is 

345 returned. 

346 threshold: A threshold that applies to the prediction logits to convert 

347 them to either predicted class 0 if the logit is below `threshold` or 

348 predicted class 1 if the logit is above `threshold`. 

349 name: (Optional) string name of the metric instance. 

350 dtype: (Optional) data type of the metric result. 

351 

352 Standalone usage: 

353 

354 >>> m = tf.keras.metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3) 

355 >>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7]) 

356 >>> m.result().numpy() 

357 0.33333334 

358 

359 >>> m.reset_state() 

360 >>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7], 

361 ... sample_weight=[0.2, 0.3, 0.4, 0.1]) 

362 >>> # cm = [[0.2, 0.4], 

363 >>> # [0.3, 0.1]] 

364 >>> # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], 

365 >>> # true_positives = [0.2, 0.1] 

366 >>> # iou = [0.222, 0.125] 

367 >>> m.result().numpy() 

368 0.17361112 

369 

370 Usage with `compile()` API: 

371 

372 ```python 

373 model.compile( 

374 optimizer='sgd', 

375 loss='mse', 

376 metrics=[tf.keras.metrics.BinaryIoU(target_class_ids=[0], threshold=0.5)]) 

377 ``` 

378 """ 

379 

380 @dtensor_utils.inject_mesh 

381 def __init__( 

382 self, 

383 target_class_ids: Union[List[int], Tuple[int, ...]] = (0, 1), 

384 threshold=0.5, 

385 name=None, 

386 dtype=None, 

387 ): 

388 

389 super().__init__( 

390 num_classes=2, 

391 target_class_ids=target_class_ids, 

392 name=name, 

393 dtype=dtype, 

394 ) 

395 self.threshold = threshold 

396 

397 def update_state(self, y_true, y_pred, sample_weight=None): 

398 """Accumulates the confusion matrix statistics. 

399 

400 Before the confusion matrix is updated, the predicted values are 

401 thresholded to be: 

402 0 for values that are smaller than the `threshold` 

403 1 for values that are larger or equal to the `threshold` 

404 

405 Args: 

406 y_true: The ground truth values. 

407 y_pred: The predicted values. 

408 sample_weight: Optional weighting of each example. Defaults to 1. Can 

409 be a `Tensor` whose rank is either 0, or the same rank as `y_true`, 

410 and must be broadcastable to `y_true`. 

411 

412 Returns: 

413 Update op. 

414 """ 

415 y_pred = tf.cast(y_pred, self._dtype) 

416 y_pred = tf.cast(y_pred >= self.threshold, self._dtype) 

417 return super().update_state(y_true, y_pred, sample_weight) 

418 

419 def get_config(self): 

420 return { 

421 "target_class_ids": self.target_class_ids, 

422 "threshold": self.threshold, 

423 "name": self.name, 

424 "dtype": self._dtype, 

425 } 

426 

427 

428@keras_export("keras.metrics.MeanIoU") 

429class MeanIoU(IoU): 

430 """Computes the mean Intersection-Over-Union metric. 

431 

432 General definition and computation: 

433 

434 Intersection-Over-Union is a common evaluation metric for semantic image 

435 segmentation. 

436 

437 For an individual class, the IoU metric is defined as follows: 

438 

439 ``` 

440 iou = true_positives / (true_positives + false_positives + false_negatives) 

441 ``` 

442 

443 To compute IoUs, the predictions are accumulated in a confusion matrix, 

444 weighted by `sample_weight` and the metric is then calculated from it. 

445 

446 If `sample_weight` is `None`, weights default to 1. 

447 Use `sample_weight` of 0 to mask values. 

448 

449 Note that this class first computes IoUs for all individual classes, then 

450 returns the mean of these values. 

451 

452 Args: 

453 num_classes: The possible number of labels the prediction task can have. 

454 This value must be provided, since a confusion matrix of dimension = 

455 [num_classes, num_classes] will be allocated. 

456 name: (Optional) string name of the metric instance. 

457 dtype: (Optional) data type of the metric result. 

458 ignore_class: Optional integer. The ID of a class to be ignored during 

459 metric computation. This is useful, for example, in segmentation 

460 problems featuring a "void" class (commonly -1 or 255) in segmentation 

461 maps. By default (`ignore_class=None`), all classes are considered. 

462 sparse_y_true: Whether labels are encoded using integers or 

463 dense floating point vectors. If `False`, the `tf.argmax` function 

464 will be used to determine each sample's most likely associated label. 

465 sparse_y_pred: Whether predictions are encoded using integers or 

466 dense floating point vectors. If `False`, the `tf.argmax` function 

467 will be used to determine each sample's most likely associated label. 

468 axis: (Optional) Defaults to -1. The dimension containing the logits. 

469 

470 Standalone usage: 

471 

472 >>> # cm = [[1, 1], 

473 >>> # [1, 1]] 

474 >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1] 

475 >>> # iou = true_positives / (sum_row + sum_col - true_positives)) 

476 >>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33 

477 >>> m = tf.keras.metrics.MeanIoU(num_classes=2) 

478 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1]) 

479 >>> m.result().numpy() 

480 0.33333334 

481 

482 >>> m.reset_state() 

483 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1], 

484 ... sample_weight=[0.3, 0.3, 0.3, 0.1]) 

485 >>> m.result().numpy() 

486 0.23809525 

487 

488 Usage with `compile()` API: 

489 

490 ```python 

491 model.compile( 

492 optimizer='sgd', 

493 loss='mse', 

494 metrics=[tf.keras.metrics.MeanIoU(num_classes=2)]) 

495 ``` 

496 """ 

497 

498 @dtensor_utils.inject_mesh 

499 def __init__( 

500 self, 

501 num_classes: int, 

502 name: Optional[str] = None, 

503 dtype: Optional[Union[str, tf.dtypes.DType]] = None, 

504 ignore_class: Optional[int] = None, 

505 sparse_y_true: bool = True, 

506 sparse_y_pred: bool = True, 

507 axis: int = -1, 

508 ): 

509 target_class_ids = list(range(num_classes)) 

510 super().__init__( 

511 name=name, 

512 num_classes=num_classes, 

513 target_class_ids=target_class_ids, 

514 axis=axis, 

515 dtype=dtype, 

516 ignore_class=ignore_class, 

517 sparse_y_true=sparse_y_true, 

518 sparse_y_pred=sparse_y_pred, 

519 ) 

520 

521 def get_config(self): 

522 return { 

523 "num_classes": self.num_classes, 

524 "name": self.name, 

525 "dtype": self._dtype, 

526 "ignore_class": self.ignore_class, 

527 "sparse_y_true": self.sparse_y_true, 

528 "sparse_y_pred": self.sparse_y_pred, 

529 "axis": self.axis, 

530 } 

531 

532 

533@keras_export("keras.metrics.OneHotIoU") 

534class OneHotIoU(IoU): 

535 """Computes the Intersection-Over-Union metric for one-hot encoded labels. 

536 

537 General definition and computation: 

538 

539 Intersection-Over-Union is a common evaluation metric for semantic image 

540 segmentation. 

541 

542 For an individual class, the IoU metric is defined as follows: 

543 

544 ``` 

545 iou = true_positives / (true_positives + false_positives + false_negatives) 

546 ``` 

547 

548 To compute IoUs, the predictions are accumulated in a confusion matrix, 

549 weighted by `sample_weight` and the metric is then calculated from it. 

550 

551 If `sample_weight` is `None`, weights default to 1. 

552 Use `sample_weight` of 0 to mask values. 

553 

554 This class can be used to compute IoU for multi-class classification tasks 

555 where the labels are one-hot encoded (the last axis should have one 

556 dimension per class). Note that the predictions should also have the same 

557 shape. To compute the IoU, first the labels and predictions are converted 

558 back into integer format by taking the argmax over the class axis. Then the 

559 same computation steps as for the base `IoU` class apply. 

560 

561 Note, if there is only one channel in the labels and predictions, this class 

562 is the same as class `IoU`. In this case, use `IoU` instead. 

563 

564 Also, make sure that `num_classes` is equal to the number of classes in the 

565 data, to avoid a "labels out of bound" error when the confusion matrix is 

566 computed. 

567 

568 Args: 

569 num_classes: The possible number of labels the prediction task can have. 

570 A confusion matrix of shape `(num_classes, num_classes)` will be 

571 allocated to accumulate predictions from which the metric is calculated. 

572 target_class_ids: A tuple or list of target class ids for which the metric 

573 is returned. To compute IoU for a specific class, a list (or tuple) of a 

574 single id value should be provided. 

575 name: (Optional) string name of the metric instance. 

576 dtype: (Optional) data type of the metric result. 

577 ignore_class: Optional integer. The ID of a class to be ignored during 

578 metric computation. This is useful, for example, in segmentation 

579 problems featuring a "void" class (commonly -1 or 255) in segmentation 

580 maps. By default (`ignore_class=None`), all classes are considered. 

581 sparse_y_pred: Whether predictions are encoded using natural numbers or 

582 probability distribution vectors. If `False`, the `tf.argmax` function 

583 will be used to determine each sample's most likely associated label. 

584 axis: (Optional) Defaults to -1. The dimension containing the logits. 

585 

586 Standalone usage: 

587 

588 >>> y_true = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]]) 

589 >>> y_pred = tf.constant([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], 

590 ... [0.1, 0.4, 0.5]]) 

591 >>> sample_weight = [0.1, 0.2, 0.3, 0.4] 

592 >>> m = tf.keras.metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2]) 

593 >>> m.update_state( 

594 ... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) 

595 >>> # cm = [[0, 0, 0.2+0.4], 

596 >>> # [0.3, 0, 0], 

597 >>> # [0, 0, 0.1]] 

598 >>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1] 

599 >>> # true_positives = [0, 0, 0.1] 

600 >>> # single_iou = true_positives / (sum_row + sum_col - true_positives)) 

601 >>> # mean_iou = (0 / (0.3 + 0.6 - 0) + 0.1 / (0.7 + 0.1 - 0.1)) / 2 

602 >>> m.result().numpy() 

603 0.071 

604 

605 Usage with `compile()` API: 

606 

607 ```python 

608 model.compile( 

609 optimizer='sgd', 

610 loss='mse', 

611 metrics=[tf.keras.metrics.OneHotIoU(num_classes=3, target_class_id=[1])]) 

612 ``` 

613 """ 

614 

615 @dtensor_utils.inject_mesh 

616 def __init__( 

617 self, 

618 num_classes: int, 

619 target_class_ids: Union[List[int], Tuple[int, ...]], 

620 name=None, 

621 dtype=None, 

622 ignore_class: Optional[int] = None, 

623 sparse_y_pred: bool = False, 

624 axis: int = -1, 

625 ): 

626 super().__init__( 

627 num_classes=num_classes, 

628 target_class_ids=target_class_ids, 

629 name=name, 

630 dtype=dtype, 

631 ignore_class=ignore_class, 

632 sparse_y_true=False, 

633 sparse_y_pred=sparse_y_pred, 

634 axis=axis, 

635 ) 

636 

637 def get_config(self): 

638 return { 

639 "num_classes": self.num_classes, 

640 "target_class_ids": self.target_class_ids, 

641 "name": self.name, 

642 "dtype": self._dtype, 

643 "ignore_class": self.ignore_class, 

644 "sparse_y_pred": self.sparse_y_pred, 

645 "axis": self.axis, 

646 } 

647 

648 

649@keras_export("keras.metrics.OneHotMeanIoU") 

650class OneHotMeanIoU(MeanIoU): 

651 """Computes mean Intersection-Over-Union metric for one-hot encoded labels. 

652 

653 General definition and computation: 

654 

655 Intersection-Over-Union is a common evaluation metric for semantic image 

656 segmentation. 

657 

658 For an individual class, the IoU metric is defined as follows: 

659 

660 ``` 

661 iou = true_positives / (true_positives + false_positives + false_negatives) 

662 ``` 

663 

664 To compute IoUs, the predictions are accumulated in a confusion matrix, 

665 weighted by `sample_weight` and the metric is then calculated from it. 

666 

667 If `sample_weight` is `None`, weights default to 1. 

668 Use `sample_weight` of 0 to mask values. 

669 

670 This class can be used to compute the mean IoU for multi-class 

671 classification tasks where the labels are one-hot encoded (the last axis 

672 should have one dimension per class). Note that the predictions should also 

673 have the same shape. To compute the mean IoU, first the labels and 

674 predictions are converted back into integer format by taking the argmax over 

675 the class axis. Then the same computation steps as for the base `MeanIoU` 

676 class apply. 

677 

678 Note, if there is only one channel in the labels and predictions, this class 

679 is the same as class `MeanIoU`. In this case, use `MeanIoU` instead. 

680 

681 Also, make sure that `num_classes` is equal to the number of classes in the 

682 data, to avoid a "labels out of bound" error when the confusion matrix is 

683 computed. 

684 

685 Args: 

686 num_classes: The possible number of labels the prediction task can have. 

687 A confusion matrix of shape `(num_classes, num_classes)` will be 

688 allocated to accumulate predictions from which the metric is calculated. 

689 name: (Optional) string name of the metric instance. 

690 dtype: (Optional) data type of the metric result. 

691 ignore_class: Optional integer. The ID of a class to be ignored during 

692 metric computation. This is useful, for example, in segmentation 

693 problems featuring a "void" class (commonly -1 or 255) in segmentation 

694 maps. By default (`ignore_class=None`), all classes are considered. 

695 sparse_y_pred: Whether predictions are encoded using natural numbers or 

696 probability distribution vectors. If `False`, the `tf.argmax` function 

697 will be used to determine each sample's most likely associated label. 

698 axis: (Optional) Defaults to -1. The dimension containing the logits. 

699 

700 Standalone usage: 

701 

702 >>> y_true = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]]) 

703 >>> y_pred = tf.constant([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], 

704 ... [0.1, 0.4, 0.5]]) 

705 >>> sample_weight = [0.1, 0.2, 0.3, 0.4] 

706 >>> m = tf.keras.metrics.OneHotMeanIoU(num_classes=3) 

707 >>> m.update_state( 

708 ... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) 

709 >>> # cm = [[0, 0, 0.2+0.4], 

710 >>> # [0.3, 0, 0], 

711 >>> # [0, 0, 0.1]] 

712 >>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1] 

713 >>> # true_positives = [0, 0, 0.1] 

714 >>> # single_iou = true_positives / (sum_row + sum_col - true_positives)) 

715 >>> # mean_iou = (0 + 0 + 0.1 / (0.7 + 0.1 - 0.1)) / 3 

716 >>> m.result().numpy() 

717 0.048 

718 

719 Usage with `compile()` API: 

720 

721 ```python 

722 model.compile( 

723 optimizer='sgd', 

724 loss='mse', 

725 metrics=[tf.keras.metrics.OneHotMeanIoU(num_classes=3)]) 

726 ``` 

727 """ 

728 

729 @dtensor_utils.inject_mesh 

730 def __init__( 

731 self, 

732 num_classes: int, 

733 name: str = None, 

734 dtype: Optional[Union[str, tf.dtypes.DType]] = None, 

735 ignore_class: Optional[int] = None, 

736 sparse_y_pred: bool = False, 

737 axis: int = -1, 

738 ): 

739 super().__init__( 

740 num_classes=num_classes, 

741 axis=axis, 

742 name=name, 

743 dtype=dtype, 

744 ignore_class=ignore_class, 

745 sparse_y_true=False, 

746 sparse_y_pred=sparse_y_pred, 

747 ) 

748 

749 def get_config(self): 

750 return { 

751 "num_classes": self.num_classes, 

752 "name": self.name, 

753 "dtype": self._dtype, 

754 "ignore_class": self.ignore_class, 

755 "sparse_y_pred": self.sparse_y_pred, 

756 "axis": self.axis, 

757 } 

758