Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/_websocket/reader_py.py: 17%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Reader for WebSocket protocol versions 13 and 8."""
3import asyncio
4import builtins
5from collections import deque
6from typing import Deque, Final, Optional, Set, Tuple, Type, Union
8from ..base_protocol import BaseProtocol
9from ..compression_utils import ZLibDecompressor
10from ..helpers import _EXC_SENTINEL, set_exception
11from ..streams import EofStream
12from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
13from .models import (
14 WS_DEFLATE_TRAILING,
15 WebSocketError,
16 WSCloseCode,
17 WSMessage,
18 WSMessageBinary,
19 WSMessageClose,
20 WSMessagePing,
21 WSMessagePong,
22 WSMessageText,
23 WSMsgType,
24)
26ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
28# States for the reader, used to parse the WebSocket frame
29# integer values are used so they can be cythonized
30READ_HEADER = 1
31READ_PAYLOAD_LENGTH = 2
32READ_PAYLOAD_MASK = 3
33READ_PAYLOAD = 4
35WS_MSG_TYPE_BINARY = WSMsgType.BINARY
36WS_MSG_TYPE_TEXT = WSMsgType.TEXT
38# WSMsgType values unpacked so they can by cythonized to ints
39OP_CODE_NOT_SET = -1
40OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
41OP_CODE_TEXT = WSMsgType.TEXT.value
42OP_CODE_BINARY = WSMsgType.BINARY.value
43OP_CODE_CLOSE = WSMsgType.CLOSE.value
44OP_CODE_PING = WSMsgType.PING.value
45OP_CODE_PONG = WSMsgType.PONG.value
47EMPTY_FRAME_ERROR = (True, b"")
48EMPTY_FRAME = (False, b"")
50COMPRESSED_NOT_SET = -1
51COMPRESSED_FALSE = 0
52COMPRESSED_TRUE = 1
54TUPLE_NEW = tuple.__new__
56cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
59class WebSocketDataQueue:
60 """WebSocketDataQueue resumes and pauses an underlying stream.
62 It is a destination for WebSocket data.
63 """
65 def __init__(
66 self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
67 ) -> None:
68 self._size = 0
69 self._protocol = protocol
70 self._limit = limit * 2
71 self._loop = loop
72 self._eof = False
73 self._waiter: Optional[asyncio.Future[None]] = None
74 self._exception: Union[Type[BaseException], BaseException, None] = None
75 self._buffer: Deque[WSMessage] = deque()
76 self._get_buffer = self._buffer.popleft
77 self._put_buffer = self._buffer.append
79 def is_eof(self) -> bool:
80 return self._eof
82 def exception(self) -> Optional[Union[Type[BaseException], BaseException]]:
83 return self._exception
85 def set_exception(
86 self,
87 exc: Union[Type[BaseException], BaseException],
88 exc_cause: builtins.BaseException = _EXC_SENTINEL,
89 ) -> None:
90 self._eof = True
91 self._exception = exc
92 if (waiter := self._waiter) is not None:
93 self._waiter = None
94 set_exception(waiter, exc, exc_cause)
96 def _release_waiter(self) -> None:
97 if (waiter := self._waiter) is None:
98 return
99 self._waiter = None
100 if not waiter.done():
101 waiter.set_result(None)
103 def feed_eof(self) -> None:
104 self._eof = True
105 self._release_waiter()
106 self._exception = None # Break cyclic references
108 def feed_data(self, data: "WSMessage") -> None:
109 size = data.size
110 self._size += size
111 self._put_buffer(data)
112 self._release_waiter()
113 if self._size > self._limit and not self._protocol._reading_paused:
114 self._protocol.pause_reading()
116 async def read(self) -> WSMessage:
117 if not self._buffer and not self._eof:
118 assert not self._waiter
119 self._waiter = self._loop.create_future()
120 try:
121 await self._waiter
122 except (asyncio.CancelledError, asyncio.TimeoutError):
123 self._waiter = None
124 raise
125 return self._read_from_buffer()
127 def _read_from_buffer(self) -> WSMessage:
128 if self._buffer:
129 data = self._get_buffer()
130 size = data.size
131 self._size -= size
132 if self._size < self._limit and self._protocol._reading_paused:
133 self._protocol.resume_reading()
134 return data
135 if self._exception is not None:
136 raise self._exception
137 raise EofStream
140class WebSocketReader:
141 def __init__(
142 self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
143 ) -> None:
144 self.queue = queue
145 self._max_msg_size = max_msg_size
147 self._exc: Optional[Exception] = None
148 self._partial = bytearray()
149 self._state = READ_HEADER
151 self._opcode: int = OP_CODE_NOT_SET
152 self._frame_fin = False
153 self._frame_opcode: int = OP_CODE_NOT_SET
154 self._payload_fragments: list[bytes] = []
155 self._frame_payload_len = 0
157 self._tail: bytes = b""
158 self._has_mask = False
159 self._frame_mask: Optional[bytes] = None
160 self._payload_bytes_to_read = 0
161 self._payload_len_flag = 0
162 self._compressed: int = COMPRESSED_NOT_SET
163 self._decompressobj: Optional[ZLibDecompressor] = None
164 self._compress = compress
166 def feed_eof(self) -> None:
167 self.queue.feed_eof()
169 # data can be bytearray on Windows because proactor event loop uses bytearray
170 # and asyncio types this to Union[bytes, bytearray, memoryview] so we need
171 # coerce data to bytes if it is not
172 def feed_data(
173 self, data: Union[bytes, bytearray, memoryview]
174 ) -> Tuple[bool, bytes]:
175 if type(data) is not bytes:
176 data = bytes(data)
178 if self._exc is not None:
179 return True, data
181 try:
182 self._feed_data(data)
183 except Exception as exc:
184 self._exc = exc
185 set_exception(self.queue, exc)
186 return EMPTY_FRAME_ERROR
188 return EMPTY_FRAME
190 def _handle_frame(
191 self,
192 fin: bool,
193 opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
194 payload: Union[bytes, bytearray],
195 compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
196 ) -> None:
197 msg: WSMessage
198 if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
199 # load text/binary
200 if not fin:
201 # got partial frame payload
202 if opcode != OP_CODE_CONTINUATION:
203 self._opcode = opcode
204 self._partial += payload
205 if self._max_msg_size and len(self._partial) >= self._max_msg_size:
206 raise WebSocketError(
207 WSCloseCode.MESSAGE_TOO_BIG,
208 f"Message size {len(self._partial)} "
209 f"exceeds limit {self._max_msg_size}",
210 )
211 return
213 has_partial = bool(self._partial)
214 if opcode == OP_CODE_CONTINUATION:
215 if self._opcode == OP_CODE_NOT_SET:
216 raise WebSocketError(
217 WSCloseCode.PROTOCOL_ERROR,
218 "Continuation frame for non started message",
219 )
220 opcode = self._opcode
221 self._opcode = OP_CODE_NOT_SET
222 # previous frame was non finished
223 # we should get continuation opcode
224 elif has_partial:
225 raise WebSocketError(
226 WSCloseCode.PROTOCOL_ERROR,
227 "The opcode in non-fin frame is expected "
228 f"to be zero, got {opcode!r}",
229 )
231 assembled_payload: Union[bytes, bytearray]
232 if has_partial:
233 assembled_payload = self._partial + payload
234 self._partial.clear()
235 else:
236 assembled_payload = payload
238 if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
239 raise WebSocketError(
240 WSCloseCode.MESSAGE_TOO_BIG,
241 f"Message size {len(assembled_payload)} "
242 f"exceeds limit {self._max_msg_size}",
243 )
245 # Decompress process must to be done after all packets
246 # received.
247 if compressed:
248 if not self._decompressobj:
249 self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
250 # XXX: It's possible that the zlib backend (isal is known to
251 # do this, maybe others too?) will return max_length bytes,
252 # but internally buffer more data such that the payload is
253 # >max_length, so we return one extra byte and if we're able
254 # to do that, then the message is too big.
255 payload_merged = self._decompressobj.decompress_sync(
256 assembled_payload + WS_DEFLATE_TRAILING,
257 (
258 self._max_msg_size + 1
259 if self._max_msg_size
260 else self._max_msg_size
261 ),
262 )
263 if self._max_msg_size and len(payload_merged) > self._max_msg_size:
264 raise WebSocketError(
265 WSCloseCode.MESSAGE_TOO_BIG,
266 f"Decompressed message exceeds size limit {self._max_msg_size}",
267 )
268 elif type(assembled_payload) is bytes:
269 payload_merged = assembled_payload
270 else:
271 payload_merged = bytes(assembled_payload)
273 size = len(payload_merged)
274 if opcode == OP_CODE_TEXT:
275 try:
276 text = payload_merged.decode("utf-8")
277 except UnicodeDecodeError as exc:
278 raise WebSocketError(
279 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
280 ) from exc
282 # XXX: The Text and Binary messages here can be a performance
283 # bottleneck, so we use tuple.__new__ to improve performance.
284 # This is not type safe, but many tests should fail in
285 # test_client_ws_functional.py if this is wrong.
286 msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT))
287 else:
288 msg = TUPLE_NEW(
289 WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY)
290 )
292 self.queue.feed_data(msg)
293 elif opcode == OP_CODE_CLOSE:
294 payload_len = len(payload)
295 if payload_len >= 2:
296 close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
297 if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
298 raise WebSocketError(
299 WSCloseCode.PROTOCOL_ERROR,
300 f"Invalid close code: {close_code}",
301 )
302 try:
303 close_message = payload[2:].decode("utf-8")
304 except UnicodeDecodeError as exc:
305 raise WebSocketError(
306 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
307 ) from exc
308 msg = WSMessageClose(
309 data=close_code, size=payload_len, extra=close_message
310 )
311 elif payload:
312 raise WebSocketError(
313 WSCloseCode.PROTOCOL_ERROR,
314 f"Invalid close frame: {fin} {opcode} {payload!r}",
315 )
316 else:
317 msg = WSMessageClose(data=0, size=payload_len, extra="")
319 self.queue.feed_data(msg)
320 elif opcode == OP_CODE_PING:
321 self.queue.feed_data(
322 WSMessagePing(data=bytes(payload), size=len(payload), extra="")
323 )
324 elif opcode == OP_CODE_PONG:
325 self.queue.feed_data(
326 WSMessagePong(data=bytes(payload), size=len(payload), extra="")
327 )
328 else:
329 raise WebSocketError(
330 WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
331 )
333 def _feed_data(self, data: bytes) -> None:
334 """Return the next frame from the socket."""
335 if self._tail:
336 data, self._tail = self._tail + data, b""
338 start_pos: int = 0
339 data_len = len(data)
340 data_cstr = data
342 while True:
343 # read header
344 if self._state == READ_HEADER:
345 if data_len - start_pos < 2:
346 break
347 first_byte = data_cstr[start_pos]
348 second_byte = data_cstr[start_pos + 1]
349 start_pos += 2
351 fin = (first_byte >> 7) & 1
352 rsv1 = (first_byte >> 6) & 1
353 rsv2 = (first_byte >> 5) & 1
354 rsv3 = (first_byte >> 4) & 1
355 opcode = first_byte & 0xF
357 # frame-fin = %x0 ; more frames of this message follow
358 # / %x1 ; final frame of this message
359 # frame-rsv1 = %x0 ;
360 # 1 bit, MUST be 0 unless negotiated otherwise
361 # frame-rsv2 = %x0 ;
362 # 1 bit, MUST be 0 unless negotiated otherwise
363 # frame-rsv3 = %x0 ;
364 # 1 bit, MUST be 0 unless negotiated otherwise
365 #
366 # Remove rsv1 from this test for deflate development
367 if rsv2 or rsv3 or (rsv1 and not self._compress):
368 raise WebSocketError(
369 WSCloseCode.PROTOCOL_ERROR,
370 "Received frame with non-zero reserved bits",
371 )
373 if opcode > 0x7 and fin == 0:
374 raise WebSocketError(
375 WSCloseCode.PROTOCOL_ERROR,
376 "Received fragmented control frame",
377 )
379 has_mask = (second_byte >> 7) & 1
380 length = second_byte & 0x7F
382 # Control frames MUST have a payload
383 # length of 125 bytes or less
384 if opcode > 0x7 and length > 125:
385 raise WebSocketError(
386 WSCloseCode.PROTOCOL_ERROR,
387 "Control frame payload cannot be larger than 125 bytes",
388 )
390 # Set compress status if last package is FIN
391 # OR set compress status if this is first fragment
392 # Raise error if not first fragment with rsv1 = 0x1
393 if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
394 self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
395 elif rsv1:
396 raise WebSocketError(
397 WSCloseCode.PROTOCOL_ERROR,
398 "Received frame with non-zero reserved bits",
399 )
401 self._frame_fin = bool(fin)
402 self._frame_opcode = opcode
403 self._has_mask = bool(has_mask)
404 self._payload_len_flag = length
405 self._state = READ_PAYLOAD_LENGTH
407 # read payload length
408 if self._state == READ_PAYLOAD_LENGTH:
409 len_flag = self._payload_len_flag
410 if len_flag == 126:
411 if data_len - start_pos < 2:
412 break
413 first_byte = data_cstr[start_pos]
414 second_byte = data_cstr[start_pos + 1]
415 start_pos += 2
416 self._payload_bytes_to_read = first_byte << 8 | second_byte
417 elif len_flag > 126:
418 if data_len - start_pos < 8:
419 break
420 self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
421 start_pos += 8
422 else:
423 self._payload_bytes_to_read = len_flag
425 self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
427 # read payload mask
428 if self._state == READ_PAYLOAD_MASK:
429 if data_len - start_pos < 4:
430 break
431 self._frame_mask = data_cstr[start_pos : start_pos + 4]
432 start_pos += 4
433 self._state = READ_PAYLOAD
435 if self._state == READ_PAYLOAD:
436 chunk_len = data_len - start_pos
437 if self._payload_bytes_to_read >= chunk_len:
438 f_end_pos = data_len
439 self._payload_bytes_to_read -= chunk_len
440 else:
441 f_end_pos = start_pos + self._payload_bytes_to_read
442 self._payload_bytes_to_read = 0
444 had_fragments = self._frame_payload_len
445 self._frame_payload_len += f_end_pos - start_pos
446 f_start_pos = start_pos
447 start_pos = f_end_pos
449 if self._payload_bytes_to_read != 0:
450 # If we don't have a complete frame, we need to save the
451 # data for the next call to feed_data.
452 self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
453 break
455 payload: Union[bytes, bytearray]
456 if had_fragments:
457 # We have to join the payload fragments get the payload
458 self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
459 if self._has_mask:
460 assert self._frame_mask is not None
461 payload_bytearray = bytearray(b"".join(self._payload_fragments))
462 websocket_mask(self._frame_mask, payload_bytearray)
463 payload = payload_bytearray
464 else:
465 payload = b"".join(self._payload_fragments)
466 self._payload_fragments.clear()
467 elif self._has_mask:
468 assert self._frame_mask is not None
469 payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
470 if type(payload_bytearray) is not bytearray: # pragma: no branch
471 # Cython will do the conversion for us
472 # but we need to do it for Python and we
473 # will always get here in Python
474 payload_bytearray = bytearray(payload_bytearray)
475 websocket_mask(self._frame_mask, payload_bytearray)
476 payload = payload_bytearray
477 else:
478 payload = data_cstr[f_start_pos:f_end_pos]
480 self._handle_frame(
481 self._frame_fin, self._frame_opcode, payload, self._compressed
482 )
483 self._frame_payload_len = 0
484 self._state = READ_HEADER
486 # XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
487 self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""