1"""Helpers for WebSocket protocol versions 13 and 8.""" 
    2 
    3import functools 
    4import re 
    5from struct import Struct 
    6from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple 
    7 
    8from ..helpers import NO_EXTENSIONS 
    9from .models import WSHandshakeError 
    10 
    11UNPACK_LEN3 = Struct("!Q").unpack_from 
    12UNPACK_CLOSE_CODE = Struct("!H").unpack 
    13PACK_LEN1 = Struct("!BB").pack 
    14PACK_LEN2 = Struct("!BBH").pack 
    15PACK_LEN3 = Struct("!BBQ").pack 
    16PACK_CLOSE_CODE = Struct("!H").pack 
    17PACK_RANDBITS = Struct("!L").pack 
    18MSG_SIZE: Final[int] = 2**14 
    19MASK_LEN: Final[int] = 4 
    20 
    21WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 
    22 
    23 
    24# Used by _websocket_mask_python 
    25@functools.lru_cache 
    26def _xor_table() -> List[bytes]: 
    27    return [bytes(a ^ b for a in range(256)) for b in range(256)] 
    28 
    29 
    30def _websocket_mask_python(mask: bytes, data: bytearray) -> None: 
    31    """Websocket masking function. 
    32 
    33    `mask` is a `bytes` object of length 4; `data` is a `bytearray` 
    34    object of any length. The contents of `data` are masked with `mask`, 
    35    as specified in section 5.3 of RFC 6455. 
    36 
    37    Note that this function mutates the `data` argument. 
    38 
    39    This pure-python implementation may be replaced by an optimized 
    40    version when available. 
    41 
    42    """ 
    43    assert isinstance(data, bytearray), data 
    44    assert len(mask) == 4, mask 
    45 
    46    if data: 
    47        _XOR_TABLE = _xor_table() 
    48        a, b, c, d = (_XOR_TABLE[n] for n in mask) 
    49        data[::4] = data[::4].translate(a) 
    50        data[1::4] = data[1::4].translate(b) 
    51        data[2::4] = data[2::4].translate(c) 
    52        data[3::4] = data[3::4].translate(d) 
    53 
    54 
    55if TYPE_CHECKING or NO_EXTENSIONS:  # pragma: no cover 
    56    websocket_mask = _websocket_mask_python 
    57else: 
    58    try: 
    59        from .mask import _websocket_mask_cython  # type: ignore[import-not-found] 
    60 
    61        websocket_mask = _websocket_mask_cython 
    62    except ImportError:  # pragma: no cover 
    63        websocket_mask = _websocket_mask_python 
    64 
    65 
    66_WS_EXT_RE: Final[Pattern[str]] = re.compile( 
    67    r"^(?:;\s*(?:" 
    68    r"(server_no_context_takeover)|" 
    69    r"(client_no_context_takeover)|" 
    70    r"(server_max_window_bits(?:=(\d+))?)|" 
    71    r"(client_max_window_bits(?:=(\d+))?)))*$" 
    72) 
    73 
    74_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?") 
    75 
    76 
    77def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]: 
    78    if not extstr: 
    79        return 0, False 
    80 
    81    compress = 0 
    82    notakeover = False 
    83    for ext in _WS_EXT_RE_SPLIT.finditer(extstr): 
    84        defext = ext.group(1) 
    85        # Return compress = 15 when get `permessage-deflate` 
    86        if not defext: 
    87            compress = 15 
    88            break 
    89        match = _WS_EXT_RE.match(defext) 
    90        if match: 
    91            compress = 15 
    92            if isserver: 
    93                # Server never fail to detect compress handshake. 
    94                # Server does not need to send max wbit to client 
    95                if match.group(4): 
    96                    compress = int(match.group(4)) 
    97                    # Group3 must match if group4 matches 
    98                    # Compress wbit 8 does not support in zlib 
    99                    # If compress level not support, 
    100                    # CONTINUE to next extension 
    101                    if compress > 15 or compress < 9: 
    102                        compress = 0 
    103                        continue 
    104                if match.group(1): 
    105                    notakeover = True 
    106                # Ignore regex group 5 & 6 for client_max_window_bits 
    107                break 
    108            else: 
    109                if match.group(6): 
    110                    compress = int(match.group(6)) 
    111                    # Group5 must match if group6 matches 
    112                    # Compress wbit 8 does not support in zlib 
    113                    # If compress level not support, 
    114                    # FAIL the parse progress 
    115                    if compress > 15 or compress < 9: 
    116                        raise WSHandshakeError("Invalid window size") 
    117                if match.group(2): 
    118                    notakeover = True 
    119                # Ignore regex group 5 & 6 for client_max_window_bits 
    120                break 
    121        # Return Fail if client side and not match 
    122        elif not isserver: 
    123            raise WSHandshakeError("Extension for deflate not supported" + ext.group(1)) 
    124 
    125    return compress, notakeover 
    126 
    127 
    128def ws_ext_gen( 
    129    compress: int = 15, isserver: bool = False, server_notakeover: bool = False 
    130) -> str: 
    131    # client_notakeover=False not used for server 
    132    # compress wbit 8 does not support in zlib 
    133    if compress < 9 or compress > 15: 
    134        raise ValueError( 
    135            "Compress wbits must between 9 and 15, zlib does not support wbits=8" 
    136        ) 
    137    enabledext = ["permessage-deflate"] 
    138    if not isserver: 
    139        enabledext.append("client_max_window_bits") 
    140 
    141    if compress < 15: 
    142        enabledext.append("server_max_window_bits=" + str(compress)) 
    143    if server_notakeover: 
    144        enabledext.append("server_no_context_takeover") 
    145    # if client_notakeover: 
    146    #     enabledext.append('client_no_context_takeover') 
    147    return "; ".join(enabledext)