Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/metrics/confusion_metrics.py: 28%

331 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Confusion metrics, i.e. metrics based on True/False positives/negatives.""" 

16 

17import abc 

18 

19import numpy as np 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src import activations 

23from keras.src import backend 

24from keras.src.dtensor import utils as dtensor_utils 

25from keras.src.metrics import base_metric 

26from keras.src.utils import metrics_utils 

27from keras.src.utils.generic_utils import to_list 

28from keras.src.utils.tf_utils import is_tensor_or_variable 

29 

30# isort: off 

31from tensorflow.python.util.tf_export import keras_export 

32 

33 

34class _ConfusionMatrixConditionCount(base_metric.Metric): 

35 """Calculates the number of the given confusion matrix condition. 

36 

37 Args: 

38 confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions. 

39 thresholds: (Optional) Defaults to 0.5. A float value or a python 

40 list/tuple of float threshold values in [0, 1]. A threshold is compared 

41 with prediction values to determine the truth value of predictions 

42 (i.e., above the threshold is `true`, below is `false`). One metric 

43 value is generated for each threshold value. 

44 name: (Optional) string name of the metric instance. 

45 dtype: (Optional) data type of the metric result. 

46 """ 

47 

48 def __init__( 

49 self, confusion_matrix_cond, thresholds=None, name=None, dtype=None 

50 ): 

51 super().__init__(name=name, dtype=dtype) 

52 self._confusion_matrix_cond = confusion_matrix_cond 

53 self.init_thresholds = thresholds 

54 self.thresholds = metrics_utils.parse_init_thresholds( 

55 thresholds, default_threshold=0.5 

56 ) 

57 self._thresholds_distributed_evenly = ( 

58 metrics_utils.is_evenly_distributed_thresholds(self.thresholds) 

59 ) 

60 self.accumulator = self.add_weight( 

61 "accumulator", shape=(len(self.thresholds),), initializer="zeros" 

62 ) 

63 

64 def update_state(self, y_true, y_pred, sample_weight=None): 

65 """Accumulates the metric statistics. 

66 

67 Args: 

68 y_true: The ground truth values. 

69 y_pred: The predicted values. 

70 sample_weight: Optional weighting of each example. Defaults to 1. Can 

71 be a `Tensor` whose rank is either 0, or the same rank as `y_true`, 

72 and must be broadcastable to `y_true`. 

73 

74 Returns: 

75 Update op. 

76 """ 

77 return metrics_utils.update_confusion_matrix_variables( 

78 {self._confusion_matrix_cond: self.accumulator}, 

79 y_true, 

80 y_pred, 

81 thresholds=self.thresholds, 

82 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 

83 sample_weight=sample_weight, 

84 ) 

85 

86 def result(self): 

87 if len(self.thresholds) == 1: 

88 result = self.accumulator[0] 

89 else: 

90 result = self.accumulator 

91 return tf.convert_to_tensor(result) 

92 

93 def reset_state(self): 

94 backend.batch_set_value( 

95 [(v, np.zeros(v.shape.as_list())) for v in self.variables] 

96 ) 

97 

98 def get_config(self): 

99 config = {"thresholds": self.init_thresholds} 

100 base_config = super().get_config() 

101 return dict(list(base_config.items()) + list(config.items())) 

102 

103 

104@keras_export("keras.metrics.FalsePositives") 

105class FalsePositives(_ConfusionMatrixConditionCount): 

106 """Calculates the number of false positives. 

107 

108 If `sample_weight` is given, calculates the sum of the weights of 

109 false positives. This metric creates one local variable, `accumulator` 

110 that is used to keep track of the number of false positives. 

111 

112 If `sample_weight` is `None`, weights default to 1. 

113 Use `sample_weight` of 0 to mask values. 

114 

115 Args: 

116 thresholds: (Optional) Defaults to 0.5. A float value, or a Python 

117 list/tuple of float threshold values in [0, 1]. A threshold is compared 

118 with prediction values to determine the truth value of predictions 

119 (i.e., above the threshold is `true`, below is `false`). If used with a 

120 loss function that sets `from_logits=True` (i.e. no sigmoid applied to 

121 predictions), `thresholds` should be set to 0. One metric value is 

122 generated for each threshold value. 

123 name: (Optional) string name of the metric instance. 

124 dtype: (Optional) data type of the metric result. 

125 

126 Standalone usage: 

127 

128 >>> m = tf.keras.metrics.FalsePositives() 

129 >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1]) 

130 >>> m.result().numpy() 

131 2.0 

132 

133 >>> m.reset_state() 

134 >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 

135 >>> m.result().numpy() 

136 1.0 

137 

138 Usage with `compile()` API: 

139 

140 ```python 

141 model.compile(optimizer='sgd', 

142 loss='mse', 

143 metrics=[tf.keras.metrics.FalsePositives()]) 

144 ``` 

145 

146 Usage with a loss with `from_logits=True`: 

147 

148 ```python 

149 model.compile(optimizer='adam', 

150 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 

151 metrics=[tf.keras.metrics.FalsePositives(thresholds=0)]) 

152 ``` 

153 """ 

154 

155 @dtensor_utils.inject_mesh 

156 def __init__(self, thresholds=None, name=None, dtype=None): 

157 super().__init__( 

158 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES, 

159 thresholds=thresholds, 

160 name=name, 

161 dtype=dtype, 

162 ) 

163 

164 

165@keras_export("keras.metrics.FalseNegatives") 

166class FalseNegatives(_ConfusionMatrixConditionCount): 

167 """Calculates the number of false negatives. 

168 

169 If `sample_weight` is given, calculates the sum of the weights of 

170 false negatives. This metric creates one local variable, `accumulator` 

171 that is used to keep track of the number of false negatives. 

172 

173 If `sample_weight` is `None`, weights default to 1. 

174 Use `sample_weight` of 0 to mask values. 

175 

176 Args: 

177 thresholds: (Optional) Defaults to 0.5. A float value, or a Python 

178 list/tuple of float threshold values in [0, 1]. A threshold is compared 

179 with prediction values to determine the truth value of predictions 

180 (i.e., above the threshold is `true`, below is `false`). If used with a 

181 loss function that sets `from_logits=True` (i.e. no sigmoid applied to 

182 predictions), `thresholds` should be set to 0. One metric value is 

183 generated for each threshold value. 

184 name: (Optional) string name of the metric instance. 

185 dtype: (Optional) data type of the metric result. 

186 

187 Standalone usage: 

188 

189 >>> m = tf.keras.metrics.FalseNegatives() 

190 >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0]) 

191 >>> m.result().numpy() 

192 2.0 

193 

194 >>> m.reset_state() 

195 >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0]) 

196 >>> m.result().numpy() 

197 1.0 

198 

199 Usage with `compile()` API: 

200 

201 ```python 

202 model.compile(optimizer='sgd', 

203 loss='mse', 

204 metrics=[tf.keras.metrics.FalseNegatives()]) 

205 ``` 

206 

207 Usage with a loss with `from_logits=True`: 

208 

209 ```python 

210 model.compile(optimizer='adam', 

211 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 

212 metrics=[tf.keras.metrics.FalseNegatives(thresholds=0)]) 

213 ``` 

214 """ 

215 

216 @dtensor_utils.inject_mesh 

217 def __init__(self, thresholds=None, name=None, dtype=None): 

218 super().__init__( 

219 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES, 

220 thresholds=thresholds, 

221 name=name, 

222 dtype=dtype, 

223 ) 

224 

225 

226@keras_export("keras.metrics.TrueNegatives") 

227class TrueNegatives(_ConfusionMatrixConditionCount): 

228 """Calculates the number of true negatives. 

229 

230 If `sample_weight` is given, calculates the sum of the weights of 

231 true negatives. This metric creates one local variable, `accumulator` 

232 that is used to keep track of the number of true negatives. 

233 

234 If `sample_weight` is `None`, weights default to 1. 

235 Use `sample_weight` of 0 to mask values. 

236 

237 Args: 

238 thresholds: (Optional) Defaults to 0.5. A float value, or a Python 

239 list/tuple of float threshold values in [0, 1]. A threshold is compared 

240 with prediction values to determine the truth value of predictions 

241 (i.e., above the threshold is `true`, below is `false`). If used with a 

242 loss function that sets `from_logits=True` (i.e. no sigmoid applied to 

243 predictions), `thresholds` should be set to 0. One metric value is 

244 generated for each threshold value. 

245 name: (Optional) string name of the metric instance. 

246 dtype: (Optional) data type of the metric result. 

247 

248 Standalone usage: 

249 

250 >>> m = tf.keras.metrics.TrueNegatives() 

251 >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0]) 

252 >>> m.result().numpy() 

253 2.0 

254 

255 >>> m.reset_state() 

256 >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0]) 

257 >>> m.result().numpy() 

258 1.0 

259 

260 Usage with `compile()` API: 

261 

262 ```python 

263 model.compile(optimizer='sgd', 

264 loss='mse', 

265 metrics=[tf.keras.metrics.TrueNegatives()]) 

266 ``` 

267 

268 Usage with a loss with `from_logits=True`: 

269 

270 ```python 

271 model.compile(optimizer='adam', 

272 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 

273 metrics=[tf.keras.metrics.TrueNegatives(thresholds=0)]) 

274 ``` 

275 """ 

276 

277 @dtensor_utils.inject_mesh 

278 def __init__(self, thresholds=None, name=None, dtype=None): 

279 super().__init__( 

280 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES, 

281 thresholds=thresholds, 

282 name=name, 

283 dtype=dtype, 

284 ) 

285 

286 

287@keras_export("keras.metrics.TruePositives") 

288class TruePositives(_ConfusionMatrixConditionCount): 

289 """Calculates the number of true positives. 

290 

291 If `sample_weight` is given, calculates the sum of the weights of 

292 true positives. This metric creates one local variable, `true_positives` 

293 that is used to keep track of the number of true positives. 

294 

295 If `sample_weight` is `None`, weights default to 1. 

296 Use `sample_weight` of 0 to mask values. 

297 

298 Args: 

299 thresholds: (Optional) Defaults to 0.5. A float value, or a Python 

300 list/tuple of float threshold values in [0, 1]. A threshold is compared 

301 with prediction values to determine the truth value of predictions 

302 (i.e., above the threshold is `true`, below is `false`). If used with a 

303 loss function that sets `from_logits=True` (i.e. no sigmoid applied to 

304 predictions), `thresholds` should be set to 0. One metric value is 

305 generated for each threshold value. 

306 name: (Optional) string name of the metric instance. 

307 dtype: (Optional) data type of the metric result. 

308 

309 Standalone usage: 

310 

311 >>> m = tf.keras.metrics.TruePositives() 

312 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 

313 >>> m.result().numpy() 

314 2.0 

315 

316 >>> m.reset_state() 

317 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 

318 >>> m.result().numpy() 

319 1.0 

320 

321 Usage with `compile()` API: 

322 

323 ```python 

324 model.compile(optimizer='sgd', 

325 loss='mse', 

326 metrics=[tf.keras.metrics.TruePositives()]) 

327 ``` 

328 

329 Usage with a loss with `from_logits=True`: 

330 

331 ```python 

332 model.compile(optimizer='adam', 

333 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 

334 metrics=[tf.keras.metrics.TruePositives(thresholds=0)]) 

335 ``` 

336 """ 

337 

338 @dtensor_utils.inject_mesh 

339 def __init__(self, thresholds=None, name=None, dtype=None): 

340 super().__init__( 

341 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES, 

342 thresholds=thresholds, 

343 name=name, 

344 dtype=dtype, 

345 ) 

346 

347 

348@keras_export("keras.metrics.Precision") 

349class Precision(base_metric.Metric): 

350 """Computes the precision of the predictions with respect to the labels. 

351 

352 The metric creates two local variables, `true_positives` and 

353 `false_positives` that are used to compute the precision. This value is 

354 ultimately returned as `precision`, an idempotent operation that simply 

355 divides `true_positives` by the sum of `true_positives` and 

356 `false_positives`. 

357 

358 If `sample_weight` is `None`, weights default to 1. 

359 Use `sample_weight` of 0 to mask values. 

360 

361 If `top_k` is set, we'll calculate precision as how often on average a class 

362 among the top-k classes with the highest predicted values of a batch entry 

363 is correct and can be found in the label for that entry. 

364 

365 If `class_id` is specified, we calculate precision by considering only the 

366 entries in the batch for which `class_id` is above the threshold and/or in 

367 the top-k highest predictions, and computing the fraction of them for which 

368 `class_id` is indeed a correct label. 

369 

370 Args: 

371 thresholds: (Optional) A float value, or a Python list/tuple of float 

372 threshold values in [0, 1]. A threshold is compared with prediction 

373 values to determine the truth value of predictions (i.e., above the 

374 threshold is `true`, below is `false`). If used with a loss function 

375 that sets `from_logits=True` (i.e. no sigmoid applied to predictions), 

376 `thresholds` should be set to 0. One metric value is generated for each 

377 threshold value. If neither thresholds nor top_k are set, the default is 

378 to calculate precision with `thresholds=0.5`. 

379 top_k: (Optional) Unset by default. An int value specifying the top-k 

380 predictions to consider when calculating precision. 

381 class_id: (Optional) Integer class ID for which we want binary metrics. 

382 This must be in the half-open interval `[0, num_classes)`, where 

383 `num_classes` is the last dimension of predictions. 

384 name: (Optional) string name of the metric instance. 

385 dtype: (Optional) data type of the metric result. 

386 

387 Standalone usage: 

388 

389 >>> m = tf.keras.metrics.Precision() 

390 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 

391 >>> m.result().numpy() 

392 0.6666667 

393 

394 >>> m.reset_state() 

395 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 

396 >>> m.result().numpy() 

397 1.0 

398 

399 >>> # With top_k=2, it will calculate precision over y_true[:2] 

400 >>> # and y_pred[:2] 

401 >>> m = tf.keras.metrics.Precision(top_k=2) 

402 >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) 

403 >>> m.result().numpy() 

404 0.0 

405 

406 >>> # With top_k=4, it will calculate precision over y_true[:4] 

407 >>> # and y_pred[:4] 

408 >>> m = tf.keras.metrics.Precision(top_k=4) 

409 >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) 

410 >>> m.result().numpy() 

411 0.5 

412 

413 Usage with `compile()` API: 

414 

415 ```python 

416 model.compile(optimizer='sgd', 

417 loss='mse', 

418 metrics=[tf.keras.metrics.Precision()]) 

419 ``` 

420 

421 Usage with a loss with `from_logits=True`: 

422 

423 ```python 

424 model.compile(optimizer='adam', 

425 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 

426 metrics=[tf.keras.metrics.Precision(thresholds=0)]) 

427 ``` 

428 """ 

429 

430 @dtensor_utils.inject_mesh 

431 def __init__( 

432 self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None 

433 ): 

434 super().__init__(name=name, dtype=dtype) 

435 self.init_thresholds = thresholds 

436 self.top_k = top_k 

437 self.class_id = class_id 

438 

439 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF 

440 self.thresholds = metrics_utils.parse_init_thresholds( 

441 thresholds, default_threshold=default_threshold 

442 ) 

443 self._thresholds_distributed_evenly = ( 

444 metrics_utils.is_evenly_distributed_thresholds(self.thresholds) 

445 ) 

446 self.true_positives = self.add_weight( 

447 "true_positives", shape=(len(self.thresholds),), initializer="zeros" 

448 ) 

449 self.false_positives = self.add_weight( 

450 "false_positives", 

451 shape=(len(self.thresholds),), 

452 initializer="zeros", 

453 ) 

454 

455 def update_state(self, y_true, y_pred, sample_weight=None): 

456 """Accumulates true positive and false positive statistics. 

457 

458 Args: 

459 y_true: The ground truth values, with the same dimensions as `y_pred`. 

460 Will be cast to `bool`. 

461 y_pred: The predicted values. Each element must be in the range 

462 `[0, 1]`. 

463 sample_weight: Optional weighting of each example. Defaults to 1. Can 

464 be a `Tensor` whose rank is either 0, or the same rank as `y_true`, 

465 and must be broadcastable to `y_true`. 

466 

467 Returns: 

468 Update op. 

469 """ 

470 return metrics_utils.update_confusion_matrix_variables( 

471 { 

472 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501 

473 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501 

474 }, 

475 y_true, 

476 y_pred, 

477 thresholds=self.thresholds, 

478 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 

479 top_k=self.top_k, 

480 class_id=self.class_id, 

481 sample_weight=sample_weight, 

482 ) 

483 

484 def result(self): 

485 result = tf.math.divide_no_nan( 

486 self.true_positives, 

487 tf.math.add(self.true_positives, self.false_positives), 

488 ) 

489 return result[0] if len(self.thresholds) == 1 else result 

490 

491 def reset_state(self): 

492 num_thresholds = len(to_list(self.thresholds)) 

493 backend.batch_set_value( 

494 [ 

495 (v, np.zeros((num_thresholds,))) 

496 for v in (self.true_positives, self.false_positives) 

497 ] 

498 ) 

499 

500 def get_config(self): 

501 config = { 

502 "thresholds": self.init_thresholds, 

503 "top_k": self.top_k, 

504 "class_id": self.class_id, 

505 } 

506 base_config = super().get_config() 

507 return dict(list(base_config.items()) + list(config.items())) 

508 

509 

510@keras_export("keras.metrics.Recall") 

511class Recall(base_metric.Metric): 

512 """Computes the recall of the predictions with respect to the labels. 

513 

514 This metric creates two local variables, `true_positives` and 

515 `false_negatives`, that are used to compute the recall. This value is 

516 ultimately returned as `recall`, an idempotent operation that simply divides 

517 `true_positives` by the sum of `true_positives` and `false_negatives`. 

518 

519 If `sample_weight` is `None`, weights default to 1. 

520 Use `sample_weight` of 0 to mask values. 

521 

522 If `top_k` is set, recall will be computed as how often on average a class 

523 among the labels of a batch entry is in the top-k predictions. 

524 

525 If `class_id` is specified, we calculate recall by considering only the 

526 entries in the batch for which `class_id` is in the label, and computing the 

527 fraction of them for which `class_id` is above the threshold and/or in the 

528 top-k predictions. 

529 

530 Args: 

531 thresholds: (Optional) A float value, or a Python list/tuple of float 

532 threshold values in [0, 1]. A threshold is compared with prediction 

533 values to determine the truth value of predictions (i.e., above the 

534 threshold is `true`, below is `false`). If used with a loss function 

535 that sets `from_logits=True` (i.e. no sigmoid applied to predictions), 

536 `thresholds` should be set to 0. One metric value is generated for each 

537 threshold value. If neither thresholds nor top_k are set, the default is 

538 to calculate recall with `thresholds=0.5`. 

539 top_k: (Optional) Unset by default. An int value specifying the top-k 

540 predictions to consider when calculating recall. 

541 class_id: (Optional) Integer class ID for which we want binary metrics. 

542 This must be in the half-open interval `[0, num_classes)`, where 

543 `num_classes` is the last dimension of predictions. 

544 name: (Optional) string name of the metric instance. 

545 dtype: (Optional) data type of the metric result. 

546 

547 Standalone usage: 

548 

549 >>> m = tf.keras.metrics.Recall() 

550 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 

551 >>> m.result().numpy() 

552 0.6666667 

553 

554 >>> m.reset_state() 

555 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 

556 >>> m.result().numpy() 

557 1.0 

558 

559 Usage with `compile()` API: 

560 

561 ```python 

562 model.compile(optimizer='sgd', 

563 loss='mse', 

564 metrics=[tf.keras.metrics.Recall()]) 

565 ``` 

566 

567 Usage with a loss with `from_logits=True`: 

568 

569 ```python 

570 model.compile(optimizer='adam', 

571 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 

572 metrics=[tf.keras.metrics.Recall(thresholds=0)]) 

573 ``` 

574 """ 

575 

576 @dtensor_utils.inject_mesh 

577 def __init__( 

578 self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None 

579 ): 

580 super().__init__(name=name, dtype=dtype) 

581 self.init_thresholds = thresholds 

582 self.top_k = top_k 

583 self.class_id = class_id 

584 

585 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF 

586 self.thresholds = metrics_utils.parse_init_thresholds( 

587 thresholds, default_threshold=default_threshold 

588 ) 

589 self._thresholds_distributed_evenly = ( 

590 metrics_utils.is_evenly_distributed_thresholds(self.thresholds) 

591 ) 

592 self.true_positives = self.add_weight( 

593 "true_positives", shape=(len(self.thresholds),), initializer="zeros" 

594 ) 

595 self.false_negatives = self.add_weight( 

596 "false_negatives", 

597 shape=(len(self.thresholds),), 

598 initializer="zeros", 

599 ) 

600 

601 def update_state(self, y_true, y_pred, sample_weight=None): 

602 """Accumulates true positive and false negative statistics. 

603 

604 Args: 

605 y_true: The ground truth values, with the same dimensions as `y_pred`. 

606 Will be cast to `bool`. 

607 y_pred: The predicted values. Each element must be in the range 

608 `[0, 1]`. 

609 sample_weight: Optional weighting of each example. Defaults to 1. Can 

610 be a `Tensor` whose rank is either 0, or the same rank as `y_true`, 

611 and must be broadcastable to `y_true`. 

612 

613 Returns: 

614 Update op. 

615 """ 

616 return metrics_utils.update_confusion_matrix_variables( 

617 { 

618 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501 

619 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501 

620 }, 

621 y_true, 

622 y_pred, 

623 thresholds=self.thresholds, 

624 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 

625 top_k=self.top_k, 

626 class_id=self.class_id, 

627 sample_weight=sample_weight, 

628 ) 

629 

630 def result(self): 

631 result = tf.math.divide_no_nan( 

632 self.true_positives, 

633 tf.math.add(self.true_positives, self.false_negatives), 

634 ) 

635 return result[0] if len(self.thresholds) == 1 else result 

636 

637 def reset_state(self): 

638 num_thresholds = len(to_list(self.thresholds)) 

639 backend.batch_set_value( 

640 [ 

641 (v, np.zeros((num_thresholds,))) 

642 for v in (self.true_positives, self.false_negatives) 

643 ] 

644 ) 

645 

646 def get_config(self): 

647 config = { 

648 "thresholds": self.init_thresholds, 

649 "top_k": self.top_k, 

650 "class_id": self.class_id, 

651 } 

652 base_config = super().get_config() 

653 return dict(list(base_config.items()) + list(config.items())) 

654 

655 

656class SensitivitySpecificityBase(base_metric.Metric, metaclass=abc.ABCMeta): 

657 """Abstract base class for computing sensitivity and specificity. 

658 

659 For additional information about specificity and sensitivity, see 

660 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). 

661 """ 

662 

663 def __init__( 

664 self, value, num_thresholds=200, class_id=None, name=None, dtype=None 

665 ): 

666 super().__init__(name=name, dtype=dtype) 

667 if num_thresholds <= 0: 

668 raise ValueError( 

669 "Argument `num_thresholds` must be an integer > 0. " 

670 f"Received: num_thresholds={num_thresholds}" 

671 ) 

672 self.value = value 

673 self.class_id = class_id 

674 self.true_positives = self.add_weight( 

675 "true_positives", shape=(num_thresholds,), initializer="zeros" 

676 ) 

677 self.true_negatives = self.add_weight( 

678 "true_negatives", shape=(num_thresholds,), initializer="zeros" 

679 ) 

680 self.false_positives = self.add_weight( 

681 "false_positives", shape=(num_thresholds,), initializer="zeros" 

682 ) 

683 self.false_negatives = self.add_weight( 

684 "false_negatives", shape=(num_thresholds,), initializer="zeros" 

685 ) 

686 

687 # Compute `num_thresholds` thresholds in [0, 1] 

688 if num_thresholds == 1: 

689 self.thresholds = [0.5] 

690 self._thresholds_distributed_evenly = False 

691 else: 

692 thresholds = [ 

693 (i + 1) * 1.0 / (num_thresholds - 1) 

694 for i in range(num_thresholds - 2) 

695 ] 

696 self.thresholds = [0.0] + thresholds + [1.0] 

697 self._thresholds_distributed_evenly = True 

698 

699 def update_state(self, y_true, y_pred, sample_weight=None): 

700 """Accumulates confusion matrix statistics. 

701 

702 Args: 

703 y_true: The ground truth values. 

704 y_pred: The predicted values. 

705 sample_weight: Optional weighting of each example. Defaults to 1. Can 

706 be a `Tensor` whose rank is either 0, or the same rank as `y_true`, 

707 and must be broadcastable to `y_true`. 

708 

709 Returns: 

710 Update op. 

711 """ 

712 return metrics_utils.update_confusion_matrix_variables( 

713 { 

714 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501 

715 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501 

716 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501 

717 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501 

718 }, 

719 y_true, 

720 y_pred, 

721 thresholds=self.thresholds, 

722 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 

723 class_id=self.class_id, 

724 sample_weight=sample_weight, 

725 ) 

726 

727 def reset_state(self): 

728 num_thresholds = len(self.thresholds) 

729 confusion_matrix_variables = ( 

730 self.true_positives, 

731 self.true_negatives, 

732 self.false_positives, 

733 self.false_negatives, 

734 ) 

735 backend.batch_set_value( 

736 [ 

737 (v, np.zeros((num_thresholds,))) 

738 for v in confusion_matrix_variables 

739 ] 

740 ) 

741 

742 def get_config(self): 

743 config = {"class_id": self.class_id} 

744 base_config = super().get_config() 

745 return dict(list(base_config.items()) + list(config.items())) 

746 

747 def _find_max_under_constraint(self, constrained, dependent, predicate): 

748 """Returns the maximum of dependent_statistic that satisfies the 

749 constraint. 

750 

751 Args: 

752 constrained: Over these values the constraint 

753 is specified. A rank-1 tensor. 

754 dependent: From these values the maximum that satiesfies the 

755 constraint is selected. Values in this tensor and in 

756 `constrained` are linked by having the same threshold at each 

757 position, hence this tensor must have the same shape. 

758 predicate: A binary boolean functor to be applied to arguments 

759 `constrained` and `self.value`, e.g. `tf.greater`. 

760 

761 Returns: 

762 maximal dependent value, if no value satiesfies the constraint 0.0. 

763 """ 

764 feasible = tf.where(predicate(constrained, self.value)) 

765 feasible_exists = tf.greater(tf.size(feasible), 0) 

766 max_dependent = tf.reduce_max(tf.gather(dependent, feasible)) 

767 

768 return tf.where(feasible_exists, max_dependent, 0.0) 

769 

770 

771@keras_export("keras.metrics.SensitivityAtSpecificity") 

772class SensitivityAtSpecificity(SensitivitySpecificityBase): 

773 """Computes best sensitivity where specificity is >= specified value. 

774 

775 the sensitivity at a given specificity. 

776 

777 `Sensitivity` measures the proportion of actual positives that are correctly 

778 identified as such (tp / (tp + fn)). 

779 `Specificity` measures the proportion of actual negatives that are correctly 

780 identified as such (tn / (tn + fp)). 

781 

782 This metric creates four local variables, `true_positives`, 

783 `true_negatives`, `false_positives` and `false_negatives` that are used to 

784 compute the sensitivity at the given specificity. The threshold for the 

785 given specificity value is computed and used to evaluate the corresponding 

786 sensitivity. 

787 

788 If `sample_weight` is `None`, weights default to 1. 

789 Use `sample_weight` of 0 to mask values. 

790 

791 If `class_id` is specified, we calculate precision by considering only the 

792 entries in the batch for which `class_id` is above the threshold 

793 predictions, and computing the fraction of them for which `class_id` is 

794 indeed a correct label. 

795 

796 For additional information about specificity and sensitivity, see 

797 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). 

798 

799 Args: 

800 specificity: A scalar value in range `[0, 1]`. 

801 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 

802 use for matching the given specificity. 

803 class_id: (Optional) Integer class ID for which we want binary metrics. 

804 This must be in the half-open interval `[0, num_classes)`, where 

805 `num_classes` is the last dimension of predictions. 

806 name: (Optional) string name of the metric instance. 

807 dtype: (Optional) data type of the metric result. 

808 

809 Standalone usage: 

810 

811 >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5) 

812 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) 

813 >>> m.result().numpy() 

814 0.5 

815 

816 >>> m.reset_state() 

817 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], 

818 ... sample_weight=[1, 1, 2, 2, 1]) 

819 >>> m.result().numpy() 

820 0.333333 

821 

822 Usage with `compile()` API: 

823 

824 ```python 

825 model.compile( 

826 optimizer='sgd', 

827 loss='mse', 

828 metrics=[tf.keras.metrics.SensitivityAtSpecificity()]) 

829 ``` 

830 """ 

831 

832 @dtensor_utils.inject_mesh 

833 def __init__( 

834 self, 

835 specificity, 

836 num_thresholds=200, 

837 class_id=None, 

838 name=None, 

839 dtype=None, 

840 ): 

841 if specificity < 0 or specificity > 1: 

842 raise ValueError( 

843 "Argument `specificity` must be in the range [0, 1]. " 

844 f"Received: specificity={specificity}" 

845 ) 

846 self.specificity = specificity 

847 self.num_thresholds = num_thresholds 

848 super().__init__( 

849 specificity, 

850 num_thresholds=num_thresholds, 

851 class_id=class_id, 

852 name=name, 

853 dtype=dtype, 

854 ) 

855 

856 def result(self): 

857 specificities = tf.math.divide_no_nan( 

858 self.true_negatives, 

859 tf.math.add(self.true_negatives, self.false_positives), 

860 ) 

861 sensitivities = tf.math.divide_no_nan( 

862 self.true_positives, 

863 tf.math.add(self.true_positives, self.false_negatives), 

864 ) 

865 return self._find_max_under_constraint( 

866 specificities, sensitivities, tf.greater_equal 

867 ) 

868 

869 def get_config(self): 

870 config = { 

871 "num_thresholds": self.num_thresholds, 

872 "specificity": self.specificity, 

873 } 

874 base_config = super().get_config() 

875 return dict(list(base_config.items()) + list(config.items())) 

876 

877 

878@keras_export("keras.metrics.SpecificityAtSensitivity") 

879class SpecificityAtSensitivity(SensitivitySpecificityBase): 

880 """Computes best specificity where sensitivity is >= specified value. 

881 

882 `Sensitivity` measures the proportion of actual positives that are correctly 

883 identified as such (tp / (tp + fn)). 

884 `Specificity` measures the proportion of actual negatives that are correctly 

885 identified as such (tn / (tn + fp)). 

886 

887 This metric creates four local variables, `true_positives`, 

888 `true_negatives`, `false_positives` and `false_negatives` that are used to 

889 compute the specificity at the given sensitivity. The threshold for the 

890 given sensitivity value is computed and used to evaluate the corresponding 

891 specificity. 

892 

893 If `sample_weight` is `None`, weights default to 1. 

894 Use `sample_weight` of 0 to mask values. 

895 

896 If `class_id` is specified, we calculate precision by considering only the 

897 entries in the batch for which `class_id` is above the threshold 

898 predictions, and computing the fraction of them for which `class_id` is 

899 indeed a correct label. 

900 

901 For additional information about specificity and sensitivity, see 

902 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). 

903 

904 Args: 

905 sensitivity: A scalar value in range `[0, 1]`. 

906 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 

907 use for matching the given sensitivity. 

908 class_id: (Optional) Integer class ID for which we want binary metrics. 

909 This must be in the half-open interval `[0, num_classes)`, where 

910 `num_classes` is the last dimension of predictions. 

911 name: (Optional) string name of the metric instance. 

912 dtype: (Optional) data type of the metric result. 

913 

914 Standalone usage: 

915 

916 >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5) 

917 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) 

918 >>> m.result().numpy() 

919 0.66666667 

920 

921 >>> m.reset_state() 

922 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], 

923 ... sample_weight=[1, 1, 2, 2, 2]) 

924 >>> m.result().numpy() 

925 0.5 

926 

927 Usage with `compile()` API: 

928 

929 ```python 

930 model.compile( 

931 optimizer='sgd', 

932 loss='mse', 

933 metrics=[tf.keras.metrics.SpecificityAtSensitivity()]) 

934 ``` 

935 """ 

936 

937 @dtensor_utils.inject_mesh 

938 def __init__( 

939 self, 

940 sensitivity, 

941 num_thresholds=200, 

942 class_id=None, 

943 name=None, 

944 dtype=None, 

945 ): 

946 if sensitivity < 0 or sensitivity > 1: 

947 raise ValueError( 

948 "Argument `sensitivity` must be in the range [0, 1]. " 

949 f"Received: sensitivity={sensitivity}" 

950 ) 

951 self.sensitivity = sensitivity 

952 self.num_thresholds = num_thresholds 

953 super().__init__( 

954 sensitivity, 

955 num_thresholds=num_thresholds, 

956 class_id=class_id, 

957 name=name, 

958 dtype=dtype, 

959 ) 

960 

961 def result(self): 

962 sensitivities = tf.math.divide_no_nan( 

963 self.true_positives, 

964 tf.math.add(self.true_positives, self.false_negatives), 

965 ) 

966 specificities = tf.math.divide_no_nan( 

967 self.true_negatives, 

968 tf.math.add(self.true_negatives, self.false_positives), 

969 ) 

970 return self._find_max_under_constraint( 

971 sensitivities, specificities, tf.greater_equal 

972 ) 

973 

974 def get_config(self): 

975 config = { 

976 "num_thresholds": self.num_thresholds, 

977 "sensitivity": self.sensitivity, 

978 } 

979 base_config = super().get_config() 

980 return dict(list(base_config.items()) + list(config.items())) 

981 

982 

983@keras_export("keras.metrics.PrecisionAtRecall") 

984class PrecisionAtRecall(SensitivitySpecificityBase): 

985 """Computes best precision where recall is >= specified value. 

986 

987 This metric creates four local variables, `true_positives`, 

988 `true_negatives`, `false_positives` and `false_negatives` that are used to 

989 compute the precision at the given recall. The threshold for the given 

990 recall value is computed and used to evaluate the corresponding precision. 

991 

992 If `sample_weight` is `None`, weights default to 1. 

993 Use `sample_weight` of 0 to mask values. 

994 

995 If `class_id` is specified, we calculate precision by considering only the 

996 entries in the batch for which `class_id` is above the threshold 

997 predictions, and computing the fraction of them for which `class_id` is 

998 indeed a correct label. 

999 

1000 Args: 

1001 recall: A scalar value in range `[0, 1]`. 

1002 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 

1003 use for matching the given recall. 

1004 class_id: (Optional) Integer class ID for which we want binary metrics. 

1005 This must be in the half-open interval `[0, num_classes)`, where 

1006 `num_classes` is the last dimension of predictions. 

1007 name: (Optional) string name of the metric instance. 

1008 dtype: (Optional) data type of the metric result. 

1009 

1010 Standalone usage: 

1011 

1012 >>> m = tf.keras.metrics.PrecisionAtRecall(0.5) 

1013 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) 

1014 >>> m.result().numpy() 

1015 0.5 

1016 

1017 >>> m.reset_state() 

1018 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], 

1019 ... sample_weight=[2, 2, 2, 1, 1]) 

1020 >>> m.result().numpy() 

1021 0.33333333 

1022 

1023 Usage with `compile()` API: 

1024 

1025 ```python 

1026 model.compile( 

1027 optimizer='sgd', 

1028 loss='mse', 

1029 metrics=[tf.keras.metrics.PrecisionAtRecall(recall=0.8)]) 

1030 ``` 

1031 """ 

1032 

1033 @dtensor_utils.inject_mesh 

1034 def __init__( 

1035 self, recall, num_thresholds=200, class_id=None, name=None, dtype=None 

1036 ): 

1037 if recall < 0 or recall > 1: 

1038 raise ValueError( 

1039 "Argument `recall` must be in the range [0, 1]. " 

1040 f"Received: recall={recall}" 

1041 ) 

1042 self.recall = recall 

1043 self.num_thresholds = num_thresholds 

1044 super().__init__( 

1045 value=recall, 

1046 num_thresholds=num_thresholds, 

1047 class_id=class_id, 

1048 name=name, 

1049 dtype=dtype, 

1050 ) 

1051 

1052 def result(self): 

1053 recalls = tf.math.divide_no_nan( 

1054 self.true_positives, 

1055 tf.math.add(self.true_positives, self.false_negatives), 

1056 ) 

1057 precisions = tf.math.divide_no_nan( 

1058 self.true_positives, 

1059 tf.math.add(self.true_positives, self.false_positives), 

1060 ) 

1061 return self._find_max_under_constraint( 

1062 recalls, precisions, tf.greater_equal 

1063 ) 

1064 

1065 def get_config(self): 

1066 config = {"num_thresholds": self.num_thresholds, "recall": self.recall} 

1067 base_config = super().get_config() 

1068 return dict(list(base_config.items()) + list(config.items())) 

1069 

1070 

1071@keras_export("keras.metrics.RecallAtPrecision") 

1072class RecallAtPrecision(SensitivitySpecificityBase): 

1073 """Computes best recall where precision is >= specified value. 

1074 

1075 For a given score-label-distribution the required precision might not 

1076 be achievable, in this case 0.0 is returned as recall. 

1077 

1078 This metric creates four local variables, `true_positives`, 

1079 `true_negatives`, `false_positives` and `false_negatives` that are used to 

1080 compute the recall at the given precision. The threshold for the given 

1081 precision value is computed and used to evaluate the corresponding recall. 

1082 

1083 If `sample_weight` is `None`, weights default to 1. 

1084 Use `sample_weight` of 0 to mask values. 

1085 

1086 If `class_id` is specified, we calculate precision by considering only the 

1087 entries in the batch for which `class_id` is above the threshold 

1088 predictions, and computing the fraction of them for which `class_id` is 

1089 indeed a correct label. 

1090 

1091 Args: 

1092 precision: A scalar value in range `[0, 1]`. 

1093 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 

1094 use for matching the given precision. 

1095 class_id: (Optional) Integer class ID for which we want binary metrics. 

1096 This must be in the half-open interval `[0, num_classes)`, where 

1097 `num_classes` is the last dimension of predictions. 

1098 name: (Optional) string name of the metric instance. 

1099 dtype: (Optional) data type of the metric result. 

1100 

1101 Standalone usage: 

1102 

1103 >>> m = tf.keras.metrics.RecallAtPrecision(0.8) 

1104 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) 

1105 >>> m.result().numpy() 

1106 0.5 

1107 

1108 >>> m.reset_state() 

1109 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], 

1110 ... sample_weight=[1, 0, 0, 1]) 

1111 >>> m.result().numpy() 

1112 1.0 

1113 

1114 Usage with `compile()` API: 

1115 

1116 ```python 

1117 model.compile( 

1118 optimizer='sgd', 

1119 loss='mse', 

1120 metrics=[tf.keras.metrics.RecallAtPrecision(precision=0.8)]) 

1121 ``` 

1122 """ 

1123 

1124 @dtensor_utils.inject_mesh 

1125 def __init__( 

1126 self, 

1127 precision, 

1128 num_thresholds=200, 

1129 class_id=None, 

1130 name=None, 

1131 dtype=None, 

1132 ): 

1133 if precision < 0 or precision > 1: 

1134 raise ValueError( 

1135 "Argument `precision` must be in the range [0, 1]. " 

1136 f"Received: precision={precision}" 

1137 ) 

1138 self.precision = precision 

1139 self.num_thresholds = num_thresholds 

1140 super().__init__( 

1141 value=precision, 

1142 num_thresholds=num_thresholds, 

1143 class_id=class_id, 

1144 name=name, 

1145 dtype=dtype, 

1146 ) 

1147 

1148 def result(self): 

1149 precisions = tf.math.divide_no_nan( 

1150 self.true_positives, 

1151 tf.math.add(self.true_positives, self.false_positives), 

1152 ) 

1153 recalls = tf.math.divide_no_nan( 

1154 self.true_positives, 

1155 tf.math.add(self.true_positives, self.false_negatives), 

1156 ) 

1157 return self._find_max_under_constraint( 

1158 precisions, recalls, tf.greater_equal 

1159 ) 

1160 

1161 def get_config(self): 

1162 config = { 

1163 "num_thresholds": self.num_thresholds, 

1164 "precision": self.precision, 

1165 } 

1166 base_config = super().get_config() 

1167 return dict(list(base_config.items()) + list(config.items())) 

1168 

1169 

1170@keras_export("keras.metrics.AUC") 

1171class AUC(base_metric.Metric): 

1172 """Approximates the AUC (Area under the curve) of the ROC or PR curves. 

1173 

1174 The AUC (Area under the curve) of the ROC (Receiver operating 

1175 characteristic; default) or PR (Precision Recall) curves are quality 

1176 measures of binary classifiers. Unlike the accuracy, and like cross-entropy 

1177 losses, ROC-AUC and PR-AUC evaluate all the operational points of a model. 

1178 

1179 This class approximates AUCs using a Riemann sum. During the metric 

1180 accumulation phrase, predictions are accumulated within predefined buckets 

1181 by value. The AUC is then computed by interpolating per-bucket averages. 

1182 These buckets define the evaluated operational points. 

1183 

1184 This metric creates four local variables, `true_positives`, 

1185 `true_negatives`, `false_positives` and `false_negatives` that are used to 

1186 compute the AUC. To discretize the AUC curve, a linearly spaced set of 

1187 thresholds is used to compute pairs of recall and precision values. The area 

1188 under the ROC-curve is therefore computed using the height of the recall 

1189 values by the false positive rate, while the area under the PR-curve is the 

1190 computed using the height of the precision values by the recall. 

1191 

1192 This value is ultimately returned as `auc`, an idempotent operation that 

1193 computes the area under a discretized curve of precision versus recall 

1194 values (computed using the aforementioned variables). The `num_thresholds` 

1195 variable controls the degree of discretization with larger numbers of 

1196 thresholds more closely approximating the true AUC. The quality of the 

1197 approximation may vary dramatically depending on `num_thresholds`. The 

1198 `thresholds` parameter can be used to manually specify thresholds which 

1199 split the predictions more evenly. 

1200 

1201 For a best approximation of the real AUC, `predictions` should be 

1202 distributed approximately uniformly in the range [0, 1] (if 

1203 `from_logits=False`). The quality of the AUC approximation may be poor if 

1204 this is not the case. Setting `summation_method` to 'minoring' or 'majoring' 

1205 can help quantify the error in the approximation by providing lower or upper 

1206 bound estimate of the AUC. 

1207 

1208 If `sample_weight` is `None`, weights default to 1. 

1209 Use `sample_weight` of 0 to mask values. 

1210 

1211 Args: 

1212 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 

1213 use when discretizing the roc curve. Values must be > 1. 

1214 curve: (Optional) Specifies the name of the curve to be computed, 'ROC' 

1215 [default] or 'PR' for the Precision-Recall-curve. 

1216 summation_method: (Optional) Specifies the [Riemann summation method]( 

1217 https://en.wikipedia.org/wiki/Riemann_sum) used. 

1218 'interpolation' (default) applies mid-point summation scheme for 

1219 `ROC`. For PR-AUC, interpolates (true/false) positives but not the 

1220 ratio that is precision (see Davis & Goadrich 2006 for details); 

1221 'minoring' applies left summation for increasing intervals and right 

1222 summation for decreasing intervals; 'majoring' does the opposite. 

1223 name: (Optional) string name of the metric instance. 

1224 dtype: (Optional) data type of the metric result. 

1225 thresholds: (Optional) A list of floating point values to use as the 

1226 thresholds for discretizing the curve. If set, the `num_thresholds` 

1227 parameter is ignored. Values should be in [0, 1]. Endpoint thresholds 

1228 equal to {-epsilon, 1+epsilon} for a small positive epsilon value will 

1229 be automatically included with these to correctly handle predictions 

1230 equal to exactly 0 or 1. 

1231 multi_label: boolean indicating whether multilabel data should be 

1232 treated as such, wherein AUC is computed separately for each label and 

1233 then averaged across labels, or (when False) if the data should be 

1234 flattened into a single label before AUC computation. In the latter 

1235 case, when multilabel data is passed to AUC, each label-prediction pair 

1236 is treated as an individual data point. Should be set to False for 

1237 multi-class data. 

1238 num_labels: (Optional) The number of labels, used when `multi_label` is 

1239 True. If `num_labels` is not specified, then state variables get created 

1240 on the first call to `update_state`. 

1241 label_weights: (Optional) list, array, or tensor of non-negative weights 

1242 used to compute AUCs for multilabel data. When `multi_label` is True, 

1243 the weights are applied to the individual label AUCs when they are 

1244 averaged to produce the multi-label AUC. When it's False, they are used 

1245 to weight the individual label predictions in computing the confusion 

1246 matrix on the flattened data. Note that this is unlike class_weights in 

1247 that class_weights weights the example depending on the value of its 

1248 label, whereas label_weights depends only on the index of that label 

1249 before flattening; therefore `label_weights` should not be used for 

1250 multi-class data. 

1251 from_logits: boolean indicating whether the predictions (`y_pred` in 

1252 `update_state`) are probabilities or sigmoid logits. As a rule of thumb, 

1253 when using a keras loss, the `from_logits` constructor argument of the 

1254 loss should match the AUC `from_logits` constructor argument. 

1255 

1256 Standalone usage: 

1257 

1258 >>> m = tf.keras.metrics.AUC(num_thresholds=3) 

1259 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) 

1260 >>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7] 

1261 >>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] 

1262 >>> # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0] 

1263 >>> # auc = ((((1+0.5)/2)*(1-0)) + (((0.5+0)/2)*(0-0))) = 0.75 

1264 >>> m.result().numpy() 

1265 0.75 

1266 

1267 >>> m.reset_state() 

1268 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], 

1269 ... sample_weight=[1, 0, 0, 1]) 

1270 >>> m.result().numpy() 

1271 1.0 

1272 

1273 Usage with `compile()` API: 

1274 

1275 ```python 

1276 # Reports the AUC of a model outputting a probability. 

1277 model.compile(optimizer='sgd', 

1278 loss=tf.keras.losses.BinaryCrossentropy(), 

1279 metrics=[tf.keras.metrics.AUC()]) 

1280 

1281 # Reports the AUC of a model outputting a logit. 

1282 model.compile(optimizer='sgd', 

1283 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 

1284 metrics=[tf.keras.metrics.AUC(from_logits=True)]) 

1285 ``` 

1286 """ 

1287 

1288 @dtensor_utils.inject_mesh 

1289 def __init__( 

1290 self, 

1291 num_thresholds=200, 

1292 curve="ROC", 

1293 summation_method="interpolation", 

1294 name=None, 

1295 dtype=None, 

1296 thresholds=None, 

1297 multi_label=False, 

1298 num_labels=None, 

1299 label_weights=None, 

1300 from_logits=False, 

1301 ): 

1302 # Validate configurations. 

1303 if isinstance(curve, metrics_utils.AUCCurve) and curve not in list( 

1304 metrics_utils.AUCCurve 

1305 ): 

1306 raise ValueError( 

1307 f'Invalid `curve` argument value "{curve}". ' 

1308 f"Expected one of: {list(metrics_utils.AUCCurve)}" 

1309 ) 

1310 if isinstance( 

1311 summation_method, metrics_utils.AUCSummationMethod 

1312 ) and summation_method not in list(metrics_utils.AUCSummationMethod): 

1313 raise ValueError( 

1314 "Invalid `summation_method` " 

1315 f'argument value "{summation_method}". ' 

1316 f"Expected one of: {list(metrics_utils.AUCSummationMethod)}" 

1317 ) 

1318 

1319 # Update properties. 

1320 self._init_from_thresholds = thresholds is not None 

1321 if thresholds is not None: 

1322 # If specified, use the supplied thresholds. 

1323 self.num_thresholds = len(thresholds) + 2 

1324 thresholds = sorted(thresholds) 

1325 self._thresholds_distributed_evenly = ( 

1326 metrics_utils.is_evenly_distributed_thresholds( 

1327 np.array([0.0] + thresholds + [1.0]) 

1328 ) 

1329 ) 

1330 else: 

1331 if num_thresholds <= 1: 

1332 raise ValueError( 

1333 "Argument `num_thresholds` must be an integer > 1. " 

1334 f"Received: num_thresholds={num_thresholds}" 

1335 ) 

1336 

1337 # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in 

1338 # (0, 1). 

1339 self.num_thresholds = num_thresholds 

1340 thresholds = [ 

1341 (i + 1) * 1.0 / (num_thresholds - 1) 

1342 for i in range(num_thresholds - 2) 

1343 ] 

1344 self._thresholds_distributed_evenly = True 

1345 

1346 # Add an endpoint "threshold" below zero and above one for either 

1347 # threshold method to account for floating point imprecisions. 

1348 self._thresholds = np.array( 

1349 [0.0 - backend.epsilon()] + thresholds + [1.0 + backend.epsilon()] 

1350 ) 

1351 

1352 if isinstance(curve, metrics_utils.AUCCurve): 

1353 self.curve = curve 

1354 else: 

1355 self.curve = metrics_utils.AUCCurve.from_str(curve) 

1356 if isinstance(summation_method, metrics_utils.AUCSummationMethod): 

1357 self.summation_method = summation_method 

1358 else: 

1359 self.summation_method = metrics_utils.AUCSummationMethod.from_str( 

1360 summation_method 

1361 ) 

1362 super().__init__(name=name, dtype=dtype) 

1363 

1364 # Handle multilabel arguments. 

1365 self.multi_label = multi_label 

1366 self.num_labels = num_labels 

1367 if label_weights is not None: 

1368 label_weights = tf.constant(label_weights, dtype=self.dtype) 

1369 tf.debugging.assert_non_negative( 

1370 label_weights, 

1371 message="All values of `label_weights` must be non-negative.", 

1372 ) 

1373 self.label_weights = label_weights 

1374 

1375 else: 

1376 self.label_weights = None 

1377 

1378 self._from_logits = from_logits 

1379 

1380 self._built = False 

1381 if self.multi_label: 

1382 if num_labels: 

1383 shape = tf.TensorShape([None, num_labels]) 

1384 self._build(shape) 

1385 else: 

1386 if num_labels: 

1387 raise ValueError( 

1388 "`num_labels` is needed only when `multi_label` is True." 

1389 ) 

1390 self._build(None) 

1391 

1392 @property 

1393 def thresholds(self): 

1394 """The thresholds used for evaluating AUC.""" 

1395 return list(self._thresholds) 

1396 

1397 def _build(self, shape): 

1398 """Initialize TP, FP, TN, and FN tensors, given the shape of the 

1399 data.""" 

1400 if self.multi_label: 

1401 if shape.ndims != 2: 

1402 raise ValueError( 

1403 "`y_true` must have rank 2 when `multi_label=True`. " 

1404 f"Found rank {shape.ndims}. " 

1405 f"Full shape received for `y_true`: {shape}" 

1406 ) 

1407 self._num_labels = shape[1] 

1408 variable_shape = tf.TensorShape( 

1409 [self.num_thresholds, self._num_labels] 

1410 ) 

1411 else: 

1412 variable_shape = tf.TensorShape([self.num_thresholds]) 

1413 

1414 self._build_input_shape = shape 

1415 # Create metric variables 

1416 self.true_positives = self.add_weight( 

1417 "true_positives", shape=variable_shape, initializer="zeros" 

1418 ) 

1419 self.true_negatives = self.add_weight( 

1420 "true_negatives", shape=variable_shape, initializer="zeros" 

1421 ) 

1422 self.false_positives = self.add_weight( 

1423 "false_positives", shape=variable_shape, initializer="zeros" 

1424 ) 

1425 self.false_negatives = self.add_weight( 

1426 "false_negatives", shape=variable_shape, initializer="zeros" 

1427 ) 

1428 

1429 if self.multi_label: 

1430 with tf.init_scope(): 

1431 # This should only be necessary for handling v1 behavior. In v2, 

1432 # AUC should be initialized outside of any tf.functions, and 

1433 # therefore in eager mode. 

1434 if not tf.executing_eagerly(): 

1435 backend._initialize_variables(backend._get_session()) 

1436 

1437 self._built = True 

1438 

1439 def update_state(self, y_true, y_pred, sample_weight=None): 

1440 """Accumulates confusion matrix statistics. 

1441 

1442 Args: 

1443 y_true: The ground truth values. 

1444 y_pred: The predicted values. 

1445 sample_weight: Optional weighting of each example. Defaults to 1. Can 

1446 be a `Tensor` whose rank is either 0, or the same rank as `y_true`, 

1447 and must be broadcastable to `y_true`. 

1448 

1449 Returns: 

1450 Update op. 

1451 """ 

1452 if not self._built: 

1453 self._build(tf.TensorShape(y_pred.shape)) 

1454 

1455 if self.multi_label or (self.label_weights is not None): 

1456 # y_true should have shape (number of examples, number of labels). 

1457 shapes = [(y_true, ("N", "L"))] 

1458 if self.multi_label: 

1459 # TP, TN, FP, and FN should all have shape 

1460 # (number of thresholds, number of labels). 

1461 shapes.extend( 

1462 [ 

1463 (self.true_positives, ("T", "L")), 

1464 (self.true_negatives, ("T", "L")), 

1465 (self.false_positives, ("T", "L")), 

1466 (self.false_negatives, ("T", "L")), 

1467 ] 

1468 ) 

1469 if self.label_weights is not None: 

1470 # label_weights should be of length equal to the number of 

1471 # labels. 

1472 shapes.append((self.label_weights, ("L",))) 

1473 tf.debugging.assert_shapes( 

1474 shapes, message="Number of labels is not consistent." 

1475 ) 

1476 

1477 # Only forward label_weights to update_confusion_matrix_variables when 

1478 # multi_label is False. Otherwise the averaging of individual label AUCs 

1479 # is handled in AUC.result 

1480 label_weights = None if self.multi_label else self.label_weights 

1481 

1482 if self._from_logits: 

1483 y_pred = activations.sigmoid(y_pred) 

1484 

1485 return metrics_utils.update_confusion_matrix_variables( 

1486 { 

1487 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501 

1488 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501 

1489 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501 

1490 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501 

1491 }, 

1492 y_true, 

1493 y_pred, 

1494 self._thresholds, 

1495 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 

1496 sample_weight=sample_weight, 

1497 multi_label=self.multi_label, 

1498 label_weights=label_weights, 

1499 ) 

1500 

1501 def interpolate_pr_auc(self): 

1502 """Interpolation formula inspired by section 4 of Davis & Goadrich 2006. 

1503 

1504 https://www.biostat.wisc.edu/~page/rocpr.pdf 

1505 

1506 Note here we derive & use a closed formula not present in the paper 

1507 as follows: 

1508 

1509 Precision = TP / (TP + FP) = TP / P 

1510 

1511 Modeling all of TP (true positive), FP (false positive) and their sum 

1512 P = TP + FP (predicted positive) as varying linearly within each 

1513 interval [A, B] between successive thresholds, we get 

1514 

1515 Precision slope = dTP / dP 

1516 = (TP_B - TP_A) / (P_B - P_A) 

1517 = (TP - TP_A) / (P - P_A) 

1518 Precision = (TP_A + slope * (P - P_A)) / P 

1519 

1520 The area within the interval is (slope / total_pos_weight) times 

1521 

1522 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} 

1523 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} 

1524 

1525 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in 

1526 

1527 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) 

1528 

1529 Bringing back the factor (slope / total_pos_weight) we'd put aside, we 

1530 get 

1531 

1532 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight 

1533 

1534 where dTP == TP_B - TP_A. 

1535 

1536 Note that when P_A == 0 the above calculation simplifies into 

1537 

1538 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) 

1539 

1540 which is really equivalent to imputing constant precision throughout the 

1541 first bucket having >0 true positives. 

1542 

1543 Returns: 

1544 pr_auc: an approximation of the area under the P-R curve. 

1545 """ 

1546 dtp = ( 

1547 self.true_positives[: self.num_thresholds - 1] 

1548 - self.true_positives[1:] 

1549 ) 

1550 p = tf.math.add(self.true_positives, self.false_positives) 

1551 dp = p[: self.num_thresholds - 1] - p[1:] 

1552 prec_slope = tf.math.divide_no_nan( 

1553 dtp, tf.maximum(dp, 0), name="prec_slope" 

1554 ) 

1555 intercept = self.true_positives[1:] - tf.multiply(prec_slope, p[1:]) 

1556 

1557 safe_p_ratio = tf.where( 

1558 tf.logical_and(p[: self.num_thresholds - 1] > 0, p[1:] > 0), 

1559 tf.math.divide_no_nan( 

1560 p[: self.num_thresholds - 1], 

1561 tf.maximum(p[1:], 0), 

1562 name="recall_relative_ratio", 

1563 ), 

1564 tf.ones_like(p[1:]), 

1565 ) 

1566 

1567 pr_auc_increment = tf.math.divide_no_nan( 

1568 prec_slope * (dtp + intercept * tf.math.log(safe_p_ratio)), 

1569 tf.maximum(self.true_positives[1:] + self.false_negatives[1:], 0), 

1570 name="pr_auc_increment", 

1571 ) 

1572 

1573 if self.multi_label: 

1574 by_label_auc = tf.reduce_sum( 

1575 pr_auc_increment, name=self.name + "_by_label", axis=0 

1576 ) 

1577 if self.label_weights is None: 

1578 # Evenly weighted average of the label AUCs. 

1579 return tf.reduce_mean(by_label_auc, name=self.name) 

1580 else: 

1581 # Weighted average of the label AUCs. 

1582 return tf.math.divide_no_nan( 

1583 tf.reduce_sum( 

1584 tf.multiply(by_label_auc, self.label_weights) 

1585 ), 

1586 tf.reduce_sum(self.label_weights), 

1587 name=self.name, 

1588 ) 

1589 else: 

1590 return tf.reduce_sum(pr_auc_increment, name="interpolate_pr_auc") 

1591 

1592 def result(self): 

1593 if ( 

1594 self.curve == metrics_utils.AUCCurve.PR 

1595 and self.summation_method 

1596 == metrics_utils.AUCSummationMethod.INTERPOLATION 

1597 ): 

1598 # This use case is different and is handled separately. 

1599 return self.interpolate_pr_auc() 

1600 

1601 # Set `x` and `y` values for the curves based on `curve` config. 

1602 recall = tf.math.divide_no_nan( 

1603 self.true_positives, 

1604 tf.math.add(self.true_positives, self.false_negatives), 

1605 ) 

1606 if self.curve == metrics_utils.AUCCurve.ROC: 

1607 fp_rate = tf.math.divide_no_nan( 

1608 self.false_positives, 

1609 tf.math.add(self.false_positives, self.true_negatives), 

1610 ) 

1611 x = fp_rate 

1612 y = recall 

1613 else: # curve == 'PR'. 

1614 precision = tf.math.divide_no_nan( 

1615 self.true_positives, 

1616 tf.math.add(self.true_positives, self.false_positives), 

1617 ) 

1618 x = recall 

1619 y = precision 

1620 

1621 # Find the rectangle heights based on `summation_method`. 

1622 if ( 

1623 self.summation_method 

1624 == metrics_utils.AUCSummationMethod.INTERPOLATION 

1625 ): 

1626 # Note: the case ('PR', 'interpolation') has been handled above. 

1627 heights = (y[: self.num_thresholds - 1] + y[1:]) / 2.0 

1628 elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING: 

1629 heights = tf.minimum(y[: self.num_thresholds - 1], y[1:]) 

1630 # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING: 

1631 else: 

1632 heights = tf.maximum(y[: self.num_thresholds - 1], y[1:]) 

1633 

1634 # Sum up the areas of all the rectangles. 

1635 if self.multi_label: 

1636 riemann_terms = tf.multiply( 

1637 x[: self.num_thresholds - 1] - x[1:], heights 

1638 ) 

1639 by_label_auc = tf.reduce_sum( 

1640 riemann_terms, name=self.name + "_by_label", axis=0 

1641 ) 

1642 

1643 if self.label_weights is None: 

1644 # Unweighted average of the label AUCs. 

1645 return tf.reduce_mean(by_label_auc, name=self.name) 

1646 else: 

1647 # Weighted average of the label AUCs. 

1648 return tf.math.divide_no_nan( 

1649 tf.reduce_sum( 

1650 tf.multiply(by_label_auc, self.label_weights) 

1651 ), 

1652 tf.reduce_sum(self.label_weights), 

1653 name=self.name, 

1654 ) 

1655 else: 

1656 return tf.reduce_sum( 

1657 tf.multiply(x[: self.num_thresholds - 1] - x[1:], heights), 

1658 name=self.name, 

1659 ) 

1660 

1661 def reset_state(self): 

1662 if self._built: 

1663 confusion_matrix_variables = ( 

1664 self.true_positives, 

1665 self.true_negatives, 

1666 self.false_positives, 

1667 self.false_negatives, 

1668 ) 

1669 if self.multi_label: 

1670 backend.batch_set_value( 

1671 [ 

1672 (v, np.zeros((self.num_thresholds, self._num_labels))) 

1673 for v in confusion_matrix_variables 

1674 ] 

1675 ) 

1676 else: 

1677 backend.batch_set_value( 

1678 [ 

1679 (v, np.zeros((self.num_thresholds,))) 

1680 for v in confusion_matrix_variables 

1681 ] 

1682 ) 

1683 

1684 def get_config(self): 

1685 if is_tensor_or_variable(self.label_weights): 

1686 label_weights = backend.eval(self.label_weights) 

1687 else: 

1688 label_weights = self.label_weights 

1689 config = { 

1690 "num_thresholds": self.num_thresholds, 

1691 "curve": self.curve.value, 

1692 "summation_method": self.summation_method.value, 

1693 "multi_label": self.multi_label, 

1694 "num_labels": self.num_labels, 

1695 "label_weights": label_weights, 

1696 "from_logits": self._from_logits, 

1697 } 

1698 # optimization to avoid serializing a large number of generated 

1699 # thresholds 

1700 if self._init_from_thresholds: 

1701 # We remove the endpoint thresholds as an inverse of how the 

1702 # thresholds were initialized. This ensures that a metric 

1703 # initialized from this config has the same thresholds. 

1704 config["thresholds"] = self.thresholds[1:-1] 

1705 base_config = super().get_config() 

1706 return dict(list(base_config.items()) + list(config.items())) 

1707