Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiohttp/client_ws.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

278 statements  

1"""WebSocket client for asyncio.""" 

2 

3import asyncio 

4import sys 

5from types import TracebackType 

6from typing import Any, Optional, Type, cast 

7 

8import attr 

9 

10from ._websocket.reader import WebSocketDataQueue 

11from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError 

12from .client_reqrep import ClientResponse 

13from .helpers import calculate_timeout_when, set_result 

14from .http import ( 

15 WS_CLOSED_MESSAGE, 

16 WS_CLOSING_MESSAGE, 

17 WebSocketError, 

18 WSCloseCode, 

19 WSMessage, 

20 WSMsgType, 

21) 

22from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter 

23from .streams import EofStream 

24from .typedefs import ( 

25 DEFAULT_JSON_DECODER, 

26 DEFAULT_JSON_ENCODER, 

27 JSONDecoder, 

28 JSONEncoder, 

29) 

30 

31if sys.version_info >= (3, 11): 

32 import asyncio as async_timeout 

33else: 

34 import async_timeout 

35 

36 

37@attr.s(frozen=True, slots=True) 

38class ClientWSTimeout: 

39 ws_receive = attr.ib(type=Optional[float], default=None) 

40 ws_close = attr.ib(type=Optional[float], default=None) 

41 

42 

43DEFAULT_WS_CLIENT_TIMEOUT = ClientWSTimeout(ws_receive=None, ws_close=10.0) 

44 

45 

46class ClientWebSocketResponse: 

47 def __init__( 

48 self, 

49 reader: WebSocketDataQueue, 

50 writer: WebSocketWriter, 

51 protocol: Optional[str], 

52 response: ClientResponse, 

53 timeout: ClientWSTimeout, 

54 autoclose: bool, 

55 autoping: bool, 

56 loop: asyncio.AbstractEventLoop, 

57 *, 

58 heartbeat: Optional[float] = None, 

59 compress: int = 0, 

60 client_notakeover: bool = False, 

61 ) -> None: 

62 self._response = response 

63 self._conn = response.connection 

64 

65 self._writer = writer 

66 self._reader = reader 

67 self._protocol = protocol 

68 self._closed = False 

69 self._closing = False 

70 self._close_code: Optional[int] = None 

71 self._timeout = timeout 

72 self._autoclose = autoclose 

73 self._autoping = autoping 

74 self._heartbeat = heartbeat 

75 self._heartbeat_cb: Optional[asyncio.TimerHandle] = None 

76 self._heartbeat_when: float = 0.0 

77 if heartbeat is not None: 

78 self._pong_heartbeat = heartbeat / 2.0 

79 self._pong_response_cb: Optional[asyncio.TimerHandle] = None 

80 self._loop = loop 

81 self._waiting: bool = False 

82 self._close_wait: Optional[asyncio.Future[None]] = None 

83 self._exception: Optional[BaseException] = None 

84 self._compress = compress 

85 self._client_notakeover = client_notakeover 

86 self._ping_task: Optional[asyncio.Task[None]] = None 

87 

88 self._reset_heartbeat() 

89 

90 def _cancel_heartbeat(self) -> None: 

91 self._cancel_pong_response_cb() 

92 if self._heartbeat_cb is not None: 

93 self._heartbeat_cb.cancel() 

94 self._heartbeat_cb = None 

95 if self._ping_task is not None: 

96 self._ping_task.cancel() 

97 self._ping_task = None 

98 

99 def _cancel_pong_response_cb(self) -> None: 

100 if self._pong_response_cb is not None: 

101 self._pong_response_cb.cancel() 

102 self._pong_response_cb = None 

103 

104 def _reset_heartbeat(self) -> None: 

105 if self._heartbeat is None: 

106 return 

107 self._cancel_pong_response_cb() 

108 loop = self._loop 

109 assert loop is not None 

110 conn = self._conn 

111 timeout_ceil_threshold = ( 

112 conn._connector._timeout_ceil_threshold if conn is not None else 5 

113 ) 

114 now = loop.time() 

115 when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) 

116 self._heartbeat_when = when 

117 if self._heartbeat_cb is None: 

118 # We do not cancel the previous heartbeat_cb here because 

119 # it generates a significant amount of TimerHandle churn 

120 # which causes asyncio to rebuild the heap frequently. 

121 # Instead _send_heartbeat() will reschedule the next 

122 # heartbeat if it fires too early. 

123 self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) 

124 

125 def _send_heartbeat(self) -> None: 

126 self._heartbeat_cb = None 

127 loop = self._loop 

128 now = loop.time() 

129 if now < self._heartbeat_when: 

130 # Heartbeat fired too early, reschedule 

131 self._heartbeat_cb = loop.call_at( 

132 self._heartbeat_when, self._send_heartbeat 

133 ) 

134 return 

135 

136 conn = self._conn 

137 timeout_ceil_threshold = ( 

138 conn._connector._timeout_ceil_threshold if conn is not None else 5 

139 ) 

140 when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) 

141 self._cancel_pong_response_cb() 

142 self._pong_response_cb = loop.call_at(when, self._pong_not_received) 

143 

144 coro = self._writer.send_frame(b"", WSMsgType.PING) 

145 if sys.version_info >= (3, 12): 

146 # Optimization for Python 3.12, try to send the ping 

147 # immediately to avoid having to schedule 

148 # the task on the event loop. 

149 ping_task = asyncio.Task(coro, loop=loop, eager_start=True) 

150 else: 

151 ping_task = loop.create_task(coro) 

152 

153 if not ping_task.done(): 

154 self._ping_task = ping_task 

155 ping_task.add_done_callback(self._ping_task_done) 

156 else: 

157 self._ping_task_done(ping_task) 

158 

159 def _ping_task_done(self, task: "asyncio.Task[None]") -> None: 

160 """Callback for when the ping task completes.""" 

161 if not task.cancelled() and (exc := task.exception()): 

162 self._handle_ping_pong_exception(exc) 

163 self._ping_task = None 

164 

165 def _pong_not_received(self) -> None: 

166 self._handle_ping_pong_exception( 

167 ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds") 

168 ) 

169 

170 def _handle_ping_pong_exception(self, exc: BaseException) -> None: 

171 """Handle exceptions raised during ping/pong processing.""" 

172 if self._closed: 

173 return 

174 self._set_closed() 

175 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

176 self._exception = exc 

177 self._response.close() 

178 if self._waiting and not self._closing: 

179 self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0) 

180 

181 def _set_closed(self) -> None: 

182 """Set the connection to closed. 

183 

184 Cancel any heartbeat timers and set the closed flag. 

185 """ 

186 self._closed = True 

187 self._cancel_heartbeat() 

188 

189 def _set_closing(self) -> None: 

190 """Set the connection to closing. 

191 

192 Cancel any heartbeat timers and set the closing flag. 

193 """ 

194 self._closing = True 

195 self._cancel_heartbeat() 

196 

197 @property 

198 def closed(self) -> bool: 

199 return self._closed 

200 

201 @property 

202 def close_code(self) -> Optional[int]: 

203 return self._close_code 

204 

205 @property 

206 def protocol(self) -> Optional[str]: 

207 return self._protocol 

208 

209 @property 

210 def compress(self) -> int: 

211 return self._compress 

212 

213 @property 

214 def client_notakeover(self) -> bool: 

215 return self._client_notakeover 

216 

217 def get_extra_info(self, name: str, default: Any = None) -> Any: 

218 """extra info from connection transport""" 

219 conn = self._response.connection 

220 if conn is None: 

221 return default 

222 transport = conn.transport 

223 if transport is None: 

224 return default 

225 return transport.get_extra_info(name, default) 

226 

227 def exception(self) -> Optional[BaseException]: 

228 return self._exception 

229 

230 async def ping(self, message: bytes = b"") -> None: 

231 await self._writer.send_frame(message, WSMsgType.PING) 

232 

233 async def pong(self, message: bytes = b"") -> None: 

234 await self._writer.send_frame(message, WSMsgType.PONG) 

235 

236 async def send_frame( 

237 self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None 

238 ) -> None: 

239 """Send a frame over the websocket.""" 

240 await self._writer.send_frame(message, opcode, compress) 

241 

242 async def send_str(self, data: str, compress: Optional[int] = None) -> None: 

243 if not isinstance(data, str): 

244 raise TypeError("data argument must be str (%r)" % type(data)) 

245 await self._writer.send_frame( 

246 data.encode("utf-8"), WSMsgType.TEXT, compress=compress 

247 ) 

248 

249 async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: 

250 if not isinstance(data, (bytes, bytearray, memoryview)): 

251 raise TypeError("data argument must be byte-ish (%r)" % type(data)) 

252 await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress) 

253 

254 async def send_json( 

255 self, 

256 data: Any, 

257 compress: Optional[int] = None, 

258 *, 

259 dumps: JSONEncoder = DEFAULT_JSON_ENCODER, 

260 ) -> None: 

261 await self.send_str(dumps(data), compress=compress) 

262 

263 async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: 

264 # we need to break `receive()` cycle first, 

265 # `close()` may be called from different task 

266 if self._waiting and not self._closing: 

267 assert self._loop is not None 

268 self._close_wait = self._loop.create_future() 

269 self._set_closing() 

270 self._reader.feed_data(WS_CLOSING_MESSAGE, 0) 

271 await self._close_wait 

272 

273 if self._closed: 

274 return False 

275 

276 self._set_closed() 

277 try: 

278 await self._writer.close(code, message) 

279 except asyncio.CancelledError: 

280 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

281 self._response.close() 

282 raise 

283 except Exception as exc: 

284 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

285 self._exception = exc 

286 self._response.close() 

287 return True 

288 

289 if self._close_code: 

290 self._response.close() 

291 return True 

292 

293 while True: 

294 try: 

295 async with async_timeout.timeout(self._timeout.ws_close): 

296 msg = await self._reader.read() 

297 except asyncio.CancelledError: 

298 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

299 self._response.close() 

300 raise 

301 except Exception as exc: 

302 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

303 self._exception = exc 

304 self._response.close() 

305 return True 

306 

307 if msg.type is WSMsgType.CLOSE: 

308 self._close_code = msg.data 

309 self._response.close() 

310 return True 

311 

312 async def receive(self, timeout: Optional[float] = None) -> WSMessage: 

313 receive_timeout = timeout or self._timeout.ws_receive 

314 

315 while True: 

316 if self._waiting: 

317 raise RuntimeError("Concurrent call to receive() is not allowed") 

318 

319 if self._closed: 

320 return WS_CLOSED_MESSAGE 

321 elif self._closing: 

322 await self.close() 

323 return WS_CLOSED_MESSAGE 

324 

325 try: 

326 self._waiting = True 

327 try: 

328 if receive_timeout: 

329 # Entering the context manager and creating 

330 # Timeout() object can take almost 50% of the 

331 # run time in this loop so we avoid it if 

332 # there is no read timeout. 

333 async with async_timeout.timeout(receive_timeout): 

334 msg = await self._reader.read() 

335 else: 

336 msg = await self._reader.read() 

337 self._reset_heartbeat() 

338 finally: 

339 self._waiting = False 

340 if self._close_wait: 

341 set_result(self._close_wait, None) 

342 except (asyncio.CancelledError, asyncio.TimeoutError): 

343 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

344 raise 

345 except EofStream: 

346 self._close_code = WSCloseCode.OK 

347 await self.close() 

348 return WSMessage(WSMsgType.CLOSED, None, None) 

349 except ClientError: 

350 # Likely ServerDisconnectedError when connection is lost 

351 self._set_closed() 

352 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

353 return WS_CLOSED_MESSAGE 

354 except WebSocketError as exc: 

355 self._close_code = exc.code 

356 await self.close(code=exc.code) 

357 return WSMessage(WSMsgType.ERROR, exc, None) 

358 except Exception as exc: 

359 self._exception = exc 

360 self._set_closing() 

361 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

362 await self.close() 

363 return WSMessage(WSMsgType.ERROR, exc, None) 

364 

365 if msg.type not in _INTERNAL_RECEIVE_TYPES: 

366 # If its not a close/closing/ping/pong message 

367 # we can return it immediately 

368 return msg 

369 

370 if msg.type is WSMsgType.CLOSE: 

371 self._set_closing() 

372 self._close_code = msg.data 

373 if not self._closed and self._autoclose: 

374 await self.close() 

375 elif msg.type is WSMsgType.CLOSING: 

376 self._set_closing() 

377 elif msg.type is WSMsgType.PING and self._autoping: 

378 await self.pong(msg.data) 

379 continue 

380 elif msg.type is WSMsgType.PONG and self._autoping: 

381 continue 

382 

383 return msg 

384 

385 async def receive_str(self, *, timeout: Optional[float] = None) -> str: 

386 msg = await self.receive(timeout) 

387 if msg.type is not WSMsgType.TEXT: 

388 raise WSMessageTypeError( 

389 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT" 

390 ) 

391 return cast(str, msg.data) 

392 

393 async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: 

394 msg = await self.receive(timeout) 

395 if msg.type is not WSMsgType.BINARY: 

396 raise WSMessageTypeError( 

397 f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY" 

398 ) 

399 return cast(bytes, msg.data) 

400 

401 async def receive_json( 

402 self, 

403 *, 

404 loads: JSONDecoder = DEFAULT_JSON_DECODER, 

405 timeout: Optional[float] = None, 

406 ) -> Any: 

407 data = await self.receive_str(timeout=timeout) 

408 return loads(data) 

409 

410 def __aiter__(self) -> "ClientWebSocketResponse": 

411 return self 

412 

413 async def __anext__(self) -> WSMessage: 

414 msg = await self.receive() 

415 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): 

416 raise StopAsyncIteration 

417 return msg 

418 

419 async def __aenter__(self) -> "ClientWebSocketResponse": 

420 return self 

421 

422 async def __aexit__( 

423 self, 

424 exc_type: Optional[Type[BaseException]], 

425 exc_val: Optional[BaseException], 

426 exc_tb: Optional[TracebackType], 

427 ) -> None: 

428 await self.close()