Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/client_ws.py: 22%

215 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-26 06:16 +0000

1"""WebSocket client for asyncio.""" 

2 

3import asyncio 

4import dataclasses 

5import sys 

6from typing import Any, Final, Optional, cast 

7 

8from .client_exceptions import ClientError 

9from .client_reqrep import ClientResponse 

10from .helpers import call_later, set_result 

11from .http import ( 

12 WS_CLOSED_MESSAGE, 

13 WS_CLOSING_MESSAGE, 

14 WebSocketError, 

15 WSCloseCode, 

16 WSMessage, 

17 WSMsgType, 

18) 

19from .http_websocket import WebSocketWriter # WSMessage 

20from .streams import EofStream, FlowControlDataQueue 

21from .typedefs import ( 

22 DEFAULT_JSON_DECODER, 

23 DEFAULT_JSON_ENCODER, 

24 JSONDecoder, 

25 JSONEncoder, 

26) 

27 

28if sys.version_info >= (3, 11): 

29 import asyncio as async_timeout 

30else: 

31 import async_timeout 

32 

33 

34@dataclasses.dataclass(frozen=True) 

35class ClientWSTimeout: 

36 ws_receive: Optional[float] = None 

37 ws_close: Optional[float] = None 

38 

39 

40DEFAULT_WS_CLIENT_TIMEOUT: Final[ClientWSTimeout] = ClientWSTimeout( 

41 ws_receive=None, ws_close=10.0 

42) 

43 

44 

45class ClientWebSocketResponse: 

46 def __init__( 

47 self, 

48 reader: "FlowControlDataQueue[WSMessage]", 

49 writer: WebSocketWriter, 

50 protocol: Optional[str], 

51 response: ClientResponse, 

52 timeout: ClientWSTimeout, 

53 autoclose: bool, 

54 autoping: bool, 

55 loop: asyncio.AbstractEventLoop, 

56 *, 

57 heartbeat: Optional[float] = None, 

58 compress: int = 0, 

59 client_notakeover: bool = False, 

60 ) -> None: 

61 self._response = response 

62 self._conn = response.connection 

63 

64 self._writer = writer 

65 self._reader = reader 

66 self._protocol = protocol 

67 self._closed = False 

68 self._closing = False 

69 self._close_code: Optional[int] = None 

70 self._timeout: ClientWSTimeout = timeout 

71 self._autoclose = autoclose 

72 self._autoping = autoping 

73 self._heartbeat = heartbeat 

74 self._heartbeat_cb: Optional[asyncio.TimerHandle] = None 

75 if heartbeat is not None: 

76 self._pong_heartbeat = heartbeat / 2.0 

77 self._pong_response_cb: Optional[asyncio.TimerHandle] = None 

78 self._loop = loop 

79 self._waiting: Optional[asyncio.Future[bool]] = None 

80 self._exception: Optional[BaseException] = None 

81 self._compress = compress 

82 self._client_notakeover = client_notakeover 

83 

84 self._reset_heartbeat() 

85 

86 def _cancel_heartbeat(self) -> None: 

87 if self._pong_response_cb is not None: 

88 self._pong_response_cb.cancel() 

89 self._pong_response_cb = None 

90 

91 if self._heartbeat_cb is not None: 

92 self._heartbeat_cb.cancel() 

93 self._heartbeat_cb = None 

94 

95 def _reset_heartbeat(self) -> None: 

96 self._cancel_heartbeat() 

97 

98 if self._heartbeat is not None: 

99 self._heartbeat_cb = call_later( 

100 self._send_heartbeat, 

101 self._heartbeat, 

102 self._loop, 

103 timeout_ceil_threshold=self._conn._connector._timeout_ceil_threshold 

104 if self._conn is not None 

105 else 5, 

106 ) 

107 

108 def _send_heartbeat(self) -> None: 

109 if self._heartbeat is not None and not self._closed: 

110 # fire-and-forget a task is not perfect but maybe ok for 

111 # sending ping. Otherwise we need a long-living heartbeat 

112 # task in the class. 

113 self._loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable] 

114 

115 if self._pong_response_cb is not None: 

116 self._pong_response_cb.cancel() 

117 self._pong_response_cb = call_later( 

118 self._pong_not_received, 

119 self._pong_heartbeat, 

120 self._loop, 

121 timeout_ceil_threshold=self._conn._connector._timeout_ceil_threshold 

122 if self._conn is not None 

123 else 5, 

124 ) 

125 

126 def _pong_not_received(self) -> None: 

127 if not self._closed: 

128 self._closed = True 

129 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

130 self._exception = asyncio.TimeoutError() 

131 self._response.close() 

132 

133 @property 

134 def closed(self) -> bool: 

135 return self._closed 

136 

137 @property 

138 def close_code(self) -> Optional[int]: 

139 return self._close_code 

140 

141 @property 

142 def protocol(self) -> Optional[str]: 

143 return self._protocol 

144 

145 @property 

146 def compress(self) -> int: 

147 return self._compress 

148 

149 @property 

150 def client_notakeover(self) -> bool: 

151 return self._client_notakeover 

152 

153 def get_extra_info(self, name: str, default: Any = None) -> Any: 

154 """extra info from connection transport""" 

155 conn = self._response.connection 

156 if conn is None: 

157 return default 

158 transport = conn.transport 

159 if transport is None: 

160 return default 

161 return transport.get_extra_info(name, default) 

162 

163 def exception(self) -> Optional[BaseException]: 

164 return self._exception 

165 

166 async def ping(self, message: bytes = b"") -> None: 

167 await self._writer.ping(message) 

168 

169 async def pong(self, message: bytes = b"") -> None: 

170 await self._writer.pong(message) 

171 

172 async def send_str(self, data: str, compress: Optional[int] = None) -> None: 

173 if not isinstance(data, str): 

174 raise TypeError("data argument must be str (%r)" % type(data)) 

175 await self._writer.send(data, binary=False, compress=compress) 

176 

177 async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: 

178 if not isinstance(data, (bytes, bytearray, memoryview)): 

179 raise TypeError("data argument must be byte-ish (%r)" % type(data)) 

180 await self._writer.send(data, binary=True, compress=compress) 

181 

182 async def send_json( 

183 self, 

184 data: Any, 

185 compress: Optional[int] = None, 

186 *, 

187 dumps: JSONEncoder = DEFAULT_JSON_ENCODER, 

188 ) -> None: 

189 await self.send_str(dumps(data), compress=compress) 

190 

191 async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: 

192 # we need to break `receive()` cycle first, 

193 # `close()` may be called from different task 

194 if self._waiting is not None and not self._closing: 

195 self._closing = True 

196 self._reader.feed_data(WS_CLOSING_MESSAGE, 0) 

197 await self._waiting 

198 

199 if not self._closed: 

200 self._cancel_heartbeat() 

201 self._closed = True 

202 try: 

203 await self._writer.close(code, message) 

204 except asyncio.CancelledError: 

205 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

206 self._response.close() 

207 raise 

208 except Exception as exc: 

209 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

210 self._exception = exc 

211 self._response.close() 

212 return True 

213 

214 if self._close_code: 

215 self._response.close() 

216 return True 

217 

218 while True: 

219 try: 

220 async with async_timeout.timeout(self._timeout.ws_close): 

221 msg = await self._reader.read() 

222 except asyncio.CancelledError: 

223 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

224 self._response.close() 

225 raise 

226 except Exception as exc: 

227 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

228 self._exception = exc 

229 self._response.close() 

230 return True 

231 

232 if msg.type == WSMsgType.CLOSE: 

233 self._close_code = msg.data 

234 self._response.close() 

235 return True 

236 else: 

237 return False 

238 

239 async def receive(self, timeout: Optional[float] = None) -> WSMessage: 

240 while True: 

241 if self._waiting is not None: 

242 raise RuntimeError("Concurrent call to receive() is not allowed") 

243 

244 if self._closed: 

245 return WS_CLOSED_MESSAGE 

246 elif self._closing: 

247 await self.close() 

248 return WS_CLOSED_MESSAGE 

249 

250 try: 

251 self._waiting = self._loop.create_future() 

252 try: 

253 async with async_timeout.timeout( 

254 timeout or self._timeout.ws_receive 

255 ): 

256 msg = await self._reader.read() 

257 self._reset_heartbeat() 

258 finally: 

259 waiter = self._waiting 

260 self._waiting = None 

261 set_result(waiter, True) 

262 except (asyncio.CancelledError, asyncio.TimeoutError): 

263 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

264 raise 

265 except EofStream: 

266 self._close_code = WSCloseCode.OK 

267 await self.close() 

268 return WSMessage(WSMsgType.CLOSED, None, None) 

269 except ClientError: 

270 self._closed = True 

271 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

272 return WS_CLOSED_MESSAGE 

273 except WebSocketError as exc: 

274 self._close_code = exc.code 

275 await self.close(code=exc.code) 

276 return WSMessage(WSMsgType.ERROR, exc, None) 

277 except Exception as exc: 

278 self._exception = exc 

279 self._closing = True 

280 self._close_code = WSCloseCode.ABNORMAL_CLOSURE 

281 await self.close() 

282 return WSMessage(WSMsgType.ERROR, exc, None) 

283 

284 if msg.type == WSMsgType.CLOSE: 

285 self._closing = True 

286 self._close_code = msg.data 

287 # Could be closed elsewhere while awaiting reader 

288 if not self._closed and self._autoclose: # type: ignore[redundant-expr] 

289 await self.close() 

290 elif msg.type == WSMsgType.CLOSING: 

291 self._closing = True 

292 elif msg.type == WSMsgType.PING and self._autoping: 

293 await self.pong(msg.data) 

294 continue 

295 elif msg.type == WSMsgType.PONG and self._autoping: 

296 continue 

297 

298 return msg 

299 

300 async def receive_str(self, *, timeout: Optional[float] = None) -> str: 

301 msg = await self.receive(timeout) 

302 if msg.type != WSMsgType.TEXT: 

303 raise TypeError(f"Received message {msg.type}:{msg.data!r} is not str") 

304 return cast(str, msg.data) 

305 

306 async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: 

307 msg = await self.receive(timeout) 

308 if msg.type != WSMsgType.BINARY: 

309 raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes") 

310 return cast(bytes, msg.data) 

311 

312 async def receive_json( 

313 self, 

314 *, 

315 loads: JSONDecoder = DEFAULT_JSON_DECODER, 

316 timeout: Optional[float] = None, 

317 ) -> Any: 

318 data = await self.receive_str(timeout=timeout) 

319 return loads(data) 

320 

321 def __aiter__(self) -> "ClientWebSocketResponse": 

322 return self 

323 

324 async def __anext__(self) -> WSMessage: 

325 msg = await self.receive() 

326 if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): 

327 raise StopAsyncIteration 

328 return msg