Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_ws.py: 20%

318 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:52 +0000

1import asyncio 

2import base64 

3import binascii 

4import dataclasses 

5import hashlib 

6import json 

7from typing import Any, Iterable, Optional, Tuple, cast 

8 

9import async_timeout 

10from multidict import CIMultiDict 

11from typing_extensions import Final 

12 

13from . import hdrs 

14from .abc import AbstractStreamWriter 

15from .helpers import call_later, set_result 

16from .http import ( 

17 WS_CLOSED_MESSAGE, 

18 WS_CLOSING_MESSAGE, 

19 WS_KEY, 

20 WebSocketError, 

21 WebSocketReader, 

22 WebSocketWriter, 

23 WSCloseCode, 

24 WSMessage, 

25 WSMsgType as WSMsgType, 

26 ws_ext_gen, 

27 ws_ext_parse, 

28) 

29from .log import ws_logger 

30from .streams import EofStream, FlowControlDataQueue 

31from .typedefs import JSONDecoder, JSONEncoder 

32from .web_exceptions import HTTPBadRequest, HTTPException 

33from .web_request import BaseRequest 

34from .web_response import StreamResponse 

35 

36__all__ = ( 

37 "WebSocketResponse", 

38 "WebSocketReady", 

39 "WSMsgType", 

40) 

41 

42THRESHOLD_CONNLOST_ACCESS: Final[int] = 5 

43 

44 

45@dataclasses.dataclass(frozen=True) 

46class WebSocketReady: 

47 ok: bool 

48 protocol: Optional[str] 

49 

50 def __bool__(self) -> bool: 

51 return self.ok 

52 

53 

54class WebSocketResponse(StreamResponse): 

55 __slots__ = ( 

56 "_protocols", 

57 "_ws_protocol", 

58 "_writer", 

59 "_reader", 

60 "_closed", 

61 "_closing", 

62 "_conn_lost", 

63 "_close_code", 

64 "_loop", 

65 "_waiting", 

66 "_exception", 

67 "_timeout", 

68 "_receive_timeout", 

69 "_autoclose", 

70 "_autoping", 

71 "_heartbeat", 

72 "_heartbeat_cb", 

73 "_pong_heartbeat", 

74 "_pong_response_cb", 

75 "_compress", 

76 "_max_msg_size", 

77 ) 

78 

79 def __init__( 

80 self, 

81 *, 

82 timeout: float = 10.0, 

83 receive_timeout: Optional[float] = None, 

84 autoclose: bool = True, 

85 autoping: bool = True, 

86 heartbeat: Optional[float] = None, 

87 protocols: Iterable[str] = (), 

88 compress: bool = True, 

89 max_msg_size: int = 4 * 1024 * 1024, 

90 ) -> None: 

91 super().__init__(status=101) 

92 self._length_check = False 

93 self._protocols = protocols 

94 self._ws_protocol: Optional[str] = None 

95 self._writer: Optional[WebSocketWriter] = None 

96 self._reader: Optional[FlowControlDataQueue[WSMessage]] = None 

97 self._closed = False 

98 self._closing = False 

99 self._conn_lost = 0 

100 self._close_code: Optional[int] = None 

101 self._loop: Optional[asyncio.AbstractEventLoop] = None 

102 self._waiting: Optional[asyncio.Future[bool]] = None 

103 self._exception: Optional[BaseException] = None 

104 self._timeout = timeout 

105 self._receive_timeout = receive_timeout 

106 self._autoclose = autoclose 

107 self._autoping = autoping 

108 self._heartbeat = heartbeat 

109 self._heartbeat_cb: Optional[asyncio.TimerHandle] = None 

110 if heartbeat is not None: 

111 self._pong_heartbeat = heartbeat / 2.0 

112 self._pong_response_cb: Optional[asyncio.TimerHandle] = None 

113 self._compress = compress 

114 self._max_msg_size = max_msg_size 

115 

116 def _cancel_heartbeat(self) -> None: 

117 if self._pong_response_cb is not None: 

118 self._pong_response_cb.cancel() 

119 self._pong_response_cb = None 

120 

121 if self._heartbeat_cb is not None: 

122 self._heartbeat_cb.cancel() 

123 self._heartbeat_cb = None 

124 

125 def _reset_heartbeat(self) -> None: 

126 self._cancel_heartbeat() 

127 

128 if self._heartbeat is not None: 

129 assert self._loop is not None 

130 self._heartbeat_cb = call_later( 

131 self._send_heartbeat, 

132 self._heartbeat, 

133 self._loop, 

134 timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold 

135 if self._req is not None 

136 else 5, 

137 ) 

138 

139 def _send_heartbeat(self) -> None: 

140 if self._heartbeat is not None and not self._closed: 

141 assert self._loop is not None 

142 # fire-and-forget a task is not perfect but maybe ok for 

143 # sending ping. Otherwise we need a long-living heartbeat 

144 # task in the class. 

145 self._loop.create_task(self._writer.ping()) # type: ignore[union-attr] 

146 

147 if self._pong_response_cb is not None: 

148 self._pong_response_cb.cancel() 

149 self._pong_response_cb = call_later( 

150 self._pong_not_received, 

151 self._pong_heartbeat, 

152 self._loop, 

153 timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold 

154 if self._req is not None 

155 else 5, 

156 ) 

157 

158 def _pong_not_received(self) -> None: 

159 if self._req is not None and self._req.transport is not None: 

160 self._closed = True 

161 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

162 self._exception = asyncio.TimeoutError() 

163 self._req.transport.close() 

164 

165 async def prepare(self, request: BaseRequest) -> AbstractStreamWriter: 

166 # make pre-check to don't hide it by do_handshake() exceptions 

167 if self._payload_writer is not None: 

168 return self._payload_writer 

169 

170 protocol, writer = self._pre_start(request) 

171 payload_writer = await super().prepare(request) 

172 assert payload_writer is not None 

173 self._post_start(request, protocol, writer) 

174 await payload_writer.drain() 

175 return payload_writer 

176 

177 def _handshake( 

178 self, request: BaseRequest 

179 ) -> Tuple["CIMultiDict[str]", str, bool, bool]: 

180 headers = request.headers 

181 if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip(): 

182 raise HTTPBadRequest( 

183 text=( 

184 "No WebSocket UPGRADE hdr: {}\n Can " 

185 '"Upgrade" only to "WebSocket".' 

186 ).format(headers.get(hdrs.UPGRADE)) 

187 ) 

188 

189 if "upgrade" not in headers.get(hdrs.CONNECTION, "").lower(): 

190 raise HTTPBadRequest( 

191 text="No CONNECTION upgrade hdr: {}".format( 

192 headers.get(hdrs.CONNECTION) 

193 ) 

194 ) 

195 

196 # find common sub-protocol between client and server 

197 protocol = None 

198 if hdrs.SEC_WEBSOCKET_PROTOCOL in headers: 

199 req_protocols = [ 

200 str(proto.strip()) 

201 for proto in headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",") 

202 ] 

203 

204 for proto in req_protocols: 

205 if proto in self._protocols: 

206 protocol = proto 

207 break 

208 else: 

209 # No overlap found: Return no protocol as per spec 

210 ws_logger.warning( 

211 "Client protocols %r don’t overlap server-known ones %r", 

212 req_protocols, 

213 self._protocols, 

214 ) 

215 

216 # check supported version 

217 version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, "") 

218 if version not in ("13", "8", "7"): 

219 raise HTTPBadRequest(text=f"Unsupported version: {version}") 

220 

221 # check client handshake for validity 

222 key = headers.get(hdrs.SEC_WEBSOCKET_KEY) 

223 try: 

224 if not key or len(base64.b64decode(key)) != 16: 

225 raise HTTPBadRequest(text=f"Handshake error: {key!r}") 

226 except binascii.Error: 

227 raise HTTPBadRequest(text=f"Handshake error: {key!r}") from None 

228 

229 accept_val = base64.b64encode( 

230 hashlib.sha1(key.encode() + WS_KEY).digest() 

231 ).decode() 

232 response_headers = CIMultiDict( 

233 { 

234 hdrs.UPGRADE: "websocket", 

235 hdrs.CONNECTION: "upgrade", 

236 hdrs.SEC_WEBSOCKET_ACCEPT: accept_val, 

237 } 

238 ) 

239 

240 notakeover = False 

241 compress = 0 

242 if self._compress: 

243 extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS) 

244 # Server side always get return with no exception. 

245 # If something happened, just drop compress extension 

246 compress, notakeover = ws_ext_parse(extensions, isserver=True) 

247 if compress: 

248 enabledext = ws_ext_gen( 

249 compress=compress, isserver=True, server_notakeover=notakeover 

250 ) 

251 response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext 

252 

253 if protocol: 

254 response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol 

255 return ( 

256 response_headers, 

257 protocol, 

258 compress, 

259 notakeover, 

260 ) # type: ignore[return-value] 

261 

262 def _pre_start(self, request: BaseRequest) -> Tuple[str, WebSocketWriter]: 

263 self._loop = request._loop 

264 

265 headers, protocol, compress, notakeover = self._handshake(request) 

266 

267 self.set_status(101) 

268 self.headers.update(headers) 

269 self.force_close() 

270 self._compress = compress 

271 transport = request._protocol.transport 

272 assert transport is not None 

273 writer = WebSocketWriter( 

274 request._protocol, transport, compress=compress, notakeover=notakeover 

275 ) 

276 

277 return protocol, writer 

278 

279 def _post_start( 

280 self, request: BaseRequest, protocol: str, writer: WebSocketWriter 

281 ) -> None: 

282 self._ws_protocol = protocol 

283 self._writer = writer 

284 

285 self._reset_heartbeat() 

286 

287 loop = self._loop 

288 assert loop is not None 

289 self._reader = FlowControlDataQueue(request._protocol, 2**16, loop=loop) 

290 request.protocol.set_parser( 

291 WebSocketReader(self._reader, self._max_msg_size, compress=self._compress) 

292 ) 

293 # disable HTTP keepalive for WebSocket 

294 request.protocol.keep_alive(False) 

295 

296 def can_prepare(self, request: BaseRequest) -> WebSocketReady: 

297 if self._writer is not None: 

298 raise RuntimeError("Already started") 

299 try: 

300 _, protocol, _, _ = self._handshake(request) 

301 except HTTPException: 

302 return WebSocketReady(False, None) 

303 else: 

304 return WebSocketReady(True, protocol) 

305 

306 @property 

307 def closed(self) -> bool: 

308 return self._closed 

309 

310 @property 

311 def close_code(self) -> Optional[int]: 

312 return self._close_code 

313 

314 @property 

315 def ws_protocol(self) -> Optional[str]: 

316 return self._ws_protocol 

317 

318 @property 

319 def compress(self) -> bool: 

320 return self._compress 

321 

322 def exception(self) -> Optional[BaseException]: 

323 return self._exception 

324 

325 async def ping(self, message: bytes = b"") -> None: 

326 if self._writer is None: 

327 raise RuntimeError("Call .prepare() first") 

328 await self._writer.ping(message) 

329 

330 async def pong(self, message: bytes = b"") -> None: 

331 # unsolicited pong 

332 if self._writer is None: 

333 raise RuntimeError("Call .prepare() first") 

334 await self._writer.pong(message) 

335 

336 async def send_str(self, data: str, compress: Optional[bool] = None) -> None: 

337 if self._writer is None: 

338 raise RuntimeError("Call .prepare() first") 

339 if not isinstance(data, str): 

340 raise TypeError("data argument must be str (%r)" % type(data)) 

341 await self._writer.send(data, binary=False, compress=compress) 

342 

343 async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None: 

344 if self._writer is None: 

345 raise RuntimeError("Call .prepare() first") 

346 if not isinstance(data, (bytes, bytearray, memoryview)): 

347 raise TypeError("data argument must be byte-ish (%r)" % type(data)) 

348 await self._writer.send(data, binary=True, compress=compress) 

349 

350 async def send_json( 

351 self, 

352 data: Any, 

353 compress: Optional[bool] = None, 

354 *, 

355 dumps: JSONEncoder = json.dumps, 

356 ) -> None: 

357 await self.send_str(dumps(data), compress=compress) 

358 

359 async def write_eof(self) -> None: # type: ignore[override] 

360 if self._eof_sent: 

361 return 

362 if self._payload_writer is None: 

363 raise RuntimeError("Response has not been started") 

364 

365 await self.close() 

366 self._eof_sent = True 

367 

368 async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: 

369 if self._writer is None: 

370 raise RuntimeError("Call .prepare() first") 

371 

372 self._cancel_heartbeat() 

373 reader = self._reader 

374 assert reader is not None 

375 

376 # we need to break `receive()` cycle first, 

377 # `close()` may be called from different task 

378 if self._waiting is not None and not self._closed: 

379 reader.feed_data(WS_CLOSING_MESSAGE, 0) 

380 await self._waiting 

381 

382 if not self._closed: 

383 self._closed = True 

384 try: 

385 await self._writer.close(code, message) 

386 writer = self._payload_writer 

387 assert writer is not None 

388 await writer.drain() 

389 except (asyncio.CancelledError, asyncio.TimeoutError): 

390 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

391 raise 

392 except Exception as exc: 

393 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

394 self._exception = exc 

395 return True 

396 

397 if self._closing: 

398 return True 

399 

400 reader = self._reader 

401 assert reader is not None 

402 try: 

403 async with async_timeout.timeout(self._timeout): 

404 msg = await reader.read() 

405 except asyncio.CancelledError: 

406 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

407 raise 

408 except Exception as exc: 

409 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

410 self._exception = exc 

411 return True 

412 

413 if msg.type == WSMsgType.CLOSE: 

414 self._close_code = msg.data 

415 return True 

416 

417 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

418 self._exception = asyncio.TimeoutError() 

419 return True 

420 else: 

421 return False 

422 

423 async def receive(self, timeout: Optional[float] = None) -> WSMessage: 

424 if self._reader is None: 

425 raise RuntimeError("Call .prepare() first") 

426 

427 loop = self._loop 

428 assert loop is not None 

429 while True: 

430 if self._waiting is not None: 

431 raise RuntimeError("Concurrent call to receive() is not allowed") 

432 

433 if self._closed: 

434 self._conn_lost += 1 

435 if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS: 

436 raise RuntimeError("WebSocket connection is closed.") 

437 return WS_CLOSED_MESSAGE 

438 elif self._closing: 

439 return WS_CLOSING_MESSAGE 

440 

441 try: 

442 self._waiting = loop.create_future() 

443 try: 

444 async with async_timeout.timeout(timeout or self._receive_timeout): 

445 msg = await self._reader.read() 

446 self._reset_heartbeat() 

447 finally: 

448 waiter = self._waiting 

449 set_result(waiter, True) 

450 self._waiting = None 

451 except (asyncio.CancelledError, asyncio.TimeoutError): 

452 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

453 raise 

454 except EofStream: 

455 self._close_code = WSCloseCode.OK 

456 await self.close() 

457 return WSMessage(WSMsgType.CLOSED, None, None) 

458 except WebSocketError as exc: 

459 self._close_code = exc.code 

460 await self.close(code=exc.code) 

461 return WSMessage(WSMsgType.ERROR, exc, None) 

462 except Exception as exc: 

463 self._exception = exc 

464 self._closing = True 

465 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

466 await self.close() 

467 return WSMessage(WSMsgType.ERROR, exc, None) 

468 

469 if msg.type == WSMsgType.CLOSE: 

470 self._closing = True 

471 self._close_code = msg.data 

472 if not self._closed and self._autoclose: 

473 await self.close() 

474 elif msg.type == WSMsgType.CLOSING: 

475 self._closing = True 

476 elif msg.type == WSMsgType.PING and self._autoping: 

477 await self.pong(msg.data) 

478 continue 

479 elif msg.type == WSMsgType.PONG and self._autoping: 

480 continue 

481 

482 return msg 

483 

484 async def receive_str(self, *, timeout: Optional[float] = None) -> str: 

485 msg = await self.receive(timeout) 

486 if msg.type != WSMsgType.TEXT: 

487 raise TypeError( 

488 "Received message {}:{!r} is not WSMsgType.TEXT".format( 

489 msg.type, msg.data 

490 ) 

491 ) 

492 return cast(str, msg.data) 

493 

494 async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: 

495 msg = await self.receive(timeout) 

496 if msg.type != WSMsgType.BINARY: 

497 raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes") 

498 return cast(bytes, msg.data) 

499 

500 async def receive_json( 

501 self, *, loads: JSONDecoder = json.loads, timeout: Optional[float] = None 

502 ) -> Any: 

503 data = await self.receive_str(timeout=timeout) 

504 return loads(data) 

505 

506 async def write(self, data: bytes) -> None: 

507 raise RuntimeError("Cannot call .write() for websocket") 

508 

509 def __aiter__(self) -> "WebSocketResponse": 

510 return self 

511 

512 async def __anext__(self) -> WSMessage: 

513 msg = await self.receive() 

514 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): 

515 raise StopAsyncIteration 

516 return msg 

517 

518 def _cancel(self, exc: BaseException) -> None: 

519 if self._reader is not None: 

520 self._reader.set_exception(exc)