Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/tools/flatbuffer_utils.py: 19%
145 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for FlatBuffers.
17All functions that are commonly used to work with FlatBuffers.
19Refer to the tensorflow lite flatbuffer schema here:
20tensorflow/lite/schema/schema.fbs
22"""
24import copy
25import random
26import re
27import struct
28import sys
30import flatbuffers
31from tensorflow.lite.python import schema_py_generated as schema_fb
32from tensorflow.lite.python import schema_util
33from tensorflow.python.platform import gfile
35_TFLITE_FILE_IDENTIFIER = b'TFL3'
38def convert_bytearray_to_object(model_bytearray):
39 """Converts a tflite model from a bytearray to an object for parsing."""
40 model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
41 return schema_fb.ModelT.InitFromObj(model_object)
44def read_model(input_tflite_file):
45 """Reads a tflite model as a python object.
47 Args:
48 input_tflite_file: Full path name to the input tflite file
50 Raises:
51 RuntimeError: If input_tflite_file path is invalid.
52 IOError: If input_tflite_file cannot be opened.
54 Returns:
55 A python object corresponding to the input tflite file.
56 """
57 if not gfile.Exists(input_tflite_file):
58 raise RuntimeError('Input file not found at %r\n' % input_tflite_file)
59 with gfile.GFile(input_tflite_file, 'rb') as input_file_handle:
60 model_bytearray = bytearray(input_file_handle.read())
61 model = convert_bytearray_to_object(model_bytearray)
62 if sys.byteorder == 'big':
63 byte_swap_tflite_model_obj(model, 'little', 'big')
64 return model
67def read_model_with_mutable_tensors(input_tflite_file):
68 """Reads a tflite model as a python object with mutable tensors.
70 Similar to read_model() with the addition that the returned object has
71 mutable tensors (read_model() returns an object with immutable tensors).
73 Args:
74 input_tflite_file: Full path name to the input tflite file
76 Raises:
77 RuntimeError: If input_tflite_file path is invalid.
78 IOError: If input_tflite_file cannot be opened.
80 Returns:
81 A mutable python object corresponding to the input tflite file.
82 """
83 return copy.deepcopy(read_model(input_tflite_file))
86def convert_object_to_bytearray(model_object):
87 """Converts a tflite model from an object to a immutable bytearray."""
88 # Initial size of the buffer, which will grow automatically if needed
89 builder = flatbuffers.Builder(1024)
90 model_offset = model_object.Pack(builder)
91 builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
92 model_bytearray = bytes(builder.Output())
93 return model_bytearray
96def write_model(model_object, output_tflite_file):
97 """Writes the tflite model, a python object, into the output file.
99 Args:
100 model_object: A tflite model as a python object
101 output_tflite_file: Full path name to the output tflite file.
103 Raises:
104 IOError: If output_tflite_file path is invalid or cannot be opened.
105 """
106 if sys.byteorder == 'big':
107 model_object = copy.deepcopy(model_object)
108 byte_swap_tflite_model_obj(model_object, 'big', 'little')
109 model_bytearray = convert_object_to_bytearray(model_object)
110 with gfile.GFile(output_tflite_file, 'wb') as output_file_handle:
111 output_file_handle.write(model_bytearray)
114def strip_strings(model):
115 """Strips all nonessential strings from the model to reduce model size.
117 We remove the following strings:
118 (find strings by searching ":string" in the tensorflow lite flatbuffer schema)
119 1. Model description
120 2. SubGraph name
121 3. Tensor names
122 We retain OperatorCode custom_code and Metadata name.
124 Args:
125 model: The model from which to remove nonessential strings.
126 """
128 model.description = None
129 for subgraph in model.subgraphs:
130 subgraph.name = None
131 for tensor in subgraph.tensors:
132 tensor.name = None
133 # We clear all signature_def structure, since without names it is useless.
134 model.signatureDefs = None
137def type_to_name(tensor_type):
138 """Converts a numerical enum to a readable tensor type."""
139 for name, value in schema_fb.TensorType.__dict__.items():
140 if value == tensor_type:
141 return name
142 return None
145def randomize_weights(model, random_seed=0, buffers_to_skip=None):
146 """Randomize weights in a model.
148 Args:
149 model: The model in which to randomize weights.
150 random_seed: The input to the random number generator (default value is 0).
151 buffers_to_skip: The list of buffer indices to skip. The weights in these
152 buffers are left unmodified.
153 """
155 # The input to the random seed generator. The default value is 0.
156 random.seed(random_seed)
158 # Parse model buffers which store the model weights
159 buffers = model.buffers
160 buffer_ids = range(1, len(buffers)) # ignore index 0 as it's always None
161 if buffers_to_skip is not None:
162 buffer_ids = [idx for idx in buffer_ids if idx not in buffers_to_skip]
164 buffer_types = {}
165 for graph in model.subgraphs:
166 for op in graph.operators:
167 if op.inputs is None:
168 break
169 for input_idx in op.inputs:
170 tensor = graph.tensors[input_idx]
171 buffer_types[tensor.buffer] = type_to_name(tensor.type)
173 for i in buffer_ids:
174 buffer_i_data = buffers[i].data
175 buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size
176 if buffer_i_size == 0:
177 continue
179 # Raw data buffers are of type ubyte (or uint8) whose values lie in the
180 # range [0, 255]. Those ubytes (or unint8s) are the underlying
181 # representation of each datatype. For example, a bias tensor of type
182 # int32 appears as a buffer 4 times it's length of type ubyte (or uint8).
183 # For floats, we need to generate a valid float and then pack it into
184 # the raw bytes in place.
185 buffer_type = buffer_types.get(i, 'INT8')
186 if buffer_type.startswith('FLOAT'):
187 format_code = 'e' if buffer_type == 'FLOAT16' else 'f'
188 for offset in range(0, buffer_i_size, struct.calcsize(format_code)):
189 value = random.uniform(-0.5, 0.5) # See http://b/152324470#comment2
190 struct.pack_into(format_code, buffer_i_data, offset, value)
191 else:
192 for j in range(buffer_i_size):
193 buffer_i_data[j] = random.randint(0, 255)
196def rename_custom_ops(model, map_custom_op_renames):
197 """Rename custom ops so they use the same naming style as builtin ops.
199 Args:
200 model: The input tflite model.
201 map_custom_op_renames: A mapping from old to new custom op names.
202 """
203 for op_code in model.operatorCodes:
204 if op_code.customCode:
205 op_code_str = op_code.customCode.decode('ascii')
206 if op_code_str in map_custom_op_renames:
207 op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii')
210def opcode_to_name(model, op_code):
211 """Converts a TFLite op_code to the human readable name.
213 Args:
214 model: The input tflite model.
215 op_code: The op_code to resolve to a readable name.
217 Returns:
218 A string containing the human readable op name, or None if not resolvable.
219 """
220 op = model.operatorCodes[op_code]
221 code = max(op.builtinCode, op.deprecatedBuiltinCode)
222 for name, value in vars(schema_fb.BuiltinOperator).items():
223 if value == code:
224 return name
225 return None
228def xxd_output_to_bytes(input_cc_file):
229 """Converts xxd output C++ source file to bytes (immutable).
231 Args:
232 input_cc_file: Full path name to th C++ source file dumped by xxd
234 Raises:
235 RuntimeError: If input_cc_file path is invalid.
236 IOError: If input_cc_file cannot be opened.
238 Returns:
239 A bytearray corresponding to the input cc file array.
240 """
241 # Match hex values in the string with comma as separator
242 pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*')
244 model_bytearray = bytearray()
246 with open(input_cc_file) as file_handle:
247 for line in file_handle:
248 values_match = pattern.match(line)
250 if values_match is None:
251 continue
253 # Match in the parentheses (hex array only)
254 list_text = values_match.group(1)
256 # Extract hex values (text) from the line
257 # e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c,
258 values_text = filter(None, list_text.split(','))
260 # Convert to hex
261 values = [int(x, base=16) for x in values_text]
262 model_bytearray.extend(values)
264 return bytes(model_bytearray)
267def xxd_output_to_object(input_cc_file):
268 """Converts xxd output C++ source file to object.
270 Args:
271 input_cc_file: Full path name to th C++ source file dumped by xxd
273 Raises:
274 RuntimeError: If input_cc_file path is invalid.
275 IOError: If input_cc_file cannot be opened.
277 Returns:
278 A python object corresponding to the input tflite file.
279 """
280 model_bytes = xxd_output_to_bytes(input_cc_file)
281 return convert_bytearray_to_object(model_bytes)
284def byte_swap_buffer_content(buffer, chunksize, from_endiness, to_endiness):
285 """Helper function for byte-swapping the buffers field."""
286 to_swap = [
287 buffer.data[i : i + chunksize]
288 for i in range(0, len(buffer.data), chunksize)
289 ]
290 buffer.data = b''.join(
291 [
292 int.from_bytes(byteswap, from_endiness).to_bytes(
293 chunksize, to_endiness
294 )
295 for byteswap in to_swap
296 ]
297 )
300def byte_swap_tflite_model_obj(model, from_endiness, to_endiness):
301 """Byte swaps the buffers field in a TFLite model.
303 Args:
304 model: TFLite model object of from_endiness format.
305 from_endiness: The original endianness format of the buffers in model.
306 to_endiness: The destined endianness format of the buffers in model.
307 """
308 if model is None:
309 return
310 # Get all the constant buffers, byte swapping them as per their data types
311 buffer_swapped = []
312 types_of_16_bits = [
313 schema_fb.TensorType.FLOAT16,
314 schema_fb.TensorType.INT16,
315 schema_fb.TensorType.UINT16,
316 ]
317 types_of_32_bits = [
318 schema_fb.TensorType.FLOAT32,
319 schema_fb.TensorType.INT32,
320 schema_fb.TensorType.COMPLEX64,
321 schema_fb.TensorType.UINT32,
322 ]
323 types_of_64_bits = [
324 schema_fb.TensorType.INT64,
325 schema_fb.TensorType.FLOAT64,
326 schema_fb.TensorType.COMPLEX128,
327 schema_fb.TensorType.UINT64,
328 ]
329 for subgraph in model.subgraphs:
330 for tensor in subgraph.tensors:
331 if (
332 tensor.buffer > 0
333 and tensor.buffer < len(model.buffers)
334 and tensor.buffer not in buffer_swapped
335 and model.buffers[tensor.buffer].data is not None
336 ):
337 if tensor.type in types_of_16_bits:
338 byte_swap_buffer_content(
339 model.buffers[tensor.buffer], 2, from_endiness, to_endiness
340 )
341 elif tensor.type in types_of_32_bits:
342 byte_swap_buffer_content(
343 model.buffers[tensor.buffer], 4, from_endiness, to_endiness
344 )
345 elif tensor.type in types_of_64_bits:
346 byte_swap_buffer_content(
347 model.buffers[tensor.buffer], 8, from_endiness, to_endiness
348 )
349 else:
350 continue
351 buffer_swapped.append(tensor.buffer)
354def byte_swap_tflite_buffer(tflite_model, from_endiness, to_endiness):
355 """Generates a new model byte array after byte swapping its buffers field.
357 Args:
358 tflite_model: TFLite flatbuffer in a byte array.
359 from_endiness: The original endianness format of the buffers in
360 tflite_model.
361 to_endiness: The destined endianness format of the buffers in tflite_model.
363 Returns:
364 TFLite flatbuffer in a byte array, after being byte swapped to to_endiness
365 format.
366 """
367 if tflite_model is None:
368 return None
369 # Load TFLite Flatbuffer byte array into an object.
370 model = convert_bytearray_to_object(tflite_model)
372 # Byte swapping the constant buffers as per their data types
373 byte_swap_tflite_model_obj(model, from_endiness, to_endiness)
375 # Return a TFLite flatbuffer as a byte array.
376 return convert_object_to_bytearray(model)
379def count_resource_variables(model):
380 """Calculates the number of unique resource variables in a model.
382 Args:
383 model: the input tflite model, either as bytearray or object.
385 Returns:
386 An integer number representing the number of unique resource variables.
387 """
388 if not isinstance(model, schema_fb.ModelT):
389 model = convert_bytearray_to_object(model)
390 unique_shared_names = set()
391 for subgraph in model.subgraphs:
392 if subgraph.operators is None:
393 continue
394 for op in subgraph.operators:
395 builtin_code = schema_util.get_builtin_code_from_operator_code(
396 model.operatorCodes[op.opcodeIndex])
397 if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE:
398 unique_shared_names.add(op.builtinOptions.sharedName)
399 return len(unique_shared_names)