Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/metrics_utils.py: 19%

307 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=protected-access 

16"""Utils related to keras metrics.""" 

17 

18from enum import Enum 

19import functools 

20import weakref 

21import numpy as np 

22 

23from tensorflow.python.compat import compat 

24from tensorflow.python.distribute import distribute_lib 

25from tensorflow.python.framework import dtypes 

26from tensorflow.python.framework import ops 

27from tensorflow.python.framework import tensor_conversion 

28from tensorflow.python.keras import backend 

29from tensorflow.python.keras.utils import losses_utils 

30from tensorflow.python.keras.utils import tf_utils 

31from tensorflow.python.keras.utils.generic_utils import to_list 

32from tensorflow.python.ops import array_ops 

33from tensorflow.python.ops import array_ops_stack 

34from tensorflow.python.ops import check_ops 

35from tensorflow.python.ops import clip_ops 

36from tensorflow.python.ops import control_flow_ops 

37from tensorflow.python.ops import gen_math_ops 

38from tensorflow.python.ops import math_ops 

39from tensorflow.python.ops import nn_ops 

40from tensorflow.python.ops import variables as variables_module 

41from tensorflow.python.ops import weights_broadcast_ops 

42from tensorflow.python.ops.parallel_for import control_flow_ops as parallel_control_flow_ops 

43from tensorflow.python.ops.ragged import ragged_tensor 

44from tensorflow.python.util import tf_decorator 

45 

46NEG_INF = -1e10 

47 

48 

49class Reduction(Enum): 

50 """Types of metrics reduction. 

51 

52 Contains the following values: 

53 

54 * `SUM`: Scalar sum of weighted values. 

55 * `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by 

56 number of elements. 

57 * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights. 

58 """ 

59 SUM = 'sum' 

60 SUM_OVER_BATCH_SIZE = 'sum_over_batch_size' 

61 WEIGHTED_MEAN = 'weighted_mean' 

62 

63 

64def update_state_wrapper(update_state_fn): 

65 """Decorator to wrap metric `update_state()` with `add_update()`. 

66 

67 Args: 

68 update_state_fn: function that accumulates metric statistics. 

69 

70 Returns: 

71 Decorated function that wraps `update_state_fn()` with `add_update()`. 

72 """ 

73 

74 def decorated(metric_obj, *args, **kwargs): 

75 """Decorated function with `add_update()`.""" 

76 strategy = distribute_lib.get_strategy() 

77 

78 for weight in metric_obj.weights: 

79 if (backend.is_tpu_strategy(strategy) and 

80 not strategy.extended.variable_created_in_scope(weight) 

81 and not distribute_lib.in_cross_replica_context()): 

82 raise ValueError( 

83 'Trying to run metric.update_state in replica context when ' 

84 'the metric was not created in TPUStrategy scope. ' 

85 'Make sure the keras Metric is created in TPUstrategy scope. ') 

86 

87 with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs): 

88 update_op = update_state_fn(*args, **kwargs) 

89 if update_op is not None: # update_op will be None in eager execution. 

90 metric_obj.add_update(update_op) 

91 return update_op 

92 

93 return tf_decorator.make_decorator(update_state_fn, decorated) 

94 

95 

96def result_wrapper(result_fn): 

97 """Decorator to wrap metric `result()` function in `merge_call()`. 

98 

99 Result computation is an idempotent operation that simply calculates the 

100 metric value using the state variables. 

101 

102 If metric state variables are distributed across replicas/devices and 

103 `result()` is requested from the context of one device - This function wraps 

104 `result()` in a distribution strategy `merge_call()`. With this, 

105 the metric state variables will be aggregated across devices. 

106 

107 Args: 

108 result_fn: function that computes the metric result. 

109 

110 Returns: 

111 Decorated function that wraps `result_fn()` in distribution strategy 

112 `merge_call()`. 

113 """ 

114 

115 def decorated(metric_obj, *args): 

116 """Decorated function with merge_call.""" 

117 has_strategy = distribute_lib.has_strategy() 

118 replica_context = distribute_lib.get_replica_context() 

119 

120 # The purpose of using `merge_call` to call `result()` is to trigger cross 

121 # replica aggregation of metric state variables (SyncOnReadVariable). After 

122 # we introduced `variable_sync_on_read_context`, in principle there is no 

123 # need to use `merge_call` here. However the branch still exists because: 

124 # 

125 # 1. Keras V1 training code sometimes assumes `result_t` is the same tensor 

126 # across replicas (achieved by `merge_call`). With 

127 # `variable_sync_on_read_context` each replica gets their own tensors 

128 # residing on replica's device, thus breaking the assumption. 

129 # 2. Keras c/fit creates a tf.function (a.k.a, train_function) that returns 

130 # the metric values of the first replica. With 

131 # `variable_sync_on_read_context` since each replica gets their own 

132 # tensors, the metric result tensors on the non-first replicas are not in 

133 # the return value of train_function, making TF graph optimizer prune the 

134 # branch that computes and aggregates those metric results. As a result, 

135 # if NCCL is used to do the aggregation, the program will hang because 

136 # NCCL ops are only launched on the non-pruned first replica. 

137 # 

138 # We condition on strategy.extended._use_merge_call() since we know if it is 

139 # false, the program uses `jit_compile` to compile replica fn, meaning it is 

140 # not V1 training (hence #1 is okay), and no pruning will happen as 

141 # compiled functions are not inlined (hence #2 is okay). 

142 

143 if (not has_strategy or replica_context is None or 

144 not distribute_lib.get_strategy( 

145 ).extended._use_merge_call()): 

146 with distribute_lib.variable_sync_on_read_context(): 

147 raw_result = result_fn(*args) 

148 # Results need to be wrapped in a `tf.identity` op to ensure 

149 # correct execution order. 

150 if isinstance(raw_result, 

151 (ops.Tensor, variables_module.Variable, float, int)): 

152 result_t = array_ops.identity(raw_result) 

153 elif isinstance(raw_result, dict): 

154 result_t = { 

155 key: array_ops.identity(value) 

156 for key, value in raw_result.items() 

157 } 

158 else: 

159 try: 

160 result_t = array_ops.identity(raw_result) 

161 except (ValueError, TypeError): 

162 raise RuntimeError( 

163 'The output of `metric.result()` can only be a single ' 

164 'Tensor/Variable, or a dict of Tensors/Variables. ' 

165 'For metric %s, got result %s.' % (metric_obj.name, raw_result)) 

166 else: 

167 # TODO(psv): Test distribution of metrics using different distribution 

168 # strategies. 

169 

170 # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn 

171 # with distribution object as the first parameter. We create a wrapper 

172 # here so that the result function need not have that parameter. 

173 def merge_fn_wrapper(distribution, merge_fn, *args): 

174 # We will get `PerReplica` merge function. Taking the first one as all 

175 # are identical copies of the function that we had passed below. 

176 result = distribution.experimental_local_results(merge_fn)[0](*args) 

177 

178 # Wrapping result in identity so that control dependency between 

179 # update_op from `update_state` and result works in case result returns 

180 # a tensor. 

181 return array_ops.identity(result) 

182 

183 # Wrapping result in merge_call. merge_call is used when we want to leave 

184 # replica mode and compute a value in cross replica mode. 

185 result_t = replica_context.merge_call( 

186 merge_fn_wrapper, args=(result_fn,) + args) 

187 

188 # We are saving the result op here to be used in train/test execution 

189 # functions. This basically gives the result op that was generated with a 

190 # control dep to the updates for these workflows. 

191 metric_obj._call_result = result_t 

192 return result_t 

193 

194 return tf_decorator.make_decorator(result_fn, decorated) 

195 

196 

197def weakmethod(method): 

198 """Creates a weak reference to the bound method.""" 

199 

200 cls = method.im_class 

201 func = method.im_func 

202 instance_ref = weakref.ref(method.im_self) 

203 

204 @functools.wraps(method) 

205 def inner(*args, **kwargs): 

206 return func.__get__(instance_ref(), cls)(*args, **kwargs) 

207 

208 del method 

209 return inner 

210 

211 

212def assert_thresholds_range(thresholds): 

213 if thresholds is not None: 

214 invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1] 

215 if invalid_thresholds: 

216 raise ValueError( 

217 'Threshold values must be in [0, 1]. Invalid values: {}'.format( 

218 invalid_thresholds)) 

219 

220 

221def parse_init_thresholds(thresholds, default_threshold=0.5): 

222 if thresholds is not None: 

223 assert_thresholds_range(to_list(thresholds)) 

224 thresholds = to_list(default_threshold if thresholds is None else thresholds) 

225 return thresholds 

226 

227 

228class ConfusionMatrix(Enum): 

229 TRUE_POSITIVES = 'tp' 

230 FALSE_POSITIVES = 'fp' 

231 TRUE_NEGATIVES = 'tn' 

232 FALSE_NEGATIVES = 'fn' 

233 

234 

235class AUCCurve(Enum): 

236 """Type of AUC Curve (ROC or PR).""" 

237 ROC = 'ROC' 

238 PR = 'PR' 

239 

240 @staticmethod 

241 def from_str(key): 

242 if key in ('pr', 'PR'): 

243 return AUCCurve.PR 

244 elif key in ('roc', 'ROC'): 

245 return AUCCurve.ROC 

246 else: 

247 raise ValueError('Invalid AUC curve value "%s".' % key) 

248 

249 

250class AUCSummationMethod(Enum): 

251 """Type of AUC summation method. 

252 

253 https://en.wikipedia.org/wiki/Riemann_sum) 

254 

255 Contains the following values: 

256 * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For 

257 `PR` curve, interpolates (true/false) positives but not the ratio that is 

258 precision (see Davis & Goadrich 2006 for details). 

259 * 'minoring': Applies left summation for increasing intervals and right 

260 summation for decreasing intervals. 

261 * 'majoring': Applies right summation for increasing intervals and left 

262 summation for decreasing intervals. 

263 """ 

264 INTERPOLATION = 'interpolation' 

265 MAJORING = 'majoring' 

266 MINORING = 'minoring' 

267 

268 @staticmethod 

269 def from_str(key): 

270 if key in ('interpolation', 'Interpolation'): 

271 return AUCSummationMethod.INTERPOLATION 

272 elif key in ('majoring', 'Majoring'): 

273 return AUCSummationMethod.MAJORING 

274 elif key in ('minoring', 'Minoring'): 

275 return AUCSummationMethod.MINORING 

276 else: 

277 raise ValueError('Invalid AUC summation method value "%s".' % key) 

278 

279 

280def _update_confusion_matrix_variables_optimized( 

281 variables_to_update, 

282 y_true, 

283 y_pred, 

284 thresholds, 

285 multi_label=False, 

286 sample_weights=None, 

287 label_weights=None, 

288 thresholds_with_epsilon=False): 

289 """Update confusion matrix variables with memory efficient alternative. 

290 

291 Note that the thresholds need to be evenly distributed within the list, eg, 

292 the diff between consecutive elements are the same. 

293 

294 To compute TP/FP/TN/FN, we are measuring a binary classifier 

295 C(t) = (predictions >= t) 

296 at each threshold 't'. So we have 

297 TP(t) = sum( C(t) * true_labels ) 

298 FP(t) = sum( C(t) * false_labels ) 

299 

300 But, computing C(t) requires computation for each t. To make it fast, 

301 observe that C(t) is a cumulative integral, and so if we have 

302 thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} 

303 where n = num_thresholds, and if we can compute the bucket function 

304 B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) 

305 then we get 

306 C(t_i) = sum( B(j), j >= i ) 

307 which is the reversed cumulative sum in tf.cumsum(). 

308 

309 We can compute B(i) efficiently by taking advantage of the fact that 

310 our thresholds are evenly distributed, in that 

311 width = 1.0 / (num_thresholds - 1) 

312 thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] 

313 Given a prediction value p, we can map it to its bucket by 

314 bucket_index(p) = floor( p * (num_thresholds - 1) ) 

315 so we can use tf.math.unsorted_segment_sum() to update the buckets in one 

316 pass. 

317 

318 Consider following example: 

319 y_true = [0, 0, 1, 1] 

320 y_pred = [0.1, 0.5, 0.3, 0.9] 

321 thresholds = [0.0, 0.5, 1.0] 

322 num_buckets = 2 # [0.0, 1.0], (1.0, 2.0] 

323 bucket_index(y_pred) = tf.math.floor(y_pred * num_buckets) 

324 = tf.math.floor([0.2, 1.0, 0.6, 1.8]) 

325 = [0, 0, 0, 1] 

326 # The meaning of this bucket is that if any of the label is true, 

327 # then 1 will be added to the corresponding bucket with the index. 

328 # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the 

329 # label for 1.8 is true, then 1 will be added to bucket 1. 

330 # 

331 # Note the second item "1.0" is floored to 0, since the value need to be 

332 # strictly larger than the bucket lower bound. 

333 # In the implementation, we use tf.math.ceil() - 1 to achieve this. 

334 tp_bucket_value = tf.math.unsorted_segment_sum(true_labels, bucket_indices, 

335 num_segments=num_thresholds) 

336 = [1, 1, 0] 

337 # For [1, 1, 0] here, it means there is 1 true value contributed by bucket 0, 

338 # and 1 value contributed by bucket 1. When we aggregate them to together, 

339 # the result become [a + b + c, b + c, c], since large thresholds will always 

340 # contribute to the value for smaller thresholds. 

341 true_positive = tf.math.cumsum(tp_bucket_value, reverse=True) 

342 = [2, 1, 0] 

343 

344 This implementation exhibits a run time and space complexity of O(T + N), 

345 where T is the number of thresholds and N is the size of predictions. 

346 Metrics that rely on standard implementation instead exhibit a complexity of 

347 O(T * N). 

348 

349 Args: 

350 variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys 

351 and corresponding variables to update as values. 

352 y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be cast 

353 to `bool`. 

354 y_pred: A floating point `Tensor` of arbitrary shape and whose values are in 

355 the range `[0, 1]`. 

356 thresholds: A sorted floating point `Tensor` with value in `[0, 1]`. 

357 It need to be evenly distributed (the diff between each element need to be 

358 the same). 

359 multi_label: Optional boolean indicating whether multidimensional 

360 prediction/labels should be treated as multilabel responses, or flattened 

361 into a single label. When True, the valus of `variables_to_update` must 

362 have a second dimension equal to the number of labels in y_true and 

363 y_pred, and those tensors must not be RaggedTensors. 

364 sample_weights: Optional `Tensor` whose rank is either 0, or the same rank 

365 as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions 

366 must be either `1`, or the same as the corresponding `y_true` dimension). 

367 label_weights: Optional tensor of non-negative weights for multilabel 

368 data. The weights are applied when calculating TP, FP, FN, and TN without 

369 explicit multilabel handling (i.e. when the data is to be flattened). 

370 thresholds_with_epsilon: Optional boolean indicating whether the leading and 

371 tailing thresholds has any epsilon added for floating point imprecisions. 

372 It will change how we handle the leading and tailing bucket. 

373 

374 Returns: 

375 Update op. 

376 """ 

377 num_thresholds = thresholds.shape.as_list()[0] 

378 

379 if sample_weights is None: 

380 sample_weights = 1.0 

381 else: 

382 sample_weights = weights_broadcast_ops.broadcast_weights( 

383 math_ops.cast(sample_weights, dtype=y_pred.dtype), y_pred) 

384 if not multi_label: 

385 sample_weights = array_ops.reshape(sample_weights, [-1]) 

386 if label_weights is None: 

387 label_weights = 1.0 

388 else: 

389 label_weights = array_ops.expand_dims(label_weights, 0) 

390 label_weights = weights_broadcast_ops.broadcast_weights(label_weights, 

391 y_pred) 

392 if not multi_label: 

393 label_weights = array_ops.reshape(label_weights, [-1]) 

394 weights = math_ops.multiply(sample_weights, label_weights) 

395 

396 # We shouldn't need this, but in case there are predict value that is out of 

397 # the range of [0.0, 1.0] 

398 y_pred = clip_ops.clip_by_value(y_pred, 

399 clip_value_min=0.0, clip_value_max=1.0) 

400 

401 y_true = math_ops.cast(math_ops.cast(y_true, dtypes.bool), y_true.dtype) 

402 if not multi_label: 

403 y_true = array_ops.reshape(y_true, [-1]) 

404 y_pred = array_ops.reshape(y_pred, [-1]) 

405 

406 true_labels = math_ops.multiply(y_true, weights) 

407 false_labels = math_ops.multiply((1.0 - y_true), weights) 

408 

409 # Compute the bucket indices for each prediction value. 

410 # Since the predict value has to be strictly greater than the thresholds, 

411 # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket. 

412 # We have to use math.ceil(val) - 1 for the bucket. 

413 bucket_indices = math_ops.ceil(y_pred * (num_thresholds - 1)) - 1 

414 

415 if thresholds_with_epsilon: 

416 # In this case, the first bucket should actually take into account since 

417 # the any prediction between [0.0, 1.0] should be larger than the first 

418 # threshold. We change the bucket value from -1 to 0. 

419 bucket_indices = nn_ops.relu(bucket_indices) 

420 

421 bucket_indices = math_ops.cast(bucket_indices, dtypes.int32) 

422 

423 if multi_label: 

424 # We need to run bucket segment sum for each of the label class. In the 

425 # multi_label case, the rank of the label is 2. We first transpose it so 

426 # that the label dim becomes the first and we can parallel run though them. 

427 true_labels = array_ops.transpose_v2(true_labels) 

428 false_labels = array_ops.transpose_v2(false_labels) 

429 bucket_indices = array_ops.transpose_v2(bucket_indices) 

430 

431 def gather_bucket(label_and_bucket_index): 

432 label, bucket_index = label_and_bucket_index[0], label_and_bucket_index[1] 

433 return math_ops.unsorted_segment_sum( 

434 data=label, segment_ids=bucket_index, num_segments=num_thresholds) 

435 tp_bucket_v = parallel_control_flow_ops.vectorized_map( 

436 gather_bucket, (true_labels, bucket_indices)) 

437 fp_bucket_v = parallel_control_flow_ops.vectorized_map( 

438 gather_bucket, (false_labels, bucket_indices)) 

439 tp = array_ops.transpose_v2( 

440 math_ops.cumsum(tp_bucket_v, reverse=True, axis=1)) 

441 fp = array_ops.transpose_v2( 

442 math_ops.cumsum(fp_bucket_v, reverse=True, axis=1)) 

443 else: 

444 tp_bucket_v = math_ops.unsorted_segment_sum( 

445 data=true_labels, segment_ids=bucket_indices, 

446 num_segments=num_thresholds) 

447 fp_bucket_v = math_ops.unsorted_segment_sum( 

448 data=false_labels, segment_ids=bucket_indices, 

449 num_segments=num_thresholds) 

450 tp = math_ops.cumsum(tp_bucket_v, reverse=True) 

451 fp = math_ops.cumsum(fp_bucket_v, reverse=True) 

452 

453 # fn = sum(true_labels) - tp 

454 # tn = sum(false_labels) - fp 

455 if (ConfusionMatrix.TRUE_NEGATIVES in variables_to_update or 

456 ConfusionMatrix.FALSE_NEGATIVES in variables_to_update): 

457 if multi_label: 

458 total_true_labels = math_ops.reduce_sum(true_labels, axis=1) 

459 total_false_labels = math_ops.reduce_sum(false_labels, axis=1) 

460 else: 

461 total_true_labels = math_ops.reduce_sum(true_labels) 

462 total_false_labels = math_ops.reduce_sum(false_labels) 

463 

464 update_ops = [] 

465 if ConfusionMatrix.TRUE_POSITIVES in variables_to_update: 

466 variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES] 

467 update_ops.append(variable.assign_add(tp)) 

468 if ConfusionMatrix.FALSE_POSITIVES in variables_to_update: 

469 variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES] 

470 update_ops.append(variable.assign_add(fp)) 

471 if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update: 

472 variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES] 

473 tn = total_false_labels - fp 

474 update_ops.append(variable.assign_add(tn)) 

475 if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update: 

476 variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES] 

477 fn = total_true_labels - tp 

478 update_ops.append(variable.assign_add(fn)) 

479 return control_flow_ops.group(update_ops) 

480 

481 

482def is_evenly_distributed_thresholds(thresholds): 

483 """Check if the thresholds list is evenly distributed. 

484 

485 We could leverage evenly distributed thresholds to use less memory when 

486 calculate metrcis like AUC where each individual threshold need to be 

487 evaluted. 

488 

489 Args: 

490 thresholds: A python list or tuple, or 1D numpy array whose value is ranged 

491 in [0, 1]. 

492 

493 Returns: 

494 boolean, whether the values in the inputs are evenly distributed. 

495 """ 

496 # Check the list value and see if it is evenly distributed. 

497 num_thresholds = len(thresholds) 

498 if num_thresholds < 3: 

499 return False 

500 even_thresholds = np.arange(num_thresholds, 

501 dtype=np.float32) / (num_thresholds - 1) 

502 return np.allclose(thresholds, even_thresholds, atol=backend.epsilon()) 

503 

504 

505def update_confusion_matrix_variables(variables_to_update, 

506 y_true, 

507 y_pred, 

508 thresholds, 

509 top_k=None, 

510 class_id=None, 

511 sample_weight=None, 

512 multi_label=False, 

513 label_weights=None, 

514 thresholds_distributed_evenly=False): 

515 """Returns op to update the given confusion matrix variables. 

516 

517 For every pair of values in y_true and y_pred: 

518 

519 true_positive: y_true == True and y_pred > thresholds 

520 false_negatives: y_true == True and y_pred <= thresholds 

521 true_negatives: y_true == False and y_pred <= thresholds 

522 false_positive: y_true == False and y_pred > thresholds 

523 

524 The results will be weighted and added together. When multiple thresholds are 

525 provided, we will repeat the same for every threshold. 

526 

527 For estimation of these metrics over a stream of data, the function creates an 

528 `update_op` operation that updates the given variables. 

529 

530 If `sample_weight` is `None`, weights default to 1. 

531 Use weights of 0 to mask values. 

532 

533 Args: 

534 variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys 

535 and corresponding variables to update as values. 

536 y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`. 

537 y_pred: A floating point `Tensor` of arbitrary shape and whose values are in 

538 the range `[0, 1]`. 

539 thresholds: A float value, float tensor, python list, or tuple of float 

540 thresholds in `[0, 1]`, or NEG_INF (used when top_k is set). 

541 top_k: Optional int, indicates that the positive labels should be limited to 

542 the top k predictions. 

543 class_id: Optional int, limits the prediction and labels to the class 

544 specified by this argument. 

545 sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as 

546 `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must 

547 be either `1`, or the same as the corresponding `y_true` dimension). 

548 multi_label: Optional boolean indicating whether multidimensional 

549 prediction/labels should be treated as multilabel responses, or flattened 

550 into a single label. When True, the valus of `variables_to_update` must 

551 have a second dimension equal to the number of labels in y_true and 

552 y_pred, and those tensors must not be RaggedTensors. 

553 label_weights: (optional) tensor of non-negative weights for multilabel 

554 data. The weights are applied when calculating TP, FP, FN, and TN without 

555 explicit multilabel handling (i.e. when the data is to be flattened). 

556 thresholds_distributed_evenly: Boolean, whether the thresholds are evenly 

557 distributed within the list. An optimized method will be used if this is 

558 the case. See _update_confusion_matrix_variables_optimized() for more 

559 details. 

560 

561 Returns: 

562 Update op. 

563 

564 Raises: 

565 ValueError: If `y_pred` and `y_true` have mismatched shapes, or if 

566 `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if 

567 `variables_to_update` contains invalid keys. 

568 """ 

569 if multi_label and label_weights is not None: 

570 raise ValueError('`label_weights` for multilabel data should be handled ' 

571 'outside of `update_confusion_matrix_variables` when ' 

572 '`multi_label` is True.') 

573 if variables_to_update is None: 

574 return 

575 if not any( 

576 key for key in variables_to_update if key in list(ConfusionMatrix)): 

577 raise ValueError( 

578 'Please provide at least one valid confusion matrix ' 

579 'variable to update. Valid variable key options are: "{}". ' 

580 'Received: "{}"'.format( 

581 list(ConfusionMatrix), variables_to_update.keys())) 

582 

583 variable_dtype = list(variables_to_update.values())[0].dtype 

584 

585 y_true = math_ops.cast(y_true, dtype=variable_dtype) 

586 y_pred = math_ops.cast(y_pred, dtype=variable_dtype) 

587 

588 if thresholds_distributed_evenly: 

589 # Check whether the thresholds has any leading or tailing epsilon added 

590 # for floating point imprecision. The leading and tailing threshold will be 

591 # handled bit differently as the corner case. 

592 # At this point, thresholds should be a list/array with more than 2 items, 

593 # and ranged between [0, 1]. See is_evenly_distributed_thresholds() for more 

594 # details. 

595 thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0 

596 

597 thresholds = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

598 thresholds, dtype=variable_dtype 

599 ) 

600 num_thresholds = thresholds.shape.as_list()[0] 

601 

602 if multi_label: 

603 one_thresh = math_ops.equal( 

604 math_ops.cast(1, dtype=dtypes.int32), 

605 array_ops.rank(thresholds), 

606 name='one_set_of_thresholds_cond') 

607 else: 

608 [y_pred, 

609 y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true], 

610 sample_weight) 

611 one_thresh = math_ops.cast(True, dtype=dtypes.bool) 

612 

613 invalid_keys = [ 

614 key for key in variables_to_update if key not in list(ConfusionMatrix) 

615 ] 

616 if invalid_keys: 

617 raise ValueError( 

618 'Invalid keys: {}. Valid variable key options are: "{}"'.format( 

619 invalid_keys, list(ConfusionMatrix))) 

620 

621 with ops.control_dependencies([ 

622 check_ops.assert_greater_equal( 

623 y_pred, 

624 math_ops.cast(0.0, dtype=y_pred.dtype), 

625 message='predictions must be >= 0'), 

626 check_ops.assert_less_equal( 

627 y_pred, 

628 math_ops.cast(1.0, dtype=y_pred.dtype), 

629 message='predictions must be <= 1') 

630 ]): 

631 if sample_weight is None: 

632 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 

633 y_pred, y_true) 

634 else: 

635 sample_weight = math_ops.cast(sample_weight, dtype=variable_dtype) 

636 y_pred, y_true, sample_weight = ( 

637 losses_utils.squeeze_or_expand_dimensions( 

638 y_pred, y_true, sample_weight=sample_weight)) 

639 y_pred.shape.assert_is_compatible_with(y_true.shape) 

640 

641 if top_k is not None: 

642 y_pred = _filter_top_k(y_pred, top_k) 

643 if class_id is not None: 

644 y_true = y_true[..., class_id] 

645 y_pred = y_pred[..., class_id] 

646 

647 if thresholds_distributed_evenly and compat.forward_compatible(2021, 6, 8): 

648 # The new approach will take effect after 2021/6/8, to give enough time 

649 # for Brella release to pick up the new op tf.math.cumsum with float32. 

650 return _update_confusion_matrix_variables_optimized( 

651 variables_to_update, y_true, y_pred, thresholds, 

652 multi_label=multi_label, sample_weights=sample_weight, 

653 label_weights=label_weights, 

654 thresholds_with_epsilon=thresholds_with_epsilon) 

655 

656 pred_shape = array_ops.shape(y_pred) 

657 num_predictions = pred_shape[0] 

658 if y_pred.shape.ndims == 1: 

659 num_labels = 1 

660 else: 

661 num_labels = gen_math_ops.Prod(input=pred_shape[1:], axis=0) 

662 thresh_label_tile = array_ops.where_v2(one_thresh, num_labels, 

663 array_ops.ones([], dtype=dtypes.int32)) 

664 

665 # Reshape predictions and labels, adding a dim for thresholding. 

666 if multi_label: 

667 predictions_extra_dim = array_ops.expand_dims(y_pred, 0) 

668 labels_extra_dim = array_ops.expand_dims( 

669 math_ops.cast(y_true, dtype=dtypes.bool), 0) 

670 else: 

671 # Flatten predictions and labels when not multilabel. 

672 predictions_extra_dim = array_ops.reshape(y_pred, [1, -1]) 

673 labels_extra_dim = array_ops.reshape( 

674 math_ops.cast(y_true, dtype=dtypes.bool), [1, -1]) 

675 

676 # Tile the thresholds for every prediction. 

677 if multi_label: 

678 thresh_pretile_shape = [num_thresholds, 1, -1] 

679 thresh_tiles = [1, num_predictions, thresh_label_tile] 

680 data_tiles = [num_thresholds, 1, 1] 

681 else: 

682 thresh_pretile_shape = [num_thresholds, -1] 

683 thresh_tiles = [1, num_predictions * num_labels] 

684 data_tiles = [num_thresholds, 1] 

685 

686 thresh_tiled = array_ops.tile( 

687 array_ops.reshape(thresholds, thresh_pretile_shape), 

688 array_ops_stack.stack(thresh_tiles)) 

689 

690 # Tile the predictions for every threshold. 

691 preds_tiled = array_ops.tile(predictions_extra_dim, data_tiles) 

692 

693 # Compare predictions and threshold. 

694 pred_is_pos = math_ops.greater(preds_tiled, thresh_tiled) 

695 

696 # Tile labels by number of thresholds 

697 label_is_pos = array_ops.tile(labels_extra_dim, data_tiles) 

698 

699 if sample_weight is not None: 

700 sample_weight = weights_broadcast_ops.broadcast_weights( 

701 math_ops.cast(sample_weight, dtype=variable_dtype), y_pred) 

702 weights_tiled = array_ops.tile( 

703 array_ops.reshape(sample_weight, thresh_tiles), data_tiles) 

704 else: 

705 weights_tiled = None 

706 

707 if label_weights is not None and not multi_label: 

708 label_weights = array_ops.expand_dims(label_weights, 0) 

709 label_weights = weights_broadcast_ops.broadcast_weights(label_weights, 

710 y_pred) 

711 label_weights_tiled = array_ops.tile( 

712 array_ops.reshape(label_weights, thresh_tiles), data_tiles) 

713 if weights_tiled is None: 

714 weights_tiled = label_weights_tiled 

715 else: 

716 weights_tiled = math_ops.multiply(weights_tiled, label_weights_tiled) 

717 

718 update_ops = [] 

719 

720 def weighted_assign_add(label, pred, weights, var): 

721 label_and_pred = math_ops.cast( 

722 math_ops.logical_and(label, pred), dtype=var.dtype) 

723 if weights is not None: 

724 label_and_pred *= math_ops.cast(weights, dtype=var.dtype) 

725 return var.assign_add(math_ops.reduce_sum(label_and_pred, 1)) 

726 

727 loop_vars = { 

728 ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos), 

729 } 

730 update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update 

731 update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update 

732 update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update 

733 

734 if update_fn or update_tn: 

735 pred_is_neg = math_ops.logical_not(pred_is_pos) 

736 loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg) 

737 

738 if update_fp or update_tn: 

739 label_is_neg = math_ops.logical_not(label_is_pos) 

740 loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos) 

741 if update_tn: 

742 loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg) 

743 

744 for matrix_cond, (label, pred) in loop_vars.items(): 

745 

746 if matrix_cond in variables_to_update: 

747 update_ops.append( 

748 weighted_assign_add(label, pred, weights_tiled, 

749 variables_to_update[matrix_cond])) 

750 

751 return control_flow_ops.group(update_ops) 

752 

753 

754def _filter_top_k(x, k): 

755 """Filters top-k values in the last dim of x and set the rest to NEG_INF. 

756 

757 Used for computing top-k prediction values in dense labels (which has the same 

758 shape as predictions) for recall and precision top-k metrics. 

759 

760 Args: 

761 x: tensor with any dimensions. 

762 k: the number of values to keep. 

763 

764 Returns: 

765 tensor with same shape and dtype as x. 

766 """ 

767 _, top_k_idx = nn_ops.top_k(x, k, sorted=False) 

768 top_k_mask = math_ops.reduce_sum( 

769 array_ops.one_hot(top_k_idx, array_ops.shape(x)[-1], axis=-1), axis=-2) 

770 return x * top_k_mask + NEG_INF * (1 - top_k_mask) 

771 

772 

773def ragged_assert_compatible_and_get_flat_values(values, mask=None): 

774 """If ragged, it checks the compatibility and then returns the flat_values. 

775 

776 Note: If two tensors are dense, it does not check their compatibility. 

777 Note: Although two ragged tensors with different ragged ranks could have 

778 identical overall rank and dimension sizes and hence be compatible, 

779 we do not support those cases. 

780 Args: 

781 values: A list of potentially ragged tensor of the same ragged_rank. 

782 mask: A potentially ragged tensor of the same ragged_rank as elements in 

783 Values. 

784 

785 Returns: 

786 A tuple in which the first element is the list of tensors and the second 

787 is the mask tensor. ([Values], mask). Mask and the element in Values 

788 are equal to the flat_values of the input arguments (if they were ragged). 

789 """ 

790 if isinstance(values, list): 

791 is_all_ragged = \ 

792 all(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values) 

793 is_any_ragged = \ 

794 any(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values) 

795 else: 

796 is_all_ragged = isinstance(values, ragged_tensor.RaggedTensor) 

797 is_any_ragged = is_all_ragged 

798 if (is_all_ragged and 

799 ((mask is None) or isinstance(mask, ragged_tensor.RaggedTensor))): 

800 to_be_stripped = False 

801 if not isinstance(values, list): 

802 values = [values] 

803 to_be_stripped = True 

804 

805 # NOTE: we leave the flat_values compatibility to 

806 # tf.TensorShape `assert_is_compatible_with` 

807 # check if both dynamic dimensions are equal and then use the flat_values. 

808 nested_row_split_list = [rt.nested_row_splits for rt in values] 

809 assertion_list = _assert_splits_match(nested_row_split_list) 

810 

811 # if both are ragged sample_weights also should be ragged with same dims. 

812 if isinstance(mask, ragged_tensor.RaggedTensor): 

813 assertion_list_for_mask = _assert_splits_match( 

814 [nested_row_split_list[0], mask.nested_row_splits]) 

815 with ops.control_dependencies(assertion_list_for_mask): 

816 mask = array_ops.expand_dims(mask.flat_values, -1) 

817 

818 # values has at least 1 element. 

819 flat_values = [] 

820 for value in values: 

821 with ops.control_dependencies(assertion_list): 

822 flat_values.append(array_ops.expand_dims(value.flat_values, -1)) 

823 

824 values = flat_values[0] if to_be_stripped else flat_values 

825 

826 elif is_any_ragged: 

827 raise TypeError('One of the inputs does not have acceptable types.') 

828 # values are empty or value are not ragged and mask is ragged. 

829 elif isinstance(mask, ragged_tensor.RaggedTensor): 

830 raise TypeError('Ragged mask is not allowed with non-ragged inputs.') 

831 

832 return values, mask 

833 

834 

835def _assert_splits_match(nested_splits_lists): 

836 """Checks that the given splits lists are identical. 

837 

838 Performs static tests to ensure that the given splits lists are identical, 

839 and returns a list of control dependency op tensors that check that they are 

840 fully identical. 

841 

842 Args: 

843 nested_splits_lists: A list of nested_splits_lists, where each split_list is 

844 a list of `splits` tensors from a `RaggedTensor`, ordered from outermost 

845 ragged dimension to innermost ragged dimension. 

846 

847 Returns: 

848 A list of control dependency op tensors. 

849 Raises: 

850 ValueError: If the splits are not identical. 

851 """ 

852 error_msg = 'Inputs must have identical ragged splits' 

853 for splits_list in nested_splits_lists: 

854 if len(splits_list) != len(nested_splits_lists[0]): 

855 raise ValueError(error_msg) 

856 return [ 

857 check_ops.assert_equal(s1, s2, message=error_msg) # pylint: disable=g-complex-comprehension 

858 for splits_list in nested_splits_lists[1:] 

859 for (s1, s2) in zip(nested_splits_lists[0], splits_list) 

860 ]