1# dialects/postgresql/bitstring.py
2# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7from __future__ import annotations
8
9import math
10from typing import Any
11from typing import cast
12from typing import Literal
13from typing import SupportsIndex
14
15
16class BitString(str):
17 """Represent a PostgreSQL bit string in python.
18
19 This object is used by the :class:`_postgresql.BIT` type when returning
20 values. :class:`_postgresql.BitString` values may also be constructed
21 directly and used with :class:`_postgresql.BIT` columns::
22
23 from sqlalchemy.dialects.postgresql import BitString
24
25 with engine.connect() as conn:
26 conn.execute(table.insert(), {"data": BitString("011001101")})
27
28 .. versionadded:: 2.1
29
30 """
31
32 _DIGITS = frozenset("01")
33
34 def __new__(cls, _value: str, _check: bool = True) -> BitString:
35 if isinstance(_value, BitString):
36 return _value
37 elif _check and cls._DIGITS.union(_value) > cls._DIGITS:
38 raise ValueError("BitString must only contain '0' and '1' chars")
39 else:
40 return super().__new__(cls, _value)
41
42 @classmethod
43 def from_int(cls, value: int, length: int) -> BitString:
44 """Returns a BitString consisting of the bits in the integer ``value``.
45 A ``ValueError`` is raised if ``value`` is not a non-negative integer.
46
47 If the provided ``value`` can not be represented in a bit string
48 of at most ``length``, a ``ValueError`` will be raised. The bitstring
49 will be padded on the left by ``'0'`` to bits to produce a
50 bitstring of the desired length.
51 """
52 if value < 0:
53 raise ValueError("value must be non-negative")
54 if length < 0:
55 raise ValueError("length must be non-negative")
56
57 template_str = f"{{0:0{length}b}}" if length > 0 else ""
58 r = template_str.format(value)
59
60 if (length == 0 and value > 0) or len(r) > length:
61 raise ValueError(
62 f"Cannot encode {value} as a BitString of length {length}"
63 )
64
65 return cls(r)
66
67 @classmethod
68 def from_bytes(cls, value: bytes, length: int = -1) -> BitString:
69 """Returns a ``BitString`` consisting of the bits in the given
70 ``value`` bytes.
71
72 If ``length`` is provided, then the length of the provided string
73 will be exactly ``length``, with ``'0'`` bits inserted at the left of
74 the string in order to produce a value of the required length.
75 If the bits obtained by omitting the leading ``'0'`` bits of ``value``
76 cannot be represented in a string of this length a ``ValueError``
77 will be raised.
78 """
79 str_v: str = "".join(f"{int(c):08b}" for c in value)
80 if length >= 0:
81 str_v = str_v.lstrip("0")
82
83 if len(str_v) > length:
84 raise ValueError(
85 f"Cannot encode {value!r} as a BitString of "
86 f"length {length}"
87 )
88 str_v = str_v.zfill(length)
89
90 return cls(str_v)
91
92 def get_bit(self, index: int) -> Literal["0", "1"]:
93 """Returns the value of the flag at the given
94 index::
95
96 BitString("0101").get_flag(4) == "1"
97 """
98 return cast(Literal["0", "1"], super().__getitem__(index))
99
100 @property
101 def bit_length(self) -> int:
102 return len(self)
103
104 @property
105 def octet_length(self) -> int:
106 return math.ceil(len(self) / 8)
107
108 def has_bit(self, index: int) -> bool:
109 return self.get_bit(index) == "1"
110
111 def set_bit(
112 self, index: int, value: bool | int | Literal["0", "1"]
113 ) -> BitString:
114 """Set the bit at index to the given value.
115
116 If value is an int, then it is considered to be '1' iff nonzero.
117 """
118 if index < 0 or index >= len(self):
119 raise IndexError("BitString index out of range")
120
121 if isinstance(value, (bool, int)):
122 value = "1" if value else "0"
123
124 if self.get_bit(index) == value:
125 return self
126
127 return BitString(
128 "".join([self[:index], value, self[index + 1 :]]), False
129 )
130
131 def lstrip(self, char: str | None = None) -> BitString:
132 """Returns a copy of the BitString with leading characters removed.
133
134 If omitted or None, 'chars' defaults '0'::
135
136 BitString("00010101000").lstrip() == BitString("00010101")
137 BitString("11110101111").lstrip("1") == BitString("1111010")
138 """
139 if char is None:
140 char = "0"
141 return BitString(super().lstrip(char), False)
142
143 def rstrip(self, char: str | None = "0") -> BitString:
144 """Returns a copy of the BitString with trailing characters removed.
145
146 If omitted or None, ``'char'`` defaults to "0"::
147
148 BitString("00010101000").rstrip() == BitString("10101000")
149 BitString("11110101111").rstrip("1") == BitString("10101111")
150 """
151 if char is None:
152 char = "0"
153 return BitString(super().rstrip(char), False)
154
155 def strip(self, char: str | None = "0") -> BitString:
156 """Returns a copy of the BitString with both leading and trailing
157 characters removed.
158 If omitted or None, ``'char'`` defaults to ``"0"``::
159
160 BitString("00010101000").rstrip() == BitString("10101")
161 BitString("11110101111").rstrip("1") == BitString("1010")
162 """
163 if char is None:
164 char = "0"
165 return BitString(super().strip(char))
166
167 def removeprefix(self, prefix: str, /) -> BitString:
168 return BitString(super().removeprefix(prefix), False)
169
170 def removesuffix(self, suffix: str, /) -> BitString:
171 return BitString(super().removesuffix(suffix), False)
172
173 def replace(
174 self,
175 old: str,
176 new: str,
177 count: SupportsIndex = -1,
178 ) -> BitString:
179 new = BitString(new)
180 return BitString(super().replace(old, new, count), False)
181
182 def split(
183 self,
184 sep: str | None = None,
185 maxsplit: SupportsIndex = -1,
186 ) -> list[str]:
187 return [BitString(word) for word in super().split(sep, maxsplit)]
188
189 def zfill(self, width: SupportsIndex) -> BitString:
190 return BitString(super().zfill(width), False)
191
192 def __repr__(self) -> str:
193 return f'BitString("{self.__str__()}")'
194
195 def __int__(self) -> int:
196 return int(self, 2) if self else 0
197
198 def to_bytes(self, length: int = -1) -> bytes:
199 return int(self).to_bytes(
200 length if length >= 0 else self.octet_length, byteorder="big"
201 )
202
203 def __bytes__(self) -> bytes:
204 return self.to_bytes()
205
206 def __getitem__(
207 self, key: SupportsIndex | slice[Any, Any, Any]
208 ) -> BitString:
209 return BitString(super().__getitem__(key), False)
210
211 def __add__(self, o: str) -> BitString:
212 """Return self + o"""
213 if not isinstance(o, str):
214 raise TypeError(
215 f"Can only concatenate str (not '{type(self)}') to BitString"
216 )
217 return BitString("".join([self, o]))
218
219 def __radd__(self, o: str) -> BitString:
220 if not isinstance(o, str):
221 raise TypeError(
222 f"Can only concatenate str (not '{type(self)}') to BitString"
223 )
224 return BitString("".join([o, self]))
225
226 def __lshift__(self, amount: int) -> BitString:
227 """Shifts each the bitstring to the left by the given amount.
228 String length is preserved::
229
230 BitString("000101") << 1 == BitString("001010")
231 """
232 return BitString(
233 "".join([self, *("0" for _ in range(amount))])[-len(self) :], False
234 )
235
236 def __rshift__(self, amount: int) -> BitString:
237 """Shifts each bit in the bitstring to the right by the given amount.
238 String length is preserved::
239
240 BitString("101") >> 1 == BitString("010")
241 """
242 return BitString(self[:-amount], False).zfill(width=len(self))
243
244 def __invert__(self) -> BitString:
245 """Inverts (~) each bit in the
246 bitstring::
247
248 ~BitString("01010") == BitString("10101")
249 """
250 return BitString("".join("1" if x == "0" else "0" for x in self))
251
252 def __and__(self, o: str) -> BitString:
253 """Performs a bitwise and (``&``) with the given operand.
254 A ``ValueError`` is raised if the operand is not the same length.
255
256 e.g.::
257
258 BitString("011") & BitString("011") == BitString("010")
259 """
260
261 if not isinstance(o, str):
262 return NotImplemented
263 o = BitString(o)
264 if len(self) != len(o):
265 raise ValueError("Operands must be the same length")
266
267 return BitString(
268 "".join(
269 "1" if (x == "1" and y == "1") else "0"
270 for x, y in zip(self, o)
271 ),
272 False,
273 )
274
275 def __or__(self, o: str) -> BitString:
276 """Performs a bitwise or (``|``) with the given operand.
277 A ``ValueError`` is raised if the operand is not the same length.
278
279 e.g.::
280
281 BitString("011") | BitString("010") == BitString("011")
282 """
283 if not isinstance(o, str):
284 return NotImplemented
285
286 if len(self) != len(o):
287 raise ValueError("Operands must be the same length")
288
289 o = BitString(o)
290 return BitString(
291 "".join(
292 "1" if (x == "1" or y == "1") else "0"
293 for (x, y) in zip(self, o)
294 ),
295 False,
296 )
297
298 def __xor__(self, o: str) -> BitString:
299 """Performs a bitwise xor (``^``) with the given operand.
300 A ``ValueError`` is raised if the operand is not the same length.
301
302 e.g.::
303
304 BitString("011") ^ BitString("010") == BitString("001")
305 """
306
307 if not isinstance(o, BitString):
308 return NotImplemented
309
310 if len(self) != len(o):
311 raise ValueError("Operands must be the same length")
312
313 return BitString(
314 "".join(
315 (
316 "1"
317 if ((x == "1" and y == "0") or (x == "0" and y == "1"))
318 else "0"
319 )
320 for (x, y) in zip(self, o)
321 ),
322 False,
323 )
324
325 __rand__ = __and__
326 __ror__ = __or__
327 __rxor__ = __xor__