Coverage for /pythoncovmergedfiles/medio/medio/src/onnx/onnx/fuzz/fuzz_shape_inference.py: 70%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1###### Coverage stub
2import atexit
3import coverage
4cov = coverage.coverage(data_file='.coverage', cover_pylib=True)
5cov.start()
6# Register an exist handler that will print coverage
7def exit_handler():
8 cov.stop()
9 cov.save()
10atexit.register(exit_handler)
11####### End of coverage stub
12# Copyright (c) ONNX Project Contributors
13# SPDX-License-Identifier: Apache-2.0
14"""Atheris fuzz harness for onnx.shape_inference.
16Two input paths are exercised per iteration, selected by a fuzzer-controlled
17toggle byte read from the *tail* of the input (so the head remains a valid
18candidate for the raw-bytes path):
20* Raw bytes -> onnx.load_model_from_string -> infer_shapes
21 Catches protobuf parser bugs and bugs reachable only through crafted
22 serialized models the structured builder will not produce.
24* Structured -> helper.make_model from FuzzedDataProvider -> infer_shapes
25 Constructs ModelProto objects whose graphs include subgraph-bearing
26 ops (If / Loop / Scan) so the recursive visitor inside shape_inference
27 is reached on most iterations rather than only when the parser happens
28 to accept a random byte string.
30Both strict_mode values and both check_type values are sampled.
31"""
34import sys
36import atheris
38with atheris.instrument_imports():
39 import onnx
40 from onnx import TensorProto, helper, shape_inference
42# Elementwise unary ops with trivial shape inference rules. Useful as
43# filler nodes so generated graphs have non-trivial bodies that exercise
44# the per-op inference dispatch table.
45_UNARY = (
46 "Relu",
47 "Sigmoid",
48 "Tanh",
49 "Abs",
50 "Neg",
51 "Exp",
52 "Log",
53 "Sqrt",
54 "Identity",
55 "Floor",
56 "Ceil",
57)
59# Ops that carry one or more subgraph attributes. Each forces the recursive
60# shape-inference visitor to descend, which is the path the known DoS lives
61# on. Loop/Scan exercise different subgraph-context plumbing than If.
62_SUBGRAPH_OPS = ("If", "Loop", "Scan")
65def _const_bool(name, value=True):
66 tensor = helper.make_tensor(name, TensorProto.BOOL, [], [value])
67 return helper.make_node("Constant", [], [name], value=tensor)
70def _build_branch(fdp, depth, max_depth):
71 """Build a self-contained subgraph.
73 Self-contained means the subgraph produces its own starting tensor via
74 a Constant node, so the branch does not depend on outer-scope captures
75 we did not declare. With probability the branch nests one of
76 If/Loop/Scan, which is what drives the recursion inside shape_inference.
77 Loop and Scan body subgraphs are deliberately not signature-conformant;
78 the recursive visitor descends before signature checks run, so the
79 recursion path is still exercised even when inference ultimately fails.
80 """
81 nodes = []
82 start = f"s_{depth}"
83 start_tensor = helper.make_tensor(start, TensorProto.FLOAT, [1], [0.0])
84 nodes.append(helper.make_node("Constant", [], [start], value=start_tensor))
86 if depth < max_depth and fdp.ConsumeBool():
87 sub_op = _SUBGRAPH_OPS[fdp.ConsumeIntInRange(0, len(_SUBGRAPH_OPS) - 1)]
88 out = f"sub_{depth}"
89 body = _build_branch(fdp, depth + 1, max_depth)
90 if sub_op == "If":
91 cond = f"c_{depth}"
92 nodes.append(_const_bool(cond))
93 else_body = _build_branch(fdp, depth + 1, max_depth)
94 nodes.append(
95 helper.make_node(
96 "If",
97 [cond],
98 [out],
99 then_branch=body,
100 else_branch=else_body,
101 )
102 )
103 elif sub_op == "Loop":
104 trip = f"M_{depth}"
105 trip_t = helper.make_tensor(trip, TensorProto.INT64, [], [1])
106 nodes.append(helper.make_node("Constant", [], [trip], value=trip_t))
107 cond = f"c_{depth}"
108 nodes.append(_const_bool(cond))
109 nodes.append(
110 helper.make_node(
111 "Loop",
112 [trip, cond],
113 [out],
114 body=body,
115 )
116 )
117 else: # Scan
118 nodes.append(
119 helper.make_node(
120 "Scan",
121 [start],
122 [out],
123 body=body,
124 num_scan_inputs=1,
125 )
126 )
127 last = out
128 else:
129 last = start
130 n_ops = fdp.ConsumeIntInRange(0, 4)
131 for i in range(n_ops):
132 op = _UNARY[fdp.ConsumeIntInRange(0, len(_UNARY) - 1)]
133 nxt = f"v_{depth}_{i}"
134 nodes.append(helper.make_node(op, [last], [nxt]))
135 last = nxt
137 return helper.make_graph(
138 nodes,
139 f"branch_{depth}",
140 inputs=[],
141 outputs=[helper.make_tensor_value_info(last, TensorProto.FLOAT, None)],
142 )
145def _build_model(fdp):
146 # Top-level graph mirrors a branch but lives at depth 0 and chooses its
147 # own opset version so different shape-inference codepaths (per-opset
148 # schemas) are reached.
149 max_depth = fdp.ConsumeIntInRange(0, 80)
150 graph = _build_branch(fdp, depth=0, max_depth=max_depth)
151 opset = fdp.ConsumeIntInRange(7, 27)
152 return helper.make_model(
153 graph,
154 opset_imports=[helper.make_opsetid("", opset)],
155 )
158def TestOneInput(data):
159 # Toggles live in the trailing byte. On the structured path we slice the
160 # byte off before handing the rest to FuzzedDataProvider. On the raw path
161 # we pass the full `data` to the protobuf parser unchanged: seed models
162 # are complete serialized ModelProtos, so slicing the tail would truncate
163 # every seed. The trailing toggle byte becomes part of the raw input,
164 # which libFuzzer mutates freely anyway.
165 if len(data) < 2:
166 return
167 toggles = data[-1]
168 strict = bool(toggles & 0x01)
169 check_type = bool(toggles & 0x02)
170 use_structured = bool(toggles & 0x04)
171 # bits 0x08..0x80 are reserved for future toggles; mutations against
172 # them are harmless until claimed.
174 try:
175 if use_structured:
176 fdp = atheris.FuzzedDataProvider(data[:-1])
177 model = _build_model(fdp)
178 else:
179 model = onnx.load_model_from_string(data)
180 shape_inference.infer_shapes(
181 model,
182 check_type=check_type,
183 strict_mode=strict,
184 )
185 except Exception:
186 # Malformed fuzz inputs raise a broad set of expected exceptions
187 # (ValidationError, InferenceError, DecodeError, ValueError, ...).
188 # Real bugs surface as crashes, hangs, or sanitizer reports.
189 return
192def main():
193 atheris.instrument_all()
194 atheris.Setup(sys.argv, TestOneInput, enable_python_coverage=True)
195 atheris.Fuzz()
198if __name__ == "__main__":
199 main()