1"""Helpers for WebSocket protocol versions 13 and 8."""
2
3import functools
4import re
5from struct import Struct
6from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple
7
8from ..helpers import NO_EXTENSIONS
9from .models import WSHandshakeError
10
11UNPACK_LEN3 = Struct("!Q").unpack_from
12UNPACK_CLOSE_CODE = Struct("!H").unpack
13PACK_LEN1 = Struct("!BB").pack
14PACK_LEN2 = Struct("!BBH").pack
15PACK_LEN3 = Struct("!BBQ").pack
16PACK_CLOSE_CODE = Struct("!H").pack
17PACK_RANDBITS = Struct("!L").pack
18MSG_SIZE: Final[int] = 2**14
19MASK_LEN: Final[int] = 4
20
21WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
22
23
24# Used by _websocket_mask_python
25@functools.lru_cache
26def _xor_table() -> List[bytes]:
27 return [bytes(a ^ b for a in range(256)) for b in range(256)]
28
29
30def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
31 """Websocket masking function.
32
33 `mask` is a `bytes` object of length 4; `data` is a `bytearray`
34 object of any length. The contents of `data` are masked with `mask`,
35 as specified in section 5.3 of RFC 6455.
36
37 Note that this function mutates the `data` argument.
38
39 This pure-python implementation may be replaced by an optimized
40 version when available.
41
42 """
43 assert isinstance(data, bytearray), data
44 assert len(mask) == 4, mask
45
46 if data:
47 _XOR_TABLE = _xor_table()
48 a, b, c, d = (_XOR_TABLE[n] for n in mask)
49 data[::4] = data[::4].translate(a)
50 data[1::4] = data[1::4].translate(b)
51 data[2::4] = data[2::4].translate(c)
52 data[3::4] = data[3::4].translate(d)
53
54
55if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
56 websocket_mask = _websocket_mask_python
57else:
58 try:
59 from .mask import _websocket_mask_cython # type: ignore[import-not-found]
60
61 websocket_mask = _websocket_mask_cython
62 except ImportError: # pragma: no cover
63 websocket_mask = _websocket_mask_python
64
65
66_WS_EXT_RE: Final[Pattern[str]] = re.compile(
67 r"^(?:;\s*(?:"
68 r"(server_no_context_takeover)|"
69 r"(client_no_context_takeover)|"
70 r"(server_max_window_bits(?:=(\d+))?)|"
71 r"(client_max_window_bits(?:=(\d+))?)))*$"
72)
73
74_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
75
76
77def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
78 if not extstr:
79 return 0, False
80
81 compress = 0
82 notakeover = False
83 for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
84 defext = ext.group(1)
85 # Return compress = 15 when get `permessage-deflate`
86 if not defext:
87 compress = 15
88 break
89 match = _WS_EXT_RE.match(defext)
90 if match:
91 compress = 15
92 if isserver:
93 # Server never fail to detect compress handshake.
94 # Server does not need to send max wbit to client
95 if match.group(4):
96 compress = int(match.group(4))
97 # Group3 must match if group4 matches
98 # Compress wbit 8 does not support in zlib
99 # If compress level not support,
100 # CONTINUE to next extension
101 if compress > 15 or compress < 9:
102 compress = 0
103 continue
104 if match.group(1):
105 notakeover = True
106 # Ignore regex group 5 & 6 for client_max_window_bits
107 break
108 else:
109 if match.group(6):
110 compress = int(match.group(6))
111 # Group5 must match if group6 matches
112 # Compress wbit 8 does not support in zlib
113 # If compress level not support,
114 # FAIL the parse progress
115 if compress > 15 or compress < 9:
116 raise WSHandshakeError("Invalid window size")
117 if match.group(2):
118 notakeover = True
119 # Ignore regex group 5 & 6 for client_max_window_bits
120 break
121 # Return Fail if client side and not match
122 elif not isserver:
123 raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
124
125 return compress, notakeover
126
127
128def ws_ext_gen(
129 compress: int = 15, isserver: bool = False, server_notakeover: bool = False
130) -> str:
131 # client_notakeover=False not used for server
132 # compress wbit 8 does not support in zlib
133 if compress < 9 or compress > 15:
134 raise ValueError(
135 "Compress wbits must between 9 and 15, zlib does not support wbits=8"
136 )
137 enabledext = ["permessage-deflate"]
138 if not isserver:
139 enabledext.append("client_max_window_bits")
140
141 if compress < 15:
142 enabledext.append("server_max_window_bits=" + str(compress))
143 if server_notakeover:
144 enabledext.append("server_no_context_takeover")
145 # if client_notakeover:
146 # enabledext.append('client_no_context_takeover')
147 return "; ".join(enabledext)