Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/aiohttp/http_websocket.py: 27%
379 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:56 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:56 +0000
1"""WebSocket protocol versions 13 and 8."""
3import asyncio
4import collections
5import json
6import random
7import re
8import sys
9import zlib
10from enum import IntEnum
11from struct import Struct
12from typing import Any, Callable, List, Optional, Pattern, Set, Tuple, Union, cast
14from .base_protocol import BaseProtocol
15from .helpers import NO_EXTENSIONS
16from .streams import DataQueue
17from .typedefs import Final
19__all__ = (
20 "WS_CLOSED_MESSAGE",
21 "WS_CLOSING_MESSAGE",
22 "WS_KEY",
23 "WebSocketReader",
24 "WebSocketWriter",
25 "WSMessage",
26 "WebSocketError",
27 "WSMsgType",
28 "WSCloseCode",
29)
32class WSCloseCode(IntEnum):
33 OK = 1000
34 GOING_AWAY = 1001
35 PROTOCOL_ERROR = 1002
36 UNSUPPORTED_DATA = 1003
37 ABNORMAL_CLOSURE = 1006
38 INVALID_TEXT = 1007
39 POLICY_VIOLATION = 1008
40 MESSAGE_TOO_BIG = 1009
41 MANDATORY_EXTENSION = 1010
42 INTERNAL_ERROR = 1011
43 SERVICE_RESTART = 1012
44 TRY_AGAIN_LATER = 1013
45 BAD_GATEWAY = 1014
48ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
51class WSMsgType(IntEnum):
52 # websocket spec types
53 CONTINUATION = 0x0
54 TEXT = 0x1
55 BINARY = 0x2
56 PING = 0x9
57 PONG = 0xA
58 CLOSE = 0x8
60 # aiohttp specific types
61 CLOSING = 0x100
62 CLOSED = 0x101
63 ERROR = 0x102
65 text = TEXT
66 binary = BINARY
67 ping = PING
68 pong = PONG
69 close = CLOSE
70 closing = CLOSING
71 closed = CLOSED
72 error = ERROR
75WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
78UNPACK_LEN2 = Struct("!H").unpack_from
79UNPACK_LEN3 = Struct("!Q").unpack_from
80UNPACK_CLOSE_CODE = Struct("!H").unpack
81PACK_LEN1 = Struct("!BB").pack
82PACK_LEN2 = Struct("!BBH").pack
83PACK_LEN3 = Struct("!BBQ").pack
84PACK_CLOSE_CODE = Struct("!H").pack
85MSG_SIZE: Final[int] = 2**14
86DEFAULT_LIMIT: Final[int] = 2**16
89_WSMessageBase = collections.namedtuple("_WSMessageBase", ["type", "data", "extra"])
92class WSMessage(_WSMessageBase):
93 def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
94 """Return parsed JSON data.
96 .. versionadded:: 0.22
97 """
98 return loads(self.data)
101WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
102WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
105class WebSocketError(Exception):
106 """WebSocket protocol parser error."""
108 def __init__(self, code: int, message: str) -> None:
109 self.code = code
110 super().__init__(code, message)
112 def __str__(self) -> str:
113 return cast(str, self.args[1])
116class WSHandshakeError(Exception):
117 """WebSocket protocol handshake error."""
120native_byteorder: Final[str] = sys.byteorder
123# Used by _websocket_mask_python
124_XOR_TABLE: Final[List[bytes]] = [bytes(a ^ b for a in range(256)) for b in range(256)]
127def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
128 """Websocket masking function.
130 `mask` is a `bytes` object of length 4; `data` is a `bytearray`
131 object of any length. The contents of `data` are masked with `mask`,
132 as specified in section 5.3 of RFC 6455.
134 Note that this function mutates the `data` argument.
136 This pure-python implementation may be replaced by an optimized
137 version when available.
139 """
140 assert isinstance(data, bytearray), data
141 assert len(mask) == 4, mask
143 if data:
144 a, b, c, d = (_XOR_TABLE[n] for n in mask)
145 data[::4] = data[::4].translate(a)
146 data[1::4] = data[1::4].translate(b)
147 data[2::4] = data[2::4].translate(c)
148 data[3::4] = data[3::4].translate(d)
151if NO_EXTENSIONS: # pragma: no cover
152 _websocket_mask = _websocket_mask_python
153else:
154 try:
155 from ._websocket import _websocket_mask_cython # type: ignore[import]
157 _websocket_mask = _websocket_mask_cython
158 except ImportError: # pragma: no cover
159 _websocket_mask = _websocket_mask_python
161_WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
164_WS_EXT_RE: Final[Pattern[str]] = re.compile(
165 r"^(?:;\s*(?:"
166 r"(server_no_context_takeover)|"
167 r"(client_no_context_takeover)|"
168 r"(server_max_window_bits(?:=(\d+))?)|"
169 r"(client_max_window_bits(?:=(\d+))?)))*$"
170)
172_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
175def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
176 if not extstr:
177 return 0, False
179 compress = 0
180 notakeover = False
181 for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
182 defext = ext.group(1)
183 # Return compress = 15 when get `permessage-deflate`
184 if not defext:
185 compress = 15
186 break
187 match = _WS_EXT_RE.match(defext)
188 if match:
189 compress = 15
190 if isserver:
191 # Server never fail to detect compress handshake.
192 # Server does not need to send max wbit to client
193 if match.group(4):
194 compress = int(match.group(4))
195 # Group3 must match if group4 matches
196 # Compress wbit 8 does not support in zlib
197 # If compress level not support,
198 # CONTINUE to next extension
199 if compress > 15 or compress < 9:
200 compress = 0
201 continue
202 if match.group(1):
203 notakeover = True
204 # Ignore regex group 5 & 6 for client_max_window_bits
205 break
206 else:
207 if match.group(6):
208 compress = int(match.group(6))
209 # Group5 must match if group6 matches
210 # Compress wbit 8 does not support in zlib
211 # If compress level not support,
212 # FAIL the parse progress
213 if compress > 15 or compress < 9:
214 raise WSHandshakeError("Invalid window size")
215 if match.group(2):
216 notakeover = True
217 # Ignore regex group 5 & 6 for client_max_window_bits
218 break
219 # Return Fail if client side and not match
220 elif not isserver:
221 raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
223 return compress, notakeover
226def ws_ext_gen(
227 compress: int = 15, isserver: bool = False, server_notakeover: bool = False
228) -> str:
229 # client_notakeover=False not used for server
230 # compress wbit 8 does not support in zlib
231 if compress < 9 or compress > 15:
232 raise ValueError(
233 "Compress wbits must between 9 and 15, " "zlib does not support wbits=8"
234 )
235 enabledext = ["permessage-deflate"]
236 if not isserver:
237 enabledext.append("client_max_window_bits")
239 if compress < 15:
240 enabledext.append("server_max_window_bits=" + str(compress))
241 if server_notakeover:
242 enabledext.append("server_no_context_takeover")
243 # if client_notakeover:
244 # enabledext.append('client_no_context_takeover')
245 return "; ".join(enabledext)
248class WSParserState(IntEnum):
249 READ_HEADER = 1
250 READ_PAYLOAD_LENGTH = 2
251 READ_PAYLOAD_MASK = 3
252 READ_PAYLOAD = 4
255class WebSocketReader:
256 def __init__(
257 self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True
258 ) -> None:
259 self.queue = queue
260 self._max_msg_size = max_msg_size
262 self._exc: Optional[BaseException] = None
263 self._partial = bytearray()
264 self._state = WSParserState.READ_HEADER
266 self._opcode: Optional[int] = None
267 self._frame_fin = False
268 self._frame_opcode: Optional[int] = None
269 self._frame_payload = bytearray()
271 self._tail = b""
272 self._has_mask = False
273 self._frame_mask: Optional[bytes] = None
274 self._payload_length = 0
275 self._payload_length_flag = 0
276 self._compressed: Optional[bool] = None
277 self._decompressobj: Any = None # zlib.decompressobj actually
278 self._compress = compress
280 def feed_eof(self) -> None:
281 self.queue.feed_eof()
283 def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
284 if self._exc:
285 return True, data
287 try:
288 return self._feed_data(data)
289 except Exception as exc:
290 self._exc = exc
291 self.queue.set_exception(exc)
292 return True, b""
294 def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
295 for fin, opcode, payload, compressed in self.parse_frame(data):
296 if compressed and not self._decompressobj:
297 self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
298 if opcode == WSMsgType.CLOSE:
299 if len(payload) >= 2:
300 close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
301 if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
302 raise WebSocketError(
303 WSCloseCode.PROTOCOL_ERROR,
304 f"Invalid close code: {close_code}",
305 )
306 try:
307 close_message = payload[2:].decode("utf-8")
308 except UnicodeDecodeError as exc:
309 raise WebSocketError(
310 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
311 ) from exc
312 msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
313 elif payload:
314 raise WebSocketError(
315 WSCloseCode.PROTOCOL_ERROR,
316 f"Invalid close frame: {fin} {opcode} {payload!r}",
317 )
318 else:
319 msg = WSMessage(WSMsgType.CLOSE, 0, "")
321 self.queue.feed_data(msg, 0)
323 elif opcode == WSMsgType.PING:
324 self.queue.feed_data(
325 WSMessage(WSMsgType.PING, payload, ""), len(payload)
326 )
328 elif opcode == WSMsgType.PONG:
329 self.queue.feed_data(
330 WSMessage(WSMsgType.PONG, payload, ""), len(payload)
331 )
333 elif (
334 opcode not in (WSMsgType.TEXT, WSMsgType.BINARY)
335 and self._opcode is None
336 ):
337 raise WebSocketError(
338 WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
339 )
340 else:
341 # load text/binary
342 if not fin:
343 # got partial frame payload
344 if opcode != WSMsgType.CONTINUATION:
345 self._opcode = opcode
346 self._partial.extend(payload)
347 if self._max_msg_size and len(self._partial) >= self._max_msg_size:
348 raise WebSocketError(
349 WSCloseCode.MESSAGE_TOO_BIG,
350 "Message size {} exceeds limit {}".format(
351 len(self._partial), self._max_msg_size
352 ),
353 )
354 else:
355 # previous frame was non finished
356 # we should get continuation opcode
357 if self._partial:
358 if opcode != WSMsgType.CONTINUATION:
359 raise WebSocketError(
360 WSCloseCode.PROTOCOL_ERROR,
361 "The opcode in non-fin frame is expected "
362 "to be zero, got {!r}".format(opcode),
363 )
365 if opcode == WSMsgType.CONTINUATION:
366 assert self._opcode is not None
367 opcode = self._opcode
368 self._opcode = None
370 self._partial.extend(payload)
371 if self._max_msg_size and len(self._partial) >= self._max_msg_size:
372 raise WebSocketError(
373 WSCloseCode.MESSAGE_TOO_BIG,
374 "Message size {} exceeds limit {}".format(
375 len(self._partial), self._max_msg_size
376 ),
377 )
379 # Decompress process must to be done after all packets
380 # received.
381 if compressed:
382 self._partial.extend(_WS_DEFLATE_TRAILING)
383 payload_merged = self._decompressobj.decompress(
384 self._partial, self._max_msg_size
385 )
386 if self._decompressobj.unconsumed_tail:
387 left = len(self._decompressobj.unconsumed_tail)
388 raise WebSocketError(
389 WSCloseCode.MESSAGE_TOO_BIG,
390 "Decompressed message size {} exceeds limit {}".format(
391 self._max_msg_size + left, self._max_msg_size
392 ),
393 )
394 else:
395 payload_merged = bytes(self._partial)
397 self._partial.clear()
399 if opcode == WSMsgType.TEXT:
400 try:
401 text = payload_merged.decode("utf-8")
402 self.queue.feed_data(
403 WSMessage(WSMsgType.TEXT, text, ""), len(text)
404 )
405 except UnicodeDecodeError as exc:
406 raise WebSocketError(
407 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
408 ) from exc
409 else:
410 self.queue.feed_data(
411 WSMessage(WSMsgType.BINARY, payload_merged, ""),
412 len(payload_merged),
413 )
415 return False, b""
417 def parse_frame(
418 self, buf: bytes
419 ) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]:
420 """Return the next frame from the socket."""
421 frames = []
422 if self._tail:
423 buf, self._tail = self._tail + buf, b""
425 start_pos = 0
426 buf_length = len(buf)
428 while True:
429 # read header
430 if self._state == WSParserState.READ_HEADER:
431 if buf_length - start_pos >= 2:
432 data = buf[start_pos : start_pos + 2]
433 start_pos += 2
434 first_byte, second_byte = data
436 fin = (first_byte >> 7) & 1
437 rsv1 = (first_byte >> 6) & 1
438 rsv2 = (first_byte >> 5) & 1
439 rsv3 = (first_byte >> 4) & 1
440 opcode = first_byte & 0xF
442 # frame-fin = %x0 ; more frames of this message follow
443 # / %x1 ; final frame of this message
444 # frame-rsv1 = %x0 ;
445 # 1 bit, MUST be 0 unless negotiated otherwise
446 # frame-rsv2 = %x0 ;
447 # 1 bit, MUST be 0 unless negotiated otherwise
448 # frame-rsv3 = %x0 ;
449 # 1 bit, MUST be 0 unless negotiated otherwise
450 #
451 # Remove rsv1 from this test for deflate development
452 if rsv2 or rsv3 or (rsv1 and not self._compress):
453 raise WebSocketError(
454 WSCloseCode.PROTOCOL_ERROR,
455 "Received frame with non-zero reserved bits",
456 )
458 if opcode > 0x7 and fin == 0:
459 raise WebSocketError(
460 WSCloseCode.PROTOCOL_ERROR,
461 "Received fragmented control frame",
462 )
464 has_mask = (second_byte >> 7) & 1
465 length = second_byte & 0x7F
467 # Control frames MUST have a payload
468 # length of 125 bytes or less
469 if opcode > 0x7 and length > 125:
470 raise WebSocketError(
471 WSCloseCode.PROTOCOL_ERROR,
472 "Control frame payload cannot be " "larger than 125 bytes",
473 )
475 # Set compress status if last package is FIN
476 # OR set compress status if this is first fragment
477 # Raise error if not first fragment with rsv1 = 0x1
478 if self._frame_fin or self._compressed is None:
479 self._compressed = True if rsv1 else False
480 elif rsv1:
481 raise WebSocketError(
482 WSCloseCode.PROTOCOL_ERROR,
483 "Received frame with non-zero reserved bits",
484 )
486 self._frame_fin = bool(fin)
487 self._frame_opcode = opcode
488 self._has_mask = bool(has_mask)
489 self._payload_length_flag = length
490 self._state = WSParserState.READ_PAYLOAD_LENGTH
491 else:
492 break
494 # read payload length
495 if self._state == WSParserState.READ_PAYLOAD_LENGTH:
496 length = self._payload_length_flag
497 if length == 126:
498 if buf_length - start_pos >= 2:
499 data = buf[start_pos : start_pos + 2]
500 start_pos += 2
501 length = UNPACK_LEN2(data)[0]
502 self._payload_length = length
503 self._state = (
504 WSParserState.READ_PAYLOAD_MASK
505 if self._has_mask
506 else WSParserState.READ_PAYLOAD
507 )
508 else:
509 break
510 elif length > 126:
511 if buf_length - start_pos >= 8:
512 data = buf[start_pos : start_pos + 8]
513 start_pos += 8
514 length = UNPACK_LEN3(data)[0]
515 self._payload_length = length
516 self._state = (
517 WSParserState.READ_PAYLOAD_MASK
518 if self._has_mask
519 else WSParserState.READ_PAYLOAD
520 )
521 else:
522 break
523 else:
524 self._payload_length = length
525 self._state = (
526 WSParserState.READ_PAYLOAD_MASK
527 if self._has_mask
528 else WSParserState.READ_PAYLOAD
529 )
531 # read payload mask
532 if self._state == WSParserState.READ_PAYLOAD_MASK:
533 if buf_length - start_pos >= 4:
534 self._frame_mask = buf[start_pos : start_pos + 4]
535 start_pos += 4
536 self._state = WSParserState.READ_PAYLOAD
537 else:
538 break
540 if self._state == WSParserState.READ_PAYLOAD:
541 length = self._payload_length
542 payload = self._frame_payload
544 chunk_len = buf_length - start_pos
545 if length >= chunk_len:
546 self._payload_length = length - chunk_len
547 payload.extend(buf[start_pos:])
548 start_pos = buf_length
549 else:
550 self._payload_length = 0
551 payload.extend(buf[start_pos : start_pos + length])
552 start_pos = start_pos + length
554 if self._payload_length == 0:
555 if self._has_mask:
556 assert self._frame_mask is not None
557 _websocket_mask(self._frame_mask, payload)
559 frames.append(
560 (self._frame_fin, self._frame_opcode, payload, self._compressed)
561 )
563 self._frame_payload = bytearray()
564 self._state = WSParserState.READ_HEADER
565 else:
566 break
568 self._tail = buf[start_pos:]
570 return frames
573class WebSocketWriter:
574 def __init__(
575 self,
576 protocol: BaseProtocol,
577 transport: asyncio.Transport,
578 *,
579 use_mask: bool = False,
580 limit: int = DEFAULT_LIMIT,
581 random: Any = random.Random(),
582 compress: int = 0,
583 notakeover: bool = False,
584 ) -> None:
585 self.protocol = protocol
586 self.transport = transport
587 self.use_mask = use_mask
588 self.randrange = random.randrange
589 self.compress = compress
590 self.notakeover = notakeover
591 self._closing = False
592 self._limit = limit
593 self._output_size = 0
594 self._compressobj: Any = None # actually compressobj
596 async def _send_frame(
597 self, message: bytes, opcode: int, compress: Optional[int] = None
598 ) -> None:
599 """Send a frame over the websocket with message as its payload."""
600 if self._closing and not (opcode & WSMsgType.CLOSE):
601 raise ConnectionResetError("Cannot write to closing transport")
603 rsv = 0
605 # Only compress larger packets (disabled)
606 # Does small packet needs to be compressed?
607 # if self.compress and opcode < 8 and len(message) > 124:
608 if (compress or self.compress) and opcode < 8:
609 if compress:
610 # Do not set self._compress if compressing is for this frame
611 compressobj = zlib.compressobj(level=zlib.Z_BEST_SPEED, wbits=-compress)
612 else: # self.compress
613 if not self._compressobj:
614 self._compressobj = zlib.compressobj(
615 level=zlib.Z_BEST_SPEED, wbits=-self.compress
616 )
617 compressobj = self._compressobj
619 message = compressobj.compress(message)
620 message = message + compressobj.flush(
621 zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
622 )
623 if message.endswith(_WS_DEFLATE_TRAILING):
624 message = message[:-4]
625 rsv = rsv | 0x40
627 msg_length = len(message)
629 use_mask = self.use_mask
630 if use_mask:
631 mask_bit = 0x80
632 else:
633 mask_bit = 0
635 if msg_length < 126:
636 header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
637 elif msg_length < (1 << 16):
638 header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
639 else:
640 header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
641 if use_mask:
642 mask = self.randrange(0, 0xFFFFFFFF)
643 mask = mask.to_bytes(4, "big")
644 message = bytearray(message)
645 _websocket_mask(mask, message)
646 self._write(header + mask + message)
647 self._output_size += len(header) + len(mask) + len(message)
648 else:
649 if len(message) > MSG_SIZE:
650 self._write(header)
651 self._write(message)
652 else:
653 self._write(header + message)
655 self._output_size += len(header) + len(message)
657 if self._output_size > self._limit:
658 self._output_size = 0
659 await self.protocol._drain_helper()
661 def _write(self, data: bytes) -> None:
662 if self.transport is None or self.transport.is_closing():
663 raise ConnectionResetError("Cannot write to closing transport")
664 self.transport.write(data)
666 async def pong(self, message: bytes = b"") -> None:
667 """Send pong message."""
668 if isinstance(message, str):
669 message = message.encode("utf-8")
670 await self._send_frame(message, WSMsgType.PONG)
672 async def ping(self, message: bytes = b"") -> None:
673 """Send ping message."""
674 if isinstance(message, str):
675 message = message.encode("utf-8")
676 await self._send_frame(message, WSMsgType.PING)
678 async def send(
679 self,
680 message: Union[str, bytes],
681 binary: bool = False,
682 compress: Optional[int] = None,
683 ) -> None:
684 """Send a frame over the websocket with message as its payload."""
685 if isinstance(message, str):
686 message = message.encode("utf-8")
687 if binary:
688 await self._send_frame(message, WSMsgType.BINARY, compress)
689 else:
690 await self._send_frame(message, WSMsgType.TEXT, compress)
692 async def close(self, code: int = 1000, message: bytes = b"") -> None:
693 """Close the websocket, sending the specified code and message."""
694 if isinstance(message, str):
695 message = message.encode("utf-8")
696 try:
697 await self._send_frame(
698 PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
699 )
700 finally:
701 self._closing = True