Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/shape_inference.py: 33%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

58 statements  

1# Copyright (c) ONNX Project Contributors 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4 

5"""onnx shape inference. Shape inference is not guaranteed to be 

6complete. 

7 

8""" 

9 

10from __future__ import annotations 

11 

12import os 

13from typing import TYPE_CHECKING 

14 

15import onnx 

16import onnx.onnx_cpp2py_export.shape_inference as C # noqa: N812 

17from onnx.onnx_pb import ( 

18 IR_VERSION, 

19 AttributeProto, 

20 FunctionProto, 

21 ModelProto, 

22 TypeProto, 

23) 

24 

25if TYPE_CHECKING: 

26 from collections.abc import Sequence 

27 

28GraphInferencer = C.GraphInferencer 

29InferenceContext = C.InferenceContext 

30 

31 

32def infer_shapes( 

33 model: ModelProto | bytes, 

34 check_type: bool = False, 

35 strict_mode: bool = False, 

36 data_prop: bool = False, 

37) -> ModelProto: 

38 """Apply shape inference to the provided ModelProto. 

39 

40 Inferred shapes are added to the value_info field of the graph. 

41 

42 If the inferred values conflict with values already provided in the 

43 graph, that means that the provided values are invalid (or there is a 

44 bug in shape inference), and the result is unspecified. 

45 

46 Arguments: 

47 model: ModelProto. 

48 check_type: Checks the type-equality for input and output. 

49 strict_mode: Stricter shape inference, it will throw errors if any; 

50 Otherwise, simply stop if any error. 

51 data_prop: Enables data propagation for limited operators to perform shape computation. 

52 

53 Returns: 

54 (ModelProto) model with inferred shape information 

55 """ 

56 if isinstance(model, (ModelProto, bytes)): 

57 model_str = model if isinstance(model, bytes) else model.SerializeToString() 

58 inferred_model_str = C.infer_shapes( 

59 model_str, check_type, strict_mode, data_prop 

60 ) 

61 return onnx.load_from_string(inferred_model_str) 

62 if isinstance(model, (str, os.PathLike)): 

63 raise TypeError( 

64 "infer_shapes only accepts ModelProto or bytes," 

65 " For Model paths (str or os.PathLike), use infer_shapes_path()." 

66 ) 

67 

68 raise TypeError( 

69 f"infer_shapes only accepts ModelProto or bytes, incorrect type: {type(model)}" 

70 ) 

71 

72 

73def infer_shapes_path( 

74 model_path: str | os.PathLike, 

75 output_path: str | os.PathLike = "", 

76 check_type: bool = False, 

77 strict_mode: bool = False, 

78 data_prop: bool = False, 

79) -> None: 

80 """Take model path for shape_inference. 

81 

82 This function is the same as :func:`infer_shape` but supports >2GB models. 

83 The function outputs the inferred model to the `output_path`. The original model path 

84 is used if not specified. 

85 """ 

86 if isinstance(model_path, ModelProto): 

87 raise TypeError( 

88 "infer_shapes_path only accepts model Path (String)," 

89 "you can use infer_shapes for the ModelProto." 

90 ) 

91 try: 

92 model_path = os.fspath(model_path) 

93 except TypeError as exp: 

94 raise TypeError( 

95 "infer_shapes_path only accepts model path as a string or PathLike, " 

96 f"incorrect model path type: {type(model_path)}" 

97 ) from exp 

98 try: 

99 output_path = os.fspath(output_path) 

100 except TypeError as exp: 

101 raise TypeError( 

102 "infer_shapes_path only accepts output path as a string or PathLike, " 

103 f"incorrect output path type: {type(output_path)}" 

104 ) from exp 

105 

106 if output_path == "": 

107 output_path = model_path 

108 C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop) 

109 

110 

111def infer_node_outputs( 

112 schema: onnx.defs.OpSchema, 

113 node: onnx.NodeProto, 

114 input_types: dict[str, onnx.TypeProto], 

115 input_data: dict[str, onnx.TensorProto] | None = None, 

116 input_sparse_data: dict[str, onnx.SparseTensorProto] | None = None, 

117 opset_imports: list[onnx.OperatorSetIdProto] | None = None, 

118 ir_version: int = IR_VERSION, 

119) -> dict[str, onnx.TypeProto]: 

120 if input_data is None: 

121 input_data = {} 

122 if input_sparse_data is None: 

123 input_sparse_data = {} 

124 if opset_imports is None: 

125 passed_opset_imports = {} 

126 else: 

127 passed_opset_imports = {opset.domain: opset.version for opset in opset_imports} 

128 

129 # catch KeyError if node's input does not exist in input_types 

130 passed_input_types = { 

131 key: input_types[key].SerializeToString() for key in node.input if key != "" 

132 } 

133 # input_types will also be used as outer_scope_value_types so do not filter by node's input here 

134 for key, value in input_types.items(): 

135 if key not in passed_input_types: 

136 passed_input_types[key] = value.SerializeToString() 

137 passed_input_data = { 

138 key: input_data[key].SerializeToString() 

139 for key in node.input 

140 if key in input_data 

141 } 

142 passed_sparse_input_data = { 

143 key: input_sparse_data[key].SerializeToString() 

144 for key in node.input 

145 if key in input_sparse_data 

146 } 

147 

148 outputs = schema._infer_node_outputs( 

149 node.SerializeToString(), 

150 passed_input_types, 

151 passed_input_data, 

152 passed_sparse_input_data, 

153 passed_opset_imports, 

154 ir_version, 

155 ) # type: ignore[call-arg] 

156 return {key: onnx.TypeProto.FromString(out) for key, out in outputs.items()} 

157 

158 

159def infer_function_output_types( 

160 function: FunctionProto, 

161 input_types: Sequence[TypeProto], 

162 attributes: Sequence[AttributeProto], 

163) -> list[TypeProto]: 

164 """Apply type-and-shape-inference to given function body, with given input types 

165 and given input attribute values. 

166 """ 

167 result = C.infer_function_output_types( 

168 function.SerializeToString(), 

169 [x.SerializeToString() for x in input_types], 

170 [x.SerializeToString() for x in attributes], 

171 ) 

172 

173 def to_type_proto(x) -> TypeProto: 

174 type_proto = onnx.TypeProto() 

175 type_proto.ParseFromString(x) 

176 return type_proto 

177 

178 return [to_type_proto(x) for x in result] 

179 

180 

181InferenceError = C.InferenceError