Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/metrics/hinge_metrics.py: 86%
22 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Hinge metrics."""
17from keras.src.dtensor import utils as dtensor_utils
18from keras.src.losses import categorical_hinge
19from keras.src.losses import hinge
20from keras.src.losses import squared_hinge
21from keras.src.metrics import base_metric
23# isort: off
24from tensorflow.python.util.tf_export import keras_export
27@keras_export("keras.metrics.Hinge")
28class Hinge(base_metric.MeanMetricWrapper):
29 """Computes the hinge metric between `y_true` and `y_pred`.
31 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
32 provided we will convert them to -1 or 1.
34 Args:
35 name: (Optional) string name of the metric instance.
36 dtype: (Optional) data type of the metric result.
38 Standalone usage:
40 >>> m = tf.keras.metrics.Hinge()
41 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
42 >>> m.result().numpy()
43 1.3
45 >>> m.reset_state()
46 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
47 ... sample_weight=[1, 0])
48 >>> m.result().numpy()
49 1.1
51 Usage with `compile()` API:
53 ```python
54 model.compile(
55 optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.Hinge()])
56 ```
57 """
59 @dtensor_utils.inject_mesh
60 def __init__(self, name="hinge", dtype=None):
61 super().__init__(hinge, name, dtype=dtype)
64@keras_export("keras.metrics.SquaredHinge")
65class SquaredHinge(base_metric.MeanMetricWrapper):
66 """Computes the squared hinge metric between `y_true` and `y_pred`.
68 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
69 provided we will convert them to -1 or 1.
71 Args:
72 name: (Optional) string name of the metric instance.
73 dtype: (Optional) data type of the metric result.
75 Standalone usage:
77 >>> m = tf.keras.metrics.SquaredHinge()
78 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
79 >>> m.result().numpy()
80 1.86
82 >>> m.reset_state()
83 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
84 ... sample_weight=[1, 0])
85 >>> m.result().numpy()
86 1.46
88 Usage with `compile()` API:
90 ```python
91 model.compile(
92 optimizer='sgd',
93 loss='mse',
94 metrics=[tf.keras.metrics.SquaredHinge()])
95 ```
96 """
98 @dtensor_utils.inject_mesh
99 def __init__(self, name="squared_hinge", dtype=None):
100 super().__init__(squared_hinge, name, dtype=dtype)
103@keras_export("keras.metrics.CategoricalHinge")
104class CategoricalHinge(base_metric.MeanMetricWrapper):
105 """Computes the categorical hinge metric between `y_true` and `y_pred`.
107 Args:
108 name: (Optional) string name of the metric instance.
109 dtype: (Optional) data type of the metric result.
111 Standalone usage:
113 >>> m = tf.keras.metrics.CategoricalHinge()
114 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
115 >>> m.result().numpy()
116 1.4000001
118 >>> m.reset_state()
119 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
120 ... sample_weight=[1, 0])
121 >>> m.result().numpy()
122 1.2
124 Usage with `compile()` API:
126 ```python
127 model.compile(
128 optimizer='sgd',
129 loss='mse',
130 metrics=[tf.keras.metrics.CategoricalHinge()])
131 ```
132 """
134 @dtensor_utils.inject_mesh
135 def __init__(self, name="categorical_hinge", dtype=None):
136 super().__init__(categorical_hinge, name, dtype=dtype)