Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/metrics_utils.py: 14%

327 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Utils related to keras metrics.""" 

17 

18import functools 

19import weakref 

20from enum import Enum 

21 

22import numpy as np 

23import tensorflow.compat.v2 as tf 

24 

25from keras.src import backend 

26from keras.src.utils import losses_utils 

27from keras.src.utils import tf_utils 

28from keras.src.utils.generic_utils import to_list 

29 

30NEG_INF = -1e10 

31 

32 

33class Reduction(Enum): 

34 """Types of metrics reduction. 

35 

36 Contains the following values: 

37 

38 * `SUM`: Scalar sum of weighted values. 

39 * `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by 

40 number of elements. 

41 * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights. 

42 """ 

43 

44 SUM = "sum" 

45 SUM_OVER_BATCH_SIZE = "sum_over_batch_size" 

46 WEIGHTED_MEAN = "weighted_mean" 

47 

48 

49def update_state_wrapper(update_state_fn): 

50 """Decorator to wrap metric `update_state()` with `add_update()`. 

51 

52 Args: 

53 update_state_fn: function that accumulates metric statistics. 

54 

55 Returns: 

56 Decorated function that wraps `update_state_fn()` with `add_update()`. 

57 """ 

58 

59 def decorated(metric_obj, *args, **kwargs): 

60 """Decorated function with `add_update()`.""" 

61 strategy = tf.distribute.get_strategy() 

62 

63 for weight in metric_obj.weights: 

64 if ( 

65 backend.is_tpu_strategy(strategy) 

66 and not strategy.extended.variable_created_in_scope(weight) 

67 and not tf.distribute.in_cross_replica_context() 

68 ): 

69 raise ValueError( 

70 "Trying to run metric.update_state in replica context when " 

71 "the metric was not created in TPUStrategy scope. " 

72 "Make sure the keras Metric is created in TPUstrategy " 

73 "scope. " 

74 ) 

75 

76 with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs): 

77 update_op = update_state_fn(*args, **kwargs) 

78 if update_op is not None: # update_op will be None in eager execution. 

79 metric_obj.add_update(update_op) 

80 return update_op 

81 

82 return tf.__internal__.decorator.make_decorator(update_state_fn, decorated) 

83 

84 

85def result_wrapper(result_fn): 

86 """Decorator to wrap metric `result()` function in `merge_call()`. 

87 

88 Result computation is an idempotent operation that simply calculates the 

89 metric value using the state variables. 

90 

91 If metric state variables are distributed across replicas/devices and 

92 `result()` is requested from the context of one device - This function wraps 

93 `result()` in a distribution strategy `merge_call()`. With this, 

94 the metric state variables will be aggregated across devices. 

95 

96 Args: 

97 result_fn: function that computes the metric result. 

98 

99 Returns: 

100 Decorated function that wraps `result_fn()` in distribution strategy 

101 `merge_call()`. 

102 """ 

103 

104 def decorated(metric_obj, *args): 

105 """Decorated function with merge_call.""" 

106 replica_context = tf.distribute.get_replica_context() 

107 

108 # The purpose of using `merge_call` to call `result()` is to trigger 

109 # cross replica aggregation of metric state variables 

110 # (SyncOnReadVariable). After we introduced 

111 # `variable_sync_on_read_context`, in principle there is no need to use 

112 # `merge_call` here. However the branch still exists because: 

113 # 

114 # 1. Keras V1 training code sometimes assumes `result_t` is the same 

115 # tensor across replicas (achieved by `merge_call`). With 

116 # `variable_sync_on_read_context` each replica gets their own tensors 

117 # residing on replica's device, thus breaking the assumption. 

118 # 2. Keras c/fit creates a tf.function (a.k.a, train_function) that 

119 # returns the metric values of the first replica. With 

120 # `variable_sync_on_read_context` since each replica gets their own 

121 # tensors, the metric result tensors on the non-first replicas are 

122 # not in the return value of train_function, making TF graph 

123 # optimizer prune the branch that computes and aggregates those 

124 # metric results. As a result, if NCCL is used to do the aggregation, 

125 # the program will hang because NCCL ops are only launched on the 

126 # non-pruned first replica. 

127 # 

128 # We condition on strategy_supports_no_merge_call() since we know if it 

129 # is True, the program uses `jit_compile` to compile replica fn, meaning 

130 # it is not V1 training (hence #1 is okay), and no pruning will happen 

131 # as compiled functions are not inlined (hence #2 is okay). 

132 if ( 

133 replica_context is None 

134 or tf.__internal__.distribute.strategy_supports_no_merge_call() 

135 ): 

136 with tf.__internal__.distribute.variable_sync_on_read_context(): 

137 raw_result = result_fn(*args) 

138 # Results need to be wrapped in a `tf.identity` op to ensure 

139 # correct execution order. 

140 if isinstance(raw_result, (tf.Tensor, tf.Variable, float, int)): 

141 result_t = tf.identity(raw_result) 

142 elif isinstance(raw_result, dict): 

143 result_t = tf.nest.map_structure(tf.identity, raw_result) 

144 else: 

145 try: 

146 result_t = tf.identity(raw_result) 

147 except (ValueError, TypeError): 

148 raise RuntimeError( 

149 "The output of `metric.result()` can only be a " 

150 "single Tensor/Variable, or a dict of " 

151 "Tensors/Variables. " 

152 f"For metric {metric_obj.name}, " 

153 f"got result {raw_result}." 

154 ) 

155 else: 

156 # TODO(psv): Test distribution of metrics using different 

157 # distribution strategies. 

158 

159 # Creating a wrapper for merge_fn. merge_call invokes the given 

160 # merge_fn with distribution object as the first parameter. We 

161 # create a wrapper here so that the result function need not have 

162 # that parameter. 

163 def merge_fn_wrapper(distribution, merge_fn, *args): 

164 # We will get `PerReplica` merge function. Taking the first one 

165 # as all are identical copies of the function that we had passed 

166 # below. 

167 result = distribution.experimental_local_results(merge_fn)[0]( 

168 *args 

169 ) 

170 

171 # Wrapping result in identity so that control dependency between 

172 # update_op from `update_state` and result works in case result 

173 # returns a tensor. 

174 return tf.nest.map_structure(tf.identity, result) 

175 

176 # Wrapping result in merge_call. merge_call is used when we want to 

177 # leave replica mode and compute a value in cross replica mode. 

178 result_t = replica_context.merge_call( 

179 merge_fn_wrapper, args=(result_fn,) + args 

180 ) 

181 

182 # We are saving the result op here to be used in train/test execution 

183 # functions. This basically gives the result op that was generated with 

184 # a control dep to the updates for these workflows. 

185 metric_obj._call_result = result_t 

186 return result_t 

187 

188 return tf.__internal__.decorator.make_decorator(result_fn, decorated) 

189 

190 

191def weakmethod(method): 

192 """Creates a weak reference to the bound method.""" 

193 

194 cls = method.im_class 

195 func = method.im_func 

196 instance_ref = weakref.ref(method.im_self) 

197 

198 @functools.wraps(method) 

199 def inner(*args, **kwargs): 

200 return func.__get__(instance_ref(), cls)(*args, **kwargs) 

201 

202 del method 

203 return inner 

204 

205 

206def assert_thresholds_range(thresholds): 

207 if thresholds is not None: 

208 invalid_thresholds = [ 

209 t for t in thresholds if t is None or t < 0 or t > 1 

210 ] 

211 if invalid_thresholds: 

212 raise ValueError( 

213 "Threshold values must be in [0, 1]. " 

214 f"Received: {invalid_thresholds}" 

215 ) 

216 

217 

218def parse_init_thresholds(thresholds, default_threshold=0.5): 

219 if thresholds is not None: 

220 assert_thresholds_range(to_list(thresholds)) 

221 thresholds = to_list( 

222 default_threshold if thresholds is None else thresholds 

223 ) 

224 return thresholds 

225 

226 

227class ConfusionMatrix(Enum): 

228 TRUE_POSITIVES = "tp" 

229 FALSE_POSITIVES = "fp" 

230 TRUE_NEGATIVES = "tn" 

231 FALSE_NEGATIVES = "fn" 

232 

233 

234class AUCCurve(Enum): 

235 """Type of AUC Curve (ROC or PR).""" 

236 

237 ROC = "ROC" 

238 PR = "PR" 

239 

240 @staticmethod 

241 def from_str(key): 

242 if key in ("pr", "PR"): 

243 return AUCCurve.PR 

244 elif key in ("roc", "ROC"): 

245 return AUCCurve.ROC 

246 else: 

247 raise ValueError( 

248 f'Invalid AUC curve value: "{key}". ' 

249 'Expected values are ["PR", "ROC"]' 

250 ) 

251 

252 

253class AUCSummationMethod(Enum): 

254 """Type of AUC summation method. 

255 

256 https://en.wikipedia.org/wiki/Riemann_sum) 

257 

258 Contains the following values: 

259 * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For 

260 `PR` curve, interpolates (true/false) positives but not the ratio that is 

261 precision (see Davis & Goadrich 2006 for details). 

262 * 'minoring': Applies left summation for increasing intervals and right 

263 summation for decreasing intervals. 

264 * 'majoring': Applies right summation for increasing intervals and left 

265 summation for decreasing intervals. 

266 """ 

267 

268 INTERPOLATION = "interpolation" 

269 MAJORING = "majoring" 

270 MINORING = "minoring" 

271 

272 @staticmethod 

273 def from_str(key): 

274 if key in ("interpolation", "Interpolation"): 

275 return AUCSummationMethod.INTERPOLATION 

276 elif key in ("majoring", "Majoring"): 

277 return AUCSummationMethod.MAJORING 

278 elif key in ("minoring", "Minoring"): 

279 return AUCSummationMethod.MINORING 

280 else: 

281 raise ValueError( 

282 f'Invalid AUC summation method value: "{key}". ' 

283 'Expected values are ["interpolation", "majoring", "minoring"]' 

284 ) 

285 

286 

287def _update_confusion_matrix_variables_optimized( 

288 variables_to_update, 

289 y_true, 

290 y_pred, 

291 thresholds, 

292 multi_label=False, 

293 sample_weights=None, 

294 label_weights=None, 

295 thresholds_with_epsilon=False, 

296): 

297 """Update confusion matrix variables with memory efficient alternative. 

298 

299 Note that the thresholds need to be evenly distributed within the list, eg, 

300 the diff between consecutive elements are the same. 

301 

302 To compute TP/FP/TN/FN, we are measuring a binary classifier 

303 C(t) = (predictions >= t) 

304 at each threshold 't'. So we have 

305 TP(t) = sum( C(t) * true_labels ) 

306 FP(t) = sum( C(t) * false_labels ) 

307 

308 But, computing C(t) requires computation for each t. To make it fast, 

309 observe that C(t) is a cumulative integral, and so if we have 

310 thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} 

311 where n = num_thresholds, and if we can compute the bucket function 

312 B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) 

313 then we get 

314 C(t_i) = sum( B(j), j >= i ) 

315 which is the reversed cumulative sum in tf.cumsum(). 

316 

317 We can compute B(i) efficiently by taking advantage of the fact that 

318 our thresholds are evenly distributed, in that 

319 width = 1.0 / (num_thresholds - 1) 

320 thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] 

321 Given a prediction value p, we can map it to its bucket by 

322 bucket_index(p) = floor( p * (num_thresholds - 1) ) 

323 so we can use tf.math.unsorted_segment_sum() to update the buckets in one 

324 pass. 

325 

326 Consider following example: 

327 y_true = [0, 0, 1, 1] 

328 y_pred = [0.1, 0.5, 0.3, 0.9] 

329 thresholds = [0.0, 0.5, 1.0] 

330 num_buckets = 2 # [0.0, 1.0], (1.0, 2.0] 

331 bucket_index(y_pred) = tf.math.floor(y_pred * num_buckets) 

332 = tf.math.floor([0.2, 1.0, 0.6, 1.8]) 

333 = [0, 0, 0, 1] 

334 # The meaning of this bucket is that if any of the label is true, 

335 # then 1 will be added to the corresponding bucket with the index. 

336 # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the 

337 # label for 1.8 is true, then 1 will be added to bucket 1. 

338 # 

339 # Note the second item "1.0" is floored to 0, since the value need to be 

340 # strictly larger than the bucket lower bound. 

341 # In the implementation, we use tf.math.ceil() - 1 to achieve this. 

342 tp_bucket_value = tf.math.unsorted_segment_sum(true_labels, bucket_indices, 

343 num_segments=num_thresholds) 

344 = [1, 1, 0] 

345 # For [1, 1, 0] here, it means there is 1 true value contributed by bucket 

346 # 0, and 1 value contributed by bucket 1. When we aggregate them to 

347 # together, the result become [a + b + c, b + c, c], since large thresholds 

348 # will always contribute to the value for smaller thresholds. 

349 true_positive = tf.math.cumsum(tp_bucket_value, reverse=True) 

350 = [2, 1, 0] 

351 

352 This implementation exhibits a run time and space complexity of O(T + N), 

353 where T is the number of thresholds and N is the size of predictions. 

354 Metrics that rely on standard implementation instead exhibit a complexity of 

355 O(T * N). 

356 

357 Args: 

358 variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys 

359 and corresponding variables to update as values. 

360 y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be 

361 cast to `bool`. 

362 y_pred: A floating point `Tensor` of arbitrary shape and whose values are 

363 in the range `[0, 1]`. 

364 thresholds: A sorted floating point `Tensor` with value in `[0, 1]`. 

365 It need to be evenly distributed (the diff between each element need to 

366 be the same). 

367 multi_label: Optional boolean indicating whether multidimensional 

368 prediction/labels should be treated as multilabel responses, or 

369 flattened into a single label. When True, the valus of 

370 `variables_to_update` must have a second dimension equal to the number 

371 of labels in y_true and y_pred, and those tensors must not be 

372 RaggedTensors. 

373 sample_weights: Optional `Tensor` whose rank is either 0, or the same rank 

374 as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions 

375 must be either `1`, or the same as the corresponding `y_true` 

376 dimension). 

377 label_weights: Optional tensor of non-negative weights for multilabel 

378 data. The weights are applied when calculating TP, FP, FN, and TN 

379 without explicit multilabel handling (i.e. when the data is to be 

380 flattened). 

381 thresholds_with_epsilon: Optional boolean indicating whether the leading 

382 and tailing thresholds has any epsilon added for floating point 

383 imprecisions. It will change how we handle the leading and tailing 

384 bucket. 

385 

386 Returns: 

387 Update op. 

388 """ 

389 num_thresholds = thresholds.shape.as_list()[0] 

390 

391 if sample_weights is None: 

392 sample_weights = 1.0 

393 else: 

394 sample_weights = tf.__internal__.ops.broadcast_weights( 

395 tf.cast(sample_weights, dtype=y_pred.dtype), y_pred 

396 ) 

397 if not multi_label: 

398 sample_weights = tf.reshape(sample_weights, [-1]) 

399 if label_weights is None: 

400 label_weights = 1.0 

401 else: 

402 label_weights = tf.expand_dims(label_weights, 0) 

403 label_weights = tf.__internal__.ops.broadcast_weights( 

404 label_weights, y_pred 

405 ) 

406 if not multi_label: 

407 label_weights = tf.reshape(label_weights, [-1]) 

408 weights = tf.cast(tf.multiply(sample_weights, label_weights), y_true.dtype) 

409 

410 # We shouldn't need this, but in case there are predict value that is out of 

411 # the range of [0.0, 1.0] 

412 y_pred = tf.clip_by_value(y_pred, clip_value_min=0.0, clip_value_max=1.0) 

413 

414 y_true = tf.cast(tf.cast(y_true, tf.bool), y_true.dtype) 

415 if not multi_label: 

416 y_true = tf.reshape(y_true, [-1]) 

417 y_pred = tf.reshape(y_pred, [-1]) 

418 

419 true_labels = tf.multiply(y_true, weights) 

420 false_labels = tf.multiply((1.0 - y_true), weights) 

421 

422 # Compute the bucket indices for each prediction value. 

423 # Since the predict value has to be strictly greater than the thresholds, 

424 # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket. 

425 # We have to use math.ceil(val) - 1 for the bucket. 

426 bucket_indices = tf.math.ceil(y_pred * (num_thresholds - 1)) - 1 

427 

428 if thresholds_with_epsilon: 

429 # In this case, the first bucket should actually take into account since 

430 # the any prediction between [0.0, 1.0] should be larger than the first 

431 # threshold. We change the bucket value from -1 to 0. 

432 bucket_indices = tf.nn.relu(bucket_indices) 

433 

434 bucket_indices = tf.cast(bucket_indices, tf.int32) 

435 

436 if multi_label: 

437 # We need to run bucket segment sum for each of the label class. In the 

438 # multi_label case, the rank of the label is 2. We first transpose it so 

439 # that the label dim becomes the first and we can parallel run though 

440 # them. 

441 true_labels = tf.transpose(true_labels) 

442 false_labels = tf.transpose(false_labels) 

443 bucket_indices = tf.transpose(bucket_indices) 

444 

445 def gather_bucket(label_and_bucket_index): 

446 label, bucket_index = ( 

447 label_and_bucket_index[0], 

448 label_and_bucket_index[1], 

449 ) 

450 return tf.math.unsorted_segment_sum( 

451 data=label, 

452 segment_ids=bucket_index, 

453 num_segments=num_thresholds, 

454 ) 

455 

456 tp_bucket_v = tf.vectorized_map( 

457 gather_bucket, (true_labels, bucket_indices), warn=False 

458 ) 

459 fp_bucket_v = tf.vectorized_map( 

460 gather_bucket, (false_labels, bucket_indices), warn=False 

461 ) 

462 tp = tf.transpose(tf.cumsum(tp_bucket_v, reverse=True, axis=1)) 

463 fp = tf.transpose(tf.cumsum(fp_bucket_v, reverse=True, axis=1)) 

464 else: 

465 tp_bucket_v = tf.math.unsorted_segment_sum( 

466 data=true_labels, 

467 segment_ids=bucket_indices, 

468 num_segments=num_thresholds, 

469 ) 

470 fp_bucket_v = tf.math.unsorted_segment_sum( 

471 data=false_labels, 

472 segment_ids=bucket_indices, 

473 num_segments=num_thresholds, 

474 ) 

475 tp = tf.cumsum(tp_bucket_v, reverse=True) 

476 fp = tf.cumsum(fp_bucket_v, reverse=True) 

477 

478 # fn = sum(true_labels) - tp 

479 # tn = sum(false_labels) - fp 

480 if ( 

481 ConfusionMatrix.TRUE_NEGATIVES in variables_to_update 

482 or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update 

483 ): 

484 if multi_label: 

485 total_true_labels = tf.reduce_sum(true_labels, axis=1) 

486 total_false_labels = tf.reduce_sum(false_labels, axis=1) 

487 else: 

488 total_true_labels = tf.reduce_sum(true_labels) 

489 total_false_labels = tf.reduce_sum(false_labels) 

490 

491 update_ops = [] 

492 if ConfusionMatrix.TRUE_POSITIVES in variables_to_update: 

493 variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES] 

494 update_ops.append(variable.assign_add(tp)) 

495 if ConfusionMatrix.FALSE_POSITIVES in variables_to_update: 

496 variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES] 

497 update_ops.append(variable.assign_add(fp)) 

498 if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update: 

499 variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES] 

500 tn = total_false_labels - fp 

501 update_ops.append(variable.assign_add(tn)) 

502 if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update: 

503 variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES] 

504 fn = total_true_labels - tp 

505 update_ops.append(variable.assign_add(fn)) 

506 return tf.group(update_ops) 

507 

508 

509def is_evenly_distributed_thresholds(thresholds): 

510 """Check if the thresholds list is evenly distributed. 

511 

512 We could leverage evenly distributed thresholds to use less memory when 

513 calculate metrcis like AUC where each individual threshold need to be 

514 evaluated. 

515 

516 Args: 

517 thresholds: A python list or tuple, or 1D numpy array whose value is 

518 ranged in [0, 1]. 

519 

520 Returns: 

521 boolean, whether the values in the inputs are evenly distributed. 

522 """ 

523 # Check the list value and see if it is evenly distributed. 

524 num_thresholds = len(thresholds) 

525 if num_thresholds < 3: 

526 return False 

527 even_thresholds = np.arange(num_thresholds, dtype=np.float32) / ( 

528 num_thresholds - 1 

529 ) 

530 return np.allclose(thresholds, even_thresholds, atol=backend.epsilon()) 

531 

532 

533def update_confusion_matrix_variables( 

534 variables_to_update, 

535 y_true, 

536 y_pred, 

537 thresholds, 

538 top_k=None, 

539 class_id=None, 

540 sample_weight=None, 

541 multi_label=False, 

542 label_weights=None, 

543 thresholds_distributed_evenly=False, 

544): 

545 """Returns op to update the given confusion matrix variables. 

546 

547 For every pair of values in y_true and y_pred: 

548 

549 true_positive: y_true == True and y_pred > thresholds 

550 false_negatives: y_true == True and y_pred <= thresholds 

551 true_negatives: y_true == False and y_pred <= thresholds 

552 false_positive: y_true == False and y_pred > thresholds 

553 

554 The results will be weighted and added together. When multiple thresholds 

555 are provided, we will repeat the same for every threshold. 

556 

557 For estimation of these metrics over a stream of data, the function creates 

558 an `update_op` operation that updates the given variables. 

559 

560 If `sample_weight` is `None`, weights default to 1. 

561 Use weights of 0 to mask values. 

562 

563 Args: 

564 variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys 

565 and corresponding variables to update as values. 

566 y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`. 

567 y_pred: A floating point `Tensor` of arbitrary shape and whose values are 

568 in the range `[0, 1]`. 

569 thresholds: A float value, float tensor, python list, or tuple of float 

570 thresholds in `[0, 1]`, or NEG_INF (used when top_k is set). 

571 top_k: Optional int, indicates that the positive labels should be limited 

572 to the top k predictions. 

573 class_id: Optional int, limits the prediction and labels to the class 

574 specified by this argument. 

575 sample_weight: Optional `Tensor` whose rank is either 0, or the same rank 

576 as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions 

577 must be either `1`, or the same as the corresponding `y_true` 

578 dimension). 

579 multi_label: Optional boolean indicating whether multidimensional 

580 prediction/labels should be treated as multilabel responses, or 

581 flattened into a single label. When True, the valus of 

582 `variables_to_update` must have a second dimension equal to the number 

583 of labels in y_true and y_pred, and those tensors must not be 

584 RaggedTensors. 

585 label_weights: (optional) tensor of non-negative weights for multilabel 

586 data. The weights are applied when calculating TP, FP, FN, and TN 

587 without explicit multilabel handling (i.e. when the data is to be 

588 flattened). 

589 thresholds_distributed_evenly: Boolean, whether the thresholds are evenly 

590 distributed within the list. An optimized method will be used if this is 

591 the case. See _update_confusion_matrix_variables_optimized() for more 

592 details. 

593 

594 Returns: 

595 Update op. 

596 

597 Raises: 

598 ValueError: If `y_pred` and `y_true` have mismatched shapes, or if 

599 `sample_weight` is not `None` and its shape doesn't match `y_pred`, or 

600 if `variables_to_update` contains invalid keys. 

601 """ 

602 if multi_label and label_weights is not None: 

603 raise ValueError( 

604 "`label_weights` for multilabel data should be handled " 

605 "outside of `update_confusion_matrix_variables` when " 

606 "`multi_label` is True." 

607 ) 

608 if variables_to_update is None: 

609 return 

610 if not any( 

611 key for key in variables_to_update if key in list(ConfusionMatrix) 

612 ): 

613 raise ValueError( 

614 "Please provide at least one valid confusion matrix " 

615 "variable to update. Valid variable key options are: " 

616 f'"{list(ConfusionMatrix)}". ' 

617 f'Received: "{variables_to_update.keys()}"' 

618 ) 

619 

620 variable_dtype = list(variables_to_update.values())[0].dtype 

621 

622 y_true = tf.cast(y_true, dtype=variable_dtype) 

623 y_pred = tf.cast(y_pred, dtype=variable_dtype) 

624 

625 if thresholds_distributed_evenly: 

626 # Check whether the thresholds has any leading or tailing epsilon added 

627 # for floating point imprecision. The leading and tailing threshold will 

628 # be handled bit differently as the corner case. At this point, 

629 # thresholds should be a list/array with more than 2 items, and ranged 

630 # between [0, 1]. See is_evenly_distributed_thresholds() for more 

631 # details. 

632 thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0 

633 

634 thresholds = tf.convert_to_tensor(thresholds, dtype=variable_dtype) 

635 num_thresholds = thresholds.shape.as_list()[0] 

636 

637 if multi_label: 

638 one_thresh = tf.equal( 

639 tf.cast(1, dtype=tf.int32), 

640 tf.rank(thresholds), 

641 name="one_set_of_thresholds_cond", 

642 ) 

643 else: 

644 [y_pred, y_true], _ = ragged_assert_compatible_and_get_flat_values( 

645 [y_pred, y_true], sample_weight 

646 ) 

647 one_thresh = tf.cast(True, dtype=tf.bool) 

648 

649 invalid_keys = [ 

650 key for key in variables_to_update if key not in list(ConfusionMatrix) 

651 ] 

652 if invalid_keys: 

653 raise ValueError( 

654 f'Invalid keys: "{invalid_keys}". ' 

655 f'Valid variable key options are: "{list(ConfusionMatrix)}"' 

656 ) 

657 

658 if sample_weight is None: 

659 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 

660 y_pred, y_true 

661 ) 

662 else: 

663 sample_weight = tf.cast(sample_weight, dtype=variable_dtype) 

664 ( 

665 y_pred, 

666 y_true, 

667 sample_weight, 

668 ) = losses_utils.squeeze_or_expand_dimensions( 

669 y_pred, y_true, sample_weight=sample_weight 

670 ) 

671 y_pred.shape.assert_is_compatible_with(y_true.shape) 

672 

673 if top_k is not None: 

674 y_pred = _filter_top_k(y_pred, top_k) 

675 if class_id is not None: 

676 # Preserve dimension to match with sample_weight 

677 y_true = y_true[..., class_id, None] 

678 y_pred = y_pred[..., class_id, None] 

679 

680 if thresholds_distributed_evenly: 

681 return _update_confusion_matrix_variables_optimized( 

682 variables_to_update, 

683 y_true, 

684 y_pred, 

685 thresholds, 

686 multi_label=multi_label, 

687 sample_weights=sample_weight, 

688 label_weights=label_weights, 

689 thresholds_with_epsilon=thresholds_with_epsilon, 

690 ) 

691 

692 pred_shape = tf.shape(y_pred) 

693 num_predictions = pred_shape[0] 

694 if y_pred.shape.ndims == 1: 

695 num_labels = 1 

696 else: 

697 num_labels = tf.math.reduce_prod(pred_shape[1:], axis=0) 

698 thresh_label_tile = tf.where( 

699 one_thresh, num_labels, tf.ones([], dtype=tf.int32) 

700 ) 

701 

702 # Reshape predictions and labels, adding a dim for thresholding. 

703 if multi_label: 

704 predictions_extra_dim = tf.expand_dims(y_pred, 0) 

705 labels_extra_dim = tf.expand_dims(tf.cast(y_true, dtype=tf.bool), 0) 

706 else: 

707 # Flatten predictions and labels when not multilabel. 

708 predictions_extra_dim = tf.reshape(y_pred, [1, -1]) 

709 labels_extra_dim = tf.reshape(tf.cast(y_true, dtype=tf.bool), [1, -1]) 

710 

711 # Tile the thresholds for every prediction. 

712 if multi_label: 

713 thresh_pretile_shape = [num_thresholds, 1, -1] 

714 thresh_tiles = [1, num_predictions, thresh_label_tile] 

715 data_tiles = [num_thresholds, 1, 1] 

716 else: 

717 thresh_pretile_shape = [num_thresholds, -1] 

718 thresh_tiles = [1, num_predictions * num_labels] 

719 data_tiles = [num_thresholds, 1] 

720 

721 thresh_tiled = tf.tile( 

722 tf.reshape(thresholds, thresh_pretile_shape), tf.stack(thresh_tiles) 

723 ) 

724 

725 # Tile the predictions for every threshold. 

726 preds_tiled = tf.tile(predictions_extra_dim, data_tiles) 

727 

728 # Compare predictions and threshold. 

729 pred_is_pos = tf.greater(preds_tiled, thresh_tiled) 

730 

731 # Tile labels by number of thresholds 

732 label_is_pos = tf.tile(labels_extra_dim, data_tiles) 

733 

734 if sample_weight is not None: 

735 sample_weight = tf.__internal__.ops.broadcast_weights( 

736 tf.cast(sample_weight, dtype=variable_dtype), y_pred 

737 ) 

738 weights_tiled = tf.tile( 

739 tf.reshape(sample_weight, thresh_tiles), data_tiles 

740 ) 

741 else: 

742 weights_tiled = None 

743 

744 if label_weights is not None and not multi_label: 

745 label_weights = tf.expand_dims(label_weights, 0) 

746 label_weights = tf.__internal__.ops.broadcast_weights( 

747 label_weights, y_pred 

748 ) 

749 label_weights_tiled = tf.tile( 

750 tf.reshape(label_weights, thresh_tiles), data_tiles 

751 ) 

752 if weights_tiled is None: 

753 weights_tiled = label_weights_tiled 

754 else: 

755 weights_tiled = tf.multiply(weights_tiled, label_weights_tiled) 

756 

757 update_ops = [] 

758 

759 def weighted_assign_add(label, pred, weights, var): 

760 label_and_pred = tf.cast(tf.logical_and(label, pred), dtype=var.dtype) 

761 if weights is not None: 

762 label_and_pred *= tf.cast(weights, dtype=var.dtype) 

763 return var.assign_add(tf.reduce_sum(label_and_pred, 1)) 

764 

765 loop_vars = { 

766 ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos), 

767 } 

768 update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update 

769 update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update 

770 update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update 

771 

772 if update_fn or update_tn: 

773 pred_is_neg = tf.logical_not(pred_is_pos) 

774 loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg) 

775 

776 if update_fp or update_tn: 

777 label_is_neg = tf.logical_not(label_is_pos) 

778 loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos) 

779 if update_tn: 

780 loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = ( 

781 label_is_neg, 

782 pred_is_neg, 

783 ) 

784 

785 for matrix_cond, (label, pred) in loop_vars.items(): 

786 

787 if matrix_cond in variables_to_update: 

788 update_ops.append( 

789 weighted_assign_add( 

790 label, pred, weights_tiled, variables_to_update[matrix_cond] 

791 ) 

792 ) 

793 

794 return tf.group(update_ops) 

795 

796 

797def _filter_top_k(x, k): 

798 """Filters top-k values in the last dim of x and set the rest to NEG_INF. 

799 

800 Used for computing top-k prediction values in dense labels (which has the 

801 same shape as predictions) for recall and precision top-k metrics. 

802 

803 Args: 

804 x: tensor with any dimensions. 

805 k: the number of values to keep. 

806 

807 Returns: 

808 tensor with same shape and dtype as x. 

809 """ 

810 _, top_k_idx = tf.math.top_k(x, k, sorted=False) 

811 top_k_mask = tf.reduce_sum( 

812 tf.one_hot(top_k_idx, tf.shape(x)[-1], axis=-1), axis=-2 

813 ) 

814 return x * top_k_mask + NEG_INF * (1 - top_k_mask) 

815 

816 

817def ragged_assert_compatible_and_get_flat_values(values, mask=None): 

818 """If ragged, it checks the compatibility and then returns the flat_values. 

819 

820 Note: If two tensors are dense, it does not check their compatibility. 

821 Note: Although two ragged tensors with different ragged ranks could have 

822 identical overall rank and dimension sizes and hence be compatible, 

823 we do not support those cases. 

824 Args: 

825 values: A list of potentially ragged tensor of the same ragged_rank. 

826 mask: A potentially ragged tensor of the same ragged_rank as elements in 

827 Values. 

828 

829 Returns: 

830 A tuple in which the first element is the list of tensors and the second 

831 is the mask tensor. ([Values], mask). Mask and the element in Values 

832 are equal to the flat_values of the input arguments (if they were 

833 ragged). 

834 """ 

835 if isinstance(values, list): 

836 is_all_ragged = all(isinstance(rt, tf.RaggedTensor) for rt in values) 

837 is_any_ragged = any(isinstance(rt, tf.RaggedTensor) for rt in values) 

838 else: 

839 is_all_ragged = isinstance(values, tf.RaggedTensor) 

840 is_any_ragged = is_all_ragged 

841 if is_all_ragged and ((mask is None) or isinstance(mask, tf.RaggedTensor)): 

842 to_be_stripped = False 

843 if not isinstance(values, list): 

844 values = [values] 

845 to_be_stripped = True 

846 

847 # NOTE: we leave the flat_values compatibility to 

848 # tf.TensorShape `assert_is_compatible_with` check if both dynamic 

849 # dimensions are equal and then use the flat_values. 

850 nested_row_split_list = [rt.nested_row_splits for rt in values] 

851 assertion_list = _assert_splits_match(nested_row_split_list) 

852 

853 # if both are ragged sample_weights also should be ragged with same 

854 # dims. 

855 if isinstance(mask, tf.RaggedTensor): 

856 assertion_list_for_mask = _assert_splits_match( 

857 [nested_row_split_list[0], mask.nested_row_splits] 

858 ) 

859 with tf.control_dependencies(assertion_list_for_mask): 

860 mask = tf.expand_dims(mask.flat_values, -1) 

861 

862 # values has at least 1 element. 

863 flat_values = [] 

864 for value in values: 

865 with tf.control_dependencies(assertion_list): 

866 flat_values.append(tf.expand_dims(value.flat_values, -1)) 

867 

868 values = flat_values[0] if to_be_stripped else flat_values 

869 

870 elif is_any_ragged: 

871 raise TypeError( 

872 "Some of the inputs are not tf.RaggedTensor. " 

873 f"Input received: {values}" 

874 ) 

875 # values are empty or value are not ragged and mask is ragged. 

876 elif isinstance(mask, tf.RaggedTensor): 

877 raise TypeError( 

878 "Ragged mask is not allowed with non-ragged inputs. " 

879 f"Input received: {values}, mask received: {mask}" 

880 ) 

881 

882 return values, mask 

883 

884 

885def _assert_splits_match(nested_splits_lists): 

886 """Checks that the given splits lists are identical. 

887 

888 Performs static tests to ensure that the given splits lists are identical, 

889 and returns a list of control dependency op tensors that check that they are 

890 fully identical. 

891 

892 Args: 

893 nested_splits_lists: A list of nested_splits_lists, where each split_list 

894 is a list of `splits` tensors from a `RaggedTensor`, ordered from 

895 outermost ragged dimension to innermost ragged dimension. 

896 

897 Returns: 

898 A list of control dependency op tensors. 

899 Raises: 

900 ValueError: If the splits are not identical. 

901 """ 

902 error_msg = ( 

903 "Inputs must have identical ragged splits. " 

904 f"Input received: {nested_splits_lists}" 

905 ) 

906 for splits_list in nested_splits_lists: 

907 if len(splits_list) != len(nested_splits_lists[0]): 

908 raise ValueError(error_msg) 

909 return [ 

910 tf.debugging.assert_equal(s1, s2, message=error_msg) 

911 for splits_list in nested_splits_lists[1:] 

912 for (s1, s2) in zip(nested_splits_lists[0], splits_list) 

913 ] 

914 

915 

916def binary_matches(y_true, y_pred, threshold=0.5): 

917 """Creates int Tensor, 1 for label-prediction match, 0 for mismatch. 

918 

919 Args: 

920 y_true: Ground truth values, of shape (batch_size, d0, .. dN). 

921 y_pred: The predicted values, of shape (batch_size, d0, .. dN). 

922 threshold: (Optional) Float representing the threshold for deciding 

923 whether prediction values are 1 or 0. 

924 

925 Returns: 

926 Binary matches, of shape (batch_size, d0, .. dN). 

927 """ 

928 y_pred = tf.convert_to_tensor(y_pred) 

929 threshold = tf.cast(threshold, y_pred.dtype) 

930 y_pred = tf.cast(y_pred > threshold, y_pred.dtype) 

931 return tf.cast(tf.equal(y_true, y_pred), backend.floatx()) 

932 

933 

934def sparse_categorical_matches(y_true, y_pred): 

935 """Creates float Tensor, 1.0 for label-prediction match, 0.0 for mismatch. 

936 

937 You can provide logits of classes as `y_pred`, since argmax of 

938 logits and probabilities are same. 

939 

940 Args: 

941 y_true: Integer ground truth values. 

942 y_pred: The prediction values. 

943 

944 Returns: 

945 Match tensor: 1.0 for label-prediction match, 0.0 for mismatch. 

946 """ 

947 reshape_matches = False 

948 y_pred = tf.convert_to_tensor(y_pred) 

949 y_true = tf.convert_to_tensor(y_true) 

950 y_true_org_shape = tf.shape(y_true) 

951 y_pred_rank = y_pred.shape.ndims 

952 y_true_rank = y_true.shape.ndims 

953 

954 # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) 

955 if ( 

956 (y_true_rank is not None) 

957 and (y_pred_rank is not None) 

958 and (len(backend.int_shape(y_true)) == len(backend.int_shape(y_pred))) 

959 ): 

960 y_true = tf.squeeze(y_true, [-1]) 

961 reshape_matches = True 

962 y_pred = tf.math.argmax(y_pred, axis=-1) 

963 

964 # If the predicted output and actual output types don't match, force cast 

965 # them to match. 

966 if backend.dtype(y_pred) != backend.dtype(y_true): 

967 y_pred = tf.cast(y_pred, backend.dtype(y_true)) 

968 matches = tf.cast(tf.equal(y_true, y_pred), backend.floatx()) 

969 if reshape_matches: 

970 matches = tf.reshape(matches, shape=y_true_org_shape) 

971 return matches 

972 

973 

974def sparse_top_k_categorical_matches(y_true, y_pred, k=5): 

975 """Creates float Tensor, 1.0 for label-TopK_prediction match, 0.0 for 

976 mismatch. 

977 

978 Args: 

979 y_true: tensor of true targets. 

980 y_pred: tensor of predicted targets. 

981 k: (Optional) Number of top elements to look at for computing accuracy. 

982 Defaults to 5. 

983 

984 Returns: 

985 Match tensor: 1.0 for label-prediction match, 0.0 for mismatch. 

986 """ 

987 reshape_matches = False 

988 y_true = tf.convert_to_tensor(y_true) 

989 y_pred = tf.convert_to_tensor(y_pred) 

990 y_true_rank = y_true.shape.ndims 

991 y_pred_rank = y_pred.shape.ndims 

992 y_true_org_shape = tf.shape(y_true) 

993 

994 # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,) 

995 if (y_true_rank is not None) and (y_pred_rank is not None): 

996 if y_pred_rank > 2: 

997 y_pred = tf.reshape(y_pred, [-1, y_pred.shape[-1]]) 

998 if y_true_rank > 1: 

999 reshape_matches = True 

1000 y_true = tf.reshape(y_true, [-1]) 

1001 

1002 matches = tf.cast( 

1003 tf.math.in_top_k( 

1004 predictions=y_pred, targets=tf.cast(y_true, "int32"), k=k 

1005 ), 

1006 dtype=backend.floatx(), 

1007 ) 

1008 

1009 # returned matches is expected to have same shape as y_true input 

1010 if reshape_matches: 

1011 return tf.reshape(matches, shape=y_true_org_shape) 

1012 

1013 return matches 

1014