1from __future__ import annotations
2
3from typing import TYPE_CHECKING, BinaryIO
4
5if TYPE_CHECKING:
6 from dissect.cstruct.types import BaseType
7
8
9class BitBuffer:
10 """Implements a bit buffer that can read and write bit fields."""
11
12 def __init__(self, stream: BinaryIO, endian: str):
13 self.stream = stream
14 self.endian = endian
15
16 self._type: type[BaseType] | None = None
17 self._buffer = 0
18 self._remaining = 0
19
20 def read(self, field_type: type[BaseType], bits: int) -> int:
21 if self._remaining == 0 or self._type != field_type:
22 if field_type.size is None:
23 raise ValueError("Reading variable-length fields is unsupported")
24
25 self._type = field_type
26 self._remaining = field_type.size * 8
27 self._buffer = field_type._read(self.stream)
28
29 if isinstance(self._buffer, bytes):
30 if self.endian == "<":
31 self._buffer = int.from_bytes(self._buffer, "little")
32 else:
33 self._buffer = int.from_bytes(self._buffer, "big")
34
35 if bits > self._remaining:
36 raise ValueError("Reading straddled bits is unsupported")
37
38 if self.endian == "<":
39 v = self._buffer & ((1 << bits) - 1)
40 self._buffer >>= bits
41 self._remaining -= bits
42 else:
43 v = self._buffer & (((1 << (self._remaining - bits)) - 1) ^ ((1 << self._remaining) - 1))
44 v >>= self._remaining - bits
45 self._remaining -= bits
46
47 return v
48
49 def write(self, field_type: type[BaseType], data: int, bits: int) -> None:
50 if self._remaining == 0 or self._type != field_type:
51 if self._type:
52 self.flush()
53
54 if field_type.size is None:
55 raise ValueError("Writing variable-length fields is unsupported")
56
57 self._remaining = field_type.size * 8
58 self._type = field_type
59
60 if self._type is None or self._type.size is None:
61 raise ValueError("Invalid state")
62
63 if self.endian == "<":
64 self._buffer |= data << (self._type.size * 8 - self._remaining)
65 else:
66 self._buffer |= data << (self._remaining - bits)
67
68 self._remaining -= bits
69 if self._remaining == 0:
70 self.flush()
71
72 def flush(self) -> None:
73 if self._type is not None:
74 self._type._write(self.stream, self._buffer)
75 self._type = None
76 self._remaining = 0
77 self._buffer = 0
78
79 def reset(self) -> None:
80 self._type = None
81 self._buffer = 0
82 self._remaining = 0