1from __future__ import annotations
2
3import bitarray
4import bitarray.util
5from bitstring.exceptions import CreationError
6from typing import Union, Iterable, Optional, overload, Iterator, Any
7from bitstring.helpers import offset_slice_indices_lsb0
8
9if bitarray.__version__.startswith("2."):
10 raise ImportError(f"bitstring version 4.3 requires bitarray version 3 or higher. Found version {bitarray.__version__}.")
11
12
13class _BitStore:
14 """A light wrapper around bitarray that does the LSB0 stuff"""
15
16 __slots__ = ('_bitarray', 'modified_length', 'immutable')
17
18 def __init__(self, initializer: Union[bitarray.bitarray, None] = None,
19 immutable: bool = False) -> None:
20 if isinstance(initializer, str):
21 assert False
22 self._bitarray = bitarray.bitarray(initializer)
23 self.immutable = immutable
24 self.modified_length = None
25
26 @classmethod
27 def from_zeros(cls, i: int) -> _BitStore:
28 x = super().__new__(cls)
29 x._bitarray = bitarray.bitarray(i)
30 x.immutable = False
31 x.modified_length = None
32 return x
33
34
35 @classmethod
36 def from_bin(cls, s: str) -> _BitStore:
37 x = super().__new__(cls)
38 x._bitarray = bitarray.bitarray(s)
39 x.immutable = False
40 x.modified_length = None
41 return x
42
43 @classmethod
44 def from_bytes(cls, b: Union[bytes, bytearray, memoryview], /) -> _BitStore:
45 x = super().__new__(cls)
46 x._bitarray = bitarray.bitarray()
47 x._bitarray.frombytes(b)
48 x.immutable = False
49 x.modified_length = None
50 return x
51
52 @classmethod
53 def frombuffer(cls, buffer, /, length: Optional[int] = None) -> _BitStore:
54 x = super().__new__(cls)
55 x._bitarray = bitarray.bitarray(buffer=buffer)
56 x.immutable = True
57 x.modified_length = length
58 # Here 'modified' means it shouldn't be changed further, so setting, deleting etc. are disallowed.
59 if x.modified_length is not None:
60 if x.modified_length < 0:
61 raise CreationError("Can't create bitstring with a negative length.")
62 if x.modified_length > len(x._bitarray):
63 raise CreationError(
64 f"Can't create bitstring with a length of {x.modified_length} from {len(x._bitarray)} bits of data.")
65 return x
66
67 @classmethod
68 def join(cls, bitstores: Iterable[_BitStore], /) -> _BitStore:
69 x = super().__new__(cls)
70 x._bitarray = bitarray.bitarray()
71 for b in bitstores:
72 x._bitarray += b._bitarray
73 x.immutable = False
74 x.modified_length = None
75 return x
76
77 @staticmethod
78 def using_rust_core() -> bool:
79 return False
80
81 def tobitarray(self) -> bitarray.bitarray:
82 if self.modified_length is not None:
83 return self.getslice(0, len(self))._bitarray
84 return self._bitarray
85
86 def to_bytes(self) -> bytes:
87 if self.modified_length is not None:
88 return self._bitarray[:self.modified_length].tobytes()
89 return self._bitarray.tobytes()
90
91 def to_u(self) -> int:
92 if self.modified_length is not None:
93 return bitarray.util.ba2int(self._bitarray[:self.modified_length], signed=False)
94 return bitarray.util.ba2int(self._bitarray, signed=False)
95
96 def to_i(self) -> int:
97 if self.modified_length is not None:
98 return bitarray.util.ba2int(self._bitarray[:self.modified_length], signed=True)
99 return bitarray.util.ba2int(self._bitarray, signed=True)
100
101 def to_hex(self) -> str:
102 if self.modified_length is not None:
103 return bitarray.util.ba2hex(self._bitarray[:self.modified_length])
104 return bitarray.util.ba2hex(self._bitarray)
105
106 def to_bin(self) -> str:
107 if self.modified_length is not None:
108 return self._bitarray[:self.modified_length].to01()
109 return self._bitarray.to01()
110
111 def to_oct(self) -> str:
112 if self.modified_length is not None:
113 return bitarray.util.ba2base(8, self._bitarray[:self.modified_length])
114 return bitarray.util.ba2base(8, self._bitarray)
115
116 def __imul__(self, n: int, /) -> _BitStore:
117 self._bitarray *= n
118 return self
119
120 def __ilshift__(self, n: int, /) -> None:
121 self._bitarray <<= n
122
123 def __irshift__(self, n: int, /) -> None:
124 self._bitarray >>= n
125
126 def __iadd__(self, other: _BitStore, /) -> _BitStore:
127 self._bitarray += other._bitarray
128 return self
129
130 def __add__(self, other: _BitStore, /) -> _BitStore:
131 bs = self._mutable_copy()
132 bs += other
133 return bs
134
135 def __eq__(self, other: Any, /) -> bool:
136 return self._bitarray == other._bitarray
137
138 def __and__(self, other: _BitStore, /) -> _BitStore:
139 return _BitStore(self._bitarray & other._bitarray)
140
141 def __or__(self, other: _BitStore, /) -> _BitStore:
142 return _BitStore(self._bitarray | other._bitarray)
143
144 def __xor__(self, other: _BitStore, /) -> _BitStore:
145 return _BitStore(self._bitarray ^ other._bitarray)
146
147 def __iand__(self, other: _BitStore, /) -> _BitStore:
148 self._bitarray &= other._bitarray
149 return self
150
151 def __ior__(self, other: _BitStore, /) -> _BitStore:
152 self._bitarray |= other._bitarray
153 return self
154
155 def __ixor__(self, other: _BitStore, /) -> _BitStore:
156 self._bitarray ^= other._bitarray
157 return self
158
159 def __invert__(self) -> _BitStore:
160 return _BitStore(~self._bitarray)
161
162 def find(self, bs: _BitStore, start: int, end: int, bytealigned: bool = False) -> int | None:
163 if not bytealigned:
164 x = self._bitarray.find(bs._bitarray, start, end)
165 return None if x == -1 else x
166 try:
167 return next(self.findall_msb0(bs, start, end, bytealigned))
168 except StopIteration:
169 return None
170
171 def rfind(self, bs: _BitStore, start: int, end: int, bytealigned: bool = False) -> int | None:
172 if not bytealigned:
173 x = self._bitarray.find(bs._bitarray, start, end, right=True)
174 return None if x == -1 else x
175 try:
176 return next(self.rfindall_msb0(bs, start, end, bytealigned))
177 except StopIteration:
178 return None
179
180 def findall_msb0(self, bs: _BitStore, start: int, end: int, bytealigned: bool = False) -> Iterator[int]:
181 if bytealigned is True and len(bs) % 8 == 0:
182 # Special case, looking for whole bytes on whole byte boundaries
183 bytes_ = bs.to_bytes()
184 # Round up start byte to next byte, and round end byte down.
185 # We're only looking for whole bytes, so can ignore bits at either end.
186 start_byte = (start + 7) // 8
187 end_byte = end // 8
188 b = self._bitarray[start_byte * 8: end_byte * 8].tobytes()
189 byte_pos = 0
190 bytes_to_search = end_byte - start_byte
191 while byte_pos < bytes_to_search:
192 byte_pos = b.find(bytes_, byte_pos)
193 if byte_pos == -1:
194 break
195 yield (byte_pos + start_byte) * 8
196 byte_pos = byte_pos + 1
197 return
198 # General case
199 i = self._bitarray.search(bs._bitarray, start, end)
200 if not bytealigned:
201 for p in i:
202 yield p
203 else:
204 for p in i:
205 if (p % 8) == 0:
206 yield p
207
208 def rfindall_msb0(self, bs: _BitStore, start: int, end: int, bytealigned: bool = False) -> Iterator[int]:
209 i = self._bitarray.search(bs._bitarray, start, end, right=True)
210 if not bytealigned:
211 for p in i:
212 yield p
213 else:
214 for p in i:
215 if (p % 8) == 0:
216 yield p
217
218 def count(self, value, /) -> int:
219 return self._bitarray.count(value)
220
221 def clear(self) -> None:
222 self._bitarray.clear()
223
224 def reverse(self) -> None:
225 self._bitarray.reverse()
226
227 def __iter__(self) -> Iterable[bool]:
228 for i in range(len(self)):
229 yield self.getindex(i)
230
231 def _mutable_copy(self) -> _BitStore:
232 """Always creates a copy, even if instance is immutable."""
233 return _BitStore(self._bitarray, immutable=False)
234
235 def as_immutable(self) -> _BitStore:
236 return _BitStore(self._bitarray, immutable=True)
237
238 def copy(self) -> _BitStore:
239 return self if self.immutable else self._mutable_copy()
240
241 def __getitem__(self, item: Union[int, slice], /) -> Union[int, _BitStore]:
242 # Use getindex or getslice instead
243 raise NotImplementedError
244
245 def getindex_msb0(self, index: int, /) -> bool:
246 return bool(self._bitarray.__getitem__(index))
247
248 def getslice_withstep_msb0(self, key: slice, /) -> _BitStore:
249 if self.modified_length is not None:
250 key = slice(*key.indices(self.modified_length))
251 return _BitStore(self._bitarray.__getitem__(key))
252
253 def getslice_withstep_lsb0(self, key: slice, /) -> _BitStore:
254 key = offset_slice_indices_lsb0(key, len(self))
255 return _BitStore(self._bitarray.__getitem__(key))
256
257 def getslice_msb0(self, start: Optional[int], stop: Optional[int], /) -> _BitStore:
258 if self.modified_length is not None:
259 key = slice(*slice(start, stop, None).indices(self.modified_length))
260 start = key.start
261 stop = key.stop
262 return _BitStore(self._bitarray[start:stop])
263
264 def getslice_lsb0(self, start: Optional[int], stop: Optional[int], /) -> _BitStore:
265 s = offset_slice_indices_lsb0(slice(start, stop, None), len(self))
266 return _BitStore(self._bitarray[s.start:s.stop])
267
268 def getindex_lsb0(self, index: int, /) -> bool:
269 return bool(self._bitarray.__getitem__(-index - 1))
270
271 @overload
272 def setitem_lsb0(self, key: int, value: int, /) -> None:
273 ...
274
275 @overload
276 def setitem_lsb0(self, key: slice, value: _BitStore, /) -> None:
277 ...
278
279 def setitem_lsb0(self, key: Union[int, slice], value: Union[int, _BitStore], /) -> None:
280 if isinstance(key, slice):
281 new_slice = offset_slice_indices_lsb0(key, len(self))
282 self._bitarray.__setitem__(new_slice, value._bitarray)
283 else:
284 self._bitarray.__setitem__(-key - 1, value)
285
286 def delitem_lsb0(self, key: Union[int, slice], /) -> None:
287 if isinstance(key, slice):
288 new_slice = offset_slice_indices_lsb0(key, len(self))
289 self._bitarray.__delitem__(new_slice)
290 else:
291 self._bitarray.__delitem__(-key - 1)
292
293 def invert_msb0(self, index: Optional[int] = None, /) -> None:
294 if index is not None:
295 self._bitarray.invert(index)
296 else:
297 self._bitarray.invert()
298
299 def invert_lsb0(self, index: Optional[int] = None, /) -> None:
300 if index is not None:
301 self._bitarray.invert(-index - 1)
302 else:
303 self._bitarray.invert()
304
305 def extend_left(self, other: _BitStore, /) -> None:
306 self._bitarray = other._bitarray + self._bitarray
307
308 def any(self) -> bool:
309 return self._bitarray.any()
310
311 def all(self) -> bool:
312 return self._bitarray.all()
313
314 def __len__(self) -> int:
315 return self.modified_length if self.modified_length is not None else len(self._bitarray)
316
317 def setitem_msb0(self, key, value, /):
318 if isinstance(value, _BitStore):
319 self._bitarray.__setitem__(key, value._bitarray)
320 else:
321 self._bitarray.__setitem__(key, value)
322
323 def delitem_msb0(self, key, /):
324 self._bitarray.__delitem__(key)
325
326
327ConstBitStore = _BitStore
328MutableBitStore = _BitStore