Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/shape_inference.py: 33%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (c) ONNX Project Contributors
2#
3# SPDX-License-Identifier: Apache-2.0
5"""onnx shape inference. Shape inference is not guaranteed to be
6complete.
8"""
10from __future__ import annotations
12import os
13from typing import TYPE_CHECKING
15import onnx
16import onnx.onnx_cpp2py_export.shape_inference as C # noqa: N812
17from onnx.onnx_pb import (
18 IR_VERSION,
19 AttributeProto,
20 FunctionProto,
21 ModelProto,
22 TypeProto,
23)
25if TYPE_CHECKING:
26 from collections.abc import Sequence
28GraphInferencer = C.GraphInferencer
29InferenceContext = C.InferenceContext
32def infer_shapes(
33 model: ModelProto | bytes,
34 check_type: bool = False,
35 strict_mode: bool = False,
36 data_prop: bool = False,
37) -> ModelProto:
38 """Apply shape inference to the provided ModelProto.
40 Inferred shapes are added to the value_info field of the graph.
42 If the inferred values conflict with values already provided in the
43 graph, that means that the provided values are invalid (or there is a
44 bug in shape inference), and the result is unspecified.
46 Arguments:
47 model: ModelProto.
48 check_type: Checks the type-equality for input and output.
49 strict_mode: Stricter shape inference, it will throw errors if any;
50 Otherwise, simply stop if any error.
51 data_prop: Enables data propagation for limited operators to perform shape computation.
53 Returns:
54 (ModelProto) model with inferred shape information
55 """
56 if isinstance(model, (ModelProto, bytes)):
57 model_str = model if isinstance(model, bytes) else model.SerializeToString()
58 inferred_model_str = C.infer_shapes(
59 model_str, check_type, strict_mode, data_prop
60 )
61 return onnx.load_from_string(inferred_model_str)
62 if isinstance(model, (str, os.PathLike)):
63 raise TypeError(
64 "infer_shapes only accepts ModelProto or bytes,"
65 " For Model paths (str or os.PathLike), use infer_shapes_path()."
66 )
68 raise TypeError(
69 f"infer_shapes only accepts ModelProto or bytes, incorrect type: {type(model)}"
70 )
73def infer_shapes_path(
74 model_path: str | os.PathLike,
75 output_path: str | os.PathLike = "",
76 check_type: bool = False,
77 strict_mode: bool = False,
78 data_prop: bool = False,
79) -> None:
80 """Take model path for shape_inference.
82 This function is the same as :func:`infer_shape` but supports >2GB models.
83 The function outputs the inferred model to the `output_path`. The original model path
84 is used if not specified.
85 """
86 if isinstance(model_path, ModelProto):
87 raise TypeError(
88 "infer_shapes_path only accepts model Path (String),"
89 "you can use infer_shapes for the ModelProto."
90 )
91 try:
92 model_path = os.fspath(model_path)
93 except TypeError as exp:
94 raise TypeError(
95 "infer_shapes_path only accepts model path as a string or PathLike, "
96 f"incorrect model path type: {type(model_path)}"
97 ) from exp
98 try:
99 output_path = os.fspath(output_path)
100 except TypeError as exp:
101 raise TypeError(
102 "infer_shapes_path only accepts output path as a string or PathLike, "
103 f"incorrect output path type: {type(output_path)}"
104 ) from exp
106 if output_path == "":
107 output_path = model_path
108 C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop)
111def infer_node_outputs(
112 schema: onnx.defs.OpSchema,
113 node: onnx.NodeProto,
114 input_types: dict[str, onnx.TypeProto],
115 input_data: dict[str, onnx.TensorProto] | None = None,
116 input_sparse_data: dict[str, onnx.SparseTensorProto] | None = None,
117 opset_imports: list[onnx.OperatorSetIdProto] | None = None,
118 ir_version: int = IR_VERSION,
119) -> dict[str, onnx.TypeProto]:
120 if input_data is None:
121 input_data = {}
122 if input_sparse_data is None:
123 input_sparse_data = {}
124 if opset_imports is None:
125 passed_opset_imports = {}
126 else:
127 passed_opset_imports = {opset.domain: opset.version for opset in opset_imports}
129 # catch KeyError if node's input does not exist in input_types
130 passed_input_types = {
131 key: input_types[key].SerializeToString() for key in node.input if key != ""
132 }
133 # input_types will also be used as outer_scope_value_types so do not filter by node's input here
134 for key, value in input_types.items():
135 if key not in passed_input_types:
136 passed_input_types[key] = value.SerializeToString()
137 passed_input_data = {
138 key: input_data[key].SerializeToString()
139 for key in node.input
140 if key in input_data
141 }
142 passed_sparse_input_data = {
143 key: input_sparse_data[key].SerializeToString()
144 for key in node.input
145 if key in input_sparse_data
146 }
148 outputs = schema._infer_node_outputs(
149 node.SerializeToString(),
150 passed_input_types,
151 passed_input_data,
152 passed_sparse_input_data,
153 passed_opset_imports,
154 ir_version,
155 ) # type: ignore[call-arg]
156 return {key: onnx.TypeProto.FromString(out) for key, out in outputs.items()}
159def infer_function_output_types(
160 function: FunctionProto,
161 input_types: Sequence[TypeProto],
162 attributes: Sequence[AttributeProto],
163) -> list[TypeProto]:
164 """Apply type-and-shape-inference to given function body, with given input types
165 and given input attribute values.
166 """
167 result = C.infer_function_output_types(
168 function.SerializeToString(),
169 [x.SerializeToString() for x in input_types],
170 [x.SerializeToString() for x in attributes],
171 )
173 def to_type_proto(x) -> TypeProto:
174 type_proto = onnx.TypeProto()
175 type_proto.ParseFromString(x)
176 return type_proto
178 return [to_type_proto(x) for x in result]
181InferenceError = C.InferenceError