Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/tools/flatbuffer_utils.py: 19%

145 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utility functions for FlatBuffers. 

16 

17All functions that are commonly used to work with FlatBuffers. 

18 

19Refer to the tensorflow lite flatbuffer schema here: 

20tensorflow/lite/schema/schema.fbs 

21 

22""" 

23 

24import copy 

25import random 

26import re 

27import struct 

28import sys 

29 

30import flatbuffers 

31from tensorflow.lite.python import schema_py_generated as schema_fb 

32from tensorflow.lite.python import schema_util 

33from tensorflow.python.platform import gfile 

34 

35_TFLITE_FILE_IDENTIFIER = b'TFL3' 

36 

37 

38def convert_bytearray_to_object(model_bytearray): 

39 """Converts a tflite model from a bytearray to an object for parsing.""" 

40 model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0) 

41 return schema_fb.ModelT.InitFromObj(model_object) 

42 

43 

44def read_model(input_tflite_file): 

45 """Reads a tflite model as a python object. 

46 

47 Args: 

48 input_tflite_file: Full path name to the input tflite file 

49 

50 Raises: 

51 RuntimeError: If input_tflite_file path is invalid. 

52 IOError: If input_tflite_file cannot be opened. 

53 

54 Returns: 

55 A python object corresponding to the input tflite file. 

56 """ 

57 if not gfile.Exists(input_tflite_file): 

58 raise RuntimeError('Input file not found at %r\n' % input_tflite_file) 

59 with gfile.GFile(input_tflite_file, 'rb') as input_file_handle: 

60 model_bytearray = bytearray(input_file_handle.read()) 

61 model = convert_bytearray_to_object(model_bytearray) 

62 if sys.byteorder == 'big': 

63 byte_swap_tflite_model_obj(model, 'little', 'big') 

64 return model 

65 

66 

67def read_model_with_mutable_tensors(input_tflite_file): 

68 """Reads a tflite model as a python object with mutable tensors. 

69 

70 Similar to read_model() with the addition that the returned object has 

71 mutable tensors (read_model() returns an object with immutable tensors). 

72 

73 Args: 

74 input_tflite_file: Full path name to the input tflite file 

75 

76 Raises: 

77 RuntimeError: If input_tflite_file path is invalid. 

78 IOError: If input_tflite_file cannot be opened. 

79 

80 Returns: 

81 A mutable python object corresponding to the input tflite file. 

82 """ 

83 return copy.deepcopy(read_model(input_tflite_file)) 

84 

85 

86def convert_object_to_bytearray(model_object): 

87 """Converts a tflite model from an object to a immutable bytearray.""" 

88 # Initial size of the buffer, which will grow automatically if needed 

89 builder = flatbuffers.Builder(1024) 

90 model_offset = model_object.Pack(builder) 

91 builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER) 

92 model_bytearray = bytes(builder.Output()) 

93 return model_bytearray 

94 

95 

96def write_model(model_object, output_tflite_file): 

97 """Writes the tflite model, a python object, into the output file. 

98 

99 Args: 

100 model_object: A tflite model as a python object 

101 output_tflite_file: Full path name to the output tflite file. 

102 

103 Raises: 

104 IOError: If output_tflite_file path is invalid or cannot be opened. 

105 """ 

106 if sys.byteorder == 'big': 

107 model_object = copy.deepcopy(model_object) 

108 byte_swap_tflite_model_obj(model_object, 'big', 'little') 

109 model_bytearray = convert_object_to_bytearray(model_object) 

110 with gfile.GFile(output_tflite_file, 'wb') as output_file_handle: 

111 output_file_handle.write(model_bytearray) 

112 

113 

114def strip_strings(model): 

115 """Strips all nonessential strings from the model to reduce model size. 

116 

117 We remove the following strings: 

118 (find strings by searching ":string" in the tensorflow lite flatbuffer schema) 

119 1. Model description 

120 2. SubGraph name 

121 3. Tensor names 

122 We retain OperatorCode custom_code and Metadata name. 

123 

124 Args: 

125 model: The model from which to remove nonessential strings. 

126 """ 

127 

128 model.description = None 

129 for subgraph in model.subgraphs: 

130 subgraph.name = None 

131 for tensor in subgraph.tensors: 

132 tensor.name = None 

133 # We clear all signature_def structure, since without names it is useless. 

134 model.signatureDefs = None 

135 

136 

137def type_to_name(tensor_type): 

138 """Converts a numerical enum to a readable tensor type.""" 

139 for name, value in schema_fb.TensorType.__dict__.items(): 

140 if value == tensor_type: 

141 return name 

142 return None 

143 

144 

145def randomize_weights(model, random_seed=0, buffers_to_skip=None): 

146 """Randomize weights in a model. 

147 

148 Args: 

149 model: The model in which to randomize weights. 

150 random_seed: The input to the random number generator (default value is 0). 

151 buffers_to_skip: The list of buffer indices to skip. The weights in these 

152 buffers are left unmodified. 

153 """ 

154 

155 # The input to the random seed generator. The default value is 0. 

156 random.seed(random_seed) 

157 

158 # Parse model buffers which store the model weights 

159 buffers = model.buffers 

160 buffer_ids = range(1, len(buffers)) # ignore index 0 as it's always None 

161 if buffers_to_skip is not None: 

162 buffer_ids = [idx for idx in buffer_ids if idx not in buffers_to_skip] 

163 

164 buffer_types = {} 

165 for graph in model.subgraphs: 

166 for op in graph.operators: 

167 if op.inputs is None: 

168 break 

169 for input_idx in op.inputs: 

170 tensor = graph.tensors[input_idx] 

171 buffer_types[tensor.buffer] = type_to_name(tensor.type) 

172 

173 for i in buffer_ids: 

174 buffer_i_data = buffers[i].data 

175 buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size 

176 if buffer_i_size == 0: 

177 continue 

178 

179 # Raw data buffers are of type ubyte (or uint8) whose values lie in the 

180 # range [0, 255]. Those ubytes (or unint8s) are the underlying 

181 # representation of each datatype. For example, a bias tensor of type 

182 # int32 appears as a buffer 4 times it's length of type ubyte (or uint8). 

183 # For floats, we need to generate a valid float and then pack it into 

184 # the raw bytes in place. 

185 buffer_type = buffer_types.get(i, 'INT8') 

186 if buffer_type.startswith('FLOAT'): 

187 format_code = 'e' if buffer_type == 'FLOAT16' else 'f' 

188 for offset in range(0, buffer_i_size, struct.calcsize(format_code)): 

189 value = random.uniform(-0.5, 0.5) # See http://b/152324470#comment2 

190 struct.pack_into(format_code, buffer_i_data, offset, value) 

191 else: 

192 for j in range(buffer_i_size): 

193 buffer_i_data[j] = random.randint(0, 255) 

194 

195 

196def rename_custom_ops(model, map_custom_op_renames): 

197 """Rename custom ops so they use the same naming style as builtin ops. 

198 

199 Args: 

200 model: The input tflite model. 

201 map_custom_op_renames: A mapping from old to new custom op names. 

202 """ 

203 for op_code in model.operatorCodes: 

204 if op_code.customCode: 

205 op_code_str = op_code.customCode.decode('ascii') 

206 if op_code_str in map_custom_op_renames: 

207 op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii') 

208 

209 

210def opcode_to_name(model, op_code): 

211 """Converts a TFLite op_code to the human readable name. 

212 

213 Args: 

214 model: The input tflite model. 

215 op_code: The op_code to resolve to a readable name. 

216 

217 Returns: 

218 A string containing the human readable op name, or None if not resolvable. 

219 """ 

220 op = model.operatorCodes[op_code] 

221 code = max(op.builtinCode, op.deprecatedBuiltinCode) 

222 for name, value in vars(schema_fb.BuiltinOperator).items(): 

223 if value == code: 

224 return name 

225 return None 

226 

227 

228def xxd_output_to_bytes(input_cc_file): 

229 """Converts xxd output C++ source file to bytes (immutable). 

230 

231 Args: 

232 input_cc_file: Full path name to th C++ source file dumped by xxd 

233 

234 Raises: 

235 RuntimeError: If input_cc_file path is invalid. 

236 IOError: If input_cc_file cannot be opened. 

237 

238 Returns: 

239 A bytearray corresponding to the input cc file array. 

240 """ 

241 # Match hex values in the string with comma as separator 

242 pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*') 

243 

244 model_bytearray = bytearray() 

245 

246 with open(input_cc_file) as file_handle: 

247 for line in file_handle: 

248 values_match = pattern.match(line) 

249 

250 if values_match is None: 

251 continue 

252 

253 # Match in the parentheses (hex array only) 

254 list_text = values_match.group(1) 

255 

256 # Extract hex values (text) from the line 

257 # e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 

258 values_text = filter(None, list_text.split(',')) 

259 

260 # Convert to hex 

261 values = [int(x, base=16) for x in values_text] 

262 model_bytearray.extend(values) 

263 

264 return bytes(model_bytearray) 

265 

266 

267def xxd_output_to_object(input_cc_file): 

268 """Converts xxd output C++ source file to object. 

269 

270 Args: 

271 input_cc_file: Full path name to th C++ source file dumped by xxd 

272 

273 Raises: 

274 RuntimeError: If input_cc_file path is invalid. 

275 IOError: If input_cc_file cannot be opened. 

276 

277 Returns: 

278 A python object corresponding to the input tflite file. 

279 """ 

280 model_bytes = xxd_output_to_bytes(input_cc_file) 

281 return convert_bytearray_to_object(model_bytes) 

282 

283 

284def byte_swap_buffer_content(buffer, chunksize, from_endiness, to_endiness): 

285 """Helper function for byte-swapping the buffers field.""" 

286 to_swap = [ 

287 buffer.data[i : i + chunksize] 

288 for i in range(0, len(buffer.data), chunksize) 

289 ] 

290 buffer.data = b''.join( 

291 [ 

292 int.from_bytes(byteswap, from_endiness).to_bytes( 

293 chunksize, to_endiness 

294 ) 

295 for byteswap in to_swap 

296 ] 

297 ) 

298 

299 

300def byte_swap_tflite_model_obj(model, from_endiness, to_endiness): 

301 """Byte swaps the buffers field in a TFLite model. 

302 

303 Args: 

304 model: TFLite model object of from_endiness format. 

305 from_endiness: The original endianness format of the buffers in model. 

306 to_endiness: The destined endianness format of the buffers in model. 

307 """ 

308 if model is None: 

309 return 

310 # Get all the constant buffers, byte swapping them as per their data types 

311 buffer_swapped = [] 

312 types_of_16_bits = [ 

313 schema_fb.TensorType.FLOAT16, 

314 schema_fb.TensorType.INT16, 

315 schema_fb.TensorType.UINT16, 

316 ] 

317 types_of_32_bits = [ 

318 schema_fb.TensorType.FLOAT32, 

319 schema_fb.TensorType.INT32, 

320 schema_fb.TensorType.COMPLEX64, 

321 schema_fb.TensorType.UINT32, 

322 ] 

323 types_of_64_bits = [ 

324 schema_fb.TensorType.INT64, 

325 schema_fb.TensorType.FLOAT64, 

326 schema_fb.TensorType.COMPLEX128, 

327 schema_fb.TensorType.UINT64, 

328 ] 

329 for subgraph in model.subgraphs: 

330 for tensor in subgraph.tensors: 

331 if ( 

332 tensor.buffer > 0 

333 and tensor.buffer < len(model.buffers) 

334 and tensor.buffer not in buffer_swapped 

335 and model.buffers[tensor.buffer].data is not None 

336 ): 

337 if tensor.type in types_of_16_bits: 

338 byte_swap_buffer_content( 

339 model.buffers[tensor.buffer], 2, from_endiness, to_endiness 

340 ) 

341 elif tensor.type in types_of_32_bits: 

342 byte_swap_buffer_content( 

343 model.buffers[tensor.buffer], 4, from_endiness, to_endiness 

344 ) 

345 elif tensor.type in types_of_64_bits: 

346 byte_swap_buffer_content( 

347 model.buffers[tensor.buffer], 8, from_endiness, to_endiness 

348 ) 

349 else: 

350 continue 

351 buffer_swapped.append(tensor.buffer) 

352 

353 

354def byte_swap_tflite_buffer(tflite_model, from_endiness, to_endiness): 

355 """Generates a new model byte array after byte swapping its buffers field. 

356 

357 Args: 

358 tflite_model: TFLite flatbuffer in a byte array. 

359 from_endiness: The original endianness format of the buffers in 

360 tflite_model. 

361 to_endiness: The destined endianness format of the buffers in tflite_model. 

362 

363 Returns: 

364 TFLite flatbuffer in a byte array, after being byte swapped to to_endiness 

365 format. 

366 """ 

367 if tflite_model is None: 

368 return None 

369 # Load TFLite Flatbuffer byte array into an object. 

370 model = convert_bytearray_to_object(tflite_model) 

371 

372 # Byte swapping the constant buffers as per their data types 

373 byte_swap_tflite_model_obj(model, from_endiness, to_endiness) 

374 

375 # Return a TFLite flatbuffer as a byte array. 

376 return convert_object_to_bytearray(model) 

377 

378 

379def count_resource_variables(model): 

380 """Calculates the number of unique resource variables in a model. 

381 

382 Args: 

383 model: the input tflite model, either as bytearray or object. 

384 

385 Returns: 

386 An integer number representing the number of unique resource variables. 

387 """ 

388 if not isinstance(model, schema_fb.ModelT): 

389 model = convert_bytearray_to_object(model) 

390 unique_shared_names = set() 

391 for subgraph in model.subgraphs: 

392 if subgraph.operators is None: 

393 continue 

394 for op in subgraph.operators: 

395 builtin_code = schema_util.get_builtin_code_from_operator_code( 

396 model.operatorCodes[op.opcodeIndex]) 

397 if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE: 

398 unique_shared_names.add(op.builtinOptions.sharedName) 

399 return len(unique_shared_names)