Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/__init__.py: 27%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (c) ONNX Project Contributors
2#
3# SPDX-License-Identifier: Apache-2.0
5from __future__ import annotations
7__all__ = [
8 # Constants
9 "ONNX_ML",
10 "IR_VERSION",
11 "IR_VERSION_2017_10_10",
12 "IR_VERSION_2017_10_30",
13 "IR_VERSION_2017_11_3",
14 "IR_VERSION_2019_1_22",
15 "IR_VERSION_2019_3_18",
16 "IR_VERSION_2019_9_19",
17 "IR_VERSION_2020_5_8",
18 "IR_VERSION_2021_7_30",
19 "IR_VERSION_2023_5_5",
20 "IR_VERSION_2024_3_25",
21 "EXPERIMENTAL",
22 "STABLE",
23 # Modules
24 "checker",
25 "compose",
26 "defs",
27 "gen_proto",
28 "helper",
29 "numpy_helper",
30 "parser",
31 "printer",
32 "shape_inference",
33 "utils",
34 "version_converter",
35 # Proto classes
36 "AttributeProto",
37 "DeviceConfigurationProto",
38 "FunctionProto",
39 "GraphProto",
40 "IntIntListEntryProto",
41 "MapProto",
42 "ModelProto",
43 "NodeDeviceConfigurationProto",
44 "NodeProto",
45 "OperatorProto",
46 "OperatorSetIdProto",
47 "OperatorSetProto",
48 "OperatorStatus",
49 "OptionalProto",
50 "SequenceProto",
51 "SimpleShardedDimProto",
52 "ShardedDimProto",
53 "ShardingSpecProto",
54 "SparseTensorProto",
55 "StringStringEntryProto",
56 "TensorAnnotation",
57 "TensorProto",
58 "TensorShapeProto",
59 "TrainingInfoProto",
60 "TypeProto",
61 "ValueInfoProto",
62 "Version",
63 # Utility functions
64 "convert_model_to_external_data",
65 "load_external_data_for_model",
66 "load_model_from_string",
67 "load_model",
68 "load_tensor_from_string",
69 "load_tensor",
70 "save_model",
71 "save_tensor",
72 "write_external_data_tensors",
73]
74# isort:skip_file
76import os
77import typing
78from typing import IO, Literal
81from onnx import serialization
82from onnx.onnx_cpp2py_export import ONNX_ML
83from onnx.external_data_helper import (
84 load_external_data_for_model,
85 write_external_data_tensors,
86 convert_model_to_external_data,
87)
88from onnx.onnx_pb import (
89 AttributeProto,
90 DeviceConfigurationProto,
91 EXPERIMENTAL,
92 FunctionProto,
93 GraphProto,
94 IntIntListEntryProto,
95 IR_VERSION,
96 IR_VERSION_2017_10_10,
97 IR_VERSION_2017_10_30,
98 IR_VERSION_2017_11_3,
99 IR_VERSION_2019_1_22,
100 IR_VERSION_2019_3_18,
101 IR_VERSION_2019_9_19,
102 IR_VERSION_2020_5_8,
103 IR_VERSION_2021_7_30,
104 IR_VERSION_2023_5_5,
105 IR_VERSION_2024_3_25,
106 ModelProto,
107 NodeDeviceConfigurationProto,
108 NodeProto,
109 OperatorSetIdProto,
110 OperatorStatus,
111 STABLE,
112 SimpleShardedDimProto,
113 ShardedDimProto,
114 ShardingSpecProto,
115 SparseTensorProto,
116 StringStringEntryProto,
117 TensorAnnotation,
118 TensorProto,
119 TensorShapeProto,
120 TrainingInfoProto,
121 TypeProto,
122 ValueInfoProto,
123 Version,
124)
125from onnx.onnx_operators_pb import OperatorProto, OperatorSetProto
126from onnx.onnx_data_pb import MapProto, OptionalProto, SequenceProto
127import importlib.metadata
129# Import common subpackages so they're available when you 'import onnx'
130from onnx import (
131 checker,
132 compose,
133 defs,
134 gen_proto,
135 helper,
136 numpy_helper,
137 parser,
138 printer,
139 shape_inference,
140 utils,
141 version_converter,
142)
144if typing.TYPE_CHECKING:
145 from collections.abc import Sequence
147try:
148 __version__ = importlib.metadata.version("onnx")
149except importlib.metadata.PackageNotFoundError:
150 try:
151 __version__ = importlib.metadata.version("onnx-weekly")
152 except importlib.metadata.PackageNotFoundError:
153 __version__ = "unknown"
155# Supported model formats that can be loaded from and saved to
156# The literals are formats with built-in support. But we also allow users to
157# register their own formats. So we allow str as well.
158_SupportedFormat = Literal["protobuf", "textproto", "onnxtxt", "json"] | str # noqa: PYI051
159# Default serialization format
160_DEFAULT_FORMAT = "protobuf"
163def _load_bytes(f: IO[bytes] | str | os.PathLike) -> bytes:
164 if hasattr(f, "read") and callable(typing.cast("IO[bytes]", f).read):
165 content = typing.cast("IO[bytes]", f).read()
166 else:
167 f = typing.cast("str | os.PathLike", f)
168 with open(f, "rb") as readable:
169 content = readable.read()
170 return content
173def _save_bytes(content: bytes, f: IO[bytes] | str | os.PathLike) -> None:
174 if hasattr(f, "write") and callable(typing.cast("IO[bytes]", f).write):
175 typing.cast("IO[bytes]", f).write(content)
176 else:
177 f = typing.cast("str | os.PathLike", f)
178 with open(f, "wb") as writable:
179 writable.write(content)
182def _get_file_path(f: IO[bytes] | str | os.PathLike | None) -> str | None:
183 if isinstance(f, (str, os.PathLike)):
184 return os.path.abspath(f)
185 if hasattr(f, "name"):
186 assert f is not None
187 return os.path.abspath(f.name)
188 return None
191def _get_serializer(
192 fmt: _SupportedFormat | None, f: str | os.PathLike | IO[bytes] | None = None
193) -> serialization.ProtoSerializer:
194 """Get the serializer for the given path and format from the serialization registry."""
195 # Use fmt if it is specified
196 if fmt is not None:
197 return serialization.registry.get(fmt)
199 if (file_path := _get_file_path(f)) is not None:
200 _, ext = os.path.splitext(file_path)
201 fmt = serialization.registry.get_format_from_file_extension(ext)
203 # Failed to resolve format if fmt is None. Use protobuf as default
204 fmt = fmt or _DEFAULT_FORMAT
205 assert fmt is not None
207 return serialization.registry.get(fmt)
210def load_model(
211 f: IO[bytes] | str | os.PathLike,
212 format: _SupportedFormat | None = None, # noqa: A002
213 load_external_data: bool = True,
214) -> ModelProto:
215 """Loads a serialized ModelProto into memory.
217 Args:
218 f: can be a file-like object (has "read" function) or a string/PathLike containing a file name
219 format: The serialization format. When it is not specified, it is inferred
220 from the file extension when ``f`` is a path. If not specified _and_
221 ``f`` is not a path, 'protobuf' is used. The encoding is assumed to
222 be "utf-8" when the format is a text format.
223 load_external_data: Whether to load the external data.
224 Set to True if the data is under the same directory of the model.
225 If not, users need to call :func:`load_external_data_for_model`
226 with directory to load external data from.
228 Returns:
229 Loaded in-memory ModelProto.
230 """
231 model = _get_serializer(format, f).deserialize_proto(_load_bytes(f), ModelProto())
233 if load_external_data:
234 model_filepath = _get_file_path(f)
235 if model_filepath:
236 base_dir = os.path.dirname(model_filepath)
237 load_external_data_for_model(model, base_dir)
239 return model
242def load_tensor(
243 f: IO[bytes] | str | os.PathLike,
244 format: _SupportedFormat | None = None, # noqa: A002
245) -> TensorProto:
246 """Loads a serialized TensorProto into memory.
248 Args:
249 f: can be a file-like object (has "read" function) or a string/PathLike containing a file name
250 format: The serialization format. When it is not specified, it is inferred
251 from the file extension when ``f`` is a path. If not specified _and_
252 ``f`` is not a path, 'protobuf' is used. The encoding is assumed to
253 be "utf-8" when the format is a text format.
255 Returns:
256 Loaded in-memory TensorProto.
257 """
258 return _get_serializer(format, f).deserialize_proto(_load_bytes(f), TensorProto())
261def load_model_from_string(
262 s: bytes | str,
263 format: _SupportedFormat = _DEFAULT_FORMAT, # noqa: A002
264) -> ModelProto:
265 """Loads a binary string (bytes) that contains serialized ModelProto.
267 Args:
268 s: a string, which contains serialized ModelProto
269 format: The serialization format. When it is not specified, it is inferred
270 from the file extension when ``f`` is a path. If not specified _and_
271 ``f`` is not a path, 'protobuf' is used. The encoding is assumed to
272 be "utf-8" when the format is a text format.
274 Returns:
275 Loaded in-memory ModelProto.
276 """
277 return _get_serializer(format).deserialize_proto(s, ModelProto())
280def load_tensor_from_string(
281 s: bytes,
282 format: _SupportedFormat = _DEFAULT_FORMAT, # noqa: A002
283) -> TensorProto:
284 """Loads a binary string (bytes) that contains serialized TensorProto.
286 Args:
287 s: a string, which contains serialized TensorProto
288 format: The serialization format. When it is not specified, it is inferred
289 from the file extension when ``f`` is a path. If not specified _and_
290 ``f`` is not a path, 'protobuf' is used. The encoding is assumed to
291 be "utf-8" when the format is a text format.
293 Returns:
294 Loaded in-memory TensorProto.
295 """
296 return _get_serializer(format).deserialize_proto(s, TensorProto())
299def save_model(
300 proto: ModelProto | bytes,
301 f: IO[bytes] | str | os.PathLike,
302 format: _SupportedFormat | None = None, # noqa: A002
303 *,
304 save_as_external_data: bool = False,
305 all_tensors_to_one_file: bool = True,
306 location: str | None = None,
307 size_threshold: int = 1024,
308 convert_attribute: bool = False,
309) -> None:
310 """Saves the ModelProto to the specified path and optionally, serialize tensors with raw data as external data before saving.
312 Args:
313 proto: should be a in-memory ModelProto
314 f: can be a file-like object (has "write" function) or a string containing
315 a file name or a pathlike object
316 format: The serialization format. When it is not specified, it is inferred
317 from the file extension when ``f`` is a path. If not specified _and_
318 ``f`` is not a path, 'protobuf' is used. The encoding is assumed to
319 be "utf-8" when the format is a text format.
320 save_as_external_data: If true, save tensors to external file(s).
321 all_tensors_to_one_file: Effective only if save_as_external_data is True.
322 If true, save all tensors to one external file specified by location.
323 If false, save each tensor to a file named with the tensor name.
324 location: Effective only if save_as_external_data is true.
325 Specify the external file that all tensors to save to.
326 Path is relative to the model path.
327 If not specified, will use the model name.
328 size_threshold: Effective only if save_as_external_data is True.
329 Threshold for size of data. Only when tensor's data is >= the size_threshold it will be converted
330 to external data. To convert every tensor with raw data to external data set size_threshold=0.
331 convert_attribute: Effective only if save_as_external_data is True.
332 If true, convert all tensors to external data
333 If false, convert only non-attribute tensors to external data
334 """
335 if isinstance(proto, bytes):
336 proto = _get_serializer(_DEFAULT_FORMAT).deserialize_proto(proto, ModelProto())
338 if save_as_external_data:
339 convert_model_to_external_data(
340 proto, all_tensors_to_one_file, location, size_threshold, convert_attribute
341 )
343 model_filepath = _get_file_path(f)
344 if model_filepath is not None:
345 basepath = os.path.dirname(model_filepath)
346 proto = write_external_data_tensors(proto, basepath)
348 serialized = _get_serializer(format, model_filepath).serialize_proto(proto)
349 _save_bytes(serialized, f)
352def save_tensor(
353 proto: TensorProto,
354 f: IO[bytes] | str | os.PathLike,
355 format: _SupportedFormat | None = None, # noqa: A002
356) -> None:
357 """Saves the TensorProto to the specified path.
359 Args:
360 proto: should be a in-memory TensorProto
361 f: can be a file-like object (has "write" function) or a string
362 containing a file name or a pathlike object.
363 format: The serialization format. When it is not specified, it is inferred
364 from the file extension when ``f`` is a path. If not specified _and_
365 ``f`` is not a path, 'protobuf' is used. The encoding is assumed to
366 be "utf-8" when the format is a text format.
367 """
368 serialized = _get_serializer(format, f).serialize_proto(proto)
369 _save_bytes(serialized, f)
372# For backward compatibility
373load = load_model
374load_from_string = load_model_from_string
375save = save_model
378def _model_proto_repr(self: ModelProto) -> str:
379 if self.domain:
380 domain = f", domain='{self.domain}'"
381 else:
382 domain = ""
383 if self.producer_name:
384 producer_name = f", producer_name='{self.producer_name}'"
385 else:
386 producer_name = ""
387 if self.producer_version:
388 producer_version = f", producer_version='{self.producer_version}'"
389 else:
390 producer_version = ""
391 if self.graph:
392 graph = f", graph={self.graph!r}"
393 else:
394 graph = ""
395 if self.functions:
396 functions = f", functions=<{len(self.functions)} functions>"
397 else:
398 functions = ""
399 if self.opset_import:
400 opset_import = f", opset_import={_operator_set_protos_repr(self.opset_import)}"
401 else:
402 opset_import = ""
403 return f"ModelProto(ir_version={self.ir_version}{opset_import}{domain}{producer_name}{producer_version}{graph}{functions})"
406def _graph_proto_repr(self: GraphProto) -> str:
407 if self.initializer:
408 initializer = f", initializer=<{len(self.initializer)} initializers>"
409 else:
410 initializer = ""
411 if self.node:
412 node = f", node=<{len(self.node)} nodes>"
413 else:
414 node = ""
415 if self.value_info:
416 value_info = f", value_info=<{len(self.value_info)} value_info>"
417 else:
418 value_info = ""
419 if self.input:
420 input = f", input=<{len(self.input)} inputs>"
421 else:
422 input = ""
423 if self.output:
424 output = f", output=<{len(self.output)} outputs>"
425 else:
426 output = ""
427 return f"GraphProto('{self.name}'{input}{output}{initializer}{node}{value_info})"
430def _function_proto_repr(self: FunctionProto) -> str:
431 if self.domain:
432 domain = f", domain='{self.domain}'"
433 else:
434 domain = ""
435 if self.overload:
436 overload = f", overload='{self.overload}'"
437 else:
438 overload = ""
439 if self.node:
440 node = f", node=<{len(self.node)} nodes>"
441 else:
442 node = ""
443 if self.attribute:
444 attribute = f", attribute={self.attribute}"
445 else:
446 attribute = ""
447 if self.opset_import:
448 opset_import = f", opset_import={_operator_set_protos_repr(self.opset_import)}"
449 else:
450 opset_import = ""
451 if self.input:
452 input = f", input=<{len(self.input)} inputs>"
453 else:
454 input = ""
455 if self.output:
456 output = f", output=<{len(self.output)} outputs>"
457 else:
458 output = ""
459 return f"FunctionProto('{self.name}'{domain}{overload}{opset_import}{input}{output}{attribute}{node})"
462def _operator_set_protos_repr(protos: Sequence[OperatorSetIdProto]) -> str:
463 opset_imports = {proto.domain: proto.version for proto in protos}
464 return repr(opset_imports)
467# Override __repr__ for some proto classes to make it more efficient
468ModelProto.__repr__ = _model_proto_repr # type: ignore[method-assign,assignment]
469GraphProto.__repr__ = _graph_proto_repr # type: ignore[method-assign,assignment]
470FunctionProto.__repr__ = _function_proto_repr # type: ignore[method-assign,assignment]