Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/http_websocket.py: 28%
380 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-26 06:16 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-26 06:16 +0000
1"""WebSocket protocol versions 13 and 8."""
3import asyncio
4import functools
5import json
6import random
7import re
8import sys
9import zlib
10from enum import IntEnum
11from struct import Struct
12from typing import (
13 Any,
14 Callable,
15 Final,
16 List,
17 NamedTuple,
18 Optional,
19 Pattern,
20 Set,
21 Tuple,
22 Union,
23 cast,
24)
26from .base_protocol import BaseProtocol
27from .compression_utils import ZLibCompressor, ZLibDecompressor
28from .helpers import NO_EXTENSIONS
29from .streams import DataQueue
31__all__ = (
32 "WS_CLOSED_MESSAGE",
33 "WS_CLOSING_MESSAGE",
34 "WS_KEY",
35 "WebSocketReader",
36 "WebSocketWriter",
37 "WSMessage",
38 "WebSocketError",
39 "WSMsgType",
40 "WSCloseCode",
41)
44class WSCloseCode(IntEnum):
45 OK = 1000
46 GOING_AWAY = 1001
47 PROTOCOL_ERROR = 1002
48 UNSUPPORTED_DATA = 1003
49 ABNORMAL_CLOSURE = 1006
50 INVALID_TEXT = 1007
51 POLICY_VIOLATION = 1008
52 MESSAGE_TOO_BIG = 1009
53 MANDATORY_EXTENSION = 1010
54 INTERNAL_ERROR = 1011
55 SERVICE_RESTART = 1012
56 TRY_AGAIN_LATER = 1013
57 BAD_GATEWAY = 1014
60ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
62# For websockets, keeping latency low is extremely important as implementations
63# generally expect to be able to send and receive messages quickly. We use a
64# larger chunk size than the default to reduce the number of executor calls
65# since the executor is a significant source of latency and overhead when
66# the chunks are small. A size of 5KiB was chosen because it is also the
67# same value python-zlib-ng choose to use as the threshold to release the GIL.
69WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024
72class WSMsgType(IntEnum):
73 # websocket spec types
74 CONTINUATION = 0x0
75 TEXT = 0x1
76 BINARY = 0x2
77 PING = 0x9
78 PONG = 0xA
79 CLOSE = 0x8
81 # aiohttp specific types
82 CLOSING = 0x100
83 CLOSED = 0x101
84 ERROR = 0x102
87WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
90UNPACK_LEN2 = Struct("!H").unpack_from
91UNPACK_LEN3 = Struct("!Q").unpack_from
92UNPACK_CLOSE_CODE = Struct("!H").unpack
93PACK_LEN1 = Struct("!BB").pack
94PACK_LEN2 = Struct("!BBH").pack
95PACK_LEN3 = Struct("!BBQ").pack
96PACK_CLOSE_CODE = Struct("!H").pack
97MSG_SIZE: Final[int] = 2**14
98DEFAULT_LIMIT: Final[int] = 2**16
101class WSMessage(NamedTuple):
102 type: WSMsgType
103 # To type correctly, this would need some kind of tagged union for each type.
104 data: Any
105 extra: Optional[str]
107 def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
108 """Return parsed JSON data.
110 .. versionadded:: 0.22
111 """
112 return loads(self.data)
115WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
116WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
119class WebSocketError(Exception):
120 """WebSocket protocol parser error."""
122 def __init__(self, code: int, message: str) -> None:
123 self.code = code
124 super().__init__(code, message)
126 def __str__(self) -> str:
127 return cast(str, self.args[1])
130class WSHandshakeError(Exception):
131 """WebSocket protocol handshake error."""
134native_byteorder: Final[str] = sys.byteorder
137# Used by _websocket_mask_python
138@functools.lru_cache
139def _xor_table() -> List[bytes]:
140 return [bytes(a ^ b for a in range(256)) for b in range(256)]
143def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
144 """Websocket masking function.
146 `mask` is a `bytes` object of length 4; `data` is a `bytearray`
147 object of any length. The contents of `data` are masked with `mask`,
148 as specified in section 5.3 of RFC 6455.
150 Note that this function mutates the `data` argument.
152 This pure-python implementation may be replaced by an optimized
153 version when available.
155 """
156 assert isinstance(data, bytearray), data
157 assert len(mask) == 4, mask
159 if data:
160 _XOR_TABLE = _xor_table()
161 a, b, c, d = (_XOR_TABLE[n] for n in mask)
162 data[::4] = data[::4].translate(a)
163 data[1::4] = data[1::4].translate(b)
164 data[2::4] = data[2::4].translate(c)
165 data[3::4] = data[3::4].translate(d)
168if NO_EXTENSIONS: # pragma: no cover
169 _websocket_mask = _websocket_mask_python
170else:
171 try:
172 from ._websocket import _websocket_mask_cython # type: ignore[import-not-found]
174 _websocket_mask = _websocket_mask_cython
175 except ImportError: # pragma: no cover
176 _websocket_mask = _websocket_mask_python
178_WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
181_WS_EXT_RE: Final[Pattern[str]] = re.compile(
182 r"^(?:;\s*(?:"
183 r"(server_no_context_takeover)|"
184 r"(client_no_context_takeover)|"
185 r"(server_max_window_bits(?:=(\d+))?)|"
186 r"(client_max_window_bits(?:=(\d+))?)))*$"
187)
189_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
192def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
193 if not extstr:
194 return 0, False
196 compress = 0
197 notakeover = False
198 for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
199 defext = ext.group(1)
200 # Return compress = 15 when get `permessage-deflate`
201 if not defext:
202 compress = 15
203 break
204 match = _WS_EXT_RE.match(defext)
205 if match:
206 compress = 15
207 if isserver:
208 # Server never fail to detect compress handshake.
209 # Server does not need to send max wbit to client
210 if match.group(4):
211 compress = int(match.group(4))
212 # Group3 must match if group4 matches
213 # Compress wbit 8 does not support in zlib
214 # If compress level not support,
215 # CONTINUE to next extension
216 if compress > 15 or compress < 9:
217 compress = 0
218 continue
219 if match.group(1):
220 notakeover = True
221 # Ignore regex group 5 & 6 for client_max_window_bits
222 break
223 else:
224 if match.group(6):
225 compress = int(match.group(6))
226 # Group5 must match if group6 matches
227 # Compress wbit 8 does not support in zlib
228 # If compress level not support,
229 # FAIL the parse progress
230 if compress > 15 or compress < 9:
231 raise WSHandshakeError("Invalid window size")
232 if match.group(2):
233 notakeover = True
234 # Ignore regex group 5 & 6 for client_max_window_bits
235 break
236 # Return Fail if client side and not match
237 elif not isserver:
238 raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
240 return compress, notakeover
243def ws_ext_gen(
244 compress: int = 15, isserver: bool = False, server_notakeover: bool = False
245) -> str:
246 # client_notakeover=False not used for server
247 # compress wbit 8 does not support in zlib
248 if compress < 9 or compress > 15:
249 raise ValueError(
250 "Compress wbits must between 9 and 15, " "zlib does not support wbits=8"
251 )
252 enabledext = ["permessage-deflate"]
253 if not isserver:
254 enabledext.append("client_max_window_bits")
256 if compress < 15:
257 enabledext.append("server_max_window_bits=" + str(compress))
258 if server_notakeover:
259 enabledext.append("server_no_context_takeover")
260 # if client_notakeover:
261 # enabledext.append('client_no_context_takeover')
262 return "; ".join(enabledext)
265class WSParserState(IntEnum):
266 READ_HEADER = 1
267 READ_PAYLOAD_LENGTH = 2
268 READ_PAYLOAD_MASK = 3
269 READ_PAYLOAD = 4
272class WebSocketReader:
273 def __init__(
274 self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True
275 ) -> None:
276 self.queue = queue
277 self._max_msg_size = max_msg_size
279 self._exc: Optional[BaseException] = None
280 self._partial = bytearray()
281 self._state = WSParserState.READ_HEADER
283 self._opcode: Optional[int] = None
284 self._frame_fin = False
285 self._frame_opcode: Optional[int] = None
286 self._frame_payload = bytearray()
288 self._tail = b""
289 self._has_mask = False
290 self._frame_mask: Optional[bytes] = None
291 self._payload_length = 0
292 self._payload_length_flag = 0
293 self._compressed: Optional[bool] = None
294 self._decompressobj: Optional[ZLibDecompressor] = None
295 self._compress = compress
297 def feed_eof(self) -> None:
298 self.queue.feed_eof()
300 def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
301 if self._exc:
302 return True, data
304 try:
305 return self._feed_data(data)
306 except Exception as exc:
307 self._exc = exc
308 self.queue.set_exception(exc)
309 return True, b""
311 def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
312 for fin, opcode, payload, compressed in self.parse_frame(data):
313 if compressed and not self._decompressobj:
314 self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
315 if opcode == WSMsgType.CLOSE:
316 if len(payload) >= 2:
317 close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
318 if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
319 raise WebSocketError(
320 WSCloseCode.PROTOCOL_ERROR,
321 f"Invalid close code: {close_code}",
322 )
323 try:
324 close_message = payload[2:].decode("utf-8")
325 except UnicodeDecodeError as exc:
326 raise WebSocketError(
327 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
328 ) from exc
329 msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
330 elif payload:
331 raise WebSocketError(
332 WSCloseCode.PROTOCOL_ERROR,
333 f"Invalid close frame: {fin} {opcode} {payload!r}",
334 )
335 else:
336 msg = WSMessage(WSMsgType.CLOSE, 0, "")
338 self.queue.feed_data(msg, 0)
340 elif opcode == WSMsgType.PING:
341 self.queue.feed_data(
342 WSMessage(WSMsgType.PING, payload, ""), len(payload)
343 )
345 elif opcode == WSMsgType.PONG:
346 self.queue.feed_data(
347 WSMessage(WSMsgType.PONG, payload, ""), len(payload)
348 )
350 elif (
351 opcode not in (WSMsgType.TEXT, WSMsgType.BINARY)
352 and self._opcode is None
353 ):
354 raise WebSocketError(
355 WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
356 )
357 else:
358 # load text/binary
359 if not fin:
360 # got partial frame payload
361 if opcode != WSMsgType.CONTINUATION:
362 self._opcode = opcode
363 self._partial.extend(payload)
364 if self._max_msg_size and len(self._partial) >= self._max_msg_size:
365 raise WebSocketError(
366 WSCloseCode.MESSAGE_TOO_BIG,
367 "Message size {} exceeds limit {}".format(
368 len(self._partial), self._max_msg_size
369 ),
370 )
371 else:
372 # previous frame was non finished
373 # we should get continuation opcode
374 if self._partial:
375 if opcode != WSMsgType.CONTINUATION:
376 raise WebSocketError(
377 WSCloseCode.PROTOCOL_ERROR,
378 "The opcode in non-fin frame is expected "
379 "to be zero, got {!r}".format(opcode),
380 )
382 if opcode == WSMsgType.CONTINUATION:
383 assert self._opcode is not None
384 opcode = self._opcode
385 self._opcode = None
387 self._partial.extend(payload)
388 if self._max_msg_size and len(self._partial) >= self._max_msg_size:
389 raise WebSocketError(
390 WSCloseCode.MESSAGE_TOO_BIG,
391 "Message size {} exceeds limit {}".format(
392 len(self._partial), self._max_msg_size
393 ),
394 )
396 # Decompress process must to be done after all packets
397 # received.
398 if compressed:
399 assert self._decompressobj is not None
400 self._partial.extend(_WS_DEFLATE_TRAILING)
401 payload_merged = self._decompressobj.decompress_sync(
402 self._partial, self._max_msg_size
403 )
404 if self._decompressobj.unconsumed_tail:
405 left = len(self._decompressobj.unconsumed_tail)
406 raise WebSocketError(
407 WSCloseCode.MESSAGE_TOO_BIG,
408 "Decompressed message size {} exceeds limit {}".format(
409 self._max_msg_size + left, self._max_msg_size
410 ),
411 )
412 else:
413 payload_merged = bytes(self._partial)
415 self._partial.clear()
417 if opcode == WSMsgType.TEXT:
418 try:
419 text = payload_merged.decode("utf-8")
420 self.queue.feed_data(
421 WSMessage(WSMsgType.TEXT, text, ""), len(text)
422 )
423 except UnicodeDecodeError as exc:
424 raise WebSocketError(
425 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
426 ) from exc
427 else:
428 self.queue.feed_data(
429 WSMessage(WSMsgType.BINARY, payload_merged, ""),
430 len(payload_merged),
431 )
433 return False, b""
435 def parse_frame(
436 self, buf: bytes
437 ) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]:
438 """Return the next frame from the socket."""
439 frames = []
440 if self._tail:
441 buf, self._tail = self._tail + buf, b""
443 start_pos = 0
444 buf_length = len(buf)
446 while True:
447 # read header
448 if self._state == WSParserState.READ_HEADER:
449 if buf_length - start_pos >= 2:
450 data = buf[start_pos : start_pos + 2]
451 start_pos += 2
452 first_byte, second_byte = data
454 fin = (first_byte >> 7) & 1
455 rsv1 = (first_byte >> 6) & 1
456 rsv2 = (first_byte >> 5) & 1
457 rsv3 = (first_byte >> 4) & 1
458 opcode = first_byte & 0xF
460 # frame-fin = %x0 ; more frames of this message follow
461 # / %x1 ; final frame of this message
462 # frame-rsv1 = %x0 ;
463 # 1 bit, MUST be 0 unless negotiated otherwise
464 # frame-rsv2 = %x0 ;
465 # 1 bit, MUST be 0 unless negotiated otherwise
466 # frame-rsv3 = %x0 ;
467 # 1 bit, MUST be 0 unless negotiated otherwise
468 #
469 # Remove rsv1 from this test for deflate development
470 if rsv2 or rsv3 or (rsv1 and not self._compress):
471 raise WebSocketError(
472 WSCloseCode.PROTOCOL_ERROR,
473 "Received frame with non-zero reserved bits",
474 )
476 if opcode > 0x7 and fin == 0:
477 raise WebSocketError(
478 WSCloseCode.PROTOCOL_ERROR,
479 "Received fragmented control frame",
480 )
482 has_mask = (second_byte >> 7) & 1
483 length = second_byte & 0x7F
485 # Control frames MUST have a payload
486 # length of 125 bytes or less
487 if opcode > 0x7 and length > 125:
488 raise WebSocketError(
489 WSCloseCode.PROTOCOL_ERROR,
490 "Control frame payload cannot be " "larger than 125 bytes",
491 )
493 # Set compress status if last package is FIN
494 # OR set compress status if this is first fragment
495 # Raise error if not first fragment with rsv1 = 0x1
496 if self._frame_fin or self._compressed is None:
497 self._compressed = True if rsv1 else False
498 elif rsv1:
499 raise WebSocketError(
500 WSCloseCode.PROTOCOL_ERROR,
501 "Received frame with non-zero reserved bits",
502 )
504 self._frame_fin = bool(fin)
505 self._frame_opcode = opcode
506 self._has_mask = bool(has_mask)
507 self._payload_length_flag = length
508 self._state = WSParserState.READ_PAYLOAD_LENGTH
509 else:
510 break
512 # read payload length
513 if self._state == WSParserState.READ_PAYLOAD_LENGTH:
514 length = self._payload_length_flag
515 if length == 126:
516 if buf_length - start_pos >= 2:
517 data = buf[start_pos : start_pos + 2]
518 start_pos += 2
519 length = UNPACK_LEN2(data)[0]
520 self._payload_length = length
521 self._state = (
522 WSParserState.READ_PAYLOAD_MASK
523 if self._has_mask
524 else WSParserState.READ_PAYLOAD
525 )
526 else:
527 break
528 elif length > 126:
529 if buf_length - start_pos >= 8:
530 data = buf[start_pos : start_pos + 8]
531 start_pos += 8
532 length = UNPACK_LEN3(data)[0]
533 self._payload_length = length
534 self._state = (
535 WSParserState.READ_PAYLOAD_MASK
536 if self._has_mask
537 else WSParserState.READ_PAYLOAD
538 )
539 else:
540 break
541 else:
542 self._payload_length = length
543 self._state = (
544 WSParserState.READ_PAYLOAD_MASK
545 if self._has_mask
546 else WSParserState.READ_PAYLOAD
547 )
549 # read payload mask
550 if self._state == WSParserState.READ_PAYLOAD_MASK:
551 if buf_length - start_pos >= 4:
552 self._frame_mask = buf[start_pos : start_pos + 4]
553 start_pos += 4
554 self._state = WSParserState.READ_PAYLOAD
555 else:
556 break
558 if self._state == WSParserState.READ_PAYLOAD:
559 length = self._payload_length
560 payload = self._frame_payload
562 chunk_len = buf_length - start_pos
563 if length >= chunk_len:
564 self._payload_length = length - chunk_len
565 payload.extend(buf[start_pos:])
566 start_pos = buf_length
567 else:
568 self._payload_length = 0
569 payload.extend(buf[start_pos : start_pos + length])
570 start_pos = start_pos + length
572 if self._payload_length == 0:
573 if self._has_mask:
574 assert self._frame_mask is not None
575 _websocket_mask(self._frame_mask, payload)
577 frames.append(
578 (self._frame_fin, self._frame_opcode, payload, self._compressed)
579 )
581 self._frame_payload = bytearray()
582 self._state = WSParserState.READ_HEADER
583 else:
584 break
586 self._tail = buf[start_pos:]
588 return frames
591class WebSocketWriter:
592 def __init__(
593 self,
594 protocol: BaseProtocol,
595 transport: asyncio.Transport,
596 *,
597 use_mask: bool = False,
598 limit: int = DEFAULT_LIMIT,
599 random: random.Random = random.Random(),
600 compress: int = 0,
601 notakeover: bool = False,
602 ) -> None:
603 self.protocol = protocol
604 self.transport = transport
605 self.use_mask = use_mask
606 self.randrange = random.randrange
607 self.compress = compress
608 self.notakeover = notakeover
609 self._closing = False
610 self._limit = limit
611 self._output_size = 0
612 self._compressobj: Any = None # actually compressobj
614 async def _send_frame(
615 self, message: bytes, opcode: int, compress: Optional[int] = None
616 ) -> None:
617 """Send a frame over the websocket with message as its payload."""
618 if self._closing and not (opcode & WSMsgType.CLOSE):
619 raise ConnectionResetError("Cannot write to closing transport")
621 rsv = 0
623 # Only compress larger packets (disabled)
624 # Does small packet needs to be compressed?
625 # if self.compress and opcode < 8 and len(message) > 124:
626 if (compress or self.compress) and opcode < 8:
627 if compress:
628 # Do not set self._compress if compressing is for this frame
629 compressobj = self._make_compress_obj(compress)
630 else: # self.compress
631 if not self._compressobj:
632 self._compressobj = self._make_compress_obj(self.compress)
633 compressobj = self._compressobj
635 message = await compressobj.compress(message)
636 # Its critical that we do not return control to the event
637 # loop until we have finished sending all the compressed
638 # data. Otherwise we could end up mixing compressed frames
639 # if there are multiple coroutines compressing data.
640 message += compressobj.flush(
641 zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
642 )
643 if message.endswith(_WS_DEFLATE_TRAILING):
644 message = message[:-4]
645 rsv = rsv | 0x40
647 msg_length = len(message)
649 use_mask = self.use_mask
650 if use_mask:
651 mask_bit = 0x80
652 else:
653 mask_bit = 0
655 if msg_length < 126:
656 header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
657 elif msg_length < (1 << 16):
658 header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
659 else:
660 header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
661 if use_mask:
662 mask_int = self.randrange(0, 0xFFFFFFFF)
663 mask = mask_int.to_bytes(4, "big")
664 message = bytearray(message)
665 _websocket_mask(mask, message)
666 self._write(header + mask + message)
667 self._output_size += len(header) + len(mask) + msg_length
668 else:
669 if msg_length > MSG_SIZE:
670 self._write(header)
671 self._write(message)
672 else:
673 self._write(header + message)
675 self._output_size += len(header) + msg_length
677 # It is safe to return control to the event loop when using compression
678 # after this point as we have already sent or buffered all the data.
680 if self._output_size > self._limit:
681 self._output_size = 0
682 await self.protocol._drain_helper()
684 def _make_compress_obj(self, compress: int) -> ZLibCompressor:
685 return ZLibCompressor(
686 level=zlib.Z_BEST_SPEED,
687 wbits=-compress,
688 max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
689 )
691 def _write(self, data: bytes) -> None:
692 if self.transport.is_closing():
693 raise ConnectionResetError("Cannot write to closing transport")
694 self.transport.write(data)
696 async def pong(self, message: Union[bytes, str] = b"") -> None:
697 """Send pong message."""
698 if isinstance(message, str):
699 message = message.encode("utf-8")
700 await self._send_frame(message, WSMsgType.PONG)
702 async def ping(self, message: Union[bytes, str] = b"") -> None:
703 """Send ping message."""
704 if isinstance(message, str):
705 message = message.encode("utf-8")
706 await self._send_frame(message, WSMsgType.PING)
708 async def send(
709 self,
710 message: Union[str, bytes],
711 binary: bool = False,
712 compress: Optional[int] = None,
713 ) -> None:
714 """Send a frame over the websocket with message as its payload."""
715 if isinstance(message, str):
716 message = message.encode("utf-8")
717 if binary:
718 await self._send_frame(message, WSMsgType.BINARY, compress)
719 else:
720 await self._send_frame(message, WSMsgType.TEXT, compress)
722 async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
723 """Close the websocket, sending the specified code and message."""
724 if isinstance(message, str):
725 message = message.encode("utf-8")
726 try:
727 await self._send_frame(
728 PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
729 )
730 finally:
731 self._closing = True