Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/python/lite.py: 24%

1014 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""TensorFlow Lite tooling helper functionality.""" 

16 

17import enum 

18import functools 

19import pprint 

20import shutil 

21import sys 

22import tempfile 

23import time 

24import warnings 

25 

26from absl import logging 

27 

28from google.protobuf import text_format as _text_format 

29from google.protobuf.message import DecodeError 

30from tensorflow.core.framework import graph_pb2 as _graph_pb2 

31from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op # pylint: disable=unused-import 

32from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metdata_fb 

33from tensorflow.lite.python import lite_constants as constants 

34from tensorflow.lite.python.convert import convert_graphdef as _convert_graphdef 

35from tensorflow.lite.python.convert import convert_graphdef_with_arrays as _convert_graphdef_with_arrays 

36from tensorflow.lite.python.convert import convert_jax_hlo as _convert_jax_hlo 

37from tensorflow.lite.python.convert import convert_saved_model as _convert_saved_model 

38from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import 

39from tensorflow.lite.python.convert import deduplicate_readonly_buffers as _deduplicate_readonly_buffers 

40from tensorflow.lite.python.convert import mlir_quantize as _mlir_quantize 

41from tensorflow.lite.python.convert import mlir_sparsify as _mlir_sparsify 

42from tensorflow.lite.python.convert import OpsSet 

43from tensorflow.lite.python.convert import toco_convert # pylint: disable=unused-import 

44from tensorflow.lite.python.convert_phase import Component 

45from tensorflow.lite.python.convert_phase import convert_phase 

46from tensorflow.lite.python.convert_phase import SubComponent 

47from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model 

48from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import 

49from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import 

50from tensorflow.lite.python.interpreter import OpResolverType # pylint: disable=unused-import 

51from tensorflow.lite.python.metrics import metrics 

52from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import 

53from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted 

54from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import 

55from tensorflow.lite.python.optimize import calibrator as _calibrator 

56from tensorflow.lite.python.util import _xla_computation 

57from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func 

58from tensorflow.lite.python.util import convert_debug_info_func as _convert_debug_info_func 

59from tensorflow.lite.python.util import freeze_graph as _freeze_graph 

60from tensorflow.lite.python.util import get_debug_info as _get_debug_info 

61from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config 

62from tensorflow.lite.python.util import get_sparsity_modes as _get_sparsity_modes 

63from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name 

64from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names 

65from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name 

66from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph 

67from tensorflow.lite.python.util import model_input_signature as _model_input_signature 

68from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type 

69from tensorflow.lite.python.util import populate_conversion_metadata as _populate_conversion_metadata 

70from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations 

71from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes 

72from tensorflow.lite.python.util import trace_model_call as _trace_model_call 

73from tensorflow.lite.tools import flatbuffer_utils 

74from tensorflow.lite.tools.optimize.debugging.python.debugger import QuantizationDebugger # pylint: disable=unused-import 

75from tensorflow.lite.tools.optimize.debugging.python.debugger import QuantizationDebugOptions # pylint: disable=unused-import 

76from tensorflow.python import saved_model as _saved_model 

77from tensorflow.python.client import session as _session 

78from tensorflow.python.eager import context 

79from tensorflow.python.eager import def_function as _def_function 

80from tensorflow.python.eager import function as _function 

81from tensorflow.python.framework import byte_swap_tensor as bst 

82from tensorflow.python.framework import convert_to_constants as _convert_to_constants 

83from tensorflow.python.framework import dtypes as _dtypes 

84from tensorflow.python.framework import ops as _ops 

85from tensorflow.python.framework import versions 

86from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError 

87from tensorflow.python.framework.importer import import_graph_def as _import_graph_def 

88from tensorflow.python.platform import gfile 

89from tensorflow.python.saved_model import loader_impl as _loader_impl 

90from tensorflow.python.saved_model import save_options as _save_options 

91from tensorflow.python.saved_model import signature_constants as _signature_constants 

92from tensorflow.python.saved_model import tag_constants as _tag_constants 

93from tensorflow.python.saved_model.load import load as _load 

94from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info 

95from tensorflow.python.util import deprecation as _deprecation 

96from tensorflow.python.util import keras_deps 

97from tensorflow.python.util.tf_export import tf_export as _tf_export 

98 

99 

100@_tf_export("lite.Optimize") 

101class Optimize(enum.Enum): 

102 """Enum defining the optimizations to apply when generating a tflite model. 

103 

104 DEFAULT 

105 The default optimization strategy that enables post-training quantization. 

106 The type of post-training quantization that will be used is dependent on 

107 the other converter options supplied. Refer to the 

108 [documentation](/lite/performance/post_training_quantization) for further 

109 information on the types available and how to use them. 

110 

111 OPTIMIZE_FOR_SIZE 

112 Deprecated. Does the same as DEFAULT. 

113 

114 OPTIMIZE_FOR_LATENCY 

115 Deprecated. Does the same as DEFAULT. 

116 

117 EXPERIMENTAL_SPARSITY 

118 Experimental flag, subject to change. 

119 

120 Enable optimization by taking advantage of the sparse model weights 

121 trained with pruning. 

122 

123 The converter will inspect the sparsity pattern of the model weights and 

124 do its best to improve size and latency. 

125 The flag can be used alone to optimize float32 models with sparse weights. 

126 It can also be used together with the DEFAULT optimization mode to 

127 optimize quantized models with sparse weights. 

128 """ 

129 

130 # Default optimization strategy that quantizes model weights. Enhanced 

131 # optimizations are gained by providing a representative dataset that 

132 # quantizes biases and activations as well. 

133 # Converter will do its best to reduce size and latency, while minimizing 

134 # the loss in accuracy. 

135 DEFAULT = "DEFAULT" 

136 

137 # Deprecated. Does the same as DEFAULT. 

138 OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE" 

139 

140 # Deprecated. Does the same as DEFAULT. 

141 OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY" 

142 

143 # Experimental flag, subject to change. 

144 # Enable optimization by taking advantage of the sparse model weights trained 

145 # with pruning. 

146 # 

147 # The converter will inspect the sparsity pattern of the model weights and do 

148 # its best to improve size and latency. 

149 # The flag can be used alone to optimize float32 models with sparse weights. 

150 # It can also be used together with the DEFAULT optimization mode to optimize 

151 # quantized models with sparse weights. 

152 # TODO(b/161560631): Add log message when this optimization is applied. 

153 EXPERIMENTAL_SPARSITY = "EXPERIMENTAL_SPARSITY" 

154 

155 def __str__(self): 

156 return str(self.value) 

157 

158 

159# TODO(b/198099651): move converter implementation out of lite.py 

160@_tf_export("lite.RepresentativeDataset") 

161class RepresentativeDataset: 

162 """Representative dataset used to optimize the model. 

163 

164 This is a generator function that provides a small dataset to calibrate or 

165 estimate the range, i.e, (min, max) of all floating-point arrays in the model 

166 (such as model input, activation outputs of intermediate layers, and model 

167 output) for quantization. Usually, this is a small subset of a few hundred 

168 samples randomly chosen, in no particular order, from the training or 

169 evaluation dataset. 

170 """ 

171 

172 def __init__(self, input_gen): 

173 """Creates a representative dataset. 

174 

175 Args: 

176 input_gen: A generator function that generates input samples for the model 

177 and has the same order, type and shape as the inputs to the model. 

178 Usually, this is a small subset of a few hundred samples randomly 

179 chosen, in no particular order, from the training or evaluation dataset. 

180 """ 

181 self.input_gen = input_gen 

182 

183 

184@_tf_export("lite.TargetSpec") 

185class TargetSpec: 

186 """Specification of target device used to optimize the model. 

187 

188 Attributes: 

189 supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet` 

190 options, where each option represents a set of operators supported by the 

191 target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS})) 

192 supported_types: Set of `tf.dtypes.DType` data types supported on the target 

193 device. If initialized, optimization might be driven by the smallest type 

194 in this set. (default set()) 

195 experimental_select_user_tf_ops: Experimental flag, subject to change. Set 

196 of user's TensorFlow operators' names that are required in the TensorFlow 

197 Lite runtime. These ops will be exported as select TensorFlow ops in the 

198 model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is 

199 an advanced feature that should only be used if the client is using TF ops 

200 that may not be linked in by default with the TF ops that are provided 

201 when using the SELECT_TF_OPS path. The client is responsible for linking 

202 these ops into the target runtime. 

203 experimental_supported_backends: Experimental flag, subject to change. Set 

204 containing names of supported backends. Currently only "GPU" is supported, 

205 more options will be available later. 

206 """ 

207 

208 def __init__( 

209 self, 

210 supported_ops=None, 

211 supported_types=None, 

212 experimental_select_user_tf_ops=None, 

213 experimental_supported_backends=None, 

214 ): 

215 if supported_ops is None: 

216 supported_ops = {OpsSet.TFLITE_BUILTINS} 

217 self.supported_ops = supported_ops 

218 if supported_types is None: 

219 supported_types = set() 

220 self.supported_types = supported_types 

221 if experimental_select_user_tf_ops is None: 

222 experimental_select_user_tf_ops = set() 

223 self.experimental_select_user_tf_ops = experimental_select_user_tf_ops 

224 self.experimental_supported_backends = experimental_supported_backends 

225 self._experimental_custom_op_registerers = [] 

226 # Hint for the supported accumulation type used for inference. Typically 

227 # used for fp16 post-training quantization, where some models can use fp16 

228 # accumulators instead of the typical fp32 type. 

229 # TODO(b/188185962): Provide full API and authoring support for 

230 # reduced precision accumulation types. 

231 self._experimental_supported_accumulation_type = None 

232 

233 

234class QuantizationMode: 

235 """QuantizationMode determines the quantization type from user options.""" 

236 

237 def __init__( 

238 self, 

239 optimizations, 

240 target_spec, 

241 representative_dataset, 

242 graph_def, 

243 disable_per_channel=False, 

244 experimental_new_dynamic_range_quantizer=False, 

245 experimental_low_bit_qat=False, 

246 full_integer_quantization_bias_type=None, 

247 experimental_mlir_variable_quantization=False, 

248 ): 

249 self._optimizations = optimizations 

250 for deprecated_optimization in [ 

251 Optimize.OPTIMIZE_FOR_SIZE, 

252 Optimize.OPTIMIZE_FOR_LATENCY, 

253 ]: 

254 if deprecated_optimization in self._optimizations: 

255 logging.warning( 

256 ( 

257 "Optimization option %s is deprecated, please use" 

258 " optimizations=[Optimize.DEFAULT] instead." 

259 ), 

260 deprecated_optimization, 

261 ) 

262 

263 self._target_spec = target_spec 

264 self._representative_dataset = representative_dataset 

265 self._graph_def = graph_def 

266 

267 self._validate_int8_required() 

268 self._disable_per_channel = disable_per_channel 

269 

270 self._enable_new_dynamic_range_quantizer = ( 

271 experimental_new_dynamic_range_quantizer 

272 ) 

273 # Allow training with lower than 8 bit weights to be converted 

274 # to constants with trained scale. 

275 self._experimental_low_bit_qat = experimental_low_bit_qat 

276 

277 self._full_integer_quantization_bias_type = ( 

278 full_integer_quantization_bias_type 

279 ) 

280 self._validate_full_integer_quantization_bias_type() 

281 

282 self.enable_mlir_variable_quantization = ( 

283 experimental_mlir_variable_quantization 

284 ) 

285 

286 def is_post_training_int8_only_quantization(self): 

287 return ( 

288 self.is_any_optimization_enabled() 

289 and self._representative_dataset is not None 

290 and not self._is_int16x8_target_required() 

291 and not self.is_allow_float() 

292 and self._is_int8_target_required() 

293 ) 

294 

295 def is_post_training_int8_quantization_with_float_fallback(self): 

296 return ( 

297 self.is_any_optimization_enabled() 

298 and self._representative_dataset is not None 

299 and not self._is_int16x8_target_required() 

300 and self.is_allow_float() 

301 and self._smallest_supported_type() == _dtypes.int8 

302 ) 

303 

304 def is_post_training_int8_quantization(self): 

305 return ( 

306 self.is_post_training_int8_only_quantization() 

307 or self.is_post_training_int8_quantization_with_float_fallback() 

308 ) 

309 

310 def is_post_training_int16x8_only_quantization(self): 

311 return ( 

312 self.is_any_optimization_enabled() 

313 and self._representative_dataset is not None 

314 and self._is_int16x8_target_required() 

315 and not self.is_allow_float() 

316 ) 

317 

318 def is_post_training_int16x8_quantization_with_float_fallback(self): 

319 return ( 

320 self.is_any_optimization_enabled() 

321 and self._representative_dataset is not None 

322 and self._is_int16x8_target_required() 

323 and self.is_allow_float() 

324 ) 

325 

326 def is_post_training_int16x8_quantization(self): 

327 return ( 

328 self.is_post_training_int16x8_only_quantization() 

329 or self.is_post_training_int16x8_quantization_with_float_fallback() 

330 ) 

331 

332 def is_post_training_integer_quantization(self): 

333 return ( 

334 self.is_post_training_int8_quantization() 

335 or self.is_post_training_int16x8_quantization() 

336 ) 

337 

338 def is_low_bit_quantize_aware_training(self): 

339 return ( 

340 self.is_any_optimization_enabled() 

341 and self.is_quantization_aware_trained_model() 

342 and self._experimental_low_bit_qat 

343 ) 

344 

345 def is_quantization_aware_training(self): 

346 return ( 

347 self.is_any_optimization_enabled() 

348 and self.is_quantization_aware_trained_model() 

349 and not self.is_low_bit_quantize_aware_training() 

350 ) 

351 

352 def is_integer_quantization(self): 

353 return ( 

354 self.is_post_training_integer_quantization() 

355 or self.is_quantization_aware_training() 

356 or self.is_low_bit_quantize_aware_training() 

357 ) 

358 

359 def is_post_training_dynamic_range_quantization(self): 

360 # Post-training dynamic range quantization is only enabled if post-training 

361 # int8 quantization and training time quantization was not done. 

362 return ( 

363 self.is_any_optimization_enabled() 

364 and self._representative_dataset is None 

365 and not self.is_quantization_aware_trained_model() 

366 and self._smallest_supported_type() == _dtypes.int8 

367 ) 

368 

369 def is_post_training_float16_quantization(self): 

370 return ( 

371 self.is_any_optimization_enabled() 

372 and self._smallest_supported_type().size == 2 

373 and _dtypes.float16 in self._target_spec.supported_types 

374 ) 

375 

376 def is_bfloat16_quantization(self): 

377 return ( 

378 self.is_any_optimization_enabled() 

379 and self._smallest_supported_type().size == 2 

380 and _dtypes.bfloat16 in self._target_spec.supported_types 

381 ) 

382 

383 def activations_type(self): 

384 if self.is_integer_quantization(): 

385 if self._is_int16x8_target_required(): 

386 return _dtypes.int16 

387 else: 

388 return _dtypes.int8 

389 else: 

390 return _dtypes.float32 

391 

392 def bias_type(self): 

393 if self._full_integer_quantization_bias_type: 

394 return self._full_integer_quantization_bias_type 

395 

396 if self.activations_type() == _dtypes.int16: 

397 return _dtypes.int64 

398 elif self.activations_type() == _dtypes.int8: 

399 return _dtypes.int32 

400 else: 

401 return _dtypes.float32 

402 

403 def converter_flags(self, inference_ty=None, inference_input_ty=None): 

404 """Flags to the converter.""" 

405 

406 if self.is_integer_quantization(): 

407 is_low_bit_qat = self.is_low_bit_quantize_aware_training() 

408 return { 

409 "inference_type": ( 

410 inference_ty 

411 if inference_ty is not None 

412 else self.activations_type() 

413 ), 

414 "inference_input_type": _dtypes.float32, 

415 "post_training_quantize": False, # disable dynamic range quantization 

416 "quantize_to_float16": False, # disable float16 quantization 

417 "disable_infer_tensor_range": is_low_bit_qat, 

418 "use_fake_quant_num_bits": is_low_bit_qat, 

419 "enable_mlir_variable_quantization": ( 

420 self.enable_mlir_variable_quantization 

421 ), 

422 } 

423 elif self.is_post_training_dynamic_range_quantization(): 

424 return { 

425 "inference_type": _dtypes.float32, 

426 "inference_input_type": _dtypes.float32, 

427 "post_training_quantize": True, # enable dynamic range quantization 

428 "quantize_to_float16": False, # disable float16 quantization 

429 # experimental: disable per-channel (per-axis) quantization. 

430 "disable_per_channel_quantization": self._disable_per_channel, 

431 "enable_mlir_dynamic_range_quantizer": ( 

432 self._enable_new_dynamic_range_quantizer 

433 ), 

434 "enable_mlir_variable_quantization": ( 

435 self.enable_mlir_variable_quantization 

436 ), 

437 } 

438 elif self.is_post_training_float16_quantization(): 

439 return { 

440 "inference_type": _dtypes.float32, 

441 "inference_input_type": _dtypes.float32, 

442 "post_training_quantize": True, 

443 "quantize_to_float16": True, # enable float16 quantization 

444 # pylint: disable=protected-access 

445 "accumulation_type": ( 

446 self._target_spec._experimental_supported_accumulation_type 

447 ), 

448 # pylint: enable=protected-access 

449 "allow_bfloat16": self.is_bfloat16_quantization(), 

450 "enable_mlir_dynamic_range_quantizer": ( 

451 self._enable_new_dynamic_range_quantizer 

452 ), 

453 "enable_mlir_variable_quantization": ( 

454 self.enable_mlir_variable_quantization 

455 ), 

456 } 

457 else: 

458 # Note this might still trigger (uint8) quantization to be compatible with 

459 # the old converter. 

460 return { 

461 "inference_type": ( 

462 inference_ty if inference_ty is not None else _dtypes.float32 

463 ), 

464 "inference_input_type": inference_input_ty, 

465 "post_training_quantize": False, # enable dynamic range quantization 

466 "quantize_to_float16": False, # disable float16 quantization 

467 "allow_bfloat16": self.is_bfloat16_quantization(), 

468 } 

469 

470 # Below are helpers for the above functions. 

471 

472 def _validate_int8_required(self): 

473 """Int8 mode requires certain parameters to exist and be compatible.""" 

474 if not self._is_int8_target_required(): 

475 return 

476 

477 # Validate target_spec attibute. 

478 if set(self._target_spec.supported_ops) == { 

479 OpsSet.TFLITE_BUILTINS_INT8 

480 } and not ( 

481 set(self._target_spec.supported_types) == set() 

482 or set(self._target_spec.supported_types) == {_dtypes.int8} 

483 ): 

484 raise ValueError( 

485 "As full integer quantization has been enabled by setting " 

486 "`target_spec.supported_ops`={tf.lite.OpsSet.TFLITE_BUILTINS_INT8}, " 

487 "thus `target_spec.supported_types` should be left uninitizalized " 

488 "or set to {tf.int8}." 

489 ) 

490 if set(self._target_spec.supported_types) == {_dtypes.int8}: 

491 self._target_spec.supported_ops = {OpsSet.TFLITE_BUILTINS_INT8} 

492 

493 # Check if representative_dataset is specified. 

494 if ( 

495 not self._representative_dataset 

496 and not self.is_quantization_aware_training() 

497 ): 

498 raise ValueError( 

499 "For full integer quantization, a " 

500 "`representative_dataset` must be specified." 

501 ) 

502 

503 # Update represenative dataset to the expected format. 

504 if self._representative_dataset: 

505 if not isinstance(self._representative_dataset, RepresentativeDataset): 

506 self._representative_dataset = RepresentativeDataset( 

507 self._representative_dataset 

508 ) 

509 

510 def _validate_full_integer_quantization_bias_type(self): 

511 """Validates bias type for full interger quantization.""" 

512 bias_type = self._full_integer_quantization_bias_type 

513 if not bias_type: 

514 return 

515 

516 if self.activations_type() == _dtypes.float32: 

517 raise ValueError( 

518 "`full_integer_quantization_bias_type` is only supported for full" 

519 " integer quantization." 

520 ) 

521 

522 if self.activations_type() == _dtypes.int8 and bias_type != _dtypes.int32: 

523 raise ValueError( 

524 "Expected bias type to be `dtypes.int32` for Int8Quant. " 

525 f"Current setting bias type: {bias_type}" 

526 ) 

527 

528 if ( 

529 self.activations_type() == _dtypes.int16 

530 and bias_type != _dtypes.int32 

531 and bias_type != _dtypes.int64 

532 ): 

533 raise ValueError( 

534 "Expected bias type to be `dtypes.int32` or `dtypes.int64` for " 

535 f"Int16Quant. Current setting bias type: {bias_type}" 

536 ) 

537 

538 def _is_int8_target_required(self): 

539 return ( 

540 OpsSet.TFLITE_BUILTINS_INT8 in set(self._target_spec.supported_ops) 

541 ) or (set(self._target_spec.supported_types) == set([_dtypes.int8])) 

542 

543 def _is_int16x8_target_required(self): 

544 return ( 

545 OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 

546 in set(self._target_spec.supported_ops) 

547 ) 

548 

549 def is_allow_float(self): 

550 return (OpsSet.TFLITE_BUILTINS in set(self._target_spec.supported_ops)) or ( 

551 OpsSet.SELECT_TF_OPS in set(self._target_spec.supported_ops) 

552 ) 

553 

554 def is_any_optimization_enabled(self): 

555 return bool( 

556 set(self._optimizations).intersection([ 

557 Optimize.OPTIMIZE_FOR_LATENCY, 

558 Optimize.OPTIMIZE_FOR_SIZE, 

559 Optimize.DEFAULT, 

560 ]) 

561 ) 

562 

563 def _smallest_supported_type(self): 

564 if self._target_spec.supported_types: 

565 return min(self._target_spec.supported_types, key=lambda x: x.size) 

566 else: 

567 # The default smallest supported type is INT8. 

568 return _dtypes.int8 

569 

570 def is_quantization_aware_trained_model(self): 

571 """Checks if the graph contains any training-time quantization ops.""" 

572 training_quant_ops = frozenset({ 

573 "FakeQuantWithMinMaxVars", 

574 "FakeQuantWithMinMaxVarsPerChannel", 

575 "FakeQuantWithMinMaxArgs", 

576 "QuantizeAndDequantizeV2", 

577 "QuantizeAndDequantizeV3", 

578 }) 

579 

580 if self._graph_def: 

581 for node_def in self._graph_def.node: 

582 if node_def.op in training_quant_ops: 

583 return True 

584 for function in self._graph_def.library.function: 

585 for node_def in function.node_def: 

586 if node_def.op in training_quant_ops: 

587 return True 

588 return False 

589 

590 

591class TFLiteConverterBase: 

592 """Converter subclass to share functionality between V1 and V2 converters.""" 

593 

594 # Stores the original model type temporarily to transmit the information 

595 # from the factory class methods to TFLiteConverterBase init function. 

596 _original_model_type = conversion_metdata_fb.ModelType.NONE 

597 

598 def __init__(self): 

599 self.optimizations = set() 

600 self.representative_dataset = None 

601 self.target_spec = TargetSpec() 

602 self.allow_custom_ops = False 

603 self.experimental_new_converter = True 

604 self.experimental_new_quantizer = True 

605 self.experimental_enable_resource_variables = True 

606 self._experimental_calibrate_only = False 

607 self._experimental_sparsify_model = False 

608 self._experimental_disable_per_channel = False 

609 self._debug_info = None # contains the stack traces of all the original 

610 # nodes in the `GraphDef` to the converter. 

611 self.saved_model_dir = None 

612 self._saved_model_tags = None 

613 self._saved_model_version = 0 

614 self._saved_model_exported_names = [] 

615 self._tflite_metrics = metrics.TFLiteConverterMetrics() 

616 self._collected_converter_params = {} 

617 self._experimental_disable_batchmatmul_unfold = False 

618 self._experimental_lower_tensor_list_ops = True 

619 self._experimental_default_to_single_batch_in_tensor_list_ops = False 

620 self._experimental_unfold_large_splat_constant = False 

621 self._experimental_tf_quantization_mode = None 

622 # If unset, bias:int32 is by default except 16x8 quant. 

623 # For 16x8 quant, bias:int64 is used to prevent any overflow by default. 

624 self._experimental_full_integer_quantization_bias_type = None 

625 # Provides specs for quantization, whether preset or custom. 

626 self._experimental_quantization_options = None 

627 # Initializes conversion metadata. 

628 self.exclude_conversion_metadata = False 

629 self._metadata = conversion_metdata_fb.ConversionMetadataT() 

630 self._metadata.environment = conversion_metdata_fb.EnvironmentT() 

631 self._metadata.options = conversion_metdata_fb.ConversionOptionsT() 

632 self._metadata.environment.tensorflowVersion = versions.__version__ 

633 self._metadata.environment.modelType = self._get_original_model_type() 

634 self._experimental_enable_dynamic_update_slice = False 

635 self._experimental_preserve_assert_op = False 

636 self._experimental_guarantee_all_funcs_one_use = False 

637 

638 # When the value is true, the MLIR quantantizer triggers dynamic range 

639 # quantization in MLIR instead of the old quantizer. Used only if 

640 # experimental_new_quantizer is on. 

641 self.experimental_new_dynamic_range_quantizer = True 

642 # Experimental flag to enable low-bit QAT in 8 bit. 

643 self._experimental_low_bit_qat = False 

644 # Experimental flag to add all TF ops (including custom TF ops) to the 

645 # converted model as flex ops. 

646 self._experimental_allow_all_select_tf_ops = False 

647 

648 self._experimental_variable_quantization = False 

649 self._experimental_disable_fuse_mul_and_fc = False 

650 

651 def _grappler_config(self, optimizers=None): 

652 """Creates a tf.compat.v1.ConfigProto for configuring Grappler. 

653 

654 Args: 

655 optimizers: List of strings that represents the list of optimizers. 

656 

657 Returns: 

658 tf.ConfigProto. 

659 """ 

660 if not optimizers: 

661 optimizers = [] 

662 # MLIR converter will take care of constant folding instead of grappler. 

663 if not self.experimental_new_converter: 

664 optimizers.append("constfold") 

665 

666 is_only_flex_enabled = set([OpsSet.SELECT_TF_OPS]) == set( 

667 self.target_spec.supported_ops 

668 ) 

669 if is_only_flex_enabled: 

670 # The layout optimizer turns NHCW to NCHW. This provides performance 

671 # optimizations when Flex mode is enabled. However, this is not compatible 

672 # with builtin ops. 

673 optimizers.append("layout") 

674 return _get_grappler_config(optimizers) 

675 

676 def _quantize( 

677 self, 

678 result, 

679 input_type, 

680 output_type, 

681 activations_type, 

682 bias_type, 

683 allow_float, 

684 enable_variable_quantization, 

685 ): 

686 """Quantize the model.""" 

687 # pylint: disable=protected-access 

688 custom_op_registerers_by_name = [ 

689 x 

690 for x in self.target_spec._experimental_custom_op_registerers 

691 if isinstance(x, str) 

692 ] 

693 custom_op_registerers_by_func = [ 

694 x 

695 for x in self.target_spec._experimental_custom_op_registerers 

696 if not isinstance(x, str) 

697 ] 

698 # pylint: enable=protected-access 

699 if not isinstance(self.representative_dataset, RepresentativeDataset): 

700 self.representative_dataset = RepresentativeDataset( 

701 self.representative_dataset 

702 ) 

703 

704 # Add intermediate tensors to the model if needed. 

705 result = _calibrator.add_intermediate_tensors(result) 

706 calibrate_quantize = _calibrator.Calibrator( 

707 result, custom_op_registerers_by_name, custom_op_registerers_by_func 

708 ) 

709 if self._experimental_calibrate_only or self.experimental_new_quantizer: 

710 calibrated = calibrate_quantize.calibrate( 

711 self.representative_dataset.input_gen 

712 ) 

713 

714 if self._experimental_calibrate_only: 

715 return calibrated 

716 elif self.experimental_new_quantizer and ( 

717 activations_type != _dtypes.int16 

718 ): 

719 # TODO(b/175659372): remove the activations_type restriction and enable 

720 # it for all the activation types. 

721 return _mlir_quantize( 

722 calibrated, 

723 self._experimental_disable_per_channel, 

724 input_data_type=input_type, 

725 output_data_type=output_type, 

726 enable_variable_quantization=enable_variable_quantization, 

727 ) 

728 else: 

729 return calibrate_quantize.calibrate_and_quantize( 

730 self.representative_dataset.input_gen, 

731 input_type, 

732 output_type, 

733 allow_float, 

734 activations_type, 

735 bias_type, 

736 disable_per_channel=self._experimental_disable_per_channel, 

737 ) 

738 

739 def _is_unknown_shapes_allowed(self): 

740 # Unknown dimensions are only allowed with the new converter. 

741 return self.experimental_new_converter 

742 

743 def _get_base_converter_args(self): 

744 """Returns the base converter args. 

745 

746 Returns: 

747 {key str: val} 

748 """ 

749 args = { 

750 "input_format": constants.TENSORFLOW_GRAPHDEF, 

751 "allow_custom_ops": self.allow_custom_ops, 

752 "debug_info": self._debug_info, 

753 "target_ops": self.target_spec.supported_ops, 

754 "enable_mlir_converter": self.experimental_new_converter, 

755 "select_user_tf_ops": self.target_spec.experimental_select_user_tf_ops, 

756 "supported_backends": self.target_spec.experimental_supported_backends, 

757 "unfold_batchmatmul": not self._experimental_disable_batchmatmul_unfold, 

758 "lower_tensor_list_ops": self._experimental_lower_tensor_list_ops, 

759 "unfold_large_splat_constant": ( 

760 self._experimental_unfold_large_splat_constant 

761 ), 

762 "default_to_single_batch_in_tensor_list_ops": ( 

763 self._experimental_default_to_single_batch_in_tensor_list_ops 

764 ), 

765 "tf_quantization_mode": self._experimental_tf_quantization_mode, 

766 "experimental_enable_resource_variables": ( 

767 self.experimental_enable_resource_variables 

768 ), 

769 "enable_dynamic_update_slice": ( 

770 self._experimental_enable_dynamic_update_slice 

771 ), 

772 "preserve_assert_op": self._experimental_preserve_assert_op, 

773 "guarantee_all_funcs_one_use": ( 

774 self._experimental_guarantee_all_funcs_one_use 

775 ), 

776 "allow_all_select_tf_ops": self._experimental_allow_all_select_tf_ops, 

777 "disable_fuse_mul_and_fc": self._experimental_disable_fuse_mul_and_fc, 

778 "quantization_options": self._experimental_quantization_options, 

779 } 

780 

781 if self.saved_model_dir: 

782 args.update({ 

783 "saved_model_dir": self.saved_model_dir, 

784 "saved_model_version": self._saved_model_version, 

785 "saved_model_tags": self._saved_model_tags, 

786 "saved_model_exported_names": self._saved_model_exported_names, 

787 }) 

788 

789 return args 

790 

791 def _contains_function_with_implements_attr(self, saved_model_proto): 

792 meta_graph = saved_model_proto.meta_graphs[0] 

793 for function in meta_graph.graph_def.library.function: 

794 if function.attr.get("_implements", None) or function.attr.get( 

795 "api_implements", None 

796 ): 

797 return True 

798 return False 

799 

800 def _parse_saved_model_args(self, always_enable_saved_model_import=False): 

801 """Parses SavedModel arguments from the given Keras/RNN SavedModel. 

802 

803 Args: 

804 always_enable_saved_model_import: Bool. When the value is true, it enables 

805 MLIR saved model import path regardless of checking the conditions. 

806 """ 

807 if not self.experimental_new_converter: 

808 self.saved_model_dir = None 

809 return 

810 if self.saved_model_dir: 

811 try: 

812 saved_model_proto, _ = _parse_saved_model_with_debug_info( 

813 self.saved_model_dir 

814 ) 

815 except OSError: 

816 # If it fails to read the given saved model, it will fall back to the 

817 # frozen graph def path. 

818 self.saved_model_dir = None 

819 return 

820 if ( 

821 not always_enable_saved_model_import 

822 and not self._contains_function_with_implements_attr( 

823 saved_model_proto 

824 ) 

825 ): 

826 self.saved_model_dir = None 

827 return 

828 

829 if not self._saved_model_exported_names: 

830 self._saved_model_exported_names = [] 

831 self._saved_model_version = saved_model_proto.saved_model_schema_version 

832 if self._saved_model_version == 0: 

833 self.saved_model_dir = None 

834 logging.warning("SavedModel schema version is zero.") 

835 return 

836 if self._saved_model_version not in [1, 2]: 

837 raise ValueError( 

838 "SavedModel file format({0}) is not supported".format( 

839 self._saved_model_version 

840 ) 

841 ) 

842 

843 def _sparsify_model(self): 

844 return Optimize.EXPERIMENTAL_SPARSITY in self.optimizations 

845 

846 def _increase_conversion_attempt_metric(self): 

847 self._tflite_metrics.increase_counter_converter_attempt() 

848 

849 def _increase_conversion_success_metric(self): 

850 self._tflite_metrics.increase_counter_converter_success() 

851 

852 @classmethod 

853 def _set_original_model_type(cls, model_type): 

854 """Stores the original model type.""" 

855 if model_type == conversion_metdata_fb.ModelType.NONE: 

856 raise ValueError("The original model type should be specified.") 

857 cls._original_model_type = model_type 

858 

859 def _get_original_model_type(self): 

860 """One-time getter to return original model type and set it to NONE.""" 

861 model_type = TFLiteConverterBase._original_model_type 

862 TFLiteConverterBase._original_model_type = ( 

863 conversion_metdata_fb.ModelType.NONE 

864 ) 

865 return model_type 

866 

867 def _save_conversion_params_metric( 

868 self, graph_def=None, inference_type=None, inference_input_type=None 

869 ): 

870 """Set conversion parameter metrics.""" 

871 converter_kwargs = self._collected_converter_params 

872 converter_kwargs.update(self._get_base_converter_args()) 

873 

874 # Optimization parameters. 

875 quant_mode = QuantizationMode( 

876 self.optimizations, 

877 self.target_spec, 

878 self.representative_dataset, 

879 graph_def, 

880 self._experimental_disable_per_channel, 

881 self.experimental_new_dynamic_range_quantizer, 

882 self._experimental_low_bit_qat, 

883 self._experimental_full_integer_quantization_bias_type, 

884 self._experimental_variable_quantization, 

885 ) 

886 converter_kwargs.update({ 

887 "tf_version": self._metadata.environment.tensorflowVersion, 

888 "api_version": self._metadata.environment.apiVersion, 

889 "original_model_format": self._metadata.environment.modelType, 

890 "optimization_default": quant_mode.is_any_optimization_enabled(), 

891 "optimization_post_training_dynamic_range": ( 

892 quant_mode.is_post_training_dynamic_range_quantization() 

893 ), 

894 "optimization_post_training_float16": ( 

895 quant_mode.is_post_training_float16_quantization() 

896 ), 

897 "optimization_post_training_integer_quantize": ( 

898 quant_mode.is_post_training_integer_quantization() 

899 ), 

900 "optimization_qat": quant_mode.is_quantization_aware_training(), 

901 "optimization_low_bit_qat": ( 

902 quant_mode.is_low_bit_quantize_aware_training() 

903 ), 

904 "optimization_sparsify": self._sparsify_model(), 

905 "activations_type": quant_mode.activations_type(), 

906 }) 

907 converter_kwargs.update( 

908 quant_mode.converter_flags(inference_type, inference_input_type) 

909 ) 

910 

911 # pylint: disable=protected-access 

912 if self.target_spec._experimental_supported_accumulation_type: 

913 converter_kwargs.update( 

914 { 

915 "accumulation_type": ( 

916 self.target_spec._experimental_supported_accumulation_type 

917 ) 

918 } 

919 ) 

920 # pylint: enable=protected-access 

921 

922 def format_element(elem): 

923 if isinstance(elem, enum.Enum): 

924 return str(elem.value) 

925 return pprint.pformat(elem) 

926 

927 def format_param(param): 

928 if isinstance(param, (list, tuple, set)): 

929 if not param: 

930 return "None" # Return None if empty. 

931 string_list = [format_element(x) for x in param] 

932 return ",".join(sorted(string_list)) 

933 return format_element(param) 

934 

935 for key, value in converter_kwargs.items(): 

936 self._tflite_metrics.set_converter_param(key, format_param(value)) 

937 self._tflite_metrics.set_export_required() 

938 

939 # Set conversion option metadata. 

940 self._metadata.options.allowCustomOps = self.allow_custom_ops 

941 self._metadata.options.enableSelectTfOps = ( 

942 OpsSet.SELECT_TF_OPS in self.target_spec.supported_ops 

943 ) 

944 self._metadata.options.forceSelectTfOps = set( 

945 [OpsSet.SELECT_TF_OPS] 

946 ) == set(self.target_spec.supported_ops) 

947 self._metadata.options.modelOptimizationModes = [] 

948 

949 if quant_mode.is_post_training_float16_quantization(): 

950 self._metadata.options.modelOptimizationModes.append( 

951 conversion_metdata_fb.ModelOptimizationMode.PTQ_FLOAT16 

952 ) 

953 

954 if quant_mode.is_post_training_dynamic_range_quantization(): 

955 self._metadata.options.modelOptimizationModes.append( 

956 conversion_metdata_fb.ModelOptimizationMode.PTQ_DYNAMIC_RANGE 

957 ) 

958 

959 if quant_mode.is_post_training_int8_quantization(): 

960 self._metadata.options.modelOptimizationModes.append( 

961 conversion_metdata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER 

962 ) 

963 

964 if quant_mode.is_post_training_int16x8_quantization(): 

965 self._metadata.options.modelOptimizationModes.append( 

966 conversion_metdata_fb.ModelOptimizationMode.PTQ_INT16 

967 ) 

968 

969 if quant_mode.is_quantization_aware_training(): 

970 self._metadata.options.modelOptimizationModes.append( 

971 conversion_metdata_fb.ModelOptimizationMode.QUANTIZATION_AWARE_TRAINING 

972 ) 

973 

974 def _set_conversion_latency_metric(self, value): 

975 self._tflite_metrics.set_converter_latency(value) 

976 

977 @convert_phase(Component.OPTIMIZE_TFLITE_MODEL) 

978 def _optimize_tflite_model(self, model, quant_mode, quant_io=True): 

979 """Apply optimizations on a TFLite model.""" 

980 

981 if quant_mode.is_integer_quantization(): 

982 in_type, out_type = self.inference_input_type, self.inference_output_type 

983 

984 if quant_mode.is_post_training_integer_quantization(): 

985 q_in_type = in_type if in_type and quant_io else _dtypes.float32 

986 q_out_type = out_type if out_type and quant_io else _dtypes.float32 

987 q_activations_type = quant_mode.activations_type() 

988 q_bias_type = quant_mode.bias_type() 

989 q_allow_float = quant_mode.is_allow_float() 

990 q_variable_quantization = quant_mode.enable_mlir_variable_quantization 

991 model = self._quantize( 

992 model, 

993 q_in_type, 

994 q_out_type, 

995 q_activations_type, 

996 q_bias_type, 

997 q_allow_float, 

998 q_variable_quantization, 

999 ) 

1000 

1001 m_in_type = in_type if in_type else _dtypes.float32 

1002 m_out_type = out_type if out_type else _dtypes.float32 

1003 # Skip updating model io types if MLIR quantizer already takes care of it 

1004 if not ( 

1005 quant_mode.is_post_training_integer_quantization() 

1006 and self.experimental_new_quantizer 

1007 and quant_io 

1008 and (m_in_type in [_dtypes.int8, _dtypes.uint8, _dtypes.float32]) 

1009 and (m_out_type in [_dtypes.int8, _dtypes.uint8, _dtypes.float32]) 

1010 ): 

1011 model = _modify_model_io_type(model, m_in_type, m_out_type) 

1012 

1013 if self._sparsify_model(): 

1014 model = _mlir_sparsify(model) 

1015 

1016 try: 

1017 model = _deduplicate_readonly_buffers(model) 

1018 except Exception: # pylint: disable=broad-except 

1019 # Skip buffer deduplication when flatbuffer library is not ready to be 

1020 # utilized. 

1021 logging.warning( 

1022 "Buffer deduplication procedure will be skipped when flatbuffer " 

1023 "library is not properly loaded" 

1024 ) 

1025 

1026 return model 

1027 

1028 def _convert_and_export_metrics(self, convert_func, *args, **kwargs): 

1029 """Wraps around convert function to export metrics. 

1030 

1031 Args: 

1032 convert_func: The convert function to wrap. 

1033 *args: Positional arguments of the convert function. 

1034 **kwargs: The keyword arguments of the convert function. 

1035 

1036 Returns: 

1037 The decorator to wrap the convert function. 

1038 """ 

1039 self._increase_conversion_attempt_metric() 

1040 self._save_conversion_params_metric() 

1041 start_time = time.process_time() 

1042 result = convert_func(self, *args, **kwargs) 

1043 elapsed_time_ms = (time.process_time() - start_time) * 1000 

1044 if result: 

1045 self._increase_conversion_success_metric() 

1046 self._set_conversion_latency_metric(round(elapsed_time_ms)) 

1047 self._tflite_metrics.export_metrics() 

1048 if self.exclude_conversion_metadata: 

1049 return result 

1050 model_object = flatbuffer_utils.convert_bytearray_to_object(result) 

1051 # Populates the conversion metadata. 

1052 # TODO(b/202090541): Collects sparsity block size information. 

1053 sparsity_modes = _get_sparsity_modes(model_object) 

1054 self._metadata.options.modelOptimizationModes.extend(sparsity_modes) 

1055 model_object = _populate_conversion_metadata(model_object, self._metadata) 

1056 return flatbuffer_utils.convert_object_to_bytearray(model_object) 

1057 

1058 

1059def _export_metrics(convert_func): 

1060 """The decorator around convert function to export metrics.""" 

1061 

1062 @functools.wraps(convert_func) 

1063 def wrapper(self, *args, **kwargs): 

1064 # pylint: disable=protected-access 

1065 return self._convert_and_export_metrics(convert_func, *args, **kwargs) 

1066 # pylint: enable=protected-access 

1067 

1068 return wrapper 

1069 

1070 

1071class TFLiteConverterBaseV2(TFLiteConverterBase): 

1072 """Converter subclass to share functionality between V2 converters.""" 

1073 

1074 def __init__(self): 

1075 """Constructor for TFLiteConverter.""" 

1076 super(TFLiteConverterBaseV2, self).__init__() 

1077 self.inference_input_type = _dtypes.float32 

1078 self.inference_output_type = _dtypes.float32 

1079 self._metadata.environment.apiVersion = 2 

1080 

1081 def _validate_inference_input_output_types(self, quant_mode): 

1082 """Validate inference_input_type and inference_output_type flags.""" 

1083 default_types = [_dtypes.float32] 

1084 # We support integer input/output for integer quantized models only. 

1085 if quant_mode.is_integer_quantization(): 

1086 if quant_mode.is_post_training_int16x8_quantization(): 

1087 all_types = default_types + [_dtypes.int16] 

1088 else: 

1089 all_types = default_types + [_dtypes.int8, _dtypes.uint8] 

1090 if ( 

1091 self.inference_input_type not in all_types 

1092 or self.inference_output_type not in all_types 

1093 ): 

1094 all_types_names = ["tf." + t.name for t in all_types] 

1095 raise ValueError( 

1096 "The inference_input_type and inference_output_type " 

1097 "must be in {}.".format(all_types_names) 

1098 ) 

1099 elif ( 

1100 self.inference_input_type not in default_types 

1101 or self.inference_output_type not in default_types 

1102 ): 

1103 raise ValueError( 

1104 "The inference_input_type and inference_output_type " 

1105 "must be tf.float32." 

1106 ) 

1107 

1108 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.LOAD_SAVED_MODEL) 

1109 def _load_saved_model(self, saved_model_dir, saved_model_tags): 

1110 """Load graph_def from saved model with the default serving signature key. 

1111 

1112 Args: 

1113 saved_model_dir: Directory of the SavedModel. 

1114 saved_model_tags: Set of tags identifying the MetaGraphDef within the 

1115 SavedModel to analyze. 

1116 

1117 Returns: 

1118 graph_def: The loaded GraphDef. 

1119 input_tensors: List of input tensors. 

1120 output_tensors: List of output tensors. 

1121 """ 

1122 graph = _ops.Graph() 

1123 saved_model = _loader_impl.SavedModelLoader(saved_model_dir) 

1124 saved_model.load_graph(graph, tags=saved_model_tags) 

1125 meta_graph = saved_model.get_meta_graph_def_from_tags(saved_model_tags) 

1126 graph_def = meta_graph.graph_def 

1127 signature_def = meta_graph.signature_def[ 

1128 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 

1129 ] 

1130 input_tensors = [ 

1131 graph.get_tensor_by_name(signature_def.inputs[key].name) 

1132 for key in signature_def.inputs 

1133 ] 

1134 output_tensors = [ 

1135 graph.get_tensor_by_name(signature_def.outputs[key].name) 

1136 for key in signature_def.outputs 

1137 ] 

1138 return graph_def, input_tensors, output_tensors 

1139 

1140 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS) 

1141 def _validate_inputs(self, graph_def, input_tensors): 

1142 """Validate the input parameters. 

1143 

1144 Args: 

1145 graph_def: The TensorFlow GraphDef. 

1146 input_tensors: List of input tensors. 

1147 Raise: 

1148 ValueError: Input shape is not specified. Invalid quantization parameters. 

1149 """ 

1150 # Update conversion params with graph_def. 

1151 self._save_conversion_params_metric(graph_def) 

1152 self._quant_mode = QuantizationMode( 

1153 self.optimizations, 

1154 self.target_spec, 

1155 self.representative_dataset, 

1156 graph_def, 

1157 self._experimental_disable_per_channel, 

1158 self.experimental_new_dynamic_range_quantizer, 

1159 self._experimental_low_bit_qat, 

1160 self._experimental_full_integer_quantization_bias_type, 

1161 self._experimental_variable_quantization, 

1162 ) 

1163 self._validate_inference_input_output_types(self._quant_mode) 

1164 

1165 if not self._is_unknown_shapes_allowed(): 

1166 # Checks dimensions in input tensor. 

1167 for tensor in input_tensors: 

1168 # Note that shape_list might be empty for scalar shapes. 

1169 shape_list = tensor.shape.as_list() 

1170 if None in shape_list[1:]: 

1171 raise ValueError( 

1172 "None is only supported in the 1st dimension. Tensor '{0}' has " 

1173 "invalid shape '{1}'.".format( 

1174 _get_tensor_name(tensor), shape_list 

1175 ) 

1176 ) 

1177 elif shape_list and shape_list[0] is None: 

1178 # Set the batch size to 1 if undefined. 

1179 shape = tensor.shape.as_list() 

1180 shape[0] = 1 

1181 tensor.set_shape(shape) 

1182 

1183 if self._trackable_obj is None or not hasattr( 

1184 self._trackable_obj, "graph_debug_info" 

1185 ): 

1186 self._debug_info = _get_debug_info( 

1187 _build_debug_info_func(self._funcs[0].graph), graph_def 

1188 ) 

1189 else: 

1190 self._debug_info = _get_debug_info( 

1191 _convert_debug_info_func(self._trackable_obj.graph_debug_info), 

1192 graph_def, 

1193 ) 

1194 

1195 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL) 

1196 def _optimize_tf_model( 

1197 self, graph_def, input_tensors, output_tensors, frozen_func 

1198 ): 

1199 """Run a Grappler pass to optimize the TensorFlow graph. 

1200 

1201 Args: 

1202 graph_def: Frozen GraphDef to be optimized. 

1203 input_tensors: List of input tensors. 

1204 output_tensors: List of output tensors. 

1205 frozen_func: TensorFlow Graph. 

1206 

1207 Returns: 

1208 The optimized TensorFlow graph. 

1209 """ 

1210 grappler_config = self._grappler_config() 

1211 # Skip running grappler when there are no optimizers to run. If not, 

1212 # grappler will run with the default optimizer set and it will lead to 

1213 # causing an unexpected behavior. 

1214 if grappler_config.graph_options.rewrite_options.optimizers: 

1215 graph_def = _run_graph_optimizations( 

1216 graph_def, 

1217 input_tensors, 

1218 output_tensors, 

1219 config=grappler_config, 

1220 graph=frozen_func.graph, 

1221 ) 

1222 return graph_def 

1223 

1224 def _convert_from_saved_model(self, graph_def): 

1225 """Helper method that converts saved model. 

1226 

1227 Args: 

1228 graph_def: GraphDef object for the model, used only for stats. 

1229 

1230 Returns: 

1231 The converted TFLite model. 

1232 """ 

1233 # Update conversion params with graph_def. 

1234 self._save_conversion_params_metric(graph_def) 

1235 # Get quantization options and do some sanity checks. 

1236 quant_mode = QuantizationMode( 

1237 self.optimizations, 

1238 self.target_spec, 

1239 self.representative_dataset, 

1240 graph_def, 

1241 self._experimental_disable_per_channel, 

1242 self.experimental_new_dynamic_range_quantizer, 

1243 self._experimental_low_bit_qat, 

1244 self._experimental_full_integer_quantization_bias_type, 

1245 self._experimental_variable_quantization, 

1246 ) 

1247 self._validate_inference_input_output_types(quant_mode) 

1248 converter_kwargs = { 

1249 "enable_tflite_resource_variables": ( 

1250 self.experimental_enable_resource_variables 

1251 ) 

1252 } 

1253 converter_kwargs.update(self._get_base_converter_args()) 

1254 converter_kwargs.update(quant_mode.converter_flags()) 

1255 

1256 result = _convert_saved_model(**converter_kwargs) 

1257 return self._optimize_tflite_model( 

1258 result, quant_mode, quant_io=self.experimental_new_quantizer 

1259 ) 

1260 

1261 def convert(self, graph_def, input_tensors, output_tensors): 

1262 """Converts a TensorFlow GraphDef based on instance variables. 

1263 

1264 Args: 

1265 graph_def: Frozen TensorFlow GraphDef. 

1266 input_tensors: List of input tensors. 

1267 output_tensors: List of output tensors. 

1268 

1269 Returns: 

1270 The converted data in serialized format. 

1271 

1272 Raises: 

1273 ValueError: 

1274 No concrete functions is specified. 

1275 Multiple concrete functions are specified. 

1276 Input shape is not specified. 

1277 Invalid quantization parameters. 

1278 """ 

1279 self._validate_inputs(graph_def, input_tensors) 

1280 converter_kwargs = self._get_base_converter_args() 

1281 converter_kwargs.update(self._quant_mode.converter_flags()) 

1282 if not self.experimental_new_converter: 

1283 logging.warning( 

1284 "Please consider switching to the new converter by setting " 

1285 "experimental_new_converter=True. " 

1286 "The old converter is deprecated." 

1287 ) 

1288 else: 

1289 logging.info( 

1290 "Using new converter: If you encounter a problem " 

1291 "please file a bug. You can opt-out " 

1292 "by setting experimental_new_converter=False" 

1293 ) 

1294 

1295 # Converts model. 

1296 result = _convert_graphdef( 

1297 input_data=graph_def, 

1298 input_tensors=input_tensors, 

1299 output_tensors=output_tensors, 

1300 **converter_kwargs, 

1301 ) 

1302 

1303 return self._optimize_tflite_model( 

1304 result, self._quant_mode, quant_io=self.experimental_new_quantizer 

1305 ) 

1306 

1307 

1308class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2): 

1309 """Converts the given SavedModel into TensorFlow Lite model. 

1310 

1311 Attributes: 

1312 saved_model_dir: Directory of the SavedModel. 

1313 """ 

1314 

1315 def __init__( 

1316 self, 

1317 saved_model_dir, 

1318 saved_model_tags=None, 

1319 saved_model_exported_names=None, 

1320 trackable_obj=None, 

1321 ): 

1322 """Constructor for TFLiteConverter. 

1323 

1324 Args: 

1325 saved_model_dir: Directory of the SavedModel. 

1326 saved_model_tags: Set of tags identifying the MetaGraphDef within the 

1327 SavedModel to analyze. All tags in the tag set must be present. (default 

1328 {tf.saved_model.SERVING}). 

1329 saved_model_exported_names: Names to be exported when the saved model 

1330 import path is on. 

1331 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 

1332 reference to this object needs to be maintained so that Variables do not 

1333 get garbage collected since functions have a weak reference to 

1334 Variables. This is only required when the tf.AutoTrackable object is not 

1335 maintained by the user (e.g. `from_saved_model`). 

1336 """ 

1337 super(TFLiteSavedModelConverterV2, self).__init__() 

1338 self.saved_model_dir = saved_model_dir 

1339 self._saved_model_tags = saved_model_tags 

1340 self._saved_model_exported_names = saved_model_exported_names 

1341 self._trackable_obj = trackable_obj 

1342 self._parse_saved_model_args(always_enable_saved_model_import=True) 

1343 

1344 @_export_metrics 

1345 def convert(self): 

1346 """Converts a TensorFlow GraphDef based on instance variables. 

1347 

1348 Returns: 

1349 The converted data in serialized format. 

1350 

1351 Raises: 

1352 ValueError: 

1353 No concrete functions is specified. 

1354 Multiple concrete functions are specified. 

1355 Input shape is not specified. 

1356 Invalid quantization parameters. 

1357 """ 

1358 graph_def, input_tensors, output_tensors = self._load_saved_model( 

1359 self.saved_model_dir, self._saved_model_tags 

1360 ) 

1361 # If we can't use saved model importer, then fallback 

1362 # to frozen graph conversion path. 

1363 if self.saved_model_dir is None or not self.experimental_new_converter: 

1364 graph_def, _, _, _ = _freeze_saved_model( 

1365 self.saved_model_dir, 

1366 None, 

1367 None, 

1368 None, 

1369 self._saved_model_tags, 

1370 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, 

1371 ) 

1372 # We make sure to clear the saved_model_dir as there is some 

1373 # legacy code down in the caller that checks this. 

1374 # TODO(b/162537905): Clean these indirect dependencies. 

1375 self.saved_model_dir = None 

1376 return super(TFLiteSavedModelConverterV2, self).convert( 

1377 graph_def, input_tensors, output_tensors 

1378 ) 

1379 

1380 if self._trackable_obj is None: 

1381 self._debug_info = _get_debug_info( 

1382 _build_debug_info_func(self._funcs[0].graph), graph_def 

1383 ) 

1384 else: 

1385 self._debug_info = _get_debug_info( 

1386 _convert_debug_info_func(self._trackable_obj.graph_debug_info), 

1387 graph_def, 

1388 ) 

1389 

1390 return self._convert_from_saved_model(graph_def) 

1391 

1392 

1393class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2): 

1394 """Converts the given Keras model into TensorFlow Lite model.""" 

1395 

1396 def __init__(self, keras_model, trackable_obj=None): 

1397 """Constructor for TFLiteConverter. 

1398 

1399 Args: 

1400 keras_model: tf.Keras.Model. 

1401 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 

1402 reference to this object needs to be maintained so that Variables do not 

1403 get garbage collected since functions have a weak reference to 

1404 Variables. This is only required when the tf.AutoTrackable object is not 

1405 maintained by the user (e.g. `from_saved_model`). 

1406 """ 

1407 super(TFLiteKerasModelConverterV2, self).__init__() 

1408 self._keras_model = keras_model 

1409 self._trackable_obj = trackable_obj 

1410 self.experimental_lower_to_saved_model = True 

1411 

1412 @convert_phase( 

1413 Component.PREPARE_TF_MODEL, SubComponent.CONVERT_KERAS_TO_SAVED_MODEL 

1414 ) 

1415 def _convert_keras_to_saved_model(self, output_dir): 

1416 """Save Keras model to the SavedModel format. 

1417 

1418 Args: 

1419 output_dir: The output directory to save the SavedModel. 

1420 

1421 Returns: 

1422 graph_def: The frozen GraphDef. 

1423 input_tensors: List of input tensors. 

1424 output_tensors: List of output tensors. 

1425 """ 

1426 try: 

1427 _saved_model.save( 

1428 self._keras_model, 

1429 output_dir, 

1430 options=_save_options.SaveOptions(save_debug_info=True), 

1431 ) 

1432 except Exception: # pylint: disable=broad-except 

1433 # When storing the given keras model to a saved model is failed, let's 

1434 # use original keras model conversion pipeline. 

1435 return None, None, None 

1436 self.saved_model_dir = output_dir 

1437 self._saved_model_tags = set([_tag_constants.SERVING]) 

1438 self._saved_model_exported_names = [ 

1439 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 

1440 ] 

1441 self._parse_saved_model_args( 

1442 always_enable_saved_model_import=self.experimental_lower_to_saved_model 

1443 ) 

1444 if self.saved_model_dir: 

1445 graph_def, input_tensors, output_tensors = self._load_saved_model( 

1446 self.saved_model_dir, self._saved_model_tags 

1447 ) 

1448 self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags) 

1449 return graph_def, input_tensors, output_tensors 

1450 return None, None, None 

1451 

1452 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL) 

1453 def _freeze_keras_model(self): 

1454 """Freeze Keras model to frozen graph. 

1455 

1456 Returns: 

1457 graph_def: The frozen GraphDef. 

1458 input_tensors: List of input tensors. 

1459 output_tensors: List of output tensors. 

1460 frozen_func: The frozen ConcreteFunction. 

1461 """ 

1462 input_signature = None 

1463 # If the model's call is not a `tf.function`, then we need to first get its 

1464 # input signature from `model_input_signature` method. We can't directly 

1465 # call `trace_model_call` because otherwise the batch dimension is set 

1466 # to None. 

1467 # Once we have better support for dynamic shapes, we can remove this. 

1468 if not isinstance(self._keras_model.call, _def_function.Function): 

1469 # Pass `keep_original_batch_size=True` will ensure that we get an input 

1470 # signature including the batch dimension specified by the user. 

1471 # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF 

1472 input_signature = _model_input_signature( 

1473 self._keras_model, keep_original_batch_size=True 

1474 ) 

1475 

1476 # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF 

1477 func = _trace_model_call(self._keras_model, input_signature) 

1478 concrete_func = func.get_concrete_function() 

1479 self._funcs = [concrete_func] 

1480 

1481 frozen_func, graph_def = ( 

1482 _convert_to_constants.convert_variables_to_constants_v2_as_graph( 

1483 self._funcs[0], lower_control_flow=False 

1484 ) 

1485 ) 

1486 

1487 input_tensors = [ 

1488 tensor 

1489 for tensor in frozen_func.inputs 

1490 if tensor.dtype != _dtypes.resource 

1491 ] 

1492 output_tensors = frozen_func.outputs 

1493 return graph_def, input_tensors, output_tensors, frozen_func 

1494 

1495 def _convert_as_saved_model(self): 

1496 """Converts a Keras model as a saved model. 

1497 

1498 Returns: 

1499 The converted data in serialized format. 

1500 """ 

1501 temp_dir = tempfile.mkdtemp() 

1502 try: 

1503 graph_def, input_tensors, output_tensors = ( 

1504 self._convert_keras_to_saved_model(temp_dir) 

1505 ) 

1506 if self.saved_model_dir: 

1507 return super(TFLiteKerasModelConverterV2, self).convert( 

1508 graph_def, input_tensors, output_tensors 

1509 ) 

1510 finally: 

1511 shutil.rmtree(temp_dir, True) 

1512 

1513 @_export_metrics 

1514 def convert(self): 

1515 """Converts a keras model based on instance variables. 

1516 

1517 Returns: 

1518 The converted data in serialized format. 

1519 

1520 Raises: 

1521 ValueError: 

1522 Multiple concrete functions are specified. 

1523 Input shape is not specified. 

1524 Invalid quantization parameters. 

1525 """ 

1526 saved_model_convert_result = self._convert_as_saved_model() 

1527 if saved_model_convert_result: 

1528 return saved_model_convert_result 

1529 

1530 graph_def, input_tensors, output_tensors, frozen_func = ( 

1531 self._freeze_keras_model() 

1532 ) 

1533 

1534 graph_def = self._optimize_tf_model( 

1535 graph_def, input_tensors, output_tensors, frozen_func 

1536 ) 

1537 

1538 return super(TFLiteKerasModelConverterV2, self).convert( 

1539 graph_def, input_tensors, output_tensors 

1540 ) 

1541 

1542 

1543class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2): 

1544 """Converts the given frozen graph into TensorFlow Lite model.""" 

1545 

1546 def __init__(self, funcs, trackable_obj=None): 

1547 """Constructor for TFLiteConverter. 

1548 

1549 Args: 

1550 funcs: List of TensorFlow ConcreteFunctions. The list should not contain 

1551 duplicate elements. 

1552 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 

1553 reference to this object needs to be maintained so that Variables do not 

1554 get garbage collected since functions have a weak reference to 

1555 Variables. This is only required when the tf.AutoTrackable object is not 

1556 maintained by the user (e.g. `from_saved_model`). 

1557 """ 

1558 super(TFLiteFrozenGraphConverterV2, self).__init__() 

1559 self._funcs = funcs 

1560 self._trackable_obj = trackable_obj 

1561 self.experimental_lower_to_saved_model = True 

1562 

1563 @convert_phase( 

1564 Component.PREPARE_TF_MODEL, SubComponent.FREEZE_CONCRETE_FUNCTION 

1565 ) 

1566 def _freeze_concrete_function(self): 

1567 """Convert the given ConcreteFunction to frozen graph. 

1568 

1569 Returns: 

1570 graph_def: The frozen GraphDef. 

1571 input_tensors: List of input tensors. 

1572 output_tensors: List of output tensors. 

1573 frozen_func: The frozen ConcreteFunction. 

1574 

1575 Raises: 

1576 ValueError: none or multiple ConcreteFunctions provided. 

1577 """ 

1578 # TODO(b/130297984): Add support for converting multiple function. 

1579 

1580 if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test 

1581 raise ValueError("No ConcreteFunction is specified.") 

1582 

1583 if len(self._funcs) > 1: 

1584 raise ValueError( 

1585 "This converter can only convert a single " 

1586 "ConcreteFunction. Converting multiple functions is " 

1587 "under development." 

1588 ) 

1589 

1590 frozen_func, graph_def = ( 

1591 _convert_to_constants.convert_variables_to_constants_v2_as_graph( 

1592 self._funcs[0], lower_control_flow=False 

1593 ) 

1594 ) 

1595 

1596 input_tensors = [ 

1597 tensor 

1598 for tensor in frozen_func.inputs 

1599 if tensor.dtype != _dtypes.resource 

1600 ] 

1601 output_tensors = frozen_func.outputs 

1602 return graph_def, input_tensors, output_tensors, frozen_func 

1603 

1604 @convert_phase( 

1605 Component.PREPARE_TF_MODEL, 

1606 SubComponent.CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL, 

1607 ) 

1608 def _convert_concrete_functions_to_saved_model(self, output_dir): 

1609 """Save concrete functions to the SavedModel format. 

1610 

1611 Args: 

1612 output_dir: The output directory to save the SavedModel. 

1613 

1614 Returns: 

1615 graph_def: The frozen GraphDef. 

1616 input_tensors: List of input tensors. 

1617 output_tensors: List of output tensors. 

1618 """ 

1619 if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test 

1620 raise ValueError("No ConcreteFunction is specified.") 

1621 

1622 if not self.experimental_lower_to_saved_model: 

1623 return None, None, None 

1624 

1625 # Without the provided trackable obj, it is not able to serialize the given 

1626 # concrete functions as a saved model format. Also when trackable obj is 

1627 # a function, use the original concrete function conversion pipline. 

1628 if not self._trackable_obj or isinstance( 

1629 self._trackable_obj, 

1630 (_function.ConcreteFunction, _def_function.Function), 

1631 ): 

1632 return None, None, None 

1633 

1634 signatures = {} 

1635 signature_keys = [] 

1636 try: 

1637 if len(self._funcs) == 1: 

1638 signatures[_signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( 

1639 self._funcs[0] 

1640 ) 

1641 signature_keys = [ 

1642 _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 

1643 ] 

1644 else: 

1645 for func in self._funcs: 

1646 signatures[func.graph.name] = func 

1647 signature_keys.append(func.graph.name) 

1648 

1649 _saved_model.save( 

1650 self._trackable_obj, 

1651 output_dir, 

1652 signatures=signatures, 

1653 options=_save_options.SaveOptions(save_debug_info=True), 

1654 ) 

1655 except Exception: # pylint: disable=broad-except 

1656 # When storing the given concrete function to a saved model is failed, 

1657 # let's use original concrete function conversion pipeline. 

1658 return None, None, None 

1659 

1660 self.saved_model_dir = output_dir 

1661 self._saved_model_tags = set([_tag_constants.SERVING]) 

1662 self._saved_model_exported_names = signature_keys 

1663 self._parse_saved_model_args(always_enable_saved_model_import=True) 

1664 if self.saved_model_dir: 

1665 graph_def, input_tensors, output_tensors = self._load_saved_model( 

1666 self.saved_model_dir, self._saved_model_tags 

1667 ) 

1668 self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags) 

1669 return graph_def, input_tensors, output_tensors 

1670 return None, None, None 

1671 

1672 def _convert_as_saved_model(self): 

1673 """Converts the given concrete functions as a saved model format. 

1674 

1675 Returns: 

1676 The converted data in serialized format. 

1677 """ 

1678 temp_dir = tempfile.mkdtemp() 

1679 try: 

1680 graph_def, input_tensors, _ = ( 

1681 self._convert_concrete_functions_to_saved_model(temp_dir) 

1682 ) 

1683 if self.saved_model_dir: 

1684 self._validate_inputs(graph_def, input_tensors) 

1685 return self._convert_from_saved_model(graph_def) 

1686 finally: 

1687 shutil.rmtree(temp_dir, True) 

1688 return None 

1689 

1690 @_export_metrics 

1691 def convert(self): 

1692 """Converts a TensorFlow GraphDef based on instance variables. 

1693 

1694 Returns: 

1695 The converted data in serialized format. 

1696 

1697 Raises: 

1698 ValueError: 

1699 No concrete functions is specified. 

1700 Multiple concrete functions are specified. 

1701 Input shape is not specified. 

1702 Invalid quantization parameters. 

1703 """ 

1704 if self.experimental_lower_to_saved_model: 

1705 saved_model_convert_result = self._convert_as_saved_model() 

1706 if saved_model_convert_result: 

1707 return saved_model_convert_result 

1708 

1709 graph_def, input_tensors, output_tensors, frozen_func = ( 

1710 self._freeze_concrete_function() 

1711 ) 

1712 

1713 graph_def = self._optimize_tf_model( 

1714 graph_def, input_tensors, output_tensors, frozen_func 

1715 ) 

1716 

1717 return super(TFLiteFrozenGraphConverterV2, self).convert( 

1718 graph_def, input_tensors, output_tensors 

1719 ) 

1720 

1721 

1722class TFLiteJaxConverterV2(TFLiteConverterBaseV2): 

1723 """Converts the given jax model into TensorFlow Lite model.""" 

1724 

1725 def __init__(self, serving_funcs, inputs): 

1726 """Constructor for TFLiteConverter. 

1727 

1728 Args: 

1729 serving_funcs: A list functions of the serving func of the jax module, the 

1730 model params should already be inlined. (e.g., `serving_func = 

1731 functools.partial(model, params=params)`) 

1732 inputs: Array of input tensor placeholders tuple,s like `jnp.zeros`. For 

1733 example, wrapped in an array like "[('input1', input1), ('input2', 

1734 input2)]]". 

1735 

1736 Jax functions are polymorphic, for example: 

1737 

1738 ```python 

1739 def add(a, b): 

1740 return a + b 

1741 ``` 

1742 

1743 Will yield different computations if different input signatures are passed 

1744 in: Pass `add(10.0, 20.0)` will yield a scalar `add` while pass 

1745 `add(np.random((100, 1)), np.random(100, 100))` will yield a broadcasting 

1746 add. We will need the input information to do tracing for the converter 

1747 to properly convert the model. So it's important to pass in the desired 

1748 `input placeholders` with the correct input shape/type. 

1749 

1750 In the converted tflite model, the function name will be default to "main", 

1751 the output names will be the traced outputs. The output ordering shall 

1752 match the serving function. 

1753 """ # fmt: skip 

1754 

1755 super(TFLiteJaxConverterV2, self).__init__() 

1756 self._serving_funcs = serving_funcs 

1757 self._inputs = inputs 

1758 

1759 @_export_metrics 

1760 def convert(self): 

1761 """Converts a Jax serving func based on instance variables. 

1762 

1763 Returns: 

1764 The converted data in serialized format. 

1765 

1766 Raises: 

1767 ImportError: 

1768 If cannot import the xla_computation from jax. 

1769 ValueError: 

1770 No serving function is specified. 

1771 Input tensors are not specified. 

1772 The truth value of an array with more than one element is ambiguous. 

1773 Failed to convert the given Jax function to hlo. 

1774 """ 

1775 if not _xla_computation: 

1776 raise ImportError("Cannot import xla_computation from jax.") 

1777 

1778 if not self._serving_funcs: 

1779 raise ValueError("No serving func is specified.") 

1780 

1781 if not self._inputs: 

1782 raise ValueError("Input tensors are not specified.") 

1783 

1784 if len(self._inputs) != len(self._serving_funcs): 

1785 msg = ( 

1786 "Input tensor mapping len {} does not match serving func len {}." 

1787 .format(len(self._inputs), len(self._serving_funcs)) 

1788 ) 

1789 raise ValueError(msg) 

1790 

1791 if not isinstance(self._inputs, (tuple, list)): 

1792 raise ValueError( 

1793 "Input tensors should be pass in a tuple list wrapped in an array." 

1794 ) 

1795 

1796 # TODO(b/197690428): Support multiple functions. 

1797 # Currently only support one serving function. 

1798 if len(self._serving_funcs) > 1: 

1799 raise ValueError("Currently only support single serving function.") 

1800 

1801 if not isinstance(self._inputs[0], (tuple, list)): 

1802 raise ValueError("The input placeholders are not a dictionary.") 

1803 

1804 input_names = [] 

1805 ordered_inputs = [] 

1806 for input_name, tensor in self._inputs[0]: 

1807 input_names.append(input_name) 

1808 ordered_inputs.append(tensor) 

1809 

1810 try: 

1811 xla_compuation = _xla_computation(self._serving_funcs[0], backend="cpu") 

1812 hlo_proto = xla_compuation( 

1813 *ordered_inputs 

1814 ).as_serialized_hlo_module_proto() 

1815 except Exception: # pylint: disable=broad-except 

1816 raise ValueError("Failed to convert the given Jax function to hlo.") 

1817 

1818 # We need to set the hlo proto, and here we use serialized proto format 

1819 # since it's more compact. 

1820 converter_kwargs = { 

1821 "input_content": hlo_proto, 

1822 "input_names": input_names, 

1823 "is_proto_format": True, 

1824 } 

1825 converter_kwargs.update(self._get_base_converter_args()) 

1826 

1827 # Get quantization options and do some checks. 

1828 quant_mode = QuantizationMode( 

1829 self.optimizations, self.target_spec, self.representative_dataset, None 

1830 ) 

1831 self._validate_inference_input_output_types(quant_mode) 

1832 converter_kwargs.update(quant_mode.converter_flags()) 

1833 result = _convert_jax_hlo(**converter_kwargs) 

1834 

1835 return self._optimize_tflite_model( 

1836 result, quant_mode, quant_io=self.experimental_new_quantizer 

1837 ) 

1838 

1839 

1840@_tf_export("lite.TFLiteConverter", v1=[]) 

1841class TFLiteConverterV2(TFLiteFrozenGraphConverterV2): 

1842 """Converts a TensorFlow model into TensorFlow Lite model. 

1843 

1844 Attributes: 

1845 optimizations: Experimental flag, subject to change. Set of optimizations to 

1846 apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a 

1847 set of values of type `tf.lite.Optimize`) 

1848 representative_dataset: A generator function used for integer quantization 

1849 where each generated sample has the same order, type and shape as the 

1850 inputs to the model. Usually, this is a small subset of a few hundred 

1851 samples randomly chosen, in no particular order, from the training or 

1852 evaluation dataset. This is an optional attribute, but required for full 

1853 integer quantization, i.e, if `tf.int8` is the only supported type in 

1854 `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`. 

1855 (default None) 

1856 target_spec: Experimental flag, subject to change. Specifications of target 

1857 device, including supported ops set, supported types and a set of user's 

1858 defined TensorFlow operators required in the TensorFlow Lite runtime. 

1859 Refer to `tf.lite.TargetSpec`. 

1860 inference_input_type: Data type of the input layer. Note that integer types 

1861 (tf.int8 and tf.uint8) are currently only supported for post training 

1862 integer quantization and quantization aware training. (default tf.float32, 

1863 must be in {tf.float32, tf.int8, tf.uint8}) 

1864 inference_output_type: Data type of the output layer. Note that integer 

1865 types (tf.int8 and tf.uint8) are currently only supported for post 

1866 training integer quantization and quantization aware training. (default 

1867 tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) 

1868 allow_custom_ops: Boolean indicating whether to allow custom operations. 

1869 When False, any unknown operation is an error. When True, custom ops are 

1870 created for any op that is unknown. The developer needs to provide these 

1871 to the TensorFlow Lite runtime with a custom resolver. (default False) 

1872 exclude_conversion_metadata: Whether not to embed the conversion metadata 

1873 into the converted model. (default False) 

1874 experimental_new_converter: Experimental flag, subject to change. Enables 

1875 MLIR-based conversion. (default True) 

1876 experimental_new_quantizer: Experimental flag, subject to change. Enables 

1877 MLIR-based quantization conversion instead of Flatbuffer-based conversion. 

1878 (default True) 

1879 experimental_enable_resource_variables: Experimental flag, subject to 

1880 change. Enables [resource 

1881 variables](https://tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables) 

1882 to be converted by this converter. This is only allowed if the 

1883 from_saved_model interface is used. (default True) 

1884 

1885 Example usage: 

1886 

1887 ```python 

1888 # Converting a SavedModel to a TensorFlow Lite model. 

1889 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) 

1890 tflite_model = converter.convert() 

1891 

1892 # Converting a tf.Keras model to a TensorFlow Lite model. 

1893 converter = tf.lite.TFLiteConverter.from_keras_model(model) 

1894 tflite_model = converter.convert() 

1895 

1896 # Converting ConcreteFunctions to a TensorFlow Lite model. 

1897 converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model) 

1898 tflite_model = converter.convert() 

1899 

1900 # Converting a Jax model to a TensorFlow Lite model. 

1901 converter = tf.lite.TFLiteConverter.experimental_from_jax( 

1902 [func], [[ ('input1', input1), ('input2', input2)]]) 

1903 tflite_model = converter.convert() 

1904 ``` 

1905 """ # fmt: skip 

1906 

1907 # pylint: disable=useless-super-delegation 

1908 def __init__(self, funcs, trackable_obj=None): 

1909 """Constructor for TFLiteConverter. 

1910 

1911 Args: 

1912 funcs: List of TensorFlow ConcreteFunctions. The list should not contain 

1913 duplicate elements. 

1914 trackable_obj: tf.AutoTrackable object associated with `funcs`. A 

1915 reference to this object needs to be maintained so that Variables do not 

1916 get garbage collected since functions have a weak reference to 

1917 Variables. This is only required when the tf.AutoTrackable object is not 

1918 maintained by the user (e.g. `from_saved_model`). 

1919 """ 

1920 super(TFLiteConverterV2, self).__init__(funcs, trackable_obj) 

1921 

1922 @classmethod 

1923 def from_concrete_functions(cls, funcs, trackable_obj=None): 

1924 """Creates a TFLiteConverter object from ConcreteFunctions. 

1925 

1926 Args: 

1927 funcs: List of TensorFlow ConcreteFunctions. The list should not contain 

1928 duplicate elements. Currently converter can only convert a single 

1929 ConcreteFunction. Converting multiple functions is under development. 

1930 trackable_obj: An `AutoTrackable` object (typically `tf.module`) 

1931 associated with `funcs`. A reference to this object needs to be 

1932 maintained so that Variables do not get garbage collected since 

1933 functions have a weak reference to Variables. 

1934 

1935 Returns: 

1936 TFLiteConverter object. 

1937 

1938 Raises: 

1939 Invalid input type. 

1940 """ 

1941 # pylint: disable=protected-access 

1942 TFLiteConverterBase._set_original_model_type( 

1943 conversion_metdata_fb.ModelType.TF_CONCRETE_FUNCTIONS 

1944 ) 

1945 # pylint: enable=protected-access 

1946 if trackable_obj is None: 

1947 logging.warning( 

1948 "Please consider providing the trackable_obj argument in the " 

1949 "from_concrete_functions. Providing without the trackable_obj " 

1950 "argument is deprecated and it will use the deprecated conversion " 

1951 "path." 

1952 ) 

1953 for func in funcs: 

1954 if not isinstance(func, _function.ConcreteFunction): 

1955 message = "This function takes in a list of ConcreteFunction." 

1956 if isinstance(func, _def_function.Function): 

1957 message += ( 

1958 " To get the ConcreteFunction from a Function," 

1959 " call get_concrete_function." 

1960 ) 

1961 raise ValueError(message) 

1962 return cls(funcs, trackable_obj) 

1963 

1964 @classmethod 

1965 def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None): 

1966 """Creates a TFLiteConverter object from a SavedModel directory. 

1967 

1968 Args: 

1969 saved_model_dir: SavedModel directory to convert. 

1970 signature_keys: List of keys identifying SignatureDef containing inputs 

1971 and outputs. Elements should not be duplicated. By default the 

1972 `signatures` attribute of the MetaGraphdef is used. (default 

1973 saved_model.signatures) 

1974 tags: Set of tags identifying the MetaGraphDef within the SavedModel to 

1975 analyze. All tags in the tag set must be present. (default 

1976 {tf.saved_model.SERVING} or {'serve'}) 

1977 

1978 Returns: 

1979 TFLiteConverter object. 

1980 

1981 Raises: 

1982 Invalid signature keys. 

1983 """ 

1984 # pylint: disable=protected-access 

1985 TFLiteConverterBase._set_original_model_type( 

1986 conversion_metdata_fb.ModelType.TF_SAVED_MODEL 

1987 ) 

1988 # pylint: enable=protected-access 

1989 # When run without eager enabled, this will return the legacy 

1990 # TFLiteConverter. 

1991 if not context.executing_eagerly(): 

1992 signature_key = None 

1993 if signature_keys: 

1994 if len(signature_keys) != 1: 

1995 raise ValueError("Only support a single signature key.") 

1996 else: 

1997 signature_key = signature_keys[0] 

1998 logging.warning( 

1999 "Invoking the TF1 implementation of TFLiteConverter " 

2000 "because eager is disabled. Consider enabling eager." 

2001 ) 

2002 return TFLiteConverter.from_saved_model( 

2003 saved_model_dir, signature_key=signature_key, tag_set=tags 

2004 ) 

2005 

2006 # Ensures any graphs created in Eager mode are able to run. This is required 

2007 # in order to create a tf.estimator.Exporter that exports a TFLite model. 

2008 if tags is None: 

2009 tags = set([_tag_constants.SERVING]) 

2010 

2011 with context.eager_mode(): 

2012 saved_model = _load(saved_model_dir, tags) 

2013 if not signature_keys: 

2014 signature_keys = saved_model.signatures 

2015 

2016 if not signature_keys: 

2017 raise ValueError("Only support at least one signature key.") 

2018 

2019 # Distinguishes SavedModel artifacts created by `model.export` 

2020 # from SavedModel created by `model.save`/`tf.saved_model.save`. 

2021 if ( 

2022 len(signature_keys) > 1 

2023 and hasattr(saved_model, "serve") # `model.export` default endpoint 

2024 and not hasattr(saved_model, "_default_save_signature") 

2025 # `_default_save_signature` does not exist for `model.export` artifacts. 

2026 ): 

2027 # Default `serve` endpoint for `model.export` should be copied 

2028 # to `serving_default` to prevent issues in TF Lite serving. 

2029 saved_model.serving_default = saved_model.serve 

2030 delattr(saved_model, "serve") 

2031 signature_keys = ["serving_default"] 

2032 

2033 funcs = [] 

2034 for key in signature_keys: 

2035 if key not in saved_model.signatures: 

2036 raise ValueError( 

2037 "Invalid signature key '{}' found. Valid keys are '{}'.".format( 

2038 key, ",".join(saved_model.signatures) 

2039 ) 

2040 ) 

2041 funcs.append(saved_model.signatures[key]) 

2042 

2043 saved_model_converter = TFLiteSavedModelConverterV2( 

2044 saved_model_dir, tags, signature_keys, saved_model 

2045 ) 

2046 if saved_model_converter.saved_model_dir: 

2047 return saved_model_converter 

2048 

2049 return cls(funcs, saved_model) 

2050 

2051 @classmethod 

2052 def from_keras_model(cls, model): 

2053 """Creates a TFLiteConverter object from a Keras model. 

2054 

2055 Args: 

2056 model: tf.Keras.Model 

2057 

2058 Returns: 

2059 TFLiteConverter object. 

2060 """ 

2061 # pylint: disable=protected-access 

2062 TFLiteConverterBase._set_original_model_type( 

2063 conversion_metdata_fb.ModelType.KERAS_MODEL 

2064 ) 

2065 # pylint: enable=protected-access 

2066 return TFLiteKerasModelConverterV2(model) 

2067 

2068 @classmethod 

2069 def experimental_from_jax(cls, serving_funcs, inputs): 

2070 # Experimental API, subject to changes. 

2071 # TODO(b/197690428): Currently only support single function. 

2072 """Creates a TFLiteConverter object from a Jax model with its inputs. 

2073 

2074 Args: 

2075 serving_funcs: A array of Jax functions with all the weights applied 

2076 already. 

2077 inputs: A array of Jax input placeholders tuples list, e.g., 

2078 jnp.zeros(INPUT_SHAPE). Each tuple list should correspond with the 

2079 serving function. 

2080 

2081 Returns: 

2082 TFLiteConverter object. 

2083 """ 

2084 # pylint: disable=protected-access 

2085 TFLiteConverterBase._set_original_model_type( 

2086 conversion_metdata_fb.ModelType.JAX 

2087 ) 

2088 # pylint: enable=protected-access 

2089 return TFLiteJaxConverterV2(serving_funcs, inputs) 

2090 

2091 # pylint: disable=useless-super-delegation 

2092 def convert(self): 

2093 """Converts a TensorFlow GraphDef based on instance variables. 

2094 

2095 Returns: 

2096 The converted data in serialized format. 

2097 

2098 Raises: 

2099 ValueError: 

2100 No concrete functions is specified. 

2101 Multiple concrete functions are specified. 

2102 Input shape is not specified. 

2103 Invalid quantization parameters. 

2104 """ 

2105 return super(TFLiteConverterV2, self).convert() 

2106 

2107 

2108class TFLiteConverterBaseV1(TFLiteConverterBase): 

2109 """Converter subclass to share functionality between V1 converters.""" 

2110 

2111 def __init__(self, experimental_debug_info_func): 

2112 """Constructor for TFLiteConverter. 

2113 

2114 Args: 

2115 experimental_debug_info_func: An experimental function to retrieve the 

2116 graph debug info for a set of nodes from the `graph_def`. 

2117 """ 

2118 super(TFLiteConverterBaseV1, self).__init__() 

2119 self.inference_type = _dtypes.float32 

2120 self.inference_input_type = None 

2121 self.inference_output_type = None 

2122 self.output_format = constants.TFLITE 

2123 self.quantized_input_stats = {} 

2124 self.default_ranges_stats = None 

2125 self.drop_control_dependency = True 

2126 self.reorder_across_fake_quant = False 

2127 self.change_concat_input_ranges = False 

2128 self.dump_graphviz_dir = None 

2129 self.dump_graphviz_video = False 

2130 self.conversion_summary_dir = None 

2131 self._debug_info_func = experimental_debug_info_func 

2132 self._metadata.environment.apiVersion = 1 

2133 

2134 def __setattr__(self, name, value): 

2135 if name == "post_training_quantize": 

2136 warnings.warn( 

2137 "Property %s is deprecated, " 

2138 "please use optimizations=[Optimize.DEFAULT]" 

2139 " instead." % name 

2140 ) 

2141 if value: 

2142 self.optimizations = [Optimize.DEFAULT] 

2143 else: 

2144 self.optimizations = [] 

2145 return 

2146 if name == "target_ops": 

2147 warnings.warn( 

2148 "Property %s is deprecated, please use " 

2149 "target_spec.supported_ops instead." % name 

2150 ) 

2151 self.target_spec.supported_ops = value 

2152 return 

2153 object.__setattr__(self, name, value) 

2154 

2155 def __getattribute__(self, name): 

2156 if name == "post_training_quantize": 

2157 warnings.warn( 

2158 "Property %s is deprecated, " 

2159 "please use optimizations=[Optimize.DEFAULT]" 

2160 " instead." % name 

2161 ) 

2162 return Optimize.DEFAULT in set(self.optimizations) 

2163 if name == "target_ops": 

2164 warnings.warn( 

2165 "Property %s is deprecated, please use " 

2166 "target_spec.supported_ops instead." % name 

2167 ) 

2168 return self.target_spec.supported_ops 

2169 return object.__getattribute__(self, name) 

2170 

2171 def _validate_quantized_input_stats(self, converter_kwargs, quant_mode): 

2172 """Ensure the `quantized_input_stats` flag is provided if required.""" 

2173 

2174 quantized_types = frozenset({_dtypes.int8, _dtypes.uint8}) 

2175 

2176 requires_quantized_input_stats = ( 

2177 converter_kwargs["inference_type"] in quantized_types 

2178 or converter_kwargs["inference_input_type"] in quantized_types 

2179 ) and not quant_mode.is_post_training_integer_quantization() 

2180 

2181 if ( 

2182 requires_quantized_input_stats 

2183 and not converter_kwargs["quantized_input_stats"] 

2184 ): 

2185 raise ValueError( 

2186 "The `quantized_input_stats` flag must be defined when either " 

2187 "`inference_type` flag or `inference_input_type` flag is set to " 

2188 "tf.int8 or tf.uint8. Currently, `inference_type={}` and " 

2189 "`inference_input_type={}`.".format( 

2190 _get_tf_type_name(converter_kwargs["inference_type"]), 

2191 _get_tf_type_name(converter_kwargs["inference_input_type"]), 

2192 ) 

2193 ) 

2194 

2195 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS) 

2196 def _validate_inputs(self, input_tensors, quantized_input_stats): 

2197 """Validate input parameters. 

2198 

2199 Args: 

2200 input_tensors: List of input tensors. 

2201 quantized_input_stats: Map of input tensor names to a tuple of floats 

2202 representing the mean and standard deviation of the training data. 

2203 

2204 Raises: 

2205 ValueError: 

2206 Input shape is not specified. 

2207 Quantization input stats is required but not provided. 

2208 """ 

2209 

2210 if not self._is_unknown_shapes_allowed() and self._has_valid_tensors(): 

2211 # Checks dimensions in input tensor. 

2212 for tensor in input_tensors: 

2213 shape = tensor.shape 

2214 if not shape: 

2215 raise ValueError( 

2216 "Provide an input shape for input array '{0}'.".format( 

2217 _get_tensor_name(tensor) 

2218 ) 

2219 ) 

2220 # Note that shape_list might be empty for scalar shapes. 

2221 shape_list = shape.as_list() 

2222 if None in shape_list[1:]: 

2223 raise ValueError( 

2224 "None is only supported in the 1st dimension. Tensor '{0}' has " 

2225 "invalid shape '{1}'.".format( 

2226 _get_tensor_name(tensor), shape_list 

2227 ) 

2228 ) 

2229 elif shape_list and shape_list[0] is None: 

2230 self._set_batch_size(batch_size=1) 

2231 

2232 # Get quantization stats. Ensures there is one stat per name if the stats 

2233 # are specified. 

2234 if quantized_input_stats: 

2235 self._quantized_stats = [] 

2236 invalid_stats = [] 

2237 for name in self.get_input_arrays(): 

2238 if name in quantized_input_stats: 

2239 self._quantized_stats.append(quantized_input_stats[name]) 

2240 else: 

2241 invalid_stats.append(name) 

2242 

2243 if invalid_stats: 

2244 raise ValueError( 

2245 "Quantization input stats are not available for input " 

2246 "tensors '{0}'.".format(",".join(invalid_stats)) 

2247 ) 

2248 else: 

2249 self._quantized_stats = None 

2250 

2251 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL) 

2252 def _optimize_tf_model( 

2253 self, graph_def, input_tensors, output_tensors, quant_mode 

2254 ): 

2255 """Run a Grappler pass to optimize the TensorFlow graph. 

2256 

2257 Args: 

2258 graph_def: Frozen GraphDef to be optimized. 

2259 input_tensors: List of input tensors. 

2260 output_tensors: List of output tensors. 

2261 quant_mode: the quantization mode. 

2262 

2263 Returns: 

2264 The optimized TensorFlow graph. 

2265 """ 

2266 # Disable grappler constant folding if there are training quant ops. 

2267 if self.saved_model_dir or quant_mode.is_quantization_aware_trained_model(): 

2268 return graph_def 

2269 

2270 try: 

2271 # TODO(b/150163103): Merge `disabling lower using switch merge' calls. 

2272 # Grappler will also try to lower while loop into switch merge 

2273 # representation which is undesired for Ophints, so we simply remove 

2274 # those attributes to prevent Grappler from doing so. 

2275 graph = _convert_to_constants.disable_lower_using_switch_merge(graph_def) 

2276 # Run function inlining optimization to ensure any models generated 

2277 # through the from_frozen_graph path have been inlined. 

2278 optimized_graph = _run_graph_optimizations( 

2279 graph, 

2280 input_tensors, 

2281 output_tensors, 

2282 config=self._grappler_config(["function"]), 

2283 ) 

2284 return optimized_graph 

2285 except Exception: # pylint: disable=broad-except 

2286 return graph_def 

2287 

2288 def convert(self): 

2289 """Converts a TensorFlow GraphDef based on instance variables. 

2290 

2291 Returns: 

2292 The converted data in serialized format. Either a TFLite Flatbuffer or a 

2293 Graphviz graph depending on value in `output_format`. 

2294 

2295 Raises: 

2296 ValueError: 

2297 Input shape is not specified. 

2298 None value for dimension in input_tensor. 

2299 """ 

2300 self._validate_inputs(self._input_tensors, self.quantized_input_stats) 

2301 

2302 quant_mode = QuantizationMode( 

2303 self.optimizations, 

2304 self.target_spec, 

2305 self.representative_dataset, 

2306 self._graph_def, 

2307 self._experimental_disable_per_channel, 

2308 self.experimental_new_dynamic_range_quantizer, 

2309 self._experimental_low_bit_qat, 

2310 self._experimental_full_integer_quantization_bias_type, 

2311 self._experimental_variable_quantization, 

2312 ) 

2313 

2314 optimized_graph = self._optimize_tf_model( 

2315 self._graph_def, self._input_tensors, self._output_tensors, quant_mode 

2316 ) 

2317 

2318 self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph) 

2319 

2320 converter_kwargs = self._get_base_converter_args() 

2321 converter_kwargs.update( 

2322 quant_mode.converter_flags( 

2323 self.inference_type, self.inference_input_type 

2324 ) 

2325 ) 

2326 converter_kwargs.update({ 

2327 "output_format": self.output_format, 

2328 "quantized_input_stats": self._quantized_stats, 

2329 "default_ranges_stats": self.default_ranges_stats, 

2330 "drop_control_dependency": self.drop_control_dependency, 

2331 "reorder_across_fake_quant": self.reorder_across_fake_quant, 

2332 "change_concat_input_ranges": self.change_concat_input_ranges, 

2333 "dump_graphviz_dir": self.dump_graphviz_dir, 

2334 "dump_graphviz_video": self.dump_graphviz_video, 

2335 "conversion_summary_dir": self.conversion_summary_dir, 

2336 }) 

2337 

2338 self._validate_quantized_input_stats(converter_kwargs, quant_mode) 

2339 if not self.experimental_new_converter: 

2340 logging.warning( 

2341 "Please consider switching to the new converter by setting " 

2342 "experimental_new_converter=True. " 

2343 "The old converter is deprecated." 

2344 ) 

2345 else: 

2346 logging.info( 

2347 "Using experimental converter: If you encountered a problem " 

2348 "please file a bug. You can opt-out " 

2349 "by setting experimental_new_converter=False" 

2350 ) 

2351 # Converts model. 

2352 if self._has_valid_tensors(): 

2353 result = _convert_graphdef( 

2354 input_data=optimized_graph, 

2355 input_tensors=self._input_tensors, 

2356 output_tensors=self._output_tensors, 

2357 **converter_kwargs, 

2358 ) 

2359 else: 

2360 result = _convert_graphdef_with_arrays( 

2361 input_data=optimized_graph, 

2362 input_arrays_with_shape=self._input_arrays_with_shape, 

2363 output_arrays=self._output_arrays, 

2364 control_output_arrays=self._control_output_arrays, 

2365 **converter_kwargs, 

2366 ) 

2367 

2368 return self._optimize_tflite_model( 

2369 result, quant_mode, quant_io=self.experimental_new_quantizer 

2370 ) 

2371 

2372 def get_input_arrays(self): 

2373 """Returns a list of the names of the input tensors. 

2374 

2375 Returns: 

2376 List of strings. 

2377 """ 

2378 if self._has_valid_tensors(): 

2379 return [_get_tensor_name(tensor) for tensor in self._input_tensors] 

2380 else: 

2381 return [name for name, _ in self._input_arrays_with_shape] 

2382 

2383 def _has_valid_tensors(self): 

2384 """Checks if the input and output tensors have been initialized. 

2385 

2386 Returns: 

2387 Bool. 

2388 """ 

2389 return self._input_tensors is not None and self._output_tensors 

2390 

2391 def _set_batch_size(self, batch_size): 

2392 """Sets the first dimension of the input tensor to `batch_size`. 

2393 

2394 Args: 

2395 batch_size: Batch size for the model. Replaces the first dimension of an 

2396 input size array if undefined. (default 1) 

2397 

2398 Raises: 

2399 ValueError: input_tensor is not defined. 

2400 """ 

2401 if not self._has_valid_tensors(): 

2402 raise ValueError( 

2403 "The batch size cannot be set for this model. Please " 

2404 "use input_shapes parameter." 

2405 ) 

2406 

2407 for tensor in self._input_tensors: 

2408 shape = tensor.shape.as_list() 

2409 if shape[0] is None: 

2410 shape[0] = batch_size 

2411 tensor.set_shape(shape) 

2412 

2413 def _is_unknown_shapes_allowed(self): 

2414 # Ophint Converted nodes will need the shapes to be known. 

2415 if _is_ophint_converted(self._graph_def): 

2416 return False 

2417 

2418 if not super(TFLiteConverterBaseV1, self)._is_unknown_shapes_allowed(): 

2419 return False 

2420 

2421 # `conversion_summary_dir` calls the old converter. Unknown shapes are only 

2422 # supported by the MLIR converter. 

2423 if self.conversion_summary_dir: 

2424 logging.warning( 

2425 "`conversion_summary_dir` does not work with unknown shapes. " 

2426 "Graphs with unknown shapes might be different than when this flag " 

2427 "is disabled." 

2428 ) 

2429 return False 

2430 return True 

2431 

2432 def _save_conversion_params_metric(self): 

2433 self._collected_converter_params.update({ 

2434 "output_format": self.output_format, 

2435 "default_ranges_stats": self.default_ranges_stats, 

2436 "drop_control_dependency": self.drop_control_dependency, 

2437 "reorder_across_fake_quant": self.reorder_across_fake_quant, 

2438 "change_concat_input_ranges": self.change_concat_input_ranges, 

2439 "dump_graphviz_dir": self.dump_graphviz_dir, 

2440 "dump_graphviz_video": self.dump_graphviz_video, 

2441 "conversion_summary_dir": self.conversion_summary_dir, 

2442 }) 

2443 super(TFLiteConverterBaseV1, self)._save_conversion_params_metric( 

2444 self._graph_def, self.inference_type, self.inference_input_type 

2445 ) 

2446 

2447 

2448class TFLiteSavedModelConverter(TFLiteConverterBaseV1): 

2449 """Converts the given SavedModel into TensorFlow Lite model. 

2450 

2451 Attributes: 

2452 saved_model_dir: Directory of the SavedModel. 

2453 """ 

2454 

2455 def __init__( 

2456 self, 

2457 saved_model_dir, 

2458 saved_model_tags, 

2459 saved_model_exported_names, 

2460 experimental_debug_info_func=None, 

2461 ): 

2462 """Constructor for TFLiteConverter. 

2463 

2464 Args: 

2465 saved_model_dir: Directory of the SavedModel. 

2466 saved_model_tags: Set of tags identifying the MetaGraphDef within the 

2467 SavedModel to analyze. All tags in the tag set must be present. (default 

2468 {tf.saved_model.SERVING}). 

2469 saved_model_exported_names: Names to be exported when the saved model 

2470 import path is on. 

2471 experimental_debug_info_func: An experimental function to retrieve the 

2472 graph debug info for a set of nodes from the `graph_def`. 

2473 

2474 Raises: 

2475 ValueError: Invalid arguments. 

2476 """ 

2477 super(TFLiteSavedModelConverter, self).__init__( 

2478 experimental_debug_info_func 

2479 ) 

2480 self.saved_model_dir = saved_model_dir 

2481 self._saved_model_tags = saved_model_tags 

2482 self._saved_model_exported_names = saved_model_exported_names 

2483 

2484 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 

2485 

2486 if len(self._saved_model_exported_names) != 1: 

2487 raise ValueError("Only support a single signature key.") 

2488 

2489 signature_key = self._saved_model_exported_names[0] 

2490 

2491 result = _freeze_saved_model( 

2492 self.saved_model_dir, 

2493 None, 

2494 None, 

2495 None, 

2496 self._saved_model_tags, 

2497 signature_key, 

2498 ) 

2499 self._graph_def = result[0] 

2500 self._input_tensors = result[1] 

2501 self._output_tensors = result[2] 

2502 self._parse_saved_model_args() 

2503 

2504 @_export_metrics 

2505 def convert(self): 

2506 """Converts a TensorFlow GraphDef based on instance variables. 

2507 

2508 Note that in the converted TensorFlow Lite model, the input tensor's order 

2509 might be changed each time `convert` is called. To access input tensor 

2510 information, please consider using the `SignatureRunner` API 

2511 (`interpreter.get_signature_runner`). 

2512 

2513 Returns: 

2514 The converted data in serialized format. Either a TFLite Flatbuffer or a 

2515 Graphviz graph depending on value in `output_format`. 

2516 

2517 Raises: 

2518 ValueError: 

2519 Input shape is not specified. 

2520 None value for dimension in input_tensor. 

2521 """ 

2522 return super(TFLiteSavedModelConverter, self).convert() 

2523 

2524 

2525class TFLiteKerasModelConverter(TFLiteConverterBaseV1): 

2526 """Converts the given SavedModel into TensorFlow Lite model.""" 

2527 

2528 def __init__( 

2529 self, 

2530 model_file, 

2531 input_arrays=None, 

2532 input_shapes=None, 

2533 output_arrays=None, 

2534 custom_objects=None, 

2535 ): 

2536 """Constructor for TFLiteConverter. 

2537 

2538 Args: 

2539 model_file: Full filepath of HDF5 file containing the tf.keras model. 

2540 input_arrays: List of input tensors to freeze graph with. Uses input 

2541 arrays from SignatureDef when none are provided. (default None) 

2542 input_shapes: Dict of strings representing input tensor names to list of 

2543 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 

2544 Automatically determined when input shapes is None (e.g., {"foo" : 

2545 None}). (default None) 

2546 output_arrays: List of output tensors to freeze graph with. Uses output 

2547 arrays from SignatureDef when none are provided. (default None) 

2548 custom_objects: Dict mapping names (strings) to custom classes or 

2549 functions to be considered during model deserialization. (default None) 

2550 

2551 Raises: 

2552 ValueError: Invalid arguments. 

2553 """ 

2554 super(TFLiteKerasModelConverter, self).__init__( 

2555 experimental_debug_info_func=None 

2556 ) 

2557 # Handles Keras when Eager mode is enabled. 

2558 if context.executing_eagerly(): 

2559 if input_arrays or output_arrays: 

2560 raise ValueError( 

2561 "`input_arrays` and `output_arrays` are unsupported " 

2562 "with Eager mode. If your model requires any of these " 

2563 "parameters, please use disable_eager_execution()." 

2564 ) 

2565 

2566 keras_model = keras_deps.get_load_model_function()( 

2567 model_file, custom_objects 

2568 ) 

2569 function = _trace_model_call(keras_model) 

2570 concrete_func = function.get_concrete_function() 

2571 

2572 frozen_func = _convert_to_constants.convert_variables_to_constants_v2( 

2573 concrete_func, lower_control_flow=False 

2574 ) 

2575 _set_tensor_shapes(frozen_func.inputs, input_shapes) 

2576 self._keras_model = keras_model 

2577 self._graph_def = frozen_func.graph.as_graph_def() 

2578 self._input_tensors = frozen_func.inputs 

2579 self._output_tensors = frozen_func.outputs 

2580 self._debug_info_func = _build_debug_info_func(frozen_func.graph) 

2581 return 

2582 

2583 # Handles Keras when Eager mode is disabled. 

2584 keras_deps.get_clear_session_function()() 

2585 keras_model = keras_deps.get_load_model_function()( 

2586 model_file, custom_objects 

2587 ) 

2588 sess = keras_deps.get_get_session_function()() 

2589 

2590 # Get input and output tensors. 

2591 if input_arrays: 

2592 input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) 

2593 else: 

2594 input_tensors = keras_model.inputs 

2595 

2596 if output_arrays: 

2597 output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays) 

2598 else: 

2599 output_tensors = keras_model.outputs 

2600 _set_tensor_shapes(input_tensors, input_shapes) 

2601 

2602 graph_def = _freeze_graph(sess, input_tensors, output_tensors) 

2603 self._keras_model = keras_model 

2604 self._graph_def = graph_def 

2605 self._input_tensors = input_tensors 

2606 self._output_tensors = output_tensors 

2607 self._debug_info_func = _build_debug_info_func(sess.graph) 

2608 

2609 @convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL) 

2610 def _freeze_keras_model(self, output_dir): 

2611 """Save Keras model to Saved Model format. 

2612 

2613 Args: 

2614 output_dir: The output directory to save the SavedModel. 

2615 """ 

2616 try: 

2617 self._keras_model.save(output_dir, save_format="tf") 

2618 except Exception: # pylint: disable=broad-except 

2619 # When storing the given keras model to a saved model is failed, let's 

2620 # use original keras model conversion pipeline. 

2621 return None 

2622 tag_set = set([_tag_constants.SERVING]) 

2623 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 

2624 graph_def, input_tensors, output_tensors, sess_graph = _freeze_saved_model( 

2625 output_dir, None, None, None, tag_set, signature_key 

2626 ) 

2627 

2628 self.saved_model_dir = output_dir 

2629 self._saved_model_tags = tag_set 

2630 self._saved_model_exported_names = [signature_key] 

2631 self._parse_saved_model_args() 

2632 if self.saved_model_dir: 

2633 self._graph_def = graph_def 

2634 self._input_tensors = input_tensors 

2635 self._output_tensors = output_tensors 

2636 self._debug_info_func = _build_debug_info_func(sess_graph) 

2637 

2638 def _convert_as_saved_model(self): 

2639 """Converts a Keras model as a saved model. 

2640 

2641 Returns: 

2642 The converted data in serialized format. 

2643 """ 

2644 temp_dir = tempfile.mkdtemp() 

2645 try: 

2646 self._freeze_keras_model(temp_dir) 

2647 if self.saved_model_dir: 

2648 return super(TFLiteKerasModelConverter, self).convert() 

2649 finally: 

2650 shutil.rmtree(temp_dir, True) 

2651 

2652 @_export_metrics 

2653 def convert(self): 

2654 """Converts a Keras model based on instance variables. 

2655 

2656 Returns: 

2657 The converted data in serialized format. Either a TFLite Flatbuffer or a 

2658 Graphviz graph depending on value in `output_format`. 

2659 

2660 Raises: 

2661 ValueError: 

2662 Input shape is not specified. 

2663 None value for dimension in input_tensor. 

2664 """ 

2665 saved_model_convert_result = self._convert_as_saved_model() 

2666 if saved_model_convert_result: 

2667 return saved_model_convert_result 

2668 

2669 return super(TFLiteKerasModelConverter, self).convert() 

2670 

2671 

2672class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1): 

2673 """Converts the given frozen graph def into TensorFlow Lite model.""" 

2674 

2675 def __init__( 

2676 self, 

2677 graph_def, 

2678 input_tensors, 

2679 output_tensors, 

2680 input_arrays_with_shape=None, 

2681 output_arrays=None, 

2682 experimental_debug_info_func=None, 

2683 ): 

2684 """Constructor for TFLiteConverter. 

2685 

2686 Args: 

2687 graph_def: Frozen TensorFlow GraphDef. 

2688 input_tensors: List of input tensors. Type and shape are computed using 

2689 `foo.shape` and `foo.dtype`. 

2690 output_tensors: List of output tensors (only .name is used from this). 

2691 input_arrays_with_shape: Tuple of strings representing input tensor names 

2692 and list of integers representing input shapes (e.g., [("foo", [1, 16, 

2693 16, 3])]). Use only when graph cannot be loaded into TensorFlow and when 

2694 `input_tensors` and `output_tensors` are None. (default None) 

2695 output_arrays: List of output tensors to freeze graph with. Use only when 

2696 graph cannot be loaded into TensorFlow and when `input_tensors` and 

2697 `output_tensors` are None. (default None) 

2698 experimental_debug_info_func: An experimental function to retrieve the 

2699 graph debug info for a set of nodes from the `graph_def`. 

2700 

2701 Raises: 

2702 ValueError: Invalid arguments. 

2703 """ 

2704 super(TFLiteFrozenGraphConverter, self).__init__( 

2705 experimental_debug_info_func 

2706 ) 

2707 self._graph_def = graph_def 

2708 self._input_tensors = input_tensors 

2709 self._output_tensors = output_tensors 

2710 self._control_output_arrays = None 

2711 

2712 # Attributes are used by models that cannot be loaded into TensorFlow. 

2713 if not self._has_valid_tensors(): 

2714 self._input_arrays_with_shape = input_arrays_with_shape 

2715 self._output_arrays = output_arrays 

2716 

2717 if input_tensors is not None and input_arrays_with_shape is not None: 

2718 logging.warning( 

2719 "input_arrays_with_shape will be ignored when both the " 

2720 "given input_tensors and input_arrays_with_shape are not " 

2721 "None." 

2722 ) 

2723 

2724 if output_tensors is not None and output_arrays is not None: 

2725 logging.warning( 

2726 "output_arrays will be ignored when both the given " 

2727 "output_tensors and output_arrays are not None." 

2728 ) 

2729 

2730 @_export_metrics 

2731 def convert(self): 

2732 """Converts a TensorFlow GraphDef based on instance variables. 

2733 

2734 Returns: 

2735 The converted data in serialized format. Either a TFLite Flatbuffer or a 

2736 Graphviz graph depending on value in `output_format`. 

2737 

2738 Raises: 

2739 ValueError: 

2740 Input shape is not specified. 

2741 None value for dimension in input_tensor. 

2742 """ 

2743 if not self._has_valid_tensors(): 

2744 if not self._input_arrays_with_shape or not ( 

2745 self._output_arrays or self._control_output_arrays 

2746 ): 

2747 raise ValueError( 

2748 "If input_tensors and output_tensors are None, both " 

2749 "input_arrays_with_shape and output_arrays|control_output_arrays " 

2750 "must be defined." 

2751 ) 

2752 return super(TFLiteFrozenGraphConverter, self).convert() 

2753 

2754 

2755@_tf_export(v1=["lite.TFLiteConverter"]) 

2756class TFLiteConverter(TFLiteFrozenGraphConverter): 

2757 """Convert a TensorFlow model into `output_format`. 

2758 

2759 This is used to convert from a TensorFlow GraphDef, SavedModel or tf.keras 

2760 model into either a TFLite FlatBuffer or graph visualization. 

2761 

2762 Attributes: 

2763 optimizations: Experimental flag, subject to change. Set of optimizations to 

2764 apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a 

2765 set of values of type `tf.lite.Optimize`) 

2766 representative_dataset: A generator function used for integer quantization 

2767 where each generated sample has the same order, type and shape as the 

2768 inputs to the model. Usually, this is a small subset of a few hundred 

2769 samples randomly chosen, in no particular order, from the training or 

2770 evaluation dataset. This is an optional attribute, but required for full 

2771 integer quantization, i.e, if `tf.int8` is the only supported type in 

2772 `target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`. 

2773 (default None) 

2774 target_spec: Experimental flag, subject to change. Specifications of target 

2775 device, including supported ops set, supported types and a set of user's 

2776 defined TensorFlow operators required in the TensorFlow Lite runtime. 

2777 Refer to `tf.lite.TargetSpec`. 

2778 inference_type: Data type of numeric arrays, excluding the input layer. 

2779 (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) 

2780 inference_input_type: Data type of the numeric arrays in the input layer. If 

2781 `inference_input_type` is in {tf.int8, tf.uint8}, then 

2782 `quantized_input_stats` must be provided. (default is the value assigned 

2783 to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8}) 

2784 inference_output_type: Data type of the numeric arrays in the output layer. 

2785 (default is the value assigned to `inference_type`, must be in 

2786 {tf.float32, tf.int8, tf.uint8}) 

2787 quantized_input_stats: Map of input tensor names to a tuple of floats 

2788 representing the mean and standard deviation of the training data. (e.g., 

2789 {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8 or 

2790 tf.uint8. (default None) 

2791 default_ranges_stats: Tuple of integers (min, max) representing range values 

2792 for all numeric arrays without a specified range. Intended for 

2793 experimenting with quantization via "dummy quantization". (default None) 

2794 allow_custom_ops: Boolean indicating whether to allow custom operations. 

2795 When False any unknown operation is an error. When True, custom ops are 

2796 created for any op that is unknown. The developer will need to provide 

2797 these to the TensorFlow Lite runtime with a custom resolver. (default 

2798 False) 

2799 drop_control_dependency: Boolean indicating whether to drop control 

2800 dependencies silently. This is due to TFLite not supporting control 

2801 dependencies. (default True) 

2802 reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant 

2803 nodes in unexpected locations. Used when the location of the FakeQuant 

2804 nodes is preventing graph transformations necessary to convert the graph. 

2805 Results in a graph that differs from the quantized training graph, 

2806 potentially causing differing arithmetic behavior. (default False) 

2807 change_concat_input_ranges: Boolean to change behavior of min/max ranges for 

2808 inputs and outputs of the concat operator for quantized models. Changes 

2809 the ranges of concat operator overlap when true. (default False) 

2810 output_format: Output file format. (default 

2811 tf.compat.v1.lite.constants.TFLITE, must be in 

2812 {tf.compat.v1.lite.constants.TFLITE, 

2813 tf.compat.v1.lite.constants.GRAPHVIZ_DOT}) 

2814 dump_graphviz_dir: Full filepath of folder to dump the graphs at various 

2815 stages of processing GraphViz .dot files. Preferred over 

2816 `output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep 

2817 the requirements of the output file. (default None) 

2818 dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot 

2819 files after every graph transformation. Requires the `dump_graphviz_dir` 

2820 flag to be specified. (default False) 

2821 conversion_summary_dir: Full path of the directory to store conversion logs. 

2822 (default None) 

2823 exclude_conversion_metadata: Whether not to embed the conversion metadata 

2824 into the converted model. (default False) 

2825 target_ops: Deprecated. Please use `target_spec.supported_ops` instead. 

2826 post_training_quantize: Deprecated. Please use `optimizations` instead and 

2827 set it to `{tf.lite.Optimize.DEFAULT}`. (default False) 

2828 experimental_new_converter: Experimental flag, subject to change. Enables 

2829 MLIR-based conversion. (default True) 

2830 experimental_new_quantizer: Experimental flag, subject to change. Enables 

2831 MLIR-based quantization conversion instead of Flatbuffer-based conversion. 

2832 (default True) Example usage: ```python # Converting a GraphDef from 

2833 session. converter = tf.compat.v1.lite.TFLiteConverter.from_session( sess, 

2834 in_tensors, out_tensors) tflite_model = converter.convert() 

2835 open("converted_model.tflite", "wb").write(tflite_model) # Converting a 

2836 GraphDef from file. converter = 

2837 tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( graph_def_file, 

2838 input_arrays, output_arrays) tflite_model = converter.convert() 

2839 open("converted_model.tflite", "wb").write(tflite_model) # Converting a 

2840 SavedModel. converter = 

2841 tf.compat.v1.lite.TFLiteConverter.from_saved_model( saved_model_dir) 

2842 tflite_model = converter.convert() open("converted_model.tflite", 

2843 "wb").write(tflite_model) # Converting a tf.keras model. converter = 

2844 tf.compat.v1.lite.TFLiteConverter.from_keras_model_file( keras_model) 

2845 tflite_model = converter.convert() open("converted_model.tflite", 

2846 "wb").write(tflite_model) ``` 

2847 """ 

2848 

2849 # pylint: disable=useless-super-delegation 

2850 def __init__( 

2851 self, 

2852 graph_def, 

2853 input_tensors, 

2854 output_tensors, 

2855 input_arrays_with_shape=None, 

2856 output_arrays=None, 

2857 experimental_debug_info_func=None, 

2858 ): 

2859 """Constructor for TFLiteConverter. 

2860 

2861 Args: 

2862 graph_def: Frozen TensorFlow GraphDef. 

2863 input_tensors: List of input tensors. Type and shape are computed using 

2864 `foo.shape` and `foo.dtype`. 

2865 output_tensors: List of output tensors (only .name is used from this). 

2866 input_arrays_with_shape: Tuple of strings representing input tensor names 

2867 and list of integers representing input shapes (e.g., [("foo" : [1, 16, 

2868 16, 3])]). Use only when graph cannot be loaded into TensorFlow and when 

2869 `input_tensors` and `output_tensors` are None. (default None) 

2870 output_arrays: List of output tensors to freeze graph with. Use only when 

2871 graph cannot be loaded into TensorFlow and when `input_tensors` and 

2872 `output_tensors` are None. (default None) 

2873 experimental_debug_info_func: An experimental function to retrieve the 

2874 graph debug info for a set of nodes from the `graph_def`. 

2875 

2876 Raises: 

2877 ValueError: Invalid arguments. 

2878 """ 

2879 super(TFLiteConverter, self).__init__( 

2880 graph_def, 

2881 input_tensors, 

2882 output_tensors, 

2883 input_arrays_with_shape, 

2884 output_arrays, 

2885 experimental_debug_info_func, 

2886 ) 

2887 

2888 @classmethod 

2889 def from_session(cls, sess, input_tensors, output_tensors): 

2890 """Creates a TFLiteConverter class from a TensorFlow Session. 

2891 

2892 Args: 

2893 sess: TensorFlow Session. 

2894 input_tensors: List of input tensors. Type and shape are computed using 

2895 `foo.shape` and `foo.dtype`. 

2896 output_tensors: List of output tensors (only .name is used from this). 

2897 

2898 Returns: 

2899 TFLiteConverter class. 

2900 """ 

2901 # pylint: disable=protected-access 

2902 TFLiteConverterBase._set_original_model_type( 

2903 conversion_metdata_fb.ModelType.TF_SESSION 

2904 ) 

2905 # pylint: enable=protected-access 

2906 graph_def = _freeze_graph(sess, input_tensors, output_tensors) 

2907 return cls( 

2908 graph_def, 

2909 input_tensors, 

2910 output_tensors, 

2911 experimental_debug_info_func=_build_debug_info_func(sess.graph), 

2912 ) 

2913 

2914 @classmethod 

2915 def from_frozen_graph( 

2916 cls, graph_def_file, input_arrays, output_arrays, input_shapes=None 

2917 ): 

2918 """Creates a TFLiteConverter class from a file containing a frozen GraphDef. 

2919 

2920 Args: 

2921 graph_def_file: Full filepath of file containing frozen GraphDef. 

2922 input_arrays: List of input tensors to freeze graph with. 

2923 output_arrays: List of output tensors to freeze graph with. 

2924 input_shapes: Dict of strings representing input tensor names to list of 

2925 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 

2926 Automatically determined when input shapes is None (e.g., {"foo" : 

2927 None}). (default None) 

2928 

2929 Returns: 

2930 TFLiteConverter class. 

2931 

2932 Raises: 

2933 IOError: 

2934 File not found. 

2935 Unable to parse input file. 

2936 ValueError: 

2937 The graph is not frozen. 

2938 input_arrays or output_arrays contains an invalid tensor name. 

2939 input_shapes is not correctly defined when required 

2940 """ 

2941 # pylint: disable=protected-access 

2942 TFLiteConverterBase._set_original_model_type( 

2943 conversion_metdata_fb.ModelType.TF_GRAPH_DEF 

2944 ) 

2945 # pylint: enable=protected-access 

2946 with _ops.Graph().as_default(): 

2947 with _session.Session() as sess: 

2948 # Read GraphDef from file. 

2949 if not gfile.Exists(graph_def_file): 

2950 raise IOError("File '{0}' does not exist.".format(graph_def_file)) 

2951 with gfile.GFile(graph_def_file, "rb") as f: 

2952 file_content = f.read() 

2953 

2954 try: 

2955 graph_def = _graph_pb2.GraphDef() 

2956 graph_def.ParseFromString(file_content) 

2957 except (_text_format.ParseError, DecodeError): 

2958 try: 

2959 print("Ignore 'tcmalloc: large alloc' warnings.") 

2960 

2961 if not isinstance(file_content, str): 

2962 file_content = file_content.decode("utf-8") 

2963 graph_def = _graph_pb2.GraphDef() 

2964 _text_format.Merge(file_content, graph_def) 

2965 except (_text_format.ParseError, DecodeError): 

2966 raise IOError( 

2967 "Unable to parse input file '{}'.".format(graph_def_file) 

2968 ) 

2969 

2970 if sys.byteorder == "big": 

2971 bst.swap_tensor_content_in_graph_node(graph_def, "little", "big") 

2972 

2973 # Handles models with custom TFLite ops that cannot be resolved in 

2974 # TensorFlow. 

2975 load_model_in_session = True 

2976 try: 

2977 _import_graph_def(graph_def, name="") 

2978 except _NotFoundError: 

2979 load_model_in_session = False 

2980 

2981 if load_model_in_session: 

2982 # Check if graph is frozen. 

2983 if not _is_frozen_graph(sess): 

2984 raise ValueError("Please freeze the graph using freeze_graph.py.") 

2985 

2986 # Get input and output tensors. 

2987 input_tensors = _get_tensors_from_tensor_names( 

2988 sess.graph, input_arrays 

2989 ) 

2990 output_tensors = _get_tensors_from_tensor_names( 

2991 sess.graph, output_arrays 

2992 ) 

2993 _set_tensor_shapes(input_tensors, input_shapes) 

2994 

2995 return cls(sess.graph_def, input_tensors, output_tensors) 

2996 else: 

2997 if not input_shapes: 

2998 raise ValueError("input_shapes must be defined for this model.") 

2999 if set(input_arrays) != set(input_shapes.keys()): 

3000 raise ValueError( 

3001 "input_shapes must contain a value for each item " 

3002 "in input_array." 

3003 ) 

3004 

3005 input_arrays_with_shape = [ 

3006 (name, input_shapes[name]) for name in input_arrays 

3007 ] 

3008 return cls( 

3009 graph_def, 

3010 input_tensors=None, 

3011 output_tensors=None, 

3012 input_arrays_with_shape=input_arrays_with_shape, 

3013 output_arrays=output_arrays, 

3014 ) 

3015 

3016 @classmethod 

3017 def from_saved_model( 

3018 cls, 

3019 saved_model_dir, 

3020 input_arrays=None, 

3021 input_shapes=None, 

3022 output_arrays=None, 

3023 tag_set=None, 

3024 signature_key=None, 

3025 ): 

3026 """Creates a TFLiteConverter class from a SavedModel. 

3027 

3028 Args: 

3029 saved_model_dir: SavedModel directory to convert. 

3030 input_arrays: List of input tensors to freeze graph with. Uses input 

3031 arrays from SignatureDef when none are provided. (default None) 

3032 input_shapes: Dict of strings representing input tensor names to list of 

3033 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 

3034 Automatically determined when input shapes is None (e.g., {"foo" : 

3035 None}). (default None) 

3036 output_arrays: List of output tensors to freeze graph with. Uses output 

3037 arrays from SignatureDef when none are provided. (default None) 

3038 tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to 

3039 analyze. All tags in the tag set must be present. (default 

3040 {tf.saved_model.SERVING}) 

3041 signature_key: Key identifying SignatureDef containing inputs and outputs. 

3042 (default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY) 

3043 

3044 Returns: 

3045 TFLiteConverter class. 

3046 """ 

3047 # pylint: disable=protected-access 

3048 TFLiteConverterBase._set_original_model_type( 

3049 conversion_metdata_fb.ModelType.TF_SAVED_MODEL 

3050 ) 

3051 # pylint: enable=protected-access 

3052 if tag_set is None: 

3053 tag_set = set([_tag_constants.SERVING]) 

3054 if signature_key is None: 

3055 signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 

3056 

3057 saved_model_converter = TFLiteSavedModelConverter( 

3058 saved_model_dir, tag_set, [signature_key] 

3059 ) 

3060 if saved_model_converter.saved_model_dir: 

3061 return saved_model_converter 

3062 

3063 result = _freeze_saved_model( 

3064 saved_model_dir, 

3065 input_arrays, 

3066 input_shapes, 

3067 output_arrays, 

3068 tag_set, 

3069 signature_key, 

3070 ) 

3071 

3072 return cls( 

3073 graph_def=result[0], 

3074 input_tensors=result[1], 

3075 output_tensors=result[2], 

3076 experimental_debug_info_func=_build_debug_info_func(result[3]), 

3077 ) 

3078 

3079 @classmethod 

3080 def from_keras_model_file( 

3081 cls, 

3082 model_file, 

3083 input_arrays=None, 

3084 input_shapes=None, 

3085 output_arrays=None, 

3086 custom_objects=None, 

3087 ): 

3088 """Creates a TFLiteConverter class from a tf.keras model file. 

3089 

3090 Args: 

3091 model_file: Full filepath of HDF5 file containing the tf.keras model. 

3092 input_arrays: List of input tensors to freeze graph with. Uses input 

3093 arrays from SignatureDef when none are provided. (default None) 

3094 input_shapes: Dict of strings representing input tensor names to list of 

3095 integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). 

3096 Automatically determined when input shapes is None (e.g., {"foo" : 

3097 None}). (default None) 

3098 output_arrays: List of output tensors to freeze graph with. Uses output 

3099 arrays from SignatureDef when none are provided. (default None) 

3100 custom_objects: Dict mapping names (strings) to custom classes or 

3101 functions to be considered during model deserialization. (default None) 

3102 

3103 Returns: 

3104 TFLiteConverter class. 

3105 """ 

3106 # pylint: disable=protected-access 

3107 TFLiteConverterBase._set_original_model_type( 

3108 conversion_metdata_fb.ModelType.KERAS_MODEL 

3109 ) 

3110 # pylint: enable=protected-access 

3111 return TFLiteKerasModelConverter( 

3112 model_file, input_arrays, input_shapes, output_arrays, custom_objects 

3113 ) 

3114 

3115 # pylint: disable=useless-super-delegation 

3116 def convert(self): 

3117 """Converts a TensorFlow GraphDef based on instance variables. 

3118 

3119 Returns: 

3120 The converted data in serialized format. Either a TFLite Flatbuffer or a 

3121 Graphviz graph depending on value in `output_format`. 

3122 

3123 Raises: 

3124 ValueError: 

3125 Input shape is not specified. 

3126 None value for dimension in input_tensor. 

3127 """ 

3128 return super(TFLiteConverter, self).convert() 

3129 

3130 

3131@_tf_export(v1=["lite.TocoConverter"]) 

3132class TocoConverter: 

3133 """Convert a TensorFlow model into `output_format`. 

3134 

3135 This class has been deprecated. Please use `lite.TFLiteConverter` instead. 

3136 """ 

3137 

3138 @classmethod 

3139 @_deprecation.deprecated( 

3140 None, "Use `lite.TFLiteConverter.from_session` instead." 

3141 ) 

3142 def from_session(cls, sess, input_tensors, output_tensors): 

3143 """Creates a TocoConverter class from a TensorFlow Session.""" 

3144 return TFLiteConverter.from_session(sess, input_tensors, output_tensors) 

3145 

3146 @classmethod 

3147 @_deprecation.deprecated( 

3148 None, "Use `lite.TFLiteConverter.from_frozen_graph` instead." 

3149 ) 

3150 def from_frozen_graph( 

3151 cls, graph_def_file, input_arrays, output_arrays, input_shapes=None 

3152 ): 

3153 """Creates a TocoConverter class from a file containing a frozen graph.""" 

3154 return TFLiteConverter.from_frozen_graph( 

3155 graph_def_file, input_arrays, output_arrays, input_shapes 

3156 ) 

3157 

3158 @classmethod 

3159 @_deprecation.deprecated( 

3160 None, "Use `lite.TFLiteConverter.from_saved_model` instead." 

3161 ) 

3162 def from_saved_model( 

3163 cls, 

3164 saved_model_dir, 

3165 input_arrays=None, 

3166 input_shapes=None, 

3167 output_arrays=None, 

3168 tag_set=None, 

3169 signature_key=None, 

3170 ): 

3171 """Creates a TocoConverter class from a SavedModel.""" 

3172 return TFLiteConverter.from_saved_model( 

3173 saved_model_dir, 

3174 input_arrays, 

3175 input_shapes, 

3176 output_arrays, 

3177 tag_set, 

3178 signature_key, 

3179 ) 

3180 

3181 @classmethod 

3182 @_deprecation.deprecated( 

3183 None, "Use `lite.TFLiteConverter.from_keras_model_file` instead." 

3184 ) 

3185 def from_keras_model_file( 

3186 cls, model_file, input_arrays=None, input_shapes=None, output_arrays=None 

3187 ): 

3188 """Creates a TocoConverter class from a tf.keras model file.""" 

3189 return TFLiteConverter.from_keras_model_file( 

3190 model_file, input_arrays, input_shapes, output_arrays 

3191 )