Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/utils.py: 17%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

147 statements  

1# Copyright (c) ONNX Project Contributors 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4from __future__ import annotations 

5 

6import os 

7import tarfile 

8from collections import deque 

9from typing import TYPE_CHECKING 

10 

11import onnx.checker 

12import onnx.helper 

13import onnx.shape_inference 

14 

15if TYPE_CHECKING: 

16 from onnx.onnx_pb import ( 

17 FunctionProto, 

18 ModelProto, 

19 NodeProto, 

20 TensorProto, 

21 ValueInfoProto, 

22 ) 

23 

24 

25class Extractor: 

26 def __init__(self, model: ModelProto) -> None: 

27 self.model = model 

28 self.graph = self.model.graph 

29 self.initializers: dict[str, TensorProto] = self._build_name2obj_dict( 

30 self.graph.initializer 

31 ) 

32 self.value_infos: dict[str, ValueInfoProto] = self._build_name2obj_dict( 

33 self.graph.value_info 

34 ) 

35 # Add input and output values (not included in the value_info for intermediate values) 

36 self.value_infos.update(self._build_name2obj_dict(self.graph.input)) 

37 self.value_infos.update(self._build_name2obj_dict(self.graph.output)) 

38 self.outmap: dict[str, int] = self._build_output_dict(self.graph) 

39 

40 @staticmethod 

41 def _build_name2obj_dict(objs) -> dict: 

42 return {obj.name: obj for obj in objs} 

43 

44 @staticmethod 

45 def _build_output_dict(graph) -> dict[str, int]: 

46 output_to_index: dict[str, int] = {} 

47 for index, node in enumerate(graph.node): 

48 for output_name in node.output: 

49 if output_name == "": 

50 continue 

51 assert output_name not in output_to_index # output_name is unique 

52 output_to_index[output_name] = index 

53 return output_to_index 

54 

55 def _collect_new_io(self, io_names_to_extract: list[str]) -> list[ValueInfoProto]: 

56 # Validate that all names exist in self.value_infos 

57 missing_names = [ 

58 name for name in io_names_to_extract if name not in self.value_infos 

59 ] 

60 if missing_names: 

61 raise ValueError( 

62 f"The following names were not found in value_infos: {', '.join(missing_names)}" 

63 ) 

64 return [self.value_infos[name] for name in io_names_to_extract] 

65 

66 def _dfs_search_reachable_nodes( 

67 self, 

68 node_output_name: str, 

69 graph_input_names: set[str], 

70 reachable: set[int], 

71 ) -> None: 

72 """Helper function to find nodes which are connected to an output 

73 

74 Arguments: 

75 node_output_name (str): The name of the output 

76 graph_input_names (set of string): The names of all inputs of the graph 

77 reachable (set of int): The set of indexes to reachable nodes in `nodes` 

78 """ 

79 stack = [node_output_name] 

80 while stack: 

81 current_output_name = stack.pop() 

82 # finish search at inputs 

83 if current_output_name in graph_input_names: 

84 continue 

85 # find nodes connected to this output 

86 if current_output_name in self.outmap: 

87 index = self.outmap[current_output_name] 

88 if index not in reachable: 

89 # add nodes connected to this output to sets 

90 reachable.add(index) 

91 stack += [ 

92 input_name 

93 for input_name in self.graph.node[index].input 

94 if input_name != "" 

95 ] 

96 

97 def _collect_reachable_nodes( 

98 self, 

99 input_names: list[str], 

100 output_names: list[str], 

101 ) -> list[NodeProto]: 

102 _input_names = set(input_names) 

103 reachable: set[int] = set() 

104 for name in output_names: 

105 self._dfs_search_reachable_nodes(name, _input_names, reachable) 

106 # needs to be topologically sorted 

107 return [self.graph.node[index] for index in sorted(reachable)] 

108 

109 def _collect_referred_local_functions( 

110 self, 

111 nodes: list[NodeProto], 

112 ) -> list[FunctionProto]: 

113 # a node in a model graph may refer a function. 

114 # a function contains nodes, some of which may in turn refer a function. 

115 # we need to find functions referred by graph nodes and 

116 # by nodes used to define functions. 

117 function_map: dict[tuple[str, str], FunctionProto] = {} 

118 for function in self.model.functions: 

119 function_map[(function.name, function.domain)] = function 

120 referred_local_functions: list[FunctionProto] = [] 

121 queue = deque(nodes) 

122 while queue: 

123 node = queue.popleft() 

124 # check if the node is a function op 

125 if (node.op_type, node.domain) in function_map: 

126 function = function_map.pop((node.op_type, node.domain)) 

127 referred_local_functions.append(function) 

128 queue.extend(function.node) 

129 # needs to be topologically sorted 

130 return referred_local_functions 

131 

132 def _collect_reachable_tensors( 

133 self, 

134 nodes: list[NodeProto], 

135 ) -> tuple[list[TensorProto], list[ValueInfoProto]]: 

136 all_tensors_names: set[str] = set() 

137 for node in nodes: 

138 all_tensors_names.update(node.input) 

139 all_tensors_names.update(node.output) 

140 initializer = [ 

141 self.initializers[t] for t in self.initializers if t in all_tensors_names 

142 ] 

143 value_info = [ 

144 self.value_infos[t] for t in self.value_infos if t in all_tensors_names 

145 ] 

146 len_sparse_initializer = len(self.graph.sparse_initializer) 

147 if len_sparse_initializer != 0: 

148 raise ValueError( 

149 f"len_sparse_initializer is {len_sparse_initializer}, it must be 0." 

150 ) 

151 len_quantization_annotation = len(self.graph.quantization_annotation) 

152 if len_quantization_annotation != 0: 

153 raise ValueError( 

154 f"len_quantization_annotation is {len_quantization_annotation}, it must be 0." 

155 ) 

156 return initializer, value_info 

157 

158 def _make_model( 

159 self, 

160 nodes: list[NodeProto], 

161 inputs: list[ValueInfoProto], 

162 outputs: list[ValueInfoProto], 

163 initializer: list[TensorProto], 

164 value_info: list[ValueInfoProto], 

165 local_functions: list[FunctionProto], 

166 ) -> ModelProto: 

167 name = "Extracted from {" + self.graph.name + "}" 

168 graph = onnx.helper.make_graph( 

169 nodes, name, inputs, outputs, initializer=initializer, value_info=value_info 

170 ) 

171 meta = { 

172 "ir_version": self.model.ir_version, 

173 "opset_imports": self.model.opset_import, 

174 "producer_name": "onnx.utils.extract_model", 

175 "functions": local_functions, 

176 } 

177 return onnx.helper.make_model(graph, **meta) 

178 

179 def extract_model( 

180 self, 

181 input_names: list[str], 

182 output_names: list[str], 

183 ) -> ModelProto: 

184 inputs = self._collect_new_io(input_names) 

185 outputs = self._collect_new_io(output_names) 

186 nodes = self._collect_reachable_nodes(input_names, output_names) 

187 initializer, value_info = self._collect_reachable_tensors(nodes) 

188 local_functions = self._collect_referred_local_functions(nodes) 

189 return self._make_model( 

190 nodes, inputs, outputs, initializer, value_info, local_functions 

191 ) 

192 

193 

194def extract_model( 

195 input_path: str | os.PathLike, 

196 output_path: str | os.PathLike, 

197 input_names: list[str], 

198 output_names: list[str], 

199 check_model: bool = True, 

200 infer_shapes: bool = True, 

201) -> None: 

202 """Extracts sub-model from an ONNX model. 

203 

204 The sub-model is defined by the names of the input and output tensors *exactly*. 

205 

206 Note: For control-flow operators, e.g. If and Loop, the _boundary of sub-model_, 

207 which is defined by the input and output tensors, should not _cut through_ the 

208 subgraph that is connected to the _main graph_ as attributes of these operators. 

209 

210 Note: When the extracted model size is larger than 2GB, the extra data will be saved in "output_path.data". 

211 

212 Arguments: 

213 input_path (str | os.PathLike): The path to original ONNX model. 

214 output_path (str | os.PathLike): The path to save the extracted ONNX model. 

215 input_names (list of string): The names of the input tensors that to be extracted. 

216 output_names (list of string): The names of the output tensors that to be extracted. 

217 check_model (bool): Whether to run model checker on the original model and the extracted model. 

218 infer_shapes (bool): Whether to infer the shapes of the original model. 

219 """ 

220 if not os.path.exists(input_path): 

221 raise ValueError(f"Invalid input model path: {input_path}") 

222 if not output_path: 

223 raise ValueError("Output model path shall not be empty!") 

224 if not input_names: 

225 raise ValueError("Input tensor names shall not be empty!") 

226 if not output_names: 

227 raise ValueError("Output tensor names shall not be empty!") 

228 

229 if len(input_names) != len(set(input_names)): 

230 raise ValueError("Duplicate names found in the input tensor names.") 

231 if len(output_names) != len(set(output_names)): 

232 raise ValueError("Duplicate names found in the output tensor names.") 

233 

234 if check_model: 

235 onnx.checker.check_model(input_path) 

236 

237 if infer_shapes and os.path.getsize(input_path) > onnx.checker.MAXIMUM_PROTOBUF: 

238 onnx.shape_inference.infer_shapes_path(input_path, output_path) 

239 model = onnx.load(output_path) 

240 elif infer_shapes: 

241 model = onnx.load(input_path, load_external_data=False) 

242 model = onnx.shape_inference.infer_shapes(model) 

243 base_dir = os.path.dirname(input_path) 

244 onnx.load_external_data_for_model(model, base_dir) 

245 else: 

246 model = onnx.load(input_path) 

247 

248 e = Extractor(model) 

249 extracted = e.extract_model(input_names, output_names) 

250 

251 if extracted.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF: 

252 location = os.path.basename(output_path) + ".data" 

253 onnx.save(extracted, output_path, save_as_external_data=True, location=location) 

254 else: 

255 onnx.save(extracted, output_path) 

256 

257 if check_model: 

258 onnx.checker.check_model(output_path) 

259 

260 

261def _tar_members_filter( 

262 tar: tarfile.TarFile, base: str | os.PathLike 

263) -> list[tarfile.TarInfo]: 

264 """Check that the content of ``tar`` will be extracted safely 

265 

266 Args: 

267 tar: The tarball file 

268 base: The directory where the tarball will be extracted 

269 

270 Returns: 

271 list of tarball members 

272 """ 

273 result = [] 

274 abs_base = os.path.abspath(base) 

275 for member in tar: 

276 member_path = os.path.join(base, member.name) 

277 abs_member = os.path.abspath(member_path) 

278 try: 

279 is_within_base = os.path.commonpath([abs_base, abs_member]) == abs_base 

280 except ValueError: 

281 is_within_base = False 

282 if not is_within_base: 

283 raise RuntimeError( 

284 f"The tarball member {member_path} in downloading model contains " 

285 f"directory traversal sequence which may contain harmful payload." 

286 ) 

287 if member.issym() or member.islnk(): 

288 raise RuntimeError( 

289 f"The tarball member {member_path} in downloading model contains " 

290 f"symbolic links which may contain harmful payload." 

291 ) 

292 result.append(member) 

293 return result 

294 

295 

296def _extract_model_safe( 

297 model_tar_path: str | os.PathLike, local_model_with_data_dir_path: str | os.PathLike 

298) -> None: 

299 """Safely extracts a tar file to a specified directory. 

300 

301 This function ensures that the extraction process mitigates against 

302 directory traversal vulnerabilities by validating or sanitizing paths 

303 within the tar file. It also provides compatibility for different versions 

304 of the tarfile module by checking for the availability of certain attributes 

305 or methods before invoking them. 

306 

307 Args: 

308 model_tar_path: The path to the tar file to be extracted. 

309 local_model_with_data_dir_path: The directory path where the tar file 

310 contents will be extracted to. 

311 """ 

312 with tarfile.open(model_tar_path) as model_with_data_zipped: 

313 # Mitigate tarball directory traversal risks 

314 if hasattr(tarfile, "data_filter"): 

315 model_with_data_zipped.extractall( 

316 path=local_model_with_data_dir_path, filter="data" 

317 ) 

318 else: 

319 model_with_data_zipped.extractall( # noqa: S202 

320 path=local_model_with_data_dir_path, 

321 members=_tar_members_filter( 

322 model_with_data_zipped, local_model_with_data_dir_path 

323 ), 

324 )