Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/metrics.py: 35%
836 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-classes-have-attributes
16# pylint: disable=g-doc-return-or-yield
17"""Built-in metrics."""
19import abc
20import types
21import warnings
23import numpy as np
25from tensorflow.python.autograph.core import ag_ctx
26from tensorflow.python.autograph.impl import api as autograph
27from tensorflow.python.distribute import distribute_lib
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_conversion
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.keras import activations
36from tensorflow.python.keras import backend
37from tensorflow.python.keras.engine import base_layer
38from tensorflow.python.keras.engine import base_layer_utils
39from tensorflow.python.keras.engine import keras_tensor
40from tensorflow.python.keras.losses import binary_crossentropy
41from tensorflow.python.keras.losses import categorical_crossentropy
42from tensorflow.python.keras.losses import categorical_hinge
43from tensorflow.python.keras.losses import hinge
44from tensorflow.python.keras.losses import kullback_leibler_divergence
45from tensorflow.python.keras.losses import logcosh
46from tensorflow.python.keras.losses import mean_absolute_error
47from tensorflow.python.keras.losses import mean_absolute_percentage_error
48from tensorflow.python.keras.losses import mean_squared_error
49from tensorflow.python.keras.losses import mean_squared_logarithmic_error
50from tensorflow.python.keras.losses import poisson
51from tensorflow.python.keras.losses import sparse_categorical_crossentropy
52from tensorflow.python.keras.losses import squared_hinge
53from tensorflow.python.keras.saving.saved_model import metric_serialization
54from tensorflow.python.keras.utils import generic_utils
55from tensorflow.python.keras.utils import losses_utils
56from tensorflow.python.keras.utils import metrics_utils
57from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
58from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
59from tensorflow.python.keras.utils.generic_utils import to_list
60from tensorflow.python.keras.utils.tf_utils import is_tensor_or_variable
61from tensorflow.python.ops import array_ops
62from tensorflow.python.ops import check_ops
63from tensorflow.python.ops import confusion_matrix
64from tensorflow.python.ops import init_ops
65from tensorflow.python.ops import math_ops
66from tensorflow.python.ops import nn
67from tensorflow.python.ops import variables as variables_module
68from tensorflow.python.ops import weights_broadcast_ops
69from tensorflow.python.util import dispatch
70from tensorflow.python.util import nest
71from tensorflow.python.util.tf_export import keras_export
72from tensorflow.tools.docs import doc_controls
75@keras_export('keras.metrics.Metric')
76class Metric(base_layer.Layer, metaclass=abc.ABCMeta):
77 """Encapsulates metric logic and state.
79 Args:
80 name: (Optional) string name of the metric instance.
81 dtype: (Optional) data type of the metric result.
82 **kwargs: Additional layer keywords arguments.
84 Standalone usage:
86 ```python
87 m = SomeMetric(...)
88 for input in ...:
89 m.update_state(input)
90 print('Final result: ', m.result().numpy())
91 ```
93 Usage with `compile()` API:
95 ```python
96 model = tf.keras.Sequential()
97 model.add(tf.keras.layers.Dense(64, activation='relu'))
98 model.add(tf.keras.layers.Dense(64, activation='relu'))
99 model.add(tf.keras.layers.Dense(10, activation='softmax'))
101 model.compile(optimizer=tf.keras.optimizers.RMSprop(0.01),
102 loss=tf.keras.losses.CategoricalCrossentropy(),
103 metrics=[tf.keras.metrics.CategoricalAccuracy()])
105 data = np.random.random((1000, 32))
106 labels = np.random.random((1000, 10))
108 dataset = tf.data.Dataset.from_tensor_slices((data, labels))
109 dataset = dataset.batch(32)
111 model.fit(dataset, epochs=10)
112 ```
114 To be implemented by subclasses:
115 * `__init__()`: All state variables should be created in this method by
116 calling `self.add_weight()` like: `self.var = self.add_weight(...)`
117 * `update_state()`: Has all updates to the state variables like:
118 self.var.assign_add(...).
119 * `result()`: Computes and returns a value for the metric
120 from the state variables.
122 Example subclass implementation:
124 ```python
125 class BinaryTruePositives(tf.keras.metrics.Metric):
127 def __init__(self, name='binary_true_positives', **kwargs):
128 super(BinaryTruePositives, self).__init__(name=name, **kwargs)
129 self.true_positives = self.add_weight(name='tp', initializer='zeros')
131 def update_state(self, y_true, y_pred, sample_weight=None):
132 y_true = tf.cast(y_true, tf.bool)
133 y_pred = tf.cast(y_pred, tf.bool)
135 values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
136 values = tf.cast(values, self.dtype)
137 if sample_weight is not None:
138 sample_weight = tf.cast(sample_weight, self.dtype)
139 sample_weight = tf.broadcast_to(sample_weight, values.shape)
140 values = tf.multiply(values, sample_weight)
141 self.true_positives.assign_add(tf.reduce_sum(values))
143 def result(self):
144 return self.true_positives
145 ```
146 """
148 def __init__(self, name=None, dtype=None, **kwargs):
149 super(Metric, self).__init__(name=name, dtype=dtype, **kwargs)
150 self.stateful = True # All metric layers are stateful.
151 self.built = True
152 if not base_layer_utils.v2_dtype_behavior_enabled():
153 # We only do this when the V2 behavior is not enabled, as when it is
154 # enabled, the dtype already defaults to floatx.
155 self._dtype = (backend.floatx() if dtype is None
156 else dtypes.as_dtype(dtype).name)
158 def __new__(cls, *args, **kwargs):
159 obj = super(Metric, cls).__new__(cls)
161 # If `update_state` is not in eager/tf.function and it is not from a
162 # built-in metric, wrap it in `tf.function`. This is so that users writing
163 # custom metrics in v1 need not worry about control dependencies and
164 # return ops.
165 if (base_layer_utils.is_in_eager_or_tf_function() or
166 is_built_in(cls)):
167 obj_update_state = obj.update_state
169 def update_state_fn(*args, **kwargs):
170 control_status = ag_ctx.control_status_ctx()
171 ag_update_state = autograph.tf_convert(obj_update_state, control_status)
172 return ag_update_state(*args, **kwargs)
173 else:
174 if isinstance(obj.update_state, def_function.Function):
175 update_state_fn = obj.update_state
176 else:
177 update_state_fn = def_function.function(obj.update_state)
179 obj.update_state = types.MethodType(
180 metrics_utils.update_state_wrapper(update_state_fn), obj)
182 obj_result = obj.result
184 def result_fn(*args, **kwargs):
185 control_status = ag_ctx.control_status_ctx()
186 ag_result = autograph.tf_convert(obj_result, control_status)
187 return ag_result(*args, **kwargs)
189 obj.result = types.MethodType(metrics_utils.result_wrapper(result_fn), obj)
191 return obj
193 def __call__(self, *args, **kwargs):
194 """Accumulates statistics and then computes metric result value.
196 Args:
197 *args:
198 **kwargs: A mini-batch of inputs to the Metric,
199 passed on to `update_state()`.
201 Returns:
202 The metric value tensor.
203 """
205 def replica_local_fn(*args, **kwargs):
206 """Updates the state of the metric in a replica-local context."""
207 if any(
208 isinstance(arg, keras_tensor.KerasTensor)
209 for arg in nest.flatten((args, kwargs))):
210 update_op = None
211 else:
212 update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable
213 update_ops = []
214 if update_op is not None:
215 update_ops.append(update_op)
216 with ops.control_dependencies(update_ops):
217 result_t = self.result() # pylint: disable=not-callable
219 # We are adding the metric object as metadata on the result tensor.
220 # This is required when we want to use a metric with `add_metric` API on
221 # a Model/Layer in graph mode. This metric instance will later be used
222 # to reset variable state after each epoch of training.
223 # Example:
224 # model = Model()
225 # mean = Mean()
226 # model.add_metric(mean(values), name='mean')
227 result_t._metric_obj = self # pylint: disable=protected-access
228 return result_t
230 from tensorflow.python.keras.distribute import distributed_training_utils # pylint:disable=g-import-not-at-top
231 return distributed_training_utils.call_replica_local_fn(
232 replica_local_fn, *args, **kwargs)
234 @property
235 def dtype(self):
236 return self._dtype
238 def get_config(self):
239 """Returns the serializable config of the metric."""
240 return {'name': self.name, 'dtype': self.dtype}
242 def reset_state(self):
243 """Resets all of the metric state variables.
245 This function is called between epochs/steps,
246 when a metric is evaluated during training.
247 """
248 if not generic_utils.is_default(self.reset_states):
249 warnings.warn('Metric %s implements a `reset_states()` method; rename it '
250 'to `reset_state()` (without the final "s"). The name '
251 '`reset_states()` has been deprecated to improve API '
252 'consistency.' % (self.__class__.__name__,))
253 return self.reset_states()
254 else:
255 backend.batch_set_value([(v, 0) for v in self.variables])
257 @abc.abstractmethod
258 def update_state(self, *args, **kwargs):
259 """Accumulates statistics for the metric.
261 Note: This function is executed as a graph function in graph mode.
262 This means:
263 a) Operations on the same resource are executed in textual order.
264 This should make it easier to do things like add the updated
265 value of a variable to another, for example.
266 b) You don't need to worry about collecting the update ops to execute.
267 All update ops added to the graph by this function will be executed.
268 As a result, code should generally work the same way with graph or
269 eager execution.
271 Args:
272 *args:
273 **kwargs: A mini-batch of inputs to the Metric.
274 """
275 raise NotImplementedError('Must be implemented in subclasses.')
277 @abc.abstractmethod
278 def result(self):
279 """Computes and returns the metric value tensor.
281 Result computation is an idempotent operation that simply calculates the
282 metric value using the state variables.
283 """
284 raise NotImplementedError('Must be implemented in subclasses.')
286 ### For use by subclasses ###
287 @doc_controls.for_subclass_implementers
288 def add_weight(
289 self,
290 name,
291 shape=(),
292 aggregation=variables_module.VariableAggregation.SUM,
293 synchronization=variables_module.VariableSynchronization.ON_READ,
294 initializer=None,
295 dtype=None):
296 """Adds state variable. Only for use by subclasses."""
297 if distribute_lib.has_strategy():
298 strategy = distribute_lib.get_strategy()
299 else:
300 strategy = None
302 # TODO(b/120571621): Make `ON_READ` work with Keras metrics on TPU.
303 if backend.is_tpu_strategy(strategy):
304 synchronization = variables_module.VariableSynchronization.ON_WRITE
306 with ops.init_scope():
307 return super(Metric, self).add_weight(
308 name=name,
309 shape=shape,
310 dtype=self._dtype if dtype is None else dtype,
311 trainable=False,
312 initializer=initializer,
313 collections=[],
314 synchronization=synchronization,
315 aggregation=aggregation)
317 ### End: For use by subclasses ###
319 @property
320 def trainable_weights(self):
321 # Overridden from Layer class to track submetric weights.
322 if self.trainable:
323 trainable_weights = self._trainable_weights
324 for m in self._metrics:
325 trainable_weights += m.trainable_weights
326 return self._dedup_weights(trainable_weights)
327 else:
328 return []
330 @property
331 def non_trainable_weights(self):
332 # Overridden from Layer class to track submetric weights.
333 if self.trainable:
334 non_trainable_weights = self._non_trainable_weights
335 for m in self._metrics:
336 non_trainable_weights += m.non_trainable_weights
337 else:
338 non_trainable_weights = (
339 self._non_trainable_weights + self._trainable_weights)
340 for m in self._metrics:
341 non_trainable_weights += m.weights
342 return self._dedup_weights(non_trainable_weights)
344 @property
345 def _trackable_saved_model_saver(self):
346 return metric_serialization.MetricSavedModelSaver(self)
348 @generic_utils.default
349 @doc_controls.do_not_generate_docs
350 def reset_states(self):
351 # Backwards compatibility alias of `reset_state`. New classes should
352 # only implement `reset_state`.
353 return self.reset_state()
356class Reduce(Metric):
357 """Encapsulates metrics that perform a reduce operation on the values.
359 Args:
360 reduction: a `tf.keras.metrics.Reduction` enum value.
361 name: string name of the metric instance.
362 dtype: (Optional) data type of the metric result.
363 """
365 def __init__(self, reduction, name, dtype=None):
366 super(Reduce, self).__init__(name=name, dtype=dtype)
367 self.reduction = reduction
368 self.total = self.add_weight(
369 'total', initializer=init_ops.zeros_initializer)
370 if reduction in [metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
371 metrics_utils.Reduction.WEIGHTED_MEAN]:
372 self.count = self.add_weight(
373 'count', initializer=init_ops.zeros_initializer)
375 def update_state(self, values, sample_weight=None):
376 """Accumulates statistics for computing the metric.
378 Args:
379 values: Per-example value.
380 sample_weight: Optional weighting of each example. Defaults to 1.
382 Returns:
383 Update op.
384 """
385 [values], sample_weight = \
386 metrics_utils.ragged_assert_compatible_and_get_flat_values(
387 [values], sample_weight)
388 try:
389 values = math_ops.cast(values, self._dtype)
390 except (ValueError, TypeError):
391 msg = ('The output of a metric function can only be a single Tensor. '
392 'Got: %s' % (values,))
393 if isinstance(values, dict):
394 msg += ('. To return a dict of values, implement a custom Metric '
395 'subclass.')
396 raise RuntimeError(msg)
397 if sample_weight is not None:
398 sample_weight = math_ops.cast(sample_weight, self._dtype)
399 # Update dimensions of weights to match with values if possible.
400 values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions(
401 values, sample_weight=sample_weight)
402 try:
403 # Broadcast weights if possible.
404 sample_weight = weights_broadcast_ops.broadcast_weights(
405 sample_weight, values)
406 except ValueError:
407 # Reduce values to same ndim as weight array
408 ndim = backend.ndim(values)
409 weight_ndim = backend.ndim(sample_weight)
410 if self.reduction == metrics_utils.Reduction.SUM:
411 values = math_ops.reduce_sum(
412 values, axis=list(range(weight_ndim, ndim)))
413 else:
414 values = math_ops.reduce_mean(
415 values, axis=list(range(weight_ndim, ndim)))
416 values = math_ops.multiply(values, sample_weight)
418 value_sum = math_ops.reduce_sum(values)
419 with ops.control_dependencies([value_sum]):
420 update_total_op = self.total.assign_add(value_sum)
422 # Exit early if the reduction doesn't have a denominator.
423 if self.reduction == metrics_utils.Reduction.SUM:
424 return update_total_op
426 # Update `count` for reductions that require a denominator.
427 if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE:
428 num_values = math_ops.cast(array_ops.size(values), self._dtype)
429 elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN:
430 if sample_weight is None:
431 num_values = math_ops.cast(array_ops.size(values), self._dtype)
432 else:
433 num_values = math_ops.reduce_sum(sample_weight)
434 else:
435 raise NotImplementedError(
436 'reduction [%s] not implemented' % self.reduction)
438 with ops.control_dependencies([update_total_op]):
439 return self.count.assign_add(num_values)
441 def result(self):
442 if self.reduction == metrics_utils.Reduction.SUM:
443 return array_ops.identity(self.total)
444 elif self.reduction in [
445 metrics_utils.Reduction.WEIGHTED_MEAN,
446 metrics_utils.Reduction.SUM_OVER_BATCH_SIZE
447 ]:
448 return math_ops.div_no_nan(self.total, self.count)
449 else:
450 raise NotImplementedError(
451 'reduction [%s] not implemented' % self.reduction)
454@keras_export('keras.metrics.Sum')
455class Sum(Reduce):
456 """Computes the (weighted) sum of the given values.
458 For example, if values is [1, 3, 5, 7] then the sum is 16.
459 If the weights were specified as [1, 1, 0, 0] then the sum would be 4.
461 This metric creates one variable, `total`, that is used to compute the sum of
462 `values`. This is ultimately returned as `sum`.
464 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0
465 to mask values.
467 Args:
468 name: (Optional) string name of the metric instance.
469 dtype: (Optional) data type of the metric result.
471 Standalone usage:
473 >>> m = tf.keras.metrics.Sum()
474 >>> m.update_state([1, 3, 5, 7])
475 >>> m.result().numpy()
476 16.0
478 Usage with `compile()` API:
480 ```python
481 model.add_metric(tf.keras.metrics.Sum(name='sum_1')(outputs))
482 model.compile(optimizer='sgd', loss='mse')
483 ```
484 """
486 def __init__(self, name='sum', dtype=None):
487 super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM,
488 name=name, dtype=dtype)
491@keras_export('keras.metrics.Mean')
492class Mean(Reduce):
493 """Computes the (weighted) mean of the given values.
495 For example, if values is [1, 3, 5, 7] then the mean is 4.
496 If the weights were specified as [1, 1, 0, 0] then the mean would be 2.
498 This metric creates two variables, `total` and `count` that are used to
499 compute the average of `values`. This average is ultimately returned as `mean`
500 which is an idempotent operation that simply divides `total` by `count`.
502 If `sample_weight` is `None`, weights default to 1.
503 Use `sample_weight` of 0 to mask values.
505 Args:
506 name: (Optional) string name of the metric instance.
507 dtype: (Optional) data type of the metric result.
509 Standalone usage:
511 >>> m = tf.keras.metrics.Mean()
512 >>> m.update_state([1, 3, 5, 7])
513 >>> m.result().numpy()
514 4.0
515 >>> m.reset_state()
516 >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0])
517 >>> m.result().numpy()
518 2.0
520 Usage with `compile()` API:
522 ```python
523 model.add_metric(tf.keras.metrics.Mean(name='mean_1')(outputs))
524 model.compile(optimizer='sgd', loss='mse')
525 ```
526 """
528 def __init__(self, name='mean', dtype=None):
529 super(Mean, self).__init__(
530 reduction=metrics_utils.Reduction.WEIGHTED_MEAN, name=name, dtype=dtype)
533@keras_export('keras.metrics.MeanRelativeError')
534class MeanRelativeError(Mean):
535 """Computes the mean relative error by normalizing with the given values.
537 This metric creates two local variables, `total` and `count` that are used to
538 compute the mean relative error. This is weighted by `sample_weight`, and
539 it is ultimately returned as `mean_relative_error`:
540 an idempotent operation that simply divides `total` by `count`.
542 If `sample_weight` is `None`, weights default to 1.
543 Use `sample_weight` of 0 to mask values.
545 Args:
546 normalizer: The normalizer values with same shape as predictions.
547 name: (Optional) string name of the metric instance.
548 dtype: (Optional) data type of the metric result.
550 Standalone usage:
552 >>> m = tf.keras.metrics.MeanRelativeError(normalizer=[1, 3, 2, 3])
553 >>> m.update_state([1, 3, 2, 3], [2, 4, 6, 8])
555 >>> # metric = mean(|y_pred - y_true| / normalizer)
556 >>> # = mean([1, 1, 4, 5] / [1, 3, 2, 3]) = mean([1, 1/3, 2, 5/3])
557 >>> # = 5/4 = 1.25
558 >>> m.result().numpy()
559 1.25
561 Usage with `compile()` API:
563 ```python
564 model.compile(
565 optimizer='sgd',
566 loss='mse',
567 metrics=[tf.keras.metrics.MeanRelativeError(normalizer=[1, 3])])
568 ```
569 """
571 def __init__(self, normalizer, name=None, dtype=None):
572 super(MeanRelativeError, self).__init__(name=name, dtype=dtype)
573 normalizer = math_ops.cast(normalizer, self._dtype)
574 self.normalizer = normalizer
576 def update_state(self, y_true, y_pred, sample_weight=None):
577 """Accumulates metric statistics.
579 Args:
580 y_true: The ground truth values.
581 y_pred: The predicted values.
582 sample_weight: Optional weighting of each example. Defaults to 1. Can be a
583 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
584 be broadcastable to `y_true`.
586 Returns:
587 Update op.
588 """
589 y_true = math_ops.cast(y_true, self._dtype)
590 y_pred = math_ops.cast(y_pred, self._dtype)
591 [y_pred, y_true], sample_weight = \
592 metrics_utils.ragged_assert_compatible_and_get_flat_values(
593 [y_pred, y_true], sample_weight)
594 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
595 y_pred, y_true)
597 y_pred, self.normalizer = losses_utils.remove_squeezable_dimensions(
598 y_pred, self.normalizer)
599 y_pred.shape.assert_is_compatible_with(y_true.shape)
600 relative_errors = math_ops.div_no_nan(
601 math_ops.abs(y_true - y_pred), self.normalizer)
603 return super(MeanRelativeError, self).update_state(
604 relative_errors, sample_weight=sample_weight)
606 def get_config(self):
607 n = self.normalizer
608 config = {'normalizer': backend.eval(n) if is_tensor_or_variable(n) else n}
609 base_config = super(MeanRelativeError, self).get_config()
610 return dict(list(base_config.items()) + list(config.items()))
613@keras_export('keras.metrics.MeanMetricWrapper')
614class MeanMetricWrapper(Mean):
615 """Wraps a stateless metric function with the Mean metric.
617 You could use this class to quickly build a mean metric from a function. The
618 function needs to have the signature `fn(y_true, y_pred)` and return a
619 per-sample loss array. `MeanMetricWrapper.result()` will return
620 the average metric value across all samples seen so far.
622 For example:
624 ```python
625 def accuracy(y_true, y_pred):
626 return tf.cast(tf.math.equal(y_true, y_pred), tf.float32)
628 accuracy_metric = tf.keras.metrics.MeanMetricWrapper(fn=accuracy)
630 keras_model.compile(..., metrics=accuracy_metric)
631 ```
633 Args:
634 fn: The metric function to wrap, with signature `fn(y_true, y_pred,
635 **kwargs)`.
636 name: (Optional) string name of the metric instance.
637 dtype: (Optional) data type of the metric result.
638 **kwargs: Keyword arguments to pass on to `fn`.
639 """
641 def __init__(self, fn, name=None, dtype=None, **kwargs):
642 super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype)
643 self._fn = fn
644 self._fn_kwargs = kwargs
646 def update_state(self, y_true, y_pred, sample_weight=None):
647 """Accumulates metric statistics.
649 `y_true` and `y_pred` should have the same shape.
651 Args:
652 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
653 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
654 sample_weight: Optional `sample_weight` acts as a
655 coefficient for the metric. If a scalar is provided, then the metric is
656 simply scaled by the given value. If `sample_weight` is a tensor of size
657 `[batch_size]`, then the metric for each sample of the batch is rescaled
658 by the corresponding element in the `sample_weight` vector. If the shape
659 of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted
660 to this shape), then each metric element of `y_pred` is scaled by the
661 corresponding value of `sample_weight`. (Note on `dN-1`: all metric
662 functions reduce by 1 dimension, usually the last axis (-1)).
664 Returns:
665 Update op.
666 """
667 y_true = math_ops.cast(y_true, self._dtype)
668 y_pred = math_ops.cast(y_pred, self._dtype)
669 [y_true, y_pred], sample_weight = (
670 metrics_utils.ragged_assert_compatible_and_get_flat_values(
671 [y_true, y_pred], sample_weight))
672 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
673 y_pred, y_true)
675 ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx())
676 matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
677 return super(MeanMetricWrapper, self).update_state(
678 matches, sample_weight=sample_weight)
680 def get_config(self):
681 config = {}
683 if type(self) is MeanMetricWrapper: # pylint: disable=unidiomatic-typecheck
684 # Only include function argument when the object is a MeanMetricWrapper
685 # and not a subclass.
686 config['fn'] = self._fn
688 for k, v in self._fn_kwargs.items():
689 config[k] = backend.eval(v) if is_tensor_or_variable(v) else v
690 base_config = super(MeanMetricWrapper, self).get_config()
691 return dict(list(base_config.items()) + list(config.items()))
693 @classmethod
694 def from_config(cls, config):
695 # Note that while MeanMetricWrapper itself isn't public, objects of this
696 # class may be created and added to the model by calling model.compile.
697 fn = config.pop('fn', None)
698 if cls is MeanMetricWrapper:
699 return cls(get(fn), **config)
700 return super(MeanMetricWrapper, cls).from_config(config)
703@keras_export('keras.metrics.Accuracy')
704class Accuracy(MeanMetricWrapper):
705 """Calculates how often predictions equal labels.
707 This metric creates two local variables, `total` and `count` that are used to
708 compute the frequency with which `y_pred` matches `y_true`. This frequency is
709 ultimately returned as `binary accuracy`: an idempotent operation that simply
710 divides `total` by `count`.
712 If `sample_weight` is `None`, weights default to 1.
713 Use `sample_weight` of 0 to mask values.
715 Args:
716 name: (Optional) string name of the metric instance.
717 dtype: (Optional) data type of the metric result.
719 Standalone usage:
721 >>> m = tf.keras.metrics.Accuracy()
722 >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]])
723 >>> m.result().numpy()
724 0.75
726 >>> m.reset_state()
727 >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]],
728 ... sample_weight=[1, 1, 0, 0])
729 >>> m.result().numpy()
730 0.5
732 Usage with `compile()` API:
734 ```python
735 model.compile(optimizer='sgd',
736 loss='mse',
737 metrics=[tf.keras.metrics.Accuracy()])
738 ```
739 """
741 def __init__(self, name='accuracy', dtype=None):
742 super(Accuracy, self).__init__(accuracy, name, dtype=dtype)
745@keras_export('keras.metrics.BinaryAccuracy')
746class BinaryAccuracy(MeanMetricWrapper):
747 """Calculates how often predictions match binary labels.
749 This metric creates two local variables, `total` and `count` that are used to
750 compute the frequency with which `y_pred` matches `y_true`. This frequency is
751 ultimately returned as `binary accuracy`: an idempotent operation that simply
752 divides `total` by `count`.
754 If `sample_weight` is `None`, weights default to 1.
755 Use `sample_weight` of 0 to mask values.
757 Args:
758 name: (Optional) string name of the metric instance.
759 dtype: (Optional) data type of the metric result.
760 threshold: (Optional) Float representing the threshold for deciding
761 whether prediction values are 1 or 0.
763 Standalone usage:
765 >>> m = tf.keras.metrics.BinaryAccuracy()
766 >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]])
767 >>> m.result().numpy()
768 0.75
770 >>> m.reset_state()
771 >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]],
772 ... sample_weight=[1, 0, 0, 1])
773 >>> m.result().numpy()
774 0.5
776 Usage with `compile()` API:
778 ```python
779 model.compile(optimizer='sgd',
780 loss='mse',
781 metrics=[tf.keras.metrics.BinaryAccuracy()])
782 ```
783 """
785 def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5):
786 super(BinaryAccuracy, self).__init__(
787 binary_accuracy, name, dtype=dtype, threshold=threshold)
790@keras_export('keras.metrics.CategoricalAccuracy')
791class CategoricalAccuracy(MeanMetricWrapper):
792 """Calculates how often predictions match one-hot labels.
794 You can provide logits of classes as `y_pred`, since argmax of
795 logits and probabilities are same.
797 This metric creates two local variables, `total` and `count` that are used to
798 compute the frequency with which `y_pred` matches `y_true`. This frequency is
799 ultimately returned as `categorical accuracy`: an idempotent operation that
800 simply divides `total` by `count`.
802 `y_pred` and `y_true` should be passed in as vectors of probabilities, rather
803 than as labels. If necessary, use `tf.one_hot` to expand `y_true` as a vector.
805 If `sample_weight` is `None`, weights default to 1.
806 Use `sample_weight` of 0 to mask values.
808 Args:
809 name: (Optional) string name of the metric instance.
810 dtype: (Optional) data type of the metric result.
812 Standalone usage:
814 >>> m = tf.keras.metrics.CategoricalAccuracy()
815 >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],
816 ... [0.05, 0.95, 0]])
817 >>> m.result().numpy()
818 0.5
820 >>> m.reset_state()
821 >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],
822 ... [0.05, 0.95, 0]],
823 ... sample_weight=[0.7, 0.3])
824 >>> m.result().numpy()
825 0.3
827 Usage with `compile()` API:
829 ```python
830 model.compile(
831 optimizer='sgd',
832 loss='mse',
833 metrics=[tf.keras.metrics.CategoricalAccuracy()])
834 ```
835 """
837 def __init__(self, name='categorical_accuracy', dtype=None):
838 super(CategoricalAccuracy, self).__init__(
839 categorical_accuracy, name, dtype=dtype)
842@keras_export('keras.metrics.SparseCategoricalAccuracy')
843class SparseCategoricalAccuracy(MeanMetricWrapper):
844 """Calculates how often predictions match integer labels.
846 ```python
847 acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1))
848 ```
850 You can provide logits of classes as `y_pred`, since argmax of
851 logits and probabilities are same.
853 This metric creates two local variables, `total` and `count` that are used to
854 compute the frequency with which `y_pred` matches `y_true`. This frequency is
855 ultimately returned as `sparse categorical accuracy`: an idempotent operation
856 that simply divides `total` by `count`.
858 If `sample_weight` is `None`, weights default to 1.
859 Use `sample_weight` of 0 to mask values.
861 Args:
862 name: (Optional) string name of the metric instance.
863 dtype: (Optional) data type of the metric result.
865 Standalone usage:
867 >>> m = tf.keras.metrics.SparseCategoricalAccuracy()
868 >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]])
869 >>> m.result().numpy()
870 0.5
872 >>> m.reset_state()
873 >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]],
874 ... sample_weight=[0.7, 0.3])
875 >>> m.result().numpy()
876 0.3
878 Usage with `compile()` API:
880 ```python
881 model.compile(
882 optimizer='sgd',
883 loss='mse',
884 metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
885 ```
886 """
888 def __init__(self, name='sparse_categorical_accuracy', dtype=None):
889 super(SparseCategoricalAccuracy, self).__init__(
890 sparse_categorical_accuracy, name, dtype=dtype)
893@keras_export('keras.metrics.TopKCategoricalAccuracy')
894class TopKCategoricalAccuracy(MeanMetricWrapper):
895 """Computes how often targets are in the top `K` predictions.
897 Args:
898 k: (Optional) Number of top elements to look at for computing accuracy.
899 Defaults to 5.
900 name: (Optional) string name of the metric instance.
901 dtype: (Optional) data type of the metric result.
903 Standalone usage:
905 >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)
906 >>> m.update_state([[0, 0, 1], [0, 1, 0]],
907 ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
908 >>> m.result().numpy()
909 0.5
911 >>> m.reset_state()
912 >>> m.update_state([[0, 0, 1], [0, 1, 0]],
913 ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
914 ... sample_weight=[0.7, 0.3])
915 >>> m.result().numpy()
916 0.3
918 Usage with `compile()` API:
920 ```python
921 model.compile(optimizer='sgd',
922 loss='mse',
923 metrics=[tf.keras.metrics.TopKCategoricalAccuracy()])
924 ```
925 """
927 def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None):
928 super(TopKCategoricalAccuracy, self).__init__(
929 top_k_categorical_accuracy, name, dtype=dtype, k=k)
932@keras_export('keras.metrics.SparseTopKCategoricalAccuracy')
933class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
934 """Computes how often integer targets are in the top `K` predictions.
936 Args:
937 k: (Optional) Number of top elements to look at for computing accuracy.
938 Defaults to 5.
939 name: (Optional) string name of the metric instance.
940 dtype: (Optional) data type of the metric result.
942 Standalone usage:
944 >>> m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
945 >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
946 >>> m.result().numpy()
947 0.5
949 >>> m.reset_state()
950 >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
951 ... sample_weight=[0.7, 0.3])
952 >>> m.result().numpy()
953 0.3
955 Usage with `compile()` API:
957 ```python
958 model.compile(
959 optimizer='sgd',
960 loss='mse',
961 metrics=[tf.keras.metrics.SparseTopKCategoricalAccuracy()])
962 ```
963 """
965 def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None):
966 super(SparseTopKCategoricalAccuracy, self).__init__(
967 sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k)
970class _ConfusionMatrixConditionCount(Metric):
971 """Calculates the number of the given confusion matrix condition.
973 Args:
974 confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions.
975 thresholds: (Optional) Defaults to 0.5. A float value or a python list/tuple
976 of float threshold values in [0, 1]. A threshold is compared with
977 prediction values to determine the truth value of predictions (i.e., above
978 the threshold is `true`, below is `false`). One metric value is generated
979 for each threshold value.
980 name: (Optional) string name of the metric instance.
981 dtype: (Optional) data type of the metric result.
982 """
984 def __init__(self,
985 confusion_matrix_cond,
986 thresholds=None,
987 name=None,
988 dtype=None):
989 super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype)
990 self._confusion_matrix_cond = confusion_matrix_cond
991 self.init_thresholds = thresholds
992 self.thresholds = metrics_utils.parse_init_thresholds(
993 thresholds, default_threshold=0.5)
994 self._thresholds_distributed_evenly = (
995 metrics_utils.is_evenly_distributed_thresholds(self.thresholds))
996 self.accumulator = self.add_weight(
997 'accumulator',
998 shape=(len(self.thresholds),),
999 initializer=init_ops.zeros_initializer)
1001 def update_state(self, y_true, y_pred, sample_weight=None):
1002 """Accumulates the metric statistics.
1004 Args:
1005 y_true: The ground truth values.
1006 y_pred: The predicted values.
1007 sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1008 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1009 be broadcastable to `y_true`.
1011 Returns:
1012 Update op.
1013 """
1014 return metrics_utils.update_confusion_matrix_variables(
1015 {self._confusion_matrix_cond: self.accumulator},
1016 y_true,
1017 y_pred,
1018 thresholds=self.thresholds,
1019 thresholds_distributed_evenly=self._thresholds_distributed_evenly,
1020 sample_weight=sample_weight)
1022 def result(self):
1023 if len(self.thresholds) == 1:
1024 result = self.accumulator[0]
1025 else:
1026 result = self.accumulator
1027 return tensor_conversion.convert_to_tensor_v2_with_dispatch(result)
1029 def reset_state(self):
1030 num_thresholds = len(to_list(self.thresholds))
1031 backend.batch_set_value(
1032 [(v, np.zeros((num_thresholds,))) for v in self.variables])
1034 def get_config(self):
1035 config = {'thresholds': self.init_thresholds}
1036 base_config = super(_ConfusionMatrixConditionCount, self).get_config()
1037 return dict(list(base_config.items()) + list(config.items()))
1040@keras_export('keras.metrics.FalsePositives')
1041class FalsePositives(_ConfusionMatrixConditionCount):
1042 """Calculates the number of false positives.
1044 If `sample_weight` is given, calculates the sum of the weights of
1045 false positives. This metric creates one local variable, `accumulator`
1046 that is used to keep track of the number of false positives.
1048 If `sample_weight` is `None`, weights default to 1.
1049 Use `sample_weight` of 0 to mask values.
1051 Args:
1052 thresholds: (Optional) Defaults to 0.5. A float value or a python
1053 list/tuple of float threshold values in [0, 1]. A threshold is compared
1054 with prediction values to determine the truth value of predictions
1055 (i.e., above the threshold is `true`, below is `false`). One metric
1056 value is generated for each threshold value.
1057 name: (Optional) string name of the metric instance.
1058 dtype: (Optional) data type of the metric result.
1060 Standalone usage:
1062 >>> m = tf.keras.metrics.FalsePositives()
1063 >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1])
1064 >>> m.result().numpy()
1065 2.0
1067 >>> m.reset_state()
1068 >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0])
1069 >>> m.result().numpy()
1070 1.0
1072 Usage with `compile()` API:
1074 ```python
1075 model.compile(optimizer='sgd',
1076 loss='mse',
1077 metrics=[tf.keras.metrics.FalsePositives()])
1078 ```
1079 """
1081 def __init__(self, thresholds=None, name=None, dtype=None):
1082 super(FalsePositives, self).__init__(
1083 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES,
1084 thresholds=thresholds,
1085 name=name,
1086 dtype=dtype)
1089@keras_export('keras.metrics.FalseNegatives')
1090class FalseNegatives(_ConfusionMatrixConditionCount):
1091 """Calculates the number of false negatives.
1093 If `sample_weight` is given, calculates the sum of the weights of
1094 false negatives. This metric creates one local variable, `accumulator`
1095 that is used to keep track of the number of false negatives.
1097 If `sample_weight` is `None`, weights default to 1.
1098 Use `sample_weight` of 0 to mask values.
1100 Args:
1101 thresholds: (Optional) Defaults to 0.5. A float value or a python
1102 list/tuple of float threshold values in [0, 1]. A threshold is compared
1103 with prediction values to determine the truth value of predictions
1104 (i.e., above the threshold is `true`, below is `false`). One metric
1105 value is generated for each threshold value.
1106 name: (Optional) string name of the metric instance.
1107 dtype: (Optional) data type of the metric result.
1109 Standalone usage:
1111 >>> m = tf.keras.metrics.FalseNegatives()
1112 >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
1113 >>> m.result().numpy()
1114 2.0
1116 >>> m.reset_state()
1117 >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0])
1118 >>> m.result().numpy()
1119 1.0
1121 Usage with `compile()` API:
1123 ```python
1124 model.compile(optimizer='sgd',
1125 loss='mse',
1126 metrics=[tf.keras.metrics.FalseNegatives()])
1127 ```
1128 """
1130 def __init__(self, thresholds=None, name=None, dtype=None):
1131 super(FalseNegatives, self).__init__(
1132 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES,
1133 thresholds=thresholds,
1134 name=name,
1135 dtype=dtype)
1138@keras_export('keras.metrics.TrueNegatives')
1139class TrueNegatives(_ConfusionMatrixConditionCount):
1140 """Calculates the number of true negatives.
1142 If `sample_weight` is given, calculates the sum of the weights of
1143 true negatives. This metric creates one local variable, `accumulator`
1144 that is used to keep track of the number of true negatives.
1146 If `sample_weight` is `None`, weights default to 1.
1147 Use `sample_weight` of 0 to mask values.
1149 Args:
1150 thresholds: (Optional) Defaults to 0.5. A float value or a python
1151 list/tuple of float threshold values in [0, 1]. A threshold is compared
1152 with prediction values to determine the truth value of predictions
1153 (i.e., above the threshold is `true`, below is `false`). One metric
1154 value is generated for each threshold value.
1155 name: (Optional) string name of the metric instance.
1156 dtype: (Optional) data type of the metric result.
1158 Standalone usage:
1160 >>> m = tf.keras.metrics.TrueNegatives()
1161 >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0])
1162 >>> m.result().numpy()
1163 2.0
1165 >>> m.reset_state()
1166 >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0])
1167 >>> m.result().numpy()
1168 1.0
1170 Usage with `compile()` API:
1172 ```python
1173 model.compile(optimizer='sgd',
1174 loss='mse',
1175 metrics=[tf.keras.metrics.TrueNegatives()])
1176 ```
1177 """
1179 def __init__(self, thresholds=None, name=None, dtype=None):
1180 super(TrueNegatives, self).__init__(
1181 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES,
1182 thresholds=thresholds,
1183 name=name,
1184 dtype=dtype)
1187@keras_export('keras.metrics.TruePositives')
1188class TruePositives(_ConfusionMatrixConditionCount):
1189 """Calculates the number of true positives.
1191 If `sample_weight` is given, calculates the sum of the weights of
1192 true positives. This metric creates one local variable, `true_positives`
1193 that is used to keep track of the number of true positives.
1195 If `sample_weight` is `None`, weights default to 1.
1196 Use `sample_weight` of 0 to mask values.
1198 Args:
1199 thresholds: (Optional) Defaults to 0.5. A float value or a python
1200 list/tuple of float threshold values in [0, 1]. A threshold is compared
1201 with prediction values to determine the truth value of predictions
1202 (i.e., above the threshold is `true`, below is `false`). One metric
1203 value is generated for each threshold value.
1204 name: (Optional) string name of the metric instance.
1205 dtype: (Optional) data type of the metric result.
1207 Standalone usage:
1209 >>> m = tf.keras.metrics.TruePositives()
1210 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
1211 >>> m.result().numpy()
1212 2.0
1214 >>> m.reset_state()
1215 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
1216 >>> m.result().numpy()
1217 1.0
1219 Usage with `compile()` API:
1221 ```python
1222 model.compile(optimizer='sgd',
1223 loss='mse',
1224 metrics=[tf.keras.metrics.TruePositives()])
1225 ```
1226 """
1228 def __init__(self, thresholds=None, name=None, dtype=None):
1229 super(TruePositives, self).__init__(
1230 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES,
1231 thresholds=thresholds,
1232 name=name,
1233 dtype=dtype)
1236@keras_export('keras.metrics.Precision')
1237class Precision(Metric):
1238 """Computes the precision of the predictions with respect to the labels.
1240 The metric creates two local variables, `true_positives` and `false_positives`
1241 that are used to compute the precision. This value is ultimately returned as
1242 `precision`, an idempotent operation that simply divides `true_positives`
1243 by the sum of `true_positives` and `false_positives`.
1245 If `sample_weight` is `None`, weights default to 1.
1246 Use `sample_weight` of 0 to mask values.
1248 If `top_k` is set, we'll calculate precision as how often on average a class
1249 among the top-k classes with the highest predicted values of a batch entry is
1250 correct and can be found in the label for that entry.
1252 If `class_id` is specified, we calculate precision by considering only the
1253 entries in the batch for which `class_id` is above the threshold and/or in the
1254 top-k highest predictions, and computing the fraction of them for which
1255 `class_id` is indeed a correct label.
1257 Args:
1258 thresholds: (Optional) A float value or a python list/tuple of float
1259 threshold values in [0, 1]. A threshold is compared with prediction
1260 values to determine the truth value of predictions (i.e., above the
1261 threshold is `true`, below is `false`). One metric value is generated
1262 for each threshold value. If neither thresholds nor top_k are set, the
1263 default is to calculate precision with `thresholds=0.5`.
1264 top_k: (Optional) Unset by default. An int value specifying the top-k
1265 predictions to consider when calculating precision.
1266 class_id: (Optional) Integer class ID for which we want binary metrics.
1267 This must be in the half-open interval `[0, num_classes)`, where
1268 `num_classes` is the last dimension of predictions.
1269 name: (Optional) string name of the metric instance.
1270 dtype: (Optional) data type of the metric result.
1272 Standalone usage:
1274 >>> m = tf.keras.metrics.Precision()
1275 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
1276 >>> m.result().numpy()
1277 0.6666667
1279 >>> m.reset_state()
1280 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
1281 >>> m.result().numpy()
1282 1.0
1284 >>> # With top_k=2, it will calculate precision over y_true[:2] and y_pred[:2]
1285 >>> m = tf.keras.metrics.Precision(top_k=2)
1286 >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
1287 >>> m.result().numpy()
1288 0.0
1290 >>> # With top_k=4, it will calculate precision over y_true[:4] and y_pred[:4]
1291 >>> m = tf.keras.metrics.Precision(top_k=4)
1292 >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
1293 >>> m.result().numpy()
1294 0.5
1296 Usage with `compile()` API:
1298 ```python
1299 model.compile(optimizer='sgd',
1300 loss='mse',
1301 metrics=[tf.keras.metrics.Precision()])
1302 ```
1303 """
1305 def __init__(self,
1306 thresholds=None,
1307 top_k=None,
1308 class_id=None,
1309 name=None,
1310 dtype=None):
1311 super(Precision, self).__init__(name=name, dtype=dtype)
1312 self.init_thresholds = thresholds
1313 self.top_k = top_k
1314 self.class_id = class_id
1316 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
1317 self.thresholds = metrics_utils.parse_init_thresholds(
1318 thresholds, default_threshold=default_threshold)
1319 self._thresholds_distributed_evenly = (
1320 metrics_utils.is_evenly_distributed_thresholds(self.thresholds))
1321 self.true_positives = self.add_weight(
1322 'true_positives',
1323 shape=(len(self.thresholds),),
1324 initializer=init_ops.zeros_initializer)
1325 self.false_positives = self.add_weight(
1326 'false_positives',
1327 shape=(len(self.thresholds),),
1328 initializer=init_ops.zeros_initializer)
1330 def update_state(self, y_true, y_pred, sample_weight=None):
1331 """Accumulates true positive and false positive statistics.
1333 Args:
1334 y_true: The ground truth values, with the same dimensions as `y_pred`.
1335 Will be cast to `bool`.
1336 y_pred: The predicted values. Each element must be in the range `[0, 1]`.
1337 sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1338 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1339 be broadcastable to `y_true`.
1341 Returns:
1342 Update op.
1343 """
1344 return metrics_utils.update_confusion_matrix_variables(
1345 {
1346 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1347 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives
1348 },
1349 y_true,
1350 y_pred,
1351 thresholds=self.thresholds,
1352 thresholds_distributed_evenly=self._thresholds_distributed_evenly,
1353 top_k=self.top_k,
1354 class_id=self.class_id,
1355 sample_weight=sample_weight)
1357 def result(self):
1358 result = math_ops.div_no_nan(self.true_positives,
1359 self.true_positives + self.false_positives)
1360 return result[0] if len(self.thresholds) == 1 else result
1362 def reset_state(self):
1363 num_thresholds = len(to_list(self.thresholds))
1364 backend.batch_set_value([(v, np.zeros((num_thresholds,)))
1365 for v in (self.true_positives,
1366 self.false_positives)])
1368 def get_config(self):
1369 config = {
1370 'thresholds': self.init_thresholds,
1371 'top_k': self.top_k,
1372 'class_id': self.class_id
1373 }
1374 base_config = super(Precision, self).get_config()
1375 return dict(list(base_config.items()) + list(config.items()))
1378@keras_export('keras.metrics.Recall')
1379class Recall(Metric):
1380 """Computes the recall of the predictions with respect to the labels.
1382 This metric creates two local variables, `true_positives` and
1383 `false_negatives`, that are used to compute the recall. This value is
1384 ultimately returned as `recall`, an idempotent operation that simply divides
1385 `true_positives` by the sum of `true_positives` and `false_negatives`.
1387 If `sample_weight` is `None`, weights default to 1.
1388 Use `sample_weight` of 0 to mask values.
1390 If `top_k` is set, recall will be computed as how often on average a class
1391 among the labels of a batch entry is in the top-k predictions.
1393 If `class_id` is specified, we calculate recall by considering only the
1394 entries in the batch for which `class_id` is in the label, and computing the
1395 fraction of them for which `class_id` is above the threshold and/or in the
1396 top-k predictions.
1398 Args:
1399 thresholds: (Optional) A float value or a python list/tuple of float
1400 threshold values in [0, 1]. A threshold is compared with prediction
1401 values to determine the truth value of predictions (i.e., above the
1402 threshold is `true`, below is `false`). One metric value is generated
1403 for each threshold value. If neither thresholds nor top_k are set, the
1404 default is to calculate recall with `thresholds=0.5`.
1405 top_k: (Optional) Unset by default. An int value specifying the top-k
1406 predictions to consider when calculating recall.
1407 class_id: (Optional) Integer class ID for which we want binary metrics.
1408 This must be in the half-open interval `[0, num_classes)`, where
1409 `num_classes` is the last dimension of predictions.
1410 name: (Optional) string name of the metric instance.
1411 dtype: (Optional) data type of the metric result.
1413 Standalone usage:
1415 >>> m = tf.keras.metrics.Recall()
1416 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
1417 >>> m.result().numpy()
1418 0.6666667
1420 >>> m.reset_state()
1421 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
1422 >>> m.result().numpy()
1423 1.0
1425 Usage with `compile()` API:
1427 ```python
1428 model.compile(optimizer='sgd',
1429 loss='mse',
1430 metrics=[tf.keras.metrics.Recall()])
1431 ```
1432 """
1434 def __init__(self,
1435 thresholds=None,
1436 top_k=None,
1437 class_id=None,
1438 name=None,
1439 dtype=None):
1440 super(Recall, self).__init__(name=name, dtype=dtype)
1441 self.init_thresholds = thresholds
1442 self.top_k = top_k
1443 self.class_id = class_id
1445 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
1446 self.thresholds = metrics_utils.parse_init_thresholds(
1447 thresholds, default_threshold=default_threshold)
1448 self._thresholds_distributed_evenly = (
1449 metrics_utils.is_evenly_distributed_thresholds(self.thresholds))
1450 self.true_positives = self.add_weight(
1451 'true_positives',
1452 shape=(len(self.thresholds),),
1453 initializer=init_ops.zeros_initializer)
1454 self.false_negatives = self.add_weight(
1455 'false_negatives',
1456 shape=(len(self.thresholds),),
1457 initializer=init_ops.zeros_initializer)
1459 def update_state(self, y_true, y_pred, sample_weight=None):
1460 """Accumulates true positive and false negative statistics.
1462 Args:
1463 y_true: The ground truth values, with the same dimensions as `y_pred`.
1464 Will be cast to `bool`.
1465 y_pred: The predicted values. Each element must be in the range `[0, 1]`.
1466 sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1467 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1468 be broadcastable to `y_true`.
1470 Returns:
1471 Update op.
1472 """
1473 return metrics_utils.update_confusion_matrix_variables(
1474 {
1475 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1476 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives
1477 },
1478 y_true,
1479 y_pred,
1480 thresholds=self.thresholds,
1481 thresholds_distributed_evenly=self._thresholds_distributed_evenly,
1482 top_k=self.top_k,
1483 class_id=self.class_id,
1484 sample_weight=sample_weight)
1486 def result(self):
1487 result = math_ops.div_no_nan(self.true_positives,
1488 self.true_positives + self.false_negatives)
1489 return result[0] if len(self.thresholds) == 1 else result
1491 def reset_state(self):
1492 num_thresholds = len(to_list(self.thresholds))
1493 backend.batch_set_value([(v, np.zeros((num_thresholds,)))
1494 for v in (self.true_positives,
1495 self.false_negatives)])
1497 def get_config(self):
1498 config = {
1499 'thresholds': self.init_thresholds,
1500 'top_k': self.top_k,
1501 'class_id': self.class_id
1502 }
1503 base_config = super(Recall, self).get_config()
1504 return dict(list(base_config.items()) + list(config.items()))
1507class SensitivitySpecificityBase(Metric, metaclass=abc.ABCMeta):
1508 """Abstract base class for computing sensitivity and specificity.
1510 For additional information about specificity and sensitivity, see
1511 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
1512 """
1514 def __init__(self,
1515 value,
1516 num_thresholds=200,
1517 class_id=None,
1518 name=None,
1519 dtype=None):
1520 super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype)
1521 if num_thresholds <= 0:
1522 raise ValueError('`num_thresholds` must be > 0.')
1523 self.value = value
1524 self.class_id = class_id
1525 self.true_positives = self.add_weight(
1526 'true_positives',
1527 shape=(num_thresholds,),
1528 initializer=init_ops.zeros_initializer)
1529 self.true_negatives = self.add_weight(
1530 'true_negatives',
1531 shape=(num_thresholds,),
1532 initializer=init_ops.zeros_initializer)
1533 self.false_positives = self.add_weight(
1534 'false_positives',
1535 shape=(num_thresholds,),
1536 initializer=init_ops.zeros_initializer)
1537 self.false_negatives = self.add_weight(
1538 'false_negatives',
1539 shape=(num_thresholds,),
1540 initializer=init_ops.zeros_initializer)
1542 # Compute `num_thresholds` thresholds in [0, 1]
1543 if num_thresholds == 1:
1544 self.thresholds = [0.5]
1545 self._thresholds_distributed_evenly = False
1546 else:
1547 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
1548 for i in range(num_thresholds - 2)]
1549 self.thresholds = [0.0] + thresholds + [1.0]
1550 self._thresholds_distributed_evenly = True
1552 def update_state(self, y_true, y_pred, sample_weight=None):
1553 """Accumulates confusion matrix statistics.
1555 Args:
1556 y_true: The ground truth values.
1557 y_pred: The predicted values.
1558 sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1559 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1560 be broadcastable to `y_true`.
1562 Returns:
1563 Update op.
1564 """
1565 return metrics_utils.update_confusion_matrix_variables(
1566 {
1567 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1568 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
1569 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
1570 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
1571 },
1572 y_true,
1573 y_pred,
1574 thresholds=self.thresholds,
1575 thresholds_distributed_evenly=self._thresholds_distributed_evenly,
1576 class_id=self.class_id,
1577 sample_weight=sample_weight)
1579 def reset_state(self):
1580 num_thresholds = len(self.thresholds)
1581 confusion_matrix_variables = (self.true_positives, self.true_negatives,
1582 self.false_positives, self.false_negatives)
1583 backend.batch_set_value([
1584 (v, np.zeros((num_thresholds,))) for v in confusion_matrix_variables
1585 ])
1587 def get_config(self):
1588 config = {'class_id': self.class_id}
1589 base_config = super(SensitivitySpecificityBase, self).get_config()
1590 return dict(list(base_config.items()) + list(config.items()))
1592 def _find_max_under_constraint(self, constrained, dependent, predicate):
1593 """Returns the maximum of dependent_statistic that satisfies the constraint.
1595 Args:
1596 constrained: Over these values the constraint
1597 is specified. A rank-1 tensor.
1598 dependent: From these values the maximum that satiesfies the
1599 constraint is selected. Values in this tensor and in
1600 `constrained` are linked by having the same threshold at each
1601 position, hence this tensor must have the same shape.
1602 predicate: A binary boolean functor to be applied to arguments
1603 `constrained` and `self.value`, e.g. `tf.greater`.
1605 Returns maximal dependent value, if no value satiesfies the constraint 0.0.
1606 """
1607 feasible = array_ops.where_v2(predicate(constrained, self.value))
1608 feasible_exists = math_ops.greater(array_ops.size(feasible), 0)
1609 max_dependent = math_ops.reduce_max(array_ops.gather(dependent, feasible))
1611 return array_ops.where_v2(feasible_exists, max_dependent, 0.0)
1614@keras_export('keras.metrics.SensitivityAtSpecificity')
1615class SensitivityAtSpecificity(SensitivitySpecificityBase):
1616 """Computes best sensitivity where specificity is >= specified value.
1618 the sensitivity at a given specificity.
1620 `Sensitivity` measures the proportion of actual positives that are correctly
1621 identified as such (tp / (tp + fn)).
1622 `Specificity` measures the proportion of actual negatives that are correctly
1623 identified as such (tn / (tn + fp)).
1625 This metric creates four local variables, `true_positives`, `true_negatives`,
1626 `false_positives` and `false_negatives` that are used to compute the
1627 sensitivity at the given specificity. The threshold for the given specificity
1628 value is computed and used to evaluate the corresponding sensitivity.
1630 If `sample_weight` is `None`, weights default to 1.
1631 Use `sample_weight` of 0 to mask values.
1633 If `class_id` is specified, we calculate precision by considering only the
1634 entries in the batch for which `class_id` is above the threshold predictions,
1635 and computing the fraction of them for which `class_id` is indeed a correct
1636 label.
1638 For additional information about specificity and sensitivity, see
1639 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
1641 Args:
1642 specificity: A scalar value in range `[0, 1]`.
1643 num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1644 use for matching the given specificity.
1645 class_id: (Optional) Integer class ID for which we want binary metrics.
1646 This must be in the half-open interval `[0, num_classes)`, where
1647 `num_classes` is the last dimension of predictions.
1648 name: (Optional) string name of the metric instance.
1649 dtype: (Optional) data type of the metric result.
1651 Standalone usage:
1653 >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5)
1654 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
1655 >>> m.result().numpy()
1656 0.5
1658 >>> m.reset_state()
1659 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
1660 ... sample_weight=[1, 1, 2, 2, 1])
1661 >>> m.result().numpy()
1662 0.333333
1664 Usage with `compile()` API:
1666 ```python
1667 model.compile(
1668 optimizer='sgd',
1669 loss='mse',
1670 metrics=[tf.keras.metrics.SensitivityAtSpecificity()])
1671 ```
1672 """
1674 def __init__(self,
1675 specificity,
1676 num_thresholds=200,
1677 class_id=None,
1678 name=None,
1679 dtype=None):
1680 if specificity < 0 or specificity > 1:
1681 raise ValueError('`specificity` must be in the range [0, 1].')
1682 self.specificity = specificity
1683 self.num_thresholds = num_thresholds
1684 super(SensitivityAtSpecificity, self).__init__(
1685 specificity,
1686 num_thresholds=num_thresholds,
1687 class_id=class_id,
1688 name=name,
1689 dtype=dtype)
1691 def result(self):
1692 specificities = math_ops.div_no_nan(
1693 self.true_negatives, self.true_negatives + self.false_positives)
1694 sensitivities = math_ops.div_no_nan(
1695 self.true_positives, self.true_positives + self.false_negatives)
1696 return self._find_max_under_constraint(
1697 specificities, sensitivities, math_ops.greater_equal)
1699 def get_config(self):
1700 config = {
1701 'num_thresholds': self.num_thresholds,
1702 'specificity': self.specificity
1703 }
1704 base_config = super(SensitivityAtSpecificity, self).get_config()
1705 return dict(list(base_config.items()) + list(config.items()))
1708@keras_export('keras.metrics.SpecificityAtSensitivity')
1709class SpecificityAtSensitivity(SensitivitySpecificityBase):
1710 """Computes best specificity where sensitivity is >= specified value.
1712 `Sensitivity` measures the proportion of actual positives that are correctly
1713 identified as such (tp / (tp + fn)).
1714 `Specificity` measures the proportion of actual negatives that are correctly
1715 identified as such (tn / (tn + fp)).
1717 This metric creates four local variables, `true_positives`, `true_negatives`,
1718 `false_positives` and `false_negatives` that are used to compute the
1719 specificity at the given sensitivity. The threshold for the given sensitivity
1720 value is computed and used to evaluate the corresponding specificity.
1722 If `sample_weight` is `None`, weights default to 1.
1723 Use `sample_weight` of 0 to mask values.
1725 If `class_id` is specified, we calculate precision by considering only the
1726 entries in the batch for which `class_id` is above the threshold predictions,
1727 and computing the fraction of them for which `class_id` is indeed a correct
1728 label.
1730 For additional information about specificity and sensitivity, see
1731 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
1733 Args:
1734 sensitivity: A scalar value in range `[0, 1]`.
1735 num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1736 use for matching the given sensitivity.
1737 class_id: (Optional) Integer class ID for which we want binary metrics.
1738 This must be in the half-open interval `[0, num_classes)`, where
1739 `num_classes` is the last dimension of predictions.
1740 name: (Optional) string name of the metric instance.
1741 dtype: (Optional) data type of the metric result.
1743 Standalone usage:
1745 >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5)
1746 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
1747 >>> m.result().numpy()
1748 0.66666667
1750 >>> m.reset_state()
1751 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
1752 ... sample_weight=[1, 1, 2, 2, 2])
1753 >>> m.result().numpy()
1754 0.5
1756 Usage with `compile()` API:
1758 ```python
1759 model.compile(
1760 optimizer='sgd',
1761 loss='mse',
1762 metrics=[tf.keras.metrics.SpecificityAtSensitivity()])
1763 ```
1764 """
1766 def __init__(self,
1767 sensitivity,
1768 num_thresholds=200,
1769 class_id=None,
1770 name=None,
1771 dtype=None):
1772 if sensitivity < 0 or sensitivity > 1:
1773 raise ValueError('`sensitivity` must be in the range [0, 1].')
1774 self.sensitivity = sensitivity
1775 self.num_thresholds = num_thresholds
1776 super(SpecificityAtSensitivity, self).__init__(
1777 sensitivity,
1778 num_thresholds=num_thresholds,
1779 class_id=class_id,
1780 name=name,
1781 dtype=dtype)
1783 def result(self):
1784 sensitivities = math_ops.div_no_nan(
1785 self.true_positives, self.true_positives + self.false_negatives)
1786 specificities = math_ops.div_no_nan(
1787 self.true_negatives, self.true_negatives + self.false_positives)
1788 return self._find_max_under_constraint(
1789 sensitivities, specificities, math_ops.greater_equal)
1791 def get_config(self):
1792 config = {
1793 'num_thresholds': self.num_thresholds,
1794 'sensitivity': self.sensitivity
1795 }
1796 base_config = super(SpecificityAtSensitivity, self).get_config()
1797 return dict(list(base_config.items()) + list(config.items()))
1800@keras_export('keras.metrics.PrecisionAtRecall')
1801class PrecisionAtRecall(SensitivitySpecificityBase):
1802 """Computes best precision where recall is >= specified value.
1804 This metric creates four local variables, `true_positives`, `true_negatives`,
1805 `false_positives` and `false_negatives` that are used to compute the
1806 precision at the given recall. The threshold for the given recall
1807 value is computed and used to evaluate the corresponding precision.
1809 If `sample_weight` is `None`, weights default to 1.
1810 Use `sample_weight` of 0 to mask values.
1812 If `class_id` is specified, we calculate precision by considering only the
1813 entries in the batch for which `class_id` is above the threshold predictions,
1814 and computing the fraction of them for which `class_id` is indeed a correct
1815 label.
1817 Args:
1818 recall: A scalar value in range `[0, 1]`.
1819 num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1820 use for matching the given recall.
1821 class_id: (Optional) Integer class ID for which we want binary metrics.
1822 This must be in the half-open interval `[0, num_classes)`, where
1823 `num_classes` is the last dimension of predictions.
1824 name: (Optional) string name of the metric instance.
1825 dtype: (Optional) data type of the metric result.
1827 Standalone usage:
1829 >>> m = tf.keras.metrics.PrecisionAtRecall(0.5)
1830 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
1831 >>> m.result().numpy()
1832 0.5
1834 >>> m.reset_state()
1835 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
1836 ... sample_weight=[2, 2, 2, 1, 1])
1837 >>> m.result().numpy()
1838 0.33333333
1840 Usage with `compile()` API:
1842 ```python
1843 model.compile(
1844 optimizer='sgd',
1845 loss='mse',
1846 metrics=[tf.keras.metrics.PrecisionAtRecall(recall=0.8)])
1847 ```
1848 """
1850 def __init__(self,
1851 recall,
1852 num_thresholds=200,
1853 class_id=None,
1854 name=None,
1855 dtype=None):
1856 if recall < 0 or recall > 1:
1857 raise ValueError('`recall` must be in the range [0, 1].')
1858 self.recall = recall
1859 self.num_thresholds = num_thresholds
1860 super(PrecisionAtRecall, self).__init__(
1861 value=recall,
1862 num_thresholds=num_thresholds,
1863 class_id=class_id,
1864 name=name,
1865 dtype=dtype)
1867 def result(self):
1868 recalls = math_ops.div_no_nan(
1869 self.true_positives, self.true_positives + self.false_negatives)
1870 precisions = math_ops.div_no_nan(
1871 self.true_positives, self.true_positives + self.false_positives)
1872 return self._find_max_under_constraint(
1873 recalls, precisions, math_ops.greater_equal)
1875 def get_config(self):
1876 config = {'num_thresholds': self.num_thresholds, 'recall': self.recall}
1877 base_config = super(PrecisionAtRecall, self).get_config()
1878 return dict(list(base_config.items()) + list(config.items()))
1881@keras_export('keras.metrics.RecallAtPrecision')
1882class RecallAtPrecision(SensitivitySpecificityBase):
1883 """Computes best recall where precision is >= specified value.
1885 For a given score-label-distribution the required precision might not
1886 be achievable, in this case 0.0 is returned as recall.
1888 This metric creates four local variables, `true_positives`, `true_negatives`,
1889 `false_positives` and `false_negatives` that are used to compute the
1890 recall at the given precision. The threshold for the given precision
1891 value is computed and used to evaluate the corresponding recall.
1893 If `sample_weight` is `None`, weights default to 1.
1894 Use `sample_weight` of 0 to mask values.
1896 If `class_id` is specified, we calculate precision by considering only the
1897 entries in the batch for which `class_id` is above the threshold predictions,
1898 and computing the fraction of them for which `class_id` is indeed a correct
1899 label.
1901 Args:
1902 precision: A scalar value in range `[0, 1]`.
1903 num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1904 use for matching the given precision.
1905 class_id: (Optional) Integer class ID for which we want binary metrics.
1906 This must be in the half-open interval `[0, num_classes)`, where
1907 `num_classes` is the last dimension of predictions.
1908 name: (Optional) string name of the metric instance.
1909 dtype: (Optional) data type of the metric result.
1911 Standalone usage:
1913 >>> m = tf.keras.metrics.RecallAtPrecision(0.8)
1914 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
1915 >>> m.result().numpy()
1916 0.5
1918 >>> m.reset_state()
1919 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
1920 ... sample_weight=[1, 0, 0, 1])
1921 >>> m.result().numpy()
1922 1.0
1924 Usage with `compile()` API:
1926 ```python
1927 model.compile(
1928 optimizer='sgd',
1929 loss='mse',
1930 metrics=[tf.keras.metrics.RecallAtPrecision(precision=0.8)])
1931 ```
1932 """
1934 def __init__(self,
1935 precision,
1936 num_thresholds=200,
1937 class_id=None,
1938 name=None,
1939 dtype=None):
1940 if precision < 0 or precision > 1:
1941 raise ValueError('`precision` must be in the range [0, 1].')
1942 self.precision = precision
1943 self.num_thresholds = num_thresholds
1944 super(RecallAtPrecision, self).__init__(
1945 value=precision,
1946 num_thresholds=num_thresholds,
1947 class_id=class_id,
1948 name=name,
1949 dtype=dtype)
1951 def result(self):
1952 precisions = math_ops.div_no_nan(
1953 self.true_positives, self.true_positives + self.false_positives)
1954 recalls = math_ops.div_no_nan(
1955 self.true_positives, self.true_positives + self.false_negatives)
1956 return self._find_max_under_constraint(
1957 precisions, recalls, math_ops.greater_equal)
1959 def get_config(self):
1960 config = {'num_thresholds': self.num_thresholds,
1961 'precision': self.precision}
1962 base_config = super(RecallAtPrecision, self).get_config()
1963 return dict(list(base_config.items()) + list(config.items()))
1966@keras_export('keras.metrics.AUC')
1967class AUC(Metric):
1968 """Approximates the AUC (Area under the curve) of the ROC or PR curves.
1970 The AUC (Area under the curve) of the ROC (Receiver operating
1971 characteristic; default) or PR (Precision Recall) curves are quality measures
1972 of binary classifiers. Unlike the accuracy, and like cross-entropy
1973 losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
1975 This class approximates AUCs using a Riemann sum. During the metric
1976 accumulation phrase, predictions are accumulated within predefined buckets
1977 by value. The AUC is then computed by interpolating per-bucket averages. These
1978 buckets define the evaluated operational points.
1980 This metric creates four local variables, `true_positives`, `true_negatives`,
1981 `false_positives` and `false_negatives` that are used to compute the AUC.
1982 To discretize the AUC curve, a linearly spaced set of thresholds is used to
1983 compute pairs of recall and precision values. The area under the ROC-curve is
1984 therefore computed using the height of the recall values by the false positive
1985 rate, while the area under the PR-curve is the computed using the height of
1986 the precision values by the recall.
1988 This value is ultimately returned as `auc`, an idempotent operation that
1989 computes the area under a discretized curve of precision versus recall values
1990 (computed using the aforementioned variables). The `num_thresholds` variable
1991 controls the degree of discretization with larger numbers of thresholds more
1992 closely approximating the true AUC. The quality of the approximation may vary
1993 dramatically depending on `num_thresholds`. The `thresholds` parameter can be
1994 used to manually specify thresholds which split the predictions more evenly.
1996 For a best approximation of the real AUC, `predictions` should be distributed
1997 approximately uniformly in the range [0, 1] (if `from_logits=False`). The
1998 quality of the AUC approximation may be poor if this is not the case. Setting
1999 `summation_method` to 'minoring' or 'majoring' can help quantify the error in
2000 the approximation by providing lower or upper bound estimate of the AUC.
2002 If `sample_weight` is `None`, weights default to 1.
2003 Use `sample_weight` of 0 to mask values.
2005 Args:
2006 num_thresholds: (Optional) Defaults to 200. The number of thresholds to
2007 use when discretizing the roc curve. Values must be > 1.
2008 curve: (Optional) Specifies the name of the curve to be computed, 'ROC'
2009 [default] or 'PR' for the Precision-Recall-curve.
2010 summation_method: (Optional) Specifies the [Riemann summation method](
2011 https://en.wikipedia.org/wiki/Riemann_sum) used.
2012 'interpolation' (default) applies mid-point summation scheme for `ROC`.
2013 For PR-AUC, interpolates (true/false) positives but not the ratio that
2014 is precision (see Davis & Goadrich 2006 for details);
2015 'minoring' applies left summation
2016 for increasing intervals and right summation for decreasing intervals;
2017 'majoring' does the opposite.
2018 name: (Optional) string name of the metric instance.
2019 dtype: (Optional) data type of the metric result.
2020 thresholds: (Optional) A list of floating point values to use as the
2021 thresholds for discretizing the curve. If set, the `num_thresholds`
2022 parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
2023 equal to {-epsilon, 1+epsilon} for a small positive epsilon value will
2024 be automatically included with these to correctly handle predictions
2025 equal to exactly 0 or 1.
2026 multi_label: boolean indicating whether multilabel data should be
2027 treated as such, wherein AUC is computed separately for each label and
2028 then averaged across labels, or (when False) if the data should be
2029 flattened into a single label before AUC computation. In the latter
2030 case, when multilabel data is passed to AUC, each label-prediction pair
2031 is treated as an individual data point. Should be set to False for
2032 multi-class data.
2033 num_labels: (Optional) The number of labels, used when `multi_label` is
2034 True. If `num_labels` is not specified, then state variables get created
2035 on the first call to `update_state`.
2036 label_weights: (Optional) list, array, or tensor of non-negative weights
2037 used to compute AUCs for multilabel data. When `multi_label` is True,
2038 the weights are applied to the individual label AUCs when they are
2039 averaged to produce the multi-label AUC. When it's False, they are used
2040 to weight the individual label predictions in computing the confusion
2041 matrix on the flattened data. Note that this is unlike class_weights in
2042 that class_weights weights the example depending on the value of its
2043 label, whereas label_weights depends only on the index of that label
2044 before flattening; therefore `label_weights` should not be used for
2045 multi-class data.
2046 from_logits: boolean indicating whether the predictions (`y_pred` in
2047 `update_state`) are probabilities or sigmoid logits. As a rule of thumb,
2048 when using a keras loss, the `from_logits` constructor argument of the
2049 loss should match the AUC `from_logits` constructor argument.
2051 Standalone usage:
2053 >>> m = tf.keras.metrics.AUC(num_thresholds=3)
2054 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
2055 >>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
2056 >>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
2057 >>> # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0]
2058 >>> # auc = ((((1+0.5)/2)*(1-0)) + (((0.5+0)/2)*(0-0))) = 0.75
2059 >>> m.result().numpy()
2060 0.75
2062 >>> m.reset_state()
2063 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
2064 ... sample_weight=[1, 0, 0, 1])
2065 >>> m.result().numpy()
2066 1.0
2068 Usage with `compile()` API:
2070 ```python
2071 # Reports the AUC of a model outputing a probability.
2072 model.compile(optimizer='sgd',
2073 loss=tf.keras.losses.BinaryCrossentropy(),
2074 metrics=[tf.keras.metrics.AUC()])
2076 # Reports the AUC of a model outputing a logit.
2077 model.compile(optimizer='sgd',
2078 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
2079 metrics=[tf.keras.metrics.AUC(from_logits=True)])
2080 ```
2081 """
2083 def __init__(self,
2084 num_thresholds=200,
2085 curve='ROC',
2086 summation_method='interpolation',
2087 name=None,
2088 dtype=None,
2089 thresholds=None,
2090 multi_label=False,
2091 num_labels=None,
2092 label_weights=None,
2093 from_logits=False):
2094 # Validate configurations.
2095 if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
2096 metrics_utils.AUCCurve):
2097 raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format(
2098 curve, list(metrics_utils.AUCCurve)))
2099 if isinstance(
2100 summation_method,
2101 metrics_utils.AUCSummationMethod) and summation_method not in list(
2102 metrics_utils.AUCSummationMethod):
2103 raise ValueError(
2104 'Invalid summation method: "{}". Valid options are: "{}"'.format(
2105 summation_method, list(metrics_utils.AUCSummationMethod)))
2107 # Update properties.
2108 if thresholds is not None:
2109 # If specified, use the supplied thresholds.
2110 self.num_thresholds = len(thresholds) + 2
2111 thresholds = sorted(thresholds)
2112 self._thresholds_distributed_evenly = (
2113 metrics_utils.is_evenly_distributed_thresholds(
2114 np.array([0.0] + thresholds + [1.0])))
2115 else:
2116 if num_thresholds <= 1:
2117 raise ValueError('`num_thresholds` must be > 1.')
2119 # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
2120 # (0, 1).
2121 self.num_thresholds = num_thresholds
2122 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
2123 for i in range(num_thresholds - 2)]
2124 self._thresholds_distributed_evenly = True
2126 # Add an endpoint "threshold" below zero and above one for either
2127 # threshold method to account for floating point imprecisions.
2128 self._thresholds = np.array([0.0 - backend.epsilon()] + thresholds +
2129 [1.0 + backend.epsilon()])
2131 if isinstance(curve, metrics_utils.AUCCurve):
2132 self.curve = curve
2133 else:
2134 self.curve = metrics_utils.AUCCurve.from_str(curve)
2135 if isinstance(summation_method, metrics_utils.AUCSummationMethod):
2136 self.summation_method = summation_method
2137 else:
2138 self.summation_method = metrics_utils.AUCSummationMethod.from_str(
2139 summation_method)
2140 super(AUC, self).__init__(name=name, dtype=dtype)
2142 # Handle multilabel arguments.
2143 self.multi_label = multi_label
2144 if label_weights is not None:
2145 label_weights = constant_op.constant(label_weights, dtype=self.dtype)
2146 checks = [
2147 check_ops.assert_non_negative(
2148 label_weights,
2149 message='All values of `label_weights` must be non-negative.')
2150 ]
2151 with ops.control_dependencies(checks):
2152 self.label_weights = label_weights
2154 else:
2155 self.label_weights = None
2157 self._from_logits = from_logits
2159 self._built = False
2160 if self.multi_label:
2161 if num_labels:
2162 shape = tensor_shape.TensorShape([None, num_labels])
2163 self._build(shape)
2164 else:
2165 if num_labels:
2166 raise ValueError(
2167 '`num_labels` is needed only when `multi_label` is True.')
2168 self._build(None)
2170 @property
2171 def thresholds(self):
2172 """The thresholds used for evaluating AUC."""
2173 return list(self._thresholds)
2175 def _build(self, shape):
2176 """Initialize TP, FP, TN, and FN tensors, given the shape of the data."""
2177 if self.multi_label:
2178 if shape.ndims != 2:
2179 raise ValueError('`y_true` must have rank=2 when `multi_label` is '
2180 'True. Found rank %s.' % shape.ndims)
2181 self._num_labels = shape[1]
2182 variable_shape = tensor_shape.TensorShape(
2183 [tensor_shape.Dimension(self.num_thresholds), self._num_labels])
2185 else:
2186 variable_shape = tensor_shape.TensorShape(
2187 [tensor_shape.Dimension(self.num_thresholds)])
2188 self._build_input_shape = shape
2189 # Create metric variables
2190 self.true_positives = self.add_weight(
2191 'true_positives',
2192 shape=variable_shape,
2193 initializer=init_ops.zeros_initializer)
2194 self.true_negatives = self.add_weight(
2195 'true_negatives',
2196 shape=variable_shape,
2197 initializer=init_ops.zeros_initializer)
2198 self.false_positives = self.add_weight(
2199 'false_positives',
2200 shape=variable_shape,
2201 initializer=init_ops.zeros_initializer)
2202 self.false_negatives = self.add_weight(
2203 'false_negatives',
2204 shape=variable_shape,
2205 initializer=init_ops.zeros_initializer)
2207 if self.multi_label:
2208 with ops.init_scope():
2209 # This should only be necessary for handling v1 behavior. In v2, AUC
2210 # should be initialized outside of any tf.functions, and therefore in
2211 # eager mode.
2212 if not context.executing_eagerly():
2213 backend._initialize_variables(backend._get_session()) # pylint: disable=protected-access
2215 self._built = True
2217 def update_state(self, y_true, y_pred, sample_weight=None):
2218 """Accumulates confusion matrix statistics.
2220 Args:
2221 y_true: The ground truth values.
2222 y_pred: The predicted values.
2223 sample_weight: Optional weighting of each example. Defaults to 1. Can be a
2224 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
2225 be broadcastable to `y_true`.
2227 Returns:
2228 Update op.
2229 """
2230 deps = []
2231 if not self._built:
2232 self._build(tensor_shape.TensorShape(y_pred.shape))
2234 if self.multi_label or (self.label_weights is not None):
2235 # y_true should have shape (number of examples, number of labels).
2236 shapes = [
2237 (y_true, ('N', 'L'))
2238 ]
2239 if self.multi_label:
2240 # TP, TN, FP, and FN should all have shape
2241 # (number of thresholds, number of labels).
2242 shapes.extend([(self.true_positives, ('T', 'L')),
2243 (self.true_negatives, ('T', 'L')),
2244 (self.false_positives, ('T', 'L')),
2245 (self.false_negatives, ('T', 'L'))])
2246 if self.label_weights is not None:
2247 # label_weights should be of length equal to the number of labels.
2248 shapes.append((self.label_weights, ('L',)))
2249 deps = [
2250 check_ops.assert_shapes(
2251 shapes, message='Number of labels is not consistent.')
2252 ]
2254 # Only forward label_weights to update_confusion_matrix_variables when
2255 # multi_label is False. Otherwise the averaging of individual label AUCs is
2256 # handled in AUC.result
2257 label_weights = None if self.multi_label else self.label_weights
2259 if self._from_logits:
2260 y_pred = activations.sigmoid(y_pred)
2262 with ops.control_dependencies(deps):
2263 return metrics_utils.update_confusion_matrix_variables(
2264 {
2265 metrics_utils.ConfusionMatrix.TRUE_POSITIVES:
2266 self.true_positives,
2267 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES:
2268 self.true_negatives,
2269 metrics_utils.ConfusionMatrix.FALSE_POSITIVES:
2270 self.false_positives,
2271 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES:
2272 self.false_negatives,
2273 },
2274 y_true,
2275 y_pred,
2276 self._thresholds,
2277 thresholds_distributed_evenly=self._thresholds_distributed_evenly,
2278 sample_weight=sample_weight,
2279 multi_label=self.multi_label,
2280 label_weights=label_weights)
2282 def interpolate_pr_auc(self):
2283 """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
2285 https://www.biostat.wisc.edu/~page/rocpr.pdf
2287 Note here we derive & use a closed formula not present in the paper
2288 as follows:
2290 Precision = TP / (TP + FP) = TP / P
2292 Modeling all of TP (true positive), FP (false positive) and their sum
2293 P = TP + FP (predicted positive) as varying linearly within each interval
2294 [A, B] between successive thresholds, we get
2296 Precision slope = dTP / dP
2297 = (TP_B - TP_A) / (P_B - P_A)
2298 = (TP - TP_A) / (P - P_A)
2299 Precision = (TP_A + slope * (P - P_A)) / P
2301 The area within the interval is (slope / total_pos_weight) times
2303 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
2304 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
2306 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
2308 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
2310 Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
2312 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
2314 where dTP == TP_B - TP_A.
2316 Note that when P_A == 0 the above calculation simplifies into
2318 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
2320 which is really equivalent to imputing constant precision throughout the
2321 first bucket having >0 true positives.
2323 Returns:
2324 pr_auc: an approximation of the area under the P-R curve.
2325 """
2326 dtp = self.true_positives[:self.num_thresholds -
2327 1] - self.true_positives[1:]
2328 p = self.true_positives + self.false_positives
2329 dp = p[:self.num_thresholds - 1] - p[1:]
2330 prec_slope = math_ops.div_no_nan(
2331 dtp, math_ops.maximum(dp, 0), name='prec_slope')
2332 intercept = self.true_positives[1:] - math_ops.multiply(prec_slope, p[1:])
2334 safe_p_ratio = array_ops.where(
2335 math_ops.logical_and(p[:self.num_thresholds - 1] > 0, p[1:] > 0),
2336 math_ops.div_no_nan(
2337 p[:self.num_thresholds - 1],
2338 math_ops.maximum(p[1:], 0),
2339 name='recall_relative_ratio'),
2340 array_ops.ones_like(p[1:]))
2342 pr_auc_increment = math_ops.div_no_nan(
2343 prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
2344 math_ops.maximum(self.true_positives[1:] + self.false_negatives[1:], 0),
2345 name='pr_auc_increment')
2347 if self.multi_label:
2348 by_label_auc = math_ops.reduce_sum(
2349 pr_auc_increment, name=self.name + '_by_label', axis=0)
2350 if self.label_weights is None:
2351 # Evenly weighted average of the label AUCs.
2352 return math_ops.reduce_mean(by_label_auc, name=self.name)
2353 else:
2354 # Weighted average of the label AUCs.
2355 return math_ops.div_no_nan(
2356 math_ops.reduce_sum(
2357 math_ops.multiply(by_label_auc, self.label_weights)),
2358 math_ops.reduce_sum(self.label_weights),
2359 name=self.name)
2360 else:
2361 return math_ops.reduce_sum(pr_auc_increment, name='interpolate_pr_auc')
2363 def result(self):
2364 if (self.curve == metrics_utils.AUCCurve.PR and
2365 self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION
2366 ):
2367 # This use case is different and is handled separately.
2368 return self.interpolate_pr_auc()
2370 # Set `x` and `y` values for the curves based on `curve` config.
2371 recall = math_ops.div_no_nan(self.true_positives,
2372 self.true_positives + self.false_negatives)
2373 if self.curve == metrics_utils.AUCCurve.ROC:
2374 fp_rate = math_ops.div_no_nan(self.false_positives,
2375 self.false_positives + self.true_negatives)
2376 x = fp_rate
2377 y = recall
2378 else: # curve == 'PR'.
2379 precision = math_ops.div_no_nan(
2380 self.true_positives, self.true_positives + self.false_positives)
2381 x = recall
2382 y = precision
2384 # Find the rectangle heights based on `summation_method`.
2385 if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION:
2386 # Note: the case ('PR', 'interpolation') has been handled above.
2387 heights = (y[:self.num_thresholds - 1] + y[1:]) / 2.
2388 elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:
2389 heights = math_ops.minimum(y[:self.num_thresholds - 1], y[1:])
2390 else: # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:
2391 heights = math_ops.maximum(y[:self.num_thresholds - 1], y[1:])
2393 # Sum up the areas of all the rectangles.
2394 if self.multi_label:
2395 riemann_terms = math_ops.multiply(x[:self.num_thresholds - 1] - x[1:],
2396 heights)
2397 by_label_auc = math_ops.reduce_sum(
2398 riemann_terms, name=self.name + '_by_label', axis=0)
2400 if self.label_weights is None:
2401 # Unweighted average of the label AUCs.
2402 return math_ops.reduce_mean(by_label_auc, name=self.name)
2403 else:
2404 # Weighted average of the label AUCs.
2405 return math_ops.div_no_nan(
2406 math_ops.reduce_sum(
2407 math_ops.multiply(by_label_auc, self.label_weights)),
2408 math_ops.reduce_sum(self.label_weights),
2409 name=self.name)
2410 else:
2411 return math_ops.reduce_sum(
2412 math_ops.multiply(x[:self.num_thresholds - 1] - x[1:], heights),
2413 name=self.name)
2415 def reset_state(self):
2416 if self._built:
2417 confusion_matrix_variables = (self.true_positives, self.true_negatives,
2418 self.false_positives, self.false_negatives)
2419 if self.multi_label:
2420 backend.batch_set_value(
2421 [(v, np.zeros((self.num_thresholds, self._num_labels)))
2422 for v in confusion_matrix_variables])
2423 else:
2424 backend.batch_set_value([(v, np.zeros((self.num_thresholds,)))
2425 for v in confusion_matrix_variables])
2427 def get_config(self):
2428 if is_tensor_or_variable(self.label_weights):
2429 label_weights = backend.eval(self.label_weights)
2430 else:
2431 label_weights = self.label_weights
2432 config = {
2433 'num_thresholds': self.num_thresholds,
2434 'curve': self.curve.value,
2435 'summation_method': self.summation_method.value,
2436 # We remove the endpoint thresholds as an inverse of how the thresholds
2437 # were initialized. This ensures that a metric initialized from this
2438 # config has the same thresholds.
2439 'thresholds': self.thresholds[1:-1],
2440 'multi_label': self.multi_label,
2441 'label_weights': label_weights
2442 }
2443 base_config = super(AUC, self).get_config()
2444 return dict(list(base_config.items()) + list(config.items()))
2447@keras_export('keras.metrics.CosineSimilarity')
2448class CosineSimilarity(MeanMetricWrapper):
2449 """Computes the cosine similarity between the labels and predictions.
2451 `cosine similarity = (a . b) / ||a|| ||b||`
2453 See: [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity).
2455 This metric keeps the average cosine similarity between `predictions` and
2456 `labels` over a stream of data.
2458 Args:
2459 name: (Optional) string name of the metric instance.
2460 dtype: (Optional) data type of the metric result.
2461 axis: (Optional) Defaults to -1. The dimension along which the cosine
2462 similarity is computed.
2464 Standalone usage:
2466 >>> # l2_norm(y_true) = [[0., 1.], [1./1.414, 1./1.414]]
2467 >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414, 1./1.414]]
2468 >>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]]
2469 >>> # result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1))
2470 >>> # = ((0. + 0.) + (0.5 + 0.5)) / 2
2471 >>> m = tf.keras.metrics.CosineSimilarity(axis=1)
2472 >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]])
2473 >>> m.result().numpy()
2474 0.49999997
2476 >>> m.reset_state()
2477 >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]],
2478 ... sample_weight=[0.3, 0.7])
2479 >>> m.result().numpy()
2480 0.6999999
2482 Usage with `compile()` API:
2484 ```python
2485 model.compile(
2486 optimizer='sgd',
2487 loss='mse',
2488 metrics=[tf.keras.metrics.CosineSimilarity(axis=1)])
2489 ```
2490 """
2492 def __init__(self, name='cosine_similarity', dtype=None, axis=-1):
2493 super(CosineSimilarity, self).__init__(
2494 cosine_similarity, name, dtype=dtype, axis=axis)
2497@keras_export('keras.metrics.MeanAbsoluteError')
2498class MeanAbsoluteError(MeanMetricWrapper):
2499 """Computes the mean absolute error between the labels and predictions.
2501 Args:
2502 name: (Optional) string name of the metric instance.
2503 dtype: (Optional) data type of the metric result.
2505 Standalone usage:
2507 >>> m = tf.keras.metrics.MeanAbsoluteError()
2508 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2509 >>> m.result().numpy()
2510 0.25
2512 >>> m.reset_state()
2513 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2514 ... sample_weight=[1, 0])
2515 >>> m.result().numpy()
2516 0.5
2518 Usage with `compile()` API:
2520 ```python
2521 model.compile(
2522 optimizer='sgd',
2523 loss='mse',
2524 metrics=[tf.keras.metrics.MeanAbsoluteError()])
2525 ```
2526 """
2528 def __init__(self, name='mean_absolute_error', dtype=None):
2529 super(MeanAbsoluteError, self).__init__(
2530 mean_absolute_error, name, dtype=dtype)
2533@keras_export('keras.metrics.MeanAbsolutePercentageError')
2534class MeanAbsolutePercentageError(MeanMetricWrapper):
2535 """Computes the mean absolute percentage error between `y_true` and `y_pred`.
2537 Args:
2538 name: (Optional) string name of the metric instance.
2539 dtype: (Optional) data type of the metric result.
2541 Standalone usage:
2543 >>> m = tf.keras.metrics.MeanAbsolutePercentageError()
2544 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2545 >>> m.result().numpy()
2546 250000000.0
2548 >>> m.reset_state()
2549 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2550 ... sample_weight=[1, 0])
2551 >>> m.result().numpy()
2552 500000000.0
2554 Usage with `compile()` API:
2556 ```python
2557 model.compile(
2558 optimizer='sgd',
2559 loss='mse',
2560 metrics=[tf.keras.metrics.MeanAbsolutePercentageError()])
2561 ```
2562 """
2564 def __init__(self, name='mean_absolute_percentage_error', dtype=None):
2565 super(MeanAbsolutePercentageError, self).__init__(
2566 mean_absolute_percentage_error, name, dtype=dtype)
2569@keras_export('keras.metrics.MeanSquaredError')
2570class MeanSquaredError(MeanMetricWrapper):
2571 """Computes the mean squared error between `y_true` and `y_pred`.
2573 Args:
2574 name: (Optional) string name of the metric instance.
2575 dtype: (Optional) data type of the metric result.
2577 Standalone usage:
2579 >>> m = tf.keras.metrics.MeanSquaredError()
2580 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2581 >>> m.result().numpy()
2582 0.25
2584 >>> m.reset_state()
2585 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2586 ... sample_weight=[1, 0])
2587 >>> m.result().numpy()
2588 0.5
2590 Usage with `compile()` API:
2592 ```python
2593 model.compile(
2594 optimizer='sgd',
2595 loss='mse',
2596 metrics=[tf.keras.metrics.MeanSquaredError()])
2597 ```
2598 """
2600 def __init__(self, name='mean_squared_error', dtype=None):
2601 super(MeanSquaredError, self).__init__(
2602 mean_squared_error, name, dtype=dtype)
2605@keras_export('keras.metrics.MeanSquaredLogarithmicError')
2606class MeanSquaredLogarithmicError(MeanMetricWrapper):
2607 """Computes the mean squared logarithmic error between `y_true` and `y_pred`.
2609 Args:
2610 name: (Optional) string name of the metric instance.
2611 dtype: (Optional) data type of the metric result.
2613 Standalone usage:
2615 >>> m = tf.keras.metrics.MeanSquaredLogarithmicError()
2616 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2617 >>> m.result().numpy()
2618 0.12011322
2620 >>> m.reset_state()
2621 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2622 ... sample_weight=[1, 0])
2623 >>> m.result().numpy()
2624 0.24022643
2626 Usage with `compile()` API:
2628 ```python
2629 model.compile(
2630 optimizer='sgd',
2631 loss='mse',
2632 metrics=[tf.keras.metrics.MeanSquaredLogarithmicError()])
2633 ```
2634 """
2636 def __init__(self, name='mean_squared_logarithmic_error', dtype=None):
2637 super(MeanSquaredLogarithmicError, self).__init__(
2638 mean_squared_logarithmic_error, name, dtype=dtype)
2641@keras_export('keras.metrics.Hinge')
2642class Hinge(MeanMetricWrapper):
2643 """Computes the hinge metric between `y_true` and `y_pred`.
2645 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
2646 provided we will convert them to -1 or 1.
2648 Args:
2649 name: (Optional) string name of the metric instance.
2650 dtype: (Optional) data type of the metric result.
2652 Standalone usage:
2654 >>> m = tf.keras.metrics.Hinge()
2655 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
2656 >>> m.result().numpy()
2657 1.3
2659 >>> m.reset_state()
2660 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
2661 ... sample_weight=[1, 0])
2662 >>> m.result().numpy()
2663 1.1
2665 Usage with `compile()` API:
2667 ```python
2668 model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.Hinge()])
2669 ```
2670 """
2672 def __init__(self, name='hinge', dtype=None):
2673 super(Hinge, self).__init__(hinge, name, dtype=dtype)
2676@keras_export('keras.metrics.SquaredHinge')
2677class SquaredHinge(MeanMetricWrapper):
2678 """Computes the squared hinge metric between `y_true` and `y_pred`.
2680 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
2681 provided we will convert them to -1 or 1.
2683 Args:
2684 name: (Optional) string name of the metric instance.
2685 dtype: (Optional) data type of the metric result.
2687 Standalone usage:
2689 >>> m = tf.keras.metrics.SquaredHinge()
2690 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
2691 >>> m.result().numpy()
2692 1.86
2694 >>> m.reset_state()
2695 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
2696 ... sample_weight=[1, 0])
2697 >>> m.result().numpy()
2698 1.46
2700 Usage with `compile()` API:
2702 ```python
2703 model.compile(
2704 optimizer='sgd',
2705 loss='mse',
2706 metrics=[tf.keras.metrics.SquaredHinge()])
2707 ```
2708 """
2710 def __init__(self, name='squared_hinge', dtype=None):
2711 super(SquaredHinge, self).__init__(squared_hinge, name, dtype=dtype)
2714@keras_export('keras.metrics.CategoricalHinge')
2715class CategoricalHinge(MeanMetricWrapper):
2716 """Computes the categorical hinge metric between `y_true` and `y_pred`.
2718 Args:
2719 name: (Optional) string name of the metric instance.
2720 dtype: (Optional) data type of the metric result.
2722 Standalone usage:
2724 >>> m = tf.keras.metrics.CategoricalHinge()
2725 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
2726 >>> m.result().numpy()
2727 1.4000001
2729 >>> m.reset_state()
2730 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
2731 ... sample_weight=[1, 0])
2732 >>> m.result().numpy()
2733 1.2
2735 Usage with `compile()` API:
2737 ```python
2738 model.compile(
2739 optimizer='sgd',
2740 loss='mse',
2741 metrics=[tf.keras.metrics.CategoricalHinge()])
2742 ```
2743 """
2745 def __init__(self, name='categorical_hinge', dtype=None):
2746 super(CategoricalHinge, self).__init__(categorical_hinge, name, dtype=dtype)
2749@keras_export('keras.metrics.RootMeanSquaredError')
2750class RootMeanSquaredError(Mean):
2751 """Computes root mean squared error metric between `y_true` and `y_pred`.
2753 Standalone usage:
2755 >>> m = tf.keras.metrics.RootMeanSquaredError()
2756 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2757 >>> m.result().numpy()
2758 0.5
2760 >>> m.reset_state()
2761 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2762 ... sample_weight=[1, 0])
2763 >>> m.result().numpy()
2764 0.70710677
2766 Usage with `compile()` API:
2768 ```python
2769 model.compile(
2770 optimizer='sgd',
2771 loss='mse',
2772 metrics=[tf.keras.metrics.RootMeanSquaredError()])
2773 ```
2774 """
2776 def __init__(self, name='root_mean_squared_error', dtype=None):
2777 super(RootMeanSquaredError, self).__init__(name, dtype=dtype)
2779 def update_state(self, y_true, y_pred, sample_weight=None):
2780 """Accumulates root mean squared error statistics.
2782 Args:
2783 y_true: The ground truth values.
2784 y_pred: The predicted values.
2785 sample_weight: Optional weighting of each example. Defaults to 1. Can be a
2786 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
2787 be broadcastable to `y_true`.
2789 Returns:
2790 Update op.
2791 """
2792 y_true = math_ops.cast(y_true, self._dtype)
2793 y_pred = math_ops.cast(y_pred, self._dtype)
2794 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
2795 y_pred, y_true)
2796 error_sq = math_ops.squared_difference(y_pred, y_true)
2797 return super(RootMeanSquaredError, self).update_state(
2798 error_sq, sample_weight=sample_weight)
2800 def result(self):
2801 return math_ops.sqrt(math_ops.div_no_nan(self.total, self.count))
2804@keras_export('keras.metrics.LogCoshError')
2805class LogCoshError(MeanMetricWrapper):
2806 """Computes the logarithm of the hyperbolic cosine of the prediction error.
2808 `logcosh = log((exp(x) + exp(-x))/2)`, where x is the error (y_pred - y_true)
2810 Args:
2811 name: (Optional) string name of the metric instance.
2812 dtype: (Optional) data type of the metric result.
2814 Standalone usage:
2816 >>> m = tf.keras.metrics.LogCoshError()
2817 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2818 >>> m.result().numpy()
2819 0.10844523
2821 >>> m.reset_state()
2822 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2823 ... sample_weight=[1, 0])
2824 >>> m.result().numpy()
2825 0.21689045
2827 Usage with `compile()` API:
2829 ```python
2830 model.compile(optimizer='sgd',
2831 loss='mse',
2832 metrics=[tf.keras.metrics.LogCoshError()])
2833 ```
2834 """
2836 def __init__(self, name='logcosh', dtype=None):
2837 super(LogCoshError, self).__init__(logcosh, name, dtype=dtype)
2840@keras_export('keras.metrics.Poisson')
2841class Poisson(MeanMetricWrapper):
2842 """Computes the Poisson metric between `y_true` and `y_pred`.
2844 `metric = y_pred - y_true * log(y_pred)`
2846 Args:
2847 name: (Optional) string name of the metric instance.
2848 dtype: (Optional) data type of the metric result.
2850 Standalone usage:
2852 >>> m = tf.keras.metrics.Poisson()
2853 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2854 >>> m.result().numpy()
2855 0.49999997
2857 >>> m.reset_state()
2858 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2859 ... sample_weight=[1, 0])
2860 >>> m.result().numpy()
2861 0.99999994
2863 Usage with `compile()` API:
2865 ```python
2866 model.compile(optimizer='sgd',
2867 loss='mse',
2868 metrics=[tf.keras.metrics.Poisson()])
2869 ```
2870 """
2872 def __init__(self, name='poisson', dtype=None):
2873 super(Poisson, self).__init__(poisson, name, dtype=dtype)
2876@keras_export('keras.metrics.KLDivergence')
2877class KLDivergence(MeanMetricWrapper):
2878 """Computes Kullback-Leibler divergence metric between `y_true` and `y_pred`.
2880 `metric = y_true * log(y_true / y_pred)`
2882 Args:
2883 name: (Optional) string name of the metric instance.
2884 dtype: (Optional) data type of the metric result.
2886 Standalone usage:
2888 >>> m = tf.keras.metrics.KLDivergence()
2889 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
2890 >>> m.result().numpy()
2891 0.45814306
2893 >>> m.reset_state()
2894 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
2895 ... sample_weight=[1, 0])
2896 >>> m.result().numpy()
2897 0.9162892
2899 Usage with `compile()` API:
2901 ```python
2902 model.compile(optimizer='sgd',
2903 loss='mse',
2904 metrics=[tf.keras.metrics.KLDivergence()])
2905 ```
2906 """
2908 def __init__(self, name='kullback_leibler_divergence', dtype=None):
2909 super(KLDivergence, self).__init__(
2910 kullback_leibler_divergence, name, dtype=dtype)
2913@keras_export('keras.metrics.MeanIoU')
2914class MeanIoU(Metric):
2915 """Computes the mean Intersection-Over-Union metric.
2917 Mean Intersection-Over-Union is a common evaluation metric for semantic image
2918 segmentation, which first computes the IOU for each semantic class and then
2919 computes the average over classes. IOU is defined as follows:
2920 IOU = true_positive / (true_positive + false_positive + false_negative).
2921 The predictions are accumulated in a confusion matrix, weighted by
2922 `sample_weight` and the metric is then calculated from it.
2924 If `sample_weight` is `None`, weights default to 1.
2925 Use `sample_weight` of 0 to mask values.
2927 Args:
2928 num_classes: The possible number of labels the prediction task can have.
2929 This value must be provided, since a confusion matrix of dimension =
2930 [num_classes, num_classes] will be allocated.
2931 name: (Optional) string name of the metric instance.
2932 dtype: (Optional) data type of the metric result.
2934 Standalone usage:
2936 >>> # cm = [[1, 1],
2937 >>> # [1, 1]]
2938 >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
2939 >>> # iou = true_positives / (sum_row + sum_col - true_positives))
2940 >>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33
2941 >>> m = tf.keras.metrics.MeanIoU(num_classes=2)
2942 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
2943 >>> m.result().numpy()
2944 0.33333334
2946 >>> m.reset_state()
2947 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],
2948 ... sample_weight=[0.3, 0.3, 0.3, 0.1])
2949 >>> m.result().numpy()
2950 0.23809525
2952 Usage with `compile()` API:
2954 ```python
2955 model.compile(
2956 optimizer='sgd',
2957 loss='mse',
2958 metrics=[tf.keras.metrics.MeanIoU(num_classes=2)])
2959 ```
2960 """
2962 def __init__(self, num_classes, name=None, dtype=None):
2963 super(MeanIoU, self).__init__(name=name, dtype=dtype)
2964 self.num_classes = num_classes
2966 # Variable to accumulate the predictions in the confusion matrix.
2967 self.total_cm = self.add_weight(
2968 'total_confusion_matrix',
2969 shape=(num_classes, num_classes),
2970 initializer=init_ops.zeros_initializer)
2972 def update_state(self, y_true, y_pred, sample_weight=None):
2973 """Accumulates the confusion matrix statistics.
2975 Args:
2976 y_true: The ground truth values.
2977 y_pred: The predicted values.
2978 sample_weight: Optional weighting of each example. Defaults to 1. Can be a
2979 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
2980 be broadcastable to `y_true`.
2982 Returns:
2983 Update op.
2984 """
2986 y_true = math_ops.cast(y_true, self._dtype)
2987 y_pred = math_ops.cast(y_pred, self._dtype)
2989 # Flatten the input if its rank > 1.
2990 if y_pred.shape.ndims > 1:
2991 y_pred = array_ops.reshape(y_pred, [-1])
2993 if y_true.shape.ndims > 1:
2994 y_true = array_ops.reshape(y_true, [-1])
2996 if sample_weight is not None:
2997 sample_weight = math_ops.cast(sample_weight, self._dtype)
2998 if sample_weight.shape.ndims > 1:
2999 sample_weight = array_ops.reshape(sample_weight, [-1])
3001 # Accumulate the prediction to current confusion matrix.
3002 current_cm = confusion_matrix.confusion_matrix(
3003 y_true,
3004 y_pred,
3005 self.num_classes,
3006 weights=sample_weight,
3007 dtype=self._dtype)
3008 return self.total_cm.assign_add(current_cm)
3010 def result(self):
3011 """Compute the mean intersection-over-union via the confusion matrix."""
3012 sum_over_row = math_ops.cast(
3013 math_ops.reduce_sum(self.total_cm, axis=0), dtype=self._dtype)
3014 sum_over_col = math_ops.cast(
3015 math_ops.reduce_sum(self.total_cm, axis=1), dtype=self._dtype)
3016 true_positives = math_ops.cast(
3017 array_ops.tensor_diag_part(self.total_cm), dtype=self._dtype)
3019 # sum_over_row + sum_over_col =
3020 # 2 * true_positives + false_positives + false_negatives.
3021 denominator = sum_over_row + sum_over_col - true_positives
3023 # The mean is only computed over classes that appear in the
3024 # label or prediction tensor. If the denominator is 0, we need to
3025 # ignore the class.
3026 num_valid_entries = math_ops.reduce_sum(
3027 math_ops.cast(math_ops.not_equal(denominator, 0), dtype=self._dtype))
3029 iou = math_ops.div_no_nan(true_positives, denominator)
3031 return math_ops.div_no_nan(
3032 math_ops.reduce_sum(iou, name='mean_iou'), num_valid_entries)
3034 def reset_state(self):
3035 backend.set_value(
3036 self.total_cm, np.zeros((self.num_classes, self.num_classes)))
3038 def get_config(self):
3039 config = {'num_classes': self.num_classes}
3040 base_config = super(MeanIoU, self).get_config()
3041 return dict(list(base_config.items()) + list(config.items()))
3044@keras_export('keras.metrics.MeanTensor')
3045class MeanTensor(Metric):
3046 """Computes the element-wise (weighted) mean of the given tensors.
3048 `MeanTensor` returns a tensor with the same shape of the input tensors. The
3049 mean value is updated by keeping local variables `total` and `count`. The
3050 `total` tracks the sum of the weighted values, and `count` stores the sum of
3051 the weighted counts.
3053 Args:
3054 name: (Optional) string name of the metric instance.
3055 dtype: (Optional) data type of the metric result.
3056 shape: (Optional) A list of integers, a tuple of integers, or a 1-D Tensor
3057 of type int32. If not specified, the shape is inferred from the values at
3058 the first call of update_state.
3060 Standalone usage:
3062 >>> m = tf.keras.metrics.MeanTensor()
3063 >>> m.update_state([0, 1, 2, 3])
3064 >>> m.update_state([4, 5, 6, 7])
3065 >>> m.result().numpy()
3066 array([2., 3., 4., 5.], dtype=float32)
3068 >>> m.update_state([12, 10, 8, 6], sample_weight= [0, 0.2, 0.5, 1])
3069 >>> m.result().numpy()
3070 array([2. , 3.6363635, 4.8 , 5.3333335], dtype=float32)
3072 >>> m = tf.keras.metrics.MeanTensor(dtype=tf.float64, shape=(1, 4))
3073 >>> m.result().numpy()
3074 array([[0., 0., 0., 0.]])
3075 >>> m.update_state([[0, 1, 2, 3]])
3076 >>> m.update_state([[4, 5, 6, 7]])
3077 >>> m.result().numpy()
3078 array([[2., 3., 4., 5.]])
3079 """
3081 def __init__(self, name='mean_tensor', dtype=None, shape=None):
3082 super(MeanTensor, self).__init__(name=name, dtype=dtype)
3083 self._shape = None
3084 self._total = None
3085 self._count = None
3086 self._built = False
3087 if shape is not None:
3088 self._build(shape)
3090 def _build(self, shape):
3091 self._shape = tensor_shape.TensorShape(shape)
3092 self._build_input_shape = self._shape
3093 # Create new state variables
3094 self._total = self.add_weight(
3095 'total', shape=shape, initializer=init_ops.zeros_initializer)
3096 self._count = self.add_weight(
3097 'count', shape=shape, initializer=init_ops.zeros_initializer)
3098 with ops.init_scope():
3099 if not context.executing_eagerly():
3100 backend._initialize_variables(backend._get_session()) # pylint: disable=protected-access
3101 self._built = True
3103 @property
3104 def total(self):
3105 return self._total if self._built else None
3107 @property
3108 def count(self):
3109 return self._count if self._built else None
3111 def update_state(self, values, sample_weight=None):
3112 """Accumulates statistics for computing the element-wise mean.
3114 Args:
3115 values: Per-example value.
3116 sample_weight: Optional weighting of each example. Defaults to 1.
3118 Returns:
3119 Update op.
3120 """
3121 values = math_ops.cast(values, self._dtype)
3122 if not self._built:
3123 self._build(values.shape)
3124 elif values.shape != self._shape:
3125 raise ValueError('MeanTensor input values must always have the same '
3126 'shape. Expected shape (set during the first call): {}. '
3127 'Got: {}'.format(self._shape, values.shape))
3129 num_values = array_ops.ones_like(values)
3130 if sample_weight is not None:
3131 sample_weight = math_ops.cast(sample_weight, self._dtype)
3133 # Update dimensions of weights to match with values if possible.
3134 values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions(
3135 values, sample_weight=sample_weight)
3136 try:
3137 # Broadcast weights if possible.
3138 sample_weight = weights_broadcast_ops.broadcast_weights(
3139 sample_weight, values)
3140 except ValueError:
3141 # Reduce values to same ndim as weight array
3142 ndim = backend.ndim(values)
3143 weight_ndim = backend.ndim(sample_weight)
3144 values = math_ops.reduce_mean(
3145 values, axis=list(range(weight_ndim, ndim)))
3147 num_values = math_ops.multiply(num_values, sample_weight)
3148 values = math_ops.multiply(values, sample_weight)
3150 update_total_op = self._total.assign_add(values)
3151 with ops.control_dependencies([update_total_op]):
3152 return self._count.assign_add(num_values)
3154 def result(self):
3155 if not self._built:
3156 raise ValueError(
3157 'MeanTensor does not have any result yet. Please call the MeanTensor '
3158 'instance or use `.update_state(value)` before retrieving the result.'
3159 )
3160 return math_ops.div_no_nan(self.total, self.count)
3162 def reset_state(self):
3163 if self._built:
3164 backend.batch_set_value(
3165 [(v, np.zeros(self._shape.as_list())) for v in self.variables])
3168@keras_export('keras.metrics.BinaryCrossentropy')
3169class BinaryCrossentropy(MeanMetricWrapper):
3170 """Computes the crossentropy metric between the labels and predictions.
3172 This is the crossentropy metric class to be used when there are only two
3173 label classes (0 and 1).
3175 Args:
3176 name: (Optional) string name of the metric instance.
3177 dtype: (Optional) data type of the metric result.
3178 from_logits: (Optional )Whether output is expected to be a logits tensor.
3179 By default, we consider that output encodes a probability distribution.
3180 label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are
3181 smoothed, meaning the confidence on label values are relaxed.
3182 e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for
3183 label `0` and `0.9` for label `1`".
3185 Standalone usage:
3187 >>> m = tf.keras.metrics.BinaryCrossentropy()
3188 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
3189 >>> m.result().numpy()
3190 0.81492424
3192 >>> m.reset_state()
3193 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
3194 ... sample_weight=[1, 0])
3195 >>> m.result().numpy()
3196 0.9162905
3198 Usage with `compile()` API:
3200 ```python
3201 model.compile(
3202 optimizer='sgd',
3203 loss='mse',
3204 metrics=[tf.keras.metrics.BinaryCrossentropy()])
3205 ```
3206 """
3208 def __init__(self,
3209 name='binary_crossentropy',
3210 dtype=None,
3211 from_logits=False,
3212 label_smoothing=0):
3213 super(BinaryCrossentropy, self).__init__(
3214 binary_crossentropy,
3215 name,
3216 dtype=dtype,
3217 from_logits=from_logits,
3218 label_smoothing=label_smoothing)
3221@keras_export('keras.metrics.CategoricalCrossentropy')
3222class CategoricalCrossentropy(MeanMetricWrapper):
3223 """Computes the crossentropy metric between the labels and predictions.
3225 This is the crossentropy metric class to be used when there are multiple
3226 label classes (2 or more). Here we assume that labels are given as a `one_hot`
3227 representation. eg., When labels values are [2, 0, 1],
3228 `y_true` = [[0, 0, 1], [1, 0, 0], [0, 1, 0]].
3230 Args:
3231 name: (Optional) string name of the metric instance.
3232 dtype: (Optional) data type of the metric result.
3233 from_logits: (Optional) Whether output is expected to be a logits tensor.
3234 By default, we consider that output encodes a probability distribution.
3235 label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are
3236 smoothed, meaning the confidence on label values are relaxed. e.g.
3237 `label_smoothing=0.2` means that we will use a value of `0.1` for label
3238 `0` and `0.9` for label `1`"
3240 Standalone usage:
3242 >>> # EPSILON = 1e-7, y = y_true, y` = y_pred
3243 >>> # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON)
3244 >>> # y` = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]]
3245 >>> # xent = -sum(y * log(y'), axis = -1)
3246 >>> # = -((log 0.95), (log 0.1))
3247 >>> # = [0.051, 2.302]
3248 >>> # Reduced xent = (0.051 + 2.302) / 2
3249 >>> m = tf.keras.metrics.CategoricalCrossentropy()
3250 >>> m.update_state([[0, 1, 0], [0, 0, 1]],
3251 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
3252 >>> m.result().numpy()
3253 1.1769392
3255 >>> m.reset_state()
3256 >>> m.update_state([[0, 1, 0], [0, 0, 1]],
3257 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]],
3258 ... sample_weight=tf.constant([0.3, 0.7]))
3259 >>> m.result().numpy()
3260 1.6271976
3262 Usage with `compile()` API:
3264 ```python
3265 model.compile(
3266 optimizer='sgd',
3267 loss='mse',
3268 metrics=[tf.keras.metrics.CategoricalCrossentropy()])
3269 ```
3270 """
3272 def __init__(self,
3273 name='categorical_crossentropy',
3274 dtype=None,
3275 from_logits=False,
3276 label_smoothing=0):
3277 super(CategoricalCrossentropy, self).__init__(
3278 categorical_crossentropy,
3279 name,
3280 dtype=dtype,
3281 from_logits=from_logits,
3282 label_smoothing=label_smoothing)
3285@keras_export('keras.metrics.SparseCategoricalCrossentropy')
3286class SparseCategoricalCrossentropy(MeanMetricWrapper):
3287 """Computes the crossentropy metric between the labels and predictions.
3289 Use this crossentropy metric when there are two or more label classes.
3290 We expect labels to be provided as integers. If you want to provide labels
3291 using `one-hot` representation, please use `CategoricalCrossentropy` metric.
3292 There should be `# classes` floating point values per feature for `y_pred`
3293 and a single floating point value per feature for `y_true`.
3295 In the snippet below, there is a single floating point value per example for
3296 `y_true` and `# classes` floating pointing values per example for `y_pred`.
3297 The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is
3298 `[batch_size, num_classes]`.
3300 Args:
3301 name: (Optional) string name of the metric instance.
3302 dtype: (Optional) data type of the metric result.
3303 from_logits: (Optional) Whether output is expected to be a logits tensor.
3304 By default, we consider that output encodes a probability distribution.
3305 axis: (Optional) Defaults to -1. The dimension along which the metric is
3306 computed.
3308 Standalone usage:
3310 >>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]]
3311 >>> # logits = log(y_pred)
3312 >>> # softmax = exp(logits) / sum(exp(logits), axis=-1)
3313 >>> # softmax = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]]
3314 >>> # xent = -sum(y * log(softmax), 1)
3315 >>> # log(softmax) = [[-2.9957, -0.0513, -16.1181],
3316 >>> # [-2.3026, -0.2231, -2.3026]]
3317 >>> # y_true * log(softmax) = [[0, -0.0513, 0], [0, 0, -2.3026]]
3318 >>> # xent = [0.0513, 2.3026]
3319 >>> # Reduced xent = (0.0513 + 2.3026) / 2
3320 >>> m = tf.keras.metrics.SparseCategoricalCrossentropy()
3321 >>> m.update_state([1, 2],
3322 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
3323 >>> m.result().numpy()
3324 1.1769392
3326 >>> m.reset_state()
3327 >>> m.update_state([1, 2],
3328 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]],
3329 ... sample_weight=tf.constant([0.3, 0.7]))
3330 >>> m.result().numpy()
3331 1.6271976
3333 Usage with `compile()` API:
3335 ```python
3336 model.compile(
3337 optimizer='sgd',
3338 loss='mse',
3339 metrics=[tf.keras.metrics.SparseCategoricalCrossentropy()])
3340 ```
3341 """
3343 def __init__(self,
3344 name='sparse_categorical_crossentropy',
3345 dtype=None,
3346 from_logits=False,
3347 axis=-1):
3348 super(SparseCategoricalCrossentropy, self).__init__(
3349 sparse_categorical_crossentropy,
3350 name,
3351 dtype=dtype,
3352 from_logits=from_logits,
3353 axis=axis)
3356class SumOverBatchSize(Reduce):
3357 """Computes the weighted sum over batch size of the given values.
3359 For example, if values is [1, 3, 5, 7] then the metric value is 4.
3360 If the weights were specified as [1, 1, 0, 0] then the value would be 1.
3362 This metric creates two variables, `total` and `count` that are used to
3363 compute the average of `values`. This average is ultimately returned as sum
3364 over batch size which is an idempotent operation that simply divides `total`
3365 by `count`.
3367 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0
3368 to mask values.
3369 """
3371 def __init__(self, name='sum_over_batch_size', dtype=None):
3372 super(SumOverBatchSize, self).__init__(
3373 reduction=metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
3374 name=name,
3375 dtype=dtype)
3378class SumOverBatchSizeMetricWrapper(SumOverBatchSize):
3379 """Wraps a function with the `SumOverBatchSizeMetricWrapper` metric."""
3381 def __init__(self, fn, name=None, dtype=None, **kwargs):
3382 """Creates a `SumOverBatchSizeMetricWrapper` instance.
3384 Args:
3385 fn: The metric function to wrap, with signature `fn(y_true, y_pred,
3386 **kwargs)`.
3387 name: (Optional) string name of the metric instance.
3388 dtype: (Optional) data type of the metric result.
3389 **kwargs: The keyword arguments that are passed on to `fn`.
3390 """
3391 super(SumOverBatchSizeMetricWrapper, self).__init__(name=name, dtype=dtype)
3392 self._fn = fn
3393 self._fn_kwargs = kwargs
3395 def update_state(self, y_true, y_pred, sample_weight=None):
3396 y_true = math_ops.cast(y_true, self._dtype)
3397 y_pred = math_ops.cast(y_pred, self._dtype)
3398 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
3399 y_pred, y_true)
3401 ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx())
3402 matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
3403 return super(SumOverBatchSizeMetricWrapper, self).update_state(
3404 matches, sample_weight=sample_weight)
3406 def get_config(self):
3407 config = {}
3408 for k, v in self._fn_kwargs.items():
3409 config[k] = backend.eval(v) if is_tensor_or_variable(v) else v
3410 base_config = super(SumOverBatchSizeMetricWrapper, self).get_config()
3411 return dict(list(base_config.items()) + list(config.items()))
3414def accuracy(y_true, y_pred):
3415 [y_pred, y_true], _ = \
3416 metrics_utils.ragged_assert_compatible_and_get_flat_values(
3417 [y_pred, y_true])
3418 y_true.shape.assert_is_compatible_with(y_pred.shape)
3419 if y_true.dtype != y_pred.dtype:
3420 y_pred = math_ops.cast(y_pred, y_true.dtype)
3421 return math_ops.cast(math_ops.equal(y_true, y_pred), backend.floatx())
3424@keras_export('keras.metrics.binary_accuracy')
3425@dispatch.add_dispatch_support
3426def binary_accuracy(y_true, y_pred, threshold=0.5):
3427 """Calculates how often predictions match binary labels.
3429 Standalone usage:
3430 >>> y_true = [[1], [1], [0], [0]]
3431 >>> y_pred = [[1], [1], [0], [0]]
3432 >>> m = tf.keras.metrics.binary_accuracy(y_true, y_pred)
3433 >>> assert m.shape == (4,)
3434 >>> m.numpy()
3435 array([1., 1., 1., 1.], dtype=float32)
3437 Args:
3438 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
3439 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
3440 threshold: (Optional) Float representing the threshold for deciding whether
3441 prediction values are 1 or 0.
3443 Returns:
3444 Binary accuracy values. shape = `[batch_size, d0, .. dN-1]`
3445 """
3446 y_pred = tensor_conversion.convert_to_tensor_v2_with_dispatch(y_pred)
3447 threshold = math_ops.cast(threshold, y_pred.dtype)
3448 y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype)
3449 return backend.mean(math_ops.equal(y_true, y_pred), axis=-1)
3452@keras_export('keras.metrics.categorical_accuracy')
3453@dispatch.add_dispatch_support
3454def categorical_accuracy(y_true, y_pred):
3455 """Calculates how often predictions match one-hot labels.
3457 Standalone usage:
3458 >>> y_true = [[0, 0, 1], [0, 1, 0]]
3459 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
3460 >>> m = tf.keras.metrics.categorical_accuracy(y_true, y_pred)
3461 >>> assert m.shape == (2,)
3462 >>> m.numpy()
3463 array([0., 1.], dtype=float32)
3465 You can provide logits of classes as `y_pred`, since argmax of
3466 logits and probabilities are same.
3468 Args:
3469 y_true: One-hot ground truth values.
3470 y_pred: The prediction values.
3472 Returns:
3473 Categorical accuracy values.
3474 """
3475 return math_ops.cast(
3476 math_ops.equal(
3477 math_ops.argmax(y_true, axis=-1), math_ops.argmax(y_pred, axis=-1)),
3478 backend.floatx())
3481@keras_export('keras.metrics.sparse_categorical_accuracy')
3482@dispatch.add_dispatch_support
3483def sparse_categorical_accuracy(y_true, y_pred):
3484 """Calculates how often predictions match integer labels.
3486 Standalone usage:
3487 >>> y_true = [2, 1]
3488 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
3489 >>> m = tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred)
3490 >>> assert m.shape == (2,)
3491 >>> m.numpy()
3492 array([0., 1.], dtype=float32)
3494 You can provide logits of classes as `y_pred`, since argmax of
3495 logits and probabilities are same.
3497 Args:
3498 y_true: Integer ground truth values.
3499 y_pred: The prediction values.
3501 Returns:
3502 Sparse categorical accuracy values.
3503 """
3504 y_pred = tensor_conversion.convert_to_tensor_v2_with_dispatch(y_pred)
3505 y_true = tensor_conversion.convert_to_tensor_v2_with_dispatch(y_true)
3506 y_pred_rank = y_pred.shape.ndims
3507 y_true_rank = y_true.shape.ndims
3508 # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
3509 if (y_true_rank is not None) and (y_pred_rank is not None) and (len(
3510 backend.int_shape(y_true)) == len(backend.int_shape(y_pred))):
3511 y_true = array_ops.squeeze(y_true, [-1])
3512 y_pred = math_ops.argmax(y_pred, axis=-1)
3514 # If the predicted output and actual output types don't match, force cast them
3515 # to match.
3516 if backend.dtype(y_pred) != backend.dtype(y_true):
3517 y_pred = math_ops.cast(y_pred, backend.dtype(y_true))
3519 return math_ops.cast(math_ops.equal(y_true, y_pred), backend.floatx())
3522@keras_export('keras.metrics.top_k_categorical_accuracy')
3523@dispatch.add_dispatch_support
3524def top_k_categorical_accuracy(y_true, y_pred, k=5):
3525 """Computes how often targets are in the top `K` predictions.
3527 Standalone usage:
3528 >>> y_true = [[0, 0, 1], [0, 1, 0]]
3529 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
3530 >>> m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)
3531 >>> assert m.shape == (2,)
3532 >>> m.numpy()
3533 array([1., 1.], dtype=float32)
3535 Args:
3536 y_true: The ground truth values.
3537 y_pred: The prediction values.
3538 k: (Optional) Number of top elements to look at for computing accuracy.
3539 Defaults to 5.
3541 Returns:
3542 Top K categorical accuracy value.
3543 """
3544 return math_ops.cast(
3545 nn.in_top_k(
3546 y_pred, math_ops.argmax(y_true, axis=-1), k), backend.floatx())
3549@keras_export('keras.metrics.sparse_top_k_categorical_accuracy')
3550@dispatch.add_dispatch_support
3551def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
3552 """Computes how often integer targets are in the top `K` predictions.
3554 Standalone usage:
3555 >>> y_true = [2, 1]
3556 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
3557 >>> m = tf.keras.metrics.sparse_top_k_categorical_accuracy(
3558 ... y_true, y_pred, k=3)
3559 >>> assert m.shape == (2,)
3560 >>> m.numpy()
3561 array([1., 1.], dtype=float32)
3563 Args:
3564 y_true: tensor of true targets.
3565 y_pred: tensor of predicted targets.
3566 k: (Optional) Number of top elements to look at for computing accuracy.
3567 Defaults to 5.
3569 Returns:
3570 Sparse top K categorical accuracy value.
3571 """
3572 y_pred_rank = tensor_conversion.convert_to_tensor_v2_with_dispatch(
3573 y_pred
3574 ).shape.ndims
3575 y_true_rank = tensor_conversion.convert_to_tensor_v2_with_dispatch(
3576 y_true
3577 ).shape.ndims
3578 # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
3579 if (y_true_rank is not None) and (y_pred_rank is not None):
3580 if y_pred_rank > 2:
3581 y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]])
3582 if y_true_rank > 1:
3583 y_true = array_ops.reshape(y_true, [-1])
3585 return math_ops.cast(
3586 nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), backend.floatx())
3589def cosine_proximity(y_true, y_pred, axis=-1):
3590 """Computes the cosine similarity between labels and predictions.
3592 Args:
3593 y_true: The ground truth values.
3594 y_pred: The prediction values.
3595 axis: (Optional) Defaults to -1. The dimension along which the cosine
3596 similarity is computed.
3598 Returns:
3599 Cosine similarity value.
3600 """
3601 y_true = nn.l2_normalize(y_true, axis=axis)
3602 y_pred = nn.l2_normalize(y_pred, axis=axis)
3603 return math_ops.reduce_sum(y_true * y_pred, axis=axis)
3605# Aliases
3607acc = ACC = accuracy
3608bce = BCE = binary_crossentropy
3609mse = MSE = mean_squared_error
3610mae = MAE = mean_absolute_error
3611mape = MAPE = mean_absolute_percentage_error
3612msle = MSLE = mean_squared_logarithmic_error
3613cosine_similarity = cosine_proximity
3614log_cosh = logcosh
3617def clone_metric(metric):
3618 """Returns a clone of the metric if stateful, otherwise returns it as is."""
3619 if isinstance(metric, Metric):
3620 with ops.init_scope():
3621 return metric.__class__.from_config(metric.get_config())
3622 return metric
3625def clone_metrics(metrics):
3626 """Clones the given metric list/dict."""
3627 return nest.map_structure(clone_metric, metrics)
3630@keras_export('keras.metrics.serialize')
3631def serialize(metric):
3632 """Serializes metric function or `Metric` instance.
3634 Args:
3635 metric: A Keras `Metric` instance or a metric function.
3637 Returns:
3638 Metric configuration dictionary.
3639 """
3640 return serialize_keras_object(metric)
3643@keras_export('keras.metrics.deserialize')
3644def deserialize(config, custom_objects=None):
3645 """Deserializes a serialized metric class/function instance.
3647 Args:
3648 config: Metric configuration.
3649 custom_objects: Optional dictionary mapping names (strings) to custom
3650 objects (classes and functions) to be considered during deserialization.
3652 Returns:
3653 A Keras `Metric` instance or a metric function.
3654 """
3655 return deserialize_keras_object(
3656 config,
3657 module_objects=globals(),
3658 custom_objects=custom_objects,
3659 printable_module_name='metric function')
3662@keras_export('keras.metrics.get')
3663def get(identifier):
3664 """Retrieves a Keras metric as a `function`/`Metric` class instance.
3666 The `identifier` may be the string name of a metric function or class.
3668 >>> metric = tf.keras.metrics.get("categorical_crossentropy")
3669 >>> type(metric)
3670 <class 'function'>
3671 >>> metric = tf.keras.metrics.get("CategoricalCrossentropy")
3672 >>> type(metric)
3673 <class '...keras.metrics.CategoricalCrossentropy'>
3675 You can also specify `config` of the metric to this function by passing dict
3676 containing `class_name` and `config` as an identifier. Also note that the
3677 `class_name` must map to a `Metric` class
3679 >>> identifier = {"class_name": "CategoricalCrossentropy",
3680 ... "config": {"from_logits": True}}
3681 >>> metric = tf.keras.metrics.get(identifier)
3682 >>> type(metric)
3683 <class '...keras.metrics.CategoricalCrossentropy'>
3685 Args:
3686 identifier: A metric identifier. One of None or string name of a metric
3687 function/class or metric configuration dictionary or a metric function or
3688 a metric class instance
3690 Returns:
3691 A Keras metric as a `function`/ `Metric` class instance.
3693 Raises:
3694 ValueError: If `identifier` cannot be interpreted.
3695 """
3696 if isinstance(identifier, dict):
3697 return deserialize(identifier)
3698 elif isinstance(identifier, str):
3699 return deserialize(str(identifier))
3700 elif callable(identifier):
3701 return identifier
3702 else:
3703 raise ValueError(
3704 'Could not interpret metric function identifier: {}'.format(identifier))
3707def is_built_in(cls):
3708 return cls.__module__ == Metric.__module__