Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorboard/plugins/pr_curve/summary.py: 12%

89 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Precision--recall curves and TensorFlow operations to create them. 

16 

17NOTE: This module is in beta, and its API is subject to change, but the 

18data that it stores to disk will be supported forever. 

19""" 

20 

21 

22import numpy as np 

23 

24from tensorboard.plugins.pr_curve import metadata 

25 

26 

27# A value that we use as the minimum value during division of counts to prevent 

28# division by 0. 1.0 does not work: Certain weights could cause counts below 1. 

29_MINIMUM_COUNT = 1e-7 

30 

31# The default number of thresholds. 

32_DEFAULT_NUM_THRESHOLDS = 201 

33 

34 

35def op( 

36 name, 

37 labels, 

38 predictions, 

39 num_thresholds=None, 

40 weights=None, 

41 display_name=None, 

42 description=None, 

43 collections=None, 

44): 

45 """Create a PR curve summary op for a single binary classifier. 

46 

47 Computes true/false positive/negative values for the given `predictions` 

48 against the ground truth `labels`, against a list of evenly distributed 

49 threshold values in `[0, 1]` of length `num_thresholds`. 

50 

51 Each number in `predictions`, a float in `[0, 1]`, is compared with its 

52 corresponding boolean label in `labels`, and counts as a single tp/fp/tn/fn 

53 value at each threshold. This is then multiplied with `weights` which can be 

54 used to reweight certain values, or more commonly used for masking values. 

55 

56 Args: 

57 name: A tag attached to the summary. Used by TensorBoard for organization. 

58 labels: The ground truth values. A Tensor of `bool` values with arbitrary 

59 shape. 

60 predictions: A float32 `Tensor` whose values are in the range `[0, 1]`. 

61 Dimensions must match those of `labels`. 

62 num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to 

63 compute PR metrics for. Should be `>= 2`. This value should be a 

64 constant integer value, not a Tensor that stores an integer. 

65 weights: Optional float32 `Tensor`. Individual counts are multiplied by this 

66 value. This tensor must be either the same shape as or broadcastable to 

67 the `labels` tensor. 

68 display_name: Optional name for this summary in TensorBoard, as a 

69 constant `str`. Defaults to `name`. 

70 description: Optional long-form description for this summary, as a 

71 constant `str`. Markdown is supported. Defaults to empty. 

72 collections: Optional list of graph collections keys. The new 

73 summary op is added to these collections. Defaults to 

74 `[Graph Keys.SUMMARIES]`. 

75 

76 Returns: 

77 A summary operation for use in a TensorFlow graph. The float32 tensor 

78 produced by the summary operation is of dimension (6, num_thresholds). The 

79 first dimension (of length 6) is of the order: true positives, 

80 false positives, true negatives, false negatives, precision, recall. 

81 """ 

82 # TODO(nickfelt): remove on-demand imports once dep situation is fixed. 

83 import tensorflow.compat.v1 as tf 

84 

85 if num_thresholds is None: 

86 num_thresholds = _DEFAULT_NUM_THRESHOLDS 

87 

88 if weights is None: 

89 weights = 1.0 

90 

91 dtype = predictions.dtype 

92 

93 with tf.name_scope(name, values=[labels, predictions, weights]): 

94 tf.assert_type(labels, tf.bool) 

95 # We cast to float to ensure we have 0.0 or 1.0. 

96 f_labels = tf.cast(labels, dtype) 

97 # Ensure predictions are all in range [0.0, 1.0]. 

98 predictions = tf.minimum(1.0, tf.maximum(0.0, predictions)) 

99 # Get weighted true/false labels. 

100 true_labels = f_labels * weights 

101 false_labels = (1.0 - f_labels) * weights 

102 

103 # Before we begin, flatten predictions. 

104 predictions = tf.reshape(predictions, [-1]) 

105 

106 # Shape the labels so they are broadcast-able for later multiplication. 

107 true_labels = tf.reshape(true_labels, [-1, 1]) 

108 false_labels = tf.reshape(false_labels, [-1, 1]) 

109 

110 # To compute TP/FP/TN/FN, we are measuring a binary classifier 

111 # C(t) = (predictions >= t) 

112 # at each threshold 't'. So we have 

113 # TP(t) = sum( C(t) * true_labels ) 

114 # FP(t) = sum( C(t) * false_labels ) 

115 # 

116 # But, computing C(t) requires computation for each t. To make it fast, 

117 # observe that C(t) is a cumulative integral, and so if we have 

118 # thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} 

119 # where n = num_thresholds, and if we can compute the bucket function 

120 # B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) 

121 # then we get 

122 # C(t_i) = sum( B(j), j >= i ) 

123 # which is the reversed cumulative sum in tf.cumsum(). 

124 # 

125 # We can compute B(i) efficiently by taking advantage of the fact that 

126 # our thresholds are evenly distributed, in that 

127 # width = 1.0 / (num_thresholds - 1) 

128 # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] 

129 # Given a prediction value p, we can map it to its bucket by 

130 # bucket_index(p) = floor( p * (num_thresholds - 1) ) 

131 # so we can use tf.scatter_add() to update the buckets in one pass. 

132 

133 # Compute the bucket indices for each prediction value. 

134 bucket_indices = tf.cast( 

135 tf.floor(predictions * (num_thresholds - 1)), tf.int32 

136 ) 

137 

138 # Bucket predictions. 

139 tp_buckets = tf.reduce_sum( 

140 input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds) 

141 * true_labels, 

142 axis=0, 

143 ) 

144 fp_buckets = tf.reduce_sum( 

145 input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds) 

146 * false_labels, 

147 axis=0, 

148 ) 

149 

150 # Set up the cumulative sums to compute the actual metrics. 

151 tp = tf.cumsum(tp_buckets, reverse=True, name="tp") 

152 fp = tf.cumsum(fp_buckets, reverse=True, name="fp") 

153 # fn = sum(true_labels) - tp 

154 # = sum(tp_buckets) - tp 

155 # = tp[0] - tp 

156 # Similarly, 

157 # tn = fp[0] - fp 

158 tn = fp[0] - fp 

159 fn = tp[0] - tp 

160 

161 precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp) 

162 recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) 

163 

164 return _create_tensor_summary( 

165 name, 

166 tp, 

167 fp, 

168 tn, 

169 fn, 

170 precision, 

171 recall, 

172 num_thresholds, 

173 display_name, 

174 description, 

175 collections, 

176 ) 

177 

178 

179def pb( 

180 name, 

181 labels, 

182 predictions, 

183 num_thresholds=None, 

184 weights=None, 

185 display_name=None, 

186 description=None, 

187): 

188 """Create a PR curves summary protobuf. 

189 

190 Arguments: 

191 name: A name for the generated node. Will also serve as a series name in 

192 TensorBoard. 

193 labels: The ground truth values. A bool numpy array. 

194 predictions: A float32 numpy array whose values are in the range `[0, 1]`. 

195 Dimensions must match those of `labels`. 

196 num_thresholds: Optional number of thresholds, evenly distributed in 

197 `[0, 1]`, to compute PR metrics for. When provided, should be an int of 

198 value at least 2. Defaults to 201. 

199 weights: Optional float or float32 numpy array. Individual counts are 

200 multiplied by this value. This tensor must be either the same shape as 

201 or broadcastable to the `labels` numpy array. 

202 display_name: Optional name for this summary in TensorBoard, as a `str`. 

203 Defaults to `name`. 

204 description: Optional long-form description for this summary, as a `str`. 

205 Markdown is supported. Defaults to empty. 

206 """ 

207 # TODO(nickfelt): remove on-demand imports once dep situation is fixed. 

208 import tensorflow.compat.v1 as tf # noqa: F401 

209 

210 if num_thresholds is None: 

211 num_thresholds = _DEFAULT_NUM_THRESHOLDS 

212 

213 if weights is None: 

214 weights = 1.0 

215 

216 # Compute bins of true positives and false positives. 

217 bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) 

218 float_labels = labels.astype(float) 

219 histogram_range = (0, num_thresholds - 1) 

220 tp_buckets, _ = np.histogram( 

221 bucket_indices, 

222 bins=num_thresholds, 

223 range=histogram_range, 

224 weights=float_labels * weights, 

225 ) 

226 fp_buckets, _ = np.histogram( 

227 bucket_indices, 

228 bins=num_thresholds, 

229 range=histogram_range, 

230 weights=(1.0 - float_labels) * weights, 

231 ) 

232 

233 # Obtain the reverse cumulative sum. 

234 tp = np.cumsum(tp_buckets[::-1])[::-1] 

235 fp = np.cumsum(fp_buckets[::-1])[::-1] 

236 tn = fp[0] - fp 

237 fn = tp[0] - tp 

238 precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp) 

239 recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn) 

240 

241 return raw_data_pb( 

242 name, 

243 true_positive_counts=tp, 

244 false_positive_counts=fp, 

245 true_negative_counts=tn, 

246 false_negative_counts=fn, 

247 precision=precision, 

248 recall=recall, 

249 num_thresholds=num_thresholds, 

250 display_name=display_name, 

251 description=description, 

252 ) 

253 

254 

255def streaming_op( 

256 name, 

257 labels, 

258 predictions, 

259 num_thresholds=None, 

260 weights=None, 

261 metrics_collections=None, 

262 updates_collections=None, 

263 display_name=None, 

264 description=None, 

265): 

266 """Computes a precision-recall curve summary across batches of data. 

267 

268 This function is similar to op() above, but can be used to compute the PR 

269 curve across multiple batches of labels and predictions, in the same style 

270 as the metrics found in tf.metrics. 

271 

272 This function creates multiple local variables for storing true positives, 

273 true negative, etc. accumulated over each batch of data, and uses these local 

274 variables for computing the final PR curve summary. These variables can be 

275 updated with the returned update_op. 

276 

277 Args: 

278 name: A tag attached to the summary. Used by TensorBoard for organization. 

279 labels: The ground truth values, a `Tensor` whose dimensions must match 

280 `predictions`. Will be cast to `bool`. 

281 predictions: A floating point `Tensor` of arbitrary shape and whose values 

282 are in the range `[0, 1]`. 

283 num_thresholds: The number of evenly spaced thresholds to generate for 

284 computing the PR curve. Defaults to 201. 

285 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

286 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

287 be either `1`, or the same as the corresponding `labels` dimension). 

288 metrics_collections: An optional list of collections that `auc` should be 

289 added to. 

290 updates_collections: An optional list of collections that `update_op` should 

291 be added to. 

292 display_name: Optional name for this summary in TensorBoard, as a 

293 constant `str`. Defaults to `name`. 

294 description: Optional long-form description for this summary, as a 

295 constant `str`. Markdown is supported. Defaults to empty. 

296 

297 Returns: 

298 pr_curve: A string `Tensor` containing a single value: the 

299 serialized PR curve Tensor summary. The summary contains a 

300 float32 `Tensor` of dimension (6, num_thresholds). The first 

301 dimension (of length 6) is of the order: true positives, false 

302 positives, true negatives, false negatives, precision, recall. 

303 update_op: An operation that updates the summary with the latest data. 

304 """ 

305 # TODO(nickfelt): remove on-demand imports once dep situation is fixed. 

306 import tensorflow.compat.v1 as tf 

307 

308 if num_thresholds is None: 

309 num_thresholds = _DEFAULT_NUM_THRESHOLDS 

310 

311 thresholds = [i / float(num_thresholds - 1) for i in range(num_thresholds)] 

312 

313 with tf.name_scope(name, values=[labels, predictions, weights]): 

314 tp, update_tp = tf.metrics.true_positives_at_thresholds( 

315 labels=labels, 

316 predictions=predictions, 

317 thresholds=thresholds, 

318 weights=weights, 

319 ) 

320 fp, update_fp = tf.metrics.false_positives_at_thresholds( 

321 labels=labels, 

322 predictions=predictions, 

323 thresholds=thresholds, 

324 weights=weights, 

325 ) 

326 tn, update_tn = tf.metrics.true_negatives_at_thresholds( 

327 labels=labels, 

328 predictions=predictions, 

329 thresholds=thresholds, 

330 weights=weights, 

331 ) 

332 fn, update_fn = tf.metrics.false_negatives_at_thresholds( 

333 labels=labels, 

334 predictions=predictions, 

335 thresholds=thresholds, 

336 weights=weights, 

337 ) 

338 

339 def compute_summary(tp, fp, tn, fn, collections): 

340 precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp) 

341 recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) 

342 

343 return _create_tensor_summary( 

344 name, 

345 tp, 

346 fp, 

347 tn, 

348 fn, 

349 precision, 

350 recall, 

351 num_thresholds, 

352 display_name, 

353 description, 

354 collections, 

355 ) 

356 

357 pr_curve = compute_summary(tp, fp, tn, fn, metrics_collections) 

358 update_op = tf.group(update_tp, update_fp, update_tn, update_fn) 

359 if updates_collections: 

360 for collection in updates_collections: 

361 tf.add_to_collection(collection, update_op) 

362 

363 return pr_curve, update_op 

364 

365 

366def raw_data_op( 

367 name, 

368 true_positive_counts, 

369 false_positive_counts, 

370 true_negative_counts, 

371 false_negative_counts, 

372 precision, 

373 recall, 

374 num_thresholds=None, 

375 display_name=None, 

376 description=None, 

377 collections=None, 

378): 

379 """Create an op that collects data for visualizing PR curves. 

380 

381 Unlike the op above, this one avoids computing precision, recall, and the 

382 intermediate counts. Instead, it accepts those tensors as arguments and 

383 relies on the caller to ensure that the calculations are correct (and the 

384 counts yield the provided precision and recall values). 

385 

386 This op is useful when a caller seeks to compute precision and recall 

387 differently but still use the PR curves plugin. 

388 

389 Args: 

390 name: A tag attached to the summary. Used by TensorBoard for organization. 

391 true_positive_counts: A rank-1 tensor of true positive counts. Must contain 

392 `num_thresholds` elements and be castable to float32. Values correspond 

393 to thresholds that increase from left to right (from 0 to 1). 

394 false_positive_counts: A rank-1 tensor of false positive counts. Must 

395 contain `num_thresholds` elements and be castable to float32. Values 

396 correspond to thresholds that increase from left to right (from 0 to 1). 

397 true_negative_counts: A rank-1 tensor of true negative counts. Must contain 

398 `num_thresholds` elements and be castable to float32. Values 

399 correspond to thresholds that increase from left to right (from 0 to 1). 

400 false_negative_counts: A rank-1 tensor of false negative counts. Must 

401 contain `num_thresholds` elements and be castable to float32. Values 

402 correspond to thresholds that increase from left to right (from 0 to 1). 

403 precision: A rank-1 tensor of precision values. Must contain 

404 `num_thresholds` elements and be castable to float32. Values correspond 

405 to thresholds that increase from left to right (from 0 to 1). 

406 recall: A rank-1 tensor of recall values. Must contain `num_thresholds` 

407 elements and be castable to float32. Values correspond to thresholds 

408 that increase from left to right (from 0 to 1). 

409 num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to 

410 compute PR metrics for. Should be `>= 2`. This value should be a 

411 constant integer value, not a Tensor that stores an integer. 

412 display_name: Optional name for this summary in TensorBoard, as a 

413 constant `str`. Defaults to `name`. 

414 description: Optional long-form description for this summary, as a 

415 constant `str`. Markdown is supported. Defaults to empty. 

416 collections: Optional list of graph collections keys. The new 

417 summary op is added to these collections. Defaults to 

418 `[Graph Keys.SUMMARIES]`. 

419 

420 Returns: 

421 A summary operation for use in a TensorFlow graph. See docs for the `op` 

422 method for details on the float32 tensor produced by this summary. 

423 """ 

424 # TODO(nickfelt): remove on-demand imports once dep situation is fixed. 

425 import tensorflow.compat.v1 as tf 

426 

427 with tf.name_scope( 

428 name, 

429 values=[ 

430 true_positive_counts, 

431 false_positive_counts, 

432 true_negative_counts, 

433 false_negative_counts, 

434 precision, 

435 recall, 

436 ], 

437 ): 

438 return _create_tensor_summary( 

439 name, 

440 true_positive_counts, 

441 false_positive_counts, 

442 true_negative_counts, 

443 false_negative_counts, 

444 precision, 

445 recall, 

446 num_thresholds, 

447 display_name, 

448 description, 

449 collections, 

450 ) 

451 

452 

453def raw_data_pb( 

454 name, 

455 true_positive_counts, 

456 false_positive_counts, 

457 true_negative_counts, 

458 false_negative_counts, 

459 precision, 

460 recall, 

461 num_thresholds=None, 

462 display_name=None, 

463 description=None, 

464): 

465 """Create a PR curves summary protobuf from raw data values. 

466 

467 Args: 

468 name: A tag attached to the summary. Used by TensorBoard for organization. 

469 true_positive_counts: A rank-1 numpy array of true positive counts. Must 

470 contain `num_thresholds` elements and be castable to float32. 

471 false_positive_counts: A rank-1 numpy array of false positive counts. Must 

472 contain `num_thresholds` elements and be castable to float32. 

473 true_negative_counts: A rank-1 numpy array of true negative counts. Must 

474 contain `num_thresholds` elements and be castable to float32. 

475 false_negative_counts: A rank-1 numpy array of false negative counts. Must 

476 contain `num_thresholds` elements and be castable to float32. 

477 precision: A rank-1 numpy array of precision values. Must contain 

478 `num_thresholds` elements and be castable to float32. 

479 recall: A rank-1 numpy array of recall values. Must contain `num_thresholds` 

480 elements and be castable to float32. 

481 num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to 

482 compute PR metrics for. Should be an int `>= 2`. 

483 display_name: Optional name for this summary in TensorBoard, as a `str`. 

484 Defaults to `name`. 

485 description: Optional long-form description for this summary, as a `str`. 

486 Markdown is supported. Defaults to empty. 

487 

488 Returns: 

489 A summary operation for use in a TensorFlow graph. See docs for the `op` 

490 method for details on the float32 tensor produced by this summary. 

491 """ 

492 # TODO(nickfelt): remove on-demand imports once dep situation is fixed. 

493 import tensorflow.compat.v1 as tf 

494 

495 if display_name is None: 

496 display_name = name 

497 summary_metadata = metadata.create_summary_metadata( 

498 display_name=display_name if display_name is not None else name, 

499 description=description or "", 

500 num_thresholds=num_thresholds, 

501 ) 

502 tf_summary_metadata = tf.SummaryMetadata.FromString( 

503 summary_metadata.SerializeToString() 

504 ) 

505 summary = tf.Summary() 

506 data = np.stack( 

507 ( 

508 true_positive_counts, 

509 false_positive_counts, 

510 true_negative_counts, 

511 false_negative_counts, 

512 precision, 

513 recall, 

514 ) 

515 ) 

516 tensor = tf.make_tensor_proto(np.float32(data), dtype=tf.float32) 

517 summary.value.add( 

518 tag="%s/pr_curves" % name, metadata=tf_summary_metadata, tensor=tensor 

519 ) 

520 return summary 

521 

522 

523def _create_tensor_summary( 

524 name, 

525 true_positive_counts, 

526 false_positive_counts, 

527 true_negative_counts, 

528 false_negative_counts, 

529 precision, 

530 recall, 

531 num_thresholds=None, 

532 display_name=None, 

533 description=None, 

534 collections=None, 

535): 

536 """A private helper method for generating a tensor summary. 

537 

538 We use a helper method instead of having `op` directly call `raw_data_op` 

539 to prevent the scope of `raw_data_op` from being embedded within `op`. 

540 

541 Arguments are the same as for raw_data_op. 

542 

543 Returns: 

544 A tensor summary that collects data for PR curves. 

545 """ 

546 # TODO(nickfelt): remove on-demand imports once dep situation is fixed. 

547 import tensorflow.compat.v1 as tf 

548 

549 # Store the number of thresholds within the summary metadata because 

550 # that value is constant for all pr curve summaries with the same tag. 

551 summary_metadata = metadata.create_summary_metadata( 

552 display_name=display_name if display_name is not None else name, 

553 description=description or "", 

554 num_thresholds=num_thresholds, 

555 ) 

556 

557 # Store values within a tensor. We store them in the order: 

558 # true positives, false positives, true negatives, false 

559 # negatives, precision, and recall. 

560 combined_data = tf.stack( 

561 [ 

562 tf.cast(true_positive_counts, tf.float32), 

563 tf.cast(false_positive_counts, tf.float32), 

564 tf.cast(true_negative_counts, tf.float32), 

565 tf.cast(false_negative_counts, tf.float32), 

566 tf.cast(precision, tf.float32), 

567 tf.cast(recall, tf.float32), 

568 ] 

569 ) 

570 

571 return tf.summary.tensor_summary( 

572 name="pr_curves", 

573 tensor=combined_data, 

574 collections=collections, 

575 summary_metadata=summary_metadata, 

576 )