Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/losses/kappa_loss.py: 24%

45 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implements Weighted kappa loss.""" 

16 

17from typing import Optional 

18 

19import tensorflow as tf 

20from typeguard import typechecked 

21 

22from tensorflow_addons.utils.types import Number 

23 

24 

25@tf.keras.utils.register_keras_serializable(package="Addons") 

26class WeightedKappaLoss(tf.keras.losses.Loss): 

27 r"""Implements the Weighted Kappa loss function. 

28 

29 Weighted Kappa loss was introduced in the 

30 [Weighted kappa loss function for multi-class classification 

31 of ordinal data in deep learning] 

32 (https://www.sciencedirect.com/science/article/abs/pii/S0167865517301666). 

33 Weighted Kappa is widely used in Ordinal Classification Problems. 

34 The loss value lies in $ [-\infty, \log 2] $, where $ \log 2 $ 

35 means the random prediction. 

36 

37 Usage: 

38 

39 >>> kappa_loss = tfa.losses.WeightedKappaLoss(num_classes=4) 

40 >>> y_true = tf.constant([[0, 0, 1, 0], [0, 1, 0, 0], 

41 ... [1, 0, 0, 0], [0, 0, 0, 1]]) 

42 >>> y_pred = tf.constant([[0.1, 0.2, 0.6, 0.1], [0.1, 0.5, 0.3, 0.1], 

43 ... [0.8, 0.05, 0.05, 0.1], [0.01, 0.09, 0.1, 0.8]]) 

44 >>> loss = kappa_loss(y_true, y_pred) 

45 >>> loss 

46 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1611925> 

47 

48 Usage with `tf.keras` API: 

49 

50 >>> model = tf.keras.Model() 

51 >>> model.compile('sgd', loss=tfa.losses.WeightedKappaLoss(num_classes=4)) 

52 

53 <... outputs should be softmax results 

54 if you want to weight the samples, just multiply the outputs 

55 by the sample weight ...> 

56 

57 """ 

58 

59 @typechecked 

60 def __init__( 

61 self, 

62 num_classes: int, 

63 weightage: Optional[str] = "quadratic", 

64 name: Optional[str] = "cohen_kappa_loss", 

65 epsilon: Optional[Number] = 1e-6, 

66 reduction: str = tf.keras.losses.Reduction.NONE, 

67 ): 

68 r"""Creates a `WeightedKappaLoss` instance. 

69 

70 Args: 

71 num_classes: Number of unique classes in your dataset. 

72 weightage: (Optional) Weighting to be considered for calculating 

73 kappa statistics. A valid value is one of 

74 ['linear', 'quadratic']. Defaults to 'quadratic'. 

75 name: (Optional) String name of the metric instance. 

76 epsilon: (Optional) increment to avoid log zero, 

77 so the loss will be $ \log(1 - k + \epsilon) $, where $ k $ lies 

78 in $ [-1, 1] $. Defaults to 1e-6. 

79 Raises: 

80 ValueError: If the value passed for `weightage` is invalid 

81 i.e. not any one of ['linear', 'quadratic'] 

82 """ 

83 

84 super().__init__(name=name, reduction=reduction) 

85 

86 if weightage not in ("linear", "quadratic"): 

87 raise ValueError("Unknown kappa weighting type.") 

88 

89 self.weightage = weightage 

90 self.num_classes = num_classes 

91 self.epsilon = epsilon or tf.keras.backend.epsilon() 

92 label_vec = tf.range(num_classes, dtype=tf.keras.backend.floatx()) 

93 self.row_label_vec = tf.reshape(label_vec, [1, num_classes]) 

94 self.col_label_vec = tf.reshape(label_vec, [num_classes, 1]) 

95 col_mat = tf.tile(self.col_label_vec, [1, num_classes]) 

96 row_mat = tf.tile(self.row_label_vec, [num_classes, 1]) 

97 if weightage == "linear": 

98 self.weight_mat = tf.abs(col_mat - row_mat) 

99 else: 

100 self.weight_mat = (col_mat - row_mat) ** 2 

101 

102 def call(self, y_true, y_pred): 

103 y_true = tf.cast(y_true, dtype=self.col_label_vec.dtype) 

104 y_pred = tf.cast(y_pred, dtype=self.weight_mat.dtype) 

105 batch_size = tf.shape(y_true)[0] 

106 cat_labels = tf.matmul(y_true, self.col_label_vec) 

107 cat_label_mat = tf.tile(cat_labels, [1, self.num_classes]) 

108 row_label_mat = tf.tile(self.row_label_vec, [batch_size, 1]) 

109 if self.weightage == "linear": 

110 weight = tf.abs(cat_label_mat - row_label_mat) 

111 else: 

112 weight = (cat_label_mat - row_label_mat) ** 2 

113 numerator = tf.reduce_sum(weight * y_pred) 

114 label_dist = tf.reduce_sum(y_true, axis=0, keepdims=True) 

115 pred_dist = tf.reduce_sum(y_pred, axis=0, keepdims=True) 

116 w_pred_dist = tf.matmul(self.weight_mat, pred_dist, transpose_b=True) 

117 denominator = tf.reduce_sum(tf.matmul(label_dist, w_pred_dist)) 

118 denominator /= tf.cast(batch_size, dtype=denominator.dtype) 

119 loss = tf.math.divide_no_nan(numerator, denominator) 

120 return tf.math.log(loss + self.epsilon) 

121 

122 def get_config(self): 

123 config = { 

124 "num_classes": self.num_classes, 

125 "weightage": self.weightage, 

126 "epsilon": self.epsilon, 

127 } 

128 base_config = super().get_config() 

129 return {**base_config, **config}