Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/python/convert.py: 20%
343 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converts a frozen graph into a TFLite FlatBuffer."""
17import distutils.spawn
18import enum
19import hashlib
20import os as _os
21import platform as _platform
22import subprocess as _subprocess
23import tempfile as _tempfile
24from typing import Optional
25import warnings
27from tensorflow.compiler.mlir.quantization.stablehlo import quantization_options_pb2 as quant_opts_pb2
28from tensorflow.lite.python import lite_constants
29from tensorflow.lite.python import util
30from tensorflow.lite.python import wrap_toco
31from tensorflow.lite.python.convert_phase import Component
32from tensorflow.lite.python.convert_phase import convert_phase
33from tensorflow.lite.python.convert_phase import ConverterError
34from tensorflow.lite.python.convert_phase import SubComponent
35from tensorflow.lite.python.metrics import converter_error_data_pb2
36from tensorflow.lite.python.metrics.wrapper import metrics_wrapper as _metrics_wrapper
37from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
38from tensorflow.lite.toco import toco_flags_pb2 as _conversion_flags_pb2
39from tensorflow.lite.toco import types_pb2 as _types_pb2
40from tensorflow.lite.tools import flatbuffer_utils
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import tensor_shape
43from tensorflow.python.platform import resource_loader as _resource_loader
44from tensorflow.python.util import deprecation
45from tensorflow.python.util.tf_export import tf_export as _tf_export
48def _is_quantized_input_stats_required(
49 conversion_flags: _conversion_flags_pb2.TocoFlags,
50) -> bool:
51 """Checks if the `quantized_input_stats` flag is required for conversion.
53 Args:
54 conversion_flags: A protocol buffer describing the conversion process.
56 Returns:
57 True, if the `inference_type` or the `inference_input_type` is a quantized
58 type and it is not post training quantization, else False.
59 """
60 quantized_inference_types = [
61 _types_pb2.QUANTIZED_UINT8,
62 _types_pb2.QUANTIZED_INT8,
63 ]
64 return (
65 conversion_flags.inference_type in quantized_inference_types
66 or conversion_flags.inference_input_type in quantized_inference_types
67 ) and not conversion_flags.post_training_quantize
70def convert_tensor_tf_type_to_tflite_type(
71 tf_type: dtypes.DType, usage: str = ""
72) -> _types_pb2.IODataType:
73 """Convert tensor type from tf type to tflite type.
75 Args:
76 tf_type: TensorFlow type.
77 usage: Text describing the reason for invoking this function.
79 Raises:
80 ValueError: If `tf_type` is unsupported.
82 Returns:
83 tflite_type: TFLite type. Refer to lite/toco/types.proto.
84 """
85 mapping = {
86 dtypes.float16: _types_pb2.FLOAT16,
87 dtypes.float32: _types_pb2.FLOAT,
88 dtypes.float64: _types_pb2.FLOAT64,
89 dtypes.int8: _types_pb2.INT8,
90 dtypes.int16: _types_pb2.INT16,
91 dtypes.uint16: _types_pb2.UINT16,
92 dtypes.int32: _types_pb2.INT32,
93 dtypes.int64: _types_pb2.INT64,
94 dtypes.uint8: _types_pb2.UINT8,
95 dtypes.uint32: _types_pb2.UINT32,
96 dtypes.uint64: _types_pb2.UINT64,
97 dtypes.string: _types_pb2.STRING,
98 dtypes.bool: _types_pb2.BOOL,
99 dtypes.complex64: _types_pb2.COMPLEX64,
100 dtypes.complex128: _types_pb2.COMPLEX128,
101 }
102 tflite_type = mapping.get(tf_type)
103 if tflite_type is None:
104 raise ValueError(
105 "Unsupported TensorFlow type `{0}` provided for the {1}".format(
106 tf_type, usage
107 )
108 )
109 return tflite_type
112# Only a few restricted tensor types are allowed for explicitly setting
113# inference/input/output types.
114def convert_inference_tf_type_to_tflite_type(
115 tf_type: dtypes.DType, usage: str = ""
116) -> _types_pb2.IODataType:
117 """Convert inference type from tf type to tflite type.
119 Args:
120 tf_type: TensorFlow type.
121 usage: Text describing the reason for invoking this function.
123 Raises:
124 ValueError: If `tf_type` is unsupported.
126 Returns:
127 tflite_type: TFLite type. Refer to lite/toco/types.proto.
128 """
129 mapping = {
130 dtypes.float32: _types_pb2.FLOAT,
131 dtypes.uint8: _types_pb2.QUANTIZED_UINT8,
132 dtypes.int8: _types_pb2.QUANTIZED_INT8,
133 dtypes.int16: _types_pb2.QUANTIZED_INT16,
134 }
135 tflite_type = mapping.get(tf_type)
136 if tflite_type is None:
137 raise ValueError(
138 "Unsupported TensorFlow type `{0}` provided for the {1}".format(
139 tf_type, usage
140 )
141 )
142 return tflite_type
145# Find the deprecated conversion binary using the resource loader if using from
146# bazel, otherwise we are in a pip where console_scripts already has the tool.
147if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
148 _deprecated_conversion_binary = ""
149else:
150 _deprecated_conversion_binary = _resource_loader.get_path_to_datafile(
151 "../toco/python/toco_from_protos"
152 )
153 if not _os.path.exists(_deprecated_conversion_binary):
154 _deprecated_conversion_binary = "toco_from_protos"
157def _try_convert_to_unicode(output):
158 if output is None:
159 return ""
161 if isinstance(output, bytes):
162 try:
163 return output.decode("utf-8")
164 except UnicodeDecodeError:
165 pass
166 return output
169@_tf_export("lite.OpsSet")
170class OpsSet(enum.Enum):
171 """Enum class defining the sets of ops available to generate TFLite models.
173 WARNING: Experimental interface, subject to change.
174 """
176 # Convert model using TensorFlow Lite builtin ops.
177 TFLITE_BUILTINS = "TFLITE_BUILTINS"
179 # Convert model using TensorFlow ops. Not all TensorFlow ops are available.
180 # WARNING: Experimental interface, subject to change.
181 SELECT_TF_OPS = "SELECT_TF_OPS"
183 # Convert model using only TensorFlow Lite quantized int8 operations.
184 # Specifying this will throw an error for operations that do not yet have
185 # quantized implementations.
186 TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8"
188 # Convert model using only TensorFlow Lite operations with quantized int8
189 # weights, int16 activations and int64 bias.
190 # Specifying this will throw an error for operations that do not yet have
191 # quantized implementations.
192 # This quantization mode may be used in models for super-resolution,
193 # audio signal processing or image de-noising. It improves accuracy
194 # significantly, but only slightly increases the model size.
195 # WARNING: These ops are currently experimental and have not yet been
196 # finalized.
197 # They are only compatible with CPU execution, and have not been optimized for
198 # production.
199 EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = (
200 "EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8"
201 )
203 # Convert model using only stablehlo ops.
204 # This option can not be combined with other OpsSets.
205 # The feature is in early development.
206 # The code to execute StableHLO ops in the runtime is to be implemented
207 # and the serialization format is not stabilized yet.
209 EXPERIMENTAL_STABLEHLO_OPS = "EXPERIMENTAL_STABLEHLO_OPS"
211 def __str__(self):
212 return str(self.value)
214 @staticmethod
215 def get_options():
216 """Returns a list of OpsSet options as a list of strings."""
217 return [str(option) for option in list(OpsSet)]
220@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.QUANTIZE)
221def mlir_quantize(
222 input_data_str,
223 disable_per_channel=False,
224 fully_quantize=False,
225 inference_type=_types_pb2.QUANTIZED_INT8,
226 input_data_type=dtypes.float32,
227 output_data_type=dtypes.float32,
228 enable_numeric_verify=False,
229 enable_whole_model_verify=False,
230 denylisted_ops=None,
231 denylisted_nodes=None,
232 enable_variable_quantization=False,
233):
234 """Quantize `input_data_str` with calibration results.
236 Args:
237 input_data_str: Input data in serialized form (e.g. a TFLITE model with
238 calibration results).
239 disable_per_channel: Bool indicating whether to do per-channel or per-tensor
240 quantization
241 fully_quantize: Bool indicating whether to fully quantize the model. Besides
242 model body, the input/output will be quantized as well.
243 inference_type: Data type for the activations. The default value is int8.
244 input_data_type: Data type for the inputs. The default value is float32.
245 output_data_type: Data type for the outputs. The default value is float32.
246 enable_numeric_verify: Experimental. Subject to change. Bool indicating
247 whether to add NumericVerify ops into the debug mode quantized model.
248 enable_whole_model_verify: Experimental. Subject to change. Bool indicating
249 whether to add verification for layer by layer, or on whole model. When
250 disabled (per-layer) float and quantized ops will be run from same input
251 (output of previous quantized layer). When enabled, float and quantized
252 ops will run with respective float and quantized output of previous ops.
253 denylisted_ops: Experimental. Subject to change. Set of ops to denylist.
254 denylisted_nodes: Experimental. Subject to change. Set of notes to denylist.
255 enable_variable_quantization: Experimental. Subject to change. Bool
256 indicating whether to enable quantization of the residual variables
257 remaining after the variable freezing pass.
259 Returns:
260 Quantized model in serialized form (e.g. a TFLITE model) with floating-point
261 inputs and outputs.
262 """
263 return wrap_toco.wrapped_experimental_mlir_quantize(
264 input_data_str,
265 disable_per_channel,
266 fully_quantize,
267 inference_type,
268 convert_tensor_tf_type_to_tflite_type(input_data_type),
269 convert_tensor_tf_type_to_tflite_type(output_data_type),
270 enable_numeric_verify,
271 enable_whole_model_verify,
272 denylisted_ops,
273 denylisted_nodes,
274 enable_variable_quantization,
275 )
278@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.SPARSIFY)
279def mlir_sparsify(input_data_str):
280 """Sparsify `input_data_str` to encode sparse tensor with proper format.
282 Args:
283 input_data_str: Input data in serialized form (e.g. a TFLITE model).
285 Returns:
286 Sparsified model in serialized form (e.g. a TFLITE model).
287 """
288 return wrap_toco.wrapped_experimental_mlir_sparsify(input_data_str)
291def register_custom_opdefs(custom_opdefs_list):
292 """Register the given custom opdefs to the TensorFlow global op registry.
294 Args:
295 custom_opdefs_list: String representing the custom ops OpDefs that are
296 included in the GraphDef.
298 Returns:
299 True if the registration is successfully completed.
300 """
301 return wrap_toco.wrapped_register_custom_opdefs(custom_opdefs_list)
304def convert(
305 model_flags: _model_flags_pb2.ModelFlags,
306 conversion_flags: _conversion_flags_pb2.TocoFlags,
307 input_data_str: Optional[str] = None,
308 debug_info_str: Optional[str] = None,
309 enable_mlir_converter: bool = True,
310):
311 """Converts `input_data_str` to a TFLite model.
313 Args:
314 model_flags: Proto describing model properties, see `model_flags.proto`.
315 conversion_flags: Proto describing conversion properties, see
316 `toco/toco_flags.proto`.
317 input_data_str: Input data in serialized form (e.g. a graphdef is common, or
318 it can be hlo text or proto)
319 debug_info_str: Serialized `GraphDebugInfo` proto describing logging
320 information.
321 enable_mlir_converter: Enables MLIR-based conversion.
323 Returns:
324 Converted model in serialized form (e.g. a TFLITE model is common).
325 Raises:
326 ConverterError: When conversion fails in TFLiteConverter, usually due to
327 ops not being supported.
328 RuntimeError: When conversion fails, an exception is raised with the error
329 message embedded.
330 """
331 # Historically, deprecated conversion failures would trigger a crash, so we
332 # attempt to run the converter out-of-process. The current MLIR conversion
333 # pipeline surfaces errors instead, and can be safely run in-process.
334 if enable_mlir_converter or not _deprecated_conversion_binary:
335 try:
336 return wrap_toco.wrapped_toco_convert(
337 model_flags.SerializeToString(),
338 conversion_flags.SerializeToString(),
339 input_data_str,
340 debug_info_str,
341 enable_mlir_converter,
342 )
343 except Exception as e:
344 converter_error = ConverterError(str(e))
346 for error_data in _metrics_wrapper.retrieve_collected_errors():
347 converter_error.append_error(error_data)
348 # Seldom we encounter the case where an unsupported
349 # `StatefulPartitionedCallOp` is not inlined and remains in the final
350 # IR. If this occurs we can set `guarantee_all_funcs_one_use` and retry.
351 # This makes the converter copy functions definitions called by
352 # multiple StatefulPartitionedCall, thus allowing them to be properly
353 # inlined.
354 if (
355 error_data.error_code
356 == converter_error_data_pb2.ConverterErrorData.ERROR_STATEFUL_PARTITIONED_CALL_IN_FINAL_IR
357 and not conversion_flags.guarantee_all_funcs_one_use
358 ):
359 conversion_flags.guarantee_all_funcs_one_use = True
360 return convert(
361 model_flags,
362 conversion_flags,
363 input_data_str,
364 debug_info_str,
365 enable_mlir_converter,
366 )
367 raise converter_error
369 return _run_deprecated_conversion_binary(
370 model_flags.SerializeToString(),
371 conversion_flags.SerializeToString(),
372 input_data_str,
373 debug_info_str,
374 )
377@convert_phase(
378 Component.CONVERT_TF_TO_TFLITE_MODEL,
379 SubComponent.CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER,
380)
381def _run_deprecated_conversion_binary(
382 model_flags_str, conversion_flags_str, input_data_str, debug_info_str=None
383):
384 """Convert `input_data_str` using deprecated conversion binary.
386 Args:
387 model_flags_str: Serialized proto describing model properties, see
388 `model_flags.proto`.
389 conversion_flags_str: Serialized proto describing TFLite converter
390 properties, see `toco/toco_flags.proto`.
391 input_data_str: Input data in serialized form (e.g. a graphdef is common)
392 debug_info_str: Serialized `GraphDebugInfo` proto describing logging
393 information. (default None)
395 Returns:
396 Converted model in serialized form (e.g. a TFLITE model is common).
397 Raises:
398 ConverterError: When cannot find the deprecated conversion binary.
399 RuntimeError: When conversion fails, an exception is raised with the error
400 message embedded.
401 """
402 if distutils.spawn.find_executable(_deprecated_conversion_binary) is None:
403 raise ConverterError("""Could not find `toco_from_protos` binary, make sure
404your virtualenv bin directory or pip local bin directory is in your path.
405In particular, if you have installed TensorFlow with --user, make sure you
406add the install directory to your path.
408For example:
409Linux: export PATH=$PATH:~/.local/bin/
410Mac: export PATH=$PATH:~/Library/Python/<version#>/bin
412Alternative, use virtualenv.""")
413 # Windows and TemporaryFile are not that useful together,
414 # since you cannot have two readers/writers. So we have to
415 # make the temporaries and close and delete them explicitly.
416 conversion_filename, model_filename, input_filename, output_filename = (
417 None,
418 None,
419 None,
420 None,
421 )
422 try:
423 # Build all input files
424 with _tempfile.NamedTemporaryFile(
425 delete=False
426 ) as fp_conversion, _tempfile.NamedTemporaryFile(
427 delete=False
428 ) as fp_model, _tempfile.NamedTemporaryFile(
429 delete=False
430 ) as fp_input, _tempfile.NamedTemporaryFile(
431 delete=False
432 ) as fp_debug:
433 conversion_filename = fp_conversion.name
434 input_filename = fp_input.name
435 model_filename = fp_model.name
436 debug_filename = fp_debug.name
438 fp_model.write(model_flags_str)
439 fp_conversion.write(conversion_flags_str)
440 fp_input.write(input_data_str)
441 debug_info_str = debug_info_str if debug_info_str else ""
442 # if debug_info_str contains a "string value", then the call to
443 # fp_debug.write(debug_info_str) will fail with the following error
444 #
445 # TypeError: a bytes-like object is required, not 'str'
446 #
447 # Some of the subtests within the "convert_test" unit-test fail
448 # with the error shown above. So watch out for that scenario and
449 # convert debug_info_str to bytes where needed
450 if not isinstance(debug_info_str, bytes):
451 fp_debug.write(debug_info_str.encode("utf-8"))
452 else:
453 fp_debug.write(debug_info_str)
455 # Reserve an output file
456 with _tempfile.NamedTemporaryFile(delete=False) as fp:
457 output_filename = fp.name
459 # Run
460 cmd = [
461 _deprecated_conversion_binary,
462 model_filename,
463 conversion_filename,
464 input_filename,
465 output_filename,
466 "--debug_proto_file={}".format(debug_filename),
467 ]
468 cmdline = " ".join(cmd)
469 is_windows = _platform.system() == "Windows"
470 proc = _subprocess.Popen(
471 cmdline,
472 shell=True,
473 stdout=_subprocess.PIPE,
474 stderr=_subprocess.STDOUT,
475 close_fds=not is_windows,
476 )
477 stdout, stderr = proc.communicate()
478 exitcode = proc.returncode
479 if exitcode == 0:
480 with open(output_filename, "rb") as fp:
481 return fp.read()
482 else:
483 stdout = _try_convert_to_unicode(stdout)
484 stderr = _try_convert_to_unicode(stderr)
485 raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr))
486 finally:
487 # Must manually cleanup files.
488 for filename in [
489 conversion_filename,
490 input_filename,
491 model_filename,
492 output_filename,
493 ]:
494 try:
495 _os.unlink(filename)
496 except (OSError, TypeError):
497 pass
500def build_model_flags(
501 change_concat_input_ranges=False,
502 allow_nonexistent_arrays=False,
503 saved_model_dir=None,
504 saved_model_version=0,
505 saved_model_tags=None,
506 saved_model_exported_names=None,
507 **_
508):
509 """Builds the model flags object from params.
511 Args:
512 change_concat_input_ranges: Boolean to change behavior of min/max ranges for
513 inputs and outputs of the concat operator for quantized models. Changes
514 the ranges of concat operator overlap when true. (default False)
515 allow_nonexistent_arrays: Allow specifying array names that don't exist or
516 are unused in the final graph. (default False)
517 saved_model_dir: Filepath of the saved model to be converted. This value
518 will be non-empty only when the saved model import path will be used.
519 Otherwises, the graph def-based conversion will be processed.
520 saved_model_version: SavedModel file format version of The saved model file
521 to be converted. This value will be set only when the SavedModel import
522 path will be used.
523 saved_model_tags: Set of string saved model tags, formatted in the
524 comma-separated value. This value will be set only when the SavedModel
525 import path will be used.
526 saved_model_exported_names: Names to be exported (default: export all) when
527 the saved model import path is on. This value will be set only when the
528 SavedModel import path will be used.
530 Returns:
531 model_flags: protocol buffer describing the model.
532 """
533 model_flags = _model_flags_pb2.ModelFlags()
534 model_flags.change_concat_input_ranges = change_concat_input_ranges
535 model_flags.allow_nonexistent_arrays = allow_nonexistent_arrays
536 if saved_model_dir:
537 model_flags.saved_model_dir = saved_model_dir
538 model_flags.saved_model_version = saved_model_version
539 if saved_model_tags:
540 model_flags.saved_model_tags.extend(saved_model_tags)
541 if saved_model_exported_names:
542 model_flags.saved_model_exported_names.extend(saved_model_exported_names)
543 return model_flags
546def build_conversion_flags(
547 inference_type=dtypes.float32,
548 inference_input_type=None,
549 input_format=lite_constants.TENSORFLOW_GRAPHDEF,
550 output_format=lite_constants.TFLITE,
551 default_ranges_stats=None,
552 drop_control_dependency=True,
553 reorder_across_fake_quant=False,
554 allow_custom_ops=False,
555 post_training_quantize=False,
556 quantize_to_float16=False,
557 dump_graphviz_dir=None,
558 dump_graphviz_video=False,
559 target_ops=None,
560 conversion_summary_dir=None,
561 select_user_tf_ops=None,
562 allow_all_select_tf_ops=False,
563 enable_tflite_resource_variables=True,
564 unfold_batchmatmul=True,
565 lower_tensor_list_ops=True,
566 default_to_single_batch_in_tensor_list_ops=False,
567 accumulation_type=None,
568 allow_bfloat16=False,
569 unfold_large_splat_constant=False,
570 supported_backends=None,
571 disable_per_channel_quantization=False,
572 enable_mlir_dynamic_range_quantizer=False,
573 tf_quantization_mode=None,
574 disable_infer_tensor_range=False,
575 use_fake_quant_num_bits=False,
576 enable_dynamic_update_slice=False,
577 preserve_assert_op=False,
578 guarantee_all_funcs_one_use=False,
579 enable_mlir_variable_quantization=False,
580 disable_fuse_mul_and_fc=False,
581 quantization_options: Optional[quant_opts_pb2.QuantizationOptions] = None,
582 **_
583):
584 """Builds protocol buffer describing a conversion of a model.
586 Typically this is to convert from TensorFlow GraphDef to TFLite, in which
587 case the default `input_format` and `output_format` are sufficient.
589 Args:
590 inference_type: Data type of numeric arrays, excluding the input layer.
591 (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
592 inference_input_type: Data type of the numeric arrays in the input layer. If
593 `inference_input_type` is in {tf.int8, tf.uint8}, then
594 `quantized_input_stats` must be provided. (default is the value assigned
595 to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
596 input_format: Type of data to read. (default TENSORFLOW_GRAPHDEF, must be in
597 {TENSORFLOW_GRAPHDEF})
598 output_format: Output file format. (default TFLITE, must be in {TFLITE,
599 GRAPHVIZ_DOT})
600 default_ranges_stats: Tuple of integers representing (min, max) range values
601 for all arrays without a specified range. Intended for experimenting with
602 quantization via "dummy quantization". (default None)
603 drop_control_dependency: Boolean indicating whether to drop control
604 dependencies silently. This is due to TFLite not supporting control
605 dependencies. (default True)
606 reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
607 nodes in unexpected locations. Used when the location of the FakeQuant
608 nodes is preventing graph transformations necessary to convert the graph.
609 Results in a graph that differs from the quantized training graph,
610 potentially causing differing arithmetic behavior. (default False)
611 allow_custom_ops: Boolean indicating whether to allow custom operations.
612 When false any unknown operation is an error. When true, custom ops are
613 created for any op that is unknown. The developer will need to provide
614 these to the TensorFlow Lite runtime with a custom resolver. (default
615 False)
616 post_training_quantize: Boolean indicating whether to quantize the weights
617 of the converted float model. Model size will be reduced and there will be
618 latency improvements (at the cost of accuracy). (default False) If
619 quantization_options is set, all quantization arg will be ignored.
620 quantize_to_float16: Boolean indicating whether to convert float buffers to
621 float16. (default False)
622 dump_graphviz_dir: Full filepath of folder to dump the graphs at various
623 stages of processing GraphViz .dot files. Preferred over
624 --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
625 output file. (default None)
626 dump_graphviz_video: Boolean indicating whether to dump the graph after
627 every graph transformation. (default False)
628 target_ops: Experimental flag, subject to change. Set of OpsSet options
629 indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS]))
630 conversion_summary_dir: A string, the path to the generated conversion logs.
631 select_user_tf_ops: List of user's defined TensorFlow ops need to be
632 supported in the TensorFlow Lite runtime. These ops will be supported as
633 select TensorFlow ops.
634 allow_all_select_tf_ops: If True, automatically add all TF ops (including
635 custom TF ops) to the converted model as flex ops.
636 enable_tflite_resource_variables: Experimental flag, subject to change.
637 Enables conversion of resource variables. (default False)
638 unfold_batchmatmul: Whether to unfold tf.BatchMatMul to a set of
639 tfl.fully_connected ops. If not, translate to tfl.batch_matmul.
640 lower_tensor_list_ops: Whether to lower tensor list ops to builtin ops. If
641 not, use Flex tensor list ops.
642 default_to_single_batch_in_tensor_list_ops: Whether to force to use batch
643 size one when the tensor list ops has the unspecified batch size.
644 accumulation_type: Data type of the accumulators in quantized inference.
645 Typically used for float16 quantization and is either fp16 or fp32.
646 allow_bfloat16: Whether the converted model supports reduced precision
647 inference with the bfloat16 type.
648 unfold_large_splat_constant: Whether to unfold large splat constant tensors
649 in the flatbuffer model to reduce size.
650 supported_backends: List of TFLite backends which needs to check
651 compatibility.
652 disable_per_channel_quantization: Disable per-channel quantized weights for
653 dynamic range quantization. Only per-tensor quantization will be used.
654 enable_mlir_dynamic_range_quantizer: Enable MLIR dynamic range quantization.
655 If False, the old converter dynamic range quantizer is used.
656 tf_quantization_mode: Indicates the mode of TF Quantization when the output
657 model is used for TF Quantization.
658 disable_infer_tensor_range: Disable infering tensor ranges.
659 use_fake_quant_num_bits: Allow quantization parameters to be calculated from
660 num_bits attribute.
661 enable_dynamic_update_slice: Enable to convert to DynamicUpdateSlice op.
662 (default: False).
663 preserve_assert_op: Whether to preserve `TF::AssertOp` (default: False).
664 guarantee_all_funcs_one_use: Whether to clone functions so that each
665 function only has a single use. This option will be helpful if the
666 conversion fails when the `PartitionedCall` or `StatefulPartitionedCall`
667 can't be properly inlined (default: False).
668 enable_mlir_variable_quantization: Enable MLIR variable quantization. There
669 is a variable freezing pass, but some variables may not be fully frozen by
670 it. This flag enables quantization of those residual variables in the MLIR
671 graph.
672 disable_fuse_mul_and_fc: Disable fusing input multiplication with
673 fullyconnected operations. Useful when quantizing weights.
674 quantization_options: Config to indicate quantization options of each
675 components (ex: weight, bias, activation). This can be a preset method or
676 a custom method, and allows finer, modular control. This option will
677 override any other existing quantization flags. We plan on gradually
678 migrating all quantization-related specs into this option.
680 Returns:
681 conversion_flags: protocol buffer describing the conversion process.
682 Raises:
683 ValueError, if the input tensor type is unknown.
684 """
685 conversion_flags = _conversion_flags_pb2.TocoFlags()
686 conversion_flags.inference_type = convert_inference_tf_type_to_tflite_type(
687 inference_type, usage="inference_type flag"
688 )
689 if inference_input_type:
690 conversion_flags.inference_input_type = (
691 convert_inference_tf_type_to_tflite_type(
692 inference_input_type, usage="inference_input_type flag"
693 )
694 )
695 else:
696 conversion_flags.inference_input_type = conversion_flags.inference_type
697 conversion_flags.input_format = input_format
698 conversion_flags.output_format = output_format
699 if default_ranges_stats:
700 conversion_flags.default_ranges_min = default_ranges_stats[0]
701 conversion_flags.default_ranges_max = default_ranges_stats[1]
702 conversion_flags.drop_control_dependency = drop_control_dependency
703 conversion_flags.reorder_across_fake_quant = reorder_across_fake_quant
704 conversion_flags.allow_custom_ops = allow_custom_ops
705 conversion_flags.post_training_quantize = post_training_quantize
706 conversion_flags.quantize_to_float16 = quantize_to_float16
707 if dump_graphviz_dir:
708 conversion_flags.dump_graphviz_dir = dump_graphviz_dir
709 conversion_flags.dump_graphviz_include_video = dump_graphviz_video
710 if target_ops:
711 if OpsSet.SELECT_TF_OPS in target_ops:
712 conversion_flags.enable_select_tf_ops = True
713 if set(target_ops) == {OpsSet.SELECT_TF_OPS}:
714 conversion_flags.force_select_tf_ops = True
715 if OpsSet.EXPERIMENTAL_STABLEHLO_OPS in target_ops:
716 conversion_flags.convert_to_stablehlo = True
717 if OpsSet.EXPERIMENTAL_STABLEHLO_OPS in target_ops and len(target_ops) > 1:
718 raise ValueError(
719 "StableHLO Ops set can not be specified with other Ops set together"
720 )
721 if conversion_summary_dir:
722 conversion_flags.conversion_summary_dir = conversion_summary_dir
723 if select_user_tf_ops:
724 conversion_flags.select_user_tf_ops.extend(select_user_tf_ops)
725 conversion_flags.allow_all_select_tf_ops = allow_all_select_tf_ops
726 conversion_flags.enable_tflite_resource_variables = (
727 enable_tflite_resource_variables
728 )
729 conversion_flags.unfold_batchmatmul = unfold_batchmatmul
730 conversion_flags.lower_tensor_list_ops = lower_tensor_list_ops
731 conversion_flags.default_to_single_batch_in_tensor_list_ops = (
732 default_to_single_batch_in_tensor_list_ops
733 )
734 if accumulation_type:
735 conversion_flags.accumulation_type = convert_tensor_tf_type_to_tflite_type(
736 accumulation_type, usage="accumulation_type flag"
737 )
738 conversion_flags.allow_bfloat16 = allow_bfloat16
739 conversion_flags.unfold_large_splat_constant = unfold_large_splat_constant
740 if supported_backends:
741 conversion_flags.supported_backends.extend(supported_backends)
742 conversion_flags.disable_per_channel_quantization = (
743 disable_per_channel_quantization
744 )
745 conversion_flags.enable_mlir_dynamic_range_quantizer = (
746 enable_mlir_dynamic_range_quantizer
747 )
748 conversion_flags.enable_dynamic_update_slice = enable_dynamic_update_slice
749 conversion_flags.preserve_assert_op = preserve_assert_op
750 conversion_flags.guarantee_all_funcs_one_use = guarantee_all_funcs_one_use
751 if tf_quantization_mode:
752 conversion_flags.tf_quantization_mode = tf_quantization_mode
753 conversion_flags.disable_infer_tensor_range = disable_infer_tensor_range
754 conversion_flags.use_fake_quant_num_bits = use_fake_quant_num_bits
755 conversion_flags.enable_mlir_variable_quantization = (
756 enable_mlir_variable_quantization
757 )
758 conversion_flags.disable_fuse_mul_and_fc = disable_fuse_mul_and_fc
759 if quantization_options:
760 conversion_flags.quantization_options.CopyFrom(quantization_options)
761 return conversion_flags
764@convert_phase(
765 Component.CONVERT_TF_TO_TFLITE_MODEL, SubComponent.CONVERT_GRAPHDEF
766)
767def convert_graphdef_with_arrays(
768 input_data,
769 input_arrays_with_shape,
770 output_arrays,
771 control_output_arrays,
772 **kwargs
773):
774 """Convert a frozen GraphDef that can't be loaded in TF.
776 Conversion can be customized by providing arguments that are forwarded to
777 `build_model_flags` and `build_conversion_flags` (see documentation).
779 Args:
780 input_data: Input data (i.e. often `sess.graph_def`),
781 input_arrays_with_shape: Tuple of strings representing input tensor names
782 and list of integers representing input shapes (e.g., [("foo" : [1, 16,
783 16, 3])]). Use only when graph cannot be loaded into TensorFlow and when
784 `input_tensors` is None.
785 output_arrays: List of output tensors to freeze graph with. Use only when
786 graph cannot be loaded into TensorFlow and when `output_tensors` is None.
787 control_output_arrays: Control output node names. This is used when
788 converting a Graph with no output tensors. For example, if the graph's
789 last operation is a Print op, just specify that op's name in this field.
790 This can be used together with the `output_arrays` parameter.
791 **kwargs: See `build_model_flags` and `build_conversion_flags`.
793 Returns:
794 The converted data. For example if TFLite was the destination, then
795 this will be a tflite flatbuffer in a bytes array.
797 Raises:
798 Defined in `build_conversion_flags`.
799 """
800 model_flags = build_model_flags(**kwargs)
801 conversion_flags = build_conversion_flags(**kwargs)
802 enable_mlir_converter = kwargs.get("enable_mlir_converter", True)
803 quantized_input_stats = kwargs.get("quantized_input_stats", None)
805 for idx, (name, shape) in enumerate(input_arrays_with_shape):
806 input_array = model_flags.input_arrays.add()
807 if _is_quantized_input_stats_required(conversion_flags):
808 if quantized_input_stats:
809 input_array.mean_value, input_array.std_value = quantized_input_stats[
810 idx
811 ]
812 else:
813 raise ValueError(
814 "The `quantized_input_stats` flag must be defined when either "
815 "`inference_type` flag or `inference_input_type` flag is set to "
816 "tf.int8 or tf.uint8."
817 )
818 input_array.name = name
819 input_array.shape.dims.extend(list(map(int, shape)))
821 if output_arrays:
822 for name in output_arrays:
823 model_flags.output_arrays.append(name)
824 if control_output_arrays:
825 for name in control_output_arrays:
826 model_flags.control_output_arrays.append(name)
828 data = convert(
829 model_flags,
830 conversion_flags,
831 input_data.SerializeToString(),
832 debug_info_str=None,
833 enable_mlir_converter=enable_mlir_converter,
834 )
835 return data
838@convert_phase(
839 Component.CONVERT_TF_TO_TFLITE_MODEL, SubComponent.CONVERT_GRAPHDEF
840)
841def convert_graphdef(input_data, input_tensors, output_tensors, **kwargs):
842 """Convert a frozen GraphDef model using the TF Lite converter.
844 Conversion can be customized by providing arguments that are forwarded to
845 `build_model_flags` and `build_conversion_flags` (see documentation).
847 Args:
848 input_data: Input data (i.e. often `sess.graph_def`),
849 input_tensors: List of input tensors. Type and shape are computed using
850 `foo.shape` and `foo.dtype`.
851 output_tensors: List of output tensors (only .name is used from this).
852 **kwargs: See `build_model_flags` and `build_conversion_flags`.
854 Returns:
855 The converted data. For example if TFLite was the destination, then
856 this will be a tflite flatbuffer in a bytes array.
858 Raises:
859 Defined in `build_conversion_flags`.
860 """
861 model_flags = build_model_flags(**kwargs)
862 conversion_flags = build_conversion_flags(**kwargs)
863 saved_model_dir = kwargs.get("saved_model_dir", None)
864 input_shapes = kwargs.get("input_shapes", None)
865 enable_mlir_converter = kwargs.get("enable_mlir_converter", True)
866 quantized_input_stats = kwargs.get("quantized_input_stats", None)
867 debug_info = kwargs.get("debug_info", None)
869 for idx, input_tensor in enumerate(input_tensors):
870 input_array = model_flags.input_arrays.add()
871 if saved_model_dir:
872 input_array.name = input_tensor.name
873 else:
874 input_array.name = util.get_tensor_name(input_tensor)
875 input_array.data_type = convert_tensor_tf_type_to_tflite_type(
876 input_tensor.dtype, usage="input type of the TensorFlow model"
877 )
879 if _is_quantized_input_stats_required(conversion_flags):
880 if quantized_input_stats:
881 input_array.mean_value, input_array.std_value = quantized_input_stats[
882 idx
883 ]
884 else:
885 # We should ideally raise an error here, but we don't as it would break
886 # several models/projects that depend on this workflow.
887 warnings.warn(
888 "Statistics for quantized inputs were expected, but not "
889 "specified; continuing anyway."
890 )
892 if input_shapes is None:
893 shape = input_tensor.shape
894 else:
895 shape = input_shapes[idx]
897 if shape.rank is not None:
898 # Create shapes with -1 for unknown dimensions.
899 dims = []
900 for dim in shape:
901 if dim is None or (
902 isinstance(dim, tensor_shape.Dimension) and dim.value is None
903 ):
904 dims.append(-1)
905 else:
906 dims.append(int(dim))
907 input_array.shape.dims.extend(dims)
908 input_array.shape.unknown_rank = False
909 else:
910 input_array.shape.unknown_rank = True
912 for output_tensor in output_tensors:
913 if saved_model_dir:
914 model_flags.output_arrays.append(output_tensor.name)
915 else:
916 model_flags.output_arrays.append(util.get_tensor_name(output_tensor))
918 data = convert(
919 model_flags,
920 conversion_flags,
921 input_data.SerializeToString(),
922 debug_info_str=debug_info.SerializeToString() if debug_info else None,
923 enable_mlir_converter=enable_mlir_converter,
924 )
925 return data
928@convert_phase(
929 Component.CONVERT_TF_TO_TFLITE_MODEL, SubComponent.CONVERT_SAVED_MODEL
930)
931def convert_saved_model(**kwargs):
932 """Converts a SavedModel using TF Lite converter."""
933 model_flags = build_model_flags(**kwargs)
934 conversion_flags = build_conversion_flags(**kwargs)
935 data = convert(
936 model_flags,
937 conversion_flags,
938 input_data_str=None,
939 debug_info_str=None,
940 enable_mlir_converter=True,
941 )
942 return data
945@convert_phase(
946 Component.CONVERT_TF_TO_TFLITE_MODEL, SubComponent.CONVERT_JAX_HLO
947)
948def convert_jax_hlo(input_content, input_names, is_proto_format, **kwargs):
949 """Converts a Jax hlo-based model using TFLite converter."""
950 model_flags = _model_flags_pb2.ModelFlags()
951 model_flags.use_hlo_import = True
952 if is_proto_format:
953 model_flags.hlo_file_type = _model_flags_pb2.ModelFlags.HLO_PROTO
954 else:
955 model_flags.hlo_file_type = _model_flags_pb2.ModelFlags.HLO_TEXT
957 # Build input names.
958 for input_name in input_names:
959 input_array = model_flags.input_arrays.add()
960 input_array.name = input_name
962 conversion_flags = build_conversion_flags(**kwargs)
963 data = convert(
964 model_flags,
965 conversion_flags,
966 input_data_str=input_content,
967 debug_info_str=None,
968 enable_mlir_converter=True,
969 )
970 return data
973@_tf_export(v1=["lite.toco_convert"])
974@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.")
975def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
976 """Convert a TensorFlow GraphDef to TFLite.
978 This function is deprecated. Please use `tf.lite.TFLiteConverter` API instead.
979 Conversion can be customized by providing arguments that are forwarded to
980 `build_model_flags` and `build_conversion_flags` (see documentation for
981 details).
982 Args:
983 input_data: Input data (i.e. often `sess.graph_def`).
984 input_tensors: List of input tensors. Type and shape are computed using
985 `foo.shape` and `foo.dtype`.
986 output_tensors: List of output tensors (only .name is used from this).
987 *args: See `build_model_flags` and `build_conversion_flags`.
988 **kwargs: See `build_model_flags` and `build_conversion_flags`.
990 Returns:
991 The converted TensorFlow Lite model in a bytes array.
993 Raises:
994 Defined in `convert`.
995 """
996 kwargs["enable_mlir_converter"] = kwargs.get("enable_mlir_converter", False)
997 return convert_graphdef(
998 input_data, input_tensors, output_tensors, *args, **kwargs
999 )
1002def deduplicate_readonly_buffers(tflite_model):
1003 """Generates a new model byte array after deduplicating readonly buffers.
1005 This function should be invoked after the model optimization toolkit. The
1006 model optimization toolkit assumes that each tensor object owns its each
1007 buffer separately.
1009 Args:
1010 tflite_model: TFLite flatbuffer in a byte array to be deduplicated.
1012 Returns:
1013 TFLite flatbuffer in a bytes array, processed with the deduplication method.
1014 """
1015 # Load TFLite Flatbuffer byte array into an object.
1016 model = flatbuffer_utils.convert_bytearray_to_object(tflite_model)
1018 # Get all the read-only buffers, which can be modified without causing any
1019 # issue in the graph invocation stage.
1020 read_only_buffer_indices = set()
1021 for subgraph in model.subgraphs:
1022 # To get all the read-only buffers:
1023 # (1) Get all read-only input tensors.
1024 # (2) Discard intermediate or output tensors.
1025 # (3) Discard the subgraph's input/output tensors.
1026 # (4) Gather the buffers of the read-only input tensors.
1028 # (1) Get read-only input tensors.
1029 read_only_input_tensor_indices = set()
1030 for op in subgraph.operators:
1031 if op.inputs is None:
1032 continue
1033 for i, input_tensor_idx in enumerate(op.inputs):
1034 # Ignore mutable tensors.
1035 if op.mutatingVariableInputs is not None:
1036 # Ignore invalid tensors.
1037 if (
1038 i < len(op.mutatingVariableInputs)
1039 and op.mutatingVariableInputs[i]
1040 ):
1041 continue
1042 # Ignore variable tensors.
1043 if subgraph.tensors[input_tensor_idx].isVariable:
1044 continue
1045 read_only_input_tensor_indices.add(input_tensor_idx)
1047 # (2) Discard intermediate or output tensors.
1048 for op in subgraph.operators:
1049 if op.outputs is not None:
1050 for output_tensor_idx in op.outputs:
1051 read_only_input_tensor_indices.discard(output_tensor_idx)
1052 if op.intermediates is not None:
1053 for intermediate_tensor_idx in op.intermediates:
1054 read_only_input_tensor_indices.discard(intermediate_tensor_idx)
1056 # (3) Discard the subgraph's input and output tensors.
1057 if subgraph.inputs is not None:
1058 for input_tensor_idx in subgraph.inputs:
1059 read_only_input_tensor_indices.discard(input_tensor_idx)
1060 if subgraph.outputs is not None:
1061 for output_tensor_idx in subgraph.outputs:
1062 read_only_input_tensor_indices.discard(output_tensor_idx)
1064 # (4) Gather the buffers of the read-only input tensors.
1065 for tensor_idx in read_only_input_tensor_indices:
1066 read_only_buffer_indices.add(subgraph.tensors[tensor_idx].buffer)
1068 # Ignore invalid negative index or zero-sized buffers.
1069 for buffer_idx in read_only_buffer_indices.copy():
1070 if buffer_idx < 0 or (
1071 model.buffers[buffer_idx].data is None
1072 or isinstance(model.buffers[buffer_idx].data, list)
1073 or model.buffers[buffer_idx].data.size == 0
1074 ):
1075 read_only_buffer_indices.discard(buffer_idx)
1077 class BufferIndex:
1078 """A class to store index, size, hash of the buffers in TFLite model."""
1080 def __init__(self, idx, size, hash_value):
1081 self.idx = idx
1082 self.size = size
1083 self.hash_value = hash_value
1085 read_only_buffers = list(
1086 map(
1087 lambda index: BufferIndex( # pylint: disable=g-long-lambda
1088 index,
1089 model.buffers[index].data.size,
1090 hashlib.md5(model.buffers[index].data.data.tobytes()).hexdigest(),
1091 ),
1092 read_only_buffer_indices,
1093 )
1094 )
1096 # Sort read_only_buffers by buffer size & hash in descending order.
1097 read_only_buffers = sorted(
1098 read_only_buffers,
1099 key=lambda buffer: (buffer.size, buffer.hash_value),
1100 reverse=True,
1101 )
1103 # Create a map of duplicate buffers (same size and same type).
1104 # eg: In [1, 2, 3, 4, 5, 6] if (1, 4, 6) and (2, 5) are each, groups of buffer
1105 # indices of the same size and type, then the map would be {4:1, 6:1, 5:2}
1106 duplicate_buffer_map = {}
1107 for i, buffer_i in enumerate(read_only_buffers):
1108 # This buffer is a duplicate.
1109 if buffer_i.idx in duplicate_buffer_map:
1110 continue
1111 # This buffer is unique. Scan rest of the list to find duplicates
1112 # of this buffer and mark them accordingly.
1113 for buffer_j in read_only_buffers[i + 1 :]:
1114 if buffer_j.idx in duplicate_buffer_map:
1115 continue
1116 if buffer_i.size != buffer_j.size:
1117 break
1118 if buffer_i.hash_value != buffer_j.hash_value:
1119 continue
1120 # Found duplicate. Nullify j-th buffer and use i-th buffer instead.
1121 duplicate_buffer_map[buffer_j.idx] = buffer_i.idx
1123 # Make the duplicated tensors use the single shared buffer index.
1124 for subgraph in model.subgraphs:
1125 for op in subgraph.operators:
1126 if op.inputs is None:
1127 continue
1128 for input_tensor in op.inputs:
1129 buffer_idx = subgraph.tensors[input_tensor].buffer
1130 if buffer_idx in duplicate_buffer_map:
1131 subgraph.tensors[input_tensor].buffer = duplicate_buffer_map[
1132 buffer_idx
1133 ]
1135 # Nullify the unused buffers.
1136 for idx in duplicate_buffer_map:
1137 model.buffers[idx].data = None
1139 # Return a TFLite flatbuffer as a byte array.
1140 return flatbuffer_utils.convert_object_to_bytearray(model)