Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/compiler/tensorrt/utils.py: 17%
138 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Exposes the Python wrapper conversion to trt_graph."""
17import collections
18import os
19import re
21from packaging import version
23from tensorflow.compiler.tf2tensorrt import _pywrap_py_utils
24from tensorflow.core.protobuf import rewriter_config_pb2
25from tensorflow.python.framework import dtypes
28def disable_non_trt_optimizers_in_rewriter_config(rewriter_config):
29 """Modifies rewriter_config to disable all non-TRT optimizations."""
30 off = rewriter_config_pb2.RewriterConfig.OFF
32 rewriter_config.arithmetic_optimization = off
33 rewriter_config.auto_mixed_precision = off
34 rewriter_config.auto_parallel.enable = False
35 rewriter_config.constant_folding = off
36 rewriter_config.debug_stripper = off
37 rewriter_config.dependency_optimization = off
38 # This one needs to be ON to allow TF-TRT
39 rewriter_config.disable_meta_optimizer = False
40 rewriter_config.disable_model_pruning = True
41 rewriter_config.function_optimization = off
42 rewriter_config.implementation_selector = off
43 rewriter_config.layout_optimizer = off
44 rewriter_config.loop_optimization = off
45 rewriter_config.memory_optimization = (
46 rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
47 rewriter_config.min_graph_nodes = -1
48 rewriter_config.pin_to_host_optimization = off
49 rewriter_config.remapping = off
50 rewriter_config.scoped_allocator_optimization = off
51 rewriter_config.shape_optimization = off
54def version_tuple_to_string(ver_tuple):
55 assert isinstance(ver_tuple, tuple)
56 assert len(ver_tuple) == 3
58 ver_tuple = [str(x) for x in ver_tuple]
59 return ".".join(ver_tuple)
62def _is_tensorrt_version_greater_equal(trt_ver, target_ver):
63 trt_ver = version.Version(version_tuple_to_string(trt_ver))
64 target_ver = version.Version(version_tuple_to_string(target_ver))
66 return trt_ver >= target_ver
69def is_linked_tensorrt_version_greater_equal(major, minor=0, patch=0):
70 ver = _pywrap_py_utils.get_linked_tensorrt_version()
71 return _is_tensorrt_version_greater_equal(ver, (major, minor, patch))
74def is_loaded_tensorrt_version_greater_equal(major, minor=0, patch=0):
75 ver = _pywrap_py_utils.get_loaded_tensorrt_version()
76 return _is_tensorrt_version_greater_equal(ver, (major, minor, patch))
79def is_experimental_feature_activated(feature_name):
80 """Determines if a TF-TRT experimental feature is enabled.
82 This helper function checks if an experimental feature was enabled using
83 the environment variable `TF_TRT_EXPERIMENTAL_FEATURES=feature_1,feature_2`.
85 Args:
86 feature_name: Name of the feature being tested for activation.
87 """
89 return (feature_name
90 in os.environ.get("TF_TRT_EXPERIMENTAL_FEATURES",
91 default="").split(","))
94def _convert_dtype_id_to_str(dtype):
95 """Helper function to convert a dtype id to a corresponding string name."""
96 if isinstance(dtype, int):
97 return dtypes._TYPE_TO_STRING[dtype]
98 else:
99 return [dtypes._TYPE_TO_STRING[d] for d in dtype]
102def get_node_compute_dtype(node):
103 """Returns the compute DType of a GraphDef Node."""
104 # Note: Order is important, by default TF Node compute dtype is mentioned
105 # under `T` key, unless these nodes are one of ["TRTEngineOP", "Cast", "Plh"].
106 for type_key in [
107 "precision_mode", # TRTEngineOp
108 "DstT", # Cast Nodes
109 "dtype", # Placeholder
110 "T", # Everything Else
111 ]:
112 try:
113 precision_val = node.attr[type_key]
114 if type_key == "precision_mode":
115 precision_val = precision_val.s.decode("utf-8")
116 if precision_val == "":
117 continue
118 if precision_val == "FP32":
119 return "float32"
120 elif precision_val == "FP16":
121 return "float16"
122 elif precision_val == "INT8":
123 return "int8"
124 else:
125 return "unknown"
126 else:
127 return _convert_dtype_id_to_str(precision_val.type)
128 except Exception as e:
129 continue
132def get_node_io_shapes(node, key):
133 """Returns the input/output shapes of a GraphDef Node."""
134 out_shape = []
135 for shape in node.attr[key].list.shape:
136 out_shape.append([dim.size for dim in shape.dim])
137 return out_shape
140def get_trtengineop_io_dtypes(node, key):
141 """Returns the input/output dtypes of a TRTEngineOp."""
142 return _convert_dtype_id_to_str(node.attr[key].list.type)
145def get_trtengineop_io_nodes_count(node, key):
146 """Returns the number of input/output nodes of a TRTEngineOp."""
147 return len(node.attr[key].list.type)
150def get_trtengineop_node_op_count(graphdef, node_name):
151 """Counts the number of nodes and OP types of a given TRTEngineOp."""
152 ops_in_engine = collections.defaultdict(int)
153 for func in graphdef.library.function:
154 if f"{node_name}_native_segment" == func.signature.name:
155 node_count = len(func.node_def)
156 for node in func.node_def:
157 ops_in_engine[node.op] += 1
158 break
159 return node_count, ops_in_engine
162class DTypeIndex(dict):
163 """Helper class to create an index of dtypes with incremental values."""
165 def get_dtype_index(self, dtype):
166 if dtype not in self:
167 self[dtype] = len(self) + 1
168 return self[dtype]
171def draw_graphdef_as_graphviz(graphdef, dot_output_filename):
172 """Exports a GraphDef to GraphViz format.
174 - Step 1: Drawing Each Node of the compute GraphDef.
175 - Step 2: Create nodes for each collected dtype in the graph.
176 - Step 3: Creating invisible links to align properly the legend.
178 Each node consequently mentions:
179 - Op Type
180 - Compute Dtype
181 - Compute Device
182 """
184 dtype_index = DTypeIndex()
186 with open(dot_output_filename, "w") as f:
187 print("digraph tftrt_converted_graph {", file=f)
189 print(" graph [fontsize=10 fontname=\"Verdana\"];", file=f)
190 # ColorScheme Documentation: https://graphviz.org/doc/info/colors.html
191 print(
192 " node [style=filled height=0.55 colorscheme=set312 shape=box];",
193 file=f)
195 # Step 1: Parsing the graph and drawing OPs one by one.
196 print("\n subgraph tensorflow_graph {", file=f)
197 print(" node [width=1.35];", file=f)
198 nodes_with_no_inputs = []
199 for node in graphdef.node:
200 output_name = node.name
202 node_precision = get_node_compute_dtype(node)
203 color_idx = dtype_index.get_dtype_index(node_precision)
205 device_key = node.device.split("/")[-1]
206 if not device_key:
207 device_key = "device:Unspecified"
209 if node.op == "TRTEngineOp":
210 node_count, _ = get_trtengineop_node_op_count(graphdef, output_name)
211 node_label = f"{output_name} [{node_count}]"
212 else:
213 node_label = f"{node.op}"
215 # Note: double space before <br/> is necessary for formatting.
216 node_label = f"<b>{node_label}</b> <br/><i>{device_key}</i>"
218 print(
219 f" \"{output_name}\" [label=<{node_label}> "
220 f"fillcolor={color_idx}];",
221 file=f)
223 if len(node.input):
224 for input_full_name in node.input:
225 parts = input_full_name.split(":")
226 input_name = re.sub(r"^\^", "", parts[0])
227 print(f" \"{input_name}\" -> \"{output_name}\";", file=f)
228 else:
229 nodes_with_no_inputs.append(output_name)
230 print(" }", file=f)
232 # Step 2: Creating the DType Nodes previously found in Step 1.
233 print("\n subgraph cluster_legend {", file=f)
234 print(" label=\"Compute Dtype Legend\";", file=f)
235 print(" margin=\"30\";", file=f)
236 print(" node [width=2];", file=f)
238 for dtype, color_idx in dtype_index.items():
239 print(
240 f" {dtype} [fillcolor={color_idx} label=<<b>{dtype}</b>>];",
241 file=f)
243 print(" }", file=f)
245 # Step 3: Alignement of the legend with the graph.
246 print("\n edge[style=\"invisible\", dir=\"none\"];", file=f)
247 for dtype in dtype_index.keys():
248 for node_name in nodes_with_no_inputs:
249 print(f" \"{dtype}\" -> \"{node_name}\"", file=f)
251 print("}", file=f)
253 print("\n===================================================================")
254 print(f"Graph Visualization Exported to: `{dot_output_filename}`.")
255 print("We recommend using https://edotor.net/ to visualize the .dot file.")
256 print("You can also use `graphviz` utility to convert them to PNG format:")
257 print(" - `sudo apt install -y graphviz`")
258 print(" - `dot -Tpng <input_filename>.dot -o <output_filename>.png`")
259 print("===================================================================\n")