Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/metrics/base_metric.py: 28%
322 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base Metric classes."""
17import abc
18import types
19import warnings
21import numpy as np
22import tensorflow.compat.v2 as tf
24from keras.src import backend
25from keras.src.dtensor import dtensor_api as dtensor
26from keras.src.dtensor import utils as dtensor_utils
27from keras.src.engine import base_layer
28from keras.src.engine import base_layer_utils
29from keras.src.engine import keras_tensor
30from keras.src.saving.legacy.saved_model import metric_serialization
31from keras.src.utils import generic_utils
32from keras.src.utils import losses_utils
33from keras.src.utils import metrics_utils
34from keras.src.utils import tf_utils
36# isort: off
37from tensorflow.python.util.tf_export import keras_export
38from tensorflow.tools.docs import doc_controls
41@keras_export("keras.metrics.Metric")
42class Metric(base_layer.Layer, metaclass=abc.ABCMeta):
43 """Encapsulates metric logic and state.
45 Args:
46 name: (Optional) string name of the metric instance.
47 dtype: (Optional) data type of the metric result.
48 **kwargs: Additional layer keywords arguments.
50 Standalone usage:
52 ```python
53 m = SomeMetric(...)
54 for input in ...:
55 m.update_state(input)
56 print('Final result: ', m.result().numpy())
57 ```
59 Usage with `compile()` API:
61 ```python
62 model = tf.keras.Sequential()
63 model.add(tf.keras.layers.Dense(64, activation='relu'))
64 model.add(tf.keras.layers.Dense(64, activation='relu'))
65 model.add(tf.keras.layers.Dense(10, activation='softmax'))
67 model.compile(optimizer=tf.keras.optimizers.RMSprop(0.01),
68 loss=tf.keras.losses.CategoricalCrossentropy(),
69 metrics=[tf.keras.metrics.CategoricalAccuracy()])
71 data = np.random.random((1000, 32))
72 labels = np.random.random((1000, 10))
74 dataset = tf.data.Dataset.from_tensor_slices((data, labels))
75 dataset = dataset.batch(32)
77 model.fit(dataset, epochs=10)
78 ```
80 To be implemented by subclasses:
81 * `__init__()`: All state variables should be created in this method by
82 calling `self.add_weight()` like: `self.var = self.add_weight(...)`
83 * `update_state()`: Has all updates to the state variables like:
84 self.var.assign_add(...).
85 * `result()`: Computes and returns a scalar value or a dict of scalar values
86 for the metric from the state variables.
88 Example subclass implementation:
90 ```python
91 class BinaryTruePositives(tf.keras.metrics.Metric):
93 def __init__(self, name='binary_true_positives', **kwargs):
94 super(BinaryTruePositives, self).__init__(name=name, **kwargs)
95 self.true_positives = self.add_weight(name='tp', initializer='zeros')
97 def update_state(self, y_true, y_pred, sample_weight=None):
98 y_true = tf.cast(y_true, tf.bool)
99 y_pred = tf.cast(y_pred, tf.bool)
101 values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
102 values = tf.cast(values, self.dtype)
103 if sample_weight is not None:
104 sample_weight = tf.cast(sample_weight, self.dtype)
105 sample_weight = tf.broadcast_to(sample_weight, values.shape)
106 values = tf.multiply(values, sample_weight)
107 self.true_positives.assign_add(tf.reduce_sum(values))
109 def result(self):
110 return self.true_positives
111 ```
112 """
114 def __init__(self, name=None, dtype=None, **kwargs):
115 super().__init__(name=name, dtype=dtype, **kwargs)
116 self.stateful = True # All metric layers are stateful.
117 self.built = True
118 if not base_layer_utils.v2_dtype_behavior_enabled():
119 # We only do this when the V2 behavior is not enabled, as when it is
120 # enabled, the dtype already defaults to floatx.
121 self._dtype = (
122 backend.floatx() if dtype is None else tf.as_dtype(dtype).name
123 )
125 def __new__(cls, *args, **kwargs):
126 obj = super(Metric, cls).__new__(cls)
128 # If `update_state` is not in eager/tf.function and it is not from a
129 # built-in metric, wrap it in `tf.function`. This is so that users
130 # writing custom metrics in v1 need not worry about control dependencies
131 # and return ops.
132 if base_layer_utils.is_in_eager_or_tf_function() or is_built_in(cls):
133 obj_update_state = obj.update_state
135 def update_state_fn(*args, **kwargs):
136 control_status = tf.__internal__.autograph.control_status_ctx()
137 ag_update_state = tf.__internal__.autograph.tf_convert(
138 obj_update_state, control_status
139 )
140 return ag_update_state(*args, **kwargs)
142 else:
143 if isinstance(obj.update_state, tf.__internal__.function.Function):
144 update_state_fn = obj.update_state
145 else:
146 update_state_fn = tf.function(obj.update_state)
148 obj.update_state = types.MethodType(
149 metrics_utils.update_state_wrapper(update_state_fn), obj
150 )
152 obj_result = obj.result
154 def result_fn(*args, **kwargs):
155 control_status = tf.__internal__.autograph.control_status_ctx()
156 ag_result = tf.__internal__.autograph.tf_convert(
157 obj_result, control_status
158 )
159 return ag_result(*args, **kwargs)
161 obj.result = types.MethodType(
162 metrics_utils.result_wrapper(result_fn), obj
163 )
165 return obj
167 def __call__(self, *args, **kwargs):
168 """Accumulates statistics and then computes metric result value.
170 Args:
171 *args:
172 **kwargs: A mini-batch of inputs to the Metric,
173 passed on to `update_state()`.
175 Returns:
176 The metric value tensor.
177 """
179 def replica_local_fn(*args, **kwargs):
180 """Updates the state of the metric in a replica-local context."""
181 if any(
182 isinstance(arg, keras_tensor.KerasTensor)
183 for arg in tf.nest.flatten((args, kwargs))
184 ):
185 update_op = None
186 else:
187 update_op = self.update_state(*args, **kwargs)
188 update_ops = []
189 if update_op is not None:
190 update_ops.append(update_op)
191 with tf.control_dependencies(update_ops):
192 result_t = self.result()
194 # If the metric object return a dictionary as a result, wrap it
195 # with our custom dict object so we can attach the metric object
196 # to it.
197 if isinstance(result_t, dict):
198 result_t = _MetricDict(**result_t)
200 # We are adding the metric object as metadata on the result
201 # tensor. This is required when we want to use a metric with
202 # `add_metric` API on a Model/Layer in graph mode. This metric
203 # instance will later be used to reset variable state after each
204 # epoch of training.
205 # Example:
206 # model = Model()
207 # mean = Mean()
208 # model.add_metric(mean(values), name='mean')
209 result_t._metric_obj = self
210 return result_t
212 from keras.src.distribute import (
213 distributed_training_utils,
214 )
216 return distributed_training_utils.call_replica_local_fn(
217 replica_local_fn, *args, **kwargs
218 )
220 def __str__(self):
221 args = ",".join(f"{k}={v}" for k, v in self.get_config().items())
222 return f"{self.__class__.__name__}({args})"
224 def __deepcopy__(self, memo=None):
225 try:
226 new_self = self.from_config(self.get_config())
227 except NotImplementedError as e:
228 raise NotImplementedError(
229 "Calling `__deepcopy__()` on a Keras metric "
230 "requires the metric to be serializable, "
231 "i.e. it should implement `get_config()`.\n\n"
232 f"Error encountered during serialization: [{e}]"
233 )
234 # Note that metrics don't implement `build()` so their variables
235 # are readily available after instantiation.
236 if self.weights:
237 new_self.set_weights(self.get_weights())
238 memo[self] = new_self
239 return new_self
241 @property
242 def dtype(self):
243 return self._dtype
245 def get_config(self):
246 """Returns the serializable config of the metric."""
247 return {"name": self.name, "dtype": self.dtype}
249 def reset_state(self):
250 """Resets all of the metric state variables.
252 This function is called between epochs/steps,
253 when a metric is evaluated during training.
254 """
255 if not generic_utils.is_default(self.reset_states):
256 warnings.warn(
257 "Metric %s implements a `reset_states()` method; rename it "
258 'to `reset_state()` (without the final "s"). The name '
259 "`reset_states()` has been deprecated to improve API "
260 "consistency." % (self.__class__.__name__,),
261 stacklevel=2,
262 )
263 return self.reset_states()
264 else:
265 backend.batch_set_value([(v, 0) for v in self.variables])
267 @abc.abstractmethod
268 def update_state(self, *args, **kwargs):
269 """Accumulates statistics for the metric.
271 Note: This function is executed as a graph function in graph mode.
272 This means:
273 a) Operations on the same resource are executed in textual order.
274 This should make it easier to do things like add the updated
275 value of a variable to another, for example.
276 b) You don't need to worry about collecting the update ops to execute.
277 All update ops added to the graph by this function will be
278 executed.
279 As a result, code should generally work the same way with graph or
280 eager execution.
282 Args:
283 *args:
284 **kwargs: A mini-batch of inputs to the Metric.
285 """
286 raise NotImplementedError("Must be implemented in subclasses.")
288 def merge_state(self, metrics):
289 """Merges the state from one or more metrics.
291 This method can be used by distributed systems to merge the state
292 computed by different metric instances. Typically the state will be
293 stored in the form of the metric's weights. For example, a
294 tf.keras.metrics.Mean metric contains a list of two weight values: a
295 total and a count. If there were two instances of a
296 tf.keras.metrics.Accuracy that each independently aggregated partial
297 state for an overall accuracy calculation, these two metric's states
298 could be combined as follows:
300 >>> m1 = tf.keras.metrics.Accuracy()
301 >>> _ = m1.update_state([[1], [2]], [[0], [2]])
303 >>> m2 = tf.keras.metrics.Accuracy()
304 >>> _ = m2.update_state([[3], [4]], [[3], [4]])
306 >>> m2.merge_state([m1])
307 >>> m2.result().numpy()
308 0.75
310 Args:
311 metrics: an iterable of metrics. The metrics must have compatible
312 state.
314 Raises:
315 ValueError: If the provided iterable does not contain metrics matching
316 the metric's required specifications.
317 """
318 assign_add_ops = []
319 for metric in metrics:
320 if len(self.weights) != len(metric.weights):
321 raise ValueError(
322 f"Metric {metric} is not compatible with {self}"
323 )
324 for weight, weight_to_add in zip(self.weights, metric.weights):
325 assign_add_ops.append(weight.assign_add(weight_to_add))
326 return assign_add_ops
328 @abc.abstractmethod
329 def result(self):
330 """Computes and returns the scalar metric value tensor or a dict of
331 scalars.
333 Result computation is an idempotent operation that simply calculates the
334 metric value using the state variables.
336 Returns:
337 A scalar tensor, or a dictionary of scalar tensors.
338 """
339 raise NotImplementedError("Must be implemented in subclasses.")
341 ### For use by subclasses ###
342 @doc_controls.for_subclass_implementers
343 def add_weight(
344 self,
345 name,
346 shape=(),
347 aggregation=tf.VariableAggregation.SUM,
348 synchronization=tf.VariableSynchronization.ON_READ,
349 initializer=None,
350 dtype=None,
351 ):
352 """Adds state variable. Only for use by subclasses."""
353 if tf.distribute.has_strategy():
354 strategy = tf.distribute.get_strategy()
355 else:
356 strategy = None
358 additional_kwargs = {}
360 # TODO(b/120571621): Make `ON_READ` work with Keras metrics on TPU.
361 if backend.is_tpu_strategy(strategy):
362 synchronization = tf.VariableSynchronization.ON_WRITE
363 if getattr(self, "_mesh", None) is not None:
364 # When self._mesh is set, it means this metric is used for DTensor.
365 additional_kwargs = {
366 "layout": dtensor.Layout.replicated(
367 self._mesh, tf.TensorShape(shape).rank
368 )
369 }
371 if tf_utils.in_local_vars_context():
372 # Metrics created within a remotely-executed tf.function during
373 # parameter server evaluation should use tf2 Variables, so that they
374 # can be local variables that are freely usable and mutable within
375 # the function, using the
376 # `experimental_enable_variable_lifting=False` argument. This
377 # supports a visitation guarantee for model evaluation.
378 def local_v2_var_creator(
379 initializer=None, dtype=None, shape=None, **kwargs
380 ):
381 init_val, var_dtype = base_layer_utils.infer_init_val_and_dtype(
382 initializer, dtype, shape
383 )
384 v1_only_args = ["use_resource", "collections"]
385 for v1_arg in v1_only_args:
386 kwargs.pop(v1_arg, None)
387 kwargs["experimental_enable_variable_lifting"] = False
388 return tf.Variable(
389 initial_value=init_val,
390 dtype=var_dtype,
391 shape=shape,
392 **kwargs,
393 )
395 additional_kwargs["getter"] = local_v2_var_creator
397 with tf_utils.maybe_init_scope(layer=self):
398 return super().add_weight(
399 name=name,
400 shape=shape,
401 dtype=self._dtype if dtype is None else dtype,
402 trainable=False,
403 initializer=initializer,
404 collections=[],
405 synchronization=synchronization,
406 aggregation=aggregation,
407 **additional_kwargs,
408 )
410 ### End: For use by subclasses ###
412 @property
413 def trainable_weights(self):
414 # Overridden from Layer class to track submetric weights.
415 if self.trainable:
416 trainable_weights = self._trainable_weights
417 for m in self._metrics:
418 trainable_weights += m.trainable_weights
419 return self._dedup_weights(trainable_weights)
420 else:
421 return []
423 @property
424 def non_trainable_weights(self):
425 # Overridden from Layer class to track submetric weights.
426 if self.trainable:
427 non_trainable_weights = self._non_trainable_weights
428 for m in self._metrics:
429 non_trainable_weights += m.non_trainable_weights
430 else:
431 non_trainable_weights = (
432 self._non_trainable_weights + self._trainable_weights
433 )
434 for m in self._metrics:
435 non_trainable_weights += m.weights
436 return self._dedup_weights(non_trainable_weights)
438 @property
439 def _trackable_saved_model_saver(self):
440 return metric_serialization.MetricSavedModelSaver(self)
442 @generic_utils.default
443 @doc_controls.do_not_generate_docs
444 def reset_states(self):
445 # Backwards compatibility alias of `reset_state`. New classes should
446 # only implement `reset_state`.
447 return self.reset_state()
450class Reduce(Metric):
451 """Encapsulates metrics that perform a reduce operation on the values.
453 Args:
454 reduction: a `tf.keras.metrics.Reduction` enum value.
455 name: string name of the metric instance.
456 dtype: (Optional) data type of the metric result.
457 """
459 def __init__(self, reduction, name, dtype=None):
460 super().__init__(name=name, dtype=dtype)
461 self.reduction = reduction
462 self.total = self.add_weight("total", initializer="zeros")
463 if reduction in [
464 metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
465 metrics_utils.Reduction.WEIGHTED_MEAN,
466 ]:
467 self.count = self.add_weight("count", initializer="zeros")
469 def update_state(self, values, sample_weight=None):
470 """Accumulates statistics for computing the metric.
472 Args:
473 values: Per-example value.
474 sample_weight: Optional weighting of each example. Defaults to `1`.
476 Returns:
477 Update op.
478 """
479 [
480 values
481 ], sample_weight = metrics_utils.ragged_assert_compatible_and_get_flat_values( # noqa: E501
482 [values], sample_weight
483 )
484 try:
485 values = tf.cast(values, self._dtype)
486 except (ValueError, TypeError):
487 msg = (
488 "The output of a metric function can only be a single Tensor. "
489 f"Received: {values}. "
490 )
491 if isinstance(values, dict):
492 msg += (
493 "To return a dict of values, implement a custom Metric "
494 "subclass."
495 )
496 raise RuntimeError(msg)
497 if sample_weight is not None:
498 sample_weight = tf.cast(sample_weight, self._dtype)
499 # Update dimensions of weights to match with values if possible.
500 (
501 values,
502 _,
503 sample_weight,
504 ) = losses_utils.squeeze_or_expand_dimensions(
505 values, sample_weight=sample_weight
506 )
507 try:
508 # Broadcast weights if possible.
509 sample_weight = tf.__internal__.ops.broadcast_weights(
510 sample_weight, values
511 )
512 except ValueError:
513 # Reduce values to same ndim as weight array
514 ndim = backend.ndim(values)
515 weight_ndim = backend.ndim(sample_weight)
516 if self.reduction == metrics_utils.Reduction.SUM:
517 values = tf.reduce_sum(
518 values, axis=list(range(weight_ndim, ndim))
519 )
520 else:
521 values = tf.reduce_mean(
522 values, axis=list(range(weight_ndim, ndim))
523 )
524 values = tf.multiply(values, sample_weight)
526 value_sum = tf.reduce_sum(values)
527 with tf.control_dependencies([value_sum]):
528 update_total_op = self.total.assign_add(value_sum)
530 # Exit early if the reduction doesn't have a denominator.
531 if self.reduction == metrics_utils.Reduction.SUM:
532 return update_total_op
534 # Update `count` for reductions that require a denominator.
535 if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE:
536 num_values = tf.cast(tf.size(values), self._dtype)
537 elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN:
538 if sample_weight is None:
539 num_values = tf.cast(tf.size(values), self._dtype)
540 else:
541 num_values = tf.reduce_sum(sample_weight)
542 else:
543 raise NotImplementedError(
544 f'Reduction "{self.reduction}" not implemented. Expected '
545 '"sum", "weighted_mean", or "sum_over_batch_size".'
546 )
548 with tf.control_dependencies([update_total_op]):
549 return self.count.assign_add(num_values)
551 def result(self):
552 if self.reduction == metrics_utils.Reduction.SUM:
553 return tf.identity(self.total)
554 elif self.reduction in [
555 metrics_utils.Reduction.WEIGHTED_MEAN,
556 metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
557 ]:
558 return tf.math.divide_no_nan(self.total, self.count)
559 else:
560 raise NotImplementedError(
561 f'Reduction "{self.reduction}" not implemented. Expected '
562 '"sum", "weighted_mean", or "sum_over_batch_size".'
563 )
566@keras_export("keras.metrics.Sum")
567class Sum(Reduce):
568 """Computes the (weighted) sum of the given values.
570 For example, if values is [1, 3, 5, 7] then the sum is 16.
571 If the weights were specified as [1, 1, 0, 0] then the sum would be 4.
573 This metric creates one variable, `total`, that is used to compute the sum
574 of `values`. This is ultimately returned as `sum`.
576 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of
577 0 to mask values.
579 Args:
580 name: (Optional) string name of the metric instance.
581 dtype: (Optional) data type of the metric result.
583 Standalone usage:
585 >>> m = tf.keras.metrics.Sum()
586 >>> m.update_state([1, 3, 5, 7])
587 >>> m.result().numpy()
588 16.0
590 Usage with `compile()` API:
592 ```python
593 model.add_metric(tf.keras.metrics.Sum(name='sum_1')(outputs))
594 model.compile(optimizer='sgd', loss='mse')
595 ```
596 """
598 @dtensor_utils.inject_mesh
599 def __init__(self, name="sum", dtype=None):
600 super().__init__(
601 reduction=metrics_utils.Reduction.SUM, name=name, dtype=dtype
602 )
605@keras_export("keras.metrics.Mean")
606class Mean(Reduce):
607 """Computes the (weighted) mean of the given values.
609 For example, if values is [1, 3, 5, 7] then the mean is 4.
610 If the weights were specified as [1, 1, 0, 0] then the mean would be 2.
612 This metric creates two variables, `total` and `count` that are used to
613 compute the average of `values`. This average is ultimately returned as
614 `mean` which is an idempotent operation that simply divides `total` by
615 `count`.
617 If `sample_weight` is `None`, weights default to 1.
618 Use `sample_weight` of 0 to mask values.
620 Args:
621 name: (Optional) string name of the metric instance.
622 dtype: (Optional) data type of the metric result.
624 Standalone usage:
626 >>> m = tf.keras.metrics.Mean()
627 >>> m.update_state([1, 3, 5, 7])
628 >>> m.result().numpy()
629 4.0
630 >>> m.reset_state()
631 >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0])
632 >>> m.result().numpy()
633 2.0
635 Usage with `compile()` API:
637 ```python
638 model.add_metric(tf.keras.metrics.Mean(name='mean_1')(outputs))
639 model.compile(optimizer='sgd', loss='mse')
640 ```
641 """
643 @dtensor_utils.inject_mesh
644 def __init__(self, name="mean", dtype=None):
645 super().__init__(
646 reduction=metrics_utils.Reduction.WEIGHTED_MEAN,
647 name=name,
648 dtype=dtype,
649 )
652@keras_export("keras.metrics.MeanMetricWrapper")
653class MeanMetricWrapper(Mean):
654 """Wraps a stateless metric function with the Mean metric.
656 You could use this class to quickly build a mean metric from a function. The
657 function needs to have the signature `fn(y_true, y_pred)` and return a
658 per-sample loss array. `MeanMetricWrapper.result()` will return
659 the average metric value across all samples seen so far.
661 For example:
663 ```python
664 def accuracy(y_true, y_pred):
665 return tf.cast(tf.math.equal(y_true, y_pred), tf.float32)
667 accuracy_metric = tf.keras.metrics.MeanMetricWrapper(fn=accuracy)
669 keras_model.compile(..., metrics=accuracy_metric)
670 ```
672 Args:
673 fn: The metric function to wrap, with signature `fn(y_true, y_pred,
674 **kwargs)`.
675 name: (Optional) string name of the metric instance.
676 dtype: (Optional) data type of the metric result.
677 **kwargs: Keyword arguments to pass on to `fn`.
678 """
680 @dtensor_utils.inject_mesh
681 def __init__(self, fn, name=None, dtype=None, **kwargs):
682 super().__init__(name=name, dtype=dtype)
683 self._fn = fn
684 self._fn_kwargs = kwargs
686 def update_state(self, y_true, y_pred, sample_weight=None):
687 """Accumulates metric statistics.
689 `y_true` and `y_pred` should have the same shape.
691 Args:
692 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
693 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
694 sample_weight: Optional `sample_weight` acts as a
695 coefficient for the metric. If a scalar is provided, then the metric
696 is simply scaled by the given value. If `sample_weight` is a tensor
697 of size `[batch_size]`, then the metric for each sample of the batch
698 is rescaled by the corresponding element in the `sample_weight`
699 vector. If the shape of `sample_weight` is `[batch_size, d0, ..
700 dN-1]` (or can be broadcasted to this shape), then each metric
701 element of `y_pred` is scaled by the corresponding value of
702 `sample_weight`. (Note on `dN-1`: all metric functions reduce by 1
703 dimension, usually the last axis (-1)).
705 Returns:
706 Update op.
707 """
708 y_true = tf.cast(y_true, self._dtype)
709 y_pred = tf.cast(y_pred, self._dtype)
710 [
711 y_true,
712 y_pred,
713 ], sample_weight = metrics_utils.ragged_assert_compatible_and_get_flat_values( # noqa: E501
714 [y_true, y_pred], sample_weight
715 )
716 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
717 y_pred, y_true
718 )
720 ag_fn = tf.__internal__.autograph.tf_convert(
721 self._fn, tf.__internal__.autograph.control_status_ctx()
722 )
723 matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
724 mask = losses_utils.get_mask(matches)
725 sample_weight = losses_utils.apply_valid_mask(
726 matches, sample_weight, mask, self.reduction
727 )
728 return super().update_state(matches, sample_weight=sample_weight)
730 def get_config(self):
731 config = {
732 k: backend.eval(v) if tf_utils.is_tensor_or_variable(v) else v
733 for k, v in self._fn_kwargs.items()
734 }
736 if type(self) is MeanMetricWrapper:
737 # Only include function argument when the object is a
738 # MeanMetricWrapper and not a subclass.
739 config["fn"] = self._fn
741 base_config = super().get_config()
742 return dict(list(base_config.items()) + list(config.items()))
744 @classmethod
745 def from_config(cls, config):
746 from keras.src.metrics import get
748 # Note that while MeanMetricWrapper itself isn't public, objects of this
749 # class may be created and added to the model by calling model.compile.
750 fn = config.pop("fn", None)
751 if cls is MeanMetricWrapper:
752 return cls(get(fn), **config)
753 return super(MeanMetricWrapper, cls).from_config(config)
756@keras_export("keras.metrics.MeanTensor")
757class MeanTensor(Metric):
758 """Computes the element-wise (weighted) mean of the given tensors.
760 `MeanTensor` returns a tensor with the same shape of the input tensors. The
761 mean value is updated by keeping local variables `total` and `count`. The
762 `total` tracks the sum of the weighted values, and `count` stores the sum of
763 the weighted counts.
765 Args:
766 name: (Optional) string name of the metric instance.
767 dtype: (Optional) data type of the metric result.
768 shape: (Optional) A list of integers, a tuple of integers, or a 1-D Tensor
769 of type int32. If not specified, the shape is inferred from the values
770 at the first call of update_state.
772 Standalone usage:
774 >>> m = tf.keras.metrics.MeanTensor()
775 >>> m.update_state([0, 1, 2, 3])
776 >>> m.update_state([4, 5, 6, 7])
777 >>> m.result().numpy()
778 array([2., 3., 4., 5.], dtype=float32)
780 >>> m.update_state([12, 10, 8, 6], sample_weight= [0, 0.2, 0.5, 1])
781 >>> m.result().numpy()
782 array([2. , 3.6363635, 4.8 , 5.3333335], dtype=float32)
784 >>> m = tf.keras.metrics.MeanTensor(dtype=tf.float64, shape=(1, 4))
785 >>> m.result().numpy()
786 array([[0., 0., 0., 0.]])
787 >>> m.update_state([[0, 1, 2, 3]])
788 >>> m.update_state([[4, 5, 6, 7]])
789 >>> m.result().numpy()
790 array([[2., 3., 4., 5.]])
791 """
793 @dtensor_utils.inject_mesh
794 def __init__(self, name="mean_tensor", dtype=None, shape=None):
795 super().__init__(name=name, dtype=dtype)
796 self._shape = None
797 self._total = None
798 self._count = None
799 self._built = False
800 if shape is not None:
801 self._build(shape)
803 def _build(self, shape):
804 self._shape = tf.TensorShape(shape)
805 self._build_input_shape = self._shape
806 # Create new state variables
807 self._total = self.add_weight(
808 name="total", shape=shape, initializer="zeros"
809 )
810 self._count = self.add_weight(
811 name="count", shape=shape, initializer="zeros"
812 )
813 with tf.init_scope():
814 if not tf.executing_eagerly():
815 backend._initialize_variables(backend._get_session())
816 self._built = True
818 @property
819 def total(self):
820 return self._total if self._built else None
822 @property
823 def count(self):
824 return self._count if self._built else None
826 def update_state(self, values, sample_weight=None):
827 """Accumulates statistics for computing the element-wise mean.
829 Args:
830 values: Per-example value.
831 sample_weight: Optional weighting of each example. Defaults to `1`.
833 Returns:
834 Update op.
835 """
836 values = tf.cast(values, self._dtype)
837 if not self._built:
838 self._build(values.shape)
839 elif values.shape != self._shape:
840 raise ValueError(
841 "MeanTensor input values must always have the same "
842 "shape. Expected shape (set during the first call): "
843 f"{self._shape}. "
844 f"Got: {values.shape}."
845 )
847 num_values = tf.ones_like(values)
848 if sample_weight is not None:
849 sample_weight = tf.cast(sample_weight, self._dtype)
851 # Update dimensions of weights to match with values if possible.
852 (
853 values,
854 _,
855 sample_weight,
856 ) = losses_utils.squeeze_or_expand_dimensions(
857 values, sample_weight=sample_weight
858 )
859 try:
860 # Broadcast weights if possible.
861 sample_weight = tf.__internal__.ops.broadcast_weights(
862 sample_weight, values
863 )
864 except ValueError:
865 # Reduce values to same ndim as weight array
866 ndim = backend.ndim(values)
867 weight_ndim = backend.ndim(sample_weight)
868 values = tf.reduce_mean(
869 values, axis=list(range(weight_ndim, ndim))
870 )
872 num_values = tf.multiply(num_values, sample_weight)
873 values = tf.multiply(values, sample_weight)
875 update_total_op = self._total.assign_add(values)
876 with tf.control_dependencies([update_total_op]):
877 return self._count.assign_add(num_values)
879 def result(self):
880 if not self._built:
881 raise ValueError(
882 "MeanTensor does not have any value yet. Please call the "
883 "MeanTensor instance or use `.update_state(value)` "
884 "before retrieving the result."
885 )
886 return tf.math.divide_no_nan(self.total, self.count)
888 def reset_state(self):
889 if self._built:
890 backend.batch_set_value(
891 [(v, np.zeros(v.shape.as_list())) for v in self.variables]
892 )
895class SumOverBatchSize(Reduce):
896 """Computes the weighted sum over batch size of the given values.
898 For example, if values is [1, 3, 5, 7] then the metric value is 4.
899 If the weights were specified as [1, 1, 0, 0] then the value would be 1.
901 This metric creates two variables, `total` and `count` that are used to
902 compute the average of `values`. This average is ultimately returned as sum
903 over batch size which is an idempotent operation that simply divides `total`
904 by `count`.
906 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of
907 0 to mask values.
908 """
910 def __init__(self, name="sum_over_batch_size", dtype=None):
911 super().__init__(
912 reduction=metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
913 name=name,
914 dtype=dtype,
915 )
918class SumOverBatchSizeMetricWrapper(SumOverBatchSize):
919 """Wraps a function with the `SumOverBatchSizeMetricWrapper` metric."""
921 def __init__(self, fn, name=None, dtype=None, **kwargs):
922 """Creates a `SumOverBatchSizeMetricWrapper` instance.
924 Args:
925 fn: The metric function to wrap, with signature `fn(y_true, y_pred,
926 **kwargs)`.
927 name: (Optional) string name of the metric instance.
928 dtype: (Optional) data type of the metric result.
929 **kwargs: The keyword arguments that are passed on to `fn`.
930 """
931 super().__init__(name=name, dtype=dtype)
932 self._fn = fn
933 self._fn_kwargs = kwargs
935 def update_state(self, y_true, y_pred, sample_weight=None):
936 y_true = tf.cast(y_true, self._dtype)
937 y_pred = tf.cast(y_pred, self._dtype)
938 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
939 y_pred, y_true
940 )
942 ag_fn = tf.__internal__.autograph.tf_convert(
943 self._fn, tf.__internal__.autograph.control_status_ctx()
944 )
945 matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
946 mask = losses_utils.get_mask(matches)
947 sample_weight = losses_utils.apply_valid_mask(
948 matches, sample_weight, mask, self.reduction
949 )
950 return super().update_state(matches, sample_weight=sample_weight)
952 def get_config(self):
953 config = {
954 k: backend.eval(v) if tf_utils.is_tensor_or_variable(v) else v
955 for k, v in self._fn_kwargs.items()
956 }
957 base_config = super().get_config()
958 return dict(list(base_config.items()) + list(config.items()))
961def clone_metric(metric):
962 """Returns a clone of the metric if stateful, otherwise returns it as is."""
963 if isinstance(metric, Metric):
964 # Metrics created within a remotely-executed tf.function during
965 # parameter server evaluation should not be lifted out of the graph by
966 # `init_scope`. This way the metric variables can be local: freely
967 # usable and mutable within the function. This supports a visitation
968 # guarantee for model evaluation.
969 if tf_utils.in_local_vars_context():
970 return metric.__class__.from_config(metric.get_config())
971 else:
972 with tf.init_scope():
973 return metric.__class__.from_config(metric.get_config())
974 return metric
977def clone_metrics(metrics):
978 """Clones the given metric list/dict."""
979 return tf.nest.map_structure(clone_metric, metrics)
982def is_built_in(cls):
983 return cls.__module__.startswith(
984 ".".join(Metric.__module__.split(".")[:-1])
985 )
988class _MetricDict(dict):
989 """Wrapper for returned dictionary of metrics."""
991 def __init__(self, **kwargs):
992 super().__init__(**kwargs)
993 self._metric_obj = None