Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/checker.py: 54%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

52 statements  

1# Copyright (c) ONNX Project Contributors 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4"""Graph utilities for checking whether an ONNX proto message is legal.""" 

5 

6from __future__ import annotations 

7 

8__all__ = [ 

9 "check_attribute", 

10 "check_function", 

11 "check_graph", 

12 "check_model", 

13 "check_node", 

14 "check_sparse_tensor", 

15 "check_tensor", 

16 "check_value_info", 

17 "DEFAULT_CONTEXT", 

18 "LEXICAL_SCOPE_CONTEXT", 

19 "ValidationError", 

20 "C", 

21 "MAXIMUM_PROTOBUF", 

22] 

23 

24import os 

25from typing import TYPE_CHECKING 

26 

27import onnx.defs 

28import onnx.onnx_cpp2py_export.checker as C # noqa: N812 

29from onnx.onnx_pb import IR_VERSION 

30 

31if TYPE_CHECKING: 

32 from google.protobuf.message import Message 

33 

34# Limitation of single protobuf file is 2GiB 

35MAXIMUM_PROTOBUF = 2147483648 

36 

37 

38# NB: Please don't edit this context! 

39DEFAULT_CONTEXT = C.CheckerContext() 

40DEFAULT_CONTEXT.ir_version = IR_VERSION 

41# TODO: Maybe ONNX-ML should also be defaulted? 

42DEFAULT_CONTEXT.opset_imports = {"": onnx.defs.onnx_opset_version()} 

43 

44LEXICAL_SCOPE_CONTEXT = C.LexicalScopeContext() 

45 

46 

47def _ensure_proto_type(proto: Message, proto_type: type[Message]) -> None: 

48 if not isinstance(proto, proto_type): 

49 raise TypeError( 

50 f"The proto message needs to be of type '{proto_type.__name__}'" 

51 ) 

52 

53 

54def check_value_info( 

55 value_info: onnx.ValueInfoProto, ctx: C.CheckerContext = DEFAULT_CONTEXT 

56) -> None: 

57 _ensure_proto_type(value_info, onnx.ValueInfoProto) 

58 return C.check_value_info(value_info.SerializeToString(), ctx) 

59 

60 

61def check_tensor( 

62 tensor: onnx.TensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT 

63) -> None: 

64 _ensure_proto_type(tensor, onnx.TensorProto) 

65 return C.check_tensor(tensor.SerializeToString(), ctx) 

66 

67 

68def check_attribute( 

69 attr: onnx.AttributeProto, 

70 ctx: C.CheckerContext = DEFAULT_CONTEXT, 

71 lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT, 

72) -> None: 

73 _ensure_proto_type(attr, onnx.AttributeProto) 

74 return C.check_attribute(attr.SerializeToString(), ctx, lexical_scope_ctx) 

75 

76 

77def check_node( 

78 node: onnx.NodeProto, 

79 ctx: C.CheckerContext = DEFAULT_CONTEXT, 

80 lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT, 

81) -> None: 

82 _ensure_proto_type(node, onnx.NodeProto) 

83 return C.check_node(node.SerializeToString(), ctx, lexical_scope_ctx) 

84 

85 

86def check_function( 

87 function: onnx.FunctionProto, 

88 ctx: C.CheckerContext | None = None, 

89 lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT, 

90) -> None: 

91 _ensure_proto_type(function, onnx.FunctionProto) 

92 if ctx is None: 

93 ctx = C.CheckerContext() 

94 ctx.ir_version = onnx.helper.find_min_ir_version_for( 

95 function.opset_import, ignore_unknown=True 

96 ) 

97 ctx.opset_imports = { 

98 domain_version.domain: domain_version.version 

99 for domain_version in function.opset_import 

100 } 

101 C.check_function(function.SerializeToString(), ctx, lexical_scope_ctx) 

102 

103 

104def check_graph( 

105 graph: onnx.GraphProto, 

106 ctx: C.CheckerContext = DEFAULT_CONTEXT, 

107 lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT, 

108) -> None: 

109 _ensure_proto_type(graph, onnx.GraphProto) 

110 return C.check_graph(graph.SerializeToString(), ctx, lexical_scope_ctx) 

111 

112 

113def check_sparse_tensor( 

114 sparse: onnx.SparseTensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT 

115) -> None: 

116 _ensure_proto_type(sparse, onnx.SparseTensorProto) 

117 C.check_sparse_tensor(sparse.SerializeToString(), ctx) 

118 

119 

120def check_model( 

121 model: onnx.ModelProto | str | bytes | os.PathLike, 

122 full_check: bool = False, 

123 skip_opset_compatibility_check: bool = False, 

124 check_custom_domain: bool = False, 

125) -> None: 

126 """Check the consistency of a model. 

127 

128 An exception will be raised if the model's ir_version is not set 

129 properly or is higher than checker's ir_version, or if the model 

130 has duplicate keys in metadata_props. 

131 

132 If IR version >= 3, the model must specify opset_import. 

133 If IR version < 3, the model cannot have any opset_import specified. 

134 

135 Args: 

136 model: Model to check. If model is a path, the function checks model 

137 path first. If the model bytes size is larger than 2GB, function 

138 should be called using model path. 

139 full_check: If True, the function also runs shape inference check. 

140 skip_opset_compatibility_check: If True, the function skips the check for 

141 opset compatibility. 

142 check_custom_domain: If True, the function will check all domains. Otherwise 

143 only check built-in domains. 

144 """ 

145 # If model is a path instead of ModelProto 

146 if isinstance(model, (str, os.PathLike)): 

147 C.check_model_path( 

148 os.fspath(model), 

149 full_check, 

150 skip_opset_compatibility_check, 

151 check_custom_domain, 

152 ) 

153 else: 

154 protobuf_string = ( 

155 model if isinstance(model, bytes) else model.SerializeToString() 

156 ) 

157 # If the protobuf is larger than 2GiB, 

158 # remind users should use the model path to check 

159 if len(protobuf_string) > MAXIMUM_PROTOBUF: 

160 raise ValueError( 

161 "This protobuf of onnx model is too large (>2GiB). Call check_model with model path instead." 

162 ) 

163 C.check_model( 

164 protobuf_string, 

165 full_check, 

166 skip_opset_compatibility_check, 

167 check_custom_domain, 

168 ) 

169 

170 

171ValidationError = C.ValidationError