Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiohttp/client_ws.py: 27%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""WebSocket client for asyncio."""
3import asyncio
4import sys
5from collections.abc import Callable
6from types import TracebackType
7from typing import Any, Generic, Literal, Optional, cast, overload
9import attr
11from ._websocket.reader import WebSocketDataQueue
12from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
13from .client_reqrep import ClientResponse
14from .helpers import calculate_timeout_when, set_result
15from .http import (
16 WS_CLOSED_MESSAGE,
17 WS_CLOSING_MESSAGE,
18 WebSocketError,
19 WSCloseCode,
20 WSMessage,
21 WSMessageDecodeText,
22 WSMessageNoDecodeText,
23 WSMsgType,
24)
25from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter
26from .streams import EofStream
27from .typedefs import (
28 DEFAULT_JSON_DECODER,
29 DEFAULT_JSON_ENCODER,
30 JSONBytesEncoder,
31 JSONDecoder,
32 JSONEncoder,
33)
35if sys.version_info >= (3, 13):
36 from typing import TypeVar
37else:
38 from typing_extensions import TypeVar
40if sys.version_info >= (3, 11):
41 import asyncio as async_timeout
42 from typing import Self
43else:
44 import async_timeout
45 from typing_extensions import Self
47# TypeVar for whether text messages are decoded to str (True) or kept as bytes (False)
48# Covariant because it only affects return types, not input types
49_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True])
52@attr.s(frozen=True, slots=True)
53class ClientWSTimeout:
54 ws_receive = attr.ib(type=Optional[float], default=None)
55 ws_close = attr.ib(type=Optional[float], default=None)
58DEFAULT_WS_CLIENT_TIMEOUT = ClientWSTimeout(ws_receive=None, ws_close=10.0)
61class ClientWebSocketResponse(Generic[_DecodeText]):
62 def __init__(
63 self,
64 reader: WebSocketDataQueue,
65 writer: WebSocketWriter,
66 protocol: str | None,
67 response: ClientResponse,
68 timeout: ClientWSTimeout,
69 autoclose: bool,
70 autoping: bool,
71 loop: asyncio.AbstractEventLoop,
72 *,
73 heartbeat: float | None = None,
74 compress: int = 0,
75 client_notakeover: bool = False,
76 ) -> None:
77 self._response = response
78 self._conn = response.connection
80 self._writer = writer
81 self._reader = reader
82 self._protocol = protocol
83 self._closed = False
84 self._closing = False
85 self._close_code: int | None = None
86 self._timeout = timeout
87 self._autoclose = autoclose
88 self._autoping = autoping
89 self._heartbeat = heartbeat
90 self._heartbeat_cb: asyncio.TimerHandle | None = None
91 self._heartbeat_when: float = 0.0
92 if heartbeat is not None:
93 self._pong_heartbeat = heartbeat / 2.0
94 self._pong_response_cb: asyncio.TimerHandle | None = None
95 self._loop = loop
96 self._waiting: bool = False
97 self._close_wait: asyncio.Future[None] | None = None
98 self._exception: BaseException | None = None
99 self._compress = compress
100 self._client_notakeover = client_notakeover
101 self._ping_task: asyncio.Task[None] | None = None
102 self._need_heartbeat_reset = False
103 self._heartbeat_reset_handle: asyncio.Handle | None = None
105 self._reset_heartbeat()
107 def _cancel_heartbeat(self) -> None:
108 self._cancel_pong_response_cb()
109 if self._heartbeat_reset_handle is not None:
110 self._heartbeat_reset_handle.cancel()
111 self._heartbeat_reset_handle = None
112 self._need_heartbeat_reset = False
113 if self._heartbeat_cb is not None:
114 self._heartbeat_cb.cancel()
115 self._heartbeat_cb = None
116 if self._ping_task is not None:
117 self._ping_task.cancel()
118 self._ping_task = None
120 def _cancel_pong_response_cb(self) -> None:
121 if self._pong_response_cb is not None:
122 self._pong_response_cb.cancel()
123 self._pong_response_cb = None
125 def _on_data_received(self) -> None:
126 if self._heartbeat is None or self._need_heartbeat_reset:
127 return
128 loop = self._loop
129 assert loop is not None
130 # Coalesce multiple chunks received in the same loop tick into a single
131 # heartbeat reset. Resetting immediately per chunk increases timer churn.
132 self._need_heartbeat_reset = True
133 self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset)
135 def _flush_heartbeat_reset(self) -> None:
136 self._heartbeat_reset_handle = None
137 if not self._need_heartbeat_reset:
138 return
139 self._reset_heartbeat()
140 self._need_heartbeat_reset = False
142 def _reset_heartbeat(self) -> None:
143 if self._heartbeat is None:
144 return
145 self._cancel_pong_response_cb()
146 loop = self._loop
147 assert loop is not None
148 conn = self._conn
149 timeout_ceil_threshold = (
150 conn._connector._timeout_ceil_threshold if conn is not None else 5
151 )
152 now = loop.time()
153 when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
154 self._heartbeat_when = when
155 if self._heartbeat_cb is None:
156 # We do not cancel the previous heartbeat_cb here because
157 # it generates a significant amount of TimerHandle churn
158 # which causes asyncio to rebuild the heap frequently.
159 # Instead _send_heartbeat() will reschedule the next
160 # heartbeat if it fires too early.
161 self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
163 def _send_heartbeat(self) -> None:
164 self._heartbeat_cb = None
166 # If heartbeat reset is pending (data is being received), skip sending
167 # the ping and let the reset callback handle rescheduling the heartbeat.
168 if self._need_heartbeat_reset:
169 return
171 loop = self._loop
172 now = loop.time()
173 if now < self._heartbeat_when:
174 # Heartbeat fired too early, reschedule
175 self._heartbeat_cb = loop.call_at(
176 self._heartbeat_when, self._send_heartbeat
177 )
178 return
180 conn = self._conn
181 timeout_ceil_threshold = (
182 conn._connector._timeout_ceil_threshold if conn is not None else 5
183 )
184 when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
185 self._cancel_pong_response_cb()
186 self._pong_response_cb = loop.call_at(when, self._pong_not_received)
188 coro = self._writer.send_frame(b"", WSMsgType.PING)
189 if sys.version_info >= (3, 12):
190 # Optimization for Python 3.12, try to send the ping
191 # immediately to avoid having to schedule
192 # the task on the event loop.
193 ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
194 else:
195 ping_task = loop.create_task(coro)
197 if not ping_task.done():
198 self._ping_task = ping_task
199 ping_task.add_done_callback(self._ping_task_done)
200 else:
201 self._ping_task_done(ping_task)
203 def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
204 """Callback for when the ping task completes."""
205 if not task.cancelled() and (exc := task.exception()):
206 self._handle_ping_pong_exception(exc)
207 self._ping_task = None
209 def _pong_not_received(self) -> None:
210 self._handle_ping_pong_exception(
211 ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds")
212 )
214 def _handle_ping_pong_exception(self, exc: BaseException) -> None:
215 """Handle exceptions raised during ping/pong processing."""
216 if self._closed:
217 return
218 self._set_closed()
219 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
220 self._exception = exc
221 self._response.close()
222 if self._waiting and not self._closing:
223 self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0)
225 def _set_closed(self) -> None:
226 """Set the connection to closed.
228 Cancel any heartbeat timers and set the closed flag.
229 """
230 self._closed = True
231 self._cancel_heartbeat()
233 def _set_closing(self) -> None:
234 """Set the connection to closing.
236 Cancel any heartbeat timers and set the closing flag.
237 """
238 self._closing = True
239 self._cancel_heartbeat()
241 @property
242 def closed(self) -> bool:
243 return self._closed
245 @property
246 def close_code(self) -> int | None:
247 return self._close_code
249 @property
250 def protocol(self) -> str | None:
251 return self._protocol
253 @property
254 def compress(self) -> int:
255 return self._compress
257 @property
258 def client_notakeover(self) -> bool:
259 return self._client_notakeover
261 def get_extra_info(self, name: str, default: Any = None) -> Any:
262 """extra info from connection transport"""
263 conn = self._response.connection
264 if conn is None:
265 return default
266 transport = conn.transport
267 if transport is None:
268 return default
269 return transport.get_extra_info(name, default)
271 def exception(self) -> BaseException | None:
272 return self._exception
274 async def ping(self, message: bytes = b"") -> None:
275 await self._writer.send_frame(message, WSMsgType.PING)
277 async def pong(self, message: bytes = b"") -> None:
278 await self._writer.send_frame(message, WSMsgType.PONG)
280 async def send_frame(
281 self, message: bytes, opcode: WSMsgType, compress: int | None = None
282 ) -> None:
283 """Send a frame over the websocket."""
284 await self._writer.send_frame(message, opcode, compress)
286 async def send_str(self, data: str, compress: int | None = None) -> None:
287 if not isinstance(data, str):
288 raise TypeError("data argument must be str (%r)" % type(data))
289 await self._writer.send_frame(
290 data.encode("utf-8"), WSMsgType.TEXT, compress=compress
291 )
293 async def send_bytes(self, data: bytes, compress: int | None = None) -> None:
294 if not isinstance(data, (bytes, bytearray, memoryview)):
295 raise TypeError("data argument must be byte-ish (%r)" % type(data))
296 await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)
298 async def send_json(
299 self,
300 data: Any,
301 compress: int | None = None,
302 *,
303 dumps: JSONEncoder = DEFAULT_JSON_ENCODER,
304 ) -> None:
305 await self.send_str(dumps(data), compress=compress)
307 async def send_json_bytes(
308 self,
309 data: Any,
310 compress: int | None = None,
311 *,
312 dumps: JSONBytesEncoder,
313 ) -> None:
314 """Send JSON data using a bytes-returning encoder as a binary frame.
316 Use this when your JSON encoder (like orjson) returns bytes
317 instead of str, avoiding the encode/decode overhead.
318 """
319 await self.send_bytes(dumps(data), compress=compress)
321 async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
322 # we need to break `receive()` cycle first,
323 # `close()` may be called from different task
324 if self._waiting and not self._closing:
325 assert self._loop is not None
326 self._close_wait = self._loop.create_future()
327 self._set_closing()
328 self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
329 await self._close_wait
331 if self._closed:
332 return False
334 self._set_closed()
335 try:
336 await self._writer.close(code, message)
337 except asyncio.CancelledError:
338 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
339 self._response.close()
340 raise
341 except Exception as exc:
342 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
343 self._exception = exc
344 self._response.close()
345 return True
347 if self._close_code:
348 self._response.close()
349 return True
351 while True:
352 try:
353 async with async_timeout.timeout(self._timeout.ws_close):
354 msg = await self._reader.read()
355 except asyncio.CancelledError:
356 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
357 self._response.close()
358 raise
359 except Exception as exc:
360 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
361 self._exception = exc
362 self._response.close()
363 return True
365 if msg.type is WSMsgType.CLOSE:
366 self._close_code = msg.data
367 self._response.close()
368 return True
370 @overload
371 async def receive(
372 self: "ClientWebSocketResponse[Literal[True]]", timeout: float | None = None
373 ) -> WSMessageDecodeText: ...
375 @overload
376 async def receive(
377 self: "ClientWebSocketResponse[Literal[False]]", timeout: float | None = None
378 ) -> WSMessageNoDecodeText: ...
380 @overload
381 async def receive(
382 self: "ClientWebSocketResponse[_DecodeText]", timeout: float | None = None
383 ) -> WSMessageDecodeText | WSMessageNoDecodeText: ...
385 async def receive(
386 self, timeout: float | None = None
387 ) -> WSMessageDecodeText | WSMessageNoDecodeText:
388 receive_timeout = timeout or self._timeout.ws_receive
390 while True:
391 if self._waiting:
392 raise RuntimeError("Concurrent call to receive() is not allowed")
394 if self._closed:
395 return WS_CLOSED_MESSAGE
396 elif self._closing:
397 await self.close()
398 return WS_CLOSED_MESSAGE
400 try:
401 self._waiting = True
402 try:
403 if receive_timeout:
404 # Entering the context manager and creating
405 # Timeout() object can take almost 50% of the
406 # run time in this loop so we avoid it if
407 # there is no read timeout.
408 async with async_timeout.timeout(receive_timeout):
409 msg = await self._reader.read()
410 else:
411 msg = await self._reader.read()
412 finally:
413 self._waiting = False
414 if self._close_wait:
415 set_result(self._close_wait, None)
416 except (asyncio.CancelledError, asyncio.TimeoutError):
417 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
418 raise
419 except EofStream:
420 self._close_code = WSCloseCode.OK
421 await self.close()
422 return WSMessage(WSMsgType.CLOSED, None, None)
423 except ClientError:
424 # Likely ServerDisconnectedError when connection is lost
425 self._set_closed()
426 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
427 return WS_CLOSED_MESSAGE
428 except WebSocketError as exc:
429 self._close_code = exc.code
430 await self.close(code=exc.code)
431 return WSMessage(WSMsgType.ERROR, exc, None)
432 except Exception as exc:
433 self._exception = exc
434 self._set_closing()
435 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
436 await self.close()
437 return WSMessage(WSMsgType.ERROR, exc, None)
439 if msg.type not in _INTERNAL_RECEIVE_TYPES:
440 # If its not a close/closing/ping/pong message
441 # we can return it immediately
442 return msg
444 if msg.type is WSMsgType.CLOSE:
445 self._set_closing()
446 self._close_code = msg.data
447 if not self._closed and self._autoclose:
448 await self.close()
449 elif msg.type is WSMsgType.CLOSING:
450 self._set_closing()
451 elif msg.type is WSMsgType.PING and self._autoping:
452 await self.pong(msg.data)
453 continue
454 elif msg.type is WSMsgType.PONG and self._autoping:
455 continue
457 return msg
459 @overload
460 async def receive_str(
461 self: "ClientWebSocketResponse[Literal[True]]", *, timeout: float | None = None
462 ) -> str: ...
464 @overload
465 async def receive_str(
466 self: "ClientWebSocketResponse[Literal[False]]", *, timeout: float | None = None
467 ) -> bytes: ...
469 @overload
470 async def receive_str(
471 self: "ClientWebSocketResponse[_DecodeText]", *, timeout: float | None = None
472 ) -> str | bytes: ...
474 async def receive_str(self, *, timeout: float | None = None) -> str | bytes:
475 """Receive TEXT message.
477 Returns str when decode_text=True (default), bytes when decode_text=False.
478 """
479 msg = await self.receive(timeout)
480 if msg.type is not WSMsgType.TEXT:
481 raise WSMessageTypeError(
482 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
483 )
484 return cast(str, msg.data)
486 async def receive_bytes(self, *, timeout: float | None = None) -> bytes:
487 msg = await self.receive(timeout)
488 if msg.type is not WSMsgType.BINARY:
489 raise WSMessageTypeError(
490 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
491 )
492 return cast(bytes, msg.data)
494 @overload
495 async def receive_json(
496 self: "ClientWebSocketResponse[Literal[True]]",
497 *,
498 loads: JSONDecoder = ...,
499 timeout: float | None = None,
500 ) -> Any: ...
502 @overload
503 async def receive_json(
504 self: "ClientWebSocketResponse[Literal[False]]",
505 *,
506 loads: Callable[[bytes], Any] = ...,
507 timeout: float | None = None,
508 ) -> Any: ...
510 @overload
511 async def receive_json(
512 self: "ClientWebSocketResponse[_DecodeText]",
513 *,
514 loads: JSONDecoder | Callable[[bytes], Any] = ...,
515 timeout: float | None = None,
516 ) -> Any: ...
518 async def receive_json(
519 self,
520 *,
521 loads: JSONDecoder | Callable[[bytes], Any] = DEFAULT_JSON_DECODER,
522 timeout: float | None = None,
523 ) -> Any:
524 data = await self.receive_str(timeout=timeout)
525 return loads(data) # type: ignore[arg-type]
527 def __aiter__(self) -> Self:
528 return self
530 @overload
531 async def __anext__(
532 self: "ClientWebSocketResponse[Literal[True]]",
533 ) -> WSMessageDecodeText: ...
535 @overload
536 async def __anext__(
537 self: "ClientWebSocketResponse[Literal[False]]",
538 ) -> WSMessageNoDecodeText: ...
540 @overload
541 async def __anext__(
542 self: "ClientWebSocketResponse[_DecodeText]",
543 ) -> WSMessageDecodeText | WSMessageNoDecodeText: ...
545 async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText:
546 msg = await self.receive()
547 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
548 raise StopAsyncIteration
549 return msg
551 async def __aenter__(self) -> Self:
552 return self
554 async def __aexit__(
555 self,
556 exc_type: type[BaseException] | None,
557 exc_val: BaseException | None,
558 exc_tb: TracebackType | None,
559 ) -> None:
560 await self.close()