Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/python/convert_phase.py: 59%

87 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for collecting TFLite metrics.""" 

16 

17import collections 

18import enum 

19import functools 

20from typing import Text 

21 

22from tensorflow.lite.python.metrics import converter_error_data_pb2 

23from tensorflow.lite.python.metrics import metrics 

24 

25 

26class Component(enum.Enum): 

27 """Enum class defining name of the converter components.""" 

28 # Validate the given input and prepare and optimize TensorFlow Model. 

29 PREPARE_TF_MODEL = "PREPARE_TF_MODEL" 

30 

31 # Convert to TFLite model format. 

32 CONVERT_TF_TO_TFLITE_MODEL = "CONVERT_TF_TO_TFLITE_MODEL" 

33 

34 # RUN quantization and sparsification. 

35 OPTIMIZE_TFLITE_MODEL = "OPTIMIZE_TFLITE_MODEL" 

36 

37 

38SubComponentItem = collections.namedtuple("SubComponentItem", 

39 ["name", "component"]) 

40 

41 

42class SubComponent(SubComponentItem, enum.Enum): 

43 """Enum class defining name of the converter subcomponents. 

44 

45 This enum only defines the subcomponents in Python, there might be more 

46 subcomponents defined in C++. 

47 """ 

48 

49 def __str__(self): 

50 return self.value.name 

51 

52 @property 

53 def name(self): 

54 return self.value.name 

55 

56 @property 

57 def component(self): 

58 return self.value.component 

59 

60 # The subcomponent name is unspecified. 

61 UNSPECIFIED = SubComponentItem("UNSPECIFIED", None) 

62 

63 # Valid the given input and parameters. 

64 VALIDATE_INPUTS = SubComponentItem("VALIDATE_INPUTS", 

65 Component.PREPARE_TF_MODEL) 

66 

67 # Load GraphDef from SavedModel. 

68 LOAD_SAVED_MODEL = SubComponentItem("LOAD_SAVED_MODEL", 

69 Component.PREPARE_TF_MODEL) 

70 

71 # Convert a SavedModel to frozen graph. 

72 FREEZE_SAVED_MODEL = SubComponentItem("FREEZE_SAVED_MODEL", 

73 Component.PREPARE_TF_MODEL) 

74 

75 # Save a Keras model to SavedModel. 

76 CONVERT_KERAS_TO_SAVED_MODEL = SubComponentItem( 

77 "CONVERT_KERAS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL) 

78 

79 # Save Concrete functions to SavedModel. 

80 CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL = SubComponentItem( 

81 "CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL) 

82 

83 # Convert a Keras model to a frozen graph. 

84 FREEZE_KERAS_MODEL = SubComponentItem("FREEZE_KERAS_MODEL", 

85 Component.PREPARE_TF_MODEL) 

86 

87 # Replace all the variables with constants in a ConcreteFunction. 

88 FREEZE_CONCRETE_FUNCTION = SubComponentItem("FREEZE_CONCRETE_FUNCTION", 

89 Component.PREPARE_TF_MODEL) 

90 

91 # Run grappler optimization. 

92 OPTIMIZE_TF_MODEL = SubComponentItem("OPTIMIZE_TF_MODEL", 

93 Component.PREPARE_TF_MODEL) 

94 

95 # Convert using the old TOCO converter. 

96 CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER = SubComponentItem( 

97 "CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER", 

98 Component.CONVERT_TF_TO_TFLITE_MODEL) 

99 

100 # Convert a GraphDef to TFLite model. 

101 CONVERT_GRAPHDEF = SubComponentItem("CONVERT_GRAPHDEF", 

102 Component.CONVERT_TF_TO_TFLITE_MODEL) 

103 

104 # Convert a SavedModel to TFLite model. 

105 CONVERT_SAVED_MODEL = SubComponentItem("CONVERT_SAVED_MODEL", 

106 Component.CONVERT_TF_TO_TFLITE_MODEL) 

107 

108 # Convert a Jax HLO to TFLite model. 

109 CONVERT_JAX_HLO = SubComponentItem("CONVERT_JAX_HLO", 

110 Component.CONVERT_TF_TO_TFLITE_MODEL) 

111 

112 # Do quantization by the deprecated quantizer. 

113 QUANTIZE_USING_DEPRECATED_QUANTIZER = SubComponentItem( 

114 "QUANTIZE_USING_DEPRECATED_QUANTIZER", Component.OPTIMIZE_TFLITE_MODEL) 

115 

116 # Do calibration. 

117 CALIBRATE = SubComponentItem("CALIBRATE", Component.OPTIMIZE_TFLITE_MODEL) 

118 

119 # Do quantization by MLIR. 

120 QUANTIZE = SubComponentItem("QUANTIZE", Component.OPTIMIZE_TFLITE_MODEL) 

121 

122 # Do sparsification by MLIR. 

123 SPARSIFY = SubComponentItem("SPARSIFY", Component.OPTIMIZE_TFLITE_MODEL) 

124 

125 

126class ConverterError(Exception): 

127 """Raised when an error occurs during model conversion.""" 

128 

129 def __init__(self, message): 

130 super(ConverterError, self).__init__(message) 

131 self.errors = [] 

132 self._parse_error_message(message) 

133 

134 def append_error(self, 

135 error_data: converter_error_data_pb2.ConverterErrorData): 

136 self.errors.append(error_data) 

137 

138 def _parse_error_message(self, message): 

139 """If the message matches a pattern, assigns the associated error code. 

140 

141 It is difficult to assign an error code to some errrors in MLIR side, Ex: 

142 errors thrown by other components than TFLite or not using mlir::emitError. 

143 This function try to detect them by the error message and assign the 

144 corresponding error code. 

145 

146 Args: 

147 message: The error message of this exception. 

148 """ 

149 error_code_mapping = { 

150 "Failed to functionalize Control Flow V1 ops. Consider using Control " 

151 "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/" 

152 "tf/compat/v1/enable_control_flow_v2.": 

153 converter_error_data_pb2.ConverterErrorData 

154 .ERROR_UNSUPPORTED_CONTROL_FLOW_V1, 

155 } 

156 for pattern, error_code in error_code_mapping.items(): 

157 if pattern in message: 

158 error_data = converter_error_data_pb2.ConverterErrorData() 

159 error_data.error_message = message 

160 error_data.error_code = error_code 

161 self.append_error(error_data) 

162 return 

163 

164 

165def convert_phase(component, subcomponent=SubComponent.UNSPECIFIED): 

166 """The decorator to identify converter component and subcomponent. 

167 

168 Args: 

169 component: Converter component name. 

170 subcomponent: Converter subcomponent name. 

171 

172 Returns: 

173 Forward the result from the wrapped function. 

174 

175 Raises: 

176 ValueError: if component and subcomponent name is not valid. 

177 """ 

178 if component not in Component: 

179 raise ValueError("Given component name not found") 

180 if subcomponent not in SubComponent: 

181 raise ValueError("Given subcomponent name not found") 

182 if (subcomponent != SubComponent.UNSPECIFIED and 

183 subcomponent.component != component): 

184 raise ValueError("component and subcomponent name don't match") 

185 

186 def report_error(error_data: converter_error_data_pb2.ConverterErrorData): 

187 # Always overwrites the component information, but only overwrites the 

188 # subcomponent if it is not available. 

189 error_data.component = component.value 

190 if not error_data.subcomponent: 

191 error_data.subcomponent = subcomponent.name 

192 tflite_metrics = metrics.TFLiteConverterMetrics() 

193 tflite_metrics.set_converter_error(error_data) 

194 

195 def report_error_message(error_message: Text): 

196 error_data = converter_error_data_pb2.ConverterErrorData() 

197 error_data.error_message = error_message 

198 report_error(error_data) 

199 

200 def actual_decorator(func): 

201 

202 @functools.wraps(func) 

203 def wrapper(*args, **kwargs): 

204 try: 

205 return func(*args, **kwargs) 

206 except ConverterError as converter_error: 

207 if converter_error.errors: 

208 for error_data in converter_error.errors: 

209 report_error(error_data) 

210 else: 

211 report_error_message(str(converter_error)) 

212 raise converter_error from None # Re-throws the exception. 

213 except Exception as error: 

214 report_error_message(str(error)) 

215 raise error from None # Re-throws the exception. 

216 

217 return wrapper 

218 

219 return actual_decorator