Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/python/convert.py: 20%

343 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Converts a frozen graph into a TFLite FlatBuffer.""" 

16 

17import distutils.spawn 

18import enum 

19import hashlib 

20import os as _os 

21import platform as _platform 

22import subprocess as _subprocess 

23import tempfile as _tempfile 

24from typing import Optional 

25import warnings 

26 

27from tensorflow.compiler.mlir.quantization.stablehlo import quantization_options_pb2 as quant_opts_pb2 

28from tensorflow.lite.python import lite_constants 

29from tensorflow.lite.python import util 

30from tensorflow.lite.python import wrap_toco 

31from tensorflow.lite.python.convert_phase import Component 

32from tensorflow.lite.python.convert_phase import convert_phase 

33from tensorflow.lite.python.convert_phase import ConverterError 

34from tensorflow.lite.python.convert_phase import SubComponent 

35from tensorflow.lite.python.metrics import converter_error_data_pb2 

36from tensorflow.lite.python.metrics.wrapper import metrics_wrapper as _metrics_wrapper 

37from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2 

38from tensorflow.lite.toco import toco_flags_pb2 as _conversion_flags_pb2 

39from tensorflow.lite.toco import types_pb2 as _types_pb2 

40from tensorflow.lite.tools import flatbuffer_utils 

41from tensorflow.python.framework import dtypes 

42from tensorflow.python.framework import tensor_shape 

43from tensorflow.python.platform import resource_loader as _resource_loader 

44from tensorflow.python.util import deprecation 

45from tensorflow.python.util.tf_export import tf_export as _tf_export 

46 

47 

48def _is_quantized_input_stats_required( 

49 conversion_flags: _conversion_flags_pb2.TocoFlags, 

50) -> bool: 

51 """Checks if the `quantized_input_stats` flag is required for conversion. 

52 

53 Args: 

54 conversion_flags: A protocol buffer describing the conversion process. 

55 

56 Returns: 

57 True, if the `inference_type` or the `inference_input_type` is a quantized 

58 type and it is not post training quantization, else False. 

59 """ 

60 quantized_inference_types = [ 

61 _types_pb2.QUANTIZED_UINT8, 

62 _types_pb2.QUANTIZED_INT8, 

63 ] 

64 return ( 

65 conversion_flags.inference_type in quantized_inference_types 

66 or conversion_flags.inference_input_type in quantized_inference_types 

67 ) and not conversion_flags.post_training_quantize 

68 

69 

70def convert_tensor_tf_type_to_tflite_type( 

71 tf_type: dtypes.DType, usage: str = "" 

72) -> _types_pb2.IODataType: 

73 """Convert tensor type from tf type to tflite type. 

74 

75 Args: 

76 tf_type: TensorFlow type. 

77 usage: Text describing the reason for invoking this function. 

78 

79 Raises: 

80 ValueError: If `tf_type` is unsupported. 

81 

82 Returns: 

83 tflite_type: TFLite type. Refer to lite/toco/types.proto. 

84 """ 

85 mapping = { 

86 dtypes.float16: _types_pb2.FLOAT16, 

87 dtypes.float32: _types_pb2.FLOAT, 

88 dtypes.float64: _types_pb2.FLOAT64, 

89 dtypes.int8: _types_pb2.INT8, 

90 dtypes.int16: _types_pb2.INT16, 

91 dtypes.uint16: _types_pb2.UINT16, 

92 dtypes.int32: _types_pb2.INT32, 

93 dtypes.int64: _types_pb2.INT64, 

94 dtypes.uint8: _types_pb2.UINT8, 

95 dtypes.uint32: _types_pb2.UINT32, 

96 dtypes.uint64: _types_pb2.UINT64, 

97 dtypes.string: _types_pb2.STRING, 

98 dtypes.bool: _types_pb2.BOOL, 

99 dtypes.complex64: _types_pb2.COMPLEX64, 

100 dtypes.complex128: _types_pb2.COMPLEX128, 

101 } 

102 tflite_type = mapping.get(tf_type) 

103 if tflite_type is None: 

104 raise ValueError( 

105 "Unsupported TensorFlow type `{0}` provided for the {1}".format( 

106 tf_type, usage 

107 ) 

108 ) 

109 return tflite_type 

110 

111 

112# Only a few restricted tensor types are allowed for explicitly setting 

113# inference/input/output types. 

114def convert_inference_tf_type_to_tflite_type( 

115 tf_type: dtypes.DType, usage: str = "" 

116) -> _types_pb2.IODataType: 

117 """Convert inference type from tf type to tflite type. 

118 

119 Args: 

120 tf_type: TensorFlow type. 

121 usage: Text describing the reason for invoking this function. 

122 

123 Raises: 

124 ValueError: If `tf_type` is unsupported. 

125 

126 Returns: 

127 tflite_type: TFLite type. Refer to lite/toco/types.proto. 

128 """ 

129 mapping = { 

130 dtypes.float32: _types_pb2.FLOAT, 

131 dtypes.uint8: _types_pb2.QUANTIZED_UINT8, 

132 dtypes.int8: _types_pb2.QUANTIZED_INT8, 

133 dtypes.int16: _types_pb2.QUANTIZED_INT16, 

134 } 

135 tflite_type = mapping.get(tf_type) 

136 if tflite_type is None: 

137 raise ValueError( 

138 "Unsupported TensorFlow type `{0}` provided for the {1}".format( 

139 tf_type, usage 

140 ) 

141 ) 

142 return tflite_type 

143 

144 

145# Find the deprecated conversion binary using the resource loader if using from 

146# bazel, otherwise we are in a pip where console_scripts already has the tool. 

147if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY: 

148 _deprecated_conversion_binary = "" 

149else: 

150 _deprecated_conversion_binary = _resource_loader.get_path_to_datafile( 

151 "../toco/python/toco_from_protos" 

152 ) 

153 if not _os.path.exists(_deprecated_conversion_binary): 

154 _deprecated_conversion_binary = "toco_from_protos" 

155 

156 

157def _try_convert_to_unicode(output): 

158 if output is None: 

159 return "" 

160 

161 if isinstance(output, bytes): 

162 try: 

163 return output.decode("utf-8") 

164 except UnicodeDecodeError: 

165 pass 

166 return output 

167 

168 

169@_tf_export("lite.OpsSet") 

170class OpsSet(enum.Enum): 

171 """Enum class defining the sets of ops available to generate TFLite models. 

172 

173 WARNING: Experimental interface, subject to change. 

174 """ 

175 

176 # Convert model using TensorFlow Lite builtin ops. 

177 TFLITE_BUILTINS = "TFLITE_BUILTINS" 

178 

179 # Convert model using TensorFlow ops. Not all TensorFlow ops are available. 

180 # WARNING: Experimental interface, subject to change. 

181 SELECT_TF_OPS = "SELECT_TF_OPS" 

182 

183 # Convert model using only TensorFlow Lite quantized int8 operations. 

184 # Specifying this will throw an error for operations that do not yet have 

185 # quantized implementations. 

186 TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8" 

187 

188 # Convert model using only TensorFlow Lite operations with quantized int8 

189 # weights, int16 activations and int64 bias. 

190 # Specifying this will throw an error for operations that do not yet have 

191 # quantized implementations. 

192 # This quantization mode may be used in models for super-resolution, 

193 # audio signal processing or image de-noising. It improves accuracy 

194 # significantly, but only slightly increases the model size. 

195 # WARNING: These ops are currently experimental and have not yet been 

196 # finalized. 

197 # They are only compatible with CPU execution, and have not been optimized for 

198 # production. 

199 EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = ( 

200 "EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8" 

201 ) 

202 

203 # Convert model using only stablehlo ops. 

204 # This option can not be combined with other OpsSets. 

205 # The feature is in early development. 

206 # The code to execute StableHLO ops in the runtime is to be implemented 

207 # and the serialization format is not stabilized yet. 

208 

209 EXPERIMENTAL_STABLEHLO_OPS = "EXPERIMENTAL_STABLEHLO_OPS" 

210 

211 def __str__(self): 

212 return str(self.value) 

213 

214 @staticmethod 

215 def get_options(): 

216 """Returns a list of OpsSet options as a list of strings.""" 

217 return [str(option) for option in list(OpsSet)] 

218 

219 

220@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.QUANTIZE) 

221def mlir_quantize( 

222 input_data_str, 

223 disable_per_channel=False, 

224 fully_quantize=False, 

225 inference_type=_types_pb2.QUANTIZED_INT8, 

226 input_data_type=dtypes.float32, 

227 output_data_type=dtypes.float32, 

228 enable_numeric_verify=False, 

229 enable_whole_model_verify=False, 

230 denylisted_ops=None, 

231 denylisted_nodes=None, 

232 enable_variable_quantization=False, 

233): 

234 """Quantize `input_data_str` with calibration results. 

235 

236 Args: 

237 input_data_str: Input data in serialized form (e.g. a TFLITE model with 

238 calibration results). 

239 disable_per_channel: Bool indicating whether to do per-channel or per-tensor 

240 quantization 

241 fully_quantize: Bool indicating whether to fully quantize the model. Besides 

242 model body, the input/output will be quantized as well. 

243 inference_type: Data type for the activations. The default value is int8. 

244 input_data_type: Data type for the inputs. The default value is float32. 

245 output_data_type: Data type for the outputs. The default value is float32. 

246 enable_numeric_verify: Experimental. Subject to change. Bool indicating 

247 whether to add NumericVerify ops into the debug mode quantized model. 

248 enable_whole_model_verify: Experimental. Subject to change. Bool indicating 

249 whether to add verification for layer by layer, or on whole model. When 

250 disabled (per-layer) float and quantized ops will be run from same input 

251 (output of previous quantized layer). When enabled, float and quantized 

252 ops will run with respective float and quantized output of previous ops. 

253 denylisted_ops: Experimental. Subject to change. Set of ops to denylist. 

254 denylisted_nodes: Experimental. Subject to change. Set of notes to denylist. 

255 enable_variable_quantization: Experimental. Subject to change. Bool 

256 indicating whether to enable quantization of the residual variables 

257 remaining after the variable freezing pass. 

258 

259 Returns: 

260 Quantized model in serialized form (e.g. a TFLITE model) with floating-point 

261 inputs and outputs. 

262 """ 

263 return wrap_toco.wrapped_experimental_mlir_quantize( 

264 input_data_str, 

265 disable_per_channel, 

266 fully_quantize, 

267 inference_type, 

268 convert_tensor_tf_type_to_tflite_type(input_data_type), 

269 convert_tensor_tf_type_to_tflite_type(output_data_type), 

270 enable_numeric_verify, 

271 enable_whole_model_verify, 

272 denylisted_ops, 

273 denylisted_nodes, 

274 enable_variable_quantization, 

275 ) 

276 

277 

278@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.SPARSIFY) 

279def mlir_sparsify(input_data_str): 

280 """Sparsify `input_data_str` to encode sparse tensor with proper format. 

281 

282 Args: 

283 input_data_str: Input data in serialized form (e.g. a TFLITE model). 

284 

285 Returns: 

286 Sparsified model in serialized form (e.g. a TFLITE model). 

287 """ 

288 return wrap_toco.wrapped_experimental_mlir_sparsify(input_data_str) 

289 

290 

291def register_custom_opdefs(custom_opdefs_list): 

292 """Register the given custom opdefs to the TensorFlow global op registry. 

293 

294 Args: 

295 custom_opdefs_list: String representing the custom ops OpDefs that are 

296 included in the GraphDef. 

297 

298 Returns: 

299 True if the registration is successfully completed. 

300 """ 

301 return wrap_toco.wrapped_register_custom_opdefs(custom_opdefs_list) 

302 

303 

304def convert( 

305 model_flags: _model_flags_pb2.ModelFlags, 

306 conversion_flags: _conversion_flags_pb2.TocoFlags, 

307 input_data_str: Optional[str] = None, 

308 debug_info_str: Optional[str] = None, 

309 enable_mlir_converter: bool = True, 

310): 

311 """Converts `input_data_str` to a TFLite model. 

312 

313 Args: 

314 model_flags: Proto describing model properties, see `model_flags.proto`. 

315 conversion_flags: Proto describing conversion properties, see 

316 `toco/toco_flags.proto`. 

317 input_data_str: Input data in serialized form (e.g. a graphdef is common, or 

318 it can be hlo text or proto) 

319 debug_info_str: Serialized `GraphDebugInfo` proto describing logging 

320 information. 

321 enable_mlir_converter: Enables MLIR-based conversion. 

322 

323 Returns: 

324 Converted model in serialized form (e.g. a TFLITE model is common). 

325 Raises: 

326 ConverterError: When conversion fails in TFLiteConverter, usually due to 

327 ops not being supported. 

328 RuntimeError: When conversion fails, an exception is raised with the error 

329 message embedded. 

330 """ 

331 # Historically, deprecated conversion failures would trigger a crash, so we 

332 # attempt to run the converter out-of-process. The current MLIR conversion 

333 # pipeline surfaces errors instead, and can be safely run in-process. 

334 if enable_mlir_converter or not _deprecated_conversion_binary: 

335 try: 

336 return wrap_toco.wrapped_toco_convert( 

337 model_flags.SerializeToString(), 

338 conversion_flags.SerializeToString(), 

339 input_data_str, 

340 debug_info_str, 

341 enable_mlir_converter, 

342 ) 

343 except Exception as e: 

344 converter_error = ConverterError(str(e)) 

345 

346 for error_data in _metrics_wrapper.retrieve_collected_errors(): 

347 converter_error.append_error(error_data) 

348 # Seldom we encounter the case where an unsupported 

349 # `StatefulPartitionedCallOp` is not inlined and remains in the final 

350 # IR. If this occurs we can set `guarantee_all_funcs_one_use` and retry. 

351 # This makes the converter copy functions definitions called by 

352 # multiple StatefulPartitionedCall, thus allowing them to be properly 

353 # inlined. 

354 if ( 

355 error_data.error_code 

356 == converter_error_data_pb2.ConverterErrorData.ERROR_STATEFUL_PARTITIONED_CALL_IN_FINAL_IR 

357 and not conversion_flags.guarantee_all_funcs_one_use 

358 ): 

359 conversion_flags.guarantee_all_funcs_one_use = True 

360 return convert( 

361 model_flags, 

362 conversion_flags, 

363 input_data_str, 

364 debug_info_str, 

365 enable_mlir_converter, 

366 ) 

367 raise converter_error 

368 

369 return _run_deprecated_conversion_binary( 

370 model_flags.SerializeToString(), 

371 conversion_flags.SerializeToString(), 

372 input_data_str, 

373 debug_info_str, 

374 ) 

375 

376 

377@convert_phase( 

378 Component.CONVERT_TF_TO_TFLITE_MODEL, 

379 SubComponent.CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER, 

380) 

381def _run_deprecated_conversion_binary( 

382 model_flags_str, conversion_flags_str, input_data_str, debug_info_str=None 

383): 

384 """Convert `input_data_str` using deprecated conversion binary. 

385 

386 Args: 

387 model_flags_str: Serialized proto describing model properties, see 

388 `model_flags.proto`. 

389 conversion_flags_str: Serialized proto describing TFLite converter 

390 properties, see `toco/toco_flags.proto`. 

391 input_data_str: Input data in serialized form (e.g. a graphdef is common) 

392 debug_info_str: Serialized `GraphDebugInfo` proto describing logging 

393 information. (default None) 

394 

395 Returns: 

396 Converted model in serialized form (e.g. a TFLITE model is common). 

397 Raises: 

398 ConverterError: When cannot find the deprecated conversion binary. 

399 RuntimeError: When conversion fails, an exception is raised with the error 

400 message embedded. 

401 """ 

402 if distutils.spawn.find_executable(_deprecated_conversion_binary) is None: 

403 raise ConverterError("""Could not find `toco_from_protos` binary, make sure 

404your virtualenv bin directory or pip local bin directory is in your path. 

405In particular, if you have installed TensorFlow with --user, make sure you 

406add the install directory to your path. 

407 

408For example: 

409Linux: export PATH=$PATH:~/.local/bin/ 

410Mac: export PATH=$PATH:~/Library/Python/<version#>/bin 

411 

412Alternative, use virtualenv.""") 

413 # Windows and TemporaryFile are not that useful together, 

414 # since you cannot have two readers/writers. So we have to 

415 # make the temporaries and close and delete them explicitly. 

416 conversion_filename, model_filename, input_filename, output_filename = ( 

417 None, 

418 None, 

419 None, 

420 None, 

421 ) 

422 try: 

423 # Build all input files 

424 with _tempfile.NamedTemporaryFile( 

425 delete=False 

426 ) as fp_conversion, _tempfile.NamedTemporaryFile( 

427 delete=False 

428 ) as fp_model, _tempfile.NamedTemporaryFile( 

429 delete=False 

430 ) as fp_input, _tempfile.NamedTemporaryFile( 

431 delete=False 

432 ) as fp_debug: 

433 conversion_filename = fp_conversion.name 

434 input_filename = fp_input.name 

435 model_filename = fp_model.name 

436 debug_filename = fp_debug.name 

437 

438 fp_model.write(model_flags_str) 

439 fp_conversion.write(conversion_flags_str) 

440 fp_input.write(input_data_str) 

441 debug_info_str = debug_info_str if debug_info_str else "" 

442 # if debug_info_str contains a "string value", then the call to 

443 # fp_debug.write(debug_info_str) will fail with the following error 

444 # 

445 # TypeError: a bytes-like object is required, not 'str' 

446 # 

447 # Some of the subtests within the "convert_test" unit-test fail 

448 # with the error shown above. So watch out for that scenario and 

449 # convert debug_info_str to bytes where needed 

450 if not isinstance(debug_info_str, bytes): 

451 fp_debug.write(debug_info_str.encode("utf-8")) 

452 else: 

453 fp_debug.write(debug_info_str) 

454 

455 # Reserve an output file 

456 with _tempfile.NamedTemporaryFile(delete=False) as fp: 

457 output_filename = fp.name 

458 

459 # Run 

460 cmd = [ 

461 _deprecated_conversion_binary, 

462 model_filename, 

463 conversion_filename, 

464 input_filename, 

465 output_filename, 

466 "--debug_proto_file={}".format(debug_filename), 

467 ] 

468 cmdline = " ".join(cmd) 

469 is_windows = _platform.system() == "Windows" 

470 proc = _subprocess.Popen( 

471 cmdline, 

472 shell=True, 

473 stdout=_subprocess.PIPE, 

474 stderr=_subprocess.STDOUT, 

475 close_fds=not is_windows, 

476 ) 

477 stdout, stderr = proc.communicate() 

478 exitcode = proc.returncode 

479 if exitcode == 0: 

480 with open(output_filename, "rb") as fp: 

481 return fp.read() 

482 else: 

483 stdout = _try_convert_to_unicode(stdout) 

484 stderr = _try_convert_to_unicode(stderr) 

485 raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr)) 

486 finally: 

487 # Must manually cleanup files. 

488 for filename in [ 

489 conversion_filename, 

490 input_filename, 

491 model_filename, 

492 output_filename, 

493 ]: 

494 try: 

495 _os.unlink(filename) 

496 except (OSError, TypeError): 

497 pass 

498 

499 

500def build_model_flags( 

501 change_concat_input_ranges=False, 

502 allow_nonexistent_arrays=False, 

503 saved_model_dir=None, 

504 saved_model_version=0, 

505 saved_model_tags=None, 

506 saved_model_exported_names=None, 

507 **_ 

508): 

509 """Builds the model flags object from params. 

510 

511 Args: 

512 change_concat_input_ranges: Boolean to change behavior of min/max ranges for 

513 inputs and outputs of the concat operator for quantized models. Changes 

514 the ranges of concat operator overlap when true. (default False) 

515 allow_nonexistent_arrays: Allow specifying array names that don't exist or 

516 are unused in the final graph. (default False) 

517 saved_model_dir: Filepath of the saved model to be converted. This value 

518 will be non-empty only when the saved model import path will be used. 

519 Otherwises, the graph def-based conversion will be processed. 

520 saved_model_version: SavedModel file format version of The saved model file 

521 to be converted. This value will be set only when the SavedModel import 

522 path will be used. 

523 saved_model_tags: Set of string saved model tags, formatted in the 

524 comma-separated value. This value will be set only when the SavedModel 

525 import path will be used. 

526 saved_model_exported_names: Names to be exported (default: export all) when 

527 the saved model import path is on. This value will be set only when the 

528 SavedModel import path will be used. 

529 

530 Returns: 

531 model_flags: protocol buffer describing the model. 

532 """ 

533 model_flags = _model_flags_pb2.ModelFlags() 

534 model_flags.change_concat_input_ranges = change_concat_input_ranges 

535 model_flags.allow_nonexistent_arrays = allow_nonexistent_arrays 

536 if saved_model_dir: 

537 model_flags.saved_model_dir = saved_model_dir 

538 model_flags.saved_model_version = saved_model_version 

539 if saved_model_tags: 

540 model_flags.saved_model_tags.extend(saved_model_tags) 

541 if saved_model_exported_names: 

542 model_flags.saved_model_exported_names.extend(saved_model_exported_names) 

543 return model_flags 

544 

545 

546def build_conversion_flags( 

547 inference_type=dtypes.float32, 

548 inference_input_type=None, 

549 input_format=lite_constants.TENSORFLOW_GRAPHDEF, 

550 output_format=lite_constants.TFLITE, 

551 default_ranges_stats=None, 

552 drop_control_dependency=True, 

553 reorder_across_fake_quant=False, 

554 allow_custom_ops=False, 

555 post_training_quantize=False, 

556 quantize_to_float16=False, 

557 dump_graphviz_dir=None, 

558 dump_graphviz_video=False, 

559 target_ops=None, 

560 conversion_summary_dir=None, 

561 select_user_tf_ops=None, 

562 allow_all_select_tf_ops=False, 

563 enable_tflite_resource_variables=True, 

564 unfold_batchmatmul=True, 

565 lower_tensor_list_ops=True, 

566 default_to_single_batch_in_tensor_list_ops=False, 

567 accumulation_type=None, 

568 allow_bfloat16=False, 

569 unfold_large_splat_constant=False, 

570 supported_backends=None, 

571 disable_per_channel_quantization=False, 

572 enable_mlir_dynamic_range_quantizer=False, 

573 tf_quantization_mode=None, 

574 disable_infer_tensor_range=False, 

575 use_fake_quant_num_bits=False, 

576 enable_dynamic_update_slice=False, 

577 preserve_assert_op=False, 

578 guarantee_all_funcs_one_use=False, 

579 enable_mlir_variable_quantization=False, 

580 disable_fuse_mul_and_fc=False, 

581 quantization_options: Optional[quant_opts_pb2.QuantizationOptions] = None, 

582 **_ 

583): 

584 """Builds protocol buffer describing a conversion of a model. 

585 

586 Typically this is to convert from TensorFlow GraphDef to TFLite, in which 

587 case the default `input_format` and `output_format` are sufficient. 

588 

589 Args: 

590 inference_type: Data type of numeric arrays, excluding the input layer. 

591 (default tf.float32, must be in {tf.float32, tf.int8, tf.uint8}) 

592 inference_input_type: Data type of the numeric arrays in the input layer. If 

593 `inference_input_type` is in {tf.int8, tf.uint8}, then 

594 `quantized_input_stats` must be provided. (default is the value assigned 

595 to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8}) 

596 input_format: Type of data to read. (default TENSORFLOW_GRAPHDEF, must be in 

597 {TENSORFLOW_GRAPHDEF}) 

598 output_format: Output file format. (default TFLITE, must be in {TFLITE, 

599 GRAPHVIZ_DOT}) 

600 default_ranges_stats: Tuple of integers representing (min, max) range values 

601 for all arrays without a specified range. Intended for experimenting with 

602 quantization via "dummy quantization". (default None) 

603 drop_control_dependency: Boolean indicating whether to drop control 

604 dependencies silently. This is due to TFLite not supporting control 

605 dependencies. (default True) 

606 reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant 

607 nodes in unexpected locations. Used when the location of the FakeQuant 

608 nodes is preventing graph transformations necessary to convert the graph. 

609 Results in a graph that differs from the quantized training graph, 

610 potentially causing differing arithmetic behavior. (default False) 

611 allow_custom_ops: Boolean indicating whether to allow custom operations. 

612 When false any unknown operation is an error. When true, custom ops are 

613 created for any op that is unknown. The developer will need to provide 

614 these to the TensorFlow Lite runtime with a custom resolver. (default 

615 False) 

616 post_training_quantize: Boolean indicating whether to quantize the weights 

617 of the converted float model. Model size will be reduced and there will be 

618 latency improvements (at the cost of accuracy). (default False) If 

619 quantization_options is set, all quantization arg will be ignored. 

620 quantize_to_float16: Boolean indicating whether to convert float buffers to 

621 float16. (default False) 

622 dump_graphviz_dir: Full filepath of folder to dump the graphs at various 

623 stages of processing GraphViz .dot files. Preferred over 

624 --output_format=GRAPHVIZ_DOT in order to keep the requirements of the 

625 output file. (default None) 

626 dump_graphviz_video: Boolean indicating whether to dump the graph after 

627 every graph transformation. (default False) 

628 target_ops: Experimental flag, subject to change. Set of OpsSet options 

629 indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS])) 

630 conversion_summary_dir: A string, the path to the generated conversion logs. 

631 select_user_tf_ops: List of user's defined TensorFlow ops need to be 

632 supported in the TensorFlow Lite runtime. These ops will be supported as 

633 select TensorFlow ops. 

634 allow_all_select_tf_ops: If True, automatically add all TF ops (including 

635 custom TF ops) to the converted model as flex ops. 

636 enable_tflite_resource_variables: Experimental flag, subject to change. 

637 Enables conversion of resource variables. (default False) 

638 unfold_batchmatmul: Whether to unfold tf.BatchMatMul to a set of 

639 tfl.fully_connected ops. If not, translate to tfl.batch_matmul. 

640 lower_tensor_list_ops: Whether to lower tensor list ops to builtin ops. If 

641 not, use Flex tensor list ops. 

642 default_to_single_batch_in_tensor_list_ops: Whether to force to use batch 

643 size one when the tensor list ops has the unspecified batch size. 

644 accumulation_type: Data type of the accumulators in quantized inference. 

645 Typically used for float16 quantization and is either fp16 or fp32. 

646 allow_bfloat16: Whether the converted model supports reduced precision 

647 inference with the bfloat16 type. 

648 unfold_large_splat_constant: Whether to unfold large splat constant tensors 

649 in the flatbuffer model to reduce size. 

650 supported_backends: List of TFLite backends which needs to check 

651 compatibility. 

652 disable_per_channel_quantization: Disable per-channel quantized weights for 

653 dynamic range quantization. Only per-tensor quantization will be used. 

654 enable_mlir_dynamic_range_quantizer: Enable MLIR dynamic range quantization. 

655 If False, the old converter dynamic range quantizer is used. 

656 tf_quantization_mode: Indicates the mode of TF Quantization when the output 

657 model is used for TF Quantization. 

658 disable_infer_tensor_range: Disable infering tensor ranges. 

659 use_fake_quant_num_bits: Allow quantization parameters to be calculated from 

660 num_bits attribute. 

661 enable_dynamic_update_slice: Enable to convert to DynamicUpdateSlice op. 

662 (default: False). 

663 preserve_assert_op: Whether to preserve `TF::AssertOp` (default: False). 

664 guarantee_all_funcs_one_use: Whether to clone functions so that each 

665 function only has a single use. This option will be helpful if the 

666 conversion fails when the `PartitionedCall` or `StatefulPartitionedCall` 

667 can't be properly inlined (default: False). 

668 enable_mlir_variable_quantization: Enable MLIR variable quantization. There 

669 is a variable freezing pass, but some variables may not be fully frozen by 

670 it. This flag enables quantization of those residual variables in the MLIR 

671 graph. 

672 disable_fuse_mul_and_fc: Disable fusing input multiplication with 

673 fullyconnected operations. Useful when quantizing weights. 

674 quantization_options: Config to indicate quantization options of each 

675 components (ex: weight, bias, activation). This can be a preset method or 

676 a custom method, and allows finer, modular control. This option will 

677 override any other existing quantization flags. We plan on gradually 

678 migrating all quantization-related specs into this option. 

679 

680 Returns: 

681 conversion_flags: protocol buffer describing the conversion process. 

682 Raises: 

683 ValueError, if the input tensor type is unknown. 

684 """ 

685 conversion_flags = _conversion_flags_pb2.TocoFlags() 

686 conversion_flags.inference_type = convert_inference_tf_type_to_tflite_type( 

687 inference_type, usage="inference_type flag" 

688 ) 

689 if inference_input_type: 

690 conversion_flags.inference_input_type = ( 

691 convert_inference_tf_type_to_tflite_type( 

692 inference_input_type, usage="inference_input_type flag" 

693 ) 

694 ) 

695 else: 

696 conversion_flags.inference_input_type = conversion_flags.inference_type 

697 conversion_flags.input_format = input_format 

698 conversion_flags.output_format = output_format 

699 if default_ranges_stats: 

700 conversion_flags.default_ranges_min = default_ranges_stats[0] 

701 conversion_flags.default_ranges_max = default_ranges_stats[1] 

702 conversion_flags.drop_control_dependency = drop_control_dependency 

703 conversion_flags.reorder_across_fake_quant = reorder_across_fake_quant 

704 conversion_flags.allow_custom_ops = allow_custom_ops 

705 conversion_flags.post_training_quantize = post_training_quantize 

706 conversion_flags.quantize_to_float16 = quantize_to_float16 

707 if dump_graphviz_dir: 

708 conversion_flags.dump_graphviz_dir = dump_graphviz_dir 

709 conversion_flags.dump_graphviz_include_video = dump_graphviz_video 

710 if target_ops: 

711 if OpsSet.SELECT_TF_OPS in target_ops: 

712 conversion_flags.enable_select_tf_ops = True 

713 if set(target_ops) == {OpsSet.SELECT_TF_OPS}: 

714 conversion_flags.force_select_tf_ops = True 

715 if OpsSet.EXPERIMENTAL_STABLEHLO_OPS in target_ops: 

716 conversion_flags.convert_to_stablehlo = True 

717 if OpsSet.EXPERIMENTAL_STABLEHLO_OPS in target_ops and len(target_ops) > 1: 

718 raise ValueError( 

719 "StableHLO Ops set can not be specified with other Ops set together" 

720 ) 

721 if conversion_summary_dir: 

722 conversion_flags.conversion_summary_dir = conversion_summary_dir 

723 if select_user_tf_ops: 

724 conversion_flags.select_user_tf_ops.extend(select_user_tf_ops) 

725 conversion_flags.allow_all_select_tf_ops = allow_all_select_tf_ops 

726 conversion_flags.enable_tflite_resource_variables = ( 

727 enable_tflite_resource_variables 

728 ) 

729 conversion_flags.unfold_batchmatmul = unfold_batchmatmul 

730 conversion_flags.lower_tensor_list_ops = lower_tensor_list_ops 

731 conversion_flags.default_to_single_batch_in_tensor_list_ops = ( 

732 default_to_single_batch_in_tensor_list_ops 

733 ) 

734 if accumulation_type: 

735 conversion_flags.accumulation_type = convert_tensor_tf_type_to_tflite_type( 

736 accumulation_type, usage="accumulation_type flag" 

737 ) 

738 conversion_flags.allow_bfloat16 = allow_bfloat16 

739 conversion_flags.unfold_large_splat_constant = unfold_large_splat_constant 

740 if supported_backends: 

741 conversion_flags.supported_backends.extend(supported_backends) 

742 conversion_flags.disable_per_channel_quantization = ( 

743 disable_per_channel_quantization 

744 ) 

745 conversion_flags.enable_mlir_dynamic_range_quantizer = ( 

746 enable_mlir_dynamic_range_quantizer 

747 ) 

748 conversion_flags.enable_dynamic_update_slice = enable_dynamic_update_slice 

749 conversion_flags.preserve_assert_op = preserve_assert_op 

750 conversion_flags.guarantee_all_funcs_one_use = guarantee_all_funcs_one_use 

751 if tf_quantization_mode: 

752 conversion_flags.tf_quantization_mode = tf_quantization_mode 

753 conversion_flags.disable_infer_tensor_range = disable_infer_tensor_range 

754 conversion_flags.use_fake_quant_num_bits = use_fake_quant_num_bits 

755 conversion_flags.enable_mlir_variable_quantization = ( 

756 enable_mlir_variable_quantization 

757 ) 

758 conversion_flags.disable_fuse_mul_and_fc = disable_fuse_mul_and_fc 

759 if quantization_options: 

760 conversion_flags.quantization_options.CopyFrom(quantization_options) 

761 return conversion_flags 

762 

763 

764@convert_phase( 

765 Component.CONVERT_TF_TO_TFLITE_MODEL, SubComponent.CONVERT_GRAPHDEF 

766) 

767def convert_graphdef_with_arrays( 

768 input_data, 

769 input_arrays_with_shape, 

770 output_arrays, 

771 control_output_arrays, 

772 **kwargs 

773): 

774 """Convert a frozen GraphDef that can't be loaded in TF. 

775 

776 Conversion can be customized by providing arguments that are forwarded to 

777 `build_model_flags` and `build_conversion_flags` (see documentation). 

778 

779 Args: 

780 input_data: Input data (i.e. often `sess.graph_def`), 

781 input_arrays_with_shape: Tuple of strings representing input tensor names 

782 and list of integers representing input shapes (e.g., [("foo" : [1, 16, 

783 16, 3])]). Use only when graph cannot be loaded into TensorFlow and when 

784 `input_tensors` is None. 

785 output_arrays: List of output tensors to freeze graph with. Use only when 

786 graph cannot be loaded into TensorFlow and when `output_tensors` is None. 

787 control_output_arrays: Control output node names. This is used when 

788 converting a Graph with no output tensors. For example, if the graph's 

789 last operation is a Print op, just specify that op's name in this field. 

790 This can be used together with the `output_arrays` parameter. 

791 **kwargs: See `build_model_flags` and `build_conversion_flags`. 

792 

793 Returns: 

794 The converted data. For example if TFLite was the destination, then 

795 this will be a tflite flatbuffer in a bytes array. 

796 

797 Raises: 

798 Defined in `build_conversion_flags`. 

799 """ 

800 model_flags = build_model_flags(**kwargs) 

801 conversion_flags = build_conversion_flags(**kwargs) 

802 enable_mlir_converter = kwargs.get("enable_mlir_converter", True) 

803 quantized_input_stats = kwargs.get("quantized_input_stats", None) 

804 

805 for idx, (name, shape) in enumerate(input_arrays_with_shape): 

806 input_array = model_flags.input_arrays.add() 

807 if _is_quantized_input_stats_required(conversion_flags): 

808 if quantized_input_stats: 

809 input_array.mean_value, input_array.std_value = quantized_input_stats[ 

810 idx 

811 ] 

812 else: 

813 raise ValueError( 

814 "The `quantized_input_stats` flag must be defined when either " 

815 "`inference_type` flag or `inference_input_type` flag is set to " 

816 "tf.int8 or tf.uint8." 

817 ) 

818 input_array.name = name 

819 input_array.shape.dims.extend(list(map(int, shape))) 

820 

821 if output_arrays: 

822 for name in output_arrays: 

823 model_flags.output_arrays.append(name) 

824 if control_output_arrays: 

825 for name in control_output_arrays: 

826 model_flags.control_output_arrays.append(name) 

827 

828 data = convert( 

829 model_flags, 

830 conversion_flags, 

831 input_data.SerializeToString(), 

832 debug_info_str=None, 

833 enable_mlir_converter=enable_mlir_converter, 

834 ) 

835 return data 

836 

837 

838@convert_phase( 

839 Component.CONVERT_TF_TO_TFLITE_MODEL, SubComponent.CONVERT_GRAPHDEF 

840) 

841def convert_graphdef(input_data, input_tensors, output_tensors, **kwargs): 

842 """Convert a frozen GraphDef model using the TF Lite converter. 

843 

844 Conversion can be customized by providing arguments that are forwarded to 

845 `build_model_flags` and `build_conversion_flags` (see documentation). 

846 

847 Args: 

848 input_data: Input data (i.e. often `sess.graph_def`), 

849 input_tensors: List of input tensors. Type and shape are computed using 

850 `foo.shape` and `foo.dtype`. 

851 output_tensors: List of output tensors (only .name is used from this). 

852 **kwargs: See `build_model_flags` and `build_conversion_flags`. 

853 

854 Returns: 

855 The converted data. For example if TFLite was the destination, then 

856 this will be a tflite flatbuffer in a bytes array. 

857 

858 Raises: 

859 Defined in `build_conversion_flags`. 

860 """ 

861 model_flags = build_model_flags(**kwargs) 

862 conversion_flags = build_conversion_flags(**kwargs) 

863 saved_model_dir = kwargs.get("saved_model_dir", None) 

864 input_shapes = kwargs.get("input_shapes", None) 

865 enable_mlir_converter = kwargs.get("enable_mlir_converter", True) 

866 quantized_input_stats = kwargs.get("quantized_input_stats", None) 

867 debug_info = kwargs.get("debug_info", None) 

868 

869 for idx, input_tensor in enumerate(input_tensors): 

870 input_array = model_flags.input_arrays.add() 

871 if saved_model_dir: 

872 input_array.name = input_tensor.name 

873 else: 

874 input_array.name = util.get_tensor_name(input_tensor) 

875 input_array.data_type = convert_tensor_tf_type_to_tflite_type( 

876 input_tensor.dtype, usage="input type of the TensorFlow model" 

877 ) 

878 

879 if _is_quantized_input_stats_required(conversion_flags): 

880 if quantized_input_stats: 

881 input_array.mean_value, input_array.std_value = quantized_input_stats[ 

882 idx 

883 ] 

884 else: 

885 # We should ideally raise an error here, but we don't as it would break 

886 # several models/projects that depend on this workflow. 

887 warnings.warn( 

888 "Statistics for quantized inputs were expected, but not " 

889 "specified; continuing anyway." 

890 ) 

891 

892 if input_shapes is None: 

893 shape = input_tensor.shape 

894 else: 

895 shape = input_shapes[idx] 

896 

897 if shape.rank is not None: 

898 # Create shapes with -1 for unknown dimensions. 

899 dims = [] 

900 for dim in shape: 

901 if dim is None or ( 

902 isinstance(dim, tensor_shape.Dimension) and dim.value is None 

903 ): 

904 dims.append(-1) 

905 else: 

906 dims.append(int(dim)) 

907 input_array.shape.dims.extend(dims) 

908 input_array.shape.unknown_rank = False 

909 else: 

910 input_array.shape.unknown_rank = True 

911 

912 for output_tensor in output_tensors: 

913 if saved_model_dir: 

914 model_flags.output_arrays.append(output_tensor.name) 

915 else: 

916 model_flags.output_arrays.append(util.get_tensor_name(output_tensor)) 

917 

918 data = convert( 

919 model_flags, 

920 conversion_flags, 

921 input_data.SerializeToString(), 

922 debug_info_str=debug_info.SerializeToString() if debug_info else None, 

923 enable_mlir_converter=enable_mlir_converter, 

924 ) 

925 return data 

926 

927 

928@convert_phase( 

929 Component.CONVERT_TF_TO_TFLITE_MODEL, SubComponent.CONVERT_SAVED_MODEL 

930) 

931def convert_saved_model(**kwargs): 

932 """Converts a SavedModel using TF Lite converter.""" 

933 model_flags = build_model_flags(**kwargs) 

934 conversion_flags = build_conversion_flags(**kwargs) 

935 data = convert( 

936 model_flags, 

937 conversion_flags, 

938 input_data_str=None, 

939 debug_info_str=None, 

940 enable_mlir_converter=True, 

941 ) 

942 return data 

943 

944 

945@convert_phase( 

946 Component.CONVERT_TF_TO_TFLITE_MODEL, SubComponent.CONVERT_JAX_HLO 

947) 

948def convert_jax_hlo(input_content, input_names, is_proto_format, **kwargs): 

949 """Converts a Jax hlo-based model using TFLite converter.""" 

950 model_flags = _model_flags_pb2.ModelFlags() 

951 model_flags.use_hlo_import = True 

952 if is_proto_format: 

953 model_flags.hlo_file_type = _model_flags_pb2.ModelFlags.HLO_PROTO 

954 else: 

955 model_flags.hlo_file_type = _model_flags_pb2.ModelFlags.HLO_TEXT 

956 

957 # Build input names. 

958 for input_name in input_names: 

959 input_array = model_flags.input_arrays.add() 

960 input_array.name = input_name 

961 

962 conversion_flags = build_conversion_flags(**kwargs) 

963 data = convert( 

964 model_flags, 

965 conversion_flags, 

966 input_data_str=input_content, 

967 debug_info_str=None, 

968 enable_mlir_converter=True, 

969 ) 

970 return data 

971 

972 

973@_tf_export(v1=["lite.toco_convert"]) 

974@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.") 

975def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): 

976 """Convert a TensorFlow GraphDef to TFLite. 

977 

978 This function is deprecated. Please use `tf.lite.TFLiteConverter` API instead. 

979 Conversion can be customized by providing arguments that are forwarded to 

980 `build_model_flags` and `build_conversion_flags` (see documentation for 

981 details). 

982 Args: 

983 input_data: Input data (i.e. often `sess.graph_def`). 

984 input_tensors: List of input tensors. Type and shape are computed using 

985 `foo.shape` and `foo.dtype`. 

986 output_tensors: List of output tensors (only .name is used from this). 

987 *args: See `build_model_flags` and `build_conversion_flags`. 

988 **kwargs: See `build_model_flags` and `build_conversion_flags`. 

989 

990 Returns: 

991 The converted TensorFlow Lite model in a bytes array. 

992 

993 Raises: 

994 Defined in `convert`. 

995 """ 

996 kwargs["enable_mlir_converter"] = kwargs.get("enable_mlir_converter", False) 

997 return convert_graphdef( 

998 input_data, input_tensors, output_tensors, *args, **kwargs 

999 ) 

1000 

1001 

1002def deduplicate_readonly_buffers(tflite_model): 

1003 """Generates a new model byte array after deduplicating readonly buffers. 

1004 

1005 This function should be invoked after the model optimization toolkit. The 

1006 model optimization toolkit assumes that each tensor object owns its each 

1007 buffer separately. 

1008 

1009 Args: 

1010 tflite_model: TFLite flatbuffer in a byte array to be deduplicated. 

1011 

1012 Returns: 

1013 TFLite flatbuffer in a bytes array, processed with the deduplication method. 

1014 """ 

1015 # Load TFLite Flatbuffer byte array into an object. 

1016 model = flatbuffer_utils.convert_bytearray_to_object(tflite_model) 

1017 

1018 # Get all the read-only buffers, which can be modified without causing any 

1019 # issue in the graph invocation stage. 

1020 read_only_buffer_indices = set() 

1021 for subgraph in model.subgraphs: 

1022 # To get all the read-only buffers: 

1023 # (1) Get all read-only input tensors. 

1024 # (2) Discard intermediate or output tensors. 

1025 # (3) Discard the subgraph's input/output tensors. 

1026 # (4) Gather the buffers of the read-only input tensors. 

1027 

1028 # (1) Get read-only input tensors. 

1029 read_only_input_tensor_indices = set() 

1030 for op in subgraph.operators: 

1031 if op.inputs is None: 

1032 continue 

1033 for i, input_tensor_idx in enumerate(op.inputs): 

1034 # Ignore mutable tensors. 

1035 if op.mutatingVariableInputs is not None: 

1036 # Ignore invalid tensors. 

1037 if ( 

1038 i < len(op.mutatingVariableInputs) 

1039 and op.mutatingVariableInputs[i] 

1040 ): 

1041 continue 

1042 # Ignore variable tensors. 

1043 if subgraph.tensors[input_tensor_idx].isVariable: 

1044 continue 

1045 read_only_input_tensor_indices.add(input_tensor_idx) 

1046 

1047 # (2) Discard intermediate or output tensors. 

1048 for op in subgraph.operators: 

1049 if op.outputs is not None: 

1050 for output_tensor_idx in op.outputs: 

1051 read_only_input_tensor_indices.discard(output_tensor_idx) 

1052 if op.intermediates is not None: 

1053 for intermediate_tensor_idx in op.intermediates: 

1054 read_only_input_tensor_indices.discard(intermediate_tensor_idx) 

1055 

1056 # (3) Discard the subgraph's input and output tensors. 

1057 if subgraph.inputs is not None: 

1058 for input_tensor_idx in subgraph.inputs: 

1059 read_only_input_tensor_indices.discard(input_tensor_idx) 

1060 if subgraph.outputs is not None: 

1061 for output_tensor_idx in subgraph.outputs: 

1062 read_only_input_tensor_indices.discard(output_tensor_idx) 

1063 

1064 # (4) Gather the buffers of the read-only input tensors. 

1065 for tensor_idx in read_only_input_tensor_indices: 

1066 read_only_buffer_indices.add(subgraph.tensors[tensor_idx].buffer) 

1067 

1068 # Ignore invalid negative index or zero-sized buffers. 

1069 for buffer_idx in read_only_buffer_indices.copy(): 

1070 if buffer_idx < 0 or ( 

1071 model.buffers[buffer_idx].data is None 

1072 or isinstance(model.buffers[buffer_idx].data, list) 

1073 or model.buffers[buffer_idx].data.size == 0 

1074 ): 

1075 read_only_buffer_indices.discard(buffer_idx) 

1076 

1077 class BufferIndex: 

1078 """A class to store index, size, hash of the buffers in TFLite model.""" 

1079 

1080 def __init__(self, idx, size, hash_value): 

1081 self.idx = idx 

1082 self.size = size 

1083 self.hash_value = hash_value 

1084 

1085 read_only_buffers = list( 

1086 map( 

1087 lambda index: BufferIndex( # pylint: disable=g-long-lambda 

1088 index, 

1089 model.buffers[index].data.size, 

1090 hashlib.md5(model.buffers[index].data.data.tobytes()).hexdigest(), 

1091 ), 

1092 read_only_buffer_indices, 

1093 ) 

1094 ) 

1095 

1096 # Sort read_only_buffers by buffer size & hash in descending order. 

1097 read_only_buffers = sorted( 

1098 read_only_buffers, 

1099 key=lambda buffer: (buffer.size, buffer.hash_value), 

1100 reverse=True, 

1101 ) 

1102 

1103 # Create a map of duplicate buffers (same size and same type). 

1104 # eg: In [1, 2, 3, 4, 5, 6] if (1, 4, 6) and (2, 5) are each, groups of buffer 

1105 # indices of the same size and type, then the map would be {4:1, 6:1, 5:2} 

1106 duplicate_buffer_map = {} 

1107 for i, buffer_i in enumerate(read_only_buffers): 

1108 # This buffer is a duplicate. 

1109 if buffer_i.idx in duplicate_buffer_map: 

1110 continue 

1111 # This buffer is unique. Scan rest of the list to find duplicates 

1112 # of this buffer and mark them accordingly. 

1113 for buffer_j in read_only_buffers[i + 1 :]: 

1114 if buffer_j.idx in duplicate_buffer_map: 

1115 continue 

1116 if buffer_i.size != buffer_j.size: 

1117 break 

1118 if buffer_i.hash_value != buffer_j.hash_value: 

1119 continue 

1120 # Found duplicate. Nullify j-th buffer and use i-th buffer instead. 

1121 duplicate_buffer_map[buffer_j.idx] = buffer_i.idx 

1122 

1123 # Make the duplicated tensors use the single shared buffer index. 

1124 for subgraph in model.subgraphs: 

1125 for op in subgraph.operators: 

1126 if op.inputs is None: 

1127 continue 

1128 for input_tensor in op.inputs: 

1129 buffer_idx = subgraph.tensors[input_tensor].buffer 

1130 if buffer_idx in duplicate_buffer_map: 

1131 subgraph.tensors[input_tensor].buffer = duplicate_buffer_map[ 

1132 buffer_idx 

1133 ] 

1134 

1135 # Nullify the unused buffers. 

1136 for idx in duplicate_buffer_map: 

1137 model.buffers[idx].data = None 

1138 

1139 # Return a TFLite flatbuffer as a byte array. 

1140 return flatbuffer_utils.convert_object_to_bytearray(model)