Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/checker.py: 54%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (c) ONNX Project Contributors
2#
3# SPDX-License-Identifier: Apache-2.0
4"""Graph utilities for checking whether an ONNX proto message is legal."""
6from __future__ import annotations
8__all__ = [
9 "check_attribute",
10 "check_function",
11 "check_graph",
12 "check_model",
13 "check_node",
14 "check_sparse_tensor",
15 "check_tensor",
16 "check_value_info",
17 "DEFAULT_CONTEXT",
18 "LEXICAL_SCOPE_CONTEXT",
19 "ValidationError",
20 "C",
21 "MAXIMUM_PROTOBUF",
22]
24import os
25from typing import TYPE_CHECKING
27import onnx.defs
28import onnx.onnx_cpp2py_export.checker as C # noqa: N812
29from onnx.onnx_pb import IR_VERSION
31if TYPE_CHECKING:
32 from google.protobuf.message import Message
34# Limitation of single protobuf file is 2GiB
35MAXIMUM_PROTOBUF = 2147483648
38# NB: Please don't edit this context!
39DEFAULT_CONTEXT = C.CheckerContext()
40DEFAULT_CONTEXT.ir_version = IR_VERSION
41# TODO: Maybe ONNX-ML should also be defaulted?
42DEFAULT_CONTEXT.opset_imports = {"": onnx.defs.onnx_opset_version()}
44LEXICAL_SCOPE_CONTEXT = C.LexicalScopeContext()
47def _ensure_proto_type(proto: Message, proto_type: type[Message]) -> None:
48 if not isinstance(proto, proto_type):
49 raise TypeError(
50 f"The proto message needs to be of type '{proto_type.__name__}'"
51 )
54def check_value_info(
55 value_info: onnx.ValueInfoProto, ctx: C.CheckerContext = DEFAULT_CONTEXT
56) -> None:
57 _ensure_proto_type(value_info, onnx.ValueInfoProto)
58 return C.check_value_info(value_info.SerializeToString(), ctx)
61def check_tensor(
62 tensor: onnx.TensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT
63) -> None:
64 _ensure_proto_type(tensor, onnx.TensorProto)
65 return C.check_tensor(tensor.SerializeToString(), ctx)
68def check_attribute(
69 attr: onnx.AttributeProto,
70 ctx: C.CheckerContext = DEFAULT_CONTEXT,
71 lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT,
72) -> None:
73 _ensure_proto_type(attr, onnx.AttributeProto)
74 return C.check_attribute(attr.SerializeToString(), ctx, lexical_scope_ctx)
77def check_node(
78 node: onnx.NodeProto,
79 ctx: C.CheckerContext = DEFAULT_CONTEXT,
80 lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT,
81) -> None:
82 _ensure_proto_type(node, onnx.NodeProto)
83 return C.check_node(node.SerializeToString(), ctx, lexical_scope_ctx)
86def check_function(
87 function: onnx.FunctionProto,
88 ctx: C.CheckerContext | None = None,
89 lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT,
90) -> None:
91 _ensure_proto_type(function, onnx.FunctionProto)
92 if ctx is None:
93 ctx = C.CheckerContext()
94 ctx.ir_version = onnx.helper.find_min_ir_version_for(
95 function.opset_import, ignore_unknown=True
96 )
97 ctx.opset_imports = {
98 domain_version.domain: domain_version.version
99 for domain_version in function.opset_import
100 }
101 C.check_function(function.SerializeToString(), ctx, lexical_scope_ctx)
104def check_graph(
105 graph: onnx.GraphProto,
106 ctx: C.CheckerContext = DEFAULT_CONTEXT,
107 lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT,
108) -> None:
109 _ensure_proto_type(graph, onnx.GraphProto)
110 return C.check_graph(graph.SerializeToString(), ctx, lexical_scope_ctx)
113def check_sparse_tensor(
114 sparse: onnx.SparseTensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT
115) -> None:
116 _ensure_proto_type(sparse, onnx.SparseTensorProto)
117 C.check_sparse_tensor(sparse.SerializeToString(), ctx)
120def check_model(
121 model: onnx.ModelProto | str | bytes | os.PathLike,
122 full_check: bool = False,
123 skip_opset_compatibility_check: bool = False,
124 check_custom_domain: bool = False,
125) -> None:
126 """Check the consistency of a model.
128 An exception will be raised if the model's ir_version is not set
129 properly or is higher than checker's ir_version, or if the model
130 has duplicate keys in metadata_props.
132 If IR version >= 3, the model must specify opset_import.
133 If IR version < 3, the model cannot have any opset_import specified.
135 Args:
136 model: Model to check. If model is a path, the function checks model
137 path first. If the model bytes size is larger than 2GB, function
138 should be called using model path.
139 full_check: If True, the function also runs shape inference check.
140 skip_opset_compatibility_check: If True, the function skips the check for
141 opset compatibility.
142 check_custom_domain: If True, the function will check all domains. Otherwise
143 only check built-in domains.
144 """
145 # If model is a path instead of ModelProto
146 if isinstance(model, (str, os.PathLike)):
147 C.check_model_path(
148 os.fspath(model),
149 full_check,
150 skip_opset_compatibility_check,
151 check_custom_domain,
152 )
153 else:
154 protobuf_string = (
155 model if isinstance(model, bytes) else model.SerializeToString()
156 )
157 # If the protobuf is larger than 2GiB,
158 # remind users should use the model path to check
159 if len(protobuf_string) > MAXIMUM_PROTOBUF:
160 raise ValueError(
161 "This protobuf of onnx model is too large (>2GiB). Call check_model with model path instead."
162 )
163 C.check_model(
164 protobuf_string,
165 full_check,
166 skip_opset_compatibility_check,
167 check_custom_domain,
168 )
171ValidationError = C.ValidationError