Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dissect/cstruct/compiler.py: 84%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

251 statements  

1# Made in Japan 

2 

3from __future__ import annotations 

4 

5import io 

6import logging 

7from enum import Enum 

8from textwrap import dedent, indent 

9from typing import TYPE_CHECKING 

10 

11from dissect.cstruct.bitbuffer import BitBuffer 

12from dissect.cstruct.types import ( 

13 Array, 

14 BaseType, 

15 Char, 

16 CharArray, 

17 Flag, 

18 Int, 

19 Packed, 

20 Pointer, 

21 Structure, 

22 Union, 

23 Void, 

24 VoidArray, 

25 Wchar, 

26 WcharArray, 

27) 

28from dissect.cstruct.types.base import BaseArray 

29from dissect.cstruct.types.enum import EnumMetaType 

30from dissect.cstruct.types.packed import _struct 

31 

32if TYPE_CHECKING: 

33 from collections.abc import Iterator 

34 from types import MethodType 

35 

36 from dissect.cstruct.cstruct import cstruct 

37 from dissect.cstruct.types.structure import Field 

38 

39SUPPORTED_TYPES = ( 

40 Array, 

41 Char, 

42 CharArray, 

43 Enum, 

44 Flag, 

45 Int, 

46 Packed, 

47 Pointer, 

48 Structure, 

49 Void, 

50 Wchar, 

51 WcharArray, 

52 VoidArray, 

53) 

54 

55log = logging.getLogger(__name__) 

56 

57python_compile = compile 

58 

59 

60def compile(structure: type[Structure]) -> type[Structure]: 

61 return Compiler(structure.cs).compile(structure) 

62 

63 

64class Compiler: 

65 def __init__(self, cs: cstruct): 

66 self.cs = cs 

67 

68 def compile(self, structure: type[Structure]) -> type[Structure]: 

69 if issubclass(structure, Union): 

70 return structure 

71 

72 try: 

73 structure._read = self.compile_read(structure.__fields__, structure.__name__, structure.__align__) 

74 structure.__compiled__ = True 

75 except Exception as e: 

76 # Silently ignore, we didn't compile unfortunately 

77 log.debug("Failed to compile %s", structure, exc_info=e) 

78 

79 return structure 

80 

81 def compile_read(self, fields: list[Field], name: str | None = None, align: bool = False) -> MethodType: 

82 return _ReadSourceGenerator(self.cs, fields, name, align).generate() 

83 

84 

85class _ReadSourceGenerator: 

86 def __init__(self, cs: cstruct, fields: list[Field], name: str | None = None, align: bool = False): 

87 self.cs = cs 

88 self.fields = fields 

89 self.name = name 

90 self.align = align 

91 

92 self.field_map: dict[str, Field] = {} 

93 self._token_id = 0 

94 

95 def _map_field(self, field: Field) -> str: 

96 token = f"_{self._token_id}" 

97 self.field_map[token] = field 

98 self._token_id += 1 

99 return token 

100 

101 def generate(self) -> MethodType: 

102 source = self.generate_source() 

103 symbols = {token: field.type for token, field in self.field_map.items()} 

104 

105 code = python_compile(source, f"<compiled {self.name or 'anonymous'}._read>", "exec") 

106 exec(code, {"BitBuffer": BitBuffer, "_struct": _struct, **symbols}, d := {}) 

107 obj = d.popitem()[1] 

108 obj.__source__ = source 

109 

110 return classmethod(obj) 

111 

112 def generate_source(self) -> str: 

113 preamble = """ 

114 r = {} 

115 s = {} 

116 o = stream.tell() 

117 """ 

118 

119 if any(field.bits for field in self.fields): 

120 preamble += "bit_reader = BitBuffer(stream, cls.cs.endian)\n" 

121 

122 read_code = "\n".join(self._generate_fields()) 

123 

124 outro = """ 

125 obj = type.__call__(cls, **r) 

126 obj._sizes = s 

127 obj._values = r 

128 

129 return obj 

130 """ 

131 

132 code = indent(dedent(preamble).lstrip() + read_code + dedent(outro), " ") 

133 

134 return f"def _read(cls, stream, context=None):\n{code}" 

135 

136 def _generate_fields(self) -> Iterator[str]: 

137 current_offset = 0 

138 current_block: list[Field] = [] 

139 prev_was_bits = False 

140 prev_bits_type = None 

141 bits_remaining = 0 

142 bits_rollover = False 

143 

144 def flush() -> Iterator[str]: 

145 if current_block: 

146 if self.align and current_block[0].offset is None: 

147 yield f"stream.seek(-stream.tell() & ({current_block[0].alignment} - 1), {io.SEEK_CUR})" 

148 

149 yield from self._generate_packed(current_block) 

150 current_block[:] = [] 

151 

152 def align_to_field(field: Field) -> Iterator[str]: 

153 nonlocal current_offset 

154 

155 if field.offset is not None and field.offset != current_offset: 

156 # If a field has a set offset and it's not the same as the current tracked offset, seek to it 

157 yield f"stream.seek(o + {field.offset})" 

158 current_offset = field.offset 

159 

160 if self.align and field.offset is None: 

161 yield f"stream.seek(-stream.tell() & ({field.alignment} - 1), {io.SEEK_CUR})" 

162 

163 for field in self.fields: 

164 field_type = field.type 

165 

166 if isinstance(field_type, EnumMetaType): 

167 field_type = field_type.type 

168 

169 if not issubclass(field_type, SUPPORTED_TYPES): 

170 raise TypeError(f"Unsupported type for compiler: {field_type}") 

171 

172 if prev_was_bits and not field.bits: 

173 yield "bit_reader.reset()" 

174 prev_was_bits = False 

175 bits_remaining = 0 

176 

177 try: 

178 size = len(field_type) 

179 is_dynamic = False 

180 except TypeError: 

181 size = None 

182 is_dynamic = True 

183 

184 # Sub structure 

185 if issubclass(field_type, Structure): 

186 yield from flush() 

187 yield from align_to_field(field) 

188 yield from self._generate_structure(field) 

189 

190 # Array of structures and multi-dimensional arrays 

191 elif issubclass(field_type, (Array, CharArray, WcharArray)) and ( 

192 issubclass(field_type.type, Structure) or issubclass(field_type.type, BaseArray) or is_dynamic 

193 ): 

194 yield from flush() 

195 yield from align_to_field(field) 

196 yield from self._generate_array(field) 

197 

198 # Bit fields 

199 elif field.bits: 

200 if size is None: 

201 raise TypeError(f"Unsupported type for bit field: {field_type}") 

202 

203 if not prev_was_bits: 

204 prev_bits_type = field_type 

205 prev_was_bits = True 

206 

207 if bits_remaining == 0 or prev_bits_type != field_type: 

208 bits_remaining = (size * 8) - field.bits 

209 bits_rollover = True 

210 

211 yield from flush() 

212 yield from align_to_field(field) 

213 yield from self._generate_bits(field) 

214 

215 # Everything else - basic and composite types (and arrays of them) 

216 else: 

217 current_block.append(field) 

218 

219 if current_offset is not None and size is not None and (not field.bits or bits_rollover): 

220 current_offset += size 

221 bits_rollover = False 

222 

223 yield from flush() 

224 

225 if self.align: 

226 yield f"stream.seek(-stream.tell() & (cls.alignment - 1), {io.SEEK_CUR})" 

227 

228 def _generate_structure(self, field: Field) -> Iterator[str]: 

229 template = f""" 

230 _s = stream.tell() 

231 r["{field._name}"] = {self._map_field(field)}._read(stream, context=r) 

232 s["{field._name}"] = stream.tell() - _s 

233 """ 

234 

235 yield dedent(template) 

236 

237 def _generate_array(self, field: Field) -> Iterator[str]: 

238 template = f""" 

239 _s = stream.tell() 

240 r["{field._name}"] = {self._map_field(field)}._read(stream, context=r) 

241 s["{field._name}"] = stream.tell() - _s 

242 """ 

243 

244 yield dedent(template) 

245 

246 def _generate_bits(self, field: Field) -> Iterator[str]: 

247 lookup = self._map_field(field) 

248 read_type = "_t" 

249 field_type = field.type 

250 if issubclass(field_type, (Enum, Flag)): 

251 read_type += ".type" 

252 field_type = field_type.type 

253 

254 if issubclass(field_type, Char): 

255 field_type = field_type.cs.uint8 

256 lookup = "cls.cs.uint8" 

257 

258 template = f""" 

259 _t = {lookup} 

260 r["{field._name}"] = type.__call__(_t, bit_reader.read({read_type}, {field.bits})) 

261 """ 

262 

263 yield dedent(template) 

264 

265 def _generate_packed(self, fields: list[Field]) -> Iterator[str]: 

266 info = list(_generate_struct_info(self.cs, fields, self.align)) 

267 reads = [] 

268 

269 size = 0 

270 slice_index = 0 

271 for field, count, _ in info: 

272 if field is None: 

273 # Padding 

274 size += count 

275 continue 

276 

277 field_type = field.type 

278 read_type = _get_read_type(self.cs, field_type) 

279 

280 if issubclass(field_type, (Array, CharArray, WcharArray)): 

281 count = field_type.num_entries 

282 read_type = _get_read_type(self.cs, field_type.type) 

283 

284 if issubclass(read_type, (Char, Wchar, Int)): 

285 count *= read_type.size 

286 getter = f"buf[{size}:{size + count}]" 

287 else: 

288 getter = f"data[{slice_index}:{slice_index + count}]" 

289 slice_index += count 

290 elif issubclass(read_type, (Char, Wchar, Int)): 

291 getter = f"buf[{size}:{size + read_type.size}]" 

292 else: 

293 getter = f"data[{slice_index}]" 

294 slice_index += 1 

295 

296 if issubclass(read_type, (Wchar, Int)): 

297 # Types that parse bytes further down to their own type 

298 parser_template = "{type}({getter})" 

299 else: 

300 # All other types can be simply intialized 

301 parser_template = "type.__call__({type}, {getter})" 

302 

303 # Create the final reading code 

304 if issubclass(field_type, Array): 

305 reads.append(f"_t = {self._map_field(field)}") 

306 reads.append("_et = _t.type") 

307 

308 if issubclass(field_type.type, Int): 

309 reads.append(f"_b = {getter}") 

310 item_parser = parser_template.format(type="_et", getter=f"_b[i:i + {field_type.type.size}]") 

311 list_comp = f"[{item_parser} for i in range(0, {count}, {field_type.type.size})]" 

312 elif issubclass(field_type.type, Pointer): 

313 item_parser = "_et.__new__(_et, e, stream, r)" 

314 list_comp = f"[{item_parser} for e in {getter}]" 

315 else: 

316 item_parser = parser_template.format(type="_et", getter="e") 

317 list_comp = f"[{item_parser} for e in {getter}]" 

318 

319 parser = f"type.__call__(_t, {list_comp})" 

320 elif issubclass(field_type, CharArray): 

321 parser = f"type.__call__({self._map_field(field)}, {getter})" 

322 elif issubclass(field_type, Pointer): 

323 reads.append(f"_pt = {self._map_field(field)}") 

324 parser = f"_pt.__new__(_pt, {getter}, stream, r)" 

325 else: 

326 parser = parser_template.format(type=self._map_field(field), getter=getter) 

327 

328 reads.append(f'r["{field._name}"] = {parser}') 

329 reads.append(f's["{field._name}"] = {field_type.size}') 

330 reads.append("") # Generates a newline in the resulting code 

331 

332 size += field_type.size 

333 

334 fmt = _optimize_struct_fmt(info) 

335 if fmt == "x" or (len(fmt) == 2 and fmt[1] == "x"): 

336 unpack = "" 

337 else: 

338 unpack = f'data = _struct(cls.cs.endian, "{fmt}").unpack(buf)\n' 

339 

340 template = f""" 

341 buf = stream.read({size}) 

342 if len(buf) != {size}: raise EOFError() 

343 {unpack} 

344 """ 

345 

346 yield dedent(template) + "\n".join(reads) 

347 

348 

349def _generate_struct_info(cs: cstruct, fields: list[Field], align: bool = False) -> Iterator[tuple[Field, int, str]]: 

350 if not fields: 

351 return 

352 

353 current_offset = fields[0].offset 

354 imaginary_offset = 0 

355 for field in fields: 

356 # We moved -- probably due to alignment 

357 if field.offset is not None and (drift := field.offset - current_offset) > 0: 

358 yield None, drift, "x" 

359 current_offset += drift 

360 

361 if align and field.offset is None and (drift := -imaginary_offset & (field.alignment - 1)) > 0: 

362 # Assume we started at a correctly aligned boundary 

363 yield None, drift, "x" 

364 imaginary_offset += drift 

365 

366 count = 1 

367 read_type = _get_read_type(cs, field.type) 

368 

369 # Drop voids 

370 if issubclass(read_type, (Void, VoidArray)): 

371 continue 

372 

373 # Array of more complex types are handled elsewhere 

374 if issubclass(read_type, (Array, CharArray, WcharArray)): 

375 count = read_type.num_entries 

376 read_type = _get_read_type(cs, read_type.type) 

377 

378 # Take the pack char for Packed 

379 if issubclass(read_type, Packed): 

380 yield field, count, read_type.packchar 

381 

382 # Other types are byte based 

383 # We don't actually unpack anything here but slice directly out of the buffer 

384 elif issubclass(read_type, (Char, Wchar, Int)): 

385 yield field, count * read_type.size, "x" 

386 

387 size = count * read_type.size 

388 imaginary_offset += size 

389 if current_offset is not None: 

390 current_offset += size 

391 

392 

393def _optimize_struct_fmt(info: Iterator[tuple[Field, int, str]]) -> str: 

394 chars = [] 

395 

396 current_count = 0 

397 current_char = None 

398 

399 for _, count, char in info: 

400 if current_char is None: 

401 current_count = count 

402 current_char = char 

403 continue 

404 

405 if char != current_char: 

406 if current_count: 

407 chars.append((current_count, current_char)) 

408 current_count = count 

409 current_char = char 

410 else: 

411 current_count += count 

412 

413 if current_char is not None and current_count: 

414 chars.append((current_count, current_char)) 

415 

416 return "".join(f"{count if count > 1 else ''}{char}" for count, char in chars) 

417 

418 

419def _get_read_type(cs: cstruct, type_: type[BaseType]) -> type[BaseType]: 

420 if issubclass(type_, (Enum, Flag)): 

421 type_ = type_.type 

422 

423 if issubclass(type_, Pointer): 

424 type_ = cs.pointer 

425 

426 return cs.resolve(type_)