Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/python/interpreter.py: 25%

223 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python TF-Lite interpreter.""" 

16import ctypes 

17import enum 

18import os 

19import platform 

20import sys 

21 

22import numpy as np 

23 

24# pylint: disable=g-import-not-at-top 

25if not os.path.splitext(__file__)[0].endswith( 

26 os.path.join('tflite_runtime', 'interpreter')): 

27 # This file is part of tensorflow package. 

28 from tensorflow.lite.python.interpreter_wrapper import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper 

29 from tensorflow.lite.python.metrics import metrics 

30 from tensorflow.python.util.tf_export import tf_export as _tf_export 

31else: 

32 # This file is part of tflite_runtime package. 

33 from tflite_runtime import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper 

34 from tflite_runtime import metrics_portable as metrics 

35 

36 def _tf_export(*x, **kwargs): 

37 del x, kwargs 

38 return lambda x: x 

39 

40 

41# pylint: enable=g-import-not-at-top 

42 

43 

44class Delegate: 

45 """Python wrapper class to manage TfLiteDelegate objects. 

46 

47 The shared library is expected to have two functions, 

48 tflite_plugin_create_delegate and tflite_plugin_destroy_delegate, 

49 which should implement the API specified in 

50 tensorflow/lite/delegates/external/external_delegate_interface.h. 

51 """ 

52 

53 def __init__(self, library, options=None): 

54 """Loads delegate from the shared library. 

55 

56 Args: 

57 library: Shared library name. 

58 options: Dictionary of options that are required to load the delegate. All 

59 keys and values in the dictionary should be serializable. Consult the 

60 documentation of the specific delegate for required and legal options. 

61 (default None) 

62 

63 Raises: 

64 RuntimeError: This is raised if the Python implementation is not CPython. 

65 """ 

66 

67 # TODO(b/136468453): Remove need for __del__ ordering needs of CPython 

68 # by using explicit closes(). See implementation of Interpreter __del__. 

69 if platform.python_implementation() != 'CPython': 

70 raise RuntimeError('Delegates are currently only supported into CPython' 

71 'due to missing immediate reference counting.') 

72 

73 self._library = ctypes.pydll.LoadLibrary(library) 

74 self._library.tflite_plugin_create_delegate.argtypes = [ 

75 ctypes.POINTER(ctypes.c_char_p), 

76 ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, 

77 ctypes.CFUNCTYPE(None, ctypes.c_char_p) 

78 ] 

79 # The return type is really 'TfLiteDelegate*', but 'void*' is close enough. 

80 self._library.tflite_plugin_create_delegate.restype = ctypes.c_void_p 

81 

82 # Convert the options from a dictionary to lists of char pointers. 

83 options = options or {} 

84 options_keys = (ctypes.c_char_p * len(options))() 

85 options_values = (ctypes.c_char_p * len(options))() 

86 for idx, (key, value) in enumerate(options.items()): 

87 options_keys[idx] = str(key).encode('utf-8') 

88 options_values[idx] = str(value).encode('utf-8') 

89 

90 class ErrorMessageCapture: 

91 

92 def __init__(self): 

93 self.message = '' 

94 

95 def report(self, x): 

96 self.message += x if isinstance(x, str) else x.decode('utf-8') 

97 

98 capture = ErrorMessageCapture() 

99 error_capturer_cb = ctypes.CFUNCTYPE(None, ctypes.c_char_p)(capture.report) 

100 # Do not make a copy of _delegate_ptr. It is freed by Delegate's finalizer. 

101 self._delegate_ptr = self._library.tflite_plugin_create_delegate( 

102 options_keys, options_values, len(options), error_capturer_cb) 

103 if self._delegate_ptr is None: 

104 raise ValueError(capture.message) 

105 

106 def __del__(self): 

107 # __del__ can not be called multiple times, so if the delegate is destroyed. 

108 # don't try to destroy it twice. 

109 if self._library is not None: 

110 self._library.tflite_plugin_destroy_delegate.argtypes = [ctypes.c_void_p] 

111 self._library.tflite_plugin_destroy_delegate(self._delegate_ptr) 

112 self._library = None 

113 

114 def _get_native_delegate_pointer(self): 

115 """Returns the native TfLiteDelegate pointer. 

116 

117 It is not safe to copy this pointer because it needs to be freed. 

118 

119 Returns: 

120 TfLiteDelegate * 

121 """ 

122 return self._delegate_ptr 

123 

124 

125@_tf_export('lite.experimental.load_delegate') 

126def load_delegate(library, options=None): 

127 """Returns loaded Delegate object. 

128 

129 Example usage: 

130 

131 ``` 

132 import tensorflow as tf 

133 

134 try: 

135 delegate = tf.lite.experimental.load_delegate('delegate.so') 

136 except ValueError: 

137 // Fallback to CPU 

138 

139 if delegate: 

140 interpreter = tf.lite.Interpreter( 

141 model_path='model.tflite', 

142 experimental_delegates=[delegate]) 

143 else: 

144 interpreter = tf.lite.Interpreter(model_path='model.tflite') 

145 ``` 

146 

147 This is typically used to leverage EdgeTPU for running TensorFlow Lite models. 

148 For more information see: https://coral.ai/docs/edgetpu/tflite-python/ 

149 

150 Args: 

151 library: Name of shared library containing the 

152 [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates). 

153 options: Dictionary of options that are required to load the delegate. All 

154 keys and values in the dictionary should be convertible to str. Consult 

155 the documentation of the specific delegate for required and legal options. 

156 (default None) 

157 

158 Returns: 

159 Delegate object. 

160 

161 Raises: 

162 ValueError: Delegate failed to load. 

163 RuntimeError: If delegate loading is used on unsupported platform. 

164 """ 

165 try: 

166 delegate = Delegate(library, options) 

167 except ValueError as e: 

168 raise ValueError('Failed to load delegate from {}\n{}'.format( 

169 library, str(e))) 

170 return delegate 

171 

172 

173class SignatureRunner: 

174 """SignatureRunner class for running TFLite models using SignatureDef. 

175 

176 This class should be instantiated through TFLite Interpreter only using 

177 get_signature_runner method on Interpreter. 

178 Example, 

179 signature = interpreter.get_signature_runner("my_signature") 

180 result = signature(input_1=my_input_1, input_2=my_input_2) 

181 print(result["my_output"]) 

182 print(result["my_second_output"]) 

183 All names used are this specific SignatureDef names. 

184 

185 Notes: 

186 No other function on this object or on the interpreter provided should be 

187 called while this object call has not finished. 

188 """ 

189 

190 def __init__(self, interpreter=None, signature_key=None): 

191 """Constructor. 

192 

193 Args: 

194 interpreter: Interpreter object that is already initialized with the 

195 requested model. 

196 signature_key: SignatureDef key to be used. 

197 """ 

198 if not interpreter: 

199 raise ValueError('None interpreter provided.') 

200 if not signature_key: 

201 raise ValueError('None signature_key provided.') 

202 self._interpreter = interpreter 

203 self._interpreter_wrapper = interpreter._interpreter 

204 self._signature_key = signature_key 

205 signature_defs = interpreter._get_full_signature_list() 

206 if signature_key not in signature_defs: 

207 raise ValueError('Invalid signature_key provided.') 

208 self._signature_def = signature_defs[signature_key] 

209 self._outputs = self._signature_def['outputs'].items() 

210 self._inputs = self._signature_def['inputs'] 

211 

212 self._subgraph_index = ( 

213 self._interpreter_wrapper.GetSubgraphIndexFromSignature( 

214 self._signature_key)) 

215 

216 def __call__(self, **kwargs): 

217 """Runs the SignatureDef given the provided inputs in arguments. 

218 

219 Args: 

220 **kwargs: key,value for inputs to the model. Key is the SignatureDef input 

221 name. Value is numpy array with the value. 

222 

223 Returns: 

224 dictionary of the results from the model invoke. 

225 Key in the dictionary is SignatureDef output name. 

226 Value is the result Tensor. 

227 """ 

228 

229 if len(kwargs) != len(self._inputs): 

230 raise ValueError( 

231 'Invalid number of inputs provided for running a SignatureDef, ' 

232 'expected %s vs provided %s' % (len(self._inputs), len(kwargs))) 

233 

234 # Resize input tensors 

235 for input_name, value in kwargs.items(): 

236 if input_name not in self._inputs: 

237 raise ValueError('Invalid Input name (%s) for SignatureDef' % 

238 input_name) 

239 self._interpreter_wrapper.ResizeInputTensor( 

240 self._inputs[input_name], np.array(value.shape, dtype=np.int32), 

241 False, self._subgraph_index) 

242 # Allocate tensors. 

243 self._interpreter_wrapper.AllocateTensors(self._subgraph_index) 

244 # Set the input values. 

245 for input_name, value in kwargs.items(): 

246 self._interpreter_wrapper.SetTensor(self._inputs[input_name], value, 

247 self._subgraph_index) 

248 

249 self._interpreter_wrapper.Invoke(self._subgraph_index) 

250 result = {} 

251 for output_name, output_index in self._outputs: 

252 result[output_name] = self._interpreter_wrapper.GetTensor( 

253 output_index, self._subgraph_index) 

254 return result 

255 

256 def get_input_details(self): 

257 """Gets input tensor details. 

258 

259 Returns: 

260 A dictionary from input name to tensor details where each item is a 

261 dictionary with details about an input tensor. Each dictionary contains 

262 the following fields that describe the tensor: 

263 

264 + `name`: The tensor name. 

265 + `index`: The tensor index in the interpreter. 

266 + `shape`: The shape of the tensor. 

267 + `shape_signature`: Same as `shape` for models with known/fixed shapes. 

268 If any dimension sizes are unknown, they are indicated with `-1`. 

269 + `dtype`: The numpy data type (such as `np.int32` or `np.uint8`). 

270 + `quantization`: Deprecated, use `quantization_parameters`. This field 

271 only works for per-tensor quantization, whereas 

272 `quantization_parameters` works in all cases. 

273 + `quantization_parameters`: A dictionary of parameters used to quantize 

274 the tensor: 

275 ~ `scales`: List of scales (one if per-tensor quantization). 

276 ~ `zero_points`: List of zero_points (one if per-tensor quantization). 

277 ~ `quantized_dimension`: Specifies the dimension of per-axis 

278 quantization, in the case of multiple scales/zero_points. 

279 + `sparsity_parameters`: A dictionary of parameters used to encode a 

280 sparse tensor. This is empty if the tensor is dense. 

281 """ 

282 result = {} 

283 for input_name, tensor_index in self._inputs.items(): 

284 result[input_name] = self._interpreter._get_tensor_details( # pylint: disable=protected-access 

285 tensor_index, self._subgraph_index) 

286 return result 

287 

288 def get_output_details(self): 

289 """Gets output tensor details. 

290 

291 Returns: 

292 A dictionary from input name to tensor details where each item is a 

293 dictionary with details about an output tensor. The dictionary contains 

294 the same fields as described for `get_input_details()`. 

295 """ 

296 result = {} 

297 for output_name, tensor_index in self._outputs: 

298 result[output_name] = self._interpreter._get_tensor_details( # pylint: disable=protected-access 

299 tensor_index, self._subgraph_index) 

300 return result 

301 

302 

303@_tf_export('lite.experimental.OpResolverType') 

304@enum.unique 

305class OpResolverType(enum.Enum): 

306 """Different types of op resolvers for Tensorflow Lite. 

307 

308 * `AUTO`: Indicates the op resolver that is chosen by default in TfLite 

309 Python, which is the "BUILTIN" as described below. 

310 * `BUILTIN`: Indicates the op resolver for built-in ops with optimized kernel 

311 implementation. 

312 * `BUILTIN_REF`: Indicates the op resolver for built-in ops with reference 

313 kernel implementation. It's generally used for testing and debugging. 

314 * `BUILTIN_WITHOUT_DEFAULT_DELEGATES`: Indicates the op resolver for 

315 built-in ops with optimized kernel implementation, but it will disable 

316 the application of default TfLite delegates (like the XNNPACK delegate) to 

317 the model graph. Generally this should not be used unless there are issues 

318 with the default configuration. 

319 """ 

320 # Corresponds to an op resolver chosen by default in TfLite Python. 

321 AUTO = 0 

322 

323 # Corresponds to tflite::ops::builtin::BuiltinOpResolver in C++. 

324 BUILTIN = 1 

325 

326 # Corresponds to tflite::ops::builtin::BuiltinRefOpResolver in C++. 

327 BUILTIN_REF = 2 

328 

329 # Corresponds to 

330 # tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates in C++. 

331 BUILTIN_WITHOUT_DEFAULT_DELEGATES = 3 

332 

333 

334def _get_op_resolver_id(op_resolver_type=OpResolverType.AUTO): 

335 """Get a integer identifier for the op resolver.""" 

336 

337 # Note: the integer identifier value needs to be same w/ op resolver ids 

338 # defined in interpreter_wrapper/interpreter_wrapper.cc. 

339 return { 

340 # Note AUTO and BUILTIN currently share the same identifier. 

341 OpResolverType.AUTO: 1, 

342 OpResolverType.BUILTIN: 1, 

343 OpResolverType.BUILTIN_REF: 2, 

344 OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES: 3 

345 }.get(op_resolver_type, None) 

346 

347 

348@_tf_export('lite.Interpreter') 

349class Interpreter: 

350 """Interpreter interface for running TensorFlow Lite models. 

351 

352 Models obtained from `TfLiteConverter` can be run in Python with 

353 `Interpreter`. 

354 

355 As an example, lets generate a simple Keras model and convert it to TFLite 

356 (`TfLiteConverter` also supports other input formats with `from_saved_model` 

357 and `from_concrete_function`) 

358 

359 >>> x = np.array([[1.], [2.]]) 

360 >>> y = np.array([[2.], [4.]]) 

361 >>> model = tf.keras.models.Sequential([ 

362 ... tf.keras.layers.Dropout(0.2), 

363 ... tf.keras.layers.Dense(units=1, input_shape=[1]) 

364 ... ]) 

365 >>> model.compile(optimizer='sgd', loss='mean_squared_error') 

366 >>> model.fit(x, y, epochs=1) 

367 >>> converter = tf.lite.TFLiteConverter.from_keras_model(model) 

368 >>> tflite_model = converter.convert() 

369 

370 `tflite_model` can be saved to a file and loaded later, or directly into the 

371 `Interpreter`. Since TensorFlow Lite pre-plans tensor allocations to optimize 

372 inference, the user needs to call `allocate_tensors()` before any inference. 

373 

374 >>> interpreter = tf.lite.Interpreter(model_content=tflite_model) 

375 >>> interpreter.allocate_tensors() # Needed before execution! 

376 

377 Sample execution: 

378 

379 >>> output = interpreter.get_output_details()[0] # Model has single output. 

380 >>> input = interpreter.get_input_details()[0] # Model has single input. 

381 >>> input_data = tf.constant(1., shape=[1, 1]) 

382 >>> interpreter.set_tensor(input['index'], input_data) 

383 >>> interpreter.invoke() 

384 >>> interpreter.get_tensor(output['index']).shape 

385 (1, 1) 

386 

387 Use `get_signature_runner()` for a more user-friendly inference API. 

388 """ 

389 

390 def __init__( 

391 self, 

392 model_path=None, 

393 model_content=None, 

394 experimental_delegates=None, 

395 num_threads=None, 

396 experimental_op_resolver_type=OpResolverType.AUTO, 

397 experimental_preserve_all_tensors=False, 

398 experimental_disable_delegate_clustering=False, 

399 ): 

400 """Constructor. 

401 

402 Args: 

403 model_path: Path to TF-Lite Flatbuffer file. 

404 model_content: Content of model. 

405 experimental_delegates: Experimental. Subject to change. List of 

406 [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates) 

407 objects returned by lite.load_delegate(). 

408 num_threads: Sets the number of threads used by the interpreter and 

409 available to CPU kernels. If not set, the interpreter will use an 

410 implementation-dependent default number of threads. Currently, only a 

411 subset of kernels, such as conv, support multi-threading. num_threads 

412 should be >= -1. Setting num_threads to 0 has the effect to disable 

413 multithreading, which is equivalent to setting num_threads to 1. If set 

414 to the value -1, the number of threads used will be 

415 implementation-defined and platform-dependent. 

416 experimental_op_resolver_type: The op resolver used by the interpreter. It 

417 must be an instance of OpResolverType. By default, we use the built-in 

418 op resolver which corresponds to tflite::ops::builtin::BuiltinOpResolver 

419 in C++. 

420 experimental_preserve_all_tensors: If true, then intermediate tensors used 

421 during computation are preserved for inspection, and if the passed op 

422 resolver type is AUTO or BUILTIN, the type will be changed to 

423 BUILTIN_WITHOUT_DEFAULT_DELEGATES so that no Tensorflow Lite default 

424 delegates are applied. If false, getting intermediate tensors could 

425 result in undefined values or None, especially when the graph is 

426 successfully modified by the Tensorflow Lite default delegate. 

427 experimental_disable_delegate_clustering: If true, don't perform delegate 

428 clustering during delegate graph partitioning phase. Disabling delegate 

429 clustering will make the execution order of ops respect the 

430 explicitly-inserted control dependencies in the graph (inserted via 

431 `with tf.control_dependencies()`) since the TF Lite converter will drop 

432 control dependencies by default. Most users shouldn't turn this flag to 

433 True if they don't insert explicit control dependencies or the graph 

434 execution order is expected. For automatically inserted control 

435 dependencies (with `tf.Variable`, `tf.Print` etc), the user doesn't need 

436 to turn this flag to True since they are respected by default. Note that 

437 this flag is currently experimental, and it might be removed/updated if 

438 the TF Lite converter doesn't drop such control dependencies in the 

439 model. Default is False. 

440 

441 Raises: 

442 ValueError: If the interpreter was unable to create. 

443 """ 

444 if not hasattr(self, '_custom_op_registerers'): 

445 self._custom_op_registerers = [] 

446 

447 actual_resolver_type = experimental_op_resolver_type 

448 if experimental_preserve_all_tensors and ( 

449 experimental_op_resolver_type == OpResolverType.AUTO or 

450 experimental_op_resolver_type == OpResolverType.BUILTIN): 

451 actual_resolver_type = OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES 

452 op_resolver_id = _get_op_resolver_id(actual_resolver_type) 

453 if op_resolver_id is None: 

454 raise ValueError('Unrecognized passed in op resolver type: {}'.format( 

455 experimental_op_resolver_type)) 

456 

457 if model_path and not model_content: 

458 custom_op_registerers_by_name = [ 

459 x for x in self._custom_op_registerers if isinstance(x, str) 

460 ] 

461 custom_op_registerers_by_func = [ 

462 x for x in self._custom_op_registerers if not isinstance(x, str) 

463 ] 

464 self._interpreter = _interpreter_wrapper.CreateWrapperFromFile( 

465 model_path, 

466 op_resolver_id, 

467 custom_op_registerers_by_name, 

468 custom_op_registerers_by_func, 

469 experimental_preserve_all_tensors, 

470 experimental_disable_delegate_clustering, 

471 ) 

472 if not self._interpreter: 

473 raise ValueError('Failed to open {}'.format(model_path)) 

474 elif model_content and not model_path: 

475 custom_op_registerers_by_name = [ 

476 x for x in self._custom_op_registerers if isinstance(x, str) 

477 ] 

478 custom_op_registerers_by_func = [ 

479 x for x in self._custom_op_registerers if not isinstance(x, str) 

480 ] 

481 # Take a reference, so the pointer remains valid. 

482 # Since python strings are immutable then PyString_XX functions 

483 # will always return the same pointer. 

484 self._model_content = model_content 

485 self._interpreter = _interpreter_wrapper.CreateWrapperFromBuffer( 

486 model_content, 

487 op_resolver_id, 

488 custom_op_registerers_by_name, 

489 custom_op_registerers_by_func, 

490 experimental_preserve_all_tensors, 

491 experimental_disable_delegate_clustering, 

492 ) 

493 elif not model_content and not model_path: 

494 raise ValueError('`model_path` or `model_content` must be specified.') 

495 else: 

496 raise ValueError('Can\'t both provide `model_path` and `model_content`') 

497 

498 if num_threads is not None: 

499 if not isinstance(num_threads, int): 

500 raise ValueError('type of num_threads should be int') 

501 if num_threads < 1: 

502 raise ValueError('num_threads should >= 1') 

503 self._interpreter.SetNumThreads(num_threads) 

504 

505 # Each delegate is a wrapper that owns the delegates that have been loaded 

506 # as plugins. The interpreter wrapper will be using them, but we need to 

507 # hold them in a list so that the lifetime is preserved at least as long as 

508 # the interpreter wrapper. 

509 self._delegates = [] 

510 if experimental_delegates: 

511 self._delegates = experimental_delegates 

512 for delegate in self._delegates: 

513 self._interpreter.ModifyGraphWithDelegate( 

514 delegate._get_native_delegate_pointer()) # pylint: disable=protected-access 

515 self._signature_defs = self.get_signature_list() 

516 

517 self._metrics = metrics.TFLiteMetrics() 

518 self._metrics.increase_counter_interpreter_creation() 

519 

520 def __del__(self): 

521 # Must make sure the interpreter is destroyed before things that 

522 # are used by it like the delegates. NOTE this only works on CPython 

523 # probably. 

524 # TODO(b/136468453): Remove need for __del__ ordering needs of CPython 

525 # by using explicit closes(). See implementation of Interpreter __del__. 

526 self._interpreter = None 

527 self._delegates = None 

528 

529 def allocate_tensors(self): 

530 self._ensure_safe() 

531 return self._interpreter.AllocateTensors() 

532 

533 def _safe_to_run(self): 

534 """Returns true if there exist no numpy array buffers. 

535 

536 This means it is safe to run tflite calls that may destroy internally 

537 allocated memory. This works, because in the wrapper.cc we have made 

538 the numpy base be the self._interpreter. 

539 """ 

540 # NOTE, our tensor() call in cpp will use _interpreter as a base pointer. 

541 # If this environment is the only _interpreter, then the ref count should be 

542 # 2 (1 in self and 1 in temporary of sys.getrefcount). 

543 return sys.getrefcount(self._interpreter) == 2 

544 

545 def _ensure_safe(self): 

546 """Makes sure no numpy arrays pointing to internal buffers are active. 

547 

548 This should be called from any function that will call a function on 

549 _interpreter that may reallocate memory e.g. invoke(), ... 

550 

551 Raises: 

552 RuntimeError: If there exist numpy objects pointing to internal memory 

553 then we throw. 

554 """ 

555 if not self._safe_to_run(): 

556 raise RuntimeError("""There is at least 1 reference to internal data 

557 in the interpreter in the form of a numpy array or slice. Be sure to 

558 only hold the function returned from tensor() if you are using raw 

559 data access.""") 

560 

561 # Experimental and subject to change 

562 def _get_op_details(self, op_index): 

563 """Gets a dictionary with arrays of ids for tensors involved with an op. 

564 

565 Args: 

566 op_index: Operation/node index of node to query. 

567 

568 Returns: 

569 a dictionary containing the index, op name, and arrays with lists of the 

570 indices for the inputs and outputs of the op/node. 

571 """ 

572 op_index = int(op_index) 

573 op_name = self._interpreter.NodeName(op_index) 

574 op_inputs = self._interpreter.NodeInputs(op_index) 

575 op_outputs = self._interpreter.NodeOutputs(op_index) 

576 

577 details = { 

578 'index': op_index, 

579 'op_name': op_name, 

580 'inputs': op_inputs, 

581 'outputs': op_outputs, 

582 } 

583 

584 return details 

585 

586 def _get_tensor_details(self, tensor_index, subgraph_index): 

587 """Gets tensor details. 

588 

589 Args: 

590 tensor_index: Tensor index of tensor to query. 

591 subgraph_index: Index of the subgraph. 

592 

593 Returns: 

594 A dictionary containing the following fields of the tensor: 

595 'name': The tensor name. 

596 'index': The tensor index in the interpreter. 

597 'shape': The shape of the tensor. 

598 'quantization': Deprecated, use 'quantization_parameters'. This field 

599 only works for per-tensor quantization, whereas 

600 'quantization_parameters' works in all cases. 

601 'quantization_parameters': The parameters used to quantize the tensor: 

602 'scales': List of scales (one if per-tensor quantization) 

603 'zero_points': List of zero_points (one if per-tensor quantization) 

604 'quantized_dimension': Specifies the dimension of per-axis 

605 quantization, in the case of multiple scales/zero_points. 

606 

607 Raises: 

608 ValueError: If tensor_index is invalid. 

609 """ 

610 tensor_index = int(tensor_index) 

611 subgraph_index = int(subgraph_index) 

612 tensor_name = self._interpreter.TensorName(tensor_index, subgraph_index) 

613 tensor_size = self._interpreter.TensorSize(tensor_index, subgraph_index) 

614 tensor_size_signature = self._interpreter.TensorSizeSignature( 

615 tensor_index, subgraph_index) 

616 tensor_type = self._interpreter.TensorType(tensor_index, subgraph_index) 

617 tensor_quantization = self._interpreter.TensorQuantization( 

618 tensor_index, subgraph_index) 

619 tensor_quantization_params = self._interpreter.TensorQuantizationParameters( 

620 tensor_index, subgraph_index) 

621 tensor_sparsity_params = self._interpreter.TensorSparsityParameters( 

622 tensor_index, subgraph_index) 

623 

624 if not tensor_type: 

625 raise ValueError('Could not get tensor details') 

626 

627 details = { 

628 'name': tensor_name, 

629 'index': tensor_index, 

630 'shape': tensor_size, 

631 'shape_signature': tensor_size_signature, 

632 'dtype': tensor_type, 

633 'quantization': tensor_quantization, 

634 'quantization_parameters': { 

635 'scales': tensor_quantization_params[0], 

636 'zero_points': tensor_quantization_params[1], 

637 'quantized_dimension': tensor_quantization_params[2], 

638 }, 

639 'sparsity_parameters': tensor_sparsity_params 

640 } 

641 

642 return details 

643 

644 # Experimental and subject to change 

645 def _get_ops_details(self): 

646 """Gets op details for every node. 

647 

648 Returns: 

649 A list of dictionaries containing arrays with lists of tensor ids for 

650 tensors involved in the op. 

651 """ 

652 return [ 

653 self._get_op_details(idx) for idx in range(self._interpreter.NumNodes()) 

654 ] 

655 

656 def get_tensor_details(self): 

657 """Gets tensor details for every tensor with valid tensor details. 

658 

659 Tensors where required information about the tensor is not found are not 

660 added to the list. This includes temporary tensors without a name. 

661 

662 Returns: 

663 A list of dictionaries containing tensor information. 

664 """ 

665 tensor_details = [] 

666 for idx in range(self._interpreter.NumTensors(0)): 

667 try: 

668 tensor_details.append(self._get_tensor_details(idx, subgraph_index=0)) 

669 except ValueError: 

670 pass 

671 return tensor_details 

672 

673 def get_input_details(self): 

674 """Gets model input tensor details. 

675 

676 Returns: 

677 A list in which each item is a dictionary with details about 

678 an input tensor. Each dictionary contains the following fields 

679 that describe the tensor: 

680 

681 + `name`: The tensor name. 

682 + `index`: The tensor index in the interpreter. 

683 + `shape`: The shape of the tensor. 

684 + `shape_signature`: Same as `shape` for models with known/fixed shapes. 

685 If any dimension sizes are unknown, they are indicated with `-1`. 

686 + `dtype`: The numpy data type (such as `np.int32` or `np.uint8`). 

687 + `quantization`: Deprecated, use `quantization_parameters`. This field 

688 only works for per-tensor quantization, whereas 

689 `quantization_parameters` works in all cases. 

690 + `quantization_parameters`: A dictionary of parameters used to quantize 

691 the tensor: 

692 ~ `scales`: List of scales (one if per-tensor quantization). 

693 ~ `zero_points`: List of zero_points (one if per-tensor quantization). 

694 ~ `quantized_dimension`: Specifies the dimension of per-axis 

695 quantization, in the case of multiple scales/zero_points. 

696 + `sparsity_parameters`: A dictionary of parameters used to encode a 

697 sparse tensor. This is empty if the tensor is dense. 

698 """ 

699 return [ 

700 self._get_tensor_details(i, subgraph_index=0) 

701 for i in self._interpreter.InputIndices() 

702 ] 

703 

704 def set_tensor(self, tensor_index, value): 

705 """Sets the value of the input tensor. 

706 

707 Note this copies data in `value`. 

708 

709 If you want to avoid copying, you can use the `tensor()` function to get a 

710 numpy buffer pointing to the input buffer in the tflite interpreter. 

711 

712 Args: 

713 tensor_index: Tensor index of tensor to set. This value can be gotten from 

714 the 'index' field in get_input_details. 

715 value: Value of tensor to set. 

716 

717 Raises: 

718 ValueError: If the interpreter could not set the tensor. 

719 """ 

720 self._interpreter.SetTensor(tensor_index, value) 

721 

722 def resize_tensor_input(self, input_index, tensor_size, strict=False): 

723 """Resizes an input tensor. 

724 

725 Args: 

726 input_index: Tensor index of input to set. This value can be gotten from 

727 the 'index' field in get_input_details. 

728 tensor_size: The tensor_shape to resize the input to. 

729 strict: Only unknown dimensions can be resized when `strict` is True. 

730 Unknown dimensions are indicated as `-1` in the `shape_signature` 

731 attribute of a given tensor. (default False) 

732 

733 Raises: 

734 ValueError: If the interpreter could not resize the input tensor. 

735 

736 Usage: 

737 ``` 

738 interpreter = Interpreter(model_content=tflite_model) 

739 interpreter.resize_tensor_input(0, [num_test_images, 224, 224, 3]) 

740 interpreter.allocate_tensors() 

741 interpreter.set_tensor(0, test_images) 

742 interpreter.invoke() 

743 ``` 

744 """ 

745 self._ensure_safe() 

746 # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size 

747 # parameter. 

748 tensor_size = np.array(tensor_size, dtype=np.int32) 

749 self._interpreter.ResizeInputTensor(input_index, tensor_size, strict) 

750 

751 def get_output_details(self): 

752 """Gets model output tensor details. 

753 

754 Returns: 

755 A list in which each item is a dictionary with details about 

756 an output tensor. The dictionary contains the same fields as 

757 described for `get_input_details()`. 

758 """ 

759 return [ 

760 self._get_tensor_details(i, subgraph_index=0) 

761 for i in self._interpreter.OutputIndices() 

762 ] 

763 

764 def get_signature_list(self): 

765 """Gets list of SignatureDefs in the model. 

766 

767 Example, 

768 ``` 

769 signatures = interpreter.get_signature_list() 

770 print(signatures) 

771 

772 # { 

773 # 'add': {'inputs': ['x', 'y'], 'outputs': ['output_0']} 

774 # } 

775 

776 Then using the names in the signature list you can get a callable from 

777 get_signature_runner(). 

778 ``` 

779 

780 Returns: 

781 A list of SignatureDef details in a dictionary structure. 

782 It is keyed on the SignatureDef method name, and the value holds 

783 dictionary of inputs and outputs. 

784 """ 

785 full_signature_defs = self._interpreter.GetSignatureDefs() 

786 for _, signature_def in full_signature_defs.items(): 

787 signature_def['inputs'] = list(signature_def['inputs'].keys()) 

788 signature_def['outputs'] = list(signature_def['outputs'].keys()) 

789 return full_signature_defs 

790 

791 def _get_full_signature_list(self): 

792 """Gets list of SignatureDefs in the model. 

793 

794 Example, 

795 ``` 

796 signatures = interpreter._get_full_signature_list() 

797 print(signatures) 

798 

799 # { 

800 # 'add': {'inputs': {'x': 1, 'y': 0}, 'outputs': {'output_0': 4}} 

801 # } 

802 

803 Then using the names in the signature list you can get a callable from 

804 get_signature_runner(). 

805 ``` 

806 

807 Returns: 

808 A list of SignatureDef details in a dictionary structure. 

809 It is keyed on the SignatureDef method name, and the value holds 

810 dictionary of inputs and outputs. 

811 """ 

812 return self._interpreter.GetSignatureDefs() 

813 

814 def get_signature_runner(self, signature_key=None): 

815 """Gets callable for inference of specific SignatureDef. 

816 

817 Example usage, 

818 ``` 

819 interpreter = tf.lite.Interpreter(model_content=tflite_model) 

820 interpreter.allocate_tensors() 

821 fn = interpreter.get_signature_runner('div_with_remainder') 

822 output = fn(x=np.array([3]), y=np.array([2])) 

823 print(output) 

824 # { 

825 # 'quotient': array([1.], dtype=float32) 

826 # 'remainder': array([1.], dtype=float32) 

827 # } 

828 ``` 

829 

830 None can be passed for signature_key if the model has a single Signature 

831 only. 

832 

833 All names used are this specific SignatureDef names. 

834 

835 

836 Args: 

837 signature_key: Signature key for the SignatureDef, it can be None if and 

838 only if the model has a single SignatureDef. Default value is None. 

839 

840 Returns: 

841 This returns a callable that can run inference for SignatureDef defined 

842 by argument 'signature_key'. 

843 The callable will take key arguments corresponding to the arguments of the 

844 SignatureDef, that should have numpy values. 

845 The callable will returns dictionary that maps from output names to numpy 

846 values of the computed results. 

847 

848 Raises: 

849 ValueError: If passed signature_key is invalid. 

850 """ 

851 if signature_key is None: 

852 if len(self._signature_defs) != 1: 

853 raise ValueError( 

854 'SignatureDef signature_key is None and model has {0} Signatures. ' 

855 'None is only allowed when the model has 1 SignatureDef'.format( 

856 len(self._signature_defs))) 

857 else: 

858 signature_key = next(iter(self._signature_defs)) 

859 return SignatureRunner(interpreter=self, signature_key=signature_key) 

860 

861 def get_tensor(self, tensor_index, subgraph_index=0): 

862 """Gets the value of the output tensor (get a copy). 

863 

864 If you wish to avoid the copy, use `tensor()`. This function cannot be used 

865 to read intermediate results. 

866 

867 Args: 

868 tensor_index: Tensor index of tensor to get. This value can be gotten from 

869 the 'index' field in get_output_details. 

870 subgraph_index: Index of the subgraph to fetch the tensor. Default value 

871 is 0, which means to fetch from the primary subgraph. 

872 

873 Returns: 

874 a numpy array. 

875 """ 

876 return self._interpreter.GetTensor(tensor_index, subgraph_index) 

877 

878 def tensor(self, tensor_index): 

879 """Returns function that gives a numpy view of the current tensor buffer. 

880 

881 This allows reading and writing to this tensors w/o copies. This more 

882 closely mirrors the C++ Interpreter class interface's tensor() member, hence 

883 the name. Be careful to not hold these output references through calls 

884 to `allocate_tensors()` and `invoke()`. This function cannot be used to read 

885 intermediate results. 

886 

887 Usage: 

888 

889 ``` 

890 interpreter.allocate_tensors() 

891 input = interpreter.tensor(interpreter.get_input_details()[0]["index"]) 

892 output = interpreter.tensor(interpreter.get_output_details()[0]["index"]) 

893 for i in range(10): 

894 input().fill(3.) 

895 interpreter.invoke() 

896 print("inference %s" % output()) 

897 ``` 

898 

899 Notice how this function avoids making a numpy array directly. This is 

900 because it is important to not hold actual numpy views to the data longer 

901 than necessary. If you do, then the interpreter can no longer be invoked, 

902 because it is possible the interpreter would resize and invalidate the 

903 referenced tensors. The NumPy API doesn't allow any mutability of the 

904 the underlying buffers. 

905 

906 WRONG: 

907 

908 ``` 

909 input = interpreter.tensor(interpreter.get_input_details()[0]["index"])() 

910 output = interpreter.tensor(interpreter.get_output_details()[0]["index"])() 

911 interpreter.allocate_tensors() # This will throw RuntimeError 

912 for i in range(10): 

913 input.fill(3.) 

914 interpreter.invoke() # this will throw RuntimeError since input,output 

915 ``` 

916 

917 Args: 

918 tensor_index: Tensor index of tensor to get. This value can be gotten from 

919 the 'index' field in get_output_details. 

920 

921 Returns: 

922 A function that can return a new numpy array pointing to the internal 

923 TFLite tensor state at any point. It is safe to hold the function forever, 

924 but it is not safe to hold the numpy array forever. 

925 """ 

926 return lambda: self._interpreter.tensor(self._interpreter, tensor_index) 

927 

928 def invoke(self): 

929 """Invoke the interpreter. 

930 

931 Be sure to set the input sizes, allocate tensors and fill values before 

932 calling this. Also, note that this function releases the GIL so heavy 

933 computation can be done in the background while the Python interpreter 

934 continues. No other function on this object should be called while the 

935 invoke() call has not finished. 

936 

937 Raises: 

938 ValueError: When the underlying interpreter fails raise ValueError. 

939 """ 

940 self._ensure_safe() 

941 self._interpreter.Invoke() 

942 

943 def reset_all_variables(self): 

944 return self._interpreter.ResetVariableTensors() 

945 

946 # Experimental and subject to change. 

947 def _native_handle(self): 

948 """Returns a pointer to the underlying tflite::Interpreter instance. 

949 

950 This allows extending tflite.Interpreter's functionality in a custom C++ 

951 function. Consider how that may work in a custom pybind wrapper: 

952 

953 m.def("SomeNewFeature", ([](py::object handle) { 

954 auto* interpreter = 

955 reinterpret_cast<tflite::Interpreter*>(handle.cast<intptr_t>()); 

956 ... 

957 })) 

958 

959 and corresponding Python call: 

960 

961 SomeNewFeature(interpreter.native_handle()) 

962 

963 Note: This approach is fragile. Users must guarantee the C++ extension build 

964 is consistent with the tflite.Interpreter's underlying C++ build. 

965 """ 

966 return self._interpreter.interpreter() 

967 

968 

969class InterpreterWithCustomOps(Interpreter): 

970 """Interpreter interface for TensorFlow Lite Models that accepts custom ops. 

971 

972 The interface provided by this class is experimental and therefore not exposed 

973 as part of the public API. 

974 

975 Wraps the tf.lite.Interpreter class and adds the ability to load custom ops 

976 by providing the names of functions that take a pointer to a BuiltinOpResolver 

977 and add a custom op. 

978 """ 

979 

980 def __init__(self, custom_op_registerers=None, **kwargs): 

981 """Constructor. 

982 

983 Args: 

984 custom_op_registerers: List of str (symbol names) or functions that take a 

985 pointer to a MutableOpResolver and register a custom op. When passing 

986 functions, use a pybind function that takes a uintptr_t that can be 

987 recast as a pointer to a MutableOpResolver. 

988 **kwargs: Additional arguments passed to Interpreter. 

989 

990 Raises: 

991 ValueError: If the interpreter was unable to create. 

992 """ 

993 self._custom_op_registerers = custom_op_registerers or [] 

994 super(InterpreterWithCustomOps, self).__init__(**kwargs)