1# Made in Japan
2
3from __future__ import annotations
4
5import io
6import logging
7from enum import Enum
8from textwrap import dedent, indent
9from typing import TYPE_CHECKING
10
11from dissect.cstruct.bitbuffer import BitBuffer
12from dissect.cstruct.types import (
13 Array,
14 BaseType,
15 Char,
16 CharArray,
17 Flag,
18 Int,
19 Packed,
20 Pointer,
21 Structure,
22 Union,
23 Void,
24 VoidArray,
25 Wchar,
26 WcharArray,
27)
28from dissect.cstruct.types.base import BaseArray
29from dissect.cstruct.types.enum import EnumMetaType
30from dissect.cstruct.types.packed import _struct
31
32if TYPE_CHECKING:
33 from collections.abc import Iterator
34 from types import MethodType
35
36 from dissect.cstruct.cstruct import cstruct
37 from dissect.cstruct.types.structure import Field
38
39SUPPORTED_TYPES = (
40 Array,
41 Char,
42 CharArray,
43 Enum,
44 Flag,
45 Int,
46 Packed,
47 Pointer,
48 Structure,
49 Void,
50 Wchar,
51 WcharArray,
52 VoidArray,
53)
54
55log = logging.getLogger(__name__)
56
57python_compile = compile
58
59
60def compile(structure: type[Structure]) -> type[Structure]:
61 return Compiler(structure.cs).compile(structure)
62
63
64class Compiler:
65 def __init__(self, cs: cstruct):
66 self.cs = cs
67
68 def compile(self, structure: type[Structure]) -> type[Structure]:
69 if issubclass(structure, Union):
70 return structure
71
72 try:
73 structure._read = self.compile_read(structure.__fields__, structure.__name__, structure.__align__)
74 structure.__compiled__ = True
75 except Exception as e:
76 # Silently ignore, we didn't compile unfortunately
77 log.debug("Failed to compile %s", structure, exc_info=e)
78
79 return structure
80
81 def compile_read(self, fields: list[Field], name: str | None = None, align: bool = False) -> MethodType:
82 return _ReadSourceGenerator(self.cs, fields, name, align).generate()
83
84
85class _ReadSourceGenerator:
86 def __init__(self, cs: cstruct, fields: list[Field], name: str | None = None, align: bool = False):
87 self.cs = cs
88 self.fields = fields
89 self.name = name
90 self.align = align
91
92 self.field_map: dict[str, Field] = {}
93 self._token_id = 0
94
95 def _map_field(self, field: Field) -> str:
96 token = f"_{self._token_id}"
97 self.field_map[token] = field
98 self._token_id += 1
99 return token
100
101 def generate(self) -> MethodType:
102 source = self.generate_source()
103 symbols = {token: field.type for token, field in self.field_map.items()}
104
105 code = python_compile(source, f"<compiled {self.name or 'anonymous'}._read>", "exec")
106 exec(code, {"BitBuffer": BitBuffer, "_struct": _struct, **symbols}, d := {})
107 obj = d.popitem()[1]
108 obj.__source__ = source
109
110 return classmethod(obj)
111
112 def generate_source(self) -> str:
113 preamble = """
114 r = {}
115 s = {}
116 o = stream.tell()
117 """
118
119 if any(field.bits for field in self.fields):
120 preamble += "bit_reader = BitBuffer(stream, cls.cs.endian)\n"
121
122 read_code = "\n".join(self._generate_fields())
123
124 outro = """
125 obj = type.__call__(cls, **r)
126 obj.__dynamic_sizes__ = s
127
128 return obj
129 """
130
131 code = indent(dedent(preamble).lstrip() + read_code + dedent(outro), " ")
132
133 return f"def _read(cls, stream, context=None):\n{code}"
134
135 def _generate_fields(self) -> Iterator[str]:
136 current_offset = 0
137 current_block: list[Field] = []
138 prev_was_bits = False
139 prev_bits_type = None
140 bits_remaining = 0
141 bits_rollover = False
142
143 def flush() -> Iterator[str]:
144 if current_block:
145 if self.align and current_block[0].offset is None:
146 yield f"stream.seek(-stream.tell() & ({current_block[0].alignment} - 1), {io.SEEK_CUR})"
147
148 yield from self._generate_packed(current_block)
149 current_block[:] = []
150
151 def align_to_field(field: Field) -> Iterator[str]:
152 nonlocal current_offset
153
154 if field.offset is not None and field.offset != current_offset:
155 # If a field has a set offset and it's not the same as the current tracked offset, seek to it
156 yield f"stream.seek(o + {field.offset})"
157 current_offset = field.offset
158
159 if self.align and field.offset is None:
160 yield f"stream.seek(-stream.tell() & ({field.alignment} - 1), {io.SEEK_CUR})"
161
162 for field in self.fields:
163 field_type = field.type
164
165 if isinstance(field_type, EnumMetaType):
166 field_type = field_type.type
167
168 if not issubclass(field_type, SUPPORTED_TYPES):
169 raise TypeError(f"Unsupported type for compiler: {field_type}")
170
171 if prev_was_bits and not field.bits:
172 yield "bit_reader.reset()"
173 prev_was_bits = False
174 bits_remaining = 0
175
176 try:
177 size = len(field_type)
178 is_dynamic = False
179 except TypeError:
180 size = None
181 is_dynamic = True
182
183 # Sub structure
184 if issubclass(field_type, Structure):
185 yield from flush()
186 yield from align_to_field(field)
187 yield from self._generate_structure(field)
188
189 # Array of structures and multi-dimensional arrays
190 elif issubclass(field_type, (Array, CharArray, WcharArray)) and (
191 issubclass(field_type.type, Structure) or issubclass(field_type.type, BaseArray) or is_dynamic
192 ):
193 yield from flush()
194 yield from align_to_field(field)
195 yield from self._generate_array(field)
196
197 # Bit fields
198 elif field.bits:
199 if size is None:
200 raise TypeError(f"Unsupported type for bit field: {field_type}")
201
202 if not prev_was_bits:
203 prev_bits_type = field_type
204 prev_was_bits = True
205
206 if bits_remaining == 0 or prev_bits_type != field_type:
207 bits_remaining = (size * 8) - field.bits
208 bits_rollover = True
209
210 yield from flush()
211 yield from align_to_field(field)
212 yield from self._generate_bits(field)
213
214 # Everything else - basic and composite types (and arrays of them)
215 else:
216 current_block.append(field)
217
218 if current_offset is not None and size is not None and (not field.bits or bits_rollover):
219 current_offset += size
220 bits_rollover = False
221
222 yield from flush()
223
224 if self.align:
225 yield f"stream.seek(-stream.tell() & (cls.alignment - 1), {io.SEEK_CUR})"
226
227 def _generate_structure(self, field: Field) -> Iterator[str]:
228 template = f"""
229 {"_s = stream.tell()" if field.type.dynamic else ""}
230 r["{field._name}"] = {self._map_field(field)}._read(stream, context=r)
231 {f's["{field._name}"] = stream.tell() - _s' if field.type.dynamic else ""}
232 """
233
234 yield dedent(template)
235
236 def _generate_array(self, field: Field) -> Iterator[str]:
237 template = f"""
238 {"_s = stream.tell()" if field.type.dynamic else ""}
239 r["{field._name}"] = {self._map_field(field)}._read(stream, context=r)
240 {f's["{field._name}"] = stream.tell() - _s' if field.type.dynamic else ""}
241 """
242
243 yield dedent(template)
244
245 def _generate_bits(self, field: Field) -> Iterator[str]:
246 lookup = self._map_field(field)
247 read_type = "_t"
248 field_type = field.type
249 if issubclass(field_type, (Enum, Flag)):
250 read_type += ".type"
251 field_type = field_type.type
252
253 if issubclass(field_type, Char):
254 field_type = field_type.cs.uint8
255 lookup = "cls.cs.uint8"
256
257 template = f"""
258 _t = {lookup}
259 r["{field._name}"] = type.__call__(_t, bit_reader.read({read_type}, {field.bits}))
260 """
261
262 yield dedent(template)
263
264 def _generate_packed(self, fields: list[Field]) -> Iterator[str]:
265 info = list(_generate_struct_info(self.cs, fields, self.align))
266 reads = []
267
268 size = 0
269 slice_index = 0
270 for field, count, _ in info:
271 if field is None:
272 # Padding
273 size += count
274 continue
275
276 field_type = field.type
277 read_type = _get_read_type(self.cs, field_type)
278
279 if issubclass(field_type, (Array, CharArray, WcharArray)):
280 count = field_type.num_entries
281 read_type = _get_read_type(self.cs, field_type.type)
282
283 if issubclass(read_type, (Char, Wchar, Int)):
284 count *= read_type.size
285 getter = f"buf[{size}:{size + count}]"
286 else:
287 getter = f"data[{slice_index}:{slice_index + count}]"
288 slice_index += count
289 elif issubclass(read_type, (Char, Wchar, Int)):
290 getter = f"buf[{size}:{size + read_type.size}]"
291 else:
292 getter = f"data[{slice_index}]"
293 slice_index += 1
294
295 if issubclass(read_type, (Wchar, Int)):
296 # Types that parse bytes further down to their own type
297 parser_template = "{type}({getter})"
298 else:
299 # All other types can be simply intialized
300 parser_template = "type.__call__({type}, {getter})"
301
302 # Create the final reading code
303 if issubclass(field_type, Array):
304 reads.append(f"_t = {self._map_field(field)}")
305 reads.append("_et = _t.type")
306
307 if issubclass(field_type.type, Int):
308 reads.append(f"_b = {getter}")
309 item_parser = parser_template.format(type="_et", getter=f"_b[i:i + {field_type.type.size}]")
310 list_comp = f"[{item_parser} for i in range(0, {count}, {field_type.type.size})]"
311 elif issubclass(field_type.type, Pointer):
312 item_parser = "_et.__new__(_et, e, stream, r)"
313 list_comp = f"[{item_parser} for e in {getter}]"
314 else:
315 item_parser = parser_template.format(type="_et", getter="e")
316 list_comp = f"[{item_parser} for e in {getter}]"
317
318 parser = f"type.__call__(_t, {list_comp})"
319 elif issubclass(field_type, CharArray):
320 parser = f"type.__call__({self._map_field(field)}, {getter})"
321 elif issubclass(field_type, Pointer):
322 reads.append(f"_pt = {self._map_field(field)}")
323 parser = f"_pt.__new__(_pt, {getter}, stream, r)"
324 else:
325 parser = parser_template.format(type=self._map_field(field), getter=getter)
326
327 reads.append(f'r["{field._name}"] = {parser}')
328 reads.append("") # Generates a newline in the resulting code
329
330 size += field_type.size
331
332 fmt = _optimize_struct_fmt(info)
333 if fmt == "x" or (len(fmt) == 2 and fmt[1] == "x"):
334 unpack = ""
335 else:
336 unpack = f'data = _struct(cls.cs.endian, "{fmt}").unpack(buf)\n'
337
338 template = f"""
339 buf = stream.read({size})
340 if len(buf) != {size}: raise EOFError()
341 {unpack}
342 """
343
344 yield dedent(template) + "\n".join(reads)
345
346
347def _generate_struct_info(cs: cstruct, fields: list[Field], align: bool = False) -> Iterator[tuple[Field, int, str]]:
348 if not fields:
349 return
350
351 current_offset = fields[0].offset
352 imaginary_offset = 0
353 for field in fields:
354 # We moved -- probably due to alignment
355 if field.offset is not None and (drift := field.offset - current_offset) > 0:
356 yield None, drift, "x"
357 current_offset += drift
358
359 if align and field.offset is None and (drift := -imaginary_offset & (field.alignment - 1)) > 0:
360 # Assume we started at a correctly aligned boundary
361 yield None, drift, "x"
362 imaginary_offset += drift
363
364 count = 1
365 read_type = _get_read_type(cs, field.type)
366
367 # Drop voids
368 if issubclass(read_type, (Void, VoidArray)):
369 continue
370
371 # Array of more complex types are handled elsewhere
372 if issubclass(read_type, (Array, CharArray, WcharArray)):
373 count = read_type.num_entries
374 read_type = _get_read_type(cs, read_type.type)
375
376 # Take the pack char for Packed
377 if issubclass(read_type, Packed):
378 yield field, count, read_type.packchar
379
380 # Other types are byte based
381 # We don't actually unpack anything here but slice directly out of the buffer
382 elif issubclass(read_type, (Char, Wchar, Int)):
383 yield field, count * read_type.size, "x"
384
385 size = count * read_type.size
386 imaginary_offset += size
387 if current_offset is not None:
388 current_offset += size
389
390
391def _optimize_struct_fmt(info: Iterator[tuple[Field, int, str]]) -> str:
392 chars = []
393
394 current_count = 0
395 current_char = None
396
397 for _, count, char in info:
398 if current_char is None:
399 current_count = count
400 current_char = char
401 continue
402
403 if char != current_char:
404 if current_count:
405 chars.append((current_count, current_char))
406 current_count = count
407 current_char = char
408 else:
409 current_count += count
410
411 if current_char is not None and current_count:
412 chars.append((current_count, current_char))
413
414 return "".join(f"{count if count > 1 else ''}{char}" for count, char in chars)
415
416
417def _get_read_type(cs: cstruct, type_: type[BaseType]) -> type[BaseType]:
418 if issubclass(type_, (Enum, Flag)):
419 type_ = type_.type
420
421 if issubclass(type_, Pointer):
422 type_ = cs.pointer
423
424 return cs.resolve(type_)