Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/client_ws.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

276 statements  

1"""WebSocket client for asyncio.""" 

2 

3import asyncio 

4import sys 

5from types import TracebackType 

6from typing import Any, Final, Optional, Type 

7 

8from ._websocket.reader import WebSocketDataQueue 

9from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError 

10from .client_reqrep import ClientResponse 

11from .helpers import calculate_timeout_when, frozen_dataclass_decorator, set_result 

12from .http import ( 

13 WS_CLOSED_MESSAGE, 

14 WS_CLOSING_MESSAGE, 

15 WebSocketError, 

16 WSCloseCode, 

17 WSMessage, 

18 WSMsgType, 

19) 

20from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter, WSMessageError 

21from .streams import EofStream 

22from .typedefs import ( 

23 DEFAULT_JSON_DECODER, 

24 DEFAULT_JSON_ENCODER, 

25 JSONDecoder, 

26 JSONEncoder, 

27) 

28 

29if sys.version_info >= (3, 11): 

30 import asyncio as async_timeout 

31else: 

32 import async_timeout 

33 

34 

35@frozen_dataclass_decorator 

36class ClientWSTimeout: 

37 ws_receive: Optional[float] = None 

38 ws_close: Optional[float] = None 

39 

40 

41DEFAULT_WS_CLIENT_TIMEOUT: Final[ClientWSTimeout] = ClientWSTimeout( 

42 ws_receive=None, ws_close=10.0 

43) 

44 

45 

46class ClientWebSocketResponse: 

47 def __init__( 

48 self, 

49 reader: WebSocketDataQueue, 

50 writer: WebSocketWriter, 

51 protocol: Optional[str], 

52 response: ClientResponse, 

53 timeout: ClientWSTimeout, 

54 autoclose: bool, 

55 autoping: bool, 

56 loop: asyncio.AbstractEventLoop, 

57 *, 

58 heartbeat: Optional[float] = None, 

59 compress: int = 0, 

60 client_notakeover: bool = False, 

61 ) -> None: 

62 self._response = response 

63 self._conn = response.connection 

64 

65 self._writer = writer 

66 self._reader = reader 

67 self._protocol = protocol 

68 self._closed = False 

69 self._closing = False 

70 self._close_code: Optional[int] = None 

71 self._timeout = timeout 

72 self._autoclose = autoclose 

73 self._autoping = autoping 

74 self._heartbeat = heartbeat 

75 self._heartbeat_cb: Optional[asyncio.TimerHandle] = None 

76 self._heartbeat_when: float = 0.0 

77 if heartbeat is not None: 

78 self._pong_heartbeat = heartbeat / 2.0 

79 self._pong_response_cb: Optional[asyncio.TimerHandle] = None 

80 self._loop = loop 

81 self._waiting: bool = False 

82 self._close_wait: Optional[asyncio.Future[None]] = None 

83 self._exception: Optional[BaseException] = None 

84 self._compress = compress 

85 self._client_notakeover = client_notakeover 

86 self._ping_task: Optional[asyncio.Task[None]] = None 

87 

88 self._reset_heartbeat() 

89 

90 def _cancel_heartbeat(self) -> None: 

91 self._cancel_pong_response_cb() 

92 if self._heartbeat_cb is not None: 

93 self._heartbeat_cb.cancel() 

94 self._heartbeat_cb = None 

95 if self._ping_task is not None: 

96 self._ping_task.cancel() 

97 self._ping_task = None 

98 

99 def _cancel_pong_response_cb(self) -> None: 

100 if self._pong_response_cb is not None: 

101 self._pong_response_cb.cancel() 

102 self._pong_response_cb = None 

103 

104 def _reset_heartbeat(self) -> None: 

105 if self._heartbeat is None: 

106 return 

107 self._cancel_pong_response_cb() 

108 loop = self._loop 

109 assert loop is not None 

110 conn = self._conn 

111 timeout_ceil_threshold = ( 

112 conn._connector._timeout_ceil_threshold if conn is not None else 5 

113 ) 

114 now = loop.time() 

115 when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) 

116 self._heartbeat_when = when 

117 if self._heartbeat_cb is None: 

118 # We do not cancel the previous heartbeat_cb here because 

119 # it generates a significant amount of TimerHandle churn 

120 # which causes asyncio to rebuild the heap frequently. 

121 # Instead _send_heartbeat() will reschedule the next 

122 # heartbeat if it fires too early. 

123 self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) 

124 

125 def _send_heartbeat(self) -> None: 

126 self._heartbeat_cb = None 

127 loop = self._loop 

128 now = loop.time() 

129 if now < self._heartbeat_when: 

130 # Heartbeat fired too early, reschedule 

131 self._heartbeat_cb = loop.call_at( 

132 self._heartbeat_when, self._send_heartbeat 

133 ) 

134 return 

135 

136 conn = self._conn 

137 timeout_ceil_threshold = ( 

138 conn._connector._timeout_ceil_threshold if conn is not None else 5 

139 ) 

140 when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) 

141 self._cancel_pong_response_cb() 

142 self._pong_response_cb = loop.call_at(when, self._pong_not_received) 

143 

144 coro = self._writer.send_frame(b"", WSMsgType.PING) 

145 if sys.version_info >= (3, 12): 

146 # Optimization for Python 3.12, try to send the ping 

147 # immediately to avoid having to schedule 

148 # the task on the event loop. 

149 ping_task = asyncio.Task(coro, loop=loop, eager_start=True) 

150 else: 

151 ping_task = loop.create_task(coro) 

152 

153 if not ping_task.done(): 

154 self._ping_task = ping_task 

155 ping_task.add_done_callback(self._ping_task_done) 

156 else: 

157 self._ping_task_done(ping_task) 

158 

159 def _ping_task_done(self, task: "asyncio.Task[None]") -> None: 

160 """Callback for when the ping task completes.""" 

161 if not task.cancelled() and (exc := task.exception()): 

162 self._handle_ping_pong_exception(exc) 

163 self._ping_task = None 

164 

165 def _pong_not_received(self) -> None: 

166 self._handle_ping_pong_exception( 

167 ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds") 

168 ) 

169 

170 def _handle_ping_pong_exception(self, exc: BaseException) -> None: 

171 """Handle exceptions raised during ping/pong processing.""" 

172 if self._closed: 

173 return 

174 self._set_closed() 

175 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

176 self._exception = exc 

177 self._response.close() 

178 if self._waiting and not self._closing: 

179 self._reader.feed_data(WSMessageError(data=exc, extra=None)) 

180 

181 def _set_closed(self) -> None: 

182 """Set the connection to closed. 

183 

184 Cancel any heartbeat timers and set the closed flag. 

185 """ 

186 self._closed = True 

187 self._cancel_heartbeat() 

188 

189 def _set_closing(self) -> None: 

190 """Set the connection to closing. 

191 

192 Cancel any heartbeat timers and set the closing flag. 

193 """ 

194 self._closing = True 

195 self._cancel_heartbeat() 

196 

197 @property 

198 def closed(self) -> bool: 

199 return self._closed 

200 

201 @property 

202 def close_code(self) -> Optional[int]: 

203 return self._close_code 

204 

205 @property 

206 def protocol(self) -> Optional[str]: 

207 return self._protocol 

208 

209 @property 

210 def compress(self) -> int: 

211 return self._compress 

212 

213 @property 

214 def client_notakeover(self) -> bool: 

215 return self._client_notakeover 

216 

217 def get_extra_info(self, name: str, default: Any = None) -> Any: 

218 """extra info from connection transport""" 

219 conn = self._response.connection 

220 if conn is None: 

221 return default 

222 transport = conn.transport 

223 if transport is None: 

224 return default 

225 return transport.get_extra_info(name, default) 

226 

227 def exception(self) -> Optional[BaseException]: 

228 return self._exception 

229 

230 async def ping(self, message: bytes = b"") -> None: 

231 await self._writer.send_frame(message, WSMsgType.PING) 

232 

233 async def pong(self, message: bytes = b"") -> None: 

234 await self._writer.send_frame(message, WSMsgType.PONG) 

235 

236 async def send_frame( 

237 self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None 

238 ) -> None: 

239 """Send a frame over the websocket.""" 

240 await self._writer.send_frame(message, opcode, compress) 

241 

242 async def send_str(self, data: str, compress: Optional[int] = None) -> None: 

243 if not isinstance(data, str): 

244 raise TypeError("data argument must be str (%r)" % type(data)) 

245 await self._writer.send_frame( 

246 data.encode("utf-8"), WSMsgType.TEXT, compress=compress 

247 ) 

248 

249 async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: 

250 if not isinstance(data, (bytes, bytearray, memoryview)): 

251 raise TypeError("data argument must be byte-ish (%r)" % type(data)) 

252 await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress) 

253 

254 async def send_json( 

255 self, 

256 data: Any, 

257 compress: Optional[int] = None, 

258 *, 

259 dumps: JSONEncoder = DEFAULT_JSON_ENCODER, 

260 ) -> None: 

261 await self.send_str(dumps(data), compress=compress) 

262 

263 async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: 

264 # we need to break `receive()` cycle first, 

265 # `close()` may be called from different task 

266 if self._waiting and not self._closing: 

267 assert self._loop is not None 

268 self._close_wait = self._loop.create_future() 

269 self._set_closing() 

270 self._reader.feed_data(WS_CLOSING_MESSAGE) 

271 await self._close_wait 

272 

273 if self._closed: 

274 return False 

275 

276 self._set_closed() 

277 try: 

278 await self._writer.close(code, message) 

279 except asyncio.CancelledError: 

280 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

281 self._response.close() 

282 raise 

283 except Exception as exc: 

284 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

285 self._exception = exc 

286 self._response.close() 

287 return True 

288 

289 if self._close_code: 

290 self._response.close() 

291 return True 

292 

293 while True: 

294 try: 

295 async with async_timeout.timeout(self._timeout.ws_close): 

296 msg = await self._reader.read() 

297 except asyncio.CancelledError: 

298 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

299 self._response.close() 

300 raise 

301 except Exception as exc: 

302 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

303 self._exception = exc 

304 self._response.close() 

305 return True 

306 

307 if msg.type is WSMsgType.CLOSE: 

308 self._close_code = msg.data 

309 self._response.close() 

310 return True 

311 

312 async def receive(self, timeout: Optional[float] = None) -> WSMessage: 

313 receive_timeout = timeout or self._timeout.ws_receive 

314 

315 while True: 

316 if self._waiting: 

317 raise RuntimeError("Concurrent call to receive() is not allowed") 

318 

319 if self._closed: 

320 return WS_CLOSED_MESSAGE 

321 elif self._closing: 

322 await self.close() 

323 return WS_CLOSED_MESSAGE 

324 

325 try: 

326 self._waiting = True 

327 try: 

328 if receive_timeout: 

329 # Entering the context manager and creating 

330 # Timeout() object can take almost 50% of the 

331 # run time in this loop so we avoid it if 

332 # there is no read timeout. 

333 async with async_timeout.timeout(receive_timeout): 

334 msg = await self._reader.read() 

335 else: 

336 msg = await self._reader.read() 

337 self._reset_heartbeat() 

338 finally: 

339 self._waiting = False 

340 if self._close_wait: 

341 set_result(self._close_wait, None) 

342 except (asyncio.CancelledError, asyncio.TimeoutError): 

343 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

344 raise 

345 except EofStream: 

346 self._close_code = WSCloseCode.OK 

347 await self.close() 

348 return WS_CLOSED_MESSAGE 

349 except ClientError: 

350 # Likely ServerDisconnectedError when connection is lost 

351 self._set_closed() 

352 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

353 return WS_CLOSED_MESSAGE 

354 except WebSocketError as exc: 

355 self._close_code = exc.code 

356 await self.close(code=exc.code) 

357 return WSMessageError(data=exc) 

358 except Exception as exc: 

359 self._exception = exc 

360 self._set_closing() 

361 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

362 await self.close() 

363 return WSMessageError(data=exc) 

364 

365 if msg.type not in _INTERNAL_RECEIVE_TYPES: 

366 # If its not a close/closing/ping/pong message 

367 # we can return it immediately 

368 return msg 

369 

370 if msg.type is WSMsgType.CLOSE: 

371 self._set_closing() 

372 self._close_code = msg.data 

373 # Could be closed elsewhere while awaiting reader 

374 if not self._closed and self._autoclose: # type: ignore[redundant-expr] 

375 await self.close() 

376 elif msg.type is WSMsgType.CLOSING: 

377 self._set_closing() 

378 elif msg.type is WSMsgType.PING and self._autoping: 

379 await self.pong(msg.data) 

380 continue 

381 elif msg.type is WSMsgType.PONG and self._autoping: 

382 continue 

383 

384 return msg 

385 

386 async def receive_str(self, *, timeout: Optional[float] = None) -> str: 

387 msg = await self.receive(timeout) 

388 if msg.type is not WSMsgType.TEXT: 

389 raise WSMessageTypeError( 

390 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT" 

391 ) 

392 return msg.data 

393 

394 async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: 

395 msg = await self.receive(timeout) 

396 if msg.type is not WSMsgType.BINARY: 

397 raise WSMessageTypeError( 

398 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY" 

399 ) 

400 return msg.data 

401 

402 async def receive_json( 

403 self, 

404 *, 

405 loads: JSONDecoder = DEFAULT_JSON_DECODER, 

406 timeout: Optional[float] = None, 

407 ) -> Any: 

408 data = await self.receive_str(timeout=timeout) 

409 return loads(data) 

410 

411 def __aiter__(self) -> "ClientWebSocketResponse": 

412 return self 

413 

414 async def __anext__(self) -> WSMessage: 

415 msg = await self.receive() 

416 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): 

417 raise StopAsyncIteration 

418 return msg 

419 

420 async def __aenter__(self) -> "ClientWebSocketResponse": 

421 return self 

422 

423 async def __aexit__( 

424 self, 

425 exc_type: Optional[Type[BaseException]], 

426 exc_val: Optional[BaseException], 

427 exc_tb: Optional[TracebackType], 

428 ) -> None: 

429 await self.close()