1"""WebSocket client for asyncio."""
2
3import asyncio
4import sys
5from types import TracebackType
6from typing import Any, Optional, Type, cast
7
8import attr
9
10from ._websocket.reader import WebSocketDataQueue
11from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
12from .client_reqrep import ClientResponse
13from .helpers import calculate_timeout_when, set_result
14from .http import (
15 WS_CLOSED_MESSAGE,
16 WS_CLOSING_MESSAGE,
17 WebSocketError,
18 WSCloseCode,
19 WSMessage,
20 WSMsgType,
21)
22from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter
23from .streams import EofStream
24from .typedefs import (
25 DEFAULT_JSON_DECODER,
26 DEFAULT_JSON_ENCODER,
27 JSONDecoder,
28 JSONEncoder,
29)
30
31if sys.version_info >= (3, 11):
32 import asyncio as async_timeout
33else:
34 import async_timeout
35
36
37@attr.s(frozen=True, slots=True)
38class ClientWSTimeout:
39 ws_receive = attr.ib(type=Optional[float], default=None)
40 ws_close = attr.ib(type=Optional[float], default=None)
41
42
43DEFAULT_WS_CLIENT_TIMEOUT = ClientWSTimeout(ws_receive=None, ws_close=10.0)
44
45
46class ClientWebSocketResponse:
47 def __init__(
48 self,
49 reader: WebSocketDataQueue,
50 writer: WebSocketWriter,
51 protocol: Optional[str],
52 response: ClientResponse,
53 timeout: ClientWSTimeout,
54 autoclose: bool,
55 autoping: bool,
56 loop: asyncio.AbstractEventLoop,
57 *,
58 heartbeat: Optional[float] = None,
59 compress: int = 0,
60 client_notakeover: bool = False,
61 ) -> None:
62 self._response = response
63 self._conn = response.connection
64
65 self._writer = writer
66 self._reader = reader
67 self._protocol = protocol
68 self._closed = False
69 self._closing = False
70 self._close_code: Optional[int] = None
71 self._timeout = timeout
72 self._autoclose = autoclose
73 self._autoping = autoping
74 self._heartbeat = heartbeat
75 self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
76 self._heartbeat_when: float = 0.0
77 if heartbeat is not None:
78 self._pong_heartbeat = heartbeat / 2.0
79 self._pong_response_cb: Optional[asyncio.TimerHandle] = None
80 self._loop = loop
81 self._waiting: bool = False
82 self._close_wait: Optional[asyncio.Future[None]] = None
83 self._exception: Optional[BaseException] = None
84 self._compress = compress
85 self._client_notakeover = client_notakeover
86 self._ping_task: Optional[asyncio.Task[None]] = None
87
88 self._reset_heartbeat()
89
90 def _cancel_heartbeat(self) -> None:
91 self._cancel_pong_response_cb()
92 if self._heartbeat_cb is not None:
93 self._heartbeat_cb.cancel()
94 self._heartbeat_cb = None
95 if self._ping_task is not None:
96 self._ping_task.cancel()
97 self._ping_task = None
98
99 def _cancel_pong_response_cb(self) -> None:
100 if self._pong_response_cb is not None:
101 self._pong_response_cb.cancel()
102 self._pong_response_cb = None
103
104 def _reset_heartbeat(self) -> None:
105 if self._heartbeat is None:
106 return
107 self._cancel_pong_response_cb()
108 loop = self._loop
109 assert loop is not None
110 conn = self._conn
111 timeout_ceil_threshold = (
112 conn._connector._timeout_ceil_threshold if conn is not None else 5
113 )
114 now = loop.time()
115 when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
116 self._heartbeat_when = when
117 if self._heartbeat_cb is None:
118 # We do not cancel the previous heartbeat_cb here because
119 # it generates a significant amount of TimerHandle churn
120 # which causes asyncio to rebuild the heap frequently.
121 # Instead _send_heartbeat() will reschedule the next
122 # heartbeat if it fires too early.
123 self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
124
125 def _send_heartbeat(self) -> None:
126 self._heartbeat_cb = None
127 loop = self._loop
128 now = loop.time()
129 if now < self._heartbeat_when:
130 # Heartbeat fired too early, reschedule
131 self._heartbeat_cb = loop.call_at(
132 self._heartbeat_when, self._send_heartbeat
133 )
134 return
135
136 conn = self._conn
137 timeout_ceil_threshold = (
138 conn._connector._timeout_ceil_threshold if conn is not None else 5
139 )
140 when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
141 self._cancel_pong_response_cb()
142 self._pong_response_cb = loop.call_at(when, self._pong_not_received)
143
144 coro = self._writer.send_frame(b"", WSMsgType.PING)
145 if sys.version_info >= (3, 12):
146 # Optimization for Python 3.12, try to send the ping
147 # immediately to avoid having to schedule
148 # the task on the event loop.
149 ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
150 else:
151 ping_task = loop.create_task(coro)
152
153 if not ping_task.done():
154 self._ping_task = ping_task
155 ping_task.add_done_callback(self._ping_task_done)
156 else:
157 self._ping_task_done(ping_task)
158
159 def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
160 """Callback for when the ping task completes."""
161 if not task.cancelled() and (exc := task.exception()):
162 self._handle_ping_pong_exception(exc)
163 self._ping_task = None
164
165 def _pong_not_received(self) -> None:
166 self._handle_ping_pong_exception(
167 ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds")
168 )
169
170 def _handle_ping_pong_exception(self, exc: BaseException) -> None:
171 """Handle exceptions raised during ping/pong processing."""
172 if self._closed:
173 return
174 self._set_closed()
175 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
176 self._exception = exc
177 self._response.close()
178 if self._waiting and not self._closing:
179 self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0)
180
181 def _set_closed(self) -> None:
182 """Set the connection to closed.
183
184 Cancel any heartbeat timers and set the closed flag.
185 """
186 self._closed = True
187 self._cancel_heartbeat()
188
189 def _set_closing(self) -> None:
190 """Set the connection to closing.
191
192 Cancel any heartbeat timers and set the closing flag.
193 """
194 self._closing = True
195 self._cancel_heartbeat()
196
197 @property
198 def closed(self) -> bool:
199 return self._closed
200
201 @property
202 def close_code(self) -> Optional[int]:
203 return self._close_code
204
205 @property
206 def protocol(self) -> Optional[str]:
207 return self._protocol
208
209 @property
210 def compress(self) -> int:
211 return self._compress
212
213 @property
214 def client_notakeover(self) -> bool:
215 return self._client_notakeover
216
217 def get_extra_info(self, name: str, default: Any = None) -> Any:
218 """extra info from connection transport"""
219 conn = self._response.connection
220 if conn is None:
221 return default
222 transport = conn.transport
223 if transport is None:
224 return default
225 return transport.get_extra_info(name, default)
226
227 def exception(self) -> Optional[BaseException]:
228 return self._exception
229
230 async def ping(self, message: bytes = b"") -> None:
231 await self._writer.send_frame(message, WSMsgType.PING)
232
233 async def pong(self, message: bytes = b"") -> None:
234 await self._writer.send_frame(message, WSMsgType.PONG)
235
236 async def send_frame(
237 self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
238 ) -> None:
239 """Send a frame over the websocket."""
240 await self._writer.send_frame(message, opcode, compress)
241
242 async def send_str(self, data: str, compress: Optional[int] = None) -> None:
243 if not isinstance(data, str):
244 raise TypeError("data argument must be str (%r)" % type(data))
245 await self._writer.send_frame(
246 data.encode("utf-8"), WSMsgType.TEXT, compress=compress
247 )
248
249 async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
250 if not isinstance(data, (bytes, bytearray, memoryview)):
251 raise TypeError("data argument must be byte-ish (%r)" % type(data))
252 await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)
253
254 async def send_json(
255 self,
256 data: Any,
257 compress: Optional[int] = None,
258 *,
259 dumps: JSONEncoder = DEFAULT_JSON_ENCODER,
260 ) -> None:
261 await self.send_str(dumps(data), compress=compress)
262
263 async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
264 # we need to break `receive()` cycle first,
265 # `close()` may be called from different task
266 if self._waiting and not self._closing:
267 assert self._loop is not None
268 self._close_wait = self._loop.create_future()
269 self._set_closing()
270 self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
271 await self._close_wait
272
273 if self._closed:
274 return False
275
276 self._set_closed()
277 try:
278 await self._writer.close(code, message)
279 except asyncio.CancelledError:
280 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
281 self._response.close()
282 raise
283 except Exception as exc:
284 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
285 self._exception = exc
286 self._response.close()
287 return True
288
289 if self._close_code:
290 self._response.close()
291 return True
292
293 while True:
294 try:
295 async with async_timeout.timeout(self._timeout.ws_close):
296 msg = await self._reader.read()
297 except asyncio.CancelledError:
298 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
299 self._response.close()
300 raise
301 except Exception as exc:
302 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
303 self._exception = exc
304 self._response.close()
305 return True
306
307 if msg.type is WSMsgType.CLOSE:
308 self._close_code = msg.data
309 self._response.close()
310 return True
311
312 async def receive(self, timeout: Optional[float] = None) -> WSMessage:
313 receive_timeout = timeout or self._timeout.ws_receive
314
315 while True:
316 if self._waiting:
317 raise RuntimeError("Concurrent call to receive() is not allowed")
318
319 if self._closed:
320 return WS_CLOSED_MESSAGE
321 elif self._closing:
322 await self.close()
323 return WS_CLOSED_MESSAGE
324
325 try:
326 self._waiting = True
327 try:
328 if receive_timeout:
329 # Entering the context manager and creating
330 # Timeout() object can take almost 50% of the
331 # run time in this loop so we avoid it if
332 # there is no read timeout.
333 async with async_timeout.timeout(receive_timeout):
334 msg = await self._reader.read()
335 else:
336 msg = await self._reader.read()
337 self._reset_heartbeat()
338 finally:
339 self._waiting = False
340 if self._close_wait:
341 set_result(self._close_wait, None)
342 except (asyncio.CancelledError, asyncio.TimeoutError):
343 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
344 raise
345 except EofStream:
346 self._close_code = WSCloseCode.OK
347 await self.close()
348 return WSMessage(WSMsgType.CLOSED, None, None)
349 except ClientError:
350 # Likely ServerDisconnectedError when connection is lost
351 self._set_closed()
352 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
353 return WS_CLOSED_MESSAGE
354 except WebSocketError as exc:
355 self._close_code = exc.code
356 await self.close(code=exc.code)
357 return WSMessage(WSMsgType.ERROR, exc, None)
358 except Exception as exc:
359 self._exception = exc
360 self._set_closing()
361 self._close_code = WSCloseCode.ABNORMAL_CLOSURE
362 await self.close()
363 return WSMessage(WSMsgType.ERROR, exc, None)
364
365 if msg.type not in _INTERNAL_RECEIVE_TYPES:
366 # If its not a close/closing/ping/pong message
367 # we can return it immediately
368 return msg
369
370 if msg.type is WSMsgType.CLOSE:
371 self._set_closing()
372 self._close_code = msg.data
373 if not self._closed and self._autoclose:
374 await self.close()
375 elif msg.type is WSMsgType.CLOSING:
376 self._set_closing()
377 elif msg.type is WSMsgType.PING and self._autoping:
378 await self.pong(msg.data)
379 continue
380 elif msg.type is WSMsgType.PONG and self._autoping:
381 continue
382
383 return msg
384
385 async def receive_str(self, *, timeout: Optional[float] = None) -> str:
386 msg = await self.receive(timeout)
387 if msg.type is not WSMsgType.TEXT:
388 raise WSMessageTypeError(
389 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
390 )
391 return cast(str, msg.data)
392
393 async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
394 msg = await self.receive(timeout)
395 if msg.type is not WSMsgType.BINARY:
396 raise WSMessageTypeError(
397 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
398 )
399 return cast(bytes, msg.data)
400
401 async def receive_json(
402 self,
403 *,
404 loads: JSONDecoder = DEFAULT_JSON_DECODER,
405 timeout: Optional[float] = None,
406 ) -> Any:
407 data = await self.receive_str(timeout=timeout)
408 return loads(data)
409
410 def __aiter__(self) -> "ClientWebSocketResponse":
411 return self
412
413 async def __anext__(self) -> WSMessage:
414 msg = await self.receive()
415 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
416 raise StopAsyncIteration
417 return msg
418
419 async def __aenter__(self) -> "ClientWebSocketResponse":
420 return self
421
422 async def __aexit__(
423 self,
424 exc_type: Optional[Type[BaseException]],
425 exc_val: Optional[BaseException],
426 exc_tb: Optional[TracebackType],
427 ) -> None:
428 await self.close()