Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/compiler/tensorrt/trt_convert.py: 17%

650 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15"""Exposes the Python wrapper conversion to trt_graph.""" 

16 

17import collections 

18from functools import partial # pylint: disable=g-importing-member 

19import os 

20import platform 

21import sys 

22import tempfile 

23 

24import numpy as np 

25import six as _six 

26 

27from tensorflow.core.framework import variable_pb2 

28from tensorflow.core.protobuf import config_pb2 

29from tensorflow.core.protobuf import meta_graph_pb2 

30from tensorflow.core.protobuf import rewriter_config_pb2 

31from tensorflow.python.client import session 

32from tensorflow.python.compiler.tensorrt import utils as trt_utils 

33from tensorflow.python.eager import context 

34from tensorflow.python.eager import wrap_function 

35from tensorflow.python.framework import convert_to_constants 

36from tensorflow.python.framework import dtypes 

37from tensorflow.python.framework import errors 

38from tensorflow.python.framework import importer 

39from tensorflow.python.framework import ops 

40from tensorflow.python.grappler import tf_optimizer 

41from tensorflow.python.ops import array_ops 

42from tensorflow.python.ops import gen_resource_variable_ops 

43from tensorflow.python.platform import tf_logging as logging 

44from tensorflow.python.saved_model import builder 

45from tensorflow.python.saved_model import load 

46from tensorflow.python.saved_model import loader 

47from tensorflow.python.saved_model import save 

48from tensorflow.python.saved_model import signature_constants 

49from tensorflow.python.saved_model import tag_constants 

50from tensorflow.python.trackable import asset 

51from tensorflow.python.trackable import autotrackable 

52from tensorflow.python.trackable import resource 

53from tensorflow.python.training import saver 

54from tensorflow.python.util import deprecation 

55from tensorflow.python.util import nest 

56from tensorflow.python.util.lazy_loader import LazyLoader 

57from tensorflow.python.util.tf_export import tf_export 

58 

59# Lazily load the op, since it's not available in cpu-only builds. Importing 

60# this at top will cause tests that imports TF-TRT fail when they're built 

61# and run without CUDA/GPU. 

62gen_trt_ops = LazyLoader( 

63 "gen_trt_ops", globals(), 

64 "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops") 

65 

66_pywrap_py_utils = LazyLoader( 

67 "_pywrap_py_utils", globals(), 

68 "tensorflow.compiler.tf2tensorrt._pywrap_py_utils") 

69 

70# Register TRT ops in python, so that when users import this module they can 

71# execute a TRT-converted graph without calling any of the methods in this 

72# module. 

73# 

74# This will call register_op_list() in 

75# tensorflow/python/framework/op_def_registry.py, but it doesn't register 

76# the op or the op kernel in C++ runtime. 

77try: 

78 gen_trt_ops.trt_engine_op # pylint: disable=pointless-statement 

79except AttributeError: 

80 pass 

81 

82 

83def _to_bytes(s): 

84 """Encode s if it is a sequence of chars.""" 

85 if isinstance(s, _six.text_type): 

86 return s.encode("utf-8", errors="surrogateescape") 

87 return s 

88 

89 

90def _to_string(s): 

91 """Decode s if it is a sequence of bytes.""" 

92 if isinstance(s, _six.binary_type): 

93 return s.decode("utf-8") 

94 return s 

95 

96 

97class TrtPrecisionMode(object): 

98 FP32 = "FP32" 

99 FP16 = "FP16" 

100 INT8 = "INT8" 

101 

102 @staticmethod 

103 def supported_precision_modes(): 

104 precisions = [ 

105 TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8 

106 ] 

107 return precisions + [p.lower() for p in precisions] 

108 

109 

110# Use a large enough number as the default max_workspace_size for TRT engines, 

111# so it can produce reasonable performance results with the default. 

112# For TRT >= 8.4, the recommendation is MAX_INT. 

113if (_pywrap_py_utils.is_tensorrt_enabled() and 

114 trt_utils.is_loaded_tensorrt_version_greater_equal(8, 4, 0)): 

115 # We must use `sys.maxsize - 512` to avoid overflow during casting. 

116 DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = sys.maxsize - 512 

117else: 

118 DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30 # 1,073,741,824 

119 

120PROFILE_STRATEGY_RANGE = "Range" 

121PROFILE_STRATEGY_OPTIMAL = "Optimal" 

122PROFILE_STRATEGY_RANGE_OPTIMAL = "Range+Optimal" 

123PROFILE_STRATEGY_IMPLICIT_BATCH_MODE_COMPATIBLE = "ImplicitBatchModeCompatible" 

124 

125 

126def supported_profile_strategies(): 

127 return [ 

128 PROFILE_STRATEGY_RANGE, PROFILE_STRATEGY_OPTIMAL, 

129 PROFILE_STRATEGY_RANGE_OPTIMAL, 

130 PROFILE_STRATEGY_IMPLICIT_BATCH_MODE_COMPATIBLE 

131 ] 

132 

133 

134@tf_export("experimental.tensorrt.ConversionParams", v1=[]) 

135class TrtConversionParams( 

136 collections.namedtuple("TrtConversionParams", [ 

137 "max_workspace_size_bytes", "precision_mode", "minimum_segment_size", 

138 "maximum_cached_engines", "use_calibration", "allow_build_at_runtime" 

139 ])): 

140 """Parameters that are used for TF-TRT conversion. 

141 

142 Fields: 

143 max_workspace_size_bytes: the maximum GPU temporary memory that the TRT 

144 engine can use at execution time. This corresponds to the 

145 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). 

146 precision_mode: one of the strings in 

147 TrtPrecisionMode.supported_precision_modes(). 

148 minimum_segment_size: the minimum number of nodes required for a subgraph 

149 to be replaced by TRTEngineOp. 

150 maximum_cached_engines: max number of cached TRT engines for dynamic TRT 

151 ops. Created TRT engines for a dynamic dimension are cached. If the 

152 number of cached engines is already at max but none of them supports the 

153 input shapes, the TRTEngineOp will fall back to run the original TF 

154 subgraph that corresponds to the TRTEngineOp. 

155 use_calibration: this argument is ignored if precision_mode is not INT8. 

156 If set to True, a calibration graph will be created to calibrate the 

157 missing ranges. The calibration graph must be converted to an inference 

158 graph by running calibration with calibrate(). If set to False, 

159 quantization nodes will be expected for every tensor in the graph 

160 (excluding those which will be fused). If a range is missing, an error 

161 will occur. Please note that accuracy may be negatively affected if 

162 there is a mismatch between which tensors TRT quantizes and which 

163 tensors were trained with fake quantization. 

164 allow_build_at_runtime: whether to allow building TensorRT engines during 

165 runtime if no prebuilt TensorRT engine can be found that can handle the 

166 given inputs during runtime, then a new TensorRT engine is built at 

167 runtime if allow_build_at_runtime=True, and otherwise native TF is used. 

168 """ 

169 

170 def __new__(cls, 

171 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, 

172 precision_mode=TrtPrecisionMode.FP32, 

173 minimum_segment_size=3, 

174 maximum_cached_engines=1, 

175 use_calibration=True, 

176 allow_build_at_runtime=True): 

177 return super(TrtConversionParams, 

178 cls).__new__(cls, max_workspace_size_bytes, precision_mode, 

179 minimum_segment_size, maximum_cached_engines, 

180 use_calibration, allow_build_at_runtime) 

181 

182 

183DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams() 

184 

185_TRT_ENGINE_OP_NAME = "TRTEngineOp" 

186 

187 

188def _check_conversion_params(conversion_params, is_v2=False): 

189 """Validate the provided TrtConversionParams. 

190 

191 Args: 

192 conversion_params: a TrtConversionParams instance. 

193 is_v2: whether we're getting a RewriterConfig for TF 2.0. 

194 

195 Raises: 

196 TypeError: if any of the parameters are of unexpected type. 

197 ValueError: if any of the parameters are of unexpected value. 

198 """ 

199 supported_precision_modes = TrtPrecisionMode.supported_precision_modes() 

200 if conversion_params.precision_mode not in supported_precision_modes: 

201 raise ValueError( 

202 ("precision mode '{}' is not supported." 

203 "It should be one of {}").format(conversion_params.precision_mode, 

204 supported_precision_modes)) 

205 if (conversion_params.minimum_segment_size <= 0 and 

206 conversion_params.minimum_segment_size != -1): 

207 raise ValueError("minimum segment size should be positive or -1 " 

208 "(to disable main graph conversion).") 

209 

210 

211def _check_trt_version_compatibility(): 

212 """Check compatibility of TensorRT version. 

213 

214 Raises: 

215 RuntimeError: if the TensorRT library version is incompatible. 

216 """ 

217 

218 if not _pywrap_py_utils.is_tensorrt_enabled(): 

219 logging.error( 

220 "Tensorflow needs to be built with TensorRT support enabled to allow " 

221 "TF-TRT to operate.") 

222 

223 raise RuntimeError("Tensorflow has not been built with TensorRT support.") 

224 

225 if platform.system() == "Windows": 

226 logging.warn( 

227 "Windows support is provided experimentally. No guarantee is made " 

228 "regarding functionality or engineering support. Use at your own risk.") 

229 

230 linked_version = _pywrap_py_utils.get_linked_tensorrt_version() 

231 loaded_version = _pywrap_py_utils.get_loaded_tensorrt_version() 

232 

233 logging.info("Linked TensorRT version: %s", str(linked_version)) 

234 logging.info("Loaded TensorRT version: %s", str(loaded_version)) 

235 

236 def raise_trt_version_deprecated(version_type, trt_version): 

237 assert version_type in [ 

238 "linked", "loaded" 

239 ], ("Incorrect value received for version_type: %s. Accepted: ['linked', " 

240 "'loaded']") % version_type 

241 

242 logging.error( 

243 "The {version_type} version of TensorRT: `{trt_version}` has now " 

244 "been removed. Please upgrade to TensorRT 7 or more recent.".format( 

245 version_type=version_type, 

246 trt_version=trt_utils.version_tuple_to_string(trt_version))) 

247 

248 raise RuntimeError("Incompatible %s TensorRT versions" % version_type) 

249 

250 if not trt_utils.is_linked_tensorrt_version_greater_equal(7, 0, 0): 

251 raise_trt_version_deprecated("linked", linked_version) 

252 

253 if not trt_utils.is_loaded_tensorrt_version_greater_equal(7, 0, 0): 

254 raise_trt_version_deprecated("loaded", loaded_version) 

255 

256 if (loaded_version[0] != linked_version[0] or 

257 not trt_utils.is_loaded_tensorrt_version_greater_equal(*linked_version)): 

258 logging.error( 

259 "Loaded TensorRT %s but linked TensorFlow against TensorRT %s. A few " 

260 "requirements must be met:\n" 

261 "\t-It is required to use the same major version of TensorRT during " 

262 "compilation and runtime.\n" 

263 "\t-TensorRT does not support forward compatibility. The loaded " 

264 "version has to be equal or more recent than the linked version.", 

265 trt_utils.version_tuple_to_string(loaded_version), 

266 trt_utils.version_tuple_to_string(linked_version)) 

267 raise RuntimeError("Incompatible TensorRT major version") 

268 

269 elif loaded_version != linked_version: 

270 logging.info( 

271 "Loaded TensorRT %s and linked TensorFlow against TensorRT %s. This is " 

272 "supported because TensorRT minor/patch upgrades are backward " 

273 "compatible.", trt_utils.version_tuple_to_string(loaded_version), 

274 trt_utils.version_tuple_to_string(linked_version)) 

275 

276 

277def _get_tensorrt_rewriter_config(conversion_params, 

278 is_dynamic_op=None, 

279 max_batch_size=None, 

280 is_v2=False, 

281 disable_non_trt_optimizers=False, 

282 use_implicit_batch=True, 

283 profile_strategy=PROFILE_STRATEGY_RANGE): 

284 """Returns a RewriterConfig proto for TRT transformation. 

285 

286 Args: 

287 conversion_params: a TrtConversionParams instance. 

288 is_dynamic_op: whether to use dynamic engines. 

289 max_batch_size: maximum batch size for static engines. 

290 is_v2: whether we're getting a RewriterConfig for TF 2.0. 

291 disable_non_trt_optimizers: Turn off all default Grappler optimizers. 

292 use_implicit_batch: Whether to use implicit batch or explicit batch. 

293 profile_strategy: dynamic shape optimization profile strategy. 

294 

295 Returns: 

296 A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. 

297 

298 Raises: 

299 TypeError: if any of the parameters are of unexpected type. 

300 ValueError: if any of the parameters are of unexpected value. 

301 """ 

302 _check_conversion_params(conversion_params, is_v2=is_v2) 

303 if is_v2 and is_dynamic_op is not None and not is_dynamic_op: 

304 raise ValueError("is_dynamic_op is either None or True for TF2") 

305 if not is_v2 and is_dynamic_op is None: 

306 raise ValueError("is_dynamic_op can't be None for TF1") 

307 

308 if (is_dynamic_op is None or is_dynamic_op) and max_batch_size is not None: 

309 raise ValueError("max_batch_size has to be None for TF2" 

310 " or when is_dynamic_op == True in TF1") 

311 if is_dynamic_op is not None and not is_dynamic_op and not isinstance( 

312 max_batch_size, int): 

313 raise ValueError( 

314 "max_batch_size has to be an integer for is_dynamic_op==False in TF1") 

315 rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() 

316 # Disable Grappler Remapper to avoid that fused OPs that may not be 

317 # beneficial to TF-TRT and are not supported by TF-TRT. 

318 rewriter_config_with_trt.remapping = False 

319 

320 # Prevent folding of Const->QDQ chains. 

321 rewriter_config_with_trt. \ 

322 experimental_disable_folding_quantization_emulation = ( 

323 trt_utils.is_linked_tensorrt_version_greater_equal(8, 0, 0) or 

324 trt_utils.is_loaded_tensorrt_version_greater_equal(8, 0, 0)) 

325 

326 if not disable_non_trt_optimizers: 

327 rewriter_config_with_trt.optimizers.extend([ 

328 "pruning", "debug_stripper", "layout", "dependency", "constfold", 

329 "common_subgraph_elimination" 

330 ]) 

331 

332 rewriter_config_with_trt.meta_optimizer_iterations = ( 

333 rewriter_config_pb2.RewriterConfig.ONE) 

334 optimizer = rewriter_config_with_trt.custom_optimizers.add() 

335 

336 if not disable_non_trt_optimizers: 

337 # Add a constfold optimizer to cleanup the unused Const nodes. 

338 rewriter_config_with_trt.custom_optimizers.add().name = "constfold" 

339 

340 optimizer.name = "TensorRTOptimizer" 

341 optimizer.parameter_map[ 

342 "minimum_segment_size"].i = conversion_params.minimum_segment_size 

343 optimizer.parameter_map["max_workspace_size_bytes"].i = ( 

344 conversion_params.max_workspace_size_bytes) 

345 optimizer.parameter_map["precision_mode"].s = _to_bytes( 

346 conversion_params.precision_mode) 

347 optimizer.parameter_map[ 

348 "maximum_cached_engines"].i = conversion_params.maximum_cached_engines 

349 optimizer.parameter_map[ 

350 "use_calibration"].b = conversion_params.use_calibration 

351 optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op 

352 optimizer.parameter_map[ 

353 "allow_build_at_runtime"].b = conversion_params.allow_build_at_runtime 

354 if max_batch_size is not None: 

355 optimizer.parameter_map["max_batch_size"].i = max_batch_size 

356 optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch 

357 # While we accept case insensitive strings from the users, we only pass the 

358 # strings in lower cases to TF-TRT converter. 

359 if not use_implicit_batch: 

360 optimizer.parameter_map["profile_strategy"].s = _to_bytes( 

361 profile_strategy.lower()) 

362 

363 # Disabling optimizers should happen after defining the TF-TRT grappler pass 

364 # otherwise the template can overwrite the disablement. 

365 if disable_non_trt_optimizers: 

366 trt_utils.disable_non_trt_optimizers_in_rewriter_config( 

367 rewriter_config_with_trt) 

368 

369 return rewriter_config_with_trt 

370 

371 

372@deprecation.deprecated( 

373 None, "You shouldn't need a rewriter_config with the current TF-TRT APIs.") 

374def get_tensorrt_rewriter_config(conversion_params, 

375 is_dynamic_op=None, 

376 max_batch_size=None, 

377 is_v2=False, 

378 disable_non_trt_optimizers=False): 

379 return _get_tensorrt_rewriter_config(conversion_params, is_dynamic_op, 

380 max_batch_size, is_v2, 

381 disable_non_trt_optimizers) 

382 

383 

384# Remove all scope prefixes in the node name. In TF 2.0, the same concrete 

385# function can be initialized multiple times with different prefixes, and 

386# this will result in the same TRTEngineOp being initialized multiple times 

387# with different cache and duplicate TRT engines. 

388# TODO(laigd): this may be caused by the fact that TRTEngineOp is not 

389# stateful, need to investigate. 

390# TODO(laigd): we rely on the fact that all functions are fully inlined 

391# before TF-TRT optimizer is called, as otherwise it may generate the same 

392# name when optimizing a different function graph. Fix this. 

393def _get_canonical_engine_name(name): 

394 return name.split("/")[-1] 

395 

396 

397class TrtGraphConverter(object): 

398 """A converter for TF-TRT transformation for TF 1.x GraphDef/SavedModels. 

399 

400 To run the conversion without quantization calibration (e.g. for FP32/FP16 

401 precision modes): 

402 

403 ```python 

404 converter = TrtGraphConverter( 

405 input_saved_model_dir="my_dir", 

406 precision_mode=TrtPrecisionMode.FP16) 

407 converted_graph_def = converter.convert() 

408 converter.save(output_saved_model_dir) 

409 ``` 

410 

411 To run the conversion with quantization calibration: 

412 

413 ```python 

414 converter = TrtGraphConverter( 

415 input_saved_model_dir="my_dir", 

416 precision_mode=TrtPrecisionMode.INT8) 

417 converter.convert() 

418 

419 # Run calibration 10 times. 

420 converted_graph_def = converter.calibrate( 

421 fetch_names=['output:0'], 

422 num_runs=10, 

423 feed_dict_fn=lambda: {'input:0': my_next_data()}) 

424 

425 converter.save(output_saved_model_dir) 

426 ``` 

427 """ 

428 

429 def __init__(self, 

430 input_saved_model_dir=None, 

431 input_saved_model_tags=None, 

432 input_saved_model_signature_key=None, 

433 input_graph_def=None, 

434 nodes_denylist=None, 

435 max_batch_size=1, 

436 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, 

437 precision_mode=TrtPrecisionMode.FP32, 

438 minimum_segment_size=3, 

439 is_dynamic_op=False, 

440 maximum_cached_engines=1, 

441 use_calibration=True): 

442 """Initializes the converter. 

443 

444 Args: 

445 input_saved_model_dir: the directory to load the SavedModel which contains 

446 the input graph to transforms. Used only when input_graph_def is None. 

447 input_saved_model_tags: list of tags to load the SavedModel. 

448 input_saved_model_signature_key: the key of the signature to optimize the 

449 graph for. 

450 input_graph_def: a GraphDef object containing a model to be transformed. 

451 If set to None, the graph will be read from the SavedModel loaded from 

452 input_saved_model_dir. 

453 nodes_denylist: list of node names to prevent the converter from touching. 

454 max_batch_size: max size for the input batch. 

455 max_workspace_size_bytes: the maximum GPU temporary memory which the TRT 

456 engine can use at execution time. This corresponds to the 

457 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). 

458 precision_mode: one of TrtPrecisionMode.supported_precision_modes(). 

459 minimum_segment_size: the minimum number of nodes required for a subgraph 

460 to be replaced by TRTEngineOp. 

461 is_dynamic_op: whether to generate dynamic TRT ops which will build the 

462 TRT network and engine at run time. 

463 maximum_cached_engines: max number of cached TRT engines in dynamic TRT 

464 ops. If the number of cached engines is already at max but none of them 

465 can serve the input, the TRTEngineOp will fall back to run the TF 

466 function based on which the TRTEngineOp is created. 

467 use_calibration: this argument is ignored if precision_mode is not INT8. 

468 If set to True, a calibration graph will be created to calibrate the 

469 missing ranges. The calibration graph must be converted to an inference 

470 graph by running calibration with calibrate(). If set to False, 

471 quantization nodes will be expected for every tensor in the graph 

472 (excluding those which will be fused). If a range is missing, an error 

473 will occur. Please note that accuracy may be negatively affected if 

474 there is a mismatch between which tensors TRT quantizes and which 

475 tensors were trained with fake quantization. 

476 

477 Raises: 

478 ValueError: if the combination of the parameters is invalid. 

479 RuntimeError: if this class is used in TF 2.0. 

480 """ 

481 if context.executing_eagerly(): 

482 raise RuntimeError( 

483 "Please use tf.experimental.tensorrt.Converter in TF 2.0.") 

484 

485 if input_graph_def and input_saved_model_dir: 

486 raise ValueError( 

487 "Can only specify one of input_graph_def and input_saved_model_dir") 

488 if not input_graph_def and not input_saved_model_dir: 

489 raise ValueError("Must specify one of input_graph_def and " 

490 "input_saved_model_dir") 

491 _check_trt_version_compatibility() 

492 

493 self._input_graph_def = input_graph_def 

494 self._nodes_denylist = nodes_denylist 

495 

496 self._input_saved_model_dir = input_saved_model_dir 

497 self._converted = False 

498 self._grappler_meta_graph_def = None 

499 

500 self._input_saved_model_tags = ( 

501 input_saved_model_tags or [tag_constants.SERVING]) 

502 self._input_saved_model_signature_key = ( 

503 input_saved_model_signature_key or 

504 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) 

505 

506 # For calibration usage. 

507 self._calibration_graph = None 

508 self._calibration_data_collected = False 

509 self._need_calibration = ( 

510 ((precision_mode == TrtPrecisionMode.INT8) or 

511 (precision_mode == TrtPrecisionMode.INT8.lower())) and use_calibration) 

512 if self._need_calibration and not is_dynamic_op: 

513 logging.warn( 

514 "INT8 precision mode with calibration is supported with " 

515 "dynamic TRT ops only. Disregarding is_dynamic_op parameter.") 

516 is_dynamic_op = True 

517 

518 self._is_dynamic_op = is_dynamic_op 

519 if is_dynamic_op: 

520 self._max_batch_size = None 

521 if max_batch_size is not None: 

522 logging.warn("When is_dynamic_op==True max_batch_size should be None") 

523 else: 

524 if not isinstance(max_batch_size, int): 

525 raise ValueError("When is_dynamic_op==False max_batch_size should be " 

526 "an integer") 

527 self._max_batch_size = max_batch_size 

528 

529 self._conversion_params = TrtConversionParams( 

530 max_workspace_size_bytes=max_workspace_size_bytes, 

531 precision_mode=precision_mode, 

532 minimum_segment_size=minimum_segment_size, 

533 maximum_cached_engines=maximum_cached_engines, 

534 use_calibration=use_calibration, 

535 allow_build_at_runtime=True) 

536 _check_conversion_params(self._conversion_params) 

537 

538 self._test_only_disable_non_trt_optimizers = False 

539 

540 def _run_conversion(self): 

541 """Run Grappler's OptimizeGraph() tool to convert the graph.""" 

542 # Create custom ConfigProto for Grappler. 

543 grappler_session_config = config_pb2.ConfigProto() 

544 custom_rewriter_config = _get_tensorrt_rewriter_config( 

545 conversion_params=self._conversion_params, 

546 is_dynamic_op=self._is_dynamic_op, 

547 max_batch_size=self._max_batch_size, 

548 disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers, 

549 use_implicit_batch=True) 

550 grappler_session_config.graph_options.rewrite_options.CopyFrom( 

551 custom_rewriter_config) 

552 

553 # Run Grappler. 

554 self._converted_graph_def = tf_optimizer.OptimizeGraph( 

555 grappler_session_config, 

556 self._grappler_meta_graph_def, 

557 graph_id=b"tf_graph") 

558 self._converted = True 

559 

560 def _add_nodes_denylist(self): 

561 if self._nodes_denylist: 

562 collection_def = self._grappler_meta_graph_def.collection_def["train_op"] 

563 denylist = collection_def.node_list.value 

564 for i in self._nodes_denylist: 

565 if isinstance(i, ops.Tensor): 

566 denylist.append(_to_bytes(i.name)) 

567 else: 

568 denylist.append(_to_bytes(i)) 

569 

570 def _convert_graph_def(self): 

571 """Convert the input GraphDef.""" 

572 graph = ops.Graph() 

573 with graph.as_default(): 

574 importer.import_graph_def(self._input_graph_def, name="") 

575 self._grappler_meta_graph_def = saver.export_meta_graph( 

576 graph_def=graph.as_graph_def(add_shapes=True), graph=graph) 

577 self._add_nodes_denylist() 

578 

579 self._run_conversion() 

580 

581 def _collections_to_keep(self, collection_keys): 

582 # TODO(laigd): currently we use the collection key to filter out 

583 # collections that depend on variable ops, but this may miss some 

584 # other user-defined collections. A better way would be to use 

585 # CollectionDef::NodeList for the filtering. 

586 collections_to_remove = ( 

587 ops.GraphKeys._VARIABLE_COLLECTIONS + [ 

588 ops.GraphKeys.TRAIN_OP, ops.GraphKeys.WHILE_CONTEXT, 

589 ops.GraphKeys.COND_CONTEXT 

590 ]) 

591 return [key for key in collection_keys if key not in collections_to_remove] 

592 

593 def _convert_saved_model(self): 

594 """Convert the input SavedModel.""" 

595 graph = ops.Graph() 

596 with session.Session(graph=graph) as sess: 

597 input_meta_graph_def = loader.load(sess, self._input_saved_model_tags, 

598 self._input_saved_model_dir) 

599 input_signature_def = input_meta_graph_def.signature_def[ 

600 self._input_saved_model_signature_key] 

601 

602 def _gather_names(tensor_info): 

603 """Get the node names from a TensorInfo.""" 

604 return {tensor_info[key].name.split(":")[0] for key in tensor_info} 

605 

606 # Get input and outputs from all SignatureDef. 

607 output_node_names = _gather_names(input_signature_def.inputs).union( 

608 _gather_names(input_signature_def.outputs)) 

609 

610 # Preserve nodes in collection 

611 for collection_key in self._collections_to_keep( 

612 input_meta_graph_def.collection_def): 

613 for op in sess.graph.get_collection(collection_key): 

614 if isinstance(op, ops.Operation): 

615 output_node_names.add(op.name.split(":")[0]) 

616 

617 # Freeze the variables in the SavedModel graph and copy the frozen 

618 # graph over. 

619 frozen_graph_def = convert_to_constants.convert_variables_to_constants( 

620 sess, sess.graph.as_graph_def(add_shapes=True), 

621 list(output_node_names)) 

622 self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef() 

623 self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def) 

624 

625 # Copy the collections that are not variables. 

626 for collection_key in self._collections_to_keep( 

627 input_meta_graph_def.collection_def): 

628 self._grappler_meta_graph_def.collection_def[collection_key].CopyFrom( 

629 input_meta_graph_def.collection_def[collection_key]) 

630 

631 self._add_nodes_denylist() 

632 

633 # Copy other information. 

634 self._grappler_meta_graph_def.meta_info_def.CopyFrom( 

635 input_meta_graph_def.meta_info_def) 

636 self._grappler_meta_graph_def.signature_def[ 

637 self._input_saved_model_signature_key].CopyFrom(input_signature_def) 

638 # TODO(laigd): maybe add back AssetFileDef. 

639 

640 self._run_conversion() 

641 

642 def convert(self): 

643 """Run the TF-TRT conversion. 

644 

645 Returns: 

646 The converted GraphDef for TF 1.x. 

647 """ 

648 assert not self._converted 

649 if self._input_graph_def: 

650 self._convert_graph_def() 

651 else: 

652 self._convert_saved_model() 

653 return self._converted_graph_def 

654 

655 def calibrate(self, 

656 fetch_names, 

657 num_runs, 

658 feed_dict_fn=None, 

659 input_map_fn=None): 

660 """Run the calibration and return the calibrated GraphDef. 

661 

662 Args: 

663 fetch_names: a list of output tensor name to fetch during calibration. 

664 num_runs: number of runs of the graph during calibration. 

665 feed_dict_fn: a function that returns a dictionary mapping input names (as 

666 strings) in the GraphDef to be calibrated to values (e.g. Python list, 

667 numpy arrays, etc). One and only one of `feed_dict_fn` and 

668 `input_map_fn` should be specified. 

669 input_map_fn: a function that returns a dictionary mapping input names (as 

670 strings) in the GraphDef to be calibrated to Tensor objects. The values 

671 of the named input tensors in the GraphDef to be calibrated will be 

672 re-mapped to the respective `Tensor` values during calibration. One and 

673 only one of `feed_dict_fn` and `input_map_fn` should be specified. 

674 

675 Raises: 

676 ValueError: if the input combination is invalid. 

677 RuntimeError: if this method is called in eager mode. 

678 

679 Returns: 

680 The GraphDef after the calibration. 

681 """ 

682 assert self._converted 

683 assert self._need_calibration 

684 assert not self._calibration_data_collected 

685 

686 if (feed_dict_fn and input_map_fn) or (not feed_dict_fn and 

687 not input_map_fn): 

688 raise ValueError( 

689 "Should specify one and only one of feed_dict_fn and input_map_fn.") 

690 

691 if input_map_fn: 

692 for k, v in input_map_fn().items(): 

693 if not isinstance(k, str): 

694 raise ValueError("Keys of input_map_fn must be of type str") 

695 if not isinstance(v, ops.Tensor): 

696 raise ValueError("Values of input_map_fn must be of type tf.Tensor") 

697 

698 self._calibration_graph = ops.Graph() 

699 with self._calibration_graph.as_default(): 

700 fetches = importer.import_graph_def( 

701 self._converted_graph_def, 

702 input_map=input_map_fn() if input_map_fn else None, 

703 return_elements=fetch_names, 

704 name="") 

705 

706 calibrate_rewriter_cfg = rewriter_config_pb2.RewriterConfig() 

707 if self._test_only_disable_non_trt_optimizers: 

708 trt_utils.disable_non_trt_optimizers_in_rewriter_config( 

709 calibrate_rewriter_cfg) 

710 

711 # Set allow_soft_placement=True to run the graph for calibration so that 

712 # OPs supported by TensorRT but don't have a GPU implementation are allowed 

713 # to execute on CPU. 

714 calibrate_config = config_pb2.ConfigProto( 

715 allow_soft_placement=True, 

716 graph_options=config_pb2.GraphOptions( 

717 rewrite_options=calibrate_rewriter_cfg)) 

718 

719 with session.Session( 

720 graph=self._calibration_graph, 

721 config=calibrate_config) as calibration_sess: 

722 for _ in range(num_runs): 

723 calibration_sess.run( 

724 fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None) 

725 

726 # Maps device name to the corresponding get_calibration_data. 

727 # 

728 # TODO(laigd): a better way would be to use calibration_sess to list 

729 # all the devices, add one get_calibration_data for each device, and 

730 # fetch each such op for every resource until its found. This can work 

731 # even when the device of the TRTEngineOp is empty or not fully specified. 

732 device_to_get_resource_op_map = {} 

733 

734 with self._calibration_graph.as_default(): 

735 resource_name_input = array_ops.placeholder(dtypes.string) 

736 

737 for node in self._converted_graph_def.node: 

738 if node.op == _TRT_ENGINE_OP_NAME: 

739 # Adds the get_calibration_data op for the device if not done 

740 # before. We only add one such op for each device. 

741 # TODO(laigd): What if the device is empty????? 

742 if node.device not in device_to_get_resource_op_map: 

743 with self._calibration_graph.device(node.device): 

744 serialized_resources_output = ( 

745 gen_trt_ops.get_calibration_data_op(resource_name_input)) 

746 device_to_get_resource_op_map[node.device] = ( 

747 serialized_resources_output) 

748 

749 # Get the calibration resource. 

750 calibration_result = calibration_sess.run( 

751 device_to_get_resource_op_map[node.device], 

752 feed_dict={ 

753 resource_name_input: _get_canonical_engine_name(node.name) 

754 }) 

755 node.attr["calibration_data"].s = calibration_result 

756 

757 self._calibration_data_collected = True 

758 

759 return self._converted_graph_def 

760 

761 def save(self, output_saved_model_dir): 

762 """Save the converted graph as a SavedModel. 

763 

764 Args: 

765 output_saved_model_dir: construct a SavedModel using the converted 

766 GraphDef and save it to the specified directory. This option only works 

767 when the input graph is loaded from a SavedModel, i.e. when 

768 input_saved_model_dir is specified and input_graph_def is None in 

769 __init__(). 

770 

771 Raises: 

772 ValueError: if the input to the converter is a GraphDef instead of a 

773 SavedModel. 

774 """ 

775 assert self._converted 

776 if self._need_calibration: 

777 assert self._calibration_data_collected 

778 if self._input_graph_def: 

779 raise ValueError( 

780 "Not able to save to a SavedModel since input is a GraphDef") 

781 

782 def _restore_collections(dest_graph, src_meta_graph_def, collection_keys): 

783 """Restores collections that we need to keep.""" 

784 scope = "" 

785 for key in collection_keys: 

786 collection_def = src_meta_graph_def.collection_def[key] 

787 kind = collection_def.WhichOneof("kind") 

788 if kind is None: 

789 logging.error( 

790 "Cannot identify data type for collection %s. Skipping.", key) 

791 continue 

792 from_proto = ops.get_from_proto_function(key) 

793 if from_proto and kind == "bytes_list": 

794 proto_type = ops.get_collection_proto_type(key) 

795 # It is assumed that there are no Variables Keys in collections 

796 for value in collection_def.bytes_list.value: 

797 proto = proto_type() 

798 proto.ParseFromString(value) 

799 try: 

800 new_value = from_proto(proto, import_scope=scope) 

801 except: 

802 continue 

803 dest_graph.add_to_collection(key, new_value) 

804 else: 

805 field = getattr(collection_def, kind) 

806 if kind == "node_list": 

807 for value in field.value: 

808 name = ops.prepend_name_scope(value, scope) 

809 # Since the graph has been optimized, the node may no longer 

810 # exists 

811 try: 

812 col_op = dest_graph.as_graph_element(name) 

813 except (TypeError, ValueError, KeyError): 

814 continue 

815 dest_graph.add_to_collection(key, col_op) 

816 elif kind == "int64_list": 

817 # NOTE(opensource): This force conversion is to work around the 

818 # fact that Python2 distinguishes between int and long, while 

819 # Python3 has only int. 

820 for value in field.value: 

821 dest_graph.add_to_collection(key, int(value)) 

822 else: 

823 for value in field.value: 

824 dest_graph.add_to_collection(key, 

825 ops.prepend_name_scope(value, scope)) 

826 

827 # Write the transformed graphdef as SavedModel. 

828 saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) 

829 with ops.Graph().as_default(): 

830 importer.import_graph_def(self._converted_graph_def, name="") 

831 _restore_collections( 

832 ops.get_default_graph(), self._grappler_meta_graph_def, 

833 self._collections_to_keep( 

834 self._grappler_meta_graph_def.collection_def)) 

835 # We don't use any specific converter here. 

836 with session.Session() as sess: 

837 saved_model_builder.add_meta_graph_and_variables( 

838 sess, 

839 self._input_saved_model_tags, 

840 signature_def_map=self._grappler_meta_graph_def.signature_def) 

841 # Ignore other meta graphs from the input SavedModel. 

842 saved_model_builder.save() 

843 

844def _get_resource_handle(name, device): 

845 with ops.device(device): 

846 return gen_trt_ops.create_trt_resource_handle(resource_name=name) 

847 

848 

849def _remove_native_segments(input_func): 

850 """Remove native segments from the input TF-TRT Converted Function. 

851 

852 Args: 

853 input_func: provide the concrete function with native segment nodes. The 

854 transformed output func will not contain any native segment nodes. All the 

855 TRTEngineOp references will be deleted and reset to default empty func. 

856 """ 

857 input_graph_def = input_func.graph.as_graph_def() 

858 # Deleting the Native Segment node in each TRTEngineOp node. 

859 nodes_deleted = 0 

860 for func_id in reversed(range(len(input_graph_def.library.function))): 

861 f = input_graph_def.library.function[func_id] 

862 if "native_segment" in f.signature.name: 

863 nodes_deleted += 1 

864 while context.context().has_function(f.signature.name): 

865 context.context().remove_function(f.signature.name) 

866 del input_graph_def.library.function[func_id] 

867 

868 logging.info( 

869 "Found and deleted native segments from " 

870 f"{nodes_deleted} TRTEngineOp nodes." 

871 ) 

872 

873 # Deleting the references to `<EngineName>_native_segment`s. 

874 # This helps TRTEngineOp constructor to not look for native segment handles 

875 # during construction of graph for inference. 

876 for node in input_graph_def.node: 

877 if node.op == "TRTEngineOp": 

878 del node.attr["segment_func"] 

879 for func in input_graph_def.library.function: 

880 for node in func.node_def: 

881 if node.op == "TRTEngineOp": 

882 del node.attr["segment_func"] 

883 # Reconstruct the converted_func with the new graph 

884 new_func = _construct_function_from_graph_def(input_func, input_graph_def) 

885 

886 return new_func 

887 

888 

889class _TRTEngineResource(resource.TrackableResource): 

890 """Class to track the serialized engines resource.""" 

891 

892 def __init__(self, 

893 resource_name, 

894 filename, 

895 maximum_cached_engines, 

896 device="GPU"): 

897 super(_TRTEngineResource, self).__init__(device=device) 

898 self._resource_name = resource_name 

899 # Track the serialized engine file in the SavedModel. 

900 self._filename = self._track_trackable( 

901 asset.Asset(filename), "_serialized_trt_resource_filename") 

902 self._maximum_cached_engines = maximum_cached_engines 

903 

904 def _create_resource(self): 

905 return _get_resource_handle(self._resource_name, self._resource_device) 

906 

907 def _initialize(self): 

908 gen_trt_ops.initialize_trt_resource( 

909 self.resource_handle, 

910 self._filename, 

911 max_cached_engines_count=self._maximum_cached_engines) 

912 

913 def _destroy_resource(self): 

914 handle = _get_resource_handle(self._resource_name, self._resource_device) 

915 with ops.device(self._resource_device): 

916 gen_resource_variable_ops.destroy_resource_op( 

917 handle, ignore_lookup_error=True) 

918 

919 

920def _print_row(fields, positions, print_fn): 

921 """Prints a row.""" 

922 line = "" 

923 for i, field in enumerate(fields): 

924 field = str(field) 

925 end_line_pos = positions[i] 

926 if i > 0: 

927 line = line + " " 

928 line = "{0:{min_length}}".format(line + field, min_length=end_line_pos) 

929 

930 if len(line) > end_line_pos: 

931 line = line[:(end_line_pos - 4)] + " ..." 

932 

933 print_fn(line) 

934 

935 

936def _construct_function_from_graph_def(func, graph_def, frozen_func=None): 

937 """Rebuild function from graph_def.""" 

938 if frozen_func is None: 

939 frozen_func = func 

940 

941 # If a function is converted, then the TF context contains the original 

942 # function while the converted_graph_def contains the converted function. 

943 # Remove the original function from the TF context in this case. 

944 for f in graph_def.library.function: 

945 while context.context().has_function(f.signature.name): 

946 context.context().remove_function(f.signature.name) 

947 

948 captures = { 

949 c[1].name.split(":")[0]: c[0] 

950 for c in frozen_func.graph.captures 

951 } 

952 new_func = wrap_function.function_from_graph_def( 

953 graph_def, [tensor.name for tensor in frozen_func.inputs], 

954 [tensor.name for tensor in frozen_func.outputs], captures) 

955 new_func.graph.structured_outputs = nest.pack_sequence_as( 

956 func.graph.structured_outputs, new_func.graph.structured_outputs) 

957 

958 # Copy structured input signature from original function (used during 

959 # serialization) 

960 new_func.graph.structured_input_signature = (func.structured_input_signature) 

961 

962 return new_func 

963 

964 

965def _apply_inlining(func): 

966 """Apply an inlining optimization to the function's graph definition.""" 

967 graph_def = func.graph.as_graph_def() 

968 

969 # In some cases, a secondary implementation of the function (e.g. for GPU) is 

970 # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in 

971 # TF2 produces a CuDNN-based RNN for GPU). 

972 # This function suppose to inline all functions calls, but "api_implements" 

973 # prevents this from happening. Removing the attribute solves the problem. 

974 # To learn more about "api_implements", see: 

975 # tensorflow/core/grappler/optimizers/implementation_selector.h 

976 for function in graph_def.library.function: 

977 if "api_implements" in function.attr: 

978 del function.attr["api_implements"] 

979 

980 meta_graph = saver.export_meta_graph(graph_def=graph_def, graph=func.graph) 

981 

982 # Clear the initializer_name for the variables collections, since they are not 

983 # needed after saved to saved_model. 

984 for name in [ 

985 "variables", "model_variables", "trainable_variables", "local_variables" 

986 ]: 

987 raw_list = [] 

988 for raw in meta_graph.collection_def["variables"].bytes_list.value: 

989 variable = variable_pb2.VariableDef() 

990 variable.ParseFromString(raw) 

991 variable.ClearField("initializer_name") 

992 raw_list.append(variable.SerializeToString()) 

993 meta_graph.collection_def[name].bytes_list.value[:] = raw_list 

994 

995 # Add a collection 'train_op' so that Grappler knows the outputs. 

996 fetch_collection = meta_graph_pb2.CollectionDef() 

997 for array in func.inputs + func.outputs: 

998 fetch_collection.node_list.value.append(array.name) 

999 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 

1000 

1001 # Initialize RewriterConfig with everything disabled except function inlining. 

1002 config = config_pb2.ConfigProto() 

1003 rewrite_options = config.graph_options.rewrite_options 

1004 rewrite_options.min_graph_nodes = -1 # do not skip small graphs 

1005 rewrite_options.optimizers.append("function") 

1006 

1007 new_graph_def = tf_optimizer.OptimizeGraph(config, meta_graph) 

1008 

1009 return new_graph_def 

1010 

1011 

1012def _annotate_variable_ops(func, graph_def): 

1013 """Annotates variable operations with custom `_shape` attribute. 

1014 

1015 This is required for the converters and shape inference. The graph 

1016 definition is modified in-place. 

1017 

1018 Args: 

1019 func: Function represented by the graph definition. 

1020 graph_def: Graph definition to be annotated in-place. 

1021 

1022 Raises: 

1023 RuntimeError: if some shapes cannot be annotated. 

1024 """ 

1025 ph_shape_map = {} 

1026 for ph, var in zip(func.graph.internal_captures, func.variables): 

1027 ph_shape_map[ph.name] = var.shape 

1028 # Construct a mapping of node names to nodes 

1029 name_to_node = {node.name: node for node in graph_def.node} 

1030 # Go through all the ReadVariableOp nodes in the graph def 

1031 for node in graph_def.node: 

1032 if node.op == "ReadVariableOp" or node.op == "ResourceGather": 

1033 node_ = node 

1034 # Go up the chain of identities to find a placeholder 

1035 while name_to_node[node_.input[0]].op == "Identity": 

1036 node_ = name_to_node[node_.input[0]] 

1037 ph_name = node_.input[0] + ":0" 

1038 if ph_name in ph_shape_map: 

1039 shape = ph_shape_map[ph_name] 

1040 node.attr["_shape"].shape.CopyFrom(shape.as_proto()) 

1041 else: 

1042 raise RuntimeError( 

1043 "Not found in the function captures: {}".format(ph_name)) 

1044 

1045 

1046def _save_calibration_table(node): 

1047 try: 

1048 calibration_table = gen_trt_ops.get_calibration_data_op( 

1049 _get_canonical_engine_name(node.name)) 

1050 node.attr["calibration_data"].s = calibration_table.numpy() 

1051 except (errors.UnknownError, errors.NotFoundError): 

1052 logging.warning("Warning calibration error for %s", node.name) 

1053 

1054 

1055def _convert_to_tensor(inp): 

1056 try: 

1057 if isinstance(inp, dict): 

1058 args = [] 

1059 kwargs = {k: ops.convert_to_tensor(v) for k, v in inp.items()} 

1060 else: 

1061 kwargs = {} 

1062 if isinstance(inp, (list, tuple)): 

1063 args = map(ops.convert_to_tensor, inp) 

1064 else: 

1065 args = [ops.convert_to_tensor(inp)] 

1066 except: 

1067 error_msg = "Failed to convert input to tensor." 

1068 logging.error(error_msg + "\ninp = `{0}`\n".format(inp)) 

1069 raise RuntimeError(error_msg) 

1070 

1071 return args, kwargs 

1072 

1073 

1074@tf_export("experimental.tensorrt.Converter", v1=[]) 

1075class TrtGraphConverterV2(object): 

1076 """An offline converter for TF-TRT transformation for TF 2.0 SavedModels. 

1077 

1078 Windows support is provided experimentally. No guarantee is made regarding 

1079 functionality or engineering support. Use at your own risk. 

1080 

1081 There are several ways to run the conversion: 

1082 

1083 1. FP32/FP16 precision 

1084 

1085 ```python 

1086 params = tf.experimental.tensorrt.ConversionParams( 

1087 precision_mode='FP16') 

1088 converter = tf.experimental.tensorrt.Converter( 

1089 input_saved_model_dir="my_dir", conversion_params=params) 

1090 converter.convert() 

1091 converter.save(output_saved_model_dir) 

1092 ``` 

1093 

1094 In this case, no TRT engines will be built or saved in the converted 

1095 SavedModel. But if input data is available during conversion, we can still 

1096 build and save the TRT engines to reduce the cost during inference (see 

1097 option 2 below). 

1098 

1099 2. FP32/FP16 precision with pre-built engines 

1100 

1101 ```python 

1102 params = tf.experimental.tensorrt.ConversionParams( 

1103 precision_mode='FP16', 

1104 # Set this to a large enough number so it can cache all the engines. 

1105 maximum_cached_engines=16) 

1106 converter = tf.experimental.tensorrt.Converter( 

1107 input_saved_model_dir="my_dir", conversion_params=params) 

1108 converter.convert() 

1109 

1110 # Define a generator function that yields input data, and use it to execute 

1111 # the graph to build TRT engines. 

1112 def my_input_fn(): 

1113 for _ in range(num_runs): 

1114 inp1, inp2 = ... 

1115 yield inp1, inp2 

1116 

1117 converter.build(input_fn=my_input_fn) # Generate corresponding TRT engines 

1118 converter.save(output_saved_model_dir) # Generated engines will be saved. 

1119 ``` 

1120 

1121 In this way, one engine will be built/saved for each unique input shapes of 

1122 the TRTEngineOp. This is good for applications that cannot afford building 

1123 engines during inference but have access to input data that is similar to 

1124 the one used in production (for example, that has the same input shapes). 

1125 Also, the generated TRT engines is platform dependent, so we need to run 

1126 `build()` in an environment that is similar to production (e.g. with 

1127 same type of GPU). 

1128 

1129 3. INT8 precision and calibration with pre-built engines 

1130 

1131 ```python 

1132 params = tf.experimental.tensorrt.ConversionParams( 

1133 precision_mode='INT8', 

1134 # Currently only one INT8 engine is supported in this mode. 

1135 maximum_cached_engines=1, 

1136 use_calibration=True) 

1137 converter = tf.experimental.tensorrt.Converter( 

1138 input_saved_model_dir="my_dir", conversion_params=params) 

1139 

1140 # Define a generator function that yields input data, and run INT8 

1141 # calibration with the data. All input data should have the same shape. 

1142 # At the end of convert(), the calibration stats (e.g. range information) 

1143 # will be saved and can be used to generate more TRT engines with different 

1144 # shapes. Also, one TRT engine will be generated (with the same shape as 

1145 # the calibration data) for save later. 

1146 def my_calibration_input_fn(): 

1147 for _ in range(num_runs): 

1148 inp1, inp2 = ... 

1149 yield inp1, inp2 

1150 

1151 converter.convert(calibration_input_fn=my_calibration_input_fn) 

1152 

1153 # (Optional) Generate more TRT engines offline (same as the previous 

1154 # option), to avoid the cost of generating them during inference. 

1155 def my_input_fn(): 

1156 for _ in range(num_runs): 

1157 inp1, inp2 = ... 

1158 yield inp1, inp2 

1159 converter.build(input_fn=my_input_fn) 

1160 

1161 # Save the TRT engine and the engines. 

1162 converter.save(output_saved_model_dir) 

1163 ``` 

1164 4. To use dynamic shape, we need to call the build method with an input 

1165 function to generate profiles. This step is similar to the INT8 calibration 

1166 step described above. The converter also needs to be created with 

1167 use_dynamic_shape=True and one of the following profile_strategies for 

1168 creating profiles based on the inputs produced by the input function: 

1169 * `Range`: create one profile that works for inputs with dimension values 

1170 in the range of [min_dims, max_dims] where min_dims and max_dims are 

1171 derived from the provided inputs. 

1172 * `Optimal`: create one profile for each input. The profile only works for 

1173 inputs with the same dimensions as the input it is created for. The GPU 

1174 engine will be run with optimal performance with such inputs. 

1175 * `Range+Optimal`: create the profiles for both `Range` and `Optimal`. 

1176 """ 

1177 

1178 def _verify_profile_strategy(self, strategy): 

1179 supported_strategies = [s.lower() for s in supported_profile_strategies()] 

1180 if strategy.lower() not in supported_strategies: 

1181 raise ValueError( 

1182 ("profile_strategy '{}' is not supported. It should be one of {}" 

1183 ).format(strategy, supported_profile_strategies())) 

1184 if strategy == "ImplicitBatchModeCompatible": 

1185 logging.warn( 

1186 "ImplicitBatchModeCompatible strategy is deprecated, and" 

1187 " using it may result in errors during engine building. Please" 

1188 " consider using a different profile strategy.") 

1189 

1190 @deprecation.deprecated_args(None, 

1191 "Use individual converter parameters instead", 

1192 "conversion_params") 

1193 def __init__(self, 

1194 input_saved_model_dir=None, 

1195 input_saved_model_tags=None, 

1196 input_saved_model_signature_key=None, 

1197 use_dynamic_shape=None, 

1198 dynamic_shape_profile_strategy=None, 

1199 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, 

1200 precision_mode=TrtPrecisionMode.FP32, 

1201 minimum_segment_size=3, 

1202 maximum_cached_engines=1, 

1203 use_calibration=True, 

1204 allow_build_at_runtime=True, 

1205 conversion_params=None): 

1206 """Initialize the converter. 

1207 

1208 Args: 

1209 input_saved_model_dir: the directory to load the SavedModel which contains 

1210 the input graph to transforms. Required. 

1211 input_saved_model_tags: list of tags to load the SavedModel. 

1212 input_saved_model_signature_key: the key of the signature to optimize the 

1213 graph for. 

1214 use_dynamic_shape: whether to enable dynamic shape support. None is 

1215 equivalent to False in the current implementation. 

1216 dynamic_shape_profile_strategy: one of the strings in 

1217 supported_profile_strategies(). None is equivalent to Range in the 

1218 current implementation. 

1219 max_workspace_size_bytes: the maximum GPU temporary memory that the TRT 

1220 engine can use at execution time. This corresponds to the 

1221 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). 

1222 precision_mode: one of the strings in 

1223 TrtPrecisionMode.supported_precision_modes(). 

1224 minimum_segment_size: the minimum number of nodes required for a subgraph 

1225 to be replaced by TRTEngineOp. 

1226 maximum_cached_engines: max number of cached TRT engines for dynamic TRT 

1227 ops. Created TRT engines for a dynamic dimension are cached. If the 

1228 number of cached engines is already at max but none of them supports the 

1229 input shapes, the TRTEngineOp will fall back to run the original TF 

1230 subgraph that corresponds to the TRTEngineOp. 

1231 use_calibration: this argument is ignored if precision_mode is not INT8. 

1232 If set to True, a calibration graph will be created to calibrate the 

1233 missing ranges. The calibration graph must be converted to an inference 

1234 graph by running calibration with calibrate(). If set to False, 

1235 quantization nodes will be expected for every tensor in the graph 

1236 (excluding those which will be fused). If a range is missing, an error 

1237 will occur. Please note that accuracy may be negatively affected if 

1238 there is a mismatch between which tensors TRT quantizes and which 

1239 tensors were trained with fake quantization. 

1240 allow_build_at_runtime: whether to allow building TensorRT engines during 

1241 runtime if no prebuilt TensorRT engine can be found that can handle the 

1242 given inputs during runtime, then a new TensorRT engine is built at 

1243 runtime if allow_build_at_runtime=True, and otherwise native TF is used. 

1244 conversion_params: a TrtConversionParams instance (deprecated). 

1245 

1246 Raises: 

1247 ValueError: if the combination of the parameters is invalid. 

1248 """ 

1249 assert context.executing_eagerly() 

1250 if conversion_params is None: 

1251 conversion_params = TrtConversionParams( 

1252 max_workspace_size_bytes=max_workspace_size_bytes, 

1253 precision_mode=precision_mode, 

1254 minimum_segment_size=minimum_segment_size, 

1255 maximum_cached_engines=maximum_cached_engines, 

1256 use_calibration=use_calibration, 

1257 allow_build_at_runtime=allow_build_at_runtime) 

1258 

1259 _check_trt_version_compatibility() 

1260 _check_conversion_params(conversion_params, is_v2=True) 

1261 

1262 self._conversion_params = conversion_params 

1263 self._input_saved_model_dir = input_saved_model_dir 

1264 self._input_saved_model_tags = ( 

1265 input_saved_model_tags or [tag_constants.SERVING]) 

1266 self._input_saved_model_signature_key = ( 

1267 input_saved_model_signature_key or 

1268 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) 

1269 self.freeze = not trt_utils.is_experimental_feature_activated( 

1270 "disable_graph_freezing") 

1271 

1272 self._need_calibration = (( 

1273 (conversion_params.precision_mode == TrtPrecisionMode.INT8) or 

1274 (conversion_params.precision_mode == TrtPrecisionMode.INT8.lower())) and 

1275 conversion_params.use_calibration) 

1276 

1277 self._calibration_input_fn = None 

1278 

1279 self._converted = False 

1280 self._device = None 

1281 self._build_called_once = False 

1282 self._calibrated = False 

1283 

1284 if use_dynamic_shape is None: 

1285 self._use_dynamic_shape = False 

1286 else: 

1287 self._use_dynamic_shape = use_dynamic_shape 

1288 

1289 if not self.freeze and not self._use_dynamic_shape: 

1290 logging.warn( 

1291 "Disabling graph freezing is only possible in dynamic shape mode." 

1292 " The graph will be frozen.") 

1293 self.freeze = True 

1294 

1295 self._profile_strategy = "Unknown" 

1296 if self._use_dynamic_shape: 

1297 if dynamic_shape_profile_strategy is None: 

1298 self._profile_strategy = PROFILE_STRATEGY_RANGE 

1299 else: 

1300 self._verify_profile_strategy(dynamic_shape_profile_strategy) 

1301 self._profile_strategy = dynamic_shape_profile_strategy 

1302 

1303 # Fields to support TF-TRT testing and shouldn't be used for other purpose. 

1304 self._test_only_disable_non_trt_optimizers = False 

1305 

1306 def _need_trt_profiles(self): 

1307 return self._use_dynamic_shape 

1308 

1309 def _run_conversion(self, meta_graph_def): 

1310 """Run Grappler's OptimizeGraph() tool to convert the graph. 

1311 

1312 Args: 

1313 meta_graph_def: the MetaGraphDef instance to run the optimizations on. 

1314 

1315 Returns: 

1316 The optimized GraphDef. 

1317 """ 

1318 grappler_session_config = config_pb2.ConfigProto() 

1319 # Always set `allow_build_at_runtime` for offline TensorRT engine building. 

1320 custom_rewriter_config = _get_tensorrt_rewriter_config( 

1321 conversion_params=self._conversion_params._replace( 

1322 allow_build_at_runtime=True), 

1323 is_dynamic_op=True, 

1324 max_batch_size=None, 

1325 disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers, 

1326 use_implicit_batch=not self._use_dynamic_shape, 

1327 profile_strategy=self._profile_strategy) 

1328 grappler_session_config.graph_options.rewrite_options.CopyFrom( 

1329 custom_rewriter_config) 

1330 return tf_optimizer.OptimizeGraph( 

1331 grappler_session_config, meta_graph_def, graph_id=b"tf_graph") 

1332 

1333 def _for_each_trt_node(self, graph_def, fn): 

1334 """Helper method to manipulate all TRTEngineOps in a GraphDef.""" 

1335 for node in graph_def.node: 

1336 if node.op == _TRT_ENGINE_OP_NAME: 

1337 fn(node) 

1338 for func in graph_def.library.function: 

1339 for node in func.node_def: 

1340 if node.op == _TRT_ENGINE_OP_NAME: 

1341 fn(node) 

1342 

1343 def _execute_calibration(self, calibration_input_fn): 

1344 """Run INT8 calibration with the provided input generator function.""" 

1345 for inp in calibration_input_fn(): 

1346 args, kwargs = _convert_to_tensor(inp) 

1347 self._converted_func(*args, **kwargs) 

1348 

1349 self._for_each_trt_node(self._converted_graph_def, _save_calibration_table) 

1350 

1351 # Rebuild the function since calibration has changed the graph. 

1352 self._converted_func = _construct_function_from_graph_def( 

1353 self._converted_func, self._converted_graph_def) 

1354 self._calibrated = True 

1355 

1356 # TODO(laigd): provide a utility function to optimize a ConcreteFunction and 

1357 # use it here (b/124792963). 

1358 def convert(self, calibration_input_fn=None): 

1359 """Convert the input SavedModel in 2.0 format. 

1360 

1361 Args: 

1362 calibration_input_fn: a generator function that yields input data as a 

1363 list or tuple or dict, which will be used to execute the converted 

1364 signature for calibration. All the returned input data should have the 

1365 same shape. Example: `def input_fn(): yield input1, input2, input3` 

1366 

1367 If dynamic_shape_mode==False, (or if the graph has static input shapes) 

1368 then we run calibration and build the calibrated engine during 

1369 conversion. 

1370 

1371 If dynamic_shape_mode==True (and the graph has any unknown input 

1372 shape), then the reference to calibration_input_fn is stored, and the 

1373 calibration is actually performed when we build the engine (see 

1374 build()). 

1375 

1376 Raises: 

1377 ValueError: if the input combination is invalid. 

1378 

1379 Returns: 

1380 The TF-TRT converted Function. 

1381 """ 

1382 assert not self._converted 

1383 

1384 # Creating an empty tensor to fetch queried device 

1385 device_requested = array_ops.zeros([]).device 

1386 

1387 if "gpu" not in device_requested.lower(): 

1388 raise ValueError(f"Specified device is not a GPU: {device_requested}") 

1389 

1390 if "gpu:0" not in device_requested.lower(): 

1391 self._device = device_requested 

1392 logging.info(f"Placing imported graph from " 

1393 f"`{self._input_saved_model_dir}` on device: {self._device}") 

1394 

1395 if (self._need_calibration and not calibration_input_fn): 

1396 raise ValueError("Should specify calibration_input_fn because INT8 " 

1397 "calibration is needed") 

1398 if (not self._need_calibration and calibration_input_fn): 

1399 raise ValueError("Should not specify calibration_input_fn because INT8 " 

1400 "calibration is not needed") 

1401 

1402 self._saved_model = load.load(self._input_saved_model_dir, 

1403 self._input_saved_model_tags) 

1404 func = self._saved_model.signatures[self._input_saved_model_signature_key] 

1405 if self.freeze: 

1406 frozen_func = convert_to_constants.convert_variables_to_constants_v2(func) 

1407 else: 

1408 inlined_graph_def = _apply_inlining(func) 

1409 _annotate_variable_ops(func, inlined_graph_def) 

1410 frozen_func = _construct_function_from_graph_def(func, inlined_graph_def) 

1411 frozen_graph_def = frozen_func.graph.as_graph_def() 

1412 

1413 # Clear any prior device assignments 

1414 logging.info("Clearing prior device assignments in loaded saved model") 

1415 for node in frozen_graph_def.node: 

1416 node.device = "" 

1417 

1418 if self._device is None: 

1419 grappler_meta_graph_def = saver.export_meta_graph( 

1420 graph_def=frozen_graph_def, graph=frozen_func.graph) 

1421 else: 

1422 with ops.Graph().as_default() as graph, ops.device(self._device): 

1423 importer.import_graph_def(frozen_graph_def, name="") 

1424 grappler_meta_graph_def = saver.export_meta_graph( 

1425 graph_def=graph.as_graph_def(), graph=graph) 

1426 

1427 # Add a collection 'train_op' so that Grappler knows the outputs. 

1428 fetch_collection = meta_graph_pb2.CollectionDef() 

1429 for array in frozen_func.inputs + frozen_func.outputs: 

1430 fetch_collection.node_list.value.append(array.name) 

1431 grappler_meta_graph_def.collection_def["train_op"].CopyFrom( 

1432 fetch_collection) 

1433 

1434 # Run TRT optimizer in Grappler to convert the graph. 

1435 self._converted_graph_def = self._run_conversion(grappler_meta_graph_def) 

1436 self._converted_func = _construct_function_from_graph_def( 

1437 func, self._converted_graph_def, frozen_func) 

1438 

1439 if self._need_calibration: 

1440 # Execute calibration here only if not in dynamic shape mode. 

1441 if not self._need_trt_profiles(): 

1442 self._execute_calibration(calibration_input_fn) 

1443 else: 

1444 self._calibration_input_fn = calibration_input_fn 

1445 

1446 self._converted = True 

1447 

1448 graphviz_path = os.environ.get("TF_TRT_EXPORT_GRAPH_VIZ_PATH", default=None) 

1449 if graphviz_path is not None: 

1450 try: 

1451 trt_utils.draw_graphdef_as_graphviz( 

1452 graphdef=self._converted_func.graph.as_graph_def(add_shapes=True), 

1453 dot_output_filename=graphviz_path) 

1454 except Exception as e: 

1455 logging.error("An Exception occured during the export of the graph " 

1456 f"visualization: {e}") 

1457 

1458 return self._converted_func 

1459 

1460 def build(self, input_fn): 

1461 """Run inference with converted graph in order to build TensorRT engines. 

1462 

1463 If the conversion requires INT8 calibration, then a reference to the 

1464 calibration function was stored during the call to convert(). Calibration 

1465 will be performed while we build the TensorRT engines. 

1466 

1467 Args: 

1468 input_fn: a generator function that provides the input data as a single 

1469 array, OR a list or tuple of the arrays OR a dict, which will be used 

1470 to execute the converted signature to generate TRT engines. 

1471 Example 1: 

1472 `def input_fn(): 

1473 # Let's assume a network with 1 input tensor. 

1474 # We generate 2 sets of dummy input data: 

1475 input_shapes = [(1, 16), # 1st shape 

1476 (2, 32)] # 2nd shape 

1477 for shapes in input_shapes: 

1478 # return an input tensor 

1479 yield np.zeros(shape).astype(np.float32)' 

1480 

1481 Example 2: 

1482 `def input_fn(): 

1483 # Let's assume a network with 2 input tensors. 

1484 # We generate 3 sets of dummy input data: 

1485 input_shapes = [[(1, 16), (2, 16)], # 1st input list 

1486 [(2, 32), (4, 32)], # 2nd list of two tensors 

1487 [(4, 32), (8, 32)]] # 3rd input list 

1488 for shapes in input_shapes: 

1489 # return a list of input tensors 

1490 yield [np.zeros(x).astype(np.float32) for x in shapes]` 

1491 

1492 Raises: 

1493 NotImplementedError: build() is already called. 

1494 RuntimeError: the input_fx is None. 

1495 """ 

1496 if self._build_called_once: 

1497 raise NotImplementedError("build() is already called. It is not " 

1498 "supported to call build() more than once.") 

1499 if not input_fn: 

1500 raise RuntimeError("input_fn is None. Method build() needs input_fn " 

1501 "to be specified in order to build TensorRT engines") 

1502 if not self._converted: 

1503 raise RuntimeError("Need to call convert() before build()") 

1504 if (self._need_calibration and not self._calibrated and 

1505 self._calibration_input_fn is None): 

1506 raise RuntimeError("Need to provide the calibration_input_fn arg while " 

1507 "calling convert().") 

1508 

1509 def _set_profile_generation_mode(value, node): 

1510 node.attr["_profile_generation_mode"].b = value 

1511 

1512 if self._need_trt_profiles(): 

1513 # Enable profile generation. 

1514 self._for_each_trt_node(self._converted_graph_def, 

1515 partial(_set_profile_generation_mode, True)) 

1516 # Profile generation is enabled using the _profile_generation_mode 

1517 # attribute of the TRTEngineOps. We need to rebuild the function to 

1518 # change this attribute. 

1519 func = _construct_function_from_graph_def(self._converted_func, 

1520 self._converted_graph_def) 

1521 else: 

1522 func = self._converted_func 

1523 

1524 first_input = None 

1525 # Run inference: 

1526 # Builds TRT engines if self._need_trt_profiles is False. 

1527 # Builds TRT optimization profiles if self._need_trt_profiles is True. 

1528 for inp in input_fn(): 

1529 if first_input is None: 

1530 first_input = inp 

1531 args, kwargs = _convert_to_tensor(inp) 

1532 func(*args, **kwargs) 

1533 

1534 if self._need_trt_profiles(): 

1535 # Disable profile generation. 

1536 self._for_each_trt_node(self._converted_graph_def, 

1537 partial(_set_profile_generation_mode, False)) 

1538 

1539 # Run calibration if required, this would have been skipped in 

1540 # the convert step 

1541 if self._need_calibration and not self._calibrated: 

1542 self._execute_calibration(self._calibration_input_fn) 

1543 # calibration also builds the engine 

1544 else: 

1545 # Use the first input in explicit batch mode to build TensorRT engines 

1546 # after generating all the profiles. The first input is used but any of 

1547 # the inputs can be used because the shape of this input does not 

1548 # determine the engine and instead the shapes collected in profiles 

1549 # determine the engine. 

1550 args, kwargs = _convert_to_tensor(first_input) 

1551 self._converted_func(*args, **kwargs) 

1552 

1553 self._build_called_once = True 

1554 

1555 def save(self, 

1556 output_saved_model_dir, 

1557 save_gpu_specific_engines=True, 

1558 options=None): 

1559 """Save the converted SavedModel. 

1560 

1561 Args: 

1562 output_saved_model_dir: directory to saved the converted SavedModel. 

1563 save_gpu_specific_engines: whether to save TRT engines that have been 

1564 built. When True, all engines are saved and when False, the engines 

1565 are not saved and will be rebuilt at inference time. By using 

1566 save_gpu_specific_engines=False after doing INT8 calibration, inference 

1567 can be done on different GPUs than the GPU that the model was calibrated 

1568 and saved on. 

1569 options: `tf.saved_model.SaveOptions` object for configuring save options. 

1570 Raises: 

1571 RuntimeError: if the needed calibration hasn't been done. 

1572 """ 

1573 assert self._converted 

1574 

1575 # 'remove_native_segments': setting this value to True removes native segments 

1576 # associated with each TRT engine. This option can be used to reduce the size 

1577 # of the converted model. Please note that a converted model without native 

1578 # segments can't be used for collecting profiles, building or re-converting. 

1579 # The reduced model can only be used for inference when no native segments 

1580 # are required for computation. When remove_native_segments flag is set to 

1581 # True, the converted_graph_def needs to be reduced before saved_model 

1582 # function serialization. 

1583 if trt_utils.is_experimental_feature_activated("remove_native_segments"): 

1584 logging.info( 

1585 "'remove_native_segments' experimental feature is enabled" 

1586 " during saving of converted SavedModel." 

1587 ) 

1588 self._converted_func = _remove_native_segments(self._converted_func) 

1589 self._converted_graph_def = self._converted_func.graph.as_graph_def() 

1590 

1591 if self._need_calibration and not self._calibrated: 

1592 raise RuntimeError("A model that requires INT8 calibration has to be " 

1593 "built before saving it. Call build() to build and " 

1594 "calibrate the TensorRT engines.") 

1595 # Serialize the TRT engines in the cache if any, and create trackable 

1596 # resource to track them. 

1597 engine_asset_dir = tempfile.mkdtemp() 

1598 resource_map = {} 

1599 

1600 def _serialize_and_track_engine(node): 

1601 """Serialize TRT engines in the cache and track them.""" 

1602 # Don't dump the same cache twice. 

1603 canonical_engine_name = _get_canonical_engine_name(node.name) 

1604 if canonical_engine_name in resource_map: 

1605 return 

1606 

1607 filename = os.path.join(engine_asset_dir, 

1608 "trt-serialized-engine." + canonical_engine_name) 

1609 

1610 try: 

1611 gen_trt_ops.serialize_trt_resource( 

1612 resource_name=canonical_engine_name, 

1613 filename=filename, 

1614 delete_resource=True, 

1615 save_gpu_specific_engines=save_gpu_specific_engines) 

1616 except errors.NotFoundError: 

1617 logging.info( 

1618 "Could not find %s in TF-TRT cache. " 

1619 "This can happen if build() is not called, " 

1620 "which means TensorRT engines will be built " 

1621 "and cached at runtime.", canonical_engine_name) 

1622 return 

1623 

1624 # TODO(laigd): add an option for the user to choose the device. 

1625 resource_map[canonical_engine_name] = _TRTEngineResource( 

1626 canonical_engine_name, filename, 

1627 self._conversion_params.maximum_cached_engines) 

1628 

1629 self._for_each_trt_node(self._converted_graph_def, 

1630 _serialize_and_track_engine) 

1631 # If the graph is frozen, tracked variables are not needed by the converted model. 

1632 trackable = autotrackable.AutoTrackable( 

1633 ) if self.freeze else self._saved_model 

1634 trackable.trt_engine_resources = resource_map 

1635 

1636 # Set allow_build_at_runtime=False if asked by user. 

1637 # 

1638 # This attribute is set here because build() needs it to be True in order to 

1639 # build engines. 

1640 if not self._conversion_params.allow_build_at_runtime: 

1641 

1642 def _reset_allow_build_at_runtime(node): 

1643 node.attr["_allow_build_at_runtime"].b = False 

1644 

1645 self._for_each_trt_node(self._converted_graph_def, 

1646 _reset_allow_build_at_runtime) 

1647 # Rebuild the function since a node attribute changed above 

1648 reset_converted_func = wrap_function.function_from_graph_def( 

1649 self._converted_graph_def, 

1650 [tensor.name for tensor in self._converted_func.inputs], 

1651 [tensor.name for tensor in self._converted_func.outputs]) 

1652 reset_converted_func.graph.structured_outputs = nest.pack_sequence_as( 

1653 self._converted_func.graph.structured_outputs, 

1654 reset_converted_func.graph.structured_outputs) 

1655 reset_converted_func.graph.structured_input_signature = ( 

1656 self._converted_func.structured_input_signature) 

1657 self._converted_func = reset_converted_func 

1658 

1659 # Rewrite the signature map using the optimized ConcreteFunction. 

1660 signatures = {self._input_saved_model_signature_key: self._converted_func} 

1661 save.save(trackable, output_saved_model_dir, signatures, options=options) 

1662 

1663 def summary(self, line_length=160, detailed=True, print_fn=None): 

1664 """This method describes the results of the conversion by TF-TRT. 

1665 

1666 It includes information such as the name of the engine, the number of nodes 

1667 per engine, the input and output dtype, along with the input shape of each 

1668 TRTEngineOp. 

1669 

1670 Args: 

1671 line_length: Default line length when printing on the console. Minimum 160 

1672 characters long. 

1673 detailed: Whether or not to show the nodes inside each TRTEngineOp. 

1674 print_fn: Print function to use. Defaults to `print`. It will be called on 

1675 each line of the summary. You can set it to a custom function in order 

1676 to capture the string summary. 

1677 

1678 Raises: 

1679 RuntimeError: if the graph is not converted. 

1680 """ 

1681 if not self._converted: 

1682 raise RuntimeError( 

1683 f"Impossible to call `{self.__class__.__name__}.summary()` before " 

1684 f"calling {self.__class__.__name__}.convert()`.") 

1685 

1686 if line_length < 160: 

1687 raise ValueError(f"Invalid `line_length` value has been received: " 

1688 f"{line_length}. Minimum: 160.") 

1689 

1690 if print_fn is None: 

1691 print_fn = print 

1692 

1693 # positions are percentage of `line_length`. positions[i]+1 is the starting 

1694 # position for (i+1)th field. We also make sure that the last char printed 

1695 # for each field is a space. 

1696 columns = [ 

1697 # (column name, column size in % of line) 

1698 ("TRTEngineOP Name", .20), # 20% 

1699 ("Device", .09), # 29% 

1700 ("# Nodes", .05), # 34% 

1701 ("# Inputs", .09), # 43% 

1702 ("# Outputs", .09), # 52% 

1703 ("Input DTypes", .12), # 64% 

1704 ("Output Dtypes", .12), # 76% 

1705 ("Input Shapes", .12), # 88% 

1706 ("Output Shapes", .12) # 100% 

1707 ] 

1708 

1709 positions = [int(line_length * p) for _, p in columns] 

1710 positions = np.cumsum(positions).tolist() 

1711 headers = [h for h, _ in columns] 

1712 

1713 _print_row(headers, positions, print_fn=print_fn) 

1714 print_fn("=" * line_length) 

1715 

1716 n_engines = 0 

1717 n_ops_converted = 0 

1718 n_ops_not_converted = 0 

1719 

1720 graphdef = self._converted_func.graph.as_graph_def(add_shapes=True) 

1721 

1722 trtengineops_dict = dict() 

1723 for node in graphdef.node: 

1724 if node.op != "TRTEngineOp": 

1725 n_ops_not_converted += 1 

1726 continue 

1727 else: 

1728 trtengineops_dict[node.name] = node 

1729 n_engines += 1 

1730 

1731 for name, node in sorted(trtengineops_dict.items()): 

1732 node_device = node.device.split("/")[-1] 

1733 in_shapes = trt_utils.get_node_io_shapes(node, "input_shapes") 

1734 out_shapes = trt_utils.get_node_io_shapes(node, "_output_shapes") 

1735 in_dtypes = trt_utils.get_trtengineop_io_dtypes(node, "InT") 

1736 out_dtypes = trt_utils.get_trtengineop_io_dtypes(node, "OutT") 

1737 in_nodes_count = trt_utils.get_trtengineop_io_nodes_count(node, "InT") 

1738 out_nodes_count = trt_utils.get_trtengineop_io_nodes_count(node, "OutT") 

1739 node_count, converted_ops_dict = trt_utils.get_trtengineop_node_op_count( 

1740 graphdef, name) 

1741 

1742 n_ops_converted += node_count 

1743 

1744 if n_engines != 1: 

1745 print_fn(f"\n{'-'*40}\n") 

1746 

1747 _print_row( 

1748 fields=[ 

1749 name, node_device, node_count, in_nodes_count, out_nodes_count, 

1750 in_dtypes, out_dtypes, in_shapes, out_shapes 

1751 ], 

1752 positions=positions, 

1753 print_fn=print_fn) 

1754 

1755 if detailed: 

1756 print_fn() 

1757 for key, value in sorted(dict(converted_ops_dict).items()): 

1758 print_fn(f"\t- {key}: {value}x") 

1759 

1760 print_fn(f"\n{'='*line_length}") 

1761 print_fn(f"[*] Total number of TensorRT engines: {n_engines}") 

1762 total_ops = n_ops_not_converted + n_ops_converted 

1763 conversion_ratio = n_ops_converted / total_ops * 100 

1764 print_fn(f"[*] % of OPs Converted: {conversion_ratio:.2f}% " 

1765 f"[{n_ops_converted}/{total_ops}]\n") 

1766 

1767 

1768# TODO(laigd): use TrtConversionParams here. 

1769def create_inference_graph( 

1770 input_graph_def, 

1771 outputs, 

1772 max_batch_size=1, 

1773 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, 

1774 precision_mode=TrtPrecisionMode.FP32, 

1775 minimum_segment_size=3, 

1776 is_dynamic_op=False, 

1777 maximum_cached_engines=1, 

1778 input_saved_model_dir=None, 

1779 input_saved_model_tags=None, 

1780 input_saved_model_signature_key=None, 

1781 output_saved_model_dir=None): 

1782 """Python wrapper for the TRT transformation. 

1783 

1784 Args: 

1785 input_graph_def: a GraphDef object containing a model to be transformed. If 

1786 set to None, the graph will be read from the SavedModel loaded from 

1787 input_saved_model_dir. 

1788 outputs: list of tensors or node names for the model outputs. Only used when 

1789 input_graph_def is not None. 

1790 max_batch_size: max size for the input batch. 

1791 max_workspace_size_bytes: the maximum GPU temporary memory which the TRT 

1792 engine can use at execution time. This corresponds to the 'workspaceSize' 

1793 parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). 

1794 precision_mode: one of TrtPrecisionMode.supported_precision_modes(). 

1795 minimum_segment_size: the minimum number of nodes required for a subgraph to 

1796 be replaced by TRTEngineOp. 

1797 is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT 

1798 network and engine at run time. 

1799 maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. 

1800 If the number of cached engines is already at max but none of them can 

1801 serve the input, the TRTEngineOp will fall back to run the TF function 

1802 based on which the TRTEngineOp is created. 

1803 input_saved_model_dir: the directory to load the SavedModel which contains 

1804 the input graph to transforms. Used only when input_graph_def is None. 

1805 input_saved_model_tags: list of tags to load the SavedModel. 

1806 input_saved_model_signature_key: the key of the signature to optimize the 

1807 graph for. 

1808 output_saved_model_dir: if not None, construct a SavedModel using the 

1809 returned GraphDef and save it to the specified directory. This option only 

1810 works when the input graph is loaded from a SavedModel, i.e. when 

1811 input_saved_model_dir is specified and input_graph_def is None. 

1812 

1813 Returns: 

1814 A GraphDef transformed from input_graph_def (or the SavedModel graph def 

1815 loaded from input_saved_model_dir, if input_graph_def is not present), where 

1816 all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF 

1817 function is added for each of the subgraphs. 

1818 

1819 If is_dynamic_op is True, each TRTEngineOp will contain a serialized 

1820 subgraph GraphDef, which will be converted to a TRT engine at execution time 

1821 and the TRT engine will be cached for future usage. A new TRT engine will be 

1822 created each time when none of the cached engines match the input shapes. If 

1823 it fails to execute the TRT engine or the number of cached engines reaches 

1824 maximum_cached_engines, the op will fall back to call the corresponding TF 

1825 function. 

1826 

1827 If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT 

1828 engine created from the corresponding subgraph. No more engines will be 

1829 created on the fly, and the op will fall back to call the corresponding TF 

1830 function when it fails to execute the engine. 

1831 

1832 Raises: 

1833 ValueError: if the combination of the parameters is invalid. 

1834 """ 

1835 trt_converter = TrtGraphConverter( 

1836 input_saved_model_dir=input_saved_model_dir, 

1837 input_saved_model_tags=input_saved_model_tags, 

1838 input_saved_model_signature_key=input_saved_model_signature_key, 

1839 input_graph_def=input_graph_def, 

1840 nodes_denylist=outputs, 

1841 max_batch_size=max_batch_size, 

1842 max_workspace_size_bytes=max_workspace_size_bytes, 

1843 precision_mode=precision_mode, 

1844 minimum_segment_size=minimum_segment_size, 

1845 is_dynamic_op=is_dynamic_op, 

1846 maximum_cached_engines=maximum_cached_engines, 

1847 use_calibration=False) 

1848 converted_graph_def = trt_converter.convert() 

1849 if output_saved_model_dir: 

1850 trt_converter.save(output_saved_model_dir) 

1851 return converted_graph_def