Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/metrics/py_metric.py: 32%
40 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The Keras Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for Python-based metrics"""
17import types
19import tensorflow.compat.v2 as tf
20from tensorflow.python.util.tf_export import keras_export
22from keras.src.metrics import base_metric
25@keras_export("keras.metrics.experimental.PyMetric", v1=[])
26class PyMetric(base_metric.Metric):
27 """Metric which runs in Python, compiled outside of the TensorFlow graph.
29 Args:
30 name: (Optional) string name of the PyMetric instance.
31 dtype: (Optional) data type of the PyMetric result.
32 **kwargs: Additional layer keywords arguments.
34 Usage of `PyMetric` is generally identical to `keras.metrics.Metric`.
35 It can be used in isolation, or in tandem with the `compile()` API. For more
36 information about the usage of `PyMetric`, see `keras.metrics.Metric`.
38 Unlike regular metrics, `PyMetric` instances are outside-compiled
39 with respect to the TensorFlow graph during training or evaluation.
40 They have access to the same
41 inputs of a standard in-graph metric, but they run in a Python interpreter
42 on the host CPU. Any data stored in a `PyMetric` is located on the main
43 memory of the host CPU, and any TensorFlow ops used in a PyMetric are
44 run eagerly on the host CPU.
46 As a result, `PyMetric` instances are generally not as performant
47 as in-graph metrics, and should only be used in cases where computing
48 the metric inside of the TensorFlow graph is either impossible
49 or prohibitively expensive.
51 **Note:** Due to the use of `tf.py_function`, PyMetrics
52 are incompatible with XLA and therefore TPUs.
54 Methods to be implemented by subclasses:
56 * `update_state()`: Handles updates to internal state variables
57 * `result()`: Computes and returns a scalar value or a dict of scalar values
58 for the metric from the state variables.
59 * `reset_state()`: Computes and returns a scalar value for the metric from
60 the state variables.
62 This subclass implementation is similar to that of `keras.metrics.Metric`,
63 with two notable differences:
65 * Inputs to `update_state()` in a `PyMetric` are eager tensors, and both
66 `update_state()` and `result()` run outside of the TensorFlow graph,
67 executing any TensorFlow ops eagerly.
68 * `reset_state()` is also called at initialization time to initialize the
69 Python state of the metric.
70 * `result()` can only return a single scalar. It does not support returning
71 a dictionary of results like `keras.metrics.Metric`.
73 Example subclass implementation using sklearn's Jaccard Score:
75 ```python
76 from sklearn.metrics import jaccard_score
77 import tensorflow as tf
79 class JaccardScore(tf.keras.metrics.experimental.PyMetric):
81 def __init__(self, name='jaccard_score', **kwargs):
82 super().__init__(name=name, **kwargs)
84 def update_state(self, y_true, y_pred, sample_weight=None):
85 self.jaccard_sum += jaccard_score(y_pred, y_true, average="macro")
86 self.count += 1
88 def reset_state(self):
89 self.jaccard_sum = 0.
90 self.count = 0.
92 def result(self):
93 return self.jaccard_sum / self.count
94 ```
95 """
97 def __init__(self, name=None, dtype=None, **kwargs):
98 super().__init__(name=name, dtype=dtype, **kwargs)
99 self.reset_state()
101 def __new__(cls, *args, **kwargs):
102 obj = super(base_metric.Metric, cls).__new__(cls)
104 # Wrap the update_state function in a py_function and scope it to /cpu:0
105 obj_update_state = obj.update_state
107 def update_state_on_cpu(y_true, y_pred, sample_weight=None):
108 with tf.device("/cpu:0"):
109 return obj_update_state(y_true, y_pred, sample_weight)
111 obj.update_state_on_cpu = update_state_on_cpu
113 def update_state_fn(self, y_true, y_pred, sample_weight=None):
114 eager_inputs = [y_true, y_pred]
115 if sample_weight is not None:
116 eager_inputs.append(sample_weight)
117 return tf.py_function(
118 func=self.update_state_on_cpu, inp=eager_inputs, Tout=[]
119 )
121 obj.update_state = types.MethodType(update_state_fn, obj)
123 # Wrap the result function in a py_function and scope it to /cpu:0
124 obj_result = obj.result
126 def result_on_host_cpu():
127 with tf.device("/cpu:0"):
128 return obj_result()
130 obj.result_on_host_cpu = result_on_host_cpu
132 def result_fn(self):
133 return tf.py_function(
134 self.result_on_host_cpu, inp=[], Tout=obj.dtype
135 )
137 obj.result = types.MethodType(result_fn, obj)
139 return obj
141 def update_state(self, y_true, y_pred, sample_weight=None):
142 """Accumulates statistics for the metric.
144 **Note:** This function is executed outside of the TensorFlow graph
145 on the CPU host.
147 This means:
149 a) Inputs are eager tensors.
150 b) Any TensorFlow ops run in this method are run eagerly.
151 c) Any Tensors created are allocated to the CPU's main memory.
153 Args:
154 y_true: Target output
155 y_pred: Predicted output
156 sample_weight: (Optional) weights for the individual samples in
157 `y_true` and `y_pred`
158 """
159 raise NotImplementedError("Subclasses should implement `update_state`")
161 def merge_state(self, metrics):
162 """Merges the state from one or more metrics.
164 `PyMetric` instances that intend to support merging state must override
165 this method, as the default implementation
166 in `keras.metrics.Metric` does not apply to `PyMetric`.
167 """
168 raise NotImplementedError("Subclasses should implement `merge_state`")
170 def reset_state(self):
171 """Resets all of the metric state variables.
173 This function is called between epochs when a metric is evaluated during
174 training. It's also called when the metric is initialized.
175 """
176 raise NotImplementedError("Subclasses should implement `reset_state`")
178 def result(self):
179 """Computes and returns the scalar metric value.
181 **Note:** This function is executed outside of the TensorFlow graph
182 on the CPU host. This means any TensorFlow ops run in this method
183 are run eagerly.
185 Result computation is an idempotent operation that simply calculates the
186 metric value using the state variables.
188 Returns:
189 A Python scalar.
190 """
191 raise NotImplementedError("Subclasses should implement `result`")