Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiohttp/client_ws.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

331 statements  

1"""WebSocket client for asyncio.""" 

2 

3import asyncio 

4import sys 

5from collections.abc import Callable 

6from types import TracebackType 

7from typing import Any, Generic, Literal, Optional, cast, overload 

8 

9import attr 

10 

11from ._websocket.reader import WebSocketDataQueue 

12from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError 

13from .client_reqrep import ClientResponse 

14from .helpers import calculate_timeout_when, set_result 

15from .http import ( 

16 WS_CLOSED_MESSAGE, 

17 WS_CLOSING_MESSAGE, 

18 WebSocketError, 

19 WSCloseCode, 

20 WSMessage, 

21 WSMessageDecodeText, 

22 WSMessageNoDecodeText, 

23 WSMsgType, 

24) 

25from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter 

26from .streams import EofStream 

27from .typedefs import ( 

28 DEFAULT_JSON_DECODER, 

29 DEFAULT_JSON_ENCODER, 

30 JSONBytesEncoder, 

31 JSONDecoder, 

32 JSONEncoder, 

33) 

34 

35if sys.version_info >= (3, 13): 

36 from typing import TypeVar 

37else: 

38 from typing_extensions import TypeVar 

39 

40if sys.version_info >= (3, 11): 

41 import asyncio as async_timeout 

42 from typing import Self 

43else: 

44 import async_timeout 

45 from typing_extensions import Self 

46 

47# TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) 

48# Covariant because it only affects return types, not input types 

49_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True]) 

50 

51 

52@attr.s(frozen=True, slots=True) 

53class ClientWSTimeout: 

54 ws_receive = attr.ib(type=Optional[float], default=None) 

55 ws_close = attr.ib(type=Optional[float], default=None) 

56 

57 

58DEFAULT_WS_CLIENT_TIMEOUT = ClientWSTimeout(ws_receive=None, ws_close=10.0) 

59 

60 

61class ClientWebSocketResponse(Generic[_DecodeText]): 

62 def __init__( 

63 self, 

64 reader: WebSocketDataQueue, 

65 writer: WebSocketWriter, 

66 protocol: str | None, 

67 response: ClientResponse, 

68 timeout: ClientWSTimeout, 

69 autoclose: bool, 

70 autoping: bool, 

71 loop: asyncio.AbstractEventLoop, 

72 *, 

73 heartbeat: float | None = None, 

74 compress: int = 0, 

75 client_notakeover: bool = False, 

76 ) -> None: 

77 self._response = response 

78 self._conn = response.connection 

79 

80 self._writer = writer 

81 self._reader = reader 

82 self._protocol = protocol 

83 self._closed = False 

84 self._closing = False 

85 self._close_code: int | None = None 

86 self._timeout = timeout 

87 self._autoclose = autoclose 

88 self._autoping = autoping 

89 self._heartbeat = heartbeat 

90 self._heartbeat_cb: asyncio.TimerHandle | None = None 

91 self._heartbeat_when: float = 0.0 

92 if heartbeat is not None: 

93 self._pong_heartbeat = heartbeat / 2.0 

94 self._pong_response_cb: asyncio.TimerHandle | None = None 

95 self._loop = loop 

96 self._waiting: bool = False 

97 self._close_wait: asyncio.Future[None] | None = None 

98 self._exception: BaseException | None = None 

99 self._compress = compress 

100 self._client_notakeover = client_notakeover 

101 self._ping_task: asyncio.Task[None] | None = None 

102 self._need_heartbeat_reset = False 

103 self._heartbeat_reset_handle: asyncio.Handle | None = None 

104 

105 self._reset_heartbeat() 

106 

107 def _cancel_heartbeat(self) -> None: 

108 self._cancel_pong_response_cb() 

109 if self._heartbeat_reset_handle is not None: 

110 self._heartbeat_reset_handle.cancel() 

111 self._heartbeat_reset_handle = None 

112 self._need_heartbeat_reset = False 

113 if self._heartbeat_cb is not None: 

114 self._heartbeat_cb.cancel() 

115 self._heartbeat_cb = None 

116 if self._ping_task is not None: 

117 self._ping_task.cancel() 

118 self._ping_task = None 

119 

120 def _cancel_pong_response_cb(self) -> None: 

121 if self._pong_response_cb is not None: 

122 self._pong_response_cb.cancel() 

123 self._pong_response_cb = None 

124 

125 def _on_data_received(self) -> None: 

126 if self._heartbeat is None or self._need_heartbeat_reset: 

127 return 

128 loop = self._loop 

129 assert loop is not None 

130 # Coalesce multiple chunks received in the same loop tick into a single 

131 # heartbeat reset. Resetting immediately per chunk increases timer churn. 

132 self._need_heartbeat_reset = True 

133 self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset) 

134 

135 def _flush_heartbeat_reset(self) -> None: 

136 self._heartbeat_reset_handle = None 

137 if not self._need_heartbeat_reset: 

138 return 

139 self._reset_heartbeat() 

140 self._need_heartbeat_reset = False 

141 

142 def _reset_heartbeat(self) -> None: 

143 if self._heartbeat is None: 

144 return 

145 self._cancel_pong_response_cb() 

146 loop = self._loop 

147 assert loop is not None 

148 conn = self._conn 

149 timeout_ceil_threshold = ( 

150 conn._connector._timeout_ceil_threshold if conn is not None else 5 

151 ) 

152 now = loop.time() 

153 when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) 

154 self._heartbeat_when = when 

155 if self._heartbeat_cb is None: 

156 # We do not cancel the previous heartbeat_cb here because 

157 # it generates a significant amount of TimerHandle churn 

158 # which causes asyncio to rebuild the heap frequently. 

159 # Instead _send_heartbeat() will reschedule the next 

160 # heartbeat if it fires too early. 

161 self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) 

162 

163 def _send_heartbeat(self) -> None: 

164 self._heartbeat_cb = None 

165 

166 # If heartbeat reset is pending (data is being received), skip sending 

167 # the ping and let the reset callback handle rescheduling the heartbeat. 

168 if self._need_heartbeat_reset: 

169 return 

170 

171 loop = self._loop 

172 now = loop.time() 

173 if now < self._heartbeat_when: 

174 # Heartbeat fired too early, reschedule 

175 self._heartbeat_cb = loop.call_at( 

176 self._heartbeat_when, self._send_heartbeat 

177 ) 

178 return 

179 

180 conn = self._conn 

181 timeout_ceil_threshold = ( 

182 conn._connector._timeout_ceil_threshold if conn is not None else 5 

183 ) 

184 when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) 

185 self._cancel_pong_response_cb() 

186 self._pong_response_cb = loop.call_at(when, self._pong_not_received) 

187 

188 coro = self._writer.send_frame(b"", WSMsgType.PING) 

189 if sys.version_info >= (3, 12): 

190 # Optimization for Python 3.12, try to send the ping 

191 # immediately to avoid having to schedule 

192 # the task on the event loop. 

193 ping_task = asyncio.Task(coro, loop=loop, eager_start=True) 

194 else: 

195 ping_task = loop.create_task(coro) 

196 

197 if not ping_task.done(): 

198 self._ping_task = ping_task 

199 ping_task.add_done_callback(self._ping_task_done) 

200 else: 

201 self._ping_task_done(ping_task) 

202 

203 def _ping_task_done(self, task: "asyncio.Task[None]") -> None: 

204 """Callback for when the ping task completes.""" 

205 if not task.cancelled() and (exc := task.exception()): 

206 self._handle_ping_pong_exception(exc) 

207 self._ping_task = None 

208 

209 def _pong_not_received(self) -> None: 

210 self._handle_ping_pong_exception( 

211 ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds") 

212 ) 

213 

214 def _handle_ping_pong_exception(self, exc: BaseException) -> None: 

215 """Handle exceptions raised during ping/pong processing.""" 

216 if self._closed: 

217 return 

218 self._set_closed() 

219 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

220 self._exception = exc 

221 self._response.close() 

222 if self._waiting and not self._closing: 

223 self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0) 

224 

225 def _set_closed(self) -> None: 

226 """Set the connection to closed. 

227 

228 Cancel any heartbeat timers and set the closed flag. 

229 """ 

230 self._closed = True 

231 self._cancel_heartbeat() 

232 

233 def _set_closing(self) -> None: 

234 """Set the connection to closing. 

235 

236 Cancel any heartbeat timers and set the closing flag. 

237 """ 

238 self._closing = True 

239 self._cancel_heartbeat() 

240 

241 @property 

242 def closed(self) -> bool: 

243 return self._closed 

244 

245 @property 

246 def close_code(self) -> int | None: 

247 return self._close_code 

248 

249 @property 

250 def protocol(self) -> str | None: 

251 return self._protocol 

252 

253 @property 

254 def compress(self) -> int: 

255 return self._compress 

256 

257 @property 

258 def client_notakeover(self) -> bool: 

259 return self._client_notakeover 

260 

261 def get_extra_info(self, name: str, default: Any = None) -> Any: 

262 """extra info from connection transport""" 

263 conn = self._response.connection 

264 if conn is None: 

265 return default 

266 transport = conn.transport 

267 if transport is None: 

268 return default 

269 return transport.get_extra_info(name, default) 

270 

271 def exception(self) -> BaseException | None: 

272 return self._exception 

273 

274 async def ping(self, message: bytes = b"") -> None: 

275 await self._writer.send_frame(message, WSMsgType.PING) 

276 

277 async def pong(self, message: bytes = b"") -> None: 

278 await self._writer.send_frame(message, WSMsgType.PONG) 

279 

280 async def send_frame( 

281 self, message: bytes, opcode: WSMsgType, compress: int | None = None 

282 ) -> None: 

283 """Send a frame over the websocket.""" 

284 await self._writer.send_frame(message, opcode, compress) 

285 

286 async def send_str(self, data: str, compress: int | None = None) -> None: 

287 if not isinstance(data, str): 

288 raise TypeError("data argument must be str (%r)" % type(data)) 

289 await self._writer.send_frame( 

290 data.encode("utf-8"), WSMsgType.TEXT, compress=compress 

291 ) 

292 

293 async def send_bytes(self, data: bytes, compress: int | None = None) -> None: 

294 if not isinstance(data, (bytes, bytearray, memoryview)): 

295 raise TypeError("data argument must be byte-ish (%r)" % type(data)) 

296 await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress) 

297 

298 async def send_json( 

299 self, 

300 data: Any, 

301 compress: int | None = None, 

302 *, 

303 dumps: JSONEncoder = DEFAULT_JSON_ENCODER, 

304 ) -> None: 

305 await self.send_str(dumps(data), compress=compress) 

306 

307 async def send_json_bytes( 

308 self, 

309 data: Any, 

310 compress: int | None = None, 

311 *, 

312 dumps: JSONBytesEncoder, 

313 ) -> None: 

314 """Send JSON data using a bytes-returning encoder as a binary frame. 

315 

316 Use this when your JSON encoder (like orjson) returns bytes 

317 instead of str, avoiding the encode/decode overhead. 

318 """ 

319 await self.send_bytes(dumps(data), compress=compress) 

320 

321 async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: 

322 # we need to break `receive()` cycle first, 

323 # `close()` may be called from different task 

324 if self._waiting and not self._closing: 

325 assert self._loop is not None 

326 self._close_wait = self._loop.create_future() 

327 self._set_closing() 

328 self._reader.feed_data(WS_CLOSING_MESSAGE, 0) 

329 await self._close_wait 

330 

331 if self._closed: 

332 return False 

333 

334 self._set_closed() 

335 try: 

336 await self._writer.close(code, message) 

337 except asyncio.CancelledError: 

338 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

339 self._response.close() 

340 raise 

341 except Exception as exc: 

342 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

343 self._exception = exc 

344 self._response.close() 

345 return True 

346 

347 if self._close_code: 

348 self._response.close() 

349 return True 

350 

351 while True: 

352 try: 

353 async with async_timeout.timeout(self._timeout.ws_close): 

354 msg = await self._reader.read() 

355 except asyncio.CancelledError: 

356 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

357 self._response.close() 

358 raise 

359 except Exception as exc: 

360 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

361 self._exception = exc 

362 self._response.close() 

363 return True 

364 

365 if msg.type is WSMsgType.CLOSE: 

366 self._close_code = msg.data 

367 self._response.close() 

368 return True 

369 

370 @overload 

371 async def receive( 

372 self: "ClientWebSocketResponse[Literal[True]]", timeout: float | None = None 

373 ) -> WSMessageDecodeText: ... 

374 

375 @overload 

376 async def receive( 

377 self: "ClientWebSocketResponse[Literal[False]]", timeout: float | None = None 

378 ) -> WSMessageNoDecodeText: ... 

379 

380 @overload 

381 async def receive( 

382 self: "ClientWebSocketResponse[_DecodeText]", timeout: float | None = None 

383 ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... 

384 

385 async def receive( 

386 self, timeout: float | None = None 

387 ) -> WSMessageDecodeText | WSMessageNoDecodeText: 

388 receive_timeout = timeout or self._timeout.ws_receive 

389 

390 while True: 

391 if self._waiting: 

392 raise RuntimeError("Concurrent call to receive() is not allowed") 

393 

394 if self._closed: 

395 return WS_CLOSED_MESSAGE 

396 elif self._closing: 

397 await self.close() 

398 return WS_CLOSED_MESSAGE 

399 

400 try: 

401 self._waiting = True 

402 try: 

403 if receive_timeout: 

404 # Entering the context manager and creating 

405 # Timeout() object can take almost 50% of the 

406 # run time in this loop so we avoid it if 

407 # there is no read timeout. 

408 async with async_timeout.timeout(receive_timeout): 

409 msg = await self._reader.read() 

410 else: 

411 msg = await self._reader.read() 

412 finally: 

413 self._waiting = False 

414 if self._close_wait: 

415 set_result(self._close_wait, None) 

416 except (asyncio.CancelledError, asyncio.TimeoutError): 

417 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

418 raise 

419 except EofStream: 

420 self._close_code = WSCloseCode.OK 

421 await self.close() 

422 return WSMessage(WSMsgType.CLOSED, None, None) 

423 except ClientError: 

424 # Likely ServerDisconnectedError when connection is lost 

425 self._set_closed() 

426 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

427 return WS_CLOSED_MESSAGE 

428 except WebSocketError as exc: 

429 self._close_code = exc.code 

430 await self.close(code=exc.code) 

431 return WSMessage(WSMsgType.ERROR, exc, None) 

432 except Exception as exc: 

433 self._exception = exc 

434 self._set_closing() 

435 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

436 await self.close() 

437 return WSMessage(WSMsgType.ERROR, exc, None) 

438 

439 if msg.type not in _INTERNAL_RECEIVE_TYPES: 

440 # If its not a close/closing/ping/pong message 

441 # we can return it immediately 

442 return msg 

443 

444 if msg.type is WSMsgType.CLOSE: 

445 self._set_closing() 

446 self._close_code = msg.data 

447 if not self._closed and self._autoclose: 

448 await self.close() 

449 elif msg.type is WSMsgType.CLOSING: 

450 self._set_closing() 

451 elif msg.type is WSMsgType.PING and self._autoping: 

452 await self.pong(msg.data) 

453 continue 

454 elif msg.type is WSMsgType.PONG and self._autoping: 

455 continue 

456 

457 return msg 

458 

459 @overload 

460 async def receive_str( 

461 self: "ClientWebSocketResponse[Literal[True]]", *, timeout: float | None = None 

462 ) -> str: ... 

463 

464 @overload 

465 async def receive_str( 

466 self: "ClientWebSocketResponse[Literal[False]]", *, timeout: float | None = None 

467 ) -> bytes: ... 

468 

469 @overload 

470 async def receive_str( 

471 self: "ClientWebSocketResponse[_DecodeText]", *, timeout: float | None = None 

472 ) -> str | bytes: ... 

473 

474 async def receive_str(self, *, timeout: float | None = None) -> str | bytes: 

475 """Receive TEXT message. 

476 

477 Returns str when decode_text=True (default), bytes when decode_text=False. 

478 """ 

479 msg = await self.receive(timeout) 

480 if msg.type is not WSMsgType.TEXT: 

481 raise WSMessageTypeError( 

482 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT" 

483 ) 

484 return cast(str, msg.data) 

485 

486 async def receive_bytes(self, *, timeout: float | None = None) -> bytes: 

487 msg = await self.receive(timeout) 

488 if msg.type is not WSMsgType.BINARY: 

489 raise WSMessageTypeError( 

490 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY" 

491 ) 

492 return cast(bytes, msg.data) 

493 

494 @overload 

495 async def receive_json( 

496 self: "ClientWebSocketResponse[Literal[True]]", 

497 *, 

498 loads: JSONDecoder = ..., 

499 timeout: float | None = None, 

500 ) -> Any: ... 

501 

502 @overload 

503 async def receive_json( 

504 self: "ClientWebSocketResponse[Literal[False]]", 

505 *, 

506 loads: Callable[[bytes], Any] = ..., 

507 timeout: float | None = None, 

508 ) -> Any: ... 

509 

510 @overload 

511 async def receive_json( 

512 self: "ClientWebSocketResponse[_DecodeText]", 

513 *, 

514 loads: JSONDecoder | Callable[[bytes], Any] = ..., 

515 timeout: float | None = None, 

516 ) -> Any: ... 

517 

518 async def receive_json( 

519 self, 

520 *, 

521 loads: JSONDecoder | Callable[[bytes], Any] = DEFAULT_JSON_DECODER, 

522 timeout: float | None = None, 

523 ) -> Any: 

524 data = await self.receive_str(timeout=timeout) 

525 return loads(data) # type: ignore[arg-type] 

526 

527 def __aiter__(self) -> Self: 

528 return self 

529 

530 @overload 

531 async def __anext__( 

532 self: "ClientWebSocketResponse[Literal[True]]", 

533 ) -> WSMessageDecodeText: ... 

534 

535 @overload 

536 async def __anext__( 

537 self: "ClientWebSocketResponse[Literal[False]]", 

538 ) -> WSMessageNoDecodeText: ... 

539 

540 @overload 

541 async def __anext__( 

542 self: "ClientWebSocketResponse[_DecodeText]", 

543 ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... 

544 

545 async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText: 

546 msg = await self.receive() 

547 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): 

548 raise StopAsyncIteration 

549 return msg 

550 

551 async def __aenter__(self) -> Self: 

552 return self 

553 

554 async def __aexit__( 

555 self, 

556 exc_type: type[BaseException] | None, 

557 exc_val: BaseException | None, 

558 exc_tb: TracebackType | None, 

559 ) -> None: 

560 await self.close()