Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/aiohttp/http_websocket.py: 30%
388 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:40 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:40 +0000
1"""WebSocket protocol versions 13 and 8."""
3import asyncio
4import functools
5import json
6import random
7import re
8import sys
9import zlib
10from enum import IntEnum
11from struct import Struct
12from typing import (
13 Any,
14 Callable,
15 Final,
16 List,
17 NamedTuple,
18 Optional,
19 Pattern,
20 Set,
21 Tuple,
22 Union,
23 cast,
24)
26from .base_protocol import BaseProtocol
27from .compression_utils import ZLibCompressor, ZLibDecompressor
28from .helpers import NO_EXTENSIONS
29from .streams import DataQueue
31__all__ = (
32 "WS_CLOSED_MESSAGE",
33 "WS_CLOSING_MESSAGE",
34 "WS_KEY",
35 "WebSocketReader",
36 "WebSocketWriter",
37 "WSMessage",
38 "WebSocketError",
39 "WSMsgType",
40 "WSCloseCode",
41)
44class WSCloseCode(IntEnum):
45 OK = 1000
46 GOING_AWAY = 1001
47 PROTOCOL_ERROR = 1002
48 UNSUPPORTED_DATA = 1003
49 ABNORMAL_CLOSURE = 1006
50 INVALID_TEXT = 1007
51 POLICY_VIOLATION = 1008
52 MESSAGE_TOO_BIG = 1009
53 MANDATORY_EXTENSION = 1010
54 INTERNAL_ERROR = 1011
55 SERVICE_RESTART = 1012
56 TRY_AGAIN_LATER = 1013
57 BAD_GATEWAY = 1014
60ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
62# For websockets, keeping latency low is extremely important as implementations
63# generally expect to be able to send and receive messages quickly. We use a
64# larger chunk size than the default to reduce the number of executor calls
65# since the executor is a significant source of latency and overhead when
66# the chunks are small. A size of 5KiB was chosen because it is also the
67# same value python-zlib-ng choose to use as the threshold to release the GIL.
69WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024
72class WSMsgType(IntEnum):
73 # websocket spec types
74 CONTINUATION = 0x0
75 TEXT = 0x1
76 BINARY = 0x2
77 PING = 0x9
78 PONG = 0xA
79 CLOSE = 0x8
81 # aiohttp specific types
82 CLOSING = 0x100
83 CLOSED = 0x101
84 ERROR = 0x102
86 text = TEXT
87 binary = BINARY
88 ping = PING
89 pong = PONG
90 close = CLOSE
91 closing = CLOSING
92 closed = CLOSED
93 error = ERROR
96WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
99UNPACK_LEN2 = Struct("!H").unpack_from
100UNPACK_LEN3 = Struct("!Q").unpack_from
101UNPACK_CLOSE_CODE = Struct("!H").unpack
102PACK_LEN1 = Struct("!BB").pack
103PACK_LEN2 = Struct("!BBH").pack
104PACK_LEN3 = Struct("!BBQ").pack
105PACK_CLOSE_CODE = Struct("!H").pack
106MSG_SIZE: Final[int] = 2**14
107DEFAULT_LIMIT: Final[int] = 2**16
110class WSMessage(NamedTuple):
111 type: WSMsgType
112 # To type correctly, this would need some kind of tagged union for each type.
113 data: Any
114 extra: Optional[str]
116 def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
117 """Return parsed JSON data.
119 .. versionadded:: 0.22
120 """
121 return loads(self.data)
124WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
125WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
128class WebSocketError(Exception):
129 """WebSocket protocol parser error."""
131 def __init__(self, code: int, message: str) -> None:
132 self.code = code
133 super().__init__(code, message)
135 def __str__(self) -> str:
136 return cast(str, self.args[1])
139class WSHandshakeError(Exception):
140 """WebSocket protocol handshake error."""
143native_byteorder: Final[str] = sys.byteorder
146# Used by _websocket_mask_python
147@functools.lru_cache
148def _xor_table() -> List[bytes]:
149 return [bytes(a ^ b for a in range(256)) for b in range(256)]
152def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
153 """Websocket masking function.
155 `mask` is a `bytes` object of length 4; `data` is a `bytearray`
156 object of any length. The contents of `data` are masked with `mask`,
157 as specified in section 5.3 of RFC 6455.
159 Note that this function mutates the `data` argument.
161 This pure-python implementation may be replaced by an optimized
162 version when available.
164 """
165 assert isinstance(data, bytearray), data
166 assert len(mask) == 4, mask
168 if data:
169 _XOR_TABLE = _xor_table()
170 a, b, c, d = (_XOR_TABLE[n] for n in mask)
171 data[::4] = data[::4].translate(a)
172 data[1::4] = data[1::4].translate(b)
173 data[2::4] = data[2::4].translate(c)
174 data[3::4] = data[3::4].translate(d)
177if NO_EXTENSIONS: # pragma: no cover
178 _websocket_mask = _websocket_mask_python
179else:
180 try:
181 from ._websocket import _websocket_mask_cython # type: ignore[import-not-found]
183 _websocket_mask = _websocket_mask_cython
184 except ImportError: # pragma: no cover
185 _websocket_mask = _websocket_mask_python
187_WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
190_WS_EXT_RE: Final[Pattern[str]] = re.compile(
191 r"^(?:;\s*(?:"
192 r"(server_no_context_takeover)|"
193 r"(client_no_context_takeover)|"
194 r"(server_max_window_bits(?:=(\d+))?)|"
195 r"(client_max_window_bits(?:=(\d+))?)))*$"
196)
198_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
201def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
202 if not extstr:
203 return 0, False
205 compress = 0
206 notakeover = False
207 for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
208 defext = ext.group(1)
209 # Return compress = 15 when get `permessage-deflate`
210 if not defext:
211 compress = 15
212 break
213 match = _WS_EXT_RE.match(defext)
214 if match:
215 compress = 15
216 if isserver:
217 # Server never fail to detect compress handshake.
218 # Server does not need to send max wbit to client
219 if match.group(4):
220 compress = int(match.group(4))
221 # Group3 must match if group4 matches
222 # Compress wbit 8 does not support in zlib
223 # If compress level not support,
224 # CONTINUE to next extension
225 if compress > 15 or compress < 9:
226 compress = 0
227 continue
228 if match.group(1):
229 notakeover = True
230 # Ignore regex group 5 & 6 for client_max_window_bits
231 break
232 else:
233 if match.group(6):
234 compress = int(match.group(6))
235 # Group5 must match if group6 matches
236 # Compress wbit 8 does not support in zlib
237 # If compress level not support,
238 # FAIL the parse progress
239 if compress > 15 or compress < 9:
240 raise WSHandshakeError("Invalid window size")
241 if match.group(2):
242 notakeover = True
243 # Ignore regex group 5 & 6 for client_max_window_bits
244 break
245 # Return Fail if client side and not match
246 elif not isserver:
247 raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
249 return compress, notakeover
252def ws_ext_gen(
253 compress: int = 15, isserver: bool = False, server_notakeover: bool = False
254) -> str:
255 # client_notakeover=False not used for server
256 # compress wbit 8 does not support in zlib
257 if compress < 9 or compress > 15:
258 raise ValueError(
259 "Compress wbits must between 9 and 15, " "zlib does not support wbits=8"
260 )
261 enabledext = ["permessage-deflate"]
262 if not isserver:
263 enabledext.append("client_max_window_bits")
265 if compress < 15:
266 enabledext.append("server_max_window_bits=" + str(compress))
267 if server_notakeover:
268 enabledext.append("server_no_context_takeover")
269 # if client_notakeover:
270 # enabledext.append('client_no_context_takeover')
271 return "; ".join(enabledext)
274class WSParserState(IntEnum):
275 READ_HEADER = 1
276 READ_PAYLOAD_LENGTH = 2
277 READ_PAYLOAD_MASK = 3
278 READ_PAYLOAD = 4
281class WebSocketReader:
282 def __init__(
283 self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True
284 ) -> None:
285 self.queue = queue
286 self._max_msg_size = max_msg_size
288 self._exc: Optional[BaseException] = None
289 self._partial = bytearray()
290 self._state = WSParserState.READ_HEADER
292 self._opcode: Optional[int] = None
293 self._frame_fin = False
294 self._frame_opcode: Optional[int] = None
295 self._frame_payload = bytearray()
297 self._tail = b""
298 self._has_mask = False
299 self._frame_mask: Optional[bytes] = None
300 self._payload_length = 0
301 self._payload_length_flag = 0
302 self._compressed: Optional[bool] = None
303 self._decompressobj: Optional[ZLibDecompressor] = None
304 self._compress = compress
306 def feed_eof(self) -> None:
307 self.queue.feed_eof()
309 def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
310 if self._exc:
311 return True, data
313 try:
314 return self._feed_data(data)
315 except Exception as exc:
316 self._exc = exc
317 self.queue.set_exception(exc)
318 return True, b""
320 def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
321 for fin, opcode, payload, compressed in self.parse_frame(data):
322 if compressed and not self._decompressobj:
323 self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
324 if opcode == WSMsgType.CLOSE:
325 if len(payload) >= 2:
326 close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
327 if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
328 raise WebSocketError(
329 WSCloseCode.PROTOCOL_ERROR,
330 f"Invalid close code: {close_code}",
331 )
332 try:
333 close_message = payload[2:].decode("utf-8")
334 except UnicodeDecodeError as exc:
335 raise WebSocketError(
336 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
337 ) from exc
338 msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
339 elif payload:
340 raise WebSocketError(
341 WSCloseCode.PROTOCOL_ERROR,
342 f"Invalid close frame: {fin} {opcode} {payload!r}",
343 )
344 else:
345 msg = WSMessage(WSMsgType.CLOSE, 0, "")
347 self.queue.feed_data(msg, 0)
349 elif opcode == WSMsgType.PING:
350 self.queue.feed_data(
351 WSMessage(WSMsgType.PING, payload, ""), len(payload)
352 )
354 elif opcode == WSMsgType.PONG:
355 self.queue.feed_data(
356 WSMessage(WSMsgType.PONG, payload, ""), len(payload)
357 )
359 elif (
360 opcode not in (WSMsgType.TEXT, WSMsgType.BINARY)
361 and self._opcode is None
362 ):
363 raise WebSocketError(
364 WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
365 )
366 else:
367 # load text/binary
368 if not fin:
369 # got partial frame payload
370 if opcode != WSMsgType.CONTINUATION:
371 self._opcode = opcode
372 self._partial.extend(payload)
373 if self._max_msg_size and len(self._partial) >= self._max_msg_size:
374 raise WebSocketError(
375 WSCloseCode.MESSAGE_TOO_BIG,
376 "Message size {} exceeds limit {}".format(
377 len(self._partial), self._max_msg_size
378 ),
379 )
380 else:
381 # previous frame was non finished
382 # we should get continuation opcode
383 if self._partial:
384 if opcode != WSMsgType.CONTINUATION:
385 raise WebSocketError(
386 WSCloseCode.PROTOCOL_ERROR,
387 "The opcode in non-fin frame is expected "
388 "to be zero, got {!r}".format(opcode),
389 )
391 if opcode == WSMsgType.CONTINUATION:
392 assert self._opcode is not None
393 opcode = self._opcode
394 self._opcode = None
396 self._partial.extend(payload)
397 if self._max_msg_size and len(self._partial) >= self._max_msg_size:
398 raise WebSocketError(
399 WSCloseCode.MESSAGE_TOO_BIG,
400 "Message size {} exceeds limit {}".format(
401 len(self._partial), self._max_msg_size
402 ),
403 )
405 # Decompress process must to be done after all packets
406 # received.
407 if compressed:
408 assert self._decompressobj is not None
409 self._partial.extend(_WS_DEFLATE_TRAILING)
410 payload_merged = self._decompressobj.decompress_sync(
411 self._partial, self._max_msg_size
412 )
413 if self._decompressobj.unconsumed_tail:
414 left = len(self._decompressobj.unconsumed_tail)
415 raise WebSocketError(
416 WSCloseCode.MESSAGE_TOO_BIG,
417 "Decompressed message size {} exceeds limit {}".format(
418 self._max_msg_size + left, self._max_msg_size
419 ),
420 )
421 else:
422 payload_merged = bytes(self._partial)
424 self._partial.clear()
426 if opcode == WSMsgType.TEXT:
427 try:
428 text = payload_merged.decode("utf-8")
429 self.queue.feed_data(
430 WSMessage(WSMsgType.TEXT, text, ""), len(text)
431 )
432 except UnicodeDecodeError as exc:
433 raise WebSocketError(
434 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
435 ) from exc
436 else:
437 self.queue.feed_data(
438 WSMessage(WSMsgType.BINARY, payload_merged, ""),
439 len(payload_merged),
440 )
442 return False, b""
444 def parse_frame(
445 self, buf: bytes
446 ) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]:
447 """Return the next frame from the socket."""
448 frames = []
449 if self._tail:
450 buf, self._tail = self._tail + buf, b""
452 start_pos = 0
453 buf_length = len(buf)
455 while True:
456 # read header
457 if self._state == WSParserState.READ_HEADER:
458 if buf_length - start_pos >= 2:
459 data = buf[start_pos : start_pos + 2]
460 start_pos += 2
461 first_byte, second_byte = data
463 fin = (first_byte >> 7) & 1
464 rsv1 = (first_byte >> 6) & 1
465 rsv2 = (first_byte >> 5) & 1
466 rsv3 = (first_byte >> 4) & 1
467 opcode = first_byte & 0xF
469 # frame-fin = %x0 ; more frames of this message follow
470 # / %x1 ; final frame of this message
471 # frame-rsv1 = %x0 ;
472 # 1 bit, MUST be 0 unless negotiated otherwise
473 # frame-rsv2 = %x0 ;
474 # 1 bit, MUST be 0 unless negotiated otherwise
475 # frame-rsv3 = %x0 ;
476 # 1 bit, MUST be 0 unless negotiated otherwise
477 #
478 # Remove rsv1 from this test for deflate development
479 if rsv2 or rsv3 or (rsv1 and not self._compress):
480 raise WebSocketError(
481 WSCloseCode.PROTOCOL_ERROR,
482 "Received frame with non-zero reserved bits",
483 )
485 if opcode > 0x7 and fin == 0:
486 raise WebSocketError(
487 WSCloseCode.PROTOCOL_ERROR,
488 "Received fragmented control frame",
489 )
491 has_mask = (second_byte >> 7) & 1
492 length = second_byte & 0x7F
494 # Control frames MUST have a payload
495 # length of 125 bytes or less
496 if opcode > 0x7 and length > 125:
497 raise WebSocketError(
498 WSCloseCode.PROTOCOL_ERROR,
499 "Control frame payload cannot be " "larger than 125 bytes",
500 )
502 # Set compress status if last package is FIN
503 # OR set compress status if this is first fragment
504 # Raise error if not first fragment with rsv1 = 0x1
505 if self._frame_fin or self._compressed is None:
506 self._compressed = True if rsv1 else False
507 elif rsv1:
508 raise WebSocketError(
509 WSCloseCode.PROTOCOL_ERROR,
510 "Received frame with non-zero reserved bits",
511 )
513 self._frame_fin = bool(fin)
514 self._frame_opcode = opcode
515 self._has_mask = bool(has_mask)
516 self._payload_length_flag = length
517 self._state = WSParserState.READ_PAYLOAD_LENGTH
518 else:
519 break
521 # read payload length
522 if self._state == WSParserState.READ_PAYLOAD_LENGTH:
523 length = self._payload_length_flag
524 if length == 126:
525 if buf_length - start_pos >= 2:
526 data = buf[start_pos : start_pos + 2]
527 start_pos += 2
528 length = UNPACK_LEN2(data)[0]
529 self._payload_length = length
530 self._state = (
531 WSParserState.READ_PAYLOAD_MASK
532 if self._has_mask
533 else WSParserState.READ_PAYLOAD
534 )
535 else:
536 break
537 elif length > 126:
538 if buf_length - start_pos >= 8:
539 data = buf[start_pos : start_pos + 8]
540 start_pos += 8
541 length = UNPACK_LEN3(data)[0]
542 self._payload_length = length
543 self._state = (
544 WSParserState.READ_PAYLOAD_MASK
545 if self._has_mask
546 else WSParserState.READ_PAYLOAD
547 )
548 else:
549 break
550 else:
551 self._payload_length = length
552 self._state = (
553 WSParserState.READ_PAYLOAD_MASK
554 if self._has_mask
555 else WSParserState.READ_PAYLOAD
556 )
558 # read payload mask
559 if self._state == WSParserState.READ_PAYLOAD_MASK:
560 if buf_length - start_pos >= 4:
561 self._frame_mask = buf[start_pos : start_pos + 4]
562 start_pos += 4
563 self._state = WSParserState.READ_PAYLOAD
564 else:
565 break
567 if self._state == WSParserState.READ_PAYLOAD:
568 length = self._payload_length
569 payload = self._frame_payload
571 chunk_len = buf_length - start_pos
572 if length >= chunk_len:
573 self._payload_length = length - chunk_len
574 payload.extend(buf[start_pos:])
575 start_pos = buf_length
576 else:
577 self._payload_length = 0
578 payload.extend(buf[start_pos : start_pos + length])
579 start_pos = start_pos + length
581 if self._payload_length == 0:
582 if self._has_mask:
583 assert self._frame_mask is not None
584 _websocket_mask(self._frame_mask, payload)
586 frames.append(
587 (self._frame_fin, self._frame_opcode, payload, self._compressed)
588 )
590 self._frame_payload = bytearray()
591 self._state = WSParserState.READ_HEADER
592 else:
593 break
595 self._tail = buf[start_pos:]
597 return frames
600class WebSocketWriter:
601 def __init__(
602 self,
603 protocol: BaseProtocol,
604 transport: asyncio.Transport,
605 *,
606 use_mask: bool = False,
607 limit: int = DEFAULT_LIMIT,
608 random: Any = random.Random(),
609 compress: int = 0,
610 notakeover: bool = False,
611 ) -> None:
612 self.protocol = protocol
613 self.transport = transport
614 self.use_mask = use_mask
615 self.randrange = random.randrange
616 self.compress = compress
617 self.notakeover = notakeover
618 self._closing = False
619 self._limit = limit
620 self._output_size = 0
621 self._compressobj: Any = None # actually compressobj
623 async def _send_frame(
624 self, message: bytes, opcode: int, compress: Optional[int] = None
625 ) -> None:
626 """Send a frame over the websocket with message as its payload."""
627 if self._closing and not (opcode & WSMsgType.CLOSE):
628 raise ConnectionResetError("Cannot write to closing transport")
630 rsv = 0
632 # Only compress larger packets (disabled)
633 # Does small packet needs to be compressed?
634 # if self.compress and opcode < 8 and len(message) > 124:
635 if (compress or self.compress) and opcode < 8:
636 if compress:
637 # Do not set self._compress if compressing is for this frame
638 compressobj = self._make_compress_obj(compress)
639 else: # self.compress
640 if not self._compressobj:
641 self._compressobj = self._make_compress_obj(self.compress)
642 compressobj = self._compressobj
644 message = await compressobj.compress(message)
645 # Its critical that we do not return control to the event
646 # loop until we have finished sending all the compressed
647 # data. Otherwise we could end up mixing compressed frames
648 # if there are multiple coroutines compressing data.
649 message += compressobj.flush(
650 zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
651 )
652 if message.endswith(_WS_DEFLATE_TRAILING):
653 message = message[:-4]
654 rsv = rsv | 0x40
656 msg_length = len(message)
658 use_mask = self.use_mask
659 if use_mask:
660 mask_bit = 0x80
661 else:
662 mask_bit = 0
664 if msg_length < 126:
665 header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
666 elif msg_length < (1 << 16):
667 header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
668 else:
669 header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
670 if use_mask:
671 mask = self.randrange(0, 0xFFFFFFFF)
672 mask = mask.to_bytes(4, "big")
673 message = bytearray(message)
674 _websocket_mask(mask, message)
675 self._write(header + mask + message)
676 self._output_size += len(header) + len(mask) + len(message)
677 else:
678 if len(message) > MSG_SIZE:
679 self._write(header)
680 self._write(message)
681 else:
682 self._write(header + message)
684 self._output_size += len(header) + len(message)
686 # It is safe to return control to the event loop when using compression
687 # after this point as we have already sent or buffered all the data.
689 if self._output_size > self._limit:
690 self._output_size = 0
691 await self.protocol._drain_helper()
693 def _make_compress_obj(self, compress: int) -> ZLibCompressor:
694 return ZLibCompressor(
695 level=zlib.Z_BEST_SPEED,
696 wbits=-compress,
697 max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
698 )
700 def _write(self, data: bytes) -> None:
701 if self.transport is None or self.transport.is_closing():
702 raise ConnectionResetError("Cannot write to closing transport")
703 self.transport.write(data)
705 async def pong(self, message: Union[bytes, str] = b"") -> None:
706 """Send pong message."""
707 if isinstance(message, str):
708 message = message.encode("utf-8")
709 await self._send_frame(message, WSMsgType.PONG)
711 async def ping(self, message: Union[bytes, str] = b"") -> None:
712 """Send ping message."""
713 if isinstance(message, str):
714 message = message.encode("utf-8")
715 await self._send_frame(message, WSMsgType.PING)
717 async def send(
718 self,
719 message: Union[str, bytes],
720 binary: bool = False,
721 compress: Optional[int] = None,
722 ) -> None:
723 """Send a frame over the websocket with message as its payload."""
724 if isinstance(message, str):
725 message = message.encode("utf-8")
726 if binary:
727 await self._send_frame(message, WSMsgType.BINARY, compress)
728 else:
729 await self._send_frame(message, WSMsgType.TEXT, compress)
731 async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
732 """Close the websocket, sending the specified code and message."""
733 if isinstance(message, str):
734 message = message.encode("utf-8")
735 try:
736 await self._send_frame(
737 PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
738 )
739 finally:
740 self._closing = True