Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/python/util.py: 13%

482 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Functions used by multiple converter files.""" 

16 

17import copy 

18import datetime 

19import sys 

20 

21from absl import logging 

22 

23import flatbuffers 

24from tensorflow.core.framework import graph_debug_info_pb2 

25from tensorflow.core.protobuf import config_pb2 as _config_pb2 

26from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 

27from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb 

28from tensorflow.lite.python import schema_py_generated as schema_fb 

29from tensorflow.lite.python import schema_util 

30from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util 

31from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs 

32from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes 

33from tensorflow.lite.tools import flatbuffer_utils 

34from tensorflow.python.eager import function 

35from tensorflow.python.framework import convert_to_constants as _convert_to_constants 

36from tensorflow.python.framework import dtypes 

37from tensorflow.python.framework import error_interpolation as _error_interpolation 

38from tensorflow.python.grappler import tf_optimizer 

39from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph 

40 

41# The field name of conversion metadata in the flatbuffer file. 

42CONVERSION_METADATA_FIELD_NAME = "CONVERSION_METADATA" 

43 

44# Keras functions used by TFLite 

45model_input_signature = _tflite_keras_util.model_input_signature 

46trace_model_call = _tflite_keras_util.trace_model_call 

47 

48# Jax functions used by TFLite 

49# pylint: disable=g-import-not-at-top 

50# pylint: disable=unused-import 

51try: 

52 from jax import xla_computation as _xla_computation 

53except ImportError: 

54 _xla_computation = None 

55# pylint: enable=g-import-not-at-top 

56# pylint: enable=unused-import 

57 

58# Defined as per TFLite schema 

59_MAP_TFLITE_ENUM_TO_TF_TYPES = { 

60 0: dtypes.float32, 

61 1: dtypes.float16, 

62 2: dtypes.int32, 

63 3: dtypes.uint8, 

64 4: dtypes.int64, 

65 5: dtypes.string, 

66 6: dtypes.bool, 

67 7: dtypes.int16, 

68 8: dtypes.complex64, 

69 9: dtypes.int8, 

70 10: dtypes.float64, 

71 11: dtypes.complex128, 

72 16: dtypes.uint32, 

73} 

74 

75_TFLITE_FILE_IDENTIFIER = b"TFL3" 

76 

77_MAP_QUANT_TO_IO_TYPES = { 

78 dtypes.int8: {dtypes.int8, dtypes.uint8}, 

79 dtypes.int16: {dtypes.int16}, 

80} 

81 

82 

83def _convert_tflite_enum_type_to_tf_type(tflite_enum_type): 

84 """Converts tflite enum type (eg: 0) to tf type (eg: tf.float32). 

85 

86 Args: 

87 tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32) 

88 

89 Raises: 

90 ValueError: If an invalid tflite enum type is provided. 

91 

92 Returns: 

93 tf type (eg: tf.float32) 

94 """ 

95 tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type) 

96 if tf_type is None: 

97 raise ValueError( 

98 "Unsupported enum {}. The valid map of enum to tf types is : {}" 

99 .format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES)) 

100 return tf_type 

101 

102 

103def get_tf_type_name(tf_type): 

104 """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32").""" 

105 return "tf." + tf_type.name if tf_type else None 

106 

107 

108def get_tensor_name(tensor): 

109 """Returns name of the input tensor. 

110 

111 Args: 

112 tensor: tf.Tensor 

113 

114 Returns: 

115 str 

116 """ 

117 parts = tensor.name.split(":") 

118 if len(parts) > 2: 

119 raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format( 

120 len(parts) - 1)) 

121 

122 # To be consistent with the tensor naming scheme in tensorflow, we need 

123 # drop the ':0' suffix for the first tensor. 

124 if len(parts) > 1 and parts[1] != "0": 

125 return tensor.name 

126 return parts[0] 

127 

128 

129def get_tensors_from_tensor_names(graph, tensor_names): 

130 """Gets the Tensors associated with the `tensor_names` in the provided graph. 

131 

132 Args: 

133 graph: TensorFlow Graph. 

134 tensor_names: List of strings that represent names of tensors in the graph. 

135 

136 Returns: 

137 A list of Tensor objects in the same order the names are provided. 

138 

139 Raises: 

140 ValueError: 

141 tensor_names contains an invalid tensor name. 

142 """ 

143 # Get the list of all of the tensors. 

144 tensor_name_to_tensor = {} 

145 for op in graph.get_operations(): 

146 for tensor in op.values(): 

147 tensor_name_to_tensor[get_tensor_name(tensor)] = tensor 

148 

149 # Get the tensors associated with tensor_names. 

150 tensors = [] 

151 invalid_tensors = [] 

152 for name in tensor_names: 

153 if not isinstance(name, str): 

154 raise ValueError("Invalid type for a tensor name in the provided graph. " 

155 "Expected type for a tensor name is 'str', instead got " 

156 "type '{}' for tensor name '{}'".format( 

157 type(name), name)) 

158 

159 tensor = tensor_name_to_tensor.get(name) 

160 if tensor is None: 

161 invalid_tensors.append(name) 

162 else: 

163 tensors.append(tensor) 

164 

165 # Throw ValueError if any user input names are not valid tensors. 

166 if invalid_tensors: 

167 raise ValueError("Invalid tensors '{}' were found.".format( 

168 ",".join(invalid_tensors))) 

169 return tensors 

170 

171 

172def set_tensor_shapes(tensors, shapes): 

173 """Sets Tensor shape for each tensor if the shape is defined. 

174 

175 Args: 

176 tensors: TensorFlow ops.Tensor. 

177 shapes: Dict of strings representing input tensor names to list of 

178 integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). 

179 

180 Raises: 

181 ValueError: 

182 `shapes` contains an invalid tensor. 

183 `shapes` contains an invalid shape for a valid tensor. 

184 """ 

185 if shapes: 

186 tensor_names_to_tensor = { 

187 get_tensor_name(tensor): tensor for tensor in tensors 

188 } 

189 for name, shape in shapes.items(): 

190 if name not in tensor_names_to_tensor: 

191 raise ValueError("Invalid tensor \'{}\' found in tensor shapes " 

192 "map.".format(name)) 

193 if shape is not None: 

194 tensor = tensor_names_to_tensor[name] 

195 try: 

196 tensor.set_shape(shape) 

197 except ValueError as error: 

198 message = ("The shape of tensor '{0}' cannot be changed from {1} to " 

199 "{2}. {3}".format(name, tensor.shape, shape, str(error))) 

200 raise ValueError(message) 

201 

202 

203def get_grappler_config(optimizers_list): 

204 """Creates a tf.compat.v1.ConfigProto for configuring Grappler. 

205 

206 Args: 

207 optimizers_list: List of strings that represents the list of optimizers. 

208 

209 Returns: 

210 tf.ConfigProto. 

211 """ 

212 config = _config_pb2.ConfigProto() 

213 rewrite_options = config.graph_options.rewrite_options 

214 for optimizer in optimizers_list: 

215 rewrite_options.optimizers.append(optimizer) 

216 return config 

217 

218 

219def run_graph_optimizations(graph_def, 

220 input_arrays, 

221 output_arrays, 

222 config, 

223 graph=None): 

224 """Apply standard TensorFlow optimizations to the graph_def. 

225 

226 Args: 

227 graph_def: Frozen GraphDef to be optimized. 

228 input_arrays: List of arrays that are considered inputs of the graph. 

229 output_arrays: List of arrays that are considered outputs of the graph. 

230 config: tf.ConfigProto. 

231 graph: TensorFlow Graph. Required when Eager mode is enabled. (default None) 

232 

233 Returns: 

234 A new, optimized GraphDef. 

235 """ 

236 meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph) 

237 

238 signature = _meta_graph_pb2.SignatureDef() 

239 for array in input_arrays: 

240 signature.inputs[array.name].name = array.name 

241 signature.inputs[array.name].dtype = array.dtype.as_datatype_enum 

242 signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto()) 

243 

244 for array in output_arrays: 

245 signature.outputs[array.name].name = array.name 

246 signature.outputs[array.name].dtype = array.dtype.as_datatype_enum 

247 signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto()) 

248 

249 meta_graph.signature_def["not_used_key"].CopyFrom(signature) 

250 

251 # We need to add a collection called 'train_op' so that grappler 

252 # knows what the outputs are. 

253 fetch_collection = _meta_graph_pb2.CollectionDef() 

254 for array in input_arrays + output_arrays: 

255 fetch_collection.node_list.value.append(array.name) 

256 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 

257 

258 return tf_optimizer.OptimizeGraph(config, meta_graph) 

259 

260 

261def _convert_op_hints_if_present(sess, graph_def, output_tensors, 

262 hinted_outputs_nodes): 

263 if is_frozen_graph(sess): 

264 raise ValueError("Try to convert op hints, needs unfrozen graph.") 

265 output_arrays = [get_tensor_name(tensor) for tensor in output_tensors] 

266 graph_def = _convert_to_constants.convert_variables_to_constants( 

267 sess, graph_def, output_arrays + hinted_outputs_nodes) 

268 graph_def = convert_op_hints_to_stubs(graph_def=graph_def) 

269 return graph_def 

270 

271 

272def freeze_graph(sess, input_tensors, output_tensors): 

273 """Returns a frozen GraphDef. 

274 

275 Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the 

276 existing GraphDef is returned. The Grappler pass is only run on models that 

277 are frozen in order to inline the functions in the graph. 

278 If OpHints is present, it will try to convert the OpHint graph. 

279 

280 Args: 

281 sess: TensorFlow Session. 

282 input_tensors: List of input tensors. 

283 output_tensors: List of output tensors (only .name is used from this). 

284 

285 Returns: 

286 Frozen GraphDef. 

287 """ 

288 # Runs a Grappler pass in order to inline any functions in the graph. 

289 # Asides from inlining any simple function, Grappler will also try to lower 

290 # while loop into switch merge representation which is undesired for Ophints, 

291 # so we simply remove those attributes to prevent Grappler from doing so. 

292 graph_def = _convert_to_constants.disable_lower_using_switch_merge( 

293 sess.graph_def) 

294 config = get_grappler_config(["function"]) 

295 graph_def = run_graph_optimizations( 

296 graph_def, input_tensors, output_tensors, config, graph=sess.graph) 

297 

298 # If ophints are present, just convert them. 

299 hinted_outputs_nodes = find_all_hinted_output_nodes(sess) 

300 if hinted_outputs_nodes: 

301 return _convert_op_hints_if_present(sess, graph_def, output_tensors, 

302 hinted_outputs_nodes) 

303 

304 if not is_frozen_graph(sess): 

305 output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors] 

306 return _convert_to_constants.convert_variables_to_constants( 

307 sess, graph_def, output_node_names 

308 ) 

309 else: 

310 return sess.graph_def 

311 

312 

313def is_frozen_graph(sess): 

314 """Determines if the graph is frozen. 

315 

316 Determines if a graph has previously been frozen by checking for any 

317 operations of type Variable*. If variables are found, the graph is not frozen. 

318 

319 Args: 

320 sess: TensorFlow Session. 

321 

322 Returns: 

323 Bool. 

324 """ 

325 for op in sess.graph.get_operations(): 

326 if op.type.startswith("Variable") or op.type.endswith("VariableOp"): 

327 return False 

328 return True 

329 

330 

331def build_debug_info_func(original_graph): 

332 """Returns a method to retrieve the `GraphDebugInfo` from the original graph. 

333 

334 Args: 

335 original_graph: The original `Graph` containing all the op stack traces. 

336 

337 Returns: 

338 A function which retrieves the stack traces from the original graph and 

339 converts them to a `GraphDebugInfo` for a given set of nodes. 

340 """ 

341 

342 def f(original_nodes): 

343 """Function to create `GraphDebugInfo` for the given `original_nodes`.""" 

344 if not original_graph: 

345 return None 

346 # For the given nodes, gets all the op definitions in the original graph. 

347 useful_ops = [] 

348 for func, name in original_nodes: 

349 try: 

350 if not func: 

351 useful_ops.append((func, original_graph.get_operation_by_name(name))) 

352 else: 

353 sub_func = original_graph._get_function(func) # pylint: disable=protected-access 

354 if isinstance(sub_func, function.AtomicFunction): # pylint: disable=protected-access 

355 useful_ops.append( 

356 (func, sub_func.graph.get_operation_by_name(name))) 

357 else: 

358 sys.stderr.write( 

359 "Use '@tf.function' or '@defun' to decorate the function.\n") 

360 continue 

361 except KeyError: 

362 # New node created by graph optimizer. No stack trace from source code. 

363 continue 

364 # Convert all the op definitions to stack traces in terms of GraphDebugInfo. 

365 return _error_interpolation.create_graph_debug_info_def(useful_ops) 

366 

367 return f 

368 

369 

370def convert_debug_info_func(saved_debug_info): 

371 """Returns a method to retrieve the `GraphDebugInfo` from the original graph. 

372 

373 Args: 

374 saved_debug_info: The `GraphDebugInfo` containing all the debug info. 

375 

376 Returns: 

377 A function which retrieves the stack traces from the original graph and 

378 converts them to a `GraphDebugInfo` for a given set of nodes. 

379 """ 

380 

381 def f(original_nodes): 

382 """Function to create `GraphDebugInfo` for the given `original_nodes`.""" 

383 if not saved_debug_info: 

384 return None 

385 

386 output_debug_info = graph_debug_info_pb2.GraphDebugInfo() 

387 # All the files are copied over, so the index wouldn't be changed. 

388 output_debug_info.files[:] = saved_debug_info.files 

389 # We only copy over the debug info for the input nodes 

390 for func, node in original_nodes: 

391 debug_key = node + "@" + func 

392 output_debug_info.traces[debug_key].CopyFrom( 

393 saved_debug_info.traces[debug_key]) 

394 return output_debug_info 

395 

396 return f 

397 

398 

399def get_debug_info(nodes_to_debug_info_func, converted_graph): 

400 """Returns the debug info for the original nodes in the `converted_graph`. 

401 

402 Args: 

403 nodes_to_debug_info_func: The method to collect the op debug info for the 

404 nodes. 

405 converted_graph: A `GraphDef` after optimization and transformation. 

406 

407 Returns: 

408 `GraphDebugInfo` for all the original nodes in `converted_graph`. 

409 """ 

410 if not nodes_to_debug_info_func: 

411 return None 

412 

413 # Collect all the debug info nodes from the converted_graph 

414 original_nodes = set() 

415 for node in converted_graph.node: 

416 debug_nodes = node.experimental_debug_info.original_node_names 

417 debug_funcs = node.experimental_debug_info.original_func_names 

418 # If the `original_node_names` are empty, uses the node name directly. 

419 if not debug_nodes: 

420 original_nodes.add(("", node.name)) 

421 else: 

422 for i in range(len(debug_nodes)): 

423 debug_func = "" if i >= len(debug_funcs) else debug_funcs[i] 

424 original_nodes.add((debug_func, debug_nodes[i])) 

425 

426 # Convert the nodes to the debug info proto object. 

427 return nodes_to_debug_info_func(original_nodes) 

428 

429 

430def convert_bytes_to_c_source(data, 

431 array_name, 

432 max_line_width=80, 

433 include_guard=None, 

434 include_path=None, 

435 use_tensorflow_license=False): 

436 """Returns strings representing a C constant array containing `data`. 

437 

438 Args: 

439 data: Byte array that will be converted into a C constant. 

440 array_name: String to use as the variable name for the constant array. 

441 max_line_width: The longest line length, for formatting purposes. 

442 include_guard: Name to use for the include guard macro definition. 

443 include_path: Optional path to include in the source file. 

444 use_tensorflow_license: Whether to include the standard TensorFlow Apache2 

445 license in the generated files. 

446 

447 Returns: 

448 Text that can be compiled as a C source file to link in the data as a 

449 literal array of values. 

450 Text that can be used as a C header file to reference the literal array. 

451 """ 

452 

453 starting_pad = " " 

454 array_lines = [] 

455 array_line = starting_pad 

456 for value in bytearray(data): 

457 if (len(array_line) + 4) > max_line_width: 

458 array_lines.append(array_line + "\n") 

459 array_line = starting_pad 

460 array_line += " 0x%02x," % (value,) 

461 if len(array_line) > len(starting_pad): 

462 array_lines.append(array_line + "\n") 

463 array_values = "".join(array_lines) 

464 

465 if include_guard is None: 

466 include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_" 

467 

468 if include_path is not None: 

469 include_line = "#include \"{include_path}\"\n".format( 

470 include_path=include_path) 

471 else: 

472 include_line = "" 

473 

474 if use_tensorflow_license: 

475 license_text = """ 

476/* Copyright {year} The TensorFlow Authors. All Rights Reserved. 

477 

478Licensed under the Apache License, Version 2.0 (the "License"); 

479you may not use this file except in compliance with the License. 

480You may obtain a copy of the License at 

481 

482 http://www.apache.org/licenses/LICENSE-2.0 

483 

484Unless required by applicable law or agreed to in writing, software 

485distributed under the License is distributed on an "AS IS" BASIS, 

486WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

487See the License for the specific language governing permissions and 

488limitations under the License. 

489==============================================================================*/ 

490""".format(year=datetime.date.today().year) 

491 else: 

492 license_text = "" 

493 

494 source_template = """{license_text} 

495// This is a TensorFlow Lite model file that has been converted into a C data 

496// array using the tensorflow.lite.util.convert_bytes_to_c_source() function. 

497// This form is useful for compiling into a binary for devices that don't have a 

498// file system. 

499 

500{include_line} 

501// We need to keep the data array aligned on some architectures. 

502#ifdef __has_attribute 

503#define HAVE_ATTRIBUTE(x) __has_attribute(x) 

504#else 

505#define HAVE_ATTRIBUTE(x) 0 

506#endif 

507#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) 

508#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4))) 

509#else 

510#define DATA_ALIGN_ATTRIBUTE 

511#endif 

512 

513const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{ 

514{array_values}}}; 

515const int {array_name}_len = {array_length}; 

516""" 

517 

518 source_text = source_template.format( 

519 array_name=array_name, 

520 array_length=len(data), 

521 array_values=array_values, 

522 license_text=license_text, 

523 include_line=include_line) 

524 

525 header_template = """ 

526{license_text} 

527 

528// This is a TensorFlow Lite model file that has been converted into a C data 

529// array using the tensorflow.lite.util.convert_bytes_to_c_source() function. 

530// This form is useful for compiling into a binary for devices that don't have a 

531// file system. 

532 

533#ifndef {include_guard} 

534#define {include_guard} 

535 

536extern const unsigned char {array_name}[]; 

537extern const int {array_name}_len; 

538 

539#endif // {include_guard} 

540""" 

541 

542 header_text = header_template.format( 

543 array_name=array_name, 

544 include_guard=include_guard, 

545 license_text=license_text) 

546 

547 return source_text, header_text 

548 

549 

550def _convert_model_from_bytearray_to_object(model_bytearray): 

551 """Converts a tflite model from a bytearray into a parsable object.""" 

552 model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0) 

553 model_object = schema_fb.ModelT.InitFromObj(model_object) 

554 model_object = copy.deepcopy(model_object) 

555 return model_object 

556 

557 

558def _convert_model_from_object_to_bytearray(model_object): 

559 """Converts a tflite model from a parsable object into a bytearray.""" 

560 # Initial size of the buffer, which will grow automatically if needed 

561 builder = flatbuffers.Builder(1024) 

562 model_offset = model_object.Pack(builder) 

563 builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER) 

564 return bytes(builder.Output()) 

565 

566 

567def get_quantize_opcode_idx(model): 

568 """Returns the quantize op idx.""" 

569 quant_opcode_idxs = [] 

570 for idx, opcode in enumerate(model.operatorCodes): 

571 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 

572 if builtin_code == schema_fb.BuiltinOperator.QUANTIZE: 

573 quant_opcode_idxs.append(idx) 

574 return quant_opcode_idxs 

575 

576 

577def get_dequantize_opcode_idx(model): 

578 """Returns the quantize op idx.""" 

579 quant_opcode_idxs = [] 

580 for idx, opcode in enumerate(model.operatorCodes): 

581 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 

582 if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE: 

583 quant_opcode_idxs.append(idx) 

584 return quant_opcode_idxs 

585 

586 

587def _update_signature_def_tensors(tensor_maps, map_old_to_new_tensors): 

588 """Update the tensors in the SignatureDef's TensorMaps.""" 

589 for i in range(len(tensor_maps)): 

590 if tensor_maps[i].tensorIndex in map_old_to_new_tensors: 

591 tensor_maps[i].tensorIndex = ( 

592 map_old_to_new_tensors[tensor_maps[i].tensorIndex]) 

593 

594 

595def _remove_tensors_from_model(model, remove_tensors_idxs): 

596 """Remove tensors from model.""" 

597 if not remove_tensors_idxs: 

598 return 

599 if len(model.subgraphs) > 1: 

600 logging.info("Skipping the removal of dangled tensors since the model has " 

601 "multiple subgraphs and tensors can be used in the different " 

602 "subgraph(s)") 

603 return 

604 subgraph = model.subgraphs[0] 

605 tensors = subgraph.tensors 

606 operators = subgraph.operators 

607 

608 logging.debug("Removing tensors at indices : %s", remove_tensors_idxs) 

609 # An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an 

610 # exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]). 

611 if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs): 

612 logging.debug("Removing tensors only at the end of the tensor list") 

613 del tensors[min(remove_tensors_idxs):] 

614 else: 

615 logging.debug("Removing tensors requires updating the model") 

616 # Map the old tensor indices to new tensor indices 

617 d_old_to_new_tensors = {} 

618 left_shift_by = 0 

619 for idx in range(len(tensors)): 

620 if idx in remove_tensors_idxs: 

621 left_shift_by += 1 

622 else: 

623 d_old_to_new_tensors[idx] = idx - left_shift_by 

624 logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__()) 

625 # Update tensor indices referenced throughout the model 

626 def update_tensors(tensor_idxs): 

627 for i, ti in enumerate(tensor_idxs): 

628 tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1) 

629 update_tensors(subgraph.inputs) 

630 update_tensors(subgraph.outputs) 

631 for op in operators: 

632 update_tensors(op.inputs) 

633 update_tensors(op.outputs) 

634 if model.signatureDefs: 

635 signature_def = model.signatureDefs[0] 

636 _update_signature_def_tensors(signature_def.inputs, d_old_to_new_tensors) 

637 _update_signature_def_tensors(signature_def.outputs, d_old_to_new_tensors) 

638 # Delete the tensors 

639 for idx in sorted(remove_tensors_idxs, reverse=True): 

640 tensors.pop(idx) 

641 logging.debug("Removed tensors marked for deletion") 

642 

643 

644def _modify_model_input_type(model, inference_input_type=dtypes.float32): 

645 """Modify model input type.""" 

646 if inference_input_type == dtypes.float32: 

647 return 

648 

649 if not model.signatureDefs: 

650 _modify_model_input_type_per_subgraph(model, 0, -1, inference_input_type) 

651 return 

652 

653 for signature_index, signature_def in enumerate(model.signatureDefs): 

654 _modify_model_input_type_per_subgraph(model, signature_def.subgraphIndex, 

655 signature_index, inference_input_type) 

656 

657 

658def _modify_model_input_type_per_subgraph(model, subgraph_index, 

659 signature_index, 

660 inference_input_type): 

661 """Modify model input type per subgraph.""" 

662 subgraph = model.subgraphs[subgraph_index] 

663 tensors = subgraph.tensors 

664 operators = subgraph.operators 

665 

666 # Find all quantize operators 

667 quant_opcode_idxs = get_quantize_opcode_idx(model) 

668 if operators and not quant_opcode_idxs: 

669 for input_idx in subgraph.inputs: 

670 input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type) 

671 if input_type == dtypes.float32: 

672 raise ValueError("Model input is not dequantized.") 

673 # None of the inputs have float32, then they must be int16, int8, or bool 

674 return 

675 

676 # Validate that the model input is quantized 

677 input_quant_ops = [] 

678 for op in operators: 

679 # Find operators that quantize model input 

680 if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs: 

681 float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] 

682 # If found, validate that the operator's input type is float 

683 float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type) 

684 if float_type != dtypes.float32: 

685 if float_type == inference_input_type: 

686 continue 

687 else: 

688 raise ValueError( 

689 "Initial model input type must be tf.float32. Expected type for " 

690 "tensor with name '{}' is tf.float32, instead type is {}".format( 

691 float_tensor.name, get_tf_type_name(float_type))) 

692 # If found, validate that the operator output is quantized and compatible 

693 # with the final model input type 

694 quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) 

695 if quant_type not in _MAP_QUANT_TO_IO_TYPES: 

696 raise ValueError( 

697 "Initial model input is not quantized. Expected type for " 

698 "tensor with name '{}' should be in {}, instead type is {}".format( 

699 quant_tensor.name, 

700 tuple(get_tf_type_name(t) for t in 

701 _MAP_QUANT_TO_IO_TYPES.keys()), 

702 get_tf_type_name(quant_type))) 

703 else: 

704 inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] 

705 if inference_input_type not in inference_io_types: 

706 raise ValueError( 

707 "Unsupported `inference_input_type` value. Expected to be in " 

708 "{}, instead got {}.".format( 

709 tuple(get_tf_type_name(t) for t in inference_io_types), 

710 get_tf_type_name(inference_input_type))) 

711 input_quant_ops.append(op) 

712 

713 if len(subgraph.inputs) != len(input_quant_ops): 

714 logging.warning( 

715 "For model inputs containing unsupported operations which cannot be " 

716 "quantized, the `inference_input_type` attribute will default to the " 

717 "original type." 

718 ) 

719 

720 # Modify model input type 

721 if inference_input_type == dtypes.uint8: 

722 # Change quant op (float to int8) to quant op (uint8 to int8) 

723 for op in input_quant_ops: 

724 int8_quantization = tensors[op.outputs[0]].quantization 

725 uint8_quantization = schema_fb.QuantizationParametersT() 

726 uint8_quantization.scale = [int8_quantization.scale[0]] 

727 uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] 

728 tensors[op.inputs[0]].quantization = uint8_quantization 

729 tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8 

730 elif inference_input_type in _MAP_QUANT_TO_IO_TYPES: 

731 # Remove the inputs and the quant operator 

732 remove_tensors_idxs = set() 

733 for op in input_quant_ops: 

734 subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0] 

735 if signature_index >= 0: 

736 signature_def = model.signatureDefs[signature_index] 

737 for i in range(len(signature_def.inputs)): 

738 if signature_def.inputs[i].tensorIndex == op.inputs[0]: 

739 signature_def.inputs[i].tensorIndex = op.outputs[0] 

740 remove_tensors_idxs.add(op.inputs[0]) 

741 operators.remove(op) 

742 # Remove tensors marked for deletion. 

743 _remove_tensors_from_model(model, remove_tensors_idxs) 

744 else: 

745 raise ValueError( 

746 "Unsupported `inference_input_type` value {}.".format( 

747 get_tf_type_name(inference_input_type))) 

748 

749 

750def _modify_model_output_type(model, inference_output_type=dtypes.float32): 

751 """Modify model output type.""" 

752 if inference_output_type == dtypes.float32: 

753 return 

754 

755 if not model.signatureDefs: 

756 _modify_model_output_type_per_subgraph(model, 0, -1, inference_output_type) 

757 return 

758 

759 for signature_index, signature_def in enumerate(model.signatureDefs): 

760 _modify_model_output_type_per_subgraph(model, signature_def.subgraphIndex, 

761 signature_index, 

762 inference_output_type) 

763 

764 

765def _modify_model_output_type_per_subgraph(model, subgraph_index, 

766 signature_index, 

767 inference_output_type): 

768 """Modify model output type per subgraph.""" 

769 subgraph = model.subgraphs[subgraph_index] 

770 tensors = subgraph.tensors 

771 operators = subgraph.operators 

772 

773 # Find all dequantize operators 

774 dequant_opcode_idxs = get_dequantize_opcode_idx(model) 

775 if operators and not dequant_opcode_idxs: 

776 for output in subgraph.outputs: 

777 output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type) 

778 if output_type == dtypes.float32: 

779 raise ValueError("Model output is not dequantized.") 

780 # None of the outputs have float32, then they must be int16, int8, or bool 

781 return 

782 

783 # Validate that the model output is dequantized 

784 output_dequant_ops = [] 

785 for op in operators: 

786 # Find operators that dequantize model output 

787 if (op.opcodeIndex in dequant_opcode_idxs and 

788 op.outputs[0] in subgraph.outputs): 

789 # If found, validate that the operator's output type is float 

790 quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] 

791 float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type) 

792 if float_type != dtypes.float32: 

793 if float_type == inference_output_type: 

794 continue 

795 else: 

796 raise ValueError( 

797 "Initial model output type must be tf.float32. Expected type for " 

798 "tensor with name '{}' is tf.float32, instead type is {}".format( 

799 float_tensor.name, get_tf_type_name(float_type))) 

800 # If found, validate that the operator input is quantized and compatible 

801 # with the final model output type 

802 quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) 

803 if quant_type not in _MAP_QUANT_TO_IO_TYPES: 

804 raise ValueError( 

805 "Initial model output is not dequantized. Expected type for " 

806 "tensor with name '{}' should be in {}, instead type is {}".format( 

807 quant_tensor.name, 

808 tuple(get_tf_type_name(t) for t in 

809 _MAP_QUANT_TO_IO_TYPES.keys()), 

810 get_tf_type_name(quant_type))) 

811 else: 

812 inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] 

813 if inference_output_type not in inference_io_types: 

814 raise ValueError( 

815 "Unsupported `inference_output_type` value. Expected to be in " 

816 "{}, instead got {}.".format( 

817 tuple(get_tf_type_name(t) for t in inference_io_types), 

818 get_tf_type_name(inference_output_type))) 

819 output_dequant_ops.append(op) 

820 

821 if len(subgraph.outputs) != len(output_dequant_ops): 

822 logging.warning( 

823 "For model outputs containing unsupported operations which cannot be " 

824 "quantized, the `inference_output_type` attribute will default to the " 

825 "original type." 

826 ) 

827 

828 # Modify model output type 

829 if inference_output_type == dtypes.uint8: 

830 # Find a quantize operator 

831 quant_opcode_idx = -1 

832 for idx, opcode in enumerate(model.operatorCodes): 

833 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 

834 if builtin_code == schema_fb.BuiltinOperator.QUANTIZE: 

835 quant_opcode_idx = idx 

836 break 

837 # Create a quantize operator, if none exist 

838 if quant_opcode_idx == -1: 

839 quant_op = schema_fb.OperatorCodeT() 

840 quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE 

841 quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE 

842 model.operatorCodes.append(quant_op) 

843 quant_opcode_idx = len(model.operatorCodes) - 1 

844 # Change dequant op (int8 to float) to quant op (int8 to uint8) 

845 for op in output_dequant_ops: 

846 op.opcodeIndex = quant_opcode_idx 

847 int8_quantization = tensors[op.inputs[0]].quantization 

848 uint8_quantization = schema_fb.QuantizationParametersT() 

849 uint8_quantization.scale = [int8_quantization.scale[0]] 

850 uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] 

851 tensors[op.outputs[0]].quantization = uint8_quantization 

852 tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8 

853 elif inference_output_type in _MAP_QUANT_TO_IO_TYPES: 

854 # Remove the outputs and the dequant operator 

855 remove_tensors_idxs = set() 

856 for op in output_dequant_ops: 

857 subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0] 

858 if signature_index >= 0: 

859 signature_def = model.signatureDefs[signature_index] 

860 for i in range(len(signature_def.outputs)): 

861 if signature_def.outputs[i].tensorIndex == op.outputs[0]: 

862 signature_def.outputs[i].tensorIndex = op.inputs[0] 

863 remove_tensors_idxs.add(op.outputs[0]) 

864 operators.remove(op) 

865 # Remove tensors marked for deletion. 

866 _remove_tensors_from_model(model, remove_tensors_idxs) 

867 else: 

868 raise ValueError( 

869 "Unsupported `inference_output_type` value {}.".format( 

870 get_tf_type_name(inference_output_type))) 

871 

872 

873def _remove_redundant_quantize_ops(model): 

874 """Finds back to back quantize ops and remove the first quantize op.""" 

875 if not model.signatureDefs: 

876 _remove_redundant_quantize_ops_per_subgraph(model, 0, -1) 

877 return 

878 

879 for signature_index, signature_def in enumerate(model.signatureDefs): 

880 _remove_redundant_quantize_ops_per_subgraph(model, 

881 signature_def.subgraphIndex, 

882 signature_index) 

883 

884 

885def _remove_redundant_quantize_ops_per_subgraph(model, subgraph_index, 

886 signature_index): 

887 """Remove redundant quantize ops per subgraph.""" 

888 subgraph = model.subgraphs[subgraph_index] 

889 tensors = subgraph.tensors 

890 operators = subgraph.operators 

891 

892 # Find all quantize operators. 

893 quant_opcode_idxs = get_quantize_opcode_idx(model) 

894 dequant_opcode_idxs = get_dequantize_opcode_idx(model) 

895 

896 # Find all redundant quant tensors. 

897 all_quant_ops = [] 

898 redundant_quant_tensors = {} 

899 output_dequant_tensors = {} 

900 for op in operators: 

901 if op.opcodeIndex in quant_opcode_idxs: 

902 all_quant_ops.append(op) 

903 input_tensor = tensors[op.inputs[0]] 

904 output_tensor = tensors[op.outputs[0]] 

905 input_type = _convert_tflite_enum_type_to_tf_type(input_tensor.type) 

906 output_type = _convert_tflite_enum_type_to_tf_type(output_tensor.type) 

907 # This is a requantize op, so write down its input tensor index. 

908 if input_type != dtypes.float32 and output_type != dtypes.float32: 

909 redundant_quant_tensors[op.inputs[0]] = op 

910 if (op.opcodeIndex in dequant_opcode_idxs and 

911 op.outputs[0] in subgraph.outputs): 

912 output_dequant_tensors[op.inputs[0]] = op 

913 

914 # Remove all the quant ops which produce the redundant quant tensors. 

915 for op in all_quant_ops: 

916 output_tensor_idx = op.outputs[0] 

917 if output_tensor_idx in redundant_quant_tensors: 

918 requantize_op = redundant_quant_tensors[output_tensor_idx] 

919 if model.signatureDefs: 

920 signature_def = model.signatureDefs[0] 

921 for output in signature_def.outputs: 

922 if output.tensorIndex == op.outputs[0]: 

923 output.tensorIndex = op.inputs[0] 

924 # Reset the input of the requantize op to the float input 

925 requantize_op.inputs[0] = op.inputs[0] 

926 operators.remove(op) 

927 

928 # Remove all the quant ops which connect to the output dequant op. 

929 for op in all_quant_ops: 

930 output_tensor_idx = op.outputs[0] 

931 if output_tensor_idx in output_dequant_tensors: 

932 dequant_op = output_dequant_tensors[output_tensor_idx] 

933 subgraph.outputs[subgraph.outputs == dequant_op.outputs[0]] = op.inputs[0] 

934 if signature_index >= 0: 

935 signature_def = model.signatureDefs[signature_index] 

936 for output in signature_def.outputs: 

937 if output.tensorIndex == dequant_op.outputs[0]: 

938 output.tensorIndex = op.inputs[0] 

939 operators.remove(op) 

940 operators.remove(dequant_op) 

941 

942 

943def modify_model_io_type( 

944 model, inference_input_type=dtypes.float32, 

945 inference_output_type=dtypes.float32): 

946 """Modify the input/output type of a tflite model. 

947 

948 Args: 

949 model: A tflite model. 

950 inference_input_type: tf.DType representing modified input type. 

951 (default tf.float32. If model input is int8 quantized, it must be in 

952 {tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized, 

953 it must be in {tf.float32, tf.int16}, else it must be tf.float32) 

954 inference_output_type: tf.DType representing modified output type. 

955 (default tf.float32. If model output is int8 dequantized, it must be in 

956 {tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized, 

957 it must be in {tf.float32, tf.int16}, else it must be tf.float32) 

958 Returns: 

959 A tflite model with modified input/output type. 

960 

961 Raises: 

962 ValueError: If `inference_input_type`/`inference_output_type` is unsupported 

963 or a supported integer type is specified for a model whose input/output is 

964 not quantized/dequantized. 

965 RuntimeError: If the modification was unsuccessful. 

966 

967 """ 

968 if (inference_input_type == dtypes.float32 and 

969 inference_output_type == dtypes.float32): 

970 return model 

971 

972 model_object = _convert_model_from_bytearray_to_object(model) 

973 

974 _modify_model_input_type(model_object, inference_input_type) 

975 

976 _modify_model_output_type(model_object, inference_output_type) 

977 

978 _remove_redundant_quantize_ops(model_object) 

979 

980 return _convert_model_from_object_to_bytearray(model_object) 

981 

982 

983def get_sparsity_modes(model_object): 

984 """Get sparsity modes used in a tflite model. 

985 

986 The sparsity modes are listed in conversion_metadata.fbs file. 

987 

988 Args: 

989 model_object: A tflite model in object form. 

990 

991 Returns: 

992 The list of sparsity modes used in the model. 

993 """ 

994 if not model_object or not model_object.metadata: 

995 return [] 

996 

997 result = set() 

998 for subgraph in model_object.subgraphs: 

999 for tensor in subgraph.tensors: 

1000 if not tensor.sparsity: 

1001 continue 

1002 

1003 # Block map is the list if indexes where the block size is larger than 1. 

1004 # So empty block map means it is random sparsity. 

1005 if not tensor.sparsity.blockMap: 

1006 result.add( 

1007 conversion_metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY) 

1008 else: 

1009 result.add( 

1010 conversion_metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY) 

1011 

1012 return list(result) 

1013 

1014 

1015def populate_conversion_metadata(model_object, metadata): 

1016 """Add or update conversion metadata to a tflite model. 

1017 

1018 Args: 

1019 model_object: A tflite model in object form. 

1020 metadata: The conversion metadata. 

1021 

1022 Returns: 

1023 A tflite model object with embedded conversion metadata. 

1024 """ 

1025 try: 

1026 metadata_builder = flatbuffers.Builder(0) 

1027 metadata_builder.Finish(metadata.Pack(metadata_builder)) 

1028 buffer_field = schema_fb.BufferT() 

1029 buffer_field.data = metadata_builder.Output() 

1030 

1031 if not model_object.metadata: 

1032 model_object.metadata = [] 

1033 else: 

1034 # Check if metadata has already been populated. 

1035 for meta in model_object.metadata: 

1036 if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME: 

1037 model_object.buffers[meta.buffer] = buffer_field 

1038 return model_object 

1039 

1040 if not model_object.buffers: 

1041 model_object.buffers = [] 

1042 model_object.buffers.append(buffer_field) 

1043 # Creates a new metadata field. 

1044 metadata_field = schema_fb.MetadataT() 

1045 metadata_field.name = CONVERSION_METADATA_FIELD_NAME 

1046 metadata_field.buffer = len(model_object.buffers) - 1 

1047 model_object.metadata.append(metadata_field) 

1048 

1049 return model_object 

1050 except Exception: # pylint: disable=broad-except 

1051 return model_object 

1052 

1053 

1054def get_conversion_metadata(model_buffer): 

1055 """Read conversion metadata from a tflite model. 

1056 

1057 Args: 

1058 model_buffer: A tflite model. 

1059 

1060 Returns: 

1061 The conversion metadata or None if it is not populated. 

1062 """ 

1063 model_object = flatbuffer_utils.convert_bytearray_to_object(model_buffer) 

1064 if not model_object or not model_object.metadata: 

1065 return None 

1066 

1067 for meta in model_object.metadata: 

1068 if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME: 

1069 metadata_buf = model_object.buffers[meta.buffer].data.tobytes() 

1070 return conversion_metadata_fb.ConversionMetadataT.InitFromObj( 

1071 conversion_metadata_fb.ConversionMetadata.GetRootAsConversionMetadata( 

1072 metadata_buf, 0)) 

1073 

1074 return None