Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/losses/kappa_loss.py: 24%
45 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements Weighted kappa loss."""
17from typing import Optional
19import tensorflow as tf
20from typeguard import typechecked
22from tensorflow_addons.utils.types import Number
25@tf.keras.utils.register_keras_serializable(package="Addons")
26class WeightedKappaLoss(tf.keras.losses.Loss):
27 r"""Implements the Weighted Kappa loss function.
29 Weighted Kappa loss was introduced in the
30 [Weighted kappa loss function for multi-class classification
31 of ordinal data in deep learning]
32 (https://www.sciencedirect.com/science/article/abs/pii/S0167865517301666).
33 Weighted Kappa is widely used in Ordinal Classification Problems.
34 The loss value lies in $ [-\infty, \log 2] $, where $ \log 2 $
35 means the random prediction.
37 Usage:
39 >>> kappa_loss = tfa.losses.WeightedKappaLoss(num_classes=4)
40 >>> y_true = tf.constant([[0, 0, 1, 0], [0, 1, 0, 0],
41 ... [1, 0, 0, 0], [0, 0, 0, 1]])
42 >>> y_pred = tf.constant([[0.1, 0.2, 0.6, 0.1], [0.1, 0.5, 0.3, 0.1],
43 ... [0.8, 0.05, 0.05, 0.1], [0.01, 0.09, 0.1, 0.8]])
44 >>> loss = kappa_loss(y_true, y_pred)
45 >>> loss
46 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1611925>
48 Usage with `tf.keras` API:
50 >>> model = tf.keras.Model()
51 >>> model.compile('sgd', loss=tfa.losses.WeightedKappaLoss(num_classes=4))
53 <... outputs should be softmax results
54 if you want to weight the samples, just multiply the outputs
55 by the sample weight ...>
57 """
59 @typechecked
60 def __init__(
61 self,
62 num_classes: int,
63 weightage: Optional[str] = "quadratic",
64 name: Optional[str] = "cohen_kappa_loss",
65 epsilon: Optional[Number] = 1e-6,
66 reduction: str = tf.keras.losses.Reduction.NONE,
67 ):
68 r"""Creates a `WeightedKappaLoss` instance.
70 Args:
71 num_classes: Number of unique classes in your dataset.
72 weightage: (Optional) Weighting to be considered for calculating
73 kappa statistics. A valid value is one of
74 ['linear', 'quadratic']. Defaults to 'quadratic'.
75 name: (Optional) String name of the metric instance.
76 epsilon: (Optional) increment to avoid log zero,
77 so the loss will be $ \log(1 - k + \epsilon) $, where $ k $ lies
78 in $ [-1, 1] $. Defaults to 1e-6.
79 Raises:
80 ValueError: If the value passed for `weightage` is invalid
81 i.e. not any one of ['linear', 'quadratic']
82 """
84 super().__init__(name=name, reduction=reduction)
86 if weightage not in ("linear", "quadratic"):
87 raise ValueError("Unknown kappa weighting type.")
89 self.weightage = weightage
90 self.num_classes = num_classes
91 self.epsilon = epsilon or tf.keras.backend.epsilon()
92 label_vec = tf.range(num_classes, dtype=tf.keras.backend.floatx())
93 self.row_label_vec = tf.reshape(label_vec, [1, num_classes])
94 self.col_label_vec = tf.reshape(label_vec, [num_classes, 1])
95 col_mat = tf.tile(self.col_label_vec, [1, num_classes])
96 row_mat = tf.tile(self.row_label_vec, [num_classes, 1])
97 if weightage == "linear":
98 self.weight_mat = tf.abs(col_mat - row_mat)
99 else:
100 self.weight_mat = (col_mat - row_mat) ** 2
102 def call(self, y_true, y_pred):
103 y_true = tf.cast(y_true, dtype=self.col_label_vec.dtype)
104 y_pred = tf.cast(y_pred, dtype=self.weight_mat.dtype)
105 batch_size = tf.shape(y_true)[0]
106 cat_labels = tf.matmul(y_true, self.col_label_vec)
107 cat_label_mat = tf.tile(cat_labels, [1, self.num_classes])
108 row_label_mat = tf.tile(self.row_label_vec, [batch_size, 1])
109 if self.weightage == "linear":
110 weight = tf.abs(cat_label_mat - row_label_mat)
111 else:
112 weight = (cat_label_mat - row_label_mat) ** 2
113 numerator = tf.reduce_sum(weight * y_pred)
114 label_dist = tf.reduce_sum(y_true, axis=0, keepdims=True)
115 pred_dist = tf.reduce_sum(y_pred, axis=0, keepdims=True)
116 w_pred_dist = tf.matmul(self.weight_mat, pred_dist, transpose_b=True)
117 denominator = tf.reduce_sum(tf.matmul(label_dist, w_pred_dist))
118 denominator /= tf.cast(batch_size, dtype=denominator.dtype)
119 loss = tf.math.divide_no_nan(numerator, denominator)
120 return tf.math.log(loss + self.epsilon)
122 def get_config(self):
123 config = {
124 "num_classes": self.num_classes,
125 "weightage": self.weightage,
126 "epsilon": self.epsilon,
127 }
128 base_config = super().get_config()
129 return {**base_config, **config}