Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_ws.py: 20%
318 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
1import asyncio
2import base64
3import binascii
4import dataclasses
5import hashlib
6import json
7from typing import Any, Iterable, Optional, Tuple, cast
9import async_timeout
10from multidict import CIMultiDict
11from typing_extensions import Final
13from . import hdrs
14from .abc import AbstractStreamWriter
15from .helpers import call_later, set_result
16from .http import (
17 WS_CLOSED_MESSAGE,
18 WS_CLOSING_MESSAGE,
19 WS_KEY,
20 WebSocketError,
21 WebSocketReader,
22 WebSocketWriter,
23 WSCloseCode,
24 WSMessage,
25 WSMsgType as WSMsgType,
26 ws_ext_gen,
27 ws_ext_parse,
28)
29from .log import ws_logger
30from .streams import EofStream, FlowControlDataQueue
31from .typedefs import JSONDecoder, JSONEncoder
32from .web_exceptions import HTTPBadRequest, HTTPException
33from .web_request import BaseRequest
34from .web_response import StreamResponse
36__all__ = (
37 "WebSocketResponse",
38 "WebSocketReady",
39 "WSMsgType",
40)
42THRESHOLD_CONNLOST_ACCESS: Final[int] = 5
45@dataclasses.dataclass(frozen=True)
46class WebSocketReady:
47 ok: bool
48 protocol: Optional[str]
50 def __bool__(self) -> bool:
51 return self.ok
54class WebSocketResponse(StreamResponse):
55 __slots__ = (
56 "_protocols",
57 "_ws_protocol",
58 "_writer",
59 "_reader",
60 "_closed",
61 "_closing",
62 "_conn_lost",
63 "_close_code",
64 "_loop",
65 "_waiting",
66 "_exception",
67 "_timeout",
68 "_receive_timeout",
69 "_autoclose",
70 "_autoping",
71 "_heartbeat",
72 "_heartbeat_cb",
73 "_pong_heartbeat",
74 "_pong_response_cb",
75 "_compress",
76 "_max_msg_size",
77 )
79 def __init__(
80 self,
81 *,
82 timeout: float = 10.0,
83 receive_timeout: Optional[float] = None,
84 autoclose: bool = True,
85 autoping: bool = True,
86 heartbeat: Optional[float] = None,
87 protocols: Iterable[str] = (),
88 compress: bool = True,
89 max_msg_size: int = 4 * 1024 * 1024,
90 ) -> None:
91 super().__init__(status=101)
92 self._length_check = False
93 self._protocols = protocols
94 self._ws_protocol: Optional[str] = None
95 self._writer: Optional[WebSocketWriter] = None
96 self._reader: Optional[FlowControlDataQueue[WSMessage]] = None
97 self._closed = False
98 self._closing = False
99 self._conn_lost = 0
100 self._close_code: Optional[int] = None
101 self._loop: Optional[asyncio.AbstractEventLoop] = None
102 self._waiting: Optional[asyncio.Future[bool]] = None
103 self._exception: Optional[BaseException] = None
104 self._timeout = timeout
105 self._receive_timeout = receive_timeout
106 self._autoclose = autoclose
107 self._autoping = autoping
108 self._heartbeat = heartbeat
109 self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
110 if heartbeat is not None:
111 self._pong_heartbeat = heartbeat / 2.0
112 self._pong_response_cb: Optional[asyncio.TimerHandle] = None
113 self._compress = compress
114 self._max_msg_size = max_msg_size
116 def _cancel_heartbeat(self) -> None:
117 if self._pong_response_cb is not None:
118 self._pong_response_cb.cancel()
119 self._pong_response_cb = None
121 if self._heartbeat_cb is not None:
122 self._heartbeat_cb.cancel()
123 self._heartbeat_cb = None
125 def _reset_heartbeat(self) -> None:
126 self._cancel_heartbeat()
128 if self._heartbeat is not None:
129 assert self._loop is not None
130 self._heartbeat_cb = call_later(
131 self._send_heartbeat,
132 self._heartbeat,
133 self._loop,
134 timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold
135 if self._req is not None
136 else 5,
137 )
139 def _send_heartbeat(self) -> None:
140 if self._heartbeat is not None and not self._closed:
141 assert self._loop is not None
142 # fire-and-forget a task is not perfect but maybe ok for
143 # sending ping. Otherwise we need a long-living heartbeat
144 # task in the class.
145 self._loop.create_task(self._writer.ping()) # type: ignore[union-attr]
147 if self._pong_response_cb is not None:
148 self._pong_response_cb.cancel()
149 self._pong_response_cb = call_later(
150 self._pong_not_received,
151 self._pong_heartbeat,
152 self._loop,
153 timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold
154 if self._req is not None
155 else 5,
156 )
158 def _pong_not_received(self) -> None:
159 if self._req is not None and self._req.transport is not None:
160 self._closed = True
161 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
162 self._exception = asyncio.TimeoutError()
163 self._req.transport.close()
165 async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
166 # make pre-check to don't hide it by do_handshake() exceptions
167 if self._payload_writer is not None:
168 return self._payload_writer
170 protocol, writer = self._pre_start(request)
171 payload_writer = await super().prepare(request)
172 assert payload_writer is not None
173 self._post_start(request, protocol, writer)
174 await payload_writer.drain()
175 return payload_writer
177 def _handshake(
178 self, request: BaseRequest
179 ) -> Tuple["CIMultiDict[str]", str, bool, bool]:
180 headers = request.headers
181 if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip():
182 raise HTTPBadRequest(
183 text=(
184 "No WebSocket UPGRADE hdr: {}\n Can "
185 '"Upgrade" only to "WebSocket".'
186 ).format(headers.get(hdrs.UPGRADE))
187 )
189 if "upgrade" not in headers.get(hdrs.CONNECTION, "").lower():
190 raise HTTPBadRequest(
191 text="No CONNECTION upgrade hdr: {}".format(
192 headers.get(hdrs.CONNECTION)
193 )
194 )
196 # find common sub-protocol between client and server
197 protocol = None
198 if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
199 req_protocols = [
200 str(proto.strip())
201 for proto in headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",")
202 ]
204 for proto in req_protocols:
205 if proto in self._protocols:
206 protocol = proto
207 break
208 else:
209 # No overlap found: Return no protocol as per spec
210 ws_logger.warning(
211 "Client protocols %r don’t overlap server-known ones %r",
212 req_protocols,
213 self._protocols,
214 )
216 # check supported version
217 version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, "")
218 if version not in ("13", "8", "7"):
219 raise HTTPBadRequest(text=f"Unsupported version: {version}")
221 # check client handshake for validity
222 key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
223 try:
224 if not key or len(base64.b64decode(key)) != 16:
225 raise HTTPBadRequest(text=f"Handshake error: {key!r}")
226 except binascii.Error:
227 raise HTTPBadRequest(text=f"Handshake error: {key!r}") from None
229 accept_val = base64.b64encode(
230 hashlib.sha1(key.encode() + WS_KEY).digest()
231 ).decode()
232 response_headers = CIMultiDict(
233 {
234 hdrs.UPGRADE: "websocket",
235 hdrs.CONNECTION: "upgrade",
236 hdrs.SEC_WEBSOCKET_ACCEPT: accept_val,
237 }
238 )
240 notakeover = False
241 compress = 0
242 if self._compress:
243 extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
244 # Server side always get return with no exception.
245 # If something happened, just drop compress extension
246 compress, notakeover = ws_ext_parse(extensions, isserver=True)
247 if compress:
248 enabledext = ws_ext_gen(
249 compress=compress, isserver=True, server_notakeover=notakeover
250 )
251 response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext
253 if protocol:
254 response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol
255 return (
256 response_headers,
257 protocol,
258 compress,
259 notakeover,
260 ) # type: ignore[return-value]
262 def _pre_start(self, request: BaseRequest) -> Tuple[str, WebSocketWriter]:
263 self._loop = request._loop
265 headers, protocol, compress, notakeover = self._handshake(request)
267 self.set_status(101)
268 self.headers.update(headers)
269 self.force_close()
270 self._compress = compress
271 transport = request._protocol.transport
272 assert transport is not None
273 writer = WebSocketWriter(
274 request._protocol, transport, compress=compress, notakeover=notakeover
275 )
277 return protocol, writer
279 def _post_start(
280 self, request: BaseRequest, protocol: str, writer: WebSocketWriter
281 ) -> None:
282 self._ws_protocol = protocol
283 self._writer = writer
285 self._reset_heartbeat()
287 loop = self._loop
288 assert loop is not None
289 self._reader = FlowControlDataQueue(request._protocol, 2**16, loop=loop)
290 request.protocol.set_parser(
291 WebSocketReader(self._reader, self._max_msg_size, compress=self._compress)
292 )
293 # disable HTTP keepalive for WebSocket
294 request.protocol.keep_alive(False)
296 def can_prepare(self, request: BaseRequest) -> WebSocketReady:
297 if self._writer is not None:
298 raise RuntimeError("Already started")
299 try:
300 _, protocol, _, _ = self._handshake(request)
301 except HTTPException:
302 return WebSocketReady(False, None)
303 else:
304 return WebSocketReady(True, protocol)
306 @property
307 def closed(self) -> bool:
308 return self._closed
310 @property
311 def close_code(self) -> Optional[int]:
312 return self._close_code
314 @property
315 def ws_protocol(self) -> Optional[str]:
316 return self._ws_protocol
318 @property
319 def compress(self) -> bool:
320 return self._compress
322 def exception(self) -> Optional[BaseException]:
323 return self._exception
325 async def ping(self, message: bytes = b"") -> None:
326 if self._writer is None:
327 raise RuntimeError("Call .prepare() first")
328 await self._writer.ping(message)
330 async def pong(self, message: bytes = b"") -> None:
331 # unsolicited pong
332 if self._writer is None:
333 raise RuntimeError("Call .prepare() first")
334 await self._writer.pong(message)
336 async def send_str(self, data: str, compress: Optional[bool] = None) -> None:
337 if self._writer is None:
338 raise RuntimeError("Call .prepare() first")
339 if not isinstance(data, str):
340 raise TypeError("data argument must be str (%r)" % type(data))
341 await self._writer.send(data, binary=False, compress=compress)
343 async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None:
344 if self._writer is None:
345 raise RuntimeError("Call .prepare() first")
346 if not isinstance(data, (bytes, bytearray, memoryview)):
347 raise TypeError("data argument must be byte-ish (%r)" % type(data))
348 await self._writer.send(data, binary=True, compress=compress)
350 async def send_json(
351 self,
352 data: Any,
353 compress: Optional[bool] = None,
354 *,
355 dumps: JSONEncoder = json.dumps,
356 ) -> None:
357 await self.send_str(dumps(data), compress=compress)
359 async def write_eof(self) -> None: # type: ignore[override]
360 if self._eof_sent:
361 return
362 if self._payload_writer is None:
363 raise RuntimeError("Response has not been started")
365 await self.close()
366 self._eof_sent = True
368 async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
369 if self._writer is None:
370 raise RuntimeError("Call .prepare() first")
372 self._cancel_heartbeat()
373 reader = self._reader
374 assert reader is not None
376 # we need to break `receive()` cycle first,
377 # `close()` may be called from different task
378 if self._waiting is not None and not self._closed:
379 reader.feed_data(WS_CLOSING_MESSAGE, 0)
380 await self._waiting
382 if not self._closed:
383 self._closed = True
384 try:
385 await self._writer.close(code, message)
386 writer = self._payload_writer
387 assert writer is not None
388 await writer.drain()
389 except (asyncio.CancelledError, asyncio.TimeoutError):
390 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
391 raise
392 except Exception as exc:
393 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
394 self._exception = exc
395 return True
397 if self._closing:
398 return True
400 reader = self._reader
401 assert reader is not None
402 try:
403 async with async_timeout.timeout(self._timeout):
404 msg = await reader.read()
405 except asyncio.CancelledError:
406 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
407 raise
408 except Exception as exc:
409 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
410 self._exception = exc
411 return True
413 if msg.type == WSMsgType.CLOSE:
414 self._close_code = msg.data
415 return True
417 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
418 self._exception = asyncio.TimeoutError()
419 return True
420 else:
421 return False
423 async def receive(self, timeout: Optional[float] = None) -> WSMessage:
424 if self._reader is None:
425 raise RuntimeError("Call .prepare() first")
427 loop = self._loop
428 assert loop is not None
429 while True:
430 if self._waiting is not None:
431 raise RuntimeError("Concurrent call to receive() is not allowed")
433 if self._closed:
434 self._conn_lost += 1
435 if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
436 raise RuntimeError("WebSocket connection is closed.")
437 return WS_CLOSED_MESSAGE
438 elif self._closing:
439 return WS_CLOSING_MESSAGE
441 try:
442 self._waiting = loop.create_future()
443 try:
444 async with async_timeout.timeout(timeout or self._receive_timeout):
445 msg = await self._reader.read()
446 self._reset_heartbeat()
447 finally:
448 waiter = self._waiting
449 set_result(waiter, True)
450 self._waiting = None
451 except (asyncio.CancelledError, asyncio.TimeoutError):
452 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
453 raise
454 except EofStream:
455 self._close_code = WSCloseCode.OK
456 await self.close()
457 return WSMessage(WSMsgType.CLOSED, None, None)
458 except WebSocketError as exc:
459 self._close_code = exc.code
460 await self.close(code=exc.code)
461 return WSMessage(WSMsgType.ERROR, exc, None)
462 except Exception as exc:
463 self._exception = exc
464 self._closing = True
465 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
466 await self.close()
467 return WSMessage(WSMsgType.ERROR, exc, None)
469 if msg.type == WSMsgType.CLOSE:
470 self._closing = True
471 self._close_code = msg.data
472 if not self._closed and self._autoclose:
473 await self.close()
474 elif msg.type == WSMsgType.CLOSING:
475 self._closing = True
476 elif msg.type == WSMsgType.PING and self._autoping:
477 await self.pong(msg.data)
478 continue
479 elif msg.type == WSMsgType.PONG and self._autoping:
480 continue
482 return msg
484 async def receive_str(self, *, timeout: Optional[float] = None) -> str:
485 msg = await self.receive(timeout)
486 if msg.type != WSMsgType.TEXT:
487 raise TypeError(
488 "Received message {}:{!r} is not WSMsgType.TEXT".format(
489 msg.type, msg.data
490 )
491 )
492 return cast(str, msg.data)
494 async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
495 msg = await self.receive(timeout)
496 if msg.type != WSMsgType.BINARY:
497 raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes")
498 return cast(bytes, msg.data)
500 async def receive_json(
501 self, *, loads: JSONDecoder = json.loads, timeout: Optional[float] = None
502 ) -> Any:
503 data = await self.receive_str(timeout=timeout)
504 return loads(data)
506 async def write(self, data: bytes) -> None:
507 raise RuntimeError("Cannot call .write() for websocket")
509 def __aiter__(self) -> "WebSocketResponse":
510 return self
512 async def __anext__(self) -> WSMessage:
513 msg = await self.receive()
514 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
515 raise StopAsyncIteration
516 return msg
518 def _cancel(self, exc: BaseException) -> None:
519 if self._reader is not None:
520 self._reader.set_exception(exc)