Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/gen_proto.py: 19%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

150 statements  

1#!/usr/bin/env python 

2 

3# Copyright (c) ONNX Project Contributors 

4# 

5# SPDX-License-Identifier: Apache-2.0 

6from __future__ import annotations 

7 

8import argparse 

9import glob 

10import os 

11import re 

12import subprocess 

13from textwrap import dedent 

14from typing import TYPE_CHECKING 

15 

16if TYPE_CHECKING: 

17 from collections.abc import Iterable 

18 

19autogen_header = """\ 

20// 

21// WARNING: This file is automatically generated! Please edit onnx.in.proto. 

22// 

23 

24 

25""" 

26 

27LITE_OPTION = """ 

28 

29// For using protobuf-lite 

30option optimize_for = LITE_RUNTIME; 

31 

32""" 

33 

34DEFAULT_PACKAGE_NAME = "onnx" 

35 

36IF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#if\s+ONNX-ML\s*$") 

37ENDIF_ONNX_ML_REGEX = re.compile(r"\s*//\s*#endif\s*$") 

38ELSE_ONNX_ML_REGEX = re.compile(r"\s*//\s*#else\s*$") 

39 

40 

41def process_ifs(lines: Iterable[str], onnx_ml: bool) -> Iterable[str]: 

42 in_if = 0 

43 for line in lines: 

44 if IF_ONNX_ML_REGEX.match(line): 

45 assert in_if == 0 

46 in_if = 1 

47 elif ELSE_ONNX_ML_REGEX.match(line): 

48 assert in_if == 1 

49 in_if = 2 

50 elif ENDIF_ONNX_ML_REGEX.match(line): 

51 assert in_if in (1, 2) 

52 in_if = 0 

53 else: # noqa: PLR5501 

54 if in_if == 0: 

55 yield line 

56 elif in_if == 1 and onnx_ml: 

57 yield line 

58 elif in_if == 2 and not onnx_ml: # noqa: PLR2004 

59 yield line 

60 

61 

62IMPORT_REGEX = re.compile(r'(\s*)import\s*"([^"]*)\.proto";\s*$') 

63PACKAGE_NAME_REGEX = re.compile(r"\{PACKAGE_NAME\}") 

64ML_REGEX = re.compile(r"(.*)\-ml") 

65 

66 

67def process_package_name(lines: Iterable[str], package_name: str) -> Iterable[str]: 

68 need_rename = package_name != DEFAULT_PACKAGE_NAME 

69 for line in lines: 

70 m = IMPORT_REGEX.match(line) if need_rename else None 

71 if m: 

72 include_name = m.group(2) 

73 ml = ML_REGEX.match(include_name) 

74 if ml: 

75 include_name = f"{ml.group(1)}_{package_name}-ml" 

76 else: 

77 include_name = f"{include_name}_{package_name}" 

78 yield m.group(1) + f'import "{include_name}.proto";' 

79 else: 

80 yield PACKAGE_NAME_REGEX.sub(package_name, line) 

81 

82 

83PROTO_SYNTAX_REGEX = re.compile(r'(\s*)syntax\s*=\s*"proto2"\s*;\s*$') 

84OPTIONAL_REGEX = re.compile(r"(\s*)optional\s(.*)$") 

85 

86 

87def convert_to_proto3(lines: Iterable[str]) -> Iterable[str]: 

88 for line in lines: 

89 # Set the syntax specifier 

90 m = PROTO_SYNTAX_REGEX.match(line) 

91 if m: 

92 yield m.group(1) + 'syntax = "proto3";' 

93 continue 

94 

95 # Remove optional keywords 

96 m = OPTIONAL_REGEX.match(line) 

97 if m: 

98 yield m.group(1) + m.group(2) 

99 continue 

100 

101 # Rewrite import 

102 m = IMPORT_REGEX.match(line) 

103 if m: 

104 yield m.group(1) + f'import "{m.group(2)}.proto3";' 

105 continue 

106 

107 yield line 

108 

109 

110def gen_proto3_code( 

111 protoc_path: str, proto3_path: str, include_path: str, cpp_out: str, python_out: str 

112) -> None: 

113 print(f"Generate pb3 code using {protoc_path}") 

114 build_args = [protoc_path, proto3_path, "-I", include_path] 

115 build_args.extend(["--cpp_out", cpp_out, "--python_out", python_out]) 

116 subprocess.run(build_args, check=True) # noqa: S603 

117 

118 

119def translate(source: str, proto: int, onnx_ml: bool, package_name: str) -> str: 

120 lines: Iterable[str] = source.splitlines() 

121 lines = process_ifs(lines, onnx_ml=onnx_ml) 

122 lines = process_package_name(lines, package_name=package_name) 

123 if proto == 3: # noqa: PLR2004 

124 lines = convert_to_proto3(lines) 

125 else: 

126 assert proto == 2 # noqa: PLR2004 

127 return os.linesep.join(lines) 

128 

129 

130def qualify(f: str, pardir: str | None = None) -> str: 

131 if pardir is None: 

132 pardir = os.path.realpath(os.path.dirname(__file__)) 

133 return os.path.join(pardir, f) 

134 

135 

136def convert( 

137 stem: str, 

138 package_name: str, 

139 output: str, 

140 do_onnx_ml: bool = False, 

141 lite: bool = False, 

142 protoc_path: str = "", 

143) -> None: 

144 proto_in = qualify(f"{stem}.in.proto") 

145 need_rename = package_name != DEFAULT_PACKAGE_NAME 

146 # Having a separate variable for import_ml ensures that the import statements for the generated 

147 # proto files can be set separately from the ONNX_ML environment variable setting. 

148 import_ml = do_onnx_ml 

149 # We do not want to generate the onnx-data-ml.proto files for onnx-data.in.proto, 

150 # as there is no change between onnx-data.proto and the ML version. 

151 if "onnx-data" in proto_in: 

152 do_onnx_ml = False 

153 if do_onnx_ml: 

154 proto_base = f"{stem}_{package_name}-ml" if need_rename else f"{stem}-ml" 

155 else: 

156 proto_base = f"{stem}_{package_name}" if need_rename else f"{stem}" 

157 proto = qualify(f"{proto_base}.proto", pardir=output) 

158 proto3 = qualify(f"{proto_base}.proto3", pardir=output) 

159 

160 print(f"Processing {proto_in}") 

161 with open(proto_in, encoding="utf-8") as fin: 

162 source = fin.read() 

163 print(f"Writing {proto}") 

164 with open(proto, "w", newline="", encoding="utf-8") as fout: 

165 fout.write(autogen_header) 

166 fout.write( 

167 translate(source, proto=2, onnx_ml=import_ml, package_name=package_name) 

168 ) 

169 if lite: 

170 fout.write(LITE_OPTION) 

171 print(f"Writing {proto3}") 

172 with open(proto3, "w", newline="", encoding="utf-8") as fout: 

173 fout.write(autogen_header) 

174 fout.write( 

175 translate(source, proto=3, onnx_ml=import_ml, package_name=package_name) 

176 ) 

177 if lite: 

178 fout.write(LITE_OPTION) 

179 if protoc_path: 

180 porto3_dir = os.path.dirname(proto3) 

181 base_dir = os.path.dirname(porto3_dir) 

182 gen_proto3_code(protoc_path, proto3, base_dir, base_dir, base_dir) 

183 pb3_files = glob.glob(os.path.join(porto3_dir, f"{proto_base}.proto3.*")) 

184 for pb3_file in pb3_files: 

185 print(f"Removing {pb3_file}") 

186 os.remove(pb3_file) 

187 

188 if need_rename: 

189 if do_onnx_ml: 

190 proto_header = qualify(f"{stem}-ml.pb.h", pardir=output) 

191 else: 

192 proto_header = qualify(f"{stem}.pb.h", pardir=output) 

193 print(f"Writing {proto_header}") 

194 with open(proto_header, "w", newline="", encoding="utf-8") as fout: 

195 fout.write("#pragma once\n") 

196 fout.write(f'#include "{proto_base}.pb.h"\n') 

197 

198 # Generate py mapping 

199 # "-" is invalid in python module name, replaces '-' with '_' 

200 pb_py = qualify(f"{stem.replace('-', '_')}_pb.py", pardir=output) 

201 if need_rename: 

202 pb2_py = qualify(f"{proto_base.replace('-', '_')}_pb2.py", pardir=output) 

203 else: # noqa: PLR5501 

204 if do_onnx_ml: 

205 pb2_py = qualify(f"{stem.replace('-', '_')}_ml_pb2.py", pardir=output) 

206 else: 

207 pb2_py = qualify(f"{stem.replace('-', '_')}_pb2.py", pardir=output) 

208 

209 print(f"generating {pb_py}") 

210 with open(pb_py, "w", encoding="utf-8") as f: 

211 f.write( 

212 dedent( 

213 f"""\ 

214 # This file is generated by 'gen_proto.py'. DO NOT EDIT! 

215 

216 

217 from .{os.path.splitext(os.path.basename(pb2_py))[0]} import * # noqa 

218 """ 

219 ) 

220 ) 

221 

222 

223def main() -> None: 

224 parser = argparse.ArgumentParser( 

225 description="Generates .proto file variations from .in.proto" 

226 ) 

227 parser.add_argument( 

228 "-p", 

229 "--package", 

230 default="onnx", 

231 help="package name in the generated proto files (default: %(default)s)", 

232 ) 

233 parser.add_argument("-m", "--ml", action="store_true", help="ML mode") 

234 parser.add_argument( 

235 "-l", 

236 "--lite", 

237 action="store_true", 

238 help="generate lite proto to use with protobuf-lite", 

239 ) 

240 parser.add_argument( 

241 "-o", 

242 "--output", 

243 default=os.path.realpath(os.path.dirname(__file__)), 

244 help="output directory (default: %(default)s)", 

245 ) 

246 parser.add_argument( 

247 "--protoc_path", default="", help="path to protoc for proto3 file validation" 

248 ) 

249 parser.add_argument( 

250 "stems", 

251 nargs="*", 

252 default=["onnx", "onnx-operators", "onnx-data"], 

253 help="list of .in.proto file stems (default: %(default)s)", 

254 ) 

255 args = parser.parse_args() 

256 

257 if not os.path.exists(args.output): 

258 os.makedirs(args.output) 

259 

260 for stem in args.stems: 

261 convert( 

262 stem, 

263 package_name=args.package, 

264 output=args.output, 

265 do_onnx_ml=args.ml, 

266 lite=args.lite, 

267 protoc_path=args.protoc_path, 

268 ) 

269 

270 

271if __name__ == "__main__": 

272 main()