Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/python/authoring/authoring.py: 22%
125 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Authoring tool package for TFLite compatibility.
17WARNING: The package is experimental and subject to change.
19This package provides a way to check TFLite compatibility at model authoring
20time.
22Example:
23 @tf.lite.experimental.authoring.compatible
24 @tf.function(input_signature=[
25 tf.TensorSpec(shape=[None], dtype=tf.float32)
26 ])
27 def f(x):
28 return tf.cosh(x)
30 result = f(tf.constant([0.0]))
32 > COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model
33 > conversion for TensorFlow Lite.
34 > Op: tf.Cosh
35 > - tensorflow/python/framework/op_def_library.py:xxx
36 > - tensorflow/python/ops/gen_math_ops.py:xxx
37 > - simple_authoring.py:xxx
38"""
39import functools
42# pylint: disable=g-import-not-at-top
43from tensorflow.lite.python import convert
44from tensorflow.lite.python import lite
45from tensorflow.lite.python.metrics import converter_error_data_pb2
46from tensorflow.python.util.tf_export import tf_export as _tf_export
49_CUSTOM_OPS_HDR = "Custom ops: "
50_TF_OPS_HDR = "TF Select ops: "
51_AUTHORING_ERROR_HDR = "COMPATIBILITY ERROR"
52_AUTHORING_WARNING_HDR = "COMPATIBILITY WARNING"
53_FUNC_GRAPH_SRC_PATH = "tensorflow/python/framework/func_graph.py"
56class CompatibilityError(Exception):
57 """Raised when an error occurs with TFLite compatibility."""
58 pass
61class _Compatible:
62 """A decorator class to check TFLite compatibility created by `lite.experimental.authoring.compatible`."""
64 def __init__(self,
65 target,
66 converter_target_spec=None,
67 converter_allow_custom_ops=None,
68 raise_exception=False):
69 """Initialize the decorator object.
71 Here is the description of the object variables.
72 - _func : decorated function.
73 - _obj_func : for class object, we need to use this object to provide `self`
74 instance as 1 first argument.
75 - _verified : whether the compatibility is checked or not.
77 Args:
78 target: decorated function.
79 converter_target_spec : target_spec of TFLite converter parameter.
80 converter_allow_custom_ops : allow_custom_ops of TFLite converter
81 parameter.
82 raise_exception : to raise an exception on compatibility issues.
83 User need to use get_compatibility_log() to check details.
84 """
85 functools.update_wrapper(self, target)
86 self._func = target
87 self._obj_func = None
88 self._verified = False
89 self._log_messages = []
90 self._raise_exception = raise_exception
91 self._converter_target_spec = converter_target_spec
92 self._converter_allow_custom_ops = converter_allow_custom_ops
94 def __get__(self, instance, cls):
95 """A Python descriptor interface."""
96 self._obj_func = self._func.__get__(instance, cls)
97 return self
99 def _get_func(self):
100 """Returns decorated function object.
102 For a class method, use self._obj_func to provide `self` instance.
103 """
104 if self._obj_func is not None:
105 return self._obj_func
106 else:
107 return self._func
109 def __call__(self, *args, **kwargs): # pylint: disable=g-doc-args
110 """Calls decorated function object.
112 Also verifies if the function is compatible with TFLite.
114 Returns:
115 A execution result of the decorated function.
116 """
118 if not self._verified:
119 model = self._get_func()
120 concrete_func = model.get_concrete_function(*args, **kwargs)
121 converter = lite.TFLiteConverterV2.from_concrete_functions(
122 [concrete_func], model)
123 # Set provided converter parameters
124 if self._converter_target_spec is not None:
125 converter.target_spec = self._converter_target_spec
126 if self._converter_allow_custom_ops is not None:
127 converter.allow_custom_ops = self._converter_allow_custom_ops
128 try:
129 converter.convert()
130 except convert.ConverterError as err:
131 self._decode_error(err)
132 finally:
133 self._verified = True
135 return self._get_func()(*args, **kwargs)
137 def get_concrete_function(self, *args, **kwargs):
138 """Returns a concrete function of the decorated function."""
139 return self._get_func().get_concrete_function(*args, **kwargs)
141 def _get_location_string(self, location):
142 """Dump location of ConveterError.errors.location."""
143 callstack = []
144 for single_call in location.call:
145 if (location.type ==
146 converter_error_data_pb2.ConverterErrorData.CALLSITELOC):
147 # Stop showing CallSite after func_graph.py which isn't meaningful.
148 if _FUNC_GRAPH_SRC_PATH in single_call.source.filename:
149 break
150 callstack.append(
151 f" - {single_call.source.filename}:{single_call.source.line}")
152 else:
153 callstack.append(str(single_call))
154 callstack_dump = "\n".join(callstack)
155 return callstack_dump
157 def _dump_error_details(self, ops, locations):
158 """Dump the list of ops and locations."""
159 for i in range(0, len(ops)):
160 callstack_dump = self._get_location_string(locations[i])
161 err_string = f"Op: {ops[i]}\n{callstack_dump}\n"
162 self._log(err_string)
164 def _decode_error_legacy(self, err):
165 """Parses the given legacy ConverterError for OSS."""
166 for line in str(err).splitlines():
167 # Check custom op usage error.
168 if line.startswith(_CUSTOM_OPS_HDR):
169 custom_ops = line[len(_CUSTOM_OPS_HDR):]
170 err_string = (
171 f"{_AUTHORING_ERROR_HDR}: op '{custom_ops}' is(are) not natively "
172 "supported by TensorFlow Lite. You need to provide a custom "
173 "operator. https://www.tensorflow.org/lite/guide/ops_custom")
174 self._log(err_string)
175 # Check TensorFlow op usage error.
176 elif line.startswith(_TF_OPS_HDR):
177 tf_ops = line[len(_TF_OPS_HDR):]
178 err_string = (
179 f"{_AUTHORING_WARNING_HDR}: op '{tf_ops}' require(s) \"Select TF "
180 "Ops\" for model conversion for TensorFlow Lite. "
181 "https://www.tensorflow.org/lite/guide/ops_select")
182 self._log(err_string)
184 def _decode_converter_error(self, err):
185 """Parses the given ConverterError which has detailed error information."""
186 custom_ops = []
187 custom_ops_location = []
188 tf_ops = []
189 tf_ops_location = []
190 gpu_not_compatible_ops = []
191 for err in err.errors:
192 # Check custom op usage error.
193 if err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_CUSTOM_OPS:
194 custom_ops.append(err.operator.name)
195 custom_ops_location.append(err.location)
196 # Check TensorFlow op usage error.
197 elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_FLEX_OPS:
198 tf_ops.append(err.operator.name)
199 tf_ops_location.append(err.location)
200 # Check GPU delegate compatibility error.
201 elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_GPU_NOT_COMPATIBLE:
202 gpu_not_compatible_ops.append(err.operator.name)
203 # Log the first line of ConveterError.errors.error_message only
204 # since the seond line is "Error code: xxxx"
205 self._log(err.error_message.splitlines()[0])
206 self._log(self._get_location_string(err.location) + "\n")
207 else:
208 # Log other errors.
209 self._log(f"{_AUTHORING_ERROR_HDR}: {err.error_message}")
210 self._log(self._get_location_string(err.location) + "\n")
212 if custom_ops:
213 custom_ops_str = ", ".join(sorted(custom_ops))
214 err_string = (
215 f"{_AUTHORING_ERROR_HDR}: op '{custom_ops_str}' is(are) not natively "
216 "supported by TensorFlow Lite. You need to provide a custom "
217 "operator. https://www.tensorflow.org/lite/guide/ops_custom")
218 self._log(err_string)
219 self._dump_error_details(custom_ops, custom_ops_location)
221 if tf_ops:
222 tf_ops_str = ", ".join(sorted(tf_ops))
223 err_string = (
224 f"{_AUTHORING_WARNING_HDR}: op '{tf_ops_str}' require(s) \"Select TF"
225 " Ops\" for model conversion for TensorFlow Lite. "
226 "https://www.tensorflow.org/lite/guide/ops_select")
227 self._log(err_string)
228 self._dump_error_details(tf_ops, tf_ops_location)
230 if gpu_not_compatible_ops:
231 not_compatible_ops_str = ", ".join(sorted(gpu_not_compatible_ops))
232 err_string = (
233 f"{_AUTHORING_WARNING_HDR}: op '{not_compatible_ops_str}' aren't "
234 "compatible with TensorFlow Lite GPU delegate. "
235 "https://www.tensorflow.org/lite/performance/gpu")
236 self._log(err_string)
238 def _decode_error(self, err):
239 """Parses the given ConverterError and generates compatibility warnings."""
240 if hasattr(err, "errors"):
241 self._decode_converter_error(err)
242 else:
243 self._decode_error_legacy(err)
245 if self._raise_exception and self._log_messages:
246 raise CompatibilityError(f"CompatibilityException at {repr(self._func)}")
248 def _log(self, message):
249 """Log and print authoring warning / error message."""
250 self._log_messages.append(message)
251 print(message)
253 def get_compatibility_log(self):
254 """Returns list of compatibility log messages.
256 WARNING: This method should only be used for unit tests.
258 Returns:
259 The list of log messages by the recent compatibility check.
260 Raises:
261 RuntimeError: when the compatibility was NOT checked.
262 """
263 if not self._verified:
264 raise RuntimeError("target compatibility isn't verified yet")
265 return self._log_messages
268@_tf_export("lite.experimental.authoring.compatible")
269def compatible(target=None, converter_target_spec=None, **kwargs):
270 """Wraps `tf.function` into a callable function with TFLite compatibility checking.
272 Example:
274 ```python
275 @tf.lite.experimental.authoring.compatible
276 @tf.function(input_signature=[
277 tf.TensorSpec(shape=[None], dtype=tf.float32)
278 ])
279 def f(x):
280 return tf.cosh(x)
282 result = f(tf.constant([0.0]))
283 # COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model
284 # conversion for TensorFlow Lite.
285 # Op: tf.Cosh
286 # - tensorflow/python/framework/op_def_library.py:748
287 # - tensorflow/python/ops/gen_math_ops.py:2458
288 # - <stdin>:6
289 ```
291 WARNING: Experimental interface, subject to change.
293 Args:
294 target: A `tf.function` to decorate.
295 converter_target_spec : target_spec of TFLite converter parameter.
296 **kwargs: The keyword arguments of the decorator class _Compatible.
298 Returns:
299 A callable object of `tf.lite.experimental.authoring._Compatible`.
300 """
301 if target is None:
302 def wrapper(target):
303 return _Compatible(target, converter_target_spec, **kwargs)
304 return wrapper
305 else:
306 return _Compatible(target, converter_target_spec, **kwargs)