1# Made in Japan
2
3from __future__ import annotations
4
5import io
6import logging
7from enum import Enum
8from textwrap import dedent, indent
9from typing import TYPE_CHECKING
10
11from dissect.cstruct.bitbuffer import BitBuffer
12from dissect.cstruct.types import (
13 Array,
14 BaseType,
15 Char,
16 CharArray,
17 Flag,
18 Int,
19 Packed,
20 Pointer,
21 Structure,
22 Union,
23 Void,
24 VoidArray,
25 Wchar,
26 WcharArray,
27)
28from dissect.cstruct.types.base import BaseArray
29from dissect.cstruct.types.enum import EnumMetaType
30from dissect.cstruct.types.packed import _struct
31
32if TYPE_CHECKING:
33 from collections.abc import Iterator
34 from types import MethodType
35
36 from dissect.cstruct.cstruct import cstruct
37 from dissect.cstruct.types.structure import Field
38
39SUPPORTED_TYPES = (
40 Array,
41 Char,
42 CharArray,
43 Enum,
44 Flag,
45 Int,
46 Packed,
47 Pointer,
48 Structure,
49 Void,
50 Wchar,
51 WcharArray,
52 VoidArray,
53)
54
55log = logging.getLogger(__name__)
56
57python_compile = compile
58
59
60def compile(structure: type[Structure]) -> type[Structure]:
61 return Compiler(structure.cs).compile(structure)
62
63
64class Compiler:
65 def __init__(self, cs: cstruct):
66 self.cs = cs
67
68 def compile(self, structure: type[Structure]) -> type[Structure]:
69 if issubclass(structure, Union):
70 return structure
71
72 try:
73 structure._read = self.compile_read(structure.__fields__, structure.__name__, structure.__align__)
74 structure.__compiled__ = True
75 except Exception as e:
76 # Silently ignore, we didn't compile unfortunately
77 log.debug("Failed to compile %s", structure, exc_info=e)
78
79 return structure
80
81 def compile_read(self, fields: list[Field], name: str | None = None, align: bool = False) -> MethodType:
82 return _ReadSourceGenerator(self.cs, fields, name, align).generate()
83
84
85class _ReadSourceGenerator:
86 def __init__(self, cs: cstruct, fields: list[Field], name: str | None = None, align: bool = False):
87 self.cs = cs
88 self.fields = fields
89 self.name = name
90 self.align = align
91
92 self.field_map: dict[str, Field] = {}
93 self._token_id = 0
94
95 def _map_field(self, field: Field) -> str:
96 token = f"_{self._token_id}"
97 self.field_map[token] = field
98 self._token_id += 1
99 return token
100
101 def generate(self) -> MethodType:
102 source = self.generate_source()
103 symbols = {token: field.type for token, field in self.field_map.items()}
104
105 code = python_compile(source, f"<compiled {self.name or 'anonymous'}._read>", "exec")
106 exec(code, {"BitBuffer": BitBuffer, "_struct": _struct, **symbols}, d := {})
107 obj = d.popitem()[1]
108 obj.__source__ = source
109
110 return classmethod(obj)
111
112 def generate_source(self) -> str:
113 preamble = """
114 r = {}
115 s = {}
116 o = stream.tell()
117 """
118
119 if any(field.bits for field in self.fields):
120 preamble += "bit_reader = BitBuffer(stream, cls.cs.endian)\n"
121
122 read_code = "\n".join(self._generate_fields())
123
124 outro = """
125 obj = type.__call__(cls, **r)
126 obj._sizes = s
127 obj._values = r
128
129 return obj
130 """
131
132 code = indent(dedent(preamble).lstrip() + read_code + dedent(outro), " ")
133
134 return f"def _read(cls, stream, context=None):\n{code}"
135
136 def _generate_fields(self) -> Iterator[str]:
137 current_offset = 0
138 current_block: list[Field] = []
139 prev_was_bits = False
140 prev_bits_type = None
141 bits_remaining = 0
142 bits_rollover = False
143
144 def flush() -> Iterator[str]:
145 if current_block:
146 if self.align and current_block[0].offset is None:
147 yield f"stream.seek(-stream.tell() & ({current_block[0].alignment} - 1), {io.SEEK_CUR})"
148
149 yield from self._generate_packed(current_block)
150 current_block[:] = []
151
152 def align_to_field(field: Field) -> Iterator[str]:
153 nonlocal current_offset
154
155 if field.offset is not None and field.offset != current_offset:
156 # If a field has a set offset and it's not the same as the current tracked offset, seek to it
157 yield f"stream.seek(o + {field.offset})"
158 current_offset = field.offset
159
160 if self.align and field.offset is None:
161 yield f"stream.seek(-stream.tell() & ({field.alignment} - 1), {io.SEEK_CUR})"
162
163 for field in self.fields:
164 field_type = field.type
165
166 if isinstance(field_type, EnumMetaType):
167 field_type = field_type.type
168
169 if not issubclass(field_type, SUPPORTED_TYPES):
170 raise TypeError(f"Unsupported type for compiler: {field_type}")
171
172 if prev_was_bits and not field.bits:
173 yield "bit_reader.reset()"
174 prev_was_bits = False
175 bits_remaining = 0
176
177 try:
178 size = len(field_type)
179 is_dynamic = False
180 except TypeError:
181 size = None
182 is_dynamic = True
183
184 # Sub structure
185 if issubclass(field_type, Structure):
186 yield from flush()
187 yield from align_to_field(field)
188 yield from self._generate_structure(field)
189
190 # Array of structures and multi-dimensional arrays
191 elif issubclass(field_type, (Array, CharArray, WcharArray)) and (
192 issubclass(field_type.type, Structure) or issubclass(field_type.type, BaseArray) or is_dynamic
193 ):
194 yield from flush()
195 yield from align_to_field(field)
196 yield from self._generate_array(field)
197
198 # Bit fields
199 elif field.bits:
200 if size is None:
201 raise TypeError(f"Unsupported type for bit field: {field_type}")
202
203 if not prev_was_bits:
204 prev_bits_type = field_type
205 prev_was_bits = True
206
207 if bits_remaining == 0 or prev_bits_type != field_type:
208 bits_remaining = (size * 8) - field.bits
209 bits_rollover = True
210
211 yield from flush()
212 yield from align_to_field(field)
213 yield from self._generate_bits(field)
214
215 # Everything else - basic and composite types (and arrays of them)
216 else:
217 current_block.append(field)
218
219 if current_offset is not None and size is not None and (not field.bits or bits_rollover):
220 current_offset += size
221 bits_rollover = False
222
223 yield from flush()
224
225 if self.align:
226 yield f"stream.seek(-stream.tell() & (cls.alignment - 1), {io.SEEK_CUR})"
227
228 def _generate_structure(self, field: Field) -> Iterator[str]:
229 template = f"""
230 _s = stream.tell()
231 r["{field._name}"] = {self._map_field(field)}._read(stream, context=r)
232 s["{field._name}"] = stream.tell() - _s
233 """
234
235 yield dedent(template)
236
237 def _generate_array(self, field: Field) -> Iterator[str]:
238 template = f"""
239 _s = stream.tell()
240 r["{field._name}"] = {self._map_field(field)}._read(stream, context=r)
241 s["{field._name}"] = stream.tell() - _s
242 """
243
244 yield dedent(template)
245
246 def _generate_bits(self, field: Field) -> Iterator[str]:
247 lookup = self._map_field(field)
248 read_type = "_t"
249 field_type = field.type
250 if issubclass(field_type, (Enum, Flag)):
251 read_type += ".type"
252 field_type = field_type.type
253
254 if issubclass(field_type, Char):
255 field_type = field_type.cs.uint8
256 lookup = "cls.cs.uint8"
257
258 template = f"""
259 _t = {lookup}
260 r["{field._name}"] = type.__call__(_t, bit_reader.read({read_type}, {field.bits}))
261 """
262
263 yield dedent(template)
264
265 def _generate_packed(self, fields: list[Field]) -> Iterator[str]:
266 info = list(_generate_struct_info(self.cs, fields, self.align))
267 reads = []
268
269 size = 0
270 slice_index = 0
271 for field, count, _ in info:
272 if field is None:
273 # Padding
274 size += count
275 continue
276
277 field_type = field.type
278 read_type = _get_read_type(self.cs, field_type)
279
280 if issubclass(field_type, (Array, CharArray, WcharArray)):
281 count = field_type.num_entries
282 read_type = _get_read_type(self.cs, field_type.type)
283
284 if issubclass(read_type, (Char, Wchar, Int)):
285 count *= read_type.size
286 getter = f"buf[{size}:{size + count}]"
287 else:
288 getter = f"data[{slice_index}:{slice_index + count}]"
289 slice_index += count
290 elif issubclass(read_type, (Char, Wchar, Int)):
291 getter = f"buf[{size}:{size + read_type.size}]"
292 else:
293 getter = f"data[{slice_index}]"
294 slice_index += 1
295
296 if issubclass(read_type, (Wchar, Int)):
297 # Types that parse bytes further down to their own type
298 parser_template = "{type}({getter})"
299 else:
300 # All other types can be simply intialized
301 parser_template = "type.__call__({type}, {getter})"
302
303 # Create the final reading code
304 if issubclass(field_type, Array):
305 reads.append(f"_t = {self._map_field(field)}")
306 reads.append("_et = _t.type")
307
308 if issubclass(field_type.type, Int):
309 reads.append(f"_b = {getter}")
310 item_parser = parser_template.format(type="_et", getter=f"_b[i:i + {field_type.type.size}]")
311 list_comp = f"[{item_parser} for i in range(0, {count}, {field_type.type.size})]"
312 elif issubclass(field_type.type, Pointer):
313 item_parser = "_et.__new__(_et, e, stream, r)"
314 list_comp = f"[{item_parser} for e in {getter}]"
315 else:
316 item_parser = parser_template.format(type="_et", getter="e")
317 list_comp = f"[{item_parser} for e in {getter}]"
318
319 parser = f"type.__call__(_t, {list_comp})"
320 elif issubclass(field_type, CharArray):
321 parser = f"type.__call__({self._map_field(field)}, {getter})"
322 elif issubclass(field_type, Pointer):
323 reads.append(f"_pt = {self._map_field(field)}")
324 parser = f"_pt.__new__(_pt, {getter}, stream, r)"
325 else:
326 parser = parser_template.format(type=self._map_field(field), getter=getter)
327
328 reads.append(f'r["{field._name}"] = {parser}')
329 reads.append(f's["{field._name}"] = {field_type.size}')
330 reads.append("") # Generates a newline in the resulting code
331
332 size += field_type.size
333
334 fmt = _optimize_struct_fmt(info)
335 if fmt == "x" or (len(fmt) == 2 and fmt[1] == "x"):
336 unpack = ""
337 else:
338 unpack = f'data = _struct(cls.cs.endian, "{fmt}").unpack(buf)\n'
339
340 template = f"""
341 buf = stream.read({size})
342 if len(buf) != {size}: raise EOFError()
343 {unpack}
344 """
345
346 yield dedent(template) + "\n".join(reads)
347
348
349def _generate_struct_info(cs: cstruct, fields: list[Field], align: bool = False) -> Iterator[tuple[Field, int, str]]:
350 if not fields:
351 return
352
353 current_offset = fields[0].offset
354 imaginary_offset = 0
355 for field in fields:
356 # We moved -- probably due to alignment
357 if field.offset is not None and (drift := field.offset - current_offset) > 0:
358 yield None, drift, "x"
359 current_offset += drift
360
361 if align and field.offset is None and (drift := -imaginary_offset & (field.alignment - 1)) > 0:
362 # Assume we started at a correctly aligned boundary
363 yield None, drift, "x"
364 imaginary_offset += drift
365
366 count = 1
367 read_type = _get_read_type(cs, field.type)
368
369 # Drop voids
370 if issubclass(read_type, (Void, VoidArray)):
371 continue
372
373 # Array of more complex types are handled elsewhere
374 if issubclass(read_type, (Array, CharArray, WcharArray)):
375 count = read_type.num_entries
376 read_type = _get_read_type(cs, read_type.type)
377
378 # Take the pack char for Packed
379 if issubclass(read_type, Packed):
380 yield field, count, read_type.packchar
381
382 # Other types are byte based
383 # We don't actually unpack anything here but slice directly out of the buffer
384 elif issubclass(read_type, (Char, Wchar, Int)):
385 yield field, count * read_type.size, "x"
386
387 size = count * read_type.size
388 imaginary_offset += size
389 if current_offset is not None:
390 current_offset += size
391
392
393def _optimize_struct_fmt(info: Iterator[tuple[Field, int, str]]) -> str:
394 chars = []
395
396 current_count = 0
397 current_char = None
398
399 for _, count, char in info:
400 if current_char is None:
401 current_count = count
402 current_char = char
403 continue
404
405 if char != current_char:
406 if current_count:
407 chars.append((current_count, current_char))
408 current_count = count
409 current_char = char
410 else:
411 current_count += count
412
413 if current_char is not None and current_count:
414 chars.append((current_count, current_char))
415
416 return "".join(f"{count if count > 1 else ''}{char}" for count, char in chars)
417
418
419def _get_read_type(cs: cstruct, type_: type[BaseType]) -> type[BaseType]:
420 if issubclass(type_, (Enum, Flag)):
421 type_ = type_.type
422
423 if issubclass(type_, Pointer):
424 type_ = cs.pointer
425
426 return cs.resolve(type_)