Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/_websocket/reader_py.py: 17%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

278 statements  

1"""Reader for WebSocket protocol versions 13 and 8.""" 

2 

3import asyncio 

4import builtins 

5from collections import deque 

6from typing import Deque, Final, Optional, Set, Tuple, Type, Union 

7 

8from ..base_protocol import BaseProtocol 

9from ..compression_utils import ZLibDecompressor 

10from ..helpers import _EXC_SENTINEL, set_exception 

11from ..streams import EofStream 

12from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask 

13from .models import ( 

14 WS_DEFLATE_TRAILING, 

15 WebSocketError, 

16 WSCloseCode, 

17 WSMessage, 

18 WSMessageBinary, 

19 WSMessageClose, 

20 WSMessagePing, 

21 WSMessagePong, 

22 WSMessageText, 

23 WSMsgType, 

24) 

25 

26ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode} 

27 

28# States for the reader, used to parse the WebSocket frame 

29# integer values are used so they can be cythonized 

30READ_HEADER = 1 

31READ_PAYLOAD_LENGTH = 2 

32READ_PAYLOAD_MASK = 3 

33READ_PAYLOAD = 4 

34 

35WS_MSG_TYPE_BINARY = WSMsgType.BINARY 

36WS_MSG_TYPE_TEXT = WSMsgType.TEXT 

37 

38# WSMsgType values unpacked so they can by cythonized to ints 

39OP_CODE_NOT_SET = -1 

40OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value 

41OP_CODE_TEXT = WSMsgType.TEXT.value 

42OP_CODE_BINARY = WSMsgType.BINARY.value 

43OP_CODE_CLOSE = WSMsgType.CLOSE.value 

44OP_CODE_PING = WSMsgType.PING.value 

45OP_CODE_PONG = WSMsgType.PONG.value 

46 

47EMPTY_FRAME_ERROR = (True, b"") 

48EMPTY_FRAME = (False, b"") 

49 

50COMPRESSED_NOT_SET = -1 

51COMPRESSED_FALSE = 0 

52COMPRESSED_TRUE = 1 

53 

54TUPLE_NEW = tuple.__new__ 

55 

56cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd 

57 

58 

59class WebSocketDataQueue: 

60 """WebSocketDataQueue resumes and pauses an underlying stream. 

61 

62 It is a destination for WebSocket data. 

63 """ 

64 

65 def __init__( 

66 self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop 

67 ) -> None: 

68 self._size = 0 

69 self._protocol = protocol 

70 self._limit = limit * 2 

71 self._loop = loop 

72 self._eof = False 

73 self._waiter: Optional[asyncio.Future[None]] = None 

74 self._exception: Union[Type[BaseException], BaseException, None] = None 

75 self._buffer: Deque[WSMessage] = deque() 

76 self._get_buffer = self._buffer.popleft 

77 self._put_buffer = self._buffer.append 

78 

79 def is_eof(self) -> bool: 

80 return self._eof 

81 

82 def exception(self) -> Optional[Union[Type[BaseException], BaseException]]: 

83 return self._exception 

84 

85 def set_exception( 

86 self, 

87 exc: Union[Type[BaseException], BaseException], 

88 exc_cause: builtins.BaseException = _EXC_SENTINEL, 

89 ) -> None: 

90 self._eof = True 

91 self._exception = exc 

92 if (waiter := self._waiter) is not None: 

93 self._waiter = None 

94 set_exception(waiter, exc, exc_cause) 

95 

96 def _release_waiter(self) -> None: 

97 if (waiter := self._waiter) is None: 

98 return 

99 self._waiter = None 

100 if not waiter.done(): 

101 waiter.set_result(None) 

102 

103 def feed_eof(self) -> None: 

104 self._eof = True 

105 self._release_waiter() 

106 self._exception = None # Break cyclic references 

107 

108 def feed_data(self, data: "WSMessage") -> None: 

109 size = data.size 

110 self._size += size 

111 self._put_buffer(data) 

112 self._release_waiter() 

113 if self._size > self._limit and not self._protocol._reading_paused: 

114 self._protocol.pause_reading() 

115 

116 async def read(self) -> WSMessage: 

117 if not self._buffer and not self._eof: 

118 assert not self._waiter 

119 self._waiter = self._loop.create_future() 

120 try: 

121 await self._waiter 

122 except (asyncio.CancelledError, asyncio.TimeoutError): 

123 self._waiter = None 

124 raise 

125 return self._read_from_buffer() 

126 

127 def _read_from_buffer(self) -> WSMessage: 

128 if self._buffer: 

129 data = self._get_buffer() 

130 size = data.size 

131 self._size -= size 

132 if self._size < self._limit and self._protocol._reading_paused: 

133 self._protocol.resume_reading() 

134 return data 

135 if self._exception is not None: 

136 raise self._exception 

137 raise EofStream 

138 

139 

140class WebSocketReader: 

141 def __init__( 

142 self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True 

143 ) -> None: 

144 self.queue = queue 

145 self._max_msg_size = max_msg_size 

146 

147 self._exc: Optional[Exception] = None 

148 self._partial = bytearray() 

149 self._state = READ_HEADER 

150 

151 self._opcode: int = OP_CODE_NOT_SET 

152 self._frame_fin = False 

153 self._frame_opcode: int = OP_CODE_NOT_SET 

154 self._payload_fragments: list[bytes] = [] 

155 self._frame_payload_len = 0 

156 

157 self._tail: bytes = b"" 

158 self._has_mask = False 

159 self._frame_mask: Optional[bytes] = None 

160 self._payload_bytes_to_read = 0 

161 self._payload_len_flag = 0 

162 self._compressed: int = COMPRESSED_NOT_SET 

163 self._decompressobj: Optional[ZLibDecompressor] = None 

164 self._compress = compress 

165 

166 def feed_eof(self) -> None: 

167 self.queue.feed_eof() 

168 

169 # data can be bytearray on Windows because proactor event loop uses bytearray 

170 # and asyncio types this to Union[bytes, bytearray, memoryview] so we need 

171 # coerce data to bytes if it is not 

172 def feed_data( 

173 self, data: Union[bytes, bytearray, memoryview] 

174 ) -> Tuple[bool, bytes]: 

175 if type(data) is not bytes: 

176 data = bytes(data) 

177 

178 if self._exc is not None: 

179 return True, data 

180 

181 try: 

182 self._feed_data(data) 

183 except Exception as exc: 

184 self._exc = exc 

185 set_exception(self.queue, exc) 

186 return EMPTY_FRAME_ERROR 

187 

188 return EMPTY_FRAME 

189 

190 def _handle_frame( 

191 self, 

192 fin: bool, 

193 opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int 

194 payload: Union[bytes, bytearray], 

195 compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int 

196 ) -> None: 

197 msg: WSMessage 

198 if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}: 

199 # load text/binary 

200 if not fin: 

201 # got partial frame payload 

202 if opcode != OP_CODE_CONTINUATION: 

203 self._opcode = opcode 

204 self._partial += payload 

205 if self._max_msg_size and len(self._partial) >= self._max_msg_size: 

206 raise WebSocketError( 

207 WSCloseCode.MESSAGE_TOO_BIG, 

208 f"Message size {len(self._partial)} " 

209 f"exceeds limit {self._max_msg_size}", 

210 ) 

211 return 

212 

213 has_partial = bool(self._partial) 

214 if opcode == OP_CODE_CONTINUATION: 

215 if self._opcode == OP_CODE_NOT_SET: 

216 raise WebSocketError( 

217 WSCloseCode.PROTOCOL_ERROR, 

218 "Continuation frame for non started message", 

219 ) 

220 opcode = self._opcode 

221 self._opcode = OP_CODE_NOT_SET 

222 # previous frame was non finished 

223 # we should get continuation opcode 

224 elif has_partial: 

225 raise WebSocketError( 

226 WSCloseCode.PROTOCOL_ERROR, 

227 "The opcode in non-fin frame is expected " 

228 f"to be zero, got {opcode!r}", 

229 ) 

230 

231 assembled_payload: Union[bytes, bytearray] 

232 if has_partial: 

233 assembled_payload = self._partial + payload 

234 self._partial.clear() 

235 else: 

236 assembled_payload = payload 

237 

238 if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: 

239 raise WebSocketError( 

240 WSCloseCode.MESSAGE_TOO_BIG, 

241 f"Message size {len(assembled_payload)} " 

242 f"exceeds limit {self._max_msg_size}", 

243 ) 

244 

245 # Decompress process must to be done after all packets 

246 # received. 

247 if compressed: 

248 if not self._decompressobj: 

249 self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) 

250 # XXX: It's possible that the zlib backend (isal is known to 

251 # do this, maybe others too?) will return max_length bytes, 

252 # but internally buffer more data such that the payload is 

253 # >max_length, so we return one extra byte and if we're able 

254 # to do that, then the message is too big. 

255 payload_merged = self._decompressobj.decompress_sync( 

256 assembled_payload + WS_DEFLATE_TRAILING, 

257 ( 

258 self._max_msg_size + 1 

259 if self._max_msg_size 

260 else self._max_msg_size 

261 ), 

262 ) 

263 if self._max_msg_size and len(payload_merged) > self._max_msg_size: 

264 raise WebSocketError( 

265 WSCloseCode.MESSAGE_TOO_BIG, 

266 f"Decompressed message exceeds size limit {self._max_msg_size}", 

267 ) 

268 elif type(assembled_payload) is bytes: 

269 payload_merged = assembled_payload 

270 else: 

271 payload_merged = bytes(assembled_payload) 

272 

273 size = len(payload_merged) 

274 if opcode == OP_CODE_TEXT: 

275 try: 

276 text = payload_merged.decode("utf-8") 

277 except UnicodeDecodeError as exc: 

278 raise WebSocketError( 

279 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" 

280 ) from exc 

281 

282 # XXX: The Text and Binary messages here can be a performance 

283 # bottleneck, so we use tuple.__new__ to improve performance. 

284 # This is not type safe, but many tests should fail in 

285 # test_client_ws_functional.py if this is wrong. 

286 msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT)) 

287 else: 

288 msg = TUPLE_NEW( 

289 WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY) 

290 ) 

291 

292 self.queue.feed_data(msg) 

293 elif opcode == OP_CODE_CLOSE: 

294 payload_len = len(payload) 

295 if payload_len >= 2: 

296 close_code = UNPACK_CLOSE_CODE(payload[:2])[0] 

297 if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: 

298 raise WebSocketError( 

299 WSCloseCode.PROTOCOL_ERROR, 

300 f"Invalid close code: {close_code}", 

301 ) 

302 try: 

303 close_message = payload[2:].decode("utf-8") 

304 except UnicodeDecodeError as exc: 

305 raise WebSocketError( 

306 WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" 

307 ) from exc 

308 msg = WSMessageClose( 

309 data=close_code, size=payload_len, extra=close_message 

310 ) 

311 elif payload: 

312 raise WebSocketError( 

313 WSCloseCode.PROTOCOL_ERROR, 

314 f"Invalid close frame: {fin} {opcode} {payload!r}", 

315 ) 

316 else: 

317 msg = WSMessageClose(data=0, size=payload_len, extra="") 

318 

319 self.queue.feed_data(msg) 

320 elif opcode == OP_CODE_PING: 

321 self.queue.feed_data( 

322 WSMessagePing(data=bytes(payload), size=len(payload), extra="") 

323 ) 

324 elif opcode == OP_CODE_PONG: 

325 self.queue.feed_data( 

326 WSMessagePong(data=bytes(payload), size=len(payload), extra="") 

327 ) 

328 else: 

329 raise WebSocketError( 

330 WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" 

331 ) 

332 

333 def _feed_data(self, data: bytes) -> None: 

334 """Return the next frame from the socket.""" 

335 if self._tail: 

336 data, self._tail = self._tail + data, b"" 

337 

338 start_pos: int = 0 

339 data_len = len(data) 

340 data_cstr = data 

341 

342 while True: 

343 # read header 

344 if self._state == READ_HEADER: 

345 if data_len - start_pos < 2: 

346 break 

347 first_byte = data_cstr[start_pos] 

348 second_byte = data_cstr[start_pos + 1] 

349 start_pos += 2 

350 

351 fin = (first_byte >> 7) & 1 

352 rsv1 = (first_byte >> 6) & 1 

353 rsv2 = (first_byte >> 5) & 1 

354 rsv3 = (first_byte >> 4) & 1 

355 opcode = first_byte & 0xF 

356 

357 # frame-fin = %x0 ; more frames of this message follow 

358 # / %x1 ; final frame of this message 

359 # frame-rsv1 = %x0 ; 

360 # 1 bit, MUST be 0 unless negotiated otherwise 

361 # frame-rsv2 = %x0 ; 

362 # 1 bit, MUST be 0 unless negotiated otherwise 

363 # frame-rsv3 = %x0 ; 

364 # 1 bit, MUST be 0 unless negotiated otherwise 

365 # 

366 # Remove rsv1 from this test for deflate development 

367 if rsv2 or rsv3 or (rsv1 and not self._compress): 

368 raise WebSocketError( 

369 WSCloseCode.PROTOCOL_ERROR, 

370 "Received frame with non-zero reserved bits", 

371 ) 

372 

373 if opcode > 0x7 and fin == 0: 

374 raise WebSocketError( 

375 WSCloseCode.PROTOCOL_ERROR, 

376 "Received fragmented control frame", 

377 ) 

378 

379 has_mask = (second_byte >> 7) & 1 

380 length = second_byte & 0x7F 

381 

382 # Control frames MUST have a payload 

383 # length of 125 bytes or less 

384 if opcode > 0x7 and length > 125: 

385 raise WebSocketError( 

386 WSCloseCode.PROTOCOL_ERROR, 

387 "Control frame payload cannot be larger than 125 bytes", 

388 ) 

389 

390 # Set compress status if last package is FIN 

391 # OR set compress status if this is first fragment 

392 # Raise error if not first fragment with rsv1 = 0x1 

393 if self._frame_fin or self._compressed == COMPRESSED_NOT_SET: 

394 self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE 

395 elif rsv1: 

396 raise WebSocketError( 

397 WSCloseCode.PROTOCOL_ERROR, 

398 "Received frame with non-zero reserved bits", 

399 ) 

400 

401 self._frame_fin = bool(fin) 

402 self._frame_opcode = opcode 

403 self._has_mask = bool(has_mask) 

404 self._payload_len_flag = length 

405 self._state = READ_PAYLOAD_LENGTH 

406 

407 # read payload length 

408 if self._state == READ_PAYLOAD_LENGTH: 

409 len_flag = self._payload_len_flag 

410 if len_flag == 126: 

411 if data_len - start_pos < 2: 

412 break 

413 first_byte = data_cstr[start_pos] 

414 second_byte = data_cstr[start_pos + 1] 

415 start_pos += 2 

416 self._payload_bytes_to_read = first_byte << 8 | second_byte 

417 elif len_flag > 126: 

418 if data_len - start_pos < 8: 

419 break 

420 self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0] 

421 start_pos += 8 

422 else: 

423 self._payload_bytes_to_read = len_flag 

424 

425 self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD 

426 

427 # read payload mask 

428 if self._state == READ_PAYLOAD_MASK: 

429 if data_len - start_pos < 4: 

430 break 

431 self._frame_mask = data_cstr[start_pos : start_pos + 4] 

432 start_pos += 4 

433 self._state = READ_PAYLOAD 

434 

435 if self._state == READ_PAYLOAD: 

436 chunk_len = data_len - start_pos 

437 if self._payload_bytes_to_read >= chunk_len: 

438 f_end_pos = data_len 

439 self._payload_bytes_to_read -= chunk_len 

440 else: 

441 f_end_pos = start_pos + self._payload_bytes_to_read 

442 self._payload_bytes_to_read = 0 

443 

444 had_fragments = self._frame_payload_len 

445 self._frame_payload_len += f_end_pos - start_pos 

446 f_start_pos = start_pos 

447 start_pos = f_end_pos 

448 

449 if self._payload_bytes_to_read != 0: 

450 # If we don't have a complete frame, we need to save the 

451 # data for the next call to feed_data. 

452 self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) 

453 break 

454 

455 payload: Union[bytes, bytearray] 

456 if had_fragments: 

457 # We have to join the payload fragments get the payload 

458 self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) 

459 if self._has_mask: 

460 assert self._frame_mask is not None 

461 payload_bytearray = bytearray(b"".join(self._payload_fragments)) 

462 websocket_mask(self._frame_mask, payload_bytearray) 

463 payload = payload_bytearray 

464 else: 

465 payload = b"".join(self._payload_fragments) 

466 self._payload_fragments.clear() 

467 elif self._has_mask: 

468 assert self._frame_mask is not None 

469 payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment] 

470 if type(payload_bytearray) is not bytearray: # pragma: no branch 

471 # Cython will do the conversion for us 

472 # but we need to do it for Python and we 

473 # will always get here in Python 

474 payload_bytearray = bytearray(payload_bytearray) 

475 websocket_mask(self._frame_mask, payload_bytearray) 

476 payload = payload_bytearray 

477 else: 

478 payload = data_cstr[f_start_pos:f_end_pos] 

479 

480 self._handle_frame( 

481 self._frame_fin, self._frame_opcode, payload, self._compressed 

482 ) 

483 self._frame_payload_len = 0 

484 self._state = READ_HEADER 

485 

486 # XXX: Cython needs slices to be bounded, so we can't omit the slice end here. 

487 self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""