Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_ws.py: 20%
332 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-26 06:16 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-26 06:16 +0000
1import asyncio
2import base64
3import binascii
4import dataclasses
5import hashlib
6import json
7import sys
8from typing import Any, Final, Iterable, Optional, Tuple, cast
10from multidict import CIMultiDict
12from . import hdrs
13from .abc import AbstractStreamWriter
14from .helpers import call_later, set_result
15from .http import (
16 WS_CLOSED_MESSAGE,
17 WS_CLOSING_MESSAGE,
18 WS_KEY,
19 WebSocketError,
20 WebSocketReader,
21 WebSocketWriter,
22 WSCloseCode,
23 WSMessage,
24 WSMsgType as WSMsgType,
25 ws_ext_gen,
26 ws_ext_parse,
27)
28from .log import ws_logger
29from .streams import EofStream, FlowControlDataQueue
30from .typedefs import JSONDecoder, JSONEncoder
31from .web_exceptions import HTTPBadRequest, HTTPException
32from .web_request import BaseRequest
33from .web_response import StreamResponse
35if sys.version_info >= (3, 11):
36 import asyncio as async_timeout
37else:
38 import async_timeout
40__all__ = (
41 "WebSocketResponse",
42 "WebSocketReady",
43 "WSMsgType",
44)
46THRESHOLD_CONNLOST_ACCESS: Final[int] = 5
49@dataclasses.dataclass(frozen=True)
50class WebSocketReady:
51 ok: bool
52 protocol: Optional[str]
54 def __bool__(self) -> bool:
55 return self.ok
58class WebSocketResponse(StreamResponse):
59 __slots__ = (
60 "_protocols",
61 "_ws_protocol",
62 "_writer",
63 "_reader",
64 "_closed",
65 "_closing",
66 "_conn_lost",
67 "_close_code",
68 "_loop",
69 "_waiting",
70 "_exception",
71 "_timeout",
72 "_receive_timeout",
73 "_autoclose",
74 "_autoping",
75 "_heartbeat",
76 "_heartbeat_cb",
77 "_pong_heartbeat",
78 "_pong_response_cb",
79 "_compress",
80 "_max_msg_size",
81 )
83 def __init__(
84 self,
85 *,
86 timeout: float = 10.0,
87 receive_timeout: Optional[float] = None,
88 autoclose: bool = True,
89 autoping: bool = True,
90 heartbeat: Optional[float] = None,
91 protocols: Iterable[str] = (),
92 compress: bool = True,
93 max_msg_size: int = 4 * 1024 * 1024,
94 ) -> None:
95 super().__init__(status=101)
96 self._length_check = False
97 self._protocols = protocols
98 self._ws_protocol: Optional[str] = None
99 self._writer: Optional[WebSocketWriter] = None
100 self._reader: Optional[FlowControlDataQueue[WSMessage]] = None
101 self._closed = False
102 self._closing = False
103 self._conn_lost = 0
104 self._close_code: Optional[int] = None
105 self._loop: Optional[asyncio.AbstractEventLoop] = None
106 self._waiting: Optional[asyncio.Future[bool]] = None
107 self._exception: Optional[BaseException] = None
108 self._timeout = timeout
109 self._receive_timeout = receive_timeout
110 self._autoclose = autoclose
111 self._autoping = autoping
112 self._heartbeat = heartbeat
113 self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
114 if heartbeat is not None:
115 self._pong_heartbeat = heartbeat / 2.0
116 self._pong_response_cb: Optional[asyncio.TimerHandle] = None
117 self._compress = compress
118 self._max_msg_size = max_msg_size
120 def _cancel_heartbeat(self) -> None:
121 if self._pong_response_cb is not None:
122 self._pong_response_cb.cancel()
123 self._pong_response_cb = None
125 if self._heartbeat_cb is not None:
126 self._heartbeat_cb.cancel()
127 self._heartbeat_cb = None
129 def _reset_heartbeat(self) -> None:
130 self._cancel_heartbeat()
132 if self._heartbeat is not None:
133 assert self._loop is not None
134 self._heartbeat_cb = call_later(
135 self._send_heartbeat,
136 self._heartbeat,
137 self._loop,
138 timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold
139 if self._req is not None
140 else 5,
141 )
143 def _send_heartbeat(self) -> None:
144 if self._heartbeat is not None and not self._closed:
145 assert self._loop is not None and self._writer is not None
146 # fire-and-forget a task is not perfect but maybe ok for
147 # sending ping. Otherwise we need a long-living heartbeat
148 # task in the class.
149 self._loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]
151 if self._pong_response_cb is not None:
152 self._pong_response_cb.cancel()
153 self._pong_response_cb = call_later(
154 self._pong_not_received,
155 self._pong_heartbeat,
156 self._loop,
157 timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold
158 if self._req is not None
159 else 5,
160 )
162 def _pong_not_received(self) -> None:
163 if self._req is not None and self._req.transport is not None:
164 self._closed = True
165 self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
166 self._exception = asyncio.TimeoutError()
168 async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
169 # make pre-check to don't hide it by do_handshake() exceptions
170 if self._payload_writer is not None:
171 return self._payload_writer
173 protocol, writer = self._pre_start(request)
174 payload_writer = await super().prepare(request)
175 assert payload_writer is not None
176 self._post_start(request, protocol, writer)
177 await payload_writer.drain()
178 return payload_writer
180 def _handshake(
181 self, request: BaseRequest
182 ) -> Tuple["CIMultiDict[str]", str, bool, bool]:
183 headers = request.headers
184 if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip():
185 raise HTTPBadRequest(
186 text=(
187 "No WebSocket UPGRADE hdr: {}\n Can "
188 '"Upgrade" only to "WebSocket".'
189 ).format(headers.get(hdrs.UPGRADE))
190 )
192 if "upgrade" not in headers.get(hdrs.CONNECTION, "").lower():
193 raise HTTPBadRequest(
194 text="No CONNECTION upgrade hdr: {}".format(
195 headers.get(hdrs.CONNECTION)
196 )
197 )
199 # find common sub-protocol between client and server
200 protocol = None
201 if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
202 req_protocols = [
203 str(proto.strip())
204 for proto in headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",")
205 ]
207 for proto in req_protocols:
208 if proto in self._protocols:
209 protocol = proto
210 break
211 else:
212 # No overlap found: Return no protocol as per spec
213 ws_logger.warning(
214 "Client protocols %r don’t overlap server-known ones %r",
215 req_protocols,
216 self._protocols,
217 )
219 # check supported version
220 version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, "")
221 if version not in ("13", "8", "7"):
222 raise HTTPBadRequest(text=f"Unsupported version: {version}")
224 # check client handshake for validity
225 key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
226 try:
227 if not key or len(base64.b64decode(key)) != 16:
228 raise HTTPBadRequest(text=f"Handshake error: {key!r}")
229 except binascii.Error:
230 raise HTTPBadRequest(text=f"Handshake error: {key!r}") from None
232 accept_val = base64.b64encode(
233 hashlib.sha1(key.encode() + WS_KEY).digest()
234 ).decode()
235 response_headers = CIMultiDict(
236 {
237 hdrs.UPGRADE: "websocket",
238 hdrs.CONNECTION: "upgrade",
239 hdrs.SEC_WEBSOCKET_ACCEPT: accept_val,
240 }
241 )
243 notakeover = False
244 compress = 0
245 if self._compress:
246 extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
247 # Server side always get return with no exception.
248 # If something happened, just drop compress extension
249 compress, notakeover = ws_ext_parse(extensions, isserver=True)
250 if compress:
251 enabledext = ws_ext_gen(
252 compress=compress, isserver=True, server_notakeover=notakeover
253 )
254 response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext
256 if protocol:
257 response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol
258 return (
259 response_headers,
260 protocol,
261 compress,
262 notakeover,
263 ) # type: ignore[return-value]
265 def _pre_start(self, request: BaseRequest) -> Tuple[str, WebSocketWriter]:
266 self._loop = request._loop
268 headers, protocol, compress, notakeover = self._handshake(request)
270 self.set_status(101)
271 self.headers.update(headers)
272 self.force_close()
273 self._compress = compress
274 transport = request._protocol.transport
275 assert transport is not None
276 writer = WebSocketWriter(
277 request._protocol, transport, compress=compress, notakeover=notakeover
278 )
280 return protocol, writer
282 def _post_start(
283 self, request: BaseRequest, protocol: str, writer: WebSocketWriter
284 ) -> None:
285 self._ws_protocol = protocol
286 self._writer = writer
288 self._reset_heartbeat()
290 loop = self._loop
291 assert loop is not None
292 self._reader = FlowControlDataQueue(request._protocol, 2**16, loop=loop)
293 request.protocol.set_parser(
294 WebSocketReader(self._reader, self._max_msg_size, compress=self._compress)
295 )
296 # disable HTTP keepalive for WebSocket
297 request.protocol.keep_alive(False)
299 def can_prepare(self, request: BaseRequest) -> WebSocketReady:
300 if self._writer is not None:
301 raise RuntimeError("Already started")
302 try:
303 _, protocol, _, _ = self._handshake(request)
304 except HTTPException:
305 return WebSocketReady(False, None)
306 else:
307 return WebSocketReady(True, protocol)
309 @property
310 def closed(self) -> bool:
311 return self._closed
313 @property
314 def close_code(self) -> Optional[int]:
315 return self._close_code
317 @property
318 def ws_protocol(self) -> Optional[str]:
319 return self._ws_protocol
321 @property
322 def compress(self) -> bool:
323 return self._compress
325 def get_extra_info(self, name: str, default: Any = None) -> Any:
326 """Get optional transport information.
328 If no value associated with ``name`` is found, ``default`` is returned.
329 """
330 writer = self._writer
331 if writer is None:
332 return default
333 transport = writer.transport
334 if transport is None:
335 return default
336 return transport.get_extra_info(name, default)
338 def exception(self) -> Optional[BaseException]:
339 return self._exception
341 async def ping(self, message: bytes = b"") -> None:
342 if self._writer is None:
343 raise RuntimeError("Call .prepare() first")
344 await self._writer.ping(message)
346 async def pong(self, message: bytes = b"") -> None:
347 # unsolicited pong
348 if self._writer is None:
349 raise RuntimeError("Call .prepare() first")
350 await self._writer.pong(message)
352 async def send_str(self, data: str, compress: Optional[bool] = None) -> None:
353 if self._writer is None:
354 raise RuntimeError("Call .prepare() first")
355 if not isinstance(data, str):
356 raise TypeError("data argument must be str (%r)" % type(data))
357 await self._writer.send(data, binary=False, compress=compress)
359 async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None:
360 if self._writer is None:
361 raise RuntimeError("Call .prepare() first")
362 if not isinstance(data, (bytes, bytearray, memoryview)):
363 raise TypeError("data argument must be byte-ish (%r)" % type(data))
364 await self._writer.send(data, binary=True, compress=compress)
366 async def send_json(
367 self,
368 data: Any,
369 compress: Optional[bool] = None,
370 *,
371 dumps: JSONEncoder = json.dumps,
372 ) -> None:
373 await self.send_str(dumps(data), compress=compress)
375 async def write_eof(self) -> None: # type: ignore[override]
376 if self._eof_sent:
377 return
378 if self._payload_writer is None:
379 raise RuntimeError("Response has not been started")
381 await self.close()
382 self._eof_sent = True
384 async def close(
385 self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True
386 ) -> bool:
387 """Close websocket connection."""
388 if self._writer is None:
389 raise RuntimeError("Call .prepare() first")
391 self._cancel_heartbeat()
392 reader = self._reader
393 assert reader is not None
395 # we need to break `receive()` cycle first,
396 # `close()` may be called from different task
397 if self._waiting is not None and not self._closed:
398 reader.feed_data(WS_CLOSING_MESSAGE, 0)
399 await self._waiting
401 if self._closed:
402 return False
404 self._closed = True
405 try:
406 await self._writer.close(code, message)
407 writer = self._payload_writer
408 assert writer is not None
409 if drain:
410 await writer.drain()
411 except (asyncio.CancelledError, asyncio.TimeoutError):
412 self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
413 raise
414 except Exception as exc:
415 self._exception = exc
416 self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
417 return True
419 if self._closing:
420 return True
422 reader = self._reader
423 assert reader is not None
424 try:
425 async with async_timeout.timeout(self._timeout):
426 msg = await reader.read()
427 except asyncio.CancelledError:
428 self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
429 raise
430 except Exception as exc:
431 self._exception = exc
432 self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
433 return True
435 if msg.type == WSMsgType.CLOSE:
436 self._set_code_close_transport(msg.data)
437 return True
439 self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
440 self._exception = asyncio.TimeoutError()
441 return True
443 def _set_code_close_transport(self, code: WSCloseCode) -> None:
444 """Set the close code and close the transport."""
445 self._close_code = code
446 if self._req is not None and self._req.transport is not None:
447 self._req.transport.close()
449 async def receive(self, timeout: Optional[float] = None) -> WSMessage:
450 if self._reader is None:
451 raise RuntimeError("Call .prepare() first")
453 loop = self._loop
454 assert loop is not None
455 while True:
456 if self._waiting is not None:
457 raise RuntimeError("Concurrent call to receive() is not allowed")
459 if self._closed:
460 self._conn_lost += 1
461 if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
462 raise RuntimeError("WebSocket connection is closed.")
463 return WS_CLOSED_MESSAGE
464 elif self._closing:
465 return WS_CLOSING_MESSAGE
467 try:
468 self._waiting = loop.create_future()
469 try:
470 async with async_timeout.timeout(timeout or self._receive_timeout):
471 msg = await self._reader.read()
472 self._reset_heartbeat()
473 finally:
474 waiter = self._waiting
475 set_result(waiter, True)
476 self._waiting = None
477 except (asyncio.CancelledError, asyncio.TimeoutError):
478 self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
479 raise
480 except EofStream:
481 self._close_code = WSCloseCode.OK
482 await self.close()
483 return WSMessage(WSMsgType.CLOSED, None, None)
484 except WebSocketError as exc:
485 self._close_code = exc.code
486 await self.close(code=exc.code)
487 return WSMessage(WSMsgType.ERROR, exc, None)
488 except Exception as exc:
489 self._exception = exc
490 self._closing = True
491 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
492 await self.close()
493 return WSMessage(WSMsgType.ERROR, exc, None)
495 if msg.type == WSMsgType.CLOSE:
496 self._closing = True
497 self._close_code = msg.data
498 # Could be closed while awaiting reader.
499 if not self._closed and self._autoclose: # type: ignore[redundant-expr]
500 # The client is likely going to close the
501 # connection out from under us so we do not
502 # want to drain any pending writes as it will
503 # likely result writing to a broken pipe.
504 await self.close(drain=False)
505 elif msg.type == WSMsgType.CLOSING:
506 self._closing = True
507 elif msg.type == WSMsgType.PING and self._autoping:
508 await self.pong(msg.data)
509 continue
510 elif msg.type == WSMsgType.PONG and self._autoping:
511 continue
513 return msg
515 async def receive_str(self, *, timeout: Optional[float] = None) -> str:
516 msg = await self.receive(timeout)
517 if msg.type != WSMsgType.TEXT:
518 raise TypeError(
519 "Received message {}:{!r} is not WSMsgType.TEXT".format(
520 msg.type, msg.data
521 )
522 )
523 return cast(str, msg.data)
525 async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
526 msg = await self.receive(timeout)
527 if msg.type != WSMsgType.BINARY:
528 raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes")
529 return cast(bytes, msg.data)
531 async def receive_json(
532 self, *, loads: JSONDecoder = json.loads, timeout: Optional[float] = None
533 ) -> Any:
534 data = await self.receive_str(timeout=timeout)
535 return loads(data)
537 async def write(self, data: bytes) -> None:
538 raise RuntimeError("Cannot call .write() for websocket")
540 def __aiter__(self) -> "WebSocketResponse":
541 return self
543 async def __anext__(self) -> WSMessage:
544 msg = await self.receive()
545 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
546 raise StopAsyncIteration
547 return msg
549 def _cancel(self, exc: BaseException) -> None:
550 if self._reader is not None:
551 self._reader.set_exception(exc)