1"""WebSocket protocol versions 13 and 8."""
2
3import asyncio
4import functools
5import json
6import random
7import re
8import sys
9import zlib
10from enum import IntEnum
11from struct import Struct
12from typing import (
13 Any,
14 Callable,
15 Final,
16 List,
17 NamedTuple,
18 Optional,
19 Pattern,
20 Set,
21 Tuple,
22 Union,
23 cast,
24)
25
26from .base_protocol import BaseProtocol
27from .compression_utils import ZLibCompressor, ZLibDecompressor
28from .helpers import NO_EXTENSIONS, set_exception
29from .streams import DataQueue
30
31__all__ = (
32 "WS_CLOSED_MESSAGE",
33 "WS_CLOSING_MESSAGE",
34 "WS_KEY",
35 "WebSocketReader",
36 "WebSocketWriter",
37 "WSMessage",
38 "WebSocketError",
39 "WSMsgType",
40 "WSCloseCode",
41)
42
43
44class WSCloseCode(IntEnum):
45 OK = 1000
46 GOING_AWAY = 1001
47 PROTOCOL_ERROR = 1002
48 UNSUPPORTED_DATA = 1003
49 ABNORMAL_CLOSURE = 1006
50 INVALID_TEXT = 1007
51 POLICY_VIOLATION = 1008
52 MESSAGE_TOO_BIG = 1009
53 MANDATORY_EXTENSION = 1010
54 INTERNAL_ERROR = 1011
55 SERVICE_RESTART = 1012
56 TRY_AGAIN_LATER = 1013
57 BAD_GATEWAY = 1014
58
59
60ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
61
62# For websockets, keeping latency low is extremely important as implementations
63# generally expect to be able to send and receive messages quickly. We use a
64# larger chunk size than the default to reduce the number of executor calls
65# since the executor is a significant source of latency and overhead when
66# the chunks are small. A size of 5KiB was chosen because it is also the
67# same value python-zlib-ng choose to use as the threshold to release the GIL.
68
69WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024
70
71
72class WSMsgType(IntEnum):
73 # websocket spec types
74 CONTINUATION = 0x0
75 TEXT = 0x1
76 BINARY = 0x2
77 PING = 0x9
78 PONG = 0xA
79 CLOSE = 0x8
80
81 # aiohttp specific types
82 CLOSING = 0x100
83 CLOSED = 0x101
84 ERROR = 0x102
85
86 text = TEXT
87 binary = BINARY
88 ping = PING
89 pong = PONG
90 close = CLOSE
91 closing = CLOSING
92 closed = CLOSED
93 error = ERROR
94
95
96WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
97
98
99UNPACK_LEN2 = Struct("!H").unpack_from
100UNPACK_LEN3 = Struct("!Q").unpack_from
101UNPACK_CLOSE_CODE = Struct("!H").unpack
102PACK_LEN1 = Struct("!BB").pack
103PACK_LEN2 = Struct("!BBH").pack
104PACK_LEN3 = Struct("!BBQ").pack
105PACK_CLOSE_CODE = Struct("!H").pack
106MSG_SIZE: Final[int] = 2**14
107DEFAULT_LIMIT: Final[int] = 2**16
108
109
110class WSMessage(NamedTuple):
111 type: WSMsgType
112 # To type correctly, this would need some kind of tagged union for each type.
113 data: Any
114 extra: Optional[str]
115
116 def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
117 """Return parsed JSON data.
118
119 .. versionadded:: 0.22
120 """
121 return loads(self.data)
122
123
124WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
125WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
126
127
128class WebSocketError(Exception):
129 """WebSocket protocol parser error."""
130
131 def __init__(self, code: int, message: str) -> None:
132 self.code = code
133 super().__init__(code, message)
134
135 def __str__(self) -> str:
136 return cast(str, self.args[1])
137
138
139class WSHandshakeError(Exception):
140 """WebSocket protocol handshake error."""
141
142
143native_byteorder: Final[str] = sys.byteorder
144
145
146# Used by _websocket_mask_python
147@functools.lru_cache
148def _xor_table() -> List[bytes]:
149 return [bytes(a ^ b for a in range(256)) for b in range(256)]
150
151
152def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
153 """Websocket masking function.
154
155 `mask` is a `bytes` object of length 4; `data` is a `bytearray`
156 object of any length. The contents of `data` are masked with `mask`,
157 as specified in section 5.3 of RFC 6455.
158
159 Note that this function mutates the `data` argument.
160
161 This pure-python implementation may be replaced by an optimized
162 version when available.
163
164 """
165 assert isinstance(data, bytearray), data
166 assert len(mask) == 4, mask
167
168 if data:
169 _XOR_TABLE = _xor_table()
170 a, b, c, d = (_XOR_TABLE[n] for n in mask)
171 data[::4] = data[::4].translate(a)
172 data[1::4] = data[1::4].translate(b)
173 data[2::4] = data[2::4].translate(c)
174 data[3::4] = data[3::4].translate(d)
175
176
177if NO_EXTENSIONS: # pragma: no cover
178 _websocket_mask = _websocket_mask_python
179else:
180 try:
181 from ._websocket import _websocket_mask_cython # type: ignore[import-not-found]
182
183 _websocket_mask = _websocket_mask_cython
184 except ImportError: # pragma: no cover
185 _websocket_mask = _websocket_mask_python
186
187_WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
188
189
190_WS_EXT_RE: Final[Pattern[str]] = re.compile(
191 r"^(?:;\s*(?:"
192 r"(server_no_context_takeover)|"
193 r"(client_no_context_takeover)|"
194 r"(server_max_window_bits(?:=(\d+))?)|"
195 r"(client_max_window_bits(?:=(\d+))?)))*$"
196)
197
198_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
199
200
201def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
202 if not extstr:
203 return 0, False
204
205 compress = 0
206 notakeover = False
207 for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
208 defext = ext.group(1)
209 # Return compress = 15 when get `permessage-deflate`
210 if not defext:
211 compress = 15
212 break
213 match = _WS_EXT_RE.match(defext)
214 if match:
215 compress = 15
216 if isserver:
217 # Server never fail to detect compress handshake.
218 # Server does not need to send max wbit to client
219 if match.group(4):
220 compress = int(match.group(4))
221 # Group3 must match if group4 matches
222 # Compress wbit 8 does not support in zlib
223 # If compress level not support,
224 # CONTINUE to next extension
225 if compress > 15 or compress < 9:
226 compress = 0
227 continue
228 if match.group(1):
229 notakeover = True
230 # Ignore regex group 5 & 6 for client_max_window_bits
231 break
232 else:
233 if match.group(6):
234 compress = int(match.group(6))
235 # Group5 must match if group6 matches
236 # Compress wbit 8 does not support in zlib
237 # If compress level not support,
238 # FAIL the parse progress
239 if compress > 15 or compress < 9:
240 raise WSHandshakeError("Invalid window size")
241 if match.group(2):
242 notakeover = True
243 # Ignore regex group 5 & 6 for client_max_window_bits
244 break
245 # Return Fail if client side and not match
246 elif not isserver:
247 raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
248
249 return compress, notakeover
250
251
252def ws_ext_gen(
253 compress: int = 15, isserver: bool = False, server_notakeover: bool = False
254) -> str:
255 # client_notakeover=False not used for server
256 # compress wbit 8 does not support in zlib
257 if compress < 9 or compress > 15:
258 raise ValueError(
259 "Compress wbits must between 9 and 15, " "zlib does not support wbits=8"
260 )
261 enabledext = ["permessage-deflate"]
262 if not isserver:
263 enabledext.append("client_max_window_bits")
264
265 if compress < 15:
266 enabledext.append("server_max_window_bits=" + str(compress))
267 if server_notakeover:
268 enabledext.append("server_no_context_takeover")
269 # if client_notakeover:
270 # enabledext.append('client_no_context_takeover')
271 return "; ".join(enabledext)
272
273
274class WSParserState(IntEnum):
275 READ_HEADER = 1
276 READ_PAYLOAD_LENGTH = 2
277 READ_PAYLOAD_MASK = 3
278 READ_PAYLOAD = 4
279
280
281class WebSocketReader:
282 def __init__(
283 self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True
284 ) -> None:
285 self.queue = queue
286 self._max_msg_size = max_msg_size
287
288 self._exc: Optional[BaseException] = None
289 self._partial = bytearray()
290 self._state = WSParserState.READ_HEADER
291
292 self._opcode: Optional[int] = None
293 self._frame_fin = False
294 self._frame_opcode: Optional[int] = None
295 self._frame_payload = bytearray()
296
297 self._tail = b""
298 self._has_mask = False
299 self._frame_mask: Optional[bytes] = None
300 self._payload_length = 0
301 self._payload_length_flag = 0
302 self._compressed: Optional[bool] = None
303 self._decompressobj: Optional[ZLibDecompressor] = None
304 self._compress = compress
305
306 def feed_eof(self) -> None:
307 self.queue.feed_eof()
308
309 def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
310 if self._exc:
311 return True, data
312
313 try:
314 return self._feed_data(data)
315 except Exception as exc:
316 self._exc = exc
317 set_exception(self.queue, exc)
318 return True, b""
319
320 def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
321 for fin, opcode, payload, compressed in self.parse_frame(data):
322 if compressed and not self._decompressobj:
323 self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
324 if opcode == WSMsgType.CLOSE:
325 if len(payload) >= 2:
326 close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
327 if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
328 raise WebSocketError(
329 WSCloseCode.PROTOCOL_ERROR,
330 f"Invalid close code: {close_code}",
331 )
332 try:
333 close_message = payload[2:].decode("utf-8")
334 except UnicodeDecodeError as exc:
335 raise WebSocketError(
336 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
337 ) from exc
338 msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
339 elif payload:
340 raise WebSocketError(
341 WSCloseCode.PROTOCOL_ERROR,
342 f"Invalid close frame: {fin} {opcode} {payload!r}",
343 )
344 else:
345 msg = WSMessage(WSMsgType.CLOSE, 0, "")
346
347 self.queue.feed_data(msg, 0)
348
349 elif opcode == WSMsgType.PING:
350 self.queue.feed_data(
351 WSMessage(WSMsgType.PING, payload, ""), len(payload)
352 )
353
354 elif opcode == WSMsgType.PONG:
355 self.queue.feed_data(
356 WSMessage(WSMsgType.PONG, payload, ""), len(payload)
357 )
358
359 elif (
360 opcode not in (WSMsgType.TEXT, WSMsgType.BINARY)
361 and self._opcode is None
362 ):
363 raise WebSocketError(
364 WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
365 )
366 else:
367 # load text/binary
368 if not fin:
369 # got partial frame payload
370 if opcode != WSMsgType.CONTINUATION:
371 self._opcode = opcode
372 self._partial.extend(payload)
373 if self._max_msg_size and len(self._partial) >= self._max_msg_size:
374 raise WebSocketError(
375 WSCloseCode.MESSAGE_TOO_BIG,
376 "Message size {} exceeds limit {}".format(
377 len(self._partial), self._max_msg_size
378 ),
379 )
380 else:
381 # previous frame was non finished
382 # we should get continuation opcode
383 if self._partial:
384 if opcode != WSMsgType.CONTINUATION:
385 raise WebSocketError(
386 WSCloseCode.PROTOCOL_ERROR,
387 "The opcode in non-fin frame is expected "
388 "to be zero, got {!r}".format(opcode),
389 )
390
391 if opcode == WSMsgType.CONTINUATION:
392 assert self._opcode is not None
393 opcode = self._opcode
394 self._opcode = None
395
396 self._partial.extend(payload)
397 if self._max_msg_size and len(self._partial) >= self._max_msg_size:
398 raise WebSocketError(
399 WSCloseCode.MESSAGE_TOO_BIG,
400 "Message size {} exceeds limit {}".format(
401 len(self._partial), self._max_msg_size
402 ),
403 )
404
405 # Decompress process must to be done after all packets
406 # received.
407 if compressed:
408 assert self._decompressobj is not None
409 self._partial.extend(_WS_DEFLATE_TRAILING)
410 payload_merged = self._decompressobj.decompress_sync(
411 self._partial, self._max_msg_size
412 )
413 if self._decompressobj.unconsumed_tail:
414 left = len(self._decompressobj.unconsumed_tail)
415 raise WebSocketError(
416 WSCloseCode.MESSAGE_TOO_BIG,
417 "Decompressed message size {} exceeds limit {}".format(
418 self._max_msg_size + left, self._max_msg_size
419 ),
420 )
421 else:
422 payload_merged = bytes(self._partial)
423
424 self._partial.clear()
425
426 if opcode == WSMsgType.TEXT:
427 try:
428 text = payload_merged.decode("utf-8")
429 self.queue.feed_data(
430 WSMessage(WSMsgType.TEXT, text, ""), len(text)
431 )
432 except UnicodeDecodeError as exc:
433 raise WebSocketError(
434 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
435 ) from exc
436 else:
437 self.queue.feed_data(
438 WSMessage(WSMsgType.BINARY, payload_merged, ""),
439 len(payload_merged),
440 )
441
442 return False, b""
443
444 def parse_frame(
445 self, buf: bytes
446 ) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]:
447 """Return the next frame from the socket."""
448 frames = []
449 if self._tail:
450 buf, self._tail = self._tail + buf, b""
451
452 start_pos = 0
453 buf_length = len(buf)
454
455 while True:
456 # read header
457 if self._state == WSParserState.READ_HEADER:
458 if buf_length - start_pos >= 2:
459 data = buf[start_pos : start_pos + 2]
460 start_pos += 2
461 first_byte, second_byte = data
462
463 fin = (first_byte >> 7) & 1
464 rsv1 = (first_byte >> 6) & 1
465 rsv2 = (first_byte >> 5) & 1
466 rsv3 = (first_byte >> 4) & 1
467 opcode = first_byte & 0xF
468
469 # frame-fin = %x0 ; more frames of this message follow
470 # / %x1 ; final frame of this message
471 # frame-rsv1 = %x0 ;
472 # 1 bit, MUST be 0 unless negotiated otherwise
473 # frame-rsv2 = %x0 ;
474 # 1 bit, MUST be 0 unless negotiated otherwise
475 # frame-rsv3 = %x0 ;
476 # 1 bit, MUST be 0 unless negotiated otherwise
477 #
478 # Remove rsv1 from this test for deflate development
479 if rsv2 or rsv3 or (rsv1 and not self._compress):
480 raise WebSocketError(
481 WSCloseCode.PROTOCOL_ERROR,
482 "Received frame with non-zero reserved bits",
483 )
484
485 if opcode > 0x7 and fin == 0:
486 raise WebSocketError(
487 WSCloseCode.PROTOCOL_ERROR,
488 "Received fragmented control frame",
489 )
490
491 has_mask = (second_byte >> 7) & 1
492 length = second_byte & 0x7F
493
494 # Control frames MUST have a payload
495 # length of 125 bytes or less
496 if opcode > 0x7 and length > 125:
497 raise WebSocketError(
498 WSCloseCode.PROTOCOL_ERROR,
499 "Control frame payload cannot be " "larger than 125 bytes",
500 )
501
502 # Set compress status if last package is FIN
503 # OR set compress status if this is first fragment
504 # Raise error if not first fragment with rsv1 = 0x1
505 if self._frame_fin or self._compressed is None:
506 self._compressed = True if rsv1 else False
507 elif rsv1:
508 raise WebSocketError(
509 WSCloseCode.PROTOCOL_ERROR,
510 "Received frame with non-zero reserved bits",
511 )
512
513 self._frame_fin = bool(fin)
514 self._frame_opcode = opcode
515 self._has_mask = bool(has_mask)
516 self._payload_length_flag = length
517 self._state = WSParserState.READ_PAYLOAD_LENGTH
518 else:
519 break
520
521 # read payload length
522 if self._state == WSParserState.READ_PAYLOAD_LENGTH:
523 length = self._payload_length_flag
524 if length == 126:
525 if buf_length - start_pos >= 2:
526 data = buf[start_pos : start_pos + 2]
527 start_pos += 2
528 length = UNPACK_LEN2(data)[0]
529 self._payload_length = length
530 self._state = (
531 WSParserState.READ_PAYLOAD_MASK
532 if self._has_mask
533 else WSParserState.READ_PAYLOAD
534 )
535 else:
536 break
537 elif length > 126:
538 if buf_length - start_pos >= 8:
539 data = buf[start_pos : start_pos + 8]
540 start_pos += 8
541 length = UNPACK_LEN3(data)[0]
542 self._payload_length = length
543 self._state = (
544 WSParserState.READ_PAYLOAD_MASK
545 if self._has_mask
546 else WSParserState.READ_PAYLOAD
547 )
548 else:
549 break
550 else:
551 self._payload_length = length
552 self._state = (
553 WSParserState.READ_PAYLOAD_MASK
554 if self._has_mask
555 else WSParserState.READ_PAYLOAD
556 )
557
558 # read payload mask
559 if self._state == WSParserState.READ_PAYLOAD_MASK:
560 if buf_length - start_pos >= 4:
561 self._frame_mask = buf[start_pos : start_pos + 4]
562 start_pos += 4
563 self._state = WSParserState.READ_PAYLOAD
564 else:
565 break
566
567 if self._state == WSParserState.READ_PAYLOAD:
568 length = self._payload_length
569 payload = self._frame_payload
570
571 chunk_len = buf_length - start_pos
572 if length >= chunk_len:
573 self._payload_length = length - chunk_len
574 payload.extend(buf[start_pos:])
575 start_pos = buf_length
576 else:
577 self._payload_length = 0
578 payload.extend(buf[start_pos : start_pos + length])
579 start_pos = start_pos + length
580
581 if self._payload_length == 0:
582 if self._has_mask:
583 assert self._frame_mask is not None
584 _websocket_mask(self._frame_mask, payload)
585
586 frames.append(
587 (self._frame_fin, self._frame_opcode, payload, self._compressed)
588 )
589
590 self._frame_payload = bytearray()
591 self._state = WSParserState.READ_HEADER
592 else:
593 break
594
595 self._tail = buf[start_pos:]
596
597 return frames
598
599
600class WebSocketWriter:
601 def __init__(
602 self,
603 protocol: BaseProtocol,
604 transport: asyncio.Transport,
605 *,
606 use_mask: bool = False,
607 limit: int = DEFAULT_LIMIT,
608 random: random.Random = random.Random(),
609 compress: int = 0,
610 notakeover: bool = False,
611 ) -> None:
612 self.protocol = protocol
613 self.transport = transport
614 self.use_mask = use_mask
615 self.randrange = random.randrange
616 self.compress = compress
617 self.notakeover = notakeover
618 self._closing = False
619 self._limit = limit
620 self._output_size = 0
621 self._compressobj: Any = None # actually compressobj
622
623 async def _send_frame(
624 self, message: bytes, opcode: int, compress: Optional[int] = None
625 ) -> None:
626 """Send a frame over the websocket with message as its payload."""
627 if self._closing and not (opcode & WSMsgType.CLOSE):
628 raise ConnectionResetError("Cannot write to closing transport")
629
630 rsv = 0
631
632 # Only compress larger packets (disabled)
633 # Does small packet needs to be compressed?
634 # if self.compress and opcode < 8 and len(message) > 124:
635 if (compress or self.compress) and opcode < 8:
636 if compress:
637 # Do not set self._compress if compressing is for this frame
638 compressobj = self._make_compress_obj(compress)
639 else: # self.compress
640 if not self._compressobj:
641 self._compressobj = self._make_compress_obj(self.compress)
642 compressobj = self._compressobj
643
644 message = await compressobj.compress(message)
645 # Its critical that we do not return control to the event
646 # loop until we have finished sending all the compressed
647 # data. Otherwise we could end up mixing compressed frames
648 # if there are multiple coroutines compressing data.
649 message += compressobj.flush(
650 zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
651 )
652 if message.endswith(_WS_DEFLATE_TRAILING):
653 message = message[:-4]
654 rsv = rsv | 0x40
655
656 msg_length = len(message)
657
658 use_mask = self.use_mask
659 if use_mask:
660 mask_bit = 0x80
661 else:
662 mask_bit = 0
663
664 if msg_length < 126:
665 header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
666 elif msg_length < (1 << 16):
667 header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
668 else:
669 header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
670 if use_mask:
671 mask_int = self.randrange(0, 0xFFFFFFFF)
672 mask = mask_int.to_bytes(4, "big")
673 message = bytearray(message)
674 _websocket_mask(mask, message)
675 self._write(header + mask + message)
676 self._output_size += len(header) + len(mask) + msg_length
677 else:
678 if msg_length > MSG_SIZE:
679 self._write(header)
680 self._write(message)
681 else:
682 self._write(header + message)
683
684 self._output_size += len(header) + msg_length
685
686 # It is safe to return control to the event loop when using compression
687 # after this point as we have already sent or buffered all the data.
688
689 if self._output_size > self._limit:
690 self._output_size = 0
691 await self.protocol._drain_helper()
692
693 def _make_compress_obj(self, compress: int) -> ZLibCompressor:
694 return ZLibCompressor(
695 level=zlib.Z_BEST_SPEED,
696 wbits=-compress,
697 max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
698 )
699
700 def _write(self, data: bytes) -> None:
701 if self.transport is None or self.transport.is_closing():
702 raise ConnectionResetError("Cannot write to closing transport")
703 self.transport.write(data)
704
705 async def pong(self, message: Union[bytes, str] = b"") -> None:
706 """Send pong message."""
707 if isinstance(message, str):
708 message = message.encode("utf-8")
709 await self._send_frame(message, WSMsgType.PONG)
710
711 async def ping(self, message: Union[bytes, str] = b"") -> None:
712 """Send ping message."""
713 if isinstance(message, str):
714 message = message.encode("utf-8")
715 await self._send_frame(message, WSMsgType.PING)
716
717 async def send(
718 self,
719 message: Union[str, bytes],
720 binary: bool = False,
721 compress: Optional[int] = None,
722 ) -> None:
723 """Send a frame over the websocket with message as its payload."""
724 if isinstance(message, str):
725 message = message.encode("utf-8")
726 if binary:
727 await self._send_frame(message, WSMsgType.BINARY, compress)
728 else:
729 await self._send_frame(message, WSMsgType.TEXT, compress)
730
731 async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
732 """Close the websocket, sending the specified code and message."""
733 if isinstance(message, str):
734 message = message.encode("utf-8")
735 try:
736 await self._send_frame(
737 PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
738 )
739 finally:
740 self._closing = True