Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/metrics/base_metric.py: 28%

322 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base Metric classes.""" 

16 

17import abc 

18import types 

19import warnings 

20 

21import numpy as np 

22import tensorflow.compat.v2 as tf 

23 

24from keras.src import backend 

25from keras.src.dtensor import dtensor_api as dtensor 

26from keras.src.dtensor import utils as dtensor_utils 

27from keras.src.engine import base_layer 

28from keras.src.engine import base_layer_utils 

29from keras.src.engine import keras_tensor 

30from keras.src.saving.legacy.saved_model import metric_serialization 

31from keras.src.utils import generic_utils 

32from keras.src.utils import losses_utils 

33from keras.src.utils import metrics_utils 

34from keras.src.utils import tf_utils 

35 

36# isort: off 

37from tensorflow.python.util.tf_export import keras_export 

38from tensorflow.tools.docs import doc_controls 

39 

40 

41@keras_export("keras.metrics.Metric") 

42class Metric(base_layer.Layer, metaclass=abc.ABCMeta): 

43 """Encapsulates metric logic and state. 

44 

45 Args: 

46 name: (Optional) string name of the metric instance. 

47 dtype: (Optional) data type of the metric result. 

48 **kwargs: Additional layer keywords arguments. 

49 

50 Standalone usage: 

51 

52 ```python 

53 m = SomeMetric(...) 

54 for input in ...: 

55 m.update_state(input) 

56 print('Final result: ', m.result().numpy()) 

57 ``` 

58 

59 Usage with `compile()` API: 

60 

61 ```python 

62 model = tf.keras.Sequential() 

63 model.add(tf.keras.layers.Dense(64, activation='relu')) 

64 model.add(tf.keras.layers.Dense(64, activation='relu')) 

65 model.add(tf.keras.layers.Dense(10, activation='softmax')) 

66 

67 model.compile(optimizer=tf.keras.optimizers.RMSprop(0.01), 

68 loss=tf.keras.losses.CategoricalCrossentropy(), 

69 metrics=[tf.keras.metrics.CategoricalAccuracy()]) 

70 

71 data = np.random.random((1000, 32)) 

72 labels = np.random.random((1000, 10)) 

73 

74 dataset = tf.data.Dataset.from_tensor_slices((data, labels)) 

75 dataset = dataset.batch(32) 

76 

77 model.fit(dataset, epochs=10) 

78 ``` 

79 

80 To be implemented by subclasses: 

81 * `__init__()`: All state variables should be created in this method by 

82 calling `self.add_weight()` like: `self.var = self.add_weight(...)` 

83 * `update_state()`: Has all updates to the state variables like: 

84 self.var.assign_add(...). 

85 * `result()`: Computes and returns a scalar value or a dict of scalar values 

86 for the metric from the state variables. 

87 

88 Example subclass implementation: 

89 

90 ```python 

91 class BinaryTruePositives(tf.keras.metrics.Metric): 

92 

93 def __init__(self, name='binary_true_positives', **kwargs): 

94 super(BinaryTruePositives, self).__init__(name=name, **kwargs) 

95 self.true_positives = self.add_weight(name='tp', initializer='zeros') 

96 

97 def update_state(self, y_true, y_pred, sample_weight=None): 

98 y_true = tf.cast(y_true, tf.bool) 

99 y_pred = tf.cast(y_pred, tf.bool) 

100 

101 values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True)) 

102 values = tf.cast(values, self.dtype) 

103 if sample_weight is not None: 

104 sample_weight = tf.cast(sample_weight, self.dtype) 

105 sample_weight = tf.broadcast_to(sample_weight, values.shape) 

106 values = tf.multiply(values, sample_weight) 

107 self.true_positives.assign_add(tf.reduce_sum(values)) 

108 

109 def result(self): 

110 return self.true_positives 

111 ``` 

112 """ 

113 

114 def __init__(self, name=None, dtype=None, **kwargs): 

115 super().__init__(name=name, dtype=dtype, **kwargs) 

116 self.stateful = True # All metric layers are stateful. 

117 self.built = True 

118 if not base_layer_utils.v2_dtype_behavior_enabled(): 

119 # We only do this when the V2 behavior is not enabled, as when it is 

120 # enabled, the dtype already defaults to floatx. 

121 self._dtype = ( 

122 backend.floatx() if dtype is None else tf.as_dtype(dtype).name 

123 ) 

124 

125 def __new__(cls, *args, **kwargs): 

126 obj = super(Metric, cls).__new__(cls) 

127 

128 # If `update_state` is not in eager/tf.function and it is not from a 

129 # built-in metric, wrap it in `tf.function`. This is so that users 

130 # writing custom metrics in v1 need not worry about control dependencies 

131 # and return ops. 

132 if base_layer_utils.is_in_eager_or_tf_function() or is_built_in(cls): 

133 obj_update_state = obj.update_state 

134 

135 def update_state_fn(*args, **kwargs): 

136 control_status = tf.__internal__.autograph.control_status_ctx() 

137 ag_update_state = tf.__internal__.autograph.tf_convert( 

138 obj_update_state, control_status 

139 ) 

140 return ag_update_state(*args, **kwargs) 

141 

142 else: 

143 if isinstance(obj.update_state, tf.__internal__.function.Function): 

144 update_state_fn = obj.update_state 

145 else: 

146 update_state_fn = tf.function(obj.update_state) 

147 

148 obj.update_state = types.MethodType( 

149 metrics_utils.update_state_wrapper(update_state_fn), obj 

150 ) 

151 

152 obj_result = obj.result 

153 

154 def result_fn(*args, **kwargs): 

155 control_status = tf.__internal__.autograph.control_status_ctx() 

156 ag_result = tf.__internal__.autograph.tf_convert( 

157 obj_result, control_status 

158 ) 

159 return ag_result(*args, **kwargs) 

160 

161 obj.result = types.MethodType( 

162 metrics_utils.result_wrapper(result_fn), obj 

163 ) 

164 

165 return obj 

166 

167 def __call__(self, *args, **kwargs): 

168 """Accumulates statistics and then computes metric result value. 

169 

170 Args: 

171 *args: 

172 **kwargs: A mini-batch of inputs to the Metric, 

173 passed on to `update_state()`. 

174 

175 Returns: 

176 The metric value tensor. 

177 """ 

178 

179 def replica_local_fn(*args, **kwargs): 

180 """Updates the state of the metric in a replica-local context.""" 

181 if any( 

182 isinstance(arg, keras_tensor.KerasTensor) 

183 for arg in tf.nest.flatten((args, kwargs)) 

184 ): 

185 update_op = None 

186 else: 

187 update_op = self.update_state(*args, **kwargs) 

188 update_ops = [] 

189 if update_op is not None: 

190 update_ops.append(update_op) 

191 with tf.control_dependencies(update_ops): 

192 result_t = self.result() 

193 

194 # If the metric object return a dictionary as a result, wrap it 

195 # with our custom dict object so we can attach the metric object 

196 # to it. 

197 if isinstance(result_t, dict): 

198 result_t = _MetricDict(**result_t) 

199 

200 # We are adding the metric object as metadata on the result 

201 # tensor. This is required when we want to use a metric with 

202 # `add_metric` API on a Model/Layer in graph mode. This metric 

203 # instance will later be used to reset variable state after each 

204 # epoch of training. 

205 # Example: 

206 # model = Model() 

207 # mean = Mean() 

208 # model.add_metric(mean(values), name='mean') 

209 result_t._metric_obj = self 

210 return result_t 

211 

212 from keras.src.distribute import ( 

213 distributed_training_utils, 

214 ) 

215 

216 return distributed_training_utils.call_replica_local_fn( 

217 replica_local_fn, *args, **kwargs 

218 ) 

219 

220 def __str__(self): 

221 args = ",".join(f"{k}={v}" for k, v in self.get_config().items()) 

222 return f"{self.__class__.__name__}({args})" 

223 

224 def __deepcopy__(self, memo=None): 

225 try: 

226 new_self = self.from_config(self.get_config()) 

227 except NotImplementedError as e: 

228 raise NotImplementedError( 

229 "Calling `__deepcopy__()` on a Keras metric " 

230 "requires the metric to be serializable, " 

231 "i.e. it should implement `get_config()`.\n\n" 

232 f"Error encountered during serialization: [{e}]" 

233 ) 

234 # Note that metrics don't implement `build()` so their variables 

235 # are readily available after instantiation. 

236 if self.weights: 

237 new_self.set_weights(self.get_weights()) 

238 memo[self] = new_self 

239 return new_self 

240 

241 @property 

242 def dtype(self): 

243 return self._dtype 

244 

245 def get_config(self): 

246 """Returns the serializable config of the metric.""" 

247 return {"name": self.name, "dtype": self.dtype} 

248 

249 def reset_state(self): 

250 """Resets all of the metric state variables. 

251 

252 This function is called between epochs/steps, 

253 when a metric is evaluated during training. 

254 """ 

255 if not generic_utils.is_default(self.reset_states): 

256 warnings.warn( 

257 "Metric %s implements a `reset_states()` method; rename it " 

258 'to `reset_state()` (without the final "s"). The name ' 

259 "`reset_states()` has been deprecated to improve API " 

260 "consistency." % (self.__class__.__name__,), 

261 stacklevel=2, 

262 ) 

263 return self.reset_states() 

264 else: 

265 backend.batch_set_value([(v, 0) for v in self.variables]) 

266 

267 @abc.abstractmethod 

268 def update_state(self, *args, **kwargs): 

269 """Accumulates statistics for the metric. 

270 

271 Note: This function is executed as a graph function in graph mode. 

272 This means: 

273 a) Operations on the same resource are executed in textual order. 

274 This should make it easier to do things like add the updated 

275 value of a variable to another, for example. 

276 b) You don't need to worry about collecting the update ops to execute. 

277 All update ops added to the graph by this function will be 

278 executed. 

279 As a result, code should generally work the same way with graph or 

280 eager execution. 

281 

282 Args: 

283 *args: 

284 **kwargs: A mini-batch of inputs to the Metric. 

285 """ 

286 raise NotImplementedError("Must be implemented in subclasses.") 

287 

288 def merge_state(self, metrics): 

289 """Merges the state from one or more metrics. 

290 

291 This method can be used by distributed systems to merge the state 

292 computed by different metric instances. Typically the state will be 

293 stored in the form of the metric's weights. For example, a 

294 tf.keras.metrics.Mean metric contains a list of two weight values: a 

295 total and a count. If there were two instances of a 

296 tf.keras.metrics.Accuracy that each independently aggregated partial 

297 state for an overall accuracy calculation, these two metric's states 

298 could be combined as follows: 

299 

300 >>> m1 = tf.keras.metrics.Accuracy() 

301 >>> _ = m1.update_state([[1], [2]], [[0], [2]]) 

302 

303 >>> m2 = tf.keras.metrics.Accuracy() 

304 >>> _ = m2.update_state([[3], [4]], [[3], [4]]) 

305 

306 >>> m2.merge_state([m1]) 

307 >>> m2.result().numpy() 

308 0.75 

309 

310 Args: 

311 metrics: an iterable of metrics. The metrics must have compatible 

312 state. 

313 

314 Raises: 

315 ValueError: If the provided iterable does not contain metrics matching 

316 the metric's required specifications. 

317 """ 

318 assign_add_ops = [] 

319 for metric in metrics: 

320 if len(self.weights) != len(metric.weights): 

321 raise ValueError( 

322 f"Metric {metric} is not compatible with {self}" 

323 ) 

324 for weight, weight_to_add in zip(self.weights, metric.weights): 

325 assign_add_ops.append(weight.assign_add(weight_to_add)) 

326 return assign_add_ops 

327 

328 @abc.abstractmethod 

329 def result(self): 

330 """Computes and returns the scalar metric value tensor or a dict of 

331 scalars. 

332 

333 Result computation is an idempotent operation that simply calculates the 

334 metric value using the state variables. 

335 

336 Returns: 

337 A scalar tensor, or a dictionary of scalar tensors. 

338 """ 

339 raise NotImplementedError("Must be implemented in subclasses.") 

340 

341 ### For use by subclasses ### 

342 @doc_controls.for_subclass_implementers 

343 def add_weight( 

344 self, 

345 name, 

346 shape=(), 

347 aggregation=tf.VariableAggregation.SUM, 

348 synchronization=tf.VariableSynchronization.ON_READ, 

349 initializer=None, 

350 dtype=None, 

351 ): 

352 """Adds state variable. Only for use by subclasses.""" 

353 if tf.distribute.has_strategy(): 

354 strategy = tf.distribute.get_strategy() 

355 else: 

356 strategy = None 

357 

358 additional_kwargs = {} 

359 

360 # TODO(b/120571621): Make `ON_READ` work with Keras metrics on TPU. 

361 if backend.is_tpu_strategy(strategy): 

362 synchronization = tf.VariableSynchronization.ON_WRITE 

363 if getattr(self, "_mesh", None) is not None: 

364 # When self._mesh is set, it means this metric is used for DTensor. 

365 additional_kwargs = { 

366 "layout": dtensor.Layout.replicated( 

367 self._mesh, tf.TensorShape(shape).rank 

368 ) 

369 } 

370 

371 if tf_utils.in_local_vars_context(): 

372 # Metrics created within a remotely-executed tf.function during 

373 # parameter server evaluation should use tf2 Variables, so that they 

374 # can be local variables that are freely usable and mutable within 

375 # the function, using the 

376 # `experimental_enable_variable_lifting=False` argument. This 

377 # supports a visitation guarantee for model evaluation. 

378 def local_v2_var_creator( 

379 initializer=None, dtype=None, shape=None, **kwargs 

380 ): 

381 init_val, var_dtype = base_layer_utils.infer_init_val_and_dtype( 

382 initializer, dtype, shape 

383 ) 

384 v1_only_args = ["use_resource", "collections"] 

385 for v1_arg in v1_only_args: 

386 kwargs.pop(v1_arg, None) 

387 kwargs["experimental_enable_variable_lifting"] = False 

388 return tf.Variable( 

389 initial_value=init_val, 

390 dtype=var_dtype, 

391 shape=shape, 

392 **kwargs, 

393 ) 

394 

395 additional_kwargs["getter"] = local_v2_var_creator 

396 

397 with tf_utils.maybe_init_scope(layer=self): 

398 return super().add_weight( 

399 name=name, 

400 shape=shape, 

401 dtype=self._dtype if dtype is None else dtype, 

402 trainable=False, 

403 initializer=initializer, 

404 collections=[], 

405 synchronization=synchronization, 

406 aggregation=aggregation, 

407 **additional_kwargs, 

408 ) 

409 

410 ### End: For use by subclasses ### 

411 

412 @property 

413 def trainable_weights(self): 

414 # Overridden from Layer class to track submetric weights. 

415 if self.trainable: 

416 trainable_weights = self._trainable_weights 

417 for m in self._metrics: 

418 trainable_weights += m.trainable_weights 

419 return self._dedup_weights(trainable_weights) 

420 else: 

421 return [] 

422 

423 @property 

424 def non_trainable_weights(self): 

425 # Overridden from Layer class to track submetric weights. 

426 if self.trainable: 

427 non_trainable_weights = self._non_trainable_weights 

428 for m in self._metrics: 

429 non_trainable_weights += m.non_trainable_weights 

430 else: 

431 non_trainable_weights = ( 

432 self._non_trainable_weights + self._trainable_weights 

433 ) 

434 for m in self._metrics: 

435 non_trainable_weights += m.weights 

436 return self._dedup_weights(non_trainable_weights) 

437 

438 @property 

439 def _trackable_saved_model_saver(self): 

440 return metric_serialization.MetricSavedModelSaver(self) 

441 

442 @generic_utils.default 

443 @doc_controls.do_not_generate_docs 

444 def reset_states(self): 

445 # Backwards compatibility alias of `reset_state`. New classes should 

446 # only implement `reset_state`. 

447 return self.reset_state() 

448 

449 

450class Reduce(Metric): 

451 """Encapsulates metrics that perform a reduce operation on the values. 

452 

453 Args: 

454 reduction: a `tf.keras.metrics.Reduction` enum value. 

455 name: string name of the metric instance. 

456 dtype: (Optional) data type of the metric result. 

457 """ 

458 

459 def __init__(self, reduction, name, dtype=None): 

460 super().__init__(name=name, dtype=dtype) 

461 self.reduction = reduction 

462 self.total = self.add_weight("total", initializer="zeros") 

463 if reduction in [ 

464 metrics_utils.Reduction.SUM_OVER_BATCH_SIZE, 

465 metrics_utils.Reduction.WEIGHTED_MEAN, 

466 ]: 

467 self.count = self.add_weight("count", initializer="zeros") 

468 

469 def update_state(self, values, sample_weight=None): 

470 """Accumulates statistics for computing the metric. 

471 

472 Args: 

473 values: Per-example value. 

474 sample_weight: Optional weighting of each example. Defaults to `1`. 

475 

476 Returns: 

477 Update op. 

478 """ 

479 [ 

480 values 

481 ], sample_weight = metrics_utils.ragged_assert_compatible_and_get_flat_values( # noqa: E501 

482 [values], sample_weight 

483 ) 

484 try: 

485 values = tf.cast(values, self._dtype) 

486 except (ValueError, TypeError): 

487 msg = ( 

488 "The output of a metric function can only be a single Tensor. " 

489 f"Received: {values}. " 

490 ) 

491 if isinstance(values, dict): 

492 msg += ( 

493 "To return a dict of values, implement a custom Metric " 

494 "subclass." 

495 ) 

496 raise RuntimeError(msg) 

497 if sample_weight is not None: 

498 sample_weight = tf.cast(sample_weight, self._dtype) 

499 # Update dimensions of weights to match with values if possible. 

500 ( 

501 values, 

502 _, 

503 sample_weight, 

504 ) = losses_utils.squeeze_or_expand_dimensions( 

505 values, sample_weight=sample_weight 

506 ) 

507 try: 

508 # Broadcast weights if possible. 

509 sample_weight = tf.__internal__.ops.broadcast_weights( 

510 sample_weight, values 

511 ) 

512 except ValueError: 

513 # Reduce values to same ndim as weight array 

514 ndim = backend.ndim(values) 

515 weight_ndim = backend.ndim(sample_weight) 

516 if self.reduction == metrics_utils.Reduction.SUM: 

517 values = tf.reduce_sum( 

518 values, axis=list(range(weight_ndim, ndim)) 

519 ) 

520 else: 

521 values = tf.reduce_mean( 

522 values, axis=list(range(weight_ndim, ndim)) 

523 ) 

524 values = tf.multiply(values, sample_weight) 

525 

526 value_sum = tf.reduce_sum(values) 

527 with tf.control_dependencies([value_sum]): 

528 update_total_op = self.total.assign_add(value_sum) 

529 

530 # Exit early if the reduction doesn't have a denominator. 

531 if self.reduction == metrics_utils.Reduction.SUM: 

532 return update_total_op 

533 

534 # Update `count` for reductions that require a denominator. 

535 if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE: 

536 num_values = tf.cast(tf.size(values), self._dtype) 

537 elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN: 

538 if sample_weight is None: 

539 num_values = tf.cast(tf.size(values), self._dtype) 

540 else: 

541 num_values = tf.reduce_sum(sample_weight) 

542 else: 

543 raise NotImplementedError( 

544 f'Reduction "{self.reduction}" not implemented. Expected ' 

545 '"sum", "weighted_mean", or "sum_over_batch_size".' 

546 ) 

547 

548 with tf.control_dependencies([update_total_op]): 

549 return self.count.assign_add(num_values) 

550 

551 def result(self): 

552 if self.reduction == metrics_utils.Reduction.SUM: 

553 return tf.identity(self.total) 

554 elif self.reduction in [ 

555 metrics_utils.Reduction.WEIGHTED_MEAN, 

556 metrics_utils.Reduction.SUM_OVER_BATCH_SIZE, 

557 ]: 

558 return tf.math.divide_no_nan(self.total, self.count) 

559 else: 

560 raise NotImplementedError( 

561 f'Reduction "{self.reduction}" not implemented. Expected ' 

562 '"sum", "weighted_mean", or "sum_over_batch_size".' 

563 ) 

564 

565 

566@keras_export("keras.metrics.Sum") 

567class Sum(Reduce): 

568 """Computes the (weighted) sum of the given values. 

569 

570 For example, if values is [1, 3, 5, 7] then the sum is 16. 

571 If the weights were specified as [1, 1, 0, 0] then the sum would be 4. 

572 

573 This metric creates one variable, `total`, that is used to compute the sum 

574 of `values`. This is ultimately returned as `sum`. 

575 

576 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 

577 0 to mask values. 

578 

579 Args: 

580 name: (Optional) string name of the metric instance. 

581 dtype: (Optional) data type of the metric result. 

582 

583 Standalone usage: 

584 

585 >>> m = tf.keras.metrics.Sum() 

586 >>> m.update_state([1, 3, 5, 7]) 

587 >>> m.result().numpy() 

588 16.0 

589 

590 Usage with `compile()` API: 

591 

592 ```python 

593 model.add_metric(tf.keras.metrics.Sum(name='sum_1')(outputs)) 

594 model.compile(optimizer='sgd', loss='mse') 

595 ``` 

596 """ 

597 

598 @dtensor_utils.inject_mesh 

599 def __init__(self, name="sum", dtype=None): 

600 super().__init__( 

601 reduction=metrics_utils.Reduction.SUM, name=name, dtype=dtype 

602 ) 

603 

604 

605@keras_export("keras.metrics.Mean") 

606class Mean(Reduce): 

607 """Computes the (weighted) mean of the given values. 

608 

609 For example, if values is [1, 3, 5, 7] then the mean is 4. 

610 If the weights were specified as [1, 1, 0, 0] then the mean would be 2. 

611 

612 This metric creates two variables, `total` and `count` that are used to 

613 compute the average of `values`. This average is ultimately returned as 

614 `mean` which is an idempotent operation that simply divides `total` by 

615 `count`. 

616 

617 If `sample_weight` is `None`, weights default to 1. 

618 Use `sample_weight` of 0 to mask values. 

619 

620 Args: 

621 name: (Optional) string name of the metric instance. 

622 dtype: (Optional) data type of the metric result. 

623 

624 Standalone usage: 

625 

626 >>> m = tf.keras.metrics.Mean() 

627 >>> m.update_state([1, 3, 5, 7]) 

628 >>> m.result().numpy() 

629 4.0 

630 >>> m.reset_state() 

631 >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0]) 

632 >>> m.result().numpy() 

633 2.0 

634 

635 Usage with `compile()` API: 

636 

637 ```python 

638 model.add_metric(tf.keras.metrics.Mean(name='mean_1')(outputs)) 

639 model.compile(optimizer='sgd', loss='mse') 

640 ``` 

641 """ 

642 

643 @dtensor_utils.inject_mesh 

644 def __init__(self, name="mean", dtype=None): 

645 super().__init__( 

646 reduction=metrics_utils.Reduction.WEIGHTED_MEAN, 

647 name=name, 

648 dtype=dtype, 

649 ) 

650 

651 

652@keras_export("keras.metrics.MeanMetricWrapper") 

653class MeanMetricWrapper(Mean): 

654 """Wraps a stateless metric function with the Mean metric. 

655 

656 You could use this class to quickly build a mean metric from a function. The 

657 function needs to have the signature `fn(y_true, y_pred)` and return a 

658 per-sample loss array. `MeanMetricWrapper.result()` will return 

659 the average metric value across all samples seen so far. 

660 

661 For example: 

662 

663 ```python 

664 def accuracy(y_true, y_pred): 

665 return tf.cast(tf.math.equal(y_true, y_pred), tf.float32) 

666 

667 accuracy_metric = tf.keras.metrics.MeanMetricWrapper(fn=accuracy) 

668 

669 keras_model.compile(..., metrics=accuracy_metric) 

670 ``` 

671 

672 Args: 

673 fn: The metric function to wrap, with signature `fn(y_true, y_pred, 

674 **kwargs)`. 

675 name: (Optional) string name of the metric instance. 

676 dtype: (Optional) data type of the metric result. 

677 **kwargs: Keyword arguments to pass on to `fn`. 

678 """ 

679 

680 @dtensor_utils.inject_mesh 

681 def __init__(self, fn, name=None, dtype=None, **kwargs): 

682 super().__init__(name=name, dtype=dtype) 

683 self._fn = fn 

684 self._fn_kwargs = kwargs 

685 

686 def update_state(self, y_true, y_pred, sample_weight=None): 

687 """Accumulates metric statistics. 

688 

689 `y_true` and `y_pred` should have the same shape. 

690 

691 Args: 

692 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 

693 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 

694 sample_weight: Optional `sample_weight` acts as a 

695 coefficient for the metric. If a scalar is provided, then the metric 

696 is simply scaled by the given value. If `sample_weight` is a tensor 

697 of size `[batch_size]`, then the metric for each sample of the batch 

698 is rescaled by the corresponding element in the `sample_weight` 

699 vector. If the shape of `sample_weight` is `[batch_size, d0, .. 

700 dN-1]` (or can be broadcasted to this shape), then each metric 

701 element of `y_pred` is scaled by the corresponding value of 

702 `sample_weight`. (Note on `dN-1`: all metric functions reduce by 1 

703 dimension, usually the last axis (-1)). 

704 

705 Returns: 

706 Update op. 

707 """ 

708 y_true = tf.cast(y_true, self._dtype) 

709 y_pred = tf.cast(y_pred, self._dtype) 

710 [ 

711 y_true, 

712 y_pred, 

713 ], sample_weight = metrics_utils.ragged_assert_compatible_and_get_flat_values( # noqa: E501 

714 [y_true, y_pred], sample_weight 

715 ) 

716 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 

717 y_pred, y_true 

718 ) 

719 

720 ag_fn = tf.__internal__.autograph.tf_convert( 

721 self._fn, tf.__internal__.autograph.control_status_ctx() 

722 ) 

723 matches = ag_fn(y_true, y_pred, **self._fn_kwargs) 

724 mask = losses_utils.get_mask(matches) 

725 sample_weight = losses_utils.apply_valid_mask( 

726 matches, sample_weight, mask, self.reduction 

727 ) 

728 return super().update_state(matches, sample_weight=sample_weight) 

729 

730 def get_config(self): 

731 config = { 

732 k: backend.eval(v) if tf_utils.is_tensor_or_variable(v) else v 

733 for k, v in self._fn_kwargs.items() 

734 } 

735 

736 if type(self) is MeanMetricWrapper: 

737 # Only include function argument when the object is a 

738 # MeanMetricWrapper and not a subclass. 

739 config["fn"] = self._fn 

740 

741 base_config = super().get_config() 

742 return dict(list(base_config.items()) + list(config.items())) 

743 

744 @classmethod 

745 def from_config(cls, config): 

746 from keras.src.metrics import get 

747 

748 # Note that while MeanMetricWrapper itself isn't public, objects of this 

749 # class may be created and added to the model by calling model.compile. 

750 fn = config.pop("fn", None) 

751 if cls is MeanMetricWrapper: 

752 return cls(get(fn), **config) 

753 return super(MeanMetricWrapper, cls).from_config(config) 

754 

755 

756@keras_export("keras.metrics.MeanTensor") 

757class MeanTensor(Metric): 

758 """Computes the element-wise (weighted) mean of the given tensors. 

759 

760 `MeanTensor` returns a tensor with the same shape of the input tensors. The 

761 mean value is updated by keeping local variables `total` and `count`. The 

762 `total` tracks the sum of the weighted values, and `count` stores the sum of 

763 the weighted counts. 

764 

765 Args: 

766 name: (Optional) string name of the metric instance. 

767 dtype: (Optional) data type of the metric result. 

768 shape: (Optional) A list of integers, a tuple of integers, or a 1-D Tensor 

769 of type int32. If not specified, the shape is inferred from the values 

770 at the first call of update_state. 

771 

772 Standalone usage: 

773 

774 >>> m = tf.keras.metrics.MeanTensor() 

775 >>> m.update_state([0, 1, 2, 3]) 

776 >>> m.update_state([4, 5, 6, 7]) 

777 >>> m.result().numpy() 

778 array([2., 3., 4., 5.], dtype=float32) 

779 

780 >>> m.update_state([12, 10, 8, 6], sample_weight= [0, 0.2, 0.5, 1]) 

781 >>> m.result().numpy() 

782 array([2. , 3.6363635, 4.8 , 5.3333335], dtype=float32) 

783 

784 >>> m = tf.keras.metrics.MeanTensor(dtype=tf.float64, shape=(1, 4)) 

785 >>> m.result().numpy() 

786 array([[0., 0., 0., 0.]]) 

787 >>> m.update_state([[0, 1, 2, 3]]) 

788 >>> m.update_state([[4, 5, 6, 7]]) 

789 >>> m.result().numpy() 

790 array([[2., 3., 4., 5.]]) 

791 """ 

792 

793 @dtensor_utils.inject_mesh 

794 def __init__(self, name="mean_tensor", dtype=None, shape=None): 

795 super().__init__(name=name, dtype=dtype) 

796 self._shape = None 

797 self._total = None 

798 self._count = None 

799 self._built = False 

800 if shape is not None: 

801 self._build(shape) 

802 

803 def _build(self, shape): 

804 self._shape = tf.TensorShape(shape) 

805 self._build_input_shape = self._shape 

806 # Create new state variables 

807 self._total = self.add_weight( 

808 name="total", shape=shape, initializer="zeros" 

809 ) 

810 self._count = self.add_weight( 

811 name="count", shape=shape, initializer="zeros" 

812 ) 

813 with tf.init_scope(): 

814 if not tf.executing_eagerly(): 

815 backend._initialize_variables(backend._get_session()) 

816 self._built = True 

817 

818 @property 

819 def total(self): 

820 return self._total if self._built else None 

821 

822 @property 

823 def count(self): 

824 return self._count if self._built else None 

825 

826 def update_state(self, values, sample_weight=None): 

827 """Accumulates statistics for computing the element-wise mean. 

828 

829 Args: 

830 values: Per-example value. 

831 sample_weight: Optional weighting of each example. Defaults to `1`. 

832 

833 Returns: 

834 Update op. 

835 """ 

836 values = tf.cast(values, self._dtype) 

837 if not self._built: 

838 self._build(values.shape) 

839 elif values.shape != self._shape: 

840 raise ValueError( 

841 "MeanTensor input values must always have the same " 

842 "shape. Expected shape (set during the first call): " 

843 f"{self._shape}. " 

844 f"Got: {values.shape}." 

845 ) 

846 

847 num_values = tf.ones_like(values) 

848 if sample_weight is not None: 

849 sample_weight = tf.cast(sample_weight, self._dtype) 

850 

851 # Update dimensions of weights to match with values if possible. 

852 ( 

853 values, 

854 _, 

855 sample_weight, 

856 ) = losses_utils.squeeze_or_expand_dimensions( 

857 values, sample_weight=sample_weight 

858 ) 

859 try: 

860 # Broadcast weights if possible. 

861 sample_weight = tf.__internal__.ops.broadcast_weights( 

862 sample_weight, values 

863 ) 

864 except ValueError: 

865 # Reduce values to same ndim as weight array 

866 ndim = backend.ndim(values) 

867 weight_ndim = backend.ndim(sample_weight) 

868 values = tf.reduce_mean( 

869 values, axis=list(range(weight_ndim, ndim)) 

870 ) 

871 

872 num_values = tf.multiply(num_values, sample_weight) 

873 values = tf.multiply(values, sample_weight) 

874 

875 update_total_op = self._total.assign_add(values) 

876 with tf.control_dependencies([update_total_op]): 

877 return self._count.assign_add(num_values) 

878 

879 def result(self): 

880 if not self._built: 

881 raise ValueError( 

882 "MeanTensor does not have any value yet. Please call the " 

883 "MeanTensor instance or use `.update_state(value)` " 

884 "before retrieving the result." 

885 ) 

886 return tf.math.divide_no_nan(self.total, self.count) 

887 

888 def reset_state(self): 

889 if self._built: 

890 backend.batch_set_value( 

891 [(v, np.zeros(v.shape.as_list())) for v in self.variables] 

892 ) 

893 

894 

895class SumOverBatchSize(Reduce): 

896 """Computes the weighted sum over batch size of the given values. 

897 

898 For example, if values is [1, 3, 5, 7] then the metric value is 4. 

899 If the weights were specified as [1, 1, 0, 0] then the value would be 1. 

900 

901 This metric creates two variables, `total` and `count` that are used to 

902 compute the average of `values`. This average is ultimately returned as sum 

903 over batch size which is an idempotent operation that simply divides `total` 

904 by `count`. 

905 

906 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 

907 0 to mask values. 

908 """ 

909 

910 def __init__(self, name="sum_over_batch_size", dtype=None): 

911 super().__init__( 

912 reduction=metrics_utils.Reduction.SUM_OVER_BATCH_SIZE, 

913 name=name, 

914 dtype=dtype, 

915 ) 

916 

917 

918class SumOverBatchSizeMetricWrapper(SumOverBatchSize): 

919 """Wraps a function with the `SumOverBatchSizeMetricWrapper` metric.""" 

920 

921 def __init__(self, fn, name=None, dtype=None, **kwargs): 

922 """Creates a `SumOverBatchSizeMetricWrapper` instance. 

923 

924 Args: 

925 fn: The metric function to wrap, with signature `fn(y_true, y_pred, 

926 **kwargs)`. 

927 name: (Optional) string name of the metric instance. 

928 dtype: (Optional) data type of the metric result. 

929 **kwargs: The keyword arguments that are passed on to `fn`. 

930 """ 

931 super().__init__(name=name, dtype=dtype) 

932 self._fn = fn 

933 self._fn_kwargs = kwargs 

934 

935 def update_state(self, y_true, y_pred, sample_weight=None): 

936 y_true = tf.cast(y_true, self._dtype) 

937 y_pred = tf.cast(y_pred, self._dtype) 

938 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 

939 y_pred, y_true 

940 ) 

941 

942 ag_fn = tf.__internal__.autograph.tf_convert( 

943 self._fn, tf.__internal__.autograph.control_status_ctx() 

944 ) 

945 matches = ag_fn(y_true, y_pred, **self._fn_kwargs) 

946 mask = losses_utils.get_mask(matches) 

947 sample_weight = losses_utils.apply_valid_mask( 

948 matches, sample_weight, mask, self.reduction 

949 ) 

950 return super().update_state(matches, sample_weight=sample_weight) 

951 

952 def get_config(self): 

953 config = { 

954 k: backend.eval(v) if tf_utils.is_tensor_or_variable(v) else v 

955 for k, v in self._fn_kwargs.items() 

956 } 

957 base_config = super().get_config() 

958 return dict(list(base_config.items()) + list(config.items())) 

959 

960 

961def clone_metric(metric): 

962 """Returns a clone of the metric if stateful, otherwise returns it as is.""" 

963 if isinstance(metric, Metric): 

964 # Metrics created within a remotely-executed tf.function during 

965 # parameter server evaluation should not be lifted out of the graph by 

966 # `init_scope`. This way the metric variables can be local: freely 

967 # usable and mutable within the function. This supports a visitation 

968 # guarantee for model evaluation. 

969 if tf_utils.in_local_vars_context(): 

970 return metric.__class__.from_config(metric.get_config()) 

971 else: 

972 with tf.init_scope(): 

973 return metric.__class__.from_config(metric.get_config()) 

974 return metric 

975 

976 

977def clone_metrics(metrics): 

978 """Clones the given metric list/dict.""" 

979 return tf.nest.map_structure(clone_metric, metrics) 

980 

981 

982def is_built_in(cls): 

983 return cls.__module__.startswith( 

984 ".".join(Metric.__module__.split(".")[:-1]) 

985 ) 

986 

987 

988class _MetricDict(dict): 

989 """Wrapper for returned dictionary of metrics.""" 

990 

991 def __init__(self, **kwargs): 

992 super().__init__(**kwargs) 

993 self._metric_obj = None 

994