Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/metrics_impl.py: 14%
796 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2# Licensed under the Apache License, Version 2.0 (the "License");
3# you may not use this file except in compliance with the License.
4# You may obtain a copy of the License at
5#
6# http://www.apache.org/licenses/LICENSE-2.0
7#
8# Unless required by applicable law or agreed to in writing, software
9# distributed under the License is distributed on an "AS IS" BASIS,
10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11# See the License for the specific language governing permissions and
12# limitations under the License.
13# ==============================================================================
14"""Implementation of tf.metrics module."""
16from tensorflow.python.distribute import distribute_lib
17from tensorflow.python.eager import context
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import sparse_tensor
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import array_ops_stack
23from tensorflow.python.ops import check_ops
24from tensorflow.python.ops import cond
25from tensorflow.python.ops import confusion_matrix
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import nn
28from tensorflow.python.ops import sets
29from tensorflow.python.ops import sparse_ops
30from tensorflow.python.ops import state_ops
31from tensorflow.python.ops import variable_scope
32from tensorflow.python.ops import variable_v1
33from tensorflow.python.ops import variables
34from tensorflow.python.ops import weights_broadcast_ops
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.util.deprecation import deprecated
37from tensorflow.python.util.tf_export import tf_export
40def metric_variable(shape, dtype, validate_shape=True, name=None):
41 """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
43 If running in a `DistributionStrategy` context, the variable will be
44 "sync on read". This means:
46 * The returned object will be a container with separate variables
47 per replica of the model.
49 * When writing to the variable, e.g. using `assign_add` in a metric
50 update, the update will be applied to the variable local to the
51 replica.
53 * To get a metric's result value, we need to sum the variable values
54 across the replicas before computing the final answer. Furthermore,
55 the final answer should be computed once instead of in every
56 replica. Both of these are accomplished by running the computation
57 of the final result value inside
58 `distribute_lib.get_replica_context().merge_call(fn)`.
59 Inside the `merge_call()`, ops are only added to the graph once
60 and access to a sync on read variable in a computation returns
61 the sum across all replicas.
63 Args:
64 shape: Shape of the created variable.
65 dtype: Type of the created variable.
66 validate_shape: (Optional) Whether shape validation is enabled for
67 the created variable.
68 name: (Optional) String name of the created variable.
70 Returns:
71 A (non-trainable) variable initialized to zero, or if inside a
72 `DistributionStrategy` scope a sync on read variable container.
73 """
74 # Note that synchronization "ON_READ" implies trainable=False.
75 return variable_v1.VariableV1(
76 lambda: array_ops.zeros(shape, dtype),
77 trainable=False,
78 collections=[
79 ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
80 ],
81 validate_shape=validate_shape,
82 synchronization=variables.VariableSynchronization.ON_READ,
83 aggregation=variables.VariableAggregation.SUM,
84 name=name)
87def _remove_squeezable_dimensions(predictions, labels, weights):
88 """Squeeze or expand last dim if needed.
90 Squeezes last dim of `predictions` or `labels` if their rank differs by 1
91 (using confusion_matrix.remove_squeezable_dimensions).
92 Squeezes or expands last dim of `weights` if its rank differs by 1 from the
93 new rank of `predictions`.
95 If `weights` is scalar, it is kept scalar.
97 This will use static shape if available. Otherwise, it will add graph
98 operations, which could result in a performance hit.
100 Args:
101 predictions: Predicted values, a `Tensor` of arbitrary dimensions.
102 labels: Optional label `Tensor` whose dimensions match `predictions`.
103 weights: Optional weight scalar or `Tensor` whose dimensions match
104 `predictions`.
106 Returns:
107 Tuple of `predictions`, `labels` and `weights`. Each of them possibly has
108 the last dimension squeezed, `weights` could be extended by one dimension.
109 """
110 predictions = ops.convert_to_tensor(predictions)
111 if labels is not None:
112 labels, predictions = confusion_matrix.remove_squeezable_dimensions(
113 labels, predictions)
114 predictions.get_shape().assert_is_compatible_with(labels.get_shape())
116 if weights is None:
117 return predictions, labels, None
119 weights = ops.convert_to_tensor(weights)
120 weights_shape = weights.get_shape()
121 weights_rank = weights_shape.ndims
122 if weights_rank == 0:
123 return predictions, labels, weights
125 predictions_shape = predictions.get_shape()
126 predictions_rank = predictions_shape.ndims
127 if (predictions_rank is not None) and (weights_rank is not None):
128 # Use static rank.
129 if weights_rank - predictions_rank == 1:
130 weights = array_ops.squeeze(weights, [-1])
131 elif predictions_rank - weights_rank == 1:
132 weights = array_ops.expand_dims(weights, [-1])
133 else:
134 # Use dynamic rank.
135 weights_rank_tensor = array_ops.rank(weights)
136 rank_diff = weights_rank_tensor - array_ops.rank(predictions)
138 def _maybe_expand_weights():
139 return cond.cond(
140 math_ops.equal(rank_diff, -1),
141 lambda: array_ops.expand_dims(weights, [-1]), lambda: weights)
143 # Don't attempt squeeze if it will fail based on static check.
144 if ((weights_rank is not None) and
145 (not weights_shape.dims[-1].is_compatible_with(1))):
146 maybe_squeeze_weights = lambda: weights
147 else:
148 maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1])
150 def _maybe_adjust_weights():
151 return cond.cond(
152 math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
153 _maybe_expand_weights)
155 # If weights are scalar, do nothing. Otherwise, try to add or remove a
156 # dimension to match predictions.
157 weights = cond.cond(
158 math_ops.equal(weights_rank_tensor, 0), lambda: weights,
159 _maybe_adjust_weights)
160 return predictions, labels, weights
163def _maybe_expand_labels(labels, predictions):
164 """If necessary, expand `labels` along last dimension to match `predictions`.
166 Args:
167 labels: `Tensor` or `SparseTensor` with shape
168 [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies
169 num_labels=1, in which case the result is an expanded `labels` with shape
170 [D1, ... DN, 1].
171 predictions: `Tensor` with shape [D1, ... DN, num_classes].
173 Returns:
174 `labels` with the same rank as `predictions`.
176 Raises:
177 ValueError: if `labels` has invalid shape.
178 """
179 with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope:
180 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
182 # If sparse, expand sparse shape.
183 if isinstance(labels, sparse_tensor.SparseTensor):
184 return cond.cond(
185 math_ops.equal(
186 array_ops.rank(predictions),
187 array_ops.size(labels.dense_shape) + 1),
188 lambda: sparse_ops.sparse_reshape( # pylint: disable=g-long-lambda
189 labels,
190 shape=array_ops.concat((labels.dense_shape, (1,)), 0),
191 name=scope),
192 lambda: labels)
194 # Otherwise, try to use static shape.
195 labels_rank = labels.get_shape().ndims
196 if labels_rank is not None:
197 predictions_rank = predictions.get_shape().ndims
198 if predictions_rank is not None:
199 if predictions_rank == labels_rank:
200 return labels
201 if predictions_rank == labels_rank + 1:
202 return array_ops.expand_dims(labels, -1, name=scope)
203 raise ValueError(
204 f'Unexpected labels shape {labels.get_shape()} for predictions '
205 f'shape {predictions.get_shape()}. Predictions rank should be the '
206 'same rank as labels rank or labels rank plus one .')
208 # Otherwise, use dynamic shape.
209 return cond.cond(
210 math_ops.equal(array_ops.rank(predictions),
211 array_ops.rank(labels) + 1),
212 lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
215def _safe_scalar_div(numerator, denominator, name):
216 """Divides two values, returning 0 if the denominator is 0.
218 Args:
219 numerator: A scalar `float64` `Tensor`.
220 denominator: A scalar `float64` `Tensor`.
221 name: Name for the returned op.
223 Returns:
224 0 if `denominator` == 0, else `numerator` / `denominator`
225 """
226 numerator.get_shape().with_rank_at_most(1)
227 denominator.get_shape().with_rank_at_most(1)
228 return math_ops.div_no_nan(numerator, denominator, name=name)
231def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
232 """Calculate a streaming confusion matrix.
234 Calculates a confusion matrix. For estimation over a stream of data,
235 the function creates an `update_op` operation.
237 Args:
238 labels: A `Tensor` of ground truth labels with shape [batch size] and of
239 type `int32` or `int64`. The tensor will be flattened if its rank > 1.
240 predictions: A `Tensor` of prediction results for semantic labels, whose
241 shape is [batch size] and type `int32` or `int64`. The tensor will be
242 flattened if its rank > 1.
243 num_classes: The possible number of labels the prediction task can
244 have. This value must be provided, since a confusion matrix of
245 dimension = [num_classes, num_classes] will be allocated.
246 weights: Optional `Tensor` whose rank is either 0, or the same rank as
247 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
248 be either `1`, or the same as the corresponding `labels` dimension).
250 Returns:
251 total_cm: A `Tensor` representing the confusion matrix.
252 update_op: An operation that increments the confusion matrix.
253 """
254 # Local variable to accumulate the predictions in the confusion matrix.
255 total_cm = metric_variable(
256 [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix')
258 # Cast the type to int64 required by confusion_matrix_ops.
259 predictions = math_ops.cast(predictions, dtypes.int64)
260 labels = math_ops.cast(labels, dtypes.int64)
261 num_classes = math_ops.cast(num_classes, dtypes.int64)
263 # Flatten the input if its rank > 1.
264 if predictions.get_shape().ndims > 1:
265 predictions = array_ops.reshape(predictions, [-1])
267 if labels.get_shape().ndims > 1:
268 labels = array_ops.reshape(labels, [-1])
270 if (weights is not None) and (weights.get_shape().ndims > 1):
271 weights = array_ops.reshape(weights, [-1])
273 # Accumulate the prediction to current confusion matrix.
274 current_cm = confusion_matrix.confusion_matrix(
275 labels, predictions, num_classes, weights=weights, dtype=dtypes.float64)
276 update_op = state_ops.assign_add(total_cm, current_cm)
277 return total_cm, update_op
280def _aggregate_across_replicas(metrics_collections, metric_value_fn, *args):
281 """Aggregate metric value across replicas."""
282 def fn(distribution, *a):
283 """Call `metric_value_fn` in the correct control flow context."""
284 if hasattr(distribution.extended, '_outer_control_flow_context'):
285 # If there was an outer context captured before this method was called,
286 # then we enter that context to create the metric value op. If the
287 # captured context is `None`, ops.control_dependencies(None) gives the
288 # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
289 # captured context.
290 # This special handling is needed because sometimes the metric is created
291 # inside a while_loop (and perhaps a TPU rewrite context). But we don't
292 # want the value op to be evaluated every step or on the TPU. So we
293 # create it outside so that it can be evaluated at the end on the host,
294 # once the update ops have been evaluated.
296 # pylint: disable=protected-access
297 if distribution.extended._outer_control_flow_context is None:
298 with ops.control_dependencies(None):
299 metric_value = metric_value_fn(distribution, *a)
300 else:
301 distribution.extended._outer_control_flow_context.Enter()
302 metric_value = metric_value_fn(distribution, *a)
303 distribution.extended._outer_control_flow_context.Exit()
304 # pylint: enable=protected-access
305 else:
306 metric_value = metric_value_fn(distribution, *a)
307 if metrics_collections:
308 ops.add_to_collections(metrics_collections, metric_value)
309 return metric_value
311 return distribute_lib.get_replica_context().merge_call(
312 fn, args=args)
315@tf_export(v1=['metrics.mean'])
316def mean(values,
317 weights=None,
318 metrics_collections=None,
319 updates_collections=None,
320 name=None):
321 """Computes the (weighted) mean of the given values.
323 The `mean` function creates two local variables, `total` and `count`
324 that are used to compute the average of `values`. This average is ultimately
325 returned as `mean` which is an idempotent operation that simply divides
326 `total` by `count`.
328 For estimation of the metric over a stream of data, the function creates an
329 `update_op` operation that updates these variables and returns the `mean`.
330 `update_op` increments `total` with the reduced sum of the product of `values`
331 and `weights`, and it increments `count` with the reduced sum of `weights`.
333 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
335 Args:
336 values: A `Tensor` of arbitrary dimensions.
337 weights: Optional `Tensor` whose rank is either 0, or the same rank as
338 `values`, and must be broadcastable to `values` (i.e., all dimensions must
339 be either `1`, or the same as the corresponding `values` dimension).
340 metrics_collections: An optional list of collections that `mean`
341 should be added to.
342 updates_collections: An optional list of collections that `update_op`
343 should be added to.
344 name: An optional variable_scope name.
346 Returns:
347 mean: A `Tensor` representing the current mean, the value of `total` divided
348 by `count`.
349 update_op: An operation that increments the `total` and `count` variables
350 appropriately and whose value matches `mean_value`.
352 Raises:
353 ValueError: If `weights` is not `None` and its shape doesn't match `values`,
354 or if either `metrics_collections` or `updates_collections` are not a list
355 or tuple.
356 RuntimeError: If eager execution is enabled.
358 @compatibility(TF2)
359 `tf.compat.v1.metrics.mean` is not compatible with eager
360 execution or `tf.function`.
361 Please use `tf.keras.metrics.Mean` instead for TF2 migration. After
362 instantiating a `tf.keras.metrics.Mean` object, you can first call the
363 `update_state()` method to record the new values, and then call the
364 `result()` method to get the mean eagerly. You can also attach it to a
365 Keras model with the `add_metric` method. Please refer to the [migration
366 guide](https://www.tensorflow.org/guide/migrate#new-style_metrics_and_losses)
367 for more details.
369 #### Structural Mapping to TF2
371 Before:
373 ```python
374 mean, update_op = tf.compat.v1.metrics.mean(
375 values=values,
376 weights=weights,
377 metrics_collections=metrics_collections,
378 update_collections=update_collections,
379 name=name)
380 ```
382 After:
384 ```python
385 m = tf.keras.metrics.Mean(
386 name=name)
388 m.update_state(
389 values=values,
390 sample_weight=weights)
392 mean = m.result()
393 ```
395 #### How to Map Arguments
397 | TF1 Arg Name | TF2 Arg Name | Note |
398 | :-------------------- | :-------------- | :------------------------- |
399 | `values` | `values` | In `update_state()` method |
400 | `weights` | `sample_weight` | In `update_state()` method |
401 | `metrics_collections` | Not supported | Metrics should be tracked |
402 : : : explicitly or with Keras :
403 : : : APIs, for example, :
404 : : : [add_metric][add_metric], :
405 : : : instead of via collections :
406 | `updates_collections` | Not supported | - |
407 | `name` | `name` | In constructor |
409 [add_metric]:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_metric
412 #### Before & After Usage Example
414 Before:
416 >>> g = tf.Graph()
417 >>> with g.as_default():
418 ... values = [1, 2, 3]
419 ... mean, update_op = tf.compat.v1.metrics.mean(values)
420 ... global_init = tf.compat.v1.global_variables_initializer()
421 ... local_init = tf.compat.v1.local_variables_initializer()
422 >>> sess = tf.compat.v1.Session(graph=g)
423 >>> sess.run([global_init, local_init])
424 >>> sess.run(update_op)
425 >>> sess.run(mean)
426 2.0
429 After:
431 >>> m = tf.keras.metrics.Mean()
432 >>> m.update_state([1, 2, 3])
433 >>> m.result().numpy()
434 2.0
436 ```python
437 # Used within Keras model
438 model.add_metric(tf.keras.metrics.Mean()(values))
439 ```
441 @end_compatibility
442 """
443 if context.executing_eagerly():
444 raise RuntimeError('tf.metrics.mean is not supported when eager execution '
445 'is enabled.')
447 with variable_scope.variable_scope(name, 'mean', (values, weights)):
448 values = math_ops.cast(values, dtypes.float32)
450 total = metric_variable([], dtypes.float32, name='total')
451 count = metric_variable([], dtypes.float32, name='count')
453 if weights is None:
454 num_values = math_ops.cast(array_ops.size(values), dtypes.float32)
455 else:
456 values, _, weights = _remove_squeezable_dimensions(
457 predictions=values, labels=None, weights=weights)
458 weights = weights_broadcast_ops.broadcast_weights(
459 math_ops.cast(weights, dtypes.float32), values)
460 values = math_ops.multiply(values, weights)
461 num_values = math_ops.reduce_sum(weights)
463 update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
464 with ops.control_dependencies([values]):
465 update_count_op = state_ops.assign_add(count, num_values)
467 def compute_mean(_, t, c):
468 return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value')
470 mean_t = _aggregate_across_replicas(
471 metrics_collections, compute_mean, total, count)
472 update_op = math_ops.div_no_nan(
473 update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
475 if updates_collections:
476 ops.add_to_collections(updates_collections, update_op)
478 return mean_t, update_op
481@tf_export(v1=['metrics.accuracy'])
482def accuracy(labels,
483 predictions,
484 weights=None,
485 metrics_collections=None,
486 updates_collections=None,
487 name=None):
488 """Calculates how often `predictions` matches `labels`.
490 The `accuracy` function creates two local variables, `total` and
491 `count` that are used to compute the frequency with which `predictions`
492 matches `labels`. This frequency is ultimately returned as `accuracy`: an
493 idempotent operation that simply divides `total` by `count`.
495 For estimation of the metric over a stream of data, the function creates an
496 `update_op` operation that updates these variables and returns the `accuracy`.
497 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
498 where the corresponding elements of `predictions` and `labels` match and 0.0
499 otherwise. Then `update_op` increments `total` with the reduced sum of the
500 product of `weights` and `is_correct`, and it increments `count` with the
501 reduced sum of `weights`.
503 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
505 Args:
506 labels: The ground truth values, a `Tensor` whose shape matches
507 `predictions`.
508 predictions: The predicted values, a `Tensor` of any shape.
509 weights: Optional `Tensor` whose rank is either 0, or the same rank as
510 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
511 be either `1`, or the same as the corresponding `labels` dimension).
512 metrics_collections: An optional list of collections that `accuracy` should
513 be added to.
514 updates_collections: An optional list of collections that `update_op` should
515 be added to.
516 name: An optional variable_scope name.
518 Returns:
519 accuracy: A `Tensor` representing the accuracy, the value of `total` divided
520 by `count`.
521 update_op: An operation that increments the `total` and `count` variables
522 appropriately and whose value matches `accuracy`.
524 Raises:
525 ValueError: If `predictions` and `labels` have mismatched shapes, or if
526 `weights` is not `None` and its shape doesn't match `predictions`, or if
527 either `metrics_collections` or `updates_collections` are not a list or
528 tuple.
529 RuntimeError: If eager execution is enabled.
531 @compatibility(TF2)
532 `tf.compat.v1.metrics.accuracy` is not compatible with eager
533 execution or `tf.function`.
534 Please use `tf.keras.metrics.Accuracy` instead for TF2 migration. After
535 instantiating a `tf.keras.metrics.Accuracy` object, you can first call the
536 `update_state()` method to record the prediction/labels, and then call the
537 `result()` method to get the accuracy eagerly. You can also attach it to a
538 Keras model when calling the `compile` method. Please refer to [this
539 guide](https://www.tensorflow.org/guide/migrate#new-style_metrics_and_losses)
540 for more details.
542 #### Structural Mapping to Native TF2
544 Before:
546 ```python
547 accuracy, update_op = tf.compat.v1.metrics.accuracy(
548 labels=labels,
549 predictions=predictions,
550 weights=weights,
551 metrics_collections=metrics_collections,
552 update_collections=update_collections,
553 name=name)
554 ```
556 After:
558 ```python
559 m = tf.keras.metrics.Accuracy(
560 name=name,
561 dtype=None)
563 m.update_state(
564 y_true=labels,
565 y_pred=predictions,
566 sample_weight=weights)
568 accuracy = m.result()
569 ```
571 #### How to Map Arguments
573 | TF1 Arg Name | TF2 Arg Name | Note |
574 | :-------------------- | :-------------- | :------------------------- |
575 | `label` | `y_true` | In `update_state()` method |
576 | `predictions` | `y_true` | In `update_state()` method |
577 | `weights` | `sample_weight` | In `update_state()` method |
578 | `metrics_collections` | Not supported | Metrics should be tracked |
579 : : : explicitly or with Keras :
580 : : : APIs, for example, :
581 : : : [add_metric][add_metric], :
582 : : : instead of via collections :
583 | `updates_collections` | Not supported | - |
584 | `name` | `name` | In constructor |
586 [add_metric]:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_metric
589 #### Before & After Usage Example
591 Before:
593 >>> g = tf.Graph()
594 >>> with g.as_default():
595 ... logits = [1, 2, 3]
596 ... labels = [0, 2, 3]
597 ... acc, acc_op = tf.compat.v1.metrics.accuracy(logits, labels)
598 ... global_init = tf.compat.v1.global_variables_initializer()
599 ... local_init = tf.compat.v1.local_variables_initializer()
600 >>> sess = tf.compat.v1.Session(graph=g)
601 >>> sess.run([global_init, local_init])
602 >>> print(sess.run([acc, acc_op]))
603 [0.0, 0.66667]
606 After:
608 >>> m = tf.keras.metrics.Accuracy()
609 >>> m.update_state([1, 2, 3], [0, 2, 3])
610 >>> m.result().numpy()
611 0.66667
613 ```python
614 # Used within Keras model
615 model.compile(optimizer='sgd',
616 loss='mse',
617 metrics=[tf.keras.metrics.Accuracy()])
618 ```
620 @end_compatibility
621 """
622 if context.executing_eagerly():
623 raise RuntimeError('tf.metrics.accuracy is not supported when eager '
624 'execution is enabled.')
626 predictions, labels, weights = _remove_squeezable_dimensions(
627 predictions=predictions, labels=labels, weights=weights)
628 predictions.get_shape().assert_is_compatible_with(labels.get_shape())
629 if labels.dtype != predictions.dtype:
630 predictions = math_ops.cast(predictions, labels.dtype)
631 is_correct = math_ops.cast(
632 math_ops.equal(predictions, labels), dtypes.float32)
633 return mean(is_correct, weights, metrics_collections, updates_collections,
634 name or 'accuracy')
637def _confusion_matrix_at_thresholds(labels,
638 predictions,
639 thresholds,
640 weights=None,
641 includes=None):
642 """Computes true_positives, false_negatives, true_negatives, false_positives.
644 This function creates up to four local variables, `true_positives`,
645 `true_negatives`, `false_positives` and `false_negatives`.
646 `true_positive[i]` is defined as the total weight of values in `predictions`
647 above `thresholds[i]` whose corresponding entry in `labels` is `True`.
648 `false_negatives[i]` is defined as the total weight of values in `predictions`
649 at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
650 `true_negatives[i]` is defined as the total weight of values in `predictions`
651 at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
652 `false_positives[i]` is defined as the total weight of values in `predictions`
653 above `thresholds[i]` whose corresponding entry in `labels` is `False`.
655 For estimation of these metrics over a stream of data, for each metric the
656 function respectively creates an `update_op` operation that updates the
657 variable and returns its value.
659 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
661 Args:
662 labels: A `Tensor` whose shape matches `predictions`. Will be cast to
663 `bool`.
664 predictions: A floating point `Tensor` of arbitrary shape and whose values
665 are in the range `[0, 1]`.
666 thresholds: A python list or tuple of float thresholds in `[0, 1]`.
667 weights: Optional `Tensor` whose rank is either 0, or the same rank as
668 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
669 be either `1`, or the same as the corresponding `labels` dimension).
670 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
671 default to all four.
673 Returns:
674 values: Dict of variables of shape `[len(thresholds)]`. Keys are from
675 `includes`.
676 update_ops: Dict of operations that increments the `values`. Keys are from
677 `includes`.
679 Raises:
680 ValueError: If `predictions` and `labels` have mismatched shapes, or if
681 `weights` is not `None` and its shape doesn't match `predictions`, or if
682 `includes` contains invalid keys.
683 """
684 all_includes = ('tp', 'fn', 'tn', 'fp')
685 if includes is None:
686 includes = all_includes
687 else:
688 for include in includes:
689 if include not in all_includes:
690 raise ValueError(f'Invalid key: {include}')
692 with ops.control_dependencies([
693 check_ops.assert_greater_equal(
694 predictions,
695 math_ops.cast(0.0, dtype=predictions.dtype),
696 message='predictions must be in [0, 1]'),
697 check_ops.assert_less_equal(
698 predictions,
699 math_ops.cast(1.0, dtype=predictions.dtype),
700 message='predictions must be in [0, 1]')
701 ]):
702 predictions, labels, weights = _remove_squeezable_dimensions(
703 predictions=math_ops.cast(predictions, dtypes.float32),
704 labels=math_ops.cast(labels, dtype=dtypes.bool),
705 weights=weights)
707 num_thresholds = len(thresholds)
709 # Reshape predictions and labels.
710 predictions_2d = array_ops.reshape(predictions, [-1, 1])
711 labels_2d = array_ops.reshape(
712 math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
714 # Use static shape if known.
715 num_predictions = predictions_2d.get_shape().as_list()[0]
717 # Otherwise use dynamic shape.
718 if num_predictions is None:
719 num_predictions = array_ops.shape(predictions_2d)[0]
720 thresh_tiled = array_ops.tile(
721 array_ops.expand_dims(array_ops.constant(thresholds), [1]),
722 array_ops_stack.stack([1, num_predictions]))
724 # Tile the predictions after thresholding them across different thresholds.
725 pred_is_pos = math_ops.greater(
726 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
727 thresh_tiled)
728 if ('fn' in includes) or ('tn' in includes):
729 pred_is_neg = math_ops.logical_not(pred_is_pos)
731 # Tile labels by number of thresholds
732 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
733 if ('fp' in includes) or ('tn' in includes):
734 label_is_neg = math_ops.logical_not(label_is_pos)
736 if weights is not None:
737 weights = weights_broadcast_ops.broadcast_weights(
738 math_ops.cast(weights, dtypes.float32), predictions)
739 weights_tiled = array_ops.tile(
740 array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
741 thresh_tiled.get_shape().assert_is_compatible_with(
742 weights_tiled.get_shape())
743 else:
744 weights_tiled = None
746 values = {}
747 update_ops = {}
749 if 'tp' in includes:
750 true_p = metric_variable(
751 [num_thresholds], dtypes.float32, name='true_positives')
752 is_true_positive = math_ops.cast(
753 math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32)
754 if weights_tiled is not None:
755 is_true_positive *= weights_tiled
756 update_ops['tp'] = state_ops.assign_add(true_p,
757 math_ops.reduce_sum(
758 is_true_positive, 1))
759 values['tp'] = true_p
761 if 'fn' in includes:
762 false_n = metric_variable(
763 [num_thresholds], dtypes.float32, name='false_negatives')
764 is_false_negative = math_ops.cast(
765 math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32)
766 if weights_tiled is not None:
767 is_false_negative *= weights_tiled
768 update_ops['fn'] = state_ops.assign_add(false_n,
769 math_ops.reduce_sum(
770 is_false_negative, 1))
771 values['fn'] = false_n
773 if 'tn' in includes:
774 true_n = metric_variable(
775 [num_thresholds], dtypes.float32, name='true_negatives')
776 is_true_negative = math_ops.cast(
777 math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32)
778 if weights_tiled is not None:
779 is_true_negative *= weights_tiled
780 update_ops['tn'] = state_ops.assign_add(true_n,
781 math_ops.reduce_sum(
782 is_true_negative, 1))
783 values['tn'] = true_n
785 if 'fp' in includes:
786 false_p = metric_variable(
787 [num_thresholds], dtypes.float32, name='false_positives')
788 is_false_positive = math_ops.cast(
789 math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32)
790 if weights_tiled is not None:
791 is_false_positive *= weights_tiled
792 update_ops['fp'] = state_ops.assign_add(false_p,
793 math_ops.reduce_sum(
794 is_false_positive, 1))
795 values['fp'] = false_p
797 return values, update_ops
800def _aggregate_variable(v, collections):
801 f = lambda distribution, value: distribution.extended.read_var(value)
802 return _aggregate_across_replicas(collections, f, v)
805@tf_export(v1=['metrics.auc'])
806@deprecated(None,
807 'The value of AUC returned by this may race with the update so '
808 'this is deprecated. Please use tf.keras.metrics.AUC instead.')
809def auc(labels,
810 predictions,
811 weights=None,
812 num_thresholds=200,
813 metrics_collections=None,
814 updates_collections=None,
815 curve='ROC',
816 name=None,
817 summation_method='trapezoidal',
818 thresholds=None):
819 """Computes the approximate AUC via a Riemann sum.
821 The `auc` function creates four local variables, `true_positives`,
822 `true_negatives`, `false_positives` and `false_negatives` that are used to
823 compute the AUC. To discretize the AUC curve, a linearly spaced set of
824 thresholds is used to compute pairs of recall and precision values. The area
825 under the ROC-curve is therefore computed using the height of the recall
826 values by the false positive rate, while the area under the PR-curve is the
827 computed using the height of the precision values by the recall.
829 This value is ultimately returned as `auc`, an idempotent operation that
830 computes the area under a discretized curve of precision versus recall values
831 (computed using the aforementioned variables). The `num_thresholds` variable
832 controls the degree of discretization with larger numbers of thresholds more
833 closely approximating the true AUC. The quality of the approximation may vary
834 dramatically depending on `num_thresholds`.
836 For best results, `predictions` should be distributed approximately uniformly
837 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
838 approximation may be poor if this is not the case. Setting `summation_method`
839 to 'minoring' or 'majoring' can help quantify the error in the approximation
840 by providing lower or upper bound estimate of the AUC. The `thresholds`
841 parameter can be used to manually specify thresholds which split the
842 predictions more evenly.
844 For estimation of the metric over a stream of data, the function creates an
845 `update_op` operation that updates these variables and returns the `auc`.
847 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
849 Args:
850 labels: A `Tensor` whose shape matches `predictions`. Will be cast to
851 `bool`.
852 predictions: A floating point `Tensor` of arbitrary shape and whose values
853 are in the range `[0, 1]`.
854 weights: Optional `Tensor` whose rank is either 0, or the same rank as
855 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
856 be either `1`, or the same as the corresponding `labels` dimension).
857 num_thresholds: The number of thresholds to use when discretizing the roc
858 curve.
859 metrics_collections: An optional list of collections that `auc` should be
860 added to.
861 updates_collections: An optional list of collections that `update_op` should
862 be added to.
863 curve: Specifies the name of the curve to be computed, 'ROC' [default] or
864 'PR' for the Precision-Recall-curve.
865 name: An optional variable_scope name.
866 summation_method: Specifies the Riemann summation method used
867 (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that
868 applies the trapezoidal rule; 'careful_interpolation', a variant of it
869 differing only by a more correct interpolation scheme for PR-AUC -
870 interpolating (true/false) positives but not the ratio that is precision;
871 'minoring' that applies left summation for increasing intervals and right
872 summation for decreasing intervals; 'majoring' that does the opposite.
873 Note that 'careful_interpolation' is strictly preferred to 'trapezoidal'
874 (to be deprecated soon) as it applies the same method for ROC, and a
875 better one (see Davis & Goadrich 2006 for details) for the PR curve.
876 thresholds: An optional list of floating point values to use as the
877 thresholds for discretizing the curve. If set, the `num_thresholds`
878 parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
879 equal to {-epsilon, 1+epsilon} for a small positive epsilon value will be
880 automatically included with these to correctly handle predictions equal to
881 exactly 0 or 1.
883 Returns:
884 auc: A scalar `Tensor` representing the current area-under-curve.
885 update_op: An operation that increments the `true_positives`,
886 `true_negatives`, `false_positives` and `false_negatives` variables
887 appropriately and whose value matches `auc`.
889 Raises:
890 ValueError: If `predictions` and `labels` have mismatched shapes, or if
891 `weights` is not `None` and its shape doesn't match `predictions`, or if
892 either `metrics_collections` or `updates_collections` are not a list or
893 tuple.
894 RuntimeError: If eager execution is enabled.
895 """
896 if context.executing_eagerly():
897 raise RuntimeError('tf.metrics.auc is not supported when eager execution '
898 'is enabled.')
900 with variable_scope.variable_scope(name, 'auc',
901 (labels, predictions, weights)):
902 if curve != 'ROC' and curve != 'PR':
903 raise ValueError(f'Curve must be either ROC or PR. Curve {curve} is '
904 'unknown.')
906 kepsilon = 1e-7 # To account for floating point imprecisions.
907 if thresholds is not None:
908 # If specified, use the supplied thresholds.
909 thresholds = sorted(thresholds)
910 num_thresholds = len(thresholds) + 2
911 else:
912 # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
913 # (0, 1).
914 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
915 for i in range(num_thresholds - 2)]
917 # Add an endpoint "threshold" below zero and above one for either threshold
918 # method.
919 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
921 values, update_ops = _confusion_matrix_at_thresholds(
922 labels, predictions, thresholds, weights)
924 # Add epsilons to avoid dividing by 0.
925 epsilon = 1.0e-6
927 def interpolate_pr_auc(tp, fp, fn):
928 """Interpolation formula inspired by section 4 of (Davis et al., 2006).
930 Note here we derive & use a closed formula not present in the paper
931 - as follows:
932 Modeling all of TP (true positive weight),
933 FP (false positive weight) and their sum P = TP + FP (positive weight)
934 as varying linearly within each interval [A, B] between successive
935 thresholds, we get
936 Precision = (TP_A + slope * (P - P_A)) / P
937 with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A).
938 The area within the interval is thus (slope / total_pos_weight) times
939 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
940 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
941 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
942 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
943 Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
944 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
945 where dTP == TP_B - TP_A.
946 Note that when P_A == 0 the above calculation simplifies into
947 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
948 which is really equivalent to imputing constant precision throughout the
949 first bucket having >0 true positives.
951 Args:
952 tp: true positive counts
953 fp: false positive counts
954 fn: false negative counts
956 Returns:
957 pr_auc: an approximation of the area under the P-R curve.
959 References:
960 The Relationship Between Precision-Recall and ROC Curves:
961 [Davis et al., 2006](https://dl.acm.org/citation.cfm?id=1143874)
962 ([pdf](https://www.biostat.wisc.edu/~page/rocpr.pdf))
963 """
964 dtp = tp[:num_thresholds - 1] - tp[1:]
965 p = tp + fp
966 prec_slope = math_ops.div_no_nan(
967 dtp,
968 math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0),
969 name='prec_slope')
970 intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
971 safe_p_ratio = array_ops.where(
972 math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
973 math_ops.div_no_nan(
974 p[:num_thresholds - 1],
975 math_ops.maximum(p[1:], 0),
976 name='recall_relative_ratio'), array_ops.ones_like(p[1:]))
977 return math_ops.reduce_sum(
978 math_ops.div_no_nan(
979 prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
980 math_ops.maximum(tp[1:] + fn[1:], 0),
981 name='pr_auc_increment'),
982 name='interpolate_pr_auc')
984 def compute_auc(tp, fn, tn, fp, name):
985 """Computes the roc-auc or pr-auc based on confusion counts."""
986 if curve == 'PR':
987 if summation_method == 'trapezoidal':
988 logging.warning(
989 'Trapezoidal rule is known to produce incorrect PR-AUCs; '
990 'please switch to "careful_interpolation" instead.')
991 elif summation_method == 'careful_interpolation':
992 # This one is a bit tricky and is handled separately.
993 return interpolate_pr_auc(tp, fp, fn)
994 rec = math_ops.divide(tp + epsilon, tp + fn + epsilon)
995 if curve == 'ROC':
996 fp_rate = math_ops.divide(fp, fp + tn + epsilon)
997 x = fp_rate
998 y = rec
999 else: # curve == 'PR'.
1000 prec = math_ops.divide(tp + epsilon, tp + fp + epsilon)
1001 x = rec
1002 y = prec
1003 if summation_method in ('trapezoidal', 'careful_interpolation'):
1004 # Note that the case ('PR', 'careful_interpolation') has been handled
1005 # above.
1006 return math_ops.reduce_sum(
1007 math_ops.multiply(x[:num_thresholds - 1] - x[1:],
1008 (y[:num_thresholds - 1] + y[1:]) / 2.),
1009 name=name)
1010 elif summation_method == 'minoring':
1011 return math_ops.reduce_sum(
1012 math_ops.multiply(x[:num_thresholds - 1] - x[1:],
1013 math_ops.minimum(y[:num_thresholds - 1], y[1:])),
1014 name=name)
1015 elif summation_method == 'majoring':
1016 return math_ops.reduce_sum(
1017 math_ops.multiply(x[:num_thresholds - 1] - x[1:],
1018 math_ops.maximum(y[:num_thresholds - 1], y[1:])),
1019 name=name)
1020 else:
1021 raise ValueError(f'Invalid summation_method: {summation_method} '
1022 'summation_method should be \'trapezoidal\', '
1023 '\'careful_interpolation\', \'minoring\', or '
1024 '\'majoring\'.')
1026 # sum up the areas of all the trapeziums
1027 def compute_auc_value(_, values):
1028 return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'],
1029 'value')
1031 auc_value = _aggregate_across_replicas(
1032 metrics_collections, compute_auc_value, values)
1033 update_op = compute_auc(update_ops['tp'], update_ops['fn'],
1034 update_ops['tn'], update_ops['fp'], 'update_op')
1036 if updates_collections:
1037 ops.add_to_collections(updates_collections, update_op)
1039 return auc_value, update_op
1042@tf_export(v1=['metrics.mean_absolute_error'])
1043def mean_absolute_error(labels,
1044 predictions,
1045 weights=None,
1046 metrics_collections=None,
1047 updates_collections=None,
1048 name=None):
1049 """Computes the mean absolute error between the labels and predictions.
1051 The `mean_absolute_error` function creates two local variables,
1052 `total` and `count` that are used to compute the mean absolute error. This
1053 average is weighted by `weights`, and it is ultimately returned as
1054 `mean_absolute_error`: an idempotent operation that simply divides `total` by
1055 `count`.
1057 For estimation of the metric over a stream of data, the function creates an
1058 `update_op` operation that updates these variables and returns the
1059 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
1060 absolute value of the differences between `predictions` and `labels`. Then
1061 `update_op` increments `total` with the reduced sum of the product of
1062 `weights` and `absolute_errors`, and it increments `count` with the reduced
1063 sum of `weights`
1065 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1067 Args:
1068 labels: A `Tensor` of the same shape as `predictions`.
1069 predictions: A `Tensor` of arbitrary shape.
1070 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1071 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1072 be either `1`, or the same as the corresponding `labels` dimension).
1073 metrics_collections: An optional list of collections that
1074 `mean_absolute_error` should be added to.
1075 updates_collections: An optional list of collections that `update_op` should
1076 be added to.
1077 name: An optional variable_scope name.
1079 Returns:
1080 mean_absolute_error: A `Tensor` representing the current mean, the value of
1081 `total` divided by `count`.
1082 update_op: An operation that increments the `total` and `count` variables
1083 appropriately and whose value matches `mean_absolute_error`.
1085 Raises:
1086 ValueError: If `predictions` and `labels` have mismatched shapes, or if
1087 `weights` is not `None` and its shape doesn't match `predictions`, or if
1088 either `metrics_collections` or `updates_collections` are not a list or
1089 tuple.
1090 RuntimeError: If eager execution is enabled.
1091 """
1092 if context.executing_eagerly():
1093 raise RuntimeError('tf.metrics.mean_absolute_error is not supported '
1094 'when eager execution is enabled.')
1096 predictions, labels, weights = _remove_squeezable_dimensions(
1097 predictions=predictions, labels=labels, weights=weights)
1098 absolute_errors = math_ops.abs(predictions - labels)
1099 return mean(absolute_errors, weights, metrics_collections,
1100 updates_collections, name or 'mean_absolute_error')
1103@tf_export(v1=['metrics.mean_cosine_distance'])
1104def mean_cosine_distance(labels,
1105 predictions,
1106 dim,
1107 weights=None,
1108 metrics_collections=None,
1109 updates_collections=None,
1110 name=None):
1111 """Computes the cosine distance between the labels and predictions.
1113 The `mean_cosine_distance` function creates two local variables,
1114 `total` and `count` that are used to compute the average cosine distance
1115 between `predictions` and `labels`. This average is weighted by `weights`,
1116 and it is ultimately returned as `mean_distance`, which is an idempotent
1117 operation that simply divides `total` by `count`.
1119 For estimation of the metric over a stream of data, the function creates an
1120 `update_op` operation that updates these variables and returns the
1121 `mean_distance`.
1123 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1125 Args:
1126 labels: A `Tensor` of arbitrary shape.
1127 predictions: A `Tensor` of the same shape as `labels`.
1128 dim: The dimension along which the cosine distance is computed.
1129 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1130 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1131 be either `1`, or the same as the corresponding `labels` dimension). Also,
1132 dimension `dim` must be `1`.
1133 metrics_collections: An optional list of collections that the metric
1134 value variable should be added to.
1135 updates_collections: An optional list of collections that the metric update
1136 ops should be added to.
1137 name: An optional variable_scope name.
1139 Returns:
1140 mean_distance: A `Tensor` representing the current mean, the value of
1141 `total` divided by `count`.
1142 update_op: An operation that increments the `total` and `count` variables
1143 appropriately.
1145 Raises:
1146 ValueError: If `predictions` and `labels` have mismatched shapes, or if
1147 `weights` is not `None` and its shape doesn't match `predictions`, or if
1148 either `metrics_collections` or `updates_collections` are not a list or
1149 tuple.
1150 RuntimeError: If eager execution is enabled.
1151 """
1152 if context.executing_eagerly():
1153 raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when '
1154 'eager execution is enabled.')
1156 predictions, labels, weights = _remove_squeezable_dimensions(
1157 predictions=predictions, labels=labels, weights=weights)
1158 radial_diffs = math_ops.multiply(predictions, labels)
1159 radial_diffs = math_ops.reduce_sum(
1160 radial_diffs, axis=[
1161 dim,
1162 ], keepdims=True)
1163 mean_distance, update_op = mean(radial_diffs, weights, None, None, name or
1164 'mean_cosine_distance')
1165 mean_distance = math_ops.subtract(1.0, mean_distance)
1166 update_op = math_ops.subtract(1.0, update_op)
1168 if metrics_collections:
1169 ops.add_to_collections(metrics_collections, mean_distance)
1171 if updates_collections:
1172 ops.add_to_collections(updates_collections, update_op)
1174 return mean_distance, update_op
1177@tf_export(v1=['metrics.mean_per_class_accuracy'])
1178def mean_per_class_accuracy(labels,
1179 predictions,
1180 num_classes,
1181 weights=None,
1182 metrics_collections=None,
1183 updates_collections=None,
1184 name=None):
1185 """Calculates the mean of the per-class accuracies.
1187 Calculates the accuracy for each class, then takes the mean of that.
1189 For estimation of the metric over a stream of data, the function creates an
1190 `update_op` operation that updates the accuracy of each class and returns
1191 them.
1193 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1195 Args:
1196 labels: A `Tensor` of ground truth labels with shape [batch size] and of
1197 type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1198 predictions: A `Tensor` of prediction results for semantic labels, whose
1199 shape is [batch size] and type `int32` or `int64`. The tensor will be
1200 flattened if its rank > 1.
1201 num_classes: The possible number of labels the prediction task can
1202 have. This value must be provided, since two variables with shape =
1203 [num_classes] will be allocated.
1204 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1205 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1206 be either `1`, or the same as the corresponding `labels` dimension).
1207 metrics_collections: An optional list of collections that
1208 `mean_per_class_accuracy'
1209 should be added to.
1210 updates_collections: An optional list of collections `update_op` should be
1211 added to.
1212 name: An optional variable_scope name.
1214 Returns:
1215 mean_accuracy: A `Tensor` representing the mean per class accuracy.
1216 update_op: An operation that updates the accuracy tensor.
1218 Raises:
1219 ValueError: If `predictions` and `labels` have mismatched shapes, or if
1220 `weights` is not `None` and its shape doesn't match `predictions`, or if
1221 either `metrics_collections` or `updates_collections` are not a list or
1222 tuple.
1223 RuntimeError: If eager execution is enabled.
1224 """
1225 if context.executing_eagerly():
1226 raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported '
1227 'when eager execution is enabled.')
1229 with variable_scope.variable_scope(name, 'mean_accuracy',
1230 (predictions, labels, weights)):
1231 labels = math_ops.cast(labels, dtypes.int64)
1233 # Flatten the input if its rank > 1.
1234 if labels.get_shape().ndims > 1:
1235 labels = array_ops.reshape(labels, [-1])
1237 if predictions.get_shape().ndims > 1:
1238 predictions = array_ops.reshape(predictions, [-1])
1240 # Check if shape is compatible.
1241 predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1243 total = metric_variable([num_classes], dtypes.float32, name='total')
1244 count = metric_variable([num_classes], dtypes.float32, name='count')
1246 ones = array_ops.ones([array_ops.size(labels)], dtypes.float32)
1248 if labels.dtype != predictions.dtype:
1249 predictions = math_ops.cast(predictions, labels.dtype)
1250 is_correct = math_ops.cast(
1251 math_ops.equal(predictions, labels), dtypes.float32)
1253 if weights is not None:
1254 if weights.get_shape().ndims > 1:
1255 weights = array_ops.reshape(weights, [-1])
1256 weights = math_ops.cast(weights, dtypes.float32)
1258 is_correct *= weights
1259 ones *= weights
1261 update_total_op = state_ops.scatter_add(total, labels, ones)
1262 update_count_op = state_ops.scatter_add(count, labels, is_correct)
1264 def compute_mean_accuracy(_, count, total):
1265 per_class_accuracy = math_ops.div_no_nan(
1266 count, math_ops.maximum(total, 0), name=None)
1267 mean_accuracy_v = math_ops.reduce_mean(
1268 per_class_accuracy, name='mean_accuracy')
1269 return mean_accuracy_v
1271 mean_accuracy_v = _aggregate_across_replicas(
1272 metrics_collections, compute_mean_accuracy, count, total)
1274 update_op = math_ops.div_no_nan(
1275 update_count_op, math_ops.maximum(update_total_op, 0), name='update_op')
1276 if updates_collections:
1277 ops.add_to_collections(updates_collections, update_op)
1279 return mean_accuracy_v, update_op
1282@tf_export(v1=['metrics.mean_iou'])
1283def mean_iou(labels,
1284 predictions,
1285 num_classes,
1286 weights=None,
1287 metrics_collections=None,
1288 updates_collections=None,
1289 name=None):
1290 """Calculate per-step mean Intersection-Over-Union (mIOU).
1292 Mean Intersection-Over-Union is a common evaluation metric for
1293 semantic image segmentation, which first computes the IOU for each
1294 semantic class and then computes the average over classes.
1295 IOU is defined as follows:
1296 IOU = true_positive / (true_positive + false_positive + false_negative).
1297 The predictions are accumulated in a confusion matrix, weighted by `weights`,
1298 and mIOU is then calculated from it.
1300 For estimation of the metric over a stream of data, the function creates an
1301 `update_op` operation that updates these variables and returns the `mean_iou`.
1303 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1305 Args:
1306 labels: A `Tensor` of ground truth labels with shape [batch size] and of
1307 type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1308 predictions: A `Tensor` of prediction results for semantic labels, whose
1309 shape is [batch size] and type `int32` or `int64`. The tensor will be
1310 flattened if its rank > 1.
1311 num_classes: The possible number of labels the prediction task can
1312 have. This value must be provided, since a confusion matrix of
1313 dimension = [num_classes, num_classes] will be allocated.
1314 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1315 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1316 be either `1`, or the same as the corresponding `labels` dimension).
1317 metrics_collections: An optional list of collections that `mean_iou`
1318 should be added to.
1319 updates_collections: An optional list of collections `update_op` should be
1320 added to.
1321 name: An optional variable_scope name.
1323 Returns:
1324 mean_iou: A `Tensor` representing the mean intersection-over-union.
1325 update_op: An operation that increments the confusion matrix.
1327 Raises:
1328 ValueError: If `predictions` and `labels` have mismatched shapes, or if
1329 `weights` is not `None` and its shape doesn't match `predictions`, or if
1330 either `metrics_collections` or `updates_collections` are not a list or
1331 tuple.
1332 RuntimeError: If eager execution is enabled.
1333 """
1334 if context.executing_eagerly():
1335 raise RuntimeError('tf.metrics.mean_iou is not supported when '
1336 'eager execution is enabled.')
1338 with variable_scope.variable_scope(name, 'mean_iou',
1339 (predictions, labels, weights)):
1340 # Check if shape is compatible.
1341 predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1343 total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
1344 num_classes, weights)
1346 def compute_mean_iou(_, total_cm):
1347 """Compute the mean intersection-over-union via the confusion matrix."""
1348 sum_over_row = math_ops.cast(
1349 math_ops.reduce_sum(total_cm, 0), dtypes.float32)
1350 sum_over_col = math_ops.cast(
1351 math_ops.reduce_sum(total_cm, 1), dtypes.float32)
1352 cm_diag = math_ops.cast(array_ops.diag_part(total_cm), dtypes.float32)
1353 denominator = sum_over_row + sum_over_col - cm_diag
1355 # The mean is only computed over classes that appear in the
1356 # label or prediction tensor. If the denominator is 0, we need to
1357 # ignore the class.
1358 num_valid_entries = math_ops.reduce_sum(
1359 math_ops.cast(
1360 math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
1362 # If the value of the denominator is 0, set it to 1 to avoid
1363 # zero division.
1364 denominator = array_ops.where(
1365 math_ops.greater(denominator, 0), denominator,
1366 array_ops.ones_like(denominator))
1367 iou = math_ops.divide(cm_diag, denominator)
1369 # If the number of valid entries is 0 (no classes) we return 0.
1370 result = array_ops.where(
1371 math_ops.greater(num_valid_entries, 0),
1372 math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0)
1373 return result
1375 # TODO(priyag): Use outside_compilation if in TPU context.
1376 mean_iou_v = _aggregate_across_replicas(
1377 metrics_collections, compute_mean_iou, total_cm)
1379 if updates_collections:
1380 ops.add_to_collections(updates_collections, update_op)
1382 return mean_iou_v, update_op
1385@tf_export(v1=['metrics.mean_relative_error'])
1386def mean_relative_error(labels,
1387 predictions,
1388 normalizer,
1389 weights=None,
1390 metrics_collections=None,
1391 updates_collections=None,
1392 name=None):
1393 """Computes the mean relative error by normalizing with the given values.
1395 The `mean_relative_error` function creates two local variables,
1396 `total` and `count` that are used to compute the mean relative absolute error.
1397 This average is weighted by `weights`, and it is ultimately returned as
1398 `mean_relative_error`: an idempotent operation that simply divides `total` by
1399 `count`.
1401 For estimation of the metric over a stream of data, the function creates an
1402 `update_op` operation that updates these variables and returns the
1403 `mean_reative_error`. Internally, a `relative_errors` operation divides the
1404 absolute value of the differences between `predictions` and `labels` by the
1405 `normalizer`. Then `update_op` increments `total` with the reduced sum of the
1406 product of `weights` and `relative_errors`, and it increments `count` with the
1407 reduced sum of `weights`.
1409 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1411 Args:
1412 labels: A `Tensor` of the same shape as `predictions`.
1413 predictions: A `Tensor` of arbitrary shape.
1414 normalizer: A `Tensor` of the same shape as `predictions`.
1415 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1416 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1417 be either `1`, or the same as the corresponding `labels` dimension).
1418 metrics_collections: An optional list of collections that
1419 `mean_relative_error` should be added to.
1420 updates_collections: An optional list of collections that `update_op` should
1421 be added to.
1422 name: An optional variable_scope name.
1424 Returns:
1425 mean_relative_error: A `Tensor` representing the current mean, the value of
1426 `total` divided by `count`.
1427 update_op: An operation that increments the `total` and `count` variables
1428 appropriately and whose value matches `mean_relative_error`.
1430 Raises:
1431 ValueError: If `predictions` and `labels` have mismatched shapes, or if
1432 `weights` is not `None` and its shape doesn't match `predictions`, or if
1433 either `metrics_collections` or `updates_collections` are not a list or
1434 tuple.
1435 RuntimeError: If eager execution is enabled.
1436 """
1437 if context.executing_eagerly():
1438 raise RuntimeError('tf.metrics.mean_relative_error is not supported when '
1439 'eager execution is enabled.')
1441 predictions, labels, weights = _remove_squeezable_dimensions(
1442 predictions=predictions, labels=labels, weights=weights)
1444 predictions, normalizer = confusion_matrix.remove_squeezable_dimensions(
1445 predictions, normalizer)
1446 predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
1447 relative_errors = array_ops.where(
1448 math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels),
1449 math_ops.divide(math_ops.abs(labels - predictions), normalizer))
1450 return mean(relative_errors, weights, metrics_collections,
1451 updates_collections, name or 'mean_relative_error')
1454@tf_export(v1=['metrics.mean_squared_error'])
1455def mean_squared_error(labels,
1456 predictions,
1457 weights=None,
1458 metrics_collections=None,
1459 updates_collections=None,
1460 name=None):
1461 """Computes the mean squared error between the labels and predictions.
1463 The `mean_squared_error` function creates two local variables,
1464 `total` and `count` that are used to compute the mean squared error.
1465 This average is weighted by `weights`, and it is ultimately returned as
1466 `mean_squared_error`: an idempotent operation that simply divides `total` by
1467 `count`.
1469 For estimation of the metric over a stream of data, the function creates an
1470 `update_op` operation that updates these variables and returns the
1471 `mean_squared_error`. Internally, a `squared_error` operation computes the
1472 element-wise square of the difference between `predictions` and `labels`. Then
1473 `update_op` increments `total` with the reduced sum of the product of
1474 `weights` and `squared_error`, and it increments `count` with the reduced sum
1475 of `weights`.
1477 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1479 Args:
1480 labels: A `Tensor` of the same shape as `predictions`.
1481 predictions: A `Tensor` of arbitrary shape.
1482 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1483 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1484 be either `1`, or the same as the corresponding `labels` dimension).
1485 metrics_collections: An optional list of collections that
1486 `mean_squared_error` should be added to.
1487 updates_collections: An optional list of collections that `update_op` should
1488 be added to.
1489 name: An optional variable_scope name.
1491 Returns:
1492 mean_squared_error: A `Tensor` representing the current mean, the value of
1493 `total` divided by `count`.
1494 update_op: An operation that increments the `total` and `count` variables
1495 appropriately and whose value matches `mean_squared_error`.
1497 Raises:
1498 ValueError: If `predictions` and `labels` have mismatched shapes, or if
1499 `weights` is not `None` and its shape doesn't match `predictions`, or if
1500 either `metrics_collections` or `updates_collections` are not a list or
1501 tuple.
1502 RuntimeError: If eager execution is enabled.
1503 """
1504 if context.executing_eagerly():
1505 raise RuntimeError('tf.metrics.mean_squared_error is not supported when '
1506 'eager execution is enabled.')
1508 predictions, labels, weights = _remove_squeezable_dimensions(
1509 predictions=predictions, labels=labels, weights=weights)
1510 squared_error = math_ops.squared_difference(labels, predictions)
1511 return mean(squared_error, weights, metrics_collections, updates_collections,
1512 name or 'mean_squared_error')
1515@tf_export(v1=['metrics.mean_tensor'])
1516def mean_tensor(values,
1517 weights=None,
1518 metrics_collections=None,
1519 updates_collections=None,
1520 name=None):
1521 """Computes the element-wise (weighted) mean of the given tensors.
1523 In contrast to the `mean` function which returns a scalar with the
1524 mean, this function returns an average tensor with the same shape as the
1525 input tensors.
1527 The `mean_tensor` function creates two local variables,
1528 `total_tensor` and `count_tensor` that are used to compute the average of
1529 `values`. This average is ultimately returned as `mean` which is an idempotent
1530 operation that simply divides `total` by `count`.
1532 For estimation of the metric over a stream of data, the function creates an
1533 `update_op` operation that updates these variables and returns the `mean`.
1534 `update_op` increments `total` with the reduced sum of the product of `values`
1535 and `weights`, and it increments `count` with the reduced sum of `weights`.
1537 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1539 Args:
1540 values: A `Tensor` of arbitrary dimensions.
1541 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1542 `values`, and must be broadcastable to `values` (i.e., all dimensions must
1543 be either `1`, or the same as the corresponding `values` dimension).
1544 metrics_collections: An optional list of collections that `mean`
1545 should be added to.
1546 updates_collections: An optional list of collections that `update_op`
1547 should be added to.
1548 name: An optional variable_scope name.
1550 Returns:
1551 mean: A float `Tensor` representing the current mean, the value of `total`
1552 divided by `count`.
1553 update_op: An operation that increments the `total` and `count` variables
1554 appropriately and whose value matches `mean_value`.
1556 Raises:
1557 ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1558 or if either `metrics_collections` or `updates_collections` are not a list
1559 or tuple.
1560 RuntimeError: If eager execution is enabled.
1561 """
1562 if context.executing_eagerly():
1563 raise RuntimeError('tf.metrics.mean_tensor is not supported when '
1564 'eager execution is enabled.')
1566 with variable_scope.variable_scope(name, 'mean', (values, weights)):
1567 values = math_ops.cast(values, dtypes.float32)
1568 total = metric_variable(
1569 values.get_shape(), dtypes.float32, name='total_tensor')
1570 count = metric_variable(
1571 values.get_shape(), dtypes.float32, name='count_tensor')
1573 num_values = array_ops.ones_like(values)
1574 if weights is not None:
1575 values, _, weights = _remove_squeezable_dimensions(
1576 predictions=values, labels=None, weights=weights)
1577 weights = weights_broadcast_ops.broadcast_weights(
1578 math_ops.cast(weights, dtypes.float32), values)
1579 values = math_ops.multiply(values, weights)
1580 num_values = math_ops.multiply(num_values, weights)
1582 update_total_op = state_ops.assign_add(total, values)
1583 with ops.control_dependencies([values]):
1584 update_count_op = state_ops.assign_add(count, num_values)
1586 compute_mean = lambda _, t, c: math_ops.div_no_nan( # pylint: disable=g-long-lambda
1587 t, math_ops.maximum(c, 0), name='value')
1589 mean_t = _aggregate_across_replicas(
1590 metrics_collections, compute_mean, total, count)
1592 update_op = math_ops.div_no_nan(
1593 update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
1594 if updates_collections:
1595 ops.add_to_collections(updates_collections, update_op)
1597 return mean_t, update_op
1600@tf_export(v1=['metrics.percentage_below'])
1601def percentage_below(values,
1602 threshold,
1603 weights=None,
1604 metrics_collections=None,
1605 updates_collections=None,
1606 name=None):
1607 """Computes the percentage of values less than the given threshold.
1609 The `percentage_below` function creates two local variables,
1610 `total` and `count` that are used to compute the percentage of `values` that
1611 fall below `threshold`. This rate is weighted by `weights`, and it is
1612 ultimately returned as `percentage` which is an idempotent operation that
1613 simply divides `total` by `count`.
1615 For estimation of the metric over a stream of data, the function creates an
1616 `update_op` operation that updates these variables and returns the
1617 `percentage`.
1619 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1621 Args:
1622 values: A numeric `Tensor` of arbitrary size.
1623 threshold: A scalar threshold.
1624 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1625 `values`, and must be broadcastable to `values` (i.e., all dimensions must
1626 be either `1`, or the same as the corresponding `values` dimension).
1627 metrics_collections: An optional list of collections that the metric
1628 value variable should be added to.
1629 updates_collections: An optional list of collections that the metric update
1630 ops should be added to.
1631 name: An optional variable_scope name.
1633 Returns:
1634 percentage: A `Tensor` representing the current mean, the value of `total`
1635 divided by `count`.
1636 update_op: An operation that increments the `total` and `count` variables
1637 appropriately.
1639 Raises:
1640 ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1641 or if either `metrics_collections` or `updates_collections` are not a list
1642 or tuple.
1643 RuntimeError: If eager execution is enabled.
1644 """
1645 if context.executing_eagerly():
1646 raise RuntimeError('tf.metrics.percentage_below is not supported when '
1647 'eager execution is enabled.')
1649 is_below_threshold = math_ops.cast(
1650 math_ops.less(values, threshold), dtypes.float32)
1651 return mean(is_below_threshold, weights, metrics_collections,
1652 updates_collections, name or 'percentage_below_threshold')
1655def _count_condition(values,
1656 weights=None,
1657 metrics_collections=None,
1658 updates_collections=None):
1659 """Sums the weights of cases where the given values are True.
1661 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1663 Args:
1664 values: A `bool` `Tensor` of arbitrary size.
1665 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1666 `values`, and must be broadcastable to `values` (i.e., all dimensions must
1667 be either `1`, or the same as the corresponding `values` dimension).
1668 metrics_collections: An optional list of collections that the metric
1669 value variable should be added to.
1670 updates_collections: An optional list of collections that the metric update
1671 ops should be added to.
1673 Returns:
1674 value_tensor: A `Tensor` representing the current value of the metric.
1675 update_op: An operation that accumulates the error from a batch of data.
1677 Raises:
1678 ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1679 or if either `metrics_collections` or `updates_collections` are not a list
1680 or tuple.
1681 """
1682 check_ops.assert_type(values, dtypes.bool)
1683 count = metric_variable([], dtypes.float32, name='count')
1685 values = math_ops.cast(values, dtypes.float32)
1686 if weights is not None:
1687 with ops.control_dependencies((check_ops.assert_rank_in(
1688 weights, (0, array_ops.rank(values))),)):
1689 weights = math_ops.cast(weights, dtypes.float32)
1690 values = math_ops.multiply(values, weights)
1692 value_tensor = _aggregate_variable(count, metrics_collections)
1694 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
1695 if updates_collections:
1696 ops.add_to_collections(updates_collections, update_op)
1698 return value_tensor, update_op
1701@tf_export(v1=['metrics.false_negatives'])
1702def false_negatives(labels,
1703 predictions,
1704 weights=None,
1705 metrics_collections=None,
1706 updates_collections=None,
1707 name=None):
1708 """Computes the total number of false negatives.
1710 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1712 Args:
1713 labels: The ground truth values, a `Tensor` whose dimensions must match
1714 `predictions`. Will be cast to `bool`.
1715 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1716 be cast to `bool`.
1717 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1718 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1719 be either `1`, or the same as the corresponding `labels` dimension).
1720 metrics_collections: An optional list of collections that the metric
1721 value variable should be added to.
1722 updates_collections: An optional list of collections that the metric update
1723 ops should be added to.
1724 name: An optional variable_scope name.
1726 Returns:
1727 value_tensor: A `Tensor` representing the current value of the metric.
1728 update_op: An operation that accumulates the error from a batch of data.
1730 Raises:
1731 ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1732 or if either `metrics_collections` or `updates_collections` are not a list
1733 or tuple.
1734 RuntimeError: If eager execution is enabled.
1735 """
1736 if context.executing_eagerly():
1737 raise RuntimeError('tf.metrics.false_negatives is not supported when '
1738 'eager execution is enabled.')
1740 with variable_scope.variable_scope(name, 'false_negatives',
1741 (predictions, labels, weights)):
1743 predictions, labels, weights = _remove_squeezable_dimensions(
1744 predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1745 labels=math_ops.cast(labels, dtype=dtypes.bool),
1746 weights=weights)
1747 is_false_negative = math_ops.logical_and(
1748 math_ops.equal(labels, True), math_ops.equal(predictions, False))
1749 return _count_condition(is_false_negative, weights, metrics_collections,
1750 updates_collections)
1753@tf_export(v1=['metrics.false_negatives_at_thresholds'])
1754def false_negatives_at_thresholds(labels,
1755 predictions,
1756 thresholds,
1757 weights=None,
1758 metrics_collections=None,
1759 updates_collections=None,
1760 name=None):
1761 """Computes false negatives at provided threshold values.
1763 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1765 Args:
1766 labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1767 `bool`.
1768 predictions: A floating point `Tensor` of arbitrary shape and whose values
1769 are in the range `[0, 1]`.
1770 thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1771 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1772 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1773 be either `1`, or the same as the corresponding `labels` dimension).
1774 metrics_collections: An optional list of collections that `false_negatives`
1775 should be added to.
1776 updates_collections: An optional list of collections that `update_op` should
1777 be added to.
1778 name: An optional variable_scope name.
1780 Returns:
1781 false_negatives: A float `Tensor` of shape `[len(thresholds)]`.
1782 update_op: An operation that updates the `false_negatives` variable and
1783 returns its current value.
1785 Raises:
1786 ValueError: If `predictions` and `labels` have mismatched shapes, or if
1787 `weights` is not `None` and its shape doesn't match `predictions`, or if
1788 either `metrics_collections` or `updates_collections` are not a list or
1789 tuple.
1790 RuntimeError: If eager execution is enabled.
1791 """
1792 if context.executing_eagerly():
1793 raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not '
1794 'supported when eager execution is enabled.')
1796 with variable_scope.variable_scope(name, 'false_negatives',
1797 (predictions, labels, weights)):
1798 values, update_ops = _confusion_matrix_at_thresholds(
1799 labels, predictions, thresholds, weights=weights, includes=('fn',))
1801 fn_value = _aggregate_variable(values['fn'], metrics_collections)
1803 if updates_collections:
1804 ops.add_to_collections(updates_collections, update_ops['fn'])
1806 return fn_value, update_ops['fn']
1809@tf_export(v1=['metrics.false_positives'])
1810def false_positives(labels,
1811 predictions,
1812 weights=None,
1813 metrics_collections=None,
1814 updates_collections=None,
1815 name=None):
1816 """Sum the weights of false positives.
1818 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1820 Args:
1821 labels: The ground truth values, a `Tensor` whose dimensions must match
1822 `predictions`. Will be cast to `bool`.
1823 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1824 be cast to `bool`.
1825 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1826 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1827 be either `1`, or the same as the corresponding `labels` dimension).
1828 metrics_collections: An optional list of collections that the metric
1829 value variable should be added to.
1830 updates_collections: An optional list of collections that the metric update
1831 ops should be added to.
1832 name: An optional variable_scope name.
1834 Returns:
1835 value_tensor: A `Tensor` representing the current value of the metric.
1836 update_op: An operation that accumulates the error from a batch of data.
1838 Raises:
1839 ValueError: If `predictions` and `labels` have mismatched shapes, or if
1840 `weights` is not `None` and its shape doesn't match `predictions`, or if
1841 either `metrics_collections` or `updates_collections` are not a list or
1842 tuple.
1843 RuntimeError: If eager execution is enabled.
1844 """
1845 if context.executing_eagerly():
1846 raise RuntimeError('tf.metrics.false_positives is not supported when '
1847 'eager execution is enabled.')
1849 with variable_scope.variable_scope(name, 'false_positives',
1850 (predictions, labels, weights)):
1852 predictions, labels, weights = _remove_squeezable_dimensions(
1853 predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1854 labels=math_ops.cast(labels, dtype=dtypes.bool),
1855 weights=weights)
1856 is_false_positive = math_ops.logical_and(
1857 math_ops.equal(labels, False), math_ops.equal(predictions, True))
1858 return _count_condition(is_false_positive, weights, metrics_collections,
1859 updates_collections)
1862@tf_export(v1=['metrics.false_positives_at_thresholds'])
1863def false_positives_at_thresholds(labels,
1864 predictions,
1865 thresholds,
1866 weights=None,
1867 metrics_collections=None,
1868 updates_collections=None,
1869 name=None):
1870 """Computes false positives at provided threshold values.
1872 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1874 Args:
1875 labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1876 `bool`.
1877 predictions: A floating point `Tensor` of arbitrary shape and whose values
1878 are in the range `[0, 1]`.
1879 thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1880 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1881 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1882 be either `1`, or the same as the corresponding `labels` dimension).
1883 metrics_collections: An optional list of collections that `false_positives`
1884 should be added to.
1885 updates_collections: An optional list of collections that `update_op` should
1886 be added to.
1887 name: An optional variable_scope name.
1889 Returns:
1890 false_positives: A float `Tensor` of shape `[len(thresholds)]`.
1891 update_op: An operation that updates the `false_positives` variable and
1892 returns its current value.
1894 Raises:
1895 ValueError: If `predictions` and `labels` have mismatched shapes, or if
1896 `weights` is not `None` and its shape doesn't match `predictions`, or if
1897 either `metrics_collections` or `updates_collections` are not a list or
1898 tuple.
1899 RuntimeError: If eager execution is enabled.
1900 """
1901 if context.executing_eagerly():
1902 raise RuntimeError('tf.metrics.false_positives_at_thresholds is not '
1903 'supported when eager execution is enabled.')
1905 with variable_scope.variable_scope(name, 'false_positives',
1906 (predictions, labels, weights)):
1907 values, update_ops = _confusion_matrix_at_thresholds(
1908 labels, predictions, thresholds, weights=weights, includes=('fp',))
1910 fp_value = _aggregate_variable(values['fp'], metrics_collections)
1912 if updates_collections:
1913 ops.add_to_collections(updates_collections, update_ops['fp'])
1915 return fp_value, update_ops['fp']
1918@tf_export(v1=['metrics.true_negatives'])
1919def true_negatives(labels,
1920 predictions,
1921 weights=None,
1922 metrics_collections=None,
1923 updates_collections=None,
1924 name=None):
1925 """Sum the weights of true_negatives.
1927 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1929 Args:
1930 labels: The ground truth values, a `Tensor` whose dimensions must match
1931 `predictions`. Will be cast to `bool`.
1932 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1933 be cast to `bool`.
1934 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1935 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1936 be either `1`, or the same as the corresponding `labels` dimension).
1937 metrics_collections: An optional list of collections that the metric
1938 value variable should be added to.
1939 updates_collections: An optional list of collections that the metric update
1940 ops should be added to.
1941 name: An optional variable_scope name.
1943 Returns:
1944 value_tensor: A `Tensor` representing the current value of the metric.
1945 update_op: An operation that accumulates the error from a batch of data.
1947 Raises:
1948 ValueError: If `predictions` and `labels` have mismatched shapes, or if
1949 `weights` is not `None` and its shape doesn't match `predictions`, or if
1950 either `metrics_collections` or `updates_collections` are not a list or
1951 tuple.
1952 RuntimeError: If eager execution is enabled.
1953 """
1954 if context.executing_eagerly():
1955 raise RuntimeError('tf.metrics.true_negatives is not '
1956 'supported when eager execution is enabled.')
1958 with variable_scope.variable_scope(name, 'true_negatives',
1959 (predictions, labels, weights)):
1961 predictions, labels, weights = _remove_squeezable_dimensions(
1962 predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1963 labels=math_ops.cast(labels, dtype=dtypes.bool),
1964 weights=weights)
1965 is_true_negative = math_ops.logical_and(
1966 math_ops.equal(labels, False), math_ops.equal(predictions, False))
1967 return _count_condition(is_true_negative, weights, metrics_collections,
1968 updates_collections)
1971@tf_export(v1=['metrics.true_negatives_at_thresholds'])
1972def true_negatives_at_thresholds(labels,
1973 predictions,
1974 thresholds,
1975 weights=None,
1976 metrics_collections=None,
1977 updates_collections=None,
1978 name=None):
1979 """Computes true negatives at provided threshold values.
1981 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1983 Args:
1984 labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1985 `bool`.
1986 predictions: A floating point `Tensor` of arbitrary shape and whose values
1987 are in the range `[0, 1]`.
1988 thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1989 weights: Optional `Tensor` whose rank is either 0, or the same rank as
1990 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1991 be either `1`, or the same as the corresponding `labels` dimension).
1992 metrics_collections: An optional list of collections that `true_negatives`
1993 should be added to.
1994 updates_collections: An optional list of collections that `update_op` should
1995 be added to.
1996 name: An optional variable_scope name.
1998 Returns:
1999 true_negatives: A float `Tensor` of shape `[len(thresholds)]`.
2000 update_op: An operation that updates the `true_negatives` variable and
2001 returns its current value.
2003 Raises:
2004 ValueError: If `predictions` and `labels` have mismatched shapes, or if
2005 `weights` is not `None` and its shape doesn't match `predictions`, or if
2006 either `metrics_collections` or `updates_collections` are not a list or
2007 tuple.
2008 RuntimeError: If eager execution is enabled.
2009 """
2010 if context.executing_eagerly():
2011 raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not '
2012 'supported when eager execution is enabled.')
2014 with variable_scope.variable_scope(name, 'true_negatives',
2015 (predictions, labels, weights)):
2016 values, update_ops = _confusion_matrix_at_thresholds(
2017 labels, predictions, thresholds, weights=weights, includes=('tn',))
2019 tn_value = _aggregate_variable(values['tn'], metrics_collections)
2021 if updates_collections:
2022 ops.add_to_collections(updates_collections, update_ops['tn'])
2024 return tn_value, update_ops['tn']
2027@tf_export(v1=['metrics.true_positives'])
2028def true_positives(labels,
2029 predictions,
2030 weights=None,
2031 metrics_collections=None,
2032 updates_collections=None,
2033 name=None):
2034 """Sum the weights of true_positives.
2036 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2038 Args:
2039 labels: The ground truth values, a `Tensor` whose dimensions must match
2040 `predictions`. Will be cast to `bool`.
2041 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
2042 be cast to `bool`.
2043 weights: Optional `Tensor` whose rank is either 0, or the same rank as
2044 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2045 be either `1`, or the same as the corresponding `labels` dimension).
2046 metrics_collections: An optional list of collections that the metric
2047 value variable should be added to.
2048 updates_collections: An optional list of collections that the metric update
2049 ops should be added to.
2050 name: An optional variable_scope name.
2052 Returns:
2053 value_tensor: A `Tensor` representing the current value of the metric.
2054 update_op: An operation that accumulates the error from a batch of data.
2056 Raises:
2057 ValueError: If `predictions` and `labels` have mismatched shapes, or if
2058 `weights` is not `None` and its shape doesn't match `predictions`, or if
2059 either `metrics_collections` or `updates_collections` are not a list or
2060 tuple.
2061 RuntimeError: If eager execution is enabled.
2062 """
2063 if context.executing_eagerly():
2064 raise RuntimeError('tf.metrics.true_positives is not '
2065 'supported when eager execution is enabled.')
2067 with variable_scope.variable_scope(name, 'true_positives',
2068 (predictions, labels, weights)):
2070 predictions, labels, weights = _remove_squeezable_dimensions(
2071 predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2072 labels=math_ops.cast(labels, dtype=dtypes.bool),
2073 weights=weights)
2074 is_true_positive = math_ops.logical_and(
2075 math_ops.equal(labels, True), math_ops.equal(predictions, True))
2076 return _count_condition(is_true_positive, weights, metrics_collections,
2077 updates_collections)
2080@tf_export(v1=['metrics.true_positives_at_thresholds'])
2081def true_positives_at_thresholds(labels,
2082 predictions,
2083 thresholds,
2084 weights=None,
2085 metrics_collections=None,
2086 updates_collections=None,
2087 name=None):
2088 """Computes true positives at provided threshold values.
2090 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2092 Args:
2093 labels: A `Tensor` whose shape matches `predictions`. Will be cast to
2094 `bool`.
2095 predictions: A floating point `Tensor` of arbitrary shape and whose values
2096 are in the range `[0, 1]`.
2097 thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2098 weights: Optional `Tensor` whose rank is either 0, or the same rank as
2099 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2100 be either `1`, or the same as the corresponding `labels` dimension).
2101 metrics_collections: An optional list of collections that `true_positives`
2102 should be added to.
2103 updates_collections: An optional list of collections that `update_op` should
2104 be added to.
2105 name: An optional variable_scope name.
2107 Returns:
2108 true_positives: A float `Tensor` of shape `[len(thresholds)]`.
2109 update_op: An operation that updates the `true_positives` variable and
2110 returns its current value.
2112 Raises:
2113 ValueError: If `predictions` and `labels` have mismatched shapes, or if
2114 `weights` is not `None` and its shape doesn't match `predictions`, or if
2115 either `metrics_collections` or `updates_collections` are not a list or
2116 tuple.
2117 RuntimeError: If eager execution is enabled.
2118 """
2119 if context.executing_eagerly():
2120 raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
2121 'supported when eager execution is enabled.')
2123 with variable_scope.variable_scope(name, 'true_positives',
2124 (predictions, labels, weights)):
2125 values, update_ops = _confusion_matrix_at_thresholds(
2126 labels, predictions, thresholds, weights=weights, includes=('tp',))
2128 tp_value = _aggregate_variable(values['tp'], metrics_collections)
2130 if updates_collections:
2131 ops.add_to_collections(updates_collections, update_ops['tp'])
2133 return tp_value, update_ops['tp']
2136@tf_export(v1=['metrics.precision'])
2137def precision(labels,
2138 predictions,
2139 weights=None,
2140 metrics_collections=None,
2141 updates_collections=None,
2142 name=None):
2143 """Computes the precision of the predictions with respect to the labels.
2145 The `precision` function creates two local variables,
2146 `true_positives` and `false_positives`, that are used to compute the
2147 precision. This value is ultimately returned as `precision`, an idempotent
2148 operation that simply divides `true_positives` by the sum of `true_positives`
2149 and `false_positives`.
2151 For estimation of the metric over a stream of data, the function creates an
2152 `update_op` operation that updates these variables and returns the
2153 `precision`. `update_op` weights each prediction by the corresponding value in
2154 `weights`.
2156 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2158 Args:
2159 labels: The ground truth values, a `Tensor` whose dimensions must match
2160 `predictions`. Will be cast to `bool`.
2161 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
2162 be cast to `bool`.
2163 weights: Optional `Tensor` whose rank is either 0, or the same rank as
2164 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2165 be either `1`, or the same as the corresponding `labels` dimension).
2166 metrics_collections: An optional list of collections that `precision` should
2167 be added to.
2168 updates_collections: An optional list of collections that `update_op` should
2169 be added to.
2170 name: An optional variable_scope name.
2172 Returns:
2173 precision: Scalar float `Tensor` with the value of `true_positives`
2174 divided by the sum of `true_positives` and `false_positives`.
2175 update_op: `Operation` that increments `true_positives` and
2176 `false_positives` variables appropriately and whose value matches
2177 `precision`.
2179 Raises:
2180 ValueError: If `predictions` and `labels` have mismatched shapes, or if
2181 `weights` is not `None` and its shape doesn't match `predictions`, or if
2182 either `metrics_collections` or `updates_collections` are not a list or
2183 tuple.
2184 RuntimeError: If eager execution is enabled.
2185 """
2186 if context.executing_eagerly():
2187 raise RuntimeError('tf.metrics.precision is not '
2188 'supported when eager execution is enabled.')
2190 with variable_scope.variable_scope(name, 'precision',
2191 (predictions, labels, weights)):
2193 predictions, labels, weights = _remove_squeezable_dimensions(
2194 predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2195 labels=math_ops.cast(labels, dtype=dtypes.bool),
2196 weights=weights)
2198 true_p, true_positives_update_op = true_positives(
2199 labels,
2200 predictions,
2201 weights,
2202 metrics_collections=None,
2203 updates_collections=None,
2204 name=None)
2205 false_p, false_positives_update_op = false_positives(
2206 labels,
2207 predictions,
2208 weights,
2209 metrics_collections=None,
2210 updates_collections=None,
2211 name=None)
2213 def compute_precision(tp, fp, name):
2214 return array_ops.where(
2215 math_ops.greater(tp + fp, 0), math_ops.divide(tp, tp + fp), 0, name)
2217 def once_across_replicas(_, true_p, false_p):
2218 return compute_precision(true_p, false_p, 'value')
2220 p = _aggregate_across_replicas(metrics_collections, once_across_replicas,
2221 true_p, false_p)
2223 update_op = compute_precision(true_positives_update_op,
2224 false_positives_update_op, 'update_op')
2225 if updates_collections:
2226 ops.add_to_collections(updates_collections, update_op)
2228 return p, update_op
2231@tf_export(v1=['metrics.precision_at_thresholds'])
2232def precision_at_thresholds(labels,
2233 predictions,
2234 thresholds,
2235 weights=None,
2236 metrics_collections=None,
2237 updates_collections=None,
2238 name=None):
2239 """Computes precision values for different `thresholds` on `predictions`.
2241 The `precision_at_thresholds` function creates four local variables,
2242 `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2243 for various values of thresholds. `precision[i]` is defined as the total
2244 weight of values in `predictions` above `thresholds[i]` whose corresponding
2245 entry in `labels` is `True`, divided by the total weight of values in
2246 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
2247 false_positives[i])`).
2249 For estimation of the metric over a stream of data, the function creates an
2250 `update_op` operation that updates these variables and returns the
2251 `precision`.
2253 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2255 Args:
2256 labels: The ground truth values, a `Tensor` whose dimensions must match
2257 `predictions`. Will be cast to `bool`.
2258 predictions: A floating point `Tensor` of arbitrary shape and whose values
2259 are in the range `[0, 1]`.
2260 thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2261 weights: Optional `Tensor` whose rank is either 0, or the same rank as
2262 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2263 be either `1`, or the same as the corresponding `labels` dimension).
2264 metrics_collections: An optional list of collections that `auc` should be
2265 added to.
2266 updates_collections: An optional list of collections that `update_op` should
2267 be added to.
2268 name: An optional variable_scope name.
2270 Returns:
2271 precision: A float `Tensor` of shape `[len(thresholds)]`.
2272 update_op: An operation that increments the `true_positives`,
2273 `true_negatives`, `false_positives` and `false_negatives` variables that
2274 are used in the computation of `precision`.
2276 Raises:
2277 ValueError: If `predictions` and `labels` have mismatched shapes, or if
2278 `weights` is not `None` and its shape doesn't match `predictions`, or if
2279 either `metrics_collections` or `updates_collections` are not a list or
2280 tuple.
2281 RuntimeError: If eager execution is enabled.
2282 """
2283 if context.executing_eagerly():
2284 raise RuntimeError('tf.metrics.precision_at_thresholds is not '
2285 'supported when eager execution is enabled.')
2287 with variable_scope.variable_scope(name, 'precision_at_thresholds',
2288 (predictions, labels, weights)):
2289 values, update_ops = _confusion_matrix_at_thresholds(
2290 labels, predictions, thresholds, weights, includes=('tp', 'fp'))
2292 # Avoid division by zero.
2293 epsilon = 1e-7
2295 def compute_precision(tp, fp, name):
2296 return math_ops.divide(tp, epsilon + tp + fp, name='precision_' + name)
2298 def precision_across_replicas(_, values):
2299 return compute_precision(values['tp'], values['fp'], 'value')
2301 prec = _aggregate_across_replicas(
2302 metrics_collections, precision_across_replicas, values)
2304 update_op = compute_precision(update_ops['tp'], update_ops['fp'],
2305 'update_op')
2306 if updates_collections:
2307 ops.add_to_collections(updates_collections, update_op)
2309 return prec, update_op
2312@tf_export(v1=['metrics.recall'])
2313def recall(labels,
2314 predictions,
2315 weights=None,
2316 metrics_collections=None,
2317 updates_collections=None,
2318 name=None):
2319 """Computes the recall of the predictions with respect to the labels.
2321 The `recall` function creates two local variables, `true_positives`
2322 and `false_negatives`, that are used to compute the recall. This value is
2323 ultimately returned as `recall`, an idempotent operation that simply divides
2324 `true_positives` by the sum of `true_positives` and `false_negatives`.
2326 For estimation of the metric over a stream of data, the function creates an
2327 `update_op` that updates these variables and returns the `recall`. `update_op`
2328 weights each prediction by the corresponding value in `weights`.
2330 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2332 Args:
2333 labels: The ground truth values, a `Tensor` whose dimensions must match
2334 `predictions`. Will be cast to `bool`.
2335 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
2336 be cast to `bool`.
2337 weights: Optional `Tensor` whose rank is either 0, or the same rank as
2338 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2339 be either `1`, or the same as the corresponding `labels` dimension).
2340 metrics_collections: An optional list of collections that `recall` should
2341 be added to.
2342 updates_collections: An optional list of collections that `update_op` should
2343 be added to.
2344 name: An optional variable_scope name.
2346 Returns:
2347 recall: Scalar float `Tensor` with the value of `true_positives` divided
2348 by the sum of `true_positives` and `false_negatives`.
2349 update_op: `Operation` that increments `true_positives` and
2350 `false_negatives` variables appropriately and whose value matches
2351 `recall`.
2353 Raises:
2354 ValueError: If `predictions` and `labels` have mismatched shapes, or if
2355 `weights` is not `None` and its shape doesn't match `predictions`, or if
2356 either `metrics_collections` or `updates_collections` are not a list or
2357 tuple.
2358 RuntimeError: If eager execution is enabled.
2359 """
2360 if context.executing_eagerly():
2361 raise RuntimeError('tf.metrics.recall is not supported is not '
2362 'supported when eager execution is enabled.')
2364 with variable_scope.variable_scope(name, 'recall',
2365 (predictions, labels, weights)):
2366 predictions, labels, weights = _remove_squeezable_dimensions(
2367 predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2368 labels=math_ops.cast(labels, dtype=dtypes.bool),
2369 weights=weights)
2371 true_p, true_positives_update_op = true_positives(
2372 labels,
2373 predictions,
2374 weights,
2375 metrics_collections=None,
2376 updates_collections=None,
2377 name=None)
2378 false_n, false_negatives_update_op = false_negatives(
2379 labels,
2380 predictions,
2381 weights,
2382 metrics_collections=None,
2383 updates_collections=None,
2384 name=None)
2386 def compute_recall(true_p, false_n, name):
2387 return array_ops.where(
2388 math_ops.greater(true_p + false_n, 0),
2389 math_ops.divide(true_p, true_p + false_n), 0, name)
2391 def once_across_replicas(_, true_p, false_n):
2392 return compute_recall(true_p, false_n, 'value')
2394 rec = _aggregate_across_replicas(
2395 metrics_collections, once_across_replicas, true_p, false_n)
2397 update_op = compute_recall(true_positives_update_op,
2398 false_negatives_update_op, 'update_op')
2399 if updates_collections:
2400 ops.add_to_collections(updates_collections, update_op)
2402 return rec, update_op
2405def _at_k_name(name, k=None, class_id=None):
2406 if k is not None:
2407 name = '%s_at_%d' % (name, k)
2408 else:
2409 name = '%s_at_k' % (name)
2410 if class_id is not None:
2411 name = '%s_class%d' % (name, class_id)
2412 return name
2415def _select_class_id(ids, selected_id):
2416 """Filter all but `selected_id` out of `ids`.
2418 Args:
2419 ids: `int64` `Tensor` or `SparseTensor` of IDs.
2420 selected_id: Int id to select.
2422 Returns:
2423 `SparseTensor` of same dimensions as `ids`. This contains only the entries
2424 equal to `selected_id`.
2425 """
2426 ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
2427 if isinstance(ids, sparse_tensor.SparseTensor):
2428 return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values,
2429 selected_id))
2431 # TODO(ptucker): Make this more efficient, maybe add a sparse version of
2432 # tf.equal and tf.reduce_any?
2434 # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
2435 ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
2436 ids_last_dim = array_ops.size(ids_shape) - 1
2437 filled_selected_id_shape = math_ops.reduced_shape(ids_shape,
2438 array_ops.reshape(
2439 ids_last_dim, [1]))
2441 # Intersect `ids` with the selected ID.
2442 filled_selected_id = array_ops.fill(filled_selected_id_shape,
2443 math_ops.cast(selected_id, dtypes.int64))
2444 result = sets.set_intersection(filled_selected_id, ids)
2445 return sparse_tensor.SparseTensor(
2446 indices=result.indices, values=result.values, dense_shape=ids_shape)
2449def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
2450 """If class ID is specified, filter all other classes.
2452 Args:
2453 labels: `int64` `Tensor` or `SparseTensor` with shape
2454 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2455 target classes for the associated prediction. Commonly, N=1 and `labels`
2456 has shape [batch_size, num_labels]. [D1, ... DN] must match
2457 `predictions_idx`.
2458 predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
2459 where N >= 1. Commonly, N=1 and `predictions_idx` has shape
2460 [batch size, k].
2461 selected_id: Int id to select.
2463 Returns:
2464 Tuple of `labels` and `predictions_idx`, possibly with classes removed.
2465 """
2466 if selected_id is None:
2467 return labels, predictions_idx
2468 return (_select_class_id(labels, selected_id),
2469 _select_class_id(predictions_idx, selected_id))
2472def _sparse_true_positive_at_k(labels,
2473 predictions_idx,
2474 class_id=None,
2475 weights=None,
2476 name=None):
2477 """Calculates true positives for recall@k and precision@k.
2479 If `class_id` is specified, calculate binary true positives for `class_id`
2480 only.
2481 If `class_id` is not specified, calculate metrics for `k` predicted vs
2482 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2484 Args:
2485 labels: `int64` `Tensor` or `SparseTensor` with shape
2486 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2487 target classes for the associated prediction. Commonly, N=1 and `labels`
2488 has shape [batch_size, num_labels]. [D1, ... DN] must match
2489 `predictions_idx`.
2490 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2491 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2492 match `labels`.
2493 class_id: Class for which we want binary metrics.
2494 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2495 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2496 dimensions must be either `1`, or the same as the corresponding `labels`
2497 dimension).
2498 name: Name of operation.
2500 Returns:
2501 A [D1, ... DN] `Tensor` of true positive counts.
2502 """
2503 with ops.name_scope(name, 'true_positives',
2504 (predictions_idx, labels, weights)):
2505 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2506 class_id)
2507 tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
2508 tp = math_ops.cast(tp, dtypes.float64)
2509 if weights is not None:
2510 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
2511 weights, tp),)):
2512 weights = math_ops.cast(weights, dtypes.float64)
2513 tp = math_ops.multiply(tp, weights)
2514 return tp
2517def _streaming_sparse_true_positive_at_k(labels,
2518 predictions_idx,
2519 k=None,
2520 class_id=None,
2521 weights=None,
2522 name=None):
2523 """Calculates weighted per step true positives for recall@k and precision@k.
2525 If `class_id` is specified, calculate binary true positives for `class_id`
2526 only.
2527 If `class_id` is not specified, calculate metrics for `k` predicted vs
2528 `n` label classes, where `n` is the 2nd dimension of `labels`.
2530 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2532 Args:
2533 labels: `int64` `Tensor` or `SparseTensor` with shape
2534 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2535 target classes for the associated prediction. Commonly, N=1 and `labels`
2536 has shape [batch_size, num_labels]. [D1, ... DN] must match
2537 `predictions_idx`.
2538 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2539 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2540 match `labels`.
2541 k: Integer, k for @k metric. This is only used for default op name.
2542 class_id: Class for which we want binary metrics.
2543 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2544 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2545 dimensions must be either `1`, or the same as the corresponding `labels`
2546 dimension).
2547 name: Name of new variable, and namespace for other dependent ops.
2549 Returns:
2550 A tuple of `Variable` and update `Operation`.
2552 Raises:
2553 ValueError: If `weights` is not `None` and has an incompatible shape.
2554 """
2555 with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id),
2556 (predictions_idx, labels, weights)) as scope:
2557 tp = _sparse_true_positive_at_k(
2558 predictions_idx=predictions_idx,
2559 labels=labels,
2560 class_id=class_id,
2561 weights=weights)
2562 batch_total_tp = math_ops.cast(math_ops.reduce_sum(tp), dtypes.float64)
2564 var = metric_variable([], dtypes.float64, name=scope)
2565 return var, state_ops.assign_add(var, batch_total_tp, name='update')
2568def _sparse_false_negative_at_k(labels,
2569 predictions_idx,
2570 class_id=None,
2571 weights=None):
2572 """Calculates false negatives for recall@k.
2574 If `class_id` is specified, calculate binary true positives for `class_id`
2575 only.
2576 If `class_id` is not specified, calculate metrics for `k` predicted vs
2577 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2579 Args:
2580 labels: `int64` `Tensor` or `SparseTensor` with shape
2581 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2582 target classes for the associated prediction. Commonly, N=1 and `labels`
2583 has shape [batch_size, num_labels]. [D1, ... DN] must match
2584 `predictions_idx`.
2585 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2586 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2587 match `labels`.
2588 class_id: Class for which we want binary metrics.
2589 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2590 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2591 dimensions must be either `1`, or the same as the corresponding `labels`
2592 dimension).
2594 Returns:
2595 A [D1, ... DN] `Tensor` of false negative counts.
2596 """
2597 with ops.name_scope(None, 'false_negatives',
2598 (predictions_idx, labels, weights)):
2599 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2600 class_id)
2601 fn = sets.set_size(
2602 sets.set_difference(predictions_idx, labels, aminusb=False))
2603 fn = math_ops.cast(fn, dtypes.float64)
2604 if weights is not None:
2605 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
2606 weights, fn),)):
2607 weights = math_ops.cast(weights, dtypes.float64)
2608 fn = math_ops.multiply(fn, weights)
2609 return fn
2612def _streaming_sparse_false_negative_at_k(labels,
2613 predictions_idx,
2614 k,
2615 class_id=None,
2616 weights=None,
2617 name=None):
2618 """Calculates weighted per step false negatives for recall@k.
2620 If `class_id` is specified, calculate binary true positives for `class_id`
2621 only.
2622 If `class_id` is not specified, calculate metrics for `k` predicted vs
2623 `n` label classes, where `n` is the 2nd dimension of `labels`.
2625 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2627 Args:
2628 labels: `int64` `Tensor` or `SparseTensor` with shape
2629 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2630 target classes for the associated prediction. Commonly, N=1 and `labels`
2631 has shape [batch_size, num_labels]. [D1, ... DN] must match
2632 `predictions_idx`.
2633 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2634 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2635 match `labels`.
2636 k: Integer, k for @k metric. This is only used for default op name.
2637 class_id: Class for which we want binary metrics.
2638 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2639 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2640 dimensions must be either `1`, or the same as the corresponding `labels`
2641 dimension).
2642 name: Name of new variable, and namespace for other dependent ops.
2644 Returns:
2645 A tuple of `Variable` and update `Operation`.
2647 Raises:
2648 ValueError: If `weights` is not `None` and has an incompatible shape.
2649 """
2650 with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id),
2651 (predictions_idx, labels, weights)) as scope:
2652 fn = _sparse_false_negative_at_k(
2653 predictions_idx=predictions_idx,
2654 labels=labels,
2655 class_id=class_id,
2656 weights=weights)
2657 batch_total_fn = math_ops.cast(math_ops.reduce_sum(fn), dtypes.float64)
2659 var = metric_variable([], dtypes.float64, name=scope)
2660 return var, state_ops.assign_add(var, batch_total_fn, name='update')
2663@tf_export(v1=['metrics.recall_at_k'])
2664def recall_at_k(labels,
2665 predictions,
2666 k,
2667 class_id=None,
2668 weights=None,
2669 metrics_collections=None,
2670 updates_collections=None,
2671 name=None):
2672 """Computes recall@k of the predictions with respect to sparse labels.
2674 If `class_id` is specified, we calculate recall by considering only the
2675 entries in the batch for which `class_id` is in the label, and computing
2676 the fraction of them for which `class_id` is in the top-k `predictions`.
2677 If `class_id` is not specified, we'll calculate recall as how often on
2678 average a class among the labels of a batch entry is in the top-k
2679 `predictions`.
2681 `sparse_recall_at_k` creates two local variables,
2682 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
2683 the recall_at_k frequency. This frequency is ultimately returned as
2684 `recall_at_<k>`: an idempotent operation that simply divides
2685 `true_positive_at_<k>` by total (`true_positive_at_<k>` +
2686 `false_negative_at_<k>`).
2688 For estimation of the metric over a stream of data, the function creates an
2689 `update_op` operation that updates these variables and returns the
2690 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2691 indicating the top `k` `predictions`. Set operations applied to `top_k` and
2692 `labels` calculate the true positives and false negatives weighted by
2693 `weights`. Then `update_op` increments `true_positive_at_<k>` and
2694 `false_negative_at_<k>` using these values.
2696 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2698 Args:
2699 labels: `int64` `Tensor` or `SparseTensor` with shape
2700 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2701 num_labels=1. N >= 1 and num_labels is the number of target classes for
2702 the associated prediction. Commonly, N=1 and `labels` has shape
2703 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2704 should be in range [0, num_classes), where num_classes is the last
2705 dimension of `predictions`. Values outside this range always count
2706 towards `false_negative_at_<k>`.
2707 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2708 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
2709 The final dimension contains the logit values for each class. [D1, ... DN]
2710 must match `labels`.
2711 k: Integer, k for @k metric.
2712 class_id: Integer class ID for which we want binary metrics. This should be
2713 in range [0, num_classes), where num_classes is the last dimension of
2714 `predictions`. If class_id is outside this range, the method returns NAN.
2715 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2716 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2717 dimensions must be either `1`, or the same as the corresponding `labels`
2718 dimension).
2719 metrics_collections: An optional list of collections that values should
2720 be added to.
2721 updates_collections: An optional list of collections that updates should
2722 be added to.
2723 name: Name of new update operation, and namespace for other dependent ops.
2725 Returns:
2726 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2727 by the sum of `true_positives` and `false_negatives`.
2728 update_op: `Operation` that increments `true_positives` and
2729 `false_negatives` variables appropriately, and whose value matches
2730 `recall`.
2732 Raises:
2733 ValueError: If `weights` is not `None` and its shape doesn't match
2734 `predictions`, or if either `metrics_collections` or `updates_collections`
2735 are not a list or tuple.
2736 RuntimeError: If eager execution is enabled.
2737 """
2738 if context.executing_eagerly():
2739 raise RuntimeError('tf.metrics.recall_at_k is not '
2740 'supported when eager execution is enabled.')
2742 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2743 (predictions, labels, weights)) as scope:
2744 _, top_k_idx = nn.top_k(predictions, k)
2745 return recall_at_top_k(
2746 labels=labels,
2747 predictions_idx=top_k_idx,
2748 k=k,
2749 class_id=class_id,
2750 weights=weights,
2751 metrics_collections=metrics_collections,
2752 updates_collections=updates_collections,
2753 name=scope)
2756@tf_export(v1=['metrics.recall_at_top_k'])
2757def recall_at_top_k(labels,
2758 predictions_idx,
2759 k=None,
2760 class_id=None,
2761 weights=None,
2762 metrics_collections=None,
2763 updates_collections=None,
2764 name=None):
2765 """Computes recall@k of top-k predictions with respect to sparse labels.
2767 Differs from `recall_at_k` in that predictions must be in the form of top `k`
2768 class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k`
2769 for more details.
2771 Args:
2772 labels: `int64` `Tensor` or `SparseTensor` with shape
2773 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2774 num_labels=1. N >= 1 and num_labels is the number of target classes for
2775 the associated prediction. Commonly, N=1 and `labels` has shape
2776 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2777 should be in range [0, num_classes), where num_classes is the last
2778 dimension of `predictions`. Values outside this range always count
2779 towards `false_negative_at_<k>`.
2780 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
2781 Commonly, N=1 and predictions has shape [batch size, k]. The final
2782 dimension contains the top `k` predicted class indices. [D1, ... DN] must
2783 match `labels`.
2784 k: Integer, k for @k metric. Only used for the default op name.
2785 class_id: Integer class ID for which we want binary metrics. This should be
2786 in range [0, num_classes), where num_classes is the last dimension of
2787 `predictions`. If class_id is outside this range, the method returns NAN.
2788 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2789 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2790 dimensions must be either `1`, or the same as the corresponding `labels`
2791 dimension).
2792 metrics_collections: An optional list of collections that values should
2793 be added to.
2794 updates_collections: An optional list of collections that updates should
2795 be added to.
2796 name: Name of new update operation, and namespace for other dependent ops.
2798 Returns:
2799 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2800 by the sum of `true_positives` and `false_negatives`.
2801 update_op: `Operation` that increments `true_positives` and
2802 `false_negatives` variables appropriately, and whose value matches
2803 `recall`.
2805 Raises:
2806 ValueError: If `weights` is not `None` and its shape doesn't match
2807 `predictions`, or if either `metrics_collections` or `updates_collections`
2808 are not a list or tuple.
2809 """
2810 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2811 (predictions_idx, labels, weights)) as scope:
2812 labels = _maybe_expand_labels(labels, predictions_idx)
2813 top_k_idx = math_ops.cast(predictions_idx, dtypes.int64)
2814 tp, tp_update = _streaming_sparse_true_positive_at_k(
2815 predictions_idx=top_k_idx,
2816 labels=labels,
2817 k=k,
2818 class_id=class_id,
2819 weights=weights)
2820 fn, fn_update = _streaming_sparse_false_negative_at_k(
2821 predictions_idx=top_k_idx,
2822 labels=labels,
2823 k=k,
2824 class_id=class_id,
2825 weights=weights)
2827 def compute_recall(_, tp, fn):
2828 return math_ops.divide(tp, math_ops.add(tp, fn), name=scope)
2830 metric = _aggregate_across_replicas(
2831 metrics_collections, compute_recall, tp, fn)
2833 update = math_ops.divide(
2834 tp_update, math_ops.add(tp_update, fn_update), name='update')
2835 if updates_collections:
2836 ops.add_to_collections(updates_collections, update)
2837 return metric, update
2840@tf_export(v1=['metrics.recall_at_thresholds'])
2841def recall_at_thresholds(labels,
2842 predictions,
2843 thresholds,
2844 weights=None,
2845 metrics_collections=None,
2846 updates_collections=None,
2847 name=None):
2848 """Computes various recall values for different `thresholds` on `predictions`.
2850 The `recall_at_thresholds` function creates four local variables,
2851 `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2852 for various values of thresholds. `recall[i]` is defined as the total weight
2853 of values in `predictions` above `thresholds[i]` whose corresponding entry in
2854 `labels` is `True`, divided by the total weight of `True` values in `labels`
2855 (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
2857 For estimation of the metric over a stream of data, the function creates an
2858 `update_op` operation that updates these variables and returns the `recall`.
2860 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2862 Args:
2863 labels: The ground truth values, a `Tensor` whose dimensions must match
2864 `predictions`. Will be cast to `bool`.
2865 predictions: A floating point `Tensor` of arbitrary shape and whose values
2866 are in the range `[0, 1]`.
2867 thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2868 weights: Optional `Tensor` whose rank is either 0, or the same rank as
2869 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2870 be either `1`, or the same as the corresponding `labels` dimension).
2871 metrics_collections: An optional list of collections that `recall` should be
2872 added to.
2873 updates_collections: An optional list of collections that `update_op` should
2874 be added to.
2875 name: An optional variable_scope name.
2877 Returns:
2878 recall: A float `Tensor` of shape `[len(thresholds)]`.
2879 update_op: An operation that increments the `true_positives`,
2880 `true_negatives`, `false_positives` and `false_negatives` variables that
2881 are used in the computation of `recall`.
2883 Raises:
2884 ValueError: If `predictions` and `labels` have mismatched shapes, or if
2885 `weights` is not `None` and its shape doesn't match `predictions`, or if
2886 either `metrics_collections` or `updates_collections` are not a list or
2887 tuple.
2888 RuntimeError: If eager execution is enabled.
2889 """
2890 if context.executing_eagerly():
2891 raise RuntimeError('tf.metrics.recall_at_thresholds is not '
2892 'supported when eager execution is enabled.')
2894 with variable_scope.variable_scope(name, 'recall_at_thresholds',
2895 (predictions, labels, weights)):
2896 values, update_ops = _confusion_matrix_at_thresholds(
2897 labels, predictions, thresholds, weights, includes=('tp', 'fn'))
2899 # Avoid division by zero.
2900 epsilon = 1e-7
2902 def compute_recall(tp, fn, name):
2903 return math_ops.divide(tp, epsilon + tp + fn, name='recall_' + name)
2905 def recall_across_replicas(_, values):
2906 return compute_recall(values['tp'], values['fn'], 'value')
2908 rec = _aggregate_across_replicas(
2909 metrics_collections, recall_across_replicas, values)
2911 update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
2912 if updates_collections:
2913 ops.add_to_collections(updates_collections, update_op)
2915 return rec, update_op
2918@tf_export(v1=['metrics.root_mean_squared_error'])
2919def root_mean_squared_error(labels,
2920 predictions,
2921 weights=None,
2922 metrics_collections=None,
2923 updates_collections=None,
2924 name=None):
2925 """Computes the root mean squared error between the labels and predictions.
2927 The `root_mean_squared_error` function creates two local variables,
2928 `total` and `count` that are used to compute the root mean squared error.
2929 This average is weighted by `weights`, and it is ultimately returned as
2930 `root_mean_squared_error`: an idempotent operation that takes the square root
2931 of the division of `total` by `count`.
2933 For estimation of the metric over a stream of data, the function creates an
2934 `update_op` operation that updates these variables and returns the
2935 `root_mean_squared_error`. Internally, a `squared_error` operation computes
2936 the element-wise square of the difference between `predictions` and `labels`.
2937 Then `update_op` increments `total` with the reduced sum of the product of
2938 `weights` and `squared_error`, and it increments `count` with the reduced sum
2939 of `weights`.
2941 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2943 Args:
2944 labels: A `Tensor` of the same shape as `predictions`.
2945 predictions: A `Tensor` of arbitrary shape.
2946 weights: Optional `Tensor` whose rank is either 0, or the same rank as
2947 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2948 be either `1`, or the same as the corresponding `labels` dimension).
2949 metrics_collections: An optional list of collections that
2950 `root_mean_squared_error` should be added to.
2951 updates_collections: An optional list of collections that `update_op` should
2952 be added to.
2953 name: An optional variable_scope name.
2955 Returns:
2956 root_mean_squared_error: A `Tensor` representing the current mean, the value
2957 of `total` divided by `count`.
2958 update_op: An operation that increments the `total` and `count` variables
2959 appropriately and whose value matches `root_mean_squared_error`.
2961 Raises:
2962 ValueError: If `predictions` and `labels` have mismatched shapes, or if
2963 `weights` is not `None` and its shape doesn't match `predictions`, or if
2964 either `metrics_collections` or `updates_collections` are not a list or
2965 tuple.
2966 RuntimeError: If eager execution is enabled.
2967 """
2968 if context.executing_eagerly():
2969 raise RuntimeError('tf.metrics.root_mean_squared_error is not '
2970 'supported when eager execution is enabled.')
2972 predictions, labels, weights = _remove_squeezable_dimensions(
2973 predictions=predictions, labels=labels, weights=weights)
2974 mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
2975 None, name or
2976 'root_mean_squared_error')
2978 once_across_replicas = lambda _, mse: math_ops.sqrt(mse)
2979 rmse = _aggregate_across_replicas(
2980 metrics_collections, once_across_replicas, mse)
2982 update_rmse_op = math_ops.sqrt(update_mse_op)
2983 if updates_collections:
2984 ops.add_to_collections(updates_collections, update_rmse_op)
2986 return rmse, update_rmse_op
2989@tf_export(v1=['metrics.sensitivity_at_specificity'])
2990def sensitivity_at_specificity(labels,
2991 predictions,
2992 specificity,
2993 weights=None,
2994 num_thresholds=200,
2995 metrics_collections=None,
2996 updates_collections=None,
2997 name=None):
2998 """Computes the specificity at a given sensitivity.
3000 The `sensitivity_at_specificity` function creates four local
3001 variables, `true_positives`, `true_negatives`, `false_positives` and
3002 `false_negatives` that are used to compute the sensitivity at the given
3003 specificity value. The threshold for the given specificity value is computed
3004 and used to evaluate the corresponding sensitivity.
3006 For estimation of the metric over a stream of data, the function creates an
3007 `update_op` operation that updates these variables and returns the
3008 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
3009 `false_positives` and `false_negatives` counts with the weight of each case
3010 found in the `predictions` and `labels`.
3012 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3014 For additional information about specificity and sensitivity, see the
3015 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
3017 Args:
3018 labels: The ground truth values, a `Tensor` whose dimensions must match
3019 `predictions`. Will be cast to `bool`.
3020 predictions: A floating point `Tensor` of arbitrary shape and whose values
3021 are in the range `[0, 1]`.
3022 specificity: A scalar value in range `[0, 1]`.
3023 weights: Optional `Tensor` whose rank is either 0, or the same rank as
3024 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
3025 be either `1`, or the same as the corresponding `labels` dimension).
3026 num_thresholds: The number of thresholds to use for matching the given
3027 specificity.
3028 metrics_collections: An optional list of collections that `sensitivity`
3029 should be added to.
3030 updates_collections: An optional list of collections that `update_op` should
3031 be added to.
3032 name: An optional variable_scope name.
3034 Returns:
3035 sensitivity: A scalar `Tensor` representing the sensitivity at the given
3036 `specificity` value.
3037 update_op: An operation that increments the `true_positives`,
3038 `true_negatives`, `false_positives` and `false_negatives` variables
3039 appropriately and whose value matches `sensitivity`.
3041 Raises:
3042 ValueError: If `predictions` and `labels` have mismatched shapes, if
3043 `weights` is not `None` and its shape doesn't match `predictions`, or if
3044 `specificity` is not between 0 and 1, or if either `metrics_collections`
3045 or `updates_collections` are not a list or tuple.
3046 RuntimeError: If eager execution is enabled.
3047 """
3048 if context.executing_eagerly():
3049 raise RuntimeError('tf.metrics.sensitivity_at_specificity is not '
3050 'supported when eager execution is enabled.')
3052 if specificity < 0 or specificity > 1:
3053 raise ValueError('`specificity` must be in the range [0, 1]. Currently, '
3054 f'`specificity` got {specificity}.')
3056 with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
3057 (predictions, labels, weights)):
3058 kepsilon = 1e-7 # to account for floating point imprecisions
3059 thresholds = [
3060 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
3061 ]
3062 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
3064 values, update_ops = _confusion_matrix_at_thresholds(
3065 labels, predictions, thresholds, weights)
3067 def compute_sensitivity_at_specificity(tp, tn, fp, fn, name):
3068 specificities = math_ops.divide(tn, tn + fp + kepsilon)
3069 tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
3070 tf_index = math_ops.cast(tf_index, dtypes.int32)
3072 # Now, we have the implicit threshold, so compute the sensitivity:
3073 return math_ops.divide(tp[tf_index],
3074 tp[tf_index] + fn[tf_index] + kepsilon, name)
3076 def sensitivity_across_replicas(_, values):
3077 return compute_sensitivity_at_specificity(
3078 values['tp'], values['tn'], values['fp'], values['fn'], 'value')
3080 sensitivity = _aggregate_across_replicas(
3081 metrics_collections, sensitivity_across_replicas, values)
3083 update_op = compute_sensitivity_at_specificity(
3084 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
3085 'update_op')
3086 if updates_collections:
3087 ops.add_to_collections(updates_collections, update_op)
3089 return sensitivity, update_op
3092def _expand_and_tile(tensor, multiple, dim=0, name=None):
3093 """Slice `tensor` shape in 2, then tile along the sliced dimension.
3095 A new dimension is inserted in shape of `tensor` before `dim`, then values are
3096 tiled `multiple` times along the new dimension.
3098 Args:
3099 tensor: Input `Tensor` or `SparseTensor`.
3100 multiple: Integer, number of times to tile.
3101 dim: Integer, dimension along which to tile.
3102 name: Name of operation.
3104 Returns:
3105 `Tensor` result of expanding and tiling `tensor`.
3107 Raises:
3108 ValueError: if `multiple` is less than 1, or `dim` is not in
3109 `[-rank(tensor), rank(tensor)]`.
3110 """
3111 if multiple < 1:
3112 raise ValueError(f'Invalid argument multiple={multiple} for '
3113 'expand_and_tile call. `multiple` must be an integer > 0')
3114 with ops.name_scope(name, 'expand_and_tile',
3115 (tensor, multiple, dim)) as scope:
3116 # Sparse.
3117 tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor)
3118 if isinstance(tensor, sparse_tensor.SparseTensor):
3119 if dim < 0:
3120 expand_dims = array_ops.reshape(
3121 array_ops.size(tensor.dense_shape) + dim, [1])
3122 else:
3123 expand_dims = [dim]
3124 expanded_shape = array_ops.concat(
3125 (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1],
3126 array_ops.slice(tensor.dense_shape, expand_dims, [-1])),
3127 0,
3128 name='expanded_shape')
3129 expanded = sparse_ops.sparse_reshape(
3130 tensor, shape=expanded_shape, name='expand')
3131 if multiple == 1:
3132 return expanded
3133 return sparse_ops.sparse_concat(
3134 dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope)
3136 # Dense.
3137 expanded = array_ops.expand_dims(
3138 tensor, dim if (dim >= 0) else (dim - 1), name='expand')
3139 if multiple == 1:
3140 return expanded
3141 ones = array_ops.ones_like(array_ops.shape(tensor))
3142 tile_multiples = array_ops.concat(
3143 (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples')
3144 return array_ops.tile(expanded, tile_multiples, name=scope)
3147def _num_relevant(labels, k):
3148 """Computes number of relevant values for each row in labels.
3150 For labels with shape [D1, ... DN, num_labels], this is the minimum of
3151 `num_labels` and `k`.
3153 Args:
3154 labels: `int64` `Tensor` or `SparseTensor` with shape
3155 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3156 target classes for the associated prediction. Commonly, N=1 and `labels`
3157 has shape [batch_size, num_labels].
3158 k: Integer, k for @k metric.
3160 Returns:
3161 Integer `Tensor` of shape [D1, ... DN], where each value is the number of
3162 relevant values for that row.
3164 Raises:
3165 ValueError: if inputs have invalid dtypes or values.
3166 """
3167 if k < 1:
3168 raise ValueError(f'Invalid k={k}')
3169 with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
3170 # For SparseTensor, calculate separate count for each row.
3171 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
3172 if isinstance(labels, sparse_tensor.SparseTensor):
3173 return math_ops.minimum(sets.set_size(labels), k, name=scope)
3175 # The relevant values for each (d1, ... dN) is the minimum of k and the
3176 # number of labels along the last dimension that are non-negative.
3177 num_labels = math_ops.reduce_sum(
3178 array_ops.where_v2(math_ops.greater_equal(labels, 0),
3179 array_ops.ones_like(labels),
3180 array_ops.zeros_like(labels)),
3181 axis=-1)
3182 return math_ops.minimum(num_labels, k, name=scope)
3185def _sparse_average_precision_at_top_k(labels, predictions_idx):
3186 """Computes average precision@k of predictions with respect to sparse labels.
3188 From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula
3189 for each row is:
3191 AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items
3193 A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`,
3194 `labels`, and the result `Tensors`. In the common case, this is [batch_size].
3195 Each row of the results contains the average precision for that row.
3197 Args:
3198 labels: `int64` `Tensor` or `SparseTensor` with shape
3199 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3200 num_labels=1. N >= 1 and num_labels is the number of target classes for
3201 the associated prediction. Commonly, N=1 and `labels` has shape
3202 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3203 Values should be non-negative. Negative values are ignored.
3204 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3205 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3206 dimension must be set and contains the top `k` predicted class indices.
3207 [D1, ... DN] must match `labels`. Values should be in range
3208 [0, num_classes).
3210 Returns:
3211 `float64` `Tensor` of shape [D1, ... DN], where each value is the average
3212 precision for that row.
3214 Raises:
3215 ValueError: if the last dimension of predictions_idx is not set.
3216 """
3217 with ops.name_scope(None, 'average_precision',
3218 (predictions_idx, labels)) as scope:
3219 predictions_idx = math_ops.cast(
3220 predictions_idx, dtypes.int64, name='predictions_idx')
3221 if predictions_idx.get_shape().ndims == 0:
3222 raise ValueError('The rank of `predictions_idx` must be at least 1.')
3223 k = predictions_idx.get_shape().as_list()[-1]
3224 if k is None:
3225 raise ValueError('The last dimension of predictions_idx must be set. '
3226 'Currently, it is None.')
3227 labels = _maybe_expand_labels(labels, predictions_idx)
3229 # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate
3230 # prediction for each k, so we can calculate separate true positive values
3231 # for each k.
3232 predictions_idx_per_k = array_ops.expand_dims(
3233 predictions_idx, -1, name='predictions_idx_per_k')
3235 # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor.
3236 labels_per_k = _expand_and_tile(
3237 labels, multiple=k, dim=-1, name='labels_per_k')
3239 # The following tensors are all of shape [D1, ... DN, k], containing values
3240 # per row, per k value.
3241 # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at
3242 # that k value is correct, 0 otherwise. This is the "rel_{i}" term from
3243 # the formula above.
3244 # `tp_per_k` (int32) - True positive counts.
3245 # `retrieved_per_k` (int32) - Number of predicted values at each k. This is
3246 # the precision denominator.
3247 # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}"
3248 # term from the formula above.
3249 # `relevant_precision_per_k` (float64) - Relevant precisions; i.e.,
3250 # precisions at all k for which relevance indicator is true.
3251 relevant_per_k = _sparse_true_positive_at_k(
3252 labels_per_k, predictions_idx_per_k, name='relevant_per_k')
3253 tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
3254 retrieved_per_k = math_ops.cumsum(
3255 array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
3256 precision_per_k = math_ops.divide(
3257 math_ops.cast(tp_per_k, dtypes.float64),
3258 math_ops.cast(retrieved_per_k, dtypes.float64),
3259 name='precision_per_k')
3260 relevant_precision_per_k = math_ops.multiply(
3261 precision_per_k,
3262 math_ops.cast(relevant_per_k, dtypes.float64),
3263 name='relevant_precision_per_k')
3265 # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
3266 precision_sum = math_ops.reduce_sum(
3267 relevant_precision_per_k, axis=(-1,), name='precision_sum')
3269 # Divide by number of relevant items to get average precision. These are
3270 # the "num_relevant_items" and "AveP" terms from the formula above.
3271 num_relevant_items = math_ops.cast(_num_relevant(labels, k), dtypes.float64)
3272 return math_ops.divide(precision_sum, num_relevant_items, name=scope)
3275def _streaming_sparse_average_precision_at_top_k(labels,
3276 predictions_idx,
3277 weights=None,
3278 metrics_collections=None,
3279 updates_collections=None,
3280 name=None):
3281 """Computes average precision@k of predictions with respect to sparse labels.
3283 `sparse_average_precision_at_top_k` creates two local variables,
3284 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3285 are used to compute the frequency. This frequency is ultimately returned as
3286 `average_precision_at_<k>`: an idempotent operation that simply divides
3287 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3289 For estimation of the metric over a stream of data, the function creates an
3290 `update_op` operation that updates these variables and returns the
3291 `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate
3292 the true positives and false positives weighted by `weights`. Then `update_op`
3293 increments `true_positive_at_<k>` and `false_positive_at_<k>` using these
3294 values.
3296 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3298 Args:
3299 labels: `int64` `Tensor` or `SparseTensor` with shape
3300 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3301 num_labels=1. N >= 1 and num_labels is the number of target classes for
3302 the associated prediction. Commonly, N=1 and `labels` has shape
3303 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3304 Values should be non-negative. Negative values are ignored.
3305 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3306 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3307 dimension contains the top `k` predicted class indices. [D1, ... DN] must
3308 match `labels`. Values should be in range [0, num_classes).
3309 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3310 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3311 dimensions must be either `1`, or the same as the corresponding `labels`
3312 dimension).
3313 metrics_collections: An optional list of collections that values should
3314 be added to.
3315 updates_collections: An optional list of collections that updates should
3316 be added to.
3317 name: Name of new update operation, and namespace for other dependent ops.
3319 Returns:
3320 mean_average_precision: Scalar `float64` `Tensor` with the mean average
3321 precision values.
3322 update: `Operation` that increments variables appropriately, and whose
3323 value matches `metric`.
3324 """
3325 with ops.name_scope(name, 'average_precision_at_top_k',
3326 (predictions_idx, labels, weights)) as scope:
3327 # Calculate per-example average precision, and apply weights.
3328 average_precision = _sparse_average_precision_at_top_k(
3329 predictions_idx=predictions_idx, labels=labels)
3330 if weights is not None:
3331 weights = weights_broadcast_ops.broadcast_weights(
3332 math_ops.cast(weights, dtypes.float64), average_precision)
3333 average_precision = math_ops.multiply(average_precision, weights)
3335 # Create accumulation variables and update ops for max average precision and
3336 # total average precision.
3337 with ops.name_scope(None, 'max', (average_precision,)) as max_scope:
3338 # `max` is the max possible precision. Since max for any row is 1.0:
3339 # - For the unweighted case, this is just the number of rows.
3340 # - For the weighted case, it's the sum of the weights broadcast across
3341 # `average_precision` rows.
3342 max_var = metric_variable([], dtypes.float64, name=max_scope)
3343 if weights is None:
3344 batch_max = math_ops.cast(
3345 array_ops.size(average_precision, name='batch_max'), dtypes.float64)
3346 else:
3347 batch_max = math_ops.reduce_sum(weights, name='batch_max')
3348 max_update = state_ops.assign_add(max_var, batch_max, name='update')
3349 with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
3350 total_var = metric_variable([], dtypes.float64, name=total_scope)
3351 batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
3352 total_update = state_ops.assign_add(total_var, batch_total, name='update')
3354 # Divide total by max to get mean, for both vars and the update ops.
3355 def precision_across_replicas(_, total_var, max_var):
3356 return _safe_scalar_div(total_var, max_var, name='mean')
3358 mean_average_precision = _aggregate_across_replicas(
3359 metrics_collections, precision_across_replicas, total_var, max_var)
3361 update = _safe_scalar_div(total_update, max_update, name=scope)
3362 if updates_collections:
3363 ops.add_to_collections(updates_collections, update)
3365 return mean_average_precision, update
3368def _clean_out_of_range_indices(labels, num_classes):
3369 """Replaces large out-of-range labels by small out-of-range labels.
3371 Replaces any value in `labels` that is greater or equal to `num_classes` by
3372 -1. Do this conditionally for efficiency in case there are no such values.
3374 Args:
3375 labels: `int64` `Tensor` or `SparseTensor`.
3376 num_classes: `int64` scalar `Tensor`.
3377 Returns:
3378 An `int64` `Tensor` or `SparseTensor` as `labels` with indices greater
3379 or equal to num_classes replaced by -1.
3380 """
3382 def _labels_is_sparse():
3383 """Returns true is `labels` is a sparse tensor."""
3384 return isinstance(labels, (sparse_tensor.SparseTensor,
3385 sparse_tensor.SparseTensorValue))
3387 def _clean_out_of_range(values):
3388 """Replaces by -1 any large out-of-range `values`."""
3389 return array_ops.where_v2(math_ops.greater_equal(values, num_classes),
3390 -1 * array_ops.ones_like(values), values)
3392 def _clean_labels_out_of_range():
3393 """Replaces by -1 ane large out-of-range values in `labels`."""
3394 if _labels_is_sparse():
3395 return type(labels)(indices=labels.indices,
3396 values=_clean_out_of_range(labels.values),
3397 dense_shape=labels.dense_shape)
3398 else:
3399 return _clean_out_of_range(labels)
3401 max_labels = math_ops.reduce_max(
3402 labels.values if _labels_is_sparse() else labels)
3403 return cond.cond(
3404 math_ops.greater_equal(max_labels, num_classes),
3405 _clean_labels_out_of_range,
3406 lambda: labels)
3409@tf_export(v1=['metrics.sparse_average_precision_at_k'])
3410@deprecated(None, 'Use average_precision_at_k instead')
3411def sparse_average_precision_at_k(labels,
3412 predictions,
3413 k,
3414 weights=None,
3415 metrics_collections=None,
3416 updates_collections=None,
3417 name=None):
3418 """Renamed to `average_precision_at_k`, please use that method instead."""
3419 return average_precision_at_k(
3420 labels=labels,
3421 predictions=predictions,
3422 k=k,
3423 weights=weights,
3424 metrics_collections=metrics_collections,
3425 updates_collections=updates_collections,
3426 name=name)
3429@tf_export(v1=['metrics.average_precision_at_k'])
3430def average_precision_at_k(labels,
3431 predictions,
3432 k,
3433 weights=None,
3434 metrics_collections=None,
3435 updates_collections=None,
3436 name=None):
3437 """Computes average precision@k of predictions with respect to sparse labels.
3439 `average_precision_at_k` creates two local variables,
3440 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3441 are used to compute the frequency. This frequency is ultimately returned as
3442 `average_precision_at_<k>`: an idempotent operation that simply divides
3443 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3445 For estimation of the metric over a stream of data, the function creates an
3446 `update_op` operation that updates these variables and returns the
3447 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3448 indicating the top `k` `predictions`. Set operations applied to `top_k` and
3449 `labels` calculate the true positives and false positives weighted by
3450 `weights`. Then `update_op` increments `true_positive_at_<k>` and
3451 `false_positive_at_<k>` using these values.
3453 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3455 Args:
3456 labels: `int64` `Tensor` or `SparseTensor` with shape
3457 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3458 num_labels=1. N >= 1 and num_labels is the number of target classes for
3459 the associated prediction. Commonly, N=1 and `labels` has shape
3460 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3461 should be in range [0, num_classes), where num_classes is the last
3462 dimension of `predictions`. Values outside this range are ignored.
3463 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3464 N >= 1. Commonly, N=1 and `predictions` has shape
3465 [batch size, num_classes]. The final dimension contains the logit values
3466 for each class. [D1, ... DN] must match `labels`.
3467 k: Integer, k for @k metric. This will calculate an average precision for
3468 range `[1,k]`, as documented above.
3469 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3470 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3471 dimensions must be either `1`, or the same as the corresponding `labels`
3472 dimension).
3473 metrics_collections: An optional list of collections that values should
3474 be added to.
3475 updates_collections: An optional list of collections that updates should
3476 be added to.
3477 name: Name of new update operation, and namespace for other dependent ops.
3479 Returns:
3480 mean_average_precision: Scalar `float64` `Tensor` with the mean average
3481 precision values.
3482 update: `Operation` that increments variables appropriately, and whose
3483 value matches `metric`.
3485 Raises:
3486 ValueError: if k is invalid.
3487 RuntimeError: If eager execution is enabled.
3488 """
3489 if context.executing_eagerly():
3490 raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not '
3491 'supported when eager execution is enabled.')
3493 if k < 1:
3494 raise ValueError(f'Invalid k={k}. `k` should be >= 1.')
3495 with ops.name_scope(name, _at_k_name('average_precision', k),
3496 (predictions, labels, weights)) as scope:
3497 # Calculate top k indices to produce [D1, ... DN, k] tensor.
3498 _, predictions_idx = nn.top_k(predictions, k)
3499 # The documentation states that labels should be in [0, ..., num_classes),
3500 # but num_classes is lost when predictions_idx replaces predictions.
3501 # For conformity with the documentation, any label >= num_classes, which is
3502 # ignored, is replaced by -1.
3503 labels = _clean_out_of_range_indices(
3504 labels, math_ops.cast(array_ops.shape(predictions)[-1], dtypes.int64))
3505 return _streaming_sparse_average_precision_at_top_k(
3506 labels=labels,
3507 predictions_idx=predictions_idx,
3508 weights=weights,
3509 metrics_collections=metrics_collections,
3510 updates_collections=updates_collections,
3511 name=scope)
3514def _sparse_false_positive_at_k(labels,
3515 predictions_idx,
3516 class_id=None,
3517 weights=None):
3518 """Calculates false positives for precision@k.
3520 If `class_id` is specified, calculate binary true positives for `class_id`
3521 only.
3522 If `class_id` is not specified, calculate metrics for `k` predicted vs
3523 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
3525 Args:
3526 labels: `int64` `Tensor` or `SparseTensor` with shape
3527 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3528 target classes for the associated prediction. Commonly, N=1 and `labels`
3529 has shape [batch_size, num_labels]. [D1, ... DN] must match
3530 `predictions_idx`.
3531 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3532 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3533 match `labels`.
3534 class_id: Class for which we want binary metrics.
3535 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3536 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3537 dimensions must be either `1`, or the same as the corresponding `labels`
3538 dimension).
3540 Returns:
3541 A [D1, ... DN] `Tensor` of false positive counts.
3542 """
3543 with ops.name_scope(None, 'false_positives',
3544 (predictions_idx, labels, weights)):
3545 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
3546 class_id)
3547 fp = sets.set_size(
3548 sets.set_difference(predictions_idx, labels, aminusb=True))
3549 fp = math_ops.cast(fp, dtypes.float64)
3550 if weights is not None:
3551 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
3552 weights, fp),)):
3553 weights = math_ops.cast(weights, dtypes.float64)
3554 fp = math_ops.multiply(fp, weights)
3555 return fp
3558def _streaming_sparse_false_positive_at_k(labels,
3559 predictions_idx,
3560 k=None,
3561 class_id=None,
3562 weights=None,
3563 name=None):
3564 """Calculates weighted per step false positives for precision@k.
3566 If `class_id` is specified, calculate binary true positives for `class_id`
3567 only.
3568 If `class_id` is not specified, calculate metrics for `k` predicted vs
3569 `n` label classes, where `n` is the 2nd dimension of `labels`.
3571 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3573 Args:
3574 labels: `int64` `Tensor` or `SparseTensor` with shape
3575 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3576 target classes for the associated prediction. Commonly, N=1 and `labels`
3577 has shape [batch_size, num_labels]. [D1, ... DN] must match
3578 `predictions_idx`.
3579 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3580 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3581 match `labels`.
3582 k: Integer, k for @k metric. This is only used for default op name.
3583 class_id: Class for which we want binary metrics.
3584 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3585 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3586 dimensions must be either `1`, or the same as the corresponding `labels`
3587 dimension).
3588 name: Name of new variable, and namespace for other dependent ops.
3590 Returns:
3591 A tuple of `Variable` and update `Operation`.
3593 Raises:
3594 ValueError: If `weights` is not `None` and has an incompatible shape.
3595 """
3596 with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id),
3597 (predictions_idx, labels, weights)) as scope:
3598 fp = _sparse_false_positive_at_k(
3599 predictions_idx=predictions_idx,
3600 labels=labels,
3601 class_id=class_id,
3602 weights=weights)
3603 batch_total_fp = math_ops.cast(math_ops.reduce_sum(fp), dtypes.float64)
3605 var = metric_variable([], dtypes.float64, name=scope)
3606 return var, state_ops.assign_add(var, batch_total_fp, name='update')
3609@tf_export(v1=['metrics.precision_at_top_k'])
3610def precision_at_top_k(labels,
3611 predictions_idx,
3612 k=None,
3613 class_id=None,
3614 weights=None,
3615 metrics_collections=None,
3616 updates_collections=None,
3617 name=None):
3618 """Computes precision@k of the predictions with respect to sparse labels.
3620 Differs from `sparse_precision_at_k` in that predictions must be in the form
3621 of top `k` class indices, whereas `sparse_precision_at_k` expects logits.
3622 Refer to `sparse_precision_at_k` for more details.
3624 Args:
3625 labels: `int64` `Tensor` or `SparseTensor` with shape
3626 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3627 num_labels=1. N >= 1 and num_labels is the number of target classes for
3628 the associated prediction. Commonly, N=1 and `labels` has shape
3629 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3630 should be in range [0, num_classes), where num_classes is the last
3631 dimension of `predictions`. Values outside this range are ignored.
3632 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where
3633 N >= 1. Commonly, N=1 and predictions has shape [batch size, k].
3634 The final dimension contains the top `k` predicted class indices.
3635 [D1, ... DN] must match `labels`.
3636 k: Integer, k for @k metric. Only used for the default op name.
3637 class_id: Integer class ID for which we want binary metrics. This should be
3638 in range [0, num_classes], where num_classes is the last dimension of
3639 `predictions`. If `class_id` is outside this range, the method returns
3640 NAN.
3641 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3642 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3643 dimensions must be either `1`, or the same as the corresponding `labels`
3644 dimension).
3645 metrics_collections: An optional list of collections that values should
3646 be added to.
3647 updates_collections: An optional list of collections that updates should
3648 be added to.
3649 name: Name of new update operation, and namespace for other dependent ops.
3651 Returns:
3652 precision: Scalar `float64` `Tensor` with the value of `true_positives`
3653 divided by the sum of `true_positives` and `false_positives`.
3654 update_op: `Operation` that increments `true_positives` and
3655 `false_positives` variables appropriately, and whose value matches
3656 `precision`.
3658 Raises:
3659 ValueError: If `weights` is not `None` and its shape doesn't match
3660 `predictions`, or if either `metrics_collections` or `updates_collections`
3661 are not a list or tuple.
3662 RuntimeError: If eager execution is enabled.
3663 """
3664 if context.executing_eagerly():
3665 raise RuntimeError('tf.metrics.precision_at_top_k is not '
3666 'supported when eager execution is enabled.')
3668 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3669 (predictions_idx, labels, weights)) as scope:
3670 labels = _maybe_expand_labels(labels, predictions_idx)
3671 top_k_idx = math_ops.cast(predictions_idx, dtypes.int64)
3672 tp, tp_update = _streaming_sparse_true_positive_at_k(
3673 predictions_idx=top_k_idx,
3674 labels=labels,
3675 k=k,
3676 class_id=class_id,
3677 weights=weights)
3678 fp, fp_update = _streaming_sparse_false_positive_at_k(
3679 predictions_idx=top_k_idx,
3680 labels=labels,
3681 k=k,
3682 class_id=class_id,
3683 weights=weights)
3685 def precision_across_replicas(_, tp, fp):
3686 return math_ops.divide(tp, math_ops.add(tp, fp), name=scope)
3688 metric = _aggregate_across_replicas(
3689 metrics_collections, precision_across_replicas, tp, fp)
3691 update = math_ops.divide(
3692 tp_update, math_ops.add(tp_update, fp_update), name='update')
3693 if updates_collections:
3694 ops.add_to_collections(updates_collections, update)
3695 return metric, update
3698@tf_export(v1=['metrics.sparse_precision_at_k'])
3699@deprecated(None, 'Use precision_at_k instead')
3700def sparse_precision_at_k(labels,
3701 predictions,
3702 k,
3703 class_id=None,
3704 weights=None,
3705 metrics_collections=None,
3706 updates_collections=None,
3707 name=None):
3708 """Renamed to `precision_at_k`, please use that method instead."""
3709 return precision_at_k(
3710 labels=labels,
3711 predictions=predictions,
3712 k=k,
3713 class_id=class_id,
3714 weights=weights,
3715 metrics_collections=metrics_collections,
3716 updates_collections=updates_collections,
3717 name=name)
3720@tf_export(v1=['metrics.precision_at_k'])
3721def precision_at_k(labels,
3722 predictions,
3723 k,
3724 class_id=None,
3725 weights=None,
3726 metrics_collections=None,
3727 updates_collections=None,
3728 name=None):
3729 """Computes precision@k of the predictions with respect to sparse labels.
3731 If `class_id` is specified, we calculate precision by considering only the
3732 entries in the batch for which `class_id` is in the top-k highest
3733 `predictions`, and computing the fraction of them for which `class_id` is
3734 indeed a correct label.
3735 If `class_id` is not specified, we'll calculate precision as how often on
3736 average a class among the top-k classes with the highest predicted values
3737 of a batch entry is correct and can be found in the label for that entry.
3739 `precision_at_k` creates two local variables,
3740 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
3741 the precision@k frequency. This frequency is ultimately returned as
3742 `precision_at_<k>`: an idempotent operation that simply divides
3743 `true_positive_at_<k>` by total (`true_positive_at_<k>` +
3744 `false_positive_at_<k>`).
3746 For estimation of the metric over a stream of data, the function creates an
3747 `update_op` operation that updates these variables and returns the
3748 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3749 indicating the top `k` `predictions`. Set operations applied to `top_k` and
3750 `labels` calculate the true positives and false positives weighted by
3751 `weights`. Then `update_op` increments `true_positive_at_<k>` and
3752 `false_positive_at_<k>` using these values.
3754 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3756 Args:
3757 labels: `int64` `Tensor` or `SparseTensor` with shape
3758 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3759 num_labels=1. N >= 1 and num_labels is the number of target classes for
3760 the associated prediction. Commonly, N=1 and `labels` has shape
3761 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3762 should be in range [0, num_classes), where num_classes is the last
3763 dimension of `predictions`. Values outside this range are ignored.
3764 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3765 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
3766 The final dimension contains the logit values for each class. [D1, ... DN]
3767 must match `labels`.
3768 k: Integer, k for @k metric.
3769 class_id: Integer class ID for which we want binary metrics. This should be
3770 in range [0, num_classes], where num_classes is the last dimension of
3771 `predictions`. If `class_id` is outside this range, the method returns
3772 NAN.
3773 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3774 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3775 dimensions must be either `1`, or the same as the corresponding `labels`
3776 dimension).
3777 metrics_collections: An optional list of collections that values should
3778 be added to.
3779 updates_collections: An optional list of collections that updates should
3780 be added to.
3781 name: Name of new update operation, and namespace for other dependent ops.
3783 Returns:
3784 precision: Scalar `float64` `Tensor` with the value of `true_positives`
3785 divided by the sum of `true_positives` and `false_positives`.
3786 update_op: `Operation` that increments `true_positives` and
3787 `false_positives` variables appropriately, and whose value matches
3788 `precision`.
3790 Raises:
3791 ValueError: If `weights` is not `None` and its shape doesn't match
3792 `predictions`, or if either `metrics_collections` or `updates_collections`
3793 are not a list or tuple.
3794 RuntimeError: If eager execution is enabled.
3795 """
3796 if context.executing_eagerly():
3797 raise RuntimeError('tf.metrics.sparse_precision_at_k is not '
3798 'supported when eager execution is enabled.')
3800 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3801 (predictions, labels, weights)) as scope:
3802 _, top_k_idx = nn.top_k(predictions, k)
3803 return precision_at_top_k(
3804 labels=labels,
3805 predictions_idx=top_k_idx,
3806 k=k,
3807 class_id=class_id,
3808 weights=weights,
3809 metrics_collections=metrics_collections,
3810 updates_collections=updates_collections,
3811 name=scope)
3814@tf_export(v1=['metrics.specificity_at_sensitivity'])
3815def specificity_at_sensitivity(labels,
3816 predictions,
3817 sensitivity,
3818 weights=None,
3819 num_thresholds=200,
3820 metrics_collections=None,
3821 updates_collections=None,
3822 name=None):
3823 """Computes the specificity at a given sensitivity.
3825 The `specificity_at_sensitivity` function creates four local
3826 variables, `true_positives`, `true_negatives`, `false_positives` and
3827 `false_negatives` that are used to compute the specificity at the given
3828 sensitivity value. The threshold for the given sensitivity value is computed
3829 and used to evaluate the corresponding specificity.
3831 For estimation of the metric over a stream of data, the function creates an
3832 `update_op` operation that updates these variables and returns the
3833 `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
3834 `false_positives` and `false_negatives` counts with the weight of each case
3835 found in the `predictions` and `labels`.
3837 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3839 For additional information about specificity and sensitivity, see the
3840 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
3842 Args:
3843 labels: The ground truth values, a `Tensor` whose dimensions must match
3844 `predictions`. Will be cast to `bool`.
3845 predictions: A floating point `Tensor` of arbitrary shape and whose values
3846 are in the range `[0, 1]`.
3847 sensitivity: A scalar value in range `[0, 1]`.
3848 weights: Optional `Tensor` whose rank is either 0, or the same rank as
3849 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
3850 be either `1`, or the same as the corresponding `labels` dimension).
3851 num_thresholds: The number of thresholds to use for matching the given
3852 sensitivity.
3853 metrics_collections: An optional list of collections that `specificity`
3854 should be added to.
3855 updates_collections: An optional list of collections that `update_op` should
3856 be added to.
3857 name: An optional variable_scope name.
3859 Returns:
3860 specificity: A scalar `Tensor` representing the specificity at the given
3861 `sensitivity` value.
3862 update_op: An operation that increments the `true_positives`,
3863 `true_negatives`, `false_positives` and `false_negatives` variables
3864 appropriately and whose value matches `specificity`.
3866 Raises:
3867 ValueError: If `predictions` and `labels` have mismatched shapes, if
3868 `weights` is not `None` and its shape doesn't match `predictions`, or if
3869 `sensitivity` is not between 0 and 1, or if either `metrics_collections`
3870 or `updates_collections` are not a list or tuple.
3871 RuntimeError: If eager execution is enabled.
3872 """
3873 if context.executing_eagerly():
3874 raise RuntimeError('tf.metrics.specificity_at_sensitivity is not '
3875 'supported when eager execution is enabled.')
3877 if sensitivity < 0 or sensitivity > 1:
3878 raise ValueError('`sensitivity` must be in the range [0, 1]. Currently, '
3879 f'`sensitivity` is {sensitivity}.')
3881 with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
3882 (predictions, labels, weights)):
3883 kepsilon = 1e-7 # to account for floating point imprecisions
3884 thresholds = [
3885 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
3886 ]
3887 thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
3889 values, update_ops = _confusion_matrix_at_thresholds(
3890 labels, predictions, thresholds, weights)
3892 def compute_specificity_at_sensitivity(tp, tn, fp, fn, name):
3893 """Computes the specificity at the given sensitivity.
3895 Args:
3896 tp: True positives.
3897 tn: True negatives.
3898 fp: False positives.
3899 fn: False negatives.
3900 name: The name of the operation.
3902 Returns:
3903 The specificity using the aggregated values.
3904 """
3905 sensitivities = math_ops.divide(tp, tp + fn + kepsilon)
3907 # We'll need to use this trick until tf.argmax allows us to specify
3908 # whether we should use the first or last index in case of ties.
3909 min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
3910 indices_at_minval = math_ops.equal(
3911 math_ops.abs(sensitivities - sensitivity), min_val)
3912 indices_at_minval = math_ops.cast(indices_at_minval, dtypes.int64)
3913 indices_at_minval = math_ops.cumsum(indices_at_minval)
3914 tf_index = math_ops.argmax(indices_at_minval, 0)
3915 tf_index = math_ops.cast(tf_index, dtypes.int32)
3917 # Now, we have the implicit threshold, so compute the specificity:
3918 return math_ops.divide(tn[tf_index],
3919 tn[tf_index] + fp[tf_index] + kepsilon, name)
3921 def specificity_across_replicas(_, values):
3922 return compute_specificity_at_sensitivity(
3923 values['tp'], values['tn'], values['fp'], values['fn'], 'value')
3925 specificity = _aggregate_across_replicas(
3926 metrics_collections, specificity_across_replicas, values)
3928 update_op = compute_specificity_at_sensitivity(
3929 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
3930 'update_op')
3931 if updates_collections:
3932 ops.add_to_collections(updates_collections, update_op)
3934 return specificity, update_op