Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/utils.py: 17%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (c) ONNX Project Contributors
2#
3# SPDX-License-Identifier: Apache-2.0
4from __future__ import annotations
6import os
7import tarfile
8from collections import deque
9from typing import TYPE_CHECKING
11import onnx.checker
12import onnx.helper
13import onnx.shape_inference
15if TYPE_CHECKING:
16 from onnx.onnx_pb import (
17 FunctionProto,
18 ModelProto,
19 NodeProto,
20 TensorProto,
21 ValueInfoProto,
22 )
25class Extractor:
26 def __init__(self, model: ModelProto) -> None:
27 self.model = model
28 self.graph = self.model.graph
29 self.initializers: dict[str, TensorProto] = self._build_name2obj_dict(
30 self.graph.initializer
31 )
32 self.value_infos: dict[str, ValueInfoProto] = self._build_name2obj_dict(
33 self.graph.value_info
34 )
35 # Add input and output values (not included in the value_info for intermediate values)
36 self.value_infos.update(self._build_name2obj_dict(self.graph.input))
37 self.value_infos.update(self._build_name2obj_dict(self.graph.output))
38 self.outmap: dict[str, int] = self._build_output_dict(self.graph)
40 @staticmethod
41 def _build_name2obj_dict(objs) -> dict:
42 return {obj.name: obj for obj in objs}
44 @staticmethod
45 def _build_output_dict(graph) -> dict[str, int]:
46 output_to_index: dict[str, int] = {}
47 for index, node in enumerate(graph.node):
48 for output_name in node.output:
49 if output_name == "":
50 continue
51 assert output_name not in output_to_index # output_name is unique
52 output_to_index[output_name] = index
53 return output_to_index
55 def _collect_new_io(self, io_names_to_extract: list[str]) -> list[ValueInfoProto]:
56 # Validate that all names exist in self.value_infos
57 missing_names = [
58 name for name in io_names_to_extract if name not in self.value_infos
59 ]
60 if missing_names:
61 raise ValueError(
62 f"The following names were not found in value_infos: {', '.join(missing_names)}"
63 )
64 return [self.value_infos[name] for name in io_names_to_extract]
66 def _dfs_search_reachable_nodes(
67 self,
68 node_output_name: str,
69 graph_input_names: set[str],
70 reachable: set[int],
71 ) -> None:
72 """Helper function to find nodes which are connected to an output
74 Arguments:
75 node_output_name (str): The name of the output
76 graph_input_names (set of string): The names of all inputs of the graph
77 reachable (set of int): The set of indexes to reachable nodes in `nodes`
78 """
79 stack = [node_output_name]
80 while stack:
81 current_output_name = stack.pop()
82 # finish search at inputs
83 if current_output_name in graph_input_names:
84 continue
85 # find nodes connected to this output
86 if current_output_name in self.outmap:
87 index = self.outmap[current_output_name]
88 if index not in reachable:
89 # add nodes connected to this output to sets
90 reachable.add(index)
91 stack += [
92 input_name
93 for input_name in self.graph.node[index].input
94 if input_name != ""
95 ]
97 def _collect_reachable_nodes(
98 self,
99 input_names: list[str],
100 output_names: list[str],
101 ) -> list[NodeProto]:
102 _input_names = set(input_names)
103 reachable: set[int] = set()
104 for name in output_names:
105 self._dfs_search_reachable_nodes(name, _input_names, reachable)
106 # needs to be topologically sorted
107 return [self.graph.node[index] for index in sorted(reachable)]
109 def _collect_referred_local_functions(
110 self,
111 nodes: list[NodeProto],
112 ) -> list[FunctionProto]:
113 # a node in a model graph may refer a function.
114 # a function contains nodes, some of which may in turn refer a function.
115 # we need to find functions referred by graph nodes and
116 # by nodes used to define functions.
117 function_map: dict[tuple[str, str], FunctionProto] = {}
118 for function in self.model.functions:
119 function_map[(function.name, function.domain)] = function
120 referred_local_functions: list[FunctionProto] = []
121 queue = deque(nodes)
122 while queue:
123 node = queue.popleft()
124 # check if the node is a function op
125 if (node.op_type, node.domain) in function_map:
126 function = function_map.pop((node.op_type, node.domain))
127 referred_local_functions.append(function)
128 queue.extend(function.node)
129 # needs to be topologically sorted
130 return referred_local_functions
132 def _collect_reachable_tensors(
133 self,
134 nodes: list[NodeProto],
135 ) -> tuple[list[TensorProto], list[ValueInfoProto]]:
136 all_tensors_names: set[str] = set()
137 for node in nodes:
138 all_tensors_names.update(node.input)
139 all_tensors_names.update(node.output)
140 initializer = [
141 self.initializers[t] for t in self.initializers if t in all_tensors_names
142 ]
143 value_info = [
144 self.value_infos[t] for t in self.value_infos if t in all_tensors_names
145 ]
146 len_sparse_initializer = len(self.graph.sparse_initializer)
147 if len_sparse_initializer != 0:
148 raise ValueError(
149 f"len_sparse_initializer is {len_sparse_initializer}, it must be 0."
150 )
151 len_quantization_annotation = len(self.graph.quantization_annotation)
152 if len_quantization_annotation != 0:
153 raise ValueError(
154 f"len_quantization_annotation is {len_quantization_annotation}, it must be 0."
155 )
156 return initializer, value_info
158 def _make_model(
159 self,
160 nodes: list[NodeProto],
161 inputs: list[ValueInfoProto],
162 outputs: list[ValueInfoProto],
163 initializer: list[TensorProto],
164 value_info: list[ValueInfoProto],
165 local_functions: list[FunctionProto],
166 ) -> ModelProto:
167 name = "Extracted from {" + self.graph.name + "}"
168 graph = onnx.helper.make_graph(
169 nodes, name, inputs, outputs, initializer=initializer, value_info=value_info
170 )
171 meta = {
172 "ir_version": self.model.ir_version,
173 "opset_imports": self.model.opset_import,
174 "producer_name": "onnx.utils.extract_model",
175 "functions": local_functions,
176 }
177 return onnx.helper.make_model(graph, **meta)
179 def extract_model(
180 self,
181 input_names: list[str],
182 output_names: list[str],
183 ) -> ModelProto:
184 inputs = self._collect_new_io(input_names)
185 outputs = self._collect_new_io(output_names)
186 nodes = self._collect_reachable_nodes(input_names, output_names)
187 initializer, value_info = self._collect_reachable_tensors(nodes)
188 local_functions = self._collect_referred_local_functions(nodes)
189 return self._make_model(
190 nodes, inputs, outputs, initializer, value_info, local_functions
191 )
194def extract_model(
195 input_path: str | os.PathLike,
196 output_path: str | os.PathLike,
197 input_names: list[str],
198 output_names: list[str],
199 check_model: bool = True,
200 infer_shapes: bool = True,
201) -> None:
202 """Extracts sub-model from an ONNX model.
204 The sub-model is defined by the names of the input and output tensors *exactly*.
206 Note: For control-flow operators, e.g. If and Loop, the _boundary of sub-model_,
207 which is defined by the input and output tensors, should not _cut through_ the
208 subgraph that is connected to the _main graph_ as attributes of these operators.
210 Note: When the extracted model size is larger than 2GB, the extra data will be saved in "output_path.data".
212 Arguments:
213 input_path (str | os.PathLike): The path to original ONNX model.
214 output_path (str | os.PathLike): The path to save the extracted ONNX model.
215 input_names (list of string): The names of the input tensors that to be extracted.
216 output_names (list of string): The names of the output tensors that to be extracted.
217 check_model (bool): Whether to run model checker on the original model and the extracted model.
218 infer_shapes (bool): Whether to infer the shapes of the original model.
219 """
220 if not os.path.exists(input_path):
221 raise ValueError(f"Invalid input model path: {input_path}")
222 if not output_path:
223 raise ValueError("Output model path shall not be empty!")
224 if not input_names:
225 raise ValueError("Input tensor names shall not be empty!")
226 if not output_names:
227 raise ValueError("Output tensor names shall not be empty!")
229 if len(input_names) != len(set(input_names)):
230 raise ValueError("Duplicate names found in the input tensor names.")
231 if len(output_names) != len(set(output_names)):
232 raise ValueError("Duplicate names found in the output tensor names.")
234 if check_model:
235 onnx.checker.check_model(input_path)
237 if infer_shapes and os.path.getsize(input_path) > onnx.checker.MAXIMUM_PROTOBUF:
238 onnx.shape_inference.infer_shapes_path(input_path, output_path)
239 model = onnx.load(output_path)
240 elif infer_shapes:
241 model = onnx.load(input_path, load_external_data=False)
242 model = onnx.shape_inference.infer_shapes(model)
243 base_dir = os.path.dirname(input_path)
244 onnx.load_external_data_for_model(model, base_dir)
245 else:
246 model = onnx.load(input_path)
248 e = Extractor(model)
249 extracted = e.extract_model(input_names, output_names)
251 if extracted.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF:
252 location = os.path.basename(output_path) + ".data"
253 onnx.save(extracted, output_path, save_as_external_data=True, location=location)
254 else:
255 onnx.save(extracted, output_path)
257 if check_model:
258 onnx.checker.check_model(output_path)
261def _tar_members_filter(
262 tar: tarfile.TarFile, base: str | os.PathLike
263) -> list[tarfile.TarInfo]:
264 """Check that the content of ``tar`` will be extracted safely
266 Args:
267 tar: The tarball file
268 base: The directory where the tarball will be extracted
270 Returns:
271 list of tarball members
272 """
273 result = []
274 abs_base = os.path.abspath(base)
275 for member in tar:
276 member_path = os.path.join(base, member.name)
277 abs_member = os.path.abspath(member_path)
278 try:
279 is_within_base = os.path.commonpath([abs_base, abs_member]) == abs_base
280 except ValueError:
281 is_within_base = False
282 if not is_within_base:
283 raise RuntimeError(
284 f"The tarball member {member_path} in downloading model contains "
285 f"directory traversal sequence which may contain harmful payload."
286 )
287 if member.issym() or member.islnk():
288 raise RuntimeError(
289 f"The tarball member {member_path} in downloading model contains "
290 f"symbolic links which may contain harmful payload."
291 )
292 result.append(member)
293 return result
296def _extract_model_safe(
297 model_tar_path: str | os.PathLike, local_model_with_data_dir_path: str | os.PathLike
298) -> None:
299 """Safely extracts a tar file to a specified directory.
301 This function ensures that the extraction process mitigates against
302 directory traversal vulnerabilities by validating or sanitizing paths
303 within the tar file. It also provides compatibility for different versions
304 of the tarfile module by checking for the availability of certain attributes
305 or methods before invoking them.
307 Args:
308 model_tar_path: The path to the tar file to be extracted.
309 local_model_with_data_dir_path: The directory path where the tar file
310 contents will be extracted to.
311 """
312 with tarfile.open(model_tar_path) as model_with_data_zipped:
313 # Mitigate tarball directory traversal risks
314 if hasattr(tarfile, "data_filter"):
315 model_with_data_zipped.extractall(
316 path=local_model_with_data_dir_path, filter="data"
317 )
318 else:
319 model_with_data_zipped.extractall( # noqa: S202
320 path=local_model_with_data_dir_path,
321 members=_tar_members_filter(
322 model_with_data_zipped, local_model_with_data_dir_path
323 ),
324 )