1from __future__ import annotations
2
3import bitarray
4from bitstring.exceptions import CreationError
5from typing import Union, Iterable, Optional, overload, Iterator, Any
6
7
8def offset_slice_indices_lsb0(key: slice, length: int) -> slice:
9 # First convert slice to all integers
10 # Length already should take account of the offset
11 start, stop, step = key.indices(length)
12 new_start = length - stop
13 new_stop = length - start
14 # For negative step we sometimes get a negative stop, which can't be used correctly in a new slice
15 return slice(new_start, None if new_stop < 0 else new_stop, step)
16
17def offset_start_stop_lsb0(start: Optional[int], stop: Optional[int], length: int) -> slice:
18 # First convert slice to all integers
19 # Length already should take account of the offset
20 start, stop, _ = slice(start, stop, None).indices(length)
21 new_start = length - stop
22 new_stop = length - start
23 return new_start, new_stop
24
25
26class BitStore:
27 """A light wrapper around bitarray that does the LSB0 stuff"""
28
29 __slots__ = ('_bitarray', 'modified_length', 'immutable')
30
31 def __init__(self, initializer: Union[int, bitarray.bitarray, str, None] = None,
32 immutable: bool = False) -> None:
33 self._bitarray = bitarray.bitarray(initializer)
34 self.immutable = immutable
35 self.modified_length = None
36
37 @classmethod
38 def frombytes(cls, b: Union[bytes, bytearray, memoryview], /) -> BitStore:
39 x = super().__new__(cls)
40 x._bitarray = bitarray.bitarray()
41 x._bitarray.frombytes(b)
42 x.immutable = False
43 x.modified_length = None
44 return x
45
46 @classmethod
47 def frombuffer(cls, buffer, /, length: Optional[int] = None) -> BitStore:
48 x = super().__new__(cls)
49 x._bitarray = bitarray.bitarray(buffer=buffer)
50 x.immutable = True
51 x.modified_length = length
52 # Here 'modified' means it shouldn't be changed further, so setting, deleting etc. are disallowed.
53 if x.modified_length is not None:
54 if x.modified_length < 0:
55 raise CreationError("Can't create bitstring with a negative length.")
56 if x.modified_length > len(x._bitarray):
57 raise CreationError(
58 f"Can't create bitstring with a length of {x.modified_length} from {len(x._bitarray)} bits of data.")
59 return x
60
61 def setall(self, value: int, /) -> None:
62 self._bitarray.setall(value)
63
64 def tobytes(self) -> bytes:
65 if self.modified_length is not None:
66 return self._bitarray[:self.modified_length].tobytes()
67 return self._bitarray.tobytes()
68
69 def slice_to_uint(self, start: Optional[int] = None, end: Optional[int] = None) -> int:
70 return bitarray.util.ba2int(self.getslice(start, end)._bitarray, signed=False)
71
72 def slice_to_int(self, start: Optional[int] = None, end: Optional[int] = None) -> int:
73 return bitarray.util.ba2int(self.getslice(start, end)._bitarray, signed=True)
74
75 def slice_to_hex(self, start: Optional[int] = None, end: Optional[int] = None) -> str:
76 return bitarray.util.ba2hex(self.getslice(start, end)._bitarray)
77
78 def slice_to_bin(self, start: Optional[int] = None, end: Optional[int] = None) -> str:
79 return self.getslice(start, end)._bitarray.to01()
80
81 def slice_to_oct(self, start: Optional[int] = None, end: Optional[int] = None) -> str:
82 return bitarray.util.ba2base(8, self.getslice(start, end)._bitarray)
83
84 def __iadd__(self, other: BitStore, /) -> BitStore:
85 self._bitarray += other._bitarray
86 return self
87
88 def __add__(self, other: BitStore, /) -> BitStore:
89 bs = self._copy()
90 bs += other
91 return bs
92
93 def __eq__(self, other: Any, /) -> bool:
94 return self._bitarray == other._bitarray
95
96 def __and__(self, other: BitStore, /) -> BitStore:
97 return BitStore(self._bitarray & other._bitarray)
98
99 def __or__(self, other: BitStore, /) -> BitStore:
100 return BitStore(self._bitarray | other._bitarray)
101
102 def __xor__(self, other: BitStore, /) -> BitStore:
103 return BitStore(self._bitarray ^ other._bitarray)
104
105 def __iand__(self, other: BitStore, /) -> BitStore:
106 self._bitarray &= other._bitarray
107 return self
108
109 def __ior__(self, other: BitStore, /) -> BitStore:
110 self._bitarray |= other._bitarray
111 return self
112
113 def __ixor__(self, other: BitStore, /) -> BitStore:
114 self._bitarray ^= other._bitarray
115 return self
116
117 def find(self, bs: BitStore, start: int, end: int, bytealigned: bool = False) -> int:
118 if not bytealigned:
119 return self._bitarray.find(bs._bitarray, start, end)
120 try:
121 return next(self.findall_msb0(bs, start, end, bytealigned))
122 except StopIteration:
123 return -1
124
125 def rfind(self, bs: BitStore, start: int, end: int, bytealigned: bool = False):
126 if not bytealigned:
127 return self._bitarray.find(bs._bitarray, start, end, right=True)
128 try:
129 return next(self.rfindall_msb0(bs, start, end, bytealigned))
130 except StopIteration:
131 return -1
132
133 def findall_msb0(self, bs: BitStore, start: int, end: int, bytealigned: bool = False) -> Iterator[int]:
134 i = self._bitarray.itersearch(bs._bitarray, start, end)
135 if not bytealigned:
136 for p in i:
137 yield p
138 else:
139 for p in i:
140 if (p % 8) == 0:
141 yield p
142
143 def rfindall_msb0(self, bs: BitStore, start: int, end: int, bytealigned: bool = False) -> Iterator[int]:
144 i = self._bitarray.itersearch(bs._bitarray, start, end, right=True)
145 if not bytealigned:
146 for p in i:
147 yield p
148 else:
149 for p in i:
150 if (p % 8) == 0:
151 yield p
152
153 def count(self, value, /) -> int:
154 return self._bitarray.count(value)
155
156 def clear(self) -> None:
157 self._bitarray.clear()
158
159 def reverse(self) -> None:
160 self._bitarray.reverse()
161
162 def __iter__(self) -> Iterable[bool]:
163 for i in range(len(self)):
164 yield self.getindex(i)
165
166 def _copy(self) -> BitStore:
167 """Always creates a copy, even if instance is immutable."""
168 return BitStore(self._bitarray)
169
170 def copy(self) -> BitStore:
171 return self if self.immutable else self._copy()
172
173 def __getitem__(self, item: Union[int, slice], /) -> Union[int, BitStore]:
174 # Use getindex or getslice instead
175 raise NotImplementedError
176
177 def getindex_msb0(self, index: int, /) -> bool:
178 return bool(self._bitarray.__getitem__(index))
179
180 def getslice_withstep_msb0(self, key: slice, /) -> BitStore:
181 if self.modified_length is not None:
182 key = slice(*key.indices(self.modified_length))
183 return BitStore(self._bitarray.__getitem__(key))
184
185 def getslice_withstep_lsb0(self, key: slice, /) -> BitStore:
186 key = offset_slice_indices_lsb0(key, len(self))
187 return BitStore(self._bitarray.__getitem__(key))
188
189 def getslice_msb0(self, start: Optional[int], stop: Optional[int], /) -> BitStore:
190 if self.modified_length is not None:
191 key = slice(*slice(start, stop, None).indices(self.modified_length))
192 start = key.start
193 stop = key.stop
194 return BitStore(self._bitarray[start:stop])
195
196 def getslice_lsb0(self, start: Optional[int], stop: Optional[int], /) -> BitStore:
197 start, stop = offset_start_stop_lsb0(start, stop, len(self))
198 return BitStore(self._bitarray[start:stop])
199
200 def getindex_lsb0(self, index: int, /) -> bool:
201 return bool(self._bitarray.__getitem__(-index - 1))
202
203
204 @overload
205 def setitem_lsb0(self, key: int, value: int, /) -> None:
206 ...
207
208 @overload
209 def setitem_lsb0(self, key: slice, value: BitStore, /) -> None:
210 ...
211
212 def setitem_lsb0(self, key: Union[int, slice], value: Union[int, BitStore], /) -> None:
213 if isinstance(key, slice):
214 new_slice = offset_slice_indices_lsb0(key, len(self))
215 self._bitarray.__setitem__(new_slice, value._bitarray)
216 else:
217 self._bitarray.__setitem__(-key - 1, value)
218
219 def delitem_lsb0(self, key: Union[int, slice], /) -> None:
220 if isinstance(key, slice):
221 new_slice = offset_slice_indices_lsb0(key, len(self))
222 self._bitarray.__delitem__(new_slice)
223 else:
224 self._bitarray.__delitem__(-key - 1)
225
226 def invert_msb0(self, index: Optional[int] = None, /) -> None:
227 if index is not None:
228 self._bitarray.invert(index)
229 else:
230 self._bitarray.invert()
231
232 def invert_lsb0(self, index: Optional[int] = None, /) -> None:
233 if index is not None:
234 self._bitarray.invert(-index - 1)
235 else:
236 self._bitarray.invert()
237
238 def any_set(self) -> bool:
239 return self._bitarray.any()
240
241 def all_set(self) -> bool:
242 return self._bitarray.all()
243
244 def __len__(self) -> int:
245 return self.modified_length if self.modified_length is not None else len(self._bitarray)
246
247 def setitem_msb0(self, key, value, /):
248 if isinstance(value, BitStore):
249 self._bitarray.__setitem__(key, value._bitarray)
250 else:
251 self._bitarray.__setitem__(key, value)
252
253 def delitem_msb0(self, key, /):
254 self._bitarray.__delitem__(key)