Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/metrics.py: 35%

836 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=g-classes-have-attributes 

16# pylint: disable=g-doc-return-or-yield 

17"""Built-in metrics.""" 

18 

19import abc 

20import types 

21import warnings 

22 

23import numpy as np 

24 

25from tensorflow.python.autograph.core import ag_ctx 

26from tensorflow.python.autograph.impl import api as autograph 

27from tensorflow.python.distribute import distribute_lib 

28from tensorflow.python.eager import context 

29from tensorflow.python.eager import def_function 

30from tensorflow.python.framework import constant_op 

31from tensorflow.python.framework import dtypes 

32from tensorflow.python.framework import ops 

33from tensorflow.python.framework import tensor_conversion 

34from tensorflow.python.framework import tensor_shape 

35from tensorflow.python.keras import activations 

36from tensorflow.python.keras import backend 

37from tensorflow.python.keras.engine import base_layer 

38from tensorflow.python.keras.engine import base_layer_utils 

39from tensorflow.python.keras.engine import keras_tensor 

40from tensorflow.python.keras.losses import binary_crossentropy 

41from tensorflow.python.keras.losses import categorical_crossentropy 

42from tensorflow.python.keras.losses import categorical_hinge 

43from tensorflow.python.keras.losses import hinge 

44from tensorflow.python.keras.losses import kullback_leibler_divergence 

45from tensorflow.python.keras.losses import logcosh 

46from tensorflow.python.keras.losses import mean_absolute_error 

47from tensorflow.python.keras.losses import mean_absolute_percentage_error 

48from tensorflow.python.keras.losses import mean_squared_error 

49from tensorflow.python.keras.losses import mean_squared_logarithmic_error 

50from tensorflow.python.keras.losses import poisson 

51from tensorflow.python.keras.losses import sparse_categorical_crossentropy 

52from tensorflow.python.keras.losses import squared_hinge 

53from tensorflow.python.keras.saving.saved_model import metric_serialization 

54from tensorflow.python.keras.utils import generic_utils 

55from tensorflow.python.keras.utils import losses_utils 

56from tensorflow.python.keras.utils import metrics_utils 

57from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 

58from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 

59from tensorflow.python.keras.utils.generic_utils import to_list 

60from tensorflow.python.keras.utils.tf_utils import is_tensor_or_variable 

61from tensorflow.python.ops import array_ops 

62from tensorflow.python.ops import check_ops 

63from tensorflow.python.ops import confusion_matrix 

64from tensorflow.python.ops import init_ops 

65from tensorflow.python.ops import math_ops 

66from tensorflow.python.ops import nn 

67from tensorflow.python.ops import variables as variables_module 

68from tensorflow.python.ops import weights_broadcast_ops 

69from tensorflow.python.util import dispatch 

70from tensorflow.python.util import nest 

71from tensorflow.python.util.tf_export import keras_export 

72from tensorflow.tools.docs import doc_controls 

73 

74 

75@keras_export('keras.metrics.Metric') 

76class Metric(base_layer.Layer, metaclass=abc.ABCMeta): 

77 """Encapsulates metric logic and state. 

78 

79 Args: 

80 name: (Optional) string name of the metric instance. 

81 dtype: (Optional) data type of the metric result. 

82 **kwargs: Additional layer keywords arguments. 

83 

84 Standalone usage: 

85 

86 ```python 

87 m = SomeMetric(...) 

88 for input in ...: 

89 m.update_state(input) 

90 print('Final result: ', m.result().numpy()) 

91 ``` 

92 

93 Usage with `compile()` API: 

94 

95 ```python 

96 model = tf.keras.Sequential() 

97 model.add(tf.keras.layers.Dense(64, activation='relu')) 

98 model.add(tf.keras.layers.Dense(64, activation='relu')) 

99 model.add(tf.keras.layers.Dense(10, activation='softmax')) 

100 

101 model.compile(optimizer=tf.keras.optimizers.RMSprop(0.01), 

102 loss=tf.keras.losses.CategoricalCrossentropy(), 

103 metrics=[tf.keras.metrics.CategoricalAccuracy()]) 

104 

105 data = np.random.random((1000, 32)) 

106 labels = np.random.random((1000, 10)) 

107 

108 dataset = tf.data.Dataset.from_tensor_slices((data, labels)) 

109 dataset = dataset.batch(32) 

110 

111 model.fit(dataset, epochs=10) 

112 ``` 

113 

114 To be implemented by subclasses: 

115 * `__init__()`: All state variables should be created in this method by 

116 calling `self.add_weight()` like: `self.var = self.add_weight(...)` 

117 * `update_state()`: Has all updates to the state variables like: 

118 self.var.assign_add(...). 

119 * `result()`: Computes and returns a value for the metric 

120 from the state variables. 

121 

122 Example subclass implementation: 

123 

124 ```python 

125 class BinaryTruePositives(tf.keras.metrics.Metric): 

126 

127 def __init__(self, name='binary_true_positives', **kwargs): 

128 super(BinaryTruePositives, self).__init__(name=name, **kwargs) 

129 self.true_positives = self.add_weight(name='tp', initializer='zeros') 

130 

131 def update_state(self, y_true, y_pred, sample_weight=None): 

132 y_true = tf.cast(y_true, tf.bool) 

133 y_pred = tf.cast(y_pred, tf.bool) 

134 

135 values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True)) 

136 values = tf.cast(values, self.dtype) 

137 if sample_weight is not None: 

138 sample_weight = tf.cast(sample_weight, self.dtype) 

139 sample_weight = tf.broadcast_to(sample_weight, values.shape) 

140 values = tf.multiply(values, sample_weight) 

141 self.true_positives.assign_add(tf.reduce_sum(values)) 

142 

143 def result(self): 

144 return self.true_positives 

145 ``` 

146 """ 

147 

148 def __init__(self, name=None, dtype=None, **kwargs): 

149 super(Metric, self).__init__(name=name, dtype=dtype, **kwargs) 

150 self.stateful = True # All metric layers are stateful. 

151 self.built = True 

152 if not base_layer_utils.v2_dtype_behavior_enabled(): 

153 # We only do this when the V2 behavior is not enabled, as when it is 

154 # enabled, the dtype already defaults to floatx. 

155 self._dtype = (backend.floatx() if dtype is None 

156 else dtypes.as_dtype(dtype).name) 

157 

158 def __new__(cls, *args, **kwargs): 

159 obj = super(Metric, cls).__new__(cls) 

160 

161 # If `update_state` is not in eager/tf.function and it is not from a 

162 # built-in metric, wrap it in `tf.function`. This is so that users writing 

163 # custom metrics in v1 need not worry about control dependencies and 

164 # return ops. 

165 if (base_layer_utils.is_in_eager_or_tf_function() or 

166 is_built_in(cls)): 

167 obj_update_state = obj.update_state 

168 

169 def update_state_fn(*args, **kwargs): 

170 control_status = ag_ctx.control_status_ctx() 

171 ag_update_state = autograph.tf_convert(obj_update_state, control_status) 

172 return ag_update_state(*args, **kwargs) 

173 else: 

174 if isinstance(obj.update_state, def_function.Function): 

175 update_state_fn = obj.update_state 

176 else: 

177 update_state_fn = def_function.function(obj.update_state) 

178 

179 obj.update_state = types.MethodType( 

180 metrics_utils.update_state_wrapper(update_state_fn), obj) 

181 

182 obj_result = obj.result 

183 

184 def result_fn(*args, **kwargs): 

185 control_status = ag_ctx.control_status_ctx() 

186 ag_result = autograph.tf_convert(obj_result, control_status) 

187 return ag_result(*args, **kwargs) 

188 

189 obj.result = types.MethodType(metrics_utils.result_wrapper(result_fn), obj) 

190 

191 return obj 

192 

193 def __call__(self, *args, **kwargs): 

194 """Accumulates statistics and then computes metric result value. 

195 

196 Args: 

197 *args: 

198 **kwargs: A mini-batch of inputs to the Metric, 

199 passed on to `update_state()`. 

200 

201 Returns: 

202 The metric value tensor. 

203 """ 

204 

205 def replica_local_fn(*args, **kwargs): 

206 """Updates the state of the metric in a replica-local context.""" 

207 if any( 

208 isinstance(arg, keras_tensor.KerasTensor) 

209 for arg in nest.flatten((args, kwargs))): 

210 update_op = None 

211 else: 

212 update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable 

213 update_ops = [] 

214 if update_op is not None: 

215 update_ops.append(update_op) 

216 with ops.control_dependencies(update_ops): 

217 result_t = self.result() # pylint: disable=not-callable 

218 

219 # We are adding the metric object as metadata on the result tensor. 

220 # This is required when we want to use a metric with `add_metric` API on 

221 # a Model/Layer in graph mode. This metric instance will later be used 

222 # to reset variable state after each epoch of training. 

223 # Example: 

224 # model = Model() 

225 # mean = Mean() 

226 # model.add_metric(mean(values), name='mean') 

227 result_t._metric_obj = self # pylint: disable=protected-access 

228 return result_t 

229 

230 from tensorflow.python.keras.distribute import distributed_training_utils # pylint:disable=g-import-not-at-top 

231 return distributed_training_utils.call_replica_local_fn( 

232 replica_local_fn, *args, **kwargs) 

233 

234 @property 

235 def dtype(self): 

236 return self._dtype 

237 

238 def get_config(self): 

239 """Returns the serializable config of the metric.""" 

240 return {'name': self.name, 'dtype': self.dtype} 

241 

242 def reset_state(self): 

243 """Resets all of the metric state variables. 

244 

245 This function is called between epochs/steps, 

246 when a metric is evaluated during training. 

247 """ 

248 if not generic_utils.is_default(self.reset_states): 

249 warnings.warn('Metric %s implements a `reset_states()` method; rename it ' 

250 'to `reset_state()` (without the final "s"). The name ' 

251 '`reset_states()` has been deprecated to improve API ' 

252 'consistency.' % (self.__class__.__name__,)) 

253 return self.reset_states() 

254 else: 

255 backend.batch_set_value([(v, 0) for v in self.variables]) 

256 

257 @abc.abstractmethod 

258 def update_state(self, *args, **kwargs): 

259 """Accumulates statistics for the metric. 

260 

261 Note: This function is executed as a graph function in graph mode. 

262 This means: 

263 a) Operations on the same resource are executed in textual order. 

264 This should make it easier to do things like add the updated 

265 value of a variable to another, for example. 

266 b) You don't need to worry about collecting the update ops to execute. 

267 All update ops added to the graph by this function will be executed. 

268 As a result, code should generally work the same way with graph or 

269 eager execution. 

270 

271 Args: 

272 *args: 

273 **kwargs: A mini-batch of inputs to the Metric. 

274 """ 

275 raise NotImplementedError('Must be implemented in subclasses.') 

276 

277 @abc.abstractmethod 

278 def result(self): 

279 """Computes and returns the metric value tensor. 

280 

281 Result computation is an idempotent operation that simply calculates the 

282 metric value using the state variables. 

283 """ 

284 raise NotImplementedError('Must be implemented in subclasses.') 

285 

286 ### For use by subclasses ### 

287 @doc_controls.for_subclass_implementers 

288 def add_weight( 

289 self, 

290 name, 

291 shape=(), 

292 aggregation=variables_module.VariableAggregation.SUM, 

293 synchronization=variables_module.VariableSynchronization.ON_READ, 

294 initializer=None, 

295 dtype=None): 

296 """Adds state variable. Only for use by subclasses.""" 

297 if distribute_lib.has_strategy(): 

298 strategy = distribute_lib.get_strategy() 

299 else: 

300 strategy = None 

301 

302 # TODO(b/120571621): Make `ON_READ` work with Keras metrics on TPU. 

303 if backend.is_tpu_strategy(strategy): 

304 synchronization = variables_module.VariableSynchronization.ON_WRITE 

305 

306 with ops.init_scope(): 

307 return super(Metric, self).add_weight( 

308 name=name, 

309 shape=shape, 

310 dtype=self._dtype if dtype is None else dtype, 

311 trainable=False, 

312 initializer=initializer, 

313 collections=[], 

314 synchronization=synchronization, 

315 aggregation=aggregation) 

316 

317 ### End: For use by subclasses ### 

318 

319 @property 

320 def trainable_weights(self): 

321 # Overridden from Layer class to track submetric weights. 

322 if self.trainable: 

323 trainable_weights = self._trainable_weights 

324 for m in self._metrics: 

325 trainable_weights += m.trainable_weights 

326 return self._dedup_weights(trainable_weights) 

327 else: 

328 return [] 

329 

330 @property 

331 def non_trainable_weights(self): 

332 # Overridden from Layer class to track submetric weights. 

333 if self.trainable: 

334 non_trainable_weights = self._non_trainable_weights 

335 for m in self._metrics: 

336 non_trainable_weights += m.non_trainable_weights 

337 else: 

338 non_trainable_weights = ( 

339 self._non_trainable_weights + self._trainable_weights) 

340 for m in self._metrics: 

341 non_trainable_weights += m.weights 

342 return self._dedup_weights(non_trainable_weights) 

343 

344 @property 

345 def _trackable_saved_model_saver(self): 

346 return metric_serialization.MetricSavedModelSaver(self) 

347 

348 @generic_utils.default 

349 @doc_controls.do_not_generate_docs 

350 def reset_states(self): 

351 # Backwards compatibility alias of `reset_state`. New classes should 

352 # only implement `reset_state`. 

353 return self.reset_state() 

354 

355 

356class Reduce(Metric): 

357 """Encapsulates metrics that perform a reduce operation on the values. 

358 

359 Args: 

360 reduction: a `tf.keras.metrics.Reduction` enum value. 

361 name: string name of the metric instance. 

362 dtype: (Optional) data type of the metric result. 

363 """ 

364 

365 def __init__(self, reduction, name, dtype=None): 

366 super(Reduce, self).__init__(name=name, dtype=dtype) 

367 self.reduction = reduction 

368 self.total = self.add_weight( 

369 'total', initializer=init_ops.zeros_initializer) 

370 if reduction in [metrics_utils.Reduction.SUM_OVER_BATCH_SIZE, 

371 metrics_utils.Reduction.WEIGHTED_MEAN]: 

372 self.count = self.add_weight( 

373 'count', initializer=init_ops.zeros_initializer) 

374 

375 def update_state(self, values, sample_weight=None): 

376 """Accumulates statistics for computing the metric. 

377 

378 Args: 

379 values: Per-example value. 

380 sample_weight: Optional weighting of each example. Defaults to 1. 

381 

382 Returns: 

383 Update op. 

384 """ 

385 [values], sample_weight = \ 

386 metrics_utils.ragged_assert_compatible_and_get_flat_values( 

387 [values], sample_weight) 

388 try: 

389 values = math_ops.cast(values, self._dtype) 

390 except (ValueError, TypeError): 

391 msg = ('The output of a metric function can only be a single Tensor. ' 

392 'Got: %s' % (values,)) 

393 if isinstance(values, dict): 

394 msg += ('. To return a dict of values, implement a custom Metric ' 

395 'subclass.') 

396 raise RuntimeError(msg) 

397 if sample_weight is not None: 

398 sample_weight = math_ops.cast(sample_weight, self._dtype) 

399 # Update dimensions of weights to match with values if possible. 

400 values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions( 

401 values, sample_weight=sample_weight) 

402 try: 

403 # Broadcast weights if possible. 

404 sample_weight = weights_broadcast_ops.broadcast_weights( 

405 sample_weight, values) 

406 except ValueError: 

407 # Reduce values to same ndim as weight array 

408 ndim = backend.ndim(values) 

409 weight_ndim = backend.ndim(sample_weight) 

410 if self.reduction == metrics_utils.Reduction.SUM: 

411 values = math_ops.reduce_sum( 

412 values, axis=list(range(weight_ndim, ndim))) 

413 else: 

414 values = math_ops.reduce_mean( 

415 values, axis=list(range(weight_ndim, ndim))) 

416 values = math_ops.multiply(values, sample_weight) 

417 

418 value_sum = math_ops.reduce_sum(values) 

419 with ops.control_dependencies([value_sum]): 

420 update_total_op = self.total.assign_add(value_sum) 

421 

422 # Exit early if the reduction doesn't have a denominator. 

423 if self.reduction == metrics_utils.Reduction.SUM: 

424 return update_total_op 

425 

426 # Update `count` for reductions that require a denominator. 

427 if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE: 

428 num_values = math_ops.cast(array_ops.size(values), self._dtype) 

429 elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN: 

430 if sample_weight is None: 

431 num_values = math_ops.cast(array_ops.size(values), self._dtype) 

432 else: 

433 num_values = math_ops.reduce_sum(sample_weight) 

434 else: 

435 raise NotImplementedError( 

436 'reduction [%s] not implemented' % self.reduction) 

437 

438 with ops.control_dependencies([update_total_op]): 

439 return self.count.assign_add(num_values) 

440 

441 def result(self): 

442 if self.reduction == metrics_utils.Reduction.SUM: 

443 return array_ops.identity(self.total) 

444 elif self.reduction in [ 

445 metrics_utils.Reduction.WEIGHTED_MEAN, 

446 metrics_utils.Reduction.SUM_OVER_BATCH_SIZE 

447 ]: 

448 return math_ops.div_no_nan(self.total, self.count) 

449 else: 

450 raise NotImplementedError( 

451 'reduction [%s] not implemented' % self.reduction) 

452 

453 

454@keras_export('keras.metrics.Sum') 

455class Sum(Reduce): 

456 """Computes the (weighted) sum of the given values. 

457 

458 For example, if values is [1, 3, 5, 7] then the sum is 16. 

459 If the weights were specified as [1, 1, 0, 0] then the sum would be 4. 

460 

461 This metric creates one variable, `total`, that is used to compute the sum of 

462 `values`. This is ultimately returned as `sum`. 

463 

464 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 

465 to mask values. 

466 

467 Args: 

468 name: (Optional) string name of the metric instance. 

469 dtype: (Optional) data type of the metric result. 

470 

471 Standalone usage: 

472 

473 >>> m = tf.keras.metrics.Sum() 

474 >>> m.update_state([1, 3, 5, 7]) 

475 >>> m.result().numpy() 

476 16.0 

477 

478 Usage with `compile()` API: 

479 

480 ```python 

481 model.add_metric(tf.keras.metrics.Sum(name='sum_1')(outputs)) 

482 model.compile(optimizer='sgd', loss='mse') 

483 ``` 

484 """ 

485 

486 def __init__(self, name='sum', dtype=None): 

487 super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM, 

488 name=name, dtype=dtype) 

489 

490 

491@keras_export('keras.metrics.Mean') 

492class Mean(Reduce): 

493 """Computes the (weighted) mean of the given values. 

494 

495 For example, if values is [1, 3, 5, 7] then the mean is 4. 

496 If the weights were specified as [1, 1, 0, 0] then the mean would be 2. 

497 

498 This metric creates two variables, `total` and `count` that are used to 

499 compute the average of `values`. This average is ultimately returned as `mean` 

500 which is an idempotent operation that simply divides `total` by `count`. 

501 

502 If `sample_weight` is `None`, weights default to 1. 

503 Use `sample_weight` of 0 to mask values. 

504 

505 Args: 

506 name: (Optional) string name of the metric instance. 

507 dtype: (Optional) data type of the metric result. 

508 

509 Standalone usage: 

510 

511 >>> m = tf.keras.metrics.Mean() 

512 >>> m.update_state([1, 3, 5, 7]) 

513 >>> m.result().numpy() 

514 4.0 

515 >>> m.reset_state() 

516 >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0]) 

517 >>> m.result().numpy() 

518 2.0 

519 

520 Usage with `compile()` API: 

521 

522 ```python 

523 model.add_metric(tf.keras.metrics.Mean(name='mean_1')(outputs)) 

524 model.compile(optimizer='sgd', loss='mse') 

525 ``` 

526 """ 

527 

528 def __init__(self, name='mean', dtype=None): 

529 super(Mean, self).__init__( 

530 reduction=metrics_utils.Reduction.WEIGHTED_MEAN, name=name, dtype=dtype) 

531 

532 

533@keras_export('keras.metrics.MeanRelativeError') 

534class MeanRelativeError(Mean): 

535 """Computes the mean relative error by normalizing with the given values. 

536 

537 This metric creates two local variables, `total` and `count` that are used to 

538 compute the mean relative error. This is weighted by `sample_weight`, and 

539 it is ultimately returned as `mean_relative_error`: 

540 an idempotent operation that simply divides `total` by `count`. 

541 

542 If `sample_weight` is `None`, weights default to 1. 

543 Use `sample_weight` of 0 to mask values. 

544 

545 Args: 

546 normalizer: The normalizer values with same shape as predictions. 

547 name: (Optional) string name of the metric instance. 

548 dtype: (Optional) data type of the metric result. 

549 

550 Standalone usage: 

551 

552 >>> m = tf.keras.metrics.MeanRelativeError(normalizer=[1, 3, 2, 3]) 

553 >>> m.update_state([1, 3, 2, 3], [2, 4, 6, 8]) 

554 

555 >>> # metric = mean(|y_pred - y_true| / normalizer) 

556 >>> # = mean([1, 1, 4, 5] / [1, 3, 2, 3]) = mean([1, 1/3, 2, 5/3]) 

557 >>> # = 5/4 = 1.25 

558 >>> m.result().numpy() 

559 1.25 

560 

561 Usage with `compile()` API: 

562 

563 ```python 

564 model.compile( 

565 optimizer='sgd', 

566 loss='mse', 

567 metrics=[tf.keras.metrics.MeanRelativeError(normalizer=[1, 3])]) 

568 ``` 

569 """ 

570 

571 def __init__(self, normalizer, name=None, dtype=None): 

572 super(MeanRelativeError, self).__init__(name=name, dtype=dtype) 

573 normalizer = math_ops.cast(normalizer, self._dtype) 

574 self.normalizer = normalizer 

575 

576 def update_state(self, y_true, y_pred, sample_weight=None): 

577 """Accumulates metric statistics. 

578 

579 Args: 

580 y_true: The ground truth values. 

581 y_pred: The predicted values. 

582 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 

583 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 

584 be broadcastable to `y_true`. 

585 

586 Returns: 

587 Update op. 

588 """ 

589 y_true = math_ops.cast(y_true, self._dtype) 

590 y_pred = math_ops.cast(y_pred, self._dtype) 

591 [y_pred, y_true], sample_weight = \ 

592 metrics_utils.ragged_assert_compatible_and_get_flat_values( 

593 [y_pred, y_true], sample_weight) 

594 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 

595 y_pred, y_true) 

596 

597 y_pred, self.normalizer = losses_utils.remove_squeezable_dimensions( 

598 y_pred, self.normalizer) 

599 y_pred.shape.assert_is_compatible_with(y_true.shape) 

600 relative_errors = math_ops.div_no_nan( 

601 math_ops.abs(y_true - y_pred), self.normalizer) 

602 

603 return super(MeanRelativeError, self).update_state( 

604 relative_errors, sample_weight=sample_weight) 

605 

606 def get_config(self): 

607 n = self.normalizer 

608 config = {'normalizer': backend.eval(n) if is_tensor_or_variable(n) else n} 

609 base_config = super(MeanRelativeError, self).get_config() 

610 return dict(list(base_config.items()) + list(config.items())) 

611 

612 

613@keras_export('keras.metrics.MeanMetricWrapper') 

614class MeanMetricWrapper(Mean): 

615 """Wraps a stateless metric function with the Mean metric. 

616 

617 You could use this class to quickly build a mean metric from a function. The 

618 function needs to have the signature `fn(y_true, y_pred)` and return a 

619 per-sample loss array. `MeanMetricWrapper.result()` will return 

620 the average metric value across all samples seen so far. 

621 

622 For example: 

623 

624 ```python 

625 def accuracy(y_true, y_pred): 

626 return tf.cast(tf.math.equal(y_true, y_pred), tf.float32) 

627 

628 accuracy_metric = tf.keras.metrics.MeanMetricWrapper(fn=accuracy) 

629 

630 keras_model.compile(..., metrics=accuracy_metric) 

631 ``` 

632 

633 Args: 

634 fn: The metric function to wrap, with signature `fn(y_true, y_pred, 

635 **kwargs)`. 

636 name: (Optional) string name of the metric instance. 

637 dtype: (Optional) data type of the metric result. 

638 **kwargs: Keyword arguments to pass on to `fn`. 

639 """ 

640 

641 def __init__(self, fn, name=None, dtype=None, **kwargs): 

642 super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype) 

643 self._fn = fn 

644 self._fn_kwargs = kwargs 

645 

646 def update_state(self, y_true, y_pred, sample_weight=None): 

647 """Accumulates metric statistics. 

648 

649 `y_true` and `y_pred` should have the same shape. 

650 

651 Args: 

652 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 

653 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 

654 sample_weight: Optional `sample_weight` acts as a 

655 coefficient for the metric. If a scalar is provided, then the metric is 

656 simply scaled by the given value. If `sample_weight` is a tensor of size 

657 `[batch_size]`, then the metric for each sample of the batch is rescaled 

658 by the corresponding element in the `sample_weight` vector. If the shape 

659 of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted 

660 to this shape), then each metric element of `y_pred` is scaled by the 

661 corresponding value of `sample_weight`. (Note on `dN-1`: all metric 

662 functions reduce by 1 dimension, usually the last axis (-1)). 

663 

664 Returns: 

665 Update op. 

666 """ 

667 y_true = math_ops.cast(y_true, self._dtype) 

668 y_pred = math_ops.cast(y_pred, self._dtype) 

669 [y_true, y_pred], sample_weight = ( 

670 metrics_utils.ragged_assert_compatible_and_get_flat_values( 

671 [y_true, y_pred], sample_weight)) 

672 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 

673 y_pred, y_true) 

674 

675 ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx()) 

676 matches = ag_fn(y_true, y_pred, **self._fn_kwargs) 

677 return super(MeanMetricWrapper, self).update_state( 

678 matches, sample_weight=sample_weight) 

679 

680 def get_config(self): 

681 config = {} 

682 

683 if type(self) is MeanMetricWrapper: # pylint: disable=unidiomatic-typecheck 

684 # Only include function argument when the object is a MeanMetricWrapper 

685 # and not a subclass. 

686 config['fn'] = self._fn 

687 

688 for k, v in self._fn_kwargs.items(): 

689 config[k] = backend.eval(v) if is_tensor_or_variable(v) else v 

690 base_config = super(MeanMetricWrapper, self).get_config() 

691 return dict(list(base_config.items()) + list(config.items())) 

692 

693 @classmethod 

694 def from_config(cls, config): 

695 # Note that while MeanMetricWrapper itself isn't public, objects of this 

696 # class may be created and added to the model by calling model.compile. 

697 fn = config.pop('fn', None) 

698 if cls is MeanMetricWrapper: 

699 return cls(get(fn), **config) 

700 return super(MeanMetricWrapper, cls).from_config(config) 

701 

702 

703@keras_export('keras.metrics.Accuracy') 

704class Accuracy(MeanMetricWrapper): 

705 """Calculates how often predictions equal labels. 

706 

707 This metric creates two local variables, `total` and `count` that are used to 

708 compute the frequency with which `y_pred` matches `y_true`. This frequency is 

709 ultimately returned as `binary accuracy`: an idempotent operation that simply 

710 divides `total` by `count`. 

711 

712 If `sample_weight` is `None`, weights default to 1. 

713 Use `sample_weight` of 0 to mask values. 

714 

715 Args: 

716 name: (Optional) string name of the metric instance. 

717 dtype: (Optional) data type of the metric result. 

718 

719 Standalone usage: 

720 

721 >>> m = tf.keras.metrics.Accuracy() 

722 >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]]) 

723 >>> m.result().numpy() 

724 0.75 

725 

726 >>> m.reset_state() 

727 >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]], 

728 ... sample_weight=[1, 1, 0, 0]) 

729 >>> m.result().numpy() 

730 0.5 

731 

732 Usage with `compile()` API: 

733 

734 ```python 

735 model.compile(optimizer='sgd', 

736 loss='mse', 

737 metrics=[tf.keras.metrics.Accuracy()]) 

738 ``` 

739 """ 

740 

741 def __init__(self, name='accuracy', dtype=None): 

742 super(Accuracy, self).__init__(accuracy, name, dtype=dtype) 

743 

744 

745@keras_export('keras.metrics.BinaryAccuracy') 

746class BinaryAccuracy(MeanMetricWrapper): 

747 """Calculates how often predictions match binary labels. 

748 

749 This metric creates two local variables, `total` and `count` that are used to 

750 compute the frequency with which `y_pred` matches `y_true`. This frequency is 

751 ultimately returned as `binary accuracy`: an idempotent operation that simply 

752 divides `total` by `count`. 

753 

754 If `sample_weight` is `None`, weights default to 1. 

755 Use `sample_weight` of 0 to mask values. 

756 

757 Args: 

758 name: (Optional) string name of the metric instance. 

759 dtype: (Optional) data type of the metric result. 

760 threshold: (Optional) Float representing the threshold for deciding 

761 whether prediction values are 1 or 0. 

762 

763 Standalone usage: 

764 

765 >>> m = tf.keras.metrics.BinaryAccuracy() 

766 >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]]) 

767 >>> m.result().numpy() 

768 0.75 

769 

770 >>> m.reset_state() 

771 >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]], 

772 ... sample_weight=[1, 0, 0, 1]) 

773 >>> m.result().numpy() 

774 0.5 

775 

776 Usage with `compile()` API: 

777 

778 ```python 

779 model.compile(optimizer='sgd', 

780 loss='mse', 

781 metrics=[tf.keras.metrics.BinaryAccuracy()]) 

782 ``` 

783 """ 

784 

785 def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5): 

786 super(BinaryAccuracy, self).__init__( 

787 binary_accuracy, name, dtype=dtype, threshold=threshold) 

788 

789 

790@keras_export('keras.metrics.CategoricalAccuracy') 

791class CategoricalAccuracy(MeanMetricWrapper): 

792 """Calculates how often predictions match one-hot labels. 

793 

794 You can provide logits of classes as `y_pred`, since argmax of 

795 logits and probabilities are same. 

796 

797 This metric creates two local variables, `total` and `count` that are used to 

798 compute the frequency with which `y_pred` matches `y_true`. This frequency is 

799 ultimately returned as `categorical accuracy`: an idempotent operation that 

800 simply divides `total` by `count`. 

801 

802 `y_pred` and `y_true` should be passed in as vectors of probabilities, rather 

803 than as labels. If necessary, use `tf.one_hot` to expand `y_true` as a vector. 

804 

805 If `sample_weight` is `None`, weights default to 1. 

806 Use `sample_weight` of 0 to mask values. 

807 

808 Args: 

809 name: (Optional) string name of the metric instance. 

810 dtype: (Optional) data type of the metric result. 

811 

812 Standalone usage: 

813 

814 >>> m = tf.keras.metrics.CategoricalAccuracy() 

815 >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], 

816 ... [0.05, 0.95, 0]]) 

817 >>> m.result().numpy() 

818 0.5 

819 

820 >>> m.reset_state() 

821 >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], 

822 ... [0.05, 0.95, 0]], 

823 ... sample_weight=[0.7, 0.3]) 

824 >>> m.result().numpy() 

825 0.3 

826 

827 Usage with `compile()` API: 

828 

829 ```python 

830 model.compile( 

831 optimizer='sgd', 

832 loss='mse', 

833 metrics=[tf.keras.metrics.CategoricalAccuracy()]) 

834 ``` 

835 """ 

836 

837 def __init__(self, name='categorical_accuracy', dtype=None): 

838 super(CategoricalAccuracy, self).__init__( 

839 categorical_accuracy, name, dtype=dtype) 

840 

841 

842@keras_export('keras.metrics.SparseCategoricalAccuracy') 

843class SparseCategoricalAccuracy(MeanMetricWrapper): 

844 """Calculates how often predictions match integer labels. 

845 

846 ```python 

847 acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1)) 

848 ``` 

849 

850 You can provide logits of classes as `y_pred`, since argmax of 

851 logits and probabilities are same. 

852 

853 This metric creates two local variables, `total` and `count` that are used to 

854 compute the frequency with which `y_pred` matches `y_true`. This frequency is 

855 ultimately returned as `sparse categorical accuracy`: an idempotent operation 

856 that simply divides `total` by `count`. 

857 

858 If `sample_weight` is `None`, weights default to 1. 

859 Use `sample_weight` of 0 to mask values. 

860 

861 Args: 

862 name: (Optional) string name of the metric instance. 

863 dtype: (Optional) data type of the metric result. 

864 

865 Standalone usage: 

866 

867 >>> m = tf.keras.metrics.SparseCategoricalAccuracy() 

868 >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]]) 

869 >>> m.result().numpy() 

870 0.5 

871 

872 >>> m.reset_state() 

873 >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]], 

874 ... sample_weight=[0.7, 0.3]) 

875 >>> m.result().numpy() 

876 0.3 

877 

878 Usage with `compile()` API: 

879 

880 ```python 

881 model.compile( 

882 optimizer='sgd', 

883 loss='mse', 

884 metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) 

885 ``` 

886 """ 

887 

888 def __init__(self, name='sparse_categorical_accuracy', dtype=None): 

889 super(SparseCategoricalAccuracy, self).__init__( 

890 sparse_categorical_accuracy, name, dtype=dtype) 

891 

892 

893@keras_export('keras.metrics.TopKCategoricalAccuracy') 

894class TopKCategoricalAccuracy(MeanMetricWrapper): 

895 """Computes how often targets are in the top `K` predictions. 

896 

897 Args: 

898 k: (Optional) Number of top elements to look at for computing accuracy. 

899 Defaults to 5. 

900 name: (Optional) string name of the metric instance. 

901 dtype: (Optional) data type of the metric result. 

902 

903 Standalone usage: 

904 

905 >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1) 

906 >>> m.update_state([[0, 0, 1], [0, 1, 0]], 

907 ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) 

908 >>> m.result().numpy() 

909 0.5 

910 

911 >>> m.reset_state() 

912 >>> m.update_state([[0, 0, 1], [0, 1, 0]], 

913 ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]], 

914 ... sample_weight=[0.7, 0.3]) 

915 >>> m.result().numpy() 

916 0.3 

917 

918 Usage with `compile()` API: 

919 

920 ```python 

921 model.compile(optimizer='sgd', 

922 loss='mse', 

923 metrics=[tf.keras.metrics.TopKCategoricalAccuracy()]) 

924 ``` 

925 """ 

926 

927 def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None): 

928 super(TopKCategoricalAccuracy, self).__init__( 

929 top_k_categorical_accuracy, name, dtype=dtype, k=k) 

930 

931 

932@keras_export('keras.metrics.SparseTopKCategoricalAccuracy') 

933class SparseTopKCategoricalAccuracy(MeanMetricWrapper): 

934 """Computes how often integer targets are in the top `K` predictions. 

935 

936 Args: 

937 k: (Optional) Number of top elements to look at for computing accuracy. 

938 Defaults to 5. 

939 name: (Optional) string name of the metric instance. 

940 dtype: (Optional) data type of the metric result. 

941 

942 Standalone usage: 

943 

944 >>> m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1) 

945 >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) 

946 >>> m.result().numpy() 

947 0.5 

948 

949 >>> m.reset_state() 

950 >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]], 

951 ... sample_weight=[0.7, 0.3]) 

952 >>> m.result().numpy() 

953 0.3 

954 

955 Usage with `compile()` API: 

956 

957 ```python 

958 model.compile( 

959 optimizer='sgd', 

960 loss='mse', 

961 metrics=[tf.keras.metrics.SparseTopKCategoricalAccuracy()]) 

962 ``` 

963 """ 

964 

965 def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None): 

966 super(SparseTopKCategoricalAccuracy, self).__init__( 

967 sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k) 

968 

969 

970class _ConfusionMatrixConditionCount(Metric): 

971 """Calculates the number of the given confusion matrix condition. 

972 

973 Args: 

974 confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions. 

975 thresholds: (Optional) Defaults to 0.5. A float value or a python list/tuple 

976 of float threshold values in [0, 1]. A threshold is compared with 

977 prediction values to determine the truth value of predictions (i.e., above 

978 the threshold is `true`, below is `false`). One metric value is generated 

979 for each threshold value. 

980 name: (Optional) string name of the metric instance. 

981 dtype: (Optional) data type of the metric result. 

982 """ 

983 

984 def __init__(self, 

985 confusion_matrix_cond, 

986 thresholds=None, 

987 name=None, 

988 dtype=None): 

989 super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype) 

990 self._confusion_matrix_cond = confusion_matrix_cond 

991 self.init_thresholds = thresholds 

992 self.thresholds = metrics_utils.parse_init_thresholds( 

993 thresholds, default_threshold=0.5) 

994 self._thresholds_distributed_evenly = ( 

995 metrics_utils.is_evenly_distributed_thresholds(self.thresholds)) 

996 self.accumulator = self.add_weight( 

997 'accumulator', 

998 shape=(len(self.thresholds),), 

999 initializer=init_ops.zeros_initializer) 

1000 

1001 def update_state(self, y_true, y_pred, sample_weight=None): 

1002 """Accumulates the metric statistics. 

1003 

1004 Args: 

1005 y_true: The ground truth values. 

1006 y_pred: The predicted values. 

1007 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 

1008 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 

1009 be broadcastable to `y_true`. 

1010 

1011 Returns: 

1012 Update op. 

1013 """ 

1014 return metrics_utils.update_confusion_matrix_variables( 

1015 {self._confusion_matrix_cond: self.accumulator}, 

1016 y_true, 

1017 y_pred, 

1018 thresholds=self.thresholds, 

1019 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 

1020 sample_weight=sample_weight) 

1021 

1022 def result(self): 

1023 if len(self.thresholds) == 1: 

1024 result = self.accumulator[0] 

1025 else: 

1026 result = self.accumulator 

1027 return tensor_conversion.convert_to_tensor_v2_with_dispatch(result) 

1028 

1029 def reset_state(self): 

1030 num_thresholds = len(to_list(self.thresholds)) 

1031 backend.batch_set_value( 

1032 [(v, np.zeros((num_thresholds,))) for v in self.variables]) 

1033 

1034 def get_config(self): 

1035 config = {'thresholds': self.init_thresholds} 

1036 base_config = super(_ConfusionMatrixConditionCount, self).get_config() 

1037 return dict(list(base_config.items()) + list(config.items())) 

1038 

1039 

1040@keras_export('keras.metrics.FalsePositives') 

1041class FalsePositives(_ConfusionMatrixConditionCount): 

1042 """Calculates the number of false positives. 

1043 

1044 If `sample_weight` is given, calculates the sum of the weights of 

1045 false positives. This metric creates one local variable, `accumulator` 

1046 that is used to keep track of the number of false positives. 

1047 

1048 If `sample_weight` is `None`, weights default to 1. 

1049 Use `sample_weight` of 0 to mask values. 

1050 

1051 Args: 

1052 thresholds: (Optional) Defaults to 0.5. A float value or a python 

1053 list/tuple of float threshold values in [0, 1]. A threshold is compared 

1054 with prediction values to determine the truth value of predictions 

1055 (i.e., above the threshold is `true`, below is `false`). One metric 

1056 value is generated for each threshold value. 

1057 name: (Optional) string name of the metric instance. 

1058 dtype: (Optional) data type of the metric result. 

1059 

1060 Standalone usage: 

1061 

1062 >>> m = tf.keras.metrics.FalsePositives() 

1063 >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1]) 

1064 >>> m.result().numpy() 

1065 2.0 

1066 

1067 >>> m.reset_state() 

1068 >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 

1069 >>> m.result().numpy() 

1070 1.0 

1071 

1072 Usage with `compile()` API: 

1073 

1074 ```python 

1075 model.compile(optimizer='sgd', 

1076 loss='mse', 

1077 metrics=[tf.keras.metrics.FalsePositives()]) 

1078 ``` 

1079 """ 

1080 

1081 def __init__(self, thresholds=None, name=None, dtype=None): 

1082 super(FalsePositives, self).__init__( 

1083 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES, 

1084 thresholds=thresholds, 

1085 name=name, 

1086 dtype=dtype) 

1087 

1088 

1089@keras_export('keras.metrics.FalseNegatives') 

1090class FalseNegatives(_ConfusionMatrixConditionCount): 

1091 """Calculates the number of false negatives. 

1092 

1093 If `sample_weight` is given, calculates the sum of the weights of 

1094 false negatives. This metric creates one local variable, `accumulator` 

1095 that is used to keep track of the number of false negatives. 

1096 

1097 If `sample_weight` is `None`, weights default to 1. 

1098 Use `sample_weight` of 0 to mask values. 

1099 

1100 Args: 

1101 thresholds: (Optional) Defaults to 0.5. A float value or a python 

1102 list/tuple of float threshold values in [0, 1]. A threshold is compared 

1103 with prediction values to determine the truth value of predictions 

1104 (i.e., above the threshold is `true`, below is `false`). One metric 

1105 value is generated for each threshold value. 

1106 name: (Optional) string name of the metric instance. 

1107 dtype: (Optional) data type of the metric result. 

1108 

1109 Standalone usage: 

1110 

1111 >>> m = tf.keras.metrics.FalseNegatives() 

1112 >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0]) 

1113 >>> m.result().numpy() 

1114 2.0 

1115 

1116 >>> m.reset_state() 

1117 >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0]) 

1118 >>> m.result().numpy() 

1119 1.0 

1120 

1121 Usage with `compile()` API: 

1122 

1123 ```python 

1124 model.compile(optimizer='sgd', 

1125 loss='mse', 

1126 metrics=[tf.keras.metrics.FalseNegatives()]) 

1127 ``` 

1128 """ 

1129 

1130 def __init__(self, thresholds=None, name=None, dtype=None): 

1131 super(FalseNegatives, self).__init__( 

1132 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES, 

1133 thresholds=thresholds, 

1134 name=name, 

1135 dtype=dtype) 

1136 

1137 

1138@keras_export('keras.metrics.TrueNegatives') 

1139class TrueNegatives(_ConfusionMatrixConditionCount): 

1140 """Calculates the number of true negatives. 

1141 

1142 If `sample_weight` is given, calculates the sum of the weights of 

1143 true negatives. This metric creates one local variable, `accumulator` 

1144 that is used to keep track of the number of true negatives. 

1145 

1146 If `sample_weight` is `None`, weights default to 1. 

1147 Use `sample_weight` of 0 to mask values. 

1148 

1149 Args: 

1150 thresholds: (Optional) Defaults to 0.5. A float value or a python 

1151 list/tuple of float threshold values in [0, 1]. A threshold is compared 

1152 with prediction values to determine the truth value of predictions 

1153 (i.e., above the threshold is `true`, below is `false`). One metric 

1154 value is generated for each threshold value. 

1155 name: (Optional) string name of the metric instance. 

1156 dtype: (Optional) data type of the metric result. 

1157 

1158 Standalone usage: 

1159 

1160 >>> m = tf.keras.metrics.TrueNegatives() 

1161 >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0]) 

1162 >>> m.result().numpy() 

1163 2.0 

1164 

1165 >>> m.reset_state() 

1166 >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0]) 

1167 >>> m.result().numpy() 

1168 1.0 

1169 

1170 Usage with `compile()` API: 

1171 

1172 ```python 

1173 model.compile(optimizer='sgd', 

1174 loss='mse', 

1175 metrics=[tf.keras.metrics.TrueNegatives()]) 

1176 ``` 

1177 """ 

1178 

1179 def __init__(self, thresholds=None, name=None, dtype=None): 

1180 super(TrueNegatives, self).__init__( 

1181 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES, 

1182 thresholds=thresholds, 

1183 name=name, 

1184 dtype=dtype) 

1185 

1186 

1187@keras_export('keras.metrics.TruePositives') 

1188class TruePositives(_ConfusionMatrixConditionCount): 

1189 """Calculates the number of true positives. 

1190 

1191 If `sample_weight` is given, calculates the sum of the weights of 

1192 true positives. This metric creates one local variable, `true_positives` 

1193 that is used to keep track of the number of true positives. 

1194 

1195 If `sample_weight` is `None`, weights default to 1. 

1196 Use `sample_weight` of 0 to mask values. 

1197 

1198 Args: 

1199 thresholds: (Optional) Defaults to 0.5. A float value or a python 

1200 list/tuple of float threshold values in [0, 1]. A threshold is compared 

1201 with prediction values to determine the truth value of predictions 

1202 (i.e., above the threshold is `true`, below is `false`). One metric 

1203 value is generated for each threshold value. 

1204 name: (Optional) string name of the metric instance. 

1205 dtype: (Optional) data type of the metric result. 

1206 

1207 Standalone usage: 

1208 

1209 >>> m = tf.keras.metrics.TruePositives() 

1210 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 

1211 >>> m.result().numpy() 

1212 2.0 

1213 

1214 >>> m.reset_state() 

1215 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 

1216 >>> m.result().numpy() 

1217 1.0 

1218 

1219 Usage with `compile()` API: 

1220 

1221 ```python 

1222 model.compile(optimizer='sgd', 

1223 loss='mse', 

1224 metrics=[tf.keras.metrics.TruePositives()]) 

1225 ``` 

1226 """ 

1227 

1228 def __init__(self, thresholds=None, name=None, dtype=None): 

1229 super(TruePositives, self).__init__( 

1230 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES, 

1231 thresholds=thresholds, 

1232 name=name, 

1233 dtype=dtype) 

1234 

1235 

1236@keras_export('keras.metrics.Precision') 

1237class Precision(Metric): 

1238 """Computes the precision of the predictions with respect to the labels. 

1239 

1240 The metric creates two local variables, `true_positives` and `false_positives` 

1241 that are used to compute the precision. This value is ultimately returned as 

1242 `precision`, an idempotent operation that simply divides `true_positives` 

1243 by the sum of `true_positives` and `false_positives`. 

1244 

1245 If `sample_weight` is `None`, weights default to 1. 

1246 Use `sample_weight` of 0 to mask values. 

1247 

1248 If `top_k` is set, we'll calculate precision as how often on average a class 

1249 among the top-k classes with the highest predicted values of a batch entry is 

1250 correct and can be found in the label for that entry. 

1251 

1252 If `class_id` is specified, we calculate precision by considering only the 

1253 entries in the batch for which `class_id` is above the threshold and/or in the 

1254 top-k highest predictions, and computing the fraction of them for which 

1255 `class_id` is indeed a correct label. 

1256 

1257 Args: 

1258 thresholds: (Optional) A float value or a python list/tuple of float 

1259 threshold values in [0, 1]. A threshold is compared with prediction 

1260 values to determine the truth value of predictions (i.e., above the 

1261 threshold is `true`, below is `false`). One metric value is generated 

1262 for each threshold value. If neither thresholds nor top_k are set, the 

1263 default is to calculate precision with `thresholds=0.5`. 

1264 top_k: (Optional) Unset by default. An int value specifying the top-k 

1265 predictions to consider when calculating precision. 

1266 class_id: (Optional) Integer class ID for which we want binary metrics. 

1267 This must be in the half-open interval `[0, num_classes)`, where 

1268 `num_classes` is the last dimension of predictions. 

1269 name: (Optional) string name of the metric instance. 

1270 dtype: (Optional) data type of the metric result. 

1271 

1272 Standalone usage: 

1273 

1274 >>> m = tf.keras.metrics.Precision() 

1275 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 

1276 >>> m.result().numpy() 

1277 0.6666667 

1278 

1279 >>> m.reset_state() 

1280 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 

1281 >>> m.result().numpy() 

1282 1.0 

1283 

1284 >>> # With top_k=2, it will calculate precision over y_true[:2] and y_pred[:2] 

1285 >>> m = tf.keras.metrics.Precision(top_k=2) 

1286 >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) 

1287 >>> m.result().numpy() 

1288 0.0 

1289 

1290 >>> # With top_k=4, it will calculate precision over y_true[:4] and y_pred[:4] 

1291 >>> m = tf.keras.metrics.Precision(top_k=4) 

1292 >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) 

1293 >>> m.result().numpy() 

1294 0.5 

1295 

1296 Usage with `compile()` API: 

1297 

1298 ```python 

1299 model.compile(optimizer='sgd', 

1300 loss='mse', 

1301 metrics=[tf.keras.metrics.Precision()]) 

1302 ``` 

1303 """ 

1304 

1305 def __init__(self, 

1306 thresholds=None, 

1307 top_k=None, 

1308 class_id=None, 

1309 name=None, 

1310 dtype=None): 

1311 super(Precision, self).__init__(name=name, dtype=dtype) 

1312 self.init_thresholds = thresholds 

1313 self.top_k = top_k 

1314 self.class_id = class_id 

1315 

1316 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF 

1317 self.thresholds = metrics_utils.parse_init_thresholds( 

1318 thresholds, default_threshold=default_threshold) 

1319 self._thresholds_distributed_evenly = ( 

1320 metrics_utils.is_evenly_distributed_thresholds(self.thresholds)) 

1321 self.true_positives = self.add_weight( 

1322 'true_positives', 

1323 shape=(len(self.thresholds),), 

1324 initializer=init_ops.zeros_initializer) 

1325 self.false_positives = self.add_weight( 

1326 'false_positives', 

1327 shape=(len(self.thresholds),), 

1328 initializer=init_ops.zeros_initializer) 

1329 

1330 def update_state(self, y_true, y_pred, sample_weight=None): 

1331 """Accumulates true positive and false positive statistics. 

1332 

1333 Args: 

1334 y_true: The ground truth values, with the same dimensions as `y_pred`. 

1335 Will be cast to `bool`. 

1336 y_pred: The predicted values. Each element must be in the range `[0, 1]`. 

1337 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 

1338 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 

1339 be broadcastable to `y_true`. 

1340 

1341 Returns: 

1342 Update op. 

1343 """ 

1344 return metrics_utils.update_confusion_matrix_variables( 

1345 { 

1346 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 

1347 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives 

1348 }, 

1349 y_true, 

1350 y_pred, 

1351 thresholds=self.thresholds, 

1352 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 

1353 top_k=self.top_k, 

1354 class_id=self.class_id, 

1355 sample_weight=sample_weight) 

1356 

1357 def result(self): 

1358 result = math_ops.div_no_nan(self.true_positives, 

1359 self.true_positives + self.false_positives) 

1360 return result[0] if len(self.thresholds) == 1 else result 

1361 

1362 def reset_state(self): 

1363 num_thresholds = len(to_list(self.thresholds)) 

1364 backend.batch_set_value([(v, np.zeros((num_thresholds,))) 

1365 for v in (self.true_positives, 

1366 self.false_positives)]) 

1367 

1368 def get_config(self): 

1369 config = { 

1370 'thresholds': self.init_thresholds, 

1371 'top_k': self.top_k, 

1372 'class_id': self.class_id 

1373 } 

1374 base_config = super(Precision, self).get_config() 

1375 return dict(list(base_config.items()) + list(config.items())) 

1376 

1377 

1378@keras_export('keras.metrics.Recall') 

1379class Recall(Metric): 

1380 """Computes the recall of the predictions with respect to the labels. 

1381 

1382 This metric creates two local variables, `true_positives` and 

1383 `false_negatives`, that are used to compute the recall. This value is 

1384 ultimately returned as `recall`, an idempotent operation that simply divides 

1385 `true_positives` by the sum of `true_positives` and `false_negatives`. 

1386 

1387 If `sample_weight` is `None`, weights default to 1. 

1388 Use `sample_weight` of 0 to mask values. 

1389 

1390 If `top_k` is set, recall will be computed as how often on average a class 

1391 among the labels of a batch entry is in the top-k predictions. 

1392 

1393 If `class_id` is specified, we calculate recall by considering only the 

1394 entries in the batch for which `class_id` is in the label, and computing the 

1395 fraction of them for which `class_id` is above the threshold and/or in the 

1396 top-k predictions. 

1397 

1398 Args: 

1399 thresholds: (Optional) A float value or a python list/tuple of float 

1400 threshold values in [0, 1]. A threshold is compared with prediction 

1401 values to determine the truth value of predictions (i.e., above the 

1402 threshold is `true`, below is `false`). One metric value is generated 

1403 for each threshold value. If neither thresholds nor top_k are set, the 

1404 default is to calculate recall with `thresholds=0.5`. 

1405 top_k: (Optional) Unset by default. An int value specifying the top-k 

1406 predictions to consider when calculating recall. 

1407 class_id: (Optional) Integer class ID for which we want binary metrics. 

1408 This must be in the half-open interval `[0, num_classes)`, where 

1409 `num_classes` is the last dimension of predictions. 

1410 name: (Optional) string name of the metric instance. 

1411 dtype: (Optional) data type of the metric result. 

1412 

1413 Standalone usage: 

1414 

1415 >>> m = tf.keras.metrics.Recall() 

1416 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 

1417 >>> m.result().numpy() 

1418 0.6666667 

1419 

1420 >>> m.reset_state() 

1421 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 

1422 >>> m.result().numpy() 

1423 1.0 

1424 

1425 Usage with `compile()` API: 

1426 

1427 ```python 

1428 model.compile(optimizer='sgd', 

1429 loss='mse', 

1430 metrics=[tf.keras.metrics.Recall()]) 

1431 ``` 

1432 """ 

1433 

1434 def __init__(self, 

1435 thresholds=None, 

1436 top_k=None, 

1437 class_id=None, 

1438 name=None, 

1439 dtype=None): 

1440 super(Recall, self).__init__(name=name, dtype=dtype) 

1441 self.init_thresholds = thresholds 

1442 self.top_k = top_k 

1443 self.class_id = class_id 

1444 

1445 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF 

1446 self.thresholds = metrics_utils.parse_init_thresholds( 

1447 thresholds, default_threshold=default_threshold) 

1448 self._thresholds_distributed_evenly = ( 

1449 metrics_utils.is_evenly_distributed_thresholds(self.thresholds)) 

1450 self.true_positives = self.add_weight( 

1451 'true_positives', 

1452 shape=(len(self.thresholds),), 

1453 initializer=init_ops.zeros_initializer) 

1454 self.false_negatives = self.add_weight( 

1455 'false_negatives', 

1456 shape=(len(self.thresholds),), 

1457 initializer=init_ops.zeros_initializer) 

1458 

1459 def update_state(self, y_true, y_pred, sample_weight=None): 

1460 """Accumulates true positive and false negative statistics. 

1461 

1462 Args: 

1463 y_true: The ground truth values, with the same dimensions as `y_pred`. 

1464 Will be cast to `bool`. 

1465 y_pred: The predicted values. Each element must be in the range `[0, 1]`. 

1466 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 

1467 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 

1468 be broadcastable to `y_true`. 

1469 

1470 Returns: 

1471 Update op. 

1472 """ 

1473 return metrics_utils.update_confusion_matrix_variables( 

1474 { 

1475 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 

1476 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives 

1477 }, 

1478 y_true, 

1479 y_pred, 

1480 thresholds=self.thresholds, 

1481 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 

1482 top_k=self.top_k, 

1483 class_id=self.class_id, 

1484 sample_weight=sample_weight) 

1485 

1486 def result(self): 

1487 result = math_ops.div_no_nan(self.true_positives, 

1488 self.true_positives + self.false_negatives) 

1489 return result[0] if len(self.thresholds) == 1 else result 

1490 

1491 def reset_state(self): 

1492 num_thresholds = len(to_list(self.thresholds)) 

1493 backend.batch_set_value([(v, np.zeros((num_thresholds,))) 

1494 for v in (self.true_positives, 

1495 self.false_negatives)]) 

1496 

1497 def get_config(self): 

1498 config = { 

1499 'thresholds': self.init_thresholds, 

1500 'top_k': self.top_k, 

1501 'class_id': self.class_id 

1502 } 

1503 base_config = super(Recall, self).get_config() 

1504 return dict(list(base_config.items()) + list(config.items())) 

1505 

1506 

1507class SensitivitySpecificityBase(Metric, metaclass=abc.ABCMeta): 

1508 """Abstract base class for computing sensitivity and specificity. 

1509 

1510 For additional information about specificity and sensitivity, see 

1511 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). 

1512 """ 

1513 

1514 def __init__(self, 

1515 value, 

1516 num_thresholds=200, 

1517 class_id=None, 

1518 name=None, 

1519 dtype=None): 

1520 super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype) 

1521 if num_thresholds <= 0: 

1522 raise ValueError('`num_thresholds` must be > 0.') 

1523 self.value = value 

1524 self.class_id = class_id 

1525 self.true_positives = self.add_weight( 

1526 'true_positives', 

1527 shape=(num_thresholds,), 

1528 initializer=init_ops.zeros_initializer) 

1529 self.true_negatives = self.add_weight( 

1530 'true_negatives', 

1531 shape=(num_thresholds,), 

1532 initializer=init_ops.zeros_initializer) 

1533 self.false_positives = self.add_weight( 

1534 'false_positives', 

1535 shape=(num_thresholds,), 

1536 initializer=init_ops.zeros_initializer) 

1537 self.false_negatives = self.add_weight( 

1538 'false_negatives', 

1539 shape=(num_thresholds,), 

1540 initializer=init_ops.zeros_initializer) 

1541 

1542 # Compute `num_thresholds` thresholds in [0, 1] 

1543 if num_thresholds == 1: 

1544 self.thresholds = [0.5] 

1545 self._thresholds_distributed_evenly = False 

1546 else: 

1547 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 

1548 for i in range(num_thresholds - 2)] 

1549 self.thresholds = [0.0] + thresholds + [1.0] 

1550 self._thresholds_distributed_evenly = True 

1551 

1552 def update_state(self, y_true, y_pred, sample_weight=None): 

1553 """Accumulates confusion matrix statistics. 

1554 

1555 Args: 

1556 y_true: The ground truth values. 

1557 y_pred: The predicted values. 

1558 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 

1559 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 

1560 be broadcastable to `y_true`. 

1561 

1562 Returns: 

1563 Update op. 

1564 """ 

1565 return metrics_utils.update_confusion_matrix_variables( 

1566 { 

1567 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 

1568 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, 

1569 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, 

1570 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, 

1571 }, 

1572 y_true, 

1573 y_pred, 

1574 thresholds=self.thresholds, 

1575 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 

1576 class_id=self.class_id, 

1577 sample_weight=sample_weight) 

1578 

1579 def reset_state(self): 

1580 num_thresholds = len(self.thresholds) 

1581 confusion_matrix_variables = (self.true_positives, self.true_negatives, 

1582 self.false_positives, self.false_negatives) 

1583 backend.batch_set_value([ 

1584 (v, np.zeros((num_thresholds,))) for v in confusion_matrix_variables 

1585 ]) 

1586 

1587 def get_config(self): 

1588 config = {'class_id': self.class_id} 

1589 base_config = super(SensitivitySpecificityBase, self).get_config() 

1590 return dict(list(base_config.items()) + list(config.items())) 

1591 

1592 def _find_max_under_constraint(self, constrained, dependent, predicate): 

1593 """Returns the maximum of dependent_statistic that satisfies the constraint. 

1594 

1595 Args: 

1596 constrained: Over these values the constraint 

1597 is specified. A rank-1 tensor. 

1598 dependent: From these values the maximum that satiesfies the 

1599 constraint is selected. Values in this tensor and in 

1600 `constrained` are linked by having the same threshold at each 

1601 position, hence this tensor must have the same shape. 

1602 predicate: A binary boolean functor to be applied to arguments 

1603 `constrained` and `self.value`, e.g. `tf.greater`. 

1604 

1605 Returns maximal dependent value, if no value satiesfies the constraint 0.0. 

1606 """ 

1607 feasible = array_ops.where_v2(predicate(constrained, self.value)) 

1608 feasible_exists = math_ops.greater(array_ops.size(feasible), 0) 

1609 max_dependent = math_ops.reduce_max(array_ops.gather(dependent, feasible)) 

1610 

1611 return array_ops.where_v2(feasible_exists, max_dependent, 0.0) 

1612 

1613 

1614@keras_export('keras.metrics.SensitivityAtSpecificity') 

1615class SensitivityAtSpecificity(SensitivitySpecificityBase): 

1616 """Computes best sensitivity where specificity is >= specified value. 

1617 

1618 the sensitivity at a given specificity. 

1619 

1620 `Sensitivity` measures the proportion of actual positives that are correctly 

1621 identified as such (tp / (tp + fn)). 

1622 `Specificity` measures the proportion of actual negatives that are correctly 

1623 identified as such (tn / (tn + fp)). 

1624 

1625 This metric creates four local variables, `true_positives`, `true_negatives`, 

1626 `false_positives` and `false_negatives` that are used to compute the 

1627 sensitivity at the given specificity. The threshold for the given specificity 

1628 value is computed and used to evaluate the corresponding sensitivity. 

1629 

1630 If `sample_weight` is `None`, weights default to 1. 

1631 Use `sample_weight` of 0 to mask values. 

1632 

1633 If `class_id` is specified, we calculate precision by considering only the 

1634 entries in the batch for which `class_id` is above the threshold predictions, 

1635 and computing the fraction of them for which `class_id` is indeed a correct 

1636 label. 

1637 

1638 For additional information about specificity and sensitivity, see 

1639 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). 

1640 

1641 Args: 

1642 specificity: A scalar value in range `[0, 1]`. 

1643 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 

1644 use for matching the given specificity. 

1645 class_id: (Optional) Integer class ID for which we want binary metrics. 

1646 This must be in the half-open interval `[0, num_classes)`, where 

1647 `num_classes` is the last dimension of predictions. 

1648 name: (Optional) string name of the metric instance. 

1649 dtype: (Optional) data type of the metric result. 

1650 

1651 Standalone usage: 

1652 

1653 >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5) 

1654 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) 

1655 >>> m.result().numpy() 

1656 0.5 

1657 

1658 >>> m.reset_state() 

1659 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], 

1660 ... sample_weight=[1, 1, 2, 2, 1]) 

1661 >>> m.result().numpy() 

1662 0.333333 

1663 

1664 Usage with `compile()` API: 

1665 

1666 ```python 

1667 model.compile( 

1668 optimizer='sgd', 

1669 loss='mse', 

1670 metrics=[tf.keras.metrics.SensitivityAtSpecificity()]) 

1671 ``` 

1672 """ 

1673 

1674 def __init__(self, 

1675 specificity, 

1676 num_thresholds=200, 

1677 class_id=None, 

1678 name=None, 

1679 dtype=None): 

1680 if specificity < 0 or specificity > 1: 

1681 raise ValueError('`specificity` must be in the range [0, 1].') 

1682 self.specificity = specificity 

1683 self.num_thresholds = num_thresholds 

1684 super(SensitivityAtSpecificity, self).__init__( 

1685 specificity, 

1686 num_thresholds=num_thresholds, 

1687 class_id=class_id, 

1688 name=name, 

1689 dtype=dtype) 

1690 

1691 def result(self): 

1692 specificities = math_ops.div_no_nan( 

1693 self.true_negatives, self.true_negatives + self.false_positives) 

1694 sensitivities = math_ops.div_no_nan( 

1695 self.true_positives, self.true_positives + self.false_negatives) 

1696 return self._find_max_under_constraint( 

1697 specificities, sensitivities, math_ops.greater_equal) 

1698 

1699 def get_config(self): 

1700 config = { 

1701 'num_thresholds': self.num_thresholds, 

1702 'specificity': self.specificity 

1703 } 

1704 base_config = super(SensitivityAtSpecificity, self).get_config() 

1705 return dict(list(base_config.items()) + list(config.items())) 

1706 

1707 

1708@keras_export('keras.metrics.SpecificityAtSensitivity') 

1709class SpecificityAtSensitivity(SensitivitySpecificityBase): 

1710 """Computes best specificity where sensitivity is >= specified value. 

1711 

1712 `Sensitivity` measures the proportion of actual positives that are correctly 

1713 identified as such (tp / (tp + fn)). 

1714 `Specificity` measures the proportion of actual negatives that are correctly 

1715 identified as such (tn / (tn + fp)). 

1716 

1717 This metric creates four local variables, `true_positives`, `true_negatives`, 

1718 `false_positives` and `false_negatives` that are used to compute the 

1719 specificity at the given sensitivity. The threshold for the given sensitivity 

1720 value is computed and used to evaluate the corresponding specificity. 

1721 

1722 If `sample_weight` is `None`, weights default to 1. 

1723 Use `sample_weight` of 0 to mask values. 

1724 

1725 If `class_id` is specified, we calculate precision by considering only the 

1726 entries in the batch for which `class_id` is above the threshold predictions, 

1727 and computing the fraction of them for which `class_id` is indeed a correct 

1728 label. 

1729 

1730 For additional information about specificity and sensitivity, see 

1731 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). 

1732 

1733 Args: 

1734 sensitivity: A scalar value in range `[0, 1]`. 

1735 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 

1736 use for matching the given sensitivity. 

1737 class_id: (Optional) Integer class ID for which we want binary metrics. 

1738 This must be in the half-open interval `[0, num_classes)`, where 

1739 `num_classes` is the last dimension of predictions. 

1740 name: (Optional) string name of the metric instance. 

1741 dtype: (Optional) data type of the metric result. 

1742 

1743 Standalone usage: 

1744 

1745 >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5) 

1746 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) 

1747 >>> m.result().numpy() 

1748 0.66666667 

1749 

1750 >>> m.reset_state() 

1751 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], 

1752 ... sample_weight=[1, 1, 2, 2, 2]) 

1753 >>> m.result().numpy() 

1754 0.5 

1755 

1756 Usage with `compile()` API: 

1757 

1758 ```python 

1759 model.compile( 

1760 optimizer='sgd', 

1761 loss='mse', 

1762 metrics=[tf.keras.metrics.SpecificityAtSensitivity()]) 

1763 ``` 

1764 """ 

1765 

1766 def __init__(self, 

1767 sensitivity, 

1768 num_thresholds=200, 

1769 class_id=None, 

1770 name=None, 

1771 dtype=None): 

1772 if sensitivity < 0 or sensitivity > 1: 

1773 raise ValueError('`sensitivity` must be in the range [0, 1].') 

1774 self.sensitivity = sensitivity 

1775 self.num_thresholds = num_thresholds 

1776 super(SpecificityAtSensitivity, self).__init__( 

1777 sensitivity, 

1778 num_thresholds=num_thresholds, 

1779 class_id=class_id, 

1780 name=name, 

1781 dtype=dtype) 

1782 

1783 def result(self): 

1784 sensitivities = math_ops.div_no_nan( 

1785 self.true_positives, self.true_positives + self.false_negatives) 

1786 specificities = math_ops.div_no_nan( 

1787 self.true_negatives, self.true_negatives + self.false_positives) 

1788 return self._find_max_under_constraint( 

1789 sensitivities, specificities, math_ops.greater_equal) 

1790 

1791 def get_config(self): 

1792 config = { 

1793 'num_thresholds': self.num_thresholds, 

1794 'sensitivity': self.sensitivity 

1795 } 

1796 base_config = super(SpecificityAtSensitivity, self).get_config() 

1797 return dict(list(base_config.items()) + list(config.items())) 

1798 

1799 

1800@keras_export('keras.metrics.PrecisionAtRecall') 

1801class PrecisionAtRecall(SensitivitySpecificityBase): 

1802 """Computes best precision where recall is >= specified value. 

1803 

1804 This metric creates four local variables, `true_positives`, `true_negatives`, 

1805 `false_positives` and `false_negatives` that are used to compute the 

1806 precision at the given recall. The threshold for the given recall 

1807 value is computed and used to evaluate the corresponding precision. 

1808 

1809 If `sample_weight` is `None`, weights default to 1. 

1810 Use `sample_weight` of 0 to mask values. 

1811 

1812 If `class_id` is specified, we calculate precision by considering only the 

1813 entries in the batch for which `class_id` is above the threshold predictions, 

1814 and computing the fraction of them for which `class_id` is indeed a correct 

1815 label. 

1816 

1817 Args: 

1818 recall: A scalar value in range `[0, 1]`. 

1819 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 

1820 use for matching the given recall. 

1821 class_id: (Optional) Integer class ID for which we want binary metrics. 

1822 This must be in the half-open interval `[0, num_classes)`, where 

1823 `num_classes` is the last dimension of predictions. 

1824 name: (Optional) string name of the metric instance. 

1825 dtype: (Optional) data type of the metric result. 

1826 

1827 Standalone usage: 

1828 

1829 >>> m = tf.keras.metrics.PrecisionAtRecall(0.5) 

1830 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) 

1831 >>> m.result().numpy() 

1832 0.5 

1833 

1834 >>> m.reset_state() 

1835 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], 

1836 ... sample_weight=[2, 2, 2, 1, 1]) 

1837 >>> m.result().numpy() 

1838 0.33333333 

1839 

1840 Usage with `compile()` API: 

1841 

1842 ```python 

1843 model.compile( 

1844 optimizer='sgd', 

1845 loss='mse', 

1846 metrics=[tf.keras.metrics.PrecisionAtRecall(recall=0.8)]) 

1847 ``` 

1848 """ 

1849 

1850 def __init__(self, 

1851 recall, 

1852 num_thresholds=200, 

1853 class_id=None, 

1854 name=None, 

1855 dtype=None): 

1856 if recall < 0 or recall > 1: 

1857 raise ValueError('`recall` must be in the range [0, 1].') 

1858 self.recall = recall 

1859 self.num_thresholds = num_thresholds 

1860 super(PrecisionAtRecall, self).__init__( 

1861 value=recall, 

1862 num_thresholds=num_thresholds, 

1863 class_id=class_id, 

1864 name=name, 

1865 dtype=dtype) 

1866 

1867 def result(self): 

1868 recalls = math_ops.div_no_nan( 

1869 self.true_positives, self.true_positives + self.false_negatives) 

1870 precisions = math_ops.div_no_nan( 

1871 self.true_positives, self.true_positives + self.false_positives) 

1872 return self._find_max_under_constraint( 

1873 recalls, precisions, math_ops.greater_equal) 

1874 

1875 def get_config(self): 

1876 config = {'num_thresholds': self.num_thresholds, 'recall': self.recall} 

1877 base_config = super(PrecisionAtRecall, self).get_config() 

1878 return dict(list(base_config.items()) + list(config.items())) 

1879 

1880 

1881@keras_export('keras.metrics.RecallAtPrecision') 

1882class RecallAtPrecision(SensitivitySpecificityBase): 

1883 """Computes best recall where precision is >= specified value. 

1884 

1885 For a given score-label-distribution the required precision might not 

1886 be achievable, in this case 0.0 is returned as recall. 

1887 

1888 This metric creates four local variables, `true_positives`, `true_negatives`, 

1889 `false_positives` and `false_negatives` that are used to compute the 

1890 recall at the given precision. The threshold for the given precision 

1891 value is computed and used to evaluate the corresponding recall. 

1892 

1893 If `sample_weight` is `None`, weights default to 1. 

1894 Use `sample_weight` of 0 to mask values. 

1895 

1896 If `class_id` is specified, we calculate precision by considering only the 

1897 entries in the batch for which `class_id` is above the threshold predictions, 

1898 and computing the fraction of them for which `class_id` is indeed a correct 

1899 label. 

1900 

1901 Args: 

1902 precision: A scalar value in range `[0, 1]`. 

1903 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 

1904 use for matching the given precision. 

1905 class_id: (Optional) Integer class ID for which we want binary metrics. 

1906 This must be in the half-open interval `[0, num_classes)`, where 

1907 `num_classes` is the last dimension of predictions. 

1908 name: (Optional) string name of the metric instance. 

1909 dtype: (Optional) data type of the metric result. 

1910 

1911 Standalone usage: 

1912 

1913 >>> m = tf.keras.metrics.RecallAtPrecision(0.8) 

1914 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) 

1915 >>> m.result().numpy() 

1916 0.5 

1917 

1918 >>> m.reset_state() 

1919 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], 

1920 ... sample_weight=[1, 0, 0, 1]) 

1921 >>> m.result().numpy() 

1922 1.0 

1923 

1924 Usage with `compile()` API: 

1925 

1926 ```python 

1927 model.compile( 

1928 optimizer='sgd', 

1929 loss='mse', 

1930 metrics=[tf.keras.metrics.RecallAtPrecision(precision=0.8)]) 

1931 ``` 

1932 """ 

1933 

1934 def __init__(self, 

1935 precision, 

1936 num_thresholds=200, 

1937 class_id=None, 

1938 name=None, 

1939 dtype=None): 

1940 if precision < 0 or precision > 1: 

1941 raise ValueError('`precision` must be in the range [0, 1].') 

1942 self.precision = precision 

1943 self.num_thresholds = num_thresholds 

1944 super(RecallAtPrecision, self).__init__( 

1945 value=precision, 

1946 num_thresholds=num_thresholds, 

1947 class_id=class_id, 

1948 name=name, 

1949 dtype=dtype) 

1950 

1951 def result(self): 

1952 precisions = math_ops.div_no_nan( 

1953 self.true_positives, self.true_positives + self.false_positives) 

1954 recalls = math_ops.div_no_nan( 

1955 self.true_positives, self.true_positives + self.false_negatives) 

1956 return self._find_max_under_constraint( 

1957 precisions, recalls, math_ops.greater_equal) 

1958 

1959 def get_config(self): 

1960 config = {'num_thresholds': self.num_thresholds, 

1961 'precision': self.precision} 

1962 base_config = super(RecallAtPrecision, self).get_config() 

1963 return dict(list(base_config.items()) + list(config.items())) 

1964 

1965 

1966@keras_export('keras.metrics.AUC') 

1967class AUC(Metric): 

1968 """Approximates the AUC (Area under the curve) of the ROC or PR curves. 

1969 

1970 The AUC (Area under the curve) of the ROC (Receiver operating 

1971 characteristic; default) or PR (Precision Recall) curves are quality measures 

1972 of binary classifiers. Unlike the accuracy, and like cross-entropy 

1973 losses, ROC-AUC and PR-AUC evaluate all the operational points of a model. 

1974 

1975 This class approximates AUCs using a Riemann sum. During the metric 

1976 accumulation phrase, predictions are accumulated within predefined buckets 

1977 by value. The AUC is then computed by interpolating per-bucket averages. These 

1978 buckets define the evaluated operational points. 

1979 

1980 This metric creates four local variables, `true_positives`, `true_negatives`, 

1981 `false_positives` and `false_negatives` that are used to compute the AUC. 

1982 To discretize the AUC curve, a linearly spaced set of thresholds is used to 

1983 compute pairs of recall and precision values. The area under the ROC-curve is 

1984 therefore computed using the height of the recall values by the false positive 

1985 rate, while the area under the PR-curve is the computed using the height of 

1986 the precision values by the recall. 

1987 

1988 This value is ultimately returned as `auc`, an idempotent operation that 

1989 computes the area under a discretized curve of precision versus recall values 

1990 (computed using the aforementioned variables). The `num_thresholds` variable 

1991 controls the degree of discretization with larger numbers of thresholds more 

1992 closely approximating the true AUC. The quality of the approximation may vary 

1993 dramatically depending on `num_thresholds`. The `thresholds` parameter can be 

1994 used to manually specify thresholds which split the predictions more evenly. 

1995 

1996 For a best approximation of the real AUC, `predictions` should be distributed 

1997 approximately uniformly in the range [0, 1] (if `from_logits=False`). The 

1998 quality of the AUC approximation may be poor if this is not the case. Setting 

1999 `summation_method` to 'minoring' or 'majoring' can help quantify the error in 

2000 the approximation by providing lower or upper bound estimate of the AUC. 

2001 

2002 If `sample_weight` is `None`, weights default to 1. 

2003 Use `sample_weight` of 0 to mask values. 

2004 

2005 Args: 

2006 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 

2007 use when discretizing the roc curve. Values must be > 1. 

2008 curve: (Optional) Specifies the name of the curve to be computed, 'ROC' 

2009 [default] or 'PR' for the Precision-Recall-curve. 

2010 summation_method: (Optional) Specifies the [Riemann summation method]( 

2011 https://en.wikipedia.org/wiki/Riemann_sum) used. 

2012 'interpolation' (default) applies mid-point summation scheme for `ROC`. 

2013 For PR-AUC, interpolates (true/false) positives but not the ratio that 

2014 is precision (see Davis & Goadrich 2006 for details); 

2015 'minoring' applies left summation 

2016 for increasing intervals and right summation for decreasing intervals; 

2017 'majoring' does the opposite. 

2018 name: (Optional) string name of the metric instance. 

2019 dtype: (Optional) data type of the metric result. 

2020 thresholds: (Optional) A list of floating point values to use as the 

2021 thresholds for discretizing the curve. If set, the `num_thresholds` 

2022 parameter is ignored. Values should be in [0, 1]. Endpoint thresholds 

2023 equal to {-epsilon, 1+epsilon} for a small positive epsilon value will 

2024 be automatically included with these to correctly handle predictions 

2025 equal to exactly 0 or 1. 

2026 multi_label: boolean indicating whether multilabel data should be 

2027 treated as such, wherein AUC is computed separately for each label and 

2028 then averaged across labels, or (when False) if the data should be 

2029 flattened into a single label before AUC computation. In the latter 

2030 case, when multilabel data is passed to AUC, each label-prediction pair 

2031 is treated as an individual data point. Should be set to False for 

2032 multi-class data. 

2033 num_labels: (Optional) The number of labels, used when `multi_label` is 

2034 True. If `num_labels` is not specified, then state variables get created 

2035 on the first call to `update_state`. 

2036 label_weights: (Optional) list, array, or tensor of non-negative weights 

2037 used to compute AUCs for multilabel data. When `multi_label` is True, 

2038 the weights are applied to the individual label AUCs when they are 

2039 averaged to produce the multi-label AUC. When it's False, they are used 

2040 to weight the individual label predictions in computing the confusion 

2041 matrix on the flattened data. Note that this is unlike class_weights in 

2042 that class_weights weights the example depending on the value of its 

2043 label, whereas label_weights depends only on the index of that label 

2044 before flattening; therefore `label_weights` should not be used for 

2045 multi-class data. 

2046 from_logits: boolean indicating whether the predictions (`y_pred` in 

2047 `update_state`) are probabilities or sigmoid logits. As a rule of thumb, 

2048 when using a keras loss, the `from_logits` constructor argument of the 

2049 loss should match the AUC `from_logits` constructor argument. 

2050 

2051 Standalone usage: 

2052 

2053 >>> m = tf.keras.metrics.AUC(num_thresholds=3) 

2054 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) 

2055 >>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7] 

2056 >>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] 

2057 >>> # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0] 

2058 >>> # auc = ((((1+0.5)/2)*(1-0)) + (((0.5+0)/2)*(0-0))) = 0.75 

2059 >>> m.result().numpy() 

2060 0.75 

2061 

2062 >>> m.reset_state() 

2063 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], 

2064 ... sample_weight=[1, 0, 0, 1]) 

2065 >>> m.result().numpy() 

2066 1.0 

2067 

2068 Usage with `compile()` API: 

2069 

2070 ```python 

2071 # Reports the AUC of a model outputing a probability. 

2072 model.compile(optimizer='sgd', 

2073 loss=tf.keras.losses.BinaryCrossentropy(), 

2074 metrics=[tf.keras.metrics.AUC()]) 

2075 

2076 # Reports the AUC of a model outputing a logit. 

2077 model.compile(optimizer='sgd', 

2078 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 

2079 metrics=[tf.keras.metrics.AUC(from_logits=True)]) 

2080 ``` 

2081 """ 

2082 

2083 def __init__(self, 

2084 num_thresholds=200, 

2085 curve='ROC', 

2086 summation_method='interpolation', 

2087 name=None, 

2088 dtype=None, 

2089 thresholds=None, 

2090 multi_label=False, 

2091 num_labels=None, 

2092 label_weights=None, 

2093 from_logits=False): 

2094 # Validate configurations. 

2095 if isinstance(curve, metrics_utils.AUCCurve) and curve not in list( 

2096 metrics_utils.AUCCurve): 

2097 raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format( 

2098 curve, list(metrics_utils.AUCCurve))) 

2099 if isinstance( 

2100 summation_method, 

2101 metrics_utils.AUCSummationMethod) and summation_method not in list( 

2102 metrics_utils.AUCSummationMethod): 

2103 raise ValueError( 

2104 'Invalid summation method: "{}". Valid options are: "{}"'.format( 

2105 summation_method, list(metrics_utils.AUCSummationMethod))) 

2106 

2107 # Update properties. 

2108 if thresholds is not None: 

2109 # If specified, use the supplied thresholds. 

2110 self.num_thresholds = len(thresholds) + 2 

2111 thresholds = sorted(thresholds) 

2112 self._thresholds_distributed_evenly = ( 

2113 metrics_utils.is_evenly_distributed_thresholds( 

2114 np.array([0.0] + thresholds + [1.0]))) 

2115 else: 

2116 if num_thresholds <= 1: 

2117 raise ValueError('`num_thresholds` must be > 1.') 

2118 

2119 # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in 

2120 # (0, 1). 

2121 self.num_thresholds = num_thresholds 

2122 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 

2123 for i in range(num_thresholds - 2)] 

2124 self._thresholds_distributed_evenly = True 

2125 

2126 # Add an endpoint "threshold" below zero and above one for either 

2127 # threshold method to account for floating point imprecisions. 

2128 self._thresholds = np.array([0.0 - backend.epsilon()] + thresholds + 

2129 [1.0 + backend.epsilon()]) 

2130 

2131 if isinstance(curve, metrics_utils.AUCCurve): 

2132 self.curve = curve 

2133 else: 

2134 self.curve = metrics_utils.AUCCurve.from_str(curve) 

2135 if isinstance(summation_method, metrics_utils.AUCSummationMethod): 

2136 self.summation_method = summation_method 

2137 else: 

2138 self.summation_method = metrics_utils.AUCSummationMethod.from_str( 

2139 summation_method) 

2140 super(AUC, self).__init__(name=name, dtype=dtype) 

2141 

2142 # Handle multilabel arguments. 

2143 self.multi_label = multi_label 

2144 if label_weights is not None: 

2145 label_weights = constant_op.constant(label_weights, dtype=self.dtype) 

2146 checks = [ 

2147 check_ops.assert_non_negative( 

2148 label_weights, 

2149 message='All values of `label_weights` must be non-negative.') 

2150 ] 

2151 with ops.control_dependencies(checks): 

2152 self.label_weights = label_weights 

2153 

2154 else: 

2155 self.label_weights = None 

2156 

2157 self._from_logits = from_logits 

2158 

2159 self._built = False 

2160 if self.multi_label: 

2161 if num_labels: 

2162 shape = tensor_shape.TensorShape([None, num_labels]) 

2163 self._build(shape) 

2164 else: 

2165 if num_labels: 

2166 raise ValueError( 

2167 '`num_labels` is needed only when `multi_label` is True.') 

2168 self._build(None) 

2169 

2170 @property 

2171 def thresholds(self): 

2172 """The thresholds used for evaluating AUC.""" 

2173 return list(self._thresholds) 

2174 

2175 def _build(self, shape): 

2176 """Initialize TP, FP, TN, and FN tensors, given the shape of the data.""" 

2177 if self.multi_label: 

2178 if shape.ndims != 2: 

2179 raise ValueError('`y_true` must have rank=2 when `multi_label` is ' 

2180 'True. Found rank %s.' % shape.ndims) 

2181 self._num_labels = shape[1] 

2182 variable_shape = tensor_shape.TensorShape( 

2183 [tensor_shape.Dimension(self.num_thresholds), self._num_labels]) 

2184 

2185 else: 

2186 variable_shape = tensor_shape.TensorShape( 

2187 [tensor_shape.Dimension(self.num_thresholds)]) 

2188 self._build_input_shape = shape 

2189 # Create metric variables 

2190 self.true_positives = self.add_weight( 

2191 'true_positives', 

2192 shape=variable_shape, 

2193 initializer=init_ops.zeros_initializer) 

2194 self.true_negatives = self.add_weight( 

2195 'true_negatives', 

2196 shape=variable_shape, 

2197 initializer=init_ops.zeros_initializer) 

2198 self.false_positives = self.add_weight( 

2199 'false_positives', 

2200 shape=variable_shape, 

2201 initializer=init_ops.zeros_initializer) 

2202 self.false_negatives = self.add_weight( 

2203 'false_negatives', 

2204 shape=variable_shape, 

2205 initializer=init_ops.zeros_initializer) 

2206 

2207 if self.multi_label: 

2208 with ops.init_scope(): 

2209 # This should only be necessary for handling v1 behavior. In v2, AUC 

2210 # should be initialized outside of any tf.functions, and therefore in 

2211 # eager mode. 

2212 if not context.executing_eagerly(): 

2213 backend._initialize_variables(backend._get_session()) # pylint: disable=protected-access 

2214 

2215 self._built = True 

2216 

2217 def update_state(self, y_true, y_pred, sample_weight=None): 

2218 """Accumulates confusion matrix statistics. 

2219 

2220 Args: 

2221 y_true: The ground truth values. 

2222 y_pred: The predicted values. 

2223 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 

2224 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 

2225 be broadcastable to `y_true`. 

2226 

2227 Returns: 

2228 Update op. 

2229 """ 

2230 deps = [] 

2231 if not self._built: 

2232 self._build(tensor_shape.TensorShape(y_pred.shape)) 

2233 

2234 if self.multi_label or (self.label_weights is not None): 

2235 # y_true should have shape (number of examples, number of labels). 

2236 shapes = [ 

2237 (y_true, ('N', 'L')) 

2238 ] 

2239 if self.multi_label: 

2240 # TP, TN, FP, and FN should all have shape 

2241 # (number of thresholds, number of labels). 

2242 shapes.extend([(self.true_positives, ('T', 'L')), 

2243 (self.true_negatives, ('T', 'L')), 

2244 (self.false_positives, ('T', 'L')), 

2245 (self.false_negatives, ('T', 'L'))]) 

2246 if self.label_weights is not None: 

2247 # label_weights should be of length equal to the number of labels. 

2248 shapes.append((self.label_weights, ('L',))) 

2249 deps = [ 

2250 check_ops.assert_shapes( 

2251 shapes, message='Number of labels is not consistent.') 

2252 ] 

2253 

2254 # Only forward label_weights to update_confusion_matrix_variables when 

2255 # multi_label is False. Otherwise the averaging of individual label AUCs is 

2256 # handled in AUC.result 

2257 label_weights = None if self.multi_label else self.label_weights 

2258 

2259 if self._from_logits: 

2260 y_pred = activations.sigmoid(y_pred) 

2261 

2262 with ops.control_dependencies(deps): 

2263 return metrics_utils.update_confusion_matrix_variables( 

2264 { 

2265 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: 

2266 self.true_positives, 

2267 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: 

2268 self.true_negatives, 

2269 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: 

2270 self.false_positives, 

2271 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: 

2272 self.false_negatives, 

2273 }, 

2274 y_true, 

2275 y_pred, 

2276 self._thresholds, 

2277 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 

2278 sample_weight=sample_weight, 

2279 multi_label=self.multi_label, 

2280 label_weights=label_weights) 

2281 

2282 def interpolate_pr_auc(self): 

2283 """Interpolation formula inspired by section 4 of Davis & Goadrich 2006. 

2284 

2285 https://www.biostat.wisc.edu/~page/rocpr.pdf 

2286 

2287 Note here we derive & use a closed formula not present in the paper 

2288 as follows: 

2289 

2290 Precision = TP / (TP + FP) = TP / P 

2291 

2292 Modeling all of TP (true positive), FP (false positive) and their sum 

2293 P = TP + FP (predicted positive) as varying linearly within each interval 

2294 [A, B] between successive thresholds, we get 

2295 

2296 Precision slope = dTP / dP 

2297 = (TP_B - TP_A) / (P_B - P_A) 

2298 = (TP - TP_A) / (P - P_A) 

2299 Precision = (TP_A + slope * (P - P_A)) / P 

2300 

2301 The area within the interval is (slope / total_pos_weight) times 

2302 

2303 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} 

2304 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} 

2305 

2306 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in 

2307 

2308 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) 

2309 

2310 Bringing back the factor (slope / total_pos_weight) we'd put aside, we get 

2311 

2312 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight 

2313 

2314 where dTP == TP_B - TP_A. 

2315 

2316 Note that when P_A == 0 the above calculation simplifies into 

2317 

2318 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) 

2319 

2320 which is really equivalent to imputing constant precision throughout the 

2321 first bucket having >0 true positives. 

2322 

2323 Returns: 

2324 pr_auc: an approximation of the area under the P-R curve. 

2325 """ 

2326 dtp = self.true_positives[:self.num_thresholds - 

2327 1] - self.true_positives[1:] 

2328 p = self.true_positives + self.false_positives 

2329 dp = p[:self.num_thresholds - 1] - p[1:] 

2330 prec_slope = math_ops.div_no_nan( 

2331 dtp, math_ops.maximum(dp, 0), name='prec_slope') 

2332 intercept = self.true_positives[1:] - math_ops.multiply(prec_slope, p[1:]) 

2333 

2334 safe_p_ratio = array_ops.where( 

2335 math_ops.logical_and(p[:self.num_thresholds - 1] > 0, p[1:] > 0), 

2336 math_ops.div_no_nan( 

2337 p[:self.num_thresholds - 1], 

2338 math_ops.maximum(p[1:], 0), 

2339 name='recall_relative_ratio'), 

2340 array_ops.ones_like(p[1:])) 

2341 

2342 pr_auc_increment = math_ops.div_no_nan( 

2343 prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)), 

2344 math_ops.maximum(self.true_positives[1:] + self.false_negatives[1:], 0), 

2345 name='pr_auc_increment') 

2346 

2347 if self.multi_label: 

2348 by_label_auc = math_ops.reduce_sum( 

2349 pr_auc_increment, name=self.name + '_by_label', axis=0) 

2350 if self.label_weights is None: 

2351 # Evenly weighted average of the label AUCs. 

2352 return math_ops.reduce_mean(by_label_auc, name=self.name) 

2353 else: 

2354 # Weighted average of the label AUCs. 

2355 return math_ops.div_no_nan( 

2356 math_ops.reduce_sum( 

2357 math_ops.multiply(by_label_auc, self.label_weights)), 

2358 math_ops.reduce_sum(self.label_weights), 

2359 name=self.name) 

2360 else: 

2361 return math_ops.reduce_sum(pr_auc_increment, name='interpolate_pr_auc') 

2362 

2363 def result(self): 

2364 if (self.curve == metrics_utils.AUCCurve.PR and 

2365 self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION 

2366 ): 

2367 # This use case is different and is handled separately. 

2368 return self.interpolate_pr_auc() 

2369 

2370 # Set `x` and `y` values for the curves based on `curve` config. 

2371 recall = math_ops.div_no_nan(self.true_positives, 

2372 self.true_positives + self.false_negatives) 

2373 if self.curve == metrics_utils.AUCCurve.ROC: 

2374 fp_rate = math_ops.div_no_nan(self.false_positives, 

2375 self.false_positives + self.true_negatives) 

2376 x = fp_rate 

2377 y = recall 

2378 else: # curve == 'PR'. 

2379 precision = math_ops.div_no_nan( 

2380 self.true_positives, self.true_positives + self.false_positives) 

2381 x = recall 

2382 y = precision 

2383 

2384 # Find the rectangle heights based on `summation_method`. 

2385 if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION: 

2386 # Note: the case ('PR', 'interpolation') has been handled above. 

2387 heights = (y[:self.num_thresholds - 1] + y[1:]) / 2. 

2388 elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING: 

2389 heights = math_ops.minimum(y[:self.num_thresholds - 1], y[1:]) 

2390 else: # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING: 

2391 heights = math_ops.maximum(y[:self.num_thresholds - 1], y[1:]) 

2392 

2393 # Sum up the areas of all the rectangles. 

2394 if self.multi_label: 

2395 riemann_terms = math_ops.multiply(x[:self.num_thresholds - 1] - x[1:], 

2396 heights) 

2397 by_label_auc = math_ops.reduce_sum( 

2398 riemann_terms, name=self.name + '_by_label', axis=0) 

2399 

2400 if self.label_weights is None: 

2401 # Unweighted average of the label AUCs. 

2402 return math_ops.reduce_mean(by_label_auc, name=self.name) 

2403 else: 

2404 # Weighted average of the label AUCs. 

2405 return math_ops.div_no_nan( 

2406 math_ops.reduce_sum( 

2407 math_ops.multiply(by_label_auc, self.label_weights)), 

2408 math_ops.reduce_sum(self.label_weights), 

2409 name=self.name) 

2410 else: 

2411 return math_ops.reduce_sum( 

2412 math_ops.multiply(x[:self.num_thresholds - 1] - x[1:], heights), 

2413 name=self.name) 

2414 

2415 def reset_state(self): 

2416 if self._built: 

2417 confusion_matrix_variables = (self.true_positives, self.true_negatives, 

2418 self.false_positives, self.false_negatives) 

2419 if self.multi_label: 

2420 backend.batch_set_value( 

2421 [(v, np.zeros((self.num_thresholds, self._num_labels))) 

2422 for v in confusion_matrix_variables]) 

2423 else: 

2424 backend.batch_set_value([(v, np.zeros((self.num_thresholds,))) 

2425 for v in confusion_matrix_variables]) 

2426 

2427 def get_config(self): 

2428 if is_tensor_or_variable(self.label_weights): 

2429 label_weights = backend.eval(self.label_weights) 

2430 else: 

2431 label_weights = self.label_weights 

2432 config = { 

2433 'num_thresholds': self.num_thresholds, 

2434 'curve': self.curve.value, 

2435 'summation_method': self.summation_method.value, 

2436 # We remove the endpoint thresholds as an inverse of how the thresholds 

2437 # were initialized. This ensures that a metric initialized from this 

2438 # config has the same thresholds. 

2439 'thresholds': self.thresholds[1:-1], 

2440 'multi_label': self.multi_label, 

2441 'label_weights': label_weights 

2442 } 

2443 base_config = super(AUC, self).get_config() 

2444 return dict(list(base_config.items()) + list(config.items())) 

2445 

2446 

2447@keras_export('keras.metrics.CosineSimilarity') 

2448class CosineSimilarity(MeanMetricWrapper): 

2449 """Computes the cosine similarity between the labels and predictions. 

2450 

2451 `cosine similarity = (a . b) / ||a|| ||b||` 

2452 

2453 See: [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity). 

2454 

2455 This metric keeps the average cosine similarity between `predictions` and 

2456 `labels` over a stream of data. 

2457 

2458 Args: 

2459 name: (Optional) string name of the metric instance. 

2460 dtype: (Optional) data type of the metric result. 

2461 axis: (Optional) Defaults to -1. The dimension along which the cosine 

2462 similarity is computed. 

2463 

2464 Standalone usage: 

2465 

2466 >>> # l2_norm(y_true) = [[0., 1.], [1./1.414, 1./1.414]] 

2467 >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414, 1./1.414]] 

2468 >>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]] 

2469 >>> # result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) 

2470 >>> # = ((0. + 0.) + (0.5 + 0.5)) / 2 

2471 >>> m = tf.keras.metrics.CosineSimilarity(axis=1) 

2472 >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]]) 

2473 >>> m.result().numpy() 

2474 0.49999997 

2475 

2476 >>> m.reset_state() 

2477 >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]], 

2478 ... sample_weight=[0.3, 0.7]) 

2479 >>> m.result().numpy() 

2480 0.6999999 

2481 

2482 Usage with `compile()` API: 

2483 

2484 ```python 

2485 model.compile( 

2486 optimizer='sgd', 

2487 loss='mse', 

2488 metrics=[tf.keras.metrics.CosineSimilarity(axis=1)]) 

2489 ``` 

2490 """ 

2491 

2492 def __init__(self, name='cosine_similarity', dtype=None, axis=-1): 

2493 super(CosineSimilarity, self).__init__( 

2494 cosine_similarity, name, dtype=dtype, axis=axis) 

2495 

2496 

2497@keras_export('keras.metrics.MeanAbsoluteError') 

2498class MeanAbsoluteError(MeanMetricWrapper): 

2499 """Computes the mean absolute error between the labels and predictions. 

2500 

2501 Args: 

2502 name: (Optional) string name of the metric instance. 

2503 dtype: (Optional) data type of the metric result. 

2504 

2505 Standalone usage: 

2506 

2507 >>> m = tf.keras.metrics.MeanAbsoluteError() 

2508 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 

2509 >>> m.result().numpy() 

2510 0.25 

2511 

2512 >>> m.reset_state() 

2513 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 

2514 ... sample_weight=[1, 0]) 

2515 >>> m.result().numpy() 

2516 0.5 

2517 

2518 Usage with `compile()` API: 

2519 

2520 ```python 

2521 model.compile( 

2522 optimizer='sgd', 

2523 loss='mse', 

2524 metrics=[tf.keras.metrics.MeanAbsoluteError()]) 

2525 ``` 

2526 """ 

2527 

2528 def __init__(self, name='mean_absolute_error', dtype=None): 

2529 super(MeanAbsoluteError, self).__init__( 

2530 mean_absolute_error, name, dtype=dtype) 

2531 

2532 

2533@keras_export('keras.metrics.MeanAbsolutePercentageError') 

2534class MeanAbsolutePercentageError(MeanMetricWrapper): 

2535 """Computes the mean absolute percentage error between `y_true` and `y_pred`. 

2536 

2537 Args: 

2538 name: (Optional) string name of the metric instance. 

2539 dtype: (Optional) data type of the metric result. 

2540 

2541 Standalone usage: 

2542 

2543 >>> m = tf.keras.metrics.MeanAbsolutePercentageError() 

2544 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 

2545 >>> m.result().numpy() 

2546 250000000.0 

2547 

2548 >>> m.reset_state() 

2549 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 

2550 ... sample_weight=[1, 0]) 

2551 >>> m.result().numpy() 

2552 500000000.0 

2553 

2554 Usage with `compile()` API: 

2555 

2556 ```python 

2557 model.compile( 

2558 optimizer='sgd', 

2559 loss='mse', 

2560 metrics=[tf.keras.metrics.MeanAbsolutePercentageError()]) 

2561 ``` 

2562 """ 

2563 

2564 def __init__(self, name='mean_absolute_percentage_error', dtype=None): 

2565 super(MeanAbsolutePercentageError, self).__init__( 

2566 mean_absolute_percentage_error, name, dtype=dtype) 

2567 

2568 

2569@keras_export('keras.metrics.MeanSquaredError') 

2570class MeanSquaredError(MeanMetricWrapper): 

2571 """Computes the mean squared error between `y_true` and `y_pred`. 

2572 

2573 Args: 

2574 name: (Optional) string name of the metric instance. 

2575 dtype: (Optional) data type of the metric result. 

2576 

2577 Standalone usage: 

2578 

2579 >>> m = tf.keras.metrics.MeanSquaredError() 

2580 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 

2581 >>> m.result().numpy() 

2582 0.25 

2583 

2584 >>> m.reset_state() 

2585 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 

2586 ... sample_weight=[1, 0]) 

2587 >>> m.result().numpy() 

2588 0.5 

2589 

2590 Usage with `compile()` API: 

2591 

2592 ```python 

2593 model.compile( 

2594 optimizer='sgd', 

2595 loss='mse', 

2596 metrics=[tf.keras.metrics.MeanSquaredError()]) 

2597 ``` 

2598 """ 

2599 

2600 def __init__(self, name='mean_squared_error', dtype=None): 

2601 super(MeanSquaredError, self).__init__( 

2602 mean_squared_error, name, dtype=dtype) 

2603 

2604 

2605@keras_export('keras.metrics.MeanSquaredLogarithmicError') 

2606class MeanSquaredLogarithmicError(MeanMetricWrapper): 

2607 """Computes the mean squared logarithmic error between `y_true` and `y_pred`. 

2608 

2609 Args: 

2610 name: (Optional) string name of the metric instance. 

2611 dtype: (Optional) data type of the metric result. 

2612 

2613 Standalone usage: 

2614 

2615 >>> m = tf.keras.metrics.MeanSquaredLogarithmicError() 

2616 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 

2617 >>> m.result().numpy() 

2618 0.12011322 

2619 

2620 >>> m.reset_state() 

2621 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 

2622 ... sample_weight=[1, 0]) 

2623 >>> m.result().numpy() 

2624 0.24022643 

2625 

2626 Usage with `compile()` API: 

2627 

2628 ```python 

2629 model.compile( 

2630 optimizer='sgd', 

2631 loss='mse', 

2632 metrics=[tf.keras.metrics.MeanSquaredLogarithmicError()]) 

2633 ``` 

2634 """ 

2635 

2636 def __init__(self, name='mean_squared_logarithmic_error', dtype=None): 

2637 super(MeanSquaredLogarithmicError, self).__init__( 

2638 mean_squared_logarithmic_error, name, dtype=dtype) 

2639 

2640 

2641@keras_export('keras.metrics.Hinge') 

2642class Hinge(MeanMetricWrapper): 

2643 """Computes the hinge metric between `y_true` and `y_pred`. 

2644 

2645 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 

2646 provided we will convert them to -1 or 1. 

2647 

2648 Args: 

2649 name: (Optional) string name of the metric instance. 

2650 dtype: (Optional) data type of the metric result. 

2651 

2652 Standalone usage: 

2653 

2654 >>> m = tf.keras.metrics.Hinge() 

2655 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 

2656 >>> m.result().numpy() 

2657 1.3 

2658 

2659 >>> m.reset_state() 

2660 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 

2661 ... sample_weight=[1, 0]) 

2662 >>> m.result().numpy() 

2663 1.1 

2664 

2665 Usage with `compile()` API: 

2666 

2667 ```python 

2668 model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.Hinge()]) 

2669 ``` 

2670 """ 

2671 

2672 def __init__(self, name='hinge', dtype=None): 

2673 super(Hinge, self).__init__(hinge, name, dtype=dtype) 

2674 

2675 

2676@keras_export('keras.metrics.SquaredHinge') 

2677class SquaredHinge(MeanMetricWrapper): 

2678 """Computes the squared hinge metric between `y_true` and `y_pred`. 

2679 

2680 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 

2681 provided we will convert them to -1 or 1. 

2682 

2683 Args: 

2684 name: (Optional) string name of the metric instance. 

2685 dtype: (Optional) data type of the metric result. 

2686 

2687 Standalone usage: 

2688 

2689 >>> m = tf.keras.metrics.SquaredHinge() 

2690 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 

2691 >>> m.result().numpy() 

2692 1.86 

2693 

2694 >>> m.reset_state() 

2695 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 

2696 ... sample_weight=[1, 0]) 

2697 >>> m.result().numpy() 

2698 1.46 

2699 

2700 Usage with `compile()` API: 

2701 

2702 ```python 

2703 model.compile( 

2704 optimizer='sgd', 

2705 loss='mse', 

2706 metrics=[tf.keras.metrics.SquaredHinge()]) 

2707 ``` 

2708 """ 

2709 

2710 def __init__(self, name='squared_hinge', dtype=None): 

2711 super(SquaredHinge, self).__init__(squared_hinge, name, dtype=dtype) 

2712 

2713 

2714@keras_export('keras.metrics.CategoricalHinge') 

2715class CategoricalHinge(MeanMetricWrapper): 

2716 """Computes the categorical hinge metric between `y_true` and `y_pred`. 

2717 

2718 Args: 

2719 name: (Optional) string name of the metric instance. 

2720 dtype: (Optional) data type of the metric result. 

2721 

2722 Standalone usage: 

2723 

2724 >>> m = tf.keras.metrics.CategoricalHinge() 

2725 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 

2726 >>> m.result().numpy() 

2727 1.4000001 

2728 

2729 >>> m.reset_state() 

2730 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 

2731 ... sample_weight=[1, 0]) 

2732 >>> m.result().numpy() 

2733 1.2 

2734 

2735 Usage with `compile()` API: 

2736 

2737 ```python 

2738 model.compile( 

2739 optimizer='sgd', 

2740 loss='mse', 

2741 metrics=[tf.keras.metrics.CategoricalHinge()]) 

2742 ``` 

2743 """ 

2744 

2745 def __init__(self, name='categorical_hinge', dtype=None): 

2746 super(CategoricalHinge, self).__init__(categorical_hinge, name, dtype=dtype) 

2747 

2748 

2749@keras_export('keras.metrics.RootMeanSquaredError') 

2750class RootMeanSquaredError(Mean): 

2751 """Computes root mean squared error metric between `y_true` and `y_pred`. 

2752 

2753 Standalone usage: 

2754 

2755 >>> m = tf.keras.metrics.RootMeanSquaredError() 

2756 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 

2757 >>> m.result().numpy() 

2758 0.5 

2759 

2760 >>> m.reset_state() 

2761 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 

2762 ... sample_weight=[1, 0]) 

2763 >>> m.result().numpy() 

2764 0.70710677 

2765 

2766 Usage with `compile()` API: 

2767 

2768 ```python 

2769 model.compile( 

2770 optimizer='sgd', 

2771 loss='mse', 

2772 metrics=[tf.keras.metrics.RootMeanSquaredError()]) 

2773 ``` 

2774 """ 

2775 

2776 def __init__(self, name='root_mean_squared_error', dtype=None): 

2777 super(RootMeanSquaredError, self).__init__(name, dtype=dtype) 

2778 

2779 def update_state(self, y_true, y_pred, sample_weight=None): 

2780 """Accumulates root mean squared error statistics. 

2781 

2782 Args: 

2783 y_true: The ground truth values. 

2784 y_pred: The predicted values. 

2785 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 

2786 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 

2787 be broadcastable to `y_true`. 

2788 

2789 Returns: 

2790 Update op. 

2791 """ 

2792 y_true = math_ops.cast(y_true, self._dtype) 

2793 y_pred = math_ops.cast(y_pred, self._dtype) 

2794 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 

2795 y_pred, y_true) 

2796 error_sq = math_ops.squared_difference(y_pred, y_true) 

2797 return super(RootMeanSquaredError, self).update_state( 

2798 error_sq, sample_weight=sample_weight) 

2799 

2800 def result(self): 

2801 return math_ops.sqrt(math_ops.div_no_nan(self.total, self.count)) 

2802 

2803 

2804@keras_export('keras.metrics.LogCoshError') 

2805class LogCoshError(MeanMetricWrapper): 

2806 """Computes the logarithm of the hyperbolic cosine of the prediction error. 

2807 

2808 `logcosh = log((exp(x) + exp(-x))/2)`, where x is the error (y_pred - y_true) 

2809 

2810 Args: 

2811 name: (Optional) string name of the metric instance. 

2812 dtype: (Optional) data type of the metric result. 

2813 

2814 Standalone usage: 

2815 

2816 >>> m = tf.keras.metrics.LogCoshError() 

2817 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 

2818 >>> m.result().numpy() 

2819 0.10844523 

2820 

2821 >>> m.reset_state() 

2822 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 

2823 ... sample_weight=[1, 0]) 

2824 >>> m.result().numpy() 

2825 0.21689045 

2826 

2827 Usage with `compile()` API: 

2828 

2829 ```python 

2830 model.compile(optimizer='sgd', 

2831 loss='mse', 

2832 metrics=[tf.keras.metrics.LogCoshError()]) 

2833 ``` 

2834 """ 

2835 

2836 def __init__(self, name='logcosh', dtype=None): 

2837 super(LogCoshError, self).__init__(logcosh, name, dtype=dtype) 

2838 

2839 

2840@keras_export('keras.metrics.Poisson') 

2841class Poisson(MeanMetricWrapper): 

2842 """Computes the Poisson metric between `y_true` and `y_pred`. 

2843 

2844 `metric = y_pred - y_true * log(y_pred)` 

2845 

2846 Args: 

2847 name: (Optional) string name of the metric instance. 

2848 dtype: (Optional) data type of the metric result. 

2849 

2850 Standalone usage: 

2851 

2852 >>> m = tf.keras.metrics.Poisson() 

2853 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 

2854 >>> m.result().numpy() 

2855 0.49999997 

2856 

2857 >>> m.reset_state() 

2858 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 

2859 ... sample_weight=[1, 0]) 

2860 >>> m.result().numpy() 

2861 0.99999994 

2862 

2863 Usage with `compile()` API: 

2864 

2865 ```python 

2866 model.compile(optimizer='sgd', 

2867 loss='mse', 

2868 metrics=[tf.keras.metrics.Poisson()]) 

2869 ``` 

2870 """ 

2871 

2872 def __init__(self, name='poisson', dtype=None): 

2873 super(Poisson, self).__init__(poisson, name, dtype=dtype) 

2874 

2875 

2876@keras_export('keras.metrics.KLDivergence') 

2877class KLDivergence(MeanMetricWrapper): 

2878 """Computes Kullback-Leibler divergence metric between `y_true` and `y_pred`. 

2879 

2880 `metric = y_true * log(y_true / y_pred)` 

2881 

2882 Args: 

2883 name: (Optional) string name of the metric instance. 

2884 dtype: (Optional) data type of the metric result. 

2885 

2886 Standalone usage: 

2887 

2888 >>> m = tf.keras.metrics.KLDivergence() 

2889 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 

2890 >>> m.result().numpy() 

2891 0.45814306 

2892 

2893 >>> m.reset_state() 

2894 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 

2895 ... sample_weight=[1, 0]) 

2896 >>> m.result().numpy() 

2897 0.9162892 

2898 

2899 Usage with `compile()` API: 

2900 

2901 ```python 

2902 model.compile(optimizer='sgd', 

2903 loss='mse', 

2904 metrics=[tf.keras.metrics.KLDivergence()]) 

2905 ``` 

2906 """ 

2907 

2908 def __init__(self, name='kullback_leibler_divergence', dtype=None): 

2909 super(KLDivergence, self).__init__( 

2910 kullback_leibler_divergence, name, dtype=dtype) 

2911 

2912 

2913@keras_export('keras.metrics.MeanIoU') 

2914class MeanIoU(Metric): 

2915 """Computes the mean Intersection-Over-Union metric. 

2916 

2917 Mean Intersection-Over-Union is a common evaluation metric for semantic image 

2918 segmentation, which first computes the IOU for each semantic class and then 

2919 computes the average over classes. IOU is defined as follows: 

2920 IOU = true_positive / (true_positive + false_positive + false_negative). 

2921 The predictions are accumulated in a confusion matrix, weighted by 

2922 `sample_weight` and the metric is then calculated from it. 

2923 

2924 If `sample_weight` is `None`, weights default to 1. 

2925 Use `sample_weight` of 0 to mask values. 

2926 

2927 Args: 

2928 num_classes: The possible number of labels the prediction task can have. 

2929 This value must be provided, since a confusion matrix of dimension = 

2930 [num_classes, num_classes] will be allocated. 

2931 name: (Optional) string name of the metric instance. 

2932 dtype: (Optional) data type of the metric result. 

2933 

2934 Standalone usage: 

2935 

2936 >>> # cm = [[1, 1], 

2937 >>> # [1, 1]] 

2938 >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1] 

2939 >>> # iou = true_positives / (sum_row + sum_col - true_positives)) 

2940 >>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33 

2941 >>> m = tf.keras.metrics.MeanIoU(num_classes=2) 

2942 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1]) 

2943 >>> m.result().numpy() 

2944 0.33333334 

2945 

2946 >>> m.reset_state() 

2947 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1], 

2948 ... sample_weight=[0.3, 0.3, 0.3, 0.1]) 

2949 >>> m.result().numpy() 

2950 0.23809525 

2951 

2952 Usage with `compile()` API: 

2953 

2954 ```python 

2955 model.compile( 

2956 optimizer='sgd', 

2957 loss='mse', 

2958 metrics=[tf.keras.metrics.MeanIoU(num_classes=2)]) 

2959 ``` 

2960 """ 

2961 

2962 def __init__(self, num_classes, name=None, dtype=None): 

2963 super(MeanIoU, self).__init__(name=name, dtype=dtype) 

2964 self.num_classes = num_classes 

2965 

2966 # Variable to accumulate the predictions in the confusion matrix. 

2967 self.total_cm = self.add_weight( 

2968 'total_confusion_matrix', 

2969 shape=(num_classes, num_classes), 

2970 initializer=init_ops.zeros_initializer) 

2971 

2972 def update_state(self, y_true, y_pred, sample_weight=None): 

2973 """Accumulates the confusion matrix statistics. 

2974 

2975 Args: 

2976 y_true: The ground truth values. 

2977 y_pred: The predicted values. 

2978 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 

2979 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 

2980 be broadcastable to `y_true`. 

2981 

2982 Returns: 

2983 Update op. 

2984 """ 

2985 

2986 y_true = math_ops.cast(y_true, self._dtype) 

2987 y_pred = math_ops.cast(y_pred, self._dtype) 

2988 

2989 # Flatten the input if its rank > 1. 

2990 if y_pred.shape.ndims > 1: 

2991 y_pred = array_ops.reshape(y_pred, [-1]) 

2992 

2993 if y_true.shape.ndims > 1: 

2994 y_true = array_ops.reshape(y_true, [-1]) 

2995 

2996 if sample_weight is not None: 

2997 sample_weight = math_ops.cast(sample_weight, self._dtype) 

2998 if sample_weight.shape.ndims > 1: 

2999 sample_weight = array_ops.reshape(sample_weight, [-1]) 

3000 

3001 # Accumulate the prediction to current confusion matrix. 

3002 current_cm = confusion_matrix.confusion_matrix( 

3003 y_true, 

3004 y_pred, 

3005 self.num_classes, 

3006 weights=sample_weight, 

3007 dtype=self._dtype) 

3008 return self.total_cm.assign_add(current_cm) 

3009 

3010 def result(self): 

3011 """Compute the mean intersection-over-union via the confusion matrix.""" 

3012 sum_over_row = math_ops.cast( 

3013 math_ops.reduce_sum(self.total_cm, axis=0), dtype=self._dtype) 

3014 sum_over_col = math_ops.cast( 

3015 math_ops.reduce_sum(self.total_cm, axis=1), dtype=self._dtype) 

3016 true_positives = math_ops.cast( 

3017 array_ops.tensor_diag_part(self.total_cm), dtype=self._dtype) 

3018 

3019 # sum_over_row + sum_over_col = 

3020 # 2 * true_positives + false_positives + false_negatives. 

3021 denominator = sum_over_row + sum_over_col - true_positives 

3022 

3023 # The mean is only computed over classes that appear in the 

3024 # label or prediction tensor. If the denominator is 0, we need to 

3025 # ignore the class. 

3026 num_valid_entries = math_ops.reduce_sum( 

3027 math_ops.cast(math_ops.not_equal(denominator, 0), dtype=self._dtype)) 

3028 

3029 iou = math_ops.div_no_nan(true_positives, denominator) 

3030 

3031 return math_ops.div_no_nan( 

3032 math_ops.reduce_sum(iou, name='mean_iou'), num_valid_entries) 

3033 

3034 def reset_state(self): 

3035 backend.set_value( 

3036 self.total_cm, np.zeros((self.num_classes, self.num_classes))) 

3037 

3038 def get_config(self): 

3039 config = {'num_classes': self.num_classes} 

3040 base_config = super(MeanIoU, self).get_config() 

3041 return dict(list(base_config.items()) + list(config.items())) 

3042 

3043 

3044@keras_export('keras.metrics.MeanTensor') 

3045class MeanTensor(Metric): 

3046 """Computes the element-wise (weighted) mean of the given tensors. 

3047 

3048 `MeanTensor` returns a tensor with the same shape of the input tensors. The 

3049 mean value is updated by keeping local variables `total` and `count`. The 

3050 `total` tracks the sum of the weighted values, and `count` stores the sum of 

3051 the weighted counts. 

3052 

3053 Args: 

3054 name: (Optional) string name of the metric instance. 

3055 dtype: (Optional) data type of the metric result. 

3056 shape: (Optional) A list of integers, a tuple of integers, or a 1-D Tensor 

3057 of type int32. If not specified, the shape is inferred from the values at 

3058 the first call of update_state. 

3059 

3060 Standalone usage: 

3061 

3062 >>> m = tf.keras.metrics.MeanTensor() 

3063 >>> m.update_state([0, 1, 2, 3]) 

3064 >>> m.update_state([4, 5, 6, 7]) 

3065 >>> m.result().numpy() 

3066 array([2., 3., 4., 5.], dtype=float32) 

3067 

3068 >>> m.update_state([12, 10, 8, 6], sample_weight= [0, 0.2, 0.5, 1]) 

3069 >>> m.result().numpy() 

3070 array([2. , 3.6363635, 4.8 , 5.3333335], dtype=float32) 

3071 

3072 >>> m = tf.keras.metrics.MeanTensor(dtype=tf.float64, shape=(1, 4)) 

3073 >>> m.result().numpy() 

3074 array([[0., 0., 0., 0.]]) 

3075 >>> m.update_state([[0, 1, 2, 3]]) 

3076 >>> m.update_state([[4, 5, 6, 7]]) 

3077 >>> m.result().numpy() 

3078 array([[2., 3., 4., 5.]]) 

3079 """ 

3080 

3081 def __init__(self, name='mean_tensor', dtype=None, shape=None): 

3082 super(MeanTensor, self).__init__(name=name, dtype=dtype) 

3083 self._shape = None 

3084 self._total = None 

3085 self._count = None 

3086 self._built = False 

3087 if shape is not None: 

3088 self._build(shape) 

3089 

3090 def _build(self, shape): 

3091 self._shape = tensor_shape.TensorShape(shape) 

3092 self._build_input_shape = self._shape 

3093 # Create new state variables 

3094 self._total = self.add_weight( 

3095 'total', shape=shape, initializer=init_ops.zeros_initializer) 

3096 self._count = self.add_weight( 

3097 'count', shape=shape, initializer=init_ops.zeros_initializer) 

3098 with ops.init_scope(): 

3099 if not context.executing_eagerly(): 

3100 backend._initialize_variables(backend._get_session()) # pylint: disable=protected-access 

3101 self._built = True 

3102 

3103 @property 

3104 def total(self): 

3105 return self._total if self._built else None 

3106 

3107 @property 

3108 def count(self): 

3109 return self._count if self._built else None 

3110 

3111 def update_state(self, values, sample_weight=None): 

3112 """Accumulates statistics for computing the element-wise mean. 

3113 

3114 Args: 

3115 values: Per-example value. 

3116 sample_weight: Optional weighting of each example. Defaults to 1. 

3117 

3118 Returns: 

3119 Update op. 

3120 """ 

3121 values = math_ops.cast(values, self._dtype) 

3122 if not self._built: 

3123 self._build(values.shape) 

3124 elif values.shape != self._shape: 

3125 raise ValueError('MeanTensor input values must always have the same ' 

3126 'shape. Expected shape (set during the first call): {}. ' 

3127 'Got: {}'.format(self._shape, values.shape)) 

3128 

3129 num_values = array_ops.ones_like(values) 

3130 if sample_weight is not None: 

3131 sample_weight = math_ops.cast(sample_weight, self._dtype) 

3132 

3133 # Update dimensions of weights to match with values if possible. 

3134 values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions( 

3135 values, sample_weight=sample_weight) 

3136 try: 

3137 # Broadcast weights if possible. 

3138 sample_weight = weights_broadcast_ops.broadcast_weights( 

3139 sample_weight, values) 

3140 except ValueError: 

3141 # Reduce values to same ndim as weight array 

3142 ndim = backend.ndim(values) 

3143 weight_ndim = backend.ndim(sample_weight) 

3144 values = math_ops.reduce_mean( 

3145 values, axis=list(range(weight_ndim, ndim))) 

3146 

3147 num_values = math_ops.multiply(num_values, sample_weight) 

3148 values = math_ops.multiply(values, sample_weight) 

3149 

3150 update_total_op = self._total.assign_add(values) 

3151 with ops.control_dependencies([update_total_op]): 

3152 return self._count.assign_add(num_values) 

3153 

3154 def result(self): 

3155 if not self._built: 

3156 raise ValueError( 

3157 'MeanTensor does not have any result yet. Please call the MeanTensor ' 

3158 'instance or use `.update_state(value)` before retrieving the result.' 

3159 ) 

3160 return math_ops.div_no_nan(self.total, self.count) 

3161 

3162 def reset_state(self): 

3163 if self._built: 

3164 backend.batch_set_value( 

3165 [(v, np.zeros(self._shape.as_list())) for v in self.variables]) 

3166 

3167 

3168@keras_export('keras.metrics.BinaryCrossentropy') 

3169class BinaryCrossentropy(MeanMetricWrapper): 

3170 """Computes the crossentropy metric between the labels and predictions. 

3171 

3172 This is the crossentropy metric class to be used when there are only two 

3173 label classes (0 and 1). 

3174 

3175 Args: 

3176 name: (Optional) string name of the metric instance. 

3177 dtype: (Optional) data type of the metric result. 

3178 from_logits: (Optional )Whether output is expected to be a logits tensor. 

3179 By default, we consider that output encodes a probability distribution. 

3180 label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are 

3181 smoothed, meaning the confidence on label values are relaxed. 

3182 e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for 

3183 label `0` and `0.9` for label `1`". 

3184 

3185 Standalone usage: 

3186 

3187 >>> m = tf.keras.metrics.BinaryCrossentropy() 

3188 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 

3189 >>> m.result().numpy() 

3190 0.81492424 

3191 

3192 >>> m.reset_state() 

3193 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 

3194 ... sample_weight=[1, 0]) 

3195 >>> m.result().numpy() 

3196 0.9162905 

3197 

3198 Usage with `compile()` API: 

3199 

3200 ```python 

3201 model.compile( 

3202 optimizer='sgd', 

3203 loss='mse', 

3204 metrics=[tf.keras.metrics.BinaryCrossentropy()]) 

3205 ``` 

3206 """ 

3207 

3208 def __init__(self, 

3209 name='binary_crossentropy', 

3210 dtype=None, 

3211 from_logits=False, 

3212 label_smoothing=0): 

3213 super(BinaryCrossentropy, self).__init__( 

3214 binary_crossentropy, 

3215 name, 

3216 dtype=dtype, 

3217 from_logits=from_logits, 

3218 label_smoothing=label_smoothing) 

3219 

3220 

3221@keras_export('keras.metrics.CategoricalCrossentropy') 

3222class CategoricalCrossentropy(MeanMetricWrapper): 

3223 """Computes the crossentropy metric between the labels and predictions. 

3224 

3225 This is the crossentropy metric class to be used when there are multiple 

3226 label classes (2 or more). Here we assume that labels are given as a `one_hot` 

3227 representation. eg., When labels values are [2, 0, 1], 

3228 `y_true` = [[0, 0, 1], [1, 0, 0], [0, 1, 0]]. 

3229 

3230 Args: 

3231 name: (Optional) string name of the metric instance. 

3232 dtype: (Optional) data type of the metric result. 

3233 from_logits: (Optional) Whether output is expected to be a logits tensor. 

3234 By default, we consider that output encodes a probability distribution. 

3235 label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are 

3236 smoothed, meaning the confidence on label values are relaxed. e.g. 

3237 `label_smoothing=0.2` means that we will use a value of `0.1` for label 

3238 `0` and `0.9` for label `1`" 

3239 

3240 Standalone usage: 

3241 

3242 >>> # EPSILON = 1e-7, y = y_true, y` = y_pred 

3243 >>> # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON) 

3244 >>> # y` = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]] 

3245 >>> # xent = -sum(y * log(y'), axis = -1) 

3246 >>> # = -((log 0.95), (log 0.1)) 

3247 >>> # = [0.051, 2.302] 

3248 >>> # Reduced xent = (0.051 + 2.302) / 2 

3249 >>> m = tf.keras.metrics.CategoricalCrossentropy() 

3250 >>> m.update_state([[0, 1, 0], [0, 0, 1]], 

3251 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) 

3252 >>> m.result().numpy() 

3253 1.1769392 

3254 

3255 >>> m.reset_state() 

3256 >>> m.update_state([[0, 1, 0], [0, 0, 1]], 

3257 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], 

3258 ... sample_weight=tf.constant([0.3, 0.7])) 

3259 >>> m.result().numpy() 

3260 1.6271976 

3261 

3262 Usage with `compile()` API: 

3263 

3264 ```python 

3265 model.compile( 

3266 optimizer='sgd', 

3267 loss='mse', 

3268 metrics=[tf.keras.metrics.CategoricalCrossentropy()]) 

3269 ``` 

3270 """ 

3271 

3272 def __init__(self, 

3273 name='categorical_crossentropy', 

3274 dtype=None, 

3275 from_logits=False, 

3276 label_smoothing=0): 

3277 super(CategoricalCrossentropy, self).__init__( 

3278 categorical_crossentropy, 

3279 name, 

3280 dtype=dtype, 

3281 from_logits=from_logits, 

3282 label_smoothing=label_smoothing) 

3283 

3284 

3285@keras_export('keras.metrics.SparseCategoricalCrossentropy') 

3286class SparseCategoricalCrossentropy(MeanMetricWrapper): 

3287 """Computes the crossentropy metric between the labels and predictions. 

3288 

3289 Use this crossentropy metric when there are two or more label classes. 

3290 We expect labels to be provided as integers. If you want to provide labels 

3291 using `one-hot` representation, please use `CategoricalCrossentropy` metric. 

3292 There should be `# classes` floating point values per feature for `y_pred` 

3293 and a single floating point value per feature for `y_true`. 

3294 

3295 In the snippet below, there is a single floating point value per example for 

3296 `y_true` and `# classes` floating pointing values per example for `y_pred`. 

3297 The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is 

3298 `[batch_size, num_classes]`. 

3299 

3300 Args: 

3301 name: (Optional) string name of the metric instance. 

3302 dtype: (Optional) data type of the metric result. 

3303 from_logits: (Optional) Whether output is expected to be a logits tensor. 

3304 By default, we consider that output encodes a probability distribution. 

3305 axis: (Optional) Defaults to -1. The dimension along which the metric is 

3306 computed. 

3307 

3308 Standalone usage: 

3309 

3310 >>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]] 

3311 >>> # logits = log(y_pred) 

3312 >>> # softmax = exp(logits) / sum(exp(logits), axis=-1) 

3313 >>> # softmax = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]] 

3314 >>> # xent = -sum(y * log(softmax), 1) 

3315 >>> # log(softmax) = [[-2.9957, -0.0513, -16.1181], 

3316 >>> # [-2.3026, -0.2231, -2.3026]] 

3317 >>> # y_true * log(softmax) = [[0, -0.0513, 0], [0, 0, -2.3026]] 

3318 >>> # xent = [0.0513, 2.3026] 

3319 >>> # Reduced xent = (0.0513 + 2.3026) / 2 

3320 >>> m = tf.keras.metrics.SparseCategoricalCrossentropy() 

3321 >>> m.update_state([1, 2], 

3322 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) 

3323 >>> m.result().numpy() 

3324 1.1769392 

3325 

3326 >>> m.reset_state() 

3327 >>> m.update_state([1, 2], 

3328 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], 

3329 ... sample_weight=tf.constant([0.3, 0.7])) 

3330 >>> m.result().numpy() 

3331 1.6271976 

3332 

3333 Usage with `compile()` API: 

3334 

3335 ```python 

3336 model.compile( 

3337 optimizer='sgd', 

3338 loss='mse', 

3339 metrics=[tf.keras.metrics.SparseCategoricalCrossentropy()]) 

3340 ``` 

3341 """ 

3342 

3343 def __init__(self, 

3344 name='sparse_categorical_crossentropy', 

3345 dtype=None, 

3346 from_logits=False, 

3347 axis=-1): 

3348 super(SparseCategoricalCrossentropy, self).__init__( 

3349 sparse_categorical_crossentropy, 

3350 name, 

3351 dtype=dtype, 

3352 from_logits=from_logits, 

3353 axis=axis) 

3354 

3355 

3356class SumOverBatchSize(Reduce): 

3357 """Computes the weighted sum over batch size of the given values. 

3358 

3359 For example, if values is [1, 3, 5, 7] then the metric value is 4. 

3360 If the weights were specified as [1, 1, 0, 0] then the value would be 1. 

3361 

3362 This metric creates two variables, `total` and `count` that are used to 

3363 compute the average of `values`. This average is ultimately returned as sum 

3364 over batch size which is an idempotent operation that simply divides `total` 

3365 by `count`. 

3366 

3367 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 

3368 to mask values. 

3369 """ 

3370 

3371 def __init__(self, name='sum_over_batch_size', dtype=None): 

3372 super(SumOverBatchSize, self).__init__( 

3373 reduction=metrics_utils.Reduction.SUM_OVER_BATCH_SIZE, 

3374 name=name, 

3375 dtype=dtype) 

3376 

3377 

3378class SumOverBatchSizeMetricWrapper(SumOverBatchSize): 

3379 """Wraps a function with the `SumOverBatchSizeMetricWrapper` metric.""" 

3380 

3381 def __init__(self, fn, name=None, dtype=None, **kwargs): 

3382 """Creates a `SumOverBatchSizeMetricWrapper` instance. 

3383 

3384 Args: 

3385 fn: The metric function to wrap, with signature `fn(y_true, y_pred, 

3386 **kwargs)`. 

3387 name: (Optional) string name of the metric instance. 

3388 dtype: (Optional) data type of the metric result. 

3389 **kwargs: The keyword arguments that are passed on to `fn`. 

3390 """ 

3391 super(SumOverBatchSizeMetricWrapper, self).__init__(name=name, dtype=dtype) 

3392 self._fn = fn 

3393 self._fn_kwargs = kwargs 

3394 

3395 def update_state(self, y_true, y_pred, sample_weight=None): 

3396 y_true = math_ops.cast(y_true, self._dtype) 

3397 y_pred = math_ops.cast(y_pred, self._dtype) 

3398 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 

3399 y_pred, y_true) 

3400 

3401 ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx()) 

3402 matches = ag_fn(y_true, y_pred, **self._fn_kwargs) 

3403 return super(SumOverBatchSizeMetricWrapper, self).update_state( 

3404 matches, sample_weight=sample_weight) 

3405 

3406 def get_config(self): 

3407 config = {} 

3408 for k, v in self._fn_kwargs.items(): 

3409 config[k] = backend.eval(v) if is_tensor_or_variable(v) else v 

3410 base_config = super(SumOverBatchSizeMetricWrapper, self).get_config() 

3411 return dict(list(base_config.items()) + list(config.items())) 

3412 

3413 

3414def accuracy(y_true, y_pred): 

3415 [y_pred, y_true], _ = \ 

3416 metrics_utils.ragged_assert_compatible_and_get_flat_values( 

3417 [y_pred, y_true]) 

3418 y_true.shape.assert_is_compatible_with(y_pred.shape) 

3419 if y_true.dtype != y_pred.dtype: 

3420 y_pred = math_ops.cast(y_pred, y_true.dtype) 

3421 return math_ops.cast(math_ops.equal(y_true, y_pred), backend.floatx()) 

3422 

3423 

3424@keras_export('keras.metrics.binary_accuracy') 

3425@dispatch.add_dispatch_support 

3426def binary_accuracy(y_true, y_pred, threshold=0.5): 

3427 """Calculates how often predictions match binary labels. 

3428 

3429 Standalone usage: 

3430 >>> y_true = [[1], [1], [0], [0]] 

3431 >>> y_pred = [[1], [1], [0], [0]] 

3432 >>> m = tf.keras.metrics.binary_accuracy(y_true, y_pred) 

3433 >>> assert m.shape == (4,) 

3434 >>> m.numpy() 

3435 array([1., 1., 1., 1.], dtype=float32) 

3436 

3437 Args: 

3438 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 

3439 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 

3440 threshold: (Optional) Float representing the threshold for deciding whether 

3441 prediction values are 1 or 0. 

3442 

3443 Returns: 

3444 Binary accuracy values. shape = `[batch_size, d0, .. dN-1]` 

3445 """ 

3446 y_pred = tensor_conversion.convert_to_tensor_v2_with_dispatch(y_pred) 

3447 threshold = math_ops.cast(threshold, y_pred.dtype) 

3448 y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype) 

3449 return backend.mean(math_ops.equal(y_true, y_pred), axis=-1) 

3450 

3451 

3452@keras_export('keras.metrics.categorical_accuracy') 

3453@dispatch.add_dispatch_support 

3454def categorical_accuracy(y_true, y_pred): 

3455 """Calculates how often predictions match one-hot labels. 

3456 

3457 Standalone usage: 

3458 >>> y_true = [[0, 0, 1], [0, 1, 0]] 

3459 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] 

3460 >>> m = tf.keras.metrics.categorical_accuracy(y_true, y_pred) 

3461 >>> assert m.shape == (2,) 

3462 >>> m.numpy() 

3463 array([0., 1.], dtype=float32) 

3464 

3465 You can provide logits of classes as `y_pred`, since argmax of 

3466 logits and probabilities are same. 

3467 

3468 Args: 

3469 y_true: One-hot ground truth values. 

3470 y_pred: The prediction values. 

3471 

3472 Returns: 

3473 Categorical accuracy values. 

3474 """ 

3475 return math_ops.cast( 

3476 math_ops.equal( 

3477 math_ops.argmax(y_true, axis=-1), math_ops.argmax(y_pred, axis=-1)), 

3478 backend.floatx()) 

3479 

3480 

3481@keras_export('keras.metrics.sparse_categorical_accuracy') 

3482@dispatch.add_dispatch_support 

3483def sparse_categorical_accuracy(y_true, y_pred): 

3484 """Calculates how often predictions match integer labels. 

3485 

3486 Standalone usage: 

3487 >>> y_true = [2, 1] 

3488 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] 

3489 >>> m = tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred) 

3490 >>> assert m.shape == (2,) 

3491 >>> m.numpy() 

3492 array([0., 1.], dtype=float32) 

3493 

3494 You can provide logits of classes as `y_pred`, since argmax of 

3495 logits and probabilities are same. 

3496 

3497 Args: 

3498 y_true: Integer ground truth values. 

3499 y_pred: The prediction values. 

3500 

3501 Returns: 

3502 Sparse categorical accuracy values. 

3503 """ 

3504 y_pred = tensor_conversion.convert_to_tensor_v2_with_dispatch(y_pred) 

3505 y_true = tensor_conversion.convert_to_tensor_v2_with_dispatch(y_true) 

3506 y_pred_rank = y_pred.shape.ndims 

3507 y_true_rank = y_true.shape.ndims 

3508 # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) 

3509 if (y_true_rank is not None) and (y_pred_rank is not None) and (len( 

3510 backend.int_shape(y_true)) == len(backend.int_shape(y_pred))): 

3511 y_true = array_ops.squeeze(y_true, [-1]) 

3512 y_pred = math_ops.argmax(y_pred, axis=-1) 

3513 

3514 # If the predicted output and actual output types don't match, force cast them 

3515 # to match. 

3516 if backend.dtype(y_pred) != backend.dtype(y_true): 

3517 y_pred = math_ops.cast(y_pred, backend.dtype(y_true)) 

3518 

3519 return math_ops.cast(math_ops.equal(y_true, y_pred), backend.floatx()) 

3520 

3521 

3522@keras_export('keras.metrics.top_k_categorical_accuracy') 

3523@dispatch.add_dispatch_support 

3524def top_k_categorical_accuracy(y_true, y_pred, k=5): 

3525 """Computes how often targets are in the top `K` predictions. 

3526 

3527 Standalone usage: 

3528 >>> y_true = [[0, 0, 1], [0, 1, 0]] 

3529 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] 

3530 >>> m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3) 

3531 >>> assert m.shape == (2,) 

3532 >>> m.numpy() 

3533 array([1., 1.], dtype=float32) 

3534 

3535 Args: 

3536 y_true: The ground truth values. 

3537 y_pred: The prediction values. 

3538 k: (Optional) Number of top elements to look at for computing accuracy. 

3539 Defaults to 5. 

3540 

3541 Returns: 

3542 Top K categorical accuracy value. 

3543 """ 

3544 return math_ops.cast( 

3545 nn.in_top_k( 

3546 y_pred, math_ops.argmax(y_true, axis=-1), k), backend.floatx()) 

3547 

3548 

3549@keras_export('keras.metrics.sparse_top_k_categorical_accuracy') 

3550@dispatch.add_dispatch_support 

3551def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): 

3552 """Computes how often integer targets are in the top `K` predictions. 

3553 

3554 Standalone usage: 

3555 >>> y_true = [2, 1] 

3556 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] 

3557 >>> m = tf.keras.metrics.sparse_top_k_categorical_accuracy( 

3558 ... y_true, y_pred, k=3) 

3559 >>> assert m.shape == (2,) 

3560 >>> m.numpy() 

3561 array([1., 1.], dtype=float32) 

3562 

3563 Args: 

3564 y_true: tensor of true targets. 

3565 y_pred: tensor of predicted targets. 

3566 k: (Optional) Number of top elements to look at for computing accuracy. 

3567 Defaults to 5. 

3568 

3569 Returns: 

3570 Sparse top K categorical accuracy value. 

3571 """ 

3572 y_pred_rank = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

3573 y_pred 

3574 ).shape.ndims 

3575 y_true_rank = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

3576 y_true 

3577 ).shape.ndims 

3578 # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,) 

3579 if (y_true_rank is not None) and (y_pred_rank is not None): 

3580 if y_pred_rank > 2: 

3581 y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]]) 

3582 if y_true_rank > 1: 

3583 y_true = array_ops.reshape(y_true, [-1]) 

3584 

3585 return math_ops.cast( 

3586 nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), backend.floatx()) 

3587 

3588 

3589def cosine_proximity(y_true, y_pred, axis=-1): 

3590 """Computes the cosine similarity between labels and predictions. 

3591 

3592 Args: 

3593 y_true: The ground truth values. 

3594 y_pred: The prediction values. 

3595 axis: (Optional) Defaults to -1. The dimension along which the cosine 

3596 similarity is computed. 

3597 

3598 Returns: 

3599 Cosine similarity value. 

3600 """ 

3601 y_true = nn.l2_normalize(y_true, axis=axis) 

3602 y_pred = nn.l2_normalize(y_pred, axis=axis) 

3603 return math_ops.reduce_sum(y_true * y_pred, axis=axis) 

3604 

3605# Aliases 

3606 

3607acc = ACC = accuracy 

3608bce = BCE = binary_crossentropy 

3609mse = MSE = mean_squared_error 

3610mae = MAE = mean_absolute_error 

3611mape = MAPE = mean_absolute_percentage_error 

3612msle = MSLE = mean_squared_logarithmic_error 

3613cosine_similarity = cosine_proximity 

3614log_cosh = logcosh 

3615 

3616 

3617def clone_metric(metric): 

3618 """Returns a clone of the metric if stateful, otherwise returns it as is.""" 

3619 if isinstance(metric, Metric): 

3620 with ops.init_scope(): 

3621 return metric.__class__.from_config(metric.get_config()) 

3622 return metric 

3623 

3624 

3625def clone_metrics(metrics): 

3626 """Clones the given metric list/dict.""" 

3627 return nest.map_structure(clone_metric, metrics) 

3628 

3629 

3630@keras_export('keras.metrics.serialize') 

3631def serialize(metric): 

3632 """Serializes metric function or `Metric` instance. 

3633 

3634 Args: 

3635 metric: A Keras `Metric` instance or a metric function. 

3636 

3637 Returns: 

3638 Metric configuration dictionary. 

3639 """ 

3640 return serialize_keras_object(metric) 

3641 

3642 

3643@keras_export('keras.metrics.deserialize') 

3644def deserialize(config, custom_objects=None): 

3645 """Deserializes a serialized metric class/function instance. 

3646 

3647 Args: 

3648 config: Metric configuration. 

3649 custom_objects: Optional dictionary mapping names (strings) to custom 

3650 objects (classes and functions) to be considered during deserialization. 

3651 

3652 Returns: 

3653 A Keras `Metric` instance or a metric function. 

3654 """ 

3655 return deserialize_keras_object( 

3656 config, 

3657 module_objects=globals(), 

3658 custom_objects=custom_objects, 

3659 printable_module_name='metric function') 

3660 

3661 

3662@keras_export('keras.metrics.get') 

3663def get(identifier): 

3664 """Retrieves a Keras metric as a `function`/`Metric` class instance. 

3665 

3666 The `identifier` may be the string name of a metric function or class. 

3667 

3668 >>> metric = tf.keras.metrics.get("categorical_crossentropy") 

3669 >>> type(metric) 

3670 <class 'function'> 

3671 >>> metric = tf.keras.metrics.get("CategoricalCrossentropy") 

3672 >>> type(metric) 

3673 <class '...keras.metrics.CategoricalCrossentropy'> 

3674 

3675 You can also specify `config` of the metric to this function by passing dict 

3676 containing `class_name` and `config` as an identifier. Also note that the 

3677 `class_name` must map to a `Metric` class 

3678 

3679 >>> identifier = {"class_name": "CategoricalCrossentropy", 

3680 ... "config": {"from_logits": True}} 

3681 >>> metric = tf.keras.metrics.get(identifier) 

3682 >>> type(metric) 

3683 <class '...keras.metrics.CategoricalCrossentropy'> 

3684 

3685 Args: 

3686 identifier: A metric identifier. One of None or string name of a metric 

3687 function/class or metric configuration dictionary or a metric function or 

3688 a metric class instance 

3689 

3690 Returns: 

3691 A Keras metric as a `function`/ `Metric` class instance. 

3692 

3693 Raises: 

3694 ValueError: If `identifier` cannot be interpreted. 

3695 """ 

3696 if isinstance(identifier, dict): 

3697 return deserialize(identifier) 

3698 elif isinstance(identifier, str): 

3699 return deserialize(str(identifier)) 

3700 elif callable(identifier): 

3701 return identifier 

3702 else: 

3703 raise ValueError( 

3704 'Could not interpret metric function identifier: {}'.format(identifier)) 

3705 

3706 

3707def is_built_in(cls): 

3708 return cls.__module__ == Metric.__module__