Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/python/authoring/authoring.py: 22%

125 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""TensorFlow Authoring tool package for TFLite compatibility. 

16 

17WARNING: The package is experimental and subject to change. 

18 

19This package provides a way to check TFLite compatibility at model authoring 

20time. 

21 

22Example: 

23 @tf.lite.experimental.authoring.compatible 

24 @tf.function(input_signature=[ 

25 tf.TensorSpec(shape=[None], dtype=tf.float32) 

26 ]) 

27 def f(x): 

28 return tf.cosh(x) 

29 

30 result = f(tf.constant([0.0])) 

31 

32 > COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model 

33 > conversion for TensorFlow Lite. 

34 > Op: tf.Cosh 

35 > - tensorflow/python/framework/op_def_library.py:xxx 

36 > - tensorflow/python/ops/gen_math_ops.py:xxx 

37 > - simple_authoring.py:xxx 

38""" 

39import functools 

40 

41 

42# pylint: disable=g-import-not-at-top 

43from tensorflow.lite.python import convert 

44from tensorflow.lite.python import lite 

45from tensorflow.lite.python.metrics import converter_error_data_pb2 

46from tensorflow.python.util.tf_export import tf_export as _tf_export 

47 

48 

49_CUSTOM_OPS_HDR = "Custom ops: " 

50_TF_OPS_HDR = "TF Select ops: " 

51_AUTHORING_ERROR_HDR = "COMPATIBILITY ERROR" 

52_AUTHORING_WARNING_HDR = "COMPATIBILITY WARNING" 

53_FUNC_GRAPH_SRC_PATH = "tensorflow/python/framework/func_graph.py" 

54 

55 

56class CompatibilityError(Exception): 

57 """Raised when an error occurs with TFLite compatibility.""" 

58 pass 

59 

60 

61class _Compatible: 

62 """A decorator class to check TFLite compatibility created by `lite.experimental.authoring.compatible`.""" 

63 

64 def __init__(self, 

65 target, 

66 converter_target_spec=None, 

67 converter_allow_custom_ops=None, 

68 raise_exception=False): 

69 """Initialize the decorator object. 

70 

71 Here is the description of the object variables. 

72 - _func : decorated function. 

73 - _obj_func : for class object, we need to use this object to provide `self` 

74 instance as 1 first argument. 

75 - _verified : whether the compatibility is checked or not. 

76 

77 Args: 

78 target: decorated function. 

79 converter_target_spec : target_spec of TFLite converter parameter. 

80 converter_allow_custom_ops : allow_custom_ops of TFLite converter 

81 parameter. 

82 raise_exception : to raise an exception on compatibility issues. 

83 User need to use get_compatibility_log() to check details. 

84 """ 

85 functools.update_wrapper(self, target) 

86 self._func = target 

87 self._obj_func = None 

88 self._verified = False 

89 self._log_messages = [] 

90 self._raise_exception = raise_exception 

91 self._converter_target_spec = converter_target_spec 

92 self._converter_allow_custom_ops = converter_allow_custom_ops 

93 

94 def __get__(self, instance, cls): 

95 """A Python descriptor interface.""" 

96 self._obj_func = self._func.__get__(instance, cls) 

97 return self 

98 

99 def _get_func(self): 

100 """Returns decorated function object. 

101 

102 For a class method, use self._obj_func to provide `self` instance. 

103 """ 

104 if self._obj_func is not None: 

105 return self._obj_func 

106 else: 

107 return self._func 

108 

109 def __call__(self, *args, **kwargs): # pylint: disable=g-doc-args 

110 """Calls decorated function object. 

111 

112 Also verifies if the function is compatible with TFLite. 

113 

114 Returns: 

115 A execution result of the decorated function. 

116 """ 

117 

118 if not self._verified: 

119 model = self._get_func() 

120 concrete_func = model.get_concrete_function(*args, **kwargs) 

121 converter = lite.TFLiteConverterV2.from_concrete_functions( 

122 [concrete_func], model) 

123 # Set provided converter parameters 

124 if self._converter_target_spec is not None: 

125 converter.target_spec = self._converter_target_spec 

126 if self._converter_allow_custom_ops is not None: 

127 converter.allow_custom_ops = self._converter_allow_custom_ops 

128 try: 

129 converter.convert() 

130 except convert.ConverterError as err: 

131 self._decode_error(err) 

132 finally: 

133 self._verified = True 

134 

135 return self._get_func()(*args, **kwargs) 

136 

137 def get_concrete_function(self, *args, **kwargs): 

138 """Returns a concrete function of the decorated function.""" 

139 return self._get_func().get_concrete_function(*args, **kwargs) 

140 

141 def _get_location_string(self, location): 

142 """Dump location of ConveterError.errors.location.""" 

143 callstack = [] 

144 for single_call in location.call: 

145 if (location.type == 

146 converter_error_data_pb2.ConverterErrorData.CALLSITELOC): 

147 # Stop showing CallSite after func_graph.py which isn't meaningful. 

148 if _FUNC_GRAPH_SRC_PATH in single_call.source.filename: 

149 break 

150 callstack.append( 

151 f" - {single_call.source.filename}:{single_call.source.line}") 

152 else: 

153 callstack.append(str(single_call)) 

154 callstack_dump = "\n".join(callstack) 

155 return callstack_dump 

156 

157 def _dump_error_details(self, ops, locations): 

158 """Dump the list of ops and locations.""" 

159 for i in range(0, len(ops)): 

160 callstack_dump = self._get_location_string(locations[i]) 

161 err_string = f"Op: {ops[i]}\n{callstack_dump}\n" 

162 self._log(err_string) 

163 

164 def _decode_error_legacy(self, err): 

165 """Parses the given legacy ConverterError for OSS.""" 

166 for line in str(err).splitlines(): 

167 # Check custom op usage error. 

168 if line.startswith(_CUSTOM_OPS_HDR): 

169 custom_ops = line[len(_CUSTOM_OPS_HDR):] 

170 err_string = ( 

171 f"{_AUTHORING_ERROR_HDR}: op '{custom_ops}' is(are) not natively " 

172 "supported by TensorFlow Lite. You need to provide a custom " 

173 "operator. https://www.tensorflow.org/lite/guide/ops_custom") 

174 self._log(err_string) 

175 # Check TensorFlow op usage error. 

176 elif line.startswith(_TF_OPS_HDR): 

177 tf_ops = line[len(_TF_OPS_HDR):] 

178 err_string = ( 

179 f"{_AUTHORING_WARNING_HDR}: op '{tf_ops}' require(s) \"Select TF " 

180 "Ops\" for model conversion for TensorFlow Lite. " 

181 "https://www.tensorflow.org/lite/guide/ops_select") 

182 self._log(err_string) 

183 

184 def _decode_converter_error(self, err): 

185 """Parses the given ConverterError which has detailed error information.""" 

186 custom_ops = [] 

187 custom_ops_location = [] 

188 tf_ops = [] 

189 tf_ops_location = [] 

190 gpu_not_compatible_ops = [] 

191 for err in err.errors: 

192 # Check custom op usage error. 

193 if err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_CUSTOM_OPS: 

194 custom_ops.append(err.operator.name) 

195 custom_ops_location.append(err.location) 

196 # Check TensorFlow op usage error. 

197 elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_FLEX_OPS: 

198 tf_ops.append(err.operator.name) 

199 tf_ops_location.append(err.location) 

200 # Check GPU delegate compatibility error. 

201 elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_GPU_NOT_COMPATIBLE: 

202 gpu_not_compatible_ops.append(err.operator.name) 

203 # Log the first line of ConveterError.errors.error_message only 

204 # since the seond line is "Error code: xxxx" 

205 self._log(err.error_message.splitlines()[0]) 

206 self._log(self._get_location_string(err.location) + "\n") 

207 else: 

208 # Log other errors. 

209 self._log(f"{_AUTHORING_ERROR_HDR}: {err.error_message}") 

210 self._log(self._get_location_string(err.location) + "\n") 

211 

212 if custom_ops: 

213 custom_ops_str = ", ".join(sorted(custom_ops)) 

214 err_string = ( 

215 f"{_AUTHORING_ERROR_HDR}: op '{custom_ops_str}' is(are) not natively " 

216 "supported by TensorFlow Lite. You need to provide a custom " 

217 "operator. https://www.tensorflow.org/lite/guide/ops_custom") 

218 self._log(err_string) 

219 self._dump_error_details(custom_ops, custom_ops_location) 

220 

221 if tf_ops: 

222 tf_ops_str = ", ".join(sorted(tf_ops)) 

223 err_string = ( 

224 f"{_AUTHORING_WARNING_HDR}: op '{tf_ops_str}' require(s) \"Select TF" 

225 " Ops\" for model conversion for TensorFlow Lite. " 

226 "https://www.tensorflow.org/lite/guide/ops_select") 

227 self._log(err_string) 

228 self._dump_error_details(tf_ops, tf_ops_location) 

229 

230 if gpu_not_compatible_ops: 

231 not_compatible_ops_str = ", ".join(sorted(gpu_not_compatible_ops)) 

232 err_string = ( 

233 f"{_AUTHORING_WARNING_HDR}: op '{not_compatible_ops_str}' aren't " 

234 "compatible with TensorFlow Lite GPU delegate. " 

235 "https://www.tensorflow.org/lite/performance/gpu") 

236 self._log(err_string) 

237 

238 def _decode_error(self, err): 

239 """Parses the given ConverterError and generates compatibility warnings.""" 

240 if hasattr(err, "errors"): 

241 self._decode_converter_error(err) 

242 else: 

243 self._decode_error_legacy(err) 

244 

245 if self._raise_exception and self._log_messages: 

246 raise CompatibilityError(f"CompatibilityException at {repr(self._func)}") 

247 

248 def _log(self, message): 

249 """Log and print authoring warning / error message.""" 

250 self._log_messages.append(message) 

251 print(message) 

252 

253 def get_compatibility_log(self): 

254 """Returns list of compatibility log messages. 

255 

256 WARNING: This method should only be used for unit tests. 

257 

258 Returns: 

259 The list of log messages by the recent compatibility check. 

260 Raises: 

261 RuntimeError: when the compatibility was NOT checked. 

262 """ 

263 if not self._verified: 

264 raise RuntimeError("target compatibility isn't verified yet") 

265 return self._log_messages 

266 

267 

268@_tf_export("lite.experimental.authoring.compatible") 

269def compatible(target=None, converter_target_spec=None, **kwargs): 

270 """Wraps `tf.function` into a callable function with TFLite compatibility checking. 

271 

272 Example: 

273 

274 ```python 

275 @tf.lite.experimental.authoring.compatible 

276 @tf.function(input_signature=[ 

277 tf.TensorSpec(shape=[None], dtype=tf.float32) 

278 ]) 

279 def f(x): 

280 return tf.cosh(x) 

281 

282 result = f(tf.constant([0.0])) 

283 # COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model 

284 # conversion for TensorFlow Lite. 

285 # Op: tf.Cosh 

286 # - tensorflow/python/framework/op_def_library.py:748 

287 # - tensorflow/python/ops/gen_math_ops.py:2458 

288 # - <stdin>:6 

289 ``` 

290 

291 WARNING: Experimental interface, subject to change. 

292 

293 Args: 

294 target: A `tf.function` to decorate. 

295 converter_target_spec : target_spec of TFLite converter parameter. 

296 **kwargs: The keyword arguments of the decorator class _Compatible. 

297 

298 Returns: 

299 A callable object of `tf.lite.experimental.authoring._Compatible`. 

300 """ 

301 if target is None: 

302 def wrapper(target): 

303 return _Compatible(target, converter_target_spec, **kwargs) 

304 return wrapper 

305 else: 

306 return _Compatible(target, converter_target_spec, **kwargs)