Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/metrics/hinge_metrics.py: 86%

22 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Hinge metrics.""" 

16 

17from keras.src.dtensor import utils as dtensor_utils 

18from keras.src.losses import categorical_hinge 

19from keras.src.losses import hinge 

20from keras.src.losses import squared_hinge 

21from keras.src.metrics import base_metric 

22 

23# isort: off 

24from tensorflow.python.util.tf_export import keras_export 

25 

26 

27@keras_export("keras.metrics.Hinge") 

28class Hinge(base_metric.MeanMetricWrapper): 

29 """Computes the hinge metric between `y_true` and `y_pred`. 

30 

31 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 

32 provided we will convert them to -1 or 1. 

33 

34 Args: 

35 name: (Optional) string name of the metric instance. 

36 dtype: (Optional) data type of the metric result. 

37 

38 Standalone usage: 

39 

40 >>> m = tf.keras.metrics.Hinge() 

41 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 

42 >>> m.result().numpy() 

43 1.3 

44 

45 >>> m.reset_state() 

46 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 

47 ... sample_weight=[1, 0]) 

48 >>> m.result().numpy() 

49 1.1 

50 

51 Usage with `compile()` API: 

52 

53 ```python 

54 model.compile( 

55 optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.Hinge()]) 

56 ``` 

57 """ 

58 

59 @dtensor_utils.inject_mesh 

60 def __init__(self, name="hinge", dtype=None): 

61 super().__init__(hinge, name, dtype=dtype) 

62 

63 

64@keras_export("keras.metrics.SquaredHinge") 

65class SquaredHinge(base_metric.MeanMetricWrapper): 

66 """Computes the squared hinge metric between `y_true` and `y_pred`. 

67 

68 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 

69 provided we will convert them to -1 or 1. 

70 

71 Args: 

72 name: (Optional) string name of the metric instance. 

73 dtype: (Optional) data type of the metric result. 

74 

75 Standalone usage: 

76 

77 >>> m = tf.keras.metrics.SquaredHinge() 

78 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 

79 >>> m.result().numpy() 

80 1.86 

81 

82 >>> m.reset_state() 

83 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 

84 ... sample_weight=[1, 0]) 

85 >>> m.result().numpy() 

86 1.46 

87 

88 Usage with `compile()` API: 

89 

90 ```python 

91 model.compile( 

92 optimizer='sgd', 

93 loss='mse', 

94 metrics=[tf.keras.metrics.SquaredHinge()]) 

95 ``` 

96 """ 

97 

98 @dtensor_utils.inject_mesh 

99 def __init__(self, name="squared_hinge", dtype=None): 

100 super().__init__(squared_hinge, name, dtype=dtype) 

101 

102 

103@keras_export("keras.metrics.CategoricalHinge") 

104class CategoricalHinge(base_metric.MeanMetricWrapper): 

105 """Computes the categorical hinge metric between `y_true` and `y_pred`. 

106 

107 Args: 

108 name: (Optional) string name of the metric instance. 

109 dtype: (Optional) data type of the metric result. 

110 

111 Standalone usage: 

112 

113 >>> m = tf.keras.metrics.CategoricalHinge() 

114 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 

115 >>> m.result().numpy() 

116 1.4000001 

117 

118 >>> m.reset_state() 

119 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 

120 ... sample_weight=[1, 0]) 

121 >>> m.result().numpy() 

122 1.2 

123 

124 Usage with `compile()` API: 

125 

126 ```python 

127 model.compile( 

128 optimizer='sgd', 

129 loss='mse', 

130 metrics=[tf.keras.metrics.CategoricalHinge()]) 

131 ``` 

132 """ 

133 

134 @dtensor_utils.inject_mesh 

135 def __init__(self, name="categorical_hinge", dtype=None): 

136 super().__init__(categorical_hinge, name, dtype=dtype) 

137