Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_ws.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import base64
3import binascii
4import hashlib
5import json
6import sys
7from typing import Any, Final, Iterable, Optional, Tuple, Union
9from multidict import CIMultiDict
11from . import hdrs
12from ._websocket.reader import WebSocketDataQueue
13from ._websocket.writer import DEFAULT_LIMIT
14from .abc import AbstractStreamWriter
15from .client_exceptions import WSMessageTypeError
16from .helpers import (
17 calculate_timeout_when,
18 frozen_dataclass_decorator,
19 set_exception,
20 set_result,
21)
22from .http import (
23 WS_CLOSED_MESSAGE,
24 WS_CLOSING_MESSAGE,
25 WS_KEY,
26 WebSocketError,
27 WebSocketReader,
28 WebSocketWriter,
29 WSCloseCode,
30 WSMessage,
31 WSMsgType,
32 ws_ext_gen,
33 ws_ext_parse,
34)
35from .http_websocket import _INTERNAL_RECEIVE_TYPES, WSMessageError
36from .log import ws_logger
37from .streams import EofStream
38from .typedefs import JSONDecoder, JSONEncoder
39from .web_exceptions import HTTPBadRequest, HTTPException
40from .web_request import BaseRequest
41from .web_response import StreamResponse
43if sys.version_info >= (3, 11):
44 import asyncio as async_timeout
45else:
46 import async_timeout
48__all__ = (
49 "WebSocketResponse",
50 "WebSocketReady",
51 "WSMsgType",
52)
54THRESHOLD_CONNLOST_ACCESS: Final[int] = 5
57@frozen_dataclass_decorator
58class WebSocketReady:
59 ok: bool
60 protocol: Optional[str]
62 def __bool__(self) -> bool:
63 return self.ok
66class WebSocketResponse(StreamResponse):
68 _length_check: bool = False
69 _ws_protocol: Optional[str] = None
70 _writer: Optional[WebSocketWriter] = None
71 _reader: Optional[WebSocketDataQueue] = None
72 _closed: bool = False
73 _closing: bool = False
74 _conn_lost: int = 0
75 _close_code: Optional[int] = None
76 _loop: Optional[asyncio.AbstractEventLoop] = None
77 _waiting: bool = False
78 _close_wait: Optional[asyncio.Future[None]] = None
79 _exception: Optional[BaseException] = None
80 _heartbeat_when: float = 0.0
81 _heartbeat_cb: Optional[asyncio.TimerHandle] = None
82 _pong_response_cb: Optional[asyncio.TimerHandle] = None
83 _ping_task: Optional[asyncio.Task[None]] = None
85 def __init__(
86 self,
87 *,
88 timeout: float = 10.0,
89 receive_timeout: Optional[float] = None,
90 autoclose: bool = True,
91 autoping: bool = True,
92 heartbeat: Optional[float] = None,
93 protocols: Iterable[str] = (),
94 compress: bool = True,
95 max_msg_size: int = 4 * 1024 * 1024,
96 writer_limit: int = DEFAULT_LIMIT,
97 ) -> None:
98 super().__init__(status=101)
99 self._protocols = protocols
100 self._timeout = timeout
101 self._receive_timeout = receive_timeout
102 self._autoclose = autoclose
103 self._autoping = autoping
104 self._heartbeat = heartbeat
105 if heartbeat is not None:
106 self._pong_heartbeat = heartbeat / 2.0
107 self._compress: Union[bool, int] = compress
108 self._max_msg_size = max_msg_size
109 self._writer_limit = writer_limit
111 def _cancel_heartbeat(self) -> None:
112 self._cancel_pong_response_cb()
113 if self._heartbeat_cb is not None:
114 self._heartbeat_cb.cancel()
115 self._heartbeat_cb = None
116 if self._ping_task is not None:
117 self._ping_task.cancel()
118 self._ping_task = None
120 def _cancel_pong_response_cb(self) -> None:
121 if self._pong_response_cb is not None:
122 self._pong_response_cb.cancel()
123 self._pong_response_cb = None
125 def _reset_heartbeat(self) -> None:
126 if self._heartbeat is None:
127 return
128 self._cancel_pong_response_cb()
129 req = self._req
130 timeout_ceil_threshold = (
131 req._protocol._timeout_ceil_threshold if req is not None else 5
132 )
133 loop = self._loop
134 assert loop is not None
135 now = loop.time()
136 when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
137 self._heartbeat_when = when
138 if self._heartbeat_cb is None:
139 # We do not cancel the previous heartbeat_cb here because
140 # it generates a significant amount of TimerHandle churn
141 # which causes asyncio to rebuild the heap frequently.
142 # Instead _send_heartbeat() will reschedule the next
143 # heartbeat if it fires too early.
144 self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
146 def _send_heartbeat(self) -> None:
147 self._heartbeat_cb = None
148 loop = self._loop
149 assert loop is not None and self._writer is not None
150 now = loop.time()
151 if now < self._heartbeat_when:
152 # Heartbeat fired too early, reschedule
153 self._heartbeat_cb = loop.call_at(
154 self._heartbeat_when, self._send_heartbeat
155 )
156 return
158 req = self._req
159 timeout_ceil_threshold = (
160 req._protocol._timeout_ceil_threshold if req is not None else 5
161 )
162 when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
163 self._cancel_pong_response_cb()
164 self._pong_response_cb = loop.call_at(when, self._pong_not_received)
166 coro = self._writer.send_frame(b"", WSMsgType.PING)
167 if sys.version_info >= (3, 12):
168 # Optimization for Python 3.12, try to send the ping
169 # immediately to avoid having to schedule
170 # the task on the event loop.
171 ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
172 else:
173 ping_task = loop.create_task(coro)
175 if not ping_task.done():
176 self._ping_task = ping_task
177 ping_task.add_done_callback(self._ping_task_done)
178 else:
179 self._ping_task_done(ping_task)
181 def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
182 """Callback for when the ping task completes."""
183 if not task.cancelled() and (exc := task.exception()):
184 self._handle_ping_pong_exception(exc)
185 self._ping_task = None
187 def _pong_not_received(self) -> None:
188 if self._req is not None and self._req.transport is not None:
189 self._handle_ping_pong_exception(
190 asyncio.TimeoutError(
191 f"No PONG received after {self._pong_heartbeat} seconds"
192 )
193 )
195 def _handle_ping_pong_exception(self, exc: BaseException) -> None:
196 """Handle exceptions raised during ping/pong processing."""
197 if self._closed:
198 return
199 self._set_closed()
200 self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
201 self._exception = exc
202 if self._waiting and not self._closing and self._reader is not None:
203 self._reader.feed_data(WSMessageError(data=exc, extra=None))
205 def _set_closed(self) -> None:
206 """Set the connection to closed.
208 Cancel any heartbeat timers and set the closed flag.
209 """
210 self._closed = True
211 self._cancel_heartbeat()
213 async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
214 # make pre-check to don't hide it by do_handshake() exceptions
215 if self._payload_writer is not None:
216 return self._payload_writer
218 protocol, writer = self._pre_start(request)
219 payload_writer = await super().prepare(request)
220 assert payload_writer is not None
221 self._post_start(request, protocol, writer)
222 await payload_writer.drain()
223 return payload_writer
225 def _handshake(
226 self, request: BaseRequest
227 ) -> Tuple["CIMultiDict[str]", Optional[str], int, bool]:
228 headers = request.headers
229 if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip():
230 raise HTTPBadRequest(
231 text=(
232 "No WebSocket UPGRADE hdr: {}\n Can "
233 '"Upgrade" only to "WebSocket".'
234 ).format(headers.get(hdrs.UPGRADE))
235 )
237 if "upgrade" not in headers.get(hdrs.CONNECTION, "").lower():
238 raise HTTPBadRequest(
239 text="No CONNECTION upgrade hdr: {}".format(
240 headers.get(hdrs.CONNECTION)
241 )
242 )
244 # find common sub-protocol between client and server
245 protocol: Optional[str] = None
246 if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
247 req_protocols = [
248 str(proto.strip())
249 for proto in headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",")
250 ]
252 for proto in req_protocols:
253 if proto in self._protocols:
254 protocol = proto
255 break
256 else:
257 # No overlap found: Return no protocol as per spec
258 ws_logger.warning(
259 "%s: Client protocols %r don’t overlap server-known ones %r",
260 request.remote,
261 req_protocols,
262 self._protocols,
263 )
265 # check supported version
266 version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, "")
267 if version not in ("13", "8", "7"):
268 raise HTTPBadRequest(text=f"Unsupported version: {version}")
270 # check client handshake for validity
271 key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
272 try:
273 if not key or len(base64.b64decode(key)) != 16:
274 raise HTTPBadRequest(text=f"Handshake error: {key!r}")
275 except binascii.Error:
276 raise HTTPBadRequest(text=f"Handshake error: {key!r}") from None
278 accept_val = base64.b64encode(
279 hashlib.sha1(key.encode() + WS_KEY).digest()
280 ).decode()
281 response_headers = CIMultiDict(
282 {
283 hdrs.UPGRADE: "websocket",
284 hdrs.CONNECTION: "upgrade",
285 hdrs.SEC_WEBSOCKET_ACCEPT: accept_val,
286 }
287 )
289 notakeover = False
290 compress = 0
291 if self._compress:
292 extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
293 # Server side always get return with no exception.
294 # If something happened, just drop compress extension
295 compress, notakeover = ws_ext_parse(extensions, isserver=True)
296 if compress:
297 enabledext = ws_ext_gen(
298 compress=compress, isserver=True, server_notakeover=notakeover
299 )
300 response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext
302 if protocol:
303 response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol
304 return (
305 response_headers,
306 protocol,
307 compress,
308 notakeover,
309 )
311 def _pre_start(self, request: BaseRequest) -> Tuple[Optional[str], WebSocketWriter]:
312 self._loop = request._loop
314 headers, protocol, compress, notakeover = self._handshake(request)
316 self.set_status(101)
317 self.headers.update(headers)
318 self.force_close()
319 self._compress = compress
320 transport = request._protocol.transport
321 assert transport is not None
322 writer = WebSocketWriter(
323 request._protocol,
324 transport,
325 compress=compress,
326 notakeover=notakeover,
327 limit=self._writer_limit,
328 )
330 return protocol, writer
332 def _post_start(
333 self, request: BaseRequest, protocol: Optional[str], writer: WebSocketWriter
334 ) -> None:
335 self._ws_protocol = protocol
336 self._writer = writer
338 self._reset_heartbeat()
340 loop = self._loop
341 assert loop is not None
342 self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop)
343 request.protocol.set_parser(
344 WebSocketReader(
345 self._reader, self._max_msg_size, compress=bool(self._compress)
346 )
347 )
348 # disable HTTP keepalive for WebSocket
349 request.protocol.keep_alive(False)
351 def can_prepare(self, request: BaseRequest) -> WebSocketReady:
352 if self._writer is not None:
353 raise RuntimeError("Already started")
354 try:
355 _, protocol, _, _ = self._handshake(request)
356 except HTTPException:
357 return WebSocketReady(False, None)
358 else:
359 return WebSocketReady(True, protocol)
361 @property
362 def closed(self) -> bool:
363 return self._closed
365 @property
366 def close_code(self) -> Optional[int]:
367 return self._close_code
369 @property
370 def ws_protocol(self) -> Optional[str]:
371 return self._ws_protocol
373 @property
374 def compress(self) -> Union[int, bool]:
375 return self._compress
377 def get_extra_info(self, name: str, default: Any = None) -> Any:
378 """Get optional transport information.
380 If no value associated with ``name`` is found, ``default`` is returned.
381 """
382 writer = self._writer
383 if writer is None:
384 return default
385 return writer.transport.get_extra_info(name, default)
387 def exception(self) -> Optional[BaseException]:
388 return self._exception
390 async def ping(self, message: bytes = b"") -> None:
391 if self._writer is None:
392 raise RuntimeError("Call .prepare() first")
393 await self._writer.send_frame(message, WSMsgType.PING)
395 async def pong(self, message: bytes = b"") -> None:
396 # unsolicited pong
397 if self._writer is None:
398 raise RuntimeError("Call .prepare() first")
399 await self._writer.send_frame(message, WSMsgType.PONG)
401 async def send_frame(
402 self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
403 ) -> None:
404 """Send a frame over the websocket."""
405 if self._writer is None:
406 raise RuntimeError("Call .prepare() first")
407 await self._writer.send_frame(message, opcode, compress)
409 async def send_str(self, data: str, compress: Optional[int] = None) -> None:
410 if self._writer is None:
411 raise RuntimeError("Call .prepare() first")
412 if not isinstance(data, str):
413 raise TypeError("data argument must be str (%r)" % type(data))
414 await self._writer.send_frame(
415 data.encode("utf-8"), WSMsgType.TEXT, compress=compress
416 )
418 async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
419 if self._writer is None:
420 raise RuntimeError("Call .prepare() first")
421 if not isinstance(data, (bytes, bytearray, memoryview)):
422 raise TypeError("data argument must be byte-ish (%r)" % type(data))
423 await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)
425 async def send_json(
426 self,
427 data: Any,
428 compress: Optional[int] = None,
429 *,
430 dumps: JSONEncoder = json.dumps,
431 ) -> None:
432 await self.send_str(dumps(data), compress=compress)
434 async def write_eof(self) -> None: # type: ignore[override]
435 if self._eof_sent:
436 return
437 if self._payload_writer is None:
438 raise RuntimeError("Response has not been started")
440 await self.close()
441 self._eof_sent = True
443 async def close(
444 self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True
445 ) -> bool:
446 """Close websocket connection."""
447 if self._writer is None:
448 raise RuntimeError("Call .prepare() first")
450 if self._closed:
451 return False
452 self._set_closed()
454 try:
455 await self._writer.close(code, message)
456 writer = self._payload_writer
457 assert writer is not None
458 if drain:
459 await writer.drain()
460 except (asyncio.CancelledError, asyncio.TimeoutError):
461 self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
462 raise
463 except Exception as exc:
464 self._exception = exc
465 self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
466 return True
468 reader = self._reader
469 assert reader is not None
470 # we need to break `receive()` cycle before we can call
471 # `reader.read()` as `close()` may be called from different task
472 if self._waiting:
473 assert self._loop is not None
474 assert self._close_wait is None
475 self._close_wait = self._loop.create_future()
476 reader.feed_data(WS_CLOSING_MESSAGE)
477 await self._close_wait
479 if self._closing:
480 self._close_transport()
481 return True
483 try:
484 async with async_timeout.timeout(self._timeout):
485 while True:
486 msg = await reader.read()
487 if msg.type is WSMsgType.CLOSE:
488 self._set_code_close_transport(msg.data)
489 return True
490 except asyncio.CancelledError:
491 self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
492 raise
493 except Exception as exc:
494 self._exception = exc
495 self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
496 return True
498 def _set_closing(self, code: int) -> None:
499 """Set the close code and mark the connection as closing."""
500 self._closing = True
501 self._close_code = code
502 self._cancel_heartbeat()
504 def _set_code_close_transport(self, code: int) -> None:
505 """Set the close code and close the transport."""
506 self._close_code = code
507 self._close_transport()
509 def _close_transport(self) -> None:
510 """Close the transport."""
511 if self._req is not None and self._req.transport is not None:
512 self._req.transport.close()
514 async def receive(self, timeout: Optional[float] = None) -> WSMessage:
515 if self._reader is None:
516 raise RuntimeError("Call .prepare() first")
518 receive_timeout = timeout or self._receive_timeout
519 while True:
520 if self._waiting:
521 raise RuntimeError("Concurrent call to receive() is not allowed")
523 if self._closed:
524 self._conn_lost += 1
525 if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
526 raise RuntimeError("WebSocket connection is closed.")
527 return WS_CLOSED_MESSAGE
528 elif self._closing:
529 return WS_CLOSING_MESSAGE
531 try:
532 self._waiting = True
533 try:
534 if receive_timeout:
535 # Entering the context manager and creating
536 # Timeout() object can take almost 50% of the
537 # run time in this loop so we avoid it if
538 # there is no read timeout.
539 async with async_timeout.timeout(receive_timeout):
540 msg = await self._reader.read()
541 else:
542 msg = await self._reader.read()
543 self._reset_heartbeat()
544 finally:
545 self._waiting = False
546 if self._close_wait:
547 set_result(self._close_wait, None)
548 except asyncio.TimeoutError:
549 raise
550 except EofStream:
551 self._close_code = WSCloseCode.OK
552 await self.close()
553 return WS_CLOSED_MESSAGE
554 except WebSocketError as exc:
555 self._close_code = exc.code
556 await self.close(code=exc.code)
557 return WSMessageError(data=exc)
558 except Exception as exc:
559 self._exception = exc
560 self._set_closing(WSCloseCode.ABNORMAL_CLOSURE)
561 await self.close()
562 return WSMessageError(data=exc)
564 if msg.type not in _INTERNAL_RECEIVE_TYPES:
565 # If its not a close/closing/ping/pong message
566 # we can return it immediately
567 return msg
569 if msg.type is WSMsgType.CLOSE:
570 self._set_closing(msg.data)
571 # Could be closed while awaiting reader.
572 if not self._closed and self._autoclose: # type: ignore[redundant-expr]
573 # The client is likely going to close the
574 # connection out from under us so we do not
575 # want to drain any pending writes as it will
576 # likely result writing to a broken pipe.
577 await self.close(drain=False)
578 elif msg.type is WSMsgType.CLOSING:
579 self._set_closing(WSCloseCode.OK)
580 elif msg.type is WSMsgType.PING and self._autoping:
581 await self.pong(msg.data)
582 continue
583 elif msg.type is WSMsgType.PONG and self._autoping:
584 continue
586 return msg
588 async def receive_str(self, *, timeout: Optional[float] = None) -> str:
589 msg = await self.receive(timeout)
590 if msg.type is not WSMsgType.TEXT:
591 raise WSMessageTypeError(
592 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
593 )
594 return msg.data
596 async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
597 msg = await self.receive(timeout)
598 if msg.type is not WSMsgType.BINARY:
599 raise WSMessageTypeError(
600 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
601 )
602 return msg.data
604 async def receive_json(
605 self, *, loads: JSONDecoder = json.loads, timeout: Optional[float] = None
606 ) -> Any:
607 data = await self.receive_str(timeout=timeout)
608 return loads(data)
610 async def write(
611 self, data: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
612 ) -> None:
613 raise RuntimeError("Cannot call .write() for websocket")
615 def __aiter__(self) -> "WebSocketResponse":
616 return self
618 async def __anext__(self) -> WSMessage:
619 msg = await self.receive()
620 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
621 raise StopAsyncIteration
622 return msg
624 def _cancel(self, exc: BaseException) -> None:
625 # web_protocol calls this from connection_lost
626 # or when the server is shutting down.
627 self._closing = True
628 self._cancel_heartbeat()
629 if self._reader is not None:
630 set_exception(self._reader, exc)