Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/python/util.py: 13%
482 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions used by multiple converter files."""
17import copy
18import datetime
19import sys
21from absl import logging
23import flatbuffers
24from tensorflow.core.framework import graph_debug_info_pb2
25from tensorflow.core.protobuf import config_pb2 as _config_pb2
26from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
27from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb
28from tensorflow.lite.python import schema_py_generated as schema_fb
29from tensorflow.lite.python import schema_util
30from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util
31from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
32from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
33from tensorflow.lite.tools import flatbuffer_utils
34from tensorflow.python.eager import function
35from tensorflow.python.framework import convert_to_constants as _convert_to_constants
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import error_interpolation as _error_interpolation
38from tensorflow.python.grappler import tf_optimizer
39from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
41# The field name of conversion metadata in the flatbuffer file.
42CONVERSION_METADATA_FIELD_NAME = "CONVERSION_METADATA"
44# Keras functions used by TFLite
45model_input_signature = _tflite_keras_util.model_input_signature
46trace_model_call = _tflite_keras_util.trace_model_call
48# Jax functions used by TFLite
49# pylint: disable=g-import-not-at-top
50# pylint: disable=unused-import
51try:
52 from jax import xla_computation as _xla_computation
53except ImportError:
54 _xla_computation = None
55# pylint: enable=g-import-not-at-top
56# pylint: enable=unused-import
58# Defined as per TFLite schema
59_MAP_TFLITE_ENUM_TO_TF_TYPES = {
60 0: dtypes.float32,
61 1: dtypes.float16,
62 2: dtypes.int32,
63 3: dtypes.uint8,
64 4: dtypes.int64,
65 5: dtypes.string,
66 6: dtypes.bool,
67 7: dtypes.int16,
68 8: dtypes.complex64,
69 9: dtypes.int8,
70 10: dtypes.float64,
71 11: dtypes.complex128,
72 16: dtypes.uint32,
73}
75_TFLITE_FILE_IDENTIFIER = b"TFL3"
77_MAP_QUANT_TO_IO_TYPES = {
78 dtypes.int8: {dtypes.int8, dtypes.uint8},
79 dtypes.int16: {dtypes.int16},
80}
83def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
84 """Converts tflite enum type (eg: 0) to tf type (eg: tf.float32).
86 Args:
87 tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32)
89 Raises:
90 ValueError: If an invalid tflite enum type is provided.
92 Returns:
93 tf type (eg: tf.float32)
94 """
95 tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
96 if tf_type is None:
97 raise ValueError(
98 "Unsupported enum {}. The valid map of enum to tf types is : {}"
99 .format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
100 return tf_type
103def get_tf_type_name(tf_type):
104 """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
105 return "tf." + tf_type.name if tf_type else None
108def get_tensor_name(tensor):
109 """Returns name of the input tensor.
111 Args:
112 tensor: tf.Tensor
114 Returns:
115 str
116 """
117 parts = tensor.name.split(":")
118 if len(parts) > 2:
119 raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
120 len(parts) - 1))
122 # To be consistent with the tensor naming scheme in tensorflow, we need
123 # drop the ':0' suffix for the first tensor.
124 if len(parts) > 1 and parts[1] != "0":
125 return tensor.name
126 return parts[0]
129def get_tensors_from_tensor_names(graph, tensor_names):
130 """Gets the Tensors associated with the `tensor_names` in the provided graph.
132 Args:
133 graph: TensorFlow Graph.
134 tensor_names: List of strings that represent names of tensors in the graph.
136 Returns:
137 A list of Tensor objects in the same order the names are provided.
139 Raises:
140 ValueError:
141 tensor_names contains an invalid tensor name.
142 """
143 # Get the list of all of the tensors.
144 tensor_name_to_tensor = {}
145 for op in graph.get_operations():
146 for tensor in op.values():
147 tensor_name_to_tensor[get_tensor_name(tensor)] = tensor
149 # Get the tensors associated with tensor_names.
150 tensors = []
151 invalid_tensors = []
152 for name in tensor_names:
153 if not isinstance(name, str):
154 raise ValueError("Invalid type for a tensor name in the provided graph. "
155 "Expected type for a tensor name is 'str', instead got "
156 "type '{}' for tensor name '{}'".format(
157 type(name), name))
159 tensor = tensor_name_to_tensor.get(name)
160 if tensor is None:
161 invalid_tensors.append(name)
162 else:
163 tensors.append(tensor)
165 # Throw ValueError if any user input names are not valid tensors.
166 if invalid_tensors:
167 raise ValueError("Invalid tensors '{}' were found.".format(
168 ",".join(invalid_tensors)))
169 return tensors
172def set_tensor_shapes(tensors, shapes):
173 """Sets Tensor shape for each tensor if the shape is defined.
175 Args:
176 tensors: TensorFlow ops.Tensor.
177 shapes: Dict of strings representing input tensor names to list of
178 integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
180 Raises:
181 ValueError:
182 `shapes` contains an invalid tensor.
183 `shapes` contains an invalid shape for a valid tensor.
184 """
185 if shapes:
186 tensor_names_to_tensor = {
187 get_tensor_name(tensor): tensor for tensor in tensors
188 }
189 for name, shape in shapes.items():
190 if name not in tensor_names_to_tensor:
191 raise ValueError("Invalid tensor \'{}\' found in tensor shapes "
192 "map.".format(name))
193 if shape is not None:
194 tensor = tensor_names_to_tensor[name]
195 try:
196 tensor.set_shape(shape)
197 except ValueError as error:
198 message = ("The shape of tensor '{0}' cannot be changed from {1} to "
199 "{2}. {3}".format(name, tensor.shape, shape, str(error)))
200 raise ValueError(message)
203def get_grappler_config(optimizers_list):
204 """Creates a tf.compat.v1.ConfigProto for configuring Grappler.
206 Args:
207 optimizers_list: List of strings that represents the list of optimizers.
209 Returns:
210 tf.ConfigProto.
211 """
212 config = _config_pb2.ConfigProto()
213 rewrite_options = config.graph_options.rewrite_options
214 for optimizer in optimizers_list:
215 rewrite_options.optimizers.append(optimizer)
216 return config
219def run_graph_optimizations(graph_def,
220 input_arrays,
221 output_arrays,
222 config,
223 graph=None):
224 """Apply standard TensorFlow optimizations to the graph_def.
226 Args:
227 graph_def: Frozen GraphDef to be optimized.
228 input_arrays: List of arrays that are considered inputs of the graph.
229 output_arrays: List of arrays that are considered outputs of the graph.
230 config: tf.ConfigProto.
231 graph: TensorFlow Graph. Required when Eager mode is enabled. (default None)
233 Returns:
234 A new, optimized GraphDef.
235 """
236 meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph)
238 signature = _meta_graph_pb2.SignatureDef()
239 for array in input_arrays:
240 signature.inputs[array.name].name = array.name
241 signature.inputs[array.name].dtype = array.dtype.as_datatype_enum
242 signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
244 for array in output_arrays:
245 signature.outputs[array.name].name = array.name
246 signature.outputs[array.name].dtype = array.dtype.as_datatype_enum
247 signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
249 meta_graph.signature_def["not_used_key"].CopyFrom(signature)
251 # We need to add a collection called 'train_op' so that grappler
252 # knows what the outputs are.
253 fetch_collection = _meta_graph_pb2.CollectionDef()
254 for array in input_arrays + output_arrays:
255 fetch_collection.node_list.value.append(array.name)
256 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
258 return tf_optimizer.OptimizeGraph(config, meta_graph)
261def _convert_op_hints_if_present(sess, graph_def, output_tensors,
262 hinted_outputs_nodes):
263 if is_frozen_graph(sess):
264 raise ValueError("Try to convert op hints, needs unfrozen graph.")
265 output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
266 graph_def = _convert_to_constants.convert_variables_to_constants(
267 sess, graph_def, output_arrays + hinted_outputs_nodes)
268 graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
269 return graph_def
272def freeze_graph(sess, input_tensors, output_tensors):
273 """Returns a frozen GraphDef.
275 Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
276 existing GraphDef is returned. The Grappler pass is only run on models that
277 are frozen in order to inline the functions in the graph.
278 If OpHints is present, it will try to convert the OpHint graph.
280 Args:
281 sess: TensorFlow Session.
282 input_tensors: List of input tensors.
283 output_tensors: List of output tensors (only .name is used from this).
285 Returns:
286 Frozen GraphDef.
287 """
288 # Runs a Grappler pass in order to inline any functions in the graph.
289 # Asides from inlining any simple function, Grappler will also try to lower
290 # while loop into switch merge representation which is undesired for Ophints,
291 # so we simply remove those attributes to prevent Grappler from doing so.
292 graph_def = _convert_to_constants.disable_lower_using_switch_merge(
293 sess.graph_def)
294 config = get_grappler_config(["function"])
295 graph_def = run_graph_optimizations(
296 graph_def, input_tensors, output_tensors, config, graph=sess.graph)
298 # If ophints are present, just convert them.
299 hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
300 if hinted_outputs_nodes:
301 return _convert_op_hints_if_present(sess, graph_def, output_tensors,
302 hinted_outputs_nodes)
304 if not is_frozen_graph(sess):
305 output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors]
306 return _convert_to_constants.convert_variables_to_constants(
307 sess, graph_def, output_node_names
308 )
309 else:
310 return sess.graph_def
313def is_frozen_graph(sess):
314 """Determines if the graph is frozen.
316 Determines if a graph has previously been frozen by checking for any
317 operations of type Variable*. If variables are found, the graph is not frozen.
319 Args:
320 sess: TensorFlow Session.
322 Returns:
323 Bool.
324 """
325 for op in sess.graph.get_operations():
326 if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
327 return False
328 return True
331def build_debug_info_func(original_graph):
332 """Returns a method to retrieve the `GraphDebugInfo` from the original graph.
334 Args:
335 original_graph: The original `Graph` containing all the op stack traces.
337 Returns:
338 A function which retrieves the stack traces from the original graph and
339 converts them to a `GraphDebugInfo` for a given set of nodes.
340 """
342 def f(original_nodes):
343 """Function to create `GraphDebugInfo` for the given `original_nodes`."""
344 if not original_graph:
345 return None
346 # For the given nodes, gets all the op definitions in the original graph.
347 useful_ops = []
348 for func, name in original_nodes:
349 try:
350 if not func:
351 useful_ops.append((func, original_graph.get_operation_by_name(name)))
352 else:
353 sub_func = original_graph._get_function(func) # pylint: disable=protected-access
354 if isinstance(sub_func, function.AtomicFunction): # pylint: disable=protected-access
355 useful_ops.append(
356 (func, sub_func.graph.get_operation_by_name(name)))
357 else:
358 sys.stderr.write(
359 "Use '@tf.function' or '@defun' to decorate the function.\n")
360 continue
361 except KeyError:
362 # New node created by graph optimizer. No stack trace from source code.
363 continue
364 # Convert all the op definitions to stack traces in terms of GraphDebugInfo.
365 return _error_interpolation.create_graph_debug_info_def(useful_ops)
367 return f
370def convert_debug_info_func(saved_debug_info):
371 """Returns a method to retrieve the `GraphDebugInfo` from the original graph.
373 Args:
374 saved_debug_info: The `GraphDebugInfo` containing all the debug info.
376 Returns:
377 A function which retrieves the stack traces from the original graph and
378 converts them to a `GraphDebugInfo` for a given set of nodes.
379 """
381 def f(original_nodes):
382 """Function to create `GraphDebugInfo` for the given `original_nodes`."""
383 if not saved_debug_info:
384 return None
386 output_debug_info = graph_debug_info_pb2.GraphDebugInfo()
387 # All the files are copied over, so the index wouldn't be changed.
388 output_debug_info.files[:] = saved_debug_info.files
389 # We only copy over the debug info for the input nodes
390 for func, node in original_nodes:
391 debug_key = node + "@" + func
392 output_debug_info.traces[debug_key].CopyFrom(
393 saved_debug_info.traces[debug_key])
394 return output_debug_info
396 return f
399def get_debug_info(nodes_to_debug_info_func, converted_graph):
400 """Returns the debug info for the original nodes in the `converted_graph`.
402 Args:
403 nodes_to_debug_info_func: The method to collect the op debug info for the
404 nodes.
405 converted_graph: A `GraphDef` after optimization and transformation.
407 Returns:
408 `GraphDebugInfo` for all the original nodes in `converted_graph`.
409 """
410 if not nodes_to_debug_info_func:
411 return None
413 # Collect all the debug info nodes from the converted_graph
414 original_nodes = set()
415 for node in converted_graph.node:
416 debug_nodes = node.experimental_debug_info.original_node_names
417 debug_funcs = node.experimental_debug_info.original_func_names
418 # If the `original_node_names` are empty, uses the node name directly.
419 if not debug_nodes:
420 original_nodes.add(("", node.name))
421 else:
422 for i in range(len(debug_nodes)):
423 debug_func = "" if i >= len(debug_funcs) else debug_funcs[i]
424 original_nodes.add((debug_func, debug_nodes[i]))
426 # Convert the nodes to the debug info proto object.
427 return nodes_to_debug_info_func(original_nodes)
430def convert_bytes_to_c_source(data,
431 array_name,
432 max_line_width=80,
433 include_guard=None,
434 include_path=None,
435 use_tensorflow_license=False):
436 """Returns strings representing a C constant array containing `data`.
438 Args:
439 data: Byte array that will be converted into a C constant.
440 array_name: String to use as the variable name for the constant array.
441 max_line_width: The longest line length, for formatting purposes.
442 include_guard: Name to use for the include guard macro definition.
443 include_path: Optional path to include in the source file.
444 use_tensorflow_license: Whether to include the standard TensorFlow Apache2
445 license in the generated files.
447 Returns:
448 Text that can be compiled as a C source file to link in the data as a
449 literal array of values.
450 Text that can be used as a C header file to reference the literal array.
451 """
453 starting_pad = " "
454 array_lines = []
455 array_line = starting_pad
456 for value in bytearray(data):
457 if (len(array_line) + 4) > max_line_width:
458 array_lines.append(array_line + "\n")
459 array_line = starting_pad
460 array_line += " 0x%02x," % (value,)
461 if len(array_line) > len(starting_pad):
462 array_lines.append(array_line + "\n")
463 array_values = "".join(array_lines)
465 if include_guard is None:
466 include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_"
468 if include_path is not None:
469 include_line = "#include \"{include_path}\"\n".format(
470 include_path=include_path)
471 else:
472 include_line = ""
474 if use_tensorflow_license:
475 license_text = """
476/* Copyright {year} The TensorFlow Authors. All Rights Reserved.
478Licensed under the Apache License, Version 2.0 (the "License");
479you may not use this file except in compliance with the License.
480You may obtain a copy of the License at
482 http://www.apache.org/licenses/LICENSE-2.0
484Unless required by applicable law or agreed to in writing, software
485distributed under the License is distributed on an "AS IS" BASIS,
486WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
487See the License for the specific language governing permissions and
488limitations under the License.
489==============================================================================*/
490""".format(year=datetime.date.today().year)
491 else:
492 license_text = ""
494 source_template = """{license_text}
495// This is a TensorFlow Lite model file that has been converted into a C data
496// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
497// This form is useful for compiling into a binary for devices that don't have a
498// file system.
500{include_line}
501// We need to keep the data array aligned on some architectures.
502#ifdef __has_attribute
503#define HAVE_ATTRIBUTE(x) __has_attribute(x)
504#else
505#define HAVE_ATTRIBUTE(x) 0
506#endif
507#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__))
508#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4)))
509#else
510#define DATA_ALIGN_ATTRIBUTE
511#endif
513const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{
514{array_values}}};
515const int {array_name}_len = {array_length};
516"""
518 source_text = source_template.format(
519 array_name=array_name,
520 array_length=len(data),
521 array_values=array_values,
522 license_text=license_text,
523 include_line=include_line)
525 header_template = """
526{license_text}
528// This is a TensorFlow Lite model file that has been converted into a C data
529// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
530// This form is useful for compiling into a binary for devices that don't have a
531// file system.
533#ifndef {include_guard}
534#define {include_guard}
536extern const unsigned char {array_name}[];
537extern const int {array_name}_len;
539#endif // {include_guard}
540"""
542 header_text = header_template.format(
543 array_name=array_name,
544 include_guard=include_guard,
545 license_text=license_text)
547 return source_text, header_text
550def _convert_model_from_bytearray_to_object(model_bytearray):
551 """Converts a tflite model from a bytearray into a parsable object."""
552 model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
553 model_object = schema_fb.ModelT.InitFromObj(model_object)
554 model_object = copy.deepcopy(model_object)
555 return model_object
558def _convert_model_from_object_to_bytearray(model_object):
559 """Converts a tflite model from a parsable object into a bytearray."""
560 # Initial size of the buffer, which will grow automatically if needed
561 builder = flatbuffers.Builder(1024)
562 model_offset = model_object.Pack(builder)
563 builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
564 return bytes(builder.Output())
567def get_quantize_opcode_idx(model):
568 """Returns the quantize op idx."""
569 quant_opcode_idxs = []
570 for idx, opcode in enumerate(model.operatorCodes):
571 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
572 if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
573 quant_opcode_idxs.append(idx)
574 return quant_opcode_idxs
577def get_dequantize_opcode_idx(model):
578 """Returns the quantize op idx."""
579 quant_opcode_idxs = []
580 for idx, opcode in enumerate(model.operatorCodes):
581 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
582 if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE:
583 quant_opcode_idxs.append(idx)
584 return quant_opcode_idxs
587def _update_signature_def_tensors(tensor_maps, map_old_to_new_tensors):
588 """Update the tensors in the SignatureDef's TensorMaps."""
589 for i in range(len(tensor_maps)):
590 if tensor_maps[i].tensorIndex in map_old_to_new_tensors:
591 tensor_maps[i].tensorIndex = (
592 map_old_to_new_tensors[tensor_maps[i].tensorIndex])
595def _remove_tensors_from_model(model, remove_tensors_idxs):
596 """Remove tensors from model."""
597 if not remove_tensors_idxs:
598 return
599 if len(model.subgraphs) > 1:
600 logging.info("Skipping the removal of dangled tensors since the model has "
601 "multiple subgraphs and tensors can be used in the different "
602 "subgraph(s)")
603 return
604 subgraph = model.subgraphs[0]
605 tensors = subgraph.tensors
606 operators = subgraph.operators
608 logging.debug("Removing tensors at indices : %s", remove_tensors_idxs)
609 # An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an
610 # exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]).
611 if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs):
612 logging.debug("Removing tensors only at the end of the tensor list")
613 del tensors[min(remove_tensors_idxs):]
614 else:
615 logging.debug("Removing tensors requires updating the model")
616 # Map the old tensor indices to new tensor indices
617 d_old_to_new_tensors = {}
618 left_shift_by = 0
619 for idx in range(len(tensors)):
620 if idx in remove_tensors_idxs:
621 left_shift_by += 1
622 else:
623 d_old_to_new_tensors[idx] = idx - left_shift_by
624 logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__())
625 # Update tensor indices referenced throughout the model
626 def update_tensors(tensor_idxs):
627 for i, ti in enumerate(tensor_idxs):
628 tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1)
629 update_tensors(subgraph.inputs)
630 update_tensors(subgraph.outputs)
631 for op in operators:
632 update_tensors(op.inputs)
633 update_tensors(op.outputs)
634 if model.signatureDefs:
635 signature_def = model.signatureDefs[0]
636 _update_signature_def_tensors(signature_def.inputs, d_old_to_new_tensors)
637 _update_signature_def_tensors(signature_def.outputs, d_old_to_new_tensors)
638 # Delete the tensors
639 for idx in sorted(remove_tensors_idxs, reverse=True):
640 tensors.pop(idx)
641 logging.debug("Removed tensors marked for deletion")
644def _modify_model_input_type(model, inference_input_type=dtypes.float32):
645 """Modify model input type."""
646 if inference_input_type == dtypes.float32:
647 return
649 if not model.signatureDefs:
650 _modify_model_input_type_per_subgraph(model, 0, -1, inference_input_type)
651 return
653 for signature_index, signature_def in enumerate(model.signatureDefs):
654 _modify_model_input_type_per_subgraph(model, signature_def.subgraphIndex,
655 signature_index, inference_input_type)
658def _modify_model_input_type_per_subgraph(model, subgraph_index,
659 signature_index,
660 inference_input_type):
661 """Modify model input type per subgraph."""
662 subgraph = model.subgraphs[subgraph_index]
663 tensors = subgraph.tensors
664 operators = subgraph.operators
666 # Find all quantize operators
667 quant_opcode_idxs = get_quantize_opcode_idx(model)
668 if operators and not quant_opcode_idxs:
669 for input_idx in subgraph.inputs:
670 input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type)
671 if input_type == dtypes.float32:
672 raise ValueError("Model input is not dequantized.")
673 # None of the inputs have float32, then they must be int16, int8, or bool
674 return
676 # Validate that the model input is quantized
677 input_quant_ops = []
678 for op in operators:
679 # Find operators that quantize model input
680 if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs:
681 float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
682 # If found, validate that the operator's input type is float
683 float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
684 if float_type != dtypes.float32:
685 if float_type == inference_input_type:
686 continue
687 else:
688 raise ValueError(
689 "Initial model input type must be tf.float32. Expected type for "
690 "tensor with name '{}' is tf.float32, instead type is {}".format(
691 float_tensor.name, get_tf_type_name(float_type)))
692 # If found, validate that the operator output is quantized and compatible
693 # with the final model input type
694 quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
695 if quant_type not in _MAP_QUANT_TO_IO_TYPES:
696 raise ValueError(
697 "Initial model input is not quantized. Expected type for "
698 "tensor with name '{}' should be in {}, instead type is {}".format(
699 quant_tensor.name,
700 tuple(get_tf_type_name(t) for t in
701 _MAP_QUANT_TO_IO_TYPES.keys()),
702 get_tf_type_name(quant_type)))
703 else:
704 inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
705 if inference_input_type not in inference_io_types:
706 raise ValueError(
707 "Unsupported `inference_input_type` value. Expected to be in "
708 "{}, instead got {}.".format(
709 tuple(get_tf_type_name(t) for t in inference_io_types),
710 get_tf_type_name(inference_input_type)))
711 input_quant_ops.append(op)
713 if len(subgraph.inputs) != len(input_quant_ops):
714 logging.warning(
715 "For model inputs containing unsupported operations which cannot be "
716 "quantized, the `inference_input_type` attribute will default to the "
717 "original type."
718 )
720 # Modify model input type
721 if inference_input_type == dtypes.uint8:
722 # Change quant op (float to int8) to quant op (uint8 to int8)
723 for op in input_quant_ops:
724 int8_quantization = tensors[op.outputs[0]].quantization
725 uint8_quantization = schema_fb.QuantizationParametersT()
726 uint8_quantization.scale = [int8_quantization.scale[0]]
727 uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
728 tensors[op.inputs[0]].quantization = uint8_quantization
729 tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
730 elif inference_input_type in _MAP_QUANT_TO_IO_TYPES:
731 # Remove the inputs and the quant operator
732 remove_tensors_idxs = set()
733 for op in input_quant_ops:
734 subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
735 if signature_index >= 0:
736 signature_def = model.signatureDefs[signature_index]
737 for i in range(len(signature_def.inputs)):
738 if signature_def.inputs[i].tensorIndex == op.inputs[0]:
739 signature_def.inputs[i].tensorIndex = op.outputs[0]
740 remove_tensors_idxs.add(op.inputs[0])
741 operators.remove(op)
742 # Remove tensors marked for deletion.
743 _remove_tensors_from_model(model, remove_tensors_idxs)
744 else:
745 raise ValueError(
746 "Unsupported `inference_input_type` value {}.".format(
747 get_tf_type_name(inference_input_type)))
750def _modify_model_output_type(model, inference_output_type=dtypes.float32):
751 """Modify model output type."""
752 if inference_output_type == dtypes.float32:
753 return
755 if not model.signatureDefs:
756 _modify_model_output_type_per_subgraph(model, 0, -1, inference_output_type)
757 return
759 for signature_index, signature_def in enumerate(model.signatureDefs):
760 _modify_model_output_type_per_subgraph(model, signature_def.subgraphIndex,
761 signature_index,
762 inference_output_type)
765def _modify_model_output_type_per_subgraph(model, subgraph_index,
766 signature_index,
767 inference_output_type):
768 """Modify model output type per subgraph."""
769 subgraph = model.subgraphs[subgraph_index]
770 tensors = subgraph.tensors
771 operators = subgraph.operators
773 # Find all dequantize operators
774 dequant_opcode_idxs = get_dequantize_opcode_idx(model)
775 if operators and not dequant_opcode_idxs:
776 for output in subgraph.outputs:
777 output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type)
778 if output_type == dtypes.float32:
779 raise ValueError("Model output is not dequantized.")
780 # None of the outputs have float32, then they must be int16, int8, or bool
781 return
783 # Validate that the model output is dequantized
784 output_dequant_ops = []
785 for op in operators:
786 # Find operators that dequantize model output
787 if (op.opcodeIndex in dequant_opcode_idxs and
788 op.outputs[0] in subgraph.outputs):
789 # If found, validate that the operator's output type is float
790 quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
791 float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
792 if float_type != dtypes.float32:
793 if float_type == inference_output_type:
794 continue
795 else:
796 raise ValueError(
797 "Initial model output type must be tf.float32. Expected type for "
798 "tensor with name '{}' is tf.float32, instead type is {}".format(
799 float_tensor.name, get_tf_type_name(float_type)))
800 # If found, validate that the operator input is quantized and compatible
801 # with the final model output type
802 quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
803 if quant_type not in _MAP_QUANT_TO_IO_TYPES:
804 raise ValueError(
805 "Initial model output is not dequantized. Expected type for "
806 "tensor with name '{}' should be in {}, instead type is {}".format(
807 quant_tensor.name,
808 tuple(get_tf_type_name(t) for t in
809 _MAP_QUANT_TO_IO_TYPES.keys()),
810 get_tf_type_name(quant_type)))
811 else:
812 inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
813 if inference_output_type not in inference_io_types:
814 raise ValueError(
815 "Unsupported `inference_output_type` value. Expected to be in "
816 "{}, instead got {}.".format(
817 tuple(get_tf_type_name(t) for t in inference_io_types),
818 get_tf_type_name(inference_output_type)))
819 output_dequant_ops.append(op)
821 if len(subgraph.outputs) != len(output_dequant_ops):
822 logging.warning(
823 "For model outputs containing unsupported operations which cannot be "
824 "quantized, the `inference_output_type` attribute will default to the "
825 "original type."
826 )
828 # Modify model output type
829 if inference_output_type == dtypes.uint8:
830 # Find a quantize operator
831 quant_opcode_idx = -1
832 for idx, opcode in enumerate(model.operatorCodes):
833 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
834 if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
835 quant_opcode_idx = idx
836 break
837 # Create a quantize operator, if none exist
838 if quant_opcode_idx == -1:
839 quant_op = schema_fb.OperatorCodeT()
840 quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE
841 quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE
842 model.operatorCodes.append(quant_op)
843 quant_opcode_idx = len(model.operatorCodes) - 1
844 # Change dequant op (int8 to float) to quant op (int8 to uint8)
845 for op in output_dequant_ops:
846 op.opcodeIndex = quant_opcode_idx
847 int8_quantization = tensors[op.inputs[0]].quantization
848 uint8_quantization = schema_fb.QuantizationParametersT()
849 uint8_quantization.scale = [int8_quantization.scale[0]]
850 uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
851 tensors[op.outputs[0]].quantization = uint8_quantization
852 tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
853 elif inference_output_type in _MAP_QUANT_TO_IO_TYPES:
854 # Remove the outputs and the dequant operator
855 remove_tensors_idxs = set()
856 for op in output_dequant_ops:
857 subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
858 if signature_index >= 0:
859 signature_def = model.signatureDefs[signature_index]
860 for i in range(len(signature_def.outputs)):
861 if signature_def.outputs[i].tensorIndex == op.outputs[0]:
862 signature_def.outputs[i].tensorIndex = op.inputs[0]
863 remove_tensors_idxs.add(op.outputs[0])
864 operators.remove(op)
865 # Remove tensors marked for deletion.
866 _remove_tensors_from_model(model, remove_tensors_idxs)
867 else:
868 raise ValueError(
869 "Unsupported `inference_output_type` value {}.".format(
870 get_tf_type_name(inference_output_type)))
873def _remove_redundant_quantize_ops(model):
874 """Finds back to back quantize ops and remove the first quantize op."""
875 if not model.signatureDefs:
876 _remove_redundant_quantize_ops_per_subgraph(model, 0, -1)
877 return
879 for signature_index, signature_def in enumerate(model.signatureDefs):
880 _remove_redundant_quantize_ops_per_subgraph(model,
881 signature_def.subgraphIndex,
882 signature_index)
885def _remove_redundant_quantize_ops_per_subgraph(model, subgraph_index,
886 signature_index):
887 """Remove redundant quantize ops per subgraph."""
888 subgraph = model.subgraphs[subgraph_index]
889 tensors = subgraph.tensors
890 operators = subgraph.operators
892 # Find all quantize operators.
893 quant_opcode_idxs = get_quantize_opcode_idx(model)
894 dequant_opcode_idxs = get_dequantize_opcode_idx(model)
896 # Find all redundant quant tensors.
897 all_quant_ops = []
898 redundant_quant_tensors = {}
899 output_dequant_tensors = {}
900 for op in operators:
901 if op.opcodeIndex in quant_opcode_idxs:
902 all_quant_ops.append(op)
903 input_tensor = tensors[op.inputs[0]]
904 output_tensor = tensors[op.outputs[0]]
905 input_type = _convert_tflite_enum_type_to_tf_type(input_tensor.type)
906 output_type = _convert_tflite_enum_type_to_tf_type(output_tensor.type)
907 # This is a requantize op, so write down its input tensor index.
908 if input_type != dtypes.float32 and output_type != dtypes.float32:
909 redundant_quant_tensors[op.inputs[0]] = op
910 if (op.opcodeIndex in dequant_opcode_idxs and
911 op.outputs[0] in subgraph.outputs):
912 output_dequant_tensors[op.inputs[0]] = op
914 # Remove all the quant ops which produce the redundant quant tensors.
915 for op in all_quant_ops:
916 output_tensor_idx = op.outputs[0]
917 if output_tensor_idx in redundant_quant_tensors:
918 requantize_op = redundant_quant_tensors[output_tensor_idx]
919 if model.signatureDefs:
920 signature_def = model.signatureDefs[0]
921 for output in signature_def.outputs:
922 if output.tensorIndex == op.outputs[0]:
923 output.tensorIndex = op.inputs[0]
924 # Reset the input of the requantize op to the float input
925 requantize_op.inputs[0] = op.inputs[0]
926 operators.remove(op)
928 # Remove all the quant ops which connect to the output dequant op.
929 for op in all_quant_ops:
930 output_tensor_idx = op.outputs[0]
931 if output_tensor_idx in output_dequant_tensors:
932 dequant_op = output_dequant_tensors[output_tensor_idx]
933 subgraph.outputs[subgraph.outputs == dequant_op.outputs[0]] = op.inputs[0]
934 if signature_index >= 0:
935 signature_def = model.signatureDefs[signature_index]
936 for output in signature_def.outputs:
937 if output.tensorIndex == dequant_op.outputs[0]:
938 output.tensorIndex = op.inputs[0]
939 operators.remove(op)
940 operators.remove(dequant_op)
943def modify_model_io_type(
944 model, inference_input_type=dtypes.float32,
945 inference_output_type=dtypes.float32):
946 """Modify the input/output type of a tflite model.
948 Args:
949 model: A tflite model.
950 inference_input_type: tf.DType representing modified input type.
951 (default tf.float32. If model input is int8 quantized, it must be in
952 {tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized,
953 it must be in {tf.float32, tf.int16}, else it must be tf.float32)
954 inference_output_type: tf.DType representing modified output type.
955 (default tf.float32. If model output is int8 dequantized, it must be in
956 {tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized,
957 it must be in {tf.float32, tf.int16}, else it must be tf.float32)
958 Returns:
959 A tflite model with modified input/output type.
961 Raises:
962 ValueError: If `inference_input_type`/`inference_output_type` is unsupported
963 or a supported integer type is specified for a model whose input/output is
964 not quantized/dequantized.
965 RuntimeError: If the modification was unsuccessful.
967 """
968 if (inference_input_type == dtypes.float32 and
969 inference_output_type == dtypes.float32):
970 return model
972 model_object = _convert_model_from_bytearray_to_object(model)
974 _modify_model_input_type(model_object, inference_input_type)
976 _modify_model_output_type(model_object, inference_output_type)
978 _remove_redundant_quantize_ops(model_object)
980 return _convert_model_from_object_to_bytearray(model_object)
983def get_sparsity_modes(model_object):
984 """Get sparsity modes used in a tflite model.
986 The sparsity modes are listed in conversion_metadata.fbs file.
988 Args:
989 model_object: A tflite model in object form.
991 Returns:
992 The list of sparsity modes used in the model.
993 """
994 if not model_object or not model_object.metadata:
995 return []
997 result = set()
998 for subgraph in model_object.subgraphs:
999 for tensor in subgraph.tensors:
1000 if not tensor.sparsity:
1001 continue
1003 # Block map is the list if indexes where the block size is larger than 1.
1004 # So empty block map means it is random sparsity.
1005 if not tensor.sparsity.blockMap:
1006 result.add(
1007 conversion_metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY)
1008 else:
1009 result.add(
1010 conversion_metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY)
1012 return list(result)
1015def populate_conversion_metadata(model_object, metadata):
1016 """Add or update conversion metadata to a tflite model.
1018 Args:
1019 model_object: A tflite model in object form.
1020 metadata: The conversion metadata.
1022 Returns:
1023 A tflite model object with embedded conversion metadata.
1024 """
1025 try:
1026 metadata_builder = flatbuffers.Builder(0)
1027 metadata_builder.Finish(metadata.Pack(metadata_builder))
1028 buffer_field = schema_fb.BufferT()
1029 buffer_field.data = metadata_builder.Output()
1031 if not model_object.metadata:
1032 model_object.metadata = []
1033 else:
1034 # Check if metadata has already been populated.
1035 for meta in model_object.metadata:
1036 if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME:
1037 model_object.buffers[meta.buffer] = buffer_field
1038 return model_object
1040 if not model_object.buffers:
1041 model_object.buffers = []
1042 model_object.buffers.append(buffer_field)
1043 # Creates a new metadata field.
1044 metadata_field = schema_fb.MetadataT()
1045 metadata_field.name = CONVERSION_METADATA_FIELD_NAME
1046 metadata_field.buffer = len(model_object.buffers) - 1
1047 model_object.metadata.append(metadata_field)
1049 return model_object
1050 except Exception: # pylint: disable=broad-except
1051 return model_object
1054def get_conversion_metadata(model_buffer):
1055 """Read conversion metadata from a tflite model.
1057 Args:
1058 model_buffer: A tflite model.
1060 Returns:
1061 The conversion metadata or None if it is not populated.
1062 """
1063 model_object = flatbuffer_utils.convert_bytearray_to_object(model_buffer)
1064 if not model_object or not model_object.metadata:
1065 return None
1067 for meta in model_object.metadata:
1068 if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME:
1069 metadata_buf = model_object.buffers[meta.buffer].data.tobytes()
1070 return conversion_metadata_fb.ConversionMetadataT.InitFromObj(
1071 conversion_metadata_fb.ConversionMetadata.GetRootAsConversionMetadata(
1072 metadata_buf, 0))
1074 return None