Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/metrics_impl.py: 14%

796 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# Licensed under the Apache License, Version 2.0 (the "License"); 

3# you may not use this file except in compliance with the License. 

4# You may obtain a copy of the License at 

5# 

6# http://www.apache.org/licenses/LICENSE-2.0 

7# 

8# Unless required by applicable law or agreed to in writing, software 

9# distributed under the License is distributed on an "AS IS" BASIS, 

10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

11# See the License for the specific language governing permissions and 

12# limitations under the License. 

13# ============================================================================== 

14"""Implementation of tf.metrics module.""" 

15 

16from tensorflow.python.distribute import distribute_lib 

17from tensorflow.python.eager import context 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import sparse_tensor 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import array_ops_stack 

23from tensorflow.python.ops import check_ops 

24from tensorflow.python.ops import cond 

25from tensorflow.python.ops import confusion_matrix 

26from tensorflow.python.ops import math_ops 

27from tensorflow.python.ops import nn 

28from tensorflow.python.ops import sets 

29from tensorflow.python.ops import sparse_ops 

30from tensorflow.python.ops import state_ops 

31from tensorflow.python.ops import variable_scope 

32from tensorflow.python.ops import variable_v1 

33from tensorflow.python.ops import variables 

34from tensorflow.python.ops import weights_broadcast_ops 

35from tensorflow.python.platform import tf_logging as logging 

36from tensorflow.python.util.deprecation import deprecated 

37from tensorflow.python.util.tf_export import tf_export 

38 

39 

40def metric_variable(shape, dtype, validate_shape=True, name=None): 

41 """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections. 

42 

43 If running in a `DistributionStrategy` context, the variable will be 

44 "sync on read". This means: 

45 

46 * The returned object will be a container with separate variables 

47 per replica of the model. 

48 

49 * When writing to the variable, e.g. using `assign_add` in a metric 

50 update, the update will be applied to the variable local to the 

51 replica. 

52 

53 * To get a metric's result value, we need to sum the variable values 

54 across the replicas before computing the final answer. Furthermore, 

55 the final answer should be computed once instead of in every 

56 replica. Both of these are accomplished by running the computation 

57 of the final result value inside 

58 `distribute_lib.get_replica_context().merge_call(fn)`. 

59 Inside the `merge_call()`, ops are only added to the graph once 

60 and access to a sync on read variable in a computation returns 

61 the sum across all replicas. 

62 

63 Args: 

64 shape: Shape of the created variable. 

65 dtype: Type of the created variable. 

66 validate_shape: (Optional) Whether shape validation is enabled for 

67 the created variable. 

68 name: (Optional) String name of the created variable. 

69 

70 Returns: 

71 A (non-trainable) variable initialized to zero, or if inside a 

72 `DistributionStrategy` scope a sync on read variable container. 

73 """ 

74 # Note that synchronization "ON_READ" implies trainable=False. 

75 return variable_v1.VariableV1( 

76 lambda: array_ops.zeros(shape, dtype), 

77 trainable=False, 

78 collections=[ 

79 ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES 

80 ], 

81 validate_shape=validate_shape, 

82 synchronization=variables.VariableSynchronization.ON_READ, 

83 aggregation=variables.VariableAggregation.SUM, 

84 name=name) 

85 

86 

87def _remove_squeezable_dimensions(predictions, labels, weights): 

88 """Squeeze or expand last dim if needed. 

89 

90 Squeezes last dim of `predictions` or `labels` if their rank differs by 1 

91 (using confusion_matrix.remove_squeezable_dimensions). 

92 Squeezes or expands last dim of `weights` if its rank differs by 1 from the 

93 new rank of `predictions`. 

94 

95 If `weights` is scalar, it is kept scalar. 

96 

97 This will use static shape if available. Otherwise, it will add graph 

98 operations, which could result in a performance hit. 

99 

100 Args: 

101 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 

102 labels: Optional label `Tensor` whose dimensions match `predictions`. 

103 weights: Optional weight scalar or `Tensor` whose dimensions match 

104 `predictions`. 

105 

106 Returns: 

107 Tuple of `predictions`, `labels` and `weights`. Each of them possibly has 

108 the last dimension squeezed, `weights` could be extended by one dimension. 

109 """ 

110 predictions = ops.convert_to_tensor(predictions) 

111 if labels is not None: 

112 labels, predictions = confusion_matrix.remove_squeezable_dimensions( 

113 labels, predictions) 

114 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 

115 

116 if weights is None: 

117 return predictions, labels, None 

118 

119 weights = ops.convert_to_tensor(weights) 

120 weights_shape = weights.get_shape() 

121 weights_rank = weights_shape.ndims 

122 if weights_rank == 0: 

123 return predictions, labels, weights 

124 

125 predictions_shape = predictions.get_shape() 

126 predictions_rank = predictions_shape.ndims 

127 if (predictions_rank is not None) and (weights_rank is not None): 

128 # Use static rank. 

129 if weights_rank - predictions_rank == 1: 

130 weights = array_ops.squeeze(weights, [-1]) 

131 elif predictions_rank - weights_rank == 1: 

132 weights = array_ops.expand_dims(weights, [-1]) 

133 else: 

134 # Use dynamic rank. 

135 weights_rank_tensor = array_ops.rank(weights) 

136 rank_diff = weights_rank_tensor - array_ops.rank(predictions) 

137 

138 def _maybe_expand_weights(): 

139 return cond.cond( 

140 math_ops.equal(rank_diff, -1), 

141 lambda: array_ops.expand_dims(weights, [-1]), lambda: weights) 

142 

143 # Don't attempt squeeze if it will fail based on static check. 

144 if ((weights_rank is not None) and 

145 (not weights_shape.dims[-1].is_compatible_with(1))): 

146 maybe_squeeze_weights = lambda: weights 

147 else: 

148 maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1]) 

149 

150 def _maybe_adjust_weights(): 

151 return cond.cond( 

152 math_ops.equal(rank_diff, 1), maybe_squeeze_weights, 

153 _maybe_expand_weights) 

154 

155 # If weights are scalar, do nothing. Otherwise, try to add or remove a 

156 # dimension to match predictions. 

157 weights = cond.cond( 

158 math_ops.equal(weights_rank_tensor, 0), lambda: weights, 

159 _maybe_adjust_weights) 

160 return predictions, labels, weights 

161 

162 

163def _maybe_expand_labels(labels, predictions): 

164 """If necessary, expand `labels` along last dimension to match `predictions`. 

165 

166 Args: 

167 labels: `Tensor` or `SparseTensor` with shape 

168 [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies 

169 num_labels=1, in which case the result is an expanded `labels` with shape 

170 [D1, ... DN, 1]. 

171 predictions: `Tensor` with shape [D1, ... DN, num_classes]. 

172 

173 Returns: 

174 `labels` with the same rank as `predictions`. 

175 

176 Raises: 

177 ValueError: if `labels` has invalid shape. 

178 """ 

179 with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope: 

180 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 

181 

182 # If sparse, expand sparse shape. 

183 if isinstance(labels, sparse_tensor.SparseTensor): 

184 return cond.cond( 

185 math_ops.equal( 

186 array_ops.rank(predictions), 

187 array_ops.size(labels.dense_shape) + 1), 

188 lambda: sparse_ops.sparse_reshape( # pylint: disable=g-long-lambda 

189 labels, 

190 shape=array_ops.concat((labels.dense_shape, (1,)), 0), 

191 name=scope), 

192 lambda: labels) 

193 

194 # Otherwise, try to use static shape. 

195 labels_rank = labels.get_shape().ndims 

196 if labels_rank is not None: 

197 predictions_rank = predictions.get_shape().ndims 

198 if predictions_rank is not None: 

199 if predictions_rank == labels_rank: 

200 return labels 

201 if predictions_rank == labels_rank + 1: 

202 return array_ops.expand_dims(labels, -1, name=scope) 

203 raise ValueError( 

204 f'Unexpected labels shape {labels.get_shape()} for predictions ' 

205 f'shape {predictions.get_shape()}. Predictions rank should be the ' 

206 'same rank as labels rank or labels rank plus one .') 

207 

208 # Otherwise, use dynamic shape. 

209 return cond.cond( 

210 math_ops.equal(array_ops.rank(predictions), 

211 array_ops.rank(labels) + 1), 

212 lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels) 

213 

214 

215def _safe_scalar_div(numerator, denominator, name): 

216 """Divides two values, returning 0 if the denominator is 0. 

217 

218 Args: 

219 numerator: A scalar `float64` `Tensor`. 

220 denominator: A scalar `float64` `Tensor`. 

221 name: Name for the returned op. 

222 

223 Returns: 

224 0 if `denominator` == 0, else `numerator` / `denominator` 

225 """ 

226 numerator.get_shape().with_rank_at_most(1) 

227 denominator.get_shape().with_rank_at_most(1) 

228 return math_ops.div_no_nan(numerator, denominator, name=name) 

229 

230 

231def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None): 

232 """Calculate a streaming confusion matrix. 

233 

234 Calculates a confusion matrix. For estimation over a stream of data, 

235 the function creates an `update_op` operation. 

236 

237 Args: 

238 labels: A `Tensor` of ground truth labels with shape [batch size] and of 

239 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 

240 predictions: A `Tensor` of prediction results for semantic labels, whose 

241 shape is [batch size] and type `int32` or `int64`. The tensor will be 

242 flattened if its rank > 1. 

243 num_classes: The possible number of labels the prediction task can 

244 have. This value must be provided, since a confusion matrix of 

245 dimension = [num_classes, num_classes] will be allocated. 

246 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

247 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

248 be either `1`, or the same as the corresponding `labels` dimension). 

249 

250 Returns: 

251 total_cm: A `Tensor` representing the confusion matrix. 

252 update_op: An operation that increments the confusion matrix. 

253 """ 

254 # Local variable to accumulate the predictions in the confusion matrix. 

255 total_cm = metric_variable( 

256 [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix') 

257 

258 # Cast the type to int64 required by confusion_matrix_ops. 

259 predictions = math_ops.cast(predictions, dtypes.int64) 

260 labels = math_ops.cast(labels, dtypes.int64) 

261 num_classes = math_ops.cast(num_classes, dtypes.int64) 

262 

263 # Flatten the input if its rank > 1. 

264 if predictions.get_shape().ndims > 1: 

265 predictions = array_ops.reshape(predictions, [-1]) 

266 

267 if labels.get_shape().ndims > 1: 

268 labels = array_ops.reshape(labels, [-1]) 

269 

270 if (weights is not None) and (weights.get_shape().ndims > 1): 

271 weights = array_ops.reshape(weights, [-1]) 

272 

273 # Accumulate the prediction to current confusion matrix. 

274 current_cm = confusion_matrix.confusion_matrix( 

275 labels, predictions, num_classes, weights=weights, dtype=dtypes.float64) 

276 update_op = state_ops.assign_add(total_cm, current_cm) 

277 return total_cm, update_op 

278 

279 

280def _aggregate_across_replicas(metrics_collections, metric_value_fn, *args): 

281 """Aggregate metric value across replicas.""" 

282 def fn(distribution, *a): 

283 """Call `metric_value_fn` in the correct control flow context.""" 

284 if hasattr(distribution.extended, '_outer_control_flow_context'): 

285 # If there was an outer context captured before this method was called, 

286 # then we enter that context to create the metric value op. If the 

287 # captured context is `None`, ops.control_dependencies(None) gives the 

288 # desired behavior. Else we use `Enter` and `Exit` to enter and exit the 

289 # captured context. 

290 # This special handling is needed because sometimes the metric is created 

291 # inside a while_loop (and perhaps a TPU rewrite context). But we don't 

292 # want the value op to be evaluated every step or on the TPU. So we 

293 # create it outside so that it can be evaluated at the end on the host, 

294 # once the update ops have been evaluated. 

295 

296 # pylint: disable=protected-access 

297 if distribution.extended._outer_control_flow_context is None: 

298 with ops.control_dependencies(None): 

299 metric_value = metric_value_fn(distribution, *a) 

300 else: 

301 distribution.extended._outer_control_flow_context.Enter() 

302 metric_value = metric_value_fn(distribution, *a) 

303 distribution.extended._outer_control_flow_context.Exit() 

304 # pylint: enable=protected-access 

305 else: 

306 metric_value = metric_value_fn(distribution, *a) 

307 if metrics_collections: 

308 ops.add_to_collections(metrics_collections, metric_value) 

309 return metric_value 

310 

311 return distribute_lib.get_replica_context().merge_call( 

312 fn, args=args) 

313 

314 

315@tf_export(v1=['metrics.mean']) 

316def mean(values, 

317 weights=None, 

318 metrics_collections=None, 

319 updates_collections=None, 

320 name=None): 

321 """Computes the (weighted) mean of the given values. 

322 

323 The `mean` function creates two local variables, `total` and `count` 

324 that are used to compute the average of `values`. This average is ultimately 

325 returned as `mean` which is an idempotent operation that simply divides 

326 `total` by `count`. 

327 

328 For estimation of the metric over a stream of data, the function creates an 

329 `update_op` operation that updates these variables and returns the `mean`. 

330 `update_op` increments `total` with the reduced sum of the product of `values` 

331 and `weights`, and it increments `count` with the reduced sum of `weights`. 

332 

333 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

334 

335 Args: 

336 values: A `Tensor` of arbitrary dimensions. 

337 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

338 `values`, and must be broadcastable to `values` (i.e., all dimensions must 

339 be either `1`, or the same as the corresponding `values` dimension). 

340 metrics_collections: An optional list of collections that `mean` 

341 should be added to. 

342 updates_collections: An optional list of collections that `update_op` 

343 should be added to. 

344 name: An optional variable_scope name. 

345 

346 Returns: 

347 mean: A `Tensor` representing the current mean, the value of `total` divided 

348 by `count`. 

349 update_op: An operation that increments the `total` and `count` variables 

350 appropriately and whose value matches `mean_value`. 

351 

352 Raises: 

353 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 

354 or if either `metrics_collections` or `updates_collections` are not a list 

355 or tuple. 

356 RuntimeError: If eager execution is enabled. 

357 

358 @compatibility(TF2) 

359 `tf.compat.v1.metrics.mean` is not compatible with eager 

360 execution or `tf.function`. 

361 Please use `tf.keras.metrics.Mean` instead for TF2 migration. After 

362 instantiating a `tf.keras.metrics.Mean` object, you can first call the 

363 `update_state()` method to record the new values, and then call the 

364 `result()` method to get the mean eagerly. You can also attach it to a 

365 Keras model with the `add_metric` method. Please refer to the [migration 

366 guide](https://www.tensorflow.org/guide/migrate#new-style_metrics_and_losses) 

367 for more details. 

368 

369 #### Structural Mapping to TF2 

370 

371 Before: 

372 

373 ```python 

374 mean, update_op = tf.compat.v1.metrics.mean( 

375 values=values, 

376 weights=weights, 

377 metrics_collections=metrics_collections, 

378 update_collections=update_collections, 

379 name=name) 

380 ``` 

381 

382 After: 

383 

384 ```python 

385 m = tf.keras.metrics.Mean( 

386 name=name) 

387 

388 m.update_state( 

389 values=values, 

390 sample_weight=weights) 

391 

392 mean = m.result() 

393 ``` 

394 

395 #### How to Map Arguments 

396 

397 | TF1 Arg Name | TF2 Arg Name | Note | 

398 | :-------------------- | :-------------- | :------------------------- | 

399 | `values` | `values` | In `update_state()` method | 

400 | `weights` | `sample_weight` | In `update_state()` method | 

401 | `metrics_collections` | Not supported | Metrics should be tracked | 

402 : : : explicitly or with Keras : 

403 : : : APIs, for example, : 

404 : : : [add_metric][add_metric], : 

405 : : : instead of via collections : 

406 | `updates_collections` | Not supported | - | 

407 | `name` | `name` | In constructor | 

408 

409 [add_metric]:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_metric 

410 

411 

412 #### Before & After Usage Example 

413 

414 Before: 

415 

416 >>> g = tf.Graph() 

417 >>> with g.as_default(): 

418 ... values = [1, 2, 3] 

419 ... mean, update_op = tf.compat.v1.metrics.mean(values) 

420 ... global_init = tf.compat.v1.global_variables_initializer() 

421 ... local_init = tf.compat.v1.local_variables_initializer() 

422 >>> sess = tf.compat.v1.Session(graph=g) 

423 >>> sess.run([global_init, local_init]) 

424 >>> sess.run(update_op) 

425 >>> sess.run(mean) 

426 2.0 

427 

428 

429 After: 

430 

431 >>> m = tf.keras.metrics.Mean() 

432 >>> m.update_state([1, 2, 3]) 

433 >>> m.result().numpy() 

434 2.0 

435 

436 ```python 

437 # Used within Keras model 

438 model.add_metric(tf.keras.metrics.Mean()(values)) 

439 ``` 

440 

441 @end_compatibility 

442 """ 

443 if context.executing_eagerly(): 

444 raise RuntimeError('tf.metrics.mean is not supported when eager execution ' 

445 'is enabled.') 

446 

447 with variable_scope.variable_scope(name, 'mean', (values, weights)): 

448 values = math_ops.cast(values, dtypes.float32) 

449 

450 total = metric_variable([], dtypes.float32, name='total') 

451 count = metric_variable([], dtypes.float32, name='count') 

452 

453 if weights is None: 

454 num_values = math_ops.cast(array_ops.size(values), dtypes.float32) 

455 else: 

456 values, _, weights = _remove_squeezable_dimensions( 

457 predictions=values, labels=None, weights=weights) 

458 weights = weights_broadcast_ops.broadcast_weights( 

459 math_ops.cast(weights, dtypes.float32), values) 

460 values = math_ops.multiply(values, weights) 

461 num_values = math_ops.reduce_sum(weights) 

462 

463 update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values)) 

464 with ops.control_dependencies([values]): 

465 update_count_op = state_ops.assign_add(count, num_values) 

466 

467 def compute_mean(_, t, c): 

468 return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value') 

469 

470 mean_t = _aggregate_across_replicas( 

471 metrics_collections, compute_mean, total, count) 

472 update_op = math_ops.div_no_nan( 

473 update_total_op, math_ops.maximum(update_count_op, 0), name='update_op') 

474 

475 if updates_collections: 

476 ops.add_to_collections(updates_collections, update_op) 

477 

478 return mean_t, update_op 

479 

480 

481@tf_export(v1=['metrics.accuracy']) 

482def accuracy(labels, 

483 predictions, 

484 weights=None, 

485 metrics_collections=None, 

486 updates_collections=None, 

487 name=None): 

488 """Calculates how often `predictions` matches `labels`. 

489 

490 The `accuracy` function creates two local variables, `total` and 

491 `count` that are used to compute the frequency with which `predictions` 

492 matches `labels`. This frequency is ultimately returned as `accuracy`: an 

493 idempotent operation that simply divides `total` by `count`. 

494 

495 For estimation of the metric over a stream of data, the function creates an 

496 `update_op` operation that updates these variables and returns the `accuracy`. 

497 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 

498 where the corresponding elements of `predictions` and `labels` match and 0.0 

499 otherwise. Then `update_op` increments `total` with the reduced sum of the 

500 product of `weights` and `is_correct`, and it increments `count` with the 

501 reduced sum of `weights`. 

502 

503 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

504 

505 Args: 

506 labels: The ground truth values, a `Tensor` whose shape matches 

507 `predictions`. 

508 predictions: The predicted values, a `Tensor` of any shape. 

509 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

510 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

511 be either `1`, or the same as the corresponding `labels` dimension). 

512 metrics_collections: An optional list of collections that `accuracy` should 

513 be added to. 

514 updates_collections: An optional list of collections that `update_op` should 

515 be added to. 

516 name: An optional variable_scope name. 

517 

518 Returns: 

519 accuracy: A `Tensor` representing the accuracy, the value of `total` divided 

520 by `count`. 

521 update_op: An operation that increments the `total` and `count` variables 

522 appropriately and whose value matches `accuracy`. 

523 

524 Raises: 

525 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

526 `weights` is not `None` and its shape doesn't match `predictions`, or if 

527 either `metrics_collections` or `updates_collections` are not a list or 

528 tuple. 

529 RuntimeError: If eager execution is enabled. 

530 

531 @compatibility(TF2) 

532 `tf.compat.v1.metrics.accuracy` is not compatible with eager 

533 execution or `tf.function`. 

534 Please use `tf.keras.metrics.Accuracy` instead for TF2 migration. After 

535 instantiating a `tf.keras.metrics.Accuracy` object, you can first call the 

536 `update_state()` method to record the prediction/labels, and then call the 

537 `result()` method to get the accuracy eagerly. You can also attach it to a 

538 Keras model when calling the `compile` method. Please refer to [this 

539 guide](https://www.tensorflow.org/guide/migrate#new-style_metrics_and_losses) 

540 for more details. 

541 

542 #### Structural Mapping to Native TF2 

543 

544 Before: 

545 

546 ```python 

547 accuracy, update_op = tf.compat.v1.metrics.accuracy( 

548 labels=labels, 

549 predictions=predictions, 

550 weights=weights, 

551 metrics_collections=metrics_collections, 

552 update_collections=update_collections, 

553 name=name) 

554 ``` 

555 

556 After: 

557 

558 ```python 

559 m = tf.keras.metrics.Accuracy( 

560 name=name, 

561 dtype=None) 

562 

563 m.update_state( 

564 y_true=labels, 

565 y_pred=predictions, 

566 sample_weight=weights) 

567 

568 accuracy = m.result() 

569 ``` 

570 

571 #### How to Map Arguments 

572 

573 | TF1 Arg Name | TF2 Arg Name | Note | 

574 | :-------------------- | :-------------- | :------------------------- | 

575 | `label` | `y_true` | In `update_state()` method | 

576 | `predictions` | `y_true` | In `update_state()` method | 

577 | `weights` | `sample_weight` | In `update_state()` method | 

578 | `metrics_collections` | Not supported | Metrics should be tracked | 

579 : : : explicitly or with Keras : 

580 : : : APIs, for example, : 

581 : : : [add_metric][add_metric], : 

582 : : : instead of via collections : 

583 | `updates_collections` | Not supported | - | 

584 | `name` | `name` | In constructor | 

585 

586 [add_metric]:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_metric 

587 

588 

589 #### Before & After Usage Example 

590 

591 Before: 

592 

593 >>> g = tf.Graph() 

594 >>> with g.as_default(): 

595 ... logits = [1, 2, 3] 

596 ... labels = [0, 2, 3] 

597 ... acc, acc_op = tf.compat.v1.metrics.accuracy(logits, labels) 

598 ... global_init = tf.compat.v1.global_variables_initializer() 

599 ... local_init = tf.compat.v1.local_variables_initializer() 

600 >>> sess = tf.compat.v1.Session(graph=g) 

601 >>> sess.run([global_init, local_init]) 

602 >>> print(sess.run([acc, acc_op])) 

603 [0.0, 0.66667] 

604 

605 

606 After: 

607 

608 >>> m = tf.keras.metrics.Accuracy() 

609 >>> m.update_state([1, 2, 3], [0, 2, 3]) 

610 >>> m.result().numpy() 

611 0.66667 

612 

613 ```python 

614 # Used within Keras model 

615 model.compile(optimizer='sgd', 

616 loss='mse', 

617 metrics=[tf.keras.metrics.Accuracy()]) 

618 ``` 

619 

620 @end_compatibility 

621 """ 

622 if context.executing_eagerly(): 

623 raise RuntimeError('tf.metrics.accuracy is not supported when eager ' 

624 'execution is enabled.') 

625 

626 predictions, labels, weights = _remove_squeezable_dimensions( 

627 predictions=predictions, labels=labels, weights=weights) 

628 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 

629 if labels.dtype != predictions.dtype: 

630 predictions = math_ops.cast(predictions, labels.dtype) 

631 is_correct = math_ops.cast( 

632 math_ops.equal(predictions, labels), dtypes.float32) 

633 return mean(is_correct, weights, metrics_collections, updates_collections, 

634 name or 'accuracy') 

635 

636 

637def _confusion_matrix_at_thresholds(labels, 

638 predictions, 

639 thresholds, 

640 weights=None, 

641 includes=None): 

642 """Computes true_positives, false_negatives, true_negatives, false_positives. 

643 

644 This function creates up to four local variables, `true_positives`, 

645 `true_negatives`, `false_positives` and `false_negatives`. 

646 `true_positive[i]` is defined as the total weight of values in `predictions` 

647 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 

648 `false_negatives[i]` is defined as the total weight of values in `predictions` 

649 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 

650 `true_negatives[i]` is defined as the total weight of values in `predictions` 

651 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 

652 `false_positives[i]` is defined as the total weight of values in `predictions` 

653 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 

654 

655 For estimation of these metrics over a stream of data, for each metric the 

656 function respectively creates an `update_op` operation that updates the 

657 variable and returns its value. 

658 

659 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

660 

661 Args: 

662 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 

663 `bool`. 

664 predictions: A floating point `Tensor` of arbitrary shape and whose values 

665 are in the range `[0, 1]`. 

666 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 

667 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

668 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

669 be either `1`, or the same as the corresponding `labels` dimension). 

670 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, 

671 default to all four. 

672 

673 Returns: 

674 values: Dict of variables of shape `[len(thresholds)]`. Keys are from 

675 `includes`. 

676 update_ops: Dict of operations that increments the `values`. Keys are from 

677 `includes`. 

678 

679 Raises: 

680 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

681 `weights` is not `None` and its shape doesn't match `predictions`, or if 

682 `includes` contains invalid keys. 

683 """ 

684 all_includes = ('tp', 'fn', 'tn', 'fp') 

685 if includes is None: 

686 includes = all_includes 

687 else: 

688 for include in includes: 

689 if include not in all_includes: 

690 raise ValueError(f'Invalid key: {include}') 

691 

692 with ops.control_dependencies([ 

693 check_ops.assert_greater_equal( 

694 predictions, 

695 math_ops.cast(0.0, dtype=predictions.dtype), 

696 message='predictions must be in [0, 1]'), 

697 check_ops.assert_less_equal( 

698 predictions, 

699 math_ops.cast(1.0, dtype=predictions.dtype), 

700 message='predictions must be in [0, 1]') 

701 ]): 

702 predictions, labels, weights = _remove_squeezable_dimensions( 

703 predictions=math_ops.cast(predictions, dtypes.float32), 

704 labels=math_ops.cast(labels, dtype=dtypes.bool), 

705 weights=weights) 

706 

707 num_thresholds = len(thresholds) 

708 

709 # Reshape predictions and labels. 

710 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 

711 labels_2d = array_ops.reshape( 

712 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 

713 

714 # Use static shape if known. 

715 num_predictions = predictions_2d.get_shape().as_list()[0] 

716 

717 # Otherwise use dynamic shape. 

718 if num_predictions is None: 

719 num_predictions = array_ops.shape(predictions_2d)[0] 

720 thresh_tiled = array_ops.tile( 

721 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 

722 array_ops_stack.stack([1, num_predictions])) 

723 

724 # Tile the predictions after thresholding them across different thresholds. 

725 pred_is_pos = math_ops.greater( 

726 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 

727 thresh_tiled) 

728 if ('fn' in includes) or ('tn' in includes): 

729 pred_is_neg = math_ops.logical_not(pred_is_pos) 

730 

731 # Tile labels by number of thresholds 

732 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 

733 if ('fp' in includes) or ('tn' in includes): 

734 label_is_neg = math_ops.logical_not(label_is_pos) 

735 

736 if weights is not None: 

737 weights = weights_broadcast_ops.broadcast_weights( 

738 math_ops.cast(weights, dtypes.float32), predictions) 

739 weights_tiled = array_ops.tile( 

740 array_ops.reshape(weights, [1, -1]), [num_thresholds, 1]) 

741 thresh_tiled.get_shape().assert_is_compatible_with( 

742 weights_tiled.get_shape()) 

743 else: 

744 weights_tiled = None 

745 

746 values = {} 

747 update_ops = {} 

748 

749 if 'tp' in includes: 

750 true_p = metric_variable( 

751 [num_thresholds], dtypes.float32, name='true_positives') 

752 is_true_positive = math_ops.cast( 

753 math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32) 

754 if weights_tiled is not None: 

755 is_true_positive *= weights_tiled 

756 update_ops['tp'] = state_ops.assign_add(true_p, 

757 math_ops.reduce_sum( 

758 is_true_positive, 1)) 

759 values['tp'] = true_p 

760 

761 if 'fn' in includes: 

762 false_n = metric_variable( 

763 [num_thresholds], dtypes.float32, name='false_negatives') 

764 is_false_negative = math_ops.cast( 

765 math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32) 

766 if weights_tiled is not None: 

767 is_false_negative *= weights_tiled 

768 update_ops['fn'] = state_ops.assign_add(false_n, 

769 math_ops.reduce_sum( 

770 is_false_negative, 1)) 

771 values['fn'] = false_n 

772 

773 if 'tn' in includes: 

774 true_n = metric_variable( 

775 [num_thresholds], dtypes.float32, name='true_negatives') 

776 is_true_negative = math_ops.cast( 

777 math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32) 

778 if weights_tiled is not None: 

779 is_true_negative *= weights_tiled 

780 update_ops['tn'] = state_ops.assign_add(true_n, 

781 math_ops.reduce_sum( 

782 is_true_negative, 1)) 

783 values['tn'] = true_n 

784 

785 if 'fp' in includes: 

786 false_p = metric_variable( 

787 [num_thresholds], dtypes.float32, name='false_positives') 

788 is_false_positive = math_ops.cast( 

789 math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32) 

790 if weights_tiled is not None: 

791 is_false_positive *= weights_tiled 

792 update_ops['fp'] = state_ops.assign_add(false_p, 

793 math_ops.reduce_sum( 

794 is_false_positive, 1)) 

795 values['fp'] = false_p 

796 

797 return values, update_ops 

798 

799 

800def _aggregate_variable(v, collections): 

801 f = lambda distribution, value: distribution.extended.read_var(value) 

802 return _aggregate_across_replicas(collections, f, v) 

803 

804 

805@tf_export(v1=['metrics.auc']) 

806@deprecated(None, 

807 'The value of AUC returned by this may race with the update so ' 

808 'this is deprecated. Please use tf.keras.metrics.AUC instead.') 

809def auc(labels, 

810 predictions, 

811 weights=None, 

812 num_thresholds=200, 

813 metrics_collections=None, 

814 updates_collections=None, 

815 curve='ROC', 

816 name=None, 

817 summation_method='trapezoidal', 

818 thresholds=None): 

819 """Computes the approximate AUC via a Riemann sum. 

820 

821 The `auc` function creates four local variables, `true_positives`, 

822 `true_negatives`, `false_positives` and `false_negatives` that are used to 

823 compute the AUC. To discretize the AUC curve, a linearly spaced set of 

824 thresholds is used to compute pairs of recall and precision values. The area 

825 under the ROC-curve is therefore computed using the height of the recall 

826 values by the false positive rate, while the area under the PR-curve is the 

827 computed using the height of the precision values by the recall. 

828 

829 This value is ultimately returned as `auc`, an idempotent operation that 

830 computes the area under a discretized curve of precision versus recall values 

831 (computed using the aforementioned variables). The `num_thresholds` variable 

832 controls the degree of discretization with larger numbers of thresholds more 

833 closely approximating the true AUC. The quality of the approximation may vary 

834 dramatically depending on `num_thresholds`. 

835 

836 For best results, `predictions` should be distributed approximately uniformly 

837 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 

838 approximation may be poor if this is not the case. Setting `summation_method` 

839 to 'minoring' or 'majoring' can help quantify the error in the approximation 

840 by providing lower or upper bound estimate of the AUC. The `thresholds` 

841 parameter can be used to manually specify thresholds which split the 

842 predictions more evenly. 

843 

844 For estimation of the metric over a stream of data, the function creates an 

845 `update_op` operation that updates these variables and returns the `auc`. 

846 

847 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

848 

849 Args: 

850 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 

851 `bool`. 

852 predictions: A floating point `Tensor` of arbitrary shape and whose values 

853 are in the range `[0, 1]`. 

854 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

855 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

856 be either `1`, or the same as the corresponding `labels` dimension). 

857 num_thresholds: The number of thresholds to use when discretizing the roc 

858 curve. 

859 metrics_collections: An optional list of collections that `auc` should be 

860 added to. 

861 updates_collections: An optional list of collections that `update_op` should 

862 be added to. 

863 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 

864 'PR' for the Precision-Recall-curve. 

865 name: An optional variable_scope name. 

866 summation_method: Specifies the Riemann summation method used 

867 (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that 

868 applies the trapezoidal rule; 'careful_interpolation', a variant of it 

869 differing only by a more correct interpolation scheme for PR-AUC - 

870 interpolating (true/false) positives but not the ratio that is precision; 

871 'minoring' that applies left summation for increasing intervals and right 

872 summation for decreasing intervals; 'majoring' that does the opposite. 

873 Note that 'careful_interpolation' is strictly preferred to 'trapezoidal' 

874 (to be deprecated soon) as it applies the same method for ROC, and a 

875 better one (see Davis & Goadrich 2006 for details) for the PR curve. 

876 thresholds: An optional list of floating point values to use as the 

877 thresholds for discretizing the curve. If set, the `num_thresholds` 

878 parameter is ignored. Values should be in [0, 1]. Endpoint thresholds 

879 equal to {-epsilon, 1+epsilon} for a small positive epsilon value will be 

880 automatically included with these to correctly handle predictions equal to 

881 exactly 0 or 1. 

882 

883 Returns: 

884 auc: A scalar `Tensor` representing the current area-under-curve. 

885 update_op: An operation that increments the `true_positives`, 

886 `true_negatives`, `false_positives` and `false_negatives` variables 

887 appropriately and whose value matches `auc`. 

888 

889 Raises: 

890 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

891 `weights` is not `None` and its shape doesn't match `predictions`, or if 

892 either `metrics_collections` or `updates_collections` are not a list or 

893 tuple. 

894 RuntimeError: If eager execution is enabled. 

895 """ 

896 if context.executing_eagerly(): 

897 raise RuntimeError('tf.metrics.auc is not supported when eager execution ' 

898 'is enabled.') 

899 

900 with variable_scope.variable_scope(name, 'auc', 

901 (labels, predictions, weights)): 

902 if curve != 'ROC' and curve != 'PR': 

903 raise ValueError(f'Curve must be either ROC or PR. Curve {curve} is ' 

904 'unknown.') 

905 

906 kepsilon = 1e-7 # To account for floating point imprecisions. 

907 if thresholds is not None: 

908 # If specified, use the supplied thresholds. 

909 thresholds = sorted(thresholds) 

910 num_thresholds = len(thresholds) + 2 

911 else: 

912 # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in 

913 # (0, 1). 

914 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 

915 for i in range(num_thresholds - 2)] 

916 

917 # Add an endpoint "threshold" below zero and above one for either threshold 

918 # method. 

919 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 

920 

921 values, update_ops = _confusion_matrix_at_thresholds( 

922 labels, predictions, thresholds, weights) 

923 

924 # Add epsilons to avoid dividing by 0. 

925 epsilon = 1.0e-6 

926 

927 def interpolate_pr_auc(tp, fp, fn): 

928 """Interpolation formula inspired by section 4 of (Davis et al., 2006). 

929 

930 Note here we derive & use a closed formula not present in the paper 

931 - as follows: 

932 Modeling all of TP (true positive weight), 

933 FP (false positive weight) and their sum P = TP + FP (positive weight) 

934 as varying linearly within each interval [A, B] between successive 

935 thresholds, we get 

936 Precision = (TP_A + slope * (P - P_A)) / P 

937 with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A). 

938 The area within the interval is thus (slope / total_pos_weight) times 

939 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} 

940 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} 

941 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in 

942 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) 

943 Bringing back the factor (slope / total_pos_weight) we'd put aside, we get 

944 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight 

945 where dTP == TP_B - TP_A. 

946 Note that when P_A == 0 the above calculation simplifies into 

947 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) 

948 which is really equivalent to imputing constant precision throughout the 

949 first bucket having >0 true positives. 

950 

951 Args: 

952 tp: true positive counts 

953 fp: false positive counts 

954 fn: false negative counts 

955 

956 Returns: 

957 pr_auc: an approximation of the area under the P-R curve. 

958 

959 References: 

960 The Relationship Between Precision-Recall and ROC Curves: 

961 [Davis et al., 2006](https://dl.acm.org/citation.cfm?id=1143874) 

962 ([pdf](https://www.biostat.wisc.edu/~page/rocpr.pdf)) 

963 """ 

964 dtp = tp[:num_thresholds - 1] - tp[1:] 

965 p = tp + fp 

966 prec_slope = math_ops.div_no_nan( 

967 dtp, 

968 math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0), 

969 name='prec_slope') 

970 intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:]) 

971 safe_p_ratio = array_ops.where( 

972 math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0), 

973 math_ops.div_no_nan( 

974 p[:num_thresholds - 1], 

975 math_ops.maximum(p[1:], 0), 

976 name='recall_relative_ratio'), array_ops.ones_like(p[1:])) 

977 return math_ops.reduce_sum( 

978 math_ops.div_no_nan( 

979 prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)), 

980 math_ops.maximum(tp[1:] + fn[1:], 0), 

981 name='pr_auc_increment'), 

982 name='interpolate_pr_auc') 

983 

984 def compute_auc(tp, fn, tn, fp, name): 

985 """Computes the roc-auc or pr-auc based on confusion counts.""" 

986 if curve == 'PR': 

987 if summation_method == 'trapezoidal': 

988 logging.warning( 

989 'Trapezoidal rule is known to produce incorrect PR-AUCs; ' 

990 'please switch to "careful_interpolation" instead.') 

991 elif summation_method == 'careful_interpolation': 

992 # This one is a bit tricky and is handled separately. 

993 return interpolate_pr_auc(tp, fp, fn) 

994 rec = math_ops.divide(tp + epsilon, tp + fn + epsilon) 

995 if curve == 'ROC': 

996 fp_rate = math_ops.divide(fp, fp + tn + epsilon) 

997 x = fp_rate 

998 y = rec 

999 else: # curve == 'PR'. 

1000 prec = math_ops.divide(tp + epsilon, tp + fp + epsilon) 

1001 x = rec 

1002 y = prec 

1003 if summation_method in ('trapezoidal', 'careful_interpolation'): 

1004 # Note that the case ('PR', 'careful_interpolation') has been handled 

1005 # above. 

1006 return math_ops.reduce_sum( 

1007 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 

1008 (y[:num_thresholds - 1] + y[1:]) / 2.), 

1009 name=name) 

1010 elif summation_method == 'minoring': 

1011 return math_ops.reduce_sum( 

1012 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 

1013 math_ops.minimum(y[:num_thresholds - 1], y[1:])), 

1014 name=name) 

1015 elif summation_method == 'majoring': 

1016 return math_ops.reduce_sum( 

1017 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 

1018 math_ops.maximum(y[:num_thresholds - 1], y[1:])), 

1019 name=name) 

1020 else: 

1021 raise ValueError(f'Invalid summation_method: {summation_method} ' 

1022 'summation_method should be \'trapezoidal\', ' 

1023 '\'careful_interpolation\', \'minoring\', or ' 

1024 '\'majoring\'.') 

1025 

1026 # sum up the areas of all the trapeziums 

1027 def compute_auc_value(_, values): 

1028 return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'], 

1029 'value') 

1030 

1031 auc_value = _aggregate_across_replicas( 

1032 metrics_collections, compute_auc_value, values) 

1033 update_op = compute_auc(update_ops['tp'], update_ops['fn'], 

1034 update_ops['tn'], update_ops['fp'], 'update_op') 

1035 

1036 if updates_collections: 

1037 ops.add_to_collections(updates_collections, update_op) 

1038 

1039 return auc_value, update_op 

1040 

1041 

1042@tf_export(v1=['metrics.mean_absolute_error']) 

1043def mean_absolute_error(labels, 

1044 predictions, 

1045 weights=None, 

1046 metrics_collections=None, 

1047 updates_collections=None, 

1048 name=None): 

1049 """Computes the mean absolute error between the labels and predictions. 

1050 

1051 The `mean_absolute_error` function creates two local variables, 

1052 `total` and `count` that are used to compute the mean absolute error. This 

1053 average is weighted by `weights`, and it is ultimately returned as 

1054 `mean_absolute_error`: an idempotent operation that simply divides `total` by 

1055 `count`. 

1056 

1057 For estimation of the metric over a stream of data, the function creates an 

1058 `update_op` operation that updates these variables and returns the 

1059 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 

1060 absolute value of the differences between `predictions` and `labels`. Then 

1061 `update_op` increments `total` with the reduced sum of the product of 

1062 `weights` and `absolute_errors`, and it increments `count` with the reduced 

1063 sum of `weights` 

1064 

1065 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1066 

1067 Args: 

1068 labels: A `Tensor` of the same shape as `predictions`. 

1069 predictions: A `Tensor` of arbitrary shape. 

1070 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1071 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

1072 be either `1`, or the same as the corresponding `labels` dimension). 

1073 metrics_collections: An optional list of collections that 

1074 `mean_absolute_error` should be added to. 

1075 updates_collections: An optional list of collections that `update_op` should 

1076 be added to. 

1077 name: An optional variable_scope name. 

1078 

1079 Returns: 

1080 mean_absolute_error: A `Tensor` representing the current mean, the value of 

1081 `total` divided by `count`. 

1082 update_op: An operation that increments the `total` and `count` variables 

1083 appropriately and whose value matches `mean_absolute_error`. 

1084 

1085 Raises: 

1086 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

1087 `weights` is not `None` and its shape doesn't match `predictions`, or if 

1088 either `metrics_collections` or `updates_collections` are not a list or 

1089 tuple. 

1090 RuntimeError: If eager execution is enabled. 

1091 """ 

1092 if context.executing_eagerly(): 

1093 raise RuntimeError('tf.metrics.mean_absolute_error is not supported ' 

1094 'when eager execution is enabled.') 

1095 

1096 predictions, labels, weights = _remove_squeezable_dimensions( 

1097 predictions=predictions, labels=labels, weights=weights) 

1098 absolute_errors = math_ops.abs(predictions - labels) 

1099 return mean(absolute_errors, weights, metrics_collections, 

1100 updates_collections, name or 'mean_absolute_error') 

1101 

1102 

1103@tf_export(v1=['metrics.mean_cosine_distance']) 

1104def mean_cosine_distance(labels, 

1105 predictions, 

1106 dim, 

1107 weights=None, 

1108 metrics_collections=None, 

1109 updates_collections=None, 

1110 name=None): 

1111 """Computes the cosine distance between the labels and predictions. 

1112 

1113 The `mean_cosine_distance` function creates two local variables, 

1114 `total` and `count` that are used to compute the average cosine distance 

1115 between `predictions` and `labels`. This average is weighted by `weights`, 

1116 and it is ultimately returned as `mean_distance`, which is an idempotent 

1117 operation that simply divides `total` by `count`. 

1118 

1119 For estimation of the metric over a stream of data, the function creates an 

1120 `update_op` operation that updates these variables and returns the 

1121 `mean_distance`. 

1122 

1123 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1124 

1125 Args: 

1126 labels: A `Tensor` of arbitrary shape. 

1127 predictions: A `Tensor` of the same shape as `labels`. 

1128 dim: The dimension along which the cosine distance is computed. 

1129 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1130 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

1131 be either `1`, or the same as the corresponding `labels` dimension). Also, 

1132 dimension `dim` must be `1`. 

1133 metrics_collections: An optional list of collections that the metric 

1134 value variable should be added to. 

1135 updates_collections: An optional list of collections that the metric update 

1136 ops should be added to. 

1137 name: An optional variable_scope name. 

1138 

1139 Returns: 

1140 mean_distance: A `Tensor` representing the current mean, the value of 

1141 `total` divided by `count`. 

1142 update_op: An operation that increments the `total` and `count` variables 

1143 appropriately. 

1144 

1145 Raises: 

1146 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

1147 `weights` is not `None` and its shape doesn't match `predictions`, or if 

1148 either `metrics_collections` or `updates_collections` are not a list or 

1149 tuple. 

1150 RuntimeError: If eager execution is enabled. 

1151 """ 

1152 if context.executing_eagerly(): 

1153 raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when ' 

1154 'eager execution is enabled.') 

1155 

1156 predictions, labels, weights = _remove_squeezable_dimensions( 

1157 predictions=predictions, labels=labels, weights=weights) 

1158 radial_diffs = math_ops.multiply(predictions, labels) 

1159 radial_diffs = math_ops.reduce_sum( 

1160 radial_diffs, axis=[ 

1161 dim, 

1162 ], keepdims=True) 

1163 mean_distance, update_op = mean(radial_diffs, weights, None, None, name or 

1164 'mean_cosine_distance') 

1165 mean_distance = math_ops.subtract(1.0, mean_distance) 

1166 update_op = math_ops.subtract(1.0, update_op) 

1167 

1168 if metrics_collections: 

1169 ops.add_to_collections(metrics_collections, mean_distance) 

1170 

1171 if updates_collections: 

1172 ops.add_to_collections(updates_collections, update_op) 

1173 

1174 return mean_distance, update_op 

1175 

1176 

1177@tf_export(v1=['metrics.mean_per_class_accuracy']) 

1178def mean_per_class_accuracy(labels, 

1179 predictions, 

1180 num_classes, 

1181 weights=None, 

1182 metrics_collections=None, 

1183 updates_collections=None, 

1184 name=None): 

1185 """Calculates the mean of the per-class accuracies. 

1186 

1187 Calculates the accuracy for each class, then takes the mean of that. 

1188 

1189 For estimation of the metric over a stream of data, the function creates an 

1190 `update_op` operation that updates the accuracy of each class and returns 

1191 them. 

1192 

1193 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1194 

1195 Args: 

1196 labels: A `Tensor` of ground truth labels with shape [batch size] and of 

1197 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 

1198 predictions: A `Tensor` of prediction results for semantic labels, whose 

1199 shape is [batch size] and type `int32` or `int64`. The tensor will be 

1200 flattened if its rank > 1. 

1201 num_classes: The possible number of labels the prediction task can 

1202 have. This value must be provided, since two variables with shape = 

1203 [num_classes] will be allocated. 

1204 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1205 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

1206 be either `1`, or the same as the corresponding `labels` dimension). 

1207 metrics_collections: An optional list of collections that 

1208 `mean_per_class_accuracy' 

1209 should be added to. 

1210 updates_collections: An optional list of collections `update_op` should be 

1211 added to. 

1212 name: An optional variable_scope name. 

1213 

1214 Returns: 

1215 mean_accuracy: A `Tensor` representing the mean per class accuracy. 

1216 update_op: An operation that updates the accuracy tensor. 

1217 

1218 Raises: 

1219 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

1220 `weights` is not `None` and its shape doesn't match `predictions`, or if 

1221 either `metrics_collections` or `updates_collections` are not a list or 

1222 tuple. 

1223 RuntimeError: If eager execution is enabled. 

1224 """ 

1225 if context.executing_eagerly(): 

1226 raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported ' 

1227 'when eager execution is enabled.') 

1228 

1229 with variable_scope.variable_scope(name, 'mean_accuracy', 

1230 (predictions, labels, weights)): 

1231 labels = math_ops.cast(labels, dtypes.int64) 

1232 

1233 # Flatten the input if its rank > 1. 

1234 if labels.get_shape().ndims > 1: 

1235 labels = array_ops.reshape(labels, [-1]) 

1236 

1237 if predictions.get_shape().ndims > 1: 

1238 predictions = array_ops.reshape(predictions, [-1]) 

1239 

1240 # Check if shape is compatible. 

1241 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 

1242 

1243 total = metric_variable([num_classes], dtypes.float32, name='total') 

1244 count = metric_variable([num_classes], dtypes.float32, name='count') 

1245 

1246 ones = array_ops.ones([array_ops.size(labels)], dtypes.float32) 

1247 

1248 if labels.dtype != predictions.dtype: 

1249 predictions = math_ops.cast(predictions, labels.dtype) 

1250 is_correct = math_ops.cast( 

1251 math_ops.equal(predictions, labels), dtypes.float32) 

1252 

1253 if weights is not None: 

1254 if weights.get_shape().ndims > 1: 

1255 weights = array_ops.reshape(weights, [-1]) 

1256 weights = math_ops.cast(weights, dtypes.float32) 

1257 

1258 is_correct *= weights 

1259 ones *= weights 

1260 

1261 update_total_op = state_ops.scatter_add(total, labels, ones) 

1262 update_count_op = state_ops.scatter_add(count, labels, is_correct) 

1263 

1264 def compute_mean_accuracy(_, count, total): 

1265 per_class_accuracy = math_ops.div_no_nan( 

1266 count, math_ops.maximum(total, 0), name=None) 

1267 mean_accuracy_v = math_ops.reduce_mean( 

1268 per_class_accuracy, name='mean_accuracy') 

1269 return mean_accuracy_v 

1270 

1271 mean_accuracy_v = _aggregate_across_replicas( 

1272 metrics_collections, compute_mean_accuracy, count, total) 

1273 

1274 update_op = math_ops.div_no_nan( 

1275 update_count_op, math_ops.maximum(update_total_op, 0), name='update_op') 

1276 if updates_collections: 

1277 ops.add_to_collections(updates_collections, update_op) 

1278 

1279 return mean_accuracy_v, update_op 

1280 

1281 

1282@tf_export(v1=['metrics.mean_iou']) 

1283def mean_iou(labels, 

1284 predictions, 

1285 num_classes, 

1286 weights=None, 

1287 metrics_collections=None, 

1288 updates_collections=None, 

1289 name=None): 

1290 """Calculate per-step mean Intersection-Over-Union (mIOU). 

1291 

1292 Mean Intersection-Over-Union is a common evaluation metric for 

1293 semantic image segmentation, which first computes the IOU for each 

1294 semantic class and then computes the average over classes. 

1295 IOU is defined as follows: 

1296 IOU = true_positive / (true_positive + false_positive + false_negative). 

1297 The predictions are accumulated in a confusion matrix, weighted by `weights`, 

1298 and mIOU is then calculated from it. 

1299 

1300 For estimation of the metric over a stream of data, the function creates an 

1301 `update_op` operation that updates these variables and returns the `mean_iou`. 

1302 

1303 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1304 

1305 Args: 

1306 labels: A `Tensor` of ground truth labels with shape [batch size] and of 

1307 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 

1308 predictions: A `Tensor` of prediction results for semantic labels, whose 

1309 shape is [batch size] and type `int32` or `int64`. The tensor will be 

1310 flattened if its rank > 1. 

1311 num_classes: The possible number of labels the prediction task can 

1312 have. This value must be provided, since a confusion matrix of 

1313 dimension = [num_classes, num_classes] will be allocated. 

1314 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1315 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

1316 be either `1`, or the same as the corresponding `labels` dimension). 

1317 metrics_collections: An optional list of collections that `mean_iou` 

1318 should be added to. 

1319 updates_collections: An optional list of collections `update_op` should be 

1320 added to. 

1321 name: An optional variable_scope name. 

1322 

1323 Returns: 

1324 mean_iou: A `Tensor` representing the mean intersection-over-union. 

1325 update_op: An operation that increments the confusion matrix. 

1326 

1327 Raises: 

1328 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

1329 `weights` is not `None` and its shape doesn't match `predictions`, or if 

1330 either `metrics_collections` or `updates_collections` are not a list or 

1331 tuple. 

1332 RuntimeError: If eager execution is enabled. 

1333 """ 

1334 if context.executing_eagerly(): 

1335 raise RuntimeError('tf.metrics.mean_iou is not supported when ' 

1336 'eager execution is enabled.') 

1337 

1338 with variable_scope.variable_scope(name, 'mean_iou', 

1339 (predictions, labels, weights)): 

1340 # Check if shape is compatible. 

1341 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 

1342 

1343 total_cm, update_op = _streaming_confusion_matrix(labels, predictions, 

1344 num_classes, weights) 

1345 

1346 def compute_mean_iou(_, total_cm): 

1347 """Compute the mean intersection-over-union via the confusion matrix.""" 

1348 sum_over_row = math_ops.cast( 

1349 math_ops.reduce_sum(total_cm, 0), dtypes.float32) 

1350 sum_over_col = math_ops.cast( 

1351 math_ops.reduce_sum(total_cm, 1), dtypes.float32) 

1352 cm_diag = math_ops.cast(array_ops.diag_part(total_cm), dtypes.float32) 

1353 denominator = sum_over_row + sum_over_col - cm_diag 

1354 

1355 # The mean is only computed over classes that appear in the 

1356 # label or prediction tensor. If the denominator is 0, we need to 

1357 # ignore the class. 

1358 num_valid_entries = math_ops.reduce_sum( 

1359 math_ops.cast( 

1360 math_ops.not_equal(denominator, 0), dtype=dtypes.float32)) 

1361 

1362 # If the value of the denominator is 0, set it to 1 to avoid 

1363 # zero division. 

1364 denominator = array_ops.where( 

1365 math_ops.greater(denominator, 0), denominator, 

1366 array_ops.ones_like(denominator)) 

1367 iou = math_ops.divide(cm_diag, denominator) 

1368 

1369 # If the number of valid entries is 0 (no classes) we return 0. 

1370 result = array_ops.where( 

1371 math_ops.greater(num_valid_entries, 0), 

1372 math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0) 

1373 return result 

1374 

1375 # TODO(priyag): Use outside_compilation if in TPU context. 

1376 mean_iou_v = _aggregate_across_replicas( 

1377 metrics_collections, compute_mean_iou, total_cm) 

1378 

1379 if updates_collections: 

1380 ops.add_to_collections(updates_collections, update_op) 

1381 

1382 return mean_iou_v, update_op 

1383 

1384 

1385@tf_export(v1=['metrics.mean_relative_error']) 

1386def mean_relative_error(labels, 

1387 predictions, 

1388 normalizer, 

1389 weights=None, 

1390 metrics_collections=None, 

1391 updates_collections=None, 

1392 name=None): 

1393 """Computes the mean relative error by normalizing with the given values. 

1394 

1395 The `mean_relative_error` function creates two local variables, 

1396 `total` and `count` that are used to compute the mean relative absolute error. 

1397 This average is weighted by `weights`, and it is ultimately returned as 

1398 `mean_relative_error`: an idempotent operation that simply divides `total` by 

1399 `count`. 

1400 

1401 For estimation of the metric over a stream of data, the function creates an 

1402 `update_op` operation that updates these variables and returns the 

1403 `mean_reative_error`. Internally, a `relative_errors` operation divides the 

1404 absolute value of the differences between `predictions` and `labels` by the 

1405 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 

1406 product of `weights` and `relative_errors`, and it increments `count` with the 

1407 reduced sum of `weights`. 

1408 

1409 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1410 

1411 Args: 

1412 labels: A `Tensor` of the same shape as `predictions`. 

1413 predictions: A `Tensor` of arbitrary shape. 

1414 normalizer: A `Tensor` of the same shape as `predictions`. 

1415 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1416 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

1417 be either `1`, or the same as the corresponding `labels` dimension). 

1418 metrics_collections: An optional list of collections that 

1419 `mean_relative_error` should be added to. 

1420 updates_collections: An optional list of collections that `update_op` should 

1421 be added to. 

1422 name: An optional variable_scope name. 

1423 

1424 Returns: 

1425 mean_relative_error: A `Tensor` representing the current mean, the value of 

1426 `total` divided by `count`. 

1427 update_op: An operation that increments the `total` and `count` variables 

1428 appropriately and whose value matches `mean_relative_error`. 

1429 

1430 Raises: 

1431 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

1432 `weights` is not `None` and its shape doesn't match `predictions`, or if 

1433 either `metrics_collections` or `updates_collections` are not a list or 

1434 tuple. 

1435 RuntimeError: If eager execution is enabled. 

1436 """ 

1437 if context.executing_eagerly(): 

1438 raise RuntimeError('tf.metrics.mean_relative_error is not supported when ' 

1439 'eager execution is enabled.') 

1440 

1441 predictions, labels, weights = _remove_squeezable_dimensions( 

1442 predictions=predictions, labels=labels, weights=weights) 

1443 

1444 predictions, normalizer = confusion_matrix.remove_squeezable_dimensions( 

1445 predictions, normalizer) 

1446 predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) 

1447 relative_errors = array_ops.where( 

1448 math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels), 

1449 math_ops.divide(math_ops.abs(labels - predictions), normalizer)) 

1450 return mean(relative_errors, weights, metrics_collections, 

1451 updates_collections, name or 'mean_relative_error') 

1452 

1453 

1454@tf_export(v1=['metrics.mean_squared_error']) 

1455def mean_squared_error(labels, 

1456 predictions, 

1457 weights=None, 

1458 metrics_collections=None, 

1459 updates_collections=None, 

1460 name=None): 

1461 """Computes the mean squared error between the labels and predictions. 

1462 

1463 The `mean_squared_error` function creates two local variables, 

1464 `total` and `count` that are used to compute the mean squared error. 

1465 This average is weighted by `weights`, and it is ultimately returned as 

1466 `mean_squared_error`: an idempotent operation that simply divides `total` by 

1467 `count`. 

1468 

1469 For estimation of the metric over a stream of data, the function creates an 

1470 `update_op` operation that updates these variables and returns the 

1471 `mean_squared_error`. Internally, a `squared_error` operation computes the 

1472 element-wise square of the difference between `predictions` and `labels`. Then 

1473 `update_op` increments `total` with the reduced sum of the product of 

1474 `weights` and `squared_error`, and it increments `count` with the reduced sum 

1475 of `weights`. 

1476 

1477 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1478 

1479 Args: 

1480 labels: A `Tensor` of the same shape as `predictions`. 

1481 predictions: A `Tensor` of arbitrary shape. 

1482 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1483 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

1484 be either `1`, or the same as the corresponding `labels` dimension). 

1485 metrics_collections: An optional list of collections that 

1486 `mean_squared_error` should be added to. 

1487 updates_collections: An optional list of collections that `update_op` should 

1488 be added to. 

1489 name: An optional variable_scope name. 

1490 

1491 Returns: 

1492 mean_squared_error: A `Tensor` representing the current mean, the value of 

1493 `total` divided by `count`. 

1494 update_op: An operation that increments the `total` and `count` variables 

1495 appropriately and whose value matches `mean_squared_error`. 

1496 

1497 Raises: 

1498 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

1499 `weights` is not `None` and its shape doesn't match `predictions`, or if 

1500 either `metrics_collections` or `updates_collections` are not a list or 

1501 tuple. 

1502 RuntimeError: If eager execution is enabled. 

1503 """ 

1504 if context.executing_eagerly(): 

1505 raise RuntimeError('tf.metrics.mean_squared_error is not supported when ' 

1506 'eager execution is enabled.') 

1507 

1508 predictions, labels, weights = _remove_squeezable_dimensions( 

1509 predictions=predictions, labels=labels, weights=weights) 

1510 squared_error = math_ops.squared_difference(labels, predictions) 

1511 return mean(squared_error, weights, metrics_collections, updates_collections, 

1512 name or 'mean_squared_error') 

1513 

1514 

1515@tf_export(v1=['metrics.mean_tensor']) 

1516def mean_tensor(values, 

1517 weights=None, 

1518 metrics_collections=None, 

1519 updates_collections=None, 

1520 name=None): 

1521 """Computes the element-wise (weighted) mean of the given tensors. 

1522 

1523 In contrast to the `mean` function which returns a scalar with the 

1524 mean, this function returns an average tensor with the same shape as the 

1525 input tensors. 

1526 

1527 The `mean_tensor` function creates two local variables, 

1528 `total_tensor` and `count_tensor` that are used to compute the average of 

1529 `values`. This average is ultimately returned as `mean` which is an idempotent 

1530 operation that simply divides `total` by `count`. 

1531 

1532 For estimation of the metric over a stream of data, the function creates an 

1533 `update_op` operation that updates these variables and returns the `mean`. 

1534 `update_op` increments `total` with the reduced sum of the product of `values` 

1535 and `weights`, and it increments `count` with the reduced sum of `weights`. 

1536 

1537 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1538 

1539 Args: 

1540 values: A `Tensor` of arbitrary dimensions. 

1541 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1542 `values`, and must be broadcastable to `values` (i.e., all dimensions must 

1543 be either `1`, or the same as the corresponding `values` dimension). 

1544 metrics_collections: An optional list of collections that `mean` 

1545 should be added to. 

1546 updates_collections: An optional list of collections that `update_op` 

1547 should be added to. 

1548 name: An optional variable_scope name. 

1549 

1550 Returns: 

1551 mean: A float `Tensor` representing the current mean, the value of `total` 

1552 divided by `count`. 

1553 update_op: An operation that increments the `total` and `count` variables 

1554 appropriately and whose value matches `mean_value`. 

1555 

1556 Raises: 

1557 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 

1558 or if either `metrics_collections` or `updates_collections` are not a list 

1559 or tuple. 

1560 RuntimeError: If eager execution is enabled. 

1561 """ 

1562 if context.executing_eagerly(): 

1563 raise RuntimeError('tf.metrics.mean_tensor is not supported when ' 

1564 'eager execution is enabled.') 

1565 

1566 with variable_scope.variable_scope(name, 'mean', (values, weights)): 

1567 values = math_ops.cast(values, dtypes.float32) 

1568 total = metric_variable( 

1569 values.get_shape(), dtypes.float32, name='total_tensor') 

1570 count = metric_variable( 

1571 values.get_shape(), dtypes.float32, name='count_tensor') 

1572 

1573 num_values = array_ops.ones_like(values) 

1574 if weights is not None: 

1575 values, _, weights = _remove_squeezable_dimensions( 

1576 predictions=values, labels=None, weights=weights) 

1577 weights = weights_broadcast_ops.broadcast_weights( 

1578 math_ops.cast(weights, dtypes.float32), values) 

1579 values = math_ops.multiply(values, weights) 

1580 num_values = math_ops.multiply(num_values, weights) 

1581 

1582 update_total_op = state_ops.assign_add(total, values) 

1583 with ops.control_dependencies([values]): 

1584 update_count_op = state_ops.assign_add(count, num_values) 

1585 

1586 compute_mean = lambda _, t, c: math_ops.div_no_nan( # pylint: disable=g-long-lambda 

1587 t, math_ops.maximum(c, 0), name='value') 

1588 

1589 mean_t = _aggregate_across_replicas( 

1590 metrics_collections, compute_mean, total, count) 

1591 

1592 update_op = math_ops.div_no_nan( 

1593 update_total_op, math_ops.maximum(update_count_op, 0), name='update_op') 

1594 if updates_collections: 

1595 ops.add_to_collections(updates_collections, update_op) 

1596 

1597 return mean_t, update_op 

1598 

1599 

1600@tf_export(v1=['metrics.percentage_below']) 

1601def percentage_below(values, 

1602 threshold, 

1603 weights=None, 

1604 metrics_collections=None, 

1605 updates_collections=None, 

1606 name=None): 

1607 """Computes the percentage of values less than the given threshold. 

1608 

1609 The `percentage_below` function creates two local variables, 

1610 `total` and `count` that are used to compute the percentage of `values` that 

1611 fall below `threshold`. This rate is weighted by `weights`, and it is 

1612 ultimately returned as `percentage` which is an idempotent operation that 

1613 simply divides `total` by `count`. 

1614 

1615 For estimation of the metric over a stream of data, the function creates an 

1616 `update_op` operation that updates these variables and returns the 

1617 `percentage`. 

1618 

1619 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1620 

1621 Args: 

1622 values: A numeric `Tensor` of arbitrary size. 

1623 threshold: A scalar threshold. 

1624 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1625 `values`, and must be broadcastable to `values` (i.e., all dimensions must 

1626 be either `1`, or the same as the corresponding `values` dimension). 

1627 metrics_collections: An optional list of collections that the metric 

1628 value variable should be added to. 

1629 updates_collections: An optional list of collections that the metric update 

1630 ops should be added to. 

1631 name: An optional variable_scope name. 

1632 

1633 Returns: 

1634 percentage: A `Tensor` representing the current mean, the value of `total` 

1635 divided by `count`. 

1636 update_op: An operation that increments the `total` and `count` variables 

1637 appropriately. 

1638 

1639 Raises: 

1640 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 

1641 or if either `metrics_collections` or `updates_collections` are not a list 

1642 or tuple. 

1643 RuntimeError: If eager execution is enabled. 

1644 """ 

1645 if context.executing_eagerly(): 

1646 raise RuntimeError('tf.metrics.percentage_below is not supported when ' 

1647 'eager execution is enabled.') 

1648 

1649 is_below_threshold = math_ops.cast( 

1650 math_ops.less(values, threshold), dtypes.float32) 

1651 return mean(is_below_threshold, weights, metrics_collections, 

1652 updates_collections, name or 'percentage_below_threshold') 

1653 

1654 

1655def _count_condition(values, 

1656 weights=None, 

1657 metrics_collections=None, 

1658 updates_collections=None): 

1659 """Sums the weights of cases where the given values are True. 

1660 

1661 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1662 

1663 Args: 

1664 values: A `bool` `Tensor` of arbitrary size. 

1665 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1666 `values`, and must be broadcastable to `values` (i.e., all dimensions must 

1667 be either `1`, or the same as the corresponding `values` dimension). 

1668 metrics_collections: An optional list of collections that the metric 

1669 value variable should be added to. 

1670 updates_collections: An optional list of collections that the metric update 

1671 ops should be added to. 

1672 

1673 Returns: 

1674 value_tensor: A `Tensor` representing the current value of the metric. 

1675 update_op: An operation that accumulates the error from a batch of data. 

1676 

1677 Raises: 

1678 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 

1679 or if either `metrics_collections` or `updates_collections` are not a list 

1680 or tuple. 

1681 """ 

1682 check_ops.assert_type(values, dtypes.bool) 

1683 count = metric_variable([], dtypes.float32, name='count') 

1684 

1685 values = math_ops.cast(values, dtypes.float32) 

1686 if weights is not None: 

1687 with ops.control_dependencies((check_ops.assert_rank_in( 

1688 weights, (0, array_ops.rank(values))),)): 

1689 weights = math_ops.cast(weights, dtypes.float32) 

1690 values = math_ops.multiply(values, weights) 

1691 

1692 value_tensor = _aggregate_variable(count, metrics_collections) 

1693 

1694 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 

1695 if updates_collections: 

1696 ops.add_to_collections(updates_collections, update_op) 

1697 

1698 return value_tensor, update_op 

1699 

1700 

1701@tf_export(v1=['metrics.false_negatives']) 

1702def false_negatives(labels, 

1703 predictions, 

1704 weights=None, 

1705 metrics_collections=None, 

1706 updates_collections=None, 

1707 name=None): 

1708 """Computes the total number of false negatives. 

1709 

1710 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1711 

1712 Args: 

1713 labels: The ground truth values, a `Tensor` whose dimensions must match 

1714 `predictions`. Will be cast to `bool`. 

1715 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 

1716 be cast to `bool`. 

1717 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1718 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

1719 be either `1`, or the same as the corresponding `labels` dimension). 

1720 metrics_collections: An optional list of collections that the metric 

1721 value variable should be added to. 

1722 updates_collections: An optional list of collections that the metric update 

1723 ops should be added to. 

1724 name: An optional variable_scope name. 

1725 

1726 Returns: 

1727 value_tensor: A `Tensor` representing the current value of the metric. 

1728 update_op: An operation that accumulates the error from a batch of data. 

1729 

1730 Raises: 

1731 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 

1732 or if either `metrics_collections` or `updates_collections` are not a list 

1733 or tuple. 

1734 RuntimeError: If eager execution is enabled. 

1735 """ 

1736 if context.executing_eagerly(): 

1737 raise RuntimeError('tf.metrics.false_negatives is not supported when ' 

1738 'eager execution is enabled.') 

1739 

1740 with variable_scope.variable_scope(name, 'false_negatives', 

1741 (predictions, labels, weights)): 

1742 

1743 predictions, labels, weights = _remove_squeezable_dimensions( 

1744 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 

1745 labels=math_ops.cast(labels, dtype=dtypes.bool), 

1746 weights=weights) 

1747 is_false_negative = math_ops.logical_and( 

1748 math_ops.equal(labels, True), math_ops.equal(predictions, False)) 

1749 return _count_condition(is_false_negative, weights, metrics_collections, 

1750 updates_collections) 

1751 

1752 

1753@tf_export(v1=['metrics.false_negatives_at_thresholds']) 

1754def false_negatives_at_thresholds(labels, 

1755 predictions, 

1756 thresholds, 

1757 weights=None, 

1758 metrics_collections=None, 

1759 updates_collections=None, 

1760 name=None): 

1761 """Computes false negatives at provided threshold values. 

1762 

1763 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1764 

1765 Args: 

1766 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 

1767 `bool`. 

1768 predictions: A floating point `Tensor` of arbitrary shape and whose values 

1769 are in the range `[0, 1]`. 

1770 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 

1771 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1772 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

1773 be either `1`, or the same as the corresponding `labels` dimension). 

1774 metrics_collections: An optional list of collections that `false_negatives` 

1775 should be added to. 

1776 updates_collections: An optional list of collections that `update_op` should 

1777 be added to. 

1778 name: An optional variable_scope name. 

1779 

1780 Returns: 

1781 false_negatives: A float `Tensor` of shape `[len(thresholds)]`. 

1782 update_op: An operation that updates the `false_negatives` variable and 

1783 returns its current value. 

1784 

1785 Raises: 

1786 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

1787 `weights` is not `None` and its shape doesn't match `predictions`, or if 

1788 either `metrics_collections` or `updates_collections` are not a list or 

1789 tuple. 

1790 RuntimeError: If eager execution is enabled. 

1791 """ 

1792 if context.executing_eagerly(): 

1793 raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not ' 

1794 'supported when eager execution is enabled.') 

1795 

1796 with variable_scope.variable_scope(name, 'false_negatives', 

1797 (predictions, labels, weights)): 

1798 values, update_ops = _confusion_matrix_at_thresholds( 

1799 labels, predictions, thresholds, weights=weights, includes=('fn',)) 

1800 

1801 fn_value = _aggregate_variable(values['fn'], metrics_collections) 

1802 

1803 if updates_collections: 

1804 ops.add_to_collections(updates_collections, update_ops['fn']) 

1805 

1806 return fn_value, update_ops['fn'] 

1807 

1808 

1809@tf_export(v1=['metrics.false_positives']) 

1810def false_positives(labels, 

1811 predictions, 

1812 weights=None, 

1813 metrics_collections=None, 

1814 updates_collections=None, 

1815 name=None): 

1816 """Sum the weights of false positives. 

1817 

1818 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1819 

1820 Args: 

1821 labels: The ground truth values, a `Tensor` whose dimensions must match 

1822 `predictions`. Will be cast to `bool`. 

1823 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 

1824 be cast to `bool`. 

1825 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1826 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

1827 be either `1`, or the same as the corresponding `labels` dimension). 

1828 metrics_collections: An optional list of collections that the metric 

1829 value variable should be added to. 

1830 updates_collections: An optional list of collections that the metric update 

1831 ops should be added to. 

1832 name: An optional variable_scope name. 

1833 

1834 Returns: 

1835 value_tensor: A `Tensor` representing the current value of the metric. 

1836 update_op: An operation that accumulates the error from a batch of data. 

1837 

1838 Raises: 

1839 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

1840 `weights` is not `None` and its shape doesn't match `predictions`, or if 

1841 either `metrics_collections` or `updates_collections` are not a list or 

1842 tuple. 

1843 RuntimeError: If eager execution is enabled. 

1844 """ 

1845 if context.executing_eagerly(): 

1846 raise RuntimeError('tf.metrics.false_positives is not supported when ' 

1847 'eager execution is enabled.') 

1848 

1849 with variable_scope.variable_scope(name, 'false_positives', 

1850 (predictions, labels, weights)): 

1851 

1852 predictions, labels, weights = _remove_squeezable_dimensions( 

1853 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 

1854 labels=math_ops.cast(labels, dtype=dtypes.bool), 

1855 weights=weights) 

1856 is_false_positive = math_ops.logical_and( 

1857 math_ops.equal(labels, False), math_ops.equal(predictions, True)) 

1858 return _count_condition(is_false_positive, weights, metrics_collections, 

1859 updates_collections) 

1860 

1861 

1862@tf_export(v1=['metrics.false_positives_at_thresholds']) 

1863def false_positives_at_thresholds(labels, 

1864 predictions, 

1865 thresholds, 

1866 weights=None, 

1867 metrics_collections=None, 

1868 updates_collections=None, 

1869 name=None): 

1870 """Computes false positives at provided threshold values. 

1871 

1872 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1873 

1874 Args: 

1875 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 

1876 `bool`. 

1877 predictions: A floating point `Tensor` of arbitrary shape and whose values 

1878 are in the range `[0, 1]`. 

1879 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 

1880 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1881 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

1882 be either `1`, or the same as the corresponding `labels` dimension). 

1883 metrics_collections: An optional list of collections that `false_positives` 

1884 should be added to. 

1885 updates_collections: An optional list of collections that `update_op` should 

1886 be added to. 

1887 name: An optional variable_scope name. 

1888 

1889 Returns: 

1890 false_positives: A float `Tensor` of shape `[len(thresholds)]`. 

1891 update_op: An operation that updates the `false_positives` variable and 

1892 returns its current value. 

1893 

1894 Raises: 

1895 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

1896 `weights` is not `None` and its shape doesn't match `predictions`, or if 

1897 either `metrics_collections` or `updates_collections` are not a list or 

1898 tuple. 

1899 RuntimeError: If eager execution is enabled. 

1900 """ 

1901 if context.executing_eagerly(): 

1902 raise RuntimeError('tf.metrics.false_positives_at_thresholds is not ' 

1903 'supported when eager execution is enabled.') 

1904 

1905 with variable_scope.variable_scope(name, 'false_positives', 

1906 (predictions, labels, weights)): 

1907 values, update_ops = _confusion_matrix_at_thresholds( 

1908 labels, predictions, thresholds, weights=weights, includes=('fp',)) 

1909 

1910 fp_value = _aggregate_variable(values['fp'], metrics_collections) 

1911 

1912 if updates_collections: 

1913 ops.add_to_collections(updates_collections, update_ops['fp']) 

1914 

1915 return fp_value, update_ops['fp'] 

1916 

1917 

1918@tf_export(v1=['metrics.true_negatives']) 

1919def true_negatives(labels, 

1920 predictions, 

1921 weights=None, 

1922 metrics_collections=None, 

1923 updates_collections=None, 

1924 name=None): 

1925 """Sum the weights of true_negatives. 

1926 

1927 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1928 

1929 Args: 

1930 labels: The ground truth values, a `Tensor` whose dimensions must match 

1931 `predictions`. Will be cast to `bool`. 

1932 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 

1933 be cast to `bool`. 

1934 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1935 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

1936 be either `1`, or the same as the corresponding `labels` dimension). 

1937 metrics_collections: An optional list of collections that the metric 

1938 value variable should be added to. 

1939 updates_collections: An optional list of collections that the metric update 

1940 ops should be added to. 

1941 name: An optional variable_scope name. 

1942 

1943 Returns: 

1944 value_tensor: A `Tensor` representing the current value of the metric. 

1945 update_op: An operation that accumulates the error from a batch of data. 

1946 

1947 Raises: 

1948 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

1949 `weights` is not `None` and its shape doesn't match `predictions`, or if 

1950 either `metrics_collections` or `updates_collections` are not a list or 

1951 tuple. 

1952 RuntimeError: If eager execution is enabled. 

1953 """ 

1954 if context.executing_eagerly(): 

1955 raise RuntimeError('tf.metrics.true_negatives is not ' 

1956 'supported when eager execution is enabled.') 

1957 

1958 with variable_scope.variable_scope(name, 'true_negatives', 

1959 (predictions, labels, weights)): 

1960 

1961 predictions, labels, weights = _remove_squeezable_dimensions( 

1962 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 

1963 labels=math_ops.cast(labels, dtype=dtypes.bool), 

1964 weights=weights) 

1965 is_true_negative = math_ops.logical_and( 

1966 math_ops.equal(labels, False), math_ops.equal(predictions, False)) 

1967 return _count_condition(is_true_negative, weights, metrics_collections, 

1968 updates_collections) 

1969 

1970 

1971@tf_export(v1=['metrics.true_negatives_at_thresholds']) 

1972def true_negatives_at_thresholds(labels, 

1973 predictions, 

1974 thresholds, 

1975 weights=None, 

1976 metrics_collections=None, 

1977 updates_collections=None, 

1978 name=None): 

1979 """Computes true negatives at provided threshold values. 

1980 

1981 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

1982 

1983 Args: 

1984 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 

1985 `bool`. 

1986 predictions: A floating point `Tensor` of arbitrary shape and whose values 

1987 are in the range `[0, 1]`. 

1988 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 

1989 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

1990 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

1991 be either `1`, or the same as the corresponding `labels` dimension). 

1992 metrics_collections: An optional list of collections that `true_negatives` 

1993 should be added to. 

1994 updates_collections: An optional list of collections that `update_op` should 

1995 be added to. 

1996 name: An optional variable_scope name. 

1997 

1998 Returns: 

1999 true_negatives: A float `Tensor` of shape `[len(thresholds)]`. 

2000 update_op: An operation that updates the `true_negatives` variable and 

2001 returns its current value. 

2002 

2003 Raises: 

2004 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

2005 `weights` is not `None` and its shape doesn't match `predictions`, or if 

2006 either `metrics_collections` or `updates_collections` are not a list or 

2007 tuple. 

2008 RuntimeError: If eager execution is enabled. 

2009 """ 

2010 if context.executing_eagerly(): 

2011 raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not ' 

2012 'supported when eager execution is enabled.') 

2013 

2014 with variable_scope.variable_scope(name, 'true_negatives', 

2015 (predictions, labels, weights)): 

2016 values, update_ops = _confusion_matrix_at_thresholds( 

2017 labels, predictions, thresholds, weights=weights, includes=('tn',)) 

2018 

2019 tn_value = _aggregate_variable(values['tn'], metrics_collections) 

2020 

2021 if updates_collections: 

2022 ops.add_to_collections(updates_collections, update_ops['tn']) 

2023 

2024 return tn_value, update_ops['tn'] 

2025 

2026 

2027@tf_export(v1=['metrics.true_positives']) 

2028def true_positives(labels, 

2029 predictions, 

2030 weights=None, 

2031 metrics_collections=None, 

2032 updates_collections=None, 

2033 name=None): 

2034 """Sum the weights of true_positives. 

2035 

2036 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

2037 

2038 Args: 

2039 labels: The ground truth values, a `Tensor` whose dimensions must match 

2040 `predictions`. Will be cast to `bool`. 

2041 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 

2042 be cast to `bool`. 

2043 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

2044 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

2045 be either `1`, or the same as the corresponding `labels` dimension). 

2046 metrics_collections: An optional list of collections that the metric 

2047 value variable should be added to. 

2048 updates_collections: An optional list of collections that the metric update 

2049 ops should be added to. 

2050 name: An optional variable_scope name. 

2051 

2052 Returns: 

2053 value_tensor: A `Tensor` representing the current value of the metric. 

2054 update_op: An operation that accumulates the error from a batch of data. 

2055 

2056 Raises: 

2057 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

2058 `weights` is not `None` and its shape doesn't match `predictions`, or if 

2059 either `metrics_collections` or `updates_collections` are not a list or 

2060 tuple. 

2061 RuntimeError: If eager execution is enabled. 

2062 """ 

2063 if context.executing_eagerly(): 

2064 raise RuntimeError('tf.metrics.true_positives is not ' 

2065 'supported when eager execution is enabled.') 

2066 

2067 with variable_scope.variable_scope(name, 'true_positives', 

2068 (predictions, labels, weights)): 

2069 

2070 predictions, labels, weights = _remove_squeezable_dimensions( 

2071 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 

2072 labels=math_ops.cast(labels, dtype=dtypes.bool), 

2073 weights=weights) 

2074 is_true_positive = math_ops.logical_and( 

2075 math_ops.equal(labels, True), math_ops.equal(predictions, True)) 

2076 return _count_condition(is_true_positive, weights, metrics_collections, 

2077 updates_collections) 

2078 

2079 

2080@tf_export(v1=['metrics.true_positives_at_thresholds']) 

2081def true_positives_at_thresholds(labels, 

2082 predictions, 

2083 thresholds, 

2084 weights=None, 

2085 metrics_collections=None, 

2086 updates_collections=None, 

2087 name=None): 

2088 """Computes true positives at provided threshold values. 

2089 

2090 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

2091 

2092 Args: 

2093 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 

2094 `bool`. 

2095 predictions: A floating point `Tensor` of arbitrary shape and whose values 

2096 are in the range `[0, 1]`. 

2097 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 

2098 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

2099 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

2100 be either `1`, or the same as the corresponding `labels` dimension). 

2101 metrics_collections: An optional list of collections that `true_positives` 

2102 should be added to. 

2103 updates_collections: An optional list of collections that `update_op` should 

2104 be added to. 

2105 name: An optional variable_scope name. 

2106 

2107 Returns: 

2108 true_positives: A float `Tensor` of shape `[len(thresholds)]`. 

2109 update_op: An operation that updates the `true_positives` variable and 

2110 returns its current value. 

2111 

2112 Raises: 

2113 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

2114 `weights` is not `None` and its shape doesn't match `predictions`, or if 

2115 either `metrics_collections` or `updates_collections` are not a list or 

2116 tuple. 

2117 RuntimeError: If eager execution is enabled. 

2118 """ 

2119 if context.executing_eagerly(): 

2120 raise RuntimeError('tf.metrics.true_positives_at_thresholds is not ' 

2121 'supported when eager execution is enabled.') 

2122 

2123 with variable_scope.variable_scope(name, 'true_positives', 

2124 (predictions, labels, weights)): 

2125 values, update_ops = _confusion_matrix_at_thresholds( 

2126 labels, predictions, thresholds, weights=weights, includes=('tp',)) 

2127 

2128 tp_value = _aggregate_variable(values['tp'], metrics_collections) 

2129 

2130 if updates_collections: 

2131 ops.add_to_collections(updates_collections, update_ops['tp']) 

2132 

2133 return tp_value, update_ops['tp'] 

2134 

2135 

2136@tf_export(v1=['metrics.precision']) 

2137def precision(labels, 

2138 predictions, 

2139 weights=None, 

2140 metrics_collections=None, 

2141 updates_collections=None, 

2142 name=None): 

2143 """Computes the precision of the predictions with respect to the labels. 

2144 

2145 The `precision` function creates two local variables, 

2146 `true_positives` and `false_positives`, that are used to compute the 

2147 precision. This value is ultimately returned as `precision`, an idempotent 

2148 operation that simply divides `true_positives` by the sum of `true_positives` 

2149 and `false_positives`. 

2150 

2151 For estimation of the metric over a stream of data, the function creates an 

2152 `update_op` operation that updates these variables and returns the 

2153 `precision`. `update_op` weights each prediction by the corresponding value in 

2154 `weights`. 

2155 

2156 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

2157 

2158 Args: 

2159 labels: The ground truth values, a `Tensor` whose dimensions must match 

2160 `predictions`. Will be cast to `bool`. 

2161 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 

2162 be cast to `bool`. 

2163 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

2164 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

2165 be either `1`, or the same as the corresponding `labels` dimension). 

2166 metrics_collections: An optional list of collections that `precision` should 

2167 be added to. 

2168 updates_collections: An optional list of collections that `update_op` should 

2169 be added to. 

2170 name: An optional variable_scope name. 

2171 

2172 Returns: 

2173 precision: Scalar float `Tensor` with the value of `true_positives` 

2174 divided by the sum of `true_positives` and `false_positives`. 

2175 update_op: `Operation` that increments `true_positives` and 

2176 `false_positives` variables appropriately and whose value matches 

2177 `precision`. 

2178 

2179 Raises: 

2180 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

2181 `weights` is not `None` and its shape doesn't match `predictions`, or if 

2182 either `metrics_collections` or `updates_collections` are not a list or 

2183 tuple. 

2184 RuntimeError: If eager execution is enabled. 

2185 """ 

2186 if context.executing_eagerly(): 

2187 raise RuntimeError('tf.metrics.precision is not ' 

2188 'supported when eager execution is enabled.') 

2189 

2190 with variable_scope.variable_scope(name, 'precision', 

2191 (predictions, labels, weights)): 

2192 

2193 predictions, labels, weights = _remove_squeezable_dimensions( 

2194 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 

2195 labels=math_ops.cast(labels, dtype=dtypes.bool), 

2196 weights=weights) 

2197 

2198 true_p, true_positives_update_op = true_positives( 

2199 labels, 

2200 predictions, 

2201 weights, 

2202 metrics_collections=None, 

2203 updates_collections=None, 

2204 name=None) 

2205 false_p, false_positives_update_op = false_positives( 

2206 labels, 

2207 predictions, 

2208 weights, 

2209 metrics_collections=None, 

2210 updates_collections=None, 

2211 name=None) 

2212 

2213 def compute_precision(tp, fp, name): 

2214 return array_ops.where( 

2215 math_ops.greater(tp + fp, 0), math_ops.divide(tp, tp + fp), 0, name) 

2216 

2217 def once_across_replicas(_, true_p, false_p): 

2218 return compute_precision(true_p, false_p, 'value') 

2219 

2220 p = _aggregate_across_replicas(metrics_collections, once_across_replicas, 

2221 true_p, false_p) 

2222 

2223 update_op = compute_precision(true_positives_update_op, 

2224 false_positives_update_op, 'update_op') 

2225 if updates_collections: 

2226 ops.add_to_collections(updates_collections, update_op) 

2227 

2228 return p, update_op 

2229 

2230 

2231@tf_export(v1=['metrics.precision_at_thresholds']) 

2232def precision_at_thresholds(labels, 

2233 predictions, 

2234 thresholds, 

2235 weights=None, 

2236 metrics_collections=None, 

2237 updates_collections=None, 

2238 name=None): 

2239 """Computes precision values for different `thresholds` on `predictions`. 

2240 

2241 The `precision_at_thresholds` function creates four local variables, 

2242 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 

2243 for various values of thresholds. `precision[i]` is defined as the total 

2244 weight of values in `predictions` above `thresholds[i]` whose corresponding 

2245 entry in `labels` is `True`, divided by the total weight of values in 

2246 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 

2247 false_positives[i])`). 

2248 

2249 For estimation of the metric over a stream of data, the function creates an 

2250 `update_op` operation that updates these variables and returns the 

2251 `precision`. 

2252 

2253 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

2254 

2255 Args: 

2256 labels: The ground truth values, a `Tensor` whose dimensions must match 

2257 `predictions`. Will be cast to `bool`. 

2258 predictions: A floating point `Tensor` of arbitrary shape and whose values 

2259 are in the range `[0, 1]`. 

2260 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 

2261 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

2262 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

2263 be either `1`, or the same as the corresponding `labels` dimension). 

2264 metrics_collections: An optional list of collections that `auc` should be 

2265 added to. 

2266 updates_collections: An optional list of collections that `update_op` should 

2267 be added to. 

2268 name: An optional variable_scope name. 

2269 

2270 Returns: 

2271 precision: A float `Tensor` of shape `[len(thresholds)]`. 

2272 update_op: An operation that increments the `true_positives`, 

2273 `true_negatives`, `false_positives` and `false_negatives` variables that 

2274 are used in the computation of `precision`. 

2275 

2276 Raises: 

2277 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

2278 `weights` is not `None` and its shape doesn't match `predictions`, or if 

2279 either `metrics_collections` or `updates_collections` are not a list or 

2280 tuple. 

2281 RuntimeError: If eager execution is enabled. 

2282 """ 

2283 if context.executing_eagerly(): 

2284 raise RuntimeError('tf.metrics.precision_at_thresholds is not ' 

2285 'supported when eager execution is enabled.') 

2286 

2287 with variable_scope.variable_scope(name, 'precision_at_thresholds', 

2288 (predictions, labels, weights)): 

2289 values, update_ops = _confusion_matrix_at_thresholds( 

2290 labels, predictions, thresholds, weights, includes=('tp', 'fp')) 

2291 

2292 # Avoid division by zero. 

2293 epsilon = 1e-7 

2294 

2295 def compute_precision(tp, fp, name): 

2296 return math_ops.divide(tp, epsilon + tp + fp, name='precision_' + name) 

2297 

2298 def precision_across_replicas(_, values): 

2299 return compute_precision(values['tp'], values['fp'], 'value') 

2300 

2301 prec = _aggregate_across_replicas( 

2302 metrics_collections, precision_across_replicas, values) 

2303 

2304 update_op = compute_precision(update_ops['tp'], update_ops['fp'], 

2305 'update_op') 

2306 if updates_collections: 

2307 ops.add_to_collections(updates_collections, update_op) 

2308 

2309 return prec, update_op 

2310 

2311 

2312@tf_export(v1=['metrics.recall']) 

2313def recall(labels, 

2314 predictions, 

2315 weights=None, 

2316 metrics_collections=None, 

2317 updates_collections=None, 

2318 name=None): 

2319 """Computes the recall of the predictions with respect to the labels. 

2320 

2321 The `recall` function creates two local variables, `true_positives` 

2322 and `false_negatives`, that are used to compute the recall. This value is 

2323 ultimately returned as `recall`, an idempotent operation that simply divides 

2324 `true_positives` by the sum of `true_positives` and `false_negatives`. 

2325 

2326 For estimation of the metric over a stream of data, the function creates an 

2327 `update_op` that updates these variables and returns the `recall`. `update_op` 

2328 weights each prediction by the corresponding value in `weights`. 

2329 

2330 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

2331 

2332 Args: 

2333 labels: The ground truth values, a `Tensor` whose dimensions must match 

2334 `predictions`. Will be cast to `bool`. 

2335 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 

2336 be cast to `bool`. 

2337 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

2338 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

2339 be either `1`, or the same as the corresponding `labels` dimension). 

2340 metrics_collections: An optional list of collections that `recall` should 

2341 be added to. 

2342 updates_collections: An optional list of collections that `update_op` should 

2343 be added to. 

2344 name: An optional variable_scope name. 

2345 

2346 Returns: 

2347 recall: Scalar float `Tensor` with the value of `true_positives` divided 

2348 by the sum of `true_positives` and `false_negatives`. 

2349 update_op: `Operation` that increments `true_positives` and 

2350 `false_negatives` variables appropriately and whose value matches 

2351 `recall`. 

2352 

2353 Raises: 

2354 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

2355 `weights` is not `None` and its shape doesn't match `predictions`, or if 

2356 either `metrics_collections` or `updates_collections` are not a list or 

2357 tuple. 

2358 RuntimeError: If eager execution is enabled. 

2359 """ 

2360 if context.executing_eagerly(): 

2361 raise RuntimeError('tf.metrics.recall is not supported is not ' 

2362 'supported when eager execution is enabled.') 

2363 

2364 with variable_scope.variable_scope(name, 'recall', 

2365 (predictions, labels, weights)): 

2366 predictions, labels, weights = _remove_squeezable_dimensions( 

2367 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 

2368 labels=math_ops.cast(labels, dtype=dtypes.bool), 

2369 weights=weights) 

2370 

2371 true_p, true_positives_update_op = true_positives( 

2372 labels, 

2373 predictions, 

2374 weights, 

2375 metrics_collections=None, 

2376 updates_collections=None, 

2377 name=None) 

2378 false_n, false_negatives_update_op = false_negatives( 

2379 labels, 

2380 predictions, 

2381 weights, 

2382 metrics_collections=None, 

2383 updates_collections=None, 

2384 name=None) 

2385 

2386 def compute_recall(true_p, false_n, name): 

2387 return array_ops.where( 

2388 math_ops.greater(true_p + false_n, 0), 

2389 math_ops.divide(true_p, true_p + false_n), 0, name) 

2390 

2391 def once_across_replicas(_, true_p, false_n): 

2392 return compute_recall(true_p, false_n, 'value') 

2393 

2394 rec = _aggregate_across_replicas( 

2395 metrics_collections, once_across_replicas, true_p, false_n) 

2396 

2397 update_op = compute_recall(true_positives_update_op, 

2398 false_negatives_update_op, 'update_op') 

2399 if updates_collections: 

2400 ops.add_to_collections(updates_collections, update_op) 

2401 

2402 return rec, update_op 

2403 

2404 

2405def _at_k_name(name, k=None, class_id=None): 

2406 if k is not None: 

2407 name = '%s_at_%d' % (name, k) 

2408 else: 

2409 name = '%s_at_k' % (name) 

2410 if class_id is not None: 

2411 name = '%s_class%d' % (name, class_id) 

2412 return name 

2413 

2414 

2415def _select_class_id(ids, selected_id): 

2416 """Filter all but `selected_id` out of `ids`. 

2417 

2418 Args: 

2419 ids: `int64` `Tensor` or `SparseTensor` of IDs. 

2420 selected_id: Int id to select. 

2421 

2422 Returns: 

2423 `SparseTensor` of same dimensions as `ids`. This contains only the entries 

2424 equal to `selected_id`. 

2425 """ 

2426 ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids) 

2427 if isinstance(ids, sparse_tensor.SparseTensor): 

2428 return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values, 

2429 selected_id)) 

2430 

2431 # TODO(ptucker): Make this more efficient, maybe add a sparse version of 

2432 # tf.equal and tf.reduce_any? 

2433 

2434 # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1. 

2435 ids_shape = array_ops.shape(ids, out_type=dtypes.int64) 

2436 ids_last_dim = array_ops.size(ids_shape) - 1 

2437 filled_selected_id_shape = math_ops.reduced_shape(ids_shape, 

2438 array_ops.reshape( 

2439 ids_last_dim, [1])) 

2440 

2441 # Intersect `ids` with the selected ID. 

2442 filled_selected_id = array_ops.fill(filled_selected_id_shape, 

2443 math_ops.cast(selected_id, dtypes.int64)) 

2444 result = sets.set_intersection(filled_selected_id, ids) 

2445 return sparse_tensor.SparseTensor( 

2446 indices=result.indices, values=result.values, dense_shape=ids_shape) 

2447 

2448 

2449def _maybe_select_class_id(labels, predictions_idx, selected_id=None): 

2450 """If class ID is specified, filter all other classes. 

2451 

2452 Args: 

2453 labels: `int64` `Tensor` or `SparseTensor` with shape 

2454 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 

2455 target classes for the associated prediction. Commonly, N=1 and `labels` 

2456 has shape [batch_size, num_labels]. [D1, ... DN] must match 

2457 `predictions_idx`. 

2458 predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k] 

2459 where N >= 1. Commonly, N=1 and `predictions_idx` has shape 

2460 [batch size, k]. 

2461 selected_id: Int id to select. 

2462 

2463 Returns: 

2464 Tuple of `labels` and `predictions_idx`, possibly with classes removed. 

2465 """ 

2466 if selected_id is None: 

2467 return labels, predictions_idx 

2468 return (_select_class_id(labels, selected_id), 

2469 _select_class_id(predictions_idx, selected_id)) 

2470 

2471 

2472def _sparse_true_positive_at_k(labels, 

2473 predictions_idx, 

2474 class_id=None, 

2475 weights=None, 

2476 name=None): 

2477 """Calculates true positives for recall@k and precision@k. 

2478 

2479 If `class_id` is specified, calculate binary true positives for `class_id` 

2480 only. 

2481 If `class_id` is not specified, calculate metrics for `k` predicted vs 

2482 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 

2483 

2484 Args: 

2485 labels: `int64` `Tensor` or `SparseTensor` with shape 

2486 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 

2487 target classes for the associated prediction. Commonly, N=1 and `labels` 

2488 has shape [batch_size, num_labels]. [D1, ... DN] must match 

2489 `predictions_idx`. 

2490 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 

2491 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 

2492 match `labels`. 

2493 class_id: Class for which we want binary metrics. 

2494 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 

2495 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 

2496 dimensions must be either `1`, or the same as the corresponding `labels` 

2497 dimension). 

2498 name: Name of operation. 

2499 

2500 Returns: 

2501 A [D1, ... DN] `Tensor` of true positive counts. 

2502 """ 

2503 with ops.name_scope(name, 'true_positives', 

2504 (predictions_idx, labels, weights)): 

2505 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 

2506 class_id) 

2507 tp = sets.set_size(sets.set_intersection(predictions_idx, labels)) 

2508 tp = math_ops.cast(tp, dtypes.float64) 

2509 if weights is not None: 

2510 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 

2511 weights, tp),)): 

2512 weights = math_ops.cast(weights, dtypes.float64) 

2513 tp = math_ops.multiply(tp, weights) 

2514 return tp 

2515 

2516 

2517def _streaming_sparse_true_positive_at_k(labels, 

2518 predictions_idx, 

2519 k=None, 

2520 class_id=None, 

2521 weights=None, 

2522 name=None): 

2523 """Calculates weighted per step true positives for recall@k and precision@k. 

2524 

2525 If `class_id` is specified, calculate binary true positives for `class_id` 

2526 only. 

2527 If `class_id` is not specified, calculate metrics for `k` predicted vs 

2528 `n` label classes, where `n` is the 2nd dimension of `labels`. 

2529 

2530 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

2531 

2532 Args: 

2533 labels: `int64` `Tensor` or `SparseTensor` with shape 

2534 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 

2535 target classes for the associated prediction. Commonly, N=1 and `labels` 

2536 has shape [batch_size, num_labels]. [D1, ... DN] must match 

2537 `predictions_idx`. 

2538 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 

2539 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 

2540 match `labels`. 

2541 k: Integer, k for @k metric. This is only used for default op name. 

2542 class_id: Class for which we want binary metrics. 

2543 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 

2544 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 

2545 dimensions must be either `1`, or the same as the corresponding `labels` 

2546 dimension). 

2547 name: Name of new variable, and namespace for other dependent ops. 

2548 

2549 Returns: 

2550 A tuple of `Variable` and update `Operation`. 

2551 

2552 Raises: 

2553 ValueError: If `weights` is not `None` and has an incompatible shape. 

2554 """ 

2555 with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id), 

2556 (predictions_idx, labels, weights)) as scope: 

2557 tp = _sparse_true_positive_at_k( 

2558 predictions_idx=predictions_idx, 

2559 labels=labels, 

2560 class_id=class_id, 

2561 weights=weights) 

2562 batch_total_tp = math_ops.cast(math_ops.reduce_sum(tp), dtypes.float64) 

2563 

2564 var = metric_variable([], dtypes.float64, name=scope) 

2565 return var, state_ops.assign_add(var, batch_total_tp, name='update') 

2566 

2567 

2568def _sparse_false_negative_at_k(labels, 

2569 predictions_idx, 

2570 class_id=None, 

2571 weights=None): 

2572 """Calculates false negatives for recall@k. 

2573 

2574 If `class_id` is specified, calculate binary true positives for `class_id` 

2575 only. 

2576 If `class_id` is not specified, calculate metrics for `k` predicted vs 

2577 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 

2578 

2579 Args: 

2580 labels: `int64` `Tensor` or `SparseTensor` with shape 

2581 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 

2582 target classes for the associated prediction. Commonly, N=1 and `labels` 

2583 has shape [batch_size, num_labels]. [D1, ... DN] must match 

2584 `predictions_idx`. 

2585 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 

2586 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 

2587 match `labels`. 

2588 class_id: Class for which we want binary metrics. 

2589 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 

2590 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 

2591 dimensions must be either `1`, or the same as the corresponding `labels` 

2592 dimension). 

2593 

2594 Returns: 

2595 A [D1, ... DN] `Tensor` of false negative counts. 

2596 """ 

2597 with ops.name_scope(None, 'false_negatives', 

2598 (predictions_idx, labels, weights)): 

2599 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 

2600 class_id) 

2601 fn = sets.set_size( 

2602 sets.set_difference(predictions_idx, labels, aminusb=False)) 

2603 fn = math_ops.cast(fn, dtypes.float64) 

2604 if weights is not None: 

2605 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 

2606 weights, fn),)): 

2607 weights = math_ops.cast(weights, dtypes.float64) 

2608 fn = math_ops.multiply(fn, weights) 

2609 return fn 

2610 

2611 

2612def _streaming_sparse_false_negative_at_k(labels, 

2613 predictions_idx, 

2614 k, 

2615 class_id=None, 

2616 weights=None, 

2617 name=None): 

2618 """Calculates weighted per step false negatives for recall@k. 

2619 

2620 If `class_id` is specified, calculate binary true positives for `class_id` 

2621 only. 

2622 If `class_id` is not specified, calculate metrics for `k` predicted vs 

2623 `n` label classes, where `n` is the 2nd dimension of `labels`. 

2624 

2625 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

2626 

2627 Args: 

2628 labels: `int64` `Tensor` or `SparseTensor` with shape 

2629 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 

2630 target classes for the associated prediction. Commonly, N=1 and `labels` 

2631 has shape [batch_size, num_labels]. [D1, ... DN] must match 

2632 `predictions_idx`. 

2633 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 

2634 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 

2635 match `labels`. 

2636 k: Integer, k for @k metric. This is only used for default op name. 

2637 class_id: Class for which we want binary metrics. 

2638 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 

2639 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 

2640 dimensions must be either `1`, or the same as the corresponding `labels` 

2641 dimension). 

2642 name: Name of new variable, and namespace for other dependent ops. 

2643 

2644 Returns: 

2645 A tuple of `Variable` and update `Operation`. 

2646 

2647 Raises: 

2648 ValueError: If `weights` is not `None` and has an incompatible shape. 

2649 """ 

2650 with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id), 

2651 (predictions_idx, labels, weights)) as scope: 

2652 fn = _sparse_false_negative_at_k( 

2653 predictions_idx=predictions_idx, 

2654 labels=labels, 

2655 class_id=class_id, 

2656 weights=weights) 

2657 batch_total_fn = math_ops.cast(math_ops.reduce_sum(fn), dtypes.float64) 

2658 

2659 var = metric_variable([], dtypes.float64, name=scope) 

2660 return var, state_ops.assign_add(var, batch_total_fn, name='update') 

2661 

2662 

2663@tf_export(v1=['metrics.recall_at_k']) 

2664def recall_at_k(labels, 

2665 predictions, 

2666 k, 

2667 class_id=None, 

2668 weights=None, 

2669 metrics_collections=None, 

2670 updates_collections=None, 

2671 name=None): 

2672 """Computes recall@k of the predictions with respect to sparse labels. 

2673 

2674 If `class_id` is specified, we calculate recall by considering only the 

2675 entries in the batch for which `class_id` is in the label, and computing 

2676 the fraction of them for which `class_id` is in the top-k `predictions`. 

2677 If `class_id` is not specified, we'll calculate recall as how often on 

2678 average a class among the labels of a batch entry is in the top-k 

2679 `predictions`. 

2680 

2681 `sparse_recall_at_k` creates two local variables, 

2682 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 

2683 the recall_at_k frequency. This frequency is ultimately returned as 

2684 `recall_at_<k>`: an idempotent operation that simply divides 

2685 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 

2686 `false_negative_at_<k>`). 

2687 

2688 For estimation of the metric over a stream of data, the function creates an 

2689 `update_op` operation that updates these variables and returns the 

2690 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 

2691 indicating the top `k` `predictions`. Set operations applied to `top_k` and 

2692 `labels` calculate the true positives and false negatives weighted by 

2693 `weights`. Then `update_op` increments `true_positive_at_<k>` and 

2694 `false_negative_at_<k>` using these values. 

2695 

2696 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

2697 

2698 Args: 

2699 labels: `int64` `Tensor` or `SparseTensor` with shape 

2700 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 

2701 num_labels=1. N >= 1 and num_labels is the number of target classes for 

2702 the associated prediction. Commonly, N=1 and `labels` has shape 

2703 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 

2704 should be in range [0, num_classes), where num_classes is the last 

2705 dimension of `predictions`. Values outside this range always count 

2706 towards `false_negative_at_<k>`. 

2707 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 

2708 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 

2709 The final dimension contains the logit values for each class. [D1, ... DN] 

2710 must match `labels`. 

2711 k: Integer, k for @k metric. 

2712 class_id: Integer class ID for which we want binary metrics. This should be 

2713 in range [0, num_classes), where num_classes is the last dimension of 

2714 `predictions`. If class_id is outside this range, the method returns NAN. 

2715 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 

2716 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 

2717 dimensions must be either `1`, or the same as the corresponding `labels` 

2718 dimension). 

2719 metrics_collections: An optional list of collections that values should 

2720 be added to. 

2721 updates_collections: An optional list of collections that updates should 

2722 be added to. 

2723 name: Name of new update operation, and namespace for other dependent ops. 

2724 

2725 Returns: 

2726 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 

2727 by the sum of `true_positives` and `false_negatives`. 

2728 update_op: `Operation` that increments `true_positives` and 

2729 `false_negatives` variables appropriately, and whose value matches 

2730 `recall`. 

2731 

2732 Raises: 

2733 ValueError: If `weights` is not `None` and its shape doesn't match 

2734 `predictions`, or if either `metrics_collections` or `updates_collections` 

2735 are not a list or tuple. 

2736 RuntimeError: If eager execution is enabled. 

2737 """ 

2738 if context.executing_eagerly(): 

2739 raise RuntimeError('tf.metrics.recall_at_k is not ' 

2740 'supported when eager execution is enabled.') 

2741 

2742 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), 

2743 (predictions, labels, weights)) as scope: 

2744 _, top_k_idx = nn.top_k(predictions, k) 

2745 return recall_at_top_k( 

2746 labels=labels, 

2747 predictions_idx=top_k_idx, 

2748 k=k, 

2749 class_id=class_id, 

2750 weights=weights, 

2751 metrics_collections=metrics_collections, 

2752 updates_collections=updates_collections, 

2753 name=scope) 

2754 

2755 

2756@tf_export(v1=['metrics.recall_at_top_k']) 

2757def recall_at_top_k(labels, 

2758 predictions_idx, 

2759 k=None, 

2760 class_id=None, 

2761 weights=None, 

2762 metrics_collections=None, 

2763 updates_collections=None, 

2764 name=None): 

2765 """Computes recall@k of top-k predictions with respect to sparse labels. 

2766 

2767 Differs from `recall_at_k` in that predictions must be in the form of top `k` 

2768 class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k` 

2769 for more details. 

2770 

2771 Args: 

2772 labels: `int64` `Tensor` or `SparseTensor` with shape 

2773 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 

2774 num_labels=1. N >= 1 and num_labels is the number of target classes for 

2775 the associated prediction. Commonly, N=1 and `labels` has shape 

2776 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 

2777 should be in range [0, num_classes), where num_classes is the last 

2778 dimension of `predictions`. Values outside this range always count 

2779 towards `false_negative_at_<k>`. 

2780 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 

2781 Commonly, N=1 and predictions has shape [batch size, k]. The final 

2782 dimension contains the top `k` predicted class indices. [D1, ... DN] must 

2783 match `labels`. 

2784 k: Integer, k for @k metric. Only used for the default op name. 

2785 class_id: Integer class ID for which we want binary metrics. This should be 

2786 in range [0, num_classes), where num_classes is the last dimension of 

2787 `predictions`. If class_id is outside this range, the method returns NAN. 

2788 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 

2789 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 

2790 dimensions must be either `1`, or the same as the corresponding `labels` 

2791 dimension). 

2792 metrics_collections: An optional list of collections that values should 

2793 be added to. 

2794 updates_collections: An optional list of collections that updates should 

2795 be added to. 

2796 name: Name of new update operation, and namespace for other dependent ops. 

2797 

2798 Returns: 

2799 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 

2800 by the sum of `true_positives` and `false_negatives`. 

2801 update_op: `Operation` that increments `true_positives` and 

2802 `false_negatives` variables appropriately, and whose value matches 

2803 `recall`. 

2804 

2805 Raises: 

2806 ValueError: If `weights` is not `None` and its shape doesn't match 

2807 `predictions`, or if either `metrics_collections` or `updates_collections` 

2808 are not a list or tuple. 

2809 """ 

2810 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), 

2811 (predictions_idx, labels, weights)) as scope: 

2812 labels = _maybe_expand_labels(labels, predictions_idx) 

2813 top_k_idx = math_ops.cast(predictions_idx, dtypes.int64) 

2814 tp, tp_update = _streaming_sparse_true_positive_at_k( 

2815 predictions_idx=top_k_idx, 

2816 labels=labels, 

2817 k=k, 

2818 class_id=class_id, 

2819 weights=weights) 

2820 fn, fn_update = _streaming_sparse_false_negative_at_k( 

2821 predictions_idx=top_k_idx, 

2822 labels=labels, 

2823 k=k, 

2824 class_id=class_id, 

2825 weights=weights) 

2826 

2827 def compute_recall(_, tp, fn): 

2828 return math_ops.divide(tp, math_ops.add(tp, fn), name=scope) 

2829 

2830 metric = _aggregate_across_replicas( 

2831 metrics_collections, compute_recall, tp, fn) 

2832 

2833 update = math_ops.divide( 

2834 tp_update, math_ops.add(tp_update, fn_update), name='update') 

2835 if updates_collections: 

2836 ops.add_to_collections(updates_collections, update) 

2837 return metric, update 

2838 

2839 

2840@tf_export(v1=['metrics.recall_at_thresholds']) 

2841def recall_at_thresholds(labels, 

2842 predictions, 

2843 thresholds, 

2844 weights=None, 

2845 metrics_collections=None, 

2846 updates_collections=None, 

2847 name=None): 

2848 """Computes various recall values for different `thresholds` on `predictions`. 

2849 

2850 The `recall_at_thresholds` function creates four local variables, 

2851 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 

2852 for various values of thresholds. `recall[i]` is defined as the total weight 

2853 of values in `predictions` above `thresholds[i]` whose corresponding entry in 

2854 `labels` is `True`, divided by the total weight of `True` values in `labels` 

2855 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 

2856 

2857 For estimation of the metric over a stream of data, the function creates an 

2858 `update_op` operation that updates these variables and returns the `recall`. 

2859 

2860 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

2861 

2862 Args: 

2863 labels: The ground truth values, a `Tensor` whose dimensions must match 

2864 `predictions`. Will be cast to `bool`. 

2865 predictions: A floating point `Tensor` of arbitrary shape and whose values 

2866 are in the range `[0, 1]`. 

2867 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 

2868 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

2869 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

2870 be either `1`, or the same as the corresponding `labels` dimension). 

2871 metrics_collections: An optional list of collections that `recall` should be 

2872 added to. 

2873 updates_collections: An optional list of collections that `update_op` should 

2874 be added to. 

2875 name: An optional variable_scope name. 

2876 

2877 Returns: 

2878 recall: A float `Tensor` of shape `[len(thresholds)]`. 

2879 update_op: An operation that increments the `true_positives`, 

2880 `true_negatives`, `false_positives` and `false_negatives` variables that 

2881 are used in the computation of `recall`. 

2882 

2883 Raises: 

2884 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

2885 `weights` is not `None` and its shape doesn't match `predictions`, or if 

2886 either `metrics_collections` or `updates_collections` are not a list or 

2887 tuple. 

2888 RuntimeError: If eager execution is enabled. 

2889 """ 

2890 if context.executing_eagerly(): 

2891 raise RuntimeError('tf.metrics.recall_at_thresholds is not ' 

2892 'supported when eager execution is enabled.') 

2893 

2894 with variable_scope.variable_scope(name, 'recall_at_thresholds', 

2895 (predictions, labels, weights)): 

2896 values, update_ops = _confusion_matrix_at_thresholds( 

2897 labels, predictions, thresholds, weights, includes=('tp', 'fn')) 

2898 

2899 # Avoid division by zero. 

2900 epsilon = 1e-7 

2901 

2902 def compute_recall(tp, fn, name): 

2903 return math_ops.divide(tp, epsilon + tp + fn, name='recall_' + name) 

2904 

2905 def recall_across_replicas(_, values): 

2906 return compute_recall(values['tp'], values['fn'], 'value') 

2907 

2908 rec = _aggregate_across_replicas( 

2909 metrics_collections, recall_across_replicas, values) 

2910 

2911 update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') 

2912 if updates_collections: 

2913 ops.add_to_collections(updates_collections, update_op) 

2914 

2915 return rec, update_op 

2916 

2917 

2918@tf_export(v1=['metrics.root_mean_squared_error']) 

2919def root_mean_squared_error(labels, 

2920 predictions, 

2921 weights=None, 

2922 metrics_collections=None, 

2923 updates_collections=None, 

2924 name=None): 

2925 """Computes the root mean squared error between the labels and predictions. 

2926 

2927 The `root_mean_squared_error` function creates two local variables, 

2928 `total` and `count` that are used to compute the root mean squared error. 

2929 This average is weighted by `weights`, and it is ultimately returned as 

2930 `root_mean_squared_error`: an idempotent operation that takes the square root 

2931 of the division of `total` by `count`. 

2932 

2933 For estimation of the metric over a stream of data, the function creates an 

2934 `update_op` operation that updates these variables and returns the 

2935 `root_mean_squared_error`. Internally, a `squared_error` operation computes 

2936 the element-wise square of the difference between `predictions` and `labels`. 

2937 Then `update_op` increments `total` with the reduced sum of the product of 

2938 `weights` and `squared_error`, and it increments `count` with the reduced sum 

2939 of `weights`. 

2940 

2941 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

2942 

2943 Args: 

2944 labels: A `Tensor` of the same shape as `predictions`. 

2945 predictions: A `Tensor` of arbitrary shape. 

2946 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

2947 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

2948 be either `1`, or the same as the corresponding `labels` dimension). 

2949 metrics_collections: An optional list of collections that 

2950 `root_mean_squared_error` should be added to. 

2951 updates_collections: An optional list of collections that `update_op` should 

2952 be added to. 

2953 name: An optional variable_scope name. 

2954 

2955 Returns: 

2956 root_mean_squared_error: A `Tensor` representing the current mean, the value 

2957 of `total` divided by `count`. 

2958 update_op: An operation that increments the `total` and `count` variables 

2959 appropriately and whose value matches `root_mean_squared_error`. 

2960 

2961 Raises: 

2962 ValueError: If `predictions` and `labels` have mismatched shapes, or if 

2963 `weights` is not `None` and its shape doesn't match `predictions`, or if 

2964 either `metrics_collections` or `updates_collections` are not a list or 

2965 tuple. 

2966 RuntimeError: If eager execution is enabled. 

2967 """ 

2968 if context.executing_eagerly(): 

2969 raise RuntimeError('tf.metrics.root_mean_squared_error is not ' 

2970 'supported when eager execution is enabled.') 

2971 

2972 predictions, labels, weights = _remove_squeezable_dimensions( 

2973 predictions=predictions, labels=labels, weights=weights) 

2974 mse, update_mse_op = mean_squared_error(labels, predictions, weights, None, 

2975 None, name or 

2976 'root_mean_squared_error') 

2977 

2978 once_across_replicas = lambda _, mse: math_ops.sqrt(mse) 

2979 rmse = _aggregate_across_replicas( 

2980 metrics_collections, once_across_replicas, mse) 

2981 

2982 update_rmse_op = math_ops.sqrt(update_mse_op) 

2983 if updates_collections: 

2984 ops.add_to_collections(updates_collections, update_rmse_op) 

2985 

2986 return rmse, update_rmse_op 

2987 

2988 

2989@tf_export(v1=['metrics.sensitivity_at_specificity']) 

2990def sensitivity_at_specificity(labels, 

2991 predictions, 

2992 specificity, 

2993 weights=None, 

2994 num_thresholds=200, 

2995 metrics_collections=None, 

2996 updates_collections=None, 

2997 name=None): 

2998 """Computes the specificity at a given sensitivity. 

2999 

3000 The `sensitivity_at_specificity` function creates four local 

3001 variables, `true_positives`, `true_negatives`, `false_positives` and 

3002 `false_negatives` that are used to compute the sensitivity at the given 

3003 specificity value. The threshold for the given specificity value is computed 

3004 and used to evaluate the corresponding sensitivity. 

3005 

3006 For estimation of the metric over a stream of data, the function creates an 

3007 `update_op` operation that updates these variables and returns the 

3008 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 

3009 `false_positives` and `false_negatives` counts with the weight of each case 

3010 found in the `predictions` and `labels`. 

3011 

3012 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

3013 

3014 For additional information about specificity and sensitivity, see the 

3015 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 

3016 

3017 Args: 

3018 labels: The ground truth values, a `Tensor` whose dimensions must match 

3019 `predictions`. Will be cast to `bool`. 

3020 predictions: A floating point `Tensor` of arbitrary shape and whose values 

3021 are in the range `[0, 1]`. 

3022 specificity: A scalar value in range `[0, 1]`. 

3023 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

3024 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

3025 be either `1`, or the same as the corresponding `labels` dimension). 

3026 num_thresholds: The number of thresholds to use for matching the given 

3027 specificity. 

3028 metrics_collections: An optional list of collections that `sensitivity` 

3029 should be added to. 

3030 updates_collections: An optional list of collections that `update_op` should 

3031 be added to. 

3032 name: An optional variable_scope name. 

3033 

3034 Returns: 

3035 sensitivity: A scalar `Tensor` representing the sensitivity at the given 

3036 `specificity` value. 

3037 update_op: An operation that increments the `true_positives`, 

3038 `true_negatives`, `false_positives` and `false_negatives` variables 

3039 appropriately and whose value matches `sensitivity`. 

3040 

3041 Raises: 

3042 ValueError: If `predictions` and `labels` have mismatched shapes, if 

3043 `weights` is not `None` and its shape doesn't match `predictions`, or if 

3044 `specificity` is not between 0 and 1, or if either `metrics_collections` 

3045 or `updates_collections` are not a list or tuple. 

3046 RuntimeError: If eager execution is enabled. 

3047 """ 

3048 if context.executing_eagerly(): 

3049 raise RuntimeError('tf.metrics.sensitivity_at_specificity is not ' 

3050 'supported when eager execution is enabled.') 

3051 

3052 if specificity < 0 or specificity > 1: 

3053 raise ValueError('`specificity` must be in the range [0, 1]. Currently, ' 

3054 f'`specificity` got {specificity}.') 

3055 

3056 with variable_scope.variable_scope(name, 'sensitivity_at_specificity', 

3057 (predictions, labels, weights)): 

3058 kepsilon = 1e-7 # to account for floating point imprecisions 

3059 thresholds = [ 

3060 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 

3061 ] 

3062 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 

3063 

3064 values, update_ops = _confusion_matrix_at_thresholds( 

3065 labels, predictions, thresholds, weights) 

3066 

3067 def compute_sensitivity_at_specificity(tp, tn, fp, fn, name): 

3068 specificities = math_ops.divide(tn, tn + fp + kepsilon) 

3069 tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0) 

3070 tf_index = math_ops.cast(tf_index, dtypes.int32) 

3071 

3072 # Now, we have the implicit threshold, so compute the sensitivity: 

3073 return math_ops.divide(tp[tf_index], 

3074 tp[tf_index] + fn[tf_index] + kepsilon, name) 

3075 

3076 def sensitivity_across_replicas(_, values): 

3077 return compute_sensitivity_at_specificity( 

3078 values['tp'], values['tn'], values['fp'], values['fn'], 'value') 

3079 

3080 sensitivity = _aggregate_across_replicas( 

3081 metrics_collections, sensitivity_across_replicas, values) 

3082 

3083 update_op = compute_sensitivity_at_specificity( 

3084 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 

3085 'update_op') 

3086 if updates_collections: 

3087 ops.add_to_collections(updates_collections, update_op) 

3088 

3089 return sensitivity, update_op 

3090 

3091 

3092def _expand_and_tile(tensor, multiple, dim=0, name=None): 

3093 """Slice `tensor` shape in 2, then tile along the sliced dimension. 

3094 

3095 A new dimension is inserted in shape of `tensor` before `dim`, then values are 

3096 tiled `multiple` times along the new dimension. 

3097 

3098 Args: 

3099 tensor: Input `Tensor` or `SparseTensor`. 

3100 multiple: Integer, number of times to tile. 

3101 dim: Integer, dimension along which to tile. 

3102 name: Name of operation. 

3103 

3104 Returns: 

3105 `Tensor` result of expanding and tiling `tensor`. 

3106 

3107 Raises: 

3108 ValueError: if `multiple` is less than 1, or `dim` is not in 

3109 `[-rank(tensor), rank(tensor)]`. 

3110 """ 

3111 if multiple < 1: 

3112 raise ValueError(f'Invalid argument multiple={multiple} for ' 

3113 'expand_and_tile call. `multiple` must be an integer > 0') 

3114 with ops.name_scope(name, 'expand_and_tile', 

3115 (tensor, multiple, dim)) as scope: 

3116 # Sparse. 

3117 tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor) 

3118 if isinstance(tensor, sparse_tensor.SparseTensor): 

3119 if dim < 0: 

3120 expand_dims = array_ops.reshape( 

3121 array_ops.size(tensor.dense_shape) + dim, [1]) 

3122 else: 

3123 expand_dims = [dim] 

3124 expanded_shape = array_ops.concat( 

3125 (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1], 

3126 array_ops.slice(tensor.dense_shape, expand_dims, [-1])), 

3127 0, 

3128 name='expanded_shape') 

3129 expanded = sparse_ops.sparse_reshape( 

3130 tensor, shape=expanded_shape, name='expand') 

3131 if multiple == 1: 

3132 return expanded 

3133 return sparse_ops.sparse_concat( 

3134 dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope) 

3135 

3136 # Dense. 

3137 expanded = array_ops.expand_dims( 

3138 tensor, dim if (dim >= 0) else (dim - 1), name='expand') 

3139 if multiple == 1: 

3140 return expanded 

3141 ones = array_ops.ones_like(array_ops.shape(tensor)) 

3142 tile_multiples = array_ops.concat( 

3143 (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples') 

3144 return array_ops.tile(expanded, tile_multiples, name=scope) 

3145 

3146 

3147def _num_relevant(labels, k): 

3148 """Computes number of relevant values for each row in labels. 

3149 

3150 For labels with shape [D1, ... DN, num_labels], this is the minimum of 

3151 `num_labels` and `k`. 

3152 

3153 Args: 

3154 labels: `int64` `Tensor` or `SparseTensor` with shape 

3155 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 

3156 target classes for the associated prediction. Commonly, N=1 and `labels` 

3157 has shape [batch_size, num_labels]. 

3158 k: Integer, k for @k metric. 

3159 

3160 Returns: 

3161 Integer `Tensor` of shape [D1, ... DN], where each value is the number of 

3162 relevant values for that row. 

3163 

3164 Raises: 

3165 ValueError: if inputs have invalid dtypes or values. 

3166 """ 

3167 if k < 1: 

3168 raise ValueError(f'Invalid k={k}') 

3169 with ops.name_scope(None, 'num_relevant', (labels,)) as scope: 

3170 # For SparseTensor, calculate separate count for each row. 

3171 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 

3172 if isinstance(labels, sparse_tensor.SparseTensor): 

3173 return math_ops.minimum(sets.set_size(labels), k, name=scope) 

3174 

3175 # The relevant values for each (d1, ... dN) is the minimum of k and the 

3176 # number of labels along the last dimension that are non-negative. 

3177 num_labels = math_ops.reduce_sum( 

3178 array_ops.where_v2(math_ops.greater_equal(labels, 0), 

3179 array_ops.ones_like(labels), 

3180 array_ops.zeros_like(labels)), 

3181 axis=-1) 

3182 return math_ops.minimum(num_labels, k, name=scope) 

3183 

3184 

3185def _sparse_average_precision_at_top_k(labels, predictions_idx): 

3186 """Computes average precision@k of predictions with respect to sparse labels. 

3187 

3188 From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula 

3189 for each row is: 

3190 

3191 AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items 

3192 

3193 A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`, 

3194 `labels`, and the result `Tensors`. In the common case, this is [batch_size]. 

3195 Each row of the results contains the average precision for that row. 

3196 

3197 Args: 

3198 labels: `int64` `Tensor` or `SparseTensor` with shape 

3199 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 

3200 num_labels=1. N >= 1 and num_labels is the number of target classes for 

3201 the associated prediction. Commonly, N=1 and `labels` has shape 

3202 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`. 

3203 Values should be non-negative. Negative values are ignored. 

3204 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 

3205 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 

3206 dimension must be set and contains the top `k` predicted class indices. 

3207 [D1, ... DN] must match `labels`. Values should be in range 

3208 [0, num_classes). 

3209 

3210 Returns: 

3211 `float64` `Tensor` of shape [D1, ... DN], where each value is the average 

3212 precision for that row. 

3213 

3214 Raises: 

3215 ValueError: if the last dimension of predictions_idx is not set. 

3216 """ 

3217 with ops.name_scope(None, 'average_precision', 

3218 (predictions_idx, labels)) as scope: 

3219 predictions_idx = math_ops.cast( 

3220 predictions_idx, dtypes.int64, name='predictions_idx') 

3221 if predictions_idx.get_shape().ndims == 0: 

3222 raise ValueError('The rank of `predictions_idx` must be at least 1.') 

3223 k = predictions_idx.get_shape().as_list()[-1] 

3224 if k is None: 

3225 raise ValueError('The last dimension of predictions_idx must be set. ' 

3226 'Currently, it is None.') 

3227 labels = _maybe_expand_labels(labels, predictions_idx) 

3228 

3229 # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate 

3230 # prediction for each k, so we can calculate separate true positive values 

3231 # for each k. 

3232 predictions_idx_per_k = array_ops.expand_dims( 

3233 predictions_idx, -1, name='predictions_idx_per_k') 

3234 

3235 # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor. 

3236 labels_per_k = _expand_and_tile( 

3237 labels, multiple=k, dim=-1, name='labels_per_k') 

3238 

3239 # The following tensors are all of shape [D1, ... DN, k], containing values 

3240 # per row, per k value. 

3241 # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at 

3242 # that k value is correct, 0 otherwise. This is the "rel_{i}" term from 

3243 # the formula above. 

3244 # `tp_per_k` (int32) - True positive counts. 

3245 # `retrieved_per_k` (int32) - Number of predicted values at each k. This is 

3246 # the precision denominator. 

3247 # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}" 

3248 # term from the formula above. 

3249 # `relevant_precision_per_k` (float64) - Relevant precisions; i.e., 

3250 # precisions at all k for which relevance indicator is true. 

3251 relevant_per_k = _sparse_true_positive_at_k( 

3252 labels_per_k, predictions_idx_per_k, name='relevant_per_k') 

3253 tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k') 

3254 retrieved_per_k = math_ops.cumsum( 

3255 array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k') 

3256 precision_per_k = math_ops.divide( 

3257 math_ops.cast(tp_per_k, dtypes.float64), 

3258 math_ops.cast(retrieved_per_k, dtypes.float64), 

3259 name='precision_per_k') 

3260 relevant_precision_per_k = math_ops.multiply( 

3261 precision_per_k, 

3262 math_ops.cast(relevant_per_k, dtypes.float64), 

3263 name='relevant_precision_per_k') 

3264 

3265 # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor. 

3266 precision_sum = math_ops.reduce_sum( 

3267 relevant_precision_per_k, axis=(-1,), name='precision_sum') 

3268 

3269 # Divide by number of relevant items to get average precision. These are 

3270 # the "num_relevant_items" and "AveP" terms from the formula above. 

3271 num_relevant_items = math_ops.cast(_num_relevant(labels, k), dtypes.float64) 

3272 return math_ops.divide(precision_sum, num_relevant_items, name=scope) 

3273 

3274 

3275def _streaming_sparse_average_precision_at_top_k(labels, 

3276 predictions_idx, 

3277 weights=None, 

3278 metrics_collections=None, 

3279 updates_collections=None, 

3280 name=None): 

3281 """Computes average precision@k of predictions with respect to sparse labels. 

3282 

3283 `sparse_average_precision_at_top_k` creates two local variables, 

3284 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 

3285 are used to compute the frequency. This frequency is ultimately returned as 

3286 `average_precision_at_<k>`: an idempotent operation that simply divides 

3287 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 

3288 

3289 For estimation of the metric over a stream of data, the function creates an 

3290 `update_op` operation that updates these variables and returns the 

3291 `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate 

3292 the true positives and false positives weighted by `weights`. Then `update_op` 

3293 increments `true_positive_at_<k>` and `false_positive_at_<k>` using these 

3294 values. 

3295 

3296 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

3297 

3298 Args: 

3299 labels: `int64` `Tensor` or `SparseTensor` with shape 

3300 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 

3301 num_labels=1. N >= 1 and num_labels is the number of target classes for 

3302 the associated prediction. Commonly, N=1 and `labels` has shape 

3303 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`. 

3304 Values should be non-negative. Negative values are ignored. 

3305 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 

3306 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 

3307 dimension contains the top `k` predicted class indices. [D1, ... DN] must 

3308 match `labels`. Values should be in range [0, num_classes). 

3309 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 

3310 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 

3311 dimensions must be either `1`, or the same as the corresponding `labels` 

3312 dimension). 

3313 metrics_collections: An optional list of collections that values should 

3314 be added to. 

3315 updates_collections: An optional list of collections that updates should 

3316 be added to. 

3317 name: Name of new update operation, and namespace for other dependent ops. 

3318 

3319 Returns: 

3320 mean_average_precision: Scalar `float64` `Tensor` with the mean average 

3321 precision values. 

3322 update: `Operation` that increments variables appropriately, and whose 

3323 value matches `metric`. 

3324 """ 

3325 with ops.name_scope(name, 'average_precision_at_top_k', 

3326 (predictions_idx, labels, weights)) as scope: 

3327 # Calculate per-example average precision, and apply weights. 

3328 average_precision = _sparse_average_precision_at_top_k( 

3329 predictions_idx=predictions_idx, labels=labels) 

3330 if weights is not None: 

3331 weights = weights_broadcast_ops.broadcast_weights( 

3332 math_ops.cast(weights, dtypes.float64), average_precision) 

3333 average_precision = math_ops.multiply(average_precision, weights) 

3334 

3335 # Create accumulation variables and update ops for max average precision and 

3336 # total average precision. 

3337 with ops.name_scope(None, 'max', (average_precision,)) as max_scope: 

3338 # `max` is the max possible precision. Since max for any row is 1.0: 

3339 # - For the unweighted case, this is just the number of rows. 

3340 # - For the weighted case, it's the sum of the weights broadcast across 

3341 # `average_precision` rows. 

3342 max_var = metric_variable([], dtypes.float64, name=max_scope) 

3343 if weights is None: 

3344 batch_max = math_ops.cast( 

3345 array_ops.size(average_precision, name='batch_max'), dtypes.float64) 

3346 else: 

3347 batch_max = math_ops.reduce_sum(weights, name='batch_max') 

3348 max_update = state_ops.assign_add(max_var, batch_max, name='update') 

3349 with ops.name_scope(None, 'total', (average_precision,)) as total_scope: 

3350 total_var = metric_variable([], dtypes.float64, name=total_scope) 

3351 batch_total = math_ops.reduce_sum(average_precision, name='batch_total') 

3352 total_update = state_ops.assign_add(total_var, batch_total, name='update') 

3353 

3354 # Divide total by max to get mean, for both vars and the update ops. 

3355 def precision_across_replicas(_, total_var, max_var): 

3356 return _safe_scalar_div(total_var, max_var, name='mean') 

3357 

3358 mean_average_precision = _aggregate_across_replicas( 

3359 metrics_collections, precision_across_replicas, total_var, max_var) 

3360 

3361 update = _safe_scalar_div(total_update, max_update, name=scope) 

3362 if updates_collections: 

3363 ops.add_to_collections(updates_collections, update) 

3364 

3365 return mean_average_precision, update 

3366 

3367 

3368def _clean_out_of_range_indices(labels, num_classes): 

3369 """Replaces large out-of-range labels by small out-of-range labels. 

3370 

3371 Replaces any value in `labels` that is greater or equal to `num_classes` by 

3372 -1. Do this conditionally for efficiency in case there are no such values. 

3373 

3374 Args: 

3375 labels: `int64` `Tensor` or `SparseTensor`. 

3376 num_classes: `int64` scalar `Tensor`. 

3377 Returns: 

3378 An `int64` `Tensor` or `SparseTensor` as `labels` with indices greater 

3379 or equal to num_classes replaced by -1. 

3380 """ 

3381 

3382 def _labels_is_sparse(): 

3383 """Returns true is `labels` is a sparse tensor.""" 

3384 return isinstance(labels, (sparse_tensor.SparseTensor, 

3385 sparse_tensor.SparseTensorValue)) 

3386 

3387 def _clean_out_of_range(values): 

3388 """Replaces by -1 any large out-of-range `values`.""" 

3389 return array_ops.where_v2(math_ops.greater_equal(values, num_classes), 

3390 -1 * array_ops.ones_like(values), values) 

3391 

3392 def _clean_labels_out_of_range(): 

3393 """Replaces by -1 ane large out-of-range values in `labels`.""" 

3394 if _labels_is_sparse(): 

3395 return type(labels)(indices=labels.indices, 

3396 values=_clean_out_of_range(labels.values), 

3397 dense_shape=labels.dense_shape) 

3398 else: 

3399 return _clean_out_of_range(labels) 

3400 

3401 max_labels = math_ops.reduce_max( 

3402 labels.values if _labels_is_sparse() else labels) 

3403 return cond.cond( 

3404 math_ops.greater_equal(max_labels, num_classes), 

3405 _clean_labels_out_of_range, 

3406 lambda: labels) 

3407 

3408 

3409@tf_export(v1=['metrics.sparse_average_precision_at_k']) 

3410@deprecated(None, 'Use average_precision_at_k instead') 

3411def sparse_average_precision_at_k(labels, 

3412 predictions, 

3413 k, 

3414 weights=None, 

3415 metrics_collections=None, 

3416 updates_collections=None, 

3417 name=None): 

3418 """Renamed to `average_precision_at_k`, please use that method instead.""" 

3419 return average_precision_at_k( 

3420 labels=labels, 

3421 predictions=predictions, 

3422 k=k, 

3423 weights=weights, 

3424 metrics_collections=metrics_collections, 

3425 updates_collections=updates_collections, 

3426 name=name) 

3427 

3428 

3429@tf_export(v1=['metrics.average_precision_at_k']) 

3430def average_precision_at_k(labels, 

3431 predictions, 

3432 k, 

3433 weights=None, 

3434 metrics_collections=None, 

3435 updates_collections=None, 

3436 name=None): 

3437 """Computes average precision@k of predictions with respect to sparse labels. 

3438 

3439 `average_precision_at_k` creates two local variables, 

3440 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 

3441 are used to compute the frequency. This frequency is ultimately returned as 

3442 `average_precision_at_<k>`: an idempotent operation that simply divides 

3443 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 

3444 

3445 For estimation of the metric over a stream of data, the function creates an 

3446 `update_op` operation that updates these variables and returns the 

3447 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 

3448 indicating the top `k` `predictions`. Set operations applied to `top_k` and 

3449 `labels` calculate the true positives and false positives weighted by 

3450 `weights`. Then `update_op` increments `true_positive_at_<k>` and 

3451 `false_positive_at_<k>` using these values. 

3452 

3453 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

3454 

3455 Args: 

3456 labels: `int64` `Tensor` or `SparseTensor` with shape 

3457 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 

3458 num_labels=1. N >= 1 and num_labels is the number of target classes for 

3459 the associated prediction. Commonly, N=1 and `labels` has shape 

3460 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 

3461 should be in range [0, num_classes), where num_classes is the last 

3462 dimension of `predictions`. Values outside this range are ignored. 

3463 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 

3464 N >= 1. Commonly, N=1 and `predictions` has shape 

3465 [batch size, num_classes]. The final dimension contains the logit values 

3466 for each class. [D1, ... DN] must match `labels`. 

3467 k: Integer, k for @k metric. This will calculate an average precision for 

3468 range `[1,k]`, as documented above. 

3469 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 

3470 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 

3471 dimensions must be either `1`, or the same as the corresponding `labels` 

3472 dimension). 

3473 metrics_collections: An optional list of collections that values should 

3474 be added to. 

3475 updates_collections: An optional list of collections that updates should 

3476 be added to. 

3477 name: Name of new update operation, and namespace for other dependent ops. 

3478 

3479 Returns: 

3480 mean_average_precision: Scalar `float64` `Tensor` with the mean average 

3481 precision values. 

3482 update: `Operation` that increments variables appropriately, and whose 

3483 value matches `metric`. 

3484 

3485 Raises: 

3486 ValueError: if k is invalid. 

3487 RuntimeError: If eager execution is enabled. 

3488 """ 

3489 if context.executing_eagerly(): 

3490 raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not ' 

3491 'supported when eager execution is enabled.') 

3492 

3493 if k < 1: 

3494 raise ValueError(f'Invalid k={k}. `k` should be >= 1.') 

3495 with ops.name_scope(name, _at_k_name('average_precision', k), 

3496 (predictions, labels, weights)) as scope: 

3497 # Calculate top k indices to produce [D1, ... DN, k] tensor. 

3498 _, predictions_idx = nn.top_k(predictions, k) 

3499 # The documentation states that labels should be in [0, ..., num_classes), 

3500 # but num_classes is lost when predictions_idx replaces predictions. 

3501 # For conformity with the documentation, any label >= num_classes, which is 

3502 # ignored, is replaced by -1. 

3503 labels = _clean_out_of_range_indices( 

3504 labels, math_ops.cast(array_ops.shape(predictions)[-1], dtypes.int64)) 

3505 return _streaming_sparse_average_precision_at_top_k( 

3506 labels=labels, 

3507 predictions_idx=predictions_idx, 

3508 weights=weights, 

3509 metrics_collections=metrics_collections, 

3510 updates_collections=updates_collections, 

3511 name=scope) 

3512 

3513 

3514def _sparse_false_positive_at_k(labels, 

3515 predictions_idx, 

3516 class_id=None, 

3517 weights=None): 

3518 """Calculates false positives for precision@k. 

3519 

3520 If `class_id` is specified, calculate binary true positives for `class_id` 

3521 only. 

3522 If `class_id` is not specified, calculate metrics for `k` predicted vs 

3523 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 

3524 

3525 Args: 

3526 labels: `int64` `Tensor` or `SparseTensor` with shape 

3527 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 

3528 target classes for the associated prediction. Commonly, N=1 and `labels` 

3529 has shape [batch_size, num_labels]. [D1, ... DN] must match 

3530 `predictions_idx`. 

3531 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 

3532 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 

3533 match `labels`. 

3534 class_id: Class for which we want binary metrics. 

3535 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 

3536 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 

3537 dimensions must be either `1`, or the same as the corresponding `labels` 

3538 dimension). 

3539 

3540 Returns: 

3541 A [D1, ... DN] `Tensor` of false positive counts. 

3542 """ 

3543 with ops.name_scope(None, 'false_positives', 

3544 (predictions_idx, labels, weights)): 

3545 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 

3546 class_id) 

3547 fp = sets.set_size( 

3548 sets.set_difference(predictions_idx, labels, aminusb=True)) 

3549 fp = math_ops.cast(fp, dtypes.float64) 

3550 if weights is not None: 

3551 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 

3552 weights, fp),)): 

3553 weights = math_ops.cast(weights, dtypes.float64) 

3554 fp = math_ops.multiply(fp, weights) 

3555 return fp 

3556 

3557 

3558def _streaming_sparse_false_positive_at_k(labels, 

3559 predictions_idx, 

3560 k=None, 

3561 class_id=None, 

3562 weights=None, 

3563 name=None): 

3564 """Calculates weighted per step false positives for precision@k. 

3565 

3566 If `class_id` is specified, calculate binary true positives for `class_id` 

3567 only. 

3568 If `class_id` is not specified, calculate metrics for `k` predicted vs 

3569 `n` label classes, where `n` is the 2nd dimension of `labels`. 

3570 

3571 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

3572 

3573 Args: 

3574 labels: `int64` `Tensor` or `SparseTensor` with shape 

3575 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 

3576 target classes for the associated prediction. Commonly, N=1 and `labels` 

3577 has shape [batch_size, num_labels]. [D1, ... DN] must match 

3578 `predictions_idx`. 

3579 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 

3580 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 

3581 match `labels`. 

3582 k: Integer, k for @k metric. This is only used for default op name. 

3583 class_id: Class for which we want binary metrics. 

3584 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 

3585 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 

3586 dimensions must be either `1`, or the same as the corresponding `labels` 

3587 dimension). 

3588 name: Name of new variable, and namespace for other dependent ops. 

3589 

3590 Returns: 

3591 A tuple of `Variable` and update `Operation`. 

3592 

3593 Raises: 

3594 ValueError: If `weights` is not `None` and has an incompatible shape. 

3595 """ 

3596 with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id), 

3597 (predictions_idx, labels, weights)) as scope: 

3598 fp = _sparse_false_positive_at_k( 

3599 predictions_idx=predictions_idx, 

3600 labels=labels, 

3601 class_id=class_id, 

3602 weights=weights) 

3603 batch_total_fp = math_ops.cast(math_ops.reduce_sum(fp), dtypes.float64) 

3604 

3605 var = metric_variable([], dtypes.float64, name=scope) 

3606 return var, state_ops.assign_add(var, batch_total_fp, name='update') 

3607 

3608 

3609@tf_export(v1=['metrics.precision_at_top_k']) 

3610def precision_at_top_k(labels, 

3611 predictions_idx, 

3612 k=None, 

3613 class_id=None, 

3614 weights=None, 

3615 metrics_collections=None, 

3616 updates_collections=None, 

3617 name=None): 

3618 """Computes precision@k of the predictions with respect to sparse labels. 

3619 

3620 Differs from `sparse_precision_at_k` in that predictions must be in the form 

3621 of top `k` class indices, whereas `sparse_precision_at_k` expects logits. 

3622 Refer to `sparse_precision_at_k` for more details. 

3623 

3624 Args: 

3625 labels: `int64` `Tensor` or `SparseTensor` with shape 

3626 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 

3627 num_labels=1. N >= 1 and num_labels is the number of target classes for 

3628 the associated prediction. Commonly, N=1 and `labels` has shape 

3629 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 

3630 should be in range [0, num_classes), where num_classes is the last 

3631 dimension of `predictions`. Values outside this range are ignored. 

3632 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where 

3633 N >= 1. Commonly, N=1 and predictions has shape [batch size, k]. 

3634 The final dimension contains the top `k` predicted class indices. 

3635 [D1, ... DN] must match `labels`. 

3636 k: Integer, k for @k metric. Only used for the default op name. 

3637 class_id: Integer class ID for which we want binary metrics. This should be 

3638 in range [0, num_classes], where num_classes is the last dimension of 

3639 `predictions`. If `class_id` is outside this range, the method returns 

3640 NAN. 

3641 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 

3642 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 

3643 dimensions must be either `1`, or the same as the corresponding `labels` 

3644 dimension). 

3645 metrics_collections: An optional list of collections that values should 

3646 be added to. 

3647 updates_collections: An optional list of collections that updates should 

3648 be added to. 

3649 name: Name of new update operation, and namespace for other dependent ops. 

3650 

3651 Returns: 

3652 precision: Scalar `float64` `Tensor` with the value of `true_positives` 

3653 divided by the sum of `true_positives` and `false_positives`. 

3654 update_op: `Operation` that increments `true_positives` and 

3655 `false_positives` variables appropriately, and whose value matches 

3656 `precision`. 

3657 

3658 Raises: 

3659 ValueError: If `weights` is not `None` and its shape doesn't match 

3660 `predictions`, or if either `metrics_collections` or `updates_collections` 

3661 are not a list or tuple. 

3662 RuntimeError: If eager execution is enabled. 

3663 """ 

3664 if context.executing_eagerly(): 

3665 raise RuntimeError('tf.metrics.precision_at_top_k is not ' 

3666 'supported when eager execution is enabled.') 

3667 

3668 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id), 

3669 (predictions_idx, labels, weights)) as scope: 

3670 labels = _maybe_expand_labels(labels, predictions_idx) 

3671 top_k_idx = math_ops.cast(predictions_idx, dtypes.int64) 

3672 tp, tp_update = _streaming_sparse_true_positive_at_k( 

3673 predictions_idx=top_k_idx, 

3674 labels=labels, 

3675 k=k, 

3676 class_id=class_id, 

3677 weights=weights) 

3678 fp, fp_update = _streaming_sparse_false_positive_at_k( 

3679 predictions_idx=top_k_idx, 

3680 labels=labels, 

3681 k=k, 

3682 class_id=class_id, 

3683 weights=weights) 

3684 

3685 def precision_across_replicas(_, tp, fp): 

3686 return math_ops.divide(tp, math_ops.add(tp, fp), name=scope) 

3687 

3688 metric = _aggregate_across_replicas( 

3689 metrics_collections, precision_across_replicas, tp, fp) 

3690 

3691 update = math_ops.divide( 

3692 tp_update, math_ops.add(tp_update, fp_update), name='update') 

3693 if updates_collections: 

3694 ops.add_to_collections(updates_collections, update) 

3695 return metric, update 

3696 

3697 

3698@tf_export(v1=['metrics.sparse_precision_at_k']) 

3699@deprecated(None, 'Use precision_at_k instead') 

3700def sparse_precision_at_k(labels, 

3701 predictions, 

3702 k, 

3703 class_id=None, 

3704 weights=None, 

3705 metrics_collections=None, 

3706 updates_collections=None, 

3707 name=None): 

3708 """Renamed to `precision_at_k`, please use that method instead.""" 

3709 return precision_at_k( 

3710 labels=labels, 

3711 predictions=predictions, 

3712 k=k, 

3713 class_id=class_id, 

3714 weights=weights, 

3715 metrics_collections=metrics_collections, 

3716 updates_collections=updates_collections, 

3717 name=name) 

3718 

3719 

3720@tf_export(v1=['metrics.precision_at_k']) 

3721def precision_at_k(labels, 

3722 predictions, 

3723 k, 

3724 class_id=None, 

3725 weights=None, 

3726 metrics_collections=None, 

3727 updates_collections=None, 

3728 name=None): 

3729 """Computes precision@k of the predictions with respect to sparse labels. 

3730 

3731 If `class_id` is specified, we calculate precision by considering only the 

3732 entries in the batch for which `class_id` is in the top-k highest 

3733 `predictions`, and computing the fraction of them for which `class_id` is 

3734 indeed a correct label. 

3735 If `class_id` is not specified, we'll calculate precision as how often on 

3736 average a class among the top-k classes with the highest predicted values 

3737 of a batch entry is correct and can be found in the label for that entry. 

3738 

3739 `precision_at_k` creates two local variables, 

3740 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 

3741 the precision@k frequency. This frequency is ultimately returned as 

3742 `precision_at_<k>`: an idempotent operation that simply divides 

3743 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 

3744 `false_positive_at_<k>`). 

3745 

3746 For estimation of the metric over a stream of data, the function creates an 

3747 `update_op` operation that updates these variables and returns the 

3748 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 

3749 indicating the top `k` `predictions`. Set operations applied to `top_k` and 

3750 `labels` calculate the true positives and false positives weighted by 

3751 `weights`. Then `update_op` increments `true_positive_at_<k>` and 

3752 `false_positive_at_<k>` using these values. 

3753 

3754 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

3755 

3756 Args: 

3757 labels: `int64` `Tensor` or `SparseTensor` with shape 

3758 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 

3759 num_labels=1. N >= 1 and num_labels is the number of target classes for 

3760 the associated prediction. Commonly, N=1 and `labels` has shape 

3761 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 

3762 should be in range [0, num_classes), where num_classes is the last 

3763 dimension of `predictions`. Values outside this range are ignored. 

3764 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 

3765 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 

3766 The final dimension contains the logit values for each class. [D1, ... DN] 

3767 must match `labels`. 

3768 k: Integer, k for @k metric. 

3769 class_id: Integer class ID for which we want binary metrics. This should be 

3770 in range [0, num_classes], where num_classes is the last dimension of 

3771 `predictions`. If `class_id` is outside this range, the method returns 

3772 NAN. 

3773 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 

3774 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 

3775 dimensions must be either `1`, or the same as the corresponding `labels` 

3776 dimension). 

3777 metrics_collections: An optional list of collections that values should 

3778 be added to. 

3779 updates_collections: An optional list of collections that updates should 

3780 be added to. 

3781 name: Name of new update operation, and namespace for other dependent ops. 

3782 

3783 Returns: 

3784 precision: Scalar `float64` `Tensor` with the value of `true_positives` 

3785 divided by the sum of `true_positives` and `false_positives`. 

3786 update_op: `Operation` that increments `true_positives` and 

3787 `false_positives` variables appropriately, and whose value matches 

3788 `precision`. 

3789 

3790 Raises: 

3791 ValueError: If `weights` is not `None` and its shape doesn't match 

3792 `predictions`, or if either `metrics_collections` or `updates_collections` 

3793 are not a list or tuple. 

3794 RuntimeError: If eager execution is enabled. 

3795 """ 

3796 if context.executing_eagerly(): 

3797 raise RuntimeError('tf.metrics.sparse_precision_at_k is not ' 

3798 'supported when eager execution is enabled.') 

3799 

3800 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id), 

3801 (predictions, labels, weights)) as scope: 

3802 _, top_k_idx = nn.top_k(predictions, k) 

3803 return precision_at_top_k( 

3804 labels=labels, 

3805 predictions_idx=top_k_idx, 

3806 k=k, 

3807 class_id=class_id, 

3808 weights=weights, 

3809 metrics_collections=metrics_collections, 

3810 updates_collections=updates_collections, 

3811 name=scope) 

3812 

3813 

3814@tf_export(v1=['metrics.specificity_at_sensitivity']) 

3815def specificity_at_sensitivity(labels, 

3816 predictions, 

3817 sensitivity, 

3818 weights=None, 

3819 num_thresholds=200, 

3820 metrics_collections=None, 

3821 updates_collections=None, 

3822 name=None): 

3823 """Computes the specificity at a given sensitivity. 

3824 

3825 The `specificity_at_sensitivity` function creates four local 

3826 variables, `true_positives`, `true_negatives`, `false_positives` and 

3827 `false_negatives` that are used to compute the specificity at the given 

3828 sensitivity value. The threshold for the given sensitivity value is computed 

3829 and used to evaluate the corresponding specificity. 

3830 

3831 For estimation of the metric over a stream of data, the function creates an 

3832 `update_op` operation that updates these variables and returns the 

3833 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 

3834 `false_positives` and `false_negatives` counts with the weight of each case 

3835 found in the `predictions` and `labels`. 

3836 

3837 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 

3838 

3839 For additional information about specificity and sensitivity, see the 

3840 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 

3841 

3842 Args: 

3843 labels: The ground truth values, a `Tensor` whose dimensions must match 

3844 `predictions`. Will be cast to `bool`. 

3845 predictions: A floating point `Tensor` of arbitrary shape and whose values 

3846 are in the range `[0, 1]`. 

3847 sensitivity: A scalar value in range `[0, 1]`. 

3848 weights: Optional `Tensor` whose rank is either 0, or the same rank as 

3849 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 

3850 be either `1`, or the same as the corresponding `labels` dimension). 

3851 num_thresholds: The number of thresholds to use for matching the given 

3852 sensitivity. 

3853 metrics_collections: An optional list of collections that `specificity` 

3854 should be added to. 

3855 updates_collections: An optional list of collections that `update_op` should 

3856 be added to. 

3857 name: An optional variable_scope name. 

3858 

3859 Returns: 

3860 specificity: A scalar `Tensor` representing the specificity at the given 

3861 `sensitivity` value. 

3862 update_op: An operation that increments the `true_positives`, 

3863 `true_negatives`, `false_positives` and `false_negatives` variables 

3864 appropriately and whose value matches `specificity`. 

3865 

3866 Raises: 

3867 ValueError: If `predictions` and `labels` have mismatched shapes, if 

3868 `weights` is not `None` and its shape doesn't match `predictions`, or if 

3869 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 

3870 or `updates_collections` are not a list or tuple. 

3871 RuntimeError: If eager execution is enabled. 

3872 """ 

3873 if context.executing_eagerly(): 

3874 raise RuntimeError('tf.metrics.specificity_at_sensitivity is not ' 

3875 'supported when eager execution is enabled.') 

3876 

3877 if sensitivity < 0 or sensitivity > 1: 

3878 raise ValueError('`sensitivity` must be in the range [0, 1]. Currently, ' 

3879 f'`sensitivity` is {sensitivity}.') 

3880 

3881 with variable_scope.variable_scope(name, 'specificity_at_sensitivity', 

3882 (predictions, labels, weights)): 

3883 kepsilon = 1e-7 # to account for floating point imprecisions 

3884 thresholds = [ 

3885 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 

3886 ] 

3887 thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon] 

3888 

3889 values, update_ops = _confusion_matrix_at_thresholds( 

3890 labels, predictions, thresholds, weights) 

3891 

3892 def compute_specificity_at_sensitivity(tp, tn, fp, fn, name): 

3893 """Computes the specificity at the given sensitivity. 

3894 

3895 Args: 

3896 tp: True positives. 

3897 tn: True negatives. 

3898 fp: False positives. 

3899 fn: False negatives. 

3900 name: The name of the operation. 

3901 

3902 Returns: 

3903 The specificity using the aggregated values. 

3904 """ 

3905 sensitivities = math_ops.divide(tp, tp + fn + kepsilon) 

3906 

3907 # We'll need to use this trick until tf.argmax allows us to specify 

3908 # whether we should use the first or last index in case of ties. 

3909 min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity)) 

3910 indices_at_minval = math_ops.equal( 

3911 math_ops.abs(sensitivities - sensitivity), min_val) 

3912 indices_at_minval = math_ops.cast(indices_at_minval, dtypes.int64) 

3913 indices_at_minval = math_ops.cumsum(indices_at_minval) 

3914 tf_index = math_ops.argmax(indices_at_minval, 0) 

3915 tf_index = math_ops.cast(tf_index, dtypes.int32) 

3916 

3917 # Now, we have the implicit threshold, so compute the specificity: 

3918 return math_ops.divide(tn[tf_index], 

3919 tn[tf_index] + fp[tf_index] + kepsilon, name) 

3920 

3921 def specificity_across_replicas(_, values): 

3922 return compute_specificity_at_sensitivity( 

3923 values['tp'], values['tn'], values['fp'], values['fn'], 'value') 

3924 

3925 specificity = _aggregate_across_replicas( 

3926 metrics_collections, specificity_across_replicas, values) 

3927 

3928 update_op = compute_specificity_at_sensitivity( 

3929 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 

3930 'update_op') 

3931 if updates_collections: 

3932 ops.add_to_collections(updates_collections, update_op) 

3933 

3934 return specificity, update_op