Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/client_ws.py: 21%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""WebSocket client for asyncio."""
3import asyncio
4import sys
5from types import TracebackType
6from typing import Any, Final, Optional, Type
8from ._websocket.reader import WebSocketDataQueue
9from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
10from .client_reqrep import ClientResponse
11from .helpers import calculate_timeout_when, frozen_dataclass_decorator, set_result
12from .http import (
13 WS_CLOSED_MESSAGE,
14 WS_CLOSING_MESSAGE,
15 WebSocketError,
16 WSCloseCode,
17 WSMessage,
18 WSMsgType,
19)
20from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter, WSMessageError
21from .streams import EofStream
22from .typedefs import (
23 DEFAULT_JSON_DECODER,
24 DEFAULT_JSON_ENCODER,
25 JSONDecoder,
26 JSONEncoder,
27)
29if sys.version_info >= (3, 11):
30 import asyncio as async_timeout
31else:
32 import async_timeout
35@frozen_dataclass_decorator
36class ClientWSTimeout:
37 ws_receive: Optional[float] = None
38 ws_close: Optional[float] = None
41DEFAULT_WS_CLIENT_TIMEOUT: Final[ClientWSTimeout] = ClientWSTimeout(
42 ws_receive=None, ws_close=10.0
43)
46class ClientWebSocketResponse:
47 def __init__(
48 self,
49 reader: WebSocketDataQueue,
50 writer: WebSocketWriter,
51 protocol: Optional[str],
52 response: ClientResponse,
53 timeout: ClientWSTimeout,
54 autoclose: bool,
55 autoping: bool,
56 loop: asyncio.AbstractEventLoop,
57 *,
58 heartbeat: Optional[float] = None,
59 compress: int = 0,
60 client_notakeover: bool = False,
61 ) -> None:
62 self._response = response
63 self._conn = response.connection
65 self._writer = writer
66 self._reader = reader
67 self._protocol = protocol
68 self._closed = False
69 self._closing = False
70 self._close_code: Optional[int] = None
71 self._timeout = timeout
72 self._autoclose = autoclose
73 self._autoping = autoping
74 self._heartbeat = heartbeat
75 self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
76 self._heartbeat_when: float = 0.0
77 if heartbeat is not None:
78 self._pong_heartbeat = heartbeat / 2.0
79 self._pong_response_cb: Optional[asyncio.TimerHandle] = None
80 self._loop = loop
81 self._waiting: bool = False
82 self._close_wait: Optional[asyncio.Future[None]] = None
83 self._exception: Optional[BaseException] = None
84 self._compress = compress
85 self._client_notakeover = client_notakeover
86 self._ping_task: Optional[asyncio.Task[None]] = None
88 self._reset_heartbeat()
90 def _cancel_heartbeat(self) -> None:
91 self._cancel_pong_response_cb()
92 if self._heartbeat_cb is not None:
93 self._heartbeat_cb.cancel()
94 self._heartbeat_cb = None
95 if self._ping_task is not None:
96 self._ping_task.cancel()
97 self._ping_task = None
99 def _cancel_pong_response_cb(self) -> None:
100 if self._pong_response_cb is not None:
101 self._pong_response_cb.cancel()
102 self._pong_response_cb = None
104 def _reset_heartbeat(self) -> None:
105 if self._heartbeat is None:
106 return
107 self._cancel_pong_response_cb()
108 loop = self._loop
109 assert loop is not None
110 conn = self._conn
111 timeout_ceil_threshold = (
112 conn._connector._timeout_ceil_threshold if conn is not None else 5
113 )
114 now = loop.time()
115 when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
116 self._heartbeat_when = when
117 if self._heartbeat_cb is None:
118 # We do not cancel the previous heartbeat_cb here because
119 # it generates a significant amount of TimerHandle churn
120 # which causes asyncio to rebuild the heap frequently.
121 # Instead _send_heartbeat() will reschedule the next
122 # heartbeat if it fires too early.
123 self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
125 def _send_heartbeat(self) -> None:
126 self._heartbeat_cb = None
127 loop = self._loop
128 now = loop.time()
129 if now < self._heartbeat_when:
130 # Heartbeat fired too early, reschedule
131 self._heartbeat_cb = loop.call_at(
132 self._heartbeat_when, self._send_heartbeat
133 )
134 return
136 conn = self._conn
137 timeout_ceil_threshold = (
138 conn._connector._timeout_ceil_threshold if conn is not None else 5
139 )
140 when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
141 self._cancel_pong_response_cb()
142 self._pong_response_cb = loop.call_at(when, self._pong_not_received)
144 coro = self._writer.send_frame(b"", WSMsgType.PING)
145 if sys.version_info >= (3, 12):
146 # Optimization for Python 3.12, try to send the ping
147 # immediately to avoid having to schedule
148 # the task on the event loop.
149 ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
150 else:
151 ping_task = loop.create_task(coro)
153 if not ping_task.done():
154 self._ping_task = ping_task
155 ping_task.add_done_callback(self._ping_task_done)
156 else:
157 self._ping_task_done(ping_task)
159 def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
160 """Callback for when the ping task completes."""
161 if not task.cancelled() and (exc := task.exception()):
162 self._handle_ping_pong_exception(exc)
163 self._ping_task = None
165 def _pong_not_received(self) -> None:
166 self._handle_ping_pong_exception(
167 ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds")
168 )
170 def _handle_ping_pong_exception(self, exc: BaseException) -> None:
171 """Handle exceptions raised during ping/pong processing."""
172 if self._closed:
173 return
174 self._set_closed()
175 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
176 self._exception = exc
177 self._response.close()
178 if self._waiting and not self._closing:
179 self._reader.feed_data(WSMessageError(data=exc, extra=None))
181 def _set_closed(self) -> None:
182 """Set the connection to closed.
184 Cancel any heartbeat timers and set the closed flag.
185 """
186 self._closed = True
187 self._cancel_heartbeat()
189 def _set_closing(self) -> None:
190 """Set the connection to closing.
192 Cancel any heartbeat timers and set the closing flag.
193 """
194 self._closing = True
195 self._cancel_heartbeat()
197 @property
198 def closed(self) -> bool:
199 return self._closed
201 @property
202 def close_code(self) -> Optional[int]:
203 return self._close_code
205 @property
206 def protocol(self) -> Optional[str]:
207 return self._protocol
209 @property
210 def compress(self) -> int:
211 return self._compress
213 @property
214 def client_notakeover(self) -> bool:
215 return self._client_notakeover
217 def get_extra_info(self, name: str, default: Any = None) -> Any:
218 """extra info from connection transport"""
219 conn = self._response.connection
220 if conn is None:
221 return default
222 transport = conn.transport
223 if transport is None:
224 return default
225 return transport.get_extra_info(name, default)
227 def exception(self) -> Optional[BaseException]:
228 return self._exception
230 async def ping(self, message: bytes = b"") -> None:
231 await self._writer.send_frame(message, WSMsgType.PING)
233 async def pong(self, message: bytes = b"") -> None:
234 await self._writer.send_frame(message, WSMsgType.PONG)
236 async def send_frame(
237 self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
238 ) -> None:
239 """Send a frame over the websocket."""
240 await self._writer.send_frame(message, opcode, compress)
242 async def send_str(self, data: str, compress: Optional[int] = None) -> None:
243 if not isinstance(data, str):
244 raise TypeError("data argument must be str (%r)" % type(data))
245 await self._writer.send_frame(
246 data.encode("utf-8"), WSMsgType.TEXT, compress=compress
247 )
249 async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
250 if not isinstance(data, (bytes, bytearray, memoryview)):
251 raise TypeError("data argument must be byte-ish (%r)" % type(data))
252 await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)
254 async def send_json(
255 self,
256 data: Any,
257 compress: Optional[int] = None,
258 *,
259 dumps: JSONEncoder = DEFAULT_JSON_ENCODER,
260 ) -> None:
261 await self.send_str(dumps(data), compress=compress)
263 async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
264 # we need to break `receive()` cycle first,
265 # `close()` may be called from different task
266 if self._waiting and not self._closing:
267 assert self._loop is not None
268 self._close_wait = self._loop.create_future()
269 self._set_closing()
270 self._reader.feed_data(WS_CLOSING_MESSAGE)
271 await self._close_wait
273 if self._closed:
274 return False
276 self._set_closed()
277 try:
278 await self._writer.close(code, message)
279 except asyncio.CancelledError:
280 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
281 self._response.close()
282 raise
283 except Exception as exc:
284 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
285 self._exception = exc
286 self._response.close()
287 return True
289 if self._close_code:
290 self._response.close()
291 return True
293 while True:
294 try:
295 async with async_timeout.timeout(self._timeout.ws_close):
296 msg = await self._reader.read()
297 except asyncio.CancelledError:
298 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
299 self._response.close()
300 raise
301 except Exception as exc:
302 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
303 self._exception = exc
304 self._response.close()
305 return True
307 if msg.type is WSMsgType.CLOSE:
308 self._close_code = msg.data
309 self._response.close()
310 return True
312 async def receive(self, timeout: Optional[float] = None) -> WSMessage:
313 receive_timeout = timeout or self._timeout.ws_receive
315 while True:
316 if self._waiting:
317 raise RuntimeError("Concurrent call to receive() is not allowed")
319 if self._closed:
320 return WS_CLOSED_MESSAGE
321 elif self._closing:
322 await self.close()
323 return WS_CLOSED_MESSAGE
325 try:
326 self._waiting = True
327 try:
328 if receive_timeout:
329 # Entering the context manager and creating
330 # Timeout() object can take almost 50% of the
331 # run time in this loop so we avoid it if
332 # there is no read timeout.
333 async with async_timeout.timeout(receive_timeout):
334 msg = await self._reader.read()
335 else:
336 msg = await self._reader.read()
337 self._reset_heartbeat()
338 finally:
339 self._waiting = False
340 if self._close_wait:
341 set_result(self._close_wait, None)
342 except (asyncio.CancelledError, asyncio.TimeoutError):
343 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
344 raise
345 except EofStream:
346 self._close_code = WSCloseCode.OK
347 await self.close()
348 return WS_CLOSED_MESSAGE
349 except ClientError:
350 # Likely ServerDisconnectedError when connection is lost
351 self._set_closed()
352 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
353 return WS_CLOSED_MESSAGE
354 except WebSocketError as exc:
355 self._close_code = exc.code
356 await self.close(code=exc.code)
357 return WSMessageError(data=exc)
358 except Exception as exc:
359 self._exception = exc
360 self._set_closing()
361 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
362 await self.close()
363 return WSMessageError(data=exc)
365 if msg.type not in _INTERNAL_RECEIVE_TYPES:
366 # If its not a close/closing/ping/pong message
367 # we can return it immediately
368 return msg
370 if msg.type is WSMsgType.CLOSE:
371 self._set_closing()
372 self._close_code = msg.data
373 # Could be closed elsewhere while awaiting reader
374 if not self._closed and self._autoclose: # type: ignore[redundant-expr]
375 await self.close()
376 elif msg.type is WSMsgType.CLOSING:
377 self._set_closing()
378 elif msg.type is WSMsgType.PING and self._autoping:
379 await self.pong(msg.data)
380 continue
381 elif msg.type is WSMsgType.PONG and self._autoping:
382 continue
384 return msg
386 async def receive_str(self, *, timeout: Optional[float] = None) -> str:
387 msg = await self.receive(timeout)
388 if msg.type is not WSMsgType.TEXT:
389 raise WSMessageTypeError(
390 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
391 )
392 return msg.data
394 async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
395 msg = await self.receive(timeout)
396 if msg.type is not WSMsgType.BINARY:
397 raise WSMessageTypeError(
398 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
399 )
400 return msg.data
402 async def receive_json(
403 self,
404 *,
405 loads: JSONDecoder = DEFAULT_JSON_DECODER,
406 timeout: Optional[float] = None,
407 ) -> Any:
408 data = await self.receive_str(timeout=timeout)
409 return loads(data)
411 def __aiter__(self) -> "ClientWebSocketResponse":
412 return self
414 async def __anext__(self) -> WSMessage:
415 msg = await self.receive()
416 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
417 raise StopAsyncIteration
418 return msg
420 async def __aenter__(self) -> "ClientWebSocketResponse":
421 return self
423 async def __aexit__(
424 self,
425 exc_type: Optional[Type[BaseException]],
426 exc_val: Optional[BaseException],
427 exc_tb: Optional[TracebackType],
428 ) -> None:
429 await self.close()