Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/__init__.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

151 statements  

1# Copyright (c) ONNX Project Contributors 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4 

5from __future__ import annotations 

6 

7__all__ = [ 

8 # Constants 

9 "ONNX_ML", 

10 "IR_VERSION", 

11 "IR_VERSION_2017_10_10", 

12 "IR_VERSION_2017_10_30", 

13 "IR_VERSION_2017_11_3", 

14 "IR_VERSION_2019_1_22", 

15 "IR_VERSION_2019_3_18", 

16 "IR_VERSION_2019_9_19", 

17 "IR_VERSION_2020_5_8", 

18 "IR_VERSION_2021_7_30", 

19 "IR_VERSION_2023_5_5", 

20 "IR_VERSION_2024_3_25", 

21 "EXPERIMENTAL", 

22 "STABLE", 

23 # Modules 

24 "checker", 

25 "compose", 

26 "defs", 

27 "gen_proto", 

28 "helper", 

29 "numpy_helper", 

30 "parser", 

31 "printer", 

32 "shape_inference", 

33 "utils", 

34 "version_converter", 

35 # Proto classes 

36 "AttributeProto", 

37 "DeviceConfigurationProto", 

38 "FunctionProto", 

39 "GraphProto", 

40 "IntIntListEntryProto", 

41 "MapProto", 

42 "ModelProto", 

43 "NodeDeviceConfigurationProto", 

44 "NodeProto", 

45 "OperatorProto", 

46 "OperatorSetIdProto", 

47 "OperatorSetProto", 

48 "OperatorStatus", 

49 "OptionalProto", 

50 "SequenceProto", 

51 "SimpleShardedDimProto", 

52 "ShardedDimProto", 

53 "ShardingSpecProto", 

54 "SparseTensorProto", 

55 "StringStringEntryProto", 

56 "TensorAnnotation", 

57 "TensorProto", 

58 "TensorShapeProto", 

59 "TrainingInfoProto", 

60 "TypeProto", 

61 "ValueInfoProto", 

62 "Version", 

63 # Utility functions 

64 "convert_model_to_external_data", 

65 "load_external_data_for_model", 

66 "load_model_from_string", 

67 "load_model", 

68 "load_tensor_from_string", 

69 "load_tensor", 

70 "save_model", 

71 "save_tensor", 

72 "write_external_data_tensors", 

73] 

74# isort:skip_file 

75 

76import os 

77import typing 

78from typing import IO, Literal 

79 

80 

81from onnx import serialization 

82from onnx.onnx_cpp2py_export import ONNX_ML 

83from onnx.external_data_helper import ( 

84 load_external_data_for_model, 

85 write_external_data_tensors, 

86 convert_model_to_external_data, 

87) 

88from onnx.onnx_pb import ( 

89 AttributeProto, 

90 DeviceConfigurationProto, 

91 EXPERIMENTAL, 

92 FunctionProto, 

93 GraphProto, 

94 IntIntListEntryProto, 

95 IR_VERSION, 

96 IR_VERSION_2017_10_10, 

97 IR_VERSION_2017_10_30, 

98 IR_VERSION_2017_11_3, 

99 IR_VERSION_2019_1_22, 

100 IR_VERSION_2019_3_18, 

101 IR_VERSION_2019_9_19, 

102 IR_VERSION_2020_5_8, 

103 IR_VERSION_2021_7_30, 

104 IR_VERSION_2023_5_5, 

105 IR_VERSION_2024_3_25, 

106 ModelProto, 

107 NodeDeviceConfigurationProto, 

108 NodeProto, 

109 OperatorSetIdProto, 

110 OperatorStatus, 

111 STABLE, 

112 SimpleShardedDimProto, 

113 ShardedDimProto, 

114 ShardingSpecProto, 

115 SparseTensorProto, 

116 StringStringEntryProto, 

117 TensorAnnotation, 

118 TensorProto, 

119 TensorShapeProto, 

120 TrainingInfoProto, 

121 TypeProto, 

122 ValueInfoProto, 

123 Version, 

124) 

125from onnx.onnx_operators_pb import OperatorProto, OperatorSetProto 

126from onnx.onnx_data_pb import MapProto, OptionalProto, SequenceProto 

127import importlib.metadata 

128 

129# Import common subpackages so they're available when you 'import onnx' 

130from onnx import ( 

131 checker, 

132 compose, 

133 defs, 

134 gen_proto, 

135 helper, 

136 numpy_helper, 

137 parser, 

138 printer, 

139 shape_inference, 

140 utils, 

141 version_converter, 

142) 

143 

144if typing.TYPE_CHECKING: 

145 from collections.abc import Sequence 

146 

147try: 

148 __version__ = importlib.metadata.version("onnx") 

149except importlib.metadata.PackageNotFoundError: 

150 try: 

151 __version__ = importlib.metadata.version("onnx-weekly") 

152 except importlib.metadata.PackageNotFoundError: 

153 __version__ = "unknown" 

154 

155# Supported model formats that can be loaded from and saved to 

156# The literals are formats with built-in support. But we also allow users to 

157# register their own formats. So we allow str as well. 

158_SupportedFormat = Literal["protobuf", "textproto", "onnxtxt", "json"] | str # noqa: PYI051 

159# Default serialization format 

160_DEFAULT_FORMAT = "protobuf" 

161 

162 

163def _load_bytes(f: IO[bytes] | str | os.PathLike) -> bytes: 

164 if hasattr(f, "read") and callable(typing.cast("IO[bytes]", f).read): 

165 content = typing.cast("IO[bytes]", f).read() 

166 else: 

167 f = typing.cast("str | os.PathLike", f) 

168 with open(f, "rb") as readable: 

169 content = readable.read() 

170 return content 

171 

172 

173def _save_bytes(content: bytes, f: IO[bytes] | str | os.PathLike) -> None: 

174 if hasattr(f, "write") and callable(typing.cast("IO[bytes]", f).write): 

175 typing.cast("IO[bytes]", f).write(content) 

176 else: 

177 f = typing.cast("str | os.PathLike", f) 

178 with open(f, "wb") as writable: 

179 writable.write(content) 

180 

181 

182def _get_file_path(f: IO[bytes] | str | os.PathLike | None) -> str | None: 

183 if isinstance(f, (str, os.PathLike)): 

184 return os.path.abspath(f) 

185 if hasattr(f, "name"): 

186 assert f is not None 

187 return os.path.abspath(f.name) 

188 return None 

189 

190 

191def _get_serializer( 

192 fmt: _SupportedFormat | None, f: str | os.PathLike | IO[bytes] | None = None 

193) -> serialization.ProtoSerializer: 

194 """Get the serializer for the given path and format from the serialization registry.""" 

195 # Use fmt if it is specified 

196 if fmt is not None: 

197 return serialization.registry.get(fmt) 

198 

199 if (file_path := _get_file_path(f)) is not None: 

200 _, ext = os.path.splitext(file_path) 

201 fmt = serialization.registry.get_format_from_file_extension(ext) 

202 

203 # Failed to resolve format if fmt is None. Use protobuf as default 

204 fmt = fmt or _DEFAULT_FORMAT 

205 assert fmt is not None 

206 

207 return serialization.registry.get(fmt) 

208 

209 

210def load_model( 

211 f: IO[bytes] | str | os.PathLike, 

212 format: _SupportedFormat | None = None, # noqa: A002 

213 load_external_data: bool = True, 

214) -> ModelProto: 

215 """Loads a serialized ModelProto into memory. 

216 

217 Args: 

218 f: can be a file-like object (has "read" function) or a string/PathLike containing a file name 

219 format: The serialization format. When it is not specified, it is inferred 

220 from the file extension when ``f`` is a path. If not specified _and_ 

221 ``f`` is not a path, 'protobuf' is used. The encoding is assumed to 

222 be "utf-8" when the format is a text format. 

223 load_external_data: Whether to load the external data. 

224 Set to True if the data is under the same directory of the model. 

225 If not, users need to call :func:`load_external_data_for_model` 

226 with directory to load external data from. 

227 

228 Returns: 

229 Loaded in-memory ModelProto. 

230 """ 

231 model = _get_serializer(format, f).deserialize_proto(_load_bytes(f), ModelProto()) 

232 

233 if load_external_data: 

234 model_filepath = _get_file_path(f) 

235 if model_filepath: 

236 base_dir = os.path.dirname(model_filepath) 

237 load_external_data_for_model(model, base_dir) 

238 

239 return model 

240 

241 

242def load_tensor( 

243 f: IO[bytes] | str | os.PathLike, 

244 format: _SupportedFormat | None = None, # noqa: A002 

245) -> TensorProto: 

246 """Loads a serialized TensorProto into memory. 

247 

248 Args: 

249 f: can be a file-like object (has "read" function) or a string/PathLike containing a file name 

250 format: The serialization format. When it is not specified, it is inferred 

251 from the file extension when ``f`` is a path. If not specified _and_ 

252 ``f`` is not a path, 'protobuf' is used. The encoding is assumed to 

253 be "utf-8" when the format is a text format. 

254 

255 Returns: 

256 Loaded in-memory TensorProto. 

257 """ 

258 return _get_serializer(format, f).deserialize_proto(_load_bytes(f), TensorProto()) 

259 

260 

261def load_model_from_string( 

262 s: bytes | str, 

263 format: _SupportedFormat = _DEFAULT_FORMAT, # noqa: A002 

264) -> ModelProto: 

265 """Loads a binary string (bytes) that contains serialized ModelProto. 

266 

267 Args: 

268 s: a string, which contains serialized ModelProto 

269 format: The serialization format. When it is not specified, it is inferred 

270 from the file extension when ``f`` is a path. If not specified _and_ 

271 ``f`` is not a path, 'protobuf' is used. The encoding is assumed to 

272 be "utf-8" when the format is a text format. 

273 

274 Returns: 

275 Loaded in-memory ModelProto. 

276 """ 

277 return _get_serializer(format).deserialize_proto(s, ModelProto()) 

278 

279 

280def load_tensor_from_string( 

281 s: bytes, 

282 format: _SupportedFormat = _DEFAULT_FORMAT, # noqa: A002 

283) -> TensorProto: 

284 """Loads a binary string (bytes) that contains serialized TensorProto. 

285 

286 Args: 

287 s: a string, which contains serialized TensorProto 

288 format: The serialization format. When it is not specified, it is inferred 

289 from the file extension when ``f`` is a path. If not specified _and_ 

290 ``f`` is not a path, 'protobuf' is used. The encoding is assumed to 

291 be "utf-8" when the format is a text format. 

292 

293 Returns: 

294 Loaded in-memory TensorProto. 

295 """ 

296 return _get_serializer(format).deserialize_proto(s, TensorProto()) 

297 

298 

299def save_model( 

300 proto: ModelProto | bytes, 

301 f: IO[bytes] | str | os.PathLike, 

302 format: _SupportedFormat | None = None, # noqa: A002 

303 *, 

304 save_as_external_data: bool = False, 

305 all_tensors_to_one_file: bool = True, 

306 location: str | None = None, 

307 size_threshold: int = 1024, 

308 convert_attribute: bool = False, 

309) -> None: 

310 """Saves the ModelProto to the specified path and optionally, serialize tensors with raw data as external data before saving. 

311 

312 Args: 

313 proto: should be a in-memory ModelProto 

314 f: can be a file-like object (has "write" function) or a string containing 

315 a file name or a pathlike object 

316 format: The serialization format. When it is not specified, it is inferred 

317 from the file extension when ``f`` is a path. If not specified _and_ 

318 ``f`` is not a path, 'protobuf' is used. The encoding is assumed to 

319 be "utf-8" when the format is a text format. 

320 save_as_external_data: If true, save tensors to external file(s). 

321 all_tensors_to_one_file: Effective only if save_as_external_data is True. 

322 If true, save all tensors to one external file specified by location. 

323 If false, save each tensor to a file named with the tensor name. 

324 location: Effective only if save_as_external_data is true. 

325 Specify the external file that all tensors to save to. 

326 Path is relative to the model path. 

327 If not specified, will use the model name. 

328 size_threshold: Effective only if save_as_external_data is True. 

329 Threshold for size of data. Only when tensor's data is >= the size_threshold it will be converted 

330 to external data. To convert every tensor with raw data to external data set size_threshold=0. 

331 convert_attribute: Effective only if save_as_external_data is True. 

332 If true, convert all tensors to external data 

333 If false, convert only non-attribute tensors to external data 

334 """ 

335 if isinstance(proto, bytes): 

336 proto = _get_serializer(_DEFAULT_FORMAT).deserialize_proto(proto, ModelProto()) 

337 

338 if save_as_external_data: 

339 convert_model_to_external_data( 

340 proto, all_tensors_to_one_file, location, size_threshold, convert_attribute 

341 ) 

342 

343 model_filepath = _get_file_path(f) 

344 if model_filepath is not None: 

345 basepath = os.path.dirname(model_filepath) 

346 proto = write_external_data_tensors(proto, basepath) 

347 

348 serialized = _get_serializer(format, model_filepath).serialize_proto(proto) 

349 _save_bytes(serialized, f) 

350 

351 

352def save_tensor( 

353 proto: TensorProto, 

354 f: IO[bytes] | str | os.PathLike, 

355 format: _SupportedFormat | None = None, # noqa: A002 

356) -> None: 

357 """Saves the TensorProto to the specified path. 

358 

359 Args: 

360 proto: should be a in-memory TensorProto 

361 f: can be a file-like object (has "write" function) or a string 

362 containing a file name or a pathlike object. 

363 format: The serialization format. When it is not specified, it is inferred 

364 from the file extension when ``f`` is a path. If not specified _and_ 

365 ``f`` is not a path, 'protobuf' is used. The encoding is assumed to 

366 be "utf-8" when the format is a text format. 

367 """ 

368 serialized = _get_serializer(format, f).serialize_proto(proto) 

369 _save_bytes(serialized, f) 

370 

371 

372# For backward compatibility 

373load = load_model 

374load_from_string = load_model_from_string 

375save = save_model 

376 

377 

378def _model_proto_repr(self: ModelProto) -> str: 

379 if self.domain: 

380 domain = f", domain='{self.domain}'" 

381 else: 

382 domain = "" 

383 if self.producer_name: 

384 producer_name = f", producer_name='{self.producer_name}'" 

385 else: 

386 producer_name = "" 

387 if self.producer_version: 

388 producer_version = f", producer_version='{self.producer_version}'" 

389 else: 

390 producer_version = "" 

391 if self.graph: 

392 graph = f", graph={self.graph!r}" 

393 else: 

394 graph = "" 

395 if self.functions: 

396 functions = f", functions=<{len(self.functions)} functions>" 

397 else: 

398 functions = "" 

399 if self.opset_import: 

400 opset_import = f", opset_import={_operator_set_protos_repr(self.opset_import)}" 

401 else: 

402 opset_import = "" 

403 return f"ModelProto(ir_version={self.ir_version}{opset_import}{domain}{producer_name}{producer_version}{graph}{functions})" 

404 

405 

406def _graph_proto_repr(self: GraphProto) -> str: 

407 if self.initializer: 

408 initializer = f", initializer=<{len(self.initializer)} initializers>" 

409 else: 

410 initializer = "" 

411 if self.node: 

412 node = f", node=<{len(self.node)} nodes>" 

413 else: 

414 node = "" 

415 if self.value_info: 

416 value_info = f", value_info=<{len(self.value_info)} value_info>" 

417 else: 

418 value_info = "" 

419 if self.input: 

420 input = f", input=<{len(self.input)} inputs>" 

421 else: 

422 input = "" 

423 if self.output: 

424 output = f", output=<{len(self.output)} outputs>" 

425 else: 

426 output = "" 

427 return f"GraphProto('{self.name}'{input}{output}{initializer}{node}{value_info})" 

428 

429 

430def _function_proto_repr(self: FunctionProto) -> str: 

431 if self.domain: 

432 domain = f", domain='{self.domain}'" 

433 else: 

434 domain = "" 

435 if self.overload: 

436 overload = f", overload='{self.overload}'" 

437 else: 

438 overload = "" 

439 if self.node: 

440 node = f", node=<{len(self.node)} nodes>" 

441 else: 

442 node = "" 

443 if self.attribute: 

444 attribute = f", attribute={self.attribute}" 

445 else: 

446 attribute = "" 

447 if self.opset_import: 

448 opset_import = f", opset_import={_operator_set_protos_repr(self.opset_import)}" 

449 else: 

450 opset_import = "" 

451 if self.input: 

452 input = f", input=<{len(self.input)} inputs>" 

453 else: 

454 input = "" 

455 if self.output: 

456 output = f", output=<{len(self.output)} outputs>" 

457 else: 

458 output = "" 

459 return f"FunctionProto('{self.name}'{domain}{overload}{opset_import}{input}{output}{attribute}{node})" 

460 

461 

462def _operator_set_protos_repr(protos: Sequence[OperatorSetIdProto]) -> str: 

463 opset_imports = {proto.domain: proto.version for proto in protos} 

464 return repr(opset_imports) 

465 

466 

467# Override __repr__ for some proto classes to make it more efficient 

468ModelProto.__repr__ = _model_proto_repr # type: ignore[method-assign,assignment] 

469GraphProto.__repr__ = _graph_proto_repr # type: ignore[method-assign,assignment] 

470FunctionProto.__repr__ = _function_proto_repr # type: ignore[method-assign,assignment]