Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/compiler/tensorrt/utils.py: 17%

138 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15"""Exposes the Python wrapper conversion to trt_graph.""" 

16 

17import collections 

18import os 

19import re 

20 

21from packaging import version 

22 

23from tensorflow.compiler.tf2tensorrt import _pywrap_py_utils 

24from tensorflow.core.protobuf import rewriter_config_pb2 

25from tensorflow.python.framework import dtypes 

26 

27 

28def disable_non_trt_optimizers_in_rewriter_config(rewriter_config): 

29 """Modifies rewriter_config to disable all non-TRT optimizations.""" 

30 off = rewriter_config_pb2.RewriterConfig.OFF 

31 

32 rewriter_config.arithmetic_optimization = off 

33 rewriter_config.auto_mixed_precision = off 

34 rewriter_config.auto_parallel.enable = False 

35 rewriter_config.constant_folding = off 

36 rewriter_config.debug_stripper = off 

37 rewriter_config.dependency_optimization = off 

38 # This one needs to be ON to allow TF-TRT 

39 rewriter_config.disable_meta_optimizer = False 

40 rewriter_config.disable_model_pruning = True 

41 rewriter_config.function_optimization = off 

42 rewriter_config.implementation_selector = off 

43 rewriter_config.layout_optimizer = off 

44 rewriter_config.loop_optimization = off 

45 rewriter_config.memory_optimization = ( 

46 rewriter_config_pb2.RewriterConfig.NO_MEM_OPT) 

47 rewriter_config.min_graph_nodes = -1 

48 rewriter_config.pin_to_host_optimization = off 

49 rewriter_config.remapping = off 

50 rewriter_config.scoped_allocator_optimization = off 

51 rewriter_config.shape_optimization = off 

52 

53 

54def version_tuple_to_string(ver_tuple): 

55 assert isinstance(ver_tuple, tuple) 

56 assert len(ver_tuple) == 3 

57 

58 ver_tuple = [str(x) for x in ver_tuple] 

59 return ".".join(ver_tuple) 

60 

61 

62def _is_tensorrt_version_greater_equal(trt_ver, target_ver): 

63 trt_ver = version.Version(version_tuple_to_string(trt_ver)) 

64 target_ver = version.Version(version_tuple_to_string(target_ver)) 

65 

66 return trt_ver >= target_ver 

67 

68 

69def is_linked_tensorrt_version_greater_equal(major, minor=0, patch=0): 

70 ver = _pywrap_py_utils.get_linked_tensorrt_version() 

71 return _is_tensorrt_version_greater_equal(ver, (major, minor, patch)) 

72 

73 

74def is_loaded_tensorrt_version_greater_equal(major, minor=0, patch=0): 

75 ver = _pywrap_py_utils.get_loaded_tensorrt_version() 

76 return _is_tensorrt_version_greater_equal(ver, (major, minor, patch)) 

77 

78 

79def is_experimental_feature_activated(feature_name): 

80 """Determines if a TF-TRT experimental feature is enabled. 

81 

82 This helper function checks if an experimental feature was enabled using 

83 the environment variable `TF_TRT_EXPERIMENTAL_FEATURES=feature_1,feature_2`. 

84 

85 Args: 

86 feature_name: Name of the feature being tested for activation. 

87 """ 

88 

89 return (feature_name 

90 in os.environ.get("TF_TRT_EXPERIMENTAL_FEATURES", 

91 default="").split(",")) 

92 

93 

94def _convert_dtype_id_to_str(dtype): 

95 """Helper function to convert a dtype id to a corresponding string name.""" 

96 if isinstance(dtype, int): 

97 return dtypes._TYPE_TO_STRING[dtype] 

98 else: 

99 return [dtypes._TYPE_TO_STRING[d] for d in dtype] 

100 

101 

102def get_node_compute_dtype(node): 

103 """Returns the compute DType of a GraphDef Node.""" 

104 # Note: Order is important, by default TF Node compute dtype is mentioned 

105 # under `T` key, unless these nodes are one of ["TRTEngineOP", "Cast", "Plh"]. 

106 for type_key in [ 

107 "precision_mode", # TRTEngineOp 

108 "DstT", # Cast Nodes 

109 "dtype", # Placeholder 

110 "T", # Everything Else 

111 ]: 

112 try: 

113 precision_val = node.attr[type_key] 

114 if type_key == "precision_mode": 

115 precision_val = precision_val.s.decode("utf-8") 

116 if precision_val == "": 

117 continue 

118 if precision_val == "FP32": 

119 return "float32" 

120 elif precision_val == "FP16": 

121 return "float16" 

122 elif precision_val == "INT8": 

123 return "int8" 

124 else: 

125 return "unknown" 

126 else: 

127 return _convert_dtype_id_to_str(precision_val.type) 

128 except Exception as e: 

129 continue 

130 

131 

132def get_node_io_shapes(node, key): 

133 """Returns the input/output shapes of a GraphDef Node.""" 

134 out_shape = [] 

135 for shape in node.attr[key].list.shape: 

136 out_shape.append([dim.size for dim in shape.dim]) 

137 return out_shape 

138 

139 

140def get_trtengineop_io_dtypes(node, key): 

141 """Returns the input/output dtypes of a TRTEngineOp.""" 

142 return _convert_dtype_id_to_str(node.attr[key].list.type) 

143 

144 

145def get_trtengineop_io_nodes_count(node, key): 

146 """Returns the number of input/output nodes of a TRTEngineOp.""" 

147 return len(node.attr[key].list.type) 

148 

149 

150def get_trtengineop_node_op_count(graphdef, node_name): 

151 """Counts the number of nodes and OP types of a given TRTEngineOp.""" 

152 ops_in_engine = collections.defaultdict(int) 

153 for func in graphdef.library.function: 

154 if f"{node_name}_native_segment" == func.signature.name: 

155 node_count = len(func.node_def) 

156 for node in func.node_def: 

157 ops_in_engine[node.op] += 1 

158 break 

159 return node_count, ops_in_engine 

160 

161 

162class DTypeIndex(dict): 

163 """Helper class to create an index of dtypes with incremental values.""" 

164 

165 def get_dtype_index(self, dtype): 

166 if dtype not in self: 

167 self[dtype] = len(self) + 1 

168 return self[dtype] 

169 

170 

171def draw_graphdef_as_graphviz(graphdef, dot_output_filename): 

172 """Exports a GraphDef to GraphViz format. 

173 

174 - Step 1: Drawing Each Node of the compute GraphDef. 

175 - Step 2: Create nodes for each collected dtype in the graph. 

176 - Step 3: Creating invisible links to align properly the legend. 

177 

178 Each node consequently mentions: 

179 - Op Type 

180 - Compute Dtype 

181 - Compute Device 

182 """ 

183 

184 dtype_index = DTypeIndex() 

185 

186 with open(dot_output_filename, "w") as f: 

187 print("digraph tftrt_converted_graph {", file=f) 

188 

189 print(" graph [fontsize=10 fontname=\"Verdana\"];", file=f) 

190 # ColorScheme Documentation: https://graphviz.org/doc/info/colors.html 

191 print( 

192 " node [style=filled height=0.55 colorscheme=set312 shape=box];", 

193 file=f) 

194 

195 # Step 1: Parsing the graph and drawing OPs one by one. 

196 print("\n subgraph tensorflow_graph {", file=f) 

197 print(" node [width=1.35];", file=f) 

198 nodes_with_no_inputs = [] 

199 for node in graphdef.node: 

200 output_name = node.name 

201 

202 node_precision = get_node_compute_dtype(node) 

203 color_idx = dtype_index.get_dtype_index(node_precision) 

204 

205 device_key = node.device.split("/")[-1] 

206 if not device_key: 

207 device_key = "device:Unspecified" 

208 

209 if node.op == "TRTEngineOp": 

210 node_count, _ = get_trtengineop_node_op_count(graphdef, output_name) 

211 node_label = f"{output_name} [{node_count}]" 

212 else: 

213 node_label = f"{node.op}" 

214 

215 # Note: double space before <br/> is necessary for formatting. 

216 node_label = f"<b>{node_label}</b> <br/><i>{device_key}</i>" 

217 

218 print( 

219 f" \"{output_name}\" [label=<{node_label}> " 

220 f"fillcolor={color_idx}];", 

221 file=f) 

222 

223 if len(node.input): 

224 for input_full_name in node.input: 

225 parts = input_full_name.split(":") 

226 input_name = re.sub(r"^\^", "", parts[0]) 

227 print(f" \"{input_name}\" -> \"{output_name}\";", file=f) 

228 else: 

229 nodes_with_no_inputs.append(output_name) 

230 print(" }", file=f) 

231 

232 # Step 2: Creating the DType Nodes previously found in Step 1. 

233 print("\n subgraph cluster_legend {", file=f) 

234 print(" label=\"Compute Dtype Legend\";", file=f) 

235 print(" margin=\"30\";", file=f) 

236 print(" node [width=2];", file=f) 

237 

238 for dtype, color_idx in dtype_index.items(): 

239 print( 

240 f" {dtype} [fillcolor={color_idx} label=<<b>{dtype}</b>>];", 

241 file=f) 

242 

243 print(" }", file=f) 

244 

245 # Step 3: Alignement of the legend with the graph. 

246 print("\n edge[style=\"invisible\", dir=\"none\"];", file=f) 

247 for dtype in dtype_index.keys(): 

248 for node_name in nodes_with_no_inputs: 

249 print(f" \"{dtype}\" -> \"{node_name}\"", file=f) 

250 

251 print("}", file=f) 

252 

253 print("\n===================================================================") 

254 print(f"Graph Visualization Exported to: `{dot_output_filename}`.") 

255 print("We recommend using https://edotor.net/ to visualize the .dot file.") 

256 print("You can also use `graphviz` utility to convert them to PNG format:") 

257 print(" - `sudo apt install -y graphviz`") 

258 print(" - `dot -Tpng <input_filename>.dot -o <output_filename>.png`") 

259 print("===================================================================\n")