Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/http_websocket.py: 28%
378 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
1"""WebSocket protocol versions 13 and 8."""
3import asyncio
4import functools
5import json
6import random
7import re
8import sys
9import zlib
10from enum import IntEnum
11from struct import Struct
12from typing import (
13 Any,
14 Callable,
15 List,
16 NamedTuple,
17 Optional,
18 Pattern,
19 Set,
20 Tuple,
21 Union,
22 cast,
23)
25from typing_extensions import Final
27from .base_protocol import BaseProtocol
28from .compression_utils import ZLibCompressor, ZLibDecompressor
29from .helpers import NO_EXTENSIONS
30from .streams import DataQueue
32__all__ = (
33 "WS_CLOSED_MESSAGE",
34 "WS_CLOSING_MESSAGE",
35 "WS_KEY",
36 "WebSocketReader",
37 "WebSocketWriter",
38 "WSMessage",
39 "WebSocketError",
40 "WSMsgType",
41 "WSCloseCode",
42)
45class WSCloseCode(IntEnum):
46 OK = 1000
47 GOING_AWAY = 1001
48 PROTOCOL_ERROR = 1002
49 UNSUPPORTED_DATA = 1003
50 ABNORMAL_CLOSURE = 1006
51 INVALID_TEXT = 1007
52 POLICY_VIOLATION = 1008
53 MESSAGE_TOO_BIG = 1009
54 MANDATORY_EXTENSION = 1010
55 INTERNAL_ERROR = 1011
56 SERVICE_RESTART = 1012
57 TRY_AGAIN_LATER = 1013
58 BAD_GATEWAY = 1014
61ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
64class WSMsgType(IntEnum):
65 # websocket spec types
66 CONTINUATION = 0x0
67 TEXT = 0x1
68 BINARY = 0x2
69 PING = 0x9
70 PONG = 0xA
71 CLOSE = 0x8
73 # aiohttp specific types
74 CLOSING = 0x100
75 CLOSED = 0x101
76 ERROR = 0x102
79WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
82UNPACK_LEN2 = Struct("!H").unpack_from
83UNPACK_LEN3 = Struct("!Q").unpack_from
84UNPACK_CLOSE_CODE = Struct("!H").unpack
85PACK_LEN1 = Struct("!BB").pack
86PACK_LEN2 = Struct("!BBH").pack
87PACK_LEN3 = Struct("!BBQ").pack
88PACK_CLOSE_CODE = Struct("!H").pack
89MSG_SIZE: Final[int] = 2**14
90DEFAULT_LIMIT: Final[int] = 2**16
93class WSMessage(NamedTuple):
94 type: WSMsgType
95 # To type correctly, this would need some kind of tagged union for each type.
96 data: Any
97 extra: Optional[str]
99 def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
100 """Return parsed JSON data.
102 .. versionadded:: 0.22
103 """
104 return loads(self.data)
107WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
108WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
111class WebSocketError(Exception):
112 """WebSocket protocol parser error."""
114 def __init__(self, code: int, message: str) -> None:
115 self.code = code
116 super().__init__(code, message)
118 def __str__(self) -> str:
119 return cast(str, self.args[1])
122class WSHandshakeError(Exception):
123 """WebSocket protocol handshake error."""
126native_byteorder: Final[str] = sys.byteorder
129# Used by _websocket_mask_python
130@functools.lru_cache()
131def _xor_table() -> List[bytes]:
132 return [bytes(a ^ b for a in range(256)) for b in range(256)]
135def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
136 """Websocket masking function.
138 `mask` is a `bytes` object of length 4; `data` is a `bytearray`
139 object of any length. The contents of `data` are masked with `mask`,
140 as specified in section 5.3 of RFC 6455.
142 Note that this function mutates the `data` argument.
144 This pure-python implementation may be replaced by an optimized
145 version when available.
147 """
148 assert isinstance(data, bytearray), data
149 assert len(mask) == 4, mask
151 if data:
152 _XOR_TABLE = _xor_table()
153 a, b, c, d = (_XOR_TABLE[n] for n in mask)
154 data[::4] = data[::4].translate(a)
155 data[1::4] = data[1::4].translate(b)
156 data[2::4] = data[2::4].translate(c)
157 data[3::4] = data[3::4].translate(d)
160if NO_EXTENSIONS: # pragma: no cover
161 _websocket_mask = _websocket_mask_python
162else:
163 try:
164 from ._websocket import _websocket_mask_cython # type: ignore[import]
166 _websocket_mask = _websocket_mask_cython
167 except ImportError: # pragma: no cover
168 _websocket_mask = _websocket_mask_python
170_WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
173_WS_EXT_RE: Final[Pattern[str]] = re.compile(
174 r"^(?:;\s*(?:"
175 r"(server_no_context_takeover)|"
176 r"(client_no_context_takeover)|"
177 r"(server_max_window_bits(?:=(\d+))?)|"
178 r"(client_max_window_bits(?:=(\d+))?)))*$"
179)
181_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
184def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
185 if not extstr:
186 return 0, False
188 compress = 0
189 notakeover = False
190 for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
191 defext = ext.group(1)
192 # Return compress = 15 when get `permessage-deflate`
193 if not defext:
194 compress = 15
195 break
196 match = _WS_EXT_RE.match(defext)
197 if match:
198 compress = 15
199 if isserver:
200 # Server never fail to detect compress handshake.
201 # Server does not need to send max wbit to client
202 if match.group(4):
203 compress = int(match.group(4))
204 # Group3 must match if group4 matches
205 # Compress wbit 8 does not support in zlib
206 # If compress level not support,
207 # CONTINUE to next extension
208 if compress > 15 or compress < 9:
209 compress = 0
210 continue
211 if match.group(1):
212 notakeover = True
213 # Ignore regex group 5 & 6 for client_max_window_bits
214 break
215 else:
216 if match.group(6):
217 compress = int(match.group(6))
218 # Group5 must match if group6 matches
219 # Compress wbit 8 does not support in zlib
220 # If compress level not support,
221 # FAIL the parse progress
222 if compress > 15 or compress < 9:
223 raise WSHandshakeError("Invalid window size")
224 if match.group(2):
225 notakeover = True
226 # Ignore regex group 5 & 6 for client_max_window_bits
227 break
228 # Return Fail if client side and not match
229 elif not isserver:
230 raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
232 return compress, notakeover
235def ws_ext_gen(
236 compress: int = 15, isserver: bool = False, server_notakeover: bool = False
237) -> str:
238 # client_notakeover=False not used for server
239 # compress wbit 8 does not support in zlib
240 if compress < 9 or compress > 15:
241 raise ValueError(
242 "Compress wbits must between 9 and 15, " "zlib does not support wbits=8"
243 )
244 enabledext = ["permessage-deflate"]
245 if not isserver:
246 enabledext.append("client_max_window_bits")
248 if compress < 15:
249 enabledext.append("server_max_window_bits=" + str(compress))
250 if server_notakeover:
251 enabledext.append("server_no_context_takeover")
252 # if client_notakeover:
253 # enabledext.append('client_no_context_takeover')
254 return "; ".join(enabledext)
257class WSParserState(IntEnum):
258 READ_HEADER = 1
259 READ_PAYLOAD_LENGTH = 2
260 READ_PAYLOAD_MASK = 3
261 READ_PAYLOAD = 4
264class WebSocketReader:
265 def __init__(
266 self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True
267 ) -> None:
268 self.queue = queue
269 self._max_msg_size = max_msg_size
271 self._exc: Optional[BaseException] = None
272 self._partial = bytearray()
273 self._state = WSParserState.READ_HEADER
275 self._opcode: Optional[int] = None
276 self._frame_fin = False
277 self._frame_opcode: Optional[int] = None
278 self._frame_payload = bytearray()
280 self._tail = b""
281 self._has_mask = False
282 self._frame_mask: Optional[bytes] = None
283 self._payload_length = 0
284 self._payload_length_flag = 0
285 self._compressed: Optional[bool] = None
286 self._decompressobj: Optional[ZLibDecompressor] = None
287 self._compress = compress
289 def feed_eof(self) -> None:
290 self.queue.feed_eof()
292 def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
293 if self._exc:
294 return True, data
296 try:
297 return self._feed_data(data)
298 except Exception as exc:
299 self._exc = exc
300 self.queue.set_exception(exc)
301 return True, b""
303 def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
304 for fin, opcode, payload, compressed in self.parse_frame(data):
305 if compressed and not self._decompressobj:
306 self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
307 if opcode == WSMsgType.CLOSE:
308 if len(payload) >= 2:
309 close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
310 if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
311 raise WebSocketError(
312 WSCloseCode.PROTOCOL_ERROR,
313 f"Invalid close code: {close_code}",
314 )
315 try:
316 close_message = payload[2:].decode("utf-8")
317 except UnicodeDecodeError as exc:
318 raise WebSocketError(
319 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
320 ) from exc
321 msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
322 elif payload:
323 raise WebSocketError(
324 WSCloseCode.PROTOCOL_ERROR,
325 f"Invalid close frame: {fin} {opcode} {payload!r}",
326 )
327 else:
328 msg = WSMessage(WSMsgType.CLOSE, 0, "")
330 self.queue.feed_data(msg, 0)
332 elif opcode == WSMsgType.PING:
333 self.queue.feed_data(
334 WSMessage(WSMsgType.PING, payload, ""), len(payload)
335 )
337 elif opcode == WSMsgType.PONG:
338 self.queue.feed_data(
339 WSMessage(WSMsgType.PONG, payload, ""), len(payload)
340 )
342 elif (
343 opcode not in (WSMsgType.TEXT, WSMsgType.BINARY)
344 and self._opcode is None
345 ):
346 raise WebSocketError(
347 WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
348 )
349 else:
350 # load text/binary
351 if not fin:
352 # got partial frame payload
353 if opcode != WSMsgType.CONTINUATION:
354 self._opcode = opcode
355 self._partial.extend(payload)
356 if self._max_msg_size and len(self._partial) >= self._max_msg_size:
357 raise WebSocketError(
358 WSCloseCode.MESSAGE_TOO_BIG,
359 "Message size {} exceeds limit {}".format(
360 len(self._partial), self._max_msg_size
361 ),
362 )
363 else:
364 # previous frame was non finished
365 # we should get continuation opcode
366 if self._partial:
367 if opcode != WSMsgType.CONTINUATION:
368 raise WebSocketError(
369 WSCloseCode.PROTOCOL_ERROR,
370 "The opcode in non-fin frame is expected "
371 "to be zero, got {!r}".format(opcode),
372 )
374 if opcode == WSMsgType.CONTINUATION:
375 assert self._opcode is not None
376 opcode = self._opcode
377 self._opcode = None
379 self._partial.extend(payload)
380 if self._max_msg_size and len(self._partial) >= self._max_msg_size:
381 raise WebSocketError(
382 WSCloseCode.MESSAGE_TOO_BIG,
383 "Message size {} exceeds limit {}".format(
384 len(self._partial), self._max_msg_size
385 ),
386 )
388 # Decompress process must to be done after all packets
389 # received.
390 if compressed:
391 assert self._decompressobj is not None
392 self._partial.extend(_WS_DEFLATE_TRAILING)
393 payload_merged = self._decompressobj.decompress_sync(
394 self._partial, self._max_msg_size
395 )
396 if self._decompressobj.unconsumed_tail:
397 left = len(self._decompressobj.unconsumed_tail)
398 raise WebSocketError(
399 WSCloseCode.MESSAGE_TOO_BIG,
400 "Decompressed message size {} exceeds limit {}".format(
401 self._max_msg_size + left, self._max_msg_size
402 ),
403 )
404 else:
405 payload_merged = bytes(self._partial)
407 self._partial.clear()
409 if opcode == WSMsgType.TEXT:
410 try:
411 text = payload_merged.decode("utf-8")
412 self.queue.feed_data(
413 WSMessage(WSMsgType.TEXT, text, ""), len(text)
414 )
415 except UnicodeDecodeError as exc:
416 raise WebSocketError(
417 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
418 ) from exc
419 else:
420 self.queue.feed_data(
421 WSMessage(WSMsgType.BINARY, payload_merged, ""),
422 len(payload_merged),
423 )
425 return False, b""
427 def parse_frame(
428 self, buf: bytes
429 ) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]:
430 """Return the next frame from the socket."""
431 frames = []
432 if self._tail:
433 buf, self._tail = self._tail + buf, b""
435 start_pos = 0
436 buf_length = len(buf)
438 while True:
439 # read header
440 if self._state == WSParserState.READ_HEADER:
441 if buf_length - start_pos >= 2:
442 data = buf[start_pos : start_pos + 2]
443 start_pos += 2
444 first_byte, second_byte = data
446 fin = (first_byte >> 7) & 1
447 rsv1 = (first_byte >> 6) & 1
448 rsv2 = (first_byte >> 5) & 1
449 rsv3 = (first_byte >> 4) & 1
450 opcode = first_byte & 0xF
452 # frame-fin = %x0 ; more frames of this message follow
453 # / %x1 ; final frame of this message
454 # frame-rsv1 = %x0 ;
455 # 1 bit, MUST be 0 unless negotiated otherwise
456 # frame-rsv2 = %x0 ;
457 # 1 bit, MUST be 0 unless negotiated otherwise
458 # frame-rsv3 = %x0 ;
459 # 1 bit, MUST be 0 unless negotiated otherwise
460 #
461 # Remove rsv1 from this test for deflate development
462 if rsv2 or rsv3 or (rsv1 and not self._compress):
463 raise WebSocketError(
464 WSCloseCode.PROTOCOL_ERROR,
465 "Received frame with non-zero reserved bits",
466 )
468 if opcode > 0x7 and fin == 0:
469 raise WebSocketError(
470 WSCloseCode.PROTOCOL_ERROR,
471 "Received fragmented control frame",
472 )
474 has_mask = (second_byte >> 7) & 1
475 length = second_byte & 0x7F
477 # Control frames MUST have a payload
478 # length of 125 bytes or less
479 if opcode > 0x7 and length > 125:
480 raise WebSocketError(
481 WSCloseCode.PROTOCOL_ERROR,
482 "Control frame payload cannot be " "larger than 125 bytes",
483 )
485 # Set compress status if last package is FIN
486 # OR set compress status if this is first fragment
487 # Raise error if not first fragment with rsv1 = 0x1
488 if self._frame_fin or self._compressed is None:
489 self._compressed = True if rsv1 else False
490 elif rsv1:
491 raise WebSocketError(
492 WSCloseCode.PROTOCOL_ERROR,
493 "Received frame with non-zero reserved bits",
494 )
496 self._frame_fin = bool(fin)
497 self._frame_opcode = opcode
498 self._has_mask = bool(has_mask)
499 self._payload_length_flag = length
500 self._state = WSParserState.READ_PAYLOAD_LENGTH
501 else:
502 break
504 # read payload length
505 if self._state == WSParserState.READ_PAYLOAD_LENGTH:
506 length = self._payload_length_flag
507 if length == 126:
508 if buf_length - start_pos >= 2:
509 data = buf[start_pos : start_pos + 2]
510 start_pos += 2
511 length = UNPACK_LEN2(data)[0]
512 self._payload_length = length
513 self._state = (
514 WSParserState.READ_PAYLOAD_MASK
515 if self._has_mask
516 else WSParserState.READ_PAYLOAD
517 )
518 else:
519 break
520 elif length > 126:
521 if buf_length - start_pos >= 8:
522 data = buf[start_pos : start_pos + 8]
523 start_pos += 8
524 length = UNPACK_LEN3(data)[0]
525 self._payload_length = length
526 self._state = (
527 WSParserState.READ_PAYLOAD_MASK
528 if self._has_mask
529 else WSParserState.READ_PAYLOAD
530 )
531 else:
532 break
533 else:
534 self._payload_length = length
535 self._state = (
536 WSParserState.READ_PAYLOAD_MASK
537 if self._has_mask
538 else WSParserState.READ_PAYLOAD
539 )
541 # read payload mask
542 if self._state == WSParserState.READ_PAYLOAD_MASK:
543 if buf_length - start_pos >= 4:
544 self._frame_mask = buf[start_pos : start_pos + 4]
545 start_pos += 4
546 self._state = WSParserState.READ_PAYLOAD
547 else:
548 break
550 if self._state == WSParserState.READ_PAYLOAD:
551 length = self._payload_length
552 payload = self._frame_payload
554 chunk_len = buf_length - start_pos
555 if length >= chunk_len:
556 self._payload_length = length - chunk_len
557 payload.extend(buf[start_pos:])
558 start_pos = buf_length
559 else:
560 self._payload_length = 0
561 payload.extend(buf[start_pos : start_pos + length])
562 start_pos = start_pos + length
564 if self._payload_length == 0:
565 if self._has_mask:
566 assert self._frame_mask is not None
567 _websocket_mask(self._frame_mask, payload)
569 frames.append(
570 (self._frame_fin, self._frame_opcode, payload, self._compressed)
571 )
573 self._frame_payload = bytearray()
574 self._state = WSParserState.READ_HEADER
575 else:
576 break
578 self._tail = buf[start_pos:]
580 return frames
583class WebSocketWriter:
584 def __init__(
585 self,
586 protocol: BaseProtocol,
587 transport: asyncio.Transport,
588 *,
589 use_mask: bool = False,
590 limit: int = DEFAULT_LIMIT,
591 random: Any = random.Random(),
592 compress: int = 0,
593 notakeover: bool = False,
594 ) -> None:
595 self.protocol = protocol
596 self.transport = transport
597 self.use_mask = use_mask
598 self.randrange = random.randrange
599 self.compress = compress
600 self.notakeover = notakeover
601 self._closing = False
602 self._limit = limit
603 self._output_size = 0
604 self._compressobj: Any = None # actually compressobj
606 async def _send_frame(
607 self, message: bytes, opcode: int, compress: Optional[int] = None
608 ) -> None:
609 """Send a frame over the websocket with message as its payload."""
610 if self._closing and not (opcode & WSMsgType.CLOSE):
611 raise ConnectionResetError("Cannot write to closing transport")
613 rsv = 0
615 # Only compress larger packets (disabled)
616 # Does small packet needs to be compressed?
617 # if self.compress and opcode < 8 and len(message) > 124:
618 if (compress or self.compress) and opcode < 8:
619 if compress:
620 # Do not set self._compress if compressing is for this frame
621 compressobj = ZLibCompressor(level=zlib.Z_BEST_SPEED, wbits=-compress)
622 else: # self.compress
623 if not self._compressobj:
624 self._compressobj = ZLibCompressor(
625 level=zlib.Z_BEST_SPEED, wbits=-self.compress
626 )
627 compressobj = self._compressobj
629 message = await compressobj.compress(message)
630 message += compressobj.flush(
631 zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
632 )
633 if message.endswith(_WS_DEFLATE_TRAILING):
634 message = message[:-4]
635 rsv = rsv | 0x40
637 msg_length = len(message)
639 use_mask = self.use_mask
640 if use_mask:
641 mask_bit = 0x80
642 else:
643 mask_bit = 0
645 if msg_length < 126:
646 header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
647 elif msg_length < (1 << 16):
648 header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
649 else:
650 header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
651 if use_mask:
652 mask = self.randrange(0, 0xFFFFFFFF)
653 mask = mask.to_bytes(4, "big")
654 message = bytearray(message)
655 _websocket_mask(mask, message)
656 self._write(header + mask + message)
657 self._output_size += len(header) + len(mask) + len(message)
658 else:
659 if len(message) > MSG_SIZE:
660 self._write(header)
661 self._write(message)
662 else:
663 self._write(header + message)
665 self._output_size += len(header) + len(message)
667 if self._output_size > self._limit:
668 self._output_size = 0
669 await self.protocol._drain_helper()
671 def _write(self, data: bytes) -> None:
672 if self.transport is None or self.transport.is_closing():
673 raise ConnectionResetError("Cannot write to closing transport")
674 self.transport.write(data)
676 async def pong(self, message: Union[bytes, str] = b"") -> None:
677 """Send pong message."""
678 if isinstance(message, str):
679 message = message.encode("utf-8")
680 await self._send_frame(message, WSMsgType.PONG)
682 async def ping(self, message: Union[bytes, str] = b"") -> None:
683 """Send ping message."""
684 if isinstance(message, str):
685 message = message.encode("utf-8")
686 await self._send_frame(message, WSMsgType.PING)
688 async def send(
689 self,
690 message: Union[str, bytes],
691 binary: bool = False,
692 compress: Optional[int] = None,
693 ) -> None:
694 """Send a frame over the websocket with message as its payload."""
695 if isinstance(message, str):
696 message = message.encode("utf-8")
697 if binary:
698 await self._send_frame(message, WSMsgType.BINARY, compress)
699 else:
700 await self._send_frame(message, WSMsgType.TEXT, compress)
702 async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
703 """Close the websocket, sending the specified code and message."""
704 if isinstance(message, str):
705 message = message.encode("utf-8")
706 try:
707 await self._send_frame(
708 PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
709 )
710 finally:
711 self._closing = True