Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/metrics/py_metric.py: 32%

40 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The Keras Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base class for Python-based metrics""" 

16 

17import types 

18 

19import tensorflow.compat.v2 as tf 

20from tensorflow.python.util.tf_export import keras_export 

21 

22from keras.src.metrics import base_metric 

23 

24 

25@keras_export("keras.metrics.experimental.PyMetric", v1=[]) 

26class PyMetric(base_metric.Metric): 

27 """Metric which runs in Python, compiled outside of the TensorFlow graph. 

28 

29 Args: 

30 name: (Optional) string name of the PyMetric instance. 

31 dtype: (Optional) data type of the PyMetric result. 

32 **kwargs: Additional layer keywords arguments. 

33 

34 Usage of `PyMetric` is generally identical to `keras.metrics.Metric`. 

35 It can be used in isolation, or in tandem with the `compile()` API. For more 

36 information about the usage of `PyMetric`, see `keras.metrics.Metric`. 

37 

38 Unlike regular metrics, `PyMetric` instances are outside-compiled 

39 with respect to the TensorFlow graph during training or evaluation. 

40 They have access to the same 

41 inputs of a standard in-graph metric, but they run in a Python interpreter 

42 on the host CPU. Any data stored in a `PyMetric` is located on the main 

43 memory of the host CPU, and any TensorFlow ops used in a PyMetric are 

44 run eagerly on the host CPU. 

45 

46 As a result, `PyMetric` instances are generally not as performant 

47 as in-graph metrics, and should only be used in cases where computing 

48 the metric inside of the TensorFlow graph is either impossible 

49 or prohibitively expensive. 

50 

51 **Note:** Due to the use of `tf.py_function`, PyMetrics 

52 are incompatible with XLA and therefore TPUs. 

53 

54 Methods to be implemented by subclasses: 

55 

56 * `update_state()`: Handles updates to internal state variables 

57 * `result()`: Computes and returns a scalar value or a dict of scalar values 

58 for the metric from the state variables. 

59 * `reset_state()`: Computes and returns a scalar value for the metric from 

60 the state variables. 

61 

62 This subclass implementation is similar to that of `keras.metrics.Metric`, 

63 with two notable differences: 

64 

65 * Inputs to `update_state()` in a `PyMetric` are eager tensors, and both 

66 `update_state()` and `result()` run outside of the TensorFlow graph, 

67 executing any TensorFlow ops eagerly. 

68 * `reset_state()` is also called at initialization time to initialize the 

69 Python state of the metric. 

70 * `result()` can only return a single scalar. It does not support returning 

71 a dictionary of results like `keras.metrics.Metric`. 

72 

73 Example subclass implementation using sklearn's Jaccard Score: 

74 

75 ```python 

76 from sklearn.metrics import jaccard_score 

77 import tensorflow as tf 

78 

79 class JaccardScore(tf.keras.metrics.experimental.PyMetric): 

80 

81 def __init__(self, name='jaccard_score', **kwargs): 

82 super().__init__(name=name, **kwargs) 

83 

84 def update_state(self, y_true, y_pred, sample_weight=None): 

85 self.jaccard_sum += jaccard_score(y_pred, y_true, average="macro") 

86 self.count += 1 

87 

88 def reset_state(self): 

89 self.jaccard_sum = 0. 

90 self.count = 0. 

91 

92 def result(self): 

93 return self.jaccard_sum / self.count 

94 ``` 

95 """ 

96 

97 def __init__(self, name=None, dtype=None, **kwargs): 

98 super().__init__(name=name, dtype=dtype, **kwargs) 

99 self.reset_state() 

100 

101 def __new__(cls, *args, **kwargs): 

102 obj = super(base_metric.Metric, cls).__new__(cls) 

103 

104 # Wrap the update_state function in a py_function and scope it to /cpu:0 

105 obj_update_state = obj.update_state 

106 

107 def update_state_on_cpu(y_true, y_pred, sample_weight=None): 

108 with tf.device("/cpu:0"): 

109 return obj_update_state(y_true, y_pred, sample_weight) 

110 

111 obj.update_state_on_cpu = update_state_on_cpu 

112 

113 def update_state_fn(self, y_true, y_pred, sample_weight=None): 

114 eager_inputs = [y_true, y_pred] 

115 if sample_weight is not None: 

116 eager_inputs.append(sample_weight) 

117 return tf.py_function( 

118 func=self.update_state_on_cpu, inp=eager_inputs, Tout=[] 

119 ) 

120 

121 obj.update_state = types.MethodType(update_state_fn, obj) 

122 

123 # Wrap the result function in a py_function and scope it to /cpu:0 

124 obj_result = obj.result 

125 

126 def result_on_host_cpu(): 

127 with tf.device("/cpu:0"): 

128 return obj_result() 

129 

130 obj.result_on_host_cpu = result_on_host_cpu 

131 

132 def result_fn(self): 

133 return tf.py_function( 

134 self.result_on_host_cpu, inp=[], Tout=obj.dtype 

135 ) 

136 

137 obj.result = types.MethodType(result_fn, obj) 

138 

139 return obj 

140 

141 def update_state(self, y_true, y_pred, sample_weight=None): 

142 """Accumulates statistics for the metric. 

143 

144 **Note:** This function is executed outside of the TensorFlow graph 

145 on the CPU host. 

146 

147 This means: 

148 

149 a) Inputs are eager tensors. 

150 b) Any TensorFlow ops run in this method are run eagerly. 

151 c) Any Tensors created are allocated to the CPU's main memory. 

152 

153 Args: 

154 y_true: Target output 

155 y_pred: Predicted output 

156 sample_weight: (Optional) weights for the individual samples in 

157 `y_true` and `y_pred` 

158 """ 

159 raise NotImplementedError("Subclasses should implement `update_state`") 

160 

161 def merge_state(self, metrics): 

162 """Merges the state from one or more metrics. 

163 

164 `PyMetric` instances that intend to support merging state must override 

165 this method, as the default implementation 

166 in `keras.metrics.Metric` does not apply to `PyMetric`. 

167 """ 

168 raise NotImplementedError("Subclasses should implement `merge_state`") 

169 

170 def reset_state(self): 

171 """Resets all of the metric state variables. 

172 

173 This function is called between epochs when a metric is evaluated during 

174 training. It's also called when the metric is initialized. 

175 """ 

176 raise NotImplementedError("Subclasses should implement `reset_state`") 

177 

178 def result(self): 

179 """Computes and returns the scalar metric value. 

180 

181 **Note:** This function is executed outside of the TensorFlow graph 

182 on the CPU host. This means any TensorFlow ops run in this method 

183 are run eagerly. 

184 

185 Result computation is an idempotent operation that simply calculates the 

186 metric value using the state variables. 

187 

188 Returns: 

189 A Python scalar. 

190 """ 

191 raise NotImplementedError("Subclasses should implement `result`") 

192