Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/serialization.py: 54%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (c) ONNX Project Contributors
2#
3# SPDX-License-Identifier: Apache-2.0
5from __future__ import annotations
7import warnings
9__all__ = [
10 "registry",
11]
13import typing
14from typing import Any, Protocol, TypeVar
16import google.protobuf.json_format
17import google.protobuf.message
18import google.protobuf.text_format
20import onnx
22if typing.TYPE_CHECKING:
23 from collections.abc import Collection
25_Proto = TypeVar("_Proto", bound=google.protobuf.message.Message)
26# Encoding used for serializing and deserializing text files
27_ENCODING = "utf-8"
30class ProtoSerializer(Protocol):
31 """A serializer-deserializer to and from in-memory Protocol Buffers representations."""
33 # Format supported by the serializer. E.g. "protobuf"
34 supported_format: str
35 # File extensions supported by the serializer. E.g. frozenset({".onnx", ".pb"})
36 # Be careful to include the dot in the file extension.
37 file_extensions: Collection[str]
39 # NOTE: The methods defined are serialize_proto and deserialize_proto and not the
40 # more generic serialize and deserialize to leave space for future protocols
41 # that are defined to serialize/deserialize the ONNX in memory IR.
42 # This way a class can implement both protocols.
44 def serialize_proto(self, proto: _Proto) -> Any:
45 """Serialize a in-memory proto to a serialized data type."""
47 def deserialize_proto(self, serialized: Any, proto: _Proto) -> _Proto:
48 """Parse a serialized data type into a in-memory proto."""
51class _Registry:
52 def __init__(self) -> None:
53 self._serializers: dict[str, ProtoSerializer] = {}
54 # A mapping from file extension to format
55 self._extension_to_format: dict[str, str] = {}
57 def register(self, serializer: ProtoSerializer) -> None:
58 self._serializers[serializer.supported_format] = serializer
59 self._extension_to_format.update(
60 dict.fromkeys(serializer.file_extensions, serializer.supported_format)
61 )
63 def get(self, fmt: str) -> ProtoSerializer:
64 """Get a serializer for a format.
66 Args:
67 fmt: The format to get a serializer for.
69 Returns:
70 ProtoSerializer: The serializer for the format.
72 Raises:
73 ValueError: If the format is not supported.
74 """
75 try:
76 return self._serializers[fmt]
77 except KeyError:
78 raise ValueError(
79 f"Unsupported format: '{fmt}'. Supported formats are: {self._serializers.keys()}"
80 ) from None
82 def get_format_from_file_extension(self, file_extension: str) -> str | None:
83 """Get the corresponding format from a file extension.
85 Args:
86 file_extension: The file extension to get a format for.
88 Returns:
89 The format for the file extension, or None if not found.
90 """
91 return self._extension_to_format.get(file_extension)
94class _ProtobufSerializer(ProtoSerializer):
95 """Serialize and deserialize protobuf message."""
97 supported_format = "protobuf"
98 file_extensions = frozenset({".onnx", ".pb"})
100 def serialize_proto(self, proto: _Proto) -> bytes:
101 if hasattr(proto, "SerializeToString") and callable(proto.SerializeToString):
102 try:
103 result = proto.SerializeToString()
104 except ValueError as e:
105 if proto.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
106 raise ValueError(
107 "The proto size is larger than the 2 GB limit. "
108 "Please use save_as_external_data to save tensors separately from the model file."
109 ) from e
110 raise
111 return result
112 raise TypeError(
113 f"No SerializeToString method is detected.\ntype is {type(proto)}"
114 )
116 def deserialize_proto(self, serialized: bytes, proto: _Proto) -> _Proto:
117 if not isinstance(serialized, bytes):
118 raise TypeError(
119 f"Parameter 'serialized' must be bytes, but got type: {type(serialized)}"
120 )
121 decoded = typing.cast("int | None", proto.ParseFromString(serialized))
122 if decoded is not None and decoded != len(serialized):
123 raise google.protobuf.message.DecodeError(
124 f"Protobuf decoding consumed too few bytes: {decoded} out of {len(serialized)}"
125 )
126 return proto
129class _TextProtoSerializer(ProtoSerializer):
130 """Serialize and deserialize text proto."""
132 supported_format = "textproto"
133 file_extensions = frozenset({".txtpb", ".textproto", ".prototxt", ".pbtxt"})
135 def serialize_proto(self, proto: _Proto) -> bytes:
136 textproto = google.protobuf.text_format.MessageToString(proto)
137 return textproto.encode(_ENCODING)
139 def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
140 if not isinstance(serialized, (bytes, str)):
141 raise TypeError(
142 f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
143 )
144 if isinstance(serialized, bytes):
145 serialized = serialized.decode(_ENCODING)
146 assert isinstance(serialized, str)
147 return google.protobuf.text_format.Parse(serialized, proto)
150class _JsonSerializer(ProtoSerializer):
151 """Serialize and deserialize JSON."""
153 supported_format = "json"
154 file_extensions = frozenset({".json", ".onnxjson"})
156 def serialize_proto(self, proto: _Proto) -> bytes:
157 json_message = google.protobuf.json_format.MessageToJson(
158 proto, preserving_proto_field_name=True
159 )
160 return json_message.encode(_ENCODING)
162 def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
163 if not isinstance(serialized, (bytes, str)):
164 raise TypeError(
165 f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
166 )
167 if isinstance(serialized, bytes):
168 serialized = serialized.decode(_ENCODING)
169 assert isinstance(serialized, str)
170 return google.protobuf.json_format.Parse(serialized, proto)
173class _TextualSerializer(ProtoSerializer):
174 """Serialize and deserialize the ONNX textual representation."""
176 supported_format = "onnxtxt"
177 file_extensions = frozenset({".onnxtxt", ".onnxtext"})
179 def serialize_proto(self, proto: _Proto) -> bytes:
180 text = onnx.printer.to_text(proto) # type: ignore[arg-type]
181 return text.encode(_ENCODING)
183 def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
184 warnings.warn(
185 "The onnxtxt format is experimental. Please report any errors to the ONNX GitHub repository.",
186 stacklevel=2,
187 )
188 if not isinstance(serialized, (bytes, str)):
189 raise TypeError(
190 f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
191 )
192 if isinstance(serialized, bytes):
193 text = serialized.decode(_ENCODING)
194 else:
195 text = serialized
196 if isinstance(proto, onnx.ModelProto):
197 return onnx.parser.parse_model(text) # type: ignore[return-value]
198 if isinstance(proto, onnx.GraphProto):
199 return onnx.parser.parse_graph(text) # type: ignore[return-value]
200 if isinstance(proto, onnx.FunctionProto):
201 return onnx.parser.parse_function(text) # type: ignore[return-value]
202 if isinstance(proto, onnx.NodeProto):
203 return onnx.parser.parse_node(text) # type: ignore[return-value]
204 raise ValueError(f"Unsupported proto type: {type(proto)}")
207# Register default serializers
208registry = _Registry()
209registry.register(_ProtobufSerializer())
210registry.register(_TextProtoSerializer())
211registry.register(_JsonSerializer())
212registry.register(_TextualSerializer())