Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/_websocket/reader_py.py: 17%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

282 statements  

1"""Reader for WebSocket protocol versions 13 and 8.""" 

2 

3import asyncio 

4import builtins 

5from collections import deque 

6from typing import Final 

7 

8from ..base_protocol import BaseProtocol 

9from ..compression_utils import ZLibDecompressor 

10from ..helpers import _EXC_SENTINEL, set_exception 

11from ..streams import EofStream 

12from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask 

13from .models import ( 

14 WS_DEFLATE_TRAILING, 

15 WebSocketError, 

16 WSCloseCode, 

17 WSMessage, 

18 WSMessageBinary, 

19 WSMessageClose, 

20 WSMessagePing, 

21 WSMessagePong, 

22 WSMessageText, 

23 WSMessageTextBytes, 

24 WSMsgType, 

25) 

26 

27ALLOWED_CLOSE_CODES: Final[set[int]] = {int(i) for i in WSCloseCode} 

28 

29# States for the reader, used to parse the WebSocket frame 

30# integer values are used so they can be cythonized 

31READ_HEADER = 1 

32READ_PAYLOAD_LENGTH = 2 

33READ_PAYLOAD_MASK = 3 

34READ_PAYLOAD = 4 

35 

36WS_MSG_TYPE_BINARY = WSMsgType.BINARY 

37WS_MSG_TYPE_TEXT = WSMsgType.TEXT 

38 

39# WSMsgType values unpacked so they can by cythonized to ints 

40OP_CODE_NOT_SET = -1 

41OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value 

42OP_CODE_TEXT = WSMsgType.TEXT.value 

43OP_CODE_BINARY = WSMsgType.BINARY.value 

44OP_CODE_CLOSE = WSMsgType.CLOSE.value 

45OP_CODE_PING = WSMsgType.PING.value 

46OP_CODE_PONG = WSMsgType.PONG.value 

47 

48EMPTY_FRAME_ERROR = (True, b"") 

49EMPTY_FRAME = (False, b"") 

50 

51COMPRESSED_NOT_SET = -1 

52COMPRESSED_FALSE = 0 

53COMPRESSED_TRUE = 1 

54 

55TUPLE_NEW = tuple.__new__ 

56 

57cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd 

58 

59 

60class WebSocketDataQueue: 

61 """WebSocketDataQueue resumes and pauses an underlying stream. 

62 

63 It is a destination for WebSocket data. 

64 """ 

65 

66 def __init__( 

67 self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop 

68 ) -> None: 

69 self._size = 0 

70 self._protocol = protocol 

71 self._limit = limit * 2 

72 self._loop = loop 

73 self._eof = False 

74 self._waiter: asyncio.Future[None] | None = None 

75 self._exception: type[BaseException] | BaseException | None = None 

76 self._buffer: deque[WSMessage] = deque() 

77 self._get_buffer = self._buffer.popleft 

78 self._put_buffer = self._buffer.append 

79 

80 def is_eof(self) -> bool: 

81 return self._eof 

82 

83 def exception(self) -> type[BaseException] | BaseException | None: 

84 return self._exception 

85 

86 def set_exception( 

87 self, 

88 exc: type[BaseException] | BaseException, 

89 exc_cause: builtins.BaseException = _EXC_SENTINEL, 

90 ) -> None: 

91 self._eof = True 

92 self._exception = exc 

93 if (waiter := self._waiter) is not None: 

94 self._waiter = None 

95 set_exception(waiter, exc, exc_cause) 

96 

97 def _release_waiter(self) -> None: 

98 if (waiter := self._waiter) is None: 

99 return 

100 self._waiter = None 

101 if not waiter.done(): 

102 waiter.set_result(None) 

103 

104 def feed_eof(self) -> None: 

105 self._eof = True 

106 self._release_waiter() 

107 self._exception = None # Break cyclic references 

108 

109 def feed_data(self, data: "WSMessage") -> None: 

110 size = data.size 

111 self._size += size 

112 self._put_buffer(data) 

113 self._release_waiter() 

114 if self._size > self._limit and not self._protocol._reading_paused: 

115 self._protocol.pause_reading() 

116 

117 async def read(self) -> WSMessage: 

118 if not self._buffer and not self._eof: 

119 assert not self._waiter 

120 self._waiter = self._loop.create_future() 

121 try: 

122 await self._waiter 

123 except (asyncio.CancelledError, asyncio.TimeoutError): 

124 self._waiter = None 

125 raise 

126 return self._read_from_buffer() 

127 

128 def _read_from_buffer(self) -> WSMessage: 

129 if self._buffer: 

130 data = self._get_buffer() 

131 size = data.size 

132 self._size -= size 

133 if self._size < self._limit and self._protocol._reading_paused: 

134 self._protocol.resume_reading() 

135 return data 

136 if self._exception is not None: 

137 raise self._exception 

138 raise EofStream 

139 

140 

141class WebSocketReader: 

142 def __init__( 

143 self, 

144 queue: WebSocketDataQueue, 

145 max_msg_size: int, 

146 compress: bool = True, 

147 decode_text: bool = True, 

148 ) -> None: 

149 self.queue = queue 

150 self._max_msg_size = max_msg_size 

151 self._decode_text = decode_text 

152 

153 self._exc: Exception | None = None 

154 self._partial = bytearray() 

155 self._state = READ_HEADER 

156 

157 self._opcode: int = OP_CODE_NOT_SET 

158 self._frame_fin = False 

159 self._frame_opcode: int = OP_CODE_NOT_SET 

160 self._payload_fragments: list[bytes] = [] 

161 self._frame_payload_len = 0 

162 

163 self._tail: bytes = b"" 

164 self._has_mask = False 

165 self._frame_mask: bytes | None = None 

166 self._payload_bytes_to_read = 0 

167 self._payload_len_flag = 0 

168 self._compressed: int = COMPRESSED_NOT_SET 

169 self._decompressobj: ZLibDecompressor | None = None 

170 self._compress = compress 

171 

172 def feed_eof(self) -> None: 

173 self.queue.feed_eof() 

174 

175 # data can be bytearray on Windows because proactor event loop uses bytearray 

176 # and asyncio types this to Union[bytes, bytearray, memoryview] so we need 

177 # coerce data to bytes if it is not 

178 def feed_data(self, data: bytes | bytearray | memoryview) -> tuple[bool, bytes]: 

179 if type(data) is not bytes: 

180 data = bytes(data) 

181 

182 if self._exc is not None: 

183 return True, data 

184 

185 try: 

186 self._feed_data(data) 

187 except Exception as exc: 

188 self._exc = exc 

189 set_exception(self.queue, exc) 

190 return EMPTY_FRAME_ERROR 

191 

192 return EMPTY_FRAME 

193 

194 def _handle_frame( 

195 self, 

196 fin: bool, 

197 opcode: int | cython_int, # Union intended: Cython pxd uses C int 

198 payload: bytes | bytearray, 

199 compressed: int | cython_int, # Union intended: Cython pxd uses C int 

200 ) -> None: 

201 msg: WSMessage 

202 if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}: 

203 # Validate continuation frames before processing 

204 if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET: 

205 raise WebSocketError( 

206 WSCloseCode.PROTOCOL_ERROR, 

207 "Continuation frame for non started message", 

208 ) 

209 

210 # load text/binary 

211 if not fin: 

212 # got partial frame payload 

213 if opcode != OP_CODE_CONTINUATION: 

214 self._opcode = opcode 

215 self._partial += payload 

216 if self._max_msg_size and len(self._partial) >= self._max_msg_size: 

217 raise WebSocketError( 

218 WSCloseCode.MESSAGE_TOO_BIG, 

219 f"Message size {len(self._partial)} " 

220 f"exceeds limit {self._max_msg_size}", 

221 ) 

222 return 

223 

224 has_partial = bool(self._partial) 

225 if opcode == OP_CODE_CONTINUATION: 

226 opcode = self._opcode 

227 self._opcode = OP_CODE_NOT_SET 

228 # previous frame was non finished 

229 # we should get continuation opcode 

230 elif has_partial: 

231 raise WebSocketError( 

232 WSCloseCode.PROTOCOL_ERROR, 

233 "The opcode in non-fin frame is expected " 

234 f"to be zero, got {opcode!r}", 

235 ) 

236 

237 assembled_payload: bytes | bytearray 

238 if has_partial: 

239 assembled_payload = self._partial + payload 

240 self._partial.clear() 

241 else: 

242 assembled_payload = payload 

243 

244 if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: 

245 raise WebSocketError( 

246 WSCloseCode.MESSAGE_TOO_BIG, 

247 f"Message size {len(assembled_payload)} " 

248 f"exceeds limit {self._max_msg_size}", 

249 ) 

250 

251 # Decompress process must to be done after all packets 

252 # received. 

253 if compressed: 

254 if not self._decompressobj: 

255 self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) 

256 # XXX: It's possible that the zlib backend (isal is known to 

257 # do this, maybe others too?) will return max_length bytes, 

258 # but internally buffer more data such that the payload is 

259 # >max_length, so we return one extra byte and if we're able 

260 # to do that, then the message is too big. 

261 payload_merged = self._decompressobj.decompress_sync( 

262 assembled_payload + WS_DEFLATE_TRAILING, 

263 ( 

264 self._max_msg_size + 1 

265 if self._max_msg_size 

266 else self._max_msg_size 

267 ), 

268 ) 

269 if self._max_msg_size and len(payload_merged) > self._max_msg_size: 

270 raise WebSocketError( 

271 WSCloseCode.MESSAGE_TOO_BIG, 

272 f"Decompressed message exceeds size limit {self._max_msg_size}", 

273 ) 

274 elif type(assembled_payload) is bytes: 

275 payload_merged = assembled_payload 

276 else: 

277 payload_merged = bytes(assembled_payload) 

278 

279 size = len(payload_merged) 

280 if opcode == OP_CODE_TEXT: 

281 if self._decode_text: 

282 try: 

283 text = payload_merged.decode("utf-8") 

284 except UnicodeDecodeError as exc: 

285 raise WebSocketError( 

286 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" 

287 ) from exc 

288 

289 # XXX: The Text and Binary messages here can be a performance 

290 # bottleneck, so we use tuple.__new__ to improve performance. 

291 # This is not type safe, but many tests should fail in 

292 # test_client_ws_functional.py if this is wrong. 

293 msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT)) 

294 else: 

295 # Return raw bytes for TEXT messages when decode_text=False 

296 msg = TUPLE_NEW( 

297 WSMessageTextBytes, (payload_merged, size, "", WS_MSG_TYPE_TEXT) 

298 ) 

299 else: 

300 msg = TUPLE_NEW( 

301 WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY) 

302 ) 

303 

304 self.queue.feed_data(msg) 

305 elif opcode == OP_CODE_CLOSE: 

306 payload_len = len(payload) 

307 if payload_len >= 2: 

308 close_code = UNPACK_CLOSE_CODE(payload[:2])[0] 

309 if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: 

310 raise WebSocketError( 

311 WSCloseCode.PROTOCOL_ERROR, 

312 f"Invalid close code: {close_code}", 

313 ) 

314 try: 

315 close_message = payload[2:].decode("utf-8") 

316 except UnicodeDecodeError as exc: 

317 raise WebSocketError( 

318 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" 

319 ) from exc 

320 msg = WSMessageClose( 

321 data=close_code, size=payload_len, extra=close_message 

322 ) 

323 elif payload: 

324 raise WebSocketError( 

325 WSCloseCode.PROTOCOL_ERROR, 

326 f"Invalid close frame: {fin} {opcode} {payload!r}", 

327 ) 

328 else: 

329 msg = WSMessageClose(data=0, size=payload_len, extra="") 

330 

331 self.queue.feed_data(msg) 

332 elif opcode == OP_CODE_PING: 

333 self.queue.feed_data( 

334 WSMessagePing(data=bytes(payload), size=len(payload), extra="") 

335 ) 

336 elif opcode == OP_CODE_PONG: 

337 self.queue.feed_data( 

338 WSMessagePong(data=bytes(payload), size=len(payload), extra="") 

339 ) 

340 else: 

341 raise WebSocketError( 

342 WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" 

343 ) 

344 

345 def _feed_data(self, data: bytes) -> None: 

346 """Return the next frame from the socket.""" 

347 if self._tail: 

348 data, self._tail = self._tail + data, b"" 

349 

350 start_pos: int = 0 

351 data_len = len(data) 

352 data_cstr = data 

353 

354 while True: 

355 # read header 

356 if self._state == READ_HEADER: 

357 if data_len - start_pos < 2: 

358 break 

359 first_byte = data_cstr[start_pos] 

360 second_byte = data_cstr[start_pos + 1] 

361 start_pos += 2 

362 

363 fin = (first_byte >> 7) & 1 

364 rsv1 = (first_byte >> 6) & 1 

365 rsv2 = (first_byte >> 5) & 1 

366 rsv3 = (first_byte >> 4) & 1 

367 opcode = first_byte & 0xF 

368 

369 # frame-fin = %x0 ; more frames of this message follow 

370 # / %x1 ; final frame of this message 

371 # frame-rsv1 = %x0 ; 

372 # 1 bit, MUST be 0 unless negotiated otherwise 

373 # frame-rsv2 = %x0 ; 

374 # 1 bit, MUST be 0 unless negotiated otherwise 

375 # frame-rsv3 = %x0 ; 

376 # 1 bit, MUST be 0 unless negotiated otherwise 

377 # 

378 # Remove rsv1 from this test for deflate development 

379 if rsv2 or rsv3 or (rsv1 and not self._compress): 

380 raise WebSocketError( 

381 WSCloseCode.PROTOCOL_ERROR, 

382 "Received frame with non-zero reserved bits", 

383 ) 

384 

385 if opcode > 0x7 and fin == 0: 

386 raise WebSocketError( 

387 WSCloseCode.PROTOCOL_ERROR, 

388 "Received fragmented control frame", 

389 ) 

390 

391 has_mask = (second_byte >> 7) & 1 

392 length = second_byte & 0x7F 

393 

394 # Control frames MUST have a payload 

395 # length of 125 bytes or less 

396 if opcode > 0x7 and length > 125: 

397 raise WebSocketError( 

398 WSCloseCode.PROTOCOL_ERROR, 

399 "Control frame payload cannot be larger than 125 bytes", 

400 ) 

401 

402 # Set compress status if last package is FIN 

403 # OR set compress status if this is first fragment 

404 # Raise error if not first fragment with rsv1 = 0x1 

405 if self._frame_fin or self._compressed == COMPRESSED_NOT_SET: 

406 self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE 

407 elif rsv1: 

408 raise WebSocketError( 

409 WSCloseCode.PROTOCOL_ERROR, 

410 "Received frame with non-zero reserved bits", 

411 ) 

412 

413 self._frame_fin = bool(fin) 

414 self._frame_opcode = opcode 

415 self._has_mask = bool(has_mask) 

416 self._payload_len_flag = length 

417 self._state = READ_PAYLOAD_LENGTH 

418 

419 # read payload length 

420 if self._state == READ_PAYLOAD_LENGTH: 

421 len_flag = self._payload_len_flag 

422 if len_flag == 126: 

423 if data_len - start_pos < 2: 

424 break 

425 first_byte = data_cstr[start_pos] 

426 second_byte = data_cstr[start_pos + 1] 

427 start_pos += 2 

428 self._payload_bytes_to_read = first_byte << 8 | second_byte 

429 elif len_flag > 126: 

430 if data_len - start_pos < 8: 

431 break 

432 self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0] 

433 start_pos += 8 

434 else: 

435 self._payload_bytes_to_read = len_flag 

436 

437 self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD 

438 

439 # read payload mask 

440 if self._state == READ_PAYLOAD_MASK: 

441 if data_len - start_pos < 4: 

442 break 

443 self._frame_mask = data_cstr[start_pos : start_pos + 4] 

444 start_pos += 4 

445 self._state = READ_PAYLOAD 

446 

447 if self._state == READ_PAYLOAD: 

448 chunk_len = data_len - start_pos 

449 if self._payload_bytes_to_read >= chunk_len: 

450 f_end_pos = data_len 

451 self._payload_bytes_to_read -= chunk_len 

452 else: 

453 f_end_pos = start_pos + self._payload_bytes_to_read 

454 self._payload_bytes_to_read = 0 

455 

456 had_fragments = self._frame_payload_len 

457 self._frame_payload_len += f_end_pos - start_pos 

458 f_start_pos = start_pos 

459 start_pos = f_end_pos 

460 

461 if self._payload_bytes_to_read != 0: 

462 # If we don't have a complete frame, we need to save the 

463 # data for the next call to feed_data. 

464 self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) 

465 break 

466 

467 payload: bytes | bytearray 

468 if had_fragments: 

469 # We have to join the payload fragments get the payload 

470 self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) 

471 if self._has_mask: 

472 assert self._frame_mask is not None 

473 payload_bytearray = bytearray(b"".join(self._payload_fragments)) 

474 websocket_mask(self._frame_mask, payload_bytearray) 

475 payload = payload_bytearray 

476 else: 

477 payload = b"".join(self._payload_fragments) 

478 self._payload_fragments.clear() 

479 elif self._has_mask: 

480 assert self._frame_mask is not None 

481 payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment] 

482 if type(payload_bytearray) is not bytearray: # pragma: no branch 

483 # Cython will do the conversion for us 

484 # but we need to do it for Python and we 

485 # will always get here in Python 

486 payload_bytearray = bytearray(payload_bytearray) 

487 websocket_mask(self._frame_mask, payload_bytearray) 

488 payload = payload_bytearray 

489 else: 

490 payload = data_cstr[f_start_pos:f_end_pos] 

491 

492 self._handle_frame( 

493 self._frame_fin, self._frame_opcode, payload, self._compressed 

494 ) 

495 self._frame_payload_len = 0 

496 self._state = READ_HEADER 

497 

498 # XXX: Cython needs slices to be bounded, so we can't omit the slice end here. 

499 self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""