Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/client_ws.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

330 statements  

1"""WebSocket client for asyncio.""" 

2 

3import asyncio 

4import sys 

5from collections.abc import Callable 

6from types import TracebackType 

7from typing import Any, Final, Generic, Literal, overload 

8 

9from ._websocket.reader import WebSocketDataQueue 

10from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError 

11from .client_reqrep import ClientResponse 

12from .helpers import calculate_timeout_when, frozen_dataclass_decorator, set_result 

13from .http import ( 

14 WS_CLOSED_MESSAGE, 

15 WS_CLOSING_MESSAGE, 

16 WebSocketError, 

17 WSCloseCode, 

18 WSMessageDecodeText, 

19 WSMessageNoDecodeText, 

20 WSMsgType, 

21) 

22from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter, WSMessageError 

23from .streams import EofStream 

24from .typedefs import ( 

25 DEFAULT_JSON_DECODER, 

26 DEFAULT_JSON_ENCODER, 

27 JSONBytesEncoder, 

28 JSONDecoder, 

29 JSONEncoder, 

30) 

31 

32if sys.version_info >= (3, 13): 

33 from typing import TypeVar 

34else: 

35 from typing_extensions import TypeVar 

36 

37if sys.version_info >= (3, 11): 

38 import asyncio as async_timeout 

39 from typing import Self 

40else: 

41 import async_timeout 

42 from typing_extensions import Self 

43 

44# TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) 

45# Covariant because it only affects return types, not input types 

46_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True]) 

47 

48 

49@frozen_dataclass_decorator 

50class ClientWSTimeout: 

51 ws_receive: float | None = None 

52 ws_close: float | None = None 

53 

54 

55DEFAULT_WS_CLIENT_TIMEOUT: Final[ClientWSTimeout] = ClientWSTimeout( 

56 ws_receive=None, ws_close=10.0 

57) 

58 

59 

60class ClientWebSocketResponse(Generic[_DecodeText]): 

61 def __init__( 

62 self, 

63 reader: WebSocketDataQueue, 

64 writer: WebSocketWriter, 

65 protocol: str | None, 

66 response: ClientResponse, 

67 timeout: ClientWSTimeout, 

68 autoclose: bool, 

69 autoping: bool, 

70 loop: asyncio.AbstractEventLoop, 

71 *, 

72 heartbeat: float | None = None, 

73 compress: int = 0, 

74 client_notakeover: bool = False, 

75 ) -> None: 

76 self._response = response 

77 self._conn = response.connection 

78 

79 self._writer = writer 

80 self._reader = reader 

81 self._protocol = protocol 

82 self._closed = False 

83 self._closing = False 

84 self._close_code: int | None = None 

85 self._timeout = timeout 

86 self._autoclose = autoclose 

87 self._autoping = autoping 

88 self._heartbeat = heartbeat 

89 self._heartbeat_cb: asyncio.TimerHandle | None = None 

90 self._heartbeat_when: float = 0.0 

91 if heartbeat is not None: 

92 self._pong_heartbeat = heartbeat / 2.0 

93 self._pong_response_cb: asyncio.TimerHandle | None = None 

94 self._loop = loop 

95 self._waiting: bool = False 

96 self._close_wait: asyncio.Future[None] | None = None 

97 self._exception: BaseException | None = None 

98 self._compress = compress 

99 self._client_notakeover = client_notakeover 

100 self._ping_task: asyncio.Task[None] | None = None 

101 self._need_heartbeat_reset = False 

102 self._heartbeat_reset_handle: asyncio.Handle | None = None 

103 

104 self._reset_heartbeat() 

105 

106 def _cancel_heartbeat(self) -> None: 

107 self._cancel_pong_response_cb() 

108 if self._heartbeat_reset_handle is not None: 

109 self._heartbeat_reset_handle.cancel() 

110 self._heartbeat_reset_handle = None 

111 self._need_heartbeat_reset = False 

112 if self._heartbeat_cb is not None: 

113 self._heartbeat_cb.cancel() 

114 self._heartbeat_cb = None 

115 if self._ping_task is not None: 

116 self._ping_task.cancel() 

117 self._ping_task = None 

118 

119 def _cancel_pong_response_cb(self) -> None: 

120 if self._pong_response_cb is not None: 

121 self._pong_response_cb.cancel() 

122 self._pong_response_cb = None 

123 

124 def _on_data_received(self) -> None: 

125 if self._heartbeat is None or self._need_heartbeat_reset: 

126 return 

127 loop = self._loop 

128 assert loop is not None 

129 # Coalesce multiple chunks received in the same loop tick into a single 

130 # heartbeat reset. Resetting immediately per chunk increases timer churn. 

131 self._need_heartbeat_reset = True 

132 self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset) 

133 

134 def _flush_heartbeat_reset(self) -> None: 

135 self._heartbeat_reset_handle = None 

136 if not self._need_heartbeat_reset: 

137 return 

138 self._reset_heartbeat() 

139 self._need_heartbeat_reset = False 

140 

141 def _reset_heartbeat(self) -> None: 

142 if self._heartbeat is None: 

143 return 

144 self._cancel_pong_response_cb() 

145 loop = self._loop 

146 assert loop is not None 

147 conn = self._conn 

148 timeout_ceil_threshold = ( 

149 conn._connector._timeout_ceil_threshold if conn is not None else 5 

150 ) 

151 now = loop.time() 

152 when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) 

153 self._heartbeat_when = when 

154 if self._heartbeat_cb is None: 

155 # We do not cancel the previous heartbeat_cb here because 

156 # it generates a significant amount of TimerHandle churn 

157 # which causes asyncio to rebuild the heap frequently. 

158 # Instead _send_heartbeat() will reschedule the next 

159 # heartbeat if it fires too early. 

160 self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) 

161 

162 def _send_heartbeat(self) -> None: 

163 self._heartbeat_cb = None 

164 

165 # If heartbeat reset is pending (data is being received), skip sending 

166 # the ping and let the reset callback handle rescheduling the heartbeat. 

167 if self._need_heartbeat_reset: 

168 return 

169 

170 loop = self._loop 

171 now = loop.time() 

172 if now < self._heartbeat_when: 

173 # Heartbeat fired too early, reschedule 

174 self._heartbeat_cb = loop.call_at( 

175 self._heartbeat_when, self._send_heartbeat 

176 ) 

177 return 

178 

179 conn = self._conn 

180 timeout_ceil_threshold = ( 

181 conn._connector._timeout_ceil_threshold if conn is not None else 5 

182 ) 

183 when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) 

184 self._cancel_pong_response_cb() 

185 self._pong_response_cb = loop.call_at(when, self._pong_not_received) 

186 

187 coro = self._writer.send_frame(b"", WSMsgType.PING) 

188 if sys.version_info >= (3, 12): 

189 # Optimization for Python 3.12, try to send the ping 

190 # immediately to avoid having to schedule 

191 # the task on the event loop. 

192 ping_task = asyncio.Task(coro, loop=loop, eager_start=True) 

193 else: 

194 ping_task = loop.create_task(coro) 

195 

196 if not ping_task.done(): 

197 self._ping_task = ping_task 

198 ping_task.add_done_callback(self._ping_task_done) 

199 else: 

200 self._ping_task_done(ping_task) 

201 

202 def _ping_task_done(self, task: "asyncio.Task[None]") -> None: 

203 """Callback for when the ping task completes.""" 

204 if not task.cancelled() and (exc := task.exception()): 

205 self._handle_ping_pong_exception(exc) 

206 self._ping_task = None 

207 

208 def _pong_not_received(self) -> None: 

209 self._handle_ping_pong_exception( 

210 ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds") 

211 ) 

212 

213 def _handle_ping_pong_exception(self, exc: BaseException) -> None: 

214 """Handle exceptions raised during ping/pong processing.""" 

215 if self._closed: 

216 return 

217 self._set_closed() 

218 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

219 self._exception = exc 

220 self._response.close() 

221 if self._waiting and not self._closing: 

222 self._reader.feed_data(WSMessageError(data=exc, extra=None)) 

223 

224 def _set_closed(self) -> None: 

225 """Set the connection to closed. 

226 

227 Cancel any heartbeat timers and set the closed flag. 

228 """ 

229 self._closed = True 

230 self._cancel_heartbeat() 

231 

232 def _set_closing(self) -> None: 

233 """Set the connection to closing. 

234 

235 Cancel any heartbeat timers and set the closing flag. 

236 """ 

237 self._closing = True 

238 self._cancel_heartbeat() 

239 

240 @property 

241 def closed(self) -> bool: 

242 return self._closed 

243 

244 @property 

245 def close_code(self) -> int | None: 

246 return self._close_code 

247 

248 @property 

249 def protocol(self) -> str | None: 

250 return self._protocol 

251 

252 @property 

253 def compress(self) -> int: 

254 return self._compress 

255 

256 @property 

257 def client_notakeover(self) -> bool: 

258 return self._client_notakeover 

259 

260 def get_extra_info(self, name: str, default: Any = None) -> Any: 

261 """extra info from connection transport""" 

262 conn = self._response.connection 

263 if conn is None: 

264 return default 

265 transport = conn.transport 

266 if transport is None: 

267 return default 

268 return transport.get_extra_info(name, default) 

269 

270 def exception(self) -> BaseException | None: 

271 return self._exception 

272 

273 async def ping(self, message: bytes = b"") -> None: 

274 await self._writer.send_frame(message, WSMsgType.PING) 

275 

276 async def pong(self, message: bytes = b"") -> None: 

277 await self._writer.send_frame(message, WSMsgType.PONG) 

278 

279 async def send_frame( 

280 self, message: bytes, opcode: WSMsgType, compress: int | None = None 

281 ) -> None: 

282 """Send a frame over the websocket.""" 

283 await self._writer.send_frame(message, opcode, compress) 

284 

285 async def send_str(self, data: str, compress: int | None = None) -> None: 

286 if not isinstance(data, str): 

287 raise TypeError("data argument must be str (%r)" % type(data)) 

288 await self._writer.send_frame( 

289 data.encode("utf-8"), WSMsgType.TEXT, compress=compress 

290 ) 

291 

292 async def send_bytes(self, data: bytes, compress: int | None = None) -> None: 

293 if not isinstance(data, (bytes, bytearray, memoryview)): 

294 raise TypeError("data argument must be byte-ish (%r)" % type(data)) 

295 await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress) 

296 

297 async def send_json( 

298 self, 

299 data: Any, 

300 compress: int | None = None, 

301 *, 

302 dumps: JSONEncoder = DEFAULT_JSON_ENCODER, 

303 ) -> None: 

304 await self.send_str(dumps(data), compress=compress) 

305 

306 async def send_json_bytes( 

307 self, 

308 data: Any, 

309 compress: int | None = None, 

310 *, 

311 dumps: JSONBytesEncoder, 

312 ) -> None: 

313 """Send JSON data using a bytes-returning encoder as a binary frame. 

314 

315 Use this when your JSON encoder (like orjson) returns bytes 

316 instead of str, avoiding the encode/decode overhead. 

317 """ 

318 await self.send_bytes(dumps(data), compress=compress) 

319 

320 async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: 

321 # we need to break `receive()` cycle first, 

322 # `close()` may be called from different task 

323 if self._waiting and not self._closing: 

324 assert self._loop is not None 

325 self._close_wait = self._loop.create_future() 

326 self._set_closing() 

327 self._reader.feed_data(WS_CLOSING_MESSAGE) 

328 await self._close_wait 

329 

330 if self._closed: 

331 return False 

332 

333 self._set_closed() 

334 try: 

335 await self._writer.close(code, message) 

336 except asyncio.CancelledError: 

337 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

338 self._response.close() 

339 raise 

340 except Exception as exc: 

341 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

342 self._exception = exc 

343 self._response.close() 

344 return True 

345 

346 if self._close_code: 

347 self._response.close() 

348 return True 

349 

350 while True: 

351 try: 

352 async with async_timeout.timeout(self._timeout.ws_close): 

353 msg = await self._reader.read() 

354 except asyncio.CancelledError: 

355 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

356 self._response.close() 

357 raise 

358 except Exception as exc: 

359 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

360 self._exception = exc 

361 self._response.close() 

362 return True 

363 

364 if msg.type is WSMsgType.CLOSE: 

365 self._close_code = msg.data 

366 self._response.close() 

367 return True 

368 

369 @overload 

370 async def receive( 

371 self: "ClientWebSocketResponse[Literal[True]]", timeout: float | None = None 

372 ) -> WSMessageDecodeText: ... 

373 

374 @overload 

375 async def receive( 

376 self: "ClientWebSocketResponse[Literal[False]]", timeout: float | None = None 

377 ) -> WSMessageNoDecodeText: ... 

378 

379 @overload 

380 async def receive( 

381 self: "ClientWebSocketResponse[_DecodeText]", timeout: float | None = None 

382 ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... 

383 

384 async def receive( 

385 self, timeout: float | None = None 

386 ) -> WSMessageDecodeText | WSMessageNoDecodeText: 

387 receive_timeout = timeout or self._timeout.ws_receive 

388 

389 while True: 

390 if self._waiting: 

391 raise RuntimeError("Concurrent call to receive() is not allowed") 

392 

393 if self._closed: 

394 return WS_CLOSED_MESSAGE 

395 elif self._closing: 

396 await self.close() 

397 return WS_CLOSED_MESSAGE 

398 

399 try: 

400 self._waiting = True 

401 try: 

402 if receive_timeout: 

403 # Entering the context manager and creating 

404 # Timeout() object can take almost 50% of the 

405 # run time in this loop so we avoid it if 

406 # there is no read timeout. 

407 async with async_timeout.timeout(receive_timeout): 

408 msg = await self._reader.read() 

409 else: 

410 msg = await self._reader.read() 

411 finally: 

412 self._waiting = False 

413 if self._close_wait: 

414 set_result(self._close_wait, None) 

415 except (asyncio.CancelledError, asyncio.TimeoutError): 

416 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

417 raise 

418 except EofStream: 

419 self._close_code = WSCloseCode.OK 

420 await self.close() 

421 return WS_CLOSED_MESSAGE 

422 except ClientError: 

423 # Likely ServerDisconnectedError when connection is lost 

424 self._set_closed() 

425 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

426 return WS_CLOSED_MESSAGE 

427 except WebSocketError as exc: 

428 self._close_code = exc.code 

429 await self.close(code=exc.code) 

430 return WSMessageError(data=exc) 

431 except Exception as exc: 

432 self._exception = exc 

433 self._set_closing() 

434 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

435 await self.close() 

436 return WSMessageError(data=exc) 

437 

438 if msg.type not in _INTERNAL_RECEIVE_TYPES: 

439 # If its not a close/closing/ping/pong message 

440 # we can return it immediately 

441 return msg 

442 

443 if msg.type is WSMsgType.CLOSE: 

444 self._set_closing() 

445 self._close_code = msg.data 

446 # Could be closed elsewhere while awaiting reader 

447 if not self._closed and self._autoclose: # type: ignore[redundant-expr] 

448 await self.close() 

449 elif msg.type is WSMsgType.CLOSING: 

450 self._set_closing() 

451 elif msg.type is WSMsgType.PING and self._autoping: 

452 await self.pong(msg.data) 

453 continue 

454 elif msg.type is WSMsgType.PONG and self._autoping: 

455 continue 

456 

457 return msg 

458 

459 @overload 

460 async def receive_str( 

461 self: "ClientWebSocketResponse[Literal[True]]", *, timeout: float | None = None 

462 ) -> str: ... 

463 

464 @overload 

465 async def receive_str( 

466 self: "ClientWebSocketResponse[Literal[False]]", *, timeout: float | None = None 

467 ) -> bytes: ... 

468 

469 @overload 

470 async def receive_str( 

471 self: "ClientWebSocketResponse[_DecodeText]", *, timeout: float | None = None 

472 ) -> str | bytes: ... 

473 

474 async def receive_str(self, *, timeout: float | None = None) -> str | bytes: 

475 """Receive TEXT message. 

476 

477 Returns str when decode_text=True (default), bytes when decode_text=False. 

478 """ 

479 msg = await self.receive(timeout) 

480 if msg.type is not WSMsgType.TEXT: 

481 raise WSMessageTypeError( 

482 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT" 

483 ) 

484 return msg.data 

485 

486 async def receive_bytes(self, *, timeout: float | None = None) -> bytes: 

487 msg = await self.receive(timeout) 

488 if msg.type is not WSMsgType.BINARY: 

489 raise WSMessageTypeError( 

490 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY" 

491 ) 

492 return msg.data 

493 

494 @overload 

495 async def receive_json( 

496 self: "ClientWebSocketResponse[Literal[True]]", 

497 *, 

498 loads: JSONDecoder = ..., 

499 timeout: float | None = None, 

500 ) -> Any: ... 

501 

502 @overload 

503 async def receive_json( 

504 self: "ClientWebSocketResponse[Literal[False]]", 

505 *, 

506 loads: Callable[[bytes], Any] = ..., 

507 timeout: float | None = None, 

508 ) -> Any: ... 

509 

510 @overload 

511 async def receive_json( 

512 self: "ClientWebSocketResponse[_DecodeText]", 

513 *, 

514 loads: JSONDecoder | Callable[[bytes], Any] = ..., 

515 timeout: float | None = None, 

516 ) -> Any: ... 

517 

518 async def receive_json( 

519 self, 

520 *, 

521 loads: JSONDecoder | Callable[[bytes], Any] = DEFAULT_JSON_DECODER, 

522 timeout: float | None = None, 

523 ) -> Any: 

524 data = await self.receive_str(timeout=timeout) 

525 return loads(data) # type: ignore[arg-type] 

526 

527 def __aiter__(self) -> Self: 

528 return self 

529 

530 @overload 

531 async def __anext__( 

532 self: "ClientWebSocketResponse[Literal[True]]", 

533 ) -> WSMessageDecodeText: ... 

534 

535 @overload 

536 async def __anext__( 

537 self: "ClientWebSocketResponse[Literal[False]]", 

538 ) -> WSMessageNoDecodeText: ... 

539 

540 @overload 

541 async def __anext__( 

542 self: "ClientWebSocketResponse[_DecodeText]", 

543 ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... 

544 

545 async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText: 

546 msg = await self.receive() 

547 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): 

548 raise StopAsyncIteration 

549 return msg 

550 

551 async def __aenter__(self) -> Self: 

552 return self 

553 

554 async def __aexit__( 

555 self, 

556 exc_type: type[BaseException] | None, 

557 exc_val: BaseException | None, 

558 exc_tb: TracebackType | None, 

559 ) -> None: 

560 await self.close()