Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/_websocket/reader_py.py: 17%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Reader for WebSocket protocol versions 13 and 8."""
3import asyncio
4import builtins
5from collections import deque
6from typing import Final
8from ..base_protocol import BaseProtocol
9from ..compression_utils import ZLibDecompressor
10from ..helpers import _EXC_SENTINEL, set_exception
11from ..streams import EofStream
12from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
13from .models import (
14 WS_DEFLATE_TRAILING,
15 WebSocketError,
16 WSCloseCode,
17 WSMessage,
18 WSMessageBinary,
19 WSMessageClose,
20 WSMessagePing,
21 WSMessagePong,
22 WSMessageText,
23 WSMessageTextBytes,
24 WSMsgType,
25)
27ALLOWED_CLOSE_CODES: Final[set[int]] = {int(i) for i in WSCloseCode}
29# States for the reader, used to parse the WebSocket frame
30# integer values are used so they can be cythonized
31READ_HEADER = 1
32READ_PAYLOAD_LENGTH = 2
33READ_PAYLOAD_MASK = 3
34READ_PAYLOAD = 4
36WS_MSG_TYPE_BINARY = WSMsgType.BINARY
37WS_MSG_TYPE_TEXT = WSMsgType.TEXT
39# WSMsgType values unpacked so they can by cythonized to ints
40OP_CODE_NOT_SET = -1
41OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
42OP_CODE_TEXT = WSMsgType.TEXT.value
43OP_CODE_BINARY = WSMsgType.BINARY.value
44OP_CODE_CLOSE = WSMsgType.CLOSE.value
45OP_CODE_PING = WSMsgType.PING.value
46OP_CODE_PONG = WSMsgType.PONG.value
48EMPTY_FRAME_ERROR = (True, b"")
49EMPTY_FRAME = (False, b"")
51COMPRESSED_NOT_SET = -1
52COMPRESSED_FALSE = 0
53COMPRESSED_TRUE = 1
55TUPLE_NEW = tuple.__new__
57cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
60class WebSocketDataQueue:
61 """WebSocketDataQueue resumes and pauses an underlying stream.
63 It is a destination for WebSocket data.
64 """
66 def __init__(
67 self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
68 ) -> None:
69 self._size = 0
70 self._protocol = protocol
71 self._limit = limit * 2
72 self._loop = loop
73 self._eof = False
74 self._waiter: asyncio.Future[None] | None = None
75 self._exception: type[BaseException] | BaseException | None = None
76 self._buffer: deque[WSMessage] = deque()
77 self._get_buffer = self._buffer.popleft
78 self._put_buffer = self._buffer.append
80 def is_eof(self) -> bool:
81 return self._eof
83 def exception(self) -> type[BaseException] | BaseException | None:
84 return self._exception
86 def set_exception(
87 self,
88 exc: type[BaseException] | BaseException,
89 exc_cause: builtins.BaseException = _EXC_SENTINEL,
90 ) -> None:
91 self._eof = True
92 self._exception = exc
93 if (waiter := self._waiter) is not None:
94 self._waiter = None
95 set_exception(waiter, exc, exc_cause)
97 def _release_waiter(self) -> None:
98 if (waiter := self._waiter) is None:
99 return
100 self._waiter = None
101 if not waiter.done():
102 waiter.set_result(None)
104 def feed_eof(self) -> None:
105 self._eof = True
106 self._release_waiter()
107 self._exception = None # Break cyclic references
109 def feed_data(self, data: "WSMessage") -> None:
110 size = data.size
111 self._size += size
112 self._put_buffer(data)
113 self._release_waiter()
114 if self._size > self._limit and not self._protocol._reading_paused:
115 self._protocol.pause_reading()
117 async def read(self) -> WSMessage:
118 if not self._buffer and not self._eof:
119 assert not self._waiter
120 self._waiter = self._loop.create_future()
121 try:
122 await self._waiter
123 except (asyncio.CancelledError, asyncio.TimeoutError):
124 self._waiter = None
125 raise
126 return self._read_from_buffer()
128 def _read_from_buffer(self) -> WSMessage:
129 if self._buffer:
130 data = self._get_buffer()
131 size = data.size
132 self._size -= size
133 if self._size < self._limit and self._protocol._reading_paused:
134 self._protocol.resume_reading()
135 return data
136 if self._exception is not None:
137 raise self._exception
138 raise EofStream
141class WebSocketReader:
142 def __init__(
143 self,
144 queue: WebSocketDataQueue,
145 max_msg_size: int,
146 compress: bool = True,
147 decode_text: bool = True,
148 ) -> None:
149 self.queue = queue
150 self._max_msg_size = max_msg_size
151 self._decode_text = decode_text
153 self._exc: Exception | None = None
154 self._partial = bytearray()
155 self._state = READ_HEADER
157 self._opcode: int = OP_CODE_NOT_SET
158 self._frame_fin = False
159 self._frame_opcode: int = OP_CODE_NOT_SET
160 self._payload_fragments: list[bytes] = []
161 self._frame_payload_len = 0
163 self._tail: bytes = b""
164 self._has_mask = False
165 self._frame_mask: bytes | None = None
166 self._payload_bytes_to_read = 0
167 self._payload_len_flag = 0
168 self._compressed: int = COMPRESSED_NOT_SET
169 self._decompressobj: ZLibDecompressor | None = None
170 self._compress = compress
172 def feed_eof(self) -> None:
173 self.queue.feed_eof()
175 # data can be bytearray on Windows because proactor event loop uses bytearray
176 # and asyncio types this to Union[bytes, bytearray, memoryview] so we need
177 # coerce data to bytes if it is not
178 def feed_data(self, data: bytes | bytearray | memoryview) -> tuple[bool, bytes]:
179 if type(data) is not bytes:
180 data = bytes(data)
182 if self._exc is not None:
183 return True, data
185 try:
186 self._feed_data(data)
187 except Exception as exc:
188 self._exc = exc
189 set_exception(self.queue, exc)
190 return EMPTY_FRAME_ERROR
192 return EMPTY_FRAME
194 def _handle_frame(
195 self,
196 fin: bool,
197 opcode: int | cython_int, # Union intended: Cython pxd uses C int
198 payload: bytes | bytearray,
199 compressed: int | cython_int, # Union intended: Cython pxd uses C int
200 ) -> None:
201 msg: WSMessage
202 if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
203 # Validate continuation frames before processing
204 if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET:
205 raise WebSocketError(
206 WSCloseCode.PROTOCOL_ERROR,
207 "Continuation frame for non started message",
208 )
210 # load text/binary
211 if not fin:
212 # got partial frame payload
213 if opcode != OP_CODE_CONTINUATION:
214 self._opcode = opcode
215 self._partial += payload
216 if self._max_msg_size and len(self._partial) >= self._max_msg_size:
217 raise WebSocketError(
218 WSCloseCode.MESSAGE_TOO_BIG,
219 f"Message size {len(self._partial)} "
220 f"exceeds limit {self._max_msg_size}",
221 )
222 return
224 has_partial = bool(self._partial)
225 if opcode == OP_CODE_CONTINUATION:
226 opcode = self._opcode
227 self._opcode = OP_CODE_NOT_SET
228 # previous frame was non finished
229 # we should get continuation opcode
230 elif has_partial:
231 raise WebSocketError(
232 WSCloseCode.PROTOCOL_ERROR,
233 "The opcode in non-fin frame is expected "
234 f"to be zero, got {opcode!r}",
235 )
237 assembled_payload: bytes | bytearray
238 if has_partial:
239 assembled_payload = self._partial + payload
240 self._partial.clear()
241 else:
242 assembled_payload = payload
244 if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
245 raise WebSocketError(
246 WSCloseCode.MESSAGE_TOO_BIG,
247 f"Message size {len(assembled_payload)} "
248 f"exceeds limit {self._max_msg_size}",
249 )
251 # Decompress process must to be done after all packets
252 # received.
253 if compressed:
254 if not self._decompressobj:
255 self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
256 # XXX: It's possible that the zlib backend (isal is known to
257 # do this, maybe others too?) will return max_length bytes,
258 # but internally buffer more data such that the payload is
259 # >max_length, so we return one extra byte and if we're able
260 # to do that, then the message is too big.
261 payload_merged = self._decompressobj.decompress_sync(
262 assembled_payload + WS_DEFLATE_TRAILING,
263 (
264 self._max_msg_size + 1
265 if self._max_msg_size
266 else self._max_msg_size
267 ),
268 )
269 if self._max_msg_size and len(payload_merged) > self._max_msg_size:
270 raise WebSocketError(
271 WSCloseCode.MESSAGE_TOO_BIG,
272 f"Decompressed message exceeds size limit {self._max_msg_size}",
273 )
274 elif type(assembled_payload) is bytes:
275 payload_merged = assembled_payload
276 else:
277 payload_merged = bytes(assembled_payload)
279 size = len(payload_merged)
280 if opcode == OP_CODE_TEXT:
281 if self._decode_text:
282 try:
283 text = payload_merged.decode("utf-8")
284 except UnicodeDecodeError as exc:
285 raise WebSocketError(
286 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
287 ) from exc
289 # XXX: The Text and Binary messages here can be a performance
290 # bottleneck, so we use tuple.__new__ to improve performance.
291 # This is not type safe, but many tests should fail in
292 # test_client_ws_functional.py if this is wrong.
293 msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT))
294 else:
295 # Return raw bytes for TEXT messages when decode_text=False
296 msg = TUPLE_NEW(
297 WSMessageTextBytes, (payload_merged, size, "", WS_MSG_TYPE_TEXT)
298 )
299 else:
300 msg = TUPLE_NEW(
301 WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY)
302 )
304 self.queue.feed_data(msg)
305 elif opcode == OP_CODE_CLOSE:
306 payload_len = len(payload)
307 if payload_len >= 2:
308 close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
309 if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
310 raise WebSocketError(
311 WSCloseCode.PROTOCOL_ERROR,
312 f"Invalid close code: {close_code}",
313 )
314 try:
315 close_message = payload[2:].decode("utf-8")
316 except UnicodeDecodeError as exc:
317 raise WebSocketError(
318 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
319 ) from exc
320 msg = WSMessageClose(
321 data=close_code, size=payload_len, extra=close_message
322 )
323 elif payload:
324 raise WebSocketError(
325 WSCloseCode.PROTOCOL_ERROR,
326 f"Invalid close frame: {fin} {opcode} {payload!r}",
327 )
328 else:
329 msg = WSMessageClose(data=0, size=payload_len, extra="")
331 self.queue.feed_data(msg)
332 elif opcode == OP_CODE_PING:
333 self.queue.feed_data(
334 WSMessagePing(data=bytes(payload), size=len(payload), extra="")
335 )
336 elif opcode == OP_CODE_PONG:
337 self.queue.feed_data(
338 WSMessagePong(data=bytes(payload), size=len(payload), extra="")
339 )
340 else:
341 raise WebSocketError(
342 WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
343 )
345 def _feed_data(self, data: bytes) -> None:
346 """Return the next frame from the socket."""
347 if self._tail:
348 data, self._tail = self._tail + data, b""
350 start_pos: int = 0
351 data_len = len(data)
352 data_cstr = data
354 while True:
355 # read header
356 if self._state == READ_HEADER:
357 if data_len - start_pos < 2:
358 break
359 first_byte = data_cstr[start_pos]
360 second_byte = data_cstr[start_pos + 1]
361 start_pos += 2
363 fin = (first_byte >> 7) & 1
364 rsv1 = (first_byte >> 6) & 1
365 rsv2 = (first_byte >> 5) & 1
366 rsv3 = (first_byte >> 4) & 1
367 opcode = first_byte & 0xF
369 # frame-fin = %x0 ; more frames of this message follow
370 # / %x1 ; final frame of this message
371 # frame-rsv1 = %x0 ;
372 # 1 bit, MUST be 0 unless negotiated otherwise
373 # frame-rsv2 = %x0 ;
374 # 1 bit, MUST be 0 unless negotiated otherwise
375 # frame-rsv3 = %x0 ;
376 # 1 bit, MUST be 0 unless negotiated otherwise
377 #
378 # Remove rsv1 from this test for deflate development
379 if rsv2 or rsv3 or (rsv1 and not self._compress):
380 raise WebSocketError(
381 WSCloseCode.PROTOCOL_ERROR,
382 "Received frame with non-zero reserved bits",
383 )
385 if opcode > 0x7 and fin == 0:
386 raise WebSocketError(
387 WSCloseCode.PROTOCOL_ERROR,
388 "Received fragmented control frame",
389 )
391 has_mask = (second_byte >> 7) & 1
392 length = second_byte & 0x7F
394 # Control frames MUST have a payload
395 # length of 125 bytes or less
396 if opcode > 0x7 and length > 125:
397 raise WebSocketError(
398 WSCloseCode.PROTOCOL_ERROR,
399 "Control frame payload cannot be larger than 125 bytes",
400 )
402 # Set compress status if last package is FIN
403 # OR set compress status if this is first fragment
404 # Raise error if not first fragment with rsv1 = 0x1
405 if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
406 self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
407 elif rsv1:
408 raise WebSocketError(
409 WSCloseCode.PROTOCOL_ERROR,
410 "Received frame with non-zero reserved bits",
411 )
413 self._frame_fin = bool(fin)
414 self._frame_opcode = opcode
415 self._has_mask = bool(has_mask)
416 self._payload_len_flag = length
417 self._state = READ_PAYLOAD_LENGTH
419 # read payload length
420 if self._state == READ_PAYLOAD_LENGTH:
421 len_flag = self._payload_len_flag
422 if len_flag == 126:
423 if data_len - start_pos < 2:
424 break
425 first_byte = data_cstr[start_pos]
426 second_byte = data_cstr[start_pos + 1]
427 start_pos += 2
428 self._payload_bytes_to_read = first_byte << 8 | second_byte
429 elif len_flag > 126:
430 if data_len - start_pos < 8:
431 break
432 self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
433 start_pos += 8
434 else:
435 self._payload_bytes_to_read = len_flag
437 self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
439 # read payload mask
440 if self._state == READ_PAYLOAD_MASK:
441 if data_len - start_pos < 4:
442 break
443 self._frame_mask = data_cstr[start_pos : start_pos + 4]
444 start_pos += 4
445 self._state = READ_PAYLOAD
447 if self._state == READ_PAYLOAD:
448 chunk_len = data_len - start_pos
449 if self._payload_bytes_to_read >= chunk_len:
450 f_end_pos = data_len
451 self._payload_bytes_to_read -= chunk_len
452 else:
453 f_end_pos = start_pos + self._payload_bytes_to_read
454 self._payload_bytes_to_read = 0
456 had_fragments = self._frame_payload_len
457 self._frame_payload_len += f_end_pos - start_pos
458 f_start_pos = start_pos
459 start_pos = f_end_pos
461 if self._payload_bytes_to_read != 0:
462 # If we don't have a complete frame, we need to save the
463 # data for the next call to feed_data.
464 self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
465 break
467 payload: bytes | bytearray
468 if had_fragments:
469 # We have to join the payload fragments get the payload
470 self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
471 if self._has_mask:
472 assert self._frame_mask is not None
473 payload_bytearray = bytearray(b"".join(self._payload_fragments))
474 websocket_mask(self._frame_mask, payload_bytearray)
475 payload = payload_bytearray
476 else:
477 payload = b"".join(self._payload_fragments)
478 self._payload_fragments.clear()
479 elif self._has_mask:
480 assert self._frame_mask is not None
481 payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
482 if type(payload_bytearray) is not bytearray: # pragma: no branch
483 # Cython will do the conversion for us
484 # but we need to do it for Python and we
485 # will always get here in Python
486 payload_bytearray = bytearray(payload_bytearray)
487 websocket_mask(self._frame_mask, payload_bytearray)
488 payload = payload_bytearray
489 else:
490 payload = data_cstr[f_start_pos:f_end_pos]
492 self._handle_frame(
493 self._frame_fin, self._frame_opcode, payload, self._compressed
494 )
495 self._frame_payload_len = 0
496 self._state = READ_HEADER
498 # XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
499 self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""