Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/_websocket/reader_py.py: 17%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

284 statements  

1"""Reader for WebSocket protocol versions 13 and 8.""" 

2 

3import asyncio 

4import builtins 

5from collections import deque 

6from typing import Final 

7 

8from ..base_protocol import BaseProtocol 

9from ..compression_utils import ZLibDecompressor 

10from ..helpers import _EXC_SENTINEL, set_exception 

11from ..streams import EofStream 

12from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask 

13from .models import ( 

14 WS_DEFLATE_TRAILING, 

15 WebSocketError, 

16 WSCloseCode, 

17 WSMessage, 

18 WSMessageBinary, 

19 WSMessageClose, 

20 WSMessagePing, 

21 WSMessagePong, 

22 WSMessageText, 

23 WSMessageTextBytes, 

24 WSMsgType, 

25) 

26 

27ALLOWED_CLOSE_CODES: Final[set[int]] = {int(i) for i in WSCloseCode} 

28 

29# States for the reader, used to parse the WebSocket frame 

30# integer values are used so they can be cythonized 

31READ_HEADER = 1 

32READ_PAYLOAD_LENGTH = 2 

33READ_PAYLOAD_MASK = 3 

34READ_PAYLOAD = 4 

35 

36WS_MSG_TYPE_BINARY = WSMsgType.BINARY 

37WS_MSG_TYPE_TEXT = WSMsgType.TEXT 

38 

39# WSMsgType values unpacked so they can by cythonized to ints 

40OP_CODE_NOT_SET = -1 

41OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value 

42OP_CODE_TEXT = WSMsgType.TEXT.value 

43OP_CODE_BINARY = WSMsgType.BINARY.value 

44OP_CODE_CLOSE = WSMsgType.CLOSE.value 

45OP_CODE_PING = WSMsgType.PING.value 

46OP_CODE_PONG = WSMsgType.PONG.value 

47 

48EMPTY_FRAME_ERROR = (True, b"") 

49EMPTY_FRAME = (False, b"") 

50 

51COMPRESSED_NOT_SET = -1 

52COMPRESSED_FALSE = 0 

53COMPRESSED_TRUE = 1 

54 

55TUPLE_NEW = tuple.__new__ 

56 

57cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd 

58 

59 

60class WebSocketDataQueue: 

61 """WebSocketDataQueue resumes and pauses an underlying stream. 

62 

63 It is a destination for WebSocket data. 

64 """ 

65 

66 def __init__( 

67 self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop 

68 ) -> None: 

69 self._size = 0 

70 self._protocol = protocol 

71 self._limit = limit * 2 

72 self._loop = loop 

73 self._eof = False 

74 self._waiter: asyncio.Future[None] | None = None 

75 self._exception: type[BaseException] | BaseException | None = None 

76 self._buffer: deque[WSMessage] = deque() 

77 self._get_buffer = self._buffer.popleft 

78 self._put_buffer = self._buffer.append 

79 

80 def is_eof(self) -> bool: 

81 return self._eof 

82 

83 def exception(self) -> type[BaseException] | BaseException | None: 

84 return self._exception 

85 

86 def set_exception( 

87 self, 

88 exc: type[BaseException] | BaseException, 

89 exc_cause: builtins.BaseException = _EXC_SENTINEL, 

90 ) -> None: 

91 self._eof = True 

92 self._exception = exc 

93 if (waiter := self._waiter) is not None: 

94 self._waiter = None 

95 set_exception(waiter, exc, exc_cause) 

96 

97 def _release_waiter(self) -> None: 

98 if (waiter := self._waiter) is None: 

99 return 

100 self._waiter = None 

101 if not waiter.done(): 

102 waiter.set_result(None) 

103 

104 def feed_eof(self) -> None: 

105 self._eof = True 

106 self._release_waiter() 

107 self._exception = None # Break cyclic references 

108 

109 def feed_data(self, data: "WSMessage") -> None: 

110 size = data.size 

111 self._size += size 

112 self._put_buffer(data) 

113 self._release_waiter() 

114 if self._size > self._limit and not self._protocol._reading_paused: 

115 self._protocol.pause_reading() 

116 

117 async def read(self) -> WSMessage: 

118 if not self._buffer and not self._eof: 

119 assert not self._waiter 

120 self._waiter = self._loop.create_future() 

121 try: 

122 await self._waiter 

123 except (asyncio.CancelledError, asyncio.TimeoutError): 

124 self._waiter = None 

125 raise 

126 return self._read_from_buffer() 

127 

128 def _read_from_buffer(self) -> WSMessage: 

129 if self._buffer: 

130 data = self._get_buffer() 

131 size = data.size 

132 self._size -= size 

133 if self._size < self._limit and self._protocol._reading_paused: 

134 self._protocol.resume_reading() 

135 return data 

136 if self._exception is not None: 

137 raise self._exception 

138 raise EofStream 

139 

140 

141class WebSocketReader: 

142 def __init__( 

143 self, 

144 queue: WebSocketDataQueue, 

145 max_msg_size: int, 

146 compress: bool = True, 

147 decode_text: bool = True, 

148 ) -> None: 

149 self.queue = queue 

150 self._max_msg_size = max_msg_size 

151 self._decode_text = decode_text 

152 

153 self._exc: Exception | None = None 

154 self._partial = bytearray() 

155 self._state = READ_HEADER 

156 

157 self._opcode: int = OP_CODE_NOT_SET 

158 self._frame_fin = False 

159 self._frame_opcode: int = OP_CODE_NOT_SET 

160 self._payload_fragments: list[bytes] = [] 

161 self._frame_payload_len = 0 

162 

163 self._tail: bytes = b"" 

164 self._has_mask = False 

165 self._frame_mask: bytes | None = None 

166 self._payload_bytes_to_read = 0 

167 self._payload_len_flag = 0 

168 self._compressed: int = COMPRESSED_NOT_SET 

169 self._decompressobj: ZLibDecompressor | None = None 

170 self._compress = compress 

171 

172 def feed_eof(self) -> None: 

173 self.queue.feed_eof() 

174 

175 # data can be bytearray on Windows because proactor event loop uses bytearray 

176 # and asyncio types this to Union[bytes, bytearray, memoryview] so we need 

177 # coerce data to bytes if it is not 

178 def feed_data(self, data: bytes | bytearray | memoryview) -> tuple[bool, bytes]: 

179 if type(data) is not bytes: 

180 data = bytes(data) 

181 

182 if self._exc is not None: 

183 return True, data 

184 

185 try: 

186 self._feed_data(data) 

187 except Exception as exc: 

188 self._exc = exc 

189 set_exception(self.queue, exc) 

190 return EMPTY_FRAME_ERROR 

191 

192 return EMPTY_FRAME 

193 

194 def _handle_frame( 

195 self, 

196 fin: bool, 

197 opcode: int | cython_int, # Union intended: Cython pxd uses C int 

198 payload: bytes | bytearray, 

199 compressed: int | cython_int, # Union intended: Cython pxd uses C int 

200 ) -> None: 

201 msg: WSMessage 

202 if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}: 

203 # Validate continuation frames before processing 

204 if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET: 

205 raise WebSocketError( 

206 WSCloseCode.PROTOCOL_ERROR, 

207 "Continuation frame for non started message", 

208 ) 

209 

210 # load text/binary 

211 if not fin: 

212 # got partial frame payload 

213 if opcode != OP_CODE_CONTINUATION: 

214 self._opcode = opcode 

215 self._partial += payload 

216 return 

217 

218 has_partial = bool(self._partial) 

219 if opcode == OP_CODE_CONTINUATION: 

220 opcode = self._opcode 

221 self._opcode = OP_CODE_NOT_SET 

222 # previous frame was non finished 

223 # we should get continuation opcode 

224 elif has_partial: 

225 raise WebSocketError( 

226 WSCloseCode.PROTOCOL_ERROR, 

227 "The opcode in non-fin frame is expected " 

228 f"to be zero, got {opcode!r}", 

229 ) 

230 

231 assembled_payload: bytes | bytearray 

232 if has_partial: 

233 assembled_payload = self._partial + payload 

234 self._partial.clear() 

235 else: 

236 assembled_payload = payload 

237 

238 # Decompress process must to be done after all packets 

239 # received. 

240 if compressed: 

241 if not self._decompressobj: 

242 self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) 

243 # XXX: It's possible that the zlib backend (isal is known to 

244 # do this, maybe others too?) will return max_length bytes, 

245 # but internally buffer more data such that the payload is 

246 # >max_length, so we return one extra byte and if we're able 

247 # to do that, then the message is too big. 

248 payload_merged = self._decompressobj.decompress_sync( 

249 assembled_payload + WS_DEFLATE_TRAILING, 

250 ( 

251 self._max_msg_size + 1 

252 if self._max_msg_size 

253 else self._max_msg_size 

254 ), 

255 ) 

256 if self._max_msg_size and len(payload_merged) > self._max_msg_size: 

257 raise WebSocketError( 

258 WSCloseCode.MESSAGE_TOO_BIG, 

259 f"Decompressed message exceeds size limit {self._max_msg_size}", 

260 ) 

261 elif type(assembled_payload) is bytes: 

262 payload_merged = assembled_payload 

263 else: 

264 payload_merged = bytes(assembled_payload) 

265 

266 size = len(payload_merged) 

267 if opcode == OP_CODE_TEXT: 

268 if self._decode_text: 

269 try: 

270 text = payload_merged.decode("utf-8") 

271 except UnicodeDecodeError as exc: 

272 raise WebSocketError( 

273 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" 

274 ) from exc 

275 

276 # XXX: The Text and Binary messages here can be a performance 

277 # bottleneck, so we use tuple.__new__ to improve performance. 

278 # This is not type safe, but many tests should fail in 

279 # test_client_ws_functional.py if this is wrong. 

280 msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT)) 

281 else: 

282 # Return raw bytes for TEXT messages when decode_text=False 

283 msg = TUPLE_NEW( 

284 WSMessageTextBytes, (payload_merged, size, "", WS_MSG_TYPE_TEXT) 

285 ) 

286 else: 

287 msg = TUPLE_NEW( 

288 WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY) 

289 ) 

290 

291 self.queue.feed_data(msg) 

292 elif opcode == OP_CODE_CLOSE: 

293 payload_len = len(payload) 

294 if payload_len >= 2: 

295 close_code = UNPACK_CLOSE_CODE(payload[:2])[0] 

296 if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: 

297 raise WebSocketError( 

298 WSCloseCode.PROTOCOL_ERROR, 

299 f"Invalid close code: {close_code}", 

300 ) 

301 try: 

302 close_message = payload[2:].decode("utf-8") 

303 except UnicodeDecodeError as exc: 

304 raise WebSocketError( 

305 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" 

306 ) from exc 

307 msg = WSMessageClose( 

308 data=close_code, size=payload_len, extra=close_message 

309 ) 

310 elif payload: 

311 raise WebSocketError( 

312 WSCloseCode.PROTOCOL_ERROR, 

313 f"Invalid close frame: {fin} {opcode} {payload!r}", 

314 ) 

315 else: 

316 msg = WSMessageClose(data=0, size=payload_len, extra="") 

317 

318 self.queue.feed_data(msg) 

319 elif opcode == OP_CODE_PING: 

320 self.queue.feed_data( 

321 WSMessagePing(data=bytes(payload), size=len(payload), extra="") 

322 ) 

323 elif opcode == OP_CODE_PONG: 

324 self.queue.feed_data( 

325 WSMessagePong(data=bytes(payload), size=len(payload), extra="") 

326 ) 

327 else: 

328 raise WebSocketError( 

329 WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" 

330 ) 

331 

332 def _feed_data(self, data: bytes) -> None: 

333 """Return the next frame from the socket.""" 

334 if self._tail: 

335 data, self._tail = self._tail + data, b"" 

336 

337 start_pos: int = 0 

338 data_len = len(data) 

339 data_cstr = data 

340 

341 while True: 

342 # read header 

343 if self._state == READ_HEADER: 

344 if data_len - start_pos < 2: 

345 break 

346 first_byte = data_cstr[start_pos] 

347 second_byte = data_cstr[start_pos + 1] 

348 start_pos += 2 

349 

350 fin = (first_byte >> 7) & 1 

351 rsv1 = (first_byte >> 6) & 1 

352 rsv2 = (first_byte >> 5) & 1 

353 rsv3 = (first_byte >> 4) & 1 

354 opcode = first_byte & 0xF 

355 

356 # frame-fin = %x0 ; more frames of this message follow 

357 # / %x1 ; final frame of this message 

358 # frame-rsv1 = %x0 ; 

359 # 1 bit, MUST be 0 unless negotiated otherwise 

360 # frame-rsv2 = %x0 ; 

361 # 1 bit, MUST be 0 unless negotiated otherwise 

362 # frame-rsv3 = %x0 ; 

363 # 1 bit, MUST be 0 unless negotiated otherwise 

364 # 

365 # Remove rsv1 from this test for deflate development 

366 if rsv2 or rsv3 or (rsv1 and not self._compress): 

367 raise WebSocketError( 

368 WSCloseCode.PROTOCOL_ERROR, 

369 "Received frame with non-zero reserved bits", 

370 ) 

371 

372 if opcode not in { 

373 OP_CODE_CONTINUATION, 

374 OP_CODE_TEXT, 

375 OP_CODE_BINARY, 

376 OP_CODE_CLOSE, 

377 OP_CODE_PING, 

378 OP_CODE_PONG, 

379 }: 

380 raise WebSocketError( 

381 WSCloseCode.PROTOCOL_ERROR, 

382 f"Unexpected opcode={opcode!r}", 

383 ) 

384 

385 if opcode > 0x7 and fin == 0: 

386 raise WebSocketError( 

387 WSCloseCode.PROTOCOL_ERROR, 

388 "Received fragmented control frame", 

389 ) 

390 

391 has_mask = (second_byte >> 7) & 1 

392 length = second_byte & 0x7F 

393 

394 # Control frames MUST have a payload 

395 # length of 125 bytes or less 

396 if opcode > 0x7 and length > 125: 

397 raise WebSocketError( 

398 WSCloseCode.PROTOCOL_ERROR, 

399 "Control frame payload cannot be larger than 125 bytes", 

400 ) 

401 

402 # Set compress status if last package is FIN 

403 # OR set compress status if this is first fragment 

404 # Raise error if not first fragment with rsv1 = 0x1 

405 if self._frame_fin or self._compressed == COMPRESSED_NOT_SET: 

406 self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE 

407 elif rsv1: 

408 raise WebSocketError( 

409 WSCloseCode.PROTOCOL_ERROR, 

410 "Received frame with non-zero reserved bits", 

411 ) 

412 

413 self._frame_fin = bool(fin) 

414 self._frame_opcode = opcode 

415 self._has_mask = bool(has_mask) 

416 self._payload_len_flag = length 

417 self._state = READ_PAYLOAD_LENGTH 

418 

419 # read payload length 

420 if self._state == READ_PAYLOAD_LENGTH: 

421 len_flag = self._payload_len_flag 

422 if len_flag == 126: 

423 if data_len - start_pos < 2: 

424 break 

425 first_byte = data_cstr[start_pos] 

426 second_byte = data_cstr[start_pos + 1] 

427 start_pos += 2 

428 self._payload_bytes_to_read = first_byte << 8 | second_byte 

429 elif len_flag > 126: 

430 if data_len - start_pos < 8: 

431 break 

432 self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0] 

433 start_pos += 8 

434 else: 

435 self._payload_bytes_to_read = len_flag 

436 

437 # Reject oversized data frames before buffering any payload 

438 # bytes. Control frames are capped at 125 bytes (checked in 

439 # READ_HEADER) so only text/binary/continuation need this. 

440 if self._max_msg_size and self._frame_opcode in { 

441 OP_CODE_TEXT, 

442 OP_CODE_BINARY, 

443 OP_CODE_CONTINUATION, 

444 }: 

445 projected_size = self._payload_bytes_to_read + len(self._partial) 

446 if projected_size >= self._max_msg_size: 

447 raise WebSocketError( 

448 WSCloseCode.MESSAGE_TOO_BIG, 

449 f"Message size {projected_size} " 

450 f"exceeds limit {self._max_msg_size}", 

451 ) 

452 

453 self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD 

454 

455 # read payload mask 

456 if self._state == READ_PAYLOAD_MASK: 

457 if data_len - start_pos < 4: 

458 break 

459 self._frame_mask = data_cstr[start_pos : start_pos + 4] 

460 start_pos += 4 

461 self._state = READ_PAYLOAD 

462 

463 if self._state == READ_PAYLOAD: 

464 chunk_len = data_len - start_pos 

465 if self._payload_bytes_to_read >= chunk_len: 

466 f_end_pos = data_len 

467 self._payload_bytes_to_read -= chunk_len 

468 else: 

469 f_end_pos = start_pos + self._payload_bytes_to_read 

470 self._payload_bytes_to_read = 0 

471 

472 had_fragments = self._frame_payload_len 

473 self._frame_payload_len += f_end_pos - start_pos 

474 f_start_pos = start_pos 

475 start_pos = f_end_pos 

476 

477 if self._payload_bytes_to_read != 0: 

478 # If we don't have a complete frame, we need to save the 

479 # data for the next call to feed_data. 

480 self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) 

481 break 

482 

483 payload: bytes | bytearray 

484 if had_fragments: 

485 # We have to join the payload fragments get the payload 

486 self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) 

487 if self._has_mask: 

488 assert self._frame_mask is not None 

489 payload_bytearray = bytearray(b"".join(self._payload_fragments)) 

490 websocket_mask(self._frame_mask, payload_bytearray) 

491 payload = payload_bytearray 

492 else: 

493 payload = b"".join(self._payload_fragments) 

494 self._payload_fragments.clear() 

495 elif self._has_mask: 

496 assert self._frame_mask is not None 

497 payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment] 

498 if type(payload_bytearray) is not bytearray: # pragma: no branch 

499 # Cython will do the conversion for us 

500 # but we need to do it for Python and we 

501 # will always get here in Python 

502 payload_bytearray = bytearray(payload_bytearray) 

503 websocket_mask(self._frame_mask, payload_bytearray) 

504 payload = payload_bytearray 

505 else: 

506 payload = data_cstr[f_start_pos:f_end_pos] 

507 

508 self._handle_frame( 

509 self._frame_fin, self._frame_opcode, payload, self._compressed 

510 ) 

511 self._frame_payload_len = 0 

512 self._state = READ_HEADER 

513 

514 # XXX: Cython needs slices to be bounded, so we can't omit the slice end here. 

515 self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""