1from __future__ import annotations
2
3import sys
4from typing import Any, BinaryIO, ClassVar
5
6from dissect.cstruct.types.base import EOF, BaseArray, BaseType
7
8
9class WcharArray(str, BaseArray):
10 """Wide-character array type for reading and writing UTF-16 strings."""
11
12 __slots__ = ()
13
14 @classmethod
15 def __default__(cls) -> WcharArray:
16 return type.__call__(cls, "\x00" * (0 if cls.dynamic or cls.null_terminated else cls.num_entries))
17
18 @classmethod
19 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> WcharArray:
20 return type.__call__(cls, super()._read(stream, context))
21
22 @classmethod
23 def _write(cls, stream: BinaryIO, data: str) -> int:
24 if cls.null_terminated:
25 data += "\x00"
26 return stream.write(data.encode(Wchar.__encoding_map__[cls.cs.endian]))
27
28
29class Wchar(str, BaseType):
30 """Wide-character type for reading and writing UTF-16 characters."""
31
32 ArrayType = WcharArray
33
34 __slots__ = ()
35 __encoding_map__: ClassVar[dict[str, str]] = {
36 "@": f"utf-16-{sys.byteorder[0]}e",
37 "=": f"utf-16-{sys.byteorder[0]}e",
38 "<": "utf-16-le",
39 ">": "utf-16-be",
40 "!": "utf-16-be",
41 }
42
43 @classmethod
44 def __default__(cls) -> Wchar:
45 return type.__call__(cls, "\x00")
46
47 @classmethod
48 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Wchar:
49 return cls._read_array(stream, 1, context)
50
51 @classmethod
52 def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> Wchar:
53 if count == 0:
54 return type.__call__(cls, "")
55
56 if count != EOF:
57 count *= 2
58
59 data = stream.read(-1 if count == EOF else count)
60 if count != EOF and len(data) != count:
61 raise EOFError(f"Read {len(data)} bytes, but expected {count}")
62
63 return type.__call__(cls, data.decode(cls.__encoding_map__[cls.cs.endian]))
64
65 @classmethod
66 def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Wchar:
67 buf = []
68 while True:
69 point = stream.read(2)
70 if (bytes_read := len(point)) != 2:
71 raise EOFError(f"Read {bytes_read} bytes, but expected 2")
72
73 if point == b"\x00\x00":
74 break
75
76 buf.append(point)
77
78 return type.__call__(cls, b"".join(buf).decode(cls.__encoding_map__[cls.cs.endian]))
79
80 @classmethod
81 def _write(cls, stream: BinaryIO, data: str) -> int:
82 return stream.write(data.encode(cls.__encoding_map__[cls.cs.endian]))