Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/compiler/tensorrt/trt_convert.py: 17%
650 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Exposes the Python wrapper conversion to trt_graph."""
17import collections
18from functools import partial # pylint: disable=g-importing-member
19import os
20import platform
21import sys
22import tempfile
24import numpy as np
25import six as _six
27from tensorflow.core.framework import variable_pb2
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.core.protobuf import meta_graph_pb2
30from tensorflow.core.protobuf import rewriter_config_pb2
31from tensorflow.python.client import session
32from tensorflow.python.compiler.tensorrt import utils as trt_utils
33from tensorflow.python.eager import context
34from tensorflow.python.eager import wrap_function
35from tensorflow.python.framework import convert_to_constants
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import importer
39from tensorflow.python.framework import ops
40from tensorflow.python.grappler import tf_optimizer
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import gen_resource_variable_ops
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.saved_model import builder
45from tensorflow.python.saved_model import load
46from tensorflow.python.saved_model import loader
47from tensorflow.python.saved_model import save
48from tensorflow.python.saved_model import signature_constants
49from tensorflow.python.saved_model import tag_constants
50from tensorflow.python.trackable import asset
51from tensorflow.python.trackable import autotrackable
52from tensorflow.python.trackable import resource
53from tensorflow.python.training import saver
54from tensorflow.python.util import deprecation
55from tensorflow.python.util import nest
56from tensorflow.python.util.lazy_loader import LazyLoader
57from tensorflow.python.util.tf_export import tf_export
59# Lazily load the op, since it's not available in cpu-only builds. Importing
60# this at top will cause tests that imports TF-TRT fail when they're built
61# and run without CUDA/GPU.
62gen_trt_ops = LazyLoader(
63 "gen_trt_ops", globals(),
64 "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops")
66_pywrap_py_utils = LazyLoader(
67 "_pywrap_py_utils", globals(),
68 "tensorflow.compiler.tf2tensorrt._pywrap_py_utils")
70# Register TRT ops in python, so that when users import this module they can
71# execute a TRT-converted graph without calling any of the methods in this
72# module.
73#
74# This will call register_op_list() in
75# tensorflow/python/framework/op_def_registry.py, but it doesn't register
76# the op or the op kernel in C++ runtime.
77try:
78 gen_trt_ops.trt_engine_op # pylint: disable=pointless-statement
79except AttributeError:
80 pass
83def _to_bytes(s):
84 """Encode s if it is a sequence of chars."""
85 if isinstance(s, _six.text_type):
86 return s.encode("utf-8", errors="surrogateescape")
87 return s
90def _to_string(s):
91 """Decode s if it is a sequence of bytes."""
92 if isinstance(s, _six.binary_type):
93 return s.decode("utf-8")
94 return s
97class TrtPrecisionMode(object):
98 FP32 = "FP32"
99 FP16 = "FP16"
100 INT8 = "INT8"
102 @staticmethod
103 def supported_precision_modes():
104 precisions = [
105 TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8
106 ]
107 return precisions + [p.lower() for p in precisions]
110# Use a large enough number as the default max_workspace_size for TRT engines,
111# so it can produce reasonable performance results with the default.
112# For TRT >= 8.4, the recommendation is MAX_INT.
113if (_pywrap_py_utils.is_tensorrt_enabled() and
114 trt_utils.is_loaded_tensorrt_version_greater_equal(8, 4, 0)):
115 # We must use `sys.maxsize - 512` to avoid overflow during casting.
116 DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = sys.maxsize - 512
117else:
118 DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30 # 1,073,741,824
120PROFILE_STRATEGY_RANGE = "Range"
121PROFILE_STRATEGY_OPTIMAL = "Optimal"
122PROFILE_STRATEGY_RANGE_OPTIMAL = "Range+Optimal"
123PROFILE_STRATEGY_IMPLICIT_BATCH_MODE_COMPATIBLE = "ImplicitBatchModeCompatible"
126def supported_profile_strategies():
127 return [
128 PROFILE_STRATEGY_RANGE, PROFILE_STRATEGY_OPTIMAL,
129 PROFILE_STRATEGY_RANGE_OPTIMAL,
130 PROFILE_STRATEGY_IMPLICIT_BATCH_MODE_COMPATIBLE
131 ]
134@tf_export("experimental.tensorrt.ConversionParams", v1=[])
135class TrtConversionParams(
136 collections.namedtuple("TrtConversionParams", [
137 "max_workspace_size_bytes", "precision_mode", "minimum_segment_size",
138 "maximum_cached_engines", "use_calibration", "allow_build_at_runtime"
139 ])):
140 """Parameters that are used for TF-TRT conversion.
142 Fields:
143 max_workspace_size_bytes: the maximum GPU temporary memory that the TRT
144 engine can use at execution time. This corresponds to the
145 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
146 precision_mode: one of the strings in
147 TrtPrecisionMode.supported_precision_modes().
148 minimum_segment_size: the minimum number of nodes required for a subgraph
149 to be replaced by TRTEngineOp.
150 maximum_cached_engines: max number of cached TRT engines for dynamic TRT
151 ops. Created TRT engines for a dynamic dimension are cached. If the
152 number of cached engines is already at max but none of them supports the
153 input shapes, the TRTEngineOp will fall back to run the original TF
154 subgraph that corresponds to the TRTEngineOp.
155 use_calibration: this argument is ignored if precision_mode is not INT8.
156 If set to True, a calibration graph will be created to calibrate the
157 missing ranges. The calibration graph must be converted to an inference
158 graph by running calibration with calibrate(). If set to False,
159 quantization nodes will be expected for every tensor in the graph
160 (excluding those which will be fused). If a range is missing, an error
161 will occur. Please note that accuracy may be negatively affected if
162 there is a mismatch between which tensors TRT quantizes and which
163 tensors were trained with fake quantization.
164 allow_build_at_runtime: whether to allow building TensorRT engines during
165 runtime if no prebuilt TensorRT engine can be found that can handle the
166 given inputs during runtime, then a new TensorRT engine is built at
167 runtime if allow_build_at_runtime=True, and otherwise native TF is used.
168 """
170 def __new__(cls,
171 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
172 precision_mode=TrtPrecisionMode.FP32,
173 minimum_segment_size=3,
174 maximum_cached_engines=1,
175 use_calibration=True,
176 allow_build_at_runtime=True):
177 return super(TrtConversionParams,
178 cls).__new__(cls, max_workspace_size_bytes, precision_mode,
179 minimum_segment_size, maximum_cached_engines,
180 use_calibration, allow_build_at_runtime)
183DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams()
185_TRT_ENGINE_OP_NAME = "TRTEngineOp"
188def _check_conversion_params(conversion_params, is_v2=False):
189 """Validate the provided TrtConversionParams.
191 Args:
192 conversion_params: a TrtConversionParams instance.
193 is_v2: whether we're getting a RewriterConfig for TF 2.0.
195 Raises:
196 TypeError: if any of the parameters are of unexpected type.
197 ValueError: if any of the parameters are of unexpected value.
198 """
199 supported_precision_modes = TrtPrecisionMode.supported_precision_modes()
200 if conversion_params.precision_mode not in supported_precision_modes:
201 raise ValueError(
202 ("precision mode '{}' is not supported."
203 "It should be one of {}").format(conversion_params.precision_mode,
204 supported_precision_modes))
205 if (conversion_params.minimum_segment_size <= 0 and
206 conversion_params.minimum_segment_size != -1):
207 raise ValueError("minimum segment size should be positive or -1 "
208 "(to disable main graph conversion).")
211def _check_trt_version_compatibility():
212 """Check compatibility of TensorRT version.
214 Raises:
215 RuntimeError: if the TensorRT library version is incompatible.
216 """
218 if not _pywrap_py_utils.is_tensorrt_enabled():
219 logging.error(
220 "Tensorflow needs to be built with TensorRT support enabled to allow "
221 "TF-TRT to operate.")
223 raise RuntimeError("Tensorflow has not been built with TensorRT support.")
225 if platform.system() == "Windows":
226 logging.warn(
227 "Windows support is provided experimentally. No guarantee is made "
228 "regarding functionality or engineering support. Use at your own risk.")
230 linked_version = _pywrap_py_utils.get_linked_tensorrt_version()
231 loaded_version = _pywrap_py_utils.get_loaded_tensorrt_version()
233 logging.info("Linked TensorRT version: %s", str(linked_version))
234 logging.info("Loaded TensorRT version: %s", str(loaded_version))
236 def raise_trt_version_deprecated(version_type, trt_version):
237 assert version_type in [
238 "linked", "loaded"
239 ], ("Incorrect value received for version_type: %s. Accepted: ['linked', "
240 "'loaded']") % version_type
242 logging.error(
243 "The {version_type} version of TensorRT: `{trt_version}` has now "
244 "been removed. Please upgrade to TensorRT 7 or more recent.".format(
245 version_type=version_type,
246 trt_version=trt_utils.version_tuple_to_string(trt_version)))
248 raise RuntimeError("Incompatible %s TensorRT versions" % version_type)
250 if not trt_utils.is_linked_tensorrt_version_greater_equal(7, 0, 0):
251 raise_trt_version_deprecated("linked", linked_version)
253 if not trt_utils.is_loaded_tensorrt_version_greater_equal(7, 0, 0):
254 raise_trt_version_deprecated("loaded", loaded_version)
256 if (loaded_version[0] != linked_version[0] or
257 not trt_utils.is_loaded_tensorrt_version_greater_equal(*linked_version)):
258 logging.error(
259 "Loaded TensorRT %s but linked TensorFlow against TensorRT %s. A few "
260 "requirements must be met:\n"
261 "\t-It is required to use the same major version of TensorRT during "
262 "compilation and runtime.\n"
263 "\t-TensorRT does not support forward compatibility. The loaded "
264 "version has to be equal or more recent than the linked version.",
265 trt_utils.version_tuple_to_string(loaded_version),
266 trt_utils.version_tuple_to_string(linked_version))
267 raise RuntimeError("Incompatible TensorRT major version")
269 elif loaded_version != linked_version:
270 logging.info(
271 "Loaded TensorRT %s and linked TensorFlow against TensorRT %s. This is "
272 "supported because TensorRT minor/patch upgrades are backward "
273 "compatible.", trt_utils.version_tuple_to_string(loaded_version),
274 trt_utils.version_tuple_to_string(linked_version))
277def _get_tensorrt_rewriter_config(conversion_params,
278 is_dynamic_op=None,
279 max_batch_size=None,
280 is_v2=False,
281 disable_non_trt_optimizers=False,
282 use_implicit_batch=True,
283 profile_strategy=PROFILE_STRATEGY_RANGE):
284 """Returns a RewriterConfig proto for TRT transformation.
286 Args:
287 conversion_params: a TrtConversionParams instance.
288 is_dynamic_op: whether to use dynamic engines.
289 max_batch_size: maximum batch size for static engines.
290 is_v2: whether we're getting a RewriterConfig for TF 2.0.
291 disable_non_trt_optimizers: Turn off all default Grappler optimizers.
292 use_implicit_batch: Whether to use implicit batch or explicit batch.
293 profile_strategy: dynamic shape optimization profile strategy.
295 Returns:
296 A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
298 Raises:
299 TypeError: if any of the parameters are of unexpected type.
300 ValueError: if any of the parameters are of unexpected value.
301 """
302 _check_conversion_params(conversion_params, is_v2=is_v2)
303 if is_v2 and is_dynamic_op is not None and not is_dynamic_op:
304 raise ValueError("is_dynamic_op is either None or True for TF2")
305 if not is_v2 and is_dynamic_op is None:
306 raise ValueError("is_dynamic_op can't be None for TF1")
308 if (is_dynamic_op is None or is_dynamic_op) and max_batch_size is not None:
309 raise ValueError("max_batch_size has to be None for TF2"
310 " or when is_dynamic_op == True in TF1")
311 if is_dynamic_op is not None and not is_dynamic_op and not isinstance(
312 max_batch_size, int):
313 raise ValueError(
314 "max_batch_size has to be an integer for is_dynamic_op==False in TF1")
315 rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
316 # Disable Grappler Remapper to avoid that fused OPs that may not be
317 # beneficial to TF-TRT and are not supported by TF-TRT.
318 rewriter_config_with_trt.remapping = False
320 # Prevent folding of Const->QDQ chains.
321 rewriter_config_with_trt. \
322 experimental_disable_folding_quantization_emulation = (
323 trt_utils.is_linked_tensorrt_version_greater_equal(8, 0, 0) or
324 trt_utils.is_loaded_tensorrt_version_greater_equal(8, 0, 0))
326 if not disable_non_trt_optimizers:
327 rewriter_config_with_trt.optimizers.extend([
328 "pruning", "debug_stripper", "layout", "dependency", "constfold",
329 "common_subgraph_elimination"
330 ])
332 rewriter_config_with_trt.meta_optimizer_iterations = (
333 rewriter_config_pb2.RewriterConfig.ONE)
334 optimizer = rewriter_config_with_trt.custom_optimizers.add()
336 if not disable_non_trt_optimizers:
337 # Add a constfold optimizer to cleanup the unused Const nodes.
338 rewriter_config_with_trt.custom_optimizers.add().name = "constfold"
340 optimizer.name = "TensorRTOptimizer"
341 optimizer.parameter_map[
342 "minimum_segment_size"].i = conversion_params.minimum_segment_size
343 optimizer.parameter_map["max_workspace_size_bytes"].i = (
344 conversion_params.max_workspace_size_bytes)
345 optimizer.parameter_map["precision_mode"].s = _to_bytes(
346 conversion_params.precision_mode)
347 optimizer.parameter_map[
348 "maximum_cached_engines"].i = conversion_params.maximum_cached_engines
349 optimizer.parameter_map[
350 "use_calibration"].b = conversion_params.use_calibration
351 optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
352 optimizer.parameter_map[
353 "allow_build_at_runtime"].b = conversion_params.allow_build_at_runtime
354 if max_batch_size is not None:
355 optimizer.parameter_map["max_batch_size"].i = max_batch_size
356 optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch
357 # While we accept case insensitive strings from the users, we only pass the
358 # strings in lower cases to TF-TRT converter.
359 if not use_implicit_batch:
360 optimizer.parameter_map["profile_strategy"].s = _to_bytes(
361 profile_strategy.lower())
363 # Disabling optimizers should happen after defining the TF-TRT grappler pass
364 # otherwise the template can overwrite the disablement.
365 if disable_non_trt_optimizers:
366 trt_utils.disable_non_trt_optimizers_in_rewriter_config(
367 rewriter_config_with_trt)
369 return rewriter_config_with_trt
372@deprecation.deprecated(
373 None, "You shouldn't need a rewriter_config with the current TF-TRT APIs.")
374def get_tensorrt_rewriter_config(conversion_params,
375 is_dynamic_op=None,
376 max_batch_size=None,
377 is_v2=False,
378 disable_non_trt_optimizers=False):
379 return _get_tensorrt_rewriter_config(conversion_params, is_dynamic_op,
380 max_batch_size, is_v2,
381 disable_non_trt_optimizers)
384# Remove all scope prefixes in the node name. In TF 2.0, the same concrete
385# function can be initialized multiple times with different prefixes, and
386# this will result in the same TRTEngineOp being initialized multiple times
387# with different cache and duplicate TRT engines.
388# TODO(laigd): this may be caused by the fact that TRTEngineOp is not
389# stateful, need to investigate.
390# TODO(laigd): we rely on the fact that all functions are fully inlined
391# before TF-TRT optimizer is called, as otherwise it may generate the same
392# name when optimizing a different function graph. Fix this.
393def _get_canonical_engine_name(name):
394 return name.split("/")[-1]
397class TrtGraphConverter(object):
398 """A converter for TF-TRT transformation for TF 1.x GraphDef/SavedModels.
400 To run the conversion without quantization calibration (e.g. for FP32/FP16
401 precision modes):
403 ```python
404 converter = TrtGraphConverter(
405 input_saved_model_dir="my_dir",
406 precision_mode=TrtPrecisionMode.FP16)
407 converted_graph_def = converter.convert()
408 converter.save(output_saved_model_dir)
409 ```
411 To run the conversion with quantization calibration:
413 ```python
414 converter = TrtGraphConverter(
415 input_saved_model_dir="my_dir",
416 precision_mode=TrtPrecisionMode.INT8)
417 converter.convert()
419 # Run calibration 10 times.
420 converted_graph_def = converter.calibrate(
421 fetch_names=['output:0'],
422 num_runs=10,
423 feed_dict_fn=lambda: {'input:0': my_next_data()})
425 converter.save(output_saved_model_dir)
426 ```
427 """
429 def __init__(self,
430 input_saved_model_dir=None,
431 input_saved_model_tags=None,
432 input_saved_model_signature_key=None,
433 input_graph_def=None,
434 nodes_denylist=None,
435 max_batch_size=1,
436 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
437 precision_mode=TrtPrecisionMode.FP32,
438 minimum_segment_size=3,
439 is_dynamic_op=False,
440 maximum_cached_engines=1,
441 use_calibration=True):
442 """Initializes the converter.
444 Args:
445 input_saved_model_dir: the directory to load the SavedModel which contains
446 the input graph to transforms. Used only when input_graph_def is None.
447 input_saved_model_tags: list of tags to load the SavedModel.
448 input_saved_model_signature_key: the key of the signature to optimize the
449 graph for.
450 input_graph_def: a GraphDef object containing a model to be transformed.
451 If set to None, the graph will be read from the SavedModel loaded from
452 input_saved_model_dir.
453 nodes_denylist: list of node names to prevent the converter from touching.
454 max_batch_size: max size for the input batch.
455 max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
456 engine can use at execution time. This corresponds to the
457 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
458 precision_mode: one of TrtPrecisionMode.supported_precision_modes().
459 minimum_segment_size: the minimum number of nodes required for a subgraph
460 to be replaced by TRTEngineOp.
461 is_dynamic_op: whether to generate dynamic TRT ops which will build the
462 TRT network and engine at run time.
463 maximum_cached_engines: max number of cached TRT engines in dynamic TRT
464 ops. If the number of cached engines is already at max but none of them
465 can serve the input, the TRTEngineOp will fall back to run the TF
466 function based on which the TRTEngineOp is created.
467 use_calibration: this argument is ignored if precision_mode is not INT8.
468 If set to True, a calibration graph will be created to calibrate the
469 missing ranges. The calibration graph must be converted to an inference
470 graph by running calibration with calibrate(). If set to False,
471 quantization nodes will be expected for every tensor in the graph
472 (excluding those which will be fused). If a range is missing, an error
473 will occur. Please note that accuracy may be negatively affected if
474 there is a mismatch between which tensors TRT quantizes and which
475 tensors were trained with fake quantization.
477 Raises:
478 ValueError: if the combination of the parameters is invalid.
479 RuntimeError: if this class is used in TF 2.0.
480 """
481 if context.executing_eagerly():
482 raise RuntimeError(
483 "Please use tf.experimental.tensorrt.Converter in TF 2.0.")
485 if input_graph_def and input_saved_model_dir:
486 raise ValueError(
487 "Can only specify one of input_graph_def and input_saved_model_dir")
488 if not input_graph_def and not input_saved_model_dir:
489 raise ValueError("Must specify one of input_graph_def and "
490 "input_saved_model_dir")
491 _check_trt_version_compatibility()
493 self._input_graph_def = input_graph_def
494 self._nodes_denylist = nodes_denylist
496 self._input_saved_model_dir = input_saved_model_dir
497 self._converted = False
498 self._grappler_meta_graph_def = None
500 self._input_saved_model_tags = (
501 input_saved_model_tags or [tag_constants.SERVING])
502 self._input_saved_model_signature_key = (
503 input_saved_model_signature_key or
504 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
506 # For calibration usage.
507 self._calibration_graph = None
508 self._calibration_data_collected = False
509 self._need_calibration = (
510 ((precision_mode == TrtPrecisionMode.INT8) or
511 (precision_mode == TrtPrecisionMode.INT8.lower())) and use_calibration)
512 if self._need_calibration and not is_dynamic_op:
513 logging.warn(
514 "INT8 precision mode with calibration is supported with "
515 "dynamic TRT ops only. Disregarding is_dynamic_op parameter.")
516 is_dynamic_op = True
518 self._is_dynamic_op = is_dynamic_op
519 if is_dynamic_op:
520 self._max_batch_size = None
521 if max_batch_size is not None:
522 logging.warn("When is_dynamic_op==True max_batch_size should be None")
523 else:
524 if not isinstance(max_batch_size, int):
525 raise ValueError("When is_dynamic_op==False max_batch_size should be "
526 "an integer")
527 self._max_batch_size = max_batch_size
529 self._conversion_params = TrtConversionParams(
530 max_workspace_size_bytes=max_workspace_size_bytes,
531 precision_mode=precision_mode,
532 minimum_segment_size=minimum_segment_size,
533 maximum_cached_engines=maximum_cached_engines,
534 use_calibration=use_calibration,
535 allow_build_at_runtime=True)
536 _check_conversion_params(self._conversion_params)
538 self._test_only_disable_non_trt_optimizers = False
540 def _run_conversion(self):
541 """Run Grappler's OptimizeGraph() tool to convert the graph."""
542 # Create custom ConfigProto for Grappler.
543 grappler_session_config = config_pb2.ConfigProto()
544 custom_rewriter_config = _get_tensorrt_rewriter_config(
545 conversion_params=self._conversion_params,
546 is_dynamic_op=self._is_dynamic_op,
547 max_batch_size=self._max_batch_size,
548 disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers,
549 use_implicit_batch=True)
550 grappler_session_config.graph_options.rewrite_options.CopyFrom(
551 custom_rewriter_config)
553 # Run Grappler.
554 self._converted_graph_def = tf_optimizer.OptimizeGraph(
555 grappler_session_config,
556 self._grappler_meta_graph_def,
557 graph_id=b"tf_graph")
558 self._converted = True
560 def _add_nodes_denylist(self):
561 if self._nodes_denylist:
562 collection_def = self._grappler_meta_graph_def.collection_def["train_op"]
563 denylist = collection_def.node_list.value
564 for i in self._nodes_denylist:
565 if isinstance(i, ops.Tensor):
566 denylist.append(_to_bytes(i.name))
567 else:
568 denylist.append(_to_bytes(i))
570 def _convert_graph_def(self):
571 """Convert the input GraphDef."""
572 graph = ops.Graph()
573 with graph.as_default():
574 importer.import_graph_def(self._input_graph_def, name="")
575 self._grappler_meta_graph_def = saver.export_meta_graph(
576 graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
577 self._add_nodes_denylist()
579 self._run_conversion()
581 def _collections_to_keep(self, collection_keys):
582 # TODO(laigd): currently we use the collection key to filter out
583 # collections that depend on variable ops, but this may miss some
584 # other user-defined collections. A better way would be to use
585 # CollectionDef::NodeList for the filtering.
586 collections_to_remove = (
587 ops.GraphKeys._VARIABLE_COLLECTIONS + [
588 ops.GraphKeys.TRAIN_OP, ops.GraphKeys.WHILE_CONTEXT,
589 ops.GraphKeys.COND_CONTEXT
590 ])
591 return [key for key in collection_keys if key not in collections_to_remove]
593 def _convert_saved_model(self):
594 """Convert the input SavedModel."""
595 graph = ops.Graph()
596 with session.Session(graph=graph) as sess:
597 input_meta_graph_def = loader.load(sess, self._input_saved_model_tags,
598 self._input_saved_model_dir)
599 input_signature_def = input_meta_graph_def.signature_def[
600 self._input_saved_model_signature_key]
602 def _gather_names(tensor_info):
603 """Get the node names from a TensorInfo."""
604 return {tensor_info[key].name.split(":")[0] for key in tensor_info}
606 # Get input and outputs from all SignatureDef.
607 output_node_names = _gather_names(input_signature_def.inputs).union(
608 _gather_names(input_signature_def.outputs))
610 # Preserve nodes in collection
611 for collection_key in self._collections_to_keep(
612 input_meta_graph_def.collection_def):
613 for op in sess.graph.get_collection(collection_key):
614 if isinstance(op, ops.Operation):
615 output_node_names.add(op.name.split(":")[0])
617 # Freeze the variables in the SavedModel graph and copy the frozen
618 # graph over.
619 frozen_graph_def = convert_to_constants.convert_variables_to_constants(
620 sess, sess.graph.as_graph_def(add_shapes=True),
621 list(output_node_names))
622 self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
623 self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
625 # Copy the collections that are not variables.
626 for collection_key in self._collections_to_keep(
627 input_meta_graph_def.collection_def):
628 self._grappler_meta_graph_def.collection_def[collection_key].CopyFrom(
629 input_meta_graph_def.collection_def[collection_key])
631 self._add_nodes_denylist()
633 # Copy other information.
634 self._grappler_meta_graph_def.meta_info_def.CopyFrom(
635 input_meta_graph_def.meta_info_def)
636 self._grappler_meta_graph_def.signature_def[
637 self._input_saved_model_signature_key].CopyFrom(input_signature_def)
638 # TODO(laigd): maybe add back AssetFileDef.
640 self._run_conversion()
642 def convert(self):
643 """Run the TF-TRT conversion.
645 Returns:
646 The converted GraphDef for TF 1.x.
647 """
648 assert not self._converted
649 if self._input_graph_def:
650 self._convert_graph_def()
651 else:
652 self._convert_saved_model()
653 return self._converted_graph_def
655 def calibrate(self,
656 fetch_names,
657 num_runs,
658 feed_dict_fn=None,
659 input_map_fn=None):
660 """Run the calibration and return the calibrated GraphDef.
662 Args:
663 fetch_names: a list of output tensor name to fetch during calibration.
664 num_runs: number of runs of the graph during calibration.
665 feed_dict_fn: a function that returns a dictionary mapping input names (as
666 strings) in the GraphDef to be calibrated to values (e.g. Python list,
667 numpy arrays, etc). One and only one of `feed_dict_fn` and
668 `input_map_fn` should be specified.
669 input_map_fn: a function that returns a dictionary mapping input names (as
670 strings) in the GraphDef to be calibrated to Tensor objects. The values
671 of the named input tensors in the GraphDef to be calibrated will be
672 re-mapped to the respective `Tensor` values during calibration. One and
673 only one of `feed_dict_fn` and `input_map_fn` should be specified.
675 Raises:
676 ValueError: if the input combination is invalid.
677 RuntimeError: if this method is called in eager mode.
679 Returns:
680 The GraphDef after the calibration.
681 """
682 assert self._converted
683 assert self._need_calibration
684 assert not self._calibration_data_collected
686 if (feed_dict_fn and input_map_fn) or (not feed_dict_fn and
687 not input_map_fn):
688 raise ValueError(
689 "Should specify one and only one of feed_dict_fn and input_map_fn.")
691 if input_map_fn:
692 for k, v in input_map_fn().items():
693 if not isinstance(k, str):
694 raise ValueError("Keys of input_map_fn must be of type str")
695 if not isinstance(v, ops.Tensor):
696 raise ValueError("Values of input_map_fn must be of type tf.Tensor")
698 self._calibration_graph = ops.Graph()
699 with self._calibration_graph.as_default():
700 fetches = importer.import_graph_def(
701 self._converted_graph_def,
702 input_map=input_map_fn() if input_map_fn else None,
703 return_elements=fetch_names,
704 name="")
706 calibrate_rewriter_cfg = rewriter_config_pb2.RewriterConfig()
707 if self._test_only_disable_non_trt_optimizers:
708 trt_utils.disable_non_trt_optimizers_in_rewriter_config(
709 calibrate_rewriter_cfg)
711 # Set allow_soft_placement=True to run the graph for calibration so that
712 # OPs supported by TensorRT but don't have a GPU implementation are allowed
713 # to execute on CPU.
714 calibrate_config = config_pb2.ConfigProto(
715 allow_soft_placement=True,
716 graph_options=config_pb2.GraphOptions(
717 rewrite_options=calibrate_rewriter_cfg))
719 with session.Session(
720 graph=self._calibration_graph,
721 config=calibrate_config) as calibration_sess:
722 for _ in range(num_runs):
723 calibration_sess.run(
724 fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None)
726 # Maps device name to the corresponding get_calibration_data.
727 #
728 # TODO(laigd): a better way would be to use calibration_sess to list
729 # all the devices, add one get_calibration_data for each device, and
730 # fetch each such op for every resource until its found. This can work
731 # even when the device of the TRTEngineOp is empty or not fully specified.
732 device_to_get_resource_op_map = {}
734 with self._calibration_graph.as_default():
735 resource_name_input = array_ops.placeholder(dtypes.string)
737 for node in self._converted_graph_def.node:
738 if node.op == _TRT_ENGINE_OP_NAME:
739 # Adds the get_calibration_data op for the device if not done
740 # before. We only add one such op for each device.
741 # TODO(laigd): What if the device is empty?????
742 if node.device not in device_to_get_resource_op_map:
743 with self._calibration_graph.device(node.device):
744 serialized_resources_output = (
745 gen_trt_ops.get_calibration_data_op(resource_name_input))
746 device_to_get_resource_op_map[node.device] = (
747 serialized_resources_output)
749 # Get the calibration resource.
750 calibration_result = calibration_sess.run(
751 device_to_get_resource_op_map[node.device],
752 feed_dict={
753 resource_name_input: _get_canonical_engine_name(node.name)
754 })
755 node.attr["calibration_data"].s = calibration_result
757 self._calibration_data_collected = True
759 return self._converted_graph_def
761 def save(self, output_saved_model_dir):
762 """Save the converted graph as a SavedModel.
764 Args:
765 output_saved_model_dir: construct a SavedModel using the converted
766 GraphDef and save it to the specified directory. This option only works
767 when the input graph is loaded from a SavedModel, i.e. when
768 input_saved_model_dir is specified and input_graph_def is None in
769 __init__().
771 Raises:
772 ValueError: if the input to the converter is a GraphDef instead of a
773 SavedModel.
774 """
775 assert self._converted
776 if self._need_calibration:
777 assert self._calibration_data_collected
778 if self._input_graph_def:
779 raise ValueError(
780 "Not able to save to a SavedModel since input is a GraphDef")
782 def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
783 """Restores collections that we need to keep."""
784 scope = ""
785 for key in collection_keys:
786 collection_def = src_meta_graph_def.collection_def[key]
787 kind = collection_def.WhichOneof("kind")
788 if kind is None:
789 logging.error(
790 "Cannot identify data type for collection %s. Skipping.", key)
791 continue
792 from_proto = ops.get_from_proto_function(key)
793 if from_proto and kind == "bytes_list":
794 proto_type = ops.get_collection_proto_type(key)
795 # It is assumed that there are no Variables Keys in collections
796 for value in collection_def.bytes_list.value:
797 proto = proto_type()
798 proto.ParseFromString(value)
799 try:
800 new_value = from_proto(proto, import_scope=scope)
801 except:
802 continue
803 dest_graph.add_to_collection(key, new_value)
804 else:
805 field = getattr(collection_def, kind)
806 if kind == "node_list":
807 for value in field.value:
808 name = ops.prepend_name_scope(value, scope)
809 # Since the graph has been optimized, the node may no longer
810 # exists
811 try:
812 col_op = dest_graph.as_graph_element(name)
813 except (TypeError, ValueError, KeyError):
814 continue
815 dest_graph.add_to_collection(key, col_op)
816 elif kind == "int64_list":
817 # NOTE(opensource): This force conversion is to work around the
818 # fact that Python2 distinguishes between int and long, while
819 # Python3 has only int.
820 for value in field.value:
821 dest_graph.add_to_collection(key, int(value))
822 else:
823 for value in field.value:
824 dest_graph.add_to_collection(key,
825 ops.prepend_name_scope(value, scope))
827 # Write the transformed graphdef as SavedModel.
828 saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
829 with ops.Graph().as_default():
830 importer.import_graph_def(self._converted_graph_def, name="")
831 _restore_collections(
832 ops.get_default_graph(), self._grappler_meta_graph_def,
833 self._collections_to_keep(
834 self._grappler_meta_graph_def.collection_def))
835 # We don't use any specific converter here.
836 with session.Session() as sess:
837 saved_model_builder.add_meta_graph_and_variables(
838 sess,
839 self._input_saved_model_tags,
840 signature_def_map=self._grappler_meta_graph_def.signature_def)
841 # Ignore other meta graphs from the input SavedModel.
842 saved_model_builder.save()
844def _get_resource_handle(name, device):
845 with ops.device(device):
846 return gen_trt_ops.create_trt_resource_handle(resource_name=name)
849def _remove_native_segments(input_func):
850 """Remove native segments from the input TF-TRT Converted Function.
852 Args:
853 input_func: provide the concrete function with native segment nodes. The
854 transformed output func will not contain any native segment nodes. All the
855 TRTEngineOp references will be deleted and reset to default empty func.
856 """
857 input_graph_def = input_func.graph.as_graph_def()
858 # Deleting the Native Segment node in each TRTEngineOp node.
859 nodes_deleted = 0
860 for func_id in reversed(range(len(input_graph_def.library.function))):
861 f = input_graph_def.library.function[func_id]
862 if "native_segment" in f.signature.name:
863 nodes_deleted += 1
864 while context.context().has_function(f.signature.name):
865 context.context().remove_function(f.signature.name)
866 del input_graph_def.library.function[func_id]
868 logging.info(
869 "Found and deleted native segments from "
870 f"{nodes_deleted} TRTEngineOp nodes."
871 )
873 # Deleting the references to `<EngineName>_native_segment`s.
874 # This helps TRTEngineOp constructor to not look for native segment handles
875 # during construction of graph for inference.
876 for node in input_graph_def.node:
877 if node.op == "TRTEngineOp":
878 del node.attr["segment_func"]
879 for func in input_graph_def.library.function:
880 for node in func.node_def:
881 if node.op == "TRTEngineOp":
882 del node.attr["segment_func"]
883 # Reconstruct the converted_func with the new graph
884 new_func = _construct_function_from_graph_def(input_func, input_graph_def)
886 return new_func
889class _TRTEngineResource(resource.TrackableResource):
890 """Class to track the serialized engines resource."""
892 def __init__(self,
893 resource_name,
894 filename,
895 maximum_cached_engines,
896 device="GPU"):
897 super(_TRTEngineResource, self).__init__(device=device)
898 self._resource_name = resource_name
899 # Track the serialized engine file in the SavedModel.
900 self._filename = self._track_trackable(
901 asset.Asset(filename), "_serialized_trt_resource_filename")
902 self._maximum_cached_engines = maximum_cached_engines
904 def _create_resource(self):
905 return _get_resource_handle(self._resource_name, self._resource_device)
907 def _initialize(self):
908 gen_trt_ops.initialize_trt_resource(
909 self.resource_handle,
910 self._filename,
911 max_cached_engines_count=self._maximum_cached_engines)
913 def _destroy_resource(self):
914 handle = _get_resource_handle(self._resource_name, self._resource_device)
915 with ops.device(self._resource_device):
916 gen_resource_variable_ops.destroy_resource_op(
917 handle, ignore_lookup_error=True)
920def _print_row(fields, positions, print_fn):
921 """Prints a row."""
922 line = ""
923 for i, field in enumerate(fields):
924 field = str(field)
925 end_line_pos = positions[i]
926 if i > 0:
927 line = line + " "
928 line = "{0:{min_length}}".format(line + field, min_length=end_line_pos)
930 if len(line) > end_line_pos:
931 line = line[:(end_line_pos - 4)] + " ..."
933 print_fn(line)
936def _construct_function_from_graph_def(func, graph_def, frozen_func=None):
937 """Rebuild function from graph_def."""
938 if frozen_func is None:
939 frozen_func = func
941 # If a function is converted, then the TF context contains the original
942 # function while the converted_graph_def contains the converted function.
943 # Remove the original function from the TF context in this case.
944 for f in graph_def.library.function:
945 while context.context().has_function(f.signature.name):
946 context.context().remove_function(f.signature.name)
948 captures = {
949 c[1].name.split(":")[0]: c[0]
950 for c in frozen_func.graph.captures
951 }
952 new_func = wrap_function.function_from_graph_def(
953 graph_def, [tensor.name for tensor in frozen_func.inputs],
954 [tensor.name for tensor in frozen_func.outputs], captures)
955 new_func.graph.structured_outputs = nest.pack_sequence_as(
956 func.graph.structured_outputs, new_func.graph.structured_outputs)
958 # Copy structured input signature from original function (used during
959 # serialization)
960 new_func.graph.structured_input_signature = (func.structured_input_signature)
962 return new_func
965def _apply_inlining(func):
966 """Apply an inlining optimization to the function's graph definition."""
967 graph_def = func.graph.as_graph_def()
969 # In some cases, a secondary implementation of the function (e.g. for GPU) is
970 # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in
971 # TF2 produces a CuDNN-based RNN for GPU).
972 # This function suppose to inline all functions calls, but "api_implements"
973 # prevents this from happening. Removing the attribute solves the problem.
974 # To learn more about "api_implements", see:
975 # tensorflow/core/grappler/optimizers/implementation_selector.h
976 for function in graph_def.library.function:
977 if "api_implements" in function.attr:
978 del function.attr["api_implements"]
980 meta_graph = saver.export_meta_graph(graph_def=graph_def, graph=func.graph)
982 # Clear the initializer_name for the variables collections, since they are not
983 # needed after saved to saved_model.
984 for name in [
985 "variables", "model_variables", "trainable_variables", "local_variables"
986 ]:
987 raw_list = []
988 for raw in meta_graph.collection_def["variables"].bytes_list.value:
989 variable = variable_pb2.VariableDef()
990 variable.ParseFromString(raw)
991 variable.ClearField("initializer_name")
992 raw_list.append(variable.SerializeToString())
993 meta_graph.collection_def[name].bytes_list.value[:] = raw_list
995 # Add a collection 'train_op' so that Grappler knows the outputs.
996 fetch_collection = meta_graph_pb2.CollectionDef()
997 for array in func.inputs + func.outputs:
998 fetch_collection.node_list.value.append(array.name)
999 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
1001 # Initialize RewriterConfig with everything disabled except function inlining.
1002 config = config_pb2.ConfigProto()
1003 rewrite_options = config.graph_options.rewrite_options
1004 rewrite_options.min_graph_nodes = -1 # do not skip small graphs
1005 rewrite_options.optimizers.append("function")
1007 new_graph_def = tf_optimizer.OptimizeGraph(config, meta_graph)
1009 return new_graph_def
1012def _annotate_variable_ops(func, graph_def):
1013 """Annotates variable operations with custom `_shape` attribute.
1015 This is required for the converters and shape inference. The graph
1016 definition is modified in-place.
1018 Args:
1019 func: Function represented by the graph definition.
1020 graph_def: Graph definition to be annotated in-place.
1022 Raises:
1023 RuntimeError: if some shapes cannot be annotated.
1024 """
1025 ph_shape_map = {}
1026 for ph, var in zip(func.graph.internal_captures, func.variables):
1027 ph_shape_map[ph.name] = var.shape
1028 # Construct a mapping of node names to nodes
1029 name_to_node = {node.name: node for node in graph_def.node}
1030 # Go through all the ReadVariableOp nodes in the graph def
1031 for node in graph_def.node:
1032 if node.op == "ReadVariableOp" or node.op == "ResourceGather":
1033 node_ = node
1034 # Go up the chain of identities to find a placeholder
1035 while name_to_node[node_.input[0]].op == "Identity":
1036 node_ = name_to_node[node_.input[0]]
1037 ph_name = node_.input[0] + ":0"
1038 if ph_name in ph_shape_map:
1039 shape = ph_shape_map[ph_name]
1040 node.attr["_shape"].shape.CopyFrom(shape.as_proto())
1041 else:
1042 raise RuntimeError(
1043 "Not found in the function captures: {}".format(ph_name))
1046def _save_calibration_table(node):
1047 try:
1048 calibration_table = gen_trt_ops.get_calibration_data_op(
1049 _get_canonical_engine_name(node.name))
1050 node.attr["calibration_data"].s = calibration_table.numpy()
1051 except (errors.UnknownError, errors.NotFoundError):
1052 logging.warning("Warning calibration error for %s", node.name)
1055def _convert_to_tensor(inp):
1056 try:
1057 if isinstance(inp, dict):
1058 args = []
1059 kwargs = {k: ops.convert_to_tensor(v) for k, v in inp.items()}
1060 else:
1061 kwargs = {}
1062 if isinstance(inp, (list, tuple)):
1063 args = map(ops.convert_to_tensor, inp)
1064 else:
1065 args = [ops.convert_to_tensor(inp)]
1066 except:
1067 error_msg = "Failed to convert input to tensor."
1068 logging.error(error_msg + "\ninp = `{0}`\n".format(inp))
1069 raise RuntimeError(error_msg)
1071 return args, kwargs
1074@tf_export("experimental.tensorrt.Converter", v1=[])
1075class TrtGraphConverterV2(object):
1076 """An offline converter for TF-TRT transformation for TF 2.0 SavedModels.
1078 Windows support is provided experimentally. No guarantee is made regarding
1079 functionality or engineering support. Use at your own risk.
1081 There are several ways to run the conversion:
1083 1. FP32/FP16 precision
1085 ```python
1086 params = tf.experimental.tensorrt.ConversionParams(
1087 precision_mode='FP16')
1088 converter = tf.experimental.tensorrt.Converter(
1089 input_saved_model_dir="my_dir", conversion_params=params)
1090 converter.convert()
1091 converter.save(output_saved_model_dir)
1092 ```
1094 In this case, no TRT engines will be built or saved in the converted
1095 SavedModel. But if input data is available during conversion, we can still
1096 build and save the TRT engines to reduce the cost during inference (see
1097 option 2 below).
1099 2. FP32/FP16 precision with pre-built engines
1101 ```python
1102 params = tf.experimental.tensorrt.ConversionParams(
1103 precision_mode='FP16',
1104 # Set this to a large enough number so it can cache all the engines.
1105 maximum_cached_engines=16)
1106 converter = tf.experimental.tensorrt.Converter(
1107 input_saved_model_dir="my_dir", conversion_params=params)
1108 converter.convert()
1110 # Define a generator function that yields input data, and use it to execute
1111 # the graph to build TRT engines.
1112 def my_input_fn():
1113 for _ in range(num_runs):
1114 inp1, inp2 = ...
1115 yield inp1, inp2
1117 converter.build(input_fn=my_input_fn) # Generate corresponding TRT engines
1118 converter.save(output_saved_model_dir) # Generated engines will be saved.
1119 ```
1121 In this way, one engine will be built/saved for each unique input shapes of
1122 the TRTEngineOp. This is good for applications that cannot afford building
1123 engines during inference but have access to input data that is similar to
1124 the one used in production (for example, that has the same input shapes).
1125 Also, the generated TRT engines is platform dependent, so we need to run
1126 `build()` in an environment that is similar to production (e.g. with
1127 same type of GPU).
1129 3. INT8 precision and calibration with pre-built engines
1131 ```python
1132 params = tf.experimental.tensorrt.ConversionParams(
1133 precision_mode='INT8',
1134 # Currently only one INT8 engine is supported in this mode.
1135 maximum_cached_engines=1,
1136 use_calibration=True)
1137 converter = tf.experimental.tensorrt.Converter(
1138 input_saved_model_dir="my_dir", conversion_params=params)
1140 # Define a generator function that yields input data, and run INT8
1141 # calibration with the data. All input data should have the same shape.
1142 # At the end of convert(), the calibration stats (e.g. range information)
1143 # will be saved and can be used to generate more TRT engines with different
1144 # shapes. Also, one TRT engine will be generated (with the same shape as
1145 # the calibration data) for save later.
1146 def my_calibration_input_fn():
1147 for _ in range(num_runs):
1148 inp1, inp2 = ...
1149 yield inp1, inp2
1151 converter.convert(calibration_input_fn=my_calibration_input_fn)
1153 # (Optional) Generate more TRT engines offline (same as the previous
1154 # option), to avoid the cost of generating them during inference.
1155 def my_input_fn():
1156 for _ in range(num_runs):
1157 inp1, inp2 = ...
1158 yield inp1, inp2
1159 converter.build(input_fn=my_input_fn)
1161 # Save the TRT engine and the engines.
1162 converter.save(output_saved_model_dir)
1163 ```
1164 4. To use dynamic shape, we need to call the build method with an input
1165 function to generate profiles. This step is similar to the INT8 calibration
1166 step described above. The converter also needs to be created with
1167 use_dynamic_shape=True and one of the following profile_strategies for
1168 creating profiles based on the inputs produced by the input function:
1169 * `Range`: create one profile that works for inputs with dimension values
1170 in the range of [min_dims, max_dims] where min_dims and max_dims are
1171 derived from the provided inputs.
1172 * `Optimal`: create one profile for each input. The profile only works for
1173 inputs with the same dimensions as the input it is created for. The GPU
1174 engine will be run with optimal performance with such inputs.
1175 * `Range+Optimal`: create the profiles for both `Range` and `Optimal`.
1176 """
1178 def _verify_profile_strategy(self, strategy):
1179 supported_strategies = [s.lower() for s in supported_profile_strategies()]
1180 if strategy.lower() not in supported_strategies:
1181 raise ValueError(
1182 ("profile_strategy '{}' is not supported. It should be one of {}"
1183 ).format(strategy, supported_profile_strategies()))
1184 if strategy == "ImplicitBatchModeCompatible":
1185 logging.warn(
1186 "ImplicitBatchModeCompatible strategy is deprecated, and"
1187 " using it may result in errors during engine building. Please"
1188 " consider using a different profile strategy.")
1190 @deprecation.deprecated_args(None,
1191 "Use individual converter parameters instead",
1192 "conversion_params")
1193 def __init__(self,
1194 input_saved_model_dir=None,
1195 input_saved_model_tags=None,
1196 input_saved_model_signature_key=None,
1197 use_dynamic_shape=None,
1198 dynamic_shape_profile_strategy=None,
1199 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
1200 precision_mode=TrtPrecisionMode.FP32,
1201 minimum_segment_size=3,
1202 maximum_cached_engines=1,
1203 use_calibration=True,
1204 allow_build_at_runtime=True,
1205 conversion_params=None):
1206 """Initialize the converter.
1208 Args:
1209 input_saved_model_dir: the directory to load the SavedModel which contains
1210 the input graph to transforms. Required.
1211 input_saved_model_tags: list of tags to load the SavedModel.
1212 input_saved_model_signature_key: the key of the signature to optimize the
1213 graph for.
1214 use_dynamic_shape: whether to enable dynamic shape support. None is
1215 equivalent to False in the current implementation.
1216 dynamic_shape_profile_strategy: one of the strings in
1217 supported_profile_strategies(). None is equivalent to Range in the
1218 current implementation.
1219 max_workspace_size_bytes: the maximum GPU temporary memory that the TRT
1220 engine can use at execution time. This corresponds to the
1221 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
1222 precision_mode: one of the strings in
1223 TrtPrecisionMode.supported_precision_modes().
1224 minimum_segment_size: the minimum number of nodes required for a subgraph
1225 to be replaced by TRTEngineOp.
1226 maximum_cached_engines: max number of cached TRT engines for dynamic TRT
1227 ops. Created TRT engines for a dynamic dimension are cached. If the
1228 number of cached engines is already at max but none of them supports the
1229 input shapes, the TRTEngineOp will fall back to run the original TF
1230 subgraph that corresponds to the TRTEngineOp.
1231 use_calibration: this argument is ignored if precision_mode is not INT8.
1232 If set to True, a calibration graph will be created to calibrate the
1233 missing ranges. The calibration graph must be converted to an inference
1234 graph by running calibration with calibrate(). If set to False,
1235 quantization nodes will be expected for every tensor in the graph
1236 (excluding those which will be fused). If a range is missing, an error
1237 will occur. Please note that accuracy may be negatively affected if
1238 there is a mismatch between which tensors TRT quantizes and which
1239 tensors were trained with fake quantization.
1240 allow_build_at_runtime: whether to allow building TensorRT engines during
1241 runtime if no prebuilt TensorRT engine can be found that can handle the
1242 given inputs during runtime, then a new TensorRT engine is built at
1243 runtime if allow_build_at_runtime=True, and otherwise native TF is used.
1244 conversion_params: a TrtConversionParams instance (deprecated).
1246 Raises:
1247 ValueError: if the combination of the parameters is invalid.
1248 """
1249 assert context.executing_eagerly()
1250 if conversion_params is None:
1251 conversion_params = TrtConversionParams(
1252 max_workspace_size_bytes=max_workspace_size_bytes,
1253 precision_mode=precision_mode,
1254 minimum_segment_size=minimum_segment_size,
1255 maximum_cached_engines=maximum_cached_engines,
1256 use_calibration=use_calibration,
1257 allow_build_at_runtime=allow_build_at_runtime)
1259 _check_trt_version_compatibility()
1260 _check_conversion_params(conversion_params, is_v2=True)
1262 self._conversion_params = conversion_params
1263 self._input_saved_model_dir = input_saved_model_dir
1264 self._input_saved_model_tags = (
1265 input_saved_model_tags or [tag_constants.SERVING])
1266 self._input_saved_model_signature_key = (
1267 input_saved_model_signature_key or
1268 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
1269 self.freeze = not trt_utils.is_experimental_feature_activated(
1270 "disable_graph_freezing")
1272 self._need_calibration = ((
1273 (conversion_params.precision_mode == TrtPrecisionMode.INT8) or
1274 (conversion_params.precision_mode == TrtPrecisionMode.INT8.lower())) and
1275 conversion_params.use_calibration)
1277 self._calibration_input_fn = None
1279 self._converted = False
1280 self._device = None
1281 self._build_called_once = False
1282 self._calibrated = False
1284 if use_dynamic_shape is None:
1285 self._use_dynamic_shape = False
1286 else:
1287 self._use_dynamic_shape = use_dynamic_shape
1289 if not self.freeze and not self._use_dynamic_shape:
1290 logging.warn(
1291 "Disabling graph freezing is only possible in dynamic shape mode."
1292 " The graph will be frozen.")
1293 self.freeze = True
1295 self._profile_strategy = "Unknown"
1296 if self._use_dynamic_shape:
1297 if dynamic_shape_profile_strategy is None:
1298 self._profile_strategy = PROFILE_STRATEGY_RANGE
1299 else:
1300 self._verify_profile_strategy(dynamic_shape_profile_strategy)
1301 self._profile_strategy = dynamic_shape_profile_strategy
1303 # Fields to support TF-TRT testing and shouldn't be used for other purpose.
1304 self._test_only_disable_non_trt_optimizers = False
1306 def _need_trt_profiles(self):
1307 return self._use_dynamic_shape
1309 def _run_conversion(self, meta_graph_def):
1310 """Run Grappler's OptimizeGraph() tool to convert the graph.
1312 Args:
1313 meta_graph_def: the MetaGraphDef instance to run the optimizations on.
1315 Returns:
1316 The optimized GraphDef.
1317 """
1318 grappler_session_config = config_pb2.ConfigProto()
1319 # Always set `allow_build_at_runtime` for offline TensorRT engine building.
1320 custom_rewriter_config = _get_tensorrt_rewriter_config(
1321 conversion_params=self._conversion_params._replace(
1322 allow_build_at_runtime=True),
1323 is_dynamic_op=True,
1324 max_batch_size=None,
1325 disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers,
1326 use_implicit_batch=not self._use_dynamic_shape,
1327 profile_strategy=self._profile_strategy)
1328 grappler_session_config.graph_options.rewrite_options.CopyFrom(
1329 custom_rewriter_config)
1330 return tf_optimizer.OptimizeGraph(
1331 grappler_session_config, meta_graph_def, graph_id=b"tf_graph")
1333 def _for_each_trt_node(self, graph_def, fn):
1334 """Helper method to manipulate all TRTEngineOps in a GraphDef."""
1335 for node in graph_def.node:
1336 if node.op == _TRT_ENGINE_OP_NAME:
1337 fn(node)
1338 for func in graph_def.library.function:
1339 for node in func.node_def:
1340 if node.op == _TRT_ENGINE_OP_NAME:
1341 fn(node)
1343 def _execute_calibration(self, calibration_input_fn):
1344 """Run INT8 calibration with the provided input generator function."""
1345 for inp in calibration_input_fn():
1346 args, kwargs = _convert_to_tensor(inp)
1347 self._converted_func(*args, **kwargs)
1349 self._for_each_trt_node(self._converted_graph_def, _save_calibration_table)
1351 # Rebuild the function since calibration has changed the graph.
1352 self._converted_func = _construct_function_from_graph_def(
1353 self._converted_func, self._converted_graph_def)
1354 self._calibrated = True
1356 # TODO(laigd): provide a utility function to optimize a ConcreteFunction and
1357 # use it here (b/124792963).
1358 def convert(self, calibration_input_fn=None):
1359 """Convert the input SavedModel in 2.0 format.
1361 Args:
1362 calibration_input_fn: a generator function that yields input data as a
1363 list or tuple or dict, which will be used to execute the converted
1364 signature for calibration. All the returned input data should have the
1365 same shape. Example: `def input_fn(): yield input1, input2, input3`
1367 If dynamic_shape_mode==False, (or if the graph has static input shapes)
1368 then we run calibration and build the calibrated engine during
1369 conversion.
1371 If dynamic_shape_mode==True (and the graph has any unknown input
1372 shape), then the reference to calibration_input_fn is stored, and the
1373 calibration is actually performed when we build the engine (see
1374 build()).
1376 Raises:
1377 ValueError: if the input combination is invalid.
1379 Returns:
1380 The TF-TRT converted Function.
1381 """
1382 assert not self._converted
1384 # Creating an empty tensor to fetch queried device
1385 device_requested = array_ops.zeros([]).device
1387 if "gpu" not in device_requested.lower():
1388 raise ValueError(f"Specified device is not a GPU: {device_requested}")
1390 if "gpu:0" not in device_requested.lower():
1391 self._device = device_requested
1392 logging.info(f"Placing imported graph from "
1393 f"`{self._input_saved_model_dir}` on device: {self._device}")
1395 if (self._need_calibration and not calibration_input_fn):
1396 raise ValueError("Should specify calibration_input_fn because INT8 "
1397 "calibration is needed")
1398 if (not self._need_calibration and calibration_input_fn):
1399 raise ValueError("Should not specify calibration_input_fn because INT8 "
1400 "calibration is not needed")
1402 self._saved_model = load.load(self._input_saved_model_dir,
1403 self._input_saved_model_tags)
1404 func = self._saved_model.signatures[self._input_saved_model_signature_key]
1405 if self.freeze:
1406 frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
1407 else:
1408 inlined_graph_def = _apply_inlining(func)
1409 _annotate_variable_ops(func, inlined_graph_def)
1410 frozen_func = _construct_function_from_graph_def(func, inlined_graph_def)
1411 frozen_graph_def = frozen_func.graph.as_graph_def()
1413 # Clear any prior device assignments
1414 logging.info("Clearing prior device assignments in loaded saved model")
1415 for node in frozen_graph_def.node:
1416 node.device = ""
1418 if self._device is None:
1419 grappler_meta_graph_def = saver.export_meta_graph(
1420 graph_def=frozen_graph_def, graph=frozen_func.graph)
1421 else:
1422 with ops.Graph().as_default() as graph, ops.device(self._device):
1423 importer.import_graph_def(frozen_graph_def, name="")
1424 grappler_meta_graph_def = saver.export_meta_graph(
1425 graph_def=graph.as_graph_def(), graph=graph)
1427 # Add a collection 'train_op' so that Grappler knows the outputs.
1428 fetch_collection = meta_graph_pb2.CollectionDef()
1429 for array in frozen_func.inputs + frozen_func.outputs:
1430 fetch_collection.node_list.value.append(array.name)
1431 grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
1432 fetch_collection)
1434 # Run TRT optimizer in Grappler to convert the graph.
1435 self._converted_graph_def = self._run_conversion(grappler_meta_graph_def)
1436 self._converted_func = _construct_function_from_graph_def(
1437 func, self._converted_graph_def, frozen_func)
1439 if self._need_calibration:
1440 # Execute calibration here only if not in dynamic shape mode.
1441 if not self._need_trt_profiles():
1442 self._execute_calibration(calibration_input_fn)
1443 else:
1444 self._calibration_input_fn = calibration_input_fn
1446 self._converted = True
1448 graphviz_path = os.environ.get("TF_TRT_EXPORT_GRAPH_VIZ_PATH", default=None)
1449 if graphviz_path is not None:
1450 try:
1451 trt_utils.draw_graphdef_as_graphviz(
1452 graphdef=self._converted_func.graph.as_graph_def(add_shapes=True),
1453 dot_output_filename=graphviz_path)
1454 except Exception as e:
1455 logging.error("An Exception occured during the export of the graph "
1456 f"visualization: {e}")
1458 return self._converted_func
1460 def build(self, input_fn):
1461 """Run inference with converted graph in order to build TensorRT engines.
1463 If the conversion requires INT8 calibration, then a reference to the
1464 calibration function was stored during the call to convert(). Calibration
1465 will be performed while we build the TensorRT engines.
1467 Args:
1468 input_fn: a generator function that provides the input data as a single
1469 array, OR a list or tuple of the arrays OR a dict, which will be used
1470 to execute the converted signature to generate TRT engines.
1471 Example 1:
1472 `def input_fn():
1473 # Let's assume a network with 1 input tensor.
1474 # We generate 2 sets of dummy input data:
1475 input_shapes = [(1, 16), # 1st shape
1476 (2, 32)] # 2nd shape
1477 for shapes in input_shapes:
1478 # return an input tensor
1479 yield np.zeros(shape).astype(np.float32)'
1481 Example 2:
1482 `def input_fn():
1483 # Let's assume a network with 2 input tensors.
1484 # We generate 3 sets of dummy input data:
1485 input_shapes = [[(1, 16), (2, 16)], # 1st input list
1486 [(2, 32), (4, 32)], # 2nd list of two tensors
1487 [(4, 32), (8, 32)]] # 3rd input list
1488 for shapes in input_shapes:
1489 # return a list of input tensors
1490 yield [np.zeros(x).astype(np.float32) for x in shapes]`
1492 Raises:
1493 NotImplementedError: build() is already called.
1494 RuntimeError: the input_fx is None.
1495 """
1496 if self._build_called_once:
1497 raise NotImplementedError("build() is already called. It is not "
1498 "supported to call build() more than once.")
1499 if not input_fn:
1500 raise RuntimeError("input_fn is None. Method build() needs input_fn "
1501 "to be specified in order to build TensorRT engines")
1502 if not self._converted:
1503 raise RuntimeError("Need to call convert() before build()")
1504 if (self._need_calibration and not self._calibrated and
1505 self._calibration_input_fn is None):
1506 raise RuntimeError("Need to provide the calibration_input_fn arg while "
1507 "calling convert().")
1509 def _set_profile_generation_mode(value, node):
1510 node.attr["_profile_generation_mode"].b = value
1512 if self._need_trt_profiles():
1513 # Enable profile generation.
1514 self._for_each_trt_node(self._converted_graph_def,
1515 partial(_set_profile_generation_mode, True))
1516 # Profile generation is enabled using the _profile_generation_mode
1517 # attribute of the TRTEngineOps. We need to rebuild the function to
1518 # change this attribute.
1519 func = _construct_function_from_graph_def(self._converted_func,
1520 self._converted_graph_def)
1521 else:
1522 func = self._converted_func
1524 first_input = None
1525 # Run inference:
1526 # Builds TRT engines if self._need_trt_profiles is False.
1527 # Builds TRT optimization profiles if self._need_trt_profiles is True.
1528 for inp in input_fn():
1529 if first_input is None:
1530 first_input = inp
1531 args, kwargs = _convert_to_tensor(inp)
1532 func(*args, **kwargs)
1534 if self._need_trt_profiles():
1535 # Disable profile generation.
1536 self._for_each_trt_node(self._converted_graph_def,
1537 partial(_set_profile_generation_mode, False))
1539 # Run calibration if required, this would have been skipped in
1540 # the convert step
1541 if self._need_calibration and not self._calibrated:
1542 self._execute_calibration(self._calibration_input_fn)
1543 # calibration also builds the engine
1544 else:
1545 # Use the first input in explicit batch mode to build TensorRT engines
1546 # after generating all the profiles. The first input is used but any of
1547 # the inputs can be used because the shape of this input does not
1548 # determine the engine and instead the shapes collected in profiles
1549 # determine the engine.
1550 args, kwargs = _convert_to_tensor(first_input)
1551 self._converted_func(*args, **kwargs)
1553 self._build_called_once = True
1555 def save(self,
1556 output_saved_model_dir,
1557 save_gpu_specific_engines=True,
1558 options=None):
1559 """Save the converted SavedModel.
1561 Args:
1562 output_saved_model_dir: directory to saved the converted SavedModel.
1563 save_gpu_specific_engines: whether to save TRT engines that have been
1564 built. When True, all engines are saved and when False, the engines
1565 are not saved and will be rebuilt at inference time. By using
1566 save_gpu_specific_engines=False after doing INT8 calibration, inference
1567 can be done on different GPUs than the GPU that the model was calibrated
1568 and saved on.
1569 options: `tf.saved_model.SaveOptions` object for configuring save options.
1570 Raises:
1571 RuntimeError: if the needed calibration hasn't been done.
1572 """
1573 assert self._converted
1575 # 'remove_native_segments': setting this value to True removes native segments
1576 # associated with each TRT engine. This option can be used to reduce the size
1577 # of the converted model. Please note that a converted model without native
1578 # segments can't be used for collecting profiles, building or re-converting.
1579 # The reduced model can only be used for inference when no native segments
1580 # are required for computation. When remove_native_segments flag is set to
1581 # True, the converted_graph_def needs to be reduced before saved_model
1582 # function serialization.
1583 if trt_utils.is_experimental_feature_activated("remove_native_segments"):
1584 logging.info(
1585 "'remove_native_segments' experimental feature is enabled"
1586 " during saving of converted SavedModel."
1587 )
1588 self._converted_func = _remove_native_segments(self._converted_func)
1589 self._converted_graph_def = self._converted_func.graph.as_graph_def()
1591 if self._need_calibration and not self._calibrated:
1592 raise RuntimeError("A model that requires INT8 calibration has to be "
1593 "built before saving it. Call build() to build and "
1594 "calibrate the TensorRT engines.")
1595 # Serialize the TRT engines in the cache if any, and create trackable
1596 # resource to track them.
1597 engine_asset_dir = tempfile.mkdtemp()
1598 resource_map = {}
1600 def _serialize_and_track_engine(node):
1601 """Serialize TRT engines in the cache and track them."""
1602 # Don't dump the same cache twice.
1603 canonical_engine_name = _get_canonical_engine_name(node.name)
1604 if canonical_engine_name in resource_map:
1605 return
1607 filename = os.path.join(engine_asset_dir,
1608 "trt-serialized-engine." + canonical_engine_name)
1610 try:
1611 gen_trt_ops.serialize_trt_resource(
1612 resource_name=canonical_engine_name,
1613 filename=filename,
1614 delete_resource=True,
1615 save_gpu_specific_engines=save_gpu_specific_engines)
1616 except errors.NotFoundError:
1617 logging.info(
1618 "Could not find %s in TF-TRT cache. "
1619 "This can happen if build() is not called, "
1620 "which means TensorRT engines will be built "
1621 "and cached at runtime.", canonical_engine_name)
1622 return
1624 # TODO(laigd): add an option for the user to choose the device.
1625 resource_map[canonical_engine_name] = _TRTEngineResource(
1626 canonical_engine_name, filename,
1627 self._conversion_params.maximum_cached_engines)
1629 self._for_each_trt_node(self._converted_graph_def,
1630 _serialize_and_track_engine)
1631 # If the graph is frozen, tracked variables are not needed by the converted model.
1632 trackable = autotrackable.AutoTrackable(
1633 ) if self.freeze else self._saved_model
1634 trackable.trt_engine_resources = resource_map
1636 # Set allow_build_at_runtime=False if asked by user.
1637 #
1638 # This attribute is set here because build() needs it to be True in order to
1639 # build engines.
1640 if not self._conversion_params.allow_build_at_runtime:
1642 def _reset_allow_build_at_runtime(node):
1643 node.attr["_allow_build_at_runtime"].b = False
1645 self._for_each_trt_node(self._converted_graph_def,
1646 _reset_allow_build_at_runtime)
1647 # Rebuild the function since a node attribute changed above
1648 reset_converted_func = wrap_function.function_from_graph_def(
1649 self._converted_graph_def,
1650 [tensor.name for tensor in self._converted_func.inputs],
1651 [tensor.name for tensor in self._converted_func.outputs])
1652 reset_converted_func.graph.structured_outputs = nest.pack_sequence_as(
1653 self._converted_func.graph.structured_outputs,
1654 reset_converted_func.graph.structured_outputs)
1655 reset_converted_func.graph.structured_input_signature = (
1656 self._converted_func.structured_input_signature)
1657 self._converted_func = reset_converted_func
1659 # Rewrite the signature map using the optimized ConcreteFunction.
1660 signatures = {self._input_saved_model_signature_key: self._converted_func}
1661 save.save(trackable, output_saved_model_dir, signatures, options=options)
1663 def summary(self, line_length=160, detailed=True, print_fn=None):
1664 """This method describes the results of the conversion by TF-TRT.
1666 It includes information such as the name of the engine, the number of nodes
1667 per engine, the input and output dtype, along with the input shape of each
1668 TRTEngineOp.
1670 Args:
1671 line_length: Default line length when printing on the console. Minimum 160
1672 characters long.
1673 detailed: Whether or not to show the nodes inside each TRTEngineOp.
1674 print_fn: Print function to use. Defaults to `print`. It will be called on
1675 each line of the summary. You can set it to a custom function in order
1676 to capture the string summary.
1678 Raises:
1679 RuntimeError: if the graph is not converted.
1680 """
1681 if not self._converted:
1682 raise RuntimeError(
1683 f"Impossible to call `{self.__class__.__name__}.summary()` before "
1684 f"calling {self.__class__.__name__}.convert()`.")
1686 if line_length < 160:
1687 raise ValueError(f"Invalid `line_length` value has been received: "
1688 f"{line_length}. Minimum: 160.")
1690 if print_fn is None:
1691 print_fn = print
1693 # positions are percentage of `line_length`. positions[i]+1 is the starting
1694 # position for (i+1)th field. We also make sure that the last char printed
1695 # for each field is a space.
1696 columns = [
1697 # (column name, column size in % of line)
1698 ("TRTEngineOP Name", .20), # 20%
1699 ("Device", .09), # 29%
1700 ("# Nodes", .05), # 34%
1701 ("# Inputs", .09), # 43%
1702 ("# Outputs", .09), # 52%
1703 ("Input DTypes", .12), # 64%
1704 ("Output Dtypes", .12), # 76%
1705 ("Input Shapes", .12), # 88%
1706 ("Output Shapes", .12) # 100%
1707 ]
1709 positions = [int(line_length * p) for _, p in columns]
1710 positions = np.cumsum(positions).tolist()
1711 headers = [h for h, _ in columns]
1713 _print_row(headers, positions, print_fn=print_fn)
1714 print_fn("=" * line_length)
1716 n_engines = 0
1717 n_ops_converted = 0
1718 n_ops_not_converted = 0
1720 graphdef = self._converted_func.graph.as_graph_def(add_shapes=True)
1722 trtengineops_dict = dict()
1723 for node in graphdef.node:
1724 if node.op != "TRTEngineOp":
1725 n_ops_not_converted += 1
1726 continue
1727 else:
1728 trtengineops_dict[node.name] = node
1729 n_engines += 1
1731 for name, node in sorted(trtengineops_dict.items()):
1732 node_device = node.device.split("/")[-1]
1733 in_shapes = trt_utils.get_node_io_shapes(node, "input_shapes")
1734 out_shapes = trt_utils.get_node_io_shapes(node, "_output_shapes")
1735 in_dtypes = trt_utils.get_trtengineop_io_dtypes(node, "InT")
1736 out_dtypes = trt_utils.get_trtengineop_io_dtypes(node, "OutT")
1737 in_nodes_count = trt_utils.get_trtengineop_io_nodes_count(node, "InT")
1738 out_nodes_count = trt_utils.get_trtengineop_io_nodes_count(node, "OutT")
1739 node_count, converted_ops_dict = trt_utils.get_trtengineop_node_op_count(
1740 graphdef, name)
1742 n_ops_converted += node_count
1744 if n_engines != 1:
1745 print_fn(f"\n{'-'*40}\n")
1747 _print_row(
1748 fields=[
1749 name, node_device, node_count, in_nodes_count, out_nodes_count,
1750 in_dtypes, out_dtypes, in_shapes, out_shapes
1751 ],
1752 positions=positions,
1753 print_fn=print_fn)
1755 if detailed:
1756 print_fn()
1757 for key, value in sorted(dict(converted_ops_dict).items()):
1758 print_fn(f"\t- {key}: {value}x")
1760 print_fn(f"\n{'='*line_length}")
1761 print_fn(f"[*] Total number of TensorRT engines: {n_engines}")
1762 total_ops = n_ops_not_converted + n_ops_converted
1763 conversion_ratio = n_ops_converted / total_ops * 100
1764 print_fn(f"[*] % of OPs Converted: {conversion_ratio:.2f}% "
1765 f"[{n_ops_converted}/{total_ops}]\n")
1768# TODO(laigd): use TrtConversionParams here.
1769def create_inference_graph(
1770 input_graph_def,
1771 outputs,
1772 max_batch_size=1,
1773 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
1774 precision_mode=TrtPrecisionMode.FP32,
1775 minimum_segment_size=3,
1776 is_dynamic_op=False,
1777 maximum_cached_engines=1,
1778 input_saved_model_dir=None,
1779 input_saved_model_tags=None,
1780 input_saved_model_signature_key=None,
1781 output_saved_model_dir=None):
1782 """Python wrapper for the TRT transformation.
1784 Args:
1785 input_graph_def: a GraphDef object containing a model to be transformed. If
1786 set to None, the graph will be read from the SavedModel loaded from
1787 input_saved_model_dir.
1788 outputs: list of tensors or node names for the model outputs. Only used when
1789 input_graph_def is not None.
1790 max_batch_size: max size for the input batch.
1791 max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
1792 engine can use at execution time. This corresponds to the 'workspaceSize'
1793 parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
1794 precision_mode: one of TrtPrecisionMode.supported_precision_modes().
1795 minimum_segment_size: the minimum number of nodes required for a subgraph to
1796 be replaced by TRTEngineOp.
1797 is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
1798 network and engine at run time.
1799 maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
1800 If the number of cached engines is already at max but none of them can
1801 serve the input, the TRTEngineOp will fall back to run the TF function
1802 based on which the TRTEngineOp is created.
1803 input_saved_model_dir: the directory to load the SavedModel which contains
1804 the input graph to transforms. Used only when input_graph_def is None.
1805 input_saved_model_tags: list of tags to load the SavedModel.
1806 input_saved_model_signature_key: the key of the signature to optimize the
1807 graph for.
1808 output_saved_model_dir: if not None, construct a SavedModel using the
1809 returned GraphDef and save it to the specified directory. This option only
1810 works when the input graph is loaded from a SavedModel, i.e. when
1811 input_saved_model_dir is specified and input_graph_def is None.
1813 Returns:
1814 A GraphDef transformed from input_graph_def (or the SavedModel graph def
1815 loaded from input_saved_model_dir, if input_graph_def is not present), where
1816 all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
1817 function is added for each of the subgraphs.
1819 If is_dynamic_op is True, each TRTEngineOp will contain a serialized
1820 subgraph GraphDef, which will be converted to a TRT engine at execution time
1821 and the TRT engine will be cached for future usage. A new TRT engine will be
1822 created each time when none of the cached engines match the input shapes. If
1823 it fails to execute the TRT engine or the number of cached engines reaches
1824 maximum_cached_engines, the op will fall back to call the corresponding TF
1825 function.
1827 If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
1828 engine created from the corresponding subgraph. No more engines will be
1829 created on the fly, and the op will fall back to call the corresponding TF
1830 function when it fails to execute the engine.
1832 Raises:
1833 ValueError: if the combination of the parameters is invalid.
1834 """
1835 trt_converter = TrtGraphConverter(
1836 input_saved_model_dir=input_saved_model_dir,
1837 input_saved_model_tags=input_saved_model_tags,
1838 input_saved_model_signature_key=input_saved_model_signature_key,
1839 input_graph_def=input_graph_def,
1840 nodes_denylist=outputs,
1841 max_batch_size=max_batch_size,
1842 max_workspace_size_bytes=max_workspace_size_bytes,
1843 precision_mode=precision_mode,
1844 minimum_segment_size=minimum_segment_size,
1845 is_dynamic_op=is_dynamic_op,
1846 maximum_cached_engines=maximum_cached_engines,
1847 use_calibration=False)
1848 converted_graph_def = trt_converter.convert()
1849 if output_saved_model_dir:
1850 trt_converter.save(output_saved_model_dir)
1851 return converted_graph_def