Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/metrics_utils.py: 14%
327 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Utils related to keras metrics."""
18import functools
19import weakref
20from enum import Enum
22import numpy as np
23import tensorflow.compat.v2 as tf
25from keras.src import backend
26from keras.src.utils import losses_utils
27from keras.src.utils import tf_utils
28from keras.src.utils.generic_utils import to_list
30NEG_INF = -1e10
33class Reduction(Enum):
34 """Types of metrics reduction.
36 Contains the following values:
38 * `SUM`: Scalar sum of weighted values.
39 * `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by
40 number of elements.
41 * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights.
42 """
44 SUM = "sum"
45 SUM_OVER_BATCH_SIZE = "sum_over_batch_size"
46 WEIGHTED_MEAN = "weighted_mean"
49def update_state_wrapper(update_state_fn):
50 """Decorator to wrap metric `update_state()` with `add_update()`.
52 Args:
53 update_state_fn: function that accumulates metric statistics.
55 Returns:
56 Decorated function that wraps `update_state_fn()` with `add_update()`.
57 """
59 def decorated(metric_obj, *args, **kwargs):
60 """Decorated function with `add_update()`."""
61 strategy = tf.distribute.get_strategy()
63 for weight in metric_obj.weights:
64 if (
65 backend.is_tpu_strategy(strategy)
66 and not strategy.extended.variable_created_in_scope(weight)
67 and not tf.distribute.in_cross_replica_context()
68 ):
69 raise ValueError(
70 "Trying to run metric.update_state in replica context when "
71 "the metric was not created in TPUStrategy scope. "
72 "Make sure the keras Metric is created in TPUstrategy "
73 "scope. "
74 )
76 with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
77 update_op = update_state_fn(*args, **kwargs)
78 if update_op is not None: # update_op will be None in eager execution.
79 metric_obj.add_update(update_op)
80 return update_op
82 return tf.__internal__.decorator.make_decorator(update_state_fn, decorated)
85def result_wrapper(result_fn):
86 """Decorator to wrap metric `result()` function in `merge_call()`.
88 Result computation is an idempotent operation that simply calculates the
89 metric value using the state variables.
91 If metric state variables are distributed across replicas/devices and
92 `result()` is requested from the context of one device - This function wraps
93 `result()` in a distribution strategy `merge_call()`. With this,
94 the metric state variables will be aggregated across devices.
96 Args:
97 result_fn: function that computes the metric result.
99 Returns:
100 Decorated function that wraps `result_fn()` in distribution strategy
101 `merge_call()`.
102 """
104 def decorated(metric_obj, *args):
105 """Decorated function with merge_call."""
106 replica_context = tf.distribute.get_replica_context()
108 # The purpose of using `merge_call` to call `result()` is to trigger
109 # cross replica aggregation of metric state variables
110 # (SyncOnReadVariable). After we introduced
111 # `variable_sync_on_read_context`, in principle there is no need to use
112 # `merge_call` here. However the branch still exists because:
113 #
114 # 1. Keras V1 training code sometimes assumes `result_t` is the same
115 # tensor across replicas (achieved by `merge_call`). With
116 # `variable_sync_on_read_context` each replica gets their own tensors
117 # residing on replica's device, thus breaking the assumption.
118 # 2. Keras c/fit creates a tf.function (a.k.a, train_function) that
119 # returns the metric values of the first replica. With
120 # `variable_sync_on_read_context` since each replica gets their own
121 # tensors, the metric result tensors on the non-first replicas are
122 # not in the return value of train_function, making TF graph
123 # optimizer prune the branch that computes and aggregates those
124 # metric results. As a result, if NCCL is used to do the aggregation,
125 # the program will hang because NCCL ops are only launched on the
126 # non-pruned first replica.
127 #
128 # We condition on strategy_supports_no_merge_call() since we know if it
129 # is True, the program uses `jit_compile` to compile replica fn, meaning
130 # it is not V1 training (hence #1 is okay), and no pruning will happen
131 # as compiled functions are not inlined (hence #2 is okay).
132 if (
133 replica_context is None
134 or tf.__internal__.distribute.strategy_supports_no_merge_call()
135 ):
136 with tf.__internal__.distribute.variable_sync_on_read_context():
137 raw_result = result_fn(*args)
138 # Results need to be wrapped in a `tf.identity` op to ensure
139 # correct execution order.
140 if isinstance(raw_result, (tf.Tensor, tf.Variable, float, int)):
141 result_t = tf.identity(raw_result)
142 elif isinstance(raw_result, dict):
143 result_t = tf.nest.map_structure(tf.identity, raw_result)
144 else:
145 try:
146 result_t = tf.identity(raw_result)
147 except (ValueError, TypeError):
148 raise RuntimeError(
149 "The output of `metric.result()` can only be a "
150 "single Tensor/Variable, or a dict of "
151 "Tensors/Variables. "
152 f"For metric {metric_obj.name}, "
153 f"got result {raw_result}."
154 )
155 else:
156 # TODO(psv): Test distribution of metrics using different
157 # distribution strategies.
159 # Creating a wrapper for merge_fn. merge_call invokes the given
160 # merge_fn with distribution object as the first parameter. We
161 # create a wrapper here so that the result function need not have
162 # that parameter.
163 def merge_fn_wrapper(distribution, merge_fn, *args):
164 # We will get `PerReplica` merge function. Taking the first one
165 # as all are identical copies of the function that we had passed
166 # below.
167 result = distribution.experimental_local_results(merge_fn)[0](
168 *args
169 )
171 # Wrapping result in identity so that control dependency between
172 # update_op from `update_state` and result works in case result
173 # returns a tensor.
174 return tf.nest.map_structure(tf.identity, result)
176 # Wrapping result in merge_call. merge_call is used when we want to
177 # leave replica mode and compute a value in cross replica mode.
178 result_t = replica_context.merge_call(
179 merge_fn_wrapper, args=(result_fn,) + args
180 )
182 # We are saving the result op here to be used in train/test execution
183 # functions. This basically gives the result op that was generated with
184 # a control dep to the updates for these workflows.
185 metric_obj._call_result = result_t
186 return result_t
188 return tf.__internal__.decorator.make_decorator(result_fn, decorated)
191def weakmethod(method):
192 """Creates a weak reference to the bound method."""
194 cls = method.im_class
195 func = method.im_func
196 instance_ref = weakref.ref(method.im_self)
198 @functools.wraps(method)
199 def inner(*args, **kwargs):
200 return func.__get__(instance_ref(), cls)(*args, **kwargs)
202 del method
203 return inner
206def assert_thresholds_range(thresholds):
207 if thresholds is not None:
208 invalid_thresholds = [
209 t for t in thresholds if t is None or t < 0 or t > 1
210 ]
211 if invalid_thresholds:
212 raise ValueError(
213 "Threshold values must be in [0, 1]. "
214 f"Received: {invalid_thresholds}"
215 )
218def parse_init_thresholds(thresholds, default_threshold=0.5):
219 if thresholds is not None:
220 assert_thresholds_range(to_list(thresholds))
221 thresholds = to_list(
222 default_threshold if thresholds is None else thresholds
223 )
224 return thresholds
227class ConfusionMatrix(Enum):
228 TRUE_POSITIVES = "tp"
229 FALSE_POSITIVES = "fp"
230 TRUE_NEGATIVES = "tn"
231 FALSE_NEGATIVES = "fn"
234class AUCCurve(Enum):
235 """Type of AUC Curve (ROC or PR)."""
237 ROC = "ROC"
238 PR = "PR"
240 @staticmethod
241 def from_str(key):
242 if key in ("pr", "PR"):
243 return AUCCurve.PR
244 elif key in ("roc", "ROC"):
245 return AUCCurve.ROC
246 else:
247 raise ValueError(
248 f'Invalid AUC curve value: "{key}". '
249 'Expected values are ["PR", "ROC"]'
250 )
253class AUCSummationMethod(Enum):
254 """Type of AUC summation method.
256 https://en.wikipedia.org/wiki/Riemann_sum)
258 Contains the following values:
259 * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
260 `PR` curve, interpolates (true/false) positives but not the ratio that is
261 precision (see Davis & Goadrich 2006 for details).
262 * 'minoring': Applies left summation for increasing intervals and right
263 summation for decreasing intervals.
264 * 'majoring': Applies right summation for increasing intervals and left
265 summation for decreasing intervals.
266 """
268 INTERPOLATION = "interpolation"
269 MAJORING = "majoring"
270 MINORING = "minoring"
272 @staticmethod
273 def from_str(key):
274 if key in ("interpolation", "Interpolation"):
275 return AUCSummationMethod.INTERPOLATION
276 elif key in ("majoring", "Majoring"):
277 return AUCSummationMethod.MAJORING
278 elif key in ("minoring", "Minoring"):
279 return AUCSummationMethod.MINORING
280 else:
281 raise ValueError(
282 f'Invalid AUC summation method value: "{key}". '
283 'Expected values are ["interpolation", "majoring", "minoring"]'
284 )
287def _update_confusion_matrix_variables_optimized(
288 variables_to_update,
289 y_true,
290 y_pred,
291 thresholds,
292 multi_label=False,
293 sample_weights=None,
294 label_weights=None,
295 thresholds_with_epsilon=False,
296):
297 """Update confusion matrix variables with memory efficient alternative.
299 Note that the thresholds need to be evenly distributed within the list, eg,
300 the diff between consecutive elements are the same.
302 To compute TP/FP/TN/FN, we are measuring a binary classifier
303 C(t) = (predictions >= t)
304 at each threshold 't'. So we have
305 TP(t) = sum( C(t) * true_labels )
306 FP(t) = sum( C(t) * false_labels )
308 But, computing C(t) requires computation for each t. To make it fast,
309 observe that C(t) is a cumulative integral, and so if we have
310 thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1}
311 where n = num_thresholds, and if we can compute the bucket function
312 B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
313 then we get
314 C(t_i) = sum( B(j), j >= i )
315 which is the reversed cumulative sum in tf.cumsum().
317 We can compute B(i) efficiently by taking advantage of the fact that
318 our thresholds are evenly distributed, in that
319 width = 1.0 / (num_thresholds - 1)
320 thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
321 Given a prediction value p, we can map it to its bucket by
322 bucket_index(p) = floor( p * (num_thresholds - 1) )
323 so we can use tf.math.unsorted_segment_sum() to update the buckets in one
324 pass.
326 Consider following example:
327 y_true = [0, 0, 1, 1]
328 y_pred = [0.1, 0.5, 0.3, 0.9]
329 thresholds = [0.0, 0.5, 1.0]
330 num_buckets = 2 # [0.0, 1.0], (1.0, 2.0]
331 bucket_index(y_pred) = tf.math.floor(y_pred * num_buckets)
332 = tf.math.floor([0.2, 1.0, 0.6, 1.8])
333 = [0, 0, 0, 1]
334 # The meaning of this bucket is that if any of the label is true,
335 # then 1 will be added to the corresponding bucket with the index.
336 # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the
337 # label for 1.8 is true, then 1 will be added to bucket 1.
338 #
339 # Note the second item "1.0" is floored to 0, since the value need to be
340 # strictly larger than the bucket lower bound.
341 # In the implementation, we use tf.math.ceil() - 1 to achieve this.
342 tp_bucket_value = tf.math.unsorted_segment_sum(true_labels, bucket_indices,
343 num_segments=num_thresholds)
344 = [1, 1, 0]
345 # For [1, 1, 0] here, it means there is 1 true value contributed by bucket
346 # 0, and 1 value contributed by bucket 1. When we aggregate them to
347 # together, the result become [a + b + c, b + c, c], since large thresholds
348 # will always contribute to the value for smaller thresholds.
349 true_positive = tf.math.cumsum(tp_bucket_value, reverse=True)
350 = [2, 1, 0]
352 This implementation exhibits a run time and space complexity of O(T + N),
353 where T is the number of thresholds and N is the size of predictions.
354 Metrics that rely on standard implementation instead exhibit a complexity of
355 O(T * N).
357 Args:
358 variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
359 and corresponding variables to update as values.
360 y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be
361 cast to `bool`.
362 y_pred: A floating point `Tensor` of arbitrary shape and whose values are
363 in the range `[0, 1]`.
364 thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
365 It need to be evenly distributed (the diff between each element need to
366 be the same).
367 multi_label: Optional boolean indicating whether multidimensional
368 prediction/labels should be treated as multilabel responses, or
369 flattened into a single label. When True, the valus of
370 `variables_to_update` must have a second dimension equal to the number
371 of labels in y_true and y_pred, and those tensors must not be
372 RaggedTensors.
373 sample_weights: Optional `Tensor` whose rank is either 0, or the same rank
374 as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
375 must be either `1`, or the same as the corresponding `y_true`
376 dimension).
377 label_weights: Optional tensor of non-negative weights for multilabel
378 data. The weights are applied when calculating TP, FP, FN, and TN
379 without explicit multilabel handling (i.e. when the data is to be
380 flattened).
381 thresholds_with_epsilon: Optional boolean indicating whether the leading
382 and tailing thresholds has any epsilon added for floating point
383 imprecisions. It will change how we handle the leading and tailing
384 bucket.
386 Returns:
387 Update op.
388 """
389 num_thresholds = thresholds.shape.as_list()[0]
391 if sample_weights is None:
392 sample_weights = 1.0
393 else:
394 sample_weights = tf.__internal__.ops.broadcast_weights(
395 tf.cast(sample_weights, dtype=y_pred.dtype), y_pred
396 )
397 if not multi_label:
398 sample_weights = tf.reshape(sample_weights, [-1])
399 if label_weights is None:
400 label_weights = 1.0
401 else:
402 label_weights = tf.expand_dims(label_weights, 0)
403 label_weights = tf.__internal__.ops.broadcast_weights(
404 label_weights, y_pred
405 )
406 if not multi_label:
407 label_weights = tf.reshape(label_weights, [-1])
408 weights = tf.cast(tf.multiply(sample_weights, label_weights), y_true.dtype)
410 # We shouldn't need this, but in case there are predict value that is out of
411 # the range of [0.0, 1.0]
412 y_pred = tf.clip_by_value(y_pred, clip_value_min=0.0, clip_value_max=1.0)
414 y_true = tf.cast(tf.cast(y_true, tf.bool), y_true.dtype)
415 if not multi_label:
416 y_true = tf.reshape(y_true, [-1])
417 y_pred = tf.reshape(y_pred, [-1])
419 true_labels = tf.multiply(y_true, weights)
420 false_labels = tf.multiply((1.0 - y_true), weights)
422 # Compute the bucket indices for each prediction value.
423 # Since the predict value has to be strictly greater than the thresholds,
424 # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket.
425 # We have to use math.ceil(val) - 1 for the bucket.
426 bucket_indices = tf.math.ceil(y_pred * (num_thresholds - 1)) - 1
428 if thresholds_with_epsilon:
429 # In this case, the first bucket should actually take into account since
430 # the any prediction between [0.0, 1.0] should be larger than the first
431 # threshold. We change the bucket value from -1 to 0.
432 bucket_indices = tf.nn.relu(bucket_indices)
434 bucket_indices = tf.cast(bucket_indices, tf.int32)
436 if multi_label:
437 # We need to run bucket segment sum for each of the label class. In the
438 # multi_label case, the rank of the label is 2. We first transpose it so
439 # that the label dim becomes the first and we can parallel run though
440 # them.
441 true_labels = tf.transpose(true_labels)
442 false_labels = tf.transpose(false_labels)
443 bucket_indices = tf.transpose(bucket_indices)
445 def gather_bucket(label_and_bucket_index):
446 label, bucket_index = (
447 label_and_bucket_index[0],
448 label_and_bucket_index[1],
449 )
450 return tf.math.unsorted_segment_sum(
451 data=label,
452 segment_ids=bucket_index,
453 num_segments=num_thresholds,
454 )
456 tp_bucket_v = tf.vectorized_map(
457 gather_bucket, (true_labels, bucket_indices), warn=False
458 )
459 fp_bucket_v = tf.vectorized_map(
460 gather_bucket, (false_labels, bucket_indices), warn=False
461 )
462 tp = tf.transpose(tf.cumsum(tp_bucket_v, reverse=True, axis=1))
463 fp = tf.transpose(tf.cumsum(fp_bucket_v, reverse=True, axis=1))
464 else:
465 tp_bucket_v = tf.math.unsorted_segment_sum(
466 data=true_labels,
467 segment_ids=bucket_indices,
468 num_segments=num_thresholds,
469 )
470 fp_bucket_v = tf.math.unsorted_segment_sum(
471 data=false_labels,
472 segment_ids=bucket_indices,
473 num_segments=num_thresholds,
474 )
475 tp = tf.cumsum(tp_bucket_v, reverse=True)
476 fp = tf.cumsum(fp_bucket_v, reverse=True)
478 # fn = sum(true_labels) - tp
479 # tn = sum(false_labels) - fp
480 if (
481 ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
482 or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
483 ):
484 if multi_label:
485 total_true_labels = tf.reduce_sum(true_labels, axis=1)
486 total_false_labels = tf.reduce_sum(false_labels, axis=1)
487 else:
488 total_true_labels = tf.reduce_sum(true_labels)
489 total_false_labels = tf.reduce_sum(false_labels)
491 update_ops = []
492 if ConfusionMatrix.TRUE_POSITIVES in variables_to_update:
493 variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES]
494 update_ops.append(variable.assign_add(tp))
495 if ConfusionMatrix.FALSE_POSITIVES in variables_to_update:
496 variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES]
497 update_ops.append(variable.assign_add(fp))
498 if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update:
499 variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES]
500 tn = total_false_labels - fp
501 update_ops.append(variable.assign_add(tn))
502 if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update:
503 variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES]
504 fn = total_true_labels - tp
505 update_ops.append(variable.assign_add(fn))
506 return tf.group(update_ops)
509def is_evenly_distributed_thresholds(thresholds):
510 """Check if the thresholds list is evenly distributed.
512 We could leverage evenly distributed thresholds to use less memory when
513 calculate metrcis like AUC where each individual threshold need to be
514 evaluated.
516 Args:
517 thresholds: A python list or tuple, or 1D numpy array whose value is
518 ranged in [0, 1].
520 Returns:
521 boolean, whether the values in the inputs are evenly distributed.
522 """
523 # Check the list value and see if it is evenly distributed.
524 num_thresholds = len(thresholds)
525 if num_thresholds < 3:
526 return False
527 even_thresholds = np.arange(num_thresholds, dtype=np.float32) / (
528 num_thresholds - 1
529 )
530 return np.allclose(thresholds, even_thresholds, atol=backend.epsilon())
533def update_confusion_matrix_variables(
534 variables_to_update,
535 y_true,
536 y_pred,
537 thresholds,
538 top_k=None,
539 class_id=None,
540 sample_weight=None,
541 multi_label=False,
542 label_weights=None,
543 thresholds_distributed_evenly=False,
544):
545 """Returns op to update the given confusion matrix variables.
547 For every pair of values in y_true and y_pred:
549 true_positive: y_true == True and y_pred > thresholds
550 false_negatives: y_true == True and y_pred <= thresholds
551 true_negatives: y_true == False and y_pred <= thresholds
552 false_positive: y_true == False and y_pred > thresholds
554 The results will be weighted and added together. When multiple thresholds
555 are provided, we will repeat the same for every threshold.
557 For estimation of these metrics over a stream of data, the function creates
558 an `update_op` operation that updates the given variables.
560 If `sample_weight` is `None`, weights default to 1.
561 Use weights of 0 to mask values.
563 Args:
564 variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
565 and corresponding variables to update as values.
566 y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
567 y_pred: A floating point `Tensor` of arbitrary shape and whose values are
568 in the range `[0, 1]`.
569 thresholds: A float value, float tensor, python list, or tuple of float
570 thresholds in `[0, 1]`, or NEG_INF (used when top_k is set).
571 top_k: Optional int, indicates that the positive labels should be limited
572 to the top k predictions.
573 class_id: Optional int, limits the prediction and labels to the class
574 specified by this argument.
575 sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
576 as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
577 must be either `1`, or the same as the corresponding `y_true`
578 dimension).
579 multi_label: Optional boolean indicating whether multidimensional
580 prediction/labels should be treated as multilabel responses, or
581 flattened into a single label. When True, the valus of
582 `variables_to_update` must have a second dimension equal to the number
583 of labels in y_true and y_pred, and those tensors must not be
584 RaggedTensors.
585 label_weights: (optional) tensor of non-negative weights for multilabel
586 data. The weights are applied when calculating TP, FP, FN, and TN
587 without explicit multilabel handling (i.e. when the data is to be
588 flattened).
589 thresholds_distributed_evenly: Boolean, whether the thresholds are evenly
590 distributed within the list. An optimized method will be used if this is
591 the case. See _update_confusion_matrix_variables_optimized() for more
592 details.
594 Returns:
595 Update op.
597 Raises:
598 ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
599 `sample_weight` is not `None` and its shape doesn't match `y_pred`, or
600 if `variables_to_update` contains invalid keys.
601 """
602 if multi_label and label_weights is not None:
603 raise ValueError(
604 "`label_weights` for multilabel data should be handled "
605 "outside of `update_confusion_matrix_variables` when "
606 "`multi_label` is True."
607 )
608 if variables_to_update is None:
609 return
610 if not any(
611 key for key in variables_to_update if key in list(ConfusionMatrix)
612 ):
613 raise ValueError(
614 "Please provide at least one valid confusion matrix "
615 "variable to update. Valid variable key options are: "
616 f'"{list(ConfusionMatrix)}". '
617 f'Received: "{variables_to_update.keys()}"'
618 )
620 variable_dtype = list(variables_to_update.values())[0].dtype
622 y_true = tf.cast(y_true, dtype=variable_dtype)
623 y_pred = tf.cast(y_pred, dtype=variable_dtype)
625 if thresholds_distributed_evenly:
626 # Check whether the thresholds has any leading or tailing epsilon added
627 # for floating point imprecision. The leading and tailing threshold will
628 # be handled bit differently as the corner case. At this point,
629 # thresholds should be a list/array with more than 2 items, and ranged
630 # between [0, 1]. See is_evenly_distributed_thresholds() for more
631 # details.
632 thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0
634 thresholds = tf.convert_to_tensor(thresholds, dtype=variable_dtype)
635 num_thresholds = thresholds.shape.as_list()[0]
637 if multi_label:
638 one_thresh = tf.equal(
639 tf.cast(1, dtype=tf.int32),
640 tf.rank(thresholds),
641 name="one_set_of_thresholds_cond",
642 )
643 else:
644 [y_pred, y_true], _ = ragged_assert_compatible_and_get_flat_values(
645 [y_pred, y_true], sample_weight
646 )
647 one_thresh = tf.cast(True, dtype=tf.bool)
649 invalid_keys = [
650 key for key in variables_to_update if key not in list(ConfusionMatrix)
651 ]
652 if invalid_keys:
653 raise ValueError(
654 f'Invalid keys: "{invalid_keys}". '
655 f'Valid variable key options are: "{list(ConfusionMatrix)}"'
656 )
658 if sample_weight is None:
659 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
660 y_pred, y_true
661 )
662 else:
663 sample_weight = tf.cast(sample_weight, dtype=variable_dtype)
664 (
665 y_pred,
666 y_true,
667 sample_weight,
668 ) = losses_utils.squeeze_or_expand_dimensions(
669 y_pred, y_true, sample_weight=sample_weight
670 )
671 y_pred.shape.assert_is_compatible_with(y_true.shape)
673 if top_k is not None:
674 y_pred = _filter_top_k(y_pred, top_k)
675 if class_id is not None:
676 # Preserve dimension to match with sample_weight
677 y_true = y_true[..., class_id, None]
678 y_pred = y_pred[..., class_id, None]
680 if thresholds_distributed_evenly:
681 return _update_confusion_matrix_variables_optimized(
682 variables_to_update,
683 y_true,
684 y_pred,
685 thresholds,
686 multi_label=multi_label,
687 sample_weights=sample_weight,
688 label_weights=label_weights,
689 thresholds_with_epsilon=thresholds_with_epsilon,
690 )
692 pred_shape = tf.shape(y_pred)
693 num_predictions = pred_shape[0]
694 if y_pred.shape.ndims == 1:
695 num_labels = 1
696 else:
697 num_labels = tf.math.reduce_prod(pred_shape[1:], axis=0)
698 thresh_label_tile = tf.where(
699 one_thresh, num_labels, tf.ones([], dtype=tf.int32)
700 )
702 # Reshape predictions and labels, adding a dim for thresholding.
703 if multi_label:
704 predictions_extra_dim = tf.expand_dims(y_pred, 0)
705 labels_extra_dim = tf.expand_dims(tf.cast(y_true, dtype=tf.bool), 0)
706 else:
707 # Flatten predictions and labels when not multilabel.
708 predictions_extra_dim = tf.reshape(y_pred, [1, -1])
709 labels_extra_dim = tf.reshape(tf.cast(y_true, dtype=tf.bool), [1, -1])
711 # Tile the thresholds for every prediction.
712 if multi_label:
713 thresh_pretile_shape = [num_thresholds, 1, -1]
714 thresh_tiles = [1, num_predictions, thresh_label_tile]
715 data_tiles = [num_thresholds, 1, 1]
716 else:
717 thresh_pretile_shape = [num_thresholds, -1]
718 thresh_tiles = [1, num_predictions * num_labels]
719 data_tiles = [num_thresholds, 1]
721 thresh_tiled = tf.tile(
722 tf.reshape(thresholds, thresh_pretile_shape), tf.stack(thresh_tiles)
723 )
725 # Tile the predictions for every threshold.
726 preds_tiled = tf.tile(predictions_extra_dim, data_tiles)
728 # Compare predictions and threshold.
729 pred_is_pos = tf.greater(preds_tiled, thresh_tiled)
731 # Tile labels by number of thresholds
732 label_is_pos = tf.tile(labels_extra_dim, data_tiles)
734 if sample_weight is not None:
735 sample_weight = tf.__internal__.ops.broadcast_weights(
736 tf.cast(sample_weight, dtype=variable_dtype), y_pred
737 )
738 weights_tiled = tf.tile(
739 tf.reshape(sample_weight, thresh_tiles), data_tiles
740 )
741 else:
742 weights_tiled = None
744 if label_weights is not None and not multi_label:
745 label_weights = tf.expand_dims(label_weights, 0)
746 label_weights = tf.__internal__.ops.broadcast_weights(
747 label_weights, y_pred
748 )
749 label_weights_tiled = tf.tile(
750 tf.reshape(label_weights, thresh_tiles), data_tiles
751 )
752 if weights_tiled is None:
753 weights_tiled = label_weights_tiled
754 else:
755 weights_tiled = tf.multiply(weights_tiled, label_weights_tiled)
757 update_ops = []
759 def weighted_assign_add(label, pred, weights, var):
760 label_and_pred = tf.cast(tf.logical_and(label, pred), dtype=var.dtype)
761 if weights is not None:
762 label_and_pred *= tf.cast(weights, dtype=var.dtype)
763 return var.assign_add(tf.reduce_sum(label_and_pred, 1))
765 loop_vars = {
766 ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
767 }
768 update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
769 update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
770 update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
772 if update_fn or update_tn:
773 pred_is_neg = tf.logical_not(pred_is_pos)
774 loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
776 if update_fp or update_tn:
777 label_is_neg = tf.logical_not(label_is_pos)
778 loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
779 if update_tn:
780 loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (
781 label_is_neg,
782 pred_is_neg,
783 )
785 for matrix_cond, (label, pred) in loop_vars.items():
787 if matrix_cond in variables_to_update:
788 update_ops.append(
789 weighted_assign_add(
790 label, pred, weights_tiled, variables_to_update[matrix_cond]
791 )
792 )
794 return tf.group(update_ops)
797def _filter_top_k(x, k):
798 """Filters top-k values in the last dim of x and set the rest to NEG_INF.
800 Used for computing top-k prediction values in dense labels (which has the
801 same shape as predictions) for recall and precision top-k metrics.
803 Args:
804 x: tensor with any dimensions.
805 k: the number of values to keep.
807 Returns:
808 tensor with same shape and dtype as x.
809 """
810 _, top_k_idx = tf.math.top_k(x, k, sorted=False)
811 top_k_mask = tf.reduce_sum(
812 tf.one_hot(top_k_idx, tf.shape(x)[-1], axis=-1), axis=-2
813 )
814 return x * top_k_mask + NEG_INF * (1 - top_k_mask)
817def ragged_assert_compatible_and_get_flat_values(values, mask=None):
818 """If ragged, it checks the compatibility and then returns the flat_values.
820 Note: If two tensors are dense, it does not check their compatibility.
821 Note: Although two ragged tensors with different ragged ranks could have
822 identical overall rank and dimension sizes and hence be compatible,
823 we do not support those cases.
824 Args:
825 values: A list of potentially ragged tensor of the same ragged_rank.
826 mask: A potentially ragged tensor of the same ragged_rank as elements in
827 Values.
829 Returns:
830 A tuple in which the first element is the list of tensors and the second
831 is the mask tensor. ([Values], mask). Mask and the element in Values
832 are equal to the flat_values of the input arguments (if they were
833 ragged).
834 """
835 if isinstance(values, list):
836 is_all_ragged = all(isinstance(rt, tf.RaggedTensor) for rt in values)
837 is_any_ragged = any(isinstance(rt, tf.RaggedTensor) for rt in values)
838 else:
839 is_all_ragged = isinstance(values, tf.RaggedTensor)
840 is_any_ragged = is_all_ragged
841 if is_all_ragged and ((mask is None) or isinstance(mask, tf.RaggedTensor)):
842 to_be_stripped = False
843 if not isinstance(values, list):
844 values = [values]
845 to_be_stripped = True
847 # NOTE: we leave the flat_values compatibility to
848 # tf.TensorShape `assert_is_compatible_with` check if both dynamic
849 # dimensions are equal and then use the flat_values.
850 nested_row_split_list = [rt.nested_row_splits for rt in values]
851 assertion_list = _assert_splits_match(nested_row_split_list)
853 # if both are ragged sample_weights also should be ragged with same
854 # dims.
855 if isinstance(mask, tf.RaggedTensor):
856 assertion_list_for_mask = _assert_splits_match(
857 [nested_row_split_list[0], mask.nested_row_splits]
858 )
859 with tf.control_dependencies(assertion_list_for_mask):
860 mask = tf.expand_dims(mask.flat_values, -1)
862 # values has at least 1 element.
863 flat_values = []
864 for value in values:
865 with tf.control_dependencies(assertion_list):
866 flat_values.append(tf.expand_dims(value.flat_values, -1))
868 values = flat_values[0] if to_be_stripped else flat_values
870 elif is_any_ragged:
871 raise TypeError(
872 "Some of the inputs are not tf.RaggedTensor. "
873 f"Input received: {values}"
874 )
875 # values are empty or value are not ragged and mask is ragged.
876 elif isinstance(mask, tf.RaggedTensor):
877 raise TypeError(
878 "Ragged mask is not allowed with non-ragged inputs. "
879 f"Input received: {values}, mask received: {mask}"
880 )
882 return values, mask
885def _assert_splits_match(nested_splits_lists):
886 """Checks that the given splits lists are identical.
888 Performs static tests to ensure that the given splits lists are identical,
889 and returns a list of control dependency op tensors that check that they are
890 fully identical.
892 Args:
893 nested_splits_lists: A list of nested_splits_lists, where each split_list
894 is a list of `splits` tensors from a `RaggedTensor`, ordered from
895 outermost ragged dimension to innermost ragged dimension.
897 Returns:
898 A list of control dependency op tensors.
899 Raises:
900 ValueError: If the splits are not identical.
901 """
902 error_msg = (
903 "Inputs must have identical ragged splits. "
904 f"Input received: {nested_splits_lists}"
905 )
906 for splits_list in nested_splits_lists:
907 if len(splits_list) != len(nested_splits_lists[0]):
908 raise ValueError(error_msg)
909 return [
910 tf.debugging.assert_equal(s1, s2, message=error_msg)
911 for splits_list in nested_splits_lists[1:]
912 for (s1, s2) in zip(nested_splits_lists[0], splits_list)
913 ]
916def binary_matches(y_true, y_pred, threshold=0.5):
917 """Creates int Tensor, 1 for label-prediction match, 0 for mismatch.
919 Args:
920 y_true: Ground truth values, of shape (batch_size, d0, .. dN).
921 y_pred: The predicted values, of shape (batch_size, d0, .. dN).
922 threshold: (Optional) Float representing the threshold for deciding
923 whether prediction values are 1 or 0.
925 Returns:
926 Binary matches, of shape (batch_size, d0, .. dN).
927 """
928 y_pred = tf.convert_to_tensor(y_pred)
929 threshold = tf.cast(threshold, y_pred.dtype)
930 y_pred = tf.cast(y_pred > threshold, y_pred.dtype)
931 return tf.cast(tf.equal(y_true, y_pred), backend.floatx())
934def sparse_categorical_matches(y_true, y_pred):
935 """Creates float Tensor, 1.0 for label-prediction match, 0.0 for mismatch.
937 You can provide logits of classes as `y_pred`, since argmax of
938 logits and probabilities are same.
940 Args:
941 y_true: Integer ground truth values.
942 y_pred: The prediction values.
944 Returns:
945 Match tensor: 1.0 for label-prediction match, 0.0 for mismatch.
946 """
947 reshape_matches = False
948 y_pred = tf.convert_to_tensor(y_pred)
949 y_true = tf.convert_to_tensor(y_true)
950 y_true_org_shape = tf.shape(y_true)
951 y_pred_rank = y_pred.shape.ndims
952 y_true_rank = y_true.shape.ndims
954 # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
955 if (
956 (y_true_rank is not None)
957 and (y_pred_rank is not None)
958 and (len(backend.int_shape(y_true)) == len(backend.int_shape(y_pred)))
959 ):
960 y_true = tf.squeeze(y_true, [-1])
961 reshape_matches = True
962 y_pred = tf.math.argmax(y_pred, axis=-1)
964 # If the predicted output and actual output types don't match, force cast
965 # them to match.
966 if backend.dtype(y_pred) != backend.dtype(y_true):
967 y_pred = tf.cast(y_pred, backend.dtype(y_true))
968 matches = tf.cast(tf.equal(y_true, y_pred), backend.floatx())
969 if reshape_matches:
970 matches = tf.reshape(matches, shape=y_true_org_shape)
971 return matches
974def sparse_top_k_categorical_matches(y_true, y_pred, k=5):
975 """Creates float Tensor, 1.0 for label-TopK_prediction match, 0.0 for
976 mismatch.
978 Args:
979 y_true: tensor of true targets.
980 y_pred: tensor of predicted targets.
981 k: (Optional) Number of top elements to look at for computing accuracy.
982 Defaults to 5.
984 Returns:
985 Match tensor: 1.0 for label-prediction match, 0.0 for mismatch.
986 """
987 reshape_matches = False
988 y_true = tf.convert_to_tensor(y_true)
989 y_pred = tf.convert_to_tensor(y_pred)
990 y_true_rank = y_true.shape.ndims
991 y_pred_rank = y_pred.shape.ndims
992 y_true_org_shape = tf.shape(y_true)
994 # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
995 if (y_true_rank is not None) and (y_pred_rank is not None):
996 if y_pred_rank > 2:
997 y_pred = tf.reshape(y_pred, [-1, y_pred.shape[-1]])
998 if y_true_rank > 1:
999 reshape_matches = True
1000 y_true = tf.reshape(y_true, [-1])
1002 matches = tf.cast(
1003 tf.math.in_top_k(
1004 predictions=y_pred, targets=tf.cast(y_true, "int32"), k=k
1005 ),
1006 dtype=backend.floatx(),
1007 )
1009 # returned matches is expected to have same shape as y_true input
1010 if reshape_matches:
1011 return tf.reshape(matches, shape=y_true_org_shape)
1013 return matches