Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/python/lite.py: 24%
1014 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Lite tooling helper functionality."""
17import enum
18import functools
19import pprint
20import shutil
21import sys
22import tempfile
23import time
24import warnings
26from absl import logging
28from google.protobuf import text_format as _text_format
29from google.protobuf.message import DecodeError
30from tensorflow.core.framework import graph_pb2 as _graph_pb2
31from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op # pylint: disable=unused-import
32from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metdata_fb
33from tensorflow.lite.python import lite_constants as constants
34from tensorflow.lite.python.convert import convert_graphdef as _convert_graphdef
35from tensorflow.lite.python.convert import convert_graphdef_with_arrays as _convert_graphdef_with_arrays
36from tensorflow.lite.python.convert import convert_jax_hlo as _convert_jax_hlo
37from tensorflow.lite.python.convert import convert_saved_model as _convert_saved_model
38from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import
39from tensorflow.lite.python.convert import deduplicate_readonly_buffers as _deduplicate_readonly_buffers
40from tensorflow.lite.python.convert import mlir_quantize as _mlir_quantize
41from tensorflow.lite.python.convert import mlir_sparsify as _mlir_sparsify
42from tensorflow.lite.python.convert import OpsSet
43from tensorflow.lite.python.convert import toco_convert # pylint: disable=unused-import
44from tensorflow.lite.python.convert_phase import Component
45from tensorflow.lite.python.convert_phase import convert_phase
46from tensorflow.lite.python.convert_phase import SubComponent
47from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
48from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import
49from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import
50from tensorflow.lite.python.interpreter import OpResolverType # pylint: disable=unused-import
51from tensorflow.lite.python.metrics import metrics
52from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
53from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted
54from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
55from tensorflow.lite.python.optimize import calibrator as _calibrator
56from tensorflow.lite.python.util import _xla_computation
57from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func
58from tensorflow.lite.python.util import convert_debug_info_func as _convert_debug_info_func
59from tensorflow.lite.python.util import freeze_graph as _freeze_graph
60from tensorflow.lite.python.util import get_debug_info as _get_debug_info
61from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config
62from tensorflow.lite.python.util import get_sparsity_modes as _get_sparsity_modes
63from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
64from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
65from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name
66from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
67from tensorflow.lite.python.util import model_input_signature as _model_input_signature
68from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
69from tensorflow.lite.python.util import populate_conversion_metadata as _populate_conversion_metadata
70from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
71from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
72from tensorflow.lite.python.util import trace_model_call as _trace_model_call
73from tensorflow.lite.tools import flatbuffer_utils
74from tensorflow.lite.tools.optimize.debugging.python.debugger import QuantizationDebugger # pylint: disable=unused-import
75from tensorflow.lite.tools.optimize.debugging.python.debugger import QuantizationDebugOptions # pylint: disable=unused-import
76from tensorflow.python import saved_model as _saved_model
77from tensorflow.python.client import session as _session
78from tensorflow.python.eager import context
79from tensorflow.python.eager import def_function as _def_function
80from tensorflow.python.eager import function as _function
81from tensorflow.python.framework import byte_swap_tensor as bst
82from tensorflow.python.framework import convert_to_constants as _convert_to_constants
83from tensorflow.python.framework import dtypes as _dtypes
84from tensorflow.python.framework import ops as _ops
85from tensorflow.python.framework import versions
86from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
87from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
88from tensorflow.python.platform import gfile
89from tensorflow.python.saved_model import loader_impl as _loader_impl
90from tensorflow.python.saved_model import save_options as _save_options
91from tensorflow.python.saved_model import signature_constants as _signature_constants
92from tensorflow.python.saved_model import tag_constants as _tag_constants
93from tensorflow.python.saved_model.load import load as _load
94from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
95from tensorflow.python.util import deprecation as _deprecation
96from tensorflow.python.util import keras_deps
97from tensorflow.python.util.tf_export import tf_export as _tf_export
100@_tf_export("lite.Optimize")
101class Optimize(enum.Enum):
102 """Enum defining the optimizations to apply when generating a tflite model.
104 DEFAULT
105 The default optimization strategy that enables post-training quantization.
106 The type of post-training quantization that will be used is dependent on
107 the other converter options supplied. Refer to the
108 [documentation](/lite/performance/post_training_quantization) for further
109 information on the types available and how to use them.
111 OPTIMIZE_FOR_SIZE
112 Deprecated. Does the same as DEFAULT.
114 OPTIMIZE_FOR_LATENCY
115 Deprecated. Does the same as DEFAULT.
117 EXPERIMENTAL_SPARSITY
118 Experimental flag, subject to change.
120 Enable optimization by taking advantage of the sparse model weights
121 trained with pruning.
123 The converter will inspect the sparsity pattern of the model weights and
124 do its best to improve size and latency.
125 The flag can be used alone to optimize float32 models with sparse weights.
126 It can also be used together with the DEFAULT optimization mode to
127 optimize quantized models with sparse weights.
128 """
130 # Default optimization strategy that quantizes model weights. Enhanced
131 # optimizations are gained by providing a representative dataset that
132 # quantizes biases and activations as well.
133 # Converter will do its best to reduce size and latency, while minimizing
134 # the loss in accuracy.
135 DEFAULT = "DEFAULT"
137 # Deprecated. Does the same as DEFAULT.
138 OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE"
140 # Deprecated. Does the same as DEFAULT.
141 OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
143 # Experimental flag, subject to change.
144 # Enable optimization by taking advantage of the sparse model weights trained
145 # with pruning.
146 #
147 # The converter will inspect the sparsity pattern of the model weights and do
148 # its best to improve size and latency.
149 # The flag can be used alone to optimize float32 models with sparse weights.
150 # It can also be used together with the DEFAULT optimization mode to optimize
151 # quantized models with sparse weights.
152 # TODO(b/161560631): Add log message when this optimization is applied.
153 EXPERIMENTAL_SPARSITY = "EXPERIMENTAL_SPARSITY"
155 def __str__(self):
156 return str(self.value)
159# TODO(b/198099651): move converter implementation out of lite.py
160@_tf_export("lite.RepresentativeDataset")
161class RepresentativeDataset:
162 """Representative dataset used to optimize the model.
164 This is a generator function that provides a small dataset to calibrate or
165 estimate the range, i.e, (min, max) of all floating-point arrays in the model
166 (such as model input, activation outputs of intermediate layers, and model
167 output) for quantization. Usually, this is a small subset of a few hundred
168 samples randomly chosen, in no particular order, from the training or
169 evaluation dataset.
170 """
172 def __init__(self, input_gen):
173 """Creates a representative dataset.
175 Args:
176 input_gen: A generator function that generates input samples for the model
177 and has the same order, type and shape as the inputs to the model.
178 Usually, this is a small subset of a few hundred samples randomly
179 chosen, in no particular order, from the training or evaluation dataset.
180 """
181 self.input_gen = input_gen
184@_tf_export("lite.TargetSpec")
185class TargetSpec:
186 """Specification of target device used to optimize the model.
188 Attributes:
189 supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet`
190 options, where each option represents a set of operators supported by the
191 target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS}))
192 supported_types: Set of `tf.dtypes.DType` data types supported on the target
193 device. If initialized, optimization might be driven by the smallest type
194 in this set. (default set())
195 experimental_select_user_tf_ops: Experimental flag, subject to change. Set
196 of user's TensorFlow operators' names that are required in the TensorFlow
197 Lite runtime. These ops will be exported as select TensorFlow ops in the
198 model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is
199 an advanced feature that should only be used if the client is using TF ops
200 that may not be linked in by default with the TF ops that are provided
201 when using the SELECT_TF_OPS path. The client is responsible for linking
202 these ops into the target runtime.
203 experimental_supported_backends: Experimental flag, subject to change. Set
204 containing names of supported backends. Currently only "GPU" is supported,
205 more options will be available later.
206 """
208 def __init__(
209 self,
210 supported_ops=None,
211 supported_types=None,
212 experimental_select_user_tf_ops=None,
213 experimental_supported_backends=None,
214 ):
215 if supported_ops is None:
216 supported_ops = {OpsSet.TFLITE_BUILTINS}
217 self.supported_ops = supported_ops
218 if supported_types is None:
219 supported_types = set()
220 self.supported_types = supported_types
221 if experimental_select_user_tf_ops is None:
222 experimental_select_user_tf_ops = set()
223 self.experimental_select_user_tf_ops = experimental_select_user_tf_ops
224 self.experimental_supported_backends = experimental_supported_backends
225 self._experimental_custom_op_registerers = []
226 # Hint for the supported accumulation type used for inference. Typically
227 # used for fp16 post-training quantization, where some models can use fp16
228 # accumulators instead of the typical fp32 type.
229 # TODO(b/188185962): Provide full API and authoring support for
230 # reduced precision accumulation types.
231 self._experimental_supported_accumulation_type = None
234class QuantizationMode:
235 """QuantizationMode determines the quantization type from user options."""
237 def __init__(
238 self,
239 optimizations,
240 target_spec,
241 representative_dataset,
242 graph_def,
243 disable_per_channel=False,
244 experimental_new_dynamic_range_quantizer=False,
245 experimental_low_bit_qat=False,
246 full_integer_quantization_bias_type=None,
247 experimental_mlir_variable_quantization=False,
248 ):
249 self._optimizations = optimizations
250 for deprecated_optimization in [
251 Optimize.OPTIMIZE_FOR_SIZE,
252 Optimize.OPTIMIZE_FOR_LATENCY,
253 ]:
254 if deprecated_optimization in self._optimizations:
255 logging.warning(
256 (
257 "Optimization option %s is deprecated, please use"
258 " optimizations=[Optimize.DEFAULT] instead."
259 ),
260 deprecated_optimization,
261 )
263 self._target_spec = target_spec
264 self._representative_dataset = representative_dataset
265 self._graph_def = graph_def
267 self._validate_int8_required()
268 self._disable_per_channel = disable_per_channel
270 self._enable_new_dynamic_range_quantizer = (
271 experimental_new_dynamic_range_quantizer
272 )
273 # Allow training with lower than 8 bit weights to be converted
274 # to constants with trained scale.
275 self._experimental_low_bit_qat = experimental_low_bit_qat
277 self._full_integer_quantization_bias_type = (
278 full_integer_quantization_bias_type
279 )
280 self._validate_full_integer_quantization_bias_type()
282 self.enable_mlir_variable_quantization = (
283 experimental_mlir_variable_quantization
284 )
286 def is_post_training_int8_only_quantization(self):
287 return (
288 self.is_any_optimization_enabled()
289 and self._representative_dataset is not None
290 and not self._is_int16x8_target_required()
291 and not self.is_allow_float()
292 and self._is_int8_target_required()
293 )
295 def is_post_training_int8_quantization_with_float_fallback(self):
296 return (
297 self.is_any_optimization_enabled()
298 and self._representative_dataset is not None
299 and not self._is_int16x8_target_required()
300 and self.is_allow_float()
301 and self._smallest_supported_type() == _dtypes.int8
302 )
304 def is_post_training_int8_quantization(self):
305 return (
306 self.is_post_training_int8_only_quantization()
307 or self.is_post_training_int8_quantization_with_float_fallback()
308 )
310 def is_post_training_int16x8_only_quantization(self):
311 return (
312 self.is_any_optimization_enabled()
313 and self._representative_dataset is not None
314 and self._is_int16x8_target_required()
315 and not self.is_allow_float()
316 )
318 def is_post_training_int16x8_quantization_with_float_fallback(self):
319 return (
320 self.is_any_optimization_enabled()
321 and self._representative_dataset is not None
322 and self._is_int16x8_target_required()
323 and self.is_allow_float()
324 )
326 def is_post_training_int16x8_quantization(self):
327 return (
328 self.is_post_training_int16x8_only_quantization()
329 or self.is_post_training_int16x8_quantization_with_float_fallback()
330 )
332 def is_post_training_integer_quantization(self):
333 return (
334 self.is_post_training_int8_quantization()
335 or self.is_post_training_int16x8_quantization()
336 )
338 def is_low_bit_quantize_aware_training(self):
339 return (
340 self.is_any_optimization_enabled()
341 and self.is_quantization_aware_trained_model()
342 and self._experimental_low_bit_qat
343 )
345 def is_quantization_aware_training(self):
346 return (
347 self.is_any_optimization_enabled()
348 and self.is_quantization_aware_trained_model()
349 and not self.is_low_bit_quantize_aware_training()
350 )
352 def is_integer_quantization(self):
353 return (
354 self.is_post_training_integer_quantization()
355 or self.is_quantization_aware_training()
356 or self.is_low_bit_quantize_aware_training()
357 )
359 def is_post_training_dynamic_range_quantization(self):
360 # Post-training dynamic range quantization is only enabled if post-training
361 # int8 quantization and training time quantization was not done.
362 return (
363 self.is_any_optimization_enabled()
364 and self._representative_dataset is None
365 and not self.is_quantization_aware_trained_model()
366 and self._smallest_supported_type() == _dtypes.int8
367 )
369 def is_post_training_float16_quantization(self):
370 return (
371 self.is_any_optimization_enabled()
372 and self._smallest_supported_type().size == 2
373 and _dtypes.float16 in self._target_spec.supported_types
374 )
376 def is_bfloat16_quantization(self):
377 return (
378 self.is_any_optimization_enabled()
379 and self._smallest_supported_type().size == 2
380 and _dtypes.bfloat16 in self._target_spec.supported_types
381 )
383 def activations_type(self):
384 if self.is_integer_quantization():
385 if self._is_int16x8_target_required():
386 return _dtypes.int16
387 else:
388 return _dtypes.int8
389 else:
390 return _dtypes.float32
392 def bias_type(self):
393 if self._full_integer_quantization_bias_type:
394 return self._full_integer_quantization_bias_type
396 if self.activations_type() == _dtypes.int16:
397 return _dtypes.int64
398 elif self.activations_type() == _dtypes.int8:
399 return _dtypes.int32
400 else:
401 return _dtypes.float32
403 def converter_flags(self, inference_ty=None, inference_input_ty=None):
404 """Flags to the converter."""
406 if self.is_integer_quantization():
407 is_low_bit_qat = self.is_low_bit_quantize_aware_training()
408 return {
409 "inference_type": (
410 inference_ty
411 if inference_ty is not None
412 else self.activations_type()
413 ),
414 "inference_input_type": _dtypes.float32,
415 "post_training_quantize": False, # disable dynamic range quantization
416 "quantize_to_float16": False, # disable float16 quantization
417 "disable_infer_tensor_range": is_low_bit_qat,
418 "use_fake_quant_num_bits": is_low_bit_qat,
419 "enable_mlir_variable_quantization": (
420 self.enable_mlir_variable_quantization
421 ),
422 }
423 elif self.is_post_training_dynamic_range_quantization():
424 return {
425 "inference_type": _dtypes.float32,
426 "inference_input_type": _dtypes.float32,
427 "post_training_quantize": True, # enable dynamic range quantization
428 "quantize_to_float16": False, # disable float16 quantization
429 # experimental: disable per-channel (per-axis) quantization.
430 "disable_per_channel_quantization": self._disable_per_channel,
431 "enable_mlir_dynamic_range_quantizer": (
432 self._enable_new_dynamic_range_quantizer
433 ),
434 "enable_mlir_variable_quantization": (
435 self.enable_mlir_variable_quantization
436 ),
437 }
438 elif self.is_post_training_float16_quantization():
439 return {
440 "inference_type": _dtypes.float32,
441 "inference_input_type": _dtypes.float32,
442 "post_training_quantize": True,
443 "quantize_to_float16": True, # enable float16 quantization
444 # pylint: disable=protected-access
445 "accumulation_type": (
446 self._target_spec._experimental_supported_accumulation_type
447 ),
448 # pylint: enable=protected-access
449 "allow_bfloat16": self.is_bfloat16_quantization(),
450 "enable_mlir_dynamic_range_quantizer": (
451 self._enable_new_dynamic_range_quantizer
452 ),
453 "enable_mlir_variable_quantization": (
454 self.enable_mlir_variable_quantization
455 ),
456 }
457 else:
458 # Note this might still trigger (uint8) quantization to be compatible with
459 # the old converter.
460 return {
461 "inference_type": (
462 inference_ty if inference_ty is not None else _dtypes.float32
463 ),
464 "inference_input_type": inference_input_ty,
465 "post_training_quantize": False, # enable dynamic range quantization
466 "quantize_to_float16": False, # disable float16 quantization
467 "allow_bfloat16": self.is_bfloat16_quantization(),
468 }
470 # Below are helpers for the above functions.
472 def _validate_int8_required(self):
473 """Int8 mode requires certain parameters to exist and be compatible."""
474 if not self._is_int8_target_required():
475 return
477 # Validate target_spec attibute.
478 if set(self._target_spec.supported_ops) == {
479 OpsSet.TFLITE_BUILTINS_INT8
480 } and not (
481 set(self._target_spec.supported_types) == set()
482 or set(self._target_spec.supported_types) == {_dtypes.int8}
483 ):
484 raise ValueError(
485 "As full integer quantization has been enabled by setting "
486 "`target_spec.supported_ops`={tf.lite.OpsSet.TFLITE_BUILTINS_INT8}, "
487 "thus `target_spec.supported_types` should be left uninitizalized "
488 "or set to {tf.int8}."
489 )
490 if set(self._target_spec.supported_types) == {_dtypes.int8}:
491 self._target_spec.supported_ops = {OpsSet.TFLITE_BUILTINS_INT8}
493 # Check if representative_dataset is specified.
494 if (
495 not self._representative_dataset
496 and not self.is_quantization_aware_training()
497 ):
498 raise ValueError(
499 "For full integer quantization, a "
500 "`representative_dataset` must be specified."
501 )
503 # Update represenative dataset to the expected format.
504 if self._representative_dataset:
505 if not isinstance(self._representative_dataset, RepresentativeDataset):
506 self._representative_dataset = RepresentativeDataset(
507 self._representative_dataset
508 )
510 def _validate_full_integer_quantization_bias_type(self):
511 """Validates bias type for full interger quantization."""
512 bias_type = self._full_integer_quantization_bias_type
513 if not bias_type:
514 return
516 if self.activations_type() == _dtypes.float32:
517 raise ValueError(
518 "`full_integer_quantization_bias_type` is only supported for full"
519 " integer quantization."
520 )
522 if self.activations_type() == _dtypes.int8 and bias_type != _dtypes.int32:
523 raise ValueError(
524 "Expected bias type to be `dtypes.int32` for Int8Quant. "
525 f"Current setting bias type: {bias_type}"
526 )
528 if (
529 self.activations_type() == _dtypes.int16
530 and bias_type != _dtypes.int32
531 and bias_type != _dtypes.int64
532 ):
533 raise ValueError(
534 "Expected bias type to be `dtypes.int32` or `dtypes.int64` for "
535 f"Int16Quant. Current setting bias type: {bias_type}"
536 )
538 def _is_int8_target_required(self):
539 return (
540 OpsSet.TFLITE_BUILTINS_INT8 in set(self._target_spec.supported_ops)
541 ) or (set(self._target_spec.supported_types) == set([_dtypes.int8]))
543 def _is_int16x8_target_required(self):
544 return (
545 OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
546 in set(self._target_spec.supported_ops)
547 )
549 def is_allow_float(self):
550 return (OpsSet.TFLITE_BUILTINS in set(self._target_spec.supported_ops)) or (
551 OpsSet.SELECT_TF_OPS in set(self._target_spec.supported_ops)
552 )
554 def is_any_optimization_enabled(self):
555 return bool(
556 set(self._optimizations).intersection([
557 Optimize.OPTIMIZE_FOR_LATENCY,
558 Optimize.OPTIMIZE_FOR_SIZE,
559 Optimize.DEFAULT,
560 ])
561 )
563 def _smallest_supported_type(self):
564 if self._target_spec.supported_types:
565 return min(self._target_spec.supported_types, key=lambda x: x.size)
566 else:
567 # The default smallest supported type is INT8.
568 return _dtypes.int8
570 def is_quantization_aware_trained_model(self):
571 """Checks if the graph contains any training-time quantization ops."""
572 training_quant_ops = frozenset({
573 "FakeQuantWithMinMaxVars",
574 "FakeQuantWithMinMaxVarsPerChannel",
575 "FakeQuantWithMinMaxArgs",
576 "QuantizeAndDequantizeV2",
577 "QuantizeAndDequantizeV3",
578 })
580 if self._graph_def:
581 for node_def in self._graph_def.node:
582 if node_def.op in training_quant_ops:
583 return True
584 for function in self._graph_def.library.function:
585 for node_def in function.node_def:
586 if node_def.op in training_quant_ops:
587 return True
588 return False
591class TFLiteConverterBase:
592 """Converter subclass to share functionality between V1 and V2 converters."""
594 # Stores the original model type temporarily to transmit the information
595 # from the factory class methods to TFLiteConverterBase init function.
596 _original_model_type = conversion_metdata_fb.ModelType.NONE
598 def __init__(self):
599 self.optimizations = set()
600 self.representative_dataset = None
601 self.target_spec = TargetSpec()
602 self.allow_custom_ops = False
603 self.experimental_new_converter = True
604 self.experimental_new_quantizer = True
605 self.experimental_enable_resource_variables = True
606 self._experimental_calibrate_only = False
607 self._experimental_sparsify_model = False
608 self._experimental_disable_per_channel = False
609 self._debug_info = None # contains the stack traces of all the original
610 # nodes in the `GraphDef` to the converter.
611 self.saved_model_dir = None
612 self._saved_model_tags = None
613 self._saved_model_version = 0
614 self._saved_model_exported_names = []
615 self._tflite_metrics = metrics.TFLiteConverterMetrics()
616 self._collected_converter_params = {}
617 self._experimental_disable_batchmatmul_unfold = False
618 self._experimental_lower_tensor_list_ops = True
619 self._experimental_default_to_single_batch_in_tensor_list_ops = False
620 self._experimental_unfold_large_splat_constant = False
621 self._experimental_tf_quantization_mode = None
622 # If unset, bias:int32 is by default except 16x8 quant.
623 # For 16x8 quant, bias:int64 is used to prevent any overflow by default.
624 self._experimental_full_integer_quantization_bias_type = None
625 # Provides specs for quantization, whether preset or custom.
626 self._experimental_quantization_options = None
627 # Initializes conversion metadata.
628 self.exclude_conversion_metadata = False
629 self._metadata = conversion_metdata_fb.ConversionMetadataT()
630 self._metadata.environment = conversion_metdata_fb.EnvironmentT()
631 self._metadata.options = conversion_metdata_fb.ConversionOptionsT()
632 self._metadata.environment.tensorflowVersion = versions.__version__
633 self._metadata.environment.modelType = self._get_original_model_type()
634 self._experimental_enable_dynamic_update_slice = False
635 self._experimental_preserve_assert_op = False
636 self._experimental_guarantee_all_funcs_one_use = False
638 # When the value is true, the MLIR quantantizer triggers dynamic range
639 # quantization in MLIR instead of the old quantizer. Used only if
640 # experimental_new_quantizer is on.
641 self.experimental_new_dynamic_range_quantizer = True
642 # Experimental flag to enable low-bit QAT in 8 bit.
643 self._experimental_low_bit_qat = False
644 # Experimental flag to add all TF ops (including custom TF ops) to the
645 # converted model as flex ops.
646 self._experimental_allow_all_select_tf_ops = False
648 self._experimental_variable_quantization = False
649 self._experimental_disable_fuse_mul_and_fc = False
651 def _grappler_config(self, optimizers=None):
652 """Creates a tf.compat.v1.ConfigProto for configuring Grappler.
654 Args:
655 optimizers: List of strings that represents the list of optimizers.
657 Returns:
658 tf.ConfigProto.
659 """
660 if not optimizers:
661 optimizers = []
662 # MLIR converter will take care of constant folding instead of grappler.
663 if not self.experimental_new_converter:
664 optimizers.append("constfold")
666 is_only_flex_enabled = set([OpsSet.SELECT_TF_OPS]) == set(
667 self.target_spec.supported_ops
668 )
669 if is_only_flex_enabled:
670 # The layout optimizer turns NHCW to NCHW. This provides performance
671 # optimizations when Flex mode is enabled. However, this is not compatible
672 # with builtin ops.
673 optimizers.append("layout")
674 return _get_grappler_config(optimizers)
676 def _quantize(
677 self,
678 result,
679 input_type,
680 output_type,
681 activations_type,
682 bias_type,
683 allow_float,
684 enable_variable_quantization,
685 ):
686 """Quantize the model."""
687 # pylint: disable=protected-access
688 custom_op_registerers_by_name = [
689 x
690 for x in self.target_spec._experimental_custom_op_registerers
691 if isinstance(x, str)
692 ]
693 custom_op_registerers_by_func = [
694 x
695 for x in self.target_spec._experimental_custom_op_registerers
696 if not isinstance(x, str)
697 ]
698 # pylint: enable=protected-access
699 if not isinstance(self.representative_dataset, RepresentativeDataset):
700 self.representative_dataset = RepresentativeDataset(
701 self.representative_dataset
702 )
704 # Add intermediate tensors to the model if needed.
705 result = _calibrator.add_intermediate_tensors(result)
706 calibrate_quantize = _calibrator.Calibrator(
707 result, custom_op_registerers_by_name, custom_op_registerers_by_func
708 )
709 if self._experimental_calibrate_only or self.experimental_new_quantizer:
710 calibrated = calibrate_quantize.calibrate(
711 self.representative_dataset.input_gen
712 )
714 if self._experimental_calibrate_only:
715 return calibrated
716 elif self.experimental_new_quantizer and (
717 activations_type != _dtypes.int16
718 ):
719 # TODO(b/175659372): remove the activations_type restriction and enable
720 # it for all the activation types.
721 return _mlir_quantize(
722 calibrated,
723 self._experimental_disable_per_channel,
724 input_data_type=input_type,
725 output_data_type=output_type,
726 enable_variable_quantization=enable_variable_quantization,
727 )
728 else:
729 return calibrate_quantize.calibrate_and_quantize(
730 self.representative_dataset.input_gen,
731 input_type,
732 output_type,
733 allow_float,
734 activations_type,
735 bias_type,
736 disable_per_channel=self._experimental_disable_per_channel,
737 )
739 def _is_unknown_shapes_allowed(self):
740 # Unknown dimensions are only allowed with the new converter.
741 return self.experimental_new_converter
743 def _get_base_converter_args(self):
744 """Returns the base converter args.
746 Returns:
747 {key str: val}
748 """
749 args = {
750 "input_format": constants.TENSORFLOW_GRAPHDEF,
751 "allow_custom_ops": self.allow_custom_ops,
752 "debug_info": self._debug_info,
753 "target_ops": self.target_spec.supported_ops,
754 "enable_mlir_converter": self.experimental_new_converter,
755 "select_user_tf_ops": self.target_spec.experimental_select_user_tf_ops,
756 "supported_backends": self.target_spec.experimental_supported_backends,
757 "unfold_batchmatmul": not self._experimental_disable_batchmatmul_unfold,
758 "lower_tensor_list_ops": self._experimental_lower_tensor_list_ops,
759 "unfold_large_splat_constant": (
760 self._experimental_unfold_large_splat_constant
761 ),
762 "default_to_single_batch_in_tensor_list_ops": (
763 self._experimental_default_to_single_batch_in_tensor_list_ops
764 ),
765 "tf_quantization_mode": self._experimental_tf_quantization_mode,
766 "experimental_enable_resource_variables": (
767 self.experimental_enable_resource_variables
768 ),
769 "enable_dynamic_update_slice": (
770 self._experimental_enable_dynamic_update_slice
771 ),
772 "preserve_assert_op": self._experimental_preserve_assert_op,
773 "guarantee_all_funcs_one_use": (
774 self._experimental_guarantee_all_funcs_one_use
775 ),
776 "allow_all_select_tf_ops": self._experimental_allow_all_select_tf_ops,
777 "disable_fuse_mul_and_fc": self._experimental_disable_fuse_mul_and_fc,
778 "quantization_options": self._experimental_quantization_options,
779 }
781 if self.saved_model_dir:
782 args.update({
783 "saved_model_dir": self.saved_model_dir,
784 "saved_model_version": self._saved_model_version,
785 "saved_model_tags": self._saved_model_tags,
786 "saved_model_exported_names": self._saved_model_exported_names,
787 })
789 return args
791 def _contains_function_with_implements_attr(self, saved_model_proto):
792 meta_graph = saved_model_proto.meta_graphs[0]
793 for function in meta_graph.graph_def.library.function:
794 if function.attr.get("_implements", None) or function.attr.get(
795 "api_implements", None
796 ):
797 return True
798 return False
800 def _parse_saved_model_args(self, always_enable_saved_model_import=False):
801 """Parses SavedModel arguments from the given Keras/RNN SavedModel.
803 Args:
804 always_enable_saved_model_import: Bool. When the value is true, it enables
805 MLIR saved model import path regardless of checking the conditions.
806 """
807 if not self.experimental_new_converter:
808 self.saved_model_dir = None
809 return
810 if self.saved_model_dir:
811 try:
812 saved_model_proto, _ = _parse_saved_model_with_debug_info(
813 self.saved_model_dir
814 )
815 except OSError:
816 # If it fails to read the given saved model, it will fall back to the
817 # frozen graph def path.
818 self.saved_model_dir = None
819 return
820 if (
821 not always_enable_saved_model_import
822 and not self._contains_function_with_implements_attr(
823 saved_model_proto
824 )
825 ):
826 self.saved_model_dir = None
827 return
829 if not self._saved_model_exported_names:
830 self._saved_model_exported_names = []
831 self._saved_model_version = saved_model_proto.saved_model_schema_version
832 if self._saved_model_version == 0:
833 self.saved_model_dir = None
834 logging.warning("SavedModel schema version is zero.")
835 return
836 if self._saved_model_version not in [1, 2]:
837 raise ValueError(
838 "SavedModel file format({0}) is not supported".format(
839 self._saved_model_version
840 )
841 )
843 def _sparsify_model(self):
844 return Optimize.EXPERIMENTAL_SPARSITY in self.optimizations
846 def _increase_conversion_attempt_metric(self):
847 self._tflite_metrics.increase_counter_converter_attempt()
849 def _increase_conversion_success_metric(self):
850 self._tflite_metrics.increase_counter_converter_success()
852 @classmethod
853 def _set_original_model_type(cls, model_type):
854 """Stores the original model type."""
855 if model_type == conversion_metdata_fb.ModelType.NONE:
856 raise ValueError("The original model type should be specified.")
857 cls._original_model_type = model_type
859 def _get_original_model_type(self):
860 """One-time getter to return original model type and set it to NONE."""
861 model_type = TFLiteConverterBase._original_model_type
862 TFLiteConverterBase._original_model_type = (
863 conversion_metdata_fb.ModelType.NONE
864 )
865 return model_type
867 def _save_conversion_params_metric(
868 self, graph_def=None, inference_type=None, inference_input_type=None
869 ):
870 """Set conversion parameter metrics."""
871 converter_kwargs = self._collected_converter_params
872 converter_kwargs.update(self._get_base_converter_args())
874 # Optimization parameters.
875 quant_mode = QuantizationMode(
876 self.optimizations,
877 self.target_spec,
878 self.representative_dataset,
879 graph_def,
880 self._experimental_disable_per_channel,
881 self.experimental_new_dynamic_range_quantizer,
882 self._experimental_low_bit_qat,
883 self._experimental_full_integer_quantization_bias_type,
884 self._experimental_variable_quantization,
885 )
886 converter_kwargs.update({
887 "tf_version": self._metadata.environment.tensorflowVersion,
888 "api_version": self._metadata.environment.apiVersion,
889 "original_model_format": self._metadata.environment.modelType,
890 "optimization_default": quant_mode.is_any_optimization_enabled(),
891 "optimization_post_training_dynamic_range": (
892 quant_mode.is_post_training_dynamic_range_quantization()
893 ),
894 "optimization_post_training_float16": (
895 quant_mode.is_post_training_float16_quantization()
896 ),
897 "optimization_post_training_integer_quantize": (
898 quant_mode.is_post_training_integer_quantization()
899 ),
900 "optimization_qat": quant_mode.is_quantization_aware_training(),
901 "optimization_low_bit_qat": (
902 quant_mode.is_low_bit_quantize_aware_training()
903 ),
904 "optimization_sparsify": self._sparsify_model(),
905 "activations_type": quant_mode.activations_type(),
906 })
907 converter_kwargs.update(
908 quant_mode.converter_flags(inference_type, inference_input_type)
909 )
911 # pylint: disable=protected-access
912 if self.target_spec._experimental_supported_accumulation_type:
913 converter_kwargs.update(
914 {
915 "accumulation_type": (
916 self.target_spec._experimental_supported_accumulation_type
917 )
918 }
919 )
920 # pylint: enable=protected-access
922 def format_element(elem):
923 if isinstance(elem, enum.Enum):
924 return str(elem.value)
925 return pprint.pformat(elem)
927 def format_param(param):
928 if isinstance(param, (list, tuple, set)):
929 if not param:
930 return "None" # Return None if empty.
931 string_list = [format_element(x) for x in param]
932 return ",".join(sorted(string_list))
933 return format_element(param)
935 for key, value in converter_kwargs.items():
936 self._tflite_metrics.set_converter_param(key, format_param(value))
937 self._tflite_metrics.set_export_required()
939 # Set conversion option metadata.
940 self._metadata.options.allowCustomOps = self.allow_custom_ops
941 self._metadata.options.enableSelectTfOps = (
942 OpsSet.SELECT_TF_OPS in self.target_spec.supported_ops
943 )
944 self._metadata.options.forceSelectTfOps = set(
945 [OpsSet.SELECT_TF_OPS]
946 ) == set(self.target_spec.supported_ops)
947 self._metadata.options.modelOptimizationModes = []
949 if quant_mode.is_post_training_float16_quantization():
950 self._metadata.options.modelOptimizationModes.append(
951 conversion_metdata_fb.ModelOptimizationMode.PTQ_FLOAT16
952 )
954 if quant_mode.is_post_training_dynamic_range_quantization():
955 self._metadata.options.modelOptimizationModes.append(
956 conversion_metdata_fb.ModelOptimizationMode.PTQ_DYNAMIC_RANGE
957 )
959 if quant_mode.is_post_training_int8_quantization():
960 self._metadata.options.modelOptimizationModes.append(
961 conversion_metdata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER
962 )
964 if quant_mode.is_post_training_int16x8_quantization():
965 self._metadata.options.modelOptimizationModes.append(
966 conversion_metdata_fb.ModelOptimizationMode.PTQ_INT16
967 )
969 if quant_mode.is_quantization_aware_training():
970 self._metadata.options.modelOptimizationModes.append(
971 conversion_metdata_fb.ModelOptimizationMode.QUANTIZATION_AWARE_TRAINING
972 )
974 def _set_conversion_latency_metric(self, value):
975 self._tflite_metrics.set_converter_latency(value)
977 @convert_phase(Component.OPTIMIZE_TFLITE_MODEL)
978 def _optimize_tflite_model(self, model, quant_mode, quant_io=True):
979 """Apply optimizations on a TFLite model."""
981 if quant_mode.is_integer_quantization():
982 in_type, out_type = self.inference_input_type, self.inference_output_type
984 if quant_mode.is_post_training_integer_quantization():
985 q_in_type = in_type if in_type and quant_io else _dtypes.float32
986 q_out_type = out_type if out_type and quant_io else _dtypes.float32
987 q_activations_type = quant_mode.activations_type()
988 q_bias_type = quant_mode.bias_type()
989 q_allow_float = quant_mode.is_allow_float()
990 q_variable_quantization = quant_mode.enable_mlir_variable_quantization
991 model = self._quantize(
992 model,
993 q_in_type,
994 q_out_type,
995 q_activations_type,
996 q_bias_type,
997 q_allow_float,
998 q_variable_quantization,
999 )
1001 m_in_type = in_type if in_type else _dtypes.float32
1002 m_out_type = out_type if out_type else _dtypes.float32
1003 # Skip updating model io types if MLIR quantizer already takes care of it
1004 if not (
1005 quant_mode.is_post_training_integer_quantization()
1006 and self.experimental_new_quantizer
1007 and quant_io
1008 and (m_in_type in [_dtypes.int8, _dtypes.uint8, _dtypes.float32])
1009 and (m_out_type in [_dtypes.int8, _dtypes.uint8, _dtypes.float32])
1010 ):
1011 model = _modify_model_io_type(model, m_in_type, m_out_type)
1013 if self._sparsify_model():
1014 model = _mlir_sparsify(model)
1016 try:
1017 model = _deduplicate_readonly_buffers(model)
1018 except Exception: # pylint: disable=broad-except
1019 # Skip buffer deduplication when flatbuffer library is not ready to be
1020 # utilized.
1021 logging.warning(
1022 "Buffer deduplication procedure will be skipped when flatbuffer "
1023 "library is not properly loaded"
1024 )
1026 return model
1028 def _convert_and_export_metrics(self, convert_func, *args, **kwargs):
1029 """Wraps around convert function to export metrics.
1031 Args:
1032 convert_func: The convert function to wrap.
1033 *args: Positional arguments of the convert function.
1034 **kwargs: The keyword arguments of the convert function.
1036 Returns:
1037 The decorator to wrap the convert function.
1038 """
1039 self._increase_conversion_attempt_metric()
1040 self._save_conversion_params_metric()
1041 start_time = time.process_time()
1042 result = convert_func(self, *args, **kwargs)
1043 elapsed_time_ms = (time.process_time() - start_time) * 1000
1044 if result:
1045 self._increase_conversion_success_metric()
1046 self._set_conversion_latency_metric(round(elapsed_time_ms))
1047 self._tflite_metrics.export_metrics()
1048 if self.exclude_conversion_metadata:
1049 return result
1050 model_object = flatbuffer_utils.convert_bytearray_to_object(result)
1051 # Populates the conversion metadata.
1052 # TODO(b/202090541): Collects sparsity block size information.
1053 sparsity_modes = _get_sparsity_modes(model_object)
1054 self._metadata.options.modelOptimizationModes.extend(sparsity_modes)
1055 model_object = _populate_conversion_metadata(model_object, self._metadata)
1056 return flatbuffer_utils.convert_object_to_bytearray(model_object)
1059def _export_metrics(convert_func):
1060 """The decorator around convert function to export metrics."""
1062 @functools.wraps(convert_func)
1063 def wrapper(self, *args, **kwargs):
1064 # pylint: disable=protected-access
1065 return self._convert_and_export_metrics(convert_func, *args, **kwargs)
1066 # pylint: enable=protected-access
1068 return wrapper
1071class TFLiteConverterBaseV2(TFLiteConverterBase):
1072 """Converter subclass to share functionality between V2 converters."""
1074 def __init__(self):
1075 """Constructor for TFLiteConverter."""
1076 super(TFLiteConverterBaseV2, self).__init__()
1077 self.inference_input_type = _dtypes.float32
1078 self.inference_output_type = _dtypes.float32
1079 self._metadata.environment.apiVersion = 2
1081 def _validate_inference_input_output_types(self, quant_mode):
1082 """Validate inference_input_type and inference_output_type flags."""
1083 default_types = [_dtypes.float32]
1084 # We support integer input/output for integer quantized models only.
1085 if quant_mode.is_integer_quantization():
1086 if quant_mode.is_post_training_int16x8_quantization():
1087 all_types = default_types + [_dtypes.int16]
1088 else:
1089 all_types = default_types + [_dtypes.int8, _dtypes.uint8]
1090 if (
1091 self.inference_input_type not in all_types
1092 or self.inference_output_type not in all_types
1093 ):
1094 all_types_names = ["tf." + t.name for t in all_types]
1095 raise ValueError(
1096 "The inference_input_type and inference_output_type "
1097 "must be in {}.".format(all_types_names)
1098 )
1099 elif (
1100 self.inference_input_type not in default_types
1101 or self.inference_output_type not in default_types
1102 ):
1103 raise ValueError(
1104 "The inference_input_type and inference_output_type "
1105 "must be tf.float32."
1106 )
1108 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.LOAD_SAVED_MODEL)
1109 def _load_saved_model(self, saved_model_dir, saved_model_tags):
1110 """Load graph_def from saved model with the default serving signature key.
1112 Args:
1113 saved_model_dir: Directory of the SavedModel.
1114 saved_model_tags: Set of tags identifying the MetaGraphDef within the
1115 SavedModel to analyze.
1117 Returns:
1118 graph_def: The loaded GraphDef.
1119 input_tensors: List of input tensors.
1120 output_tensors: List of output tensors.
1121 """
1122 graph = _ops.Graph()
1123 saved_model = _loader_impl.SavedModelLoader(saved_model_dir)
1124 saved_model.load_graph(graph, tags=saved_model_tags)
1125 meta_graph = saved_model.get_meta_graph_def_from_tags(saved_model_tags)
1126 graph_def = meta_graph.graph_def
1127 signature_def = meta_graph.signature_def[
1128 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1129 ]
1130 input_tensors = [
1131 graph.get_tensor_by_name(signature_def.inputs[key].name)
1132 for key in signature_def.inputs
1133 ]
1134 output_tensors = [
1135 graph.get_tensor_by_name(signature_def.outputs[key].name)
1136 for key in signature_def.outputs
1137 ]
1138 return graph_def, input_tensors, output_tensors
1140 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS)
1141 def _validate_inputs(self, graph_def, input_tensors):
1142 """Validate the input parameters.
1144 Args:
1145 graph_def: The TensorFlow GraphDef.
1146 input_tensors: List of input tensors.
1147 Raise:
1148 ValueError: Input shape is not specified. Invalid quantization parameters.
1149 """
1150 # Update conversion params with graph_def.
1151 self._save_conversion_params_metric(graph_def)
1152 self._quant_mode = QuantizationMode(
1153 self.optimizations,
1154 self.target_spec,
1155 self.representative_dataset,
1156 graph_def,
1157 self._experimental_disable_per_channel,
1158 self.experimental_new_dynamic_range_quantizer,
1159 self._experimental_low_bit_qat,
1160 self._experimental_full_integer_quantization_bias_type,
1161 self._experimental_variable_quantization,
1162 )
1163 self._validate_inference_input_output_types(self._quant_mode)
1165 if not self._is_unknown_shapes_allowed():
1166 # Checks dimensions in input tensor.
1167 for tensor in input_tensors:
1168 # Note that shape_list might be empty for scalar shapes.
1169 shape_list = tensor.shape.as_list()
1170 if None in shape_list[1:]:
1171 raise ValueError(
1172 "None is only supported in the 1st dimension. Tensor '{0}' has "
1173 "invalid shape '{1}'.".format(
1174 _get_tensor_name(tensor), shape_list
1175 )
1176 )
1177 elif shape_list and shape_list[0] is None:
1178 # Set the batch size to 1 if undefined.
1179 shape = tensor.shape.as_list()
1180 shape[0] = 1
1181 tensor.set_shape(shape)
1183 if self._trackable_obj is None or not hasattr(
1184 self._trackable_obj, "graph_debug_info"
1185 ):
1186 self._debug_info = _get_debug_info(
1187 _build_debug_info_func(self._funcs[0].graph), graph_def
1188 )
1189 else:
1190 self._debug_info = _get_debug_info(
1191 _convert_debug_info_func(self._trackable_obj.graph_debug_info),
1192 graph_def,
1193 )
1195 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL)
1196 def _optimize_tf_model(
1197 self, graph_def, input_tensors, output_tensors, frozen_func
1198 ):
1199 """Run a Grappler pass to optimize the TensorFlow graph.
1201 Args:
1202 graph_def: Frozen GraphDef to be optimized.
1203 input_tensors: List of input tensors.
1204 output_tensors: List of output tensors.
1205 frozen_func: TensorFlow Graph.
1207 Returns:
1208 The optimized TensorFlow graph.
1209 """
1210 grappler_config = self._grappler_config()
1211 # Skip running grappler when there are no optimizers to run. If not,
1212 # grappler will run with the default optimizer set and it will lead to
1213 # causing an unexpected behavior.
1214 if grappler_config.graph_options.rewrite_options.optimizers:
1215 graph_def = _run_graph_optimizations(
1216 graph_def,
1217 input_tensors,
1218 output_tensors,
1219 config=grappler_config,
1220 graph=frozen_func.graph,
1221 )
1222 return graph_def
1224 def _convert_from_saved_model(self, graph_def):
1225 """Helper method that converts saved model.
1227 Args:
1228 graph_def: GraphDef object for the model, used only for stats.
1230 Returns:
1231 The converted TFLite model.
1232 """
1233 # Update conversion params with graph_def.
1234 self._save_conversion_params_metric(graph_def)
1235 # Get quantization options and do some sanity checks.
1236 quant_mode = QuantizationMode(
1237 self.optimizations,
1238 self.target_spec,
1239 self.representative_dataset,
1240 graph_def,
1241 self._experimental_disable_per_channel,
1242 self.experimental_new_dynamic_range_quantizer,
1243 self._experimental_low_bit_qat,
1244 self._experimental_full_integer_quantization_bias_type,
1245 self._experimental_variable_quantization,
1246 )
1247 self._validate_inference_input_output_types(quant_mode)
1248 converter_kwargs = {
1249 "enable_tflite_resource_variables": (
1250 self.experimental_enable_resource_variables
1251 )
1252 }
1253 converter_kwargs.update(self._get_base_converter_args())
1254 converter_kwargs.update(quant_mode.converter_flags())
1256 result = _convert_saved_model(**converter_kwargs)
1257 return self._optimize_tflite_model(
1258 result, quant_mode, quant_io=self.experimental_new_quantizer
1259 )
1261 def convert(self, graph_def, input_tensors, output_tensors):
1262 """Converts a TensorFlow GraphDef based on instance variables.
1264 Args:
1265 graph_def: Frozen TensorFlow GraphDef.
1266 input_tensors: List of input tensors.
1267 output_tensors: List of output tensors.
1269 Returns:
1270 The converted data in serialized format.
1272 Raises:
1273 ValueError:
1274 No concrete functions is specified.
1275 Multiple concrete functions are specified.
1276 Input shape is not specified.
1277 Invalid quantization parameters.
1278 """
1279 self._validate_inputs(graph_def, input_tensors)
1280 converter_kwargs = self._get_base_converter_args()
1281 converter_kwargs.update(self._quant_mode.converter_flags())
1282 if not self.experimental_new_converter:
1283 logging.warning(
1284 "Please consider switching to the new converter by setting "
1285 "experimental_new_converter=True. "
1286 "The old converter is deprecated."
1287 )
1288 else:
1289 logging.info(
1290 "Using new converter: If you encounter a problem "
1291 "please file a bug. You can opt-out "
1292 "by setting experimental_new_converter=False"
1293 )
1295 # Converts model.
1296 result = _convert_graphdef(
1297 input_data=graph_def,
1298 input_tensors=input_tensors,
1299 output_tensors=output_tensors,
1300 **converter_kwargs,
1301 )
1303 return self._optimize_tflite_model(
1304 result, self._quant_mode, quant_io=self.experimental_new_quantizer
1305 )
1308class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
1309 """Converts the given SavedModel into TensorFlow Lite model.
1311 Attributes:
1312 saved_model_dir: Directory of the SavedModel.
1313 """
1315 def __init__(
1316 self,
1317 saved_model_dir,
1318 saved_model_tags=None,
1319 saved_model_exported_names=None,
1320 trackable_obj=None,
1321 ):
1322 """Constructor for TFLiteConverter.
1324 Args:
1325 saved_model_dir: Directory of the SavedModel.
1326 saved_model_tags: Set of tags identifying the MetaGraphDef within the
1327 SavedModel to analyze. All tags in the tag set must be present. (default
1328 {tf.saved_model.SERVING}).
1329 saved_model_exported_names: Names to be exported when the saved model
1330 import path is on.
1331 trackable_obj: tf.AutoTrackable object associated with `funcs`. A
1332 reference to this object needs to be maintained so that Variables do not
1333 get garbage collected since functions have a weak reference to
1334 Variables. This is only required when the tf.AutoTrackable object is not
1335 maintained by the user (e.g. `from_saved_model`).
1336 """
1337 super(TFLiteSavedModelConverterV2, self).__init__()
1338 self.saved_model_dir = saved_model_dir
1339 self._saved_model_tags = saved_model_tags
1340 self._saved_model_exported_names = saved_model_exported_names
1341 self._trackable_obj = trackable_obj
1342 self._parse_saved_model_args(always_enable_saved_model_import=True)
1344 @_export_metrics
1345 def convert(self):
1346 """Converts a TensorFlow GraphDef based on instance variables.
1348 Returns:
1349 The converted data in serialized format.
1351 Raises:
1352 ValueError:
1353 No concrete functions is specified.
1354 Multiple concrete functions are specified.
1355 Input shape is not specified.
1356 Invalid quantization parameters.
1357 """
1358 graph_def, input_tensors, output_tensors = self._load_saved_model(
1359 self.saved_model_dir, self._saved_model_tags
1360 )
1361 # If we can't use saved model importer, then fallback
1362 # to frozen graph conversion path.
1363 if self.saved_model_dir is None or not self.experimental_new_converter:
1364 graph_def, _, _, _ = _freeze_saved_model(
1365 self.saved_model_dir,
1366 None,
1367 None,
1368 None,
1369 self._saved_model_tags,
1370 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
1371 )
1372 # We make sure to clear the saved_model_dir as there is some
1373 # legacy code down in the caller that checks this.
1374 # TODO(b/162537905): Clean these indirect dependencies.
1375 self.saved_model_dir = None
1376 return super(TFLiteSavedModelConverterV2, self).convert(
1377 graph_def, input_tensors, output_tensors
1378 )
1380 if self._trackable_obj is None:
1381 self._debug_info = _get_debug_info(
1382 _build_debug_info_func(self._funcs[0].graph), graph_def
1383 )
1384 else:
1385 self._debug_info = _get_debug_info(
1386 _convert_debug_info_func(self._trackable_obj.graph_debug_info),
1387 graph_def,
1388 )
1390 return self._convert_from_saved_model(graph_def)
1393class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
1394 """Converts the given Keras model into TensorFlow Lite model."""
1396 def __init__(self, keras_model, trackable_obj=None):
1397 """Constructor for TFLiteConverter.
1399 Args:
1400 keras_model: tf.Keras.Model.
1401 trackable_obj: tf.AutoTrackable object associated with `funcs`. A
1402 reference to this object needs to be maintained so that Variables do not
1403 get garbage collected since functions have a weak reference to
1404 Variables. This is only required when the tf.AutoTrackable object is not
1405 maintained by the user (e.g. `from_saved_model`).
1406 """
1407 super(TFLiteKerasModelConverterV2, self).__init__()
1408 self._keras_model = keras_model
1409 self._trackable_obj = trackable_obj
1410 self.experimental_lower_to_saved_model = True
1412 @convert_phase(
1413 Component.PREPARE_TF_MODEL, SubComponent.CONVERT_KERAS_TO_SAVED_MODEL
1414 )
1415 def _convert_keras_to_saved_model(self, output_dir):
1416 """Save Keras model to the SavedModel format.
1418 Args:
1419 output_dir: The output directory to save the SavedModel.
1421 Returns:
1422 graph_def: The frozen GraphDef.
1423 input_tensors: List of input tensors.
1424 output_tensors: List of output tensors.
1425 """
1426 try:
1427 _saved_model.save(
1428 self._keras_model,
1429 output_dir,
1430 options=_save_options.SaveOptions(save_debug_info=True),
1431 )
1432 except Exception: # pylint: disable=broad-except
1433 # When storing the given keras model to a saved model is failed, let's
1434 # use original keras model conversion pipeline.
1435 return None, None, None
1436 self.saved_model_dir = output_dir
1437 self._saved_model_tags = set([_tag_constants.SERVING])
1438 self._saved_model_exported_names = [
1439 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1440 ]
1441 self._parse_saved_model_args(
1442 always_enable_saved_model_import=self.experimental_lower_to_saved_model
1443 )
1444 if self.saved_model_dir:
1445 graph_def, input_tensors, output_tensors = self._load_saved_model(
1446 self.saved_model_dir, self._saved_model_tags
1447 )
1448 self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags)
1449 return graph_def, input_tensors, output_tensors
1450 return None, None, None
1452 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL)
1453 def _freeze_keras_model(self):
1454 """Freeze Keras model to frozen graph.
1456 Returns:
1457 graph_def: The frozen GraphDef.
1458 input_tensors: List of input tensors.
1459 output_tensors: List of output tensors.
1460 frozen_func: The frozen ConcreteFunction.
1461 """
1462 input_signature = None
1463 # If the model's call is not a `tf.function`, then we need to first get its
1464 # input signature from `model_input_signature` method. We can't directly
1465 # call `trace_model_call` because otherwise the batch dimension is set
1466 # to None.
1467 # Once we have better support for dynamic shapes, we can remove this.
1468 if not isinstance(self._keras_model.call, _def_function.Function):
1469 # Pass `keep_original_batch_size=True` will ensure that we get an input
1470 # signature including the batch dimension specified by the user.
1471 # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
1472 input_signature = _model_input_signature(
1473 self._keras_model, keep_original_batch_size=True
1474 )
1476 # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
1477 func = _trace_model_call(self._keras_model, input_signature)
1478 concrete_func = func.get_concrete_function()
1479 self._funcs = [concrete_func]
1481 frozen_func, graph_def = (
1482 _convert_to_constants.convert_variables_to_constants_v2_as_graph(
1483 self._funcs[0], lower_control_flow=False
1484 )
1485 )
1487 input_tensors = [
1488 tensor
1489 for tensor in frozen_func.inputs
1490 if tensor.dtype != _dtypes.resource
1491 ]
1492 output_tensors = frozen_func.outputs
1493 return graph_def, input_tensors, output_tensors, frozen_func
1495 def _convert_as_saved_model(self):
1496 """Converts a Keras model as a saved model.
1498 Returns:
1499 The converted data in serialized format.
1500 """
1501 temp_dir = tempfile.mkdtemp()
1502 try:
1503 graph_def, input_tensors, output_tensors = (
1504 self._convert_keras_to_saved_model(temp_dir)
1505 )
1506 if self.saved_model_dir:
1507 return super(TFLiteKerasModelConverterV2, self).convert(
1508 graph_def, input_tensors, output_tensors
1509 )
1510 finally:
1511 shutil.rmtree(temp_dir, True)
1513 @_export_metrics
1514 def convert(self):
1515 """Converts a keras model based on instance variables.
1517 Returns:
1518 The converted data in serialized format.
1520 Raises:
1521 ValueError:
1522 Multiple concrete functions are specified.
1523 Input shape is not specified.
1524 Invalid quantization parameters.
1525 """
1526 saved_model_convert_result = self._convert_as_saved_model()
1527 if saved_model_convert_result:
1528 return saved_model_convert_result
1530 graph_def, input_tensors, output_tensors, frozen_func = (
1531 self._freeze_keras_model()
1532 )
1534 graph_def = self._optimize_tf_model(
1535 graph_def, input_tensors, output_tensors, frozen_func
1536 )
1538 return super(TFLiteKerasModelConverterV2, self).convert(
1539 graph_def, input_tensors, output_tensors
1540 )
1543class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2):
1544 """Converts the given frozen graph into TensorFlow Lite model."""
1546 def __init__(self, funcs, trackable_obj=None):
1547 """Constructor for TFLiteConverter.
1549 Args:
1550 funcs: List of TensorFlow ConcreteFunctions. The list should not contain
1551 duplicate elements.
1552 trackable_obj: tf.AutoTrackable object associated with `funcs`. A
1553 reference to this object needs to be maintained so that Variables do not
1554 get garbage collected since functions have a weak reference to
1555 Variables. This is only required when the tf.AutoTrackable object is not
1556 maintained by the user (e.g. `from_saved_model`).
1557 """
1558 super(TFLiteFrozenGraphConverterV2, self).__init__()
1559 self._funcs = funcs
1560 self._trackable_obj = trackable_obj
1561 self.experimental_lower_to_saved_model = True
1563 @convert_phase(
1564 Component.PREPARE_TF_MODEL, SubComponent.FREEZE_CONCRETE_FUNCTION
1565 )
1566 def _freeze_concrete_function(self):
1567 """Convert the given ConcreteFunction to frozen graph.
1569 Returns:
1570 graph_def: The frozen GraphDef.
1571 input_tensors: List of input tensors.
1572 output_tensors: List of output tensors.
1573 frozen_func: The frozen ConcreteFunction.
1575 Raises:
1576 ValueError: none or multiple ConcreteFunctions provided.
1577 """
1578 # TODO(b/130297984): Add support for converting multiple function.
1580 if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test
1581 raise ValueError("No ConcreteFunction is specified.")
1583 if len(self._funcs) > 1:
1584 raise ValueError(
1585 "This converter can only convert a single "
1586 "ConcreteFunction. Converting multiple functions is "
1587 "under development."
1588 )
1590 frozen_func, graph_def = (
1591 _convert_to_constants.convert_variables_to_constants_v2_as_graph(
1592 self._funcs[0], lower_control_flow=False
1593 )
1594 )
1596 input_tensors = [
1597 tensor
1598 for tensor in frozen_func.inputs
1599 if tensor.dtype != _dtypes.resource
1600 ]
1601 output_tensors = frozen_func.outputs
1602 return graph_def, input_tensors, output_tensors, frozen_func
1604 @convert_phase(
1605 Component.PREPARE_TF_MODEL,
1606 SubComponent.CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL,
1607 )
1608 def _convert_concrete_functions_to_saved_model(self, output_dir):
1609 """Save concrete functions to the SavedModel format.
1611 Args:
1612 output_dir: The output directory to save the SavedModel.
1614 Returns:
1615 graph_def: The frozen GraphDef.
1616 input_tensors: List of input tensors.
1617 output_tensors: List of output tensors.
1618 """
1619 if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test
1620 raise ValueError("No ConcreteFunction is specified.")
1622 if not self.experimental_lower_to_saved_model:
1623 return None, None, None
1625 # Without the provided trackable obj, it is not able to serialize the given
1626 # concrete functions as a saved model format. Also when trackable obj is
1627 # a function, use the original concrete function conversion pipline.
1628 if not self._trackable_obj or isinstance(
1629 self._trackable_obj,
1630 (_function.ConcreteFunction, _def_function.Function),
1631 ):
1632 return None, None, None
1634 signatures = {}
1635 signature_keys = []
1636 try:
1637 if len(self._funcs) == 1:
1638 signatures[_signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
1639 self._funcs[0]
1640 )
1641 signature_keys = [
1642 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1643 ]
1644 else:
1645 for func in self._funcs:
1646 signatures[func.graph.name] = func
1647 signature_keys.append(func.graph.name)
1649 _saved_model.save(
1650 self._trackable_obj,
1651 output_dir,
1652 signatures=signatures,
1653 options=_save_options.SaveOptions(save_debug_info=True),
1654 )
1655 except Exception: # pylint: disable=broad-except
1656 # When storing the given concrete function to a saved model is failed,
1657 # let's use original concrete function conversion pipeline.
1658 return None, None, None
1660 self.saved_model_dir = output_dir
1661 self._saved_model_tags = set([_tag_constants.SERVING])
1662 self._saved_model_exported_names = signature_keys
1663 self._parse_saved_model_args(always_enable_saved_model_import=True)
1664 if self.saved_model_dir:
1665 graph_def, input_tensors, output_tensors = self._load_saved_model(
1666 self.saved_model_dir, self._saved_model_tags
1667 )
1668 self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags)
1669 return graph_def, input_tensors, output_tensors
1670 return None, None, None
1672 def _convert_as_saved_model(self):
1673 """Converts the given concrete functions as a saved model format.
1675 Returns:
1676 The converted data in serialized format.
1677 """
1678 temp_dir = tempfile.mkdtemp()
1679 try:
1680 graph_def, input_tensors, _ = (
1681 self._convert_concrete_functions_to_saved_model(temp_dir)
1682 )
1683 if self.saved_model_dir:
1684 self._validate_inputs(graph_def, input_tensors)
1685 return self._convert_from_saved_model(graph_def)
1686 finally:
1687 shutil.rmtree(temp_dir, True)
1688 return None
1690 @_export_metrics
1691 def convert(self):
1692 """Converts a TensorFlow GraphDef based on instance variables.
1694 Returns:
1695 The converted data in serialized format.
1697 Raises:
1698 ValueError:
1699 No concrete functions is specified.
1700 Multiple concrete functions are specified.
1701 Input shape is not specified.
1702 Invalid quantization parameters.
1703 """
1704 if self.experimental_lower_to_saved_model:
1705 saved_model_convert_result = self._convert_as_saved_model()
1706 if saved_model_convert_result:
1707 return saved_model_convert_result
1709 graph_def, input_tensors, output_tensors, frozen_func = (
1710 self._freeze_concrete_function()
1711 )
1713 graph_def = self._optimize_tf_model(
1714 graph_def, input_tensors, output_tensors, frozen_func
1715 )
1717 return super(TFLiteFrozenGraphConverterV2, self).convert(
1718 graph_def, input_tensors, output_tensors
1719 )
1722class TFLiteJaxConverterV2(TFLiteConverterBaseV2):
1723 """Converts the given jax model into TensorFlow Lite model."""
1725 def __init__(self, serving_funcs, inputs):
1726 """Constructor for TFLiteConverter.
1728 Args:
1729 serving_funcs: A list functions of the serving func of the jax module, the
1730 model params should already be inlined. (e.g., `serving_func =
1731 functools.partial(model, params=params)`)
1732 inputs: Array of input tensor placeholders tuple,s like `jnp.zeros`. For
1733 example, wrapped in an array like "[('input1', input1), ('input2',
1734 input2)]]".
1736 Jax functions are polymorphic, for example:
1738 ```python
1739 def add(a, b):
1740 return a + b
1741 ```
1743 Will yield different computations if different input signatures are passed
1744 in: Pass `add(10.0, 20.0)` will yield a scalar `add` while pass
1745 `add(np.random((100, 1)), np.random(100, 100))` will yield a broadcasting
1746 add. We will need the input information to do tracing for the converter
1747 to properly convert the model. So it's important to pass in the desired
1748 `input placeholders` with the correct input shape/type.
1750 In the converted tflite model, the function name will be default to "main",
1751 the output names will be the traced outputs. The output ordering shall
1752 match the serving function.
1753 """ # fmt: skip
1755 super(TFLiteJaxConverterV2, self).__init__()
1756 self._serving_funcs = serving_funcs
1757 self._inputs = inputs
1759 @_export_metrics
1760 def convert(self):
1761 """Converts a Jax serving func based on instance variables.
1763 Returns:
1764 The converted data in serialized format.
1766 Raises:
1767 ImportError:
1768 If cannot import the xla_computation from jax.
1769 ValueError:
1770 No serving function is specified.
1771 Input tensors are not specified.
1772 The truth value of an array with more than one element is ambiguous.
1773 Failed to convert the given Jax function to hlo.
1774 """
1775 if not _xla_computation:
1776 raise ImportError("Cannot import xla_computation from jax.")
1778 if not self._serving_funcs:
1779 raise ValueError("No serving func is specified.")
1781 if not self._inputs:
1782 raise ValueError("Input tensors are not specified.")
1784 if len(self._inputs) != len(self._serving_funcs):
1785 msg = (
1786 "Input tensor mapping len {} does not match serving func len {}."
1787 .format(len(self._inputs), len(self._serving_funcs))
1788 )
1789 raise ValueError(msg)
1791 if not isinstance(self._inputs, (tuple, list)):
1792 raise ValueError(
1793 "Input tensors should be pass in a tuple list wrapped in an array."
1794 )
1796 # TODO(b/197690428): Support multiple functions.
1797 # Currently only support one serving function.
1798 if len(self._serving_funcs) > 1:
1799 raise ValueError("Currently only support single serving function.")
1801 if not isinstance(self._inputs[0], (tuple, list)):
1802 raise ValueError("The input placeholders are not a dictionary.")
1804 input_names = []
1805 ordered_inputs = []
1806 for input_name, tensor in self._inputs[0]:
1807 input_names.append(input_name)
1808 ordered_inputs.append(tensor)
1810 try:
1811 xla_compuation = _xla_computation(self._serving_funcs[0], backend="cpu")
1812 hlo_proto = xla_compuation(
1813 *ordered_inputs
1814 ).as_serialized_hlo_module_proto()
1815 except Exception: # pylint: disable=broad-except
1816 raise ValueError("Failed to convert the given Jax function to hlo.")
1818 # We need to set the hlo proto, and here we use serialized proto format
1819 # since it's more compact.
1820 converter_kwargs = {
1821 "input_content": hlo_proto,
1822 "input_names": input_names,
1823 "is_proto_format": True,
1824 }
1825 converter_kwargs.update(self._get_base_converter_args())
1827 # Get quantization options and do some checks.
1828 quant_mode = QuantizationMode(
1829 self.optimizations, self.target_spec, self.representative_dataset, None
1830 )
1831 self._validate_inference_input_output_types(quant_mode)
1832 converter_kwargs.update(quant_mode.converter_flags())
1833 result = _convert_jax_hlo(**converter_kwargs)
1835 return self._optimize_tflite_model(
1836 result, quant_mode, quant_io=self.experimental_new_quantizer
1837 )
1840@_tf_export("lite.TFLiteConverter", v1=[])
1841class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
1842 """Converts a TensorFlow model into TensorFlow Lite model.
1844 Attributes:
1845 optimizations: Experimental flag, subject to change. Set of optimizations to
1846 apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
1847 set of values of type `tf.lite.Optimize`)
1848 representative_dataset: A generator function used for integer quantization
1849 where each generated sample has the same order, type and shape as the
1850 inputs to the model. Usually, this is a small subset of a few hundred
1851 samples randomly chosen, in no particular order, from the training or
1852 evaluation dataset. This is an optional attribute, but required for full
1853 integer quantization, i.e, if `tf.int8` is the only supported type in
1854 `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
1855 (default None)
1856 target_spec: Experimental flag, subject to change. Specifications of target
1857 device, including supported ops set, supported types and a set of user's
1858 defined TensorFlow operators required in the TensorFlow Lite runtime.
1859 Refer to `tf.lite.TargetSpec`.
1860 inference_input_type: Data type of the input layer. Note that integer types
1861 (tf.int8 and tf.uint8) are currently only supported for post training
1862 integer quantization and quantization aware training. (default tf.float32,
1863 must be in {tf.float32, tf.int8, tf.uint8})
1864 inference_output_type: Data type of the output layer. Note that integer
1865 types (tf.int8 and tf.uint8) are currently only supported for post
1866 training integer quantization and quantization aware training. (default
1867 tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
1868 allow_custom_ops: Boolean indicating whether to allow custom operations.
1869 When False, any unknown operation is an error. When True, custom ops are
1870 created for any op that is unknown. The developer needs to provide these
1871 to the TensorFlow Lite runtime with a custom resolver. (default False)
1872 exclude_conversion_metadata: Whether not to embed the conversion metadata
1873 into the converted model. (default False)
1874 experimental_new_converter: Experimental flag, subject to change. Enables
1875 MLIR-based conversion. (default True)
1876 experimental_new_quantizer: Experimental flag, subject to change. Enables
1877 MLIR-based quantization conversion instead of Flatbuffer-based conversion.
1878 (default True)
1879 experimental_enable_resource_variables: Experimental flag, subject to
1880 change. Enables [resource
1881 variables](https://tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables)
1882 to be converted by this converter. This is only allowed if the
1883 from_saved_model interface is used. (default True)
1885 Example usage:
1887 ```python
1888 # Converting a SavedModel to a TensorFlow Lite model.
1889 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
1890 tflite_model = converter.convert()
1892 # Converting a tf.Keras model to a TensorFlow Lite model.
1893 converter = tf.lite.TFLiteConverter.from_keras_model(model)
1894 tflite_model = converter.convert()
1896 # Converting ConcreteFunctions to a TensorFlow Lite model.
1897 converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model)
1898 tflite_model = converter.convert()
1900 # Converting a Jax model to a TensorFlow Lite model.
1901 converter = tf.lite.TFLiteConverter.experimental_from_jax(
1902 [func], [[ ('input1', input1), ('input2', input2)]])
1903 tflite_model = converter.convert()
1904 ```
1905 """ # fmt: skip
1907 # pylint: disable=useless-super-delegation
1908 def __init__(self, funcs, trackable_obj=None):
1909 """Constructor for TFLiteConverter.
1911 Args:
1912 funcs: List of TensorFlow ConcreteFunctions. The list should not contain
1913 duplicate elements.
1914 trackable_obj: tf.AutoTrackable object associated with `funcs`. A
1915 reference to this object needs to be maintained so that Variables do not
1916 get garbage collected since functions have a weak reference to
1917 Variables. This is only required when the tf.AutoTrackable object is not
1918 maintained by the user (e.g. `from_saved_model`).
1919 """
1920 super(TFLiteConverterV2, self).__init__(funcs, trackable_obj)
1922 @classmethod
1923 def from_concrete_functions(cls, funcs, trackable_obj=None):
1924 """Creates a TFLiteConverter object from ConcreteFunctions.
1926 Args:
1927 funcs: List of TensorFlow ConcreteFunctions. The list should not contain
1928 duplicate elements. Currently converter can only convert a single
1929 ConcreteFunction. Converting multiple functions is under development.
1930 trackable_obj: An `AutoTrackable` object (typically `tf.module`)
1931 associated with `funcs`. A reference to this object needs to be
1932 maintained so that Variables do not get garbage collected since
1933 functions have a weak reference to Variables.
1935 Returns:
1936 TFLiteConverter object.
1938 Raises:
1939 Invalid input type.
1940 """
1941 # pylint: disable=protected-access
1942 TFLiteConverterBase._set_original_model_type(
1943 conversion_metdata_fb.ModelType.TF_CONCRETE_FUNCTIONS
1944 )
1945 # pylint: enable=protected-access
1946 if trackable_obj is None:
1947 logging.warning(
1948 "Please consider providing the trackable_obj argument in the "
1949 "from_concrete_functions. Providing without the trackable_obj "
1950 "argument is deprecated and it will use the deprecated conversion "
1951 "path."
1952 )
1953 for func in funcs:
1954 if not isinstance(func, _function.ConcreteFunction):
1955 message = "This function takes in a list of ConcreteFunction."
1956 if isinstance(func, _def_function.Function):
1957 message += (
1958 " To get the ConcreteFunction from a Function,"
1959 " call get_concrete_function."
1960 )
1961 raise ValueError(message)
1962 return cls(funcs, trackable_obj)
1964 @classmethod
1965 def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None):
1966 """Creates a TFLiteConverter object from a SavedModel directory.
1968 Args:
1969 saved_model_dir: SavedModel directory to convert.
1970 signature_keys: List of keys identifying SignatureDef containing inputs
1971 and outputs. Elements should not be duplicated. By default the
1972 `signatures` attribute of the MetaGraphdef is used. (default
1973 saved_model.signatures)
1974 tags: Set of tags identifying the MetaGraphDef within the SavedModel to
1975 analyze. All tags in the tag set must be present. (default
1976 {tf.saved_model.SERVING} or {'serve'})
1978 Returns:
1979 TFLiteConverter object.
1981 Raises:
1982 Invalid signature keys.
1983 """
1984 # pylint: disable=protected-access
1985 TFLiteConverterBase._set_original_model_type(
1986 conversion_metdata_fb.ModelType.TF_SAVED_MODEL
1987 )
1988 # pylint: enable=protected-access
1989 # When run without eager enabled, this will return the legacy
1990 # TFLiteConverter.
1991 if not context.executing_eagerly():
1992 signature_key = None
1993 if signature_keys:
1994 if len(signature_keys) != 1:
1995 raise ValueError("Only support a single signature key.")
1996 else:
1997 signature_key = signature_keys[0]
1998 logging.warning(
1999 "Invoking the TF1 implementation of TFLiteConverter "
2000 "because eager is disabled. Consider enabling eager."
2001 )
2002 return TFLiteConverter.from_saved_model(
2003 saved_model_dir, signature_key=signature_key, tag_set=tags
2004 )
2006 # Ensures any graphs created in Eager mode are able to run. This is required
2007 # in order to create a tf.estimator.Exporter that exports a TFLite model.
2008 if tags is None:
2009 tags = set([_tag_constants.SERVING])
2011 with context.eager_mode():
2012 saved_model = _load(saved_model_dir, tags)
2013 if not signature_keys:
2014 signature_keys = saved_model.signatures
2016 if not signature_keys:
2017 raise ValueError("Only support at least one signature key.")
2019 # Distinguishes SavedModel artifacts created by `model.export`
2020 # from SavedModel created by `model.save`/`tf.saved_model.save`.
2021 if (
2022 len(signature_keys) > 1
2023 and hasattr(saved_model, "serve") # `model.export` default endpoint
2024 and not hasattr(saved_model, "_default_save_signature")
2025 # `_default_save_signature` does not exist for `model.export` artifacts.
2026 ):
2027 # Default `serve` endpoint for `model.export` should be copied
2028 # to `serving_default` to prevent issues in TF Lite serving.
2029 saved_model.serving_default = saved_model.serve
2030 delattr(saved_model, "serve")
2031 signature_keys = ["serving_default"]
2033 funcs = []
2034 for key in signature_keys:
2035 if key not in saved_model.signatures:
2036 raise ValueError(
2037 "Invalid signature key '{}' found. Valid keys are '{}'.".format(
2038 key, ",".join(saved_model.signatures)
2039 )
2040 )
2041 funcs.append(saved_model.signatures[key])
2043 saved_model_converter = TFLiteSavedModelConverterV2(
2044 saved_model_dir, tags, signature_keys, saved_model
2045 )
2046 if saved_model_converter.saved_model_dir:
2047 return saved_model_converter
2049 return cls(funcs, saved_model)
2051 @classmethod
2052 def from_keras_model(cls, model):
2053 """Creates a TFLiteConverter object from a Keras model.
2055 Args:
2056 model: tf.Keras.Model
2058 Returns:
2059 TFLiteConverter object.
2060 """
2061 # pylint: disable=protected-access
2062 TFLiteConverterBase._set_original_model_type(
2063 conversion_metdata_fb.ModelType.KERAS_MODEL
2064 )
2065 # pylint: enable=protected-access
2066 return TFLiteKerasModelConverterV2(model)
2068 @classmethod
2069 def experimental_from_jax(cls, serving_funcs, inputs):
2070 # Experimental API, subject to changes.
2071 # TODO(b/197690428): Currently only support single function.
2072 """Creates a TFLiteConverter object from a Jax model with its inputs.
2074 Args:
2075 serving_funcs: A array of Jax functions with all the weights applied
2076 already.
2077 inputs: A array of Jax input placeholders tuples list, e.g.,
2078 jnp.zeros(INPUT_SHAPE). Each tuple list should correspond with the
2079 serving function.
2081 Returns:
2082 TFLiteConverter object.
2083 """
2084 # pylint: disable=protected-access
2085 TFLiteConverterBase._set_original_model_type(
2086 conversion_metdata_fb.ModelType.JAX
2087 )
2088 # pylint: enable=protected-access
2089 return TFLiteJaxConverterV2(serving_funcs, inputs)
2091 # pylint: disable=useless-super-delegation
2092 def convert(self):
2093 """Converts a TensorFlow GraphDef based on instance variables.
2095 Returns:
2096 The converted data in serialized format.
2098 Raises:
2099 ValueError:
2100 No concrete functions is specified.
2101 Multiple concrete functions are specified.
2102 Input shape is not specified.
2103 Invalid quantization parameters.
2104 """
2105 return super(TFLiteConverterV2, self).convert()
2108class TFLiteConverterBaseV1(TFLiteConverterBase):
2109 """Converter subclass to share functionality between V1 converters."""
2111 def __init__(self, experimental_debug_info_func):
2112 """Constructor for TFLiteConverter.
2114 Args:
2115 experimental_debug_info_func: An experimental function to retrieve the
2116 graph debug info for a set of nodes from the `graph_def`.
2117 """
2118 super(TFLiteConverterBaseV1, self).__init__()
2119 self.inference_type = _dtypes.float32
2120 self.inference_input_type = None
2121 self.inference_output_type = None
2122 self.output_format = constants.TFLITE
2123 self.quantized_input_stats = {}
2124 self.default_ranges_stats = None
2125 self.drop_control_dependency = True
2126 self.reorder_across_fake_quant = False
2127 self.change_concat_input_ranges = False
2128 self.dump_graphviz_dir = None
2129 self.dump_graphviz_video = False
2130 self.conversion_summary_dir = None
2131 self._debug_info_func = experimental_debug_info_func
2132 self._metadata.environment.apiVersion = 1
2134 def __setattr__(self, name, value):
2135 if name == "post_training_quantize":
2136 warnings.warn(
2137 "Property %s is deprecated, "
2138 "please use optimizations=[Optimize.DEFAULT]"
2139 " instead." % name
2140 )
2141 if value:
2142 self.optimizations = [Optimize.DEFAULT]
2143 else:
2144 self.optimizations = []
2145 return
2146 if name == "target_ops":
2147 warnings.warn(
2148 "Property %s is deprecated, please use "
2149 "target_spec.supported_ops instead." % name
2150 )
2151 self.target_spec.supported_ops = value
2152 return
2153 object.__setattr__(self, name, value)
2155 def __getattribute__(self, name):
2156 if name == "post_training_quantize":
2157 warnings.warn(
2158 "Property %s is deprecated, "
2159 "please use optimizations=[Optimize.DEFAULT]"
2160 " instead." % name
2161 )
2162 return Optimize.DEFAULT in set(self.optimizations)
2163 if name == "target_ops":
2164 warnings.warn(
2165 "Property %s is deprecated, please use "
2166 "target_spec.supported_ops instead." % name
2167 )
2168 return self.target_spec.supported_ops
2169 return object.__getattribute__(self, name)
2171 def _validate_quantized_input_stats(self, converter_kwargs, quant_mode):
2172 """Ensure the `quantized_input_stats` flag is provided if required."""
2174 quantized_types = frozenset({_dtypes.int8, _dtypes.uint8})
2176 requires_quantized_input_stats = (
2177 converter_kwargs["inference_type"] in quantized_types
2178 or converter_kwargs["inference_input_type"] in quantized_types
2179 ) and not quant_mode.is_post_training_integer_quantization()
2181 if (
2182 requires_quantized_input_stats
2183 and not converter_kwargs["quantized_input_stats"]
2184 ):
2185 raise ValueError(
2186 "The `quantized_input_stats` flag must be defined when either "
2187 "`inference_type` flag or `inference_input_type` flag is set to "
2188 "tf.int8 or tf.uint8. Currently, `inference_type={}` and "
2189 "`inference_input_type={}`.".format(
2190 _get_tf_type_name(converter_kwargs["inference_type"]),
2191 _get_tf_type_name(converter_kwargs["inference_input_type"]),
2192 )
2193 )
2195 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS)
2196 def _validate_inputs(self, input_tensors, quantized_input_stats):
2197 """Validate input parameters.
2199 Args:
2200 input_tensors: List of input tensors.
2201 quantized_input_stats: Map of input tensor names to a tuple of floats
2202 representing the mean and standard deviation of the training data.
2204 Raises:
2205 ValueError:
2206 Input shape is not specified.
2207 Quantization input stats is required but not provided.
2208 """
2210 if not self._is_unknown_shapes_allowed() and self._has_valid_tensors():
2211 # Checks dimensions in input tensor.
2212 for tensor in input_tensors:
2213 shape = tensor.shape
2214 if not shape:
2215 raise ValueError(
2216 "Provide an input shape for input array '{0}'.".format(
2217 _get_tensor_name(tensor)
2218 )
2219 )
2220 # Note that shape_list might be empty for scalar shapes.
2221 shape_list = shape.as_list()
2222 if None in shape_list[1:]:
2223 raise ValueError(
2224 "None is only supported in the 1st dimension. Tensor '{0}' has "
2225 "invalid shape '{1}'.".format(
2226 _get_tensor_name(tensor), shape_list
2227 )
2228 )
2229 elif shape_list and shape_list[0] is None:
2230 self._set_batch_size(batch_size=1)
2232 # Get quantization stats. Ensures there is one stat per name if the stats
2233 # are specified.
2234 if quantized_input_stats:
2235 self._quantized_stats = []
2236 invalid_stats = []
2237 for name in self.get_input_arrays():
2238 if name in quantized_input_stats:
2239 self._quantized_stats.append(quantized_input_stats[name])
2240 else:
2241 invalid_stats.append(name)
2243 if invalid_stats:
2244 raise ValueError(
2245 "Quantization input stats are not available for input "
2246 "tensors '{0}'.".format(",".join(invalid_stats))
2247 )
2248 else:
2249 self._quantized_stats = None
2251 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL)
2252 def _optimize_tf_model(
2253 self, graph_def, input_tensors, output_tensors, quant_mode
2254 ):
2255 """Run a Grappler pass to optimize the TensorFlow graph.
2257 Args:
2258 graph_def: Frozen GraphDef to be optimized.
2259 input_tensors: List of input tensors.
2260 output_tensors: List of output tensors.
2261 quant_mode: the quantization mode.
2263 Returns:
2264 The optimized TensorFlow graph.
2265 """
2266 # Disable grappler constant folding if there are training quant ops.
2267 if self.saved_model_dir or quant_mode.is_quantization_aware_trained_model():
2268 return graph_def
2270 try:
2271 # TODO(b/150163103): Merge `disabling lower using switch merge' calls.
2272 # Grappler will also try to lower while loop into switch merge
2273 # representation which is undesired for Ophints, so we simply remove
2274 # those attributes to prevent Grappler from doing so.
2275 graph = _convert_to_constants.disable_lower_using_switch_merge(graph_def)
2276 # Run function inlining optimization to ensure any models generated
2277 # through the from_frozen_graph path have been inlined.
2278 optimized_graph = _run_graph_optimizations(
2279 graph,
2280 input_tensors,
2281 output_tensors,
2282 config=self._grappler_config(["function"]),
2283 )
2284 return optimized_graph
2285 except Exception: # pylint: disable=broad-except
2286 return graph_def
2288 def convert(self):
2289 """Converts a TensorFlow GraphDef based on instance variables.
2291 Returns:
2292 The converted data in serialized format. Either a TFLite Flatbuffer or a
2293 Graphviz graph depending on value in `output_format`.
2295 Raises:
2296 ValueError:
2297 Input shape is not specified.
2298 None value for dimension in input_tensor.
2299 """
2300 self._validate_inputs(self._input_tensors, self.quantized_input_stats)
2302 quant_mode = QuantizationMode(
2303 self.optimizations,
2304 self.target_spec,
2305 self.representative_dataset,
2306 self._graph_def,
2307 self._experimental_disable_per_channel,
2308 self.experimental_new_dynamic_range_quantizer,
2309 self._experimental_low_bit_qat,
2310 self._experimental_full_integer_quantization_bias_type,
2311 self._experimental_variable_quantization,
2312 )
2314 optimized_graph = self._optimize_tf_model(
2315 self._graph_def, self._input_tensors, self._output_tensors, quant_mode
2316 )
2318 self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
2320 converter_kwargs = self._get_base_converter_args()
2321 converter_kwargs.update(
2322 quant_mode.converter_flags(
2323 self.inference_type, self.inference_input_type
2324 )
2325 )
2326 converter_kwargs.update({
2327 "output_format": self.output_format,
2328 "quantized_input_stats": self._quantized_stats,
2329 "default_ranges_stats": self.default_ranges_stats,
2330 "drop_control_dependency": self.drop_control_dependency,
2331 "reorder_across_fake_quant": self.reorder_across_fake_quant,
2332 "change_concat_input_ranges": self.change_concat_input_ranges,
2333 "dump_graphviz_dir": self.dump_graphviz_dir,
2334 "dump_graphviz_video": self.dump_graphviz_video,
2335 "conversion_summary_dir": self.conversion_summary_dir,
2336 })
2338 self._validate_quantized_input_stats(converter_kwargs, quant_mode)
2339 if not self.experimental_new_converter:
2340 logging.warning(
2341 "Please consider switching to the new converter by setting "
2342 "experimental_new_converter=True. "
2343 "The old converter is deprecated."
2344 )
2345 else:
2346 logging.info(
2347 "Using experimental converter: If you encountered a problem "
2348 "please file a bug. You can opt-out "
2349 "by setting experimental_new_converter=False"
2350 )
2351 # Converts model.
2352 if self._has_valid_tensors():
2353 result = _convert_graphdef(
2354 input_data=optimized_graph,
2355 input_tensors=self._input_tensors,
2356 output_tensors=self._output_tensors,
2357 **converter_kwargs,
2358 )
2359 else:
2360 result = _convert_graphdef_with_arrays(
2361 input_data=optimized_graph,
2362 input_arrays_with_shape=self._input_arrays_with_shape,
2363 output_arrays=self._output_arrays,
2364 control_output_arrays=self._control_output_arrays,
2365 **converter_kwargs,
2366 )
2368 return self._optimize_tflite_model(
2369 result, quant_mode, quant_io=self.experimental_new_quantizer
2370 )
2372 def get_input_arrays(self):
2373 """Returns a list of the names of the input tensors.
2375 Returns:
2376 List of strings.
2377 """
2378 if self._has_valid_tensors():
2379 return [_get_tensor_name(tensor) for tensor in self._input_tensors]
2380 else:
2381 return [name for name, _ in self._input_arrays_with_shape]
2383 def _has_valid_tensors(self):
2384 """Checks if the input and output tensors have been initialized.
2386 Returns:
2387 Bool.
2388 """
2389 return self._input_tensors is not None and self._output_tensors
2391 def _set_batch_size(self, batch_size):
2392 """Sets the first dimension of the input tensor to `batch_size`.
2394 Args:
2395 batch_size: Batch size for the model. Replaces the first dimension of an
2396 input size array if undefined. (default 1)
2398 Raises:
2399 ValueError: input_tensor is not defined.
2400 """
2401 if not self._has_valid_tensors():
2402 raise ValueError(
2403 "The batch size cannot be set for this model. Please "
2404 "use input_shapes parameter."
2405 )
2407 for tensor in self._input_tensors:
2408 shape = tensor.shape.as_list()
2409 if shape[0] is None:
2410 shape[0] = batch_size
2411 tensor.set_shape(shape)
2413 def _is_unknown_shapes_allowed(self):
2414 # Ophint Converted nodes will need the shapes to be known.
2415 if _is_ophint_converted(self._graph_def):
2416 return False
2418 if not super(TFLiteConverterBaseV1, self)._is_unknown_shapes_allowed():
2419 return False
2421 # `conversion_summary_dir` calls the old converter. Unknown shapes are only
2422 # supported by the MLIR converter.
2423 if self.conversion_summary_dir:
2424 logging.warning(
2425 "`conversion_summary_dir` does not work with unknown shapes. "
2426 "Graphs with unknown shapes might be different than when this flag "
2427 "is disabled."
2428 )
2429 return False
2430 return True
2432 def _save_conversion_params_metric(self):
2433 self._collected_converter_params.update({
2434 "output_format": self.output_format,
2435 "default_ranges_stats": self.default_ranges_stats,
2436 "drop_control_dependency": self.drop_control_dependency,
2437 "reorder_across_fake_quant": self.reorder_across_fake_quant,
2438 "change_concat_input_ranges": self.change_concat_input_ranges,
2439 "dump_graphviz_dir": self.dump_graphviz_dir,
2440 "dump_graphviz_video": self.dump_graphviz_video,
2441 "conversion_summary_dir": self.conversion_summary_dir,
2442 })
2443 super(TFLiteConverterBaseV1, self)._save_conversion_params_metric(
2444 self._graph_def, self.inference_type, self.inference_input_type
2445 )
2448class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
2449 """Converts the given SavedModel into TensorFlow Lite model.
2451 Attributes:
2452 saved_model_dir: Directory of the SavedModel.
2453 """
2455 def __init__(
2456 self,
2457 saved_model_dir,
2458 saved_model_tags,
2459 saved_model_exported_names,
2460 experimental_debug_info_func=None,
2461 ):
2462 """Constructor for TFLiteConverter.
2464 Args:
2465 saved_model_dir: Directory of the SavedModel.
2466 saved_model_tags: Set of tags identifying the MetaGraphDef within the
2467 SavedModel to analyze. All tags in the tag set must be present. (default
2468 {tf.saved_model.SERVING}).
2469 saved_model_exported_names: Names to be exported when the saved model
2470 import path is on.
2471 experimental_debug_info_func: An experimental function to retrieve the
2472 graph debug info for a set of nodes from the `graph_def`.
2474 Raises:
2475 ValueError: Invalid arguments.
2476 """
2477 super(TFLiteSavedModelConverter, self).__init__(
2478 experimental_debug_info_func
2479 )
2480 self.saved_model_dir = saved_model_dir
2481 self._saved_model_tags = saved_model_tags
2482 self._saved_model_exported_names = saved_model_exported_names
2484 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
2486 if len(self._saved_model_exported_names) != 1:
2487 raise ValueError("Only support a single signature key.")
2489 signature_key = self._saved_model_exported_names[0]
2491 result = _freeze_saved_model(
2492 self.saved_model_dir,
2493 None,
2494 None,
2495 None,
2496 self._saved_model_tags,
2497 signature_key,
2498 )
2499 self._graph_def = result[0]
2500 self._input_tensors = result[1]
2501 self._output_tensors = result[2]
2502 self._parse_saved_model_args()
2504 @_export_metrics
2505 def convert(self):
2506 """Converts a TensorFlow GraphDef based on instance variables.
2508 Note that in the converted TensorFlow Lite model, the input tensor's order
2509 might be changed each time `convert` is called. To access input tensor
2510 information, please consider using the `SignatureRunner` API
2511 (`interpreter.get_signature_runner`).
2513 Returns:
2514 The converted data in serialized format. Either a TFLite Flatbuffer or a
2515 Graphviz graph depending on value in `output_format`.
2517 Raises:
2518 ValueError:
2519 Input shape is not specified.
2520 None value for dimension in input_tensor.
2521 """
2522 return super(TFLiteSavedModelConverter, self).convert()
2525class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
2526 """Converts the given SavedModel into TensorFlow Lite model."""
2528 def __init__(
2529 self,
2530 model_file,
2531 input_arrays=None,
2532 input_shapes=None,
2533 output_arrays=None,
2534 custom_objects=None,
2535 ):
2536 """Constructor for TFLiteConverter.
2538 Args:
2539 model_file: Full filepath of HDF5 file containing the tf.keras model.
2540 input_arrays: List of input tensors to freeze graph with. Uses input
2541 arrays from SignatureDef when none are provided. (default None)
2542 input_shapes: Dict of strings representing input tensor names to list of
2543 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
2544 Automatically determined when input shapes is None (e.g., {"foo" :
2545 None}). (default None)
2546 output_arrays: List of output tensors to freeze graph with. Uses output
2547 arrays from SignatureDef when none are provided. (default None)
2548 custom_objects: Dict mapping names (strings) to custom classes or
2549 functions to be considered during model deserialization. (default None)
2551 Raises:
2552 ValueError: Invalid arguments.
2553 """
2554 super(TFLiteKerasModelConverter, self).__init__(
2555 experimental_debug_info_func=None
2556 )
2557 # Handles Keras when Eager mode is enabled.
2558 if context.executing_eagerly():
2559 if input_arrays or output_arrays:
2560 raise ValueError(
2561 "`input_arrays` and `output_arrays` are unsupported "
2562 "with Eager mode. If your model requires any of these "
2563 "parameters, please use disable_eager_execution()."
2564 )
2566 keras_model = keras_deps.get_load_model_function()(
2567 model_file, custom_objects
2568 )
2569 function = _trace_model_call(keras_model)
2570 concrete_func = function.get_concrete_function()
2572 frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
2573 concrete_func, lower_control_flow=False
2574 )
2575 _set_tensor_shapes(frozen_func.inputs, input_shapes)
2576 self._keras_model = keras_model
2577 self._graph_def = frozen_func.graph.as_graph_def()
2578 self._input_tensors = frozen_func.inputs
2579 self._output_tensors = frozen_func.outputs
2580 self._debug_info_func = _build_debug_info_func(frozen_func.graph)
2581 return
2583 # Handles Keras when Eager mode is disabled.
2584 keras_deps.get_clear_session_function()()
2585 keras_model = keras_deps.get_load_model_function()(
2586 model_file, custom_objects
2587 )
2588 sess = keras_deps.get_get_session_function()()
2590 # Get input and output tensors.
2591 if input_arrays:
2592 input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
2593 else:
2594 input_tensors = keras_model.inputs
2596 if output_arrays:
2597 output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
2598 else:
2599 output_tensors = keras_model.outputs
2600 _set_tensor_shapes(input_tensors, input_shapes)
2602 graph_def = _freeze_graph(sess, input_tensors, output_tensors)
2603 self._keras_model = keras_model
2604 self._graph_def = graph_def
2605 self._input_tensors = input_tensors
2606 self._output_tensors = output_tensors
2607 self._debug_info_func = _build_debug_info_func(sess.graph)
2609 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL)
2610 def _freeze_keras_model(self, output_dir):
2611 """Save Keras model to Saved Model format.
2613 Args:
2614 output_dir: The output directory to save the SavedModel.
2615 """
2616 try:
2617 self._keras_model.save(output_dir, save_format="tf")
2618 except Exception: # pylint: disable=broad-except
2619 # When storing the given keras model to a saved model is failed, let's
2620 # use original keras model conversion pipeline.
2621 return None
2622 tag_set = set([_tag_constants.SERVING])
2623 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
2624 graph_def, input_tensors, output_tensors, sess_graph = _freeze_saved_model(
2625 output_dir, None, None, None, tag_set, signature_key
2626 )
2628 self.saved_model_dir = output_dir
2629 self._saved_model_tags = tag_set
2630 self._saved_model_exported_names = [signature_key]
2631 self._parse_saved_model_args()
2632 if self.saved_model_dir:
2633 self._graph_def = graph_def
2634 self._input_tensors = input_tensors
2635 self._output_tensors = output_tensors
2636 self._debug_info_func = _build_debug_info_func(sess_graph)
2638 def _convert_as_saved_model(self):
2639 """Converts a Keras model as a saved model.
2641 Returns:
2642 The converted data in serialized format.
2643 """
2644 temp_dir = tempfile.mkdtemp()
2645 try:
2646 self._freeze_keras_model(temp_dir)
2647 if self.saved_model_dir:
2648 return super(TFLiteKerasModelConverter, self).convert()
2649 finally:
2650 shutil.rmtree(temp_dir, True)
2652 @_export_metrics
2653 def convert(self):
2654 """Converts a Keras model based on instance variables.
2656 Returns:
2657 The converted data in serialized format. Either a TFLite Flatbuffer or a
2658 Graphviz graph depending on value in `output_format`.
2660 Raises:
2661 ValueError:
2662 Input shape is not specified.
2663 None value for dimension in input_tensor.
2664 """
2665 saved_model_convert_result = self._convert_as_saved_model()
2666 if saved_model_convert_result:
2667 return saved_model_convert_result
2669 return super(TFLiteKerasModelConverter, self).convert()
2672class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1):
2673 """Converts the given frozen graph def into TensorFlow Lite model."""
2675 def __init__(
2676 self,
2677 graph_def,
2678 input_tensors,
2679 output_tensors,
2680 input_arrays_with_shape=None,
2681 output_arrays=None,
2682 experimental_debug_info_func=None,
2683 ):
2684 """Constructor for TFLiteConverter.
2686 Args:
2687 graph_def: Frozen TensorFlow GraphDef.
2688 input_tensors: List of input tensors. Type and shape are computed using
2689 `foo.shape` and `foo.dtype`.
2690 output_tensors: List of output tensors (only .name is used from this).
2691 input_arrays_with_shape: Tuple of strings representing input tensor names
2692 and list of integers representing input shapes (e.g., [("foo", [1, 16,
2693 16, 3])]). Use only when graph cannot be loaded into TensorFlow and when
2694 `input_tensors` and `output_tensors` are None. (default None)
2695 output_arrays: List of output tensors to freeze graph with. Use only when
2696 graph cannot be loaded into TensorFlow and when `input_tensors` and
2697 `output_tensors` are None. (default None)
2698 experimental_debug_info_func: An experimental function to retrieve the
2699 graph debug info for a set of nodes from the `graph_def`.
2701 Raises:
2702 ValueError: Invalid arguments.
2703 """
2704 super(TFLiteFrozenGraphConverter, self).__init__(
2705 experimental_debug_info_func
2706 )
2707 self._graph_def = graph_def
2708 self._input_tensors = input_tensors
2709 self._output_tensors = output_tensors
2710 self._control_output_arrays = None
2712 # Attributes are used by models that cannot be loaded into TensorFlow.
2713 if not self._has_valid_tensors():
2714 self._input_arrays_with_shape = input_arrays_with_shape
2715 self._output_arrays = output_arrays
2717 if input_tensors is not None and input_arrays_with_shape is not None:
2718 logging.warning(
2719 "input_arrays_with_shape will be ignored when both the "
2720 "given input_tensors and input_arrays_with_shape are not "
2721 "None."
2722 )
2724 if output_tensors is not None and output_arrays is not None:
2725 logging.warning(
2726 "output_arrays will be ignored when both the given "
2727 "output_tensors and output_arrays are not None."
2728 )
2730 @_export_metrics
2731 def convert(self):
2732 """Converts a TensorFlow GraphDef based on instance variables.
2734 Returns:
2735 The converted data in serialized format. Either a TFLite Flatbuffer or a
2736 Graphviz graph depending on value in `output_format`.
2738 Raises:
2739 ValueError:
2740 Input shape is not specified.
2741 None value for dimension in input_tensor.
2742 """
2743 if not self._has_valid_tensors():
2744 if not self._input_arrays_with_shape or not (
2745 self._output_arrays or self._control_output_arrays
2746 ):
2747 raise ValueError(
2748 "If input_tensors and output_tensors are None, both "
2749 "input_arrays_with_shape and output_arrays|control_output_arrays "
2750 "must be defined."
2751 )
2752 return super(TFLiteFrozenGraphConverter, self).convert()
2755@_tf_export(v1=["lite.TFLiteConverter"])
2756class TFLiteConverter(TFLiteFrozenGraphConverter):
2757 """Convert a TensorFlow model into `output_format`.
2759 This is used to convert from a TensorFlow GraphDef, SavedModel or tf.keras
2760 model into either a TFLite FlatBuffer or graph visualization.
2762 Attributes:
2763 optimizations: Experimental flag, subject to change. Set of optimizations to
2764 apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
2765 set of values of type `tf.lite.Optimize`)
2766 representative_dataset: A generator function used for integer quantization
2767 where each generated sample has the same order, type and shape as the
2768 inputs to the model. Usually, this is a small subset of a few hundred
2769 samples randomly chosen, in no particular order, from the training or
2770 evaluation dataset. This is an optional attribute, but required for full
2771 integer quantization, i.e, if `tf.int8` is the only supported type in
2772 `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
2773 (default None)
2774 target_spec: Experimental flag, subject to change. Specifications of target
2775 device, including supported ops set, supported types and a set of user's
2776 defined TensorFlow operators required in the TensorFlow Lite runtime.
2777 Refer to `tf.lite.TargetSpec`.
2778 inference_type: Data type of numeric arrays, excluding the input layer.
2779 (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
2780 inference_input_type: Data type of the numeric arrays in the input layer. If
2781 `inference_input_type` is in {tf.int8, tf.uint8}, then
2782 `quantized_input_stats` must be provided. (default is the value assigned
2783 to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
2784 inference_output_type: Data type of the numeric arrays in the output layer.
2785 (default is the value assigned to `inference_type`, must be in
2786 {tf.float32, tf.int8, tf.uint8})
2787 quantized_input_stats: Map of input tensor names to a tuple of floats
2788 representing the mean and standard deviation of the training data. (e.g.,
2789 {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8 or
2790 tf.uint8. (default None)
2791 default_ranges_stats: Tuple of integers (min, max) representing range values
2792 for all numeric arrays without a specified range. Intended for
2793 experimenting with quantization via "dummy quantization". (default None)
2794 allow_custom_ops: Boolean indicating whether to allow custom operations.
2795 When False any unknown operation is an error. When True, custom ops are
2796 created for any op that is unknown. The developer will need to provide
2797 these to the TensorFlow Lite runtime with a custom resolver. (default
2798 False)
2799 drop_control_dependency: Boolean indicating whether to drop control
2800 dependencies silently. This is due to TFLite not supporting control
2801 dependencies. (default True)
2802 reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
2803 nodes in unexpected locations. Used when the location of the FakeQuant
2804 nodes is preventing graph transformations necessary to convert the graph.
2805 Results in a graph that differs from the quantized training graph,
2806 potentially causing differing arithmetic behavior. (default False)
2807 change_concat_input_ranges: Boolean to change behavior of min/max ranges for
2808 inputs and outputs of the concat operator for quantized models. Changes
2809 the ranges of concat operator overlap when true. (default False)
2810 output_format: Output file format. (default
2811 tf.compat.v1.lite.constants.TFLITE, must be in
2812 {tf.compat.v1.lite.constants.TFLITE,
2813 tf.compat.v1.lite.constants.GRAPHVIZ_DOT})
2814 dump_graphviz_dir: Full filepath of folder to dump the graphs at various
2815 stages of processing GraphViz .dot files. Preferred over
2816 `output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep
2817 the requirements of the output file. (default None)
2818 dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot
2819 files after every graph transformation. Requires the `dump_graphviz_dir`
2820 flag to be specified. (default False)
2821 conversion_summary_dir: Full path of the directory to store conversion logs.
2822 (default None)
2823 exclude_conversion_metadata: Whether not to embed the conversion metadata
2824 into the converted model. (default False)
2825 target_ops: Deprecated. Please use `target_spec.supported_ops` instead.
2826 post_training_quantize: Deprecated. Please use `optimizations` instead and
2827 set it to `{tf.lite.Optimize.DEFAULT}`. (default False)
2828 experimental_new_converter: Experimental flag, subject to change. Enables
2829 MLIR-based conversion. (default True)
2830 experimental_new_quantizer: Experimental flag, subject to change. Enables
2831 MLIR-based quantization conversion instead of Flatbuffer-based conversion.
2832 (default True) Example usage: ```python # Converting a GraphDef from
2833 session. converter = tf.compat.v1.lite.TFLiteConverter.from_session( sess,
2834 in_tensors, out_tensors) tflite_model = converter.convert()
2835 open("converted_model.tflite", "wb").write(tflite_model) # Converting a
2836 GraphDef from file. converter =
2837 tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( graph_def_file,
2838 input_arrays, output_arrays) tflite_model = converter.convert()
2839 open("converted_model.tflite", "wb").write(tflite_model) # Converting a
2840 SavedModel. converter =
2841 tf.compat.v1.lite.TFLiteConverter.from_saved_model( saved_model_dir)
2842 tflite_model = converter.convert() open("converted_model.tflite",
2843 "wb").write(tflite_model) # Converting a tf.keras model. converter =
2844 tf.compat.v1.lite.TFLiteConverter.from_keras_model_file( keras_model)
2845 tflite_model = converter.convert() open("converted_model.tflite",
2846 "wb").write(tflite_model) ```
2847 """
2849 # pylint: disable=useless-super-delegation
2850 def __init__(
2851 self,
2852 graph_def,
2853 input_tensors,
2854 output_tensors,
2855 input_arrays_with_shape=None,
2856 output_arrays=None,
2857 experimental_debug_info_func=None,
2858 ):
2859 """Constructor for TFLiteConverter.
2861 Args:
2862 graph_def: Frozen TensorFlow GraphDef.
2863 input_tensors: List of input tensors. Type and shape are computed using
2864 `foo.shape` and `foo.dtype`.
2865 output_tensors: List of output tensors (only .name is used from this).
2866 input_arrays_with_shape: Tuple of strings representing input tensor names
2867 and list of integers representing input shapes (e.g., [("foo" : [1, 16,
2868 16, 3])]). Use only when graph cannot be loaded into TensorFlow and when
2869 `input_tensors` and `output_tensors` are None. (default None)
2870 output_arrays: List of output tensors to freeze graph with. Use only when
2871 graph cannot be loaded into TensorFlow and when `input_tensors` and
2872 `output_tensors` are None. (default None)
2873 experimental_debug_info_func: An experimental function to retrieve the
2874 graph debug info for a set of nodes from the `graph_def`.
2876 Raises:
2877 ValueError: Invalid arguments.
2878 """
2879 super(TFLiteConverter, self).__init__(
2880 graph_def,
2881 input_tensors,
2882 output_tensors,
2883 input_arrays_with_shape,
2884 output_arrays,
2885 experimental_debug_info_func,
2886 )
2888 @classmethod
2889 def from_session(cls, sess, input_tensors, output_tensors):
2890 """Creates a TFLiteConverter class from a TensorFlow Session.
2892 Args:
2893 sess: TensorFlow Session.
2894 input_tensors: List of input tensors. Type and shape are computed using
2895 `foo.shape` and `foo.dtype`.
2896 output_tensors: List of output tensors (only .name is used from this).
2898 Returns:
2899 TFLiteConverter class.
2900 """
2901 # pylint: disable=protected-access
2902 TFLiteConverterBase._set_original_model_type(
2903 conversion_metdata_fb.ModelType.TF_SESSION
2904 )
2905 # pylint: enable=protected-access
2906 graph_def = _freeze_graph(sess, input_tensors, output_tensors)
2907 return cls(
2908 graph_def,
2909 input_tensors,
2910 output_tensors,
2911 experimental_debug_info_func=_build_debug_info_func(sess.graph),
2912 )
2914 @classmethod
2915 def from_frozen_graph(
2916 cls, graph_def_file, input_arrays, output_arrays, input_shapes=None
2917 ):
2918 """Creates a TFLiteConverter class from a file containing a frozen GraphDef.
2920 Args:
2921 graph_def_file: Full filepath of file containing frozen GraphDef.
2922 input_arrays: List of input tensors to freeze graph with.
2923 output_arrays: List of output tensors to freeze graph with.
2924 input_shapes: Dict of strings representing input tensor names to list of
2925 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
2926 Automatically determined when input shapes is None (e.g., {"foo" :
2927 None}). (default None)
2929 Returns:
2930 TFLiteConverter class.
2932 Raises:
2933 IOError:
2934 File not found.
2935 Unable to parse input file.
2936 ValueError:
2937 The graph is not frozen.
2938 input_arrays or output_arrays contains an invalid tensor name.
2939 input_shapes is not correctly defined when required
2940 """
2941 # pylint: disable=protected-access
2942 TFLiteConverterBase._set_original_model_type(
2943 conversion_metdata_fb.ModelType.TF_GRAPH_DEF
2944 )
2945 # pylint: enable=protected-access
2946 with _ops.Graph().as_default():
2947 with _session.Session() as sess:
2948 # Read GraphDef from file.
2949 if not gfile.Exists(graph_def_file):
2950 raise IOError("File '{0}' does not exist.".format(graph_def_file))
2951 with gfile.GFile(graph_def_file, "rb") as f:
2952 file_content = f.read()
2954 try:
2955 graph_def = _graph_pb2.GraphDef()
2956 graph_def.ParseFromString(file_content)
2957 except (_text_format.ParseError, DecodeError):
2958 try:
2959 print("Ignore 'tcmalloc: large alloc' warnings.")
2961 if not isinstance(file_content, str):
2962 file_content = file_content.decode("utf-8")
2963 graph_def = _graph_pb2.GraphDef()
2964 _text_format.Merge(file_content, graph_def)
2965 except (_text_format.ParseError, DecodeError):
2966 raise IOError(
2967 "Unable to parse input file '{}'.".format(graph_def_file)
2968 )
2970 if sys.byteorder == "big":
2971 bst.swap_tensor_content_in_graph_node(graph_def, "little", "big")
2973 # Handles models with custom TFLite ops that cannot be resolved in
2974 # TensorFlow.
2975 load_model_in_session = True
2976 try:
2977 _import_graph_def(graph_def, name="")
2978 except _NotFoundError:
2979 load_model_in_session = False
2981 if load_model_in_session:
2982 # Check if graph is frozen.
2983 if not _is_frozen_graph(sess):
2984 raise ValueError("Please freeze the graph using freeze_graph.py.")
2986 # Get input and output tensors.
2987 input_tensors = _get_tensors_from_tensor_names(
2988 sess.graph, input_arrays
2989 )
2990 output_tensors = _get_tensors_from_tensor_names(
2991 sess.graph, output_arrays
2992 )
2993 _set_tensor_shapes(input_tensors, input_shapes)
2995 return cls(sess.graph_def, input_tensors, output_tensors)
2996 else:
2997 if not input_shapes:
2998 raise ValueError("input_shapes must be defined for this model.")
2999 if set(input_arrays) != set(input_shapes.keys()):
3000 raise ValueError(
3001 "input_shapes must contain a value for each item "
3002 "in input_array."
3003 )
3005 input_arrays_with_shape = [
3006 (name, input_shapes[name]) for name in input_arrays
3007 ]
3008 return cls(
3009 graph_def,
3010 input_tensors=None,
3011 output_tensors=None,
3012 input_arrays_with_shape=input_arrays_with_shape,
3013 output_arrays=output_arrays,
3014 )
3016 @classmethod
3017 def from_saved_model(
3018 cls,
3019 saved_model_dir,
3020 input_arrays=None,
3021 input_shapes=None,
3022 output_arrays=None,
3023 tag_set=None,
3024 signature_key=None,
3025 ):
3026 """Creates a TFLiteConverter class from a SavedModel.
3028 Args:
3029 saved_model_dir: SavedModel directory to convert.
3030 input_arrays: List of input tensors to freeze graph with. Uses input
3031 arrays from SignatureDef when none are provided. (default None)
3032 input_shapes: Dict of strings representing input tensor names to list of
3033 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
3034 Automatically determined when input shapes is None (e.g., {"foo" :
3035 None}). (default None)
3036 output_arrays: List of output tensors to freeze graph with. Uses output
3037 arrays from SignatureDef when none are provided. (default None)
3038 tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
3039 analyze. All tags in the tag set must be present. (default
3040 {tf.saved_model.SERVING})
3041 signature_key: Key identifying SignatureDef containing inputs and outputs.
3042 (default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
3044 Returns:
3045 TFLiteConverter class.
3046 """
3047 # pylint: disable=protected-access
3048 TFLiteConverterBase._set_original_model_type(
3049 conversion_metdata_fb.ModelType.TF_SAVED_MODEL
3050 )
3051 # pylint: enable=protected-access
3052 if tag_set is None:
3053 tag_set = set([_tag_constants.SERVING])
3054 if signature_key is None:
3055 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
3057 saved_model_converter = TFLiteSavedModelConverter(
3058 saved_model_dir, tag_set, [signature_key]
3059 )
3060 if saved_model_converter.saved_model_dir:
3061 return saved_model_converter
3063 result = _freeze_saved_model(
3064 saved_model_dir,
3065 input_arrays,
3066 input_shapes,
3067 output_arrays,
3068 tag_set,
3069 signature_key,
3070 )
3072 return cls(
3073 graph_def=result[0],
3074 input_tensors=result[1],
3075 output_tensors=result[2],
3076 experimental_debug_info_func=_build_debug_info_func(result[3]),
3077 )
3079 @classmethod
3080 def from_keras_model_file(
3081 cls,
3082 model_file,
3083 input_arrays=None,
3084 input_shapes=None,
3085 output_arrays=None,
3086 custom_objects=None,
3087 ):
3088 """Creates a TFLiteConverter class from a tf.keras model file.
3090 Args:
3091 model_file: Full filepath of HDF5 file containing the tf.keras model.
3092 input_arrays: List of input tensors to freeze graph with. Uses input
3093 arrays from SignatureDef when none are provided. (default None)
3094 input_shapes: Dict of strings representing input tensor names to list of
3095 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
3096 Automatically determined when input shapes is None (e.g., {"foo" :
3097 None}). (default None)
3098 output_arrays: List of output tensors to freeze graph with. Uses output
3099 arrays from SignatureDef when none are provided. (default None)
3100 custom_objects: Dict mapping names (strings) to custom classes or
3101 functions to be considered during model deserialization. (default None)
3103 Returns:
3104 TFLiteConverter class.
3105 """
3106 # pylint: disable=protected-access
3107 TFLiteConverterBase._set_original_model_type(
3108 conversion_metdata_fb.ModelType.KERAS_MODEL
3109 )
3110 # pylint: enable=protected-access
3111 return TFLiteKerasModelConverter(
3112 model_file, input_arrays, input_shapes, output_arrays, custom_objects
3113 )
3115 # pylint: disable=useless-super-delegation
3116 def convert(self):
3117 """Converts a TensorFlow GraphDef based on instance variables.
3119 Returns:
3120 The converted data in serialized format. Either a TFLite Flatbuffer or a
3121 Graphviz graph depending on value in `output_format`.
3123 Raises:
3124 ValueError:
3125 Input shape is not specified.
3126 None value for dimension in input_tensor.
3127 """
3128 return super(TFLiteConverter, self).convert()
3131@_tf_export(v1=["lite.TocoConverter"])
3132class TocoConverter:
3133 """Convert a TensorFlow model into `output_format`.
3135 This class has been deprecated. Please use `lite.TFLiteConverter` instead.
3136 """
3138 @classmethod
3139 @_deprecation.deprecated(
3140 None, "Use `lite.TFLiteConverter.from_session` instead."
3141 )
3142 def from_session(cls, sess, input_tensors, output_tensors):
3143 """Creates a TocoConverter class from a TensorFlow Session."""
3144 return TFLiteConverter.from_session(sess, input_tensors, output_tensors)
3146 @classmethod
3147 @_deprecation.deprecated(
3148 None, "Use `lite.TFLiteConverter.from_frozen_graph` instead."
3149 )
3150 def from_frozen_graph(
3151 cls, graph_def_file, input_arrays, output_arrays, input_shapes=None
3152 ):
3153 """Creates a TocoConverter class from a file containing a frozen graph."""
3154 return TFLiteConverter.from_frozen_graph(
3155 graph_def_file, input_arrays, output_arrays, input_shapes
3156 )
3158 @classmethod
3159 @_deprecation.deprecated(
3160 None, "Use `lite.TFLiteConverter.from_saved_model` instead."
3161 )
3162 def from_saved_model(
3163 cls,
3164 saved_model_dir,
3165 input_arrays=None,
3166 input_shapes=None,
3167 output_arrays=None,
3168 tag_set=None,
3169 signature_key=None,
3170 ):
3171 """Creates a TocoConverter class from a SavedModel."""
3172 return TFLiteConverter.from_saved_model(
3173 saved_model_dir,
3174 input_arrays,
3175 input_shapes,
3176 output_arrays,
3177 tag_set,
3178 signature_key,
3179 )
3181 @classmethod
3182 @_deprecation.deprecated(
3183 None, "Use `lite.TFLiteConverter.from_keras_model_file` instead."
3184 )
3185 def from_keras_model_file(
3186 cls, model_file, input_arrays=None, input_shapes=None, output_arrays=None
3187 ):
3188 """Creates a TocoConverter class from a tf.keras model file."""
3189 return TFLiteConverter.from_keras_model_file(
3190 model_file, input_arrays, input_shapes, output_arrays
3191 )