Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/serialization.py: 54%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

106 statements  

1# Copyright (c) ONNX Project Contributors 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4 

5from __future__ import annotations 

6 

7import warnings 

8 

9__all__ = [ 

10 "registry", 

11] 

12 

13import typing 

14from typing import Any, Protocol, TypeVar 

15 

16import google.protobuf.json_format 

17import google.protobuf.message 

18import google.protobuf.text_format 

19 

20import onnx 

21 

22if typing.TYPE_CHECKING: 

23 from collections.abc import Collection 

24 

25_Proto = TypeVar("_Proto", bound=google.protobuf.message.Message) 

26# Encoding used for serializing and deserializing text files 

27_ENCODING = "utf-8" 

28 

29 

30class ProtoSerializer(Protocol): 

31 """A serializer-deserializer to and from in-memory Protocol Buffers representations.""" 

32 

33 # Format supported by the serializer. E.g. "protobuf" 

34 supported_format: str 

35 # File extensions supported by the serializer. E.g. frozenset({".onnx", ".pb"}) 

36 # Be careful to include the dot in the file extension. 

37 file_extensions: Collection[str] 

38 

39 # NOTE: The methods defined are serialize_proto and deserialize_proto and not the 

40 # more generic serialize and deserialize to leave space for future protocols 

41 # that are defined to serialize/deserialize the ONNX in memory IR. 

42 # This way a class can implement both protocols. 

43 

44 def serialize_proto(self, proto: _Proto) -> Any: 

45 """Serialize a in-memory proto to a serialized data type.""" 

46 

47 def deserialize_proto(self, serialized: Any, proto: _Proto) -> _Proto: 

48 """Parse a serialized data type into a in-memory proto.""" 

49 

50 

51class _Registry: 

52 def __init__(self) -> None: 

53 self._serializers: dict[str, ProtoSerializer] = {} 

54 # A mapping from file extension to format 

55 self._extension_to_format: dict[str, str] = {} 

56 

57 def register(self, serializer: ProtoSerializer) -> None: 

58 self._serializers[serializer.supported_format] = serializer 

59 self._extension_to_format.update( 

60 dict.fromkeys(serializer.file_extensions, serializer.supported_format) 

61 ) 

62 

63 def get(self, fmt: str) -> ProtoSerializer: 

64 """Get a serializer for a format. 

65 

66 Args: 

67 fmt: The format to get a serializer for. 

68 

69 Returns: 

70 ProtoSerializer: The serializer for the format. 

71 

72 Raises: 

73 ValueError: If the format is not supported. 

74 """ 

75 try: 

76 return self._serializers[fmt] 

77 except KeyError: 

78 raise ValueError( 

79 f"Unsupported format: '{fmt}'. Supported formats are: {self._serializers.keys()}" 

80 ) from None 

81 

82 def get_format_from_file_extension(self, file_extension: str) -> str | None: 

83 """Get the corresponding format from a file extension. 

84 

85 Args: 

86 file_extension: The file extension to get a format for. 

87 

88 Returns: 

89 The format for the file extension, or None if not found. 

90 """ 

91 return self._extension_to_format.get(file_extension) 

92 

93 

94class _ProtobufSerializer(ProtoSerializer): 

95 """Serialize and deserialize protobuf message.""" 

96 

97 supported_format = "protobuf" 

98 file_extensions = frozenset({".onnx", ".pb"}) 

99 

100 def serialize_proto(self, proto: _Proto) -> bytes: 

101 if hasattr(proto, "SerializeToString") and callable(proto.SerializeToString): 

102 try: 

103 result = proto.SerializeToString() 

104 except ValueError as e: 

105 if proto.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: 

106 raise ValueError( 

107 "The proto size is larger than the 2 GB limit. " 

108 "Please use save_as_external_data to save tensors separately from the model file." 

109 ) from e 

110 raise 

111 return result 

112 raise TypeError( 

113 f"No SerializeToString method is detected.\ntype is {type(proto)}" 

114 ) 

115 

116 def deserialize_proto(self, serialized: bytes, proto: _Proto) -> _Proto: 

117 if not isinstance(serialized, bytes): 

118 raise TypeError( 

119 f"Parameter 'serialized' must be bytes, but got type: {type(serialized)}" 

120 ) 

121 decoded = typing.cast("int | None", proto.ParseFromString(serialized)) 

122 if decoded is not None and decoded != len(serialized): 

123 raise google.protobuf.message.DecodeError( 

124 f"Protobuf decoding consumed too few bytes: {decoded} out of {len(serialized)}" 

125 ) 

126 return proto 

127 

128 

129class _TextProtoSerializer(ProtoSerializer): 

130 """Serialize and deserialize text proto.""" 

131 

132 supported_format = "textproto" 

133 file_extensions = frozenset({".txtpb", ".textproto", ".prototxt", ".pbtxt"}) 

134 

135 def serialize_proto(self, proto: _Proto) -> bytes: 

136 textproto = google.protobuf.text_format.MessageToString(proto) 

137 return textproto.encode(_ENCODING) 

138 

139 def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto: 

140 if not isinstance(serialized, (bytes, str)): 

141 raise TypeError( 

142 f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}" 

143 ) 

144 if isinstance(serialized, bytes): 

145 serialized = serialized.decode(_ENCODING) 

146 assert isinstance(serialized, str) 

147 return google.protobuf.text_format.Parse(serialized, proto) 

148 

149 

150class _JsonSerializer(ProtoSerializer): 

151 """Serialize and deserialize JSON.""" 

152 

153 supported_format = "json" 

154 file_extensions = frozenset({".json", ".onnxjson"}) 

155 

156 def serialize_proto(self, proto: _Proto) -> bytes: 

157 json_message = google.protobuf.json_format.MessageToJson( 

158 proto, preserving_proto_field_name=True 

159 ) 

160 return json_message.encode(_ENCODING) 

161 

162 def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto: 

163 if not isinstance(serialized, (bytes, str)): 

164 raise TypeError( 

165 f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}" 

166 ) 

167 if isinstance(serialized, bytes): 

168 serialized = serialized.decode(_ENCODING) 

169 assert isinstance(serialized, str) 

170 return google.protobuf.json_format.Parse(serialized, proto) 

171 

172 

173class _TextualSerializer(ProtoSerializer): 

174 """Serialize and deserialize the ONNX textual representation.""" 

175 

176 supported_format = "onnxtxt" 

177 file_extensions = frozenset({".onnxtxt", ".onnxtext"}) 

178 

179 def serialize_proto(self, proto: _Proto) -> bytes: 

180 text = onnx.printer.to_text(proto) # type: ignore[arg-type] 

181 return text.encode(_ENCODING) 

182 

183 def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto: 

184 warnings.warn( 

185 "The onnxtxt format is experimental. Please report any errors to the ONNX GitHub repository.", 

186 stacklevel=2, 

187 ) 

188 if not isinstance(serialized, (bytes, str)): 

189 raise TypeError( 

190 f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}" 

191 ) 

192 if isinstance(serialized, bytes): 

193 text = serialized.decode(_ENCODING) 

194 else: 

195 text = serialized 

196 if isinstance(proto, onnx.ModelProto): 

197 return onnx.parser.parse_model(text) # type: ignore[return-value] 

198 if isinstance(proto, onnx.GraphProto): 

199 return onnx.parser.parse_graph(text) # type: ignore[return-value] 

200 if isinstance(proto, onnx.FunctionProto): 

201 return onnx.parser.parse_function(text) # type: ignore[return-value] 

202 if isinstance(proto, onnx.NodeProto): 

203 return onnx.parser.parse_node(text) # type: ignore[return-value] 

204 raise ValueError(f"Unsupported proto type: {type(proto)}") 

205 

206 

207# Register default serializers 

208registry = _Registry() 

209registry.register(_ProtobufSerializer()) 

210registry.register(_TextProtoSerializer()) 

211registry.register(_JsonSerializer()) 

212registry.register(_TextualSerializer())