1from __future__ import annotations
2
3import functools
4from io import BytesIO
5from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar, TypeVar
6
7from dissect.cstruct.exceptions import ArraySizeError
8from dissect.cstruct.expression import Expression
9
10if TYPE_CHECKING:
11 from collections.abc import Callable
12
13 from typing_extensions import Self
14
15 from dissect.cstruct.cstruct import cstruct
16
17
18EOF = -0xE0F # Negative counts are illegal anyway, so abuse that for our EOF sentinel
19
20
21class MetaType(type):
22 """Base metaclass for cstruct type classes."""
23
24 cs: cstruct
25 """The cstruct instance this type class belongs to."""
26 size: int | None
27 """The size of the type in bytes. Can be ``None`` for dynamic sized types."""
28 dynamic: bool
29 """Whether or not the type is dynamically sized."""
30 alignment: int | None
31 """The alignment of the type in bytes. A value of ``None`` will be treated as 1-byte aligned."""
32
33 # This must be the actual type, but since Array is a subclass of BaseType, we correct this at the bottom of the file
34 ArrayType: type[BaseArray] = "Array"
35 """The array type for this type class."""
36
37 def __call__(cls, *args, **kwargs) -> Self: # type: ignore
38 """Adds support for ``TypeClass(bytes | file-like object)`` parsing syntax."""
39 # TODO: add support for Type(cs) API to create new bounded type classes, similar to the old API?
40 if len(args) == 1 and not isinstance(args[0], cls):
41 stream = args[0]
42
43 if _is_readable_type(stream):
44 return cls._read(stream)
45
46 if issubclass(cls, bytes) and isinstance(stream, bytes) and len(stream) == cls.size:
47 # Shortcut for char/bytes type
48 return type.__call__(cls, *args, **kwargs)
49
50 if _is_buffer_type(stream):
51 return cls.reads(stream)
52
53 return type.__call__(cls, *args, **kwargs)
54
55 def __getitem__(cls, num_entries: int | Expression | None) -> type[BaseArray]:
56 """Create a new array with the given number of entries."""
57 return cls.cs._make_array(cls, num_entries)
58
59 def __bool__(cls) -> bool:
60 """Type class is always truthy."""
61 return True
62
63 def __len__(cls) -> int:
64 """Return the byte size of the type."""
65 if cls.size is None:
66 raise TypeError("Dynamic size")
67
68 return cls.size
69
70 def __default__(cls) -> Self: # type: ignore
71 """Return the default value of this type."""
72 return cls()
73
74 def reads(cls, data: bytes | memoryview | bytearray) -> Self: # type: ignore
75 """Parse the given data from a bytes-like object.
76
77 Args:
78 data: Bytes-like object to parse.
79
80 Returns:
81 The parsed value of this type.
82 """
83 return cls._read(BytesIO(data))
84
85 def read(cls, obj: BinaryIO | bytes | memoryview | bytearray) -> Self: # type: ignore
86 """Parse the given data.
87
88 Args:
89 obj: Data to parse. Can be a bytes-like object or a file-like object.
90
91 Returns:
92 The parsed value of this type.
93 """
94 if _is_buffer_type(obj):
95 return cls.reads(obj)
96
97 if not _is_readable_type(obj):
98 raise TypeError("Invalid object type")
99
100 return cls._read(obj)
101
102 def write(cls, stream: BinaryIO, value: Any) -> int:
103 """Write a value to a writable file-like object.
104
105 Args:
106 stream: File-like objects that supports writing.
107 value: Value to write.
108
109 Returns:
110 The amount of bytes written.
111 """
112 return cls._write(stream, value)
113
114 def dumps(cls, value: Any) -> bytes:
115 """Dump a value to a byte string.
116
117 Args:
118 value: Value to dump.
119
120 Returns:
121 The raw bytes of this type.
122 """
123 out = BytesIO()
124 cls._write(out, value)
125 return out.getvalue()
126
127 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore
128 """Internal function for reading value.
129
130 Must be implemented per type.
131
132 Args:
133 stream: The stream to read from.
134 context: Optional reading context.
135 """
136 raise NotImplementedError
137
138 def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Self]: # type: ignore
139 """Internal function for reading array values.
140
141 Allows type implementations to do optimized reading for their type.
142
143 Args:
144 stream: The stream to read from.
145 count: The amount of values to read.
146 context: Optional reading context.
147 """
148 if count == EOF:
149 result = []
150 while not _is_eof(stream):
151 result.append(cls._read(stream, context))
152 return result
153
154 return [cls._read(stream, context) for _ in range(count)]
155
156 def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Self]:
157 """Internal function for reading null-terminated data.
158
159 "Null" is type specific, so must be implemented per type.
160
161 Args:
162 stream: The stream to read from.
163 context: Optional reading context.
164 """
165 raise NotImplementedError
166
167 def _write(cls, stream: BinaryIO, data: Any) -> int:
168 raise NotImplementedError
169
170 def _write_array(cls, stream: BinaryIO, array: list[Self]) -> int: # type: ignore
171 """Internal function for writing arrays.
172
173 Allows type implementations to do optimized writing for their type.
174
175 Args:
176 stream: The stream to read from.
177 array: The array to write.
178 """
179 return sum(cls._write(stream, entry) for entry in array)
180
181 def _write_0(cls, stream: BinaryIO, array: list[Self]) -> int: # type: ignore
182 """Internal function for writing null-terminated arrays.
183
184 Allows type implementations to do optimized writing for their type.
185
186 Args:
187 stream: The stream to read from.
188 array: The array to write.
189 """
190 return cls._write_array(stream, [*array, cls.__default__()])
191
192
193class _overload:
194 """Descriptor to use on the ``write`` and ``dumps`` methods on cstruct types.
195
196 Allows for calling these methods on both the type and instance.
197
198 Example:
199 >>> int32.dumps(123)
200 b'\\x7b\\x00\\x00\\x00'
201 >>> int32(123).dumps()
202 b'\\x7b\\x00\\x00\\x00'
203 """
204
205 def __init__(self, func: Callable[..., Any]) -> None:
206 self.func = func
207
208 def __get__(self, instance: BaseType | None, owner: type[BaseType]) -> Callable[[], bytes]:
209 if instance is None:
210 return functools.partial(self.func, owner)
211 return functools.partial(self.func, instance.__class__, value=instance)
212
213
214class BaseType(metaclass=MetaType):
215 """Base class for cstruct type classes."""
216
217 dumps = _overload(MetaType.dumps)
218 write = _overload(MetaType.write)
219
220 def __len__(self) -> int:
221 """Return the byte size of the type."""
222 if self.__class__.size is None:
223 raise TypeError("Dynamic size")
224
225 return self.__class__.size
226
227
228T = TypeVar("T", bound=BaseType)
229
230
231class BaseArray(BaseType):
232 """Implements a fixed or dynamically sized array type.
233
234 Example:
235 When using the default C-style parser, the following syntax is supported:
236
237 x[3] -> 3 -> static length.
238 x[] -> None -> null-terminated.
239 x[expr] -> expr -> dynamic length.
240 """
241
242 type: ClassVar[type[BaseType]]
243 num_entries: ClassVar[int | Expression | None]
244 null_terminated: ClassVar[bool]
245
246 @classmethod
247 def __default__(cls) -> BaseType:
248 return type.__call__(
249 cls, [cls.type.__default__()] * (cls.num_entries if isinstance(cls.num_entries, int) else 0)
250 )
251
252 @classmethod
253 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[BaseType]:
254 if cls.null_terminated:
255 return cls.type._read_0(stream, context)
256
257 if isinstance(cls.num_entries, int):
258 num = max(0, cls.num_entries)
259 elif cls.num_entries is None:
260 num = EOF
261 elif isinstance(cls.num_entries, Expression):
262 try:
263 num = max(0, cls.num_entries.evaluate(cls.cs, context))
264 except Exception:
265 if cls.num_entries.expression != "EOF":
266 raise
267 num = EOF
268
269 return cls.type._read_array(stream, num, context)
270
271 @classmethod
272 def _write(cls, stream: BinaryIO, data: list[Any]) -> int:
273 if cls.null_terminated:
274 return cls.type._write_0(stream, data)
275
276 if not cls.dynamic and cls.num_entries != (actual_size := len(data)):
277 raise ArraySizeError(f"Expected static array size {cls.num_entries}, got {actual_size} instead.")
278
279 return cls.type._write_array(stream, data)
280
281
282class Array(list[T], BaseArray):
283 @classmethod
284 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[T]:
285 return cls(super()._read(stream, context))
286
287
288def _is_readable_type(value: object) -> bool:
289 return hasattr(value, "read")
290
291
292def _is_buffer_type(value: object) -> bool:
293 return isinstance(value, (bytes, memoryview, bytearray))
294
295
296def _is_eof(stream: BinaryIO) -> bool:
297 """Check if the stream has reached EOF."""
298 pos = stream.tell()
299 stream.read(1)
300
301 if stream.tell() == pos:
302 return True
303
304 stream.seek(pos)
305 return False
306
307
308# As mentioned in the BaseType class, we correctly set the type here
309MetaType.ArrayType = Array