Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/metrics_utils.py: 19%
307 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utils related to keras metrics."""
18from enum import Enum
19import functools
20import weakref
21import numpy as np
23from tensorflow.python.compat import compat
24from tensorflow.python.distribute import distribute_lib
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_conversion
28from tensorflow.python.keras import backend
29from tensorflow.python.keras.utils import losses_utils
30from tensorflow.python.keras.utils import tf_utils
31from tensorflow.python.keras.utils.generic_utils import to_list
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import array_ops_stack
34from tensorflow.python.ops import check_ops
35from tensorflow.python.ops import clip_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import gen_math_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import nn_ops
40from tensorflow.python.ops import variables as variables_module
41from tensorflow.python.ops import weights_broadcast_ops
42from tensorflow.python.ops.parallel_for import control_flow_ops as parallel_control_flow_ops
43from tensorflow.python.ops.ragged import ragged_tensor
44from tensorflow.python.util import tf_decorator
46NEG_INF = -1e10
49class Reduction(Enum):
50 """Types of metrics reduction.
52 Contains the following values:
54 * `SUM`: Scalar sum of weighted values.
55 * `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by
56 number of elements.
57 * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights.
58 """
59 SUM = 'sum'
60 SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
61 WEIGHTED_MEAN = 'weighted_mean'
64def update_state_wrapper(update_state_fn):
65 """Decorator to wrap metric `update_state()` with `add_update()`.
67 Args:
68 update_state_fn: function that accumulates metric statistics.
70 Returns:
71 Decorated function that wraps `update_state_fn()` with `add_update()`.
72 """
74 def decorated(metric_obj, *args, **kwargs):
75 """Decorated function with `add_update()`."""
76 strategy = distribute_lib.get_strategy()
78 for weight in metric_obj.weights:
79 if (backend.is_tpu_strategy(strategy) and
80 not strategy.extended.variable_created_in_scope(weight)
81 and not distribute_lib.in_cross_replica_context()):
82 raise ValueError(
83 'Trying to run metric.update_state in replica context when '
84 'the metric was not created in TPUStrategy scope. '
85 'Make sure the keras Metric is created in TPUstrategy scope. ')
87 with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
88 update_op = update_state_fn(*args, **kwargs)
89 if update_op is not None: # update_op will be None in eager execution.
90 metric_obj.add_update(update_op)
91 return update_op
93 return tf_decorator.make_decorator(update_state_fn, decorated)
96def result_wrapper(result_fn):
97 """Decorator to wrap metric `result()` function in `merge_call()`.
99 Result computation is an idempotent operation that simply calculates the
100 metric value using the state variables.
102 If metric state variables are distributed across replicas/devices and
103 `result()` is requested from the context of one device - This function wraps
104 `result()` in a distribution strategy `merge_call()`. With this,
105 the metric state variables will be aggregated across devices.
107 Args:
108 result_fn: function that computes the metric result.
110 Returns:
111 Decorated function that wraps `result_fn()` in distribution strategy
112 `merge_call()`.
113 """
115 def decorated(metric_obj, *args):
116 """Decorated function with merge_call."""
117 has_strategy = distribute_lib.has_strategy()
118 replica_context = distribute_lib.get_replica_context()
120 # The purpose of using `merge_call` to call `result()` is to trigger cross
121 # replica aggregation of metric state variables (SyncOnReadVariable). After
122 # we introduced `variable_sync_on_read_context`, in principle there is no
123 # need to use `merge_call` here. However the branch still exists because:
124 #
125 # 1. Keras V1 training code sometimes assumes `result_t` is the same tensor
126 # across replicas (achieved by `merge_call`). With
127 # `variable_sync_on_read_context` each replica gets their own tensors
128 # residing on replica's device, thus breaking the assumption.
129 # 2. Keras c/fit creates a tf.function (a.k.a, train_function) that returns
130 # the metric values of the first replica. With
131 # `variable_sync_on_read_context` since each replica gets their own
132 # tensors, the metric result tensors on the non-first replicas are not in
133 # the return value of train_function, making TF graph optimizer prune the
134 # branch that computes and aggregates those metric results. As a result,
135 # if NCCL is used to do the aggregation, the program will hang because
136 # NCCL ops are only launched on the non-pruned first replica.
137 #
138 # We condition on strategy.extended._use_merge_call() since we know if it is
139 # false, the program uses `jit_compile` to compile replica fn, meaning it is
140 # not V1 training (hence #1 is okay), and no pruning will happen as
141 # compiled functions are not inlined (hence #2 is okay).
143 if (not has_strategy or replica_context is None or
144 not distribute_lib.get_strategy(
145 ).extended._use_merge_call()):
146 with distribute_lib.variable_sync_on_read_context():
147 raw_result = result_fn(*args)
148 # Results need to be wrapped in a `tf.identity` op to ensure
149 # correct execution order.
150 if isinstance(raw_result,
151 (ops.Tensor, variables_module.Variable, float, int)):
152 result_t = array_ops.identity(raw_result)
153 elif isinstance(raw_result, dict):
154 result_t = {
155 key: array_ops.identity(value)
156 for key, value in raw_result.items()
157 }
158 else:
159 try:
160 result_t = array_ops.identity(raw_result)
161 except (ValueError, TypeError):
162 raise RuntimeError(
163 'The output of `metric.result()` can only be a single '
164 'Tensor/Variable, or a dict of Tensors/Variables. '
165 'For metric %s, got result %s.' % (metric_obj.name, raw_result))
166 else:
167 # TODO(psv): Test distribution of metrics using different distribution
168 # strategies.
170 # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
171 # with distribution object as the first parameter. We create a wrapper
172 # here so that the result function need not have that parameter.
173 def merge_fn_wrapper(distribution, merge_fn, *args):
174 # We will get `PerReplica` merge function. Taking the first one as all
175 # are identical copies of the function that we had passed below.
176 result = distribution.experimental_local_results(merge_fn)[0](*args)
178 # Wrapping result in identity so that control dependency between
179 # update_op from `update_state` and result works in case result returns
180 # a tensor.
181 return array_ops.identity(result)
183 # Wrapping result in merge_call. merge_call is used when we want to leave
184 # replica mode and compute a value in cross replica mode.
185 result_t = replica_context.merge_call(
186 merge_fn_wrapper, args=(result_fn,) + args)
188 # We are saving the result op here to be used in train/test execution
189 # functions. This basically gives the result op that was generated with a
190 # control dep to the updates for these workflows.
191 metric_obj._call_result = result_t
192 return result_t
194 return tf_decorator.make_decorator(result_fn, decorated)
197def weakmethod(method):
198 """Creates a weak reference to the bound method."""
200 cls = method.im_class
201 func = method.im_func
202 instance_ref = weakref.ref(method.im_self)
204 @functools.wraps(method)
205 def inner(*args, **kwargs):
206 return func.__get__(instance_ref(), cls)(*args, **kwargs)
208 del method
209 return inner
212def assert_thresholds_range(thresholds):
213 if thresholds is not None:
214 invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1]
215 if invalid_thresholds:
216 raise ValueError(
217 'Threshold values must be in [0, 1]. Invalid values: {}'.format(
218 invalid_thresholds))
221def parse_init_thresholds(thresholds, default_threshold=0.5):
222 if thresholds is not None:
223 assert_thresholds_range(to_list(thresholds))
224 thresholds = to_list(default_threshold if thresholds is None else thresholds)
225 return thresholds
228class ConfusionMatrix(Enum):
229 TRUE_POSITIVES = 'tp'
230 FALSE_POSITIVES = 'fp'
231 TRUE_NEGATIVES = 'tn'
232 FALSE_NEGATIVES = 'fn'
235class AUCCurve(Enum):
236 """Type of AUC Curve (ROC or PR)."""
237 ROC = 'ROC'
238 PR = 'PR'
240 @staticmethod
241 def from_str(key):
242 if key in ('pr', 'PR'):
243 return AUCCurve.PR
244 elif key in ('roc', 'ROC'):
245 return AUCCurve.ROC
246 else:
247 raise ValueError('Invalid AUC curve value "%s".' % key)
250class AUCSummationMethod(Enum):
251 """Type of AUC summation method.
253 https://en.wikipedia.org/wiki/Riemann_sum)
255 Contains the following values:
256 * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
257 `PR` curve, interpolates (true/false) positives but not the ratio that is
258 precision (see Davis & Goadrich 2006 for details).
259 * 'minoring': Applies left summation for increasing intervals and right
260 summation for decreasing intervals.
261 * 'majoring': Applies right summation for increasing intervals and left
262 summation for decreasing intervals.
263 """
264 INTERPOLATION = 'interpolation'
265 MAJORING = 'majoring'
266 MINORING = 'minoring'
268 @staticmethod
269 def from_str(key):
270 if key in ('interpolation', 'Interpolation'):
271 return AUCSummationMethod.INTERPOLATION
272 elif key in ('majoring', 'Majoring'):
273 return AUCSummationMethod.MAJORING
274 elif key in ('minoring', 'Minoring'):
275 return AUCSummationMethod.MINORING
276 else:
277 raise ValueError('Invalid AUC summation method value "%s".' % key)
280def _update_confusion_matrix_variables_optimized(
281 variables_to_update,
282 y_true,
283 y_pred,
284 thresholds,
285 multi_label=False,
286 sample_weights=None,
287 label_weights=None,
288 thresholds_with_epsilon=False):
289 """Update confusion matrix variables with memory efficient alternative.
291 Note that the thresholds need to be evenly distributed within the list, eg,
292 the diff between consecutive elements are the same.
294 To compute TP/FP/TN/FN, we are measuring a binary classifier
295 C(t) = (predictions >= t)
296 at each threshold 't'. So we have
297 TP(t) = sum( C(t) * true_labels )
298 FP(t) = sum( C(t) * false_labels )
300 But, computing C(t) requires computation for each t. To make it fast,
301 observe that C(t) is a cumulative integral, and so if we have
302 thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1}
303 where n = num_thresholds, and if we can compute the bucket function
304 B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
305 then we get
306 C(t_i) = sum( B(j), j >= i )
307 which is the reversed cumulative sum in tf.cumsum().
309 We can compute B(i) efficiently by taking advantage of the fact that
310 our thresholds are evenly distributed, in that
311 width = 1.0 / (num_thresholds - 1)
312 thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
313 Given a prediction value p, we can map it to its bucket by
314 bucket_index(p) = floor( p * (num_thresholds - 1) )
315 so we can use tf.math.unsorted_segment_sum() to update the buckets in one
316 pass.
318 Consider following example:
319 y_true = [0, 0, 1, 1]
320 y_pred = [0.1, 0.5, 0.3, 0.9]
321 thresholds = [0.0, 0.5, 1.0]
322 num_buckets = 2 # [0.0, 1.0], (1.0, 2.0]
323 bucket_index(y_pred) = tf.math.floor(y_pred * num_buckets)
324 = tf.math.floor([0.2, 1.0, 0.6, 1.8])
325 = [0, 0, 0, 1]
326 # The meaning of this bucket is that if any of the label is true,
327 # then 1 will be added to the corresponding bucket with the index.
328 # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the
329 # label for 1.8 is true, then 1 will be added to bucket 1.
330 #
331 # Note the second item "1.0" is floored to 0, since the value need to be
332 # strictly larger than the bucket lower bound.
333 # In the implementation, we use tf.math.ceil() - 1 to achieve this.
334 tp_bucket_value = tf.math.unsorted_segment_sum(true_labels, bucket_indices,
335 num_segments=num_thresholds)
336 = [1, 1, 0]
337 # For [1, 1, 0] here, it means there is 1 true value contributed by bucket 0,
338 # and 1 value contributed by bucket 1. When we aggregate them to together,
339 # the result become [a + b + c, b + c, c], since large thresholds will always
340 # contribute to the value for smaller thresholds.
341 true_positive = tf.math.cumsum(tp_bucket_value, reverse=True)
342 = [2, 1, 0]
344 This implementation exhibits a run time and space complexity of O(T + N),
345 where T is the number of thresholds and N is the size of predictions.
346 Metrics that rely on standard implementation instead exhibit a complexity of
347 O(T * N).
349 Args:
350 variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
351 and corresponding variables to update as values.
352 y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be cast
353 to `bool`.
354 y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
355 the range `[0, 1]`.
356 thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
357 It need to be evenly distributed (the diff between each element need to be
358 the same).
359 multi_label: Optional boolean indicating whether multidimensional
360 prediction/labels should be treated as multilabel responses, or flattened
361 into a single label. When True, the valus of `variables_to_update` must
362 have a second dimension equal to the number of labels in y_true and
363 y_pred, and those tensors must not be RaggedTensors.
364 sample_weights: Optional `Tensor` whose rank is either 0, or the same rank
365 as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
366 must be either `1`, or the same as the corresponding `y_true` dimension).
367 label_weights: Optional tensor of non-negative weights for multilabel
368 data. The weights are applied when calculating TP, FP, FN, and TN without
369 explicit multilabel handling (i.e. when the data is to be flattened).
370 thresholds_with_epsilon: Optional boolean indicating whether the leading and
371 tailing thresholds has any epsilon added for floating point imprecisions.
372 It will change how we handle the leading and tailing bucket.
374 Returns:
375 Update op.
376 """
377 num_thresholds = thresholds.shape.as_list()[0]
379 if sample_weights is None:
380 sample_weights = 1.0
381 else:
382 sample_weights = weights_broadcast_ops.broadcast_weights(
383 math_ops.cast(sample_weights, dtype=y_pred.dtype), y_pred)
384 if not multi_label:
385 sample_weights = array_ops.reshape(sample_weights, [-1])
386 if label_weights is None:
387 label_weights = 1.0
388 else:
389 label_weights = array_ops.expand_dims(label_weights, 0)
390 label_weights = weights_broadcast_ops.broadcast_weights(label_weights,
391 y_pred)
392 if not multi_label:
393 label_weights = array_ops.reshape(label_weights, [-1])
394 weights = math_ops.multiply(sample_weights, label_weights)
396 # We shouldn't need this, but in case there are predict value that is out of
397 # the range of [0.0, 1.0]
398 y_pred = clip_ops.clip_by_value(y_pred,
399 clip_value_min=0.0, clip_value_max=1.0)
401 y_true = math_ops.cast(math_ops.cast(y_true, dtypes.bool), y_true.dtype)
402 if not multi_label:
403 y_true = array_ops.reshape(y_true, [-1])
404 y_pred = array_ops.reshape(y_pred, [-1])
406 true_labels = math_ops.multiply(y_true, weights)
407 false_labels = math_ops.multiply((1.0 - y_true), weights)
409 # Compute the bucket indices for each prediction value.
410 # Since the predict value has to be strictly greater than the thresholds,
411 # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket.
412 # We have to use math.ceil(val) - 1 for the bucket.
413 bucket_indices = math_ops.ceil(y_pred * (num_thresholds - 1)) - 1
415 if thresholds_with_epsilon:
416 # In this case, the first bucket should actually take into account since
417 # the any prediction between [0.0, 1.0] should be larger than the first
418 # threshold. We change the bucket value from -1 to 0.
419 bucket_indices = nn_ops.relu(bucket_indices)
421 bucket_indices = math_ops.cast(bucket_indices, dtypes.int32)
423 if multi_label:
424 # We need to run bucket segment sum for each of the label class. In the
425 # multi_label case, the rank of the label is 2. We first transpose it so
426 # that the label dim becomes the first and we can parallel run though them.
427 true_labels = array_ops.transpose_v2(true_labels)
428 false_labels = array_ops.transpose_v2(false_labels)
429 bucket_indices = array_ops.transpose_v2(bucket_indices)
431 def gather_bucket(label_and_bucket_index):
432 label, bucket_index = label_and_bucket_index[0], label_and_bucket_index[1]
433 return math_ops.unsorted_segment_sum(
434 data=label, segment_ids=bucket_index, num_segments=num_thresholds)
435 tp_bucket_v = parallel_control_flow_ops.vectorized_map(
436 gather_bucket, (true_labels, bucket_indices))
437 fp_bucket_v = parallel_control_flow_ops.vectorized_map(
438 gather_bucket, (false_labels, bucket_indices))
439 tp = array_ops.transpose_v2(
440 math_ops.cumsum(tp_bucket_v, reverse=True, axis=1))
441 fp = array_ops.transpose_v2(
442 math_ops.cumsum(fp_bucket_v, reverse=True, axis=1))
443 else:
444 tp_bucket_v = math_ops.unsorted_segment_sum(
445 data=true_labels, segment_ids=bucket_indices,
446 num_segments=num_thresholds)
447 fp_bucket_v = math_ops.unsorted_segment_sum(
448 data=false_labels, segment_ids=bucket_indices,
449 num_segments=num_thresholds)
450 tp = math_ops.cumsum(tp_bucket_v, reverse=True)
451 fp = math_ops.cumsum(fp_bucket_v, reverse=True)
453 # fn = sum(true_labels) - tp
454 # tn = sum(false_labels) - fp
455 if (ConfusionMatrix.TRUE_NEGATIVES in variables_to_update or
456 ConfusionMatrix.FALSE_NEGATIVES in variables_to_update):
457 if multi_label:
458 total_true_labels = math_ops.reduce_sum(true_labels, axis=1)
459 total_false_labels = math_ops.reduce_sum(false_labels, axis=1)
460 else:
461 total_true_labels = math_ops.reduce_sum(true_labels)
462 total_false_labels = math_ops.reduce_sum(false_labels)
464 update_ops = []
465 if ConfusionMatrix.TRUE_POSITIVES in variables_to_update:
466 variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES]
467 update_ops.append(variable.assign_add(tp))
468 if ConfusionMatrix.FALSE_POSITIVES in variables_to_update:
469 variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES]
470 update_ops.append(variable.assign_add(fp))
471 if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update:
472 variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES]
473 tn = total_false_labels - fp
474 update_ops.append(variable.assign_add(tn))
475 if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update:
476 variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES]
477 fn = total_true_labels - tp
478 update_ops.append(variable.assign_add(fn))
479 return control_flow_ops.group(update_ops)
482def is_evenly_distributed_thresholds(thresholds):
483 """Check if the thresholds list is evenly distributed.
485 We could leverage evenly distributed thresholds to use less memory when
486 calculate metrcis like AUC where each individual threshold need to be
487 evaluted.
489 Args:
490 thresholds: A python list or tuple, or 1D numpy array whose value is ranged
491 in [0, 1].
493 Returns:
494 boolean, whether the values in the inputs are evenly distributed.
495 """
496 # Check the list value and see if it is evenly distributed.
497 num_thresholds = len(thresholds)
498 if num_thresholds < 3:
499 return False
500 even_thresholds = np.arange(num_thresholds,
501 dtype=np.float32) / (num_thresholds - 1)
502 return np.allclose(thresholds, even_thresholds, atol=backend.epsilon())
505def update_confusion_matrix_variables(variables_to_update,
506 y_true,
507 y_pred,
508 thresholds,
509 top_k=None,
510 class_id=None,
511 sample_weight=None,
512 multi_label=False,
513 label_weights=None,
514 thresholds_distributed_evenly=False):
515 """Returns op to update the given confusion matrix variables.
517 For every pair of values in y_true and y_pred:
519 true_positive: y_true == True and y_pred > thresholds
520 false_negatives: y_true == True and y_pred <= thresholds
521 true_negatives: y_true == False and y_pred <= thresholds
522 false_positive: y_true == False and y_pred > thresholds
524 The results will be weighted and added together. When multiple thresholds are
525 provided, we will repeat the same for every threshold.
527 For estimation of these metrics over a stream of data, the function creates an
528 `update_op` operation that updates the given variables.
530 If `sample_weight` is `None`, weights default to 1.
531 Use weights of 0 to mask values.
533 Args:
534 variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
535 and corresponding variables to update as values.
536 y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
537 y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
538 the range `[0, 1]`.
539 thresholds: A float value, float tensor, python list, or tuple of float
540 thresholds in `[0, 1]`, or NEG_INF (used when top_k is set).
541 top_k: Optional int, indicates that the positive labels should be limited to
542 the top k predictions.
543 class_id: Optional int, limits the prediction and labels to the class
544 specified by this argument.
545 sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
546 `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must
547 be either `1`, or the same as the corresponding `y_true` dimension).
548 multi_label: Optional boolean indicating whether multidimensional
549 prediction/labels should be treated as multilabel responses, or flattened
550 into a single label. When True, the valus of `variables_to_update` must
551 have a second dimension equal to the number of labels in y_true and
552 y_pred, and those tensors must not be RaggedTensors.
553 label_weights: (optional) tensor of non-negative weights for multilabel
554 data. The weights are applied when calculating TP, FP, FN, and TN without
555 explicit multilabel handling (i.e. when the data is to be flattened).
556 thresholds_distributed_evenly: Boolean, whether the thresholds are evenly
557 distributed within the list. An optimized method will be used if this is
558 the case. See _update_confusion_matrix_variables_optimized() for more
559 details.
561 Returns:
562 Update op.
564 Raises:
565 ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
566 `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if
567 `variables_to_update` contains invalid keys.
568 """
569 if multi_label and label_weights is not None:
570 raise ValueError('`label_weights` for multilabel data should be handled '
571 'outside of `update_confusion_matrix_variables` when '
572 '`multi_label` is True.')
573 if variables_to_update is None:
574 return
575 if not any(
576 key for key in variables_to_update if key in list(ConfusionMatrix)):
577 raise ValueError(
578 'Please provide at least one valid confusion matrix '
579 'variable to update. Valid variable key options are: "{}". '
580 'Received: "{}"'.format(
581 list(ConfusionMatrix), variables_to_update.keys()))
583 variable_dtype = list(variables_to_update.values())[0].dtype
585 y_true = math_ops.cast(y_true, dtype=variable_dtype)
586 y_pred = math_ops.cast(y_pred, dtype=variable_dtype)
588 if thresholds_distributed_evenly:
589 # Check whether the thresholds has any leading or tailing epsilon added
590 # for floating point imprecision. The leading and tailing threshold will be
591 # handled bit differently as the corner case.
592 # At this point, thresholds should be a list/array with more than 2 items,
593 # and ranged between [0, 1]. See is_evenly_distributed_thresholds() for more
594 # details.
595 thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0
597 thresholds = tensor_conversion.convert_to_tensor_v2_with_dispatch(
598 thresholds, dtype=variable_dtype
599 )
600 num_thresholds = thresholds.shape.as_list()[0]
602 if multi_label:
603 one_thresh = math_ops.equal(
604 math_ops.cast(1, dtype=dtypes.int32),
605 array_ops.rank(thresholds),
606 name='one_set_of_thresholds_cond')
607 else:
608 [y_pred,
609 y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
610 sample_weight)
611 one_thresh = math_ops.cast(True, dtype=dtypes.bool)
613 invalid_keys = [
614 key for key in variables_to_update if key not in list(ConfusionMatrix)
615 ]
616 if invalid_keys:
617 raise ValueError(
618 'Invalid keys: {}. Valid variable key options are: "{}"'.format(
619 invalid_keys, list(ConfusionMatrix)))
621 with ops.control_dependencies([
622 check_ops.assert_greater_equal(
623 y_pred,
624 math_ops.cast(0.0, dtype=y_pred.dtype),
625 message='predictions must be >= 0'),
626 check_ops.assert_less_equal(
627 y_pred,
628 math_ops.cast(1.0, dtype=y_pred.dtype),
629 message='predictions must be <= 1')
630 ]):
631 if sample_weight is None:
632 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
633 y_pred, y_true)
634 else:
635 sample_weight = math_ops.cast(sample_weight, dtype=variable_dtype)
636 y_pred, y_true, sample_weight = (
637 losses_utils.squeeze_or_expand_dimensions(
638 y_pred, y_true, sample_weight=sample_weight))
639 y_pred.shape.assert_is_compatible_with(y_true.shape)
641 if top_k is not None:
642 y_pred = _filter_top_k(y_pred, top_k)
643 if class_id is not None:
644 y_true = y_true[..., class_id]
645 y_pred = y_pred[..., class_id]
647 if thresholds_distributed_evenly and compat.forward_compatible(2021, 6, 8):
648 # The new approach will take effect after 2021/6/8, to give enough time
649 # for Brella release to pick up the new op tf.math.cumsum with float32.
650 return _update_confusion_matrix_variables_optimized(
651 variables_to_update, y_true, y_pred, thresholds,
652 multi_label=multi_label, sample_weights=sample_weight,
653 label_weights=label_weights,
654 thresholds_with_epsilon=thresholds_with_epsilon)
656 pred_shape = array_ops.shape(y_pred)
657 num_predictions = pred_shape[0]
658 if y_pred.shape.ndims == 1:
659 num_labels = 1
660 else:
661 num_labels = gen_math_ops.Prod(input=pred_shape[1:], axis=0)
662 thresh_label_tile = array_ops.where_v2(one_thresh, num_labels,
663 array_ops.ones([], dtype=dtypes.int32))
665 # Reshape predictions and labels, adding a dim for thresholding.
666 if multi_label:
667 predictions_extra_dim = array_ops.expand_dims(y_pred, 0)
668 labels_extra_dim = array_ops.expand_dims(
669 math_ops.cast(y_true, dtype=dtypes.bool), 0)
670 else:
671 # Flatten predictions and labels when not multilabel.
672 predictions_extra_dim = array_ops.reshape(y_pred, [1, -1])
673 labels_extra_dim = array_ops.reshape(
674 math_ops.cast(y_true, dtype=dtypes.bool), [1, -1])
676 # Tile the thresholds for every prediction.
677 if multi_label:
678 thresh_pretile_shape = [num_thresholds, 1, -1]
679 thresh_tiles = [1, num_predictions, thresh_label_tile]
680 data_tiles = [num_thresholds, 1, 1]
681 else:
682 thresh_pretile_shape = [num_thresholds, -1]
683 thresh_tiles = [1, num_predictions * num_labels]
684 data_tiles = [num_thresholds, 1]
686 thresh_tiled = array_ops.tile(
687 array_ops.reshape(thresholds, thresh_pretile_shape),
688 array_ops_stack.stack(thresh_tiles))
690 # Tile the predictions for every threshold.
691 preds_tiled = array_ops.tile(predictions_extra_dim, data_tiles)
693 # Compare predictions and threshold.
694 pred_is_pos = math_ops.greater(preds_tiled, thresh_tiled)
696 # Tile labels by number of thresholds
697 label_is_pos = array_ops.tile(labels_extra_dim, data_tiles)
699 if sample_weight is not None:
700 sample_weight = weights_broadcast_ops.broadcast_weights(
701 math_ops.cast(sample_weight, dtype=variable_dtype), y_pred)
702 weights_tiled = array_ops.tile(
703 array_ops.reshape(sample_weight, thresh_tiles), data_tiles)
704 else:
705 weights_tiled = None
707 if label_weights is not None and not multi_label:
708 label_weights = array_ops.expand_dims(label_weights, 0)
709 label_weights = weights_broadcast_ops.broadcast_weights(label_weights,
710 y_pred)
711 label_weights_tiled = array_ops.tile(
712 array_ops.reshape(label_weights, thresh_tiles), data_tiles)
713 if weights_tiled is None:
714 weights_tiled = label_weights_tiled
715 else:
716 weights_tiled = math_ops.multiply(weights_tiled, label_weights_tiled)
718 update_ops = []
720 def weighted_assign_add(label, pred, weights, var):
721 label_and_pred = math_ops.cast(
722 math_ops.logical_and(label, pred), dtype=var.dtype)
723 if weights is not None:
724 label_and_pred *= math_ops.cast(weights, dtype=var.dtype)
725 return var.assign_add(math_ops.reduce_sum(label_and_pred, 1))
727 loop_vars = {
728 ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
729 }
730 update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
731 update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
732 update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
734 if update_fn or update_tn:
735 pred_is_neg = math_ops.logical_not(pred_is_pos)
736 loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
738 if update_fp or update_tn:
739 label_is_neg = math_ops.logical_not(label_is_pos)
740 loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
741 if update_tn:
742 loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg)
744 for matrix_cond, (label, pred) in loop_vars.items():
746 if matrix_cond in variables_to_update:
747 update_ops.append(
748 weighted_assign_add(label, pred, weights_tiled,
749 variables_to_update[matrix_cond]))
751 return control_flow_ops.group(update_ops)
754def _filter_top_k(x, k):
755 """Filters top-k values in the last dim of x and set the rest to NEG_INF.
757 Used for computing top-k prediction values in dense labels (which has the same
758 shape as predictions) for recall and precision top-k metrics.
760 Args:
761 x: tensor with any dimensions.
762 k: the number of values to keep.
764 Returns:
765 tensor with same shape and dtype as x.
766 """
767 _, top_k_idx = nn_ops.top_k(x, k, sorted=False)
768 top_k_mask = math_ops.reduce_sum(
769 array_ops.one_hot(top_k_idx, array_ops.shape(x)[-1], axis=-1), axis=-2)
770 return x * top_k_mask + NEG_INF * (1 - top_k_mask)
773def ragged_assert_compatible_and_get_flat_values(values, mask=None):
774 """If ragged, it checks the compatibility and then returns the flat_values.
776 Note: If two tensors are dense, it does not check their compatibility.
777 Note: Although two ragged tensors with different ragged ranks could have
778 identical overall rank and dimension sizes and hence be compatible,
779 we do not support those cases.
780 Args:
781 values: A list of potentially ragged tensor of the same ragged_rank.
782 mask: A potentially ragged tensor of the same ragged_rank as elements in
783 Values.
785 Returns:
786 A tuple in which the first element is the list of tensors and the second
787 is the mask tensor. ([Values], mask). Mask and the element in Values
788 are equal to the flat_values of the input arguments (if they were ragged).
789 """
790 if isinstance(values, list):
791 is_all_ragged = \
792 all(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
793 is_any_ragged = \
794 any(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
795 else:
796 is_all_ragged = isinstance(values, ragged_tensor.RaggedTensor)
797 is_any_ragged = is_all_ragged
798 if (is_all_ragged and
799 ((mask is None) or isinstance(mask, ragged_tensor.RaggedTensor))):
800 to_be_stripped = False
801 if not isinstance(values, list):
802 values = [values]
803 to_be_stripped = True
805 # NOTE: we leave the flat_values compatibility to
806 # tf.TensorShape `assert_is_compatible_with`
807 # check if both dynamic dimensions are equal and then use the flat_values.
808 nested_row_split_list = [rt.nested_row_splits for rt in values]
809 assertion_list = _assert_splits_match(nested_row_split_list)
811 # if both are ragged sample_weights also should be ragged with same dims.
812 if isinstance(mask, ragged_tensor.RaggedTensor):
813 assertion_list_for_mask = _assert_splits_match(
814 [nested_row_split_list[0], mask.nested_row_splits])
815 with ops.control_dependencies(assertion_list_for_mask):
816 mask = array_ops.expand_dims(mask.flat_values, -1)
818 # values has at least 1 element.
819 flat_values = []
820 for value in values:
821 with ops.control_dependencies(assertion_list):
822 flat_values.append(array_ops.expand_dims(value.flat_values, -1))
824 values = flat_values[0] if to_be_stripped else flat_values
826 elif is_any_ragged:
827 raise TypeError('One of the inputs does not have acceptable types.')
828 # values are empty or value are not ragged and mask is ragged.
829 elif isinstance(mask, ragged_tensor.RaggedTensor):
830 raise TypeError('Ragged mask is not allowed with non-ragged inputs.')
832 return values, mask
835def _assert_splits_match(nested_splits_lists):
836 """Checks that the given splits lists are identical.
838 Performs static tests to ensure that the given splits lists are identical,
839 and returns a list of control dependency op tensors that check that they are
840 fully identical.
842 Args:
843 nested_splits_lists: A list of nested_splits_lists, where each split_list is
844 a list of `splits` tensors from a `RaggedTensor`, ordered from outermost
845 ragged dimension to innermost ragged dimension.
847 Returns:
848 A list of control dependency op tensors.
849 Raises:
850 ValueError: If the splits are not identical.
851 """
852 error_msg = 'Inputs must have identical ragged splits'
853 for splits_list in nested_splits_lists:
854 if len(splits_list) != len(nested_splits_lists[0]):
855 raise ValueError(error_msg)
856 return [
857 check_ops.assert_equal(s1, s2, message=error_msg) # pylint: disable=g-complex-comprehension
858 for splits_list in nested_splits_lists[1:]
859 for (s1, s2) in zip(nested_splits_lists[0], splits_list)
860 ]