Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/metrics/utils.py: 37%
35 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for metrics."""
17import numpy as np
18import tensorflow as tf
19from tensorflow_addons.utils.types import AcceptableDTypes
21from typeguard import typechecked
22from typing import Optional, Callable
25class MeanMetricWrapper(tf.keras.metrics.Mean):
26 """Wraps a stateless metric function with the Mean metric."""
28 @typechecked
29 def __init__(
30 self,
31 fn: Callable,
32 name: Optional[str] = None,
33 dtype: AcceptableDTypes = None,
34 **kwargs,
35 ):
36 """Creates a `MeanMetricWrapper` instance.
37 Args:
38 fn: The metric function to wrap, with signature
39 `fn(y_true, y_pred, **kwargs)`.
40 name: (Optional) string name of the metric instance.
41 dtype: (Optional) data type of the metric result.
42 **kwargs: The keyword arguments that are passed on to `fn`.
43 """
44 super().__init__(name=name, dtype=dtype)
45 self._fn = fn
46 self._fn_kwargs = kwargs
48 def update_state(self, y_true, y_pred, sample_weight=None):
49 """Accumulates metric statistics.
51 `y_true` and `y_pred` should have the same shape.
52 Args:
53 y_true: The ground truth values.
54 y_pred: The predicted values.
55 sample_weight: Optional weighting of each example. Defaults to 1.
56 Can be a `Tensor` whose rank is either 0, or the same rank as
57 `y_true`, and must be broadcastable to `y_true`.
58 Returns:
59 Update op.
60 """
61 y_true = tf.cast(y_true, self._dtype)
62 y_pred = tf.cast(y_pred, self._dtype)
63 # TODO: Add checks for ragged tensors and dimensions:
64 # `ragged_assert_compatible_and_get_flat_values`
65 # and `squeeze_or_expand_dimensions`
66 matches = self._fn(y_true, y_pred, **self._fn_kwargs)
67 return super().update_state(matches, sample_weight=sample_weight)
69 def get_config(self):
70 config = {k: v for k, v in self._fn_kwargs.items()}
71 base_config = super().get_config()
72 return {**base_config, **config}
75def _get_model(metric, num_output):
76 # Test API comptibility with tf.keras Model
77 model = tf.keras.Sequential()
78 model.add(tf.keras.layers.Dense(64, activation="relu"))
79 model.add(tf.keras.layers.Dense(num_output, activation="softmax"))
80 model.compile(
81 optimizer="adam", loss="categorical_crossentropy", metrics=["acc", metric]
82 )
84 data = np.random.random((10, 3))
85 labels = np.random.random((10, num_output))
86 model.fit(data, labels, epochs=1, batch_size=5, verbose=0)
89def sample_weight_shape_match(v, sample_weight):
90 if sample_weight is None:
91 return tf.ones_like(v)
92 if np.size(sample_weight) == 1:
93 return tf.fill(v.shape, sample_weight)
94 return tf.convert_to_tensor(sample_weight)