Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/python/convert_phase.py: 59%
87 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for collecting TFLite metrics."""
17import collections
18import enum
19import functools
20from typing import Text
22from tensorflow.lite.python.metrics import converter_error_data_pb2
23from tensorflow.lite.python.metrics import metrics
26class Component(enum.Enum):
27 """Enum class defining name of the converter components."""
28 # Validate the given input and prepare and optimize TensorFlow Model.
29 PREPARE_TF_MODEL = "PREPARE_TF_MODEL"
31 # Convert to TFLite model format.
32 CONVERT_TF_TO_TFLITE_MODEL = "CONVERT_TF_TO_TFLITE_MODEL"
34 # RUN quantization and sparsification.
35 OPTIMIZE_TFLITE_MODEL = "OPTIMIZE_TFLITE_MODEL"
38SubComponentItem = collections.namedtuple("SubComponentItem",
39 ["name", "component"])
42class SubComponent(SubComponentItem, enum.Enum):
43 """Enum class defining name of the converter subcomponents.
45 This enum only defines the subcomponents in Python, there might be more
46 subcomponents defined in C++.
47 """
49 def __str__(self):
50 return self.value.name
52 @property
53 def name(self):
54 return self.value.name
56 @property
57 def component(self):
58 return self.value.component
60 # The subcomponent name is unspecified.
61 UNSPECIFIED = SubComponentItem("UNSPECIFIED", None)
63 # Valid the given input and parameters.
64 VALIDATE_INPUTS = SubComponentItem("VALIDATE_INPUTS",
65 Component.PREPARE_TF_MODEL)
67 # Load GraphDef from SavedModel.
68 LOAD_SAVED_MODEL = SubComponentItem("LOAD_SAVED_MODEL",
69 Component.PREPARE_TF_MODEL)
71 # Convert a SavedModel to frozen graph.
72 FREEZE_SAVED_MODEL = SubComponentItem("FREEZE_SAVED_MODEL",
73 Component.PREPARE_TF_MODEL)
75 # Save a Keras model to SavedModel.
76 CONVERT_KERAS_TO_SAVED_MODEL = SubComponentItem(
77 "CONVERT_KERAS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
79 # Save Concrete functions to SavedModel.
80 CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL = SubComponentItem(
81 "CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
83 # Convert a Keras model to a frozen graph.
84 FREEZE_KERAS_MODEL = SubComponentItem("FREEZE_KERAS_MODEL",
85 Component.PREPARE_TF_MODEL)
87 # Replace all the variables with constants in a ConcreteFunction.
88 FREEZE_CONCRETE_FUNCTION = SubComponentItem("FREEZE_CONCRETE_FUNCTION",
89 Component.PREPARE_TF_MODEL)
91 # Run grappler optimization.
92 OPTIMIZE_TF_MODEL = SubComponentItem("OPTIMIZE_TF_MODEL",
93 Component.PREPARE_TF_MODEL)
95 # Convert using the old TOCO converter.
96 CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER = SubComponentItem(
97 "CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER",
98 Component.CONVERT_TF_TO_TFLITE_MODEL)
100 # Convert a GraphDef to TFLite model.
101 CONVERT_GRAPHDEF = SubComponentItem("CONVERT_GRAPHDEF",
102 Component.CONVERT_TF_TO_TFLITE_MODEL)
104 # Convert a SavedModel to TFLite model.
105 CONVERT_SAVED_MODEL = SubComponentItem("CONVERT_SAVED_MODEL",
106 Component.CONVERT_TF_TO_TFLITE_MODEL)
108 # Convert a Jax HLO to TFLite model.
109 CONVERT_JAX_HLO = SubComponentItem("CONVERT_JAX_HLO",
110 Component.CONVERT_TF_TO_TFLITE_MODEL)
112 # Do quantization by the deprecated quantizer.
113 QUANTIZE_USING_DEPRECATED_QUANTIZER = SubComponentItem(
114 "QUANTIZE_USING_DEPRECATED_QUANTIZER", Component.OPTIMIZE_TFLITE_MODEL)
116 # Do calibration.
117 CALIBRATE = SubComponentItem("CALIBRATE", Component.OPTIMIZE_TFLITE_MODEL)
119 # Do quantization by MLIR.
120 QUANTIZE = SubComponentItem("QUANTIZE", Component.OPTIMIZE_TFLITE_MODEL)
122 # Do sparsification by MLIR.
123 SPARSIFY = SubComponentItem("SPARSIFY", Component.OPTIMIZE_TFLITE_MODEL)
126class ConverterError(Exception):
127 """Raised when an error occurs during model conversion."""
129 def __init__(self, message):
130 super(ConverterError, self).__init__(message)
131 self.errors = []
132 self._parse_error_message(message)
134 def append_error(self,
135 error_data: converter_error_data_pb2.ConverterErrorData):
136 self.errors.append(error_data)
138 def _parse_error_message(self, message):
139 """If the message matches a pattern, assigns the associated error code.
141 It is difficult to assign an error code to some errrors in MLIR side, Ex:
142 errors thrown by other components than TFLite or not using mlir::emitError.
143 This function try to detect them by the error message and assign the
144 corresponding error code.
146 Args:
147 message: The error message of this exception.
148 """
149 error_code_mapping = {
150 "Failed to functionalize Control Flow V1 ops. Consider using Control "
151 "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/"
152 "tf/compat/v1/enable_control_flow_v2.":
153 converter_error_data_pb2.ConverterErrorData
154 .ERROR_UNSUPPORTED_CONTROL_FLOW_V1,
155 }
156 for pattern, error_code in error_code_mapping.items():
157 if pattern in message:
158 error_data = converter_error_data_pb2.ConverterErrorData()
159 error_data.error_message = message
160 error_data.error_code = error_code
161 self.append_error(error_data)
162 return
165def convert_phase(component, subcomponent=SubComponent.UNSPECIFIED):
166 """The decorator to identify converter component and subcomponent.
168 Args:
169 component: Converter component name.
170 subcomponent: Converter subcomponent name.
172 Returns:
173 Forward the result from the wrapped function.
175 Raises:
176 ValueError: if component and subcomponent name is not valid.
177 """
178 if component not in Component:
179 raise ValueError("Given component name not found")
180 if subcomponent not in SubComponent:
181 raise ValueError("Given subcomponent name not found")
182 if (subcomponent != SubComponent.UNSPECIFIED and
183 subcomponent.component != component):
184 raise ValueError("component and subcomponent name don't match")
186 def report_error(error_data: converter_error_data_pb2.ConverterErrorData):
187 # Always overwrites the component information, but only overwrites the
188 # subcomponent if it is not available.
189 error_data.component = component.value
190 if not error_data.subcomponent:
191 error_data.subcomponent = subcomponent.name
192 tflite_metrics = metrics.TFLiteConverterMetrics()
193 tflite_metrics.set_converter_error(error_data)
195 def report_error_message(error_message: Text):
196 error_data = converter_error_data_pb2.ConverterErrorData()
197 error_data.error_message = error_message
198 report_error(error_data)
200 def actual_decorator(func):
202 @functools.wraps(func)
203 def wrapper(*args, **kwargs):
204 try:
205 return func(*args, **kwargs)
206 except ConverterError as converter_error:
207 if converter_error.errors:
208 for error_data in converter_error.errors:
209 report_error(error_data)
210 else:
211 report_error_message(str(converter_error))
212 raise converter_error from None # Re-throws the exception.
213 except Exception as error:
214 report_error_message(str(error))
215 raise error from None # Re-throws the exception.
217 return wrapper
219 return actual_decorator