1# dialects/postgresql/bitstring.py 
    2# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: https://www.opensource.org/licenses/mit-license.php 
    7from __future__ import annotations 
    8 
    9import math 
    10from typing import Any 
    11from typing import cast 
    12from typing import Literal 
    13from typing import SupportsIndex 
    14 
    15 
    16class BitString(str): 
    17    """Represent a PostgreSQL bit string in python. 
    18 
    19    This object is used by the :class:`_postgresql.BIT` type when returning 
    20    values.   :class:`_postgresql.BitString` values may also be constructed 
    21    directly and used with :class:`_postgresql.BIT` columns:: 
    22 
    23        from sqlalchemy.dialects.postgresql import BitString 
    24 
    25        with engine.connect() as conn: 
    26            conn.execute(table.insert(), {"data": BitString("011001101")}) 
    27 
    28    .. versionadded:: 2.1 
    29 
    30    """ 
    31 
    32    _DIGITS = frozenset("01") 
    33 
    34    def __new__(cls, _value: str, _check: bool = True) -> BitString: 
    35        if isinstance(_value, BitString): 
    36            return _value 
    37        elif _check and cls._DIGITS.union(_value) > cls._DIGITS: 
    38            raise ValueError("BitString must only contain '0' and '1' chars") 
    39        else: 
    40            return super().__new__(cls, _value) 
    41 
    42    @classmethod 
    43    def from_int(cls, value: int, length: int) -> BitString: 
    44        """Returns a BitString consisting of the bits in the integer ``value``. 
    45        A ``ValueError`` is raised if ``value`` is not a non-negative integer. 
    46 
    47        If the provided ``value`` can not be represented in a bit string 
    48        of at most ``length``, a ``ValueError`` will be raised. The bitstring 
    49        will be padded on the left by ``'0'`` to bits to produce a 
    50        bitstring of the desired length. 
    51        """ 
    52        if value < 0: 
    53            raise ValueError("value must be non-negative") 
    54        if length < 0: 
    55            raise ValueError("length must be non-negative") 
    56 
    57        template_str = f"{{0:0{length}b}}" if length > 0 else "" 
    58        r = template_str.format(value) 
    59 
    60        if (length == 0 and value > 0) or len(r) > length: 
    61            raise ValueError( 
    62                f"Cannot encode {value} as a BitString of length {length}" 
    63            ) 
    64 
    65        return cls(r) 
    66 
    67    @classmethod 
    68    def from_bytes(cls, value: bytes, length: int = -1) -> BitString: 
    69        """Returns a ``BitString`` consisting of the bits in the given 
    70        ``value`` bytes. 
    71 
    72        If ``length`` is provided, then the length of the provided string 
    73        will be exactly ``length``, with ``'0'`` bits inserted at the left of 
    74        the string in order to produce a value of the required length. 
    75        If the bits obtained by omitting the leading ``'0'`` bits of ``value`` 
    76        cannot be represented in a string of this length a ``ValueError`` 
    77        will be raised. 
    78        """ 
    79        str_v: str = "".join(f"{int(c):08b}" for c in value) 
    80        if length >= 0: 
    81            str_v = str_v.lstrip("0") 
    82 
    83            if len(str_v) > length: 
    84                raise ValueError( 
    85                    f"Cannot encode {value!r} as a BitString of " 
    86                    f"length {length}" 
    87                ) 
    88            str_v = str_v.zfill(length) 
    89 
    90        return cls(str_v) 
    91 
    92    def get_bit(self, index: int) -> Literal["0", "1"]: 
    93        """Returns the value of the flag at the given 
    94        index:: 
    95 
    96            BitString("0101").get_flag(4) == "1" 
    97        """ 
    98        return cast(Literal["0", "1"], super().__getitem__(index)) 
    99 
    100    @property 
    101    def bit_length(self) -> int: 
    102        return len(self) 
    103 
    104    @property 
    105    def octet_length(self) -> int: 
    106        return math.ceil(len(self) / 8) 
    107 
    108    def has_bit(self, index: int) -> bool: 
    109        return self.get_bit(index) == "1" 
    110 
    111    def set_bit( 
    112        self, index: int, value: bool | int | Literal["0", "1"] 
    113    ) -> BitString: 
    114        """Set the bit at index to the given value. 
    115 
    116        If value is an int, then it is considered to be '1' iff nonzero. 
    117        """ 
    118        if index < 0 or index >= len(self): 
    119            raise IndexError("BitString index out of range") 
    120 
    121        if isinstance(value, (bool, int)): 
    122            value = "1" if value else "0" 
    123 
    124        if self.get_bit(index) == value: 
    125            return self 
    126 
    127        return BitString( 
    128            "".join([self[:index], value, self[index + 1 :]]), False 
    129        ) 
    130 
    131    def lstrip(self, char: str | None = None) -> BitString: 
    132        """Returns a copy of the BitString with leading characters removed. 
    133 
    134        If omitted or None, 'chars' defaults '0':: 
    135 
    136            BitString("00010101000").lstrip() == BitString("00010101") 
    137            BitString("11110101111").lstrip("1") == BitString("1111010") 
    138        """ 
    139        if char is None: 
    140            char = "0" 
    141        return BitString(super().lstrip(char), False) 
    142 
    143    def rstrip(self, char: str | None = "0") -> BitString: 
    144        """Returns a copy of the BitString with trailing characters removed. 
    145 
    146        If omitted or None, ``'char'`` defaults to "0":: 
    147 
    148            BitString("00010101000").rstrip() == BitString("10101000") 
    149            BitString("11110101111").rstrip("1") == BitString("10101111") 
    150        """ 
    151        if char is None: 
    152            char = "0" 
    153        return BitString(super().rstrip(char), False) 
    154 
    155    def strip(self, char: str | None = "0") -> BitString: 
    156        """Returns a copy of the BitString with both leading and trailing 
    157        characters removed. 
    158        If omitted or None, ``'char'`` defaults to ``"0"``:: 
    159 
    160            BitString("00010101000").rstrip() == BitString("10101") 
    161            BitString("11110101111").rstrip("1") == BitString("1010") 
    162        """ 
    163        if char is None: 
    164            char = "0" 
    165        return BitString(super().strip(char)) 
    166 
    167    def removeprefix(self, prefix: str, /) -> BitString: 
    168        return BitString(super().removeprefix(prefix), False) 
    169 
    170    def removesuffix(self, suffix: str, /) -> BitString: 
    171        return BitString(super().removesuffix(suffix), False) 
    172 
    173    def replace( 
    174        self, 
    175        old: str, 
    176        new: str, 
    177        count: SupportsIndex = -1, 
    178    ) -> BitString: 
    179        new = BitString(new) 
    180        return BitString(super().replace(old, new, count), False) 
    181 
    182    def split( 
    183        self, 
    184        sep: str | None = None, 
    185        maxsplit: SupportsIndex = -1, 
    186    ) -> list[str]: 
    187        return [BitString(word) for word in super().split(sep, maxsplit)] 
    188 
    189    def zfill(self, width: SupportsIndex) -> BitString: 
    190        return BitString(super().zfill(width), False) 
    191 
    192    def __repr__(self) -> str: 
    193        return f'BitString("{self.__str__()}")' 
    194 
    195    def __int__(self) -> int: 
    196        return int(self, 2) if self else 0 
    197 
    198    def to_bytes(self, length: int = -1) -> bytes: 
    199        return int(self).to_bytes( 
    200            length if length >= 0 else self.octet_length, byteorder="big" 
    201        ) 
    202 
    203    def __bytes__(self) -> bytes: 
    204        return self.to_bytes() 
    205 
    206    def __getitem__( 
    207        self, key: SupportsIndex | slice[Any, Any, Any] 
    208    ) -> BitString: 
    209        return BitString(super().__getitem__(key), False) 
    210 
    211    def __add__(self, o: str) -> BitString: 
    212        """Return self + o""" 
    213        if not isinstance(o, str): 
    214            raise TypeError( 
    215                f"Can only concatenate str (not '{type(self)}') to BitString" 
    216            ) 
    217        return BitString("".join([self, o])) 
    218 
    219    def __radd__(self, o: str) -> BitString: 
    220        if not isinstance(o, str): 
    221            raise TypeError( 
    222                f"Can only concatenate str (not '{type(self)}') to BitString" 
    223            ) 
    224        return BitString("".join([o, self])) 
    225 
    226    def __lshift__(self, amount: int) -> BitString: 
    227        """Shifts each the bitstring to the left by the given amount. 
    228        String length is preserved:: 
    229 
    230            BitString("000101") << 1 == BitString("001010") 
    231        """ 
    232        return BitString( 
    233            "".join([self, *("0" for _ in range(amount))])[-len(self) :], False 
    234        ) 
    235 
    236    def __rshift__(self, amount: int) -> BitString: 
    237        """Shifts each bit in the bitstring to the right by the given amount. 
    238        String length is preserved:: 
    239 
    240            BitString("101") >> 1 == BitString("010") 
    241        """ 
    242        return BitString(self[:-amount], False).zfill(width=len(self)) 
    243 
    244    def __invert__(self) -> BitString: 
    245        """Inverts (~) each bit in the 
    246        bitstring:: 
    247 
    248            ~BitString("01010") == BitString("10101") 
    249        """ 
    250        return BitString("".join("1" if x == "0" else "0" for x in self)) 
    251 
    252    def __and__(self, o: str) -> BitString: 
    253        """Performs a bitwise and (``&``) with the given operand. 
    254        A ``ValueError`` is raised if the operand is not the same length. 
    255 
    256        e.g.:: 
    257 
    258            BitString("011") & BitString("011") == BitString("010") 
    259        """ 
    260 
    261        if not isinstance(o, str): 
    262            return NotImplemented 
    263        o = BitString(o) 
    264        if len(self) != len(o): 
    265            raise ValueError("Operands must be the same length") 
    266 
    267        return BitString( 
    268            "".join( 
    269                "1" if (x == "1" and y == "1") else "0" 
    270                for x, y in zip(self, o) 
    271            ), 
    272            False, 
    273        ) 
    274 
    275    def __or__(self, o: str) -> BitString: 
    276        """Performs a bitwise or (``|``) with the given operand. 
    277        A ``ValueError`` is raised if the operand is not the same length. 
    278 
    279        e.g.:: 
    280 
    281            BitString("011") | BitString("010") == BitString("011") 
    282        """ 
    283        if not isinstance(o, str): 
    284            return NotImplemented 
    285 
    286        if len(self) != len(o): 
    287            raise ValueError("Operands must be the same length") 
    288 
    289        o = BitString(o) 
    290        return BitString( 
    291            "".join( 
    292                "1" if (x == "1" or y == "1") else "0" 
    293                for (x, y) in zip(self, o) 
    294            ), 
    295            False, 
    296        ) 
    297 
    298    def __xor__(self, o: str) -> BitString: 
    299        """Performs a bitwise xor (``^``) with the given operand. 
    300        A ``ValueError`` is raised if the operand is not the same length. 
    301 
    302        e.g.:: 
    303 
    304            BitString("011") ^ BitString("010") == BitString("001") 
    305        """ 
    306 
    307        if not isinstance(o, BitString): 
    308            return NotImplemented 
    309 
    310        if len(self) != len(o): 
    311            raise ValueError("Operands must be the same length") 
    312 
    313        return BitString( 
    314            "".join( 
    315                ( 
    316                    "1" 
    317                    if ((x == "1" and y == "0") or (x == "0" and y == "1")) 
    318                    else "0" 
    319                ) 
    320                for (x, y) in zip(self, o) 
    321            ), 
    322            False, 
    323        ) 
    324 
    325    __rand__ = __and__ 
    326    __ror__ = __or__ 
    327    __rxor__ = __xor__