Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/gen_proto.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
3# Copyright (c) ONNX Project Contributors
4#
5# SPDX-License-Identifier: Apache-2.0
6from __future__ import annotations
8import argparse
9import glob
10import os
11import re
12import subprocess
13from textwrap import dedent
14from typing import TYPE_CHECKING
16if TYPE_CHECKING:
17 from collections.abc import Iterable
19autogen_header = """\
20//
21// WARNING: This file is automatically generated! Please edit onnx.in.proto.
22//
25"""
27LITE_OPTION = """
29// For using protobuf-lite
30option optimize_for = LITE_RUNTIME;
32"""
34DEFAULT_PACKAGE_NAME = "onnx"
36IF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#if\s+ONNX-ML\s*$")
37ENDIF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#endif\s*$")
38ELSE_ONNX_ML_REGEX = re.compile(r"\s*//\s*#else\s*$")
41def process_ifs(lines: Iterable[str], onnx_ml: bool) -> Iterable[str]:
42 in_if = 0
43 for line in lines:
44 if IF_ONNX_ML_REGEX.match(line):
45 assert in_if == 0
46 in_if = 1
47 elif ELSE_ONNX_ML_REGEX.match(line):
48 assert in_if == 1
49 in_if = 2
50 elif ENDIF_ONNX_ML_REGEX.match(line):
51 assert in_if in (1, 2)
52 in_if = 0
53 else: # noqa: PLR5501
54 if in_if == 0:
55 yield line
56 elif in_if == 1 and onnx_ml:
57 yield line
58 elif in_if == 2 and not onnx_ml: # noqa: PLR2004
59 yield line
62IMPORT_REGEX = re.compile(r'(\s*)import\s*"([^"]*)\.proto";\s*$')
63PACKAGE_NAME_REGEX = re.compile(r"\{PACKAGE_NAME\}")
64ML_REGEX = re.compile(r"(.*)\-ml")
67def process_package_name(lines: Iterable[str], package_name: str) -> Iterable[str]:
68 need_rename = package_name != DEFAULT_PACKAGE_NAME
69 for line in lines:
70 m = IMPORT_REGEX.match(line) if need_rename else None
71 if m:
72 include_name = m.group(2)
73 ml = ML_REGEX.match(include_name)
74 if ml:
75 include_name = f"{ml.group(1)}_{package_name}-ml"
76 else:
77 include_name = f"{include_name}_{package_name}"
78 yield m.group(1) + f'import "{include_name}.proto";'
79 else:
80 yield PACKAGE_NAME_REGEX.sub(package_name, line)
83PROTO_SYNTAX_REGEX = re.compile(r'(\s*)syntax\s*=\s*"proto2"\s*;\s*$')
84OPTIONAL_REGEX = re.compile(r"(\s*)optional\s(.*)$")
87def convert_to_proto3(lines: Iterable[str]) -> Iterable[str]:
88 for line in lines:
89 # Set the syntax specifier
90 m = PROTO_SYNTAX_REGEX.match(line)
91 if m:
92 yield m.group(1) + 'syntax = "proto3";'
93 continue
95 # Remove optional keywords
96 m = OPTIONAL_REGEX.match(line)
97 if m:
98 yield m.group(1) + m.group(2)
99 continue
101 # Rewrite import
102 m = IMPORT_REGEX.match(line)
103 if m:
104 yield m.group(1) + f'import "{m.group(2)}.proto3";'
105 continue
107 yield line
110def gen_proto3_code(
111 protoc_path: str, proto3_path: str, include_path: str, cpp_out: str, python_out: str
112) -> None:
113 print(f"Generate pb3 code using {protoc_path}")
114 build_args = [protoc_path, proto3_path, "-I", include_path]
115 build_args.extend(["--cpp_out", cpp_out, "--python_out", python_out])
116 subprocess.run(build_args, check=True) # noqa: S603
119def translate(source: str, proto: int, onnx_ml: bool, package_name: str) -> str:
120 lines: Iterable[str] = source.splitlines()
121 lines = process_ifs(lines, onnx_ml=onnx_ml)
122 lines = process_package_name(lines, package_name=package_name)
123 if proto == 3: # noqa: PLR2004
124 lines = convert_to_proto3(lines)
125 else:
126 assert proto == 2 # noqa: PLR2004
127 return os.linesep.join(lines)
130def qualify(f: str, pardir: str | None = None) -> str:
131 if pardir is None:
132 pardir = os.path.realpath(os.path.dirname(__file__))
133 return os.path.join(pardir, f)
136def convert(
137 stem: str,
138 package_name: str,
139 output: str,
140 do_onnx_ml: bool = False,
141 lite: bool = False,
142 protoc_path: str = "",
143) -> None:
144 proto_in = qualify(f"{stem}.in.proto")
145 need_rename = package_name != DEFAULT_PACKAGE_NAME
146 # Having a separate variable for import_ml ensures that the import statements for the generated
147 # proto files can be set separately from the ONNX_ML environment variable setting.
148 import_ml = do_onnx_ml
149 # We do not want to generate the onnx-data-ml.proto files for onnx-data.in.proto,
150 # as there is no change between onnx-data.proto and the ML version.
151 if "onnx-data" in proto_in:
152 do_onnx_ml = False
153 if do_onnx_ml:
154 proto_base = f"{stem}_{package_name}-ml" if need_rename else f"{stem}-ml"
155 else:
156 proto_base = f"{stem}_{package_name}" if need_rename else f"{stem}"
157 proto = qualify(f"{proto_base}.proto", pardir=output)
158 proto3 = qualify(f"{proto_base}.proto3", pardir=output)
160 print(f"Processing {proto_in}")
161 with open(proto_in, encoding="utf-8") as fin:
162 source = fin.read()
163 print(f"Writing {proto}")
164 with open(proto, "w", newline="", encoding="utf-8") as fout:
165 fout.write(autogen_header)
166 fout.write(
167 translate(source, proto=2, onnx_ml=import_ml, package_name=package_name)
168 )
169 if lite:
170 fout.write(LITE_OPTION)
171 print(f"Writing {proto3}")
172 with open(proto3, "w", newline="", encoding="utf-8") as fout:
173 fout.write(autogen_header)
174 fout.write(
175 translate(source, proto=3, onnx_ml=import_ml, package_name=package_name)
176 )
177 if lite:
178 fout.write(LITE_OPTION)
179 if protoc_path:
180 porto3_dir = os.path.dirname(proto3)
181 base_dir = os.path.dirname(porto3_dir)
182 gen_proto3_code(protoc_path, proto3, base_dir, base_dir, base_dir)
183 pb3_files = glob.glob(os.path.join(porto3_dir, f"{proto_base}.proto3.*"))
184 for pb3_file in pb3_files:
185 print(f"Removing {pb3_file}")
186 os.remove(pb3_file)
188 if need_rename:
189 if do_onnx_ml:
190 proto_header = qualify(f"{stem}-ml.pb.h", pardir=output)
191 else:
192 proto_header = qualify(f"{stem}.pb.h", pardir=output)
193 print(f"Writing {proto_header}")
194 with open(proto_header, "w", newline="", encoding="utf-8") as fout:
195 fout.write("#pragma once\n")
196 fout.write(f'#include "{proto_base}.pb.h"\n')
198 # Generate py mapping
199 # "-" is invalid in python module name, replaces '-' with '_'
200 pb_py = qualify(f"{stem.replace('-', '_')}_pb.py", pardir=output)
201 if need_rename:
202 pb2_py = qualify(f"{proto_base.replace('-', '_')}_pb2.py", pardir=output)
203 else: # noqa: PLR5501
204 if do_onnx_ml:
205 pb2_py = qualify(f"{stem.replace('-', '_')}_ml_pb2.py", pardir=output)
206 else:
207 pb2_py = qualify(f"{stem.replace('-', '_')}_pb2.py", pardir=output)
209 print(f"generating {pb_py}")
210 with open(pb_py, "w", encoding="utf-8") as f:
211 f.write(
212 dedent(
213 f"""\
214 # This file is generated by 'gen_proto.py'. DO NOT EDIT!
217 from .{os.path.splitext(os.path.basename(pb2_py))[0]} import * # noqa
218 """
219 )
220 )
223def main() -> None:
224 parser = argparse.ArgumentParser(
225 description="Generates .proto file variations from .in.proto"
226 )
227 parser.add_argument(
228 "-p",
229 "--package",
230 default="onnx",
231 help="package name in the generated proto files (default: %(default)s)",
232 )
233 parser.add_argument("-m", "--ml", action="store_true", help="ML mode")
234 parser.add_argument(
235 "-l",
236 "--lite",
237 action="store_true",
238 help="generate lite proto to use with protobuf-lite",
239 )
240 parser.add_argument(
241 "-o",
242 "--output",
243 default=os.path.realpath(os.path.dirname(__file__)),
244 help="output directory (default: %(default)s)",
245 )
246 parser.add_argument(
247 "--protoc_path", default="", help="path to protoc for proto3 file validation"
248 )
249 parser.add_argument(
250 "stems",
251 nargs="*",
252 default=["onnx", "onnx-operators", "onnx-data"],
253 help="list of .in.proto file stems (default: %(default)s)",
254 )
255 args = parser.parse_args()
257 if not os.path.exists(args.output):
258 os.makedirs(args.output)
260 for stem in args.stems:
261 convert(
262 stem,
263 package_name=args.package,
264 output=args.output,
265 do_onnx_ml=args.ml,
266 lite=args.lite,
267 protoc_path=args.protoc_path,
268 )
271if __name__ == "__main__":
272 main()