Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/client_ws.py: 27%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""WebSocket client for asyncio."""
3import asyncio
4import sys
5from collections.abc import Callable
6from types import TracebackType
7from typing import Any, Final, Generic, Literal, overload
9from ._websocket.reader import WebSocketDataQueue
10from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
11from .client_reqrep import ClientResponse
12from .helpers import calculate_timeout_when, frozen_dataclass_decorator, set_result
13from .http import (
14 WS_CLOSED_MESSAGE,
15 WS_CLOSING_MESSAGE,
16 WebSocketError,
17 WSCloseCode,
18 WSMessageDecodeText,
19 WSMessageNoDecodeText,
20 WSMsgType,
21)
22from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter, WSMessageError
23from .streams import EofStream
24from .typedefs import (
25 DEFAULT_JSON_DECODER,
26 DEFAULT_JSON_ENCODER,
27 JSONBytesEncoder,
28 JSONDecoder,
29 JSONEncoder,
30)
32if sys.version_info >= (3, 13):
33 from typing import TypeVar
34else:
35 from typing_extensions import TypeVar
37if sys.version_info >= (3, 11):
38 import asyncio as async_timeout
39 from typing import Self
40else:
41 import async_timeout
42 from typing_extensions import Self
44# TypeVar for whether text messages are decoded to str (True) or kept as bytes (False)
45# Covariant because it only affects return types, not input types
46_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True])
49@frozen_dataclass_decorator
50class ClientWSTimeout:
51 ws_receive: float | None = None
52 ws_close: float | None = None
55DEFAULT_WS_CLIENT_TIMEOUT: Final[ClientWSTimeout] = ClientWSTimeout(
56 ws_receive=None, ws_close=10.0
57)
60class ClientWebSocketResponse(Generic[_DecodeText]):
61 def __init__(
62 self,
63 reader: WebSocketDataQueue,
64 writer: WebSocketWriter,
65 protocol: str | None,
66 response: ClientResponse,
67 timeout: ClientWSTimeout,
68 autoclose: bool,
69 autoping: bool,
70 loop: asyncio.AbstractEventLoop,
71 *,
72 heartbeat: float | None = None,
73 compress: int = 0,
74 client_notakeover: bool = False,
75 ) -> None:
76 self._response = response
77 self._conn = response.connection
79 self._writer = writer
80 self._reader = reader
81 self._protocol = protocol
82 self._closed = False
83 self._closing = False
84 self._close_code: int | None = None
85 self._timeout = timeout
86 self._autoclose = autoclose
87 self._autoping = autoping
88 self._heartbeat = heartbeat
89 self._heartbeat_cb: asyncio.TimerHandle | None = None
90 self._heartbeat_when: float = 0.0
91 if heartbeat is not None:
92 self._pong_heartbeat = heartbeat / 2.0
93 self._pong_response_cb: asyncio.TimerHandle | None = None
94 self._loop = loop
95 self._waiting: bool = False
96 self._close_wait: asyncio.Future[None] | None = None
97 self._exception: BaseException | None = None
98 self._compress = compress
99 self._client_notakeover = client_notakeover
100 self._ping_task: asyncio.Task[None] | None = None
101 self._need_heartbeat_reset = False
102 self._heartbeat_reset_handle: asyncio.Handle | None = None
104 self._reset_heartbeat()
106 def _cancel_heartbeat(self) -> None:
107 self._cancel_pong_response_cb()
108 if self._heartbeat_reset_handle is not None:
109 self._heartbeat_reset_handle.cancel()
110 self._heartbeat_reset_handle = None
111 self._need_heartbeat_reset = False
112 if self._heartbeat_cb is not None:
113 self._heartbeat_cb.cancel()
114 self._heartbeat_cb = None
115 if self._ping_task is not None:
116 self._ping_task.cancel()
117 self._ping_task = None
119 def _cancel_pong_response_cb(self) -> None:
120 if self._pong_response_cb is not None:
121 self._pong_response_cb.cancel()
122 self._pong_response_cb = None
124 def _on_data_received(self) -> None:
125 if self._heartbeat is None or self._need_heartbeat_reset:
126 return
127 loop = self._loop
128 assert loop is not None
129 # Coalesce multiple chunks received in the same loop tick into a single
130 # heartbeat reset. Resetting immediately per chunk increases timer churn.
131 self._need_heartbeat_reset = True
132 self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset)
134 def _flush_heartbeat_reset(self) -> None:
135 self._heartbeat_reset_handle = None
136 if not self._need_heartbeat_reset:
137 return
138 self._reset_heartbeat()
139 self._need_heartbeat_reset = False
141 def _reset_heartbeat(self) -> None:
142 if self._heartbeat is None:
143 return
144 self._cancel_pong_response_cb()
145 loop = self._loop
146 assert loop is not None
147 conn = self._conn
148 timeout_ceil_threshold = (
149 conn._connector._timeout_ceil_threshold if conn is not None else 5
150 )
151 now = loop.time()
152 when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
153 self._heartbeat_when = when
154 if self._heartbeat_cb is None:
155 # We do not cancel the previous heartbeat_cb here because
156 # it generates a significant amount of TimerHandle churn
157 # which causes asyncio to rebuild the heap frequently.
158 # Instead _send_heartbeat() will reschedule the next
159 # heartbeat if it fires too early.
160 self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
162 def _send_heartbeat(self) -> None:
163 self._heartbeat_cb = None
165 # If heartbeat reset is pending (data is being received), skip sending
166 # the ping and let the reset callback handle rescheduling the heartbeat.
167 if self._need_heartbeat_reset:
168 return
170 loop = self._loop
171 now = loop.time()
172 if now < self._heartbeat_when:
173 # Heartbeat fired too early, reschedule
174 self._heartbeat_cb = loop.call_at(
175 self._heartbeat_when, self._send_heartbeat
176 )
177 return
179 conn = self._conn
180 timeout_ceil_threshold = (
181 conn._connector._timeout_ceil_threshold if conn is not None else 5
182 )
183 when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
184 self._cancel_pong_response_cb()
185 self._pong_response_cb = loop.call_at(when, self._pong_not_received)
187 coro = self._writer.send_frame(b"", WSMsgType.PING)
188 if sys.version_info >= (3, 12):
189 # Optimization for Python 3.12, try to send the ping
190 # immediately to avoid having to schedule
191 # the task on the event loop.
192 ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
193 else:
194 ping_task = loop.create_task(coro)
196 if not ping_task.done():
197 self._ping_task = ping_task
198 ping_task.add_done_callback(self._ping_task_done)
199 else:
200 self._ping_task_done(ping_task)
202 def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
203 """Callback for when the ping task completes."""
204 if not task.cancelled() and (exc := task.exception()):
205 self._handle_ping_pong_exception(exc)
206 self._ping_task = None
208 def _pong_not_received(self) -> None:
209 self._handle_ping_pong_exception(
210 ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds")
211 )
213 def _handle_ping_pong_exception(self, exc: BaseException) -> None:
214 """Handle exceptions raised during ping/pong processing."""
215 if self._closed:
216 return
217 self._set_closed()
218 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
219 self._exception = exc
220 self._response.close()
221 if self._waiting and not self._closing:
222 self._reader.feed_data(WSMessageError(data=exc, extra=None))
224 def _set_closed(self) -> None:
225 """Set the connection to closed.
227 Cancel any heartbeat timers and set the closed flag.
228 """
229 self._closed = True
230 self._cancel_heartbeat()
232 def _set_closing(self) -> None:
233 """Set the connection to closing.
235 Cancel any heartbeat timers and set the closing flag.
236 """
237 self._closing = True
238 self._cancel_heartbeat()
240 @property
241 def closed(self) -> bool:
242 return self._closed
244 @property
245 def close_code(self) -> int | None:
246 return self._close_code
248 @property
249 def protocol(self) -> str | None:
250 return self._protocol
252 @property
253 def compress(self) -> int:
254 return self._compress
256 @property
257 def client_notakeover(self) -> bool:
258 return self._client_notakeover
260 def get_extra_info(self, name: str, default: Any = None) -> Any:
261 """extra info from connection transport"""
262 conn = self._response.connection
263 if conn is None:
264 return default
265 transport = conn.transport
266 if transport is None:
267 return default
268 return transport.get_extra_info(name, default)
270 def exception(self) -> BaseException | None:
271 return self._exception
273 async def ping(self, message: bytes = b"") -> None:
274 await self._writer.send_frame(message, WSMsgType.PING)
276 async def pong(self, message: bytes = b"") -> None:
277 await self._writer.send_frame(message, WSMsgType.PONG)
279 async def send_frame(
280 self, message: bytes, opcode: WSMsgType, compress: int | None = None
281 ) -> None:
282 """Send a frame over the websocket."""
283 await self._writer.send_frame(message, opcode, compress)
285 async def send_str(self, data: str, compress: int | None = None) -> None:
286 if not isinstance(data, str):
287 raise TypeError("data argument must be str (%r)" % type(data))
288 await self._writer.send_frame(
289 data.encode("utf-8"), WSMsgType.TEXT, compress=compress
290 )
292 async def send_bytes(self, data: bytes, compress: int | None = None) -> None:
293 if not isinstance(data, (bytes, bytearray, memoryview)):
294 raise TypeError("data argument must be byte-ish (%r)" % type(data))
295 await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)
297 async def send_json(
298 self,
299 data: Any,
300 compress: int | None = None,
301 *,
302 dumps: JSONEncoder = DEFAULT_JSON_ENCODER,
303 ) -> None:
304 await self.send_str(dumps(data), compress=compress)
306 async def send_json_bytes(
307 self,
308 data: Any,
309 compress: int | None = None,
310 *,
311 dumps: JSONBytesEncoder,
312 ) -> None:
313 """Send JSON data using a bytes-returning encoder as a binary frame.
315 Use this when your JSON encoder (like orjson) returns bytes
316 instead of str, avoiding the encode/decode overhead.
317 """
318 await self.send_bytes(dumps(data), compress=compress)
320 async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
321 # we need to break `receive()` cycle first,
322 # `close()` may be called from different task
323 if self._waiting and not self._closing:
324 assert self._loop is not None
325 self._close_wait = self._loop.create_future()
326 self._set_closing()
327 self._reader.feed_data(WS_CLOSING_MESSAGE)
328 await self._close_wait
330 if self._closed:
331 return False
333 self._set_closed()
334 try:
335 await self._writer.close(code, message)
336 except asyncio.CancelledError:
337 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
338 self._response.close()
339 raise
340 except Exception as exc:
341 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
342 self._exception = exc
343 self._response.close()
344 return True
346 if self._close_code:
347 self._response.close()
348 return True
350 while True:
351 try:
352 async with async_timeout.timeout(self._timeout.ws_close):
353 msg = await self._reader.read()
354 except asyncio.CancelledError:
355 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
356 self._response.close()
357 raise
358 except Exception as exc:
359 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
360 self._exception = exc
361 self._response.close()
362 return True
364 if msg.type is WSMsgType.CLOSE:
365 self._close_code = msg.data
366 self._response.close()
367 return True
369 @overload
370 async def receive(
371 self: "ClientWebSocketResponse[Literal[True]]", timeout: float | None = None
372 ) -> WSMessageDecodeText: ...
374 @overload
375 async def receive(
376 self: "ClientWebSocketResponse[Literal[False]]", timeout: float | None = None
377 ) -> WSMessageNoDecodeText: ...
379 @overload
380 async def receive(
381 self: "ClientWebSocketResponse[_DecodeText]", timeout: float | None = None
382 ) -> WSMessageDecodeText | WSMessageNoDecodeText: ...
384 async def receive(
385 self, timeout: float | None = None
386 ) -> WSMessageDecodeText | WSMessageNoDecodeText:
387 receive_timeout = timeout or self._timeout.ws_receive
389 while True:
390 if self._waiting:
391 raise RuntimeError("Concurrent call to receive() is not allowed")
393 if self._closed:
394 return WS_CLOSED_MESSAGE
395 elif self._closing:
396 await self.close()
397 return WS_CLOSED_MESSAGE
399 try:
400 self._waiting = True
401 try:
402 if receive_timeout:
403 # Entering the context manager and creating
404 # Timeout() object can take almost 50% of the
405 # run time in this loop so we avoid it if
406 # there is no read timeout.
407 async with async_timeout.timeout(receive_timeout):
408 msg = await self._reader.read()
409 else:
410 msg = await self._reader.read()
411 finally:
412 self._waiting = False
413 if self._close_wait:
414 set_result(self._close_wait, None)
415 except (asyncio.CancelledError, asyncio.TimeoutError):
416 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
417 raise
418 except EofStream:
419 self._close_code = WSCloseCode.OK
420 await self.close()
421 return WS_CLOSED_MESSAGE
422 except ClientError:
423 # Likely ServerDisconnectedError when connection is lost
424 self._set_closed()
425 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
426 return WS_CLOSED_MESSAGE
427 except WebSocketError as exc:
428 self._close_code = exc.code
429 await self.close(code=exc.code)
430 return WSMessageError(data=exc)
431 except Exception as exc:
432 self._exception = exc
433 self._set_closing()
434 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
435 await self.close()
436 return WSMessageError(data=exc)
438 if msg.type not in _INTERNAL_RECEIVE_TYPES:
439 # If its not a close/closing/ping/pong message
440 # we can return it immediately
441 return msg
443 if msg.type is WSMsgType.CLOSE:
444 self._set_closing()
445 self._close_code = msg.data
446 # Could be closed elsewhere while awaiting reader
447 if not self._closed and self._autoclose: # type: ignore[redundant-expr]
448 await self.close()
449 elif msg.type is WSMsgType.CLOSING:
450 self._set_closing()
451 elif msg.type is WSMsgType.PING and self._autoping:
452 await self.pong(msg.data)
453 continue
454 elif msg.type is WSMsgType.PONG and self._autoping:
455 continue
457 return msg
459 @overload
460 async def receive_str(
461 self: "ClientWebSocketResponse[Literal[True]]", *, timeout: float | None = None
462 ) -> str: ...
464 @overload
465 async def receive_str(
466 self: "ClientWebSocketResponse[Literal[False]]", *, timeout: float | None = None
467 ) -> bytes: ...
469 @overload
470 async def receive_str(
471 self: "ClientWebSocketResponse[_DecodeText]", *, timeout: float | None = None
472 ) -> str | bytes: ...
474 async def receive_str(self, *, timeout: float | None = None) -> str | bytes:
475 """Receive TEXT message.
477 Returns str when decode_text=True (default), bytes when decode_text=False.
478 """
479 msg = await self.receive(timeout)
480 if msg.type is not WSMsgType.TEXT:
481 raise WSMessageTypeError(
482 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
483 )
484 return msg.data
486 async def receive_bytes(self, *, timeout: float | None = None) -> bytes:
487 msg = await self.receive(timeout)
488 if msg.type is not WSMsgType.BINARY:
489 raise WSMessageTypeError(
490 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
491 )
492 return msg.data
494 @overload
495 async def receive_json(
496 self: "ClientWebSocketResponse[Literal[True]]",
497 *,
498 loads: JSONDecoder = ...,
499 timeout: float | None = None,
500 ) -> Any: ...
502 @overload
503 async def receive_json(
504 self: "ClientWebSocketResponse[Literal[False]]",
505 *,
506 loads: Callable[[bytes], Any] = ...,
507 timeout: float | None = None,
508 ) -> Any: ...
510 @overload
511 async def receive_json(
512 self: "ClientWebSocketResponse[_DecodeText]",
513 *,
514 loads: JSONDecoder | Callable[[bytes], Any] = ...,
515 timeout: float | None = None,
516 ) -> Any: ...
518 async def receive_json(
519 self,
520 *,
521 loads: JSONDecoder | Callable[[bytes], Any] = DEFAULT_JSON_DECODER,
522 timeout: float | None = None,
523 ) -> Any:
524 data = await self.receive_str(timeout=timeout)
525 return loads(data) # type: ignore[arg-type]
527 def __aiter__(self) -> Self:
528 return self
530 @overload
531 async def __anext__(
532 self: "ClientWebSocketResponse[Literal[True]]",
533 ) -> WSMessageDecodeText: ...
535 @overload
536 async def __anext__(
537 self: "ClientWebSocketResponse[Literal[False]]",
538 ) -> WSMessageNoDecodeText: ...
540 @overload
541 async def __anext__(
542 self: "ClientWebSocketResponse[_DecodeText]",
543 ) -> WSMessageDecodeText | WSMessageNoDecodeText: ...
545 async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText:
546 msg = await self.receive()
547 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
548 raise StopAsyncIteration
549 return msg
551 async def __aenter__(self) -> Self:
552 return self
554 async def __aexit__(
555 self,
556 exc_type: type[BaseException] | None,
557 exc_val: BaseException | None,
558 exc_tb: TracebackType | None,
559 ) -> None:
560 await self.close()